In [None]:
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
from swdb_2018_neuropixels.ephys_nwb_adapter import NWB_adapter    

In [None]:
drive_path = '/data/dynamic-brain-workshop/visual_coding_neuropixels'

In [None]:
manifest_file = os.path.join(drive_path,'ephys_manifest.csv')
expt_info_df = pd.read_csv(manifest_file)
multi_probe_expt_info = expt_info_df[expt_info_df.experiment_type == 'multi_probe']
multi_probe_example = 1 # index to row in multi_probe_expt_info
multi_probe_filename  = multi_probe_expt_info.iloc[multi_probe_example]['nwb_filename']
nwb_file = os.path.join(drive_path,multi_probe_filename)
data_set = NWB_adapter(nwb_file)

In [None]:
from downsampling_module import downsample_images

In [None]:
nat_scenes = np.load('natural_scenes.npy')

In [None]:
ds_nat_scenes = np.array(downsample_images(nat_scenes, 25, 25))
ds_nat_scenes = ds_nat_scenes*(1.0/255.0)

In [None]:
print(ds_nat_scenes[1].shape)
plt.imshow(ds_nat_scenes[7], cmap='gray')
plt.colorbar()

In [None]:
from downsampling_module import flatten_images

In [None]:
flattened_image_list = flatten_images(np.array(ds_nat_scenes))

In [None]:
stim_table = data_set.get_stimulus_table('natural_scenes')

In [None]:
def get_frame_at_time(time, stim_table):
    starts = stim_table.start.values
    idx = np.searchsorted(starts, time)-1
    return(stim_table.iloc[idx].values[2])

In [None]:
def get_stim_time_array(stim_table, tns_start, tns_end, bin_len, flattened_image_list):
    T = int(np.floor((tns_end - tns_start)/bin_len))
    time_array = np.linspace(tns_start,tns_end,T)
    stim_array = []
    for idx, time_point in enumerate(time_array):
        stim_index = get_frame_at_time(time_point, stim_table)
        stim_array.append(flattened_image_list[int(stim_index)])
    return(stim_array, time_array)

In [None]:
bin_len = 0.001
num_stim_rows = 200
tns_start = stim_table.iloc[0].values[0]
tns_end = stim_table.iloc[num_stim_rows].values[1]
print(tns_start)
print(tns_end)

In [None]:
[stim_array, time_array] = get_stim_time_array(stim_table,tns_start,tns_end,bin_len,flattened_image_list)
print(tns_start)
print(tns_end)

In [None]:
print(np.shape(stim_array))

print(np.shape(time_array))


In [None]:
# print(time_array.shape)
# print(np.array(stim_array).shape)

# time_array_short = time_array[0:len(time_array)/4]
# stim_array_short = stim_array[:,:len(time_array)/4]
# print(time_array_short.shape)
# print(np.array(stim_array_short).shape)

tns_start = time_array[0]
tns_end = time_array[-1]
print(tns_start)
print(tns_end)

print(time_array.shape)


In [None]:
def bin_spikes(data_set,bin_len,t_start,t_final,probes=None,regions=None):
    if probes is None:
        probes = data_set.probe_list
    if regions is None:
        regions = data_set.unit_df.structure.unique()
    
    #gather cells from desired regions and probes into cell_table
    use_cells = False
    for probe in probes:
        for region in regions:
            use_cells |= (data_set.unit_df.probe==probe) & (data_set.unit_df.structure==region)
    cell_table = data_set.unit_df[use_cells]
    
    N = len(cell_table)     #number of cells
    T = int(np.floor((t_final-t_start)/bin_len)) #number of time bins
    binned_spikes = np.zeros((N,T)) # binned_spikes[i,j] is the number of spikes from neuron i in time bin j

    #for each cell in the table, add each spike to the appropriate bin
    i = 0
    for z,cell in cell_table.iterrows(): 
        for spike_time in data_set.spike_times[cell['probe']][cell['unit_id']]:
            t = int(np.floor((spike_time-t_start)/bin_len))
            if (t >=0) & (t<T):
                binned_spikes[i,t] += 1
        i+=1    
    return (binned_spikes, cell_table)

In [None]:
(binned_spikes, cell_table) = bin_spikes(data_set,bin_len,tns_start,tns_end,regions=['VISp'])

In [None]:
binned_spikes.shape

In [None]:
reduced_binned_spikes = binned_spikes[:10,:]

In [None]:
reduced_binned_spikes.shape

In [None]:
import keras
from keras import backend as K
from keras.models import Model
from keras.layers import Input
from keras import Sequential
from keras.layers import Dense, Lambda
from keras.regularizers import Regularizer
from keras.callbacks import ModelCheckpoint

def GLM_network_fit(stimulus,spikes,d_stim, d_spk,bin_len,f='exp',priors=None,L1=None):
    N = spikes.shape[0]
    print("N", N)
    M = stimulus.shape[0]
    print("M", M)
    F = np.empty((N,M,d_stim)) # stimulus filters
    W = np.empty((N,N,d_spk))  # spike train filters
    b = np.empty((N,)) # biases
    fs = {'exp':K.exp}
    Xdsn = construct_GLM_mat(stimulus,spikes, d_stim, d_spk)
    for i in range(1):
        y = spikes[i,max(d_stim,d_spk):]
        # construct GLM model and return fit
        model = Sequential()
        model.add(Dense(1,input_dim = Xdsn.shape[1],use_bias=True))
        model.add(Lambda(lambda x: fs[f](x)*bin_len))
        model.compile(loss = 'poisson',optimizer = keras.optimizers.adam(lr=5e-1))
