# Deep Sleep 2.0 - Model Training

## Main Switches

In [None]:
''' 
MODEL: 
Select an existing model configuration (see the 'models' folder and paper) or 
define a custom configuration using a JSON configuration file 
'''
MODEL_NAME = 'model_2' # available: {'model_0'; 'model_1'; 'model_2'; 'model_3'}

''' 
MODE: 
Select between model training (True) or model inference (False)
'''
TRAIN_MODE = True  # available options: {True; False}

''' 
CHECKPOINT: 
Load the last available training checkpoint?
'''
LOAD_CHECKPOINT = False  # available options: {True; False}

## Modules

In [None]:
import sys
import os
import random
import numpy as np
import pickle
import json
import csv
import time
import importlib
import matplotlib.pyplot as plt
from itertools import zip_longest
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import random_split
from torchsummary import summary
from losses import CustomBCELoss, CustomBCEWithLogitsLoss
from tqdm import tqdm
from score2018 import Challenge2018Score
from utils import preprocess, get_record, folders_to_records_txt

print('GPU is', '' if torch.cuda.is_available() else 'not', 'available.')

## Paths

In [None]:
ROOT_PATH        = './'
DATA_PATH        = os.path.join(ROOT_PATH, 'data')
TRAIN_DATA_PATH  = os.path.join(DATA_PATH, 'training')
TEST_DATA_PATH   = os.path.join(DATA_PATH, 'test')
MODEL_PATH       = os.path.join(ROOT_PATH, 'models', MODEL_NAME)

## Hyperparameters and Reproducibility

In [None]:
assert os.path.isfile(os.path.join(MODEL_PATH, 'hyperparameters.txt')), \
"File hyperparameters.txt does not exist in: " + MODEL_PATH + os.path.sep

# load the model configuration and hyperparameters from the JSON file
with open(os.path.join(MODEL_PATH, 'hyperparameters.txt')) as f:
    hyperparameters = json.load(f)
print('\033[1m',"Hyperparameters specific to the selected model:",'\033[0m', \
      *hyperparameters.items(), sep='\n')
    
ARCHITECTURE_NAME = hyperparameters['ARCHITECTURE_NAME']
SEED              = hyperparameters['SEED']
MAX_NUM_EPOCHS    = hyperparameters['MAX_NUM_EPOCHS']
CHANNELS          = hyperparameters['CHANNELS']
LEARNING_RATE     = hyperparameters['LEARNING_RATE']
DECAY_RATE        = hyperparameters['DECAY_RATE']
STOP_STRIP        = hyperparameters['STOP_STRIP']
DEVICE            = hyperparameters['DEVICE']
BATCH_SIZE        = hyperparameters['BATCH_SIZE']
NUM_WORKERS       = hyperparameters['NUM_WORKERS']
PIN_MEMORY        = hyperparameters['PIN_MEMORY']
LINEAR            = hyperparameters['LINEAR']
Z_NORM            = hyperparameters['Z_NORM']
TRANSFORMS        = hyperparameters['TRANSFORMS']

# to enable analysis/inference on CPU-only devices
if not TRAIN_MODE: 
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 
    BATCH_SIZE = 2 if torch.cuda.is_available() else 1
    NUM_WORKERS = 2 if torch.cuda.is_available() else 0

# set seed for reproducibility
random.seed(SEED) 
np.random.seed(SEED) 
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)

## Data Availability & Preprocessing

In [None]:
'''
Identify all records stored in '/data/*' (incl. subfolders) and save 
their respective paths in 'RECORDS.txt'.
'''
folders_to_records_txt(DATA_PATH)

'''
Uniform all recording/arousal signals to have the same 8-milion length 
(2^23 = 8,388,608) by padding zeros and centering the recording/arousal 
signals. Perform Z-score normalization, if applicable.
'''
preprocess(DATA_PATH, Z_NORM)

## Custom Dataset

