# Baseline Model

## Import packages

In [2]:
import os
import numpy as np
import pandas as pd
import random

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score

## CONFIG

#### Data Directory

In [3]:
DATA_DIR = '/workspace/Competition/PSG/data/final'
label_dir = os.path.join(DATA_DIR, 'train_labels.csv')
train_dir = os.path.join(DATA_DIR, 'train')
test_dir = os.path.join(DATA_DIR, 'test')
result_dir = os.path.join('/workspace/Competition/PSG/data/results')
norm_dir = os.path.join(DATA_DIR, 'norm.npy')

#### Set Seed

In [4]:
RANDOM_SEED = 2022

torch.manual_seed(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

#### Set Device

In [5]:
# GPU config
os.environ["CUDA_VISIBLE_DEVICES"]="2"
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#### Hyperparameters

In [12]:
EPOCHS = 20
BATCH_SIZE = 32
LEARNING_RATE = 0.0003
EARLY_STOPPING_PATIENCE = 10

## Define Dataset

In [13]:
class EEG_Single_Dataset(Dataset):
    def __init__(self, datapath, labeldf, normpath):
        self.df = labeldf
        self.label_encoding = {'W':0, 'N1':1, 'N2':2, 'N3':3, 'R':4}
        self.data_path = datapath
        self.file_ids = self.df['rec_id']
        self.labels = self.df['stage']
        self.normparams = np.load(normpath).astype('float32')
        self.mean = self.normparams[0]
        self.std = self.normparams[1]
    
    def __len__(self):
        return len(self.file_ids)
    
    def __getitem__(self,index):
        npypath = os.path.join(self.data_path, self.file_ids[index])
        x = torch.from_numpy(np.load(npypath).astype('float32'))
        x = (x-self.mean)/self.std
        subx = x[-30*128:].view(1,-1)
        label = self.labels[index]
        y = self.label_encoding[label]
        
        return subx,y

## Define Model

In [7]:
class DOUBLE_CNN(nn.Module):
    def __init__(self):
        super(DOUBLE_CNN, self).__init__()
        
        self.small_cnn = nn.Sequential(
            nn.Conv1d(in_channels=1, out_channels=64, kernel_size=int(128/2), stride = int(128/16)),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=8, stride=8),
            nn.Dropout(p=0.3),
            nn.Conv1d(in_channels=64, out_channels=128, kernel_size=4),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Conv1d(in_channels=128, out_channels=128, kernel_size=4),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Conv1d(in_channels=128, out_channels=128, kernel_size=4),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=4, stride=4),
        )

        self.large_cnn = nn.Sequential(
            nn.Conv1d(in_channels=1, out_channels=64, kernel_size=128*4, stride=int(128/2)),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=4, stride=4),
            nn.Dropout(p=0.3),
            nn.Conv1d(in_channels = 64, out_channels=128, kernel_size=3),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Conv1d(in_channels = 128, out_channels = 128, kernel_size=3),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2),
        )
        
        self.fc = nn.Sequential(
            nn.Linear((12+4)*128,1024),
            nn.ReLU(),
            nn.Linear(1024,128),
            nn.ReLU(),
            nn.Linear(128,64),
            nn.ReLU(),
            nn.Linear(64,5)
        )

    def forward(self, x):
        xs = self.small_cnn(x)
        xl = self.large_cnn(x)
        xs = xs.flatten(1,2)
        xl = xl.flatten(1,2)
        xcat = torch.cat((xs,xl),1)
        out = self.fc(xcat)
        return out

## Utils
#### EarlyStopper

In [8]:
class LossEarlyStopper():
    def __init__(self, patience: int)-> None:
        self.patience = patience
        self.patience_counter = 0
        self.min_loss = np.Inf
        self.stop = False
        self.savel_model = False
        
    def check_early_stopping(self, loss: float)-> None:
        if loss > self.min_loss:
            self.patience_counter +=1
            msg = f"Early stopping counter {self.patience_counter}/{self.patience}"
            
            if self.patience_counter == self.patience:
                self.stop=True
            
        else:
            self.patience_counter = 0
            self.save_model = True
            msg = f"Validation loss decreased {self.min_loss} - > {loss}"
            self.min_loss = loss
        print(msg)

#### Trainer

