In [1]:
import numpy as np
import matplotlib.pyplot as plt
from os import path, listdir
from os.path import join

In [2]:
from plot_props import PlotProps
pp = PlotProps()

In [3]:
class RNNGain:
    def __init__(self, J, Bc, Bs, cx, cz, wo, xo, us, uc, τ, dt, numneur):
        self.J = J
        self.Bc = Bc
        self.Bs = Bs
        self.cx = cx
        self.cz = cz
        self.wo = wo
        self.xo =xo
        
        self.numneur = numneur
        self.dt = dt
        self.τ = τ
        self.us = us
        self.uc = uc
        
    def nonlin(self, x):
        return np.tanh(x)
    
    def update(self, x, ind):
        xnew = x + ((self.dt / self.τ) * (J.dot(self.nonlin(x)) +
                                          (-x) +
                                         self.Bc.dot(self.uc[ind]) +
                                         self.Bs.dot(self.us[ind]) +
                                         self.cx 
                                        ))
        return xnew
    
    def activity(self, x):
        return self.wo.dot(self.nonlin(x)) + cz
    
    def simulate(self, numsteps):
        xs = np.zeros([numsteps, self.numneur])
        z = np.zeros(numsteps)
        
        xs[0] = self.xo
        z[0] = self.activity(self.xo)
        
        for i in range(1, numsteps):
            xs[i] = self.update(xs[i-1], i)
            z[i] = self.activity(xs[i])
        
        return xs, z
            

In [4]:
source = join('neural_timing', 'results', 'tonic')
dest = join('neural_timing', 'figures')

a=5

In [5]:
params = np.loadtxt(join(source, 'prms_a_%d'%a))
tr_time = np.loadtxt(join(source, 'ts_set'))

xo = params[:, 0]
wo = params[:, 1]
Bc = params[:, 2]
Bs = params[:, 3]
J = params[:, 4:-2]
cx = params[:, -2]
cz = params[:, -1][0]

In [6]:
source_nf = join(source, 'noisefree')
cntxt_in_high = np.loadtxt(join(source_nf, 'context_inputs_high_a_%d'%a))
cntxt_in_low = np.loadtxt(join(source_nf, 'context_inputs_low_a_%d'%a))
set_in = np.loadtxt(join(source_nf, 'set_inputs_a_%d'%a))

In [7]:
perf_time = 500
perftime_ind = np.where(tr_time*1000==perf_time)[0][0]
cntxt_amp = np.arange(.1, 1.125, .025)
gains = (cntxt_amp - .1)*5
print(gains)

[0.    0.125 0.25  0.375 0.5   0.625 0.75  0.875 1.    1.125 1.25  1.375
 1.5   1.625 1.75  1.875 2.    2.125 2.25  2.375 2.5   2.625 2.75  2.875
 3.    3.125 3.25  3.375 3.5   3.625 3.75  3.875 4.    4.125 4.25  4.375
 4.5   4.625 4.75  4.875 5.   ]


In [8]:
dest_res = join(source, 'var_gain')
np.savetxt(join(dest_res, 'set_in_t_%d_a_%d'%(perf_time, a)), 
           set_in[:, perftime_ind])


In [9]:
cntxt_pos = cntxt_in_low[:, perftime_ind]>0.

cntxt = np.zeros_like(cntxt_in_low[:, perftime_ind])
cntxt[cntxt_pos] = 1.

cntxt = np.array([cntxt*amp for amp in cntxt_amp]).T
np.savetxt(join(dest_res, 'cntxt_var_gain'), cntxt)
np.savetxt(join(dest_res, 'gains'), gains)

In [10]:
set_in_sm = set_in[:, perftime_ind]
print(len(set_in_sm[set_in_sm>0.]))

41


In [11]:
est_time = []
xs_list = []
z_list = []
set_ind = np.where(set_in_sm>0.)[0][-1]

Bs_sm = Bs
Bc_sm = Bc
for gain_ind, gain in enumerate(gains):
    us_sm = set_in_sm
    uc_sm = cntxt[:, gain_ind]
    rnn = RNNGain(J=J, Bc=Bc_sm, Bs=Bs_sm, 
                  cx=cx, cz=cz, 
                  wo=wo, xo=xo, 
                  us=us_sm, 
                  uc=uc_sm, 
                  τ=10, 
                  dt=1, 
                  numneur=200)
    xs, z = rnn.simulate(3300)
    xs_list.append(xs)
    z_list.append(z) 
