<a href="https://colab.research.google.com/github/sdevries0/ISMI_group13/blob/main/NCDE_testing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# import jax
# from jax.lib import xla_bridge
# print(xla_bridge.get_backend().platform)

In [None]:
!pip install diffrax
!pip install optax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import jax
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

cpu


In [None]:
from math import pi
import numpy as np
#from utils import plot_system, SinusoidalControlPath
import matplotlib.pyplot as plt
import diffrax as dfx
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jrandom
import jax.scipy as jsp
import matplotlib as mpl
mpl.rcParams.update(mpl.rcParamsDefault)
import matplotlib.pyplot as plt
import optax  # https://github.com/deepmind/optax
import seaborn as sns
import pandas as pd
# from equinox.custom_types import Array
import math
from typing import Callable
import seaborn as sns

In [None]:
# Define control path with multiple different functions. t is added to the resulting array. 
class MultiControlPath(dfx.AbstractPath):

    C : int
    phase: Callable
    frequency: Callable

    def __init__(self, phase, frequency, C = 2):
      self.C = C
      self.phase = phase
      self.frequency = frequency

    def evaluate(self, t0, t1=None, left=True):
      del left
      if t1 is not None:
        return self.evaluate(t1) - self.evaluate(t0)
      #Evaluate t0 and t1 for each sinoid control 
      controls_at_t = jnp.array([jnp.sin(self.phase[i] + self.frequency[i] * t0) for i in range(self.C)])
      return jnp.append(t0, controls_at_t)

In [None]:
class CDE():
    
    f_state : Callable
    f_obs : Callable

    def __init__(self, f_state, f_obs = lambda x: x):
        """
        params:
            f_state: vector field; function dom_state -> dom_state x dom_ctrl
            f_obs: linear readout (complete observability by default); function dom_ctrl -> dom_obs
        """
        self.f_state = f_state
        self.f_obs = f_obs
 
    def __call__(self, ts, phase, frequency, init):
        """
        Generates states at specified times ts given a control
 
        params:
            ts: time points
            phase: phases used for control
            frequency: frequencies used for control
            init: initial state of the CDE 
        """
        #Create control
        control = MultiControlPath(phase, frequency, frequency.shape[0])
        system = dfx.ControlTerm(self.f_state, control).to_ode()
        solver = dfx.Tsit5()
        dt0=0.1
        
        #Solve differential equation
        sol = dfx.diffeqsolve(
            system,
            solver,
            ts[0],
            ts[-1],
            dt0,
            y0=init,
            stepsize_controller=dfx.ConstantStepSize(),
            # stepsize_controller=dfx.PIDController(rtol=1e-3, atol=1e-6),
            saveat=dfx.SaveAt(ts=ts),
        )

        # return phase, frequency, initial state, hidden states and observations
        return jax.vmap(self.f_obs)(sol.ys)

In [None]:
class Dataloader():
    def __init__(self, key, nr_batch, N, C, sd, dataset_size, ts):
      assert dataset_size%nr_batch == 0
      self.dataset_size = dataset_size
      self.nr_batch = nr_batch
      keys = jrandom.split(key, 3)
      self.frequency = jrandom.uniform(keys[0], shape=(self.dataset_size, C), minval = 0.0, maxval = 1.0)
      self.phase = jrandom.normal(keys[1], shape=(self.dataset_size, C))
      init = sd*jrandom.normal(key, shape=(self.dataset_size, N))
      self.dataset = jax.vmap(system, in_axes=[None, 0, 0, 0])(ts, self.phase, self.frequency, init)

    def sample_observations(self, epoch):
      start = (epoch*self.nr_batch)%self.dataset_size
      end = start + self.nr_batch

      return self.frequency[start:end], self.phase[start:end], self.dataset[start:end]
      # new_key = jrandom.fold_in(self.key, epoch)
      # indices = jrandom.randint(new_key,shape=(self.nr_batch,),minval=0,maxval=self.dataset_size)

      # return self.frequency[indices], self.phase[indices], self.dataset[indices]

