In [None]:
import numpy as np
import torch
from tqdm import tqdm

%matplotlib inline
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import seaborn
import sys
seaborn.set()

import warnings
warnings.filterwarnings("ignore")

import reservoir
import reckernel
import kuramoto
%load_ext autoreload
%autoreload 2
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

# Data generation

In [None]:
L = 22 / (2 * np.pi)  # length
N = 100  # space discretization step
dt = 0.25  # time discretization step
N_train = 10000
N_test = 5000
N_init = 1000  # remove the initial points
tend = (N_train + N_test + N_init) * dt

dns = kuramoto.KS(L=L, N=N, dt=dt, tend=tend)
dns.simulate()
u = dns.uu[ninittransients:] / np.sqrt(N)
[u_train, u_test, _] = np.split(u, [N_data_train, N_data_train+N_data_test], axis=0)

In [None]:
N_plot = 1000
u_plot = u[:N_plot,:]

plt.figure()
plt.imshow(u.T)
plt.colorbar()
plt.grid(b=None)
plt.xlabel(r"$n, \quad t={:}n$".format(dt))
plt.ylabel(r"$x$")
plt.title("Kuramoto-Sivashisky time series")

# Recurrent Kernel

In [None]:
u_train_t = torch.from_numpy(u_train).to(device)
u_test_t = torch.from_numpy(u_test).to(device)
input_len, input_dim = u_train_t.shape

pred_horizon_range = 5
out_train = torch.zeros(input_len, input_dim * pred_horizon_range).to(device)
for pred_horizon in range(1, pred_horizon_range+1):
    out_train[:, (pred_horizon-1)*input_dim:pred_horizon*input_dim] = torch.roll(u_train_t, -pred_horizon, dims=0)

input_scale = 0.01
res_scale = 0.8
n_iter = 50

model = reckernel.RK(function='arcsin', res_scale=res_scale, input_scale=input_scale, n_iter=n_iter)
K = model.forward(u_train_t).to(device)

output_w = model.train(K, out_train, alpha=1e-5)

n_rec = 20
Ktest, diag_res_train, diag_res_test = model.forward_test(u_train_t, u_test_t)
test_rec_pred = model.rec_pred(Ktest, u_train_t, u_test_t, output_w, n_rec, diag_res_train, diag_res_test)

In [None]:
n_rec_pred = test_len - len_of_interest + 1
new_pred_horizon = pred_horizon_range * (n_rec+1)
out_test = torch.zeros(test_len, input_dim * new_pred_horizon).to(device)
for pred_horizon in range(1, new_pred_horizon+1):
    out_test[:, (pred_horizon-1)*input_dim:pred_horizon*input_dim] = torch.roll(u_test_t[:test_len, :], -pred_horizon, dims=0)

rec_pred = test_rec_pred.reshape(n_rec_pred, (n_rec+1)*pred_horizon_range, input_dim).to(device)
truth = out_test[len_of_interest-1:, :].reshape(n_rec_pred, (n_rec+1)*pred_horizon_range, input_dim).to(device)
diff = rec_pred - truth
test = torch.mean(diff**2, dim=2)[:500, :]
plt.figure(figsize=(10,10))
plt.imshow(test.cpu(), aspect='auto', vmin=0, vmax=0.03) #, extent=[0, 200*0.043, 500, 0])
plt.grid(b=None)
plt.colorbar();

In [None]:
origin = 75
pred_plot = test_rec_pred[origin, :].reshape((n_rec+1)*pred_horizon_range, input_dim)

new_pred_horizon = pred_horizon_range * (n_rec+1)
out_test = torch.zeros(test_len, input_dim * new_pred_horizon).to(device)
for pred_horizon in range(1, new_pred_horizon+1):
    out_test[:, (pred_horizon-1)*input_dim:pred_horizon*input_dim] = torch.roll(u_test_t[:test_len, :], -pred_horizon, dims=0)

true_plot = out_test[len_of_interest+origin, :].reshape(-1, input_dim).cpu()

from mpl_toolkits.axes_grid1 import make_axes_locatable
_, axes  = plt.subplots(1, 3, figsize=(15, 6))

plot = axes[0].imshow(true_plot.T, cmap='jet')
axes[0].grid(b=None)
axes[0].set_xlabel("Prediction time")
axes[0].set_ylabel("Space")
axes[0].set_title("True output")
divider = make_axes_locatable(axes[0])
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(plot, cax=cax)
# plt.colormap('jet')

