In [1]:
import torch
import torch.optim as optim
import numpy as np
from importlib import reload

In [2]:
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

In [3]:
# custom libraries
import sys, os
print(f"{sys.executable = }")
sys.path.append(os.path.abspath('../'))
sys.path.append(os.path.abspath('../../spaths'))
import spaths
from sf_nets import datasets
from src import systems, nets, training

sys.executable = '/Users/pzielinski/opt/anaconda3/envs/sf_nets/bin/python'


## Load simulation data & solver

In [4]:
simdat = torch.load("../data/rqp4_old/simdat.pt")
simdat.keys()

dict_keys(['seed', 'solver', 'dt', 'x0', 'tspan'])

In [5]:
# corresponding system
rqp4_sde = systems.rqp4_sde
rqp4_dat = systems.rqp4_dat
# sde_rqp4 = spaths.ItoSDE(rqp4['drift'], rqp4['dispersion'], noise_mixing_dim=rqp4['nmd'])

## Train

In [6]:
model_id = 'simple_7'

In [7]:
train_ds = datasets.RQP4(root="../data/")

In [8]:
train_ds[0]

(tensor([ 0.2896, -0.1219, -0.1329, -0.6876]),
 tensor([[ 4.0528e-01, -3.8182e-03, -2.5143e-03, -8.7549e-03],
         [-3.8182e-03,  9.8282e-01,  6.3884e-01,  2.3917e+00],
         [-2.5143e-03,  6.3884e-01,  1.1035e+00,  2.6550e+00],
         [-8.7549e-03,  2.3917e+00,  2.6550e+00,  9.9840e+00]]))

In [50]:
train_size = 4000
valid_size = 1300

train_dataset = torch.utils.data.Subset(torch.load("../data/rqp4/train.pt"), list(range(train_size)))
valid_dataset = torch.utils.data.Subset(torch.load("../data/rqp4/valid.pt"), list(range(valid_size)))

In [57]:
batch_size = 16
max_epochs = 200
learning_rate = 1e-3

In [58]:
params = {
    'model_id': model_id,
    'batch_size': batch_size,
    'max_epochs': max_epochs,
    'burst_size': 10**4,
    'system': rqp4_dat['name'],
    'sde': rqp4_sde,
    'solver': simdat['solver'],
    'burst_dt': simdat['dt'],
    'train_nsam': train_size,
    'valid_nsam': valid_size,
    'learning_rate': learning_rate
}

In [59]:
net_arch = {
    'in_features': rqp4_dat['ndim'],
    'latent_features': rqp4_dat['sdim'],
    'hidden_dimensions': [8, 4]
}
params['net_arch'] = net_arch

In [60]:
simple = training.Simple(params)

In [61]:
simple.train(train_dataset, valid_dataset)

epoch :   1/200, reconstruction loss = 1.89668, validation loss = 1.98741
epoch :  10/200, reconstruction loss = 1.14960, validation loss = 1.24190
epoch :  20/200, reconstruction loss = 1.10413, validation loss = 1.22341
epoch :  30/200, reconstruction loss = 1.08264, validation loss = 1.18597
epoch :  40/200, reconstruction loss = 1.06695, validation loss = 1.12815
epoch :  50/200, reconstruction loss = 1.02915, validation loss = 1.10361
epoch :  60/200, reconstruction loss = 0.97206, validation loss = 0.99467
epoch :  70/200, reconstruction loss = 0.93838, validation loss = 1.01775
epoch :  80/200, reconstruction loss = 0.91424, validation loss = 0.94544
epoch :  90/200, reconstruction loss = 0.89311, validation loss = 0.91199
epoch : 100/200, reconstruction loss = 0.87983, validation loss = 0.91611
epoch : 110/200, reconstruction loss = 0.87482, validation loss = 0.88941
epoch : 120/200, reconstruction loss = 0.87067, validation loss = 0.90268
epoch : 130/200, reconstruction loss =