In [1]:
import os
import sys
sys.path.append('./')

import args
import argparse
import logging

import torch
from models.AdversarialModel import AdversarialModel
from models.FOCALModules import FOCAL
from models.loss import FOCALLoss
from data.EfficientDataset import MESAPairDataset
from data.Augmentaion import init_augmenter

import datetime
import importlib

  return torch._C._cuda_getDeviceCount() > 0


In [2]:
print(f"base_config: \n {args.base_config}")
# print(f"focal_config: \n {args.focal_config} \n")

base_config: 
 {'train_data_dir': '/data8/jungmin/uot_class/MIE1517_DL/FM_for_bio_signal/src/foundation/dataset/pair_small', 'val_data_dir': '/data8/jungmin/uot_class/MIE1517_DL/FM_for_bio_signal/src/foundation/dataset/pair', 'test_data_dir': '/NFS/Users/moonsh/data/mesa/preproc/pair_test', 'modalities': ['ecg', 'hr'], 'label_key': 'stage', 'subject_key': 'subject_idx', 'train_num_subjects': 100, 'test_num_subjects': 50, 'device': device(type='cpu'), 'log_save_dir': '/data8/jungmin/uot_class/MIE1517_DL/FM_for_bio_signal/logs'}


### DataLoader

In [3]:
train_dataset = MESAPairDataset(file_path=args.base_config['train_data_dir'],
                                    modalities=args.base_config['modalities'],
                                    subject_idx=args.base_config['subject_key'],
                                    stage=args.base_config['label_key'])
train_loader = torch.utils.data.DataLoader(train_dataset, 
                                            batch_size=4,
                                            shuffle=True,
                                            num_workers=1)

In [4]:
print(train_dataset.__len__())

for i , (raw_modal_1, raw_modal_2, subj, sleep) in enumerate(train_loader):
    print(i)
    print(raw_modal_1.shape)
    print(raw_modal_2.shape)
    print(subj.shape)
    print(sleep.shape)
    
    break

14
0
torch.Size([4, 7680])
torch.Size([4, 30])
torch.Size([4])
torch.Size([4])


### Augmentation

In [5]:
print(f"data_config: \n {args.data_config} \n")

data_config: 
 {'modalities': ['ecg', 'hr'], 'label_key': 'stage', 'augmentation': ['GaussianNoise', 'AmplitudeScale'], 'augmenter_config': {'GaussianNoise': {'max_noise_std': 0.1}, 'AmplitudeScale': {'amplitude_scale': 0.5}}, 'num_classes': None} 



In [6]:
aug_1_name = args.data_config['augmentation'][0]
aug_1_config = args.data_config['augmenter_config'].get(aug_1_name, {})
aug_2_name = args.data_config['augmentation'][1]
aug_2_config = args.data_config['augmenter_config'].get(aug_2_name, {})
    
augmenter_1 = init_augmenter(aug_1_name, aug_1_config)
augmenter_2 = init_augmenter(aug_2_name, aug_2_config)

Loading GaussianNoise augmenter...
Loading AmplitudeScale augmenter...


In [7]:
raw_modal_1.shape

torch.Size([4, 7680])

In [8]:
# It changes the shape of input: (B, seq) -> (B
augmenter_1(raw_modal_1).shape

torch.Size([4, 7680])

In [9]:
for i , (raw_modal_1, raw_modal_2, subj, sleep) in enumerate(train_loader):
    aug_1_modal_1 = augmenter_1(raw_modal_1)
    aug_1_modal_2 = augmenter_1(raw_modal_2)
    aug_2_modal_1 = augmenter_2(raw_modal_1)
    aug_2_modal_2 = augmenter_2(raw_modal_2)
    print(aug_1_modal_1.shape)
    print(aug_1_modal_2.shape)
    print(aug_2_modal_1.shape)
    print(aug_2_modal_2.shape)

