# Check true spikes detection stata

author: steeve.laquitaine@epfl.ch


Useful:  
* Spikeinterface: "For all pairs of GT unit and tested unit we first count how many events are matched within a delta_time tolerance (0.4 ms by default)." (see ref 1)

Virtual env is `env_kilosort_silico`

You should have extracted the templates from a KS3 run by running `python3.9 -m src.pipes.postpro.univ_temp` in the terminal before. See usage in `univ_temp` module.


In [1]:
# SETUP PACKAGES 
%load_ext autoreload
%autoreload 2
import os 
import numpy as np
import pandas as pd
from spikeinterface import comparison

# SET PROJECT PATH
PROJ_PATH = "/gpfs/bbp.cscs.ch/project/proj68/home/laquitai/spike-sorting"
os.chdir(PROJ_PATH)


# SETUP PROJECT PACKAGE
from src.nodes.utils import get_config
from src.nodes.truth.silico import ground_truth
from src.nodes.io.silico import sorting
from src.nodes.postpro import spike_detection, metrics


# SET PARAMETERS
EXPERIMENT = "silico_neuropixels"
SIMULATION_DATE = "2023_02_19"
SAMPLING_FREQ = 10000 # sample/sec


# EXAMPLE UNIT PAIR MATCHED
TRUE_UNIT = 19690
SORTED_UNIT = 255

# MATCHING PARAMETERS
MATCH_WIND_MS = 0.4


# SET CONFIG
data_conf, param_conf = get_config(EXPERIMENT, SIMULATION_DATE).values()

# SET PATHS 
CELL_MATCHING_PATH = data_conf["postprocessing"]["cell_matching"]


2023-05-25 21:29:55,371 - root - utils.py - get_config - INFO - Reading experiment config.
2023-05-25 21:29:55,394 - root - utils.py - get_config - INFO - Reading experiment config. - done


## Get the true/sorted spike hits for an example true unit

In [2]:
# calculate the MATCH_WIND_MS (0.4 ms in SpikeInterface) matching window in timepoints
match_wind = int(MATCH_WIND_MS * SAMPLING_FREQ / 1000)

# load precomputed ground truth extractor
Truth = ground_truth.load(data_conf)

# load precomputed Sorting extractor
Sorting = sorting.load(data_conf)

# detect hits between a single true unit and all sorted unit timestamps
out = spike_detection.match_a_true_unit_spikes_to_all_sorted_spikes(true_unit_id=TRUE_UNIT, Truth=Truth, Sorting=Sorting, match_wind=match_wind)
out

2023-05-25 21:29:59,495 - root - ground_truth.py - load - INFO - loading already processed ground truth SortingExtractor ...
2023-05-25 21:29:59,508 - root - ground_truth.py - load - INFO - loading already processed true sorting - done in 0.0


{'sorted_ttps_hits': {1516374: [],
  2553188: [],
  2700057: [2700054],
  2754658: [2754661],
  4813134: []},
 'all_sorted_ttps': array([    468,     571,     802, ..., 5499805, 5499815, 5499853]),
 'unit_labels_for_sorted_ttps': array([297., 298., 219., ..., 184., 283.,  99.]),
 'sorted_unit_hits': {1516374: [],
  2553188: [],
  2700057: [255],
  2754658: [298],
  4813134: []}}

## Check spikes detection stata for example units

In [3]:
# get the detection status for each spikes of the chosen true unit
is_spike_detected = spike_detection.get_true_unit_spikes_detection_status(true_unit_id=TRUE_UNIT, data_conf=data_conf)
is_spike_detected

2023-05-25 21:30:01,315 - root - ground_truth.py - load - INFO - loading already processed ground truth SortingExtractor ...
2023-05-25 21:30:01,325 - root - ground_truth.py - load - INFO - loading already processed true sorting - done in 0.0


  dict([(k, pd.Series(v)) for k, v in hits_dict.items()])


Unnamed: 0_level_0,detected
events,Unnamed: 1_level_1
1516374,False
2553188,False
2700057,True
2754658,True
4813134,False


## Check my agreement score vs. SpikeInterface's for a true/sorted match

In [4]:
# get hit count
hit_count = metrics.get_hit_counts_for_a_true_units(out)

# get true unit's event count 
event_counts_truth = dict()
event_counts_truth[TRUE_UNIT] = metrics.get_event_count_truth(unit_id=TRUE_UNIT, Truth=Truth)

# get sorted unit's event count 
event_counts_sorting = dict()
event_counts_sorting[SORTED_UNIT] = metrics.get_event_count_sorting(unit_id=SORTED_UNIT, Sorting=Sorting)

# calculate agreement score
agreement_score = metrics.get_agreement_score(TRUE_UNIT, SORTED_UNIT, hit_count, event_counts_truth, event_counts_sorting)

  dict([(k, pd.Series(v)) for k, v in hits_dict.items()])


In [5]:
# check that spikeinterface produces the same agreement score
MatchingObject = comparison.compare_sorter_to_ground_truth(
    Truth, Sorting, exhaustive_gt=True
)

In [6]:
# sanity check
si_agreement_score = MatchingObject.agreement_scores.loc[TRUE_UNIT, SORTED_UNIT]
assert si_agreement_score == agreement_score, "Your agreement score differs from  SpikeInterface"

## References

(1) https://spikeinterface.readthedocs.io/en/0.96.1/module_comparison.html#compare-the-output-of-multiple-spike-sorters