In [1]:
import numpy as np 
from scipy.integrate import odeint
import os, sys, warnings
from pathlib import Path
from os.path import dirname, realpath
script_dir = Path(dirname(realpath('.')))
module_dir = str(script_dir)
sys.path.insert(0, module_dir + '/modules')
import utility as ut
import surrogate_nn as srnn
import matplotlib.pyplot as plt
import pandas as pd
import torch
from torch import nn
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, TensorDataset
warnings.filterwarnings('ignore')

In [2]:
L63_data_path = '../data/L63-trajectories'
N = 20000
train = np.load(f'{L63_data_path}/train.npy')[:, :N].astype(np.float32)
test = np.load(f'{L63_data_path}/test.npy')[:, :, :1000].astype(np.float32)

train_dataset = TensorDataset(torch.Tensor(train[:, :-1].T), torch.Tensor(train[:, 1:].T))
# test_dataset = test = np.load(f'{L63_data_path}/test.npy')[:, :, 1000]
train_dataloader = DataLoader(train_dataset, batch_size=int(N/10))#, shuffle=True)
# test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

model = srnn.SurrogateModel_NN(3, 1024)
x = torch.rand(3)

In [None]:
model.learn(train, epochs=10000, learning_rate=1e-4)

epoch: 0    loss: 16005862.000000     time elapsed=0.0701
epoch: 100    loss: 1328855.625000     time elapsed=6.0199
epoch: 200    loss: 579925.687500     time elapsed=12.0289
epoch: 300    loss: 109313.976562     time elapsed=18.0863
epoch: 400    loss: 34549.503906     time elapsed=24.1609
epoch: 500    loss: 17422.062500     time elapsed=30.2298
epoch: 600    loss: 10629.001953     time elapsed=36.3187
epoch: 700    loss: 7022.848145     time elapsed=42.3969
epoch: 800    loss: 5028.617188     time elapsed=48.4887
epoch: 900    loss: 3828.250000     time elapsed=55.0033
epoch: 1000    loss: 3058.062988     time elapsed=61.8323
epoch: 1100    loss: 2347.989014     time elapsed=68.2926
epoch: 1200    loss: 1924.561768     time elapsed=74.3979
epoch: 1300    loss: 1603.679077     time elapsed=80.5894
epoch: 1400    loss: 1357.071167     time elapsed=86.7339
epoch: 1500    loss: 1179.967285     time elapsed=92.8014
epoch: 1600    loss: 1031.718994     time elapsed=98.8585
epoch: 1700   

In [None]:
tau_f_rmse, tau_f_se, rmse, se = model.compute_tau_f(test[:500], error_threshold=0.05)

In [None]:
plt.hist(tau_f_se, density=True)
plt.xlabel(r'$\tau_f$ (NN)')
plt.savefig('../data/plots/tau_f_NN.png')
# plt.title(f'training_data_size={N}')

In [None]:
tau_f_rmse, tau_f_se, rmse, se = model.compute_tau_f(train.T.reshape(-1, 3, 800))

plt.hist(tau_f_se)
plt.xlabel(r'$\tau_f$ (NN), for train data')
plt.title(f'training_data_size={N}')

In [None]:
tau_f_se.mean()

In [None]:
for batch, (X, y) in enumerate(train_dataloader):
    print(X, y)

In [None]:
train.T

In [None]:
import torch.nn.functional as F

In [None]:
l = nn.Linear(3, 300)

In [None]:
F.tanh(l(torch.rand(3)))

In [None]:
train_dataset

In [None]:
for batch, (X, y) in enumerate(train_dataloader):
    if batch == 1:
        print(X, y)

In [None]:
model.net(X)

In [None]:
model.net(y)

In [None]:
nn.MSELoss()(model.net(X), y)

In [None]:
torch.mean(torch.sum((model.net(X)-y)**2, axis=1)**0.5)

In [None]:
242.81**2

In [None]:
attrs(model.net.W)

In [None]:
seq = nn.Sequential(nn.Linear(3, 300, bias=True), nn.Tanh(), nn.Linear(300, 3, bias=False))

In [None]:
for name, param in seq.named_parameters():
    if param.requires_grad:
        print(name, param.data)

In [None]:
torch.sum(seq.state_dict()['2.weight']**2)

In [None]:
fig, N, dt = plt.figure(figsize=(8, 8)), 500, 0.02
t = np.arange(0., N*dt, dt) / (1/0.91)
axs = [fig.add_subplot(311), fig.add_subplot(312), fig.add_subplot(313)]
dims = ['x', 'y', 'z']
u = np.load('{}/validation.npy'.format(L63_data_path)).astype(np.float32)

idx = [[4, 14, 18], [13, 17, 46], [13, 2, 31], ]
labels = ['good', 'medium', 'bad']
predicted = model.multistep_forecast(u[:, 0], N)
for i, ax in enumerate(axs):
    ax.plot(t, u[i, :N], label='truth')
    ax.plot(t, predicted[i, :N], label='network')
    ax.legend(loc='upper right')
    if i == len(dims)-1:
        ax.set_xlabel('t')
    ax.set_ylabel(dims[i])