In [2]:
%load_ext autoreload
%autoreload 2

import numpy as np
import csv
import os
from tqdm.auto import tqdm
import gc

print("Importing spike interface, this may take a while...")
import spikeinterface.full as si
import docker
print("Done...")
import scipy
from multiprocessing import cpu_count

import probeinterface as pi
from probeinterface.plotting import plot_probe_group, plot_probe

# Load probe
probe = pi.read_prb('mcs_256_30_8iR_ITO.prb')

import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('qt5agg')

from utils import *         # Local file containing all the functions that we need
import params               # Parameters file. You should tune it for your own experiment

Importing spike interface, this may take a while...
Done...


### Cell 1 : Open files

Open all recordings before filtering + sanity check

In [3]:
"""
    Variables
    
    DO NOT CHANGE VALUES HERE UNLESS DEBUG/SPECIFIC USE
    
    You will find here all variables used in this notebook cell. They should always refere to your 'params.py' file
    except if you want to manually change some variable only for this run (i.e. debugging). You may have to add those
    variable into the function you want to adapt as only the minimal amount of var are currently given to functions as inputs.
"""

#Link to the actual raw files from the recording listed in the input_file
recording_directory = params.recording_directory

#Loading raw recording files names
recording_names = params.recording_names

# number of triggers samples acquired per second
fs         = params.fs

Nchannels  = params.nb_channels                #256 for standard MEA, 17 for MEA1 Polychrome

"""
    Processing
"""

recording_names = [rec.replace('.raw','') for rec in recording_names]
rec_it = recording_names[:]+['end']
print('Number of recordings: {}\n'.format(len(recording_names)))

#getting onset for next prints
onsets = {}
onsets = recording_onsets(recording_names, path = recording_directory)

#Opening files
print('\nCheck that recordings lengths are consistent with recording names\n') 


for i in range(len(rec_it)-1):    
    print("{} minutes\t--->\t{} : {} -> OK".format(int((onsets[rec_it[i+1]]-onsets[rec_it[i]])/params.fs/60), i, rec_it[i]))


"""Output :

Var :
recordings_names : Ordered list of stimuli names played during experiment
"""   

print('\n\t\t\t------ End Of Cell ------')

Number of recordings: 6


Check that recordings lengths are consistent with recording names

30 minutes	--->	0 : checkerboard -> OK
11 minutes	--->	1 : chirp -> OK
7 minutes	--->	2 : drifting_gratings -> OK
50 minutes	--->	3 : white_noise_1d -> OK
37 minutes	--->	4 : moving_bars -> OK
119 minutes	--->	5 : perturbed_moving_bar -> OK

			------ End Of Cell ------


### Cell 2 : Extract the triggers

#### <center>REQUIRES CELL 1 RUN</center>

Extract triggers from either the visual or holo trigger channel. Automatic detection of Holographic recording. Check that the detection is perform on the right files. Perform triggers sanity checks for visual stimumi. You can plot them later on cell 4. Can take up to more than 1h to run all recordings depending on your experiment length.

In [4]:
"""
    Variable
    
    You will find here all variables used in this notebook cell. They should always refere to your 'params.py' file
    except if you want to manually change some variable only for this run (i.e. debugging). You may have to add those
    variable into the function you want to adapt as only the minimal amount of var are currently given to functions as inputs.
"""

#name of your experiment for saving the triggers
exp = params.exp

# select MEA (3=2p room) (4=MEA1 Polycrhome)
MEA = params.MEA                       

#the optimal threshhold for detecting stimuli onsets varies with the rig
threshold  = params.threshold           

Nchannels  = params.nb_channels                #256 for standard MEA, 17 for MEA1 Polychrome

# number of triggers samples acquired per second
fs         = params.fs

#The folder in which you want your triggers to be saved 
triggers_directory = params.triggers_directory

#Channel recording triggers in case of holographic stimuli
holo_channel_id   = params.holo_channel_id

#Channel recording triggers in case of visual stimuli
visual_channel_id = params.visual_channel_id 

"""
    Inputs
"""

#you can decide here to extract the triggers only for some recordings. List their indexes here (starting from 0).
select_rec = []    # do only measurement N, put [] or the complet list to call all of them


"""
    Processing
"""

