In [1]:
import os
import sys
sys.path.append('./')

import args
import argparse
import logging

import torch
from models.AdversarialModel import AdversarialModel
from models.FOCALModules import FOCAL
from models.loss import FOCALLoss
from data.EfficientDataset import MESAPairDataset
from data.Augmentaion import init_augmenter

import datetime
import importlib

  return torch._C._cuda_getDeviceCount() > 0


In [2]:
print(f"base_config: \n {args.base_config}")
# print(f"focal_config: \n {args.focal_config} \n")

base_config: 
 {'train_data_dir': '/data8/jungmin/uot_class/MIE1517_DL/FM_for_bio_signal/src/foundation/dataset/pair_small', 'valid_data_dir': '/data8/jungmin/uot_class/MIE1517_DL/FM_for_bio_signal/src/foundation/dataset/pair_small_valid', 'test_data_dir': '/NFS/Users/moonsh/data/mesa/preproc/pair_test', 'modalities': ['ecg', 'hr'], 'label_key': 'stage', 'subject_key': 'subject_idx', 'train_num_subjects': 100, 'test_num_subjects': 50, 'device': device(type='cpu'), 'log_save_dir': '/data8/jungmin/uot_class/MIE1517_DL/FM_for_bio_signal/logs'}


### DataLoader

In [3]:
train_dataset = MESAPairDataset(file_path=args.base_config['train_data_dir'],
                                    modalities=args.base_config['modalities'],
                                    subject_idx=args.base_config['subject_key'],
                                    stage=args.base_config['label_key'])
train_loader = torch.utils.data.DataLoader(train_dataset, 
                                            batch_size=14,
                                            shuffle=True,
                                            num_workers=1)

In [4]:
print(train_dataset.__len__())

for i , (raw_modal_1, raw_modal_2, subj, sleep) in enumerate(train_loader):
    print(i)
    print(raw_modal_1.shape)
    print(raw_modal_2.shape)
    print(subj.shape)
    print(sleep.shape)
    
    break

16
0
torch.Size([14, 7680])
torch.Size([14, 30])
torch.Size([14])
torch.Size([14])


### Augmentation

In [5]:
print(f"data_config: \n {args.data_config} \n")

data_config: 
 {'modalities': ['ecg', 'hr'], 'label_key': 'stage', 'augmentation': ['GaussianNoise', 'AmplitudeScale'], 'augmenter_config': {'GaussianNoise': {'max_noise_std': 0.1}, 'AmplitudeScale': {'amplitude_scale': 0.5}}, 'num_classes': None} 



In [6]:
aug_1_name = args.data_config['augmentation'][0]
aug_1_config = args.data_config['augmenter_config'].get(aug_1_name, {})
aug_2_name = args.data_config['augmentation'][1]
aug_2_config = args.data_config['augmenter_config'].get(aug_2_name, {})
    
augmenter_1 = init_augmenter(aug_1_name, aug_1_config)
augmenter_2 = init_augmenter(aug_2_name, aug_2_config)

Loading GaussianNoise augmenter...
Loading AmplitudeScale augmenter...


In [7]:
raw_modal_1.shape

torch.Size([14, 7680])

In [8]:
# It changes the shape of input: (B, seq) -> (B
augmenter_1(raw_modal_1).shape

torch.Size([14, 7680])

In [9]:
for i , (raw_modal_1, raw_modal_2, subj, sleep) in enumerate(train_loader):
    aug_1_modal_1 = augmenter_1(raw_modal_1)
    aug_1_modal_2 = augmenter_1(raw_modal_2)
    aug_2_modal_1 = augmenter_2(raw_modal_1)
    aug_2_modal_2 = augmenter_2(raw_modal_2)
    print(aug_1_modal_1.shape)
    print(aug_1_modal_2.shape)
    print(aug_2_modal_1.shape)
    print(aug_2_modal_2.shape)

torch.Size([14, 7680])
torch.Size([14, 30])
torch.Size([14, 7680])
torch.Size([14, 30])
torch.Size([2, 7680])
torch.Size([2, 30])
torch.Size([2, 7680])
torch.Size([2, 30])


