In [1]:
%load_ext autoreload
%autoreload 2
from utils import dataloader, generate, plot_1d_statistic_over_time
from models import Forward, Posterior, Prior, Decoder, reparameterize
from train import run, val_step
import os, pickle
os.environ["KERAS_BACKEND"] = "torch"
import keras

# Train

In [None]:
# Instantiate models
forward_t = Forward()
forward_tplus1 = Forward()
prior = Prior()
posterior = Posterior()
decoder = Decoder()

# Instantiate optimizer
opt = keras.optimizers.AdamW(1e-4)

# Get data
train_loader = dataloader(data_dir="./data/train", batch_size=16)
val_loader = dataloader(data_dir="./data/val", batch_size=16)
test_loader = dataloader(data_dir="./data/test", batch_size=16)

# Run training
save_dir = "./results/basic0"
run(train_loader, val_loader, forward_t, forward_tplus1, prior, posterior, decoder, opt, save_dir, 100)

# Evaluate

In [None]:
# Load trained models
forward_t.load_weights(f"{save_dir}/forward_t.weights.h5")
prior.load_weights(f"{save_dir}/prior.weights.h5")
decoder.load_weights(f"{save_dir}/decoder.weights.h5")

# Evaluate by calculating loss over test set
test_loss = 0
for i, (x_t, x_tplus1) in enumerate(test_loader):
    test_loss += val_step(x_t, x_tplus1, forward_t, prior, decoder)
test_loss /= i+1
print("Test reconstruction nll:", test_loss)

# Generate
Make sure the trained models have been loaded with the code cell above.

In [None]:
# Evaluate by generating multiple trajectories from one random starting point
trajectory = test_loader.dataset.get_trajectory()
print(len(trajectory))
trajectory_hats = []
for i in range(6):
    trajectory_hats.append(keras.ops.concatenate(generate(trajectory, forward_t, prior, decoder)))
    # Save trajectories as figure
    fig = plot_1d_statistic_over_time(trajectory_hats[i].detach().cpu(), 0, "I don't know what this variable is");
    fig.savefig(f"{save_dir}/gen_{i}")
# Save trajectories as tensors
trajectory_hats = keras.ops.stack(trajectory_hats)
with open(f"{save_dir}/generated_trajectories.pkl", "wb") as file:
    pickle.dump(trajectory_hats, file)