In [None]:
import os
import random
import numpy as np
import pandas as pd

from sklearn.preprocessing import StandardScaler, LabelEncoder
import lightgbm as lgb

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


# =========================================
# 0. 공통 설정
# =========================================
DATA_PATH = "./../cohort/cohort_ver151_reorder_col.csv"
RANDOM_STATE = 42
MAX_LEN = 128
BATCH_SIZE = 32
N_EPOCHS = 10


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


set_seed(RANDOM_STATE)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] device: {device}")


# =========================================
# 1. 데이터 로딩 및 기본 전처리
# =========================================
df = pd.read_csv(DATA_PATH, low_memory=False)
print("[LOAD] 데이터 로딩 완료")
print(" - shape :", df.shape)

required_cols = [
    "hadm_id",
    "target_mortality",
    "target_next_evt",
    "target_time_to_next",
    "target_remain_los",
]
for c in required_cols:
    if c not in df.columns:
        raise ValueError(f"필수 컬럼 {c} 이(가) 없습니다: {c}")

if "subject_id" not in df.columns:
    df["subject_id"] = df["hadm_id"]

# race 인코딩
if "race" in df.columns:
    le_race = LabelEncoder()
    df["race_enc"] = le_race.fit_transform(df["race"].astype(str))
    print("[PREP] race -> race_enc 인코딩 완료")
else:
    df["race_enc"] = 0
    print("[WARN] race 컬럼 없음, race_enc=0으로 채움")

# time_to_next 클리핑 + log1p
clip_value = df["target_time_to_next"].quantile(0.995)
df["time_to_next_clip"] = df["target_time_to_next"].clip(upper=clip_value)
df["time_to_next_log1p"] = np.log1p(df["time_to_next_clip"])

# delay_label (75% 기준)
if "delay_label" not in df.columns:
    delay_thr = df["time_to_next_clip"].quantile(0.75)
    df["delay_label"] = (df["time_to_next_clip"] > delay_thr).astype(int)
    print(f"[PREP] delay_label 생성 완료 (75% 기준: {delay_thr:.2f})")
else:
    print("[PREP] 기존 delay_label 사용")


# =========================================
# 2. Feature Engineering
# =========================================
def add_features(df_in: pd.DataFrame) -> pd.DataFrame:
    df_fe = df_in.copy()

    if "current_heart_rate" in df_fe.columns and "current_mean_bp" in df_fe.columns:
        df_fe["hr_bp_ratio"] = df_fe["current_heart_rate"] / (df_fe["current_mean_bp"].abs() + 1.0)
    else:
        df_fe["hr_bp_ratio"] = 0.0

    if "time_since_last" in df_fe.columns and "time_since_start_min" in df_fe.columns:
        df_fe["time_last_ratio"] = df_fe["time_since_last"] / (df_fe["time_since_start_min"].abs() + 1.0)
    else:
        df_fe["time_last_ratio"] = 0.0

    if "pathway_stage" in df_fe.columns and "prefix_len" in df_fe.columns:
        df_fe["event_progress"] = df_fe["pathway_stage"] / (df_fe["prefix_len"] + 1.0)
    else:
        df_fe["event_progress"] = 0.0

    if "time_since_last" in df.columns:
        global_median = df["time_since_last"].median()
        df_fe["risk_delay"] = (df_fe["time_since_last"] > global_median).astype(int)
    else:
        df_fe["risk_delay"] = 0

    has_stemi = "stemi_flag" in df_fe.columns
    has_cum_stemi = "cum_stemi_cnt" in df_fe.columns
    if has_stemi and has_cum_stemi:
        df_fe["risk_stemi"] = df_fe["stemi_flag"] * df_fe["cum_stemi_cnt"]
    elif has_stemi:
        df_fe["risk_stemi"] = df_fe["stemi_flag"]
    else:
        df_fe["risk_stemi"] = 0

    if "last_trop" in df_fe.columns:
        df_fe["trop_abnormal"] = (df_fe["last_trop"] > 0.04).astype(int)
    else:
        df_fe["trop_abnormal"] = 0

    return df_fe


