# Kilosort 4

To run kilosort4 on Google Colab

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


***
## Import and Install Libraries

In [None]:
import os
import re
import sys
import warnings
import argparse
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
from collections import defaultdict
from matplotlib import gridspec, rcParams
warnings.simplefilter("ignore")
# install kilosort
!pip install kilosort
from kilosort import run_kilosort

sys.path.append('/content/drive/MyDrive/Columbia/Salzman/Projects/Spike-Sorting/Kilosort4-SGLX')



***
## Find Recording Folders

In [None]:
def read_recording_folder(root, monkey, date, session_num=0):
	sglx_folder = os.path.join(root, f'{monkey}_{date}_g{session_num}')
	imec_folder_dict = defaultdict(str)
	print(f'SpikeGLX folder: {sglx_folder}')
	# check if it exists
	if not os.path.exists(sglx_folder):
		sys.exit(f'{sglx_folder} does not exist')
	# print all folders that have imec<int> in them
	for root, dirs, files in os.walk(sglx_folder):
		for folder in dirs:
			# only if the directory ends with imec<int>
			if re.search('imec\d$', folder):
				imec_num = re.search('imec\d$', folder).group()
				imec_folder_dict[imec_num] = os.path.join(root, folder)
	if not imec_folder_dict:
		try:
			print(os.listdir(root))
		except:
			print(f'{root} missing')
		sys.exit(f'No imec folders found in {sglx_folder}')
	# order the dictionary
	imec_folder_dict = dict(sorted(imec_folder_dict.items(), key=lambda item: item[1]))
	return sglx_folder, imec_folder_dict

root='/content/drive/Othercomputers/Ephys/E:'
# root = 'E:'
monkey = 'gandalf'
date = '20240126'
session_num = 0
sglx_folder, imec_folder_dict = read_recording_folder(root, monkey, date, session_num)

imec_folder_dict

SpikeGLX folder: /content/drive/Othercomputers/Ephys/E:/gandalf_20240126_g0


{'imec0': '/content/drive/Othercomputers/Ephys/E:/gandalf_20240126_g0/gandalf_20240126_g0_imec0',
 'imec1': '/content/drive/Othercomputers/Ephys/E:/gandalf_20240126_g0/gandalf_20240126_g0_imec1'}

***
## Run CatGT

In [None]:
import time
import subprocess

run_catgt = True
cat_prb_fld = '0:3'
ks_path = '/content/drive/MyDrive/Columbia/Salzman/Projects/Spike-Sorting/Kilosort4-SGLX'
catgt_path = os.path.join(ks_path, 'CatGT-win')

if run_catgt:
  sys.path.append(catgt_path)
  if not os.path.exists(catgt_path):
    print(f'CatGT Missing: ')
  print(f"Running CatGT on {sglx_folder}")
  # time how long it takes
  start_time = time.time()
  catgt_command = f"runit.bat -dir={root} -run={monkey}_{date} -prb_fld -g={session_num} -t=0 -ni -prb={cat_prb_fld} -ap"
  print(f"  Bash command: {catgt_command}")
  subprocess.run(catgt_command, cwd=os.path.join(ks_path, "CatGT-win/"), shell=True)
  print(f"  CatGT complete. Time elapsed: {time.time() - start_time:.2f} seconds")


Running CatGT on /content/drive/Othercomputers/Ephys/E:/gandalf_20240126_g0
  Bash command: runit.bat -dir=/content/drive/Othercomputers/Ephys/E: -run=gandalf_20240126 -prb_fld -g=0 -t=0 -ni -prb=0:3 -ap
  CatGT complete. Time elapsed: 0.02 seconds


***
## Run Kilosort

In [None]:
def run_kilosort4(settings):
	default_kilosort_settings = {'n_chan_bin': 385}
	ops, st, clu, tF, Wall, similar_templates, is_ref, est_contam_rate = \
			run_kilosort(settings=default_kilosort_settings,
									 probe_name=settings['probe_name'],
									 filename=settings['filename'],
									 results_dir=settings['results_dir'],
									 do_CAR=settings['do_CAR'],)

	plot_results(settings)

