In [None]:
!pip install -r requirements.txt

In [1]:
import os
from glob import glob
import numpy as np
from tqdm import tqdm
import random
import pandas as pd

#torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from torch.utils.data import Dataset,DataLoader

#mne
from mne import Epochs, pick_types, find_events
from mne.io import concatenate_raws, read_raw_edf

#signal preprocess
import scipy.signal as ssig    

from sklearn.metrics import f1_score, classification_report
from sklearn.model_selection import train_test_split

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seed = 42

In [2]:
# seed 고정 함수 및 seed 고정
def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(seed)

In [3]:
# train_path = '/home/maic-player/dataset/1_Train+Val'
# test_path = '/home/maic-player/dataset/2_Test'
train_path = '/home/maic-player/FINAL_SUBMISSION/npy_60000/train'
valid_path = '/home/maic-player/FINAL_SUBMISSION/npy_60000/valid'
test_path = '/home/maic-player/FINAL_SUBMISSION/npy_60000/test'
save_path = '/home/maic-player/FINAL_SUBMISSION/results'
save_name = 'deepsleepnet_60000_focal.pth'
csv_name = 'deepsleepnet_60000_focal.csv'

In [4]:
train_path_list = glob(os.path.join(train_path, '*.npy'))
valid_path_list = glob(os.path.join(valid_path, '*.npy'))
test_path_list = glob(os.path.join(test_path, '*.npy'))

In [5]:
class EDFDataLoader(Dataset):
    def __init__(self, mode):
        self.mode = mode

        if self.mode == 'train':
            self.dataset = train_path_list
            
        elif self.mode == 'valid':
            self.dataset = valid_path_list
                
        elif self.mode == 'test':
            self.dataset = test_path_list
        
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        data_np = np.load(self.dataset[idx])
        if 'Normal' in self.dataset[idx]:
            label = 0
        else:
            label = 1

        if (self.mode == 'train') or (self.mode == 'valid'):
            return data_np, torch.tensor(label)
        elif self.mode == 'test':
            return data_np, self.dataset[idx]

In [6]:
train_dataset = EDFDataLoader('train')
valid_dataset = EDFDataLoader('valid')

train_loader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=16)
valid_loader = DataLoader(dataset=valid_dataset, shuffle=False, batch_size=32)

In [7]:
class BiLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(BiLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=0.5, bidirectional=True)

    def forward(self, x):
        # set initial hidden and cell states
        h0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).cuda() #RuntimeError: Input and hidden tensors are not at the same device
        c0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).cuda()

        # forward propagate LSTM
        out, _ = self.lstm(x, (h0, c0))
        return out

class DeepSleepNet(nn.Module):

    def __init__(self, ch=24):
        super(DeepSleepNet, self).__init__()
        self.features_s = nn.Sequential(
            nn.Conv1d(ch, 64, 50, 6),
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(kernel_size=8, stride=8),
            nn.Dropout(),
            nn.Conv1d(64, 128, 6),
            nn.BatchNorm1d(128),
            nn.Conv1d(128, 128, 6),
            nn.BatchNorm1d(128),
            nn.Conv1d(128, 128, 6),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(kernel_size=2, stride=2),
        )
        self.features_l = nn.Sequential(
            nn.Conv1d(ch, 64, 400, 50),
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(kernel_size=4, stride=4),
            nn.Dropout(),
            nn.Conv1d(64, 128, 8),
            nn.BatchNorm1d(128),
            nn.Conv1d(128, 128, 8),
            nn.BatchNorm1d(128),
            nn.Conv1d(128, 128, 8),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(kernel_size=2, stride=2),
        )
        self.features_seq = nn.Sequential(
            BiLSTM(96640, 512, 2),
        )
        self.res = nn.Linear(96640, 1024)
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(1024, 2),
        )

    def forward(self, x):
        x_s = self.features_s(x)
        x_l = self.features_l(x)
        x_s = x_s.flatten(1,2)
        x_l = x_l.flatten(1,2)
        x = torch.cat((x_s, x_l),1) # [bs, 7296]
        x_seq = x.unsqueeze(1)
        x_blstm = self.features_seq(x_seq) # [bs, 1, 1024]
        x_blstm = torch.squeeze(x_blstm, 1)
        x_res = self.res(x)
        x = torch.mul(x_res, x_blstm)
        x = self.classifier(x)
        return x

In [8]:
# Reference: https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py

