In [None]:
import brainpy as bp
import brainpy.math.jax as bm
import matplotlib.pyplot as plt

bp.math.use_backend('jax')

class EchoStateNet(bp.DynamicalSystem):
  def __init__(self, num_input, num_hidden, num_output,
               tau=1.0, dt=0.1, g=1.8, alpha=1.0, **kwargs):
    super(EchoStateNet, self).__init__(**kwargs)

    # parameters
    self.tau = tau
    self.dt = dt
    self.alpha = alpha

    # weights
    self.w_ir = bm.random.normal(size=(num_input, num_hidden)) / bm.sqrt(num_input)
    self.w_rr = g * bm.random.normal(size=(num_hidden, num_hidden)) / bm.sqrt(num_hidden)
    self.w_or = bm.random.normal(size=(num_output, num_hidden))
    w_ro = bm.random.normal(size=(num_hidden, num_output)) / bm.sqrt(num_hidden)
    self.w_ro = bm.Variable(w_ro)

    # variables
    self.h = bm.Variable(bm.random.normal(size=num_hidden) * 0.25)  # hidden
    self.r = bm.tanh(self.h)  # firing rate
    self.o = bm.Variable(bm.dot(self.r, w_ro))  # output unit
    self.P = bm.Variable(bm.eye(num_hidden) * self.alpha)  # inverse correlation matrix

  def update(self, x, **kwargs):
    dhdt = -self.h + bm.dot(x, self.w_ir)
    dhdt += bm.dot(self.r, self.w_rr)
    dhdt += bm.dot(self.o, self.w_or)
    self.h += self.dt / self.tau * dhdt
    self.r.value = bm.tanh(self.h)
    self.o.value = bm.dot(self.r, self.w_ro)

  def rls(self, target):
    # update inverse correlation matrix
    k = bm.expand_dims(bm.dot(self.P, self.r), axis=1)  # (num_hidden, 1)
    hPh = bm.dot(self.r.T, k)  # (1,)
    c = 1.0 / (1.0 + hPh)  # (1,)
    self.P -= bm.dot(k * c, k.T) # (num_hidden, num_hidden)
    # update the output weights
    e = bm.atleast_2d(self.o - target)  # (1, num_output)
    dw = bm.dot(-c * k, e)  # (num_hidden, num_output)
    self.w_ro += dw

  def simulate(self, xs):
    f = bm.easy_scan(self.update, dyn_vars=[self.h, self.r, self.o], out_vars=[self.r, self.o])
    return f(xs)

  def train(self, xs, targets):
    def _f(x):
      input, target = x
      self.update(input)
      self.rls(target)

    f = bm.easy_scan(_f, dyn_vars=self.vars(), out_vars=[self.r, self.o])
    return f([xs, targets])

In [None]:
esn = EchoStateNet(num_input=1, num_hidden=500, 
                   num_output=20, dt=dt, g=1.5)
rs, ys = esn.simulate(xs)  # the untrained ESN

In [None]:
esn = EchoStateNet(num_input=1, num_hidden=500, 
                   num_output=20, dt=dt, g=1.5, alpha=1.)
for i in range(10):
    rs, ys = esn.train(xs=xs, targets=targets)  # train once