In [None]:
""" import settings """
%load_ext autoreload
%autoreload 2

import numpy as np
from matplotlib import pyplot as plt
from matplotlib_settings import set_plot_settings, reset_plot_settings

# Set the plot settings
set_plot_settings()

# import global variables
from utils_ssep_global import *

In [None]:
""" load data, define static params """
# load bad channel
bad_ch_idx_dir = f'{DATA_DIR}/1_bad_channels'
bad_chs = np.load(f"{bad_ch_idx_dir}/bad_ch_idx.npy")

# load downsampled data
fs, Ts = FS/RS, TS*RS
ds_data_dir = f'{DATA_DIR}/5_downsample'
t = np.load(f"{ds_data_dir}/t_{RS}.npy")

# SSEP index
sep_idxs = np.where(np.logical_and(t > SEP_T0, t < SEP_T1))[0]

# Baseline index
baseline_idxs = np.where(np.logical_and(t > BASELINE_T0, t < BASELINE_T1))[0]

Load trial averaged channel recording

In [None]:
n_sites = 5 # total number of stimulation locations
stim_sites = [0, 1, 4, 2, 3] # re-order for plotting
min_trial = 100

In [None]:
""" for all stim sites, channel recordings averaged over a fixed number of trials """
data_means = np.zeros((n_sites, NCH, len(t))) # dimension: site*ch*time

for idx, stim_site in enumerate(stim_sites):
    # load data
    fn_label = STIM_LABELS[stim_site].replace(" ", "_").lower()
    ds_segs = np.load(f"{ds_data_dir}/{fn_label}_ds_{RS}_segs.npy")
    vtc = np.load(f"{ds_data_dir}/{fn_label}_valid_trial_count.npy")

    # channels with insufficient valid trials will be discarded
    disqualified_ch = np.where(vtc < min_trial)[0]

    ch_means = np.full((ds_segs.shape[0], ds_segs.shape[2]), np.nan) # dimension: ch*time

    # ch_data: trial*time
    for ch, ch_data in enumerate(ds_segs):
        if ch in bad_chs or ch in disqualified_ch:
            continue

        count = 0
        ch_mean = np.zeros((ch_data.shape[1],))
        for trial_data in ch_data:
            if np.any(np.isnan(trial_data)):
                continue

            ch_mean += trial_data
            count += 1
            if count == min_trial:
                break

        assert count == min_trial
        ch_mean = ch_mean/min_trial

        # apply baseline correction
        ch_means[ch] = ch_mean - np.mean(ch_mean[baseline_idxs])
    data_means[idx] = ch_means

Make Spatial Map of SSEP Peaks (Extrema)

In [None]:
""" find SSEP peak (extrema) for each stimulation site """
data_sep_maxs = np.nanmax(data_means[:,:,sep_idxs], axis=2) # dimension: site*ch
data_sep_mins = np.nanmin(data_means[:,:,sep_idxs], axis=2) # dimension: site*ch

data_sep_pks = np.zeros_like(data_sep_maxs)

for ii, (sep_maxs, sep_mins) in enumerate(zip(data_sep_maxs, data_sep_mins)):
    for ch, (vmax, vmin) in enumerate(zip(sep_maxs, sep_mins)):
        if np.isnan(vmax) or np.isnan(vmin):
            data_sep_pks[ii][ch] = np.nan
        elif vmax > -vmin:
            data_sep_pks[ii][ch] = vmax
        else:
            data_sep_pks[ii][ch] = vmin

In [None]:
""" plot spatial map """
from utils_ssep_plot import plot_spatial_peaks

fig, ax = plt.subplots(2, 3, figsize=(8, 6))
plt.subplots_adjust(wspace=0.05, hspace=0.05)

title_strs = ['MN', 'SL', 'SI', 'SS', 'SM'] 
ax_rc = [(0, 0), (0, 1), (1, 0), (1, 1), (1, 2)] # assign ax to each stim site

plot_spatial_peaks(fig, ax, data_sep_pks, title_strs, ax_rc)

# save_dir = './figures/ssep/spatial_map'
# save_fn = f'ssep_spatial_map'
# plt.savefig(f"{save_dir}/{save_fn}.svg", bbox_inches='tight')
# plt.savefig(f"{save_dir}/{save_fn}.png", bbox_inches='tight', dpi=1200)