class FocalLoss(nn.Module):
    def __init__(self, gamma=0.8, alpha=0.7, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha,(float,int,int)): self.alpha = torch.Tensor([alpha,1-alpha])
        if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim()>2:
            input = input.view(input.size(0),input.size(1),-1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1,2)    # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1,input.size(2))   # N,H*W,C => N*H*W,C
        target = target.view(-1,1)

        logpt = F.log_softmax(input)
        logpt = logpt.gather(1,target)
        logpt = logpt.view(-1)
        pt = Variable(logpt.data.exp())

        if self.alpha is not None:
            if self.alpha.type()!=input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0,target.data.view(-1))
            logpt = logpt * Variable(at)

        loss = -1 * (1-pt)**self.gamma * logpt
        if self.size_average: return loss.mean()
        else: return loss.sum()

In [9]:
model = DeepSleepNet()
model.to(device)

# criterion = torch.nn.CrossEntropyLoss()
criterion = FocalLoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 100)

In [None]:
# training
print('Start Training')
print('-'*30)

best_val_f1 = 0
# early_stop_cnt = 0
best_pred = []
for epoch in range(100):
    for idx, data in enumerate(tqdm(train_loader)):
        train_data, train_labels = data
        train_data, train_labels  = train_data.to(device).float(), train_labels.to(device)
        y_pred = model(train_data)
#         y_pred = torch.sigmoid(y_pred)
        loss = criterion(y_pred, train_labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    # validation  
    model.eval()
    with torch.no_grad():    
        test_pred = []
        test_true = [] 
        for jdx, data in enumerate(tqdm(valid_loader)):
            test_data, test_labels = data
            test_data = test_data.to(device).float()
            test_labels = test_labels.to(device)
            y_pred = model(test_data)
            y_pred = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)
            test_pred.append(y_pred.detach().cpu().numpy())
            test_true.append(test_labels.detach().cpu().numpy())

        test_true = np.concatenate(test_true)
        test_pred = np.concatenate(test_pred)
        val_f1 = f1_score(test_true, test_pred, average='macro')
        
        model.train()
        if best_val_f1 < val_f1:
            best_val_f1 = val_f1
#             early_stop_cnt = 0
            best_pred = test_pred
            torch.save(model.state_dict(), os.path.join(save_path, save_name))
#         else:
#             early_stop_cnt += 1
#             if early_stop_cnt == 5:
#                 print('Epoch=%s, BatchID=%s, Val_F1=%.4f, Best_Val_F1=%.4f'%(epoch, idx, val_f1, best_val_f1))
#                 sys.exit()

        print('Epoch=%s, BatchID=%s, Val_F1=%.4f, Best_Val_F1=%.4f'%(epoch, idx, val_f1, best_val_f1))
        print(classification_report(test_true, test_pred))
        
    scheduler.step()

Start Training
------------------------------


  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:34<00:00,  2.32s/it]
100%|██████████| 2/2 [00:07<00:00,  3.73s/it]


Epoch=0, BatchID=14, Val_F1=0.6528, Best_Val_F1=0.6528
              precision    recall  f1-score   support

           0       0.41      1.00      0.58        14
           1       1.00      0.57      0.72        46

    accuracy                           0.67        60
   macro avg       0.71      0.78      0.65        60
weighted avg       0.86      0.67      0.69        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:24<00:00,  1.61s/it]
100%|██████████| 2/2 [00:07<00:00,  3.78s/it]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch=1, BatchID=14, Val_F1=0.4340, Best_Val_F1=0.6528
              precision    recall  f1-score   support

           0       0.00      0.00      0.00        14
           1       0.77      1.00      0.87        46

    accuracy                           0.77        60
   macro avg       0.38      0.50      0.43        60
weighted avg       0.59      0.77      0.67        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:30<00:00,  2.02s/it]
100%|██████████| 2/2 [00:09<00:00,  4.97s/it]


Epoch=2, BatchID=14, Val_F1=0.6447, Best_Val_F1=0.6528
              precision    recall  f1-score   support

           0       0.42      0.57      0.48        14
           1       0.85      0.76      0.80        46

    accuracy                           0.72        60
   macro avg       0.64      0.67      0.64        60
weighted avg       0.75      0.72      0.73        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:27<00:00,  1.81s/it]
100%|██████████| 2/2 [00:11<00:00,  5.53s/it]


Epoch=3, BatchID=14, Val_F1=0.6391, Best_Val_F1=0.6528
              precision    recall  f1-score   support

           0       0.40      0.71      0.51        14
           1       0.89      0.67      0.77        46

    accuracy                           0.68        60
   macro avg       0.64      0.69      0.64        60
