# AD Cox (Tensor Fusion + LSTM) — End-to-End Notebook

This notebook walks through the full pipeline:

1) Environment & common setup
2) Data processing (Step 5) — build patient sequences & save
3) Split & preprocessing (Step 6) — train-only fit → transform
4) Dataset & dataloader — variable-length padding & mask
5) Model assembly — 3 embeddings → Tensor Fusion → LSTM → Cox
6) Smoke test — single-batch forward & loss
7) Training loop — optimization, checkpoints, early stop
8) Evaluation — Val/Test C-index & visit-trimming checks
9) Baseline hazard & time-to-AD predictions (Breslow + median)
10) Subsequence / Landmark inference — residual-time predictions
11) Calibration / error — IPCW Brier @h & IBS
A) Utilities — I/O helpers

**One-line summary**: After building per-patient visit sequences (MCI → AD-1), we embed (visual / demographics / time), fuse them with Tensor Fusion, pass through an LSTM, and train a Cox model (Efron ties). We then compute baseline hazard on **train only** and turn risk scores into **time-to-AD** predictions.


## 1) Environment & Common Setup

- Import libraries and set device / seeds.
- Define key paths (data, outputs, checkpoints).
- (Windows/Jupyter) Set multiprocessing start method to `spawn` for safety.
- No code changes here — this cell only prepares global context used later.


In [1]:
import os
import math
import random
from typing import List, Tuple, Optional

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from PIL import Image

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.impute import SimpleImputer

# Reproducibility
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [2]:
# These outputs will feed into per-visit Tensor Fusion in the next step.

import torch

def _ensure_3d_time(output: torch.Tensor) -> torch.Tensor:
    """
    Ensure output has an explicit time axis.
    If output is [B, D], convert to [B, 1, D].
    If already [B, T, D], return as-is.
    """
    if output.dim() == 2:
        return output.unsqueeze(1)
    return output



## 2) Data Processing — Build Patient Sequences & Save

**Goal**: Convert raw per-patient PKL tables into uniform sequence objects.

**What this cell does**
- Load the manifest CSV (`PTID`, `path`) to find each patient PKL.
- For each patient table:
  - Sort by `Years_bl`.
  - Identify the first **MCI** and include visits **MCI → (AD/Dementia)-1**; if never AD, include **MCI → last**.
  - Define `times` as Δt from MCI (monotonic, starting at 0).
  - Collect structured features using the `WHITELIST_STRUCT_COLS`.
  - Collect `img_paths` (optionally drop visits with missing images).
  - Set **targets**:
    - `event`: 1 if the patient eventually reaches AD/Dementia, else 0.
    - `t_event`: time-to-event if event=1, or last observed time (censoring) if event=0.
- Save all sequences to `seqs_all_raw.pkl`.
- Print a basic sanity report (count, lengths, event rate, monotonicity).

**Outputs**
- `AD_Cox_Files/out_step5/seqs_all_raw.pkl`
- Sanity report dictionary in the log.


In [3]:
from AD_Cox_Files.DataProcessors import (
    WHITELIST_STRUCT_COLS,
    read_manifest,
    build_all_sequences,
    build_patient_sequence,
    PatientSeq,
)
from AD_Cox_Files.Utils import save_sequences, sanity_check_sequences

import pandas as pd
import numpy as np
manifest = pd.read_csv("./AD_Patient_Manifest.csv")

# 1) Load manifest
manifest = pd.read_csv("./AD_Patient_Manifest.csv")

pid_col  = "PTID"
path_col = "path"

# 2) Normalize to a standard schema
mf = manifest[[pid_col, path_col]].copy()
mf.columns = ["pid", "pkl_path"]
mf["pid"] = mf["pid"].astype(str)
mf["pkl_path"] = mf["pkl_path"].astype(str)

# 4) Quick path existence check (optional)
missing = mf[~mf["pkl_path"].apply(os.path.exists)]
if len(missing):
    print(f"⚠️ {len(missing)} PKL paths not found. First 3:")
    print(missing.head(3))

IMG_PREFIX_REPLACE = ("/home/mason/ADNI_Dataset/", "../ADNI_Dataset/")

seqs_all_raw = build_all_sequences(
    manifest_df=mf,
    whitelist_cols=WHITELIST_STRUCT_COLS,
    require_images=True,
    img_prefix_replace=IMG_PREFIX_REPLACE,
)

print(f"[Step5] built raw sequences: {len(seqs_all_raw)} patients")
print("Sanity:", sanity_check_sequences(seqs_all_raw))

OUT_RAW = "AD_Cox_Files/seqs_all_raw.pkl"
save_sequences(OUT_RAW, seqs_all_raw)
print(f"[Step5] saved → {OUT_RAW}")




[Step5] built raw sequences: 161 patients
Sanity: {'num_patients': 161, 'avg_T': 4.39751552795031, 'min_T': 1, 'max_T': 6, 'event_rate': 0.45962732919254656, 'violations': 0}
[Step5] saved → AD_Cox_Files/seqs_all_raw.pkl


In [4]:
import os, json, numpy as np

from AD_Cox_Files.DataProcessors import StructPreprocessor, PatientSeq
from AD_Cox_Files.Utils import (
    load_sequences, save_sequences, sanity_check_sequences,
    ensure_dir, save_json, assert_disjoint
)

SEQ_RAW_PATH = "AD_Cox_Files/seqs_all_raw.pkl"
OUT_DIR      = "AD_Cox_Files/out_step6"
VAL_RATIO    = 0.10
TEST_RATIO   = 0.20
SEED         = 42

ensure_dir(OUT_DIR)


In [5]:
seqs_all_raw = load_sequences(SEQ_RAW_PATH)
print(f"[Step6] Loaded raw sequences: {len(seqs_all_raw)} patients")
print("Sanity(all_raw):", sanity_check_sequences(seqs_all_raw))


[Step6] Loaded raw sequences: 161 patients
Sanity(all_raw): {'num_patients': 161, 'avg_T': 4.39751552795031, 'min_T': 1, 'max_T': 6, 'event_rate': 0.45962732919254656, 'violations': 0}


## 3) Split & Preprocessing — Train-only fit → transform

**Goal**
Create patient-disjoint splits and apply **leak-free** preprocessing.

**What this cell does**
- **Patient-level split** into `train / val / test` (disjoint `PTID`s; store to `splits.json`).
- **Fit** `StructPreprocessor` **only on train**:
  - Median imputation + standard scaling.
  - Optional feature augmentation: first-difference (Δ) and velocity over a short window.
