In [17]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
import matplotlib as mlp
import mne
from scipy.io import loadmat

import os
from pathlib import Path
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

# Get

In [18]:
# paths
arr_dir = '/home/SharedFiles/SharedNotebooks/EEG/Outputs/seg_arr'

In [19]:
# retrieve all paths
all_preictals, all_interictals, patients = [], [], []
for patient in os.listdir(arr_dir):
    if patient.startswith('SNUCH'):
        patient_pth = os.path.join(arr_dir, patient)
        patients.append(patient)
        all_preictals.append([])
        all_interictals.append([])
        for ictalType in os.listdir(patient_pth):
            ictalType_pth = os.path.join(patient_pth, ictalType) 
            if os.path.isdir(ictalType_pth):
                if ictalType == 'preictals':
                    for preictal in os.listdir(ictalType_pth):
                        if not preictal.startswith('._'):
                            preictal_pth = os.path.join(ictalType_pth, preictal)
                            all_preictals[-1].append(preictal_pth)
                if ictalType == 'interictals':
                    for interictal in os.listdir(ictalType_pth):
                        if not interictal.startswith('._'):
                            interictal_pth = os.path.join(ictalType_pth, interictal)
                            all_interictals[-1].append(interictal_pth)

                            
print('all_preictals of n patients; n =', len(all_preictals))
print('all_interictals of n patients; n =', len(all_interictals))

all_preictals of n patients; n = 11
all_interictals of n patients; n = 11


In [20]:
len(all_preictals[0])

540

In [21]:
patients

['SNUCH01',
 'SNUCH02',
 'SNUCH03',
 'SNUCH04',
 'SNUCH05',
 'SNUCH06',
 'SNUCH07',
 'SNUCH08',
 'SNUCH09',
 'SNUCH10',
 'SNUCH11']

# Preprocess

### Averaging
https://mne.tools/dev/auto_tutorials/evoked/10_evoked_overview.html
https://mne.tools/stable/generated/mne.grand_average.html
https://mne.tools/0.17/auto_tutorials/plot_epoching_and_averaging.html

# Learn

In [22]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class NN(nn.Module):
    def __init__(self,NFILT=256,NOUT=2):
        super(NN,self).__init__()
        self.conv0 = nn.Conv2d(1,NFILT,kernel_size=(200,3),padding=(0,1),bias=False)
        self.bn0 = nn.BatchNorm2d(NFILT)
        self.gru = nn.GRU(input_size=NFILT,hidden_size=128,num_layers=1,batch_first=True,bidirectional=False)
        self.fc1 = nn.Linear(128,NOUT)

    def forward(self, x):
        x = F.relu(self.bn0(self.conv0(x)))
        x = x.squeeze().permute(0,2,1)
        x,_ = self.gru(x)
        x = F.dropout(x,p=0.5,training=self.training)
        x = self.fc1(x)
        return x

In [23]:
import torch
import numpy as np
from scipy.special import softmax,expit
from sklearn.metrics import f1_score,confusion_matrix,cohen_kappa_score,roc_curve,roc_auc_score,average_precision_score


class Statistics(object):
    def __init__(self):
        self.target = []
        self.logits = []

    def reset(self):
        self.target = []
        self.logits = []

    @staticmethod
    def idx2onehot(idx_array):
        y = np.zeros((idx_array.shape[0], idx_array.max() + 1))
        y[np.arange(y.shape[0]), idx_array] = 1
        return y

    @staticmethod
    def F1(conf):
        x0 = np.sum(conf, 0)
        x1 = np.sum(conf, 1)
        dg = np.diag(conf)
        f1 = 2 * dg / (x0 + x1)
        return f1

    @staticmethod
    def Kappa(conf):
        x0 = np.sum(conf, 0)
        x1 = np.sum(conf, 1)
        N = np.sum(np.sum(conf))
        ef = np.sum(x0 * x1 / N)
        dg = np.sum(np.diag(conf))
        K = (dg - ef) / (N - ef)
        return K

    def append(self,target,logits):
        self.logits.append(logits.data.cpu().numpy())
        self.target.append(target.data.cpu().numpy())

    @staticmethod
    def random_auprc(target):
        y_chance = np.zeros((target.max()+1,))
        for i in range(target.max()+1):
            y_chance[i] = len(target[target==i]) / len(target)

        return y_chance

    def evaluate(self):
        self.logits = np.concatenate(self.logits)
        self.target = np.concatenate(self.target).astype('int32')

        self.probs = softmax(self.logits,axis=1)
        self.argmax = np.argmax(self.probs,axis=1)

        CONF = np.array(confusion_matrix(y_true=self.target,y_pred=self.argmax))
        F1 = Statistics.F1(CONF)
        KPS = Statistics.Kappa(CONF)
        AUROC = roc_auc_score(y_true=Statistics.idx2onehot(self.target),y_score=self.probs,average=None)
        AUPRC = average_precision_score(y_true=Statistics.idx2onehot(self.target),y_score=self.probs,average=None)
        AUPRC_chance = self.random_auprc(self.target)

        print(CONF)
        print(F1)
        print(KPS)
        print(AUROC,np.mean(AUROC))
        print(AUPRC,np.mean(AUPRC))
        print(AUPRC_chance)

        self.reset()

