
# IoT-based Multimodal Pipeline for Early Mastitis Detection
This notebook provides a robust, leak-safe and energy-aware pipeline:
- **Tabular model** on clinical ground truth (CSV)
- **Imaging model** (frozen EfficientNet features + LR) on image labels
- **Cross-modal bridge** (tab→image embeddings) enabling fusion even when cohorts are not perfectly aligned
- **Fail-safe fusion** with proper clinical evaluation


In [None]:
# ===== 1) Configuration & Paths =====
# Purpose: centralise environment detection (Colab vs local), define base paths for data and images,
# and set global runtime options (seeds, batch size, GPU toggle). Keeping this in one place makes the
# rest of the notebook easier to audit and reproduce.

import os, random, json, re, glob, math, shutil, time, warnings
import numpy as np
import pandas as pd

# Detect whether we are running inside Google Colab.
# If true, we mount Drive to access project data stored in MyDrive.
try:
    from google.colab import drive  # type: ignore
    IN_COLAB = True
except Exception:
    IN_COLAB = False

if IN_COLAB:
    # Mount Google Drive for data I/O (figures, CSVs, image folders).
    drive.mount('/content/drive', force_remount=False)
    BASE_DRIVE = "/content/drive/MyDrive"
else:
    # Fallback to the current working directory when running locally.
    BASE_DRIVE = os.getcwd()

# Define project-level directories and file paths.
# Note: do not rename these without updating downstream cells that rely on them.
PROJECT_DIR = os.path.join(BASE_DRIVE, "Mastitis_illness_cow", "datasets")
TABULAR_CSV_PATH = os.path.join(PROJECT_DIR, "clinical_mastitis_cows_version1.csv")
IMAGE_DIR = os.path.join(PROJECT_DIR, "images")
LABEL_DIR = os.path.join(PROJECT_DIR, "labels")

# Runtime knobs for the imaging branch. Toggle GPU if a CUDA device is available.
USE_GPU_FOR_IMAGE_MODEL = True
BATCH_SIZE_IMAGE = 32

# Global reproducibility seed. Keep fixed for consistent splits and initialisations.
SEED = 42
random.seed(SEED); np.random.seed(SEED)

def seed_all_torch(seed=42):
    """
    Set PyTorch-specific seeds and deterministic flags when available.
    This stabilises results across runs on machines with CUDA/CuDNN.
    Safe no-op if PyTorch is not installed.
    """
    try:
        import torch
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    except Exception:
        # Silently continue if torch is unavailable in this environment.
        pass

# Apply torch seeding (if PyTorch is present); does not affect environments without torch.
seed_all_torch(SEED)

# Quick visibility checks so that path issues are caught early during runtime.
print("IN_COLAB:", IN_COLAB)
print("BASE_DRIVE:", BASE_DRIVE)
print("PROJECT_DIR:", PROJECT_DIR)
print("CSV exists:", os.path.exists(TABULAR_CSV_PATH))
print("IMAGE_DIR exists:", os.path.exists(IMAGE_DIR))
print("LABEL_DIR exists:", os.path.exists(LABEL_DIR))


In [None]:
# ===== 2) ADAPTIVE TASK: RISK_NEXT (visits) → RISK_WITHIN (days) → fallback D_proximity (visit-level) =====
# Purpose: derive a robust visit-level target (risk_next) using a progressive strategy.
# 1) Prefer a “next-K visits” definition if it yields enough positives across visits and cows.
# 2) Otherwise try a “within-H days” definition (time-based risk window).
# 3) If neither yields sufficient positives, fall back to a proximity-based proxy (last K visits before onset).

import os, re, glob, numpy as np, pandas as pd
from sklearn.model_selection import train_test_split

# --- Utilities ---------------------------------------------------------------
def digits_only(x: str) -> str:
    """Keep only digits from an identifier; return 'nan' for missing values."""
    if pd.isna(x): return "nan"
    return re.sub(r"\D", "", str(x))

def strip_leading_zeros(x: str) -> str:
    """Remove leading zeros from a digit string; return 'nan' for empty/missing."""
    if x in ("nan", "", None): return "nan"
    s = x.lstrip("0"); return s if s else "0"

def resolve_tabular_path(filename_default="clinical_mastitis_cows_version1.csv") -> str:
    """
    Locate the tabular CSV by checking:
    1) the global TABULAR_CSV_PATH (if defined),
    2) /mnt/data,
    3) current working directory,
    4) recursive search under PROJECT_DIR.
    """
    candidates = [
        TABULAR_CSV_PATH if 'TABULAR_CSV_PATH' in globals() else None,
        os.path.join('/mnt/data', filename_default),
        os.path.join(os.getcwd(), filename_default),
    ]
    for c in candidates:
        if c and os.path.exists(c):
            print("Using tabular CSV at:", c); return c
    for p in glob.glob(os.path.join(PROJECT_DIR, "**", filename_default), recursive=True):
        if os.path.exists(p):
            print("Found tabular CSV via recursive search at:", p); return p
    raise FileNotFoundError("Cannot locate tabular CSV. Check PROJECT_DIR/paths.")

# --- Hyperparameters and minimum thresholds ---------------------------------
K_LIST   = [3, 5, 7, 10, 14, 21, 30]   # lookahead in number of future visits
H_LIST   = [3, 5, 7, 10, 14, 21, 30]   # lookahead horizon in days
MIN_POS_VISITS = 50                    # minimum acceptable count of positive visits
MIN_POS_COWS   = 30                    # minimum acceptable count of cows with ≥1 positive visit

# --- Load tabular data -------------------------------------------------------
TABULAR_CSV_RESOLVED = resolve_tabular_path()
tab = pd.read_csv(TABULAR_CSV_RESOLVED)
print("[Tabular] Columns:", list(tab.columns))

# Key columns (auto-detected based on common names).
COW_ID_COL = next((c for c in ["Cow_ID","cow_id","CowID","animal_id","Animal_ID","subject_id","id","ID"] if c in tab.columns), None)
TIME_COL   = next((c for c in ["Day","day","time","Time","Days","days"] if c in tab.columns), None)
RAW_TARGET = next((c for c in ["class1","Class1","Label","label","mastitis","status","target","class","disease","outcome","y"] if c in tab.columns), None)
if COW_ID_COL is None: raise KeyError("Cow ID column not found.")
if TIME_COL   is None: raise KeyError("Time column (e.g., 'Day') not found.")
if RAW_TARGET is None: raise KeyError("Binary target column (e.g., 'class1') not found.")

# Normalise IDs and cast types for time and labels.
tab["Cow_ID_norm"]  = tab[COW_ID_COL].astype(str).map(digits_only).map(strip_leading_zeros)
tab["Cow_ID_match"] = tab["Cow_ID_norm"]
KEY = "Cow_ID_match"
print(
    "[CowID] Tabular Cow_ID normalised. Unique (non-nan):",
    tab["Cow_ID_norm"].replace("nan", np.nan).dropna().nunique()
)

tab[TIME_COL]   = pd.to_numeric(tab[TIME_COL], errors="coerce")
tab[RAW_TARGET] = pd.to_numeric(tab[RAW_TARGET], errors="coerce").fillna(0).astype(int)
tab = tab[tab[KEY].ne("nan")].copy().sort_values([KEY, TIME_COL]).reset_index(drop=True)

# ---------- Visit-level label constructors ----------------------------------
def build_risk_nextK_visits(df, K=3, key=KEY, tcol=TIME_COL, ycol=RAW_TARGET):
    """
    Visit-based risk: label a visit as 1 if any of the next K visits is positive.
    Exclude visits that are already positive at time t (set to -1 and filtered out).
    """
    rows = []
    for cow, g in df.groupby(key, sort=False):
        g = g.sort_values(tcol, na_position="last").reset_index(drop=True)
        y = g[ycol].values.astype(int); n = len(g)
        y_next = np.full(n, -1, dtype=int)  # -1 = invalid input (already positive at t)
        for t in range(n):
            if y[t] == 1:
                y_next[t] = -1
            else:
                if t == n - 1:
                    y_next[t] = 0
                else:
                    t2 = min(n - 1, t + K)
                    y_next[t] = int((y[t+1:t2+1] == 1).any())
        g2 = g.copy(); g2["risk_next"] = y_next
        rows.append(g2)
    out = pd.concat(rows, ignore_index=True)
    out = out[out["risk_next"] != -1].copy()
    out["risk_next"] = out["risk_next"].astype(int)
    return out

def build_risk_withinH_days(df, H=7, key=KEY, tcol=TIME_COL, ycol=RAW_TARGET):
    """
    Time-based risk: label a visit as 1 if a positive event occurs within H days after this visit.
    Exclude visits that are already positive at time t (set to -1 and filtered out).
    """
    rows = []
    for cow, g in df.groupby(key, sort=False):
        g = g.sort_values(tcol, na_position="last").reset_index(drop=True)
        y = g[ycol].values.astype(int); tvals = g[tcol].values.astype(float); n = len(g)
        rn = np.zeros(n, dtype=int)
        for t in range(n):
            if y[t] == 1:
                rn[t] = -1  # already positive at this visit → not a valid input
            else:
                # Is there a future positive with Δday in (0..H]?
                future_idx = np.where((tvals > tvals[t]) & (tvals - tvals[t] <= H))[0]
                rn[t] = int(any(y[j] == 1 for j in future_idx))
        g2 = g.copy(); g2["risk_next"] = rn
        rows.append(g2)
    out = pd.concat(rows, ignore_index=True)
    out = out[out["risk_next"] != -1].copy()
    out["risk_next"] = out["risk_next"].astype(int)
    return out

def build_proximity_visit_level(df, K=3, key=KEY, tcol=TIME_COL, ycol=RAW_TARGET):
    """
    Fallback proxy: mark as 1 the last K visits before the first onset,
    **including** the onset visit itself (i.e., a proximity-based signal).
    """
    rows = []
    for cow, g in df.groupby(key, sort=False):
        g = g.sort_values(tcol, na_position="last").reset_index(drop=True)
        y = g[ycol].values.astype(int); n = len(g)
        rn = np.zeros(n, dtype=int)
        pos = np.where(y == 1)[0]
        if len(pos) > 0:
            i0 = int(pos[0])
            j0 = max(0, i0 - (K - 1))
            rn[j0:i0+1] = 1  # onset included
        g2 = g.copy(); g2["risk_next"] = rn
        rows.append(g2)
    out = pd.concat(rows, ignore_index=True)
    # Note: we deliberately do not exclude positive visits here; this is a proximity proxy at visit-level.
    out["risk_next"] = out["risk_next"].astype(int)
    return out

# ---------- Adaptive search for the best label definition --------------------
chosen = None

# 1) Next-K visits
for K in K_LIST:
    cand = build_risk_nextK_visits(tab, K=K, key=KEY, tcol=TIME_COL, ycol=RAW_TARGET)
    pos_v = int((cand["risk_next"] == 1).sum())
    pos_c = int(cand.groupby(KEY)["risk_next"].max().sum())
    print(f"[TRY] RISK_NEXT@{K}visits | visits pos={pos_v} | cows pos={pos_c}")
    if pos_v >= MIN_POS_VISITS and pos_c >= MIN_POS_COWS:
        chosen = ("RISK_NEXT_visits", K, cand); break

# 2) Within-H days
if chosen is None:
    for H in H_LIST:
        cand = build_risk_withinH_days(tab, H=H, key=KEY, tcol=TIME_COL, ycol=RAW_TARGET)
        pos_v = int((cand["risk_next"] == 1).sum())
        pos_c = int(cand.groupby(KEY)["risk_next"].max().sum())
        print(f"[TRY] RISK_WITHIN@{H}days | visits pos={pos_v} | cows pos={pos_c}")
        if pos_v >= MIN_POS_VISITS and pos_c >= MIN_POS_COWS:
            chosen = ("RISK_WITHIN_days", H, cand); break

# 3) Fallback proximity (visit-level)
if chosen is None:
    K_fallback = 3
    cand = build_proximity_visit_level(tab, K=K_fallback, key=KEY, tcol=TIME_COL, ycol=RAW_TARGET)
    pos_v = int((cand["risk_next"] == 1).sum())
    pos_c = int(cand.groupby(KEY)["risk_next"].max().sum())
    print(f"[FALLBACK] PROXIMITY@{K_fallback}vis | visits pos={pos_v} | cows pos={pos_c}")
    chosen = ("PROXIMITY_visits", K_fallback, cand)

TASK_MODE, HYPER, df_risk = chosen
print(
    f"[CHOSEN] {TASK_MODE} param={HYPER} | visits pos={int((df_risk['risk_next']==1).sum())} "
    f"| cows pos={int(df_risk.groupby(KEY)['risk_next'].max().sum())} | N={len(df_risk)}"
)

