In [None]:
""" import settings """
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import scipy

from matplotlib import pyplot as plt
from matplotlib_settings import set_plot_settings, reset_plot_settings

# Set the plot settings
set_plot_settings()

# import global variables
from utils_ssep_global import *

In [None]:
""" load data, define static params """
# load bad channel
bad_ch_idx_dir = f'{DATA_DIR}/1_bad_channels'
bad_chs = np.load(f"{bad_ch_idx_dir}/bad_ch_idx.npy")

# load segmentized data
seg_data_dir = f'{DATA_DIR}/2_segmentized'
t = np.load(f"{seg_data_dir}/t.npy")

#
car_data_dir = f'{DATA_DIR}/3_car'

# SSEP index
sep_idxs = np.where(np.logical_and(t > SEP_T0, t < SEP_T1))[0]
# Baseline index
baseline_idxs = np.where(np.logical_and(t > BASELINE_T0, t < BASELINE_T1))[0]
# Artifact index
artifact_idxs = np.where(np.logical_and(t > ARTIFACT_T0, t < ARTIFACT_T1))[0]

# save to this directory
save_dir = f'{DATA_DIR}/4_denoise'
if not os.path.exists(save_dir):
   os.makedirs(save_dir)

Try first on a single dataset

In [None]:
def get_pc(n_pc, u, s, v):
    s_padded = np.zeros((u.shape[1], v.shape[0]))
    s_padded[n_pc, n_pc] = s[n_pc]

    return u@s_padded@v

In [None]:
""" try PC removal """
from sklearn.utils.extmath import randomized_svd
# from numpy.linalg import svd

def remove_artifact_pc(dn_segs, n_components, n_remove, sel_stitch=False, n_stitch=1):

    n_trial = dn_segs.shape[1] # total number of trials

    # PCs in the time window defined as the "artifact range"
    artifact_pcs = np.zeros((n_components, NCH, len(artifact_idxs)))
    sum_pc_removed = np.zeros((NCH, len(artifact_idxs)))

    for trial_idx in range(n_trial):
        trial_data = np.copy(dn_segs[:,trial_idx,artifact_idxs])

        # apply zero-mean
        for ch_data in trial_data:
            ch_data -= np.nanmean(ch_data)

        """ SVD """
        u, s, v = randomized_svd(trial_data, n_components=n_components)

        # compute artifact PCs for the trial
        for pc_idx in range(n_components):
            artifact_pcs[pc_idx,:,:] = get_pc(pc_idx, u, s, v)

        pc_removed = np.sum(artifact_pcs[:n_remove, :], axis=0)

        sum_pc_removed += pc_removed
        dn_segs[:,trial_idx,artifact_idxs] -= pc_removed
        if sel_stitch:
            idx0 = artifact_idxs[0]
            for ch in range(NCH):
                dn_ch_data = dn_segs[ch,trial_idx,:]
                mu1 = np.nanmean(dn_ch_data[idx0-n_stitch:idx0])
                mu2 = np.nanmean(dn_ch_data[idx0:idx0+n_stitch])
                dn_ch_data[idx0:] -= (mu2 -mu1)

    return dn_segs, sum_pc_removed/n_trial

Remove Artifact PCs

In [None]:
""" remove artifact PCs for all data """
for stim_site in range(N_SITES):
    fn_label = STIM_LABELS[stim_site].replace(" ", "_").lower()
    cmr_segs = np.load(f"{car_data_dir}/{fn_label}_cmr_segs.npy") # ch * trial *time

    dn_segs = np.copy(cmr_segs) # de-noised segments
    saturated_idx = np.where(np.isnan(dn_segs))
    dn_segs[saturated_idx] = 0
    # null out bad channel data
    dn_segs[bad_chs] = 0

    dn_segs, pc_removed = remove_artifact_pc(dn_segs, N_PC, N_REMOVE, 
                                             sel_stitch=True, n_stitch=N_STITCH)

    
    dn_segs[dn_segs==0] = np.nan

    np.save(f'{save_dir}/{fn_label}_dn_segs.npy',    dn_segs)
    np.save(f'{save_dir}/{fn_label}_pc_removed.npy', pc_removed)