<a href="https://colab.research.google.com/github/sdevries0/ISMI_group13/blob/main/Kopie_van_NSDE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Neural SDE

In [18]:
!pip install diffrax
!pip install optax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [19]:
TF_CPP_MIN_LOG_LEVEL=0

In [20]:
from typing import Union
from math import pi
import diffrax as dfx
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jrandom
import matplotlib.pyplot as plt
import optax  # https://github.com/deepmind/optax
from typing import Callable

##Data

In [21]:
# Sigmoid firing rate
# r = lambda x: 1/(1+jnp.exp(-x)) 
r = lambda x: jnp.tanh(x) #tanh
# r = lambda x: (x>0) * x

In [22]:
# Define control path with multiple different functions. t is added to the resulting array. 
class MultiControlPath(dfx.AbstractPath):
    C: int
    phase: Callable
    frequency: Callable
    noise: bool
    key: jax.random.PRNGKey

    def __init__(self, phase, frequency, key, C = 2):
      self.C = C
      self.phase = phase
      self.frequency = frequency
      if key is None:
        self.noise = False
        self.key = jrandom.PRNGKey(0)
      else:
        self.noise = True
        self.key = key

    def evaluate(self, t0, t1=None, left=True):
      del left
      if t1 is not None:
        return self.evaluate(t1) - self.evaluate(t0)
      #Evaluate t0 and t1 for each sinoid control 
      controls_at_t = jnp.array([jnp.sin(self.phase[i] + self.frequency[i] * t0) for i in range(self.C)])
      if self.noise:
        #Fix keys
        new_key = jax.random.fold_in(key, t0)
        dw = jrandom.normal(new_key, shape=(1,))
        return jnp.append(jnp.append(t0, controls_at_t), dw)
      else:
        return jnp.append(t0, controls_at_t)

In [23]:
class CDE():
    
    f_state : Callable
    f_obs : Callable

    def __init__(self, f_state, f_obs = lambda x: x):
        """
        params:
            f_state: vector field; function dom_state -> dom_state x dom_ctrl
            f_obs: linear readout (complete observability by default); function dom_ctrl -> dom_obs
        """
        self.f_state = f_state
        self.f_obs = f_obs
 
    def __call__(self, ts, phase, frequency, init):
        """
        Generates states at specified times ts given a control
 
        params:
            ts: time points
            phase: phases used for control
            frequency: frequencies used for control
            init: initial state of the CDE 
        """
        #Create control
        control = MultiControlPath(phase, frequency, None, frequency.shape[0])
        system = dfx.ControlTerm(self.f_state, control).to_ode()
        solver = dfx.Tsit5()
        dt0=0.01
        
        #Solve differential equation
        sol = dfx.diffeqsolve(
            system,
            solver,
            ts[0],
            ts[-1],
            dt0,
            y0=init,
            stepsize_controller=dfx.PIDController(rtol=1e-3, atol=1e-6),
            saveat=dfx.SaveAt(ts=ts)
        )

        # return phase, frequency, initial state, hidden states and observations
        return phase, frequency, init, sol.ys, jax.vmap(self.f_obs)(sol.ys)

In [24]:
# We can use a CDE as a data generator
def dataloader(system, ts, nr_batch, keys, N=1, C=1, sd=1.0):
    #Sample initial states
    init = sd*jrandom.normal(keys[1], shape=(nr_batch, N))

    #Sample frequencies and phases
    frequency = jrandom.uniform(keys[2], shape=(nr_batch, C), minval = 0.0, maxval = 3.0)
    phase = jrandom.normal(keys[3], shape=(nr_batch, C))

    #Generate data from the CDE
    return jax.vmap(system, in_axes=[None, 0, 0, 0])(ts, phase, frequency, init)

#NSDE+NCDE

In [25]:
def lipswish(x):
    return 0.909 * jnn.silu(x)

