In [1]:
%load_ext autoreload
%autoreload 2
import os, sys

import numpy as np
from scipy.io import loadmat

from matplotlib import pyplot as plt
from matplotlib_settings import set_plot_settings, reset_plot_settings

# Set the plot settings
set_plot_settings()

# import global variables
from utils_motor_global import *
from utils_motion import MOTION_IDX_DICT, MOTION_IDX_DICT_V3P0

# load interpolation params
from utils_motion import T_SHORT, T_LONG, T_WIN, T_SPLINE
from utils_motion import interpolate_nans, moving_average, get_nanseg_idxs, spline_nans

from utils_motor_misc import list_files_with_keyword_extension

In [2]:
save_img_dir = f'./{MOTION_DIR}/smoothing_img'
if not os.path.exists(save_img_dir):
    os.makedirs(save_img_dir)

In [5]:
for session in GOOD_SESSIONS:

    """ load raw wrist position. use full length raw data for smoothing """
    load_dir = f'{RAW_MOTION_DIR}/{session:003}' # raw data directory
    save_dir = f'{MOTION_DIR}/{session:003}' # post-processed data directory 
    assert os.path.exists(load_dir)

    # position data (contains x,y,z coordinates of hand, wrist, elbow)
    all_pos_fname = list_files_with_keyword_extension(load_dir, 'all_pos3d', '.mat')
    assert len(all_pos_fname) == 1
    all_pos_fname = all_pos_fname[0]

    all_pos3d = loadmat(f'{load_dir}/{all_pos_fname}')['all_pos3d']
    wrist_x_raw = all_pos3d[1, 0, :]
    wrist_y_raw = all_pos3d[1, 1, :]
    wrist_z_raw = all_pos3d[1, 2, :]

    """ load time. use the sync corrected version """
    raw_t = loadmat(f'{load_dir}/rec_{session:003}_timestamp.mat')
    raw_t = np.squeeze(raw_t['timestamps_rec'])

    t = np.load(f'{MOTION_DIR}/{session:003}/motion_t_{session:003}.npy')
    assert np.abs((raw_t[-1] - raw_t[0]) - (t[-1] - t[0])) < 1e-6
    
    """ apply smoothing """
    # step 1. interpolate SHORT gaps
    dt = np.mean(np.diff(t))
    n_interp = int(np.round(T_SHORT/dt))
    wrist_x1 = interpolate_nans(wrist_x_raw, N = n_interp)
    wrist_y1 = interpolate_nans(wrist_y_raw, N = n_interp)
    wrist_z1 = interpolate_nans(wrist_z_raw, N = n_interp)

    # step 2. apply moving average filter
    winsize = int(np.round(T_WIN/dt/2))
    # print(f'{winsize = }')
    wrist_x2 = moving_average(wrist_x1, window_size = winsize)
    wrist_y2 = moving_average(wrist_y1, window_size = winsize)
    wrist_z2 = moving_average(wrist_z1, window_size = winsize)

    # step 3. fill-in remaining gap using cubic spline. within the range of raw data vmin, vmax
    nan_idxs_x2 = np.where(np.isnan(wrist_x2))[0] # used for step 4.
    nan_idxs_y2 = np.where(np.isnan(wrist_y2))[0]
    nan_idxs_z2 = np.where(np.isnan(wrist_z2))[0]

    nanseg_start_idxs_x, nanseg_end_idxs_x = get_nanseg_idxs(wrist_x2)
    nanseg_start_idxs_y, nanseg_end_idxs_y = get_nanseg_idxs(wrist_y2)
    nanseg_start_idxs_z, nanseg_end_idxs_z = get_nanseg_idxs(wrist_z2)

    wrist_x3 = spline_nans(t, wrist_x2, nanseg_start_idxs_x, nanseg_end_idxs_x, 
                           T_spline=T_SPLINE, T_limit=T_LONG)
    wrist_y3 = spline_nans(t, wrist_y2, nanseg_start_idxs_y, nanseg_end_idxs_y, 
                           T_spline=T_SPLINE, T_limit=T_LONG)
    wrist_z3 = spline_nans(t, wrist_z2, nanseg_start_idxs_z, nanseg_end_idxs_z, 
                           T_spline=T_SPLINE, T_limit=T_LONG)

    # step 4. derive velocity.
    # note that length of velocity time series is one less than position
    wrist_vel_x = np.diff(wrist_x3)
    wrist_vel_y = np.diff(wrist_y3)
    wrist_vel_z = np.diff(wrist_z3)

    wrist_vel_x = moving_average(wrist_vel_x, window_size=winsize, nan_tol=0)
    wrist_vel_y = moving_average(wrist_vel_y, window_size=winsize, nan_tol=0)
    wrist_vel_z = moving_average(wrist_vel_z, window_size=winsize, nan_tol=0)
    
    # End of Smoothing

    """ save full length motor features """
    # raw position
    np.save(f'{save_dir}/wrist_x_raw.npy', wrist_x_raw)
    np.save(f'{save_dir}/wrist_y_raw.npy', wrist_y_raw)
    np.save(f'{save_dir}/wrist_z_raw.npy', wrist_z_raw)

    # smoothed position
    np.save(f'{save_dir}/wrist_pos_x.npy', wrist_x3)
    np.save(f'{save_dir}/wrist_pos_y.npy', wrist_y3)
    np.save(f'{save_dir}/wrist_pos_z.npy', wrist_z3)

    # velocity
    np.save(f'{save_dir}/wrist_vel_x.npy', wrist_vel_x)
    np.save(f'{save_dir}/wrist_vel_y.npy', wrist_vel_y)
    np.save(f'{save_dir}/wrist_vel_z.npy', wrist_vel_z)

    keys = [key for key in MOTION_IDX_DICT.keys() if key.startswith(f'{session:003}')]

    """ save motor features for good behaving segments """
    for key in keys:
        """ save """
        idx0, idx1 = MOTION_IDX_DICT[key]
        if idx1 == -1: idx1 = len(t) # +1 should have been added to idx1.. but the impact should be negligible
        np.save(f'{save_dir}/wrist_pos_x_session_{key}.npy', wrist_x3[idx0:idx1])
        np.save(f'{save_dir}/wrist_pos_y_session_{key}.npy', wrist_y3[idx0:idx1])
        np.save(f'{save_dir}/wrist_pos_z_session_{key}.npy', wrist_z3[idx0:idx1])

        np.save(f'{save_dir}/pos_t_session_{key}.npy', t[idx0:idx1])

        np.save(f'{save_dir}/wrist_vel_x_session_{key}.npy', wrist_vel_x[idx0:idx1-1])
        np.save(f'{save_dir}/wrist_vel_y_session_{key}.npy', wrist_vel_y[idx0:idx1-1])
        np.save(f'{save_dir}/wrist_vel_z_session_{key}.npy', wrist_vel_z[idx0:idx1-1])

        np.save(f'{save_dir}/vel_t_session_{key}.npy', t[idx0:idx1-1])

        """ load KEW's smoothed data (used for comparison) """
        # wrist_x_v3p0 = loadmat(f'{load_dir}/smooth_wristx_rec{key}.mat')
        # wrist_y_v3p0 = loadmat(f'{load_dir}/smooth_wristy_rec{key}.mat')
        # wrist_z_v3p0 = loadmat(f'{load_dir}/smooth_wristz_rec{key}.mat')

        # wrist_x_v3p0 = np.squeeze(wrist_x_v3p0['smooth_wristx'])
        # wrist_y_v3p0 = np.squeeze(wrist_y_v3p0['smooth_wristy'])
        # wrist_z_v3p0 = np.squeeze(wrist_z_v3p0['smooth_wristz'])

        # # motion_t_v3p0 = np.load(f'{save_dir}/motion_t_v3p0_session_{key}.npy')

        # v3p0_idx0, v3p0_idx1 = MOTION_IDX_DICT_V3P0[key]
        # if v3p0_idx1 == -1: v3p0_idx1 = len(t)
        # motion_t_v3p0 = t[v3p0_idx0:v3p0_idx1]

        """ save plot, position. good segment """
        # plt.close('all')
        # fig, ax = plt.subplots(3, 3, sharex=True, figsize=(12, 9))
        
        # for r in range(3):
        #     ax[r, 0].sharey(ax[r, 1])
        #     ax[r, 1].sharey(ax[r, 2])
        #     ax[r, 1].set_yticks([])
        #     ax[r, 2].set_yticks([])
        
        # ax[0,0].plot(t[idx0:idx1], wrist_x_raw[idx0:idx1])
        # ax[1,0].plot(t[idx0:idx1], wrist_y_raw[idx0:idx1])
        # ax[2,0].plot(t[idx0:idx1], wrist_z_raw[idx0:idx1])
        
        # ax[0,1].plot(motion_t_v3p0, wrist_x_v3p0)
        # ax[1,1].plot(motion_t_v3p0, wrist_y_v3p0)
        # ax[2,1].plot(motion_t_v3p0, wrist_z_v3p0)
        
        # ax[0,2].plot(t[idx0:idx1], wrist_x3[idx0:idx1])
        # ax[1,2].plot(t[idx0:idx1], wrist_y3[idx0:idx1])
        # ax[2,2].plot(t[idx0:idx1], wrist_z3[idx0:idx1])
        
        # ax[0,0].set_title('Raw')
        # ax[0,1].set_title('v3.0')
        # ax[0,2].set_title('v3.1')
        
        # ax[0,0].set_ylabel('x-pos')
        # ax[1,0].set_ylabel('y-pos')
        # ax[2,0].set_ylabel('z-pos')
        
        # ax[2,0].set_xlabel('Time (sec)')
        # ax[2,1].set_xlabel('Time (sec)')
        # ax[2,2].set_xlabel('Time (sec)')
        
        # fig.suptitle(f'Session {key}. Wrist Position')
        # fig.savefig(f'{save_img_dir}/smoothed_wrist_session_{key}.png', bbox_inches='tight')
        # plt.close(fig)

        # """ save plot, position. full session """
        # plt.close('all')
        # fig, ax = plt.subplots(3, 2, sharex=True, figsize=(12, 9))
        
        # for r in range(3):
        #     ax[r, 0].sharey(ax[r, 1])
        #     ax[r, 1].set_yticks([])
        
        # ax[0,0].plot(t, wrist_x_raw)
        # ax[1,0].plot(t, wrist_y_raw)
        # ax[2,0].plot(t, wrist_z_raw)
        
        # ax[0,1].plot(t, wrist_x3)
        # ax[1,1].plot(t, wrist_y3)
        # ax[2,1].plot(t, wrist_z3)
        
        # ax[0,0].set_title('Raw')
        # ax[0,1].set_title('v3.1')
        
        # ax[0,0].set_ylabel('x-pos')
        # ax[1,0].set_ylabel('y-pos')
        # ax[2,0].set_ylabel('z-pos')
        
        # ax[2,0].set_xlabel('Time (sec)')
        # ax[2,1].set_xlabel('Time (sec)')
        
        # fig.suptitle(f'Session {session:003}. Wrist Position')
        # fig.savefig(f'{save_img_dir}/smoothed_wrist_full_session_{session:003}.png', bbox_inches='tight')
        # plt.close(fig)

        # """ save plot, velocity. good segment """
        # plt.close('all')
        # fig, axs = plt.subplots(6, 1, sharex=True, figsize=(9, 9))

        # axs[0].plot(t[idx0:idx1], wrist_x3[idx0:idx1])
        # axs[2].plot(t[idx0:idx1], wrist_y3[idx0:idx1])
        # axs[4].plot(t[idx0:idx1], wrist_z3[idx0:idx1])

        # axs[1].plot(t[idx0:idx1-1], wrist_vel_x[idx0:idx1-1])
        # axs[3].plot(t[idx0:idx1-1], wrist_vel_y[idx0:idx1-1])
        # axs[5].plot(t[idx0:idx1-1], wrist_vel_z[idx0:idx1-1])

        # axs[0].set_title(f'Session {key}. Wrist')

        # axs[0].set_ylabel('x-pos')
        # axs[2].set_ylabel('y-pos')
        # axs[4].set_ylabel('z-pos')

        # axs[1].set_ylabel('x-vel')
        # axs[3].set_ylabel('y-vel')
        # axs[5].set_ylabel('z-vel')

        # axs[-1].set_xlabel('Time (sec)')

        # for ax in axs:
        #     ax.set_yticks([])
        #     ax.grid(True)

        # fig.suptitle(f'Session {key}. Wrist Position')
        # fig.savefig(f'{save_img_dir}/velocity_wrist_session_{key}.png', bbox_inches='tight')
        # plt.close(fig)

