This notebook runs a grid search on reservoirs for their best performence

In [1]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Wed Nov 10 15:41:30 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   32C    P0    26W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

Your runtime has 27.3 gigabytes of available RAM



In [3]:
!pip install --upgrade pip # To support manylinux2010 wheels
!pip install --upgrade "jax[cuda111]" -f https://storage.googleapis.com/jax-releases/jax_releases.html  # GPU
!pip install flax
!git clone https://github.com/GJBoth/jacho.git
%cd jacho
!sudo pip install . 

Collecting pip
  Downloading pip-21.3.1-py3-none-any.whl (1.7 MB)
[K     |████████████████████████████████| 1.7 MB 4.2 MB/s 
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 21.1.3
    Uninstalling pip-21.1.3:
      Successfully uninstalled pip-21.1.3
Successfully installed pip-21.3.1
Looking in links: https://storage.googleapis.com/jax-releases/jax_releases.html
Collecting jax[cuda111]
  Downloading jax-0.2.24.tar.gz (786 kB)
     |████████████████████████████████| 786 kB 4.3 MB/s            
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: jax
  Building wheel for jax (setup.py) ... [?25l[?25hdone
  Created wheel for jax: filename=jax-0.2.24-py3-none-any.whl size=903112 sha256=4bbf244d73aceeaf1094a01e7ec669c18ba659c5c388434b9672b5376a4a4c15
  Stored in directory: /root/.cache/pip/wheels/28/a9/0f/3497740c85f6e1de8f4d291fd2f77d046d66a87620143d0d0e
Successfully built jax
Install

In [4]:
from jacho.layers.reservoirs import RandomReservoir, StructuredTransform, FastStructuredTransform, SparseReservoir
from jacho.models.generic import GenericEchoState
from jacho.layers.output import Residual
from jacho.training.training import ridge
from jacho.data.KS import KS

from jax import random
import numpy as np
import jax.numpy as jnp
from jax import jit
from flax import linen as nn

import matplotlib.pyplot as plt

key = random.PRNGKey(42)

from jax.interpreters import xla

import itertools

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
%cd ..
%cd drive/MyDrive/Colab_Notebooks

/
[Errno 2] No such file or directory: 'drive/MyDrive/Colab_Notebooks'
/


# Making data

In [5]:
# Setting up our dataset; similar to jonathans
L = 22 / (2 * np.pi)  # length
N = 32  # space discretization step
dt = 0.25  # time discretization step
N_train = 10000
N_test = 1000
N_init = 1000  # remove the initial points
tend = (N_train + N_test + N_init - 1) * dt

np.random.seed(2)
dns = KS(L=L, N=N, dt=dt, tend=tend)
dns.simulate()

In [6]:
# Prepping train and test matrices
# inputs need to be size [time_steps, samples, spatial_points]
# i.e. here they are [10000, 1, 100]
u = np.expand_dims(dns.uu, axis=1)
_, u_train, u_test = np.split(u / np.sqrt(N), 
                                     [N_init, 
                                     (N_init + N_train)], axis=0)

# Runing grid searchs

In [None]:
# Setting up random model
n_reservoir = 2000
n_out = u_train.shape[-1]
norm_factor = jnp.sqrt(n_out / n_reservoir)
output_layer_args = (norm_factor, )

def grid_search(reservoir_type, grid_params):
  params_list = []
  errors_sum = []
  for x in itertools.product(*grid_params.values()):

    model = GenericEchoState(n_reservoir, reservoir_type, x, n_out, Residual, 
                             output_layer_args)

    state = model.initialize_state(key, n_reservoir)
    params = model.init(key, state, u_train[0]) # initializing the parameters and state

    # Running reservoir
    # new_state, intermediate_states = model.apply(params, state, u_train, method=model.run_reservoir)

    # Training - currently runs the reservoir as well but should change.
    end_of_train_state, params = ridge(model, params, state, u_train, renorm_factor=norm_factor, alpha=1e-2, n_init=50)

    # Predicting
    n_predict_steps = N_test
    end_of_predict_state, (prediction, _) = model.apply(params, end_of_train_state, u_train[-1], n_predict_steps)
    errors_sum += [np.sum(abs(u_test.squeeze().T - prediction.squeeze().T), axis=0)]
    params_list += [x]

    xla._xla_callable.cache_clear()
    del state, params, end_of_predict_state, prediction
  return params_list, errors_sum

In [None]:
grid_params = {'input_scale': np.linspace(0.1, 1.1, 10),
               'reservoir_scale': np.linspace(0.1, 1.1, 10),
               'bias_scale': np.linspace(0.3, 4.5, 22)}
params_list, sum_errors = grid_search(RandomReservoir, grid_params)
np.save("RR_gridsearch_error.npy", np.array(sum_errors))
np.save("RR_gridsearch_params.npy", np.array(params_list))

In [None]:
grid_params = {'input_scale': np.linspace(0.1, 1.1, 10),
               'reservoir_scale': np.linspace(0.1, 1.1, 10),
               'bias_scale': np.linspace(0.3, 4.5, 22)}
params_list, sum_errors = grid_search(StructuredTransform, grid_params)
np.save("ST_gridsearch_error.npy", np.array(sum_errors))
np.save("ST_gridsearch_params.npy", np.array(params_list))

In [None]:
grid_params = {'input_scale': np.linspace(0.1, 1.1, 10),
               'reservoir_scale': np.linspace(0.1, 1.1, 10),
               'bias_scale': np.linspace(0.3, 4.5, 22)}
params_list, sum_errors = grid_search(FastStructuredTransform, grid_params)
np.save("FST_gridsearch_error.npy", np.array(sum_errors))
np.save("FST_gridsearch_params.npy", np.array(params_list))

In [None]:
grid_params = {'sparsity_level': np.array([0.1])
               'input_scale': np.linspace(0.1, 1.1, 10),
               'reservoir_scale': np.linspace(0.1, 1.1, 10),
               'bias_scale': np.linspace(0.3, 4.5, 22)}
params_list, sum_errors = grid_search(SparseReservoir, grid_params)
np.save("SparseR_gridsearch_error.npy", np.array(sum_errors))
np.save("SparseR_gridsearch_params.npy", np.array(params_list))

In [None]:
from jacho.recurrent_kernel import RecurrentKernel, erf_kernel, train
n_out = u_train.shape[-1]
norm_factor = 1.1 * jnp.sqrt(n_out / n_reservoir)
output_layer_args = (norm_factor, )
key = random.PRNGKey(42)
params = model.init(key, u_train.squeeze())

grid_params = {'input_scale': np.linspace(0.2, 1.1, 10),
               'reservoir_scale': np.linspace(0.2, 1.1, 10),
               'bias_scale': np.linspace(0.5, 5, 19)}
params_list = []
errors_sum = []
for x in itertools.product(*grid_params.values()):
  model = RecurrentKernel(erf_kernel, 50, 0.11, x)

  # forward pass to get kernel etc
  model.apply(params, u_train.squeeze(), method=model.train_kernel);
  
  alpha = 1e-4
  error_sum = np.nan
  while np.isnan(error_sum):
    model_state = train(model, params, u_train.squeeze(), alpha=alpha)
    prediction = model.apply(params, u_train.squeeze(), model_state, length=N_test, method=model.predict)
    error_sum = np.sum(abs(u_test.squeeze()[:500].T - prediction.squeeze()[:500].T))
    alpha = 5*alpha
  errors_sum += [np.sum(abs(u_test.squeeze().T - prediction.squeeze().T), axis=0)]
  params_list += [x + (alpha,)]

  xla._xla_callable.cache_clear()
  del state, params, end_of_predict_state, prediction

np.save("RK_gridsearch_error.npy", np.array(sum_errors))
np.save("RK_gridsearch_params.npy", np.array(params_list))