- **Transform** each split (`train/val/test`) with the trained preprocessor.
- **Save** artifacts for downstream steps:
  - `AD_Cox_Files/out_step6/struct_preproc.pkl`
  - `AD_Cox_Files/out_step6/seqs_train.pkl`
  - `AD_Cox_Files/out_step6/seqs_val.pkl`
  - `AD_Cox_Files/out_step6/seqs_test.pkl`
  - `AD_Cox_Files/out_step6/splits.json`

**Why train-only fit matters**
- Prevents leakage (val/test statistics never influence imputation/scaling).
- Keeps evaluation honest and comparable across runs.

**Quick checklist**
- [ ] Train/Val/Test are **patient-disjoint**.
- [ ] Preprocessor `.fit()` used **train only**; `.transform()` used on all splits.
- [ ] Output files exist in `out_step6/` and load without errors.


In [6]:
rng = np.random.RandomState(SEED)

pos_ids = [s.pid for s in seqs_all_raw if int(s.event) == 1]
neg_ids = [s.pid for s in seqs_all_raw if int(s.event) == 0]

def _split_ids(ids, val_ratio=VAL_RATIO, test_ratio=TEST_RATIO):
    ids = np.array(ids, dtype=str)
    rng.shuffle(ids)
    n = len(ids)
    n_test = int(round(n * test_ratio))
    n_val  = int(round(n * val_ratio))
    test_ids = ids[:n_test].tolist()
    val_ids  = ids[n_test:n_test+n_val].tolist()
    train_ids= ids[n_test+n_val:].tolist()
    return train_ids, val_ids, test_ids

tr_p, va_p, te_p = _split_ids(pos_ids)
tr_n, va_n, te_n = _split_ids(neg_ids)

train_ids = tr_p + tr_n
val_ids   = va_p + va_n
test_ids  = te_p + te_n

rng.shuffle(train_ids); rng.shuffle(val_ids); rng.shuffle(test_ids)
assert_disjoint(train_ids, val_ids, test_ids)

print("Split sizes:", len(train_ids), len(val_ids), len(test_ids))
save_json(os.path.join(OUT_DIR, "splits.json"), {
    "train_ids": train_ids, "val_ids": val_ids, "test_ids": test_ids
})


Split sizes: 113 16 32


In [7]:
idset_tr, idset_va, idset_te = set(train_ids), set(val_ids), set(test_ids)

train_raw = [s for s in seqs_all_raw if s.pid in idset_tr]
val_raw   = [s for s in seqs_all_raw if s.pid in idset_va]
test_raw  = [s for s in seqs_all_raw if s.pid in idset_te]

print("Patients (train/val/test):", len(train_raw), len(val_raw), len(test_raw))
print("Event rate (train/val/test):",
      np.mean([s.event for s in train_raw]),
      np.mean([s.event for s in val_raw]),
      np.mean([s.event for s in test_raw]))


Patients (train/val/test): 113 16 32
Event rate (train/val/test): 0.46017699115044247 0.4375 0.46875


In [8]:
sp = StructPreprocessor(add_deltas=True, add_velocity=False, vel_window=3)
sp.fit(train_raw)

train_proc = sp.transform(train_raw)
val_proc   = sp.transform(val_raw)
test_proc  = sp.transform(test_raw)


In [9]:
sp.save(os.path.join(OUT_DIR, "struct_preproc.pkl"))
save_sequences(os.path.join(OUT_DIR, "seqs_train.pkl"), train_proc)
save_sequences(os.path.join(OUT_DIR, "seqs_val.pkl"),   val_proc)
save_sequences(os.path.join(OUT_DIR, "seqs_test.pkl"),  test_proc)

print("[Saved] struct_preproc.pkl & seqs_{train,val,test}.pkl →", OUT_DIR)


[Saved] struct_preproc.pkl & seqs_{train,val,test}.pkl → AD_Cox_Files/out_step6


In [10]:
print("Sanity(train):", sanity_check_sequences(train_proc))
print("Sanity(val):  ", sanity_check_sequences(val_proc))
print("Sanity(test): ", sanity_check_sequences(test_proc))


Sanity(train): {'num_patients': 113, 'avg_T': 4.398230088495575, 'min_T': 1, 'max_T': 6, 'event_rate': 0.46017699115044247, 'violations': 0}
Sanity(val):   {'num_patients': 16, 'avg_T': 4.0625, 'min_T': 1, 'max_T': 6, 'event_rate': 0.4375, 'violations': 0}
Sanity(test):  {'num_patients': 32, 'avg_T': 4.5625, 'min_T': 1, 'max_T': 6, 'event_rate': 0.46875, 'violations': 0}


## 4) Dataset & Dataloader — Variable-Length Padding & Mask

**Goal**: Turn sequences into batches while preserving variable length.

**What this cell does**
- `ADSequenceDataset` returns (per patient):
  - `x_struct` : `[T, F]` (preprocessed demographics)
  - `x_time`   : `[T]`     (Δt since MCI, used as input time axis)
  - `x_img`    : `[T, 3, 224, 224]` or `None`
  - `event`    : `[]` (0/1)
  - `t_event`  : `[]` (time-to-event or censor time)
  - `pid`      : string
- `pad_collate` builds batch tensors:
  - `X_struct`: `[B, T_max, F]`
  - `X_time`  : `[B, T_max]`
  - `X_img`   : `[B, T_max, 3, 224, 224]` or `None`
  - `mask`    : `[B, T_max]` (True for valid timesteps)
  - `lengths` : `[B]`
  - `event`   : `[B]`
  - `t_event` : `[B]`
  - `pid`     : list of strings

**Note**
- Images use `Resize(224) → ToTensor → Normalize`.
- `mask` and `lengths` are used by the LSTM to ignore padding.


In [11]:
import os
from torchvision import transforms
from AD_Cox_Files.Utils import load_sequences


SEQ_DIR = "AD_Cox_Files/out_step6"
train_seqs = load_sequences(os.path.join(SEQ_DIR, "seqs_train.pkl"))
val_seqs   = load_sequences(os.path.join(SEQ_DIR, "seqs_val.pkl"))
test_seqs  = load_sequences(os.path.join(SEQ_DIR, "seqs_test.pkl"))

img_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])


