In [None]:
""" import settings """
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import scipy

from matplotlib import pyplot as plt
from matplotlib_settings import set_plot_settings, reset_plot_settings

# Set the plot settings
set_plot_settings()

# import global variables
from utils_ssep_global import *

In [None]:
""" load data, define static params """
# load bad channel
bad_ch_idx_dir = f'{DATA_DIR}/1_bad_channels'
bad_chs = np.load(f"{bad_ch_idx_dir}/bad_ch_idx.npy")

# load segmentized data
seg_data_dir = f'{DATA_DIR}/2_segmentized'
t = np.load(f"{seg_data_dir}/t.npy")

#
car_data_dir = f'{DATA_DIR}/3_car'
# t = np.load(f"{car_data_dir}/t.npy")

dn_data_dir = f'{DATA_DIR}/4_denoise'

# SSEP index
sep_idxs = np.where(np.logical_and(t > SEP_T0, t < SEP_T1))[0]
# Baseline index
baseline_idxs = np.where(np.logical_and(t > BASELINE_T0, t < BASELINE_T1))[0]

# Downsampled index. Use a smaller window from this point of the pipeline
ds_idxs = np.where(np.logical_and(t > DS_T0, t < DS_T1))[0]

# save to this directory
save_dir = f'{DATA_DIR}/5_downsample'
if not os.path.exists(save_dir):
   os.makedirs(save_dir)

Downsample all data sets

In [None]:
t = t[ds_idxs]
t_RS = t[::RS]
print(f'Downsampling to {FS/RS:.1f} Hz')

In [None]:
from tqdm.notebook import tqdm
from scipy.signal import decimate

for stim_site in range(N_SITES):
    # re-do counting
    valid_trial_count = np.zeros((NCH,))

    """ load data """
    fn_label = STIM_LABELS[stim_site].replace(" ", "_").lower()

    # dimensions: ch * trial * time
    full_segs = np.load(f"{dn_data_dir}/{fn_label}_dn_segs.npy")
    full_segs = full_segs[:,:,ds_idxs]

    # ds_seg_len = int(np.ceil(full_segs.shape[2]/RS))
    ds_seg_len = int(np.ceil(full_segs.shape[2]/RS))
    num_trials = full_segs.shape[1]

    ds_segs = np.full((NCH, num_trials, ds_seg_len), np.nan)
    # decimate segment by segment. batch process doesn't work because of NaN
    for ch, ch_data in enumerate(full_segs):
        for tr, trial_data in enumerate(ch_data):
            if np.any(np.isnan(trial_data)):
                # ds_segs[ch][tr] = np.full((ds_seg_len,), np.nan)
                continue
            else:
                ds_segs[ch][tr] = decimate(decimate(trial_data, q=RS1, zero_phase=True),
                                          q=RS2, zero_phase=True)
                valid_trial_count[ch] += 1

    np.save(f'{save_dir}/{fn_label}_ds_{RS}_segs.npy', ds_segs)
    np.save(f'{save_dir}/{fn_label}_valid_trial_count.npy', valid_trial_count)

np.save(f'{save_dir}/t_{RS}.npy', t_RS)