In [None]:
%load_ext autoreload
%autoreload 2
import os, sys

import numpy as np

from matplotlib import pyplot as plt
from matplotlib_settings import set_plot_settings, reset_plot_settings

# Set the plot settings
set_plot_settings()

# import global variables
from utils_motor_global import *
from utils_motor_plot  import draw_CS_boundary

from utils_motor_sigproc import normalize_band
from utils_motor_misc import load_session_spect_data

ROOT_SCALOMAT_DIR = f'{MODEL_INPUT_DIR}/scalogram_matrix'
# SAVE_IMG_DIR = f'{MODEL_OUTPUT_DIR}/decoder_and_band_plot'

Load data from sessions 10, 12, 13

In [None]:
def fetch_motion_data(key):
    load_dir = f'{MODEL_OUTPUT_DIR}/decoded/{key}'
    y_true  = np.load(f'{load_dir}/vel_y_observed_{key}.npy')
    y_pred  = np.load(f'{load_dir}/vel_y_predicted_{key}.npy')
    return y_true, y_pred

y_true_10, y_pred_10 = fetch_motion_data('010')
y_true_12, y_pred_12 = fetch_motion_data('012')
y_true_13, y_pred_13 = fetch_motion_data('013')

# motion time scale
motion_t_10 = np.load(f'{MOTION_DIR}/010/model_t_010.npy')
motion_t_12 = np.load(f'{MOTION_DIR}/012/model_t_012.npy')
motion_t_13 = np.load(f'{MOTION_DIR}/013/model_t_013.npy')

In [None]:
""" channel band mean and standard deviation """
spect_mu  = np.load(f'{BAND_MU_SIGMA_DIR}/spect_mu.npy')
lmp_mu    = np.load(f'{BAND_MU_SIGMA_DIR}/lmp_mu.npy')
beta_mu   = np.load(f'{BAND_MU_SIGMA_DIR}/beta_mu.npy')
hga_mu    = np.load(f'{BAND_MU_SIGMA_DIR}/hga_mu.npy')

spect_sigma  = np.load(f'{BAND_MU_SIGMA_DIR}/spect_sigma.npy')
lmp_sigma    = np.load(f'{BAND_MU_SIGMA_DIR}/lmp_sigma.npy')
beta_sigma   = np.load(f'{BAND_MU_SIGMA_DIR}/beta_sigma.npy')
hga_sigma    = np.load(f'{BAND_MU_SIGMA_DIR}/hga_sigma.npy')

In [None]:
""" load session data """
# spectrogram and LMP
good_chs_10, spect_10, lmp_data_10, spect_t_10 = load_session_spect_data(
    f'{SPECT_DATA_DIR}/010', '010')
good_chs_12, spect_12, lmp_data_12, spect_t_12 = load_session_spect_data(
    f'{SPECT_DATA_DIR}/012', '012')
good_chs_13, spect_13, lmp_data_13, spect_t_13 = load_session_spect_data(
    f'{SPECT_DATA_DIR}/013', '013')

# load common data
key = '013'
spect_f     = np.load(f'{SPECT_DATA_DIR}/{key}/spect_f_{key}.npy'   ) 
beta_fidxs  = np.load(f'{SPECT_DATA_DIR}/{key}/beta_fidxs_{key}.npy') 
hga_fidxs   = np.load(f'{SPECT_DATA_DIR}/{key}/hga_fidxs_{key}.npy' ) 

In [None]:
""" extract band from spectrogram """
spect_beta_10 = np.sum(spect_10[:,:,beta_fidxs], axis=2)
spect_beta_12 = np.sum(spect_12[:,:,beta_fidxs], axis=2)
spect_beta_13 = np.sum(spect_13[:,:,beta_fidxs], axis=2)
spect_hga_10  = np.sum(spect_10[:,:, hga_fidxs], axis=2)
spect_hga_12  = np.sum(spect_12[:,:, hga_fidxs], axis=2)
spect_hga_13  = np.sum(spect_13[:,:, hga_fidxs], axis=2)