map_file = '/content/drive/MyDrive/Columbia/Salzman/Projects/Spike-Sorting/Kilosort4-SGLX/configFiles/neuropixels_NHP_channel_map_linear_v1.mat'
# map_file = "C:/Users/Milner/OneDrive/Desktop/Kilosort4-SGLX/configFiles/neuropixels_NHP_channel_map_linear_v1.mat"
n_channels = 385

for imec_num, imec_folder in imec_folder_dict.items():
	print(f'{imec_num} folder: {imec_folder}')
	data_dir = imec_folder
	save_path = os.path.join(imec_folder, f'{monkey}_{date}_g{session_num}_t0.{imec_num}_ks4')
	bin_file = os.path.join(imec_folder, f'{monkey}_{date}_g{session_num}_t0.{imec_num}.ap.bin')
	if not os.path.exists(bin_file):
		print(f'{bin_file} does not exist')
		continue
	settings = {'filename': bin_file,
							'results_dir': save_path,
							'do_CAR': True,
							'n_channels': n_channels,
							'probe_name': map_file}
	run_kilosort4(settings)

imec0 folder: /content/drive/Othercomputers/Ephys/E:/gandalf_20240126_g0/gandalf_20240126_g0_imec0
Interpreting binary file as default dtype='int16'. If data was saved in a different format, specify `data_dtype`.
Using GPU for PyTorch computations. Specify `device` to change this.
using probe neuropixels_NHP_channel_map_linear_v1.mat
Preprocessing filters computed in  517.05s; total  517.92s

computing drift
Re-computing universal templates from data.


 97%|█████████▋| 3962/4078 [1:36:14<02:53,  1.50s/it]

***
## Run TPrime

In [None]:
from pathlib import Path

run_tprime = True
sglx_tools = os.path.join(ks_path, 'SpikeGLX_Datafile_Tools/Python/DemoReadSGLXData')
sys.path.append(sglx_tools)
from readSGLX import readMeta

if run_tprime:
  for imec_num, imec_folder in imec_folder_dict.items():
    meta_file = [f for f in os.listdir(imec_folder) if f.endswith('.meta')][0]
    bin_file = os.path.join(imec_folder, f'{monkey}_{date}_g{session_num}_t0.{imec_num}.ap.bin')
    bin_file_path = Path(bin_file_path)
    meta_file = readMeta(bin_file_path)
    print("\tConverting spike times to seconds...")
    sample_rate = float(meta_file['imSampRate'])
    print(f"\tSampling rate: {sample_rate:.2f}")
    spike_times = np.load(os.path.join(dest_folder_path, 'spike_times.npy'))
    # keep full precision
    spike_times_sec = spike_times / sample_rate
    spike_times_sec_file = os.path.join(dest_folder_path, "spike_times_sec.txt")
    with open(spike_times_sec_file, 'w') as f:
        for spike_time in spike_times_sec:
            f.write(f"{spike_time}\n")
    print(f"\t\tGenerating {spike_times_sec_file}")
    print("\t\tDone.")

