In [None]:
from itertools import product

import torch
import plotly.io as pio
import plotly.express as px
import pandas as pd
import numpy as np
from dask.distributed import Client

from time_series_prediction import settings, utils

pio.templates.default = 'plotly_white'

In [None]:
NOISE = 2e-5
# NOISE = 0

In [None]:
# settings.switch_device('cuda')
settings.device == torch.device('cpu')

In [None]:
params = [
    ('k_l2', 10**np.linspace(-2, 1, 5)),
    ('source', ['rossler', 'lorenz'])
]
keys, vals = zip(*params)
param_dicts = []
for val in product(*vals):
    param_dicts.append({k: v for k, v in zip(keys, val)})
print(f'{len(param_dicts)} simulations')

In [None]:
client = Client(n_workers=10, threads_per_worker=1)

In [None]:
n_ensemble = 30

kwargs_default = {
    # 'source': 'lorenz',
    'spectral_radius': 0.99,
    'n_neurons': 1000,
    'leaking_rate': 0.1,
}

sweep = utils.SweepESN(
    param_dicts,
    kwargs_default,
    n_ensemble,
    noise=NOISE,
)

In [None]:
sweep.submit_jobs(client)

In [None]:
res = sweep.get_results()

In [None]:
client.shutdown()

In [None]:
df = pd.DataFrame(res)
df

In [None]:
for i in range(10):
    err = f'Lyapunov {i}'
    df[f'MSE {err}'] = df['MSE Lyapunov'].apply(lambda x: x[:, i] if x is not None else x)
    df[f'MAE {err}'] = df['MAE Lyapunov'].apply(lambda x: x[:, i] if x is not None else x)
    df[f'Mean MSE {err}'] = df[f'MSE {err}'].apply(lambda x: x.mean() if x is not None else x)
    df[f'Mean MAE {err}'] = df[f'MAE {err}'].apply(lambda x: x.mean() if x is not None else x)
    df[f'Max MSE {err}'] = df[f'MSE {err}'].apply(lambda x: x.max() if x is not None else x)
    df[f'Max MAE {err}'] = df[f'MAE {err}'].apply(lambda x: x.max() if x is not None else x)
    df[f'Min MSE {err}'] = df[f'MSE {err}'].apply(lambda x: x.min() if x is not None else x)
    df[f'Min MAE {err}'] = df[f'MAE {err}'].apply(lambda x: x.min() if x is not None else x)

In [None]:
from plotly.subplots import make_subplots
from plotly import colors
k_l2_vals = np.sort(df['k_l2'].unique())
c = colors.sample_colorscale(
    colors.get_colorscale('Plasma'),
    np.linspace(0, 0.9, len(k_l2_vals)),
)
row_ax = 'source'
row_vals = df[row_ax].unique()
fig = make_subplots(rows=len(row_vals), cols=1, shared_xaxes='all', shared_yaxes='all')
dfq = df
for i_row, row_name in enumerate(row_vals):
    for i_l2, k_l2 in enumerate(k_l2_vals):
        df_l2 = dfq[(dfq['k_l2'] == k_l2) & (dfq[row_ax] == row_name)]
        lyap_err = torch.stack(list(df_l2['MAE Lyapunov']))
        t_lyap = torch.zeros_like(lyap_err)
        for i in range(t_lyap.shape[-1]):
            t_lyap[..., i] = i + 1
        fig.add_box(
            x=t_lyap.flatten(),
            y=lyap_err.flatten(),
            name=f'k_l2={k_l2:.4f}',
            row=i_row+1, col=1,
            legendgroup=f'k_l2={k_l2:.4f}',
            offsetgroup=f'k_l2={k_l2:.4f}',
            marker_color=c[i_l2],
            showlegend=i_row==0,
            # boxpoints='all'
        )
    fig.update_yaxes(title_text=f'MAE n_neurons={row_name}', range=[0, 2], row=i_row+1)
fig.update_xaxes(title_text='Lyapunov time', row=len(row_vals), dtick=1)
fig.update_layout(boxmode='group', height=1000)

In [None]:
dfq = df.query('source=="rossler"')
dfq

In [None]:
dfq.iloc[2]['MAE Lyapunov'].mean(axis=0)[:5]

In [None]:
fig = pio.read_json(dfq.iloc[2]['_fig_path'])
fig.update_layout(template='plotly_white', height=800)

In [None]:
hidden_layers = [6, 10, 6]
lorenz = utils.SweepEnKF(
    'lorenz', 
    'activation_function', 
    ['ELU'] * 8,
    train_sweep=False,
    noise=NOISE,
    hidden_layers=hidden_layers,
    train_kwargs={'lr_hold': 10, 'subseq_len': 50},
)
lorenz.train(100)

In [None]:
lorenz.plot_test(3)

In [None]:
rossler = utils.SweepEnKF(
    'rossler', 
    'activation_function', 
    ['ELU'] * 8,
    train_sweep=False,
    noise=NOISE,
    hidden_layers=hidden_layers,
    train_kwargs={'lr_hold': 10, 'subseq_len': 50},
)
rossler.train(100)

In [None]:
rossler.plot_test(1)

In [None]:
with torch.no_grad():
    mae_rossler = [rossler.enkf_err(i)[5] for i in range(8)]
    err_rossler = torch.stack(mae_rossler)
err_rossler.mean(axis=0)[:5]

In [None]:
with torch.no_grad():
    mae_lorenz = [lorenz.enkf_err(i)[5] for i in range(8)]
    err_lorenz = torch.stack(mae_lorenz)
err_lorenz.mean(axis=0)[:5]