#        checkpointer = ModelCheckpoint(filepath='weights.hdf5', verbose=1, save_best_only=False)
        model.fit(x=Xdsn,y=y,epochs=5,verbose=1)
        p = model.get_weights()[0]
        F[i,:,:] = p[:M*d_stim].reshape((M,d_stim))
        W[i,:,:] = p[M*d_stim:].reshape((N,d_spk))
        b[i] = model.get_weights()[1]
    return (F,W,b)

In [None]:
#kernel_regularizer=SparseGroupLasso(M*d_stim,d_spk,lgroup=0))

In [None]:
def construct_GLM_mat(flat_stimulus, binned_spikes, d_stim, d_spk):
    (N,T) = binned_spikes.shape # N is number of neurons, T is number of time bins
    (M,T) = flat_stimulus.shape # M is the size of a stimulus
    print("N,T", (N,T))
    print("M,T", (M,T))
    X_dsn = np.empty((T-d_stim,M*d_stim+N*d_spk))
    d_max = max(d_stim,d_spk)
    for t in range(T-d_max):
        X_dsn[t,:M*d_stim] = np.fliplr(flat_stimulus[:,t+d_max-d_stim:t+d_max]).reshape((1,-1))  #stimulus inputs
        X_dsn[t,M*d_stim:] = np.fliplr(binned_spikes[:,t+d_max-d_spk:t+d_max]).reshape((1,-1)) #spike inputs
    return X_dsn    


In [None]:
from keras.regularizers import Regularizer
from keras import backend as K

class SparseGroupLasso(Regularizer):
    """Regularizer for group lasso regularization.
    # Arguments
       l1: Float; L1 regularization factor.
       l2: Float; L2 group regularization factor.
   """

    def __init__(self, size_stim, d_spike, lgroup = 1.):
        self.lgroup = K.cast_to_floatx(lgroup)
        self.d_spike = d_spike
        self.size_stim = size_stim

    def __call__(self, x): 
        xr = K.reshape(x[self.size_stim:], (-1, self.d_spike))
        print("xrshape", xr.shape)
        return(self.lgroup * np.sqrt(K.int_shape(xr)[1])*K.sum(K.sqrt(K.sum(K.square(xr),axis=1))))


In [None]:
[F, W, b] = GLM_network_fit(np.array(stim_array).T,reduced_binned_spikes,20,20, bin_len)

In [None]:
import h5py

In [None]:
hf = h5py.File('weights.hdf5', 'r')

In [None]:
hf['model_weights/dense_4/dense_4'].values()

In [None]:
W[0].shape

In [None]:
norms = np.linalg.norm(W[0], axis = 1)

In [None]:
norms.shape

In [None]:
plt.hist(norms, bins = 50);

In [None]:
filter = W[0][1]
plt.plot(filter)
print("norm", np.linalg.norm(filter))



In [None]:
filtermat = W[0,:,:]

In [None]:
filtermat.shape

In [None]:
plt.plot(filtermat[:,4])

In [None]:
plt.hist(np.linalg.norm(filtermat, axis = 0), bins = 10)

In [None]:
# def construct_GLM_mat(flat_stimulus, binned_spikes, i, d_stim, d_spk):
#     (N,T) = binned_spikes.shape # N is number of neurons, T is number of time bins
#     print("T",T)
#     (M,T) = flat_stimulus.shape # M is the size of a stimulus
#     X_dsn = np.empty((T-d_stim+1,M*d_stim+N*d_spk))
#     d_max = max(d_stim,d_spk)
#     y = np.empty((T-d_max+1,))
#     for t in range(T-d_max+1):
#         y[t] = binned_spikes[i,t+d_max-1]
#         X_dsn[t,:M*d_stim] = flat_stimulus[:,t+d_max-d_stim:t+d_max].reshape((1,-1))
#         X_dsn[t,M*d_stim:] = binned_spikes[:,t+d_max-d_spk:t+d_max].reshape((1,-1))
#     return (y, X_dsn)   

In [None]:
# import keras
# from keras import backend as K
# from keras.models import Model
# from keras.layers import Input
# from keras import Sequential
# from keras.layers import Dense, Lambda
# from keras.regularizers import Regularizer
# def GLM_network_fit(stimulus,spikes,d_stim, d_spk,bin_len,f='exp',priors=None,L1=None):
#     N = spikes.shape[0]
#     print("N", N)
#     M = stimulus.shape[0]
#     print("M", M)
#     F = np.empty((N,M,d_stim)) # stimulus filters
#     W = np.empty((N,N,d_spk))  # spike train filters
#     b = np.empty((N,)) # biases
#     fs = {'exp':K.exp}
#     for i in range(1):
#         [y, Xdsn] = construct_GLM_mat(np.array(stim_array), binned_spikes, i, d_stim, d_spk)
#         print("yshape",y.shape)
#         model = Sequential()
#         model.add(Dense(1,input_dim = Xdsn.shape[1],use_bias=True, kernel_regularizer=SparseGroupLasso(M*d_stim,d_spk,lgroup=1e-10)))
#         model.add(Lambda(lambda x: fs[f](x)*bin_len))
#         model.compile(loss = 'poisson',optimizer = keras.optimizers.adam(lr=5e-1))
#         model.fit(x=Xdsn,y=y,epochs=50, batch_size = 1000,  verbose=1)
#         p = model.get_weights()[0]
#         print("pshape", p.shape)
#         print("Mdstim", M*d_stim)
#         F[i,:,:] = p[:M*d_stim].reshape((M,d_stim))
#         W[i,:,:] = p[M*d_stim:].reshape((N,d_spk))
#         b[i] = model.get_weights()[1]
#     return (F,W,b)
