In [1]:
import matplotlib.pyplot as plt

from __future__ import print_function, division
import jax.numpy as np
from jax import grad, jit, vmap
from jax import random
import numpy as onp
import time as time_time
# # JAX is a package, by the same author of the paper, for speeding up 
# # linear algebra-heavy operations

In [2]:
# Generate randomness

def keygen(key, nkeys):
  """Generate randomness that JAX can use by splitting the JAX keys.
  Args:
    key : the random.PRNGKey for JAX
    nkeys : how many keys in key generator
  Returns:
    2-tuple (new key for further generators, key generator)
  """
  keys = random.split(key, nkeys+1)
  return keys[0], (k for k in keys[1:])

In [3]:
# Generate random parameters 

def random_esn_params(key, u, n, m, tau=1.0, dt=0.1, g=1.0):
  """Generate random RNN parameters
  
  Arguments: 
    u: dim of the input
    n: dim of the hidden state
    m: dim of the output
    tau: "neuronal" time constant
    dt: time between Euler integration updates
    g: scaling of the recurrent matrix in the reservoir

  Returns:
    A dictionary of parameters for the ESN.
  """

  key, skeys = keygen(key, 5)
  hscale = 0.25
  ifactor = 1.0 / np.sqrt(u)
  hfactor = g / np.sqrt(n)
  pfactor = 1.0 / np.sqrt(n)
  ffactor = 1.0 # Feedback factor, keep at 1 for now.
  return {'a0' : random.normal(next(skeys), (n,)) * hscale,
          'wI' : random.normal(next(skeys), (n,u)) * ifactor,
          'wR' : random.normal(next(skeys), (n,n)) * hfactor,
          'wO' : random.normal(next(skeys), (m,n)) * pfactor,
          'wF' : random.normal(next(skeys), (n,m)) * ffactor,
          'dt_over_tau' : dt / tau}


def new_force_params(n, alpha=1.0):
  """Generate new 'parameters' for the RLS learning rule.

  This routine essentially initializes the inverse correlation matrix in RLS.

  Arguments:
    n: dim of hidden state
    alpha: initial learning rate

  Returns: 
    A dictionary with RLS parameters.
  """

  identity_3d = onp.zeros((n,n,n))
  idx = onp.arange(n)
  identity_3d[:, idx, idx] = 1    

  return {'P' : np.eye(n) * alpha, 'P_recurr': np.array(identity_3d)*alpha}


# Run one iteration of "forward pass". 
# TODO: Implement in Keras model
def esn(x, a, h, z, wI, wR, wF, wO, dtdivtau):
  """Run the continuous-time Echostate network one step.
  
    da/dt = -a + wI x + wR h + wF z

    Arguments:
      x: ndarray of input to ESN
      a: ndarray of activations (pre nonlinearity) from prev time step
      h: ndarray of hidden states from prev time step
      z: ndarray of output from prev time step
      wI: ndarray, input matrix, shape (n, u)
      wR: ndarray, recurrent matrix, shape (n, n)
      wF: ndarray, feedback matrix, shape (n, m)
      wO: ndarray, output matrix, shape (m, n)
      dtdivtau: dt / tau

    Returns: 
      The update to the ESN at this time step.
  """
  dadt = -a + np.dot(wI, x) + np.dot(wR, h) # + np.dot(wF, z)
  a = a + dtdivtau * dadt
  h = np.tanh(a)
  z = np.dot(wO, h)
  
  return a, h, z

# Calculate weight updates
# TODO: Implement in Keras model
def rls(h, z, f, wO, P):
  """Perform the recursive least squares step.
  
    Arguments: 
      h: ndarray of hidden state at current time step
      z: ndarray of output at current time step
      f: ndarray of targets at current time step
      wO: ndarray of output weights, shape (m, n)
      P: ndarray, inverse correlation matrix, shape (n,n)

    Returns: 
      A 2-tuple of the updated wO, and updated P
  
  """
  # See paper for meanings of P and w
  # update inverse correlation matrix
  # k = np.expand_dims(np.dot(P, h), axis=1)
  # hPh = np.dot(h.T, k)
  # c = 1.0/(1.0 + hPh)
  # # print(k.shape)
  # # print(hPh.shape)
  # # print(c.shape)
  # # print()
  
  # P = P - np.dot(k*c, k.T)
    
  # # update the output weights
  # e = np.atleast_2d(z-f)
  # dw = np.dot(-c*k, e).T
  # # print(dw.shape)
  # print(e.shape)

  k = np.expand_dims(np.dot(P, h), axis=1)
  hPh = np.dot(h.T, k)
  c = 1.0/(1.0+hPh)

  dP = np.dot(c*k, np.transpose(k))
  P = P - dP

  e = np.atleast_2d(z-f)
  dw = np.dot(np.expand_dims(np.dot(P, h), axis=1), e).T
  return wO - dw, P, dw, dP     


