In [None]:
import torch
from einops import rearrange,repeat
from typing import Union, Tuple
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
if torch.cuda.is_available():
	torch.set_default_dtype(torch.float16)
	torch.set_default_device("cuda")
	DTYPE = torch.float16
else:
	torch.set_default_dtype(torch.float32)
	DTYPE = torch.float32

## Constants

In [None]:
# ----------------------
# ----------------------

NUM_EXCITATORY = 8000
NUM_INHIBITORY = 2000
P_CONNECTION = 0.20

# for generalizability, all connections could be slow or fast
RATIO_SLOW_EXCITATORY_EXCITATORY = 0.5
RATIO_SLOW_INHIBITORY_INHIBITORY = 0.1
RATIO_SLOW_EXCITATORY_INHIBITORY = 0
RATIO_SLOW_INHIBITORY_EXCITATORY = 0

P_POTENTIATED = 0.2 # 0.2
J_POTENTIATED = 0.21 # 0.21
J_DEPRESSED = 0.03 # 0.03
J_E_TO_I = 0.08 # 0.08
J_I_TO_E = -0.18 #-0.18
J_I_TO_I = -0.18 #-0.18

# ----------------------
# ----------------------

REFRACTORY_TIME = 2
MEMBRANE_REFRACTORY_VEC = REFRACTORY_TIME*torch.ones([NUM_EXCITATORY + NUM_INHIBITORY])

MEMBRANE_CONST_EXCITATORY = 20
MEMBRANE_CONST_INHIBITORY = 10
MEMBRANE_CONST_VEC = torch.concat((MEMBRANE_CONST_EXCITATORY*torch.ones(NUM_EXCITATORY),
						  MEMBRANE_CONST_INHIBITORY*torch.ones(NUM_INHIBITORY)))

THRESHOLD_POTENTIAL = 20
V_RESET_EXCITATORY = 15
V_RESET_INHIBITORY = 10
THRESHOLD_POTENTIAL_VECTOR = THRESHOLD_POTENTIAL*torch.ones([NUM_EXCITATORY + NUM_INHIBITORY])
V_RESET_VECT = torch.concat((V_RESET_EXCITATORY*torch.ones(NUM_EXCITATORY),
							 V_RESET_INHIBITORY*torch.ones(NUM_INHIBITORY)))

# ----------------------
# ----------------------

DT = 0.1
T_MAX = 1000
TIMELINE = torch.linspace(0,T_MAX,int(T_MAX/DT))
T_IDX_MAX = len(TIMELINE)

# ----------------------
# ----------------------

TAU_S_EXCITATORY = 	10
TAU_S_INHIBITORY = 10
TAU_S_VEC = torch.concat((TAU_S_EXCITATORY*torch.ones(NUM_EXCITATORY),
						  TAU_S_INHIBITORY*torch.ones(NUM_INHIBITORY)))

# ----------------------
# ----------------------

TAU_RECOVERY = 200

# ----------------------
# ----------------------

ALPHA_PLASTICITY_RECOVERY = 0.0147
BETA_PLASTICITY_RECOVERY = 0.01
THETA_X_PLASTICITY_THRESHOLD = 0.4
U_SPIKING_COST = 0.45

# ----------------------
# ----------------------

A_HEBBIAN = 0.25
B_HEBBIAN = 0.17
THETA_LTP = 17.5
THETA_LTD = 15.5

## Neuron setup

- We first initialize the neurons and the connectivity between them

In [None]:
def random_sparse(m:int,
				  n:int,
				  p:float)->torch.Tensor:
	"""
	Function for creating a sparse m x n with zeros and ones with some probability.

	Args:
		m: number of post-synaptic neurons
		n: number of pre-synaptic neurons
		p: probability of connection
	
	Returns:
		mat: m x n sparse matrix
	"""
	mat = torch.zeros(m,n)
	mat[torch.rand(m,n)<p] = 1
	return mat

def sparse_square(n,p):
	"""
	Function for creating a n x n sparse matrix with zeros on the diagonal

	Args:
		n: number of neurons
		p: probability of connection

	Returns:
		mat: n x n sparse matrix
	"""
	mat = torch.zeros(n,n)
	mat[torch.rand(n,n)<p] = 1
	return mat*(1-torch.eye(n))

