# Speed up chance calculation for score matrix

author: steeve.laquitaine@epfl.ch  
last modified: 13-02-2024

**Method**:

* **delta_time (Δ𝑡)** = 1.3 ms: the time windows before and after the spike timestamp of a ground truth. When a the timestamp of a sorted unit falls within this time window, they coincide and the sorted timestamp is a hit.
* **chance level score**: see paper
* **dark (missed) units**: ground truth units with sorting accuracy below the chance agreement score (their best match with a sorted unit produce an agreement score below chance).
* **false positive units**: units which timestamps never hit the timestamps of the ground truth units wihin 50 microns of the probe: they never fall withing the delta_time window.


### Setup

Create or activate env `spikeinterf...`

In [2]:
%load_ext autoreload
%autoreload 2
import os 
from matplotlib import pyplot as plt
import numpy as np
from collections import Counter
import pandas as pd
import seaborn as sns 
import spikeinterface as si
from spikeinterface import comparison
import copy
from concurrent.futures import ProcessPoolExecutor
proj_path = "/gpfs/bbp.cscs.ch/project/proj85/home/laquitai/preprint_2023/"
os.chdir(proj_path)

from src.nodes.postpro.cell_matching import get_SpikeInterface_matching_object
from src.nodes.utils import get_config
from src.nodes.postpro.feateng import (add_firing_rates)
from src.nodes.analysis.failures import accuracy as acc
from src.nodes.metrics.metrics import get_firing_rate

# PARAMETERS
REC_DURATION = 600 # 10 minutes recording
DET = 0.8
CHANCE_THRESH = 0.1

# DATASETS

# NPX
# Synthetic
cfg_nb, _ = get_config("buccino_2020", "2020").values()
GT_nb_10m = cfg_nb["sorting"]["simulation"]["ground_truth"]["10m"]["output"]
KS4_nb = cfg_nb["sorting"]["sorters"]["kilosort3"]["10m"]["output"]
REC_nb = cfg_nb["probe_wiring"]["full"]["output"]

# biophy spont
cfg_ns, _ = get_config("silico_neuropixels", "concatenated").values()
KS4_ns = cfg_ns["sorting"]["sorters"]["kilosort3"]["output"]
GT_ns_10m = cfg_ns["sorting"]["simulation"]["ground_truth"]["10m"]["output"]
REC_ns = cfg_ns["probe_wiring"]["full"]["output"]

# biophy evoked
cfg_ne, _ = get_config("silico_neuropixels", "stimulus").values()
KS4_ne = cfg_ne["sorting"]["sorters"]["kilosort3"]["output"]
GT_ne_10m = cfg_ne["sorting"]["simulation"]["ground_truth"]["10m"]["output"]
REC_ne = cfg_ne["probe_wiring"]["full"]["output"]

# FIGURE SETTINGS
COLOR_VIVO = (0.7, 0.7, 0.7)
COLOR_SILI = (0.84, 0.27, 0.2)
COLOR_STIM = (0.6, 0.75, 0.1)
BOX_ASPECT = 1                  # square fig
FIG_SIZE = (1,1)
plt.rcParams['figure.figsize'] = (2,1)
plt.rcParams["font.family"] = "Arial"
plt.rcParams["font.size"] = 6
plt.rcParams['lines.linewidth'] = 0.2
plt.rcParams['axes.linewidth'] = 0.5
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False
plt.rcParams['xtick.major.width'] = 0.3
plt.rcParams['xtick.minor.size'] = 0.1
plt.rcParams['xtick.major.size'] = 1.5
plt.rcParams['ytick.major.size'] = 1.5
plt.rcParams['ytick.major.width'] = 0.3
legend_cfg = {"frameon": False, "handletextpad": 0.1}
savefig_cfg = {"transparent":True}
# print(plt.rcParams.keys())

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
2024-08-26 17:17:19,237 - root - utils.py - get_config - INFO - Reading experiment config.
2024-08-26 17:17:19,282 - root - utils.py - get_config - INFO - Reading experiment config. - done
2024-08-26 17:17:19,287 - root - utils.py - get_config - INFO - Reading experiment config.
2024-08-26 17:17:19,348 - root - utils.py - get_config - INFO - Reading experiment config. - done
2024-08-26 17:17:19,350 - root - utils.py - get_config - INFO - Reading experiment config.
2024-08-26 17:17:19,432 - root - utils.py - get_config - INFO - Reading experiment config. - done


### Chance score

In [3]:
# calculate chance score for a 1 Hz sorted unit (spikes/secs)
fr_gt = 1
fr_s = 1
delta_time = 1.3  # in ms
rec_dur = 600  # recording duration

