In [1]:
import os
use_gpu = '5' # kimbg
os.environ["CUDA_VISIBLE_DEVICES"] = use_gpu 
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"

import sys
from glob import glob
from tqdm import tqdm
import time
import numpy as np
import pandas as pd
# pd.set_option('display.max_columns', None)  # 최대 100개 열 표시x
# pd.set_option('display.max_colwidth', None)  # 열 너비 제한 해제

import matplotlib.pyplot as plt
from IPython.display import display
from pathlib import Path
import torch
import torch.nn as nn
import torch.amp as amp
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler

In [2]:
from model_architecture import SynAI
from processing import utils, dataset, transforms
from trainer import train_engine, losses

### Configuration Setup

In [None]:
class CFG:
    ###--- Genernal Settings ---###
    seed = 0 ; utils.SetSeedEverything(seed, fully_deterministic=True)
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    use_amp = torch.cuda.is_available()
    amp_scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
    num_workers = 4
    papermill = False
    model_auto_remove = False


    ###--- Main Config. ---###
    CurStep = 3  # {3: ECG-labeled prediction}
    model_save_root_path = '/mnt/home/bgk/macai-model-experimental/_kimbg_code/SynAI_v2/outputs/checkpoint'
    model_dir_name = f'CL_step{CurStep}-TRN-MacAI_v1_2-0822'
    model_save_path = os.path.join(model_save_root_path, model_dir_name)
    os.makedirs(model_save_path, exist_ok=True)


    ###--- Dataset Load ---###
    dataset_fraction = 1 # 0~1 if False: fully use
    interest_windows = [14]  # [3,7,14,30,90,180,365]
    add_status = False

    target_cols = {'input':'bpf0540_or3-npy_path',}
    subset_cols = [target_cols['input']] + ['kfold']

    # if CurStep == 4:
    #     pass    
    if CurStep == 3:
        # target_class_list = ['AFIB_AFL-keyword_v2', 'CIA-keyword_v2', 'AA-keyword_v2', 'VP-keyword_v2', 'BBB-keyword_v2']
        target_class_list = ['AFIB_AFL-keyword_v2', 'CIA-keyword_v2']
        target_cols['label'] = []
        for window in interest_windows:
            for target_class in target_class_list:
                target_cols['label'] += [f'ECG_event_{window}d_{target_class}_onset']
        if add_status: 
            target_cols['label'] += target_class_list[:]
        subset_cols += target_cols['label']
        dataset_df_path = '/home/Datasets_processed/EKG_Latest/processed_metadata_v6-sv2_subset_2_holdin.csv'
        test_dataset_df_path = '/home/Datasets_processed/EKG_Latest/processed_metadata_v6-sv2_subset_2_holdout.csv'
        # public_trn_dataset_df_path = '/home/Datasets_processed/BenchmarkSets/MIMIC_IV_ECG/preprocessed_v6-trn.csv'
        public_trn_dataset_df_path = None
        public_val_dataset_df_path = '/home/Datasets_processed/BenchmarkSets/MIMIC_IV_ECG/preprocessed_v6-val.csv'
        # public_val_dataset_df_path = None
        # snuh_trn_dataset_df_path = '/home/Datasets_processed/EKG_Latest/SNUH/metadata_v4-ms_v2-trn.csv'
        snuh_trn_dataset_df_path = None
        snuh_val_dataset_df_path = '/home/Datasets_processed/EKG_Latest/SNUH/metadata_v4-ms_v2-val.csv'
        target_fold_list = [0]  # [0,2,4,6,8]
        n_folds, trn_ratio, val_ratio = 10, 9, 1
        weighted_random_sampler = 'pid_log_weight'  # pid_log_weight, pid_weight
        if weighted_random_sampler:
            subset_cols.append(weighted_random_sampler)
        model_auto_remove = True

    if CurStep == 2:
        target_cols['label'] = ['AFIB_AFL-keyword_v2', 'CIA-keyword_v2']
        subset_cols += target_cols['label']
        dataset_df_path = '/home/Datasets_processed/EKG_Latest/processed_metadata_v6-subset_1.csv'
        test_dataset_df_path = None
        # public_trn_dataset_df_path = '/home/Datasets_processed/BenchmarkSets/MIMIC_IV_ECG/preprocessed_v6-subset_1_trn.csv'
        public_trn_dataset_df_path = None
        # public_val_dataset_df_path = '/home/Datasets_processed/BenchmarkSets/MIMIC_IV_ECG/preprocessed_v6-subset_1_val.csv'
        public_val_dataset_df_path = None
        target_fold_list = [0]  # [0,2,4,6,8]
        n_folds, trn_ratio, val_ratio = 10, 9, 1
        weighted_random_sampler = 'pid_log_weight'  # pid_log_weight, pid_weight
        if weighted_random_sampler:
            subset_cols.append(weighted_random_sampler)
        model_auto_remove = True
        
    filter_col = target_cols['label'][-1]
    print('filter_col : ', filter_col)
    print('subset_cols : ', subset_cols)
    def df_read_and_filtering(df_path, filter_col, subset_cols, prefix = None):
        df = pd.read_csv(df_path)
        print(f'Original {prefix} dataframe shape : {df.shape}')
        con = (df[filter_col].isin([0,1]))
        df = df[con].reset_index(drop=True)[subset_cols].copy()
        print(f'Filtered {prefix} dataframe shape : {df.shape}') 
        return df
    
    dataset_df = df_read_and_filtering(dataset_df_path, filter_col, subset_cols, prefix = 'EUMC-trn-set')
    if test_dataset_df_path:
        test_dataset_df = df_read_and_filtering(test_dataset_df_path, filter_col, subset_cols, prefix = 'EUMC-tst-set')
    if public_trn_dataset_df_path:
        public_trn_dataset_df = df_read_and_filtering(public_trn_dataset_df_path, filter_col, subset_cols, prefix = 'BIDMC-trn-set')
        dataset_df = pd.concat([dataset_df, public_trn_dataset_df], axis=0).reset_index(drop=True)
    if public_val_dataset_df_path:
        public_val_dataset_df = df_read_and_filtering(public_val_dataset_df_path, filter_col, subset_cols, prefix = 'BIDMC-tst-set')
    if snuh_trn_dataset_df_path:
        snuh_trn_dataset_df = df_read_and_filtering(snuh_trn_dataset_df_path, filter_col, subset_cols, prefix = 'SNUH-trn-set')
        dataset_df = pd.concat([dataset_df, snuh_trn_dataset_df], axis=0).reset_index(drop=True)
    if snuh_val_dataset_df_path:
        snuh_val_dataset_df = df_read_and_filtering(snuh_val_dataset_df_path, filter_col, subset_cols, prefix = 'SNUH-tst-set')
    print(f'*** FINAL TRAIN dataframe shape *** : {dataset_df.shape}')


    ###--- Data Input Config. ---###
    input_config = {
        'seq_length':2560, 'in_channels':12, 'num_classes':len(target_cols['label']),
        'target_lead':'12lead'  # I, II, III, aVR, aVL, aVF, V1, V2, V3, V4, V5, V6, limb, precordial, 12lead
    }


    ###--- Model Architecture ---###
    model_architecture = SynAI.Build_Model_250526
    model_info = {
        #--- meta ---#
        'mode':'finetuning',  # finetuning, pretraining
        'weights_init':'SSL_transfer',  # scratch, SSL_transfer, DST_transfer

        #--- SynAI config ---#
        'name':'MAE_1D_250409_v3',
        'config':{
            'embed_dim':768,  # 384 768
            'patch_size':32, 
            'seq_length':input_config['seq_length'],
            'in_channels':input_config['in_channels'], 
            'encoder':'vit_encoder',
            'merge_mode':'projection',  # linear_projection avg add
            #--- for self-supervised learning ---#
            'decoder_depth':2, 
            'decoder_num_heads':8,
            'stft_loss_ratio':0,
        }
    }
    if model_info['mode'] == 'finetuning':
        model_info['config'].update({'num_classes':input_config['num_classes']})
    if 'transfer' in model_info['weights_init']:
        model_info['prev_model_path'] = os.path.join(
            '/home/bgk/macai-model-experimental/checkpoint',
            'Step1-TRN_VAL-0526/Best_loss-Ep_550-Lo_0.0020.pth'
        )


    ###--- Modeling Options ---###
    batch_size = 512
    epoch = 30
    valid_interval = 1 # 10
    early_stopping_patience = 5
    early_stopping_mode = 'max' # 'min' -> loss, 'max' -> main_metric
    start_lr, final_lr = 2e-4, 1e-5
    start_factor, warmup = 1, 0.05
    scheduler_lr = True # True
    clip_grad_norm = 1.0
    early_stopping = utils.EarlyStopping
    

    ###--- Loss Functions & Metrics ---###
    pos_weight = None
    if pos_weight:
        train_labels = torch.tensor(dataset_df[target_cols['label']].values)
        class_counts_pos = (train_labels == 1).sum(dim=0)
        class_counts_neg = (train_labels == 0).sum(dim=0)
        class_counts_pos = class_counts_pos.float() + 1e-8  # epsilon
        class_counts_neg = class_counts_neg.float() + 1e-8  # epsilon
        pos_weight = torch.tensor(class_counts_neg / class_counts_pos, device=device)
        pos_weight = pos_weight / pos_weight.mean()
    print(f"계산된 pos_weight: {pos_weight}")

    task_weights = []
    pri_task_w, aux_task_w = 1, 0.1
    for iw in interest_windows:
        if iw == 14:
            task_weights += [pri_task_w,]*2
        else:
            task_weights += [aux_task_w,]*2
    if add_status:
        task_weights += [aux_task_w,]*2
    task_weights = torch.tensor(task_weights, device=device)
    print(f"계산된 task_weights: {task_weights}")

    loss_fn = losses.AuxFocalSmoothBCELoss(pos_weight=pos_weight, task_weights=task_weights, gamma=0, smoothing=0.1)


    #--- Transforms ---#
    transforms_config = {}
    transforms_config['NormalizeECG'] = {'method':"tanh", 'scope':"lead-wise", 'scale':1}
    transforms_config['RandomLeadSwapping'] = {'swap_ratio':0.5, 'swap_pairs':None, 'p':0.25}    
    # transforms_config['ApplyGaussianNoise(add_per_lead)'] = {'noise_level': 0.1, 'mode': 'add', 'per_lead': True, 'p':0.5}
    # transforms_config['ApplyGaussianNoise(mul_per_lead)'] = {'noise_level': 0.05, 'mode': 'mul', 'per_lead': True, 'p':0.5}
    # transforms_config['ApplySinusoidalNoise(baseline_drift)'] = {
    #     'frequency_range': (0.1, 1), 'amplitude_range': (0.01, 0.1), 'mode':'add', 'per_lead': False, 'p':0.5}
    # transforms_config['ApplySinusoidalNoise(EMG)'] = {
    #     'frequency_range': (30, 100), 'amplitude_range': (0.01, 0.1), 'mode':'add', 'per_lead': True, 'p':0.5}
    # transforms_config['ApplySinusoidalNoise(powerline_noise)'] = {
    #     'frequency_range': (50,60), 'amplitude_range': (0.01, 0.1), 'mode':'add', 'per_lead': False, 'p':0.5}
    transforms_config['AddTimeWarp'] = {'warp_factor_range':(0.75, 1.25), 'padding_mode':"repeat", 'p':0.25}
    transforms_config['RandomLeadMasking'] = {'mask_ratio':0.5, 'mask_leads':None, 'p':0.25}
    transforms_list = [
        getattr(transforms, key.split('(')[0])(**params)
        for key, params in transforms_config.items()
    ]
    transform_pipeline_dict = {}
    transform_pipeline_dict['train'] = transforms.TransformPipeline(transforms_list)
    transform_pipeline_dict['val'] = transforms.TransformPipeline(transforms_list[:1])  # only scaling


    ###--- check gpu ---###
    if torch.cuda.is_available(): print(f"CUDA is available. GPU No.{use_gpu}")
    else: print("CUDA is not available.")