weighted avg       0.77      0.68      0.71        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:27<00:00,  1.82s/it]
100%|██████████| 2/2 [00:08<00:00,  4.39s/it]


Epoch=4, BatchID=14, Val_F1=0.6563, Best_Val_F1=0.6563
              precision    recall  f1-score   support

           0       0.50      0.43      0.46        14
           1       0.83      0.87      0.85        46

    accuracy                           0.77        60
   macro avg       0.67      0.65      0.66        60
weighted avg       0.76      0.77      0.76        60



100%|██████████| 2/2 [00:08<00:00,  4.43s/it]


Epoch=7, BatchID=14, Val_F1=0.4667, Best_Val_F1=0.6563
              precision    recall  f1-score   support

           0       0.30      1.00      0.47        14
           1       1.00      0.30      0.47        46

    accuracy                           0.47        60
   macro avg       0.65      0.65      0.47        60
weighted avg       0.84      0.47      0.47        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:25<00:00,  1.73s/it]
100%|██████████| 2/2 [00:05<00:00,  2.84s/it]


Epoch=8, BatchID=14, Val_F1=0.6248, Best_Val_F1=0.6563
              precision    recall  f1-score   support

           0       0.57      0.29      0.38        14
           1       0.81      0.93      0.87        46

    accuracy                           0.78        60
   macro avg       0.69      0.61      0.62        60
weighted avg       0.76      0.78      0.75        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:27<00:00,  1.86s/it]
100%|██████████| 2/2 [00:08<00:00,  4.08s/it]


Epoch=9, BatchID=14, Val_F1=0.6590, Best_Val_F1=0.6590
              precision    recall  f1-score   support

           0       0.47      0.50      0.48        14
           1       0.84      0.83      0.84        46

    accuracy                           0.75        60
   macro avg       0.66      0.66      0.66        60
weighted avg       0.76      0.75      0.75        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:19<00:00,  1.32s/it]
100%|██████████| 2/2 [00:05<00:00,  2.82s/it]


Epoch=10, BatchID=14, Val_F1=0.5048, Best_Val_F1=0.6590
              precision    recall  f1-score   support

           0       1.00      0.07      0.13        14
           1       0.78      1.00      0.88        46

    accuracy                           0.78        60
   macro avg       0.89      0.54      0.50        60
weighted avg       0.83      0.78      0.70        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:23<00:00,  1.55s/it]
100%|██████████| 2/2 [00:06<00:00,  3.13s/it]


Epoch=11, BatchID=14, Val_F1=0.6475, Best_Val_F1=0.6590
              precision    recall  f1-score   support

           0       0.41      0.93      0.57        14
           1       0.96      0.59      0.73        46

    accuracy                           0.67        60
   macro avg       0.69      0.76      0.65        60
weighted avg       0.83      0.67      0.69        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:29<00:00,  1.94s/it]
100%|██████████| 2/2 [00:10<00:00,  5.03s/it]


Epoch=12, BatchID=14, Val_F1=0.6865, Best_Val_F1=0.6865
              precision    recall  f1-score   support

           0       0.47      0.64      0.55        14
           1       0.88      0.78      0.83        46

    accuracy                           0.75        60
   macro avg       0.68      0.71      0.69        60
weighted avg       0.78      0.75      0.76        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:30<00:00,  2.05s/it]
100%|██████████| 2/2 [00:09<00:00,  4.63s/it]


Epoch=13, BatchID=14, Val_F1=0.7304, Best_Val_F1=0.7304
              precision    recall  f1-score   support

           0       0.50      0.86      0.63        14
           1       0.94      0.74      0.83        46

    accuracy                           0.77        60
   macro avg       0.72      0.80      0.73        60
weighted avg       0.84      0.77      0.78        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:23<00:00,  1.55s/it]
100%|██████████| 2/2 [00:07<00:00,  3.62s/it]


Epoch=14, BatchID=14, Val_F1=0.5966, Best_Val_F1=0.7304
              precision    recall  f1-score   support

           0       0.44      0.29      0.35        14
           1       0.80      0.89      0.85        46

    accuracy                           0.75        60
   macro avg       0.62      0.59      0.60        60
weighted avg       0.72      0.75      0.73        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:25<00:00,  1.71s/it]
100%|██████████| 2/2 [00:05<00:00,  2.64s/it]


