In [1]:
import jax

from RCN import RCN

import matplotlib.pyplot as plt
import numpy as np
from copy import copy
import jax.numpy as jnp
from jax import random, grad, jit
from functools import partial

from dysts.datasets import load_dataset
from utils import train_test_split, compute_forecast_horizon

from dysts.flows import Lorenz
from readouts import *




In [3]:
dt = 1e-3
train_per = 0.7
lam_lorenz = 0.906

## Load and simulate an attractor

model = Lorenz()
model.dt = dt

t, x_tot = model.make_trajectory(8000, return_times=True)
x_dot_tot = jnp.array(model.rhs(x_tot, t)).T

x_train, x_test = train_test_split(x_tot,1000, train_percentage=train_per)
x_dot_train, x_dot_test = train_test_split(x_dot_tot, 1000, train_percentage=train_per)




key = random.PRNGKey(14)
#readout = LinearReadout(500, 1e-6)
readout = QuadraticReadout(500, reg_param=1e-6)
rcn = RCN(key=key, readout=readout, n_input=3, dt =dt, washout_steps=100)
rcn.train(x_train, x_dot_train)
y = rcn.predict_states()

print(f"MSE is {rcn.train_MSE()}")



d_mce = rcn.derivative_train_MSE(x_dot_train)
print(f"MSE on derivative is {d_mce}")

print("generating test")
y_test = rcn.generate(len(x_test))

TypeError: unhashable type: 'ArrayImpl'

In [3]:
from jax import jacobian
r_o = rcn.read_out

AttributeError: 'RCN' object has no attribute 'read_out'

In [17]:
X = jnp.arange(3*10, dtype=float).reshape(10,3)
x = X[-1]
x.shape

(3,)

In [18]:
W_l = jnp.ones((3,2))
W_nl = jnp.ones((3,2))*2
print(W_l.shape,
W_nl.shape)

(3, 2) (3, 2)


In [21]:
funz = lambda x: x@W_l + x**2 @ W_nl
y = funz(x)
y

Array([4792., 4792.], dtype=float64)

In [24]:
funz_jac = jax.jacfwd(funz)

  # Example input for computing Jacobian
jac = funz_jac(x)
print(jac.shape)

(2, 3)


(2, 3)


In [26]:
almox_perex_jac = jax.vmap(funz_jac, in_axes=0)

In [27]:
almox_perex_jac(X).shape

(10, 2, 3)

In [28]:
almox_perex_jac(X)[1]

Array([[13., 17., 21.],
       [13., 17., 21.]], dtype=float64)

In [418]:
X[0]

Array([0., 1., 2.], dtype=float64)

(10, 3, 10, 500)

In [29]:
import jax
import jax.numpy as jnp

X = jnp.arange(3 * 10, dtype=float).reshape(10, 3)
W = jnp.arange(3 * 2, dtype=float).reshape( 3,2)

def funz(x):
    return jnp.dot( x, W_l) + jnp.dot(x**2, W_nl)

funz_jac = jax.jacfwd(funz)

jac = jax.vmap(funz_jac, axis_name=1)(X)
jac = jnp.swapaxes(jac, 1, 2)  # Swap the last two dimensions
print(jac.shape)  # Output: (10, 3, 2)


(10, 3, 2)


In [30]:
jac

Array([[[  1.,   1.],
        [  5.,   5.],
        [  9.,   9.]],

       [[ 13.,  13.],
        [ 17.,  17.],
        [ 21.,  21.]],

       [[ 25.,  25.],
        [ 29.,  29.],
        [ 33.,  33.]],

       [[ 37.,  37.],
        [ 41.,  41.],
        [ 45.,  45.]],

       [[ 49.,  49.],
        [ 53.,  53.],
        [ 57.,  57.]],

       [[ 61.,  61.],
        [ 65.,  65.],
        [ 69.,  69.]],

       [[ 73.,  73.],
        [ 77.,  77.],
        [ 81.,  81.]],

       [[ 85.,  85.],
        [ 89.,  89.],
        [ 93.,  93.]],

       [[ 97.,  97.],
        [101., 101.],
        [105., 105.]],

       [[109., 109.],
        [113., 113.],
        [117., 117.]]], dtype=float64)

In [445]:
jac[0]

Array([[0., 1.],
       [2., 3.],
       [4., 5.]], dtype=float64)

In [32]:
import jax
import jax.numpy as jnp

X = jnp.arange(3*10, dtype=float).reshape(3, 10)
W = jnp.ones((2, 3))

def funz(x):
    return jnp.dot(W, (0.5 * x ** 2))

funz_jac = jax.jacrev(funz)

x = X[:,0]  # Example input for computing Jacobian
jac = funz_jac(x)
print(jac.shape)  # Output: (2, 3)


(2, 3)
