In [None]:
import os
import sys
sys.path.append("../")
import numpy as np


import torch
import torch.nn as nn
from torch.utils.data import Dataset

from tqdm.notebook import tqdm

import downargs as args
from foundation.models.FOCALModules import FOCAL
from foundation.models.Backbone import DeepSense
from classifier import SleepStageClassifier
import datetime

torch.manual_seed(args.SEED)
torch.cuda.manual_seed(args.SEED)
torch.cuda.manual_seed_all(args.SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

from foundation.data.Dataset import MESAPairDataset
from foundation.data.Augmentaion import init_augmenter

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else "mps")
print(device)

In [None]:
aug_1 = init_augmenter("NoAugmenter", None).to(device)
aug_2 = init_augmenter("NoAugmenter", None).to(device)

In [None]:
def get_accuracy_from_train_process(logit_arr, true_label):
    
    predicted_label = torch.argmax(logit_arr, dim=1)
    acc = torch.sum(predicted_label == true_label).item() / true_label.size(0)

    return acc


def get_acc_loss_from_dataloader(model, downstream_model, dataloder, loss_fn, device):
    
    model.eval()
    
    total_correct = 0
    total_samples = 0
    total_loss = 0
    
    for i, data in enumerate(dataloder):
        ecg, hr, _, sleep_stage = data
        ecg = ecg.to(device)
        hr = hr.to(device)
        
        aug_1_modal_1 = aug_1(ecg)
        aug_2_modal_1 = aug_2(ecg)
        
        aug_1_modal_2 = aug_1(hr)
        aug_2_modal_2 = aug_2(hr)
        
        sleep_stage = sleep_stage.to(device)
        
        mod_feature1, mod_feature2 = model(aug_1_modal_1, aug_1_modal_2, 
                                           aug_2_modal_1, aug_2_modal_2, proj_head=True, class_head=False)
        
        prediction = downstream_model(mod_feature1, mod_feature2)
        loss = loss_fn(prediction, sleep_stage)
        
        total_loss += loss.item()
        total_correct += torch.sum(torch.argmax(prediction, dim=1) == sleep_stage).item()
        total_samples += sleep_stage.size(0)
        
    return total_correct / total_samples, total_loss / (i+1)

In [None]:
def downstream(model, downstream_model, train_loader, val_lodaer, optimizer, loss_fn, downargs, device, model_idx):
    # model.to(device)
    model.train()
    best_acc = 0
    
    plot_train_loss = []
    plot_val_loss = []
    plot_val_acc = []
    plot_train_acc = []
    
    model_save_format = downargs.model_save_format
    model_save_format["lr"] = downargs.downstream_config["lr"]

    modelPATH = os.path.join(downargs.downstream_config["model_save_dir"], downargs.SUBJECT_ID)
    
    if not os.path.exists(modelPATH):
        os.makedirs(modelPATH)
        
    for ep in tqdm(range(downargs.downstream_config["epoch"])):
        prediction_arr = []
        true_arr = []
        train_loss = 0
        model.train()
        
        for i, data in enumerate(train_loader):
            ecg, hr, _, sleep_stage = data
            ecg = ecg.to(device)
            hr = hr.to(device)

            aug_1_modal_1 = aug_1(ecg)
            aug_2_modal_1 = aug_2(ecg)
            
            aug_1_modal_2 = aug_1(hr)
            aug_2_modal_2 = aug_2(hr)
            
            sleep_stage = sleep_stage.to(device)
            
            # For updating the only downstream model
            for param in downstream_model.parameters():
                param.requires_grad = True
            for param in model.parameters():
                param.requires_grad = False
            
            optimizer.zero_grad()
            with torch.no_grad():
                mod_feature1, mod_feature2 = model(aug_1_modal_1, aug_1_modal_2, aug_2_modal_1, aug_2_modal_2, 
                                                   proj_head=True, class_head=False)
                
            prediction = downstream_model(mod_feature1, mod_feature2)
            
            loss = loss_fn(prediction, sleep_stage)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            prediction_arr.extend(prediction.detach().cpu().squeeze().numpy())
            true_arr.extend(sleep_stage.detach().cpu().squeeze().numpy())
            
        model.eval()
        
        train_loss /= (i+1)
        
        prediction_arr = torch.tensor(np.array(prediction_arr))
        true_arr = torch.tensor(np.array(true_arr))
        
        train_acc = get_accuracy_from_train_process(prediction_arr, true_arr)
        
        plot_train_loss.append(train_loss)
        plot_train_acc.append(train_acc)
        
        print(f'Epoch: {ep}, Batch: {i+1}, TrainLoss: {train_loss}, TrainAcc: {train_acc}')
        
        
        if ep % downargs.downstream_config['val_freq'] == 0:
            
            val_acc, val_loss = get_acc_loss_from_dataloader(model, downstream_model, val_lodaer, loss_fn, device)
            print(f'(Validation) Epoch: {ep},  ValLoss: {val_loss}, ValAcc: {val_acc}')
            
            plot_val_acc.append(val_acc)
            plot_val_loss.append(val_loss)
            
            if val_acc > best_acc:
                print("--------"*15)
                best_acc = val_acc
                
                MODELPATH = os.path.join(modelPATH, f'FM_based_classfier_{model_idx}.pth')
                model_save_format["epoch"] = ep
                model_save_format["state_dict"] = model.state_dict()
                model_save_format["model_path"] = MODELPATH
                model_save_format["train_acc"] = train_acc
                model_save_format["train_loss"] = train_loss
                model_save_format["val_acc"] = val_acc
                model_save_format["val_loss"] = val_loss
                
                torch.save(model_save_format, MODELPATH)
                print("Best Model Saved!")
                print("--------"*15)
    
    print("Finished Training")
    print(f'Best Validation Accuracy: {best_acc}')
    
    return model_save_format, (plot_train_loss, plot_train_acc, plot_val_loss, plot_val_acc)

