In [None]:
""" import settings """
%load_ext autoreload
%autoreload 2

# ensure we can see bisc python files if not installed into system
import os, sys
sys.path.append('C:\\Users\\kraus\\Desktop\\BISC_software')
sys.path.append('C:\\Users\\kraus\\Desktop\\BISC_software\\measurements\\pig3')

from utils_mp import read_recdata
from utils_analysis import find_matching_recording

import numpy as np
from matplotlib import pyplot as plt
from matplotlib_settings import set_plot_settings, reset_plot_settings

# Set the plot settings
set_plot_settings()

# import global variables
from utils_ssep_global import *

In [None]:
""" find all baseline recording files """
log_dir = "./recording_raw"
log_fname = "record_log_20231017_cleaned.csv"

criteria = {}
criteria['stim location'] = 'Baseline'
criteria['en_static_electrode'] = 1
criteria['vga_gain'] = VGA_GAIN

rec_fnames, infos = find_matching_recording(f'{RAW_DATA_DIR}/{LOG_FNAME}', criteria)
print(len(rec_fnames))

for ii, rec_fname in enumerate(rec_fnames):
    rec_fnames[ii] = rec_fname.split("/")[-1] # extract filename from the path

Go through baseline recording, identify faulty channels

In [None]:
""" step 1. load all baseline recordings """
i_settle = int(T_SETTLE*FS)

raw_data = [[] for _ in range(NCH)]
for ii, rec_fname in enumerate(rec_fnames):
    values, pulses, ts, nr, nc = read_recdata(f'{RAW_DATA_DIR}/{rec_fname}', use_tetrode=False)

    for ch in range(NCH):
        raw_data[ch].append(values[i_settle:, ch])

# check all channel recording lengths are same
for ch in range(NCH):
    for seg0, seg1 in zip(raw_data[ch][:-1], raw_data[ch][1:]):
        assert len(seg0) == len(seg1)

In [None]:
""" step 2. filter """
# BPF and HPF are not used any more. deprecated
from scipy.signal import iirfilter, sosfiltfilt

lpf_data = [[] for _ in range(NCH)]

# criteria: 
f_cut0 = 300 # LFP cut-offs
lpf_sos = iirfilter(N=8, Wn=f_cut0/(FS/2), btype='low', ftype='butter', output='sos')

for ch, ch_data in enumerate(raw_data):
    ch_lpf = []
    for jj, y in enumerate(ch_data):
        # "flatten out" saturated values
        y[y <= LOW_CUT] = LOW_CUT
        y[y >= HIGH_CUT] = HIGH_CUT

        if np.std(y) == 0:
            ch_lpf.append(np.zeros(y.shape)) # fully saturated
        else:
            ch_lpf.append(sosfiltfilt(lpf_sos, y))

    lpf_data[ch] = np.array(ch_lpf)
    
lpf_data = np.array(lpf_data) # channel * files * time

In [None]:
def compute_overall_std_dev(sigmas):
    if np.alltrue(np.isnan(sigmas)):
        return np.nan

    numel, denom = 0, 0
    for sigma in sigmas:
        if np.isnan(sigma):
            continue
        
        numel += sigma**2
        denom += 1
    
    return np.sqrt(numel/denom)

In [None]:
""" Step 3. Compute Channel RMS """
# compute channel RMS for each recording file
lpf_sigmas = np.zeros((NCH, lpf_data.shape[1]))
for ch, ch_lpf in enumerate(lpf_data):
    for trial, lpf_seg in enumerate(ch_lpf):
        s0 = np.std(lpf_seg)

        if s0 == 0: lpf_sigmas[ch, trial] = np.nan
        else: lpf_sigmas[ch, trial] = s0

# compute channel RMS across all recording
S = np.zeros((NCH, ))
for ch, sigmas in enumerate(lpf_sigmas):
    S[ch] = compute_overall_std_dev(sigmas)

In [None]:
""" plot RMS map """
fig, ax = plt.subplots(1, 1, figsize=(3, 3))
fig.suptitle('Channel RMS')

im0 = ax.imshow(S.reshape(16, -1))
# ax.set_title(f'LPF@{f_cut0:d}Hz', fontsize=12)

In [None]:
""" Apply Heuristic Thresholding """
thresh0, thresh1 = 2.5, 3
dead_channels = S < np.nanmean(S)/DEAD_THRESH
noisy_channels = S > np.nanmean(S)*NOISY_THRESH

saturated_channels = (np.isnan(S))

fig, ax = plt.subplots(1, 3)
ax[0].imshow(dead_channels.reshape(16, -1))
ax[1].imshow(noisy_channels.reshape(16, -1))
ax[2].imshow(saturated_channels.reshape(16, -1))

ax[0].set_title('Dead')
ax[1].set_title('Noisy')
ax[2].set_title('Saturated')

for ii in range(3):
    ax[ii].set_xticks([])
    ax[ii].set_yticks([])

In [None]:
""" Step 4. save bad channel idx """
bad_ch_idx = np.where(np.logical_or(np.logical_or(dead_channels, noisy_channels), 
                                    saturated_channels))[0]
print(f'bad channels: {bad_ch_idx}')

save_dir = f'{DATA_DIR}/1_bad_channels'
if not os.path.exists(save_dir):
   os.makedirs(save_dir)

np.save(f'{save_dir}/bad_ch_idx.npy', bad_ch_idx)