#### Importing modules

In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from statistics import mode
import chiCa
from spks import *

%matplotlib widget

#### Workaround for having modified the kilosort clusters with Phy

clu = Clusters(folder = kilosort_path, get_waveforms=False)
raw_data = RawRecording([fast_binary_path], return_preprocessed = False)
clu.extract_waveforms(data = raw_data, save_folder_path = clu.folder) # no filtering if using filtered binary

clu2 = Clusters(folder = kilosort_path, get_waveforms=True, load_template_features=True, get_metrics=True)


In [2]:
binary_path = Path('/home/data/GRB006/20240429_174359/ephys_g0/ephys_g0_imec0/ephys_g0_t0.imec0.ap.bin')
fast_binary_path = Path('/scratch/GRB/temp_bin/ephys_g0_t0.imec0.ap.bin')
kilosort_path = Path('/home/data/GRB006/20240429_174359/kilosort2.5/imec0/')

#### Loading sync data

In [3]:
from spks.sync import load_ni_sync_data,interp1d
sessionpath = Path('/home/data/GRB006/20240429_174359/')
sync_port = 0 # this is where the SMA of the probe is connected

(nionsets,nioffsets),(nisync,nimeta),(apsyncdata) = load_ni_sync_data(sessionpath=sessionpath)
aponsets = apsyncdata[0]['file0_sync_onsets'][6] # this should be the same for you, its where the sync is on the probe

corrected_onsets = {} # This is a dictionary with the digital events that were connected to the breakout box.
for k in nionsets.keys():
    corrected_onsets[k] = interp1d(nionsets[sync_port],aponsets,fill_value='extrapolate')(nionsets[k]).astype('uint64')

# if you need analog channels those are in "nisync"
nitime = interp1d(nionsets[sync_port],aponsets,fill_value='extrapolate')(np.arange(len(nisync)))

# everything is in samples, use this sampling rate
srate = apsyncdata[0]['sampling_rate']  

frame_rate = mode(1/(np.diff(corrected_onsets[1])/srate)) #corrected_onsets[1] are the frame samples, [2] are the trial start samples
trial_start_times = corrected_onsets[2][:-1]/srate


#### Loading behavioral data

In [4]:
behavior_data = chiCa.load_trialdata('/home/data/GRB006/20240429_174359/chipmunk/GRB006_20240429_174359_chipmunk_DemonstratorAudiTask.mat')

# stim_period_times = np.zeros(shape=(len(behavior_data),1))
# relative_stim_period_times = np.zeros(shape=(len(behavior_data),1))

# for trial, _ in behavior_data.PlayStimulus.items(): #unpacking index as trial and omitting the data
#     stim_period_times[trial] = behavior_data.trial_start_time[trial] + behavior_data.PlayStimulus[trial][0]
#     relative_stim_period_times[trial] = behavior_data.PlayStimulus[trial][0]


# stim_period_times = stim_period_times.flatten()
# relative_stim_period_times = relative_stim_period_times.flatten()



#### Get stimulus onsets aligned to nidaq time

It's not exactly aligned to hybrid time as it's trial start (nidaq) + stimulus timestamps (Bpod)

In [5]:
stim_onsets = trial_start_times + behavior_data.stimulus_event_timestamps

first_stim_onsets = np.zeros(len(stim_onsets))
for trial, timestamps in enumerate(stim_onsets):
    if np.isnan(timestamps[0]):
        first_stim_onsets[trial] = np.nan
    else:
        first_stim_onsets[trial] = timestamps[0]

first_stim_onsets = first_stim_onsets[~np.isnan(first_stim_onsets)]
# first_stim_onsets.shape

Verifying that it worked...

In [None]:
fig, axs = plt.subplots(len(stim_onsets[:5]), 1, figsize=(5,5))
for i, trials in enumerate(stim_onsets[:5]):
    axs[i].vlines(stim_onsets[i], 0, 1)
    axs[i].get_yaxis().set_visible(False)
