# Test Attractor Dynamics with LIF Neurons

In [1]:
# Imports and funcs
%matplotlib widget
from brian2 import *
import numpy as np
import scipy.sparse as sp
import excitation_schedule as es
import brian_weight_submatrix as bws
import diagonal_sums as ds
import time
import datetime
import pickle

In [2]:
start_scope()
# The following line suppresses a warning about order of executions in the abstract code: "v_post = clip(v_post + w, v_gaba, 0)"
# As far as I can tell, the warning is due to the inability of the OOE checker to deal with the "clip()" function, and is ok in this case
BrianLogger.suppress_hierarchy('brian2.codegen')

################## INDEPENDENT PARAMETERS ################
# Node parameters
t_mem = 10*ms
t_adapt_e = 100*ms
t_adapt_i = 100*ms
t_refract_e = 6.3*ms       # 6.3 ± 1.7 Raastad 2003
t_refract_std_e = 1.7*ms
t_refract_i = 5*ms # jittered with std=1 below
v_rest = -70*mV
v_adapt_step_e = 6*mV # 4
v_adapt_step_i = 5*mV # 4
v_thresh = -50*mV
v_reset = -65*mV
v_gaba = -80*mV

# Network parameters
net_size = 15.0*cm # circumference of toroidal network - brain width: mouse=1cm, human=15cm
nxe = 100 #100
nye = 100 #100
nxi = 40 #40
nyi = 40 #40
p0ee = 0.8
lradee = 30 # in units of neuron spacing
p0ei = 0.4
lradei = 2 # in units of neuron spacing
p0ie = 0.3
lradie = 5 # [5] in units of neuron spacing
wbase_ee = 0.0
wrange_ee = 0.5
wbase_ei = 3.0
wrange_ei = 6.0
wbase_ie = -4.0
wrange_ie = -5.0
delay_ee = True
dendritic_delay_min = 2.0 * ms # 3ms from Jarsky NatNeuro 2005
dendritic_delay_range = 2.0 * ms
conduction_velocity = 5.0 * meter / second

# STDP parameters
t_stdppre = 20.0 * ms
t_stdppost = 20.0 * ms
dApre = 0.10 * mV # 0.3
STDP_neg_pos_ratio = 1.05 # 1.05 Greater than 1 favors decay of weights
wmax = 5.0 * mV # Need about 20mV to trigger a spike

# Stimulus parameters
attractor_size = 500
attractor_type = 'random' # 'random' or 'circular'
stim_freq = 8.0 # Hz
stim_duty = 0.25 # 0.25
stim_time = 1100 * ms # 1100
stim_dt = 1 * ms
stim_rate = 200 * Hz # 200
stim_ramp_on = 0.1
stim_ramp_off = 0.1
restim_delay = 0 * ms # 200
restim_time = 10 * ms # 10
restim_runtime = 140 * ms # 200
restim_fraction = 0.25 # 0.2
restim_density = 1.0 # 1.0

# Timing parameters
defaultclock.dt = 0.01*ms

################## CALCULATED PARAMETERS ################
# Network geometry
Ne = nxe * nye
Ni = nxi * nyi
dxe = net_size / nxe
dye = net_size / nye
dxi = net_size / nxi
dyi = net_size / nyi

# STDP
dApost = -dApre * t_stdppre / t_stdppost * STDP_neg_pos_ratio

########################## MODEL ########################
# E nodes decay towards v_rest-v_adapt, and v_adapt decays towards 0
eqs_e = '''
dv/dt = (v_rest-v-v_adapt)/t_mem : volt (unless refractory)
dv_adapt/dt = (-v_adapt)/t_adapt_e : volt
x : meter
y : meter
net_size : meter
refract : second
'''
reset_e = '''
v = v_reset
v_adapt += v_adapt_step_e
'''

