In [None]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from vae.datasets import VolSurfaceDataSetDict, VolSurfaceExFeatsDataSet
from vae.cvae import CVAE
from vae.utils import *

In [None]:
set_seeds(0)
torch.set_default_dtype(torch.float64)

In [None]:
seq_len = 5
ctx_len = seq_len - 1
num_epochs = 100

In [None]:
data = np.load("data/vol_surface_with_ret.npz")
vol_surf_data = data["surface"]
train_simple = DataLoader(VolSurfaceDataSetDict(vol_surf_data[:4000], seq_len), shuffle=True, batch_size=64)
valid_simple = DataLoader(VolSurfaceDataSetDict(vol_surf_data[4000:5000], seq_len), shuffle=True, batch_size=16)
test_simple = DataLoader(VolSurfaceDataSetDict(vol_surf_data[5000:], seq_len), shuffle=True, batch_size=16)

# Conv version

In [None]:
config = {
    "seq_len": seq_len, 
    "feat_dim": (5, 5),
    "latent_dim": 25,
    "device": "cuda",
    "kl_weight": 1,
    "surface_hidden": [5, 5, 5],
    "ctx_len": ctx_len, 
    "ctx_surface_hidden": [5, 5, 5], 
    "ctx_embedding": 100,
    "use_dense_surface": False,
}
model = CVAE(config)
print(model)

In [None]:
train(model, train_simple, valid_simple, epochs=num_epochs, lr=1e-05, model_dir="test_spx/no_mem", file_name="conv2d_spx.pt")

In [None]:
test(model, train_simple, valid_simple, "test_spx/no_mem/conv2d_spx.pt")

In [None]:
surf = model.get_surface_given_conditions({"surface": torch.from_numpy(vol_surf_data[5000:5000+ctx_len])})
surf = surf.detach().cpu().numpy().reshape((5,5))
plot_surface_separate(vol_surf_data[5000+ctx_len], surf)

In [None]:
model_data = torch.load("test_spx/no_mem/conv2d_spx.pt")
sim = generate_surface_spx(vol_surf_data, None, model_data, start_time=5000, steps_to_sim=30, model_type=CVAE)
plot_surface_time_series(sim)

# Dense version

In [None]:
config = {
    "seq_len": seq_len, 
    "feat_dim": (5, 5),
    "latent_dim": 100,
    "device": "cuda",
    "kl_weight": 1,
    "surface_hidden": [100, 200, 200],
    "ctx_len": ctx_len, 
    "ctx_surface_hidden": [100, 200, 200], 
    "ctx_embedding": 100,
    "use_dense_surface": True,
}
model2 = CVAE(config)
print(model2)

In [None]:
train(model2, train_simple, valid_simple, epochs=num_epochs, lr=1e-05, model_dir="test_spx/no_mem", file_name="dense_spx.pt")

In [None]:
test(model2, train_simple, valid_simple, "test_spx/no_mem/dense_spx.pt")

In [None]:
surf = model.get_surface_given_conditions({"surface": torch.from_numpy(vol_surf_data[5000:5000+ctx_len])})
surf = surf.detach().cpu().numpy().reshape((5,5))
plot_surface_separate(vol_surf_data[5000+ctx_len], surf)

In [None]:
model_data = torch.load("test_spx/no_mem/dense_spx.pt")
sim = generate_surface_spx(vol_surf_data, None, model_data, start_time=5000, steps_to_sim=30, model_type=CVAE)
plot_surface_time_series(sim)