In [1]:
#I 2025-11-20 07:35:15,051] Trial 93 finished with value: 0.7909400798921143 and parameters: {'backbone_lr': 0.00011984014753946137, 'head_lr': 5.693955800220594e-05, 'weight_decay': 1.7206999596641163e-08, 'optimizer_name': 'adamw', 'loss_name': 'softmargin', 'activation_name': 'gelu', 'head_hidden_dim': 512, 'head_dropout': 0.3929098112015572, 'tune_epochs': 16}. Best is trial 72 with value: 0.7978573596844559.

In [None]:
# domain_adapt_from_checkpoint.py
# EfficientNetV2-S + MLP head, ASL loss, MixUp/CutMix, TTA, 4-combo sampler.

import os, json, random, warnings, re
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import cv2
import torch
from torch import nn
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights

from sklearn.model_selection import GroupShuffleSplit
from sklearn.metrics import (
    roc_auc_score,
    roc_curve,
    average_precision_score,
    multilabel_confusion_matrix,  # <--- ADDED
)

import albumentations as A
from albumentations.pytorch import ToTensorV2

import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

# ========================== PATHS ============================================
TRAIN_CSV       = "/scratch/ssriva94/Capstone/new/low-quality-image.fixed_paths.csv"
TRAIN_PATH_COL  = "New_path"

TEST_A_CSV      = "/scratch/ssriva94/Capstone/new/high-quality-image.with_pid.csv"
TEST_A_PATH_COL = "New_path"

TEST_B_CSV      = "/scratch/ssriva94/Capstone/new/low-quality-image.fixed_paths.csv"
TEST_B_PATH_COL = "New_path"

CHECKPOINT_PATH = "/scratch/ssriva94/Capstone/new/m-epoch_FL_run3.pth.tar"  # dense checkpoint; strict=False

# ============================ LABELS / CONFIG ================================
CHEXPERT_LABELS = [
    "No Finding","Enlarged Cardiomediastinum","Cardiomegaly","Lung Opacity",
    "Lung Lesion","Edema","Consolidation","Pneumonia","Atelectasis",
    "Pneumothorax","Pleural Effusion","Pleural Other","Fracture","Support Devices",
]

OUTPUT_ROOT = "./runs_adapt"

# Higher resolution
TRAIN_CROP_SIZE = 320
VAL_RESIZE      = 384
VAL_CROP_SIZE   = 320

BATCH_SIZE  = 12
VAL_SPLIT   = 0.20
SEED        = 1337
NUM_WORKERS = 4
READ_GRAYSCALE = True
THRESH      = 0.5

INIT_FROM_IMAGENET           = True
OVERLAY_CHECKPOINT_ON_TOP    = True      # will mostly be ignored (different arch), but kept
REINIT_CLASSIFIER_AFTER_LOAD = True

# ---- Hyperparams (from your good trial) -------------------------------------
BACKBONE_LR   = 1.8425919726599874e-07
HEAD_LR       = 0.0011007522070230675
WEIGHT_DECAY  = 3.69534470108604e-07

LOSS_NAME       = "asl"       # <--- ASL
ACTIVATION_NAME = "gelu"
HEAD_HIDDEN_DIM = 128
HEAD_DROPOUT    = 0.5585538051607142

EPOCHS = 35
WARMUP_EPOCHS      = 2
MIN_LR             = 1e-6
EARLYSTOP_PATIENCE = 12

# ---- MixUp / CutMix ---------------------------------------------------------
USE_MIXUP      = True
MIXUP_ALPHA    = 0.2
CUTMIX_ALPHA   = 0.2
MIXUP_PROB     = 0.7   # prob to apply mixup/cutmix each batch
CUTMIX_PROB    = 0.5   # inside that, prob to choose cutmix vs mixup

# ---- Test-time augmentation (TTA) -------------------------------------------
USE_TTA = True   # simple horizontal flip TTA

# ---- Attribute metadata -----------------------------------------------------
SUBSET_SOURCE_FILE       = "/scratch/ssriva94/Capstone/new/low-quality-image.csv"
SUBSET_LIGHTING_COL      = "Lighting Source"
SUBSET_HEIGHT_COL        = "Height Variation"
SUBSET_PATH_COL_CANDIDATES = ["New_path","Path","path","filepath","file_path"]

# all 4 combos for eval
SUBSET_DEFINITIONS = [
    {"name": "subset_white_furtherout",
     SUBSET_LIGHTING_COL: "White Light", SUBSET_HEIGHT_COL: "Further Out"},
    {"name": "subset_white_closerin",
     SUBSET_LIGHTING_COL: "White Light", SUBSET_HEIGHT_COL: "Closer In"},
    {"name": "subset_yellow_furtherout",
     SUBSET_LIGHTING_COL: "Yellow Light", SUBSET_HEIGHT_COL: "Further Out"},
    {"name": "subset_yellow_closerin",
     SUBSET_LIGHTING_COL: "Yellow Light", SUBSET_HEIGHT_COL: "Closer In"},
]

warnings.filterwarnings("ignore")
os.makedirs(OUTPUT_ROOT, exist_ok=True)
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

if not torch.cuda.is_available():
    raise RuntimeError("CUDA requested but not available")
device = torch.device("cuda")
AMP = True

# =========================== UTILS ===========================================
MAP_TRUE  = {"1","1.0","true","t","yes","y","positive","pos"}
MAP_FALSE = {"0","0.0","false","f","no","n","negative","neg"}
MAP_UNC   = {"-1","-1.0","uncertain","u","na","nan",""}

