# Test Attractor Dynamics with LIF Neurons

In [1]:
# Imports and funcs
%matplotlib widget
from brian2 import *
import numpy as np
import scipy.sparse as sp
import excitation_schedule as es
import brian_weight_submatrix as bws
import diagonal_sums as ds
import time

In [2]:
start_scope()
# The following line suppresses a warning about order of executions in the abstract code: "v_post = clip(v_post + w, v_gaba, 0)"
# As far as I can tell, the warning is due to the inability of the OOE checker to deal with the "clip()" function, and is ok in this case
BrianLogger.suppress_hierarchy('brian2.codegen')

seed(seed=1)

################## INDEPENDENT PARAMETERS ################
# Node parameters
t_mem = 10*ms
t_adapt_e = 100*ms
t_adapt_i = 100*ms
t_refract_e = 6.3*ms       # 6.3 ± 1.7 Raastad 2003
t_refract_std_e = 1.7*ms
t_refract_i = 5*ms # jittered with std=1 below
v_rest = -70*mV
v_adapt_step_e = 6*mV # 4
v_adapt_step_i = 5*mV # 4
v_thresh = -50*mV
v_reset = -65*mV
v_gaba = -80*mV

# Network parameters
net_size = 1.0*cm # circumference of toroidal network - brain width: mouse=1cm, human=15cm
nxe = 100 #100
nye = 100 #100
nxi = 40 #40
nyi = 40 #40
p0ee = 0.8
lradee = 30 # in units of neuron spacing
p0ei = 0.4
lradei = 2 # in units of neuron spacing
p0ie = 0.3
lradie = 5 # [5] in units of neuron spacing
wbase_ee = 0.0
wrange_ee = 0.5
wbase_ei = 3.0
wrange_ei = 6.0
wbase_ie = -4.0
wrange_ie = -5.0
delay_ee = True
dendritic_delay_min = 2.0 * ms # 3ms from Jarsky NatNeuro 2005
dendritic_delay_range = 2.0 * ms
conduction_velocity = 5.0 * meter / second

# STDP parameters
t_stdppre = 20.0 * ms
t_stdppost = 20.0 * ms
dApre = 0.16 * mV # 0.2
STDP_neg_pos_ratio = 1.05 # 1.05 Greater than 1 favors decay of weights
wmax = 5.0 * mV # Need about 20mV to trigger a spike
STDP_delayed = True

# Stimulus parameters
attractor_size = 500
attractor_type = 'random' # 'random' or 'circular'
stim_freq = 8.0 # Hz
stim_duty = 0.25 # 0.25
stim_time = 1100 * ms # 1100
stim_dt = 1 * ms
stim_rate = 200 * Hz # 200
stim_ramp_on = 0.1
stim_ramp_off = 0.1
restim_delay = 0 * ms # 200
restim_time = 10 * ms # 10
restim_runtime = 140 * ms # 200
restim_fraction = 0.25 # 0.2
restim_density = 1.0 # 1.0

attractor_N = 3
attractor_cycles = 1
disjoint_attractors = False

# Timing parameters
defaultclock.dt = 0.01*ms

################## CALCULATED PARAMETERS ################
# Network geometry
Ne = nxe * nye
Ni = nxi * nyi
dxe = net_size / nxe
dye = net_size / nye
dxi = net_size / nxi
dyi = net_size / nyi

# STDP
dApost = -dApre * t_stdppre / t_stdppost * STDP_neg_pos_ratio

########################## ATTRACTOR ########################
# Set attractor nodes
attractor_nodes = []
attractor_index = []
if disjoint_attractors:
    all_anodes = np.random.choice(Ne,(attractor_N,attractor_size),replace=False)
for ia in range(attractor_N):
    if disjoint_attractors:
        anodes = all_anodes[ia,:]
    else:
        anodes = np.random.choice(Ne,attractor_size,replace=False)
    if attractor_type == 'circular':
        anodes = np.sort(anodes)
    attractor_nodes.append(anodes)
    attractor_index.append({e:i for i,e in enumerate(anodes)}) # lookup dictionary for getting attractor index from node number