### For papermill

In [3]:
# # param_dict would be defined from papermill

# if CFG.papermill:
#     for k, v in param_dict.items():
#         if k in ['batch_size', 'start_lr']:
#             setattr(CFG, f'{k}', v)  # 'k'라는 속성을 동적으로 추가
#             print(k, CFG.__dict__[f'{k}'])
            
#         elif k == 'model_dir_name':
#             setattr(CFG, f'{k}', v)  # 'k'라는 속성을 동적으로 추가
#             CFG.model_save_path = os.path.join(CFG.model_save_root_path, CFG.model_dir_name)
#             os.makedirs(CFG.model_save_path, exist_ok=True)
#             print(k, CFG.__dict__[f'{k}'])

#         elif k == 'transforms_config':
#             setattr(CFG, f'{k}', v)  # 'k'라는 속성을 동적으로 추가
#             transforms_list = [
#                 getattr(_kimbg_transforms, key.split('(')[0])(**params)
#                 for key, params in CFG.transforms_config.items()
#             ]
#             CFG.transform_pipeline = _kimbg_transforms.TransformPipeline(transforms_list)
#             print('transforms_config', CFG.transforms_config)

#         elif k == 'interest_windows':
#             setattr(CFG, f'{k}', v)  # 'k'라는 속성을 동적으로 추가
#             CFG.interest_windows = [v]
#             CFG.label_col = [f'label_ECG_AFIB_AFL_within_{i}d' for i in CFG.interest_windows]
#             CFG.target_cols = {
#                 'input':'bpf0540_or3-npy_path', 'label':CFG.label_col
#             }
#             CFG.dataset_df_path = "/home/Datasets_processed/raw_ECGs/meta_log_v8-labelled_fold-S2.csv"
#             CFG.dataset_df = pd.read_csv(CFG.dataset_df_path)
#             print(f'Original DataFrame shape : {CFG.dataset_df.shape}')
#             con1 = (CFG.dataset_df[CFG.label_col[-1]]!=-1)
#             CFG.dataset_df = CFG.dataset_df[con1].reset_index(drop=True)
#             print(f'Filtered DataFrame shape (after removing no_label) : {CFG.dataset_df.shape}')

