In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from typing import List, Optional, Tuple

import corner


In [None]:
from starccato_flow.data.toy_data import ToyData
from starccato_flow.data.ccsn_data import CCSNData
from starccato_flow.data.ccsn_snr_data import CCSNSNRData
from starccato_flow.training.trainer import Trainer

from starccato_flow.plotting.plotting import plot_reconstruction_distribution

In [None]:
from starccato_flow.utils.defaults import DEVICE

### Dataset

In [None]:
# train_dataset = ToyData(num_signals=1684, signal_length=256)
# validation_dataset = ToyData(num_signals=round(1684 * 0.1), signal_length=256)

### Dataset Plots

In [None]:
ccsn_dataset = CCSNData(noise=True, curriculum=True)
ccsn_dataset.plot_signal_distribution(background="white", font_family="sans-serif", font_name="Avenir", fname="plots/ccsn_signal_distribution.svg")

In [None]:
ccsn_dataset.plot_signal_grid(background="white", font_family="sans-serif", font_name="Avenir", fname="plots/ccsn_signal_grid.svg")

In [None]:
toy=False

### Train VAE + Flow

In [None]:
# Train with 3 noise realizations = 3x more data!
# This should significantly improve validation NLL
trainer = Trainer(
    toy=toy, 
    noise=True, 
    curriculum=True,
    noise_realizations=3
)
trainer.train()

In [None]:
trainer = Trainer(toy=toy, noise=True, curriculum=True)
trainer.train()

### Display Results

In [None]:
trainer.display_results()

In [None]:
# plot_latent_morph_up_and_down(
#     trainer.vae,
#     signal_1=ccsn_dataset.__getitem__(800)[0],
#     signal_2=ccsn_dataset.__getitem__(600)[0],
#     max_value=trainer.training_dataset.max_strain,
#     train_dataset=CCSNData(),
#     steps=1
# )

In [None]:
trainer.plot_generated_signal_distribution(
    background="white",
    font_family="sans-serif",
    font_name="Avenir"
)

In [None]:
index = 1

val_idx = trainer.validation_sampler.indices[index]
signal, noisy_signal, params = trainer.val_loader.dataset.__getitem__(val_idx)

plot_reconstruction_distribution(
    vae=trainer.vae,
    noisy_signal=noisy_signal,
    true_signal=signal,
    max_value=trainer.validation_dataset.max_strain,
    num_samples=1000,
    background="white",
    font_family="sans-serif",
    font_name="Avenir"
)

In [None]:
trainer.save_models()

In [None]:
trainer.plot_corner(index=100)