In [None]:
""" z-score bands """
zs_lmp_10  = normalize_band(good_chs_10, lmp_data_10, lmp_mu, lmp_sigma, sel_zscore=True)
zs_lmp_12  = normalize_band(good_chs_12, lmp_data_12, lmp_mu, lmp_sigma, sel_zscore=True)
zs_lmp_13  = normalize_band(good_chs_13, lmp_data_13, lmp_mu, lmp_sigma, sel_zscore=True)

zs_hga_10  = normalize_band(good_chs_10, spect_hga_10, hga_mu, hga_sigma, sel_zscore=True)
zs_hga_12  = normalize_band(good_chs_12, spect_hga_12, hga_mu, hga_sigma, sel_zscore=True)
zs_hga_13  = normalize_band(good_chs_13, spect_hga_13, hga_mu, hga_sigma, sel_zscore=True)

zs_beta_10 = normalize_band(good_chs_10, spect_beta_10, beta_mu, beta_sigma, sel_zscore=True)
zs_beta_12 = normalize_band(good_chs_12, spect_beta_12, beta_mu, beta_sigma, sel_zscore=True)
zs_beta_13 = normalize_band(good_chs_13, spect_beta_13, beta_mu, beta_sigma, sel_zscore=True)

In [None]:
""" find common channels and create subset. do not truncate yet """
overlap_chs = np.intersect1d(good_chs_10, good_chs_12)
overlap_chs = np.intersect1d(overlap_chs, good_chs_13)

ch_idxs_10 = np.array([ch_idx for ch_idx, ch in enumerate(good_chs_10) if ch in overlap_chs])
ch_idxs_12 = np.array([ch_idx for ch_idx, ch in enumerate(good_chs_12) if ch in overlap_chs])
ch_idxs_13 = np.array([ch_idx for ch_idx, ch in enumerate(good_chs_13) if ch in overlap_chs])

In [None]:
""" truncate in time """
# motion (wrist velocity) data time range is a subset of spectrogram time scale
# truncate the spectrogram in time domain to make them equal

# step 1. identify index
idx0_10 = np.where(spect_t_10 >= motion_t_10[0])[0][0]
idx1_10 = np.where(spect_t_10 <= motion_t_10[-1])[0][-1] + 1

idx0_12 = np.where(spect_t_12 >= motion_t_12[0])[0][0]
idx1_12 = np.where(spect_t_12 <= motion_t_12[-1])[0][-1] + 1

idx0_13 = np.where(spect_t_13 >= motion_t_13[0])[0][0]
idx1_13 = np.where(spect_t_13 <= motion_t_13[-1])[0][-1] + 1

assert len(spect_t_10[idx0_10:idx1_10]) == len(motion_t_10)
assert len(spect_t_12[idx0_12:idx1_12]) == len(motion_t_12)
assert len(spect_t_13[idx0_13:idx1_13]) == len(motion_t_13)

# step 2. truncate
zs_lmp_10 = zs_lmp_10[:, idx0_10:idx1_10]
zs_hga_10 = zs_hga_10[:, idx0_10:idx1_10]
zs_beta_10 = zs_beta_10[:, idx0_10:idx1_10]

zs_lmp_12 = zs_lmp_12[:, idx0_12:idx1_12]
zs_hga_12 = zs_hga_12[:, idx0_12:idx1_12]
zs_beta_12 = zs_beta_12[:, idx0_12:idx1_12]

zs_lmp_13 = zs_lmp_13[:, idx0_13:idx1_13]
zs_hga_13 = zs_hga_13[:, idx0_13:idx1_13]
zs_beta_13 = zs_beta_13[:, idx0_13:idx1_13]

In [None]:
""" create copy before truncating channels (for spatiotemporal progression plots later)"""
fch_zs_lmp_10 = np.copy(zs_lmp_10)
fch_zs_hga_10 = np.copy(zs_hga_10)
fch_zs_beta_10 = np.copy(zs_beta_10)

