# Vec File Based Analysis

This notebook allows you to perform a simple "classic analysis of raster and psth of your recording using you vec file as reference for sequences. Your vec file last column must contain the trigger sequence key. This shold have the following shape : "1" + "any number you want" + "Two digit repetition number". Repetition number don't need to be sorted. They will be analysed and stacked in the raster in their order of appearance in the vec file.

In [None]:
import gc
gc.collect()

%load_ext autoreload
%autoreload 2

import os
import sys
from tqdm.auto import tqdm
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('agg')
import params
from utils import *
from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec

import math
from scipy.io import loadmat

# For clustering
from sklearn.decomposition import PCA
from sklearn.cluster import AgglomerativeClustering
import scipy as sc
from sklearn.decomposition import SparsePCA
from collections import defaultdict

# Automatic raw files detection
def find_files(path):
    return sorted([os.path.splitext(f)[0] for f in os.listdir(path) if (os.path.isfile(os.path.join(path, f)) and os.path.splitext(f)[1] == ".mcd")])       #If no, the path is considered as a folder and return the name of all the files in alphabetic order


def split_spikes_between_triggers(spike_train,triggers):
    """
Returns a list of spikes includes between 2 triggers in a row. Everything must be in sampling point or sec.
    """
    return [spike_train[(spike_train >= triggers[i]) & (spike_train < triggers[i + 1])] for i in range(len(triggers) - 1)]

def get_sequences_triggers(triggers, vec):
    """
Spilt all triggers into dict of triggers from the same sequence using the key on the last colomn of the vec.
Same key for the triggers means same sequence.

Could be rewritten without the "defaultdict" trick
    """
    from collections import defaultdict
    sequences = defaultdict(list)

    keys = vec[:, -1].astype(int).astype(str)
    for key, trigger in zip(keys, triggers):
        sequences[key].append(trigger)
    
    return dict(sequences) #this dictionnary has its keys ordered as the vec. !! CAUTION !! works for python > 3.7 only
   
    
def get_spikes_sequences(spike_times, trig_seq):
    """
Read the first trigger of all sequence and group all spikes between each begining of sequence into a dict with
sequence key as dict key and a list of spike times with the 0 at the begining of a sequence.
    """
    trigs=[[trig_list[0],trig_list[-1]+np.mean(np.diff(np.array(trig_list)))] for trig_list in trig_seq.values()]  #make a list of all first and last trig of each seq    
    splited_spikes = [split_spikes_between_triggers(spike_times,seq_times)[0] for seq_times in trigs]
    return dict(zip(trig_seq.keys(), splited_spikes))
    
    
def spikeseq2raster(spikesequences, trig_seq):
    """
Makes a raster from a dictionnary of sequences splited with repetition. 
Looks for the key to stack repetitions (last 2 digits of the key). Repetition number is not representative of when it has been played 
    """
    
    from collections import defaultdict
    rasters = defaultdict(list) #more compliant than dict. Allows you to either use an existing key or create it with empty list and than use it if missing.

    for key in spikesequences.keys():
        rasters[key[1:-2]].append(spikesequences[key]-trig_seq[key][0])
    return dict(rasters)

def spikeseq2psth(raster, trig_seq, n_bin=40):
    psth={}
    for key in raster.keys():
        n_rep = len(raster[key])
        if key=='':
            seq_range  = (0, trig_seq['0'][-1]-trig_seq['0'][0] + np.mean(np.diff(trig_seq['0'])))
        else:
            seq_range  = (0, trig_seq['1'+key+'00'][-1]-trig_seq['1'+key+'00'][0] + np.mean(np.diff(trig_seq['1'+key+'00'])))
        
        if n_bin =="relative":
            all_spikes_times=[]
            for i in range(n_rep):
                all_spikes_times+=list(raster[key][i])
            psth[key] = np.histogram(np.array(all_spikes_times), bins=max(1,int(np.sqrt(len(all_spikes_times)))), range=seq_range   )[0]/n_rep
        else:
            binned_spike_count = np.zeros((n_rep, n_bin))
            for i in range(n_rep):
                binned_spike_count[i,:] = np.histogram(raster[key][i], bins=n_bin, range=seq_range   )[0]
            psth[key] = np.sum(binned_spike_count, axis=0)/n_rep
            
    return psth

def smooth(scalars: list[float], weight: float) -> list[float]:  # Weight between 0 and 1
    """
Function to smooth a 1D numpy array before plotting
    """
    last = scalars[0]  # First value in the plot (first timestep)
    smoothed = list()
    for point in scalars:
        smoothed_val = last * weight + (1 - weight) * point  # Calculate smoothed value
        smoothed.append(smoothed_val)                        # Save it
        last = smoothed_val                                  # Anchor the last smoothed value
        
    return smoothed

