## Load libraries

In [None]:
%load_ext autoreload
%autoreload 2
import gc
import numpy as np
import os
import pickle
import random
import warnings
from math import sqrt
from datetime import datetime
from tqdm import tqdm
import ml_insights as mli
import json
import xml.etree.ElementTree as ET
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.colors import LinearSegmentedColormap
import joblib
import h5py
# Data processing
import pandas as pd
import neurokit2 as nk
from joblib import Parallel, delayed

# Machine learning
from sklearnex import patch_sklearn
patch_sklearn()  

from sklearn.metrics import (
    average_precision_score, roc_curve, 
    roc_auc_score, accuracy_score, f1_score
)
from sklearn.model_selection import GroupKFold, train_test_split
from sklearn.utils import shuffle
from sklearn.calibration import IsotonicRegression
from sklearn.metrics import brier_score_loss
from sklearn.linear_model import LinearRegression
from sklearn.metrics import confusion_matrix,roc_curve, auc
from scipy.stats import chi2
from scipy.signal import resample

# Deep learning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

# ML
import xgboost as xgb
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression

# Visualization
import shap
import optuna
from BorutaShap import BorutaShap
# Custom imports
from model_code import *
from model.blocks import FinalModel
from team_code import *
from helper_code import *
from plot_model import *
from delong import *

SEED = 1
warnings.filterwarnings("ignore")
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
shap.initjs()
# ==================================================
N_TRIAL                 = 50
N_FOLD                  = 5
N_EPOCH                 = 100

# Train hyperparams
BATCH_SIZE_LIST_GLOBAL  = [8,16]

LR_INITIAL_LO           = 5e-4
LR_INITIAL_HI           = 5e-3

LR_STEP_SIZE_LO         = 2
LR_STEP_SIZE_HI         = 3
LR_STEP_GAMMA_LO        = 0.05
LR_STEP_GAMMA_HI        = 0.3

EARLY_STOP_PATIENCE_LO  = 3
EARLY_STOP_PATIENCE_HI  = 4

# Architecture hyperparams
BLOCK_SIZE_GLOBAL = [12,16,24]
BLOCK_DEPTH_GLOBAL = [2,3]
BLOCK_LAYERS_GLOBAL = [3,4]
HIDDEN_SIZE_GLOBAL = [32,64,128]
KERNEL_NUM_GLOBAL = [5,7,9]
# ==================================================
# Modify CustomCNN to take dropout_rate as an argument
def collate_fn(batch):
    inputs = torch.stack([torch.tensor(i[0], dtype=torch.float32) for i in batch])
    targets = torch.stack([torch.tensor(i[1], dtype=torch.float32) for i in batch])
    age = torch.tensor([[i[3]] for i in batch], dtype=torch.float32)
    gender = torch.tensor([[i[4]] for i in batch], dtype=torch.float32)

    return inputs, targets, age, gender

## Youden index
def youden(y_true, y_score):
    fpr, tpr, thresholds = roc_curve(y_true, y_score)
    idx = np.argmax(tpr - fpr)
    return thresholds[idx]

# Function to train the model (speed optimized, no prefetch, batch-by-batch)
def train_model(train_loader, model, criterion, optimizer, scheduler, device=DEVICE):
    model.train()
    running_loss = 0.0
    n_batch = len(train_loader)
    pbar = tqdm(train_loader, total=n_batch, leave=False)
    torch.backends.cudnn.benchmark = True

    for batch in pbar:
        # Unpack and move to device efficiently
        inputs, labels, age, gender = batch
        inputs = inputs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        age = age.to(device, non_blocking=True)
        gender = gender.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)  # set_to_none is faster than zeroing
        outputs = model(inputs, age, gender)
        labels_ = labels[:, 1].unsqueeze(1)
        loss = criterion(outputs, labels_)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        # Explicitly delete to free memory
        del inputs, labels, age, gender, outputs, loss, labels_
    torch.backends.cudnn.benchmark = False  # Reset if needed
    del pbar  # Remove after use
    return running_loss / n_batch

def evaluate_model(valid_loader, model, device=DEVICE):
    model.eval()
    model.return_hidden = True  # Enable hidden state return
    y_true, y_pred, hidden_states = None, None, None
    n_batch = len(valid_loader)
    pbar = tqdm(enumerate(valid_loader), total=n_batch, leave=False)

    with torch.no_grad():
        for i_batch, (X_batch, y_true_batch, age_batch, gender_batch) in pbar:
            X_batch, y_true_batch = X_batch.to(device), y_true_batch.to(device)
            age_batch, gender_batch = age_batch.to(device), gender_batch.to(device)
            y_true_batch = y_true_batch[:, 1].unsqueeze(1)
            
            y_pred_batch, hidden_batch = model(X_batch, age_batch, gender_batch)
            y_pred_batch = F.sigmoid(y_pred_batch)
            
            # Convert to numpy arrays
            y_pred_batch = y_pred_batch.cpu().numpy().reshape((-1, 1))
            y_true_batch = y_true_batch.cpu().numpy().reshape((-1, 1))
            hidden_batch = hidden_batch.cpu().numpy()

            # Concatenate results
            if y_pred is None:
                y_pred = y_pred_batch
                y_true = y_true_batch
                hidden_states = hidden_batch
            else:
                y_pred = np.r_[y_pred, y_pred_batch]
                y_true = np.r_[y_true, y_true_batch]
                hidden_states = np.r_[hidden_states, hidden_batch]
            
            # Remove variables after use
            del X_batch, y_true_batch, age_batch, gender_batch, y_pred_batch, hidden_batch

            pbar.set_description(f'Evaluating ... {1 + i_batch}/{n_batch}')
    
    model.return_hidden = False  # Reset hidden state return
    del pbar  # Remove after use
    return y_true, y_pred, hidden_states

def loguniform(low, high, size=None):
    return np.exp(np.random.uniform(np.log(low), np.log(high), size))

def calc_unreliability(y_true, y_prob, g=10):
    """
    Calculates a Hosmer-Lemeshow type statistic and p-value, 
    with some defensive checks against empty bins or p=0/1.
    """
    # Clip probabilities away from exact 0 or 1
    y_prob = np.clip(y_prob, 1e-12, 1 - 1e-12)
    
    # Put data in a DataFrame so we can bin easily
    df = pd.DataFrame({'y': y_true, 'p': y_prob})
    df['bin'] = pd.qcut(df['p'], q=g, duplicates='drop')  # drop duplicate edges
    
    # Group by bin to compute sums and means
    bin_data = df.groupby('bin', dropna=True).agg({'y': ['sum','count'], 'p': 'mean'})
    bin_data.columns = ['y_sum', 'y_count', 'p_mean']
    
    # Drop any rows that might be empty or cause invalid calculations
    bin_data = bin_data[bin_data['y_count'] > 0]
    
    # Observed vs. expected
    O = bin_data['y_sum']         # observed events
    N = bin_data['y_count']       # bin size
    P = bin_data['p_mean']        # avg predicted p in bin
    E = N * P                     # expected # events
    
    # HL statistic
    hl_stat = np.sum((O - E)**2 / (N * P * (1 - P)))
    
    # Degrees of freedom: (number_of_nonempty_bins - 2)
    dof = len(bin_data) - 2
    if dof > 0:
        p_value = 1 - chi2.cdf(hl_stat, dof)
    else:
        # If dof <= 0, cannot calculate a valid p-value
        p_value = np.nan
    
    return hl_stat, p_value

def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

def create_objective_function(model_type, model_class, param_grid_fn, features, model_name):
    def objective(trial, df_train, cv=N_FOLD):
        param_grid = param_grid_fn(trial)
        h = df_train['hid'].unique()
        groups = np.array(range(len(h)))
        h_shuffled, groups_shuffled = shuffle(h, groups, random_state=SEED)
        group_kf = GroupKFold(n_splits=cv)
        cv_scores = np.empty(cv)
        for idx, (train_h, val_h) in enumerate(group_kf.split(h_shuffled, groups=groups_shuffled)):
            X_train = df_train[df_train['hid'].isin(h_shuffled[train_h])][features]
            X_valid = df_train[df_train['hid'].isin(h_shuffled[val_h])][features]
            y_train = df_train.loc[df_train['hid'].isin(h_shuffled[train_h]), 'label'].values
            y_valid = df_train.loc[df_train['hid'].isin(h_shuffled[val_h]), 'label'].values
            
            model = model_class(**param_grid)
            if model_type == 'xgb':
                model.fit(X_train, y_train, verbose=False)
            else:
                model.fit(X_train, y_train)   
            cv_scores[idx] = roc_auc_score(y_valid, model.predict_proba(X_valid)[:, 1])
        mean_cv_score = np.mean(cv_scores)      
        try:
            if mean_cv_score > trial.study.best_value: 
                with open(f'best_params_{model_type}_{model_name}.json', 'w') as f:
                    json.dump(param_grid, f, indent=4)
        except ValueError:
            with open(f'best_params_{model_type}_{model_name}.json', 'w') as f:
                json.dump(param_grid, f, indent=4)
        return mean_cv_score
    return objective

def xgb_params(trial):
    return {
        'objective': 'binary:logistic',
        'tree_method': 'gpu_hist',
        'predictor': 'gpu_predictor',
        'tweedie_variance_power': trial.suggest_discrete_uniform('tweedie_variance_power', 1.0, 2.0, 0.1),
        'max_depth': trial.suggest_int('max_depth', 3, 8),
        'n_estimators': trial.suggest_int('n_estimators', 200, 1600, 200),
        'eta': trial.suggest_float('eta', 0.005, 0.05),
        'subsample': trial.suggest_float('subsample', 0.5, 0.9),
        'colsample_bytree': trial.suggest_float('colsample_bytree', 0.5, 0.9),
        'colsample_bylevel': trial.suggest_float('colsample_bylevel', 0.5, 0.9),
        'min_child_weight': trial.suggest_loguniform('min_child_weight', 1, 1e3),
        'reg_lambda': trial.suggest_loguniform('reg_lambda', 1, 1e3),
        'reg_alpha': trial.suggest_loguniform('reg_alpha', 0.1, 1e3),
        'gamma': trial.suggest_loguniform('gamma', 0.01, 1e2),
        'random_state': SEED,
        'enable_categorical': True
    }

def rf_params(trial):
    return {
        'n_estimators': trial.suggest_int('n_estimators', 100, 1000, 100),
        'max_depth': trial.suggest_int('max_depth', 3, 20),
        'min_samples_split': trial.suggest_int('min_samples_split', 2, 20),
        'min_samples_leaf': trial.suggest_int('min_samples_leaf', 1, 20),
        'max_features': trial.suggest_categorical('max_features', ['sqrt', 'log2', None]),
        'bootstrap': trial.suggest_categorical('bootstrap', [True, False]),
        'random_state': SEED
    }

def svm_params(trial):
    return {
        'C': trial.suggest_loguniform('C', 1e-3, 1e3),
        'kernel': trial.suggest_categorical('kernel', ['linear', 'rbf', 'poly', 'sigmoid']),
        'gamma': trial.suggest_categorical('gamma', ['scale', 'auto']),
        'probability': True,
        'random_state': SEED
    }

def logreg_params(trial):
    return {
        'C': trial.suggest_loguniform('C', 1e-4, 1e2),
        'penalty': trial.suggest_categorical('penalty', ['l2', 'none']),
        'solver': trial.suggest_categorical('solver', ['lbfgs', 'saga']),
        'max_iter': 1000,
        'random_state': SEED
    }
def compute_saliency_maps(input_data, input_age, input_gender, model):
    input_tensor = torch.FloatTensor(input_data).to(DEVICE)
    input_age_tensor = torch.FloatTensor(input_age).to(DEVICE)
    input_gender_tensor = torch.FloatTensor(input_gender).to(DEVICE)
    input_tensor.requires_grad_()
    # Forward pass
    model.eval()
    output = model(input_tensor, input_age_tensor, input_gender_tensor)
    # Compute gradients with respect to input
    output.backward(torch.ones_like(output))
    # Get gradients
    saliency = input_tensor.grad.abs().cpu().numpy()
    return saliency

