In [8]:
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import (
    roc_auc_score, average_precision_score,
    accuracy_score, f1_score, precision_score, recall_score,
    mean_squared_error, mean_absolute_error
)

# ================================
# 0. 설정
# ================================
RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)
torch.manual_seed(RANDOM_STATE)

DATA_PATH = "./../cohort/cohort_ver151_reorder_col.csv"

MAX_SEQ_LEN = 128
BATCH_SIZE = 64
EPOCHS = 20
LR = 1e-3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("[INFO] device:", device)

# ================================
# 1. 데이터 로딩
# ================================
df = pd.read_csv(DATA_PATH, low_memory=False)
print("[LOAD] 데이터 로딩 완료")
print(" - shape :", df.shape)

# ================================
# 2. 기본 전처리
#    - race 인코딩
#    - time_to_next_clip / log1p
#    - delay_label 생성 (75% 기준)
# ================================
if "race" in df.columns:
    le_race = LabelEncoder()
    df["race_enc"] = le_race.fit_transform(df["race"].astype(str))
    print("[PREP] race -> race_enc 인코딩 완료")
else:
    print("[WARN] race 컬럼 없음, race_enc 미생성")

if "target_time_to_next" not in df.columns:
    raise ValueError("target_time_to_next 컬럼이 없습니다.")

clip_value = df["target_time_to_next"].quantile(0.995)
df["time_to_next_clip"] = df["target_time_to_next"].clip(upper=clip_value)
df["time_to_next_log1p"] = np.log1p(df["time_to_next_clip"])

if "delay_label" not in df.columns:
    delay_thr = df["time_to_next_clip"].quantile(0.75)
    df["delay_label"] = (df["time_to_next_clip"] > delay_thr).astype(int)
    print(f"[PREP] delay_label 생성 완료 (75% 기준: {delay_thr:.2f})")

# ================================
# 3. hadm_id 기준 Train / Val / Test 분할
# ================================
def split_by_hadm(df_in, random_state=42, train_ratio=0.7, val_ratio=0.15):
    if "hadm_id" not in df_in.columns:
        raise ValueError("hadm_id 컬럼이 없습니다.")

    hadm_ids = df_in["hadm_id"].unique()
    rng = np.random.RandomState(random_state)
    rng.shuffle(hadm_ids)

    n = len(hadm_ids)
    n_train = int(n * train_ratio)
    n_val = int(n * val_ratio)

    hadm_train = hadm_ids[:n_train]
    hadm_val   = hadm_ids[n_train:n_train + n_val]
    hadm_test  = hadm_ids[n_train + n_val:]

    df_train = df_in[df_in["hadm_id"].isin(hadm_train)].copy()
    df_val   = df_in[df_in["hadm_id"].isin(hadm_val)].copy()
    df_test  = df_in[df_in["hadm_id"].isin(hadm_test)].copy()

    print("[SPLIT] hadm_id 기준 분할 완료")
    print(" - 전체 hadm_id:", n)
    print(" - train hadm_id:", len(hadm_train), ", rows:", len(df_train))
    print(" - val   hadm_id:", len(hadm_val),   ", rows:", len(df_val))
    print(" - test  hadm_id:", len(hadm_test),  ", rows:", len(df_test))

    return df_train, df_val, df_test

df_train, df_val, df_test = split_by_hadm(df, random_state=RANDOM_STATE)