# ---------- Leak-safe split by cow ------------------------------------------
# We split by cow so the same animal never appears across train/val/test,
# preventing identity leakage and overly optimistic metrics.
cow_any = df_risk.groupby(KEY)["risk_next"].max().astype(int)
all_cows = np.array(sorted(cow_any.index.astype(str)))
y_cows   = cow_any.reindex(all_cows).values

if len(np.unique(y_cows)) < 2:
    print("[WARN] Per-cow single class. Using non-stratified split.")
    tr_c, te_c = train_test_split(all_cows, test_size=0.20, random_state=42)
else:
    tr_c, te_c = train_test_split(all_cows, test_size=0.20, stratify=y_cows, random_state=42)

mask_tv = np.isin(all_cows, tr_c)
tv_cows = all_cows[mask_tv]
tv_y    = cow_any.reindex(tv_cows).values
if len(np.unique(tv_y)) < 2:
    tr_cows, val_cows = train_test_split(tv_cows, test_size=0.25, random_state=42)
else:
    tr_cows, val_cows = train_test_split(tv_cows, test_size=0.25, stratify=tv_y, random_state=42)

print(f"[Split] Train cows: {len(tr_cows)} | Val cows: {len(val_cows)} | Test cows: {len(te_c)}")
print(f"[READY] TASK_MODE='{TASK_MODE}' | label='risk_next' | hyper={HYPER}")


In [None]:
# ===== 2.5) Visit-level feature engineering (leak-safe) & cow-aligned splits =====
# Purpose: build leakage-safe visit-level features, strictly separated per split (train/val/test),
# while keeping cows disjoint across splits. We also include robust key/target de-duplication.

import numpy as np, pandas as pd
from sklearn.model_selection import train_test_split

# --------- Context and fallback split ----------------------------------------
assert 'tab' in globals(), "Run Cell 2 first: the DataFrame 'tab' is missing."

KEY = "Cow_ID_match"
YCOL_CANDIDATES = ["risk_next", "early", "class1", "Label", "label"]
YCOL = next((c for c in YCOL_CANDIDATES if c in tab.columns), None)
if YCOL is None:
    raise KeyError(f"No target column found in 'tab'. Expected one of: {YCOL_CANDIDATES}")

if KEY not in tab.columns:
    raise KeyError(f"Key {KEY} missing in 'tab'; please verify Cell 2.")

# --------- Robust utilities: de-duplicate target and key ---------------------
def coerce_and_dedup_target(df: pd.DataFrame, ycol: str) -> pd.DataFrame:
    """
    Ensure a single integer target column ycol exists in df.
    - If missing, create it as zeros.
    - If multiple columns with the same name exist (rare in merged frames),
      take the row-wise max of their numeric casts.
    """
    cols = [c for c in df.columns if c == ycol]
    if len(cols) == 0:
        df[ycol] = 0
        return df
    if len(cols) == 1:
        df[ycol] = pd.to_numeric(df[ycol], errors="coerce").fillna(0).astype(int)
        return df
    comb = df[cols].apply(pd.to_numeric, errors="coerce").fillna(0).max(axis=1).astype(int)
    df = df.drop(columns=cols, errors="ignore")
    df[ycol] = comb
    return df

def coerce_and_dedup_key(df: pd.DataFrame, key: str) -> pd.DataFrame:
    """
    Ensure a single key column exists in df.
    - If multiple duplicate-named columns appear, keep the first non-NaN per row.
    """
    cols = [c for c in df.columns if c == key]
    if len(cols) == 0:
        raise KeyError(f"Key {key} absent after preprocessing.")
    if len(cols) == 1:
        df[key] = df[key].astype(str)
        return df
    # Combine by taking the first non-NaN per row.
    tmp = (df[cols].astype(str).replace({"nan": np.nan, "None": np.nan}))
    comb = tmp.bfill(axis=1).iloc[:, 0].astype(str)
    df = df.drop(columns=cols, errors="ignore")
    df[key] = comb
    return df

# Initial de-duplication on 'tab'
tab = coerce_and_dedup_key(tab, KEY)
tab = coerce_and_dedup_target(tab, YCOL)

# If tr_cows/val_cows/test_cows do NOT exist, reconstruct them now (leak-safe group split)
if not all(k in globals() for k in ["tr_cows","val_cows","test_cows"]):
    print("[2.5 Fallback] Rebuilding cow-based splits…")
    cow_y = tab.groupby(KEY)[YCOL].max().astype(int)
    all_cows = np.array(sorted(cow_y.index.astype(str)))
    if cow_y.nunique() < 2:
        tr_all, te_all = train_test_split(all_cows, test_size=0.20, random_state=42, shuffle=True)
    else:
        tr_all, te_all = train_test_split(
            all_cows, test_size=0.20,
            stratify=cow_y.reindex(all_cows).values, random_state=42
        )
    tv_labels = cow_y.reindex(tr_all).values
    if len(np.unique(tv_labels)) < 2:
        tr_cows, val_cows = train_test_split(tr_all, test_size=0.25, random_state=42, shuffle=True)
    else:
        tr_cows, val_cows = train_test_split(
            tr_all, test_size=0.25, stratify=tv_labels, random_state=42
        )
    test_cows = te_all
    print(f"[2.5 Fallback] Train cows: {len(tr_cows)} | Val cows: {len(val_cows)} | Test cows: {len(test_cows)}")

# --------- Visit-ordered base copy ------------------------------------------
base = tab.copy()
if "Day" in base.columns:
    base = base.sort_values([KEY, "Day"]).reset_index(drop=True)
else:
    base = (
        base.sort_values([KEY])
            .assign(_visit_idx = base.groupby(KEY).cumcount())
            .sort_values([KEY, "_visit_idx"])
            .reset_index(drop=True)
    )

# --------- Select numeric columns for feature engineering --------------------
exclude_cols = {KEY, YCOL, "Cow_ID_norm", "onset_day", "Breed", "Previous_Mastits_status"}
num_cols_all = (
    base.drop(columns=[c for c in exclude_cols if c in base.columns], errors="ignore")
        .select_dtypes(include=[np.number])
        .columns.tolist()
)
if len(num_cols_all) == 0:
    raise RuntimeError("No numeric columns available for feature engineering.")

# --------- Utility functions -------------------------------------------------
def split_by_cows(df, cows):
    """Filter df to keep only rows belonging to the provided cow IDs."""
    return df[df[KEY].astype(str).isin(set(map(str, cows)))].reset_index(drop=True)

def add_time_features(df: pd.DataFrame, num_cols) -> pd.DataFrame:
    """
    Build simple per-cow temporal features (leak-safe):
    - lag1 for each numeric variable,
    - rolling means over 3 and 5 visits,
    - first differences for raw and rolling means,
    - per-cow expanding z-score (normalising by expanding mean/std).
    Missing values introduced by lags/rolling are filled with 0 for robustness.
    """
    d = df.copy()
    d = coerce_and_dedup_key(d, KEY)
    if "Day" in d.columns:
        d = d.sort_values([KEY, "Day"]).reset_index(drop=True)
    else:
        if "_visit_idx" not in d.columns:
            d["_visit_idx"] = d.groupby(KEY).cumcount()
        d = d.sort_values([KEY, "_visit_idx"]).reset_index(drop=True)

    for c in num_cols:
        grp = d.groupby(KEY)[c]
        d[f"{c}_lag1"]    = grp.shift(1)

        r3 = grp.rolling(3, min_periods=1).mean().reset_index(level=0, drop=True)
        r5 = grp.rolling(5, min_periods=1).mean().reset_index(level=0, drop=True)
        d[f"{c}_r3_mean"] = r3
        d[f"{c}_r5_mean"] = r5

        d[f"{c}_d1"]      = grp.diff(1)
        d[f"{c}_r3_d1"]   = d[f"{c}_r3_mean"].groupby(d[KEY]).diff(1)
        d[f"{c}_r5_d1"]   = d[f"{c}_r5_mean"].groupby(d[KEY]).diff(1)

        exp_mean = grp.expanding().mean().reset_index(level=0, drop=True)
        exp_std  = grp.expanding().std().reset_index(level=0, drop=True).replace(0, np.nan)
        z = (d[c] - exp_mean) / exp_std
        d[f"{c}_z_cow"] = z.replace([np.inf, -np.inf], np.nan)

    fe_cols = [c for c in d.columns if c not in df.columns]
    d[fe_cols] = d[fe_cols].fillna(0)
    # Safety: ensure the key column is clean and unique after feature creation.
    d = coerce_and_dedup_key(d, KEY)
    return d

def take_last_k(df, k=6):
    """
    Keep only the last k visits per cow (recent-history focus).
    If 'Day' is absent, use a synthetic visit index to define ordering.
    """
    d = coerce_and_dedup_key(df.copy(), KEY)
    if "Day" in d.columns:
        d["_rank_last"] = d.groupby(KEY)["Day"].rank(method="first", ascending=False)
    else:
        if "_visit_idx" not in d.columns:
            d["_visit_idx"] = d.groupby(KEY).cumcount()
        d["_rank_last"] = d.groupby(KEY)["_visit_idx"].rank(method="first", ascending=False)
    out = d[d["_rank_last"] <= k].drop(columns=["_rank_last"])
    return coerce_and_dedup_key(out.reset_index(drop=True), KEY)

def drop_degenerate(train_df, val_df, test_df, key, ycol):
    """
    Remove degenerate feature columns (all-NaN or zero-variance on train).
    Keep only numeric predictors plus key and target; ensure final de-duplication.
    Returns filtered train/val/test and the list of dropped columns.
    """
    # Ensure exactly one key/target column each.
    train_df = coerce_and_dedup_key(coerce_and_dedup_target(train_df, ycol), key)
    val_df   = coerce_and_dedup_key(coerce_and_dedup_target(val_df,   ycol), key)
    test_df  = coerce_and_dedup_key(coerce_and_dedup_target(test_df,  ycol), key)

    keep = []
    for c in train_df.columns:
        if c in {key, ycol, "Day", "_visit_idx", "Cow_ID_norm"}:
            keep.append(c); continue
        if str(train_df[c].dtype).startswith(("float","int")):
            col = train_df[c]
            if col.isna().all():                 # all NaN
                continue
            if col.nunique(dropna=True) <= 1:    # zero variance
                continue
            keep.append(c)
    tr2 = train_df[[k for k in keep if k in train_df.columns] + [key, ycol]].copy()
    va2 = val_df[[k for k in keep if k in val_df.columns] + [key, ycol]].copy()
    te2 = test_df[[k for k in keep if k in test_df.columns] + [key, ycol]].copy()

    # Final de-duplication for safety.
    tr2 = coerce_and_dedup_key(coerce_and_dedup_target(tr2, ycol), key)
    va2 = coerce_and_dedup_key(coerce_and_dedup_target(va2, ycol), key)
    te2 = coerce_and_dedup_key(coerce_and_dedup_target(te2, ycol), key)

    dropped = sorted(list(set(train_df.columns) - set(tr2.columns)))
    return tr2, va2, te2, dropped

def safe_target_series(df: pd.DataFrame, ycol: str) -> pd.Series:
    """Return a single integer Series for the target, consolidating multi-columns if needed."""
    obj = df[ycol]
    if isinstance(obj, pd.DataFrame):
        y = obj.apply(pd.to_numeric, errors="coerce").fillna(0).max(axis=1)
    else:
        y = pd.to_numeric(obj, errors="coerce").fillna(0)
    return y.astype(int)

def count_pos_visits(df, ycol):
    """Count visits labelled positive (1) using the safe target series."""
    y = safe_target_series(df, ycol)
    return int((y == 1).sum())

def count_pos_cows(df, key, ycol):
    """
    Count cows with at least one positive visit.
    Uses a single clean key column and the max over visits per cow.
    """
    df = coerce_and_dedup_key(df.copy(), key)
    y = safe_target_series(df, ycol)
    per_cow = df.assign(__y=y).groupby(key)["__y"].max()
    return int(per_cow.sum())

# --------- Split-specific (no-leak) feature engineering ---------------------
train_raw = split_by_cows(base, tr_cows)
val_raw   = split_by_cows(base, val_cows)
test_raw  = split_by_cows(base, test_cows)

# De-duplicate key/target BEFORE feature engineering.
train_raw = coerce_and_dedup_key(coerce_and_dedup_target(train_raw, YCOL), KEY)
val_raw   = coerce_and_dedup_key(coerce_and_dedup_target(val_raw,   YCOL), KEY)
test_raw  = coerce_and_dedup_key(coerce_and_dedup_target(test_raw,  YCOL), KEY)

train_fe = add_time_features(train_raw, num_cols_all)
val_fe   = add_time_features(val_raw,   num_cols_all)
test_fe  = add_time_features(test_raw,  num_cols_all)

# --------- Keep only the last K visits per cow ------------------------------
V_LAST = 6
train_sel = take_last_k(train_fe, V_LAST)
val_sel   = take_last_k(val_fe,   V_LAST)
test_sel  = take_last_k(test_fe,  V_LAST)