#             # Dataset Matched with Holter Label
#             CFG.holter_dataset_df_path = "/home/Datasets_processed/raw_ECGs/meta_log_v8-labelled_fold-S3.csv"
#             CFG.holter_label_col = [f'label_holter_AFIB_AFL_within_{i}d' for i in CFG.interest_windows]
#             CFG.holter_target_cols = {
#                 'input':'bpf0540_or3-npy_path', 'label':CFG.holter_label_col
#             }
#             CFG.holter_dataset_df = pd.read_csv(CFG.holter_dataset_df_path)
#             print(f'Original holter_dataset_df shape : {CFG.holter_dataset_df.shape}')
#             holter_con1 = (CFG.holter_dataset_df[CFG.holter_label_col[-1]]!=-1)
#             CFG.holter_dataset_df = CFG.holter_dataset_df[holter_con1].reset_index(drop=True)
#             print(f'Filtered holter_dataset_df shape (after removing no_label) : {CFG.holter_dataset_df.shape}')
            
#         else:
#             CFG.model_info[f'{k}'] = v
#             print(k, CFG.model_info[f'{k}'])


In [4]:
model = CFG.model_architecture(CFG.model_info).to(CFG.device)
total_params = sum(p.numel() for p in model.parameters())
print(f"Model Parameter Count: {total_params*1e-6:.2f}M")

model name : MAE_1D_250409_v3
seq_length 2560, in_channels 12, patch_size 32, embed_dim 768, token_len 80, 
model was loaded from '/home/bgk/macai-model-experimental/checkpoint/Step1-TRN_VAL-0526/Best_loss-Ep_550-Lo_0.0020.pth'
Model Parameter Count: 113.34M


### Dataset Load

In [5]:
fold_data_dict = {}