for rec in range(len(recording_names)):
    if select_rec:
        if rec not in select_rec: continue
    
    print('\n-------------   Processing recording {} out of {}   -------------\n'.format(rec+1,len(recording_names)))

    # Creating all files path
    input_file    = os.path.join(recording_directory,recording_names[rec]+'.raw')
    trigger_file  = os.path.join(triggers_directory,'{}_{}_triggers.pkl'.format(exp,recording_names[rec]))
    data_file     = os.path.join(triggers_directory,'{}_{}_triggers_data.pkl'.format(exp,recording_names[rec]))
    
    print('The triggers are extracted from the sorting file:\t{}\nand the results will be saved at:\t\t\t{}'.format(recording_names[rec]+'.raw',trigger_file))
    if os.path.exists(data_file):
        if (str(input('Trigers already extracted previously. Write again files files? Type Y to do so :\n')) != 'Y') : continue
        
    if is_holographic_rec(input_file): 
        #in this case the stimulus was holograpic
        print(r" /!\ HOLOGRAPHIC Recording /!\ ")
        channel_id   = holo_channel_id
        trigger_type = 'holo'
        onsets_file  = os.path.join(triggers_directory,'{}_{}_laser_onsets.npy'.format(exp,recording_names[rec]))
        offsets_file  = os.path.join(triggers_directory,'{}_{}_laser_offsets.npy'.format(exp,recording_names[rec]))
    else: 
        #in this other case the stimulus was visual
        print(r" /!\ VISUAL Recording /!\ ")
        channel_id   = visual_channel_id        
        trigger_type = 'visual'
        
    #Processing of data calling utils functions
    print("Loading Data...")
    channel_id   = visual_channel_id        
    trigger_type = 'visual'      

    data, t_tot    = load_data(input_file, channel_id = channel_id )  #MANUALLY CHANGE HERE IF THE CHANNEL IS 
                                                                     #AUTHOMATICALLY MISDETECTED. IF SO IT SHOULD 
                                                                    #BE BECAUSE OF ALIASING OR BAD TRIGGER QUALITY
    indices        = detect_onsets(data,threshold)
    indices_errors = run_minimal_sanity_check(indices, stim_type = trigger_type)
    
    #Saving data using utils function save_obj
    save_obj({'indices':indices,'duration':t_tot,'trigger_type':trigger_type,'indice_errors':indices_errors}, trigger_file )
    save_obj(data,data_file)
    

    if trigger_type == 'holo':
        save_obj(indices, onsets_file)
    
        offsets = detect_offsets(data)
        save_obj(offsets, offsets_file)    
        
"""
    Output
    
    Saved in triggers_directory :

{experience_name}_{link_file_name}_triggers.pkl (dict) : 
    keys 'indices' --> detected triggers indices, 
         'duration' --> the stimuli duration, 
         'trigger_type' --> the detection visual or holo stimuli, 
         'indice_errors' --> triggers violating sanity check 
         
{experience_name}_{link_file_name}_triggers_data.pkl (numpy array) : raw signal recorded on the trigger channel
"""

print('\n\t\t\t------ End Of Cell ------')


-------------   Processing recording 1 out of 6   -------------

The triggers are extracted from the sorting file:	checkerboard.raw
and the results will be saved at:			/media/guiglaz/DATA1/20230824_tbt_1/Analysis/triggers/20230824_tbt_1_checkerboard_triggers.pkl
Trigers already extracted previously. Write again files files? Type Y to do so :


-------------   Processing recording 2 out of 6   -------------

The triggers are extracted from the sorting file:	chirp.raw
and the results will be saved at:			/media/guiglaz/DATA1/20230824_tbt_1/Analysis/triggers/20230824_tbt_1_chirp_triggers.pkl
Trigers already extracted previously. Write again files files? Type Y to do so :


-------------   Processing recording 3 out of 6   -------------

The triggers are extracted from the sorting file:	drifting_gratings.raw
and the results will be saved at:			/media/guiglaz/DATA1/20230824_tbt_1/Analysis/triggers/20230824_tbt_1_drifting_gratings_triggers.pkl
Trigers already extracted previously. Write agai

### CELL 3 : Plots triggers for sanity check

#### <center>REQUIRES CELL 1 RUN & CELL 2 RUN AT LEAST ONCE FOR THIS EXPERIMENT </center>


Plots the raw trigger signal with the detected triggers and the errors detected. Independently, plots also the detected triggers, should be a perfect diagonal. Third, plots the number of time points gap to the most common trigger duration (ie theoretical_time_per_frame +- ploted value).

#### <center>/!\/!\/!\ Caution on memory leaks /!\/!\/!\ </center> (if you know a solution please let me know)

In [6]:
"""
    Variable
    
    You will find here all variables used in this notebook cell. They should always refere to your 'params.py' file
    except if you want to manually change some variable only for this run (i.e. debugging). You may have to add those
    variable into the function you want to adapt as only the minimal amount of var are currently given to functions as inputs.
"""

#Experiment name
exp = params.exp

# Optimal threshhold for detecting stimuli onsets varies with the rig
threshold  = params.threshold

# Directory where plots will be saved
output_directory = params.output_directory


"""
    Inputs
"""

#Set True if you want the plots to be saved
save = False

#Define your x-axis ploting window in a tuple (x-min,x-max). Set False to plot the complete data
ploting_range = False


"""
    Ploting
"""

print(*['{} : {}'.format(i,recording_name) for i, recording_name in enumerate(recording_names)], sep="\n")
recordings = [int(rec_id) for rec_id in input("\nSelect recording : ").split()]


