In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

class WilsonCowanParams:
    def __init__(self):
        self.r = 0.2 
        self.tau_E = 1.0
        self.tau_I = 2.0
        self.w_ee = 8 + np.random.randn()
        self.w_ei = 4 + np.random.randn()
        self.w_ie = 10 + np.random.randn()
        self.w_ii = 3 + np.random.randn()
        self.P = 1 + 0.1 * np.random.randn()
        self.Q = 1 + 0.1 * np.random.randn()

class Region:
    def __init__(self, region_id, T, dt):
        t = np.arange(0, T + dt, dt)
        self.region_id = region_id
        self.E = np.zeros(len(t))
        self.I = np.zeros(len(t))

        self.E[0] = 0.2
        self.I[0] = 0.3

        self.params = WilsonCowanParams()

        
def sig(x):
    return 1 / (1 + np.exp(-x))

def dedt(t, E, I, E_ext, region : Region, C):
    params = region.params
    print(np.dot(C[region.region_id], E_ext), params.w_ee*E)
    return (-E + (1 - params.r*E) * Se(params.w_ee*E + G * np.dot(C[region.region_id], E_ext) - params.w_ei*I + params.P)) / params.tau_E

def didt(t, E, I,region):
    params = region.params
    return (-I + (1 - params.r*I) * Si(params.w_ie*E - params.w_ii*I + params.Q)) / params.tau_I

def plot_results(regions, t):
    fig, axs = plt.subplots(10, 10, figsize=(60, 60))
    axs = axs.ravel()
    
    for i, region in enumerate(regions):  # Plot first 4 regions
        axs[i].plot(t, region.E, label='E')
        axs[i].plot(t, region.I, label='I')
        axs[i].set_title(f'Region {region.region_id}')
        axs[i].set_xlabel('Time')
        axs[i].set_ylabel('Activity')
        axs[i].legend()
    
    plt.tight_layout()
    plt.show()

# Main simulation
num_regions = 100
T = 100
dt = 0.01
G = 2.
  
# Load and normalise connectivity matrix
C = pd.read_csv('./fmri/DTI_fiber_consensus_HCP.csv', header=None).to_numpy()[:num_regions, :num_regions]
np.fill_diagonal(C, 0)
C /= C.sum(axis=1)[:, np.newaxis]

# Run simulation
t = np.arange(0, T + dt, dt)
regions = [Region(i, T, dt) for i in range(num_regions)]

for i in range(len(t) - 1):
    E_ext = np.array([r.E[i] for r in regions])
    
    for r in range(num_regions):
        curr = regions[r]
        
        curr.E[i + 1] = curr.E[i] + dt * (dedt(t[i], curr.E[i], curr.I[i], E_ext, curr, C))
        curr.I[i + 1] = curr.I[i] + dt * (didt(t[i], curr.E[i], curr.I[i], curr))

# Plot results
plot_results(regions, t)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
np.random.seed(42)

plt.rc('text', usetex=True)
plt.rc('font', family='serif')
plt.rc('animation', html='jshtml')

%matplotlib inline

# constants
r = .2
tau_E = 1.
tau_I = 2.

P = 0.2
Q = .5

w_ee = 10
w_ei = 12
w_ie = 9
w_ii = 3

def sig(x):
    return 1/(1 + np.exp(-x))

def tanh(x):
    return 2*sig(2*x) - 1

def dedt(t, E, I, P):
    de = (-E + (1 - r*E)*sig(w_ee*E - w_ei*I + P)) / tau_E
    # de = (-E + sig(k*E + k*P)) / tau_E
    return de

def didt(t, E, I, Q):
    di = (-I + (1 - r*I)*sig(w_ie*E - w_ii*I + Q)) / tau_I
    return np.array([di])

T = 6
dt = .01
t = np.arange(0, T + dt, dt)
E = np.zeros(len(t))
I = np.zeros(len(t))