### Backbone

In [10]:
import importlib
from models.Backbone import DeepSense
importlib.reload(args)

<module 'args' from '/data8/jungmin/uot_class/MIE1517_DL/FM_for_bio_signal/src/foundation/args.py'>

In [11]:
backbone_model = DeepSense(args)
# dims = [1, 16, 32]

ecg extractor is initialized.
hr extractor is initialized.
ecg recurrent layer is initialized.
hr recurrent layer is initialized.
** Finished Initializing DeepSense Backbone **


In [12]:
aug_1_modal_1.shape

torch.Size([2, 7680])

In [13]:
enc_mod_features_1 = backbone_model(aug_1_modal_1, aug_1_modal_2)
enc_mod_features_2 = backbone_model(aug_2_modal_1, aug_2_modal_2)

In [14]:
enc_mod_features_1

{'ecg': tensor([[ 4.2017e-02, -1.9050e-03, -1.0544e-01, -2.3508e-02,  2.7437e-02,
           1.0870e-01, -8.6688e-02, -1.3929e-02, -9.3054e-02,  7.1771e-02,
          -2.2432e-02, -7.7172e-02, -6.7358e-02, -5.2277e-02,  1.1432e-02,
           2.6332e-02,  7.7238e-02, -1.2176e-02,  1.0183e-01,  6.7707e-02,
           5.6365e-03,  1.1129e-01,  4.2131e-02, -4.5158e-02, -2.7685e-02,
          -6.4567e-02,  9.3014e-03, -4.6144e-02,  1.2319e-01,  6.4648e-02,
          -4.3848e-02, -1.3296e-01, -2.0376e-02,  5.3034e-02, -7.3046e-02,
          -1.0714e-01,  3.9749e-02, -6.5112e-02, -4.5738e-03,  8.8621e-02,
           6.3777e-03,  4.8814e-02,  5.5052e-02,  5.4076e-02,  4.7503e-02,
           1.4531e-02, -2.8623e-02, -1.2066e-01,  1.1547e-01,  5.1368e-02,
           5.3672e-02, -3.4242e-03,  5.5235e-02,  9.0216e-02,  1.0376e-01,
           4.2634e-02, -3.9785e-02, -9.0280e-02,  2.9262e-02,  1.4152e-01,
          -1.0171e-01,  1.0605e-01, -2.3399e-02, -6.4366e-02],
         [ 4.6072e-02, -1.1256

In [15]:
print(enc_mod_features_1['ecg'].shape)
print(enc_mod_features_1['hr'].shape)
print(enc_mod_features_2['ecg'].shape)
print(enc_mod_features_2['hr'].shape)

torch.Size([2, 64])
torch.Size([2, 64])
torch.Size([2, 64])
torch.Size([2, 64])


### Focal Model and Loss

In [16]:
from models.FOCALModules import FOCAL
from models.loss import FOCALLoss

In [17]:
backbone_model = DeepSense(args)
focal_model = FOCAL(args, backbone_model)

ecg extractor is initialized.
hr extractor is initialized.
ecg recurrent layer is initialized.
hr recurrent layer is initialized.
** Finished Initializing DeepSense Backbone **


In [18]:
enc_mod_features_1 = backbone_model(aug_1_modal_1, aug_1_modal_2)
enc_mod_features_2 = backbone_model(aug_2_modal_1, aug_2_modal_2)

In [19]:
aug_2_modal_2.shape

torch.Size([2, 30])

In [20]:
# proj_head = True
mod_features_1, mod_features_2 = focal_model(aug_1_modal_1, aug_1_modal_2, aug_2_modal_1, aug_2_modal_2, proj_head=True)

In [21]:
# print(mod_features_1)
# print(mod_features_1['ecg'].shape)
# print(mod_features_1['hr'].shape)

### Training

In [22]:
train_dataset = MESAPairDataset(file_path=args.base_config['train_data_dir'],
                                    modalities=args.base_config['modalities'],
                                    subject_idx=args.base_config['subject_key'],
                                    stage=args.base_config['label_key'])
train_loader = torch.utils.data.DataLoader(train_dataset, 
                                            batch_size=16,
                                            shuffle=True,
                                            num_workers=1)

valid_dataset = MESAPairDataset(file_path=args.base_config['valid_data_dir'],
                                    modalities=args.base_config['modalities'],
                                    subject_idx=args.base_config['subject_key'],
                                    stage=args.base_config['label_key'])
valid_loader = torch.utils.data.DataLoader(valid_dataset, 
                                            batch_size=16,
                                            shuffle=True,
                                            num_workers=1)

In [23]:
if str(list(args.focal_config["backbone"].keys())[0]) == "DeepSense":
    backbone = DeepSense(args).to(args.focal_config["device"])

model = FOCAL(args, backbone)
optimizer = torch.optim.Adam(model.parameters(), lr=args.focal_config["lr"])
focal_loss_fn = FOCALLoss(args)

ecg extractor is initialized.
hr extractor is initialized.
ecg recurrent layer is initialized.
hr recurrent layer is initialized.
** Finished Initializing DeepSense Backbone **


  from .autonotebook import tqdm as notebook_tqdm


In [24]:
advs_model = AdversarialModel(args).to(args.subj_invariant_config["device"])
advs_optimizer = torch.optim.Adam(advs_model.parameters(), lr=args.subj_invariant_config['lr'])

In [25]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [26]:
trainer_config = args.trainer_config

aug_1_name = args.data_config['augmentation'][0]
aug_1_config = args.data_config['augmenter_config'].get(aug_1_name, {})
aug_2_name = args.data_config['augmentation'][1]
aug_2_config = args.data_config['augmenter_config'].get(aug_2_name, {})

aug_1 = init_augmenter(aug_1_name, aug_1_config)
aug_2 = init_augmenter(aug_2_name, aug_2_config)

model.train()
best_val_loss = float('inf')

for ep in range(trainer_config['epochs']):
    running_advs_train_loss = 0
    focal_train_loss = 0
    
    for raw_modal_1, raw_modal_2, subj_label, sleep_label in train_loader:
        raw_modal_1, raw_modal_2, subj_label, sleep_label = raw_modal_1.to(device), raw_modal_2.to(device), subj_label.to(device), sleep_label.to(device) # [B, 30], [B, 30*256], [B, 1]
        
        aug_1_modal_1 = aug_1(raw_modal_1)
        aug_2_modal_1 = aug_2(raw_modal_1)
        
        aug_1_modal_2 = aug_1(raw_modal_2)
        aug_2_modal_2 = aug_2(raw_modal_2)
        
        # For updating the only advs_model (classifier)
        for param in model.parameters():
            param.requires_grad = False
        for param in advs_model.parameters():
            param.requires_grad = True
            
        advs_optimizer.zero_grad()
        
        # Using Encoder for classify the subject
        enc_feature_1, enc_feature_2 = model(aug_1_modal_1, aug_1_modal_2, aug_2_modal_1, aug_2_modal_2, proj_head=True)
        # enc_feature1 -> dict // (example) enc_feature1['ecg'] & enc_feature1['hr'] from Augmentation 1
        # enc_feature2 -> dict // (example) enc_feature2['ecg'] & enc_feature2['hr'] from Augmentation 2
        
        
        # Predict the subject
        subj_pred = advs_model(enc_feature_1, enc_feature_2)     
        advs_loss = advs_model.forward_adversarial_loss(subj_pred, subj_label)
        
        # To-do for calculating the accuracy
        # num_adversary_correct_train_preds += adversarial_loss_fn.get_number_of_correct_preds(x_t1_initial_subject_preds, y)
        # total_num_adversary_train_preds += len(x_t1_initial_subject_preds)
        
        advs_loss.backward()
        advs_optimizer.step()
        
        running_advs_train_loss += advs_loss.item()
        
        # For efficient memory management
        del enc_feature_1, enc_feature_2, subj_pred, advs_loss
        
        # For updating the only Focal model (SSL model)
        for param in model.parameters():
            param.requires_grad = True
        for param in advs_model.parameters():
            param.requires_grad = False
        
        optimizer.zero_grad()

        enc_feature_1, enc_feature_2 = model(aug_1_modal_1, aug_1_modal_2, aug_2_modal_1, aug_2_modal_2, proj_head=True)
        
        subj_pred = advs_model(enc_feature_1, enc_feature_2) 
        subj_invariant_loss = advs_model.forward_subject_invariance_loss(subj_pred, subj_label, args.subj_invariant_config['adversarial_weighting_factor']) # DONE -> add subject_invariant function loss
        
        focal_loss = focal_loss_fn(enc_feature_1, enc_feature_2, subj_invariant_loss) # To-Do -> add regularization term about subject invariant
        focal_loss.backward()
        optimizer.step()
        
        focal_train_loss += focal_loss.item()
        
        # For efficient memory management
        del enc_feature_1, enc_feature_2, subj_pred, focal_loss
        torch.cuda.empty_cache()
        
    if ep % trainer_config['log_interval'] == 0:
        print(f"Epoch {ep} - Adversarial Loss: {running_advs_train_loss/ len(train_loader)}, \
            Focal Loss: {focal_train_loss/ len(train_loader)}")
        
        if ep % trainer_config['val_interval'] == 0:
            model.eval()
            advs_model.eval()
            
            advs_val_loss = 0
            focal_val_loss = 0
            
            for raw_modal_1, raw_modal_2, subj_label, sleep_label in valid_loader:
                raw_modal_1, raw_modal_2, subj_label, sleep_label = raw_modal_1.to(device), raw_modal_2.to(device), subj_label.to(device), sleep_label.to(device)
                
                aug_1_modal_1, aug_2_modal_1  = aug_1(raw_modal_1), aug_2(raw_modal_1)
                aug_1_modal_2, aug_2_modal_2  = aug_1(raw_modal_2), aug_2(raw_modal_2)
                
                with torch.no_grad():
                    # x1_represent, x2_represent = model(raw_modal_1, raw_modal_2)
                    enc_feature_1, enc_feature_2 = model(aug_1_modal_1, aug_1_modal_2, aug_2_modal_1, aug_2_modal_2, proj_head=True)
                    subj_pred = advs_model(enc_feature_1, enc_feature_2)
                    
                    advs_loss = advs_model.forward_adversarial_loss(subj_pred, subj_label)
                    focal_loss = focal_loss_fn(enc_feature_1, enc_feature_2, subj_invariant_loss) # To-Do -> add regularization term about subject invariant
                    
                    advs_val_loss += advs_loss.item()
                    focal_val_loss += focal_loss.item()
                    
                    # For efficient memory management
                    del enc_feature_1, enc_feature_2, subj_pred, focal_loss, advs_loss
                    torch.cuda.empty_cache()
                    
            print("-----"*10)
            print(f"(Validation) Epoch{ep} - Adversarial Loss: {advs_val_loss/ len(valid_loader)}, \
                Focal Loss: {focal_val_loss/ len(valid_loader)}")                    
                            
            if focal_val_loss < best_val_loss:
                best_val_loss = focal_val_loss
                
                # To-do -> fix the save model format
                # torch.save(model.state_dict(), os.path.join(args.save_dir, 'focal_model.pth'))
                # torch.save(advs_model.state_dict(), os.path.join(args.save_dir, 'advs_model.pth'))
                print("************* Model Saved *************")
            print("-----"*10)

Loading GaussianNoise augmenter...
Loading AmplitudeScale augmenter...
Epoch 0 - Adversarial Loss: 36.68265151977539,             Focal Loss: 0.15066389739513397
--------------------------------------------------
(Validation) Epoch0 - Adversarial Loss: 36.60572814941406,                 Focal Loss: 0.03700614720582962
************* Model Saved *************
--------------------------------------------------