all_attractor_nodes = np.concatenate(attractor_nodes)

########################## STIMULUS ########################
# Set stimulus
nstim = int((stim_time+restim_delay)/stim_dt)
ex = np.zeros((attractor_N*nstim,attractor_N*attractor_size))
nex1 = int(stim_time/stim_dt)
for iex in range(attractor_N):
    ex1 = es.excitation_schedule(attractor_size,'circular',stim_dt/ms,stim_time/ms,stim_duty,stim_freq,ramp_on=stim_ramp_on,ramp_off=stim_ramp_off)
    ex[iex*nstim:iex*nstim+nex1,iex*attractor_size:(iex+1)*attractor_size] = ex1
stimulus_teach = TimedArray(ex*stim_rate,dt=stim_dt)

# Set overlapping restim to teach connections
nrestim = int(restim_runtime/stim_dt)
restim = np.zeros((attractor_cycles*attractor_N*nrestim,attractor_N*attractor_size))
for iex in range(attractor_cycles*attractor_N):
    restim_nodes = np.random.choice(int(attractor_size * restim_fraction),int(attractor_size * restim_fraction * restim_density))
    restim[iex*nrestim:iex*nrestim+int(restim_time/stim_dt),(iex%attractor_N)*attractor_size+restim_nodes] = 1.
stimulus_connect = TimedArray(restim*stim_rate,dt=stim_dt)

# Test by stimulating just one attractor
atest = 0
testim = np.zeros((attractor_N*nrestim,attractor_N*attractor_size))
testim_nodes = np.random.choice(int(attractor_size * restim_fraction),int(attractor_size * restim_fraction * restim_density))
testim[:int(restim_time/stim_dt),atest*attractor_size+testim_nodes] = 1.
stimulus_test = TimedArray(testim*stim_rate,dt=stim_dt)

########################## MODEL ########################
# E nodes decay towards v_rest-v_adapt, and v_adapt decays towards 0
eqs_e = '''
dv/dt = (v_rest-v-v_adapt)/t_mem : volt (unless refractory)
dv_adapt/dt = (-v_adapt)/t_adapt_e : volt
x : meter
y : meter
net_size : meter
refract : second
'''
reset_e = '''
v = v_reset
v_adapt += v_adapt_step_e
'''

# I nodes decay towards v_rest-v_adapt, and v_adapt decays towards 0
eqs_i = '''
dv/dt = (v_rest-v-v_adapt)/t_mem : volt (unless refractory)
dv_adapt/dt = (-v_adapt)/t_adapt_i : volt
x : meter
y : meter
net_size : meter
refract : second
'''
reset_i = '''
v = v_reset
v_adapt += v_adapt_step_i
'''

# STDP Synapse equations
STDP_eqs = '''
w : volt
dApre/dt = -Apre / t_stdppre : volt (event-driven)
dApost/dt = -Apost / t_stdppost : volt (event-driven)
'''
STDP_onpre = '''
v_post += w
Apre += dApre
'''
STDP_onpost = '''
Apost += dApost
'''
if STDP_delayed:        # Accumulate prospective changes in dw to apply later
    STDP_eqs = STDP_eqs + 'dw : volt\n'
    STDP_onpre = STDP_onpre + 'dw += Apost\n'
    STDP_onpost = STDP_onpost + 'dw += Apre\n'
else:                   # Apply weight changes immediately
    STDP_onpre = STDP_onpre + 'w = clip(w + Apost, 0, wmax)\n'
    STDP_onpost = STDP_onpost + 'w = clip(w + Apre, 0, wmax)\n'
    
# Lorentz connection probability (mod (%) stuff makes it toroidal)
pLorentz = 'p0 / (1 + (((x_pre-x_post + 1.5*net_size_post) % net_size_post - 0.5*net_size_post)**2 \
                     + ((y_pre-y_post + 1.5*net_size_post) % net_size_post - 0.5*net_size_post)**2) / (lrad)**2)'