def extract_date(filename):
    parts = filename.split('_')
    # Case 1: ..._YYYYMMDD.xml
    if len(parts) >= 3 and parts[-1].endswith('.xml') and len(parts[-1].replace('.xml','')) == 8:
        date_str = parts[-1].replace('.xml','')
        return f"{date_str[:4]}/{date_str[4:6]}/{date_str[6:8]}"
    # Case 2: ..._YYYY_MM_DD_...
    elif len(parts) >= 4 and parts[1].isdigit() and parts[2].isdigit() and parts[3].isdigit():
        # e.g. 74079672_2010_05_11_07_20_37.xml
        return f"{parts[1]}/{parts[2]}/{parts[3]}"
    # Fallback: try to find 8-digit date in any part
    for p in parts:
        if len(p) == 8 and p.isdigit():
            return f"{p[:4]}/{p[4:6]}/{p[6:8]}"
    return None

## Train Test split

In [None]:
df = pd.read_csv('data_labels_M_N.csv')

# Randomly split into develop and test sets, ensuring no overlap in 'hid'
unique_hids = df['hid'].unique()
np.random.seed(SEED)
np.random.shuffle(unique_hids)

n_test_hids = int(np.ceil(len(unique_hids) * 0.10))
test_hids = set(unique_hids[:n_test_hids])
develop_hids = set(unique_hids[n_test_hids:])

test_df = df[df['hid'].isin(test_hids)].copy()
develop_df = df[df['hid'].isin(develop_hids)].copy()

# Double-check no overlap in 'hid'
assert len(set(test_df['hid']) & set(develop_df['hid'])) == 0, "Overlap in 'hid' between test and develop sets!"

# Now, split develop set into 5 folds, grouped by 'hid'
gkf = GroupKFold(n_splits=5)
folds = []
develop_idx = develop_df.index.to_numpy()
for fold_idx, (_, val_idx) in enumerate(gkf.split(develop_df, groups=develop_df['hid'])):
    fold = develop_df.iloc[val_idx]
    folds.append(fold)

# Double-check no overlap of 'hid' between any fold and test set
for i, fold in enumerate(folds):
    overlap = set(fold['hid']) & test_hids
    assert len(overlap) == 0, f"Fold {i+1} and test set have overlapping hids: {overlap}"

# Save filenames for each fold and test set
for i, fold in enumerate(folds):
    np.save(f'fold_{i+1}_filenames.npy', fold['filename'].to_numpy())
np.save('test_filenames.npy', test_df['filename'].to_numpy())

## Hyperparameter tuning

In [None]:
seed_everything(SEED)

study_date = datetime.now().strftime('%y%m%d')
study_folder_prefix = f'{study_date}'
study_num = 0
for folder in sorted(os.listdir('model'), reverse=True):
    if folder.startswith(study_folder_prefix):
        ch_idx = len(study_folder_prefix) + 1
        study_num = int(folder[ch_idx:ch_idx + 2]) + 1
        break
STUDY_DIR = f'model/{study_folder_prefix}_{study_num:02d}'
os.makedirs(STUDY_DIR, exist_ok=True)
search_space_train = {
    'BATCH_SIZE_LIST'       : BATCH_SIZE_LIST_GLOBAL,
    'LR_INITIAL_LO'         : LR_INITIAL_LO,
    'LR_INITIAL_HI'         : LR_INITIAL_HI,
    'LR_STEP_SIZE_LO'       : LR_STEP_SIZE_LO,
    'LR_STEP_SIZE_HI'       : LR_STEP_SIZE_HI,
    'LR_STEP_GAMMA_LO'      : LR_STEP_GAMMA_LO,
    'LR_STEP_GAMMA_HI'      : LR_STEP_GAMMA_HI,
    'EARLY_STOP_PATIENCE_LO'  : EARLY_STOP_PATIENCE_LO,
    'EARLY_STOP_PATIENCE_HI'  : EARLY_STOP_PATIENCE_HI
}
search_space_arch = {
    'BLOCK_SIZE'          : BLOCK_SIZE_GLOBAL,
    'BLOCK_DEPTH'        : BLOCK_DEPTH_GLOBAL,
    'BLOCK_LAYERS'       : BLOCK_LAYERS_GLOBAL,
    'HIDDEN_SIZE'        : HIDDEN_SIZE_GLOBAL,
    'KERNEL_NUM'         : KERNEL_NUM_GLOBAL,  
}
SEARCH_SPACE_ARCH_FILENAME  = f'{STUDY_DIR}/search_space_arch.pkl'
SEARCH_SPACE_TRAIN_FILENAME = f'{STUDY_DIR}/search_space_train.pkl'
pickle.dump(search_space_arch, open(SEARCH_SPACE_ARCH_FILENAME, 'wb'))
pickle.dump(search_space_train, open(SEARCH_SPACE_TRAIN_FILENAME, 'wb'))

fold_filenames = [np.load(f'fold_{i}_filenames.npy', allow_pickle=True) for i in range(1, 6)]

for i_trial in range(N_TRIAL):
    search_space_arch = pickle.load(open(SEARCH_SPACE_ARCH_FILENAME, 'rb'))
    search_space_train = pickle.load(open(SEARCH_SPACE_TRAIN_FILENAME, 'rb'))

    batch_size          = np.random.choice(search_space_train['BATCH_SIZE_LIST'], size=1)[0]
    lr_initial          = loguniform(low=search_space_train['LR_INITIAL_LO'], high=search_space_train['LR_INITIAL_HI'], size=1)[0]
    lr_step_size        = np.random.randint(low=search_space_train['LR_STEP_SIZE_LO'], high=search_space_train['LR_STEP_SIZE_HI'], size=1)[0]
    lr_step_gamma       = np.random.uniform(low=search_space_train['LR_STEP_GAMMA_LO'], high=search_space_train['LR_STEP_GAMMA_HI'], size=1)[0]
    early_stop_pat      = np.random.randint(low=search_space_train['EARLY_STOP_PATIENCE_LO'], high=search_space_train['EARLY_STOP_PATIENCE_HI'], size=1)[0]

    # Remove search_space_train after use
    del search_space_train

    block_size             = np.random.choice(search_space_arch['BLOCK_SIZE'], size=1)[0]
    block_depth            = np.random.choice(search_space_arch['BLOCK_DEPTH'], size=1)[0]
    block_layers           = np.random.choice(search_space_arch['BLOCK_LAYERS'], size=1)[0]
    hidden_size            = np.random.choice(search_space_arch['HIDDEN_SIZE'], size=1)[0]
    kernel_num             = np.random.choice(search_space_arch['KERNEL_NUM'], size=1)[0]

    # Remove search_space_arch after use
    del search_space_arch

    trial_folder = ''
    trial_folder += f'batch={batch_size}_'
    trial_folder += f'lr={lr_initial:.5f}_step={lr_step_size}_gam={lr_step_gamma:.3f}_pat={early_stop_pat}_'
    trial_folder += f'block_size={block_size}_block_depth={block_depth}_hidden_size={hidden_size}_block_layers={block_layers}_kernel_num={kernel_num}'
    trial_dir = f'{STUDY_DIR}/{trial_folder}'
    os.makedirs(trial_dir, exist_ok=True)
    hparams_train = {
        'N_FOLD'            : N_FOLD,
        'N_EPOCH'           : N_EPOCH,
        'BATCH_SIZE'        : batch_size,
        'LR_INITIAL'        : lr_initial,
        'LR_STEP_SIZE'      : lr_step_size,
        'LR_STEP_GAMMA'     : lr_step_gamma,
        'EARLY_STOP_PAT'    : early_stop_pat,
    }
    hparams_arch = {
        'BLOCK_SIZE'        : block_size,
        'BLOCK_DEPTH'       : block_depth,
        'BLOCK_LAYERS'      : block_layers,
        'HIDDEN_SIZE'       : hidden_size,
        'KERNEL_NUM'        : kernel_num,
    }
    # Remove all single-use hparam variables after use
    del batch_size, lr_initial, lr_step_size, lr_step_gamma, early_stop_pat
    del block_size, block_depth, block_layers, hidden_size, kernel_num

    print(f'RANDOM SEARCH TRIAL {1 + i_trial}/{N_TRIAL}')
    HPARAMS_TRAIN_FILENAME  = f'{trial_dir}/hparams_train.pkl'
    HPARAMS_ARCH_FILENAME   = f'{trial_dir}/hparams_arch.pkl'
    pickle.dump(hparams_train, open(HPARAMS_TRAIN_FILENAME, 'wb'))
    pickle.dump(hparams_arch, open(HPARAMS_ARCH_FILENAME, 'wb'))

    TRIAL_DIR                   = trial_dir
    TRIAL_FOLDER                = TRIAL_DIR.split('/')[-1]
    STUDY_DIR                   = TRIAL_DIR[:-(len(TRIAL_FOLDER) + 1)]
    HPARAMS_ARCH_FILENAME       = f'{TRIAL_DIR}/hparams_arch.pkl'
    HPARAMS_TRAIN_FILENAME      = f'{TRIAL_DIR}/hparams_train.pkl'
    hparams_arch    = pickle.load(open(HPARAMS_ARCH_FILENAME, 'rb'))
    BLOCK_SIZE      = int(hparams_arch['BLOCK_SIZE'])
    BLOCK_DEPTH     = int(hparams_arch['BLOCK_DEPTH'])
    BLOCK_LAYERS    = int(hparams_arch['BLOCK_LAYERS'])
    HIDDEN_SIZE     = int(hparams_arch['HIDDEN_SIZE'])
    KERNEL_NUM      = int(hparams_arch['KERNEL_NUM'])
    # Remove hparams_arch after use
    del hparams_arch

    hparams_train   = pickle.load(open(HPARAMS_TRAIN_FILENAME, 'rb'))
    N_FOLD          = int(hparams_train['N_FOLD'])
    N_EPOCH         = int(hparams_train['N_EPOCH'])
    BATCH_SIZE      = int(hparams_train['BATCH_SIZE'])
    LR_INITIAL      = hparams_train['LR_INITIAL']
    LR_STEP_SIZE    = hparams_train['LR_STEP_SIZE']
    LR_STEP_GAMMA   = hparams_train['LR_STEP_GAMMA']
    EARLY_STOP_PAT  = hparams_train['EARLY_STOP_PAT']
    # Remove hparams_train after use
    del hparams_train

    model = FinalModel(block_size=BLOCK_SIZE, block_depth=BLOCK_DEPTH, block_layers=BLOCK_LAYERS, hidden_size=HIDDEN_SIZE, kernel_num=KERNEL_NUM).to(DEVICE)
    torch.save(model.state_dict(), f'{TRIAL_DIR}/initial_weights.pth')
    best_val_auroc_list = []

    # --- Cross-validation to determine early stopping points ---
    early_stop_epochs = []
    for i_fold in range(N_FOLD):
        train_filenames = np.concatenate(fold_filenames[:i_fold] + fold_filenames[i_fold+1:])
        valid_filenames = fold_filenames[i_fold]
        train = dataset(header_files=train_filenames)
        train.num_leads = 12
        train.sample = True
        valid = dataset(header_files=valid_filenames)
        valid.num_leads = 12
        valid.sample = False
        valid.files.reset_index(drop=True, inplace=True)
        train_loader = DataLoader(train, collate_fn=collate_fn, batch_size=BATCH_SIZE, shuffle=True)
        valid_loader = DataLoader(valid, collate_fn=collate_fn, batch_size=BATCH_SIZE, shuffle=False)

        model.load_state_dict(torch.load(f'{TRIAL_DIR}/initial_weights.pth', weights_only=True))
        weight_cache = f'{TRIAL_DIR}/weights_fold_{i_fold+1}.pth'
        criterion = nn.BCEWithLogitsLoss()
        optimizer = Adam(model.parameters(), lr=LR_INITIAL, weight_decay=1e-4)
        scheduler = ReduceLROnPlateau(optimizer, patience=LR_STEP_SIZE, factor=LR_STEP_GAMMA, min_lr=1e-5, mode='max')
        best_val_auroc = 0.0
        early_stop_count = 0
        best_epoch = 0
        for epoch in range(N_EPOCH):
            print(f'FOLD {i_fold+1} - EPOCH {1 + epoch}/{N_EPOCH}')        
            train_loss = train_model(train_loader, model, criterion, optimizer, scheduler)
            y_true_valid, y_pred_valid, _ = evaluate_model(valid_loader, model)
            valid_loss = F.binary_cross_entropy(torch.FloatTensor(y_pred_valid), torch.FloatTensor(y_true_valid))
            valid_auroc = roc_auc_score(y_true_valid, y_pred_valid)
            print(f'val_auroc: {valid_auroc:.4f}')
            scheduler.step(valid_auroc)
            if valid_auroc > best_val_auroc:
                print(f'>> val_auroc increased from {best_val_auroc:.4f} to {valid_auroc:.4f}>> Saving weights to [{weight_cache}]')
                torch.save(model.state_dict(), weight_cache)
                best_val_auroc = valid_auroc
                early_stop_count = 0
                best_epoch = epoch + 1  # 1-based epoch
            else:
                early_stop_count += 1
            if early_stop_count >= EARLY_STOP_PAT:
                break
        best_val_auroc_list.append(best_val_auroc)
        early_stop_epochs.append(best_epoch if best_epoch > 0 else epoch + 1)
        # Remove per-fold variables after use
        del train_filenames, valid_filenames, train, valid, train_loader, valid_loader, weight_cache, criterion, optimizer, scheduler

    val_auroc_mean = np.mean(np.array(best_val_auroc_list), axis=0)

    print('*' * 100)
    for i in range(N_FOLD):
        print(f'Fold #{i}  : {best_val_auroc_list[i]:.4f} (early stop at epoch {early_stop_epochs[i]})')
    print(f'Mean : {val_auroc_mean:.4f}')
    print('*' * 100)

    # --- Fit final model on develop set for average early stopping steps ---
    weight_cache = f'{TRIAL_DIR}/weights.pth'
    # Only create develop dataset/loader once per trial
    develop_filenames = np.concatenate(fold_filenames)
    develop = dataset(header_files=develop_filenames)
    develop.num_leads = 12
    develop.sample = True
    develop_loader = DataLoader(develop, collate_fn=collate_fn, batch_size=BATCH_SIZE, shuffle=False)
    avg_early_stop_epoch = int(np.round(np.mean(early_stop_epochs)))
    print(f"Training final model on develop set for {avg_early_stop_epoch} epochs (average early stopping point)")
    model.load_state_dict(torch.load(f'{TRIAL_DIR}/initial_weights.pth', weights_only=True))
    criterion = nn.BCEWithLogitsLoss()
    optimizer = Adam(model.parameters(), lr=LR_INITIAL)
    scheduler = ReduceLROnPlateau(optimizer, patience=LR_STEP_SIZE, factor=LR_STEP_GAMMA, min_lr=1e-5, mode='max')
    for epoch in range(int((avg_early_stop_epoch)*4/5)):
        print(f'DEVELOP FINAL - EPOCH {1 + epoch}/{avg_early_stop_epoch}')
        train_loss = train_model(develop_loader, model, criterion, optimizer, scheduler)
    torch.save(model.state_dict(), weight_cache)
    # Remove develop set variables after use
    del develop_filenames, develop, develop_loader, criterion, optimizer, scheduler, weight_cache

    TRIAL_FOLDER_NEW = f'auc={val_auroc_mean:.4f}_{TRIAL_FOLDER}'
    TRIAL_DIR_NEW = f'{STUDY_DIR}/{TRIAL_FOLDER_NEW}'
    os.rename(TRIAL_DIR, TRIAL_DIR_NEW)
    gc.collect()