In [12]:
def pad_collate(batch):
    import torch

    B = len(batch)
    lengths = [b["x_struct"].shape[0] for b in batch]
    T_max = max(lengths)
    F = batch[0]["x_struct"].shape[1]

    has_img = any(b["x_img"] is not None for b in batch)

    Xs   = torch.zeros(B, T_max, F, dtype=torch.float32)
    Xt   = torch.zeros(B, T_max,     dtype=torch.float32)
    mask = torch.zeros(B, T_max,     dtype=torch.bool)
    events = torch.zeros(B,          dtype=torch.float32)
    pids = []

    Xi = None
    if has_img:
        Xi = torch.zeros(B, T_max, 3, 224, 224, dtype=torch.float32)

    for i, b in enumerate(batch):
        T = b["x_struct"].shape[0]
        Xs[i, :T] = b["x_struct"]
        Xt[i, :T] = b["x_time"]
        mask[i, :T] = True
        events[i] = b["event"]
        pids.append(b["pid"])
        if has_img and (b["x_img"] is not None):
            Xi[i, :T] = b["x_img"]

    t_event = torch.stack([
        b["t_event"] if isinstance(b["t_event"], torch.Tensor)
        else torch.tensor(float(b["t_event"]), dtype=torch.float32)
        for b in batch
    ]).float()  # [B]

    return {
        "X_struct": Xs, "X_time": Xt, "X_img": Xi,
        "mask": mask, "event": events,
        "t_event": t_event,                 # ★ 여기 사용
        "lengths": torch.tensor(lengths),
        "pid": pids,
    }


## 5) Model Assembly — 3 Embeddings → Tensor Fusion → LSTM → Cox

**Goal**: Build the multimodal sequence model.

**Components**
- **VisualEmbedder**: CNN backbone (frozen or trainable) → per-visit visual feature.
- **DemographicsEmbedder**: MLP for structured features.
- **TimeEmbedder**: Time2Vec-style time embedding using `X_time` (as `[B, T, 1]`).
- **TensorFusion**: Project to common dims and combine via Hadamard (element-wise) interactions (captures cross-modal interactions better than concat).
- **LSTM**: Variable-length sequence encoder (`pack_padded_sequence` style).
- **Cox head**: Outputs `log_risk` per patient.

**Loss**
- Cox partial likelihood with **Efron ties** (penalty λ configurable).

**Forward signature**
- `model(imgs=X_img, demos=X_struct, times=X_time.unsqueeze(-1), lengths=lengths) → log_risk [B]`


In [13]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms

from AD_Cox_Files.ADSequenceDataset import ADSequenceDataset
from AD_Cox_Files.Utils import load_sequences

from AD_Cox_Files.VisualEmbedder import VisualEmbedder
from AD_Cox_Files.DemographicsEmbedder import DemographicsEmbedder
from AD_Cox_Files.TimeEmbedder import TimeEmbedder
from AD_Cox_Files.TensorFusion import TensorFusion
from AD_Cox_Files.CoxModel import build_sequence_cox_model
import torch.multiprocessing as mp
mp.set_start_method("spawn", force=True)



# 1) Data loading
SEQ_DIR = "AD_Cox_Files/out_step6"
train = load_sequences(f"{SEQ_DIR}/seqs_train.pkl")
val   = load_sequences(f"{SEQ_DIR}/seqs_val.pkl")

img_tfm = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

train_ds = ADSequenceDataset(train, img_transform=img_tfm, load_images=True)
val_ds   = ADSequenceDataset(val,   img_transform=img_tfm, load_images=True)

# DataLoader
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True,
                          num_workers=0, collate_fn=pad_collate, pin_memory=False)
val_loader   = DataLoader(val_ds,   batch_size=8, shuffle=False,
                          num_workers=0, collate_fn=pad_collate, pin_memory=False)


# 2) Embeddings / Fusion / Model assembly
F_struct = train[0].struct.shape[1]   # demographics in-dim
v = VisualEmbedder(out_dim=256, train_backbone=False)  # VGG16 backbone
d = DemographicsEmbedder(in_dim=F_struct, out_dim=64)
t = TimeEmbedder(in_dim=1, out_dim=16)

fusion = TensorFusion(v_dim=256, d_dim=64, t_dim=16, proj_dim=128, dropout=0.1)

model, loss_fn = build_sequence_cox_model(
    visual_embedder=v, demo_embedder=d, time_embedder=t, fusion_module=fusion,
    lstm_hidden=64, lstm_layers=1, bidirectional=False, dropout=0.1,
    penalty=1e-4,  # Efron loss penalty λ
)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# 3) Smoke test (single batch)
batch = next(iter(train_loader))
log_risk = model(
    imgs=batch["X_img"].to(device),            # [B,T,3,224,224]
    demos=batch["X_struct"].to(device),        # [B,T,F]
    times=batch["X_time"].unsqueeze(-1).to(device),  # [B,T,1]
    lengths=batch["lengths"],
)
loss = loss_fn(batch["t_event"].to(device), batch["event"].to(device), log_risk)
print("log_risk:", tuple(log_risk.shape), "loss:", float(loss))


log_risk: (4,) loss: 1.2519538402557373


## 6) Smoke Test — Single-Batch Forward & Loss

**Goal**: Verify shapes and loss computation before training.

**What to expect**
- `log_risk` shape: `[B]`.
- `loss_fn(t_event, event, log_risk)` returns a scalar (float).
- If this passes, dataloader → model → loss pipeline is wired correctly.


In [14]:
import numpy as np
import torch, os

def c_index_global(times, events, risks):
    """
    times, events, risks: 1D tensors on CPU
      - times = t_last (B,)
      - events = {0,1} (B,)
      - risks = log-risk or risk (B,)
    """
    t = times.numpy()
    e = events.numpy().astype(bool)
    r = risks.numpy()
    num = 0.0; den = 0.0
    n = len(t)
    for i in range(n):
        if not e[i]:
            continue
        for j in range(n):
            if t[i] < t[j]:
                den += 1.0
                if   r[i] > r[j]: num += 1.0
                elif r[i] == r[j]: num += 0.5
    return (num/den) if den > 0 else float("nan")

@torch.no_grad()
def evaluate_epoch(model, loader, loss_fn, device):
    model.eval()
    all_t, all_e, all_r = [], [], []
    total_loss, total_n = 0.0, 0
    for batch in loader:
        imgs   = batch["X_img"].to(device) if batch["X_img"] is not None else None
        demos  = batch["X_struct"].to(device)
        times  = batch["X_time"].unsqueeze(-1).to(device)
        lens   = batch["lengths"]

        t_event = batch["t_event"].to(device)

        event  = batch["event"].to(device)

        log_risk = model(imgs=imgs, demos=demos, times=times, lengths=lens)
        loss = loss_fn(t_event, event, log_risk)

        B = t_event.shape[0]
        total_loss += float(loss) * B
        total_n    += B

        all_t.append(t_event.detach().cpu())
        all_e.append(event.detach().cpu())
        all_r.append(log_risk.detach().cpu())

    all_t = torch.cat(all_t, dim=0)
    all_e = torch.cat(all_e, dim=0)
    all_r = torch.cat(all_r, dim=0)

    c = c_index_global(all_t, all_e, all_r)
    return total_loss/total_n, c