# ================================
# 4. Feature Engineering
# ================================
def add_features(df_in: pd.DataFrame) -> pd.DataFrame:
    df_fe = df_in.copy()

    # HR/BP 비율
    if "current_heart_rate" in df_fe.columns and "current_mean_bp" in df_fe.columns:
        df_fe["hr_bp_ratio"] = df_fe["current_heart_rate"] / (df_fe["current_mean_bp"].abs() + 1.0)
    else:
        df_fe["hr_bp_ratio"] = 0.0

    # 마지막 이벤트 시간 / 전체 시간 비율
    if "time_since_last" in df_fe.columns and "time_since_start_min" in df_fe.columns:
        df_fe["time_last_ratio"] = df_fe["time_since_last"] / (df_fe["time_since_start_min"].abs() + 1.0)
    else:
        df_fe["time_last_ratio"] = 0.0

    # pathway 진행도
    if "pathway_stage" in df_fe.columns and "prefix_len" in df_fe.columns:
        df_fe["event_progress"] = df_fe["pathway_stage"] / (df_fe["prefix_len"] + 1.0)
    else:
        df_fe["event_progress"] = 0.0

    # 글로벌 median 기준 delay risk
    if "time_since_last" in df.columns:
        global_median = df["time_since_last"].median()
        df_fe["risk_delay"] = (df_fe["time_since_last"] > global_median).astype(int)
    else:
        df_fe["risk_delay"] = 0

    # STEMI 누적 위험도
    has_stemi = "stemi_flag" in df_fe.columns
    has_cum_stemi = "cum_stemi_cnt" in df_fe.columns
    if has_stemi and has_cum_stemi:
        df_fe["risk_stemi"] = df_fe["stemi_flag"] * df_fe["cum_stemi_cnt"]
    elif has_stemi:
        df_fe["risk_stemi"] = df_fe["stemi_flag"]
    else:
        df_fe["risk_stemi"] = 0

    # Troponin 이상 여부
    if "last_trop" in df_fe.columns:
        df_fe["trop_abnormal"] = (df_fe["last_trop"] > 0.04).astype(int)
    else:
        df_fe["trop_abnormal"] = 0

    return df_fe

df_train = add_features(df_train)
df_val   = add_features(df_val)
df_test  = add_features(df_test)

# ================================
# 5. Transformer 입력 feature 정의
# ================================
candidate_cols = [
    "age", "gender", "race_enc",
    "arrival_transport",
    "prefix_len", "current_event_id",
    "time_since_start_min", "time_since_ed", "time_since_last",
    "is_night",
    "cum_ecg_cnt", "cum_trop_cnt",
    "stemi_flag", "trop_pos_flag",
    "last_trop", "run_max_trop", "trop_trend",
    "pci_status",
    "current_heart_rate", "current_mean_bp",
    "hr_bp_ratio", "time_last_ratio", "event_progress",
    "risk_delay", "risk_stemi", "trop_abnormal",
]

feature_cols = [c for c in candidate_cols if c in df_train.columns]
print("[INFO] Transformer input feature 수:", len(feature_cols))
print("[INFO] Feature 예시:", feature_cols[:10])

num_next_classes = int(df["target_next_evt"].max())
print("[INFO] target_next_evt num_class:", num_next_classes)

# ================================
# 6. Feature / 타깃 정규화 (성능 향상을 위해 추가)
# ================================
# 6-1) 입력 feature 표준화
scaler_X = StandardScaler()
scaler_X.fit(df_train[feature_cols].values)

df_train[feature_cols] = scaler_X.transform(df_train[feature_cols].values)
df_val[feature_cols]   = scaler_X.transform(df_val[feature_cols].values)
df_test[feature_cols]  = scaler_X.transform(df_test[feature_cols].values)

print("[PREP] Feature StandardScaler 적용 완료")

# 6-2) 회귀 타깃 표준화 (time_to_next_log1p, remain_los)
if "target_remain_los" not in df_train.columns:
    raise ValueError("target_remain_los 컬럼이 없습니다.")

ttn_mean = df_train["time_to_next_log1p"].mean()
ttn_std  = df_train["time_to_next_log1p"].std() + 1e-8

los_mean = df_train["target_remain_los"].mean()
los_std  = df_train["target_remain_los"].std() + 1e-8

def add_target_std(df_in):
    df_out = df_in.copy()
    df_out["time_to_next_std"] = (df_out["time_to_next_log1p"] - ttn_mean) / ttn_std
    df_out["remain_los_std"]   = (df_out["target_remain_los"] - los_mean) / los_std
    return df_out

df_train = add_target_std(df_train)
df_val   = add_target_std(df_val)
df_test  = add_target_std(df_test)

print("[PREP] 회귀 타깃 표준화 완료 (time_to_next_std, remain_los_std)")

