In [10]:
#import allensdk
import numpy as np
import pandas as pd
#import pprint as pp
#from allensdk.brain_observatory.ecephys.ecephys_project_cache import EcephysProjectCache

from allensdk.brain_observatory.ecephys.ecephys_session import EcephysSession

import scipy
from scipy.fftpack import fft
from scipy import stats, sparse
from scipy.stats import pearsonr
from scipy.cluster.hierarchy import dendrogram, linkage
from scipy.spatial.distance import pdist

# from sklearn.cluster import AgglomerativeClustering
# from sklearn.metrics import pairwise_distances

# from sklearn.preprocessing import scale 
# from sklearn import model_selection
# from sklearn.decomposition import PCA
# from sklearn.linear_model import LinearRegression
# from sklearn.cross_decomposition import PLSRegression, PLSSVD
# from sklearn.metrics import mean_squared_error
# from sklearn.linear_model import LassoCV, Lasso
# from sklearn.model_selection import KFold


#import statsmodels.api as sm#

import matplotlib.pyplot as plt

In [11]:
###################################################
# Elastic net reduced-rank regression

def elastic_rrr(X, Y, rank=2, lambdau=1, alpha=0.5, max_iter = 100, verbose=0,
                sparsity='row-wise'):

    # in the pure ridge case, analytic solution is available:
    if alpha == 0:
        U,s,V = np.linalg.svd(X, full_matrices=False)
        B = V.T @ np.diag(s/(s**2 + lambdau*X.shape[0])) @ U.T @ Y
        U,s,V = np.linalg.svd(X@B, full_matrices=False)
        w = B @ V.T[:,:rank]
        v = V.T[:,:rank]

        pos = np.argmax(np.abs(v), axis=0)
        flips = np.sign(v[pos, range(v.shape[1])])
        v = v * flips
        w = w * flips

        return (w,v)

    # initialize with PLS direction
    _,_,v = np.linalg.svd(X.T @ Y, full_matrices=False)
    v = v[:rank,:].T
    
    loss = np.zeros(max_iter)
    
    for iter in range(max_iter):
        if rank == 1:
            w = glmnet(x = X.copy(), y = (Y @ v).copy(), alpha = alpha, lambdau = np.array([lambdau]), 
                       standardize = False, intr = False)['beta']
        else: 
            if sparsity=='row-wise':
                w = glmnet(x = X.copy(), y = (Y @ v).copy(), alpha = alpha, lambdau = np.array([lambdau]), 
                           family = "mgaussian", standardize = False, intr = False,
                           standardize_resp = False)['beta']
            else:
                w = []
                for i in range(rank):
                    w.append(glmnet(x = X.copy(), y = (Y @ v[:,i]).copy(), alpha = alpha, lambdau = np.array([lambdau]), 
                             standardize = False, intr = False, standardize_resp = False)['beta'])
            w = np.concatenate(w, axis=1)
                
        if np.all(w==0):
            v = v * 0
            return (w, v)
            
        A = Y.T @ X @ w
        a,c,b = np.linalg.svd(A, full_matrices = False)
        v = a @ b
        pos = np.argmax(np.abs(v), axis=0)
        flips = np.sign(v[pos, range(v.shape[1])])
        v = v * flips
        w = w * flips
        
        loss[iter] = np.sum((Y - X @ w @ v.T)**2)/np.sum(Y**2);        
        
        if iter > 0 and np.abs(loss[iter]-loss[iter-1]) < 1e-6:
            if verbose > 0:
                print('Converged in {} iteration(s)'.format(iter))
            break
        if (iter == max_iter-1) and (verbose > 0):
            print('Did not converge. Losses: ', loss)
    
    return (w, v)

In [2]:
expt_id = '719161530'#'715093703'
nwb_path = '/Users/Ram/Dropbox/VC_NP_sklearn_copy/ecephys_session_'+expt_id+'.nwb'

session = EcephysSession.from_nwb_path(nwb_path, api_kwargs={
        "amplitude_cutoff_maximum": 0.1,
        "presence_ratio_minimum": 0.9,
        "isi_violations_maximum": 0.5
    })

In [4]:
session.units.ecephys_structure_acronym.unique()

array(['APN', 'DG', 'CA1', 'VISam', 'TH', 'Eth', 'POL', 'LP', 'VISpm',
       'NOT', 'SUB', 'VISp', 'grey', 'VL', 'CA3', 'VISl', 'PO', 'VPM',
       'LGd', 'VISal', 'VISrl'], dtype=object)

In [13]:
session.get_stimulus_table('spontaneous')