In [None]:
""" truncate channels (keep only common channels) """
zs_lmp_10 = zs_lmp_10[ch_idxs_10, :]
zs_hga_10 = zs_hga_10[ch_idxs_10, :]
zs_beta_10 = zs_beta_10[ch_idxs_10, :]

zs_lmp_12 = zs_lmp_12[ch_idxs_12, :]
zs_hga_12 = zs_hga_12[ch_idxs_12, :]
zs_beta_12 = zs_beta_12[ch_idxs_12, :]

zs_lmp_13 = zs_lmp_13[ch_idxs_13, :]
zs_hga_13 = zs_hga_13[ch_idxs_13, :]
zs_beta_13 = zs_beta_13[ch_idxs_13, :]

Merge data across sessions 10, 12, 13

In [None]:
""" merge motor features, bands """
merged_y_true = np.concatenate((y_true_10, y_true_12, y_true_13))
merged_y_pred = np.concatenate((y_pred_10, y_pred_12, y_pred_13))

merged_zs_lmp = np.concatenate((zs_lmp_10, zs_lmp_12, zs_lmp_13), axis=-1)
merged_zs_hga = np.concatenate((zs_hga_10, zs_hga_12, zs_hga_13), axis=-1)
merged_zs_beta = np.concatenate((zs_beta_10, zs_beta_12, zs_beta_13), axis=-1)

In [None]:
""" merge time """
motion_t_10 = motion_t_10 - motion_t_10[0]
motion_t_12 = motion_t_12 - motion_t_12[0] + motion_t_10[-1] + motion_t_10[1]
motion_t_13 = motion_t_13 - motion_t_13[0] + motion_t_12[-1] + motion_t_10[1]

merged_t = np.concatenate((motion_t_10, motion_t_12))
merged_t = np.concatenate((merged_t, motion_t_13))

Plot Bands

In [None]:
""" decide on time segment to plot """
t_end = 60 # choose length of segment to plot
t_end_idx = np.where(merged_t > t_end)[0][0]

merged_zs_lmp   = merged_zs_lmp  [:,:t_end_idx]
merged_zs_hga   = merged_zs_hga  [:,:t_end_idx]
merged_zs_beta  = merged_zs_beta [:,:t_end_idx]

In [None]:
""" decide on quantile to apply for colormap range """
lmp_qcut = 0.01
vmin = np.quantile(merged_zs_lmp, lmp_qcut)
vmax = np.quantile(merged_zs_lmp, 1-lmp_qcut)
abs_vmax = max(-vmin, vmax)
lmp_vmin = -1*abs_vmax
lmp_vmax = abs_vmax

hga_qcut = 0.01
hga_vmin = np.quantile(merged_zs_hga, hga_qcut)
hga_vmax = np.quantile(merged_zs_hga, 1-hga_qcut)

beta_qcut = 0.01
beta_vmin = np.quantile(merged_zs_beta, beta_qcut)
beta_vmax = np.quantile(merged_zs_beta, 1-beta_qcut)

In [None]:
""" plot two bands """
# step 1. plot
plt.close('all')
fig, ax = plt.subplots(2, 1, figsize=(7, 2.2), sharex=True)

t0, t1 = 0, t_end
extent = [t0, t1, len(ch_idxs_10), 0]
im0 = ax[0].imshow(merged_zs_lmp, aspect='auto', vmin=lmp_vmin, vmax=lmp_vmax, 
             extent=extent, cmap='bwr', interpolation='none')
im1 = ax[1].imshow(merged_zs_hga, aspect='auto', vmin=hga_vmin, vmax=hga_vmax, 
             extent=extent, cmap='viridis')

fig.text(0.08, 0.5, 'Channels', va='center', rotation='vertical')
ax[0].invert_yaxis()
ax[1].invert_yaxis()

# step 2. add labels
left_cbar = 0.92  # Adjust left position of the colorbars
bottom_cbar0 = 0.53  # Adjust bottom position of the colorbars
width_cbar = 0.02  # Adjust width of the colorbars
height_cbar = 0.35  # Adjust height of the colorbars

