# STDP learning
<br>
From 10x10x8 complex cells to each digit (classifications)

In [1]:
from brian2 import *
%matplotlib inline
import numpy as np
import pytime

x_train = np.load('x_train.npy')
t_train = np.load('t_train.npy')

n = 28

angles = np.array([0, 1/6, 1/4, 1/3, 1/2, 2/3, 3/4, 5/6])*np.pi

def norm_weight(W):
    
    tot_exc = sum(sum(W))
    new_W = W/tot_exc
    
    return new_W

def visualize_Hypercolumn_8ori(mfr, n):

    titles = ['0°','30°','45°','60°','90°','120°','135°','150°']
    nL = n**2
    
    for ii in range(8):
        
        a = asarray(mfr[nL*ii : nL*(ii+1)])
        b = a.reshape(n,n)
        
        plt.subplot(2,4,ii+1)
        plt.imshow(b,cmap='gray')
        plt.xticks([])
        plt.yticks([])
        plt.title(titles[ii])

[X,Y] = np.meshgrid(range(n), range(n))

taum = 10*ms
taui = 50*ms
taue = 100*ms

dge = 80*mV
dge_c = 500*mV

eqsPOISSON='''
rates : Hz
x : 1 (constant)
y : 1 (constant)
'''

eqsLGN='''
dv/dt  = (ge+gi-v)/taum : volt (unless refractory)
dge/dt = -ge/taue : volt
dgi/dt = -gi/taui : volt
x : 1 (constant)
y : 1 (constant)
'''

eqsV1='''
dv/dt  = (ge+gi-v)/taum : volt (unless refractory)
dge/dt = -ge/taue : volt
dgi/dt = -gi/taui : volt
x : 1 (constant)
y : 1 (constant)
ori : 1 (constant)
'''

eqsClass='''
dv/dt  = (ge-v)/taum : volt (unless refractory)
dge/dt = -ge/taue : volt
'''

################################################### Retina ################################################################

nL1 = n**2

L1 = NeuronGroup(nL1,
                 eqsPOISSON,
                 threshold='rand()<rates*dt')

L1.x = X.flatten()
L1.y = Y.flatten()

################################################### LGN 2A ################################################################

nL2a = n**2

L2a = NeuronGroup(nL2a,
                 eqsLGN,
                 threshold = 'v>0.3*volt',
                 reset = 'v=0*volt',
                 refractory = 5*ms,
                 method = 'euler')

L2a.x = X.flatten()
L2a.y = Y.flatten()

e1a = Synapses(L1, L2a, on_pre='ge += dge')
e1a.connect(condition = 'i==j')

i1a = Synapses(L1, L2a, on_pre='gi -= 1/8*dge')
i1a.connect(condition = 'sqrt((x_pre-x_post)**2+(y_pre-y_post)**2)<2 and i!=j')

################################################ V1 simple A ##############################################################

nL3a = nL2a*8

L3a = NeuronGroup(nL3a,
                 eqsV1,
                 threshold='v>0.3*volt',
                 reset='v=0*volt',
                 refractory=5*ms,
                 method='euler')

L3a.ori = np.repeat(angles, nL2a)

X3 = np.tile(X.flatten(), (1,8))
X3.flatten()
L3a.x = X3.flatten()

Y3 = np.tile(Y.flatten(), (1,8))
Y3.flatten()
L3a.y = Y3.flatten()

angles = np.array([0, 1/6, 1/4, 1/3, 1/2, 2/3, 3/4, 5/6])*np.pi

ce = 'abs(y_pre-y_post)<3 and abs(x_pre-x_post)<3 and ('
ce += 'ori_post==0 and x_post==x_pre or '
ce += 'ori_post==1/6*pi and abs(1.5*(x_post-x_pre)-(y_post-y_pre))<1 or '
ce += 'ori_post==1/4*pi and abs(y_post-y_pre-x_post+x_pre)==0 or '
ce += 'ori_post==1/3*pi and abs((x_post-x_pre)-1.5*(y_post-y_pre))<1 or '
ce += 'ori_post==1/2*pi and y_post==y_pre or '
ce += 'ori_post==2/3*pi and abs((x_post-x_pre)+1.5*(y_post-y_pre))<1 or '
ce += 'ori_post==3/4*pi and abs(y_post-y_pre+x_post-x_pre)==0 or '
ce += 'ori_post==5/6*pi and abs(1.5*(x_post-x_pre)+(y_post-y_pre))<1)'