In [None]:
class NeuralCDE(eqx.Module):
    
    f_state : Callable
    f_init : Callable
    f_obs : Callable
    t0: int

    def __init__(self, f_state, f_obs = lambda x: x, f_init = None, t0 = 0, **kwargs):
        """
        params:
            f_state: vector field; function dom_state -> dom_state x dom_ctrl
            f_obs: linear readout (complete observability by default); function dom_state -> dom_obs
            f_init: initial state; function dom_obs -> dom_state 
            t0: starting point of NCDE
        
        Each term can be either a fixed function (CDE) or a parameterized nonlinear function (neural CDE)

        """
        super().__init__(**kwargs)
        self.f_state = f_state
        self.f_obs = f_obs
        self.f_init = f_init
        self.t0 = t0

    def __call__(self, ts, phase, frequency, obs):
        """
        Generates states at specified times ts given a control
        The initial state is determined using another NCDE which receives the first observation and control
        """
        #Create control
        control = MultiControlPath(phase, frequency, frequency.shape[0])
        system = dfx.ControlTerm(self.f_state, control).to_ode()

        # #Determine initial state
        y0 = self.f_init(ts[:self.t0], phase, frequency, obs)
        # y0 = self.f_init(control.evaluate(ts[0]))

        solver = dfx.Tsit5()
        dt0=0.1

        #Solve differential equation
        sol = dfx.diffeqsolve(
            system,
            solver,
            ts[self.t0],
            ts[-1],
            dt0,
            y0,
            stepsize_controller=dfx.ConstantStepSize(),
            # stepsize_controller=dfx.PIDController(rtol=1e-3, atol=1e-6),
            saveat=dfx.SaveAt(ts=ts[self.t0:]),
        )

        # return observations
        return jax.vmap(self.f_obs)(sol.ys)