df = add_features(df)

# next_event 0-index
df["next_evt_label"] = df["target_next_evt"].astype(int)
min_next_label = df["next_evt_label"].min()
df["next_evt_label0"] = df["next_evt_label"] - min_next_label
num_next_classes = df["next_evt_label0"].max() + 1

candidate_cols = [
    "age", "gender", "race_enc",
    "arrival_transport",
    "prefix_len", "current_event_id",
    "time_since_start_min", "time_since_ed", "time_since_last",
    "is_night",
    "cum_ecg_cnt", "cum_trop_cnt", "cum_stemi_cnt",
    "stemi_flag", "trop_pos_flag",
    "last_trop", "run_max_trop", "trop_trend",
    "pci_status",
    "current_heart_rate", "current_mean_bp",
    "hr_bp_ratio", "time_last_ratio", "event_progress",
    "risk_delay", "risk_stemi", "trop_abnormal",
]

feature_cols = [c for c in candidate_cols if c in df.columns]
print("[INFO] 사용 feature 수:", len(feature_cols))
print("[INFO] Feature 예시:", feature_cols[:10])

target_cols = {
    "mortality": "target_mortality",
    "delay": "delay_label",
    "next_event0": "next_evt_label0",
    "time_to_next_log1p": "time_to_next_log1p",
    "remain_los": "target_remain_los",
}

df = df.dropna(subset=feature_cols + list(target_cols.values()))
print("[PREP] 결측 제거 후:", df.shape)


# =========================================
# 3. hadm_id 기반 Split
# =========================================
def split_by_hadm(df_in, random_state=42, train_ratio=0.7, val_ratio=0.15):
    hadm_ids = df_in["hadm_id"].unique()
    rng = np.random.RandomState(random_state)
    rng.shuffle(hadm_ids)

    n = len(hadm_ids)
    n_train = int(n * train_ratio)
    n_val = int(n * val_ratio)

    hadm_train = hadm_ids[:n_train]
    hadm_val = hadm_ids[n_train:n_train + n_val]
    hadm_test = hadm_ids[n_train + n_val:]

    df_train = df_in[df_in["hadm_id"].isin(hadm_train)].copy()
    df_val = df_in[df_in["hadm_id"].isin(hadm_val)].copy()
    df_test = df_in[df_in["hadm_id"].isin(hadm_test)].copy()

    print("[SPLIT] hadm_id 기준 분할 완료")
    print(" - 전체 hadm_id:", n)
    print(" - train hadm_id:", len(hadm_train), ", rows:", len(df_train))
    print(" - val   hadm_id:", len(hadm_val), ", rows:", len(df_val))
    print(" - test  hadm_id:", len(hadm_test), ", rows:", len(df_test))

    return df_train, df_val, df_test


df_train, df_val, df_test = split_by_hadm(df, random_state=RANDOM_STATE)


# =========================================
# 4. LightGBM: 사망 예측 모델
# =========================================
X_train_lgbm = df_train[feature_cols].values
y_train_mort = df_train[target_cols["mortality"]].values

mort_lgbm = lgb.LGBMClassifier(
    objective="binary",
    n_estimators=400,
    learning_rate=0.05,
    num_leaves=31,
    subsample=0.8,
    colsample_bytree=0.8,
    class_weight="balanced",
    random_state=RANDOM_STATE,
)

print("[TRAIN] LGBM Mortality 학습")
mort_lgbm.fit(X_train_lgbm, y_train_mort)


# =========================================
# 5. Transformer용 시퀀스 데이터 구성
# =========================================
scaler = StandardScaler()
# 이름 없는 ndarray로 fit해서 feature name 경고 방지
scaler.fit(df_train[feature_cols].values)


