# Network State Index -- API demo

see: https://github.com/yzerlaut/Network_State_Index

We demonstrate the use of the API on the following publicly available dataset:

## the "Visual Coding – Neuropixels" dataset from the Allen Observatory

All details about this dataset and instructions for analysis are available at:

https://allensdk.readthedocs.io/en/latest/visual_coding_neuropixels.html

## Dataset download

I made a [custom script](https://github.com/yzerlaut/Network_State_Index/blob/main/demo/download_Allen_Visual-Coding_dataset.py) to download exclusively the part of the dataset of interest here (V1 probes).
You can run the script as:
```
python demo/download_Allen_Visual-Coding_dataset.py
```

In [1]:
# some general python / scientific-python modules
import os, shutil
import numpy as np
import pandas as pd
import matplotlib.pylab as plt

# the Network State Index API, see: https://github.com/yzerlaut/Network_State_Index
# install it with: "pip install git+https://github.com/yzerlaut/Network_State_Index"
import nsi

In [2]:
# Download and load the data with the "allensdk" API
# get the "allensdk" api with: "pip install allensdk"
from allensdk.brain_observatory.ecephys.ecephys_project_cache import EcephysProjectCache

# now let's define a cache repository for the data: by default ~/Downloads/ecephys_cache_dir
# insure that you have a "Downloads" repository in your home directory (/!\ non-english systems) or update below
data_directory = os.path.join(os.path.expanduser('~'), 'Downloads', 'ecephys_cache_dir')
manifest_path = os.path.join(data_directory, "manifest.json")
cache = EcephysProjectCache.from_warehouse(manifest=manifest_path)
all_sessions = cache.get_session_table() # get all sessions

### We restrict the analysis to:

    - wild type / wild type strain
    - male
    - recordings including V1 ("VISp")
    - "Brain Observatory 1.1" dataset

In [3]:
# let's filter the sessions according to the above criteria
sessions = all_sessions[(all_sessions.sex == 'M') & \
                        (all_sessions.full_genotype.str.find('wt/wt') > -1) & \
                        #(all_sessions.session_type == 'brain_observatory_1.1') & \
                        (all_sessions.session_type == 'functional_connectivity') & \
                        (['VISp' in acronyms for acronyms in all_sessions.ecephys_structure_acronyms])]
print(30*'--'+'\n--> Number of sessions with the desired characteristics: ' + str(len(sessions))+30*'--')
# sessions.head() # uncomment to see how they look

------------------------------------------------------------
--> Number of sessions with the desired characteristics: 11------------------------------------------------------------


## Minimal demo

In [None]:
LFP = np.random.randn(100)
pLFP = NSI.compute_pLFP(LFP,
                       freqs=np.linspace(50, 300, 10))
NSI = NSI.compute_NSI(pLFP)

## In a more structured way

In [None]:
# let's do things in a bit more structured way
import time

class Data:
    """
    an object to load, format and process the data
    """
    
    def __init__(self, session_index,
                 t0=120*60, 
                 duration=1*60, # 1 min by defualt
                 #init=['raster', 'pop_act', 'pLFP', 'NSI'], # for a full init
                 init = []):
        """
        loading data according to the index of the "sessions" above
        
        can load a subset using the "t0" and "duration" args
        """
        print('loading session #%i [...]' % (1+session_index))
        tic = time.time()
        # we load a single session
        session = cache.get_session_data(sessions.index.values[session_index])

        # use the running timestamps to set start and duration in the data object
        self.t0 = np.max([t0, session.running_speed.start_time.values[0]])
        self.duration = np.min([duration, session.running_speed.end_time.values[-1]-self.t0])
        
        # let's fetch the running speed
        cond = (session.running_speed.end_time.values>self.t0) &\
            (session.running_speed.start_time.values<(self.t0+self.duration))
        self.t_running_speed = .5*(session.running_speed.start_time.values[cond]+\
                                   session.running_speed.end_time.values[cond])
        self.running_speed = session.running_speed.velocity[cond]

        if 'raster' in init:
            # let's fetch the isolated single units in V1
            V1_units = session.units[session.units.ecephys_structure_acronym == 'VISp'] # V1==VISp
            self.V1_RASTER = []
            for i in V1_units.index:
                cond = (session.spike_times[i]>=self.t0) & (session.spike_times[i]<(self.t0+self.duration))
                self.V1_RASTER.append(session.spike_times[i][cond])
        else:
            self.V1_RASTER = None
        
        # let's fetch the V1 probe --> always on "probeC"
        probe_id = session.probes[session.probes.description == 'probeC'].index.values[0]
        
        # -- let's fetch the lfp data for that probe and that session --
        # let's fetch the all the channels falling into V1 domain
        self.V1_channel_ids = session.channels[(session.channels.probe_id == probe_id) & \
                      (session.channels.ecephys_structure_acronym.isin(['VISp']))].index.values

        # limit LFP to desired times and channels
        # N.B. "get_lfp" returns a subset of all channels above
        self.lfp_slice_V1 = session.get_lfp(probe_id).sel(time=slice(self.t0,
                                                                     self.t0+self.duration),
                                                          channel=slice(np.min(self.V1_channel_ids), 
                                                                        np.max(self.V1_channel_ids)))
        self.Nchannels_V1 = len(self.lfp_slice_V1.channel) # store number of channels with LFP in V1
        self.lfp_sampling_rate = session.probes.lfp_sampling_rate[probe_id] # keeping track of sampling rate
        
        
        if 'pop_act' in init:
            # let's compute the population activity from the spikes
            self.compute_pop_act() # you can recall this functions with different bins/smoothing        
            
        if 'pLFP' in init:
            self.compute_pLFP(t0=self.t0, duration=self.duration) # on the full trace
            if 'NSI' in init:
                self.compute_NSI()
                
        print('data successfully loaded in %.1fs' % (time.time()-tic))

    def update_t0_duration(self, t0, duration):
        t0 = t0 if (t0 is not None) else self.t0
        duration = duration if (duration is not None) else self.duration
        return t0, duration
        
    def compute_pop_act(self, 
                        pop_act_bin=5e-3,
                        pop_act_smoothing=20e-3):
        """
        we bin spikes to compute population activity
        """
        t_pop_act = self.t0+np.arange(int(self.duration/pop_act_bin)+1)*pop_act_bin
        pop_act = np.zeros(len(t_pop_act)-1)

        for i, spikes in enumerate(self.V1_RASTER):
            pop_act += np.histogram(spikes, bins=t_pop_act)[0]
        pop_act /= (len(self.V1_RASTER)*pop_act_bin)

        self.t_pop_act = .5*(t_pop_act[1:]+t_pop_act[:-1])
        self.pop_act = nsi.gaussian_filter1d(pop_act, 
                                             int(pop_act_smoothing/pop_act_bin)) # filter from scipy
        self.pop_act_sampling_rate = pop_act_bin
        
    def get_LFP(self, channelID,
                ground_channel_ID=0,
                t0=None, 
                duration=None):
        """
        channelID can be an integer or a list of integers
        
        returns time in s, LFP in mV
        """
        t0, duration = self.update_t0_duration(t0, duration)
        if type(channelID) in [range, list, np.array]:
            self.LFP = 1e3*self.lfp_slice_V1.isel(channel=channelID[0]).sel(time=slice(t0,t0+duration)).values/len(channelID)
            self.t_LFP = self.lfp_slice_V1.isel(channel=channelID[0]).sel(time=slice(t0,t0+duration)).time.values
            for i in channelID[1:]:
                self.LFP += 1e3*self.lfp_slice_V1.isel(channel=i).sel(time=slice(t0,t0+duration)).values/len(channelID)
        else:
            self.LFP = 1e3*self.lfp_slice_V1.isel(channel=channelID).sel(time=slice(t0,t0+duration)).values
            self.t_LFP = self.lfp_slice_V1.isel(channel=channelID).sel(time=slice(t0,t0+duration)).time.values
            
        if ground_channel_ID is not None:
            self.LFP -= 1e3*self.lfp_slice_V1.isel(channel=ground_channel_ID).sel(time=slice(t0,t0+duration)).values
            
        return self.t_LFP, self.LFP
    
    def compute_pLFP(self, 
                     channelIDs=None, ground_channel_ID=0,
                     t0=None, duration=None,
                     freqs = np.linspace(72.8/1.83,72.8*1.83,20),
                     new_dt=2e-3,
                     smoothing=42e-3):
        """
        ------------------------------
            HERE we use the NSI API
        ------------------------------
        by default on a zoom of the data, not the entire trace as this can be long
        call: "data.compute_pLFP(duration=data.duration)" to have it on the full trace
        
        pLFP in microvolts (uV)
        """
        print(' - computing pLFP [...]') 
        self.pLFP_sampling_rate = 1./new_dt
        
        t0, duration = self.update_t0_duration(t0, duration)
        
        if channelIDs is None:
            channelIDs = np.arange(self.Nchannels_V1) # using all channels
          
        # we use the first channel to init the pLFP
        self.t_pLFP, self.pLFP = nsi.compute_pLFP(1e3*self.get_LFP(channelID=channelIDs[0],
                                                               ground_channel_ID=ground_channel_ID,
                                                               t0=t0, 
                                                               duration=duration)[1]/len(channelIDs),
                                                  self.lfp_sampling_rate,
                                                  freqs=freqs, 
                                                  new_dt=new_dt,
                                                  smoothing=smoothing)
        self.t_pLFP += t0
        
        # we loop over all other channels (in case they exist)
        for i in channelIDs[1:]:
            self.pLFP += 1e3*nsi.compute_pLFP(self.get_LFP(channelID=i,
                                                       ground_channel_ID=ground_channel_ID,
                                                       t0=t0, duration=duration)[1],
                                          self.lfp_sampling_rate,
                                          freqs=freqs, 
                                          new_dt=new_dt,
                                          smoothing=smoothing)[1]/len(channelIDs)
        
        print(' - - > done !') 
        
        
    def compute_NSI(self, quantity='pLFP',
                    low_freqs = np.linspace(2, 5, 5),
                    p0_percentile=1.,
                    alpha=2.87,
                    with_subquantities=False):
        """
        """
        setattr(self, '%s_0' % quantity, np.percentile(getattr(self, quantity), p0_percentile/100.))
        
        if with_subquantities:
            lfe, sm, NSI = nsi.compute_NSI(getattr(self, quantity),
                                           getattr(self, '%s_sampling_rate' % quantity),
                                           low_freqs = low_freqs,
                                           p0=getattr(self, '%s_0' % quantity),
                                           alpha=alpha,
                                           with_subquantities=True)
            setattr(self, '%s_low_freq_env' % quantity, lfe)
            setattr(self, '%s_sliding_mean' % quantity, sm)
            setattr(self, '%s_NSI' % quantity, NSI)
        
        else:
            setattr(self, '%s_NSI' % quantity, nsi.compute_NSI(getattr(self, quantity),
                                                              getattr(self, '%s_sampling_rate' % quantity),
                                                              low_freqs = low_freqs,
                                                              p0=getattr(self, '%s_0' % quantity),
                                                              alpha=alpha))
        
    def plot(self, quantity, 
             t0=None, duration=None,
             ax=None,
             subsampling=1,
             color='k', 
             lw=1):
        """
        quantity as a string (e.g. "pLFP" or "running_speed")
        """
        
        t0, duration = self.update_t0_duration(t0, duration)
        
        if hasattr(self, quantity):
            if ax is None:
                fig, ax =plt.subplots(1, figsize=(8,3))
            else:
                fig = None
            t = getattr(self, 't_'+quantity.replace('_NSI','').replace('_low_freq_env','').replace('_sliding_mean',''))
            signal = getattr(self, quantity)
            cond = (t>t0) & (t<(t0+duration))
            ax.plot(t[cond][::subsampling], signal[cond][::subsampling], color=color, lw=lw)
            return fig, ax
        else:
            print('%s not an attribute of data' % quantity)

In [None]:
data = Data(0, t0=110*60, duration=20*60, init=['raster', 'pop_act'])
fig, ax = data.plot('running_speed')
ax.set_ylabel('speed (cm/s)')

In [14]:
plt.style.use('seaborn')

def plot_sample_data(data,
                     time_points=[100, 1000],
                     duration=2., 
                     pop_act_bin=5e-3):

    fig, AX_full = plt.subplots(6,len(time_points), figsize=(3*len(time_points), 7))
    if len(time_points)==1:
        AX_full = [AX_full]
        
    YLIMS = [[np.inf, -np.inf] for i in range(len(AX_full))]
    for t0, AX in zip(time_points, AX_full.T):
        
        # raster plot
        for i, spikes in enumerate(data.V1_RASTER):
            cond = (spikes>t0) & (spikes<(t0+duration))
            AX[0].plot(spikes[cond], i+0*spikes[cond], 'ko', ms=1.5)

        # pop act. plot
        data.plot('pop_act', t0=t0, duration=duration, ax=AX[1])
        
        # LFP plot
        lfp_slice = data.lfp_slice_V1.sel(time=slice(t0,t0+duration))
        for i in range(len(lfp_slice.channel.values)):
            AX[2].plot(lfp_slice.time, 1e3*lfp_slice.sel(channel=lfp_slice.channel[i]), 
                       lw=0.2, color=plt.cm.copper(1-i/(len(lfp_slice.channel.values)-1)))
        
        # pLFP plot
        data.plot('pLFP', t0=t0, duration=duration, ax=AX[3], color=plt.cm.tab10(5))

        # NSI plot
        data.plot('pLFP_NSI', t0=t0, duration=duration, ax=AX[4], color=plt.cm.tab10(5))

        # speed plot
        data.plot('running_speed', t0=t0, duration=duration, ax=AX[5])

        # labelling axes and setting the same limes
        for j, label, ax in zip(range(len(AX)), 
                                ['units', 'rate (Hz)', 'LFP (mV)', 'pLFP (uV)', 'NSI (uV)', 'run. speed\n(cm/s)'],
                                AX):
            if ax in AX_full.T[0]:
                ax.set_ylabel(label)
            ax.set_xticks([])
            ax.set_xlim([t0,t0+duration])
            YLIMS[j] = [np.min([AX[j].get_ylim()[0], YLIMS[j][0]]),
                        np.max([AX[j].get_ylim()[1], YLIMS[j][1]])]
        AX[0].set_title('$t_0$=%.1fs' % (t0), size='small')
        
    for AX in AX_full.T:
        for j in range(len(AX)):
            try:
                AX[j].set_ylim(YLIMS[j])
            except BaseException:
                pass
    for t0, AX in zip(time_points, AX_full.T):
        AX[1].plot([t0,t0+0.2], .9*YLIMS[1][1]*np.ones(2), 'k-', lw=1)
        AX[1].annotate('200ms', (t0, .92*YLIMS[1][1]))

In [None]:
plot_sample_data(data, 
                 time_points=[6605.5, 6602.3, 7245]) # session 1 -- functional_connectivity

In [None]:
plot_sample_data(data, time_points=[7199.5, 7531, 7421.5]) # session 1 - brain observatory

In [29]:
# looking for a good grounding...

#_ = data.plot('LFP[%i]' % channelID, t0=6605, duration=3)
channelIDs, gcID = [0], 2
data.get_LFP(channelIDs, ground_channel_ID=gcID)
data.compute_pLFP(channelIDs=channelIDs,ground_channel_ID=gcID,t0=6605, duration=3)
fig, ax = plt.subplots(1, figsize=(8,4))
_ = data.plot('LFP', t0=6605, duration=3, ax=ax)
_ = data.plot('pLFP', t0=6605, duration=3, ax=ax.twinx(),
              color=plt.cm.tab10(3), lw=2)


array([198.   , 198.005, 198.01 , ..., 212.565, 212.57 , 212.575])