In [24]:
import copy
import scipy.signal as signal
import scipy.stats as stats
import scipy.io as sio
import tqdm

class Dataset:
    def __init__(self,path):
        self.path = path
        if self.path[-1] != '/':
            self.path += '/'
        self.files = np.concatenate([[self.path+'interictals/'+x for x in os.listdir(self.path+'interictals/')], [self.path+'preictals/'+x for x in os.listdir(self.path+'preictals/')]])
        self.N_interictal = 15000
        self.N_preictal = 5400
        self.targets=np.array([0 if x < self.N_interictal else 1 for x in range(len(self.files))])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, item):
        data = np.load(self.files[item])
        print(data) #### TEST
        target = self.targets[item]
        data = np.expand_dims(data,axis=0)
        return data,target

    def split_random(self,N_valid):
        sample = np.random.choice(len(self), len(self), replace=False)
        self.files, self.targets = self.files[sample], self.targets[sample]
        train = copy.deepcopy(self)
        valid = copy.deepcopy(self)
        train.files, train.targets = train.files[N_valid:], train.targets[N_valid:]
        valid.files, valid.targets = valid.files[:N_valid], valid.targets[:N_valid]
        return train,valid

In [25]:
cnt_all_interictals = 0
cnt_all_preictals = 0
for interictals in all_interictals:
    for interictal in interictals:
        cnt_all_interictals += 1
        
for preictals in all_preictals:
    for preictal in preictals:
        cnt_all_preictals += 1
        
print(cnt_all_interictals, cnt_all_preictals)

12912 12914


* https://towardsdatascience.com/conv1d-and-conv2d-did-you-realize-that-conv1d-is-a-subclass-of-conv2d-8819675bec78
* https://discuss.pytorch.org/t/how-to-set-batch-size-correctly-when-using-multi-gpu-training/131262

In [26]:
import torch
from torch.utils.data import DataLoader

for patient in os.listdir(arr_dir):
    
    if patient.startswith('SNUCH01'): ##### TEST
        pat_interictals, pat_preictals = [],[]
        patient_pth = os.path.join(arr_dir, patient)
        print(patient_pth)
        for ictalType in os.listdir(patient_pth):
            ictalType_pth = os.path.join(patient_pth, ictalType) 
            if os.path.isdir(ictalType_pth):
                if ictalType == 'preictals':
                    for preictal in os.listdir(ictalType_pth):
                        if not preictal.startswith('._'):
                            preictal_pth = os.path.join(ictalType_pth, preictal)
                            pat_preictals.append(preictal_pth)
                if ictalType == 'interictals':
                    for interictal in os.listdir(ictalType_pth):
                        if not interictal.startswith('._'):
                            interictal_pth = os.path.join(ictalType_pth, interictal)
                            pat_interictals.append(interictal_pth)
    
    N = len(pat_interictals) + len(pat_preictals) 
    print('pat_interictals:', len(pat_interictals), '| pat_preictals:', len(pat_preictals), '| total:', N)
        
    train_dataset, valid_dataset = Dataset(patient_pth).split_random(int(0.3*N))
    print("Train:", len(train_dataset), '| valid:', len(valid_dataset))
    
    BATCHSIZE = 16
    NWORKERS = 2
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    print('DEVICE:', DEVICE)
    TRAIN = DataLoader(dataset=train_dataset,
                       batch_size=BATCHSIZE,
                       shuffle=True,
                       drop_last=False,
                       num_workers=NWORKERS)

    VALID = DataLoader(dataset=valid_dataset,
                       batch_size=BATCHSIZE,
                       shuffle=True,
                       drop_last=False,
                       num_workers=NWORKERS)

    model = NN(2).to(DEVICE)
    optimizer = optim.Adam(model.parameters(),lr=1e-3,weight_decay=1e-4)
    loss = nn.CrossEntropyLoss()
    statistics = Statistics()
    N_epochs = 30
    epoch_loss={'train':np.zeros((N_epochs)),'val':np.zeros((N_epochs))}
    for epoch in range(N_epochs):
        model.train()
        for i,(x,t) in enumerate(TRAIN):
            optimizer.zero_grad()
            x = x.to(DEVICE).float()
            t = t.to(DEVICE).long()
            y = model(x)
            J = loss(input=y[:,-1,:],target=t)
            epoch_loss['train'][epoch] += J.data.cpu().numpy()
            J.backward()
            optimizer.step()

            if i%50==0:
                print('EPOCH:{}\tITER:{}\tLOSS:{}'.format(str(epoch).zfill(2),
                                                          str(i).zfill(5),
                                                          J.data.cpu().numpy()))
        epoch_loss['train'][epoch]/=len(TRAIN)

        # evaluate results for validation set
        model.eval()
        for i,(x,t) in enumerate(VALID):
            x = x.to(DEVICE).float()
            t = t.to(DEVICE).long()
            y = model(x)
            val_loss = loss(input=y[:,-1,:],target=t)
            epoch_loss['val'][epoch] += val_loss.data.cpu().numpy()
            statistics.append(target=t,logits=y[:,-1,:])
        epoch_loss['val'][epoch]/=len(VALID)
        statistics.evaluate()