def update_recurr_P(h, z, f, P_recurr, wR):
  # zero out the rows and columsn of P where the weights are 0
  # I,J = np.nonzero(wR==0)
  # P_recurr[I,:,J]=0
  # P_recurr[I,J,:]=0

  h = np.expand_dims(h,axis = 1)
  Ph = np.dot(P_recurr, h)[:,:,0] # need to multiply by error term to get n x n matrix of weight updates (indiced by i x j)

  hPh = np.expand_dims(np.dot(Ph, h),axis = 2) # n x 1 x 1 array for i

  #htP = np.dot(np.transpose(h),P_recurr)[0] # indiced by i x k, n x n matrix 
  
  dP_recurr =  np.expand_dims(Ph, axis = 2) * np.expand_dims(Ph, axis = 1)  / (1+hPh) # Ph[:,:,None]*htP[:,None,:] / (1+hPh)

  P_recurr -= dP_recurr

  e = np.atleast_2d(z-f)
  assert e.shape == (1,1)

  dwR = e*np.dot(P_recurr, h)[:,:,0]

  return wR - dwR, P_recurr, dwR, dP_recurr 
  #h_mask = wR_mask * h
  #dwR = np.diagonal(np.dot(P_recurr, h_mask), axis1=0, axis2=2) # missing error term



In [4]:
# This shows the beginning to end pipeline
# This should all change in the Keras implementation
# For a first pass, ignore the use of JAX and try to get it working
# with no optimizations (i.e. only using numpy/tensorflow)

def esn_run_and_train_jax(params, fparams, x_t, f_t=None, do_train=False):
  """Run the Echostate network forward a number of steps the length of x_t.
  
    This implementation uses JAX to build the outer time loop from basic
    Python for loop.

    Arguments: 
      params: dict of ESN params
      fparams: dict of RLS params
      x_t: ndarray of input time series, shape (t, u)
      f_t: ndarray of target time series, shape (t, m)
      do_train: Should the network be trained on this run? 
    
    Returns:
      4-tuple of params, fparams, h_t, z_t, after running ESN and potentially
        updating the readout vector.  
  """
  # per-example predictions
  a = params['a0']
  h = np.tanh(a)
  wO = params['wO']
  wI = params['wI']
  wR = params['wR']
  wF = params['wF']
  z = np.dot(wO, h)
  if do_train:
    P = fparams['P']
    P_recurr = fparams['P_recurr']
    # I,J = onp.nonzero(onp.array(wR)==0)
    # P_recurr[I,:,J]=0
    # P_recurr[I,J,:]=0
    # P_recurr=np.array(P_recurr)
  else:
    P = None
    P_recurr = None
  h_t = []
  z_t = []

  dP = 0
  dw = 0

  dtdivtau = params['dt_over_tau']
  for tidx, x in enumerate(x_t):
    a, h, z = esn(x, a, h, z, wI, wR, wF, wO, dtdivtau)

    if do_train:
      wO, P, dw, dP = rls(h, z, f_t[tidx], wO, P)
      wR, P_recurr, dwR, dP_recurr = update_recurr_P(h, z, f_t[tidx], P_recurr, wR)
    h_t.append(h)
    z_t.append(z)
    
  if do_train:
    fparams['P'] = P
    fparams['P_recurr'] = P_recurr
  params['wO'] = wO
  params['wR'] = wR
  params['a0'] = a
  
  h_t = np.array(h_t)  
  z_t = np.array(z_t)
  return params, fparams, h_t, z_t


def esn_run_jax(params, x_t):
  """Run the echostate network forward.

    Arguments:
      params: dict of ESN params
      x_t: ndarray of input with shape (t,u)

    Returns: 
      2-tuple of ndarrays with first dim time, the hidden state and the outputs.
  """
  _, _, h_t, z_t  = esn_run_and_train_jax(params, None, x_t, 
                                          None, do_train=False)
  return h_t, z_t

esn_run_jax_jit = jit(esn_run_jax)


def esn_train_jax(params, fparams, x_t, f_t):
  """Run the echostate network forward and also train it.

    Arguments:
      params: dict of ESN params
      fparams: dict of RLS params
      x_t: ndarray of inputs with shape (t,u)
      f_t: ndarray of targets with shape (t,m)

    Returns: 
      4-tuple of updated params, fparams, and also ndarrays with first dim 
        time, the hidden state and the outputs.
  """
  return esn_run_and_train_jax(params, fparams, x_t, f_t, do_train=True)
  