# ================================
# 7. hadm_id 단위 시퀀스 생성 (5 task)
# ================================
def build_sequences_multi(df_split: pd.DataFrame, feature_cols, max_seq_len: int):
    X_seqs = []
    M_seqs = []
    y_mort_list  = []
    y_next_list  = []
    y_ttn_list   = []
    y_los_list   = []
    y_delay_list = []

    sort_key = "time_since_start_min" if "time_since_start_min" in df_split.columns else "prefix_len"

    for _, g in df_split.groupby("hadm_id"):
        g = g.sort_values(sort_key)

        seq_X = g[feature_cols].values.astype(np.float32)

        y_mort  = g["target_mortality"].values.astype(np.int64)
        y_next  = (g["target_next_evt"].values - 1).astype(np.int64)  # 1~K → 0~K-1
        y_ttn   = g["time_to_next_std"].values.astype(np.float32)
        y_los   = g["remain_los_std"].values.astype(np.float32)
        y_delay = g["delay_label"].values.astype(np.int64)

        T = len(seq_X)
        T_use = min(T, max_seq_len)

        seq_X = seq_X[:T_use]
        y_mort = y_mort[:T_use]
        y_next = y_next[:T_use]
        y_ttn  = y_ttn[:T_use]
        y_los  = y_los[:T_use]
        y_delay = y_delay[:T_use]

        pad_len = max_seq_len - T_use

        if pad_len > 0:
            pad_X = np.zeros((pad_len, seq_X.shape[1]), dtype=np.float32)
            seq_X = np.concatenate([seq_X, pad_X], axis=0)

            pad_int = np.zeros((pad_len,), dtype=np.int64)
            pad_float = np.zeros((pad_len,), dtype=np.float32)

            y_mort  = np.concatenate([y_mort, pad_int], axis=0)
            y_next  = np.concatenate([y_next, pad_int], axis=0)
            y_ttn   = np.concatenate([y_ttn,  pad_float], axis=0)
            y_los   = np.concatenate([y_los,  pad_float], axis=0)
            y_delay = np.concatenate([y_delay, pad_int], axis=0)

        mask = np.zeros((max_seq_len,), dtype=np.float32)
        mask[:T_use] = 1.0

        X_seqs.append(seq_X)
        M_seqs.append(mask)
        y_mort_list.append(y_mort)
        y_next_list.append(y_next)
        y_ttn_list.append(y_ttn)
        y_los_list.append(y_los)
        y_delay_list.append(y_delay)

    X_seqs = np.stack(X_seqs, axis=0)
    M_seqs = np.stack(M_seqs, axis=0)
    y_mort_arr  = np.stack(y_mort_list, axis=0)
    y_next_arr  = np.stack(y_next_list, axis=0)
    y_ttn_arr   = np.stack(y_ttn_list, axis=0)
    y_los_arr   = np.stack(y_los_list, axis=0)
    y_delay_arr = np.stack(y_delay_list, axis=0)

    return X_seqs, M_seqs, y_mort_arr, y_next_arr, y_ttn_arr, y_los_arr, y_delay_arr

X_train_np, M_train_np, y_mort_tr, y_next_tr, y_ttn_tr, y_los_tr, y_delay_tr = build_sequences_multi(df_train, feature_cols, MAX_SEQ_LEN)
X_val_np,   M_val_np,   y_mort_va, y_next_va, y_ttn_va, y_los_va, y_delay_va = build_sequences_multi(df_val,   feature_cols, MAX_SEQ_LEN)
X_test_np,  M_test_np,  y_mort_te, y_next_te, y_ttn_te, y_los_te, y_delay_te = build_sequences_multi(df_test,  feature_cols, MAX_SEQ_LEN)

print("[SEQ] Train:", X_train_np.shape, y_mort_tr.shape)
print("[SEQ] Val  :", X_val_np.shape,   y_mort_va.shape)
print("[SEQ] Test :", X_test_np.shape,  y_mort_te.shape)

# ================================
# 8. Dataset / DataLoader
# ================================
class PPMSeqDataset(Dataset):
    def __init__(self, X, M, y_mort, y_next, y_ttn, y_los, y_delay):
        self.X = X
        self.M = M
        self.y_mort = y_mort
        self.y_next = y_next
        self.y_ttn = y_ttn
        self.y_los = y_los
        self.y_delay = y_delay

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return (
            torch.from_numpy(self.X[idx]),
            torch.from_numpy(self.M[idx]),
            torch.from_numpy(self.y_mort[idx]),
            torch.from_numpy(self.y_next[idx]),
            torch.from_numpy(self.y_ttn[idx]),
            torch.from_numpy(self.y_los[idx]),
            torch.from_numpy(self.y_delay[idx]),
        )

