In [None]:
from starccato_flow.data.ccsn_data import CCSNData
from starccato_flow.data.toy_data import ToyData
from starccato_flow.training.trainer_vae import VAETrainer
from starccato_flow.training.trainer_cvae import ConditionalVAETrainer

In [None]:
ccsn_dataset = CCSNData(noise=True, curriculum=False)
ccsn_dataset.plot_signal_distribution(background="black", font_family="sans-serif", font_name="Avenir", fname="plots/ccsn_signal_distribution.svg")

toy_dataset = ToyData(noise=False, curriculum=False)
toy_dataset.plot_signal_distribution(background="black", font_family="sans-serif", font_name="Avenir", fname="plots/toy_signal_distribution.svg")

In [None]:
trainer = ConditionalVAETrainer(
    toy=False, 
    num_epochs=256,
    noise=False, 
    curriculum=False,
    validation_split=0.1,
    noise_realizations=1
)

In [None]:
trainer.train()

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from starccato_flow.utils.defaults import DEVICE

# Test: Generate signals with DIFFERENT parameter values
print("Testing CVAE parameter conditioning...")
print("=" * 60)

# Define 3 different parameter sets
if trainer.param_dim == 1:
    param_sets = [
        np.array([0.02]),   # Low beta
        np.array([0.10]),   # Medium beta
        np.array([0.18])    # High beta
    ]
    param_labels = ['β=0.02', 'β=0.10', 'β=0.18']
elif trainer.param_dim == 4:
    param_sets = [
        np.array([0.02, 6.0, 3000.0, 0.10]),   # Low values
        np.array([0.10, 10.0, 6000.0, 0.15]),  # Medium values
        np.array([0.18, 14.0, 9000.0, 0.20])   # High values
    ]
    param_labels = ['Low', 'Medium', 'High']
else:
    param_sets = [np.zeros(trainer.param_dim) for _ in range(3)]
    param_labels = ['Set 1', 'Set 2', 'Set 3']

# Normalize parameters
param_sets_norm = [trainer.training_dataset.normalize_parameters(p) for p in param_sets]

# Generate noise samples ONCE (will be reused for all parameter sets)
num_samples_per_set = 3
z_samples = torch.randn(num_samples_per_set, trainer.z_dim).to(DEVICE)
print(f"\nUsing {num_samples_per_set} shared noise samples across all parameter sets")

trainer.cvae.eval()

fig, axes = plt.subplots(3, 3, figsize=(15, 10))

# First pass: generate all signals and find global y-limits
all_signals_scaled = []
with torch.no_grad():
    for i, (params_norm, params_raw, label) in enumerate(zip(param_sets_norm, param_sets, param_labels)):
        # Use the SAME z_samples for all parameter sets
        params_tensor = torch.tensor(params_norm, dtype=torch.float32).unsqueeze(0).repeat(num_samples_per_set, 1).to(DEVICE)
        
        generated = trainer.cvae.decoder(z_samples, params_tensor).cpu().numpy()
        
        for j in range(num_samples_per_set):
            signal_denorm = trainer.training_dataset.denormalise_signals(generated[j])
            signal_scaled = signal_denorm / 1e-20
            all_signals_scaled.append(signal_scaled)

# Calculate global y-limits
all_signals_array = np.array(all_signals_scaled)
y_min = all_signals_array.min()
y_max = all_signals_array.max()
y_margin = (y_max - y_min) * 0.1

# Second pass: plot with consistent y-limits
idx = 0
with torch.no_grad():
    for i, (params_norm, params_raw, label) in enumerate(zip(param_sets_norm, param_sets, param_labels)):
        for j in range(num_samples_per_set):
            signal_scaled = all_signals_scaled[idx]
            time = np.arange(len(signal_scaled)) / 16384
            
            axes[i, j].plot(time, signal_scaled, linewidth=0.8, color='#2c3e50')
            axes[i, j].set_ylabel('Strain (×10⁻²⁰)', fontsize=9)
            axes[i, j].set_ylim(y_min - y_margin, y_max + y_margin)
            axes[i, j].grid(True, alpha=0.3)
            axes[i, j].axhline(y=0, color='gray', linestyle='--', alpha=0.3)
            
            if i == 0:
                axes[i, j].set_title(f'Noise Sample {j+1}', fontsize=10)
            if i == 2:
                axes[i, j].set_xlabel('Time (s)', fontsize=9)
            if j == 0:
                if trainer.param_dim == 1:
                    axes[i, j].text(-0.3, 0.5, label, transform=axes[i, j].transAxes, 
                                   fontsize=11, va='center', rotation=90, weight='bold')
                else:
                    param_str = f"β={params_raw[0]:.2f}\nω={params_raw[1]:.1f}"
                    axes[i, j].text(-0.35, 0.5, param_str, transform=axes[i, j].transAxes, 
                                   fontsize=9, va='center', rotation=90)
            
            idx += 1

plt.suptitle('CVAE Generation Test: Same Noise, Different Parameters', fontsize=14, weight='bold')
plt.tight_layout()
plt.savefig('plots/cvae_parameter_test.svg', bbox_inches='tight', dpi=150)
plt.show()