esn_train_jax_jit = jit(esn_train_jax)

In [5]:
# Basic parameters of the Echostate networks
key = random.PRNGKey(0)

T = 30              # total time
u = 1               # number of inputs (didn't bother to set up zero, just put in zeros)
n = 400            # size of the reservoir in the ESN
tau = 1          # neuron time constant
dt = tau / 10.0     # Euler integration step
time = np.arange(0, T, dt) # all time
ntime = time.shape[0]      # the number of time steps
alpha = 1e0
m = 1

x_t = np.zeros((ntime,u)) # Just a stand-in in folks want a real input later



In [6]:
# Generate some target data by running an ESN, and just grabbing hidden 
# dimensions as the targets of the FORCE trained network.

g = 1.8  # Recurrent scaling of the data ESN, gives how wild the dynamics are.

data_seed = onp.random.randint(0, 10000000)
print("Data seed: %d" % (data_seed))
key = random.PRNGKey(data_seed)
data_params = random_esn_params(key, u, n, m, g=g)
h_t, z_t = esn_run_jax_jit(data_params, x_t)

f_t = h_t[:,0:m] # This will be the training data for the trained ESN


Data seed: 8717853


In [7]:
g = 1.5
params_seed = onp.random.randint(0, 10000000)
print("Params seed %d" %(params_seed))
key = random.PRNGKey(params_seed)
init_params = random_esn_params(key, u, n, m, g=g)

Params seed 3836463


In [8]:
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import backend, activations

In [9]:
wI = tf.transpose(tf.convert_to_tensor(init_params['wI']))
wR = tf.transpose(tf.convert_to_tensor(init_params['wR']))
wF = tf.transpose(tf.convert_to_tensor(init_params['wF']))
wO = tf.transpose(tf.convert_to_tensor(init_params['wO']))

a0 = tf.convert_to_tensor(tf.expand_dims(init_params['a0'], axis = 0))

f_t =  tf.convert_to_tensor(f_t,dtype=tf.float32) 
input2 = tf.convert_to_tensor(x_t)
ntraining = 20

