In [1]:
import json
import math
import os
import time
from typing import Any, TypeAlias

import numpy as np
import pandas as pd
import torch
from numpy.typing import NDArray
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

import burst_detector as bd

In [2]:
params = {}
params['data_filepath'] = "C:/Users/Harris_Lab/Projects/burst-detector/data/rec_bank0_dense_g0/KS2.5/catgt_rec_bank0_dense_g0/rec_bank0_dense_g0_imec0/rec_bank0_dense_g0_tcat.imec0.ap.bin"
params['KS_folder'] = "../data/rec_bank0_dense_g0/KS2.5/catgt_rec_bank0_dense_g0/rec_bank0_dense_g0_imec0/imec0_ks2/"
params['n_chan'] = 385


params["dtype"] = "int16"
params["min_spikes"] = 100
params["good_lbls"] = ["good", "mua"]
params["pre_samples"] = 20
params["post_samples"] = 62
params["max_spikes"] = 1000
params["sim_type"] = "ae"
params["ae_pre"] = 10
params["ae_post"] = 30
params["ae_chan"] = 8
params["ae_shft"] = False
params["ae_epochs"] = 25
params["sim_thresh"] = .4

params['window_size'] = .025
params["xcorr_bin_width"] = 0.0005
params["max_window"] = 0.25
params["min_xcorr_rate"] = 1200
params["fs"] = 30000
params["overlap_tol"] = 10 / 30000

params['ref_pen_bin_width'] = 1
params['max_viol'] = .25
params["xcorr_coeff"] = 0.5
params["ref_pen_coeff"] = 1
params["final_thresh"] = .6
params["max_dist"] = 10

In [3]:
os.makedirs(os.path.join(params["KS_folder"], "automerge"), exist_ok=True)
os.makedirs(os.path.join(params["KS_folder"], "automerge", "merges"), exist_ok=True)

In [4]:
# Load sorting and recording info.
print("Loading files...")
times: NDArray[np.float_] = np.load(
    os.path.join(params["KS_folder"], "spike_times.npy")
).flatten()
clusters: NDArray[np.int_] = np.load(
    os.path.join(params["KS_folder"], "spike_clusters.npy")
).flatten()
cl_labels: pd.DataFrame = pd.read_csv(
    os.path.join(params["KS_folder"], "cluster_group.tsv"), sep="\t"
)
channel_pos: NDArray[np.float_] = np.load(
    os.path.join(params["KS_folder"], "channel_positions.npy")
)

# Compute useful cluster info.
n_clust: int = clusters.max() + 1
counts: dict[int, int] = bd.spikes_per_cluster(clusters)
times_multi: list[NDArray[np.float_]] = bd.find_times_multi(
    times, clusters, np.arange(clusters.max() + 1)
)

# Load the ephys recording.
rawData = np.memmap(params["data_filepath"], dtype=params["dtype"], mode="r")
data: NDArray[np.int16] = np.reshape(
    rawData, (int(rawData.size / params["n_chan"]), params["n_chan"])
)

Loading files...


In [5]:
# Mark units that we don't want to consider.
cl_good: NDArray[np.bool_] = np.zeros(n_clust, dtype=bool)
unique: NDArray[np.int_] = np.unique(clusters)
for cl in range(n_clust):
    if (
        (cl in unique)
        and (counts[cl] > params["min_spikes"])
        and (
            cl_labels.loc[cl_labels["cluster_id"] == cl, "group"].item()
            in params["good_lbls"]
        )
    ):
        cl_good[cl] = True

In [6]:
# Calculate cluster mean waveforms if needed.
spikes: dict[int, NDArray[np.int_]] | None = None
try:
    mean_wf: NDArray[np.float_] = np.load(
        os.path.join(params["KS_folder"], "mean_waveforms.npy")
    )
    std_wf: NDArray[np.float_] = np.load(
        os.path.join(params["KS_folder"], "std_waveforms.npy")
    )
except OSError:
    print(
        "mean_waveforms.npy doesn't exist, calculating mean waveforms on the fly..."
    )
    mean_wf: NDArray[np.float_] = np.zeros(
        (n_clust, params["n_chan"], params["pre_samples"] + params["post_samples"])
    )
    std_wf: NDArray[np.float_] = np.zeros_like(mean_wf)
    spikes = {}
    for i in range(n_clust):
        if cl_good[i]:
            spikes[i] = bd.extract_spikes(
                data,
                times_multi,
                i,
                n_chan=params["n_chan"],
                pre_samples=params["pre_samples"],
                post_samples=params["post_samples"],
                max_spikes=params["max_spikes"],
            )
            mean_wf[i, :, :] = np.nanmean(spikes[i], axis=0)
            std_wf[i, :, :] = np.nanstd(spikes[i], axis=0)
    np.save(os.path.join(params["KS_folder"], "mean_waveforms.npy"), mean_wf)
    np.save(os.path.join(params["KS_folder"], "std_waveforms.npy"), std_wf)