In [None]:
def slow_fast_pair(m:int,
				   n:int,
				   p_connection:float,
				   ratio_slow:float,)->torch.Tensor:

	"""
	Function for obtaining pairs of connectivity matrices for slow/fast currents.
	The matrices will not overlap so connections can either be slow or fast.

	Args:
		m: number of post-synaptic neurons
		n: number of pre-synaptic neurons
		p_connection: probability of connection:
		ratio_slow: ratio of slow connection over all connections

	Returns:
		slow: sparse matrix for slow currents
		fast: sparse matrix for fast currents
	"""

	p_slow = p_connection*ratio_slow
	p_fast = p_connection*(1-ratio_slow)

	if m==n:
		slow = sparse_square(m,p_slow)
		fast = sparse_square(m,p_fast)
		fast = (1-slow)*fast
		return slow,fast

	else:
		slow = random_sparse(m,n,p_slow)
		fast = random_sparse(m,n,p_fast)
		fast = (1-slow)*fast
		return slow,fast 

In [None]:
def initialize_j(u_mat:torch.Tensor,
				 p_pot:float,
				 j_p:float,
				 j_d:float)->Tuple[torch.Tensor]:

	"""
	Function for filling connectivity matrix with one of two possible conductance value.
	Also initializes the hidden synaptic variable L

	Args:
		u_mat: connectivity matrix
		p_pot: probability for potentiated conductance
		j_p: conductivity for potentiated synapse
		j_d: conductivity for depressed synapse

	Returns:
		conductance matrix
	"""
	j = j_d*torch.ones_like(u_mat)
	initalization = torch.rand_like(j)
	x_init = torch.zeros_like(u_mat)
	x_init[initalization<=p_pot] = 1
	j[initalization<=p_pot] = j_p
	return (j*u_mat,x_init*u_mat)

In [None]:
def get_connectivity_matrices():

	U_e_to_e_slow,U_e_to_e_fast = slow_fast_pair(NUM_EXCITATORY,NUM_EXCITATORY,P_CONNECTION,RATIO_SLOW_EXCITATORY_EXCITATORY)
	U_i_to_e_slow,U_i_to_e_fast = slow_fast_pair(NUM_EXCITATORY,NUM_INHIBITORY,P_CONNECTION,RATIO_SLOW_EXCITATORY_INHIBITORY)
	U_e_to_i_slow,U_e_to_i_fast = slow_fast_pair(NUM_INHIBITORY,NUM_EXCITATORY,P_CONNECTION,RATIO_SLOW_INHIBITORY_EXCITATORY)
	U_i_to_i_slow,U_i_to_i_fast = slow_fast_pair(NUM_INHIBITORY,NUM_INHIBITORY,P_CONNECTION,RATIO_SLOW_INHIBITORY_INHIBITORY)

	U_to_e_fast = torch.concat((U_e_to_e_fast,U_i_to_e_fast),1)
	U_to_i_fast = torch.concat((U_e_to_i_fast,U_i_to_i_fast),1)
	U_fast = torch.concat((U_to_e_fast,U_to_i_fast))

	U_to_e_slow = torch.concat((U_e_to_e_slow,U_i_to_e_slow),1)
	U_to_i_slow = torch.concat((U_e_to_i_slow,U_i_to_i_slow),1)
	U_slow = torch.concat((U_to_e_slow,U_to_i_slow))

	J_e_to_e_fast,l_e_to_e_fast = initialize_j(U_e_to_e_fast,p_pot=P_POTENTIATED,j_p=J_POTENTIATED,j_d=J_DEPRESSED)
	J_e_to_e_slow,l_e_to_e_slow = initialize_j(U_e_to_e_slow,p_pot=P_POTENTIATED,j_p=J_POTENTIATED,j_d=J_DEPRESSED)
	
	J_i_to_e_slow,_ = initialize_j(U_i_to_e_slow,p_pot=1,j_p=J_I_TO_E,j_d=0)
	J_i_to_e_fast,_ = initialize_j(U_i_to_e_fast,p_pot=1,j_p=J_I_TO_E,j_d=0)

	J_e_to_i_slow,_ = initialize_j(U_e_to_i_slow,p_pot=1,j_p=J_E_TO_I,j_d=0)
	J_e_to_i_fast,_ = initialize_j(U_e_to_i_fast,p_pot=1,j_p=J_E_TO_I,j_d=0)

	J_i_to_i_slow,_ = initialize_j(U_i_to_i_slow,p_pot=1,j_p=J_I_TO_I,j_d=0)
	J_i_to_i_fast,_ = initialize_j(U_i_to_i_fast,p_pot=1,j_p=J_I_TO_I,j_d=0)
	
	J_to_e_fast = torch.concat((J_e_to_e_fast,J_i_to_e_fast),1)
	J_to_i_fast = torch.concat((J_e_to_i_fast,J_i_to_i_fast),1)
	J_fast = torch.concat((J_to_e_fast,J_to_i_fast))

	J_to_e_slow = torch.concat((J_e_to_e_slow,J_i_to_e_slow),1)
	J_to_i_slow = torch.concat((J_e_to_i_slow,J_i_to_i_slow),1)
	J_slow = torch.concat((J_to_e_slow,J_to_i_slow))

	w_excitatory = torch.concat((U_e_to_e_fast + U_e_to_e_slow, torch.zeros_like(U_i_to_e_fast)),1)
	w_inhibitory = torch.zeros_like((U_to_i_fast))
	w_mat = torch.concat((w_excitatory,w_inhibitory))

	l_excitatory = torch.concat((l_e_to_e_fast+l_e_to_e_slow,torch.zeros_like(U_i_to_e_fast)),1)
	l_mat = torch.concat((l_excitatory,w_inhibitory))

	return U_slow,U_fast,J_slow,J_fast,w_mat,l_mat