In [10]:
class FORCELayer(keras.layers.AbstractRNNCell):
    def __init__(self, units, output_size, activation, seed = None, g = 1.5, 
                 input_kernel_trainable = False, recurrent_kernel_trainable = False, 
                 output_kernel_trainable = True, feedback_kernel_trainable = False, p_recurr = 1, **kwargs):
                
        self.units = units 
        self._output_size = output_size
        self.activation = activations.get(activation)

        if seed is None:
          self.seed_gen = tf.random.Generator.from_non_deterministic_state()
        else:
          self.seed_gen = tf.random.Generator.from_seed(seed)
        
        self._g = g

        self._input_kernel_trainable = input_kernel_trainable
        self._recurrent_kernel_trainable = recurrent_kernel_trainable
        self._feedback_kernel_trainable = feedback_kernel_trainable
        self._output_kernel_trainable = output_kernel_trainable
        self._p_recurr = p_recurr

        super().__init__(**kwargs)

    @property
    def state_size(self):
        return [self.units, self.units, self.output_size]

    @property 
    def output_size(self):
        return self._output_size

    def initialize_input_kernel(self, input_shape, input_kernel = None):
        if input_kernel is None:
            initializer = keras.initializers.RandomNormal(mean=0., 
                                                          stddev= 1/input_shape**0.5, 
                                                          seed=self.seed_gen.uniform([1], 
                                                                                    minval=None, 
                                                                                    dtype=tf.dtypes.int64)[0])
            input_kernel = initializer(shape = (input_shape, self.units))
         
        self.input_kernel = self.add_weight(shape=(input_shape, self.units),
                                            initializer=keras.initializers.constant(input_kernel),
                                            trainable = self._input_kernel_trainable,
                                            name='input_kernel')
        
    def initialize_recurrent_kernel(self, recurrent_kernel = None):
        if recurrent_kernel is None:        
            initializer = keras.initializers.RandomNormal(mean=0., 
                                                          stddev= self._g/self.units**0.5, 
                                                          seed=self.seed_gen.uniform([1], 
                                                                                      minval=None, 
                                                                                      dtype=tf.dtypes.int64)[0])
        
            recurrent_kernel = self._p_recurr*keras.layers.Dropout(1-self._p_recurr)(initializer(shape = (self.units, self.units)), 
                                                                                    training = True)

        self.recurrent_kernel = self.add_weight(shape=(self.units, self.units),
                                                initializer=keras.initializers.constant(recurrent_kernel),
                                                trainable = self._recurrent_kernel_trainable,
                                                name='recurrent_kernel')
    
    def initialize_feedback_kernel(self, feedback_kernel = None):
        if feedback_kernel is None:
            initializer = keras.initializers.RandomNormal(mean=0., 
                                                          stddev= 1, 
                                                          seed=self.seed_gen.uniform([1], 
                                                                                 minval=None, 
                                                                                 dtype=tf.dtypes.int64)[0])
            feedback_kernel = initializer(shape = (self.output_size, self.units))

        self.feedback_kernel = self.add_weight(shape=(self.output_size, self.units),
                                                initializer=keras.initializers.constant(feedback_kernel),
                                                trainable = self._feedback_kernel_trainable,
                                                name='feedback_kernel')
                                            

    def initialize_output_kernel(self, output_kernel = None):
        if output_kernel is None:
            initializer=keras.initializers.RandomNormal(mean=0., 
                                                        stddev= 1/self.units**0.5, 
                                                        seed=self.seed_gen.uniform([1], 
                                                                                   minval=None, 
                                                                                   dtype=tf.dtypes.int64)[0])
            output_kernel = initializer(shape = (self.units, self.output_size))

        self.output_kernel = self.add_weight(shape=(self.units, self.output_size),
                                              initializer=keras.initializers.constant(output_kernel),
                                              trainable = self._output_kernel_trainable,
                                              name='output_kernel')     
    
    def build(self, input_shape):

        self.initialize_input_kernel(input_shape[-1])
        self.initialize_recurrent_kernel()
        self.initialize_feedback_kernel()
        self.initialize_output_kernel() 

        self.built = True

    @classmethod
    def from_weights(cls, weights, **kwargs):
        # Initialize the network from a list of weights (e.g., user-generated)
        input_kernel, recurrent_kernel, feedback_kernel, output_kernel = weights 
        input_shape, input_units = input_kernel.shape 
        recurrent_units1, recurrent_units2 = recurrent_kernel.shape 
        feedback_output_size, feedback_units = feedback_kernel.shape 
        output_units, output_size = output_kernel.shape 


        units = input_units 
        assert np.all(np.array([input_units, recurrent_units1, recurrent_units2, 
                            feedback_units, output_units]) == units)

        assert feedback_output_size == output_size 
        assert 'p_recurr' not in kwargs.keys(), 'p_recurr not supported in this method'

        self = cls(units=units, output_size=output_size, p_recurr = None, **kwargs)

        self.initialize_input_kernel(input_shape, input_kernel)
        self.initialize_recurrent_kernel(recurrent_kernel)
        self.initialize_feedback_kernel(feedback_kernel)
        self.initialize_output_kernel(output_kernel)

        self.built = True

        return self

class FORCEModel(keras.Model):
    def __init__(self, force_layer, alpha_P=1.,return_sequences=True):
        super().__init__()
        self.alpha_P = alpha_P
        self.force_layer = keras.layers.RNN(force_layer, 
                                            stateful=True, 
                                            return_state=True, 
                                            return_sequences=return_sequences)

        self.units = force_layer.units 

        self.original_force_layer = force_layer

    def build(self, input_shape):
        super().build(input_shape)
        self.initialize_P()
        self.initialize_train_idx()

    def initialize_P(self):

        self.P_output = self.add_weight(name='P_output', shape=(self.units, self.units), 
                                 initializer=keras.initializers.Identity(
                                              gain=self.alpha_P), trainable=True)

        if self.original_force_layer.recurrent_kernel.trainable:

            identity_3d = np.zeros((self.units, self.units, self.units))
            idx = np.arange(self.units)

#################### 

            identity_3d[:, idx, idx] = self.alpha_P 

            if self.original_force_layer.recurrent_nontrainable_boolean_mask is not None:
                I,J = np.nonzero(tf.transpose(self.original_force_layer.recurrent_nontrainable_boolean_mask).numpy()==True)
                identity_3d[I,:,J]=0
                identity_3d[I,J,:]=0

#################### 
# # new 
#          #  print('new')
#             identity_3d[idx, idx, :] = self.alpha_P 
#             J,I = np.nonzero(self.original_force_layer.recurrent_kernel.numpy()==0)
#             identity_3d[J,:,I]=0
#             identity_3d[:,J,I]=0

