# (6m) By time

TODO:
- make sure the number of ground truth unit is the true number of units within 50 microns that spike at least once.

## Setup

activate `spikeinterf..`

In [9]:
# SETUP PACKAGES 
%load_ext autoreload
%autoreload 2
import os
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.gridspec import GridSpec
import spikeinterface as si
from spikeinterface import comparison

# SET PROJECT PATH
PROJ_PATH = "/gpfs/bbp.cscs.ch/project/proj85/home/laquitai/preprint_2023"
os.chdir(PROJ_PATH)
from src.nodes.utils import get_config
from src.nodes.postpro.accuracy import get_sorting_accuracies

# DATASETS

# BUCCINO 2020 
data_conf, _ = get_config("buccino_2020", "2020").values()
BUCCI_KS3_SORTING_PATH = data_conf["sorting"]["sorters"]["kilosort3"]["output"]
BUCCI_GT_SORTING_PATH = data_conf["sorting"]["simulation"]["ground_truth"]["output"]

# SILICO MARQUES
data_conf_marques, _ = get_config("silico_neuropixels", "concatenated").values()
PREP_PATH_M = data_conf_marques["preprocessing"]["output"]["trace_file_path"]
GT_SORTING_PATH_marques = data_conf_marques["sorting"]["simulation"]["ground_truth"]["output"]
KS_SORTING_PATH_marques = data_conf_marques["sorting"]["sorters"]["kilosort"]["output"]

# FIGURE SETTINGS
COLOR_VIVO = (0.7, 0.7, 0.7)
COLOR_SILI = (0.84, 0.27, 0.2)
COLOR_STIM = (0.6, 0.75, 0.1)
BOX_ASPECT = 1                  # square fig
FIG_SIZE = (1,1)
plt.rcParams['figure.figsize'] = (2,1)
plt.rcParams["font.family"] = "Arial"
plt.rcParams["font.size"] = 6
plt.rcParams['lines.linewidth'] = 0.2
plt.rcParams['axes.linewidth'] = 0.5
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False
plt.rcParams['xtick.major.width'] = 0.3
plt.rcParams['xtick.minor.size'] = 0.1
plt.rcParams['xtick.major.size'] = 1.5
plt.rcParams['ytick.major.size'] = 1.5
plt.rcParams['ytick.major.width'] = 0.3
legend_cfg = {"frameon": False, "handletextpad": 0.1}
savefig_cfg = {"transparent":True}
# print(plt.rcParams.keys())

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
2024-05-08 18:31:01,130 - root - utils.py - get_config - INFO - Reading experiment config.
2024-05-08 18:31:01,147 - root - utils.py - get_config - INFO - Reading experiment config. - done
2024-05-08 18:31:01,149 - root - utils.py - get_config - INFO - Reading experiment config.
2024-05-08 18:31:01,179 - root - utils.py - get_config - INFO - Reading experiment config. - done


In [12]:
print("duration:", si.load_extractor(PREP_PATH_M).get_total_duration() / 60, "min")

duration: 34.299982500000006 min




### (6m) Cumulative period

In [15]:
# full
SortingTrue = si.load_extractor(GT_SORTING_PATH_marques)
SortingKS = si.load_extractor(KS_SORTING_PATH_marques)

# 10 min
SortingTrue10 = SortingTrue.frame_slice(start_frame=0, end_frame=10 * 60 * 40000)
SortingKS10 = SortingKS.frame_slice(start_frame=0, end_frame=10 * 60 * 40000)
comp = comparison.compare_sorter_to_ground_truth(
    SortingTrue10, SortingKS10, exhaustive_gt=True
)
acc_ks_10m = comp.agreement_scores.max(axis=1).sort_values(ascending=False).values

# 20 min
SortingTrue20 = SortingTrue.frame_slice(start_frame=0, end_frame=20 * 60 * 40000)
SortingKS20 = SortingKS.frame_slice(start_frame=0, end_frame=20 * 60 * 40000)
comp = comparison.compare_sorter_to_ground_truth(
    SortingTrue20, SortingKS20, exhaustive_gt=True
)
acc_ks_20m = comp.agreement_scores.max(axis=1).sort_values(ascending=False).values