In [None]:
class CustomDataset(Dataset):
    """Custom dataset for multichannel PSG recordings and sleep arousal annotations"""
    
    def __init__(self, data_path, records, channels = [*range(13)], z_norm = False, transforms = None):     
        self.data_path = data_path
        self.records = records
        self.channels = channels
        self.n_channels = len(channels)
        self.normalization = 2 if z_norm else 1
        self.transforms = transforms
        
    def __len__(self):
        return len(self.records)        
        
    def __getitem__(self, idx):
        file_path = os.path.join(self.data_path, self.records[idx])
        recording, arousal = get_record(file_path, self.channels, self.normalization)
        
        if self.transforms is not None:
            recording, arousal = self.transforms(recording, arousal)
        
        return recording, arousal, idx

    def get_record_name(self, idx):
        return self.records[idx]   
    

## Custom Transforms

In [None]:
class MagScale(object):
    """Rescale the magnitude of all PSG channels with the same random scale factor"""
    
    def __init__(self, low = 0.8, high = 1.25):
        self.low = low
        self.high = high
    
    def __call__(self, recording, arousal):
        scale = self.low + torch.rand(1)*(self.high - self.low)
        recording = scale*recording

        return recording, arousal   

    
class MagScaleRandCh(object):
    """Rescale the magnitude of a randomly selected PSG channel with a random scale factor"""
    
    def __init__(self, n_channels = 13, low = 0.8, high = 1.25):
        self.n_channels = n_channels        
        self.low = low
        self.high = high
    
    def __call__(self, recording, arousal):
        scales = self.low + torch.rand(self.n_channels).view(-1,1)*(self.high - self.low)
        recording = scales*recording

        return recording, arousal       

    
class RandShuffle(object):
    """Randomly reshuffle a subset of related PSG channels"""
    
    def __init__(self):
        self.r2 = torch.LongTensor([6, 7])       # fixed channels (E1-M2, Chin)
        self.r4 = torch.LongTensor([10, 11, 12]) # fixed channels (Airflow, SaO2, ECG)
    
    def __call__(self, recording, arousal):    
        r1 = torch.randperm(6)     # shuffle EEG channels (F3-M2, F4-M1, C3-M2, C4-M1, O1-M2, O2-M1)
        r3 = 8 + torch.randperm(2) # shuffle EMG channels (ABD, Chest)
        r = torch.cat((r1, self.r2, r3, self.r4)).type(torch.long) 
        
        return recording.index_select(0, r), arousal     

    
class AddRandGaussian2All(object):
    """Add zero-mean Gaussian noise to all PSG channels"""
    
    def __init__(self, z_norm = True):
        self.z_norm = z_norm
    
    def __call__(self, recording, arousal):
        if self.z_norm:
            std_dev = 0.1 
        else:
            std_dev = 0.1*torch.std(recording, 1, keepdim = True)
        recording = recording + std_dev*torch.randn(recording.shape)
        
        return recording, arousal    
    
    
class InjectRandGaussian(object):
    """Replace a randomly selected PSG channel with a standard Gaussian noise sequence"""
    
    def __init__(self, n_channels = 13):
        self.n_channels = n_channels
    
    def __call__(self, recording, arousal):
        ri = torch.randint(0,self.n_channels,(1,)).type(torch.long)
        recording[ri] = torch.normal(mean = 0, std = 1, size = (1, recording.shape[1]))
        
        return recording, arousal     
    

class TimeScale(object):
    """Stretch/shrink the recording and arousal signals with a random time scale while 
       maintaining the original lengths"""

    def __init__(self, interval, n_channels = 13):
        self.interval = interval
        self.n_channels = n_channels        

    def __call__(self, recording, arousal):
        scale = 1 + self.interval*(torch.rand(1) - 0.5)
        recording = F.interpolate(recording.reshape((1,self.n_channels,-1)), \
                                  scale_factor = scale, recompute_scale_factor = True)
        arousal = F.interpolate(arousal.reshape((1,1,-1)), scale_factor = scale, \
                                recompute_scale_factor = True)

        return recording, arousal

    
class ToTensor(object):
    """Convert the recording and arousal signals to Tensors"""

    def __call__(self, recording, arousal):
        return torch.Tensor(recording), torch.Tensor(arousal)     
    
    
class Compose:
    """Stack multiple transforms together"""
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, recording, arousal):
        for t in self.transforms:
            recording, arousal = t(recording, arousal)
            
        return recording, arousal   
    

## Dataset Split

In [None]:
assert os.path.isfile(os.path.join(TRAIN_DATA_PATH, 'RECORDS.txt')), \
"File RECORDS.txt does not exist in: " + TRAIN_DATA_PATH + os.path.sep

