In [None]:
%load_ext autoreload
%autoreload 2
import os, sys

import numpy as np

from matplotlib import pyplot as plt
from matplotlib_settings import set_plot_settings, reset_plot_settings

# Set the plot settings
set_plot_settings()

# import global variables
from utils_motor_global import *

sys.path.append(UTILS_DIR)
from utils_mp import read_recdata

from utils_impute import interpolate_pkt_loss_mp
from utils_impute import impute_td_mp
# multi-processing space-domain imputing does NOT reduce runtime compared to single core
from utils_impute import impute_sd # impute_sd_mp

ROOT_SAVE_DIR = f'{REC_DIR}/1_impute_new'

Read raw recording

In [None]:
for session in GOOD_SESSIONS:
    keys = [key for key in SESSION_KEYS if key.startswith(f'{session:003}')]

    for key in keys:
    # for key in [keys[0]]:
        """ load motion """
        motion_dir = f'{MOTION_DIR}/{session:03}'
        motion_t  = np.load(f'{motion_dir}/pos_t_session_{key}.npy')

        """ load recording """
        rec_dir = f'{RAW_REC_DIR}/{session:03}'
        assert len(os.listdir(rec_dir)) == 1
        rec_fname = os.listdir(rec_dir)[0]

        print(f'processing session: {key}')
        rec_data, _, ts, nr, nc = read_recdata(file_path=f'{rec_dir}/{rec_fname}',
                                                   use_tetrode=USE_TETRODE)

        n = rec_data.shape[0]
        t = np.arange(n)*TS

        # check that BISC recording and good behavior segments overlap
        if t[0] > motion_t[-1] or t[-1] < motion_t[0]:
            print(f'Session {key}: no overlap between good behaving segment and BISC recording')
            continue

        """ truncate recording to match good behaving segment """
        t0 = max(T_SETTLE, motion_t[0] - T_BPF_PAD - T_SCALO_PAD)
        t1 = min(t[-1], motion_t[-1] + T_BPF_PAD + T_SCALO_PAD)
        i0, i1 = int(t0*FS), int(t1*FS)

        rec_data = rec_data[i0:i1,:].astype('float')
        t = t[i0:i1]
    
        """ saturated indices of raw data """
        saturated = np.logical_or(rec_data <= LOW_CUT, rec_data >= HIGH_CUT)
        # init_nonsat_ratio = (np.ones((nch,))*rec_data.shape[0] - np.sum(saturated, axis=0))/rec_data.shape[0] # raw data after truncating for initial settling

        """ impute packet loss """
        # packet loss points are filled with 1023. 
        plc_rec_data = interpolate_pkt_loss_mp(rec_data, 11, LOW_CUT, HIGH_CUT) # N=11: 3.08kHz. Fix Me: LOW_CUT and HIGH_CUT aren't used inside the function
        plc_saturated = np.logical_or(plc_rec_data <= LOW_CUT, plc_rec_data >= HIGH_CUT) # saturated indices
        plc_nonsat_ratio = (np.ones((NCH,))*plc_rec_data.shape[0] - np.sum(saturated, axis=0))/plc_rec_data.shape[0] # non-sat ratio for each channel
        del saturated

        """ impute in time domain """
        imp_td_rec_data = impute_td_mp(plc_rec_data, 3, LOW_CUT, HIGH_CUT)
        imp_td_saturated = np.logical_or(imp_td_rec_data <= LOW_CUT, imp_td_rec_data >= HIGH_CUT)
        imp_td_nonsat_ratio = (np.ones((NCH,))*imp_td_rec_data.shape[0] - np.sum(imp_td_saturated, axis=0))/imp_td_rec_data.shape[0]

        """ impute in space domain """
        # channels are zero-meaned prior to imputing. 
        # imputed recording data, list of imputed channels
        imp_sd_rec_data, imp_sd_chs = impute_sd(imp_td_rec_data, imp_td_saturated, GOODCH_CUT)
        imp_sd_saturated = np.isnan(imp_sd_rec_data)
        imp_sd_nonsat_ratio = (np.ones((NCH,))*imp_sd_rec_data.shape[0] - np.sum(imp_sd_saturated, axis=0))/imp_sd_rec_data.shape[0]

        """ save imputed data, plots """
        save_data_dir = f'{ROOT_SAVE_DIR}/{key}'
        save_img_dir = f'{ROOT_SAVE_DIR}_imgs/{key}'
        if not os.path.exists(save_data_dir):
            os.makedirs(save_data_dir)
        if not os.path.exists(save_img_dir):
            os.makedirs(save_img_dir)

        """ save imputed data """
        np.save(f'{save_data_dir}/recording_session_{key}.npy', imp_sd_rec_data)
        np.save(f'{save_data_dir}/t_session_{key}.npy', t) 
        np.save(f'{save_data_dir}/imp_sd_channels_session_{key}.npy', t) 
        np.save(f'{save_data_dir}/plc_saturated_session_{key}.npy', plc_nonsat_ratio) 
        np.save(f'{save_data_dir}/imp_td_saturated_{key}.npy', imp_td_saturated) 
        np.save(f'{save_data_dir}/imp_sd_saturated_{key}.npy', imp_sd_saturated) 

        """ save plot: saturated ratio. before vs. after imputing """
        plt.close('all')
        fig, ax = plt.subplots(1, 3)
        ax[0].imshow(plc_nonsat_ratio.reshape(16, -1), vmin=0.95, vmax=1.0)
        ax[1].imshow(imp_td_nonsat_ratio.reshape(16, -1), vmin=0.95, vmax=1.0)
        ax[2].imshow(imp_sd_nonsat_ratio.reshape(16, -1), vmin=0.95, vmax=1.0)

        ax[0].set_title('Before Impute')
        ax[1].set_title('After TD Impute')
        ax[2].set_title('After SD Impute')
        fig.text(0.5, 0.15, 'vmin = 0.95, vmax = 1.0', ha='center')

        plt.savefig(f'{save_img_dir}/saturation_map.png', bbox_inches='tight')

        """ save plot: waveform for every channel (to figure out the appropriate threshold for next processing steps) """
        # fig, ax = plt.subplots(2, 1, sharex=True)

        # for ch in range(NCH):
        #     ax[0].clear()
        #     ax[1].clear()

        #     ax[0].plot(t, rec_data[:, ch])
        #     ax[1].plot(t, imp_sd_rec_data[:, ch])

        #     ax[0].set_title(f'Session {session:003}, Channel {ch} \n \
        #     Non-saturated Points: {plc_nonsat_ratio[ch]*100:.2f} -> {imp_sd_nonsat_ratio[ch]*100:.2f}%')

        #     ax[0].set_ylabel('Before Impute')
        #     ax[1].set_ylabel('After Impute')
        #     ax[1].set_xlabel('Time (sec)')

        #     plt.savefig(f'{save_img_dir}/channel_{ch}.png', bbox_inches='tight')
    #     # break
    # break

In [26]:
""" write assertion check """

# key = SESSION_KEYS[0]
# # fn = f'imp_sd_saturated_{key}.npy'
# # fn = f'imp_sd_channels_session_{key}.npy'
# # fn = f't_session_{key}.npy'
# fn = f'recording_session_{key}.npy'
# dir0 = f'./recording_preprocessed_v3/1_impute/{key}'
# dir1 = f'./recording_preprocessed_v3/1_impute_new/{key}'
# 
# X0 = np.load(f'{dir0}/{fn}')
# X1 = np.load(f'{dir1}/{fn}')
# 
# np.allclose(X0, X1, equal_nan=True)

True