def train_epoch(model, loader, loss_fn, optimizer, device, grad_clip=1.0, amp=True):
    model.train()
    scaler = torch.cuda.amp.GradScaler(enabled=(amp and torch.cuda.is_available()))
    total_loss, total_n = 0.0, 0

    for batch in loader:
        imgs   = batch["X_img"].to(device) if batch["X_img"] is not None else None
        demos  = batch["X_struct"].to(device)
        times  = batch["X_time"].unsqueeze(-1).to(device)   # [B,T,1]
        lens   = batch["lengths"]
        t_event = batch["t_event"].to(device)
        event  = batch["event"].to(device)

        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=(amp and torch.cuda.is_available())):
            log_risk = model(imgs=imgs, demos=demos, times=times, lengths=lens)  # [B]
            loss = loss_fn(t_event, event, log_risk)

        scaler.scale(loss).backward()
        if grad_clip is not None:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        scaler.step(optimizer)
        scaler.update()

        B = t_event.shape[0]
        total_loss += float(loss) * B
        total_n    += B

    return total_loss / max(total_n, 1)



## 7) Training Loop — Optimization / Checkpoints / Early Stop

**Goal**: Train the sequence model with validation monitoring.

**What this cell does**
- Define optimizer (e.g., AdamW) and LR scheduler.
- Training epoch:
  - Forward with AMP (optional), compute Cox loss, backprop, grad-clip.
- Validation epoch:
  - Compute loss and **C-index** (Harrell’s C using event/comparable pairs).
- Save:
  - `best_model.pt` based on **val C-index** (or val loss).
  - `last_model.pt` at the end.
- Print per-epoch log: `train loss | val loss | val C`.

**Tip**
- On Windows/Jupyter, set `num_workers=0` first; increase only if stable.


In [15]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=2)

EPOCHS = 10
BEST_METRIC = -1.0     # maximize C-index
PATIENCE = 5
pat = 0

CKPT_DIR = "AD_Cox_Files/checkpoints"
os.makedirs(CKPT_DIR, exist_ok=True)
BEST_PATH = os.path.join(CKPT_DIR, "best_model.pt")
LAST_PATH = os.path.join(CKPT_DIR, "last_model.pt")

for ep in range(1, EPOCHS+1):
    tr_loss = train_epoch(model, train_loader, loss_fn, optimizer, device, grad_clip=1.0, amp=True)
    va_loss, va_c = evaluate_epoch(model, val_loader, loss_fn, device)
    scheduler.step(va_loss)

    print(f"[{ep:02d}] train loss {tr_loss:.4f} | val loss {va_loss:.4f} | val C {va_c:.3f}")

    # best by C-index (tie-breaker: val loss)
    is_best = (va_c > BEST_METRIC) or (np.isclose(va_c, BEST_METRIC) and va_loss < scheduler.best)
    if is_best:
        BEST_METRIC = va_c
        pat = 0
        torch.save({"model": model.state_dict(),
                    "opt": optimizer.state_dict(),
                    "best_val_c": BEST_METRIC,
                    "val_loss": va_loss}, BEST_PATH)
    else:
        pat += 1
        if pat >= PATIENCE:
            print("Early stopping triggered.")
            break

torch.save({"model": model.state_dict(),
            "opt": optimizer.state_dict()}, LAST_PATH)

print(f"Best model @ {BEST_PATH}  (val C={BEST_METRIC:.3f})")


  scaler = torch.cuda.amp.GradScaler(enabled=(amp and torch.cuda.is_available()))
  with torch.cuda.amp.autocast(enabled=(amp and torch.cuda.is_available())):


[01] train loss 0.5864 | val loss 0.8033 | val C 0.923
[02] train loss 0.3491 | val loss 0.7376 | val C 0.923
[03] train loss 0.2499 | val loss 0.6623 | val C 0.923
[04] train loss 0.2433 | val loss 0.6908 | val C 0.885
[05] train loss 0.2102 | val loss 0.6101 | val C 0.923
[06] train loss 0.1909 | val loss 0.5204 | val C 0.987
[07] train loss 0.1785 | val loss 0.5825 | val C 0.962
[08] train loss 0.1814 | val loss 0.5733 | val C 0.923
[09] train loss 0.1485 | val loss 0.5230 | val C 0.923
[10] train loss 0.1443 | val loss 0.4993 | val C 0.923
Best model @ AD_Cox_Files/checkpoints\best_model.pt  (val C=0.987)


## 8) Evaluation — Val/Test C-index & Visit-Trimming Checks

**Goal**: Confirm generalization and inspect robustness to visit coverage.

**What this cell does**
- Load `seqs_test.pkl`, evaluate loss/C-index on the held-out test cohort.
- Sensitivity tests:
  - Remove last k visits and re-evaluate (does performance collapse?).
  - First visit only (harder).
- If test C is close to val C, generalization is good; trimming sanity checks mitigate “using only very last visit” concerns.


In [16]:
from AD_Cox_Files.Utils import load_sequences
from AD_Cox_Files.ADSequenceDataset import ADSequenceDataset
from torch.utils.data import DataLoader
from torchvision import transforms

# test loader
test = load_sequences("AD_Cox_Files/out_step6/seqs_test.pkl")
img_tfm = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])
test_ds = ADSequenceDataset(test, img_transform=img_tfm, load_images=True)
test_loader = DataLoader(test_ds, batch_size=8, shuffle=False, num_workers=0, collate_fn=pad_collate)

ckpt = torch.load(BEST_PATH, map_location=device)
model.load_state_dict(ckpt["model"])
test_loss, test_c = evaluate_epoch(model, test_loader, loss_fn, device)
print(f"[TEST] loss {test_loss:.4f} | C-index {test_c:.3f}")


[TEST] loss 0.4006 | C-index 0.962


In [17]:
import pandas as pd