plt.close('all')
gc.collect()
plot_idx = 0

for rec in recordings:
    plot_idx+=1
    # Loading data from pickle files created in cell 3
    data    = np.array(load_obj(os.path.normpath(os.path.join(params.triggers_directory,'{}_{}_triggers_data.pkl'.format(exp,recording_names[rec])))))
    extracted = load_obj(os.path.normpath(os.path.join(params.triggers_directory,'{}_{}_triggers.pkl'.format(exp,recording_names[rec]))))
    err = extracted['indice_errors']
    indices = extracted['indices']
    rec_type = extracted["trigger_type"]
    
    # If ploting range is a tuple, reduce the plot to indices between both values of the tuple
    if ploting_range :
        indices = indices[np.logical_and(indices > ploting_range[0], indices < ploting_range[1])]
        data    = data[np.logical_and(np.array(range(len(data))) > ploting_range[0], np.array(range(len(data))) < ploting_range[1])]
        err     = err[np.logical_and(err > ploting_range[0], err < ploting_range[1])]
    
    plt.figure("Trigger sanity check {}".format(plot_idx))
    
    # Top plot with raw trigger signal, threshold of detection, detected triggers and wrong triggers
    plt.subplot(2,1,1)
    plt.title('{}\n{} channel'.format(recording_names[rec],rec_type))
    
    plt.plot(np.linspace(0,len(data)/fs,len(data) ),data)
    plt.plot(indices/fs,data[indices],'.',markersize=2,zorder=10)

    plt.axhline(threshold, color='green')
    plt.scatter(err/fs,data[err], color='red', marker='x',zorder = 15)
    
    # Bottom left plot of triggers indices. Shoule be a perfect diagonal
    plt.subplot(2,2,3)
    plt.plot(indices)
    plt.title('Detected indices')
    
    # Bottom right plot of relative error gap between detect time of frame and mean frame time
    plt.subplot(2,2,4)
    plt.plot(np.diff(np.diff(indices)))
    try :
        plt.title('Duration {} +- error'.format(np.round(np.mean(np.diff(indices)))))
    except :
        plt.title('Duration {} +- error'.format("NOT COMPUTED"))
                  
    plt.tight_layout()
    plt.show(block = False)
    
    # Saving plot if needed
    if save:
        fig_name = os.path.join(output_directory,r'{}_{}.png'.format(recording_names[rec],link_names[rec]))
        plt.savefig(fig_name)


"""
    Output
    
    if save == True
    
    {recording_file_name}_{link_file_name}.png : Plots for a given recording file

"""

print('\n\t\t\t------ End Of Cell ------')

0 : checkerboard
1 : chirp
2 : drifting_gratings
3 : white_noise_1d
4 : moving_bars
5 : perturbed_moving_bar

Select recording : 0

			------ End Of Cell ------


### Cell 4 : Preprocess all recordings

In [5]:
exp = params.exp
fs = params.fs

#The folder in which you want your triggers to be saved 
triggers_directory = params.triggers_directory

recordings = {}
for recording_name in recording_names:
    recordings[recording_name] = {}
    
    extracted = load_obj(os.path.normpath(os.path.join(params.triggers_directory,'{}_{}_triggers.pkl'.format(exp,recording_name))))
    trigger_type = extracted['trigger_type']
    
    # Open file
    recordings[recording_name]['raw'] = si.read_binary(os.path.join(recording_directory,recording_name+'.raw'), sampling_frequency=fs, num_chan=Nchannels, dtype='uint16')
    print(f"{recording_name} \n Opened", end=' ')
    print(recordings[recording_name]['raw'])
    #Set probe to .prb file in the folder
    recordings[recording_name]['raw'] = recordings[recording_name]['raw'].set_probegroup(probe)
    print("--> Probe attached",end=' ')
    print(recordings[recording_name]['raw'])
    # Filter recordings
    recordings[recording_name]['si_filtered'] = si.bandpass_filter(recordings[recording_name]['raw'], dtype="float32")
    print("--> Filtered bandpass",end=' ')
    print(recordings[recording_name]['si_filtered'])
    # Remove median for all recordings
    recordings[recording_name]['si_filtered_medianremoved'] = si.common_reference(recordings[recording_name]['si_filtered'])
    print("--> Median removed", end=' ')
    print(recordings[recording_name]['si_filtered_medianremoved'])
    
    if trigger_type == 'holo':       
        #Onsets and offsets loading
        stim_onsets = load_obj(os.path.join(triggers_directory,'{}_{}_laser_onsets.npy'.format(exp,recording_name)))
        stim_offsets = load_obj(os.path.join(triggers_directory,'{}_{}_laser_offsets.npy'.format(exp,recording_name)))

        #Computing times of laser bims
        times = np.sort(np.concatenate((stim_onsets, stim_offsets)))
    
        #Removing artefacts
        recordings[recording_name]['si_cleaned_zeros'] = si.remove_artifacts(recordings[recording_name]['si_filtered_medianremoved'], 
                                                               list_triggers=[list(times)], 
                                                               ms_before=5, 
                                                               ms_after=5, 
                                                               mode='zeros')
        print("--> laser on and off times set to 0") 
    print('\n')
    
    