print("\n✓ If the model is learning parameters correctly:")
print("  - Each COLUMN should show consistent structure (same noise)")
print("  - Different ROWS should look different (different parameters)")
print("  - Signals in same column vary due to parameter changes, not noise")
print("\nIf all signals look too similar → model NOT learning parameters")
print("If columns show inconsistent structure → check noise reuse logic")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# ========== CHOOSE YOUR TARGET PARAMETER VALUES HERE ==========
target_beta = 0.05    # Beta value
target_omega = None    # Set to a value like 10.0 for omega, or None to ignore
# ================================================================

# Get parameters from training dataset
params = trainer.training_dataset.parameters
num_params = params.shape[1]
param_names = trainer.training_dataset.parameter_names

print(f"Dataset has {num_params} parameter(s): {param_names}")

# Build target vector based on available parameters
if num_params == 1:
    # Only beta available
    target_vector = np.array([target_beta])
    print(f"\nSearching for signals near: β = {target_beta:.6f}")
elif num_params == 4:
    # Beta, omega, A, Ye available
    if target_omega is None:
        print("\nNote: You have 4 parameters but target_omega is None. Finding signals with similar beta only.")
        # Find signals with closest beta
        beta_values = params[:, 0]
        differences = np.abs(beta_values - target_beta)
        closest_indices = np.argsort(differences)[:2]
    else:
        # Match on both beta and omega
        target_vector = np.array([target_beta, target_omega, 0.0, 0.0])  # A and Ye set to 0 for now
        print(f"\nSearching for signals near: β = {target_beta:.6f}, ω₀ = {target_omega:.6f}")
        # Calculate Euclidean distance for beta and omega only
        differences = np.sqrt((params[:, 0] - target_beta)**2 + (params[:, 1] - target_omega)**2)
        closest_indices = np.argsort(differences)[:2]
else:
    # Use all available parameters
    target_vector = np.array([target_beta] + [0.0] * (num_params - 1))
    print(f"\nSearching for signals near target values")

# For single parameter case
if num_params == 1:
    differences = np.abs(params[:, 0] - target_beta)
    closest_indices = np.argsort(differences)[:2]

idx1, idx2 = closest_indices[0], closest_indices[1]

# Get the signals
signal1, _, param1 = trainer.training_dataset[idx1]
signal2, _, param2 = trainer.training_dataset[idx2]

# Denormalize for plotting
signal1_denorm = trainer.training_dataset.denormalise_signals(signal1.cpu().numpy().flatten())
signal2_denorm = trainer.training_dataset.denormalise_signals(signal2.cpu().numpy().flatten())
param1_denorm = trainer.training_dataset.denormalize_parameters(param1.cpu().numpy().flatten())
param2_denorm = trainer.training_dataset.denormalize_parameters(param2.cpu().numpy().flatten())

print(f"\nSignal 1 - Index: {idx1}")
print(f"  β = {param1_denorm[0]:.6f} (Δ = {abs(param1_denorm[0] - target_beta):.6f})")
if num_params == 4 and target_omega is not None:
    print(f"  ω₀ = {param1_denorm[1]:.6f} (Δ = {abs(param1_denorm[1] - target_omega):.6f})")

print(f"\nSignal 2 - Index: {idx2}")
print(f"  β = {param2_denorm[0]:.6f} (Δ = {abs(param2_denorm[0] - target_beta):.6f})")
if num_params == 4 and target_omega is not None:
    print(f"  ω₀ = {param2_denorm[1]:.6f} (Δ = {abs(param2_denorm[1] - target_omega):.6f})")

# Calculate global y-limits for consistent scaling
signal1_scaled = signal1_denorm / 1e-20
signal2_scaled = signal2_denorm / 1e-20
y_min = min(signal1_scaled.min(), signal2_scaled.min())
y_max = max(signal1_scaled.max(), signal2_scaled.max())
y_margin = (y_max - y_min) * 0.1  # Add 10% margin

# Plot the two signals
fig, axes = plt.subplots(2, 1, figsize=(12, 8))
time = np.arange(len(signal1_denorm)) / 16384  # Assuming 16384 Hz sampling rate

# Build title strings
if num_params == 4 and target_omega is not None:
    title1 = f'Signal {idx1}: β = {param1_denorm[0]:.6f}, ω₀ = {param1_denorm[1]:.6f}'
    title2 = f'Signal {idx2}: β = {param2_denorm[0]:.6f}, ω₀ = {param2_denorm[1]:.6f}'
    suptitle = f'Two Signals Closest to Target β = {target_beta:.6f}, ω₀ = {target_omega:.6f}'
else:
    title1 = f'Signal {idx1}: β = {param1_denorm[0]:.6f}'
    title2 = f'Signal {idx2}: β = {param2_denorm[0]:.6f}'
    suptitle = f'Two Signals Closest to Target β = {target_beta:.6f}'

