In [None]:
"""
This notebook z-scores each channel recording using per-session mean and all-session std dev.
But this notebook has no effect on modeling and would have been better to skip altogether
because each channel bands are separately z-scored later in the processing pipeline.
"""

In [None]:
%load_ext autoreload
%autoreload 2
import os, sys

import numpy as np

from matplotlib import pyplot as plt
from matplotlib_settings import set_plot_settings, reset_plot_settings

# Set the plot settings
set_plot_settings()

# import global variables
from utils_motor_global import *

sys.path.append(UTILS_DIR)
from utils_impute import recover_rec_chs
from utils_motor_sigproc import compute_overall_std_dev

ROOT_SAVE_DIR = f'{REC_DIR}/4_zscore'

In [None]:
""" Step 1. recover Channels """
for session in GOOD_SESSIONS:
    keys = [key for key in SESSION_KEYS if key.startswith(f'{session:003}')]

    for key in keys:
        load_dir = f'{REC_DIR}/3_HD_remove_{N_PC_REMOVE}_PCs/{key}'
        if not os.path.exists(load_dir):
            continue

        print(f'Processing session {key}..')

        save_data_dir = f'{ROOT_SAVE_DIR}/{key}'
        save_img_dir = f'{ROOT_SAVE_DIR}_imgs/{key}'
        if not os.path.exists(save_data_dir):
            os.makedirs(save_data_dir)
        if not os.path.exists(save_img_dir):
            os.makedirs(save_img_dir)     

        """ load data """
        good_chs        = np.load(f'{load_dir}/good_channels_{key}.npy')
        t               = np.load(f'{load_dir}/t_HB_removed_session_{key}.npy')
        rec_data        = np.load(f'{load_dir}/recording_HB_removed_session_{key}.npy') # time * channel

        """ recover channels """
        # recov_data: (channel * time). transposed, finally.
        recov_data, recov_chs, _ = recover_rec_chs(rec_data.T, good_chs)
        
        """ save """
        # carry over
        np.save(f'{save_data_dir}/t_session_{key}.npy', t)

        np.save(f'{save_data_dir}/good_channels_{key}.npy', recov_chs)
        np.save(f'{save_data_dir}/recording_session_{key}.npy', recov_data)
        
        """ save plot: imputed channels """
        # fig, ax = plt.subplots(1, 2)

        # good_chs_padded, recov_chs_padded = np.zeros((256,)), np.zeros((256,))
        # good_chs_padded[good_chs] = 1
        # recov_chs_padded[recov_chs] = 1

        # ax[0].imshow(good_chs_padded.reshape(16, -1), vmin=0, vmax=1)
        # ax[1].imshow(recov_chs_padded.reshape(16, -1), vmin=0, vmax=1)

        # ax[0].set_title(f'Before: {len(good_chs)}')
        # ax[1].set_title(f'After: {len(recov_chs)}')

        # plt.savefig(f'{save_img_dir}/good_channel_map.png', bbox_inches='tight')
        # plt.close(fig)

In [None]:
""" write assertion check """
# key = SESSION_KEYS[0]
for key in SESSION_KEYS:
    # fn0 = f't_session_{key}.npy'
    # fn1 = f't_session_{key}.npy'

    # fn0 = f'sd_imp_good_channels_{key}.npy'
    # fn1 = f'good_channels_{key}.npy'

    fn0 = f'sd_imp_recording_session_{key}.npy'
    fn1 = f'recording_session_{key}.npy'

    dir0 = f'./recording_preprocessed_v3/4_zscore_bak/{key}'
    dir1 = f'./recording_preprocessed_v3/4_zscore/{key}'

    X0 = np.load(f'{dir0}/{fn0}')
    X1 = np.load(f'{dir1}/{fn1}')

    assert np.allclose(X0, X1, equal_nan=True)

In [None]:
""" Step 2. compute per-session std dev. """
goodch_tracker = np.ones(256, dtype=int)

session_sigmas  = []
session_lens    = []