# calculate chance agreement score
# - chance probability of hits
# - chance score
p_chance_hit = acc.get_p_chance_hit(1 / 1000, 1.3)
chance_acc = acc.get_unit_chance_agreement_score(fr_gt, fr_s, 600, p_chance_hit)

0.0012999992676670962

### Define custom functions

In [4]:
def classify_sorted_unit_biases(agreem_mx, det):
    """_summary_

    Args:
        agreem_mx (_type_): _description_

    Returns:
        (dict):
        - "match" (pd.DataFrame): N ground truth unit indices, N best-match sorted units

    """

    # note: with this approach (BEST matching approach), the same sorted unit can be paired with more than one true unit
    # we only keep the pairings with highest agreement scores
    # true-sorted unit pairing
    pairing = agreem_mx.T.idxmax(axis=1)
    pairing = pairing.to_frame()
    pairing.columns = ["sorted"]

    # add agreement score
    accuracy = agreem_mx.T.max(axis=1)
    pairing["accuracy"] = accuracy

    # check if the only sorted unit paired with this true unit
    sorted_ids = agreem_mx.index

    df = copy.copy(pairing.iloc[0, :].to_frame().T)
    false_positives = []

    # else keep the pairing with highest agreement score
    # loop over all sorted single unit units
    for ix in range(len(sorted_ids)):
        # case the sorted unit was paired with a ground truth unit
        if any(pairing["sorted"] == sorted_ids[ix]):
            sorted_pairings = pairing[pairing["sorted"] == sorted_ids[ix]].sort_values(
                by="accuracy", ascending=False
            )
            # take max pairing (first row)
            df = pd.concat([df, sorted_pairings.iloc[0, :].to_frame().T])
        else:
            # case the sorted unit was paired with none of the ground truth units
            false_positives.append(sorted_ids[ix])

    df = df[1:]
    df["sorted"] = df["sorted"].astype(int)

    # count biases
    n_good = sum(df["accuracy"] >= det)
    n_poor = sum((df["accuracy"] >= CHANCE_THRESH) & (df["accuracy"] < det))
    n_below_chance = sum((df["accuracy"] > 0) & (df["accuracy"] < CHANCE_THRESH))
    n_false_pos = len(false_positives)

    # sanity check
    # assert n_good + n_poor + n_below_chance + n_false_pos == len(
    #     sorted_ids
    # ), "They must match"
    return {
        "n_good": n_good,
        "n_poor": n_poor,
        "n_below_chance": n_below_chance,
        "n_false_pos": n_false_pos,
        "match": pairing,
    }

    """_summary_

    Some scores can be 0, if they were the best found

    Args:
        scores (pd.DataFrame): _description_

    Returns:
        _type_: _description_
    """
    # the same sorted unit can be paired with more than one true unit
    # we keep the pairings with highest agreement scores
    # true-sorted unit pairing
    match = scores.idxmax(axis=0)
    match = match.to_frame()
    match.columns = ["sorted"]
    match.index.name = "true"

    # record agreement score
    match["accuracy"] = scores.max(axis=0)
    return match


def count_match(match: pd.DataFrame, sorted: int):
    """count the number of true units
    one sorted unit matches

    Args:
        match (pd.DataFrame):
        - index: true units id
        - columns:
        - - "sorted": best-match sorted unit id
        - - "accuracy": best-match agreement score
        (true unit sorting accuracy)
    """
    return sum(match["sorted"] == sorted)


def get_best_matched_true_unit(match, s_id):
    """get true unit matchs

    Args:
        match (_type_): _description_
        s_id (_type_): _description_

    Returns:
        _type_: _description_
    """
    units = match[match["sorted"] == s_id].index.values
    if len(units) == 1:
        units = units[0]
    return units


def get_all_matched_true_unit(s_id, scores):
    """get all true unit matchs

    Args:
        match (_type_): _description_
        s_id (_type_): _description_

    Returns:
        _type_: _description_
    """
    return np.array(scores.columns[scores.loc[s_id, :] > 0])


def get_chance_score(gt_id, s_id, Sorting_ns, SortingTrue_ns, duration, delta_time):
    fr_s = get_firing_rate(s_id, Sorting_ns, duration)
    fr_gt = get_firing_rate(gt_id, SortingTrue_ns, duration)
    p_chance_hit = acc.get_p_chance_hit(min(fr_gt, fr_s) / 1000, delta_time)
    return acc.get_unit_chance_agreement_score(fr_gt, fr_s, duration, p_chance_hit)