/home/SharedFiles/SharedNotebooks/EEG/Outputs/seg_arr/SNUCH01
pat_interictals: 538 | pat_preictals: 540 | total: 1078
Train: 755 | valid: 323
DEVICE: cuda
[[ 2.30000005e-05  2.49999994e-05  2.70000000e-05 ... -1.60000000e-05
  -1.80000006e-05 -2.20000002e-05]
 [-6.39999998e-05 -6.39999998e-05 -6.39999998e-05 ...  1.49999996e-05
   1.40000002e-05  1.29999999e-05]
 [-1.00999998e-04 -1.02999998e-04 -1.04999999e-04 ...  1.29999999e-05
   1.60000000e-05  1.99999995e-05]
 ...
 [-2.90000007e-05 -2.99999992e-05 -2.99999992e-05 ...  9.00000032e-06
   1.10000001e-05  1.20000004e-05]
 [-5.40000001e-05 -5.40000001e-05 -5.60000008e-05 ...  3.19999999e-05
   3.89999987e-05  4.60000010e-05]
 [-3.99999990e-05 -4.19999997e-05 -4.50000007e-05 ...  3.99999999e-06
   3.99999999e-06  4.99999987e-06]]
[[-1.89999992e-05 -1.89999992e-05 -1.60000000e-05 ...  1.40000002e-05
   9.99999975e-06  7.99999998e-06]
 [ 3.40000006e-05  3.19999999e-05  3.09999996e-05 ...  2.59999997e-05
   2.09999998e-05  1.80000006e-05]

RuntimeError: Calculated padded input size per channel: (21 x 6002). Kernel size: (200 x 3). Kernel size can't be greater than actual input size

In [None]:
break

In [None]:
# def epoch_raw(raw_path, secs=600):
#     raw_files_interictal = sorted([f for f in os.listdir(raw_path) if 'interictal' in f])
#     raw_files_preictal = sorted([f for f in os.listdir(raw_path) if 'preictal' in f])
#     raw_files = {'interictal': raw_files_interictal, 'preictal': raw_files_preictal}
#     pat_dir = '/home/SharedFiles/SharedNotebooks/EEG/Outputs/cnn_outputs'
#     train_dir = pat_dir+'/train/'
#     train_dir_interictal = train_dir+'/interictal/'
#     train_dir_preictal = train_dir+'/preictal/'
#     if not os.path.isdir(pat_dir):
#         os.mkdir(pat_dir)
#     if not os.path.isdir(train_dir):
#         os.mkdir(train_dir)
#     if not os.path.isdir(train_dir_interictal):
#         os.mkdir(train_dir_interictal)
#     if not os.path.isdir(train_dir_preictal):
#         os.mkdir(train_dir_preictal)
#     train_dirs = {'interictal': train_dir_interictal, 'preictal': train_dir_preictal}
    
#     for category in ['preictal']:
#         for _,f in enumerate(raw_files[category]):
#             search_key = "_segment_"       # search key string
#             try:
#                 matFile = mne.io.read_raw_fif(os.path.join(raw_path,f))
#             except OSError as error:
#                 print(error)
#             segment = dict(filter(lambda item: search_key in item[0], matFile.items()))
#             segment_list_values = list(segment.values())
#             x = np.array(segment_list_values[0][0][0][0], dtype=np.float)
#             Fs = np.array(segment_list_values[0][0][0][2], dtype=np.float)[0][0]
#             L = int(Fs*secs)
#             for electrode in range(x.shape[0]):
#                 for i in range(L,x.shape[1]+1,L):
#                     digit = f'_{electrode}_{int(i//L)}'
#                     save_filepath = train_dirs[category]+f[:-4]+digit+'.npy'
#                     spect,_,_,_ = plt.specgram(x[electrode, i-L:i], NFFT=1024, Fs=5000, noverlap=128)
#                     spect = stats.zscore(spect[:200,:],axis=1)
#                     try:
#                         np.save(save_filepath, spect)
#                     except OSError as e:
#                         print(e)
#             print(f"Finished file {f}")
#         print(f"Finished category {category}")   
#     return

In [None]:
fig=plt.figure()
plt.title("Epoch Loss")
plt.plot(epoch_loss['train'])
plt.plot(epoch_loss['val'])
plt.legend(['train', 'val'])
plt.show()

In [None]:
for i in range(len(raws)):
    patient = 'SNUCH' + str(i+1).zfill(2)
    pat_dir = os.path.join(sliced_raws_pth, patient)
    epoch_raw(pat_dir, 30)