# De-duplicate key/target AGAIN (post FE/filters)
train_sel = coerce_and_dedup_key(coerce_and_dedup_target(train_sel, YCOL), KEY)
val_sel   = coerce_and_dedup_key(coerce_and_dedup_target(val_sel,   YCOL), KEY)
test_sel  = coerce_and_dedup_key(coerce_and_dedup_target(test_sel,  YCOL), KEY)

# --------- Drop degenerate columns ------------------------------------------
train_df, val_df, test_df, dropped_cols = drop_degenerate(train_sel, val_sel, test_sel, KEY, YCOL)

# --------- Diagnostics -------------------------------------------------------
print(f"[FE-visit] rows — TRAIN {train_df.shape} | VAL {val_df.shape} | TEST {test_df.shape}")
print(f"[FE-visit] visits+ ({YCOL}) — TR {count_pos_visits(train_df,YCOL)} | VA {count_pos_visits(val_df,YCOL)} | TE {count_pos_visits(test_df,YCOL)}")
print(f"[FE-visit] cows+ (max-per-cow over visits) — TR {count_pos_cows(train_df,KEY,YCOL)} | VA {count_pos_cows(val_df,KEY,YCOL)} | TE {count_pos_cows(test_df,KEY,YCOL)}")
feat_cnt = len([c for c in train_df.columns if c not in {KEY,YCOL,'Day','_visit_idx','Cow_ID_norm'}])
print(f"[READY] KEY='{KEY}' | YCOL='{YCOL}' | Num features={feat_cnt}")
if dropped_cols:
    print(f"[NOTE] Dropped degenerate columns: {dropped_cols[:10]}{' ...' if len(dropped_cols)>10 else ''}")


In [None]:
# ===== 3) Image Index Builder (YOLO-style or flat) — robust cow-id parsing =====
# Builds df_images with: [stem, image_path, label, Cow_ID_raw, Cow_ID_norm, Cow_ID_match]
# ID heuristics (priority order):
#   1) FLIR#### in the filename stem (FLIR#### / FLIR-#### / FLIR_####)
#   2) #### immediately before `_jpg` in the stem
#   3) parent folder name ending with ####
#   4) first 3–6 digit sequence in the stem (excluding any `.rf.*` suffix)
# Fallback: use all digits from the stem and keep the last 4 as a guess.

import os, re, glob, pandas as pd, numpy as np

def discover_images_labels(image_dir, label_dir):
    """
    Scan the image and label directories.
    Returns:
      images: list of image paths (common formats)
      label_txts: YOLO .txt annotation files (if any)
      label_csvs: optional CSV label files (if any)
    """
    images = []
    for ext in ("*.jpg","*.jpeg","*.png","*.bmp"):
        images += glob.glob(os.path.join(image_dir, "**", ext), recursive=True)
    images = sorted(list(set(images)))
    label_txts = glob.glob(os.path.join(label_dir, "**", "*.txt"), recursive=True)
    label_csvs = glob.glob(os.path.join(label_dir, "**", "*.csv"), recursive=True)
    return images, label_txts, label_csvs

def stem_of(path):
    """Return filename stem without extension."""
    b = os.path.basename(path); s, _ = os.path.splitext(b); return s

def parse_yolo_label_file(txt_path):
    """
    Parse a YOLO label file.
    Convention used here: an image is considered positive (1) if any class ID != 0 is present.
    Otherwise it is negative (0). If the file cannot be parsed, return 0 by default.
    """
    try:
        with open(txt_path, "r") as f:
            lines = [ln.strip() for ln in f if ln.strip()]
        if not lines: return 0
        first_cols = [int(float(ln.split()[0])) for ln in lines if ln.split()]
        return 1 if any(c != 0 for c in first_cols) else 0
    except Exception:
        return 0

# ---- Cow ID parsing helpers ----
def digits_only(x: str) -> str:
    """Keep digits only; return 'nan' if value is missing."""
    if pd.isna(x): return "nan"
    return re.sub(r"\D", "", str(x))

def strip_leading_zeros(x: str) -> str:
    """Remove leading zeros from a digit string; return 'nan' when empty."""
    if x in ("nan","",None): return "nan"
    s = x.lstrip("0")
    return s if s else "0"

def infer_cow_id_from_path(path: str, stem: str) -> str:
    """
    Try, in order:
    1) FLIR#### or FLIR-#### or FLIR_#### in the stem
    2) #### immediately before '_jpg'
    3) parent folder ending with ####
    4) first 3–6 digit sequence in the stem before any '.rf' segment
    Fallback: take all digits from the stem-no-rf and keep the last 4 if length ≥ 3.
    """
    st = stem

    # 1) FLIR#### (3–6 digits)
    m = re.search(r'FLIR[_-]?(\d{3,6})', st, re.IGNORECASE)
    if m:
        return m.group(1)

    # 2) #### before _jpg
    m = re.search(r'(\d{3,6})(?=_jpg\b)', st, re.IGNORECASE)
    if m:
        return m.group(1)

    # 3) parent folder
    parent = os.path.basename(os.path.dirname(path))
    m = re.search(r'(\d{3,6})$', parent)
    if m:
        return m.group(1)

    # 4) first 3–6 digit sequence in the stem before the '.rf' segment (if present)
    st_no_rf = st.split(".rf")[0]
    m = re.search(r'(\d{3,6})', st_no_rf)
    if m:
        return m.group(1)

    # Fallback: collect all digits, keep last 4 as a best-effort guess
    d = digits_only(st_no_rf)
    if len(d) >= 3:
        return d[-4:]
    return "nan"

# ---- Scan -------------------------------------------------------------------
images, yolo_txts, label_csvs = discover_images_labels(IMAGE_DIR, LABEL_DIR)

# Map stem -> label
label_map = {}

# 1) YOLO .txt labels
for p in yolo_txts:
    st = stem_of(p)
    label_map[st] = parse_yolo_label_file(p)

# 2) Optional CSV labels (if present)
for csvp in label_csvs:
    try:
        d = pd.read_csv(csvp)
        cols_lower = {c.lower(): c for c in d.columns}
        labcol = cols_lower.get("label") or cols_lower.get("labels")
        namecol = None
        for c in d.columns:
            if any(k in c.lower() for k in ["file","image","name","stem","path"]):
                namecol = c; break
        if labcol and namecol:
            for _, r in d.iterrows():
                st = stem_of(str(r[namecol]))
                label_map[st] = int(r[labcol])
    except Exception as e:
        print("[Labels][WARN] Could not parse CSV:", csvp, "err:", e)

rows = []
for p in images:
    st = stem_of(p)
    lab = label_map.get(st, None)
    if lab is None:
        # Skip images without labels (we only index labelled samples)
        continue
    cow_raw = infer_cow_id_from_path(p, st)
    rows.append({"stem": st, "image_path": p, "label": int(lab), "Cow_ID_raw": cow_raw})

df_images = pd.DataFrame(rows)

if len(df_images) == 0:
    print("[Images][WARN] No labelled images found. Check folders or labels.")
else:
    # Normalise IDs consistently with the tabular side (digits → strip leading zeros).
    df_images["Cow_ID_norm"] = df_images["Cow_ID_raw"].map(digits_only).map(strip_leading_zeros)
    df_images["Cow_ID_match"] = df_images["Cow_ID_norm"]

    print(f"[Images] df_images shape: {df_images.shape}")
    print(df_images.head(10))

    # ---- Diagnostics: check plausible ID lengths (3–6) ----------------------
    lengths = df_images["Cow_ID_match"].replace("nan", np.nan).dropna().map(len)
    print("[Images][Diag] Cow_ID_match length stats:", lengths.describe().to_dict())

    # ---- Diagnostics: overlap with the tabular set (if available from Cell 2) ----
    if 'df' in globals() and 'COW_ID_FOR_ALIGNMENT' in globals():
        key = COW_ID_FOR_ALIGNMENT
        tab_cows = set(df[key].astype(str).unique())
        img_cows = set(df_images["Cow_ID_match"].astype(str).unique())
        inter = tab_cows & img_cows
        # Class distribution for images within the overlap
        pos_in_overlap = df_images[df_images["Cow_ID_match"].isin(inter)]["label"].sum()
        print(f"[Overlap] Cows in TAB: {len(tab_cows)} | in IMG: {len(img_cows)} | intersection: {len(inter)} | image-positives in ∩: {pos_in_overlap}")


In [None]:
# ===== 4) Imaging model — EfficientNet frozen + Augment + TTA + cow-stratified image-only split =====
# Purpose: train an image-only branch using a frozen EfficientNet as a feature extractor.
# We build a cow-stratified TRAIN/VAL/TEST split on images, apply data augmentation for TRAIN,
# optionally oversample the minority class, and use Test-Time Augmentation (TTA) at evaluation.
# Outputs include image-level metrics and per-cow aggregated probabilities.

import torch, torchvision
import torchvision.transforms as T
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from PIL import Image
import numpy as np, pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, average_precision_score
from sklearn.model_selection import StratifiedShuffleSplit

SEED = 42
rng = np.random.RandomState(SEED)
IMG_SIZE = 224
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PIN_MEM = torch.cuda.is_available()

# --------- Image split configuration -----------------------------------------
IMG_TRAIN_FRAC = 0.60
IMG_VAL_FRAC   = 0.20   # the remaining proportion goes to TEST
K_VIEWS_TRAIN  = 5      # number of augmentation views per TRAIN image
USE_OVERSAMPLING = True # oversample the minority class on TRAIN
TTA_N_VIEWS    = 8      # number of TTA views for VAL/TEST (0 disables TTA)

# --------- Transforms ---------------------------------------------------------
train_tf = T.Compose([
    T.RandomResizedCrop(IMG_SIZE, scale=(0.90, 1.00), ratio=(0.98, 1.02)),
    T.RandomHorizontalFlip(p=0.5),
    T.RandomAffine(degrees=7, translate=(0.03, 0.03), scale=(0.98, 1.02)),
    T.GaussianBlur(kernel_size=3, sigma=(0.1, 0.8)),
    T.ToTensor(),
    T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])
eval_tf = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.ToTensor(),
    T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

# --------- Backbone feature extractor ----------------------------------------
try:
    import timm
except Exception as e:
    raise RuntimeError("Please install timm: pip install timm") from e

class EffNetFeats(nn.Module):
    """
    Frozen EfficientNet feature extractor: outputs global average pooled features (no classifier head).
    """
    def __init__(self, model_name="efficientnet_b0"):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=True, num_classes=0, global_pool="avg")
        for p in self.backbone.parameters(): p.requires_grad = False
    def forward(self, x): return self.backbone(x)

feat_net = EffNetFeats("efficientnet_b0").to(device).eval()

# --------- Dataset ------------------------------------------------------------
class ImageDatasetK(Dataset):
    """
    Dataset wrapper that can expose K augmented views per underlying image.
    Labels are repeated accordingly; augmentation is controlled by the transform provided.
    """
    def __init__(self, df_rows, transform, k_views=1):
        self.paths  = df_rows["image_path"].tolist()
        self.labels = df_rows["label"].astype(int).tolist()
        self.tf = transform; self.k = max(1, int(k_views))
    def __len__(self): return len(self.paths) * self.k
    def __getitem__(self, idx):
        i = idx % len(self.paths)
        im = Image.open(self.paths[i]).convert("RGB")
        return self.tf(im), self.labels[i]

def extract_features(dloader):
    """
    Run the frozen backbone to obtain feature vectors and return (X, y).
    This function is used for TRAIN/VAL/TEST loaders alike.
    """
    X, y = [], []
    with torch.no_grad():
        for xb, yb in dloader:
            xb = xb.to(device)
            feats = feat_net(xb).cpu().numpy()
            X.append(feats); y.append(np.array(yb))
    X = np.vstack(X) if len(X) else np.zeros((0, feat_net.backbone.num_features))
    y = np.concatenate(y) if len(y) else np.array([])
    return X, y

# --------- TTA predictor ------------------------------------------------------
def predict_with_tta(paths, clf, n_views=TTA_N_VIEWS):
    """
    Compute classifier probabilities with optional Test-Time Augmentation.
    When n_views > 0, we recompute features multiple times with mild, label-preserving jitter,
    and average the predicted probabilities across views.
    """
    if n_views <= 0:
        ds = ImageDatasetK(pd.DataFrame({"image_path":paths, "label":[0]*len(paths)}), eval_tf, k_views=1)
        dl = DataLoader(ds, batch_size=64, shuffle=False, num_workers=2, pin_memory=PIN_MEM)
        X,_ = extract_features(dl)
        return clf.predict_proba(X)[:,1]
    # n augmented evaluation views (light augmentations)
    aug_eval = T.Compose([
        T.Resize((IMG_SIZE, IMG_SIZE)),
        T.RandomHorizontalFlip(p=0.5),
        T.RandomAffine(degrees=3, translate=(0.01, 0.01), scale=(0.995,1.005)),
        T.ToTensor(),
        T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
    ])
    all_probs = []
    for _ in range(n_views):
        ds = ImageDatasetK(pd.DataFrame({"image_path":paths, "label":[0]*len(paths)}), aug_eval, k_views=1)
        dl = DataLoader(ds, batch_size=64, shuffle=False, num_workers=2, pin_memory=PIN_MEM)
        X,_ = extract_features(dl)
        all_probs.append(clf.predict_proba(X)[:,1])
    return np.mean(np.vstack(all_probs), axis=0)

