In [1]:
import os, sys
sys.path.append('../')
sys.path.append('../../')
import numpy as np
import h5py
import pickle
from data_utils import getSeizureTimes

# File markers for EEG clips

In [2]:
RESAMPLE_DIR = '/media/nvme_data/siyitang/TUH_eeg_seq_v1.5.2/resampled_signal/'
#RESAMPLE_DIR = '/home/siyitang/data/TUH_v1.5.2/TUH_eeg_seq_v1.5.2/resampled_signal'
CLIP_LEN = 12
TIME_STEP_SIZE = 1
STRIDE = CLIP_LEN
FREQUENCY = 200

In [3]:
FILES_TO_CONSIDER = {}
for split in ['train', 'dev', 'test']:
    file_to_consider_txt = split+'Set_seizureDetect_files.txt'
    with open(file_to_consider_txt, 'r') as f:
        fstr = f.readlines()
    FILES_TO_CONSIDER[split] = [fstr[i].strip('\n').split(',')[0].split('/')[-1] for i in range(len(fstr))]
print(len(FILES_TO_CONSIDER['train']))
print(len(FILES_TO_CONSIDER['dev']))
print(len(FILES_TO_CONSIDER['test']))

5179
928
1192


In [4]:
FILES_TO_CONSIDER['train'][:5]

['00008295_s009_t007.edf',
 '00010591_s001_t000.edf',
 '00010489_s008_t007.edf',
 '00012262_s007_t001.edf',
 '00004294_s001_t000.edf']

In [5]:
RAW_DATA_DIR = "/media/nvme_data/TUH/v1.5.2/edf/"
#RAW_DATA_DIR = "/data/crypt/eegdbs/temple/tuh_eeg_seizure/v1.5.2/edf/"

edf_files = []
for path, subdirs, files in os.walk(RAW_DATA_DIR):
    for name in files:
        if ".edf" in name:
            edf_files.append(os.path.join(path, name))

In [6]:
VARIABLE_LENGTH = False

In [7]:
np.random.seed(123)

resampled_files = os.listdir(RESAMPLE_DIR)
for split in ['train', 'dev', 'test']:
    physical_clip_len = int(FREQUENCY*CLIP_LEN)
    
    if VARIABLE_LENGTH:
        filemarker = os.path.join("variable_length", 
                split+"_cliplen"+str(CLIP_LEN)+"_stride"+str(STRIDE)+"_timestep"+str(TIME_STEP_SIZE)+".txt")
    else:
        filemarker = split+"_cliplen"+str(CLIP_LEN)+"_stride"+str(STRIDE)+".txt"
    
    write_str = []
    for h5_fn in resampled_files:
        edf_fn = h5_fn.split('.h5')[0]+'.edf'
        if edf_fn not in FILES_TO_CONSIDER[split]:
            continue
        edf_fn_full = [file for file in edf_files if edf_fn in file]
        if len(edf_fn_full) != 1:
            print("{} found {} times!".format(edf_fn, len(edf_fn_full)))
            print(edf_fn_full)
        edf_fn_full = edf_fn_full[0]
        seizure_times = getSeizureTimes(edf_fn_full.split('.edf')[0])
        
        h5_fn_full = os.path.join(RESAMPLE_DIR, h5_fn)
        with h5py.File(h5_fn_full, 'r') as hf:
            resampled_sig = hf["resampled_signal"][()]
        
        if VARIABLE_LENGTH:
            num_clips = (resampled_sig.shape[-1] - CLIP_LEN * FREQUENCY) // (STRIDE * FREQUENCY) + 2
        else:
            num_clips = (resampled_sig.shape[-1] - CLIP_LEN * FREQUENCY) // (STRIDE * FREQUENCY) + 1
        
        for i in range(num_clips):
            start_window = i * FREQUENCY * STRIDE
            end_window = np.minimum(start_window + FREQUENCY * CLIP_LEN, resampled_sig.shape[-1])
            
            # only include last short clip if it's longer than 60s time step size
            if VARIABLE_LENGTH:
                if (i == num_clips-1) and (end_window - start_window) < (TIME_STEP_SIZE * FREQUENCY):
                    break
                
            is_seizure = 0
            for t in seizure_times:
                start_t = int(t[0] * FREQUENCY)
                end_t = int(t[1] * FREQUENCY)
                if not ((end_window < start_t) or (start_window > end_t)):
                    is_seizure = 1
                    break
            write_str.append(edf_fn + ',' + str(i) + ',' + str(is_seizure) + '\n')
    
    np.random.shuffle(write_str)
    with open(filemarker, 'w') as f:
        for curr_str in write_str:
            f.writelines(curr_str)

# Get seizure/non-seizure balanced train set

In [8]:
if VARIABLE_LENGTH:
    train_filemarker = os.path.join("variable_length", 
                split+"_cliplen"+str(CLIP_LEN)+"_stride"+str(STRIDE)+"_timestep"+str(TIME_STEP_SIZE)+".txt")
else:
    train_filemarker = os.path.join("train_cliplen"+str(CLIP_LEN)+"_stride"+str(STRIDE)+".txt")
        
with open(train_filemarker, 'r') as f:
    train_str = f.readlines()