print('\n\t\t\t------ End Of Cell ------')

checkerboard 
 Opened BinaryRecordingExtractor: 256 channels - 1 segments - 20.0kHz - 1816.300s
  file_paths: ['/media/guiglaz/DATA1/20230824_tbt_1/RAW_Files/checkerboard.raw']
--> Probe attached ChannelSliceRecording: 252 channels - 1 segments - 20.0kHz - 1816.300s
--> Filtered bandpass BandpassFilterRecording: 252 channels - 1 segments - 20.0kHz - 1816.300s
--> Median removed CommonReferenceRecording: 252 channels - 1 segments - 20.0kHz - 1816.300s


chirp 
 Opened BinaryRecordingExtractor: 256 channels - 1 segments - 20.0kHz - 663.000s
  file_paths: ['/media/guiglaz/DATA1/20230824_tbt_1/RAW_Files/chirp.raw']
--> Probe attached ChannelSliceRecording: 252 channels - 1 segments - 20.0kHz - 663.000s
--> Filtered bandpass BandpassFilterRecording: 252 channels - 1 segments - 20.0kHz - 663.000s
--> Median removed CommonReferenceRecording: 252 channels - 1 segments - 20.0kHz - 663.000s


drifting_gratings 
 Opened BinaryRecordingExtractor: 256 channels - 1 segments - 20.0kHz - 423.500s
  fi

### Visualize artefacts

Helps you visualize the artefacts once removed. If it breaks with an IndexError, you may have too much or too few triggers. It may not be an issue for you if you accept to put more parts of the recording to 0. If it is you have to rewrite the files laser_onsets and laser_offsets in the trigger folder.

In [20]:
assert params.MEA == 3, "You have a visual recording only. You don't have laser artefacts in your recording. Proceed with the rest of the notebook"

#import pylab as plt
exp = params.exp

frames_path = params.frames_path
fs = params.fs

waveform_id = 1 #waveform patern number
elec = 136  #electrode number

print(*['{} : {}'.format(i,recording_name) for i, recording_name in enumerate(recordings.keys())], sep="\n")
recording_name = recording_names[int(input("\nSelect holographic recording : "))]


frames_folder_files = [f for f in os.listdir(frames_path) if os.path.isfile(os.path.join(frames_path, f))]
print(*['{} : {}'.format(i,frame_file) for i, frame_file in enumerate(frames_folder_files)], sep="\n")
frame_name = frames_folder_files[int(input(f"\nSelect the DH_frames file for the recording {recording_name} \n"))]

#Frames loading
frames = scipy.io.loadmat(os.path.join(frames_path, frame_name))
spot_order = np.array([i[0] for i in frames['OrderFrames']])

#Onsets and offsets loading
onsets_file  = os.path.join(triggers_directory,'{}_{}_laser_onsets.npy'.format(exp,recording_name))
offsets_file  = os.path.join(triggers_directory,'{}_{}_laser_offsets.npy'.format(exp,recording_name))

stim_onsets = load_obj(onsets_file)
stim_offsets = load_obj(offsets_file)

#Computing times of laser bims
times = np.concatenate((stim_onsets, stim_offsets))
labels = np.concatenate((spot_order, spot_order+max(spot_order)))
idx = np.argsort(times)
times = times[idx]
labels = labels[idx]

sorting = si.NumpySorting.from_times_labels(times, labels, sampling_frequency=fs)
sorting = sorting.save()


waveforms = {}
waveforms['si_filtered_medianremoved'] = si.extract_waveforms(recordings[recording_name]['si_filtered_medianremoved'], 
                             sorting, ms_before=10, ms_after=10, mode='memory',
     n_jobs=10, allow_unfiltered=True, chunk_memory="10M", overwrite=True, sparse=False)

waveforms['si_cleaned_zeros'] = si.extract_waveforms(recordings[recording_name]['si_cleaned_zeros'], 
                             sorting, ms_before=10, ms_after=10, mode='memory',
     n_jobs=10, allow_unfiltered=True, chunk_memory="10M", overwrite=True, sparse=False)




## Plotting waveform with the probe plot

fig, axes = plt.subplots(1,2, figsize=(13,9))
si.plot_unit_templates(waveforms['si_filtered_medianremoved'], 
                       unit_ids=[waveform_id], same_axis=True, ax=axes[0], plot_legend=False)
si.plot_unit_templates(waveforms['si_cleaned_zeros'], 
                       unit_ids=[waveform_id], same_axis=True, ax=axes[1], plot_legend=False)

si.plot_unit_probe_map(waveforms['si_cleaned_zeros'], unit_ids=[waveform_id], with_channel_ids=True)