for fold_idx, target_fold in enumerate(CFG.target_fold_list):

    #--- 기본 fold 지정 ---#
    fold_dict = {
        'train_fold': [x % CFG.n_folds for x in range(target_fold, target_fold + CFG.trn_ratio)],
        'val_fold': [x % CFG.n_folds for x in range(
            target_fold + CFG.trn_ratio, target_fold + CFG.trn_ratio + CFG.val_ratio)],
    }

    #--- dataframe 생성 ---#
    df_dict = {}
    for split in ['train', 'val']:
        split_folds = fold_dict[f'{split}_fold']
        df  = CFG.dataset_df[CFG.dataset_df['kfold'].isin(split_folds)]
        if split == 'train' and (0 < CFG.dataset_fraction < 1):
            df = df.sample(frac=CFG.dataset_fraction, random_state=CFG.seed+831, ignore_index=True)
        df_dict[split] = df
    if CFG.test_dataset_df_path:
        df_dict['test'] = CFG.test_dataset_df
    if CFG.public_val_dataset_df_path:
        df_dict['pub_test'] = CFG.public_val_dataset_df
    if CFG.snuh_val_dataset_df_path:
        df_dict['snuh_test'] = CFG.snuh_val_dataset_df

    #--- dataset 생성 ---#
    dataset_dict = {}
    for split in df_dict.keys():
        # dataset_dict[split] = dataset.DatasetFromDataframe(
        dataset_dict[split] = dataset.DatasetFromDataframe_v2(
            df          = df_dict[split],
            input_col   = CFG.target_cols['input'],
            label_col   = CFG.target_cols['label'],
            target_len  = CFG.input_config['seq_length'],
            target_lead = CFG.input_config['target_lead'],
            transforms  = CFG.transform_pipeline_dict['train'] if split == 'train' else CFG.transform_pipeline_dict['val'],
        )

    #--- 샘플러 설정 ---#
    sampler = None
    if CFG.weighted_random_sampler:
        sampler = WeightedRandomSampler(
            weights     = df_dict['train'][CFG.weighted_random_sampler].values,
            num_samples = len(dataset_dict['train']),
            replacement = True
        )

    #--- dataloader 생성 ---#
    dataloader_dict = {}
    for split in df_dict.keys():
        dataloader_dict[split] = DataLoader(
            dataset            = dataset_dict[split],
            batch_size         = CFG.batch_size,
            shuffle            = (split == 'train') and not CFG.weighted_random_sampler,
            sampler            = sampler if split == 'train' else None,
            num_workers        = CFG.num_workers,
            pin_memory         = True,
            drop_last          = (split == 'train'),
            persistent_workers = (CFG.num_workers != 0),
            prefetch_factor    = (CFG.num_workers // 4) if CFG.num_workers != 0 else None,
        )

    # fold 결과 저장
    fold_data_dict[fold_idx] = {
        'fold_dict'      : fold_dict,
        'df_dict'        : df_dict,
        'dataset_dict'   : dataset_dict,
        'dataloader_dict': dataloader_dict,
    }

# 체크 로그
for split in df_dict.keys():
    print(f'{split} dataframe shape: {df_dict[split].shape}, dataloader length: {len(dataloader_dict[split])}')
print()
for split in df_dict.keys():
    sample = next(iter(dataloader_dict[split]))
    print(f'{split} sample path: {sample["input_path"][0]}')
    print(
        f'{sample["input"].shape}, dtype: {sample["input"].dtype}, '
        f'max: {torch.max(sample["input"]):.4f}, min: {torch.min(sample["input"]):.4f}, '
        f'mean: {torch.mean(sample["input"]):.4f}, std: {torch.std(sample["input"]):.4f}'
    )


train dataframe shape: (50026, 5), dataloader length: 97
val dataframe shape: (5453, 5), dataloader length: 11
test dataframe shape: (15190, 5), dataloader length: 30
pub_test dataframe shape: (20060, 5), dataloader length: 40
snuh_test dataframe shape: (31776, 5), dataloader length: 63



train sample path: /home/Datasets_processed/EKG_Latest/signal-bpf0540_or3/20170721201048_0664F020_0000_20170721192615_11324604.npy
torch.Size([512, 12, 2560]), dtype: torch.float32, max: 1.0000, min: -1.0000, mean: 0.0173, std: 0.4324
val sample path: /home/Datasets_processed/EKG_Latest/signal-bpf0540_or3/14898871_2021-11-17_2021111719451618_2021111719452386.npy
torch.Size([512, 12, 2560]), dtype: torch.float32, max: 1.0000, min: -1.0000, mean: 0.0203, std: 0.4602
test sample path: /home/Datasets_processed/EKG_Latest/signal-bpf0540_or3/10302870_2023-10-17_2023101716092374_2023101716095795.npy
torch.Size([512, 12, 2560]), dtype: torch.float32, max: 1.0000, min: -1.0000, mean: 0.0172, std: 0.4635
pub_test sample path: /home/Datasets_processed/BenchmarkSets/MIMIC_IV_ECG/npys-bpf0540_or3/10001186/10001186_45986118.npy
torch.Size([512, 12, 2560]), dtype: torch.float32, max: 1.0000, min: -1.0000, mean: 0.0188, std: 0.4738
snuh_test sample path: /home/Datasets_processed/EKG_Latest/SNUH/signal

### Training

In [6]:
for fold_idx, target_fold in enumerate(CFG.target_fold_list):
    
    #--- 경로 및 초기 설정 ---#
    model_fold_path = os.path.join(CFG.model_save_path, f'{target_fold}fold')
    os.makedirs(model_fold_path, exist_ok=True)
    

    #--- EarlyStopping 인스턴스 초기화 ---#
    early_stopping = CFG.early_stopping(patience = CFG.early_stopping_patience,
                                        delta    = 0,
                                        mode     = CFG.early_stopping_mode,
                                        verbose  = True
    )

    #--- Best 기록 초기화 ---#
    BEST_RECORD_DICT = {'record_df'               : pd.DataFrame(),
                        'epoch'                   : -1,
                        'best_loss'               : float('inf'),
                        'best_AUROC'              : -1,
                        'best_AUPRC'              : -1,
                        'best_loss_ckp_path'      : None,
                        'best_AUROC_ckp_path'     : None,
                        'best_AUPRC_ckp_path'     : None,
    }


    #--- 모델 구성 ---#
    model = CFG.model_architecture(CFG.model_info).to(CFG.device)


    #--- Optimizer & Scheduler ---#
    optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.start_lr)
    scheduler_dict = {'lr': None}
    if CFG.scheduler_lr:
        steps_per_epoch = len(fold_data_dict[fold_idx]['dataloader_dict']['train'])
        num_training_steps = CFG.epoch * steps_per_epoch
        num_warmup_steps = int(CFG.epoch * CFG.warmup) * steps_per_epoch
        num_cosine_steps = num_training_steps - num_warmup_steps
        warmup_scheduler = lr_scheduler.LinearLR(
            optimizer, start_factor = CFG.start_factor, end_factor = 1.0, total_iters = num_warmup_steps)
        cosine_scheduler = lr_scheduler.CosineAnnealingLR(
            optimizer, T_max = num_cosine_steps, eta_min = CFG.final_lr)
        scheduler_dict['lr'] = lr_scheduler.SequentialLR(
            optimizer, schedulers = [warmup_scheduler, cosine_scheduler],
            milestones = [num_warmup_steps])
    start_epoch = 0
    print(f"Training will start from epoch {0}")


    # ------------------------------------------------------------------------------------ # 
    # ------------------------------------------------------------------------------------ # 
    # ------------------------------------------------------------------------------------ # 
    
    
    #--- Model Training Loop ---#
    for epoch in range(start_epoch, CFG.epoch):
        print(f'\n||| Current epoch: {epoch} |||')

        train_status = train_engine.train_one_epoch(
            model          = model, 
            data_loader    = fold_data_dict[fold_idx]['dataloader_dict']['train'], 
            loss_fn        = CFG.loss_fn, 
            device         = CFG.device, 
            amp_scaler     = CFG.amp_scaler, 
            optimizer      = optimizer, 
            clip_grad_norm = CFG.clip_grad_norm,
            scheduler_lr   = scheduler_dict['lr'], 
            verbose        = True,
        )
        train_metrics = utils.classification_metrics(
            y                = train_status['labels'], 
            y_pred           = train_status['logits'], 
            activation_fn    = True, 
            mode             = 'multilabel',
            threshold_method = 'youden'
        )
        train_metrics['avg']['loss'] = train_status['loss']
        monitor = f"| TRN-Avg. | loss : {train_metrics['avg']['loss']:.3f}, AUROC : {train_metrics['avg']['AUROC']:.2f}, AUPRC : {train_metrics['avg']['AUPRC']:.2f}"
        print(monitor)
        

        #--- Model Validation ---#
        if epoch % CFG.valid_interval == 0:
            valid_status = train_engine.evaluate_one_epoch(
                model          = model, 
                data_loader    = fold_data_dict[fold_idx]['dataloader_dict']['val'], 
                loss_fn        = CFG.loss_fn, 
                device         = CFG.device, 
                verbose        = True,
            )
            valid_metrics = utils.classification_metrics(
                y                = valid_status['labels'], 
                y_pred           = valid_status['logits'], 
                activation_fn    = True, 
                mode             = 'multilabel',
                threshold_method = 'youden'
            )
            valid_metrics['avg']['loss'] = valid_status['loss']
            monitor = f"| VAL-Avg. | loss : {valid_metrics['avg']['loss']:.3f}, AUROC : {valid_metrics['avg']['AUROC']:.3f}, AUPRC : {valid_metrics['avg']['AUPRC']:.3f}"
            print(monitor)


            #--- Record training log ---#
            BEST_RECORD_DICT['record_df'] = utils.history_recording(
                record_df       = BEST_RECORD_DICT['record_df'], 
                epoch           = epoch, 
                train_metrics   = train_metrics, 
                valid_metrics   = valid_metrics,
                save_path       = os.path.join(model_fold_path, 'learning_history.csv'), 
            )
            utils.learning_curve_recording(
                record_df         = BEST_RECORD_DICT['record_df'], 
                save_path         = os.path.join(model_fold_path, 'learning_curve.png'), 
                show              = False                
            )

            #--- Helper: 저장 함수 ---#
            def maybe_save_best(metric_key: str, is_better: bool):
                if is_better:
                    ckpt_path_key = f'best_{metric_key}_ckp_path'
                    old_ckpt = BEST_RECORD_DICT[ckpt_path_key]
                    if old_ckpt and os.path.exists(old_ckpt) and CFG.model_auto_remove:
                        os.remove(old_ckpt)
                    model_name = f"Best_{metric_key}-Ep_{epoch}-Lo_{valid_metrics['avg']['loss']:.3f}"
                    model_name += f"-M0_{valid_metrics['avg']['AUROC']:.3f}-M1_{valid_metrics['avg']['AUPRC']:.3f}"
                    new_ckpt = os.path.join(model_fold_path, model_name + '.pth')
                    utils.ModelSave(new_ckpt, epoch, model, optimizer, scheduler_dict['lr'])
                    BEST_RECORD_DICT[ckpt_path_key] = new_ckpt
                    BEST_RECORD_DICT[f'best_{metric_key}'] = valid_metrics['avg'][metric_key]

            # Check improvements
            improved = {
                'loss': valid_metrics['avg']['loss'] < BEST_RECORD_DICT['best_loss'],
                'AUROC': valid_metrics['avg']['AUROC'] > BEST_RECORD_DICT['best_AUROC'],
                'AUPRC': valid_metrics['avg']['AUPRC'] > BEST_RECORD_DICT['best_AUPRC'],
            }
            if any(improved.values()):
                print("*** Best Record! ***")
                for key, flag in improved.items():
                    if flag:
                        print(f"- Improved {key}: {BEST_RECORD_DICT.get(f'best_{key}', float('nan')):.4f} → {valid_metrics['avg'][key]:.4f}")
                if improved['loss']:
                    BEST_RECORD_DICT['best_loss'] = valid_status['loss']
                    maybe_save_best('loss', True)
                maybe_save_best('AUROC', improved['AUROC'])
                maybe_save_best('AUPRC', improved['AUPRC'])

            # Early stopping
            monitor_value = valid_metrics['avg']['loss'] if CFG.early_stopping_mode == 'min' else valid_metrics['avg']['AUROC']
            early_stopping(monitor_value)
            if early_stopping.early_stop:
                print("Early stopping triggered."); break
        
        torch.cuda.empty_cache()        
        # break