In [9]:
class Trainer():
    def __init__(self, model, optimizer, loss, metrics, device):
        self.model = model
        self.optimizer = optimizer
        self.loss = loss
        self.metric_fn = metrics
        self.device = device
        
    def train_epoch(self, dataloader, epoch_index):
        self.model.train()
        train_total_loss = 0
        target_list = []
        pred_list = []
        
        for batch_index, (x,y) in enumerate(dataloader):
            x,y = x.to(self.device), y.to(self.device)
            y_pred = model(x)
            loss = self.loss(y_pred,y)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            train_total_loss += loss.item()
            pred_list.extend(y_pred.argmax(dim=1).cpu().tolist())
            target_list.extend(y.cpu().tolist())
        self.train_mean_loss = train_total_loss / (batch_index+1)
        train_score, f1 = self.metric_fn(y_pred=pred_list, y_answer=target_list)
        msg = f"Epoch {epoch_index}, Train loss: {self.train_mean_loss}, Acc:{train_score}, F1-Macro: {f1}"
        print(msg)
    
    def validate_epoch(self, dataloader, epoch_index):
        val_total_loss = 0
        target_list = []
        pred_list = []
        
        with torch.no_grad():
            for batch_index, (x, y) in enumerate(dataloader):
                x = x.to(self.device)
                y = y.to(self.device)
                y_pred = self.model(x)
                loss = self.loss(y_pred, y)
                
                val_total_loss += loss.item()
                target_list.extend(y.cpu().tolist())
                pred_list.extend(y_pred.argmax(dim=1).cpu().tolist())
        self.val_mean_loss = val_total_loss / (batch_index+1)
        val_score, f1 = self.metric_fn(y_pred = pred_list, y_answer = target_list)
        msg = f"Epoch {epoch_index}, Val loss: {self.val_mean_loss}, Acc:{val_score}, F1-Macro: {f1}"
        print(msg)

#### Metrics

In [10]:
def get_metric_fn(y_pred, y_answer):
    assert len(y_pred) == len(y_answer), 'The size of prediction and answer are not the same.'
    accuracy = accuracy_score(y_answer, y_pred)
    f1 = f1_score(y_answer, y_pred, average='macro')
    return accuracy, f1

## Train Model

#### Set Dataset & Dataloader 

In [11]:
# Load label dataframe
entiredf = pd.read_csv(label_dir)
traindf, valdf = train_test_split(entiredf, test_size=0.2)
traindf = traindf.reset_index(drop=True)
valdf = valdf.reset_index(drop=True)


train_dataset = EEG_Single_Dataset(datapath=train_dir, labeldf=traindf, normpath=norm_dir)
val_dataset = EEG_Single_Dataset(datapath=train_dir, labeldf=valdf, normpath=norm_dir)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

print('Train set samples: ', len(train_dataset), 'Val set samples: ', len(val_dataset))

NameError: name 'BATCH_SIZE' is not defined

#### Set Model and trainer

In [12]:
model = DOUBLE_CNN().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_fn = nn.CrossEntropyLoss()
early_stopper = LossEarlyStopper(patience=EARLY_STOPPING_PATIENCE)
metrics = get_metric_fn

trainer = Trainer(model, optimizer, loss_fn, get_metric_fn, DEVICE)

In [13]:
model