In [26]:
#RNN that models the states of neurons given input. Used as state equation for an NCDE
class NeuralSystem(eqx.Module):
    
    J: float
    B: float
    b: float
    tau: float
    N: int
    C: int
    noise_component: eqx.nn.MLP
    noise_size: int
    
    def __init__(self, keys, N, C, noise_size, tau):
        super().__init__()
        self.J = jrandom.normal(keys[0], shape=(N,N))
        self.B = jrandom.normal(keys[1], shape=(N,C))
        self.b = jrandom.normal(keys[2], shape=(N,))
        self.tau = tau
        self.N = N
        self.C = C
        self.noise_size = noise_size
        self.noise_component = eqx.nn.MLP(in_size=N, out_size=N * noise_size, width_size=8, depth=1, activation=jnn.tanh, final_activation=jnn.tanh, key=keys[3],)

    def __call__(self, t, x, args):
      #Returns tau*x' = -x + Jr(x) + Bu + b
      return jnp.concatenate((
          jnp.concatenate((jnp.array([-x[i]+(self.J@r(x))[i] + self.b[i] for i in range(self.N)]).reshape(self.N,1), jnp.array([[self.B[i,j] for j in range(self.C)] for i in range(self.N)])),axis=1),
          self.noise_component(x).reshape(N,self.noise_size)),axis=1)/self.tau


In [27]:
class VectorField(eqx.Module):
    scale: Union[int, jnp.ndarray]
    mlp: eqx.nn.MLP

    def __init__(self, hidden_size, width_size, depth, scale, *, key, **kwargs):
        super().__init__(**kwargs)
        scale_key, mlp_key = jrandom.split(key)
        if scale:
            self.scale = jrandom.uniform(
                scale_key, (hidden_size,), minval=0.9, maxval=1.1
            )
        else:
            self.scale = 1
        self.mlp = eqx.nn.MLP(in_size=hidden_size, out_size=hidden_size, width_size=width_size, depth=depth, activation=jnn.tanh, final_activation=jnn.tanh, key=mlp_key,)

    def __call__(self, t, y, args):
        return self.scale * self.mlp(y)


# class VectorField(eqx.Module):
#     scale: Union[int, jnp.ndarray]
#     mlp: eqx.nn.MLP

#     def __init__(self, hidden_size, control_size, noise_size, width_size, depth, scale, *, key, **kwargs):
#         super().__init__(**kwargs)
#         scale_key, mlp_key = jrandom.split(key)
#         self.mlp = eqx.nn.MLP(in_size=hidden_size, out_size=hidden_size*(control_size+1), width_size=width_size, depth=depth, activation=jnn.tanh, final_activation=jnn.tanh, key=mlp_key,)
#         self.noise_component = eqx.nn.MLP(in_size=hidden_size, out_size=hidden_size*noise_size, width_size=width_size, depth=depth, activation=jnn.tanh, final_activation=jnn.tanh, key=mlp_key,)
#         self.noise_size = noise_size
#         self.hidden_size = hidden_size
#         self.control_size = control_size


#     def __call__(self, t, x, args):
#         return jnp.concatenate((self.mlp(x).reshape(self.hidden_size,self.control_size), self.noise_component(x).reshape(self.hidden_size, self.noise_size)),axis=1)

In [28]:
class ControlledVectorField(eqx.Module):
    scale: Union[int, jnp.ndarray]
    mlp: eqx.nn.MLP
    noise_size: int
    hidden_size: int

    def __init__(
        self, noise_size, hidden_size, width_size, depth, scale, *, key, **kwargs
    ):
        super().__init__(**kwargs)
        scale_key, mlp_key = jrandom.split(key)
        if scale:
            self.scale = jrandom.uniform(scale_key, (hidden_size, noise_size), minval=0.9, maxval=1.1)
        else:
            self.scale = 1
        self.mlp = eqx.nn.MLP(in_size=hidden_size, out_size=hidden_size * noise_size, width_size=width_size, depth=depth, activation=jnn.tanh, final_activation=jnn.tanh, key=mlp_key,)
        self.noise_size = noise_size
        self.hidden_size = hidden_size

    def __call__(self, t, y, args):
        return self.scale * self.mlp(y).reshape(self.hidden_size, self.noise_size)

