In [1]:
import numpy as np
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from typing import List, Optional, Tuple

import corner


In [2]:
from starccato_flow.data.toy_data import ToyData
from starccato_flow.data.ccsn_data import CCSNData
from starccato_flow.data.ccsn_snr_data import CCSNSNRData
from starccato_flow.training.trainer import Trainer

from starccato_flow.plotting.plotting import plot_reconstruction_distribution, plot_candidate_signal

  assert (


MPS device found


ImportError: cannot import name 'plot_candidate_signals' from 'starccato_flow.plotting.plotting' (/Users/tarineccleston/Desktop/starccato/starccato-flow/src/starccato_flow/plotting/plotting.py)

In [None]:
from starccato_flow.utils.defaults import DEVICE

### Dataset Plots

In [None]:
ccsn_dataset = CCSNSNRData(noise=True, curriculum=False)
ccsn_dataset.plot_signal_distribution(background="black", font_family="sans-serif", font_name="Avenir", fname="plots/ccsn_signal_distribution.svg")

In [None]:
ccsn_dataset.update_snr(100)
# ccsn_dataset.plot_signal_grid(background="white", font_family="sans-serif", font_name="Avenir", fname="plots/ccsn_signal_grid.svg")


plot_candidate_signal(
    signal=
    noisy_signal=
    fname="plots/ccsn_candidate_signal.svg"
)

In [None]:
toy=False

### Train VAE + Flow

In [None]:
trainer = Trainer(
    toy=toy, 
    start_snr=200,
    end_snr=10,
    noise=True, 
    validation_split=0.1,
    curriculum=True,
    noise_realizations=1  # Increased from 1 to 3 for more data augmentation
)

trainer.val_loader.dataset.update_snr(100)
index = 0
plot_candidate_signal(
    signal=trainer.val_loader.dataset.signals[:, index],
    noisy_signal=trainer.val_loader.dataset.signals[:, index],
    fname="plots/ccsn_candidate_signal.svg"
)

trainer.train()

### Display Results

In [None]:
trainer.display_results()

In [None]:
trainer.plot_generated_signal_distribution(
    background="white",
    font_family="sans-serif",
    font_name="Avenir"
)

In [None]:
index = 1000

trainer.val_loader.dataset.update_snr(10)

trainer.plot_reconstruction_distribution(
    num_samples=1000,
    background="white",
    font_family="sans-serif",
    font_name="Avenir",
    index=index
)

In [None]:
trainer.save_models()

In [None]:
trainer.validation_dataset.update_snr(8)
trainer.training_dataset.update_snr(8)
signal, noisy_signal, params = trainer.training_dataset[1]
trainer.plot_corner(signal, noisy_signal, params, index=160)

In [None]:
# Check 1: Are latent encodings different for different signals?
import torch

num_test_signals = 5
test_indices = [10, 50, 100, 150, 200]

print("=" * 60)
print("DIAGNOSTIC: Checking Latent Encodings")
print("=" * 60)

latents = []
true_params = []

for idx in test_indices:
    signal, noisy_signal, params = trainer.validation_dataset[idx]
    
    # Encode to latent (data is already tensor from dataset)
    with torch.no_grad():
        if isinstance(noisy_signal, np.ndarray):
            noisy_signal_tensor = torch.from_numpy(noisy_signal).unsqueeze(0).to(DEVICE)
        else:
            noisy_signal_tensor = noisy_signal.unsqueeze(0).to(DEVICE)
        
        _, mean, _ = trainer.vae(noisy_signal_tensor)
        mean = mean.view(-1).cpu().numpy()
    
    latents.append(mean)
    
    # Convert params to numpy if needed
    if isinstance(params, torch.Tensor):
        params = params.cpu().numpy()
    true_params.append(params.flatten())
    
    print(f"\nSignal {idx}:")
    print(f"  True params: {params.flatten()}")
    print(f"  Latent (first 5 dims): {mean[:5]}")

# Check variance in latents
latents_array = np.array(latents)
print("\n" + "=" * 60)
print("Latent Statistics Across Signals:")
print(f"  Mean per dimension: {latents_array.mean(axis=0)[:5]}")
print(f"  Std per dimension:  {latents_array.std(axis=0)[:5]}")
print(f"  Are latents all the same? {np.allclose(latents_array[0], latents_array[1:], atol=1e-3)}")

# Check variance in true parameters
true_params_array = np.array(true_params)
print("\n" + "=" * 60)
print("True Parameter Statistics:")
print(f"  Mean: {true_params_array.mean(axis=0)}")
print(f"  Std:  {true_params_array.std(axis=0)}")
print(f"  Min:  {true_params_array.min(axis=0)}")
print(f"  Max:  {true_params_array.max(axis=0)}")
print("=" * 60)

In [None]:
# Check 2: Does the flow produce different posteriors for different latents?
print("\n" + "=" * 60)
print("DIAGNOSTIC: Checking Flow Posterior Predictions")
print("=" * 60)

num_samples = 1000

for i, idx in enumerate(test_indices[:3]):  # Test 3 signals
    signal, noisy_signal, params = trainer.training_dataset[idx]
    
    # Encode to latent (data is already tensor from dataset)
    with torch.no_grad():
        if isinstance(noisy_signal, np.ndarray):
            noisy_signal_tensor = torch.from_numpy(noisy_signal).unsqueeze(0).to(DEVICE)
        else:
            noisy_signal_tensor = noisy_signal.unsqueeze(0).to(DEVICE)
            
        _, mean, _ = trainer.vae(noisy_signal_tensor)
        z_latent = mean.view(1, -1)
        
        # Sample from flow conditioned on this latent
        samples = trainer.flow.sample(num_samples, context=z_latent).cpu().numpy()
        samples = np.exp(samples) - 1e-8  # Reverse log transform
    
    # Convert params to numpy if needed
    if isinstance(params, torch.Tensor):
        params = params.cpu().numpy()
    
    print(f"\nSignal {idx}:")
    print(f"  True params: {params.flatten()}")
    print(f"  Predicted mean: {samples.mean(axis=0)}")
    print(f"  Predicted std:  {samples.std(axis=0)}")

print("\n" + "=" * 60)
print("If predicted means are all similar, flow is NOT conditioning properly!")
print("=" * 60)

In [None]:
# Check 3: Detailed analysis of flow samples
print("\n" + "=" * 60)
print("DETAILED DIAGNOSTIC: Flow Sample Analysis")
print("=" * 60)

idx = test_indices[0]  # Test first signal
signal, noisy_signal, params = trainer.validation_dataset[idx]

with torch.no_grad():
    if isinstance(noisy_signal, np.ndarray):
        noisy_signal_tensor = torch.from_numpy(noisy_signal).unsqueeze(0).to(DEVICE)
    else:
        noisy_signal_tensor = noisy_signal.unsqueeze(0).to(DEVICE)
        
    _, mean, _ = trainer.vae(noisy_signal_tensor)
    z_latent = mean.view(1, -1)
    
    # Sample from flow
    samples = trainer.flow.sample(1000, context=z_latent).cpu().numpy()
    
    print(f"\nSignal {idx}:")
    print(f"Samples shape: {samples.shape}")
    print(f"First 5 samples (in log space):")
    for i in range(5):
        print(f"  Sample {i}: {samples[i]}")
    
    print(f"\nAre all samples identical? {np.allclose(samples[0], samples[1:], atol=1e-6)}")
    print(f"Std dev across samples: {samples.std(axis=0)}")
    print(f"Min across samples: {samples.min(axis=0)}")
    print(f"Max across samples: {samples.max(axis=0)}")
    
    # After exp transform
    samples_exp = np.exp(samples) - 1e-8
    print(f"\nAfter exp transform:")
    print(f"Std dev: {samples_exp.std(axis=0)}")
    print(f"Mean: {samples_exp.mean(axis=0)}")

print("\n" + "=" * 60)
print("ISSUE: Flow is deterministic - outputting same value every time!")
print("=" * 60)

### Diagnostic: Check if Flow is Learning Conditional Distribution