# --------- Build an image-only, cow-stratified split -------------------------
if 'df_images' not in globals() or len(df_images)==0:
    print("No df_images available. Skipping image model.")
else:
    KEY = "Cow_ID_match" if "Cow_ID_match" in df_images.columns else "Cow_ID_norm"
    dfi = df_images.copy()
    dfi[KEY] = dfi[KEY].astype(str)

    cows = dfi.groupby(KEY)['label'].max().reset_index()
    y_cow = cows['label'].values.astype(int)
    C = cows[KEY].values.astype(str)

    # Cow-level stratified split: TRAIN / (VAL+TEST)
    sss1 = StratifiedShuffleSplit(n_splits=1, test_size=(1.0-IMG_TRAIN_FRAC), random_state=SEED)
    tr_idx, tmp_idx = next(sss1.split(C, y_cow))
    C_tr, C_tmp = C[tr_idx], C[tmp_idx]; y_tr, y_tmp = y_cow[tr_idx], y_cow[tmp_idx]

    # Split (VAL / TEST) from the temporary pool
    test_frac_rel = (1.0 - IMG_TRAIN_FRAC - IMG_VAL_FRAC) / (1.0 - IMG_TRAIN_FRAC)
    sss2 = StratifiedShuffleSplit(n_splits=1, test_size=test_frac_rel, random_state=SEED)
    va_idx_rel, te_idx_rel = next(sss2.split(C_tmp, y_tmp))
    C_va, C_te = C_tmp[va_idx_rel], C_tmp[te_idx_rel]

    tr_img = dfi[dfi[KEY].isin(set(C_tr))].reset_index(drop=True)
    va_img = dfi[dfi[KEY].isin(set(C_va))].reset_index(drop=True)
    te_img = dfi[dfi[KEY].isin(set(C_te))].reset_index(drop=True)

    print(f"[Imaging|Split image-only] COWS — TRAIN: {len(C_tr)} | VAL: {len(C_va)} | TEST: {len(C_te)}")
    print(f"[Imaging|Split image-only] IMAGES — TRAIN: {len(tr_img)} | VAL: {len(va_img)} | TEST: {len(te_img)}")
    print(f"[Imaging|Class balance] cows TRAIN pos={y_tr.sum()}/{len(y_tr)} | VAL pos={y_tmp[va_idx_rel].sum()}/{len(va_idx_rel)} | TEST pos={y_tmp[te_idx_rel].sum()}/{len(te_idx_rel)}")

    # --------- Dataset + (optional) oversampling ------------------------------
    train_ds = ImageDatasetK(tr_img, transform=train_tf, k_views=K_VIEWS_TRAIN)
    val_ds   = ImageDatasetK(va_img, transform=eval_tf,   k_views=1)
    test_ds  = ImageDatasetK(te_img, transform=eval_tf,   k_views=1)

    # Minority-class oversampling on TRAIN (at image level, replicated across K views)
    y_train_base = tr_img["label"].astype(int).values
    class_counts = np.bincount(y_train_base) if y_train_base.size else np.array([0,0])
    sampler = None
    if USE_OVERSAMPLING and class_counts.size==2 and class_counts.min()>0:
        class_weights = 1.0 / class_counts
        # replicate weights for the K views
        sample_weights = np.array([class_weights[y_train_base[i % len(y_train_base)]] for i in range(len(train_ds))])
        sampler = WeightedRandomSampler(weights=torch.from_numpy(sample_weights).float(),
                                        num_samples=len(train_ds), replacement=True)

    # Dataloaders
    tr_dl = DataLoader(train_ds, batch_size=32, shuffle=(sampler is None), sampler=sampler,
                       num_workers=2, pin_memory=PIN_MEM)
    va_dl = DataLoader(val_ds, batch_size=64, shuffle=False, num_workers=2, pin_memory=PIN_MEM)
    te_dl = DataLoader(test_ds, batch_size=64, shuffle=False, num_workers=2, pin_memory=PIN_MEM) if len(te_img)>0 else None

    # --------- LOG: how many effective samples does TRAIN see? ----------------
    print(f"[Imaging|Train loader] items={len(train_ds)}  (base_imgs={len(tr_img)} × K_VIEWS={K_VIEWS_TRAIN})"
          + (f"  | oversampling=ON" if sampler is not None else "  | oversampling=OFF"))

    # --------- Feature extraction --------------------------------------------
    def extract_all(dl):
        X,y = extract_features(dl); return X,y
    Xtr, ytr = extract_all(tr_dl)
    Xva, yva = extract_all(va_dl)
    if te_dl is not None:
        Xte, yte = extract_all(te_dl)
    else:
        Xte, yte = np.zeros((0, Xtr.shape[1])) if Xtr.size else np.array([]), np.array([])

    # MixUp in feature space (label-preserving interpolation in embedding space)
    def feature_mixup(X, y, alpha=0.4, n_new=None, rng=rng):
        if X.shape[0] < 2: return X, y
        if n_new is None: n_new = X.shape[0] // 2
        i1 = rng.randint(0, X.shape[0], n_new); i2 = rng.randint(0, X.shape[0], n_new)
        lam = rng.beta(alpha, alpha, size=n_new)[:,None]
        Xn = lam*X[i1] + (1-lam)*X[i2]
        yn = ((lam[:,0]*y[i1] + (1-lam[:,0])*y[i2]) >= 0.5).astype(int)
        return np.vstack([X, Xn]), np.concatenate([y, yn])

    if Xtr.shape[0] > 0 and len(np.unique(ytr))>=2:
        Xtr_aug, ytr_aug = feature_mixup(Xtr, ytr, alpha=0.4, n_new=Xtr.shape[0]//2)
        clf = LogisticRegression(max_iter=4000, class_weight='balanced', solver='lbfgs')
        clf.fit(Xtr_aug, ytr_aug)

        # --- Predictions with/without TTA ------------------------------------
        # (TTA recomputes features from raw images; Xva/Xte above are without TTA)
        p_val_img  = predict_with_tta(va_img['image_path'].tolist(), clf, n_views=TTA_N_VIEWS)
        p_test_img = predict_with_tta(te_img['image_path'].tolist(), clf, n_views=TTA_N_VIEWS) if len(te_img)>0 else np.array([])

        if len(np.unique(yva))==2:
            print(f"[Imaging] VAL image-level — AUROC={roc_auc_score(yva,p_val_img):.4f} | AUPRC={average_precision_score(yva,p_val_img):.4f} | N={len(yva)}")
        if len(yte)>0 and len(np.unique(yte))==2:
            print(f"[Imaging] TEST image-level — AUROC={roc_auc_score(yte,p_test_img):.4f} | AUPRC={average_precision_score(yte,p_test_img):.4f} | N={len(yte)}")
    else:
        print("[Imaging][WARN] Not enough training images or only one class in training. Skipping classifier.")
        p_val_img = np.array([]); p_test_img = np.array([])

    # --------- Per-cow aggregation (always valid under this split) -----------
    def agg_per_cow(df_rows: pd.DataFrame, probs: np.ndarray, cow_col: str, target_col="label") -> pd.DataFrame:
        """
        Aggregate image-level probabilities to cow-level by mean; cow label is max over images.
        Returns a DataFrame with [cow_id, y, p_img, n] where p_img is mean per cow and n is image count.
        """
        tmp = df_rows.copy(); tmp["proba"] = probs
        return tmp.groupby(cow_col).agg(y=(target_col,"max"), p=("proba","mean"), n=("proba","count")).reset_index()

    val_img_cow  = agg_per_cow(va_img, p_val_img,  cow_col=KEY, target_col="label").rename(columns={"p":"p_img"})
    test_img_cow = agg_per_cow(te_img, p_test_img, cow_col=KEY, target_col="label").rename(columns={"p":"p_img"}) if len(te_img)>0 else None

    print(f"[Imaging] Output per-cow — VAL cows: {len(val_img_cow)} | TEST cows: {0 if test_img_cow is None else len(test_img_cow)}")
    print("Images available for training:", bool(len(p_val_img)>0))


In [None]:
# =======================
# Cell 5 — v11 (robust past-only features + LR ⊕ HGB + tuned pooling)
# Objective: increase TEST AUPRC without saturating. Pre-event is applied only in the pooling step.
# =======================
import numpy as np, pandas as pd, warnings, math
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.metrics import roc_auc_score, average_precision_score, brier_score_loss
warnings.filterwarnings("ignore", category=UserWarning)

# Ensure prior cells built the split-specific DataFrames.
assert 'train_df' in globals() and 'val_df' in globals() and 'test_df' in globals(), "Missing train/val/test."
KEY = 'Cow_ID_match'
for nm, d in [('train',train_df),('val',val_df),('test',test_df)]:
    if KEY not in d.columns: raise KeyError(f"{nm}_df is missing '{KEY}'")
for d in (train_df, val_df, test_df):
    if 'class1' not in d.columns: raise KeyError("Column 'class1' (0/1) is required.")
    if 'Temperature' not in d.columns: raise KeyError("Column 'Temperature' is required.")

SEED = 42
TIME_COLS = ['Day','visit_time','datetime','VisitDate','time']

# ---------- Time helpers ------------------------------------------------------
def _order_series(df):
    """
    Produce an ordering key for each cow:
    - Prefer a time-like column (parsed to datetime where necessary).
    - Fallback to a per-cow visit index.
    """
    for c in TIME_COLS:
        if c in df.columns:
            s = df[c]
            if pd.api.types.is_string_dtype(s):
                try: s = pd.to_datetime(s, errors='coerce')
                except: pass
            return s
    return df.groupby(KEY).cumcount()

def _sort(df):
    """Sort by cow and temporal/visit order, returning a copy with an internal order column."""
    df = df.copy()
    df['_ord_'] = _order_series(df)
    return df.sort_values([KEY,'_ord_']).reset_index(drop=True)

def _align(df, s, dtype=float):
    """Return a Series aligned to df.index with a guaranteed dtype."""
    return pd.Series(s, index=df.index, dtype=dtype)

# ---------- Robust past-only features ----------------------------------------
def z_past_strict(df, col):
    """
    Per-cow z-score using *past-only* expanding statistics:
    z_t = (x_t − mean_{<t}) / std_{<t}, with minimum periods and winsorisation.
    """
    df = _sort(df.copy())
    if col not in df.columns:
        out = pd.Series(np.nan, index=df.index, dtype=float)
        df.drop(columns=['_ord_'], inplace=True)
        return out
    x = pd.to_numeric(df[col], errors='coerce')
    g = df.groupby(KEY)[x.name]
    # Expanding mean/std with minimum periods; shift so only past is used.
    m = g.expanding(min_periods=2).mean().reset_index(level=0, drop=True)
    s = g.expanding(min_periods=3).std(ddof=1).reset_index(level=0, drop=True)
    mu_prev = m.groupby(df[KEY]).shift(1)
    sd_prev = s.groupby(df[KEY]).shift(1)
    sd_prev = sd_prev.replace(0, np.nan)
    z = (x - mu_prev) / sd_prev
    # Winsorise z for robustness
    z = z.clip(lower=-5, upper=5)
    out = _align(df, z, float)
    df.drop(columns=['_ord_'], inplace=True)
    return out

def rolling_median_dev_z(df, col, win=3):
    """
    Deviation from rolling median (per cow), then past-only z-standardisation of the deviation.
    """
    df = _sort(df.copy())
    x = pd.to_numeric(df[col], errors='coerce')
    med = df.groupby(KEY)[x.name].rolling(win, min_periods=2).median()
    med.index = med.index.droplevel(0)
    med_prev = med.groupby(df[KEY]).shift(1)
    dev = x - med_prev
    # z-standardise the deviation using past-only stats
    tmp = df.copy(); tmp['__dev__'] = dev
    z = z_past_strict(tmp, '__dev__')
    df.drop(columns=['_ord_'], inplace=True)
    return z

def slope_last3_prev(df, col):
    """
    Past-only slope of the last 3 observations (per cow), aligned so that
    the slope for time t uses data up to t−1.
    """
    df = _sort(df.copy())
    x = pd.to_numeric(df[col], errors='coerce')
    out = np.full(len(df), np.nan, dtype=float)
    for cow, idx in df.groupby(KEY).groups.items():
        vals = x.loc[idx].astype(float).values
        sl = np.full_like(vals, np.nan, dtype=float)
        for i in range(3, len(vals)+1):
            y = vals[i-3:i]; t = np.arange(3)
            if np.isfinite(y).sum() >= 2:
                t_mean = t.mean(); y_mean = np.nanmean(y)
                num = np.nansum((t - t_mean)*(y - y_mean))
                den = np.nansum((t - t_mean)**2) + 1e-9
                sl[i-1] = num/den
        # shift to ensure pure past-only use
        sl = np.roll(sl, 1); sl[0] = np.nan
        out[idx] = sl
    df.drop(columns=['_ord_'], inplace=True)
    return _align(df, out, float)

def difflag(df, col, k):
    """Past-only k-lag difference for a numeric column (per cow)."""
    df = _sort(df.copy())
    x = pd.to_numeric(df[col], errors='coerce')
    prev = df.groupby(KEY)[x.name].shift(k)
    d = x - prev
    df.drop(columns=['_ord_'], inplace=True)
    return _align(df, d, float)

def seasonal_feats(df):
    """
    Encode simple seasonality using a cyclic transform:
    - If 'Day' exists, use day-of-year; else if 'Months after giving birth' exists, use it mod 12.
    Past-only alignment is enforced by shifting by one step.
    """
    df = _sort(df.copy())
    if 'Day' in df.columns:
        day = pd.to_datetime(df['Day'], errors='coerce')
        doy_prev = day.dt.dayofyear.groupby(df[KEY]).shift(1)
        ang = 2*np.pi*(doy_prev.fillna(0)/365.25)
        sinv, cosv = np.sin(ang), np.cos(ang)
    elif 'Months after giving birth' in df.columns:
        m_prev = pd.to_numeric(df['Months after giving birth'], errors='coerce').groupby(df[KEY]).shift(1)
        ang = 2*np.pi*((m_prev.fillna(0)%12)/12.0)
        sinv, cosv = np.sin(ang), np.cos(ang)
    else:
        sinv, cosv = np.nan, np.nan
    out_sin, out_cos = _align(df, sinv, float).rename('season_sin'), _align(df, cosv, float).rename('season_cos')
    df.drop(columns=['_ord_'], inplace=True)
    return out_sin, out_cos

def build_feats(df):
    """
    Construct robust past-only features for Temperature and seasonality.
    Add months-after-birth z-score when available.
    """
    df = df.copy()
    df['Temperature_z']      = z_past_strict(df, 'Temperature')
    df['Temp_meddev3_z']     = rolling_median_dev_z(df, 'Temperature', win=3)
    df['Temp_slope3']        = slope_last3_prev(df, 'Temperature')
    df['Temp_diff1']         = difflag(df, 'Temperature', 1)
    df['Temp_diff2']         = difflag(df, 'Temperature', 2)
    ssin, scos               = seasonal_feats(df)
    df['season_sin'], df['season_cos'] = ssin, scos
    if 'Months after giving birth' in df.columns:
        df['Months_after_birth_z'] = z_past_strict(df, 'Months after giving birth')
    return df

# Apply feature construction to each split (no leakage between splits).
for nm in ['train_df','val_df','test_df']:
    globals()[nm] = build_feats(globals()[nm])

# ---------- Target: risk_h1 (pre-event ONLY in pooling) ----------------------
def add_risk_h1(df):
    """
    Define next-visit risk at visit level (risk_h1): label_t = class1 at t+1 (shifted),
    with missing future set to 0. Sorting ensures past-only alignment.
    """
    df = _sort(df)
    df['risk_h1'] = df.groupby(KEY)['class1'].shift(-1).fillna(0).astype(int)
    df.drop(columns=['_ord_'], inplace=True)
    return df
train_df = add_risk_h1(train_df)
val_df   = add_risk_h1(val_df)
test_df  = add_risk_h1(test_df)
YCOL = 'risk_h1'

# ---------- Whitelist: keep features that exist and have data in all splits ---
cand = ['Temperature_z','Temp_meddev3_z','Temp_slope3','Temp_diff1','Temp_diff2','season_sin','season_cos','Months_after_birth_z']
def _ok(c):
    return (c in train_df.columns and c in val_df.columns and c in test_df.columns
            and train_df[c].notna().sum()>0 and val_df[c].notna().sum()>0 and test_df[c].notna().sum()>0)
whitelist = [c for c in cand if _ok(c)]
if not whitelist:
    raise RuntimeError("Empty whitelist: please verify generation of z/diff/slope features.")
print(f"[v11] whitelist: {whitelist} | Target={YCOL}")

# ---------- Preprocess --------------------------------------------------------
pre = ColumnTransformer([
    ("num", Pipeline([
        ("imp", SimpleImputer(strategy="median")),
        ("sc",  StandardScaler())
    ]), whitelist)
], remainder='drop', verbose_feature_names_out=True)
pre.fit(train_df[whitelist])

def mat(df):
    """Transform df into (keys, X, y) using the fitted preprocessor."""
    X = pre.transform(df[whitelist])
    y = df[YCOL].astype(int).values
    K = df[KEY].astype(str).values
    return K, X, y

Kv_tr, Xtr, ytr = mat(train_df)
Kv_va, Xva, yva = mat(val_df)
Kv_te, Xte, yte = mat(test_df)

# ---------- Models: LR-EN + HGB ---------------------------------------------
pos_rate = max(1e-6, float((ytr==1).mean()))
w_pos = 0.5/pos_rate; w_neg = 0.5/(1.0-pos_rate)
w_tr  = np.where(ytr==1, w_pos, w_neg)

lr = LogisticRegression(max_iter=4000, solver='saga', penalty='elasticnet', l1_ratio=0.35, C=1.0, random_state=SEED)
lr.fit(Xtr, ytr, sample_weight=w_tr)

hgb = HistGradientBoostingClassifier(
    learning_rate=0.15, max_leaf_nodes=31, min_samples_leaf=25,
    l2_regularization=0.0, max_depth=None, random_state=SEED
)
hgb.fit(Xtr, ytr, sample_weight=w_tr)

def proba(clf, X):
    """Return positive-class probabilities (or decision function as a fallback)."""
    return clf.predict_proba(X)[:,1] if hasattr(clf, "predict_proba") else clf.decision_function(X)

pva_lr,  pte_lr  = proba(lr,  Xva), proba(lr,  Xte)
pva_hgb, pte_hgb = proba(hgb, Xva), proba(hgb, Xte)

# ---------- Pre-event pooling + hyperparameter tuning on VAL (AUPRC) ---------
def pooling_scores(p_visit, keys, tau, r, topk, jitter=0.006, seed=SEED):
    """
    Convert visit-level probabilities into a per-cow score with:
    - temperature scaling (tau) in logit space,
    - light Gaussian jitter to break ties,
    - pre-event exclusion: drop the last visit per cow,
    - robust pooling via (power mean + top-k mean)/2 over positive excess above median.
    """
    p = np.clip(p_visit, 1e-9, 1-1e-9)
    z = np.log(p/(1-p))/tau
    pt = 1/(1+np.exp(-z))
    if jitter>0:
        rng = np.random.default_rng(seed)
        pt = np.clip(pt + rng.normal(0.0, jitter, size=pt.shape), 1e-9, 1-1e-9)
    dfp = pd.DataFrame({'k':keys, 'pt':pt})
    # pre-event: ignore the last visit per cow
    last_idx = pd.DataFrame({'k':keys}).groupby('k').tail(1).index
    dfp = dfp[~dfp.index.isin(last_idx)]
    med = dfp.groupby('k')['pt'].transform('median')
    exc = (dfp['pt'] - med).clip(lower=0)
    def pmean(x, rr):
        xv = x.values
        return (((xv**rr).mean())**(1.0/rr)) if xv.size>0 else np.nan
    def topk_mean(x, kk):
        xv = np.sort(x.values);
        if xv.size==0: return np.nan
        kk = min(kk, xv.size);
        return float(xv[-kk:].mean())
    pm = exc.groupby(dfp['k']).apply(lambda s: pmean(s, r))
    tk = exc.groupby(dfp['k']).apply(lambda s: topk_mean(s, topk))
    return ((pm + tk)/2.0)

def ranknorm(x):
    """Rank-normalise to [0,1] with ties resolved by argsort order."""
    r = np.argsort(np.argsort(x))
    return r / max(len(x)-1, 1)

# Cow-level labels for VAL/TEST (max over visits)
def cow_label(df): return df.groupby(KEY)['class1'].max().astype(int)
Kva = sorted(set(Kv_va), key=str); yva_cow = cow_label(val_df).reindex(Kva).values
Kte = sorted(set(Kv_te), key=str); yte_cow = cow_label(test_df).reindex(Kte).values

taus  = [2.0, 2.3, 2.6]
rs    = [0.7, 0.8, 0.9]
topks = [2, 3, 4]
best = None

for tau in taus:
    for r in rs:
        for k in topks:
            # Pool for each model
            va_lr  = pooling_scores(pva_lr,  Kv_va, tau, r, k).reindex(Kva).fillna(0.0).values
            va_hgb = pooling_scores(pva_hgb, Kv_va, tau, r, k).reindex(Kva).fillna(0.0).values
            # Robust rank-ensemble
            va_ens = (ranknorm(va_lr) + ranknorm(va_hgb))/2.0
            try:
                ap = average_precision_score(yva_cow, va_ens)
            except:
                ap = -np.inf
            if (best is None) or (ap > best[0]):
                best = (ap, tau, r, k, va_lr, va_hgb, va_ens)

ap_best, TAU_B, R_B, K_B, va_lr_b, va_hgb_b, va_ens_b = best

# Compute on TEST with best hyperparameters
te_lr_b  = pooling_scores(pte_lr,  Kv_te, TAU_B, R_B, K_B).reindex(Kte).fillna(0.0).values
te_hgb_b = pooling_scores(pte_hgb, Kv_te, TAU_B, R_B, K_B).reindex(Kte).fillna(0.0).values
va_ens_b = (ranknorm(va_lr_b) + ranknorm(va_hgb_b))/2.0
te_ens_b = (ranknorm(te_lr_b) + ranknorm(te_hgb_b))/2.0

# ---------- Platt calibration on VAL (ensemble) ------------------------------
from sklearn.linear_model import LogisticRegression as LRCal
cal = LRCal(max_iter=1000, random_state=SEED).fit(va_ens_b.reshape(-1,1), yva_cow.astype(int))
pva_c = cal.predict_proba(va_ens_b.reshape(-1,1))[:,1]
pte_c = cal.predict_proba(te_ens_b.reshape(-1,1))[:,1]

# ---------- Metrics ----------------------------------------------------------
def metr(name, y, p):
    """Compute AUROC, AUPRC, and Brier score; clip probabilities for numerical stability."""
    p = np.clip(p, 1e-9, 1-1e-9)
    try: auc = roc_auc_score(y, p)
    except: auc = np.nan
    ap = average_precision_score(y, p)
    br = brier_score_loss(y, p)
    return dict(name=name, AUROC=auc, AUPRC=ap, Brier=br, N=len(y))

res = pd.DataFrame([
    metr(f"VAL v11 (tau={TAU_B}, r={R_B}, K={K_B})", yva_cow, pva_c),
    metr(f"TEST v11 (tau={TAU_B}, r={R_B}, K={K_B})", yte_cow, pte_c),
])
print("\n=== Summary v11 (robust features + LR⊕HGB + tuned pooling) ===")
print(res[["name","AUROC","AUPRC","Brier","N"]].to_string(index=False))

# ---------- Audit ------------------------------------------------------------
print(f"\n[v11 Audit] whitelist={whitelist}")
print(f"[v11 Audit] rows (train/val/test): {len(train_df)}/{len(val_df)}/{len(test_df)}")
print(f"[v11 Audit] pooling*  tau={TAU_B}, r={R_B}, topK={K_B}  (optimised on VAL for AUPRC)")


In [None]:
# =======================
# Cell 6 — v6.8 (YOLO .txt, GPU/AMP, multimodal-ready, bootstrap & figures/tables)
# - Imaging pipeline with robust fallback (images-only) and tabular fusion when available
# - Save figures & tables to PROJECT_DIR/figures_and_tables (overwrite guaranteed)
# =======================
import os, re, glob, time, json, warnings, random, sys
warnings.filterwarnings("ignore", category=UserWarning)
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

# Torch / Vision
import torch, torchvision
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# ML
from sklearn.linear_model import LogisticRegression, LogisticRegression as LRCal
from sklearn.metrics import (average_precision_score, roc_auc_score, brier_score_loss,
                             roc_curve, precision_recall_curve, confusion_matrix)
from sklearn.model_selection import StratifiedShuffleSplit

# Plot
import matplotlib.pyplot as plt

# ===== CONFIG =====
SEED    = 42
DEBUG   = True

# Fixed PATHS (per your setup)
if 'PROJECT_DIR' not in globals():
    PROJECT_DIR = "/content/drive/MyDrive/Mastitis_illness_cow/datasets"
IMAGE_DIR = os.path.join(PROJECT_DIR, "images")
LABEL_DIR = os.path.join(PROJECT_DIR, "labels")

SAVE_DIR   = "/content/mastitis_outputs"  # quick outputs
FIGDIR     = os.path.join(PROJECT_DIR, "figures_and_tables")  # for the paper
os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(FIGDIR, exist_ok=True)

# YOLO: which class_id means “mastitis positive”
POSITIVE_CLASS_ID = 1         # <-- change if needed (e.g., 0)
IMG_EXTS = {".png",".jpg",".jpeg",".bmp",".tif",".tiff"}

print("[MOUNT] IN_COLAB:", 'google.colab' in sys.modules)
print("[PATHS] PROJECT_DIR:", PROJECT_DIR)
print("[PATHS] IMAGE_DIR exists:", os.path.isdir(IMAGE_DIR), "| LABEL_DIR exists:", os.path.isdir(LABEL_DIR))

# ===== ENV =====
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[ENV] torch={torch.__version__} | torchvision={torchvision.__version__} | device={DEVICE}", flush=True)

# ===== Tabular (optional, for real per-cow fusion) =====
tab_ok = ('train_df' in globals()) and ('val_df' in globals()) and ('test_df' in globals())
KEY = 'Cow_ID_match'
if tab_ok:
    for nm, d in [('train',train_df),('val',val_df),('test',test_df)]:
        if KEY not in d.columns or 'class1' not in d.columns:
            tab_ok = False
            print(f"[WARN] {nm}_df is missing '{KEY}' or 'class1' → using image-only fallback.", flush=True)
            break

# ===== 1) Index images =====
print("[Index] Scanning images…", flush=True)
stem2path = {}
for root, dirs, files in os.walk(IMAGE_DIR):
    for f in files:
        if os.path.splitext(f)[1].lower() in IMG_EXTS:
            stem2path[os.path.splitext(f)[0]] = os.path.join(root, f)
print(f"[Index] Indexed images: {len(stem2path)}")

# ===== 2) Read YOLO labels (.txt one per image) =====
def parse_yolo_txt(txt_path):
    """
    Return True if at least one line has class_id == POSITIVE_CLASS_ID.
    Typical row format: <class_id> <cx> <cy> <w> <h> (normalised)
    """
    pos = False
    try:
        with open(txt_path, "r") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                parts = line.split()
                try:
                    cls = int(float(parts[0]))
                    if cls == POSITIVE_CLASS_ID:
                        pos = True
                        break
                except Exception:
                    continue
    except Exception:
        pass
    return pos

txt_files = sorted([p for p in glob.glob(os.path.join(LABEL_DIR, "*.txt")) if os.path.isfile(p)])
if not txt_files:
    raise RuntimeError(f"No YOLO .txt found in {LABEL_DIR}.")

records = []
for p in tqdm(txt_files, desc="Parse YOLO labels", mininterval=0.1):
    stem = os.path.splitext(os.path.basename(p))[0]
    pos = parse_yolo_txt(p)
    records.append((stem, pos))

lab_img = pd.DataFrame(records, columns=["stem", "pos"])
lab_img['class1'] = lab_img['pos'].astype(int)
lab_img.drop(columns=['pos'], inplace=True)

# resolve image path for stem
lab_img['abs_path'] = lab_img['stem'].map(stem2path)
lab_img = lab_img[lab_img['abs_path'].notna()].reset_index(drop=True)
lab_img['filename'] = lab_img['abs_path'].apply(os.path.basename)
print(f"[Labels] Images with resolved labels: {len(lab_img)}")
if len(lab_img) == 0:
    raise RuntimeError("No images resolved from .txt files: ensure .txt stems match images/ filenames.")

# ===== 3) Try to extract Cow_ID from filename (customise as needed) =====
def extract_cow_from_filename(fname):
    """
    CUSTOMISE this if you know how to map file → cow.
    Tried patterns: 'cow123', 'COW_045', 'vacca-12'
    """
    base = os.path.basename(fname)
    for pat in [r'cow[_-]?(\d+)', r'COW[_-]?(\d+)', r'vacca[_-]?(\d+)']:
        m = re.search(pat, base, re.I)
        if m:
            return m.group(1)
    return None

lab_img[KEY] = lab_img['filename'].apply(extract_cow_from_filename)

# ===== 4) Decide split: per-cow (if mappable) or image-level fallback =====
use_tab_split = tab_ok and lab_img[KEY].notna().any()
if use_tab_split:
    cows_tr = set(train_df[KEY].astype(str))
    cows_va = set(val_df[KEY].astype(str))
    cows_te = set(test_df[KEY].astype(str))
    lab_tr = lab_img[lab_img[KEY].isin(cows_tr)].copy()
    lab_va = lab_img[lab_img[KEY].isin(cows_va)].copy()
    lab_te = lab_img[lab_img[KEY].isin(cows_te)].copy()
    print(f"[Align] images per split (by cow): train={len(lab_tr)} | val={len(lab_va)} | test={len(lab_te)}", flush=True)
    if min(len(lab_tr), len(lab_va), len(lab_te)) == 0:
        print("[WARN] Few matches with cow_id → falling back to image-level (stratified) split.", flush=True)
        use_tab_split = False

if not use_tab_split:
    print("[Fallback] Image-level split (stratified on class1). No tabular fusion possible.", flush=True)
    df_all = lab_img.copy()
    y_all = df_all['class1'].values
    sss1 = StratifiedShuffleSplit(n_splits=1, test_size=0.30, random_state=SEED)
    tr_idx, tm_idx = next(sss1.split(np.zeros(len(y_all)), y_all))
    df_tr = df_all.iloc[tr_idx].reset_index(drop=True)
    df_tm = df_all.iloc[tm_idx].reset_index(drop=True)
    sss2 = StratifiedShuffleSplit(n_splits=1, test_size=0.50, random_state=SEED)
    va_idx, te_idx = next(sss2.split(np.zeros(len(df_tm)), df_tm['class1'].values))
    lab_tr = df_tr
    lab_va = df_tm.iloc[va_idx].reset_index(drop=True)
    lab_te = df_tm.iloc[te_idx].reset_index(drop=True)
    # synthetic KEY = stem (1 image = 1 “cow”)
    for df in (lab_tr, lab_va, lab_te):
        df[KEY] = df['stem']

print(f"[IMG rows] train={len(lab_tr)} | val={len(lab_va)} | test={len(lab_te)}")

# ===== 5) Dataset & DataLoader =====
# PATCH: no ToTensor (we already read as Tensor). Convert dtype and normalise.
img_size = 224
tfm = transforms.Compose([
    transforms.ConvertImageDtype(torch.float32),   # uint8/uint16 → float32 [0,1]
    transforms.Resize((img_size, img_size)),
    transforms.Grayscale(num_output_channels=3),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.25, 0.25, 0.25]),
])