for i in range(len(t) - 1):
    E[i + 1] = E[i] + dt * dedt(t[i], E[i], I[i], P)
    I[i + 1] = I[i] + dt * didt(t[i], E[i], I[i], Q)

# E += np.random.normal(0, 0.01, E.shape)
plt.plot(E, label='Excitatory')
plt.plot(I, label='Inhibitory')
# plt.plot(I, label='Inhibitory')
plt.legend()

# np.save('E_synthetic_kdot5_Pdot2.npy', E)
plt.xlabel('Time')
plt.ylabel('Activity')
# plt.savefig('varying_k.pdf', dpi=600)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import multivariate_normal
from ipywidgets import interact
import ipywidgets as widgets
from scipy.stats import norm

N    = 200
X    = np.linspace(-4, 4, N)
Y    = np.linspace(-4, 4, N)
X, Y = np.meshgrid(X, Y)
pos  = np.dstack((X, Y))


def gauss2d(mean_a, mean_b, cov_aa, cov_bb):
    rv = multivariate_normal([mean_a, mean_b], [[cov_aa, 0.8], [0.8, cov_bb]])
    Z = rv.pdf(pos)
    plt.contour(X, Y, Z)
    plt.show()

def gauss1d(mean, std):
    X = np.linspace(-4, 4, N)
    z = norm.pdf(X, mean, std)
    plt.plot(X, z)
    plt.show()
    
# interact(gauss2d, mean_a = widgets.FloatSlider(value=0, min=0.1, max=5, step=0.1),
#     mean_b = widgets.FloatSlider(value=0, min=0.1, max=5, step=0.1),
#     cov_aa = widgets.FloatSlider(value=1, min=0.1, max=5, step=0.1),
#     cov_bb = widgets.FloatSlider(value=2, min=0.1, max=5, step=0.1))

interact(gauss1d, mean = widgets.FloatSlider(value = 0, min = 0, max = 5, step=0.1), std = widgets.FloatSlider(value = 0, min = .1, max = 3, step = 0.1))

In [None]:
Et_slider = widgets.FloatSlider(value = 0, min = -5, max = 5, step=0.1)

def gauss1d_neuronal(Et, std):
    X = np.linspace(-4, 4, N)
    mean = Et + dt*dedt(0, Et, I=0, P=.2)
    print(Et, mean)
    z = norm.pdf(X, mean, std)
    plt.plot(X, z)
    plt.show()

def gauss2d_neuronal(mean_a, mean_b, cov_aa, cov_bb):
    rv = multivariate_normal([mean_a, mean_b], [[cov_aa, 0.8], [0.8, cov_bb]])
    Z = rv.pdf(pos)
    plt.contour(X, Y, Z)
    plt.show()

interact(gauss1d_neuronal, Et = Et_slider, std = 0.5)

In [None]:
# import numpy as np
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import interact
import ipywidgets as widgets

P = .2
sigma_n = 0.05
k = 5
tau_E = 1

def sig(x):
    return 1/(1 + np.exp(-x))

def dedt(t, E, I, P):
    # de = (-E + (1 - r*E)*sig(w_ee*E - w_ei*I + P)) / tau_E
    de = (-E + sig(k*E + k*P)) / tau_E
    return de

def ddedt(E):
    s =  sig(k*E + k*P)
    dde = (-1 + k * s * (1 - s)) / tau_E
    return dde

def j_func(E):
    return 1 + ddedt(E)

def Etp(E):
    return E + dedt(0, E, 0, P)

def linearised(x0):
    J = j_func(X)
    return Etp(x0) + J*(X - x0)

def gauss_approx(x0):
    J_x0 = j_func(x0)
    sigman_inv = (1/sigma_n)
    eta = J_x0.T * (z(x0) - Etp(x0) + J_x0*x0)
    lam = J_x0.T * sigman_inv * J_x0

    return eta, lam

X = np.linspace(-1, 1, 500)

