In [1]:
import os
import sys
sys.path.append('./')

import args
import argparse
import logging

import torch
from models.AdversarialModel import AdversarialModel
from models.FOCALModules import FOCAL
from models.loss import FOCALLoss
from data.EfficientDataset import MESAPairDataset
from data.Augmentaion import init_augmenter

import datetime
import importlib

  return torch._C._cuda_getDeviceCount() > 0


In [2]:
print(f"base_config: \n {args.base_config}")
# print(f"focal_config: \n {args.focal_config} \n")

base_config: 
 {'train_data_dir': '/data8/jungmin/uot_class/MIE1517_DL/FM_for_bio_signal/src/foundation/dataset/pair_small', 'val_data_dir': '/data8/jungmin/uot_class/MIE1517_DL/FM_for_bio_signal/src/foundation/dataset/pair', 'test_data_dir': '/NFS/Users/moonsh/data/mesa/preproc/pair_test', 'modalities': ['ecg', 'hr'], 'label_key': 'stage', 'subject_key': 'subject_idx', 'train_num_subjects': 100, 'test_num_subjects': 50, 'device': device(type='cpu'), 'log_save_dir': '/data8/jungmin/uot_class/MIE1517_DL/FM_for_bio_signal/logs'}


### DataLoader

In [3]:
train_dataset = MESAPairDataset(file_path=args.base_config['train_data_dir'],
                                    modalities=args.base_config['modalities'],
                                    subject_idx=args.base_config['subject_key'],
                                    stage=args.base_config['label_key'])
train_loader = torch.utils.data.DataLoader(train_dataset, 
                                            batch_size=4,
                                            shuffle=True,
                                            num_workers=1)

In [4]:
print(train_dataset.__len__())

for i , (raw_modal_1, raw_modal_2, subj, sleep) in enumerate(train_loader):
    print(i)
    print(raw_modal_1.shape)
    print(raw_modal_2.shape)
    print(subj.shape)
    print(sleep.shape)
    
    break

14
0
torch.Size([4, 7680])
torch.Size([4, 30])
torch.Size([4])
torch.Size([4])


### Augmentation

In [5]:
print(f"data_config: \n {args.data_config} \n")

data_config: 
 {'modalities': ['ecg', 'hr'], 'label_key': 'stage', 'augmentation': ['GaussianNoise', 'AmplitudeScale'], 'augmenter_config': {'GaussianNoise': {'max_noise_std': 0.1}, 'AmplitudeScale': {'amplitude_scale': 0.5}}, 'num_classes': None} 



In [6]:
aug_1_name = args.data_config['augmentation'][0]
aug_1_config = args.data_config['augmenter_config'].get(aug_1_name, {})
aug_2_name = args.data_config['augmentation'][1]
aug_2_config = args.data_config['augmenter_config'].get(aug_2_name, {})
    
augmenter_1 = init_augmenter(aug_1_name, aug_1_config)
augmenter_2 = init_augmenter(aug_2_name, aug_2_config)

Loading GaussianNoise augmenter...
Loading AmplitudeScale augmenter...


In [7]:
raw_modal_1.shape

torch.Size([4, 7680])

In [8]:
# It changes the shape of input: (B, seq) -> (B
augmenter_1(raw_modal_1).shape

torch.Size([4, 7680])

In [9]:
for i , (raw_modal_1, raw_modal_2, subj, sleep) in enumerate(train_loader):
    aug_1_modal_1 = augmenter_1(raw_modal_1)
    aug_1_modal_2 = augmenter_1(raw_modal_2)
    aug_2_modal_1 = augmenter_2(raw_modal_1)
    aug_2_modal_2 = augmenter_2(raw_modal_2)
    print(aug_1_modal_1.shape)
    print(aug_1_modal_2.shape)
    print(aug_2_modal_1.shape)
    print(aug_2_modal_2.shape)

torch.Size([4, 7680])
torch.Size([4, 30])
torch.Size([4, 7680])
torch.Size([4, 30])
torch.Size([4, 7680])
torch.Size([4, 30])
torch.Size([4, 7680])
torch.Size([4, 30])
torch.Size([4, 7680])
torch.Size([4, 30])
torch.Size([4, 7680])
torch.Size([4, 30])
torch.Size([2, 7680])
torch.Size([2, 30])
torch.Size([2, 7680])
torch.Size([2, 30])