In [None]:
modelpath = args.trainer_config["model_save_dir"]
log_path = args.trainer_config["log_save_dir"]

model_list = os.listdir(modelpath)

In [None]:
train_dataset = MESAPairDataset(file_path=args.data_config['train_data_dir'], 
                                modalities=args.data_config['modalities'],
                                subject_idx=args.data_config['subject_key'],
                                stage=args.data_config['label_key'])

train_loader = torch.utils.data.DataLoader(train_dataset, 
                                            batch_size=args.trainer_config['batch_size'],
                                            shuffle=True,
                                            num_workers=4)

val_dataset = MESAPairDataset(file_path=args.data_config['val_data_dir'],
                                modalities=args.data_config['modalities'],
                                subject_idx=args.data_config['subject_key'],
                                stage=args.data_config['label_key'])

val_loader = torch.utils.data.DataLoader(val_dataset,
                                            batch_size=args.trainer_config['batch_size']//4,
                                            shuffle=False,
                                            num_workers=2)

test_dataset = MESAPairDataset(file_path=args.data_config['test_data_dir'],
                                modalities=args.data_config['modalities'],
                                subject_idx=args.data_config['subject_key'],
                                stage=args.data_config['label_key'])

test_loader = torch.utils.data.DataLoader(test_dataset,
                                            batch_size=args.trainer_config['batch_size']//4,
                                            shuffle=False,
                                            num_workers=2)

In [None]:
for model_name in model_list:
    
    model_index = model_name.split("_")[3].split(".")[0]
    model_ckpt = torch.load(os.path.join(modelpath, model_name), map_location=device)
    
    args.trainer_config = model_ckpt['trainer_config']
    args.focal_config = model_ckpt["focal_config"]
    args.data_config = model_ckpt["data_config"]
    
    args.downstream_config['embedding_dim'] = model_ckpt['focal_config']['embedding_dim']
    
    backbone = DeepSense(args).to(device)
    focal_model = FOCAL(args, backbone).to(device)
    
    backbone = DeepSense(args).to(device)
    focal_model = FOCAL(args, backbone).to(device)
    focal_model.load_state_dict(model_ckpt["focal_state_dict"], strict=False)
    
    downstream_model = SleepStageClassifier(args).to(device)
    
    downstream_loss_fn = nn.CrossEntropyLoss()
    downstream_optimizer = torch.optim.Adam(downstream_model.parameters(), lr=args.downstream_config['lr'])
            
    ckpt, logs = downstream(focal_model, downstream_model, train_loader, val_loader,
                            downstream_optimizer, downstream_loss_fn, args, device, model_index)
    
    
    logPATH = os.path.join(args.downstream_config["log_save_dir"], args.SUBJECT_ID)
    
    if not os.path.exists(logPATH):
        os.makedirs(logPATH)
        
    LOGPATH = os.path.join(args.trainer_config["log_save_dir"], f'FM_based_classfier_{model_index}.npz')
    result_log = np.array(logs)
    np.savez(LOGPATH, result_log)
    
    
    test_acc, test_loss = get_acc_loss_from_dataloader(focal_model, downstream_model, test_loader, downstream_loss_fn, device)
    test_acc = round(test_acc, 2)
    test_loss = round(test_loss, 2)
    
    LOGPATH = os.path.join(args.trainer_config["log_save_dir"], f'FM_based_{model_index}_acc{test_acc}_loss{test_loss}.npz')