def z(X):
    X = np.asarray(X)
    return Etp(X) + np.random.normal(0, 0.05, size=X.shape)

def linearised_with_slider(x0):
    plt.plot(X, z(X))
    plt.plot(X, linearised(x0))
    plt.axvline(x = x0, color = 'b', label = 'axvline - full height')
    plt.show()

X = np.linspace(-1, 1, 500)
x0 = 0.5

interact(linearised_with_slider, x0 = widgets.FloatSlider(value = 0, min = -1, max = 1, step=0.1))

# def gauss1d(mean, std):
#     z = norm.pdf(np.linalg.norm(-5, 5, 500), mean, std)
#     plt.plot(np.linalg.norm(-5, 5, 500), z)
#     plt.show()

# def inf_to_moments(eta, lam):
#     return (eta/lam, 1/lam)

# gauss1d(*inf_to_moments(*gauss_approx(0.5)))

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact
import ipywidgets as widgets
from scipy.stats import norm
import torch

tau_E = 1.
P = .2
k = 5

def sigmoid(x):
    return 1 / (1 + torch.exp(-x))
    
def _dedt(E):
    return (-E + sigmoid(k*E + k*P)) / tau_E

def meas_fn(Et):
    return Et + _dedt(Et)

def jac_fn(Et):
    meas_fn(Et).backward()
    return Et.grad

x0 = 0.2
x_pos = torch.as_tensor(x0).requires_grad_(True)
MEAS = meas_fn(x_pos)
meas_lam = torch.tensor([1/0.1])
# pred_meas = 0.8

pred_meas = torch.linspace(0.5, 1.5, 10)

diff = MEAS - pred_meas
    
print(torch.exp(-0.5 * diff * meas_lam * diff.T).detach().numpy())

jac = jac_fn(x_pos)

diff = MEAS - pred_meas
lam = jac.T * meas_lam * jac
eta = jac.T * meas_lam * (diff + jac * x_pos)

print(torch.exp(-0.5 * MEAS * lam * MEAS - eta * MEAS))

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact
import ipywidgets as widgets
from scipy.stats import norm
import torch

k = 5
P = .2
tau_E = 1

def sigmoid(x):
    return 1 / (1 + torch.exp(-x))
    
def _dedt(E):
    return (-E + sigmoid(k*E + k*P)) / tau_E

def meas_fn(Et, Etp):
    Etp, Et = torch.as_tensor(Etp), torch.as_tensor(Et)
    return torch.abs(Etp - (Et + _dedt(Et)))

Et = torch.tensor(0.2, requires_grad=True)
Etp = torch.tensor(0.8, requires_grad=True)

X0 = torch.tensor(([Et], [Etp]))

sigma_n_inv = torch.tensor([[1/0.3, 0.], [0, 1/0.2]])

Et, Etp = torch.tensor(Et, requires_grad=True), torch.tensor(Etp, requires_grad=True)
h = meas_fn(Et, Etp)
h.backward()

J = torch.tensor((Et.grad, Etp.grad))
eta = J.T @ sigma_n_inv * (J @ torch.tensor([Et, Etp]) - h) 
lam = J.T @ sigma_n_inv @ J

Et_values = np.linspace(-2, 2, 100)
Etp_values = np.linspace(-2, 2, 100)
Et_grid, Etp_grid = np.meshgrid(Et_values, Etp_values)

# Evaluate the function over the grid (vectorized)
Et_tensor = torch.from_numpy(Et_grid)
Etp_tensor = torch.from_numpy(Etp_grid)
meas_values = meas_fn(Et_tensor, Etp_tensor).numpy()