def reshape_dict(original_dict):
    """
This function allows you to reshape dictionnaries by reversing their keys. 
If you have {Cell1 : {key1: data, key2: data}, Cell2 : {key1: data, key2: data}}
you will get {key1 : {Cell1: data, Cell2: data}, key2 : {Cell1: data, Cell2: data}}
    """
    reshaped_dict = {}

    for cell_number, seq_dict in original_dict.items():
        for seq_number, data_dict in seq_dict.items():
            if seq_number not in reshaped_dict:
                reshaped_dict[seq_number] = {}
            reshaped_dict[seq_number][cell_number] = data_dict

    return reshaped_dict

### Load recording data 

In [None]:
"""
    Variables
    
    DO NOT CHANGE VALUES HERE UNLESS DEBUG/SPECIFIC USE
    
    All the variables used in this part of the cell should always refere to your 'params.py' file
    unless you want to manually change them only for this run (i.e. debugging). 
    You may have to add those variable into the function you want to adapt as only the minimal 
    amount of var are currently given to functions as inputs.
"""

recording_names = params.recording_names

#Experiment name
exp = params.exp

#Analysis output directory
output_directory=params.output_directory

#Sampling rate of the mea
fs = params.fs  

#Trigger directory
triggers_directory= params.triggers_directory

#Vec files drectory
vec_directory = os.path.join(params.root,'VEC_Files')

#List all vec files in vec files directory
available_vec = os.listdir(os.path.normpath(vec_directory))

"""
    Input
"""

cells_to_skip = []  ### LOAD HERE CELLS TO SKIP IF YOU WANT TO SELECT SPECIFICALY SOME OF THEM


#Find rec triggers
print(*['{} : {}'.format(i,recording_name) for i, recording_name in enumerate(recording_names)], sep="\n")
recording_number = int(input("\nSelect recording : "))
rec = recording_names[recording_number]
print(f"\nSelected recording : {rec} \n")

analysis_directory = os.path.normpath(os.path.join(output_directory,r'Vec_Analysis_rec_{}'.format(recording_number)))
if not os.path.isdir(analysis_directory): os.makedirs(analysis_directory)

#Find stim vec file
print(*['{} : {}'.format(i,vec_file) for i, vec_file in enumerate(available_vec)], sep="\n")
vec_number = int(input("\nSelect stimulus file : "))

"""
    Processing
"""

#Load vec file
vec = np.loadtxt(os.path.join(vec_directory,available_vec[vec_number]))[1:,:]  #Remove first line as it is not a trigger
print(f"\nSelected vec file : {available_vec[vec_number]}\n")
print(f"Vec file length : {vec.shape[0]}")

#load triggers
trig_data = load_obj(os.path.normpath(os.path.join(triggers_directory,'{}_{}_triggers.pkl'.format(exp,rec))))
trig_indices = trig_data['indices']
stim_onsets = trig_data['indices']/fs
print(f"Total triggers number : {len(stim_onsets)}")
print(f"Triggers type loaded : {trig_data['trigger_type']}")

spike_trains=load_obj(os.path.join(output_directory, r'{}_fullexp_neurons_data.pkl'.format(exp)))

cells=list(spike_trains.keys())
spike_times={}
for cell in cells:
    if cell in cells_to_skip: continue
    spike_times[cell] = (spike_trains[cell][rec])
    
print('Total : {} neurons loaded \n\nClusters id :\n{}\n'.format(len(spike_trains.keys()),cells))

### Compute raster and psth for all sequences types 

In [None]:
"""
    Input
"""

n_bin=40          # binning for the psth, default = 40
n_bin ='relative' # if n_bin = "relative", n_bin = sqrt(nb of spikes in the raster)



"""
    Variables
"""

dict_name = f'{rec}_dict.pkl'



"""
    Processing
"""