# 30 min
SortingTrue30 = SortingTrue.frame_slice(start_frame=0, end_frame=30 * 60 * 40000)
SortingKS30 = SortingKS.frame_slice(start_frame=0, end_frame=30 * 60 * 40000)
comp = comparison.compare_sorter_to_ground_truth(
    SortingTrue30, SortingKS30, exhaustive_gt=True
)
acc_ks_30m = comp.agreement_scores.max(axis=1).sort_values(ascending=False).values

# 40 min
SortingTrue40 = SortingTrue.frame_slice(start_frame=0, end_frame=40 * 60 * 40000)
SortingKS40 = SortingKS.frame_slice(start_frame=0, end_frame=40 * 60 * 40000)
comp = comparison.compare_sorter_to_ground_truth(
    SortingTrue40, SortingKS40, exhaustive_gt=True
)
acc_ks_40m = comp.agreement_scores.max(axis=1).sort_values(ascending=False).values



In [16]:
print(sum(acc_ks_10m >= 0.8) / acc_ks_10m.shape[0])
print(sum(acc_ks_20m >= 0.8) / acc_ks_20m.shape[0])
print(sum(acc_ks_30m >= 0.8) / acc_ks_30m.shape[0])
print(sum(acc_ks_40m >= 0.8) / acc_ks_40m.shape[0])

0.0792507204610951
0.08213256484149856
0.04755043227665706
0.00792507204610951


### Contiguous periods

In [4]:
# full
SortingTrue = si.load_extractor(GT_SORTING_PATH_marques)
SortingKS = si.load_extractor(KS_SORTING_PATH_marques)

# 10 min
SortingTrue10 = SortingTrue.frame_slice(start_frame=0, end_frame=10 * 60 * 40000)
SortingKS10 = SortingKS.frame_slice(start_frame=0, end_frame=10 * 60 * 40000)
comp = comparison.compare_sorter_to_ground_truth(
    SortingTrue10, SortingKS10, exhaustive_gt=True
)
acc_ks_10m = comp.agreement_scores.max(axis=1).sort_values(ascending=False).values

# 20 min
SortingTrue20 = SortingTrue.frame_slice(
    start_frame=10 * 60 * 40000 + 1, end_frame=20 * 60 * 40000
)
SortingKS20 = SortingKS.frame_slice(
    start_frame=10 * 60 * 40000 + 1, end_frame=20 * 60 * 40000
)
comp = comparison.compare_sorter_to_ground_truth(
    SortingTrue20, SortingKS20, exhaustive_gt=True
)
acc_ks_20m = comp.agreement_scores.max(axis=1).sort_values(ascending=False).values


# 30 min
SortingTrue30 = SortingTrue.frame_slice(
    start_frame=20 * 60 * 40000 + 1, end_frame=30 * 60 * 40000
)
SortingKS30 = SortingKS.frame_slice(
    start_frame=20 * 60 * 40000 + 1, end_frame=30 * 60 * 40000
)
comp = comparison.compare_sorter_to_ground_truth(
    SortingTrue30, SortingKS30, exhaustive_gt=True
)
acc_ks_30m = comp.agreement_scores.max(axis=1).sort_values(ascending=False).values

# 40 min
SortingTrue40 = SortingTrue.frame_slice(
    start_frame=30 * 60 * 40000 + 1, end_frame=40 * 60 * 40000
)
SortingKS40 = SortingKS.frame_slice(
    start_frame=30 * 60 * 40000 + 1, end_frame=40 * 60 * 40000
)
comp = comparison.compare_sorter_to_ground_truth(
    SortingTrue40, SortingKS40, exhaustive_gt=True
)
acc_ks_40m = comp.agreement_scores.max(axis=1).sort_values(ascending=False).values



In [5]:
print(sum(acc_ks_10m >= 0.8) / acc_ks_10m.shape[0])
print(sum(acc_ks_20m >= 0.8) / acc_ks_20m.shape[0])
print(sum(acc_ks_30m >= 0.8) / acc_ks_30m.shape[0])
print(sum(acc_ks_40m >= 0.8) / acc_ks_40m.shape[0])

0.0792507204610951
0.15273775216138327
0.12319884726224783
0.07564841498559077


### Overlapping over simulations

In [13]:
# full
SortingTrue = si.load_extractor(GT_SORTING_PATH_marques)
SortingKS = si.load_extractor(KS_SORTING_PATH_marques)