#     np.savetxt(join(dest_res, 'xs_t_%d_gain_%d'%(perf_time, gain_ind+1)), xs)
#     np.savetxt(join(dest_res, 'z_t_%d_gain_%d'%(perf_time, gain_ind+1)), z)
    try:
        end_ind = np.where(z[set_ind:]>1.)[0][0] + set_ind
    except IndexError:
        end_ind = 3300
#     plt.plot(z)
#     plt.plot(us_sm)
#     plt.plot(uc_sm)
#     plt.show()
#     print(np.where(z[set_ind:]>1.))
    print(end_ind, set_ind)
    est_time.append(end_ind - set_ind)
    

704 619
749 619
794 619
843 619
893 619
946 619
1001 619
1058 619
1115 619
1173 619
1232 619
1292 619
1356 619
1425 619
1502 619
1594 619
1710 619
1867 619
2114 619
2642 619
3300 619
3300 619
3300 619
3300 619
3300 619
3300 619
3300 619
3300 619
3300 619
3300 619
3300 619
3300 619
3300 619
3300 619
3300 619
3300 619
3300 619
3300 619
3300 619
3300 619
3300 619


In [None]:
plt.plot(perf_time*gains[:-6], est_time[:-6], 'o')
plt.plot(perf_time*gains, perf_time*gains, 'k--')
plt.show()

In [None]:
plt.plot(set_in_sm)
plt.xlim([100, 150])

## PCA

In [12]:
len(xs_list)

41

In [None]:
xs_list_cut = [xs[set_ind:set_ind+t] for xs, t in zip(xs_list, est_time)]

In [37]:
num_gains = 5
xs_sm = np.vstack(xs_list_cut[:num_gains])

In [23]:
mean = xs_sm.mean(axis=0)
xs_sm_cent = xs_sm - mean 

In [24]:
cov = xs_sm_cent.T.dot(xs_sm_cent)
eig, vec = np.linalg.eig(cov)

In [32]:
sort = np.argsort(eig)[::-1]
eig_sort = eig[sort]
vec_sort = vec[:, sort]

In [79]:
import sklearn.linear_model
ind1 = np.argwhere(np.array(gains)<=1)[-1].squeeze()
ind2 = np.argwhere(np.array(gains)<=1.5)[-1].squeeze()
x1 = xs_list_cut[ind1]
x2 = xs_list_cut[ind2]


xs_lr = np.vstack([x1, x2]) 

lr = sklearn.linear_model.LogisticRegression()
y  = np.hstack([np.ones(len(x1)),
                np.zeros(len(x2)),
               ]).flatten()

lr_fit = lr.fit(X=xs_lr, y=y)


8 12


In [95]:
from mpl_toolkits.mplot3d import Axes3D
%matplotlib tk
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
plt.ion()

ax.set_xlabel('PC1')
ax.set_ylabel('PC2')
ax.set_zlabel('PC3')

# Hide grid lines
ax.grid(False)

# Hide axes ticks
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])

for i, g in enumerate(gains[ind1-1:ind2+2]):
    proj = xs_list_cut[i].dot(vec[:, :3])
    print(proj.shape)
    x = np.real(proj[:, 0])
    y = np.real(proj[:, 1])
    z = np.real(proj[:, 2])
    ax.scatter(x, y, z, label='Gain=%.3f'%g)
#     ax.scatter(proj[:, 0], proj[:, 1], proj[:, 2])
proj_line = np.zeros([2,3])
proj_line[1] = wt.dot(vec_sort[:, 3])
ax.plo(proj_line[:, 0], proj_line[:, 1], proj_line[:, 2], s=500, color='k')

pp.legend(fontsize=10)
plt.show()
plt.ioff()


(85, 3)
(130, 3)
(175, 3)
(224, 3)
(274, 3)
(327, 3)
(382, 3)


IndexError: index 1 is out of bounds for axis 0 with size 1

In [88]:

plt.plot(lr_fit.coef_.T)
plt.show()