axs[-1].set_xlabel('Stimulus timestamps (s)')
fig.tight_layout()

#### Temp ignore

In [None]:
stim_analog = nisync[:, 0]
threshold = 10000
stim_digital = (stim_analog > threshold).astype(int)

t = nitime/srate
idx = t<20 #first # of seconds of the session
fig, axs = plt.subplots(1,2, figsize=(14,8))
axs[0].plot(t[idx],stim_digital[idx],'k')
axs[0].set_xlim([15.1, 16.1])

# t = nitime/srate
# idx = t<300 #first # of seconds of the session
# fig = plt.figure()
axs[1].plot(t[idx],nisync[idx,0],'k')
axs[1].set_xlim([15.1, 16.1])


In [None]:
def get_boxcar_signal(stim_digital, min_event_gap=30):  # gap in samples
    transitions = np.diff(stim_digital)
    start_indices = np.where(transitions == 1)[0] + 1
    end_indices = np.where(transitions == -1)[0] + 1

    if stim_digital[0] == 1:
        start_indices = np.insert(start_indices, 0, 0)
    if stim_digital[-1] == 1:
        end_indices = np.append(end_indices, len(stim_digital))

    # Consolidate closely spaced mini-events
    consolidated_starts = []
    consolidated_ends = []
    current_start = start_indices[0]
    for i in range(1, len(start_indices)):
        if start_indices[i] - end_indices[i - 1] > min_event_gap:
            consolidated_starts.append(current_start)
            consolidated_ends.append(end_indices[i - 1])
            current_start = start_indices[i]
    consolidated_starts.append(current_start)
    consolidated_ends.append(end_indices[-1])

    boxcar_signal = np.zeros_like(stim_digital)
    for start, end in zip(consolidated_starts, consolidated_ends):
        boxcar_signal[start:end] = 1

    return boxcar_signal

# Generate the boxcar signal
boxcar_signal = get_boxcar_signal(stim_digital)

# Plot the original and boxcar signals
t = nitime/srate
idx = t < 300  # First 20 seconds of the session
fig, axs = plt.subplots(1, 2, figsize=(14, 8))

# Original digital signal in the specified range
axs[0].plot(t[idx], boxcar_signal[idx], 'k')
# axs[0].set_xlim([15, 16.2])
axs[0].set_ylim([-0.5, 1.5])
axs[0].set_title('Original Digital Signal')
axs[0].set_xlabel('Time (s)')
axs[0].set_ylabel('Signal')

# Analog signal in the specified range
axs[1].plot(t[idx], nisync[idx, 0], 'k')
# axs[1].set_xlim([15, 16.2])
axs[1].set_title('Analog Signal')
axs[1].set_xlabel('Time (s)')
axs[1].set_ylabel('Signal')

plt.show()

In [None]:
t = nitime/srate
t = ~np.isnan(t)
boxcar_signal[t].shape[0]

In [None]:
stim_event_idx = np.where(boxcar_signal == 1)[0]

In [None]:
stim_event_idx

In [None]:
start_idx = trial_start_samples[0]
end_idx = trial_start_samples[0 + 1]
stim_event_idx[(stim_event_idx >= start_idx) & (stim_event_idx < end_idx)]

In [None]:
stim_event_idx

In [None]:
# Print the first few values of trial_start_samples and stim_event_idx
print("First 10 trial_start_samples:", trial_start_samples[:10])
print("First 10 stim_event_idx:", stim_event_idx[:10])

#### Trying to align stuff

In [None]:
video_alignment_data = chiCa.align_behavioral_video('/home/data/GRB006/20240429_174359/chipmunk/GRB006_20240429_174359_chipmunk_DemonstratorAudiTask_BackStereoView_00000000.camlog') #this has the trial start frames

In [None]:
# so I have trial start frames for behavior in trial_start_frames
# trial_start_frames = video_alignment_data['trial_starts']