model name : MAE_1D_250409_v3
seq_length 2560, in_channels 12, patch_size 32, embed_dim 768, token_len 80, 
model was loaded from '/home/bgk/macai-model-experimental/checkpoint/Step1-TRN_VAL-0526/Best_loss-Ep_550-Lo_0.0020.pth'
Training will start from epoch 0

||| Current epoch: 0 |||


| TRN | batch loss : 0.4052, lr : 2.0e-04 : 100%|██████████| 97/97 [01:31<00:00,  1.07it/s]


| TRN-Avg. | loss : 0.480, AUROC : 0.65, AUPRC : 0.24


| VAL | batch loss: 0.5139, : 100%|██████████| 11/11 [00:09<00:00,  1.10it/s]


| VAL-Avg. | loss : 0.471, AUROC : 0.748, AUPRC : 0.406
*** Best Record! ***
- Improved loss: inf → 0.4712
- Improved AUROC: -1.0000 → 0.7477
- Improved AUPRC: -1.0000 → 0.4056
Saved at /home/bgk/macai-model-experimental/_kimbg_code/SynAI_v2/outputs/checkpoint/CL_step3-TRN-MacAI_v1_2-0822/0fold/Best_loss-Ep_0-Lo_0.471-M0_0.748-M1_0.406.pth
Saved at /home/bgk/macai-model-experimental/_kimbg_code/SynAI_v2/outputs/checkpoint/CL_step3-TRN-MacAI_v1_2-0822/0fold/Best_AUROC-Ep_0-Lo_0.471-M0_0.748-M1_0.406.pth
Saved at /home/bgk/macai-model-experimental/_kimbg_code/SynAI_v2/outputs/checkpoint/CL_step3-TRN-MacAI_v1_2-0822/0fold/Best_AUPRC-Ep_0-Lo_0.471-M0_0.748-M1_0.406.pth
[EarlyStopping] (Update) Best Score: 0.74775

||| Current epoch: 1 |||


| TRN | batch loss : 0.3952, lr : 2.0e-04 : 100%|██████████| 97/97 [02:01<00:00,  1.25s/it]


| TRN-Avg. | loss : 0.419, AUROC : 0.75, AUPRC : 0.35


| VAL | batch loss: 0.5085, : 100%|██████████| 11/11 [00:09<00:00,  1.12it/s]