# Plot the contour plot
fig, ax = plt.subplots(figsize=(8, 6))
contour_levels = np.linspace(np.min(meas_values), np.max(meas_values), 100)
cp = ax.contourf(Et_grid, Etp_grid, meas_values, levels=contour_levels, cmap='Greens')
ax.plot(Et_values, (torch.tensor(Et_values) + _dedt(torch.tensor(Et_values))).numpy(), label='Line of optimality, h(Et, Etp) = 0')
ax.set_xlabel('Et')
ax.set_ylabel('Etp')
cbar = fig.colorbar(cp)
cbar.set_label('h(Et, Etp)')
plt.title('h(Et, Etp) = |Etp - (Et + dedt(Et))|')
plt.legend()
plt.show()

def inf_to_moments(eta, lam):
    return (eta/lam, 1/lam)

eta, lam
inf_to_moments(eta, lam)

print(f'True value: {meas_fn(Et, Etp)}')

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact
import ipywidgets as widgets
from scipy.stats import norm
import torch
from matplotlib.colors import LinearSegmentedColormap

plt.rc('text', usetex=True)
plt.rc('font', family='serif')
plt.rc('animation', html='jshtml')

colours = [
    (0.0, 'red'),  
    (0.5, 'white'),
    (1.0, 'blue')  
]

cmap = LinearSegmentedColormap.from_list('RedWhiteBlue', colours)

k = 5
P = .2
tau_E = 1

lmbda_in = torch.tensor([[0.05 ** -2]])

def sigmoid(x):
    return 1 / (1 + torch.exp(-x))
    
def _dedt(E):
    return (-E + sigmoid(k*E + k*P)) / tau_E

def meas_fn(Et, Etp):
    Etp, Et = torch.as_tensor(Etp), torch.as_tensor(Et)
    return torch.abs(Etp - (Et + _dedt(Et)))

Et, Etp = 0.2, 0.8
Et_values = torch.linspace(-2, 2, 1000)
Etp_values = torch.linspace(-2, 2, 1000)

Et_tensor = torch.tensor(Et, requires_grad=True)
Etp_tensor = torch.tensor(Etp, requires_grad=True)

X0 = torch.tensor((Et_tensor, Etp_tensor))
h_X0 = meas_fn(Et_tensor, Etp_tensor)
h_X0.backward()
J = torch.tensor([[Et_tensor.grad, Etp_tensor.grad]])
# print(h_X0, J)

print(f'Using {Et} and {Etp} to calculate...')

eta = J.T @ lmbda_in * ((J @ X0) - h_X0) 
lmbdap = J.T @ lmbda_in @ J

Et_grid, Etp_grid = torch.meshgrid(Et_values, Etp_values)

X_flat = torch.stack((Etp_grid.flatten(), Et_grid.flatten()), dim=1)

# print((X_flat - X0).shape)
h_approx_flat = h_X0 + J @ (X_flat - X0).T

meas_values = h_approx_flat.detach().view(Et_grid.shape).numpy()

plt.plot(Et, Etp, 'r.')
contour_levels = np.linspace(-3.8, 3.8, 200)
# # print(f'True Value: {meas_fn(0.2, 0.8)}, Approximated value: {h_approx(0.2, 0.8, 0.2, 0.8)}')
plt.contourf(Et_values, Etp_values, meas_values, levels=contour_levels, cmap=cmap)
plt.xlabel(r'$E_t$')
plt.ylabel(r'$E_{t+1}$')
plt.title(rf'$h(X) \approx h(X_0) + J(X - X_0), X_0 = [{Et}, {Etp}]$')
cbar = plt.colorbar()
cbar.set_label('$h(E_t, E_{t + 1}$)')

In [None]:
import numpy as np
from scipy.integrate import odeint
from scipy.optimize import fsolve
from fg.functions import dEdt, dIdt
import matplotlib.pyplot as plt
import torch

%matplotlib inline
plt.style.use('ggplot')

config = {
    'T': 10,
    'dt': 0.01,
    'k1': 16,
    'k2': 12,
    'k3': 15,
    'k4': 3,
    'P': 0.2,
    'Q': 0.2,
    'tau_E': 1.,
    'tau_I': 1.
}

