In [None]:
import torch
import plotly.io as pio
import pickle as pkl

from time_series_prediction import settings, utils

pio.templates.default = 'plotly_white'

In [None]:
# settings.switch_device('cuda')
settings.device == torch.device('cpu')

In [None]:
hidden_layers = [6, 10, 6]

In [None]:
lorenz = utils.SweepEnKF(
    'lorenz', 
    'activation_function', 
    ['ReLU', 'ReLU', 'ELU', 'ELU', 'Tanh', 'Tanh'], 
    train_sweep=False,
    noise=5e-2,
    hidden_layers=hidden_layers,
    train_kwargs={'lr_hold': 10, 'subseq_len': 50},
)

In [None]:
lorenz.train(100)

In [None]:
from plotly import colors
import numpy as np
c = colors.sample_colorscale(
    colors.get_colorscale('Plasma'), 
    np.linspace(0, 0.9, 3),
)
c_dict = {
    'activation_function=ReLU': c[0],
    'activation_function=ELU': c[1],
    'activation_function=Tanh': c[2],
}
lorenz.train_fig.for_each_trace(
    lambda trace: trace.update(line_color=c_dict[trace.legendgroup])
)

In [None]:
lorenz.plot_test(0).update_yaxes(
    row=1, col=1, title_text='x',
).update_yaxes(
    row=2, col=1, title_text='y',
).update_yaxes(
    row=3, col=1, title_text='z',
).update_xaxes(
    col=1, title_text='',
).update_xaxes(
    col=1, row=5, title_text='t',
)

In [None]:

err_lorenz = []
with torch.no_grad():
    for i in range(20):
        _, _, _, _, mse_lyap_lorenz, mae_lyap_lorenz = lorenz.enkf_err(0)
        err_lorenz.append(mae_lyap_lorenz)
    err_lorenz = torch.stack(err_lorenz)
err_lorenz.mean(axis=0)

In [None]:
rossler = utils.SweepEnKF(
    'rossler', 
    'activation_function', 
    ['ReLU', 'ReLU', 'ELU', 'ELU', 'Tanh', 'Tanh'], 
    train_sweep=False,
    noise=5e-2,
    hidden_layers=hidden_layers,
    train_kwargs={'lr_hold': 10, 'subseq_len': 50},
)


In [None]:
rossler.train(100)

In [None]:
c = colors.sample_colorscale(
    colors.get_colorscale('Plasma'), 
    np.linspace(0, 0.9, 3),
)
c_dict = {
    'activation_function=ReLU': c[0],
    'activation_function=ELU': c[1],
    'activation_function=Tanh': c[2],
}
rossler.train_fig.for_each_trace(
    lambda trace: trace.update(line_color=c_dict[trace.legendgroup])
)

In [None]:
rossler.plot_test(0).update_yaxes(
    row=1, col=1, title_text='x',
).update_yaxes(
    row=2, col=1, title_text='y',
).update_yaxes(
    row=3, col=1, title_text='z',
).update_xaxes(
    col=1, title_text='',
).update_xaxes(
    col=1, row=5, title_text='t',
)

In [None]:

err_rossler = []
with torch.no_grad():
    for i in range(20):
        _, _, _, _, mse_lyap_r, mae_lyap_r = lorenz.enkf_err(0)
        err_rossler.append(mae_lyap_r)
    err_rossler = torch.stack(err_rossler)
err_rossler.mean(axis=0)

In [None]:
vdp = utils.SweepEnKF(
    'vdp', 
    'activation_function', 
    ['ReLU', 'ReLU', 'ELU', 'ELU', 'Tanh', 'Tanh'], 
    train_sweep=False,
    noise=5e-2,
    hidden_layers=hidden_layers,
    train_kwargs={'lr_hold': 10, 'subseq_len': 50},
    additional_states=4,
)

In [None]:
vdp.train(100)

In [None]:
from plotly import colors
import numpy as np
c = colors.sample_colorscale(
    colors.get_colorscale('Plasma'), 
    np.linspace(0, 0.9, 3),
)
c_dict = {
    'activation_function=ReLU': c[0],
    'activation_function=ELU': c[1],
    'activation_function=Tanh': c[2],
}
vdp.train_fig.for_each_trace(
    lambda trace: trace.update(line_color=c_dict[trace.legendgroup])
)

In [None]:
vdp.plot_test(0).update_yaxes(
    row=1, col=1, title_text=r'$\frac{dx}{dt}$',
).update_yaxes(
    row=2, col=1, title_text='x',
).update_yaxes(
    row=3, col=1, title_text='$s_1$',
).update_yaxes(
    row=4, col=1, title_text='$s_2$',
).update_xaxes(
    title_text='', col=1,
).update_xaxes(
    title_text='t', row=6, col=1,
).write_html('enkf_vdp_best.html', include_mathjax='cdn')