## Inference

In [None]:
trial_dir = "model/auc=0.8772_batch=64_lr=0.00086_step=2_gam=0.164_pat=4_block_size=16_block_depth=3_hidden_size=32_block_layers=3_kernel_num=5"
TRIAL_DIR                   = trial_dir
TRIAL_FOLDER                = TRIAL_DIR.split('/')[-1]
STUDY_DIR                   = TRIAL_DIR[:-(len(TRIAL_FOLDER) + 1)]
HPARAMS_ARCH_FILENAME       = f'{TRIAL_DIR}/hparams_arch.pkl'
HPARAMS_TRAIN_FILENAME      = f'{TRIAL_DIR}/hparams_train.pkl'
hparams_arch    = pickle.load(open(HPARAMS_ARCH_FILENAME, 'rb'))
BLOCK_SIZE      = int(hparams_arch['BLOCK_SIZE'])
BLOCK_DEPTH     = int(hparams_arch['BLOCK_DEPTH'])
BLOCK_LAYERS    = int(hparams_arch['BLOCK_LAYERS'])
HIDDEN_SIZE     = int(hparams_arch['HIDDEN_SIZE'])
KERNEL_NUM      = int(hparams_arch['KERNEL_NUM'])
hparams_train   = pickle.load(open(HPARAMS_TRAIN_FILENAME, 'rb'))
N_FOLD          = int(hparams_train['N_FOLD'])
N_EPOCH         = int(hparams_train['N_EPOCH'])
BATCH_SIZE      = int(hparams_train['BATCH_SIZE'])
LR_INITIAL      = hparams_train['LR_INITIAL']
LR_STEP_SIZE    = hparams_train['LR_STEP_SIZE']
LR_STEP_GAMMA   = hparams_train['LR_STEP_GAMMA']
EARLY_STOP_PAT  = hparams_train['EARLY_STOP_PAT']

test_filenames = np.load('test_filenames.npy', allow_pickle=True)
test= dataset(header_files=test_filenames)
test.num_leads = 12
test.sample = False
test.files.reset_index(drop=True, inplace=True)
test_loader = DataLoader(test, collate_fn=collate_fn, batch_size=BATCH_SIZE, shuffle=False)

