In [None]:
%load_ext autoreload
%autoreload 2
import os, sys

import numpy as np

from matplotlib import pyplot as plt
from matplotlib_settings import set_plot_settings, reset_plot_settings

# Set the plot settings
set_plot_settings()

# import global variables
from utils_motor_global import *
from utils_motor_model import prepare_model_data, sweep_model, build_model
from utils_motor_plot  import draw_CS_boundary

# from utils_motor_sigproc import get_mt_ch_psd, normalize_spect

ROOT_SCALOMAT_DIR = f'{MODEL_INPUT_DIR}/scalogram_matrix'
ROOT_MOTION_DIR = f'{MODEL_INPUT_DIR}/motion'

ROOT_SAVE_DIR = f'{MODEL_OUTPUT_DIR}/decoded'
# SAVE_IMG_DIR = f'{MODEL_OUTPUT_DIR}/decoded_img'

Part 1. Build Decoder

In [None]:
""" After the sweep, it has been decided that these params will be used """
# All sweepable params
SEL_CHANNELS = 'all'
SEL_INCLUDE_LFS = False
SEL_MODEL = 'pls' # Ridge

band_strs = ['LMP', 'β ', 'Low γ ', 'High γ']

In [None]:
""" load common channels """
common_chs = np.load(f'{ROOT_SCALOMAT_DIR}/common_good_channels.npy')

In [None]:
spect_mu  = np.load(f'{BAND_MU_SIGMA_DIR}/spect_mu.npy')
spect_sigma  = np.load(f'{BAND_MU_SIGMA_DIR}/spect_sigma.npy')

In [None]:
X, (_, _, _), (wrist_vel_xs, wrist_vel_ys, wrist_vel_zs), ssids, taus = \
prepare_model_data(ROOT_SCALOMAT_DIR, ROOT_MOTION_DIR,
                        TAU_START, TAU_END, TAU_DF, T_DF_MOTION,
                        SEL_INCLUDE_LFS, SEL_CHANNELS, None, None)

Part 2. Plot Decoder for X, Y, Z

In [None]:
""" manually repeat for DIMENSION = 'x', 'y', 'z' """ 
DIMENSION = 'y'

In [None]:
if DIMENSION == 'x':
    y, OPT_HPARAM = np.abs(wrist_vel_xs), X_OPT_HPARAM
if DIMENSION == 'y':
    y, OPT_HPARAM = wrist_vel_ys, Y_OPT_HPARAM
if DIMENSION == 'z':
    y, OPT_HPARAM = wrist_vel_zs, Z_OPT_HPARAM

if   SEL_MODEL == 'pls':   hparam_range = np.arange(1, 10)
elif SEL_MODEL == 'ridge': hparam_range = np.logspace(3, 6, (6-3)*2 + 1)

In [None]:
sweep_result = sweep_model(N_SPLIT, X, y, SEL_MODEL, hparam_range)

""" find the optimum hparam """
# PRESS, r2 (coef. of determination), mean squared error, correlation coef.
presses, r2s, mses, rs = [np.array(sweep_result[hparam]['presses']) for hparam in hparam_range], \
    [np.array(sweep_result[hparam]['r2s']) for hparam in hparam_range], \
    [np.array(sweep_result[hparam]['mses']) for hparam in hparam_range], \
    [np.array(sweep_result[hparam]['rs']) for hparam in hparam_range]

In [None]:
""" convert to array """
# PRESS, r2 (coef. of determination), mean squared error, correlation coef.
press_avgs = [np.mean(press) for press in presses]
r2_avgs =    [np.mean(r2) for r2 in r2s]
mse_avgs =   [np.mean(mse) for mse in mses]
r_avgs =     [np.mean(r) for r in rs]

press_stds = [np.std(press) for press in presses]
r2_stds =    [np.std(r2) for r2 in r2s]
mse_stds =   [np.std(mse) for mse in mses]
r_stds =     [np.std(r) for r in rs]

In [None]:
""" plot sweep result """
fig, ax = plt.subplots(2, 1, figsize=(4, 3), sharex=True)

x_range = hparam_range
xlabel_str = '# of PLS components'

ax[0].errorbar(x_range, press_avgs, yerr=press_stds/np.sqrt(N_SPLIT), fmt='o',
            capsize=3, markersize = 6, markerfacecolor='none', markeredgecolor='darkblue',
            color='darkblue') # markersize..
ax[0].plot(x_range, press_avgs, 'o-', markersize=6)
ax[1].errorbar(x_range, r_avgs, yerr=r_stds/np.sqrt(N_SPLIT), fmt='o',
            capsize=3, markersize = 6, markerfacecolor='none', markeredgecolor='darkblue',
            color='darkblue') # markersize..
ax[1].plot(x_range, r_avgs, 'o-', markersize=6)

ax[0].set_ylabel('PRESS')
ax[1].set_ylabel('Corr. Coeff.')
ax[0].set_yticks([400, 500, 600])
ax[1].set_yticks([0.4, 0.5, 0.6])
ax[0].set_xticks(hparam_range)

ax[-1].set_xlabel(xlabel_str)

for ii in range(2):
    ax[ii].grid(True)

# save_img_dir = f'{ROOT_SAVE_DIR}_img/hparam_sweep'
# if not os.path.exists(save_img_dir):
    # os.makedirs(save_img_dir)
