In [None]:
"""
This notebook tests out different model input parameters, eventually converging to the set
that is most useful.
"""

In [None]:
%load_ext autoreload
%autoreload 2
import os, sys

import numpy as np

from matplotlib import pyplot as plt
from matplotlib_settings import set_plot_settings, reset_plot_settings

# Set the plot settings
set_plot_settings()

# import global variables
from utils_motor_global import *

ROOT_SCALOMAT_DIR = f'{MODEL_INPUT_DIR}/scalogram_matrix'
ROOT_MOTION_DIR = f'{MODEL_INPUT_DIR}/motion'
ROOT_SAVE_DIR = f'{MODEL_OUTPUT_DIR}/param_sweep'


In [None]:
""" load common channels """
common_chs = np.load(f'{ROOT_SCALOMAT_DIR}/common_good_channels.npy')
M1_chs = np.load(f'{ROOT_SCALOMAT_DIR}/M1_channels.npy')
S1_chs = np.load(f'{ROOT_SCALOMAT_DIR}/S1_channels.npy')

M1_ch_idxs = np.array([ch_idx for ch_idx, ch in enumerate(common_chs) if ch in M1_chs])
S1_ch_idxs = np.array([ch_idx for ch_idx, ch in enumerate(common_chs) if ch in S1_chs])

In [None]:
# Rather than using PRESS, this notebook used coefficient of determination (r2)
# as the metric for choosing the best PLS model and correlation coefficient for
# choosing the best Ridge model

In [None]:
# This notebook was originally ran with T_DF_MOTION = 10, but for the paper, this parameter
# was later modified to 20.
# T_DF_MATRIX = 2 # this param was set during model input generation
# T_DF_MOTION = 10 # decimation factor for motion data. 20: 5Hz, 10: 10Hz, 5: 20Hz

In [None]:
""" Full List of Sweepable Parameters """ 
# TAU_START = [-1, -0.75, -0.5, -0.25]
# TAU_END = [0, 0.25, 0.5, 0.75, 1]
# TAU_DF = [20, 10, 5, 2, 1] # decimation factor for model input matrix time resolution
# SEL_INCLUDE_LFS = [True, False]
# SEL_CHANNELS = ['all', 'M1', 'S1']
# SEL_VELOCITY = [True, False]
# DIMENSION = ['x', 'y', 'z']
# MODEL = ['pls', 'ridge']

# not part of this notebook. SEL_INPUT was part of the playground notebook
# SEL_INPUT = ['linear', 'quadratic', 'composite'] # both quadratic and composite are significantly worse than linear

In [None]:
"""
From the playground, it was determined that:
-Band merged spectrogram performs better than full frequency spectrogram
-Normalization vs Z-scoring has negligible effect
-PLS has worse 'r' than Ridge, but the time domain waveform looks more appealing
-Linear model performs better than composite or quadratic model
"""

In [None]:
from utils_motor_model import prepare_model_data, sweep_model, build_model
from utils_motor_model import plot_decoded_y, compute_and_plot_model_coeff_contributions

In [None]:
""" 1st pass. Determine the meaningful tau window, post-feature """
import itertools, pickle

# input matrix param permutation
TAU_STARTS = [-1] # [-1, -0.75, -0.5, -0.25]
TAU_ENDS = [0, 0.25, 0.5, 0.75, 1]
TAU_DFS = [10] # [20, 10, 5, 2, 1] # decimation factor for model input matrix time resolution. 10 gives 100 msec
SEL_INCLUDE_LFS = [False] # [True, False]
SEL_CHANNELS = ['all'] # ['all', 'M1', 'S1']

# motor feature permutation
SEL_VELOCITY = [True] # [True, False]
DIMENSIONS = ['y'] # ['x', 'y', 'z']

# model selector
MODELS = ['pls'] # ['pls', 'ridge']