def build_sequences_with_ids(df_split, feature_cols, max_len=128, scaler=None):
    groups = list(df_split.groupby("hadm_id"))
    n_seq = len(groups)
    d = len(feature_cols)

    X_seq = np.zeros((n_seq, max_len, d), dtype=np.float32)
    pad_mask = np.ones((n_seq, max_len), dtype=bool)
    y_delay = np.zeros((n_seq, max_len), dtype=np.float32)
    y_next = np.zeros((n_seq, max_len), dtype=np.int64)
    y_ttn = np.zeros((n_seq, max_len), dtype=np.float32)
    y_los = np.zeros((n_seq, max_len), dtype=np.float32)
    hadm_ids = []

    for i, (hid, g) in enumerate(groups):
        if "time_since_start_min" in g.columns:
            g = g.sort_values("time_since_start_min")
        elif "prefix_len" in g.columns:
            g = g.sort_values("prefix_len")
        g = g.reset_index(drop=True)

        X_raw = g[feature_cols].values.astype(float)
        if scaler is not None:
            X_raw = scaler.transform(X_raw)

        L = min(len(g), max_len)
        X_seq[i, :L, :] = X_raw[:L]
        pad_mask[i, :L] = False

        y_delay[i, :L] = g[target_cols["delay"]].values[:L]
        y_next[i, :L] = g[target_cols["next_event0"]].values[:L]
        y_ttn[i, :L] = g[target_cols["time_to_next_log1p"]].values[:L]
        y_los[i, :L] = g[target_cols["remain_los"]].values[:L]

        hadm_ids.append(hid)

    return X_seq, pad_mask, y_delay, y_next, y_ttn, y_los, np.array(hadm_ids)


X_tr, mask_tr, y_tr_delay, y_tr_next, y_tr_ttn, y_tr_los, hadm_tr = build_sequences_with_ids(
    df_train, feature_cols, max_len=MAX_LEN, scaler=scaler
)
X_va, mask_va, y_va_delay, y_va_next, y_va_ttn, y_va_los, hadm_va = build_sequences_with_ids(
    df_val, feature_cols, max_len=MAX_LEN, scaler=scaler
)
X_te, mask_te, y_te_delay, y_te_next, y_te_ttn, y_te_los, hadm_te = build_sequences_with_ids(
    df_test, feature_cols, max_len=MAX_LEN, scaler=scaler
)

print(f"[SEQ] Train: {X_tr.shape}, Val: {X_va.shape}, Test: {X_te.shape}")


class PPMDataset(Dataset):
    def __init__(self, X, mask, y_delay, y_next, y_ttn, y_los):
        self.X = X
        self.mask = mask
        self.y_delay = y_delay
        self.y_next = y_next
        self.y_ttn = y_ttn
        self.y_los = y_los

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return (
            torch.from_numpy(self.X[idx]),
            torch.from_numpy(self.mask[idx]),
            torch.from_numpy(self.y_delay[idx]),
            torch.from_numpy(self.y_next[idx]),
            torch.from_numpy(self.y_ttn[idx]),
            torch.from_numpy(self.y_los[idx]),
        )


train_ds = PPMDataset(X_tr, mask_tr, y_tr_delay, y_tr_next, y_tr_ttn, y_tr_los)
val_ds = PPMDataset(X_va, mask_va, y_va_delay, y_va_next, y_va_ttn, y_va_los)
test_ds = PPMDataset(X_te, mask_te, y_te_delay, y_te_next, y_te_ttn, y_te_los)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)


