In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from typing import List, Optional, Tuple


In [None]:
from starccato_flow.data.toy_data import ToyData
from starccato_flow.data.ccsn_data import CCSNData
from starccato_flow.training.trainer import Trainer

from starccato_flow.plotting.plotting import plot_reconstruction_distribution

In [None]:
from starccato_flow.utils.defaults import DEVICE

### Dataset

In [None]:
# train_dataset = ToyData(num_signals=1684, signal_length=256)
# validation_dataset = ToyData(num_signals=round(1684 * 0.1), signal_length=256)

### Dataset Plots

In [None]:
ccsn_dataset = CCSNData(noise=True, curriculum=True)
ccsn_dataset.plot_signal_distribution(background="white", font_family="sans-serif", font_name="Avenir", fname="plots/ccsn_signal_distribution.svg")

In [None]:
ccsn_dataset.plot_signal_grid(background="white", font_family="sans-serif", font_name="Avenir", fname="plots/ccsn_signal_grid.svg")

In [None]:
toy=False

### Train VAE + Flow

In [None]:
vae_trainer = Trainer(toy=toy, noise=True, curriculum=True)
vae_trainer.train()

### Display Results

In [None]:
vae_trainer.display_results()

In [None]:
# plot_latent_morph_up_and_down(
#     vae_trainer.vae,
#     signal_1=ccsn_dataset.__getitem__(800)[0],
#     signal_2=ccsn_dataset.__getitem__(600)[0],
#     max_value=vae_trainer.training_dataset.max_strain,
#     train_dataset=CCSNData(),
#     steps=1
# )

In [None]:
vae_trainer.plot_generated_signal_distribution(
    background="white",
    font_family="sans-serif",
    font_name="Avenir"
)

In [None]:
vae_trainer.plot_reconstruction_distribution(
    num_samples=1000,
    index=1
)

In [None]:
vae_trainer.save_models()

In [None]:
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

from starccato_flow.data.ccsn_data import CCSNData
from starccato_flow.utils.defaults import DEVICE

from nflows.distributions.normal import StandardNormal
from nflows.transforms import CompositeTransform, ReversePermutation, MaskedAffineAutoregressiveTransform
from nflows.flows import Flow
import torch.optim as optim


def train_npe_with_vae(vae_trainer, num_epochs=20, batch_size=32, lr=1e-4, flow=None):
    """
    Train a MaskedAutoregressiveFlow to estimate p(params | latent)
    """

    vae = vae_trainer.vae
    vae.eval()  # freeze VAE
    latent_dim = 6
    param_dim = 6  # your target parameter space

    num_layers = 6
    # create base dist and transforms in float32
    base_dist = StandardNormal(shape=[param_dim])

    # composite transform
    transforms = []
    for i in range(num_layers):
        if i % 2 == 0:
            transforms.append(ReversePermutation(features=param_dim))
        transforms.append(
            MaskedAffineAutoregressiveTransform(
                features=param_dim,
                hidden_features=128,
                context_features=latent_dim
            )
        )

    transform = CompositeTransform(transforms)

    # create flow on CPU first, in float32
    flow = Flow(transform, base_dist)

    # move to device explicitly, MPS requires float32
    flow = flow.to(DEVICE, dtype=torch.float32)

    optimizer = optim.Adam(flow.parameters(), lr=lr)

    ccsn_loader = DataLoader(
        CCSNData(noise=True, curriculum=False),
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
    )

    for epoch in range(num_epochs):
        total_loss = 0.0

        for batch_idx, (signal, noisy_signal, params) in enumerate(ccsn_loader):
            signal = signal.to(DEVICE).float()
            noisy_signal = noisy_signal.to(DEVICE).float()
            params = params.to(DEVICE).float()

            # Encode signal into latent space
            with torch.no_grad():
                _, mean, log_var = vae(noisy_signal)
                z_latent = vae.reparameterization(mean, log_var)

            # p(params | z)
            params = params.view(params.size(0), -1) 
            z_latent = z_latent.view(z_latent.size(0), -1) 

            optimizer.zero_grad(set_to_none=True)

            log_prob = flow.log_prob(params, context=mean) # this conditions the flow on the latent variable z
            loss = -log_prob.mean()

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch [{epoch+1}/{num_epochs}] | Flow NLL: {total_loss / len(ccsn_loader):.4f}")

    return flow

npe_flow = train_npe_with_vae(vae_trainer, num_epochs=50, batch_size=32, lr=1e-4)

In [None]:
import torch
import matplotlib.pyplot as plt

vae_trainer.vae.eval()
npe_flow.eval()

from torch.utils.data import DataLoader
from starccato_flow.data.ccsn_data import CCSNData
from starccato_flow.utils.defaults import DEVICE

index = 1100

signal = vae_trainer.training_dataset.__getitem__(index)[0]
noisy_signal = vae_trainer.training_dataset.__getitem__(index)[1]
params = vae_trainer.training_dataset.__getitem__(index)[2]

# Ensure batch dimension [B, C, T]
if noisy_signal.dim() == 2:
    noisy_signal = noisy_signal.unsqueeze(0)

with torch.no_grad():
    noisy_signal = noisy_signal.to(DEVICE).float()
    signal = signal.to(DEVICE).float()
    _, mean, log_var = vae_trainer.vae(noisy_signal)
    z = vae_trainer.vae.reparameterization(mean, torch.exp(0.5 * log_var))

    # Use z as context
    context = z.view(z.size(0), -1).to(DEVICE, dtype=torch.float32)

    # Sample from flow conditioned on z
    num_draws = 1000
    if context.size(0) != num_draws:
        context = context.repeat(num_draws, 1)

    samples = npe_flow.sample(num_samples=num_draws, context=context)
    samples = samples.reshape(num_draws, -1)  # -> [num_draws, 6]

    samples_cpu = samples.detach().cpu()
    true_params = params.detach().cpu() if torch.is_tensor(params) else params
    true_params = true_params.flatten()  # Flatten to [6] from [1, 6]
    
    print("True params:", true_params)
    print("Mean predicted:", samples_cpu.mean(dim=0))
    print("Std predicted:", samples_cpu.std(dim=0))
    
    # Plot histogram of first parameter
    plt.figure(figsize=(10, 6))
    plt.hist(samples_cpu[:, 0].numpy(), bins=50, alpha=0.7, edgecolor='black')
    plt.axvline(true_params[0].item(), color='red', linestyle='--', linewidth=2, label=f'True value: {true_params[0].item():.3f}')
    plt.axvline(samples_cpu[:, 0].mean().item(), color='green', linestyle='--', linewidth=2, label=f'Predicted mean: {samples_cpu[:, 0].mean().item():.3f}')
    plt.xlabel('Parameter 1 Value')
    plt.ylabel('Frequency')
    plt.title('Posterior Distribution of Parameter 1')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()