### Backbone

In [10]:
import importlib
from models.Backbone import DeepSense
importlib.reload(args)

<module 'args' from '/data8/jungmin/uot_class/MIE1517_DL/FM_for_bio_signal/src/foundation/args.py'>

In [11]:
backbone_model = DeepSense(args)
# dims = [1, 16, 32]

[1, 16, 32]
ecg extractor is initialized.
hr extractor is initialized.
ecg recurrent layer is initialized.
hr recurrent layer is initialized.
** Finished Initializing DeepSense Backbone **


In [12]:
aug_1_modal_1.shape

torch.Size([2, 7680])

In [13]:
enc_mod_features_1 = backbone_model(aug_1_modal_1, aug_1_modal_2)
enc_mod_features_2 = backbone_model(aug_2_modal_1, aug_2_modal_2)

mod1 cnn feature shape: torch.Size([2, 64, 280]) mod2 cnn feature shape: torch.Size([2, 64, 30])
mod1 rnn feature shape: torch.Size([2, 17920]) mod2 rnn feature shape: torch.Size([2, 1920])
mod1 cnn feature shape: torch.Size([2, 64, 280]) mod2 cnn feature shape: torch.Size([2, 64, 30])
mod1 rnn feature shape: torch.Size([2, 17920]) mod2 rnn feature shape: torch.Size([2, 1920])


In [14]:
enc_mod_features_1