| VAL-Avg. | loss : 0.462, AUROC : 0.754, AUPRC : 0.417
*** Best Record! ***
- Improved loss: 0.4712 → 0.4619
- Improved AUROC: 0.7477 → 0.7544
- Improved AUPRC: 0.4056 → 0.4170
Saved at /home/bgk/macai-model-experimental/_kimbg_code/SynAI_v2/outputs/checkpoint/CL_step3-TRN-MacAI_v1_2-0822/0fold/Best_loss-Ep_1-Lo_0.462-M0_0.754-M1_0.417.pth
Saved at /home/bgk/macai-model-experimental/_kimbg_code/SynAI_v2/outputs/checkpoint/CL_step3-TRN-MacAI_v1_2-0822/0fold/Best_AUROC-Ep_1-Lo_0.462-M0_0.754-M1_0.417.pth
Saved at /home/bgk/macai-model-experimental/_kimbg_code/SynAI_v2/outputs/checkpoint/CL_step3-TRN-MacAI_v1_2-0822/0fold/Best_AUPRC-Ep_1-Lo_0.462-M0_0.754-M1_0.417.pth
[EarlyStopping] (Update) Best Score: 0.75442

||| Current epoch: 2 |||


| TRN | batch loss : 0.4053, lr : 2.0e-04 : 100%|██████████| 97/97 [01:49<00:00,  1.13s/it]


| TRN-Avg. | loss : 0.412, AUROC : 0.77, AUPRC : 0.38


| VAL | batch loss: 0.5056, : 100%|██████████| 11/11 [00:08<00:00,  1.26it/s]


| VAL-Avg. | loss : 0.462, AUROC : 0.758, AUPRC : 0.421
*** Best Record! ***
- Improved loss: 0.4619 → 0.4617
- Improved AUROC: 0.7544 → 0.7581
- Improved AUPRC: 0.4170 → 0.4214
Saved at /home/bgk/macai-model-experimental/_kimbg_code/SynAI_v2/outputs/checkpoint/CL_step3-TRN-MacAI_v1_2-0822/0fold/Best_loss-Ep_2-Lo_0.462-M0_0.758-M1_0.421.pth
Saved at /home/bgk/macai-model-experimental/_kimbg_code/SynAI_v2/outputs/checkpoint/CL_step3-TRN-MacAI_v1_2-0822/0fold/Best_AUROC-Ep_2-Lo_0.462-M0_0.758-M1_0.421.pth
Saved at /home/bgk/macai-model-experimental/_kimbg_code/SynAI_v2/outputs/checkpoint/CL_step3-TRN-MacAI_v1_2-0822/0fold/Best_AUPRC-Ep_2-Lo_0.462-M0_0.758-M1_0.421.pth
[EarlyStopping] (Update) Best Score: 0.75810

||| Current epoch: 3 |||


| TRN | batch loss : 0.3902, lr : 2.0e-04 : 100%|██████████| 97/97 [01:26<00:00,  1.12it/s]


| TRN-Avg. | loss : 0.407, AUROC : 0.79, AUPRC : 0.42


| VAL | batch loss: 0.5086, : 100%|██████████| 11/11 [00:09<00:00,  1.14it/s]


| VAL-Avg. | loss : 0.463, AUROC : 0.754, AUPRC : 0.416
[EarlyStopping] (Patience) 1/5, Best: 0.75810, Current: 0.75375, Delta: 0.00435

||| Current epoch: 4 |||


| TRN | batch loss : 0.4045, lr : 1.9e-04 : 100%|██████████| 97/97 [01:28<00:00,  1.09it/s]


| TRN-Avg. | loss : 0.401, AUROC : 0.80, AUPRC : 0.47


| VAL | batch loss: 0.5146, : 100%|██████████| 11/11 [00:09<00:00,  1.19it/s]


| VAL-Avg. | loss : 0.475, AUROC : 0.755, AUPRC : 0.415
[EarlyStopping] (Patience) 2/5, Best: 0.75810, Current: 0.75522, Delta: 0.00288

||| Current epoch: 5 |||


| TRN | batch loss : 0.3954, lr : 1.9e-04 : 100%|██████████| 97/97 [01:29<00:00,  1.09it/s]


| TRN-Avg. | loss : 0.393, AUROC : 0.82, AUPRC : 0.51


| VAL | batch loss: 0.5579, : 100%|██████████| 11/11 [00:10<00:00,  1.03it/s]


| VAL-Avg. | loss : 0.492, AUROC : 0.740, AUPRC : 0.392
[EarlyStopping] (Patience) 3/5, Best: 0.75810, Current: 0.73977, Delta: 0.01832

||| Current epoch: 6 |||


| TRN | batch loss : 0.3751, lr : 1.8e-04 : 100%|██████████| 97/97 [01:27<00:00,  1.11it/s]


| TRN-Avg. | loss : 0.383, AUROC : 0.84, AUPRC : 0.56


| VAL | batch loss: 0.5311, : 100%|██████████| 11/11 [00:08<00:00,  1.26it/s]


| VAL-Avg. | loss : 0.491, AUROC : 0.736, AUPRC : 0.386
[EarlyStopping] (Patience) 4/5, Best: 0.75810, Current: 0.73625, Delta: 0.02184

||| Current epoch: 7 |||


| TRN | batch loss : 0.3635, lr : 1.7e-04 : 100%|██████████| 97/97 [01:26<00:00,  1.12it/s]


| TRN-Avg. | loss : 0.373, AUROC : 0.86, AUPRC : 0.60


| VAL | batch loss: 0.5371, : 100%|██████████| 11/11 [00:09<00:00,  1.10it/s]


| VAL-Avg. | loss : 0.496, AUROC : 0.737, AUPRC : 0.387
[EarlyStopping] (Patience) 5/5, Best: 0.75810, Current: 0.73711, Delta: 0.02099
[EarlyStop Triggered] Best Score: 0.75810
Early stopping triggered.


### Evaluate at Test Dataset

In [7]:
def load_models(model_paths, device, model_fn, model_info):
    models = []
    for path in sorted(model_paths):
        ckp = torch.load(path, map_location=device)
        model = model_fn(model_info).to(device)
        model.load_state_dict(ckp['model_state_dict'])
        models.append({'path': path, 'model': model})
        print(f"Loaded: {path}")
    return models