#Check if computing is needed first
if not os.path.exists(os.path.join(analysis_directory, dict_name)) or (str(input('Trigers already extracted previously. Write again files files? Type Yes to do so :\n')) in ["Y", "y", "yes", "Yes"]):
    print("Splitting spikes into sequences for each cluster...")
    rec_dict = {}
    sorted_data = {}
    trig_seq  = get_sequences_triggers(stim_onsets, vec) # Create a dictionnary with sequence keys including repetion number
                                                         # as key and a list of all the triggers in this repetition seq
                                                         # {"rep_seq_key" : [float]}
    for i in tqdm(cells):

        #Process
        rec_dict[i]={}
        spike_seq = get_spikes_sequences(spike_times[i], trig_seq) # Create a dictionnary for each cluster with sequence keys including repetion number
                                                                   # as key and a list of all the spikes in this repetition seq 
                                                                   # {"rep_seq_key" : [float]}
                
        raster    = spikeseq2raster(spike_seq, trig_seq)           # Create a dictionnary for each cluster with sequence types keys 
                                                                   # as key and a list of all the repetition of this sequence type (raster of the sequence type) 
                                                                   # {"seq_type_key" : [np.arrays]}

        psth      = spikeseq2psth(raster, trig_seq, n_bin=n_bin)   # Create a dictionnary for each cluster with sequence types keys 
                                                                   # as key and a binning of the raster 
                                                                   # {"seq_type_key" : np.array (len(n_bin))}
                
                
        #Cluster all data in a dict {Cell_id : {"seq_type_key": {'raster': [np.arrays]; 'psth':np.array (len(n_bin)) }  }   }
        for key in raster.keys():
            if key=='':
                continue
            rec_dict[i][key]={}
            rec_dict[i][key]['raster']   = raster[key]
            rec_dict[i][key]['psth']     = psth[key]
            
            #Add the necessary info for plotting about the seq length and triggers taking the repetition number 0 of this seqence type
            rec_dict[i][key]['triggers'] = {
                'start':trig_seq['1'+key+'00'][0], 
                'end':trig_seq['1'+key+'00'][-1], 
                'rng':(0, trig_seq['1'+key+'00'][-1]-trig_seq['1'+key+'00'][0] + np.mean(np.diff(trig_seq['1'+key+'00'])))
                 }   
            
        sorted_data[i] = spike_seq

        
        
"""
    Saving
"""

    print("Saving...")
    save_obj(rec_dict, os.path.join(analysis_directory, dict_name))
    save_obj(trig_seq, os.path.join(analysis_directory, f'{rec}_Sorted_triggers.pkl'))
    save_obj(sorted_data, os.path.join(analysis_directory, f'{rec}_Sorted_Spikes.pkl'))
else:
    print(f"Dictionnary loaded from : \n{os.path.join(analysis_directory, dict_name)}")
    rec_dict = load_obj(os.path.join(analysis_directory, dict_name))
    
del sorted_data 
print("----- Done -----")

### Plot Raster and PSTH for all conditions (cell x sequence types)

Can take a lot of time (about 10s per condition). You may want to load the rec_dict in an other notebook to do your own plotting for it to be more efficient !

In [None]:
clusters_as_folder = True   #If true, creates a folder per cell and save there a plot of this cell for each sequence
                            #else, creates a folder per sequence and save there a plot of this seq for each cell

color ="#B85A8F"     #Plotting color

if clusters_as_folder:
    dict_to_plot = rec_dict
    element = "Cell"
    scd_element = "Sequence"

else:
    dict_to_plot = reshape_dict(rec_dict)  #Reverse dict in the case of saving per sequence
    element = "Sequence"
    scd_element = "Cell"

for elt in tqdm(dict_to_plot.keys()):
    item_directory = os.path.normpath(os.path.join(analysis_directory,f'{element}_{elt}')) #Saving folder for this cell or sequence
    if not os.path.isdir(item_directory): os.makedirs(item_directory)
    
    for scd_elt in dict_to_plot[elt].keys():
        if os.path.isfile(os.path.join(item_directory,f'{scd_element}_{scd_elt}.png')): continue  #Check if need to plot or not
        
        # New figure
        fig, axs = plt.subplots(nrows = 2,ncols = 1, sharex=True, gridspec_kw={'height_ratios': [3, 1]}, figsize=(10,10))
                
        #Plot the rasters
        ax_rast = axs[0]
        ax_rast.eventplot(dict_to_plot[elt][scd_elt]["raster"], color=color)
        ax_rast.set(title = "Raster plot", ylabel='N Repetitions')
            
        #Plot the psth
        ax_psth = axs[1]
        y    = dict_to_plot[elt][scd_elt]["psth"]
        rng  = dict_to_plot[elt][scd_elt]['triggers']['rng']
        y    = y*(len(y)/(rng[1]-rng[0]))   #Turning number of spikes into firing rate 
        ysmooth = smooth(y,.4)
        
        x=np.linspace(rng[0],rng[1],len(y))
        ax_psth.fill_between(x, ysmooth,0,alpha=1, color=color)
        ax_psth.set(xlabel='Time in sec', ylabel='Firing rate (spikes/s)')
        ax_psth.set_ylim(bottom=-0.1, top=max(1, max(ysmooth)))

        #Finish the plot and save
        plt.suptitle(f'{scd_element}_{scd_elt}')
        plt.subplots_adjust(wspace=0, hspace=0)
        plt.savefig(os.path.join(item_directory,f'{scd_element}_{scd_elt}.png'))
        plt.close(fig)
        del fig
        gc.collect() #Just in case...