axes[0].plot(time, signal1_scaled, linewidth=1, color='#3498db')
axes[0].set_ylabel('Strain (×10⁻²⁰)', fontsize=12)
axes[0].set_title(title1, fontsize=14)
axes[0].grid(True, alpha=0.3)
axes[0].axhline(y=0, color='gray', linestyle='--', alpha=0.3)
axes[0].set_ylim(y_min - y_margin, y_max + y_margin)

axes[1].plot(time, signal2_scaled, linewidth=1, color='#e74c3c')
axes[1].set_xlabel('Time (s)', fontsize=12)
axes[1].set_ylabel('Strain (×10⁻²⁰)', fontsize=12)
axes[1].set_title(title2, fontsize=14)
axes[1].grid(True, alpha=0.3)
axes[1].axhline(y=0, color='gray', linestyle='--', alpha=0.3)
axes[1].set_ylim(y_min - y_margin, y_max + y_margin)

plt.suptitle(suptitle, fontsize=16, y=1.00)
plt.tight_layout()
plt.show()

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from starccato_flow.utils.defaults import DEVICE

# Encode the two signals to latent space
with torch.no_grad():
    # Get normalized signals (already in normalized form from __getitem__)
    signal1_norm = torch.tensor(trainer.training_dataset.normalise_signals(signal1_denorm), dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device=DEVICE)
    signal2_norm = torch.tensor(trainer.training_dataset.normalise_signals(signal2_denorm), dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device=DEVICE)

    # Encode to latent space
    mu1, logvar1 = trainer.vae.encoder(signal1_norm)
    mu2, logvar2 = trainer.vae.encoder(signal2_norm)
    
    # Use mean (mu) for interpolation (not sampling)
    z1 = mu1
    z2 = mu2
    
    print(f"Latent vector 1 shape: {z1.shape}")
    print(f"Latent vector 2 shape: {z2.shape}")
    
    # Interpolate in latent space (midpoint)
    z_interpolated = (z1 + z2) / 2.0
    
    print(f"Interpolated latent vector shape: {z_interpolated.shape}")
    print(f"Distance between z1 and z2: {torch.norm(z2 - z1).item():.6f}")
    
    # Decode interpolated latent vector
    interpolated_signal_norm = trainer.vae.decoder(z_interpolated).cpu().numpy().flatten()
    
# Denormalize the interpolated signal
interpolated_signal = trainer.training_dataset.denormalise_signals(interpolated_signal_norm)

# Plot all three signals
fig, axes = plt.subplots(3, 1, figsize=(12, 10))
time = np.arange(len(signal1_denorm)) / 16384

# Scale all signals
signal1_scaled = signal1_denorm / 1e-20
signal2_scaled = signal2_denorm / 1e-20
interpolated_scaled = interpolated_signal / 1e-20

# Calculate global y-limits for consistent scaling
y_min = min(signal1_scaled.min(), signal2_scaled.min(), interpolated_scaled.min())
y_max = max(signal1_scaled.max(), signal2_scaled.max(), interpolated_scaled.max())
y_margin = (y_max - y_min) * 0.1

# Plot signal 1
axes[0].plot(time, signal1_scaled, linewidth=1, color='#3498db')
axes[0].set_ylabel('Strain (×10⁻²⁰)', fontsize=12)
axes[0].set_title(f'Original Signal {idx1}: β = {param1_denorm[0]:.6f}', fontsize=14)
axes[0].grid(True, alpha=0.3)
axes[0].axhline(y=0, color='gray', linestyle='--', alpha=0.3)
axes[0].set_ylim(y_min - y_margin, y_max + y_margin)

# Plot interpolated signal
axes[1].plot(time, interpolated_scaled, linewidth=1, color='#2ecc71')
axes[1].set_ylabel('Strain (×10⁻²⁰)', fontsize=12)
axes[1].set_title(f'Interpolated Signal (Latent Space Midpoint)', fontsize=14)
axes[1].grid(True, alpha=0.3)
axes[1].axhline(y=0, color='gray', linestyle='--', alpha=0.3)
axes[1].set_ylim(y_min - y_margin, y_max + y_margin)

# Plot signal 2
axes[2].plot(time, signal2_scaled, linewidth=1, color='#e74c3c')
axes[2].set_xlabel('Time (s)', fontsize=12)
axes[2].set_ylabel('Strain (×10⁻²⁰)', fontsize=12)
axes[2].set_title(f'Original Signal {idx2}: β = {param2_denorm[0]:.6f}', fontsize=14)
axes[2].grid(True, alpha=0.3)
axes[2].axhline(y=0, color='gray', linestyle='--', alpha=0.3)
axes[2].set_ylim(y_min - y_margin, y_max + y_margin)

plt.suptitle('Latent Space Interpolation Between Two Signals', fontsize=16, y=1.00)
plt.tight_layout()
plt.show()

print(f"\nInterpolated signal statistics:")
print(f"  Min: {interpolated_signal.min():.2e}")
print(f"  Max: {interpolated_signal.max():.2e}")
print(f"  Mean: {interpolated_signal.mean():.2e}")