## Plotting waveform electrode wise  

unit_id = waveform_id

colors = ['k', 'r', 'b']

fig, axs = plt.subplots(4,1, figsize=(10,15))

ax = axs[0]
for i_key, key in enumerate(['si_filtered_medianremoved', 'si_cleaned_zeros']):
    wfs= waveforms[key].get_waveforms(unit_id=unit_id)
    for i_wf in range(wfs.shape[0]):
        ax.plot(wfs[i_wf,:,elec], color = colors[i_key],
                 label=key,
                alpha=0.1)
    ax.set_title("all waveforms")
        
ax = axs[1]
for i_key, key in enumerate(['si_filtered_medianremoved', 'si_cleaned_zeros']):
    wfs= waveforms[key].get_waveforms(unit_id=unit_id)
    ax.plot(np.median(wfs[:,:,elec],axis=0),
           color = colors[i_key])
    ax.set_title("medians")
        
for i_key, key in enumerate(['si_filtered_medianremoved', 'si_cleaned_zeros',]):
    ax = axs[i_key+2]

    wfs= waveforms[key].get_waveforms(unit_id=unit_id)
    for i_wf in range(wfs.shape[0]):
        ax.plot(wfs[i_wf,:,elec], color = colors[i_key],
                 label=key,
                alpha=0.2)
    ax.plot(np.median(wfs[:,:,elec],axis=0),
           color = colors[i_key])

    ax.set_title("waveforms and median")

plt.show(block=False)
print('\n\t\t\t------ End Of Cell ------')

AssertionError: You have a visual recording only. You don't have laser artefacts in your recording. Proceed with the rest of the notebook

### Select recordings to cluster

In [6]:
#from colorama import Fore, Style

skip_recording = []

recording_list = []
for rec_idx, recording_name in enumerate(recording_names):
    
    if rec_idx in skip_recording and (not print(Fore.RED+f"{rec_idx} : {recording_name} --> /! SKIPPED /! "+Style.RESET_ALL)):  continue
    
    print(f"{rec_idx} : {recording_name}",end=' ')
    if "si_cleaned_zeros" in recordings[recording_name].keys():
        recording_list += [recordings[recording_name]['si_cleaned_zeros']]
        print('--> holo')
    else:
        print("\t --> visual")
        recording_list += [recordings[recording_name]['si_filtered_medianremoved']]

multirecording = si.concatenate_recordings(recording_list)
multirecording = multirecording.set_probe(recording_list[0].get_probe())

0 : checkerboard 	 --> visual
1 : chirp 	 --> visual
2 : drifting_gratings 	 --> visual
3 : white_noise_1d 	 --> visual
4 : moving_bars 	 --> visual
5 : perturbed_moving_bar 	 --> visual


### RUN SORTING
Choose your sorter to run and go have a coffee with a long paper or go home

In [14]:
import pprint
pprint.pprint(si.get_sorter_params_description('kilosort2'))
print('')
default_params = si.get_default_sorter_params('kilosort2')
pprint.pprint(default_params)