T = config.get('T', 6.)
dt = config.get('dt', 0.01)
k1 = config.get('k1', 10.)
k2 = config.get('k2', 12.)
k3 = config.get('k3', 9.)
k4 = config.get('k4', 3.)
P = config.get('P', 0.2)
Q = config.get('Q', 0.5)
tauE = config.get('tauE', 1.)
tauI = config.get('tauI', 2.)

time = torch.arange(0, T + dt, dt)

E0, I0 = 0.4, 0.5
E = torch.zeros(len(time))
I = torch.zeros(len(time))

E[0] = E0
I[0] = I0

for i in range(len(time) - 1):
    E[i + 1] = E[i] + dt * dEdt(E[i], I[i], k1, k2, P, tauE)
    I[i + 1] = I[i] + dt * dIdt(E[i], I[i], k3, k4, Q, tauI)


plt.figure(figsize=(10.3,3))
plt.ylabel(r'$E, I$')
plt.xlabel(r'$t$')
plt.plot(time, E, '.-', label="excitatory");
plt.plot(time, I, '.-', label="inhibitory");
plt.legend();

In [None]:
from fg.simulation_config import simulate_wc
import numpy as np
import matplotlib.pyplot as plt


config = {
    'T': 9,
    'dt': 0.01,
    'k1': 4.77,
    'k2': 4.907,
    'k3': 0.4,
    'k4': 3.6329,
    'P': 1.,
    'Q': 1.,
}

E, I = simulate_wc(config)
plt.plot(E - I, label='GT_E')
# plt.plot(I, label='GT_I')



In [3]:
from fg.variables import Variable, Parameter
from fg.factors import DynamicsFactor, ObservationFactor, PriorFactor
from fg.functions import sig
from fg.simulation_config import simulate_wc
from fg.graph import Graph
from fg.gaussian import Gaussian
import torch
import matplotlib.pyplot as plt
import numpy as np