# =========================================
# 6. Multi-task Transformer 정의
# =========================================
class MultiTaskTransformer(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_heads, num_layers, num_next_classes, dropout=0.1):
        super().__init__()
        self.in_proj = nn.Linear(input_dim, hidden_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=dropout,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.delay_head = nn.Linear(hidden_dim, 1)
        self.next_head = nn.Linear(hidden_dim, num_next_classes)
        self.ttn_head = nn.Linear(hidden_dim, 1)
        self.los_head = nn.Linear(hidden_dim, 1)

    def forward(self, x, key_padding_mask):
        h = self.in_proj(x)
        h = self.encoder(h, src_key_padding_mask=key_padding_mask)

        delay_logit = self.delay_head(h).squeeze(-1)
        next_logits = self.next_head(h)
        ttn_out = self.ttn_head(h).squeeze(-1)
        los_out = self.los_head(h).squeeze(-1)

        return {
            "delay_logit": delay_logit,
            "next_event_logits": next_logits,
            "time_to_next": ttn_out,
            "remain_los": los_out,
        }


model = MultiTaskTransformer(
    input_dim=len(feature_cols),
    hidden_dim=64,
    num_heads=4,
    num_layers=2,
    num_next_classes=num_next_classes,
    dropout=0.1,
).to(device)


# =========================================
# 7. Transformer 학습
# =========================================
pos_ratio_delay = df_train[target_cols["delay"]].mean()
pos_weight_delay = (1 - pos_ratio_delay) / (pos_ratio_delay + 1e-8)
print(f"[INFO] delay_pos_ratio={pos_ratio_delay:.4f}, pos_weight_delay={pos_weight_delay:.2f}")

bce_delay = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight_delay, device=device))
ce_next = nn.CrossEntropyLoss()
mse_loss = nn.MSELoss()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(1, N_EPOCHS + 1):
    model.train()
    train_loss = 0.0
    for X, mask, y_delay, y_next, y_ttn, y_los in train_loader:
        X = X.to(device)
        mask = mask.to(device)
        y_delay = y_delay.to(device)
        y_next = y_next.to(device)
        y_ttn = y_ttn.to(device)
        y_los = y_los.to(device)

        outputs = model(X, key_padding_mask=mask)

        valid = ~mask
        if valid.sum() == 0:
            continue

        delay_logit = outputs["delay_logit"][valid]
        loss_delay = bce_delay(delay_logit, y_delay[valid])

        next_logits = outputs["next_event_logits"][valid]
        loss_next = ce_next(next_logits, y_next[valid])

        ttn_pred = outputs["time_to_next"][valid]
        loss_ttn = mse_loss(ttn_pred, y_ttn[valid])

        los_pred = outputs["remain_los"][valid]
        loss_los = mse_loss(los_pred, y_los[valid])

        loss = loss_delay + loss_next + 0.5 * loss_ttn + 0.5 * loss_los

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * X.size(0)

    train_loss /= len(train_ds)
    print(f"[Epoch {epoch:02d}] TrainLoss={train_loss:.4f}")


# =========================================
# 8. 온라인 예측용 헬퍼들
# =========================================
def event_name(event_id: int):
    mapping = {
        0: "ED_ARRIVAL",
        1: "ECG_TAKEN",
        2: "ECG_STEMI_FLAG",
        3: "TROP_TAKEN",
        4: "TROP_POSITIVE",
        5: "PCI_START",
        6: "PCI_END",
        7: "ICU_INTIME",
    }
    return mapping.get(event_id, f"EVT_{event_id}")


def build_patient_sequence(df_all: pd.DataFrame,
                           hadm_id: int,
                           feature_cols: list,
                           scaler,
                           max_len: int = 128):
    df_pat = df_all[df_all["hadm_id"] == hadm_id].copy()
    if df_pat.empty:
        raise ValueError(f"hadm_id={hadm_id} 에 해당하는 row가 없습니다.")

    if "time_since_start_min" in df_pat.columns:
        df_pat = df_pat.sort_values("time_since_start_min")
    elif "prefix_len" in df_pat.columns:
        df_pat = df_pat.sort_values("prefix_len")
    df_pat = df_pat.reset_index(drop=True)

    X_raw = df_pat[feature_cols].values.astype(float)
    X_scaled = scaler.transform(X_raw)

    L, D = X_scaled.shape
    if L > max_len:
        X_scaled = X_scaled[:max_len]
        L = max_len

    seq = np.zeros((max_len, D), dtype=np.float32)
    seq[:L] = X_scaled

    key_padding_mask = np.ones((max_len,), dtype=bool)
    key_padding_mask[:L] = False

    return df_pat, seq, key_padding_mask