models_dict = {}
for fold_idx, target_fold in enumerate(CFG.target_fold_list):
    model_fold_path = os.path.join(CFG.model_save_path, f'{target_fold}fold')
    if CFG.model_auto_remove:
        model_paths = sorted(glob(os.path.join(model_fold_path, 'Be*.pth')))
    else:
        model_paths = sorted(glob(os.path.join(model_fold_path, '*loss*.pth')))[-1:]
    models = load_models(model_paths, CFG.device, CFG.model_architecture, CFG.model_info)
    models_dict[fold_idx] = models

  ckp = torch.load(path, map_location=device)


model name : MAE_1D_250409_v3
seq_length 2560, in_channels 12, patch_size 32, embed_dim 768, token_len 80, 
model was loaded from '/home/bgk/macai-model-experimental/checkpoint/Step1-TRN_VAL-0526/Best_loss-Ep_550-Lo_0.0020.pth'
Loaded: /home/bgk/macai-model-experimental/_kimbg_code/SynAI_v2/outputs/checkpoint/CL_step3-TRN-MacAI_v1_2-0822/0fold/Best_AUPRC-Ep_2-Lo_0.462-M0_0.758-M1_0.421.pth
model name : MAE_1D_250409_v3
seq_length 2560, in_channels 12, patch_size 32, embed_dim 768, token_len 80, 
model was loaded from '/home/bgk/macai-model-experimental/checkpoint/Step1-TRN_VAL-0526/Best_loss-Ep_550-Lo_0.0020.pth'
Loaded: /home/bgk/macai-model-experimental/_kimbg_code/SynAI_v2/outputs/checkpoint/CL_step3-TRN-MacAI_v1_2-0822/0fold/Best_AUROC-Ep_2-Lo_0.462-M0_0.758-M1_0.421.pth
model name : MAE_1D_250409_v3
seq_length 2560, in_channels 12, patch_size 32, embed_dim 768, token_len 80, 
model was loaded from '/home/bgk/macai-model-experimental/checkpoint/Step1-TRN_VAL-0526/Best_loss-Ep_550-L

In [8]:
def infer_and_collect_probs(models, dataloader, device):
    probs_list = []
    for model_dict in models:
        logits, all_labels = train_engine.only_inference(model_dict['model'], dataloader, device)
        probs = torch.sigmoid(logits).cpu().numpy()
        name = os.path.basename(model_dict['path']).split('-')[0].split('_')[-1]
        probs_list.append({'name': name, 'probs': probs})
    return probs_list, all_labels

#### Validation set

In [46]:
if 'val' in dataloader_dict.keys():
    target_dataloader = dataloader_dict['val']
    tmp_df = df_dict['val'].copy()

    models = []
    for k,v in models_dict.items():
        models += v
    probs_list, labels = infer_and_collect_probs(models, target_dataloader, CFG.device)
    ensemble_probs = np.mean([d['probs'] for d in probs_list], axis=0)
    prob_col = [i+'-prob' for i in CFG.target_cols['label']]
    tmp_df[prob_col] = ensemble_probs
    tmp_df.to_csv(os.path.join(CFG.model_save_path, 'EUMC-val-prob.csv'))

    metrics = utils.classification_metrics(
        y                = labels, 
        y_pred           = ensemble_probs, 
        activation_fn    = False, 
        mode             = 'multilabel',
        threshold_method = 'youden'
    )
    metrics_df = pd.DataFrame(metrics).T
    display(metrics_df)
    metrics_df.index = CFG.target_cols['label'] + ['avg']
    sub_idx = metrics_df.index[metrics_df.index.str.contains('14d')]
    display(pd.DataFrame(metrics_df.loc[sub_idx, :].mean()).round(4).T)

else:
    print('No val set')

100%|██████████| 11/11 [00:09<00:00,  1.21it/s]
100%|██████████| 11/11 [00:09<00:00,  1.11it/s]
100%|██████████| 11/11 [00:09<00:00,  1.16it/s]


Unnamed: 0,AUROC,AUPRC,opt_thr,Accu,Sens,Prec,Spec,NPV,F1,TN,FP,FN,TP,Total_N
0,0.766292,0.293886,0.152588,0.70695,0.699422,0.201108,0.707742,0.957237,0.312392,3492.0,1442.0,156.0,363.0,5453.0
1,0.749905,0.548998,0.26123,0.674858,0.732589,0.449416,0.652506,0.863055,0.557082,2565.0,1366.0,407.0,1115.0,5453.0
avg,0.758098,0.421442,,0.690904,0.716005,0.325262,0.680124,0.910146,0.434737,,,,,


Unnamed: 0,AUROC,AUPRC,opt_thr,Accu,Sens,Prec,Spec,NPV,F1,TN,FP,FN,TP,Total_N
0,0.7581,0.4214,0.2069,0.6909,0.716,0.3253,0.6801,0.9101,0.4347,3028.5,1404.0,281.5,739.0,5453.0


#### Holdout Set

In [47]:
if 'test' in dataloader_dict.keys():
    target_dataloader = dataloader_dict['test']
    tmp_df = df_dict['test'].copy()

    models = []
    for k,v in models_dict.items():
        models += v
    probs_list, labels = infer_and_collect_probs(models, target_dataloader, CFG.device)
    ensemble_probs = np.mean([d['probs'] for d in probs_list], axis=0)
    prob_col = [i+'-prob' for i in CFG.target_cols['label']]
    tmp_df[prob_col] = ensemble_probs
    tmp_df.to_csv(os.path.join(CFG.model_save_path, 'EUMC-test-prob.csv'))

    metrics = utils.classification_metrics(
        y                = labels, 
        y_pred           = ensemble_probs, 
        activation_fn    = False, 
        mode             = 'multilabel',
        threshold_method = 'youden'
    )
    metrics_df = pd.DataFrame(metrics).T
    display(metrics_df)
    metrics_df.index = CFG.target_cols['label'] + ['avg']
    sub_idx = metrics_df.index[metrics_df.index.str.contains('14d')]
    display(pd.DataFrame(metrics_df.loc[sub_idx, :].mean()).round(4).T)

else:
    print('No test set')

  0%|          | 0/30 [00:00<?, ?it/s]