matrix_perms = list(itertools.product(TAU_STARTS, TAU_ENDS, TAU_DFS,
                                     SEL_INCLUDE_LFS, SEL_CHANNELS))
motion_perms = list(itertools.product(SEL_VELOCITY, DIMENSIONS))

for matrix_params in matrix_perms:
    tau_start, tau_end, tau_df, sel_include_lfs, sel_channels = matrix_params
    
    if sel_channels == 'all':  chs = common_chs
    elif sel_channels == 'M1': chs = common_chs[M1_ch_idxs]
    elif sel_channels == 'S1': chs = common_chs[S1_ch_idxs]
    else: raise Exception

    if sel_include_lfs: band_strs = ['LMP', 'LFS', 'β ', 'Low γ ', 'High γ']
    else:               band_strs = ['LMP', 'β ', 'Low γ ', 'High γ']

    X, (wrist_pos_xs, wrist_pos_ys, _), (wrist_vel_xs, wrist_vel_ys, _), _, taus = \
        prepare_model_data(ROOT_SCALOMAT_DIR, ROOT_MOTION_DIR,
                        tau_start, tau_end, tau_df, T_DF_MOTION,
                        sel_include_lfs, sel_channels, M1_ch_idxs, S1_ch_idxs)

    for motion_params in motion_perms:
        sel_velocity, dimension = motion_params

        if sel_velocity:
            if   dimension == 'x': y = wrist_vel_xs
            elif dimension == 'y': y = wrist_vel_ys
            else: raise Exception
        else:
            if   dimension == 'x': y = wrist_pos_xs
            elif dimension == 'y': y = wrist_pos_ys
            else: raise Exception

        for sel_model in MODELS:
            if sel_model == 'pls':     hparam_range = np.arange(3, 10)
            elif sel_model == 'ridge': hparam_range = np.logspace(3, 6, (6-3)*2 + 1)
            else: raise Exception

            sweep_result = sweep_model(N_SPLIT, X, y, sel_model, hparam_range)

            """ find the optimum hparam """
            presses, r2s, mses, rs = [np.array(sweep_result[hparam]['presses']) for hparam in hparam_range], \
                [np.array(sweep_result[hparam]['r2s']) for hparam in hparam_range], \
                [np.array(sweep_result[hparam]['mses']) for hparam in hparam_range], \
                [np.array(sweep_result[hparam]['rs']) for hparam in hparam_range]
            press_avgs = [np.mean(press) for press in presses]
            r2_avgs = [np.mean(r2) for r2 in r2s]
            mse_avgs = [np.mean(mse) for mse in mses]
            r_avgs = [np.mean(r) for r in rs]

            if sel_model == 'pls':  opt_hparam = hparam_range[np.argmax(r2_avgs)]
            if sel_model == 'ridge':opt_hparam = hparam_range[np.argmax(r_avgs)]

            """ build model with optimum hparam """
            y_preds, test_idxs, coefs, intercepts, (opt_presses, opt_r2s, opt_mses, opt_rs) = \
                build_model(N_SPLIT, X, y, sel_model, opt_hparam)

            """ save results """
            if sel_velocity: save_data_dir = f'{ROOT_SAVE_DIR}/1st_pass/{dimension}_vel'
            else:            save_data_dir = f'{ROOT_SAVE_DIR}/1st_pass/{dimension}_pos'
            save_img_dir = f'{save_data_dir}_img'
            if not os.path.exists(save_data_dir):
                os.makedirs(save_data_dir)
            if not os.path.exists(save_img_dir):
                os.makedirs(save_img_dir)

            fn_substr = f'{sel_model}_{int(opt_hparam)}_tau_win_{int(tau_start*1000)}'
            fn_substr += f'_{int(tau_end*1000)}_{int(10*tau_df)}'

            opt_result = {
                'y': y,
                'y_preds': y_preds,
                'coefs': coefs,
                'intercepts': intercepts,
                'presses': opt_presses,
                'r2s': opt_r2s,
                'mses': opt_mses,
                'rs': opt_rs
            }

            with open(f'{save_data_dir}/{fn_substr}.pkl', 'wb') as f:
                pickle.dump(opt_result, f)

            """ save plots """
            save_img_fn0, save_img_fn1 = f'y_pred_{fn_substr}', f'coefs_{fn_substr}'

            if sel_velocity: title_str = f'Wrist {dimension}-velocity vs. CV'
            else:            title_str = f'Wrist {dimension}-position vs. CV'
            if sel_model == 'pls':
                title_str += f'\nPLS Component: #{opt_hparam}, r: {np.mean(rs):.2f}'
            if sel_model == 'ridge':
                title_str += f'\nRidge alpha: 1e{int(np.log10(opt_hparam))}, r: {np.mean(rs):.2f}'


            fig, ax = plt.subplots(N_SPLIT, 1, figsize=(12, 8), sharex=True) # , sharey=True)
            title_str = plot_decoded_y(fig, ax, N_SPLIT, y, y_preds, test_idxs, opt_rs,
                                       title_str)
            plt.savefig(f'{save_img_dir}/{save_img_fn0}.png', bbox_inches='tight')
            plt.close(fig)

            q_cut = 0.01
            fig, ax = plt.subplots(1, 3, figsize=(11, 5), 
                                   gridspec_kw={'width_ratios': [4, 3, 3]})
            plt.subplots_adjust(wspace=0.3)
            compute_and_plot_model_coeff_contributions(fig, ax, coefs, N_SPLIT, chs, taus,
                                                       band_strs, q_cut=0.01, title_str=title_str)
            plt.tight_layout()
            plt.savefig(f'{save_img_dir}/{save_img_fn1}.png', bbox_inches='tight')
            plt.close(fig)