In [10]:
### write assertion checks..!
# for session in GOOD_SESSIONS:
#     keys = [key for key in MOTION_IDX_DICT.keys() if key.startswith(f'{session:003}')]
# 
#     for key in keys:
#         dir0 = f'./motion_v3_postprocess_bak/{session:003}'
#         dir1 = f'./motion_v3_postprocess/{session:003}'
# 
#         # position t
#         pos_t0 = np.load(f'{dir0}/pos_t_v3p1_session_{key}.npy')
#         pos_t1 = np.load(f'{dir1}/pos_t_session_{key}.npy')
#         assert np.array_equal(pos_t0, pos_t1)
# 
#         # velocity t
#         vel_t0 = np.load(f'{dir0}/vel_t_session_{key}.npy')
#         vel_t1 = np.load(f'{dir1}/vel_t_session_{key}.npy')
#         assert np.array_equal(vel_t0, vel_t1)
# 
#         # positions
#         pos_x0 = np.load(f'{dir0}/wrist_x_v3p1_session_{key}.npy')
#         pos_y0 = np.load(f'{dir0}/wrist_y_v3p1_session_{key}.npy')
#         pos_z0 = np.load(f'{dir0}/wrist_z_v3p1_session_{key}.npy')
#         pos_x1 = np.load(f'{dir1}/wrist_pos_x_session_{key}.npy')
#         pos_y1 = np.load(f'{dir1}/wrist_pos_y_session_{key}.npy')
#         pos_z1 = np.load(f'{dir1}/wrist_pos_z_session_{key}.npy')
#         assert np.array_equal(pos_x0, pos_x1)
#         assert np.array_equal(pos_y0, pos_y1)
#         assert np.array_equal(pos_z0, pos_z1)
# 
#         # velocities
#         vel_x0 = np.load(f'{dir0}/wrist_vel_x_session_{key}.npy')
#         vel_y0 = np.load(f'{dir0}/wrist_vel_y_session_{key}.npy')
#         vel_z0 = np.load(f'{dir0}/wrist_vel_z_session_{key}.npy')
#         vel_x1 = np.load(f'{dir1}/wrist_vel_x_session_{key}.npy')
#         vel_y1 = np.load(f'{dir1}/wrist_vel_y_session_{key}.npy')
#         vel_z1 = np.load(f'{dir1}/wrist_vel_z_session_{key}.npy')
#         assert np.array_equal(vel_x0, vel_x1)
#         assert np.array_equal(vel_y0, vel_y1)
#         assert np.array_equal(vel_z0, vel_z1)