ci = 'abs(y_pre-y_post)<3 and abs(x_pre-x_post)<3 and ('
ci += 'ori_post==0 and abs(x_post-x_pre)==1 or '
ci += 'ori_post==1/6*pi and (abs(1.5*(x_post-x_pre-1)-(y_post-y_pre))<1 or abs(1.5*(x_post-x_pre+1)-(y_post-y_pre))<1) or '
ci += 'ori_post==1/4*pi and abs(y_post-y_pre-x_post+x_pre)==1 or '
ci += 'ori_post==1/3*pi and (abs((x_post-x_pre)-1.5*(y_post-y_pre-1))<1 or abs((x_post-x_pre)-1.5*(y_post-y_pre+1))<1) or '
ci += 'ori_post==1/2*pi and abs(y_post-y_pre)==1 or '
ci += 'ori_post==2/3*pi and (abs((x_post-x_pre)+1.5*(y_post-y_pre-1))<1 or abs((x_post-x_pre)+1.5*(y_post-y_pre+1))<1) or '
ci += 'ori_post==3/4*pi and abs(y_post-y_pre+x_post-x_pre)==1 or '
ci += 'ori_post==5/6*pi and (abs(1.5*(x_post-x_pre-1)+(y_post-y_pre))<1 or abs(1.5*(x_post-x_pre+1)+(y_post-y_pre))<1))'

e2a = Synapses(L2a, L3a, on_pre='ge += 1/5*dge')
e2a.connect(condition = ce)

i2a = Synapses(L2a, L3a, on_pre='gi -= 1/9*dge')
i2a.connect(condition = ci)

############################################## V1 A complex A #############################################################

naa = 10
nL4aa = naa**2*8
[X4aa,Y4aa] = np.meshgrid(range(naa), range(naa))

L4aa = NeuronGroup(nL4aa,
                 eqsV1,
                 threshold='v>0.3*volt',
                 reset='v=0*volt',
                 refractory=5*ms,
                 method='euler')

L4aa.ori = np.repeat(angles, naa**2)

X4 = np.tile(X4aa.flatten(), (1,8))
X4.flatten()
L4aa.x = X4.flatten()

Y4 = np.tile(Y4aa.flatten(), (1,8))
Y4.flatten()
L4aa.y = Y4.flatten()

e3aa = Synapses(L3a, L4aa, on_pre='ge += 1/9*dge_c')
e3aa.connect(condition = 'ori_post==ori_pre and abs(3*x_post-x_pre)<=1 and abs(3*y_post-y_pre)<=1')

e3aa2 = Synapses(L3a, L4aa, on_pre='ge += 1/18*dge_c')
e3aa2.connect(condition = 'cos(2*(ori_post-ori_pre))>0.49 and ori_post!=ori_pre and abs(3*x_post-x_pre)<=1 and abs(3*y_post-y_pre)<=1')

e3aa3 = Synapses(L3a, L4aa, on_pre='ge -= 1/18*dge_c')
e3aa3.connect(condition = 'cos(2*(ori_post-ori_pre))<-0.49 and abs(3*x_post-x_pre)<=1 and abs(3*y_post-y_pre)<=1')

e3aa4 = Synapses(L3a, L4aa, on_pre='ge -= 1/9*dge_c')
e3aa4.connect(condition = 'cos(2*(ori_post-ori_pre))==-1 and abs(3*x_post-x_pre)<=1 and abs(3*y_post-y_pre)<=1')

################################################### STDP ##################################################################

n_class = 1

t_stdp = 5*ms
dge_class = 500*mV
Apre = 0.00001

L5 = NeuronGroup(n_class,
                 eqsClass,
                 threshold='v>0.3*volt',
                 reset='v=0*volt',
                 refractory=5*ms,
                 method='euler')

e4 = Synapses(L4aa, L5,
              '''w : 1
                dapre/dt = -apre/t_stdp : 1 (event-driven)''',
              on_pre = '''ge += w*dge_class
                    apre += Apre''',
              on_post = '''w = clip(w + apre, 0, 1)''',
              method=linear)
e4.connect()
e4.w = 1/nL4aa

###########################################################################################################################
###########################################################################################################################
###########################################################################################################################

start = time.time()
duration = 500*ms
FR = 200*Hz

#spa = SpikeMonitor(L4aa[0])
#spb = SpikeMonitor(L4aa[635])
#sp2 = SpikeMonitor(L5)

#M4 = StateMonitor(L4aa, 'ge', record=[635])
#M5 = StateMonitor(L5, 'v', record=True)

x_train_a = x_train[t_train==0,:]

for ii,pat in enumerate(x_train_a):
    
    if ii==x_train_a.shape[0]-1:
        M = StateMonitor(e4, 'w', record=True)
    
    L1.rates = pat*FR/255
    run(duration)
    
    tmp1 = e4.w
    tmp2 = norm_weight(tmp1)
    e4.w = tmp2
    
stop = time.time()
print(f'Elapsed time: {stop-start:.2f}')

np.save('w0',asarray(M.w))

INFO       Cannot use compiled code, falling back to the numpy code generation target. Note that this will likely be slower than using compiled code. Set the code generation to numpy manually to avoid this message:
prefs.codegen.target = "numpy" [brian2.devices.device.codegen_fallback]


Elapsed time: 25832.97