In [29]:
#Observation class used as the readout in the NCDE
class Readout(eqx.Module):
    
    W: float
    
    def __init__(self, key, N, M):
        super().__init__()
        self.W = jrandom.normal(key, shape=(M,N))

    def __call__(self, x):
        return self.W@r(x)

In [30]:
class NeuralSDE(eqx.Module):
    initial: eqx.nn.MLP
    drift: NeuralSystem  # drift
    # diffusion: ControlledVectorField  # diffusion
    readout: Readout
    V: int #initial noise size
    C: int
    noise_size: int #noise size

    def __init__(self, V, noise_size, width_size, depth, key, N, C, M, tau, **kwargs,):
      super().__init__(**kwargs)
      keys = jrandom.split(key, 7)

      self.initial = eqx.nn.MLP(V, N, width_size, depth, key=keys[0])
      self.drift = NeuralSystem(keys[1:5], N, C, noise_size, tau)
      # self.diffusion = ControlledVectorField(noise_size, N, width_size, depth, scale=True, key=keys[5])
      self.readout = Readout(keys[6], N, M)

      self.V = V
      self.C = C
      self.noise_size = noise_size

    def __call__(self, ts, phase, frequency, key):
      t0 = ts[0]
      t1 = ts[-1]
      dt0 = 1.0
      init_key, bm_key = jrandom.split(key, 2)
      init = jrandom.normal(init_key, (self.V,))
      control = MultiControlPath(phase, frequency, bm_key, self.C)
      # control = dfx.VirtualBrownianTree(t0=t0, t1=t1, tol=dt0 , shape=(self.noise_size,), key=bm_key)
      # drift = dfx.ODETerm(self.drift)  # Drift term
      system = dfx.ControlTerm(self.drift, control).to_ode()  # Diffusion term
      # terms = dfx.MultiTerm(drift, diffusion)
      # ReversibleHeun is a cheap choice of SDE solver. We could also use Euler etc.
      solver = dfx.ReversibleHeun()
      y0 = self.initial(init)
      saveat = dfx.SaveAt(ts=ts)
      # We happen to know from our dataset that we're not going to take many steps.
      # Specifying a smallest-possible upper bound speeds things up.
      sol = dfx.diffeqsolve(
          system, solver, t0, t1, dt0, y0, saveat=saveat, max_steps=128
      )
      return jax.vmap(self.readout)(sol.ys)

