In [2]:
import os
import sys
sys.path.append("../src")
import numpy as np


import torch
from torch.utils.data import Dataset

from datetime import datetime
from tqdm import tqdm

import args

torch.manual_seed(args.SEED) # set the random seed

<torch._C.Generator at 0x1169a0890>

In [10]:
data_config = args.data_config
model_config = args.pretrain_config
model_save_format = args.model_save_format

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device('cuda' if torch.cuda.is_available() else "mps")
print(device)

mps


In [11]:
class MESADataset(Dataset):
    def __init__(self, file_path, modalities=['ecg', 'hr'], subject_idx='subject_idx', stage='stage'):
        super(MESADataset, self).__init__()
        self.root_dir = file_path
        self.files = os.listdir(file_path)
        self.modalities = modalities
        self.subject_idx = subject_idx
        self.stage = stage
        
    def __len__(self):

        return len(self.files)


    def __getitem__(self, idx):
        data = np.load(os.path.join(self.root_dir, self.files[idx])) # numpy file on each sample (segments)
        
        self.modality_1 = torch.tensor(data[self.modalities[0]], dtype=torch.float)
        self.modality_2 = torch.tensor(data[self.modalities[1]], dtype=torch.float)
        self.subject_id = torch.tensor(data[self.subject_idx], dtype=torch.long)
        self.sleep_stage = torch.tensor(data[self.stage], dtype=torch.long)
        
        sample = [self.modality_1, self.modality_2, self.subject_id, self.sleep_stage]
        
        return sample

In [13]:
def get_accuracy_from_train_process(logit_arr, true_label):

    predicted_label = torch.argmax(logit_arr, dim=1)
    acc = torch.sum(predicted_label == true_label).item() / true_label.size(0)

    return acc


def get_acc_loss_from_dataloader(model, dataloder, device, criterion):
    
    model.eval()
    
    total_correct = 0
    total_samples = 0
    total_loss = 0
    
    for i, data in enumerate(dataloder):
        ecg, hr, _, sleep_stage = data
        ecg = ecg.to(device)
        hr = hr.to(device)
        sleep_stage = sleep_stage.to(device)
        
        output = model(ecg, hr)
        loss = criterion(output, sleep_stage)
        
        total_loss += loss.item()
        total_correct += torch.sum(torch.argmax(output, dim=1) == sleep_stage).item()
        total_samples += sleep_stage.size(0)
        
    return total_correct / total_samples, total_loss / total_samples

In [14]:
def pretrain(model, model_name, train_loader, val_lodaer, optimizer, loss_fn, args, device):
    
    model.to(device)
    model.train()
    best_acc = 0
    
    plot_train_loss = []
    plot_val_loss = []
    plot_val_acc = []
    plot_train_acc = []
    
    model_save_format["lr"] = args["lr"]
    
    for ep in args["epoch"]:
        prediction_arr = []
        true_arr = []
        train_loss = 0
        for i, data in enumerate(train_loader):
            ecg, hr, _, sleep_stage = data
            ecg = ecg.to(device)
            hr = hr.to(device)
            sleep_stage = sleep_stage.to(device)
            
            optimizer.zero_grad()
            prediction = model(ecg, hr)
            loss = loss_fn(prediction, sleep_stage)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            prediction_arr.append(prediction.detach().cpu().squeeze().numpy())
            true_arr.append(sleep_stage.detach().cpu().squeeze().numpy())
        
        train_loss /= len(train_loader)
        train_acc = get_accuracy_from_train_process(prediction_arr, true_arr)
        
        plot_train_loss.append(train_loss)
        plot_train_acc.append(train_acc)
        print(f'Epoch: {ep}, Batch: {i}, TrainLoss: {loss.item()}, TrainAcc: {train_acc}')
            
        if ep % args['val_freq'] == 0:
            val_acc, val_loss = get_acc_loss_from_dataloader(model, val_lodaer, device)
            print(f'(Validation) Epoch: {ep}, ValAcc: {val_acc}, ValLoss: {val_loss}')
            plot_val_acc.append(val_acc)
            plot_val_loss.append(val_loss)
            
            if val_acc > best_acc:
                print("--------"*15)
                best_acc = val_acc
                time = datetime.now().strftime("%Y%m%d_%H%M%S")
                MODELPATH = os.path.join(args["model_save_dir"], f'pretrain_{model_name}_{ep}_{time}.pth')
                model_save_format["epoch"] = ep
                model_save_format["state_dict"] = model.state_dict()
                model_save_format["model_path"] = MODELPATH
                model_save_format["train_acc"] = train_acc
                model_save_format["train_loss"] = train_loss
                model_save_format["val_acc"] = val_acc
                model_save_format["val_loss"] = val_loss
                
                torch.save(model_save_format, MODELPATH)
                print("Best Model Saved!")
                print("--------"*15)
    
    print("Finished Training")
    print(f'Best Validation Accuracy: {best_acc}')
    
    return model

In [None]:
train_loader = torch.utils.data.DataLoader(MESADataset(data_config["train_data_dir"]), batch_size=model_config["batch_size"], shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(MESADataset(data_config["val_data_dir"]), batch_size=model_config["batch_size"], shuffle=False, num_workers=4)
test_loader = torch.utils.data.DataLoader(MESADataset(data_config["test_data_dir"]), batch_size=model_config["batch_size"], shuffle=False, num_workers=4)

In [None]:
model = None # To-do List
optimizer = torch.optim.Adam(model.parameters(), lr=model_config["lr"], weight_decay=model_config["weight_decay"])
loss_fn = torch.nn.CrossEntropyLoss()

pretrain(model, train_loader, val_loader, optimizer, loss_fn, model_config, device)