In [14]:
%load_ext autoreload
%autoreload 2
from utils import FusionDataset, TrajectoryPreservingSampler, generate, plot_1d_statistic_over_time
from models import Forward, Posterior, Prior, Decoder
from train_2 import run
import os, pickle
import pandas as pd 
import torch 
from torch.utils.data import Subset
os.environ["KERAS_BACKEND"] = "torch"
import keras

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [22]:
device = torch.device("cpu")

## B input parameter 

In [23]:
b_field = pd.read_csv('b-field.csv', delimiter=',', index_col=0)
b_field_values = b_field.values.flatten()  # Convert to a 1D Numpy array
b_field = torch.tensor(b_field_values, dtype=torch.float32).to(device)


In [None]:
# Instantiate models on the chosen device
forward_t = Forward().to(device)
forward_tplus1 = Forward().to(device)
prior = Prior().to(device)
posterior = Posterior().to(device)
decoder = Decoder().to(device)

# Instantiate optimizer
opt = keras.optimizers.AdamW()

# Get data
batch_size = 8
train_loader = TrajectoryPreservingSampler(FusionDataset(data_dir="./data/train",max_instances=30), batch_size=batch_size)
val_loader = TrajectoryPreservingSampler(FusionDataset(data_dir="./data/val",max_instances=10), batch_size=batch_size)
test_ds = TrajectoryPreservingSampler(FusionDataset(data_dir="./data/test"), batch_size=16)

# Run training
save_dir = "./results/basic0"
x_tensor = torch.linspace(0, 500, steps=batch_size, requires_grad=True).unsqueeze(1).to(device)

run(train_loader, val_loader, forward_t, forward_tplus1, prior, posterior, decoder, opt, b_field, x_tensor, save_dir, 10)

# Evaluate by generating multiple trajectories from one random starting point
trajectory = next(iter(test_ds))
trajectory_hats = []
for i in range(6):
    trajectory_hats.append(keras.ops.concatenate(generate(trajectory, forward_t, prior, decoder)))
    # Save trajectories as figure
    fig = plot_1d_statistic_over_time(trajectory_hats[i].detach().cpu(), 0, "I don't know what this variable is");
    fig.savefig(f"{save_dir}/gen_{i}")
# Save trajectories as tensors
trajectory_hats = keras.ops.stack(trajectory_hats)
with open(f"{save_dir}/generated_trajectories.pkl", "wb") as file:
    pickle.dump(trajectory_hats, file)