def get_score(scores, s_id, gt_id):
    try:
        return scores.loc[s_id, gt_id]
    except:
        from ipdb import set_trace

        set_trace()


def is_score_at_chance(
    scores, s_id, gt_id, Sorting_ns, SortingTrue_ns, duration, delta_time
):
    score = get_score(scores, s_id, gt_id)
    chance = get_chance_score(
        gt_id, s_id, Sorting_ns, SortingTrue_ns, duration, delta_time
    )
    return score <= chance


def are_all_scores_at_chance(
    s_id, gt_id, scores, Sorting_ns, SortingTrue_ns, duration, delta_time
):
    """check whether a sorted unit agreement scores with the many
    ground truth units it matches are all at chance (it is a false
    positive)

    Args:
        s_id (_type_): sorted unit id
        gt_id (_type_): ground truth unit id
        scores (_type_): _description_
        match (_type_): _description_
        Sorting_ns (_type_): _description_
        SortingTrue_ns (_type_): _description_
        duration (_type_): _description_
        delta_time (_type_): _description_

    Returns:
        _type_: _description_
    """

    # make an 1D array
    gt_id = make_1darray(gt_id)

    # start counting scores below chance
    d = 0
    for _, g_i in enumerate(gt_id):
        d += is_score_at_chance(
            scores, s_id, g_i, Sorting_ns, SortingTrue_ns, duration, delta_time
        )
    return d == len(gt_id)


def make_1darray(gt_id):
    if isinstance(gt_id, np.ndarray):
        return gt_id
    else:
        return np.array([gt_id])


def is_oversplitter_1(
    scores, match, s_id, gt_id, Sorting_ns, SortingTrue_ns, duration, delta_time
):
    """check when sorted unit matches only one ground truth unit

    Args:
        scores (_type_): _description_
        match (_type_): _description_
        s_id (_type_): _description_
        gt_id (_type_): _description_
        Sorting_ns (_type_): _description_
        SortingTrue_ns (_type_): _description_
        duration (_type_): _description_
        delta_time (_type_): _description_

    Returns:
        _type_: _description_
    """
    # find all sorted units that match this ground truth unit
    sorteds = match[match.index == gt_id]["sorted"].values.tolist()

    d = 0
    # loop over sorted unit match
    for _, s_i in enumerate(sorteds):

        # check for other sorted unit match
        # than the target match s_id
        if not s_i == s_id:
            d += is_score_at_chance(
                scores,
                s_i,
                gt_id,
                Sorting_ns,
                SortingTrue_ns,
                duration,
                delta_time,
            )
    return d > 0


def is_oversplitter_2(
    s_id, gt_id, scores, match, Sorting_ns, SortingTrue_ns, duration, delta_time
):
    """check when sorted unit matches many ground truth units

    Args:
        s_id (_type_): _description_
        scores (_type_): _description_
        match (_type_): _description_
        Sorting_ns (_type_): _description_
        SortingTrue_ns (_type_): _description_
        duration (_type_): _description_
        delta_time (_type_): _description_

    Returns:
        _type_: _description_
    """
    d = 0
    for _, g_i in enumerate(gt_id):
        d += is_oversplitter_1(
            scores,
            match,
            s_id,
            g_i,
            Sorting_ns,
            SortingTrue_ns,
            duration,
            delta_time,
        )
    return d > 0


def is_poor(det, scores, gt_id, s_id, Sorting_ns, SortingTrue_ns, duration, delta_time):
    """is poor (score between chance and threshold for "good" (80%))

    Args:
        scores (_type_): _description_
        gt_id (_type_): _description_
        s_id (_type_): _description_
        Sorting_ns (_type_): _description_
        SortingTrue_ns (_type_): _description_
        duration (_type_): _description_
        delta_time (_type_): _description_

    Returns:
        _type_: _description_
    """
    chance = get_chance_score(
        gt_id, s_id, Sorting_ns, SortingTrue_ns, duration, delta_time
    )
    score = get_score(scores, s_id, gt_id)
    return (score > chance) & (score < det)


def is_poor_2(
    s_id, gt_id, scores, Sorting_ns, SortingTrue_ns, duration, delta_time, det
):
    """is poor when sorted unit matches many ground truth units
    (all its scores with the ground truth are below the "good" threshold
    and it has at least one score with a ground truth above chance

    Args:
        s_id (_type_): _description_
        match (_type_): _description_
        Sorting_ns (_type_): _description_
        SortingTrue_ns (_type_): _description_
        duration (_type_): _description_
        delta_time (_type_): _description_
        det (_type_): _description_

    Returns:
        _type_: _description_
    """
    chance_all = []
    score_all = []

    for _, g_i in enumerate(gt_id):
        chance_all.append(
            get_chance_score(
                g_i, s_id, Sorting_ns, SortingTrue_ns, duration, delta_time
            )
        )
        score_all.append(get_score(scores, s_id, g_i))

    # all scores are below DET
    test_1 = all(np.array(score_all) < det)

    # at least one score is above chance
    test_2 = any(np.array(score_all) > np.array(chance_all))
    return test_1 & test_2