fold_filenames = [np.load(f'fold_{i}_filenames.npy', allow_pickle=True) for i in range(1, 6)]
develop_filenames = np.concatenate(fold_filenames)
develop = dataset(header_files=develop_filenames)
develop.num_leads = 12
develop.sample = False
develop.files.reset_index(drop=True, inplace=True)
develop_loader = DataLoader(develop, collate_fn=collate_fn, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
torch.set_num_threads(os.cpu_count())

model = FinalModel(
    block_size=BLOCK_SIZE,
    block_depth=BLOCK_DEPTH,
    block_layers=BLOCK_LAYERS,
    hidden_size=HIDDEN_SIZE,
    kernel_num=KERNEL_NUM
).to(DEVICE)
weight_cache = f'{TRIAL_DIR}/weights.pth'
model.load_state_dict(torch.load(weight_cache, weights_only=True))

with open('y_test_proba_cnn.pkl', 'rb') as f:
    y_pred_test = pickle.load(f)

DEVICE = next(model.parameters()).device
SAMPLING_RATE = 500
BEAT_LEN = 400
TOP_K = 1000

# ----------------------------------------------------------------
# 1. Grab raw signals, ages, genders, predictions (one pass only)
# ----------------------------------------------------------------
ecgs, ages, genders, preds = [], [], [], []

model.eval()
with torch.no_grad():
    for X, _, a, g in tqdm(test_loader, desc='Collecting signals'):
        ecgs.extend(X.cpu().numpy())
        ages.extend(a.cpu().numpy())
        genders.extend(g.cpu().numpy())
preds = y_pred_test.ravel()
ecgs = np.asarray(ecgs)
ages = np.asarray(ages)
genders = np.asarray(genders)

# ----------------------------------------------------------------
# 2. Patient-level median waveform & ROI  (Lead-II reference)
# ----------------------------------------------------------------
def patient_median_and_roi(signal_12xN, *, sampling_rate=SAMPLING_RATE, beat_len=BEAT_LEN):
    """
    signal_12xN : np.ndarray (12, N)
    Returns
    -------
    median_wave : np.ndarray (12, beat_len)
    roi_indices_per_lead : list[ list[(start, end)] ]   # per lead
    """
    # --- 1) detect R-peaks once on Lead-II (index 1) -----------------
    lead_ref = signal_12xN[1]
    try:
        _, rpeaks = nk.ecg_peaks(lead_ref, sampling_rate=sampling_rate)
        r_idx = rpeaks["ECG_R_Peaks"]
    except Exception:
        return np.full((12, beat_len), np.nan, np.float32), [[] for _ in range(12)]

    if len(r_idx) < 3:
        return np.full((12, beat_len), np.nan, np.float32), [[] for _ in range(12)]

    # --- 2) build per-lead epochs and resample -----------------------
    median_wave = np.full((12, beat_len), np.nan, np.float32)
    roi_indices_perlead = []

    for lead in signal_12xN:
        try:
            epochs = nk.epochs_create(
                lead, events=r_idx,
                sampling_rate=sampling_rate,
                epochs_start=-0.3, epochs_end=0.5
            )
        except Exception:
            roi_indices_perlead.append([])
            continue

        keys = list(epochs.keys())[1:-1]
        if not keys:
            roi_indices_perlead.append([])
            continue

        roi_this_lead = [epochs[k]["Index"] for k in keys]
        roi_indices_perlead.append(roi_this_lead)

        try:
            beats_rs = resample(
                np.stack([epochs[k]["Signal"] for k in keys]),
                beat_len, axis=1
            )
            median_wave[len(roi_indices_perlead)-1] = np.nanmedian(beats_rs, axis=0)
        except Exception:
            median_wave[len(roi_indices_perlead)-1] = np.full((beat_len,), np.nan, np.float32)

    return median_wave, roi_indices_perlead

# ----------------------------------------------------------------
# 3. Group median waveforms (high-risk vs low-risk) and ROI indices
# ----------------------------------------------------------------
hi_idx = np.argsort(preds)[-TOP_K:]
lo_idx = np.argsort(preds)[:TOP_K:]
del preds
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

def group_median_and_roi(indices):
    results = Parallel(n_jobs=-1, backend="loky")(
        delayed(patient_median_and_roi)(ecgs[i]) for i in indices
    )
    patient_meds, patient_rois = zip(*results)
    patient_meds = np.stack(patient_meds)
    return np.nanmedian(patient_meds, axis=0), list(patient_rois)

try:
    median_hi, roi_hi = group_median_and_roi(hi_idx)
except Exception as e:
    print(f"Error in group_median_and_roi(hi_idx): {e}")
    median_hi, roi_hi = None, None

try:
    median_lo, roi_lo = group_median_and_roi(lo_idx)
except Exception as e:
    print(f"Error in group_median_and_roi(lo_idx): {e}")
    median_lo, roi_lo = None, None

gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

# ----------------------------------------------------------------
# 4. Plot median waveforms (4x3 grid, custom lead order and names)
#    Also overlay original top_K waveforms in brighter, slimmer lines
# ----------------------------------------------------------------
lead_names = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF','V1', 'V2', 'V3', 'V4', 'V5', 'V6']
lead_indices = [0, 1,2,3,4,5,6,7,8,9,10,11]

time_ms = np.arange(BEAT_LEN) * 1000 / SAMPLING_RATE  # x-axis in ms

fig, axes = plt.subplots(4, 3, figsize=(14, 16), sharex=True)
axes = axes.flatten()

for plot_idx, lead_idx in enumerate(lead_indices):
    # Shift high risk plot by 10 frames
    shifted_hi = np.roll(median_hi[lead_idx], 10)
    # Plot the group median waveforms only
    axes[plot_idx].plot(time_ms, shifted_hi, color='red', label='High MACCE risk', linewidth=2.2, zorder=2)
    axes[plot_idx].plot(time_ms, median_lo[lead_idx], color='green', label='Low MACCE risk', linewidth=2.2, zorder=2)
    axes[plot_idx].set_ylabel(f"mV")
    axes[plot_idx].set_title(lead_names[lead_idx])
    axes[plot_idx].set_xlim([time_ms[0], time_ms[-1]])
    axes[plot_idx].set_ylim([median_hi[lead_idx].min()-5, median_hi[lead_idx].max()+3])
    axes[plot_idx].set_xlabel('ms')
    # Place legend inside the plot area at the bottom center, not overlapping with timeseries
    axes[plot_idx].legend(
        loc='lower center',
        bbox_to_anchor=(0.5, 0.05),
        borderaxespad=0.,
        fontsize='small',
        frameon=True
    )
for ax in axes[:-1]:
    ax.set_xlabel('')
# Hide any unused subplots (if any)
for ax in axes[len(lead_indices):]:
    ax.axis('off')
# Add more empty space above and below the plots
plt.tight_layout(rect=[0, 0.03, 1, 0.95])  
plt.show()

# ----------------------------------------------------------------
# 5. SHAP deep explainer on high-risk ECGs
# ----------------------------------------------------------------
del test_loader

class ModelWrapper(torch.nn.Module):
    """Wrap FinalModel to handle concatenated ECG, age, and gender inputs."""
    def __init__(self, base):
        super().__init__()
        self.base = base.eval()

    def forward(self, x):
        # Split input into ECG, age, and gender components
        ecg = x[:, :12*5000]  # First 12*5000 elements are ECG channels
        age = x[:, 12*5000:12*5000+1]  # Next element is age
        gender = x[:, 12*5000+1:]  # Last element is gender
        # Reshape ECG to (batch, 12, 5000)
        ecg = ecg.reshape(-1, 12, 5000)
        out = self.base(ecg, age, gender)
        return F.sigmoid(out)

# Use GPU for SHAP computation if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Convert ages and genders to tensors (on GPU)
ages_torch = torch.tensor(ages, dtype=torch.float32, device=device)
genders_torch = torch.tensor(genders, dtype=torch.float32, device=device)

# Create background data with ECG, age, and gender concatenated (on GPU)
background_ecgs = torch.tensor(ecgs[lo_idx[:10]], dtype=torch.float32, device=device)
background_ecgs_flat = background_ecgs.reshape(background_ecgs.shape[0], -1)
background_ages = ages_torch[lo_idx[:10]]
background_genders = genders_torch[lo_idx[:10]]
background = torch.cat([background_ecgs_flat, background_ages, background_genders], dim=1)

SHAP_BATCH = 1
DTYPE = torch.float32

def make_activations_safe(module):
    for child in module.children():
        if hasattr(child, "inplace") and child.inplace:
            child.inplace = False
        make_activations_safe(child)

wrapped_model = ModelWrapper(model).to(device)
make_activations_safe(wrapped_model)
wrapped_model_gpu = wrapped_model.float().to(device)
torch.cuda.empty_cache()
gc.collect()

explainer = shap.DeepExplainer(wrapped_model_gpu, background)
del background, background_ecgs, background_ecgs_flat, background_ages, background_genders
torch.cuda.empty_cache()
gc.collect()

# Compute SHAP values batch-wise (on GPU)
shap_vals_chunks = []
for start in tqdm(range(0, len(hi_idx), SHAP_BATCH), desc="SHAP batches"):
    batch_idx = hi_idx[start:start+SHAP_BATCH]
    batch_ecgs = torch.tensor(ecgs[batch_idx], dtype=DTYPE, device=device)
    batch_ages = ages_torch[batch_idx]
    batch_genders = genders_torch[batch_idx]
    batch_flat = batch_ecgs.reshape(batch_ecgs.shape[0], -1)
    batch = torch.cat([batch_flat, batch_ages, batch_genders], dim=1)
    torch.cuda.empty_cache()
    gc.collect()
    
    # Calculate SHAP values
    shap_batch = explainer.shap_values(batch, check_additivity=False)[0]
    shap_batch = shap_batch.astype(np.float32)[:12*5000]
    shap_vals_chunks.append(shap_batch)
    
    # Clear memory
    del batch, batch_ecgs, batch_ages, batch_genders, batch_flat, shap_batch
    torch.cuda.empty_cache()
    gc.collect()

shap_vals = np.array(shap_vals_chunks)    # (K, 12*5000)
torch.cuda.empty_cache()
gc.collect()

# ----------------------------------------------------------------
# 6. Aggregate saliency using region-of-interest (ROI) SHAP values
#    (EXCLUDE first and last detected beat in both waveform and shap)
# ----------------------------------------------------------------
# For each high-risk sample, for each lead, extract the SHAP values corresponding to the same ROI as used for the median waveform
# and resample/average to BEAT_LEN, then median across patients

shap_abs = np.abs(shap_vals)  # (K, 12*5000)
del shap_vals
gc.collect()

from scipy.signal import resample
def extract_roi_shap_per_lead(shap_12x5000, roi_indices_per_lead, beat_len=BEAT_LEN):
    beats_out = np.full((12, beat_len), np.nan, dtype=np.float32)
    for lead, (lead_shap, lead_rois) in enumerate(zip(shap_12x5000, roi_indices_per_lead)):
        if len(lead_rois) < 3:       # no middle beats → already nan
            continue
        use_rois = lead_rois[1:-1]
        if not use_rois:
            continue
        # Vectorised resampling of all beats in one shot
        beats = np.vstack([
            resample(lead_shap[roi.loc[-0.3]:roi.loc[0.5]+1], beat_len)
            for roi in use_rois
        ])
        beats_out[lead] = np.nanmedian(beats, axis=0)
    return beats_out

# For each high-risk sample, extract ROI-aligned SHAP
saliency_hi = []
for i in range(len(hi_idx)):
    # SHAP for this sample: (12*5000,) -> (12, 5000)
    sample_shap = shap_abs[i].reshape(12, 5000)
    # ROI indices for this sample: roi_hi[i] (list of 12 lists of (start, end))
    med_sample = extract_roi_shap_per_lead(sample_shap, roi_hi[i], beat_len=BEAT_LEN)
    saliency_hi.append(med_sample)
saliency_hi = np.stack(saliency_hi)  # (K, 12, beat_len)

del shap_abs
gc.collect()
saliency_hi = np.nanmedian(saliency_hi, axis=0)        # (12, beat_len)

# ----------------------------------------------------------------
# 7. Plot saliency map OVER high-risk median waveform (4x3 grid, custom lead order and names)
# ----------------------------------------------------------------
# Desired order: I, II, III // aVR, aVL, aVF // V1, V2, V3 // V4, V5, V6
fig, axes = plt.subplots(4, 3, figsize=(14, 16), sharex=True)
axes = axes.flatten()
im = None
time_ms = np.arange(BEAT_LEN) * 1000 / SAMPLING_RATE  # x-axis in ms

# Set a simple white-to-blue colormap for SHAP
white_blue_cmap = LinearSegmentedColormap.from_list(
    "white_blue", ["#ffffff", "#3b4cc0"]
)
saliency_cmap = white_blue_cmap

for plot_idx, lead_idx in enumerate(lead_indices):
    y_min = np.nanmin(median_hi[lead_idx])
    y_max = np.nanmax(median_hi[lead_idx])
    y_range = y_max - y_min
    margin = 0.25 * y_range if y_range > 0 else 0.5
    plot_ymin = y_min - margin
    plot_ymax = y_max + margin
    axes[plot_idx].set_ylabel(f"mV")
    axes[plot_idx].set_title(lead_names[plot_idx])
    axes[plot_idx].set_xlim([time_ms[0], time_ms[-1]])
    axes[plot_idx].set_ylim([median_hi[lead_idx].min()-5, median_hi[lead_idx].max()+3]) 
    axes[plot_idx].set_xlabel('ms')

    axes[plot_idx].plot(time_ms, median_hi[lead_idx], color='red', zorder=2)
    im = axes[plot_idx].imshow(
        saliency_hi[lead_idx][None, :], aspect='auto',
        cmap=saliency_cmap, interpolation='nearest',
        extent=[time_ms[0], time_ms[-1], median_hi[lead_idx].min()-5, median_hi[lead_idx].max()+3],
        alpha=0.85,  # slightly less transparent for more contrast
        origin='lower', zorder=1
    )

for ax in axes[:-1]:
    ax.set_xlabel('')

# Hide any unused subplots (if any)
for ax in axes[len(lead_indices):]:
    ax.axis('off')

# Place colorbars on the rightmost column (i.e., axes 2, 5, 8, 11)
rightmost_indices = [2, 5, 8, 11]
for i in rightmost_indices:
    cbar = fig.colorbar(im, ax=axes[i], orientation='vertical',
                       fraction=0.04, pad=0.02)
    cbar.set_ticks([cbar.vmin, cbar.vmax])
    cbar.set_ticklabels(['Low\nsalience', 'High\nsalience'])
    cbar.ax.yaxis.set_ticks_position('right')
plt.tight_layout(rect=[0, 0.03, 1, 0.95])  
plt.show()
del fig, axes, im, saliency_hi, time_ms
gc.collect()

In [None]:
seed_everything(SEED)
df = pd.read_csv('data_labels_M_N.csv')
df2 = pd.read_csv('ecg_labels_w_features_lab_asa_3mo_N.csv')
df['andur'] = df['filename'].map(df2.set_index('filename')['andur'])
df['asa'] = df['filename'].map(df2.set_index('filename')['final_asa'])
df.to_csv('data_labels_gbm.csv', index=False)

all= dataset(header_files=df['filename'].to_list())
all.num_leads = 12
all.sample = False
all.files.reset_index(drop=True, inplace=True)
all_loader = DataLoader(all, collate_fn=collate_fn, batch_size=BATCH_SIZE, shuffle=False)
model = FinalModel(block_size =BLOCK_SIZE, block_depth =BLOCK_DEPTH, block_layers=BLOCK_LAYERS, hidden_size=HIDDEN_SIZE, kernel_num=KERNEL_NUM).to(DEVICE)
weight_cache = f'{TRIAL_DIR}/weights.pth'
model.load_state_dict(torch.load(weight_cache, weights_only=True))
y_true_test, y_pred_test, hidden_states = evaluate_model(all_loader, model)
hidden_states_df = pd.DataFrame(hidden_states, columns=[f'hidden_{i}' for i in range(HIDDEN_SIZE)])
df = pd.concat([df, hidden_states_df], axis=1)
df.to_csv('data_labels_gbm_hidden.csv', index=False)

In [None]:
# Run on tensorflow
import tensorflow as tf
from tensorflow import keras
from keras import layers
from keras.preprocessing.text import Tokenizer
import json
import pandas as pd
import numpy as np
import pickle
SEED= 1

class TransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, droprate):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.droprate = droprate
        self.ff_dim = ff_dim
        self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.ffn = keras.Sequential([layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim),])
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(droprate)
        self.dropout2 = layers.Dropout(droprate)

    def call(self, inputs, training):
        attn_output = self.att(inputs, inputs)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)

    def get_config(self):
        return {'embed_dim': self.embed_dim, 'num_heads': self.num_heads, 'ff_dim':self.ff_dim, 'droprate':self.droprate}

class TokenAndPositionEmbedding(layers.Layer):
    def __init__(self, maxlen, vocab_size, embed_dim):
        super().__init__()
        self.maxlen = maxlen
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
        self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)

    def call(self, x):
        maxlen = tf.shape(x)[-1]
        positions = tf.range(start=0, limit=maxlen, delta=1)
        positions = self.pos_emb(positions)
        x = self.token_emb(x)
        return x + positions

    def get_config(self):
        return {'maxlen': self.maxlen, 'vocab_size': self.vocab_size, 'embed_dim': self.embed_dim}

df_temp = pd.read_csv('transformer/icd10_mapping_hclee_corrected_manual_01.csv', usecols=['opname', 'p', 'o', 'a'], dtype=str)
df_temp = df_temp.dropna(subset=['p', 'o', 'a'])
df_temp = df_temp.drop_duplicates(subset=['opname'], keep='first')

df = pd.read_csv('data_labels_gbm.csv', dtype=str)
df['src'] = 'snuh'
df = df.sample(frac=1, random_state=SEED).reset_index(drop=True)

opname_set = set(df_temp['opname'])
mask = df['opname'].isin(opname_set)
df.loc[mask, 'opname_final'] = df.loc[mask, 'opname']

df['opname_final'] = df['opname_final'].fillna(df['opname'])

x_raw = np.copy(df['opname_final'].values.astype(str))
vocab_size = 4000

#t_x = Tokenizer(vocab_size)
#t_x.fit_on_texts(x_raw)
#with open('transformer/tokenizer_x.pkl', 'wb') as f:
#    pickle.dump(t_x, f)

with open('transformer/tokenizer_x.pkl', 'rb') as f:
    t_x = pickle.load(f)
x_seq = t_x.texts_to_sequences(x_raw)
maxlen = 158  
x_pad = keras.preprocessing.sequence.pad_sequences(x_seq, maxlen=maxlen)

# To ensure consistent predictions for the same input, cache predictions for unique input sequences
# Map from tuple(sequence) -> prediction
x_pad_tuples = [tuple(row) for row in x_pad]
unique_x_pad, unique_indices, inverse_indices = np.unique(x_pad, axis=0, return_index=True, return_inverse=True)
unique_x_pad_tuples = [tuple(row) for row in unique_x_pad]