# I have trial start frames for npx in corrected_onsets[2]
trial_start_samples = corrected_onsets[2][:-1]

In [None]:
sc = np.load('/home/data/GRB006/20240429_174359/kilosort2.5/imec0/spike_clusters.npy') # vector of all spike times recorded
st = np.load('/home/data/GRB006/20240429_174359/kilosort2.5/imec0/spike_times.npy') # vector of to what cluster each spike belonged to

assert sc.shape == st.shape

In [None]:
import labcams

logdata, comments = labcams.parse_cam_log('/home/data/GRB006/20240429_174359/chipmunk/GRB006_20240429_174359_chipmunk_DemonstratorAudiTask_BackStereoView_00000000.camlog')

# (logdata.timestamp.values[-1]-logdata.timestamp.values[0])/60

bpodtime = interp1d(nionsets[sync_port],corrected_onsets[2],fill_value='extrapolate')(np.arange(len(nisync)))
bpodtime

#### PSTHs?

In [None]:
# sc = np.load('/home/data/GRB006/20240429_174359/kilosort2.5/imec0/spike_clusters.npy')
# st = np.load('/home/data/GRB006/20240429_174359/kilosort2.5/imec0/spike_times.npy')

# if sc.shape != st.shape:
#     raise ValueError(f"The shapes of sc {sc.shape} and st {st.shape} are not the same")

# df = pd.DataFrame({'spike_samples': st, 'spike_clusters': sc})
# df['spike_times'] = df['spike_samples'] / srate
# cluster_spike_times = df.groupby('spike_clusters').agg(spike_samples=('spike_samples', list), spike_times=('spike_times', list)).reset_index()

# cluster_spike_times

In [None]:
# sp_samples = np.array(cluster_spike_times.spike_samples[7])
# start = trial_start_samples[0]
# end = trial_start_samples[1]

# filtered_sp_samples = sp_samples[(sp_samples >= start) & (sp_samples <= end)]

# plt.figure()
# plt.eventplot(filtered_sp_samples, lineoffsets=0.5, colors='black')


In [None]:
# binary_path = Path('/home/data/GRB006/20240429_174359/ephys_g0/ephys_g0_imec0/ephys_g0_t0.imec0.ap.bin')
# fast_binary_path = Path('/scratch/GRB/temp_bin/ephys_g0_t0.imec0.ap.bin')
kilosort_path = Path('/home/data/GRB006/20240429_174359/kilosort2.5/imec0/')

clu = Clusters(folder = kilosort_path, get_waveforms=False, get_metrics=True, load_template_features=True)

# ---------- this gets the row indices ---------- #
single_unit_idx = np.where((np.abs(clu.cluster_info.trough_amplitude - clu.cluster_info.peak_amplitude) > 50)
            & (clu.cluster_info.amplitude_cutoff < 0.1) 
            & (clu.cluster_info.isi_contamination < 0.1)
            & (clu.cluster_info.presence_ratio >= 0.6)
            & (clu.cluster_info.spike_duration > 0.1))[0]

# ---------- and this get the cluster_id values ---------- #
mask = ((np.abs(clu.cluster_info.trough_amplitude - clu.cluster_info.peak_amplitude) > 50)
            & (clu.cluster_info.amplitude_cutoff < 0.1) 
            & (clu.cluster_info.isi_contamination < 0.1)
            & (clu.cluster_info.presence_ratio >= 0.6)
            & (clu.cluster_info.spike_duration > 0.1))


single_unit_ids = clu.cluster_info[mask].cluster_id.values

#### Code from tutorial
https://github.com/jcouto/cshl_spks/blob/main/tutorials/tutorial_plot_psths.ipynb

In [None]:
sc = np.load('/home/data/GRB006/20240429_174359/kilosort2.5/imec0/spike_clusters.npy')
ss = np.load('/home/data/GRB006/20240429_174359/kilosort2.5/imec0/spike_times.npy')

In [None]:
np.max(st[selection])