# I nodes decay towards v_rest-v_adapt, and v_adapt decays towards 0
eqs_i = '''
dv/dt = (v_rest-v-v_adapt)/t_mem : volt (unless refractory)
dv_adapt/dt = (-v_adapt)/t_adapt_i : volt
x : meter
y : meter
net_size : meter
refract : second
'''
reset_i = '''
v = v_reset
v_adapt += v_adapt_step_i
'''

# STDP Synapse equations
# Candidate changes accumulate in dw for implementation later
STDP_eqs = '''
w : volt
dApre/dt = -Apre / t_stdppre : volt (event-driven)
dApost/dt = -Apost / t_stdppost : volt (event-driven)
dw : volt
'''
STDP_onpre = '''
v_post += w
Apre += dApre
dw += Apost
'''
STDP_onpost = '''
Apost += dApost
dw += Apre
'''

# Lorentz connection probability (mod (%) stuff makes it toroidal)
pLorentz = 'p0 / (1 + (((x_pre-x_post + 1.5*net_size_post) % net_size_post - 0.5*net_size_post)**2 \
                     + ((y_pre-y_post + 1.5*net_size_post) % net_size_post - 0.5*net_size_post)**2) / (lrad)**2)'

# Calculated conduction delay (mod (%) stuff makes it toroidal)
conduction_delay = 'dendritic_delay_min + rand() * dendritic_delay_range \
                        + sqrt( (((x_pre-x_post + 1.5*net_size_post) % net_size_post - 0.5*net_size_post)**2 \
                        + ((y_pre-y_post + 1.5*net_size_post) % net_size_post - 0.5*net_size_post)**2) ) / conduction_velocity'