torch.Size([4, 7680])
torch.Size([4, 30])
torch.Size([4, 7680])
torch.Size([4, 30])
torch.Size([4, 7680])
torch.Size([4, 30])
torch.Size([4, 7680])
torch.Size([4, 30])
torch.Size([4, 7680])
torch.Size([4, 30])
torch.Size([4, 7680])
torch.Size([4, 30])
torch.Size([2, 7680])
torch.Size([2, 30])
torch.Size([2, 7680])
torch.Size([2, 30])


### Backbone

In [29]:
import importlib
from models.Backbone import DeepSense
importlib.reload(args)

<module 'args' from '/data8/jungmin/uot_class/MIE1517_DL/FM_for_bio_signal/src/foundation/args.py'>

In [11]:
backbone_model = DeepSense(args)
# dims = [1, 16, 32]

[1, 16, 32]
ecg extractor is initialized.
hr extractor is initialized.
ecg recurrent layer is initialized.
hr recurrent layer is initialized.
** Finished Initializing DeepSense Backbone **


In [12]:
aug_1_modal_1.shape

torch.Size([2, 7680])

In [13]:
enc_mod_features_1 = backbone_model(aug_1_modal_1, aug_1_modal_2)
enc_mod_features_2 = backbone_model(aug_2_modal_1, aug_2_modal_2)

mod1 cnn feature shape: torch.Size([2, 64, 280]) mod2 cnn feature shape: torch.Size([2, 64, 30])
mod1 rnn feature shape: torch.Size([2, 17920]) mod2 rnn feature shape: torch.Size([2, 1920])
mod1 cnn feature shape: torch.Size([2, 64, 280]) mod2 cnn feature shape: torch.Size([2, 64, 30])
mod1 rnn feature shape: torch.Size([2, 17920]) mod2 rnn feature shape: torch.Size([2, 1920])


In [14]:
enc_mod_features_1