train_dataset = PPMSeqDataset(X_train_np, M_train_np, y_mort_tr, y_next_tr, y_ttn_tr, y_los_tr, y_delay_tr)
val_dataset   = PPMSeqDataset(X_val_np,   M_val_np,   y_mort_va, y_next_va, y_ttn_va, y_los_va, y_delay_va)
test_dataset  = PPMSeqDataset(X_test_np,  M_test_np,  y_mort_te, y_next_te, y_ttn_te, y_los_te, y_delay_te)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False)
test_loader  = DataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False)

# ================================
# 9. Transformer 모델 정의 (5-task head, 모델 용량 살짝 증가)
# ================================
class PPMTransformer(nn.Module):
    def __init__(self, input_dim, num_next_classes, d_model=128, nhead=4, num_layers=3, dim_ff=256, max_len=MAX_SEQ_LEN):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, d_model)
        self.pos_embedding = nn.Embedding(max_len, d_model)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_ff,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Task-specific heads
        self.mort_head  = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, 1)
        )
        self.delay_head = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, 1)
        )
        self.next_head  = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, num_next_classes)
        )
        self.ttn_head   = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, 1)
        )
        self.los_head   = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, 1)
        )

    def forward(self, x, mask):
        B, T, _ = x.size()
        h = self.input_proj(x)

        positions = torch.arange(T, device=x.device).unsqueeze(0).expand(B, T)
        pos_emb = self.pos_embedding(positions)
        h = h + pos_emb

        src_key_padding_mask = (mask == 0)
        h_enc = self.encoder(h, src_key_padding_mask=src_key_padding_mask)

        mort_logits  = self.mort_head(h_enc).squeeze(-1)
        delay_logits = self.delay_head(h_enc).squeeze(-1)
        next_logits  = self.next_head(h_enc)
        ttn_pred     = self.ttn_head(h_enc).squeeze(-1)
        los_pred     = self.los_head(h_enc).squeeze(-1)

        return {
            "mort_logits": mort_logits,
            "delay_logits": delay_logits,
            "next_logits": next_logits,
            "ttn_pred": ttn_pred,
            "los_pred": los_pred,
        }

input_dim = len(feature_cols)
model = PPMTransformer(input_dim=input_dim, num_next_classes=num_next_classes).to(device)

# ================================
# 10. loss 함수 및 optimizer
#    - 사망/지연: pos_weight BCEWithLogits
#    - 회귀: MSE (표준화된 타깃 기준)
#    - 태스크별 loss 가중치 강화
# ================================
mort_pos_ratio = df_train["target_mortality"].mean()
delay_pos_ratio = df_train["delay_label"].mean()

pos_weight_mort = (1.0 - mort_pos_ratio) / max(mort_pos_ratio, 1e-6)
pos_weight_delay = (1.0 - delay_pos_ratio) / max(delay_pos_ratio, 1e-6)

print(f"[INFO] mort_pos_ratio={mort_pos_ratio:.4f}, pos_weight_mort={pos_weight_mort:.2f}")
print(f"[INFO] delay_pos_ratio={delay_pos_ratio:.4f}, pos_weight_delay={pos_weight_delay:.2f}")

bce_mort = nn.BCEWithLogitsLoss(
    pos_weight=torch.tensor([pos_weight_mort], device=device),
    reduction="none",
)
bce_delay = nn.BCEWithLogitsLoss(
    pos_weight=torch.tensor([pos_weight_delay], device=device),
    reduction="none",
)
ce_loss  = nn.CrossEntropyLoss(reduction="none")
mse_loss = nn.MSELoss(reduction="none")

optimizer = torch.optim.Adam(model.parameters(), lr=LR)

# ================================
# 11. metric 함수들
# ================================
def compute_metrics_binary(y_true, y_proba, threshold=0.5):
    y_pred = (y_proba >= threshold).astype(int)
    if len(np.unique(y_true)) > 1:
        auc = roc_auc_score(y_true, y_proba)
    else:
        auc = np.nan
    ap  = average_precision_score(y_true, y_proba)
    acc = accuracy_score(y_true, y_pred)
    f1  = f1_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred)
    rec  = recall_score(y_true, y_pred)
    return dict(AUC=auc, AP=ap, ACC=acc, PREC=prec, REC=rec, F1=f1)

def compute_metrics_reg(y_true, y_pred):
    mse  = mean_squared_error(y_true, y_pred)
    rmse = mse ** 0.5
    mae  = mean_absolute_error(y_true, y_pred)
    return dict(RMSE=rmse, MAE=mae)

