In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from typing import List, Optional, Tuple

import corner

In [None]:
from starccato_flow.data.ccsn_snr_data import CCSNSNRData
from starccato_flow.data.toy_data import ToyData
from starccato_flow.training.trainer import Trainer
from starccato_flow.training.trainer_flow_matching import FlowMatchingTrainer

from starccato_flow.utils.defaults import TEN_KPC
from starccato_flow.plotting.plotting import plot_reconstruction_distribution, plot_candidate_signal, create_latent_morph_gif

In [None]:
from starccato_flow.utils.defaults import DEVICE

In [None]:
ccsn_dataset = CCSNSNRData(noise=True, curriculum=False)
ccsn_dataset.plot_signal_distribution(background="black", font_family="sans-serif", font_name="Avenir", fname="plots/ccsn_signal_distribution.svg")

toy_dataset = ToyData(noise=False, curriculum=False)
toy_dataset.plot_signal_distribution(background="black", font_family="sans-serif", font_name="Avenir", fname="plots/toy_signal_distribution.svg")

In [None]:
trainer = FlowMatchingTrainer(
    toy=True, 
    num_epochs=512,
    start_snr=200,
    end_snr=10,
    noise=False, 
    validation_split=0.1,
    curriculum=False,
    noise_realizations=1
)

# trainer.plot_candidate_signal(
#     snr=30,
#     index=60,
#     background="black"
# )

In [None]:
trainer.train()

In [None]:
trainer.display_results()

In [None]:
# trainer.val_loader.dataset.update_snr(20)
# signal, noisy_signal, params = trainer.val_loader.dataset.__getitem__(9)
# plot_candidate_signal(noisy_signal=noisy_signal/TEN_KPC, signal=signal/TEN_KPC, max_value=trainer.validation_dataset.max_strain, fname="plots/detected_signal.svg", background="black")
trainer.plot_corner(index=150, fname="plots/corner_plot.svg")

In [None]:
def plot_velocity_field_evolution(trainer, n_test_samples=300, n_steps=8, background="white", fname="plots/flow_evolution_1d_marginals.png"):
    """
    Visualize the evolution of parameter distributions through the velocity field.
    
    Args:
        trainer: FlowMatchingTrainer instance
        n_test_samples: Number of test samples to use
        n_steps: Number of flow steps to visualize
        background: "white" or "black" for plot styling
        fname: Filename to save the plot
    """
    from starccato_flow.plotting.plotting import set_plot_style
    
    # Set plot style based on background
    set_plot_style(background=background, font_family="sans-serif", font_name="Avenir")
    
    val_dataset = trainer.val_loader.dataset
    
    # Get true parameters and signals from validation set
    n_test_samples = min(n_test_samples, len(val_dataset))
    
    # Detect number of parameters dynamically
    sample_signal, sample_noisy, sample_params = val_dataset[0]
    num_params = sample_params.shape[-1]
    print(f"Number of parameters: {num_params}")
    
    true_params = torch.zeros(n_test_samples, num_params, device=DEVICE)
    test_signals = torch.zeros(n_test_samples, trainer.y_length, device=DEVICE)
    
    for i in range(n_test_samples):
        clean_signal, noisy_signal, params = val_dataset[i]
        true_params[i] = params.squeeze()
        test_signals[i] = noisy_signal.squeeze()
    
    # Start from noise in parameter space (independent 1D Gaussians)
    x = torch.randn(n_test_samples, num_params, device=DEVICE)
    time_steps = torch.linspace(0, 1.0, n_steps + 1)
    
    # Store snapshots for visualization
    snapshots = [x.detach().cpu().clone()]
    
    # Flow the samples
    for i in range(n_steps):
        x = trainer.flow.step(x, time_steps[i], time_steps[i + 1], test_signals)
        snapshots.append(x.detach().cpu().clone())
    
    # Visualize 1D marginal distributions evolving
    param_names = ['β', 'log(ω₀)', 'log(A)', 'Yₑ'] if num_params == 4 else ['Param 1', 'Param 2']
    
    # Color scheme based on background
    if background == "black":
        estimated_color = '#3498db'  # Blue
        true_color = '#e74c3c'  # Red
        text_color = 'white'
        grid_color = 'gray'
        grid_alpha = 0.3
    else:
        estimated_color = '#2980b9'  # Darker blue
        true_color = '#c0392b'  # Darker red
        text_color = 'black'
        grid_color = 'lightgray'
        grid_alpha = 0.5
    
    fig, axes = plt.subplots(num_params, n_steps + 1, figsize=(30, 4*num_params))
    if num_params == 1:
        axes = axes.reshape(1, -1)
    
    for param_idx in range(num_params):
        for step_idx, snapshot in enumerate(snapshots):
            ax = axes[param_idx, step_idx]
            
            # Plot histogram of this parameter dimension
            ax.hist(snapshot[:, param_idx].numpy(), bins=30, alpha=0.7, 
                    color=estimated_color, density=True, label='Estimated', edgecolor='none')
            ax.hist(true_params.cpu()[:, param_idx].numpy(), bins=30, alpha=0.5, 
                    color=true_color, density=True, label='True', edgecolor='none')
            
            if param_idx == 0:
                ax.set_title(f't = {time_steps[step_idx]:.2f}', fontsize=12, color=text_color)
            if step_idx == 0:
                ax.set_ylabel(f'{param_names[param_idx]}\nDensity', fontsize=10, color=text_color)
            if param_idx == num_params - 1:
                ax.set_xlabel('Value', fontsize=10, color=text_color)
            
            # Set x-limits based on parameter type
            if num_params == 4:
                ax.set_xlim(-1.5, 1.5)  # Normalized parameters
            else:
                ax.set_xlim(-1.5, 1.5)  # Normalized toy parameters
            
            if step_idx == 0:
                ax.legend(fontsize=8, framealpha=0.9)
            
            # Grid
            ax.grid(True, alpha=grid_alpha, color=grid_color, linestyle='--', linewidth=0.5)
            
            # Spine colors
            for spine in ax.spines.values():
                spine.set_edgecolor(text_color)
                spine.set_linewidth(0.5)
    
    title_text = '1D Marginal Distributions: Velocity Field Evolution from Gaussian to Target'
    plt.suptitle(title_text, fontsize=16, y=1.00, color=text_color)
    plt.tight_layout()
    plt.savefig(fname, dpi=150, bbox_inches='tight', facecolor=fig.get_facecolor())
    plt.show()

# Call the function
plot_velocity_field_evolution(trainer, n_test_samples=300, n_steps=8, background="white", 
                               fname="plots/flow_evolution_1d_marginals.png")