In [None]:
""" import settings """
%load_ext autoreload
%autoreload 2

import numpy as np
from matplotlib import pyplot as plt
from matplotlib_settings import set_plot_settings, reset_plot_settings

# Set the plot settings
set_plot_settings()

# import global variables
from utils_ssep_global import *

In [None]:
""" load data, define static params """
# load bad channel
bad_ch_idx_dir = f'{DATA_DIR}/1_bad_channels'
bad_chs = np.load(f"{bad_ch_idx_dir}/bad_ch_idx.npy")

# load downsampled data
fs, Ts = FS/RS, TS*RS
ds_data_dir = f'{DATA_DIR}/5_downsample'
t = np.load(f"{ds_data_dir}/t_{RS}.npy")

# SSEP index
sep_idxs = np.where(np.logical_and(t > SEP_T0, t < SEP_T1))[0]

# Baseline index
baseline_idxs = np.where(np.logical_and(t > BASELINE_T0, t < BASELINE_T1))[0]

Choose a "stim_site"

In [None]:
# 0: Median Nerve
# 1: Snout Lateral
# 2: Snout Superior
# 3: Snout Medial
# 4: Snout Inferior
stim_site = 4 # pick from 0, 1, 2, 3, 4. 

In [None]:
""" load segmentized recording """
min_trial = 100 # number of trials to average over

# load data
fn_label = STIM_LABELS[stim_site].replace(" ", "_").lower()
ds_segs = np.load(f"{ds_data_dir}/{fn_label}_ds_{RS}_segs.npy")
vtc = np.load(f"{ds_data_dir}/{fn_label}_valid_trial_count.npy")

In [None]:
""" channel recording averaged over a fixed number of trials """
ch_means = np.full((ds_segs.shape[0], ds_segs.shape[2]), np.nan)  # dimension: ch*time
ch_stds  = np.full((ds_segs.shape[0], ds_segs.shape[2]), np.nan)  # dimension: ch*time

# channels with insufficient valid trials will be discarded
disqualified_chs = np.where(vtc < min_trial)[0]

for ch, ch_data in enumerate(ds_segs):
    if ch in bad_chs or ch in disqualified_chs:
        continue
    count = 0

    trial_datas = []
    for trial_data in ch_data:
        if np.any(np.isnan(trial_data)):
            continue

        trial_datas.append(trial_data)
        count += 1
        if count == min_trial:
            break

    assert count == min_trial
    
    trial_datas = np.array(trial_datas)
    ch_mean = np.mean(trial_datas, axis=0)
    # baseline correction
    ch_means[ch] = ch_mean - np.mean(ch_mean[baseline_idxs])
    ch_stds[ch] = np.std(trial_datas, axis=0)

In [None]:
""" convert to micro-volts """
data_means = ch_means*FS_ADC/MAX_ADC_CODE/GAIN/1e-6
data_stds  = ch_stds*FS_ADC/MAX_ADC_CODE/GAIN/1e-6

In [None]:
from utils_ssep_plot import plot_spatiotemporal
plt.close('all')
fig, ax = plt.subplots(16, 16, figsize=(6, 6), sharex=True, sharey=True)

# adjust spacing between subplots (default: 0.2 i.e. 20% of plot)
plt.subplots_adjust(wspace=0.05, hspace=0)

if stim_site == 0:
    title_str = 'Median Nerve'
else:
    title_str = STIM_LABELS[stim_site]
t0, t1 = 0, 50e-3 # plotting range
plot_spatiotemporal(fig, ax, t, data_means, data_stds, disqualified_chs, t0, t1, title_str)

# save_dir = './figures/ssep/16x16_temporal'
# save_fn = f'{fn_label}_16x16_temporal'
# plt.savefig(f"{save_dir}./{save_fn}.svg", bbox_inches='tight')
# plt.savefig(f"{save_dir}./{save_fn}.png", bbox_inches='tight', dpi=1200)