cbar0_ax = fig.add_axes([left_cbar, bottom_cbar0, width_cbar, height_cbar])
cbar0 = fig.colorbar(im0, cax=cbar0_ax)
cbar0.set_ticks([-1.5, 0, 1.5])

bottom_cbar1 = 0.12  # Adjust bottom position of the colorbars
cbar1_ax = fig.add_axes([left_cbar, bottom_cbar1, width_cbar, height_cbar])
cbar1 = fig.colorbar(im1, cax=cbar1_ax)
cbar1.set_ticks([-1.5, 0, 1.5, 3])

ax[1].set_xlabel('Time (s)')

ax[0].grid(True)
ax[1].grid(True)

ax[0].set_yticks([])
ax[1].set_yticks([])
ax[1].set_xticks([0, 10, 20, 30, 40, 50, 60])

plt.subplots_adjust(hspace=0.15)

# plt.savefig(f"{SAVE_IMG_DIR}/two_bands.svg", bbox_inches='tight')
# plt.savefig(f"{SAVE_IMG_DIR}/two_bands.png", bbox_inches='tight', dpi=1200)

In [None]:
""" plot three bands """
# step 1. plot
plt.close('all')
fig, ax = plt.subplots(4, 1, figsize=(7, 2.2*2), sharex=True)

ax[0].plot(merged_t[:len(merged_y_true)]*T_DF_MOTION, merged_y_true)
ax[0].plot(merged_t[:len(merged_y_true)]*T_DF_MOTION, merged_y_pred)
ax[0].grid(True)
ax[0].set_yticks([0])
ax[0].set_xlim((0, 60)) # set the rnage to (0, 100) if you want to re-create the figure in paper

ax[0].set_ylabel('Wrist y-vel.\n(normalized)', fontsize=10)
ax[0].legend(['Observed', 'Predicted'], loc=(1.02, 0.03), fontsize=10)
ax[0].set_xticks(np.arange(0, 61, 10))

t0, t1 = 0, t_end
extent = [t0, t1, len(ch_idxs_10), 0]

im1 = ax[1].imshow(merged_zs_lmp, aspect='auto', vmin=lmp_vmin, vmax=lmp_vmax, 
             extent=extent, cmap='bwr', interpolation='none')
im2 = ax[2].imshow(merged_zs_beta, aspect='auto', vmin=beta_vmin, vmax=beta_vmax, 
             extent=extent, cmap='viridis')
im3 = ax[3].imshow(merged_zs_hga, aspect='auto', vmin=hga_vmin, vmax=hga_vmax, 
             extent=extent, cmap='viridis')

ax[1].invert_yaxis()
ax[2].invert_yaxis()
ax[3].invert_yaxis()

# step 2. add labels
left_cbar = 0.92  # Adjust left position of the colorbars
width_cbar = 0.02  # Adjust width of the colorbars
height_cbar = 0.17  # Adjust height of the colorbars

bottom_cbar1 = 0.51  # Adjust bottom position of the colorbars
bottom_cbar2 = 0.31  # Adjust bottom position of the colorbars
bottom_cbar3 = 0.11  # Adjust bottom position of the colorbars

cbar1_ax = fig.add_axes([left_cbar, bottom_cbar1, width_cbar, height_cbar])
cbar2_ax = fig.add_axes([left_cbar, bottom_cbar2, width_cbar, height_cbar])
cbar3_ax = fig.add_axes([left_cbar, bottom_cbar3, width_cbar, height_cbar])

cbar1 = fig.colorbar(im1, cax=cbar1_ax)
cbar2 = fig.colorbar(im2, cax=cbar2_ax)
cbar3 = fig.colorbar(im3, cax=cbar3_ax)
cbar1.set_ticks([-1.5, 0, 1.5])
cbar2.set_ticks([-1, 0, 1.5, 3])
cbar3.set_ticks([-1.5, 0, 1.5, 3])

