In [None]:
import torch
import numpy as np
from plotly import graph_objects as go
from plotly import subplots
from plotly import colors
from torch.profiler import profile, record_function, ProfilerActivity
from IPython.display import display

from time_series_prediction import settings, ode_problems, echo_state_network, kalman

In [None]:
# settings.switch_device('cuda')
settings.device == torch.device('cpu')

In [None]:
source = 'rossler'
n_train = 3000
n_test = 2000
resolution = 10
noise = 5e-2

y_full, u_full, t_out = ode_problems.generate_data(source, n_train+n_test, resolution, noise)

y_train = y_full[:n_train, :]
u_train = u_full[:n_train, :]

y_test = y_full[n_train:, :]
u_test = u_full[n_train:, :]

fig = subplots.make_subplots(rows=1, cols=2, column_widths=[3, 1])
for i in range(y_full.shape[1]):
    fig.add_scatter(x=t_out, y=y_full[:, i].cpu(), row=1, col=1, name=f'state {i+1}')
fig.add_scatter(x=y_full[:n_train, 0], y=y_full[:n_train, 1], row=1, col=2, name='train')
fig.add_scatter(x=y_full[n_train:, 0], y=y_full[n_train:, 1], row=1, col=2, name='test')
fig.add_vline(x=n_train/resolution, row=1, col=1)

fig

In [None]:
# torch.manual_seed(5)
esn = echo_state_network.ESN(
    n_inputs=0, 
    n_outputs=y_train.shape[1],
    n_neurons=1000,
    spectral_radius=0.99,
    f_activation=torch.tanh,
    )
x_train = esn.train(
    u_train.to(settings.device), 
    y_train[:, :].to(settings.device), 
    n_discard=500, 
    k_l2=1e-1,
)

x_init = x_train[-1, :]
y_init = y_train[-1, :]
x_test_esn, y_test_esn = esn.predict(
    u_test.to(settings.device), 
    x_init, 
    y_init.to(settings.device),
)
x_esn = np.concatenate([x_train.cpu(), x_test_esn])
# fig, axs = plt.subplots(2, 1, figsize=(20, 10), sharex=True)
# axs[0].plot(np.concatenate([x_train.numpy(), x]))
# axs[1].axvline(n_train, color='black')
# axs[1].plot(y_plot, linewidth=3)
# axs[1].plot(np.concatenate([(x_train @ esn.output_weights).numpy(), y]))

y_esn = np.concatenate([(x_train @ esn.output_weights).cpu(), y_test_esn])
rows = esn.n_outputs # + 1
fig = subplots.make_subplots(rows=rows, cols=1, shared_xaxes=True)
for i in range(esn.n_outputs):
    fig.add_scatter(y=y_full[:, i], name='system data', row=i+1, col=1, line_color='black')
    fig.add_scatter(y=y_esn[:, i], name='ESN', row=i+1, col=1, line_color='red')
    fig.add_vline(x=n_train, row=i+1, col=1)
if rows > esn.n_outputs:
    for x_neuron in x_esn.T:
        fig.add_scatter(y=x_neuron, showlegend=False, row=3, col=1)
fig.update_layout(hovermode='x')
fig.update_xaxes(title_text='n')
fig
# esn.output_weights

In [None]:
if y_full.shape[1] > 2:
    fig = go.Figure()
    fig.add_scatter3d(
        x=y_full[:, 0],
        y=y_full[:, 1],
        z=y_full[:, 2],
        name='system data',
        line_color='black',
        opacity=0.4,
        mode='lines',
    )
    fig.add_scatter3d(
        x=y_esn[n_train:, 0],
        y=y_esn[n_train:, 1],
        z=y_esn[n_train:, 2],
        name='ESN',
        line_color='red',
        mode='lines',
    )
    fig.update_layout(height=800)
    display(fig)

In [None]:
exps = np.linspace(-1, 2, 4)
# exps = np.array([-2.0, -0.5, 0])
k_l2 =  10**exps

show_reg_path = False

# n_discards = [3750, 3700, 3500, 3000, 2500, 2000, 1000]
n_discards = [500, 750, 1000, 1250, 1500, 1750]

if len(k_l2) > 1:
    c = colors.sample_colorscale(colors.get_colorscale('Plotly3'), (exps-min(exps))/(max(exps)-min(exps)))
else:
    c = ['red']

fig = subplots.make_subplots(
    rows=len(n_discards), 
    cols=esn.n_outputs*2 if show_reg_path else esn.n_outputs, 
    shared_xaxes=True, 
    column_widths=[4]*esn.n_outputs + [1]*esn.n_outputs if show_reg_path else [1]*esn.n_outputs,
)