with open(os.path.join(TRAIN_DATA_PATH, 'RECORDS.txt'), 'r') as f:
    records_in_train_folder = f.read().splitlines()
try:
    with open(os.path.join(TEST_DATA_PATH, 'RECORDS.txt'), 'r') as f:
        records_in_test_folder = f.read().splitlines()
        print("PhysioNet test dataset available and will be used for testing (100%).")
except FileNotFoundError:
        records_in_test_folder = None
        print("PhysioNet test dataset not available.")

train_tot = len(records_in_train_folder)

if not records_in_test_folder: 
    
    '''
    Default Case: complete PhysioNet test dataset not available. 
    The PhysioNet training dataset is split in training (60%), validation (15%), 
    and test (25%) sets.
    '''
    
    TEST_DATA_PATH = TRAIN_DATA_PATH
    
    train_split = int(0.6*train_tot)
    val_split   = int(0.15*train_tot)
    test_split  = train_tot - train_split - val_split    
    
    train_records, tmp = train_test_split(records_in_train_folder, \
                                          test_size = val_split + test_split, \
                                          random_state = SEED)
    val_records, test_records = train_test_split(tmp, test_size = test_split, \
                                                 random_state = SEED) 
    
    print("The PhysioNet training dataset has been split in", 
          "training ({:2.0%}),".format(train_split/train_tot),
          "validation ({:2.0%}),".format(val_split/train_tot),
          "and test ({:2.0%}) sets.".format(test_split/train_tot))
    
else:
    
    '''
    Exceptional Case: complete PhysioNet test dataset is available (incl.labels!). 
    The PhysioNet training dataset is split in training (80%) and validation (20%) 
    sets. The complete PhysioNet test dataset is used for testing purposes.
    '''
    
    train_split = int(0.8*train_tot)
    val_split = train_tot - train_split
    
    train_records, val_records = train_test_split(records_in_train_folder, \
                                                  test_size = val_split, \
                                                  random_state = SEED)
    test_records = records_in_test_folder
    
    print("The PhysioNet training dataset has been split in",
          "training ({:2.0%})".format(train_split/train_tot),
          "and validation ({:2.0%}) sets.".format(val_split/train_tot))

# Record the splits
if TRAIN_MODE:
    with open(os.path.join(MODEL_PATH, 'records.csv'),"w+") as f:
        writer = csv.writer(f)
        writer.writerow(['train_records', 'val_records', 'test_records'])
        for values in zip_longest(*[train_records, val_records, test_records]):
            writer.writerow(values)     
        

## Dataloader

In [None]:
train_transforms_list = [ToTensor()]

for transform in TRANSFORMS:
    if transform == 'MagScale':
        train_transforms_list.append(MagScale())
    if transform == 'MagScaleRandCh':
        train_transforms_list.append(MagScaleRandCh())        
    if transform == 'TimeScale':
        train_transforms_list.append(TimeScale())
    if transform == 'RandShuffle':
        train_transforms_list.append(RandShuffle())
    if transform == 'AddRandGaussian2All':
        train_transforms_list.append(AddRandGaussian2All())
    if transform == 'InjectRandGaussian':
        train_transforms_list.append(InjectRandGaussian())
        
train_transforms = Compose(train_transforms_list)
val_test_transforms = ToTensor()

train_dataset = CustomDataset(TRAIN_DATA_PATH, train_records, CHANNELS, \
                              Z_NORM, train_transforms)
val_dataset   = CustomDataset(TRAIN_DATA_PATH, val_records, CHANNELS, \
                              Z_NORM, val_test_transforms)
test_dataset  = CustomDataset(TEST_DATA_PATH, test_records, CHANNELS, \
                              Z_NORM, val_test_transforms)

print("Length of the train dataset is:", len(train_dataset))
print("Length of the val dataset is:", len(val_dataset))
print("Length of the test dataset is:", len(test_dataset))

train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle = True, \
                          num_workers = NUM_WORKERS, pin_memory = PIN_MEMORY)
val_loader   = DataLoader(val_dataset, BATCH_SIZE, shuffle = False, \
                          num_workers = NUM_WORKERS, pin_memory = PIN_MEMORY)
test_loader  = DataLoader(test_dataset, BATCH_SIZE, shuffle = False, \
                          num_workers = NUM_WORKERS, pin_memory = PIN_MEMORY)

