/
LIF_FORCE_sinewave.py
171 lines (139 loc) · 5.54 KB
/
LIF_FORCE_sinewave.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import math
np.random.seed(seed=0)
N = 2000 #Number of neurons
dt = 5e-5
tref = 2e-3 #Refractory time constant in seconds
tm = 1e-2 #Membrane time constant
vreset = -65 #Voltage reset
vthr = -40 #Voltage threshold
vpeak = 30
td = 2e-2
tr = 2e-3
alpha = dt*0.1 #Sets the rate of weight change, too fast is unstable, too slow is bad as well.
Pinv = np.eye(N)*alpha #initialize the correlation weight matrix for RLMS
p = 0.1 #Set the network sparsity
#Target Dynamics for Product of Sine Waves
T = 15 # Simulation time (s)
imin = round(5/dt) # beginning time step of RLS training
icrit = round(10/dt) # end time step of RLS training
step = 50 # weights update time step
nt = round(T/dt) # Simulation time step
Q = 10; G = 0.04;
zx = np.sin(2*math.pi*np.arange(nt)*dt*5) # Target signal
k = 1 # number of output unit
IPSC = np.zeros(N) #post synaptic current storage variable
h = np.zeros(N) #Storage variable for filtered firing rates
r = np.zeros((N,1)) #second storage variable for filtered rates
hr = np.zeros(N) #Third variable for filtered rates
JD = np.zeros(N) #storage variable required for each spike time
tspike = np.zeros((4*nt,2)) #Storage variable for spike times
ns = 0 #Number of spikes, counts during simulation
z = np.zeros(k) #Initialize the approximant
v = vreset + np.random.rand(N)*(vpeak-vreset) #Initialize neuronal voltage with random distribtuions
OMEGA = G*(np.random.randn(N,N))*(np.random.rand(N,N)<p)/(math.sqrt(N)*p) #The initial weight matrix with fixed random weights
BPhi = np.zeros(N) #The initial matrix that will be learned by FORCE method
#Set the row average weight to be zero, explicitly.
for i in range(N):
QS = np.where(np.abs(OMEGA[i,:])>0)[0]
OMEGA[i,QS] = OMEGA[i,QS] - np.sum(OMEGA[i,QS], axis=0)/len(QS)
E = (2*np.random.rand(N)-1)*Q #n
# arrays to save
RECB = np.zeros((nt, 10)) #Storage matrix for the synaptic weights (a subset of them)
REC2 = np.zeros((nt,20))
REC = np.zeros((nt,10))
current = np.zeros(nt) #storage variable for output current/approximant
tlast = np.zeros(N) #This vector is used to set the refractory times
BIAS = vthr #Set the BIAS current, can help decrease/increase firing rates. 0 is fine.
#################
## Simulation ###
#################
for i in tqdm(range(nt)):
I = IPSC + E*z + BIAS #Neuronal Current
dv = ((dt*i) > (tlast + tref))*(-v + I) / tm #Voltage equation with refractory period
v = v + dt*dv
index = np.where(v>=vthr)[0] #Find the neurons that have spiked
# Store spike times, and get the weight matrix column sum of spikers
len_idx = len(index)
if len_idx>0:
JD = np.sum(OMEGA[:, index], axis=1) #compute the increase in current due to spiking
tspike[ns:ns+len_idx,:] = np.vstack((index, 0*index+dt*i)).T
ns = ns + len_idx # total number of psikes so far
tlast = tlast + (dt*i - tlast)*(v>=vthr) #Used to set the refractory period of LIF neurons
# Code if the rise time is 0, and if the rise time is positive
if tr == 0:
# synapse for single exponential
IPSC = IPSC*math.exp(-dt/td) + JD*(len_idx>0)/td
r = r[:,0]*math.exp(-dt/td) + (v>=vthr)/td
else:
# synapse for double exponential
IPSC = IPSC*math.exp(-dt/tr) + h*dt
h = h*math.exp(-dt/td) + JD*(len_idx>0)/(tr*td) #Integrate the current
r = r[:,0]*math.exp(-dt/tr) + hr*dt
hr = hr*math.exp(-dt/td) + (v>=vthr)/(tr*td)
r = np.expand_dims(r,1) # (N,) -> (N, 1)
# Implement RLMS with the FORCE method
z = BPhi.T @ r #approximant
err = z - zx[i] #error
# RLMS
if i % step == 1:
if i > imin:
if i < icrit:
cd = (Pinv @ r)
BPhi = BPhi - (cd @ err.T)
Pinv = Pinv - (cd @ cd.T) / (1.0 + r.T @ cd)
v = v + (vpeak - v)*(v>=vthr) # set peak voltage
REC[i] = v[:10] #Record a random voltage
v = v + (vreset - v)*(v>=vthr) #reset with spike time interpolant implemented.
current[i] = z
RECB[i,:] = BPhi[:10]
REC2[i,:] = r[:20,0]
#################
#### results ####
#################
TotNumSpikes = ns
M = tspike[tspike[:,1]>dt*icrit,:]
AverageRate = len(M)/(N*(T-dt*icrit))
print("\n")
print("Total number of spikes : ", TotNumSpikes)
print("Average firing rate(Hz): ", AverageRate)
step_range = 20000
plt.figure(figsize=(6, 6))
for j in range(5):
plt.plot(np.arange(step_range)*dt, REC[:step_range, j]/(50-vreset)+j)
plt.title('Pre-Learning')
plt.xlabel('Time (s)')
plt.ylabel('Neuron Index')
plt.savefig("LIF_pre.png")
#plt.show()
plt.figure(figsize=(6, 6))
for j in range(5):
plt.plot(np.arange(nt-step_range, nt)*dt, REC[nt-step_range:, j]/(50-vreset)+j)
plt.title('Post Learning')
plt.xlabel('Time (s)')
plt.ylabel('Neuron Index')
plt.savefig("LIF_post.png")
#plt.show()
plt.figure(figsize=(12, 6))
plt.plot(np.arange(nt)*dt, current, label="Decoded output")
plt.plot(np.arange(nt)*dt, zx, label="Target")
plt.xlim(12,15)
plt.title('Decoded output')
plt.xlabel('Time (s)')
plt.ylabel('current')
plt.savefig("LIF_post_out.png")
#plt.show()
Z = np.linalg.eig(OMEGA + np.expand_dims(E,1) @ np.expand_dims(BPhi,1).T)
Z2 = np.linalg.eig(OMEGA)
plt.figure(figsize=(6, 5))
plt.title('Weight eigenvalues')
plt.scatter(Z2[0].real, Z2[0].imag, c='r', s=5, label='Pre-Learning')
plt.scatter(Z[0].real, Z[0].imag, c='k', s=5, label='Post-Learning')
plt.legend()
plt.xlabel('Real')
plt.ylabel('Imaginary')
plt.savefig("LIF_weight_eigenvalues.png")
#plt.show()