@torch.no_grad()
def collect_preds(loader, split_name="test"):
    model.eval()
    rows = []
    for b in loader:
        imgs   = b["X_img"].to(device) if b["X_img"] is not None else None
        demos  = b["X_struct"].to(device)
        times  = b["X_time"].unsqueeze(-1).to(device)
        lens   = b["lengths"]
        t_last = b["t_event"]
        event  = b["event"]

        log_risk = model(imgs=imgs, demos=demos, times=times, lengths=lens).cpu()
        for i, pid in enumerate(b["pid"]):
            rows.append({
                "pid": pid,
                "t_last": float(t_last[i]),
                "event": int(event[i]),
                "log_risk": float(log_risk[i]),
                "split": split_name,
            })
    df = pd.DataFrame(rows)
    df["risk_rank"] = df["log_risk"].rank(ascending=False, method="dense")
    return df

df_test = collect_preds(test_loader, "test")
df_test.to_csv("AD_Cox_Files/test_preds.csv", index=False)
df_test.head()


Unnamed: 0,pid,t_last,event,log_risk,split,risk_rank
0,002_S_0729,1.0924,1,1.758306,test,4.0
1,002_S_1155,3.0089,0,-3.738484,test,30.0
2,005_S_0546,3.04997,0,-3.657618,test,25.0
3,006_S_1130,1.48939,1,0.894991,test,10.0
4,007_S_0041,1.49487,1,1.153793,test,8.0


In [18]:
import numpy as np
@torch.no_grad()
def collect_times_events_risks(loader):
    t_list, e_list, r_list = [], [], []
    for b in loader:
        imgs  = b["X_img"].to(device) if b["X_img"] is not None else None
        demos = b["X_struct"].to(device)
        times = b["X_time"].unsqueeze(-1).to(device)
        lens  = b["lengths"]
        t_last= b["t_event"].cpu().numpy()
        event = b["event"].cpu().numpy().astype(int)
        log_r = model(imgs=imgs, demos=demos, times=times, lengths=lens).cpu().numpy()
        t_list.append(t_last); e_list.append(event); r_list.append(log_r)
    t = np.concatenate(t_list); e = np.concatenate(e_list); r = np.concatenate(r_list)
    return t, e, r

train_loader_eval = DataLoader(train_ds, batch_size=8, shuffle=False, num_workers=0, collate_fn=pad_collate)
t_tr, e_tr, r_tr = collect_times_events_risks(train_loader_eval)

# Breslow baseline cumulative hazard H0(t)
order = np.argsort(t_tr)
t_sorted, e_sorted, r_sorted = t_tr[order], e_tr[order], r_tr[order]
unique_times = np.unique(t_sorted[e_sorted==1])
H0 = []
cum = 0.0
for u in unique_times:
    d_u = ( (t_sorted==u) & (e_sorted==1) ).sum()
    risk_set = np.exp(r_sorted[t_sorted >= u]).sum()
    cum += d_u / max(risk_set, 1e-12)
    H0.append(cum)
H0 = np.array(H0)

def S0_at(horizon):
    # step function: last H0 at time <= horizon
    if len(unique_times)==0: return 1.0
    idx = np.searchsorted(unique_times, horizon, side="right") - 1
    if idx < 0: return 1.0
    return float(np.exp(-H0[idx]))

# Compute per-patient survival probabilities at 1/2/3 years on the test set
t_te, e_te, r_te = collect_times_events_risks(test_loader)
horizons = [1.0, 2.0, 3.0]  # years from MCI
S0_vec = [S0_at(h) for h in horizons]
surv = {f"S@{h}y": [ (S0_vec[i] ** float(np.exp(r))) for r in r_te ] for i,h in enumerate(horizons)}
import pandas as pd
df_surv = pd.DataFrame({"t_last": t_te, "event": e_te, "log_risk": r_te, **surv})
df_surv.to_csv("AD_Cox_Files/test_survival_estimates.csv", index=False)
df_surv.head()


Unnamed: 0,t_last,event,log_risk,S@1.0y,S@2.0y,S@3.0y
0,1.0924,1,1.758306,0.563511,0.000116,1.379151e-11
1,3.0089,0,-3.738484,0.997651,0.963517,0.9025544
2,3.04997,0,-3.657618,0.997454,0.960506,0.8947941
3,1.48939,1,0.894991,0.785127,0.021858,2.627263e-05
4,1.49487,1,1.153793,0.730984,0.007066,1.165548e-06


In [19]:
import numpy as np
import torch

@torch.no_grad()
def collect_times_events_risks(model, loader, device):
    model.eval()
    T, E, R = [], [], []
    for b in loader:
        imgs  = b["X_img"].to(device) if b["X_img"] is not None else None
        demos = b["X_struct"].to(device)
        times = b["X_time"].unsqueeze(-1).to(device)
        lens  = b["lengths"]
        t_ev  = b["t_event"].cpu().numpy().astype(float)
        e     = b["event"].cpu().numpy().astype(int)
        r     = model(imgs=imgs, demos=demos, times=times, lengths=lens).cpu().numpy()  # log-risk
        T.append(t_ev); E.append(e); R.append(r)
    t = np.concatenate(T); e = np.concatenate(E); eta = np.concatenate(R)
    return t, e, eta

def fit_breslow_baseline(t, e, eta):
    """ Breslow baseline cumulative hazard H0(t) 추정 """
    event_times = np.sort(np.unique(t[e == 1]))
    if event_times.size == 0:
        return np.array([]), np.array([]), np.array([])

    exp_eta = np.exp(eta)
    H0 = []
    cum = 0.0
    for u in event_times:
        d_u = np.sum((t == u) & (e == 1))
        risk_set = np.sum(exp_eta[t >= u])
        dh = d_u / max(risk_set, 1e-12)
        cum += dh
        H0.append(cum)
    H0 = np.asarray(H0)
    S0 = np.exp(-H0)
    return event_times, H0, S0

t_tr, e_tr, eta_tr = collect_times_events_risks(model, train_loader, device)
u_times, H0, S0 = fit_breslow_baseline(t_tr, e_tr, eta_tr)
print(f"Baseline fitted with {len(u_times)} unique event times.")


Baseline fitted with 47 unique event times.


In [20]:
import pandas as pd

def S0_at(h, u_times, H0):
    """Stepwise baseline survival S0(h) = exp(-H0(h)) at horizon h (right-continuous)."""

    if len(u_times) == 0:
        return 1.0
    idx = np.searchsorted(u_times, h, side="right") - 1
    if idx < 0:
        return 1.0
    return float(np.exp(-H0[idx]))