In [None]:
st = ss/srate
# trial_start_times = trial_start_samples/srate

# lets plot a population PSTH of good units
# lets do this only for good units
selection = np.isin(sc,single_unit_ids)

binsize = 0.01 # lets use a 10ms binsize
edges = np.arange(0,np.max(st[selection]),binsize)

pop_rate,_ = np.histogram(st[selection],edges)
pop_rate = pop_rate/binsize
pop_rate_time = edges[:-1]+np.diff(edges[:2])/2

psth = []
tpre = 0.5
tpost = 1
 
for onset in first_stim_onsets:
    psth.append(pop_rate[(pop_rate_time>= onset -tpre) & (pop_rate_time< onset +tpost)])
psth = np.stack(psth)
fig1 = plt.figure(figsize=(10,10))
plt.imshow(psth,aspect='auto',extent=[-tpre,tpost,0,len(psth)],cmap = 'RdBu_r',clim = [0,2500])
plt.colorbar(label='Population rate (Hz)')
plt.xlabel('time from first stim event')
plt.ylabel('Number of trials')

In [None]:
clus = sc[selection]
# separate the spikes from each unit
timestamps = [st[selection][clus == uclu] for uclu in np.unique(clus)]

trig_ts = []
for sp in timestamps:
    trig_ts.append([])
    for o in first_stim_onsets:
        trig_ts[-1].append(sp[(sp>=(o-tpre)) & (sp<(o+tpost))] - o)

from ipywidgets import interact, IntSlider

fig2, ax = plt.subplots(1, 1, figsize=(8, 5))

tpre = 0.1
tpost = 1

@interact(iunit=IntSlider(min=0, max=len(trig_ts)-1, step=1, value=0))
def g(iunit):
    # ax.clear()
    plt.clf()
    iunit = iunit
    for i,ss in enumerate(trig_ts[iunit]):
        plt.vlines(ss,i,i+1,color = 'k')
    plt.xlim([-tpre,tpost])
    # plt.ylim([0,len(stimlog)])
    plt.ylabel('Trial #')
    plt.xlabel('Time from first stim onset (s)')

In [None]:
#interesting units to plot
#excitation: 67, 28, 29, 32, 105, 136, 151, 154, 209, 216
#inhibition: 4, 82, 84, 67, 82, 99, 106, 162, 166, 167

In [None]:
import matplotlib.pyplot as plt
from ipywidgets import IntSlider, Button, HBox, VBox
from IPython.display import display

def plot_individual_neurons_interactively(trig_ts, tpre=0.1, tpost=1):
    # Create the figure and axis
    fig2, ax = plt.subplots(1, 1, figsize=(8, 5))

    # Define the plotting function
    def plot_neuron(iunit):
        ax.clear()
        for i, ss in enumerate(trig_ts[iunit]):
            ax.vlines(ss, i, i + 1, color='k')
        ax.set_xlim([-tpre, tpost])
        ax.set_ylabel('Trial #')
        ax.set_xlabel('Time from first stim onset (s)')
        fig2.canvas.draw_idle()  # Update the plot without blocking

    # Create the slider and buttons
    slider = IntSlider(min=0, max=len(trig_ts) - 1, step=1, value=0)
    next_button = Button(description="Next")
    prev_button = Button(description="Previous")

    # Define button click event handlers
    def on_next_button_clicked(b):
        slider.value = min(slider.value + 1, slider.max)

    def on_prev_button_clicked(b):
        slider.value = max(slider.value - 1, slider.min)

    # Attach event handlers to buttons
    next_button.on_click(on_next_button_clicked)
    prev_button.on_click(on_prev_button_clicked)

    # Update plot when slider value changes
    def on_slider_value_change(change):
        plot_neuron(change['new'])

    slider.observe(on_slider_value_change, names='value')

    # Display buttons and slider
    display(VBox([HBox([prev_button, next_button]), slider]))

    # Initial plot
    plot_neuron(slider.value)