################ LOOP OVER NETWORK PARAMETERS ##################
net_sizes = [1.0*cm, 15.0*cm]
time_threshold = 120.0 # ms 120 gives 100ms of oscillation (20-120) to measure frequency
gainres = 0.01 # 0.01 stopping criteria for threshold gain estimation
nreps = 5
for irep in range(nreps):
    seed(seed=irep)

    # Set attractor nodes
    attractor_nodes = np.random.choice(Ne,attractor_size,replace=False)
    if attractor_type == 'circular':
        attractor_nodes = np.sort(attractor_nodes)

    trainer_spike_record = []
    gain_record = []
    time_record = []
    spike_record = []
    weight_record = []
    for isize in range(len(net_sizes)):
        net_size = net_sizes[isize]

        ###################### CONSTRUCT NETWORK ####################
        # Generate nodes
        PGteach = PoissonGroup(attractor_size, rates='stimulus_teach(t,i)',name='PGteach')
        PGtest  = PoissonGroup(attractor_size, rates='stimulus_test(t,i)',name='PGtest')
        E = NeuronGroup(nxe * nye, eqs_e, threshold='v>v_thresh', reset=reset_e, refractory='refract', method='euler', name='E')
        I = NeuronGroup(nxi * nyi, eqs_i, threshold='v>v_thresh', reset=reset_i, refractory='refract', method='euler', name='I')

        # set neuron locations (zero indexed, so net is in quadrant 1)
        E.x = '(i % nxe) * dxe'
        E.y = '(i // nxe) * dye'
        I.x = '(i % nxi) * dxi'
        I.y = '(i // nxi) * dyi'
        E.net_size = net_size
        I.net_size = net_size

        # Set refractory periods with jitter
        E.refract = t_refract_e + t_refract_std_e * randn(len(E))
        I.refract = t_refract_i * (1.0 + 0.2 * randn(len(I)))

        # Connect Poisson Groups to E
        SPGteach = Synapses(PGteach, E, on_pre='v_post += 50*mV') # 50 guarantees firing of post
        SPGteach.connect(i=np.arange(attractor_size),j=attractor_nodes)
        SPGtest  = Synapses(PGtest, E, on_pre='v_post += 50*mV') # 50 guarantees firing of post
        SPGtest.connect(i=np.arange(attractor_size),j=attractor_nodes)

        # Connect EE, EI, IE
        # No self connections in EE network
        SEE = Synapses(E, E, STDP_eqs, on_pre=STDP_onpre, on_post=STDP_onpost, name='EE')
        SEE.variables.add_constant('p0',p0ee)
        SEE.variables.add_constant('lrad',lradee * net_size / nxe) # Convert from neuron spacing units to meters here
        SEE.connect(condition = 'i != j', p=pLorentz)
        SEE.w = '(wbase_ee + wrange_ee*rand())*mV'
        if delay_ee == True:
            SEE.delay = conduction_delay

        SEI = Synapses(E, I, 'w : volt', on_pre='v_post += w', name='EI')
        SEI.variables.add_constant('p0',p0ei)
        SEI.variables.add_constant('lrad',lradei * net_size / nxi) # Convert from neuron spacing units to meters here
        SEI.connect(p=pLorentz)
        SEI.w = '(wbase_ei + wrange_ei*rand())*mV'

        SIE = Synapses(I, E, 'w : volt', on_pre='v_post = clip((v_post + w), v_gaba, 0)', name='IE')
        SIE.variables.add_constant('p0',p0ie)
        SIE.variables.add_constant('lrad',lradie * net_size / nxe) # Convert from neuron spacing units to meters here
        SIE.connect(p=pLorentz)
        SIE.w = '(wbase_ie + wrange_ie*rand())*mV'

        # Initialize nodes at v_rest and store initial state
        E.v = v_rest
        I.v = v_rest

        # Monitor just E spikes
        MSE = SpikeMonitor(E,name='Espikemon')

        now = datetime.datetime.now()
        print('    ',now.strftime("%H:%M:%S"),'New network with size=',net_size/cm)
        store('initial_state')

        ################ LOOP OVER STIMULUS PARAMETERS ##################
        stim_duties = [0.125, 0.25, 0.5]
        stim_frequencies = list(range(1,10)) + list(range(10,21,2))
        # for test:
        # stim_duties = [0.125, 0.5]
        # stim_frequencies = [1, 20]

        for istim in range(len(stim_duties)*len(stim_frequencies)):
            stim_duty = stim_duties[istim // len(stim_frequencies)]
            stim_freq = stim_frequencies[istim % len(stim_frequencies)]

            ########################## STIMULUS ########################

            # Set stimulus
            ex = es.excitation_schedule(attractor_size,'circular',stim_dt/ms,stim_time/ms,stim_duty,stim_freq,ramp_on=stim_ramp_on,ramp_off=stim_ramp_off)
            ex = np.vstack((ex,np.zeros((int(restim_delay/stim_dt),attractor_size))))
            stimulus_teach = TimedArray(ex*stim_rate,dt=stim_dt)

            restim = np.zeros((int(restim_runtime/stim_dt),attractor_size))
            restim_nodes = np.random.choice(int(attractor_size * restim_fraction),int(attractor_size * restim_fraction * restim_density))
            restim[:int(restim_time/stim_dt),restim_nodes] = 1.
            stimulus_test = TimedArray(restim*stim_rate,dt=stim_dt)

            ########################## GO ########################
            # Run training
            restore('initial_state')
            PGteach.active = True # only teacher for now
            PGtest.active = False
            run(stim_time + restim_delay)
            dw0 = 1.0 * SEE.dw # Multiply by 1.0 to force copy
            trainer_spike_record.append(MSE.num_spikes)

            gainlo = 0.
            gainhi = 1.e6
            gaintest = 1.0
            gotgainhi = False
            done = False
            thi = np.nan
            shi = np.nan
            whi = np.nan
            now = datetime.datetime.now()
            print(now.strftime("%H:%M:%S"),'size=',net_size/cm,'stim_duty=',stim_duty,'stim_freq=',stim_freq,'trainer_spikes=',MSE.num_spikes,'\n  ', end = " ")
            while not done:
                print(gaintest, end = " ")
                restore('initial_state')
                PGteach.active = False
                PGtest.active = True # only test stim

                SEE.w = clip(SEE.w + gaintest*dw0,0,wmax)
                run(restim_runtime)
                if np.amax(MSE.t/ms)>time_threshold:
                    print('v', end = " ")
                    gainhi = gaintest
                    thi = MSE.t/ms
                    shi = 1*MSE.i
                    whi = bws.brian_weight_submatrix(SEE,attractor_nodes,attractor_nodes)
                    gotgainhi = True
                    gaintest = 0.5 * (gainhi + gainlo)
                else:
                    print('^', end = " ")
                    gainlo = gaintest
                    if gotgainhi == False:
                        gaintest *= 1.5
                    else:
                        gaintest = 0.5 * (gainhi + gainlo)
                if abs(2*(gainhi-gainlo)/(gainhi+gainlo)) < gainres:
                    done = True
                    print(gainhi, end = " ")
                    gain_record.append(gainhi)
                    time_record.append(thi)
                    spike_record.append(shi)
                    weight_record.append(whi)
                elif gaintest<0.05 or gaintest>20:
                    done = True
                    print('This one did not converge', end = " ")
                    gain_record.append(gaintest)
                    time_record.append(thi)
                    spike_record.append(shi)
                    weight_record.append(whi)

            # Run at twice threshold
            restore('initial_state')
            SEE.w = clip(SEE.w + 2.0*gainhi*dw0,0,wmax)
            run(restim_runtime)
            time_record.append(MSE.t/ms)
            spike_record.append(1*MSE.i)
            weight_record.append(bws.brian_weight_submatrix(SEE,attractor_nodes,attractor_nodes))
            print('Done')
    mylist = [net_sizes/cm, stim_duties, stim_frequencies, attractor_nodes, gain_record, time_record, spike_record, weight_record, trainer_spike_record]
    with open('frequency_dependence_data' + str(irep+1) + '.pkl', 'wb') as f:
        pickle.dump(mylist, f)

     14:37:03 New network with size= 1.0
14:37:58 size= 1.0 stim_duty= 0.125 stim_freq= 1 trainer_spikes= 4275 
   1.0 ^ 1.5 ^ 2.25 ^ 3.375 ^ 5.0625 ^ 7.59375 v 6.328125 ^ 6.9609375 ^ 7.27734375 v 7.119140625 ^ 7.1982421875 v 7.15869140625 v 7.15869140625 Done
14:40:22 size= 1.0 stim_duty= 0.125 stim_freq= 2 trainer_spikes= 4351 
   1.0 ^ 1.5 ^ 2.25 ^ 3.375 ^ 5.0625 v 4.21875 ^ 4.640625 v 4.4296875 v 4.32421875 ^ 4.376953125 v 4.3505859375 v 4.3505859375 Done
14:42:43 size= 1.0 stim_duty= 0.125 stim_freq= 3 trainer_spikes= 4349 
   1.0 ^ 1.5 ^ 2.25 ^ 3.375 ^ 5.0625 v 4.21875 v 3.796875 v 3.5859375 ^ 3.69140625 v 3.638671875 v 3.6123046875 v 3.6123046875 Done
14:45:28 size= 1.0 stim_duty= 0.125 stim_freq= 4 trainer_spikes= 4442 
   1.0 ^ 1.5 ^ 2.25 ^ 3.375 v 2.8125 ^ 3.09375 ^ 3.234375 v 3.1640625 v 3.12890625 ^ 3.146484375 ^ 3.1640625 Done
14:47:42 size= 1.0 stim_duty= 0.125 stim_freq= 5 trainer_spikes= 4383 
   1.0 ^ 1.5 ^ 2.25 ^ 3.375 v 2.8125 ^ 3.09375 ^ 3.234375 v 3.1640625 ^ 3.199