for j, n in enumerate(n_discards):
    l = []
    for i, k in enumerate(k_l2):
        x_train = esn.train(
            u_train.to(settings.device), 
            y_train.to(settings.device), 
            n_discard=n, 
            k_l2=k,
        )
        x_init = x_train[-1, :]
        y_init = y_train[-1, :]
        x_test_esn, y_test_esn = esn.predict(
            torch.zeros((3000, 0), device=settings.device), 
            x_init, 
            y_init.to(settings.device),
        )
        y_esn = np.concatenate([(x_train @ esn.output_weights).cpu(), y_test_esn])
        l.append(esn.output_weights.cpu().numpy().squeeze())
        for o in range(esn.n_outputs):
            fig.add_scatter(y=y_esn[:, o], name=f'Regularisation: {k:0.3f}', row=j+1, col=o+1, line_color=c[i], legendgroup=str(k), showlegend=(j==0 and o==0))


    for o in range(esn.n_outputs):
        fig.add_scatter(y=y_full[:, o].cpu(), name='system data', row=j+1, col=o+1, line_color='black', showlegend=j==0, legendgroup='system')
        fig.add_vline(x=n, row=j+1, col=o+1, line_color='blue', line_width=2)
        fig.add_vline(x=n_train, row=j+1, col=o+1, line_color='blue', line_width=2)

    if show_reg_path:
        l = np.array(l)
        for coef in range(l.shape[1]):
            for o in range(esn.n_outputs):
                fig.add_scatter(x=k_l2, y=l[:, coef, o], row=j+1, col=esn.n_outputs+o+1, showlegend=False)
fig.update_layout(hovermode='x', height=1000)

for o in range(esn.n_outputs):
    fig.update_xaxes(col=o+1, title_text='n')
    if show_reg_path:
        fig.update_xaxes(col=esn.n_outputs+o+1, type='log')
fig


In [None]:
n_states = y_train.shape[1]
# n = kalman.NeuralNet(n_states, n_states, [3, 5, 10, 20]).to(settings.device)
n = kalman.EulerStepNet(n_states, n_states, [20, 40, 20], dt=1/resolution).to(settings.device)
# n = L96_ODE_Net_2(n_states)

obs_noise = kalman.ScalarNoise(
    torch.tensor([noise], device=settings.device), 
    n_states,
).to(settings.device)
proc_noise = kalman.ScalarNoise(
    torch.tensor([1], device=settings.device), 
    n_states,
).to(settings.device)

# obs_matrix = torch.zeros((1, n_states))
# obs_matrix[0, 0] = 1
obs_matrix = torch.eye(n_states, device=settings.device)

kf = kalman.AD_EnKF(
    n, obs_matrix, obs_noise, proc_noise, n_particles=20, 
    # init_state_distribution=torch.distributions.MultivariateNormal(
    #             torch.zeros(n_states), 50*torch.eye(n_states)
    # )
)

activities = [ProfilerActivity.CPU]
if settings.device == torch.device('cuda'):
    activities.append(ProfilerActivity.CUDA)

# with profile(activities=activities, record_shapes=True) as prof:
#     with record_function('training'):
kf.train(y_train.to(settings.device), 100, lr_decay=0.4, lr_hold=25)

In [None]:
kf.train(y_train.to(settings.device), 10, display_fig=False)

In [None]:
proc_noise._softplus(proc_noise.param.detach()).numpy()[0]

In [None]:
with torch.no_grad():
    ll, x = kf.log_likelihood(y_train.to(settings.device))
    x_test, y_test = kf.predict(x[-1, :, :], n_test*4, True, False)
    x_test_no_noise, y_test_no_noise = kf.predict(x[-1, :, :], n_test*4, False, False)

x_enkf = np.concatenate([
    x.detach().cpu().numpy()[1:, :, :], 
    x_test[1:, :, :],
])
y_enkf = np.concatenate([
    y_train[1:, None, :].repeat((1, kf.n_particles, 1)).detach().cpu().numpy(), 
    y_test[1:, :, :],
])


x_enkf_no_noise = np.concatenate([
    x.detach().cpu().numpy()[1:, :, :], 
    x_test_no_noise[1:, :, :],
])
y_enkf_no_noise = np.concatenate([
    y_train[1:, None, :].repeat((1, kf.n_particles, 1)).detach().cpu().numpy(), 
    y_test_no_noise[1:, :, :],
])
x_enkf_mean = x_enkf_no_noise.mean(axis=1)