def compute_metrics_multiclass(y_true, y_pred):
    acc = accuracy_score(y_true, y_pred)
    macro_f1 = f1_score(y_true, y_pred, average="macro")
    macro_prec = precision_score(y_true, y_pred, average="macro")
    macro_rec  = recall_score(y_true, y_pred, average="macro")
    return dict(ACC=acc, PREC=macro_prec, REC=macro_rec, F1=macro_f1)

def eval_epoch(model, loader, threshold=0.5):
    model.eval()
    all_mort_y, all_mort_p = [], []
    all_delay_y, all_delay_p = [], []
    all_next_y, all_next_pred = [], []
    all_ttn_y_std, all_ttn_pred_std = [], []
    all_los_y_std, all_los_pred_std = [], []

    with torch.no_grad():
        for xb, mb, y_mort_b, y_next_b, y_ttn_b, y_los_b, y_delay_b in loader:
            xb = xb.to(device).float()
            mb = mb.to(device).float()
            y_mort_b = y_mort_b.to(device).float()
            y_next_b = y_next_b.to(device).long()
            y_ttn_b = y_ttn_b.to(device).float()
            y_los_b = y_los_b.to(device).float()
            y_delay_b = y_delay_b.to(device).float()

            out = model(xb, mb)
            mort_logits  = out["mort_logits"]
            delay_logits = out["delay_logits"]
            next_logits  = out["next_logits"]
            ttn_pred     = out["ttn_pred"]
            los_pred     = out["los_pred"]

            mask_flat = mb.view(-1) > 0
            if mask_flat.sum() == 0:
                continue

            mort_p = torch.sigmoid(mort_logits.view(-1)[mask_flat])
            delay_p = torch.sigmoid(delay_logits.view(-1)[mask_flat])
            next_pred = torch.argmax(next_logits.view(-1, num_next_classes)[mask_flat], dim=1)
            ttn_p_std = ttn_pred.view(-1)[mask_flat]
            los_p_std = los_pred.view(-1)[mask_flat]

            mort_y = y_mort_b.view(-1)[mask_flat]
            delay_y = y_delay_b.view(-1)[mask_flat]
            next_y = y_next_b.view(-1)[mask_flat]
            ttn_y_std = y_ttn_b.view(-1)[mask_flat]
            los_y_std = y_los_b.view(-1)[mask_flat]

            all_mort_y.append(mort_y.cpu().numpy())
            all_mort_p.append(mort_p.cpu().numpy())
            all_delay_y.append(delay_y.cpu().numpy())
            all_delay_p.append(delay_p.cpu().numpy())
            all_next_y.append(next_y.cpu().numpy())
            all_next_pred.append(next_pred.cpu().numpy())
            all_ttn_y_std.append(ttn_y_std.cpu().numpy())
            all_ttn_pred_std.append(ttn_p_std.cpu().numpy())
            all_los_y_std.append(los_y_std.cpu().numpy())
            all_los_pred_std.append(los_p_std.cpu().numpy())

    def concat_or_empty(lst):
        return np.concatenate(lst, axis=0) if len(lst) > 0 else np.array([])

    mort_y = concat_or_empty(all_mort_y)
    mort_p = concat_or_empty(all_mort_p)
    delay_y = concat_or_empty(all_delay_y)
    delay_p = concat_or_empty(all_delay_p)
    next_y = concat_or_empty(all_next_y)
    next_pred = concat_or_empty(all_next_pred)
    ttn_y_std = concat_or_empty(all_ttn_y_std)
    ttn_pred_std = concat_or_empty(all_ttn_pred_std)
    los_y_std = concat_or_empty(all_los_y_std)
    los_pred_std = concat_or_empty(all_los_pred_std)

    metrics = {}

    if len(mort_y) > 0:
        metrics["mortality"] = compute_metrics_binary(mort_y, mort_p, threshold)

    if len(delay_y) > 0:
        metrics["delay"] = compute_metrics_binary(delay_y, delay_p, threshold)

    if len(next_y) > 0:
        metrics["next_event"] = compute_metrics_multiclass(next_y, next_pred)

    # 회귀는 표준화 → log1p/일 단위로 역변환 후 RMSE 계산
    if len(ttn_y_std) > 0:
        y_true_log1p = ttn_y_std * ttn_std + ttn_mean
        y_pred_log1p = ttn_pred_std * ttn_std + ttn_mean
        metrics["time_to_next"] = compute_metrics_reg(y_true_log1p, y_pred_log1p)

    if len(los_y_std) > 0:
        y_true_los = los_y_std * los_std + los_mean
        y_pred_los = los_pred_std * los_std + los_mean
        metrics["remain_los"] = compute_metrics_reg(y_true_los, y_pred_los)

    return metrics