prefix = 'transformer_res'
for target in ('p', 'o', 'a'):
    print(f'Processing {target}...')
    opath = f'transformer/{prefix}_{target}.csv'
    t_y = Tokenizer()
    tokenizer_config = json.loads(open(f'transformer/tokenizer_y_{target}.json').read())
    tokenizer_config = tokenizer_config['config']
    t_y.word_index = json.loads(tokenizer_config['word_index']) if isinstance(tokenizer_config['word_index'], str) else tokenizer_config['word_index']
    t_y.index_word = json.loads(tokenizer_config['index_word']) if isinstance(tokenizer_config['index_word'], str) else tokenizer_config['index_word']
    t_y.word_counts = json.loads(tokenizer_config['word_counts']) if isinstance(tokenizer_config['word_counts'], str) else tokenizer_config['word_counts']
    t_y.index_word = {int(k): v for k, v in t_y.index_word.items()}
    t_y.document_count = tokenizer_config['document_count']
    custom_objects = {
        'TokenAndPositionEmbedding': TokenAndPositionEmbedding,
        'TransformerBlock': TransformerBlock
    }
    model = keras.models.model_from_json(
        open(f'transformer/model_{target}.json').read(),
        custom_objects=custom_objects
    )
    model.load_weights(f'transformer/tuned_weights_{target}.h5')

    pred_unique = model.predict(unique_x_pad, verbose=1)
    pred = pred_unique[inverse_indices]

    df['pred'] = pd.Series(t_y.sequences_to_texts(np.argmax(pred, axis=1)[...,None] + 1)).str.upper()
    df['conf'] = pred.max(axis=1)
    df[df['src'] == 'snuh'].drop(columns='src').to_csv(opath, index=False, encoding='utf-8-sig')
    print(f'Finished processing {target}\n')

df = pd.read_csv(f'transformer/{prefix}_p.csv', dtype=str, usecols=['opname_final', 'pred']).rename(columns={'pred':'p'})
df = df.drop_duplicates(subset=['opname_final'], keep='first')
df_o = pd.read_csv(f'transformer/{prefix}_o.csv', dtype=str, usecols=['opname_final', 'pred']).rename(columns={'pred':'o'})
df_o = df_o.drop_duplicates(subset=['opname_final'], keep='first')
df['o'] = df['opname_final'].map(df_o.set_index('opname_final')['o'])
df_a = pd.read_csv(f'transformer/{prefix}_a.csv', dtype=str, usecols=['opname_final', 'pred']).rename(columns={'pred':'a'})
df_a = df_a.drop_duplicates(subset=['opname_final'], keep='first')
df['a'] = df['opname_final'].map(df_a.set_index('opname_final')['a'])

df_final = pd.read_csv('data_labels_gbm_hidden_icd.csv', dtype=str)

# Set opname_final to opname if opname is in df_temp['opname'] (even if not NA)
mask = df_final['opname'].isin(opname_set)
df_final.loc[mask, 'opname_final'] = df_final.loc[mask, 'opname']

df_final['opname_final'] = df_final['opname_final'].fillna(df_final['opname'])

mask2 = df_final['opname_final'].isin(opname_set)
for col in ['p','o','a']:
    df_final.loc[mask2, col] = df_final.loc[mask2, 'opname_final'].map(df_temp.set_index('opname')[col])
for col in ['p', 'o', 'a']:
    df_final.loc[~mask2, col] = df_final.loc[~mask2, 'opname_final'].map(df.set_index('opname_final')[col])

df_final.to_csv('data_labels_gbm_hidden_icd.csv', index=False, encoding='utf-8-sig')


In [None]:
df = pd.read_csv('data_labels_gbm_hidden_icd.csv')
xml_fields = {
    'mach_ventricular_rate': 'VentricularRate',
    'mach_atrial_rate': 'AtrialRate',
    'mach_pr_interval': 'PRInterval',
    'mach_qrs_duration': 'QRSDuration',
    'mach_qt_interval': 'QTInterval',
    'mach_qtc_interval': 'QTCorrected',
    'mach_p_axis': 'PAxis',
    'mach_r_axis': 'RAxis',
    'mach_t_axis': 'TAxis'
}
for col in xml_fields:
    df[col] = None
for idx, row in tqdm(df.iterrows(), total=len(df)):
    xml_path = 'C:/rsrch/240801_ecg_mace/data/'+row['filename']
    try:
        tree = ET.parse(xml_path)
        root = tree.getroot()
        # Find the RestingECGMeasurements section
        measurements = root.find('RestingECGMeasurements')
        if measurements is not None:
            for col, tag in xml_fields.items():
                elem = measurements.find(tag)
                df.at[idx, col] = elem.text if elem is not None else None
        else:
            continue
    except Exception as e:
        continue
df.to_csv('data_labels_gbm_hidden_icd.csv', index=False)

In [55]:
HIDDEN_SIZE=32
df = pd.read_csv('data_labels_gbm_hidden_icd.csv')
df_col =[]

labels = 'label'
model_name = 'age_gender_p_o_a' # wo_cnn_, mach_/ age,andur, asa, gender, p_o_a

df['gender'] = df['gender'].astype(float)
for x in ['gender','asa','age','andur']:
   if x in model_name:
      df_col.extend([x])
if 'wo_cnn' not in model_name:
   if 'mach_' in model_name:
      df_col.extend([col for col in df.columns if 'mach_' in col])
   else:
      df_col.extend([f'hidden_{i}' for i in range(HIDDEN_SIZE)])
if 'p_o_a' in model_name:
   onehot_cols = ['p', 'o', 'a']
   onehot_df = pd.get_dummies(df[onehot_cols], prefix=onehot_cols).astype(float)
   df = df.drop(columns=onehot_cols)
   df = pd.concat([df, onehot_df], axis=1)
   df = df[df_col + list(onehot_df.columns) + ['label', 'hid', 'filename']]
else:
   df = df[df_col + ['label', 'hid', 'filename']]
df['label'] = df['label'].replace({3: 1})

test_filenames = np.load('test_filenames.npy', allow_pickle=True)
fold_filenames = [np.load(f'fold_{i}_filenames.npy', allow_pickle=True) for i in range(1, 6)]
develop_filenames = np.concatenate(fold_filenames)
df_train = df[df['filename'].isin(develop_filenames)]
df_test = df[df['filename'].isin(test_filenames)]

In [None]:
models = {
    'xgb': create_objective_function('xgb', xgb.XGBClassifier, xgb_params, df_train.drop(columns=['hid', 'label', 'filename']).columns, model_name),
    'rf': create_objective_function('rf', RandomForestClassifier, rf_params, df_train.drop(columns=['hid', 'label', 'filename']).columns, model_name),
    'svm': create_objective_function('svm', SVC, svm_params, df_train.drop(columns=['hid', 'label', 'filename']).columns, model_name),
    'logreg': create_objective_function('logreg', LogisticRegression, logreg_params, df_train.drop(columns=['hid', 'label', 'filename']).columns, model_name)
}
for model_type, objective_func in models.items():
    print(f"...{model_type.upper()}...")
    study = optuna.create_study(direction='maximize')
    study.optimize(lambda trial: objective_func(trial, df_train.drop(columns=['filename'])), n_trials=500)

In [None]:
model_type = 'xgb'
with open(f"best_params_{model_type}_{model_name}.json", 'r') as f:
    params = json.load(f)
if model_type =='xgb':
    base_model = xgb.XGBClassifier(**params)
elif model_type =='svm':
    base_model = SVC(**params)
elif model_type == 'logreg':
    base_model = LogisticRegression(**params)
elif model_type == 'rf':
    base_model = RandomForestClassifier(**params)

# Prepare X_train and y_train
X_train_full = df_train.drop(['label', 'hid', 'filename'], axis=1)
y_train = df_train['label'].ravel()

# --- BorutaShap feature selection only if model_name is 'age_gender_p_o_a' ---
if model_name == 'age_gender_p_o_a':
    if os.path.exists('selected_features.npy'):
        selected_features = np.load('selected_features.npy', allow_pickle=True).tolist()
        X_train = X_train_full[selected_features]
    else:
        boruta_model = BorutaShap(model=base_model, importance_measure='shap', classification=True)
        boruta_model.fit(X=X_train_full, y=y_train, n_trials=30, sample=False, train_or_test='test', verbose=False)
        # Get the selected features
        selected_features = boruta_model.Subset().columns.tolist()
        # Save the selected features
        np.save('selected_features.npy', np.array(selected_features))
        # Reduce X_train to selected features
        X_train = X_train_full[selected_features]
else:
    selected_features = X_train_full.columns.tolist()
    if 'p_o_a' in model_name:
        selected_features_orig = np.load('selected_features.npy', allow_pickle=True).tolist()
        selected_features = [col for col in selected_features_orig if col in selected_features]
        if 'asa' in model_name:
            selected_features.extend(['asa'])
        if 'andur' in model_name:
            selected_features.extend(['andur'])
    X_train = X_train_full[selected_features]

# Re-instantiate the model to avoid any contamination from BorutaShap
if model_type =='xgb':
    best_model = xgb.XGBClassifier(**params)
elif model_type =='svm':
    best_model = SVC(**params)
elif model_type == 'logreg':
    best_model = LogisticRegression(**params)
elif model_type == 'rf':
    best_model = RandomForestClassifier(**params)

if not os.path.exists(f"fitted_{model_type}_{model_name}.joblib"):
    best_model.fit(X_train, y_train)
    joblib.dump(best_model, f"fitted_{model_type}_{model_name}.joblib")
else:
    best_model = joblib.load(f"fitted_{model_type}_{model_name}.joblib")
y_train_proba = best_model.predict_proba(X_train)[:,1].ravel()

# Read anesthesia type and emop
sg = pd.read_csv('data_labels_gbm_hidden_icd.csv')
df_test['anetype'] = df_test['filename'].map(sg.set_index('filename')['anetype'])
df_test['emop'] = df_test['filename'].map(sg.set_index('filename')['emop'])
df_test['new_dept'] = df_test['filename'].map(sg.set_index('filename')['new_dept'])
df_test['new_dept'] = df_test['new_dept'].replace(['OG', 'UR'], 'OG+UR')
df_test['new_dept'] = df_test['new_dept'].replace(['OL', 'OT', 'PS', 'Others'], 'Others')

# Add opid to df_test for subgroup counting
if 'opid' not in df_test.columns:
    if 'opid' in sg.columns:
        df_test['opid'] = df_test['filename'].map(sg.set_index('filename')['opid'])
    else:
        # fallback: try to load from opid_icd_matching.csv if available
        try:
            opid_map = pd.read_csv('opid_icd_matching.csv')
            df_test['opid'] = df_test['filename'].map(dict(zip(opid_map['filename'], opid_map['opid'])))
        except Exception:
            df_test['opid'] = None

# Prepare test sets for all, general (1), and regional (0) anesthesia
test_groups = {
    'all': df_test.drop(columns=['anetype']),
    'gen': df_test[df_test['anetype'] == 1].drop(columns=['anetype']),
    'reg': df_test[df_test['anetype'] == 0].drop(columns=['anetype'])
}
# Add emop subgroups
test_groups['emop1'] = df_test[df_test['emop'] == 1].drop(columns=['anetype'])
test_groups['emop0'] = df_test[df_test['emop'] == 0].drop(columns=['anetype'])

# Add gender subgroups only if 'gender' exists in df_test
if 'gender' in df_test.columns:
    test_groups['gender1'] = df_test[df_test['gender'] == 1].drop(columns=['anetype'])
    test_groups['gender0'] = df_test[df_test['gender'] == 0].drop(columns=['anetype'])

# Add elderly subgroups (age > 60 and age <= 60) only if 'age' exists in df_test
if 'age' in df_test.columns:
    test_groups['elderly1'] = df_test[df_test['age'] > 60].drop(columns=['anetype'])
    test_groups['elderly0'] = df_test[df_test['age'] <= 60].drop(columns=['anetype'])

# Add new_dept subgroups
new_dept_categories = ['GS', 'NS', 'OG+UR', 'OS', 'TS', 'Others']
for dept in new_dept_categories:
    test_groups[f'new_dept_{dept}'] = df_test[df_test['new_dept'] == dept].drop(columns=['anetype'])