print("Total number of train batches:", len(train_loader))
print("Total number of val batches:", len(val_loader))
print("Total number of test batches:", len(test_loader))

## Evaluation Function

In [None]:
def eval_fn(model, loader, DEVICE, comp_score = False):
    
    # prepare model for evaluation (disable batch normalization)
    model.eval()
    
    # use the official PhysioNet scoring function
    scores = Challenge2018Score() if comp_score else None
    
    # define loss function
    if comp_score:
        loss_fn = CustomBCELoss().to(DEVICE)
    else:
        loss_fn = CustomBCEWithLogitsLoss().to(DEVICE)
    
    # compute loss and, if applicable, score
    with torch.no_grad():
        loss_epoch_sum = 0
        for x, y, idx in loader:          
            x = x.to(device = DEVICE)
            y = y.to(device = DEVICE)

            with torch.cuda.amp.autocast(enabled = torch.cuda.is_available() \
                                         and not comp_score):

                y_hat = model(x, comp_score)
                loss = loss_fn(y_hat, y)
            
                # compute AUROC/AUPRC score for each record in batch
                if comp_score:                
                    for i, single_idx in enumerate(idx):
                        record = loader.dataset.get_record_name(single_idx)
                        scores.score_record(y[i].view(-1).to('cpu'), \
                                            y_hat[i].view(-1).to('cpu'), record)
                        auroc = scores.record_auroc(record)
                        auprc = scores.record_auprc(record)               
                        print('%-11s  AUROC: %8.6f,  AUPRC: %8.6f' % \
                              (record, auroc, auprc))

            loss_epoch_sum += float(loss.item())                 

    return loss_epoch_sum/len(loader), scores

## Training Function

In [None]:
def train_fn(model, loader, optimizer, loss_fn, scaler, DEVICE):
    
    # prepare model for evaluation (enable batch normalization)
    model.train()

    loss_epoch_sum = 0
    
    with tqdm(total = len(loader)) as pbar:
        start_time = time.time()
        
        for batch_idx, (x, y, _) in enumerate(loader):

            """ Step 0. move the tensors to the right device """    
            
            x = x.to(device = DEVICE)
            y = y.to(device = DEVICE)  
            
            prepare_time = start_time - time.time()

            """ Step 1. clear gradients """
            optimizer.zero_grad(set_to_none = True)  

            with torch.cuda.amp.autocast(enabled = torch.cuda.is_available()):     

                """ Step 2. Forward pass """
                y_hat = model(x)

                """ Step 3. Loss calculation """                
                loss = loss_fn(y_hat, y)

            """ Step 4. Backward pass """
            scaler.scale(loss).backward()

            """ Step 5. Optimization (parameter update) """
            scaler.step(optimizer)
            scaler.update()

            """ Step 6. Timing and logging """
            loss_epoch_sum += float(loss.item())
                   
            process_time = start_time - time.time() - prepare_time
            compute_efficiency = process_time / (process_time + prepare_time)    
            
            pbar.update(1)
            pbar.set_postfix({'Running loss' : loss_epoch_sum/(batch_idx + 1), \
                              'Compute efficiency': compute_efficiency}) 
            start_time = time.time()
        
    return loss_epoch_sum/len(loader) 

## Training Loop

In [None]:
''' 
Create (and load) model and define hyperparameters 
'''

DeepSleepNet = getattr(importlib.import_module('architectures.' + ARCHITECTURE_NAME), 'DeepSleepNet')

model     = DeepSleepNet(in_channels = len(CHANNELS), linear = LINEAR).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE, weight_decay = DECAY_RATE)
loss_fn   = CustomBCEWithLogitsLoss().to(DEVICE)
scaler    = torch.cuda.amp.GradScaler(enabled = torch.cuda.is_available())

if LOAD_CHECKPOINT:
    try:
        for chp_no in range(99,-1,-1):
            last_checkpoint_path = os.path.join(MODEL_PATH, 'my_checkpoint_' + \
                                                str(chp_no)  +'.pth.tar')
            if os.path.exists(last_checkpoint_path):
                break
        checkpoint = torch.load(last_checkpoint_path, \
                                map_location = torch.device(DEVICE))

        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch0 = checkpoint['epoch'] + 1
        train_loss_history = checkpoint['train_loss_history']
        val_loss_history = checkpoint['val_loss_history']
        
        print('Model saved in my_checkpoint_'+str(chp_no)+'.pth.tar) loaded.')    
    except: 
        print('No checkpoint file found.')
    