# 10 min
SortingTrue10 = SortingTrue.frame_slice(
    start_frame=5 * 60 * 40000, end_frame=15 * 60 * 40000
)
SortingKS10 = SortingKS.frame_slice(
    start_frame=5 * 60 * 40000, end_frame=15 * 60 * 40000
)
comp = comparison.compare_sorter_to_ground_truth(
    SortingTrue10, SortingKS10, exhaustive_gt=True
)
acc_ks_10m = comp.agreement_scores.max(axis=1).sort_values(ascending=False).values

# 20 min
SortingTrue20 = SortingTrue.frame_slice(
    start_frame=15 * 60 * 40000 + 1, end_frame=25 * 60 * 40000
)
SortingKS20 = SortingKS.frame_slice(
    start_frame=15 * 60 * 40000 + 1, end_frame=25 * 60 * 40000
)
comp = comparison.compare_sorter_to_ground_truth(
    SortingTrue20, SortingKS20, exhaustive_gt=True
)
acc_ks_20m = comp.agreement_scores.max(axis=1).sort_values(ascending=False).values


# 30 min
SortingTrue30 = SortingTrue.frame_slice(
    start_frame=25 * 60 * 40000 + 1, end_frame=35 * 60 * 40000
)
SortingKS30 = SortingKS.frame_slice(
    start_frame=25 * 60 * 40000 + 1, end_frame=35 * 60 * 40000
)
comp = comparison.compare_sorter_to_ground_truth(
    SortingTrue30, SortingKS30, exhaustive_gt=True
)
acc_ks_30m = comp.agreement_scores.max(axis=1).sort_values(ascending=False).values

In [14]:
print(sum(acc_ks_10m >= 0.8) / acc_ks_10m.shape[0])
print(sum(acc_ks_20m >= 0.8) / acc_ks_20m.shape[0])
print(sum(acc_ks_30m >= 0.8) / acc_ks_30m.shape[0])

0.13544668587896252
0.14481268011527376
0.0792507204610951


### First and second half

In [17]:
# full
SortingTrue = si.load_extractor(GT_SORTING_PATH_marques)
SortingKS = si.load_extractor(KS_SORTING_PATH_marques)

# first min
SortingTrue10 = SortingTrue.frame_slice(start_frame=0, end_frame=20 * 60 * 40000)
SortingKS10 = SortingKS.frame_slice(start_frame=0, end_frame=20 * 60 * 40000)
comp = comparison.compare_sorter_to_ground_truth(
    SortingTrue10, SortingKS10, exhaustive_gt=True
)
acc_ks_10m = comp.agreement_scores.max(axis=1).sort_values(ascending=False).values

# 20 min
SortingTrue20 = SortingTrue.frame_slice(
    start_frame=20 * 60 * 40000 + 1, end_frame=40 * 60 * 40000
)
SortingKS20 = SortingKS.frame_slice(
    start_frame=20 * 60 * 40000 + 1, end_frame=40 * 60 * 40000
)
comp = comparison.compare_sorter_to_ground_truth(
    SortingTrue20, SortingKS20, exhaustive_gt=True
)
acc_ks_20m = comp.agreement_scores.max(axis=1).sort_values(ascending=False).values



In [18]:
print(sum(acc_ks_10m >= 0.8) / acc_ks_10m.shape[0])
print(sum(acc_ks_20m >= 0.8) / acc_ks_20m.shape[0])

0.08213256484149856
0.0792507204610951


### Full

In [38]:
# full
SortingTrue = si.load_extractor(GT_SORTING_PATH_marques)
SortingKS = si.load_extractor(KS_SORTING_PATH_marques)

#
SortingTrue = SortingTrue.frame_slice(start_frame=0, end_frame=40 * 60 * 40000)
SortingKS = SortingKS.frame_slice(start_frame=0, end_frame=40 * 60 * 40000)
comp = comparison.compare_sorter_to_ground_truth(
    SortingTrue, SortingKS, exhaustive_gt=True
)
acc_ks_m = comp.agreement_scores.max(axis=1).sort_values(ascending=False).values



In [48]:
print(sum(acc_ks_m >= 0.8) / acc_ks_m.shape[0])

0.00792507204610951


In [52]:
# check that SortingTrue has timestamps covering the duration of the recording
print("last spike sample:", SortingTrue.to_spike_vector()[-1][0])
Rec = si.load_extractor(PREP_PATH_M)
print("last recording sample:", len(Rec.get_traces()))

last spike sample: 82319846
last recording sample: 82319958
