In [1]:
import os
import sys
sys.path.append("../")
import numpy as np


import torch
import torch.nn as nn
from torch.utils.data import Dataset

from tqdm.notebook import tqdm

import args
import downargs
from foundation.models.FOCALModules import FOCAL
from foundation.models.Backbone import DeepSense
from classifier import SleepStageClassifier
import datetime

torch.manual_seed(args.SEED)
torch.cuda.manual_seed(args.SEED)
torch.cuda.manual_seed_all(args.SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

from foundation.data.Dataset import MESAPairDataset
from foundation.data.Augmentaion import init_augmenter

  return torch._C._cuda_getDeviceCount() > 0


In [2]:
class MESAPairDataset(Dataset):
    def __init__(self, file_path, modalities=['ecg', 'hr'], subject_idx='subject_idx', stage='stage'):
        super(MESAPairDataset, self).__init__()
        self.root_dir = file_path
        self.files = os.listdir(file_path)
        self.modalities = modalities
        self.subject_idx = subject_idx
        self.stage = stage
        
    def __len__(self):

        return len(self.files)


    def __getitem__(self, idx):
        data = np.load(os.path.join(self.root_dir, self.files[idx])) # numpy file on each sample (segments)
        
        self.modality_1 = torch.tensor(data[self.modalities[0]], dtype=torch.float)
        self.modality_2 = torch.tensor(data[self.modalities[1]], dtype=torch.float)
        self.subject_id = torch.tensor(data[self.subject_idx], dtype=torch.long)
        stage = data[self.stage]
        
        #if self.num_outputs == 4:
        if stage in [1, 2]:
            stage = 1
        elif stage in [3, 4]:
            stage = 2
        elif stage == 5:
            stage = 3
        
        # elif self.num_outputs == 2:
        #     if labels in [1, 2, 3, 4, 5]:
        #         labels = 1
        self.sleep_stage = torch.tensor(stage, dtype=torch.long)

        sample = [self.modality_1, self.modality_2, self.subject_id, self.sleep_stage]
        
        return sample

In [3]:
def get_accuracy_from_train_process(logit_arr, true_label):
    
    predicted_label = torch.argmax(logit_arr, dim=1)
    acc = torch.sum(predicted_label == true_label).item() / true_label.size(0)

    return acc


def get_acc_loss_from_dataloader(model, downstream_model, dataloder, loss_fn, device):
    
    model.eval()
    
    total_correct = 0
    total_samples = 0
    total_loss = 0
    
    for i, data in enumerate(dataloder):
        ecg, hr, _, sleep_stage = data
        ecg = ecg #.to(device) # To-Do: Add device
        hr = hr# .to(device) # To-Do: Add device

        aug_1 = init_augmenter("NoAugmenter", None)
        aug_2 = init_augmenter("NoAugmenter", None)
        
        aug_1_modal_1 = aug_1(ecg)
        aug_2_modal_1 = aug_2(ecg)
        
        aug_1_modal_2 = aug_1(hr)
        aug_2_modal_2 = aug_2(hr)
        
        sleep_stage = sleep_stage # .to(device) # To-Do: Add device
        
        mod_feature1, mod_feature2 = model(aug_1_modal_1, aug_1_modal_2, aug_2_modal_1, aug_2_modal_2, proj_head=True, class_head=False)
        prediction = downstream_model(mod_feature1, mod_feature2)
        
        loss = loss_fn(prediction, sleep_stage)
        
        total_loss += loss.item()
        total_correct += torch.sum(torch.argmax(prediction, dim=1) == sleep_stage).item()
        total_samples += sleep_stage.size(0)
        
    return total_correct / total_samples, total_loss / (i+1)

In [4]:
os.environ['CUDA_VISIBLE_DEVICES'] = '8'
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [5]:
modelpath = args.trainer_config["model_save_dir"]
log_path = args.trainer_config["log_save_dir"]

# Load the config of FM
model_name = 'SSL_focal_model_ep_0.pth'
model_ckpt = torch.load(os.path.join(modelpath, model_name), map_location=torch.device('cpu'))

In [6]:
# Load the model config
args.trainer_config = model_ckpt['trainer_config']
args.focal_config = model_ckpt["focal_config"]
args.data_config = model_ckpt["data_config"]

In [7]:
backbone = DeepSense(args)#.to(args.focal_config["device"])
focal_model = FOCAL(args, backbone)#.to(args.focal_config["device"])
downstream_model = SleepStageClassifier(downargs)

** Finished Initializing DeepSense Backbone **


In [8]:
focal_model_checkpoint = model_ckpt["focal_state_dict"]
# focal_model_checkpoint

In [9]:
focal_model_checkpoint = model_ckpt["focal_state_dict"]
focal_model.load_state_dict(focal_model_checkpoint, strict=False)

downstream_loss_fn = nn.CrossEntropyLoss()
downstream_optimizer = torch.optim.Adam(downstream_model.parameters(), lr=0.001)
# downstream_optimizer = torch.optim.Adam(focal_model.parameters(), lr=0.001)

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
args.data_config#['train_data_dir'],

{'train_data_dir': '/NFS/Users/moonsh/data/mesa/preproc/pair_train',
 'val_data_dir': '/NFS/Users/moonsh/data/mesa/preproc/pair_val',
 'test_data_dir': '/NFS/Users/moonsh/data/mesa/preproc/pair_test',
 'modalities': ['ecg', 'hr'],
 'label_key': 'stage',
 'subject_key': 'subject_idx',
 'augmentation': ['GaussianNoise', 'NoAugmenter'],
 'augmenter_config': {'GaussianNoise': {'max_noise_std': 0.1},
  'AmplitudeScale': {'amplitude_scale': 0.3}}}

In [11]:
train_dataset = MESAPairDataset(file_path=downargs.data_config['train_data_dir'], # To-Do: edit data_config 
                                    modalities=downargs.data_config['modalities'],
                                    subject_idx=downargs.data_config['subject_key'],
                                    stage=downargs.data_config['label_key'])
train_loader = torch.utils.data.DataLoader(train_dataset, 
                                            batch_size=downargs.trainer_config['batch_size'],
                                            shuffle=True,
                                            num_workers=4)

val_dataset = MESAPairDataset(file_path=downargs.data_config['val_data_dir'], # To-Do: edit data_config 
                                modalities=downargs.data_config['modalities'],
                                subject_idx=downargs.data_config['subject_key'],
                                stage=downargs.data_config['label_key'])

val_loader = torch.utils.data.DataLoader(val_dataset,
                                            batch_size=downargs.trainer_config['batch_size']//3,
                                            shuffle=False,
                                            num_workers=2)

In [12]:
downargs.model_save_format

{'train_acc': None,
 'val_acc': None,
 'train_loss': None,
 'val_loss': None,
 'epoch': None,
 'lr': None,
 'model_path': None,
 'model_state_dict': None,
 'batch_size': None}

In [13]:
def downstream(model, downstream_model, train_loader, val_lodaer, optimizer, loss_fn, downargs, device):
    # model.to(device)
    model.train()
    best_acc = 0
    
    plot_train_loss = []
    plot_val_loss = []
    plot_val_acc = []
    plot_train_acc = []
    
    model_save_format = downargs.model_save_format
    model_save_format["lr"] = downargs.downstream_config["lr"]
    start_time = datetime.datetime.now().strftime("%Y%m%d_%H%M")
    modelPATH = os.path.join(downargs.downstream_config["model_save_dir"], start_time)
    if not os.path.exists(modelPATH):
        os.makedirs(modelPATH)
        
    for ep in range(downargs.downstream_config["epoch"]): # TO-DO: adding tqdm
        prediction_arr = []
        true_arr = []
        train_loss = 0
        model.train()
        for i, data in enumerate(train_loader):
            ecg, hr, _, sleep_stage = data
            ecg = ecg#.to(device)
            hr = hr#.to(device)
            
            aug_1 = init_augmenter("NoAugmenter", None)
            aug_2 = init_augmenter("NoAugmenter", None)

            aug_1_modal_1 = aug_1(ecg)
            aug_2_modal_1 = aug_2(ecg)
            
            aug_1_modal_2 = aug_1(hr)
            aug_2_modal_2 = aug_2(hr)
            
            sleep_stage = sleep_stage#.to(device)
            
            # For updating the only downstream model
            for param in downstream_model.parameters():
                param.requires_grad = True
            for param in model.parameters():
                param.requires_grad = False
            
            downstream_optimizer.zero_grad()
            mod_feature1, mod_feature2 = model(aug_1_modal_1, aug_1_modal_2, aug_2_modal_1, aug_2_modal_2, proj_head=True, class_head=False)
            # prediction = model(aug_1_modal_1, aug_1_modal_2, aug_2_modal_1, aug_2_modal_2, proj_head=True, class_head=True)
            prediction = downstream_model(mod_feature1, mod_feature2)
            
            loss = loss_fn(prediction, sleep_stage)
            loss.backward()
            downstream_optimizer.step()
            
            train_loss += loss.item()
            prediction_arr.extend(prediction.detach().cpu().squeeze().numpy())
            true_arr.extend(sleep_stage.detach().cpu().squeeze().numpy())
            
        model.eval()
        train_loss /= (i+1)
        prediction_arr = torch.tensor(np.array(prediction_arr))
        true_arr = torch.tensor(np.array(true_arr))
        train_acc = get_accuracy_from_train_process(prediction_arr, true_arr)
        
        plot_train_loss.append(train_loss)
        plot_train_acc.append(train_acc)
        print(f'Epoch: {ep}, Batch: {i+1}, TrainLoss: {loss.item()}, TrainAcc: {train_acc}')
            
        if ep % downargs.downstream_config['val_freq'] == 0:
            val_acc, val_loss = get_acc_loss_from_dataloader(model, downstream_model, val_lodaer, loss_fn, device)
            print(f'(Validation) Epoch: {ep},  ValLoss: {val_loss}, ValAcc: {val_acc}')
            plot_val_acc.append(val_acc)
            plot_val_loss.append(val_loss)
            
            if val_acc > best_acc:
                print("--------"*15)
                best_acc = val_acc
                time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
                MODELPATH = os.path.join(downargs.downstream_config["model_save_dir"], f'FM_based_classfier_{ep}.pth')
                model_save_format["epoch"] = ep
                model_save_format["state_dict"] = model.state_dict()
                model_save_format["model_path"] = MODELPATH
                model_save_format["train_acc"] = train_acc
                model_save_format["train_loss"] = train_loss
                model_save_format["val_acc"] = val_acc
                model_save_format["val_loss"] = val_loss
                
                torch.save(model_save_format, MODELPATH)
                print("Best Model Saved!")
                print("--------"*15)
    
    print("Finished Training")
    print(f'Best Validation Accuracy: {best_acc}')
    
    return model_save_format, (plot_train_loss, plot_train_acc, plot_val_loss, plot_val_acc)

In [14]:
downstream(focal_model, downstream_model, train_loader, val_loader, downstream_optimizer, downstream_loss_fn, downargs, device)

Loading NoAugmenter augmenter...
Loading NoAugmenter augmenter...
Epoch: 0, Batch: 1, TrainLoss: 1.3724225759506226, TrainAcc: 0.2822677925211098
Loading NoAugmenter augmenter...
Loading NoAugmenter augmenter...
(Validation) Epoch: 0,  ValLoss: 1.9178822040557861, ValAcc: 0.5028248587570622
------------------------------------------------------------------------------------------------------------------------
Best Model Saved!
------------------------------------------------------------------------------------------------------------------------
Loading NoAugmenter augmenter...
Loading NoAugmenter augmenter...
Epoch: 1, Batch: 1, TrainLoss: 1.0241804122924805, TrainAcc: 0.5681544028950543
Loading NoAugmenter augmenter...
Loading NoAugmenter augmenter...
(Validation) Epoch: 1,  ValLoss: 1.649709939956665, ValAcc: 0.5988700564971752
------------------------------------------------------------------------------------------------------------------------
Best Model Saved!
------------------

KeyboardInterrupt: 