# Predict probabilities for each group using only selected features
X_test = {k: v.drop(['label', 'hid','filename'], axis=1)[selected_features] for k, v in test_groups.items()}
y_test = {k: v['label'].ravel() for k, v in test_groups.items()}
y_test_proba = {k: best_model.predict_proba(X_test[k])[:, 1].ravel() for k in test_groups}

# Helper function to get n filenames, n unique hids, and n unique opids for a group
def get_n_filename_hid_opid(df):
    n_filename = len(df)
    n_hid = len(df['hid'].unique()) if 'hid' in df.columns else 0
    n_opid = len(df['opid'].unique()) if 'opid' in df.columns and df['opid'].notna().any() else 0
    return n_filename, n_hid, n_opid

# SHAP summary for the full test set (only if model_type is 'xgb')
if model_type == 'xgb':
    # Define mapping for feature names
    feature_name_map = {
        'age': 'Age',
        'gender': 'Gender',
        'o_J': 'Operation (Inspection)',
        'o_Q': 'Operation (Repair)',
        'o_Y': 'Operation (Transplantation)',
        'p_WG': 'Body Part (Peritoneal Cavity)',
        'p_BK': 'Body Part (Right Lung)',
        'p_D5': 'Body Part (Esophagus)',
        'p_T0': 'Body Part (Right Kidney)',
        'a_3': 'Approach (Percutaneous)',
        **{f'hidden_{i}': f'Hidden unit {i}' for i in range(33)},
    }

    # Create a list of display names for the features, using the mapping where available
    display_feature_names = [
        feature_name_map.get(col, col) for col in X_test['all'].columns
    ]

    explainer = shap.TreeExplainer(best_model)
    shap_values = explainer.shap_values(X_test['all'])
    shap.summary_plot(shap_values, X_test['all'], feature_names=display_feature_names, max_display =60)

    # --- Draw SHAP force plot for a random patient in X_test['all'] ---
    rng = random.Random(6)
    random_idx = rng.randint(0, X_test['all'].shape[0] - 1)
    patient_data = X_test['all'].iloc[random_idx:random_idx+1]
    patient_shap_values = explainer.shap_values(patient_data)
    print(f"SHAP force plot for random patient at index {random_idx}:")
    # Format feature values and SHAP values to .3f for display
    formatted_feature_values = [f"{v:.1f}" if isinstance(v, (float, int)) else str(v) 
                               for v in patient_data.values[0]]
    formatted_shap_values = [f"{v:.2f}" for v in patient_shap_values[0]]
    # Compose custom feature names with value and SHAP value
    custom_feature_names = [
        f"{name}\n({val}, SHAP={shapv})"  # {name}\n({val}, SHAP={shapv})
        for name, val, shapv in zip(display_feature_names, formatted_feature_values, formatted_shap_values)
    ]
    shap.initjs()
    # Use matplotlib force plot with custom feature names, adjust figsize and text position
    fig = plt.figure(figsize=(max(8, len(custom_feature_names) * 0.7), 2.5))
    # Draw force plot
    force_plot = shap.force_plot(
        explainer.expected_value, 
        patient_shap_values, 
        #patient_data, 
        feature_names=custom_feature_names, 
        matplotlib=True,
        show=False
    )
    # Move x-tick labels (feature names) lower for readability
    plt.xticks(rotation=45, ha='right', fontsize=9)
    ax = plt.gca()
    texts_to_remove = []
    for txt in ax.texts:
        if txt.get_text() in ["base value", "f(x)"]:
            texts_to_remove.append(txt)
    for txt in texts_to_remove:
        txt.remove()
    plt.tight_layout()
    plt.show()

# Calibration before and after (only for 'all')
cal_model = LinearRegression().fit(y_test_proba['all'].reshape(-1,1), y_test['all'])
cal_intercept, cal_slope = cal_model.intercept_, cal_model.coef_[0]
brier_before = brier_score_loss(y_test['all'], y_test_proba['all'])
U_before, p_before = calc_unreliability(y_test['all'], y_test_proba['all'])

# Spline calibration
calib_filename = f"calib_{model_type}_{model_name}.pkl"
if os.path.exists(calib_filename):
    with open(calib_filename, "rb") as f:
        calib = pickle.load(f)
else:
    calib = mli.SplineCalib(unity_prior=False, unity_prior_weight=100, random_state=42, max_iter=500)
    calib.fit(y_train_proba, y_train)
    with open(calib_filename, "wb") as f:
        pickle.dump(calib, f)
y_test_proba_cal = {k: calib.calibrate(y_test_proba[k]) for k in test_groups}
cal_model_after = LinearRegression().fit(y_test_proba_cal['all'].reshape(-1,1), y_test['all'])
cal_intercept_after, cal_slope_after = cal_model_after.intercept_, cal_model_after.coef_[0]
brier_after = brier_score_loss(y_test['all'],y_test_proba_cal['all'])
U_after, p_after = calc_unreliability(y_test['all'], y_test_proba_cal['all'])

# Isotonic calibration (only for 'all') 
iso_calib_filename = f"iso_{model_type}_{model_name}.pkl"
if os.path.exists(iso_calib_filename):
    with open(iso_calib_filename, "rb") as f:
        iso_calib = pickle.load(f)
else:
    iso_calib = IsotonicRegression(out_of_bounds='clip').fit(y_train_proba, y_train)
    with open(iso_calib_filename, "wb") as f:
        pickle.dump(iso_calib, f)
y_test_proba_iso = iso_calib.predict(y_test_proba['all'])
cal_model_iso = LinearRegression().fit(y_test_proba_iso.reshape(-1,1), y_test['all'])
cal_intercept_iso, cal_slope_iso = cal_model_iso.intercept_, cal_model_iso.coef_[0]
brier_iso = brier_score_loss(y_test['all'], y_test_proba_iso)
U_iso, p_iso = calc_unreliability(y_test['all'], y_test_proba_iso)

print(f"Before Calibration:\nCalibration Intercept: {cal_intercept:.4f}\nCalibration Slope: {cal_slope:.4f}\nBrier Score: {brier_before:.4f}\nUnreliability Index: {U_before:.4f}\nUnreliability p-value: {p_before:.4f}")
print(f"\nAfter Spline Calibration:\nCalibration Intercept: {cal_intercept_after:.4f}\nCalibration Slope: {cal_slope_after:.4f}\nBrier Score: {brier_after:.4f}\nUnreliability Index: {U_after:.4f}\nUnreliability p-value: {p_after:.4f}")
print(f"\nAfter Isotonic Regression:\nCalibration Intercept: {cal_intercept_iso:.4f}\nCalibration Slope: {cal_slope_iso:.4f}\nBrier Score: {brier_iso:.4f}\nUnreliability Index: {U_iso:.4f}\nUnreliability p-value: {p_iso:.4f}\n")

# Youden threshold and predictions for all groups
youden_filename = f"Youden_{model_type}_{model_name}.pkl"
if os.path.exists(youden_filename):
    with open(youden_filename, "rb") as f:
        Youden = pickle.load(f)
else:
    Youden = youden(y_train, y_train_proba)
    with open(youden_filename, "wb") as f:
        pickle.dump(Youden, f)
Youden = youden(y_train, y_train_proba)
y_pred = {k: (y_test_proba_cal[k] > Youden).astype(int) for k in test_groups}

print(f"Youden: {Youden:.3f}")
# Print n filenames, n unique hids, and n unique opids for each group in the following outputs

# For 'all'
n_filename, n_hid, n_opid = get_n_filename_hid_opid(test_groups['all'])
print(f"<Without Calibration> (n={n_filename}, unique hids={n_hid}, unique opids={n_opid})")
draw_model_evaluation_plots(y_test['all'], y_test_proba['all'], y_pred['all'])
print(f"<With Isotonic Calibration> (n={n_filename}, unique hids={n_hid}, unique opids={n_opid})")
draw_model_evaluation_plots(y_test['all'], y_test_proba_iso, y_pred['all'])
print(f"<With Spline Calibration> (n={n_filename}, unique hids={n_hid}, unique opids={n_opid})")
draw_model_evaluation_plots(y_test['all'], y_test_proba_cal['all'], y_pred['all'])

# For subgroups
def print_subgroup_with_counts(label, key):
    n_filename, n_hid, n_opid = get_n_filename_hid_opid(test_groups[key])
    print(f"<Subgroup: {label} (n={n_filename}, unique hids={n_hid}, unique opids={n_opid})>")
    draw_model_evaluation_plots(y_test[key], y_test_proba_cal[key], y_pred[key])

print_subgroup_with_counts("General Anesthesia", "gen")
print_subgroup_with_counts("Regional Anesthesia", "reg")
print_subgroup_with_counts("EMOP=1", "emop1")
print_subgroup_with_counts("EMOP=0", "emop0")

# Only print and plot gender subgroups if 'gender' in df_test
if 'gender' in df_test.columns:
    print_subgroup_with_counts("Gender=1", "gender1")
    print_subgroup_with_counts("Gender=0", "gender0")

# Only print and plot elderly subgroups if 'age' in df_test
if 'age' in df_test.columns:
    print_subgroup_with_counts("Elderly (age > 60)", "elderly1")
    print_subgroup_with_counts("Not Elderly (age <= 60)", "elderly0")

# --- Subgroup analysis for new_dept ---
from sklearn.utils import resample
from scipy.stats import ttest_ind, f_oneway
import itertools
import statsmodels.stats.multitest as smm

def bootstrap_auroc(y, y_proba, n_bootstrap=4000, random_state=42):
    np.random.seed(random_state)
    aucs = []
    for _ in range(n_bootstrap):
        idx = np.random.choice(len(y), len(y), replace=True)
        try:
            auc = roc_auc_score(y[idx], y_proba[idx])
        except ValueError:
            auc = np.nan
        aucs.append(auc)
    return np.array(aucs)

# Compute bootstrapped AUROC for each new_dept subgroup
auroc_bootstrap = {}
for dept in new_dept_categories:
    key = f'new_dept_{dept}'
    y = y_test[key]
    y_proba = y_test_proba_cal[key]
    if len(np.unique(y)) < 2 or len(y) < 10:
        print(f"Skipping {dept} (not enough samples or only one class present)")
        auroc_bootstrap[dept] = np.array([np.nan]*4000)
        continue
    auroc_bootstrap[dept] = bootstrap_auroc(y, y_proba, n_bootstrap=4000)

# Draw model evaluation for each new_dept subgroup
print("\n[NEW_DEPT SUBGROUP MODEL EVALUATION]")
for dept in new_dept_categories:
    key = f'new_dept_{dept}'
    y = y_test[key]
    y_proba = y_test_proba_cal[key]
    y_pred_dept = y_pred[key]
    n_filename, n_hid, n_opid = get_n_filename_hid_opid(test_groups[key])
    if len(np.unique(y)) < 2 or len(y) < 10:
        print(f"Skipping {dept} (not enough samples or only one class present)")
        continue
    print(f"<Subgroup: new_dept={dept} (n={n_filename}, unique hids={n_hid}, unique opids={n_opid})>")
    draw_model_evaluation_plots(y, y_proba, y_pred_dept)

# One-way ANOVA across all new_dept subgroups (only those with valid AUROC)
valid_depts = [dept for dept in new_dept_categories if not np.isnan(auroc_bootstrap[dept]).all()]
anova_data = [auroc_bootstrap[dept] for dept in valid_depts]
anova_stat, anova_p = f_oneway(*anova_data)
print("\n[NEW_DEPT SUBGROUP AUROC BOOTSTRAP]")
for dept in valid_depts:
    print(f"{dept}: mean AUROC={np.nanmean(auroc_bootstrap[dept]):.3f} (n={len(auroc_bootstrap[dept])})")
print(f"One-way ANOVA p-value: {anova_p:.4f}")

