In [399]:
import torch
from torch import nn, optim

import numpy as np
import pandas as pd
import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc
from pandas.plotting import register_matplotlib_converters
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score

import os, random, copy, warnings
from pathlib import Path
warnings.filterwarnings("ignore", category=DeprecationWarning)

In [400]:
%config InlineBackend.figure_format='retina'
sns.set(style='whitegrid', palette='muted', font_scale=1.2)
rcParams['figure.figsize'] = 14, 10
register_matplotlib_converters()
RANDOM_SEED = 369
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

<torch._C.Generator at 0x7ff18e789890>

In [401]:
# paths
arr_dir = '/home/SharedFiles/Projects/EEG/Data/seg_arr/seq05m_seg30s'

In [402]:
# retrieve all paths
all_preictals, all_interictals, patients = [], [], []
for patient in os.listdir(arr_dir):
    if patient.startswith('SNUCH'):
        patient_pth = os.path.join(arr_dir, patient)
        patients.append(patient)
        all_preictals.append([])
        all_interictals.append([])
        for ictalType in os.listdir(patient_pth):
            ictalType_pth = os.path.join(patient_pth, ictalType) 
            if os.path.isdir(ictalType_pth):
                if ictalType == 'preictals':
                    for preictal in os.listdir(ictalType_pth):
                        if not preictal.startswith('._'):
                            preictal_pth = os.path.join(ictalType_pth, preictal)
                            all_preictals[-1].append(preictal_pth)
                if ictalType == 'interictals':
                    for interictal in os.listdir(ictalType_pth):
                        if not interictal.startswith('._'):
                            interictal_pth = os.path.join(ictalType_pth, interictal)
                            all_interictals[-1].append(interictal_pth)

                            
print('all_preictals of n patients; n =', len(all_preictals))
print('all_interictals of n patients; n =', len(all_interictals))

all_preictals of n patients; n = 11
all_interictals of n patients; n = 11


In [403]:
len(all_preictals[0])

540

# Split

In [404]:
# train valid test split
def split_data(preictals, interictals, train_ratio):
    X_train, y_train, X_val, y_val, X_test, y_test = [], [], [], [], [], []
    random.shuffle(preictals)
    random.shuffle(interictals)
    
    preictals_train_size = int(len(preictals)*train_ratio)
    preictals_val_size = int((len(preictals)-preictals_train_size)/2)
    preictals_val_ind = preictals_train_size + preictals_val_size - 1
    X_train += [np.load(array) for array in preictals[:preictals_train_size]]
    X_val += [np.load(array) for array in preictals[preictals_train_size:preictals_val_ind]]
    X_test += [np.load(array) for array in preictals[preictals_val_ind:]]
    y_train += [1 for ictal in range(preictals_train_size)]
    y_val += [1 for ictal in range(preictals_val_size)]
    y_test += [1 for ictal in range(len(preictals[preictals_val_ind:]))]

    interictals_train_size = int(len(interictals)*train_ratio)
    interictals_val_size = int((len(interictals)-interictals_train_size)/2)
    interictals_val_ind = interictals_train_size + interictals_val_size - 1
    X_train += [np.load(array) for array in interictals[:interictals_train_size]]
    X_val += [np.load(array) for array in interictals[interictals_train_size:interictals_val_ind]]
    X_test += [np.load(array) for array in interictals[interictals_val_ind:]]
    y_train += [0 for ictal in range(interictals_train_size)]
    y_val += [0 for ictal in range(interictals_val_size)]
    y_test += [0 for ictal in range(len(interictals[interictals_val_ind:]))]
    
    return np.array(X_train), np.array(y_train), np.array(X_val), np.array(y_val), np.array(X_test), np.array(y_test)

In [405]:
X_train, y_train, X_val, y_val, X_test, y_test = split_data(all_preictals[0], all_interictals[0], 0.7)

In [406]:
# len(X_train) + len(X_val) + len(X_test) == len(y_train) + len(y_val) + len(y_test) == len(all_preictals[0]) + len(all_interictals[0])

In [407]:
#Scale

MIN = X_train.min()
MAX = X_train.max()

def MinMaxScale(array, min, max):
    return (array - min) / (max - min)

#MinMax 스케일링
X_train = MinMaxScale(X_train, MIN, MAX)
X_val = MinMaxScale(X_val, MIN, MAX)
X_test = MinMaxScale(X_test, MIN, MAX)

In [408]:
# scaler = StandardScaler() # 2D only

# X_train = scaler.fit_transform(X_train)
# X_val = scaler.fit_transform(X_val)
# X_test = scaler.fit_transform(X_test)

In [409]:
def make_Tensor(array):
    return torch.from_numpy(array).float()

X_train = make_Tensor(X_train)
y_train = make_Tensor(y_train)
X_val = make_Tensor(X_val)
y_val = make_Tensor(y_val)
X_test = make_Tensor(X_test)
y_test = make_Tensor(y_test)

print(X_train.shape, X_val.shape, X_test.shape)
print(y_train.shape, y_val.shape, y_test.shape)

