In [None]:
%load_ext autoreload
%autoreload 2
import os, sys

import numpy as np

from matplotlib import pyplot as plt
from matplotlib_settings import set_plot_settings, reset_plot_settings

# Set the plot settings
set_plot_settings()

# import global variables
from utils_motor_global import *
from utils_motor_sigproc import get_mt_ch_psd, get_pc

ROOT_SAVE_DIR = f'{REC_DIR}/3_HD_remove_{N_PC_REMOVE}_PCs'

# Band-pass filter for extracting heartbeat PC
from scipy.signal import iirfilter, sosfiltfilt
fs = FS/RS
Wn = [HB_BPF_LOW/(fs/2), HB_BPF_HIGH/(fs/2)]
sos_hb = iirfilter(N=HB_BPF_N, Wn=Wn, analog=False, btype='band', ftype='butter', output='sos')

N_PC_PLOT = 10 # number of PCs to plot

from scipy.signal.windows import dpss
from sklearn.utils.extmath import randomized_svd

In [None]:
for session in GOOD_SESSIONS:
    keys = [key for key in SESSION_KEYS if key.startswith(f'{session:003}')]

    for key in keys:
        load_dir = f'{REC_DIR}/2_BPF_DS/{key}'
        if not os.path.exists(load_dir):
            continue

        """ save directory """
        save_data_dir = f'{ROOT_SAVE_DIR}/{key}'
        save_img_dir = f'{ROOT_SAVE_DIR}_imgs/{key}'
        if not os.path.exists(save_data_dir):
            os.makedirs(save_data_dir)
        if not os.path.exists(save_img_dir):
            os.makedirs(save_img_dir)

        """ load data """
        good_channels = np.load(f'{load_dir}/good_channels_{key}.npy')
        t             = np.load(f'{load_dir}/t_DS_session_{key}.npy')
        rec_data      = np.load(f'{load_dir}/recording_DS_session_{key}.npy')

        """ truncate initial data points affected by BPF artifact """
        t0 = t[0] + T_BPF_PAD
        idx0 = np.where(t >= t0)[0][0]

        t = t[idx0:]
        rec_data = rec_data[idx0:,:]

        """ segmentize """
        # recordings are partitioned to T_PC (s) long segments (last seg rounded up)
        # two reasons for doing so
        # 1. scipy dpss doesn't work (seems like a bug) for large dataset
        # 2. heartbeat strength is varying in time and space. If PCs are calculated using full
        # segments, they can become dominated by short bursts with strong HB whereas 
        # portions with weak HB become completely neglected
        n_seg = int(np.round((t[-1] - t[0])/T_PC))

        seg_start_idxs = []
        for i_seg in range(n_seg):
            idx = int(np.where(t >= i_seg*T_PC + t[0])[0][0])
            seg_start_idxs.append(idx)

        seg_start_idxs.append(len(t)+1)

        """ initialize arrays to contain HB PCs, HB removed data """
        pc_removed = np.zeros_like(rec_data)
        pc_removed_data = np.copy(rec_data)

        """ iterate over segments, remove HB """
        for i_seg, (idx0, idx1) in enumerate(zip(seg_start_idxs[:-1], seg_start_idxs[1:])):
            # fetch segment
            t_seg = t[idx0:idx1]
            rec_data_seg = rec_data[idx0:idx1, :]

            """ multi-taper params """
            len_win = t_seg[-1] - t_seg[0] # (s)
            assert len_win <= 1.5*T_PC
            NW = len_win*W_MT_PC # common choices are 2.5, 3, 3.5, 4
            K = int(2*NW - 1) # Number of tapers
            wt = np.ones(K)/K # apply unity weight

            """ params dependent on multi-taper params """
            n = len(t_seg)
            half_n = int(np.ceil(n/2))
            freq = np.fft.fftfreq(n, d = 1/fs)
            half_freq = freq[:half_n]
            fbin = half_freq[1]

            # DPSS
            # dpss_tapers, dpss_eigen = dpss(n, NW, K, return_ratios=True)
            dpss_tapers = dpss(n, NW, K)

            """ band-pass filter """
            filt_hb_data = sosfiltfilt(sos_hb, rec_data_seg, axis=0) # time * ch

            """ SVD """
            # (ch * N_PC), (N_PC,), (N_PC * time)
            hb_u, hb_s, hb_v = randomized_svd(filt_hb_data.T, n_components=N_PC)

            """ Hemodynamics PCs """
            hb_pcs = np.zeros((N_PC, filt_hb_data.shape[1], filt_hb_data.shape[0])) # N * ch * time

            for pc_idx in range(N_PC):
                hb_pcs[pc_idx,:,:] = get_pc(pc_idx, hb_u, hb_s, hb_v)

            hb_pcs = np.transpose(hb_pcs, (0, 2, 1)) # N * time * ch

            """ get PSD of PCs """
            f_idx0 = np.where(half_freq > HB_FREQ_LOW)[0][0]
            f_idx1 = np.where(half_freq < HB_FREQ_HIGH)[0][-1]

            hb_pc_psds = np.zeros((N_PC, half_n))
            hb_tones = []
            for pc_idx, pc in enumerate(hb_v):
                # time-series PC
                assert len(pc) == dpss_tapers.shape[1]

                # PSD of PC. hb_s[pc_idx]*pc is the time series component that is spatially 
                # distributed to channels. hb_u is excluded from computation (its
                # columns have unit lengths)
                hb_pc_psds[pc_idx, :] = get_mt_ch_psd(hb_s[pc_idx]*pc, dpss_tapers, wt)

                # integrate the heartbeat power, which represents hemodynamics strength
                hb_tones.append(np.sum(hb_pc_psds[pc_idx, f_idx0:f_idx1 + 1]))

            """ rank PCs in terms of hemodynamics strength """
            pc_order = []
            for val in reversed(sorted(hb_tones)):
                pc_order.append(hb_tones.index(val))

            """ remove PCs """
            pc_removed[idx0:idx1,:] = np.sum(hb_pcs[pc_order[:N_PC_REMOVE], :, :], axis=0)
            pc_removed_data[idx0:idx1,:] -= pc_removed[idx0:idx1]

            """ Plot: PSD of PCs before and after HD removal """
            # """ PC of recording, after removing hemodynamics """
            # u, s, v = randomized_svd(pc_removed_data[idx0:idx1,:].T, n_components=N_PC)

            # post_pcs = np.zeros((N_PC, filt_hb_data.shape[1], filt_hb_data.shape[0]))
            # for pc_idx in range(N_PC):
            #     post_pcs[pc_idx,:,:] = get_pc(pc_idx, u, s, v)
            # post_pcs = np.transpose(post_pcs, (0, 2, 1))
            # post_pc_psds = np.zeros_like(hb_pc_psds)

            # for pc_idx, pc in enumerate(v):
            #     post_pc_psds[pc_idx, :] = get_mt_ch_psd(s[pc_idx]*pc, dpss_tapers, wt)

            # """ plot: PC PSDs """
            # plt.close('all')
            # fig, ax = plt.subplots(1, 1)

            # legend_strs = []
            # for pc_idx in pc_order[:N_PC_PLOT]:
            #     legend_strs.append(str(pc_idx))
            #     ax.loglog(half_freq[1:], hb_pc_psds[pc_idx, 1:])
            # ax.set_xlim((0.5, 300))
            # ax.legend(legend_strs, title='PC', loc=(1.05, 0))
            # ax.set_title(f'Session {session:003}. Segment: {int(t_seg[0])}-{int(t_seg[-1])} sec. PC')
            # ax.set_xlabel('Frequency (Hz)')
            # ax.set_ylabel('PSD (a.u.)')

            # plt.savefig(f'{save_img_dir}/PC_PSD_seg_{i_seg}.png', bbox_inches='tight')

            # """ PC of recording, after removing hemodynamics """
            # plt.close('all')
            # fig, ax = plt.subplots(1, 1)
            # legend_strs = []
            # for pc_idx in range(N_PC_PLOT):
            #     legend_strs.append(str(pc_idx))
            #     ax.loglog(half_freq[1:], post_pc_psds[pc_idx, 1:])
            # ax.set_xlim((0.5, 300))
            # ax.legend(legend_strs, title='PC', loc=(1.05, 0))
            # ax.set_title(f'Session {session:003}. Segment: {int(t_seg[0])}-{int(t_seg[-1])} sec. HB Removed')
            # ax.set_xlabel('Frequency (Hz)')
            # ax.set_ylabel('PSD (a.u.)')

            # plt.savefig(f'{save_img_dir}/Recording_PSD_seg_{i_seg}.png', bbox_inches='tight')

        """ save pc removed recording """
        np.save(f'{save_data_dir}/good_channels_{key}.npy', good_channels) # carry over
        np.save(f'{save_data_dir}/t_HB_removed_session_{key}.npy', t) # truncated
        np.save(f'{save_data_dir}/HB_PC_removed_session_{key}.npy', pc_removed)
        np.save(f'{save_data_dir}/recording_HB_removed_session_{key}.npy', pc_removed_data)

        # """ plot: channel waveforms, before and after PC removal """
        # # load motion data. plot together with recording data
        # motion_dir = f'{MOTION_DIR}/{session:03}'
        # motion_t  = np.load(f'{motion_dir}/vel_t_session_{key}.npy')

        # wrist_vel_x   = np.load(f'{motion_dir}/wrist_vel_x_session_{key}.npy')
        # wrist_vel_y   = np.load(f'{motion_dir}/wrist_vel_y_session_{key}.npy')
        # wrist_vel_z   = np.load(f'{motion_dir}/wrist_vel_z_session_{key}.npy')

        # wrist_vel_x = (wrist_vel_x - np.mean(wrist_vel_x))/np.std(wrist_vel_x)
        # wrist_vel_y = (wrist_vel_y - np.mean(wrist_vel_y))/np.std(wrist_vel_y)
        # wrist_vel_z = (wrist_vel_z - np.mean(wrist_vel_z))/np.std(wrist_vel_z)

        # plt.close('all')
        # fig, ax = plt.subplots(4, 1, figsize=(6, 8), sharex=True)

        # y_offset = 5

        # for ch_idx, ch in enumerate(good_channels):
        #     for ii in range(4):
        #         ax[ii].clear()
        #     ax[0].set_title(f'Session {key}, Channel {ch}')

        #     ax[0].plot(t, rec_data[:, ch_idx])
        #     ax[1].plot(t, pc_removed[:, ch_idx])
        #     ax[2].plot(t, pc_removed_data[:, ch_idx])

        #     ax[0].set_ylabel('Raw')
        #     ax[1].set_ylabel(f'First {N_PC_REMOVE} PCs\n BPF {HB_BPF_LOW:d}-{HB_BPF_HIGH:d} Hz')
        #     ax[2].set_ylabel(f'PC Removed')

        #     ax[-1].plot(motion_t, wrist_vel_x + y_offset)
        #     ax[-1].plot(motion_t, wrist_vel_y)
        #     ax[-1].plot(motion_t, wrist_vel_z - y_offset)
        #     ax[-1].set_ylabel(f'Wrist Vel.')
        #     ax[-1].legend(['x', 'y', 'z'], fontsize=8)
        #     ax[-1].set_xlabel('Time (s)')

        #     for ii in range(4):
        #         for idx in seg_start_idxs[1:-1]:
        #             ax[ii].axvline(x=t[idx], color='k')
        #         ax[ii].grid(True)

        #     plt.savefig(f'{save_img_dir}/channel_{ch}.png', bbox_inches='tight')

In [None]:
""" write assertion check """
key = SESSION_KEYS[0]
# fn = f'good_channels_{key}.npy'
# fn = f'HB_PC_removed_session_{key}.npy'
fn = f'recording_HB_removed_session_{key}.npy'
# fn = f't_HB_removed_session_{key}.npy'
dir0 = f'./recording_preprocessed_v2/3_SVD_remove_5_PCs/{key}'
dir1 = f'./recording_preprocessed_v3/3_HD_remove_5_PCs/{key}'

X0 = np.load(f'{dir0}/{fn}')
X1 = np.load(f'{dir1}/{fn}')

np.allclose(X0, X1, equal_nan=True)