In [None]:
# 1st pass best TAU_ENDS (If r is tied, smaller range wins)
# x-pos. PLS: 0.25, Ridge: 0.25
# x-vel. PLS: 0.50, Ridge: 0.50
# y-pos. PLS: 0.25, Ridge: 0
# y-vel. PLS: 0.25, Ridge: 0.25
# choose value that is larger than the optimum (allow headroom)
# result of x isn't relevant unless r becomes good enough for publication 
""" 2nd pass. Determine the meaningful tau window, pre-feature """
import itertools, pickle

# input matrix param permutation
TAU_STARTS = [-1, -0.75, -0.5, -0.25]
TAU_ENDS = [0.5]
TAU_DFS = [10] # [20, 10, 5, 2, 1] # decimation factor for model input matrix time resolution. 10 gives 100 msec
SEL_INCLUDE_LFS = [False] # [True, False]
SEL_CHANNELS = ['all'] # ['all', 'M1', 'S1']

# motor feature permutation
SEL_VELOCITY = [True] # [True, False]
DIMENSIONS = ['y'] # ['x', 'y', 'z']

# model selector
MODELS = ['pls'] # ['pls', 'ridge']

matrix_perms = list(itertools.product(TAU_STARTS, TAU_ENDS, TAU_DFS,
                                     SEL_INCLUDE_LFS, SEL_CHANNELS))
motion_perms = list(itertools.product(SEL_VELOCITY, DIMENSIONS))