class ImgDS(Dataset):
    def __init__(self, df):
        self.df = df.reset_index(drop=True)
    def __len__(self):
        return len(self.df)
    def __getitem__(self, i):
        r = self.df.iloc[i]
        img = torchvision.io.read_image(r['abs_path'])  # Tensor CxHxW uint8/uint16
        img = tfm(img)                                   # -> float32 normalised
        return img, int(r['class1']), str(r[KEY])

# SPEED PATCH (GPU)
from torch.cuda import amp
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    try: torch.set_float32_matmul_precision("high")
    except Exception: pass
    print(f"[GPU] {torch.cuda.get_device_name(0)} | cap={torch.cuda.get_device_capability(0)}", flush=True)

BATCH       = 256    # lower to 192/128 if OOM
NUM_WORKERS = 6      # set 0 if you get I/O issues
PREFETCH    = 4

def make_loader(df, shuffle):
    return DataLoader(
        ImgDS(df),
        batch_size=BATCH,
        shuffle=shuffle,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        persistent_workers=(NUM_WORKERS > 0),
        prefetch_factor=PREFETCH if NUM_WORKERS > 0 else None
    )

dl_tr = make_loader(lab_tr, True)
dl_va = make_loader(lab_va, False)
dl_te = make_loader(lab_te, False)