def predict_time_quantile(log_risk, q, u_times, H0):
    """
    Earliest time t where the individual survival S_i(t) = S0(t)^{exp(eta)} reaches S = q.
    If q = 0.5, this is the median time. If not found within the observed range, return np.nan.
    """

    if len(u_times) == 0:
        return np.nan
    gamma = float(np.exp(log_risk))
    # S_i(t_k) = [exp(-H0_k)]**gamma
    S_i = np.exp(-H0 * gamma)
    # index of the first step where S_i <= q
    idx = np.argmax(S_i <= q)  # first True position (returns 0 if all False)
    if S_i[idx] > q:           # if survival never drops below q
        return np.nan
    return float(u_times[idx])

@torch.no_grad()
def make_test_time_predictions(model, loader, device, u_times, H0, horizons=[1.0,2.0,3.0]):
    model.eval()
    rows = []
    S0_cache = {h: S0_at(h, u_times, H0) for h in horizons}
    for b in loader:
        imgs  = b["X_img"].to(device) if b["X_img"] is not None else None
        demos = b["X_struct"].to(device)
        times = b["X_time"].unsqueeze(-1).to(device)
        lens  = b["lengths"]
        pids  = b["pid"]
        log_r = model(imgs=imgs, demos=demos, times=times, lengths=lens).cpu().numpy()

        for i, pid in enumerate(pids):
            lr = float(log_r[i])
            gamma = float(np.exp(lr))
            # Predicted conversion times (median / 25% / 75%)
            t50 = predict_time_quantile(lr, 0.5,  u_times, H0)
            t25 = predict_time_quantile(lr, 0.75, u_times, H0)  # 75% survival → 25% converted (earlier time)
            t75 = predict_time_quantile(lr, 0.25, u_times, H0)  # 25% survival → 75% converted (later time)
            # Conversion probability within a given horizon (1 - S_i(h))
            probs = {}
            for h in horizons:
                S0h = S0_cache[h]
                Sih = S0h ** gamma
                probs[f"p_AD_by_{h:.1f}y"] = float(1.0 - Sih)
            rows.append({
                "pid": pid,
                "log_risk": lr,
                "risk": gamma,
                "t_median": t50,    # predicted conversion time (median)
                "t_25pct": t25,     # faster-side marker (25% converted; 75% survival)
                "t_75pct": t75,     # slower-side marker (75% converted; 25% survival)
                **probs
            })
    return pd.DataFrame(rows)

# --- Run: load the best model and then run test predictions ---
# (you can skip this if BEST_PATH has already been loaded)
# ckpt = torch.load(BEST_PATH, map_location=device)
# model.load_state_dict(ckpt["model"])

test = load_sequences("AD_Cox_Files/out_step6/seqs_test.pkl")
test_ds = ADSequenceDataset(test, img_transform=img_tfm, load_images=True)
test_loader = DataLoader(test_ds, batch_size=8, shuffle=False, num_workers=0, collate_fn=pad_collate)

df_pred = make_test_time_predictions(model, test_loader, device, u_times, H0, horizons=[1.0, 2.0, 3.0])
display(df_pred.head(10))
df_pred.to_csv("AD_Cox_Files/test_time_to_AD_predictions.csv", index=False)
print("Saved → AD_Cox_Files/test_time_to_AD_predictions.csv")


Unnamed: 0,pid,log_risk,risk,t_median,t_25pct,t_75pct,p_AD_by_1.0y,p_AD_by_2.0y,p_AD_by_3.0y
0,002_S_0729,1.758306,5.802601,1.03491,0.977413,1.49487,0.43649,0.999884,1.0
1,002_S_1155,-3.738484,0.02379,,3.08282,,0.002349,0.036483,0.097446
2,005_S_0546,-3.657618,0.025794,,3.04175,,0.002546,0.039494,0.105206
3,006_S_1130,0.894991,2.447315,1.50308,1.03491,1.67009,0.214873,0.978142,0.999974
4,007_S_0041,1.153793,3.170196,1.483917,0.999316,1.5332,0.269016,0.992934,0.999999
5,007_S_0293,-0.590058,0.554295,1.98768,1.67009,2.507872,0.053316,0.579337,0.90826
6,011_S_0241,1.830521,6.237134,1.01574,0.974675,1.483917,0.460181,0.999941,1.0
7,011_S_0362,-1.972143,0.139158,3.00342,2.04244,3.04175,0.013661,0.195387,0.451034
8,016_S_1121,1.352472,3.866971,1.47296,0.996578,1.5332,0.317668,0.997621,1.0
9,021_S_0273,-3.337314,0.035532,3.30459,3.02533,,0.003506,0.053996,0.141982


Saved → AD_Cox_Files/test_time_to_AD_predictions.csv


## 9) Baseline Hazard & Time-to-AD Predictions (Breslow + Median)

**Goal**: Turn risk scores into **time-to-AD** predictions.

**Principle**
- Compute **Breslow baseline cumulative hazard** \(H₀(t)\) on **train only**:
  \[
  H₀(t_k) = \sum_{u \le t_k} \frac{d(u)}{\sum_{j : t_j \ge u} \exp(\eta_j)}
  \]
- For each test patient with log-risk \(\eta\), survival is:
  \[
  S_i(t) = \exp\big(- H₀(t)\, e^{\eta_i} \big).
  \]
- Predict the **median time-to-AD** as the earliest \(t\) with \(S_i(t) \le 0.5\).
- Print per-patient lines:
  - `pred: X yrs | true: Y yrs | event: {0/1} | error: |X−Y| → HIT/OK/MISS`
    - Event: HIT if \(|X−Y| \le\) tolerance.
    - Censored: OK if `pred ≥ censor time`; “survive” if no crossing within horizon.

**Why it matters**
- Converts ranking into interpretable **time predictions** and keeps train/test separation.


In [21]:
import numpy as np
import torch

# ---------------------------
# 1) Baseline hazard (Breslow)
# ---------------------------
def estimate_baseline_hazard(times, events, risks):
    """
    Breslow: H0(t_k) = Σ_{u<=t_k} d(u) / Σ_{j: t_j >= u} exp(risk_j)
    times: (N,), events: (N,), risks: (N,)   # risks = log-risk (η); exp(risks) = risk
    Returns: (uniq_times_sorted, H0_cumsum)  # ascending time grid
    """
    order = np.argsort(times)               # ↑ ascending
    t = times[order].astype(float)
    e = events[order].astype(int)
    eta = risks[order].astype(float)

    uniq = np.unique(t[e == 1])
    if uniq.size == 0:
        return np.array([]), np.array([])

    exp_eta = np.exp(eta)
    H0_vals = []
    cum = 0.0
    for u in uniq:
        d_u = np.sum((t == u) & (e == 1))
        denom = np.sum(exp_eta[t >= u])
        if denom > 0:
            cum += d_u / denom
        H0_vals.append(cum)
    return uniq, np.array(H0_vals, dtype=float)