for matrix_params in matrix_perms:
    tau_start, tau_end, tau_df, sel_include_lfs, sel_channels = matrix_params
    
    if sel_channels == 'all':  chs = common_chs
    elif sel_channels == 'M1': chs = common_chs[M1_ch_idxs]
    elif sel_channels == 'S1': chs = common_chs[S1_ch_idxs]
    else: raise Exception

    if sel_include_lfs: band_strs = ['LMP', 'LFS', 'β ', 'Low γ ', 'High γ']
    else:               band_strs = ['LMP', 'β ', 'Low γ ', 'High γ']

    X, (wrist_pos_xs, wrist_pos_ys, _), (wrist_vel_xs, wrist_vel_ys, _), _, taus = \
        prepare_model_data(ROOT_SCALOMAT_DIR, ROOT_MOTION_DIR,
                        tau_start, tau_end, tau_df, T_DF_MOTION,
                        sel_include_lfs, sel_channels, M1_ch_idxs, S1_ch_idxs)

    for motion_params in motion_perms:
        sel_velocity, dimension = motion_params

        if sel_velocity:
            if   dimension == 'x': y = wrist_vel_xs
            elif dimension == 'y': y = wrist_vel_ys
            else: raise Exception
        else:
            if   dimension == 'x': y = wrist_pos_xs
            elif dimension == 'y': y = wrist_pos_ys
            else: raise Exception

        for sel_model in MODELS:
            if sel_model == 'pls':     hparam_range = np.arange(3, 10)
            elif sel_model == 'ridge': hparam_range = np.logspace(3, 6, (6-3)*2 + 1)
            else: raise Exception

            sweep_result = sweep_model(N_SPLIT, X, y, sel_model, hparam_range)

            """ find the optimum hparam """
            presses, r2s, mses, rs = [np.array(sweep_result[hparam]['presses']) for hparam in hparam_range], \
                [np.array(sweep_result[hparam]['r2s']) for hparam in hparam_range], \
                [np.array(sweep_result[hparam]['mses']) for hparam in hparam_range], \
                [np.array(sweep_result[hparam]['rs']) for hparam in hparam_range]
            press_avgs = [np.mean(press) for press in presses]
            r2_avgs = [np.mean(r2) for r2 in r2s]
            mse_avgs = [np.mean(mse) for mse in mses]
            r_avgs = [np.mean(r) for r in rs]

            if sel_model == 'pls':  opt_hparam = hparam_range[np.argmax(r2_avgs)]
            if sel_model == 'ridge':opt_hparam = hparam_range[np.argmax(r_avgs)]

            """ build model with optimum hparam """
            y_preds, test_idxs, coefs, intercepts, (opt_presses, opt_r2s, opt_mses, opt_rs) = \
                build_model(N_SPLIT, X, y, sel_model, opt_hparam)

            """ save results """
            if sel_velocity: save_data_dir = f'{ROOT_SAVE_DIR}/2nd_pass/{dimension}_vel'
            else:            save_data_dir = f'{ROOT_SAVE_DIR}/2nd_pass/{dimension}_pos'
            save_img_dir = f'{save_data_dir}_img'
            if not os.path.exists(save_data_dir):
                os.makedirs(save_data_dir)
            if not os.path.exists(save_img_dir):
                os.makedirs(save_img_dir)

            fn_substr = f'{sel_model}_{int(opt_hparam)}_tau_win_{int(tau_start*1000)}'
            fn_substr += f'_{int(tau_end*1000)}_{int(10*tau_df)}'

            opt_result = {
                'y': y,
                'y_preds': y_preds,
                'coefs': coefs,
                'intercepts': intercepts,
                'presses': opt_presses,
                'r2s': opt_r2s,
                'mses': opt_mses,
                'rs': opt_rs
            }

            with open(f'{save_data_dir}/{fn_substr}.pkl', 'wb') as f:
                pickle.dump(opt_result, f)

            """ save plots """
            save_img_fn0, save_img_fn1 = f'y_pred_{fn_substr}', f'coefs_{fn_substr}'

            if sel_velocity: title_str = f'Wrist {dimension}-velocity vs. CV'
            else:            title_str = f'Wrist {dimension}-position vs. CV'
            if sel_model == 'pls':
                title_str += f'\nPLS Component: #{opt_hparam}, r: {np.mean(rs):.2f}'
            if sel_model == 'ridge':
                title_str += f'\nRidge alpha: 1e{int(np.log10(opt_hparam))}, r: {np.mean(rs):.2f}'


            fig, ax = plt.subplots(N_SPLIT, 1, figsize=(12, 8), sharex=True) # , sharey=True)
            title_str = plot_decoded_y(fig, ax, N_SPLIT, y, y_preds, test_idxs, opt_rs,
                                       title_str)
            plt.savefig(f'{save_img_dir}/{save_img_fn0}.png', bbox_inches='tight')
            plt.close(fig)

            q_cut = 0.01
            fig, ax = plt.subplots(1, 3, figsize=(11, 5), 
                                   gridspec_kw={'width_ratios': [4, 3, 3]})
            plt.subplots_adjust(wspace=0.3)
            compute_and_plot_model_coeff_contributions(fig, ax, coefs, N_SPLIT, chs, taus,
                                                       band_strs, q_cut=0.01, title_str=title_str)
            plt.tight_layout()
            plt.savefig(f'{save_img_dir}/{save_img_fn1}.png', bbox_inches='tight')
            plt.close(fig)