# Calculated conduction delay (mod (%) stuff makes it toroidal)
conduction_delay = 'dendritic_delay_min + rand() * dendritic_delay_range \
                        + sqrt( (((x_pre-x_post + 1.5*net_size_post) % net_size_post - 0.5*net_size_post)**2 \
                        + ((y_pre-y_post + 1.5*net_size_post) % net_size_post - 0.5*net_size_post)**2) ) / conduction_velocity'

###################### CONSTRUCT NETWORK ####################
# Generate nodes
PGteach = PoissonGroup(attractor_N*attractor_size, rates='stimulus_teach(t,i)',name='PGteach')
PGconnect  = PoissonGroup(attractor_N*attractor_size, rates='stimulus_connect(t,i)',name='PGconnect')
PGtest  = PoissonGroup(attractor_N*attractor_size, rates='stimulus_test(t,i)',name='PGtest')
E = NeuronGroup(nxe * nye, eqs_e, threshold='v>v_thresh', reset=reset_e, refractory='refract', method='euler', name='E')
I = NeuronGroup(nxi * nyi, eqs_i, threshold='v>v_thresh', reset=reset_i, refractory='refract', method='euler', name='I')

# set neuron locations (zero indexed, so net is in quadrant 1)
E.x = '(i % nxe) * dxe'
E.y = '(i // nxe) * dye'
I.x = '(i % nxi) * dxi'
I.y = '(i // nxi) * dyi'
E.net_size = net_size
I.net_size = net_size

# Set refractory periods with jitter
E.refract = t_refract_e + t_refract_std_e * randn(len(E))
I.refract = t_refract_i * (1.0 + 0.2 * randn(len(I)))

# Connect Poisson Groups to E
SPGteach = Synapses(PGteach, E, on_pre='v_post += 50*mV') # 50 guarantees firing of post
SPGteach.connect(i=np.arange(attractor_N*attractor_size),j=all_attractor_nodes)
SPGconnect  = Synapses(PGconnect, E, on_pre='v_post += 50*mV') # 50 guarantees firing of post
SPGconnect.connect(i=np.arange(attractor_N*attractor_size),j=all_attractor_nodes)
SPGtest  = Synapses(PGtest, E, on_pre='v_post += 50*mV') # 50 guarantees firing of post
SPGtest.connect(i=np.arange(attractor_N*attractor_size),j=all_attractor_nodes)

# Connect EE, EI, IE
# No self connections in EE network
SEE = Synapses(E, E, STDP_eqs, on_pre=STDP_onpre, on_post=STDP_onpost, name='EE')
SEE.variables.add_constant('p0',p0ee)
SEE.variables.add_constant('lrad',lradee * net_size / nxe) # Convert from neuron spacing units to meters here
SEE.connect(condition = 'i != j', p=pLorentz)
SEE.w = '(wbase_ee + wrange_ee*rand())*mV'
if delay_ee == True:
    SEE.delay = conduction_delay

SEI = Synapses(E, I, 'w : volt', on_pre='v_post += w', name='EI')
SEI.variables.add_constant('p0',p0ei)
SEI.variables.add_constant('lrad',lradei * net_size / nxi) # Convert from neuron spacing units to meters here
SEI.connect(p=pLorentz)
SEI.w = '(wbase_ei + wrange_ei*rand())*mV'

SIE = Synapses(I, E, 'w : volt', on_pre='v_post = clip((v_post + w), v_gaba, 0)', name='IE')
SIE.variables.add_constant('p0',p0ie)
SIE.variables.add_constant('lrad',lradie * net_size / nxe) # Convert from neuron spacing units to meters here
SIE.connect(p=pLorentz)
SIE.w = '(wbase_ie + wrange_ie*rand())*mV'

########################## GO ########################
# Monitor stuff
MSE = SpikeMonitor(E,name='Espikemon')
# MSI = SpikeMonitor(I,name='Ispikemon')
# MVE = StateMonitor(E,'v',True,dt=max(defaultclock.dt,1.*ms),name='Estatemon')
# MVI = StateMonitor(I,'v',True,dt=max(defaultclock.dt,1.*ms),name='Istatemon')

# Initialize nodes at v_rest
E.v = v_rest
I.v = v_rest
store('state')