Epoch=15, BatchID=14, Val_F1=0.7948, Best_Val_F1=0.7948
              precision    recall  f1-score   support

           0       0.60      0.86      0.71        14
           1       0.95      0.83      0.88        46

    accuracy                           0.83        60
   macro avg       0.77      0.84      0.79        60
weighted avg       0.87      0.83      0.84        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:22<00:00,  1.49s/it]
100%|██████████| 2/2 [00:09<00:00,  4.94s/it]


Epoch=16, BatchID=14, Val_F1=0.6770, Best_Val_F1=0.7948
              precision    recall  f1-score   support

           0       0.43      0.93      0.59        14
           1       0.97      0.63      0.76        46

    accuracy                           0.70        60
   macro avg       0.70      0.78      0.68        60
weighted avg       0.84      0.70      0.72        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:24<00:00,  1.66s/it]
100%|██████████| 2/2 [00:08<00:00,  4.49s/it]


Epoch=17, BatchID=14, Val_F1=0.6231, Best_Val_F1=0.7948
              precision    recall  f1-score   support

           0       1.00      0.21      0.35        14
           1       0.81      1.00      0.89        46

    accuracy                           0.82        60
   macro avg       0.90      0.61      0.62        60
weighted avg       0.85      0.82      0.77        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:24<00:00,  1.64s/it]
100%|██████████| 2/2 [00:06<00:00,  3.46s/it]


Epoch=18, BatchID=14, Val_F1=0.7283, Best_Val_F1=0.7948
              precision    recall  f1-score   support

           0       0.53      0.71      0.61        14
           1       0.90      0.80      0.85        46

    accuracy                           0.78        60
   macro avg       0.71      0.76      0.73        60
weighted avg       0.81      0.78      0.79        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:22<00:00,  1.48s/it]
100%|██████████| 2/2 [00:09<00:00,  4.88s/it]


Epoch=19, BatchID=14, Val_F1=0.7304, Best_Val_F1=0.7948
              precision    recall  f1-score   support

           0       0.50      0.86      0.63        14
           1       0.94      0.74      0.83        46

    accuracy                           0.77        60
   macro avg       0.72      0.80      0.73        60
weighted avg       0.84      0.77      0.78        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:26<00:00,  1.74s/it]
100%|██████████| 2/2 [00:09<00:00,  4.63s/it]


Epoch=20, BatchID=14, Val_F1=0.7069, Best_Val_F1=0.7948
              precision    recall  f1-score   support

           0       0.48      0.79      0.59        14
           1       0.92      0.74      0.82        46

    accuracy                           0.75        60
   macro avg       0.70      0.76      0.71        60
weighted avg       0.82      0.75      0.77        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:25<00:00,  1.73s/it]
100%|██████████| 2/2 [00:04<00:00,  2.34s/it]


Epoch=21, BatchID=14, Val_F1=0.6622, Best_Val_F1=0.7948
              precision    recall  f1-score   support

           0       0.42      0.93      0.58        14
           1       0.97      0.61      0.75        46

    accuracy                           0.68        60
   macro avg       0.69      0.77      0.66        60
weighted avg       0.84      0.68      0.71        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:17<00:00,  1.20s/it]
100%|██████████| 2/2 [00:07<00:00,  3.83s/it]


Epoch=22, BatchID=14, Val_F1=0.7778, Best_Val_F1=0.7948
              precision    recall  f1-score   support

           0       0.62      0.71      0.67        14
           1       0.91      0.87      0.89        46

    accuracy                           0.83        60
   macro avg       0.77      0.79      0.78        60
weighted avg       0.84      0.83      0.84        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:32<00:00,  2.14s/it]
100%|██████████| 2/2 [00:07<00:00,  3.98s/it]


Epoch=23, BatchID=14, Val_F1=0.7054, Best_Val_F1=0.7948
              precision    recall  f1-score   support

           0       0.58      0.50      0.54        14
           1       0.85      0.89      0.87        46

    accuracy                           0.80        60
   macro avg       0.72      0.70      0.71        60
weighted avg       0.79      0.80      0.79        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:29<00:00,  1.98s/it]
100%|██████████| 2/2 [00:07<00:00,  3.63s/it]


Epoch=24, BatchID=14, Val_F1=0.7151, Best_Val_F1=0.7948
              precision    recall  f1-score   support

           0       0.48      0.86      0.62        14
           1       0.94      0.72      0.81        46

    accuracy                           0.75        60
   macro avg       0.71      0.79      0.72        60
