In [2]:
from brian2 import *
from matrices import import_light, import_dog_inh, import_v1
%matplotlib inline
import numpy as np
import pytime

nL1 = nL2 = 5**2
nL3 = nL1*4
nL4 = 10
n = np.sqrt(nL1).astype(int)

[X,Y] = np.meshgrid(np.linspace(1,n,n,dtype=int), np.linspace(1,n,n,dtype=int))

taum = 10*ms
taue = taui = 5*ms
duration = 500*ms

eqsPOISSON ='''
rates : Hz
x : 1 (constant)
y : 1 (constant)
'''

eqsCUBALIF ='''
dv/dt  = (ge-v)/taum : volt (unless refractory)
dge/dt = -ge/taue : volt
x : 1 (constant)
y : 1 (constant)
'''

eqsV1 ='''
dv/dt  = (ge-v)/taum : volt (unless refractory)
dge/dt = -ge/taue : volt
x : 1 (constant)
y : 1 (constant)
ori : 1 (constant)
'''

letters = ['C','D','H','K','N','O','R','S','V','Z']
ind = 0
real_letter = letters[ind]

################################################## Layer 1 ################################################################

L1 = NeuronGroup(nL1,
                 eqsPOISSON,
                 threshold='rand()<rates*dt')

L1.x = X.flatten()
L1.y = Y.flatten()
pat = import_light()
L1.rates = pat[ind]*100*Hz

################################################## Layer 2 ################################################################

L2 = NeuronGroup(nL2,
                 eqsCUBALIF,
                 threshold = 'v>0.3*volt',
                 reset = 'v=0*volt',
                 refractory = 5*ms,
                 method = 'euler')

L2.x = X.flatten()
L2.y = Y.flatten()

e1 = Synapses(L1, L2, on_pre='ge += 8*volt')
e1.connect(condition = 'i==j')

i1 = Synapses(L1, L2, on_pre='ge -= 1*volt')
i1.connect(condition = 'sqrt((x_pre - x_post)**2 + (y_pre-y_post)**2) < 2 and i!=j')

################################################## Layer 3 ################################################################

L3 = NeuronGroup(nL3,
                 eqsV1,
                 threshold='v>0.3*volt',
                 reset='v=0*volt',
                 refractory=5*ms,
                 method='euler')

L3.ori = np.repeat(np.array([0,1,2,3]), nL1)

X3 = np.tile(X.flatten(), (1,4))
X3.flatten()
L3.x = X3.flatten()

Y3 = np.tile(Y.flatten(), (1,4))
Y3.flatten()
L3.y = Y3.flatten()

e2 = Synapses(L2, L3, on_pre='ge += volt/3')
cond_ver = 'ori_post==0 and (y_pre-y_post)**2<=1 and x_pre==x_post or '
cond_hor= 'ori_post==1 and (x_pre-x_post)**2<=1 and y_pre==y_post or '
cond_TLBR = 'ori_post==2 and (x_pre==x_post-1 and y_pre==y_post-1 or x_pre==x_post and y_pre==y_post or x_pre==x_post+1 and y_pre==y_post+1) or '
cond_TRBL = 'ori_post==3 and (x_pre==x_post+1 and y_pre==y_post-1 or x_pre==x_post and y_pre==y_post or x_pre==x_post-1 and y_pre==y_post+1)'
cond = cond_ver + cond_hor + cond_TLBR + cond_TRBL
e2.connect(condition = cond)

################################################## Layer 4 ################################################################

L4 = NeuronGroup(nL4,
                 eqsCUBALIF,
                 threshold='v>0.3*volt',
                 reset='v=0*volt',
                 refractory=5*ms,
                 method='euler')

e3 = Synapses(L3, L4, 'w : volt', on_pre='ge += w')
e3.connect()
W_V1 = np.load('V1_weights.npy')
e3.w = W_V1.flatten()*volt*2

#################################################### Run ##################################################################

sp = SpikeMonitor(L4)
run(duration)
mfr = sp.count/duration

print(f'Real letter: {real_letter}')
print('Classification')
print('  C   D   H    K   N   O   R   S   V    Z')
mfr

Real letter: C
Classification
  C   D   H    K   N   O   R   S   V    Z


array([66.,  0.,  0.,  0.,  0., 22.,  0.,  0.,  0.,  0.]) * hertz