## Current dynamics

In [None]:
def sigma(v_t:torch.Tensor,
		  v_th:torch.Tensor)->torch.Tensor:
	"""
	Function for verifying which neurons are spiking

	Args:
		v_t: membrane potentials at time t
		v_th: vector of threshold potentials (may not be uniform)

	Returns:
		tensor of 0 (not spiking), 1 (spiking) 
	"""
	return (v_t>=v_th).to(DTYPE)

In [None]:
def dis_dt(i_s:torch.Tensor,
		   tau_s:float,
		   dt:float,
		   j_slow:torch.Tensor,
		   x_sigma_v:torch.Tensor)->torch.Tensor:

	"""
	Function for obtaining the update using forward Euler for slow current dynamics
	tau_s di/dt = - i_s + j sigma(v,v_th)

	Args:
		i_s: slow currents at time t
		tau_s: time constant for slow currents
		dt: time step for forward Euler
		u_slow: synaptic conductance matrix for slow currents
		x_sigma_v: boolean vector for spiking neurons

	Returns:
		update for slow currents at time t+1
	"""
	return (dt/tau_s)*(-i_s + j_slow @ x_sigma_v)

## Short term plasticity

In [None]:
def dx_dt(x:torch.Tensor,
		  u:float,
		  tau_r:float,
		  x_sigma_v:torch.Tensor,
		  dt:float)->torch.Tensor:

	"""
	Function for obtaining the update using forward Euler for synaptic resource dynamics

	Args:
		x: resources at time t
		u: cost of spiking
		tau_r: recovery time
		x_sigma_v: boolean vector for spiking neurons
		dt: time step for forward Euler

	Returns:
		update for slow currents at time t+1
	"""
	return dt*((1-x)/tau_r)-u*x_sigma_v

## Long-term plasticity

In [None]:
def r_t(l:torch.Tensor,
		alpha:float,
		beta:float,
		theta_x:float)->torch.Tensor:

	"""
	Function for computing the recovery term for long-term plasticity dynamics

	Args:
		l: plasticity latent variable vector at time t
		alpha: upward recovery factor
		beta: downward recovery factor
		theta_x: plasticity threshold

	Returns:
		vector of recovery term at time t
	"""
	return -alpha*((theta_x-l)>=0) + beta*((l-theta_x)>=0)

In [None]:
def f_v(v:torch.Tensor,
		a:float,
		b:float,
		theta_ltp:float,
		theta_ltd:float,
		v_th:torch.Tensor,
		v_reset:torch.Tensor)->torch.Tensor:
	"""
	Function for verifying if post and pre-synaptic neurons are firing together.

	Args:
		v: membrane potential at time t
		a: step up value for synchronous firing
		b: step down value for asynchronous firing
		theta_ltp: threshold potential for long-term potentiation
		theta_ltd: threshold potential for long-term depression
		v_th: threshold potential for action potential firing
		v_reset: reset potential

	Returns:
		vector for tracking synchronous firing
	"""
	return a*(theta_ltp<=v)*(v<=v_th) - b*(v<=theta_ltd)*(v<=v_reset)

In [None]:
def dl_dt(h_t:torch.Tensor,
		  w:torch.Tensor,
		  r_t:torch.Tensor,
		  dt:float):
	
	"""
	Function for the long-term plasticity latent variable update with forward Euler

	Args:
		h_t: Hebbian term matrix
		w: masking matrix for excitatory-excitatory synapses
		r_t: recovery term matrix
		dt: time step for forward Euler

	Returns:
		n x n latent variable for plasticity
	"""
	return dt*w*(r_t+h_t)