weighted avg       0.83      0.75      0.77        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:34<00:00,  2.32s/it]
100%|██████████| 2/2 [00:08<00:00,  4.12s/it]


Epoch=25, BatchID=14, Val_F1=0.7460, Best_Val_F1=0.7948
              precision    recall  f1-score   support

           0       0.52      0.86      0.65        14
           1       0.95      0.76      0.84        46

    accuracy                           0.78        60
   macro avg       0.73      0.81      0.75        60
weighted avg       0.85      0.78      0.80        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:26<00:00,  1.75s/it]
100%|██████████| 2/2 [00:08<00:00,  4.03s/it]


Epoch=26, BatchID=14, Val_F1=0.6400, Best_Val_F1=0.7948
              precision    recall  f1-score   support

           0       0.67      0.29      0.40        14
           1       0.81      0.96      0.88        46

    accuracy                           0.80        60
   macro avg       0.74      0.62      0.64        60
weighted avg       0.78      0.80      0.77        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:29<00:00,  2.00s/it]
100%|██████████| 2/2 [00:10<00:00,  5.23s/it]


Epoch=27, BatchID=14, Val_F1=0.6739, Best_Val_F1=0.7948
              precision    recall  f1-score   support

           0       0.50      0.50      0.50        14
           1       0.85      0.85      0.85        46

    accuracy                           0.77        60
   macro avg       0.67      0.67      0.67        60
weighted avg       0.77      0.77      0.77        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:31<00:00,  2.07s/it]
100%|██████████| 2/2 [00:07<00:00,  3.86s/it]


Epoch=28, BatchID=14, Val_F1=0.6273, Best_Val_F1=0.7948
              precision    recall  f1-score   support

           0       0.43      0.43      0.43        14
           1       0.83      0.83      0.83        46

    accuracy                           0.73        60
   macro avg       0.63      0.63      0.63        60
weighted avg       0.73      0.73      0.73        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:31<00:00,  2.09s/it]
100%|██████████| 2/2 [00:08<00:00,  4.12s/it]


Epoch=29, BatchID=14, Val_F1=0.7378, Best_Val_F1=0.7948
              precision    recall  f1-score   support

           0       0.52      0.79      0.63        14
           1       0.92      0.78      0.85        46

    accuracy                           0.78        60
   macro avg       0.72      0.78      0.74        60
weighted avg       0.83      0.78      0.80        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:34<00:00,  2.28s/it]
100%|██████████| 2/2 [00:08<00:00,  4.42s/it]


Epoch=30, BatchID=14, Val_F1=0.7608, Best_Val_F1=0.7948
              precision    recall  f1-score   support

           0       0.59      0.71      0.65        14
           1       0.91      0.85      0.88        46

    accuracy                           0.82        60
   macro avg       0.75      0.78      0.76        60
weighted avg       0.83      0.82      0.82        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:29<00:00,  1.95s/it]
100%|██████████| 2/2 [00:08<00:00,  4.38s/it]


Epoch=31, BatchID=14, Val_F1=0.7173, Best_Val_F1=0.7948
              precision    recall  f1-score   support

           0       0.53      0.64      0.58        14
           1       0.88      0.83      0.85        46

    accuracy                           0.78        60
   macro avg       0.71      0.73      0.72        60
weighted avg       0.80      0.78      0.79        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:26<00:00,  1.79s/it]
100%|██████████| 2/2 [00:05<00:00,  2.73s/it]


Epoch=32, BatchID=14, Val_F1=0.6563, Best_Val_F1=0.7948
              precision    recall  f1-score   support

           0       0.50      0.43      0.46        14
           1       0.83      0.87      0.85        46

    accuracy                           0.77        60
   macro avg       0.67      0.65      0.66        60
weighted avg       0.76      0.77      0.76        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:27<00:00,  1.85s/it]
100%|██████████| 2/2 [00:08<00:00,  4.03s/it]


Epoch=33, BatchID=14, Val_F1=0.6919, Best_Val_F1=0.7948
              precision    recall  f1-score   support

           0       0.46      0.79      0.58        14
           1       0.92      0.72      0.80        46

    accuracy                           0.73        60
   macro avg       0.69      0.75      0.69        60
weighted avg       0.81      0.73      0.75        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:28<00:00,  1.92s/it]
100%|██████████| 2/2 [00:07<00:00,  3.70s/it]


