# Network State Index -- API demo

see: https://github.com/yzerlaut/Network_State_Index

We demonstrate the use of the API on the following publicly available dataset:

## the "Visual Coding – Neuropixels" dataset from the Allen Observatory

All details about this dataset and instructions for analysis are available at:

https://allensdk.readthedocs.io/en/latest/visual_coding_neuropixels.html

## Dataset download

I made a [custom script](https://github.com/yzerlaut/Network_State_Index/blob/main/demo/download_Allen_Visual-Coding_dataset.py) to download exclusively the part of the dataset of interest here (V1 probes).
You can run the script as:
```
python demo/download_Allen_Visual-Coding_dataset.py
```

In [1]:
# some general python / scientific-python modules
import os, shutil
import numpy as np
import pandas as pd
import matplotlib.pylab as plt
plt.style.use('seaborn')

# the Network State Index API, see: https://github.com/yzerlaut/Network_State_Index
# install it with: "pip install git+https://github.com/yzerlaut/Network_State_Index"
import nsi

In [2]:
# Download and load the data with the "allensdk" API
# get the "allensdk" api with: "pip install allensdk"
from allensdk.brain_observatory.ecephys.ecephys_project_cache import EcephysProjectCache

# now let's define a cache repository for the data: by default ~/Downloads/ecephys_cache_dir
# insure that you have a "Downloads" repository in your home directory (/!\ non-english systems) or update below
data_directory = os.path.join(os.path.expanduser('~'), 'Downloads', 'ecephys_cache_dir')
manifest_path = os.path.join(data_directory, "manifest.json")
cache = EcephysProjectCache.from_warehouse(manifest=manifest_path)
all_sessions = cache.get_session_table() # get all sessions

### We restrict the analysis to:

    - wild type / wild type strain
    - male
    - recordings including V1 ("VISp")
    - "Brain Observatory 1.1" dataset

In [3]:
# let's filter the sessions according to the above criteria
sessions = all_sessions[(all_sessions.sex == 'M') & \
                        (all_sessions.full_genotype.str.find('wt/wt') > -1) & \
                        #(all_sessions.session_type == 'brain_observatory_1.1') & \
                        (all_sessions.session_type == 'functional_connectivity') & \
                        (['VISp' in acronyms for acronyms in all_sessions.ecephys_structure_acronyms])]
print(30*'--'+'\n--> Number of sessions with the desired characteristics: ' + str(len(sessions))+'\n'+30*'--')
# sessions.head() # uncomment to see how they look

------------------------------------------------------------
--> Number of sessions with the desired characteristics: 11
------------------------------------------------------------


## Loading, formatting and preprocessing the data

In [16]:
import time

class Data:
    """
    an object to load, format and process the data
    """
    
    def __init__(self, 
                 session_index=0,
                 demo=True, demo_filename='allen_demo_sample.npy',
                 reduced=False,
                 t0=115*60, 
                 duration=20*60, # 20 min by default
                 #init=['pop_act', 'pLFP', 'NSI'], # for a full init
                 init = []):
        """
        loading data according to the index of the "sessions" above
        to troubleshoot, there is a "demo" option that loads the sample provided in the repo
        
        can load a subset using the "t0" and "duration" args
        """
        
        if demo:
            # -------- using the stored demo data if "demo" mode ----- #
            DEMO = np.load(demo_filename, allow_pickle=True).item()
            for key in DEMO:
                setattr(self, key, DEMO[key]) # sets LFP
            self.t_LFP = np.arange(len(self.LFP))/self.lfp_sampling_rate+self.t0
        elif reduced:
            rdata = np.load('reduced_data/Allen_FC_session%i.npy' % (session_index+1), allow_pickle=True).item()
            for key in rdata:
                setattr(self, key, rdata[key]) # sets LFP, pLFP, pop_act, running_speed
            for key in ['LFP', 'pLFP', 'pop_act']:
                setattr(self, 't_%s'%key, np.arange(len(getattr(self,key)))/getattr(self,'%s_sampling_rate'%key)+self.t0)
        else:
            # -------------------------------------------------------- # 
            # -- using the Allen SDK to retrieve and cache the data -- #
            # -------------------------------------------------------- # 

            print('loading session #%i [...]' % (1+session_index))
            tic = time.time()
            # we load a single session
            session = cache.get_session_data(sessions.index.values[session_index])

            # use the running timestamps to set start and duration in the data object
            self.t0 = np.max([t0, session.running_speed.start_time.values[0]])
            self.duration = np.min([duration, session.running_speed.end_time.values[-1]-self.t0])

            # let's fetch the running speed
            cond = (session.running_speed.end_time.values>self.t0) &\
                (session.running_speed.start_time.values<(self.t0+self.duration))
            self.t_running_speed = .5*(session.running_speed.start_time.values[cond]+\
                                       session.running_speed.end_time.values[cond])
            self.running_speed = session.running_speed.velocity[cond]

            # let's fetch the isolated single units in V1
            V1_units = session.units[session.units.ecephys_structure_acronym == 'VISp'] # V1==VISp
            self.V1_RASTER = []
            for i in V1_units.index:
                cond = (session.spike_times[i]>=self.t0) & (session.spike_times[i]<(self.t0+self.duration))
                self.V1_RASTER.append(session.spike_times[i][cond])

            # let's fetch the V1 probe --> always on "probeC"
            probe_id = session.probes[session.probes.description == 'probeC'].index.values[0]

            # -- let's fetch the lfp data for that probe and that session --
            # let's fetch the all the channels falling into V1 domain
            self.V1_channel_ids = session.channels[(session.channels.probe_id == probe_id) & \
                          (session.channels.ecephys_structure_acronym.isin(['VISp']))].index.values

            # limit LFP to desired times and channels
            # N.B. "get_lfp" returns a subset of all channels above
            self.lfp_slice_V1 = session.get_lfp(probe_id).sel(time=slice(self.t0,
                                                                         self.t0+self.duration),
                                                              channel=slice(np.min(self.V1_channel_ids), 
                                                                            np.max(self.V1_channel_ids)))
            self.Nchannels_V1 = len(self.lfp_slice_V1.channel) # store number of channels with LFP in V1
            self.lfp_sampling_rate = session.probes.lfp_sampling_rate[probe_id] # keeping track of sampling rate
            print('data successfully loaded in %.1fs' % (time.time()-tic))
              
        for key in init:
            getattr(self, 'compute_%s' % key)()
            

    def update_t0_duration(self, t0, duration):
        t0 = t0 if (t0 is not None) else self.t0
        duration = duration if (duration is not None) else self.duration
        return t0, duration
    
        
    def compute_pop_act(self, 
                        pop_act_bin=5e-3,
                        pop_act_smoothing=20e-3):
        """
        we bin spikes to compute population activity
        """
        print(' - computing pop_act from raster [...]') 
        t_pop_act = self.t0+np.arange(int(self.duration/pop_act_bin)+1)*pop_act_bin
        pop_act = np.zeros(len(t_pop_act)-1)

        for i, spikes in enumerate(self.V1_RASTER):
            pop_act += np.histogram(spikes, bins=t_pop_act)[0]
        pop_act /= (len(self.V1_RASTER)*pop_act_bin)

        self.t_pop_act = .5*(t_pop_act[1:]+t_pop_act[:-1])
        self.pop_act = nsi.gaussian_filter1d(pop_act, 
                                             int(pop_act_smoothing/pop_act_bin)) # filter from scipy
        self.pop_act_sampling_rate = 1./pop_act_bin
        print(' - - > done !') 
        
        
    def compute_NSI(self, quantity='pLFP',
                    low_freqs = np.linspace(2, 5, 5),
                    p0_percentile=1.,
                    alpha=2.87,
                    with_subquantities=True):
        """
        ------------------------------
            HERE we use the NSI API
        ------------------------------
        """
        print(' - computing NSI for "%s" [...]' % quantity) 
        setattr(self, '%s_0' % quantity, np.percentile(getattr(self, quantity), p0_percentile/100.))
        
        if with_subquantities:
            lfe, sm, NSI = nsi.compute_NSI(getattr(self, quantity),
                                           getattr(self, '%s_sampling_rate' % quantity),
                                           low_freqs = low_freqs,
                                           p0=getattr(self, '%s_0' % quantity),
                                           alpha=alpha,
                                           with_subquantities=True)
            setattr(self, '%s_low_freq_env' % quantity, lfe)
            setattr(self, '%s_sliding_mean' % quantity, sm)
            setattr(self, '%s_NSI' % quantity, NSI)
        
        else:
            setattr(self, '%s_NSI' % quantity, nsi.compute_NSI(getattr(self, quantity),
                                                              getattr(self, '%s_sampling_rate' % quantity),
                                                              low_freqs = low_freqs,
                                                              p0=getattr(self, '%s_0' % quantity),
                                                              alpha=alpha))
        print(' - - > done !') 
        
    def validate_NSI(self, quantity='pLFP',
                     Tstate=200e-3,
                     var_tolerance_threshold=None):
        """
        ------------------------------
            HERE we use the NSI API
        ------------------------------
        """
        print(' - validating NSI for "%s" [...]' % quantity) 
        
        if var_tolerance_threshold is None:
            # by default the ~noise level evaluated as the first percentile
            var_tolerance_threshold = getattr(self, '%s_0' % quantity)
 
        vNSI = nsi.validate_NSI(getattr(self, 't_%s' % quantity),
                                getattr(self, '%s_NSI' % quantity),
                                Tstate=Tstate,
                                var_tolerance_threshold=var_tolerance_threshold)
    
        setattr(self, 'i_%s_vNSI' % quantity, vNSI)
        setattr(self, 't_%s_vNSI' % quantity, getattr(self, 't_%s' % quantity)[vNSI])
        setattr(self, '%s_vNSI' % quantity, getattr(self, '%s_NSI' % quantity)[vNSI])
        print(' - - > done !')
        
    def plot(self, quantity, 
             t0=None, duration=None,
             ax=None,
             subsampling=1,
             color='k', ms=0,
             lw=1):
        """
        quantity as a string (e.g. "pLFP" or "running_speed")
        """
        
        t0, duration = self.update_t0_duration(t0, duration)
        
        try:
            if ax is None:
                fig, ax =plt.subplots(1, figsize=(8,3))
            else:
                fig = None
            t = getattr(self, 't_'+quantity.replace('_NSI','').replace('_low_freq_env','').replace('_sliding_mean',''))
            signal = getattr(self, quantity)
            cond = (t>t0) & (t<(t0+duration))
            ax.plot(t[cond][::subsampling], signal[cond][::subsampling], color=color, lw=lw, ms=ms, marker='o')
            return fig, ax
        except BaseException as be:
            print(be)
            print('%s not a recognized attribute to plot' % quantity)
            return None, None
        
# a tool very useful to 
def resample_trace(old_t, old_data, new_t):
    func = interp1d(old_t, old_data, kind='nearest', fill_value="extrapolate")
    return func(new_t)


## Channel selection

We benefit from many channels (~20 in V1), how to deal with this ? 

--> simple solution: we pick just one channel, the one that has the highest delta envelope in the pLFP. This sounds like a good guess for a channel with good physiological signal.

In [None]:
data = Data(session_index=0, demo=False, 
            init=['pop_act'])

In [None]:
def find_channel_with_highest_delta(data,
                                    pLFP_band=[40,140],
                                    delta_band=[3,6],
                                    pLFP_subsampling=5):
    
    channel_mean_delta, channel_id, final_pLFP, final_LFP = 0, None, None, None

    for c in range(len(data.lfp_slice_V1.channel.values)):
        # first compute pLFP
        LFP = 1e3*np.array(data.lfp_slice_V1.sel(channel=data.lfp_slice_V1.channel[c]))
        pLFP = 1e3*nsi.compute_freq_envelope(LFP, data.lfp_sampling_rate,
                                             np.linspace(pLFP_band[0], pLFP_band[1], 40))
        
        # then compute low freq envelope of pLFP (subsampled)
        lf_env = nsi.compute_freq_envelope(pLFP[::pLFP_subsampling], 
                                           data.lfp_sampling_rate/pLFP_subsampling,
                                           np.linspace(delta_band[0], delta_band[1], 5))
        
        if np.mean(lf_env)>channel_mean_delta:
            channel_mean_delta = np.mean(lf_env)
            final_pLFP = pLFP
            final_LFP = LFP
            channel_id = c
            
pLFP, LFP, channel_id = find_channel_with_highest_delta(data)

In [None]:
t0, duration = data.t0+120, 5
lfp_slice = data.lfp_slice_V1.sel(time=slice(t0,t0+duration))

band = [40,140]
cmap = plt.cm.copper
channel_variance = []

fig, AX = plt.subplots(3, figsize=(10,8))
for i in range(len(lfp_slice.channel.values)):
    AX[0].plot(lfp_slice.time, 1e3*lfp_slice.sel(channel=lfp_slice.channel[i]), 
                          lw=0.3, color=cmap(1-i/(len(lfp_slice.channel.values)-1)))
    pLFP = nsi.compute_freq_envelope(1e3*np.array(lfp_slice.sel(channel=lfp_slice.channel[i])), 
                                     data.lfp_sampling_rate,
                                     np.linspace(band[0], band[1], 40))
        
    AX[1].plot(lfp_slice.time, 1e3*pLFP, 
               lw=0.3, color=cmap(1-i/(len(lfp_slice.channel.values)-1)))
        
data.plot('pop_act', ax=AX[2], t0=t0, duration=duration, color='g')
AX[2].set_ylabel('pop. rate (Hz)')
AX[0].set_ylabel('LFP (mV)')
AX[1].set_ylabel('pLFP (uV)  %sHz' % band)
#AX[1].annotate('selected channels: %s' % (np.argsort(channel_variance)[-5:][::-1]),
AX[1].annotate('selected channel ID: %i' % channel_id,
               (0,1), xycoords='axes fraction', va='top')
for i in range(len(lfp_slice.channel.values)):
    AX[0].annotate((i+1)*'      '+'                 %i' % i, (0,1), xycoords='axes fraction', va='top',
                color=cmap(1-i/(len(lfp_slice.channel.values)-1)))
AX[0].annotate('channel ID:\n(depth-ordered)', (0,1), xycoords='axes fraction', va='top');

In [4]:
### Now we loop over all sessions to get the reduced data with those properties

In [None]:
 
def save_reduced_data(data, session_index, LFP, pLFP, channel_id):
    new_data = {'t0':data.t0, 'duration':data.duration,
                'pop_act':data.pop_act, 'pop_act_sampling_rate':data.pop_act_sampling_rate,
                'selected_channel_id':channel_id,
                'LFP':LFP, 'LFP_sampling_rate':data.lfp_sampling_rate,
                'pLFP':pLFP, 'pLFP_sampling_rate':data.lfp_sampling_rate,
                't_running_speed':data.t_running_speed, 'running_speed':data.running_speed,
               }
    np.save('demo/reduced_data/Allen_FC_session%i.npy' % (session_index+1), new_data)

    
for session_index in range(len(sessions)):
    
    data = Data(session_index=session_index, demo=False, 
                init=['pop_act'])
    pLFP, LFP, channel_id = find_channel_with_highest_delta(data)
    save_reduced_data(data, session_index, LFP, pLFP, channel_id)
    

In [None]:
data0 = Data0(0, t0=110*60, duration=20*60, init=['raster', 'pop_act'])


fig, ax = data.plot('running_speed')
ax.set_ylabel('speed (cm/s)')

In [59]:
plt.style.use('seaborn')

def plot_sample_data(data,
                     time_points=[100, 1000],
                     duration=2., 
                     pop_act_bin=5e-3):

    nplots = 6 if hasattr(data, 'V1_RASTER') else 5
    fig, AX_full = plt.subplots(nplots,len(time_points), figsize=(3*len(time_points), 7))
    if len(time_points)==1:
        AX_full = [AX_full]
        
    YLIMS = [[np.inf, -np.inf] for i in range(len(AX_full))]
    for t0, AX in zip(time_points, AX_full.T):
        
        # raster plot
        if hasattr(data, 'V1_RASTER'):
            for i, spikes in enumerate(data.V1_RASTER):
                cond = (spikes>t0) & (spikes<(t0+duration))
                AX[0].plot(spikes[cond], i+0*spikes[cond], 'ko', ms=1.5)

        # pop act. plot
        data.plot('pop_act', t0=t0, duration=duration, ax=AX[nplots-5])
        
        # LFP plot
        data.plot('LFP', t0=t0, duration=duration, ax=AX[nplots-4])
        
        # pLFP plot
        data.plot('pLFP', t0=t0, duration=duration, ax=AX[nplots-3], color=plt.cm.tab10(5))

        # NSI plot
        data.plot('pLFP_NSI', t0=t0, duration=duration, ax=AX[nplots-2], color=plt.cm.tab10(5))

        # speed plot
        data.plot('running_speed', t0=t0, duration=duration, ax=AX[nplots-1])

        # labelling axes and setting the same limes
        for j, label, ax in zip(range(len(AX)), 
                                ['units', 'rate (Hz)', 'LFP (mV)', 'pLFP (uV)', 'NSI (uV)', 'run. speed\n(cm/s)'],
                                AX):
            if ax in AX_full.T[0]:
                ax.set_ylabel(label)
            ax.set_xticks([])
            ax.set_xlim([t0,t0+duration])
            YLIMS[j] = [np.min([AX[j].get_ylim()[0], YLIMS[j][0]]),
                        np.max([AX[j].get_ylim()[1], YLIMS[j][1]])]
        AX[0].set_title('$t_0$=%.1fs' % (t0), size='small')
        
    for AX in AX_full.T:
        for j in range(len(AX)):
            try:
                AX[j].set_ylim(YLIMS[j])
            except BaseException:
                pass
    for t0, AX in zip(time_points, AX_full.T):
        AX[1].plot([t0,t0+0.2], .9*YLIMS[1][1]*np.ones(2), 'k-', lw=1)
        AX[1].annotate('200ms', (t0, .92*YLIMS[1][1]))

In [None]:
data = Data(session_index=0, demo=False, reduced=True)
plot_sample_data(data, 
                 time_points=[6605.5, 6602.3, 7245]) # session 1 -- functional_connectivity

In [56]:
from scipy.optimize import curve_fit
from scipy.interpolate import interp1d 

def get_accuracy(data,
                 rate_tolerance=2,
                 with_fig=True):
    
    pop_act_NSI_resampled = resample_trace(data.t_pop_act, 
                                           data.pop_act_NSI, data.t_pLFP)
    

    x, y = data.pLFP_NSI[data.i_pLFP_vNSI], pop_act_NSI_resampled[data.i_pLFP_vNSI]
    cond = ((x>0) & (y>0)) | ((x<0) & (y<0))

    lin = np.polyfit(x[cond], y[cond], 1)
    
    accuracy_cond = np.abs(y-np.polyval(lin, x))<rate_tolerance
    
    accuracy = 100*np.sum(accuracy_cond)/len(y)    
    if with_fig:
        fig, ax = plt.subplots(1, figsize=(2,2))
        ax.set_title('accuracy=%.1f%%'%accuracy)
        x = np.linspace(x.min(), x.max())
        ax.plot(data.pLFP_NSI[data.i_pLFP_vNSI], 
                 pop_act_NSI_resampled[data.i_pLFP_vNSI], 'o', lw=1, ms=1)
        ax.fill_between(x, np.polyval(lin, x)-rate_tolerance, np.polyval(lin, x)+rate_tolerance, color='g', alpha=.3)
        ax.plot(x, np.polyval(lin, x), 'k-')
        ax.set_ylabel('NSI$_{\,rate}$ (Hz)')
        ax.set_xlabel('NSI$_{\,pLFP}$ ($\mu$V)')
        return fig, ax, accuracy
    else:
        return accuracy

In [None]:
data = Data(session_index=5, demo=False, reduced=True)
            
data.compute_NSI(quantity='pLFP',
                 alpha=2.8,
                 p0_percentile=1)
data.compute_NSI(quantity='pop_act',
                 alpha=2,
                 p0_percentile=0)
data.validate_NSI(quantity='pLFP')

get_accuracy(data)
#ax.set_ylabel('speed (cm/s)')