#################### 

            self.P_GG = self.add_weight(name='P_GG', shape=(self.units, self.units, self.units), 
                                    initializer=keras.initializers.constant(identity_3d), 
                                    trainable=True)

    def initialize_train_idx(self):
        self._output_kernel_idx = None
        self._recurrent_kernel_idx = None
        for idx in range(len(self.trainable_variables)):
          trainable_name = self.trainable_variables[idx].name
              
          if 'output_kernel' in trainable_name:
            self._output_kernel_idx = idx
          elif 'P_output' in trainable_name:
            self._P_output_idx = idx
          elif 'P_GG' in trainable_name:
            self._P_GG_idx = idx
          elif 'recurrent_kernel' in trainable_name:
            self._recurrent_kernel_idx = idx

    def call(self, x, training=False,   **kwargs):

        if training:
            return self.force_layer_call(x, training, **kwargs)
        else:
            initialization = all(v is None for v in self.force_layer.states)
            
            if not initialization:
              original_state = [i.numpy() for i in self.force_layer.states]
            output = self.force_layer_call(x, training, **kwargs)[0]

            if not initialization:
              self.force_layer.reset_states(states = original_state)
            return output

    def force_layer_call(self, x, training, **kwargs):
        return self.force_layer(x, **kwargs) 

    def train_step(self, data):

        x, y = data

        if self.run_eagerly:
          self.hidden_activation = []
                

        for i in range(x.shape[1]):
          z, _, h, _ = self(x[:,i:i+1,:], training=True)

          if self.force_layer.return_sequences:
            z = z[:,0,:]
         
          trainable_vars = self.trainable_variables

          if self._output_kernel_idx is not None:
            self.update_output_kernel(self.P_output, h, z, y[:,i,:], 
                                      trainable_vars[self._P_output_idx], 
                                      trainable_vars[self._output_kernel_idx])
          
          if self._recurrent_kernel_idx is not None:
            self.update_recurrent_kernel(self.P_GG, h, z, y[:,i,:],
                                         trainable_vars[self._P_GG_idx],
                                         trainable_vars[self._recurrent_kernel_idx])
          
        # Update metrics (includes the metric that tracks the loss)
          self.compiled_metrics.update_state(y[:,i,:], z)
        # Return a dict mapping metric names to current value

          if self.run_eagerly:
            self.hidden_activation.append(h.numpy()[0])

        return {m.name: m.result() for m in self.metrics}

    def update_output_kernel(self, P_output, h, z, y, trainable_vars_P_output, trainable_vars_output_kernel):

        # Compute pseudogradients
        dP = self.pseudogradient_P(P_output, h)
        # Update weights
        self.optimizer.apply_gradients(zip([dP], [trainable_vars_P_output]))

        dwO = self.pseudogradient_wO(P_output, h, z, y)
        self.optimizer.apply_gradients(zip([dwO], [trainable_vars_output_kernel]))

    def update_recurrent_kernel(self, P_Gx, h, z, y, trainable_vars_P_Gx, trainable_vars_recurrent_kernel):

        # Compute pseudogradients
        dP_Gx = self.pseudogradient_P_Gx(P_Gx, h)
        # Update weights
        self.optimizer.apply_gradients(zip([dP_Gx], [trainable_vars_P_Gx]))

        dwR = self.pseudogradient_wR(P_Gx, h, z, y)
        self.optimizer.apply_gradients(zip([dwR], [trainable_vars_recurrent_kernel]))


    def pseudogradient_P(self, P, h):
        # Implements the training step i.e. the rls() function
        # This not a real gradient (does not use gradient.tape())
        # Computes the actual update
        # Example array shapes
        # h : 1 x 500
        # P : 500 x 500 
        # k : 500 x 1 
        # hPht : 1 x 1
        # dP : 500 x 500 


        k = backend.dot(P, tf.transpose(h))
        hPht = backend.dot(h, k)
        c = 1./(1.+hPht)
      #  assert c.shape == (1,1)
        #hP = backend.dot(h, P)
        #dP = backend.dot(c*k, hP)
        dP = backend.dot(c*k, tf.transpose(k))
        return  dP 

    def pseudogradient_wO(self, P, h, z, y):
        # z : 1 x 20 
        # y : 1 x 20
        # e : 1 x 20
        # dwO : 500 x 20  

        e = z-y
        Ph = backend.dot(P, tf.transpose(h))
        dwO = backend.dot(Ph, e)

        return  dwO

#################### 

    def pseudogradient_wR(self, P_Gx, h, z, y):
        e = z - y 
        assert e.shape == (1,1), 'Output must only have 1 dimension'
        Ph = backend.dot(P_Gx, tf.transpose(h))[:,:,0]

        dwR = Ph*e ### only valid for 1-d output

        return tf.transpose(dwR) 

    def pseudogradient_P_Gx(self, P_Gx, h):
        Ph = backend.dot(P_Gx, tf.transpose(h))[:,:,0]
        hPh = tf.expand_dims(backend.dot(Ph, tf.transpose(h)),axis = 2)
        #htP = backend.dot(h, P_Gx)[0]
        #dP_Gx = tf.expand_dims(Ph, axis = 2) * tf.expand_dims(htP, axis = 1)/(1+hPh)
        dP_Gx = tf.expand_dims(Ph, axis = 2) * tf.expand_dims(Ph, axis = 1)/(1+hPh)
        return dP_Gx