# ================================
# 12. Training Loop (loss 가중치 강화 + early-best by Val AUC)
# ================================
best_val_auc = -1.0
best_state = None

for epoch in range(1, EPOCHS + 1):
    model.train()
    total_loss = 0.0
    n_batches = 0

    for xb, mb, y_mort_b, y_next_b, y_ttn_b, y_los_b, y_delay_b in train_loader:
        xb = xb.to(device).float()
        mb = mb.to(device).float()
        y_mort_b = y_mort_b.to(device).float()
        y_next_b = y_next_b.to(device).long()
        y_ttn_b = y_ttn_b.to(device).float()
        y_los_b = y_los_b.to(device).float()
        y_delay_b = y_delay_b.to(device).float()

        optimizer.zero_grad()
        out = model(xb, mb)

        mort_logits  = out["mort_logits"]
        delay_logits = out["delay_logits"]
        next_logits  = out["next_logits"]
        ttn_pred     = out["ttn_pred"]
        los_pred     = out["los_pred"]

        mask_flat = mb.view(-1) > 0
        if mask_flat.sum() == 0:
            continue

        mort_logits_flat  = mort_logits.view(-1)[mask_flat]
        delay_logits_flat = delay_logits.view(-1)[mask_flat]
        next_logits_flat  = next_logits.view(-1, num_next_classes)[mask_flat]
        ttn_pred_flat     = ttn_pred.view(-1)[mask_flat]
        los_pred_flat     = los_pred.view(-1)[mask_flat]

        y_mort_flat  = y_mort_b.view(-1)[mask_flat]
        y_delay_flat = y_delay_b.view(-1)[mask_flat]
        y_next_flat  = y_next_b.view(-1)[mask_flat]
        y_ttn_flat   = y_ttn_b.view(-1)[mask_flat]
        y_los_flat   = y_los_b.view(-1)[mask_flat]

        loss_mort  = bce_mort(mort_logits_flat, y_mort_flat).mean()
        loss_delay = bce_delay(delay_logits_flat, y_delay_flat).mean()
        loss_next  = ce_loss(next_logits_flat, y_next_flat).mean()
        loss_ttn   = mse_loss(ttn_pred_flat, y_ttn_flat).mean()
        loss_los   = mse_loss(los_pred_flat, y_los_flat).mean()

        # 태스크별 가중치 (사망/다음 이벤트를 더 강하게, 회귀는 상대적으로 줄임)
        loss = (
            4.0 * loss_mort +
            2.0 * loss_delay +
            4.0 * loss_next +
            1.0 * loss_ttn +
            1.0 * loss_los
        )

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        n_batches += 1

    avg_loss = total_loss / max(n_batches, 1)
    val_metrics = eval_epoch(model, val_loader, threshold=0.5)
    val_auc = val_metrics.get("mortality", {}).get("AUC", np.nan)

    print(f"[Epoch {epoch:02d}] TrainLoss={avg_loss:.4f} | "
          f"Val mortality AUC={val_auc:.4f}")

    if not np.isnan(val_auc) and val_auc > best_val_auc:
        best_val_auc = val_auc
        best_state = {k: v.cpu() for k, v in model.state_dict().items()}

if best_state is not None:
    model.load_state_dict({k: v.to(device) for k, v in best_state.items()})
    print(f"[INFO] Best Val mortality AUC: {best_val_auc:.4f} 모델 로드 완료")

# ================================
# 13. 최종 평가
# ================================
print("\n[FINAL EVAL]")

train_metrics = eval_epoch(model, train_loader)
val_metrics   = eval_epoch(model, val_loader)
test_metrics  = eval_epoch(model, test_loader)

print("Train:", train_metrics)
print("Val  :", val_metrics)
print("Test :", test_metrics)


[INFO] device: cpu
[LOAD] 데이터 로딩 완료
 - shape : (40817, 27)