def make_H0_func(uniq_times, H0_vals):
    """ Continuous interpolation for H0(u). (left=0, right=last-value extrapolation) """
    if len(uniq_times) == 0:
        return lambda u: 0.0
    def H0_func(u):
        return float(np.interp(u, uniq_times, H0_vals, left=0.0, right=H0_vals[-1]))
    return H0_func

# ---------------------------
# 2) Collect arrays (loader → numpy)
# ---------------------------
@torch.no_grad()
def collect_arrays(model, loader, device):
    model.eval()
    T, E, R, P = [], [], [], []
    for b in loader:
        imgs  = b["X_img"].to(device) if b["X_img"] is not None else None
        demos = b["X_struct"].to(device)
        times = b["X_time"].unsqueeze(-1).to(device)
        lens  = b["lengths"]
        t_ev  = b["t_event"].cpu().numpy().astype(float)
        e     = b["event"].cpu().numpy().astype(int)
        log_r = model(imgs=imgs, demos=demos, times=times, lengths=lens).cpu().numpy().reshape(-1)
        T.append(t_ev); E.append(e); R.append(log_r); P.extend(b["pid"])
    return np.concatenate(T), np.concatenate(E), np.concatenate(R), P

# ---------------------------
# 3) Baseline estimation on the train set
# ---------------------------

t_tr, e_tr, r_tr, _ = collect_arrays(model, train_loader, device)
uniq_times, H0_vals = estimate_baseline_hazard(t_tr, e_tr, r_tr)
H0_func = make_H0_func(uniq_times, H0_vals)
print(f"[Baseline] unique event times = {len(uniq_times)}")

# ---------------------------
# 4) Per-patient predictions & accuracy on the test set
# ---------------------------

t_te, e_te, r_te, pid_te = collect_arrays(model, test_loader, device)

print("\n📋 Predicted AD conversion time (Cox model, ±0.2 yr tolerance)\n")

tol = 0.2          # tolerance (years)
max_years = 3.5    # maximum search horizon
step = 0.01        # scan in 0.01-year increments
threshold = 0.5    # median (50%) threshold


hit_count = 0
total_count = 0

for i, (pid, risk_score, t_true, e_true) in enumerate(zip(pid_te, r_te, t_te, e_te)):
    found_time = None
    for u in np.arange(step, max_years + step/2, step):
        H_t = H0_func(u) * np.exp(risk_score)  # H_i(u) = H0(u) * exp(η_i)
        S_t = np.exp(-H_t)
        if S_t < threshold:
            found_time = round(u, 2)
            break

    if found_time is not None:
        error = abs(found_time - t_true)
        if e_true == 1:
            is_hit = (error <= tol)
            status = "✅ HIT" if is_hit else "❌ MISS"
            if is_hit: hit_count += 1
        else:
            is_hit = (found_time >= t_true)
            status = "✅ OK (censored)" if is_hit else "❌ FAIL (censored)"
            if is_hit: hit_count += 1
        print(f"  - {pid:>10s} | pred: {found_time:5.2f} yrs | true: {t_true:5.2f} yrs | event: {int(e_true)} | error: {error:.2f} → {status}")

    else:
        if e_true == 0:
            status = "✅ SURVIVED (censored)"
            hit_count += 1
        else:
            status = "❌ MISMATCH (should convert)"
        print(f"  - {pid:>10s} | pred:  survive    | true: {t_true:5.2f} yrs | event: {int(e_true)} → {status}")

    total_count += 1

acc = hit_count / max(total_count, 1)
print(f"\n🎯 Overall Prediction Accuracy (±{tol} yr or survived OK): {hit_count} / {total_count} = {acc:.2%}")


[Baseline] unique event times = 47

📋 Predicted AD conversion time (Cox model, ±0.2 yr tolerance)

  - 002_S_0729 | pred:  1.03 yrs | true:  1.09 yrs | event: 1 | error: 0.06 → ✅ HIT
  - 002_S_1155 | pred:  survive    | true:  3.01 yrs | event: 0 → ✅ SURVIVED (censored)
  - 005_S_0546 | pred:  survive    | true:  3.05 yrs | event: 0 → ✅ SURVIVED (censored)
  - 006_S_1130 | pred:  1.51 yrs | true:  1.49 yrs | event: 1 | error: 0.02 → ✅ HIT
  - 007_S_0041 | pred:  1.48 yrs | true:  1.49 yrs | event: 1 | error: 0.01 → ✅ HIT
  - 007_S_0293 | pred:  1.98 yrs | true:  1.99 yrs | event: 1 | error: 0.01 → ✅ HIT
  - 011_S_0241 | pred:  1.02 yrs | true:  0.50 yrs | event: 1 | error: 0.52 → ❌ MISS
  - 011_S_0362 | pred:  3.01 yrs | true:  2.97 yrs | event: 1 | error: 0.04 → ✅ HIT
  - 016_S_1121 | pred:  1.39 yrs | true:  1.49 yrs | event: 1 | error: 0.10 → ✅ HIT
  - 021_S_0273 | pred:  3.18 yrs | true:  3.01 yrs | event: 0 | error: 0.17 → ✅ OK (censored)
  - 021_S_0424 | pred:  3.29 yrs | true:  

## 10) Subsequence / Landmark Inference — Residual-Time Predictions

**Goal**: Predict **remaining time** from a chosen start visit (landmark s).

**Principle**
- Re-zero the input time axis to the subset’s **first visit**.
- Use **conditional survival** from landmark \(s\):
  \[
  S(u \mid s)=\exp\!\big(-[H₀(s{+}u)-H₀(s)]\, e^{\eta}\big).
  \]
- Predict median **residual time** \(u^\*\) as earliest \(u\) with \(S(u\mid s)\le 0.5\).
- Evaluate:
  - Event: \(|u^\* - (t_{\text{event}}-s)|\) within tolerance → HIT.
  - Censored: \(u^\* \ge (t_{\text{censor}}-s)\) → OK.

**Notes**
- Landmark evaluation is **harder** than full-prefix evaluation (distribution shift if you allow non-prefix subsets like (1,3)).
- For fairness, prefer **prefix-only** subsets (1), (1,2), (1,2,3), … which match training usage.


In [22]:
import numpy as np, torch, itertools, string

model.eval()