In [None]:
# 2nd pass best TAU_STARTS (If r is tied, smaller range wins)
# x-pos. PLS: -0.50, Ridge: -0.50
# x-vel. PLS: -0.25, Ridge: -0.25
# y-pos. PLS: -0.50, Ridge: -0.50
# y-vel. PLS: -0.75, Ridge: -0.50
# choose value that is larger than the optimum (allow headroom)
# result of x isn't relevant unless r becomes good enough for publication
""" 3rd pass. Determine Time Lag Resolution """
import itertools, pickle

# input matrix param permutation
TAU_STARTS = [-0.5] # [-0.75, -0.5]
TAU_ENDS = [0.5] # [0.25, 0.5]
TAU_DFS = [20, 10, 5, 2, 1] # decimation factor for model input matrix resolution. 10 gives 100 msec
SEL_INCLUDE_LFS = [False] # [True, False]
SEL_CHANNELS = ['all'] # ['all', 'M1', 'S1']

# motor feature permutation
SEL_VELOCITY = [True] # [True, False]
DIMENSIONS = ['y'] # ['x', 'y', 'z']

# model selector
MODELS = ['pls'] # ['pls', 'ridge']

matrix_perms = list(itertools.product(TAU_STARTS, TAU_ENDS, TAU_DFS,
                                     SEL_INCLUDE_LFS, SEL_CHANNELS))
motion_perms = list(itertools.product(SEL_VELOCITY, DIMENSIONS))

