In [1]:
# run_m3h_cancers.py
import os, json
import numpy as np
import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils import prune
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, average_precision_score

from M3H import M3H, train_with_early_stopping

# ----------------------------
# Config (edit these)
# ----------------------------
CANCER_CSV = "data/blood_protein_cancers_clean.csv"     # change
ID_COL    = "eid"                          # change
CANCER_LABELS = [
    "breast_cancer", "prostate_cancer", "lung_cancer", "colorectal_cancer", 
    "bladder_cancer", "pancreatic_cancer", "liver_cancer"
]
SAVE_DIR = "output/m3h_cancers"; os.makedirs(SAVE_DIR, exist_ok=True)
SEED=42; np.random.seed(SEED); torch.manual_seed(SEED)

BATCH_SIZE=128; MAX_EPOCHS=200; PATIENCE=20; LR=3e-4; WD=1e-4
VAL_FRAC=0.15; TEST_FRAC=0.15

# ----------------------------
# Data utils
# ----------------------------
def split(df, ycols):
    df = df.dropna(subset=ycols)
    # stratify on "has any cancer" to stabilize splits
    strat = (df[ycols].sum(axis=1) > 0).astype(int)
    rest, test = train_test_split(df, test_size=TEST_FRAC, random_state=SEED, stratify=strat)
    strat_rest = (rest[ycols].sum(axis=1) > 0).astype(int)
    val_size = VAL_FRAC/(1-TEST_FRAC)
    train, val = train_test_split(rest, test_size=val_size, random_state=SEED, stratify=strat_rest)
    return train, val, test

def matrices(df, ycols, id_col, feature_prefixes = ["blood_", "olink_"]):
    feature_cols = [
        c for c in df.columns
        if any(c.startswith(pref) for pref in feature_prefixes)
    ]
    X = df[feature_cols].astype(np.float32).values
    y = df[ycols].astype(np.float32).values
    return X, y

class TabularSet(Dataset):
    def __init__(self, X, y):
        self.X = torch.from_numpy(X)
        self.y = torch.from_numpy(y)
    def __len__(self): return len(self.X)
    def __getitem__(self, i): return self.X[i], self.y[i]

# ----------------------------
# Main
# ----------------------------
# def main():
#     df = pd.read_csv(CANCER_CSV)
#     tr_df, va_df, te_df = split(df, CANCER_LABELS)

#     Xtr, Ytr = matrices(tr_df, CANCER_LABELS, ID_COL)
#     Xva, Yva = matrices(va_df, CANCER_LABELS, ID_COL)
#     Xte, Yte = matrices(te_df, CANCER_LABELS, ID_COL)

#     scaler = StandardScaler().fit(Xtr)
#     Xtr = scaler.transform(Xtr).astype(np.float32)
#     Xva = scaler.transform(Xva).astype(np.float32)
#     Xte = scaler.transform(Xte).astype(np.float32)

#     dl_tr = DataLoader(TabularSet(Xtr, Ytr), batch_size=BATCH_SIZE, shuffle=True)
#     dl_va = DataLoader(TabularSet(Xva, Yva), batch_size=BATCH_SIZE, shuffle=False)
#     dl_te = DataLoader(TabularSet(Xte, Yte), batch_size=BATCH_SIZE, shuffle=False)

#     # --- Model: y1_bins = 7 tasks (one per cancer)
#     model = M3H(
#         input_dim=Xtr.shape[1],
#         y1_bins=7,              # 7 parallel heads/tasks
#         alpha=1.0,              # task attention mixing strength
#         hidden_dim=128,         # project feature subsets -> hidden_dim
#         hidden_layers=2,        # per-head depth
#         feature_indices_per_head=[list(range(Xtr.shape[1])) for _ in range(7)],
#         prune_amount=0.0        # turn off pruning unless you want it
#     )

#     opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)

#     # --- Use your built-in trainer (early stopping + best-state restore)
#     train_with_early_stopping(
#         model=model,
#         train_loader=dl_tr,
#         val_loader=dl_va,
#         optimizer=opt,
#         num_epochs=MAX_EPOCHS,
#         patience=PATIENCE,
#         l1=1e-4,               # L1 on head params (supported by your trainer)
#     )

#     # --- Evaluate (AUROC / AP per task)
#     model.eval()
#     probs, targs = [], []
#     with torch.no_grad():
#         for xb, yb in dl_te:
#             p = model(xb)              # [B,7] probabilities from your forward
#             probs.append(p.cpu().numpy())
#             targs.append(yb.cpu().numpy())
#     P = np.vstack(probs); Y = np.vstack(targs)

#     metrics = {}
#     aurocs, aps = [], []
#     for j, name in enumerate(CANCER_LABELS):
#         if len(np.unique(Y[:, j])) < 2:
#             continue
#         au = roc_auc_score(Y[:, j], P[:, j])
#         ap = average_precision_score(Y[:, j], P[:, j])
#         metrics[f"{name}_AUROC"] = float(au)
#         metrics[f"{name}_AP"]    = float(ap)
#         aurocs.append(au); aps.append(ap)
#     if aurocs:
#         metrics["macro_AUROC"] = float(np.mean(aurocs))
#         metrics["macro_AP"]    = float(np.mean(aps))

