In [None]:
%load_ext autoreload
%autoreload 2
import os, sys

import numpy as np

from matplotlib import pyplot as plt
from matplotlib_settings import set_plot_settings, reset_plot_settings

# Set the plot settings
set_plot_settings()

# import global variables
from utils_motor_global import *

from utils_motor_sigproc import get_mt_ch_psd, normalize_spect

SAVE_IMG_DIR = './paper_figure_fodder/thesis'

Channel Spectrogram 

In [None]:
spect_mu  = np.load(f'{BAND_MU_SIGMA_DIR}/spect_mu.npy')
spect_sigma  = np.load(f'{BAND_MU_SIGMA_DIR}/spect_sigma.npy')

In [None]:
key = '010'

In [None]:
""" load recording before PC removal """
rec_load_dir = f'{REC_DIR}/2_BPF_DS/{key}'
DS_t = np.load(f'{rec_load_dir}/t_DS_session_{key}.npy')
DS_rec_data = np.load(f'{rec_load_dir}/recording_DS_session_{key}.npy')
DS_good_ch = np.load(f'{rec_load_dir}/good_channels_{key}.npy')

DS_rec_data = DS_rec_data.T # channel * time

In [None]:
""" load recording after PC removal """
rec_load_dir = f'{REC_DIR}/3_HD_remove_5_PCs/{key}'
t_HB_removed = np.load(f'{rec_load_dir}/t_HB_removed_session_{key}.npy')
rec_data_HB_removed = np.load(f'{rec_load_dir}/recording_HB_removed_session_{key}.npy')
good_ch_HB_removed = np.load(f'{rec_load_dir}/good_channels_{key}.npy')

rec_data_HB_removed = rec_data_HB_removed.T # channel * time

In [None]:
""" truncate data before PC removal """
t_HB_removed[0], t_HB_removed[-1], DS_t[0], DS_t[-1]
idx0 = np.where(DS_t == t_HB_removed[0])[0][0]
idx1 = np.where(DS_t == t_HB_removed[-1])[0][0] + 1

DS_t = DS_t[idx0:idx1]
DS_rec_data = DS_rec_data[:,idx0:idx1]

assert DS_rec_data.shape == rec_data_HB_removed.shape

In [None]:
""" load spectrogram data """
# recording data
spect_load_dir = f'{REC_DIR}/6_truncate/{key}'
good_chs_10    = np.load(f'{spect_load_dir}/good_channels_{key}.npy')
spect_10       = np.load(f'{spect_load_dir}/spect_{key}.npy'     ) 
lmp_data_10    = np.load(f'{spect_load_dir}/lmp_{key}.npy'       ) 
spect_f_10     = np.load(f'{spect_load_dir}/spect_f_{key}.npy'   ) 
spect_t_10     = np.load(f'{spect_load_dir}/spect_t_{key}.npy'   ) 

In [None]:
""" load motion data """
motion_dir = f'{MODEL_INPUT_DIR}/motion/{key}'

y_10            = np.load(f'{motion_dir}/norm_wrist_vel_y_{key}.npy')
motion_t_10     = np.load(f'{motion_dir}/model_t_{key}.npy')

In [None]:
t_HB_removed[0], t_HB_removed[-1], spect_t_10[0], spect_t_10[-1], motion_t_10[0], motion_t_10[-1]

In [None]:
""" truncate the recording and spectrogram to match motion (model) time window """
rec_idx0   = np.where(t_HB_removed >= motion_t_10[0])[0][0]
rec_idx1   = np.where(t_HB_removed <= motion_t_10[-1])[0][-1] + 1

spect_idx0 = np.where(spect_t_10 >= motion_t_10[0])[0][0]
spect_idx1 = np.where(spect_t_10 <= motion_t_10[-1])[0][-1] + 1
assert spect_idx1 - spect_idx0 == len(motion_t_10)

t_HB_removed        = t_HB_removed[rec_idx0:rec_idx1]
rec_data_HB_removed = rec_data_HB_removed[:,rec_idx0:rec_idx1]
spect_t_10   = spect_t_10[spect_idx0:spect_idx1]
spect_10     = spect_10[:,spect_idx0:spect_idx1,:]

