In [4]:
%load_ext autoreload
%autoreload 2
from utils import dataloader, generate, plot_1d_statistic_over_time, plot_loss
from models import Forward, Posterior, Prior, Decoder
from train import run, val_step
import os, pickle, random, shutil
os.environ["KERAS_BACKEND"] = "torch"
import keras

# Split data
Make sure all data files are stored somewhere in `"./data/"`.

In [12]:
files = [os.path.join(root, file) for root, _, files in os.walk("./data/") for file in files if file.endswith(".npz")]
random.seed(42)
random.shuffle(files)

train_files = files[:int(0.6 * len(files))]
val_files = files[int(0.6 * len(files)):int(0.8 * len(files))]
test_files = files[int(0.8 * len(files)):]

os.makedirs("./data/train", exist_ok=True)
os.makedirs("./data/val", exist_ok=True)
os.makedirs("./data/test", exist_ok=True)

for f in train_files:
    if not f.split("/")[-1] in os.listdir("./data/train"):
        shutil.move(f, "./data/train")
for f in val_files:
    if not f.split("/")[-1] in os.listdir("./data/val"):
        shutil.move(f, "./data/val")
for f in test_files:
    if not f.split("/")[-1] in os.listdir("./data/test"):
        shutil.move(f, "./data/test")

# Train

In [None]:
# Instantiate models
forward_t = Forward()
forward_tplus1 = Forward()
prior = Prior()
posterior = Posterior()
decoder = Decoder()

# Instantiate optimizer
opt = keras.optimizers.AdamW(1e-4)

# Get data
train_loader = dataloader(data_dir="./data/train", batch_size=32)
val_loader = dataloader(data_dir="./data/val", batch_size=32)
test_loader = dataloader(data_dir="./data/test", batch_size=32)

# Run training
save_dir = "./results/basic-minmax"
run(train_loader, val_loader, forward_t, forward_tplus1, prior, posterior, decoder, opt, save_dir, 200, 10)

# # Plot loss
plot_loss(f"{save_dir}/history.json");

In [None]:
import random, os
files = [os.path.join(root, file) for root, _, files in os.walk("./data/") for file in files if file.endswith(".npz")]
random.shuffle(files)
files

['./data/test/div1d_output_280.npz',
 './data/train/div1d_output_232.npz',
 './data/train/div1d_output_31.npz',
 './data/train/div1d_output_199.npz',
 './data/train/div1d_output_140.npz',
 './data/train/div1d_output_128.npz',
 './data/train/div1d_output_154.npz',
 './data/test/div1d_output_296.npz',
 './data/val/div1d_output_163.npz',
 './data/test/div1d_output_88.npz',
 './data/train/div1d_output_23.npz',
 './data/val/div1d_output_275.npz',
 './data/train/div1d_output_145.npz',
 './data/train/div1d_output_373.npz',
 './data/test/div1d_output_102.npz',
 './data/train/div1d_output_283.npz',
 './data/test/div1d_output_359.npz',
 './data/train/div1d_output_357.npz',
 './data/train/div1d_output_190.npz',
 './data/train/div1d_output_131.npz',
 './data/train/div1d_output_12.npz',
 './data/train/div1d_output_69.npz',
 './data/train/div1d_output_39.npz',
 './data/test/div1d_output_297.npz',
 './data/test/div1d_output_92.npz',
 './data/val/div1d_output_238.npz',
 './data/val/div1d_output_340.np

# Evaluate

In [None]:
# Load trained models
forward_t.load_weights(f"{save_dir}/forward_t.weights.h5")
prior.load_weights(f"{save_dir}/prior.weights.h5")
decoder.load_weights(f"{save_dir}/decoder.weights.h5")

# Evaluate by calculating loss over test set
test_loss = 0
for i, (x_t, x_tplus1) in enumerate(test_loader):
    test_loss += val_step(x_t, x_tplus1, forward_t, prior, decoder)
test_loss /= i+1
print("Test reconstruction loss:", test_loss)

# Generate
Make sure the trained models have been loaded with the code cell above.

In [None]:
# Select a random ground truth starting point
trajectory = test_loader.dataset.get_trajectory(pushforward=False)
# Generate multiple trajectories from one starting point
trajectory_hats = []
for i in range(6):
    # Generate trajectory
    trajectory_hat = generate(trajectory, forward_t, prior, decoder)
    trajectory_hats.append(keras.ops.concatenate(trajectory_hat))
    # Save trajectories as figure
    fig = plot_1d_statistic_over_time(trajectory_hats[i].detach().cpu(), 0, "I don't know what this variable is");
    fig.savefig(f"{save_dir}/gen_{i}")
# Save trajectories as tensors
trajectory_hats = keras.ops.stack(trajectory_hats)
with open(f"{save_dir}/generated_trajectories.pkl", "wb") as file:
    pickle.dump(trajectory_hats, file)