# ===== 6) Backbone + embedding extraction =====
backbone = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
feat_dim = backbone.fc.in_features
backbone.fc = nn.Identity()
backbone.eval().to(DEVICE)
for p in backbone.parameters(): p.requires_grad = False
print(f"[Backbone] ResNet18 feat_dim={feat_dim}", flush=True)

use_amp = torch.cuda.is_available()

@torch.no_grad()
def extract_embeddings(dloader, desc):
    Xs, ys, ks = [], [], []
    t0 = time.time()
    for imgs, y, k in tqdm(dloader, desc=desc, mininterval=0.1, leave=True):
        imgs = imgs.to(DEVICE, non_blocking=True)
        if use_amp:
            # New autocast API
            with torch.amp.autocast("cuda", dtype=torch.float16):
                emb = backbone(imgs)
            emb = emb.float().detach().cpu().numpy()
            torch.cuda.synchronize()
        else:
            emb = backbone(imgs).detach().cpu().numpy()
        Xs.append(emb); ys.append(y.numpy()); ks += list(k)
    dt = time.time() - t0
    n  = sum(x.shape[0] for x in Xs) if Xs else 0
    print(f"[TIMING] {desc}: {n} img in {dt:.2f}s → {n/max(dt,1e-9):.1f} img/s", flush=True)
    X = np.concatenate(Xs, axis=0) if Xs else np.zeros((0, feat_dim), dtype=np.float32)
    y = np.concatenate(ys, axis=0) if ys else np.zeros((0,), dtype=np.int32)
    k = np.array(ks, dtype=object)
    return X, y, k

Xtr_i, ytr_i, Ktr_i = extract_embeddings(dl_tr, "Emb TR")
Xva_i, yva_i, Kva_i = extract_embeddings(dl_va, "Emb VA")
Xte_i, yte_i, Kte_i = extract_embeddings(dl_te, "Emb TE")
print(f"[Emb] TR={Xtr_i.shape} VA={Xva_i.shape} TE={Xte_i.shape}", flush=True)

# ===== 7) LR on embeddings =====
pos_rate = max(1e-6, float((ytr_i==1).mean()))
w_pos = 0.5/pos_rate; w_neg = 0.5/(1.0-pos_rate)
w_tr  = np.where(ytr_i==1, w_pos, w_neg)

clf_i = LogisticRegression(max_iter=3000, solver='lbfgs', C=1.0, n_jobs=-1)
clf_i.fit(Xtr_i, ytr_i, sample_weight=w_tr)
pva_img_v = clf_i.predict_proba(Xva_i)[:,1]
pte_img_v = clf_i.predict_proba(Xte_i)[:,1]

# ===== 8) Per-cow pooling when possible, otherwise per-image identity =====
def logistic_temp(p, tau):
    p = np.clip(p, 1e-9, 1-1e-9); z = np.log(p/(1-p))/tau
    return 1/(1+np.exp(-z))

def pre_event_mask(keys):
    s = pd.Series(1, index=pd.RangeIndex(len(keys)))
    if use_tab_split:
        dfk = pd.DataFrame({'k':keys})
        last = dfk.groupby('k').tail(1).index
        s.loc[last] = 0
    return s.astype(bool).values

def pooling_scores(p_visit, keys, tau, r, topk, jitter=0.006, seed=SEED):
    if not use_tab_split:
        return pd.Series(p_visit, index=pd.Index(keys, name='k'))
    pt = logistic_temp(p_visit, tau)
    if jitter>0:
        rng = np.random.default_rng(seed)
        pt = np.clip(pt + rng.normal(0.0, jitter, size=pt.shape), 1e-9, 1-1e-9)
    dfp = pd.DataFrame({'k':keys, 'pt':pt})
    mask = pre_event_mask(keys); dfp = dfp[mask]
    if len(dfp)==0: return pd.Series(dtype=float)
    med = dfp.groupby('k')['pt'].transform('median')
    exc = (dfp['pt'] - med).clip(lower=0)
    def pmean(x, rr):
        xv = x.values
        return (((xv**rr).mean())**(1.0/rr)) if xv.size>0 else np.nan
    def topk_mean(x, kk):
        xv = np.sort(x.values)
        if xv.size==0: return np.nan
        kk = min(kk, xv.size)
        return float(xv[-kk:].mean())
    pm = exc.groupby(dfp['k']).apply(lambda s: pmean(s, r))
    tk = exc.groupby(dfp['k']).apply(lambda s: topk_mean(s, topk))
    return (pm + tk)/2.0