#################### 
#new 

    # def pseudogradient_wR(self, P_Gx, h, z, y):
    #     e = z - y 
    #     assert e.shape == (1,1)
    #     Pht = backend.dot(h, P_Gx)[0] 
    #     dwR = e*Pht ### only valid for 1-d output

    #     return dwR 


    # def pseudogradient_P_Gx(self, P_Gx, h):
    #    Pht = backend.dot(h, P_Gx)      # get 1 by j by i
    #    hPht = backend.dot(h, Pht)      # get 1 by 1 by i
    #    hP = tf.tensordot(h, P_Gx, axes = [[1],[0]]) # get 1 by k by i
    #    #dP_Gx = tf.reshape(Pht, (self.units, 1, self.units)) * hP / (1 + hPht)
    #    dP_Gx = tf.expand_dims(Pht[0], axis = 1) * hP / (1 + hPht)

    #    return dP_Gx

#################### 

    def compile(self, metrics, **kwargs):
        super().compile(optimizer=keras.optimizers.SGD(learning_rate=1), loss = 'mae', metrics=metrics,   **kwargs)


    def fit(self, x, y=None, epochs = 1, verbose = 'auto', **kwargs):

        if len(x.shape) < 2 or len(x.shape) > 3:
            raise ValueError('Shape of x is invalid')

        if len(y.shape) < 2 or len(y.shape) > 3:
            raise ValueError('Shape of y is invalid')
        
        if len(x.shape) == 2:
            x = tf.expand_dims(x, axis = 0)
        
        if len(y.shape) == 2:
            y = tf.expand_dims(y, axis = 0)
        
        if x.shape[0] != 1:
            raise ValueError("Dim 0 of x must be 1")

        if y.shape[0] != 1:
            raise ValueError("Dim 0 of y must be 1")
        
        if x.shape[1] != y.shape[1]: 
            raise ValueError('Timestep dimension of inputs must match')     

        return super().fit(x = x, y = y, epochs = epochs, batch_size = 1, verbose = verbose, **kwargs)

    def predict(self, x, **kwargs):
        if len(x.shape) == 3 and x.shape[0] != 1:
            raise ValueError('Dim 0 must be 1')
        
        if len(x.shape) < 2 or len(x.shape) > 3:
            raise ValueError('')

        if len(x.shape) == 2:
            x = tf.expand_dims(x, axis = 0)

        return self(x, training = False)[0]

In [11]:
class EchoStateNetwork(FORCELayer):
    def __init__(self, dtdivtau, hscale = 0.25, initial_a = None, **kwargs):
        self.dtdivtau = dtdivtau 
        self.hscale = hscale
        self._initial_a = initial_a
        super().__init__(**kwargs)        

    def call(self, inputs, states):
        """Implements the forward step (i.e., the esn() function)
        """
        prev_a, prev_h, prev_output = states      
        input_term = backend.dot(inputs, self.input_kernel)
        recurrent_term = backend.dot(prev_h, self.recurrent_kernel)
        feedback_term = backend.dot(prev_output, self.feedback_kernel)

        dadt = -prev_a + input_term + recurrent_term + feedback_term 
        a = prev_a + self.dtdivtau * dadt
        h = self.activation(a)
        output = backend.dot(h, self.output_kernel)

        return output, [a, h, output]

    def get_initial_state(self, inputs=None, batch_size=None, dtype=None):

        if self._initial_a is not None:
          init_a = self._initial_a
        else:
          initializer = keras.initializers.RandomNormal(mean=0., 
                                                        stddev= self.hscale , 
                                                        seed = self.seed_gen.uniform([1], 
                                                                                        minval=None, 
                                                                                        dtype=tf.dtypes.int64)[0])
          init_a = initializer((batch_size, self.units))  

        init_h =  self.activation(init_a)
        init_out = backend.dot(init_h,self.output_kernel) 

        return (init_a, init_h, init_out)


