In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from typing import List, Optional, Tuple

import corner


In [None]:
from starccato_flow.data.toy_data import ToyData
from starccato_flow.data.ccsn_snr_data import CCSNSNRData
from starccato_flow.training.trainer import Trainer

from starccato_flow.plotting.plotting import plot_reconstruction_distribution, plot_candidate_signal

In [None]:
from starccato_flow.utils.defaults import DEVICE

In [None]:
ccsn_dataset = CCSNSNRData(noise=True, curriculum=False)
ccsn_dataset.plot_signal_distribution(background="black", font_family="sans-serif", font_name="Avenir", fname="plots/ccsn_signal_distribution.svg")

In [None]:
ccsn_dataset.plot_signal_grid(background="black", font_family="sans-serif", font_name="Avenir", fname="plots/ccsn_signal_grid.png")

In [None]:
trainer = Trainer(
    toy=False, 
    start_snr=200,
    end_snr=10,
    noise=True, 
    validation_split=0.1,
    curriculum=True,
    noise_realizations=1
)

trainer.plot_candidate_signal(
    snr=30,
    index=60,
    background="black",
    fname="plots/ccsn_candidate_signal.png"
)

In [None]:
trainer.train()

In [None]:
trainer.display_results()

In [None]:
# index = 9

# val_idx = trainer.validation_sampler.indices[index]
# signal, noisy_signal, params = trainer.val_loader.dataset.__getitem__(val_idx)

# trainer.plot_reconstruction_distribution(
#     num_samples=1000,
#     background="white",
#     font_family="sans-serif",
#     font_name="Avenir"
# )

In [None]:
# Check 2: Does the flow produce different posteriors for different latents?
print("\n" + "=" * 60)
print("DIAGNOSTIC: Checking Flow Posterior Predictions")
print("=" * 60)

num_samples = 1000

for i, idx in enumerate(test_indices[:3]):  # Test 3 signals
    signal, noisy_signal, params = trainer.training_dataset[idx]
    
    # Encode to latent (data is already tensor from dataset)
    with torch.no_grad():
        if isinstance(noisy_signal, np.ndarray):
            noisy_signal_tensor = torch.from_numpy(noisy_signal).unsqueeze(0).to(DEVICE)
        else:
            noisy_signal_tensor = noisy_signal.unsqueeze(0).to(DEVICE)
            
        _, mean, _ = trainer.vae(noisy_signal_tensor)
        z_latent = mean.view(1, -1)
        
        # Sample from flow conditioned on this latent
        samples = trainer.flow.sample(num_samples, context=z_latent).cpu().numpy()
        samples = np.exp(samples) - 1e-8  # Reverse log transform
    
    # Convert params to numpy if needed
    if isinstance(params, torch.Tensor):
        params = params.cpu().numpy()
    
    print(f"\nSignal {idx}:")
    print(f"  True params: {params.flatten()}")
    print(f"  Predicted mean: {samples.mean(axis=0)}")
    print(f"  Predicted std:  {samples.std(axis=0)}")

print("\n" + "=" * 60)
print("If predicted means are all similar, flow is NOT conditioning properly!")
print("=" * 60)