def make_prefix_inputs(seq_full: np.ndarray,
                       pad_mask_full: np.ndarray,
                       t: int,
                       device: torch.device):
    max_len, D = seq_full.shape

    seq_prefix = seq_full.copy()
    if t < max_len:
        seq_prefix[t:] = 0.0

    key_padding_mask = np.ones((max_len,), dtype=bool)
    key_padding_mask[:t] = False

    seq_tensor = torch.from_numpy(seq_prefix).unsqueeze(0).to(device)
    mask_tensor = torch.from_numpy(key_padding_mask).unsqueeze(0).to(device)

    return seq_tensor, mask_tensor


def decode_transformer_outputs(outputs, step_index: int, min_next_label: int):
    delay_logit = outputs["delay_logit"][:, step_index]
    delay_prob = torch.sigmoid(delay_logit).item()

    next_logits = outputs["next_event_logits"][:, step_index, :]
    next_prob = F.softmax(next_logits, dim=-1)
    cls0 = int(torch.argmax(next_prob, dim=-1).item())
    next_event_id = cls0 + min_next_label
    next_event_prob = float(torch.max(next_prob).item())

    ttn_pred = outputs["time_to_next"][:, step_index].item()
    time_to_next_min = float(np.expm1(ttn_pred))

    los_pred = outputs["remain_los"][:, step_index].item()
    remain_los_days = float(los_pred)

    return {
        "delay_prob": delay_prob,
        "next_event_id": next_event_id,
        "next_event_prob": next_event_prob,
        "time_to_next_min": time_to_next_min,
        "remain_los_days": remain_los_days,
    }


def run_online_prediction_for_hadm(hadm_id: int):
    model.eval()

    df_pat, seq_full, pad_mask_full = build_patient_sequence(
        df_all=df,
        hadm_id=hadm_id,
        feature_cols=feature_cols,
        scaler=scaler,
        max_len=MAX_LEN
    )

    L = len(df_pat)
    print(f"\n[ONLINE PRED TABLE] hadm_id={hadm_id}, event 개수={L}")

    X_lgbm = df_pat[feature_cols].values
    time_col = "time_since_start_min" if "time_since_start_min" in df_pat.columns else None

    rows = []

    for t in range(1, L + 1):
        row = df_pat.iloc[t - 1]

        # ===== 실제값(target) =====
        true_delay = int(row[target_cols["delay"]])
        true_next = int(row["target_next_evt"])
        true_ttn = float(row["target_time_to_next"])
        true_los = float(row["target_remain_los"])

        # ===== LGBM mortality 예측 =====
        mort_proba = float(mort_lgbm.predict_proba(X_lgbm[t - 1:t])[0, 1])
        pred_mort = int(mort_proba >= 0.5)

        # ===== Transformer 예측 =====
        seq_tensor, mask_tensor = make_prefix_inputs(
            seq_full=seq_full,
            pad_mask_full=pad_mask_full,
            t=t,
            device=device
        )

        with torch.no_grad():
            outputs = model(seq_tensor, key_padding_mask=mask_tensor)

        trans_res = decode_transformer_outputs(
            outputs,
            step_index=t - 1,
            min_next_label=min_next_label
        )

        pred_delay = int(trans_res["delay_prob"] >= 0.5)
        pred_next = trans_res["next_event_id"]
        pred_ttn = float(trans_res["time_to_next_min"])
        pred_los = float(trans_res["remain_los_days"])

        ev_id = int(row["current_event_id"]) if "current_event_id" in row else -1
        ev_name = event_name(ev_id)

        if time_col is not None:
            time_since_start = float(row[time_col])
        else:
            time_since_start = np.nan

        prefix_len_val = int(row["prefix_len"]) if "prefix_len" in row else t

        # ===== pair 문자열 생성 =====
        mortality_pair = f"{int(row[target_cols['mortality']])}|{pred_mort}"
        delay_pair = f"{true_delay}|{pred_delay}"
        next_event_pair = f"{true_next}|{pred_next}"
        time_to_next_pair = f"{true_ttn:.2f}|{pred_ttn:.2f}"
        remain_los_pair = f"{true_los:.2f}|{pred_los:.2f}"

        # ===== row 저장 =====
        rows.append({
            "step": t,
            "prefix_len": prefix_len_val,
            "time_since_start_min": round(time_since_start, 1) if not np.isnan(time_since_start) else np.nan,
            "current_event_id": ev_id,
            "current_event_name": ev_name,

            # 실제값
            "true_mortality": int(row[target_cols["mortality"]]),
            "true_delay": true_delay,
            "true_next_event": true_next,
            "true_time_to_next_min": true_ttn,
            "true_remain_los_days": true_los,

            # 예측값
            "pred_mortality": pred_mort,
            "pred_delay": pred_delay,
            "pred_next_event": pred_next,
            "pred_time_to_next_min": round(pred_ttn, 2),
            "pred_remain_los_days": round(pred_los, 2),

            # pair (원본 | 예측)
            "mortality_pair": mortality_pair,
            "delay_pair": delay_pair,
            "next_event_pair": next_event_pair,
            "time_to_next_pair": time_to_next_pair,
            "remain_los_pair": remain_los_pair
        })

    log_df = pd.DataFrame(rows)
    print(log_df.to_string(index=False))
    return log_df