In [None]:
#Simple RNN that is used as state equation for the initialization NCDE
class InitFunc(eqx.Module):
    mlp: eqx.nn.MLP
    ctrl_size: int
    state_size: int

    def __init__(self, state_size, ctrl_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.ctrl_size = ctrl_size
        self.state_size = state_size
        self.mlp = eqx.nn.MLP(
            in_size=state_size,
            out_size=state_size * ctrl_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.softplus,
            final_activation=jnn.tanh,
            key=key,
        )

    def __call__(self, t, x, args):
        return self.mlp(x).reshape(self.state_size, self.ctrl_size)

In [None]:
class InitNeuralCDE(eqx.Module):
    
    f_control_state : InitFunc
    f_obs_state: InitFunc
    y0 : Callable
    f_obs : Callable

    def __init__(self, state_size, ctrl_size, M, width_size, depth, key = None, **kwargs):
        """
        params:
            f_state: vector field; function dom_state -> dom_state x dom_ctrl
            f_obs: linear readout (complete observability by default); function dom_state -> dom_obs
            f_init: initial state; function dom_obs -> dom_state 
            epsilon: added to the initial state to fix stability
        """
        super().__init__(**kwargs)
        control_key, obs_key, init_key = jrandom.split(key, 3)
        self.f_control_state = InitFunc(state_size, ctrl_size, width_size, depth, key=control_key)
        self.f_obs_state = InitFunc(state_size, M, width_size, depth, key=obs_key)
        self.f_obs = lambda x: x
        self.y0 = jrandom.normal(init_key, shape=(N,))

    def __call__(self, ts, phase, frequency, obs):
        """
        Generates states at specified times ts given a control
        The initial state is determined using an MLP
        """
        #Create control
        control = MultiControlPath(phase, frequency, frequency.shape[0])
        observation_control = dfx.LinearInterpolation(ts, obs)
        s1 = dfx.ControlTerm(self.f_control_state, control).to_ode()
        s2 = dfx.ControlTerm(self.f_obs_state, observation_control).to_ode()
        system = dfx.MultiTerm(s1, s2)

        #Determine initial state
        # y0 = self.f_init(init + self.epsilon * jrandom.normal(key, shape=init.shape)) 

        solver = dfx.Tsit5() 
        dt0=0.1

        #Solve differential equation
        sol = dfx.diffeqsolve(
            system,
            solver,
            ts[0],
            ts[-1],
            dt0,
            self.y0,
            stepsize_controller=dfx.ConstantStepSize(),
            # stepsize_controller=dfx.PIDController(rtol=1e-3, atol=1e-6),
            saveat=dfx.SaveAt(ts=ts),
        )

        # return observations
        return jax.vmap(self.f_obs)(sol.ys[-1])

In [None]:
@eqx.filter_jit
def mse_loss(model, ts, phase, frequency, obs, t0):
    #Get predictions from NCDE
    predictions = jax.vmap(model, in_axes=[None, 0, 0, 0])(ts, phase, frequency, obs[:,:t0])
    #sum MSE loss over time points and take mean over batch dimension
    return jnp.mean(jnp.sum(jnp.linalg.norm(predictions - obs[:,t0:], axis=2) ** 2, axis=1), axis = 0)

def increase_update_initial(updates):
    get_initial_leaves = lambda u: jax.tree_util.tree_leaves(u.f_init)
    return eqx.tree_at(get_initial_leaves, updates, replace_fn=lambda x: x * 10)

@eqx.filter_jit
def make_step(ts, phase, frequency, model, obs, opt_state, grad_loss, optim, t0, block):
    #Determine initial state
    # init_states = jax.vmap(model.f_init,in_axes=[None, 0, 0, 0])(ts[:t0], phase, frequency, obs[:,:t0])
    loss, grad = grad_loss(model, ts, phase, frequency, obs, t0)
    updates, opt_state = optim.update(grad, opt_state)
    updates = eqx.tree_at(lambda u: u.f_obs.W, updates, replace=updates.f_obs.W*block)
    # updates = increase_update_initial(updates)
    model = eqx.apply_updates(model, updates)

    return loss, model, opt_state#, jnp.array(init_states)

In [None]:
# Sigmoid firing rate
# r = lambda x: 1/(1+jnp.exp(-x)) 
r = lambda x: jnp.tanh(x) #tanh
# r = lambda x: (x>0) * x

In [None]:
#RNN that models the states of neurons given input. Used as state equation for an NCDE
class Func(eqx.Module):
    
    J: float
    B: float
    b: float
    tau: float
    N: int
    C: int
    
    def __init__(self, keys, N, C, tau):
        super().__init__()
        keys = jrandom.split(key,3)
        # self.J = jrandom.normal(keys[0], shape=(N,N))
        self.J = 0.01*jnp.identity(N)
        self.B = jrandom.normal(keys[1], shape=(N,C))
        self.b = jrandom.normal(keys[2], shape=(N,))
        self.tau = tau
        self.N = N
        self.C = C

    def __call__(self, t, x, args):
      #Returns tau*x' = -x + Jr(x) + Bu + b
      return jnp.concatenate(((-x+self.J@r(x) + self.b).reshape(self.N,1), self.B), axis = 1)
      # return jnp.array([jnp.append(-x[i]+(self.J@r(x))[i] + self.b[i], jnp.array([self.B[i,j] for j in range(self.C)])) for i in range(self.N)])/self.tau


In [None]:
#Observation class used as the readout in the NCDE
class Obs1(eqx.Module):
    
    W: float
    
    def __init__(self, key, N, M, block):
        super().__init__()
        self.W = jrandom.normal(key, shape=(M,N)) * block

    def __call__(self, x):
        return self.W@r(x)

In [None]:
def train_nn(key, N, C, M, system, time_points, nr_batch, tau, t0, block):
  keys = jrandom.split(key, num=5)
  new_key = keys[0]

  #Initialize functions and MLPs for the initialization NCDE
  # init_f_obs = lambda x: x

  ctrl_size = C+1
  state_size = N
  width_size = 8
  depth = 1

  init_model = InitNeuralCDE(state_size, ctrl_size, M, width_size, depth, key=keys[1])

  #State equation
  f_state = Func(keys[2], N, C, tau)

  #Observation function
  f_obs = Obs1(keys[3], N, M, block)

  #Use another NCDE as initialization function
  f_init = init_model

  # define model
  model = NeuralCDE(f_state, f_obs, f_init, t0)

  #Maximum number of epochs
  epoch = 100000

  #Threshold for convergence
  threshold = 2000

  #Initialize optimizer
  lr = 1e-3
  optim = optax.adam(lr)
  
  #SD used to sample initial states
  sd = 1

  grad_loss = eqx.filter_value_and_grad(mse_loss)

  opt_state = optim.init(eqx.filter(model, eqx.is_array_like)) 

  #Initialize lists to save intermediate losses and weight values
  losses = []
  Js = [model.f_state.J]
  Bs = [model.f_state.B]
  bs = [model.f_state.b]
  Ws = [model.f_obs.W]

  #Parameters for convergence
  best_loss = jnp.inf
  last_loss = 0

  dataloader = Dataloader(keys[4], nr_batch, N, C, sd, nr_batch*50, time_points)

  for e in range(epoch):
    try:
      frequency, phase, obs = dataloader.sample_observations(e)
      loss, model, opt_state = make_step(time_points, phase, frequency, model, obs, opt_state, grad_loss, optim, t0, block)
      if (e % 1000) == 999:
        print(r"Currently at epoch: {}. The loss is: {}".format(e+1,loss))

      #Store intermediate loss and weight values
      losses.append(loss)
      Js.append(model.f_state.J)
      Bs.append(model.f_state.B)
      bs.append(model.f_state.b)
      Ws.append(model.f_obs.W)

      #New lowest loss has been reached
      if loss < best_loss:
        best_loss = loss
        last_loss = 0
      #Current loss is higher than lowest loss
      else:
        last_loss += 1
        #Loss did not decrease for a number of epochs in a row
        if last_loss >= threshold and e>20000:
              print(r"The loss has converged on {} at epoch {}".format(best_loss, e))
              return losses, Js, Bs, bs, Ws

    except:
      # An error was thrown when the loss was too small
      print(r"The final loss is {} at epoch {}".format(loss,e))
      return losses[:-1], Js[:-1], Bs[:-1], bs, Ws[:-1]

  return losses, Js, Bs, bs, Ws
  #david sussillo, randall beer, jack galant, john pillow

  #A suficient condition for a structure S to be ident$able
  #is that the matrix of second-order partial derivatives with respect to the
  #parameters, V_theta,theta , is positive definite for all theta in parameters. On structural identifiability

  #kalman filter
  #stability and equilibrium points
  #The first 500 steps of gradient-based optimisation were performed on only the first 10 sample points of each time series (so that approximately the interval [0, 1] was considered instead)

In [None]:
def find_order(W_, W):
  orders = np.zeros(W.shape, dtype=int)
  for i in range(orders.shape[0]):
    original = jnp.argsort(jnp.abs(W_[i]))
    target = jnp.argsort(jnp.abs(W[i]))
    for j in range(orders.shape[1]):
      orders[i, original[j]] = target[j]

  final_order = np.zeros(W.shape[1], dtype=int)
  for j in range(orders.shape[1]):
    values, counts = np.unique(orders[:,j], return_counts=True)
    final_order[j] = values[jnp.argmax(counts)]
  return jnp.array(final_order)

def find_block_order(W_, W, groups):
  counts = jnp.cumsum(groups, axis=0)
  order = find_order(W_[:counts[0,1], :counts[0,0]], W[:counts[0,1], :counts[0,0]])
  for i in range(1,counts.shape[0]):
    block_order = find_order(W_[counts[i-1,1]:counts[i,1], counts[i-1,0]:counts[i,0]], W[counts[i-1,1]:counts[i,1], counts[i-1,0]:counts[i,0]])
    order = jnp.append(order,block_order+counts[i-1,0])
  print(order)
  return order

def moving_avg(losses):
    mean_data = np.zeros(len(losses))
    mean_data[0] = losses[0]
    for i in range(1,len(losses)):
        mean_data[i] = (losses[i] * 0.01 ) + (losses[i-1] * 0.99 )
    return mean_data

def plot_figures(losses, Js, Bs, bs, Ws, J, B, b, W, groups):
  order = find_block_order(Ws[-1], W, groups)
  fig, axes = plt.subplots(2, 3, figsize=(14, 10))
  axes[0,0].plot(moving_avg(losses))
  axes[0,0].set_title("Loss")
  axes[0,0].set(xlabel="Epoch", ylabel="Log loss")
  axes[0,0].set_yscale('log')

  help_J = jnp.abs(jnp.array([j[order] for j in J[order]]))
  axes[0,1].plot([jnp.sum((help_J-jnp.abs(Js[i]))**2)/(J.shape[0]*J.shape[1]) for i in range(len(Js))])
  axes[0,1].set_title("Mean squared error from estimated J to true J")
  axes[0,1].set(xlabel="Epoch", ylabel="Mean squared error")

  help_B = jnp.abs(B[order])
  axes[1,0].plot([jnp.sum((jnp.abs(Bs[i])-help_B)**2)/(B.shape[0]*B.shape[1]) for i in range(len(Bs))])
  axes[1,0].set_title("Mean squared error from estimated B to true B")
  axes[1,0].set(xlabel="Epoch", ylabel="Mean squared error")

  help_b = jnp.abs(b[order])
  axes[1,1].plot([jnp.sum((jnp.abs(bs[i])-help_b)**2)/b.shape[0] for i in range(len(bs))])
  axes[1,1].set_title("Mean squared error from estimated b to true b")
  axes[1,1].set(xlabel="Epoch", ylabel="Mean squared error")

  help_W = jnp.abs(W[:,order])
  axes[0,2].plot([jnp.sum((jnp.abs(Ws[i])-help_W)**2)/(W.shape[0]*W.shape[1]) for i in range(len(Ws))])
  axes[0,2].set_title("Mean squared error from estimated W to true W")
  axes[0,2].set(xlabel="Epoch", ylabel="Mean squared error")

    
  # if J.shape[0] == 5:
  #   axes[0,0].set_ylim(bottom=0)
  #   axes[0,1].set_ylim(bottom=0, top=10)
  #   axes[0,1].set_xlim(left=0, right=20000)
  #   axes[1,0].set_ylim(bottom=0, top=10)
  #   axes[1,0].set_xlim(left=0, right=20000)
  #   axes[1,1].set_ylim(bottom=0, top=2)
  #   axes[1,1].set_xlim(left=0, right=20000)
  #   axes[0,2].set_ylim(bottom=0, top=10)
  #   axes[0,2].set_xlim(left=0, right=20000)
  # else:
  #   axes[0,0].set_ylim(bottom=0)
  #   axes[0,1].set_ylim(bottom=0, top=10)
  #   axes[0,1].set_xlim(left=0, right=20000)
  #   axes[1,0].set_ylim(bottom=0, top=10)
  #   axes[1,0].set_xlim(left=0, right=20000)
  #   axes[1,1].set_ylim(bottom=0, top=2)
  #   axes[1,1].set_xlim(left=0, right=20000)
  #   axes[0,2].set_ylim(bottom=0, top=10)
  #   axes[0,2].set_xlim(left=0, right=20000)

  plt.show()

  #RMSPROP instead of ADAM, no momentum, moving target

In [None]:
#Define number of neurons, control inputs and observations
neuron_groups = jnp.array([[3,2],[2,2],[2,2]]) #N,M
N = jnp.sum(neuron_groups[:,0]).item() #neurons
C = 3 #control
M = jnp.sum(neuron_groups[:,1]).item() #observations
key = jrandom.PRNGKey(0)
tau = 1 #time constant

def block_diag(a,b):
  result = np.zeros((a.shape[0]+b.shape[0],a.shape[1]+b.shape[1]))
  result[:a.shape[0],:a.shape[1]] = a
  result[a.shape[0]:,a.shape[1]:] = b
  return jnp.array(result)

keys = jrandom.split(key, num=4+neuron_groups.shape[0])
key = keys[0]
#Use Bernoulli matrices to induce sparsity
p = 1.0 #sparsity
J = jrandom.normal(keys[1], shape=(N,N)) * jrandom.bernoulli(keys[1], p=p, shape=(N,N)) 
B = jrandom.normal(keys[2], shape=(N,C)) * jrandom.bernoulli(keys[2], p=p, shape=(N,C)) 
b = jrandom.normal(keys[3], shape=(N,))
# W = jrandom.normal(keys[4], shape=(M,N)) * jrandom.bernoulli(keys[4], p=p, shape=(M,N))
W = jnp.empty([0,0])
for i in range(neuron_groups.shape[0]):
  neuron = neuron_groups[i]
  W = block_diag(W, jrandom.normal(keys[4+i], shape=(neuron[1],neuron[0])) * jrandom.bernoulli(keys[4+i], p=p, shape=(neuron[1],neuron[0])))
block = jnp.where(W==0,W,1)
#State equation for the CDE
f_state = lambda t, x, args: jnp.concatenate(((-x+J@r(x) + b).reshape(N,1), B), axis = 1)

#Observation function for the CDE
f_obs = lambda x : W@r(x)

#Define CDE
system = CDE(f_state, f_obs)

#Time point that defines the starting point of the NCDE. The data before this point is used to predict the initial condition.
t0 = 75

#Sample path
T = 500
time_points = jnp.linspace(0, 8*pi, T)
losses, Js, Bs, bs, Ws = train_nn(key, N, C, M, system, time_points, nr_batch = 64, tau=tau, t0=t0, block=block)
plot_figures(losses, Js, Bs, bs, Ws, J, B, b, W, neuron_groups)

#kleine weights, uniform weights, relu initialization network, mean 0, identity J, 0.01 identity, 1/N J
#solve steps
#random batches

# lr: 5e-3, 1e-3, 5e-4
# increased init lr
# random vs batch
# identity vs 0.01 identity vs 0.01 random weights vs uniform 1/N^2
# initialization network relu vs tanh vs silu

Currently at epoch: 1000. The loss is: 505.3908996582031
Currently at epoch: 2000. The loss is: 199.19320678710938
Currently at epoch: 3000. The loss is: 166.9359893798828
Currently at epoch: 4000. The loss is: 155.99000549316406
Currently at epoch: 5000. The loss is: 136.4904022216797
Currently at epoch: 6000. The loss is: 71.41743469238281
Currently at epoch: 7000. The loss is: 52.659767150878906
Currently at epoch: 8000. The loss is: 41.815467834472656
Currently at epoch: 9000. The loss is: 27.86749267578125
Currently at epoch: 10000. The loss is: 20.534526824951172
Currently at epoch: 11000. The loss is: 17.237503051757812
Currently at epoch: 12000. The loss is: 14.525548934936523
Currently at epoch: 13000. The loss is: 12.671895980834961
Currently at epoch: 14000. The loss is: 13.350332260131836
Currently at epoch: 15000. The loss is: 8.616917610168457
Currently at epoch: 16000. The loss is: 8.680686950683594
Currently at epoch: 17000. The loss is: 7.372756004333496
Currently at e

In [None]:
order = find_block_order(Ws[-1],W, neuron_groups)
fig, axes = plt.subplots(4, 2, figsize=(14, 14))
g1 = sns.heatmap(Js[-1],annot=True,ax=axes[0,0])
g2 = sns.heatmap([j[order] for j in J[order]],annot=True,ax=axes[0,1])
g3 = sns.heatmap(Bs[-1],annot=True,ax=axes[1,0])
g4 = sns.heatmap(B[order],annot=True,ax=axes[1,1])
g5 = sns.heatmap([bs[-1]],annot=True,ax=axes[2,0])
g6 = sns.heatmap([b[order]],annot=True,ax=axes[2,1])
g7 = sns.heatmap(Ws[-1],annot=True,ax=axes[3,0])
g8 = sns.heatmap(W[:,order],annot=True,ax=axes[3,1])
plt.show()

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
dictionary = {'losses':losses, 'Js':Js, 'Bs':Bs, 'bs':bs, 'Ws':Ws, 'True_J':J,'True_B':B,'True_b':b,'True_W':W}
np.save(f'/content/drive/MyDrive/Thesis_Data/J_identity_0.01_fixed_batch.npy', dictionary) 

In [None]:
# read_dictionary = np.load('/content/drive/MyDrive/Thesis_Data/N_6_t0_50.npy',allow_pickle='TRUE').item()
# plot_figures(losses, read_dictionary['Js'], read_dictionary['Bs'], read_dictionary['bs'], read_dictionary['Ws'], read_dictionary['True_J'], read_dictionary['True_B'], read_dictionary['True_b'], read_dictionary['True_W'], read_dictionary['x0s'], read_dictionary['x0'])