torch.Size([754, 21, 6000]) torch.Size([162, 21, 6000]) torch.Size([162, 21, 6000])
torch.Size([754]) torch.Size([162]) torch.Size([162])


# Model (LSTM)

https://pseudo-lab.github.io/Tutorial-Book/chapters/time-series/Ch4-LSTM.html

In [410]:
class SeizurePredictorLSTM(nn.Module):
    def __init__(self, n_features, n_hidden, seg_len, n_layers):
        #print('n_features', n_features) ## TEST
        #print('seg_len', seg_len) ## TEST
        super(SeizurePredictorLSTM, self).__init__()
        self.n_hidden = n_hidden
        self.seg_len = seg_len
        self.n_layers = n_layers
        self.lstm = nn.LSTM(
            input_size=n_features,
            hidden_size=n_hidden,
            num_layers=n_layers
        )
        self.linear = nn.Linear(in_features=n_hidden, out_features=2)
    def reset_hidden_state(self):
        self.hidden = (
            torch.zeros(self.n_layers, self.seg_len, self.n_hidden),
            torch.zeros(self.n_layers, self.seg_len, self.n_hidden)
        )
    def forward(self, sequences):
        lstm_out, self.hidden = self.lstm(
            sequences.view(len(sequences), self.seg_len, -1),
            self.hidden
        )
        last_time_step = lstm_out.view(self.seg_len, len(sequences), self.n_hidden)[-1]
        y_pred = self.linear(last_time_step)
        return y_pred


# Train

In [458]:
# # TEST
# model = SeizurePredictorLSTM(
#     n_features=len(X_train[0]),
#     n_hidden=4,
#     n_layers=1,
#     seg_len = len(X_train[0][0])
# )
# model=model
# train_data = X_train
# train_labels = y_train
# val_data = X_val
# val_labels = y_val
# num_epochs=2
# verbose=1
# patience=50

def train_model(model, train_data, train_labels, val_data, val_labels, num_epochs, verbose, patience):

    loss_fn = torch.nn.L1Loss() #
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    train_hist = []
    val_hist = []
    val_auc = []
    for t in range(num_epochs):
        epoch_loss = 0
        for idx, ictal in enumerate(train_data): 
            model.reset_hidden_state() # ictal 별 hidden state reset

            # train loss
            ictal = torch.unsqueeze(ictal, 0)
            y_pred = model(ictal)
            loss = loss_fn(y_pred[0].float(), train_labels[idx].view(-1,1,1)) # 1개의 step에 대한 loss

            # update weights
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        train_hist.append(epoch_loss / len(train_data))

        if val_data is not None:
            with torch.no_grad():
                val_loss = 0
                val_preds = []
                for val_idx, val_ictal in enumerate(val_data):
                    model.reset_hidden_state() # ictal 별로 hidden state 초기화 

                    val_ictal = torch.unsqueeze(val_ictal, 0)
                    y_val_pred = model(val_ictal)
                    val_preds.append(y_val_pred[0].float())
                    val_step_loss = loss_fn(y_val_pred[0].float(), val_labels[val_idx].view(-1,1,1))

                    val_loss += val_step_loss
                    
                    

            val_hist.append(val_loss / len(val_data)) # val hist에 추가

            ## verbose 번째 마다 loss 출력 
            if t % verbose == 0:
                print(f'Epoch {t} train loss: {epoch_loss / len(train_data)} val loss: {val_loss / len(val_data)}')

                # print('val_labels shape', val_labels.cpu().detach().numpy().shape)
                # print('val_labels[0]', val_labels.cpu().detach().numpy()[0])
                # print('val_preds shape', np.array(val_preds).shape)
                # print('val_preds[0]', np.array(val_preds)[0])
                auc = roc_auc_score(val_labels.cpu().detach().numpy(), val_preds)#.cpu().detach().numpy())
                val_auc.append(auc)

                print("Validation AUC: {}".format(auc))
            ## patience 번째 마다 early stopping 여부 확인
            if (t % patience == 0) & (t != 0):
                ## loss가 커졌다면 early stop
                if val_hist[t - patience] < val_hist[t] :
                    print('\n Early Stopping')
                    break

        elif t % verbose == 0:
            print(f'Epoch {t} train loss: {epoch_loss / len(train_data)}')


    return model, train_hist, val_hist, val_auc

In [459]:
# val_auc

In [460]:
# val_labels.cpu().detach().numpy()

In [461]:
model = SeizurePredictorLSTM(
    n_features=len(X_train[0]),
    n_hidden=4,
    n_layers=1,
    seg_len = len(X_train[0][0])
)
model, train_hist, val_hist, val_auc= train_model(
    model,
    X_train,
    y_train,
    X_val,
    y_val,
    num_epochs=100,
    verbose=1,
    patience=50
)

  return F.l1_loss(input, target, reduction=self.reduction)


Epoch 0 train loss: 0.2654950601994912 val loss: 0.49999791383743286


ValueError: only one element tensors can be converted to Python scalars

In [None]:
plt.plot(train_hist, label="Training loss")
plt.plot(val_hist, label="Val loss")
plt.legend()

In [None]:
plt.plot(val_auc)