In [None]:
from itertools import product

import torch
import plotly.io as pio
import plotly.express as px
import pandas as pd
import numpy as np
from dask.distributed import Client

from time_series_prediction import settings, utils

pio.templates.default = 'plotly_white'

In [None]:
# settings.switch_device('cuda')
settings.device == torch.device('cpu')

In [None]:
params = [
    ('leaking_rate', [0.05, 0.1, 0.3, 0.5, 0.7, 1]),
    ('k_l2', 10**np.linspace(1, 3, 5)),
    ('n_neurons', [400, 800, 1600, 2400, 3200]),
    ('spectral_radius', [1, 2, 3, 4]),
]
# params = [
#     ('k_l2', 10**torch.linspace(-0.5, -2.5, 5)),
#     ('source', ['rossler', 'lorenz'])
# ]
keys, vals = zip(*params)
param_dicts = []
for val in product(*vals):
    param_dicts.append({k: v for k, v in zip(keys, val)})
print(f'{len(param_dicts)} simulations')

In [None]:
client = Client(n_workers=5, threads_per_worker=1)

In [None]:
n_ensemble = 15

kwargs_default = {
}

sweep = utils.SweepESN(
    param_dicts,
    kwargs_default,
    n_ensemble,
    noise=0.0,
    source='jallon',
)

In [None]:
sweep.submit_jobs(client)

In [None]:
res = sweep.get_results()

In [None]:
client.shutdown()

In [None]:
df = pd.DataFrame(res)
df

In [None]:
df.to_pickle('esn_sweep_jallon.pkl')

In [None]:
df.sort_values(by='Min MAE').iloc[0]

In [None]:
fig = pio.read_json(df.sort_values(by='Min MAE').iloc[0]['_fig_path'])
fig.update_layout(template='plotly_white', height=800)

In [None]:
dfq = df.query("leaking_rate==0.05 & n_neurons==3200 & spectral_radius==1")
dfq

In [None]:
fig = pio.read_json(df.query("leaking_rate==0.1 & n_neurons==3200 & spectral_radius==1").iloc[2]['_fig_path'])
fig.update_layout(template='plotly_white', height=1200)

In [None]:
x = 'leaking_rate'
color = 'n_neurons'
y = 'Mean MAE'

dfq = df#.query('k_l2>=0.01')

fig = px.line(
    dfq,
    x=x,
    y=y,
    color=color,
    facet_row='k_l2',
    line_dash='spectral_radius',
    # facet_col='spectral_radius',
    error_y=dfq['Max MAE'] - dfq[y],
    error_y_minus=dfq[y] - dfq['Min MAE'],
    # log_x=True,
    # log_y=True,
    color_discrete_sequence=px.colors.sample_colorscale(
        px.colors.get_colorscale('Plasma'), 
        np.linspace(0, 0.9, len(df[color].unique())),
    )
)
fig.update_traces(mode='markers+lines')
fig.update_layout(height=800)
fig.for_each_annotation(lambda x: x.update(text=x.text[:12]), lambda x: x.text[:4] == 'k_l2')
fig.update_yaxes(title_text='median_err')