In [None]:
def update_j(j:torch.Tensor,
			 j_p:float,
			 j_d:float,
			 threshold:float,
			 l_old:torch.Tensor,
			 l_new:torch.Tensor):

	"""
	Function for updating conductivity matrices based on plasticity latent variable

	Args:
		j: conductivity matrix
		j_p: potentiated conductance
		j_d: depressed conductance
		threshold: latent variable threshold for long-term plasticity
		l_old: latent variable value at time t
		l_new: updated latent variable value

	Returns:
		updated conductivity matrix
	"""
	a = (l_old<=threshold)*(threshold<l_new)
	b = (l_old>=threshold)*(threshold>l_new)
	return j_p*a + j_d*b + j*(1-(a+b).int())

## Membrane dynamics

In [None]:
def da_dt(v:torch.Tensor,
		  tau_r:torch.Tensor,
		  v_th:torch.Tensor,
		  dt:float):
	
	"""
	Function for the spiking hidden variable update using forward Euler

	Args:
		v: membrane potential at time t
		tau_r: vector of membrane refractory time constant
		v_th: vector of membrane threshold potentials
		dt: time step for forward Euler

	Returns:
		vector of spiking hidden variable  update
	"""
	return (1/tau_r)*dt*(v>=v_th)

def dv_dt(v:torch.Tensor,
		  i:torch.Tensor,
		  tau_m:torch.Tensor,
		  v_th:torch.Tensor,
		  dt:float):

	"""
	Function for the membrane potential update using forward Euler

	Args:
		v: membrane potential at time t
		i: total currents at time t
		tau_m: vector of membrane 
		v_th: vector of membrane threshold potentials
		dt: time step for forward Euler

	Returns:
		vector of spiking hidden variable update
	"""
	return (-v/tau_m+i)*dt*(v<=v_th)

## Putting everything together

In [None]:
U_slow,U_fast,J_slow,J_fast,W_mat,L_mat = get_connectivity_matrices()
NUM_NEURONS = NUM_EXCITATORY + NUM_INHIBITORY
V_MEMBRANE_POTENTIALS = torch.zeros([NUM_EXCITATORY + NUM_INHIBITORY,T_IDX_MAX])
A_SPIKING_STATE_VAR = torch.zeros_like(V_MEMBRANE_POTENTIALS)
X_RESOURCE_STATE_VAR = torch.ones_like(V_MEMBRANE_POTENTIALS)
I_S_SLOW_CURRENTS = torch.zeros_like(V_MEMBRANE_POTENTIALS)
I_F_FAST_CURRENTS = torch.zeros_like(V_MEMBRANE_POTENTIALS)

AMPLITUDES = 5*torch.rand([NUM_NEURONS])
AMPLITUDES = rearrange(AMPLITUDES,"n -> n 1")
PHASES = 10*torch.rand([NUM_NEURONS])
FREQUENCIES = 5*torch.rand([NUM_NEURONS])+5

t_sin = repeat(TIMELINE,"t ->t n",n=NUM_NEURONS)
t_sin = (t_sin/FREQUENCIES) + PHASES
t_sin = rearrange(t_sin,"t n -> n t")

I_EXT = AMPLITUDES*torch.sin(t_sin)+AMPLITUDES

In [None]:
W_mat

In [None]:
plt.plot(I_EXT[0].cpu())
plt.plot(I_EXT[1].cpu())
plt.plot(I_EXT[2].cpu())
plt.show()

- Stored variables:
  - $V(t)$: membrane potential
  - $I_s(t)$: slow currents
  - $I_f(t)$: fast currents
  - $I_t(t)$: total currents 
  - $I_e(t)$: external currents
  - $x(t)$: synaptic resources
  - $a(t)$: for keeping track of spiking
- Passed states (not saved):
  - $L$: square matrix
  - $R$: square matrix (fully described by $L$ if necessary)
  - $H$: square matrix fully described by $v,a$
  - $J$: Not sure, it's kind of expensive to track since it's a square matrix but studying learning dynamics could be interesting