# Initialize records for each teach/test period
time_record = []
spike_record = []

# Train on each attractor
PGconnect.active = False # only teacher for now
PGtest.active = False # only teacher for now
run(attractor_N * (stim_time + restim_delay) )
time_record.append(MSE.t/ms)
spike_record.append(1*MSE.i)

# Update weights and sequentially stimulate attractors
dw0 = 1.0 * SEE.dw # Multiply by 1.0 to force copy
restore('state')
SEE.w = clip(SEE.w + dw0,0,wmax)
store('state')
SEE.dw = 0
PGteach.active = False
PGconnect.active = True
w1 = bws.brian_weight_submatrix(SEE,all_attractor_nodes,all_attractor_nodes)
run(attractor_cycles * attractor_N * restim_runtime)
time_record.append(MSE.t/ms)
spike_record.append(1*MSE.i)

# Update weights and stimulate just one attractor
dw0 = 1.0 * SEE.dw # Multiply by 1.0 to force copy
restore('state')
SEE.w = clip(SEE.w + 0.25*dw0,0,wmax)
PGconnect.active = False
PGtest.active = True
w2 = bws.brian_weight_submatrix(SEE,all_attractor_nodes,all_attractor_nodes)
SEE.dw = 0
run(attractor_N * restim_runtime)
time_record.append(MSE.t/ms)
spike_record.append(1*MSE.i)


In [3]:
figure(figsize=(12,4))

period = (0,1,2)
for idx,per in enumerate(period):
    for ia in range(attractor_N):
        spikes_in_attractor = np.isin(spike_record[per],attractor_nodes[ia])
        z = spike_record[per][spikes_in_attractor]
        spikes = np.ndarray(z.shape)
        for node in attractor_index[ia]:
            spikes[z==node] = attractor_index[ia][node]
        time = time_record[per][spikes_in_attractor]
        subplot(1,3,idx+1)
        plot(time, spikes+ia*attractor_size, 'k.',markersize=2)
        xlabel('Time (ms)')
        ylabel('Neuron Index')

tight_layout()


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

  subplot(1,3,idx+1)
  subplot(1,3,idx+1)
  subplot(1,3,idx+1)
  subplot(1,3,idx+1)
  subplot(1,3,idx+1)
  subplot(1,3,idx+1)


In [6]:
wN = np.zeros((attractor_N*attractor_size,attractor_N*attractor_size))
diags = []
for iin in range(attractor_N):
    for iout in range(attractor_N):
        submat = bws.brian_weight_submatrix(SEE,attractor_nodes[iin],attractor_nodes[iout])
        wN[iout*attractor_size:(iout+1)*attractor_size,iin*attractor_size:(iin+1)*attractor_size] = submat
        diags.append(ds.diagonal_sums(submat))
figure(figsize=(12,6))
subplot(121)
imshow(wN)
subplot(122)
plot(np.asarray(diags).T)
tight_layout()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [5]:
time_record

[array([2.2000e+00, 3.0000e+00, 3.0000e+00, ..., 3.2942e+03, 3.2956e+03,
        3.2978e+03]),
 array([4.000e-01, 4.000e-01, 6.000e-01, ..., 4.198e+02, 4.198e+02,
        4.198e+02]),
 array([2.000e-01, 2.000e-01, 2.000e-01, ..., 4.198e+02, 4.198e+02,
        4.198e+02])]

In [28]:
attractor_indices_sorted = np.argsort(attractor_nodes)
attractor_sorted = attractor_nodes[attractor_indices_sorted]
spikes_in_attractor = np.isin(MSE.i,attractor_nodes)
idx_in_sort = np.searchsorted(attractor_sorted,MSE.i[spikes_in_attractor])
time = MSE.t[spikes_in_attractor]/ms
spikes = attractor_indices_sorted[idx_in_sort]

figure(figsize=(12,12))
subplot(331)
plot(time, spikes, 'k.',markersize=2)

subplot(332)
wts = SEE.w/mV
hist(wts[wts>1], 20)