plot = axes[1].imshow(pred_plot.T, cmap='jet')
axes[1].grid(b=None)
axes[1].set_xlabel("Prediction time")
axes[1].set_ylabel("Space")
axes[1].set_title("Predicted output")
divider = make_axes_locatable(axes[1])
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(plot, cax=cax)

plot = axes[2].imshow(((true_plot-pred_plot)).T, cmap='jet')
axes[2].grid(b=None)
axes[2].set_xlabel("Prediction time")
axes[2].set_ylabel("Space")
axes[2].set_title("Difference");
divider = make_axes_locatable(axes[2])
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(plot, cax=cax);

# Reservoir Computing / Structured Reservoir Computing

In [None]:
u_train_t = torch.from_numpy(u_train).to(device)
u_test_t = torch.from_numpy(u_test).to(device)
input_len, input_dim = u_train_t.shape

pred_horizon_range = 10
out_train = torch.zeros(input_len, input_dim * pred_horizon_range).to(device)
for pred_horizon in range(1, pred_horizon_range+1):
    out_train[:, (pred_horizon-1)*input_dim:pred_horizon*input_dim] = torch.roll(u_train_t, -pred_horizon, dims=0)

n_res = 1000
input_scale = 0.014
res_scale = 0.8
leak_rate = 0.6

initial_state = torch.randn(n_res).to(device) / np.sqrt(n_res)
model = reservoir.ESN(input_dim, res_size=n_res, res_scale=res_scale, input_scale=input_scale, f='erf', leak_rate=leak_rate)
X = model.forward(u_train_t, initial_state=initial_state).to(device)

output_w = model.train(X, out_train, alpha=1e-4)
pred_output = X @ output_w

Xtest = model.forward(u_test_t, initial_state=initial_state).to(device)
pred_output_test = Xtest @ output_w

n_rec_pred = 2000
Xtest = Xtest[:n_rec_pred, :]

input_dim = 100
n_rec = 20
test_rec_pred = model.rec_pred(Xtest, output_w, n_rec, input_dim)
print(test_rec_pred.shape)

In [None]:
new_pred_horizon = pred_horizon_range * (n_rec+1)
out_test = torch.zeros(input_len, input_dim * new_pred_horizon).to(device)
for pred_horizon in range(1, new_pred_horizon+1):
    out_test[:, (pred_horizon-1)*input_dim:pred_horizon*input_dim] = torch.roll(u_test_t, -pred_horizon, dims=0)

rec_pred = test_rec_pred.reshape(n_rec_pred, (n_rec+1)*pred_horizon_range, input_dim).to(device)
truth = out_test[:n_rec_pred, :].reshape(n_rec_pred, (n_rec+1)*pred_horizon_range, input_dim).to(device)
diff = rec_pred - truth
test = torch.mean(diff**2, dim=2)[:500, :]
plt.figure(figsize=(10,10))
plt.imshow(test.cpu(), aspect='auto', vmin=0, vmax=0.03, extent=[0, 200*0.043, 500, 0])
plt.grid(b=None)
plt.colorbar()

In [None]:
origin = 220
pred_plot = test_rec_pred[origin, :].reshape((n_rec+1)*pred_horizon_range, input_dim)

new_pred_horizon = pred_horizon_range * (n_rec+1)
out_test = torch.zeros(input_len, input_dim * new_pred_horizon).to(device)
for pred_horizon in range(1, new_pred_horizon+1):
    out_test[:, (pred_horizon-1)*input_dim:pred_horizon*input_dim] = torch.roll(u_test_t, -pred_horizon, dims=0)

true_plot = out_test[origin, :].reshape(-1, input_dim).cpu()

from mpl_toolkits.axes_grid1 import make_axes_locatable
_, axes  = plt.subplots(1, 3, figsize=(15, 6))

plot = axes[0].imshow(true_plot.T, cmap='jet')
axes[0].grid(b=None)
axes[0].set_xlabel("Prediction time")
axes[0].set_ylabel("Space")
axes[0].set_title("True output")
divider = make_axes_locatable(axes[0])
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(plot, cax=cax)

plot = axes[1].imshow(pred_plot.T, cmap='jet')
axes[1].grid(b=None)
axes[1].set_xlabel("Prediction time")
axes[1].set_ylabel("Space")
axes[1].set_title("Predicted output")
divider = make_axes_locatable(axes[1])
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(plot, cax=cax)

plot = axes[2].imshow(((true_plot-pred_plot)).T, cmap='jet')
axes[2].grid(b=None)
axes[2].set_xlabel("Prediction time")
axes[2].set_ylabel("Space")
axes[2].set_title("Difference");
divider = make_axes_locatable(axes[2])
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(plot, cax=cax);