def is_overmerger_2(
    s_id, gt_id, scores, Sorting_ns, SortingTrue_ns, duration, delta_time
):
    """is an overmerger: the sorted unit matches at least two ground
    truths with above chance scores

    Args:
        s_id (_type_): _description_
        match (_type_): _description_
        Sorting_ns (_type_): _description_
        SortingTrue_ns (_type_): _description_
        duration (_type_): _description_
        delta_time (_type_): _description_

    Returns:
        _type_: _description_
    """
    # get there scores
    chance_all = []
    score_all = []

    for _, g_i in enumerate(gt_id):
        # get chance scores
        chance_all.append(
            get_chance_score(
                g_i, s_id, Sorting_ns, SortingTrue_ns, duration, delta_time
            )
        )
        # get scores
        score_all.append(get_score(scores, s_id, g_i))
    return sum(np.array(score_all) > np.array(chance_all)) > 1


def is_good_2(s_id, gt_id, scores):
    """is sorted unit good: it matches one ground truth with 80% and no
    other ground truth with

    Args:
        s_id (_type_): _description_
        scores (_type_): _description_
        match (_type_): _description_
        Sorting_ns (_type_): _description_
        SortingTrue_ns (_type_): _description_
        duration (_type_): _description_
        delta_time (_type_): _description_

    Returns:
        _type_: _description_
    """
    score_all = []
    for _, g_i in enumerate(gt_id):
        score_all.append(get_score(scores, s_id, g_i))
    return any(np.array(score_all) >= DET)


def set_df(df, sorted, true, quality, score):
    df.loc[sorted, "sorted"] = sorted
    df.loc[sorted, "true"] = true
    df.loc[sorted, "score"] = score
    qual = df.loc[sorted, "quality"]
    # record quality if nan (empty)
    # else append new quality
    if isinstance(qual, str):
        df.loc[sorted, "quality"] += quality
    else:
        df.loc[sorted, "quality"] = quality
    return df


def get_scores(
    SortingTrue,
    Sorting,
    delta_time: float,
):
    comp = comparison.compare_sorter_to_ground_truth(
        SortingTrue,
        Sorting,
        exhaustive_gt=True,
        delta_time=delta_time,
    )
    # return comp.agreement_scores.max(axis=1).sort_values(ascending=False).values
    return comp.agreement_scores

## False positives

In [6]:
# load SortingExtractors
SortingTrue_ns = si.load_extractor(GT_ns_10m)
Sorting_ns = si.load_extractor(KS4_ns)
Rec_ns = si.load_extractor(REC_ns)

# get scores (N sorted units rows x N true units columns)
scores_ns = get_scores(SortingTrue_ns, Sorting_ns, 1.3)
scores_ns = scores_ns.T

# curate (get single-unit only)
scores_ns = scores_ns.loc[
    Sorting_ns.unit_ids[Sorting_ns.get_property("KSLabel") == "good"], :
]

scores_by_exp = [scores_ns]



In [21]:
# compute sorting quality
out = classify_sorted_unit_biases(scores_ns, DET)

**Method**:

* For each true unit we find its best-match sorted unit
* calculate chance level for all units is pairs is consuming-
* all NaN scores are sorted units' will many ground truth matches.
* 15 minutes


In [7]:
scores = scores_ns
scores