# plt.savefig(f"{save_img_dir}./hparam_sweep_{DIMENSION}.svg", bbox_inches='tight')
# plt.savefig(f"{save_img_dir}./hparam_sweep_{DIMENSION}.png", bbox_inches='tight', dpi=1200)

In [None]:
""" select optimum hyperparameter """
opt_idx = np.argmin(press_avgs)
assert hparam_range[opt_idx] == OPT_HPARAM

In [None]:
y_preds, test_idxs, coefs, intercepts, (opt_presses, opt_r2s, opt_mses, opt_rs) = \
    build_model(N_SPLIT, X, y, SEL_MODEL, OPT_HPARAM)

In [None]:
y_preds_flat = np.concatenate(y_preds, axis=-1)
print(f'correlation coef: {np.mean(opt_rs):.2f} ± {np.std(opt_rs)/np.sqrt(N_SPLIT):.2f} (mean ± SE)')

In [None]:
""" save """
for idx, key in enumerate(SESSION_KEYS):
    save_data_dir = f'{ROOT_SAVE_DIR}/{key}'
    # save_img_dir = f'{SAVE_IMG_DIR}/{key}'
    if not os.path.exists(save_data_dir):
        os.makedirs(save_data_dir)
    # if not os.path.exists(save_img_dir):
        # os.makedirs(save_img_dir)

    # save data
    session_idxs = np.where(ssids == idx)[0]
    session_y = y[session_idxs]
    session_y_pred = y_preds_flat[session_idxs]

    t = np.arange(0, len(session_y_pred))*T_STEP_SCALO*T_DF_MOTION

    np.save(f'{save_data_dir}/vel_{DIMENSION}_observed_{key}.npy', session_y)
    np.save(f'{save_data_dir}/vel_{DIMENSION}_predicted_{key}.npy', session_y_pred)
    np.save(f'{save_data_dir}/t_{key}.npy', t)

    # save plot
    # fig, ax = plt.subplots(figsize=(min(int(t[-1]*0.3), 15), 2))
    # ax.plot(t, session_y)
    # ax.plot(t, session_y_pred)

    # ax.grid(True)
    # ax.set_yticks([0])
    # ax.set_xlim(t[0], t[-1])
    # ax.set_xlabel('Time (s)')
    # ax.set_ylabel(f'Wrist {DIMENSION}-vel.\n(normalized)')

    # plt.savefig(f"{save_img_dir}./vel_{DIMENSION}_decoded.svg", bbox_inches='tight')
    # plt.savefig(f"{save_img_dir}./vel_{DIMENSION}_decoded.png", bbox_inches='tight', dpi=1200)
    # plt.close(fig)

In [None]:
# from utils_motor_model import plot_decoded_y

# plt.close('all')
# fig, ax = plt.subplots(N_SPLIT, 1, figsize=(12, 8),
#                        sharex=True) # , sharey=True)

# title_str = f'PLS Component: #{OPT_HPARAM}, r: {np.mean(opt_rs):.2f}'
# if DIMENSION == 'x':
#     title_str += f'\nWrist {DIMENSION}-speed (normalized) vs. CV'
# else:
#     title_str += f'\nWrist {DIMENSION}-velocity (normalized) vs. CV'

# plot_decoded_y(fig, ax, N_SPLIT, y, y_preds, test_idxs, opt_rs, title_str)

# ax[0].set_title(title_str)
# ax[-1].set_xlabel('Time (sec)')

# plt.savefig(f"{SAVE_IMG_DIR}./full_decoder_{DIMENSION}.svg", bbox_inches='tight')
# plt.savefig(f"{SAVE_IMG_DIR}./full_decoder_{DIMENSION}.png", bbox_inches='tight', dpi=1200)

Part 3. Plot Coefficient Contribution

In [None]:
if SEL_CHANNELS == 'all': good_chs = common_chs
# if SEL_CHANNELS == 'M1':  good_chs = M1_ch_idxs
# if SEL_CHANNELS == 'S1':  good_chs = S1_ch_idxs
coefs = np.squeeze(np.array(coefs))
coefs = coefs.reshape(N_SPLIT, len(good_chs), len(taus), len(band_strs))
coef_avg = np.mean(coefs, axis=0)

In [None]:
""" relative weight, spectral"""
w_f  = np.sum(np.sum(np.abs(coefs), axis=2), axis=1)
w_f = w_f/np.sum(w_f)*100*N_SPLIT

w_f_avg = np.mean(w_f, axis=0)
w_f_stderr = np.std(w_f, axis=0)/np.sqrt(N_SPLIT)

In [None]:
fig, ax = plt.subplots(figsize=(3, 2.5))
ax.axhline(y = 1/len(band_strs)*100, color='gray', linestyle='--')

ax.bar(band_strs, w_f_avg)
ax.errorbar(band_strs, w_f_avg, yerr=w_f_stderr, fmt='.', capsize=6, color='darkblue',
            markerfacecolor='none', markeredgecolor='none')
ax.tick_params(axis='x', labelrotation=90, labelsize=16)
ax.set_ylabel('Contribution (%)')

# plt.savefig(f"{SAVE_IMG_DIR}/contributions/freq_contribution_{DIMENSION}.svg", bbox_inches='tight')
# plt.savefig(f"{SAVE_IMG_DIR}/contributions/freq_contribution_{DIMENSION}.png", bbox_inches='tight', dpi=1200)