def normalize_label(val):
    if pd.isna(val): return 0.0
    if isinstance(val,(int,float,np.number)):
        return 1.0 if float(val) >= 0.5 else 0.0
    s = str(val).strip().strip('"').strip("'").lower()
    if s in MAP_TRUE: return 1.0
    if s in MAP_FALSE or s in MAP_UNC: return 0.0
    try: return 1.0 if float(s) >= 0.5 else 0.0
    except: return 0.0

PATIENT_SEG_RE = re.compile(r"^patient(\d+)$", re.IGNORECASE)
def pid_from_path_strict_or_fallback(p: str) -> str:
    parts = Path(str(p)).parts
    for seg in parts:
        m = PATIENT_SEG_RE.match(seg)
        if m: return f"patient{m.group(1)}".lower()
    return f"auto_{Path(str(p)).stem.lower()}"

def ensure_patient_id(df: pd.DataFrame, path_col: str) -> pd.DataFrame:
    df = df.copy()
    if "patient_id" not in df.columns:
        df["patient_id"] = ""
    mask = df["patient_id"].astype(str).str.strip().eq("")
    if mask.any():
        df.loc[mask, "patient_id"] = df.loc[mask, path_col].apply(pid_from_path_strict_or_fallback)
    df["patient_id"] = df["patient_id"].astype(str).str.lower()
    return df

def read_img(p):
    if READ_GRAYSCALE:
        img = cv2.imread(p, cv2.IMREAD_GRAYSCALE)
        if img is None: raise FileNotFoundError(f"Missing image: {p}")
        img = np.stack([img,img,img], axis=-1)
    else:
        img = cv2.imread(p, cv2.IMREAD_COLOR)
        if img is None: raise FileNotFoundError(f"Missing image: {p}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

class CSVImageDataset(torch.utils.data.Dataset):
    def __init__(self, df, path_col, label_cols, a_transform, expect_labels=True):
        df = df.copy()
        self.has_labels = expect_labels and all(c in df.columns for c in label_cols)
        if self.has_labels:
            for c in label_cols:
                df[c] = df[c].apply(normalize_label).astype("float32")
        exists_mask = df[path_col].astype(str).apply(lambda x: os.path.exists(str(x)))
        missing = int((~exists_mask).sum())
        if missing:
            print(f"[Dataset] dropping {missing} rows with missing files. Examples:")
            for b in df.loc[~exists_mask, path_col].astype(str).head(10).tolist():
                print("  -", b)
        self.df = df.loc[exists_mask].reset_index(drop=True)
        self.path_col = path_col
        self.label_cols = label_cols
        self.tf = a_transform

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = read_img(str(row[self.path_col]))
        img = self.tf(image=img)["image"]
        if self.has_labels:
            y = torch.tensor(row[self.label_cols].values.astype("float32"),
                             dtype=torch.float32)
        else:
            y = torch.zeros(len(self.label_cols), dtype=torch.float32)
        return img, y

def multilabel_acc(probs, target, thr=THRESH):
    pred = (probs >= thr).float()
    return (pred == target).float().mean(dim=1).mean().item()

def _clean_state_dict(sd: dict):
    out = {}
    for k, v in sd.items():
        kk = k
        for pref in ("module.","densenet121.","efficientnet_v2_s.","model.","backbone."):
            if kk.startswith(pref): kk = kk[len(pref):]
        out[kk] = v
    return out

# =========================== MODEL HEAD ======================================
class MLPHead(nn.Module):
    def __init__(self, in_features, out_features,
                 hidden_dim=None, dropout=0.0, activation_name="relu"):
        super().__init__()
        act_name = activation_name.lower()
        if act_name == "gelu":
            act = nn.GELU()
        elif act_name == "leaky_relu":
            act = nn.LeakyReLU(inplace=True)
        else:
            act = nn.ReLU(inplace=True)
        layers = []
        if hidden_dim and hidden_dim > 0:
            layers.append(nn.Linear(in_features, hidden_dim))
            layers.append(act)
            if dropout and dropout > 0:
                layers.append(nn.Dropout(dropout))
            layers.append(nn.Linear(hidden_dim, out_features))
        else:
            layers.append(nn.Linear(in_features, out_features))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

def build_model():
    print("\n[config] Building EfficientNetV2-S model...")
    if INIT_FROM_IMAGENET:
        try:
            model = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)
            print("[init] TorchVision EfficientNetV2-S ImageNet weights")
        except Exception as e:
            print(f"[init] ImageNet load failed ({e}); using random init")
            model = efficientnet_v2_s(weights=None)
    else:
        model = efficientnet_v2_s(weights=None)
        print("[init] Random init")

    # Replace classifier
    if isinstance(model.classifier, nn.Sequential):
        in_feats = model.classifier[-1].in_features
    else:
        in_feats = model.classifier.in_features

    model.classifier = MLPHead(in_feats, len(CHEXPERT_LABELS),
                               hidden_dim=HEAD_HIDDEN_DIM,
                               dropout=HEAD_DROPOUT,
                               activation_name=ACTIVATION_NAME)

    if OVERLAY_CHECKPOINT_ON_TOP and CHECKPOINT_PATH and os.path.isfile(CHECKPOINT_PATH):
        try:
            print(f"[ckpt] loading checkpoint from {CHECKPOINT_PATH}")
            ckpt = torch.load(CHECKPOINT_PATH, map_location="cpu")
            sd = ckpt.get("state_dict", ckpt.get("model_state_dict", ckpt))
            if not isinstance(sd, dict): sd = ckpt
            sd = _clean_state_dict(sd)
            missing, unexpected = model.load_state_dict(sd, strict=False)
            print(f"[ckpt] overlay strict=False | missing={len(missing)} unexpected={len(unexpected)}")
            if missing:    print("        missing keys:", missing[:10])
            if unexpected: print("        unexpected keys:", unexpected[:10])
            if REINIT_CLASSIFIER_AFTER_LOAD:
                model.classifier = MLPHead(in_feats, len(CHEXPERT_LABELS),
                                           hidden_dim=HEAD_HIDDEN_DIM,
                                           dropout=HEAD_DROPOUT,
                                           activation_name=ACTIVATION_NAME)
                print("[ckpt] reinitialized classifier head (MLPHead)")
        except Exception as e:
            print(f"[ckpt] failed to load checkpoint: {e} â€” continuing without it.")
    print("[config] Model classifier:\n", model.classifier)
    return model.to(device)