Epoch=34, BatchID=14, Val_F1=0.6770, Best_Val_F1=0.7948
              precision    recall  f1-score   support

           0       0.43      0.93      0.59        14
           1       0.97      0.63      0.76        46

    accuracy                           0.70        60
   macro avg       0.70      0.78      0.68        60
weighted avg       0.84      0.70      0.72        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:31<00:00,  2.07s/it]
100%|██████████| 2/2 [00:09<00:00,  4.73s/it]


Epoch=35, BatchID=14, Val_F1=0.7221, Best_Val_F1=0.7948
              precision    recall  f1-score   support

           0       0.64      0.50      0.56        14
           1       0.86      0.91      0.88        46

    accuracy                           0.82        60
   macro avg       0.75      0.71      0.72        60
weighted avg       0.81      0.82      0.81        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:28<00:00,  1.90s/it]
100%|██████████| 2/2 [00:08<00:00,  4.06s/it]


Epoch=36, BatchID=14, Val_F1=0.7671, Best_Val_F1=0.7948
              precision    recall  f1-score   support

           0       0.64      0.64      0.64        14
           1       0.89      0.89      0.89        46

    accuracy                           0.83        60
   macro avg       0.77      0.77      0.77        60
weighted avg       0.83      0.83      0.83        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:37<00:00,  2.52s/it]
100%|██████████| 2/2 [00:04<00:00,  2.18s/it]


Epoch=37, BatchID=14, Val_F1=0.7173, Best_Val_F1=0.7948
              precision    recall  f1-score   support

           0       0.53      0.64      0.58        14
           1       0.88      0.83      0.85        46

    accuracy                           0.78        60
   macro avg       0.71      0.73      0.72        60
weighted avg       0.80      0.78      0.79        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:31<00:00,  2.10s/it]
100%|██████████| 2/2 [00:06<00:00,  3.08s/it]


Epoch=38, BatchID=14, Val_F1=0.6231, Best_Val_F1=0.7948
              precision    recall  f1-score   support

           0       1.00      0.21      0.35        14
           1       0.81      1.00      0.89        46

    accuracy                           0.82        60
   macro avg       0.90      0.61      0.62        60
weighted avg       0.85      0.82      0.77        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:35<00:00,  2.38s/it]
100%|██████████| 2/2 [00:03<00:00,  1.98s/it]


Epoch=39, BatchID=14, Val_F1=0.6660, Best_Val_F1=0.7948
              precision    recall  f1-score   support

           0       0.62      0.36      0.45        14
           1       0.83      0.93      0.88        46

    accuracy                           0.80        60
   macro avg       0.73      0.65      0.67        60
weighted avg       0.78      0.80      0.78        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:27<00:00,  1.84s/it]
100%|██████████| 2/2 [00:08<00:00,  4.28s/it]


Epoch=40, BatchID=14, Val_F1=0.6033, Best_Val_F1=0.7948
              precision    recall  f1-score   support

           0       0.37      0.93      0.53        14
           1       0.96      0.52      0.68        46

    accuracy                           0.62        60
   macro avg       0.67      0.73      0.60        60
weighted avg       0.82      0.62      0.64        60



  logpt = F.log_softmax(input)
100%|██████████| 15/15 [00:37<00:00,  2.26s/it]

In [None]:
test_dataset = EDFDataLoader('test')
test_loader = DataLoader(dataset=test_dataset, shuffle=False, batch_size=16)

model = DeepSleepNet()
model.to(device)
model.load_state_dict(torch.load(os.path.join(save_path, save_name)))

In [None]:
# testing
print('Start Testing')
print('-'*30)

model.eval()
with torch.no_grad():    
    test_pred = []
    data_path_lst = []
    for jdx, (test_data, data_path) in enumerate(test_loader):
        test_data = test_data.to(device).float()
        y_pred = model(test_data)
        y_pred = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)
        test_pred.append(y_pred.detach().cpu().numpy())
        data_path_lst.append(data_path)

    test_pred = np.concatenate(test_pred)
    data_path_lst = np.concatenate(data_path_lst)
    
print('Done')

In [None]:
pred_info = {}

for pred, img_path in zip(test_pred, data_path_lst):
    file_name = img_path.split('/')[-1].replace('.edf', '')
    if pred == 1:
        pred_info[file_name] = 'OSA'
    else:
        pred_info[file_name] = 'Normal'
    

sort_pred = dict(sorted(pred_info.items()))
print(sort_pred.items())

submission = pd.DataFrame.from_dict([sort_pred]).T

submission.to_csv(os.path.join(save_path, csv_name), index=True, header=False)
submission.head()