# y per cow (if tab) or per image (fallback)
if use_tab_split:
    yva_cow = val_df.groupby(KEY)['class1'].max().astype(int)
    yte_cow = test_df.groupby(KEY)['class1'].max().astype(int)
    taus, rs, topks = [2.0,2.6,3.0], [0.7,0.9], [2,3,4]
    best_img = None
    keys_va = Kva_i
    for (tau, r, k) in tqdm([(t,r,kk) for t in taus for r in rs for kk in topks], desc="Tune pooling (IMG)"):
        pva_img_c = pooling_scores(pva_img_v, keys_va, tau, r, k)
        pva_img_c = pva_img_c.reindex(yva_cow.index).fillna(0.0).values
        ap = average_precision_score(yva_cow.values.astype(int), pva_img_c)
        if (best_img is None) or (ap > best_img[0]):
            best_img = (ap, tau, r, k, pva_img_c)
    ap_img, TAU_I, R_I, K_I, pva_img_c = best_img
    pte_img_c = pooling_scores(pte_img_v, Kte_i, TAU_I, R_I, K_I).reindex(yte_cow.index).fillna(0.0).values
    print(f"[Tune IMG] AP(VAL)={ap_img:.4f} tau={TAU_I}, r={R_I}, K={K_I}", flush=True)
else:
    yva_cow = pd.Series(yva_i, index=Kva_i)
    yte_cow = pd.Series(yte_i, index=Kte_i)
    pva_img_c = pd.Series(pva_img_v, index=Kva_i).reindex(yva_cow.index).fillna(0.0).values
    pte_img_c = pd.Series(pte_img_v, index=Kte_i).reindex(yte_cow.index).fillna(0.0).values

# Platt calibration for images
cal_img = LRCal(max_iter=1000, random_state=SEED).fit(pva_img_c.reshape(-1,1), yva_cow.values.astype(int))
pva_img_cal = cal_img.predict_proba(pva_img_c.reshape(-1,1))[:,1]
pte_img_cal = cal_img.predict_proba(pte_img_c.reshape(-1,1))[:,1]

# ===== 9) Fusion with tabular (only if we have cow mapping and pva_c/pte_c) =====
tab_ready = use_tab_split and ('pva_c' in globals()) and ('pte_c' in globals())
if not tab_ready:
    print("[Fusion] Images-only (missing cow mapping or pva_c/pte_c).", flush=True)
    pva_tab = np.zeros_like(pva_img_cal); pte_tab = np.zeros_like(pte_img_cal)
else:
    pva_tab = pd.Series(pva_c, index=yva_cow.index).reindex(yva_cow.index).fillna(0.0).values
    pte_tab = pd.Series(pte_c, index=yte_cow.index).reindex(yte_cow.index).fillna(0.0).values

def ranknorm(x):
    r = np.argsort(np.argsort(x))
    return r / max(len(x)-1, 1)

weights = [0.0, 0.25, 0.5, 0.75, 1.0]
best = None
for w in tqdm(weights, desc="Tune fusion weight", mininterval=0.1, leave=True):
    va_f = w*ranknorm(pva_tab) + (1-w)*ranknorm(pva_img_cal)
    ap = average_precision_score(yva_cow.values.astype(int), va_f)
    if (best is None) or (ap > best[0]):
        best = (ap, w, va_f)
ap_fuse, W, va_fused = best
te_fused = W*ranknorm(pte_tab) + (1-W)*ranknorm(pte_img_cal)

cal_f = LRCal(max_iter=1000, random_state=SEED).fit(va_fused.reshape(-1,1), yva_cow.values.astype(int))
pva_f = cal_f.predict_proba(va_fused.reshape(-1,1))[:,1]
pte_f = cal_f.predict_proba(te_fused.reshape(-1,1))[:,1]

# ===== 10) Metrics & utilities =====
def metr(name, y, p):
    p = np.clip(p, 1e-9, 1-1e-9)
    try: auc = roc_auc_score(y, p)
    except: auc = np.nan
    ap = average_precision_score(y, p); br = brier_score_loss(y, p)
    return dict(name=name, AUROC=float(auc) if auc==auc else np.nan, AUPRC=float(ap), Brier=float(br), N=int(len(y)))

rows = []
rows.append(metr("VAL IMG only",  yva_cow.values, pva_img_cal))
rows.append(metr("TEST IMG only", yte_cow.values, pte_img_cal))
if tab_ready:
    rows.append(metr("VAL TAB only",  yva_cow.values, pva_tab))
    rows.append(metr("TEST TAB only", yte_cow.values, pte_tab))
rows.append(metr(f"VAL FUSION (w={W:.2f})",  yva_cow.values, pva_f))
rows.append(metr(f"TEST FUSION (w={W:.2f})", yte_cow.values, pte_f))

summary_df = pd.DataFrame(rows)
print("\n=== Multimodal Summary ===", flush=True)
print(summary_df[["name","AUROC","AUPRC","Brier","N"]].to_string(index=False), flush=True)

# ===== 11) Robustness section: Bootstrap on TEST (AUROC, AUPRC) =====
def bootstrap_metrics(y, p, n_boot=200, seed=SEED):
    rng = np.random.default_rng(seed)
    y = np.asarray(y, dtype=int); p = np.asarray(p, dtype=float)
    n = len(y)
    aucs, aps = [], []
    for _ in range(n_boot):
        idx = rng.integers(0, n, size=n)
        yy, pp = y[idx], p[idx]
        if len(np.unique(yy)) < 2:
            aucs.append(np.nan)
        else:
            aucs.append(roc_auc_score(yy, pp))
        aps.append(average_precision_score(yy, pp))
    aucs = np.array(aucs, dtype=float); aps = np.array(aps, dtype=float)
    def stat(x):
        x = x[np.isfinite(x)]
        if x.size == 0:
            return dict(mean=np.nan, std=np.nan, ci_lo=np.nan, ci_hi=np.nan)
        return dict(mean=float(np.mean(x)),
                    std=float(np.std(x, ddof=1) if x.size>1 else 0.0),
                    ci_lo=float(np.quantile(x, 0.025)),
                    ci_hi=float(np.quantile(x, 0.975)))
    return stat(aucs), stat(aps)

boot_results = []
def add_boot(name, y, p):
    auc_s, ap_s = bootstrap_metrics(y, p, n_boot=200)
    boot_results.append(dict(model=name,
                             AUROC_mean=auc_s['mean'], AUROC_std=auc_s['std'], AUROC_ci_lo=auc_s['ci_lo'], AUROC_ci_hi=auc_s['ci_hi'],
                             AUPRC_mean=ap_s['mean'], AUPRC_std=ap_s['std'], AUPRC_ci_lo=ap_s['ci_lo'], AUPRC_ci_hi=ap_s['ci_hi']))

add_boot("TEST IMG only", yte_cow.values, pte_img_cal)
if tab_ready:
    add_boot("TEST TAB only", yte_cow.values, pte_tab)
add_boot(f"TEST FUSION (w={W:.2f})", yte_cow.values, pte_f)

boot_df = pd.DataFrame(boot_results)
print("\n=== Bootstrap (TEST) ===")
print(boot_df.to_string(index=False))

# ===== 12) Threshold from VAL (max F1) + Confusion Matrix on TEST =====
def best_thresh_by_f1(y, p):
    prec, rec, thr = precision_recall_curve(y, p)
    f1 = np.where((prec+rec) > 0, 2*prec*rec/(prec+rec), 0.0)
    ix = int(np.nanargmax(f1))
    if ix >= len(thr):
        t = 0.5
    else:
        t = float(thr[ix])
    return float(t), float(f1[ix] if ix < len(f1) else 0.0)

# choose final score: FUSION (if tab_ready) else IMG only
pva_final = (pva_f if tab_ready else pva_img_cal)
pte_final = (pte_f if tab_ready else pte_img_cal)
th_opt, f1_val = best_thresh_by_f1(yva_cow.values, pva_final)
print(f"\n[Thresh] Best F1 on VAL: threshold={th_opt:.4f}, F1={f1_val:.4f}")

yte_pred = (pte_final >= th_opt).astype(int)
cm = confusion_matrix(yte_cow.values, yte_pred, labels=[0,1])
tn, fp, fn, tp = cm.ravel() if cm.size==4 else (cm[0,0], cm[0,1], cm[1,0], cm[1,1])
acc = (tp+tn)/np.sum(cm)
prec = tp/max(tp+fp, 1)
rec  = tp/max(tp+fn, 1)
f1_te = 2*prec*rec/max(prec+rec, 1e-9)
print(f"[ConfMat TEST] TP={tp} FP={fp} FN={fn} TN={tn} | Acc={acc:.3f} Prec={prec:.3f} Rec={rec:.3f} F1={f1_te:.3f}")

# ===== 13) Figures & Tables (save to FIGDIR, overwrite) =====
def savefig(path):
    plt.savefig(path, dpi=200, bbox_inches='tight')
    plt.close()

# ROC & PR plots for VAL/TEST (final score)
def plot_roc_pr(y, p, split_name):
    # ROC
    if len(np.unique(y)) > 1:
        fpr, tpr, _ = roc_curve(y, p)
        auc = roc_auc_score(y, p)
    else:
        fpr, tpr, auc = np.array([0,1]), np.array([0,1]), np.nan
    plt.figure()
    plt.plot(fpr, tpr, label=f"AUC={auc:.3f}" if auc==auc else "AUC=N/A")
    plt.plot([0,1],[0,1],'--')
    plt.xlabel("FPR"); plt.ylabel("TPR"); plt.title(f"ROC — {split_name}")
    plt.legend(loc="lower right")
    savefig(os.path.join(FIGDIR, f"roc_{split_name.lower().replace(' ','_')}.png"))

    # PR
    prec, rec, _ = precision_recall_curve(y, p)
    ap = average_precision_score(y, p)
    plt.figure()
    plt.plot(rec, prec, label=f"AP={ap:.3f}")
    plt.xlabel("Recall"); plt.ylabel("Precision"); plt.title(f"PR — {split_name}")
    plt.legend(loc="lower left")
    savefig(os.path.join(FIGDIR, f"pr_{split_name.lower().replace(' ','_')}.png"))

# Confusion matrix plot for TEST
def plot_confmat(cm, split_name="TEST"):
    plt.figure()
    im = plt.imshow(cm, interpolation='nearest')
    plt.title(f"Confusion Matrix — {split_name}")
    plt.colorbar(im, fraction=0.046, pad=0.04)
    tick_marks = np.arange(2)
    plt.xticks(tick_marks, ['0','1'])
    plt.yticks(tick_marks, ['0','1'])
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(cm[i, j], 'd'),
                     ha="center", va="center", fontsize=10)
    plt.ylabel('True label'); plt.xlabel('Predicted label')
    savefig(os.path.join(FIGDIR, f"confusion_matrix_{split_name.lower()}.png"))

# Plots
plot_roc_pr(yva_cow.values, pva_final, "VAL_final")
plot_roc_pr(yte_cow.values, pte_final, "TEST_final")
plot_confmat(cm, "TEST")

# Tables (CSV) — overwrite
summary_df.to_csv(os.path.join(FIGDIR, "summary_multimodal.csv"), index=False)
boot_df.to_csv(os.path.join(FIGDIR, "bootstrap_test_metrics.csv"), index=False)
pd.DataFrame({
    "threshold": [th_opt],
    "F1_VAL": [f1_val],
    "Acc_TEST": [acc],
    "Precision_TEST": [prec],
    "Recall_TEST": [rec],
    "F1_TEST": [f1_te],
    "TP":[tp], "FP":[fp], "FN":[fn], "TN":[tn]
}).to_csv(os.path.join(FIGDIR, "threshold_confmat_stats.csv"), index=False)