In [12]:
class NoFeedbackESN(EchoStateNetwork):

    def __init__(self, recurrent_kernel_trainable  = True, **kwargs):
        super().__init__(recurrent_kernel_trainable = recurrent_kernel_trainable, **kwargs)

    def call(self, inputs, states):
        """Implements the forward step (i.e., the esn() function)
        """
        prev_a, prev_h, prev_output = states      
        input_term = backend.dot(inputs, self.input_kernel)
        recurrent_term = backend.dot(prev_h, self.recurrent_kernel)

        dadt = -prev_a + input_term + recurrent_term 
        a = prev_a + self.dtdivtau * dadt
        h = self.activation(a)
        output = backend.dot(h, self.output_kernel)

        return output, [a, h, output]


    def build(self, input_shape):

        self.initialize_input_kernel(input_shape[-1])
        self.initialize_recurrent_kernel()
        self.initialize_output_kernel()
          
        self.recurrent_nontrainable_boolean_mask = None
        self.built = True

    @classmethod
    def from_weights(cls, weights, recurrent_nontrainable_boolean_mask, **kwargs):
        # Initialize the network from a list of weights (e.g., user-generated)
        input_kernel, recurrent_kernel, output_kernel = weights 
        input_shape, input_units = input_kernel.shape 
        recurrent_units1, recurrent_units2 = recurrent_kernel.shape 
        output_units, output_size = output_kernel.shape 

        units = input_units      

        assert np.all(np.array([input_units, recurrent_units1, recurrent_units2, 
                            output_units]) == units)

        assert 'p_recurr' not in kwargs.keys(), 'p_recurr not supported in this method'
        assert recurrent_kernel.shape == recurrent_nontrainable_boolean_mask.shape, "Boolean mask and recurrent kernel shape mis-match"
        assert tf.math.count_nonzero(tf.boolean_mask(recurrent_kernel, recurrent_nontrainable_boolean_mask)).numpy() == 0, "Invalid boolean mask"  

        self = cls(units=units, output_size=output_size, p_recurr = None, **kwargs)

        self.recurrent_nontrainable_boolean_mask = tf.convert_to_tensor(recurrent_nontrainable_boolean_mask)
        
        self.initialize_input_kernel(input_shape, input_kernel)
        self.initialize_recurrent_kernel(recurrent_kernel)
        self.initialize_output_kernel(output_kernel)

        self.built = True

        return self

In [13]:
class secondFORCEModel(FORCEModel):

    def initialize_P(self):

        self.P_output = self.add_weight(name='P_output', shape=(self.units, self.units), 
                                        initializer=keras.initializers.Identity(gain=self.alpha_P), 
                                        trainable=True)

        if self.original_force_layer.recurrent_kernel.trainable:

            bool_mask = self.original_force_layer.recurrent_nontrainable_boolean_mask

            if bool_mask is None or tf.math.count_nonzero(bool_mask) == 0:
          #    print('bool mask None or count is 0')
              self.P_GG = self.add_weight(name='P_GG', shape=(self.units, self.units), 
                                          initializer=keras.initializers.Identity(gain=self.alpha_P), 
                                          trainable=True)
              
              
            else:
              print('bool mask count is NOT zero')
              identity_3d = np.zeros((self.units, self.units, self.units))
              idx = np.arange(self.units)

  #################### 

              identity_3d[:, idx, idx] = self.alpha_P 

              I,J = np.nonzero(tf.transpose(bool_mask).numpy()==True)
              identity_3d[I,:,J]=0
              identity_3d[I,J,:]=0

  #################### 
  # # new 
  #          #  print('new')
  #             identity_3d[idx, idx, :] = self.alpha_P 
  #             J,I = np.nonzero(self.original_force_layer.recurrent_kernel.numpy()==0)
  #             identity_3d[J,:,I]=0
  #             identity_3d[:,J,I]=0

  #################### 

              self.P_GG = self.add_weight(name='P_GG', shape=(self.units, self.units, self.units), 
                                          initializer=keras.initializers.constant(identity_3d), 
                                          trainable=True)
              
    def pseudogradient_wR(self, P_Gx, h, z, y):
        e = z - y 
        assert e.shape == (1,1), 'Output must only have 1 dimension'

        if len(P_Gx.shape) == 2:
           # print('wR len')
            dwR_inter = backend.dot(P_Gx, tf.transpose(h))*e
            return dwR_inter*tf.ones((P_Gx.shape))
        else:
            Ph = backend.dot(P_Gx, tf.transpose(h))[:,:,0]
            dwR = Ph*e ### only valid for 1-d output
            return tf.transpose(dwR) 

    def pseudogradient_P_Gx(self, P_Gx, h):

        if len(P_Gx.shape) == 2:
            #print('PGx len')
            return self.pseudogradient_P(P_Gx,h)

        Ph = backend.dot(P_Gx, tf.transpose(h))[:,:,0]
        hPh = tf.expand_dims(backend.dot(Ph, tf.transpose(h)),axis = 2)
        dP_Gx = tf.expand_dims(Ph, axis = 2) * tf.expand_dims(Ph, axis = 1)/(1+hPh)
        return dP_Gx

        