In [9]:
sz_tuples = []
nonsz_tuples = []
for curr_str in train_str:
    file, clip_idx, sz_label = curr_str.strip('\n').split(',')
    sz_label = int(sz_label)
    if sz_label == 1:
        sz_tuples.append((file, clip_idx, sz_label))
    else:
        nonsz_tuples.append((file, clip_idx, sz_label))
print(len(sz_tuples))
print(len(nonsz_tuples))

13646
183000


### Keep all the seizure files and undersample non-seizure files...

In [10]:
np.random.seed(123)

np.random.shuffle(nonsz_tuples)
nonsz_tuple_small = nonsz_tuples[:len(sz_tuples)]

len(nonsz_tuple_small)

13646

In [11]:
balanced_files = sz_tuples + nonsz_tuple_small
np.random.shuffle(balanced_files)

In [12]:
balanced_files[:5]

[('00007584_s004_t000.edf', '70', 1),
 ('00002991_s004_t006.edf', '2', 0),
 ('00007032_s008_t000.edf', '146', 1),
 ('00012940_s001_t003.edf', '173', 1),
 ('00010455_s003_t001.edf', '45', 1)]

In [13]:
if VARIABLE_LENGTH:
    balanced_train_filemarker = os.path.join(
        "variable_length", "train_cliplen"+str(CLIP_LEN)+"_stride"+str(STRIDE)+"_timestep"+str(TIME_STEP_SIZE)+"_balanced.txt")
else:
    balanced_train_filemarker = "train_cliplen"+str(CLIP_LEN)+"_stride"+str(STRIDE)+"_balanced.txt"

with open(balanced_train_filemarker, "w") as f:
    for tup in balanced_files:
        f.writelines(tup[0] + ',' + str(tup[1]) + ',' + str(tup[2]) + '\n')

# Get `pos_weight` to weigh the loss function

In [417]:
from data.dataloader import load_dataset
import torch

In [320]:
## on gemini
RAW_DATA_DIR = "/media/nvme_data/TUH/v1.5.2/"
PREPROC_DIR = "/media/nvme_data/siyitang/TUH_eeg_seq_v1.5.2/resampled_signal"

In [32]:
USE_FFT = True

In [21]:
dataloaders = load_dataset(input_dir=PREPROC_DIR, 
                           raw_data_dir=RAW_DATA_DIR, 
                           train_batch_size=64, 
                           test_batch_size=64,
                           clip_len=CLIP_LEN, 
                           time_step_size=TIME_STEP_SIZE, 
                           stride=STRIDE,
                           standardize=False, 
                           num_workers=8, 
                           augmentation=True,
                           use_fft=USE_FFT,
                          balance_train=True)

In [30]:
y_train = []
x_train = []
file_name_train = []
for x, y, _, _, _, file_name in dataloaders['train']:
    y_train.append(y)
    x_train.append(x)
    file_name_train.extend(file_name)

In [31]:
x_train = torch.cat(x_train, dim=0)
x_train = x_train.data.cpu().numpy()
x_train.shape

(1152, 60, 19, 100)

In [23]:
y_train = torch.cat(y_train, dim=0)
y_train = y_train.data.cpu().numpy()
y_train.shape

(1152, 60)

In [24]:
y_single = np.sum(y_train, axis=-1)
y_single.shape

(1152,)

In [25]:
pos_clip_idxs = (y_single != 0)
pos_timesteps = np.sum(y_train[pos_clip_idxs,:] == 1)
pos_timesteps

7837

In [26]:
pos_clip_neg_timesteps = np.sum(y_train[pos_clip_idxs,:] == 0)
pos_clip_neg_timesteps

26723

In [27]:
neg_clip_idxs = (y_single == 0)
neg_clip_neg_timesteps = np.sum(y_train[neg_clip_idxs,:] == 0)
neg_clip_neg_timesteps

34560

In [28]:
print("Total positive time steps:", pos_timesteps)
print("Total negative time steps:", neg_clip_neg_timesteps+pos_clip_neg_timesteps)

Total positive time steps: 7837
Total negative time steps: 61283


In [29]:
pos_weight = (neg_clip_neg_timesteps+pos_clip_neg_timesteps) / pos_timesteps
pos_weight

7.8197014163583

# Compute mean & std of train set

In [33]:
x_train.shape

(1152, 60, 19, 100)

In [34]:
mean = np.mean(x_train)
std = np.std(x_train)
print("Mean: {:.3f}, Std: {:.3f}".format(mean, std))

Mean: 3.882, Std: 1.566


In [37]:
if USE_FFT:
    with open("./mean_cliplen"+str(CLIP_LEN)+"_stride"+str(STRIDE)+"_fft.pkl", "wb") as pf:
        pickle.dump(mean, pf)
    with open("./std_cliplen"+str(CLIP_LEN)+"_stride"+str(STRIDE)+"_fft.pkl", "wb") as pf:
        pickle.dump(std, pf)
else:
    with open("./mean_cliplen"+str(CLIP_LEN)+"_stride"+str(STRIDE)+".pkl", "wb") as pf:
        pickle.dump(mean, pf)
    with open("./std_cliplen"+str(CLIP_LEN)+"_stride"+str(STRIDE)+".pkl", "wb") as pf:
        pickle.dump(std, pf)