peak_chans: NDArray[np.int_] = np.argmax(np.max(mean_wf, 2) - np.min(mean_wf, 2), 1)

In [7]:
print("Extracting spike snippets to train autoencoder...")
spk_fld: str = os.path.join(params["KS_folder"], "automerge", "spikes")
ci: dict[str, Any] = {
    "times": times,
    "times_multi": times_multi,
    "clusters": clusters,
    "counts": counts,
    "labels": cl_labels,
    "mean_wf": mean_wf,
}
ext_params: dict[str, Any] = {
    "spk_fld": spk_fld,
    "pre_samples": params["ae_pre"],
    "post_samples": params["ae_post"],
    "num_chan": params["ae_chan"],
    "for_shft": params["ae_shft"],
}
spk_snips: torch.Tensor
cl_ids: NDArray[np.int_]
spk_snips, cl_ids = bd.generate_train_data(
    data, ci, channel_pos, ext_params, params
)

Extracting spike snippets to train autoencoder...


In [9]:
# print("Training autoencoder...")
# net, spk_data = bd.train_ae(
#     spk_snips,
#     cl_ids,
#     do_shft=params["ae_shft"],
#     num_epochs=params["ae_epochs"],
# )
# torch.save(
#     net.state_dict(),
#     os.path.join(params["KS_folder"], "automerge", "ae.pt"),
# )
# print(
#     "Autoencoder saved in "
#     + str(os.path.join(params["KS_folder"], "automerge", "ae.pt"))
# )
params["model_path"] = r"C:\Users\Harris_Lab\Projects\burst-detector\data\rec_bank0_dense_g0\KS2.5\catgt_rec_bank0_dense_g0\rec_bank0_dense_g0_imec0\imec0_ks2\automerge\ae.pt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net: bd.CN_AE = bd.CN_AE().to(device)
net.load_state_dict(torch.load(params["model_path"]))
net.eval()
spk_data: bd.SpikeDataset = bd.SpikeDataset(spk_snips, cl_ids)

In [10]:
# Calculate similarity using distances in the autoencoder latent space.
spk_lat_peak: NDArray[np.float_]
lat_mean: NDArray[np.float_]
spk_lab: NDArray[np.int_]
sim, spk_lat_peak, lat_mean, spk_lab = bd.calc_ae_sim(
mean_wf, net, peak_chans, spk_data, cl_good, do_shft=params["ae_shft"]
)
pass_ms = sim > params["sim_thresh"]

Calculating latent features...
Batch 1022/1022
LOSS: tensor(20.2725, device='cuda:0')


In [11]:
# Calculate a significance metric for cross-correlograms.
print("Calculating cross-correlation metric...")
xcorr_sig: NDArray[np.float_]
x_grams: NDArray
shfl_xgrams: NDArray
xcorr_sig, xgrams, shfl_xgrams = bd.calc_xcorr_metric(
    times_multi, n_clust, pass_ms, params
)

Calculating cross-correlation metric...


In [12]:
ref_pen: NDArray[np.float_]
ref_per: NDArray[np.float_]
ref_pen, ref_per = bd.calc_ref_p(
    times_multi, clusters, n_clust, pass_ms, xcorr_sig, params
)

In [14]:
final_metric: NDArray[np.float_] = np.zeros_like(sim)
for c1 in range(n_clust):
    for c2 in range(c1, n_clust):
        met: float = (
            sim[c1, c2]
            + params["xcorr_coeff"] * xcorr_sig[c1, c2]
            - params["ref_pen_coeff"] * ref_pen[c1, c2]
        )

        final_metric[c1, c2] = max(met, 0)
        final_metric[c2, c1] = max(met, 0)

In [15]:
old2new: dict[int, int]
new2old: dict[int, list[int]]
old2new, new2old = bd.merge_clusters(
    clusters, counts, mean_wf, final_metric, params
)

In [16]:
new2old

{457: [370, 367],
 458: [105, 389],
 459: [436, 350],
 460: [432, 337],
 461: [392, 438],
 462: [434, 451],
 463: [398, 165],
 464: [139, 148],
 465: [395, 440],
 466: [404, 223],
 467: [160, 60],
 468: [439, 134],
 469: [119, 120]}