subplot(333)
imshow(bws.brian_weight_submatrix(SEE,attractor_nodes,attractor_nodes))

nbin = 800
real,timeline = np.histogram(time,bins=nbin,weights=np.cos(2.0*np.pi*spikes/500))
imag,timeline = np.histogram(time,bins=nbin,weights=np.sin(2.0*np.pi*spikes/500))
cplx = real + 1.j * imag

freq = 1000. * (angle(cplx[1:] / cplx[:-1])) / (2. * np.pi * (timeline[1]-timeline[0]))
from scipy.signal import savgol_filter
fhat = savgol_filter(freq, 21, 2) # window size 51, polynomial order 3

subplot(336)
plot(timeline[1:-1],freq)
plot(timeline[1:-1],fhat,'r')

subplot(339)
plot(real,imag)
scatter(real,imag,c=range(nbin),cmap='jet')

subplot(334)
imshow(MVE.v / mV,aspect='auto')
subplot(335)
plot(MSE.t/ms, MSE.i, 'ko')
plot(MSE.t[spikes_in_attractor]/ms, MSE.i[spikes_in_attractor], 'r+')
xlabel('Time (ms)')
ylabel('Neuron index')
subplot(337)
imshow(MVI.v / mV,aspect='auto')
subplot(338)
plot(MSI.t/ms, MSI.i, '.k')
xlabel('Time (ms)')
ylabel('Neuron index')
tight_layout()

TypeError: only integer scalar arrays can be converted to a scalar index

In [4]:
tmin = 50
tmax = tmin+100
sseg = spikes[time>tmin]
tseg = time[time>tmin]
sseg = sseg[tseg<tmax]
tseg = tseg[tseg<tmax]

segfreq = np.full(attractor_size,np.nan)
segphase = np.full(attractor_size,np.nan)
tstart = np.amin(tseg)
for inode in range(attractor_size):
    times = tseg[np.argwhere(sseg==inode)]
    if times.size > 1:
        segfreq[inode] = 1000 * (times.size-1) / (np.amax(times)-np.amin(times))
        segphase[inode] = (2.0e-3 * np.pi * (np.amin(times)-tstart) * segfreq[inode]) % (2*np.pi)