In [None]:
""" force the all times to start at zero """
t_HB_removed -= t_HB_removed[0]
spect_t_10  -= spect_t_10[0]
motion_t_10 -= motion_t_10[0]

t_HB_removed[0], t_HB_removed[-1], spect_t_10[0], spect_t_10[-1], motion_t_10[0], motion_t_10[-1]

In [None]:
""" z-score spectrogram """
zs_spect_10  =  normalize_spect(good_chs_10, spect_10, spect_mu, spect_sigma, sel_zscore=True)

In [None]:
from scipy.signal.windows import dpss
""" multi-taper params (decide how long to plot) """
t_start, t_end = 0, 20 # seconds
Ts = t_HB_removed[1] - t_HB_removed[0]
rec_idx0 = np.where(t_HB_removed >= t_start)[0][0]
rec_idx1 = np.where(t_HB_removed <= t_end)[0][-1] + 1

len_win = t_HB_removed[rec_idx1] - t_HB_removed[rec_idx0]

W = 0.5 # Hz. frequency smoothing: [f0 - W , f0 + W]
NW = len_win*W # common choices are 2.5, 3, 3.5, 4
K = int(2*NW - 1) # Number of tapers
print(f'{NW=:.2f}, {K=}')

fs = 1/Ts
win_size = rec_idx1 - rec_idx0

n = win_size
half_n = int(np.ceil(win_size/2))
freq = np.fft.fftfreq(win_size, d = 1/fs)
half_freq = freq[:half_n]
fbin = half_freq[1]

# DPSS
dpss_tapers, dpss_eigen = dpss(n, NW, K, return_ratios=True)
wt = np.ones(K)/K # apply unity weight

In [None]:
""" plot motion, time domain trace, spectrogram of a single channel """
r, c = 3, 8
ch = 16*r + c # pick a channel

rec_ch_idx   = np.where(good_ch_HB_removed == ch)[0][0]
spect_ch_idx = np.where(good_chs_10 == ch)[0][0]

# motion params
motion_idx0 = np.where(motion_t_10 >= t_start)[0][0]
motion_idx1 = np.where(motion_t_10 <= t_end)[0][-1] + 1

# rec params
assert np.array_equal(good_ch_HB_removed, DS_good_ch)
ch_data = rec_data_HB_removed[rec_ch_idx,rec_idx0:rec_idx1]
ch_data = ch_data*FS_ADC/MAX_ADC_CODE/GAIN
DS_ch_data = DS_rec_data[rec_ch_idx,rec_idx0:rec_idx1]
DS_ch_data = DS_ch_data*FS_ADC/MAX_ADC_CODE/GAIN

# spectrogram params
spect_idx0 = np.where(spect_t_10 >= t_start)[0][0]
spect_idx1 = np.where(spect_t_10 <= t_end)[0][-1] + 1

ch_zs_spect = zs_spect_10[spect_ch_idx,spect_idx0:spect_idx1,:]
ch_zs_spect = ch_zs_spect[::16, :] # downsample

t0, t1 = t_start, t_end
f0, f1 = int(spect_f_10[0]), round(spect_f_10[-1])
extent = [t0, t1, f1, f0]
q_cut = 0.01
vmin = np.quantile(ch_zs_spect, q_cut)
vmax = np.quantile(ch_zs_spect, 1-q_cut)
cmap = 'viridis'

plt.close('all')
k = 0.9
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(10*k, 6*k),
                       gridspec_kw = {'height_ratios': [1.5, 1.5, 2]})
plt.subplots_adjust(hspace=0.3)

ax[0].plot(motion_t_10[motion_idx0:motion_idx1], y_10[motion_idx0:motion_idx1])
ax[1].plot(t_HB_removed[rec_idx0:rec_idx1], ch_data/1e-6)

im = ax[2].imshow(ch_zs_spect.T, aspect='auto', vmin=vmin, vmax=vmax, extent=extent, cmap=cmap,
             interpolation='bilinear')