{'ecg': tensor([[-0.1241,  0.0844, -0.1616, -0.0966,  0.0134, -0.0635, -0.0689,  0.0298,
          -0.0890, -0.0152,  0.0303,  0.0630,  0.0059,  0.0851,  0.0213, -0.0429,
           0.0028, -0.0003, -0.0404, -0.0465, -0.0803, -0.0385,  0.0005, -0.0915,
          -0.1217,  0.0007, -0.0927,  0.0268,  0.1581,  0.0587,  0.0383,  0.0710,
           0.1060,  0.0111,  0.1551, -0.0105,  0.0744,  0.1144,  0.1411,  0.1285,
          -0.0831, -0.0204,  0.1206,  0.0123, -0.1289, -0.0165, -0.0503,  0.0140,
          -0.0649,  0.0073, -0.0044, -0.0010,  0.0933,  0.0865, -0.0045,  0.0095,
          -0.0885,  0.1091,  0.0897, -0.0932,  0.1393, -0.0972,  0.0073,  0.0770],
         [-0.1030,  0.0623, -0.1084, -0.1130,  0.0150, -0.0558, -0.0655, -0.0128,
          -0.1164, -0.0196,  0.0149,  0.0769,  0.0198,  0.0897,  0.0047, -0.0564,
           0.0142,  0.0122, -0.0319, -0.0310, -0.0421, -0.0547, -0.0215, -0.1030,
          -0.1109, -0.0423, -0.0943,  0.0035,  0.1794,  0.0762,  0.0308,  0.0461,
        

In [15]:
print(enc_mod_features_1['ecg'].shape)
print(enc_mod_features_1['hr'].shape)
print(enc_mod_features_2['ecg'].shape)
print(enc_mod_features_2['hr'].shape)

torch.Size([2, 64])
torch.Size([2, 64])
torch.Size([2, 64])
torch.Size([2, 64])


### Focal Model and Loss

In [16]:
from models.FOCALModules import FOCAL
from models.loss import FOCALLoss

In [17]:
backbone_model = DeepSense(args)
focal_model = FOCAL(args, backbone_model)

[1, 16, 32]
ecg extractor is initialized.
hr extractor is initialized.
ecg recurrent layer is initialized.
hr recurrent layer is initialized.
** Finished Initializing DeepSense Backbone **


In [18]:
enc_mod_features_1 = backbone_model(aug_1_modal_1, aug_1_modal_2)
enc_mod_features_2 = backbone_model(aug_2_modal_1, aug_2_modal_2)

mod1 cnn feature shape: torch.Size([2, 64, 280]) mod2 cnn feature shape: torch.Size([2, 64, 30])
mod1 rnn feature shape: torch.Size([2, 17920]) mod2 rnn feature shape: torch.Size([2, 1920])
mod1 cnn feature shape: torch.Size([2, 64, 280]) mod2 cnn feature shape: torch.Size([2, 64, 30])
mod1 rnn feature shape: torch.Size([2, 17920]) mod2 rnn feature shape: torch.Size([2, 1920])


In [19]:
aug_2_modal_2.shape

torch.Size([2, 30])

In [20]:
# proj_head = True
mod_features_1, mod_features_2 = focal_model(aug_1_modal_1, aug_1_modal_2, aug_2_modal_1, aug_2_modal_2, proj_head=True)

mod1 cnn feature shape: torch.Size([2, 64, 280]) mod2 cnn feature shape: torch.Size([2, 64, 30])
mod1 rnn feature shape: torch.Size([2, 17920]) mod2 rnn feature shape: torch.Size([2, 1920])
mod1 cnn feature shape: torch.Size([2, 64, 280]) mod2 cnn feature shape: torch.Size([2, 64, 30])
mod1 rnn feature shape: torch.Size([2, 17920]) mod2 rnn feature shape: torch.Size([2, 1920])


In [21]:
print(mod_features_1)
print(mod_features_1['ecg'].shape)
print(mod_features_1['hr'].shape)

{'ecg': tensor([[ 5.9375e-02,  9.3346e-02, -5.9988e-02, -8.4248e-03,  2.9164e-02,
          1.5648e-01, -1.0696e-01, -1.0277e-02,  4.7275e-02,  7.8495e-02,
         -1.2658e-01,  1.3161e-01, -1.1611e-01, -1.4131e-01,  3.2412e-02,
         -1.6903e-02, -5.4150e-02, -8.1899e-02,  9.5554e-05, -7.0945e-02,
         -3.4094e-02,  4.6955e-02,  1.3384e-01,  1.2069e-01, -1.9006e-02,
          6.3173e-03, -2.9440e-02,  1.1398e-01, -1.3466e-01, -1.6825e-02,
         -1.0564e-01, -3.7422e-02,  5.6960e-02,  1.2609e-01,  2.6290e-03,
         -1.2372e-01, -8.2032e-02,  1.5359e-02,  4.9434e-02, -6.0125e-02,
          1.2757e-01,  7.0556e-02, -4.1827e-02,  2.9723e-02,  3.5180e-02,
         -8.6259e-02, -4.0477e-02,  1.2347e-01,  1.0928e-01, -9.8584e-02,
          2.9936e-02, -7.4971e-02,  1.5314e-02, -2.3503e-02, -1.1261e-01,
         -4.0138e-02, -5.3316e-02,  1.4977e-01, -5.0693e-02,  5.8202e-02,
          3.2393e-02, -6.2542e-02, -7.2803e-02, -4.1505e-02],
        [ 7.6348e-02,  1.0286e-01, -6.8874

In [22]:
print(mod_features_2)
print(mod_features_2['ecg'].shape)
print(mod_features_2['hr'].shape)

{'ecg': tensor([[ 0.0464,  0.0861, -0.0591, -0.0155,  0.0219,  0.1569, -0.0983, -0.0196,
          0.0447,  0.0720, -0.1324,  0.1300, -0.1093, -0.1331,  0.0272, -0.0097,
         -0.0601, -0.0931, -0.0134, -0.0793, -0.0246,  0.0520,  0.1455,  0.1034,
         -0.0395,  0.0099, -0.0339,  0.1163, -0.1373, -0.0222, -0.1134, -0.0464,
          0.0550,  0.1360,  0.0093, -0.1391, -0.0766,  0.0118,  0.0462, -0.0572,
          0.1289,  0.0721, -0.0339,  0.0215,  0.0445, -0.0769, -0.0340,  0.1261,
          0.1134, -0.1085,  0.0282, -0.0626,  0.0163, -0.0250, -0.1092, -0.0398,
         -0.0568,  0.1510, -0.0642,  0.0530,  0.0482, -0.0744, -0.0790, -0.0436],
        [ 0.0769,  0.0968, -0.0629,  0.0019,  0.0328,  0.1255, -0.0936,  0.0129,
          0.0284,  0.0567, -0.1017,  0.1567, -0.0982, -0.1598,  0.0486,  0.0070,
         -0.0288, -0.0875,  0.0107, -0.0941, -0.0397,  0.0823,  0.1168,  0.1139,
         -0.0090,  0.0207, -0.0179,  0.1101, -0.0954, -0.0127, -0.0626, -0.0294,
          0.0515,  

In [23]:
focal_loss_fn = FOCALLoss(args)
focal_loss_fn

FOCALLoss(
  (criterion): CrossEntropyLoss()
  (similarity_f): CosineSimilarity()
  (orthonal_loss_f): CosineEmbeddingLoss()
)

### Training

In [24]:
train_dataset = MESAPairDataset(file_path=args.base_config['train_data_dir'],
                                    modalities=args.base_config['modalities'],
                                    subject_idx=args.base_config['subject_key'],
                                    stage=args.base_config['label_key'])
train_loader = torch.utils.data.DataLoader(train_dataset, 
                                            batch_size=4,
                                            shuffle=True,
                                            num_workers=1)

In [25]:
if str(list(args.focal_config["backbone"].keys())[0]) == "DeepSense":
    backbone = DeepSense(args).to(args.focal_config["device"])

model = FOCAL(args, backbone)
optimizer = torch.optim.Adam(model.parameters(), lr=args.focal_config["lr"])

[1, 16, 32]
ecg extractor is initialized.
hr extractor is initialized.
ecg recurrent layer is initialized.
hr recurrent layer is initialized.
** Finished Initializing DeepSense Backbone **


  from .autonotebook import tqdm as notebook_tqdm


In [30]:
advs_model = AdversarialModel(args).to(args.subj_invariant_config["device"])
advs_optimizer = torch.optim.Adam(advs_model.parameters(), lr=args.subj_invariant_config['lr'])

In [None]:
def train_SA_Focal(train_loader, val_loader, model, advs_model, 
                   optimizer, advs_optimizer, focal_loss_fn, device, args):
    

In [None]:
trainer_config = args.trainer_config

aug_1_name = args.data_config['augmentation'][0]
aug_1_config = args.data_config['augmenter_config'].get(aug_1_name, {})
aug_2_name = args.data_config['augmentation'][1]
aug_2_config = args.data_config['augmenter_config'].get(aug_2_name, {})

aug_1 = init_augmenter(aug_1_name, aug_1_config)
aug_2 = init_augmenter(aug_2_name, aug_2_config)

model.train()
best_val_loss = float('inf')

for ep in range(trainer_config['epochs']):
    running_advs_train_loss = 0
    focal_train_loss = 0
    
    for raw_modal_1, raw_modal_2, subj_label in train_loader:
        raw_modal_1, raw_modal_2, subj_label = raw_modal_1.to(device), raw_modal_2.to(device), subj_label.to(device) # [B, 30], [B, 30*256], [B, 1]
        
        aug_1_modal_1 = aug_1(raw_modal_1)
        aug_2_modal_1 = aug_2(raw_modal_1)
        
        aug_1_modal_2 = aug_1(raw_modal_2)
        aug_2_modal_2 = aug_2(raw_modal_2)
        
        # For updating the only advs_model (classifier)
        for param in model.parameters():
            param.requires_grad = False
        for param in advs_model.parameters():
            param.requires_grad = True
            
        advs_optimizer.zero_grad()
        
        # Using Encoder for classify the subject
        enc_feature_1, enc_feature_2 = model(aug_1_modal_1, aug_1_modal_2, aug_2_modal_1, aug_2_modal_2, proj_head=True)
        # enc_feature1 -> dict // (example) enc_feature1['ecg'] & enc_feature1['hr'] from Augmentation 1
        # enc_feature2 -> dict // (example) enc_feature2['ecg'] & enc_feature2['hr'] from Augmentation 2
        
        
        # Predict the subject
        subj_preds = advs_model(enc_modal_1, enc_modal_2) 
        # or subj_preds = advs_model(enc_feature_1['ecg'], enc_feature_1['hr], enc_feature_2['ecg'], enc_feature_2['hr'])
        
        advs_loss = advs_model.forward_adversarial_loss(subj_preds, subj_label)
        
        # To-do for calculating the accuracy
        # num_adversary_correct_train_preds += adversarial_loss_fn.get_number_of_correct_preds(x_t1_initial_subject_preds, y)
        # total_num_adversary_train_preds += len(x_t1_initial_subject_preds)
        
        advs_loss.backward()
        advs_optimizer.step()
        
        running_advs_train_loss += advs_loss.item()
        
        # For efficient memory management
        del enc_modal_1, enc_modal_2, subj_preds, advs_loss
        
        # For updating the only Focal model (SSL model)
        for param in model.parameters():
            param.requires_grad = True
        for param in advs_model.parameters():
            param.requires_grad = False
        
        optimizer.zero_grad()

        x1_represent, x2_represent = model(raw_modal_1, raw_modal_2)
        
        x1_embd, x2_embd = model.encoder(raw_modal_1, raw_modal_2)
        subj_pred = advs_model(x1_embd, x2_embd)
        subj_invariant_loss = advs_model.forward_subject_invariance_loss(subj_pred, subj_labels, args.adversarial_weighting_factor) # DONE -> add subject_invariant function loss
        
        focal_loss = focal_loss_fn(x1_represent, x2_represent, subj_invariant_loss) # To-Do -> add regularization term about subject invariant
        focal_loss.backward()
        optimizer.step()
        
        focal_train_loss += focal_loss.item()
        
        # For efficient memory management
        del x1_represent, x2_represent, x1_embd, x2_embd, subj_pred, focal_loss
        torch.cuda.empty_cache()
        
    if ep % args.log_interval == 0:
        print(f"Epoch {ep} - Adversarial Loss: {running_advs_train_loss/ len(train_loader)}, \
            Focal Loss: {focal_train_loss/ len(train_loader)}")
        
        if ep % args.val_interval == 0:
            model.eval()
            advs_model.eval()
            
            advs_val_loss = 0
            focal_val_loss = 0
            
            for raw_modal_1, raw_modal_2, subj in val_loader:
                raw_modal_1, raw_modal_2, subj = raw_modal_1.to(device), raw_modal_2.to(device), subj.to(device)
                
                with torch.no_grad():
                    x1_represent, x2_represent = model(raw_modal_1, raw_modal_2)
                    x1_embd, x2_embd = model.encoder(raw_modal_1), model.encoder(raw_modal_2)
                    subj_pred = advs_model(x1_embd, x2_embd) # output -> sigmoid value
                    
                    advs_loss = advs_model.loss_fcn(subj_pred, subj)
                    focal_loss = focal_loss_fn(x1_represent, x2_represent, subj_pred, subj)
                    
                    advs_val_loss += advs_loss.item()
                    focal_val_loss += focal_loss.item()
                    
                    # For efficient memory management
                    del x1_represent, x2_represent, x1_embd, x2_embd, subj_pred, focal_loss
                    torch.cuda.empty_cache()
                    
            print("-----"*10)
            print(f"(Validation) Epoch{ep} - Adversarial Loss: {advs_val_loss/ len(val_loader)}, \
                Focal Loss: {focal_val_loss/ len(val_loader)}")                    
                            
            if focal_val_loss < best_val_loss:
                best_val_loss = focal_val_loss
                
                # To-do -> fix the save model format
                # torch.save(model.state_dict(), os.path.join(args.save_dir, 'focal_model.pth'))
                # torch.save(advs_model.state_dict(), os.path.join(args.save_dir, 'advs_model.pth'))
                print("************* Model Saved *************")
            print("-----"*10)