else:
    epoch0 = 0        
    train_loss_history = []
    val_loss_history = []
    
''' 
Start the training procedure 
'''
if TRAIN_MODE:
    for epoch in range(epoch0, MAX_NUM_EPOCHS):
        np.random.seed(SEED + epoch)

        # evaluate early stop criterion
        if len(val_loss_history) > STOP_STRIP:           
            if val_loss_history[-(STOP_STRIP + 1)] < min(val_loss_history[-STOP_STRIP::]) \
            or np.isnan([val_loss_history, train_loss_history]).any():
                print("Early stop has been triggered.")
                break        

        # run one epoch training
        train_loss_epoch = train_fn(model, train_loader, optimizer, loss_fn, scaler, DEVICE)

        # check validation loss
        val_loss_epoch, _ = eval_fn(model, val_loader, DEVICE) 

        # loss logging
        train_loss_history.append(train_loss_epoch)
        val_loss_history.append(val_loss_epoch)          

        # loss printing
        print("\nEpoch: {} \t Mean training loss {:.6f}".format(epoch, train_loss_epoch))
        print("Epoch: {} \t Mean validation loss {:.6f}".format(epoch, val_loss_epoch))    

        # save current epoch model checkpoint
        checkpoint = {
            "epoch"                : epoch,
            "model_state_dict"     : model.state_dict(),
            "optimizer_state_dict" : optimizer.state_dict(),
            "train_loss_history"   : train_loss_history,
            "val_loss_history"     : val_loss_history,
        }
        torch.save(checkpoint, os.path.join(MODEL_PATH, \
                                            'my_checkpoint_'+str(epoch)+'.pth.tar'))
        print(">Checkpoint {} saved<".format(epoch))

## Training & Validation Loss Histories

In [None]:
plt.plot(range(len(train_loss_history)), train_loss_history, '-rx', \
         label = "Training loss")
plt.plot(range(len(val_loss_history)), val_loss_history, '-bx', \
         label = "Validation loss")
plt.xlabel("Epoch (#)", fontdict = None, labelpad = None)
plt.ylabel("Cross Entropy Loss", fontdict = None, labelpad = None)
plt.legend(loc = 'upper right', borderaxespad = 0.7, shadow = True)
plt.grid(linestyle = '--', linewidth = 0.1)

## Identify the Smallest Cross-Validation Loss Model

In [None]:
best_model_idx = val_loss_history.index(min(val_loss_history))
best_model_path = os.path.join(MODEL_PATH, 'my_checkpoint_' + str(best_model_idx) + '.pth.tar')
checkpoint = torch.load(best_model_path, map_location = torch.device(DEVICE))
model.load_state_dict(checkpoint['model_state_dict'])
print("Best model has been loaded from my_checkpoint_" + str(best_model_idx) + ".pth.tar")    

## Compute Test Loss and Gross AUROC/AUPRC

In [None]:
test_loss_path = os.path.join(MODEL_PATH, 'test_loss_' + str(best_model_idx) + '.pickle')
test_score_path = os.path.join(MODEL_PATH, 'test_score_' + str(best_model_idx) + '.pickle')

if os.path.exists(test_loss_path) and os.path.exists(test_score_path):
    with open(test_loss_path, 'rb') as f1, open(test_score_path, 'rb') as f2:
        test_loss = pickle.load(f1)
        test_score = pickle.load(f2)
else:
    test_loss, test_score = eval_fn(model, test_loader, DEVICE, comp_score = True)
    with open(test_loss_path, 'wb') as f1, open(test_score_path, 'wb') as f2:
        pickle.dump(test_loss, f1)    
        pickle.dump(test_score, f2)            

print("Test results based on {:d} test cases.".format(len(test_score._record_auc)))
print("Cross-entropy loss = {:.6f}".format(test_loss))
print('Gross AUROC: %8.6f' % test_score.gross_auroc())
print('Gross AUPRC: %8.6f' % test_score.gross_auprc())

## Model Summary

In [None]:
summary(model, input_size=(len(CHANNELS), 2**23))