def split_param_groups(model):
    head_params, body_params = [], []
    for n, p in model.named_parameters():
        if "classifier" in n: head_params.append(p)
        else: body_params.append(p)
    return [
        {"params": body_params, "lr": BACKBONE_LR, "weight_decay": WEIGHT_DECAY},
        {"params": head_params,  "lr": HEAD_LR,     "weight_decay": WEIGHT_DECAY},
    ]

class WarmupCosine(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, opt, warmup_epochs, total_epochs, min_lr=1e-6, last_epoch=-1):
        self.warmup_epochs = warmup_epochs
        self.total_epochs = total_epochs
        self.min_lr = min_lr
        super().__init__(opt, last_epoch)
    def get_lr(self):
        e = self.last_epoch + 1
        lrs = []
        for base in self.base_lrs:
            if e <= self.warmup_epochs:
                lrs.append(base * e / max(1, self.warmup_epochs))
            else:
                t = (e - self.warmup_epochs) / max(1, self.total_epochs - self.warmup_epochs)
                lrs.append(self.min_lr + 0.5*(base - self.min_lr)*(1 + np.cos(np.pi*t)))
        return lrs

# ============================= LOSSES ========================================
class AsymmetricLossMultiLabel(nn.Module):
    def __init__(self, gamma_pos=0.0, gamma_neg=2.0, clip=0.05,
                 eps=1e-8, label_smoothing=0.02, reduction="mean"):
        super().__init__()
        self.gamma_pos = gamma_pos; self.gamma_neg = gamma_neg
        self.clip = clip; self.eps = eps
        self.ls = label_smoothing; self.reduction = reduction
    def forward(self, logits, targets):
        if self.ls > 0:
            targets = targets * (1 - self.ls) + 0.5 * self.ls
        x_sigmoid = torch.sigmoid(logits)
        xs_pos = x_sigmoid.clamp(self.eps, 1 - self.eps)
        xs_neg = (1.0 - x_sigmoid).clamp(self.eps, 1 - self.eps)
        if self.clip and self.clip > 0:
            xs_neg = torch.clamp(xs_neg + self.clip, max=1.0)
        los_pos = -targets * torch.log(xs_pos)
        los_neg = -(1.0 - targets) * torch.log(xs_neg)
        if self.gamma_pos > 0 or self.gamma_neg > 0:
            with torch.no_grad():
                w_pos = (1.0 - xs_pos) ** self.gamma_pos
                w_neg = (1.0 - xs_neg) ** self.gamma_neg
            los_pos *= w_pos; los_neg *= w_neg
        loss = los_pos + los_neg
        if self.reduction == "mean": return loss.mean()
        if self.reduction == "sum":  return loss.sum()
        return loss