Unnamed: 0_level_0,start_time,stimulus_name,stop_time,duration,stimulus_condition_id
stimulus_presentation_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,29.830107,spontaneous,89.896827,60.06672,0.0
3646,1001.891772,spontaneous,1290.883097,288.991326,0.0
3797,1589.382401,spontaneous,1591.133857,1.751456,0.0
3998,2190.634543,spontaneous,2221.660447,31.025904,0.0
21999,2822.161967,spontaneous,2852.187037,30.02507,0.0
31000,3152.437777,spontaneous,3182.462857,30.02508,0.0
31201,3781.963503,spontaneous,4083.215117,301.251614,0.0
49202,4683.716567,spontaneous,4713.741627,30.02506,0.0
49431,5397.312443,spontaneous,5398.313257,1.000814,0.0
51352,5878.714467,spontaneous,5908.739537,30.02507,0.0


In [5]:
stim_name = 'drifting_gratings'


In [6]:
stim_table = session.get_stimulus_table(stim_name)

In [7]:
dt = 0.1   
stim_durn = int(1000.*np.mean(stim_table.duration.values))/1000

start_time = 0.16 
end_time = 1.16
bef = np.arange(start_time,end_time, dt)
print(stim_durn,len(bef))

2.001 10


In [None]:
visp_units = session.units[session.units.ecephys_structure_acronym==cstr].index.values
scid = np.unique(stim_table.stimulus_condition_id.values)

spids = stim_table[stim_table.stimulus_condition_id==stim_cond_id].index.values

tmp_binned_spt = session.presentationwise_spike_counts\
(bin_edges = bin_edges_full, stimulus_presentation_ids = spids, unit_ids = visp_units)

In [8]:
import glmnet_python

In [9]:
from glmnet import glmnet

In [12]:
session.units[session.units.ecephys_structure_acronym == 'POL']

Unnamed: 0_level_0,waveform_amplitude,waveform_repolarization_slope,nn_hit_rate,peak_channel_id,isi_violations,d_prime,waveform_halfwidth,waveform_spread,snr,amplitude_cutoff,...,probe_vertical_position,probe_id,ecephys_structure_id,channel_local_index,probe_horizontal_position,probe_description,location,probe_sampling_rate,probe_lfp_sampling_rate,probe_has_lfp_data
unit_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
950922701,450.21834,1.352459,0.97861,850249351,0.002504,6.411387,0.20603,50.0,5.150346,0.000426,...,440,729445650,1029.0,43,27,probeB,,29999.91888,1249.99662,True
950922531,236.11848,0.618716,0.266667,850249351,0.123768,2.285846,0.302178,80.0,2.901084,0.045209,...,440,729445650,1029.0,43,27,probeB,,29999.91888,1249.99662,True
950928995,183.55038,0.502436,0.027778,850249355,0.241501,1.976461,0.260972,80.0,2.135468,0.062463,...,460,729445650,1029.0,45,11,probeB,,29999.91888,1249.99662,True
950928972,219.10356,0.638081,0.804312,850249355,0.119628,3.190678,0.20603,70.0,2.456872,0.035652,...,460,729445650,1029.0,45,11,probeB,,29999.91888,1249.99662,True
950922872,311.65953,0.902597,0.900383,850249355,0.015899,4.311851,0.233501,60.0,3.571287,0.000328,...,460,729445650,1029.0,45,11,probeB,,29999.91888,1249.99662,True
950922833,166.25388,0.444245,0.203704,850249355,0.0,2.695338,0.247236,80.0,1.965187,0.053612,...,460,729445650,1029.0,45,11,probeB,,29999.91888,1249.99662,True
950922817,270.58785,0.718883,0.783784,850249355,0.069763,2.819402,0.274707,60.0,3.289254,0.004565,...,460,729445650,1029.0,45,11,probeB,,29999.91888,1249.99662,True
950929010,156.596115,0.406059,0.916667,850249353,0.050202,3.959245,0.219765,70.0,1.972336,0.02934,...,460,729445650,1029.0,44,43,probeB,,29999.91888,1249.99662,True
950922775,161.23458,0.462119,0.944444,850249353,0.025253,3.914908,0.219765,60.0,2.227342,0.004268,...,460,729445650,1029.0,44,43,probeB,,29999.91888,1249.99662,True
950922761,281.968635,0.873366,0.981735,850249353,0.00049,5.181204,0.233501,50.0,3.768645,0.000243,...,460,729445650,1029.0,44,43,probeB,,29999.91888,1249.99662,True