for matrix_params in matrix_perms:
    tau_start, tau_end, tau_df, sel_include_lfs, sel_channels = matrix_params
    
    if sel_channels == 'all':  chs = common_chs
    elif sel_channels == 'M1': chs = common_chs[M1_ch_idxs]
    elif sel_channels == 'S1': chs = common_chs[S1_ch_idxs]
    else: raise Exception

    if sel_include_lfs: band_strs = ['LMP', 'LFS', 'β ', 'Low γ ', 'High γ']
    else:               band_strs = ['LMP', 'β ', 'Low γ ', 'High γ']

    X, (wrist_pos_xs, wrist_pos_ys, _), (wrist_vel_xs, wrist_vel_ys, _), _, taus = \
        prepare_model_data(ROOT_SCALOMAT_DIR, ROOT_MOTION_DIR,
                        tau_start, tau_end, tau_df, T_DF_MOTION,
                        sel_include_lfs, sel_channels, M1_ch_idxs, S1_ch_idxs)

    for motion_params in motion_perms:
        sel_velocity, dimension = motion_params

        if sel_velocity:
            if   dimension == 'x': y = wrist_vel_xs
            elif dimension == 'y': y = wrist_vel_ys
            else: raise Exception
        else:
            if   dimension == 'x': y = wrist_pos_xs
            elif dimension == 'y': y = wrist_pos_ys
            else: raise Exception

        for sel_model in MODELS:
            if sel_model == 'pls':     hparam_range = np.arange(3, 10)
            elif sel_model == 'ridge': hparam_range = np.logspace(3, 6, (6-3)*2 + 1)
            else: raise Exception

            sweep_result = sweep_model(N_SPLIT, X, y, sel_model, hparam_range)

            """ find the optimum hparam """
            presses, r2s, mses, rs = [np.array(sweep_result[hparam]['presses']) for hparam in hparam_range], \
                [np.array(sweep_result[hparam]['r2s']) for hparam in hparam_range], \
                [np.array(sweep_result[hparam]['mses']) for hparam in hparam_range], \
                [np.array(sweep_result[hparam]['rs']) for hparam in hparam_range]
            press_avgs = [np.mean(press) for press in presses]
            r2_avgs = [np.mean(r2) for r2 in r2s]
            mse_avgs = [np.mean(mse) for mse in mses]
            r_avgs = [np.mean(r) for r in rs]
            
            if sel_model == 'pls':  opt_hparam = hparam_range[np.argmax(r2_avgs)]
            if sel_model == 'ridge':opt_hparam = hparam_range[np.argmax(r_avgs)]

            """ build model with optimum hparam """
            y_preds, test_idxs, coefs, intercepts, (opt_presses, opt_r2s, opt_mses, opt_rs) = \
                build_model(N_SPLIT, X, y, sel_model, opt_hparam)

            """ save results """
            if sel_velocity: save_data_dir = f'{ROOT_SAVE_DIR}/3rd_pass/{dimension}_vel'
            else:            save_data_dir = f'{ROOT_SAVE_DIR}/3rd_pass/{dimension}_pos'
            save_img_dir = f'{save_data_dir}_img'
            if not os.path.exists(save_data_dir):
                os.makedirs(save_data_dir)
            if not os.path.exists(save_img_dir):
                os.makedirs(save_img_dir)

            fn_substr = f'{sel_model}_{int(opt_hparam)}_tau_win_{int(tau_start*1000)}'
            fn_substr += f'_{int(tau_end*1000)}_{int(10*tau_df)}'

            opt_result = {
                'y': y,
                'y_preds': y_preds,
                'coefs': coefs,
                'intercepts': intercepts,
                'presses': opt_presses,
                'r2s': opt_r2s,
                'mses': opt_mses,
                'rs': opt_rs
            }

            with open(f'{save_data_dir}/{fn_substr}.pkl', 'wb') as f:
                pickle.dump(opt_result, f)

            """ save plots """
            save_img_fn0, save_img_fn1 = f'y_pred_{fn_substr}', f'coefs_{fn_substr}'

            if sel_velocity: title_str = f'Wrist {dimension}-velocity vs. CV'
            else:            title_str = f'Wrist {dimension}-position vs. CV'
            if sel_model == 'pls':
                title_str += f'\nPLS Component: #{opt_hparam}, r: {np.mean(rs):.2f}'
            if sel_model == 'ridge':
                title_str += f'\nRidge alpha: 1e{int(np.log10(opt_hparam))}, r: {np.mean(rs):.2f}'


            fig, ax = plt.subplots(N_SPLIT, 1, figsize=(12, 8), sharex=True) # , sharey=True)
            title_str = plot_decoded_y(fig, ax, N_SPLIT, y, y_preds, test_idxs, opt_rs,
                                       title_str)
            plt.savefig(f'{save_img_dir}/{save_img_fn0}.png', bbox_inches='tight')
            plt.close(fig)

            q_cut = 0.01
            fig, ax = plt.subplots(1, 3, figsize=(11, 5), 
                                   gridspec_kw={'width_ratios': [4, 3, 3]})
            plt.subplots_adjust(wspace=0.3)
            compute_and_plot_model_coeff_contributions(fig, ax, coefs, N_SPLIT, chs, taus,
                                                       band_strs, q_cut=0.01, title_str=title_str)
            plt.tight_layout()
            plt.savefig(f'{save_img_dir}/{save_img_fn1}.png', bbox_inches='tight')
            plt.close(fig)