# --- Breslow H0 interpolation ---
def H0_func_factory(uniq_times, H0_vals):
    if len(uniq_times) == 0:
        return lambda x: 0.0
    def H0_at(x):
        return float(np.interp(x, uniq_times, H0_vals, left=0.0, right=H0_vals[-1]))
    return H0_at

H0_at = H0_func_factory(uniq_times, H0_vals)

# --- Conditional (landmark s) median residual-time prediction ---
def median_residual_time_from_logrisk(log_risk, s, uniq_times, H0_vals, threshold=0.5, fallback=(3.5, 0.01)):
    """
    s: landmark 시각(부분 시퀀스의 첫 방문 시각; MCI 기준 Δt)
    u*: S(u|s) = exp(-[H0(s+u)-H0(s)] * exp(eta)) <= threshold 되는 최소 u
    """
    if len(uniq_times) == 0:
        return np.nan
    eta = float(log_risk)
    gamma = float(np.exp(eta))
    H_s = H0_at(s)

    mask = (uniq_times > s)
    if np.any(mask):
        u_grid = uniq_times[mask] - s
        H_diff = H0_vals[mask] - H_s
        S_cond = np.exp(-H_diff * gamma)
        idx = np.argmax(S_cond <= threshold)
        if S_cond[idx] <= threshold:
            return float(u_grid[idx])

    if fallback is not None:
        max_years, step = fallback
        u = step
        lastH = H0_vals[-1]
        while u <= max_years + 1e-9:
            H_diff = (H0_at(s + u) - H_s)
            S_cond = np.exp(-H_diff * gamma)
            if S_cond <= threshold:
                return float(u)
            u += step
    return np.nan

@torch.no_grad()
def predict_logrisk_for_subset(sample, idxs, device):
    idx = torch.tensor(list(idxs), dtype=torch.long)

    x_struct   = sample["x_struct"].index_select(0, idx)   # [t,F]
    x_time_abs = sample["x_time"].index_select(0, idx)     # [t]  (MCI 기준 Δt)
    x_img      = sample["x_img"].index_select(0, idx) if (sample["x_img"] is not None) else None

    # ★ landmark
    s = float(x_time_abs[0].item())
    # 모델 입력은 re-zero(첫 방문을 0으로)
    x_time_rel = x_time_abs - x_time_abs[0]

    demos = x_struct.unsqueeze(0).to(device)
    times = x_time_rel.unsqueeze(0).unsqueeze(-1).to(device)
    imgs  = x_img.unsqueeze(0).to(device) if x_img is not None else None
    lengths = torch.tensor([x_struct.shape[0]], dtype=torch.long)

    log_risk = model(imgs=imgs, demos=demos, times=times, lengths=lengths)
    return float(log_risk.squeeze(0).cpu().item()), s   # ★ s도 반환


def ordered_subsets(T):
    for r in range(1, T+1):
        for comb in itertools.combinations(range(T), r):
            yield comb

# --- pid suffix creator ---
def letter_seq():
    letters = list(string.ascii_lowercase)
    for a in letters: yield a
    for a in letters:
        for b in letters:
            yield a + b

# ===== Run: landmark inference for all subsequences across all test patients =====

tol = 0.2
fallback = (3.5, 0.01)
hit = 0; tot = 0

print("\n📋 Predicted AD conversion time from LANDMARK (residual time)\n")

for i in range(len(test_ds)):
    sample  = test_ds[i]
    pid     = sample["pid"]
    T       = sample["x_struct"].shape[0]
    t_event = float(sample["t_event"].item() if torch.is_tensor(sample["t_event"]) else sample["t_event"])
    event   = int(sample["event"].item() if torch.is_tensor(sample["event"]) else sample["event"])

    suf = letter_seq()
    for idxs in ordered_subsets(T):
        log_risk, s = predict_logrisk_for_subset(sample, idxs, device)  # ★ s 받은 것 주의
        u_pred = median_residual_time_from_logrisk(log_risk, s, uniq_times, H0_vals,
                                                   threshold=0.5, fallback=fallback)

        u_true = max(t_event - s, 0.0)
        pid_ext = f"{pid}_{next(suf)}"

        if not (u_pred is None or np.isnan(u_pred)):
            err = abs(u_pred - u_true)
            if event == 1:
                ok = (err <= tol)
                status = "✅ HIT" if ok else "❌ MISS"
                if ok: hit += 1
            else:
                ok = (u_pred >= u_true)
                status = "✅ OK (censored)" if ok else "❌ FAIL (censored)"
                if ok: hit += 1

            print(f"  - {pid_ext:>10s} | pred:  {u_pred:6.2f} yrs | true:  {u_true:6.2f} yrs | event: {event} | error: {err:.2f} → {status}")
        else:
            if event == 0:
                status = "✅ SURVIVED (censored)"
                hit += 1
            else:
                status = "❌ MISMATCH (should convert)"
            print(f"  - {pid_ext:>10s} | pred:  survive    | true:  {u_true:6.2f} yrs | event: {event} → {status}")
        tot += 1

print(f"\n🎯 Overall Landmark Accuracy (±{tol} yr or survived OK): {hit} / {tot} = {hit/max(tot,1):.2%}")




📋 Predicted AD conversion time from LANDMARK (residual time)

  - 002_S_0729_a | pred:    1.02 yrs | true:    1.09 yrs | event: 1 | error: 0.08 → ✅ HIT
  - 002_S_0729_b | pred:    0.50 yrs | true:    0.53 yrs | event: 1 | error: 0.04 → ✅ HIT
  - 002_S_0729_c | pred:    1.03 yrs | true:    1.09 yrs | event: 1 | error: 0.06 → ✅ HIT
  - 002_S_1155_a | pred:    1.02 yrs | true:    3.01 yrs | event: 0 | error: 1.99 → ❌ FAIL (censored)
  - 002_S_1155_b | pred:    0.56 yrs | true:    2.55 yrs | event: 0 | error: 1.99 → ❌ FAIL (censored)
  - 002_S_1155_c | pred:    0.45 yrs | true:    1.98 yrs | event: 0 | error: 1.52 → ❌ FAIL (censored)
  - 002_S_1155_d | pred:    0.28 yrs | true:    1.38 yrs | event: 0 | error: 1.10 → ❌ FAIL (censored)
  - 002_S_1155_e | pred:    0.32 yrs | true:    0.82 yrs | event: 0 | error: 0.50 → ❌ FAIL (censored)
  - 002_S_1155_f | pred:    0.00 yrs | true:    0.00 yrs | event: 0 | error: 0.00 → ✅ OK (censored)
  - 002_S_1155_g | pred:    1.03 yrs | true:    3.01 yrs 