In [None]:
""" import settings """
%load_ext autoreload
%autoreload 2

import numpy as np
from matplotlib import pyplot as plt
from matplotlib_settings import set_plot_settings, reset_plot_settings

# Set the plot settings
set_plot_settings()

# import global variables
from utils_ssep_global import *

In [None]:
""" load data, define static params """
# load bad channel
bad_ch_idx_dir = f'{DATA_DIR}/1_bad_channels'
bad_chs = np.load(f"{bad_ch_idx_dir}/bad_ch_idx.npy")

# load downsampled data
fs, Ts = FS/RS, TS*RS
ds_data_dir = f'{DATA_DIR}/5_downsample'
t = np.load(f"{ds_data_dir}/t_{RS}.npy")

# SSEP index
sep_idxs = np.where(np.logical_and(t > SEP_T0, t < SEP_T1))[0]

# Baseline index
baseline_idxs = np.where(np.logical_and(t > BASELINE_T0, t < BASELINE_T1))[0]

In [None]:
n_sites = 5 # total number of stimulation locations
stim_sites = [0, 1, 4, 2, 3] # re-order for plotting
min_trial = 100

Load trial averaged channel recording

In [None]:
""" for all stim sites, channel recordings averaged over a fixed number of trials """
data_means = np.zeros((n_sites, NCH, len(t))) # dimension: site*ch*time

for idx, stim_site in enumerate(stim_sites):
    # load data
    fn_label = STIM_LABELS[stim_site].replace(" ", "_").lower()
    ds_segs = np.load(f"{ds_data_dir}/{fn_label}_ds_{RS}_segs.npy")
    vtc = np.load(f"{ds_data_dir}/{fn_label}_valid_trial_count.npy")

    # channels with insufficient valid trials will be discarded
    disqualified_ch = np.where(vtc < min_trial)[0]

    ch_means = np.full((ds_segs.shape[0], ds_segs.shape[2]), np.nan) # dimension: ch*time

    # ch_data: trial*time
    for ch, ch_data in enumerate(ds_segs):
        if ch in bad_chs or ch in disqualified_ch:
            continue

        count = 0
        ch_mean = np.zeros((ch_data.shape[1],))
        for trial_data in ch_data:
            if np.any(np.isnan(trial_data)):
                continue

            ch_mean += trial_data
            count += 1
            if count == min_trial:
                break

        assert count == min_trial
        ch_mean = ch_mean/min_trial

        # apply baseline correction
        ch_means[ch] = ch_mean - np.mean(ch_mean[baseline_idxs])
    data_means[idx] = ch_means

In [None]:
""" normalize """
sep_means = data_means[:,:,sep_idxs]
for sep_mean, data_mean in zip(sep_means, data_means):
    data_mean /= np.nanmax(np.abs(sep_mean)) # normalize

In [None]:
""" time points of a few example frames to plot """
frame_ts = np.array([-10e-3, -5e-3, 0, 5e-3, 10e-3, 15e-3, 20e-3, 25e-3,
                     30e-3, 35e-3, 40e-3, 45e-3])

frame_idxs = []
for frame_t in frame_ts:
    frame_idxs.append(np.where(t < frame_t)[0][-1])
frame_idxs = np.array(frame_idxs)

In [None]:
""" plot and save a few example frames """
from utils_ssep_plot import plot_spatial_cbar, plot_spatial_ruler, plot_spatial_compass

title_strs = []
for stim_site in stim_sites:
    if stim_site == 0:
        title_strs.append('Median Nerve')
    else:
        title_strs.append(STIM_LABELS[stim_site])

save_dir = './figures/ssep/spatial_map'