plot_individual_neurons_interactively(trig_ts)

In [None]:
for i, ss in enumerate(trig_ts[iunit]):
    ax.vlines(ss, i, i + 1, color='k')
ax.set_xlim([-tpre, tpost])
ax.set_ylabel('Trial #')
ax.set_xlabel('Time from first stim onset (s)')

In [None]:
trig_ts[151][37].shape

In [None]:
# Lets compute the psth, now it is a bit more tricky 
# because we don't have equal number of opto and no-opto trials
# we just have to do it per trial
nunits,ntrials,nstims,ntime = trig_unit_rates_ori.shape
psth_mean = np.zeros([nunits,2,nstims,ntime]) # one for opto and one for no-opto
psth_sterr = np.zeros([nunits,2,nstims,ntime]) 
for istim in range(nstims):
    optoidx = np.sort(has_opto[istim])
    for sel in [0,1]: # select between no_opto and opto trials
        psth_mean[:,sel,istim] = trig_unit_rates_ori[:,optoidx==sel,istim].mean(axis = 1)
        psth_sterr[:,sel,istim] = trig_unit_rates_ori[:,optoidx==sel,istim].std(axis = 1)/np.sqrt(np.sum(optoidx==sel))

# Plot the single trials
iunit = 21
plt.figure()
t = np.linspace(-tpre,tpost,psth_mean.shape[-1])
plt.plot(t,psth_mean[iunit,0].T,color = 'k',alpha = 0.5)
plt.plot(t,psth_mean[iunit,1].T,color = 'blue',alpha = 0.5)
# plot the visual stim dur (we are using the arduino, there is a 60ms (measured) lag)
plt.plot([0.060,0.060+1],np.array([1,1])*np.max(plt.ylim())*0.9,
         lw = 5,color='lightgray')
plt.plot([0.5,0.5+1],np.array([1,1])*np.max(plt.ylim())*0.95,
         lw = 5,color='blue')
plt.xlabel('time (s)')
plt.ylabel('spks/sec');

#### Sandbox - ignore for now

In [None]:
sp_times = np.array(cluster_spike_times.spike_times[7])
# sp_times[sp_times>300]

# Define the time window
start_time = 0  # Start of the time window
end_time = 10    # End of the time window

# Filter the spike times to include only those within the time window
filtered_sp_times = sp_times[(sp_times >= start_time) & (sp_times <= end_time)]

fig = plt.figure()
plt.eventplot(filtered_sp_times, lineoffsets=0.5, colors='black')


In [None]:
single_units_df = clu.cluster_info[
    (np.abs(clu.cluster_info.trough_amplitude - clu.cluster_info.peak_amplitude) > 50) 
    & (clu.cluster_info.amplitude_cutoff < 0.1) 
    & (clu.cluster_info.isi_contamination < 0.1)
    & (clu.cluster_info.presence_ratio >= 0.6)
    & (clu.cluster_info.spike_duration > 0.1)
]

single_units_df.cluster_id