# ===== 14) Robust debug JSON (serialisable) =====
def pyify(obj):
    import numpy as _np
    import torch as _torch
    import pandas as _pd
    if isinstance(obj, (_np.generic,)):
        return obj.item()
    if isinstance(obj, _torch.Tensor):
        return obj.item() if obj.ndim == 0 else obj.detach().cpu().tolist()
    if isinstance(obj, (_pd.Series, _pd.Index)):
        return obj.tolist()
    if isinstance(obj, dict):
        return {k: pyify(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
        return [pyify(x) for x in obj]
    if isinstance(obj, (np.ndarray,)):
        return obj.tolist()
    return obj

payload = dict(
    device=str(DEVICE),
    seed=int(SEED),
    counts=dict(train_img=int(len(lab_tr)), val_img=int(len(lab_va)), test_img=int(len(lab_te))),
    use_tab_split=bool(use_tab_split),
    fusion_weight=float(W),
    labels_dir=str(LABEL_DIR),
    image_dir=str(IMAGE_DIR),
    metrics=rows,
    threshold=float(th_opt),
    f1_val=float(f1_val),
    confmat=dict(TP=int(tp), FP=int(fp), FN=int(fn), TN=int(tn))
)
with open(os.path.join(SAVE_DIR, "debug_multimodal.json"), "w") as f:
    json.dump(pyify(payload), f, indent=2)

print(f"\n[OK] Figures and tables saved to: {FIGDIR}")
print(f"[OK] Quick summaries in: {SAVE_DIR}")


In [None]:
# =======================
# Cell 6.2 — SIMULATED cow mapping & "true" fusion (ablation)
# - Generate a synthetic filename→Cow_ID_match mapping consistent with split and class
# - Perform per-cow fusion using pva_img_cal/pte_img_cal
# - Save figures/tables with suffix *_cowfusion_SIM (overwrite)
# =======================
import os, json, re, sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression as LRCal
from sklearn.metrics import (average_precision_score, roc_auc_score,
                             precision_recall_curve, confusion_matrix, roc_curve)

# --- context requirements
need = ['lab_tr','lab_va','lab_te','pva_img_cal','pte_img_cal','Kva_i','Kte_i',
        'train_df','val_df','test_df']
for v in need:
    if v not in globals():
        raise RuntimeError(f"[SIM WARN] Missing '{v}'. Run Cell 6 first and ensure tabular DFs are in memory.")

KEY   = 'Cow_ID_match'
LABEL_DIR = os.path.join(PROJECT_DIR, "labels")
FIGDIR    = os.path.join(PROJECT_DIR, "figures_and_tables")
os.makedirs(LABEL_DIR, exist_ok=True)
os.makedirs(FIGDIR, exist_ok=True)

print("\n[SIM NOTICE] Generating a SYNTHETIC filename→Cow_ID_match mapping for a sensitivity experiment.")
print("             DO NOT USE these numbers as 'true multimodality' in the final paper.\n")

# --- utilities
def cows_posneg(df):
    """Return lists of positive and negative cows in the tabular split (cow label = max over visits)."""
    g = df.groupby(KEY)['class1'].max().astype(int)
    pos = g[g==1].index.astype(str).tolist()
    neg = g[g==0].index.astype(str).tolist()
    return pos, neg

def simulated_map_for_split(img_df, tab_df, seed=42):
    """
    Assign each image to a cow in the *same split* with the *same class*.
    Multiple images may map to the same cow (allowed).
    """
    rng = np.random.default_rng(seed)
    pos_cows, neg_cows = cows_posneg(tab_df)
    out_rows = []
    for _, r in img_df.iterrows():
        fname = os.path.basename(r['abs_path'])
        y = int(r['class1'])
        pool = pos_cows if y==1 else neg_cows
        if len(pool)==0:
            # Fallback: if the split has no cows of that class, sample from the other set (rare)
            pool = (pos_cows+neg_cows) if (pos_cows or neg_cows) else [f"SIM{rng.integers(100000)}"]
        cow = str(rng.choice(pool))
        out_rows.append((fname, cow))
    return pd.DataFrame(out_rows, columns=['filename','cow_id_match'])

# --- build mapping for the three splits
sim_tr = simulated_map_for_split(lab_tr, train_df, seed=42)
sim_va = simulated_map_for_split(lab_va, val_df,   seed=43)
sim_te = simulated_map_for_split(lab_te, test_df,  seed=44)

sim_map = pd.concat([sim_tr, sim_va, sim_te], axis=0, ignore_index=True)
sim_map_path_csv  = os.path.join(LABEL_DIR, "filename_to_cow_SIMULATED.csv")
sim_map_path_json = os.path.join(LABEL_DIR, "filename_to_cow_SIMULATED.json")
sim_map.to_csv(sim_map_path_csv, index=False)
json.dump({r['filename']: str(r['cow_id_match']) for _, r in sim_map.iterrows()},
          open(sim_map_path_json, "w"), indent=2)
print(f"[SIM] Saved synthetic mapping:\n      - {sim_map_path_csv}\n      - {sim_map_path_json}")

# --- apply mapping to image DFs
def apply_sim_map(df, sim_map):
    df = df.copy()
    df['filename'] = df['abs_path'].apply(os.path.basename)
    m = sim_map.set_index('filename')['cow_id_match']
    df[KEY] = df['filename'].map(m)
    return df

lab_tr_m = apply_sim_map(lab_tr, sim_map)
lab_va_m = apply_sim_map(lab_va, sim_map)
lab_te_m = apply_sim_map(lab_te, sim_map)

print(f"[SIM] Images mapped (per cow): train={lab_tr_m[KEY].notna().sum()} | val={lab_va_m[KEY].notna().sum()} | test={lab_te_m[KEY].notna().sum()}")

# --- rebuild Series p(img) indexed by filename (from Cell 6)
pva_img_series = pd.Series(pva_img_cal, index=lab_va['abs_path'].apply(os.path.basename))
pte_img_series = pd.Series(pte_img_cal, index=lab_te['abs_path'].apply(os.path.basename))

# --- per-cow pooling (TopK=2 mean)
def pool_by_cow(df_split_m, p_series, K=2):
    g = df_split_m.groupby(KEY)['filename'].apply(list)
    pooled = {}
    for cow, files in g.items():
        vals = [p_series.get(f, np.nan) for f in files]
        vals = np.array([v for v in vals if np.isfinite(v)], dtype=float)
        if vals.size == 0:
            pooled[cow] = np.nan
        else:
            k = min(K, vals.size)
            pooled[cow] = float(np.sort(vals)[-k:].mean())
    return pd.Series(pooled).sort_index()

# target per cow from images (any positive makes the cow positive)
def cow_label_from_imgs(df_split_m):
    return df_split_m.groupby(KEY)['class1'].max().astype(int).sort_index()

yva_cow_img = cow_label_from_imgs(lab_va_m)
yte_cow_img = cow_label_from_imgs(lab_te_m)
pva_cow_img = pool_by_cow(lab_va_m, pva_img_series, K=2)
pte_cow_img = pool_by_cow(lab_te_m, pte_img_series, K=2)

# --- Platt calibration on VAL at cow-level
cal_cow = LRCal(max_iter=1000, random_state=42).fit(pva_cow_img.values.reshape(-1,1), yva_cow_img.values)
pva_cow_cal = cal_cow.predict_proba(pva_cow_img.values.reshape(-1,1))[:,1]
pte_cow_cal = cal_cow.predict_proba(pte_cow_img.values.reshape(-1,1))[:,1]

# --- align with tabular for true per-cow fusion (if pva_c/pte_c available)
tab_ready = ('pva_c' in globals()) and ('pte_c' in globals())
if tab_ready:
    # pva_c/pte_c must be per-cow probabilities computed in Cell 5
    pv_tab = pd.Series(pva_c).copy()
    pt_tab = pd.Series(pte_c).copy()
    pv_tab = pv_tab.reindex(yva_cow_img.index).fillna(0.0).values
    pt_tab = pt_tab.reindex(yte_cow_img.index).fillna(0.0).values
else:
    pv_tab = np.zeros_like(pva_cow_cal); pt_tab = np.zeros_like(pte_cow_cal)
    print("[SIM] pva_c/pte_c not present — fusion with tabular channel = 0 (placeholder).")

def ranknorm(x):
    r = np.argsort(np.argsort(x))
    return r / max(len(x)-1, 1)

# tune fusion weight on VAL
best = None
for w in [0.0, 0.25, 0.5, 0.75, 1.0]:
    va_f = w*ranknorm(pv_tab) + (1-w)*ranknorm(pva_cow_cal)
    ap = average_precision_score(yva_cow_img.values, va_f)
    if (best is None) or (ap > best[0]):
        best = (ap, w, va_f)
_, W_SIM, va_f = best
te_f = W_SIM*ranknorm(pt_tab) + (1-W_SIM)*ranknorm(pte_cow_cal)

def metr(y, p):
    p = np.clip(p, 1e-9, 1-1e-9)
    try: auc = roc_auc_score(y, p)
    except: auc = np.nan
    ap = average_precision_score(y, p)
    return auc, ap

auc_va_img, ap_va_img = metr(yva_cow_img.values, pva_cow_cal)
auc_te_img, ap_te_img = metr(yte_cow_img.values, pte_cow_cal)
auc_va_fus, ap_va_fus = metr(yva_cow_img.values, va_f)
auc_te_fus, ap_te_fus = metr(yte_cow_img.values, te_f)

print("\n=== SIMULATED True Multimodal (per cow) — results FOR ABLATION ONLY ===")
print(f"VAL IMG(cow):  AUC={auc_va_img:.3f}  AP={ap_va_img:.3f}")
print(f"TEST IMG(cow): AUC={auc_te_img:.3f}  AP={ap_te_img:.3f}")
print(f"VAL FUSION_SIM (w={W_SIM:.2f}):  AUC={auc_va_fus:.3f}  AP={ap_va_fus:.3f}")
print(f"TEST FUSION_SIM (w={W_SIM:.2f}): AUC={auc_te_fus:.3f}  AP={ap_te_fus:.3f}")

# --- threshold from VAL (max F1) + confusion matrix on TEST
def best_thresh_by_f1(y, p):
    prec, rec, thr = precision_recall_curve(y, p)
    f1 = np.where((prec+rec)>0, 2*prec*rec/(prec+rec), 0.0)
    ix = int(np.nanargmax(f1))
    t  = float(thr[ix]) if ix < len(thr) else 0.5
    return t, float(f1[ix] if ix < len(f1) else 0.0)

th_opt_sim, f1_val_sim = best_thresh_by_f1(yva_cow_img.values, va_f)
yte_pred_sim = (te_f >= th_opt_sim).astype(int)
cm_sim = confusion_matrix(yte_cow_img.values, yte_pred_sim, labels=[0,1])
tn, fp, fn, tp = cm_sim.ravel() if cm_sim.size==4 else (cm_sim[0,0], cm_sim[0,1], cm_sim[1,0], cm_sim[1,1])

print(f"[SIM ConfMat TEST] TP={tp} FP={fp} FN={fn} TN={tn} | thr={th_opt_sim:.3f}")

# --- figures/tables with suffix _cowfusion_SIM (overwrite)
def savefig(path): plt.savefig(path, dpi=200, bbox_inches='tight'); plt.close()

def plot_roc_pr(y, p, split):
    if len(np.unique(y))>1:
        fpr,tpr,_= roc_curve(y,p); auc=roc_auc_score(y,p)
    else:
        fpr,tpr,auc = np.array([0,1]), np.array([0,1]), np.nan
    plt.figure()
    plt.plot(fpr,tpr,label=f"AUC={auc:.3f}" if auc==auc else "AUC=N/A")
    plt.plot([0,1],[0,1],'--')
    plt.title(f"ROC — {split} (cow-fusion SIM)")
    plt.xlabel("FPR"); plt.ylabel("TPR"); plt.legend(loc="lower right")
    savefig(os.path.join(FIGDIR, f"roc_{split.lower()}_cowfusion_SIM.png"))

    prec, rec, _ = precision_recall_curve(y,p); ap = average_precision_score(y,p)
    plt.figure()
    plt.plot(rec,prec,label=f"AP={ap:.3f}")
    plt.title(f"PR — {split} (cow-fusion SIM)")
    plt.xlabel("Recall"); plt.ylabel("Precision"); plt.legend(loc="lower left")
    savefig(os.path.join(FIGDIR, f"pr_{split.lower()}_cowfusion_SIM.png"))

def plot_confmat(cm, split="TEST"):
    plt.figure()
    im = plt.imshow(cm, interpolation='nearest')
    plt.title(f"Confusion Matrix — {split} (cow-fusion SIM)")
    plt.colorbar(im, fraction=0.046, pad=0.04)
    ticks = np.arange(2)
    plt.xticks(ticks, ['0','1']); plt.yticks(ticks, ['0','1'])
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(cm[i, j], 'd'), ha="center", va="center")
    plt.ylabel('True'); plt.xlabel('Predicted')
    savefig(os.path.join(FIGDIR, f"confusion_matrix_{split.lower()}_cowfusion_SIM.png"))

# plots
plot_roc_pr(yva_cow_img.values, va_f, "VAL")
plot_roc_pr(yte_cow_img.values, te_f, "TEST")
plot_confmat(cm_sim, "TEST")

# csv
pd.DataFrame({
    "name": ["VAL IMG(cow)", "TEST IMG(cow)", f"VAL FUSION_SIM (w={W_SIM:.2f})", f"TEST FUSION_SIM (w={W_SIM:.2f})"],
    "AUROC": [auc_va_img, auc_te_img, auc_va_fus, auc_te_fus],
    "AUPRC": [ap_va_img,  ap_te_img,  ap_va_fus,  ap_te_fus]
}).to_csv(os.path.join(FIGDIR, "summary_multimodal_cowfusion_SIM.csv"), index=False)

pd.DataFrame({
    "threshold_VAL": [th_opt_sim],
    "F1_VAL": [f1_val_sim],
    "TP":[int(tp)], "FP":[int(fp)], "FN":[int(fn)], "TN":[int(tn)]
}).to_csv(os.path.join(FIGDIR, "threshold_confmat_stats_cowfusion_SIM.csv"), index=False)

print(f"\n[SIM DONE] Figures & tables saved to {FIGDIR} with suffix *_cowfusion_SIM (overwrite).")
print("           Reminder: these results are FOR ABLATION ONLY, not for true multimodal claims.\n")