100%|██████████| 30/30 [00:24<00:00,  1.23it/s]
100%|██████████| 30/30 [00:23<00:00,  1.28it/s]
100%|██████████| 30/30 [00:23<00:00,  1.26it/s]


Unnamed: 0,AUROC,AUPRC,opt_thr,Accu,Sens,Prec,Spec,NPV,F1,TN,FP,FN,TP,Total_N
0,0.772515,0.276545,0.160767,0.707834,0.700581,0.193186,0.708557,0.959608,0.302859,9788.0,4026.0,412.0,964.0,15190.0
1,0.738613,0.55314,0.350342,0.712377,0.593247,0.504237,0.761158,0.820464,0.545133,8203.0,2574.0,1795.0,2618.0,15190.0
avg,0.755564,0.414842,,0.710105,0.646914,0.348712,0.734857,0.890036,0.423996,,,,,


Unnamed: 0,AUROC,AUPRC,opt_thr,Accu,Sens,Prec,Spec,NPV,F1,TN,FP,FN,TP,Total_N
0,0.7556,0.4148,0.2556,0.7101,0.6469,0.3487,0.7349,0.89,0.424,8995.5,3300.0,1103.5,1791.0,15190.0


#### Pub Set

In [48]:
if 'pub_test' in dataloader_dict.keys():
    target_dataloader = dataloader_dict['pub_test']
    tmp_df = df_dict['pub_test'].copy()

    models = []
    for k,v in models_dict.items():
        models += v
    probs_list, labels = infer_and_collect_probs(models, target_dataloader, CFG.device)
    ensemble_probs = np.mean([d['probs'] for d in probs_list], axis=0)
    prob_col = [i+'-prob' for i in CFG.target_cols['label']]
    tmp_df[prob_col] = ensemble_probs
    tmp_df.to_csv(os.path.join(CFG.model_save_path, 'BID-test-prob.csv'))

    metrics = utils.classification_metrics(
        y                = labels, 
        y_pred           = ensemble_probs, 
        activation_fn    = False, 
        mode             = 'multilabel',
        threshold_method = 'youden'
    )
    metrics_df = pd.DataFrame(metrics).T
    display(metrics_df)
    metrics_df.index = CFG.target_cols['label'] + ['avg']
    sub_idx = metrics_df.index[metrics_df.index.str.contains('14d')]
    display(pd.DataFrame(metrics_df.loc[sub_idx, :].mean()).round(4).T)

else:
    print('No pub_test set')

  0%|          | 0/40 [00:00<?, ?it/s]

100%|██████████| 40/40 [00:36<00:00,  1.11it/s]
100%|██████████| 40/40 [00:30<00:00,  1.33it/s]
100%|██████████| 40/40 [00:31<00:00,  1.28it/s]


Unnamed: 0,AUROC,AUPRC,opt_thr,Accu,Sens,Prec,Spec,NPV,F1,TN,FP,FN,TP,Total_N
0,0.731436,0.262902,0.139771,0.651695,0.697822,0.195699,0.646001,0.945414,0.305674,11535.0,6321.0,666.0,1538.0,20060.0
1,0.723555,0.440061,0.306152,0.672682,0.655604,0.389847,0.67804,0.862534,0.488948,10353.0,4916.0,1650.0,3141.0,20060.0
avg,0.727495,0.351481,,0.662188,0.676713,0.292773,0.662021,0.903974,0.397311,,,,,


Unnamed: 0,AUROC,AUPRC,opt_thr,Accu,Sens,Prec,Spec,NPV,F1,TN,FP,FN,TP,Total_N
0,0.7275,0.3515,0.223,0.6622,0.6767,0.2928,0.662,0.904,0.3973,10944.0,5618.5,1158.0,2339.5,20060.0


#### SNUH Set

In [49]:
if 'snuh_test' in dataloader_dict.keys():
    target_dataloader = dataloader_dict['snuh_test']
    tmp_df = df_dict['snuh_test'].copy()

    models = []
    for k,v in models_dict.items():
        models += v
    probs_list, labels = infer_and_collect_probs(models, target_dataloader, CFG.device)
    ensemble_probs = np.mean([d['probs'] for d in probs_list], axis=0)
    prob_col = [i+'-prob' for i in CFG.target_cols['label']]
    tmp_df[prob_col] = ensemble_probs
    tmp_df.to_csv(os.path.join(CFG.model_save_path, 'SNUH-test-prob.csv'))

    metrics = utils.classification_metrics(
        y                = labels, 
        y_pred           = ensemble_probs, 
        activation_fn    = False, 
        mode             = 'multilabel',
        threshold_method = 'youden'
    )
    metrics_df = pd.DataFrame(metrics).T
    display(metrics_df)
    metrics_df.index = CFG.target_cols['label'] + ['avg']
    sub_idx = metrics_df.index[metrics_df.index.str.contains('14d')]
    display(pd.DataFrame(metrics_df.loc[sub_idx, :].mean()).round(4).T)

else:
    print('No snuh_test set')

  0%|          | 0/63 [00:00<?, ?it/s]

100%|██████████| 63/63 [00:46<00:00,  1.37it/s]
100%|██████████| 63/63 [00:44<00:00,  1.42it/s]
100%|██████████| 63/63 [00:48<00:00,  1.31it/s]


Unnamed: 0,AUROC,AUPRC,opt_thr,Accu,Sens,Prec,Spec,NPV,F1,TN,FP,FN,TP,Total_N
0,0.754971,0.310331,0.164551,0.69499,0.678884,0.225636,0.697084,0.943492,0.338701,19602.0,8518.0,1174.0,2482.0,31776.0
1,0.726095,0.500828,0.30249,0.653481,0.702816,0.415274,0.635299,0.852955,0.522071,14751.0,8468.0,2543.0,6014.0,31776.0
avg,0.740533,0.40558,,0.674235,0.69085,0.320455,0.666191,0.898224,0.430386,,,,,


Unnamed: 0,AUROC,AUPRC,opt_thr,Accu,Sens,Prec,Spec,NPV,F1,TN,FP,FN,TP,Total_N
0,0.7405,0.4056,0.2335,0.6742,0.6909,0.3205,0.6662,0.8982,0.4304,17176.5,8493.0,1858.5,4248.0,31776.0