In [31]:
class NeuralCDE(eqx.Module):
    initial: eqx.nn.MLP
    drift: VectorField
    diffusion: ControlledVectorField
    readout: eqx.nn.Linear

    def __init__(self, width_size, depth, key, N, C, M, **kwargs):
        super().__init__(**kwargs)
        initial_key, vf_key, cvf_key, readout_key = jrandom.split(key, 4)

        self.initial = eqx.nn.MLP(M, N, width_size, depth, key=initial_key)
        self.drift = VectorField(N, width_size, depth, scale=False, key=vf_key)
        self.diffusion = ControlledVectorField(M, N, width_size, depth, scale=False, key=cvf_key)
        self.readout = eqx.nn.Linear(N, 1, key=readout_key)

    def __call__(self, ts, ys):
        # Interpolate data into a continuous path.
        ys = dfx.linear_interpolation(ts, ys, replace_nans_at_start=0.0, fill_forward_nans_at_end=True)
        init = ys[0]
        control = dfx.LinearInterpolation(ts, ys)
        drift = dfx.ODETerm(self.drift)
        diffusion = dfx.ControlTerm(self.diffusion, control).to_ode()
        terms = dfx.MultiTerm(drift, diffusion)
        solver = dfx.ReversibleHeun()
        t0 = ts[0]
        t1 = ts[-1]
        dt0 = 1.0
        y0 = self.initial(init)
        # Have the discriminator produce an output at both `t0` *and* `t1`.
        # The output at `t0` has only seen the initial point of a sample. This gives
        # additional supervision to the distribution learnt for the initial condition.
        # The output at `t1` has seen the entire path of a sample. This is needed to
        # actually learn the evolving trajectory.
        saveat = dfx.SaveAt(ts=ts)
        sol = dfx.diffeqsolve(
            diffusion, solver, t0, t1, dt0, y0, saveat=saveat, max_steps=128
        )
        loc = jnp.array([0,-1])
        return jax.vmap(self.readout)(sol.ys[loc])

    @eqx.filter_jit
    def clip_weights(self):
        leaves, treedef = jax.tree_util.tree_flatten(
            self, is_leaf=lambda x: isinstance(x, eqx.nn.Linear)
        )
        new_leaves = []
        for leaf in leaves:
            if isinstance(leaf, eqx.nn.Linear):
                lim = 1 / leaf.out_features
                leaf = eqx.tree_at(
                    lambda x: x.weight, leaf, leaf.weight.clip(-lim, lim)
                )
            new_leaves.append(leaf)
        return jax.tree_util.tree_unflatten(treedef, new_leaves)

#Losses

In [32]:
@eqx.filter_jit
def GAN_loss(generator, discriminator, ts_i, ys_i, phase, frequency, keys, step=0):
    fake_ys_i = jax.vmap(generator, in_axes=[None, 0, 0, 0])(ts_i, phase, frequency, keys)
    real_score = jax.vmap(discriminator, in_axes=[None, 0])(ts_i, ys_i)
    fake_score = jax.vmap(discriminator, in_axes=[None, 0])(ts_i, fake_ys_i)
    return jnp.mean(real_score - fake_score)

@eqx.filter_grad
def grad_loss(g_d, ts_i, ys_i, phase, frequency, keys, step):
    generator, discriminator = g_d
    return GAN_loss(generator, discriminator, ts_i, ys_i, phase, frequency, keys, step)

In [33]:
@eqx.filter_jit
def make_step(generator, discriminator, g_opt_state, d_opt_state, g_optim, d_optim, ts_i, ys_i, phase, frequency, keys, step):
    g_grad, d_grad = grad_loss((generator, discriminator), ts_i, ys_i, phase, frequency, keys, step)
    g_updates, g_opt_state = g_optim.update(g_grad, g_opt_state)
    d_updates, d_opt_state = d_optim.update(d_grad, d_opt_state)
    # g_updates = increase_update_initial(g_updates)
    # d_updates = increase_update_initial(d_updates)
    generator = eqx.apply_updates(generator, g_updates)
    discriminator = eqx.apply_updates(discriminator, d_updates)
    discriminator = discriminator.clip_weights()
    return generator, discriminator, g_opt_state, d_opt_state

#Main part