for session in GOOD_SESSIONS:
    keys = [key for key in SESSION_KEYS if key.startswith(f'{session:003}')]
    
    for key in keys:
        load_dir = f'{ROOT_SAVE_DIR}/{key}'
        if not os.path.exists(load_dir):
            continue

        """ load data """
        good_chs        = np.load(f'{load_dir}/good_channels_{key}.npy')
        rec_data        = np.load(f'{load_dir}/recording_session_{key}.npy')        

        """ remove bad channel """
        for ch in range(NCH):
            if ch not in good_chs:
                goodch_tracker[ch] = 0

        """ compute std dev. """
        sigmas = np.full((256, ), np.nan)
        sigmas[good_chs] = np.std(rec_data, axis=1)

        session_sigmas.append(sigmas)
        session_lens.append(rec_data.shape[1])

session_sigmas  = np.array(session_sigmas)
session_lens    = np.array(session_lens)

common_chs = np.where(goodch_tracker == 1)[0]

In [None]:
""" Step 2.1 compute overall channel std devs """
all_session_sigmas = np.full((NCH,), np.nan)

for ch, ch_sigmas in enumerate(session_sigmas.T):
    all_session_sigmas[ch] = compute_overall_std_dev(ch_sigmas, session_lens)

In [None]:
# X = np.full((256,), np.nan)
# X[common_chs] = all_session_sigmas[common_chs]
# plt.imshow(X.reshape(16, -1))

In [None]:
""" save common channels, channel std devs """
np.save(f'{ROOT_SAVE_DIR}/common_good_channels.npy', common_chs)
np.save(f'{ROOT_SAVE_DIR}/channel_sigmas.npy', all_session_sigmas)

In [None]:
""" step 3. z-score using per session mean and all session std dev. """
# > 300uV points can be labeled later by inverse-normalizing
ch_sigmas       = np.load(f'{ROOT_SAVE_DIR}/channel_sigmas.npy')

for session in GOOD_SESSIONS:
    keys = [key for key in SESSION_KEYS if key.startswith(f'{session:003}')]
    
    for key in keys:
        load_dir = f'{ROOT_SAVE_DIR}/{key}'
        if not os.path.exists(load_dir):
            continue

        """ load data """
        good_chs        = np.load(f'{load_dir}/good_channels_{key}.npy')
        rec_data        = np.load(f'{load_dir}/recording_session_{key}.npy') 

        """ z-score """
        for ch, ch_data in zip(good_chs, rec_data):
            ch_data -= np.mean(ch_data)
            ch_data /= ch_sigmas[ch]
        
        """ save """
        np.save(f'{load_dir}/zscored_recording_session_{key}.npy', rec_data)

In [None]:
""" write assertion check """
# key = SESSION_KEYS[0]
for key in SESSION_KEYS:
    fn = f'zscored_recording_session_{key}.npy'

    dir0 = f'./recording_preprocessed_v3/4_zscore_bak/{key}'
    dir1 = f'./recording_preprocessed_v3/4_zscore/{key}'

    X0 = np.load(f'{dir0}/{fn}')
    X1 = np.load(f'{dir1}/{fn}')

    assert np.allclose(X0, X1, equal_nan=True)

In [None]:
""" sample plot zscore before vs after """
# session = 10
# keys = [key for key in SESSION_KEYS if key.startswith(f'{session:003}')]

# key = keys[0]
# load_dir = f'{ROOT_SAVE_DIR}/{key}'

# """ load data """
# good_chs        = np.load(f'{load_dir}/good_channels_{key}.npy')
# t               = np.load(f'{load_dir}/t_session_{key}.npy')
# rec_data        = np.load(f'{load_dir}/recording_session_{key}.npy') 
# zs_rec_data     = np.load(f'{load_dir}/zscored_recording_session_{key}.npy') 

# ch = 16*1 + 13
# assert ch in good_chs
# ch_idx = np.where(good_chs == ch)[0][0]

# fig, ax = plt.subplots(2, 1)

# ax[0].plot(t, rec_data[ch_idx, :])
# ax[1].plot(t, zs_rec_data[ch_idx, :])