{'ecg': tensor([[-0.0469,  0.0594, -0.0035, -0.1884,  0.0118,  0.1436, -0.0171,  0.0004,
          -0.0408,  0.1022,  0.0754, -0.0747, -0.0902,  0.0702, -0.1361, -0.0884,
          -0.0608,  0.0980,  0.0162, -0.0623,  0.0227,  0.0675,  0.2004, -0.0830,
          -0.0643,  0.0542,  0.0541, -0.0054,  0.0563, -0.0466, -0.1225, -0.0362,
          -0.0038, -0.0609, -0.0492, -0.0611, -0.0693, -0.0300,  0.0859, -0.0529,
           0.0483, -0.1190, -0.0650, -0.0939,  0.0668,  0.1367, -0.1430,  0.0392,
          -0.0744,  0.0925,  0.0746,  0.0222, -0.0917, -0.0927, -0.0793, -0.0003,
           0.0704, -0.0661,  0.0338,  0.0701,  0.0193, -0.0116, -0.1011,  0.0049],
         [-0.0101,  0.0699,  0.0118, -0.1676,  0.0237,  0.1079, -0.0151,  0.0360,
          -0.0587,  0.0968,  0.0721, -0.0561, -0.0806,  0.0266, -0.1426, -0.0640,
          -0.0928,  0.1159, -0.0355, -0.0153,  0.0098,  0.0931,  0.1730, -0.0837,
          -0.0983,  0.0318,  0.0434, -0.0042,  0.0514, -0.0955, -0.1056, -0.0503,
        

In [15]:
print(enc_mod_features_1['ecg'].shape)
print(enc_mod_features_1['hr'].shape)
print(enc_mod_features_2['ecg'].shape)
print(enc_mod_features_2['hr'].shape)

torch.Size([2, 64])
torch.Size([2, 64])
torch.Size([2, 64])
torch.Size([2, 64])


### Focal Model and Loss

In [16]:
from models.FOCALModules import FOCAL
from models.loss import FOCALLoss

In [17]:
backbone_model = DeepSense(args)
focal_model = FOCAL(args, backbone_model)

[1, 16, 32]
ecg extractor is initialized.
hr extractor is initialized.
ecg recurrent layer is initialized.
hr recurrent layer is initialized.
** Finished Initializing DeepSense Backbone **


In [18]:
enc_mod_features_1 = backbone_model(aug_1_modal_1, aug_1_modal_2)
enc_mod_features_2 = backbone_model(aug_2_modal_1, aug_2_modal_2)

mod1 cnn feature shape: torch.Size([2, 64, 280]) mod2 cnn feature shape: torch.Size([2, 64, 30])
mod1 rnn feature shape: torch.Size([2, 17920]) mod2 rnn feature shape: torch.Size([2, 1920])
mod1 cnn feature shape: torch.Size([2, 64, 280]) mod2 cnn feature shape: torch.Size([2, 64, 30])
mod1 rnn feature shape: torch.Size([2, 17920]) mod2 rnn feature shape: torch.Size([2, 1920])


In [19]:
aug_2_modal_2.shape

torch.Size([2, 30])

In [20]:
# proj_head = True
mod_features_1, mod_features_2 = focal_model(aug_1_modal_1, aug_1_modal_2, aug_2_modal_1, aug_2_modal_2, proj_head=True)

mod1 cnn feature shape: torch.Size([2, 64, 280]) mod2 cnn feature shape: torch.Size([2, 64, 30])
mod1 rnn feature shape: torch.Size([2, 17920]) mod2 rnn feature shape: torch.Size([2, 1920])
mod1 cnn feature shape: torch.Size([2, 64, 280]) mod2 cnn feature shape: torch.Size([2, 64, 30])
mod1 rnn feature shape: torch.Size([2, 17920]) mod2 rnn feature shape: torch.Size([2, 1920])


In [21]:
print(mod_features_1)
print(mod_features_1['ecg'].shape)
print(mod_features_1['hr'].shape)

{'ecg': tensor([[ 4.4685e-02, -4.9168e-02,  4.5327e-02,  4.2640e-02,  6.6214e-03,
         -1.1088e-01,  7.6225e-03,  5.8352e-02,  5.9951e-02,  9.9122e-02,
         -9.1643e-02,  1.1393e-01,  2.2038e-02,  2.1702e-02,  1.2288e-01,
          2.4977e-03, -3.9403e-02, -3.9489e-03, -6.5871e-02,  3.4900e-02,
         -5.4816e-02, -6.7001e-02,  3.2699e-02, -8.2026e-02, -1.0802e-01,
          4.6436e-03, -5.1613e-02, -9.6486e-02, -9.5172e-02, -3.7564e-02,
          1.7867e-02, -1.1037e-01, -1.7181e-02, -5.5518e-02, -2.6604e-02,
          2.8248e-02, -9.5827e-02, -7.7724e-02, -9.0723e-02,  1.5823e-02,
         -1.3332e-01, -1.5923e-01,  8.7660e-02, -2.0874e-02,  8.7037e-02,
          4.7225e-02, -5.0642e-02,  3.4934e-02,  5.0135e-02, -1.9298e-01,
         -9.1765e-02, -1.1545e-01, -1.7079e-02,  1.1389e-01,  3.7390e-02,
          1.2760e-01, -4.8292e-02,  6.0565e-02,  4.1838e-03,  3.4358e-02,
          8.7230e-02, -1.0784e-01, -4.1390e-02,  1.2189e-01],
        [ 5.1116e-02, -5.3966e-02,  2.2625

In [22]:
print(mod_features_2)
print(mod_features_2['ecg'].shape)
print(mod_features_2['hr'].shape)

{'ecg': tensor([[ 0.0371, -0.0468,  0.0518,  0.0490,  0.0114, -0.1230,  0.0031,  0.0635,
          0.0655,  0.1005, -0.0900,  0.1190,  0.0243,  0.0245,  0.1211,  0.0090,
         -0.0433, -0.0005, -0.0747,  0.0430, -0.0481, -0.0592,  0.0367, -0.0914,
         -0.1147,  0.0067, -0.0567, -0.0978, -0.0905, -0.0431,  0.0310, -0.1173,
         -0.0158, -0.0544, -0.0356,  0.0230, -0.0930, -0.0780, -0.0846,  0.0162,
         -0.1362, -0.1443,  0.0912, -0.0176,  0.0944,  0.0440, -0.0380,  0.0381,
          0.0506, -0.1917, -0.0929, -0.1129, -0.0240,  0.1162,  0.0391,  0.1196,
         -0.0566,  0.0551, -0.0010,  0.0361,  0.0852, -0.1069, -0.0434,  0.1198],
        [ 0.0484, -0.0547,  0.0299,  0.0495,  0.0006, -0.1061, -0.0130,  0.0322,
          0.0750,  0.1430, -0.0762,  0.0740,  0.0318,  0.0643,  0.1068,  0.0134,
         -0.0256, -0.0038, -0.0278,  0.0136, -0.0891, -0.0785,  0.0376, -0.0965,
         -0.0642,  0.0319, -0.0429, -0.1224, -0.1066, -0.0585,  0.0136, -0.1050,
         -0.0438, -

In [23]:
focal_loss_fn = FOCALLoss(args)
focal_loss_fn

FOCALLoss(
  (criterion): CrossEntropyLoss()
  (similarity_f): CosineSimilarity()
  (orthonal_loss_f): CosineEmbeddingLoss()
)

### Training

In [24]:
train_dataset = MESAPairDataset(file_path=args.base_config['train_data_dir'],
                                    modalities=args.base_config['modalities'],
                                    subject_idx=args.base_config['subject_key'],
                                    stage=args.base_config['label_key'])
train_loader = torch.utils.data.DataLoader(train_dataset, 
                                            batch_size=4,
                                            shuffle=True,
                                            num_workers=1)

In [25]:
if str(list(args.focal_config["backbone"].keys())[0]) == "DeepSense":
    backbone = DeepSense(args).to(args.focal_config["device"])

model = FOCAL(args, backbone)
optimizer = torch.optim.Adam(model.parameters(), lr=args.focal_config["lr"])

[1, 16, 32]
ecg extractor is initialized.
hr extractor is initialized.
ecg recurrent layer is initialized.
hr recurrent layer is initialized.
** Finished Initializing DeepSense Backbone **


  from .autonotebook import tqdm as notebook_tqdm


In [26]:
advs_model = AdversarialModel(args).to(args.subj_invariant_config["device"])
advs_optimizer = torch.optim.Adam(advs_model.parameters(), lr=args.subj_invariant_config['lr'])

In [27]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [28]:
trainer_config = args.trainer_config

aug_1_name = args.data_config['augmentation'][0]
aug_1_config = args.data_config['augmenter_config'].get(aug_1_name, {})
aug_2_name = args.data_config['augmentation'][1]
aug_2_config = args.data_config['augmenter_config'].get(aug_2_name, {})

aug_1 = init_augmenter(aug_1_name, aug_1_config)
aug_2 = init_augmenter(aug_2_name, aug_2_config)

model.train()
best_val_loss = float('inf')

for ep in range(trainer_config['epochs']):
    running_advs_train_loss = 0
    focal_train_loss = 0
    
    for raw_modal_1, raw_modal_2, subj_label, sleep_label in train_loader:
        raw_modal_1, raw_modal_2, subj_label, sleep_label = raw_modal_1.to(device), raw_modal_2.to(device), subj_label.to(device), sleep_label.to(device) # [B, 30], [B, 30*256], [B, 1]
        
        aug_1_modal_1 = aug_1(raw_modal_1)
        aug_2_modal_1 = aug_2(raw_modal_1)
        
        aug_1_modal_2 = aug_1(raw_modal_2)
        aug_2_modal_2 = aug_2(raw_modal_2)
        
        # For updating the only advs_model (classifier)
        for param in model.parameters():
            param.requires_grad = False
        for param in advs_model.parameters():
            param.requires_grad = True
            
        advs_optimizer.zero_grad()
        
        # Using Encoder for classify the subject
        enc_feature_1, enc_feature_2 = model(aug_1_modal_1, aug_1_modal_2, aug_2_modal_1, aug_2_modal_2, proj_head=True)
        # enc_feature1 -> dict // (example) enc_feature1['ecg'] & enc_feature1['hr'] from Augmentation 1
        # enc_feature2 -> dict // (example) enc_feature2['ecg'] & enc_feature2['hr'] from Augmentation 2
        
        
        # Predict the subject
        subj_pred = advs_model(enc_feature_1, enc_feature_2) 
        # or subj_pred = advs_model(enc_feature_1['ecg'], enc_feature_1['hr], enc_feature_2['ecg'], enc_feature_2['hr'])
        print(subj_pred.shape)
        print(subj_label.shape)
        
        advs_loss = advs_model.forward_adversarial_loss(subj_pred, subj_label)
        
        # To-do for calculating the accuracy
        # num_adversary_correct_train_preds += adversarial_loss_fn.get_number_of_correct_preds(x_t1_initial_subject_preds, y)
        # total_num_adversary_train_preds += len(x_t1_initial_subject_preds)
        
        advs_loss.backward()
        advs_optimizer.step()
        
        running_advs_train_loss += advs_loss.item()
        
        # For efficient memory management
        del enc_feature_1, enc_feature_2, subj_pred, advs_loss
        
        # For updating the only Focal model (SSL model)
        for param in model.parameters():
            param.requires_grad = True
        for param in advs_model.parameters():
            param.requires_grad = False
        
        optimizer.zero_grad()

        x1_represent, x2_represent = model(raw_modal_1, raw_modal_2)
        
        x1_embd, x2_embd = model.encoder(raw_modal_1, raw_modal_2)
        subj_pred = advs_model(x1_embd, x2_embd)
        subj_invariant_loss = advs_model.forward_subject_invariance_loss(subj_pred, subj_label, args.subj_invariant_config['adversarial_weighting_factor']) # DONE -> add subject_invariant function loss
        
        focal_loss = focal_loss_fn(x1_represent, x2_represent, subj_invariant_loss) # To-Do -> add regularization term about subject invariant
        focal_loss.backward()
        optimizer.step()
        
        focal_train_loss += focal_loss.item()
        
        # For efficient memory management
        del x1_represent, x2_represent, x1_embd, x2_embd, subj_pred, focal_loss
        torch.cuda.empty_cache()
        
    if ep % args.log_interval == 0:
        print(f"Epoch {ep} - Adversarial Loss: {running_advs_train_loss/ len(train_loader)}, \
            Focal Loss: {focal_train_loss/ len(train_loader)}")
        
        if ep % args.val_interval == 0:
            model.eval()
            advs_model.eval()
            
            advs_val_loss = 0
            focal_val_loss = 0
            
            for raw_modal_1, raw_modal_2, subj_label, sleep_label in val_loader:
                raw_modal_1, raw_modal_2, subj_label, sleep_label = raw_modal_1.to(device), raw_modal_2.to(device), subj_label.to(device), sleep_label.to(device)
                
                with torch.no_grad():
                    x1_represent, x2_represent = model(raw_modal_1, raw_modal_2)
                    x1_embd, x2_embd = model.encoder(raw_modal_1), model.encoder(raw_modal_2)
                    subj_pred = advs_model(x1_embd, x2_embd) # output -> sigmoid value
                    
                    advs_loss = advs_model.loss_fcn(subj_pred, subj_label)
                    focal_loss = focal_loss_fn(x1_represent, x2_represent, subj_pred, subj_label)
                    
                    advs_val_loss += advs_loss.item()
                    focal_val_loss += focal_loss.item()
                    
                    # For efficient memory management
                    del x1_represent, x2_represent, x1_embd, x2_embd, subj_pred, focal_loss
                    torch.cuda.empty_cache()
                    
            print("-----"*10)
            print(f"(Validation) Epoch{ep} - Adversarial Loss: {advs_val_loss/ len(val_loader)}, \
                Focal Loss: {focal_val_loss/ len(val_loader)}")                    
                            
            if focal_val_loss < best_val_loss:
                best_val_loss = focal_val_loss
                
                # To-do -> fix the save model format
                # torch.save(model.state_dict(), os.path.join(args.save_dir, 'focal_model.pth'))
                # torch.save(advs_model.state_dict(), os.path.join(args.save_dir, 'advs_model.pth'))
                print("************* Model Saved *************")
            print("-----"*10)

Loading GaussianNoise augmenter...
Loading AmplitudeScale augmenter...
mod1 cnn feature shape: torch.Size([4, 64, 280]) mod2 cnn feature shape: torch.Size([4, 64, 30])
mod1 rnn feature shape: torch.Size([4, 17920]) mod2 rnn feature shape: torch.Size([4, 1920])
mod1 cnn feature shape: torch.Size([4, 64, 280]) mod2 cnn feature shape: torch.Size([4, 64, 30])
mod1 rnn feature shape: torch.Size([4, 17920]) mod2 rnn feature shape: torch.Size([4, 1920])
torch.Size([4, 4])
torch.Size([4])
torch.Size([4, 4])


IndexError: too many indices for tensor of dimension 1