In [16]:
%load_ext autoreload
%autoreload 2
from utils import dataloader, FusionDataset, generate, plot_1d_statistic_over_time
from models import Forward, Posterior, Prior, Decoder
from train import run
import os, pickle
os.environ["KERAS_BACKEND"] = "torch"
import keras

In [None]:
# Instantiate models
forward_t = Forward()
forward_tplus1 = Forward()
prior = Prior()
posterior = Posterior()
decoder = Decoder()

# Instantiate optimizer
opt = keras.optimizers.AdamW(1e-4)

# Get data
train_loader = dataloader(data_dir="./data/train", batch_size=16)
val_loader = dataloader(data_dir="./data/val", batch_size=16)
test_ds = FusionDataset(data_dir="./data/test")

# Run training
save_dir = "./results/basic0"
run(train_loader, val_loader, forward_t, forward_tplus1, prior, posterior, decoder, opt, save_dir, 100)

# Load trained models
forward_t = keras.models.load_model(f"{save_dir}/forward_t")
prior = keras.models.load_model(f"{save_dir}/prior")
decoder = keras.models.load_model(f"{save_dir}/decoder")

# Evaluate by generating multiple trajectories from one random starting point
trajectory = test_ds.get_trajectory()
trajectory_hats = []
for i in range(6):
    trajectory_hats.append(keras.ops.concatenate(generate(trajectory, forward_t, prior, decoder)))
    # Save trajectories as figure
    fig = plot_1d_statistic_over_time(trajectory_hats[i].detach().cpu(), 0, "I don't know what this variable is");
    fig.savefig(f"{save_dir}/gen_{i}")
# Save trajectories as tensors
trajectory_hats = keras.ops.stack(trajectory_hats)
with open(f"{save_dir}/generated_trajectories.pkl", "wb") as file:
    pickle.dump(trajectory_hats, file)