Unnamed: 0,12165,16652,18371,19690,21040,24768,29248,30168,32331,37423,...,4213917,4215563,4216128,4217493,4221920,4223302,4225319,4228700,4229218,4229506
143,0.000000,0.0,0.000000,0.0,0.0,0.0,0.0,0.000000,0.0,0.0,...,0.000000,0.000000,0.000000,0.0,0.000000,0.0,0.0,0.000000,0.0,0.001927
201,0.000000,0.0,0.000000,0.0,0.0,0.0,0.0,0.000000,0.0,0.0,...,0.001315,0.000804,0.000649,0.0,0.001650,0.0,0.0,0.001612,0.0,0.000878
208,0.000000,0.0,0.000000,0.0,0.0,0.0,0.0,0.000000,0.0,0.0,...,0.001019,0.002314,0.000000,0.0,0.000391,0.0,0.0,0.000000,0.0,0.000000
215,0.000000,0.0,0.000000,0.0,0.0,0.0,0.0,0.000000,0.0,0.0,...,0.001202,0.000451,0.001184,0.0,0.000970,0.0,0.0,0.000391,0.0,0.000000
217,0.000000,0.0,0.000000,0.0,0.0,0.0,0.0,0.000000,0.0,0.0,...,0.000622,0.000282,0.000000,0.0,0.000453,0.0,0.0,0.000599,0.0,0.002375
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1596,0.000000,0.0,0.000000,0.0,0.0,0.0,0.0,0.000000,0.0,0.0,...,0.000000,0.000000,0.000000,0.0,0.000000,0.0,0.0,0.000000,0.0,0.000000
1612,0.000000,0.0,0.000000,0.0,0.0,0.0,0.0,0.000000,0.0,0.0,...,0.000964,0.001316,0.001433,0.0,0.000842,0.0,0.0,0.001188,0.0,0.000590
1613,0.000000,0.0,0.000000,0.0,0.0,0.0,0.0,0.000000,0.0,0.0,...,0.002236,0.000537,0.000000,0.0,0.001676,0.0,0.0,0.000000,0.0,0.000000
1632,0.000109,0.0,0.000109,0.0,0.0,0.0,0.0,0.000109,0.0,0.0,...,0.001555,0.002293,0.001065,0.0,0.001194,0.0,0.0,0.001063,0.0,0.000525


In [8]:
from scipy.stats import poisson
from joblib import Parallel, delayed
from tqdm import tqdm

# import warnings

# warnings.filterwarnings("ignore", category=DeprecationWarning)


def get_p_chance_hit(fr: float, delta_time: float):
    """derive the chance probability of hits
    (coincidences between two independent sorted and
    ground truth unit spike trains)

    We should use the firing rate of the less firing
    of the two. It determines the expected maximum
    possible number of coincidences.

    Args:
        fr (float): firing rate in spikes/ms
        delta_time (float): SpikeInterface delta_time interval in ms

    Returns:
        _type_: _description_
    """
    k = 0  # we want the probability of k=0 coincidences
    interval_ms = 2 * delta_time  # time interval in ms
    n_sp = interval_ms * fr  # expected nb of coincidences
    return 1.0 - poisson.pmf(k=k, mu=n_sp)


def get_unit_chance_agreement_score(
    fr_gt: float, fr_s: float, rec_dur: float, p_chance_hit: float
):
    """get unit chance scorey

    The chance scorey metrics should change with the ground truth firing rate.
    It is not the case with the current calculation.
    Intuition: the more a ground truth unit spikes within the duration of recording (say 600 secs),
    the more spikes will be missed when compared a sorting unit of a fixed firing rate.
    The increasing number of misses should decrease the value of the chance score metrics,
    which is currently not the case.

    Args:
        fr_gt (float): ground truth firing rate (spikes/secs)
        fr_s (float): sorted unit firing rate (spikes/secs)
        p_chance_hit (float): chance probability of hits
        rec_dur (float): recording duration
    """
    # nb of spikes
    n_gt = fr_gt * rec_dur
    n_s = fr_s * rec_dur

    # nb of hits, false positives, misses
    # - the smallers spike train min(n_gt, n_s) determines
    # the maximum possible number of hits
    n_h = p_chance_hit * min(n_gt, n_s)
    n_fp = n_s - n_h
    n_m = n_gt - n_h
    return n_h / (n_h + n_m + n_fp)


def get_chance_score(gt_id, s_id, Sorting_ns, SortingTrue_ns, duration, delta_time):
    fr_s = get_firing_rate(s_id, Sorting_ns, duration)
    fr_gt = get_firing_rate(gt_id, SortingTrue_ns, duration)
    p_chance_hit = get_p_chance_hit(min(fr_gt, fr_s) / 1000, delta_time)
    return get_unit_chance_agreement_score(fr_gt, fr_s, duration, p_chance_hit)


def get_firing_rate(unit_id: int, Sorting, rec_duration: float):
    """get a unit's firing rate

    Returns:
        (float): firing rate in spikes/secs
    """
    n_spikes = Sorting.count_num_spikes_per_unit()[unit_id]
    return n_spikes / rec_duration


# warning: all get_chance_score's underlying functions must be in the same module
chance_df = Parallel(n_jobs=-1)(
    delayed(get_chance_score)(i, j, Sorting_ns, SortingTrue_ns, 600, delta_time)
    for _, i in enumerate(tqdm(scores.columns))
    for j in scores.index
)

  1%|          | 9/1388 [1:01:48<162:45:51, 424.91s/it]