c = colors.sample_colorscale(colors.get_colorscale('Plotly3'), np.arange(x.shape[1])/x.shape[1])

fig = subplots.make_subplots(rows=x.shape[2], cols=1, shared_xaxes=True)

y_mean = y_full.mean(axis=0).numpy()
y_std = y_full.std(axis=0).numpy()
y_upper = y_mean + 3 * y_std
y_lower = y_mean - 3 * y_std


for j in range(x.shape[2]):
    for i in range(x.shape[1]):
        # fig.add_scatter(x=t_out, y=x_enkf[:, i, j], name="EnKF", row=j+1, col=1, legendgroup='EnKF', showlegend=(i==0 and j==0), line_color=c[i])
        fig.add_scatter(x=t_out, y=x_enkf_no_noise[:, i, j], name="EnKF (no noise)", row=j+1, col=1, legendgroup='EnKF (no noise)', showlegend=(i==0 and j==0), line_color=c[i])

    fig.add_scatter(x=t_out, y=x_enkf_mean[:, j], name="EnKF (mean)", row=j+1, col=1, legendgroup='EnKF (mean)', showlegend=j==0, line_color='green')

    fig.add_scatter(x=t_out, y=y_full[:, j], name="system data", row=j+1, col=1, line_color='black', legendgroup='system', showlegend=j==0)
    fig.add_vline(x=n_train/resolution, row=j+1, col=1)
    fig.update_yaxes(row=j+1, range=[y_lower[j], y_upper[j]])
fig.update_layout(hovermode='x', height=800)
fig.update_xaxes(title_text='t')

# if np.abs(x_enkf).max() > 20 or np.any(np.isnan(x_enkf)):
#     fig.update_yaxes(range=[-20, 20])


fig

In [None]:
fig.write_html('Rossler EnKF.html')

In [None]:
fig = go.Figure()
# fig.add_scatter(x=x_enkf_mean[n_train:, 0], y=x_enkf_mean[n_train:, 1], name='EnKF')
fig.add_scatter(x=x_enkf[n_train:, 0, 0], y=x_enkf[n_train:, 0, 1], name='EnKF particle')
fig.add_scatter(x=x_enkf[n_train:, 1, 0], y=x_enkf[n_train:, 1, 1], name='EnKF particle')
fig.add_scatter(x=x_enkf[n_train:, 2, 0], y=x_enkf[n_train:, 2, 1], name='EnKF particle')
fig.add_scatter(x=y_full[:, 0], y=y_full[:, 1], name='system', line_color='black')
fig.update_xaxes(range=[y_lower[0], y_upper[0]])
fig.update_yaxes(range=[y_lower[1], y_upper[1]])

In [None]:
i_close = (x_enkf_mean < y_upper) & (x_enkf_mean > y_lower) & ~np.isinf(x_enkf_mean) & ~np.isnan(x_enkf_mean)
i_close = i_close.any(axis=1)
x_plot = x_enkf_mean.copy()
x_plot[~i_close, :] = np.nan
x_plot[:n_train, :] = np.nan

In [None]:
if y_full.shape[1] > 2:
    
    c = colors.sample_colorscale(colors.get_colorscale('Plotly3'), np.arange(x.shape[1])/x.shape[1])


    fig = go.Figure()
    fig.add_scatter3d(
        x=y_full[:, 0],
        y=y_full[:, 1],
        z=y_full[:, 2],
        name='system data',
        line_color='black',
        opacity=0.4,
        mode='lines',
    )
    for i in range(x_enkf_no_noise.shape[1]):
        fig.add_scatter3d(
            x=x_enkf_no_noise[n_train:, i, 0],
            y=x_enkf_no_noise[n_train:, i, 1],
            z=x_enkf_no_noise[n_train:, i, 2],
            name='EnKF',
            # line_color='red',
            mode='lines',
            legendgroup='EnKF',
            showlegend=i==0,
            line_color=c[i],
        )
    fig.update_layout(height=800)
    display(fig)

In [None]:
if y_full.shape[1] > 2:
    fig = go.Figure()
    fig.add_scatter3d(
        x=y_full[:, 0],
        y=y_full[:, 1],
        z=y_full[:, 2],
        name='system data',
        line_color='black',
        opacity=0.4,
        mode='lines',
    )
    fig.add_scatter3d(
        x=x_plot[:, 0],
        y=x_plot[:, 1],
        z=x_plot[:, 2],
        name='EnKF',
        line_color='red',
        mode='lines',
    )
    fig.update_layout(height=800)
    display(fig)