# ============================= MIXUP / CUTMIX ================================
def rand_bbox(size, lam):
    # size: (B, C, H, W)
    W = size[3]
    H = size[2]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    x1 = np.clip(cx - cut_w // 2, 0, W)
    x2 = np.clip(cx + cut_w // 2, 0, W)
    y1 = np.clip(cy - cut_h // 2, 0, H)
    y2 = np.clip(cy + cut_h // 2, 0, H)
    return x1, y1, x2, y2

def apply_mixup_cutmix(x, y):
    if not USE_MIXUP:
        return x, y
    if random.random() > MIXUP_PROB:
        return x, y

    batch_size = x.size(0)
    index = torch.randperm(batch_size, device=x.device)

    use_cutmix = (random.random() < CUTMIX_PROB) and CUTMIX_ALPHA > 0
    if use_cutmix:
        lam = np.random.beta(CUTMIX_ALPHA, CUTMIX_ALPHA)
        x1, y1, x2, y2 = rand_bbox(x.size(), lam)
        x_cut = x.clone()
        x_cut[:, :, y1:y2, x1:x2] = x[index, :, y1:y2, x1:x2]
        # adjust lambda to exactly match pixel ratio
        lam = 1 - ((x2 - x1) * (y2 - y1) / (x.size(-1) * x.size(-2)))
        y_mix = y * lam + y[index] * (1. - lam)
        return x_cut, y_mix
    else:
        # MixUp
        if MIXUP_ALPHA <= 0:
            return x, y
        lam = np.random.beta(MIXUP_ALPHA, MIXUP_ALPHA)
        x_mix = lam * x + (1 - lam) * x[index]
        y_mix = lam * y + (1 - lam) * y[index]
        return x_mix, y_mix

# ============================ EVAL ===========================================
def forward_with_tta(model, xb):
    if not USE_TTA:
        return model(xb)
    logits = model(xb)
    xb_flip = torch.flip(xb, dims=[3])  # horizontal flip
    logits_flip = model(xb_flip)
    return 0.5 * (logits + logits_flip)

def eval_split(model, loader, split_name, out_dir, labels_present=True):
    model.eval()
    all_probs, all_true = [], []
    with torch.no_grad(), torch.cuda.amp.autocast(enabled=AMP):
        for xb, yb in loader:
            xb = xb.to(device, non_blocking=True)
            logits = forward_with_tta(model, xb)
            probs = torch.sigmoid(logits).cpu().numpy()
            all_probs.append(probs); all_true.append(yb.numpy())
    y_prob = np.concatenate(all_probs, axis=0)
    y_true = np.concatenate(all_true,  axis=0)

    # Save raw probabilities
    pd.DataFrame(y_prob, columns=[f"prob_{c}" for c in CHEXPERT_LABELS]) \
      .to_csv(os.path.join(out_dir, f"{split_name}_preds.csv"), index=False)

    if not labels_present:
        print(f"[{split_name}] no labels present. Wrote probabilities.")
        return None

    # ---------------- AUROC + ROC curves -------------------------------------
    per_class_auroc = []
    with PdfPages(os.path.join(out_dir, f"{split_name}_roc_per_class.pdf")) as pdf:
        for i, name in enumerate(CHEXPERT_LABELS):
            y = y_true[:, i]; p = y_prob[:, i]
            if len(np.unique(y)) < 2:
                per_class_auroc.append(np.nan)
                plt.figure(figsize=(6,5))
                plt.text(0.5,0.5,"Insufficient positives/negatives",
                         ha="center",va="center"); plt.axis("off")
                plt.title(f"{name} - AUROC: NA"); pdf.savefig(); plt.close()
                continue
            au = roc_auc_score(y, p); per_class_auroc.append(au)
            fpr, tpr, _ = roc_curve(y, p)
            plt.figure(figsize=(6,5))
            plt.plot(fpr, tpr, label="ROC"); plt.plot([0,1],[0,1],"--")
            plt.xlabel("FPR"); plt.ylabel("TPR")
            plt.title(f"{name} - AUROC: {au:.3f}")
            plt.legend(); plt.tight_layout(); pdf.savefig(); plt.close()
    macro_auroc = float(np.nanmean(per_class_auroc))
    df_auroc = pd.DataFrame({"class": CHEXPERT_LABELS,"auroc":per_class_auroc})
    df_auroc = pd.concat([df_auroc,
                          pd.DataFrame([{"class":"macro","auroc":macro_auroc}])],
                         ignore_index=True)
    df_auroc.to_csv(os.path.join(out_dir, f"{split_name}_per_class_auroc.csv"),
                    index=False)

    # ---------------- PR-AUC --------------------------------------------------
    per_class_ap = []
    for i in range(y_true.shape[1]):
        y = y_true[:, i]; p = y_prob[:, i]
        if len(np.unique(y)) < 2:
            per_class_ap.append(np.nan); continue
        per_class_ap.append(average_precision_score(y, p))
    macro_ap = float(np.nanmean(per_class_ap))
    df_ap = pd.DataFrame({"class": CHEXPERT_LABELS,"prauc":per_class_ap})
    df_ap = pd.concat([df_ap,
                       pd.DataFrame([{"class":"macro","prauc":macro_ap}])],
                      ignore_index=True)
    df_ap.to_csv(os.path.join(out_dir,f"{split_name}_per_class_prauc.csv"),
                 index=False)

    # ---------------- Confusion matrices (per class) -------------------------
    # thresholding at global THRESH
    y_pred = (y_prob >= THRESH).astype(int)
    mcm = multilabel_confusion_matrix(y_true, y_pred)  # shape: (num_classes, 2, 2)

    rows_cm = []
    for i, name in enumerate(CHEXPERT_LABELS):
        tn, fp, fn, tp = mcm[i].ravel()
        rows_cm.append({
            "class": name,
            "tn": int(tn),
            "fp": int(fp),
            "fn": int(fn),
            "tp": int(tp),
        })
    df_cm = pd.DataFrame(rows_cm)
    df_cm.to_csv(os.path.join(out_dir, f"{split_name}_confusion_matrix.csv"),
                 index=False)
    print(f"[{split_name}] Saved per-class confusion matrix CSV.")

    print(f"[{split_name}] macro AUROC: {macro_auroc:.4f} | macro PR-AUC: {macro_ap:.4f}")
    return macro_auroc

def read_table_any(path: str) -> pd.DataFrame:
    p = str(path).lower()
    if p.endswith(".xlsx") or p.endswith(".xls"):
        return pd.read_excel(path)
    return pd.read_csv(path)

def detect_path_col(df: pd.DataFrame, candidates) -> str:
    for c in candidates:
        if c in df.columns:
            return c
    raise KeyError(f"No path column found. Looked for: {candidates}")

def _norm_str_series(s: pd.Series) -> pd.Series:
    return s.astype(str).str.strip().str.lower()

# ============================= MAIN ==========================================
def run():
    run_id = datetime.now().strftime("%Y%m%d-%H%M%S")
    out_dir = os.path.join(OUTPUT_ROOT, run_id); os.makedirs(out_dir, exist_ok=True)
    print("\n[run] New run id:", run_id)
    print("[run] Output directory:", out_dir)

    with open(os.path.join(out_dir,"config.json"),"w") as f:
        json.dump({
            "TRAIN_CSV":TRAIN_CSV,"TRAIN_PATH_COL":TRAIN_PATH_COL,
            "TEST_A_CSV":TEST_A_CSV,"TEST_A_PATH_COL":TEST_A_PATH_COL,
            "TEST_B_CSV":TEST_B_CSV,"TEST_B_PATH_COL":TEST_B_PATH_COL,
            "CHECKPOINT_PATH":CHECKPOINT_PATH,
            "BACKBONE_LR":BACKBONE_LR,"HEAD_LR":HEAD_LR,
            "WEIGHT_DECAY":WEIGHT_DECAY,"EPOCHS":EPOCHS,
            "LOSS_NAME":LOSS_NAME,"ACTIVATION_NAME":ACTIVATION_NAME,
            "HEAD_HIDDEN_DIM":HEAD_HIDDEN_DIM,"HEAD_DROPOUT":HEAD_DROPOUT,
            "SUBSET_SOURCE_FILE":SUBSET_SOURCE_FILE,
            "SUBSET_LIGHTING_COL":SUBSET_LIGHTING_COL,
            "SUBSET_HEIGHT_COL":SUBSET_HEIGHT_COL,
            "SUBSET_DEFINITIONS":SUBSET_DEFINITIONS,
            "USE_MIXUP":USE_MIXUP,"USE_TTA":USE_TTA,
            "TRAIN_CROP_SIZE":TRAIN_CROP_SIZE,"VAL_RESIZE":VAL_RESIZE,"VAL_CROP_SIZE":VAL_CROP_SIZE,
        }, f, indent=2)
    print("[run] Saved config.json")

    print("\n[data] Loading training CSV:", TRAIN_CSV)
    df = ensure_patient_id(pd.read_csv(TRAIN_CSV), TRAIN_PATH_COL)
    print("[data] Total rows (after patient_id ensure):", len(df))
    print("[data] Unique patients (total):", df["patient_id"].nunique())

    gss = GroupShuffleSplit(n_splits=1, test_size=VAL_SPLIT, random_state=SEED)
    idx = np.arange(len(df))
    tr_idx, va_idx = next(gss.split(idx, groups=df["patient_id"].values))
    tr_df = df.iloc[tr_idx].reset_index(drop=True)
    va_df = df.iloc[va_idx].reset_index(drop=True)
    print(f"[data] Train rows: {len(tr_df)}, Val rows: {len(va_df)}")
    print(f"[data] Unique patients (train): {tr_df['patient_id'].nunique()}")
    print(f"[data] Unique patients (val):   {va_df['patient_id'].nunique()}")

    # Augs
    train_transform = A.Compose([
        A.RandomResizedCrop(size=(TRAIN_CROP_SIZE, TRAIN_CROP_SIZE),
                            scale=(0.8,1.0)),
        A.HorizontalFlip(p=0.5),
        # mild photometric jitter to help with yellow/lighting variations
        A.RandomBrightnessContrast(
            brightness_limit=(0.08,0.18),
            contrast_limit=(0.02,0.08),
            p=0.3),
        A.ColorJitter(brightness=0.05, contrast=0.05,
                      saturation=0.05, hue=0.02, p=0.3),
        A.GaussNoise(var_limit=(5.0,20.0),mean=0.0,p=0.2),
        A.ToFloat(max_value=255.0),
        A.Normalize(mean=[0.5]*3,std=[0.5]*3),
        ToTensorV2(),
    ])
    eval_transform = A.Compose([
        A.Resize(height=VAL_RESIZE,width=VAL_RESIZE,
                 interpolation=cv2.INTER_AREA),
        A.CenterCrop(height=VAL_CROP_SIZE,width=VAL_CROP_SIZE),
        A.ToFloat(max_value=255.0),
        A.Normalize(mean=[0.5]*3,std=[0.5]*3),
        ToTensorV2(),
    ])

    num_workers = min(NUM_WORKERS, os.cpu_count() or 2)

    # ------------------ inverse-frequency sampler over 4 combos ---------------
    use_attr_sampler = False
    sampler = None
    tr_df_for_dataset = tr_df

    if SUBSET_SOURCE_FILE and os.path.isfile(SUBSET_SOURCE_FILE):
        try:
            print("\n[sampler] Building inverse-frequency sampler from",
                  SUBSET_SOURCE_FILE)
            attr_df = read_table_any(SUBSET_SOURCE_FILE)
            subset_path_col = detect_path_col(attr_df, SUBSET_PATH_COL_CANDIDATES)
            print(f"[sampler] Detected path column: {subset_path_col}")

            attr_df[subset_path_col] = attr_df[subset_path_col].astype(str)
            df[TRAIN_PATH_COL]       = df[TRAIN_PATH_COL].astype(str)

            attr_df["_ls"] = attr_df[SUBSET_LIGHTING_COL].astype(str).str.strip().str.lower()
            attr_df["_hv"] = attr_df[SUBSET_HEIGHT_COL].astype(str).str.strip().str.lower()

            df_attr_merged = df.merge(
                attr_df[[subset_path_col,"_ls","_hv"]],
                left_on=TRAIN_PATH_COL, right_on=subset_path_col, how="left"
            )
            tr_df_attr = df_attr_merged.iloc[tr_idx].reset_index(drop=True)

            ls = tr_df_attr["_ls"].fillna("").str.lower()
            hv = tr_df_attr["_hv"].fillna("").str.lower()

            combo_map = {
                ("white light","further out"): "white_furtherout",
                ("white light","closer in"):   "white_closerin",
                ("yellow light","further out"):"yellow_furtherout",
                ("yellow light","closer in"):  "yellow_closerin",
            }

            groups = []
            for l,h in zip(ls.values, hv.values):
                groups.append(combo_map.get((l,h),"other"))
            groups = pd.Series(groups, index=tr_df_attr.index, name="group")

            group_counts = groups.value_counts().sort_index()
            print("[sampler] Group counts in train:")
            for g,c in group_counts.items():
                print(f"          {g:16s} -> {c}")

            max_count = group_counts.max()
            max_ratio = 3.0
            group_weights = {}
            for g,c in group_counts.items():
                ratio = max_count / float(c)
                ratio = min(ratio, max_ratio)
                group_weights[g] = ratio
            print("[sampler] Group weight multipliers (capped):")
            for g,w in group_weights.items():
                print(f"          {g:16s} -> {w:.3f}")

            weights = np.array([group_weights[g] for g in groups],
                               dtype=np.float32)
            weights *= (len(weights) / weights.sum())
            print("[sampler] weight stats after norm: "
                  f"min={weights.min():.3f}, max={weights.max():.3f}, "
                  f"mean={weights.mean():.3f}")

            sampler = WeightedRandomSampler(
                weights=torch.from_numpy(weights.astype(np.float64)),
                num_samples=len(weights),
                replacement=True,
            )
            tr_df_for_dataset = tr_df_attr
            use_attr_sampler = True
            print("[sampler] Using inverse-frequency WeightedRandomSampler.")
        except Exception as e:
            print(f"[sampler] Failed to build sampler: {e}")
            print("[sampler] Falling back to standard shuffle.")
    else:
        print(f"[sampler] Attribute file missing or unset: {SUBSET_SOURCE_FILE}")

    ds_train = CSVImageDataset(tr_df_for_dataset, TRAIN_PATH_COL, CHEXPERT_LABELS,
                               train_transform, expect_labels=True)
    ds_val   = CSVImageDataset(va_df, TRAIN_PATH_COL, CHEXPERT_LABELS,
                               eval_transform, expect_labels=True)

    if use_attr_sampler and sampler is not None:
        dl_train = DataLoader(ds_train, batch_size=BATCH_SIZE, sampler=sampler,
                              num_workers=num_workers, pin_memory=True)
    else:
        dl_train = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=num_workers, pin_memory=True)
    dl_val = DataLoader(ds_val, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=num_workers, pin_memory=True)
    print(f"[data] ds_train size={len(ds_train)}, ds_val size={len(ds_val)}")
    print(f"[data] BATCH_SIZE={BATCH_SIZE}, NUM_WORKERS={num_workers}")

    # ------------------ model / opt / loss -----------------------------------
    model = build_model()

    # prior-based bias init
    try:
        print("[init] computing class priors from train split...")
        priors = tr_df_for_dataset[CHEXPERT_LABELS].applymap(normalize_label)\
                                                   .mean(axis=0).values.astype(np.float32)
        prior_logits = np.log(np.clip(priors,1e-6,1-1e-6) /
                              np.clip(1-priors,1e-6,1-1e-6))
        last_linear = None
        for m in model.classifier.modules():
            if isinstance(m, nn.Linear):
                last_linear = m
        if last_linear is not None:
            with torch.no_grad():
                last_linear.bias.copy_(torch.from_numpy(prior_logits)
                                       .to(last_linear.bias.device))
            print("[init] set classifier bias from priors")
        else:
            print("[init] WARNING: no Linear layer found in classifier.")
    except Exception as e:
        print(f"[init] could not set classifier bias: {e}")

    optimizer = torch.optim.AdamW(split_param_groups(model))
    scheduler = WarmupCosine(optimizer, WARMUP_EPOCHS, EPOCHS, min_lr=MIN_LR)
    print("\n[opt] Using AdamW:")
    for i,g in enumerate(optimizer.param_groups):
        print(f"      group {i}: lr={g['lr']:.3e}, weight_decay={g['weight_decay']:.3e}")
    print(f"[sched] WarmupCosine: warmup={WARMUP_EPOCHS}, epochs={EPOCHS}, min_lr={MIN_LR}")

    print(f"\n[config] LOSS_NAME={LOSS_NAME}")
    criterion = AsymmetricLossMultiLabel()
    print("[loss] Using AsymmetricLossMultiLabel")

    scaler = torch.cuda.amp.GradScaler(enabled=AMP)

    # ------------------ training loop ----------------------------------------
    best_auroc, no_improve = -1.0, 0
    hist = {"epoch":[], "lr_head":[], "lr_backbone":[],
            "train_loss":[], "train_acc":[],
            "val_loss":[], "val_acc":[],
            "val_auroc":[], "val_prauc":[]}

    print("\n[train] Starting training loop...")
    for epoch in range(1, EPOCHS+1):
        model.train()
        tl_sum = ta_sum = 0.0; tn = 0
        for xb,yb in dl_train:
            xb = xb.to(device,non_blocking=True)
            yb = yb.to(device,non_blocking=True)

            # MixUp / CutMix
            xb_mix, yb_mix = apply_mixup_cutmix(xb, yb)

            optimizer.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=AMP):
                logits = model(xb_mix)
                loss = criterion(logits, yb_mix)

            probs = torch.sigmoid(logits).detach().cpu()
            acc = multilabel_acc(probs, yb.detach().cpu(), thr=THRESH)  # acc vs original labels

            scaler.scale(loss).backward()
            scaler.step(optimizer); scaler.update()

            bs = xb.size(0); tl_sum += float(loss.item())*bs; ta_sum += acc*bs; tn += bs
        train_loss = tl_sum/max(1,tn); train_acc = ta_sum/max(1,tn)

        model.eval()
        vl_sum = va_sum = 0.0; vn = 0
        all_probs, all_true = [], []
        with torch.no_grad(), torch.cuda.amp.autocast(enabled=AMP):
            for xb,yb in dl_val:
                xb = xb.to(device,non_blocking=True)
                yb = yb.to(device,non_blocking=True)
                logits = forward_with_tta(model, xb)
                vloss = criterion(logits, yb)
                probs = torch.sigmoid(logits).detach().cpu()
                vacc = multilabel_acc(probs, yb.detach().cpu(), thr=THRESH)
                all_probs.append(probs.numpy()); all_true.append(yb.cpu().numpy())
                bs = xb.size(0); vl_sum += float(vloss.item())*bs; va_sum += vacc*bs; vn += bs
        val_loss = vl_sum/max(1,vn); val_acc = va_sum/max(1,vn)

        y_true = np.concatenate(all_true,axis=0)
        y_prob = np.concatenate(all_probs,axis=0)
        per_auroc, per_ap = [], []
        for i in range(y_true.shape[1]):
            y = y_true[:,i]; p = y_prob[:,i]
            if len(np.unique(y))<2:
                per_auroc.append(np.nan); per_ap.append(np.nan); continue
            per_auroc.append(roc_auc_score(y,p))
            per_ap.append(average_precision_score(y,p))
        val_macro_auroc = float(np.nanmean(per_auroc))
        val_macro_ap    = float(np.nanmean(per_ap))

        scheduler.step()
        lr_bb = optimizer.param_groups[0]["lr"]
        lr_hd = optimizer.param_groups[1]["lr"]

        hist["epoch"].append(epoch)
        hist["lr_backbone"].append(lr_bb); hist["lr_head"].append(lr_hd)
        hist["train_loss"].append(train_loss); hist["train_acc"].append(train_acc)
        hist["val_loss"].append(val_loss);     hist["val_acc"].append(val_acc)
        hist["val_auroc"].append(val_macro_auroc); hist["val_prauc"].append(val_macro_ap)

        print(f"[{epoch:02d}/{EPOCHS}] lr_bb={lr_bb:.2e} lr_hd={lr_hd:.2e} | "
              f"train_loss={train_loss:.4f} train_acc={train_acc:.4f} | "
              f"val_loss={val_loss:.4f} val_acc={val_acc:.4f} | "
              f"val_macro_AUROC={val_macro_auroc:.4f} | val_macro_PR-AUC={val_macro_ap:.4f}")

        if val_macro_auroc > best_auroc + 1e-6:
            best_auroc, no_improve = val_macro_auroc, 0
            torch.save({"state_dict":model.state_dict(),"epoch":epoch},
                       os.path.join(out_dir,"best.pt"))
            print(f"[model] Saved new best.pt (epoch={epoch}, macro_AUROC={best_auroc:.4f})")
        else:
            no_improve += 1
            print(f"[earlystop] no_improve={no_improve}/{EARLYSTOP_PATIENCE}")
            if no_improve >= EARLYSTOP_PATIENCE:
                print("[earlystop] stopping early.")
                break

    # --------- Final checkpoints ---------------------------------------------
    torch.save({"state_dict":model.state_dict(),"epoch":epoch},
           os.path.join(out_dir,"last.pt"))
    print(f"[model] Saved last.pt (epoch={epoch})")
    
    # --------- Reload BEST checkpoint and export that as .pth ----------------
    best_ckpt_path = os.path.join(out_dir, "best.pt")
    if os.path.isfile(best_ckpt_path):
        print(f"[export] Reloading best model from {best_ckpt_path} for export...")
        ckpt = torch.load(best_ckpt_path, map_location="cpu")
        state_dict = ckpt.get("state_dict", ckpt)
        model.load_state_dict(state_dict)
    else:
        print("[export] WARNING: best.pt not found, exporting last epoch instead.")
        state_dict = model.state_dict()
    
    export_path = os.path.join(out_dir, "model_export.pth")
    torch.save(model.state_dict(), export_path)
    print(f"[export] Saved pure state_dict BEST model to {export_path}")


    # --------- History & plots -----------------------------------------------
    hist_df = pd.DataFrame(hist)
    hist_df.to_csv(os.path.join(out_dir,"history.csv"),index=False)
    ep = np.arange(1, len(hist_df)+1)

    plt.figure(figsize=(7,5))
    plt.plot(ep,hist_df["train_loss"],label="train_loss")
    plt.plot(ep,hist_df["val_loss"],label="val_loss")
    plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.title("Loss Curves"); plt.legend()
    plt.tight_layout(); plt.savefig(os.path.join(out_dir,"train_val_curves.png")); plt.close()
    print("[plot] Saved train_val_curves.png")

    plt.figure(figsize=(7,5))
    plt.plot(ep,hist_df["val_auroc"],label="val_macro_AUROC")
    plt.plot(ep,hist_df["val_prauc"],label="val_macro_PR-AUC")
    plt.xlabel("Epoch"); plt.ylabel("Metric")
    plt.title("Validation AUROC / PR-AUC"); plt.legend()
    plt.tight_layout(); plt.savefig(os.path.join(out_dir,"val_auroc_prauc.png")); plt.close()
    print("[plot] Saved val_auroc_prauc.png")

    print("\n[val] Per-class metrics on validation split")
    _ = eval_split(model, dl_val, "val", out_dir, labels_present=True)

    # ---- full test splits ---------------------------------------------------
    def maybe_eval(csv_path, path_col, name):
        if not csv_path or not os.path.isfile(csv_path):
            print(f"[{name}] skipped (missing file: {csv_path})"); return
        print(f"\n[{name}] Evaluating on {csv_path}")
        tdf = ensure_patient_id(pd.read_csv(csv_path), path_col)
        ds  = CSVImageDataset(tdf, path_col, CHEXPERT_LABELS,
                              eval_transform, expect_labels=True)
        dl  = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False,
                         num_workers=num_workers, pin_memory=True)
        print(f"[{name}] Dataset size={len(ds)}")
        _ = eval_split(model, dl, name.lower(), out_dir, labels_present=ds.has_labels)

    maybe_eval(TEST_A_CSV, TEST_A_PATH_COL, "Test A")
    maybe_eval(TEST_B_CSV, TEST_B_PATH_COL, "Test B")

    # ---- subset evals for all 4 combos -------------------------------------
    print("\n[subset] Attribute-based subset evaluation...")
    if SUBSET_SOURCE_FILE and os.path.isfile(SUBSET_SOURCE_FILE):
        try:
            df_attr = read_table_any(SUBSET_SOURCE_FILE)
            print(f"[subset] Loaded file with {len(df_attr)} rows")
            missing_cols = [c for c in [SUBSET_LIGHTING_COL,SUBSET_HEIGHT_COL]
                            if c not in df_attr.columns]
            if missing_cols:
                print(f"[subset] missing columns: {missing_cols}")
            else:
                subset_path_col = detect_path_col(df_attr, SUBSET_PATH_COL_CANDIDATES)
                print(f"[subset] Detected path column: {subset_path_col}")
                ls_norm = _norm_str_series(df_attr[SUBSET_LIGHTING_COL])
                hv_norm = _norm_str_series(df_attr[SUBSET_HEIGHT_COL])
                for sd in SUBSET_DEFINITIONS:
                    name = sd["name"]
                    want_ls = sd[SUBSET_LIGHTING_COL].strip().lower()
                    want_hv = sd[SUBSET_HEIGHT_COL].strip().lower()
                    mask = ls_norm.eq(want_ls) & hv_norm.eq(want_hv)
                    sub = df_attr.loc[mask].copy()
                    print(f"\n[subset:{name}] target LS='{want_ls}', HV='{want_hv}', "
                          f"rows before existence check={len(sub)}")
                    if sub.empty:
                        print(f"[subset:{name}] subset empty, skipping."); continue
                    sub = ensure_patient_id(sub, subset_path_col)
                    exists_mask = sub[subset_path_col].astype(str).apply(
                        lambda x: os.path.exists(str(x)))
                    sub = sub.loc[exists_mask].reset_index(drop=True)
                    print(f"[subset:{name}] rows after existence check={len(sub)}")
                    if sub.empty:
                        print(f"[subset:{name}] no existing files, skipping."); continue
                    subset_csv_path = os.path.join(out_dir, f"{name}_rows.csv")
                    sub.to_csv(subset_csv_path,index=False)
                    print(f"[subset:{name}] saved rows to {subset_csv_path}")
                    ds = CSVImageDataset(sub, subset_path_col, CHEXPERT_LABELS,
                                         eval_transform, expect_labels=True)
                    dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False,
                                    num_workers=num_workers, pin_memory=True)
                    print(f"[subset:{name}] Evaluating subset (n={len(ds)})")
                    _ = eval_split(model, dl, name, out_dir, labels_present=ds.has_labels)
        except Exception as e:
            print(f"[subset] failed subset evaluation: {e}")
    else:
        print(f"[subset] attribute file missing or unset: {SUBSET_SOURCE_FILE}")

    print(f"\n[done] All outputs in: {out_dir}")

if __name__ == "__main__":
    run()



[run] New run id: 20251124-163525
[run] Output directory: ./runs_adapt/20251124-163525
[run] Saved config.json

[data] Loading training CSV: /scratch/ssriva94/Capstone/new/low-quality-image.fixed_paths.csv
[data] Total rows (after patient_id ensure): 1406
[data] Unique patients (total): 1366
[data] Train rows: 1130, Val rows: 276
[data] Unique patients (train): 1092
[data] Unique patients (val):   274

[sampler] Building inverse-frequency sampler from /scratch/ssriva94/Capstone/new/low-quality-image.csv
[sampler] Detected path column: New_path
[sampler] Group counts in train:
          other            -> 155
          white_closerin   -> 161
          white_furtherout -> 329
          yellow_closerin  -> 321
          yellow_furtherout -> 164
[sampler] Group weight multipliers (capped):
          other            -> 2.123
          white_closerin   -> 2.043
          white_furtherout -> 1.000
          yellow_closerin  -> 1.025
          yellow_furtherout -> 2.006
[sampler] weight st