[PREP] race -> race_enc 인코딩 완료
[PREP] delay_label 생성 완료 (75% 기준: 168.10)
[SPLIT] hadm_id 기준 분할 완료
 - 전체 hadm_id: 1929
 - train hadm_id: 1350 , rows: 28412
 - val   hadm_id: 289 , rows: 5723
 - test  hadm_id: 290 , rows: 6682
[INFO] Transformer input feature 수: 26
[INFO] Feature 예시: ['age', 'gender', 'race_enc', 'arrival_transport', 'prefix_len', 'current_event_id', 'time_since_start_min', 'time_since_ed', 'time_since_last', 'is_night']
[INFO] target_next_evt num_class: 14
[PREP] Feature StandardScaler 적용 완료
[PREP] 회귀 타깃 표준화 완료 (time_to_next_std, remain_los_std)
[SEQ] Train: (1350, 128, 26) (1350, 128)
[SEQ] Val  : (289, 128, 26) (289, 128)
[SEQ] Test : (290, 128, 26) (290, 128)
[INFO] mort_pos_ratio=0.0810, pos_weight_mort=11.35
[INFO] delay_pos_ratio=0.2489, pos_weight_delay=3.02


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[Epoch 01] TrainLoss=16.0110 | Val mortality AUC=0.7378


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[Epoch 02] TrainLoss=13.2749 | Val mortality AUC=0.7795


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[Epoch 03] TrainLoss=11.5749 | Val mortality AUC=0.9339


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[Epoch 04] TrainLoss=9.8191 | Val mortality AUC=0.9249


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[Epoch 05] TrainLoss=9.3752 | Val mortality AUC=0.9975


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[Epoch 06] TrainLoss=7.3548 | Val mortality AUC=0.9837


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[Epoch 07] TrainLoss=6.6432 | Val mortality AUC=0.9983


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[Epoch 08] TrainLoss=6.2682 | Val mortality AUC=0.9908


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[Epoch 09] TrainLoss=5.4926 | Val mortality AUC=0.9983


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[Epoch 10] TrainLoss=5.0109 | Val mortality AUC=0.9954


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[Epoch 11] TrainLoss=4.2461 | Val mortality AUC=0.9986


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[Epoch 12] TrainLoss=3.6805 | Val mortality AUC=0.9983


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[Epoch 13] TrainLoss=3.1199 | Val mortality AUC=0.9951


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[Epoch 14] TrainLoss=2.6223 | Val mortality AUC=0.9988


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[Epoch 15] TrainLoss=2.3058 | Val mortality AUC=0.9922


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[Epoch 16] TrainLoss=3.6419 | Val mortality AUC=0.9918


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[Epoch 17] TrainLoss=2.9653 | Val mortality AUC=0.9964


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[Epoch 18] TrainLoss=2.4921 | Val mortality AUC=0.9349


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[Epoch 19] TrainLoss=2.6160 | Val mortality AUC=0.9809


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[Epoch 20] TrainLoss=2.5613 | Val mortality AUC=0.9827
[INFO] Best Val mortality AUC: 0.9988 모델 로드 완료

[FINAL EVAL]


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Train: {'mortality': {'AUC': np.float64(0.9994616825747086), 'AP': np.float64(0.9900258173858376), 'ACC': 0.9775726060348792, 'PREC': 0.7839127471029311, 'REC': 1.0, 'F1': 0.8788689338937715}, 'delay': {'AUC': np.float64(0.9883536413868453), 'AP': np.float64(0.9647584776118696), 'ACC': 0.9272701545862959, 'PREC': 0.7848764377633527, 'REC': 0.9763422581102139, 'F1': 0.8702020202020202}, 'next_event': {'ACC': 0.9475043333687078, 'PREC': 0.8370103029686334, 'REC': 0.8434321984148788, 'F1': 0.8344057096208176}, 'time_to_next': {'RMSE': 0.7709130914464908, 'MAE': 0.4610585985545264}, 'remain_los': {'RMSE': 0.9229646543185377, 'MAE': 0.6480720407922537}}
Val  : {'mortality': {'AUC': np.float64(0.9827373228065657), 'AP': np.float64(0.9136102533698841), 'ACC': 0.9578892189411148, 'PREC': 0.757754800590842, 'REC': 0.8694915254237288, 'F1': 0.8097868981846882}, 'delay': {'AUC': np.float64(0.9857378455803261), 'AP': np.float64(0.960423668884319), 'ACC': 0.9241656473877337, 'PREC': 0.7835616438356