In [None]:
# 3rd pass best window, resolution (If r is tied, smaller range wins, and then coarser dt wins)
#               window (sec), dt (ms). corr coef
# x-pos. PLS:   [-0.50, 0.25], 100. r = 0.24
# x-pos. Ridge: [-0.50, 0.25], 100. r = 0.29
# x-vel. PLS:   [-0.50, 0.25], 100. r = 0.21
# x-vel. Ridge: [-0.50, 0.25], 100. r = 0.23
# y-pos. PLS:   [-0.25, 0.50], 100. r = 0.63
# y-pos. Ridge: [-0.50, 0.25], 100. r = 0.67
# y-vel. PLS:   [-0.50, 0.25], 20. r = 0.53
# y-vel. Ridge: [-0.50, 0.25], 50, r= 0.57
# choose value that is larger than the optimum (allow headroom)
# result of x isn't relevant unless r becomes good enough for publication

In [None]:
""" 4th pass. Test Channel Groups (Bonus) """
# TAU_ENDS = [0.25] gives better 'r', but time domain contribution plot doesn't look good
import itertools, pickle

# input matrix param permutation
TAU_STARTS = [-0.5] # [-0.75, -0.5]
TAU_ENDS = [0.5] # [0.25, 0.5]
TAU_DFS = [5] # decimation factor for model input matrix resolution. 10 gives 100 msec
SEL_INCLUDE_LFS = [False] # [True, False]
SEL_CHANNELS = ['all', 'M1', 'S1']

# motor feature permutation
SEL_VELOCITY = [True] # [True, False]
DIMENSIONS = ['y'] # ['x', 'y', 'z']

# model selector
MODELS = ['pls'] # ['pls', 'ridge']

matrix_perms = list(itertools.product(TAU_STARTS, TAU_ENDS, TAU_DFS,
                                     SEL_INCLUDE_LFS, SEL_CHANNELS))
motion_perms = list(itertools.product(SEL_VELOCITY, DIMENSIONS))