figure(figsize=(12,3))
subplot(141)
plot(tseg, sseg, 'k.',markersize=2)
subplot(142)
#plot(segfreq)
hist(segfreq[~np.isnan(segfreq)],20)
subplot(143)
#plot(segphase)
scatter(attractor_nodes%nye,attractor_nodes//nye,c=segphase,cmap='gist_rainbow')
colorbar()

nbin = 80
real,timeline = np.histogram(tseg,bins=nbin,weights=np.cos(2.0*np.pi*sseg/500))
imag,timeline = np.histogram(tseg,bins=nbin,weights=np.sin(2.0*np.pi*sseg/500))
cplx = real + 1.j * imag

freq = 1000. * (angle(cplx[1:] / cplx[:-1])) / (2. * np.pi * (timeline[1]-timeline[0]))
from scipy.signal import savgol_filter
fhat = savgol_filter(freq, 11, 2) # window size 51, polynomial order 3

subplot(144)
plot(timeline[1:-1],freq)
plot(timeline[1:-1],fhat,'r')

tight_layout()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [6]:
figure(figsize=(12,4))
subplot(131)
imshow(w1)
colorbar()
subplot(132)
imshow(w2)
colorbar()
subplot(133)
plot(ds.diagonal_sums(w1))
plot(ds.diagonal_sums(w2),'r')
tight_layout()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [60]:
scheduling_summary()

object,part of,Clock dt,when,order,active
Estatemon (StateMonitor),Estatemon (StateMonitor),1. ms (every 5 steps),start,0,yes
Istatemon (StateMonitor),Istatemon (StateMonitor),1. ms (every 5 steps),start,0,yes
E_stateupdater (StateUpdater),E (NeuronGroup),200. us (every step),groups,0,yes
I_stateupdater (StateUpdater),I (NeuronGroup),200. us (every step),groups,0,yes
E_thresholder (Thresholder),E (NeuronGroup),200. us (every step),thresholds,0,yes
I_thresholder (Thresholder),I (NeuronGroup),200. us (every step),thresholds,0,yes
PGteach_thresholder (Thresholder),PGteach (PoissonGroup),200. us (every step),thresholds,0,no
PGtest_thresholder (Thresholder),PGtest (PoissonGroup),200. us (every step),thresholds,0,yes
Espikemon (SpikeMonitor),Espikemon (SpikeMonitor),200. us (every step),thresholds,1,yes
Ispikemon (SpikeMonitor),Ispikemon (SpikeMonitor),200. us (every step),thresholds,1,yes


In [12]:
print(spike_record[0][:10],spike_record[1][:10],spike_record[0].shape,spike_record[1].shape)

[3226 5172 5591 8404 8501 7016 9897 8302 9239  427] [2181 2331 3422 5214 5363 6134 2589 6564 6863 7385] (9403,) (1816,)


In [9]:
figure(figsize=(12,4))
subplot(131)
idx = 0
targets = SEE.j[idx, :] 
plot(E.x[targets] / cm, E.y[targets] / cm, 'r.')

subplot(132)
targets = SEI.j[idx, :] 
plot(I.x[targets] / cm, I.y[targets] / cm, 'r.')

subplot(133)
targets = SIE.j[idx, :] 
plot(E.x[targets] / cm, E.y[targets] / cm, 'r.')

subplot(131)
idx = 5150
targets = SEE.j[idx, :] 
plot(E.x[targets] / cm, E.y[targets] / cm, '.')

subplot(132)
targets = SEI.j[idx, :] 
plot(I.x[targets] / cm, I.y[targets] / cm, '.')

subplot(133)
idx = 820
targets = SIE.j[idx, :] 
plot(E.x[targets] / cm, E.y[targets] / cm, '.')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

  subplot(131)
  subplot(132)
  subplot(133)


[<matplotlib.lines.Line2D at 0x7f882002a0a0>]

In [4]:
print(np.mean(w1),np.mean(w2),w1.size,np.count_nonzero(w1),np.mean(w1)*w1.size/np.count_nonzero(w1))


0.23094016687739957 0.2412829279812967 250000 38810 1.4876331285583586


In [4]:
import pickle
time = MSE.t[spikes_in_attractor]/ms
spikes = attractor_indices_sorted[idx_in_sort]
mylist = [time,spikes]
with open('example_spikes2.pkl', 'wb') as f:
    pickle.dump(mylist, f)

In [10]:
print(defaultclock.t)

<defaultclock.t: 2.8 * second>


In [35]:
print(not STDP_immediate)

False


In [7]:
with open('params.csv', 'w') as f:
    f.write(s)

In [11]:
dir(SEE.delay)

['__add__',
 '__array__',
 '__array_prepare__',
 '__array_wrap__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__div__',
 '__doc__',
 '__eq__',
 '__floordiv__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__gt__',
 '__hash__',
 '__iadd__',
 '__idiv__',
 '__imul__',
 '__init__',
 '__init_subclass__',
 '__isub__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__mul__',
 '__ne__',
 '__neg__',
 '__new__',
 '__pos__',
 '__radd__',
 '__rdiv__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__rfloordiv__',
 '__rmul__',
 '__rsub__',
 '__rtruediv__',
 '__setattr__',
 '__setitem__',
 '__sizeof__',
 '__str__',
 '__sub__',
 '__subclasshook__',
 '__truediv__',
 '__weakref__',
 'dim',
 'dtype',
 'get_item',
 'get_subexpression_with_index_array',
 'get_with_expression',
 'get_with_index_array',
 'group',
 'group_name',
 'index_var',
 'index_var_name',
 'indexing',
 'name',
 'set_item',
 'set_with_expression',
 'set_with_expression_conditional',
 'set_with_inde

In [17]:
print(SEE.delay[323],SEE.j.shape,SEE.w.shape)

1.27121989 ms (19181504,) (19181504,)


In [17]:
print(segfreq[~np.isnan(segfreq)].T.shape)

(406,)


In [8]:
print(segfreq.shape)

(500,)