In [None]:
def plot_cluster_info_histograms(clu):
    # Setup
    fig, axs = plt.subplots(3, 2, tight_layout=True, figsize=(10, 6))
    n_bins = 50

    # Calculate the total number of neurons
    total_neurons = len(clu.cluster_info)

    # Calculate the excluded samples
    excluded_peak_amplitude = np.sum(np.abs(clu.cluster_info.trough_amplitude - clu.cluster_info.peak_amplitude) <= 50)
    excluded_isi_contamination = np.sum(clu.cluster_info.isi_contamination > 0.1)
    excluded_spike_duration = np.sum(clu.cluster_info.spike_duration <= 0.1)
    excluded_presence_ratio = np.sum(clu.cluster_info.presence_ratio <= 0.6)
    excluded_amplitude_cutoff = np.sum(clu.cluster_info.amplitude_cutoff > 0.1)

    # Set the figure title
    fig.suptitle(f'Total clusters: {total_neurons}\n Shaded area indicates excluded clusters by metric', fontsize=14)

    # ------- Plot histograms ------- #

    # spike amplitude
    axs[0, 0].hist(np.abs(clu.cluster_info.trough_amplitude - clu.cluster_info.peak_amplitude), bins=n_bins, alpha=0.5, color='b')
    axs[0, 0].axvline(x=50, color='b', linestyle='--')
    axs[0, 0].fill_betweenx(y=axs[0, 0].get_ylim(), x1=max(axs[0, 0].get_xlim()[0], 0), x2=50, color='b', alpha=0.1)
    axs[0, 0].set_xlabel('spike amplitude')
    axs[0, 0].set_ylabel('counts')
    axs[0, 0].set_title(f'n = {excluded_peak_amplitude}')

    # ISI contamination
    axs[1, 0].hist(clu.cluster_info.isi_contamination, bins=n_bins, alpha=0.5, color='g')
    axs[1, 0].axvline(x=0.1, color='g', linestyle='--')
    axs[1, 0].fill_betweenx(y=axs[1, 0].get_ylim(), x1=0.1, x2=min(axs[1, 0].get_xlim()[1], axs[1, 0].get_xlim()[1]), color='g', alpha=0.1)
    axs[1, 0].set_xlabel('isi_contamination')
    axs[1, 0].set_ylabel('counts')
    axs[1, 0].set_title(f'n = {excluded_isi_contamination}')

    # spike duration
    axs[2, 0].hist(clu.cluster_info.spike_duration, bins=n_bins, alpha=0.5, color='r')
    axs[2, 0].axvline(x=0.1, color='r', linestyle='--')
    axs[2, 0].fill_betweenx(y=axs[2, 0].get_ylim(), x1=max(axs[2, 0].get_xlim()[0], 0), x2=0.1, color='r', alpha=0.1)
    axs[2, 0].set_xlabel('spike_duration')
    axs[2, 0].set_ylabel('counts')
    axs[2, 0].set_title(f'n = {excluded_spike_duration}')

    # presence ratio
    axs[0, 1].hist(clu.cluster_info.presence_ratio, bins=n_bins, alpha=0.5, color='orange')
    axs[0, 1].axvline(x=0.6, color='orange', linestyle='--')
    axs[0, 1].fill_betweenx(y=axs[0, 1].get_ylim(), x1=max(axs[0, 1].get_xlim()[0], 0), x2=0.6, color='orange', alpha=0.1)
    axs[0, 1].set_xlabel('presence_ratio')
    axs[0, 1].set_ylabel('counts')
    axs[0, 1].set_title(f'n = {excluded_presence_ratio}')

    # amplitude cutoff
    axs[1, 1].hist(clu.cluster_info.amplitude_cutoff, bins=n_bins, alpha=0.5, color='purple')
    axs[1, 1].axvline(x=0.1, color='purple', linestyle='--')
    axs[1, 1].fill_betweenx(y=axs[1, 1].get_ylim(), x1=0.1, x2=min(axs[1, 1].get_xlim()[1], axs[1, 1].get_xlim()[1]), color='purple', alpha=0.1)
    axs[1, 1].set_xlabel('amplitude_cutoff')
    axs[1, 1].set_ylabel('counts')
    axs[1, 1].set_title(f'n = {excluded_amplitude_cutoff}')

    # depth - not used for filtering out clusters, but a helpful viz nonetheless
    axs[2, 1].hist(clu.cluster_info.depth, bins=n_bins, alpha=0.5, color='c')
    axs[2, 1].set_xlabel('depth')
    axs[2, 1].set_ylabel('counts')


plot_cluster_info_histograms(clu)

In [None]:
units_spiketimes = clu.spike_times[np.isin(clu.spike_clusters,single_unit_idx)] 
units_spikeclusters = clu.spike_clusters[np.isin(clu.spike_clusters,single_unit_idx)]  