{'NT': 'Batch size (if None it is automatically computed)',
 'car': 'Enable or disable common reference',
 'chunk_duration': 'Chunk duration in s if float or with units if str (e.g. '
                   "'1s', '500ms') (when saving to binary) - default global",
 'chunk_memory': "Memory usage for each job (e.g. '100M', '1G') (when saving "
                 'to binary) - default global',
 'chunk_size': 'Number of samples per chunk (when saving ti binary) - default '
               'global',
 'detect_threshold': 'Threshold for spike detection',
 'freq_min': 'High-pass filter cutoff frequency',
 'keep_good_only': "If True only 'good' units are returned",
 'minFR': 'Minimum spike rate (Hz), if a cluster falls below this for too long '
          'it gets removed',
 'minfr_goodchannels': "Minimum firing rate on a 'good' channel",
 'nPCs': 'Number of PCA dimensions',
 'n_jobs': 'Number of jobs (when saving ti binary) - default -1 (all cores)',
 'nfilt_factor': 'Max number of clusters per good 

In [12]:
###############################
##### Circus 1 dockerised #####
###############################

clustering_folder = os.path.join(params.sorting_directory,r'Circus_1/clustering')
if not os.path.exists(clustering_folder): os.makedirs(clustering_folder)

waveforms_directory = os.path.join(params.sorting_directory,r'Circus_1/waveforms')
if not os.path.exists(waveforms_directory): os.makedirs(waveforms_directory)

default_params = si.get_default_sorter_params('spykingcircus')
custom_params = default_params.copy()
custom_params['num_workers'] = int((cpu_count()/2)-2)
custom_params['filter'] = False

sorting = si.run_sorter('spykingcircus',
    recording=test_recording,
    output_folder="spyking_circus",
    **custom_params,
    verbose=True,
    docker_image=True)

print(sorting)

Starting container
Installing spikeinterface==0.97.1 in spikeinterface/spyking-circus-base
Running spykingcircus sorter inside spikeinterface/spyking-circus-base
Stopping container
SpykingCircusSortingExtractor: 8 units - 1 segments - 20.0kHz


In [None]:
####################
##### Circus 1 #####
####################

clustering_folder = os.path.join(params.sorting_directory,r'Circus_1/clustering')
if not os.path.exists(clustering_folder): os.makedirs(clustering_folder)

waveforms_directory = os.path.join(params.sorting_directory,r'Circus_1/waveforms')
if not os.path.exists(waveforms_directory): os.makedirs(waveforms_directory)


# default_params = si.SpykingcircusSorter.default_params()
default_params = si.get_default_sorter_params('spykingcircus')
custom_params = default_params.copy()
custom_params['num_workers'] = int((cpu_count()/2)-2)
custom_params['filter'] = False


sorting = si.run_sorter('spykingcircus',
                        multirecording, 
                        clustering_folder,
                       **custom_params,
                        docker_image=True,
                       verbose=True)

print(f"Waveforms extraction to {waveforms_directory}")
w = si.extract_waveforms(multirecording, sorting, waveforms_directory, dtype='float32', 
                         chunk_memory="10M", overwrite=True, sparse=True, method='snr', threshold=1, n_jobs=int(cpu_count()/2))


print(sorting)

w = si.extract_waveforms(multirecording, sorting, waveforms_directory, dtype='float32', 
                         chunk_memory="10M", overwrite=True, sparse=True, method='snr', threshold=1, n_jobs=int(cpu_count()/2))



In [None]:
####################
##### Circus 2 #####
####################

clustering_folder = os.path.join(params.sorting_directory,r'Circus_2/clustering')
if not os.path.exists(clustering_folder): os.makedirs(clustering_folder)

waveforms_directory = os.path.join(params.sorting_directory,r'Circus_2/waveforms')
if not os.path.exists(waveforms_directory): os.makedirs(waveforms_directory)

nb_cpus = int(cpu_count()/2)

default_params = si.Spykingcircus2Sorter.default_params()
custom_params = default_params.copy()
custom_params['job_kwargs'] = {'n_jobs': nb_cpus, 'verbose': True}
custom_params['apply_preprocessing'] = False
custom_params

sorting = si.run_sorter('spykingcircus2',
                        multirecording, 
                        clustering_folder,
                       **custom_params,
                       verbose=True)

print(f"Waveforms extraction to {waveforms_directory}")
w = si.extract_waveforms(multirecording, sorting, waveforms_directory, dtype='float32', 
                         chunk_memory="10M", overwrite=True, sparse=True, method='snr', threshold=1, n_jobs=nb_cpus)


print(sorting)

w = si.extract_waveforms(multirecording, sorting, waveforms_directory, dtype='float32', 
                         chunk_memory="10M", overwrite=True, sparse=True, method='snr', threshold=1, n_jobs=nb_cpus)



### Sorting Comparision

Work in progress

In [None]:
# Then run 3 spike sorters and compare their outputs.
sortings = {
    "spykingcircus1"  : si.read_sorter_folder(os.path.join(params.sorting_directory,r'Circus_1/clustering')),
    "spykingcircus2"  : si.read_sorter_folder(os.path.join(params.sorting_directory,r'Circus_2/clustering')),
    "yass" : si.read_sorter_folder(os.path.join(params.sorting_directory,r'YASS/clustering')),
    "kilosort2"   : si.read_sorter_folder(os.path.join(params.sorting_directory,r'kilosort2/clustering'))
    }

# Compare multiple spike sorter outputs
mcmp = si.compare_multiple_sorters(
    sorting_list=[item[1] for item in sorting.items]
    name_list=[key for key in sortings.keys()],
    verbose=True,
)

# The multiple sorters comparison internally computes pairwise comparisons,
# that can be accessed as follows:
print(mcmp.comparisons[('spykingcircus1', 'spykingcircus2')].sorting1, mcmp.comparisons[('spykingcircus1', 'spykingcircus2')].sorting2)
print(mcmp.comparisons[('spykingcircus1', 'spykingcircus2')].get_matching())

print(mcmp.comparisons[('spykingcircus1', 'kilosort2')].sorting1, mcmp.comparisons[('spykingcircus1', 'kilosort2')].sorting2)
print(mcmp.comparisons[('spykingcircus1', 'kilosort2')].get_matching())

# The global multi comparison can be visualized with this graph
sw.plot_multicomp_graph(mcmp)

# Consensus-based method

agr_3 = mcmp.get_agreement_sorting(minimum_agreement_count=3)
print('Units in agreement for all three sorters: ', agr_3.get_unit_ids())

agr_2 = mcmp.get_agreement_sorting(minimum_agreement_count=2)
print('Units in agreement for at least two sorters: ', agr_2.get_unit_ids())

agr_all = mcmp.get_agreement_sorting()

# The unit index of the different sorters can also be retrieved from the
# agreement sorting object (:code:`agr_3`) property :code:`sorter_unit_ids`.

print(agr_3.get_property('unit_ids'))

print(agr_3.get_unit_ids())
# take one unit in agreement
unit_id0 = agr_3.get_unit_ids()[0]
sorter_unit_ids = agr_3.get_property('unit_ids')[0]
print(unit_id0, ':', sorter_unit_ids)

### Export to phy

In [None]:
"""
    Input
Give the path to your sorting folder (i.e. kilosort, circus1, circus2, ...)
"""

clustering_folder = os.path.join(params.sorting_directory,r'Circus_1/clustering')

waveforms_directory = os.path.join(params.sorting_directory,r'Circus_1/waveforms')

"""
    Params
"""

copy_binary = True   #Set to false for it to be much faster but you will lose spikes raw traces on the waveform view
                      #There is currently an error in the spikeinterface code with an uninitalized variable when exporting without raw
                      #It has been reported and should be fixed in the futur. 
                      #### Meanwhile, juste keep True here ####

phy_directory = params.phy_directory

nb_cpus = int(cpu_count()/2)  #This should be at maximum the number of real cores you have in your cpu (intel's cpu shows hyperthreaded number of cores)

"""
    Processing
"""

sorting = si.read_sorter_folder(clustering_folder)
print(sorting)
try:
    
    w = si.load_waveforms(waveforms_directory)
    print(f"Waveforms read from {waveforms_directory}")
except:
    print(f"Waveforms extraction to {waveforms_directory}")
    w = si.extract_waveforms(multirecording, sorting, waveforms_directory, dtype='float32', 
                         chunk_memory="10M", overwrite=True, sparse=True, method='snr', threshold=1, n_jobs=nb_cpus)

print(f"Exportion to phy format at {phy_directory}")
si.export_to_phy(w, phy_directory,
                 copy_binary = copy_binary,
                 compute_pc_features=False,
                 compute_amplitudes=True,
                 remove_if_exists=True,
                 verbose=True,
                 n_jobs=nb_cpus)

### Cell 6 : Creating all clusters rasters plots

#### <center>REQUIRES CELL 1 RUN</center>

Run after automatic sorting to help with manual sorting. Saves all automatic clusters' rasters on the repeated checherboard in the phy directory. You can run this several times during sorting to make new clusters rasters. Can take few sec per cluster...

In [None]:
"""
    Variable
    
    You will find here all variables used in this notebook cell. They should always refere to your 'params.py' file
    except if you want to manually change some variable only for this run (i.e. debugging). You may have to add those
    variable into the function you want to adapt as only the minimal amount of var are currently given to functions as inputs.
"""
#length of each sequence composed of half repeated sequence and half random sequence

nb_frames_by_sequence = params.nb_frames_by_sequence

# Name of your experiment
exp = params.exp

#Path to the folder with the phy output
phy_directory = params.phy_directory

#Path to raw recording files
recoding_directory = params.recording_directory

#Frequency of sampling of the mea
fs = params.fs

"""
    Input
"""
#Number of the checkerboard recording of choice (start from zero)
print(*['{} : {}'.format(i,recording_name) for i, recording_name in enumerate(recordings.keys())], sep="\n")
check_recording_number = int(input("\nSelect Checkerboard recording : "))
checkerboard_name = recording_names[check_recording_number]

#Checkerboard frequency in Hz

stimulus_frequency = int(input("\nEnter Checkerboard frequency of the recording '{}.raw' : ".format(checkerboard_name)))

"""
    Processing
"""
###################################
#### Loading phy clusters info ####
###################################



rec_onsets = recording_onsets(recording_names, path = recording_directory)  

# Get cells index and number
cluster_number , good_clusters = extract_cluster_groups(phy_directory)

# Extract the spike times from the spike sorting files. This can take a few minutes.
print('Spike extraction: ')
all_spike_times = extract_all_spike_times_from_phy(phy_directory)

# create a dictionary with another dictionary for each cluster
all_neurons_data = split_spikes_by_recording(all_spike_times, cluster_number, rec_onsets)


#############################
#### Making raster plots ####
#############################

checkerboard_name = recording_names[check_recording_number]

checkerboard_spikes = get_recording_spikes(checkerboard_name, all_neurons_data)

trig_data = load_obj(os.path.normpath(os.path.join(params.triggers_directory,'{}_{}_triggers.pkl'.format(exp,checkerboard_name))))
triggers = trig_data['indices']/fs
raster_data = {}

print('Building Raster plots : ')
for (cell_nb, spike_times) in tqdm(checkerboard_spikes.items()):
    # Align triggers and spike times
    aligned_triggers, aligned_spike_times = align_triggers_spikes(triggers, spike_times)
    
    # Get rasters on repeated sequence
    raster_data[cell_nb] = build_rasters(aligned_spike_times, aligned_triggers, stim_frequency = stimulus_frequency)


"""
    Saving
"""

#Save all clusters rasters plots    
fig_directory = os.path.normpath(os.path.join(phy_directory,r'clusters_rasters_{}'.format(check_recording_number)))
if not os.path.isdir(fig_directory): os.makedirs(fig_directory)

    
print("Saving rasters plots of all clusters :")
for cell_nb in tqdm(raster_data.keys()):
    fig, axs = plt.subplots(nrows = 2,ncols = 1, sharex=True, gridspec_kw={'height_ratios': [3, 1]}, figsize=(10,10))

    plt.suptitle(f'Cell {cell_nb}')

    ax_rast = axs[0]
    ax_rast.eventplot(raster_data[cell_nb]["spike_trains"])
    ax_rast.set(title = "Raster plot", ylabel='N Repetitions')

    ax_psth = axs[1]
    width = (raster_data[cell_nb]["repeated_sequences_times"][0][0]/600)
    ax_psth.bar(np.linspace(0,raster_data[cell_nb]["repeated_sequences_times"][0][0],int(nb_frames_by_sequence/2))+width/2, raster_data[cell_nb]["psth"], width=1.3*width)
    ax_psth.set(xlabel='Time in sec', ylabel='Firing rate (spikes/s)')

    plt.subplots_adjust(wspace=0, hspace=0)
    fig_file = os.path.normpath(os.path.join(fig_directory,f'Cluster_{cell_nb}.png'))
    plt.savefig(fig_file, dpi=fig.dpi)
    plt.clf()
    plt.close()
    
"""
    Output
    
    Save :
    
    ""{phy_directory}/clusters_rasters/Cluster_{Cluster_number}.png" for each found clusters in phy's files
"""    

print('\n\t\t\t------ End Of Cell ------')

In [None]:
out = f"phy template-gui {params.phy_directory}/params.py"

!env QTWEBENGINE_CHROMIUM_FLAGS="--single-process" {out}

# <center>/!\/!\/!\ Run this after sorting /!\/!\/!\ </center>

### Cell 7 : Extract data per neurons

#### <center>REQUIRES CELL 1 RUN</center>

Extract all data from phy numpy variables. Create&save a dictionnary containg spikes times in sec for each neuron splited by recording. Depending on your experiment, this can take severeal minutes.

In [32]:
"""
    Variable
    
    You will find here all variables used in this notebook cell. They should always refere to your 'params.py' file
    except if you want to manually change some variable only for this run (i.e. debugging). You may have to add those
    variable into the function you want to adapt as only the minimal amount of var are currently given to functions as inputs.
"""

# Name of your experiment
exp = params.exp

#Path to the folder with the phy output
phy_directory = params.phy_directory

#Path to where data should be saved
output_directory = params.output_directory

#Path to rax recording files
recoding_directory = params.recording_directory

#Frequency of sampling of the mea
fs = params.fs


"""
    Processing
"""

rec_onsets    = recording_onsets(recording_names, path = recording_directory)  

# Get cells index and number
cluster_number , good_clusters = extract_cluster_groups(phy_directory)
print("There are {} good clusters ({} clusters in total)\n".format(len(good_clusters), len(cluster_number)))


# Extract the spike times from the spike sorting files. This can take a few minutes.
print('Spike extraction: ')
all_spike_times = extract_all_spike_times_from_phy(phy_directory)

print('\n')
print('Spike division in recordings per neuron:')
# create a dictionary with another dictionary for each good cluster
good_data = split_spikes_by_recording(all_spike_times, good_clusters, rec_onsets)


# Save the spike data. This can take a few minutes.
good_data_file_name = os.path.join(output_directory,r'{}_fullexp_neurons_data.pkl'.format(exp))
save_obj(good_data,good_data_file_name)

"""
    Output
    
    data (dict) : key 'cluster_id' --> (dict) key 'recording_name' --> This neuron & this recording spikes times in sec

"""

print('\n\t\t\t------ End Of Cell ------')

There are 0 good clusters (151 clusters in total)

Spike extraction: 


  0%|          | 0/3933362 [00:00<?, ?it/s]



Spike division in recordings per neuron:


0it [00:00, ?it/s]


			------ End Of Cell ------


In [13]:
a = np.load(os.path.join(params.phy_directory, 'spike_clusters.npy'))
fold='/media/guiglaz/Guilhem_01/deby/20220722_VIP.Project_18betaG/sorting/recording_0/recording_0.GUI'
b = np.load(os.path.join(fold, 'spike_clusters.npy'))

In [34]:
a = np.load(os.path.join(params.phy_directory, 'spike_times.npy'))
fold='/media/guiglaz/Guilhem_01/deby/20220722_VIP.Project_18betaG/sorting/recording_0/recording_0.GUI'
b = np.load(os.path.join(fold, 'spike_times.npy'))

In [36]:
a.shape

(3933362, 1)