In [None]:
%load_ext autoreload
%autoreload 2
import os, sys

import numpy as np

from matplotlib import pyplot as plt
from matplotlib_settings import set_plot_settings, reset_plot_settings

# Set the plot settings
set_plot_settings()

from scipy.signal import iirfilter, sosfiltfilt
from scipy.signal import decimate


# import global variables
from utils_motor_global import *
from utils_motor_sigproc import interpolate_good_channels

# create BPF
Wn = [BPF_LOW/(FS/2), BPF_HIGH/(FS/2)]
sos = iirfilter(N=N_BPF, Wn=Wn, btype='band', ftype='butter', output='sos')

# N_PC = 10 # number of PCs to calculate
ROOT_SAVE_DIR = f'{REC_DIR}/2_BPF_DS'

In [None]:
for session in GOOD_SESSIONS:

    keys = [key for key in SESSION_KEYS if key.startswith(f'{session:003}')]
    for key in keys:
        print(f'processing session: {key}')

        """ load BISC recording """
        load_dir = f'{REC_DIR}/1_impute/{key}'
        # if not os.path.exists(load_dir):
            # print(f'Session {key}: No overlap between motion and BISC recording')
            # continue
        rec_data = np.load(f'{load_dir}/recording_session_{key}.npy')
        t        = np.load(f'{load_dir}/t_session_{key}.npy')

        """ load motion data """
        motion_dir = f'{MOTION_DIR}/{session:03}'
        motion_t  = np.load(f'{motion_dir}/pos_t_session_{key}.npy')

        """ identify good channels, by taking non-sat ratio in segment that overlaps with good behavior """
        nonsaturated = ~np.isnan(rec_data)
        # nonsat ratio for each channel, before truncating (for plotting only)
        full_nonsat_ratio = np.sum(nonsaturated, axis=0)/nonsaturated.shape[0] 

        # In retrospect, these parameters should have been:
        # t0 = max(T_SETTLE, motion_t[0] - T_BPF_PAD - T_SCALO_PAD)
        # t1 = min(t[-1], motion_t[-1] + T_BPF_PAD + T_SCALO_PAD)
        t0 = max(t[0], motion_t[0] - T_SCALO_PAD)
        t1 = min(t[-1], motion_t[-1])
        i0 = np.where(t >= t0)[0][0]
        i1 = np.where(t >= t1)[0][0] + 1

        # nonsat ratio for each channel, after truncating to match the good behavior segs.
        effective_nonsat_ratio = np.sum(nonsaturated[i0:i1,:], axis=0)/(i1-i0)
        good_channels = np.where(effective_nonsat_ratio > GOODCH_CUT)[0]

        """ interpolate good channels """
        gch_data = interpolate_good_channels(rec_data, good_channels)
        # check that all points are non-NaN
        assert np.alltrue(~np.isnan(gch_data))

        """ band pass filter """
        filt_data = sosfiltfilt(sos, gch_data, axis=0)

        """ down sample """
        t_RS = t[::RS]
        # cascaded decimation, as recommended by scipy doc
        filt_data_RS = decimate(filt_data, RS1, axis=0)
        filt_data_RS = decimate(filt_data_RS, RS2, axis=0)

        """ save data """
        save_data_dir = f'{ROOT_SAVE_DIR}/{key}'
        save_img_dir = f'{ROOT_SAVE_DIR}_imgs/{key}'
        if not os.path.exists(save_data_dir):
            os.makedirs(save_data_dir)
        if not os.path.exists(save_img_dir):
            os.makedirs(save_img_dir)

        # good channels
        np.save(f'{save_data_dir}/nonsat_ratio_channel_{key}.npy', effective_nonsat_ratio)
        np.save(f'{save_data_dir}/good_channels_{key}.npy', good_channels)

        # BPF & Downsampled
        np.save(f'{save_data_dir}/t_DS_session_{key}.npy', t_RS)
        np.save(f'{save_data_dir}/recording_DS_session_{key}.npy', filt_data_RS)

        """ save plot: saturated ratio. full vs. effective segment """
        plt.close('all')
        fig, ax = plt.subplots(1, 2)
        vmin = GOODCH_CUT
        ax[0].imshow(full_nonsat_ratio.reshape(16, -1), vmin=vmin, vmax=1.0)
        ax[1].imshow(effective_nonsat_ratio.reshape(16, -1), vmin=vmin, vmax=1.0)

        ax[0].set_title('Full Segment')
        ax[1].set_title('Effective Segment')
        fig.text(0.5, 0.15, f'vmin = {vmin}, vmax = 1.0', ha='center')

        plt.savefig(f'{save_img_dir}/saturation_map.png', bbox_inches='tight')

        """ save plot: waveform for every channel """
        # plt.close('all')
        # fig, ax = plt.subplots(2, 1, sharex=True)

        # for ch_idx, ch in enumerate(good_channels):
        #     ax[0].clear()
        #     ax[1].clear()

        #     ax[0].plot(t, rec_data[:, ch])
        #     ax[1].plot(t_RS, filt_data_RS[:, ch_idx])

        #     ax[0].set_title(f'Session {key}, Channel {ch} \n \
        #     Pre-Interpolation Non-sat. Ratio: {effective_nonsat_ratio[ch]*100:.2f}%')

        #     ax[0].set_ylabel('Before BPF, DS')
        #     ax[1].set_ylabel('After BPF, DS')
        #     ax[1].set_xlabel('Time (sec)')

        #     ax[1].axvline(x=motion_t[0], color='k')
        #     ax[1].axvline(x=motion_t[-1], color='k')
        #     ax[1].axvline(x=t[0] + T_BPF_PAD, color='r')

        #     plt.savefig(f'{save_img_dir}/channel_{ch}.png', bbox_inches='tight')

In [None]:
""" write assertion check """
key = SESSION_KEYS[0]
# fn = f'good_channels_{key}.npy'
# fn = f'nonsat_ratio_channel_{key}.npy'
# fn = f't_DS_session_{key}.npy'
fn = f'recording_DS_session_{key}.npy'
dir0 = f'./recording_preprocessed_v2/2_BPF_DS/{key}'
dir1 = f'./recording_preprocessed_v3/2_BPF_DS/{key}'

X0 = np.load(f'{dir0}/{fn}')
X1 = np.load(f'{dir1}/{fn}')

# np.array_equal(X0[~np.isnan(X0)], X1[~np.isnan(X1)])
np.allclose(X0, X1, equal_nan=True)