# =========================================
# 9. 테스트셋에서 "원본값|예측값" CSV 생성
# =========================================
def build_df_pred_from_test_with_pairs():
    model.eval()

    # hadm_id -> subject_id
    hadm_to_subj = (
        df[["hadm_id", "subject_id"]]
        .drop_duplicates()
        .set_index("hadm_id")["subject_id"]
        .to_dict()
    )

    # hadm_id -> mortality(원본)
    hadm_to_mort_true = (
        df[["hadm_id", "target_mortality"]]
        .drop_duplicates()
        .set_index("hadm_id")["target_mortality"]
        .to_dict()
    )

    # hadm_id -> mortality(예측, LGBM)
    tmp = df_test.copy()
    if "time_since_start_min" in tmp.columns:
        tmp = tmp.sort_values(["hadm_id", "time_since_start_min"])
    elif "prefix_len" in tmp.columns:
        tmp = tmp.sort_values(["hadm_id", "prefix_len"])
    tmp_first = tmp.groupby("hadm_id").head(1).copy()
    X_tmp = tmp_first[feature_cols].values
    mort_proba = mort_lgbm.predict_proba(X_tmp)[:, 1]
    mort_pred_label = (mort_proba >= 0.5).astype(int)

    hadm_to_mort_pred = {
        hid: pred for hid, pred in zip(tmp_first["hadm_id"].values, mort_pred_label)
    }

    records = []
    global_idx = 0

    with torch.no_grad():
        for X, mask, y_delay, y_next, y_ttn, y_los in test_loader:
            bs = X.size(0)

            X = X.to(device)
            mask = mask.to(device)
            y_delay = y_delay.to(device)
            y_next = y_next.to(device)
            y_ttn = y_ttn.to(device)
            y_los = y_los.to(device)

            outputs = model(X, key_padding_mask=mask)
            logits_next = outputs["next_event_logits"]
            logits_delay = outputs["delay_logit"]
            pred_next0 = torch.argmax(logits_next, dim=-1)
            pred_delay_prob = torch.sigmoid(logits_delay)
            pred_delay_label = (pred_delay_prob >= 0.5).long()

            pred_ttn = outputs["time_to_next"]
            pred_los = outputs["remain_los"]

            mask_np = mask.cpu().numpy()
            y_delay_np = y_delay.cpu().numpy()
            y_next_np = y_next.cpu().numpy()
            y_ttn_np = y_ttn.cpu().numpy()
            y_los_np = y_los.cpu().numpy()

            pred_next0_np = pred_next0.cpu().numpy()
            pred_delay_label_np = pred_delay_label.cpu().numpy()
            pred_ttn_np = pred_ttn.cpu().numpy()
            pred_los_np = pred_los.cpu().numpy()

            for b in range(bs):
                hadm_id = int(hadm_te[global_idx + b])
                subject_id = int(hadm_to_subj[hadm_id])
                mort_true = int(hadm_to_mort_true[hadm_id])
                mort_pred = int(hadm_to_mort_pred[hadm_id])

                valid = ~mask_np[b]

                true_delay = y_delay_np[b][valid]
                true_next0 = y_next_np[b][valid]
                true_ttn_log = y_ttn_np[b][valid]
                true_ttn_min = np.expm1(true_ttn_log)
                true_los = y_los_np[b][valid]

                pred_delay = pred_delay_label_np[b][valid]
                pred_next0_b = pred_next0_np[b][valid]
                pred_ttn_log = pred_ttn_np[b][valid]
                pred_ttn_min = np.expm1(pred_ttn_log)
                pred_los = pred_los_np[b][valid]

                for step_idx in range(len(true_delay)):
                    td = int(true_delay[step_idx])
                    pd_ = int(pred_delay[step_idx])

                    t_next_label0 = int(true_next0[step_idx])
                    p_next_label0 = int(pred_next0_b[step_idx])
                    t_next_label = t_next_label0 + min_next_label
                    p_next_label = p_next_label0 + min_next_label

                    t_ttn = float(true_ttn_min[step_idx])
                    p_ttn = float(pred_ttn_min[step_idx])

                    t_los = float(true_los[step_idx])
                    p_los = float(pred_los[step_idx])

                    records.append({
                        "subject_id": subject_id,
                        "hadm_id": hadm_id,
                        "step": step_idx + 1,

                        # 숫자 원본/예측
                        "target_mortality": mort_true,
                        "pred_mortality": mort_pred,

                        "target_delay": td,
                        "pred_delay": pd_,

                        "target_next_evt": t_next_label,
                        "pred_next_evt": p_next_label,

                        "target_time_to_next_min": t_ttn,
                        "pred_time_to_next_min": p_ttn,

                        "target_remain_los_days": t_los,
                        "pred_remain_los_days": p_los,

                        # 원본값|예측값 문자열
                        "mortality_pair": f"{mort_true}|{mort_pred}",
                        "delay_pair": f"{td}|{pd_}",
                        "next_event_pair": f"{t_next_label}|{p_next_label}",
                        "time_to_next_min_pair": f"{t_ttn:.2f}|{p_ttn:.2f}",
                        "remain_los_days_pair": f"{t_los:.2f}|{p_los:.2f}",
                    })

            global_idx += bs

    df_pred = pd.DataFrame(records)
    print("[EVAL] df_pred 생성 완료:", df_pred.shape)
    return df_pred