for r in [1, 2, 3]:
    ax[r].grid(True)
    ax[r].set_yticks([])

ax[1].set_ylabel('Channels\n(LMP)', fontsize=12)
ax[2].set_ylabel('Channels\n(β)', fontsize=12)
ax[3].set_ylabel('Channels\n(High γ)', fontsize=12)

ax[-1].set_xlabel('Time (s)')
ax[-1].set_xticks([0, 10, 20, 30, 40, 50, 60])

plt.subplots_adjust(hspace=0.2)

# plt.savefig(f"{SAVE_IMG_DIR}/three_bands.svg", bbox_inches='tight')
# plt.savefig(f"{SAVE_IMG_DIR}/three_bands.png", bbox_inches='tight', dpi=1200)

Plot Spatiotemporal Progression using session 10 data

In [None]:
""" pad the band data to full 16x16 channels """
padded_zs_lmp   = np.zeros((NCH, len(motion_t_10)))
padded_zs_hga   = np.zeros((NCH, len(motion_t_10)))
padded_zs_beta  = np.zeros((NCH, len(motion_t_10)))

padded_zs_lmp[good_chs_10] = fch_zs_lmp_10
padded_zs_hga[good_chs_10] = fch_zs_hga_10
padded_zs_beta[good_chs_10] = fch_zs_beta_10

In [None]:
""" decide time range to plot """
t_start, t_end = 13.5, 14.5
t_range = np.arange(t_start, t_end, (t_end-t_start)/10) # 10: number of frames to plot
frame_idxs = []
for tt in t_range:
    frame_idxs.append(np.where(motion_t_10 > tt)[0][0])

In [None]:
""" plot LMP and High-gamma """
plt.close('all')
fig, ax = plt.subplots(3, 10, figsize=(20, 5), gridspec_kw={'height_ratios': [1, 2, 2]})

ax[0, 0].set_yticks([0])
ax[0, 0].set_ylabel('Wrist y-vel.\n(normalized)', fontsize=18)
ax[1, 0].set_ylabel('LMP', fontsize=24)
ax[2, 0].set_ylabel('High γ', fontsize=24)

for idx, frame_idx in enumerate(frame_idxs):
    lmp_frame = padded_zs_lmp[:, frame_idx].reshape(16, -1)
    hga_frame = padded_zs_hga[:, frame_idx].reshape(16, -1)

    ax[0, idx].axhline(y=0, color = 'silver', linewidth=2)
    # the following two lines only works because 10 is the first session
    ax[0, idx].plot(motion_t_10[:len(y_true_10)]*T_DF_MOTION, y_true_10, linewidth=4)
    ax[0, idx].plot(motion_t_10[:len(y_pred_10)]*T_DF_MOTION, y_pred_10, linewidth=2.5)
    ax[0, idx].set_xlim((t_start-0.1, t_end))
    ax[0, idx].axvline(x = t_range[idx], color='k', linestyle='--', linewidth=4)

    ax[1, idx].imshow(lmp_frame, vmin=lmp_vmin, vmax=lmp_vmax, cmap='bwr')
    ax[2, idx].imshow(hga_frame, vmin=hga_vmin, vmax=hga_vmax, cmap='viridis')

    draw_CS_boundary(ax[1, idx])
    draw_CS_boundary(ax[2, idx])

    ax[0, idx].set_title(f'{t_range[idx]:.1f}s', fontsize=24)
    
    if idx > 0:
        ax[0, idx].set_yticks([])
    ax[0, idx].set_xticks([])
    ax[1, idx].set_xticks([])
    ax[1, idx].set_yticks([])
    ax[2, idx].set_xticks([])
    ax[2, idx].set_yticks([])

plt.subplots_adjust(hspace=0.02, wspace=0.02)

# plt.savefig(f"{SAVE_IMG_DIR}/spatiotemporal_two_bands.svg", bbox_inches='tight')
# plt.savefig(f"{SAVE_IMG_DIR}/spatiotemporal_two_bands.png", bbox_inches='tight', dpi=1200)