for matrix_params in matrix_perms:
    tau_start, tau_end, tau_df, sel_include_lfs, sel_channels = matrix_params
    
    if sel_channels == 'all':  chs = common_chs
    elif sel_channels == 'M1': chs = common_chs[M1_ch_idxs]
    elif sel_channels == 'S1': chs = common_chs[S1_ch_idxs]
    else: raise Exception

    if sel_include_lfs: band_strs = ['LMP', 'LFS', 'β ', 'Low γ ', 'High γ']
    else:               band_strs = ['LMP', 'β ', 'Low γ ', 'High γ']

    X, (wrist_pos_xs, wrist_pos_ys, _), (wrist_vel_xs, wrist_vel_ys, _), _, taus = \
        prepare_model_data(ROOT_SCALOMAT_DIR, ROOT_MOTION_DIR,
                        tau_start, tau_end, tau_df, T_DF_MOTION,
                        sel_include_lfs, sel_channels, M1_ch_idxs, S1_ch_idxs)

    for motion_params in motion_perms:
        sel_velocity, dimension = motion_params

        if sel_velocity:
            if   dimension == 'x': y = wrist_vel_xs
            elif dimension == 'y': y = wrist_vel_ys
            else: raise Exception
        else:
            if   dimension == 'x': y = wrist_pos_xs
            elif dimension == 'y': y = wrist_pos_ys
            else: raise Exception

        for sel_model in MODELS:
            if sel_model == 'pls':     hparam_range = np.arange(3, 10)
            elif sel_model == 'ridge': hparam_range = np.logspace(3, 6, (6-3)*2 + 1)
            else: raise Exception

            sweep_result = sweep_model(N_SPLIT, X, y, sel_model, hparam_range)

            """ find the optimum hparam """
            presses, r2s, mses, rs = [np.array(sweep_result[hparam]['presses']) for hparam in hparam_range], \
                [np.array(sweep_result[hparam]['r2s']) for hparam in hparam_range], \
                [np.array(sweep_result[hparam]['mses']) for hparam in hparam_range], \
                [np.array(sweep_result[hparam]['rs']) for hparam in hparam_range]
            press_avgs = [np.mean(press) for press in presses]
            r2_avgs = [np.mean(r2) for r2 in r2s]
            mse_avgs = [np.mean(mse) for mse in mses]
            r_avgs = [np.mean(r) for r in rs]
       
            if sel_model == 'pls':  opt_hparam = hparam_range[np.argmax(r2_avgs)]
            if sel_model == 'ridge':opt_hparam = hparam_range[np.argmax(r_avgs)]

            """ build model with optimum hparam """
            y_preds, test_idxs, coefs, intercepts, (opt_presses, opt_r2s, opt_mses, opt_rs) = \
                build_model(N_SPLIT, X, y, sel_model, opt_hparam)

            """ save results """
            if sel_include_lfs:
                if sel_velocity: save_data_dir = f'{ROOT_SAVE_DIR}/4th_pass/{dimension}_vel/{sel_channels}_include_lfs'
                else:            save_data_dir = f'{ROOT_SAVE_DIR}/4th_pass/{dimension}_pos/{sel_channels}_include_lfs'
            else:
                if sel_velocity: save_data_dir = f'{ROOT_SAVE_DIR}/4th_pass/{dimension}_vel/{sel_channels}'
                else:            save_data_dir = f'{ROOT_SAVE_DIR}/4th_pass/{dimension}_pos/{sel_channels}'
            save_img_dir = f'{save_data_dir}_img'
            if not os.path.exists(save_data_dir):
                os.makedirs(save_data_dir)
            if not os.path.exists(save_img_dir):
                os.makedirs(save_img_dir)

            fn_substr = f'{sel_model}_{int(opt_hparam)}_tau_win_{int(tau_start*1000)}'
            fn_substr += f'_{int(tau_end*1000)}_{int(10*tau_df)}'

            opt_result = {
                'y': y,
                'y_preds': y_preds,
                'coefs': coefs,
                'intercepts': intercepts,
                'presses': opt_presses,
                'r2s': opt_r2s,
                'mses': opt_mses,
                'rs': opt_rs
            }

            with open(f'{save_data_dir}/{fn_substr}.pkl', 'wb') as f:
                pickle.dump(opt_result, f)

            """ save plots """
            save_img_fn0, save_img_fn1 = f'y_pred_{fn_substr}', f'coefs_{fn_substr}'

            if sel_velocity: title_str = f'Wrist {dimension}-velocity vs. CV'
            else:            title_str = f'Wrist {dimension}-position vs. CV'
            if sel_model == 'pls':
                title_str += f'\nPLS Component: #{opt_hparam}, r: {np.mean(rs):.2f}'
            if sel_model == 'ridge':
                title_str += f'\nRidge alpha: 1e{int(np.log10(opt_hparam))}, r: {np.mean(rs):.2f}'


            fig, ax = plt.subplots(N_SPLIT, 1, figsize=(12, 8), sharex=True) # , sharey=True)
            title_str = plot_decoded_y(fig, ax, N_SPLIT, y, y_preds, test_idxs, opt_rs,
                                       title_str)
            plt.savefig(f'{save_img_dir}/{save_img_fn0}.png', bbox_inches='tight')
            plt.close(fig)

            q_cut = 0.01
            fig, ax = plt.subplots(1, 3, figsize=(11, 5), 
                                   gridspec_kw={'width_ratios': [4, 3, 3]})
            plt.subplots_adjust(wspace=0.3)
            compute_and_plot_model_coeff_contributions(fig, ax, coefs, N_SPLIT, chs, taus,
                                                       band_strs, q_cut=0.01, title_str=title_str)
            plt.tight_layout()
            plt.savefig(f'{save_img_dir}/{save_img_fn1}.png', bbox_inches='tight')
            plt.close(fig)