# If ANOVA p < 0.05, do pairwise t-tests with Bonferroni correction
if anova_p < 0.05:
    print("ANOVA significant, performing pairwise t-tests (Bonferroni corrected):")
    pairs = list(itertools.combinations(valid_depts, 2))
    ttest_pvals = []
    ttest_results = []
    for d1, d2 in pairs:
        # Remove nan values for t-test
        a1 = auroc_bootstrap[d1][~np.isnan(auroc_bootstrap[d1])]
        a2 = auroc_bootstrap[d2][~np.isnan(auroc_bootstrap[d2])]
        # If either group is empty, skip
        if len(a1) == 0 or len(a2) == 0:
            ttest_pvals.append(np.nan)
            ttest_results.append((d1, d2, np.nan, np.nan))
            continue
        t_stat, p_val = ttest_ind(a1, a2, equal_var=False)
        ttest_pvals.append(p_val)
        ttest_results.append((d1, d2, t_stat, p_val))
    # Bonferroni correction
    reject, pvals_corrected, _, _ = smm.multipletests(ttest_pvals, alpha=0.05, method='bonferroni')
    for i, (d1, d2, t_stat, p_val) in enumerate(ttest_results):
        print(f"{d1} vs {d2}: t={t_stat:.3f}, raw p={p_val:.4g}, corrected p={pvals_corrected[i]:.4g}, significant={reject[i]}")
else:
    print("ANOVA not significant, no pairwise t-tests performed.")

# --- Additional: t-test between bootstrapped AUROC in other subgroup categories (anetype, emop, age, sex, ...) ---

def print_ttest_bootstrap_auroc(subgroup1, subgroup2, label1, label2):
    # Only run if both subgroups have at least 2 classes and enough samples
    y1, y2 = y_test[subgroup1], y_test[subgroup2]
    proba1, proba2 = y_test_proba_cal[subgroup1], y_test_proba_cal[subgroup2]
    if len(np.unique(y1)) < 2 or len(y1) < 10 or len(np.unique(y2)) < 2 or len(y2) < 10:
        print(f"Skipping {label1} vs {label2} (not enough samples or only one class present)")
        return
    aucs1 = bootstrap_auroc(y1, proba1, n_bootstrap=4000)
    aucs2 = bootstrap_auroc(y2, proba2, n_bootstrap=4000)
    t_stat, p_val = ttest_ind(aucs1, aucs2, equal_var=False)
    print(f"AUROC bootstrap t-test: {label1} vs {label2}: t={t_stat:.3f}, p={p_val:.4g}, mean1={np.nanmean(aucs1):.3f}, mean2={np.nanmean(aucs2):.3f}")

print("\n[SUBGROUP AUROC BOOTSTRAP T-TESTS]")

# Anesthesia type
print_ttest_bootstrap_auroc('gen', 'reg', 'General', 'Regional')

# EMOP
print_ttest_bootstrap_auroc('emop1', 'emop0', 'EMOP=1', 'EMOP=0')

# Gender
if 'gender1' in y_test and 'gender0' in y_test:
    print_ttest_bootstrap_auroc('gender1', 'gender0', 'Gender=1', 'Gender=0')

# Elderly
if 'elderly1' in y_test and 'elderly0' in y_test:
    print_ttest_bootstrap_auroc('elderly1', 'elderly0', 'Elderly (age>60)', 'Not Elderly (age<=60)')

with open(f'Youden_{model_name}.pkl', 'wb') as f:
    pickle.dump(Youden, f)
with open(f'y_test_{model_type}_{model_name}.pkl', 'wb') as f:
    pickle.dump(y_test['all'], f)
with open(f'y_test_proba_{model_type}_{model_name}.pkl', 'wb') as f:
    pickle.dump(y_test_proba_cal['all'], f)

with open(f'y_test_proba_xgb_age_gender_p_o_a.pkl', 'rb') as f:
    y_test_proba_xgb_age_gender_p_o_a = pickle.load(f)

print(f"\nDelong's P: {delong_roc_test(y_test['all'], y_test_proba_xgb_age_gender_p_o_a, y_test_proba_cal['all'])[0][0]:.3f}")
print(f"AUPRC (Paired T-test): {auprc_test(y_test['all'], y_test_proba_xgb_age_gender_p_o_a, y_test_proba_cal['all']):.3f}\n")

## Multimodal GBM, Baseline GBM, ECG GBM, ASA, RCRI

In [None]:
df = pd.read_csv('data_labels_gbm_hidden_icd.csv')
df_catcol = ['gender','p','o','a'] #'gender','asa','p','o','a'
df_numcol = ['age'] #'age','andur'
df_numcol.extend([f'hidden_{i}' for i in range(HIDDEN_SIZE)])
labels = 'label'
model_name = 'age_gender_p_o_a'

onehot_cols = ['p', 'o', 'a']
onehot_df = pd.get_dummies(df[onehot_cols], prefix=onehot_cols).astype(float)
df = df.drop(columns=onehot_cols)
df = pd.concat([df, onehot_df], axis=1)
df['gender'] = df['gender'].astype(float)
df = df[['gender'] + list(onehot_df.columns) + df_numcol + ['label', 'hid', 'filename','asa']]
df['label'] = df['label'].replace({3: 1})

rcri = pd.read_csv('rcri_missing_2.csv')
opid = pd.read_csv('opid_icd_matching.csv')
df['opid'] = df['filename'].map(dict(zip(opid['filename'], opid['opid'])))
df['rcri'] = df['opid'].map(dict(zip(rcri['opid'], rcri['rcri_score'])))
df = df[df['rcri'].notna()]

test_filenames = np.load('test_filenames.npy', allow_pickle=True)
fold_filenames = [np.load(f'fold_{i}_filenames.npy', allow_pickle=True) for i in range(1, 6)]
develop_filenames = np.concatenate(fold_filenames)
df_train = df[df['filename'].isin(develop_filenames)]
df_test = df[df['filename'].isin(test_filenames)]
df_train.drop(columns=['filename'], inplace=True)
df_test.drop(columns=['filename'], inplace=True)

with open(f'y_test_proba_xgb_age_gender_p_o_a.pkl', 'rb') as f:
    df_test['y_test_proba_xgb_age_gender_p_o_a'] = pickle.load(f)
with open(f'y_test_proba_xgb_wo_cnn_age_gender_p_o_a.pkl', 'rb') as f:
    df_test['y_test_proba_xgb_wo_cnn_age_gender_p_o_a'] = pickle.load(f)

#ASA
print("<ASA>")
Youden_ASA = youden(df_train['label'], df_train['asa'])
print(f"Youden ASA: {Youden_ASA}")
df_test['asa_pred'] = [1 if value >= Youden_ASA else 0 for value in df_test['asa']]
draw_model_evaluation_plots(df_test['label'].ravel(),df_test['asa'].ravel() , df_test['asa_pred'].ravel())
print(f"ASA (Delong's P): {delong_roc_test(df_test['label'], df_test['y_test_proba_xgb_age_gender_p_o_a'], df_test['asa'])[0][0]:.3f}")
print(f"ASA (AUPRC): {auprc_test(df_test['label'], df_test['y_test_proba_xgb_age_gender_p_o_a'], df_test['asa']):.3f}\n")
#RCRI
print("<RCRI>")
Youden_RCRI = youden(df_train['label'], df_train['rcri'])
print(f"Youden RCRI: {Youden_RCRI}")
df_test['rcri_pred'] = [1 if value >= Youden_RCRI else 0 for value in df_test['rcri']]
draw_model_evaluation_plots(df_test['label'].ravel(),df_test['rcri'].ravel() , df_test['rcri_pred'].ravel())
print(f"RCRI (Delong's P): {delong_roc_test(df_test['label'], df_test['y_test_proba_xgb_age_gender_p_o_a'], df_test['rcri'])[0][0]:.3f}")
print(f"RCRI (AUPRC): {auprc_test(df_test['label'], df_test['y_test_proba_xgb_age_gender_p_o_a'], df_test['rcri']):.3f}\n")

plt.figure(figsize=(8, 8))

def plot_roc_with_ci(y_true, y_score, color, label, n_bootstraps=1000, seed=42):
    from plot_model import normal_ci  # assumes normal_ci is available as in your context
    fpr, tpr, _ = roc_curve(y_true, y_score)
    roc_auc = auc(fpr, tpr)
    # Calculate 95% CI for AUROC using normal approximation
    ci_metrics, _ = normal_ci(np.array(y_true), np.array(y_score), (np.array(y_score) >= 0.5).astype(int))
    auroc_ci = ci_metrics.get("AUROC", (np.nan, np.nan))
    plt.plot(
        fpr, tpr, color=color, lw=2,
        label=f'{label} (AUC = {roc_auc:.3f}, 95% CI: {auroc_ci[0]:.3f}-{auroc_ci[1]:.3f})'
    )

    # Bootstrap 95% CI for ROC curve
    rng = np.random.RandomState(seed)
    bootstrapped_tprs = []
    mean_fpr = np.linspace(0, 1, 100)
    for i in range(n_bootstraps):
        indices = rng.randint(0, len(y_true), len(y_true))
        if len(np.unique(y_true.iloc[indices] if hasattr(y_true, "iloc") else y_true[indices])) < 2:
            continue
        fpr_boot, tpr_boot, _ = roc_curve(
            y_true.iloc[indices] if hasattr(y_true, "iloc") else y_true[indices],
            y_score.iloc[indices] if hasattr(y_score, "iloc") else y_score[indices]
        )
        tpr_interp = np.interp(mean_fpr, fpr_boot, tpr_boot)
        tpr_interp[0] = 0.0
        bootstrapped_tprs.append(tpr_interp)
    if len(bootstrapped_tprs) > 0:
        bootstrapped_tprs = np.array(bootstrapped_tprs)
        tpr_lower = np.percentile(bootstrapped_tprs, 2.5, axis=0)
        tpr_upper = np.percentile(bootstrapped_tprs, 97.5, axis=0)
        plt.fill_between(mean_fpr, tpr_lower, tpr_upper, color=color, alpha=0.2)

# Multimodal GBM
plot_roc_with_ci(df_test['label'], df_test['y_test_proba_xgb_age_gender_p_o_a'], color='blue', label='Multimodal GBM')
# Baseline GBM
plot_roc_with_ci(df_test['label'], df_test['y_test_proba_xgb_wo_cnn_age_gender_p_o_a'], color='orange', label='Baseline GBM')
# RCRI
plot_roc_with_ci(df_test['label'], df_test['rcri'], color='red', label='RCRI')
# ASA
plot_roc_with_ci(df_test['label'], df_test['asa'], color='purple', label='ASA')

plt.plot([0, 1], [0, 1], 'k--', lw=1)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend(loc='lower right')
plt.grid(True, alpha=0.3)
plt.show()

## Multimodal GBM, TnI

In [None]:
df = pd.read_csv('data_labels_gbm_hidden_icd.csv')
df = df.loc[df['tni'].notna()]
df_col = []

labels = 'label'
model_name = 'age_gender_p_o_a'  # wo_cnn_, mach_/ age,andur, asa, gender, p_o_a
df['gender'] = df['gender'].astype(float)
for x in ['gender', 'asa', 'age', 'andur']:
    if x in model_name:
        df_col.extend([x])
if 'wo_cnn' not in model_name:
    if 'mach_' in model_name:
        df_col.extend([col for col in df.columns if 'mach_' in col])
    else:
        df_col.extend([f'hidden_{i}' for i in range(HIDDEN_SIZE)])
if 'p_o_a' in model_name:
    onehot_cols = ['p', 'o', 'a']
    onehot_df = pd.get_dummies(df[onehot_cols], prefix=onehot_cols).astype(float)
    df = df.drop(columns=onehot_cols)
    df = pd.concat([df, onehot_df], axis=1)
    df = df[df_col + list(onehot_df.columns) + ['label', 'hid', 'filename']]
else:
    df = df[df_col + ['label', 'hid', 'filename']]
df['label'] = df['label'].replace({3: 1})

tni = pd.read_csv('opid_troponin_orig.csv')
opid = pd.read_csv('opid_icd_matching.csv')
df['opid'] = df['filename'].map(dict(zip(opid['filename'], opid['opid'])))
df['tni'] = df['opid'].map(dict(zip(tni['opid'], tni['pre_troponin_I'])))

unique_hids = df['hid'].unique()
train_hids, test_hids = train_test_split(unique_hids, test_size=0.1, random_state=23, shuffle=True)

df_train = df[df['hid'].isin(train_hids)].copy()
df_test = df[df['hid'].isin(test_hids)].copy()

test_filenames = np.array(df_test['filename'].unique())

df = df[df['tni'].notna()]
df_train_tni = df.loc[~df['filename'].isin(test_filenames)]
df_test = df.loc[df['filename'].isin(test_filenames)]

selected_features = np.load('selected_features.npy', allow_pickle=True)
with open(f"best_params_xgb_{model_name}.json", 'r') as f:
    params = json.load(f)

