In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from typing import List, Optional, Tuple

import corner

In [None]:
from starccato_flow.data.ccsn_snr_data import CCSNSNRData
from starccato_flow.data.toy_data import ToyData
from starccato_flow.training.trainer import Trainer
from starccato_flow.training.trainer_flow_matching import FlowMatchingTrainer

from starccato_flow.utils.defaults import TEN_KPC
from starccato_flow.plotting.plotting import plot_reconstruction_distribution, plot_candidate_signal, create_latent_morph_gif

In [None]:
from starccato_flow.utils.defaults import DEVICE

In [None]:
ccsn_dataset = CCSNSNRData(noise=True, curriculum=False)
ccsn_dataset.plot_signal_distribution(background="black", font_family="sans-serif", font_name="Avenir", fname="plots/ccsn_signal_distribution.svg")

toy_dataset = ToyData(noise=False, curriculum=False)
toy_dataset.plot_signal_distribution(background="black", font_family="sans-serif", font_name="Avenir", fname="plots/toy_signal_distribution.svg")

In [None]:
trainer = FlowMatchingTrainer(
    toy=False, 
    start_snr=200,
    end_snr=10,
    noise=False, 
    validation_split=0.1,
    curriculum=True,
    noise_realizations=1
)

# trainer.plot_candidate_signal(
#     snr=30,
#     index=60,
#     background="black"
# )

In [None]:
trainer.train()

In [None]:
trainer.display_results()

In [None]:
# trainer.val_loader.dataset.update_snr(20)
# signal, noisy_signal, params = trainer.val_loader.dataset.__getitem__(9)
# plot_candidate_signal(noisy_signal=noisy_signal/TEN_KPC, signal=signal/TEN_KPC, max_value=trainer.validation_dataset.max_strain, fname="plots/detected_signal.svg", background="black")
trainer.plot_corner(index=100, fname="plots/corner_plot.svg")

In [None]:
# Sampling - Parameter Estimation from Test Signals
# Use validation dataset from trainer
val_dataset = trainer.val_loader.dataset

# Get true parameters and signals from validation set
n_test_samples = min(300, len(val_dataset))
true_params = torch.zeros(n_test_samples, 2, device=DEVICE)
test_signals = torch.zeros(n_test_samples, trainer.y_length, device=DEVICE)

for i in range(n_test_samples):
    clean_signal, noisy_signal, params = val_dataset[i]
    true_params[i] = params.squeeze()
    test_signals[i] = noisy_signal.squeeze()

# Start from noise in parameter space
x = torch.randn(n_test_samples, 2, device=DEVICE)
n_steps = 8
fig, axes = plt.subplots(1, n_steps + 1, figsize=(30, 4), sharex=True, sharey=True)
time_steps = torch.linspace(0, 1.0, n_steps + 1)

axes[0].scatter(x.detach().cpu()[:, 0], x.detach().cpu()[:, 1], s=10, alpha=0.6, label='Estimated')
axes[0].scatter(true_params.cpu()[:, 0], true_params.cpu()[:, 1], s=10, alpha=0.3, label='True', color='red')
axes[0].set_title(f't = {time_steps[0]:.2f}')
axes[0].set_xlim(-3.0, 3.0)
axes[0].set_ylim(-3.0, 3.0)
axes[0].legend()

for i in range(n_steps):
    x = trainer.flow.step(x, time_steps[i], time_steps[i + 1], test_signals)
    axes[i + 1].scatter(x.detach().cpu()[:, 0], x.detach().cpu()[:, 1], s=10, alpha=0.6, label='Estimated')
    axes[i + 1].scatter(true_params.cpu()[:, 0], true_params.cpu()[:, 1], s=10, alpha=0.3, label='True', color='red')
    axes[i + 1].set_title(f't = {time_steps[i + 1]:.2f}')
    if i == n_steps - 1:
        axes[i + 1].legend()

plt.tight_layout()
plt.show()

# Generate posterior samples for a single signal
# Pick one true parameter and its signal from validation set
single_idx = 2  # Choose an example
single_true_param = true_params[single_idx:single_idx+1]
single_signal = test_signals[single_idx:single_idx+1]

# Generate multiple posterior samples by starting from different noise realizations
n_posterior_samples = 10000
posterior_samples = torch.randn(n_posterior_samples, 2, device=DEVICE)

# Repeat the single signal for all samples
repeated_signal = single_signal.repeat(n_posterior_samples, 1)

# Flow the samples to get posterior distribution
n_steps = 20
time_steps_posterior = torch.linspace(0, 1.0, n_steps + 1)

for i in range(n_steps):
    posterior_samples = trainer.flow.step(posterior_samples, time_steps_posterior[i], 
                                   time_steps_posterior[i + 1], repeated_signal)

# Visualize the posterior
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
ax.scatter(true_params.cpu()[:, 0], true_params.cpu()[:, 1], 
           s=200, marker='X', label='"Prior"', color='green', alpha=0.3)
ax.scatter(posterior_samples.detach().cpu()[:, 0], posterior_samples.detach().cpu()[:, 1], 
           s=20, alpha=0.4, label='Posterior samples', color='blue')
ax.scatter(single_true_param.cpu()[:, 0], single_true_param.cpu()[:, 1], 
           s=200, marker='*', label='True parameter', color='red', edgecolors='black', linewidths=2)
ax.set_xlim(-3, 3)
ax.set_ylim(-3, 3)
ax.set_xlabel('Parameter 1')
ax.set_ylabel('Parameter 2')
ax.set_title('Posterior Distribution for Single Signal')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_aspect('equal')
plt.tight_layout()
plt.show()

print(f"True parameter: [{single_true_param[0, 0]:.3f}, {single_true_param[0, 1]:.3f}]")
print(f"Posterior mean: [{posterior_samples[:, 0].mean():.3f}, {posterior_samples[:, 1].mean():.3f}]")
print(f"Posterior std: [{posterior_samples[:, 0].std():.3f}, {posterior_samples[:, 1].std():.3f}]")