ax[2].invert_yaxis()


ax[0].set_yticks([0])
ax[2].set_xticks([0, 5, 10, 15, 20])

for ii in range(3):
    ax[ii].grid(True)

ax[0].set_title(f'Channel ({r, c})\n')
fontsize = 14
ax[0].set_ylabel('Wrist y-vel.\n(normalized)\n', fontsize=fontsize)
ax[1].set_ylabel('Recording (μV)', fontsize=fontsize)
ax[2].set_ylabel('Frequency (Hz)', fontsize=fontsize)

ax[2].set_xlabel('Time (s)')

left_cbar = 0.92  # Adjust left position of the colorbars
bottom_cbar0 = 0.11  # Adjust bottom position of the colorbars
width_cbar = 0.015  # Adjust width of the colorbars
height_cbar = 0.26  # Adjust height of the colorbars

cbar0_ax = fig.add_axes([left_cbar, bottom_cbar0, width_cbar, height_cbar])
cbar0 = fig.colorbar(im, ax=ax[2], cax=cbar0_ax)
# cbar0.set_ticks([-1.5, 0, 1.5])

plt.savefig(f"{SAVE_IMG_DIR}./channel_spectrogram.svg", bbox_inches='tight')
plt.savefig(f"{SAVE_IMG_DIR}./channel_spectrogram.png", bbox_inches='tight', dpi=1200)

In [None]:
""" plot a single channel PSD"""
assert len(ch_data) == win_size

# freq_half, ch_psd = get_fft_power(ch_data/MAX_ADC_CODE*FS_ADC, Ts, filter='None')
ch_psd = get_mt_ch_psd(ch_data, dpss_tapers, wt)
DS_ch_psd = get_mt_ch_psd(DS_ch_data, dpss_tapers, wt)

psd_idx0 = np.where(half_freq >= 1)[0][0]
psd_idx1 = np.where(half_freq <= 300)[0][-1] + 1
norm_val = np.max(ch_psd[psd_idx0:psd_idx1])
ch_psd = ch_psd/norm_val
DS_ch_psd = DS_ch_psd/norm_val

plt.close('all')
fig, ax = plt.subplots(figsize=(4, 3))
ax.loglog(half_freq, DS_ch_psd)
ax.loglog(half_freq, ch_psd)

ax.set_xlim((1, 300))
# ax.set_ylim((np.min(ch_psd)*0.5, 2))
ax.set_ylim((1e-5, 2))
ax.set_xlabel('Frequency (Hz)')
ax.set_ylabel('PSD (a.u.)')

ax.set_title(f'Channel ({r, c})')
ax.grid(True)

ax.axvline(x = HB_BPF_LOW , linestyle='--', color='k')
ax.axvline(x = HB_BPF_HIGH, linestyle='--', color='k')

ax.legend(['Before HD Removal', 'After HD Removal'], loc = (1.02, 0))

plt.savefig(f"{SAVE_IMG_DIR}./channel_psd.svg", bbox_inches='tight')
plt.savefig(f"{SAVE_IMG_DIR}./channel_psd.png", bbox_inches='tight', dpi=1200)

In [None]:
fig, ax = plt.subplots(2, 1, sharex=True, sharey=True, figsize=(4, 3))
ax[0].plot(t_HB_removed[rec_idx0:rec_idx1], DS_ch_data/1e-6)
ax[1].plot(t_HB_removed[rec_idx0:rec_idx1], ch_data/1e-6)
ax[0].set_xlim((15, 20))


fontsize = 12
ax[0].set_ylabel('Before HD\nRemoval (μV)', fontsize=fontsize)
ax[1].set_ylabel('After HD\nRemoval (μV)', fontsize=fontsize)
ax[1].set_xlabel('Time (s)')

plt.savefig(f"{SAVE_IMG_DIR}./HD_effect_channel_td.svg", bbox_inches='tight')
plt.savefig(f"{SAVE_IMG_DIR}./HD_effect_channel_td.png", bbox_inches='tight', dpi=1200)