In [None]:
def plot_results(settings):

	# outputs saved to results_dir
	results_dir = Path(settings['data_dir']).joinpath('kilosort4')
	ops = np.load(results_dir / 'ops.npy', allow_pickle=True).item()
	camps = pd.read_csv(results_dir / 'cluster_Amplitude.tsv', sep='\t')['Amplitude'].values
	contam_pct = pd.read_csv(results_dir / 'cluster_ContamPct.tsv', sep='\t')['ContamPct'].values
	chan_map =  np.load(results_dir / 'channel_map.npy')
	templates =  np.load(results_dir / 'templates.npy')
	chan_best = (templates**2).sum(axis=1).argmax(axis=-1)
	chan_best = chan_map[chan_best]
	amplitudes = np.load(results_dir / 'amplitudes.npy')
	st = np.load(results_dir / 'spike_times.npy')
	clu = np.load(results_dir / 'spike_clusters.npy')
	firing_rates = np.unique(clu, return_counts=True)[1] * 30000 / st.max()
	dshift = ops['dshift']

	rcParams['axes.spines.top'] = False
	rcParams['axes.spines.right'] = False
	gray = .5 * np.ones(3)

	fig = plt.figure(figsize=(10,10), dpi=100)
	grid = gridspec.GridSpec(3, 3, figure=fig, hspace=0.5, wspace=0.5)

	ax = fig.add_subplot(grid[0,0])
	ax.plot(np.arange(0, ops['Nbatches'])*2, dshift);
	ax.set_xlabel('time (sec.)')
	ax.set_ylabel('drift (um)')

	ax = fig.add_subplot(grid[0,1:])
	t0 = 0
	t1 = np.nonzero(st > ops['fs']*5)[0][0]
	ax.scatter(st[t0:t1]/30000., chan_best[clu[t0:t1]], s=0.5, color='k', alpha=0.25)
	ax.set_xlim([0, 5])
	ax.set_ylim([chan_map.max(), 0])
	ax.set_xlabel('time (sec.)')
	ax.set_ylabel('channel')
	ax.set_title('spikes from units')

	ax = fig.add_subplot(grid[1,0])
	nb=ax.hist(firing_rates, 20, color=gray)
	ax.set_xlabel('firing rate (Hz)')
	ax.set_ylabel('# of units')

	ax = fig.add_subplot(grid[1,1])
	nb=ax.hist(camps, 20, color=gray)
	ax.set_xlabel('amplitude')
	ax.set_ylabel('# of units')

	ax = fig.add_subplot(grid[1,2])
	nb=ax.hist(np.minimum(100, contam_pct), np.arange(0,105,5), color=gray)
	ax.plot([10, 10], [0, nb[0].max()], 'k--')
	ax.set_xlabel('% contamination')
	ax.set_ylabel('# of units')
	ax.set_title('< 10% = good units')

	for k in range(2):
			ax = fig.add_subplot(grid[2,k])
			is_ref = contam_pct<10.
			ax.scatter(firing_rates[~is_ref], camps[~is_ref], s=3, color='r', label='mua', alpha=0.25)
			ax.scatter(firing_rates[is_ref], camps[is_ref], s=3, color='b', label='good', alpha=0.25)
			ax.set_ylabel('amplitude (a.u.)')
			ax.set_xlabel('firing rate (Hz)')
			ax.legend()
			if k==1:
					ax.set_xscale('log')
					ax.set_yscale('log')
					ax.set_title('loglog')

	#
	probe = ops['probe']
	# x and y position of probe sites
	xc, yc = probe['xc'], probe['yc']
	nc = 16 # number of channels to show
	good_units = np.nonzero(contam_pct <= 0.1)[0]
	mua_units = np.nonzero(contam_pct > 0.1)[0]

	gstr = ['good', 'mua']
	for j in range(2):
			print(f'~~~~~~~~~~~~~~ {gstr[j]} units ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
			print('title = number of spikes from each unit')
			units = good_units if j==0 else mua_units
			fig = plt.figure(figsize=(12,3), dpi=150)
			grid = gridspec.GridSpec(2,20, figure=fig, hspace=0.25, wspace=0.5)

			for k in range(40):
					wi = units[np.random.randint(len(units))]
					wv = templates[wi].copy()
					cb = chan_best[wi]
					nsp = (clu==wi).sum()

					ax = fig.add_subplot(grid[k//20, k%20])
					n_chan = wv.shape[-1]
					ic0 = max(0, cb-nc//2)
					ic1 = min(n_chan, cb+nc//2)
					wv = wv[:, ic0:ic1]
					x0, y0 = xc[ic0:ic1], yc[ic0:ic1]

					amp = 4
					for ii, (xi,yi) in enumerate(zip(x0,y0)):
							t = np.arange(-wv.shape[0]//2,wv.shape[0]//2,1,'float32')
							t /= wv.shape[0] / 20
							ax.plot(xi + t, yi + wv[:,ii]*amp, lw=0.5, color='k')

					ax.set_title(f'{nsp}', fontsize='small')
					ax.axis('off')
			plt.show()