In [34]:
def train_nn(key, N, C, M, system, time_points, nr_batch, tau):
  keys = jrandom.split(key, num=7)
  new_key = keys[0]

  # define model
  V = 2
  noise_size = 1
  generator = NeuralSDE(V, noise_size, 8, 1, key, N, C, M, tau)
  discriminator = NeuralCDE(8, 1, key, N, C, M)

  #Maximum number of epochs
  epoch = 1000

  #Threshold for convergence
  threshold = 500
  
  #SD used to sample initial states
  sd = 1e-5

  # grad_loss = eqx.filter_value_and_grad(GAN_loss)

  generator_lr = 1e-3
  discriminator_lr = 5e-3

  g_optim = optax.rmsprop(generator_lr)
  d_optim = optax.rmsprop(-discriminator_lr)
  g_opt_state = g_optim.init(eqx.filter(generator, eqx.is_inexact_array))
  d_opt_state = d_optim.init(eqx.filter(discriminator, eqx.is_inexact_array))

  #Initialize lists to save intermediate losses and weight values
  g_losses = []
  d_losses = []
  Js = [generator.drift.J]
  Bs = [generator.drift.B]
  bs = [generator.drift.b]
  Ws = [generator.readout.W]
  
  keys = jrandom.split(new_key, num=4)
  new_key = keys[0]

  #Parameters for convergence
  best_loss = jnp.inf
  last_loss = 0

  #Generate data consisting of phases and frequencies for control, initial and hidden states and observations
  phase, frequency, init, state, obs = dataloader(system, time_points, nr_batch=nr_batch, keys=keys[1:], N=N, C=C, sd=sd)

  for e in range(epoch):
    new_keys = jrandom.split(new_key, num=nr_batch+1)
    key = new_keys[0]
    # try:
    generator, discriminator, g_opt_state, d_opt_state = make_step(generator, discriminator, g_opt_state, d_opt_state, g_optim, d_optim, time_points, obs, phase, frequency, new_keys[1:], e)

    if (e % 50) == 49:
      print(e)
      # print(r"Currently at epoch: {}. The generator loss is: {} and the discriminator loss is: {}".format(e+1, g_loss, d_loss))

    #Store intermediate loss and weight values
    # g_losses.append(g_loss)
    # d_losses.append(d_loss)
    Js.append(generator.drift.J)
    Bs.append(generator.drift.B)
    bs.append(generator.drift.b)
    Ws.append(generator.readout.W)

    # #New lowest loss has been reached
    # if g_loss < best_loss:
    #   best_loss = g_loss
    #   last_loss = 0
    # #Current loss is higher than lowest loss
    # else:
    #   last_loss += 1
    #   #Loss did not decrease for a number of epochs in a row
    #   if last_loss >= threshold:
    #         print(r"The loss has converged on {} at epoch {}".format(best_loss, e))
    #         return g_losses, d_losses, Js, Bs, bs, Ws

    # except:
    #   # An error was thrown when the loss was too small
    #   print(r"The final loss is {} at epoch {}".format(loss,epoch))
    #   return losses[:-1], Js[:-1], Bs[:-1], bs, Ws[:-1] #HIER
    
  return g_losses, d_losses, Js, Bs, bs, Ws

In [None]:
#Define number of neurons, control inputs and observations
N = 2 #neurons
C = 1 #control
M = 1 #observations
key = jrandom.PRNGKey(0)
tau = 1 #time constant
keys = jrandom.split(key, num=5)
key = keys[0]

#Use Bernoulli matrices to induce sparsity
p = 1.0 #sparsity
J = jrandom.normal(keys[1], shape=(N,N)) * jrandom.bernoulli(keys[1], p=p, shape=(N,N)) 
B = jrandom.normal(keys[2], shape=(N,C)) * jrandom.bernoulli(keys[2], p=p, shape=(N,C)) 
b = jrandom.normal(keys[3], shape=(N,))
W = jrandom.normal(keys[4], shape=(M,N)) * jrandom.bernoulli(keys[4], p=p, shape=(M,N)) 

#State equation for the CDE
f_state = lambda t, x, args: jnp.array([jnp.append(-x[i]+(J@r(x))[i] + b[i],jnp.array([B[i,j] for j in range(C)])) for i in range(N)])

#Observation function for the CDE
f_obs = lambda x : W@r(x)

#Define CDE
system = CDE(f_state, f_obs)

#Sample path
T = 100
time_points = jnp.linspace(0, 2*pi, T)

g_losses, d_losses, Js, Bs, bs, Ws = train_nn(key, N, C, M, system, time_points, nr_batch = 2, tau=tau)
# plot_figures(losses, Js, Bs, bs, Ws, J, B, b, W)

49
