In [1]:
import os
import sys
sys.path.append("../")
import numpy as np


import torch
from torch.utils.data import Dataset

from tqdm.notebook import tqdm

import args
import downargs
from foundation.models.FOCALModules import FOCAL
import datetime

torch.manual_seed(args.SEED)
torch.cuda.manual_seed(args.SEED)
torch.cuda.manual_seed_all(args.SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
from data.Augmentaion import init_augmenter

In [2]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device('cuda:0' if torch.cuda.is_available() else "mps")
print(device)

mps


In [4]:
modelpath = args.trainer_config["model_save_dir"]
log_path = args.trainer_config["log_save_dir"]

# logs = os.listdir(log_path)

In [None]:
FM_args = args

In [None]:
# Load the config of FM

model_ckpt = torch.load(modelpath)
args.trainer_config = model_ckpt["trainer_config"]
args.model_config = model_ckpt["model_config"]
args.data_config = model_ckpt["data_config"]

FM_args = args

In [None]:
model_save_format = args.model_save_format
model_save_format["model_config"] = model_config
model_save_format["data_config"] = data_config
model_save_format["focal_config"] = args.focal_config

In [None]:
class MESADataset(Dataset):
    def __init__(self, file_path, modalities=['ecg', 'hr'], subject_idx='subject_idx', stage='stage'):
        super(MESADataset, self).__init__()
        self.root_dir = file_path
        self.files = os.listdir(file_path)
        self.modalities = modalities
        self.subject_idx = subject_idx
        self.stage = stage
        
    def __len__(self):

        return len(self.files)


    def __getitem__(self, idx):
        data = np.load(os.path.join(self.root_dir, self.files[idx])) # numpy file on each sample (segments)
        
        self.modality_1 = torch.tensor(data[self.modalities[0]], dtype=torch.float)
        self.modality_2 = torch.tensor(data[self.modalities[1]], dtype=torch.float)
        self.subject_id = torch.tensor(data[self.subject_idx], dtype=torch.long)
        stage = data[self.stage]
        
        #if self.num_outputs == 4:
        if stage in [1, 2]:
            stage = 1
        elif stage in [3, 4]:
            stage = 2
        elif stage == 5:
            stage = 3
        
        # elif self.num_outputs == 2:
        #     if labels in [1, 2, 3, 4, 5]:
        #         labels = 1
        self.sleep_stage = torch.tensor(stage, dtype=torch.long)

        sample = [self.modality_1, self.modality_2, self.subject_id, self.sleep_stage]
        
        return sample

In [None]:
def get_accuracy_from_train_process(logit_arr, true_label):
    
    predicted_label = torch.argmax(logit_arr, dim=1)
    acc = torch.sum(predicted_label == true_label).item() / true_label.size(0)

    return acc


def get_acc_loss_from_dataloader(model, dataloder, device, criterion):
    
    model.eval()
    
    total_correct = 0
    total_samples = 0
    total_loss = 0
    
    for i, data in enumerate(dataloder):
        ecg, hr, _, sleep_stage = data
        ecg = ecg.to(device)
        hr = hr.to(device)

        sleep_stage = sleep_stage.to(device)
        
        output = model(ecg, hr, class_head=True,  proj_head=False)
        loss = criterion(output, sleep_stage)
        
        total_loss += loss.item()
        total_correct += torch.sum(torch.argmax(output, dim=1) == sleep_stage).item()
        total_samples += sleep_stage.size(0)
        
    return total_correct / total_samples, total_loss / (i+1)

In [None]:
def downstream(model, model_name, train_loader, val_lodaer, optimizer, loss_fn, downstream_config, device):
    model.to(device)
    model.train()
    best_acc = 0
    
    plot_train_loss = []
    plot_val_loss = []
    plot_val_acc = []
    plot_train_acc = []
    
    model_save_format = args.model_save_format
    model_save_format["lr"] = downstream_config["lr"]
    start_time = datetime.now().strftime("%Y%m%d_%H%M%S")
    modelPATH = os.path.join(downstream_config["model_save_dir"], start_time)
    if not os.path.exists(modelPATH):
        os.makedirs(modelPATH)
        
    for ep in tqdm(range(downstream_config["epoch"])):
        prediction_arr = []
        true_arr = []
        train_loss = 0
        model.train()
        for i, data in enumerate(train_loader):
            ecg, hr, _, sleep_stage = data
            ecg = ecg.to(device)
            hr = hr.to(device)
            
            aug_1 = init_augmenter("NoAugmenter", None)
            aug_2 = init_augmenter("NoAugmenter", None)

            aug_1_modal_1 = aug_1(ecg)
            aug_2_modal_1 = aug_2(ecg)
            
            aug_1_modal_2 = aug_1(hr)
            aug_2_modal_2 = aug_2(hr)
            
            sleep_stage = sleep_stage.to(device)
            
            optimizer.zero_grad()
            prediction = model(aug_1_modal_1, aug_1_modal_2, aug_2_modal_1, aug_2_modal_2, proj_head=True, class_head=True)
            
            loss = loss_fn(prediction, sleep_stage)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            prediction_arr.extend(prediction.detach().cpu().squeeze().numpy())
            true_arr.extend(sleep_stage.detach().cpu().squeeze().numpy())
            
        model.eval()
        train_loss /= (i+1)
        prediction_arr = torch.tensor(np.array(prediction_arr))
        true_arr = torch.tensor(np.array(true_arr))
        train_acc = get_accuracy_from_train_process(prediction_arr, true_arr)
        
        plot_train_loss.append(train_loss)
        plot_train_acc.append(train_acc)
        print(f'Epoch: {ep}, Batch: {i+1}, TrainLoss: {loss.item()}, TrainAcc: {train_acc}')
            
        if ep % args['val_freq'] == 0:
            val_acc, val_loss = get_acc_loss_from_dataloader(model, val_lodaer, device, loss_fn)
            print(f'(Validation) Epoch: {ep},  ValLoss: {val_loss}, ValAcc: {val_acc}')
            plot_val_acc.append(val_acc)
            plot_val_loss.append(val_loss)
            
            if val_acc > best_acc:
                print("--------"*15)
                best_acc = val_acc
                time = datetime.now().strftime("%Y%m%d_%H%M%S")
                MODELPATH = os.path.join(downstream_config["model_save_dir"], f'FM_based_classfier_{ep}.pth')
                model_save_format["epoch"] = ep
                model_save_format["state_dict"] = model.state_dict()
                model_save_format["model_path"] = MODELPATH
                model_save_format["train_acc"] = train_acc
                model_save_format["train_loss"] = train_loss
                model_save_format["val_acc"] = val_acc
                model_save_format["val_loss"] = val_loss
                
                torch.save(model_save_format, MODELPATH)
                print("Best Model Saved!")
                print("--------"*15)
    
    print("Finished Training")
    print(f'Best Validation Accuracy: {best_acc}')
    
    return model_save_format, (plot_train_loss, plot_train_acc, plot_val_loss, plot_val_acc)