if __name__ == "__main__":
    sigma_obs = 1e-2
    sigma_dynamics = 1e-3
    sigma_prior = 2e0
    iters = 200

    config = {
        'T': 0.04,
        'dt': 0.01,
        'k1': 3. + np.random.normal(),
        'k2': 5. + np.random.normal(),
        'k3': 4. + np.random.normal(),
        'k4': 3. + np.random.normal(),
        'P': 1.  + np.random.normal(0, 0.1),
        'Q': 1.  + np.random.normal(0, 0.1),
    }

    E, I = simulate_wc(config)
    X = E - I
    t = torch.arange(0, len(E), 1)

    factor_graph = Graph()

    # Add 0 as id as we populate them later
    param_dict = {
        'k1': Parameter(0, Gaussian(torch.tensor([[0.]]), torch.tensor([[sigma_prior ** 2.]])), factor_graph, []),
        # 'ks': Parameter(0, Gaussian(torch.tensor([[0.] * 4]).T, torch.diag(torch.tensor([sigma_prior ** 2.] * 4))), factor_graph, [], 4),
        'k2': Parameter(0, Gaussian(torch.tensor([[0.]]), torch.tensor([[sigma_prior ** 2.]])), factor_graph, []),
        'k3': Parameter(0, Gaussian(torch.tensor([[0.]]), torch.tensor([[sigma_prior ** 2.]])), factor_graph, []),
        'k4': Parameter(0, Gaussian(torch.tensor([[0.]]), torch.tensor([[sigma_prior ** 2.]])), factor_graph, []),
        'P':  Parameter(0, Gaussian(torch.tensor([[0.]]), torch.diag(torch.tensor([sigma_prior ** 2.]))), factor_graph, [], 1),
        'Q':  Parameter(0, Gaussian(torch.tensor([[0.]]), torch.diag(torch.tensor([sigma_prior ** 2.]))), factor_graph, [], 1)
    }

    # -- Construct FG -- #
    # Add our variable and observation factors at each time step
    for i in range(len(t)):
        factor_graph.var_nodes[f'o{i}'] = Variable(f'o{i}', 
                                                   Gaussian(torch.tensor([[0.1, 0.1]]).T, torch.tensor([[0.2, 0.], [0., 0.2]])), 
                                                   -1 if i == 0        else (f'o{i-1}', f'o{i}'),
                                                   -1 if i+1 == len(t) else (f'o{i}', f'o{i+1}'),
                                                   -1, 
                                                   factor_graph, 
                                                   2)
        
        factor_graph.factor_nodes[f'o{i}'] = ObservationFactor(f'o{i}', f'o{i}', torch.tensor([[X[i]]]), torch.tensor([[sigma_obs ** -2]]), factor_graph)

    # Add our parameters as additional variables to our factor graph
    for p_id, (_,p) in enumerate(param_dict.items()):
        p.id = f'p{p_id}'
        factor_graph.param_ids.append(p.id)
        factor_graph.var_nodes[p.id] = p

    # Connect dynamics factors between timestep i and i+1 and connect each dyn. factor to our parameters
    for i in range(len(t)):
        if i+1 < len(t):
            dyn_id = (f'o{i}', f'o{i+1}')
            factor_graph.factor_nodes[dyn_id] = DynamicsFactor(f'o{i}', f'o{i+1}', torch.tensor([[sigma_dynamics ** -2]]), dyn_id, factor_graph)

            for _,p in param_dict.items():
                p.connected_factors.append(dyn_id)

    # Zero mean priors on the parameters
    for p_id, (_,p) in enumerate(param_dict.items()):
        factor_graph.factor_nodes[f'p{p_id}'] = PriorFactor(f'p{p_id}', p.id, torch.tensor([[3.] * p.num_vars]).T, torch.diag(torch.tensor([sigma_prior ** -2] * p.num_vars)),
                                                                  factor_graph)
    
    iter = 0
    print(f'Iteration {iter}')
    for k, v in param_dict.items():
        print(k, v)

        if v.belief.eta.isnan().any(): exit(0)

    print('------')

    if iter == 0:
        # Initialise messages from observation factors to variables
        # and prior factors to parameters (if learning params)
        factor_graph.update_all_observational_factors()

        # Now update messages from variables to factors
        # This should ensure all var to dynamics factor messages have non-zero precision
        for i in factor_graph.var_nodes:
            curr = factor_graph.var_nodes[i]
            curr.compute_and_send_messages()

        factor_graph.prune()
    
    for i in range(len(t)):
        curr = factor_graph.var_nodes[f'o{i}']
        curr.compute_and_send_messages()

        if curr.right_id == -1: continue

        fac = factor_graph.factor_nodes[curr.right_id]
        fac.compute_and_send_messages()

Running simulation with: T = 0.04, dt = 0.01, k1 = 2.992554188416303, k2 = 5.140534962181106, k3 = 4.848443103889046, k4 = 1.3618461573732632, P = 1.041932515681223, Q = 0.9671668189310363, tauE = 1.0, tauI = 2.0
Iteration 0
k1 Parameter p0 [n = 1, mu=tensor([[0.]]), cov=tensor([[4.]])]
k2 Parameter p1 [n = 1, mu=tensor([[0.]]), cov=tensor([[4.]])]
k3 Parameter p2 [n = 1, mu=tensor([[0.]]), cov=tensor([[4.]])]
k4 Parameter p3 [n = 1, mu=tensor([[0.]]), cov=tensor([[4.]])]
P Parameter p4 [n = 1, mu=tensor([[0.]]), cov=tensor([[4.]])]
Q Parameter p5 [n = 1, mu=tensor([[0.]]), cov=tensor([[4.]])]
------


_LinAlgError: linalg.inv: The diagonal element 2 is zero, the inversion could not be completed because the input matrix is singular.