# Model 1: GBM-TnI (without tni)
best_model = xgb.XGBClassifier(**params)
X_train = df_train.drop(['label', 'hid', 'filename', 'tni', 'opid'], axis=1)
X_train = X_train[selected_features]
y_train = df_train['label'].ravel()
best_model.fit(X_train, y_train)
y_train_proba = best_model.predict_proba(X_train)[:,1].ravel()

# Model 2: GBM-TnI + tni as input
best_model_tni = xgb.XGBClassifier(**params)
X_train_tni = df_train.drop(['label', 'hid', 'filename', 'opid'], axis=1)  # keep tni
X_train_tni = X_train_tni[np.append(selected_features, 'tni')]
best_model_tni.fit(X_train_tni, y_train)
y_train_proba_tni = best_model_tni.predict_proba(X_train_tni)[:,1].ravel()

# Model 3: TnI only
y_train_tni = df_train_tni['label'].ravel()
df_train.drop(columns=['filename'], inplace=True)
df_test.drop(columns=['filename'], inplace=True)

print(f"Number of test set: {len(df_test)}")

# Prepare test sets
X_test = df_test.drop(['label', 'hid','tni','opid'], axis=1)
X_test_tni = df_test.drop(['label', 'hid','opid'], axis=1)  # keep tni
X_test = X_test[selected_features]
X_test_tni = X_test_tni[np.append(selected_features, 'tni')]
y_test = df_test['label'].ravel()
y_test_proba = best_model.predict_proba(X_test)[:,1].ravel()
y_test_proba_tni = best_model_tni.predict_proba(X_test_tni)[:,1].ravel()

# TnI only
calib_tni = mli.SplineCalib(unity_prior=False, unity_prior_weight=100, random_state=42, max_iter=500)
calib_tni.fit(df_train_tni['tni'].ravel(), y_train_tni)
y_test_proba_cal_tni = calib_tni.calibrate(df_test['tni'])

# Calibrate both GBM models
calib = mli.SplineCalib(unity_prior=False, unity_prior_weight=100, random_state=42, max_iter=500)
calib.fit(y_train_proba, y_train)
y_test_proba_cal = calib.calibrate(y_test_proba)

calib_gbm_tni = mli.SplineCalib(unity_prior=False, unity_prior_weight=100, random_state=42, max_iter=500)
calib_gbm_tni.fit(y_train_proba_tni, y_train)
y_test_proba_cal_gbm_tni = calib_gbm_tni.calibrate(y_test_proba_tni)

# Youden thresholds
Youden = youden(y_train, y_train_proba)
Youden_gbm_tni = youden(y_train, y_train_proba_tni)
Youden_tni = youden(y_train_tni, df_train_tni['tni'])

y_pred = (y_test_proba_cal > Youden).astype(int)
y_pred_gbm_tni = (y_test_proba_cal_gbm_tni > Youden_gbm_tni).astype(int)
y_pred_tni = (y_test_proba_cal_tni > Youden_tni).astype(int)

print(f"Youden (GBM-TS): {Youden:.3f}")
print(f"Youden (GBM-TS+TnI): {Youden_gbm_tni:.3f}")
print(f"Youden (TnI): {Youden_tni:.3f}")

print("<GBM-TS>")
draw_model_evaluation_plots(y_test, y_test_proba_cal, y_pred)
print("<GBM-TS+TnI>")
draw_model_evaluation_plots(y_test, y_test_proba_cal_gbm_tni, y_pred_gbm_tni)
print("<Troponin I>")
draw_model_evaluation_plots(y_test, y_test_proba_cal_tni, y_pred_tni, draw=False)

# --- Decision Curve Analysis (DCA) for GBM-TS, GBM-TS+TnI, and TnI only ---

# Net Benefit calculation function
def decision_curve_analysis(y_true, prob_model, thresholds=np.linspace(0.001, 0.5, 1000)):
    net_benefit = []
    n = len(y_true)
    for thresh in thresholds:
        pred = (prob_model >= thresh).astype(int)
        tp = np.sum((pred == 1) & (y_true == 1))
        fp = np.sum((pred == 1) & (y_true == 0))
        nb = (tp / n) - (fp / n) * (thresh / (1 - thresh))
        net_benefit.append(nb)
    return thresholds, np.array(net_benefit)

# Treat All Net Benefit calculation (skip values below -0.005)
def treat_all_net_benefit(y_true, thresholds):
    prevalence = np.mean(y_true)
    nb = prevalence - (1 - prevalence) * (thresholds / (1 - thresholds))
    nb = np.where(nb >= -0.005, nb, np.nan)
    return nb

# DCA for GBM-TS, GBM-TS+TnI, and TnI only
thresholds = np.linspace(0.001, 0.5, 1000)
y_true = y_test

# Use calibrated probabilities
probs_gbm = y_test_proba_cal
probs_gbm_tni = y_test_proba_cal_gbm_tni
probs_tni = y_test_proba_cal_tni

# Calculate net benefit
thresholds, nb_gbm = decision_curve_analysis(y_true, probs_gbm, thresholds)
_, nb_gbm_tni = decision_curve_analysis(y_true, probs_gbm_tni, thresholds)
_, nb_tni = decision_curve_analysis(y_true, probs_tni, thresholds)
nb_treat_all = treat_all_net_benefit(y_true, thresholds)
nb_treat_none = np.zeros_like(thresholds)

# Plot DCA
plt.figure(figsize=(8, 6))
plt.plot(thresholds, nb_gbm, color='#0044cc', linestyle='-', linewidth=2.5, label='GBM-TS')
plt.plot(thresholds, nb_gbm_tni, color='#009933', linestyle='--', linewidth=2.5, label='GBM-TS+TnI')
plt.plot(thresholds, nb_tni, color='#ff9900', linestyle='-.', linewidth=2.5, label='TnI only')
plt.plot(thresholds, nb_treat_all, color='black', linestyle='-', linewidth=2.5, label='Treat All')
plt.plot(thresholds, nb_treat_none, color='dimgray', linestyle=(0, (3, 3, 1, 3)), linewidth=2.5, label='Treat None')

plt.xlabel('Threshold Probability', fontsize=13)
plt.ylabel('Net Benefit', fontsize=13)
plt.title('Decision Curve Analysis', fontsize=15)
plt.legend(loc='lower right', fontsize=12, frameon=True, facecolor='white', edgecolor='black')
plt.grid(True, alpha=0.3)
plt.xlim([0, 0.5])
plt.ylim(bottom=np.nanmin([nb_gbm, nb_gbm_tni, nb_tni, nb_treat_all])-0.01)
plt.tight_layout()
plt.show()

# --- End DCA ---

def bootstrap_roc_ci(y_true, y_score, n_bootstraps=1000, seed=42, fpr_grid=None):
    rng = np.random.RandomState(seed)
    bootstrapped_tprs = []
    bootstrapped_aucs = []
    if fpr_grid is None:
        fpr_grid = np.linspace(0, 1, 100)
    for i in range(n_bootstraps):
        # bootstrap by sampling with replacement
        indices = rng.randint(0, len(y_true), len(y_true))
        if len(np.unique(y_true[indices])) < 2:
            # We need at least one positive and one negative sample for ROC
            continue
        fpr, tpr, _ = roc_curve(y_true[indices], y_score[indices])
        auc_score = auc(fpr, tpr)
        bootstrapped_aucs.append(auc_score)
        # Interpolate tpr at fpr_grid
        tpr_interp = np.interp(fpr_grid, fpr, tpr)
        bootstrapped_tprs.append(tpr_interp)
    bootstrapped_tprs = np.array(bootstrapped_tprs)
    tpr_mean = np.mean(bootstrapped_tprs, axis=0)
    tpr_lower = np.percentile(bootstrapped_tprs, 2.5, axis=0)
    tpr_upper = np.percentile(bootstrapped_tprs, 97.5, axis=0)
    auc_mean = np.mean(bootstrapped_aucs)
    auc_lower = np.percentile(bootstrapped_aucs, 2.5)
    auc_upper = np.percentile(bootstrapped_aucs, 97.5)
    return fpr_grid, tpr_mean, tpr_lower, tpr_upper, auc_mean, auc_lower, auc_upper

# Compute ROC and bootstrap CIs for all three models
fpr_grid = np.linspace(0, 1, 100)
fpr_gbm, tpr_gbm, _ = roc_curve(y_test, y_test_proba_cal)
roc_auc_gbm = auc(fpr_gbm, tpr_gbm)
fpr_gbm_tni, tpr_gbm_tni, _ = roc_curve(y_test, y_test_proba_cal_gbm_tni)
roc_auc_gbm_tni = auc(fpr_gbm_tni, tpr_gbm_tni)
fpr_tni, tpr_tni, _ = roc_curve(y_test, y_test_proba_cal_tni)
roc_auc_tni = auc(fpr_tni, tpr_tni)

# Bootstrapping for 95% CI
fpr_grid, tpr_gbm_mean, tpr_gbm_lower, tpr_gbm_upper, auc_gbm_mean, auc_gbm_lower, auc_gbm_upper = bootstrap_roc_ci(
    np.array(y_test), np.array(y_test_proba_cal), n_bootstraps=1000, fpr_grid=fpr_grid
)
_, tpr_gbm_tni_mean, tpr_gbm_tni_lower, tpr_gbm_tni_upper, auc_gbm_tni_mean, auc_gbm_tni_lower, auc_gbm_tni_upper = bootstrap_roc_ci(
    np.array(y_test), np.array(y_test_proba_cal_gbm_tni), n_bootstraps=1000, fpr_grid=fpr_grid
)
_, tpr_tni_mean, tpr_tni_lower, tpr_tni_upper, auc_tni_mean, auc_tni_lower, auc_tni_upper = bootstrap_roc_ci(
    np.array(y_test), np.array(y_test_proba_cal_tni), n_bootstraps=1000, fpr_grid=fpr_grid
)

plt.figure(figsize=(8, 6))
plt.plot(
    fpr_grid, tpr_gbm_mean, color='blue', lw=2,
    label=f"GBM-TS (AUROC = {roc_auc_gbm:.3f}, 95% CI: {auc_gbm_lower:.3f}-{auc_gbm_upper:.3f})"
)
plt.fill_between(fpr_grid, tpr_gbm_lower, tpr_gbm_upper, color='blue', alpha=0.2)
plt.plot(
    fpr_grid, tpr_gbm_tni_mean, color='green', lw=2,
    label=f"GBM-TS+TnI (AUROC = {roc_auc_gbm_tni:.3f}, 95% CI: {auc_gbm_tni_lower:.3f}-{auc_gbm_tni_upper:.3f})"
)
plt.fill_between(fpr_grid, tpr_gbm_tni_lower, tpr_gbm_tni_upper, color='green', alpha=0.2)
plt.plot(
    fpr_grid, tpr_tni_mean, color='orange', lw=2,
    label=f"TnI (AUROC = {roc_auc_tni:.3f}, 95% CI: {auc_tni_lower:.3f}-{auc_tni_upper:.3f})"
)
plt.fill_between(fpr_grid, tpr_tni_lower, tpr_tni_upper, color='orange', alpha=0.2)
plt.plot([0, 1], [0, 1], color='gray', lw=1, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend(loc="lower right")
plt.grid(True)
plt.show()
print(f"Delong's P (GBM-TS vs GBM-TS+TnI): {delong_roc_test(y_test, y_test_proba_cal, y_test_proba_cal_gbm_tni)[0][0]:.3f}")
print(f"Delong's P (GBM-TS vs TnI): {delong_roc_test(y_test, y_test_proba_cal, y_test_proba_cal_tni)[0][0]:.3f}")
print(f"Delong's P (GBM-TS+TnI vs TnI): {delong_roc_test(y_test, y_test_proba_cal_gbm_tni, y_test_proba_cal_tni)[0][0]:.3f}")
print(f"AUPRC P (GBM-TS vs GBM-TS+TnI): {auprc_test(y_test, y_test_proba_cal, y_test_proba_cal_gbm_tni):.3f}")
print(f"AUPRC P (GBM-TS vs TnI): {auprc_test(y_test, y_test_proba_cal, y_test_proba_cal_tni):.3f}")
print(f"AUPRC P (GBM-TS+TnI vs TnI): {auprc_test(y_test, y_test_proba_cal_gbm_tni, y_test_proba_cal_tni):.3f}\n")