# =========================================
# 10. 메인 실행
# =========================================
if __name__ == "__main__":
    # 1) 테스트셋 전체에 대해 예측 + "원본값|예측값" CSV 생성
    df_pred = build_df_pred_from_test_with_pairs()
    csv_path = "./ppm_test_pred_with_pairs.csv"
    df_pred.to_csv(csv_path, index=False, encoding="utf-8-sig")
    print(f"[SAVE] 테스트셋 예측 CSV 저장 완료: {csv_path}")

    # 2) next_event 기준 환자별 예측 정확도
    df_pred["correct_next"] = (df_pred["target_next_evt"] == df_pred["pred_next_evt"]).astype(int)
    patient_perf = (
        df_pred.groupby(["subject_id", "hadm_id"])["correct_next"]
        .mean()
        .reset_index()
        .rename(columns={"correct_next": "next_event_accuracy"})
    )

    # 3) accuracy < 1.0 인 환자만 필터
    filtered = patient_perf[patient_perf["next_event_accuracy"] < 1.0]

    if filtered.empty:
        print("[WARN] accuracy < 1.0 인 환자가 없습니다. 전체 중 최고 accuracy 환자 사용.")
        candidate_df = patient_perf.copy()
    else:
        candidate_df = filtered

    # 4) 그 중 accuracy 최고 환자 선택
    best_row = candidate_df.sort_values("next_event_accuracy", ascending=False).iloc[0]
    best_subject_id = int(best_row["subject_id"])
    best_hadm_id = int(best_row["hadm_id"])
    best_acc = float(best_row["next_event_accuracy"])

    print("\n[INFO] 선택된 환자 (accuracy 1이 아닌 환자 중 최고, 없으면 전체 최고)")
    print(f" subject_id: {best_subject_id}")
    print(f" hadm_id   : {best_hadm_id}")
    print(f" accuracy  : {best_acc:.4f}")

    # 5) 선택된 환자로 온라인 예측 실행 + CSV 저장
    log_df = run_online_prediction_for_hadm(best_hadm_id)
    save_path = f"./online_pred_best_nonperfect_{best_hadm_id}.csv"
    log_df.to_csv(save_path, index=False, encoding="utf-8-sig")
    print(f"[SAVE] 온라인 예측 로그 저장 완료: {save_path}")