DOUBLE_CNN(
  (small_cnn): Sequential(
    (0): Conv1d(1, 64, kernel_size=(64,), stride=(8,))
    (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool1d(kernel_size=8, stride=8, padding=0, dilation=1, ceil_mode=False)
    (4): Dropout(p=0.3, inplace=False)
    (5): Conv1d(64, 128, kernel_size=(4,), stride=(1,))
    (6): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): ReLU()
    (8): Conv1d(128, 128, kernel_size=(4,), stride=(1,))
    (9): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): Conv1d(128, 128, kernel_size=(4,), stride=(1,))
    (12): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU()
    (14): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  )
  (large_cnn): Sequential(
    (0): Conv1d(1, 64, kernel_size=(512,), stride=(64,))
    (1): Batc

### Train

In [14]:
for epoch_index in tqdm(range(EPOCHS)):
    trainer.train_epoch(train_loader, epoch_index)
    trainer.validate_epoch(val_loader, epoch_index)
    
    early_stopper.check_early_stopping(loss = trainer.val_mean_loss)
    
    if early_stopper.stop:
        print('Early Stopped')
        break
    if early_stopper.save_model:
        check_point = {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }
        torch.save(check_point, os.path.join(result_dir,'best.pt'))

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 0, Train loss: 1.002196481221617, Acc:0.5887037189064938, F1-Macro: 0.4200351549566871


  5%|▌         | 1/20 [00:13<04:22, 13.80s/it]

Epoch 0, Val loss: 0.9083360452204943, Acc:0.639527248850952, F1-Macro: 0.5119867344475857
Validation loss decreased inf - > 0.9083360452204943
Epoch 1, Train loss: 0.862658102994203, Acc:0.6543797717757163, F1-Macro: 0.5123175219949341


 10%|█         | 2/20 [00:27<04:08, 13.83s/it]

Epoch 1, Val loss: 0.8320806616296371, Acc:0.6680892974392646, F1-Macro: 0.5251614337258814
Validation loss decreased 0.9083360452204943 - > 0.8320806616296371
Epoch 2, Train loss: 0.8028361478189784, Acc:0.6889417945981446, F1-Macro: 0.546197073313224


 15%|█▌        | 3/20 [00:41<03:54, 13.79s/it]

Epoch 2, Val loss: 0.7837761209035913, Acc:0.7032173342087984, F1-Macro: 0.5588789763820738
Validation loss decreased 0.8320806616296371 - > 0.7837761209035913
Epoch 3, Train loss: 0.7640197228571874, Acc:0.7117642229701995, F1-Macro: 0.5688389049384436


 20%|██        | 4/20 [00:54<03:38, 13.68s/it]

Epoch 3, Val loss: 0.834229551255703, Acc:0.6723571897570584, F1-Macro: 0.5175885512263101
Early stopping counter 1/10
Epoch 4, Train loss: 0.7274075289723754, Acc:0.7268697151301207, F1-Macro: 0.5806826174322414


 25%|██▌       | 5/20 [01:08<03:25, 13.68s/it]

Epoch 4, Val loss: 0.7388613211611906, Acc:0.7242284963887065, F1-Macro: 0.5788078236168671
Validation loss decreased 0.7837761209035913 - > 0.7388613211611906
Epoch 5, Train loss: 0.6937615220628073, Acc:0.7413184467613496, F1-Macro: 0.5940317052540717


 30%|███       | 6/20 [01:22<03:12, 13.75s/it]

Epoch 5, Val loss: 0.712985510006547, Acc:0.7386736703873933, F1-Macro: 0.5955736544034782
Validation loss decreased 0.7388613211611906 - > 0.712985510006547
Epoch 6, Train loss: 0.678066751697245, Acc:0.7517445201543387, F1-Macro: 0.6025094809162413


 35%|███▌      | 7/20 [01:35<02:57, 13.66s/it]

Epoch 6, Val loss: 0.7181747819607457, Acc:0.7337491792514773, F1-Macro: 0.5899247223544847
Early stopping counter 1/10
Epoch 7, Train loss: 0.6484726528170228, Acc:0.7634020195386257, F1-Macro: 0.6115236961939304


 40%|████      | 8/20 [01:49<02:43, 13.61s/it]

Epoch 7, Val loss: 0.7342343532169858, Acc:0.7301378857518056, F1-Macro: 0.5796328433190732
Early stopping counter 2/10
Epoch 8, Train loss: 0.6325665832035184, Acc:0.7711189557507594, F1-Macro: 0.6183679814746394


 45%|████▌     | 9/20 [02:02<02:29, 13.57s/it]

Epoch 8, Val loss: 0.7091463586936394, Acc:0.7435981615233093, F1-Macro: 0.59651751781811
Validation loss decreased 0.712985510006547 - > 0.7091463586936394
Epoch 9, Train loss: 0.6292416716967355, Acc:0.7705442902881536, F1-Macro: 0.6189440132084688


 50%|█████     | 10/20 [02:16<02:15, 13.54s/it]

Epoch 9, Val loss: 0.6932106815899411, Acc:0.7531188443860801, F1-Macro: 0.6052176600638104
Validation loss decreased 0.7091463586936394 - > 0.6932106815899411
Epoch 10, Train loss: 0.6126158626999442, Acc:0.7762909449142107, F1-Macro: 0.6246104936590812


 55%|█████▌    | 11/20 [02:29<02:01, 13.51s/it]

Epoch 10, Val loss: 0.6729304976761341, Acc:0.7609980302035456, F1-Macro: 0.6104182648641363
Validation loss decreased 0.6932106815899411 - > 0.6729304976761341
Epoch 11, Train loss: 0.5997185035953372, Acc:0.7807240784828832, F1-Macro: 0.6272811830713397


 60%|██████    | 12/20 [02:43<01:47, 13.49s/it]

Epoch 11, Val loss: 0.6893165853495399, Acc:0.747866053841103, F1-Macro: 0.5974179460355438
Early stopping counter 1/10
Epoch 12, Train loss: 0.5907395635377078, Acc:0.7857318775141614, F1-Macro: 0.6336384018254048


 65%|██████▌   | 13/20 [02:56<01:34, 13.49s/it]

Epoch 12, Val loss: 0.6822538975005349, Acc:0.7570584372948129, F1-Macro: 0.6062708353935459
Early stopping counter 2/10
Epoch 13, Train loss: 0.5703487320756662, Acc:0.7927920531976028, F1-Macro: 0.6414372189042272


 70%|███████   | 14/20 [03:10<01:20, 13.48s/it]

Epoch 13, Val loss: 0.6810168915738662, Acc:0.7580433355219961, F1-Macro: 0.6107380054408618
Early stopping counter 3/10
Epoch 14, Train loss: 0.5477196632877109, Acc:0.796240045973237, F1-Macro: 0.6467599862473425


 75%|███████▌  | 15/20 [03:23<01:07, 13.48s/it]

Epoch 14, Val loss: 0.6907382154216369, Acc:0.7541037426132633, F1-Macro: 0.6131787156183954
Early stopping counter 4/10
Epoch 15, Train loss: 0.5424555755193465, Acc:0.8039569821853707, F1-Macro: 0.6558008807389455


 80%|████████  | 16/20 [03:37<00:53, 13.47s/it]

Epoch 15, Val loss: 0.7036121310666203, Acc:0.7580433355219961, F1-Macro: 0.6061423140970417
Early stopping counter 5/10
Epoch 16, Train loss: 0.5242615159884525, Acc:0.8066661193662261, F1-Macro: 0.6608876494440324


 85%|████████▌ | 17/20 [03:50<00:40, 13.51s/it]

Epoch 16, Val loss: 0.7094305884093046, Acc:0.7560735390676296, F1-Macro: 0.6053327977516718
Early stopping counter 6/10
Epoch 17, Train loss: 0.5300345015259865, Acc:0.8039569821853707, F1-Macro: 0.6605007739600006


 90%|█████████ | 18/20 [04:04<00:27, 13.57s/it]

Epoch 17, Val loss: 0.6816253044332067, Acc:0.7583716349310571, F1-Macro: 0.6071942300386772
Early stopping counter 7/10
Epoch 18, Train loss: 0.5058447112796187, Acc:0.8124948690583695, F1-Macro: 0.681008923192746


 95%|█████████▌| 19/20 [04:17<00:13, 13.55s/it]

Epoch 18, Val loss: 0.7148175472393632, Acc:0.7619829284307288, F1-Macro: 0.6174000456640542
Early stopping counter 8/10
Epoch 19, Train loss: 0.4807601214784963, Acc:0.8244807487070027, F1-Macro: 0.6950584675964901


100%|██████████| 20/20 [04:31<00:00, 13.58s/it]

Epoch 19, Val loss: 0.7311292433490356, Acc:0.7508207485226527, F1-Macro: 0.6153676784203238
Early stopping counter 9/10





## Inference

#### Define Test Dataset

In [21]:
class TestDataset(Dataset):
    def __init__(self, datapath, normpath):
        self.data_path = datapath
        self.npy_list = os.listdir(self.data_path)
        self.normparams = np.load(normpath).astype('float32')
        self.mean = self.normparams[0]
        self.std = self.normparams[1]
    
    def __len__(self):
        return len(self.npy_list)
    
    def __getitem__(self,index):
        filename = self.npy_list[index]
        npypath = os.path.join(self.data_path, filename)
        x = torch.from_numpy(np.load(npypath).astype('float32'))
        x = (x-self.mean)/self.std
        subx = x[-30*128:].reshape(1,-1)
        return subx, filename

#### Set Dataset and Dataloader for inference

In [22]:
test_dataset = TestDataset(test_dir, norm_dir)
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)