In [14]:
%%time 
start = time_time.time()

wR = wR.numpy()
np.fill_diagonal(wR, 0)
wR = tf.convert_to_tensor(wR)
bool_mask = wR == 0
#bool_mask = tf.ones((n,n)) != 1

print(tf.math.reduce_mean(tf.cast(bool_mask, tf.float32)))
print()
myesn2 = NoFeedbackESN.from_weights(weights = (wI, wR, wO), 
                                      dtdivtau=init_params['dt_over_tau'], 
                                      activation = 'tanh',
                                       recurrent_nontrainable_boolean_mask = bool_mask, 
                                       initial_a = a0)

model2 = FORCEModel(myesn2, return_sequences=True)  
model2.compile(metrics=["mae"] , run_eagerly = True)

history2 = model2.fit(x=input2, y= f_t , epochs = ntraining)
end = time_time.time()
print('Seconds per epoch: ',f'{round(end-start,1)/ntraining}')

tf.Tensor(0.0025, shape=(), dtype=float32)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Seconds per epoch:  76.115
CPU times: user 40min 2s, sys: 22.1 s, total: 40min 25s
Wall time: 25min 22s


In [15]:
np.diagonal(myesn2.recurrent_kernel.numpy())

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0.

In [16]:
np.diagonal(wR.numpy())

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0.

In [17]:
%%time
start = time_time.time()
bool_mask = tf.ones((n,n)) != 1
myesn2 = NoFeedbackESN.from_weights(weights = (wI, wR, wO), 
                                      dtdivtau=init_params['dt_over_tau'], 
                                      activation = 'tanh',
                                       recurrent_nontrainable_boolean_mask = bool_mask, 
                                       initial_a = a0)

model2 = FORCEModel(myesn2, return_sequences=True)  
model2.compile(metrics=["mae"] , run_eagerly = True)

history2 = model2.fit(x=input2, y= f_t , epochs = ntraining)
end = time_time.time()
print('Seconds per epoch: ',f'{round(end-start,1)/ntraining}')

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Seconds per epoch:  71.02000000000001
CPU times: user 38min 12s, sys: 8.72 s, total: 38min 21s
Wall time: 23min 40s


In [18]:
np.diagonal(myesn2.recurrent_kernel.numpy())

array([-0.01170432,  0.03379069, -0.03195122,  0.0445899 ,  0.0501068 ,
       -0.02229326,  0.00382762,  0.01983151,  0.12214473,  0.00242615,
        0.05157316, -0.02733548, -0.02111337,  0.04538266,  0.00200173,
       -0.02401695, -0.05029807, -0.05429257,  0.02734522, -0.01736236,
        0.01622583, -0.07090793,  0.05781898,  0.00932593, -0.03284219,
        0.07053553, -0.03637595, -0.02905682,  0.09198544,  0.00527502,
        0.03581332,  0.01031048, -0.08506119,  0.01068352,  0.02045452,
       -0.00258599,  0.01506071,  0.03377052,  0.01563763, -0.01687369,
       -0.000776  , -0.0128849 , -0.0094495 ,  0.07666625, -0.00844504,
        0.0603042 , -0.03435494, -0.09966249,  0.04538142,  0.02583141,
       -0.02017155, -0.05427522,  0.0235228 , -0.04752948,  0.00351077,
        0.02968894,  0.00098787, -0.07781325,  0.05562407, -0.08233822,
        0.03214607, -0.00844648,  0.06274369, -0.02784043,  0.05954359,
       -0.00526223,  0.07202754,  0.01271846,  0.01028233, -0.04

In [19]:
# fig, (ax1, ax2) = plt.subplots(1, 2, sharey = True, figsize=(24,12))


# ax1.plot(time , f_t  + 2*np.arange(0, f_t.shape[1]), 'g')
# ax1.plot(time , z_t_call_2 + 2*np.arange(0, z_t_call_2.shape[1]), 'r');
 
# ax1.set_xlim((0, T))
# plt.title('Target - f (green), Output - z (red)')
# ax1.set_xlabel('Time')
# ax1.set_ylabel('Dimension')

# ax2.plot(time, tf.math.abs(f_t-z_t_call_2) + 2*np.arange(0, z_t_call_2.shape[1]), 'r');
# ax2.set_xlim((0, T))
# plt.title('MAE')
# ax2.set_xlabel('Time')
# ax2.set_ylabel('Dimension')

In [20]:
# plt.figure(figsize=(12,8), dpi=80)
# plt.plot(history2.history['mae'])