In [None]:
def simulation_step(v:torch.Tensor,
					i_s:torch.Tensor,
					i_f:torch.Tensor,
					i_e:torch.Tensor,
					x:torch.Tensor,
					a:torch.Tensor,
					l:torch.Tensor,
					j_slow:torch.Tensor,
					j_fast:torch.Tensor,
					w:torch.Tensor,
					u_slow:torch.Tensor,
					u_fast:torch.Tensor):
	
	# we must first compute the total currents and update the membrane potential

	sigma_v = sigma(v,THRESHOLD_POTENTIAL_VECTOR)
	x_sigma_v = x*sigma_v

	i_s_new = i_s + dis_dt(i_s,TAU_S_VEC,DT,j_slow,x_sigma_v)

	i_f_new = j_fast @ x_sigma_v
	# i_f_new = torch.sparse.mm(j_fast.to(torch.float32).to_sparse(),x_sigma_v.to(torch.float32).unsqueeze(-1))
	# i_f_new = i_f_new.to(DTYPE).squeeze()

	i_t = i_s + i_f + i_e

	# we can then update the membrane potential

	da = da_dt(v,MEMBRANE_REFRACTORY_VEC,THRESHOLD_POTENTIAL_VECTOR,DT)
	dv = dv_dt(v,i_t,MEMBRANE_CONST_VEC,THRESHOLD_POTENTIAL_VECTOR,DT)
	a_new = torch.zeros_like(a) + (a+da)*(a<1)
	v_new = V_RESET_VECT*(a>=1) + (v+dv)*(a<1)

	# finally, we can update the hidden variables for plasticity

	x_new = x + dx_dt(x,U_SPIKING_COST,TAU_RECOVERY,x_sigma_v,DT)

	r_mat = r_t(l=l,alpha=ALPHA_PLASTICITY_RECOVERY,beta=BETA_PLASTICITY_RECOVERY,
			 theta_x=THETA_X_PLASTICITY_THRESHOLD)

	f_v_t = f_v(v,A_HEBBIAN,B_HEBBIAN,THETA_LTP,THETA_LTD,THRESHOLD_POTENTIAL_VECTOR,
			 V_RESET_VECT)
	
	h_t = rearrange(f_v_t,"n -> n 1") @ rearrange(sigma_v,"n -> 1 n")
	
	l_new = l + dl_dt(h_t,w,r_mat,DT)

	# we don't want to create any new slow/fast connection so we need to mask
	
	j_fast_new = u_fast*update_j(j_fast,J_POTENTIATED,J_DEPRESSED,THETA_X_PLASTICITY_THRESHOLD,
					   l,l_new)
	
	j_slow_new = u_slow*update_j(j_slow,J_POTENTIATED,J_DEPRESSED,THETA_X_PLASTICITY_THRESHOLD,
					   l,l_new)
	
	return v_new,i_s_new,i_f_new,x_new,a_new,j_slow_new,j_fast_new

In [None]:

for idx in tqdm(range(T_IDX_MAX-1)):
	v_t = V_MEMBRANE_POTENTIALS[:,idx]
	i_s_t = I_S_SLOW_CURRENTS[:,idx]
	i_f_t = I_F_FAST_CURRENTS[:,idx]
	i_e_t = I_EXT[:,idx]
	x_t = X_RESOURCE_STATE_VAR[:,idx]
	a_t = A_SPIKING_STATE_VAR[:,idx]

	v_new,i_s_new,i_f_new,x_new,a_new,J_slow,J_fast = simulation_step(v=v_t,
																	i_s=i_s_t,
																	i_f=i_f_t,
																	i_e=i_e_t,
																	x=x_t,
																	a=a_t,
																	l=L_mat,
																	j_slow=J_slow,
																	j_fast=J_fast,
																	w=W_mat,
																	u_slow=U_slow,
																	u_fast=U_fast)
	
	V_MEMBRANE_POTENTIALS[:,idx+1] = v_new
	I_S_SLOW_CURRENTS[:,idx + 1] = i_s_new
	I_F_FAST_CURRENTS[:,idx + 1] = i_f_new
	X_RESOURCE_STATE_VAR[:,idx + 1] = x_new
	A_SPIKING_STATE_VAR[:,idx + 1] = a_new

In [None]:
V_MEMBRANE_POTENTIALS[0]

In [None]:
plt.plot(TIMELINE.cpu(),V_MEMBRANE_POTENTIALS[0].cpu())
plt.xlim(0,300)

In [None]:
plt.plot(TIMELINE.cpu(),I_F_FAST_CURRENTS[1].cpu())
plt.xlim(0,300)

- We definitely see the rapid inhibitory currents

In [None]:
plt.plot(I_S_SLOW_CURRENTS[1].cpu())
plt.xlim(0,300)