[INFO] device: cpu
[LOAD] 데이터 로딩 완료
 - shape : (40817, 27)
[PREP] race -> race_enc 인코딩 완료
[PREP] delay_label 생성 완료 (75% 기준: 168.10)
[INFO] 사용 feature 수: 26
[INFO] Feature 예시: ['age', 'gender', 'race_enc', 'arrival_transport', 'prefix_len', 'current_event_id', 'time_since_start_min', 'time_since_ed', 'time_since_last', 'is_night']
[PREP] 결측 제거 후: (40817, 39)
[SPLIT] hadm_id 기준 분할 완료
 - 전체 hadm_id: 1929
 - train hadm_id: 1350 , rows: 28412
 - val   hadm_id: 289 , rows: 5723
 - test  hadm_id: 290 , rows: 6682
[TRAIN] LGBM Mortality 학습
[LightGBM] [Info] Number of positive: 2300, number of negative: 26112
[LightGBM] [Info] Auto-choosing row-wise multi-threading, the overhead of testing was 0.000992 seconds.
You can set `force_row_wise=true` to remove the overhead.
And if memory is not enough, you can set `force_col_wise=true`.
[LightGBM] [Info] Total Bins 2256
[LightGBM] [Info] Number of data points in the train set: 28412, number of used features: 21
[LightGBM] [Info] [binary:BoostFromScor



[SAVE] 테스트셋 예측 CSV 저장 완료: ./ppm_test_pred_with_pairs.csv

[INFO] 선택된 환자 (accuracy 1이 아닌 환자 중 최고, 없으면 전체 최고)
 subject_id: 17953273
 hadm_id   : 23590849
 accuracy  : 0.9412

[ONLINE PRED TABLE] hadm_id=23590849, event 개수=17
 step  prefix_len  time_since_start_min  current_event_id current_event_name  mortality_prob  delay_prob  next_event_id next_event_name  next_event_prob  time_to_next_min  remain_los_days
    1           1                   0.0                 1          ECG_TAKEN          0.0843      0.0111              3      TROP_TAKEN           0.9302              1.74             4.76
    2           2                   3.6                 3         TROP_TAKEN          0.1133      0.1251              3      TROP_TAKEN           0.5714              5.97             3.07
    3           3                   3.6                 3         TROP_TAKEN          0.0853      0.0210              3      TROP_TAKEN           0.8534              1.05             3.61
    4           4        

