In [None]:
""" import settings """
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import scipy

from matplotlib import pyplot as plt
from matplotlib_settings import set_plot_settings, reset_plot_settings

# Set the plot settings
set_plot_settings()

# import global variables
from utils_ssep_global import *

In [None]:
""" load data, define static params """
# load bad channel
bad_ch_idx_dir = f'{DATA_DIR}/1_bad_channels'
bad_chs = np.load(f"{bad_ch_idx_dir}/bad_ch_idx.npy")

# load segmentized data
seg_data_dir = f'{DATA_DIR}/2_segmentized'
t = np.load(f"{seg_data_dir}/t.npy")

# SSEP index
sep_idxs = np.where(np.logical_and(t > SEP_T0, t < SEP_T1))[0]
# Baseline index
baseline_idxs = np.where(np.logical_and(t > BASELINE_T0, t < BASELINE_T1))[0]

# save to this directory
save_dir = f'{DATA_DIR}/3_car'
if not os.path.exists(save_dir):
   os.makedirs(save_dir)

Select channels to be used for CAR

In [None]:
""" First, exclude unstable channels """
# fetch number of valid trials
valid_trial_counts = []
for stim_site in range(N_SITES):
    fn_label = STIM_LABELS[stim_site].replace(" ", "_").lower()
    vtc = np.load(f"{DATA_DIR}/2_segmentized/{fn_label}_valid_trial_count.npy")
    valid_trial_counts.append(vtc)

# channels with not enough valid trials. exclusive with bad_ch
unstable_chs = []

for vtc in valid_trial_counts:
    n_total = np.max(vtc)
    for ch, count in enumerate(vtc):
        if ch in unstable_chs or ch in bad_chs:
            continue
        if count < n_total*STABLE_PROP_THRESH:
            unstable_chs.append(ch)

unstable_chs = np.array(unstable_chs)

# plot
# fig, ax = plt.subplots(figsize=(3,3))
# z = np.zeros((NCH,))
# z[unstable_chs] = 1
# ax.imshow(z.reshape(16, -1))
# ax.set_title('unstable channels')

In [None]:
""" for each SSEP, fetch channel RMS """
sep_rms = [[] for _ in range(NCH)]
for stim_site in range(N_SITES):
    fn_label = STIM_LABELS[stim_site].replace(" ", "_").lower()
    raw_segs = np.load(f"{seg_data_dir}/{fn_label}_raw_segs.npy")

    for ch, ch_data in enumerate(raw_segs):
        for trial, trial_data in enumerate(ch_data):
            if ch in bad_chs or ch in unstable_chs:
                sep_rms[ch].append(np.nan)
            elif np.isnan(trial_data[0]):
                sep_rms[ch].append(np.nan)
            else:
                sep_rms[ch].append(np.std(trial_data[sep_idxs]))

sep_rms = np.array(sep_rms)
np.save(f'{save_dir}/sep_rms.npy', sep_rms)

In [None]:
""" find channels with lowest RMS """
# sep_rms = np.load(f'{save_dir}/sep_rms.npz')['arr_0']
ch_rms = np.nanmean(sep_rms, axis=1)
car_chs = np.argsort(ch_rms)[:N_CAR_CH]
np.save(f'{save_dir}/car_chs.npy', car_chs)

print(car_chs)

# plot
# fig, ax = plt.subplots(1, 2)
# z = np.zeros((NCH,))
# z[car_chs] = 1
# ax[0].imshow(ch_rms.reshape(16, -1))
# ax[1].imshow(z.reshape(16, -1))

# ax[0].set_title('RMS')
# ax[1].set_title('CAR Channels')

Re-reference all data using common reference

In [None]:
""" Re-reference all data """
for stim_site in range(N_SITES):
    fn_label = STIM_LABELS[stim_site].replace(" ", "_").lower()
    raw_segs = np.load(f"{seg_data_dir}/{fn_label}_raw_segs.npy")

    # rearrange the dimensions to (trials, channels, time points)
    cmr_segs = np.transpose(raw_segs, (1, 0, 2))    

    cars = []
    for trial_data in cmr_segs:
        car = np.nanmean(trial_data[car_chs], axis=0)
        for ch_data in trial_data:
            if np.isnan(ch_data[0]):
                continue
            ch_data -= car
        cars.append(car)

    cars = np.array(cars)
    mean_car = np.nanmean(cars, axis=0)

    # revert arrangement
    cmr_segs = np.transpose(cmr_segs, (1, 0, 2))

    np.save(f'{save_dir}/{fn_label}_mean_car.npy', mean_car)
    np.save(f'{save_dir}/{fn_label}_cmr_segs.npy', cmr_segs)