for idx0 in frame_idxs:
    plt.close('all')
    fig, ax = plt.subplots(2, 3, figsize=(8, 6))
    plt.subplots_adjust(wspace=0.05, hspace=0.05)

    ax_rc = [(0, 0), (0, 1), (1, 0), (1, 1), (1, 2)]
    for idx, rc in enumerate(ax_rc):
        r, c = rc
        ax[r, c].set_xticks([])
        ax[r, c].set_yticks([])
        ax[r, c].set_title(title_strs[idx])

    im0 = ax[0, 0].imshow(data_means[0,:,idx0].reshape(16, -1), vmax=1, vmin=-1, cmap='bwr')
    im1 = ax[0, 1].imshow(data_means[1,:,idx0].reshape(16, -1), vmax=1, vmin=-1, cmap='bwr')
    im2 = ax[1, 0].imshow(data_means[2,:,idx0].reshape(16, -1), vmax=1, vmin=-1, cmap='bwr')
    im3 = ax[1, 1].imshow(data_means[3,:,idx0].reshape(16, -1), vmax=1, vmin=-1, cmap='bwr')
    im4 = ax[1, 2].imshow(data_means[4,:,idx0].reshape(16, -1), vmax=1, vmin=-1, cmap='bwr')

    fig.suptitle(f't = {t[idx0]/1e-3:.1f} ms')

    plot_spatial_cbar(fig, ax[0, 2])
    plot_spatial_ruler(ax[0, 2])
    plot_spatial_compass(ax[0, 2])

    # save_fn = f'ssep_spatial_map_{np.round(t[idx0]/1e-3)}ms'
    # plt.savefig(f"{save_dir}/{save_fn}.svg", bbox_inches='tight')
    # plt.savefig(f"{save_dir}/{save_fn}.png", bbox_inches='tight', dpi=1200)

Plot Frames

In [None]:
""" define time range for generating movie """ 
ani_t0, ani_t1 = -10e-3, 45e-3
fn_t0, fn_t1 = int(ani_t0/1e-3), int(ani_t1/1e-3)
ani_idxs = np.where(np.logical_and(t > ani_t0, t < ani_t1))[0]

In [None]:
""" initialize the 2x3 plot """
from utils_ssep_plot import plot_spatial_cbar, plot_spatial_ruler, plot_spatial_compass

title_strs = []
for stim_site in stim_sites:
    if stim_site == 0:
        title_strs.append('Median Nerve')
    else:
        title_strs.append(STIM_LABELS[stim_site])

plt.close('all')
fig, ax = plt.subplots(2, 3, figsize=(8, 6))
plt.subplots_adjust(wspace=0.05, hspace=0.05)

ax_rc = [(0, 0), (0, 1), (1, 0), (1, 1), (1, 2)]
for idx, rc in enumerate(ax_rc):
    r, c = rc
    ax[r, c].set_xticks([])
    ax[r, c].set_yticks([])
    ax[r, c].set_title(title_strs[idx])

idx0 = ani_idxs[0]
im0 = ax[0, 0].imshow(data_means[0,:,idx0].reshape(16, -1), vmax=1, vmin=-1, cmap='bwr')
im1 = ax[0, 1].imshow(data_means[1,:,idx0].reshape(16, -1), vmax=1, vmin=-1, cmap='bwr')
im2 = ax[1, 0].imshow(data_means[2,:,idx0].reshape(16, -1), vmax=1, vmin=-1, cmap='bwr')
im3 = ax[1, 1].imshow(data_means[3,:,idx0].reshape(16, -1), vmax=1, vmin=-1, cmap='bwr')
im4 = ax[1, 2].imshow(data_means[4,:,idx0].reshape(16, -1), vmax=1, vmin=-1, cmap='bwr')

fig.suptitle(f't = {t[idx0]/1e-3:.1f} ms')

plot_spatial_cbar(fig, ax[0, 2])
plot_spatial_ruler(ax[0, 2])
plot_spatial_compass(ax[0, 2])

In [None]:
""" generate movie """
# download FFMPegWriter from here: https://ffmpeg.org/download.html
import matplotlib.animation as animation
from matplotlib.animation import FFMpegWriter
save_dir = '.'

def animate(idx):
    fig.suptitle(f't = {t[idx]/1e-3:.1f} ms')

    frame0 = data_means[0,:,idx].reshape(16, -1)
    frame1 = data_means[1,:,idx].reshape(16, -1)
    frame2 = data_means[2,:,idx].reshape(16, -1)
    frame3 = data_means[3,:,idx].reshape(16, -1)
    frame4 = data_means[4,:,idx].reshape(16, -1)

    im0.set_array(frame0)
    im1.set_array(frame1)
    im2.set_array(frame2)
    im3.set_array(frame3)
    im4.set_array(frame4)

    return im0, im1, im2, im3, im4

ani = animation.FuncAnimation(fig, animate, frames=ani_idxs, blit=True)

save_fn = f"bisc_porcine_ssep"
ani.save(f'{save_dir}/{save_fn}.mp4', writer=FFMpegWriter(fps=10))