assert units_spiketimes.shape == units_spikeclusters.shape

In [None]:
# CLUSTER_ID = 9
# good_spikes = good_spikes[good_clusters == CLUSTER_ID] # only get spikes from one cluster
# print(f'We will extract {len(good_spikes)} spikes for this cluster')

# # good_spikes = np.hstack((good_spikes + 30000, good_spikes + np.max(good_spikes) - 30000))
# print(good_spikes.shape)

In [None]:
# import spks.waveforms as spwaves

# SCRATCH_DIR = '/scratch/GRB' # fast disk to write memmap file
# nchannels = raw_data.shape[1]

# waves = spwaves.extract_memmapped_waveforms(raw_data, SCRATCH_DIR, good_spikes[:100])
# waves.shape # (nspikes, nsamples, nchannels)

In [None]:
# spike_times = read_phy_data(kilosort_path, srate=1, use_kilosort_results=True)['ts'] #ts = timestamps i think

In [None]:
clu.load_waveforms()

In [None]:
waves = load_dict_from_h5('/home/data/GRB006/20240429_174359/kilosort2.5/imec0/cluster_waveforms.hdf')

In [None]:
waves.keys()

In [None]:
units_waves = [waves[i] for i in single_unit_ids]

In [None]:
units_waves[0].shape

In [None]:
channel_idx = raw_data.metadata[0]['channel_idx']
np.mean(units_waves[0][:,:,383], axis=0)

In [None]:
from spks.phy_utils import read_phy_data
from spks.waveforms import extract_waveform_set

# spike_times = read_phy_data(kilosort_path, srate=1, use_kilosort_results=True)['ts'] #ts = timestamps i think
SCRATCH_DIR = '/scratch/GRB/temp_waves' # fast disk to write memmap file
nchannels = raw_data.shape[1]

# waves = extract_waveform_set(spike_times, raw_data, max_n_spikes=100)
# st = [spike_times[i] for i in single_unit_idx[single_unit_idx!=1084]]
waves = extract_waveform_set(units_spiketimes, raw_data, max_n_spikes=100, mmap_output=True, scratch_directory=SCRATCH_DIR)

In [None]:
from ipywidgets import interact, IntSlider
from spks.viz import plot_footprints

fig, ax = plt.subplots(1, 1, figsize=(12, 5))

@interact(clu_ids=IntSlider(min=0, max=len(units_waves)-1, step=1, value=0))
def g(clu_ids):
    ax.clear()

    channel_xy = raw_data.metadata[0]['coords']
    channel_idx = raw_data.metadata[0]['channel_idx']-1
    plot_footprints(np.mean(units_waves[clu_ids][:,:,channel_idx], axis=0), channel_xy, shade_color='k')

    return

In [None]:
# weird clusters: 48, 79, 81, 83, 90, 91, 113, 115... there's actually a bunch of them

In [None]:
single_unit_ids

In [None]:
plt.figure(figsize=[12,5])
# plot the cluster waveforms accross all channels and overlay with the standard deviation
plot_footprints(clu.cluster_waveforms_mean[clu.cluster_id == 9],clu.channel_positions,
                shade_data = clu.cluster_waveforms_std[clu.cluster_id == 9].squeeze(), color='r');
# plt.axis((720, 805., 1407, 1677));

In [None]:
# plot the principal channels for a set of clusters
# %matplotlib notebook
# plot some of the high amplitude clusters
clusters = clu.cluster_id[np.argsort(clu.trough_amplitude)[:5]]
for iclu,c in zip(clusters,['#d62728',
                            '#1f77b4',
                            '#ff7f0e',
                            '#2ca02c',
                            '#9467bd']):
    
    idx = clu.active_channels[clu.cluster_id == iclu][0]
    plot_footprints(clu.cluster_waveforms_mean[clu.cluster_id == iclu].squeeze()[:,idx],clu.channel_positions[idx,:],
                    color=c);
    
# plt.axis((718, 808, 1718, 2231));