#     # --- Save
#     torch.save(model.state_dict(), os.path.join(SAVE_DIR, "m3h_7cancers.pt"))
#     np.save(os.path.join(SAVE_DIR, "test_probs.npy"), P)
#     np.save(os.path.join(SAVE_DIR, "test_targets.npy"), Y)
#     with open(os.path.join(SAVE_DIR, "metrics.json"), "w") as f: json.dump(metrics, f, indent=2)
#     with open(os.path.join(SAVE_DIR, "labels.json"), "w") as f: json.dump(CANCER_LABELS, f, indent=2)

#     print(json.dumps(metrics, indent=2))

# if __name__ == "__main__":
#     main()


In [2]:
df = pd.read_csv(CANCER_CSV)
tr_df, va_df, te_df = split(df, CANCER_LABELS)

Xtr, Ytr = matrices(tr_df, CANCER_LABELS, ID_COL)
Xva, Yva = matrices(va_df, CANCER_LABELS, ID_COL)
Xte, Yte = matrices(te_df, CANCER_LABELS, ID_COL)

scaler = StandardScaler().fit(Xtr)
Xtr = scaler.transform(Xtr).astype(np.float32)
Xva = scaler.transform(Xva).astype(np.float32)
Xte = scaler.transform(Xte).astype(np.float32)

dl_tr = DataLoader(TabularSet(Xtr, Ytr), batch_size=BATCH_SIZE, shuffle=True)
dl_va = DataLoader(TabularSet(Xva, Yva), batch_size=BATCH_SIZE, shuffle=False)
dl_te = DataLoader(TabularSet(Xte, Yte), batch_size=BATCH_SIZE, shuffle=False)

# --- Model: y1_bins = 7 tasks (one per cancer)
model = M3H(
    input_dim=Xtr.shape[1],
    y1_bins=7,              # 7 parallel heads/tasks
    alpha=1.0,              # task attention mixing strength
    hidden_dim=128,         # project feature subsets -> hidden_dim
    hidden_layers=2,        # per-head depth
    feature_indices_per_head=[list(range(Xtr.shape[1])) for _ in range(7)],
    prune_amount=0.0        # turn off pruning unless you want it
)

opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)

In [3]:
# --- Use your built-in trainer (early stopping + best-state restore)
train_with_early_stopping(
    model=model,
    train_loader=dl_tr,
    val_loader=dl_va,
    optimizer=opt,
    num_epochs=MAX_EPOCHS,
    patience=PATIENCE,
    l1=1e-4,               # L1 on head params (supported by your trainer)
)

Epoch 1 | Train Loss: nan | Val Loss: nan
Epoch 2 | Train Loss: nan | Val Loss: nan
Epoch 3 | Train Loss: nan | Val Loss: nan


KeyboardInterrupt: 

In [None]:
# --- Use your built-in trainer (early stopping + best-state restore)
train_with_early_stopping(
    model=model,
    train_loader=dl_tr,
    val_loader=dl_va,
    optimizer=opt,
    num_epochs=MAX_EPOCHS,
    patience=PATIENCE,
    l1=1e-4,               # L1 on head params (supported by your trainer)
)

# --- Evaluate (AUROC / AP per task)
model.eval()
probs, targs = [], []
with torch.no_grad():
    for xb, yb in dl_te:
        p = model(xb)              # [B,7] probabilities from your forward
        probs.append(p.cpu().numpy())
        targs.append(yb.cpu().numpy())
P = np.vstack(probs); Y = np.vstack(targs)

metrics = {}
aurocs, aps = [], []
for j, name in enumerate(CANCER_LABELS):
    if len(np.unique(Y[:, j])) < 2:
        continue
    au = roc_auc_score(Y[:, j], P[:, j])
    ap = average_precision_score(Y[:, j], P[:, j])
    metrics[f"{name}_AUROC"] = float(au)
    metrics[f"{name}_AP"]    = float(ap)
    aurocs.append(au); aps.append(ap)
if aurocs:
    metrics["macro_AUROC"] = float(np.mean(aurocs))
    metrics["macro_AP"]    = float(np.mean(aps))

# --- Save
torch.save(model.state_dict(), os.path.join(SAVE_DIR, "m3h_7cancers.pt"))
np.save(os.path.join(SAVE_DIR, "test_probs.npy"), P)
np.save(os.path.join(SAVE_DIR, "test_targets.npy"), Y)
with open(os.path.join(SAVE_DIR, "metrics.json"), "w") as f: json.dump(metrics, f, indent=2)
with open(os.path.join(SAVE_DIR, "labels.json"), "w") as f: json.dump(CANCER_LABELS, f, indent=2)

print(json.dumps(metrics, indent=2))