In [None]:
""" import settings """
%load_ext autoreload
%autoreload 2

import numpy as np
from matplotlib import pyplot as plt
from matplotlib_settings import set_plot_settings, reset_plot_settings

# Set the plot settings
set_plot_settings()

# import global variables
from utils_ssep_global import *

In [None]:
""" load data, define static params """
# load bad channel
bad_ch_idx_dir = f'{DATA_DIR}/1_bad_channels'
bad_chs = np.load(f"{bad_ch_idx_dir}/bad_ch_idx.npy")

# load downsampled data
fs, Ts = FS/RS, TS*RS
ds_data_dir = f'{DATA_DIR}/5_downsample'
t = np.load(f"{ds_data_dir}/t_{RS}.npy")

# SSEP index
sep_idxs = np.where(np.logical_and(t > SEP_T0, t < SEP_T1))[0]

# Baseline index
baseline_idxs = np.where(np.logical_and(t > BASELINE_T0, t < BASELINE_T1))[0]

Choose a "stim_site"

In [None]:
# 0: Median Nerve
# 1: Snout Lateral
# 2: Snout Superior
# 3: Snout Medial
# 4: Snout Inferior
stim_site = 3 # pick from 0, 1, 2, 3, 4. 

In [None]:
""" load segmentized recording """
min_trial = 100 # number of trials to average over

# load data
fn_label = STIM_LABELS[stim_site].replace(" ", "_").lower()
ds_segs = np.load(f"{ds_data_dir}/{fn_label}_ds_{RS}_segs.npy")
vtc = np.load(f"{ds_data_dir}/{fn_label}_valid_trial_count.npy")

In [None]:
""" channel recording averaged over a fixed number of trials """
ch_means = np.full((ds_segs.shape[0], ds_segs.shape[2]), np.nan)  # dimension: ch*time

# channels with insufficient valid trials will be discarded
disqualified_chs = np.where(vtc < min_trial)[0]

for ch, ch_data in enumerate(ds_segs):
    if ch in bad_chs or ch in disqualified_chs:
        continue

    count = 0
    ch_mean = np.zeros((ch_data.shape[1],))
    for trial_data in ch_data:
        if np.any(np.isnan(trial_data)): # if any data point is saturated
            continue

        ch_mean += trial_data
        count += 1
        if count == min_trial:
            break

    assert count == min_trial
    ch_mean = ch_mean/min_trial

    # apply baseline correction
    ch_means[ch] = ch_mean - np.mean(ch_mean[baseline_idxs])

In [None]:
""" Channel recordings will be arranged according to their "peak" SSEP voltage """
maxs = np.nanmax(ch_means[:,sep_idxs], axis=1)
mins = np.nanmin(ch_means[:,sep_idxs], axis=1)

# peak-to-peak amplitude
p2pks = maxs - mins
max_p2pk_ch =  np.nanargmax(p2pks) # channel with maximum peak-to-peak
max_p2pk_row, max_p2pk_col = max_p2pk_ch//16, max_p2pk_ch%16
print(max_p2pk_row, max_p2pk_col)

# peak amplitude
pks = np.zeros((NCH, 2))
for ch, (vmax, vmin) in enumerate(zip(maxs, mins)):
    pks[ch,1] = ch
    if np.isnan(vmax) or np.isnan(vmin):
        pks[ch,0] = np.nan
        continue
    if vmax > -vmin:
        pks[ch,0] = vmax
    else:
        pks[ch,0] = vmin

# sort by peak
peak_sort = np.delete(pks, np.where(np.isnan(pks)), axis=0)
peak_sort_idxs = np.argsort(peak_sort[:,0])[::-1] # reverse
peak_sort = peak_sort[peak_sort_idxs]

# fig, ax = plt.subplots(figsize=(3,1))
# ax.plot(peak_sort[:,0])

Generate Waterfall Plot

In [None]:
from utils_ssep_plot import plot_waterfall
# std_devs = np.nanstd(ch_means[:,sep_idxs], axis=1)

t0, t1 = -50e-3, 100e-3 # plot range

sort_chs = peak_sort[:,1].astype(int)
voff = 10 # vertical offset between channels

# convert bits to voltage
plot_data = ch_means*FS_ADC/MAX_ADC_CODE/GAIN/1e-6

num_good_ch = 256 - np.sum((vtc < min_trial).astype(int))
title_str = f'N = {min_trial}, nch = {num_good_ch}\n'
if stim_site == 0:
    title_str += 'Median Nerve'
else:
    title_str += STIM_LABELS[stim_site]


plt.close('all')
figsize = (4.5, 12)
fig, ax = plt.subplots(figsize=figsize)
plot_waterfall(ax, t, plot_data, sort_chs, t0, t1, voff, sel_detrend =True,
               title_str=title_str)


# save_dir = './figures/ssep/waterfall'
# save_fn = f'{fn_label}_waterfall'
# plt.savefig(f"{save_dir}./{save_fn}.svg", bbox_inches='tight')
# plt.savefig(f"{save_dir}./{save_fn}.png", bbox_inches='tight', dpi=1200)