In [None]:
""" import settings """
%load_ext autoreload
%autoreload 2

import os, sys
sys.path.append('C:\\Users\\kraus\\Desktop\\BISC_software')
sys.path.append('C:\\Users\\kraus\\Desktop\\BISC_software\\measurements\\pig3')

from utils_mp import read_recdata
from utils_analysis import recover_pulses, get_pulse_edges, find_matching_recording

import numpy as np
import scipy

from matplotlib import pyplot as plt
from matplotlib_settings import set_plot_settings, reset_plot_settings

# Set the plot settings
set_plot_settings()

# import global variables
from utils_ssep_global import *

In [None]:
""" Load bad channel index """ 
bad_ch_idx_dir = f'{DATA_DIR}/1_bad_channels'
bad_chs = np.load(f"{bad_ch_idx_dir}/bad_ch_idx.npy")

In [None]:
""" define segment length """
# SSEP technical params and segmentization params 
# get almost full period for now. they will be further chopped up down the pipeline -> changed to [-250, 250] msec
# tic, toc = -1*((1/STIM_RATE)-5e-3), (1/STIM_RATE)-5e-3 
tic, toc = FULL_T0, FULL_T1
len_seg = int(toc*FS) - int(tic*FS)
i_sync_delay = int(T_SYNC_DELAY*FS)

# local field potential BPF
# I ended up not using filtered data down the pipeline
# Biggest reason is that low-pass filtering could not applied to data prior to segmentization
# because of saturation issues. Note that motion artifacts and vet techs stepping on the cables
# can kick the channels in and out of saturation..
# And it is impractical to apply a reasonable high-pass filtering on segmentized data.
# For example 1-Hz cut-off can't be applied when the segment itself is shorter than a second..
# Instead of BPF, I relied on combination of linear detrending and low-pass filtering

# band = [1, 300]
# Wn = [e / FS* 2 for e in band]
# lfp_filter_coeff = scipy.signal.iirfilter(N=4, Wn=Wn, analog=False, btype='bandpass', 
#                                           ftype='butter', output='sos')

In [None]:
# save to this directory
save_dir = f'{DATA_DIR}/2_segmentized'
if not os.path.exists(save_dir):
   os.makedirs(save_dir)

In [None]:
""" manually pick a stimulation site for single file inspection """
stim_site = 0 # pick from 0, 1, 2, 3, 4

stim_label = STIM_LABELS[stim_site]
fn_label = STIM_LABELS[stim_site].replace(" ", "_").lower()
print(stim_label)

In [None]:
""" find all recording files corresponding to the given peripheral stim site """
en_static_electrode = 1
criteria = {}
criteria['en_static_electrode'] = en_static_electrode
criteria['vga_gain'] = VGA_GAIN
criteria['stim location'] = stim_label

# some sites have more number of files because (I think) some were taken without turning on the sync pulse
# number of valid recordings will come out to be more or less even for all sties
rec_fnames, infos = find_matching_recording(f'{RAW_DATA_DIR}/{LOG_FNAME}', criteria)
print(len(rec_fnames))

for ii, rec_fname in enumerate(rec_fnames):
    rec_fnames[ii] = rec_fname.split("/")[-1]

In [None]:
""" Segmentize all SSEP recordings and save them """

en_static_electrode = 1
criteria = {}
criteria['en_static_electrode'] = en_static_electrode
criteria['vga_gain'] = VGA_GAIN

for stim_site in range(N_SITES):
    # find all recording files corresponding to the given peripheral stim site
    stim_label = STIM_LABELS[stim_site]
    criteria['stim location'] = stim_label

    rec_fnames, infos = find_matching_recording(f'{RAW_DATA_DIR}/{LOG_FNAME}', criteria)
    print(f'Identified {len(rec_fnames)} files for {stim_label} stimulation')

    for ii, rec_fname in enumerate(rec_fnames):
        rec_fnames[ii] = rec_fname.split("/")[-1]

    # initialize
    valid_trial_count = np.zeros((NCH,)) # number of valid (non-satruated) trials
    raw_segs = [[] for _ in range(NCH)]  # unfiltered segments

    for rec_fname in rec_fnames:
        values, pulses, ts, nr, nc = read_recdata(f'{RAW_DATA_DIR}/{rec_fname}', use_tetrode=False)


        # from raw pulse data, get index of rising edges
        recovered = recover_pulses(pulses, STIM_RATE, STIM_PULSE_WIDTH, always_on=True, dt=TS)
        sync_idxs, _ = get_pulse_edges(pulses)

        if not sync_idxs: # if empty
            print(f'file: {rec_fname}. no sync pulse detected')
            continue

        # from sync pulse timings, retrieve stimulation timings
        stim_idxs = np.array(sync_idxs) - i_sync_delay
        if stim_idxs[0] < 0: stim_idxs = stim_idxs[1:]

        """ segmentize """

        for stim_idx in stim_idxs:
            i0, i1 = stim_idx + int(tic*FS), stim_idx + int(toc*FS)

            if i0 < 0: continue # first stim recording is incomplete
            if i1 > values.shape[0]: break # last stim recording is incomplete

            for ch in range(NCH):
                raw_seg = values[i0:i1, ch]

                # not sure why I didn't code like this: 
                is_saturated = np.any(np.logical_or(raw_seg <= LOW_CUT, raw_seg >= HIGH_CUT))
                # is_saturated = not np.all((raw_seg > LOW_CUT) & (raw_seg < HIGH_CUT))

                if is_saturated:
                    raw_segs[ch].append(np.full((len_seg,), np.nan))
                else:
                    valid_trial_count[ch] += 1
                    raw_segs[ch].append(scipy.signal.detrend(raw_seg))

    # convert to array and save
    raw_segs = np.array(raw_segs)
    fn_label = STIM_LABELS[stim_site].replace(" ", "_").lower()
    np.save(f'{save_dir}/{fn_label}_valid_trial_count.npy', valid_trial_count)
    np.save(f'{save_dir}/{fn_label}_raw_segs.npy', raw_segs)

# time
t = np.arange(0, len_seg)*TS
t = t - np.mean(t)
np.save(f'{save_dir}/t.npy', t)