#### Load Model

In [23]:
TRAINED_MODEL_PATH = os.path.join(result_dir, 'best.pt')
test_model = DOUBLE_CNN()
test_model.load_state_dict(torch.load(TRAINED_MODEL_PATH)['model'])

<All keys matched successfully>

#### Make predictions

In [24]:
file_list = []
pred_list = []

test_model.to(DEVICE)
test_model.eval()
with torch.no_grad():
    for batch_index, (x,y) in tqdm(enumerate(test_loader)):
        x = x.to(DEVICE)
        pred = test_model(x)
        
        file_list.extend(list(y))
        pred_list.extend(pred.argmax(dim=1).tolist())

123it [00:02, 54.55it/s]


#### Save predictions

In [25]:
# Make dataframe of predictions
results = pd.DataFrame({'rec_id':file_list, 'stage':pred_list})

# Change predictions to labels
label_decoding = {0:'W', 1:'N1', 2:'N2', 3:'N3', 4:'R'}
results = results.replace(label_decoding)

# Change order of predictios to match sample_submission.csv file
sampledf = pd.read_csv(os.path.join(DATA_DIR,'sample_submission.csv'))
sorter = list(sampledf['rec_id'])
results = results.set_index('rec_id')
results = results.loc[sorter].reset_index()

# Save predictions
results.to_csv(os.path.join(result_dir,'prediction.csv'),index=False)