---
# Cross-Language Generalization of Parkinson’s Disease Speech Detection Using Limited Target Data Calibration

Yash N. Ganatra, Michael E. DeBakey High School for Health Professions, Houston, TX
---

## Project Notes

This notebook implements a speech-based Parkinson’s disease detection study focused on cross-language generalization under limited target data exposure.

The project methodology, preprocessing logic, model configuration, and evaluation pipeline were defined by the author. AI-based coding and debugging support was used during software development to assist with implementation and iterative refinement or debugging. All scientific objectives, dataset selection, and evaluation criteria were determined by the author.

## Project context
This notebook contains the analysis and experiments supporting the paper:

“Cross-Language Generalization of Parkinson’s Disease Speech Detection Using Limited Target Data Calibration”

The notebook is intended for research and reproducibility only.


## Definition of the Datasets

---

## D1 – NeuroVoz

Primary language: Castilian Spanish,
Acronym used: ES

Description:
This dataset contains speech recordings labeled as Healthy Control or Parkinson’s Disease. It includes both sustained phonation (vowel) recordings and read or other speech tasks. Speaker and clip information is provided through CSV metadata files.

Counts:

Speakers

* Total speakers: 107
* Male speakers: 61
* Female speakers: 46
* Parkinson’s disease speakers: 52
* Healthy control speakers: 55

Audio clips

* Total audio clips used: 1692
* Vowel or sustained phonation clips: 837
* Reading or other speech clips: 855

---

## D2 – EWA-DB

Primary language: Slovak,
Acronym used: SK

Description:
This dataset includes speech recordings labeled as Healthy, Parkinson’s Disease, Alzheimer’s Disease, and Alzheimer–Parkinson’s Disease. Both sustained phonation and read or other speech tasks are available, with metadata stored in TSV files. Only Healthy Control and Parkinson’s Disease recordings were included in this study.

Counts:

Speakers

* Total speakers: 630
* Male speakers: 189
* Female speakers: 441
* Parkinson’s disease speakers: 92
* Healthy control speakers: 538

Audio clips

* Total audio clips used: 5670
* Vowel or sustained phonation clips: 630
* Reading or other speech clips: 5040

---

## D3 – UCI PD Speech (NOT USED FOR ANALYSIS)

Primary language: Turkish,
Acronym used: TK

Description:
This dataset consists of feature-based Excel files for Healthy Control and Parkinson’s Disease speakers. Raw audio recordings are not available. As a result, this dataset was not used for model training or evaluation.

---

## D4 – IPVS

Primary language: Italian,
Acronym used: IT

Description:
This dataset includes speech recordings from Young Healthy Control, Elderly Healthy Control, and Parkinson’s Disease speakers. Both sustained phonation and read or other speech tasks are available, with metadata provided in Excel files.

Counts:

Speakers

* Total speakers: 65
* Male speakers: 44
* Female speakers: 21
* Parkinson’s disease speakers: 28
* Healthy control speakers: 37

Audio clips

* Total audio clips used: 731
* Vowel or sustained phonation clips: 397
* Reading or other speech clips: 334

---

## D5 – MDVR-KCL

Primary language: English (UK),
Acronym used: EN, ENUK

Description:
This dataset contains speech recordings labeled as Healthy Control or Parkinson’s Disease. It includes read or other speech tasks only. No external metadata files are provided.

Preprocessing variants:

* Preprocessed_v1: Original 70/15/15 train, validation, and test speaker split
* Preprocessed_v2 (D5v2): Revised 50/20/30 split to improve monolingual evaluation due to the small number of test speakers

Counts:

Speakers

* Total speakers: 37
* Male speakers: 13
* Female speakers: 24
* Parkinson’s disease speakers: 16
* Healthy control speakers: 21

Audio clips

* Total audio clips used: 73
* Vowel or sustained phonation clips: 0
* Reading or other speech clips: 73

---

## D6 – “Ah Sound” Dataset (Figshare)

Primary language: Not applicable (sustained “ah” sound only; recordings originate from the United States),
Acronym used: EN, ENUS_AH

Description:
This dataset consists only of sustained phonation (“ah” sound) recordings. Clips are labeled as Healthy Control or Parkinson’s Disease, with metadata provided in an Excel file.

Counts:

Speakers

* Total speakers: 81
* Male speakers: 37
* Female speakers: 44
* Parkinson’s disease speakers: 40
* Healthy control speakers: 41

Audio clips

* Total audio clips used: 81
* Vowel or sustained phonation clips: 81
* Reading or other speech clips: 0

---

## D7 – Multilingual Combined Dataset

Name: D7 Multilingual (combined training pool)

Description:
This dataset was created by merging D1, D4, D5v2, and D6 into a single training pool. It was designed to support multilingual training and evaluation of cross-language generalization.

Primary language: Multilingual by construction

Counts:

Speakers

* Total speakers: 290
* Male speakers: 155
* Female speakers: 135
* Parkinson’s disease speakers: 136
* Healthy control speakers: 154

Audio clips

* Total audio clips used: 2577
* Vowel or sustained phonation clips: 1315
* Reading or other speech clips: 1262

---

The following cell resolves a common Google Colab issue where the Google Drive mount directory (/content/drive) already contains leftover files from a previous session, which can cause mounting to fail. The cell first checks whether the mountpoint exists and prints a brief view of its current contents. It then safely unmounts any existing Drive connection using drive.flush_and_unmount(), which has no effect if Drive is not mounted. Next, it recreates the mountpoint folder if needed and removes only the local files and folders inside /content/drive to ensure the directory is empty, without deleting any data from Google Drive itself. After cleaning the mountpoint, it remounts Google Drive at /content/drive using force_remount=True. Finally, it verifies that the expected project directories (such as /content/drive/MyDrive/AI_PD_Project and the Datasets folder) are present and prints the top-level contents of the project folder to confirm that the Drive connection and paths are set up correctly.

In [None]:
# Drive remount cleanup: clear a dirty mountpoint and verify project folders
# Goal: fix the Colab error where the mount folder already has leftover files.
# Inputs: none (checks the local /content/drive mountpoint).
# Outputs: a fresh Drive mount at /content/drive and printed checks for key project paths.

import os, shutil

MOUNT_POINT = "/content/drive"

# Current mountpoint status (exists and first few items)
print("Exists?", os.path.exists(MOUNT_POINT))
if os.path.exists(MOUNT_POINT):
    print("Contents BEFORE:", os.listdir(MOUNT_POINT)[:50])

# Unmount any existing Drive mount (safe even if Drive is not mounted)
try:
    from google.colab import drive
    drive.flush_and_unmount()
    print("flush_and_unmount() called.")
except Exception as e:
    print("flush_and_unmount() not available or failed:", repr(e))

# Recreate the mountpoint folder to ensure it exists
os.makedirs(MOUNT_POINT, exist_ok=True)

# Clear leftover files or folders inside the mountpoint directory only
# Note: this deletes only items under /content/drive (the local mount folder), not Drive contents.
if os.listdir(MOUNT_POINT):
    print("Clearing mountpoint folder contents...")
    for name in os.listdir(MOUNT_POINT):
        p = os.path.join(MOUNT_POINT, name)
        if os.path.isdir(p) and not os.path.islink(p):
            shutil.rmtree(p)
        else:
            os.remove(p)

print("Contents AFTER clearing:", os.listdir(MOUNT_POINT))

# Mount Google Drive at the cleaned mountpoint
from google.colab import drive
drive.mount(MOUNT_POINT, force_remount=True)

# Quick sanity check: confirm expected project folders exist after mount
PROJECT_DIR = "/content/drive/MyDrive/AI_PD_Project"
DATASETS_DIR = f"{PROJECT_DIR}/Datasets"

print("PROJECT_DIR exists?", os.path.exists(PROJECT_DIR))
print("DATASETS_DIR exists?", os.path.exists(DATASETS_DIR))
if os.path.exists(PROJECT_DIR):
    print("Top-level:", os.listdir(PROJECT_DIR)[:50])

# Preprocessing of D1, D2, D4, D5 and D6 Datasets

The following cell runs the full **D1 (NeuroVoz, Castilian Spanish) preprocessing** in a single step and creates a standardized `preprocessed_v1` folder that is ready for model training. It mounts Google Drive if needed, installs required libraries (WebRTC VAD and SciPy when missing), and applies fixed preprocessing settings, including a 16 kHz sample rate, loudness normalization rules, voice activity detection settings, a maximum number of clips per speaker, and a speaker level train validation test split of 70 15 15. Output folders are created at the start, and the temporary `_candidates` folder is cleared to prevent mixing files from different runs.

The cell then reads metadata from **two CSV files** (`metadata_hc.csv` and `metadata_pd.csv`), combines them into a single table, and checks for required columns (`ID`, `Group`, `Audio`). Age and sex fields are included only if they are present. Each entry is mapped to a speaker ID, a Healthy or Parkinson label, and an audio file path. Any entries with missing or invalid audio paths are skipped.

Task type is determined from the audio filename and path. Vowel tasks are identified using codes such as `A1–U3`, spontaneous speech is identified using keyword matches (for example, “spont”), and all remaining files are treated as reading tasks.

For each valid source audio file, the cell generates **at most one clip per source file**. Audio is converted to mono, resampled to 16 kHz, and normalized to a consistent loudness using simple RMS based leveling with a peak limit. Speech regions are detected using WebRTC VAD when available, with a fallback to an energy based method. From these regions, the **single longest voiced segment** on the original timeline is selected. Clip extraction follows fixed rules:

* **Vowel**: one **2.0 second** clip centered within the longest voiced segment, padded with silence if needed.
* **Reading or spontaneous**: one **8.0 second** clip taken from the start of the longest voiced segment, or the full segment if shorter than 8 seconds, without padding as long as it meets the minimum length.

If no usable voiced segment is found, the file is skipped and a warning is recorded.

Each extracted clip is first saved as a temporary WAV file in `clips/_candidates/`, and a candidate table is built with details such as speaker, task type, label, duration, source path, and clip start and end times. After all candidates are collected, the cell limits the data to at most **8 clips per speaker per task**, removes unused candidate files to save space, and assigns each speaker to exactly one split (train, validation, or test) while keeping class balance across splits. The selected clips are then moved to `clips/train`, `clips/val`, and `clips/test` using standardized filenames.

At the end, the cell writes all required outputs: `manifests/manifest_all.csv` and per split manifest files, `logs/preprocess_warnings.csv`, `logs/dataset_summary.json`, and `config/run_config.json`. Final checks confirm that all audio files are mono at 16 kHz and that no speaker appears in more than one split.

In [None]:
# NeuroVoz (D1) preprocessing: 1 clip per source, standardized outputs
# Builds train/val/test clips and a single manifest from two metadata tables.
# Writes: clips/ (flat per split), manifests/manifest_all.csv, config/run_config.json, logs/*

# =========================
# D1 Preprocessing v1 (NeuroVoz) — SINGLE COMPLETE CELL (CONSISTENCY-UPDATED)
# Summary of behavior:
# - For each source audio file, create at most 1 candidate clip and write it immediately
# - Select the clip from the longest voiced segment on the source timeline (after resample)
# - Vowel tasks: 2.0 s clip (centered if long enough, else zero-padded)
# - Reading or spontaneous: 8.0 s clip if possible, else keep true duration (no padding)
# - After candidates: cap by (speaker_id, task), split speakers into train/val/test, then move kept clips into clips/<split>/
# =========================

import os
import re
import json
import math
import random
from typing import List, Tuple, Dict, Optional

import numpy as np
import pandas as pd
import soundfile as sf
from tqdm.auto import tqdm

# -------------------------
# Drive mount
# -------------------------
if not os.path.isdir("/content/drive/MyDrive"):
    from google.colab import drive  # type: ignore
    drive.mount("/content/drive")

# -------------------------
# Optional packages
# - webrtcvad: voice activity detection
# - scipy: higher-quality resampling
# -------------------------
def _d1_try_import_webrtcvad():
    try:
        import webrtcvad  # type: ignore
        return webrtcvad, True
    except Exception:
        return None, False

webrtcvad, HAVE_WEBRTCVAD = _d1_try_import_webrtcvad()
if not HAVE_WEBRTCVAD:
    !pip -q install webrtcvad
    webrtcvad, HAVE_WEBRTCVAD = _d1_try_import_webrtcvad()

try:
    from scipy import signal  # type: ignore
    HAVE_SCIPY = True
except Exception:
    HAVE_SCIPY = False
    !pip -q install scipy
    from scipy import signal  # type: ignore
    HAVE_SCIPY = True

# -------------------------
# Paths
# Inputs: two metadata CSVs + audio files referenced inside them
# Outputs: clips/, manifests/, config/, logs/
# -------------------------
D1_PROJECT_DIR  = "/content/drive/MyDrive/AI_PD_Project"
D1_DATASETS_DIR = f"{D1_PROJECT_DIR}/Datasets"
D1_DIR          = f"{D1_DATASETS_DIR}/D1-NeuroVoz-Castillan Spanish"

D1_METADATA_DIR = f"{D1_DIR}/data/metadata"
D1_META_HC      = f"{D1_METADATA_DIR}/metadata_hc.csv"
D1_META_PD      = f"{D1_METADATA_DIR}/metadata_pd.csv"

# Outputs
D1_OUT_ROOT     = f"{D1_DIR}/preprocessed_v1"
D1_CLIPS_DIR    = f"{D1_OUT_ROOT}/clips"
D1_CAND_DIR     = f"{D1_CLIPS_DIR}/_candidates"
D1_MANIFEST_DIR = f"{D1_OUT_ROOT}/manifests"
D1_CONFIG_DIR   = f"{D1_OUT_ROOT}/config"
D1_LOGS_DIR     = f"{D1_OUT_ROOT}/logs"

for p in [D1_OUT_ROOT, D1_CLIPS_DIR, D1_CAND_DIR, D1_MANIFEST_DIR, D1_CONFIG_DIR, D1_LOGS_DIR]:
    os.makedirs(p, exist_ok=True)
for sp in ["train", "val", "test"]:
    os.makedirs(os.path.join(D1_CLIPS_DIR, sp), exist_ok=True)

# Candidate workspace: cleared at start to avoid mixing with older runs
if os.path.isdir(D1_CAND_DIR):
    try:
        import shutil
        shutil.rmtree(D1_CAND_DIR)
    except Exception:
        pass
os.makedirs(D1_CAND_DIR, exist_ok=True)

# -------------------------
# Quick input checks
# -------------------------
print("D1_DIR exists?", os.path.exists(D1_DIR))
print("D1_METADATA_DIR exists?", os.path.exists(D1_METADATA_DIR))
print("D1_META_HC exists?", os.path.exists(D1_META_HC))
print("D1_META_PD exists?", os.path.exists(D1_META_PD))
print("webrtcvad available?", HAVE_WEBRTCVAD)
print("scipy available?", HAVE_SCIPY)

if not os.path.exists(D1_DIR):
    raise FileNotFoundError(f"D1_DIR not found: {D1_DIR}")

if not os.path.exists(D1_METADATA_DIR):
    top = sorted(os.listdir(D1_DIR))[:50]
    raise FileNotFoundError(
        f"D1 metadata folder not found: {D1_METADATA_DIR}\n"
        f"Top-level contents of D1_DIR (first 50): {top}"
    )

if not os.path.exists(D1_META_HC) or not os.path.exists(D1_META_PD):
    md = sorted(os.listdir(D1_METADATA_DIR))
    raise FileNotFoundError(
        "D1 metadata not found at expected paths:\n"
        f"  {D1_META_HC}\n"
        f"  {D1_META_PD}\n"
        f"Files present in metadata dir: {md}"
    )

# -------------------------
# Audio + split settings
# -------------------------
D1_SR = 16000
D1_RANDOM_SEED = 1337
random.seed(D1_RANDOM_SEED)
np.random.seed(D1_RANDOM_SEED)

# Loudness: RMS target with a peak limiter (simple and stable)
D1_TARGET_RMS_DBFS = -20.0
D1_PEAK_LIMIT_DBFS = -1.0
D1_MIN_RMS_DBFS    = -60.0
D1_MAX_GAIN_DB     = 18.0

# Voice activity detection settings (segment finding on source timeline)
D1_VAD_MODE       = 2
D1_FRAME_MS       = 30
D1_PAD_SEC        = 0.25
D1_TRAIL_PAD_SEC  = 0.15
D1_MAX_GAP_MS     = 200
D1_MIN_KEEP_SEC   = 0.30

# Clip length rules
D1_VOWEL_SEC = 2.0
D1_OTHER_SEC = 8.0

# Cap per speaker and task (keeps datasets balanced)
D1_MAX_CLIPS_PER_SPK_PER_TASK = 8

# Speaker-level split (keeps a speaker in only one split)
D1_TRAIN_PCT, D1_VAL_PCT, D1_TEST_PCT = 0.70, 0.15, 0.15

# -------------------------
# Helper functions: naming, reading, resampling, normalization, labels
# -------------------------
def d1_safe(s: str) -> str:
    return re.sub(r"[^A-Za-z0-9_\-\.]+", "_", str(s))

def d1_db_to_lin(db: float) -> float:
    return 10.0 ** (db / 20.0)

def d1_rms_dbfs(y: np.ndarray) -> float:
    if y is None or len(y) == 0:
        return -120.0
    rms = float(np.sqrt(np.mean(y.astype(np.float64) ** 2) + 1e-12))
    return 20.0 * math.log10(max(rms, 1e-12))

def d1_peak_limit(y: np.ndarray, peak_dbfs: float) -> np.ndarray:
    if y is None or len(y) == 0:
        return y
    peak = float(np.max(np.abs(y)))
    lim = d1_db_to_lin(peak_dbfs)
    if peak > lim and peak > 0:
        y = y * (lim / peak)
    return np.clip(y, -1.0, 1.0).astype(np.float32)

def d1_norm_rms_then_peak(y: np.ndarray) -> np.ndarray:
    if y is None or len(y) == 0:
        return y
    cur = d1_rms_dbfs(y)
    if cur < D1_MIN_RMS_DBFS:
        return d1_peak_limit(y.astype(np.float32), D1_PEAK_LIMIT_DBFS)
    gain_db = float(D1_TARGET_RMS_DBFS - cur)
    gain_db = float(np.clip(gain_db, -60.0, D1_MAX_GAIN_DB))
    y2 = (y.astype(np.float32) * d1_db_to_lin(gain_db)).astype(np.float32)
    return d1_peak_limit(y2, D1_PEAK_LIMIT_DBFS)

def d1_read_mono(path: str) -> Tuple[np.ndarray, int]:
    x, sr = sf.read(path, always_2d=True)
    x = x.mean(axis=1).astype(np.float32)
    if not np.isfinite(x).all():
        x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
    return x, int(sr)

def d1_resample_to_16k(x: np.ndarray, sr_in: int) -> np.ndarray:
    if sr_in == D1_SR:
        return x.astype(np.float32, copy=False)
    g = math.gcd(sr_in, D1_SR)
    up = D1_SR // g
    down = sr_in // g
    y = signal.resample_poly(x.astype(np.float64), up, down).astype(np.float32)
    return y.astype(np.float32, copy=False)

def d1_float_to_pcm16_bytes(y: np.ndarray) -> bytes:
    y = np.clip(y, -1.0, 1.0)
    return (y * 32767.0).astype(np.int16).tobytes()

# Task inference from filename patterns
def d1_task_code(audio_path: str) -> str:
    stem = os.path.splitext(os.path.basename(audio_path))[0]
    parts = stem.split("_")
    if len(parts) >= 3:
        return str(parts[1]).upper()
    m = re.match(r"^(HC|PD)_(.+)_(\d+)$", stem, flags=re.IGNORECASE)
    if m:
        return str(m.group(2)).upper()
    return "UNKNOWN"

def d1_is_vowel(code: str) -> bool:
    return re.fullmatch(r"[AEIOU][1-3]", (code or "").upper()) is not None

SPONT_KEYS = ["ESPONT", "SPONT", "FREE", "MONOLOG", "MONOLO", "LIBRE", "DIALOG", "CONVERS"]
def d1_is_spont(code: str, path: str) -> bool:
    cu = (code or "").upper()
    pu = (path or "").upper()
    return any((k in cu) or (k in pu) for k in SPONT_KEYS)

def d1_task_type(code: str, path: str) -> str:
    if d1_is_vowel(code):
        return "vowel"
    if d1_is_spont(code, path):
        return "spontaneous"
    return "reading"

def d1_task_short(tt: str) -> str:
    tt = (tt or "").lower()
    if tt == "vowel": return "vowl"
    if tt == "reading": return "read"
    if tt == "spontaneous": return "spont"
    return "unk"

def d1_label(group_val: str) -> Tuple[str, int]:
    g = str(group_val).strip().lower()
    if ("pd" in g) or ("parkinson" in g):
        return "Parkinson", 1
    return "Healthy", 0

def d1_pd_hc(label_str: str) -> str:
    return "PD" if str(label_str).lower().startswith("parkinson") else "HC"

def d1_write_wav(path: str, y: np.ndarray, sr: int):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    sf.write(path, np.clip(y, -1.0, 1.0).astype(np.float32), sr, subtype="PCM_16")

# -------------------------
# Voice segments on SOURCE timeline (after resample)
# Returns list of (start_sample, end_sample) segments
# -------------------------
def d1_vad_segments_source_timeline(y16: np.ndarray, sr: int) -> Optional[List[Tuple[int, int]]]:
    if not HAVE_WEBRTCVAD:
        return None
    if sr != 16000:
        return None

    frame_len = int(sr * (D1_FRAME_MS / 1000.0))
    if frame_len <= 0 or len(y16) < frame_len:
        return []

    n = len(y16)
    n_frames = n // frame_len
    if n_frames == 0:
        return []

    pcm = d1_float_to_pcm16_bytes(y16[:n_frames * frame_len])
    vad = webrtcvad.Vad(int(D1_VAD_MODE))

    def frame_bytes(i: int) -> bytes:
        s = i * frame_len * 2
        e = s + frame_len * 2
        return pcm[s:e]

    flags = [vad.is_speech(frame_bytes(i), 16000) for i in range(n_frames)]

    segs = []
    i = 0
    while i < n_frames:
        if not flags[i]:
            i += 1
            continue
        s0 = i
        while i < n_frames and flags[i]:
            i += 1
        e0 = i
        segs.append((s0 * frame_len, e0 * frame_len))

    max_gap = int(sr * (D1_MAX_GAP_MS / 1000.0))
    merged = []
    for s, e in segs:
        if not merged:
            merged.append([s, e])
        else:
            ps, pe = merged[-1]
            if s - pe <= max_gap:
                merged[-1][1] = max(pe, e)
            else:
                merged.append([s, e])

    pad = int(round(sr * D1_PAD_SEC))
    trail = int(round(sr * D1_TRAIL_PAD_SEC))

    out = []
    for s, e in merged:
        s2 = max(0, s - pad)
        e2 = min(n, e + pad + trail)
        if e2 > s2:
            out.append((int(s2), int(e2)))
    return out

def d1_energy_segments_source_timeline(y16: np.ndarray, sr: int) -> List[Tuple[int, int]]:
    frame = int(sr * 0.02)
    hop = frame
    if frame <= 0 or len(y16) < frame:
        return []

    eng = []
    idx = []
    for i in range(0, len(y16) - frame + 1, hop):
        w = y16[i:i+frame]
        eng.append(float(np.mean(w*w)))
        idx.append(i)
    eng = np.array(eng, dtype=np.float32)
    thr = float(np.percentile(eng, 25)) * 2.5
    thr = max(thr, 1e-8)
    keep = eng > thr

    segs = []
    on = False
    s0 = 0
    for k, flag in enumerate(keep):
        if flag and not on:
            on = True
            s0 = idx[k]
        elif (not flag) and on:
            on = False
            segs.append((s0, idx[k] + frame))
    if on and idx:
        segs.append((s0, idx[-1] + frame))

    max_gap = int(sr * (D1_MAX_GAP_MS / 1000.0))
    merged = []
    for s, e in sorted(segs):
        if not merged:
            merged.append([s, e])
        else:
            ps, pe = merged[-1]
            if s - pe <= max_gap:
                merged[-1][1] = max(pe, e)
            else:
                merged.append([s, e])

    pad = int(round(sr * D1_PAD_SEC))
    trail = int(round(sr * D1_TRAIL_PAD_SEC))
    out = []
    for s, e in merged:
        s2 = max(0, s - pad)
        e2 = min(len(y16), e + pad + trail)
        if e2 > s2:
            out.append((int(s2), int(e2)))
    return out

# -------------------------
# Single-clip selection per source file
# - Picks the longest voiced segment (ties go to earliest)
# - Builds exactly one clip based on task type
# -------------------------
def d1_force_length(x: np.ndarray, n: int) -> np.ndarray:
    if x is None:
        return np.zeros((n,), dtype=np.float32)
    if len(x) == n:
        return x.astype(np.float32)
    if len(x) > n:
        return x[:n].astype(np.float32)
    y = np.zeros((n,), dtype=np.float32)
    y[:len(x)] = x.astype(np.float32)
    return y

def d1_pick_longest_segment(segs: List[Tuple[int, int]]) -> Optional[Tuple[int, int]]:
    if not segs:
        return None
    segs2 = sorted(segs, key=lambda t: (-(t[1]-t[0]), t[0]))
    return segs2[0]

def d1_make_single_clip_from_source_segments(
    y16: np.ndarray,
    sr: int,
    segs: List[Tuple[int, int]],
    task_type: str
) -> Optional[Dict]:
    """
    Returns a single clip dict:
      {audio, clip_start_sec, clip_end_sec, duration_sec}
    clip_start/end are SOURCE timeline (post-resample).
    """
    if not segs:
        return None

    best = d1_pick_longest_segment(segs)
    if best is None:
        return None
    s, e = best
    seg = y16[s:e].astype(np.float32)
    if len(seg) <= 0:
        return None

    if task_type == "vowel":
        L = int(round(sr * D1_VOWEL_SEC))
        if len(seg) >= L:
            mid = len(seg) // 2
            a0 = max(0, mid - L // 2)
            a1 = a0 + L
            audio = seg[a0:a1].astype(np.float32)
            st = float(s + a0) / sr
            en = float(s + a1) / sr
        else:
            audio = d1_force_length(seg, L)
            st = float(s) / sr
            en = float(s + L) / sr
        return {
            "audio": np.clip(audio, -1.0, 1.0).astype(np.float32),
            "clip_start_sec": float(st),
            "clip_end_sec": float(en),
            "duration_sec": float(len(audio) / sr),
        }

    L = int(round(sr * D1_OTHER_SEC))
    if len(seg) >= L:
        audio = seg[:L].astype(np.float32)
        st = float(s) / sr
        en = float(s + L) / sr
        return {
            "audio": np.clip(audio, -1.0, 1.0).astype(np.float32),
            "clip_start_sec": float(st),
            "clip_end_sec": float(en),
            "duration_sec": float(L / sr),
        }

    if len(seg) < int(sr * D1_MIN_KEEP_SEC):
        return None
    st = float(s) / sr
    en = float(s + len(seg)) / sr
    return {
        "audio": np.clip(seg, -1.0, 1.0).astype(np.float32),
        "clip_start_sec": float(st),
        "clip_end_sec": float(en),
        "duration_sec": float(len(seg) / sr),
    }

# -------------------------
# Read metadata and build the source table
# Inputs: metadata_hc.csv + metadata_pd.csv
# Output: one combined table with audio paths, labels, and tasks
# -------------------------
hc_df = pd.read_csv(D1_META_HC)
pd_df = pd.read_csv(D1_META_PD)
d1_all = pd.concat([hc_df, pd_df], ignore_index=True).copy()

req = ["ID", "Group", "Audio"]
miss = [c for c in req if c not in d1_all.columns]
if miss:
    raise ValueError(f"D1 metadata missing required columns {miss}. Found: {list(d1_all.columns)}")

# Optional columns: use if present, otherwise leave as NaN
col_age = None
col_sex = None
for c in d1_all.columns:
    cl = str(c).strip().lower()
    if col_age is None and cl in {"age", "edad"}:
        col_age = c
    if col_sex is None and cl in {"sex", "gender", "sexo"}:
        col_sex = c

use_cols = req + ([col_age] if col_age else []) + ([col_sex] if col_sex else [])
d1 = d1_all[use_cols].copy()

d1["speaker_id"] = d1["ID"].astype(str)
d1["audio_path"] = d1["Audio"].apply(lambda p: os.path.join(D1_DIR, str(p).replace("\\", "/")))

lbl = d1["Group"].apply(d1_label)
d1["label_str"] = lbl.apply(lambda t: t[0])
d1["label_num"] = lbl.apply(lambda t: t[1]).astype(int)

d1["age"] = d1[col_age] if col_age else np.nan
d1["sex"] = d1[col_sex] if col_sex else np.nan

d1["exists"] = d1["audio_path"].apply(os.path.exists)
missing_audio = d1.loc[~d1["exists"], ["speaker_id","Group","Audio","audio_path"]].copy()

print("\nRows total:", len(d1))
print("Rows existing audio:", int(d1["exists"].sum()))
print("Rows missing audio:", int((~d1["exists"]).sum()))

d1 = d1.loc[d1["exists"]].copy()

# Infer task type from filename or path (vowel vs reading vs spontaneous)
d1["task_code"] = d1["audio_path"].apply(d1_task_code)
d1["task_type"] = d1.apply(lambda r: d1_task_type(r["task_code"], r["audio_path"]), axis=1)
d1["task"] = d1["task_type"].apply(d1_task_short)

print("\nTask counts (rows):")
print(d1["task"].value_counts(dropna=False))
print("\nLabel counts (rows):")
print(d1["label_str"].value_counts(dropna=False))

if len(d1) == 0:
    raise RuntimeError("No usable D1 rows after existence filtering.")

# -------------------------
# Candidate creation: 1 candidate WAV per source (written immediately)
# Output: cand_df table + warnings_df log
# -------------------------
MANIFEST_COLS = [
    "split","dataset","task","speaker_id","sample_id",
    "label_str","label_num","age","sex",
    "speaker_key_rel",
    "clip_path","duration_sec","source_path",
    "clip_start_sec","clip_end_sec","sr_hz","channels",
    "clip_is_contiguous",
]

cand_rows: List[Dict] = []
warn_rows: List[Dict] = []

cand_counter = 0

pbar = tqdm(d1.itertuples(index=False), total=len(d1), desc="D1 preprocess (1 clip per source; write candidates)", dynamic_ncols=True)

for r in pbar:
    src = r.audio_path
    try:
        x, sr0 = d1_read_mono(src)
        y = d1_resample_to_16k(x, sr0)
        y = d1_norm_rms_then_peak(y)

        segs = d1_vad_segments_source_timeline(y, D1_SR)
        vad_used = "webrtcvad"
        if segs is None:
            segs = d1_energy_segments_source_timeline(y, D1_SR)
            vad_used = "energy"

        if not segs:
            warn_rows.append({
                "dataset":"D1","speaker_id":str(r.speaker_id),"source_path":src,
                "warning_type":"no_vad_segments","detail":f"vad_used={vad_used}"
            })
            continue

        clip = d1_make_single_clip_from_source_segments(y, D1_SR, segs, r.task_type)
        if clip is None:
            warn_rows.append({
                "dataset":"D1","speaker_id":str(r.speaker_id),"source_path":src,
                "warning_type":"no_clip_selected","detail":f"vad_used={vad_used}"
            })
            continue

        task5 = d1_task_short(r.task_type)

        cand_counter += 1
        cand_name = d1_safe(f"CAND_{cand_counter:08d}.wav")
        cand_path = os.path.join(D1_CAND_DIR, cand_name)

        # Write candidate clip immediately
        d1_write_wav(cand_path, clip["audio"], D1_SR)

        cand_rows.append({
            "dataset": "D1",
            "task": task5,
            "speaker_id": str(r.speaker_id),
            "sample_id": os.path.basename(src),
            "label_str": r.label_str,
            "label_num": int(r.label_num),
            "age": r.age if pd.notna(r.age) else np.nan,
            "sex": r.sex if pd.notna(r.sex) else np.nan,
            "speaker_key_rel": np.nan,      # true null for D1
            "clip_path_cand": cand_path,    # candidate on disk
            "duration_sec": float(clip["duration_sec"]),
            "source_path": src,
            "clip_start_sec": float(clip["clip_start_sec"]),
            "clip_end_sec": float(clip["clip_end_sec"]),
            "sr_hz": int(D1_SR),
            "channels": 1,
            "clip_is_contiguous": True,
            "__vad_used__": vad_used,
        })

    except Exception as e:
        warn_rows.append({
            "dataset":"D1","speaker_id":str(getattr(r,"speaker_id","")),"source_path":src,
            "warning_type":"preprocess_error","detail":repr(e)
        })

cand_df = pd.DataFrame(cand_rows)
warnings_df = pd.DataFrame(warn_rows)

print("\nD1 candidates written:", int(len(cand_df)))
if len(cand_df) == 0:
    raise RuntimeError("No D1 clips produced. Check VAD settings and audio paths.")

# -------------------------
# Cap: limit clips per (speaker_id, task)
# Unkept candidates are deleted to save space
# -------------------------
def d1_cap_manifest_keep_set(df: pd.DataFrame, max_k: int, seed: int) -> Tuple[pd.DataFrame, set]:
    rng = np.random.default_rng(seed)
    kept_idx = []
    for (spk, task), g in df.groupby(["speaker_id", "task"], sort=False):
        idxs = g.index.to_numpy()
        if len(idxs) <= max_k:
            kept_idx.extend(idxs.tolist())
        else:
            chosen = rng.choice(idxs, size=max_k, replace=False)
            kept_idx.extend(chosen.tolist())
    kept_idx = sorted(set(kept_idx))
    return df.loc[kept_idx].reset_index(drop=True), set(kept_idx)

cand_df_capped, keep_set = d1_cap_manifest_keep_set(cand_df, D1_MAX_CLIPS_PER_SPK_PER_TASK, D1_RANDOM_SEED)
print("D1 clips after cap:", int(len(cand_df_capped)))

to_delete = cand_df.loc[~cand_df.index.isin(list(keep_set)), "clip_path_cand"].tolist()
deleted = 0
for p in to_delete:
    try:
        if os.path.exists(p):
            os.remove(p)
            deleted += 1
    except Exception as e:
        warnings_df = pd.concat([warnings_df, pd.DataFrame([{
            "dataset": "D1",
            "speaker_id": "",
            "source_path": "",
            "warning_type": "candidate_delete_failed",
            "detail": f"{p} :: {repr(e)}"
        }])], ignore_index=True)

print("Deleted unkept candidates:", deleted)
cand_df = cand_df_capped

# -------------------------
# Speaker-level split (label-stratified)
# Output: a split assignment per speaker, merged onto cand_df
# -------------------------
def d1_split_speakers(df: pd.DataFrame, fracs=(0.70, 0.15, 0.15), seed=1337) -> pd.DataFrame:
    rng = np.random.default_rng(seed)
    spk_df = df[["speaker_id", "label_num", "label_str"]].drop_duplicates().copy()

    rows = []
    for lbl in [0, 1]:
        spks = spk_df[spk_df["label_num"] == lbl]["speaker_id"].tolist()
        rng.shuffle(spks)
        n = len(spks)
        n_train = int(round(fracs[0] * n))
        n_val = int(round(fracs[1] * n))

        train_spks = spks[:n_train]
        val_spks = spks[n_train:n_train + n_val]
        test_spks = spks[n_train + n_val:]

        rows += [{"speaker_id": s, "split": "train"} for s in train_spks]
        rows += [{"speaker_id": s, "split": "val"} for s in val_spks]
        rows += [{"speaker_id": s, "split": "test"} for s in test_spks]

    return pd.DataFrame(rows)

spk_split = d1_split_speakers(cand_df, (D1_TRAIN_PCT, D1_VAL_PCT, D1_TEST_PCT), D1_RANDOM_SEED)

cand_df = cand_df.drop(columns=["split"], errors="ignore")
cand_df = cand_df.merge(spk_split, on="speaker_id", how="left", validate="many_to_one")

if cand_df["split"].isna().any():
    ex = cand_df.loc[cand_df["split"].isna(),"speaker_id"].drop_duplicates().head(10).tolist()
    raise RuntimeError(f"Some speakers did not get split. Example: {ex}")

# -------------------------
# Finalize: move kept candidates into clips/<split>/ and build manifest_all
# -------------------------
import shutil

global_counter = 0
final_rows: List[Dict] = []

pbar2 = tqdm(cand_df.itertuples(index=False), total=len(cand_df), desc="D1 finalize (move kept)", dynamic_ncols=True)

for r in pbar2:
    global_counter += 1

    label_tag = d1_pd_hc(r.label_str)
    task5 = str(r.task)
    spk = str(r.speaker_id)

    out_name = d1_safe(f"D1_{label_tag}_{spk}_{task5}_{global_counter:06d}.wav")
    out_path = os.path.join(D1_CLIPS_DIR, r.split, out_name)

    cand_path = getattr(r, "clip_path_cand")
    if not os.path.exists(cand_path):
        warnings_df = pd.concat([warnings_df, pd.DataFrame([{
            "dataset":"D1","speaker_id":spk,"source_path":r.source_path,
            "warning_type":"missing_candidate_file","detail":cand_path
        }])], ignore_index=True)
        continue

    shutil.move(cand_path, out_path)

    final_rows.append({
        "split": r.split,
        "dataset": "D1",
        "task": task5,
        "speaker_id": spk,
        "sample_id": r.sample_id,
        "label_str": r.label_str,
        "label_num": int(r.label_num),
        "age": r.age if pd.notna(r.age) else np.nan,
        "sex": r.sex if pd.notna(r.sex) else np.nan,
        "speaker_key_rel": np.nan,
        "clip_path": out_path,
        "duration_sec": float(r.duration_sec),
        "source_path": r.source_path,
        "clip_start_sec": float(r.clip_start_sec),
        "clip_end_sec": float(r.clip_end_sec),
        "sr_hz": int(r.sr_hz),
        "channels": 1,
        "clip_is_contiguous": True,
    })

# Cleanup: remove any leftover candidate files (best effort)
try:
    if os.path.isdir(D1_CAND_DIR):
        leftovers = list(os.scandir(D1_CAND_DIR))
        for ent in leftovers:
            try:
                os.remove(ent.path)
            except Exception:
                pass
        try:
            os.rmdir(D1_CAND_DIR)
        except Exception:
            pass
except Exception:
    pass

manifest_df = pd.DataFrame(final_rows)

# Ensure standardized schema + column order
for c in MANIFEST_COLS:
    if c not in manifest_df.columns:
        manifest_df[c] = np.nan
manifest_df = manifest_df[MANIFEST_COLS].copy()

# -------------------------
# Save outputs
# - manifests/manifest_all.csv (+ per-split manifests)
# - logs/preprocess_warnings.csv and logs/dataset_summary.json
# - config/run_config.json
# -------------------------
manifest_all_path = os.path.join(D1_MANIFEST_DIR, "manifest_all.csv")
manifest_df.to_csv(manifest_all_path, index=False)

for sp in ["train","val","test"]:
    p = os.path.join(D1_MANIFEST_DIR, f"manifest_{sp}.csv")
    manifest_df.loc[manifest_df["split"] == sp].to_csv(p, index=False)

warnings_path = os.path.join(D1_LOGS_DIR, "preprocess_warnings.csv")
warnings_df.to_csv(warnings_path, index=False)

summary = {
    "dataset": "D1",
    "source_root": D1_DIR,
    "metadata_dir": D1_METADATA_DIR,
    "metadata_hc": D1_META_HC,
    "metadata_pd": D1_META_PD,
    "sr_hz": int(D1_SR),
    "webrtcvad_available": bool(HAVE_WEBRTCVAD),
    "scipy_available": bool(HAVE_SCIPY),
    "n_source_rows_used": int(len(d1)),
    "n_unique_speakers_source": int(d1["speaker_id"].nunique()) if len(d1) else 0,
    "label_counts_source_rows": d1["label_str"].value_counts(dropna=False).to_dict() if len(d1) else {},
    "n_clips_total": int(len(manifest_df)),
    "label_counts_clips": manifest_df["label_str"].value_counts(dropna=False).to_dict() if len(manifest_df) else {},
    "split_counts_clips": manifest_df["split"].value_counts(dropna=False).to_dict() if len(manifest_df) else {},
    "task_counts_clips": manifest_df["task"].value_counts(dropna=False).to_dict() if len(manifest_df) else {},
    "n_warnings": int(len(warnings_df)),
    "n_missing_audio_rows_in_metadata": int(len(missing_audio)),
    "notes": (
        "At most 1 clip per source file. Candidate written immediately to clips/_candidates, "
        "then cap/split and moved to clips/<split>. "
        "Selection uses the longest voiced segment on source timeline."
    ),
}
summary_path = os.path.join(D1_LOGS_DIR, "dataset_summary.json")
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(summary, f, indent=2)

run_cfg = {
    "dataset": "D1",
    "paths": {
        "dataset_dir": D1_DIR,
        "metadata_dir": D1_METADATA_DIR,
        "metadata_hc": D1_META_HC,
        "metadata_pd": D1_META_PD,
        "out_root": D1_OUT_ROOT,
        "clips_dir": D1_CLIPS_DIR,
        "candidate_dir": D1_CAND_DIR,
        "manifest_all": manifest_all_path,
        "warnings_csv": warnings_path,
        "summary_json": summary_path,
    },
    "folder_structure": "clips/<split>/ (flat) with temporary clips/_candidates during run",
    "filename_format": "D1_{HC|PD}_{speaker_id}_{task<=5}_{global_index:06d}.wav",
    "manifest_schema_order": MANIFEST_COLS,
    "cap_policy": {"groupby": ["speaker_id", "task"], "max_per_group": int(D1_MAX_CLIPS_PER_SPK_PER_TASK)},
    "clip_boundary_policy": "clip_start_sec/clip_end_sec are source-timeline (post-resample). clip_is_contiguous=True.",
    "one_clip_per_source_file": True,
    "one_clip_selection_rule": "Pick longest voiced segment; vowel=2s centered/padded, others=8s from start or true duration if shorter.",
    "notes": "Manual RMS gain + peak limiting; no pyloudnorm.",
}
run_cfg_path = os.path.join(D1_CONFIG_DIR, "run_config.json")
with open(run_cfg_path, "w", encoding="utf-8") as f:
    json.dump(run_cfg, f, indent=2)

print("\nDONE: D1 preprocessing")
print("Manifest:", manifest_all_path)
print("Warnings:", warnings_path)
print("Summary:", summary_path)
print("Run config:", run_cfg_path)
print("Clips written:", int(len(manifest_df)))
print("Missing-audio metadata rows (not processed):", int(len(missing_audio)))

print("\nSanity checks:")
print("Unique SR:", sorted(manifest_df["sr_hz"].unique().tolist()) if len(manifest_df) else [])
print("Unique channels:", sorted(manifest_df["channels"].unique().tolist()) if len(manifest_df) else [])

# Speaker split uniqueness check (speaker must appear in only one split)
spk_split_chk = manifest_df[["speaker_id", "split"]].drop_duplicates()
dup = spk_split_chk.groupby("speaker_id")["split"].nunique()
bad = dup[dup > 1]
print("Speakers appearing in multiple splits:", int(len(bad)))
if len(bad) > 0:
    print("Example bad speaker ids:", bad.index.tolist()[:10])
    raise RuntimeError("Speaker appears in more than one split.")

The following cell runs the full D2 (EWA-DB, Slovak) preprocessing in one place and is written so it can be safely rerun if Google Drive disconnects during execution. It reads the dataset metadata tables (`FILES.TSV` and `SPEAKERS.TSV`), filters the records to include only usable audio (Healthy or Parkinson diagnosis, inclusion criteria met, publish agreement true when present, and an available audio file), and then checks that each referenced audio file actually exists. Any missing files are recorded in `missing_paths.csv`.

For each remaining source audio file, the cell creates **exactly one** processed clip. The audio is converted to mono, resampled to 16 kHz, and normalized to a consistent loudness using simple RMS based leveling with a peak limit. Speech regions are detected using voice activity detection (WebRTC VAD when available, otherwise an energy based fallback), and the detected speech segments are stitched together into a single voiced signal. For reading style clips, a light trim is applied to reduce leading silence. Clip selection follows the dataset rules: if `PICTURE == vokal`, a **2.0 second vowel clip** is produced and padded with silence if needed; otherwise, a **reading style clip** of up to **8.0 seconds** is created. If the voiced signal is longer than 8 seconds, the best 8 second window is chosen to reduce early silence; if it is shorter, the true duration is kept with no padding. The start and end times written to the manifest refer to the **voiced, stitched timeline**, not the original raw audio.

To support restarts, candidate clips are written immediately to `clips/_candidates/` using deterministic filenames that remain the same across reruns. If a candidate file already exists, that source is skipped instead of being regenerated. A local temporary cache is used by default, where files are first written under `/content/…` and then copied to Google Drive with retries, which helps if Drive becomes unstable. Progress is saved periodically to `logs/candidates_checkpoint.csv`, and any issues such as read errors, missing speech segments, Drive interruptions, or trimming events are logged in `logs/preprocess_warnings.csv`.

After all candidate clips are created, the cell enforces a limit of **8 clips per (speaker, task)** and removes any extra candidate files. Speakers are then assigned to train, validation, and test splits at the **speaker level** using a 70/15/15 ratio, stratified by label so that no speaker appears in more than one split. The selected clips are moved into `clips/train`, `clips/val`, and `clips/test` with standardized filenames, and the final outputs are written to the run folder, including `manifests/manifest_all.csv`, `config/run_config.json`, and `logs/dataset_summary.json`, along with the warning and checkpoint logs.

In [None]:
# D2 (EWA-DB) Preprocessing v1 — Resume-Safe, 1 Clip per Source, Stream Write
# Inputs: FILES.TSV + SPEAKERS.TSV + referenced audio files under the dataset folder
# Outputs: clips/<split>/ WAVs, manifests/manifest_all.csv, logs/preprocess_warnings.csv + dataset_summary.json (+ checkpoints)

# =========================
# D2 Preprocessing v1 (EWA-DB) — SINGLE COMPLETE CELL
# RESUME-SAFE + STREAM-WRITE + 1CLIP/SRC + DIAGNOSTICS FOR DRIVE DISCONNECT
# =========================
# What this cell does:
# - Reads metadata tables (FILES.TSV, SPEAKERS.TSV), filters usable rows, verifies audio paths
# - Creates exactly 1 processed candidate clip per source audio (resume-safe across restarts)
# - Caps clips per (speaker, task), splits speakers into train/val/test
# - Moves kept clips into final folders and writes manifests, logs, and config

import os
import re
import json
import math
import random
import shutil
import time
from pathlib import Path
from typing import List, Tuple, Dict, Optional

import numpy as np
import pandas as pd
import soundfile as sf
from tqdm.auto import tqdm

# -------------------------
# Drive mount and environment
# Inputs: Colab runtime + Google Drive
# Outputs: Drive mounted at /content/drive
# -------------------------
if not os.path.isdir("/content/drive/MyDrive"):
    from google.colab import drive  # type: ignore
    drive.mount("/content/drive")

# -------------------------
# Package checks (install if missing)
# Inputs: runtime environment
# Outputs: webrtcvad and (optional) scipy resampler available
# -------------------------
def _d2_try_import_webrtcvad():
    try:
        import webrtcvad  # type: ignore
        return webrtcvad, True
    except Exception:
        return None, False

webrtcvad, D2_HAVE_WEBRTCVAD = _d2_try_import_webrtcvad()
if not D2_HAVE_WEBRTCVAD:
    !pip -q install webrtcvad
    webrtcvad, D2_HAVE_WEBRTCVAD = _d2_try_import_webrtcvad()

try:
    from scipy.signal import resample_poly  # type: ignore
    D2_HAVE_SCIPY = True
except Exception:
    D2_HAVE_SCIPY = False

# -------------------------
# Dataset and output paths
# Inputs: dataset folder + TSV filenames
# Outputs: run folder with clips/, manifests/, config/, logs/ created
# -------------------------
D2_PROJECT_DIR  = "/content/drive/MyDrive/AI_PD_Project"
D2_DATASETS_DIR = f"{D2_PROJECT_DIR}/Datasets"
D2_EWA_DIR      = f"{D2_DATASETS_DIR}/D2-Slovak (EWA-DB)/EWA-DB"

D2_FILES_TSV    = f"{D2_EWA_DIR}/FILES.TSV"
D2_SPEAKERS_TSV = f"{D2_EWA_DIR}/SPEAKERS.TSV"
D2_README_TXT   = f"{D2_EWA_DIR}/README.TXT"

# Output root (Drive)
D2_OUT_ROOT     = f"{D2_EWA_DIR}/preprocessed_v1"

# Run isolation
# - Keep the same RUN_NAME to resume the same run after a disconnect
# - Change RUN_NAME to start a clean, separate run folder
D2_RUN_NAME     = "run_v1"  # change to "run_v2" if one want a totally fresh run folder

D2_RUN_ROOT     = f"{D2_OUT_ROOT}/runs/{D2_RUN_NAME}"
D2_CLIPS_DIR    = f"{D2_RUN_ROOT}/clips"
D2_MANIFEST_DIR = f"{D2_RUN_ROOT}/manifests"
D2_CONFIG_DIR   = f"{D2_RUN_ROOT}/config"
D2_LOGS_DIR     = f"{D2_RUN_ROOT}/logs"

# Candidate clips (temporary stage)
D2_CAND_DIR_DRIVE = f"{D2_CLIPS_DIR}/_candidates"

# Optional local cache to reduce Drive write failures
# - Writes candidate WAVs locally first, then copies to Drive
D2_USE_LOCAL_CAND_CACHE = True
D2_CAND_DIR_LOCAL = f"/content/d2_candidates_cache/{D2_RUN_NAME}"

# Reset is destructive: deletes the whole run folder
D2_RESET_RUN = False

# Create required folders
for p in [D2_RUN_ROOT, D2_CLIPS_DIR, D2_MANIFEST_DIR, D2_CONFIG_DIR, D2_LOGS_DIR]:
    os.makedirs(p, exist_ok=True)
for sp in ["train", "val", "test"]:
    os.makedirs(os.path.join(D2_CLIPS_DIR, sp), exist_ok=True)

os.makedirs(D2_CAND_DIR_DRIVE, exist_ok=True)
if D2_USE_LOCAL_CAND_CACHE:
    os.makedirs(D2_CAND_DIR_LOCAL, exist_ok=True)

# Reset behavior (explicit only)
if D2_RESET_RUN:
    print("D2_RESET_RUN=True -> deleting run folder:", D2_RUN_ROOT)
    shutil.rmtree(D2_RUN_ROOT, ignore_errors=True)
    for p in [D2_RUN_ROOT, D2_CLIPS_DIR, D2_MANIFEST_DIR, D2_CONFIG_DIR, D2_LOGS_DIR]:
        os.makedirs(p, exist_ok=True)
    for sp in ["train", "val", "test"]:
        os.makedirs(os.path.join(D2_CLIPS_DIR, sp), exist_ok=True)
    os.makedirs(D2_CAND_DIR_DRIVE, exist_ok=True)
    if D2_USE_LOCAL_CAND_CACHE:
        shutil.rmtree(D2_CAND_DIR_LOCAL, ignore_errors=True)
        os.makedirs(D2_CAND_DIR_LOCAL, exist_ok=True)

# -------------------------
# Quick checks (fail early if dataset is missing)
# Inputs: expected dataset files and folders
# Outputs: printed status; raises if critical files are missing
# -------------------------
print("D2_EWA_DIR exists?", os.path.exists(D2_EWA_DIR))
print("FILES.TSV exists?", os.path.exists(D2_FILES_TSV))
print("SPEAKERS.TSV exists?", os.path.exists(D2_SPEAKERS_TSV))
print("README.TXT exists?", os.path.exists(D2_README_TXT))
print("webrtcvad available?", D2_HAVE_WEBRTCVAD)
print("scipy available?", D2_HAVE_SCIPY)
print("Run root:", D2_RUN_ROOT)
print("Candidate dir (Drive):", D2_CAND_DIR_DRIVE)
print("Using local candidate cache?", D2_USE_LOCAL_CAND_CACHE)
if D2_USE_LOCAL_CAND_CACHE:
    print("Candidate dir (Local):", D2_CAND_DIR_LOCAL)

if not os.path.exists(D2_EWA_DIR):
    raise FileNotFoundError(f"D2_EWA_DIR not found: {D2_EWA_DIR}")
if not os.path.exists(D2_FILES_TSV) or not os.path.exists(D2_SPEAKERS_TSV):
    raise FileNotFoundError(f"Missing FILES.TSV or SPEAKERS.TSV under: {D2_EWA_DIR}")

# -------------------------
# Processing settings
# Inputs: constants (sampling rate, clip lengths, split fractions)
# Outputs: consistent preprocessing behavior across runs
# -------------------------
D2_SR = 16000
D2_RANDOM_SEED = 1337
random.seed(D2_RANDOM_SEED)
np.random.seed(D2_RANDOM_SEED)

# Loudness leveling (simple RMS target + peak limiting)
D2_TARGET_RMS_DBFS = -20.0
D2_PEAK_LIMIT_DBFS = -1.0
D2_MIN_RMS_DBFS    = -60.0
D2_MAX_GAIN_DB     = 18.0

# Voice activity detection (used only to build a voiced-only signal)
D2_VAD_MODE       = 2
D2_FRAME_MS       = 30
D2_MIN_SPEECH_MS  = 200
D2_MERGE_GAP_MS   = 200
D2_PAD_SEC        = 0.25
D2_MIN_KEEP_SEC   = 0.30

# Clip rules (fixed)
# - vowel: 2.0s, pad if short
# - reading: 8.0s if long enough else keep true duration
D2_VOWEL_SEC = 2.0
D2_READ_SEC  = 8.0

# Cap and split
D2_MAX_CLIPS_PER_SPK_TASK = 8
D2_TRAIN_PCT, D2_VAL_PCT, D2_TEST_PCT = 0.70, 0.15, 0.15

# Audio existence checking
D2_DO_EXISTENCE_CHECK = True
D2_EXIST_CHECK_UNIQUE_ONLY = True

# Robust writing (retries help with transient Drive errors)
D2_WRITE_RETRIES = 4
D2_WRITE_SLEEP   = 0.75

# Progress checkpoint (diagnostic)
D2_CHECKPOINT_EVERY = 500  # rows
D2_CAND_CHECKPOINT_CSV = os.path.join(D2_LOGS_DIR, "candidates_checkpoint.csv")

# -------------------------
# Manifest schema (fixed order)
# Inputs: processed clips + metadata
# Outputs: manifest_all.csv and per-split manifests in the same column order
# -------------------------
MANIFEST_COLS = [
    "split","dataset","task","speaker_id","sample_id",
    "label_str","label_num","age","sex","speaker_key_rel",
    "clip_path","duration_sec","source_path",
    "clip_start_sec","clip_end_sec","sr_hz","channels",
    "clip_is_contiguous",
]

# -------------------------
# Utility functions
# Inputs: strings/audio arrays
# Outputs: cleaned names, normalized audio, VAD segments, final clip selection
# -------------------------
def d2_safe(s: str) -> str:
    return re.sub(r"[^A-Za-z0-9_\-\.]+", "_", str(s))

def d2_db_to_lin(db: float) -> float:
    return 10.0 ** (db / 20.0)

def d2_rms_dbfs(y: np.ndarray) -> float:
    if y is None or len(y) == 0:
        return -120.0
    rms = float(np.sqrt(np.mean(y.astype(np.float64) ** 2) + 1e-12))
    return 20.0 * math.log10(max(rms, 1e-12))

def d2_peak_limit(y: np.ndarray, peak_dbfs: float) -> np.ndarray:
    if y is None or len(y) == 0:
        return y
    peak = float(np.max(np.abs(y)))
    lim = d2_db_to_lin(peak_dbfs)
    if peak > lim and peak > 0:
        y = y * (lim / peak)
    return np.clip(y, -1.0, 1.0).astype(np.float32)

def d2_norm_rms_then_peak(y: np.ndarray) -> np.ndarray:
    if y is None or len(y) == 0:
        return y
    cur = d2_rms_dbfs(y)
    if cur < D2_MIN_RMS_DBFS:
        return d2_peak_limit(y.astype(np.float32), D2_PEAK_LIMIT_DBFS)
    gain_db = float(D2_TARGET_RMS_DBFS - cur)
    gain_db = float(np.clip(gain_db, -60.0, D2_MAX_GAIN_DB))
    y2 = (y.astype(np.float32) * d2_db_to_lin(gain_db)).astype(np.float32)
    return d2_peak_limit(y2, D2_PEAK_LIMIT_DBFS)

def d2_read_mono(path: str) -> Tuple[np.ndarray, int]:
    x, sr = sf.read(path, always_2d=True)
    x = x.mean(axis=1).astype(np.float32)
    if not np.isfinite(x).all():
        x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
    return x, int(sr)

def d2_resample(y: np.ndarray, sr_in: int, sr_out: int) -> np.ndarray:
    if sr_in == sr_out:
        return y.astype(np.float32, copy=False)
    if D2_HAVE_SCIPY:
        g = math.gcd(sr_in, sr_out)
        up = sr_out // g
        down = sr_in // g
        return resample_poly(y.astype(np.float64), up, down).astype(np.float32, copy=False)
    n_new = int(round(len(y) * (sr_out / sr_in)))
    if n_new <= 1:
        return y[:1].astype(np.float32, copy=False)
    x_old = np.linspace(0.0, 1.0, num=len(y), endpoint=False)
    x_new = np.linspace(0.0, 1.0, num=n_new, endpoint=False)
    return np.interp(x_new, x_old, y).astype(np.float32, copy=False)

def d2_pcm16_bytes(y: np.ndarray) -> bytes:
    y = np.clip(y, -1.0, 1.0)
    return (y * 32767.0).astype(np.int16).tobytes()

def d2_task_type_from_picture(picture: str) -> str:
    return "vowel" if str(picture).strip().lower() == "vokal" else "reading"

def d2_task_short(tt: str) -> str:
    tt = (tt or "").lower()
    if tt == "vowel":
        return "vowl"
    if tt == "reading":
        return "read"
    return "unk"

def d2_label(diag: str) -> Tuple[str, int, str]:
    d = str(diag).strip().lower()
    if "parkinson" in d:
        return "Parkinson", 1, "PD"
    return "Healthy", 0, "HC"

def d2_safe_write_wav(path: str, y: np.ndarray, sr: int) -> None:
    # Writes WAV with retries to handle transient filesystem issues
    os.makedirs(os.path.dirname(path), exist_ok=True)
    last_err = None
    for attempt in range(1, D2_WRITE_RETRIES + 1):
        try:
            sf.write(path, np.clip(y, -1.0, 1.0).astype(np.float32), sr, subtype="PCM_16")
            return
        except Exception as e:
            last_err = e
            time.sleep(D2_WRITE_SLEEP * attempt)
    raise RuntimeError(f"Failed to write WAV: {path}. Last error: {repr(last_err)}")

def d2_to_boolish(x) -> bool:
    # Parses common “true-ish” values from TSV text
    if isinstance(x, bool):
        return x
    s = str(x).strip().lower()
    return s in {"true","1","yes","y","t"}

# --- Voiced-segment utilities (builds a voiced-only concatenated signal) ---
def d2_energy_segments(y: np.ndarray, sr: int) -> List[Tuple[int, int]]:
    frame = int(sr * 0.02)
    hop = frame
    if frame <= 0 or len(y) < frame:
        return []
    eng = []
    idx = []
    for i in range(0, len(y) - frame + 1, hop):
        w = y[i:i+frame]
        eng.append(float(np.mean(w*w)))
        idx.append(i)
    eng = np.array(eng, dtype=np.float32)
    thr = float(np.percentile(eng, 25)) * 2.5
    thr = max(thr, 1e-8)
    keep = eng > thr

    segs = []
    on = False
    s0 = 0
    for k, flag in enumerate(keep):
        if flag and not on:
            on = True
            s0 = idx[k]
        elif (not flag) and on:
            on = False
            segs.append((s0, idx[k] + frame))
    if on and idx:
        segs.append((s0, idx[-1] + frame))
    return segs

def d2_merge(segs: List[Tuple[int,int]], sr: int) -> List[Tuple[int,int]]:
    if not segs:
        return []
    gap = int(sr * (D2_MERGE_GAP_MS/1000.0))
    min_len = int(sr * (D2_MIN_SPEECH_MS/1000.0))
    segs = sorted(segs)
    merged = [list(segs[0])]
    for s,e in segs[1:]:
        if s - merged[-1][1] <= gap:
            merged[-1][1] = max(merged[-1][1], e)
        else:
            merged.append([s,e])
    out = []
    for s,e in merged:
        if (e-s) >= min_len:
            out.append((int(s), int(e)))
    return out

def d2_pad(segs: List[Tuple[int,int]], sr: int, n: int) -> List[Tuple[int,int]]:
    pad = int(sr * D2_PAD_SEC)
    out = []
    for s,e in segs:
        s2 = max(0, s-pad)
        e2 = min(n, e+pad)
        if e2 > s2:
            out.append((s2,e2))
    return out

def d2_webrtc_segments(y: np.ndarray, sr: int) -> Optional[List[Tuple[int,int]]]:
    # Uses webrtcvad only at 16 kHz; returns source segments used to build voiced-only signal
    if not D2_HAVE_WEBRTCVAD or sr != 16000:
        return None
    frame_len = int(sr * (D2_FRAME_MS/1000.0))
    if frame_len <= 0 or len(y) < frame_len:
        return []

    n_frames = int(math.ceil(len(y)/frame_len))
    pad_samp = n_frames*frame_len - len(y)
    if pad_samp > 0:
        y = np.concatenate([y, np.zeros(pad_samp, dtype=np.float32)], axis=0)
    pcm = d2_pcm16_bytes(y)

    vad = webrtcvad.Vad(int(D2_VAD_MODE))
    flags = []
    for i in range(n_frames):
        s = i*frame_len*2
        e = s + frame_len*2
        flags.append(vad.is_speech(pcm[s:e], 16000))

    segs = []
    on = False
    s0 = 0
    for i,f in enumerate(flags):
        if f and not on:
            on = True
            s0 = i
        elif (not f) and on:
            on = False
            segs.append((s0*frame_len, i*frame_len))
    if on:
        segs.append((s0*frame_len, n_frames*frame_len))

    n0 = len(y) - pad_samp if pad_samp > 0 else len(y)
    return [(max(0,s), min(n0,e)) for s,e in segs]

def d2_voiced_concat(y: np.ndarray, sr: int) -> Tuple[np.ndarray, str]:
    # Creates a voiced-only signal by concatenating voiced segments
    segs = d2_webrtc_segments(y, sr)
    if segs is None:
        segs = d2_energy_segments(y, sr)
        used = "energy"
    else:
        used = "webrtcvad"
    segs = d2_merge(segs, sr)
    segs = d2_pad(segs, sr, len(y))
    if not segs:
        return y.astype(np.float32, copy=False), used
    return np.concatenate([y[s:e] for s,e in segs], axis=0).astype(np.float32, copy=False), used

def d2_trim_leading_silence(y: np.ndarray, sr: int, max_trim_sec: float = 1.5, thr_dbfs: float = -45.0) -> Tuple[np.ndarray, int]:
    # Conservative leading trim to reduce silent start in reading clips
    if y is None or len(y) == 0:
        return y, 0
    max_trim = int(sr * max_trim_sec)
    n = min(len(y), max_trim)
    if n <= 0:
        return y, 0
    thr = d2_db_to_lin(thr_dbfs)
    idx = 0
    while idx < n and float(abs(y[idx])) < thr:
        idx += 1
    if idx >= n:
        return y, 0
    back = int(0.02 * sr)
    idx2 = max(0, idx - back)
    return y[idx2:], idx2

def d2_best_reading_window(voiced: np.ndarray, sr: int, target_sec: float = 8.0) -> Tuple[np.ndarray, int, int]:
    # Picks an 8s crop that tends to start with speech sooner
    target_len = int(round(sr * target_sec))
    n = len(voiced)
    if n <= target_len:
        return voiced.astype(np.float32, copy=False), 0, n
    offsets = sorted(set([0, int(0.25*(n-target_len)), int(0.5*(n-target_len)), int(0.75*(n-target_len)), n-target_len]))
    best = None
    best_score = -1.0
    for st in offsets:
        en = st + target_len
        w = voiced[st:en]
        k = int(0.8 * sr)
        k = min(k, len(w))
        if k <= 0:
            continue
        score = float(np.mean(w[:k].astype(np.float64)**2))
        if score > best_score:
            best_score = score
            best = (st, en)
    if best is None:
        st = (n - target_len)//2
        en = st + target_len
        return voiced[st:en].astype(np.float32, copy=False), st, en
    st, en = best
    return voiced[st:en].astype(np.float32, copy=False), st, en

def d2_choose_one_clip(voiced: np.ndarray, sr: int, is_vowel: bool) -> Tuple[np.ndarray, float, float]:
    """
    Exactly one clip per source:
      - Vowel: 2.0s center crop if longer; zero-pad if shorter
      - Reading: best 8.0s window if long enough; else keep true duration
    clip_start_sec/clip_end_sec are on the voiced-concat timeline.
    """
    if voiced is None or len(voiced) == 0:
        return np.zeros((0,), dtype=np.float32), 0.0, 0.0

    if is_vowel:
        target_len = int(round(sr * D2_VOWEL_SEC))
        n = len(voiced)
        if n >= target_len:
            start = (n - target_len)//2
            end = start + target_len
            clip = voiced[start:end].astype(np.float32, copy=False)
            return clip, float(start)/sr, float(end)/sr
        tmp = np.zeros((target_len,), dtype=np.float32)
        tmp[:n] = voiced.astype(np.float32, copy=False)
        return tmp, 0.0, float(target_len)/sr

    n = len(voiced)
    target_len = int(round(sr * D2_READ_SEC))
    if n >= target_len:
        clip, st, en = d2_best_reading_window(voiced, sr, D2_READ_SEC)
        return clip, float(st)/sr, float(en)/sr
    return voiced.astype(np.float32, copy=False), 0.0, float(n)/sr

def d2_candidate_filename(row) -> str:
    # Deterministic name so the same row produces the same candidate file after restart
    spk = d2_safe(str(row.speaker_id))
    task = d2_safe(str(row.task)[:5])
    base = d2_safe(Path(str(row.audio_path)).name)
    tag = "vowl" if str(row.task_type).lower() == "vowel" else "read"
    return f"D2CAND_{spk}_{tag}_{task}_{base}.wav"

def d2_drive_ok() -> bool:
    # Lightweight check used during long runs
    return os.path.isdir("/content/drive/MyDrive") and os.path.isdir(D2_EWA_DIR)

def d2_sync_one_local_to_drive(local_path: str, drive_path: str) -> None:
    # Copies local candidate to Drive with retries
    os.makedirs(os.path.dirname(drive_path), exist_ok=True)
    last_err = None
    for attempt in range(1, D2_WRITE_RETRIES + 1):
        try:
            shutil.copy2(local_path, drive_path)
            return
        except Exception as e:
            last_err = e
            time.sleep(D2_WRITE_SLEEP * attempt)
    raise RuntimeError(f"Failed to sync local->drive: {local_path} -> {drive_path}. Last error: {repr(last_err)}")

# -------------------------
# Load metadata + filter usable rows
# Inputs: FILES.TSV, SPEAKERS.TSV
# Outputs: filtered table with audio_path, labels, task type, speaker metadata
# -------------------------
print("\nReading TSV/README...")

files_df = pd.read_csv(D2_FILES_TSV, sep="\t", dtype=str, keep_default_na=False, na_values=[""])
speakers_df = pd.read_csv(D2_SPEAKERS_TSV, sep="\t", dtype=str, keep_default_na=False, na_values=[""])
_ = Path(D2_README_TXT).read_text(errors="ignore") if os.path.exists(D2_README_TXT) else ""

print("FILES.TSV rows:", len(files_df))
print("SPEAKERS.TSV rows:", len(speakers_df))

req_files = {"SPEAKER_CODE","DIAGNOSIS","INCLUSIVE_CRITERIA","AUDIOFILE","PICTURE"}
miss_f = sorted(list(req_files - set(files_df.columns)))
if miss_f:
    raise ValueError(f"FILES.TSV missing columns: {miss_f}")
if "SPEAKER_CODE" not in speakers_df.columns:
    raise ValueError("SPEAKERS.TSV missing column: SPEAKER_CODE")

d2 = files_df.copy()

# Canonical filters: diagnosis, inclusion, publish agreement (if present), audio availability
d2["diag_norm"] = d2["DIAGNOSIS"].astype(str).str.strip()
d2 = d2[d2["diag_norm"].isin(["Healthy","Parkinson"])].copy()

d2["inclusive_ok"] = d2["INCLUSIVE_CRITERIA"].apply(d2_to_boolish)
d2 = d2[d2["inclusive_ok"]].copy()

if "PUBLISH_AGREEMENT" in d2.columns:
    d2["publish_ok"] = d2["PUBLISH_AGREEMENT"].apply(d2_to_boolish)
    d2 = d2[d2["publish_ok"]].copy()

d2 = d2[d2["AUDIOFILE"].astype(str).str.strip() != "<not-available>"].copy()

print("\nRows after diag + inclusive + publish(if exists) + audio:", len(d2))
print("Counts by diag_norm:")
print(d2["diag_norm"].value_counts())

def d2_make_audio_path(p: str) -> str:
    p = str(p).strip().lstrip("./").replace("\\", "/")
    return os.path.join(D2_EWA_DIR, p)

d2["audio_path"] = d2["AUDIOFILE"].apply(d2_make_audio_path)
d2["speaker_id"] = d2["SPEAKER_CODE"].astype(str)

# Task mapping: vokal -> vowel clip rules, everything else -> reading clip rules
d2["task_type"] = d2["PICTURE"].apply(d2_task_type_from_picture)
d2["task"] = d2["task_type"].apply(d2_task_short)

# Labels
lab = d2["diag_norm"].apply(d2_label)
d2["label_str"] = lab.apply(lambda t: t[0])
d2["label_num"] = lab.apply(lambda t: t[1]).astype(int)
d2["hc_pd_tag"] = lab.apply(lambda t: t[2])

# Speaker metadata (optional columns only)
col_age = "AGE" if "AGE" in speakers_df.columns else None
col_sex = "SEX" if "SEX" in speakers_df.columns else None

speaker_key_candidates = {"speaker_key_rel","speaker_key","speakerkeyrel","speakerkey","speaker_key_rel "}
d2_spk_key_col = None
for c in speakers_df.columns:
    if str(c).strip().lower() in speaker_key_candidates:
        d2_spk_key_col = c
        break

spk_cols = ["SPEAKER_CODE"]
if col_age: spk_cols.append(col_age)
if col_sex: spk_cols.append(col_sex)
if d2_spk_key_col and d2_spk_key_col not in spk_cols:
    spk_cols.append(d2_spk_key_col)

d2 = d2.merge(speakers_df[spk_cols], on="SPEAKER_CODE", how="left", validate="many_to_one")
d2["age"] = d2[col_age] if col_age else np.nan
d2["sex"] = d2[col_sex] if col_sex else np.nan
d2["speaker_key_rel"] = d2[d2_spk_key_col] if d2_spk_key_col else np.nan

# -------------------------
# Audio existence check
# Inputs: filtered rows with audio_path
# Outputs: missing_paths.csv + filtered rows with existing audio only
# -------------------------
missing_audio_rows = pd.DataFrame(columns=["speaker_id","audio_path","AUDIOFILE","diag_norm","PICTURE"])

if D2_DO_EXISTENCE_CHECK:
    if D2_EXIST_CHECK_UNIQUE_ONLY:
        uniq = d2["audio_path"].dropna().astype(str).unique().tolist()
        exists_map = {}
        for p in tqdm(uniq, desc="D2 existence check (unique paths)", dynamic_ncols=True):
            exists_map[p] = os.path.exists(p)
        d2["exists"] = d2["audio_path"].map(exists_map).fillna(False)
    else:
        d2["exists"] = [os.path.exists(p) for p in tqdm(d2["audio_path"].astype(str).tolist(),
                                                        desc="D2 existence check", dynamic_ncols=True)]
else:
    d2["exists"] = True

missing_audio_rows = d2.loc[~d2["exists"], ["speaker_id","audio_path","AUDIOFILE","diag_norm","PICTURE"]].copy()
print("\nAudio exists:", int(d2["exists"].sum()), "missing:", int((~d2["exists"]).sum()))

d2_missing_csv = os.path.join(D2_MANIFEST_DIR, "missing_paths.csv")
os.makedirs(os.path.dirname(d2_missing_csv), exist_ok=True)
missing_audio_rows.to_csv(d2_missing_csv, index=False)
print("Saved missing list:", d2_missing_csv)

d2 = d2.loc[d2["exists"]].copy()
if len(d2) == 0:
    raise RuntimeError("No usable D2 rows after filtering + existence checks.")

print("\nLabel counts (rows):")
print(d2["label_str"].value_counts(dropna=False))
print("\nTask counts (rows):")
print(d2["task"].value_counts(dropna=False))

# -------------------------
# Resume snapshot (before heavy work)
# Inputs: candidate folders
# Outputs: printed counts and an estimated skip rate
# -------------------------
d2_proc = d2.sort_values(["speaker_id","audio_path"]).reset_index(drop=True)

existing_drive = len(list(Path(D2_CAND_DIR_DRIVE).glob("D2CAND_*.wav")))
existing_local = len(list(Path(D2_CAND_DIR_LOCAL).glob("D2CAND_*.wav"))) if D2_USE_LOCAL_CAND_CACHE else 0

print("\n=== Resume diagnostics ===")
print("Rows to process (after filters + exists):", len(d2_proc))
print("Existing candidates on Drive:", existing_drive)
if D2_USE_LOCAL_CAND_CACHE:
    print("Existing candidates in Local cache:", existing_local)

sample_n = min(2000, len(d2_proc))
already = 0
for rr in d2_proc.head(sample_n).itertuples(index=False):
    fn = d2_candidate_filename(rr)
    drive_path = os.path.join(D2_CAND_DIR_DRIVE, fn)
    local_path = os.path.join(D2_CAND_DIR_LOCAL, fn) if D2_USE_LOCAL_CAND_CACHE else ""
    if os.path.exists(drive_path) or (D2_USE_LOCAL_CAND_CACHE and os.path.exists(local_path)):
        already += 1
print(f"Sample skip rate (first {sample_n} rows): {already}/{sample_n} already have candidate")

# -------------------------
# Candidate creation (resume-safe, 1 clip per source)
# Inputs: d2_proc rows + source audio files
# Outputs: clips/_candidates/*.wav (and checkpoint CSV)
# -------------------------
cand_rows: List[Dict] = []
warn_rows: List[Dict] = []

pbar = tqdm(d2_proc.itertuples(index=False), total=len(d2_proc),
            desc="D2 preprocess (resume-safe; 1 clip per source)", dynamic_ncols=True)

written_now = 0
skipped_existing = 0
bad_reads = 0

def d2_checkpoint_save(rows: List[Dict]) -> None:
    # Writes a current snapshot of candidate records (diagnostic only)
    if not rows:
        return
    try:
        ck = pd.DataFrame(rows)
        ck.to_csv(D2_CAND_CHECKPOINT_CSV, index=False)
    except Exception as e:
        warn_rows.append({
            "dataset":"D2","speaker_id":"","source_path":"",
            "warning_type":"checkpoint_write_failed","detail":repr(e)
        })

for i, r in enumerate(pbar, start=1):
    src = str(r.audio_path)

    # Deterministic candidate filenames enable safe restarts
    cand_fn = d2_candidate_filename(r)
    cand_drive_path = os.path.join(D2_CAND_DIR_DRIVE, cand_fn)
    cand_local_path = os.path.join(D2_CAND_DIR_LOCAL, cand_fn) if D2_USE_LOCAL_CAND_CACHE else None

    # Resume-safe: skip processing if candidate already exists
    if os.path.exists(cand_drive_path) or (cand_local_path and os.path.exists(cand_local_path)):
        skipped_existing += 1
        # Still record a row so downstream cap/split can proceed
        cand_rows.append({
            "dataset": "D2",
            "task": str(r.task)[:5],
            "speaker_id": str(r.speaker_id),
            "sample_id": os.path.basename(src),
            "label_str": r.label_str,
            "label_num": int(r.label_num),
            "age": r.age if pd.notna(r.age) else np.nan,
            "sex": r.sex if pd.notna(r.sex) else np.nan,
            "speaker_key_rel": r.speaker_key_rel if pd.notna(r.speaker_key_rel) else np.nan,
            "clip_path_cand": cand_drive_path,
            "duration_sec": np.nan,
            "source_path": src,
            "clip_start_sec": np.nan,
            "clip_end_sec": np.nan,
            "sr_hz": int(D2_SR),
            "channels": 1,
            "clip_is_contiguous": True,
            "hc_pd_tag": str(r.hc_pd_tag).strip().upper(),
        })
        continue

    # Quick Drive availability note (does not stop the run)
    if not d2_drive_ok():
        warn_rows.append({
            "dataset":"D2","speaker_id":str(r.speaker_id),"source_path":src,
            "warning_type":"drive_not_available","detail":"Drive path not available during processing"
        })

    try:
        # Load -> resample -> normalize loudness
        y, sr0 = d2_read_mono(src)
        y = d2_resample(y, sr0, D2_SR)
        y = d2_norm_rms_then_peak(y)

        # Build a voiced-only signal (concatenated segments)
        voiced, vad_used = d2_voiced_concat(y, D2_SR)
        if len(voiced) < int(D2_SR * D2_MIN_KEEP_SEC):
            voiced = y
            warn_rows.append({
                "dataset":"D2","speaker_id":str(r.speaker_id),"source_path":src,
                "warning_type":"vad_too_short_fallback_original",
                "detail": f"vad_used={vad_used}"
            })

        is_vowel = (str(r.task_type).lower() == "vowel")

        # Reading: trim some leading silence before choosing the final window
        if not is_vowel:
            voiced, trimmed = d2_trim_leading_silence(voiced, D2_SR, max_trim_sec=1.5, thr_dbfs=-45.0)
            if trimmed > 0:
                warn_rows.append({
                    "dataset":"D2","speaker_id":str(r.speaker_id),"source_path":src,
                    "warning_type":"leading_silence_trim_applied",
                    "detail": f"trimmed_samples={trimmed}"
                })

        # Select exactly one final clip from the voiced-only signal
        clip_audio, st, en = d2_choose_one_clip(voiced, D2_SR, is_vowel)

        if clip_audio is None or len(clip_audio) == 0:
            warn_rows.append({
                "dataset":"D2","speaker_id":str(r.speaker_id),"source_path":src,
                "warning_type":"empty_clip_after_processing",
                "detail": f"is_vowel={is_vowel}, vad_used={vad_used}"
            })
            continue

        # Write candidate to local cache then sync to Drive (or write directly to Drive)
        if D2_USE_LOCAL_CAND_CACHE and cand_local_path:
            d2_safe_write_wav(cand_local_path, clip_audio, D2_SR)
            d2_sync_one_local_to_drive(cand_local_path, cand_drive_path)
        else:
            d2_safe_write_wav(cand_drive_path, clip_audio, D2_SR)

        written_now += 1

        # Record candidate metadata (used later for cap/split/final move)
        cand_rows.append({
            "dataset": "D2",
            "task": str(r.task)[:5],
            "speaker_id": str(r.speaker_id),
            "sample_id": os.path.basename(src),
            "label_str": r.label_str,
            "label_num": int(r.label_num),
            "age": r.age if pd.notna(r.age) else np.nan,
            "sex": r.sex if pd.notna(r.sex) else np.nan,
            "speaker_key_rel": r.speaker_key_rel if pd.notna(r.speaker_key_rel) else np.nan,
            "clip_path_cand": cand_drive_path,
            "duration_sec": float(len(clip_audio) / D2_SR),
            "source_path": src,
            "clip_start_sec": float(st),
            "clip_end_sec": float(en),
            "sr_hz": int(D2_SR),
            "channels": 1,
            "clip_is_contiguous": True,
            "hc_pd_tag": str(r.hc_pd_tag).strip().upper(),
        })

    except Exception as e:
        bad_reads += 1
        warn_rows.append({
            "dataset":"D2",
            "speaker_id": str(getattr(r, "speaker_id", "")),
            "source_path": src,
            "warning_type":"preprocess_error",
            "detail": repr(e)
        })

    # Periodic checkpoint to confirm progress and aid debugging
    if (i % D2_CHECKPOINT_EVERY) == 0:
        d2_checkpoint_save(cand_rows)
        pbar.set_postfix({
            "written_now": written_now,
            "skipped_exist": skipped_existing,
            "bad_reads": bad_reads
        })

# Final checkpoint (latest snapshot)
d2_checkpoint_save(cand_rows)

cand_df = pd.DataFrame(cand_rows)
warnings_df = pd.DataFrame(warn_rows)

print("\n=== Candidate stage summary ===")
print("Rows iterated:", len(d2_proc))
print("Candidates recorded in table:", len(cand_df))
print("Written now (this run):", written_now)
print("Skipped because candidate existed:", skipped_existing)
print("Bad reads:", bad_reads)
print("Checkpoint CSV (latest):", D2_CAND_CHECKPOINT_CSV)

if len(cand_df) == 0:
    if len(warnings_df):
        print("\nSample warnings (first 10):")
        print(warnings_df.head(10).to_string(index=False))
    raise RuntimeError("No D2 clip candidates produced. Check audio paths/VAD settings.")

# Keep only candidates that exist on Drive (required for downstream move)
cand_df["cand_exists"] = cand_df["clip_path_cand"].astype(str).apply(os.path.exists)
missing_cands = cand_df.loc[~cand_df["cand_exists"]].copy()
if len(missing_cands) > 0:
    missing_cands_csv = os.path.join(D2_MANIFEST_DIR, "missing_candidates.csv")
    missing_cands.to_csv(missing_cands_csv, index=False)
    print("\nMissing candidate files on Drive:", len(missing_cands), "saved:", missing_cands_csv)
cand_df = cand_df.loc[cand_df["cand_exists"]].drop(columns=["cand_exists"]).reset_index(drop=True)

# -------------------------
# Cap per speaker and task
# Inputs: cand_df (already-written candidates)
# Outputs: reduced cand_df; deletes unkept candidate WAV files
# -------------------------
def d2_cap_candidates(df: pd.DataFrame, max_k: int, seed: int) -> Tuple[pd.DataFrame, int]:
    if df.empty:
        return df.copy(), 0
    rng = np.random.default_rng(seed)
    keep_idx = []
    for (spk, task), g in df.groupby(["speaker_id", "task"], sort=False):
        idxs = g.index.to_numpy()
        if len(idxs) <= max_k:
            keep_idx.extend(idxs.tolist())
        else:
            chosen = rng.choice(idxs, size=max_k, replace=False)
            keep_idx.extend(chosen.tolist())

    keep_idx = sorted(set(keep_idx))
    keep_set = set(keep_idx)

    deleted = 0
    unkept_paths = df.loc[~df.index.isin(list(keep_set)), "clip_path_cand"].tolist()
    for p in unkept_paths:
        try:
            if os.path.exists(p):
                os.remove(p)
                deleted += 1
        except Exception as e:
            warnings_df_local = pd.DataFrame([{
                "dataset":"D2","speaker_id":"","source_path":"",
                "warning_type":"candidate_delete_failed","detail":f"{p} :: {repr(e)}"
            }])
            nonlocal_warnings.append(warnings_df_local)

    return df.loc[keep_idx].reset_index(drop=True), deleted

nonlocal_warnings = []
cand_df, deleted = d2_cap_candidates(cand_df, D2_MAX_CLIPS_PER_SPK_TASK, D2_RANDOM_SEED)
if nonlocal_warnings:
    warnings_df = pd.concat([warnings_df] + nonlocal_warnings, ignore_index=True)

print("\nD2 candidates after cap:", int(len(cand_df)))
print("Deleted unkept candidates:", int(deleted))

# -------------------------
# Speaker-level split (stratified by label)
# Inputs: capped candidates
# Outputs: split assignment per speaker; merged into cand_df
# -------------------------
def d2_split_speakers(df: pd.DataFrame, fracs=(0.70, 0.15, 0.15), seed=1337) -> pd.DataFrame:
    rng = np.random.default_rng(seed)
    spk_df = df[["speaker_id", "label_num"]].drop_duplicates().copy()

    rows = []
    for lbl in [0, 1]:
        spks = spk_df[spk_df["label_num"] == lbl]["speaker_id"].tolist()
        rng.shuffle(spks)
        n = len(spks)
        n_train = int(round(fracs[0] * n))
        n_val   = int(round(fracs[1] * n))

        train_spks = spks[:n_train]
        val_spks   = spks[n_train:n_train + n_val]
        test_spks  = spks[n_train + n_val:]

        rows += [{"speaker_id": s, "split": "train"} for s in train_spks]
        rows += [{"speaker_id": s, "split": "val"} for s in val_spks]
        rows += [{"speaker_id": s, "split": "test"} for s in test_spks]

    return pd.DataFrame(rows)

spk_split = d2_split_speakers(cand_df, (D2_TRAIN_PCT, D2_VAL_PCT, D2_TEST_PCT), D2_RANDOM_SEED)
cand_df = cand_df.merge(spk_split, on="speaker_id", how="left", validate="many_to_one")

if cand_df["split"].isna().any():
    ex = cand_df.loc[cand_df["split"].isna(),"speaker_id"].drop_duplicates().head(10).tolist()
    raise RuntimeError(f"Some speakers did not get split. Example: {ex}")

print("\nSplit counts (clips):")
print(cand_df["split"].value_counts(dropna=False))
print("Label by split (clips):")
print(pd.crosstab(cand_df["split"], cand_df["label_str"]))

# -------------------------
# Finalize: move clips into clips/<split>/ and build manifest
# Inputs: kept candidates + split assignment
# Outputs: final WAVs + manifest rows
# -------------------------
global_clip_idx = 0
final_rows: List[Dict] = []

pbar2 = tqdm(cand_df.itertuples(index=False), total=len(cand_df),
             desc="D2 finalize (move kept)", dynamic_ncols=True)

for r in pbar2:
    global_clip_idx += 1

    tag = str(r.hc_pd_tag).strip().upper()
    if tag not in {"HC","PD"}:
        tag = "HC" if str(r.label_str).lower().startswith("healthy") else "PD"

    task5 = str(r.task)[:5]
    spk = str(r.speaker_id)

    out_name = d2_safe(f"D2_{tag}_{spk}_{task5}_{global_clip_idx:06d}.wav")
    out_path = os.path.join(D2_CLIPS_DIR, str(r.split), out_name)

    cand_path = str(r.clip_path_cand)
    if not os.path.exists(cand_path):
        warnings_df = pd.concat([warnings_df, pd.DataFrame([{
            "dataset":"D2","speaker_id":spk,"source_path":str(r.source_path),
            "warning_type":"missing_candidate_file","detail":cand_path
        }])], ignore_index=True)
        continue

    try:
        os.makedirs(os.path.dirname(out_path), exist_ok=True)
        shutil.move(cand_path, out_path)
    except Exception as e:
        warnings_df = pd.concat([warnings_df, pd.DataFrame([{
            "dataset":"D2","speaker_id":spk,"source_path":str(r.source_path),
            "warning_type":"candidate_move_failed",
            "detail": f"{cand_path} -> {out_path} :: {repr(e)}"
        }])], ignore_index=True)
        continue

    final_rows.append({
        "split": str(r.split),
        "dataset": "D2",
        "task": task5,
        "speaker_id": spk,
        "sample_id": str(r.sample_id),
        "label_str": str(r.label_str),
        "label_num": int(r.label_num),
        "age": r.age if pd.notna(r.age) else np.nan,
        "sex": r.sex if pd.notna(r.sex) else np.nan,
        "speaker_key_rel": r.speaker_key_rel if pd.notna(r.speaker_key_rel) else np.nan,
        "clip_path": out_path,
        "duration_sec": float(r.duration_sec) if pd.notna(r.duration_sec) else np.nan,
        "source_path": str(r.source_path),
        "clip_start_sec": float(r.clip_start_sec) if pd.notna(r.clip_start_sec) else np.nan,
        "clip_end_sec": float(r.clip_end_sec) if pd.notna(r.clip_end_sec) else np.nan,
        "sr_hz": int(r.sr_hz),
        "channels": 1,
        "clip_is_contiguous": True,
    })

# Candidate folder is kept to support future resume behavior
try:
    if os.path.isdir(D2_CAND_DIR_DRIVE):
        leftovers = list(Path(D2_CAND_DIR_DRIVE).glob("*.wav"))
        if len(leftovers) == 0:
            pass
except Exception:
    pass

manifest_df = pd.DataFrame(final_rows)

# Ensure schema and column order
for c in MANIFEST_COLS:
    if c not in manifest_df.columns:
        manifest_df[c] = np.nan
manifest_df = manifest_df[MANIFEST_COLS].copy()

# -------------------------
# Save outputs: manifests, logs, config
# Inputs: manifest_df + warnings_df + run settings
# Outputs: manifest_all.csv, preprocess_warnings.csv, dataset_summary.json, run_config.json
# -------------------------
manifest_all_path = os.path.join(D2_MANIFEST_DIR, "manifest_all.csv")
warnings_path     = os.path.join(D2_LOGS_DIR, "preprocess_warnings.csv")
summary_path      = os.path.join(D2_LOGS_DIR, "dataset_summary.json")
run_cfg_path      = os.path.join(D2_CONFIG_DIR, "run_config.json")

os.makedirs(os.path.dirname(manifest_all_path), exist_ok=True)
os.makedirs(os.path.dirname(warnings_path), exist_ok=True)
os.makedirs(os.path.dirname(summary_path), exist_ok=True)
os.makedirs(os.path.dirname(run_cfg_path), exist_ok=True)

manifest_df.to_csv(manifest_all_path, index=False)
warnings_df.to_csv(warnings_path, index=False)

summary = {
    "dataset": "D2",
    "run_name": D2_RUN_NAME,
    "source_root": D2_EWA_DIR,
    "files_tsv": D2_FILES_TSV,
    "speakers_tsv": D2_SPEAKERS_TSV,
    "sr_hz": int(D2_SR),
    "webrtcvad_available": bool(D2_HAVE_WEBRTCVAD),
    "scipy_available": bool(D2_HAVE_SCIPY),
    "one_clip_per_source_file": True,
    "resume_safe_candidates": True,
    "local_candidate_cache_enabled": bool(D2_USE_LOCAL_CAND_CACHE),
    "paths": {
        "run_root": D2_RUN_ROOT,
        "candidate_dir_drive": D2_CAND_DIR_DRIVE,
        "candidate_dir_local": D2_CAND_DIR_LOCAL if D2_USE_LOCAL_CAND_CACHE else None,
        "manifest_all": manifest_all_path,
        "warnings_csv": warnings_path,
        "missing_paths_csv": d2_missing_csv,
        "candidates_checkpoint_csv": D2_CAND_CHECKPOINT_CSV,
    },
    "filtering": {
        "diagnosis_in": ["Healthy","Parkinson"],
        "inclusive_criteria_required": True,
        "publish_agreement_required_if_column_exists": ("PUBLISH_AGREEMENT" in files_df.columns),
        "audiofile_not_available_excluded": True,
        "picture_mapping_rule": {"vokal": "vowel", "else": "reading"},
        "existence_check_enabled": bool(D2_DO_EXISTENCE_CHECK),
        "existence_check_unique_only": bool(D2_EXIST_CHECK_UNIQUE_ONLY),
    },
    "counts": {
        "n_source_rows_after_filters": int(len(d2_proc)),
        "n_missing_audio_rows": int(len(missing_audio_rows)),
        "n_candidates_recorded": int(len(cand_df)),
        "n_candidates_written_this_run": int(written_now),
        "n_candidates_skipped_existing": int(skipped_existing),
        "n_candidates_deleted_by_cap": int(deleted),
        "n_clips_written": int(len(manifest_df)),
        "label_counts_rows": d2_proc["label_str"].value_counts(dropna=False).to_dict(),
        "task_counts_rows": d2_proc["task"].value_counts(dropna=False).to_dict(),
        "label_counts_clips": manifest_df["label_str"].value_counts(dropna=False).to_dict() if len(manifest_df) else {},
        "split_counts_clips": manifest_df["split"].value_counts(dropna=False).to_dict() if len(manifest_df) else {},
        "task_counts_clips": manifest_df["task"].value_counts(dropna=False).to_dict() if len(manifest_df) else {},
        "n_warnings": int(len(warnings_df)),
    },
    "notes": (
        "Exactly 1 candidate WAV is written per source file to clips/_candidates (resume-safe), then capped/split and moved to clips/<split>. "
        "clip_start_sec/clip_end_sec are on concatenated voiced timeline (not source-timeline). "
        "Vowel clips are 2.0s padded if short; reading clips are 8.0s if long enough else true duration without padding. "
        "Reading clips apply conservative leading-silence trim and best-window selection to reduce initial silence."
    ),
}
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(summary, f, indent=2)

run_cfg = {
    "dataset": "D2",
    "run_name": D2_RUN_NAME,
    "paths": {
        "dataset_dir": D2_EWA_DIR,
        "files_tsv": D2_FILES_TSV,
        "speakers_tsv": D2_SPEAKERS_TSV,
        "run_root": D2_RUN_ROOT,
        "clips_dir": D2_CLIPS_DIR,
        "candidate_dir_drive": D2_CAND_DIR_DRIVE,
        "candidate_dir_local": D2_CAND_DIR_LOCAL if D2_USE_LOCAL_CAND_CACHE else None,
        "manifest_all": manifest_all_path,
        "warnings_csv": warnings_path,
        "missing_paths_csv": d2_missing_csv,
        "candidates_checkpoint_csv": D2_CAND_CHECKPOINT_CSV,
        "summary_json": summary_path,
    },
    "standard_structure": {
        "clips": "clips/<split>/ (flat), with temporary clips/_candidates during run",
        "manifests": "manifests/manifest_all.csv",
        "config": "config/run_config.json",
        "logs": ["logs/preprocess_warnings.csv", "logs/dataset_summary.json", "logs/candidates_checkpoint.csv"],
    },
    "manifest_schema_order": MANIFEST_COLS,
    "filename_format": "D2_{HC|PD}_{speaker_id}_{task<=5}_{global_index:06d}.wav",
    "task_mapping": {"PICTURE == 'vokal'": "vowl", "else": "read"},
    "cap_policy": {"groupby": ["speaker_id", "task"], "max_per_group": int(D2_MAX_CLIPS_PER_SPK_TASK)},
    "normalization": {
        "method": "manual RMS gain + peak limiting",
        "target_rms_dbfs": float(D2_TARGET_RMS_DBFS),
        "peak_limit_dbfs": float(D2_PEAK_LIMIT_DBFS),
        "min_rms_dbfs": float(D2_MIN_RMS_DBFS),
        "max_gain_db": float(D2_MAX_GAIN_DB),
        "pyloudnorm": False,
    },
    "vad": {
        "method": "webrtcvad if available at 16 kHz else energy fallback; pad ±0.25 s",
        "mode": int(D2_VAD_MODE),
        "frame_ms": int(D2_FRAME_MS),
        "min_speech_ms": int(D2_MIN_SPEECH_MS),
        "merge_gap_ms": int(D2_MERGE_GAP_MS),
        "pad_sec": float(D2_PAD_SEC),
        "min_keep_sec": float(D2_MIN_KEEP_SEC),
    },
    "resume": {
        "resume_safe_candidates": True,
        "deterministic_candidate_naming": True,
        "skip_if_candidate_exists": True,
        "checkpoint_every_rows": int(D2_CHECKPOINT_EVERY),
        "local_candidate_cache": bool(D2_USE_LOCAL_CAND_CACHE),
        "reset_run": bool(D2_RESET_RUN),
    },
    "clip_rules": {
        "vowel": {"sec": float(D2_VOWEL_SEC), "pad_if_short": True, "selection": "center-crop if longer"},
        "reading": {"sec": float(D2_READ_SEC), "pad_if_short": False, "selection": "best-of-multiple windows if longer else keep full",
                    "leading_silence_trim": {"enabled": True, "max_trim_sec": 1.5, "thr_dbfs": -45.0}},
    },
    "seed": int(D2_RANDOM_SEED),
}
with open(run_cfg_path, "w", encoding="utf-8") as f:
    json.dump(run_cfg, f, indent=2)

print("\nDONE: D2 preprocessing")
print("Run root:", D2_RUN_ROOT)
print("Manifest:", manifest_all_path)
print("Warnings:", warnings_path)
print("Summary:", summary_path)
print("Run config:", run_cfg_path)
print("Clips written:", int(len(manifest_df)))
print("Missing-audio rows (not processed):", int(len(missing_audio_rows)))

print("\nSanity checks:")
print("Unique SR:", sorted(manifest_df["sr_hz"].dropna().unique().tolist()) if len(manifest_df) else [])
print("Unique channels:", sorted(manifest_df["channels"].dropna().unique().tolist()) if len(manifest_df) else [])

# Speaker split uniqueness check
spk_split_chk = manifest_df[["speaker_id", "split"]].drop_duplicates()
dup = spk_split_chk.groupby("speaker_id")["split"].nunique()
bad = dup[dup > 1]
print("Speakers appearing in multiple splits:", int(len(bad)))
if len(bad) > 0:
    print("Example bad speaker ids:", bad.index.tolist()[:10])
    raise RuntimeError("Speaker appears in more than one split.")

The following cell copies the full contents of the D2 preprocessing run folder (`runs/run_v1/`) into the main D2 `preprocessed_v1/` directory using `rsync`, a reliable file copy tool. The entire folder structure and all files (including `clips/`, `manifests/`, `config/`, and `logs/`) are kept intact, and the command prints each file as it is copied. The `-a` option copies directories recursively while preserving timestamps and permissions, and the `-v` option enables detailed output. This step is typically run after preprocessing finishes in a temporary run folder and moves the finalized outputs into the standard `preprocessed_v1` location for use in training and evaluation.

In [None]:
!rsync -av \
  "/content/drive/MyDrive/AI_PD_Project/Datasets/D2-Slovak (EWA-DB)/EWA-DB/preprocessed_v1/runs/run_v1/" \
  "/content/drive/MyDrive/AI_PD_Project/Datasets/D2-Slovak (EWA-DB)/EWA-DB/preprocessed_v1/"

The following cell preprocesses Dataset D4 (IPVS Italian) into the project’s standard `preprocessed_v1` format so it matches the structure used for the other datasets. It mounts Google Drive, installs required audio tools if they are missing (WebRTC VAD and SciPy), defines the input folders for young healthy, elderly healthy, and Parkinson speakers, and creates the output folders for clips, manifests, logs, and configuration files. At the start, it clears the temporary `_candidates` folder to avoid mixing files from earlier runs and performs basic checks to confirm that all dataset folders and dependencies are available.

The cell then builds a clean list of usable source WAV files. Each filename is parsed using the IPVS naming convention to extract the task code, repetition number, and basic metadata such as sex and recording date. Only files that match known vowel or reading tasks are kept. Speakers are identified using the original folder structure (`speaker_key_rel`) and converted into a stable, unique `speaker_id` by hashing this value, which avoids conflicts when folder names are similar. This results in a table where each row represents one source file with its speaker, task type, and label (Healthy or Parkinson).

For each indexed source file, the cell cleans the audio and selects **at most one clip per source file**. Audio is loaded as mono and resampled to 16 kHz, then cleaned using DC offset removal and optional hum reduction with high pass and notch filters at 50 Hz and its harmonics. Loudness is normalized using an RMS based method with gain and peak limiting, without using pyloudnorm. Speech regions are detected using voice activity detection, and the single best region is chosen using a clear rule: **the longest voiced segment is selected**, with the earliest segment chosen if there is a tie. From this region, the final clip is created as follows: for vowel tasks, a 2 second clip is taken from the center of the segment and padded if needed; for reading tasks, up to the first 8 seconds are used, or the full duration if it is shorter but not extremely brief. Optional silence trimming is applied to reduce leading and trailing quiet sections. Each selected clip is immediately written as a temporary WAV file in `clips/_candidates/`, and a record is stored with timing information (`clip_start_sec`, `clip_end_sec`), duration, and metadata.

Once all candidate clips are created, the cell applies a balance limit by keeping no more than 8 clips per `(speaker_id, task)` group. Any extra candidate WAV files are deleted at this stage. The remaining clips are then split into training, validation, and test sets at the **speaker level** (70/15/15), while maintaining a similar balance between Healthy and Parkinson speakers. The kept clips are moved into the final flat folder structure `clips/<split>/` using standardized filenames, and a single `manifest_all.csv` is written with a fixed column order used across all datasets. The cell also creates split specific manifest files, a warnings log for skipped or problematic files, a dataset summary JSON with counts and key settings, and a run configuration JSON that records all preprocessing paths and rules. Final checks confirm that every clip listed in the manifests exists on disk and that no speaker appears in more than one split.

In [None]:
# ============================================================
# D4 (IPVS) Preprocessing v1 — One Clip Per Source File
# Inputs: raw IPVS WAV files across YHC, HEC, PD folders
# Outputs: clips/<split>/ WAVs (flat), manifests/manifest_all.csv (+ per-split),
#          config/run_config.json, logs/preprocess_warnings.csv, logs/dataset_summary.json
# Notes: Writes temporary candidates to clips/_candidates/ during processing, then moves kept clips.
# ============================================================

import os
import re
import json
import math
import random
import shutil
import hashlib
from typing import List, Tuple, Dict, Optional

import numpy as np
import pandas as pd
import soundfile as sf
from tqdm.auto import tqdm

# -------------------------
# Drive mount check
# Purpose: ensure dataset and output folders are reachable in Colab
# -------------------------
if not os.path.isdir("/content/drive/MyDrive"):
    from google.colab import drive  # type: ignore
    drive.mount("/content/drive")

# -------------------------
# Optional dependencies
# Purpose: enable VAD and signal processing (install if missing)
# Inputs: none
# Outputs: flags showing which optional libraries are available
# -------------------------
def _d4_try_import_webrtcvad():
    try:
        import webrtcvad  # type: ignore
        return webrtcvad, True
    except Exception:
        return None, False

webrtcvad, HAVE_WEBRTCVAD = _d4_try_import_webrtcvad()
if not HAVE_WEBRTCVAD:
    !pip -q install webrtcvad
    webrtcvad, HAVE_WEBRTCVAD = _d4_try_import_webrtcvad()

try:
    from scipy import signal  # type: ignore
    HAVE_SCIPY = True
except Exception:
    HAVE_SCIPY = False
    !pip -q install scipy
    from scipy import signal  # type: ignore
    HAVE_SCIPY = True

# -------------------------
# Paths and standardized outputs
# Inputs: dataset root folder
# Outputs: preprocessed_v1 folder tree (clips/manifests/config/logs)
# -------------------------
D4_PROJECT_DIR = "/content/drive/MyDrive/AI_PD_Project"
D4_DATASETS_DIR = f"{D4_PROJECT_DIR}/Datasets"
D4_IPVS_DIR = f"{D4_DATASETS_DIR}/D4-Italian (IPVS)"

D4_GROUP_DIRS = {
    "YHC": os.path.join(D4_IPVS_DIR, "15 Young Healthy Control"),
    "HEC": os.path.join(D4_IPVS_DIR, "22 Elderly Healthy Control"),
    "PD":  os.path.join(D4_IPVS_DIR, "28 People with Parkinson's disease"),
}

D4_OUT_ROOT = os.path.join(D4_IPVS_DIR, "preprocessed_v1")
D4_CLIPS_DIR = os.path.join(D4_OUT_ROOT, "clips")
D4_CAND_DIR  = os.path.join(D4_CLIPS_DIR, "_candidates")   # temporary staging during this run
D4_MANIFESTS_DIR = os.path.join(D4_OUT_ROOT, "manifests")
D4_CONFIG_DIR = os.path.join(D4_OUT_ROOT, "config")
D4_LOGS_DIR = os.path.join(D4_OUT_ROOT, "logs")

for p in [D4_OUT_ROOT, D4_CLIPS_DIR, D4_CAND_DIR, D4_MANIFESTS_DIR, D4_CONFIG_DIR, D4_LOGS_DIR]:
    os.makedirs(p, exist_ok=True)
for sp in ["train", "val", "test"]:
    os.makedirs(os.path.join(D4_CLIPS_DIR, sp), exist_ok=True)

# Purpose: avoid mixing candidate files from earlier runs
if os.path.isdir(D4_CAND_DIR):
    try:
        shutil.rmtree(D4_CAND_DIR)
    except Exception:
        pass
os.makedirs(D4_CAND_DIR, exist_ok=True)

# -------------------------
# Fail-fast checks
# Purpose: confirm input folders and optional tools are available
# -------------------------
print("D4_IPVS_DIR exists?", os.path.isdir(D4_IPVS_DIR))
for grp, pth in D4_GROUP_DIRS.items():
    print(f"{grp} folder exists? {os.path.isdir(pth)} -> {pth}")
print("webrtcvad available?", HAVE_WEBRTCVAD)
print("scipy available?", HAVE_SCIPY)

if not os.path.isdir(D4_IPVS_DIR):
    raise FileNotFoundError(f"D4_IPVS_DIR not found: {D4_IPVS_DIR}")

# -------------------------
# Preprocessing constants
# Purpose: unify audio format and enforce the one-clip-per-source rule
# -------------------------
D4_TARGET_SR = 16000

# Purpose: stable RMS-based leveling + peak limiting (no external loudness library)
D4_TARGET_LUFS_LIKE = -23.0   # implemented as RMS-based approximation
D4_PEAK_CEIL = 0.95
D4_MAX_GAIN_DB = 18.0

# Purpose: VAD settings for voiced-region detection
D4_VAD_AGGRESSIVENESS = 2
D4_VAD_FRAME_MS = 30
D4_VAD_PAD_SEC = 0.25
D4_POST_VAD_TRAIL_PAD_SEC = 0.15
D4_MIN_KEEP_SEC = 0.30

# Purpose: task-dependent clip length
D4_VOWEL_SEC = 2.0
D4_READ_SEC = 8.0
D4_OPTION_A = True  # reading < 8s: keep true duration (no padding)

# Purpose: cap clips per (speaker, task) after candidates are created
D4_MAX_CLIPS_PER_SPK_PER_TASK = 8

# Purpose: speaker-level stratified split (label-balanced)
D4_RANDOM_SEED = 1337
D4_SPLIT_FRACS = (0.70, 0.15, 0.15)  # train/val/test

random.seed(D4_RANDOM_SEED)
np.random.seed(D4_RANDOM_SEED)

# Purpose: keep D4-specific hum reduction behavior
D4_HUM_REMOVE = True
D4_HUM_BASE_HZ = 50.0
D4_HUM_N_HARMONICS = 6
D4_HUM_Q = 35.0
D4_HP_CUTOFF_HZ = 70.0

# Purpose: optional trimming steps used inside segments/clips
D4_TRIM_SEGMENTS = True
D4_TRIM_CLIPS = True

# -------------------------
# Small helpers: safe names, labels, and task tags
# -------------------------
def d4_safe(s: str) -> str:
    # Purpose: make filenames filesystem-safe
    return re.sub(r"[^A-Za-z0-9_\-\.]+", "_", str(s))

def d4_pd_hc(label_str: str) -> str:
    # Purpose: short label tag for filenames
    return "PD" if str(label_str).lower().startswith("parkinson") else "HC"

def d4_task_short(task_type: str) -> str:
    # Purpose: map task types into short tokens used in filenames/manifest
    tt = (task_type or "").lower()
    if tt == "vowel":
        return "vowl"
    if tt == "reading":
        return "read"
    if tt == "spontaneous":
        return "spont"
    return "unk"

# -------------------------
# IPVS filename parsing + task mapping
# Inputs: filename stem tokens
# Outputs: parsed fields (task, rep, subject token, sex, recording date token)
# -------------------------
def d4_normalize_stem(stem: str) -> str:
    # Purpose: remove spaces and trailing dots for consistent parsing
    return re.sub(r"\s+", "", str(stem)).rstrip(".")

d4_ipvs_pattern = re.compile(
    r"""
    ^(?P<task>[A-Za-z]{1,3})
    (?P<rep>\d+)
    (?P<subject>.+?)
    (?P<birth_year>\d{2})
    (?P<sex>[MF])
    (?P<recdate>\d{10}|\d{12})$
    """,
    re.VERBOSE
)

def d4_parse_ipvs_basename(basename_no_ext: str) -> Optional[Dict[str, str]]:
    # Purpose: parse metadata encoded in the filename
    s = d4_normalize_stem(basename_no_ext)
    m = d4_ipvs_pattern.match(s)
    if not m:
        return None
    d = m.groupdict()
    d["task"] = str(d["task"]).upper()
    d["subject"] = str(d["subject"]).upper()
    d["recdate_fmt"] = "DDMMYYYYHHMM" if len(d["recdate"]) == 12 else "DDMMYYHHMM"
    return d

D4_VOWEL_TASKS = {"VA", "VE", "VI", "VO", "VU"}
D4_READ_TASKS  = {"PR", "FB", "D", "B"}

def d4_task_type(task: str) -> Optional[str]:
    # Purpose: collapse many IPVS tasks into the two modeled groups
    t = str(task).upper()
    if t in D4_VOWEL_TASKS:
        return "vowel"
    if t in D4_READ_TASKS:
        return "reading"
    return None

# -------------------------
# File indexing and speaker_id creation (keep dataset-specific logic)
# Inputs: group folders with nested subfolders
# Outputs: speaker_key_rel and stable speaker_id for splitting and tracking
# -------------------------
def d4_walk_wavs_only_raw(group_root: str) -> List[str]:
    # Purpose: list raw WAVs while skipping any generated folders
    wavs = []
    for root, dirs, files in os.walk(group_root):
        dirs[:] = [d for d in dirs if d not in {"preprocessed_v1", ".ipynb_checkpoints"}]
        for f in files:
            if f.lower().endswith(".wav"):
                wavs.append(os.path.join(root, f))
    return wavs

def d4_speaker_key_rel(group_root: str, wav_path: str, group: str) -> str:
    # Purpose: build a relative speaker key based on folder structure
    rel_dir = os.path.relpath(os.path.dirname(wav_path), group_root)
    parts = [p for p in rel_dir.split(os.sep) if p not in {"", "."}]

    if group in {"YHC", "HEC"}:
        if len(parts) >= 1:
            return f"{group}/{parts[0]}"
        return f"{group}/{os.path.basename(os.path.dirname(wav_path))}"

    if len(parts) >= 2:
        return f"{group}/{parts[0]}/{parts[1]}"
    if len(parts) == 1:
        return f"{group}/{parts[0]}"
    return f"{group}/{os.path.basename(os.path.dirname(wav_path))}"

def d4_speaker_name_from_key(speaker_key_rel: str) -> str:
    return speaker_key_rel.split("/")[-1]

def d4_speaker_id_from_key(speaker_key_rel: str) -> str:
    # Purpose: stable ID that does not depend on absolute paths
    h = hashlib.md5(speaker_key_rel.encode("utf-8")).hexdigest()[:8]
    speaker_name = d4_speaker_name_from_key(speaker_key_rel)
    group = speaker_key_rel.split("/")[0]
    return f"D4_{group}__{speaker_name}__{h}"

def d4_build_index() -> Tuple[pd.DataFrame, List[str], int]:
    # Inputs: group roots
    # Outputs: table of usable source files (one row per source WAV)
    wav_files = []
    wav_files_with_roots = []

    for grp, grp_root in D4_GROUP_DIRS.items():
        if not os.path.isdir(grp_root):
            print("WARNING missing group folder:", grp_root)
            continue
        grp_wavs = d4_walk_wavs_only_raw(grp_root)
        wav_files.extend(grp_wavs)
        wav_files_with_roots.extend([(w, grp, grp_root) for w in grp_wavs])

    rows = []
    unmatched = []

    for wav_path, grp, grp_root in wav_files_with_roots:
        base = os.path.basename(wav_path)
        stem = d4_normalize_stem(base[:-4])

        parsed = d4_parse_ipvs_basename(stem)
        if parsed is None:
            unmatched.append(stem)
            continue

        task = parsed["task"]
        ttype = d4_task_type(task)
        if ttype is None:
            continue

        speaker_key = d4_speaker_key_rel(grp_root, wav_path, grp)
        speaker_id = d4_speaker_id_from_key(speaker_key)

        label_str = "Parkinson" if grp == "PD" else "Healthy"
        label_num = 1 if grp == "PD" else 0

        rows.append({
            "audio_path": wav_path,
            "group": grp,
            "speaker_key_rel": speaker_key,
            "speaker_id": speaker_id,
            "task": task,
            "rep": int(parsed["rep"]),
            "subject_token": parsed["subject"],
            "birth_year_2d": parsed["birth_year"],
            "sex": parsed["sex"],
            "recdate_raw": parsed["recdate"],
            "recdate_fmt": parsed["recdate_fmt"],
            "task_type": ttype,
            "label_str": label_str,
            "label_num": label_num,
            "dataset": "D4",
        })

    df = pd.DataFrame(rows)
    total_wavs = len(wav_files)

    print("\nD4 total wavs found:", total_wavs)
    print("D4 matched+used (after task filter):", len(df), "unmatched stems:", len(unmatched))
    if unmatched:
        print("First 15 unmatched stems:", unmatched[:15])

    return df, unmatched, total_wavs

# -------------------------
# Audio utilities (manual leveling, trimming, hum reduction, resampling)
# Purpose: produce clean mono 16 kHz audio for VAD and clip extraction
# -------------------------
def d4_peak_limit(x: np.ndarray, peak_ceiling: float = D4_PEAK_CEIL) -> np.ndarray:
    # Purpose: ensure headroom to avoid clipping
    if x is None or len(x) == 0:
        return x
    peak = float(np.max(np.abs(x)))
    if peak <= 0:
        return x.astype(np.float32)
    if peak > peak_ceiling:
        x = x * (peak_ceiling / peak)
    return x.astype(np.float32)

def d4_integrated_rms_db(x: np.ndarray) -> Optional[float]:
    # Purpose: RMS in dB for gain calculation
    if x is None or len(x) == 0:
        return None
    rms = float(np.sqrt(np.mean(x.astype(np.float64) ** 2)))
    if rms <= 1e-12:
        return None
    return float(20.0 * np.log10(rms))

def d4_remove_dc(x: np.ndarray) -> np.ndarray:
    # Purpose: remove constant offset that can confuse filters and VAD
    if x is None or len(x) == 0:
        return x.astype(np.float32)
    y = x.astype(np.float32)
    y = y - float(np.mean(y))
    return y.astype(np.float32)

def d4_soft_clip(x: np.ndarray) -> np.ndarray:
    # Purpose: gentle limiting when a large gain is applied
    if x is None or len(x) == 0:
        return x.astype(np.float32)
    y = np.tanh(1.25 * x.astype(np.float32)) / np.tanh(1.25)
    return y.astype(np.float32)

def d4_apply_level_target_manual(x: np.ndarray, sr: int) -> Tuple[np.ndarray, Dict[str, float]]:
    # Inputs: float32 audio
    # Outputs: leveled audio + small stats for debugging/logging
    info = {"lvl_in_db": float("nan"), "lvl_out_db": float("nan"), "gain_db": 0.0, "peak_in": 0.0, "peak_out": 0.0}
    if x is None or len(x) == 0:
        return x.astype(np.float32), info

    info["peak_in"] = float(np.max(np.abs(x)))
    in_db = d4_integrated_rms_db(x)
    if in_db is None:
        y = d4_peak_limit(x, D4_PEAK_CEIL)
        info["peak_out"] = float(np.max(np.abs(y))) if len(y) else 0.0
        return y.astype(np.float32), info

    info["lvl_in_db"] = float(in_db)

    # Purpose: RMS-like target (shifted from the LUFS-like constant)
    target_rms_db = float(D4_TARGET_LUFS_LIKE + 7.0)
    gain_db = float(target_rms_db - in_db)
    gain_db = float(np.clip(gain_db, -60.0, D4_MAX_GAIN_DB))
    info["gain_db"] = float(gain_db)

    gain = float(10.0 ** (gain_db / 20.0))
    y = (x.astype(np.float32) * gain).astype(np.float32)

    if gain_db > 12.0:
        y = d4_soft_clip(y)

    y = d4_peak_limit(y, D4_PEAK_CEIL)
    y = np.clip(y, -1.0, 1.0).astype(np.float32)

    info["peak_out"] = float(np.max(np.abs(y))) if len(y) else 0.0
    out_db = d4_integrated_rms_db(y)
    if out_db is not None:
        info["lvl_out_db"] = float(out_db)

    return y.astype(np.float32), info

def d4_trim_silence_energy(
    x: np.ndarray,
    sr: int,
    frame_ms: int = 20,
    hop_ms: int = 10,
    db_margin: float = 25.0,
    min_keep_ms: int = 150
) -> np.ndarray:
    # Purpose: drop leading/trailing low-energy regions while keeping a minimum duration
    if x is None or len(x) == 0:
        return x.astype(np.float32)

    frame = int(sr * frame_ms / 1000.0)
    hop = int(sr * hop_ms / 1000.0)
    if frame <= 0 or hop <= 0 or len(x) < frame:
        return x.astype(np.float32)

    rms_db = []
    idx = 0
    while idx + frame <= len(x):
        w = x[idx:idx+frame]
        rms = float(np.sqrt(np.mean(w.astype(np.float64) ** 2)))
        db = -120.0 if rms <= 1e-12 else (20.0 * np.log10(rms))
        rms_db.append(db)
        idx += hop

    rms_db = np.array(rms_db, dtype=np.float32)
    if len(rms_db) == 0:
        return x.astype(np.float32)

    mx = float(np.max(rms_db))
    thr = mx - float(db_margin)
    keep = np.where(rms_db >= thr)[0]
    if len(keep) == 0:
        return x.astype(np.float32)

    start_f = int(keep[0])
    end_f = int(keep[-1])

    start = max(0, start_f * hop)
    end = min(len(x), end_f * hop + frame)

    min_keep = int(sr * (min_keep_ms / 1000.0))
    if end - start < min_keep:
        mid = (start + end) // 2
        start = max(0, mid - min_keep // 2)
        end = min(len(x), start + min_keep)

    return x[start:end].astype(np.float32)

def d4_force_length(x: np.ndarray, n: int) -> np.ndarray:
    # Purpose: pad or trim to an exact length (used for vowel clips)
    if x is None:
        return np.zeros((n,), dtype=np.float32)
    if len(x) == n:
        return x.astype(np.float32)
    if len(x) > n:
        return x[:n].astype(np.float32)
    y = np.zeros((n,), dtype=np.float32)
    y[:len(x)] = x.astype(np.float32)
    return y

def d4_float_to_pcm16(x: np.ndarray) -> bytes:
    # Purpose: VAD requires 16-bit PCM bytes
    x = np.clip(x, -1.0, 1.0)
    return (x * 32767.0).astype(np.int16).tobytes()

def d4_notch_filter(x: np.ndarray, sr: int, f0: float, q: float) -> np.ndarray:
    # Purpose: remove a narrowband tone (hum harmonic)
    if f0 <= 0 or f0 >= (sr / 2.0):
        return x.astype(np.float32)
    b, a = signal.iirnotch(w0=f0, Q=q, fs=sr)
    y = signal.filtfilt(b, a, x.astype(np.float64)).astype(np.float32)
    return y

def d4_highpass_filter(x: np.ndarray, sr: int, cutoff_hz: float) -> np.ndarray:
    # Purpose: reduce very low-frequency rumble
    if cutoff_hz <= 0:
        return x.astype(np.float32)
    cutoff_hz = min(cutoff_hz, sr / 2.0 - 1.0)
    b, a = signal.butter(2, cutoff_hz, btype="highpass", fs=sr)
    y = signal.filtfilt(b, a, x.astype(np.float64)).astype(np.float32)
    return y

def d4_hum_reduce(x: np.ndarray, sr: int) -> np.ndarray:
    # Purpose: remove DC, apply high-pass, then notch out mains hum harmonics
    if x is None or len(x) == 0:
        return x.astype(np.float32)
    y = x.astype(np.float32)

    y = d4_remove_dc(y)
    y = d4_highpass_filter(y, sr, D4_HP_CUTOFF_HZ)

    for k in range(1, D4_HUM_N_HARMONICS + 1):
        f0 = D4_HUM_BASE_HZ * k
        if f0 >= sr / 2.0:
            break
        y = d4_notch_filter(y, sr, f0, D4_HUM_Q)

    y = d4_peak_limit(y, D4_PEAK_CEIL)
    y = np.clip(y, -1.0, 1.0).astype(np.float32)
    return y

def d4_load_audio_mono_16k(path: str, target_sr: int = D4_TARGET_SR) -> Tuple[np.ndarray, int]:
    # Inputs: source WAV path
    # Outputs: mono float32 audio at 16 kHz
    x, sr = sf.read(path, always_2d=True)
    x = x.mean(axis=1).astype(np.float32)

    if not np.isfinite(x).all():
        x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)

    if sr != target_sr:
        g = np.gcd(sr, target_sr)
        up = target_sr // g
        down = sr // g
        x = signal.resample_poly(x.astype(np.float64), up, down).astype(np.float32)
        sr = target_sr

    x = np.clip(x, -1.0, 1.0).astype(np.float32)
    return x.astype(np.float32), sr

# -------------------------
# VAD segments (source timeline)
# Inputs: leveled audio at 16 kHz
# Outputs: list of (start_sample, end_sample) in the source timeline
# -------------------------
def d4_vad_segments_source_timeline(
    x: np.ndarray,
    sr: int,
    vad_aggr: int = D4_VAD_AGGRESSIVENESS,
    frame_ms: int = D4_VAD_FRAME_MS
) -> List[Tuple[int, int]]:
    if not HAVE_WEBRTCVAD:
        return []

    assert sr == 16000, "VAD expects 16k audio."
    vad = webrtcvad.Vad(int(vad_aggr))

    frame_len = int(sr * (frame_ms / 1000.0))
    if frame_len <= 0:
        return []

    n = len(x)
    n_frames = n // frame_len
    if n_frames == 0:
        return []

    pcm = d4_float_to_pcm16(x[:n_frames * frame_len])

    def frame_bytes(i: int) -> bytes:
        start = i * frame_len * 2
        end = start + frame_len * 2
        return pcm[start:end]

    voiced = [vad.is_speech(frame_bytes(i), sr) for i in range(n_frames)]

    segs = []
    i = 0
    while i < n_frames:
        if not voiced[i]:
            i += 1
            continue
        s0 = i
        while i < n_frames and voiced[i]:
            i += 1
        e0 = i
        segs.append((s0 * frame_len, e0 * frame_len))

    # Purpose: merge short gaps between voiced segments
    max_gap = int(0.20 * sr)
    merged = []
    for s, e in segs:
        if not merged:
            merged.append([s, e])
        else:
            ps, pe = merged[-1]
            if s - pe <= max_gap:
                merged[-1][1] = e
            else:
                merged.append([s, e])

    # Purpose: pad segments slightly to avoid hard cuts
    pad = int(round(D4_VAD_PAD_SEC * sr))
    tail = int(round(D4_POST_VAD_TRAIL_PAD_SEC * sr))

    out = []
    for s, e in merged:
        s2 = max(0, s - pad)
        e2 = min(n, e + pad + tail)
        if e2 > s2:
            out.append((int(s2), int(e2)))

    return out

# -------------------------
# Single-clip selection per source file
# Rule: choose the longest voiced segment (tie-breaker: earliest)
# Output: one clip dict or None
# -------------------------
def d4_pick_longest_segment(segs: List[Tuple[int, int]]) -> Optional[Tuple[int, int]]:
    if not segs:
        return None
    segs2 = sorted(segs, key=lambda t: (-(t[1]-t[0]), t[0]))
    return segs2[0]

def d4_make_single_clip_from_source_segments(
    y: np.ndarray,
    sr: int,
    segs: List[Tuple[int, int]],
    task_type: str
) -> Optional[Dict]:
    # Inputs: full audio + VAD segments + task type
    # Outputs: {audio, start, end, duration} or None
    if not segs:
        return None

    best = d4_pick_longest_segment(segs)
    if best is None:
        return None
    s, e = best

    seg = y[s:e].astype(np.float32)
    if D4_TRIM_SEGMENTS:
        seg = d4_trim_silence_energy(seg, sr, db_margin=28.0)
    if seg is None or len(seg) <= 0:
        return None

    if task_type == "vowel":
        # Purpose: fixed 2.0 s, centered in the chosen segment
        L = int(round(D4_VOWEL_SEC * sr))
        if len(seg) >= L:
            mid = len(seg) // 2
            a0 = max(0, mid - L // 2)
            a1 = a0 + L
            audio = seg[a0:a1].astype(np.float32)
            src_start = float(s + a0) / sr
            src_end = float(s + a1) / sr
        else:
            # Purpose: allow short vowel by padding to 2.0 s
            audio = d4_force_length(seg, L).astype(np.float32)
            src_start = float(s) / sr
            src_end = float(s + L) / sr  # padded end

        if D4_TRIM_CLIPS:
            t = d4_trim_silence_energy(audio, sr, db_margin=30.0)
            audio = d4_force_length(t, L)

        audio = d4_peak_limit(audio, D4_PEAK_CEIL)
        audio = np.clip(audio, -1.0, 1.0).astype(np.float32)
        return {"audio": audio, "start": float(src_start), "end": float(src_end), "duration": float(len(audio))/sr}

    # reading
    L = int(round(D4_READ_SEC * sr))
    if len(seg) >= L:
        # Purpose: enforce one-window rule (first 8 seconds only)
        audio = seg[:L].astype(np.float32)
        src_start = float(s) / sr
        src_end = float(s + L) / sr
        audio = d4_peak_limit(audio, D4_PEAK_CEIL)
        audio = np.clip(audio, -1.0, 1.0).astype(np.float32)
        return {"audio": audio, "start": float(src_start), "end": float(src_end), "duration": float(len(audio))/sr}

    # Purpose: keep shorter reading clips only if not too short
    if len(seg) < int(sr * D4_MIN_KEEP_SEC):
        return None

    audio = seg.astype(np.float32)
    audio = d4_peak_limit(audio, D4_PEAK_CEIL)
    audio = np.clip(audio, -1.0, 1.0).astype(np.float32)
    src_start = float(s) / sr
    src_end = float(s + len(audio)) / sr
    return {"audio": audio, "start": float(src_start), "end": float(src_end), "duration": float(len(audio))/sr}

# -------------------------
# WAV writer
# Inputs: clip audio
# Outputs: PCM_16 WAV file on Drive
# -------------------------
def d4_write_wav(path: str, x: np.ndarray, sr: int):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    x = np.clip(x, -1.0, 1.0).astype(np.float32)
    sf.write(path, x, sr, subtype="PCM_16")

# -------------------------
# Build source index (one row per source WAV)
# Outputs: d4_index used for preprocessing loop
# -------------------------
d4_index, d4_unmatched, d4_total_wavs = d4_build_index()

if len(d4_index) == 0:
    raise RuntimeError("No usable D4 rows after indexing/parsing/task mapping.")

print("\nD4 index shape:", d4_index.shape)
print("Group counts:\n", d4_index["group"].value_counts(dropna=False), "\n")
print("Task_type counts:\n", d4_index["task_type"].value_counts(dropna=False), "\n")
print("Unique speakers:", int(d4_index["speaker_id"].nunique()))

# -------------------------
# Preprocess each source → write one candidate clip immediately
# Inputs: source WAV paths
# Outputs: candidate WAVs + candidate metadata table
# -------------------------
MANIFEST_COLS = [
    "split", "dataset", "task", "speaker_id", "sample_id",
    "label_str", "label_num", "age", "sex",
    "speaker_key_rel",
    "clip_path", "duration_sec", "source_path",
    "clip_start_sec", "clip_end_sec",
    "sr_hz", "channels",
    "clip_is_contiguous",
]

cand_rows: List[Dict] = []
warn_rows: List[Dict] = []
cand_counter = 0

pbar = tqdm(d4_index.itertuples(index=False), total=len(d4_index),
            desc="D4 preprocess (1 clip per source; write candidates)", dynamic_ncols=True)

for r in pbar:
    src = r.audio_path
    try:
        # Pipeline: load/resample → DC removal → hum reduction → leveling
        y, sr = d4_load_audio_mono_16k(src, D4_TARGET_SR)
        if sr != D4_TARGET_SR:
            raise RuntimeError(f"Unexpected SR after load: {sr}")

        y = d4_remove_dc(y)
        if D4_HUM_REMOVE:
            y = d4_hum_reduce(y, sr)
        y, _linfo = d4_apply_level_target_manual(y, sr)

        # Purpose: find voiced regions on the processed signal
        segs = d4_vad_segments_source_timeline(y, sr)
        if not segs:
            warn_rows.append({"dataset": "D4", "speaker_id": r.speaker_id, "source_path": src,
                              "warning_type": "no_vad_segments", "detail": ""})
            continue

        # Purpose: choose exactly one clip from the voiced regions
        clip = d4_make_single_clip_from_source_segments(y, sr, segs, r.task_type)
        if clip is None:
            warn_rows.append({"dataset": "D4", "speaker_id": r.speaker_id, "source_path": src,
                              "warning_type": "no_clip_selected", "detail": ""})
            continue

        task5 = d4_task_short(r.task_type)

        cand_counter += 1
        cand_name = d4_safe(f"CAND_{cand_counter:08d}.wav")
        cand_path = os.path.join(D4_CAND_DIR, cand_name)

        # Write candidate now (staging file)
        d4_write_wav(cand_path, clip["audio"], sr)

        cand_rows.append({
            "dataset": "D4",
            "task": task5,
            "speaker_id": r.speaker_id,
            "sample_id": os.path.basename(src),
            "label_str": r.label_str,
            "label_num": int(r.label_num),
            "age": np.nan,
            "sex": r.sex if pd.notna(r.sex) else np.nan,
            "speaker_key_rel": r.speaker_key_rel,
            "clip_path_cand": cand_path,
            "duration_sec": float(clip["duration"]),
            "source_path": src,
            "clip_start_sec": float(clip["start"]),
            "clip_end_sec": float(clip["end"]),
            "sr_hz": int(sr),
            "channels": 1,
            "clip_is_contiguous": True,
        })

    except Exception as e:
        warn_rows.append({"dataset": "D4", "speaker_id": getattr(r, "speaker_id", ""),
                          "source_path": src, "warning_type": "preprocess_error", "detail": repr(e)})

cand_df = pd.DataFrame(cand_rows)
warnings_df = pd.DataFrame(warn_rows)

if len(cand_df) == 0:
    raise RuntimeError("No D4 clips produced. Check indexing/parsing/VAD settings.")

print("\nD4 candidates written:", int(len(cand_df)))
print("Candidate dir:", D4_CAND_DIR)

# -------------------------
# Cap candidates per (speaker_id, task), then delete unkept candidate WAVs
# Inputs: candidate table + candidate files
# Outputs: capped candidate table, staging cleaned
# -------------------------
def d4_cap_manifest(df: pd.DataFrame, max_k: int, seed: int) -> Tuple[pd.DataFrame, set]:
    rng = np.random.default_rng(seed)
    kept_idx = []
    for (spk, tt), g in df.groupby(["speaker_id", "task"], sort=False):
        idxs = g.index.to_numpy()
        if len(idxs) <= max_k:
            kept_idx.extend(idxs.tolist())
        else:
            chosen = rng.choice(idxs, size=max_k, replace=False)
            kept_idx.extend(chosen.tolist())
    kept_idx = sorted(set(kept_idx))
    kept_set = set(kept_idx)
    return df.loc[kept_idx].reset_index(drop=True), kept_set

cand_df_capped, keep_set = d4_cap_manifest(cand_df, D4_MAX_CLIPS_PER_SPK_PER_TASK, D4_RANDOM_SEED)
print("D4 clips after cap:", int(len(cand_df_capped)))

to_delete = cand_df.loc[~cand_df.index.isin(list(keep_set)), "clip_path_cand"].tolist()
deleted = 0
for p in to_delete:
    try:
        if os.path.exists(p):
            os.remove(p)
            deleted += 1
    except Exception as e:
        warnings_df = pd.concat([warnings_df, pd.DataFrame([{
            "dataset": "D4",
            "speaker_id": "",
            "source_path": "",
            "warning_type": "candidate_delete_failed",
            "detail": f"{p} :: {repr(e)}"
        }])], ignore_index=True)

print("Deleted unkept candidates:", deleted)
cand_df = cand_df_capped

# -------------------------
# Speaker-level split (stratified by label_num)
# Inputs: unique speakers from kept candidates
# Outputs: split assignment merged back onto cand_df
# -------------------------
def d4_split_speakers(df: pd.DataFrame, fracs=(0.70, 0.15, 0.15), seed=1337) -> pd.DataFrame:
    assert abs(sum(fracs) - 1.0) < 1e-9
    rng = np.random.default_rng(seed)

    spk_df = df[["speaker_id", "label_num", "label_str"]].drop_duplicates().copy()
    split_rows = []

    for lbl in [0, 1]:
        spks = spk_df[spk_df["label_num"] == lbl]["speaker_id"].tolist()
        rng.shuffle(spks)
        n = len(spks)
        n_train = int(round(fracs[0] * n))
        n_val = int(round(fracs[1] * n))

        train_spks = spks[:n_train]
        val_spks = spks[n_train:n_train + n_val]
        test_spks = spks[n_train + n_val:]

        split_rows += [{"speaker_id": s, "split": "train"} for s in train_spks]
        split_rows += [{"speaker_id": s, "split": "val"} for s in val_spks]
        split_rows += [{"speaker_id": s, "split": "test"} for s in test_spks]

    return pd.DataFrame(split_rows)

spk_split = d4_split_speakers(cand_df, D4_SPLIT_FRACS, D4_RANDOM_SEED)

cand_df = cand_df.drop(columns=["split"], errors="ignore")
cand_df = cand_df.merge(spk_split, on="speaker_id", how="left", validate="many_to_one")

if cand_df["split"].isna().any():
    ex = cand_df.loc[cand_df["split"].isna(), "speaker_id"].drop_duplicates().head(10).tolist()
    raise RuntimeError(f"Some speakers did not get a split assignment. Example: {ex}")

# -------------------------
# Finalize: move kept candidates into clips/<split>/ and build manifest rows
# Inputs: kept candidate WAVs + split assignment
# Outputs: final WAVs + manifest rows pointing to final paths
# -------------------------
global_counter = 0
final_rows: List[Dict] = []

pbar2 = tqdm(cand_df.itertuples(index=False), total=len(cand_df),
             desc="D4 finalize (move kept)", dynamic_ncols=True)

for r in pbar2:
    global_counter += 1

    label_tag = d4_pd_hc(r.label_str)
    task5 = str(r.task)
    spk = str(r.speaker_id)

    out_name = d4_safe(f"D4_{label_tag}_{spk}_{task5}_{global_counter:06d}.wav")
    out_path = os.path.join(D4_CLIPS_DIR, r.split, out_name)

    cand_path = getattr(r, "clip_path_cand")
    if not os.path.exists(cand_path):
        warnings_df = pd.concat([warnings_df, pd.DataFrame([{
            "dataset": "D4",
            "speaker_id": spk,
            "source_path": r.source_path,
            "warning_type": "missing_candidate_file",
            "detail": cand_path
        }])], ignore_index=True)
        continue

    shutil.move(cand_path, out_path)

    final_rows.append({
        "split": r.split,
        "dataset": "D4",
        "task": task5,
        "speaker_id": spk,
        "sample_id": r.sample_id,
        "label_str": r.label_str,
        "label_num": int(r.label_num),
        "age": np.nan,
        "sex": r.sex if pd.notna(r.sex) else np.nan,
        "speaker_key_rel": r.speaker_key_rel,
        "clip_path": out_path,
        "duration_sec": float(r.duration_sec),
        "source_path": r.source_path,
        "clip_start_sec": float(r.clip_start_sec),
        "clip_end_sec": float(r.clip_end_sec),
        "sr_hz": int(r.sr_hz),
        "channels": 1,
        "clip_is_contiguous": True,
    })

# Purpose: remove temporary staging directory when done
try:
    if os.path.isdir(D4_CAND_DIR):
        leftovers = list(os.scandir(D4_CAND_DIR))
        for ent in leftovers:
            try:
                os.remove(ent.path)
            except Exception:
                pass
        try:
            os.rmdir(D4_CAND_DIR)
        except Exception:
            pass
except Exception:
    pass

manifest_df = pd.DataFrame(final_rows)

# Purpose: enforce canonical schema and column order
for c in MANIFEST_COLS:
    if c not in manifest_df.columns:
        manifest_df[c] = np.nan
manifest_df = manifest_df[MANIFEST_COLS].copy()

# -------------------------
# Save artifacts (manifests, warnings, summary, run config)
# Outputs: CSV/JSON files under manifests/, logs/, config/
# -------------------------
manifest_all_path = os.path.join(D4_MANIFESTS_DIR, "manifest_all.csv")
manifest_df.to_csv(manifest_all_path, index=False)

for sp in ["train", "val", "test"]:
    p = os.path.join(D4_MANIFESTS_DIR, f"manifest_{sp}.csv")
    manifest_df.loc[manifest_df["split"] == sp].to_csv(p, index=False)

warnings_path = os.path.join(D4_LOGS_DIR, "preprocess_warnings.csv")
warnings_df.to_csv(warnings_path, index=False)

summary = {
    "dataset": "D4",
    "source_root": D4_IPVS_DIR,
    "target_sr": int(D4_TARGET_SR),
    "webrtcvad_available": bool(HAVE_WEBRTCVAD),
    "scipy_available": bool(HAVE_SCIPY),
    "n_source_files_total_found": int(d4_total_wavs),
    "n_source_files_used_after_filter": int(len(d4_index)),
    "n_output_clips_total": int(len(manifest_df)),
    "n_speakers": int(manifest_df["speaker_id"].nunique()) if len(manifest_df) else 0,
    "label_counts_clips": manifest_df["label_str"].value_counts(dropna=False).to_dict() if len(manifest_df) else {},
    "split_counts_clips": manifest_df["split"].value_counts(dropna=False).to_dict() if len(manifest_df) else {},
    "task_counts_clips": manifest_df["task"].value_counts(dropna=False).to_dict() if len(manifest_df) else {},
    "n_warnings": int(len(warnings_df)),
    "notes": "At most 1 clip per source file. Candidates written to clips/_candidates, then capped/split and moved to clips/<split>.",
}
summary_path = os.path.join(D4_LOGS_DIR, "dataset_summary.json")
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(summary, f, indent=2)

run_cfg = {
    "dataset": "D4",
    "paths": {
        "dataset_dir": D4_IPVS_DIR,
        "out_root": D4_OUT_ROOT,
        "clips_dir": D4_CLIPS_DIR,
        "candidate_dir": D4_CAND_DIR,
        "manifest_all": manifest_all_path,
        "warnings_csv": warnings_path,
        "summary_json": summary_path,
    },
    "folder_structure": "clips/<split>/ (flat) with temporary clips/_candidates during run",
    "filename_format": "D4_{HC|PD}_{speaker_id}_{task<=5}_{global_index:06d}.wav",
    "manifest_schema_order": MANIFEST_COLS,
    "cap_policy": {"groupby": ["speaker_id", "task"], "max_per_group": int(D4_MAX_CLIPS_PER_SPK_PER_TASK)},
    "clip_boundary_policy": "clip_start_sec/clip_end_sec are source-timeline (post-resample). clip_is_contiguous=True.",
    "one_clip_per_source_file": True,
    "one_clip_selection_rule": "Longest voiced segment; vowel=2s centered/padded, reading=first 8s or true duration if shorter (>=min).",
    "audio_processing": {
        "sr_hz": int(D4_TARGET_SR),
        "hum_remove": bool(D4_HUM_REMOVE),
        "hum_base_hz": float(D4_HUM_BASE_HZ),
        "highpass_cutoff_hz": float(D4_HP_CUTOFF_HZ),
        "leveling": "manual RMS-like + peak limit (no pyloudnorm)",
        "peak_ceiling": float(D4_PEAK_CEIL),
        "max_gain_db": float(D4_MAX_GAIN_DB),
    },
}
run_cfg_path = os.path.join(D4_CONFIG_DIR, "run_config.json")
with open(run_cfg_path, "w", encoding="utf-8") as f:
    json.dump(run_cfg, f, indent=2)

print("\nDONE: D4 preprocessing")
print("Manifest:", manifest_all_path)
print("Warnings:", warnings_path)
print("Summary:", summary_path)
print("Run config:", run_cfg_path)
print("Clips written:", int(len(manifest_df)))

# -------------------------
# Quick sanity prints
# Purpose: confirm uniform audio format and clean speaker split
# -------------------------
print("\nSanity checks:")
print("Unique SR:", sorted(manifest_df["sr_hz"].unique().tolist()) if len(manifest_df) else [])
print("Unique channels:", sorted(manifest_df["channels"].unique().tolist()) if len(manifest_df) else [])
print("Speakers per split x label (by speaker):")
print(
    manifest_df[["speaker_id", "split", "label_str"]]
    .drop_duplicates()
    .groupby(["split", "label_str"])
    .size()
)

spk_split_chk = manifest_df[["speaker_id", "split"]].drop_duplicates()
dup = spk_split_chk.groupby("speaker_id")["split"].nunique()
bad = dup[dup > 1]
print("Speakers appearing in multiple splits:", int(len(bad)))
if len(bad) > 0:
    print("Example bad speaker ids:", bad.index.tolist()[:10])
    raise RuntimeError("Speaker appears in more than one split.")

The following cell preprocesses Dataset D5 (MDVR-KCL) into the standard `preprocessed_v1` folder structure used across the project. It scans the raw D5 WAV files and keeps only those that can be clearly linked to a speaker ID, a task type, and a label (Healthy or Parkinson). The cell then creates clean and consistent training clips along with a single combined manifest. A strict rule is enforced so that **no more than one final clip is produced from each source WAV file**. To make the process more reliable when working on Google Drive, each processed clip is written right away to a temporary `_candidates` folder, and only the selected clips are later moved into the final `clips/train`, `clips/val`, and `clips/test` folders.

Audio processing follows a fixed and repeatable pipeline. Each source file is loaded, converted to mono, resampled to 16 kHz, and the first 40 seconds are removed to skip setup sounds or non target audio. When SciPy is available, an optional high pass filter is applied to reduce very low frequency noise. Speech regions are then detected using voice activity detection, with WebRTC VAD used when available and an energy based fallback otherwise. Before normalizing volume, a **loudness based gate** is applied to the detected speech regions. Quieter segments are dropped under the assumption that they are more likely to be background speech or another speaker, while the main speaker is typically louder and closer to the microphone. The gate starts strict and gradually relaxes until enough speech is available to build an 8 second clip. The remaining segments are stitched together, then volume normalization is applied using an RMS target with peak limiting, without using pyloudnorm. From this stitched signal, the first full 8 seconds are taken as the final clip. If less than 8 seconds of usable audio remains, the file is skipped and a warning is recorded.

After all candidate clips are created, the cell applies a limit of up to 8 clips per speaker and task, randomly selecting clips when more are available, and deletes any unselected candidate WAV files. Speakers are then split into training, validation, and test sets at the speaker level (70/15/15), while keeping Healthy and Parkinson speakers proportionally balanced in each split. The selected clips are moved into their final split folders using a standardized filename format, and a final `manifest_all.csv` is written with a consistent column order used throughout the project. The cell also generates `logs/preprocess_warnings.csv` to record any issues, `logs/dataset_summary.json` to summarize counts and rules, and `config/run_config.json` to document all settings and paths. Finally, the temporary `_candidates` folder is removed and sanity checks are run to confirm correct speaker splits and expected audio settings.

In [None]:
# ============================================================
# D5 Preprocessing v1 (MDVR-KCL) — Loudness-Gated VAD Stitching (1 clip per file) -- NOT USED: Preprocessing v2 is used further
# Inputs: raw D5 WAV files under the dataset folder
# Outputs: clips/<split>/ WAVs (flat), manifests/manifest_all.csv,
#          config/run_config.json, logs/preprocess_warnings.csv, logs/dataset_summary.json
# Notes: Writes temporary candidates to clips/_candidates/ during processing, then moves kept clips.
# ============================================================

import os
import re
import json
import math
import random
import time
import shutil
from pathlib import Path
from typing import List, Tuple, Dict, Optional

import numpy as np
import pandas as pd
import soundfile as sf
from tqdm.auto import tqdm

# -------------------------
# Drive mount check
# Purpose: ensure dataset and output folders are reachable in Colab
# -------------------------
if not os.path.isdir("/content/drive/MyDrive"):
    from google.colab import drive  # type: ignore
    drive.mount("/content/drive")

# -------------------------
# Optional dependencies
# Purpose: enable VAD and higher-quality resampling/filtering when available
# Inputs: none
# Outputs: flags showing which optional libraries are available
# -------------------------
def _d5_try_import_webrtcvad():
    try:
        import webrtcvad  # type: ignore
        return webrtcvad, True
    except Exception:
        return None, False

webrtcvad, D5_HAVE_WEBRTCVAD = _d5_try_import_webrtcvad()
if not D5_HAVE_WEBRTCVAD:
    try:
        import subprocess, sys
        subprocess.check_call([sys.executable, "-m", "pip", "-q", "install", "webrtcvad"])
    except Exception:
        pass
    webrtcvad, D5_HAVE_WEBRTCVAD = _d5_try_import_webrtcvad()

try:
    from scipy.signal import resample_poly, butter, filtfilt  # type: ignore
    D5_HAVE_SCIPY = True
except Exception:
    D5_HAVE_SCIPY = False

print("webrtcvad available?", D5_HAVE_WEBRTCVAD)
print("scipy available?", D5_HAVE_SCIPY)

# -------------------------
# Dataset paths and standardized outputs
# Inputs: raw dataset folder
# Outputs: preprocessed_v1 folder tree (clips/manifests/config/logs)
# -------------------------
D5_PROJECT_DIR  = "/content/drive/MyDrive/AI_PD_Project"
D5_DATASET_DIR  = f"{D5_PROJECT_DIR}/Datasets/D5-English (MDVR-KCL)"
D5_DIR          = f"{D5_DATASET_DIR}/26-29_09_2017_KCL"

# Outputs
D5_OUT_ROOT     = f"{D5_DATASET_DIR}/preprocessed_v1"
D5_CLIPS_DIR    = f"{D5_OUT_ROOT}/clips"
D5_CAND_DIR     = f"{D5_CLIPS_DIR}/_candidates"     # temporary staging during this run
D5_MANIFEST_DIR = f"{D5_OUT_ROOT}/manifests"
D5_CONFIG_DIR   = f"{D5_OUT_ROOT}/config"
D5_LOGS_DIR     = f"{D5_OUT_ROOT}/logs"

for p in [D5_OUT_ROOT, D5_CLIPS_DIR, D5_CAND_DIR, D5_MANIFEST_DIR, D5_CONFIG_DIR, D5_LOGS_DIR]:
    os.makedirs(p, exist_ok=True)
for sp in ["train", "val", "test"]:
    os.makedirs(os.path.join(D5_CLIPS_DIR, sp), exist_ok=True)

# Purpose: avoid mixing candidate files from earlier runs
if os.path.isdir(D5_CAND_DIR):
    try:
        shutil.rmtree(D5_CAND_DIR)
    except Exception:
        pass
os.makedirs(D5_CAND_DIR, exist_ok=True)

print("\nD5_DIR exists?", os.path.exists(D5_DIR))
if not os.path.exists(D5_DIR):
    raise FileNotFoundError(f"D5_DIR not found: {D5_DIR}")

# -------------------------
# Preprocessing configuration
# Purpose: keep audio format consistent and enforce single-clip policy
# -------------------------
D5_SR = 16000
D5_RANDOM_SEED = 1337
random.seed(D5_RANDOM_SEED)
np.random.seed(D5_RANDOM_SEED)

# Purpose: skip long lead-in content common in some recordings
D5_SKIP_SEC = 40.0

# Purpose: simple, stable loudness control without external loudness libraries
D5_TARGET_RMS_DBFS = -20.0
D5_PEAK_LIMIT_DBFS = -1.0
D5_MIN_RMS_DBFS    = -60.0
D5_MAX_GAIN_DB     = 18.0

# Purpose: reduce low-frequency rumble (only if SciPy is available)
D5_ENABLE_HIGHPASS = True
D5_HIGHPASS_HZ     = 70.0
D5_HIGHPASS_ORDER  = 4

# Purpose: speech segment detection settings
D5_VAD_MODE      = 3
D5_FRAME_MS      = 30
D5_MIN_SPEECH_MS = 200
D5_MERGE_GAP_MS  = 200
D5_PAD_SEC       = 0.25
D5_MIN_KEEP_SEC  = 0.30

# Purpose: enforce exactly one 8.0-second clip per source file
D5_CLIP_SEC = 8.0

# Purpose: cap per speaker and task after candidate generation
D5_MAX_CLIPS_PER_SPK_TASK = 8

# Purpose: speaker-level split for train/val/test
D5_TRAIN_PCT, D5_VAL_PCT, D5_TEST_PCT = 0.70, 0.15, 0.15

# Purpose: reliable file writes on Drive
D5_WRITE_RETRIES = 4
D5_WRITE_SLEEP   = 0.5

# Purpose: drop quiet segments (assumed other speaker / background) before normalization
D5_ENABLE_LOUDNESS_GATE = True
D5_GATE_START_PCTL = 65
D5_GATE_MIN_PCTL   = 35
D5_GATE_STEP_PCTL  = 5
D5_GATE_REQUIRE_SEC = D5_CLIP_SEC
D5_VAD_SCALE_TARGET_PEAK = 0.25  # only used for VAD stability on very quiet signals

# -------------------------
# Manifest schema (canonical order)
# Purpose: identical column names and ordering across datasets
# Output: manifest_all.csv follows this exact order
# -------------------------
MANIFEST_COLS = [
    "split","dataset","task","speaker_id","speaker_key_rel","sample_id","label_str","label_num","age","sex",
    "clip_path","duration_sec","source_path","clip_start_sec","clip_end_sec","clip_is_contiguous","sr_hz","channels",
]

# -------------------------
# Core audio helpers
# Purpose: read, resample, trim, normalize, and write WAV files safely
# -------------------------
def d5_safe(s: str) -> str:
    # Purpose: make file names filesystem-safe
    return re.sub(r"[^A-Za-z0-9_\-\.]+", "_", str(s))

def d5_db_to_lin(db: float) -> float:
    return 10.0 ** (db / 20.0)

def d5_rms_dbfs(y: np.ndarray) -> float:
    # Purpose: RMS level estimate in dBFS for gating and normalization
    if y is None or len(y) == 0:
        return -120.0
    rms = float(np.sqrt(np.mean(y.astype(np.float64) ** 2) + 1e-12))
    return 20.0 * math.log10(max(rms, 1e-12))

def d5_peak_limit(y: np.ndarray, peak_dbfs: float) -> np.ndarray:
    # Purpose: prevent clipping after gain
    if y is None or len(y) == 0:
        return y
    peak = float(np.max(np.abs(y)))
    lim = d5_db_to_lin(peak_dbfs)
    if peak > lim and peak > 0:
        y = y * (lim / peak)
    return np.clip(y, -1.0, 1.0).astype(np.float32)

def d5_norm_rms_then_peak(y: np.ndarray) -> np.ndarray:
    # Purpose: bring audio to a target RMS without boosting extremely quiet signals too much
    if y is None or len(y) == 0:
        return y
    cur = d5_rms_dbfs(y)
    if cur < D5_MIN_RMS_DBFS:
        return d5_peak_limit(y, D5_PEAK_LIMIT_DBFS)
    gain_db = float(D5_TARGET_RMS_DBFS - cur)
    gain_db = float(np.clip(gain_db, -60.0, D5_MAX_GAIN_DB))
    y2 = (y.astype(np.float32) * d5_db_to_lin(gain_db)).astype(np.float32)
    return d5_peak_limit(y2, D5_PEAK_LIMIT_DBFS)

def d5_read_mono(path: str) -> Tuple[np.ndarray, int]:
    # Purpose: read WAV and convert to mono float32
    x, sr = sf.read(path, always_2d=False)
    if isinstance(x, np.ndarray) and x.ndim == 2:
        x = x.mean(axis=1)
    x = np.asarray(x, dtype=np.float32)
    if not np.isfinite(x).all():
        x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
    return x, int(sr)

def d5_resample(y: np.ndarray, sr_in: int, sr_out: int) -> np.ndarray:
    # Purpose: resample to a common sample rate (16 kHz)
    if sr_in == sr_out:
        return y.astype(np.float32, copy=False)
    if D5_HAVE_SCIPY:
        g = math.gcd(sr_in, sr_out)
        up = sr_out // g
        down = sr_in // g
        return resample_poly(y.astype(np.float64), up, down).astype(np.float32, copy=False)
    n_new = int(round(len(y) * (sr_out / sr_in)))
    if n_new <= 1:
        return y[:1].astype(np.float32, copy=False)
    x_old = np.linspace(0.0, 1.0, num=len(y), endpoint=False)
    x_new = np.linspace(0.0, 1.0, num=n_new, endpoint=False)
    return np.interp(x_new, x_old, y).astype(np.float32, copy=False)

def d5_crop_skip(y: np.ndarray, sr: int, skip_sec: float) -> np.ndarray:
    # Purpose: drop first N seconds (often contains setup noise or other voices)
    k = int(round(sr * skip_sec))
    if k <= 0:
        return y
    if k >= len(y):
        return np.zeros((0,), dtype=np.float32)
    return y[k:].astype(np.float32, copy=False)

def d5_apply_highpass(y: np.ndarray, sr: int) -> np.ndarray:
    # Purpose: remove very low frequencies (only when SciPy is available)
    if not D5_ENABLE_HIGHPASS or not D5_HAVE_SCIPY:
        return y
    cutoff = float(D5_HIGHPASS_HZ)
    if cutoff <= 0:
        return y
    nyq = 0.5 * sr
    wn = min(0.99, cutoff / nyq)
    b, a = butter(int(D5_HIGHPASS_ORDER), wn, btype="highpass")
    y2 = filtfilt(b, a, y.astype(np.float32)).astype(np.float32)
    if not np.isfinite(y2).all():
        y2 = np.nan_to_num(y2, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
    return y2

def d5_pcm16_bytes(y: np.ndarray) -> bytes:
    # Purpose: VAD requires 16-bit PCM bytes
    y = np.clip(y, -1.0, 1.0)
    return (y * 32767.0).astype(np.int16).tobytes()

def d5_task_short(task: str) -> str:
    # Purpose: map noisy folder labels into a small task vocabulary
    t = (task or "").strip().lower()
    if t.startswith("read"):
        return "read"
    if t.startswith("spont"):
        return "spont"
    if t.startswith("vow") or t in ("ah", "a", "vowel"):
        return "vowl"
    return "unk"

def d5_label_to_tag(label_str: str) -> str:
    # Purpose: short tag for filenames
    return "PD" if str(label_str).strip().lower().startswith("parkinson") else "HC"

def d5_safe_write_wav(path: str, audio: np.ndarray, sr: int) -> None:
    # Purpose: retry writes to reduce intermittent Drive failures
    os.makedirs(os.path.dirname(path), exist_ok=True)
    last_err = None
    for attempt in range(1, D5_WRITE_RETRIES + 1):
        try:
            sf.write(path, np.clip(audio, -1.0, 1.0), sr, subtype="PCM_16")
            return
        except Exception as e:
            last_err = e
            time.sleep(D5_WRITE_SLEEP * attempt)
    raise RuntimeError(f"Failed to write WAV: {path}. Last error: {repr(last_err)}")

# -------------------------
# Segmentation + stitching helpers
# Purpose: find speech, optionally keep only louder segments, then stitch into one stream
# -------------------------
def d5_webrtc_segments(y: np.ndarray, sr: int) -> Optional[List[Tuple[int, int]]]:
    # Output: list of (start_sample, end_sample) for speech-like regions
    if not D5_HAVE_WEBRTCVAD:
        return None
    if sr not in (8000, 16000, 32000, 48000):
        return None
    frame_ms = int(D5_FRAME_MS)
    if frame_ms not in (10, 20, 30):
        frame_ms = 30
    frame_len = int(sr * (frame_ms / 1000.0))
    if frame_len <= 0 or len(y) < frame_len:
        return []
    n_frames = int(math.ceil(len(y) / frame_len))
    pad_samp = n_frames * frame_len - len(y)
    if pad_samp > 0:
        y = np.concatenate([y, np.zeros(pad_samp, dtype=np.float32)], axis=0)
    pcm = d5_pcm16_bytes(y)

    vad = webrtcvad.Vad(int(D5_VAD_MODE))
    flags = []
    for i in range(n_frames):
        b0 = i * frame_len * 2
        b1 = b0 + frame_len * 2
        flags.append(vad.is_speech(pcm[b0:b1], sr))

    segs = []
    on = False
    s0 = 0
    for i, f in enumerate(flags):
        if f and not on:
            on = True
            s0 = i
        elif (not f) and on:
            on = False
            segs.append((s0 * frame_len, i * frame_len))
    if on:
        segs.append((s0 * frame_len, n_frames * frame_len))

    n0 = len(y) - pad_samp if pad_samp > 0 else len(y)
    return [(max(0, s), min(n0, e)) for s, e in segs]

def d5_energy_segments(y: np.ndarray, sr: int) -> List[Tuple[int, int]]:
    # Purpose: fallback segmentation when VAD is unavailable
    frame = int(sr * 0.02)
    hop = frame
    if frame <= 0 or len(y) < frame:
        return []
    eng = []
    idx = []
    for i in range(0, len(y) - frame + 1, hop):
        w = y[i:i + frame]
        eng.append(float(np.mean(w * w)))
        idx.append(i)
    eng = np.array(eng, dtype=np.float32)
    thr = float(np.percentile(eng, 25)) * 2.5
    thr = max(thr, 1e-8)
    keep = eng > thr

    segs = []
    in_seg = False
    s0 = 0
    for k, flag in enumerate(keep):
        if flag and not in_seg:
            in_seg = True
            s0 = idx[k]
        elif (not flag) and in_seg:
            in_seg = False
            segs.append((s0, idx[k] + frame))
    if in_seg:
        segs.append((s0, idx[-1] + frame))
    return segs

def d5_merge_and_filter(segs: List[Tuple[int, int]], sr: int) -> List[Tuple[int, int]]:
    # Purpose: merge nearby segments and drop very short ones
    if not segs:
        return []
    gap = int(sr * (D5_MERGE_GAP_MS / 1000.0))
    min_len = int(sr * (D5_MIN_SPEECH_MS / 1000.0))
    segs = sorted(segs)
    merged = [list(segs[0])]
    for s, e in segs[1:]:
        if s - merged[-1][1] <= gap:
            merged[-1][1] = max(merged[-1][1], e)
        else:
            merged.append([s, e])
    out = []
    for s, e in merged:
        if (e - s) >= min_len:
            out.append((int(s), int(e)))
    return out

def d5_pad_segs(segs: List[Tuple[int, int]], sr: int, n: int) -> List[Tuple[int, int]]:
    # Purpose: add a small buffer around detected speech
    pad = int(sr * D5_PAD_SEC)
    out = []
    for s, e in segs:
        s2 = max(0, s - pad)
        e2 = min(n, e + pad)
        if e2 > s2:
            out.append((s2, e2))
    return out

def d5_scale_for_vad(y: np.ndarray) -> np.ndarray:
    # Purpose: improve VAD stability on extremely quiet recordings (does not change final audio)
    peak = float(np.max(np.abs(y))) if len(y) else 0.0
    if peak <= 0:
        return y
    if peak < 0.02:
        gain = D5_VAD_SCALE_TARGET_PEAK / peak
        gain = float(np.clip(gain, 1.0, 20.0))
        return (y.astype(np.float32) * gain).astype(np.float32)
    return y

def d5_loudness_gate_segments(
    y_pre_norm: np.ndarray,
    segs: List[Tuple[int, int]],
    sr: int,
    require_sec: float
) -> Tuple[List[Tuple[int, int]], Dict]:
    # Purpose: keep loud segments first, relax threshold until enough audio is available
    info = {
        "enabled": bool(D5_ENABLE_LOUDNESS_GATE),
        "start_pctl": int(D5_GATE_START_PCTL),
        "min_pctl": int(D5_GATE_MIN_PCTL),
        "step_pctl": int(D5_GATE_STEP_PCTL),
        "chosen_pctl": None,
        "kept_segments": 0,
        "total_kept_sec": 0.0,
        "fallback": False,
    }
    if (not D5_ENABLE_LOUDNESS_GATE) or (not segs):
        info["fallback"] = True
        return segs, info

    seg_rms_db = []
    for (s, e) in segs:
        seg_rms_db.append(d5_rms_dbfs(y_pre_norm[s:e]))
    seg_rms_db = np.asarray(seg_rms_db, dtype=np.float32)

    need = int(round(sr * require_sec))
    for p in range(int(D5_GATE_START_PCTL), int(D5_GATE_MIN_PCTL) - 1, -int(D5_GATE_STEP_PCTL)):
        thr = float(np.percentile(seg_rms_db, p))
        keep_mask = seg_rms_db >= thr
        kept = [segs[i] for i in range(len(segs)) if bool(keep_mask[i])]
        total = sum((e - s) for (s, e) in kept)
        if total >= need:
            info["chosen_pctl"] = int(p)
            info["kept_segments"] = int(len(kept))
            info["total_kept_sec"] = float(total / sr)
            return kept, info

    # Purpose: avoid dropping everything when audio is scarce
    info["fallback"] = True
    info["kept_segments"] = int(len(segs))
    info["total_kept_sec"] = float(sum((e - s) for (s, e) in segs) / sr)
    return segs, info

def d5_stitch_voiced(y_pre_norm: np.ndarray, sr: int) -> Tuple[np.ndarray, str, Dict]:
    # Inputs: pre-normalization audio
    # Outputs: stitched audio, method label, and gate stats for logging
    y_for_vad = d5_scale_for_vad(y_pre_norm)

    segs = d5_webrtc_segments(y_for_vad, sr)
    if segs is None:
        segs = d5_energy_segments(y_for_vad, sr)
        used = "energy"
    else:
        used = "webrtcvad"

    segs = d5_merge_and_filter(segs, sr)
    segs = d5_pad_segs(segs, sr, len(y_pre_norm))

    if not segs:
        return y_pre_norm.astype(np.float32, copy=False), used, {"enabled": bool(D5_ENABLE_LOUDNESS_GATE), "fallback": True}

    segs2, gate_info = d5_loudness_gate_segments(y_pre_norm, segs, sr, require_sec=D5_GATE_REQUIRE_SEC)
    if not segs2:
        segs2 = segs
        gate_info["fallback"] = True

    stitched = np.concatenate([y_pre_norm[s:e] for (s, e) in segs2], axis=0).astype(np.float32, copy=False)
    return stitched, used, gate_info

def d5_first_full_window(voiced: np.ndarray, sr: int, sec: float) -> Optional[np.ndarray]:
    # Purpose: enforce the single-clip rule (must have full 8.0 seconds)
    need = int(round(sr * sec))
    if len(voiced) < need:
        return None
    return voiced[:need].astype(np.float32, copy=False)

# -------------------------
# Dataset scan and label parsing
# Inputs: raw WAV file paths
# Outputs: base_df with speaker_id, label, task, and source paths
# -------------------------
def d5_guess_task_from_path(p: Path) -> str:
    parts = [x.lower() for x in p.parts]
    if "readtext" in parts:
        return "read"
    if "spontaneousdialogue" in parts or any("spont" in x for x in parts):
        return "spont"
    return "unk"

D5_ID_RE = re.compile(r"(?:^|_)(id)(\d+)", re.IGNORECASE)
def d5_guess_speaker_id(p: Path) -> Optional[str]:
    # Purpose: extract ID## pattern used by this dataset
    m = D5_ID_RE.search(p.stem)
    if not m:
        return None
    return f"ID{int(m.group(2)):02d}"

def d5_guess_group_hc_pd(p: Path) -> str:
    # Purpose: infer HC/PD from folder names or filename tokens
    parts = [x.lower() for x in p.parts]
    if "pd" in parts:
        return "PD"
    if "hc" in parts:
        return "HC"
    s = p.stem.lower()
    if re.search(r"(?:^|_)pd(?:_|$)", s):
        return "PD"
    if re.search(r"(?:^|_)hc(?:_|$)", s):
        return "HC"
    return "Unknown"

wav_paths = sorted(Path(D5_DIR).rglob("*.wav"))
print("\nD5 total WAV files found:", len(wav_paths))

rows = []
for p in tqdm(wav_paths, desc="D5 scan", dynamic_ncols=True):
    spk = d5_guess_speaker_id(p)
    grp = d5_guess_group_hc_pd(p)
    task = d5_guess_task_from_path(p)

    # Purpose: keep only files that can be assigned speaker, label, and task
    if spk is None:
        continue
    if grp not in ("HC", "PD"):
        continue

    label_str = "Healthy" if grp == "HC" else "Parkinson"
    label_num = 0 if grp == "HC" else 1

    rows.append({
        "dataset": "D5",
        "speaker_id": spk,
        "audio_path": str(p),
        "task": task,
        "group_code": grp,
        "label_str": label_str,
        "label_num": int(label_num),
        "sample_id": p.name,
    })

base_df = pd.DataFrame(rows)
print("D5 usable rows after scan:", len(base_df))
if len(base_df) == 0:
    raise RuntimeError("No usable D5 WAV rows found after parsing (speaker_id + HC/PD + task).")

print("Label counts (rows):")
print(base_df["label_str"].value_counts(dropna=False))
print("Task counts (rows):")
print(base_df["task"].value_counts(dropna=False))

# -------------------------
# Source processing → immediate candidate write (one candidate per source)
# Inputs: base_df rows
# Outputs: candidate WAVs in clips/_candidates/ and cand_df metadata table
# -------------------------
cand_meta: List[Dict] = []
warn_rows: List[Dict] = []

global_cand_counter = 0
proc_df = base_df.sort_values(["speaker_id", "task", "audio_path"]).reset_index(drop=True)

pbar = tqdm(
    proc_df.itertuples(index=False),
    total=len(proc_df),
    desc="D5 preprocess (1 clip per source; gate+stitch; write candidates)",
    dynamic_ncols=True
)

for r in pbar:
    try:
        # Pipeline: read → resample → skip lead-in → optional high-pass
        y, sr0 = d5_read_mono(r.audio_path)
        y = d5_resample(y, sr0, D5_SR)
        y = d5_crop_skip(y, D5_SR, D5_SKIP_SEC)

        if len(y) == 0:
            warn_rows.append({
                "dataset": "D5", "speaker_id": r.speaker_id, "source_path": r.audio_path,
                "warning_type": "skip_after_initial_trim_empty",
                "detail": f"skip_sec={D5_SKIP_SEC}",
            })
            continue

        if D5_ENABLE_HIGHPASS and D5_HAVE_SCIPY:
            y = d5_apply_highpass(y, D5_SR)

        # Key step: gate and stitch before normalization to avoid amplifying quiet background voices
        stitched, seg_used, gate_info = d5_stitch_voiced(y, D5_SR)

        if len(stitched) < int(D5_SR * D5_MIN_KEEP_SEC):
            warn_rows.append({
                "dataset": "D5", "speaker_id": r.speaker_id, "source_path": r.audio_path,
                "warning_type": "stitched_too_short",
                "detail": f"seg_used={seg_used}, stitched_sec={len(stitched)/D5_SR:.3f}",
            })
            continue

        # Normalize only after the final stitched stream is chosen
        stitched = d5_norm_rms_then_peak(stitched)

        # Single-clip rule: must have a full 8 seconds, otherwise drop this source
        one = d5_first_full_window(stitched, D5_SR, D5_CLIP_SEC)
        if one is None:
            warn_rows.append({
                "dataset": "D5", "speaker_id": r.speaker_id, "source_path": r.audio_path,
                "warning_type": "too_short_for_full_clip",
                "detail": f"stitched_sec={len(stitched)/D5_SR:.3f} < {D5_CLIP_SEC}; gate={gate_info}",
            })
            continue

        # Write candidate immediately (staging file)
        global_cand_counter += 1
        cand_name = d5_safe(f"CAND_{global_cand_counter:08d}.wav")
        cand_path = os.path.join(D5_CAND_DIR, cand_name)
        d5_safe_write_wav(cand_path, one, D5_SR)

        cand_meta.append({
            "dataset": "D5",
            "speaker_id": str(r.speaker_id),
            "task": d5_task_short(r.task),
            "label_str": str(r.label_str),
            "label_num": int(r.label_num),
            "sample_id": str(r.sample_id),
            "source_path": str(r.audio_path),
            "group_code": str(r.group_code),
            "duration_sec": float(len(one) / D5_SR),
            "clip_path_cand": cand_path,
            "segmentation_used": seg_used,
            "loudness_gate_info": json.dumps(gate_info),
        })

    except Exception as e:
        warn_rows.append({
            "dataset": "D5",
            "speaker_id": getattr(r, "speaker_id", ""),
            "source_path": getattr(r, "audio_path", ""),
            "warning_type": "preprocess_error",
            "detail": repr(e),
        })

print("\nD5 candidates written:", len(cand_meta))
if len(cand_meta) == 0:
    raise RuntimeError("No D5 candidates produced. Check input audio and parsing.")

cand_df = pd.DataFrame(cand_meta)

# -------------------------
# Cap candidates per speaker and task (then delete unkept candidates)
# Inputs: candidate table + candidate files
# Outputs: cap_df (kept candidates), deleted candidate files removed from staging
# -------------------------
rng = np.random.default_rng(D5_RANDOM_SEED)
keep_idx: List[int] = []

for (spk, task), g in cand_df.groupby(["speaker_id", "task"], sort=False):
    idx = g.index.to_numpy()
    if len(idx) <= D5_MAX_CLIPS_PER_SPK_TASK:
        keep_idx.extend(idx.tolist())
    else:
        chosen = rng.choice(idx, size=D5_MAX_CLIPS_PER_SPK_TASK, replace=False)
        keep_idx.extend(chosen.tolist())

keep_idx = sorted(set(keep_idx))
cap_df = cand_df.loc[keep_idx].reset_index(drop=True)

to_delete = cand_df.loc[~cand_df.index.isin(keep_idx), "clip_path_cand"].tolist()
deleted = 0
for p in to_delete:
    try:
        if os.path.exists(p):
            os.remove(p)
            deleted += 1
    except Exception as e:
        warn_rows.append({
            "dataset": "D5",
            "speaker_id": "",
            "source_path": "",
            "warning_type": "candidate_delete_failed",
            "detail": f"{p} :: {repr(e)}",
        })

print("D5 clips after cap:", int(len(cap_df)))
print("Deleted unkept candidates:", deleted)

# -------------------------
# Speaker-level split (stratified by label_str)
# Inputs: unique speakers from kept candidates
# Outputs: split assignment merged back onto cap_df
# -------------------------
spk_tbl = cap_df[["speaker_id", "label_str"]].drop_duplicates().copy()
rng_py = random.Random(D5_RANDOM_SEED)

split_rows = []
for lab_name, g in spk_tbl.groupby("label_str"):
    spks = g["speaker_id"].tolist()
    rng_py.shuffle(spks)
    n = len(spks)

    n_train = int(round(n * D5_TRAIN_PCT))
    n_val = int(round(n * D5_VAL_PCT))
    n_train = min(n_train, n)
    n_val = min(n_val, n - n_train)

    train = spks[:n_train]
    val = spks[n_train:n_train + n_val]
    test = spks[n_train + n_val:]

    split_rows += [{"speaker_id": s, "split": "train"} for s in train]
    split_rows += [{"speaker_id": s, "split": "val"} for s in val]
    split_rows += [{"speaker_id": s, "split": "test"} for s in test]

spk_split = pd.DataFrame(split_rows)
cap_df["speaker_id"] = cap_df["speaker_id"].astype(str)
spk_split["speaker_id"] = spk_split["speaker_id"].astype(str)
cap_df = cap_df.merge(spk_split, on="speaker_id", how="left", validate="many_to_one")

if cap_df["split"].isna().any():
    ex = cap_df.loc[cap_df["split"].isna(), "speaker_id"].drop_duplicates().head(10).tolist()
    raise RuntimeError(f"Some speakers did not get a split assignment. Example: {ex}")

print("\nSplit counts (capped candidates):")
print(cap_df["split"].value_counts())

# -------------------------
# Finalize: move candidates into clips/<split>/ and build manifest rows
# Inputs: kept candidate WAVs + split assignment
# Outputs: final WAVs in clips/<split>/ and manifest_df rows pointing to final paths
# -------------------------
manifest_rows: List[Dict] = []
global_final_idx = 0

pbar2 = tqdm(cap_df.itertuples(index=False), total=len(cap_df),
             desc="D5 finalize (move kept)", dynamic_ncols=True)

for r in pbar2:
    global_final_idx += 1
    tag = d5_label_to_tag(r.label_str)

    # Purpose: stable, readable filenames for downstream training and debugging
    out_name = d5_safe(f"D5_{tag}_{r.speaker_id}_{r.task}_{global_final_idx:06d}.wav")
    out_path = os.path.join(D5_CLIPS_DIR, r.split, out_name)

    cand_path = getattr(r, "clip_path_cand")
    if not os.path.exists(cand_path):
        warn_rows.append({
            "dataset": "D5",
            "speaker_id": r.speaker_id,
            "source_path": r.source_path,
            "warning_type": "missing_candidate_file",
            "detail": cand_path,
        })
        continue

    try:
        shutil.move(cand_path, out_path)
    except Exception as e:
        warn_rows.append({
            "dataset": "D5",
            "speaker_id": r.speaker_id,
            "source_path": r.source_path,
            "warning_type": "move_error",
            "detail": repr(e),
        })
        continue

    manifest_rows.append({
        "split": r.split,
        "dataset": "D5",
        "task": r.task,
        "speaker_id": r.speaker_id,
        "speaker_key_rel": np.nan,
        "sample_id": r.sample_id,
        "label_str": r.label_str,
        "label_num": int(r.label_num),
        "age": np.nan,
        "sex": np.nan,
        "clip_path": out_path,
        "duration_sec": float(r.duration_sec),
        "source_path": r.source_path,
        "clip_start_sec": np.nan,
        "clip_end_sec": np.nan,
        "clip_is_contiguous": False,  # stitched stream, not a single contiguous region in the source
        "sr_hz": int(D5_SR),
        "channels": 1,
    })

# Purpose: remove temporary staging directory when done
try:
    if os.path.isdir(D5_CAND_DIR):
        leftovers = list(os.scandir(D5_CAND_DIR))
        for ent in leftovers:
            try:
                os.remove(ent.path)
            except Exception:
                pass
        try:
            os.rmdir(D5_CAND_DIR)
        except Exception:
            pass
except Exception:
    pass

manifest_df = pd.DataFrame(manifest_rows)
warnings_df = pd.DataFrame(warn_rows)

# Purpose: enforce canonical schema and column order
for c in MANIFEST_COLS:
    if c not in manifest_df.columns:
        manifest_df[c] = np.nan
manifest_df = manifest_df[MANIFEST_COLS].copy()

# -------------------------
# Save required artifacts (manifest, warnings, summary, run config)
# Outputs: CSV/JSON files under manifests/, logs/, config/
# -------------------------
manifest_all_path = os.path.join(D5_MANIFEST_DIR, "manifest_all.csv")
warnings_path     = os.path.join(D5_LOGS_DIR, "preprocess_warnings.csv")
summary_path      = os.path.join(D5_LOGS_DIR, "dataset_summary.json")
run_cfg_path      = os.path.join(D5_CONFIG_DIR, "run_config.json")

manifest_df.to_csv(manifest_all_path, index=False)
warnings_df.to_csv(warnings_path, index=False)

summary = {
    "dataset": "D5",
    "source_root": D5_DIR,
    "sr_hz": int(D5_SR),
    "webrtcvad_available": bool(D5_HAVE_WEBRTCVAD),
    "scipy_available": bool(D5_HAVE_SCIPY),
    "policies": {
        "skip_initial_sec": float(D5_SKIP_SEC),
        "highpass": {"enabled": bool(D5_ENABLE_HIGHPASS and D5_HAVE_SCIPY), "hz": float(D5_HIGHPASS_HZ)},
        "segmentation": "webrtcvad if available else energy",
        "loudness_gate": {
            "enabled": bool(D5_ENABLE_LOUDNESS_GATE),
            "start_percentile": int(D5_GATE_START_PCTL),
            "min_percentile": int(D5_GATE_MIN_PCTL),
            "step_percentile": int(D5_GATE_STEP_PCTL),
            "require_sec": float(D5_GATE_REQUIRE_SEC),
            "note": "gate and stitch on pre-normalization audio; normalize after stitching",
        },
        "normalization": "manual RMS gain + peak limiting; no pyloudnorm",
        "clip": {
            "one_clip_per_source": True,
            "sec": float(D5_CLIP_SEC),
            "selection_rule": "first full 8 s from gated+stitched stream; drop if shorter",
        },
        "cap": {"grouping": ["speaker_id", "task"], "max_per_group": int(D5_MAX_CLIPS_PER_SPK_TASK)},
        "split": {"unit": "speaker", "train": D5_TRAIN_PCT, "val": D5_VAL_PCT, "test": D5_TEST_PCT, "stratified_by": "label_str"},
        "clip_is_contiguous": False,
    },
    "counts": {
        "n_raw_wavs_found": int(len(wav_paths)),
        "n_rows_after_scan": int(len(base_df)),
        "n_candidates_before_cap": int(len(cand_df)),
        "n_candidates_after_cap": int(len(cap_df)),
        "n_clips_written": int(len(manifest_df)),
        "split_counts_clips": manifest_df["split"].value_counts(dropna=False).to_dict() if len(manifest_df) else {},
        "label_counts_clips": manifest_df["label_str"].value_counts(dropna=False).to_dict() if len(manifest_df) else {},
        "task_counts_clips": manifest_df["task"].value_counts(dropna=False).to_dict() if len(manifest_df) else {},
        "n_warnings": int(len(warnings_df)),
    },
}

with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(summary, f, indent=2)

run_cfg = {
    "dataset": "D5",
    "paths": {
        "dataset_dir": D5_DIR,
        "out_root": D5_OUT_ROOT,
        "clips_dir": D5_CLIPS_DIR,
        "candidate_dir": D5_CAND_DIR,
        "manifest_all": manifest_all_path,
        "warnings_csv": warnings_path,
        "summary_json": summary_path,
    },
    "audio": {
        "sr_hz": int(D5_SR),
        "skip_initial_sec": float(D5_SKIP_SEC),
        "highpass": {"enabled": bool(D5_ENABLE_HIGHPASS), "hz": float(D5_HIGHPASS_HZ), "order": int(D5_HIGHPASS_ORDER)},
        "vad": {"mode": int(D5_VAD_MODE), "frame_ms": int(D5_FRAME_MS), "min_speech_ms": int(D5_MIN_SPEECH_MS), "merge_gap_ms": int(D5_MERGE_GAP_MS), "pad_sec": float(D5_PAD_SEC)},
        "loudness_gate": {"enabled": bool(D5_ENABLE_LOUDNESS_GATE), "start_pctl": int(D5_GATE_START_PCTL), "min_pctl": int(D5_GATE_MIN_PCTL), "step_pctl": int(D5_GATE_STEP_PCTL)},
        "normalization": {"target_rms_dbfs": float(D5_TARGET_RMS_DBFS), "peak_limit_dbfs": float(D5_PEAK_LIMIT_DBFS), "min_rms_dbfs": float(D5_MIN_RMS_DBFS), "max_gain_db": float(D5_MAX_GAIN_DB)},
        "clip_sec": float(D5_CLIP_SEC),
        "one_per_source": True,
    },
    "split": {"train": D5_TRAIN_PCT, "val": D5_VAL_PCT, "test": D5_TEST_PCT, "unit": "speaker"},
    "cap": {"max_per_speaker_task": int(D5_MAX_CLIPS_PER_SPK_TASK)},
    "seed": int(D5_RANDOM_SEED),
}

with open(run_cfg_path, "w", encoding="utf-8") as f:
    json.dump(run_cfg, f, indent=2)

print("\nDONE: D5 preprocessing")
print("Manifest:", manifest_all_path)
print("Warnings:", warnings_path)
print("Summary:", summary_path)
print("Run config:", run_cfg_path)
print("Clips written:", int(len(manifest_df)))

# -------------------------
# Quick sanity prints
# Purpose: confirm uniform audio format and clean speaker split
# -------------------------
print("\nSanity checks:")
print("Unique SR:", sorted(manifest_df["sr_hz"].unique().tolist()) if len(manifest_df) else [])
print("Unique channels:", sorted(manifest_df["channels"].unique().tolist()) if len(manifest_df) else [])
print("Speakers per split x label (by speaker):")
if len(manifest_df):
    print(
        manifest_df[["speaker_id", "split", "label_str"]]
        .drop_duplicates()
        .groupby(["split", "label_str"])
        .size()
    )
spk_split_chk = manifest_df[["speaker_id", "split"]].drop_duplicates()
dup = spk_split_chk.groupby("speaker_id")["split"].nunique()
bad = dup[dup > 1]
print("Speakers appearing in multiple splits:", int(len(bad)))
if len(bad) > 0:
    print("Example bad speaker ids:", bad.index.tolist()[:10])
    raise RuntimeError("Speaker appears in more than one split.")

The following cell rebuilds Dataset D5 into a new `preprocessed_v2` folder by changing **only** the train, validation, and test split. The audio itself is not modified or reprocessed. All clips created earlier in `preprocessed_v1` are reused exactly as they are, with filenames kept exactly the same. Only the split assignment and folder location are updated. This makes it possible to test a new split setup while keeping the audio data identical.

The cell begins by mounting Google Drive if needed, defining the D5 input and output paths, and creating the full `preprocessed_v2` folder structure (`clips/train`, `clips/val`, `clips/test`, `manifests`, `logs`, and `config`). It checks that the original v1 manifest exists and prints the source and destination paths so it is clear which dataset version is being used.

Next, a new **speaker level split policy** is defined using a 50% train, 20% validation, and 30% test split, with a fixed random seed to ensure the results can be reproduced. Helper functions are prepared to safely write files, copy or move audio clips with retries, and ensure that all JSON and CSV outputs are written in a safe and consistent way.

The cell then loads the **v1 manifest**, verifies that all required columns are present, filters strictly to D5 entries, and converts any literal `"NaN"` strings into true missing values. Before continuing, it checks that every audio file listed in the v1 manifest actually exists on disk.

A new **speaker level split** is created and stratified by label (Healthy versus Parkinson’s), ensuring that each speaker appears in only one of the train, validation, or test splits. This new split information is merged back into the clip level table, replacing the old split assignments from v1.

For each clip, the cell enforces a strict file transfer rule. If the clip already exists in the correct v2 split folder with the same file size, it is reused. If it exists in a different v2 split folder from a previous run, it is moved. Otherwise, it is copied from the v1 location. This guarantees that **each filename exists in exactly one split folder** in v2 and avoids leftover duplicates.

After all files are transferred, the manifest is rewritten so that `clip_path` points to the new v2 locations, while all other metadata remains unchanged. The cell then writes the required outputs: a new `manifest_all.csv`, a warnings file for any transfer issues, a CSV listing the speaker split used, a dataset summary JSON with counts and split details, and a run configuration JSON that documents the rebuild policy.

Finally, the cell runs thorough checks to confirm that no speaker appears in more than one split, every manifest entry points to an existing file, and no filename appears in multiple split folders. A clear summary is printed showing clip counts, speaker counts by split and label, and any warnings that were recorded.

In [None]:
# ============================================================
# D5 Split-Only Rebuild → preprocessed_v2 (50/20/30 speaker split)
# Inputs: preprocessed_v1/manifests/manifest_all.csv + existing v1 clips
# Outputs: preprocessed_v2/clips/<split>/ (same filenames), manifests/manifest_all.csv,
#          logs/preprocess_warnings.csv + dataset_summary.json, config/run_config.json
# Notes: No audio work. Only speaker split and file relocation.
# ============================================================

import os
import json
import time
import shutil
import random
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

# -------------------------
# Drive mount check
# Purpose: ensure preprocessed_v1 and preprocessed_v2 paths are reachable in Colab
# -------------------------
if not os.path.isdir("/content/drive/MyDrive"):
    from google.colab import drive  # type: ignore
    drive.mount("/content/drive")

# -------------------------
# Dataset paths and output folders
# Inputs: v1 manifest and v1 clip files
# Outputs: v2 folder tree (clips/manifests/config/logs)
# -------------------------
D5_PROJECT_DIR   = "/content/drive/MyDrive/AI_PD_Project"
D5_DATASET_DIR   = f"{D5_PROJECT_DIR}/Datasets/D5-English (MDVR-KCL)"

D5_V1_ROOT       = Path(f"{D5_DATASET_DIR}/preprocessed_v1")
D5_V2_ROOT       = Path(f"{D5_DATASET_DIR}/preprocessed_v2")

D5_V1_MANIFEST   = D5_V1_ROOT / "manifests" / "manifest_all.csv"

D5_V2_CLIPS_DIR  = D5_V2_ROOT / "clips"
D5_V2_MAN_DIR    = D5_V2_ROOT / "manifests"
D5_V2_CFG_DIR    = D5_V2_ROOT / "config"
D5_V2_LOGS_DIR   = D5_V2_ROOT / "logs"

# Purpose: create v2 structure without touching v1
for p in [D5_V2_ROOT, D5_V2_CLIPS_DIR, D5_V2_MAN_DIR, D5_V2_CFG_DIR, D5_V2_LOGS_DIR]:
    p.mkdir(parents=True, exist_ok=True)
for sp in ["train", "val", "test"]:
    (D5_V2_CLIPS_DIR / sp).mkdir(parents=True, exist_ok=True)

# Purpose: fail early if v1 manifest is missing
if not D5_V1_MANIFEST.exists():
    raise FileNotFoundError(f"Missing D5 v1 manifest: {str(D5_V1_MANIFEST)}")

print("D5 v1 manifest:", str(D5_V1_MANIFEST))
print("D5 v2 out root:", str(D5_V2_ROOT))

# -------------------------
# Split configuration (speaker-level)
# Purpose: rebuild only the train/val/test speaker assignment
# -------------------------
D5_SEED = 1337
D5_TRAIN_PCT, D5_VAL_PCT, D5_TEST_PCT = 0.50, 0.20, 0.30

# Purpose: make Drive file transfer more reliable
COPY_RETRIES = 4
COPY_SLEEP   = 0.5

# -------------------------
# Canonical manifest column order
# Purpose: keep the manifest consistent across datasets/runs
# -------------------------
MANIFEST_COLS = [
    "split","dataset","task","speaker_id","speaker_key_rel","sample_id","label_str","label_num","age","sex",
    "clip_path","duration_sec","source_path","clip_start_sec","clip_end_sec","clip_is_contiguous","sr_hz","channels",
]

# -------------------------
# Safe writers and transfer helpers
# Purpose: avoid partial files and handle already-copied runs cleanly
# -------------------------
def atomic_write_text(path: Path, text: str):
    # Write to temp and replace to reduce risk of partial outputs
    tmp = path.with_suffix(path.suffix + ".tmp")
    with open(tmp, "w", encoding="utf-8") as f:
        f.write(text)
    os.replace(tmp, path)

def atomic_write_json(path: Path, obj: dict):
    atomic_write_text(path, json.dumps(obj, indent=2))

def atomic_write_csv(path: Path, df: pd.DataFrame):
    # NaN values stay blank in CSV (no literal "NaN")
    tmp = path.with_suffix(path.suffix + ".tmp")
    df.to_csv(tmp, index=False, na_rep="")
    os.replace(tmp, path)

def json_safe(obj):
    # Convert numpy types and tuple keys so JSON dumps cleanly
    if isinstance(obj, dict):
        out = {}
        for k, v in obj.items():
            if isinstance(k, tuple):
                k2 = " | ".join([str(x) for x in k])
            else:
                k2 = str(k) if not isinstance(k, (str, int, float, bool)) and k is not None else k
            out[k2] = json_safe(v)
        return out
    if isinstance(obj, (list, tuple)):
        return [json_safe(x) for x in obj]
    if isinstance(obj, (np.integer,)):
        return int(obj)
    if isinstance(obj, (np.floating,)):
        return float(obj)
    return obj

def split_counts(n: int) -> Tuple[int, int, int]:
    """Deterministic 50/20/30 with rounding and guardrails."""
    if n <= 0:
        return 0, 0, 0

    n_train = int(round(n * D5_TRAIN_PCT))
    n_val   = int(round(n * D5_VAL_PCT))

    n_train = min(max(n_train, 0), n)
    n_val   = min(max(n_val, 0), n - n_train)
    n_test  = n - n_train - n_val

    # Purpose: avoid empty splits when enough speakers exist
    if n >= 3:
        if n_val == 0:
            if n_train > 1:
                n_train -= 1; n_val += 1
            elif n_test > 1:
                n_test -= 1; n_val += 1
        if n_test == 0:
            if n_train > 1:
                n_train -= 1; n_test += 1
            elif n_val > 1:
                n_val -= 1; n_test += 1
        if n_train == 0:
            if n_test > 1:
                n_test -= 1; n_train += 1
            elif n_val > 1:
                n_val -= 1; n_train += 1

    n_train = min(n_train, n)
    n_val   = min(n_val, n - n_train)
    n_test  = n - n_train - n_val
    return n_train, n_val, n_test

def file_size(path: Path) -> Optional[int]:
    # Purpose: compare files without reading full contents
    try:
        return path.stat().st_size
    except Exception:
        return None

def retry_copy(src: Path, dst: Path):
    # Purpose: retry copy to reduce intermittent Drive failures
    dst.parent.mkdir(parents=True, exist_ok=True)
    last_err = None
    for attempt in range(1, COPY_RETRIES + 1):
        try:
            shutil.copy2(str(src), str(dst))
            return
        except Exception as e:
            last_err = e
            time.sleep(COPY_SLEEP * attempt)
    raise RuntimeError(f"Failed to copy: {str(src)} -> {str(dst)}. Last error: {repr(last_err)}")

def safe_transfer_clip(
    src_v1: Path,
    dst_v2: Path,
    fname: str,
    v2_clips_dir: Path,
) -> Tuple[str, Optional[str]]:
    """
    Goal: ensure dst_v2 has the correct clip with the same filename.

    Priority:
      1) If dst exists and size matches v1 → keep it
      2) Else if found in another v2 split and size matches v1 → move it
      3) Else → copy from v1

    Returns:
      (action, detail)
      action in {"SKIP_OK", "MOVED_FROM_OTHER_SPLIT", "COPIED_FROM_V1", "ERROR"}
    """
    src_size = file_size(src_v1)
    if src_size is None or (not src_v1.exists()):
        return "ERROR", f"v1 source missing or unreadable: {str(src_v1)}"

    # 1) already correct
    if dst_v2.exists():
        dst_size = file_size(dst_v2)
        if dst_size == src_size:
            return "SKIP_OK", None
        # wrong-sized dst is treated as stale; repair below

    # 2) check if the clip is in the wrong v2 split folder (leftover from earlier run)
    other_paths = []
    for sp in ["train", "val", "test"]:
        cand = v2_clips_dir / sp / fname
        if cand.exists():
            other_paths.append(cand)

    for cand in other_paths:
        cand_size = file_size(cand)
        if cand_size == src_size:
            dst_v2.parent.mkdir(parents=True, exist_ok=True)
            try:
                os.replace(str(cand), str(dst_v2))
                return "MOVED_FROM_OTHER_SPLIT", str(cand)
            except Exception as e:
                return "ERROR", f"move failed {str(cand)} -> {str(dst_v2)} :: {repr(e)}"

    # 3) copy from v1
    try:
        retry_copy(src_v1, dst_v2)
        if file_size(dst_v2) != src_size:
            return "ERROR", f"copy size mismatch {str(src_v1)} -> {str(dst_v2)}"
        return "COPIED_FROM_V1", None
    except Exception as e:
        return "ERROR", f"copy failed {str(src_v1)} -> {str(dst_v2)} :: {repr(e)}"

# -------------------------
# Load v1 manifest and validate clip files
# Inputs: v1 manifest_all.csv
# Outputs: m1 (v1 rows for D5), verified that v1 clip paths exist
# -------------------------
m1 = pd.read_csv(D5_V1_MANIFEST)

missing = [c for c in MANIFEST_COLS if c not in m1.columns]
if missing:
    raise RuntimeError(f"v1 manifest missing columns: {missing}")

m1 = m1[m1["dataset"].astype(str) == "D5"].copy()
if len(m1) == 0:
    raise RuntimeError("No D5 rows found in v1 manifest.")

# Purpose: keep missing values as real NaN, not the string "NaN"
for c in ["age", "sex", "duration_sec", "clip_start_sec", "clip_end_sec", "speaker_id", "task", "sample_id", "speaker_key_rel"]:
    if c in m1.columns:
        m1[c] = m1[c].replace("NaN", np.nan)

# Purpose: fail early if any v1 clip is missing
missing_files = [p for p in m1["clip_path"].astype(str).tolist() if not os.path.exists(p)]
if missing_files:
    print("Example missing clip_path (first 10):")
    for x in missing_files[:10]:
        print("  ", x)
    raise FileNotFoundError(f"Missing {len(missing_files)} v1 clip files. Fix before rebuilding v2.")

print("D5 v1 clips verified:", len(m1))

# -------------------------
# Build new speaker split (stratified by label_str)
# Inputs: unique speakers from v1
# Outputs: spk_split table: one split per speaker
# -------------------------
spk_tbl = (
    m1[["speaker_id", "label_str", "label_num"]]
    .drop_duplicates()
    .copy()
)

rng = random.Random(D5_SEED)
speaker_split_rows: List[Dict] = []

for lab, g in spk_tbl.groupby("label_str", sort=True):
    spks = g["speaker_id"].astype(str).tolist()
    rng.shuffle(spks)

    n = len(spks)
    n_train, n_val, n_test = split_counts(n)

    train_spks = spks[:n_train]
    val_spks   = spks[n_train:n_train + n_val]
    test_spks  = spks[n_train + n_val:]

    speaker_split_rows += [{"speaker_id": s, "split": "train", "label_str": lab} for s in train_spks]
    speaker_split_rows += [{"speaker_id": s, "split": "val",   "label_str": lab} for s in val_spks]
    speaker_split_rows += [{"speaker_id": s, "split": "test",  "label_str": lab} for s in test_spks]

spk_split = pd.DataFrame(speaker_split_rows)

# Purpose: ensure each speaker is assigned exactly once
if spk_split["speaker_id"].duplicated().any():
    dup = spk_split.loc[spk_split["speaker_id"].duplicated(), "speaker_id"].tolist()[:10]
    raise RuntimeError(f"Speaker appears twice in split mapping. Examples: {dup}")

print("\nNew speaker split (counts):")
print(spk_split.groupby(["split", "label_str"]).size())

# -------------------------
# Apply new split to every clip row
# Inputs: v1 clip-level rows + speaker split table
# Outputs: m2 (clip rows with fresh split values)
# -------------------------
m2 = m1.copy()

# Purpose: prevent split_x / split_y columns during merge
if "split" in m2.columns:
    m2 = m2.drop(columns=["split"])

m2["speaker_id"] = m2["speaker_id"].astype(str)
spk_split["speaker_id"] = spk_split["speaker_id"].astype(str)

m2 = m2.merge(spk_split[["speaker_id", "split"]], on="speaker_id", how="left", validate="many_to_one")

if m2["split"].isna().any():
    ex = m2.loc[m2["split"].isna(), "speaker_id"].drop_duplicates().head(10).tolist()
    raise RuntimeError(f"Some speakers did not receive a split. Examples: {ex}")

# -------------------------
# Transfer clips into v2 (same filenames, new split folders)
# Inputs: v1 clip_path + new split assignment
# Outputs: v2 clip files placed under clips/<split>/ with no duplicates across splits
# -------------------------
warn_rows: List[Dict] = []
actions = {"SKIP_OK": 0, "MOVED_FROM_OTHER_SPLIT": 0, "COPIED_FROM_V1": 0, "ERROR": 0}

# Purpose: avoid filename collisions in v2 split folders
fn_counts = m2["clip_path"].astype(str).map(lambda p: os.path.basename(p)).value_counts()
dupe_fns = fn_counts[fn_counts > 1]
if len(dupe_fns) > 0:
    ex = dupe_fns.head(10).to_dict()
    raise RuntimeError(f"Duplicate clip basenames found in v1 manifest (would collide in v2). Examples: {ex}")

pbar = tqdm(m2.itertuples(index=False), total=len(m2), desc="D5 v2: transfer clips", dynamic_ncols=True)
for r in pbar:
    src_v1 = Path(str(r.clip_path))
    fname  = os.path.basename(str(r.clip_path))
    dst_v2 = D5_V2_CLIPS_DIR / str(r.split) / fname

    action, detail = safe_transfer_clip(src_v1=src_v1, dst_v2=dst_v2, fname=fname, v2_clips_dir=D5_V2_CLIPS_DIR)
    actions[action] = actions.get(action, 0) + 1

    if action == "ERROR":
        warn_rows.append({
            "dataset": "D5",
            "speaker_id": str(r.speaker_id),
            "source_path": str(r.source_path),
            "warning_type": "transfer_failed",
            "detail": detail,
        })

# Purpose: update manifest paths to the v2 clip locations (same filename, new folder)
m2 = m2.copy()
m2["clip_path"] = [
    str(D5_V2_CLIPS_DIR / sp / os.path.basename(p))
    for sp, p in zip(m2["split"].astype(str).tolist(), m2["clip_path"].astype(str).tolist())
]

# -------------------------
# Write v2 artifacts (manifest, warnings, summary, config, speaker split)
# Outputs: files under preprocessed_v2/manifests, logs, config
# -------------------------
for c in MANIFEST_COLS:
    if c not in m2.columns:
        m2[c] = np.nan
m2 = m2[MANIFEST_COLS].copy()

# Purpose: keep true missing values (never literal "NaN")
for c in MANIFEST_COLS:
    if m2[c].dtype == object:
        m2[c] = m2[c].replace("NaN", np.nan)

v2_manifest_path = D5_V2_MAN_DIR / "manifest_all.csv"
v2_warn_path     = D5_V2_LOGS_DIR / "preprocess_warnings.csv"
v2_summary_path  = D5_V2_LOGS_DIR / "dataset_summary.json"
v2_cfg_path      = D5_V2_CFG_DIR / "run_config.json"
v2_spk_split_csv = D5_V2_LOGS_DIR / f"speaker_split_seed{D5_SEED}.csv"

atomic_write_csv(v2_manifest_path, m2)

warnings_df = pd.DataFrame(warn_rows)
atomic_write_csv(v2_warn_path, warnings_df)

spk_split.sort_values(["split", "label_str", "speaker_id"]).to_csv(v2_spk_split_csv, index=False)

speakers_per_split_x_label = spk_split.groupby(["split", "label_str"]).size().to_dict()

summary = {
    "dataset": "D5",
    "note": "Split-only rebuild: v2 contains v1 clips relocated into new split folders; no audio reprocessing. If a clip existed in the wrong v2 split, it was moved.",
    "source_preprocessed_root": str(D5_V1_ROOT),
    "out_root": str(D5_V2_ROOT),
    "seed": int(D5_SEED),
    "split_policy": {
        "unit": "speaker",
        "train": float(D5_TRAIN_PCT),
        "val": float(D5_VAL_PCT),
        "test": float(D5_TEST_PCT),
        "stratified_by": "label_str",
    },
    "transfer_actions": actions,
    "counts": {
        "n_clips_total": int(len(m2)),
        "split_counts_clips": m2["split"].value_counts(dropna=False).to_dict(),
        "label_counts_clips": m2["label_str"].value_counts(dropna=False).to_dict(),
        "task_counts_clips": m2["task"].value_counts(dropna=False).to_dict(),
        "n_speakers_total": int(spk_tbl["speaker_id"].nunique()),
        "speakers_per_split_x_label": json_safe(speakers_per_split_x_label),
        "n_warnings": int(len(warnings_df)),
    },
    "artifacts": {
        "manifest_all_csv": str(v2_manifest_path),
        "warnings_csv": str(v2_warn_path),
        "speaker_split_csv": str(v2_spk_split_csv),
    },
}
atomic_write_json(v2_summary_path, summary)

run_cfg = {
    "dataset": "D5",
    "mode": "split_only_rebuild",
    "seed": int(D5_SEED),
    "split": {
        "unit": "speaker",
        "train": float(D5_TRAIN_PCT),
        "val": float(D5_VAL_PCT),
        "test": float(D5_TEST_PCT),
        "stratified_by": "label_str",
    },
    "paths": {
        "source_preprocessed_v1": str(D5_V1_ROOT),
        "out_root": str(D5_V2_ROOT),
        "clips_dir": str(D5_V2_CLIPS_DIR),
        "manifest_all": str(v2_manifest_path),
        "warnings_csv": str(v2_warn_path),
        "summary_json": str(v2_summary_path),
        "speaker_split_csv": str(v2_spk_split_csv),
    },
    "transfer_policy": {
        "priority": [
            "keep dst if exists and same size as v1",
            "else move from other v2 split folder if matches size",
            "else copy from v1"
        ],
        "preserve_filename": True,
        "no_audio_processing": True,
    },
}
atomic_write_json(v2_cfg_path, run_cfg)

# -------------------------
# Sanity checks (speaker integrity + file existence + no stale duplicates)
# Inputs: m2 + v2 clips folders
# Outputs: raises error if anything is inconsistent
# -------------------------
# A) Each speaker must be in exactly one split
spk_chk = m2[["speaker_id", "split"]].drop_duplicates()
bad = spk_chk.groupby("speaker_id")["split"].nunique()
bad = bad[bad > 1]
if len(bad) > 0:
    raise RuntimeError(f"Speaker appears in multiple splits. Examples: {bad.index.tolist()[:10]}")

# B) Every manifest row must point to an existing v2 file
missing_v2 = [p for p in m2["clip_path"].astype(str).tolist() if not os.path.exists(p)]
if missing_v2:
    print("Example missing v2 clip_path (first 10):")
    for x in missing_v2[:10]:
        print("  ", x)
    raise FileNotFoundError(f"Missing {len(missing_v2)} v2 clip files after transfer. Check warnings/logs.")

# C) A filename must not exist in more than one v2 split folder
fn_to_splits: Dict[str, List[str]] = {}
for sp in ["train", "val", "test"]:
    sp_dir = D5_V2_CLIPS_DIR / sp
    for fn in os.listdir(sp_dir):
        if fn.lower().endswith(".wav"):
            fn_to_splits.setdefault(fn, []).append(sp)
dupe_across = {fn: sps for fn, sps in fn_to_splits.items() if len(sps) > 1}
if len(dupe_across) > 0:
    ex = list(dupe_across.items())[:10]
    raise RuntimeError(f"Found same filename in multiple v2 split folders (stale duplicates). Examples: {ex}")

print("\nDONE: D5 preprocessed_v2 split-only rebuild (50/20/30)")
print("Transfer actions:", actions)
print("v2 Manifest:", str(v2_manifest_path))
print("v2 Warnings:", str(v2_warn_path))
print("v2 Summary:", str(v2_summary_path))
print("v2 Run config:", str(v2_cfg_path))
print("v2 Speaker split:", str(v2_spk_split_csv))
print("\nClip counts (v2):")
print(m2["split"].value_counts())
print("\nSpeakers per split x label:")
print(spk_split.groupby(["split", "label_str"]).size())
if len(warnings_df):
    print("\nWarnings count:", len(warnings_df))
    print(warnings_df["warning_type"].value_counts().head(10))

The following cell preprocesses Dataset D6 (Ah Sound from Figshare) into a clean and standardized folder that can be used for model training. It mounts Google Drive if needed, installs and imports the required audio tools (including WebRTC VAD when available), defines the input locations for Healthy and Parkinson’s recordings along with the demographics Excel file, and creates the output folders under `preprocessed_v1` (clips, manifests, config, and logs). At the start, it clears the temporary `clips/_candidates` folder and performs basic checks so missing inputs are detected right away.

The cell then cleans and standardizes the audio in a consistent way. Each source WAV file is loaded, converted to mono if needed, resampled to 16 kHz, and normalized using an RMS based gain with a peak limit to prevent clipping. Speech regions are detected using WebRTC VAD when available, or a simple energy based fallback otherwise, and small padding is added around detected speech. A strict rule is applied so that exactly one clip is created from each original source file. This clip is produced by taking a centered segment of up to 2.0 seconds from the speech focused audio, without padding if the segment is shorter. If voice detection finds too little usable speech, the code falls back to using the original audio and records a warning.

As each file is processed, the clip is written immediately to a staging folder (`clips/_candidates`) to reduce the risk of data loss if the run is interrupted. At the same time, a metadata row is collected for each candidate, including label, speaker ID, sample ID, age, sex, source path, duration, and clip timing. Any errors or unusual cases, such as empty audio after processing, very short VAD output, or file read and write issues, are logged in a warnings table for later review.

After all candidate clips are written, the cell optionally applies a cap policy that limits the number of clips per speaker per task, using a fixed random seed for repeatability, and removes any unkept candidate WAV files. It then splits speakers into training, validation, and test sets at the speaker level (70% train, 15% validation, 15% test), while keeping Healthy and Parkinson’s speakers balanced across splits. The selected clips are moved from the staging folder into their final locations (`clips/train`, `clips/val`, and `clips/test`) using a standardized filename format, and a final manifest is created with a consistent structure and updated clip paths.

Finally, the cell writes the required outputs: `manifests/manifest_all.csv` for training and evaluation, `logs/preprocess_warnings.csv` with all recorded warnings, `logs/dataset_summary.json` summarizing counts and split details, and `config/run_config.json` describing the preprocessing settings and folder layout. Any remaining temporary files are cleaned up when possible, and a summary is printed showing where the outputs were saved and how many clips and warnings were generated.

In [None]:
# ============================================================
# D6 Preprocessing v1 — Ah Sound (Figshare), One Clip per Source, Standard Outputs
# Inputs: HC/PD WAV folders + demographics Excel
# Outputs: clips/<split>/ WAVs (flat), manifests/manifest_all.csv, config/run_config.json,
#          logs/preprocess_warnings.csv, logs/dataset_summary.json
# ============================================================
# Key rules:
# - Exactly 1 candidate clip per source WAV (no multi-clip splitting)
# - Write candidates immediately, then cap, then speaker split, then move kept clips
# ============================================================

import os
import re
import json
import math
import random
import time
import shutil
from pathlib import Path
from typing import List, Tuple, Dict, Optional

import numpy as np
import pandas as pd
import soundfile as sf
from tqdm.auto import tqdm

# -------------------------
# Drive mount check
# Purpose: ensure dataset paths are reachable in Colab
# -------------------------
if not os.path.isdir("/content/drive/MyDrive"):
    from google.colab import drive  # type: ignore
    drive.mount("/content/drive")

# -------------------------
# Optional dependencies
# Purpose: use WebRTC VAD when available; fall back if missing
# -------------------------
def _d6_try_import_webrtcvad():
    try:
        import webrtcvad  # type: ignore
        return webrtcvad, True
    except Exception:
        return None, False

webrtcvad, D6_HAVE_WEBRTCVAD = _d6_try_import_webrtcvad()
if not D6_HAVE_WEBRTCVAD:
    !pip -q install webrtcvad
    webrtcvad, D6_HAVE_WEBRTCVAD = _d6_try_import_webrtcvad()

try:
    from scipy.signal import resample_poly  # type: ignore
    D6_HAVE_SCIPY = True
except Exception:
    D6_HAVE_SCIPY = False

# -------------------------
# Dataset paths and output structure
# Inputs: raw HC/PD folders and Excel metadata
# Outputs: standardized preprocessed_v1 folder layout
# -------------------------
D6_DATASET_DIR = "/content/drive/MyDrive/AI_PD_Project/Datasets/D6-Ah Sound (Figshare)"
D6_HC_ROOT = os.path.join(D6_DATASET_DIR, "HC_AH", "HC_AH")
D6_PD_ROOT = os.path.join(D6_DATASET_DIR, "PD_AH", "PD_AH")
D6_META_XLSX = os.path.join(D6_DATASET_DIR, "Demographics_age_sex.xlsx")

# Outputs (standardized)
D6_OUT_ROOT     = os.path.join(D6_DATASET_DIR, "preprocessed_v1")
D6_CLIPS_DIR    = os.path.join(D6_OUT_ROOT, "clips")
D6_CAND_DIR     = os.path.join(D6_CLIPS_DIR, "_candidates")  # staging area during the run
D6_MANIFEST_DIR = os.path.join(D6_OUT_ROOT, "manifests")
D6_CONFIG_DIR   = os.path.join(D6_OUT_ROOT, "config")
D6_LOGS_DIR     = os.path.join(D6_OUT_ROOT, "logs")

for p in [D6_OUT_ROOT, D6_CLIPS_DIR, D6_MANIFEST_DIR, D6_CONFIG_DIR, D6_LOGS_DIR]:
    os.makedirs(p, exist_ok=True)
for sp in ["train", "val", "test"]:
    os.makedirs(os.path.join(D6_CLIPS_DIR, sp), exist_ok=True)

# Purpose: avoid mixing candidates from older runs
if os.path.isdir(D6_CAND_DIR):
    try:
        shutil.rmtree(D6_CAND_DIR)
    except Exception:
        pass
os.makedirs(D6_CAND_DIR, exist_ok=True)

# -------------------------
# Fail-fast checks
# Purpose: stop early if inputs are missing
# -------------------------
print("D6_DATASET_DIR exists?", os.path.exists(D6_DATASET_DIR))
print("D6_HC_ROOT exists?", os.path.exists(D6_HC_ROOT))
print("D6_PD_ROOT exists?", os.path.exists(D6_PD_ROOT))
print("Demographics_age_sex.xlsx exists?", os.path.exists(D6_META_XLSX))
print("webrtcvad available?", D6_HAVE_WEBRTCVAD)
print("scipy available?", D6_HAVE_SCIPY)

if not os.path.exists(D6_DATASET_DIR):
    raise FileNotFoundError(f"D6_DATASET_DIR not found: {D6_DATASET_DIR}")
if not os.path.exists(D6_HC_ROOT) or not os.path.exists(D6_PD_ROOT):
    raise FileNotFoundError("Missing HC or PD root folders. Check HC_AH/HC_AH and PD_AH/PD_AH.")
if not os.path.exists(D6_META_XLSX):
    raise FileNotFoundError(f"Missing metadata Excel: {D6_META_XLSX}")

# -------------------------
# Processing configuration
# Purpose: standardize audio, clip length, splitting, and logging
# -------------------------
D6_SR = 16000
D6_RANDOM_SEED = 1337
random.seed(D6_RANDOM_SEED)
np.random.seed(D6_RANDOM_SEED)

D6_TASK5 = "vowl"         # task label stored in the manifest
D6_TARGET_SEC = 2.0       # target clip length; no padding if shorter

# Manual normalization (no external loudness dependency)
D6_TARGET_RMS_DBFS = -20.0
D6_PEAK_LIMIT_DBFS = -1.0
D6_MIN_RMS_DBFS    = -60.0
D6_MAX_GAIN_DB     = 18.0

# VAD settings (used only to find voiced regions)
D6_VAD_MODE       = 2
D6_FRAME_MS       = 30
D6_MIN_SPEECH_MS  = 200
D6_MERGE_GAP_MS   = 200
D6_PAD_SEC        = 0.25
D6_MIN_KEEP_SEC   = 0.30

# Cap policy (applied after candidates are created)
D6_MAX_CLIPS_PER_SPK_TASK = 8

# Speaker-level split ratios
D6_TRAIN_PCT, D6_VAL_PCT, D6_TEST_PCT = 0.70, 0.15, 0.15

# Reliable write retries (Drive can be flaky)
D6_WRITE_RETRIES = 4
D6_WRITE_SLEEP   = 0.5

# -------------------------
# Standard manifest schema
# Output: manifest_all.csv uses these columns in this order
# -------------------------
MANIFEST_COLS = [
    "split","dataset","task","speaker_id","sample_id",
    "label_str","label_num","age","sex","speaker_key_rel",
    "clip_path","duration_sec","source_path",
    "clip_start_sec","clip_end_sec","sr_hz","channels",
    "clip_is_contiguous",
]

# -------------------------
# Utility helpers (names, audio stats, I/O)
# Purpose: keep core loop readable and consistent
# -------------------------
def d6_safe(s: str) -> str:
    return re.sub(r"[^A-Za-z0-9_\-\.]+", "_", str(s))

def d6_db_to_lin(db: float) -> float:
    return 10.0 ** (db / 20.0)

def d6_rms_dbfs(y: np.ndarray) -> float:
    if y is None or len(y) == 0:
        return -120.0
    rms = float(np.sqrt(np.mean(y.astype(np.float64) ** 2) + 1e-12))
    return 20.0 * math.log10(max(rms, 1e-12))

def d6_peak_limit(y: np.ndarray, peak_dbfs: float) -> np.ndarray:
    # Purpose: avoid clipping after gain changes
    if y is None or len(y) == 0:
        return y
    peak = float(np.max(np.abs(y)))
    lim = d6_db_to_lin(peak_dbfs)
    if peak > lim and peak > 0:
        y = y * (lim / peak)
    return np.clip(y, -1.0, 1.0).astype(np.float32)

def d6_norm_rms_then_peak(y: np.ndarray) -> np.ndarray:
    # Purpose: normalize loudness while protecting peaks
    if y is None or len(y) == 0:
        return y
    cur = d6_rms_dbfs(y)
    if cur < D6_MIN_RMS_DBFS:
        return d6_peak_limit(y, D6_PEAK_LIMIT_DBFS)
    gain_db = float(D6_TARGET_RMS_DBFS - cur)
    gain_db = float(np.clip(gain_db, -60.0, D6_MAX_GAIN_DB))
    y2 = (y.astype(np.float32) * d6_db_to_lin(gain_db)).astype(np.float32)
    return d6_peak_limit(y2, D6_PEAK_LIMIT_DBFS)

def d6_read_mono(path: str) -> Tuple[np.ndarray, int]:
    # Purpose: load audio and ensure a clean mono float32 vector
    x, sr = sf.read(path, always_2d=False)
    if isinstance(x, np.ndarray) and x.ndim == 2:
        x = x.mean(axis=1)
    x = np.asarray(x, dtype=np.float32)
    if not np.isfinite(x).all():
        x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
    return x, int(sr)

def d6_resample(y: np.ndarray, sr_in: int, sr_out: int) -> np.ndarray:
    # Purpose: resample to 16 kHz for all clips
    if sr_in == sr_out:
        return y.astype(np.float32, copy=False)
    if D6_HAVE_SCIPY:
        g = math.gcd(sr_in, sr_out)
        up = sr_out // g
        down = sr_in // g
        return resample_poly(y.astype(np.float64), up, down).astype(np.float32, copy=False)
    n_new = int(round(len(y) * (sr_out / sr_in)))
    if n_new <= 1:
        return y[:1].astype(np.float32, copy=False)
    x_old = np.linspace(0.0, 1.0, num=len(y), endpoint=False)
    x_new = np.linspace(0.0, 1.0, num=n_new, endpoint=False)
    return np.interp(x_new, x_old, y).astype(np.float32, copy=False)

def d6_pcm16_bytes(y: np.ndarray) -> bytes:
    # Purpose: WebRTC VAD expects 16-bit PCM bytes
    y = np.clip(y, -1.0, 1.0)
    return (y * 32767.0).astype(np.int16).tobytes()

def d6_merge(segs: List[Tuple[int,int]], sr: int) -> List[Tuple[int,int]]:
    # Purpose: merge close speech segments and drop very short ones
    if not segs:
        return []
    gap = int(sr * (D6_MERGE_GAP_MS/1000.0))
    min_len = int(sr * (D6_MIN_SPEECH_MS/1000.0))
    segs = sorted(segs)
    merged = [list(segs[0])]
    for s,e in segs[1:]:
        if s - merged[-1][1] <= gap:
            merged[-1][1] = max(merged[-1][1], e)
        else:
            merged.append([s,e])
    out = []
    for s,e in merged:
        if (e-s) >= min_len:
            out.append((int(s), int(e)))
    return out

def d6_pad(segs: List[Tuple[int,int]], sr: int, n: int) -> List[Tuple[int,int]]:
    # Purpose: add small context around detected speech
    pad = int(sr * D6_PAD_SEC)
    out = []
    for s,e in segs:
        s2 = max(0, s-pad)
        e2 = min(n, e+pad)
        if e2 > s2:
            out.append((s2,e2))
    return out

def d6_webrtc_segments(y: np.ndarray, sr: int) -> Optional[List[Tuple[int,int]]]:
    # Purpose: get voiced segments using WebRTC VAD (when possible)
    if not D6_HAVE_WEBRTCVAD:
        return None
    if sr not in (8000, 16000, 32000, 48000):
        return None
    frame_ms = int(D6_FRAME_MS)
    if frame_ms not in (10, 20, 30):
        frame_ms = 30
    frame_len = int(sr * (frame_ms/1000.0))
    if frame_len <= 0 or len(y) < frame_len:
        return []

    n_frames = int(math.ceil(len(y)/frame_len))
    pad_samp = n_frames*frame_len - len(y)
    if pad_samp > 0:
        y = np.concatenate([y, np.zeros(pad_samp, dtype=np.float32)], axis=0)

    pcm = d6_pcm16_bytes(y)
    vad = webrtcvad.Vad(int(D6_VAD_MODE))
    flags = []
    for i in range(n_frames):
        s = i*frame_len*2
        e = s + frame_len*2
        flags.append(vad.is_speech(pcm[s:e], sr))

    segs = []
    on = False
    s0 = 0
    for i,f in enumerate(flags):
        if f and not on:
            on = True
            s0 = i
        elif (not f) and on:
            on = False
            segs.append((s0*frame_len, i*frame_len))
    if on:
        segs.append((s0*frame_len, n_frames*frame_len))

    n0 = len(y) - pad_samp if pad_samp > 0 else len(y)
    return [(max(0,s), min(n0,e)) for s,e in segs]

def d6_energy_segments(y: np.ndarray, sr: int) -> List[Tuple[int,int]]:
    # Purpose: fallback voiced detection using short-time energy
    frame = int(sr * 0.02)
    hop = frame
    if frame <= 0 or len(y) < frame:
        return []
    eng = []
    idx = []
    for i in range(0, len(y) - frame + 1, hop):
        w = y[i:i+frame]
        eng.append(float(np.mean(w*w)))
        idx.append(i)
    eng = np.array(eng, dtype=np.float32)
    thr = float(np.percentile(eng, 25)) * 2.5
    thr = max(thr, 1e-8)
    keep = eng > thr

    segs = []
    in_seg = False
    s0 = 0
    for k, flag in enumerate(keep):
        if flag and not in_seg:
            in_seg = True
            s0 = idx[k]
        elif (not flag) and in_seg:
            in_seg = False
            segs.append((s0, idx[k] + frame))
    if in_seg:
        segs.append((s0, idx[-1] + frame))
    return segs

def d6_voiced_concat(y: np.ndarray, sr: int) -> Tuple[np.ndarray, str]:
    # Purpose: keep voiced parts and stitch them together into one stream
    segs = d6_webrtc_segments(y, sr)
    if segs is None:
        segs = d6_energy_segments(y, sr)
        used = "energy"
    else:
        used = "webrtcvad"
    segs = d6_merge(segs, sr)
    segs = d6_pad(segs, sr, len(y))
    if not segs:
        return y.astype(np.float32, copy=False), used
    return np.concatenate([y[s:e] for s,e in segs], axis=0).astype(np.float32, copy=False), used

def d6_choose_clip_center(voiced: np.ndarray, sr: int, target_sec: float) -> Tuple[np.ndarray, float, float]:
    # Purpose: return ONE clip (<= target_sec), centered if longer, never padded
    if voiced is None or len(voiced) == 0:
        return np.zeros((0,), dtype=np.float32), 0.0, 0.0
    target_len = int(round(sr * target_sec))
    n = len(voiced)
    if target_len <= 0:
        return voiced.astype(np.float32, copy=False), 0.0, float(n)/sr
    if n >= target_len:
        start = (n - target_len) // 2
        end = start + target_len
        return voiced[start:end].astype(np.float32, copy=False), float(start)/sr, float(end)/sr
    return voiced.astype(np.float32, copy=False), 0.0, float(n)/sr

def d6_find_col(df: pd.DataFrame, candidates: List[str]) -> Optional[str]:
    # Purpose: locate metadata columns despite small naming differences
    cols_lower = {c.lower(): c for c in df.columns}
    for cand in candidates:
        if cand.lower() in cols_lower:
            return cols_lower[cand.lower()]
    for c in df.columns:
        cl = c.lower()
        for cand in candidates:
            if cand.lower() in cl:
                return c
    return None

def d6_list_wavs(root: str) -> List[str]:
    # Purpose: collect all WAVs under a root folder
    return sorted([str(p) for p in Path(root).rglob("*.wav")])

def d6_safe_write_wav(path: str, audio: np.ndarray, sr: int) -> None:
    # Purpose: reliable write with retries
    os.makedirs(os.path.dirname(path), exist_ok=True)
    last_err = None
    for attempt in range(1, D6_WRITE_RETRIES + 1):
        try:
            sf.write(path, np.clip(audio, -1.0, 1.0), sr, subtype="PCM_16")
            return
        except Exception as e:
            last_err = e
            time.sleep(D6_WRITE_SLEEP * attempt)
    raise RuntimeError(f"Failed to write WAV: {path}. Last error: {repr(last_err)}")

# -------------------------
# Build file index (one row per source WAV)
# Inputs: raw HC and PD WAV lists
# Output: d6_files table used for metadata merge and processing
# -------------------------
hc_wavs = d6_list_wavs(D6_HC_ROOT)
pd_wavs = d6_list_wavs(D6_PD_ROOT)

print("\nHC wav count:", len(hc_wavs))
print("PD wav count:", len(pd_wavs))

def d6_build_file_df(wavs: List[str], label_str: str) -> pd.DataFrame:
    # Purpose: attach labels and IDs based on file names
    rows = []
    for fp in wavs:
        base = os.path.basename(fp)
        stem = Path(fp).stem
        rows.append({
            "dataset": "D6",
            "task": D6_TASK5,
            "label_str": label_str,
            "label_num": 1 if label_str.lower() == "parkinson" else 0,
            "speaker_id": str(stem),
            "sample_id": str(base),
            "source_path": str(fp),
        })
    return pd.DataFrame(rows)

d6_files = pd.concat(
    [d6_build_file_df(hc_wavs, "Healthy"), d6_build_file_df(pd_wavs, "Parkinson")],
    ignore_index=True
)

print("D6 file rows:", len(d6_files))
print("Label counts (rows):")
print(d6_files["label_str"].value_counts(dropna=False))

if len(d6_files) == 0:
    raise RuntimeError("No D6 wav files found under HC/PD roots.")

# -------------------------
# Load demographics and merge onto file table
# Inputs: Demographics_age_sex.xlsx
# Output: d6 table with age and sex (when available)
# -------------------------
d6_demo = pd.read_excel(D6_META_XLSX)

col_sample = d6_find_col(d6_demo, ["Sample ID", "SampleID", "Sample_Id", "ID"])
col_age    = d6_find_col(d6_demo, ["Age"])
col_sex    = d6_find_col(d6_demo, ["Sex", "Gender"])

if col_sample is None:
    raise ValueError(f"Could not find a Sample ID column in {D6_META_XLSX}. Columns: {list(d6_demo.columns)}")

d6_demo[col_sample] = d6_demo[col_sample].astype(str).str.strip()
d6_demo["__sample_id_wav__"] = d6_demo[col_sample].apply(lambda x: x if x.lower().endswith(".wav") else f"{x}.wav")

keep_cols = ["__sample_id_wav__"]
if col_age: keep_cols.append(col_age)
if col_sex: keep_cols.append(col_sex)

d6 = d6_files.merge(
    d6_demo[keep_cols],
    left_on="sample_id",
    right_on="__sample_id_wav__",
    how="left"
).drop(columns=["__sample_id_wav__"], errors="ignore")

if col_age:
    d6 = d6.rename(columns={col_age: "age"})
else:
    d6["age"] = np.nan

if col_sex:
    d6 = d6.rename(columns={col_sex: "sex"})
else:
    d6["sex"] = np.nan

missing_both = d6["age"].isna() & d6["sex"].isna()
print("\nRows total:", len(d6))
print("Rows missing BOTH age and sex:", int(missing_both.sum()))

# -------------------------
# Create candidate clips (stream-write)
# Inputs: each source WAV path
# Outputs: one candidate WAV per source in clips/_candidates and a candidate table
# -------------------------
cand_rows: List[Dict] = []
warn_rows: List[Dict] = []
cand_counter = 0

pbar = tqdm(d6.itertuples(index=False), total=len(d6),
            desc="D6 preprocess (1 clip per source; write candidates)", dynamic_ncols=True)

for r in pbar:
    try:
        src = str(r.source_path)

        # Load -> resample -> normalize
        y, sr0 = d6_read_mono(src)
        y = d6_resample(y, sr0, D6_SR)
        y = d6_norm_rms_then_peak(y)

        # Extract voiced stream (or fall back to original if too short)
        voiced, vad_used = d6_voiced_concat(y, D6_SR)
        if len(voiced) < int(D6_SR * D6_MIN_KEEP_SEC):
            voiced = y
            warn_rows.append({
                "dataset":"D6",
                "speaker_id": str(r.speaker_id),
                "source_path": src,
                "warning_type":"vad_too_short_fallback_original",
                "detail": f"vad_used={vad_used}"
            })

        # One clip only: center-crop up to target seconds (no padding)
        clip, st, en = d6_choose_clip_center(voiced, D6_SR, D6_TARGET_SEC)

        if clip is None or len(clip) == 0:
            warn_rows.append({
                "dataset":"D6",
                "speaker_id": str(r.speaker_id),
                "source_path": src,
                "warning_type":"empty_after_processing",
                "detail": "clip length 0"
            })
            continue

        # Write candidate immediately
        cand_counter += 1
        cand_name = d6_safe(f"CAND_{cand_counter:08d}.wav")
        cand_path = os.path.join(D6_CAND_DIR, cand_name)
        d6_safe_write_wav(cand_path, clip.astype(np.float32, copy=False), D6_SR)

        cand_rows.append({
            "dataset":"D6",
            "task": D6_TASK5,
            "speaker_id": str(r.speaker_id),
            "sample_id": str(r.sample_id),
            "label_str": str(r.label_str),
            "label_num": int(r.label_num),
            "age": r.age if pd.notna(r.age) else np.nan,
            "sex": r.sex if pd.notna(r.sex) else np.nan,
            "speaker_key_rel": np.nan,
            "source_path": src,
            "duration_sec": float(len(clip)/D6_SR),
            "clip_start_sec": float(st),
            "clip_end_sec": float(en),
            "sr_hz": int(D6_SR),
            "channels": 1,
            "clip_is_contiguous": True,
            "clip_path_cand": cand_path,
        })

    except Exception as e:
        warn_rows.append({
            "dataset":"D6",
            "speaker_id": getattr(r, "speaker_id", ""),
            "source_path": getattr(r, "source_path", ""),
            "warning_type":"preprocess_error",
            "detail": repr(e)
        })

cand_df = pd.DataFrame(cand_rows)
warnings_df = pd.DataFrame(warn_rows)

print("\nD6 candidates written:", len(cand_df))
print("Candidate dir:", D6_CAND_DIR)

if len(cand_df) == 0:
    if len(warnings_df):
        print("\nSample warnings (first 10):")
        print(warnings_df.head(10).to_string(index=False))
    raise RuntimeError("No D6 candidates produced. See warnings above (if any).")

# -------------------------
# Cap candidates per speaker/task and delete the rest
# Inputs: cand_df and cap limit
# Outputs: capped cand_df and fewer files in clips/_candidates
# -------------------------
def d6_apply_cap_and_delete(df: pd.DataFrame, max_k: int, seed: int) -> pd.DataFrame:
    if df.empty:
        return df.copy()

    rng = np.random.default_rng(seed)
    keep_idx: List[int] = []
    for (spk, task), g in df.groupby(["speaker_id", "task"], sort=False):
        idx = g.index.to_numpy()
        if len(idx) <= max_k:
            keep_idx.extend(idx.tolist())
        else:
            chosen = rng.choice(idx, size=max_k, replace=False)
            keep_idx.extend(chosen.tolist())

    keep_idx = sorted(set(keep_idx))
    keep_set = set(keep_idx)

    # Purpose: remove unneeded candidate files to save space
    to_delete = df.loc[~df.index.isin(list(keep_set)), "clip_path_cand"].tolist()
    deleted = 0
    for p in to_delete:
        try:
            if os.path.exists(p):
                os.remove(p)
                deleted += 1
        except Exception as e:
            nonlocal_warn.append({
                "dataset":"D6",
                "speaker_id":"",
                "source_path":"",
                "warning_type":"candidate_delete_failed",
                "detail": f"{p} :: {repr(e)}"
            })

    return df.loc[keep_idx].reset_index(drop=True), deleted

# Purpose: collect warnings from inside cap function without changing behavior
nonlocal_warn: List[Dict] = []

cand_df, deleted_count = d6_apply_cap_and_delete(cand_df, D6_MAX_CLIPS_PER_SPK_TASK, D6_RANDOM_SEED)
if nonlocal_warn:
    warnings_df = pd.concat([warnings_df, pd.DataFrame(nonlocal_warn)], ignore_index=True)

print("D6 clips after cap:", len(cand_df))
print("Deleted unkept candidates:", int(deleted_count))

# -------------------------
# Speaker-level split (label-stratified)
# Inputs: speakers after cap
# Outputs: a split label for each speaker, then merged onto cand_df
# -------------------------
spk_tbl = cand_df[["speaker_id","label_str"]].drop_duplicates().copy()
rng = random.Random(D6_RANDOM_SEED)

split_rows = []
for lab_name, g in spk_tbl.groupby("label_str"):
    spks = g["speaker_id"].tolist()
    rng.shuffle(spks)
    n = len(spks)

    n_train = int(round(n * D6_TRAIN_PCT))
    n_val   = int(round(n * D6_VAL_PCT))
    n_train = min(n_train, n)
    n_val   = min(n_val, n - n_train)

    train = spks[:n_train]
    val   = spks[n_train:n_train+n_val]
    test  = spks[n_train+n_val:]

    split_rows += [{"speaker_id": s, "split":"train"} for s in train]
    split_rows += [{"speaker_id": s, "split":"val"} for s in val]
    split_rows += [{"speaker_id": s, "split":"test"} for s in test]

spk_split = pd.DataFrame(split_rows)
cand_df = cand_df.merge(spk_split, on="speaker_id", how="left")

# Purpose: ensure every kept speaker received a split
if "split" not in cand_df.columns or cand_df["split"].isna().any():
    ex = cand_df.loc[cand_df.get("split").isna() if "split" in cand_df.columns else [True]*len(cand_df), "speaker_id"] \
                .drop_duplicates().head(10).tolist()
    raise RuntimeError(f"Some speakers did not get split. Example: {ex}")

for sp in ["train","val","test"]:
    os.makedirs(os.path.join(D6_CLIPS_DIR, sp), exist_ok=True)

print("\nSplit counts (clips):")
print(cand_df["split"].value_counts(dropna=False))
print("Label by split (clips):")
print(pd.crosstab(cand_df["split"], cand_df["label_str"]))

# -------------------------
# Finalize clips and build manifest
# Inputs: kept candidate WAVs + split labels
# Outputs: clips/<split>/ WAVs and manifest rows
# -------------------------
final_rows: List[Dict] = []
global_idx = 0

pbarw = tqdm(cand_df.itertuples(index=False), total=len(cand_df),
             desc="D6 finalize (move kept candidates)", dynamic_ncols=True)

for r in pbarw:
    global_idx += 1

    # Purpose: name files consistently without relying on folder structure
    tag = "PD" if str(r.label_str).lower().startswith("parkinson") else "HC"
    spk = str(r.speaker_id)
    split = str(r.split)

    out_name = d6_safe(f"D6_{tag}_{spk}_{D6_TASK5}_{global_idx:06d}.wav")
    out_path = os.path.join(D6_CLIPS_DIR, split, out_name)

    cand_path = str(r.clip_path_cand)
    if not os.path.exists(cand_path):
        warnings_df = pd.concat([warnings_df, pd.DataFrame([{
            "dataset":"D6",
            "speaker_id": spk,
            "source_path": str(r.source_path),
            "warning_type":"missing_candidate_file",
            "detail": cand_path
        }])], ignore_index=True)
        continue

    try:
        shutil.move(cand_path, out_path)
    except Exception as e:
        warnings_df = pd.concat([warnings_df, pd.DataFrame([{
            "dataset":"D6",
            "speaker_id": spk,
            "source_path": str(r.source_path),
            "warning_type":"candidate_move_failed",
            "detail": f"{cand_path} -> {out_path} :: {repr(e)}"
        }])], ignore_index=True)
        continue

    final_rows.append({
        "split": split,
        "dataset": "D6",
        "task": D6_TASK5,
        "speaker_id": spk,
        "sample_id": str(r.sample_id),
        "label_str": str(r.label_str),
        "label_num": int(r.label_num),
        "age": r.age if pd.notna(r.age) else np.nan,
        "sex": r.sex if pd.notna(r.sex) else np.nan,
        "speaker_key_rel": np.nan,
        "clip_path": out_path,
        "duration_sec": float(r.duration_sec),
        "source_path": str(r.source_path),
        "clip_start_sec": float(r.clip_start_sec),
        "clip_end_sec": float(r.clip_end_sec),
        "sr_hz": int(D6_SR),
        "channels": 1,
        "clip_is_contiguous": True,
    })

manifest_df = pd.DataFrame(final_rows)

# Purpose: enforce schema and column order even if some fields are missing
for c in MANIFEST_COLS:
    if c not in manifest_df.columns:
        manifest_df[c] = np.nan
manifest_df = manifest_df[MANIFEST_COLS].copy()

# Purpose: remove staging folder when possible
try:
    if os.path.isdir(D6_CAND_DIR):
        leftovers = list(os.scandir(D6_CAND_DIR))
        for ent in leftovers:
            try:
                os.remove(ent.path)
            except Exception:
                pass
        try:
            os.rmdir(D6_CAND_DIR)
        except Exception:
            pass
except Exception:
    pass

# -------------------------
# Save standardized outputs
# Outputs: manifest_all.csv, preprocess_warnings.csv, dataset_summary.json, run_config.json
# -------------------------
manifest_all_path = os.path.join(D6_MANIFEST_DIR, "manifest_all.csv")
warnings_path     = os.path.join(D6_LOGS_DIR, "preprocess_warnings.csv")
summary_path      = os.path.join(D6_LOGS_DIR, "dataset_summary.json")
run_cfg_path      = os.path.join(D6_CONFIG_DIR, "run_config.json")

manifest_df.to_csv(manifest_all_path, index=False)
warnings_df.to_csv(warnings_path, index=False)

summary = {
    "dataset": "D6",
    "source_root": D6_DATASET_DIR,
    "hc_root": D6_HC_ROOT,
    "pd_root": D6_PD_ROOT,
    "metadata_xlsx": D6_META_XLSX,
    "sr_hz": int(D6_SR),
    "webrtcvad_available": bool(D6_HAVE_WEBRTCVAD),
    "scipy_available": bool(D6_HAVE_SCIPY),
    "task": D6_TASK5,
    "one_clip_per_source_file": True,
    "candidate_write_timing": "written immediately to clips/_candidates then moved to clips/<split>",
    "cap_policy": {
        "applied_stage": "after_candidates_before_splitting",
        "grouping": ["speaker_id", "task"],
        "max_per_group": int(D6_MAX_CLIPS_PER_SPK_TASK)
    },
    "splits": {"train": D6_TRAIN_PCT, "val": D6_VAL_PCT, "test": D6_TEST_PCT},
    "counts": {
        "n_source_files": int(len(d6_files)),
        "n_candidates_written_pre_cap": int(len(cand_rows)),
        "n_candidates_post_cap": int(len(cand_df)),
        "n_clips_written": int(len(manifest_df)),
        "split_counts_clips": manifest_df["split"].value_counts(dropna=False).to_dict() if len(manifest_df) else {},
        "label_counts_clips": manifest_df["label_str"].value_counts(dropna=False).to_dict() if len(manifest_df) else {},
        "n_warnings": int(len(warnings_df)),
        "deleted_unkept_candidates": int(deleted_count),
    }
}
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(summary, f, indent=2)

run_cfg = {
    "dataset": "D6",
    "paths": {
        "dataset_dir": D6_DATASET_DIR,
        "hc_root": D6_HC_ROOT,
        "pd_root": D6_PD_ROOT,
        "metadata_xlsx": D6_META_XLSX,
        "out_root": D6_OUT_ROOT,
        "clips_dir": D6_CLIPS_DIR,
        "candidate_dir": D6_CAND_DIR,
        "manifest_all": manifest_all_path,
        "warnings_csv": warnings_path,
        "summary_json": summary_path,
    },
    "standard_structure": {
        "clips": "clips/<split>/ (flat) with temporary clips/_candidates during run",
        "manifests": "manifests/manifest_all.csv",
        "config": "config/run_config.json",
        "logs": ["logs/preprocess_warnings.csv", "logs/dataset_summary.json"],
    },
    "filename_format": "D6_{HC|PD}_{speaker_id}_{task<=5}_{global_index:06d}.wav",
    "task_short": D6_TASK5,
    "one_clip_per_source_file": True,
    "clip_selection": {
        "method": "center-crop voiced-concatenated stream",
        "target_sec": float(D6_TARGET_SEC),
        "padding": False
    },
    "normalization": {
        "method": "manual RMS gain + peak limit",
        "target_rms_dbfs": float(D6_TARGET_RMS_DBFS),
        "peak_limit_dbfs": float(D6_PEAK_LIMIT_DBFS),
        "min_rms_dbfs": float(D6_MIN_RMS_DBFS),
        "max_gain_db": float(D6_MAX_GAIN_DB),
        "pyloudnorm": False
    },
    "vad": {
        "method": "webrtcvad if available else energy fallback",
        "mode": int(D6_VAD_MODE),
        "frame_ms": int(D6_FRAME_MS),
        "min_speech_ms": int(D6_MIN_SPEECH_MS),
        "merge_gap_ms": int(D6_MERGE_GAP_MS),
        "pad_sec": float(D6_PAD_SEC),
        "min_keep_sec": float(D6_MIN_KEEP_SEC)
    },
    "cap": {"max_per_speaker_task": int(D6_MAX_CLIPS_PER_SPK_TASK)},
    "seed": int(D6_RANDOM_SEED),
}
with open(run_cfg_path, "w", encoding="utf-8") as f:
    json.dump(run_cfg, f, indent=2)

print("\nDONE: D6 preprocessing")
print("Manifest:", manifest_all_path)
print("Warnings:", warnings_path)
print("Summary:", summary_path)
print("Run config:", run_cfg_path)
print("Clips written:", int(len(manifest_df)))
print("Warnings:", int(len(warnings_df)))

#Training and Validation of D1, D2, D4, D5 and D6 Datasets

The following cell trains and validates the baseline model on Dataset D1 (NeuroVoz Spanish) using only the preprocessed audio clips listed in `manifests/manifest_all.csv`. It is written to be robust against failures by checking inputs early, showing progress bars, and saving outputs into a new time stamped experiment folder so previous runs are never overwritten. The Wav2Vec2 backbone remains frozen, and only the small task specific heads are trained.

The cell starts by importing all required libraries and setting the D1 output root path (`DX_OUT_ROOT`) along with the manifest location. An experiment record is then created using an experiment tag and a timestamp, and the full output folder structure is built as
`<DX_OUT_ROOT>/trainval_runs/exp_<tag>_<timestamp>/run_<dataset>_seed####/`.

Next, fixed training settings are defined to ensure repeatable results. These include up to 10 epochs, an effective batch size of 64 using gradient accumulation, a learning rate of 1e-3, early stopping after two validation AUROC non improvements, three fixed random seeds (1337, 2024, 7777), mixed precision on the GPU for speed, and a fixed probability threshold of 0.5 for threshold based metrics. The main run settings and GPU information are printed.

The cell then loads `manifest_all.csv`, checks that all required columns are present, keeps only rows marked as `train` or `val`, and infers the dataset name from the most common `dataset` value if that column exists. It prints the number of training and validation rows along with class counts for Healthy and Parkinson’s clips, and stops immediately if either split is empty. A progress bar based check then confirms that all audio files listed in `clip_path` exist, failing early and reporting example missing paths if any are found.

Each row is assigned to a task group, where `task == "vowl"` is treated as vowel speech and all other values are treated as non vowel speech. A custom dataset loader reads each audio clip, enforces a 16 kHz sample rate, converts audio to mono if needed, and creates an attention mask. For vowel clips, trailing near zero padding is detected using a small amplitude threshold and masked out so padded silence does not affect learning. For non vowel clips, the attention mask contains all ones. A custom collator pads audio clips and attention masks within each batch to a common length.

The model uses a frozen Wav2Vec2 backbone with a two head classifier on top. After feature extraction, a masked mean pooling step summarizes each clip. Two small pre head blocks, each made of LayerNorm and Dropout, are used, one for vowel clips and one for non vowel clips, followed by separate linear classifiers. To avoid issues under mixed precision, all head computations are forced to run in float32. Each batch routes data through the correct head based on the task group.

For each random seed, the cell creates a seed specific run folder, builds data loaders, and runs a short warm up by loading three training batches to catch disk or batching problems early. Training then proceeds epoch by epoch with validation after each epoch, tracking validation AUROC. Early stopping is applied when AUROC does not improve for two consecutive epochs. Only the best validation epoch is saved, including the best head weights (`best_heads.pt`), validation plots (`roc_curve.png` and `confusion_matrix.png` at threshold 0.5), and a `metrics.json` file containing the best AUROC and epoch, dataset sizes, label counts, run settings, and additional threshold based metrics such as accuracy, precision, recall or sensitivity, specificity, F1 score, MCC, and a two sided Fisher exact test p value.

After all three seeds complete, the cell prints the AUROC for each seed and computes the mean AUROC with a 95% confidence interval using a t distribution with n=3. A single `summary_trainval.json` file is written for the experiment, and the same summary is appended as one line to the global history log at
`<DX_OUT_ROOT>/trainval_runs/history_index.jsonl`.

Finally, the cell prints the path to the experiment folder and unassigns the Colab runtime to cleanly stop the L4 GPU session.

In [None]:
# ============================================================
# D1 Train + Val — Frozen Wav2Vec2 + Two Task Heads, Full Run History
# Inputs: manifest_all.csv (train/val rows) and audio clips referenced by clip_path
# Outputs: per-seed run folders (metrics.json, roc_curve.png, confusion_matrix.png, best_heads.pt),
#          plus one experiment summary_trainval.json and one appended history_index.jsonl record
# ============================================================

import os, json, math, random, time, warnings
from pathlib import Path

import numpy as np
import pandas as pd
import soundfile as sf

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import Wav2Vec2Model

from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, matthews_corrcoef
from scipy.stats import fisher_exact
import matplotlib.pyplot as plt

from tqdm.auto import tqdm

# -------------------------
# Dataset root and manifest location
# Input: dataset output root created by preprocessing
# Output: MANIFEST_ALL path used for train/val tables
# -------------------------
DX_OUT_ROOT = "/content/drive/MyDrive/AI_PD_Project/Datasets/D1-NeuroVoz-Castillan Spanish/preprocessed_v1"
MANIFEST_ALL = f"{DX_OUT_ROOT}/manifests/manifest_all.csv"

# -------------------------
# Experiment folder naming (keeps older runs intact)
# Output: one EXP_ROOT folder for this execution
# -------------------------
EXPERIMENT_TAG = "frozen_LNDO"   # short label used in folder and summary files
RUN_STAMP = time.strftime("%Y%m%d_%H%M%S")  # unique timestamp per run
TRAINVAL_ROOT = Path(DX_OUT_ROOT) / "trainval_runs"
EXP_ROOT = TRAINVAL_ROOT / f"exp_{EXPERIMENT_TAG}_{RUN_STAMP}"
EXP_ROOT.mkdir(parents=True, exist_ok=True)

# -------------------------
# Training settings (same across seeds)
# Purpose: stable and comparable runs
# -------------------------
MAX_EPOCHS     = 10
EFFECTIVE_BS   = 64
PER_DEVICE_BS  = 16
GRAD_ACCUM     = max(1, EFFECTIVE_BS // PER_DEVICE_BS)

LR             = 1e-3
PATIENCE       = 2
SEEDS          = [1337, 2024, 7777]

BACKBONE_CKPT  = "facebook/wav2vec2-base"
SR_EXPECTED    = 16000

# Purpose: detect where trailing padding starts for vowel clips
TINY_THRESH    = 1e-4

# Purpose: small regularization in the trainable head path
DROPOUT_P      = 0.2

# Purpose: stable loading from Drive-backed files
NUM_WORKERS    = 0
PIN_MEMORY     = False

# Purpose: fixed operating point for confusion matrix and threshold metrics
VAL_THRESHOLD  = 0.5

# Purpose: speed on GPU while keeping head math stable (heads forced FP32 later)
USE_AMP        = True

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
warnings.filterwarnings("ignore", message="Passing `gradient_checkpointing` to a config initialization is deprecated")

print("EXPERIMENT ROOT:", str(EXP_ROOT))
print("DEVICE:", DEVICE)
if DEVICE.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
print("PER_DEVICE_BS:", PER_DEVICE_BS, "| GRAD_ACCUM:", GRAD_ACCUM, "| EFFECTIVE_BS:", PER_DEVICE_BS * GRAD_ACCUM)
print("NUM_WORKERS:", NUM_WORKERS, "| PIN_MEMORY:", PIN_MEMORY)
print("VAL_THRESHOLD:", VAL_THRESHOLD)
print("USE_AMP:", bool(USE_AMP and DEVICE.type == "cuda"))

# -------------------------
# Load and filter manifest (train/val only)
# Inputs: manifest_all.csv
# Outputs: train_df, val_df, dataset_id, and basic counts printed
# -------------------------
if not os.path.exists(MANIFEST_ALL):
    raise FileNotFoundError(f"Missing manifest_all.csv: {MANIFEST_ALL}")

m = pd.read_csv(MANIFEST_ALL)
req_cols = {"split", "clip_path", "label_num", "task"}
missing = [c for c in sorted(req_cols) if c not in m.columns]
if missing:
    raise ValueError(f"Manifest missing required columns: {missing}. Found: {list(m.columns)}")

m = m[m["split"].isin(["train", "val"])].copy()
if len(m) == 0:
    raise RuntimeError("After filtering to split in {train,val}, manifest has 0 rows.")

# Purpose: pick the dominant dataset label if multiple are present
if "dataset" in m.columns and m["dataset"].notna().any():
    dataset_id = str(m["dataset"].value_counts(dropna=True).idxmax())
    m = m[m["dataset"].astype(str) == dataset_id].copy()
else:
    dataset_id = "DX"

# Purpose: keep a consistent set of columns even if some are missing
keep_cols = ["clip_path", "label_num", "task", "speaker_id", "duration_sec", "split"]
for c in keep_cols:
    if c not in m.columns:
        m[c] = np.nan
m = m[keep_cols].copy()

train_df = m[m["split"] == "train"].copy().reset_index(drop=True)
val_df   = m[m["split"] == "val"].copy().reset_index(drop=True)

print(f"\nDataset inferred: {dataset_id}")
print(f"Train rows: {len(train_df)} | Val rows: {len(val_df)}")
print("Train label counts:", train_df["label_num"].value_counts(dropna=False).to_dict())
print("Val label counts:",   val_df["label_num"].value_counts(dropna=False).to_dict())

if len(train_df) == 0 or len(val_df) == 0:
    raise RuntimeError("Train or Val split has 0 rows.")

# -------------------------
# Clip existence check (progress bar)
# Inputs: train_df/val_df clip_path
# Output: fail early with a few missing examples
# -------------------------
def _fail_fast_missing_paths(df: pd.DataFrame, name: str):
    missing_paths = []
    for p in tqdm(df["clip_path"].astype(str).tolist(), desc=f"Check {name} clip_path exists", dynamic_ncols=True):
        if not os.path.exists(p):
            missing_paths.append(p)
            if len(missing_paths) >= 10:
                break
    if missing_paths:
        raise FileNotFoundError(f"{name}: missing clip_path(s). Examples: {missing_paths}")

_fail_fast_missing_paths(train_df, "TRAIN")
_fail_fast_missing_paths(val_df, "VAL")

# -------------------------
# Task grouping (vowel vs other)
# Purpose: route each clip to the matching head
# Output: task_group column used by dataset and model
# -------------------------
def _task_group(task_val) -> str:
    return "vowel" if str(task_val) == "vowl" else "other"

train_df["task_group"] = train_df["task"].apply(_task_group)
val_df["task_group"]   = val_df["task"].apply(_task_group)

# -------------------------
# Dataset and collator (dynamic padding + custom mask)
# Inputs: one manifest row per sample
# Outputs: padded batches with input_values, attention_mask, labels, task_group
# -------------------------
class AudioManifestDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        clip_path = str(row["clip_path"])
        label = int(row["label_num"])
        task_group = str(row["task_group"])

        if not os.path.exists(clip_path):
            raise FileNotFoundError(f"Missing clip_path: {clip_path}")

        # Purpose: load audio, convert to mono, force float32
        y, sr = sf.read(clip_path, always_2d=False)
        if isinstance(y, np.ndarray) and y.ndim > 1:
            y = np.mean(y, axis=1)
        y = np.asarray(y, dtype=np.float32)

        # Purpose: enforce training sample rate
        if int(sr) != SR_EXPECTED:
            raise ValueError(f"Sample rate mismatch (expected {SR_EXPECTED}): {clip_path} has sr={sr}")

        # Purpose: hide trailing padding in vowel clips; other clips are fully visible
        attn = np.ones((len(y),), dtype=np.int64)
        if task_group == "vowel":
            k = -1
            for j in range(len(y) - 1, -1, -1):
                if abs(float(y[j])) > TINY_THRESH:
                    k = j
                    break
            if k >= 0:
                attn[:k+1] = 1
                attn[k+1:] = 0
            else:
                attn[:] = 1
        else:
            attn[:] = 1

        return {
            "input_values": torch.from_numpy(y),
            "attention_mask": torch.from_numpy(attn),
            "labels": torch.tensor(label, dtype=torch.long),
            "task_group": task_group,
        }

def collate_fn(batch):
    # Purpose: pad each batch to its own max length, and pad mask with zeros
    max_len = int(max(b["input_values"].numel() for b in batch))
    input_vals, attn_masks, labels, task_groups = [], [], [], []
    for b in batch:
        x = b["input_values"]
        a = b["attention_mask"]
        pad = max_len - x.numel()
        if pad > 0:
            x = torch.cat([x, torch.zeros(pad, dtype=x.dtype)], dim=0)
            a = torch.cat([a, torch.zeros(pad, dtype=a.dtype)], dim=0)
        input_vals.append(x)
        attn_masks.append(a)
        labels.append(b["labels"])
        task_groups.append(b["task_group"])
    return {
        "input_values": torch.stack(input_vals, dim=0),
        "attention_mask": torch.stack(attn_masks, dim=0),
        "labels": torch.stack(labels, dim=0),
        "task_group": task_groups,
    }

# -------------------------
# Model: frozen backbone + per-task LN/Dropout + per-task linear head
# Inputs: padded audio batch + attention_mask + task_group
# Outputs: loss and logits (PD probability comes from softmax(logits)[:,1])
# -------------------------
class Wav2Vec2TwoHeadClassifier(nn.Module):
    def __init__(self, ckpt: str, dropout_p: float):
        super().__init__()

        # Purpose: backbone features are fixed; only heads train
        self.backbone = Wav2Vec2Model.from_pretrained(ckpt)
        for p in self.backbone.parameters():
            p.requires_grad = False

        H = int(self.backbone.config.hidden_size)

        # Purpose: task-specific normalization and dropout before each head
        self.pre_vowel = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.pre_other = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))

        self.head_vowel = nn.Linear(H, 2)
        self.head_other = nn.Linear(H, 2)
        self.loss_fn = nn.CrossEntropyLoss()

    def masked_mean_pool(self, last_hidden, attn_mask_samples):
        # Purpose: pool only where attention_mask is 1 (after converting to feature time)
        feat_mask = self.backbone._get_feature_vector_attention_mask(last_hidden.shape[1], attn_mask_samples)
        feat_mask = feat_mask.to(last_hidden.device).unsqueeze(-1).type_as(last_hidden)
        masked = last_hidden * feat_mask
        denom = feat_mask.sum(dim=1).clamp(min=1.0)
        return masked.sum(dim=1) / denom

    def _heads_fp32(self, x_fp_any: torch.Tensor, head: nn.Module) -> torch.Tensor:
        # Purpose: keep head math in float32 even when autocast is enabled
        x = x_fp_any.float()
        if x.is_cuda:
            with torch.amp.autocast(device_type="cuda", enabled=False):
                return head(x)
        return head(x)

    def forward(self, input_values, attention_mask, labels, task_group):
        # Purpose: avoid gradients through the frozen backbone
        with torch.no_grad():
            out = self.backbone(input_values=input_values, attention_mask=attention_mask)
            last_hidden = out.last_hidden_state

        pooled = self.masked_mean_pool(last_hidden, attention_mask)  # [B,H]

        # Purpose: compute both routes in FP32 so routing is dtype-safe
        z_v = self.pre_vowel(pooled.float())
        z_o = self.pre_other(pooled.float())

        logits_v = self._heads_fp32(z_v, self.head_vowel)  # FP32
        logits_o = self._heads_fp32(z_o, self.head_other)  # FP32

        # Purpose: select the matching head per sample
        is_vowel = torch.tensor([tg == "vowel" for tg in task_group], device=pooled.device, dtype=torch.bool)
        logits = logits_o.clone()
        if is_vowel.any():
            logits[is_vowel] = logits_v[is_vowel]

        loss = self.loss_fn(logits, labels)
        return loss, logits

# -------------------------
# Validation metrics and figures
# Inputs: y_true and PD probability
# Outputs: AUROC, threshold metrics, ROC and confusion PNGs
# -------------------------
def compute_auc(y_true, y_prob):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    if len(np.unique(y_true)) < 2:
        return float("nan")
    return float(roc_auc_score(y_true, y_prob))

def compute_threshold_metrics(y_true, y_prob, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])

    eps = 1e-12
    acc = (TP + TN) / max(1, (TP + TN + FP + FN))
    prec = TP / (TP + FP + eps)
    rec = TP / (TP + FN + eps)     # sensitivity
    f1 = 2 * prec * rec / (prec + rec + eps)
    spec = TN / (TN + FP + eps)

    try:
        mcc = float(matthews_corrcoef(y_true, y_pred)) if len(np.unique(y_pred)) > 1 else float("nan")
    except Exception:
        mcc = float("nan")

    # Purpose: quick association test on [[TN,FP],[FN,TP]]
    try:
        _, pval = fisher_exact([[TN, FP], [FN, TP]], alternative="two-sided")
        pval = float(pval)
    except Exception:
        pval = float("nan")

    return {
        "threshold": float(thr),
        "confusion_matrix": {"TN": TN, "FP": FP, "FN": FN, "TP": TP},
        "accuracy": float(acc),
        "precision": float(prec),
        "recall": float(rec),
        "f1_score": float(f1),
        "sensitivity": float(rec),
        "specificity": float(spec),
        "mcc": float(mcc),
        "p_value_fisher_two_sided": float(pval),
    }

def save_roc_curve_png(y_true, y_prob, out_png):
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    plt.figure()
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curve (Val)")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def save_confusion_png(y_true, y_prob, out_png, thr=0.5):
    y_pred = (np.asarray(y_prob) >= float(thr)).astype(int)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    plt.figure()
    plt.imshow(cm)
    plt.xticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.yticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix (Val, thr={thr:.2f})")
    for (i, j), v in np.ndenumerate(cm):
        plt.text(j, i, str(v), ha="center", va="center")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

# -------------------------
# Deterministic seeding
# Purpose: stabilize shuffling and initialization per seed
# -------------------------
def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# -------------------------
# One seed run: train, validate, save best epoch artifacts
# Inputs: seed, train_df, val_df
# Outputs: run folder files and values used in the experiment summary
# -------------------------
def run_trainval_once(seed: int):
    set_all_seeds(seed)

    run_dir = EXP_ROOT / f"run_{dataset_id}_seed{seed}"
    run_dir.mkdir(parents=True, exist_ok=True)

    train_ds = AudioManifestDataset(train_df)
    val_ds   = AudioManifestDataset(val_df)

    train_loader = DataLoader(train_ds, batch_size=PER_DEVICE_BS, shuffle=True,
                              num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=collate_fn)
    val_loader   = DataLoader(val_ds, batch_size=PER_DEVICE_BS, shuffle=False,
                              num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=collate_fn)

    # Purpose: quick I/O and batching smoke test
    print(f"\n[seed={seed}] Warm-up: loading 3 train batches...")
    t0 = time.time()
    it = iter(train_loader)
    for i in range(3):
        _ = next(it)
        print(f"  loaded warmup batch {i+1}/3")
    print(f"[seed={seed}] Warm-up done in {time.time()-t0:.2f}s")

    model = Wav2Vec2TwoHeadClassifier(BACKBONE_CKPT, dropout_p=DROPOUT_P).to(DEVICE)

    # Purpose: only heads and pre-head blocks train
    trainable_params = (
        list(model.pre_vowel.parameters()) + list(model.pre_other.parameters()) +
        list(model.head_vowel.parameters()) + list(model.head_other.parameters())
    )
    opt = torch.optim.Adam(trainable_params, lr=LR)

    use_amp = bool(USE_AMP and DEVICE.type == "cuda")
    scaler = torch.amp.GradScaler("cuda", enabled=use_amp)

    best_auc = -1.0
    best_epoch = -1
    no_improve = 0
    best_state = None
    best_val_probs = None
    best_val_true = None
    best_thr_metrics = None

    for epoch in range(1, MAX_EPOCHS + 1):
        # Train
        model.train()
        train_losses = []
        opt.zero_grad(set_to_none=True)

        pbar = tqdm(train_loader, desc=f"[seed={seed}] Train epoch {epoch}", dynamic_ncols=True)
        step = 0
        for batch in pbar:
            step += 1
            input_values = batch["input_values"].to(DEVICE, non_blocking=False)
            attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
            labels = batch["labels"].to(DEVICE, non_blocking=False)
            task_group = batch["task_group"]

            with torch.amp.autocast(device_type="cuda", enabled=use_amp):
                loss, _ = model(input_values, attention_mask, labels, task_group)
                loss = loss / GRAD_ACCUM

            scaler.scale(loss).backward()
            train_losses.append(float(loss.detach().cpu().item()) * GRAD_ACCUM)

            if (step % GRAD_ACCUM) == 0:
                scaler.step(opt)
                scaler.update()
                opt.zero_grad(set_to_none=True)

        # Purpose: flush last partial accumulation if needed
        if (step % GRAD_ACCUM) != 0:
            scaler.step(opt)
            scaler.update()
            opt.zero_grad(set_to_none=True)

        avg_train_loss = float(np.mean(train_losses)) if train_losses else float("nan")

        # Validate
        model.eval()
        all_probs, all_true = [], []
        vpbar = tqdm(val_loader, desc=f"[seed={seed}] Val epoch {epoch}", dynamic_ncols=True)
        with torch.inference_mode():
            for batch in vpbar:
                input_values = batch["input_values"].to(DEVICE, non_blocking=False)
                attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
                labels = batch["labels"].to(DEVICE, non_blocking=False)
                task_group = batch["task_group"]

                with torch.amp.autocast(device_type="cuda", enabled=use_amp):
                    _, logits = model(input_values, attention_mask, labels, task_group)

                probs = torch.softmax(logits.float(), dim=-1)[:, 1].detach().cpu().numpy().astype(np.float64)
                all_probs.extend(probs.tolist())
                all_true.extend(labels.detach().cpu().numpy().astype(np.int64).tolist())

        val_auc = compute_auc(all_true, all_probs)
        print(f"seed={seed} | epoch {epoch:02d}/{MAX_EPOCHS} | train_loss={avg_train_loss:.5f} | val_AUROC={val_auc:.5f}")

        # Purpose: keep only the best epoch by VAL AUROC
        improved = (not math.isnan(val_auc)) and (val_auc > best_auc + 1e-12)
        if improved:
            best_auc = float(val_auc)
            best_epoch = int(epoch)
            no_improve = 0
            best_state = {
                "pre_vowel": {k: v.detach().cpu().clone() for k, v in model.pre_vowel.state_dict().items()},
                "pre_other": {k: v.detach().cpu().clone() for k, v in model.pre_other.state_dict().items()},
                "head_vowel": {k: v.detach().cpu().clone() for k, v in model.head_vowel.state_dict().items()},
                "head_other": {k: v.detach().cpu().clone() for k, v in model.head_other.state_dict().items()},
            }
            best_val_probs = list(all_probs)
            best_val_true  = list(all_true)
            best_thr_metrics = compute_threshold_metrics(best_val_true, best_val_probs, thr=VAL_THRESHOLD)
        else:
            no_improve += 1

        if no_improve >= PATIENCE:
            break

    if best_state is None or best_val_probs is None or best_val_true is None or best_thr_metrics is None:
        raise RuntimeError("No best epoch captured. Validation AUROC may be NaN due to single-class validation split.")

    # Save best epoch artifacts only
    best_heads_path = run_dir / "best_heads.pt"
    torch.save(best_state, str(best_heads_path))

    roc_png = run_dir / "roc_curve.png"
    cm_png  = run_dir / "confusion_matrix.png"
    ytrue_np = np.asarray(best_val_true, dtype=np.int64)
    yprob_np = np.asarray(best_val_probs, dtype=np.float64)

    save_roc_curve_png(ytrue_np, yprob_np, str(roc_png))
    save_confusion_png(ytrue_np, yprob_np, str(cm_png), thr=VAL_THRESHOLD)

    metrics = {
        "dataset": dataset_id,
        "seed": int(seed),
        "best_val_auroc": float(best_auc),
        "best_epoch": int(best_epoch),
        "n_train": int(len(train_df)),
        "n_val": int(len(val_df)),
        "label_counts_train": train_df["label_num"].value_counts(dropna=False).to_dict(),
        "label_counts_val": val_df["label_num"].value_counts(dropna=False).to_dict(),
        "experiment_tag": EXPERIMENT_TAG,
        "run_stamp": RUN_STAMP,
        "dropout_p": float(DROPOUT_P),
        "lr": float(LR),
        "effective_batch_size": int(PER_DEVICE_BS * GRAD_ACCUM),
        "per_device_batch_size": int(PER_DEVICE_BS),
        "grad_accum_steps": int(GRAD_ACCUM),
        "val_threshold": float(VAL_THRESHOLD),
        "backbone_ckpt": BACKBONE_CKPT,
        "threshold_metrics_best_epoch": best_thr_metrics,
    }
    metrics_path = run_dir / "metrics.json"
    with open(metrics_path, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2)

    print(f"[seed={seed}] WROTE:")
    print(" ", str(metrics_path))
    print(" ", str(roc_png))
    print(" ", str(cm_png))
    print(" ", str(best_heads_path))

    return float(best_auc), str(run_dir), best_thr_metrics

# -------------------------
# Run all seeds and write experiment-level summary
# Inputs: SEEDS list
# Outputs: summary_trainval.json and appended history_index.jsonl
# -------------------------
aucs = []
run_dirs = []
per_seed_metrics = []

for seed in SEEDS:
    a, rd, thrm = run_trainval_once(seed)
    aucs.append(a)
    run_dirs.append(rd)
    per_seed_metrics.append({"seed": int(seed), "best_val_auroc": float(a), "threshold_metrics_best_epoch": thrm})

t_crit = 4.302652729911275  # df=2, 95% CI
n = len(aucs)
mean_auc = float(np.mean(aucs))
std_auc = float(np.std(aucs, ddof=1)) if n > 1 else 0.0
half_width = float(t_crit * (std_auc / math.sqrt(n))) if n > 1 else 0.0

print("\nAUROC by seed:")
for s, a in zip(SEEDS, aucs):
    print(f"  seed {s}: {a:.6f}")
print(f"\nMean AUROC: {mean_auc:.6f}")
print(f"95% CI (t, n=3): [{mean_auc - half_width:.6f}, {mean_auc + half_width:.6f}]")

exp_summary = {
    "dataset": dataset_id,
    "dx_out_root": DX_OUT_ROOT,
    "experiment_tag": EXPERIMENT_TAG,
    "run_stamp": RUN_STAMP,
    "exp_root": str(EXP_ROOT),
    "run_dirs": run_dirs,
    "seeds": SEEDS,
    "aurocs": [float(x) for x in aucs],
    "mean_auroc": float(mean_auc),
    "t_crit_df2_95": float(t_crit),
    "half_width_95_ci": float(half_width),
    "ci95": [float(mean_auc - half_width), float(mean_auc + half_width)],
    "n_train": int(len(train_df)),
    "n_val": int(len(val_df)),
    "label_counts_train": train_df["label_num"].value_counts(dropna=False).to_dict(),
    "label_counts_val": val_df["label_num"].value_counts(dropna=False).to_dict(),
    "effective_batch_size": int(PER_DEVICE_BS * GRAD_ACCUM),
    "per_device_batch_size": int(PER_DEVICE_BS),
    "grad_accum_steps": int(GRAD_ACCUM),
    "val_threshold": float(VAL_THRESHOLD),
    "backbone_ckpt": BACKBONE_CKPT,
    "dropout_p": float(DROPOUT_P),
    "lr": float(LR),
    "per_seed_best_epoch_metrics": per_seed_metrics,
}

summary_path = EXP_ROOT / "summary_trainval.json"
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(exp_summary, f, indent=2)

# Output: append one line per experiment so older records remain intact
history_path = TRAINVAL_ROOT / "history_index.jsonl"
with open(history_path, "a", encoding="utf-8") as f:
    f.write(json.dumps(exp_summary) + "\n")

print("\nWROTE per-experiment summary:", str(summary_path))
print("APPENDED global history index:", str(history_path))
print("\nOpen this folder to access artifacts:", str(EXP_ROOT))

# -------------------------
# Stop runtime (release GPU)
# -------------------------
print("\nAll done. Unassigning the runtime to stop the L4 instance...")
try:
    from google.colab import runtime  # type: ignore
    print("Calling runtime.unassign() now.")
    runtime.unassign()
except Exception as e:
    print("Could not unassign runtime automatically. You can stop the runtime manually in Colab.")
    print("Reason:", repr(e))

The following cell trains and validates a Parkinson’s versus Healthy classifier on Dataset D2 (Slovak) using only the preprocessed dataset folder and its `manifests/manifest_all.csv`. It performs training and validation only, with no test step. The cell is set up to stop early if required files or inputs are missing, and it saves all outputs in a new time stamped experiment folder so previous results are not overwritten.

The cell starts by importing the required libraries and setting `DX_OUT_ROOT` to the D2 preprocessed dataset path, along with the path to `manifest_all.csv`. It then defines fixed run settings to ensure repeatable results, including up to 10 epochs, an effective batch size of 64 using gradient accumulation, a learning rate of 1e-3, early stopping after two validation AUROC non improvements, three fixed random seeds (1337, 2024, 7777), a fixed decision threshold of 0.5 for threshold based metrics, and mixed precision when running on an L4 GPU. An output folder structure matching the other datasets is created at
`<DX_OUT_ROOT>/trainval_runs/exp_<tag>_<timestamp>/run_<dataset>_seed####/`, and a global history log file (`history_index.jsonl`) is prepared.

Next, the cell loads the manifest file, checks that all required columns are present, and keeps only rows where the split is `train` or `val`. If available, the dataset name is inferred from the most common value in the `dataset` column. It prints the number of training and validation rows along with label counts. Each row is assigned to a simple task group, where `task == "vowl"` is treated as vowel speech and all other tasks are treated as non vowel speech. Before training begins, a progress bar based check confirms that all audio files listed in `clip_path` exist on disk.

A custom dataset class is used to load audio clips from disk, enforcing a 16 kHz sample rate and converting multi channel audio to mono. An attention mask is created for each clip. For vowel clips, the end of real audio is detected using a small amplitude threshold and trailing near zero padding is masked out so the model does not learn from artificial silence. For other clips, the attention mask contains only ones. A custom collator then pads each batch to the length of the longest clip and pads the attention masks in the same way.

The model consists of a frozen Wav2Vec2 backbone with two small trainable heads, one for vowel clips and one for non vowel clips. Each head uses LayerNorm, Dropout, and a Linear layer that outputs two classes. The backbone weights are not updated during training. When mixed precision is enabled, the head computations are forced to run in float32 for stability.

For each of the three random seeds, the cell creates a seed specific run folder and data loaders, and runs a short warm up by loading three training batches to catch disk or batching issues early. Training then proceeds epoch by epoch with validation at the end of each epoch. Validation AUROC is tracked and early stopping is applied after two epochs without improvement. Only the best validation epoch is kept for each seed, and the following files are saved in the seed’s run folder: `best_heads.pt` with the trained head weights and metadata, `roc_curve.png` and `confusion_matrix.png` showing validation results at threshold 0.5, and a `metrics.json` file with the best AUROC, best epoch, dataset sizes, run settings, and threshold based metrics at 0.5 including accuracy, precision, recall or sensitivity, specificity, F1 score, MCC, and the Fisher exact test p value.

After all seeds finish, the cell prints the AUROC for each seed and computes the mean AUROC with a 95% confidence interval using a t distribution with n=3. A single `summary_trainval.json` file is written for the full experiment, and the same summary is appended as one line to the global `trainval_runs/history_index.jsonl` file.

Finally, the cell prints the main output folder path and unassigns the Colab runtime to cleanly stop the L4 GPU session.

In [None]:
# ============================================================
# D2 Train + Val — Frozen Backbone, Two Heads, LN + Dropout, Full Run History
# Inputs: manifest_all.csv (train/val only) and the audio clips referenced by clip_path
# Outputs: per-seed run folders with metrics.json, roc_curve.png, confusion_matrix.png, best_heads.pt
#          plus one experiment summary_trainval.json and one appended history_index.jsonl record
# ============================================================

# -------------------------
# Imports and core libraries
# Purpose: audio I/O, model training, metrics, and plotting
# -------------------------
import os
import json
import math
import time
import random
from pathlib import Path
from typing import List, Dict, Any

import numpy as np
import pandas as pd
import soundfile as sf

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from tqdm.auto import tqdm

import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, matthews_corrcoef
from scipy.stats import fisher_exact

from transformers import Wav2Vec2Model

# -------------------------
# D2 dataset root and manifest location
# Input: dataset output root created by preprocessing
# Output: MANIFEST_ALL path used for train/val tables
# -------------------------
D2_OUT_ROOT = "/content/drive/MyDrive/AI_PD_Project/Datasets/D2-Slovak (EWA-DB)/EWA-DB/preprocessed_v1"
DX_OUT_ROOT = D2_OUT_ROOT  # standardized variable name the rest of the code uses

MANIFEST_ALL = f"{DX_OUT_ROOT}/manifests/manifest_all.csv"

# -------------------------
# Run settings and training defaults
# Purpose: fixed settings for consistent comparisons across runs
# -------------------------
MAX_EPOCHS = 10
EFFECTIVE_BS = 64

PER_DEVICE_BS = 16
GRAD_ACCUM = max(1, EFFECTIVE_BS // PER_DEVICE_BS)

LR = 1e-3
EARLY_STOP_PATIENCE = 2

SEEDS = [1337, 2024, 7777]

# Purpose: detect where trailing padding starts in vowel clips
TINY_THRESHOLD = 1e-4  # mandatory

# Purpose: stable DataLoader behavior on Drive-backed storage
NUM_WORKERS = 0
PIN_MEMORY = False

BACKBONE_CKPT = "facebook/wav2vec2-base"

# Purpose: small regularization in the heads
HEAD_DROPOUT_P = 0.10

# Purpose: faster training on GPU while keeping heads stable (heads forced FP32 later)
USE_AMP = True

# Purpose: fixed operating point for confusion matrix and threshold metrics
VAL_THRESHOLD = 0.5

# -------------------------
# Output folders and history tracking
# Output: one EXP_ROOT folder per execution, plus a global JSONL index
# -------------------------
EXPERIMENT_TAG = "frozen_LNDO"   # keep same meaning as D1 tag
RUN_STAMP = time.strftime("%Y%m%d_%H%M%S")

TRAINVAL_ROOT = Path(DX_OUT_ROOT) / "trainval_runs"
EXP_ROOT = TRAINVAL_ROOT / f"exp_{EXPERIMENT_TAG}_{RUN_STAMP}"
EXP_ROOT.mkdir(parents=True, exist_ok=True)

HISTORY_INDEX_PATH = TRAINVAL_ROOT / "history_index.jsonl"

# -------------------------
# Device and environment printout
# Purpose: capture key runtime settings in the notebook output
# -------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"

print("EXPERIMENT ROOT:", str(EXP_ROOT))
print("DEVICE:", device)
if device.type == "cuda":
    print("GPU:", gpu_name)
print("PER_DEVICE_BS:", PER_DEVICE_BS, "| GRAD_ACCUM:", GRAD_ACCUM, "| EFFECTIVE_BS:", PER_DEVICE_BS * GRAD_ACCUM)
print("NUM_WORKERS:", NUM_WORKERS, "| PIN_MEMORY:", PIN_MEMORY)
print("VAL_THRESHOLD:", VAL_THRESHOLD)
print("USE_AMP:", bool(USE_AMP and device.type == "cuda"))
print("")

# -------------------------
# Manifest existence check
# Input: MANIFEST_ALL
# Output: fail early with a clear error if missing
# -------------------------
if not os.path.exists(MANIFEST_ALL):
    raise FileNotFoundError(f"Missing manifest_all.csv: {MANIFEST_ALL}")

# -------------------------
# Build train/val tables from the manifest
# Inputs: manifest_all.csv
# Outputs: train_df and val_df, plus basic counts printed to output
# -------------------------
m = pd.read_csv(MANIFEST_ALL)

required_cols = ["split", "clip_path", "label_num", "task", "speaker_id", "duration_sec"]
for c in required_cols:
    if c not in m.columns:
        raise ValueError(f"manifest_all.csv missing required column: {c}")

# Purpose: train/val only in this cell
m = m[m["split"].isin(["train", "val"])].copy()
if len(m) == 0:
    raise RuntimeError("No rows with split in {'train','val'} found in manifest_all.csv")

# Purpose: pick the dominant dataset label if multiple datasets are present
if "dataset" in m.columns:
    dataset_id = str(m["dataset"].astype(str).value_counts().idxmax())
    m = m[m["dataset"].astype(str) == dataset_id].copy()
else:
    dataset_id = "DX"

train_df = m[m["split"] == "train"].copy().reset_index(drop=True)
val_df   = m[m["split"] == "val"].copy().reset_index(drop=True)

keep_cols = ["clip_path", "label_num", "task", "speaker_id", "duration_sec", "split"]
train_df = train_df[keep_cols].copy()
val_df   = val_df[keep_cols].copy()

def _label_counts(df: pd.DataFrame) -> Dict[int, int]:
    vc = df["label_num"].astype(int).value_counts().to_dict()
    return {0: int(vc.get(0, 0)), 1: int(vc.get(1, 0))}

print(f"Dataset inferred: {dataset_id}")
print(f"Train rows: {len(train_df)} | Val rows: {len(val_df)}")
print(f"Train label counts: {_label_counts(train_df)}")
print(f"Val label counts: {_label_counts(val_df)}")

# -------------------------
# Task grouping (vowel vs other)
# Purpose: route each clip through the matching head
# Output: task_group column used by dataset and model
# -------------------------
def make_task_group(task: Any) -> str:
    return "vowel" if str(task) == "vowl" else "other"

train_df["task_group"] = train_df["task"].apply(make_task_group)
val_df["task_group"]   = val_df["task"].apply(make_task_group)

# -------------------------
# Clip path existence check (progress bar)
# Inputs: train_df/val_df clip_path values
# Output: fail early with a few missing examples
# -------------------------
def check_paths_exist(df: pd.DataFrame, desc: str) -> None:
    paths = df["clip_path"].astype(str).tolist()
    missing = []
    for p in tqdm(paths, desc=desc, dynamic_ncols=True):
        if not os.path.exists(p):
            missing.append(p)
            if len(missing) >= 10:
                break
    if missing:
        raise FileNotFoundError(f"{desc}: missing clip_path(s), first examples:\n" + "\n".join(missing))

check_paths_exist(train_df, "Check TRAIN clip_path exists")
check_paths_exist(val_df, "Check VAL clip_path exists")

# -------------------------
# Dataset: read audio and build per-sample attention masks
# Inputs: one manifest row (clip_path, label_num, task_group)
# Output: tensors for input_values, attention_mask, labels, and task_group
# -------------------------
class AudioClipsDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        row = self.df.iloc[idx]
        path = str(row["clip_path"])
        label = int(row["label_num"])
        tg = str(row["task_group"])

        # Purpose: load audio, convert to mono, force float32
        x, sr = sf.read(path, always_2d=True)
        if x.shape[1] != 1:
            x = x.mean(axis=1, keepdims=True)
        x = x[:, 0].astype(np.float32, copy=False)

        # Purpose: enforce training sample rate
        if int(sr) != 16000:
            raise RuntimeError(f"Sample rate is not 16kHz for {path}: got {sr}")

        # Purpose: hide trailing padding in vowel clips; keep other clips fully visible
        if tg == "vowel":
            k = None
            for i in range(len(x) - 1, -1, -1):
                if abs(float(x[i])) > TINY_THRESHOLD:
                    k = i
                    break
            if k is None:
                attn = np.ones((len(x),), dtype=np.int64)
            else:
                attn = np.zeros((len(x),), dtype=np.int64)
                attn[:k+1] = 1
        else:
            attn = np.ones((len(x),), dtype=np.int64)

        return {
            "input_values": torch.tensor(x, dtype=torch.float32),
            "attention_mask": torch.tensor(attn, dtype=torch.long),
            "labels": torch.tensor(label, dtype=torch.long),
            "task_group": tg,
        }

# -------------------------
# Collator: dynamic padding and mask preservation
# Inputs: list of dataset items with variable length
# Output: padded batch tensors + task_group list
# -------------------------
def collate_batch(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
    lengths = [b["input_values"].shape[0] for b in batch]
    max_len = int(max(lengths)) if lengths else 0

    B = len(batch)
    x = torch.zeros((B, max_len), dtype=torch.float32)
    msk = torch.zeros((B, max_len), dtype=torch.long)
    y = torch.zeros((B,), dtype=torch.long)
    tg = []

    for i, b in enumerate(batch):
        t = b["input_values"]
        a = b["attention_mask"]
        n = t.shape[0]
        x[i, :n] = t
        msk[i, :n] = a
        y[i] = b["labels"]
        tg.append(b["task_group"])

    return {"input_values": x, "attention_mask": msk, "labels": y, "task_group": tg}

# -------------------------
# Model: frozen Wav2Vec2 backbone + two small heads
# Inputs: audio batch + attention_mask + task_group
# Outputs: loss and logits, using the matching head per sample
# -------------------------
class FrozenW2V2TwoHead(nn.Module):
    def __init__(self, backbone_ckpt: str, dropout_p: float):
        super().__init__()

        # Purpose: load backbone once and freeze it
        self.backbone = Wav2Vec2Model.from_pretrained(backbone_ckpt)
        for p in self.backbone.parameters():
            p.requires_grad = False

        H = int(self.backbone.config.hidden_size)

        # Purpose: small trainable heads, one per task group
        self.head_vowel = nn.Sequential(
            nn.LayerNorm(H),
            nn.Dropout(dropout_p),
            nn.Linear(H, 2),
        )
        self.head_other = nn.Sequential(
            nn.LayerNorm(H),
            nn.Dropout(dropout_p),
            nn.Linear(H, 2),
        )

        self.loss_fn = nn.CrossEntropyLoss()

    def masked_mean_pool(self, last_hidden: torch.Tensor, attn_mask_samples: torch.Tensor) -> torch.Tensor:
        # Purpose: map the sample-space mask to feature-space mask for pooling
        B, T_feat, H = last_hidden.shape
        feat_mask = self.backbone._get_feature_vector_attention_mask(T_feat, attn_mask_samples)  # [B,T_feat] bool
        feat_mask_f = feat_mask.to(last_hidden.dtype)
        denom = feat_mask_f.sum(dim=1).clamp(min=1.0).unsqueeze(-1)
        pooled = (last_hidden * feat_mask_f.unsqueeze(-1)).sum(dim=1) / denom
        return pooled

    def _heads_fp32(self, pooled_any: torch.Tensor, head: nn.Module) -> torch.Tensor:
        # Purpose: run heads in float32 even when autocast is enabled
        pooled_fp32 = pooled_any.float()
        if pooled_fp32.is_cuda:
            with torch.amp.autocast(device_type="cuda", enabled=False):
                return head(pooled_fp32)
        return head(pooled_fp32)

    def forward(self, input_values: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor, task_group: List[str]):
        out = self.backbone(input_values=input_values, attention_mask=attention_mask)
        last_hidden = out.last_hidden_state  # [B, T_feat, H]

        pooled = self.masked_mean_pool(last_hidden, attention_mask)  # [B, H]

        # Purpose: fill logits by routing each sample to its head
        B = pooled.shape[0]
        logits = torch.zeros((B, 2), dtype=torch.float32, device=pooled.device)

        idx_v = [i for i, tg in enumerate(task_group) if tg == "vowel"]
        idx_o = [i for i, tg in enumerate(task_group) if tg != "vowel"]

        if idx_v:
            pv = pooled[idx_v]
            lv = self._heads_fp32(pv, self.head_vowel)   # FP32
            logits[idx_v] = lv

        if idx_o:
            po = pooled[idx_o]
            lo = self._heads_fp32(po, self.head_other)   # FP32
            logits[idx_o] = lo

        loss = self.loss_fn(logits, labels)
        return loss, logits

# -------------------------
# Metrics and plots (validation only)
# Inputs: y_true and PD probability
# Outputs: AUROC, threshold metrics, and PNG figures
# -------------------------
def compute_val_auroc(probs1: np.ndarray, y_true: np.ndarray) -> float:
    if len(np.unique(y_true)) < 2:
        return float("nan")
    return float(roc_auc_score(y_true, probs1))

def compute_threshold_metrics(y_true: np.ndarray, probs1: np.ndarray, thr: float = 0.5) -> Dict[str, Any]:
    y_true = y_true.astype(int)
    y_pred = (probs1 >= thr).astype(int)

    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    if cm.shape != (2, 2):
        TN = FP = FN = TP = 0
    else:
        TN, FP, FN, TP = cm[0, 0], cm[0, 1], cm[1, 0], cm[1, 1]

    eps = 1e-12
    accuracy = (TP + TN) / max(1, (TP + TN + FP + FN))
    precision = TP / (TP + FP + eps)
    recall = TP / (TP + FN + eps)            # sensitivity
    f1 = 2 * precision * recall / (precision + recall + eps)
    specificity = TN / (TN + FP + eps)

    mcc = float(matthews_corrcoef(y_true, y_pred)) if len(np.unique(y_pred)) > 1 else float("nan")

    # Purpose: quick association test between predictions and truth
    try:
        _, p_value = fisher_exact([[TN, FP], [FN, TP]], alternative="two-sided")
        p_value = float(p_value)
    except Exception:
        p_value = float("nan")

    return {
        "threshold": float(thr),
        "confusion_matrix": {"TN": int(TN), "FP": int(FP), "FN": int(FN), "TP": int(TP),
        },
        "accuracy": float(accuracy),
        "precision": float(precision),
        "recall": float(recall),
        "f1_score": float(f1),
        "sensitivity": float(recall),
        "specificity": float(specificity),
        "mcc": float(mcc),
        "p_value_fisher_two_sided": float(p_value),
    }

def save_roc_curve(y_true: np.ndarray, probs1: np.ndarray, out_png: str) -> None:
    fpr, tpr, _ = roc_curve(y_true, probs1)
    plt.figure()
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curve (Val)")
    plt.tight_layout()
    plt.savefig(out_png, dpi=160)
    plt.close()

def save_confusion_matrix(y_true: np.ndarray, probs1: np.ndarray, out_png: str, thr: float = 0.5) -> None:
    y_pred = (probs1 >= thr).astype(int)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    plt.figure()
    plt.imshow(cm, interpolation="nearest")
    plt.title(f"Confusion Matrix (Val) @ thr={thr:.2f}")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.xticks([0, 1], ["HC(0)", "PD(1)"])
    plt.yticks([0, 1], ["HC(0)", "PD(1)"])
    for (i, j), v in np.ndenumerate(cm):
        plt.text(j, i, str(v), ha="center", va="center")
    plt.tight_layout()
    plt.savefig(out_png, dpi=160)
    plt.close()

# -------------------------
# Deterministic seeding
# Purpose: stabilize shuffling and initialization per seed
# -------------------------
def set_all_seeds(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# -------------------------
# Warm-up loader
# Purpose: catch I/O and collation issues before training starts
# -------------------------
def warmup_loader(loader: DataLoader, n_batches: int = 3, seed: int = 0) -> None:
    t0 = time.time()
    print(f"[seed={seed}] Warm-up: loading {n_batches} train batches...")
    it = iter(loader)
    for i in range(n_batches):
        _ = next(it)
        print(f"  loaded warmup batch {i+1}/{n_batches}")
    dt = time.time() - t0
    print(f"[seed={seed}] Warm-up done in {dt:.2f}s")

# -------------------------
# One seed run: train, validate, save best epoch
# Inputs: seed, train_df, val_df
# Outputs: run folder artifacts + metrics dict used for experiment summary
# -------------------------
def run_one_seed(seed: int) -> Dict[str, Any]:
    set_all_seeds(seed)

    run_dir_path = EXP_ROOT / f"run_{dataset_id}_seed{seed}"
    run_dir_path.mkdir(parents=True, exist_ok=True)
    run_dir = str(run_dir_path)  # keep downstream path handling unchanged

    train_ds = AudioClipsDataset(train_df)
    val_ds   = AudioClipsDataset(val_df)

    train_loader = DataLoader(
        train_ds,
        batch_size=PER_DEVICE_BS,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        collate_fn=collate_batch,
        drop_last=False,
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=PER_DEVICE_BS,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        collate_fn=collate_batch,
        drop_last=False,
    )

    warmup_loader(train_loader, n_batches=3, seed=seed)

    model = FrozenW2V2TwoHead(BACKBONE_CKPT, HEAD_DROPOUT_P).to(device)

    # Purpose: only head weights update
    params = list(model.head_vowel.parameters()) + list(model.head_other.parameters())
    opt = torch.optim.Adam(params, lr=LR)

    use_amp = bool(USE_AMP and device.type == "cuda")
    scaler = torch.amp.GradScaler("cuda", enabled=use_amp)

    best_auroc = -1.0
    best_epoch = -1
    best_state = None
    no_improve = 0

    best_probs = None
    best_ytrue = None

    for epoch in range(1, MAX_EPOCHS + 1):
        # Train loop
        model.train()
        opt.zero_grad(set_to_none=True)

        running_loss = 0.0
        n_steps = 0

        pbar = tqdm(train_loader, desc=f"[seed={seed}] Train epoch {epoch}", dynamic_ncols=True)
        for step, batch in enumerate(pbar, start=1):
            x = batch["input_values"].to(device, non_blocking=True)
            msk = batch["attention_mask"].to(device, non_blocking=True)
            y = batch["labels"].to(device, non_blocking=True)
            tg = batch["task_group"]

            with torch.amp.autocast(device_type="cuda", enabled=use_amp):
                loss, _ = model(x, msk, y, tg)
                loss = loss / GRAD_ACCUM

            scaler.scale(loss).backward()

            if (step % GRAD_ACCUM) == 0:
                scaler.step(opt)
                scaler.update()
                opt.zero_grad(set_to_none=True)

            running_loss += float(loss.item()) * GRAD_ACCUM
            n_steps += 1

        # Purpose: flush last partial accumulation if needed
        if (n_steps % GRAD_ACCUM) != 0:
            scaler.step(opt)
            scaler.update()
            opt.zero_grad(set_to_none=True)

        avg_train_loss = running_loss / max(1, n_steps)

        # Validation loop
        model.eval()
        y_true_all = []
        p1_all = []

        pbarv = tqdm(val_loader, desc=f"[seed={seed}] Val epoch {epoch}", dynamic_ncols=True)
        with torch.no_grad():
            for batch in pbarv:
                x = batch["input_values"].to(device, non_blocking=True)
                msk = batch["attention_mask"].to(device, non_blocking=True)
                y = batch["labels"].to(device, non_blocking=True)
                tg = batch["task_group"]

                with torch.amp.autocast(device_type="cuda", enabled=use_amp):
                    _, logits = model(x, msk, y, tg)

                probs = torch.softmax(logits.float(), dim=-1)[:, 1]
                y_true_all.append(y.detach().cpu().numpy())
                p1_all.append(probs.detach().cpu().numpy())

        y_true = np.concatenate(y_true_all, axis=0).astype(int)
        p1 = np.concatenate(p1_all, axis=0).astype(np.float64)

        val_auroc = compute_val_auroc(p1, y_true)

        print(f"seed={seed} | epoch {epoch:02d}/{MAX_EPOCHS} | train_loss={avg_train_loss:.5f} | val_AUROC={val_auroc:.5f}")

        # Purpose: keep only the best epoch by VAL AUROC
        improved = (not math.isnan(val_auroc)) and (val_auroc > best_auroc + 1e-12)
        if improved:
            best_auroc = float(val_auroc)
            best_epoch = int(epoch)
            no_improve = 0

            best_state = {
                "head_vowel": {k: v.detach().cpu() for k, v in model.head_vowel.state_dict().items()},
                "head_other": {k: v.detach().cpu() for k, v in model.head_other.state_dict().items()},
                "backbone_ckpt": BACKBONE_CKPT,
                "hidden_size": int(model.backbone.config.hidden_size),
                "dropout_p": float(HEAD_DROPOUT_P),
                "dataset_id": dataset_id,
                "seed": int(seed),
                "val_threshold": float(VAL_THRESHOLD),
            }
            best_probs = p1.copy()
            best_ytrue = y_true.copy()
        else:
            no_improve += 1
            if no_improve >= EARLY_STOP_PATIENCE:
                break

    # Save best epoch artifacts only
    metrics_path = os.path.join(run_dir, "metrics.json")
    roc_png = os.path.join(run_dir, "roc_curve.png")
    cm_png = os.path.join(run_dir, "confusion_matrix.png")
    heads_path = os.path.join(run_dir, "best_heads.pt")

    if best_state is None or best_probs is None or best_ytrue is None:
        raise RuntimeError(f"[seed={seed}] No valid best_state captured; cannot write artifacts.")

    save_roc_curve(best_ytrue, best_probs, roc_png)
    save_confusion_matrix(best_ytrue, best_probs, cm_png, thr=VAL_THRESHOLD)
    torch.save(best_state, heads_path)

    thr_metrics = compute_threshold_metrics(best_ytrue, best_probs, thr=VAL_THRESHOLD)

    # Output: metrics.json (includes best AUROC and threshold metrics)
    metrics = {
        "dataset": dataset_id,
        "seed": int(seed),
        "best_val_auroc": float(best_auroc),
        "best_epoch": int(best_epoch),
        "max_epochs": int(MAX_EPOCHS),
        "early_stop_patience": int(EARLY_STOP_PATIENCE),
        "n_train": int(len(train_df)),
        "n_val": int(len(val_df)),
        "label_counts_train": _label_counts(train_df),
        "label_counts_val": _label_counts(val_df),
        "backbone_ckpt": BACKBONE_CKPT,
        "frozen_backbone": True,
        "heads": {
            "vowel": "LayerNorm + Dropout + Linear(768->2)",
            "other": "LayerNorm + Dropout + Linear(768->2)",
        },
        "threshold_metrics_best_epoch": thr_metrics,
        "padding_mask_rule": {
            "tiny_threshold": TINY_THRESHOLD,
            "vowel": "trailing abs(x) > tiny_threshold => mask=1 else 0",
            "other": "all ones",
            "batch_pad": "zero pad + mask zeros",
        },
        "batching": {
            "per_device_bs": int(PER_DEVICE_BS),
            "grad_accum": int(GRAD_ACCUM),
            "effective_bs": int(PER_DEVICE_BS * GRAD_ACCUM),
        },
        "optimizer": {"name": "Adam", "lr": float(LR)},
        "amp": bool(use_amp),
        "device": str(device),
        "gpu": str(gpu_name),
        "experiment_tag": EXPERIMENT_TAG,
        "run_stamp": RUN_STAMP,
        "val_threshold": float(VAL_THRESHOLD),
        "dropout_p": float(HEAD_DROPOUT_P),
        "lr_scalar": float(LR),
        "effective_batch_size": int(PER_DEVICE_BS * GRAD_ACCUM),
        "per_device_batch_size": int(PER_DEVICE_BS),
        "grad_accum_steps": int(GRAD_ACCUM),
        "paths": {
            "manifest_all": MANIFEST_ALL,
            "run_dir": run_dir,
            "roc_curve_png": roc_png,
            "confusion_matrix_png": cm_png,
            "best_heads_pt": heads_path,
        },
    }

    with open(metrics_path, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2)

    print(f"[seed={seed}] WROTE:\n  {metrics_path}\n  {roc_png}\n  {cm_png}\n  {heads_path}\n")

    return metrics

# -------------------------
# Run all seeds and write experiment summary
# Inputs: SEEDS list
# Outputs: summary_trainval.json and appended history_index.jsonl
# -------------------------
results = []
for sd in SEEDS:
    results.append(run_one_seed(sd))

aurocs = np.array([r["best_val_auroc"] for r in results], dtype=np.float64)

print("AUROC by seed:")
for r in results:
    print(f"  seed {r['seed']}: {r['best_val_auroc']:.6f}")

t_crit = 4.302652729911275  # df=2, 95% CI
n = len(aurocs)
mean_auroc = float(np.mean(aurocs))
std_auc = float(np.std(aurocs, ddof=1)) if n > 1 else 0.0
half_width = float(t_crit * (std_auc / math.sqrt(n))) if n > 1 else 0.0
ci_low = float(mean_auroc - half_width)
ci_high = float(mean_auroc + half_width)

print("")
print(f"Mean AUROC: {mean_auroc:.6f}")
print(f"95% CI (t, n=3): [{ci_low:.6f}, {ci_high:.6f}]")
print("")

# Purpose: keep a compact per-seed record inside the experiment summary
per_seed_metrics = []
run_dirs = []
for r in results:
    per_seed_metrics.append({
        "seed": int(r["seed"]),
        "best_val_auroc": float(r["best_val_auroc"]),
        "threshold_metrics_best_epoch": r["threshold_metrics_best_epoch"],
    })
    run_dirs.append(r["paths"]["run_dir"])

# Output: one summary JSON for the whole experiment
exp_summary = {
    "dataset": dataset_id,
    "dx_out_root": DX_OUT_ROOT,
    "experiment_tag": EXPERIMENT_TAG,
    "run_stamp": RUN_STAMP,
    "exp_root": str(EXP_ROOT),
    "run_dirs": run_dirs,
    "seeds": SEEDS,
    "aurocs": [float(x) for x in aurocs.tolist()],  # note: keep spelling consistent with earlier summaries
    "mean_auroc": float(mean_auroc),
    "t_crit_df2_95": float(t_crit),
    "half_width_95_ci": float(half_width),
    "ci95": [float(ci_low), float(ci_high)],
    "n_train": int(len(train_df)),
    "n_val": int(len(val_df)),
    "label_counts_train": _label_counts(train_df),
    "label_counts_val": _label_counts(val_df),
    "effective_batch_size": int(PER_DEVICE_BS * GRAD_ACCUM),
    "per_device_batch_size": int(PER_DEVICE_BS),
    "grad_accum_steps": int(GRAD_ACCUM),
    "val_threshold": float(VAL_THRESHOLD),
    "backbone_ckpt": BACKBONE_CKPT,
    "dropout_p": float(HEAD_DROPOUT_P),
    "lr": float(LR),
    "per_seed_best_epoch_metrics": per_seed_metrics,
}

summary_path = Path(EXP_ROOT) / "summary_trainval.json"
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(exp_summary, f, indent=2)

# Output: append one line per experiment so older records remain intact
with open(HISTORY_INDEX_PATH, "a", encoding="utf-8") as f:
    f.write(json.dumps(exp_summary) + "\n")

print("WROTE per-experiment summary:", str(summary_path))
print("APPENDED global history index:", str(HISTORY_INDEX_PATH))
print("")
print("Open this folder to access artifacts:", str(EXP_ROOT))

# -------------------------
# Stop runtime (release GPU)
# -------------------------
print("\nAll done. Unassigning the runtime to stop the L4 instance...")
try:
    from google.colab import runtime  # type: ignore
    print("Calling runtime.unassign() now.")
    runtime.unassign()
except Exception as e:
    print("Could not unassign runtime automatically. You can stop the runtime manually in Colab.")
    print("Reason:", repr(e))

The following cell trains and validates the D4 model using only the training and validation splits from the prepared dataset folder (`DX_OUT_ROOT`) and its `manifests/manifest_all.csv`. It is written to be robust against failures by running early checks, showing progress bars, and saving each run in its own folder so results are never overwritten.

The cell expects `DX_OUT_ROOT` to already be defined from the D4 preprocessing step. It creates a new experiment folder under `trainval_runs/` using a fixed tag (`frozen_LNDO`) and a timestamp, which keeps each run separate for later reference. Training settings are fixed for consistency across runs, including three random seeds (1337, 2024, 7777), up to 10 epochs, an effective batch size of 64 using gradient accumulation, a learning rate of 1e-3, early stopping with a patience of 2 epochs, and a required audio sample rate of 16 kHz.

The manifest is then loaded and filtered to keep only rows marked as `train` or `val`. The dataset name is inferred from the most common value in the `dataset` column. The cell prints the size of each split and the label counts, and stops immediately if any audio file paths listed in the manifest do not exist. Each clip is also assigned to a simple task group, where `vowl` is treated as vowel speech and all other values are treated as non vowel speech.

A custom dataset and batch collation step is used to load audio files, convert stereo audio to mono when needed, and confirm that the sample rate is exactly 16 kHz. An attention mask is created so that padded silence at the end of vowel clips, identified by near zero values, is ignored during training. Each batch is padded to the length of the longest clip in that batch, and the attention masks are padded in the same way so the model does not learn from added padding.

The model consists of a frozen Wav2Vec2 backbone, downloaded using safetensors for compatibility, and two small trainable classifiers, one for vowel clips and one for other clips. Each classifier includes a small pre head block made of LayerNorm and Dropout followed by a linear layer. Only these small blocks and heads are trained, while the backbone remains fixed.

For each random seed, the cell builds the training and validation data loaders and runs a short warm up by loading three training batches to catch input output or batching issues early. Training then proceeds epoch by epoch with validation at the end of each epoch, and validation AUROC is computed. Early stopping is applied when AUROC does not improve for two consecutive epochs. Only the best validation epoch for each seed is kept, and its outputs are saved under `exp_<tag>_<timestamp>/run_<dataset>_seed####/`.

For each seed, the saved outputs include `best_heads.pt` with the trained head weights and pre head blocks, `roc_curve.png` showing the validation ROC curve, `confusion_matrix.png` showing the validation confusion matrix at a threshold of 0.5, and a `metrics.json` file containing the best epoch, best validation AUROC, dataset counts, key hyperparameters, and additional threshold based metrics at 0.5 such as accuracy, precision, recall or sensitivity, specificity, F1 score, MCC, and a Fisher exact test p value.

After all three seeds finish, the cell prints the AUROC for each seed and the mean AUROC with a 95% confidence interval computed using a t distribution with n=3. It then writes a `summary_trainval.json` file inside the experiment folder with the full summary across seeds, and appends a one line record to `trainval_runs/history_index.jsonl` as a global log of experiments.

Finally, the cell prints the output folder locations and unassigns the Colab runtime to cleanly shut down the L4 GPU session.

In [None]:
# =========================
# D4 Train + Val — Frozen Backbone, Two Heads, Full Run History
# Inputs: manifest_all.csv (train/val only) and the audio clips referenced by clip_path
# Outputs: per-seed run folders with metrics.json, roc_curve.png, confusion_matrix.png, best_heads.pt
#          plus one experiment summary_trainval.json and one appended history_index.jsonl record
# =========================

import os, json, math, random, time, warnings
from pathlib import Path

import numpy as np
import pandas as pd
import soundfile as sf

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import Wav2Vec2Model

from sklearn.metrics import (
    roc_auc_score, roc_curve,
    confusion_matrix, accuracy_score,
    precision_recall_fscore_support,
    matthews_corrcoef
)
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# -------------------------
# Dataset root from preprocessing runtime
# Purpose: reuse the dataset output root already set by preprocessing
# Input: DX_OUT_ROOT must exist in memory
# -------------------------
if "DX_OUT_ROOT" not in globals():
    raise RuntimeError("DX_OUT_ROOT is not defined in the runtime. Run the dataset preprocessing cell first.")
DX_OUT_ROOT = str(globals()["DX_OUT_ROOT"])
MANIFEST_ALL = f"{DX_OUT_ROOT}/manifests/manifest_all.csv"

# -------------------------
# Experiment folder naming (no overwrites)
# Purpose: keep results from every run, even when rerunning the notebook
# Output: EXP_ROOT holds all per-seed outputs for this execution
# -------------------------
EXPERIMENT_TAG = "frozen_LNDO"
RUN_STAMP = time.strftime("%Y%m%d_%H%M%S")

TRAINVAL_ROOT = Path(DX_OUT_ROOT) / "trainval_runs"
EXP_ROOT = TRAINVAL_ROOT / f"exp_{EXPERIMENT_TAG}_{RUN_STAMP}"
EXP_ROOT.mkdir(parents=True, exist_ok=True)

# -------------------------
# Training defaults
# Purpose: fixed settings for consistent comparison across datasets
# -------------------------
MAX_EPOCHS     = 10
EFFECTIVE_BS   = 64
PER_DEVICE_BS  = 16
GRAD_ACCUM     = max(1, EFFECTIVE_BS // PER_DEVICE_BS)

LR             = 1e-3
PATIENCE       = 2
SEEDS          = [1337, 2024, 7777]

BACKBONE_CKPT  = "facebook/wav2vec2-base"
SR_EXPECTED    = 16000
TINY_THRESH    = 1e-4

# Small pre-head block (trainable)
DROPOUT_P      = 0.2

# Drive-friendly loader defaults
NUM_WORKERS    = 0
PIN_MEMORY     = False

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
warnings.filterwarnings("ignore", message="Passing `gradient_checkpointing` to a config initialization is deprecated")
warnings.filterwarnings("ignore", category=UserWarning, module="huggingface_hub")

print("DX_OUT_ROOT:", DX_OUT_ROOT)
print("MANIFEST_ALL:", MANIFEST_ALL)
print("DEVICE:", DEVICE)
if DEVICE.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
print("PER_DEVICE_BS:", PER_DEVICE_BS, "| GRAD_ACCUM:", GRAD_ACCUM, "| EFFECTIVE_BS:", PER_DEVICE_BS * GRAD_ACCUM)
print("NUM_WORKERS:", NUM_WORKERS, "| PIN_MEMORY:", PIN_MEMORY)
print("EXPERIMENT_TAG:", EXPERIMENT_TAG, "| RUN_STAMP:", RUN_STAMP)

# -------------------------
# Manifest load and split filtering
# Input: manifest_all.csv
# Output: train_df and val_df for this dataset only
# -------------------------
if not os.path.exists(MANIFEST_ALL):
    raise FileNotFoundError(f"Missing manifest_all.csv: {MANIFEST_ALL}")

m = pd.read_csv(MANIFEST_ALL)

# Purpose: fail early if required columns are missing
req_cols = {"split", "clip_path", "label_num", "task"}
missing = [c for c in sorted(req_cols) if c not in m.columns]
if missing:
    raise ValueError(f"Manifest missing required columns: {missing}. Found: {list(m.columns)}")

# Purpose: this cell is train/val only
m = m[m["split"].isin(["train", "val"])].copy()
if len(m) == 0:
    raise RuntimeError("After filtering to split in {train,val}, manifest has 0 rows.")

# Purpose: pick the dominant dataset label if the file contains multiple datasets
if "dataset" in m.columns and m["dataset"].notna().any():
    dataset_id = str(m["dataset"].value_counts(dropna=True).idxmax())
    m = m[m["dataset"].astype(str) == dataset_id].copy()
else:
    dataset_id = "DX"

# Purpose: keep a consistent set of columns even if some are missing
keep_cols = ["clip_path", "label_num", "task", "speaker_id", "duration_sec", "split"]
for c in keep_cols:
    if c not in m.columns:
        m[c] = np.nan
m = m[keep_cols].copy()

train_df = m[m["split"] == "train"].copy().reset_index(drop=True)
val_df   = m[m["split"] == "val"].copy().reset_index(drop=True)

print(f"\nDataset inferred: {dataset_id}")
print(f"Train rows: {len(train_df)} | Val rows: {len(val_df)}")
print("Train label counts:", train_df["label_num"].value_counts(dropna=False).to_dict())
print("Val label counts:",   val_df["label_num"].value_counts(dropna=False).to_dict())

# Purpose: avoid training on empty tables
if len(train_df) == 0 or len(val_df) == 0:
    raise RuntimeError("Train or Val split has 0 rows.")

# -------------------------
# Clip path existence check
# Input: clip_path values in train_df and val_df
# Output: raises with a few missing examples instead of failing mid-training
# -------------------------
def _fail_fast_missing_paths(df: pd.DataFrame, name: str):
    missing_paths = []
    for p in tqdm(df["clip_path"].astype(str).tolist(), desc=f"Check {name} clip_path exists", dynamic_ncols=True):
        if not os.path.exists(p):
            missing_paths.append(p)
            if len(missing_paths) >= 10:
                break
    if missing_paths:
        raise FileNotFoundError(f"{name}: missing clip_path(s). Examples: {missing_paths}")

_fail_fast_missing_paths(train_df, "TRAIN")
_fail_fast_missing_paths(val_df, "VAL")

# -------------------------
# Task grouping
# Purpose: route each clip to the vowel head or the other head
# Output: task_group column used by dataset and model
# -------------------------
def _task_group(task_val) -> str:
    return "vowel" if str(task_val) == "vowl" else "other"

train_df["task_group"] = train_df["task"].apply(_task_group)
val_df["task_group"]   = val_df["task"].apply(_task_group)

# -------------------------
# Dataset and batch collation
# Inputs: manifest rows (clip_path, label_num, task_group)
# Outputs: padded batches with attention_mask kept aligned with audio samples
# -------------------------
class AudioManifestDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        clip_path = str(row["clip_path"])
        label = int(row["label_num"])
        task_group = str(row["task_group"])

        if not os.path.exists(clip_path):
            raise FileNotFoundError(f"Missing clip_path: {clip_path}")

        # Purpose: load audio, convert stereo to mono, force float32
        y, sr = sf.read(clip_path, always_2d=False)
        if isinstance(y, np.ndarray) and y.ndim > 1:
            y = np.mean(y, axis=1)
        y = np.asarray(y, dtype=np.float32)

        # Purpose: enforce a single training sample rate
        if int(sr) != SR_EXPECTED:
            raise ValueError(f"Sample rate mismatch (expected {SR_EXPECTED}): {clip_path} has sr={sr}")

        # Purpose: build an attention mask in sample space
        attn = np.ones((len(y),), dtype=np.int64)

        # Purpose: for vowel clips, hide trailing zero padding so pooling ignores it
        if task_group == "vowel":
            k = -1
            for j in range(len(y) - 1, -1, -1):
                if abs(float(y[j])) > TINY_THRESH:
                    k = j
                    break
            if k >= 0:
                attn[:k+1] = 1
                attn[k+1:] = 0
            else:
                attn[:] = 1
        else:
            attn[:] = 1

        return {
            "input_values": torch.from_numpy(y),                 # float32 [T]
            "attention_mask": torch.from_numpy(attn),            # int64   [T]
            "labels": torch.tensor(label, dtype=torch.long),     # int64   []
            "task_group": task_group,                            # str
        }

def collate_fn(batch):
    # Purpose: pad all clips in the batch to the same length using zeros
    max_len = int(max(b["input_values"].numel() for b in batch))
    input_vals, attn_masks, labels, task_groups = [], [], [], []
    for b in batch:
        x = b["input_values"]
        a = b["attention_mask"]
        pad = max_len - x.numel()
        if pad > 0:
            x = torch.cat([x, torch.zeros(pad, dtype=x.dtype)], dim=0)
            a = torch.cat([a, torch.zeros(pad, dtype=a.dtype)], dim=0)
        input_vals.append(x)
        attn_masks.append(a)
        labels.append(b["labels"])
        task_groups.append(b["task_group"])
    return {
        "input_values": torch.stack(input_vals, dim=0),      # [B,T]
        "attention_mask": torch.stack(attn_masks, dim=0),    # [B,T]
        "labels": torch.stack(labels, dim=0),                # [B]
        "task_group": task_groups,                           # list[str]
    }

# -------------------------
# Model definition
# Inputs: backbone checkpoint and dropout probability
# Output: loss and logits (PD vs Healthy) chosen by task_group
# -------------------------
class Wav2Vec2TwoHeadClassifier(nn.Module):
    def __init__(self, ckpt: str, dropout_p: float):
        super().__init__()

        # Purpose: load weights via safetensors for safer, consistent loading
        self.backbone = Wav2Vec2Model.from_pretrained(
            ckpt,
            use_safetensors=True,      # <-- critical fix
            local_files_only=False
        )
        for p in self.backbone.parameters():
            p.requires_grad = False

        H = int(self.backbone.config.hidden_size)

        # Purpose: small per-task feature cleanup before each head
        self.pre_vowel = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.pre_other = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))

        self.head_vowel = nn.Linear(H, 2)
        self.head_other = nn.Linear(H, 2)
        self.loss_fn = nn.CrossEntropyLoss()

    def masked_mean_pool(self, last_hidden, attn_mask_samples):
        # Purpose: convert sample-level mask to feature-level mask for pooling
        feat_mask = self.backbone._get_feature_vector_attention_mask(last_hidden.shape[1], attn_mask_samples)
        feat_mask = feat_mask.to(last_hidden.device).unsqueeze(-1).type_as(last_hidden)
        masked = last_hidden * feat_mask
        denom = feat_mask.sum(dim=1).clamp(min=1.0)
        return masked.sum(dim=1) / denom

    def forward(self, input_values, attention_mask, labels, task_group):
        # Purpose: backbone stays frozen and runs without gradients
        with torch.no_grad():
            out = self.backbone(input_values=input_values, attention_mask=attention_mask)
            last_hidden = out.last_hidden_state  # [B,T',H]

        pooled = self.masked_mean_pool(last_hidden, attention_mask)  # [B,H]
        pooled = pooled.float()  # keep heads stable and avoid dtype surprises

        z_v = self.pre_vowel(pooled)
        z_o = self.pre_other(pooled)

        logits_v = self.head_vowel(z_v)  # float32
        logits_o = self.head_other(z_o)  # float32

        # Purpose: choose the matching head per sample
        is_vowel = torch.tensor([tg == "vowel" for tg in task_group], device=pooled.device, dtype=torch.bool)
        logits = logits_o.clone()
        if is_vowel.any():
            logits[is_vowel] = logits_v[is_vowel]

        loss = self.loss_fn(logits, labels)
        return loss, logits

# -------------------------
# Metric and plot helpers (VAL)
# Inputs: y_true and predicted PD probability
# Outputs: AUROC, threshold metrics, and PNG plots
# -------------------------
def compute_auc(y_true, y_prob):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    if len(np.unique(y_true)) < 2:
        return float("nan")
    return float(roc_auc_score(y_true, y_prob))

def save_roc_curve_png(y_true, y_prob, out_png):
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    plt.figure()
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curve (Val)")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def save_confusion_png(y_true, y_prob, out_png):
    # Purpose: show thresholded performance at a fixed 0.5 cutoff
    y_pred = (np.asarray(y_prob) >= 0.5).astype(int)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    plt.figure()
    plt.imshow(cm)
    plt.xticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.yticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix (Val, thr=0.5)")
    for (i, j), v in np.ndenumerate(cm):
        plt.text(j, i, str(v), ha="center", va="center")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def compute_threshold_metrics(y_true, y_prob, thr=0.5):
    # Purpose: extra “single-number” metrics at a fixed cutoff
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    tn, fp, fn, tp = (cm.ravel().tolist() if cm.size == 4 else [0, 0, 0, 0])

    acc = float(accuracy_score(y_true, y_pred))
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", zero_division=0)
    mcc = float(matthews_corrcoef(y_true, y_pred)) if len(np.unique(y_true)) > 1 else float("nan")

    sensitivity = float(rec)
    specificity = float(tn / (tn + fp)) if (tn + fp) > 0 else float("nan")

    # Purpose: quick association test between prediction and truth
    p_value = float("nan")
    try:
        from scipy.stats import fisher_exact  # type: ignore
        _, p_value = fisher_exact([[tn, fp], [fn, tp]], alternative="two-sided")
        p_value = float(p_value)
    except Exception:
        p_value = float("nan")

    return {
        "thr": float(thr),
        "tn": int(tn), "fp": int(fp), "fn": int(fn), "tp": int(tp),
        "accuracy": float(acc),
        "precision": float(prec),
        "recall": float(rec),
        "f1": float(f1),
        "sensitivity": float(sensitivity),
        "specificity": float(specificity),
        "mcc": float(mcc),
        "p_value_fisher": float(p_value),
    }

# -------------------------
# Random seed control
# Purpose: stabilize shuffling and weight init per run
# -------------------------
def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# -------------------------
# One seed run: train, validate, save best epoch
# Inputs: seed, train_df, val_df
# Outputs: artifacts in run_dir and a metrics dict for exp summary
# -------------------------
def run_trainval_once(seed: int):
    set_all_seeds(seed)

    run_dir = EXP_ROOT / f"run_{dataset_id}_seed{seed}"
    run_dir.mkdir(parents=True, exist_ok=True)

    # Purpose: build loaders from manifest tables
    train_ds = AudioManifestDataset(train_df)
    val_ds   = AudioManifestDataset(val_df)

    train_loader = DataLoader(
        train_ds,
        batch_size=PER_DEVICE_BS,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        collate_fn=collate_fn
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=PER_DEVICE_BS,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        collate_fn=collate_fn
    )

    # Purpose: confirm file I/O and collation are working before training starts
    print(f"\n[seed={seed}] Warm-up: loading 3 train batches...")
    t0 = time.time()
    it = iter(train_loader)
    for i in range(3):
        _ = next(it)
        print(f"  loaded warmup batch {i+1}/3")
    print(f"[seed={seed}] Warm-up done in {time.time()-t0:.2f}s")

    model = Wav2Vec2TwoHeadClassifier(BACKBONE_CKPT, dropout_p=DROPOUT_P).to(DEVICE)

    # Purpose: train only the small pre-head blocks and heads
    trainable_params = (
        list(model.pre_vowel.parameters()) + list(model.pre_other.parameters()) +
        list(model.head_vowel.parameters()) + list(model.head_other.parameters())
    )
    opt = torch.optim.Adam(trainable_params, lr=LR)

    best_auc = -1.0
    best_epoch = -1
    no_improve = 0
    best_state = None
    best_val_probs = None
    best_val_true = None
    best_val_metrics_thr = None

    for epoch in range(1, MAX_EPOCHS + 1):
        # Train
        model.train()
        train_losses = []
        opt.zero_grad(set_to_none=True)

        pbar = tqdm(train_loader, desc=f"[seed={seed}] Train epoch {epoch}", dynamic_ncols=True)
        step = 0
        for batch in pbar:
            step += 1
            input_values = batch["input_values"].to(DEVICE, non_blocking=False)
            attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
            labels = batch["labels"].to(DEVICE, non_blocking=False)
            task_group = batch["task_group"]

            loss, _ = model(input_values, attention_mask, labels, task_group)
            loss = loss / GRAD_ACCUM
            loss.backward()

            train_losses.append(float(loss.detach().cpu().item()) * GRAD_ACCUM)

            # Purpose: simulate a larger batch size via accumulation
            if (step % GRAD_ACCUM) == 0:
                opt.step()
                opt.zero_grad(set_to_none=True)

        # Purpose: handle leftover steps at the end of the epoch
        if (step % GRAD_ACCUM) != 0:
            opt.step()
            opt.zero_grad(set_to_none=True)

        avg_train_loss = float(np.mean(train_losses)) if train_losses else float("nan")

        # Validate
        model.eval()
        all_probs, all_true = [], []
        vpbar = tqdm(val_loader, desc=f"[seed={seed}] Val epoch {epoch}", dynamic_ncols=True)
        with torch.inference_mode():
            for batch in vpbar:
                input_values = batch["input_values"].to(DEVICE, non_blocking=False)
                attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
                labels = batch["labels"].to(DEVICE, non_blocking=False)
                task_group = batch["task_group"]

                _, logits = model(input_values, attention_mask, labels, task_group)
                probs = torch.softmax(logits, dim=-1)[:, 1].detach().cpu().numpy().astype(np.float64)
                all_probs.extend(probs.tolist())
                all_true.extend(labels.detach().cpu().numpy().astype(np.int64).tolist())

        val_auc = compute_auc(all_true, all_probs)
        print(f"seed={seed} | epoch {epoch:02d}/{MAX_EPOCHS} | train_loss={avg_train_loss:.5f} | val_AUROC={val_auc:.5f}")

        # Purpose: keep only the best epoch by VAL AUROC
        improved = (not math.isnan(val_auc)) and (val_auc > best_auc + 1e-12)
        if improved:
            best_auc = float(val_auc)
            best_epoch = int(epoch)
            no_improve = 0
            best_state = {
                "pre_vowel": {k: v.detach().cpu().clone() for k, v in model.pre_vowel.state_dict().items()},
                "pre_other": {k: v.detach().cpu().clone() for k, v in model.pre_other.state_dict().items()},
                "head_vowel": {k: v.detach().cpu().clone() for k, v in model.head_vowel.state_dict().items()},
                "head_other": {k: v.detach().cpu().clone() for k, v in model.head_other.state_dict().items()},
            }
            best_val_probs = list(all_probs)
            best_val_true  = list(all_true)
            best_val_metrics_thr = compute_threshold_metrics(best_val_true, best_val_probs, thr=0.5)
        else:
            no_improve += 1

        # Purpose: stop early if AUROC does not improve for PATIENCE epochs
        if no_improve >= PATIENCE:
            break

    if best_state is None or best_val_probs is None or best_val_true is None:
        raise RuntimeError("No best epoch captured. Validation AUROC may be NaN due to single-class validation split.")

    # Save best epoch artifacts
    best_heads_path = run_dir / "best_heads.pt"
    torch.save(best_state, str(best_heads_path))

    roc_png = run_dir / "roc_curve.png"
    cm_png  = run_dir / "confusion_matrix.png"
    save_roc_curve_png(np.asarray(best_val_true, dtype=np.int64), np.asarray(best_val_probs, dtype=np.float64), str(roc_png))
    save_confusion_png(np.asarray(best_val_true, dtype=np.int64), np.asarray(best_val_probs, dtype=np.float64), str(cm_png))

    # Output: metrics.json for this seed
    metrics = {
        "dataset": dataset_id,
        "seed": int(seed),
        "best_val_auroc": float(best_auc),
        "best_epoch": int(best_epoch),
        "n_train": int(len(train_df)),
        "n_val": int(len(val_df)),
        "label_counts_train": train_df["label_num"].value_counts(dropna=False).to_dict(),
        "label_counts_val": val_df["label_num"].value_counts(dropna=False).to_dict(),
        "experiment_tag": EXPERIMENT_TAG,
        "run_stamp": RUN_STAMP,
        "dropout_p": float(DROPOUT_P),
        "lr": float(LR),
        "effective_batch_size": int(PER_DEVICE_BS * GRAD_ACCUM),
        "per_device_batch_size": int(PER_DEVICE_BS),
        "grad_accum_steps": int(GRAD_ACCUM),
        "backbone_ckpt": BACKBONE_CKPT,
        "thr_metrics_val_thr0p5": best_val_metrics_thr,
    }
    metrics_path = run_dir / "metrics.json"
    with open(metrics_path, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2)

    print(f"[seed={seed}] WROTE:")
    print(" ", str(metrics_path))
    print(" ", str(roc_png))
    print(" ", str(cm_png))
    print(" ", str(best_heads_path))

    return float(best_auc), str(run_dir), metrics

# -------------------------
# Run all seeds + write experiment summary
# Inputs: SEEDS list
# Outputs: summary_trainval.json and appended history_index.jsonl
# -------------------------
aucs = []
run_dirs = []
per_seed_metrics = []

for seed in SEEDS:
    a, rd, met = run_trainval_once(seed)
    aucs.append(float(a))
    run_dirs.append(str(rd))
    per_seed_metrics.append(met)

# Purpose: simple mean and 95% CI across 3 seeds
t_crit = 4.302652729911275  # df=2, 95% CI
n = len(aucs)
mean_auc = float(np.mean(aucs))
std_auc = float(np.std(aucs, ddof=1)) if n > 1 else 0.0
half_width = float(t_crit * (std_auc / math.sqrt(n))) if n > 1 else 0.0
ci95 = [float(mean_auc - half_width), float(mean_auc + half_width)]

print("\nAUROC by seed:")
for s, a in zip(SEEDS, aucs):
    print(f"  seed {s}: {a:.6f}")
print(f"\nMean AUROC: {mean_auc:.6f}")
print(f"95% CI (t, n=3): [{ci95[0]:.6f}, {ci95[1]:.6f}]")

exp_summary = {
    "dataset": dataset_id,
    "dx_out_root": DX_OUT_ROOT,
    "experiment_tag": EXPERIMENT_TAG,
    "run_stamp": RUN_STAMP,
    "exp_root": str(EXP_ROOT),
    "run_dirs": run_dirs,
    "seeds": SEEDS,
    "aurocs": [float(x) for x in aucs],
    "mean_auroc": float(mean_auc),
    "t_crit_df2_95": float(t_crit),
    "half_width_95_ci": float(half_width),
    "ci95": ci95,
    "n_train": int(len(train_df)),
    "n_val": int(len(val_df)),
    "label_counts_train": train_df["label_num"].value_counts(dropna=False).to_dict(),
    "label_counts_val": val_df["label_num"].value_counts(dropna=False).to_dict(),
    "effective_batch_size": int(PER_DEVICE_BS * GRAD_ACCUM),
    "per_device_batch_size": int(PER_DEVICE_BS),
    "grad_accum_steps": int(GRAD_ACCUM),
    "backbone_ckpt": BACKBONE_CKPT,
    "dropout_p": float(DROPOUT_P),
    "lr": float(LR),
    "per_seed_metrics": per_seed_metrics,
}

summary_path = EXP_ROOT / "summary_trainval.json"
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(exp_summary, f, indent=2)

# Output: append one line so older experiments remain intact
history_path = TRAINVAL_ROOT / "history_index.jsonl"
with open(history_path, "a", encoding="utf-8") as f:
    f.write(json.dumps(exp_summary) + "\n")

print("\nWROTE per-experiment summary:", str(summary_path))
print("APPENDED global history index:", str(history_path))
print("\nOpen this folder to access artifacts:", str(EXP_ROOT))

# -------------------------
# Stop runtime (release GPU)
# -------------------------
print("\nAll done. Unassigning the runtime to stop the L4 instance...")
try:
    from google.colab import runtime  # type: ignore
    print("Calling runtime.unassign() now.")
    runtime.unassign()
except Exception as e:
    print("Could not unassign runtime automatically. Error:", repr(e))

The following cell trains and validates the D5 model using the updated v2 speaker splits (50% train, 20% validation, 30% test), and records the results in a way that does not overwrite earlier runs. It reads `manifests/manifest_all.csv` from the D5 v2 output folder, keeps only rows marked as `train` or `val`, infers the dataset label from the most common `dataset` value in the manifest, prints basic counts for sanity checking, and stops immediately if any referenced audio files are missing.

An experiment folder is created under `trainval_runs/` using a fixed tag (`frozen_LNDO`) and a timestamp so that each run is stored separately. Within this folder, the cell runs three independent training jobs using different random seeds (1337, 2024, 7777). Training uses a frozen Wav2Vec2 feature extractor and learns only two small task specific classifiers, one for vowel clips and one for other speech. The `task` column is converted into a simple `task_group`, where `vowl` is treated as vowel and all other values are treated as other. Audio files are loaded from disk, converted to mono if needed, checked to be 16 kHz, and paired with an attention mask so trailing padded silence in vowel clips is ignored. During batching, audio is padded with zeros and the attention mask is padded in the same way so the model does not learn from padded silence.

For each seed, the cell builds the training and validation data loaders and runs a short warm up by loading three batches to catch input or disk issues early. Training runs for up to 10 epochs using an effective batch size of 64 through gradient accumulation, mixed precision on the GPU, and early stopping when validation AUROC does not improve for two consecutive epochs. After each epoch, validation is run and validation AUROC is computed. Only the best validation epoch for each seed is kept.

For each seed run, results are saved under `exp_<tag>_<timestamp>/run_D5_seed####/`. Saved files include `best_heads.pt` with the trained weights for the two task heads and their small LayerNorm and Dropout blocks, `roc_curve.png` showing the validation ROC curve, `confusion_matrix.png` showing the validation confusion matrix at a fixed threshold of 0.5, and `metrics.json` containing training settings, dataset sizes, the best epoch, the best validation AUROC, and additional threshold based metrics at 0.5 such as accuracy, precision, recall or sensitivity, specificity, F1 score, MCC, and Fisher’s exact test p value.

After all three seeds complete, the cell prints the AUROC for each seed and the mean AUROC with a 95% confidence interval computed using a t distribution with n=3. A single `summary_trainval.json` file is written to summarize the entire experiment, including paths, settings, and per seed results, and the same summary is appended as one line to `trainval_runs/history_index.jsonl` so experiments are tracked over time. Finally, the Colab runtime is unassigned to stop the L4 GPU instance.

In [None]:
# =========================
# D5 Train + Val — v2 Splits, Frozen Backbone, Two Heads
# Inputs: manifest_all.csv (train/val only) and audio clips referenced by clip_path
# Outputs: per-seed run folders with metrics.json, roc_curve.png, confusion_matrix.png, best_heads.pt
#          plus one exp-level summary_trainval.json and one appended history_index.jsonl record
# =========================

import os, json, math, random, time, warnings
from pathlib import Path

import numpy as np
import pandas as pd
import soundfile as sf

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import Wav2Vec2Model

from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, matthews_corrcoef
from scipy.stats import fisher_exact
import matplotlib.pyplot as plt

from tqdm.auto import tqdm

# -------------------------
# Dataset root (D5 v2)
# Purpose: pick the preprocessed version that contains the v2 train/val splits
# Input: manifest_all.csv under this root
# -------------------------
# DX_OUT_ROOT = "/content/drive/MyDrive/AI_PD_Project/Datasets/D5-English (MDVR-KCL)/preprocessed_v1"  # modified to consider the new 50/20/30 splits path
DX_OUT_ROOT = "/content/drive/MyDrive/AI_PD_Project/Datasets/D5-English (MDVR-KCL)/preprocessed_v2"    # modified to consider the new 50/20/30 splits path

MANIFEST_ALL = f"{DX_OUT_ROOT}/manifests/manifest_all.csv"

# -------------------------
# Experiment folder naming
# Purpose: keep a permanent run history without overwriting older experiments
# Output: EXP_ROOT contains all per-seed run folders for this execution
# -------------------------
EXPERIMENT_TAG = "frozen_LNDO"   # change to something meaningful for paper
RUN_STAMP = time.strftime("%Y%m%d_%H%M%S")  # unique timestamp
TRAINVAL_ROOT = Path(DX_OUT_ROOT) / "trainval_runs"
EXP_ROOT = TRAINVAL_ROOT / f"exp_{EXPERIMENT_TAG}_{RUN_STAMP}"
EXP_ROOT.mkdir(parents=True, exist_ok=True)

# -------------------------
# Training settings (fixed defaults)
# Purpose: keep training behavior comparable across datasets
# -------------------------
MAX_EPOCHS     = 10
EFFECTIVE_BS   = 64
PER_DEVICE_BS  = 16
GRAD_ACCUM     = max(1, EFFECTIVE_BS // PER_DEVICE_BS)

LR             = 1e-3
PATIENCE       = 2
SEEDS          = [1337, 2024, 7777]

BACKBONE_CKPT  = "facebook/wav2vec2-base"
SR_EXPECTED    = 16000
TINY_THRESH    = 1e-4

# Small pre-head blocks (trainable)
DROPOUT_P      = 0.2

# Data loader stability (Drive-friendly)
NUM_WORKERS    = 0
PIN_MEMORY     = False

# Fixed threshold used for confusion and threshold metrics (VAL only)
VAL_THRESHOLD  = 0.5

# Mixed precision (GPU only)
USE_AMP        = True

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
warnings.filterwarnings("ignore", message="Passing `gradient_checkpointing` to a config initialization is deprecated")

print("EXPERIMENT ROOT:", str(EXP_ROOT))
print("DEVICE:", DEVICE)
if DEVICE.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
print("PER_DEVICE_BS:", PER_DEVICE_BS, "| GRAD_ACCUM:", GRAD_ACCUM, "| EFFECTIVE_BS:", PER_DEVICE_BS * GRAD_ACCUM)
print("NUM_WORKERS:", NUM_WORKERS, "| PIN_MEMORY:", PIN_MEMORY)
print("VAL_THRESHOLD:", VAL_THRESHOLD)
print("USE_AMP:", bool(USE_AMP and DEVICE.type == "cuda"))

# -------------------------
# Manifest read + split filtering
# Input: manifest_all.csv
# Output: train_df and val_df (train/val rows only) and dataset_id label
# -------------------------
if not os.path.exists(MANIFEST_ALL):
    raise FileNotFoundError(f"Missing manifest_all.csv: {MANIFEST_ALL}")

m = pd.read_csv(MANIFEST_ALL)

# Purpose: ensure the required columns exist before continuing
req_cols = {"split", "clip_path", "label_num", "task"}
missing = [c for c in sorted(req_cols) if c not in m.columns]
if missing:
    raise ValueError(f"Manifest missing required columns: {missing}. Found: {list(m.columns)}")

# Purpose: train/val only in this cell
m = m[m["split"].isin(["train", "val"])].copy()
if len(m) == 0:
    raise RuntimeError("After filtering to split in {train,val}, manifest has 0 rows.")

# Purpose: infer the dataset label when the manifest mixes datasets
if "dataset" in m.columns and m["dataset"].notna().any():
    dataset_id = str(m["dataset"].value_counts(dropna=True).idxmax())
    m = m[m["dataset"].astype(str) == dataset_id].copy()
else:
    dataset_id = "DX"

# Purpose: keep a consistent table shape across datasets
keep_cols = ["clip_path", "label_num", "task", "speaker_id", "duration_sec", "split"]
for c in keep_cols:
    if c not in m.columns:
        m[c] = np.nan
m = m[keep_cols].copy()

train_df = m[m["split"] == "train"].copy().reset_index(drop=True)
val_df   = m[m["split"] == "val"].copy().reset_index(drop=True)

print(f"\nDataset inferred: {dataset_id}")
print(f"Train rows: {len(train_df)} | Val rows: {len(val_df)}")
print("Train label counts:", train_df["label_num"].value_counts(dropna=False).to_dict())
print("Val label counts:",   val_df["label_num"].value_counts(dropna=False).to_dict())

# Purpose: avoid silent runs that cannot train or validate
if len(train_df) == 0 or len(val_df) == 0:
    raise RuntimeError("Train or Val split has 0 rows.")

# -------------------------
# Clip existence check
# Input: clip_path from train_df/val_df
# Output: raises early if any audio files are missing (shows a few examples)
# -------------------------
def _fail_fast_missing_paths(df: pd.DataFrame, name: str):
    missing_paths = []
    for p in tqdm(df["clip_path"].astype(str).tolist(), desc=f"Check {name} clip_path exists", dynamic_ncols=True):
        if not os.path.exists(p):
            missing_paths.append(p)
            if len(missing_paths) >= 10:
                break
    if missing_paths:
        raise FileNotFoundError(f"{name}: missing clip_path(s). Examples: {missing_paths}")

_fail_fast_missing_paths(train_df, "TRAIN")
_fail_fast_missing_paths(val_df, "VAL")

# -------------------------
# Task grouping
# Purpose: choose vowel head vs other head per clip
# Output: task_group column added to both tables
# -------------------------
def _task_group(task_val) -> str:
    return "vowel" if str(task_val) == "vowl" else "other"

train_df["task_group"] = train_df["task"].apply(_task_group)
val_df["task_group"]   = val_df["task"].apply(_task_group)

# -------------------------
# Dataset + collator (padding + attention masks)
# Inputs: train_df/val_df clip_path and task_group
# Outputs: batches with padded input_values and matching attention_mask
# -------------------------
class AudioManifestDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        clip_path = str(row["clip_path"])
        label = int(row["label_num"])
        task_group = str(row["task_group"])

        if not os.path.exists(clip_path):
            raise FileNotFoundError(f"Missing clip_path: {clip_path}")

        # Purpose: load audio and force mono float32
        y, sr = sf.read(clip_path, always_2d=False)
        if isinstance(y, np.ndarray) and y.ndim > 1:
            y = np.mean(y, axis=1)
        y = np.asarray(y, dtype=np.float32)

        # Purpose: keep a single sample rate across all training
        if int(sr) != SR_EXPECTED:
            raise ValueError(f"Sample rate mismatch (expected {SR_EXPECTED}): {clip_path} has sr={sr}")

        # Purpose: mask trailing zero padding for vowel clips during pooling
        attn = np.ones((len(y),), dtype=np.int64)
        if task_group == "vowel":
            k = -1
            for j in range(len(y) - 1, -1, -1):
                if abs(float(y[j])) > TINY_THRESH:
                    k = j
                    break
            if k >= 0:
                attn[:k+1] = 1
                attn[k+1:] = 0
            else:
                attn[:] = 1
        else:
            attn[:] = 1

        return {
            "input_values": torch.from_numpy(y),
            "attention_mask": torch.from_numpy(attn),
            "labels": torch.tensor(label, dtype=torch.long),
            "task_group": task_group,
        }

def collate_fn(batch):
    # Purpose: pad to the longest audio in the batch (pads input and mask with zeros)
    max_len = int(max(b["input_values"].numel() for b in batch))
    input_vals, attn_masks, labels, task_groups = [], [], [], []
    for b in batch:
        x = b["input_values"]
        a = b["attention_mask"]
        pad = max_len - x.numel()
        if pad > 0:
            x = torch.cat([x, torch.zeros(pad, dtype=x.dtype)], dim=0)
            a = torch.cat([a, torch.zeros(pad, dtype=a.dtype)], dim=0)
        input_vals.append(x)
        attn_masks.append(a)
        labels.append(b["labels"])
        task_groups.append(b["task_group"])
    return {
        "input_values": torch.stack(input_vals, dim=0),
        "attention_mask": torch.stack(attn_masks, dim=0),
        "labels": torch.stack(labels, dim=0),
        "task_group": task_groups,
    }

# -------------------------
# Model: frozen Wav2Vec2 + two task heads
# Inputs: backbone checkpoint and dropout probability
# Output: loss and logits (PD vs Healthy) using the head selected by task_group
# -------------------------
class Wav2Vec2TwoHeadClassifier(nn.Module):
    def __init__(self, ckpt: str, dropout_p: float):
        super().__init__()
        self.backbone = Wav2Vec2Model.from_pretrained(ckpt)
        for p in self.backbone.parameters():
            p.requires_grad = False

        H = int(self.backbone.config.hidden_size)

        # Purpose: small per-task feature cleanup before each head
        self.pre_vowel = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.pre_other = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))

        self.head_vowel = nn.Linear(H, 2)
        self.head_other = nn.Linear(H, 2)
        self.loss_fn = nn.CrossEntropyLoss()

    def masked_mean_pool(self, last_hidden, attn_mask_samples):
        # Purpose: pool frame features while ignoring masked samples
        feat_mask = self.backbone._get_feature_vector_attention_mask(last_hidden.shape[1], attn_mask_samples)
        feat_mask = feat_mask.to(last_hidden.device).unsqueeze(-1).type_as(last_hidden)
        masked = last_hidden * feat_mask
        denom = feat_mask.sum(dim=1).clamp(min=1.0)
        return masked.sum(dim=1) / denom

    def _heads_fp32(self, x_fp_any: torch.Tensor, head: nn.Module) -> torch.Tensor:
        # Purpose: keep heads in fp32 even when AMP is enabled
        x = x_fp_any.float()
        if x.is_cuda:
            with torch.amp.autocast(device_type="cuda", enabled=False):
                return head(x)
        return head(x)

    def forward(self, input_values, attention_mask, labels, task_group):
        # Purpose: backbone is frozen, so gradients only flow through heads
        with torch.no_grad():
            out = self.backbone(input_values=input_values, attention_mask=attention_mask)
            last_hidden = out.last_hidden_state

        pooled = self.masked_mean_pool(last_hidden, attention_mask)  # [B,H]

        z_v = self.pre_vowel(pooled.float())
        z_o = self.pre_other(pooled.float())

        logits_v = self._heads_fp32(z_v, self.head_vowel)
        logits_o = self._heads_fp32(z_o, self.head_other)

        # Purpose: apply the correct head per sample
        is_vowel = torch.tensor([tg == "vowel" for tg in task_group], device=pooled.device, dtype=torch.bool)
        logits = logits_o.clone()
        if is_vowel.any():
            logits[is_vowel] = logits_v[is_vowel]

        loss = self.loss_fn(logits, labels)
        return loss, logits

# -------------------------
# Metrics + plots (VAL only)
# Inputs: val labels and predicted PD probabilities
# Outputs: AUROC, threshold metrics, ROC curve PNG, confusion matrix PNG
# -------------------------
def compute_auc(y_true, y_prob):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    if len(np.unique(y_true)) < 2:
        return float("nan")
    return float(roc_auc_score(y_true, y_prob))

def compute_threshold_metrics(y_true, y_prob, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])

    eps = 1e-12
    acc = (TP + TN) / max(1, (TP + TN + FP + FN))
    prec = TP / (TP + FP + eps)
    rec = TP / (TP + FN + eps)
    f1 = 2 * prec * rec / (prec + rec + eps)
    spec = TN / (TN + FP + eps)

    try:
        mcc = float(matthews_corrcoef(y_true, y_pred)) if len(np.unique(y_pred)) > 1 else float("nan")
    except Exception:
        mcc = float("nan")

    try:
        _, pval = fisher_exact([[TN, FP], [FN, TP]], alternative="two-sided")
        pval = float(pval)
    except Exception:
        pval = float("nan")

    return {
        "threshold": float(thr),
        "confusion_matrix": {"TN": TN, "FP": FP, "FN": FN, "TP": TP},
        "accuracy": float(acc),
        "precision": float(prec),
        "recall": float(rec),
        "f1_score": float(f1),
        "sensitivity": float(rec),
        "specificity": float(spec),
        "mcc": float(mcc),
        "p_value_fisher_two_sided": float(pval),
    }

def save_roc_curve_png(y_true, y_prob, out_png):
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    plt.figure()
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curve (Val)")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def save_confusion_png(y_true, y_prob, out_png, thr=0.5):
    y_pred = (np.asarray(y_prob) >= float(thr)).astype(int)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    plt.figure()
    plt.imshow(cm)
    plt.xticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.yticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix (Val, thr={thr:.2f})")
    for (i, j), v in np.ndenumerate(cm):
        plt.text(j, i, str(v), ha="center", va="center")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

# -------------------------
# Seed control
# Purpose: stabilize training order and results per seed
# -------------------------
def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# -------------------------
# One seed run (train + val)
# Inputs: seed, train_df, val_df
# Outputs: best_heads.pt + plots + metrics.json in run_dir
# -------------------------
def run_trainval_once(seed: int):
    set_all_seeds(seed)

    # Purpose: keep D5 run folder names stable and easy to scan
    # Output: run_dir for this seed under EXP_ROOT
    # run_dir = EXP_ROOT / f"run_{dataset_id}_seed{seed}"
    # run_dir.mkdir(parents=True, exist_ok=True)
    run_dir = EXP_ROOT / f"run_D5_seed{seed}"  # modified to consider the new 50/20/30 splits path
    run_dir.mkdir(parents=True, exist_ok=True)  # modified to consider the new 50/20/30 splits path

    # Purpose: create loaders from the manifest tables
    train_ds = AudioManifestDataset(train_df)
    val_ds   = AudioManifestDataset(val_df)

    train_loader = DataLoader(train_ds, batch_size=PER_DEVICE_BS, shuffle=True,
                              num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=collate_fn)
    val_loader   = DataLoader(val_ds, batch_size=PER_DEVICE_BS, shuffle=False,
                              num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=collate_fn)

    # Warm-up
    # Purpose: ensure the training loader is working before epochs start
    print(f"\n[seed={seed}] Warm-up: loading 3 train batches...")
    t0 = time.time()
    it = iter(train_loader)
    for i in range(3):
        _ = next(it)
        print(f"  loaded warmup batch {i+1}/3")
    print(f"[seed={seed}] Warm-up done in {time.time()-t0:.2f}s")

    model = Wav2Vec2TwoHeadClassifier(BACKBONE_CKPT, dropout_p=DROPOUT_P).to(DEVICE)

    # Purpose: only train pre-head blocks and heads (backbone stays frozen)
    trainable_params = (
        list(model.pre_vowel.parameters()) + list(model.pre_other.parameters()) +
        list(model.head_vowel.parameters()) + list(model.head_other.parameters())
    )
    opt = torch.optim.Adam(trainable_params, lr=LR)

    use_amp = bool(USE_AMP and DEVICE.type == "cuda")
    scaler = torch.amp.GradScaler("cuda", enabled=use_amp)

    # Purpose: track the best epoch by VAL AUROC
    best_auc = -1.0
    best_epoch = -1
    no_improve = 0
    best_state = None
    best_val_probs = None
    best_val_true = None
    best_thr_metrics = None

    for epoch in range(1, MAX_EPOCHS + 1):
        # Train
        model.train()
        train_losses = []
        opt.zero_grad(set_to_none=True)

        pbar = tqdm(train_loader, desc=f"[seed={seed}] Train epoch {epoch}", dynamic_ncols=True)
        step = 0
        for batch in pbar:
            step += 1
            input_values = batch["input_values"].to(DEVICE, non_blocking=False)
            attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
            labels = batch["labels"].to(DEVICE, non_blocking=False)
            task_group = batch["task_group"]

            with torch.amp.autocast(device_type="cuda", enabled=use_amp):
                loss, _ = model(input_values, attention_mask, labels, task_group)
                loss = loss / GRAD_ACCUM

            scaler.scale(loss).backward()
            train_losses.append(float(loss.detach().cpu().item()) * GRAD_ACCUM)

            # Purpose: update weights every GRAD_ACCUM steps to reach EFFECTIVE_BS
            if (step % GRAD_ACCUM) == 0:
                scaler.step(opt)
                scaler.update()
                opt.zero_grad(set_to_none=True)

        # Purpose: handle leftover steps when epoch length is not divisible by GRAD_ACCUM
        if (step % GRAD_ACCUM) != 0:
            scaler.step(opt)
            scaler.update()
            opt.zero_grad(set_to_none=True)

        avg_train_loss = float(np.mean(train_losses)) if train_losses else float("nan")

        # Validate
        model.eval()
        all_probs, all_true = [], []
        vpbar = tqdm(val_loader, desc=f"[seed={seed}] Val epoch {epoch}", dynamic_ncols=True)
        with torch.inference_mode():
            for batch in vpbar:
                input_values = batch["input_values"].to(DEVICE, non_blocking=False)
                attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
                labels = batch["labels"].to(DEVICE, non_blocking=False)
                task_group = batch["task_group"]

                with torch.amp.autocast(device_type="cuda", enabled=use_amp):
                    _, logits = model(input_values, attention_mask, labels, task_group)

                # Purpose: use PD probability from softmax(logits)
                probs = torch.softmax(logits.float(), dim=-1)[:, 1].detach().cpu().numpy().astype(np.float64)
                all_probs.extend(probs.tolist())
                all_true.extend(labels.detach().cpu().numpy().astype(np.int64).tolist())

        val_auc = compute_auc(all_true, all_probs)
        print(f"seed={seed} | epoch {epoch:02d}/{MAX_EPOCHS} | train_loss={avg_train_loss:.5f} | val_AUROC={val_auc:.5f}")

        # Purpose: store best weights and best-epoch VAL threshold metrics
        improved = (not math.isnan(val_auc)) and (val_auc > best_auc + 1e-12)
        if improved:
            best_auc = float(val_auc)
            best_epoch = int(epoch)
            no_improve = 0
            best_state = {
                "pre_vowel": {k: v.detach().cpu().clone() for k, v in model.pre_vowel.state_dict().items()},
                "pre_other": {k: v.detach().cpu().clone() for k, v in model.pre_other.state_dict().items()},
                "head_vowel": {k: v.detach().cpu().clone() for k, v in model.head_vowel.state_dict().items()},
                "head_other": {k: v.detach().cpu().clone() for k, v in model.head_other.state_dict().items()},
            }
            best_val_probs = list(all_probs)
            best_val_true  = list(all_true)
            best_thr_metrics = compute_threshold_metrics(best_val_true, best_val_probs, thr=VAL_THRESHOLD)
        else:
            no_improve += 1

        # Purpose: early stop when VAL AUROC does not improve for PATIENCE epochs
        if no_improve >= PATIENCE:
            break

    # Purpose: avoid writing partial artifacts if no valid best epoch exists
    if best_state is None or best_val_probs is None or best_val_true is None or best_thr_metrics is None:
        raise RuntimeError("No best epoch captured. Validation AUROC may be NaN due to single-class validation split.")

    # Save best epoch artifacts
    best_heads_path = run_dir / "best_heads.pt"
    torch.save(best_state, str(best_heads_path))

    roc_png = run_dir / "roc_curve.png"
    cm_png  = run_dir / "confusion_matrix.png"
    ytrue_np = np.asarray(best_val_true, dtype=np.int64)
    yprob_np = np.asarray(best_val_probs, dtype=np.float64)

    save_roc_curve_png(ytrue_np, yprob_np, str(roc_png))
    save_confusion_png(ytrue_np, yprob_np, str(cm_png), thr=VAL_THRESHOLD)

    # Output: metrics.json captures the best epoch and the extra threshold metrics
    metrics = {
        "dataset": dataset_id,
        "seed": int(seed),
        "best_val_auroc": float(best_auc),
        "best_epoch": int(best_epoch),
        "n_train": int(len(train_df)),
        "n_val": int(len(val_df)),
        "label_counts_train": train_df["label_num"].value_counts(dropna=False).to_dict(),
        "label_counts_val": val_df["label_num"].value_counts(dropna=False).to_dict(),
        "experiment_tag": EXPERIMENT_TAG,
        "run_stamp": RUN_STAMP,
        "dropout_p": float(DROPOUT_P),
        "lr": float(LR),
        "effective_batch_size": int(PER_DEVICE_BS * GRAD_ACCUM),
        "per_device_batch_size": int(PER_DEVICE_BS),
        "grad_accum_steps": int(GRAD_ACCUM),
        "val_threshold": float(VAL_THRESHOLD),
        "backbone_ckpt": BACKBONE_CKPT,
        "threshold_metrics_best_epoch": best_thr_metrics,
    }
    metrics_path = run_dir / "metrics.json"
    with open(metrics_path, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2)

    print(f"[seed={seed}] WROTE:")
    print(" ", str(metrics_path))
    print(" ", str(roc_png))
    print(" ", str(cm_png))
    print(" ", str(best_heads_path))

    return float(best_auc), str(run_dir), best_thr_metrics

# -------------------------
# Run all seeds + write exp summary + append global history
# Inputs: SEEDS list
# Outputs: exp_root/summary_trainval.json and one appended history_index.jsonl record
# -------------------------
aucs = []
run_dirs = []
per_seed_metrics = []

for seed in SEEDS:
    a, rd, thrm = run_trainval_once(seed)
    aucs.append(a)
    run_dirs.append(rd)
    per_seed_metrics.append({"seed": int(seed), "best_val_auroc": float(a), "threshold_metrics_best_epoch": thrm})

t_crit = 4.302652729911275  # df=2, 95% CI
n = len(aucs)
mean_auc = float(np.mean(aucs))
std_auc = float(np.std(aucs, ddof=1)) if n > 1 else 0.0
half_width = float(t_crit * (std_auc / math.sqrt(n))) if n > 1 else 0.0

print("\nAUROC by seed:")
for s, a in zip(SEEDS, aucs):
    print(f"  seed {s}: {a:.6f}")
print(f"\nMean AUROC: {mean_auc:.6f}")
print(f"95% CI (t, n=3): [{mean_auc - half_width:.6f}, {mean_auc + half_width:.6f}]")

exp_summary = {
    "dataset": dataset_id,
    "dx_out_root": DX_OUT_ROOT,
    "experiment_tag": EXPERIMENT_TAG,
    "run_stamp": RUN_STAMP,
    "exp_root": str(EXP_ROOT),
    "run_dirs": run_dirs,
    "seeds": SEEDS,
    "aurocs": [float(x) for x in aucs],
    "mean_auroc": float(mean_auc),
    "t_crit_df2_95": float(t_crit),
    "half_width_95_ci": float(half_width),
    "ci95": [float(mean_auc - half_width), float(mean_auc + half_width)],
    "n_train": int(len(train_df)),
    "n_val": int(len(val_df)),
    "label_counts_train": train_df["label_num"].value_counts(dropna=False).to_dict(),
    "label_counts_val": val_df["label_num"].value_counts(dropna=False).to_dict(),
    "effective_batch_size": int(PER_DEVICE_BS * GRAD_ACCUM),
    "per_device_batch_size": int(PER_DEVICE_BS),
    "grad_accum_steps": int(GRAD_ACCUM),
    "val_threshold": float(VAL_THRESHOLD),
    "backbone_ckpt": BACKBONE_CKPT,
    "dropout_p": float(DROPOUT_P),
    "lr": float(LR),
    "per_seed_best_epoch_metrics": per_seed_metrics,
}

summary_path = EXP_ROOT / "summary_trainval.json"
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(exp_summary, f, indent=2)

# Output: one-line append so older experiments stay intact
history_path = TRAINVAL_ROOT / "history_index.jsonl"
with open(history_path, "a", encoding="utf-8") as f:
    f.write(json.dumps(exp_summary) + "\n")

print("\nWROTE per-experiment summary:", str(summary_path))
print("APPENDED global history index:", str(history_path))
print("\nOpen this folder to access artifacts:", str(EXP_ROOT))

# -------------------------
# Stop runtime
# Purpose: release the GPU machine after outputs are written
# -------------------------
print("\nAll done. Unassigning the runtime to stop the L4 instance...")
try:
    from google.colab import runtime  # type: ignore
    print("Calling runtime.unassign() now.")
    runtime.unassign()
except Exception as e:
    print("Could not unassign runtime automatically. You can stop the runtime manually in Colab.")
    print("Reason:", repr(e))

The following cell trains and validates a two head classifier for a single dataset, shown here using D6, with training and validation only. It uses a frozen Wav2Vec2 feature extractor and reads a single standardized manifest file (`manifests/manifest_all.csv`) from `DX_OUT_ROOT`. Only rows marked as `train` or `val` are used. The dataset name is inferred from the most common `dataset` value in the manifest, and basic counts are printed to confirm that the splits look reasonable. Before training begins, the cell runs safety checks to avoid common Colab problems, such as local files shadowing PyTorch or Transformers, and scans the manifest with a progress bar to confirm that all referenced audio files exist.

A new experiment folder is then created under `trainval_runs/` using a fixed tag (`frozen_LNDO`) and a timestamp. Within this folder, one subfolder is created per random seed, and a short summary record is later appended to a running history file so past experiments remain easy to track. Training settings are fixed across datasets to keep results comparable. These include three random seeds, up to 10 epochs, an effective batch size of 64 using gradient accumulation, a learning rate of 1e-3, and early stopping with a patience of 2 epochs. Mixed precision is enabled on the GPU to improve speed, while final head computations are kept in full precision to avoid numerical issues.

Data loading is set up to prevent the model from learning from padded silence. Each clip is assigned to one of two task groups using a simple rule: `task == "vowl"` is treated as vowel speech, and everything else is treated as other speech. Audio is loaded from disk, converted to mono if needed, and checked to ensure a 16 kHz sample rate. For vowel clips, the code detects trailing near zero samples caused by padding and builds an attention mask that excludes this padded region. When forming batches, shorter clips are padded with zeros, and their attention masks are padded in the same way so padded regions are ignored.

The model uses a frozen Wav2Vec2 backbone as a feature extractor, with two separate classification heads on top, one for vowel clips and one for other speech. Each head includes a small task specific preprocessing block made of LayerNorm and Dropout, followed by a linear classifier. Only the head related layers are trained, and the backbone weights are never updated. During training, each sample in a batch is routed to the correct head based on its task group. Loss is computed using standard cross entropy, and optimization is performed with Adam on the head parameters only.

For each random seed, the cell builds the training and validation data loaders and runs a short warm up by loading a few batches to catch dataset or padding issues early. Training then runs for several epochs with progress bars, using gradient accumulation to reach the target effective batch size. After each epoch, validation is performed and validation AUROC is computed. The best epoch is tracked based on AUROC, and early stopping is applied if AUROC does not improve for the specified number of epochs.

At the end of each seed run, only the best epoch results are saved in that seed’s run folder. These include `best_heads.pt` with the saved weights for the two heads and their preprocessing layers, `roc_curve.png` and `confusion_matrix.png` showing validation results using a fixed threshold of 0.5 for the confusion matrix, and a `metrics.json` file containing key settings, dataset sizes, the best AUROC, the best epoch, and threshold based metrics at 0.5.

After all three seeds complete, the cell prints the AUROC for each seed along with the mean AUROC and a 95% confidence interval computed using a t distribution with n=3. A single `summary_trainval.json` file is written to the experiment folder with paths, settings, and per seed results, and the same summary is appended to `trainval_runs/history_index.jsonl` so the experiment is logged. Finally, the Colab runtime is unassigned to stop the GPU instance.

In [None]:
# =========================
# D6 Train + Val — Frozen Backbone, Two Heads, Val AUROC Tracking
# Inputs: manifest_all.csv (train/val only) and audio clips referenced by clip_path
# Outputs: per-seed run folders with metrics.json, roc_curve.png, confusion_matrix.png, best_heads.pt
#          plus one exp-level summary_trainval.json and one appended history_index.jsonl record
# =========================

import os, json, math, random, time, warnings
from pathlib import Path

import numpy as np
import pandas as pd
import soundfile as sf

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import Wav2Vec2Model

from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, matthews_corrcoef
from scipy.stats import fisher_exact
import matplotlib.pyplot as plt

from tqdm.auto import tqdm

# -------------------------
# Import safety guard
# Purpose: avoid importing local files named torch/transformers by mistake
# -------------------------
if os.path.exists("/content/torch.py") or os.path.exists("/content/torch/__init__.py"):
    raise RuntimeError("Found local /content/torch.py or /content/torch/ that shadows PyTorch. Rename/remove it and restart runtime.")
if os.path.exists("/content/transformers.py") or os.path.exists("/content/transformers/__init__.py"):
    raise RuntimeError("Found local /content/transformers.py or /content/transformers/ that shadows Hugging Face Transformers. Rename/remove it and restart runtime.")

# -------------------------
# Paths (fallback + runtime override)
# Input: DX_OUT_ROOT may already exist from preprocessing
# Output: MANIFEST_ALL used for train/val rows only
# -------------------------
D6_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D6-Ah Sound (Figshare)/preprocessed_v1"
DX_OUT_ROOT = globals().get("DX_OUT_ROOT", D6_OUT_ROOT_FALLBACK)
MANIFEST_ALL = f"{DX_OUT_ROOT}/manifests/manifest_all.csv"

# -------------------------
# Experiment folders
# Output: EXP_ROOT for this run, plus global history_index.jsonl under trainval_runs/
# -------------------------
EXPERIMENT_TAG = "frozen_LNDO"
RUN_STAMP = time.strftime("%Y%m%d_%H%M%S")

TRAINVAL_ROOT = Path(DX_OUT_ROOT) / "trainval_runs"
TRAINVAL_ROOT.mkdir(parents=True, exist_ok=True)

EXP_ROOT = TRAINVAL_ROOT / f"exp_{EXPERIMENT_TAG}_{RUN_STAMP}"
EXP_ROOT.mkdir(parents=True, exist_ok=True)

HISTORY_INDEX_PATH = TRAINVAL_ROOT / "history_index.jsonl"

# -------------------------
# Training constants
# Purpose: keep settings stable across datasets and runs
# -------------------------
MAX_EPOCHS     = 10
EFFECTIVE_BS   = 64
PER_DEVICE_BS  = 16
GRAD_ACCUM     = max(1, EFFECTIVE_BS // PER_DEVICE_BS)

LR             = 1e-3
PATIENCE       = 2
SEEDS          = [1337, 2024, 7777]

BACKBONE_CKPT  = "facebook/wav2vec2-base"
SR_EXPECTED    = 16000
TINY_THRESH    = 1e-4

# Trainable blocks only (heads + small pre-head blocks)
DROPOUT_P      = 0.2

# Data loader defaults (Drive-friendly)
NUM_WORKERS    = 0
PIN_MEMORY     = False

# Fixed reporting threshold (used only for confusion/threshold metrics on VAL)
VAL_THRESHOLD  = 0.5

# Mixed precision (GPU only)
USE_AMP        = True

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
warnings.filterwarnings("ignore", message="Passing `gradient_checkpointing` to a config initialization is deprecated")

print("DX_OUT_ROOT:", DX_OUT_ROOT)
print("MANIFEST_ALL:", MANIFEST_ALL)
print("EXPERIMENT ROOT:", str(EXP_ROOT))
print("DEVICE:", DEVICE)
if DEVICE.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
print("PER_DEVICE_BS:", PER_DEVICE_BS, "| GRAD_ACCUM:", GRAD_ACCUM, "| EFFECTIVE_BS:", PER_DEVICE_BS * GRAD_ACCUM)
print("NUM_WORKERS:", NUM_WORKERS, "| PIN_MEMORY:", PIN_MEMORY)
print("EXPERIMENT_TAG:", EXPERIMENT_TAG, "| RUN_STAMP:", RUN_STAMP)
print("VAL_THRESHOLD:", VAL_THRESHOLD)
print("USE_AMP:", bool(USE_AMP and DEVICE.type == "cuda"))

# -------------------------
# Manifest read + split filtering
# Input: manifest_all.csv
# Output: train_df and val_df with consistent columns and dataset_id inferred
# -------------------------
if not os.path.exists(MANIFEST_ALL):
    raise FileNotFoundError(f"Missing manifest_all.csv: {MANIFEST_ALL}")

m = pd.read_csv(MANIFEST_ALL)

# Purpose: enforce minimum required columns for training
req_cols = {"split", "clip_path", "label_num", "task"}
missing = [c for c in sorted(req_cols) if c not in m.columns]
if missing:
    raise ValueError(f"Manifest missing required columns: {missing}. Found: {list(m.columns)}")

# Purpose: train/val only in this cell
m = m[m["split"].isin(["train", "val"])].copy()
if len(m) == 0:
    raise RuntimeError("After filtering to split in {train,val}, manifest has 0 rows.")

# Purpose: select the active dataset id when multiple datasets are present
if "dataset" in m.columns and m["dataset"].notna().any():
    dataset_id = str(m["dataset"].astype(str).value_counts(dropna=True).idxmax())
    m = m[m["dataset"].astype(str) == dataset_id].copy()
else:
    dataset_id = "DX"

# Purpose: keep a consistent set of columns even if some are missing
keep_cols = ["clip_path", "label_num", "task", "speaker_id", "duration_sec", "split"]
for c in keep_cols:
    if c not in m.columns:
        m[c] = np.nan
m = m[keep_cols].copy()

train_df = m[m["split"] == "train"].copy().reset_index(drop=True)
val_df   = m[m["split"] == "val"].copy().reset_index(drop=True)

print(f"\nDataset inferred: {dataset_id}")
print(f"Train rows: {len(train_df)} | Val rows: {len(val_df)}")
print("Train label counts:", train_df["label_num"].value_counts(dropna=False).to_dict())
print("Val label counts:",   val_df["label_num"].value_counts(dropna=False).to_dict())

# Purpose: fail early if either split is empty
if len(train_df) == 0 or len(val_df) == 0:
    raise RuntimeError("Train or Val split has 0 rows.")

# -------------------------
# Clip existence check
# Input: clip_path from train_df/val_df
# Output: raises early if any audio files are missing (shows a few examples)
# -------------------------
def _fail_fast_missing_paths(df: pd.DataFrame, name: str):
    missing_paths = []
    for p in tqdm(df["clip_path"].astype(str).tolist(), desc=f"Check {name} clip_path exists", dynamic_ncols=True):
        if not os.path.exists(p):
            missing_paths.append(p)
            if len(missing_paths) >= 10:
                break
    if missing_paths:
        raise FileNotFoundError(f"{name}: missing clip_path(s). Examples: {missing_paths}")

_fail_fast_missing_paths(train_df, "TRAIN")
_fail_fast_missing_paths(val_df, "VAL")

# -------------------------
# Task grouping
# Purpose: choose vowel head vs other head per clip
# Output: task_group column added to both tables
# -------------------------
def _task_group(task_val) -> str:
    return "vowel" if str(task_val) == "vowl" else "other"

train_df["task_group"] = train_df["task"].apply(_task_group)
val_df["task_group"]   = val_df["task"].apply(_task_group)

# -------------------------
# Dataset + collator (padding + attention masks)
# Inputs: train_df/val_df clip_path and task_group
# Outputs: batches with padded input_values and matching attention_mask
# -------------------------
class AudioManifestDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        clip_path = str(row["clip_path"])
        label = int(row["label_num"])
        task_group = str(row["task_group"])

        if not os.path.exists(clip_path):
            raise FileNotFoundError(f"Missing clip_path: {clip_path}")

        # Purpose: load audio and force mono float32
        y, sr = sf.read(clip_path, always_2d=False)
        if isinstance(y, np.ndarray) and y.ndim > 1:
            y = np.mean(y, axis=1)
        y = np.asarray(y, dtype=np.float32)

        # Purpose: keep a single sample rate across all training
        if int(sr) != SR_EXPECTED:
            raise ValueError(f"Sample rate mismatch (expected {SR_EXPECTED}): {clip_path} has sr={sr}")

        # Purpose: ignore trailing zero padding for vowel clips during pooling
        attn = np.ones((len(y),), dtype=np.int64)
        if task_group == "vowel":
            k = -1
            for j in range(len(y) - 1, -1, -1):
                if abs(float(y[j])) > TINY_THRESH:
                    k = j
                    break
            if k >= 0:
                attn[:k+1] = 1
                attn[k+1:] = 0
            else:
                attn[:] = 1
        else:
            attn[:] = 1

        return {
            "input_values": torch.from_numpy(y),
            "attention_mask": torch.from_numpy(attn),
            "labels": torch.tensor(label, dtype=torch.long),
            "task_group": task_group,
        }

def collate_fn(batch):
    # Purpose: pad to the longest audio in the batch (pads input and mask with zeros)
    max_len = int(max(b["input_values"].numel() for b in batch))
    input_vals, attn_masks, labels, task_groups = [], [], [], []
    for b in batch:
        x = b["input_values"]
        a = b["attention_mask"]
        pad = max_len - x.numel()
        if pad > 0:
            x = torch.cat([x, torch.zeros(pad, dtype=x.dtype)], dim=0)
            a = torch.cat([a, torch.zeros(pad, dtype=a.dtype)], dim=0)
        input_vals.append(x)
        attn_masks.append(a)
        labels.append(b["labels"])
        task_groups.append(b["task_group"])
    return {
        "input_values": torch.stack(input_vals, dim=0),
        "attention_mask": torch.stack(attn_masks, dim=0),
        "labels": torch.stack(labels, dim=0),
        "task_group": task_groups,
    }

# -------------------------
# Model: frozen Wav2Vec2 + two task heads
# Inputs: backbone checkpoint and dropout probability
# Output: loss and logits (PD vs Healthy) using the head selected by task_group
# -------------------------
class Wav2Vec2TwoHeadClassifier(nn.Module):
    def __init__(self, ckpt: str, dropout_p: float):
        super().__init__()
        self.backbone = Wav2Vec2Model.from_pretrained(ckpt, use_safetensors=True)
        for p in self.backbone.parameters():
            p.requires_grad = False

        H = int(self.backbone.config.hidden_size)

        # Purpose: small per-task feature cleanup before each head
        self.pre_vowel = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.pre_other = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))

        self.head_vowel = nn.Linear(H, 2)
        self.head_other = nn.Linear(H, 2)
        self.loss_fn = nn.CrossEntropyLoss()

    def masked_mean_pool(self, last_hidden, attn_mask_samples):
        # Purpose: pool frame features while ignoring masked samples
        feat_mask = self.backbone._get_feature_vector_attention_mask(last_hidden.shape[1], attn_mask_samples)
        feat_mask = feat_mask.to(last_hidden.device).unsqueeze(-1).type_as(last_hidden)
        masked = last_hidden * feat_mask
        denom = feat_mask.sum(dim=1).clamp(min=1.0)
        return masked.sum(dim=1) / denom

    def _heads_fp32(self, x_fp_any: torch.Tensor, head: nn.Module) -> torch.Tensor:
        # Purpose: keep heads in fp32 even when AMP is enabled
        x = x_fp_any.float()
        if x.is_cuda:
            with torch.amp.autocast(device_type="cuda", enabled=False):
                return head(x)
        return head(x)

    def forward(self, input_values, attention_mask, labels, task_group):
        # Purpose: backbone is frozen, so gradients only flow through heads
        with torch.no_grad():
            out = self.backbone(input_values=input_values, attention_mask=attention_mask)
            last_hidden = out.last_hidden_state

        pooled = self.masked_mean_pool(last_hidden, attention_mask)  # [B,H]

        z_v = self.pre_vowel(pooled.float())
        z_o = self.pre_other(pooled.float())

        logits_v = self._heads_fp32(z_v, self.head_vowel)
        logits_o = self._heads_fp32(z_o, self.head_other)

        # Purpose: apply the correct head per sample
        is_vowel = torch.tensor([tg == "vowel" for tg in task_group], device=pooled.device, dtype=torch.bool)
        logits = logits_o.clone()
        if is_vowel.any():
            logits[is_vowel] = logits_v[is_vowel]

        loss = self.loss_fn(logits, labels)
        return loss, logits

# -------------------------
# Metrics + plots (VAL only)
# Inputs: val labels and predicted PD probabilities
# Outputs: AUROC, threshold metrics, ROC curve PNG, confusion matrix PNG
# -------------------------
def compute_auc(y_true, y_prob):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    if len(np.unique(y_true)) < 2:
        return float("nan")
    return float(roc_auc_score(y_true, y_prob))

def compute_threshold_metrics(y_true, y_prob, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])

    eps = 1e-12
    acc = (TP + TN) / max(1, (TP + TN + FP + FN))
    prec = TP / (TP + FP + eps)
    rec = TP / (TP + FN + eps)
    f1 = 2 * prec * rec / (prec + rec + eps)
    spec = TN / (TN + FP + eps)

    try:
        mcc = float(matthews_corrcoef(y_true, y_pred)) if len(np.unique(y_pred)) > 1 else float("nan")
    except Exception:
        mcc = float("nan")

    try:
        _, pval = fisher_exact([[TN, FP], [FN, TP]], alternative="two-sided")
        pval = float(pval)
    except Exception:
        pval = float("nan")

    return {
        "threshold": float(thr),
        "confusion_matrix": {"TN": TN, "FP": FP, "FN": FN, "TP": TP},
        "accuracy": float(acc),
        "precision": float(prec),
        "recall": float(rec),
        "f1_score": float(f1),
        "sensitivity": float(rec),
        "specificity": float(spec),
        "mcc": float(mcc),
        "p_value_fisher_two_sided": float(pval),
    }

def save_roc_curve_png(y_true, y_prob, out_png):
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    plt.figure()
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curve (Val)")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def save_confusion_png(y_true, y_prob, out_png, thr=0.5):
    y_pred = (np.asarray(y_prob) >= float(thr)).astype(int)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    plt.figure()
    plt.imshow(cm)
    plt.xticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.yticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix (Val, thr={thr:.2f})")
    for (i, j), v in np.ndenumerate(cm):
        plt.text(j, i, str(v), ha="center", va="center")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

# -------------------------
# Seed control
# Purpose: stabilize training order and results per seed
# -------------------------
def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# -------------------------
# One seed run (train + val)
# Inputs: seed, train_df, val_df
# Outputs: best_heads.pt + plots + metrics.json in run_dir
# -------------------------
def run_trainval_once(seed: int):
    set_all_seeds(seed)

    run_dir = EXP_ROOT / f"run_{dataset_id}_seed{seed}"
    run_dir.mkdir(parents=True, exist_ok=True)

    train_ds = AudioManifestDataset(train_df)
    val_ds   = AudioManifestDataset(val_df)

    train_loader = DataLoader(train_ds, batch_size=PER_DEVICE_BS, shuffle=True,
                              num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=collate_fn)
    val_loader   = DataLoader(val_ds, batch_size=PER_DEVICE_BS, shuffle=False,
                              num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=collate_fn)

    # Warm-up
    # Purpose: ensure the dataloader works before training starts
    print(f"\n[seed={seed}] Warm-up: loading 3 train batches...")
    t0 = time.time()
    it = iter(train_loader)
    for i in range(3):
        _ = next(it)
        print(f"  loaded warmup batch {i+1}/3")
    print(f"[seed={seed}] Warm-up done in {time.time()-t0:.2f}s")

    model = Wav2Vec2TwoHeadClassifier(BACKBONE_CKPT, dropout_p=DROPOUT_P).to(DEVICE)

    # Purpose: only train pre-head blocks and heads (backbone stays frozen)
    trainable_params = (
        list(model.pre_vowel.parameters()) + list(model.pre_other.parameters()) +
        list(model.head_vowel.parameters()) + list(model.head_other.parameters())
    )
    opt = torch.optim.Adam(trainable_params, lr=LR)

    use_amp = bool(USE_AMP and DEVICE.type == "cuda")
    scaler = torch.amp.GradScaler("cuda", enabled=use_amp)

    # Purpose: keep the best epoch by VAL AUROC
    best_auc = -1.0
    best_epoch = -1
    no_improve = 0
    best_state = None
    best_val_probs = None
    best_val_true = None
    best_thr_metrics = None

    for epoch in range(1, MAX_EPOCHS + 1):
        # Train
        model.train()
        train_losses = []
        opt.zero_grad(set_to_none=True)

        pbar = tqdm(train_loader, desc=f"[seed={seed}] Train epoch {epoch}", dynamic_ncols=True)
        step = 0
        for batch in pbar:
            step += 1
            input_values = batch["input_values"].to(DEVICE, non_blocking=False)
            attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
            labels = batch["labels"].to(DEVICE, non_blocking=False)
            task_group = batch["task_group"]

            with torch.amp.autocast(device_type="cuda", enabled=use_amp):
                loss, _ = model(input_values, attention_mask, labels, task_group)
                loss = loss / GRAD_ACCUM

            scaler.scale(loss).backward()
            train_losses.append(float(loss.detach().cpu().item()) * GRAD_ACCUM)

            # Purpose: update weights every GRAD_ACCUM steps to reach EFFECTIVE_BS
            if (step % GRAD_ACCUM) == 0:
                scaler.step(opt)
                scaler.update()
                opt.zero_grad(set_to_none=True)

        # Purpose: handle leftover steps when epoch length is not divisible by GRAD_ACCUM
        if (step % GRAD_ACCUM) != 0:
            scaler.step(opt)
            scaler.update()
            opt.zero_grad(set_to_none=True)

        avg_train_loss = float(np.mean(train_losses)) if train_losses else float("nan")

        # Validate
        model.eval()
        all_probs, all_true = [], []
        vpbar = tqdm(val_loader, desc=f"[seed={seed}] Val epoch {epoch}", dynamic_ncols=True)
        with torch.no_grad():
            for batch in vpbar:
                input_values = batch["input_values"].to(DEVICE, non_blocking=False)
                attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
                labels = batch["labels"].to(DEVICE, non_blocking=False)
                task_group = batch["task_group"]

                with torch.amp.autocast(device_type="cuda", enabled=use_amp):
                    _, logits = model(input_values, attention_mask, labels, task_group)

                # Purpose: use PD probability from softmax(logits)
                probs = torch.softmax(logits.float(), dim=-1)[:, 1].detach().cpu().numpy().astype(np.float64)
                all_probs.extend(probs.tolist())
                all_true.extend(labels.detach().cpu().numpy().astype(np.int64).tolist())

        val_auc = compute_auc(all_true, all_probs)
        print(f"seed={seed} | epoch {epoch:02d}/{MAX_EPOCHS} | train_loss={avg_train_loss:.5f} | val_AUROC={val_auc:.5f}")

        # Purpose: store the best epoch by AUROC and reset early-stop counter
        improved = (not math.isnan(val_auc)) and (val_auc > best_auc + 1e-12)
        if improved:
            best_auc = float(val_auc)
            best_epoch = int(epoch)
            no_improve = 0
            best_state = {
                "pre_vowel": {k: v.detach().cpu().clone() for k, v in model.pre_vowel.state_dict().items()},
                "pre_other": {k: v.detach().cpu().clone() for k, v in model.pre_other.state_dict().items()},
                "head_vowel": {k: v.detach().cpu().clone() for k, v in model.head_vowel.state_dict().items()},
                "head_other": {k: v.detach().cpu().clone() for k, v in model.head_other.state_dict().items()},
            }
            best_val_probs = list(all_probs)
            best_val_true  = list(all_true)
            best_thr_metrics = compute_threshold_metrics(best_val_true, best_val_probs, thr=VAL_THRESHOLD)
        else:
            no_improve += 1

        # Purpose: stop when VAL AUROC has not improved for PATIENCE epochs
        if no_improve >= PATIENCE:
            break

    # Purpose: avoid writing partial artifacts if no valid best epoch exists
    if best_state is None or best_val_probs is None or best_val_true is None or best_thr_metrics is None:
        raise RuntimeError("No best epoch captured. Validation AUROC may be NaN due to single-class validation split.")

    # Save best epoch artifacts
    best_heads_path = run_dir / "best_heads.pt"
    torch.save(best_state, str(best_heads_path))

    roc_png = run_dir / "roc_curve.png"
    cm_png  = run_dir / "confusion_matrix.png"
    ytrue_np = np.asarray(best_val_true, dtype=np.int64)
    yprob_np = np.asarray(best_val_probs, dtype=np.float64)

    save_roc_curve_png(ytrue_np, yprob_np, str(roc_png))
    save_confusion_png(ytrue_np, yprob_np, str(cm_png), thr=VAL_THRESHOLD)

    metrics = {
        "dataset": dataset_id,
        "seed": int(seed),
        "best_val_auroc": float(best_auc),
        "best_epoch": int(best_epoch),
        "n_train": int(len(train_df)),
        "n_val": int(len(val_df)),
        "label_counts_train": train_df["label_num"].value_counts(dropna=False).to_dict(),
        "label_counts_val": val_df["label_num"].value_counts(dropna=False).to_dict(),
        "experiment_tag": EXPERIMENT_TAG,
        "run_stamp": RUN_STAMP,
        "dropout_p": float(DROPOUT_P),
        "lr": float(LR),
        "effective_batch_size": int(PER_DEVICE_BS * GRAD_ACCUM),
        "per_device_batch_size": int(PER_DEVICE_BS),
        "grad_accum_steps": int(GRAD_ACCUM),
        "val_threshold": float(VAL_THRESHOLD),
        "backbone_ckpt": BACKBONE_CKPT,
        "threshold_metrics_best_epoch": best_thr_metrics,
    }
    metrics_path = run_dir / "metrics.json"
    with open(metrics_path, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2)

    print(f"[seed={seed}] WROTE:")
    print(" ", str(metrics_path))
    print(" ", str(roc_png))
    print(" ", str(cm_png))
    print(" ", str(best_heads_path))

    return float(best_auc), str(run_dir), best_thr_metrics

# -------------------------
# Run all seeds + write exp summary
# Inputs: SEEDS list
# Outputs: exp_root/summary_trainval.json and one appended history record
# -------------------------
aucs = []
run_dirs = []
per_seed_metrics = []

for seed in SEEDS:
    a, rd, thrm = run_trainval_once(seed)
    aucs.append(a)
    run_dirs.append(rd)
    per_seed_metrics.append({"seed": int(seed), "best_val_auroc": float(a), "threshold_metrics_best_epoch": thrm})

t_crit = 4.302652729911275  # df=2, 95% CI
n = len(aucs)
mean_auc = float(np.mean(aucs))
std_auc = float(np.std(aucs, ddof=1)) if n > 1 else 0.0
half_width = float(t_crit * (std_auc / math.sqrt(n))) if n > 1 else 0.0

print("\nAUROC by seed:")
for s, a in zip(SEEDS, aucs):
    print(f"  seed {s}: {a:.6f}")
print(f"\nMean AUROC: {mean_auc:.6f}")
print(f"95% CI (t, n=3): [{mean_auc - half_width:.6f}, {mean_auc + half_width:.6f}]")

exp_summary = {
    "dataset": dataset_id,
    "dx_out_root": DX_OUT_ROOT,
    "experiment_tag": EXPERIMENT_TAG,
    "run_stamp": RUN_STAMP,
    "exp_root": str(EXP_ROOT),
    "run_dirs": run_dirs,
    "seeds": SEEDS,
    "aurocs": [float(x) for x in aucs],
    "mean_auroc": float(mean_auc),
    "t_crit_df2_95": float(t_crit),
    "half_width_95_ci": float(half_width),
    "ci95": [float(mean_auc - half_width), float(mean_auc + half_width)],
    "n_train": int(len(train_df)),
    "n_val": int(len(val_df)),
    "label_counts_train": train_df["label_num"].value_counts(dropna=False).to_dict(),
    "label_counts_val": val_df["label_num"].value_counts(dropna=False).to_dict(),
    "effective_batch_size": int(PER_DEVICE_BS * GRAD_ACCUM),
    "per_device_batch_size": int(PER_DEVICE_BS),
    "grad_accum_steps": int(GRAD_ACCUM),
    "val_threshold": float(VAL_THRESHOLD),
    "backbone_ckpt": BACKBONE_CKPT,
    "dropout_p": float(DROPOUT_P),
    "lr": float(LR),
    "per_seed_best_epoch_metrics": per_seed_metrics,
}

summary_path = EXP_ROOT / "summary_trainval.json"
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(exp_summary, f, indent=2)

with open(HISTORY_INDEX_PATH, "a", encoding="utf-8") as f:
    f.write(json.dumps(exp_summary) + "\n")

print("\nWROTE per-experiment summary:", str(summary_path))
print("APPENDED global history index:", str(HISTORY_INDEX_PATH))
print("\nOpen this folder to access artifacts:", str(EXP_ROOT))

# -------------------------
# Stop runtime
# Purpose: release the GPU machine after outputs are written
# -------------------------
print("\nAll done. Unassigning the runtime to stop the L4 instance...")
try:
    from google.colab import runtime  # type: ignore
    print("Calling runtime.unassign() now.")
    runtime.unassign()
except Exception as e:
    print("Could not unassign runtime automatically. You can stop the runtime manually in Colab.")
    print("Reason:", repr(e))

#Monolingual Testing of the D1, D2, D4, D5 and D6 Models on Their Own Datasets

The following cell evaluates the D1 (NeuroVoz Spanish) model on the D1 test split in a test only setting, without any retraining. It reads `manifests/manifest_all.csv` from the D1 `preprocessed_v1` folder, using `DX_OUT_ROOT` if it is already defined at runtime, otherwise falling back to the default D1 path. Only the validation and test splits are used: validation is used to choose a decision threshold, and test is used to report final results. Two safety checks are enforced. First, the manifest must clearly indicate that the dataset is D1, otherwise the cell stops to avoid evaluating the wrong dataset. Second, after locating the most recent training run, the cell confirms that `best_heads.pt` exists for all three seeds before continuing.

Before evaluation starts, the cell checks the runtime environment to avoid issues such as local files shadowing PyTorch or Transformers. It mounts Google Drive if needed, sets consistent evaluation settings including a 16 kHz sample rate, three seeds, batch size, and mixed precision on the GPU, and prints key paths and hardware details. The manifest schema is validated to ensure required columns are present, basic counts for the test split are printed, and a fail fast scan confirms that all audio files referenced by the manifest exist on disk.

The cell then prepares inputs for inference. Each clip is assigned to a task group using a strict rule: `task == "vowl"` is treated as vowel speech, and everything else is treated as other speech. Sex values are normalized to M, F, or UNK. For D1, sex is encoded numerically in the manifest, so 0 is mapped to F and 1 is mapped to M, with common text labels also supported. A dataset loader reads each audio file, converts stereo audio to mono if needed, enforces a 16 kHz sample rate, and builds an attention mask. For vowel clips, the mask attempts to ignore trailing near silence so padded or silent regions contribute less to the pooled features. A collator pads audio in each batch to the same length and pads the attention masks in the same way. A short warm up step loads a few batches from both the validation and test splits to catch data issues early.

The model architecture matches the D1 training setup. It uses a frozen Wav2Vec2 backbone with two separate heads, one for vowel clips and one for other speech. Each head includes a small preprocessing block made of LayerNorm and Dropout followed by a linear classifier, and each clip is routed to the correct head based on its task group. The cell automatically finds the most recent train and validation experiment that contains `best_heads.pt` for all three seeds and loads the saved head weights for each seed.

For each seed, evaluation is performed in two stages. First, a validation pass runs inference on the validation split, computes validation AUROC, and selects a validation optimal threshold using Youden’s J statistic, which maximizes the difference between true positive rate and false positive rate. Second, a test pass runs inference on the test split and computes test AUROC and threshold based metrics using the threshold selected from validation, not a fixed value of 0.5. If the validation split cannot produce a threshold, for example if it contains only one class, the seed falls back to a threshold of 0.5 and records a note explaining the reason.

For each seed, the cell writes a full set of results under
`<DX_OUT_ROOT>/monolingual_test_runs/run_D1_seed<seed>/`.
This includes `metrics.json`, a ROC curve image, an overall confusion matrix at the selected threshold, and additional confusion matrices split by sex for M and F when available. Fairness is computed on the test split at the same threshold using the project’s H3 definition. This is based on the false negative rate for male and female Parkinson’s cases only, with ΔFNR defined as FNR(F) minus FNR(M). The absolute difference is also stored. Confusion counts by sex group, including UNK, are recorded as well.

After all three seeds finish, the cell aggregates results across seeds and prints them. This includes the mean test AUROC with a 95 percent confidence interval computed using a t distribution with n equal to 3, the mean and standard deviation of the validation selected thresholds, the mean and standard deviation of threshold based test metrics, and the mean and standard deviation of the fairness values. A combined `summary_test.json` file is written to `monolingual_test_runs/`, and the same summary is appended to `history_index.jsonl` so multiple runs can be tracked over time. Finally, the Colab runtime is unassigned to stop the GPU instance.

In [None]:
# =========================
# D1 Test Only — Val-Optimal Threshold + Fairness (Youden J)
# Inputs: manifest_all.csv (VAL for threshold, TEST for reporting), best_heads.pt from latest trainval run
# Outputs: per-seed metrics.json + ROC/confusion PNGs, plus monolingual_test_runs/summary_test.json and history_index.jsonl
# =========================

import os, json, math, random, time, warnings
from pathlib import Path

import numpy as np
import pandas as pd
import soundfile as sf

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import Wav2Vec2Model

from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, matthews_corrcoef
from scipy.stats import fisher_exact
import matplotlib.pyplot as plt

from tqdm.auto import tqdm

# -------------------------
# Runtime sanity checks
# Purpose: avoid accidental imports from local files that shadow installed libraries
# -------------------------
if os.path.exists("/content/torch.py") or os.path.exists("/content/torch/__init__.py"):
    raise RuntimeError("Found local /content/torch.py or /content/torch/ that shadows PyTorch. Rename/remove it and restart runtime.")
if os.path.exists("/content/transformers.py") or os.path.exists("/content/transformers/__init__.py"):
    raise RuntimeError("Found local /content/transformers.py or /content/transformers/ that shadows Hugging Face Transformers. Rename/remove it and restart runtime.")

# -------------------------
# Drive mount
# Purpose: ensure files under MyDrive can be read and written
# -------------------------
if not os.path.isdir("/content/drive/MyDrive"):
    from google.colab import drive  # type: ignore
    drive.mount("/content/drive")

# -------------------------
# Paths
# Inputs: DX_OUT_ROOT (preferred if already defined), otherwise D1 fallback
# Output: MANIFEST_ALL path used throughout the cell
# -------------------------
D1_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D1-NeuroVoz-Castillan Spanish/preprocessed_v1"
DX_OUT_ROOT = globals().get("DX_OUT_ROOT", D1_OUT_ROOT_FALLBACK)
MANIFEST_ALL = f"{DX_OUT_ROOT}/manifests/manifest_all.csv"

# -------------------------
# Fixed run settings
# Purpose: match trainval assumptions and keep reporting consistent across seeds
# -------------------------
SEEDS          = [1337, 2024, 7777]
BACKBONE_CKPT  = "facebook/wav2vec2-base"
SR_EXPECTED    = 16000
TINY_THRESH    = 1e-4

EFFECTIVE_BS   = 64
PER_DEVICE_BS  = 16
GRAD_ACCUM     = max(1, EFFECTIVE_BS // PER_DEVICE_BS)

DROPOUT_P      = 0.2

NUM_WORKERS    = 0
PIN_MEMORY     = False

USE_AMP        = True

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
warnings.filterwarnings("ignore", message="Passing `gradient_checkpointing` to a config initialization is deprecated")

print("DX_OUT_ROOT:", DX_OUT_ROOT)
print("MANIFEST_ALL:", MANIFEST_ALL)
print("DEVICE:", DEVICE)
if DEVICE.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
print("PER_DEVICE_BS:", PER_DEVICE_BS, "| EFFECTIVE_BS:", PER_DEVICE_BS * GRAD_ACCUM)
print("NUM_WORKERS:", NUM_WORKERS, "| PIN_MEMORY:", PIN_MEMORY)
print("USE_AMP:", bool(USE_AMP and DEVICE.type == "cuda"))

# -------------------------
# Read manifest and build VAL/TEST tables
# Inputs: manifest_all.csv
# Outputs: val_df and test_df with a standardized set of columns
# -------------------------
if not os.path.exists(MANIFEST_ALL):
    raise FileNotFoundError(f"Missing manifest_all.csv: {MANIFEST_ALL}")

m_all = pd.read_csv(MANIFEST_ALL)

# Purpose: enforce the minimum schema needed for metrics + fairness reporting
req_cols = {"split", "clip_path", "label_num", "task", "sex", "age"}
missing = [c for c in sorted(req_cols) if c not in m_all.columns]
if missing:
    raise ValueError(f"Manifest missing required columns: {missing}. Found: {list(m_all.columns)}")

# Purpose: identify dataset id from the manifest, then filter to that dataset
if "dataset" in m_all.columns and m_all["dataset"].notna().any():
    dataset_id = str(m_all["dataset"].astype(str).value_counts(dropna=True).idxmax())
    m_all = m_all[m_all["dataset"].astype(str) == dataset_id].copy()
else:
    dataset_id = "DX"

# Guard A
# Purpose: prevent running D1 evaluation on the wrong dataset by accident
if dataset_id != "D1":
    raise RuntimeError(
        f"Expected dataset_id=='D1' but got {dataset_id!r}. "
        "This usually means DX_OUT_ROOT was inherited from a previous cell or the manifest is not D1. "
        f"DX_OUT_ROOT={DX_OUT_ROOT}"
    )

# Purpose: keep a consistent column set even if some optional columns are missing
keep_cols = ["clip_path", "label_num", "task", "speaker_id", "sex", "age", "duration_sec", "split"]
for c in keep_cols:
    if c not in m_all.columns:
        m_all[c] = np.nan
m_all = m_all[keep_cols].copy()

val_df  = m_all[m_all["split"].isin(["val"])].reset_index(drop=True)
test_df = m_all[m_all["split"].isin(["test"])].reset_index(drop=True)

print(f"\nDataset inferred: {dataset_id}")
print(f"VAL rows:  {len(val_df)}")
print(f"TEST rows: {len(test_df)}")
print("TEST label counts:", test_df["label_num"].value_counts(dropna=False).to_dict())
print("TEST sex counts (raw):", test_df["sex"].value_counts(dropna=False).to_dict())

# Purpose: fail early if required splits are missing
if len(val_df) == 0:
    raise RuntimeError("After filtering to split=='val', manifest has 0 rows (VAL is required for Youden-J threshold).")
if len(test_df) == 0:
    raise RuntimeError("After filtering to split=='test', manifest has 0 rows.")

# -------------------------
# Clip existence check
# Inputs: val_df/test_df clip_path
# Output: raises early if any required audio files are missing (shows a few examples)
# -------------------------
def _fail_fast_missing_paths(df: pd.DataFrame, name: str):
    missing_paths = []
    for p in tqdm(df["clip_path"].astype(str).tolist(), desc=f"Check {name} clip_path exists", dynamic_ncols=True):
        if not os.path.exists(p):
            missing_paths.append(p)
            if len(missing_paths) >= 10:
                break
    if missing_paths:
        raise FileNotFoundError(f"{name}: missing clip_path(s). Examples: {missing_paths}")

_fail_fast_missing_paths(val_df, "VAL")
_fail_fast_missing_paths(test_df, "TEST")

# -------------------------
# Task grouping
# Purpose: decide which head to use (vowel head vs other head)
# Output: task_group column in both VAL and TEST
# -------------------------
def _task_group(task_val) -> str:
    return "vowel" if str(task_val) == "vowl" else "other"

val_df["task_group"]  = val_df["task"].apply(_task_group)
test_df["task_group"] = test_df["task"].apply(_task_group)

# -------------------------
# Sex normalization (D1 encoding)
# Inputs: manifest sex column (0/1 in D1)
# Outputs: sex_norm in {M, F, UNK} for fairness and sex-split charts
# -------------------------
def normalize_sex(val) -> str:
    """
    Returns 'M', 'F', or 'UNK'

    D1 manifest encoding:
      0 -> F
      1 -> M

    Also handles common string encodings: M/F, Male/Female, etc.
    """
    if pd.isna(val):
        return "UNK"

    # numeric (float/int-like)
    try:
        fv = float(val)
        if np.isfinite(fv) and abs(fv - round(fv)) < 1e-9:
            iv = int(round(fv))
            if iv == 0:
                return "F"
            if iv == 1:
                return "M"
    except Exception:
        pass

    # string
    s = str(val).strip().lower()
    if s in {"m", "male", "man", "masc", "masculine"}:
        return "M"
    if s in {"f", "female", "woman", "fem", "feminine"}:
        return "F"

    return "UNK"

val_df["sex_norm"]  = val_df["sex"].apply(normalize_sex)
test_df["sex_norm"] = test_df["sex"].apply(normalize_sex)

print("VAL sex counts (normalized): ", val_df["sex_norm"].value_counts(dropna=False).to_dict())
print("TEST sex counts (normalized):", test_df["sex_norm"].value_counts(dropna=False).to_dict())
if (val_df["sex_norm"] == "UNK").any() or (test_df["sex_norm"] == "UNK").any():
    print("NOTE: Some 'sex' values could not be normalized to M/F and were counted as 'UNK' for fairness and sex charts.")

# -------------------------
# Audio dataset + padding collator
# Inputs: val_df/test_df
# Outputs: batches with input_values, attention_mask, labels, task_group, sex_norm
# -------------------------
class AudioManifestDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        clip_path = str(row["clip_path"])
        label = int(row["label_num"])
        task_group = str(row["task_group"])
        sex_norm = str(row["sex_norm"])

        if not os.path.exists(clip_path):
            raise FileNotFoundError(f"Missing clip_path: {clip_path}")

        # Purpose: load audio, force mono, enforce expected sample rate
        y, sr = sf.read(clip_path, always_2d=False)
        if isinstance(y, np.ndarray) and y.ndim > 1:
            y = np.mean(y, axis=1)
        y = np.asarray(y, dtype=np.float32)

        if int(sr) != SR_EXPECTED:
            raise ValueError(f"Sample rate mismatch (expected {SR_EXPECTED}): {clip_path} has sr={sr}")

        # Purpose: for vowels, ignore padded trailing silence when pooling
        attn = np.ones((len(y),), dtype=np.int64)
        if task_group == "vowel":
            k = -1
            for j in range(len(y) - 1, -1, -1):
                if abs(float(y[j])) > TINY_THRESH:
                    k = j
                    break
            if k >= 0:
                attn[:k+1] = 1
                attn[k+1:] = 0
            else:
                attn[:] = 1
        else:
            attn[:] = 1

        return {
            "input_values": torch.from_numpy(y),
            "attention_mask": torch.from_numpy(attn),
            "labels": torch.tensor(label, dtype=torch.long),
            "task_group": task_group,
            "sex_norm": sex_norm,
        }

def collate_fn(batch):
    # Purpose: pad variable-length audio to the max length in the batch
    max_len = int(max(b["input_values"].numel() for b in batch))
    input_vals, attn_masks, labels, task_groups, sex_norms = [], [], [], [], []
    for b in batch:
        x = b["input_values"]
        a = b["attention_mask"]
        pad = max_len - x.numel()
        if pad > 0:
            x = torch.cat([x, torch.zeros(pad, dtype=x.dtype)], dim=0)
            a = torch.cat([a, torch.zeros(pad, dtype=a.dtype)], dim=0)
        input_vals.append(x)
        attn_masks.append(a)
        labels.append(b["labels"])
        task_groups.append(b["task_group"])
        sex_norms.append(b["sex_norm"])
    return {
        "input_values": torch.stack(input_vals, dim=0),
        "attention_mask": torch.stack(attn_masks, dim=0),
        "labels": torch.stack(labels, dim=0),
        "task_group": task_groups,
        "sex_norm": sex_norms,
    }

# -------------------------
# Model (frozen backbone + two heads)
# Inputs: wav2vec2 checkpoint and dropout
# Output: logits for PD vs Healthy using the head chosen by task_group
# -------------------------
class Wav2Vec2TwoHeadClassifier(nn.Module):
    def __init__(self, ckpt: str, dropout_p: float):
        super().__init__()
        self.backbone = Wav2Vec2Model.from_pretrained(ckpt, use_safetensors=True)
        for p in self.backbone.parameters():
            p.requires_grad = False

        H = int(self.backbone.config.hidden_size)
        self.pre_vowel = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.pre_other = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.head_vowel = nn.Linear(H, 2)
        self.head_other = nn.Linear(H, 2)

    def masked_mean_pool(self, last_hidden, attn_mask_samples):
        # Purpose: pool frame features using a mask derived from the sample-level attention mask
        feat_mask = self.backbone._get_feature_vector_attention_mask(last_hidden.shape[1], attn_mask_samples)
        feat_mask = feat_mask.to(last_hidden.device).unsqueeze(-1).type_as(last_hidden)
        masked = last_hidden * feat_mask
        denom = feat_mask.sum(dim=1).clamp(min=1.0)
        return masked.sum(dim=1) / denom

    def _heads_fp32(self, x_fp_any: torch.Tensor, head: nn.Module) -> torch.Tensor:
        # Purpose: run small heads in fp32 even when AMP is enabled
        x = x_fp_any.float()
        if x.is_cuda:
            with torch.amp.autocast(device_type="cuda", enabled=False):
                return head(x)
        return head(x)

    def forward_logits(self, input_values, attention_mask, task_group):
        with torch.no_grad():
            out = self.backbone(input_values=input_values, attention_mask=attention_mask)
            last_hidden = out.last_hidden_state

        pooled = self.masked_mean_pool(last_hidden, attention_mask)  # [B,H]

        z_v = self.pre_vowel(pooled.float())
        z_o = self.pre_other(pooled.float())

        logits_v = self._heads_fp32(z_v, self.head_vowel)
        logits_o = self._heads_fp32(z_o, self.head_other)

        # Purpose: choose the correct head per sample based on task_group
        is_vowel = torch.tensor([tg == "vowel" for tg in task_group], device=pooled.device, dtype=torch.bool)
        logits = logits_o.clone()
        if is_vowel.any():
            logits[is_vowel] = logits_v[is_vowel]
        return logits

# -------------------------
# Metrics and plots
# Inputs: y_true and y_prob (PD probability)
# Outputs: AUROC, threshold metrics, ROC/confusion PNG files
# -------------------------
def compute_auc(y_true, y_prob):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    if len(np.unique(y_true)) < 2:
        return float("nan")
    return float(roc_auc_score(y_true, y_prob))

def compute_threshold_metrics(y_true, y_prob, thr=0.5):
    # Purpose: compute threshold-based metrics and confusion counts
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])

    eps = 1e-12
    acc = (TP + TN) / max(1, (TP + TN + FP + FN))
    prec = TP / (TP + FP + eps)
    rec = TP / (TP + FN + eps)     # sensitivity
    f1 = 2 * prec * rec / (prec + rec + eps)
    spec = TN / (TN + FP + eps)

    try:
        mcc = float(matthews_corrcoef(y_true, y_pred)) if len(np.unique(y_pred)) > 1 else float("nan")
    except Exception:
        mcc = float("nan")

    try:
        _, pval = fisher_exact([[TN, FP], [FN, TP]], alternative="two-sided")
        pval = float(pval)
    except Exception:
        pval = float("nan")

    return {
        "threshold": float(thr),
        "confusion_matrix": {"TN": TN, "FP": FP, "FN": FN, "TP": TP},
        "accuracy": float(acc),
        "precision": float(prec),
        "recall": float(rec),
        "f1_score": float(f1),
        "sensitivity": float(rec),
        "specificity": float(spec),
        "mcc": float(mcc),
        "p_value_fisher_two_sided": float(pval),
    }

def save_roc_curve_png(y_true, y_prob, out_png, title_suffix="Test"):
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    plt.figure()
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curve ({title_suffix})")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def save_confusion_png(y_true, y_prob, out_png, thr=0.5, title_suffix="Test"):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(int)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])

    plt.figure()
    plt.imshow(cm)
    plt.xticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.yticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix ({title_suffix}, thr={thr:.2f})")
    for (i, j), v in np.ndenumerate(cm):
        plt.text(j, i, str(v), ha="center", va="center")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

# -------------------------
# Val-optimal threshold (Youden J)
# Input: VAL labels and probabilities
# Output: threshold that maximizes (TPR - FPR) on the VAL ROC curve
# -------------------------
def youden_j_optimal_threshold(y_true, y_prob):
    """
    Returns:
      thr_opt, j_opt, tpr_opt, fpr_opt
    If undefined (single-class), returns (nan, nan, nan, nan).
    """
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    if len(np.unique(y_true)) < 2:
        return float("nan"), float("nan"), float("nan"), float("nan")

    fpr, tpr, thr = roc_curve(y_true, y_prob)
    j = tpr - fpr
    if len(j) == 0:
        return float("nan"), float("nan"), float("nan"), float("nan")

    best_idx = int(np.nanargmax(j))
    return float(thr[best_idx]), float(j[best_idx]), float(tpr[best_idx]), float(fpr[best_idx])

# -------------------------
# Fairness (H3): FNR by sex and ΔFNR
# Inputs: TEST labels, probabilities, sex_norm, and chosen threshold
# Outputs: per-group FNR and ΔFNR = FNR(F) - FNR(M)
# -------------------------
def compute_fnr_by_group_signed(y_true, y_prob, groups, thr=0.5):
    """
    FNR = FN/(FN+TP) computed on POSITIVE ground-truth samples only.
    Returns:
      - per-group: n_total, n_pos, tp, fn, fnr
      - delta_f_minus_m (signed): FNR(F) - FNR(M) when both defined
      - delta_abs_m_f: |FNR(F) - FNR(M)|
    """
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    groups = np.asarray(list(groups), dtype=object)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    out = {}
    for g in sorted(set(groups.tolist())):
        mask_g = (groups == g)
        if mask_g.sum() == 0:
            continue

        pos_mask = mask_g & (y_true == 1)
        n_pos = int(pos_mask.sum())
        if n_pos == 0:
            out[g] = {"n_total": int(mask_g.sum()), "n_pos": 0, "tp": 0, "fn": 0, "fnr": float("nan")}
            continue

        tp = int(((y_pred == 1) & pos_mask).sum())
        fn = int(((y_pred == 0) & pos_mask).sum())
        fnr = float(fn / max(1, (fn + tp)))

        out[g] = {"n_total": int(mask_g.sum()), "n_pos": int(n_pos), "tp": int(tp), "fn": int(fn), "fnr": float(fnr)}

    fnr_m = out.get("M", {}).get("fnr", float("nan"))
    fnr_f = out.get("F", {}).get("fnr", float("nan"))

    if (not np.isnan(fnr_m)) and (not np.isnan(fnr_f)):
        delta_signed = float(fnr_f - fnr_m)         # H3: FNR(F) - FNR(M)
        delta_abs = float(abs(delta_signed))
    else:
        delta_signed = float("nan")
        delta_abs = float("nan")

    return out, delta_signed, delta_abs

# -------------------------
# Confusion counts by sex
# Inputs: TEST labels, probabilities, sex_norm, threshold
# Output: TN/FP/FN/TP counts per sex group
# -------------------------
def compute_confusion_counts(y_true, y_prob, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])
    return {"TN": TN, "FP": FP, "FN": FN, "TP": TP}

def compute_confusion_by_group(y_true, y_prob, groups, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    groups = np.asarray(list(groups), dtype=object)
    out = {}
    for g in sorted(set(groups.tolist())):
        mask = (groups == g)
        if int(mask.sum()) == 0:
            continue
        out[g] = {
            "n": int(mask.sum()),
            "confusion": compute_confusion_counts(y_true[mask], y_prob[mask], thr=thr),
        }
    return out

# -------------------------
# Reproducibility seeding
# Purpose: stabilize dataloader order and any stochastic ops
# -------------------------
def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# -------------------------
# Pick latest trainval experiment with all required head checkpoints
# Inputs: trainval_runs/exp_* folders
# Output: chosen_exp path
# -------------------------
TRAINVAL_ROOT = Path(DX_OUT_ROOT) / "trainval_runs"
if not TRAINVAL_ROOT.exists():
    raise FileNotFoundError(f"Missing trainval_runs folder: {str(TRAINVAL_ROOT)}")

exp_dirs = sorted([p for p in TRAINVAL_ROOT.glob("exp_*") if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
if not exp_dirs:
    raise FileNotFoundError(f"No exp_* folders found under: {str(TRAINVAL_ROOT)}")

def _has_all_seeds(exp_path: Path, dataset_id: str, seeds: list):
    for s in seeds:
        p = exp_path / f"run_{dataset_id}_seed{s}" / "best_heads.pt"
        if not p.exists():
            return False
    return True

chosen_exp = None
for ed in exp_dirs:
    if _has_all_seeds(ed, dataset_id, SEEDS):
        chosen_exp = ed
        break

if chosen_exp is None:
    sample = exp_dirs[0]
    raise FileNotFoundError(
        "Could not find a recent trainval experiment with all 3 best_heads.pt files.\n"
        f"Expected under: {str(TRAINVAL_ROOT)}/exp_*/run_{dataset_id}_seedXXXX/best_heads.pt\n"
        f"Most recent exp checked: {str(sample)}"
    )

print("\nUsing Train+Val experiment folder:")
print(" ", str(chosen_exp))

# Guard B
# Purpose: confirm artifacts exist right after selecting the experiment folder
for s in SEEDS:
    p = chosen_exp / f"run_{dataset_id}_seed{s}" / "best_heads.pt"
    if not p.exists():
        raise RuntimeError(f"Trainval artifact missing after choosing exp. Missing: {str(p)}")

# -------------------------
# Output folder for this evaluation
# Output: per-seed run_... folders and aggregated summary files
# -------------------------
TEST_ROOT = Path(DX_OUT_ROOT) / "monolingual_test_runs"
TEST_ROOT.mkdir(parents=True, exist_ok=True)

# -------------------------
# Dataloaders for VAL and TEST
# Inputs: val_df/test_df
# Output: val_loader/test_loader used by inference
# -------------------------
val_ds = AudioManifestDataset(val_df)
val_loader = DataLoader(
    val_ds,
    batch_size=PER_DEVICE_BS,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    collate_fn=collate_fn,
)

test_ds = AudioManifestDataset(test_df)
test_loader = DataLoader(
    test_ds,
    batch_size=PER_DEVICE_BS,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    collate_fn=collate_fn,
)

# -------------------------
# Loader warm-up
# Purpose: catch dataloader issues early (empty batches, IO problems)
# -------------------------
print("\nWarm-up: loading up to 3 VAL + TEST batches...")
t0 = time.time()

def _warmup(loader, name):
    nb = len(loader)
    wb = min(3, nb)
    if wb == 0:
        raise RuntimeError(f"{name} DataLoader has 0 batches. Check df length and PER_DEVICE_BS.")
    it = iter(loader)
    for i in range(wb):
        _ = next(it)
        print(f"  loaded warmup {name} batch {i+1}/{wb}")

_warmup(val_loader, "VAL")
_warmup(test_loader, "TEST")
print(f"Warm-up done in {time.time()-t0:.2f}s")

# -------------------------
# Load trained heads into the model
# Inputs: best_heads.pt written by D1 trainval
# Output: model ready for inference
# -------------------------
def load_heads_into_model(model: Wav2Vec2TwoHeadClassifier, best_heads_path: Path):
    if not best_heads_path.exists():
        raise FileNotFoundError(f"Missing best_heads.pt: {str(best_heads_path)}")
    state = torch.load(str(best_heads_path), map_location="cpu")
    model.pre_vowel.load_state_dict(state["pre_vowel"], strict=True)
    model.pre_other.load_state_dict(state["pre_other"], strict=True)
    model.head_vowel.load_state_dict(state["head_vowel"], strict=True)
    model.head_other.load_state_dict(state["head_other"], strict=True)
    return model

# -------------------------
# Inference helper
# Inputs: loader and model
# Outputs: y_true, y_prob(PD), and sex_norm for each clip
# -------------------------
def _infer_probs(loader, model, desc):
    use_amp = bool(USE_AMP and DEVICE.type == "cuda")
    all_probs, all_true, all_sex = [], [], []

    pbar = tqdm(loader, desc=desc, dynamic_ncols=True)
    with torch.inference_mode():
        for batch in pbar:
            input_values = batch["input_values"].to(DEVICE, non_blocking=False)
            attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
            labels = batch["labels"].to(DEVICE, non_blocking=False)
            task_group = batch["task_group"]
            sex_norm = batch["sex_norm"]

            with torch.amp.autocast(device_type="cuda", enabled=use_amp):
                logits = model.forward_logits(input_values, attention_mask, task_group)

            probs = torch.softmax(logits.float(), dim=-1)[:, 1].detach().cpu().numpy().astype(np.float64)
            all_probs.extend(probs.tolist())
            all_true.extend(labels.detach().cpu().numpy().astype(np.int64).tolist())
            all_sex.extend(list(sex_norm))

    return np.asarray(all_true, dtype=np.int64), np.asarray(all_probs, dtype=np.float64), np.asarray(all_sex, dtype=object)

# -------------------------
# One seed run: VAL threshold, then TEST reporting
# Inputs: seed, val_loader, test_loader, chosen_exp/best_heads.pt
# Outputs: run_dir artifacts + metrics.json, and a compact console summary
# -------------------------
def run_test_once(seed: int):
    set_all_seeds(seed)

    run_dir = TEST_ROOT / f"run_{dataset_id}_seed{seed}"
    run_dir.mkdir(parents=True, exist_ok=True)

    best_heads_path = chosen_exp / f"run_{dataset_id}_seed{seed}" / "best_heads.pt"

    print(f"\n[seed={seed}] Loading model + heads from:")
    print(" ", str(best_heads_path))

    model = Wav2Vec2TwoHeadClassifier(BACKBONE_CKPT, dropout_p=DROPOUT_P).to(DEVICE)
    model = load_heads_into_model(model, best_heads_path)
    model.eval()

    # VAL: select threshold only
    yv_true, yv_prob, _ = _infer_probs(val_loader, model, desc=f"[seed={seed}] Val")
    val_auc = compute_auc(yv_true, yv_prob)
    val_thr, val_j, val_tpr, val_fpr = youden_j_optimal_threshold(yv_true, yv_prob)

    print(f"[seed={seed}] VAL Youden-J threshold: {val_thr:.6f}" if not np.isnan(val_thr) else f"[seed={seed}] VAL Youden-J threshold: nan")

    # TEST: report at fixed val-selected threshold
    yt_true, yt_prob, yt_sex = _infer_probs(test_loader, model, desc=f"[seed={seed}] Test")
    test_auc = compute_auc(yt_true, yt_prob)

    # Purpose: avoid crashing if VAL ROC cannot be formed (single-class VAL)
    if np.isnan(val_thr):
        thr_used = 0.5
        thr_note = "VAL Youden-J threshold was NaN (single-class VAL). Fallback thr_used=0.5 for this seed only."
    else:
        thr_used = float(val_thr)
        thr_note = None

    thr_metrics_test = compute_threshold_metrics(yt_true, yt_prob, thr=thr_used)

    fnr_by_sex, delta_f_minus_m, delta_abs = compute_fnr_by_group_signed(
        yt_true, yt_prob, yt_sex, thr=thr_used
    )

    confusion_by_sex = compute_confusion_by_group(yt_true, yt_prob, yt_sex, thr=thr_used)

    # Plots: overall and sex-split confusion matrices
    roc_png = run_dir / "roc_curve.png"
    cm_png  = run_dir / "confusion_matrix.png"
    save_roc_curve_png(yt_true, yt_prob, str(roc_png), title_suffix=f"Test (seed={seed})")
    save_confusion_png(yt_true, yt_prob, str(cm_png), thr=thr_used, title_suffix=f"Test (seed={seed})")

    cm_m_png = None
    cm_f_png = None
    mask_m = (yt_sex == "M")
    mask_f = (yt_sex == "F")

    if int(mask_m.sum()) > 0:
        cm_m_png = run_dir / "confusion_matrix_M.png"
        save_confusion_png(
            yt_true[mask_m], yt_prob[mask_m], str(cm_m_png),
            thr=thr_used, title_suffix=f"Test SEX=M (seed={seed})"
        )

    if int(mask_f.sum()) > 0:
        cm_f_png = run_dir / "confusion_matrix_F.png"
        save_confusion_png(
            yt_true[mask_f], yt_prob[mask_f], str(cm_f_png),
            thr=thr_used, title_suffix=f"Test SEX=F (seed={seed})"
        )

    # Output: one JSON per seed with the key values needed for later analysis and paper writing
    metrics = {
        "dataset": dataset_id,
        "seed": int(seed),

        "n_val": int(len(val_df)),
        "label_counts_val": val_df["label_num"].value_counts(dropna=False).to_dict(),
        "sex_counts_val_norm": val_df["sex_norm"].value_counts(dropna=False).to_dict(),
        "val_auroc": float(val_auc),

        "val_optimal_threshold": {
            "method": "Youden J (argmax(TPR - FPR) on VAL ROC)",
            "threshold": float(val_thr),
            "youden_j": float(val_j),
            "tpr_at_threshold": float(val_tpr),
            "fpr_at_threshold": float(val_fpr),
        },

        "n_test": int(len(test_df)),
        "label_counts_test": test_df["label_num"].value_counts(dropna=False).to_dict(),
        "sex_counts_test_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),
        "test_auroc": float(test_auc),

        "test_threshold_used": float(thr_used),
        "test_threshold_note": thr_note,
        "threshold_metrics_test_at_val_opt": thr_metrics_test,

        "fairness_test_at_val_opt": {
            "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at test_threshold_used.",
            "threshold_used": float(thr_used),
            "fnr_by_sex_norm": fnr_by_sex,
            "delta_fnr_F_minus_M": float(delta_f_minus_m),
            "delta_fnr_abs": float(delta_abs),
            "note": "If n_PD for a sex is 0, its FNR is NaN and ΔFNR is NaN.",
            "sex_normalization_note": "D1 mapping: 0->F, 1->M; otherwise UNK.",
        },

        "confusion_by_sex_norm_at_val_opt": confusion_by_sex,

        "artifacts": {
            "roc_curve_png": str(roc_png),
            "confusion_matrix_png": str(cm_png),
            "confusion_matrix_M_png": str(cm_m_png) if cm_m_png is not None else None,
            "confusion_matrix_F_png": str(cm_f_png) if cm_f_png is not None else None,
        },

        "dx_out_root": DX_OUT_ROOT,
        "trainval_experiment_used": str(chosen_exp),
        "best_heads_path": str(best_heads_path),
        "backbone_ckpt": BACKBONE_CKPT,
        "dropout_p": float(DROPOUT_P),
        "timestamp": time.strftime("%Y%m%d_%H%M%S"),
    }

    metrics_path = run_dir / "metrics.json"
    with open(metrics_path, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2)

    print(f"[seed={seed}] DONE | test_AUROC={test_auc:.6f}")

    def _fmt_fnr(dct, g):
        d = (dct or {}).get(g, None)
        if d is None:
            return "n/a"
        return f"fnr={d['fnr']:.6f} (n_PD={d['n_pos']}, fn={d['fn']}, tp={d['tp']})"

    print(f"[seed={seed}] TEST @ thr_used={thr_used:.6f}:")
    if thr_note is not None:
        print("  NOTE:", thr_note)

    print(f"[seed={seed}] FAIRNESS (H3) @ thr_used={thr_used:.6f}:")
    print("  M:", _fmt_fnr(fnr_by_sex, "M"))
    print("  F:", _fmt_fnr(fnr_by_sex, "F"))
    if fnr_by_sex is not None and "UNK" in fnr_by_sex:
        print("  UNK:", _fmt_fnr(fnr_by_sex, "UNK"))
    print("  ΔFNR (F-M):", f"{delta_f_minus_m:.6f}" if not np.isnan(delta_f_minus_m) else "nan")
    print("  |ΔFNR|:", f"{delta_abs:.6f}" if not np.isnan(delta_abs) else "nan")

    print(f"[seed={seed}] WROTE:")
    print(" ", str(metrics_path))
    print(" ", str(roc_png))
    print(" ", str(cm_png))
    if cm_m_png is not None:
        print(" ", str(cm_m_png))
    if cm_f_png is not None:
        print(" ", str(cm_f_png))

    return {
        "seed": int(seed),
        "val_thr": float(val_thr),
        "thr_used": float(thr_used),
        "thr_note": thr_note,
        "val_auc": float(val_auc),
        "test_auc": float(test_auc),
        "thr_metrics_test": thr_metrics_test,
        "fnr_by_sex": fnr_by_sex,
        "delta_signed": float(delta_f_minus_m),
        "delta_abs": float(delta_abs),
        "run_dir": str(run_dir),
    }

# -------------------------
# Run all seeds and aggregate results
# Outputs: printed summary + summary_test.json + history_index.jsonl
# -------------------------
results = []
for seed in SEEDS:
    results.append(run_test_once(seed))

# Purpose: report AUROC mean and 95% CI across seeds
aurocs = [r["test_auc"] for r in results]
t_crit = 4.302652729911275  # df=2, 95% CI
n = len(aurocs)
mean_auc = float(np.mean(aurocs))
std_auc = float(np.std(aurocs, ddof=1)) if n > 1 else 0.0
half_width = float(t_crit * (std_auc / math.sqrt(n))) if n > 1 else 0.0
ci95 = [float(mean_auc - half_width), float(mean_auc + half_width)]

def _mean_sd(vals):
    vals = np.asarray(vals, dtype=np.float64)
    return float(np.nanmean(vals)), float(np.nanstd(vals, ddof=1)) if np.sum(~np.isnan(vals)) > 1 else 0.0

# Purpose: summarize thresholds and threshold-based metrics across seeds
val_thrs = [float(r["val_thr"]) for r in results]
val_thr_mean, val_thr_sd = _mean_sd(val_thrs)

thr_used_list = [float(r["thr_used"]) for r in results]
thr_used_mean, thr_used_sd = _mean_sd(thr_used_list)

thr_list = [r["thr_metrics_test"] for r in results]
keys = ["accuracy","precision","recall","f1_score","sensitivity","specificity","mcc","p_value_fisher_two_sided"]
agg = {}
for k in keys:
    v = [float(tm.get(k, float("nan"))) for tm in thr_list]
    mu, sd = _mean_sd(v)
    agg[k] = {"mean": mu, "sd": sd, "values_by_seed": {str(s): float(tm.get(k, float("nan"))) for s, tm in zip(SEEDS, thr_list)}}
cm_by_seed = {str(s): thr_list[i]["confusion_matrix"] for i, s in enumerate(SEEDS)}

# Purpose: summarize fairness (H3) across seeds
fnr_by_seed = {str(r["seed"]): r["fnr_by_sex"] for r in results}
delta_signed_by_seed = {str(r["seed"]): float(r["delta_signed"]) for r in results}
delta_abs_by_seed = {str(r["seed"]): float(r["delta_abs"]) for r in results}

fnr_m_vals, fnr_f_vals, n_pd_m_vals, n_pd_f_vals = [], [], [], []
d_signed_vals, d_abs_vals = [], []
for r in results:
    d = r["fnr_by_sex"] or {}
    fnr_m_vals.append(float(d.get("M", {}).get("fnr", float("nan"))))
    fnr_f_vals.append(float(d.get("F", {}).get("fnr", float("nan"))))
    n_pd_m_vals.append(float(d.get("M", {}).get("n_pos", float("nan"))))
    n_pd_f_vals.append(float(d.get("F", {}).get("n_pos", float("nan"))))
    d_signed_vals.append(float(r["delta_signed"]))
    d_abs_vals.append(float(r["delta_abs"]))

fnr_m_mean, fnr_m_sd = _mean_sd(fnr_m_vals)
fnr_f_mean, fnr_f_sd = _mean_sd(fnr_f_vals)
d_signed_mean, d_signed_sd = _mean_sd(d_signed_vals)
d_abs_mean, d_abs_sd = _mean_sd(d_abs_vals)

print("\nTest AUROC by seed:")
for r in results:
    print(f"  seed {r['seed']}: {r['test_auc']:.6f}")
print(f"\nMean Test AUROC: {mean_auc:.6f}")
print(f"95% CI (t, n=3): [{ci95[0]:.6f}, {ci95[1]:.6f}]")

print("\nVAL Youden-J threshold by seed:")
for r in results:
    vthr = r["val_thr"]
    if np.isnan(vthr):
        print(f"  seed {r['seed']}: nan")
    else:
        print(f"  seed {r['seed']}: {vthr:.6f}")
print(f"VAL Youden-J threshold (mean ± SD): {val_thr_mean:.6f} ± {val_thr_sd:.6f}")

print("\nTEST threshold used (should equal VAL Youden-J threshold unless VAL threshold is NaN):")
for r in results:
    if r["thr_note"] is None:
        print(f"  seed {r['seed']}: {r['thr_used']:.6f}")
    else:
        print(f"  seed {r['seed']}: {r['thr_used']:.6f}  (NOTE: fallback used)")
print(f"TEST threshold used (mean ± SD): {thr_used_mean:.6f} ± {thr_used_sd:.6f}")

print("\nThreshold metrics on TEST @ VAL Youden-J threshold (mean ± SD across seeds):")
for k in ["accuracy","precision","sensitivity","specificity","f1_score","mcc"]:
    mu = agg[k]["mean"]
    sd = agg[k]["sd"]
    print(f"  {k}: {mu:.6f} ± {sd:.6f}")
print("  fisher_p_value_two_sided:", f"{agg['p_value_fisher_two_sided']['mean']:.6g} ± {agg['p_value_fisher_two_sided']['sd']:.6g}")

print("\nFAIRNESS (H3) on TEST @ VAL Youden-J threshold across seeds (mean ± SD):")
print(f"  FNR_M: {fnr_m_mean:.6f} ± {fnr_m_sd:.6f}")
print(f"  FNR_F: {fnr_f_mean:.6f} ± {fnr_f_sd:.6f}")
print(f"  ΔFNR (F-M): {d_signed_mean:.6f} ± {d_signed_sd:.6f}")
print(f"  |ΔFNR|: {d_abs_mean:.6f} ± {d_abs_sd:.6f}")
print("  Per-seed:", {
    str(SEEDS[i]): {
        "thr_used": thr_used_list[i],
        "FNR_M": fnr_m_vals[i],
        "n_PD_M": n_pd_m_vals[i],
        "FNR_F": fnr_f_vals[i],
        "n_PD_F": n_pd_f_vals[i],
        "delta_F_minus_M": d_signed_vals[i],
        "abs_delta": d_abs_vals[i],
    } for i in range(len(SEEDS))
})

# -------------------------
# Write aggregated summary files
# Inputs: per-seed results collected above
# Outputs: summary_test.json (overwrite) and history_index.jsonl (append)
# -------------------------
summary = {
    "dataset": dataset_id,
    "dx_out_root": DX_OUT_ROOT,
    "manifest_all": MANIFEST_ALL,
    "trainval_experiment_used": str(chosen_exp),
    "seeds": SEEDS,

    "n_val": int(len(val_df)),
    "label_counts_val": val_df["label_num"].value_counts(dropna=False).to_dict(),
    "sex_counts_val_norm": val_df["sex_norm"].value_counts(dropna=False).to_dict(),

    "n_test": int(len(test_df)),
    "label_counts_test": test_df["label_num"].value_counts(dropna=False).to_dict(),
    "sex_counts_test_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),

    "val_optimal_threshold": {
        "method": "Youden J (argmax(TPR - FPR) on VAL ROC)",
        "thresholds_by_seed": {str(r["seed"]): float(r["val_thr"]) for r in results},
        "mean_sd": {"mean": float(val_thr_mean), "sd": float(val_thr_sd)},
    },

    "test_threshold_used": {
        "threshold_used_by_seed": {str(r["seed"]): float(r["thr_used"]) for r in results},
        "mean_sd": {"mean": float(thr_used_mean), "sd": float(thr_used_sd)},
        "notes_by_seed": {str(r["seed"]): (r["thr_note"] if r["thr_note"] is not None else "") for r in results},
    },

    "test_aurocs_by_seed": {str(r["seed"]): float(r["test_auc"]) for r in results},
    "mean_test_auroc": float(mean_auc),
    "t_crit_df2_95": float(t_crit),
    "half_width_95_ci": float(half_width),
    "ci95_test_auroc": ci95,

    "threshold_metrics_mean_sd_test_at_val_opt": agg,
    "confusion_matrix_by_seed_test_at_val_opt": cm_by_seed,

    "fairness_test_at_val_opt": {
        "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at test_threshold_used.",
        "fnr_by_sex_norm_by_seed": fnr_by_seed,
        "delta_fnr_F_minus_M_by_seed": delta_signed_by_seed,
        "delta_fnr_abs_by_seed": delta_abs_by_seed,
        "fnr_M_mean_sd": {"mean": float(fnr_m_mean), "sd": float(fnr_m_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, fnr_m_vals)}},
        "fnr_F_mean_sd": {"mean": float(fnr_f_mean), "sd": float(fnr_f_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, fnr_f_vals)}},
        "delta_fnr_F_minus_M_mean_sd": {"mean": float(d_signed_mean), "sd": float(d_signed_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, d_signed_vals)}},
        "delta_fnr_abs_mean_sd": {"mean": float(d_abs_mean), "sd": float(d_abs_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, d_abs_vals)}},
        "denominators_PD_by_seed": {str(s): {"n_PD_M": float(n_pd_m_vals[i]), "n_PD_F": float(n_pd_f_vals[i])} for i, s in enumerate(SEEDS)},
        "sex_normalization_note": "D1 mapping: 0->F, 1->M; otherwise UNK. ΔFNR computed only when both M and F have defined FNR.",
    },

    "run_dirs": [r["run_dir"] for r in results],
    "timestamp": time.strftime("%Y%m%d_%H%M%S"),
}

summary_path = TEST_ROOT / "summary_test.json"
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(summary, f, indent=2)

history_path = TEST_ROOT / "history_index.jsonl"
with open(history_path, "a", encoding="utf-8") as f:
    f.write(json.dumps(summary) + "\n")

print("\nWROTE summary:", str(summary_path))
print("APPENDED history index:", str(history_path))
print("Open this folder to access artifacts:", str(TEST_ROOT))

# -------------------------
# Stop runtime
# Purpose: release the GPU machine after outputs are written
# -------------------------
print("\nAll done. Unassigning the runtime to stop the L4 instance...")
try:
    from google.colab import runtime  # type: ignore
    print("Calling runtime.unassign() now.")
    runtime.unassign()
except Exception as e:
    print("Could not unassign runtime automatically. You can stop the runtime manually in Colab.")
    print("Reason:", repr(e))

The following cell runs the final D2 (Slovak, EWA-DB) test evaluation in a safe and repeatable way, without retraining the model. It reads the combined manifest file (`manifests/manifest_all.csv`) from the D2 `preprocessed_v1` folder and keeps only the validation and test rows. It checks that all required fields are present, including split, audio path, label, task type, sex, and age. Before continuing, it performs quick fail fast checks to confirm that both splits are not empty and that all referenced audio files exist on disk, showing progress while the paths are checked.

The cell then prepares the data for inference. Each clip is assigned to one of two task groups using a strict rule: `task == "vowl"` is treated as vowel speech, and all other tasks are treated as other speech. Sex labels are normalized to M, F, or UNK using common text formats, without guessing numeric encodings. A dataset class loads each audio clip, converts stereo audio to mono if needed, enforces a 16 kHz sample rate, and builds an attention mask. For vowel clips, the attention mask attempts to ignore trailing near silence so padded or silent regions do not influence the model. A custom collator pads audio clips in each batch to a common length and pads the attention masks in the same way. A short warm up step loads a few batches from both validation and test splits to catch data loading issues early.

The model used for evaluation matches the training setup. It consists of a frozen Wav2Vec2 backbone and two small classification heads, one for vowel clips and one for other speech, each built as LayerNorm followed by Dropout and a Linear layer. The cell automatically finds the most recent training and validation experiment folder that contains `best_heads.pt` for all three seeds (1337, 2024, 7777). For each seed, it loads the saved head weights and, when available, reads the backbone checkpoint name and dropout value from the checkpoint so the test run remains consistent with training.

For each seed, evaluation is performed in two stages. First, during the validation stage, inference is run on the validation split and a decision threshold is selected using Youden’s J statistic, which chooses the threshold that maximizes the difference between true positive rate and false positive rate. Second, during the test stage, inference is run on the test split and all threshold based metrics are computed using the threshold chosen from validation, not a fixed value of 0.5, so the test set is not used to tune the threshold.

For each seed, results are saved under
`<DX_OUT_ROOT>/monolingual_test_runs/run_<DATASET>_seed<seed>/`.
Saved files include `metrics.json` with dataset counts, the chosen validation threshold, validation and test AUROC, test metrics at the validation selected threshold, and fairness results. A ROC curve image for the test split is saved, along with a confusion matrix for the full test set at the chosen threshold. Additional confusion matrices split by M and F are also saved when those groups are present.

Fairness is computed on the test split at the same chosen threshold using the project’s H3 definition. This measures the false negative rate by sex on Parkinson’s positive cases only, and computes ΔFNR as FNR(F) minus FNR(M), along with the absolute difference. The script also records confusion counts by sex group, including M, F, and UNK, for the full label set.

After all three seeds finish, the cell prints and writes a combined summary. This includes the mean test AUROC with a 95% confidence interval using a t distribution with n equal to 3, as well as the mean and standard deviation of the validation selected thresholds, threshold based test metrics, and fairness statistics including FNR for M and F, ΔFNR, and its absolute value. A `summary_test.json` file is written and the same summary is appended to `history_index.jsonl` under `monolingual_test_runs/`. The cell then prints the output folder location and unassigns the Colab runtime to stop the GPU instance.

In [None]:
# D2 Test Only: VAL-Selected Threshold Evaluation (Youden J)
#
# What this cell does
# - Reads one manifest (VAL + TEST rows only)
# - For each seed:
#   1) Runs VAL inference to pick a threshold using Youden J (max TPR − FPR)
#   2) Runs TEST inference and reports metrics at that fixed VAL-selected threshold
# - Aggregates results across 3 seeds (AUROC + threshold metrics + fairness)
#
# Inputs
# - manifest_all.csv (must include split, clip_path, label_num, task, sex, age)
# - best_heads.pt for each seed from the most recent trainval experiment
#
# Outputs
# - Per-seed folder: metrics.json + ROC and confusion PNGs (overall + by sex when present)
# - Summary JSON (summary_test.json) + append-only history (history_index.jsonl)
# - Stops the Colab runtime at the end
#
# Notes
# - Threshold is selected on VAL only to avoid test-tuning
# - Fairness metric (H3): ΔFNR = FNR(F) − FNR(M) at the selected threshold
# =========================

import os, json, math, random, time, warnings
from pathlib import Path

import numpy as np
import pandas as pd
import soundfile as sf

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import Wav2Vec2Model

from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, matthews_corrcoef
from scipy.stats import fisher_exact
import matplotlib.pyplot as plt

from tqdm.auto import tqdm

# -------------------------
# Drive mount (Colab)
# -------------------------
if not os.path.isdir("/content/drive/MyDrive"):
    from google.colab import drive  # type: ignore
    drive.mount("/content/drive")

# -------------------------
# Paths: dataset root + manifest
# -------------------------
DX_OUT_ROOT = "/content/drive/MyDrive/AI_PD_Project/Datasets/D2-Slovak (EWA-DB)/EWA-DB/preprocessed_v1"
MANIFEST_ALL = f"{DX_OUT_ROOT}/manifests/manifest_all.csv"

# -------------------------
# Fixed run settings (match Train+Val patterns)
# -------------------------
SEEDS          = [1337, 2024, 7777]
BACKBONE_CKPT  = "facebook/wav2vec2-base"  # default fallback; best_heads.pt may override
SR_EXPECTED    = 16000
TINY_THRESH    = 1e-4

# Batch sizing (same pattern as Train+Val)
EFFECTIVE_BS   = 64
PER_DEVICE_BS  = 16
GRAD_ACCUM     = max(1, EFFECTIVE_BS // PER_DEVICE_BS)  # printed only

# Threshold selection rule
# - Youden J on VAL: maximize (TPR - FPR)
THRESHOLD_SELECTION = "youden_j_on_val"

# Drive stability defaults
NUM_WORKERS    = 0
PIN_MEMORY     = False

# AMP (safe on L4)
USE_AMP        = True

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")

print("DX_OUT_ROOT:", DX_OUT_ROOT)
print("MANIFEST_ALL:", MANIFEST_ALL)
print("DEVICE:", DEVICE)
if DEVICE.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
print("PER_DEVICE_BS:", PER_DEVICE_BS, "| EFFECTIVE_BS:", PER_DEVICE_BS * GRAD_ACCUM)
print("NUM_WORKERS:", NUM_WORKERS, "| PIN_MEMORY:", PIN_MEMORY)
print("THRESHOLD_SELECTION:", THRESHOLD_SELECTION)
print("USE_AMP:", bool(USE_AMP and DEVICE.type == "cuda"))

# -------------------------
# Load manifest + keep only VAL/TEST rows
# -------------------------
if not os.path.exists(MANIFEST_ALL):
    raise FileNotFoundError(f"Missing manifest_all.csv: {MANIFEST_ALL}")

m = pd.read_csv(MANIFEST_ALL)

# Required columns for metrics + fairness + schema checks
req_cols = {"split", "clip_path", "label_num", "task", "sex", "age"}
missing = [c for c in sorted(req_cols) if c not in m.columns]
if missing:
    raise ValueError(f"Manifest missing required columns: {missing}. Found: {list(m.columns)}")

# Keep only val/test for this script (val selects threshold, test reports results)
m = m[m["split"].isin(["val", "test"])].copy()
if len(m) == 0:
    raise RuntimeError("After filtering to split in {'val','test'}, manifest has 0 rows.")

# Infer dataset_id (most frequent non-null dataset value, if present)
if "dataset" in m.columns and m["dataset"].notna().any():
    dataset_id = str(m["dataset"].value_counts(dropna=True).idxmax())
    m = m[m["dataset"].astype(str) == dataset_id].copy()
else:
    dataset_id = "DX"

# Keep a standard subset of columns (create as NaN if missing)
keep_cols = ["clip_path", "label_num", "task", "speaker_id", "sex", "age", "duration_sec", "split"]
for c in keep_cols:
    if c not in m.columns:
        m[c] = np.nan
m = m[keep_cols].copy()

val_df  = m[m["split"] == "val"].reset_index(drop=True)
test_df = m[m["split"] == "test"].reset_index(drop=True)

print(f"\nDataset inferred: {dataset_id}")
print(f"Val rows:  {len(val_df)}")
print(f"Test rows: {len(test_df)}")
print("Test label counts:", test_df["label_num"].value_counts(dropna=False).to_dict())
print("Test sex counts (raw):", test_df["sex"].value_counts(dropna=False).to_dict())

if len(test_df) == 0:
    raise RuntimeError("Test split has 0 rows.")
if len(val_df) == 0:
    raise RuntimeError("Validation split has 0 rows. Cannot select a val-optimal threshold.")

# -------------------------
# Clip existence check (stop early if paths are broken)
# -------------------------
def _fail_fast_missing_paths(df: pd.DataFrame, name: str):
    missing_paths = []
    for p in tqdm(df["clip_path"].astype(str).tolist(), desc=f"Check {name} clip_path exists", dynamic_ncols=True):
        if not os.path.exists(p):
            missing_paths.append(p)
            if len(missing_paths) >= 10:
                break
    if missing_paths:
        raise FileNotFoundError(f"{name}: missing clip_path(s). Examples: {missing_paths}")

_fail_fast_missing_paths(val_df, "VAL")
_fail_fast_missing_paths(test_df, "TEST")

# -------------------------
# Task grouping (exact mapping used by the project)
# -------------------------
def _task_group(task_val) -> str:
    return "vowel" if str(task_val) == "vowl" else "other"

val_df["task_group"]  = val_df["task"].apply(_task_group)
test_df["task_group"] = test_df["task"].apply(_task_group)

# -------------------------
# Sex normalization for fairness + sex-split confusion charts
# -------------------------
def normalize_sex(val) -> str:
    """
    Returns 'M', 'F', or 'UNK'
    Handles common strings: M/F, Male/Female, etc.
    Numeric encodings are treated as UNK to avoid silent mis-mapping.
    """
    if pd.isna(val):
        return "UNK"
    s = str(val).strip().lower()

    if s in {"m", "male", "man", "masc", "masculine"}:
        return "M"
    if s in {"f", "female", "woman", "fem", "feminine"}:
        return "F"

    return "UNK"

val_df["sex_norm"]  = val_df["sex"].apply(normalize_sex)
test_df["sex_norm"] = test_df["sex"].apply(normalize_sex)
print("Test sex counts (normalized):", test_df["sex_norm"].value_counts(dropna=False).to_dict())
if (test_df["sex_norm"] == "UNK").any():
    print("NOTE: Some 'sex' values could not be normalized to M/F and were counted as 'UNK' for fairness and sex charts.")

# -------------------------
# Audio dataset + collator (pads to batch max length)
# -------------------------
class AudioManifestDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        clip_path = str(row["clip_path"])
        label = int(row["label_num"])
        task_group = str(row["task_group"])
        sex_norm = str(row["sex_norm"])

        if not os.path.exists(clip_path):
            raise FileNotFoundError(f"Missing clip_path: {clip_path}")

        y, sr = sf.read(clip_path, always_2d=False)
        if isinstance(y, np.ndarray) and y.ndim > 1:
            y = np.mean(y, axis=1)
        y = np.asarray(y, dtype=np.float32)

        if int(sr) != SR_EXPECTED:
            raise ValueError(f"Sample rate mismatch (expected {SR_EXPECTED}): {clip_path} has sr={sr}")

        # Attention mask blocks padded tail for vowel clips (based on tiny amplitude cutoff)
        attn = np.ones((len(y),), dtype=np.int64)
        if task_group == "vowel":
            k = -1
            for j in range(len(y) - 1, -1, -1):
                if abs(float(y[j])) > TINY_THRESH:
                    k = j
                    break
            if k >= 0:
                attn[:k+1] = 1
                attn[k+1:] = 0
            else:
                attn[:] = 1
        else:
            attn[:] = 1

        return {
            "input_values": torch.from_numpy(y),
            "attention_mask": torch.from_numpy(attn),
            "labels": torch.tensor(label, dtype=torch.long),
            "task_group": task_group,
            "sex_norm": sex_norm,
        }

def collate_fn(batch):
    max_len = int(max(b["input_values"].numel() for b in batch))
    input_vals, attn_masks, labels, task_groups, sex_norms = [], [], [], [], []
    for b in batch:
        x = b["input_values"]
        a = b["attention_mask"]
        pad = max_len - x.numel()
        if pad > 0:
            x = torch.cat([x, torch.zeros(pad, dtype=x.dtype)], dim=0)
            a = torch.cat([a, torch.zeros(pad, dtype=a.dtype)], dim=0)
        input_vals.append(x)
        attn_masks.append(a)
        labels.append(b["labels"])
        task_groups.append(b["task_group"])
        sex_norms.append(b["sex_norm"])
    return {
        "input_values": torch.stack(input_vals, dim=0),
        "attention_mask": torch.stack(attn_masks, dim=0),
        "labels": torch.stack(labels, dim=0),
        "task_group": task_groups,
        "sex_norm": sex_norms,
    }

# -------------------------
# Model (test-time): frozen Wav2Vec2 + two task heads
# -------------------------
class FrozenW2V2TwoHeadTest(nn.Module):
    def __init__(self, ckpt: str, dropout_p: float):
        super().__init__()
        self.backbone = Wav2Vec2Model.from_pretrained(ckpt)
        for p in self.backbone.parameters():
            p.requires_grad = False

        H = int(self.backbone.config.hidden_size)
        self.head_vowel = nn.Sequential(
            nn.LayerNorm(H),
            nn.Dropout(float(dropout_p)),
            nn.Linear(H, 2),
        )
        self.head_other = nn.Sequential(
            nn.LayerNorm(H),
            nn.Dropout(float(dropout_p)),
            nn.Linear(H, 2),
        )

    def masked_mean_pool(self, last_hidden, attn_mask_samples):
        feat_mask = self.backbone._get_feature_vector_attention_mask(last_hidden.shape[1], attn_mask_samples)
        feat_mask = feat_mask.to(last_hidden.device).unsqueeze(-1).type_as(last_hidden)
        masked = last_hidden * feat_mask
        denom = feat_mask.sum(dim=1).clamp(min=1.0)
        return masked.sum(dim=1) / denom  # [B, H]

    def _heads_fp32(self, x_any: torch.Tensor, head: nn.Module) -> torch.Tensor:
        # Keep heads in fp32 even under AMP for numeric stability
        x = x_any.float()
        if x.is_cuda:
            with torch.amp.autocast(device_type="cuda", enabled=False):
                return head(x)
        return head(x)

    def forward_logits(self, input_values, attention_mask, task_group):
        with torch.no_grad():
            out = self.backbone(input_values=input_values, attention_mask=attention_mask)
            last_hidden = out.last_hidden_state

        pooled = self.masked_mean_pool(last_hidden, attention_mask)  # [B, H]

        logits_v = self._heads_fp32(pooled, self.head_vowel)
        logits_o = self._heads_fp32(pooled, self.head_other)

        # Route each sample to the matching head based on task_group
        is_vowel = torch.tensor([tg == "vowel" for tg in task_group], device=pooled.device, dtype=torch.bool)
        logits = logits_o.clone()
        if is_vowel.any():
            logits[is_vowel] = logits_v[is_vowel]
        return logits

# -------------------------
# Metrics + plotting helpers
# -------------------------
def compute_auc(y_true, y_prob):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    if len(np.unique(y_true)) < 2:
        return float("nan")
    return float(roc_auc_score(y_true, y_prob))

def compute_threshold_metrics(y_true, y_prob, thr):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])

    eps = 1e-12
    acc = (TP + TN) / max(1, (TP + TN + FP + FN))
    prec = TP / (TP + FP + eps)
    rec = TP / (TP + FN + eps)     # sensitivity
    f1 = 2 * prec * rec / (prec + rec + eps)
    spec = TN / (TN + FP + eps)

    try:
        mcc = float(matthews_corrcoef(y_true, y_pred)) if len(np.unique(y_pred)) > 1 else float("nan")
    except Exception:
        mcc = float("nan")

    try:
        _, pval = fisher_exact([[TN, FP], [FN, TP]], alternative="two-sided")
        pval = float(pval)
    except Exception:
        pval = float("nan")

    return {
        "threshold": float(thr),
        "confusion_matrix": {"TN": TN, "FP": FP, "FN": FN, "TP": TP},
        "accuracy": float(acc),
        "precision": float(prec),
        "recall": float(rec),
        "f1_score": float(f1),
        "sensitivity": float(rec),
        "specificity": float(spec),
        "mcc": float(mcc),
        "p_value_fisher_two_sided": float(pval),
    }

def save_roc_curve_png(y_true, y_prob, out_png, title_suffix="Set"):
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    plt.figure()
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curve ({title_suffix})")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def save_confusion_png(y_true, y_prob, out_png, thr, title_suffix="Set"):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(int)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    plt.figure()
    plt.imshow(cm)
    plt.xticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.yticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix ({title_suffix}, thr={thr:.4f})")
    for (i, j), v in np.ndenumerate(cm):
        plt.text(j, i, str(v), ha="center", va="center")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

# -------------------------
# Fairness (H3): FNR by sex + ΔFNR = FNR(F) − FNR(M)
# -------------------------
def compute_fnr_by_group_signed(y_true, y_prob, groups, thr):
    """
    FNR = FN/(FN+TP) computed on PD-only true labels.
    Returns:
      - per group: counts + fnr
      - delta_signed: FNR(F) - FNR(M)
      - delta_abs: |FNR(F) - FNR(M)|
    """
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    groups = np.asarray(list(groups), dtype=object)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    out = {}
    for g in sorted(set(groups.tolist())):
        mask_g = (groups == g)
        if int(mask_g.sum()) == 0:
            continue

        pos_mask = mask_g & (y_true == 1)
        n_pos = int(pos_mask.sum())
        if n_pos == 0:
            out[g] = {"n_total": int(mask_g.sum()), "n_pos": 0, "tp": 0, "fn": 0, "fnr": float("nan")}
            continue

        tp = int(((y_pred == 1) & pos_mask).sum())
        fn = int(((y_pred == 0) & pos_mask).sum())
        fnr = float(fn / max(1, (fn + tp)))

        out[g] = {"n_total": int(mask_g.sum()), "n_pos": int(n_pos), "tp": int(tp), "fn": int(fn), "fnr": float(fnr)}

    fnr_m = out.get("M", {}).get("fnr", float("nan"))
    fnr_f = out.get("F", {}).get("fnr", float("nan"))

    if (not np.isnan(fnr_m)) and (not np.isnan(fnr_f)):
        delta_signed = float(fnr_f - fnr_m)   # H3
        delta_abs = float(abs(delta_signed))
    else:
        delta_signed = float("nan")
        delta_abs = float("nan")

    return out, delta_signed, delta_abs

# -------------------------
# Confusion counts by group (all labels)
# -------------------------
def compute_confusion_counts(y_true, y_prob, thr):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])
    return {"TN": TN, "FP": FP, "FN": FN, "TP": TP}

def compute_confusion_by_group(y_true, y_prob, groups, thr):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    groups = np.asarray(list(groups), dtype=object)
    out = {}
    for g in sorted(set(groups.tolist())):
        mask = (groups == g)
        if int(mask.sum()) == 0:
            continue
        out[g] = {
            "n": int(mask.sum()),
            "confusion": compute_confusion_counts(y_true[mask], y_prob[mask], thr=thr),
        }
    return out

# -------------------------
# Reproducible seeding (per run)
# -------------------------
def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# -------------------------
# Find the most recent trainval experiment with all 3 seeds present
# -------------------------
TRAINVAL_ROOT = Path(DX_OUT_ROOT) / "trainval_runs"
if not TRAINVAL_ROOT.exists():
    raise FileNotFoundError(f"Missing trainval_runs folder: {str(TRAINVAL_ROOT)}")

exp_dirs = sorted([p for p in TRAINVAL_ROOT.glob("exp_*") if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
if not exp_dirs:
    raise FileNotFoundError(f"No exp_* folders found under: {str(TRAINVAL_ROOT)}")

def _has_all_seeds(exp_path: Path, dataset_id: str, seeds: list):
    for s in seeds:
        p = exp_path / f"run_{dataset_id}_seed{s}" / "best_heads.pt"
        if not p.exists():
            return False
    return True

chosen_exp = None
for ed in exp_dirs:
    if _has_all_seeds(ed, dataset_id, SEEDS):
        chosen_exp = ed
        break

if chosen_exp is None:
    sample = exp_dirs[0]
    raise FileNotFoundError(
        "Could not find a trainval experiment with all 3 best_heads.pt files.\n"
        f"Expected under: {str(TRAINVAL_ROOT)}/exp_*/run_{dataset_id}_seedXXXX/best_heads.pt\n"
        f"Most recent exp checked: {str(sample)}"
    )

print("\nUsing Train+Val experiment folder:")
print(" ", str(chosen_exp))

# -------------------------
# Output root for this evaluation
# -------------------------
TEST_ROOT = Path(DX_OUT_ROOT) / "monolingual_test_runs"
TEST_ROOT.mkdir(parents=True, exist_ok=True)

# -------------------------
# DataLoaders (VAL selects threshold, TEST reports metrics)
# -------------------------
val_ds = AudioManifestDataset(val_df)
val_loader = DataLoader(
    val_ds,
    batch_size=PER_DEVICE_BS,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    collate_fn=collate_fn,
)

test_ds = AudioManifestDataset(test_df)
test_loader = DataLoader(
    test_ds,
    batch_size=PER_DEVICE_BS,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    collate_fn=collate_fn,
)

# -------------------------
# Warm-up: catch loader issues early (no training)
# -------------------------
def warmup_loader(loader, name: str, max_batches: int = 3):
    print(f"\nWarm-up: loading up to {max_batches} {name} batches...")
    t0 = time.time()
    nb = len(loader)
    wb = min(max_batches, nb)
    if wb == 0:
        raise RuntimeError(f"{name} DataLoader has 0 batches. Check df length and PER_DEVICE_BS.")
    it = iter(loader)
    for i in range(wb):
        _ = next(it)
        print(f"  loaded warmup batch {i+1}/{wb}")
    print(f"Warm-up done in {time.time()-t0:.2f}s")

warmup_loader(val_loader, "VAL", max_batches=2)
warmup_loader(test_loader, "TEST", max_batches=2)

# -------------------------
# Load heads from best_heads.pt (supports current format + common wrappers)
# -------------------------
def load_heads_into_model(model: FrozenW2V2TwoHeadTest, best_heads_path: Path, state: dict):
    """
    Expected trainval format:
      state["head_vowel"] = state_dict for nn.Sequential(LN, Dropout, Linear)
      state["head_other"] = state_dict for nn.Sequential(LN, Dropout, Linear)

    Fallback: wrapper dict with a full model state_dict.
    """
    if not best_heads_path.exists():
        raise FileNotFoundError(f"Missing best_heads.pt: {str(best_heads_path)}")

    if not isinstance(state, dict):
        raise ValueError(f"Unexpected best_heads.pt type: {type(state)}")

    # Preferred (current trainval format)
    if "head_vowel" in state and "head_other" in state and isinstance(state["head_vowel"], dict) and isinstance(state["head_other"], dict):
        model.head_vowel.load_state_dict(state["head_vowel"], strict=True)
        model.head_other.load_state_dict(state["head_other"], strict=True)
        return model

    # Wrapper fallback
    for wrap_key in ["state_dict", "model_state_dict", "model"]:
        if wrap_key in state and isinstance(state[wrap_key], dict):
            sd = state[wrap_key]
            _missing, _unexpected = model.load_state_dict(sd, strict=False)
            has_head = any(k.startswith("head_vowel.") or k.startswith("head_other.") for k in sd.keys())
            if not has_head:
                raise KeyError(f"Checkpoint wrapper '{wrap_key}' did not contain head keys. First keys: {list(sd.keys())[:25]}")
            return model

    raise KeyError(
        "best_heads.pt did not match expected formats.\n"
        f"Top-level keys: {list(state.keys())[:50]}"
    )

# -------------------------
# Inference helper (returns y_true, y_score, sex_norm)
# -------------------------
def run_inference(model: FrozenW2V2TwoHeadTest, loader: DataLoader, desc: str, use_amp: bool):
    all_probs, all_true, all_sex = [], [], []
    pbar = tqdm(loader, desc=desc, dynamic_ncols=True)
    with torch.inference_mode():
        for batch in pbar:
            input_values = batch["input_values"].to(DEVICE, non_blocking=False)
            attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
            labels = batch["labels"].to(DEVICE, non_blocking=False)
            task_group = batch["task_group"]
            sex_norm = batch["sex_norm"]

            with torch.amp.autocast(device_type="cuda", enabled=use_amp):
                logits = model.forward_logits(input_values, attention_mask, task_group)

            probs = torch.softmax(logits.float(), dim=-1)[:, 1].detach().cpu().numpy().astype(np.float64)
            all_probs.extend(probs.tolist())
            all_true.extend(labels.detach().cpu().numpy().astype(np.int64).tolist())
            all_sex.extend(list(sex_norm))

    return np.asarray(all_true, dtype=np.int64), np.asarray(all_probs, dtype=np.float64), np.asarray(all_sex, dtype=object)

# -------------------------
# VAL threshold selection (Youden J on ROC curve)
# -------------------------
def select_threshold_youden_j(y_true_val: np.ndarray, y_prob_val: np.ndarray) -> float:
    y_true_val = np.asarray(y_true_val, dtype=np.int64)
    y_prob_val = np.asarray(y_prob_val, dtype=np.float64)

    if len(np.unique(y_true_val)) < 2:
        # ROC curve is undefined with only one class
        return 0.5

    fpr, tpr, thresholds = roc_curve(y_true_val, y_prob_val)

    # thresholds may include inf; ignore non-finite values
    finite = np.isfinite(thresholds)
    fpr = fpr[finite]
    tpr = tpr[finite]
    thresholds = thresholds[finite]

    if thresholds.size == 0:
        return 0.5

    j = tpr - fpr
    best_idx = int(np.argmax(j))
    thr = float(thresholds[best_idx])

    # Probabilities are in [0,1]; clamp for safety
    thr = float(min(1.0, max(0.0, thr)))
    return thr

# -------------------------
# One seed: VAL threshold selection -> TEST evaluation
# -------------------------
def run_seed_once(seed: int):
    set_all_seeds(seed)

    run_dir = TEST_ROOT / f"run_{dataset_id}_seed{seed}"
    run_dir.mkdir(parents=True, exist_ok=True)

    best_heads_path = chosen_exp / f"run_{dataset_id}_seed{seed}" / "best_heads.pt"

    print(f"\n[seed={seed}] Loading model + heads from:")
    print(" ", str(best_heads_path))

    state = torch.load(str(best_heads_path), map_location="cpu")

    # Read config from checkpoint when available (keeps test aligned to trainval)
    ckpt_backbone = str(state.get("backbone_ckpt", BACKBONE_CKPT)) if isinstance(state, dict) else BACKBONE_CKPT
    ckpt_dropout  = float(state.get("dropout_p", 0.10)) if isinstance(state, dict) else 0.10  # trainval uses 0.10

    model = FrozenW2V2TwoHeadTest(ckpt_backbone, dropout_p=ckpt_dropout).to(DEVICE)
    model = load_heads_into_model(model, best_heads_path, state)
    model.eval()

    use_amp = bool(USE_AMP and DEVICE.type == "cuda")

    # ---- VAL inference (select threshold)
    yv, pv, sv = run_inference(model, val_loader, desc=f"[seed={seed}] Val (for threshold)", use_amp=use_amp)
    val_auc = compute_auc(yv, pv)
    val_thr = select_threshold_youden_j(yv, pv)

    val_thr_metrics = compute_threshold_metrics(yv, pv, thr=val_thr)

    print(f"[seed={seed}] VAL: auroc={val_auc:.6f} | selected_thr={val_thr:.6f} (Youden J)")

    # ---- TEST inference (fixed threshold from VAL)
    yt, pt, st = run_inference(model, test_loader, desc=f"[seed={seed}] Test", use_amp=use_amp)
    test_auc = compute_auc(yt, pt)
    test_thr_metrics = compute_threshold_metrics(yt, pt, thr=val_thr)

    # ---- fairness on TEST at val_thr
    fnr_by_sex, delta_f_minus_m, delta_abs = compute_fnr_by_group_signed(
        yt, pt, st, thr=val_thr
    )

    # ---- confusion by sex (all labels) on TEST at val_thr
    confusion_by_sex = compute_confusion_by_group(yt, pt, st, thr=val_thr)

    # ---- plots (TEST overall)
    roc_png = run_dir / "roc_curve.png"
    cm_png  = run_dir / "confusion_matrix.png"
    save_roc_curve_png(yt, pt, str(roc_png), title_suffix=f"Test (seed={seed})")
    save_confusion_png(yt, pt, str(cm_png), thr=val_thr, title_suffix=f"Test (seed={seed})")

    # ---- plots (TEST by sex) at val_thr
    cm_m_png = None
    cm_f_png = None

    mask_m = (st == "M")
    mask_f = (st == "F")

    if int(mask_m.sum()) > 0:
        cm_m_png = run_dir / "confusion_matrix_M.png"
        save_confusion_png(
            yt[mask_m], pt[mask_m], str(cm_m_png),
            thr=val_thr, title_suffix=f"Test SEX=M (seed={seed})"
        )

    if int(mask_f.sum()) > 0:
        cm_f_png = run_dir / "confusion_matrix_F.png"
        save_confusion_png(
            yt[mask_f], pt[mask_f], str(cm_f_png),
            thr=val_thr, title_suffix=f"Test SEX=F (seed={seed})"
        )

    metrics = {
        "dataset": dataset_id,
        "seed": int(seed),

        "n_val": int(len(val_df)),
        "label_counts_val": val_df["label_num"].value_counts(dropna=False).to_dict(),
        "sex_counts_val_norm": val_df["sex_norm"].value_counts(dropna=False).to_dict(),

        "n_test": int(len(test_df)),
        "label_counts_test": test_df["label_num"].value_counts(dropna=False).to_dict(),
        "sex_counts_test_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),

        "threshold_selection": {
            "method": "youden_j_on_val",
            "selected_threshold": float(val_thr),
            "val_auroc": float(val_auc),
            "val_threshold_metrics": val_thr_metrics,
            "note": "Threshold selected on VAL only (Youden J). TEST metrics computed at this fixed threshold to avoid test-tuning.",
        },

        "test_auroc": float(test_auc),
        "threshold_metrics_test": test_thr_metrics,
        "test_threshold": float(val_thr),

        "fairness_test": {
            "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at test_threshold.",
            "fnr_by_sex_norm": fnr_by_sex,
            "delta_fnr_F_minus_M": float(delta_f_minus_m),
            "delta_fnr_abs": float(delta_abs),
            "note": "If n_PD for a sex is 0, its FNR is NaN and ΔFNR is NaN.",
            "sex_normalization_note": "sex_norm in {M,F,UNK}. Values not mapped to M/F counted as UNK.",
        },

        "confusion_by_sex_norm": confusion_by_sex,

        "artifacts": {
            "roc_curve_png": str(roc_png),
            "confusion_matrix_png": str(cm_png),
            "confusion_matrix_M_png": str(cm_m_png) if cm_m_png is not None else None,
            "confusion_matrix_F_png": str(cm_f_png) if cm_f_png is not None else None,
        },

        "dx_out_root": DX_OUT_ROOT,
        "trainval_experiment_used": str(chosen_exp),
        "best_heads_path": str(best_heads_path),
        "backbone_ckpt_used": ckpt_backbone,
        "dropout_p_used": float(ckpt_dropout),
        "timestamp": time.strftime("%Y%m%d_%H%M%S"),
    }

    metrics_path = run_dir / "metrics.json"
    with open(metrics_path, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2)

    print(f"[seed={seed}] DONE | test_AUROC={test_auc:.6f} | thr(val-opt)={val_thr:.6f}")

    def _fmt_fnr(g):
        d = fnr_by_sex.get(g, None)
        if d is None:
            return "n/a"
        return f"fnr={d['fnr']:.6f} (n_PD={d['n_pos']}, fn={d['fn']}, tp={d['tp']})"

    print(f"[seed={seed}] FAIRNESS (H3) @ thr(val-opt)={val_thr:.6f}:")
    print("  M:", _fmt_fnr("M"))
    print("  F:", _fmt_fnr("F"))
    if "UNK" in fnr_by_sex:
        print("  UNK:", _fmt_fnr("UNK"))
    print("  ΔFNR (F-M):", f"{delta_f_minus_m:.6f}" if not np.isnan(delta_f_minus_m) else "nan")
    print("  |ΔFNR|:", f"{delta_abs:.6f}" if not np.isnan(delta_abs) else "nan")

    print(f"[seed={seed}] WROTE:")
    print(" ", str(metrics_path))
    print(" ", str(roc_png))
    print(" ", str(cm_png))
    if cm_m_png is not None:
        print(" ", str(cm_m_png))
    if cm_f_png is not None:
        print(" ", str(cm_f_png))

    return {
        "seed": int(seed),
        "val_selected_threshold": float(val_thr),
        "val_auc": float(val_auc),
        "test_auc": float(test_auc),
        "test_thr_metrics": test_thr_metrics,
        "fnr_by_sex": fnr_by_sex,
        "delta_signed": float(delta_f_minus_m),
        "delta_abs": float(delta_abs),
        "confusion_by_sex": confusion_by_sex,
        "run_dir": str(run_dir),
    }

# -------------------------
# Run all seeds + aggregate summary
# -------------------------
results = []
for seed in SEEDS:
    results.append(run_seed_once(seed))

# AUROC mean ± 95% CI (t, n=3)
aurocs = [r["test_auc"] for r in results]
t_crit = 4.302652729911275  # df=2, 95% CI
n = len(aurocs)
mean_auc = float(np.mean(aurocs))
std_auc = float(np.std(aurocs, ddof=1)) if n > 1 else 0.0
half_width = float(t_crit * (std_auc / math.sqrt(n))) if n > 1 else 0.0
ci95 = [float(mean_auc - half_width), float(mean_auc + half_width)]

def _mean_sd(vals):
    vals = np.asarray(vals, dtype=np.float64)
    return float(np.nanmean(vals)), float(np.nanstd(vals, ddof=1)) if np.sum(~np.isnan(vals)) > 1 else 0.0

# Aggregate threshold metrics at val-opt threshold
keys = ["accuracy","precision","recall","f1_score","sensitivity","specificity","mcc","p_value_fisher_two_sided"]
agg = {}
for k in keys:
    v = [float(r["test_thr_metrics"].get(k, float("nan"))) for r in results]
    mu, sd = _mean_sd(v)
    agg[k] = {
        "mean": mu,
        "sd": sd,
        "values_by_seed": {str(r["seed"]): float(r["test_thr_metrics"].get(k, float("nan"))) for r in results},
    }

cm_by_seed = {str(r["seed"]): r["test_thr_metrics"]["confusion_matrix"] for r in results}

# Aggregate thresholds selected on VAL
thr_vals = [float(r["val_selected_threshold"]) for r in results]
thr_mean, thr_sd = _mean_sd(thr_vals)

# FAIRNESS aggregation (H3) on TEST at val-opt threshold
fnr_m_vals, fnr_f_vals, n_pd_m_vals, n_pd_f_vals = [], [], [], []
d_signed_vals, d_abs_vals = [], []

fnr_by_seed = {}
delta_signed_by_seed = {}
delta_abs_by_seed = {}
confusion_by_sex_by_seed = {}
run_dirs = []

for r in results:
    s = str(r["seed"])
    fnr_by_seed[s] = r["fnr_by_sex"]
    delta_signed_by_seed[s] = float(r["delta_signed"])
    delta_abs_by_seed[s] = float(r["delta_abs"])
    confusion_by_sex_by_seed[s] = r["confusion_by_sex"]
    run_dirs.append(r["run_dir"])

    d = r["fnr_by_sex"]
    fnr_m_vals.append(float(d.get("M", {}).get("fnr", float("nan"))))
    fnr_f_vals.append(float(d.get("F", {}).get("fnr", float("nan"))))
    n_pd_m_vals.append(float(d.get("M", {}).get("n_pos", float("nan"))))
    n_pd_f_vals.append(float(d.get("F", {}).get("n_pos", float("nan"))))
    d_signed_vals.append(float(r["delta_signed"]))
    d_abs_vals.append(float(r["delta_abs"]))

fnr_m_mean, fnr_m_sd = _mean_sd(fnr_m_vals)
fnr_f_mean, fnr_f_sd = _mean_sd(fnr_f_vals)
d_signed_mean, d_signed_sd = _mean_sd(d_signed_vals)
d_abs_mean, d_abs_sd = _mean_sd(d_abs_vals)

print("\nTest AUROC by seed:")
for r in results:
    print(f"  seed {r['seed']}: {r['test_auc']:.6f}")
print(f"\nMean Test AUROC: {mean_auc:.6f}")
print(f"95% CI (t, n=3): [{ci95[0]:.6f}, {ci95[1]:.6f}]")

print("\nVAL-selected threshold (Youden J) across seeds (mean ± SD):")
print(f"  thr_val_opt: {thr_mean:.6f} ± {thr_sd:.6f}")
print("  values_by_seed:", {str(r["seed"]): float(r["val_selected_threshold"]) for r in results})

print("\nThreshold metrics at VAL-OPT threshold on TEST (mean ± SD across seeds):")
for k in ["accuracy","precision","sensitivity","specificity","f1_score","mcc"]:
    mu = agg[k]["mean"]
    sd = agg[k]["sd"]
    print(f"  {k}: {mu:.6f} ± {sd:.6f}")
print("  fisher_p_value_two_sided:", f"{agg['p_value_fisher_two_sided']['mean']:.6g} ± {agg['p_value_fisher_two_sided']['sd']:.6g}")

print("\nFAIRNESS (H3) on TEST @ VAL-OPT threshold (mean ± SD):")
print(f"  FNR_M: {fnr_m_mean:.6f} ± {fnr_m_sd:.6f}")
print(f"  FNR_F: {fnr_f_mean:.6f} ± {fnr_f_sd:.6f}")
print(f"  ΔFNR (F-M): {d_signed_mean:.6f} ± {d_signed_sd:.6f}")
print(f"  |ΔFNR|: {d_abs_mean:.6f} ± {d_abs_sd:.6f}")
print("  Per-seed:", {
    str(r["seed"]): {
        "thr_val_opt": float(r["val_selected_threshold"]),
        "FNR_M": fnr_m_vals[i],
        "n_PD_M": n_pd_m_vals[i],
        "FNR_F": fnr_f_vals[i],
        "n_PD_F": n_pd_f_vals[i],
        "delta_F_minus_M": d_signed_vals[i],
        "abs_delta": d_abs_vals[i],
    } for i, r in enumerate(results)
})

summary = {
    "dataset": dataset_id,
    "dx_out_root": DX_OUT_ROOT,
    "manifest_all": MANIFEST_ALL,
    "trainval_experiment_used": str(chosen_exp),
    "seeds": SEEDS,

    "n_val": int(len(val_df)),
    "label_counts_val": val_df["label_num"].value_counts(dropna=False).to_dict(),
    "sex_counts_val_norm": val_df["sex_norm"].value_counts(dropna=False).to_dict(),

    "n_test": int(len(test_df)),
    "label_counts_test": test_df["label_num"].value_counts(dropna=False).to_dict(),
    "sex_counts_test_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),

    "threshold_selection": {
        "method": "youden_j_on_val",
        "val_selected_threshold_by_seed": {str(r["seed"]): float(r["val_selected_threshold"]) for r in results},
        "val_selected_threshold_mean_sd": {"mean": float(thr_mean), "sd": float(thr_sd)},
        "note": "Threshold selected on VAL only (Youden J). TEST metrics computed at this fixed threshold to avoid test-tuning.",
    },

    "test_aurocs_by_seed": {str(r["seed"]): float(r["test_auc"]) for r in results},
    "mean_test_auroc": float(mean_auc),
    "t_crit_df2_95": float(t_crit),
    "half_width_95_ci": float(half_width),
    "ci95_test_auroc": ci95,

    "threshold_metrics_mean_sd": agg,
    "confusion_matrix_by_seed": cm_by_seed,
    "run_dirs": run_dirs,

    "fairness_test": {
        "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at test_threshold.",
        "fnr_by_sex_norm_by_seed": fnr_by_seed,
        "delta_fnr_F_minus_M_by_seed": delta_signed_by_seed,
        "delta_fnr_abs_by_seed": delta_abs_by_seed,
        "fnr_M_mean_sd": {"mean": float(fnr_m_mean), "sd": float(fnr_m_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, fnr_m_vals)}},
        "fnr_F_mean_sd": {"mean": float(fnr_f_mean), "sd": float(fnr_f_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, fnr_f_vals)}},
        "delta_fnr_F_minus_M_mean_sd": {"mean": float(d_signed_mean), "sd": float(d_signed_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, d_signed_vals)}},
        "delta_fnr_abs_mean_sd": {"mean": float(d_abs_mean), "sd": float(d_abs_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, d_abs_vals)}},
        "denominators_PD_by_seed": {str(s): {"n_PD_M": float(n_pd_m_vals[i]), "n_PD_F": float(n_pd_f_vals[i])} for i, s in enumerate(SEEDS)},
        "sex_normalization_note": "sex_norm in {M,F,UNK}. Values not mapped to M/F counted as UNK. ΔFNR computed only when both M and F have defined FNR.",
    },

    "confusion_by_sex_norm_by_seed": confusion_by_sex_by_seed,

    "timestamp": time.strftime("%Y%m%d_%H%M%S"),
}

summary_path = TEST_ROOT / "summary_test.json"
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(summary, f, indent=2)

history_path = TEST_ROOT / "history_index.jsonl"
with open(history_path, "a", encoding="utf-8") as f:
    f.write(json.dumps(summary) + "\n")

print("\nWROTE summary:", str(summary_path))
print("APPENDED history index:", str(history_path))
print("Open this folder to access artifacts:", str(TEST_ROOT))

# -------------------------
# Stop runtime (release L4)
# -------------------------
print("\nAll done. Unassigning the runtime to stop the L4 instance...")
try:
    from google.colab import runtime  # type: ignore
    print("Calling runtime.unassign() now.")
    runtime.unassign()
except Exception as e:
    print("Could not unassign runtime automatically. You can stop the runtime manually in Colab.")
    print("Reason:", repr(e))

The following cell runs the final D4 (Italian, IPVS) test evaluation in a safe and repeatable way, using the trained model heads from the most recent train and validation experiment. It reads `manifests/manifest_all.csv` from the D4 output folder and keeps only the validation split, which is used to choose a decision threshold, and the test split, which is used to report performance. It checks that all required columns are present, with special attention to the sex column since it is needed for fairness checks and sex specific confusion matrices. Before any model work begins, the cell performs basic fail fast checks to confirm that the manifest exists, that the validation and test splits are not empty, and that all audio files listed in `clip_path` are present on disk.

After setup, including fixed random seeds, batch settings, GPU mixed precision settings, and a required 16 kHz sample rate, the cell prepares the data for inference. Each clip is assigned to one of two task groups using a strict rule: `task == "vowl"` is treated as vowel speech and everything else is treated as other speech. Sex labels are normalized to M, F, or UNK using common text formats, without guessing numeric encodings. A dataset class loads each audio file, converts it to mono, verifies the sample rate, and builds an attention mask to help the model ignore trailing silence in short vowel clips. A collator pads audio and attention masks in each batch to a common length. A short warm up step loads up to two batches from both validation and test to catch data loading or file issues early, including edge cases with very small splits.

The model used for evaluation matches the training setup. It consists of a frozen Wav2Vec2 backbone with two small classification heads, one for vowel clips and one for other clips, each preceded by LayerNorm and Dropout. The cell then finds the most recent train and validation experiment folder that contains the required `best_heads.pt` files for all three seeds. For each seed, 1337, 2024, and 7777, it loads the saved head weights, runs inference on the validation split to compute a validation optimal threshold using Youden’s J statistic, which maximizes true positive rate minus false positive rate, and then runs inference on the test split using that same threshold rather than a fixed value of 0.5.

For each seed, results are saved under
`<DX_OUT_ROOT>/monolingual_test_runs/run_<DATASET>_seed<seed>/`.
Saved outputs include a ROC curve image for the test split, a confusion matrix for the test split at the validation chosen threshold, additional confusion matrices split by M and F when those groups are present, and a `metrics.json` file containing dataset counts, the selected threshold and its details, AUROC values, threshold based metrics, fairness results, and paths to the saved plots.

After all three seeds finish, the cell combines results across seeds and prints a short summary. This includes the mean test AUROC with a 95% confidence interval computed using a t distribution with n equal to 3, along with the mean and standard deviation of the validation optimal thresholds, threshold based test metrics, and the H3 fairness measure. The fairness measure reports false negative rates by sex on Parkinson’s positive cases only and computes ΔFNR as FNR(F) minus FNR(M). A combined `summary_test.json` file is written, the same summary is appended to `history_index.jsonl`, the output locations are printed, and the Colab runtime is unassigned to stop the GPU instance.

In [None]:
# =========================
# D4 Monolingual TEST (VAL-OPT THRESHOLD via Youden J) + Fairness by Sex
# What this cell does:
# - Reads one manifest (VAL + TEST) for a single dataset
# - For each seed:
#   1) Runs inference on VAL to pick a threshold using Youden J (TPR − FPR)
#   2) Runs inference on TEST and reports metrics at that VAL-picked threshold
# - Also saves ROC and confusion matrix plots (overall + by sex when available)
# Inputs:
# - manifest_all.csv (must include VAL and TEST rows)
# - best_heads.pt for each seed from the most recent trainval experiment
# Outputs:
# - Per-seed run folder with metrics.json and plots
# - summary_test.json (aggregated across seeds) + history_index.jsonl
# =========================
#
# NOTE FOR D4:
# - Manifest sex is typically "M" and "F".
# - Sex is normalized to M/F/UNK. Numeric encodings are not auto-mapped.
# =========================

import os, json, math, random, time, warnings
from pathlib import Path

import numpy as np
import pandas as pd
import soundfile as sf

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import Wav2Vec2Model

from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, matthews_corrcoef
from scipy.stats import fisher_exact
import matplotlib.pyplot as plt

from tqdm.auto import tqdm

# -------------------------
# Safety check: avoid local files that override real libraries
# -------------------------
if os.path.exists("/content/torch.py") or os.path.exists("/content/torch/__init__.py"):
    raise RuntimeError("Found local /content/torch.py or /content/torch/ that shadows PyTorch. Rename/remove it and restart runtime.")
if os.path.exists("/content/transformers.py") or os.path.exists("/content/transformers/__init__.py"):
    raise RuntimeError("Found local /content/transformers.py or /content/transformers/ that shadows Hugging Face Transformers. Rename/remove it and restart runtime.")

# -------------------------
# Colab Drive: mount if missing
# -------------------------
if not os.path.isdir("/content/drive/MyDrive"):
    from google.colab import drive  # type: ignore
    drive.mount("/content/drive")

# -------------------------
# Paths: use runtime DX_OUT_ROOT if present, otherwise fallback to D4
# Inputs:
# - DX_OUT_ROOT/manifests/manifest_all.csv
# -------------------------
D4_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D4-Italian (IPVS)/preprocessed_v1"
DX_OUT_ROOT = globals().get("DX_OUT_ROOT", D4_OUT_ROOT_FALLBACK)
MANIFEST_ALL = f"{DX_OUT_ROOT}/manifests/manifest_all.csv"

# -------------------------
# Settings (kept consistent with trainval where relevant)
# -------------------------
SEEDS          = [1337, 2024, 7777]
BACKBONE_CKPT  = "facebook/wav2vec2-base"
SR_EXPECTED    = 16000
TINY_THRESH    = 1e-4

EFFECTIVE_BS   = 64
PER_DEVICE_BS  = 16
GRAD_ACCUM     = max(1, EFFECTIVE_BS // PER_DEVICE_BS)  # printed only

DROPOUT_P      = 0.2

NUM_WORKERS    = 0
PIN_MEMORY     = False

USE_AMP        = True

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

# Quiet down known, non-actionable warnings for this test cell
warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
warnings.filterwarnings("ignore", message="Passing `gradient_checkpointing` to a config initialization is deprecated")

print("DX_OUT_ROOT:", DX_OUT_ROOT)
print("MANIFEST_ALL:", MANIFEST_ALL)
print("DEVICE:", DEVICE)
if DEVICE.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
print("PER_DEVICE_BS:", PER_DEVICE_BS, "| EFFECTIVE_BS:", PER_DEVICE_BS * GRAD_ACCUM)
print("NUM_WORKERS:", NUM_WORKERS, "| PIN_MEMORY:", PIN_MEMORY)
print("USE_AMP:", bool(USE_AMP and DEVICE.type == "cuda"))

# -------------------------
# Load manifest and prepare VAL/TEST tables
# Inputs:
# - manifest_all.csv with required columns
# Outputs:
# - val_df, test_df filtered to split == val/test
# - dataset_id inferred from manifest (if available)
# -------------------------
if not os.path.exists(MANIFEST_ALL):
    raise FileNotFoundError(f"Missing manifest_all.csv: {MANIFEST_ALL}")

m_all = pd.read_csv(MANIFEST_ALL)

# Required for: splitting, audio loading, labels, fairness and sex plots
req_cols = {"split", "clip_path", "label_num", "task", "sex"}
missing = [c for c in sorted(req_cols) if c not in m_all.columns]
if missing:
    raise ValueError(f"Manifest missing required columns: {missing}. Found: {list(m_all.columns)}")

# Infer dataset id from the dominant value in 'dataset' (if present)
if "dataset" in m_all.columns and m_all["dataset"].notna().any():
    dataset_id = str(m_all["dataset"].astype(str).value_counts(dropna=True).idxmax())
    m_all = m_all[m_all["dataset"].astype(str) == dataset_id].copy()
else:
    dataset_id = "DX"

# Keep a compact set of columns (adds placeholders if missing)
keep_cols = ["clip_path", "label_num", "task", "speaker_id", "sex", "duration_sec", "split"]
for c in keep_cols:
    if c not in m_all.columns:
        m_all[c] = np.nan
m_all = m_all[keep_cols].copy()

val_df  = m_all[m_all["split"].isin(["val"])].copy().reset_index(drop=True)
test_df = m_all[m_all["split"].isin(["test"])].copy().reset_index(drop=True)

print(f"\nDataset inferred: {dataset_id}")
print(f"VAL rows:  {len(val_df)}")
print(f"TEST rows: {len(test_df)}")

# VAL is required here because the threshold is chosen on VAL (Youden J)
if len(val_df) == 0:
    raise RuntimeError("After filtering to split=='val', manifest has 0 rows. VAL is required to compute VAL-opt threshold (Youden J).")
if len(test_df) == 0:
    raise RuntimeError("After filtering to split=='test', manifest has 0 rows.")

print("VAL label counts:",  val_df["label_num"].value_counts(dropna=False).to_dict())
print("TEST label counts:", test_df["label_num"].value_counts(dropna=False).to_dict())
print("VAL sex counts (raw):",  val_df["sex"].value_counts(dropna=False).to_dict())
print("TEST sex counts (raw):", test_df["sex"].value_counts(dropna=False).to_dict())

# -------------------------
# Fast failure: verify clip_path files exist (shows progress)
# Output: raises early if missing audio is found
# -------------------------
def _fail_fast_missing_paths(df: pd.DataFrame, name: str):
    missing_paths = []
    for p in tqdm(df["clip_path"].astype(str).tolist(), desc=f"Check {name} clip_path exists", dynamic_ncols=True):
        if not os.path.exists(p):
            missing_paths.append(p)
            if len(missing_paths) >= 10:
                break
    if missing_paths:
        raise FileNotFoundError(f"{name}: missing clip_path(s). Examples: {missing_paths}")

_fail_fast_missing_paths(val_df, "VAL")
_fail_fast_missing_paths(test_df, "TEST")

# -------------------------
# Task grouping (exact rule used by this project)
# Output: task_group in {"vowel","other"}
# -------------------------
def _task_group(task_val) -> str:
    return "vowel" if str(task_val) == "vowl" else "other"

val_df["task_group"]  = val_df["task"].apply(_task_group)
test_df["task_group"] = test_df["task"].apply(_task_group)

# -------------------------
# Sex normalization for fairness + sex-specific confusion charts
# Output: sex_norm in {"M","F","UNK"}
# -------------------------
def normalize_sex(val) -> str:
    """
    Returns 'M', 'F', or 'UNK'.
    Maps common text forms (M/F, Male/Female, etc).
    Numeric encodings are not auto-mapped to avoid silent mix-ups.
    """
    if pd.isna(val):
        return "UNK"
    s = str(val).strip().lower()
    if s in {"m", "male", "man", "masc", "masculine"}:
        return "M"
    if s in {"f", "female", "woman", "fem", "feminine"}:
        return "F"
    return "UNK"

val_df["sex_norm"]  = val_df["sex"].apply(normalize_sex)
test_df["sex_norm"] = test_df["sex"].apply(normalize_sex)

print("VAL sex counts (normalized):",  val_df["sex_norm"].value_counts(dropna=False).to_dict())
print("TEST sex counts (normalized):", test_df["sex_norm"].value_counts(dropna=False).to_dict())
if (val_df["sex_norm"] == "UNK").any() or (test_df["sex_norm"] == "UNK").any():
    print("NOTE: Some 'sex' values could not be normalized to M/F and were counted as 'UNK' for fairness and sex charts.")

# -------------------------
# Dataset + collator
# Input: df with clip_path, label_num, task_group, sex_norm
# Output: padded batches with attention_mask
# Notes:
# - For vowel clips, attention_mask trims trailing near-silence using TINY_THRESH
# - For other clips, attention_mask is all ones
# -------------------------
class AudioManifestDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        clip_path = str(row["clip_path"])
        label = int(row["label_num"])
        task_group = str(row["task_group"])
        sex_norm = str(row["sex_norm"])

        if not os.path.exists(clip_path):
            raise FileNotFoundError(f"Missing clip_path: {clip_path}")

        y, sr = sf.read(clip_path, always_2d=False)
        if isinstance(y, np.ndarray) and y.ndim > 1:
            y = np.mean(y, axis=1)
        y = np.asarray(y, dtype=np.float32)

        if int(sr) != SR_EXPECTED:
            raise ValueError(f"Sample rate mismatch (expected {SR_EXPECTED}): {clip_path} has sr={sr}")

        # attention_mask is in sample space, then converted inside Wav2Vec2 to feature-mask
        attn = np.ones((len(y),), dtype=np.int64)
        if task_group == "vowel":
            # mark trailing near-zero tail as padding so it does not affect pooling
            k = -1
            for j in range(len(y) - 1, -1, -1):
                if abs(float(y[j])) > TINY_THRESH:
                    k = j
                    break
            if k >= 0:
                attn[:k+1] = 1
                attn[k+1:] = 0
            else:
                attn[:] = 1
        else:
            attn[:] = 1

        return {
            "input_values": torch.from_numpy(y),
            "attention_mask": torch.from_numpy(attn),
            "labels": torch.tensor(label, dtype=torch.long),
            "task_group": task_group,
            "sex_norm": sex_norm,
        }

def collate_fn(batch):
    # Pads to the max length in the batch (zeros for audio and attention)
    max_len = int(max(b["input_values"].numel() for b in batch))
    input_vals, attn_masks, labels, task_groups, sex_norms = [], [], [], [], []
    for b in batch:
        x = b["input_values"]
        a = b["attention_mask"]
        pad = max_len - x.numel()
        if pad > 0:
            x = torch.cat([x, torch.zeros(pad, dtype=x.dtype)], dim=0)
            a = torch.cat([a, torch.zeros(pad, dtype=a.dtype)], dim=0)
        input_vals.append(x)
        attn_masks.append(a)
        labels.append(b["labels"])
        task_groups.append(b["task_group"])
        sex_norms.append(b["sex_norm"])
    return {
        "input_values": torch.stack(input_vals, dim=0),
        "attention_mask": torch.stack(attn_masks, dim=0),
        "labels": torch.stack(labels, dim=0),
        "task_group": task_groups,
        "sex_norm": sex_norms,
    }

# -------------------------
# Model: frozen wav2vec2 backbone + two task-specific heads
# Output: logits for PD vs Healthy, chosen per item by task_group
# -------------------------
class Wav2Vec2TwoHeadClassifier(nn.Module):
    def __init__(self, ckpt: str, dropout_p: float):
        super().__init__()
        self.backbone = Wav2Vec2Model.from_pretrained(ckpt, use_safetensors=True)
        for p in self.backbone.parameters():
            p.requires_grad = False

        H = int(self.backbone.config.hidden_size)
        self.pre_vowel = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.pre_other = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.head_vowel = nn.Linear(H, 2)
        self.head_other = nn.Linear(H, 2)

    def masked_mean_pool(self, last_hidden, attn_mask_samples):
        # Converts sample-level mask to feature-level mask, then mean-pools masked features
        feat_mask = self.backbone._get_feature_vector_attention_mask(last_hidden.shape[1], attn_mask_samples)
        feat_mask = feat_mask.to(last_hidden.device).unsqueeze(-1).type_as(last_hidden)
        masked = last_hidden * feat_mask
        denom = feat_mask.sum(dim=1).clamp(min=1.0)
        return masked.sum(dim=1) / denom

    def _heads_fp32(self, x_fp_any: torch.Tensor, head: nn.Module) -> torch.Tensor:
        # Keeps heads in float32 even when AMP is enabled
        x = x_fp_any.float()
        if x.is_cuda:
            with torch.amp.autocast(device_type="cuda", enabled=False):
                return head(x)
        return head(x)

    def forward_logits(self, input_values, attention_mask, task_group):
        # Backbone is frozen: inference-only for features
        with torch.no_grad():
            out = self.backbone(input_values=input_values, attention_mask=attention_mask)
            last_hidden = out.last_hidden_state

        pooled = self.masked_mean_pool(last_hidden, attention_mask)  # [B,H]

        z_v = self.pre_vowel(pooled.float())
        z_o = self.pre_other(pooled.float())

        logits_v = self._heads_fp32(z_v, self.head_vowel)
        logits_o = self._heads_fp32(z_o, self.head_other)

        # Pick the matching head for each item in the batch
        is_vowel = torch.tensor([tg == "vowel" for tg in task_group], device=pooled.device, dtype=torch.bool)
        logits = logits_o.clone()
        if is_vowel.any():
            logits[is_vowel] = logits_v[is_vowel]
        return logits

# -------------------------
# Metrics and plots
# Output:
# - AUROC (threshold-free)
# - Threshold metrics (confusion-based) at chosen threshold
# - ROC and confusion matrix PNGs
# -------------------------
def compute_auc(y_true, y_prob):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    if len(np.unique(y_true)) < 2:
        return float("nan")
    return float(roc_auc_score(y_true, y_prob))

def compute_threshold_metrics(y_true, y_prob, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])

    eps = 1e-12
    acc = (TP + TN) / max(1, (TP + TN + FP + FN))
    prec = TP / (TP + FP + eps)
    rec = TP / (TP + FN + eps)     # sensitivity
    f1 = 2 * prec * rec / (prec + rec + eps)
    spec = TN / (TN + FP + eps)

    try:
        mcc = float(matthews_corrcoef(y_true, y_pred)) if len(np.unique(y_pred)) > 1 else float("nan")
    except Exception:
        mcc = float("nan")

    try:
        _, pval = fisher_exact([[TN, FP], [FN, TP]], alternative="two-sided")
        pval = float(pval)
    except Exception:
        pval = float("nan")

    return {
        "threshold": float(thr),
        "confusion_matrix": {"TN": TN, "FP": FP, "FN": FN, "TP": TP},
        "accuracy": float(acc),
        "precision": float(prec),
        "recall": float(rec),
        "f1_score": float(f1),
        "sensitivity": float(rec),
        "specificity": float(spec),
        "mcc": float(mcc),
        "p_value_fisher_two_sided": float(pval),
    }

def save_roc_curve_png(y_true, y_prob, out_png, title_suffix="Test"):
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    plt.figure()
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curve ({title_suffix})")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def save_confusion_png(y_true, y_prob, out_png, thr=0.5, title_suffix="Test"):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(int)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])

    plt.figure()
    plt.imshow(cm)
    plt.xticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.yticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix ({title_suffix}, thr={thr:.2f})")
    for (i, j), v in np.ndenumerate(cm):
        plt.text(j, i, str(v), ha="center", va="center")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

# -------------------------
# VAL-opt threshold (Youden J): maximize (TPR - FPR) on VAL
# Output: thr_opt plus details for logging
# -------------------------
def compute_val_opt_threshold_youden_j(y_true, y_prob):
    """
    Picks the ROC threshold that maximizes J = TPR - FPR on VAL.
    Returns: thr_opt, J_opt, tpr_opt, fpr_opt
    """
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)

    if len(np.unique(y_true)) < 2:
        return float("nan"), float("nan"), float("nan"), float("nan")

    fpr, tpr, thr = roc_curve(y_true, y_prob)
    J = tpr - fpr
    idx = int(np.argmax(J))
    return float(thr[idx]), float(J[idx]), float(tpr[idx]), float(fpr[idx])

# -------------------------
# Fairness metric (H3): FNR by sex and ΔFNR = FNR(F) - FNR(M)
# Notes:
# - FNR is computed only on true PD samples (label == 1)
# - UNK is tracked but ΔFNR is only defined when both M and F exist
# -------------------------
def compute_fnr_by_group_signed(y_true, y_prob, groups, thr=0.5):
    """
    FNR = FN/(FN+TP) on true PD samples only.
    Returns:
    - per-group stats (n_total, n_pos, tp, fn, fnr)
    - delta_f_minus_m: FNR(F) - FNR(M) when both defined
    - delta_abs_m_f: |FNR(F) - FNR(M)|
    """
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    groups = np.asarray(list(groups), dtype=object)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    out = {}
    for g in sorted(set(groups.tolist())):
        mask_g = (groups == g)
        if int(mask_g.sum()) == 0:
            continue

        pos_mask = mask_g & (y_true == 1)
        n_pos = int(pos_mask.sum())
        if n_pos == 0:
            out[g] = {"n_total": int(mask_g.sum()), "n_pos": 0, "tp": 0, "fn": 0, "fnr": float("nan")}
            continue

        tp = int(((y_pred == 1) & pos_mask).sum())
        fn = int(((y_pred == 0) & pos_mask).sum())
        fnr = float(fn / max(1, (fn + tp)))

        out[g] = {"n_total": int(mask_g.sum()), "n_pos": int(n_pos), "tp": int(tp), "fn": int(fn), "fnr": float(fnr)}

    fnr_m = out.get("M", {}).get("fnr", float("nan"))
    fnr_f = out.get("F", {}).get("fnr", float("nan"))

    if (not np.isnan(fnr_m)) and (not np.isnan(fnr_f)):
        delta_signed = float(fnr_f - fnr_m)   # H3
        delta_abs = float(abs(delta_signed))
    else:
        delta_signed = float("nan")
        delta_abs = float("nan")

    return out, delta_signed, delta_abs

# -------------------------
# Confusion counts by group (M/F/UNK), useful for deeper debugging
# -------------------------
def compute_confusion_counts(y_true, y_prob, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])
    return {"TN": TN, "FP": FP, "FN": FN, "TP": TP}

def compute_confusion_by_group(y_true, y_prob, groups, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    groups = np.asarray(list(groups), dtype=object)
    out = {}
    for g in sorted(set(groups.tolist())):
        mask = (groups == g)
        if int(mask.sum()) == 0:
            continue
        out[g] = {
            "n": int(mask.sum()),
            "confusion": compute_confusion_counts(y_true[mask], y_prob[mask], thr=thr),
        }
    return out

# -------------------------
# Seed control for repeatable inference
# -------------------------
def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# -------------------------
# Locate the most recent trainval experiment with all 3 best_heads.pt files
# Input: DX_OUT_ROOT/trainval_runs/exp_*
# Output: chosen_exp (folder path)
# -------------------------
TRAINVAL_ROOT = Path(DX_OUT_ROOT) / "trainval_runs"
if not TRAINVAL_ROOT.exists():
    raise FileNotFoundError(f"Missing trainval_runs folder: {str(TRAINVAL_ROOT)}")

exp_dirs = sorted([p for p in TRAINVAL_ROOT.glob("exp_*") if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
if not exp_dirs:
    raise FileNotFoundError(f"No exp_* folders found under: {str(TRAINVAL_ROOT)}")

def _has_all_seeds(exp_path: Path, dataset_id: str, seeds: list):
    for s in seeds:
        p = exp_path / f"run_{dataset_id}_seed{s}" / "best_heads.pt"
        if not p.exists():
            return False
    return True

chosen_exp = None
for ed in exp_dirs:
    if _has_all_seeds(ed, dataset_id, SEEDS):
        chosen_exp = ed
        break

if chosen_exp is None:
    sample = exp_dirs[0]
    raise FileNotFoundError(
        "Could not find a trainval experiment with all 3 best_heads.pt files.\n"
        f"Expected under: {str(TRAINVAL_ROOT)}/exp_*/run_{dataset_id}_seedXXXX/best_heads.pt\n"
        f"Most recent exp checked: {str(sample)}"
    )

print("\nUsing Train+Val experiment folder:")
print(" ", str(chosen_exp))

# -------------------------
# Output root for this test run
# Output: DX_OUT_ROOT/monolingual_test_runs/...
# -------------------------
TEST_ROOT = Path(DX_OUT_ROOT) / "monolingual_test_runs"
TEST_ROOT.mkdir(parents=True, exist_ok=True)

# -------------------------
# Build VAL and TEST loaders and do a small warm-up read
# Purpose: catches empty loaders and common I/O issues early
# -------------------------
val_ds = AudioManifestDataset(val_df)
val_loader = DataLoader(
    val_ds,
    batch_size=PER_DEVICE_BS,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    collate_fn=collate_fn,
)

test_ds = AudioManifestDataset(test_df)
test_loader = DataLoader(
    test_ds,
    batch_size=PER_DEVICE_BS,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    collate_fn=collate_fn,
)

print("\nWarm-up: loading up to 2 VAL batches and 2 TEST batches...")

def _warmup_loader(loader, name):
    t0 = time.time()
    num_batches = len(loader)
    warmup_batches = min(2, num_batches)
    if warmup_batches == 0:
        raise RuntimeError(f"{name} DataLoader has 0 batches. Check df length and PER_DEVICE_BS.")
    it = iter(loader)
    for i in range(warmup_batches):
        _ = next(it)
        print(f"  loaded {name} warmup batch {i+1}/{warmup_batches}")
    print(f"{name} warm-up done in {time.time()-t0:.2f}s")

_warmup_loader(val_loader, "VAL")
_warmup_loader(test_loader, "TEST")

# -------------------------
# Helpers: load heads and run inference
# Inputs:
# - best_heads.pt (trainval output)
# - loader (VAL or TEST)
# Outputs:
# - arrays: y_true, y_prob (PD probability), sex_norm
# -------------------------
def load_heads_into_model(model: Wav2Vec2TwoHeadClassifier, best_heads_path: Path):
    if not best_heads_path.exists():
        raise FileNotFoundError(f"Missing best_heads.pt: {str(best_heads_path)}")
    state = torch.load(str(best_heads_path), map_location="cpu")
    needed = ["pre_vowel", "pre_other", "head_vowel", "head_other"]
    missing = [k for k in needed if k not in state]
    if missing:
        raise KeyError(
            f"best_heads.pt missing keys {missing}. Found keys: {list(state.keys())}. "
            "This test code expects the trainval save format."
        )
    model.pre_vowel.load_state_dict(state["pre_vowel"], strict=True)
    model.pre_other.load_state_dict(state["pre_other"], strict=True)
    model.head_vowel.load_state_dict(state["head_vowel"], strict=True)
    model.head_other.load_state_dict(state["head_other"], strict=True)
    return model

def run_inference(loader, model, desc):
    # Returns PD probability (class 1) for each clip
    use_amp = bool(USE_AMP and DEVICE.type == "cuda")
    all_probs, all_true, all_sex = [], [], []

    pbar = tqdm(loader, desc=desc, dynamic_ncols=True)
    with torch.inference_mode():
        for batch in pbar:
            input_values = batch["input_values"].to(DEVICE, non_blocking=False)
            attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
            labels = batch["labels"].to(DEVICE, non_blocking=False)
            task_group = batch["task_group"]
            sex_norm = batch["sex_norm"]

            with torch.amp.autocast(device_type="cuda", enabled=use_amp):
                logits = model.forward_logits(input_values, attention_mask, task_group)

            probs = torch.softmax(logits.float(), dim=-1)[:, 1].detach().cpu().numpy().astype(np.float64)
            all_probs.extend(probs.tolist())
            all_true.extend(labels.detach().cpu().numpy().astype(np.int64).tolist())
            all_sex.extend(list(sex_norm))

    return np.asarray(all_true, dtype=np.int64), np.asarray(all_probs, dtype=np.float64), np.asarray(all_sex, dtype=object)

# -------------------------
# One seed run:
# 1) VAL inference -> choose threshold (Youden J)
# 2) TEST inference -> metrics at that threshold + plots + metrics.json
# Output: compact dict used later for aggregation
# -------------------------
def run_seed(seed: int):
    set_all_seeds(seed)

    run_dir = TEST_ROOT / f"run_{dataset_id}_seed{seed}"
    run_dir.mkdir(parents=True, exist_ok=True)

    best_heads_path = chosen_exp / f"run_{dataset_id}_seed{seed}" / "best_heads.pt"

    print(f"\n[seed={seed}] Loading model + heads from:")
    print(" ", str(best_heads_path))

    model = Wav2Vec2TwoHeadClassifier(BACKBONE_CKPT, dropout_p=DROPOUT_P).to(DEVICE)
    model = load_heads_into_model(model, best_heads_path)
    model.eval()

    # ----- VAL -> threshold via Youden J
    yv, pv, _ = run_inference(val_loader, model, desc=f"[seed={seed}] VAL (for thr)")
    val_auc = compute_auc(yv, pv)
    thr_opt, j_opt, tpr_opt, fpr_opt = compute_val_opt_threshold_youden_j(yv, pv)

    # Display format (keep tight + numeric)
    print(f"[seed={seed}] VAL AUROC: {val_auc:.6f}")
    print(f"[seed={seed}] VAL-opt threshold (Youden J): {thr_opt:.6f}")

    # ----- TEST -> metrics and plots at VAL-picked threshold
    yt, pt, st = run_inference(test_loader, model, desc=f"[seed={seed}] TEST")

    test_auc = compute_auc(yt, pt)  # threshold-free
    thr_metrics = compute_threshold_metrics(yt, pt, thr=thr_opt)

    fnr_by_sex, delta_f_minus_m, delta_abs = compute_fnr_by_group_signed(
        yt, pt, st, thr=thr_opt
    )

    confusion_by_sex = compute_confusion_by_group(yt, pt, st, thr=thr_opt)

    # ----- plots (overall)
    roc_png = run_dir / "roc_curve.png"
    cm_png  = run_dir / "confusion_matrix.png"
    save_roc_curve_png(yt, pt, str(roc_png), title_suffix=f"Test (seed={seed})")
    save_confusion_png(yt, pt, str(cm_png), thr=thr_opt, title_suffix=f"Test (seed={seed})")

    # ----- plots (by sex): separate confusion charts for M and F when present
    cm_m_png = None
    cm_f_png = None
    mask_m = (st == "M")
    mask_f = (st == "F")

    if int(mask_m.sum()) > 0:
        cm_m_png = run_dir / "confusion_matrix_M.png"
        save_confusion_png(
            yt[mask_m], pt[mask_m], str(cm_m_png),
            thr=thr_opt, title_suffix=f"Test SEX=M (seed={seed})"
        )

    if int(mask_f.sum()) > 0:
        cm_f_png = run_dir / "confusion_matrix_F.png"
        save_confusion_png(
            yt[mask_f], pt[mask_f], str(cm_f_png),
            thr=thr_opt, title_suffix=f"Test SEX=F (seed={seed})"
        )

    # Full per-seed record written to metrics.json
    metrics = {
        "dataset": dataset_id,
        "seed": int(seed),

        "n_val": int(len(val_df)),
        "label_counts_val": val_df["label_num"].value_counts(dropna=False).to_dict(),
        "sex_counts_val_norm": val_df["sex_norm"].value_counts(dropna=False).to_dict(),

        "n_test": int(len(test_df)),
        "label_counts_test": test_df["label_num"].value_counts(dropna=False).to_dict(),
        "sex_counts_test_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),

        "val_auroc": float(val_auc),
        "val_opt_threshold_method": "Youden J (maximize TPR - FPR on VAL ROC curve)",
        "val_opt_threshold": float(thr_opt),
        "val_opt_details": {
            "youden_j": float(j_opt),
            "tpr_at_opt": float(tpr_opt),
            "fpr_at_opt": float(fpr_opt),
        },

        "test_auroc": float(test_auc),

        # TEST evaluated at VAL-opt threshold
        "threshold_metrics_test": thr_metrics,
        "test_threshold_used": float(thr_opt),

        "fairness_test": {
            "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at test_threshold_used.",
            "fnr_by_sex_norm": fnr_by_sex,
            "delta_fnr_F_minus_M": float(delta_f_minus_m),
            "delta_fnr_abs": float(delta_abs),
            "note": "If n_PD for a sex is 0, its FNR is NaN and ΔFNR is NaN.",
            "sex_normalization_note": "sex_norm in {M,F,UNK}. Values not mapped to M/F counted as UNK.",
        },

        "confusion_by_sex_norm": confusion_by_sex,

        "artifacts": {
            "roc_curve_png": str(roc_png),
            "confusion_matrix_png": str(cm_png),
            "confusion_matrix_M_png": str(cm_m_png) if cm_m_png is not None else None,
            "confusion_matrix_F_png": str(cm_f_png) if cm_f_png is not None else None,
        },

        "dx_out_root": DX_OUT_ROOT,
        "trainval_experiment_used": str(chosen_exp),
        "best_heads_path": str(best_heads_path),
        "backbone_ckpt": BACKBONE_CKPT,
        "dropout_p": float(DROPOUT_P),
        "timestamp": time.strftime("%Y%m%d_%H%M%S"),
    }

    metrics_path = run_dir / "metrics.json"
    with open(metrics_path, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2)

    print(f"[seed={seed}] DONE | test_AUROC={test_auc:.6f}")

    # Small, readable fairness printout
    def _fmt_fnr(g):
        d = fnr_by_sex.get(g, None)
        if d is None:
            return "n/a"
        return f"fnr={d['fnr']:.6f} (n_PD={d['n_pos']}, fn={d['fn']}, tp={d['tp']})"

    print(f"[seed={seed}] TEST metrics @ VAL-opt thr={thr_opt:.6f}")
    print(f"[seed={seed}] FAIRNESS (H3) @ thr={thr_opt:.6f}:")
    print("  M:", _fmt_fnr("M"))
    print("  F:", _fmt_fnr("F"))
    if "UNK" in fnr_by_sex:
        print("  UNK:", _fmt_fnr("UNK"))
    print("  ΔFNR (F-M):", f"{delta_f_minus_m:.6f}" if not np.isnan(delta_f_minus_m) else "nan")
    print("  |ΔFNR|:", f"{delta_abs:.6f}" if not np.isnan(delta_abs) else "nan")

    print(f"[seed={seed}] WROTE:")
    print(" ", str(metrics_path))
    print(" ", str(roc_png))
    print(" ", str(cm_png))
    if cm_m_png is not None:
        print(" ", str(cm_m_png))
    if cm_f_png is not None:
        print(" ", str(cm_f_png))

    return {
        "seed": int(seed),
        "val_opt_threshold": float(thr_opt),
        "test_auroc": float(test_auc),
        "thr_metrics_test": thr_metrics,
        "fnr_by_sex_test": fnr_by_sex,
        "delta_f_minus_m_test": float(delta_f_minus_m),
        "delta_abs_test": float(delta_abs),
        "run_dir": str(run_dir),
    }

# -------------------------
# Run all seeds and aggregate results
# Outputs:
# - Printed per-seed thresholds and AUROC
# - summary_test.json and history_index.jsonl
# -------------------------
seed_results = []
for seed in SEEDS:
    seed_results.append(run_seed(seed))

# AUROC aggregation (threshold-free)
aurocs = [r["test_auroc"] for r in seed_results]
t_crit = 4.302652729911275  # df=2, 95% CI
n = len(aurocs)
mean_auc = float(np.mean(aurocs))
std_auc = float(np.std(aurocs, ddof=1)) if n > 1 else 0.0
half_width = float(t_crit * (std_auc / math.sqrt(n))) if n > 1 else 0.0
ci95 = [float(mean_auc - half_width), float(mean_auc + half_width)]

def _mean_sd(vals):
    vals = np.asarray(vals, dtype=np.float64)
    return float(np.nanmean(vals)), float(np.nanstd(vals, ddof=1)) if np.sum(~np.isnan(vals)) > 1 else 0.0

# Threshold aggregation (VAL-opt)
thr_vals = [r["val_opt_threshold"] for r in seed_results]
thr_mean, thr_sd = _mean_sd(thr_vals)

# Threshold metrics on TEST @ per-seed VAL-opt threshold
thr_list = [r["thr_metrics_test"] for r in seed_results]
keys = ["accuracy","precision","recall","f1_score","sensitivity","specificity","mcc","p_value_fisher_two_sided"]

agg = {}
for k in keys:
    v = [float(tm.get(k, float("nan"))) for tm in thr_list]
    mu, sd = _mean_sd(v)
    agg[k] = {"mean": mu, "sd": sd, "values_by_seed": {str(r["seed"]): float(tm.get(k, float("nan"))) for r, tm in zip(seed_results, thr_list)}}

cm_by_seed = {str(r["seed"]): r["thr_metrics_test"]["confusion_matrix"] for r in seed_results}

# FAIRNESS aggregation (H3) on TEST @ per-seed VAL-opt threshold
fnr_by_seed = {str(r["seed"]): r["fnr_by_sex_test"] for r in seed_results}
delta_signed_by_seed = {str(r["seed"]): float(r["delta_f_minus_m_test"]) for r in seed_results}
delta_abs_by_seed = {str(r["seed"]): float(r["delta_abs_test"]) for r in seed_results}

fnr_m_vals, fnr_f_vals, n_pd_m_vals, n_pd_f_vals = [], [], [], []
d_signed_vals, d_abs_vals = [], []

for r in seed_results:
    s = str(r["seed"])
    d = fnr_by_seed.get(s, {})
    fnr_m_vals.append(float(d.get("M", {}).get("fnr", float("nan"))))
    fnr_f_vals.append(float(d.get("F", {}).get("fnr", float("nan"))))
    n_pd_m_vals.append(float(d.get("M", {}).get("n_pos", float("nan"))))
    n_pd_f_vals.append(float(d.get("F", {}).get("n_pos", float("nan"))))
    d_signed_vals.append(float(delta_signed_by_seed.get(s, float("nan"))))
    d_abs_vals.append(float(delta_abs_by_seed.get(s, float("nan"))))

fnr_m_mean, fnr_m_sd = _mean_sd(fnr_m_vals)
fnr_f_mean, fnr_f_sd = _mean_sd(fnr_f_vals)
d_signed_mean, d_signed_sd = _mean_sd(d_signed_vals)
d_abs_mean, d_abs_sd = _mean_sd(d_abs_vals)

print("\nVAL-opt thresholds (Youden J) by seed:")
for r in seed_results:
    print(f"  seed {r['seed']}: {r['val_opt_threshold']:.6f}")
print(f"  mean ± SD: {thr_mean:.6f} ± {thr_sd:.6f}")

print("\nTest AUROC by seed:")
for r in seed_results:
    print(f"  seed {r['seed']}: {r['test_auroc']:.6f}")
print(f"\nMean Test AUROC: {mean_auc:.6f}")
print(f"95% CI (t, n=3): [{ci95[0]:.6f}, {ci95[1]:.6f}]")

print("\nThreshold metrics on TEST @ VAL-opt threshold (mean ± SD across seeds):")
for k in ["accuracy","precision","sensitivity","specificity","f1_score","mcc"]:
    mu = agg[k]["mean"]
    sd = agg[k]["sd"]
    print(f"  {k}: {mu:.6f} ± {sd:.6f}")
print("  fisher_p_value_two_sided:", f"{agg['p_value_fisher_two_sided']['mean']:.6g} ± {agg['p_value_fisher_two_sided']['sd']:.6g}")

print("\nFAIRNESS (H3) on TEST @ VAL-opt threshold (mean ± SD across seeds):")
print(f"  FNR_M: {fnr_m_mean:.6f} ± {fnr_m_sd:.6f}")
print(f"  FNR_F: {fnr_f_mean:.6f} ± {fnr_f_sd:.6f}")
print(f"  ΔFNR (F-M): {d_signed_mean:.6f} ± {d_signed_sd:.6f}")
print(f"  |ΔFNR|: {d_abs_mean:.6f} ± {d_abs_sd:.6f}")
print("  Per-seed:", {
    str(r["seed"]): {
        "val_opt_thr": float(r["val_opt_threshold"]),
        "FNR_M": fnr_m_vals[i],
        "n_PD_M": n_pd_m_vals[i],
        "FNR_F": fnr_f_vals[i],
        "n_PD_F": n_pd_f_vals[i],
        "delta_F_minus_M": d_signed_vals[i],
        "abs_delta": d_abs_vals[i],
    } for i, r in enumerate(seed_results)
})

summary = {
    "dataset": dataset_id,
    "dx_out_root": DX_OUT_ROOT,
    "manifest_all": MANIFEST_ALL,
    "trainval_experiment_used": str(chosen_exp),
    "seeds": SEEDS,

    "n_val": int(len(val_df)),
    "label_counts_val": val_df["label_num"].value_counts(dropna=False).to_dict(),
    "sex_counts_val_norm": val_df["sex_norm"].value_counts(dropna=False).to_dict(),

    "n_test": int(len(test_df)),
    "label_counts_test": test_df["label_num"].value_counts(dropna=False).to_dict(),
    "sex_counts_test_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),

    "val_opt_threshold_method": "Youden J (maximize TPR - FPR on VAL ROC curve)",
    "val_opt_threshold_by_seed": {str(r["seed"]): float(r["val_opt_threshold"]) for r in seed_results},
    "val_opt_threshold_mean_sd": {"mean": float(thr_mean), "sd": float(thr_sd)},

    "test_aurocs_by_seed": {str(r["seed"]): float(r["test_auroc"]) for r in seed_results},
    "mean_test_auroc": float(mean_auc),
    "t_crit_df2_95": float(t_crit),
    "half_width_95_ci": float(half_width),
    "ci95_test_auroc": ci95,

    "threshold_metrics_test_mean_sd": agg,
    "confusion_matrix_by_seed": cm_by_seed,
    "run_dirs": [r["run_dir"] for r in seed_results],

    "fairness_test": {
        "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at per-seed VAL-opt threshold used on TEST.",
        "fnr_by_sex_norm_by_seed": fnr_by_seed,
        "delta_fnr_F_minus_M_by_seed": delta_signed_by_seed,
        "delta_fnr_abs_by_seed": delta_abs_by_seed,
        "fnr_M_mean_sd": {"mean": float(fnr_m_mean), "sd": float(fnr_m_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, fnr_m_vals)}},
        "fnr_F_mean_sd": {"mean": float(fnr_f_mean), "sd": float(fnr_f_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, fnr_f_vals)}},
        "delta_fnr_F_minus_M_mean_sd": {"mean": float(d_signed_mean), "sd": float(d_signed_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, d_signed_vals)}},
        "delta_fnr_abs_mean_sd": {"mean": float(d_abs_mean), "sd": float(d_abs_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, d_abs_vals)}},
        "denominators_PD_by_seed": {str(s): {"n_PD_M": float(n_pd_m_vals[i]), "n_PD_F": float(n_pd_f_vals[i])} for i, s in enumerate(SEEDS)},
        "sex_normalization_note": "sex_norm in {M,F,UNK}. Values not mapped to M/F counted as UNK. ΔFNR computed only when both M and F have defined FNR.",
    },

    "timestamp": time.strftime("%Y%m%d_%H%M%S"),
}

summary_path = TEST_ROOT / "summary_test.json"
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(summary, f, indent=2)

history_path = TEST_ROOT / "history_index.jsonl"
with open(history_path, "a", encoding="utf-8") as f:
    f.write(json.dumps(summary) + "\n")

print("\nWROTE summary:", str(summary_path))
print("APPENDED history index:", str(history_path))
print("Open this folder to access artifacts:", str(TEST_ROOT))

# -------------------------
# Runtime release (stop paid GPU)
# -------------------------
print("\nAll done. Unassigning the runtime to stop the L4 instance...")
try:
    from google.colab import runtime  # type: ignore
    print("Calling runtime.unassign() now.")
    runtime.unassign()
except Exception as e:
    print("Could not unassign runtime automatically. You can stop the runtime manually in Colab.")
    print("Reason:", repr(e))

The following cell evaluates the D5 (English, MDVR-KCL) model on the test split, using a decision threshold that is chosen only from the validation split. It reads `manifests/manifest_all.csv` from the D5 `preprocessed_v2` folder and keeps only the validation and test rows. It requires that the sex and age columns are present so fairness can be evaluated, and it checks that the dataset is clearly identified as D5 to avoid accidentally using files from another dataset. Before running any model code, it confirms that all referenced audio clip files exist on disk.

The cell then sets fixed run settings, including three random seeds, the Wav2Vec2 backbone, a 16 kHz sample rate, batch sizes, dropout, and mixed precision on the GPU. Input data is prepared for inference by assigning each clip to either a vowel group (`task == "vowl"`) or an other group. Sex values are normalized to M, F, or UNK for consistent reporting. A dataset and collator load audio from disk, check that the sample rate is correct, pad clips in each batch to a common length, and create an attention mask so padded silence in vowel clips does not affect the model. A short warm up step loads a few batches from each split to catch data loading or file issues early, even when the splits are small.

The model structure matches the training setup. It uses a frozen Wav2Vec2 backbone with two small classification heads, one for vowel clips and one for other clips, each preceded by LayerNorm and Dropout. For each seed, 1337, 2024, and 7777, the cell finds the most recent train and validation experiment folder that contains all required `best_heads.pt` files, reloads the saved head weights, and runs inference in two stages.

In the validation stage, prediction scores are computed on the validation split and a validation optimal threshold is selected using Youden’s J statistic, which maximizes the difference between true positive rate and false positive rate. In the test stage, this validation chosen threshold is applied to the test split to compute test AUROC, threshold based metrics including the confusion matrix, accuracy, precision, sensitivity, specificity, F1 score, MCC, and Fisher exact test p value, and fairness metrics. Fairness is computed using the H3 definition as the false negative rate for males and females separately on Parkinson’s positive cases only, along with ΔFNR defined as FNR(F) minus FNR(M).

For each seed, results are saved under
`<DX_OUT_ROOT>/monolingual_test_runs/run_D5_seed<seed>/`.
Saved outputs include a ROC curve image, an overall confusion matrix, additional confusion matrices split by M and F when those groups are present in the test data, and a per seed `metrics.json` file containing all settings, counts, thresholds, metrics, fairness results, and paths to the saved artifacts.

After all three seeds complete, the cell aggregates results across seeds. It reports the mean test AUROC with a 95% confidence interval using a t distribution with n equal to 3, along with the mean and standard deviation of the validation chosen thresholds, threshold based test metrics, and fairness values. A combined `summary_test.json` file is written, the same summary is appended to `history_index.jsonl`, the output locations are printed, and the Colab runtime is unassigned to stop the GPU instance.

In [None]:
# =========================
# D5 Test Only — Val Optimal Threshold + Fairness
# Purpose: Evaluate the D5 model on VAL and TEST, pick a per seed VAL optimal threshold (Youden J),
#          then report TEST metrics and fairness at that threshold.
# Inputs:  manifest_all.csv (VAL and TEST rows), best_heads.pt from latest trainval exp (3 seeds)
# Outputs: Per seed run folder with metrics.json + plots, plus summary_test.json + history_index.jsonl
# =========================

import os, json, math, random, time, warnings
from pathlib import Path

import numpy as np
import pandas as pd
import soundfile as sf

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import Wav2Vec2Model

from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, matthews_corrcoef
from scipy.stats import fisher_exact
import matplotlib.pyplot as plt

from tqdm.auto import tqdm

# -------------------------
# Safety checks: avoid importing local files named torch/transformers
# -------------------------
if os.path.exists("/content/torch.py") or os.path.exists("/content/torch/__init__.py"):
    raise RuntimeError("Found local /content/torch.py or /content/torch/ that shadows PyTorch. Rename/remove it and restart runtime.")
if os.path.exists("/content/transformers.py") or os.path.exists("/content/transformers/__init__.py"):
    raise RuntimeError("Found local /content/transformers.py or /content/transformers/ that shadows Hugging Face Transformers. Rename/remove it and restart runtime.")

# -------------------------
# Drive access
# -------------------------
if not os.path.isdir("/content/drive/MyDrive"):
    from google.colab import drive  # type: ignore
    drive.mount("/content/drive")

# -------------------------
# Paths: output root and the single manifest used for VAL and TEST
# -------------------------
D5_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D5-English (MDVR-KCL)/preprocessed_v2"
# DX_OUT_ROOT = globals().get("DX_OUT_ROOT", D5_OUT_ROOT_FALLBACK)
# modified to consider the new 50/20/30 splits path
DX_OUT_ROOT = "/content/drive/MyDrive/AI_PD_Project/Datasets/D5-English (MDVR-KCL)/preprocessed_v2"  # modified to consider the new 50/20/30 splits path
MANIFEST_ALL = f"{DX_OUT_ROOT}/manifests/manifest_all.csv"

# -------------------------
# Run settings: must match Train+Val behavior where relevant
# -------------------------
SEEDS          = [1337, 2024, 7777]
BACKBONE_CKPT  = "facebook/wav2vec2-base"
SR_EXPECTED    = 16000
TINY_THRESH    = 1e-4

# Batch sizing (used only for inference here; printed for traceability)
EFFECTIVE_BS   = 64
PER_DEVICE_BS  = 16
GRAD_ACCUM     = max(1, EFFECTIVE_BS // PER_DEVICE_BS)  # not used in test, but printed

# Must match Train+Val
DROPOUT_P      = 0.2

# Drive friendly defaults
NUM_WORKERS    = 0
PIN_MEMORY     = False

# AMP speeds up inference on GPU; heads are still computed in FP32
USE_AMP        = True

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

# Keep notebook output cleaner (no change to logic)
warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
warnings.filterwarnings("ignore", message="Passing `gradient_checkpointing` to a config initialization is deprecated")

print("DX_OUT_ROOT:", DX_OUT_ROOT)
print("MANIFEST_ALL:", MANIFEST_ALL)
print("DEVICE:", DEVICE)
if DEVICE.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
print("PER_DEVICE_BS:", PER_DEVICE_BS, "| EFFECTIVE_BS:", PER_DEVICE_BS * GRAD_ACCUM)
print("NUM_WORKERS:", NUM_WORKERS, "| PIN_MEMORY:", PIN_MEMORY)
print("USE_AMP:", bool(USE_AMP and DEVICE.type == "cuda"))

# -------------------------
# Load manifest, keep VAL and TEST only, and lock to dataset_id (guarded)
# -------------------------
if not os.path.exists(MANIFEST_ALL):
    raise FileNotFoundError(f"Missing manifest_all.csv: {MANIFEST_ALL}")

m = pd.read_csv(MANIFEST_ALL)

# Required columns for labels, task routing, and fairness reporting
req_cols = {"split", "clip_path", "label_num", "task", "sex", "age"}
missing = [c for c in sorted(req_cols) if c not in m.columns]
if missing:
    raise ValueError(f"Manifest missing required columns: {missing}. Found: {list(m.columns)}")

m = m[m["split"].isin(["val", "test"])].copy()
if len(m) == 0:
    raise RuntimeError("After filtering to split in {'val','test'}, manifest has 0 rows.")

# Prefer the manifest dataset column when present, otherwise treat as generic "DX"
if "dataset" in m.columns and m["dataset"].notna().any():
    dataset_id = str(m["dataset"].astype(str).value_counts(dropna=True).idxmax())
    m = m[m["dataset"].astype(str) == dataset_id].copy()
else:
    dataset_id = "DX"

# Guard: prevents accidental reuse of the wrong DX_OUT_ROOT or manifest
if dataset_id != "D5":
    raise RuntimeError(
        f"Expected dataset_id=='D5' but got {dataset_id!r}. "
        "This usually means DX_OUT_ROOT was inherited from a previous cell or the manifest is not D5. "
        f"DX_OUT_ROOT={DX_OUT_ROOT}"
    )

# Keep only the columns this cell needs (missing ones are created as NaN)
keep_cols = ["clip_path", "label_num", "task", "speaker_id", "sex", "age", "duration_sec", "split"]
for c in keep_cols:
    if c not in m.columns:
        m[c] = np.nan
m = m[keep_cols].copy()

val_df  = m[m["split"] == "val"].reset_index(drop=True)
test_df = m[m["split"] == "test"].reset_index(drop=True)

print(f"\nDataset inferred: {dataset_id}")
print(f"Val rows:  {len(val_df)}")
print(f"Test rows: {len(test_df)}")
print("Val label counts:",  val_df["label_num"].value_counts(dropna=False).to_dict())
print("Test label counts:", test_df["label_num"].value_counts(dropna=False).to_dict())
print("Val sex counts (raw):",  val_df["sex"].value_counts(dropna=False).to_dict())
print("Test sex counts (raw):", test_df["sex"].value_counts(dropna=False).to_dict())

# VAL is required because the threshold is picked from VAL
if len(val_df) == 0:
    raise RuntimeError("Val split has 0 rows. VAL-opt threshold cannot be computed.")
if len(test_df) == 0:
    raise RuntimeError("Test split has 0 rows.")

# -------------------------
# Quick data integrity: ensure referenced audio clips exist (stop early if not)
# -------------------------
def _fail_fast_missing_paths(df: pd.DataFrame, name: str):
    missing_paths = []
    for p in tqdm(df["clip_path"].astype(str).tolist(), desc=f"Check {name} clip_path exists", dynamic_ncols=True):
        if not os.path.exists(p):
            missing_paths.append(p)
            if len(missing_paths) >= 10:
                break
    if missing_paths:
        raise FileNotFoundError(f"{name}: missing clip_path(s). Examples: {missing_paths}")

_fail_fast_missing_paths(val_df, "VAL")
_fail_fast_missing_paths(test_df, "TEST")

# -------------------------
# Task routing: map each clip into the head it should use
# -------------------------
def _task_group(task_val) -> str:
    return "vowel" if str(task_val) == "vowl" else "other"

val_df["task_group"]  = val_df["task"].apply(_task_group)
test_df["task_group"] = test_df["task"].apply(_task_group)

# -------------------------
# Sex normalization: standardize for fairness reporting and sex specific confusion charts
# -------------------------
def normalize_sex(val) -> str:
    """
    Returns 'M', 'F', or 'UNK'.
    Leaves unknown formats as UNK to avoid silent remapping mistakes.
    """
    if pd.isna(val):
        return "UNK"
    s = str(val).strip().lower()
    if s in {"m", "male", "man", "masc", "masculine"}:
        return "M"
    if s in {"f", "female", "woman", "fem", "feminine"}:
        return "F"
    return "UNK"

val_df["sex_norm"]  = val_df["sex"].apply(normalize_sex)
test_df["sex_norm"] = test_df["sex"].apply(normalize_sex)

print("Val sex counts (normalized):",  val_df["sex_norm"].value_counts(dropna=False).to_dict())
print("Test sex counts (normalized):", test_df["sex_norm"].value_counts(dropna=False).to_dict())
if (val_df["sex_norm"] == "UNK").any() or (test_df["sex_norm"] == "UNK").any():
    print("NOTE: Some 'sex' values could not be normalized to M/F and were counted as 'UNK' for fairness and sex charts.")

# -------------------------
# Data pipeline: read waveform + build attention mask + pad batches
# -------------------------
class AudioManifestDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        clip_path = str(row["clip_path"])
        label = int(row["label_num"])
        task_group = str(row["task_group"])
        sex_norm = str(row["sex_norm"])

        if not os.path.exists(clip_path):
            raise FileNotFoundError(f"Missing clip_path: {clip_path}")

        # Input: audio clip on disk. Output: mono float32 waveform at 16 kHz.
        y, sr = sf.read(clip_path, always_2d=False)
        if isinstance(y, np.ndarray) and y.ndim > 1:
            y = np.mean(y, axis=1)
        y = np.asarray(y, dtype=np.float32)

        if int(sr) != SR_EXPECTED:
            raise ValueError(f"Sample rate mismatch (expected {SR_EXPECTED}): {clip_path} has sr={sr}")

        # Attention mask: for vowel clips, ignore trailing near zero samples (padding like silence)
        attn = np.ones((len(y),), dtype=np.int64)
        if task_group == "vowel":
            k = -1
            for j in range(len(y) - 1, -1, -1):
                if abs(float(y[j])) > TINY_THRESH:
                    k = j
                    break
            if k >= 0:
                attn[:k+1] = 1
                attn[k+1:] = 0
            else:
                attn[:] = 1
        else:
            attn[:] = 1

        return {
            "input_values": torch.from_numpy(y),
            "attention_mask": torch.from_numpy(attn),
            "labels": torch.tensor(label, dtype=torch.long),
            "task_group": task_group,
            "sex_norm": sex_norm,
        }

def collate_fn(batch):
    # Pads a batch to the longest waveform length and pads attention mask to match
    max_len = int(max(b["input_values"].numel() for b in batch))
    input_vals, attn_masks, labels, task_groups, sex_norms = [], [], [], [], []
    for b in batch:
        x = b["input_values"]
        a = b["attention_mask"]
        pad = max_len - x.numel()
        if pad > 0:
            x = torch.cat([x, torch.zeros(pad, dtype=x.dtype)], dim=0)
            a = torch.cat([a, torch.zeros(pad, dtype=a.dtype)], dim=0)
        input_vals.append(x)
        attn_masks.append(a)
        labels.append(b["labels"])
        task_groups.append(b["task_group"])
        sex_norms.append(b["sex_norm"])
    return {
        "input_values": torch.stack(input_vals, dim=0),
        "attention_mask": torch.stack(attn_masks, dim=0),
        "labels": torch.stack(labels, dim=0),
        "task_group": task_groups,
        "sex_norm": sex_norms,
    }

# -------------------------
# Model: frozen Wav2Vec2 backbone + 2 heads (task routed) like Train+Val
# -------------------------
class Wav2Vec2TwoHeadClassifier(nn.Module):
    def __init__(self, ckpt: str, dropout_p: float):
        super().__init__()
        # Input: padded waveforms + attention mask. Output: pooled embedding -> head logits.
        self.backbone = Wav2Vec2Model.from_pretrained(ckpt, use_safetensors=True)
        for p in self.backbone.parameters():
            p.requires_grad = False

        H = int(self.backbone.config.hidden_size)

        self.pre_vowel = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.pre_other = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))

        self.head_vowel = nn.Linear(H, 2)
        self.head_other = nn.Linear(H, 2)

    def masked_mean_pool(self, last_hidden, attn_mask_samples):
        # Converts sample mask -> feature mask (Wav2Vec2 downsampling) and mean pools
        feat_mask = self.backbone._get_feature_vector_attention_mask(last_hidden.shape[1], attn_mask_samples)
        feat_mask = feat_mask.to(last_hidden.device).unsqueeze(-1).type_as(last_hidden)
        masked = last_hidden * feat_mask
        denom = feat_mask.sum(dim=1).clamp(min=1.0)
        return masked.sum(dim=1) / denom

    def _heads_fp32(self, x_fp_any: torch.Tensor, head: nn.Module) -> torch.Tensor:
        # Ensures head logits are computed in FP32 even when AMP is enabled
        x = x_fp_any.float()
        if x.is_cuda:
            with torch.amp.autocast(device_type="cuda", enabled=False):
                return head(x)
        return head(x)

    def forward_logits(self, input_values, attention_mask, task_group):
        with torch.no_grad():
            out = self.backbone(input_values=input_values, attention_mask=attention_mask)
            last_hidden = out.last_hidden_state

        pooled = self.masked_mean_pool(last_hidden, attention_mask)  # [B,H]

        z_v = self.pre_vowel(pooled.float())
        z_o = self.pre_other(pooled.float())

        logits_v = self._heads_fp32(z_v, self.head_vowel)
        logits_o = self._heads_fp32(z_o, self.head_other)

        # Route each sample to the correct head based on task_group
        is_vowel = torch.tensor([tg == "vowel" for tg in task_group], device=pooled.device, dtype=torch.bool)
        logits = logits_o.clone()
        if is_vowel.any():
            logits[is_vowel] = logits_v[is_vowel]
        return logits

# -------------------------
# Metrics + plots: AUROC, threshold metrics, ROC and confusion charts
# -------------------------
def compute_auc(y_true, y_prob):
    # Output: AUROC (NaN if split has only one class)
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    if len(np.unique(y_true)) < 2:
        return float("nan")
    return float(roc_auc_score(y_true, y_prob))

def compute_threshold_metrics(y_true, y_prob, thr=0.5):
    # Output: confusion matrix + accuracy/precision/recall/F1/specificity/MCC + Fisher p value
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])

    eps = 1e-12
    acc = (TP + TN) / max(1, (TP + TN + FP + FN))
    prec = TP / (TP + FP + eps)
    rec = TP / (TP + FN + eps)     # sensitivity
    f1 = 2 * prec * rec / (prec + rec + eps)
    spec = TN / (TN + FP + eps)

    try:
        mcc = float(matthews_corrcoef(y_true, y_pred)) if len(np.unique(y_pred)) > 1 else float("nan")
    except Exception:
        mcc = float("nan")

    try:
        _, pval = fisher_exact([[TN, FP], [FN, TP]], alternative="two-sided")
        pval = float(pval)
    except Exception:
        pval = float("nan")

    return {
        "threshold": float(thr),
        "confusion_matrix": {"TN": TN, "FP": FP, "FN": FN, "TP": TP},
        "accuracy": float(acc),
        "precision": float(prec),
        "recall": float(rec),
        "f1_score": float(f1),
        "sensitivity": float(rec),
        "specificity": float(spec),
        "mcc": float(mcc),
        "p_value_fisher_two_sided": float(pval),
    }

def save_roc_curve_png(y_true, y_prob, out_png, title_suffix="Test"):
    # Output: ROC curve PNG (threshold free)
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    plt.figure()
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curve ({title_suffix})")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def save_confusion_png(y_true, y_prob, out_png, thr=0.5, title_suffix="Test"):
    # Output: confusion matrix PNG at a fixed threshold
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(int)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    plt.figure()
    plt.imshow(cm)
    plt.xticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.yticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix ({title_suffix}, thr={thr:.2f})")
    for (i, j), v in np.ndenumerate(cm):
        plt.text(j, i, str(v), ha="center", va="center")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

# -------------------------
# Val optimal threshold (Youden J): pick thr that maximizes TPR - FPR on VAL ROC
# -------------------------
def compute_val_opt_threshold_youden(y_true, y_prob):
    """
    Output: (thr_opt, j_opt, tpr_opt, fpr_opt)
    Fallback: if VAL has one class, returns thr_opt=0.5 and NaNs for others.
    """
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)

    if len(np.unique(y_true)) < 2:
        return 0.5, float("nan"), float("nan"), float("nan")

    fpr, tpr, thr = roc_curve(y_true, y_prob)
    thr = np.asarray(thr, dtype=np.float64)
    fpr = np.asarray(fpr, dtype=np.float64)
    tpr = np.asarray(tpr, dtype=np.float64)

    finite = np.isfinite(thr)
    if finite.any():
        thr_f = thr[finite]
        fpr_f = fpr[finite]
        tpr_f = tpr[finite]
    else:
        thr_f, fpr_f, tpr_f = thr, fpr, tpr

    j = tpr_f - fpr_f
    idx = int(np.nanargmax(j)) if len(j) else 0

    thr_opt = float(thr_f[idx]) if len(thr_f) else 0.5
    j_opt   = float(j[idx]) if len(j) else float("nan")
    tpr_opt = float(tpr_f[idx]) if len(tpr_f) else float("nan")
    fpr_opt = float(fpr_f[idx]) if len(fpr_f) else float("nan")

    return thr_opt, j_opt, tpr_opt, fpr_opt

# -------------------------
# Fairness (H3): FNR by sex on PD only, plus ΔFNR = FNR(F) - FNR(M)
# -------------------------
def compute_fnr_by_group_signed(y_true, y_prob, groups, thr=0.5):
    """
    Output:
      - per group: n_total, n_pos (PD), tp, fn, fnr
      - delta_f_minus_m: FNR(F) - FNR(M) when both exist
      - delta_abs: absolute gap
    """
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    groups = np.asarray(list(groups), dtype=object)

    y_pred = (y_prob >= float(thr)).astype(np.int64)

    out = {}
    for g in sorted(set(groups.tolist())):
        mask_g = (groups == g)
        if int(mask_g.sum()) == 0:
            continue

        pos_mask = mask_g & (y_true == 1)
        n_pos = int(pos_mask.sum())

        if n_pos == 0:
            out[g] = {"n_total": int(mask_g.sum()), "n_pos": 0, "tp": 0, "fn": 0, "fnr": float("nan")}
            continue

        tp = int(((y_pred == 1) & pos_mask).sum())
        fn = int(((y_pred == 0) & pos_mask).sum())
        fnr = float(fn / max(1, (fn + tp)))

        out[g] = {"n_total": int(mask_g.sum()), "n_pos": int(n_pos), "tp": int(tp), "fn": int(fn), "fnr": float(fnr)}

    fnr_m = out.get("M", {}).get("fnr", float("nan"))
    fnr_f = out.get("F", {}).get("fnr", float("nan"))
    if (not np.isnan(fnr_m)) and (not np.isnan(fnr_f)):
        delta_signed = float(fnr_f - fnr_m)   # H3
        delta_abs = float(abs(delta_signed))
    else:
        delta_signed = float("nan")
        delta_abs = float("nan")

    return out, delta_signed, delta_abs

# -------------------------
# Confusion breakdown by group: full TN/FP/FN/TP for each sex category
# -------------------------
def compute_confusion_counts(y_true, y_prob, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])
    return {"TN": TN, "FP": FP, "FN": FN, "TP": TP}

def compute_confusion_by_group(y_true, y_prob, groups, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    groups = np.asarray(list(groups), dtype=object)
    out = {}
    for g in sorted(set(groups.tolist())):
        mask = (groups == g)
        if int(mask.sum()) == 0:
            continue
        out[g] = {"n": int(mask.sum()), "confusion": compute_confusion_counts(y_true[mask], y_prob[mask], thr=thr)}
    return out

# -------------------------
# Reproducibility: set all RNG seeds per run
# -------------------------
def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# -------------------------
# Find the most recent trainval experiment that has best_heads.pt for all 3 seeds
# -------------------------
TRAINVAL_ROOT = Path(DX_OUT_ROOT) / "trainval_runs"
if not TRAINVAL_ROOT.exists():
    raise FileNotFoundError(f"Missing trainval_runs folder: {str(TRAINVAL_ROOT)}")

exp_dirs = sorted([p for p in TRAINVAL_ROOT.glob("exp_*") if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
if not exp_dirs:
    raise FileNotFoundError(f"No exp_* folders found under: {str(TRAINVAL_ROOT)}")

def _has_all_seeds(exp_path: Path, dataset_id: str, seeds: list):
    for s in seeds:
        p = exp_path / f"run_{dataset_id}_seed{s}" / "best_heads.pt"
        if not p.exists():
            return False
    return True

chosen_exp = None
for ed in exp_dirs:
    if _has_all_seeds(ed, dataset_id, SEEDS):
        chosen_exp = ed
        break

if chosen_exp is None:
    sample = exp_dirs[0]
    raise FileNotFoundError(
        "Could not find a trainval experiment with all 3 best_heads.pt files.\n"
        f"Expected under: {str(TRAINVAL_ROOT)}/exp_*/run_{dataset_id}_seedXXXX/best_heads.pt\n"
        f"Most recent exp checked: {str(sample)}"
    )

print("\nUsing Train+Val experiment folder:")
print(" ", str(chosen_exp))

# Guard: recheck expected artifacts right after choosing the exp folder
for s in SEEDS:
    p = chosen_exp / f"run_{dataset_id}_seed{s}" / "best_heads.pt"
    if not p.exists():
        raise RuntimeError(f"Trainval artifact missing after choosing exp. Missing: {str(p)}")

# -------------------------
# Output folder for test runs
# -------------------------
TEST_ROOT = Path(DX_OUT_ROOT) / "monolingual_test_runs"
TEST_ROOT.mkdir(parents=True, exist_ok=True)

# -------------------------
# Build DataLoaders and warm up a few batches to catch issues early
# -------------------------
def _num_batches(n_items: int, bs: int) -> int:
    if bs <= 0:
        return 0
    return int((n_items + bs - 1) // bs)

val_ds = AudioManifestDataset(val_df)
test_ds = AudioManifestDataset(test_df)

val_loader = DataLoader(
    val_ds,
    batch_size=PER_DEVICE_BS,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    collate_fn=collate_fn,
)

test_loader = DataLoader(
    test_ds,
    batch_size=PER_DEVICE_BS,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    collate_fn=collate_fn,
)

n_val_items = int(len(val_ds))
n_test_items = int(len(test_ds))
n_val_batches = _num_batches(n_val_items, int(PER_DEVICE_BS))
n_test_batches = _num_batches(n_test_items, int(PER_DEVICE_BS))

print(f"\nVAL  items: {n_val_items}  | batch_size: {PER_DEVICE_BS} | num_batches: {n_val_batches}")
print(f"TEST items: {n_test_items} | batch_size: {PER_DEVICE_BS} | num_batches: {n_test_batches}")

def _warmup(loader, name: str):
    # Loads up to 3 batches to verify reading, padding, and shapes
    n_items = len(loader.dataset)
    n_batches = _num_batches(int(n_items), int(PER_DEVICE_BS))
    n_warm = int(min(3, n_batches))
    print(f"\nWarm-up ({name}): loading {n_warm} batch(es)...")
    t0 = time.time()
    if n_warm == 0:
        raise RuntimeError(f"{name} DataLoader has 0 batches. Check that {name.lower()}_df is non-empty and paths exist.")
    it = iter(loader)
    for i in range(n_warm):
        _ = next(it)
        print(f"  loaded warmup batch {i+1}/{n_warm}")
    print(f"Warm-up ({name}) done in {time.time()-t0:.2f}s")

_warmup(val_loader, "VAL")
_warmup(test_loader, "TEST")

# -------------------------
# Load only the saved heads into a fresh model (backbone stays frozen)
# -------------------------
def load_heads_into_model(model: Wav2Vec2TwoHeadClassifier, best_heads_path: Path):
    if not best_heads_path.exists():
        raise FileNotFoundError(f"Missing best_heads.pt: {str(best_heads_path)}")
    state = torch.load(str(best_heads_path), map_location="cpu")

    needed = ["pre_vowel", "pre_other", "head_vowel", "head_other"]
    missing = [k for k in needed if k not in state]
    if missing:
        raise KeyError(
            f"best_heads.pt missing keys {missing}. Found keys: {list(state.keys())}. "
            "This test code expects the D5 trainval save format."
        )

    model.pre_vowel.load_state_dict(state["pre_vowel"], strict=True)
    model.pre_other.load_state_dict(state["pre_other"], strict=True)
    model.head_vowel.load_state_dict(state["head_vowel"], strict=True)
    model.head_other.load_state_dict(state["head_other"], strict=True)
    return model

# -------------------------
# Inference helper: run a loader and return labels, PD probabilities, and sex labels
# -------------------------
def infer_probs(model, loader, seed: int, split_name: str):
    use_amp = bool(USE_AMP and DEVICE.type == "cuda")
    all_probs, all_true, all_sex = [], [], []

    pbar = tqdm(loader, desc=f"[seed={seed}] {split_name}", dynamic_ncols=True)
    with torch.inference_mode():
        for batch in pbar:
            input_values = batch["input_values"].to(DEVICE, non_blocking=False)
            attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
            labels = batch["labels"].to(DEVICE, non_blocking=False)
            task_group = batch["task_group"]
            sex_norm = batch["sex_norm"]

            with torch.amp.autocast(device_type="cuda", enabled=use_amp):
                logits = model.forward_logits(input_values, attention_mask, task_group)

            # Output probability: PD class probability (class index 1)
            probs = torch.softmax(logits.float(), dim=-1)[:, 1].detach().cpu().numpy().astype(np.float64)
            all_probs.extend(probs.tolist())
            all_true.extend(labels.detach().cpu().numpy().astype(np.int64).tolist())
            all_sex.extend(list(sex_norm))

    return np.asarray(all_true, dtype=np.int64), np.asarray(all_probs, dtype=np.float64), np.asarray(all_sex, dtype=object)

# -------------------------
# Single seed evaluation:
# 1) infer on VAL -> choose threshold
# 2) infer on TEST -> report metrics and fairness at that threshold
# 3) write metrics.json and plots under run_<DATASET>_seedXXXX
# -------------------------
def run_test_once(seed: int):
    set_all_seeds(seed)

    run_dir = TEST_ROOT / f"run_{dataset_id}_seed{seed}"
    run_dir.mkdir(parents=True, exist_ok=True)

    best_heads_path = chosen_exp / f"run_{dataset_id}_seed{seed}" / "best_heads.pt"

    print(f"\n[seed={seed}] Loading model + heads from:")
    print(" ", str(best_heads_path))

    model = Wav2Vec2TwoHeadClassifier(BACKBONE_CKPT, dropout_p=DROPOUT_P).to(DEVICE)
    model = load_heads_into_model(model, best_heads_path)
    model.eval()

    # VAL: used only to pick a threshold (Youden J)
    y_val, p_val, _sex_val = infer_probs(model, val_loader, seed=seed, split_name="VAL")
    val_auc = compute_auc(y_val, p_val)

    val_thr_opt, val_j_opt, val_tpr_opt, val_fpr_opt = compute_val_opt_threshold_youden(y_val, p_val)

    print(f"[seed={seed}] VAL-opt threshold (Youden J): thr={val_thr_opt:.6f} | J={val_j_opt:.6f} | TPR={val_tpr_opt:.6f} | FPR={val_fpr_opt:.6f}")

    # TEST: final reporting split
    y_test, p_test, sex_test = infer_probs(model, test_loader, seed=seed, split_name="TEST")

    test_auc = compute_auc(y_test, p_test)
    thr_metrics = compute_threshold_metrics(y_test, p_test, thr=val_thr_opt)

    # Fairness on TEST using PD only denominators
    fnr_by_sex, delta_f_minus_m, delta_abs = compute_fnr_by_group_signed(
        y_test, p_test, sex_test, thr=val_thr_opt
    )

    # Confusion matrices per sex on TEST (all labels)
    confusion_by_sex = compute_confusion_by_group(y_test, p_test, sex_test, thr=val_thr_opt)

    # Plots (overall + sex specific confusion)
    roc_png = run_dir / "roc_curve.png"
    cm_png  = run_dir / "confusion_matrix.png"
    save_roc_curve_png(y_test, p_test, str(roc_png), title_suffix=f"Test (seed={seed})")
    save_confusion_png(y_test, p_test, str(cm_png), thr=val_thr_opt, title_suffix=f"Test (seed={seed})")

    cm_m_png = None
    cm_f_png = None
    mask_m = (sex_test == "M")
    mask_f = (sex_test == "F")

    if int(mask_m.sum()) > 0:
        cm_m_png = run_dir / "confusion_matrix_M.png"
        save_confusion_png(
            y_test[mask_m], p_test[mask_m], str(cm_m_png),
            thr=val_thr_opt, title_suffix=f"Test SEX=M (seed={seed})"
        )

    if int(mask_f.sum()) > 0:
        cm_f_png = run_dir / "confusion_matrix_F.png"
        save_confusion_png(
            y_test[mask_f], p_test[mask_f], str(cm_f_png),
            thr=val_thr_opt, title_suffix=f"Test SEX=F (seed={seed})"
        )

    # Save: per seed JSON metrics + artifact paths
    metrics = {
        "dataset": dataset_id,
        "seed": int(seed),

        "n_val": int(len(val_df)),
        "label_counts_val": val_df["label_num"].value_counts(dropna=False).to_dict(),
        "sex_counts_val_norm": val_df["sex_norm"].value_counts(dropna=False).to_dict(),

        "n_test": int(len(test_df)),
        "label_counts_test": test_df["label_num"].value_counts(dropna=False).to_dict(),
        "sex_counts_test_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),

        "val_auroc": float(val_auc),
        "val_opt_threshold": {
            "method": "Youden J on VAL (maximize TPR - FPR)",
            "threshold": float(val_thr_opt),
            "youden_j": float(val_j_opt),
            "tpr_at_opt": float(val_tpr_opt),
            "fpr_at_opt": float(val_fpr_opt),
            "note": "If VAL has only one class, threshold defaults to 0.5 and J/TPR/FPR are NaN.",
        },

        "test_auroc": float(test_auc),
        "threshold_metrics_test_at_val_opt": thr_metrics,

        "fairness_test_at_val_opt": {
            "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at VAL-opt threshold.",
            "fnr_by_sex_norm": fnr_by_sex,
            "delta_fnr_F_minus_M": float(delta_f_minus_m),
            "delta_fnr_abs": float(delta_abs),
            "sex_normalization_note": "sex_norm in {M,F,UNK}. Values not mapped to M/F counted as UNK. ΔFNR computed only when both M and F have defined FNR.",
        },

        "confusion_by_sex_norm_at_val_opt": confusion_by_sex,

        "artifacts": {
            "roc_curve_png": str(roc_png),
            "confusion_matrix_png": str(cm_png),
            "confusion_matrix_M_png": (str(cm_m_png) if cm_m_png is not None else None),
            "confusion_matrix_F_png": (str(cm_f_png) if cm_f_png is not None else None),
        },

        "dx_out_root": DX_OUT_ROOT,
        "trainval_experiment_used": str(chosen_exp),
        "best_heads_path": str(best_heads_path),
        "backbone_ckpt": BACKBONE_CKPT,
        "dropout_p": float(DROPOUT_P),

        "batching": {
            "per_device_bs": int(PER_DEVICE_BS),
            "effective_bs": int(PER_DEVICE_BS * GRAD_ACCUM),
            "num_val_batches": int(n_val_batches),
            "num_test_batches": int(n_test_batches),
            "n_val_items": int(n_val_items),
            "n_test_items": int(n_test_items),
        },
        "amp": bool(USE_AMP and DEVICE.type == "cuda"),
        "device": str(DEVICE),
        "gpu": (torch.cuda.get_device_name(0) if DEVICE.type == "cuda" else "CPU"),
        "timestamp": time.strftime("%Y%m%d_%H%M%S"),
    }

    metrics_path = run_dir / "metrics.json"
    with open(metrics_path, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2)

    # Console summary (kept as existing prints)
    def _fmt_fnr(g):
        d = fnr_by_sex.get(g, None)
        if d is None:
            return "n/a"
        return f"fnr={d['fnr']:.6f} (n_PD={d['n_pos']}, fn={d['fn']}, tp={d['tp']})"

    print(f"[seed={seed}] DONE | val_AUROC={val_auc:.6f} | test_AUROC={test_auc:.6f}")
    print(f"[seed={seed}] TEST METRICS @ VAL-OPT thr={val_thr_opt:.6f} written to metrics.json")
    print(f"[seed={seed}] FAIRNESS (H3) @ VAL-OPT thr={val_thr_opt:.6f}:")
    print("  M:", _fmt_fnr("M"))
    print("  F:", _fmt_fnr("F"))
    if "UNK" in fnr_by_sex:
        print("  UNK:", _fmt_fnr("UNK"))
    print("  ΔFNR (F-M):", f"{delta_f_minus_m:.6f}" if not np.isnan(delta_f_minus_m) else "nan")
    print("  |ΔFNR|:", f"{delta_abs:.6f}" if not np.isnan(delta_abs) else "nan")

    print(f"[seed={seed}] WROTE:")
    print(" ", str(metrics_path))
    print(" ", str(roc_png))
    print(" ", str(cm_png))
    if cm_m_png is not None:
        print(" ", str(cm_m_png))
    if cm_f_png is not None:
        print(" ", str(cm_f_png))

    return {
        "seed": int(seed),
        "val_auc": float(val_auc),
        "val_opt_thr": float(val_thr_opt),
        "val_youj": float(val_j_opt),
        "test_auc": float(test_auc),
        "thr_metrics_test": thr_metrics,
        "fnr_by_sex": fnr_by_sex,
        "delta_signed": float(delta_f_minus_m),
        "delta_abs": float(delta_abs),
        "confusion_by_sex": confusion_by_sex,
        "run_dir": str(run_dir),
    }

# -------------------------
# Run all seeds and summarize:
# - mean TEST AUROC with 95% CI
# - mean ± SD threshold metrics on TEST
# - mean ± SD fairness gaps on TEST
# -------------------------
results = []
for seed in SEEDS:
    results.append(run_test_once(seed))

test_aurocs = [r["test_auc"] for r in results]

t_crit = 4.302652729911275  # df=2, 95% CI
n = len(test_aurocs)
mean_auc = float(np.mean(test_aurocs))
std_auc = float(np.std(test_aurocs, ddof=1)) if n > 1 else 0.0
half_width = float(t_crit * (std_auc / math.sqrt(n))) if n > 1 else 0.0
ci95 = [float(mean_auc - half_width), float(mean_auc + half_width)]

def _mean_sd_nan(vals):
    # Aggregates across seeds while safely ignoring NaNs
    vals = np.asarray(vals, dtype=np.float64)
    return float(np.nanmean(vals)), float(np.nanstd(vals, ddof=1)) if np.sum(~np.isnan(vals)) > 1 else 0.0

# Aggregate threshold metrics (TEST @ VAL-opt thr)
keys = ["accuracy","precision","recall","f1_score","sensitivity","specificity","mcc","p_value_fisher_two_sided"]
agg = {}
for k in keys:
    v = [float(r["thr_metrics_test"].get(k, float("nan"))) for r in results]
    mu, sd = _mean_sd_nan(v)
    agg[k] = {
        "mean": mu,
        "sd": sd,
        "values_by_seed": {str(r["seed"]): float(r["thr_metrics_test"].get(k, float("nan"))) for r in results},
    }

cm_by_seed = {str(r["seed"]): r["thr_metrics_test"]["confusion_matrix"] for r in results}

# Aggregate thresholds picked on VAL (informational)
val_thr_vals = [float(r["val_opt_thr"]) for r in results]
val_thr_mean, val_thr_sd = _mean_sd_nan(val_thr_vals)

# Aggregate fairness (TEST @ VAL-opt thr)
fnr_m_vals, fnr_f_vals, n_pd_m_vals, n_pd_f_vals = [], [], [], []
d_signed_vals, d_abs_vals = [], []

for r in results:
    d = r["fnr_by_sex"]
    fnr_m_vals.append(float(d.get("M", {}).get("fnr", float("nan"))))
    fnr_f_vals.append(float(d.get("F", {}).get("fnr", float("nan"))))
    n_pd_m_vals.append(float(d.get("M", {}).get("n_pos", float("nan"))))
    n_pd_f_vals.append(float(d.get("F", {}).get("n_pos", float("nan"))))
    d_signed_vals.append(float(r["delta_signed"]))
    d_abs_vals.append(float(r["delta_abs"]))

fnr_m_mean, fnr_m_sd = _mean_sd_nan(fnr_m_vals)
fnr_f_mean, fnr_f_sd = _mean_sd_nan(fnr_f_vals)
d_signed_mean, d_signed_sd = _mean_sd_nan(d_signed_vals)
d_abs_mean, d_abs_sd = _mean_sd_nan(d_abs_vals)

print("\nTest AUROC by seed:")
for r in results:
    print(f"  seed {r['seed']}: {r['test_auc']:.6f}")
print(f"\nMean Test AUROC: {mean_auc:.6f}")
print(f"95% CI (t, n=3): [{ci95[0]:.6f}, {ci95[1]:.6f}]")

print("\nVAL-opt thresholds (Youden J) by seed:")
for r in results:
    print(f"  seed {r['seed']}: thr={r['val_opt_thr']:.6f}")
print(f"  mean ± SD: {val_thr_mean:.6f} ± {val_thr_sd:.6f}")

print("\nThreshold metrics on TEST @ VAL-OPT threshold (mean ± SD across seeds):")
for k in ["accuracy","precision","sensitivity","specificity","f1_score","mcc"]:
    mu = agg[k]["mean"]
    sd = agg[k]["sd"]
    print(f"  {k}: {mu:.6f} ± {sd:.6f}")
print("  fisher_p_value_two_sided:", f"{agg['p_value_fisher_two_sided']['mean']:.6g} ± {agg['p_value_fisher_two_sided']['sd']:.6g}")

print("\nFAIRNESS (H3) on TEST @ VAL-OPT threshold (mean ± SD):")
print(f"  FNR_M: {fnr_m_mean:.6f} ± {fnr_m_sd:.6f}")
print(f"  FNR_F: {fnr_f_mean:.6f} ± {fnr_f_sd:.6f}")
print(f"  ΔFNR (F-M): {d_signed_mean:.6f} ± {d_signed_sd:.6f}")
print(f"  |ΔFNR|: {d_abs_mean:.6f} ± {d_abs_sd:.6f}")
print("  Per-seed:", {
    str(r["seed"]): {
        "val_opt_thr": r["val_opt_thr"],
        "FNR_M": fnr_m_vals[i],
        "n_PD_M": n_pd_m_vals[i],
        "FNR_F": fnr_f_vals[i],
        "n_PD_F": n_pd_f_vals[i],
        "delta_F_minus_M": d_signed_vals[i],
        "abs_delta": d_abs_vals[i],
    } for i, r in enumerate(results)
})

# Summary JSON: single place to read all aggregated results + pointers to per-seed runs
summary = {
    "dataset": dataset_id,
    "dx_out_root": DX_OUT_ROOT,
    "manifest_all": MANIFEST_ALL,
    "trainval_experiment_used": str(chosen_exp),
    "seeds": SEEDS,

    "n_val": int(len(val_df)),
    "label_counts_val": val_df["label_num"].value_counts(dropna=False).to_dict(),
    "sex_counts_val_norm": val_df["sex_norm"].value_counts(dropna=False).to_dict(),

    "n_test": int(len(test_df)),
    "label_counts_test": test_df["label_num"].value_counts(dropna=False).to_dict(),
    "sex_counts_test_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),

    "val_opt_threshold_by_seed": {str(r["seed"]): float(r["val_opt_thr"]) for r in results},
    "val_opt_threshold_mean_sd": {"mean": float(val_thr_mean), "sd": float(val_thr_sd)},
    "val_you_dens_j_by_seed": {str(r["seed"]): float(r["val_youj"]) for r in results},

    "test_aurocs_by_seed": {str(r["seed"]): float(r["test_auc"]) for r in results},
    "mean_test_auroc": float(mean_auc),
    "t_crit_df2_95": float(t_crit),
    "half_width_95_ci": float(half_width),
    "ci95_test_auroc": ci95,

    "threshold_metrics_test_at_val_opt_mean_sd": agg,
    "confusion_matrix_test_at_val_opt_by_seed": cm_by_seed,
    "run_dirs": [r["run_dir"] for r in results],

    "fairness_test_at_val_opt": {
        "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at VAL-opt threshold.",
        "fnr_by_sex_norm_by_seed": {str(r["seed"]): r["fnr_by_sex"] for r in results},
        "delta_fnr_F_minus_M_by_seed": {str(r["seed"]): float(r["delta_signed"]) for r in results},
        "delta_fnr_abs_by_seed": {str(r["seed"]): float(r["delta_abs"]) for r in results},
        "fnr_M_mean_sd": {"mean": float(fnr_m_mean), "sd": float(fnr_m_sd), "values_by_seed": {str(r["seed"]): float(v) for r, v in zip(results, fnr_m_vals)}},
        "fnr_F_mean_sd": {"mean": float(fnr_f_mean), "sd": float(fnr_f_sd), "values_by_seed": {str(r["seed"]): float(v) for r, v in zip(results, fnr_f_vals)}},
        "delta_fnr_F_minus_M_mean_sd": {"mean": float(d_signed_mean), "sd": float(d_signed_sd), "values_by_seed": {str(r["seed"]): float(v) for r, v in zip(results, d_signed_vals)}},
        "delta_fnr_abs_mean_sd": {"mean": float(d_abs_mean), "sd": float(d_abs_sd), "values_by_seed": {str(r["seed"]): float(v) for r, v in zip(results, d_abs_vals)}},
        "denominators_PD_by_seed": {str(r["seed"]): {"n_PD_M": float(n_pd_m_vals[i]), "n_PD_F": float(n_pd_f_vals[i])} for i, r in enumerate(results)},
        "sex_normalization_note": "sex_norm in {M,F,UNK}. Values not mapped to M/F counted as UNK. ΔFNR computed only when both M and F have defined FNR.",
    },

    "batching": {
        "per_device_bs": int(PER_DEVICE_BS),
        "effective_bs": int(PER_DEVICE_BS * GRAD_ACCUM),
        "num_val_batches": int(n_val_batches),
        "num_test_batches": int(n_test_batches),
        "n_val_items": int(n_val_items),
        "n_test_items": int(n_test_items),
    },
    "device": str(DEVICE),
    "gpu": (torch.cuda.get_device_name(0) if DEVICE.type == "cuda" else "CPU"),
    "timestamp": time.strftime("%Y%m%d_%H%M%S"),
}

summary_path = TEST_ROOT / "summary_test.json"
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(summary, f, indent=2)

history_path = TEST_ROOT / "history_index.jsonl"
with open(history_path, "a", encoding="utf-8") as f:
    f.write(json.dumps(summary) + "\n")

print("\nWROTE summary:", str(summary_path))
print("APPENDED history index:", str(history_path))
print("Open this folder to access artifacts:", str(TEST_ROOT))

# -------------------------
# Stop the runtime to avoid leaving the GPU instance running
# -------------------------
print("\nAll done. Unassigning the runtime to stop the L4 instance...")
try:
    from google.colab import runtime  # type: ignore
    print("Calling runtime.unassign() now.")
    runtime.unassign()
except Exception as e:
    print("Could not unassign runtime automatically. You can stop the runtime manually in Colab.")
    print("Reason:", repr(e))

The following cell runs a monolingual test evaluation for D6 (AhSound) using the model heads from the most recent D6 training and validation run. It reads `manifest_all.csv`, keeps only the validation and test rows, and checks that the dataset is clearly D6 so the wrong dataset is not evaluated by mistake. It also confirms that the expected `best_heads.pt` files exist for all three seeds (1337, 2024, 7777) before any evaluation starts.

The cell then prepares the data and model in a consistent and repeatable way. It checks that all audio files listed in the manifest exist on disk, assigns each clip to a simple task group (“vowel” when `task == "vowl"`, otherwise “other”), and normalizes sex labels to M, F, or UNK so sex based results are reported consistently. Audio is loaded from disk, verified to be at a 16 kHz sample rate, and formatted for use with the Wav2Vec2 backbone. For vowel clips, the attention mask is adjusted so the model ignores trailing padded silence. For non vowel clips, the full signal is used.

For each seed, the cell runs two linked steps. First, a validation pass is used only to choose a decision threshold. Inference is run on the D6 validation split and a validation optimal threshold is selected using Youden J, which maximizes TPR minus FPR on the validation ROC curve. Second, a test pass is run on the D6 test split using that same threshold. During this step, the cell reports test AUROC, threshold based metrics at the chosen threshold (including the confusion matrix, accuracy, precision, sensitivity, specificity, F1, MCC, and Fisher exact test p value), and fairness results using the H3 definition. Fairness is computed as the false negative rate for males and females separately, along with ΔFNR defined as FNR(F) minus FNR(M). Confusion counts by sex are also recorded.

For each seed, the cell saves standard plots such as the ROC curve and confusion matrix, along with sex specific confusion matrices when possible. It also writes a detailed `metrics.json` file under
`<DX_OUT_ROOT>/monolingual_test_runs/run_D6_seed<seed>/`.

After all three seeds complete, the cell combines results across seeds. It reports the mean test AUROC with a 95 percent confidence interval using a t distribution with n equal to 3, as well as the mean and standard deviation of the validation chosen thresholds, test threshold metrics, and fairness values. The combined summary is saved to `monolingual_test_runs/summary_test.json`, the same record is appended to `monolingual_test_runs/history_index.jsonl`, the output locations are printed, and the Colab runtime is unassigned to stop the GPU instance.

In [None]:
# =========================
# D6 Monolingual Test: VAL threshold (Youden J) + TEST metrics + fairness
# - Computes a VAL-optimal threshold per seed (maximize TPR - FPR)
# - Evaluates TEST metrics and fairness at that VAL-opt threshold
# - Saves per-seed artifacts and a cross-seed summary under monolingual_test_runs/
# =========================

import os, json, math, random, time, warnings
from pathlib import Path

import numpy as np
import pandas as pd
import soundfile as sf

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import Wav2Vec2Model

from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, matthews_corrcoef
from scipy.stats import fisher_exact
import matplotlib.pyplot as plt

from tqdm.auto import tqdm

# -------------------------
# Import safety checks
# Inputs: /content directory
# Output: raises early if a local file/folder would override torch/transformers
# -------------------------
if os.path.exists("/content/torch.py") or os.path.exists("/content/torch/__init__.py"):
    raise RuntimeError("Found local /content/torch.py or /content/torch/ that shadows PyTorch. Rename/remove it and restart runtime.")
if os.path.exists("/content/transformers.py") or os.path.exists("/content/transformers/__init__.py"):
    raise RuntimeError("Found local /content/transformers.py or /content/transformers/ that shadows Hugging Face Transformers. Rename/remove it and restart runtime.")

# -------------------------
# Drive mount (Colab)
# Inputs: Google Drive
# Output: /content/drive mounted for reading inputs and writing results
# -------------------------
if not os.path.isdir("/content/drive/MyDrive"):
    from google.colab import drive  # type: ignore
    drive.mount("/content/drive")

# -------------------------
# Roots and manifest path
# Inputs: optional DX_OUT_ROOT global, else fallback
# Output: MANIFEST_ALL path for loading val/test rows
# -------------------------
D6_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D6-Ah Sound (Figshare)/preprocessed_v1"
DX_OUT_ROOT = globals().get("DX_OUT_ROOT", D6_OUT_ROOT_FALLBACK)
MANIFEST_ALL = f"{DX_OUT_ROOT}/manifests/manifest_all.csv"

# -------------------------
# Runtime settings (kept consistent with other test cells)
# Inputs: constants below
# Output: printed config and device selection
# -------------------------
SEEDS          = [1337, 2024, 7777]
BACKBONE_CKPT  = "facebook/wav2vec2-base"
SR_EXPECTED    = 16000
TINY_THRESH    = 1e-4

EFFECTIVE_BS   = 64
PER_DEVICE_BS  = 16
GRAD_ACCUM     = max(1, EFFECTIVE_BS // PER_DEVICE_BS)  # printed only

DROPOUT_P      = 0.2

NUM_WORKERS    = 0
PIN_MEMORY     = False

USE_AMP        = True

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
warnings.filterwarnings("ignore", message="Passing `gradient_checkpointing` to a config initialization is deprecated")

print("DX_OUT_ROOT:", DX_OUT_ROOT)
print("MANIFEST_ALL:", MANIFEST_ALL)
print("DEVICE:", DEVICE)
if DEVICE.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
print("PER_DEVICE_BS:", PER_DEVICE_BS, "| EFFECTIVE_BS:", PER_DEVICE_BS * GRAD_ACCUM)
print("NUM_WORKERS:", NUM_WORKERS, "| PIN_MEMORY:", PIN_MEMORY)
print("USE_AMP:", bool(USE_AMP and DEVICE.type == "cuda"))

# -------------------------
# Manifest load: keep VAL and TEST only, then assert D6
# Inputs: manifest_all.csv
# Output: val_df and test_df with stable columns + basic counts printed
# -------------------------
if not os.path.exists(MANIFEST_ALL):
    raise FileNotFoundError(f"Missing manifest_all.csv: {MANIFEST_ALL}")

m = pd.read_csv(MANIFEST_ALL)

req_cols = {"split", "clip_path", "label_num", "task", "sex", "age"}
missing = [c for c in sorted(req_cols) if c not in m.columns]
if missing:
    raise ValueError(f"Manifest missing required columns: {missing}. Found: {list(m.columns)}")

m = m[m["split"].isin(["val", "test"])].copy()
if len(m) == 0:
    raise RuntimeError("After filtering to split in {'val','test'}, manifest has 0 rows.")

if "dataset" in m.columns and m["dataset"].notna().any():
    dataset_id = str(m["dataset"].astype(str).value_counts(dropna=True).idxmax())
    m = m[m["dataset"].astype(str) == dataset_id].copy()
else:
    dataset_id = "DX"

# Guard A: prevent accidental evaluation on a different dataset
if dataset_id != "D6":
    raise RuntimeError(
        f"Expected dataset_id=='D6' but got {dataset_id!r}. "
        "This usually means DX_OUT_ROOT was inherited from a previous cell or the manifest is not D6. "
        f"DX_OUT_ROOT={DX_OUT_ROOT}"
    )

keep_cols = ["clip_path", "label_num", "task", "speaker_id", "sex", "age", "duration_sec", "split"]
for c in keep_cols:
    if c not in m.columns:
        m[c] = np.nan
m = m[keep_cols].copy()

val_df  = m[m["split"].astype(str) == "val"].reset_index(drop=True)
test_df = m[m["split"].astype(str) == "test"].reset_index(drop=True)

print(f"\nDataset inferred: {dataset_id}")
print(f"Val rows:  {len(val_df)}")
print(f"Test rows: {len(test_df)}")
print("Val label counts:",  val_df["label_num"].value_counts(dropna=False).to_dict() if len(val_df) else {})
print("Test label counts:", test_df["label_num"].value_counts(dropna=False).to_dict())
print("Val sex counts (raw):",  val_df["sex"].value_counts(dropna=False).to_dict() if len(val_df) else {})
print("Test sex counts (raw):", test_df["sex"].value_counts(dropna=False).to_dict())

if len(val_df) == 0:
    raise RuntimeError("Val split has 0 rows. VAL is required to compute VAL-opt threshold (Youden J).")
if len(test_df) == 0:
    raise RuntimeError("Test split has 0 rows.")

# -------------------------
# Clip existence checks (fail early, show a few missing)
# Inputs: val_df/test_df clip_path
# Output: raises with example missing paths
# -------------------------
def _fail_fast_missing_paths(df: pd.DataFrame, name: str):
    missing_paths = []
    for p in tqdm(df["clip_path"].astype(str).tolist(), desc=f"Check {name} clip_path exists", dynamic_ncols=True):
        if not os.path.exists(p):
            missing_paths.append(p)
            if len(missing_paths) >= 10:
                break
    if missing_paths:
        raise FileNotFoundError(f"{name}: missing clip_path(s). Examples: {missing_paths}")

_fail_fast_missing_paths(val_df, "VAL")
_fail_fast_missing_paths(test_df, "TEST")

# -------------------------
# Task grouping: choose vowel head vs other head
# Inputs: task column values
# Output: task_group used during inference
# -------------------------
def _task_group(task_val) -> str:
    return "vowel" if str(task_val) == "vowl" else "other"

val_df["task_group"]  = val_df["task"].apply(_task_group)
test_df["task_group"] = test_df["task"].apply(_task_group)

# -------------------------
# Sex normalization for fairness and sex-specific plots
# Inputs: sex values from the manifest
# Output: sex_norm in {M, F, UNK}
# -------------------------
def normalize_sex(val) -> str:
    """
    Returns 'M', 'F', or 'UNK'
    Handles common strings (M/F, male/female, etc.)
    Numeric encodings are not guessed to avoid silent mis-mapping.
    """
    if pd.isna(val):
        return "UNK"
    s = str(val).strip().lower()

    if s in {"m", "male", "man", "masc", "masculine"}:
        return "M"
    if s in {"f", "female", "woman", "fem", "feminine"}:
        return "F"

    return "UNK"

val_df["sex_norm"]  = val_df["sex"].apply(normalize_sex)
test_df["sex_norm"] = test_df["sex"].apply(normalize_sex)

print("Val sex counts (normalized):",  val_df["sex_norm"].value_counts(dropna=False).to_dict())
print("Test sex counts (normalized):", test_df["sex_norm"].value_counts(dropna=False).to_dict())
if (val_df["sex_norm"] == "UNK").any() or (test_df["sex_norm"] == "UNK").any():
    print("NOTE: Some 'sex' values could not be normalized to M/F and were counted as 'UNK' for fairness and sex charts.")

# -------------------------
# Audio dataset + collator (padding + attention masks)
# Inputs: val_df/test_df and audio files on disk
# Output: batches with padded input_values and attention_mask
# -------------------------
class AudioManifestDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        clip_path = str(row["clip_path"])
        label = int(row["label_num"])
        task_group = str(row["task_group"])
        sex_norm = str(row["sex_norm"])

        if not os.path.exists(clip_path):
            raise FileNotFoundError(f"Missing clip_path: {clip_path}")

        y, sr = sf.read(clip_path, always_2d=False)
        if isinstance(y, np.ndarray) and y.ndim > 1:
            y = np.mean(y, axis=1)
        y = np.asarray(y, dtype=np.float32)

        if int(sr) != SR_EXPECTED:
            raise ValueError(f"Sample rate mismatch (expected {SR_EXPECTED}): {clip_path} has sr={sr}")

        # For vowel clips, mask out trailing near-zero padding
        attn = np.ones((len(y),), dtype=np.int64)
        if task_group == "vowel":
            k = -1
            for j in range(len(y) - 1, -1, -1):
                if abs(float(y[j])) > TINY_THRESH:
                    k = j
                    break
            if k >= 0:
                attn[:k+1] = 1
                attn[k+1:] = 0
            else:
                attn[:] = 1
        else:
            attn[:] = 1

        return {
            "input_values": torch.from_numpy(y),
            "attention_mask": torch.from_numpy(attn),
            "labels": torch.tensor(label, dtype=torch.long),
            "task_group": task_group,
            "sex_norm": sex_norm,
        }

def collate_fn(batch):
    max_len = int(max(b["input_values"].numel() for b in batch))
    input_vals, attn_masks, labels, task_groups, sex_norms = [], [], [], [], []
    for b in batch:
        x = b["input_values"]
        a = b["attention_mask"]
        pad = max_len - x.numel()
        if pad > 0:
            x = torch.cat([x, torch.zeros(pad, dtype=x.dtype)], dim=0)
            a = torch.cat([a, torch.zeros(pad, dtype=a.dtype)], dim=0)
        input_vals.append(x)
        attn_masks.append(a)
        labels.append(b["labels"])
        task_groups.append(b["task_group"])
        sex_norms.append(b["sex_norm"])
    return {
        "input_values": torch.stack(input_vals, dim=0),
        "attention_mask": torch.stack(attn_masks, dim=0),
        "labels": torch.stack(labels, dim=0),
        "task_group": task_groups,
        "sex_norm": sex_norms,
    }

# -------------------------
# Model: frozen wav2vec2 + two heads (vowel vs other)
# Inputs: BACKBONE_CKPT and dropout
# Output: logits for PD vs healthy using the correct head per task_group
# -------------------------
class Wav2Vec2TwoHeadClassifier(nn.Module):
    def __init__(self, ckpt: str, dropout_p: float):
        super().__init__()
        self.backbone = Wav2Vec2Model.from_pretrained(ckpt, use_safetensors=True)
        for p in self.backbone.parameters():
            p.requires_grad = False

        H = int(self.backbone.config.hidden_size)
        self.pre_vowel = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.pre_other = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.head_vowel = nn.Linear(H, 2)
        self.head_other = nn.Linear(H, 2)

    def masked_mean_pool(self, last_hidden, attn_mask_samples):
        feat_mask = self.backbone._get_feature_vector_attention_mask(last_hidden.shape[1], attn_mask_samples)
        feat_mask = feat_mask.to(last_hidden.device).unsqueeze(-1).type_as(last_hidden)
        masked = last_hidden * feat_mask
        denom = feat_mask.sum(dim=1).clamp(min=1.0)
        return masked.sum(dim=1) / denom

    def _heads_fp32(self, x_fp_any: torch.Tensor, head: nn.Module) -> torch.Tensor:
        # Keep head math in fp32 even under AMP
        x = x_fp_any.float()
        if x.is_cuda:
            with torch.amp.autocast(device_type="cuda", enabled=False):
                return head(x)
        return head(x)

    def forward_logits(self, input_values, attention_mask, task_group):
        with torch.no_grad():
            out = self.backbone(input_values=input_values, attention_mask=attention_mask)
            last_hidden = out.last_hidden_state

        pooled = self.masked_mean_pool(last_hidden, attention_mask)  # [B,H]

        z_v = self.pre_vowel(pooled.float())
        z_o = self.pre_other(pooled.float())

        logits_v = self._heads_fp32(z_v, self.head_vowel)
        logits_o = self._heads_fp32(z_o, self.head_other)

        is_vowel = torch.tensor([tg == "vowel" for tg in task_group], device=pooled.device, dtype=torch.bool)
        logits = logits_o.clone()
        if is_vowel.any():
            logits[is_vowel] = logits_v[is_vowel]
        return logits

# -------------------------
# Metrics + plots
# Inputs: y_true and y_prob (PD probability), plus threshold when needed
# Output: metric dicts and PNG plots saved to each seed folder
# -------------------------
def compute_auc(y_true, y_prob):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    if len(np.unique(y_true)) < 2:
        return float("nan")
    return float(roc_auc_score(y_true, y_prob))

def compute_threshold_metrics(y_true, y_prob, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])

    eps = 1e-12
    acc = (TP + TN) / max(1, (TP + TN + FP + FN))
    prec = TP / (TP + FP + eps)
    rec = TP / (TP + FN + eps)
    f1 = 2 * prec * rec / (prec + rec + eps)
    spec = TN / (TN + FP + eps)

    try:
        mcc = float(matthews_corrcoef(y_true, y_pred)) if len(np.unique(y_pred)) > 1 else float("nan")
    except Exception:
        mcc = float("nan")

    try:
        _, pval = fisher_exact([[TN, FP], [FN, TP]], alternative="two-sided")
        pval = float(pval)
    except Exception:
        pval = float("nan")

    return {
        "threshold": float(thr),
        "confusion_matrix": {"TN": TN, "FP": FP, "FN": FN, "TP": TP},
        "accuracy": float(acc),
        "precision": float(prec),
        "recall": float(rec),
        "f1_score": float(f1),
        "sensitivity": float(rec),
        "specificity": float(spec),
        "mcc": float(mcc),
        "p_value_fisher_two_sided": float(pval),
    }

def save_roc_curve_png(y_true, y_prob, out_png, title_suffix="Test"):
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    plt.figure()
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curve ({title_suffix})")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def save_confusion_png(y_true, y_prob, out_png, thr=0.5, title_suffix="Test"):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(int)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])

    plt.figure()
    plt.imshow(cm)
    plt.xticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.yticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix ({title_suffix}, thr={thr:.2f})")
    for (i, j), v in np.ndenumerate(cm):
        plt.text(j, i, str(v), ha="center", va="center")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

# -------------------------
# VAL-opt threshold (Youden J): choose threshold with max (TPR - FPR)
# Inputs: VAL labels and probabilities
# Output: best threshold and its ROC point
# -------------------------
def compute_val_opt_threshold_youden_j(y_true, y_prob):
    """
    Returns:
      thr_opt, youden_j_opt, fpr_opt, tpr_opt
    If ROC cannot be computed (single-class), returns (nan, nan, nan, nan).
    """
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)

    if len(np.unique(y_true)) < 2:
        return float("nan"), float("nan"), float("nan"), float("nan")

    fpr, tpr, thr = roc_curve(y_true, y_prob)
    j = tpr - fpr
    if len(j) == 0:
        return float("nan"), float("nan"), float("nan"), float("nan")
    k = int(np.nanargmax(j))
    return float(thr[k]), float(j[k]), float(fpr[k]), float(tpr[k])

# -------------------------
# Fairness (H3): FNR per sex and ΔFNR = FNR(F) - FNR(M)
# Inputs: TEST labels/probabilities, sex_norm, threshold
# Output: per-sex FNR details + signed and absolute gap
# -------------------------
def compute_fnr_by_group_signed(y_true, y_prob, groups, thr=0.5):
    """
    FNR = FN/(FN+TP) computed only on PD ground-truth rows.
    """
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    groups = np.asarray(list(groups), dtype=object)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    out = {}
    for g in sorted(set(groups.tolist())):
        mask_g = (groups == g)
        if mask_g.sum() == 0:
            continue

        pos_mask = mask_g & (y_true == 1)
        n_pos = int(pos_mask.sum())
        if n_pos == 0:
            out[g] = {"n_total": int(mask_g.sum()), "n_pos": 0, "tp": 0, "fn": 0, "fnr": float("nan")}
            continue

        tp = int(((y_pred == 1) & pos_mask).sum())
        fn = int(((y_pred == 0) & pos_mask).sum())
        fnr = float(fn / max(1, (fn + tp)))

        out[g] = {"n_total": int(mask_g.sum()), "n_pos": int(n_pos), "tp": int(tp), "fn": int(fn), "fnr": float(fnr)}

    fnr_m = out.get("M", {}).get("fnr", float("nan"))
    fnr_f = out.get("F", {}).get("fnr", float("nan"))

    if (not np.isnan(fnr_m)) and (not np.isnan(fnr_f)):
        delta_signed = float(fnr_f - fnr_m)
        delta_abs = float(abs(delta_signed))
    else:
        delta_signed = float("nan")
        delta_abs = float("nan")

    return out, delta_signed, delta_abs

# -------------------------
# Confusion counts by sex group (for reporting, not plotting)
# Inputs: TEST labels/probabilities, sex_norm, threshold
# Output: TN/FP/FN/TP per sex group
# -------------------------
def compute_confusion_counts(y_true, y_prob, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])
    return {"TN": TN, "FP": FP, "FN": FN, "TP": TP}

def compute_confusion_by_group(y_true, y_prob, groups, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    groups = np.asarray(list(groups), dtype=object)
    out = {}
    for g in sorted(set(groups.tolist())):
        mask = (groups == g)
        if int(mask.sum()) == 0:
            continue
        out[g] = {
            "n": int(mask.sum()),
            "confusion": compute_confusion_counts(y_true[mask], y_prob[mask], thr=thr),
        }
    return out

# -------------------------
# Seeding utilities
# Inputs: seed
# Output: stable randomness across numpy/torch
# -------------------------
def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# -------------------------
# Select most recent trainval experiment with all seed checkpoints
# Inputs: trainval_runs folder
# Output: chosen_exp (source of best_heads.pt)
# -------------------------
TRAINVAL_ROOT = Path(DX_OUT_ROOT) / "trainval_runs"
if not TRAINVAL_ROOT.exists():
    raise FileNotFoundError(f"Missing trainval_runs folder: {str(TRAINVAL_ROOT)}")

exp_dirs = sorted([p for p in TRAINVAL_ROOT.glob("exp_*") if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
if not exp_dirs:
    raise FileNotFoundError(f"No exp_* folders found under: {str(TRAINVAL_ROOT)}")

def _has_all_seeds(exp_path: Path, dataset_id: str, seeds: list):
    for s in seeds:
        p = exp_path / f"run_{dataset_id}_seed{s}" / "best_heads.pt"
        if not p.exists():
            return False
    return True

chosen_exp = None
for ed in exp_dirs:
    if _has_all_seeds(ed, dataset_id, SEEDS):
        chosen_exp = ed
        break

if chosen_exp is None:
    sample = exp_dirs[0]
    raise FileNotFoundError(
        "Could not find a recent trainval experiment with all 3 best_heads.pt files.\n"
        f"Expected under: {str(TRAINVAL_ROOT)}/exp_*/run_{dataset_id}_seedXXXX/best_heads.pt\n"
        f"Most recent exp checked: {str(sample)}"
    )

print("\nUsing Train+Val experiment folder:")
print(" ", str(chosen_exp))

# Guard B: re-check the three checkpoint files immediately after selection
for s in SEEDS:
    p = chosen_exp / f"run_{dataset_id}_seed{s}" / "best_heads.pt"
    if not p.exists():
        raise RuntimeError(f"Trainval artifact missing after choosing exp. Missing: {str(p)}")

# -------------------------
# Output root for this evaluation
# Inputs: DX_OUT_ROOT
# Output: per-seed folders and summary files under monolingual_test_runs/
# -------------------------
TEST_ROOT = Path(DX_OUT_ROOT) / "monolingual_test_runs"
TEST_ROOT.mkdir(parents=True, exist_ok=True)

# -------------------------
# DataLoaders (VAL for threshold, TEST for final metrics)
# Inputs: val_df and test_df
# Output: val_loader and test_loader
# -------------------------
val_loader = DataLoader(
    AudioManifestDataset(val_df),
    batch_size=PER_DEVICE_BS,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    collate_fn=collate_fn,
)

test_loader = DataLoader(
    AudioManifestDataset(test_df),
    batch_size=PER_DEVICE_BS,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    collate_fn=collate_fn,
)

# -------------------------
# Warm-up reads (quick sanity check that loaders work)
# Inputs: loaders
# Output: loads a few batches to surface issues early
# -------------------------
def _warmup(loader, name):
    print(f"\nWarm-up: loading up to 3 {name} batches...")
    t0 = time.time()
    num_batches = len(loader)
    warmup_batches = min(3, num_batches)
    if warmup_batches == 0:
        raise RuntimeError(f"{name} DataLoader has 0 batches. Check df length and PER_DEVICE_BS.")
    it = iter(loader)
    for i in range(warmup_batches):
        _ = next(it)
        print(f"  loaded warmup batch {i+1}/{warmup_batches}")
    print(f"Warm-up done in {time.time()-t0:.2f}s")

_warmup(val_loader, "VAL")
_warmup(test_loader, "TEST")

# -------------------------
# One-seed evaluation: pick VAL threshold, then score TEST at that threshold
# Inputs: seed, best_heads.pt, loaders
# Output: per-seed metrics.json and PNG plots under run_<DATASET>_seedXXXX/
# -------------------------
def load_heads_into_model(model: Wav2Vec2TwoHeadClassifier, best_heads_path: Path):
    if not best_heads_path.exists():
        raise FileNotFoundError(f"Missing best_heads.pt: {str(best_heads_path)}")
    state = torch.load(str(best_heads_path), map_location="cpu")
    model.pre_vowel.load_state_dict(state["pre_vowel"], strict=True)
    model.pre_other.load_state_dict(state["pre_other"], strict=True)
    model.head_vowel.load_state_dict(state["head_vowel"], strict=True)
    model.head_other.load_state_dict(state["head_other"], strict=True)
    return model

def _infer_probs(loader, model, seed, stage_name):
    use_amp = bool(USE_AMP and DEVICE.type == "cuda")
    all_probs, all_true, all_sex = [], [], []
    pbar = tqdm(loader, desc=f"[seed={seed}] {stage_name}", dynamic_ncols=True)
    with torch.inference_mode():
        for batch in pbar:
            input_values = batch["input_values"].to(DEVICE, non_blocking=False)
            attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
            labels = batch["labels"].to(DEVICE, non_blocking=False)
            task_group = batch["task_group"]
            sex_norm = batch["sex_norm"]

            with torch.amp.autocast(device_type="cuda", enabled=use_amp):
                logits = model.forward_logits(input_values, attention_mask, task_group)

            probs = torch.softmax(logits.float(), dim=-1)[:, 1].detach().cpu().numpy().astype(np.float64)
            all_probs.extend(probs.tolist())
            all_true.extend(labels.detach().cpu().numpy().astype(np.int64).tolist())
            all_sex.extend(list(sex_norm))
    return np.asarray(all_true, dtype=np.int64), np.asarray(all_probs, dtype=np.float64), np.asarray(all_sex, dtype=object)

def run_test_once(seed: int):
    set_all_seeds(seed)

    run_dir = TEST_ROOT / f"run_{dataset_id}_seed{seed}"
    run_dir.mkdir(parents=True, exist_ok=True)

    best_heads_path = chosen_exp / f"run_{dataset_id}_seed{seed}" / "best_heads.pt"

    print(f"\n[seed={seed}] Loading model + heads from:")
    print(" ", str(best_heads_path))

    model = Wav2Vec2TwoHeadClassifier(BACKBONE_CKPT, dropout_p=DROPOUT_P).to(DEVICE)
    model = load_heads_into_model(model, best_heads_path)
    model.eval()

    # VAL: compute per-seed threshold (Youden J)
    yv_true, yv_prob, _ = _infer_probs(val_loader, model, seed, "Val (for threshold)")
    val_thr_opt, val_j_opt, val_fpr_opt, val_tpr_opt = compute_val_opt_threshold_youden_j(yv_true, yv_prob)

    print(f"[seed={seed}] VAL-OPT threshold (Youden J): {val_thr_opt:.6f}" if not np.isnan(val_thr_opt) else f"[seed={seed}] VAL-OPT threshold (Youden J): nan")
    if not np.isnan(val_thr_opt):
        print(f"[seed={seed}] Youden J stats on VAL: J={val_j_opt:.6f} | TPR={val_tpr_opt:.6f} | FPR={val_fpr_opt:.6f}")

    # TEST: metrics at the VAL-opt threshold for this seed
    yt_true, yt_prob, yt_sex = _infer_probs(test_loader, model, seed, "Test")

    test_auc = compute_auc(yt_true, yt_prob)
    thr_use = float(val_thr_opt)
    thr_metrics = compute_threshold_metrics(yt_true, yt_prob, thr=thr_use)

    fnr_by_sex, delta_f_minus_m, delta_abs = compute_fnr_by_group_signed(
        yt_true, yt_prob, yt_sex, thr=thr_use
    )

    confusion_by_sex = compute_confusion_by_group(yt_true, yt_prob, yt_sex, thr=thr_use)

    # Plots: overall ROC + overall confusion
    roc_png = run_dir / "roc_curve.png"
    cm_png  = run_dir / "confusion_matrix.png"
    save_roc_curve_png(yt_true, yt_prob, str(roc_png), title_suffix=f"Test (seed={seed})")
    save_confusion_png(yt_true, yt_prob, str(cm_png), thr=thr_use, title_suffix=f"Test (seed={seed})")

    # Plots: sex-specific confusion (if that group exists)
    cm_m_png = None
    cm_f_png = None
    mask_m = (yt_sex == "M")
    mask_f = (yt_sex == "F")

    if int(mask_m.sum()) > 0:
        cm_m_png = run_dir / "confusion_matrix_M.png"
        save_confusion_png(
            yt_true[mask_m], yt_prob[mask_m], str(cm_m_png),
            thr=thr_use, title_suffix=f"Test SEX=M (seed={seed})"
        )

    if int(mask_f.sum()) > 0:
        cm_f_png = run_dir / "confusion_matrix_F.png"
        save_confusion_png(
            yt_true[mask_f], yt_prob[mask_f], str(cm_f_png),
            thr=thr_use, title_suffix=f"Test SEX=F (seed={seed})"
        )

    metrics = {
        "dataset": dataset_id,
        "seed": int(seed),

        "n_val": int(len(val_df)),
        "label_counts_val": val_df["label_num"].value_counts(dropna=False).to_dict(),
        "sex_counts_val_norm": val_df["sex_norm"].value_counts(dropna=False).to_dict(),

        "n_test": int(len(test_df)),
        "label_counts_test": test_df["label_num"].value_counts(dropna=False).to_dict(),
        "sex_counts_test_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),

        "val_opt_threshold": float(val_thr_opt),
        "val_opt_threshold_method": "Youden J (maximize TPR - FPR) on VAL ROC",
        "val_youden_j": float(val_j_opt),
        "val_youden_tpr": float(val_tpr_opt),
        "val_youden_fpr": float(val_fpr_opt),

        "test_auroc": float(test_auc),
        "threshold_metrics_test_at_val_opt": thr_metrics,

        "fairness_test_at_val_opt": {
            "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at val_opt_threshold.",
            "fnr_by_sex_norm": fnr_by_sex,
            "delta_fnr_F_minus_M": float(delta_f_minus_m),
            "delta_fnr_abs": float(delta_abs),
            "note": "If n_PD for a sex is 0, its FNR is NaN and ΔFNR is NaN.",
            "sex_normalization_note": "sex_norm in {M,F,UNK}. Values not mapped to M/F counted as UNK.",
        },

        "confusion_by_sex_norm_at_val_opt": confusion_by_sex,

        "artifacts": {
            "roc_curve_png": str(roc_png),
            "confusion_matrix_png": str(cm_png),
            "confusion_matrix_M_png": str(cm_m_png) if cm_m_png is not None else None,
            "confusion_matrix_F_png": str(cm_f_png) if cm_f_png is not None else None,
        },

        "dx_out_root": DX_OUT_ROOT,
        "trainval_experiment_used": str(chosen_exp),
        "best_heads_path": str(best_heads_path),
        "backbone_ckpt": BACKBONE_CKPT,
        "dropout_p": float(DROPOUT_P),
        "timestamp": time.strftime("%Y%m%d_%H%M%S"),
    }

    metrics_path = run_dir / "metrics.json"
    with open(metrics_path, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2)

    print(f"[seed={seed}] DONE | test_AUROC={test_auc:.6f}")

    def _fmt_fnr(g):
        d = fnr_by_sex.get(g, None)
        if d is None:
            return "n/a"
        return f"fnr={d['fnr']:.6f} (n_PD={d['n_pos']}, fn={d['fn']}, tp={d['tp']})"

    print(f"[seed={seed}] FAIRNESS (H3) on TEST @ VAL-OPT thr={thr_use:.6f}:")
    print("  M:", _fmt_fnr("M"))
    print("  F:", _fmt_fnr("F"))
    if "UNK" in fnr_by_sex:
        print("  UNK:", _fmt_fnr("UNK"))
    print("  ΔFNR (F-M):", f"{delta_f_minus_m:.6f}" if not np.isnan(delta_f_minus_m) else "nan")
    print("  |ΔFNR|:", f"{delta_abs:.6f}" if not np.isnan(delta_abs) else "nan")

    print(f"[seed={seed}] WROTE:")
    print(" ", str(metrics_path))
    print(" ", str(roc_png))
    print(" ", str(cm_png))
    if cm_m_png is not None:
        print(" ", str(cm_m_png))
    if cm_f_png is not None:
        print(" ", str(cm_f_png))

    return {
        "seed": int(seed),
        "val_opt_threshold": float(val_thr_opt),
        "test_auroc": float(test_auc),
        "threshold_metrics": thr_metrics,
        "fnr_by_sex": fnr_by_sex,
        "delta_signed": float(delta_f_minus_m),
        "delta_abs": float(delta_abs),
        "run_dir": str(run_dir),
    }

# -------------------------
# Cross-seed summary: AUROC CI, threshold stats, metrics mean±SD, fairness mean±SD
# Inputs: 3 per-seed results
# Output: summary_test.json and history_index.jsonl under monolingual_test_runs/
# -------------------------
results = []
for seed in SEEDS:
    results.append(run_test_once(seed))

aurocs = [r["test_auroc"] for r in results]
val_thrs = [r["val_opt_threshold"] for r in results]
thr_list = [r["threshold_metrics"] for r in results]
run_dirs = [r["run_dir"] for r in results]

t_crit = 4.302652729911275  # df=2, 95% CI
n = len(aurocs)
mean_auc = float(np.mean(aurocs))
std_auc = float(np.std(aurocs, ddof=1)) if n > 1 else 0.0
half_width = float(t_crit * (std_auc / math.sqrt(n))) if n > 1 else 0.0
ci95 = [float(mean_auc - half_width), float(mean_auc + half_width)]

def _mean_sd(vals):
    vals = np.asarray(vals, dtype=np.float64)
    return float(np.nanmean(vals)), float(np.nanstd(vals, ddof=1)) if np.sum(~np.isnan(vals)) > 1 else 0.0

keys = ["accuracy","precision","recall","f1_score","sensitivity","specificity","mcc","p_value_fisher_two_sided"]
agg = {}
for k in keys:
    v = [float(tm.get(k, float("nan"))) for tm in thr_list]
    mu, sd = _mean_sd(v)
    agg[k] = {
        "mean": mu,
        "sd": sd,
        "values_by_seed": {str(r["seed"]): float(r["threshold_metrics"].get(k, float("nan"))) for r in results},
    }
cm_by_seed = {str(r["seed"]): r["threshold_metrics"]["confusion_matrix"] for r in results}

fnr_by_seed = {str(r["seed"]): r["fnr_by_sex"] for r in results}
delta_signed_by_seed = {str(r["seed"]): float(r["delta_signed"]) for r in results}
delta_abs_by_seed = {str(r["seed"]): float(r["delta_abs"]) for r in results}

fnr_m_vals, fnr_f_vals, n_pd_m_vals, n_pd_f_vals = [], [], [], []
d_signed_vals, d_abs_vals = [], []
for r in results:
    d = r["fnr_by_sex"]
    fnr_m_vals.append(float(d.get("M", {}).get("fnr", float("nan"))))
    fnr_f_vals.append(float(d.get("F", {}).get("fnr", float("nan"))))
    n_pd_m_vals.append(float(d.get("M", {}).get("n_pos", float("nan"))))
    n_pd_f_vals.append(float(d.get("F", {}).get("n_pos", float("nan"))))
    d_signed_vals.append(float(r["delta_signed"]))
    d_abs_vals.append(float(r["delta_abs"]))

fnr_m_mean, fnr_m_sd = _mean_sd(fnr_m_vals)
fnr_f_mean, fnr_f_sd = _mean_sd(fnr_f_vals)
d_signed_mean, d_signed_sd = _mean_sd(d_signed_vals)
d_abs_mean, d_abs_sd = _mean_sd(d_abs_vals)

thr_mean, thr_sd = _mean_sd(val_thrs)

print("\nTest AUROC by seed:")
for r in results:
    print(f"  seed {r['seed']}: {r['test_auroc']:.6f}")

print(f"\nMean Test AUROC: {mean_auc:.6f}")
print(f"95% CI (t, n=3): [{ci95[0]:.6f}, {ci95[1]:.6f}]")

print("\nVAL-OPT thresholds (Youden J) by seed:")
for r in results:
    v = r["val_opt_threshold"]
    print(f"  seed {r['seed']}: {v:.6f}" if not np.isnan(v) else f"  seed {r['seed']}: nan")
print(f"  mean ± SD: {thr_mean:.6f} ± {thr_sd:.6f}" if not np.isnan(thr_mean) else "  mean ± SD: nan")

print("\nThreshold metrics on TEST @ VAL-OPT threshold (mean ± SD across seeds):")
for k in ["accuracy","precision","sensitivity","specificity","f1_score","mcc"]:
    mu = agg[k]["mean"]
    sd = agg[k]["sd"]
    print(f"  {k}: {mu:.6f} ± {sd:.6f}")
print("  fisher_p_value_two_sided:", f"{agg['p_value_fisher_two_sided']['mean']:.6g} ± {agg['p_value_fisher_two_sided']['sd']:.6g}")

print("\nFAIRNESS (H3) on TEST @ VAL-OPT threshold across seeds (mean ± SD):")
print(f"  FNR_M: {fnr_m_mean:.6f} ± {fnr_m_sd:.6f}")
print(f"  FNR_F: {fnr_f_mean:.6f} ± {fnr_f_sd:.6f}")
print(f"  ΔFNR (F-M): {d_signed_mean:.6f} ± {d_signed_sd:.6f}")
print(f"  |ΔFNR|: {d_abs_mean:.6f} ± {d_abs_sd:.6f}")
print("  Per-seed:", {
    str(r["seed"]): {
        "val_opt_threshold": r["val_opt_threshold"],
        "FNR_M": fnr_m_vals[i],
        "n_PD_M": n_pd_m_vals[i],
        "FNR_F": fnr_f_vals[i],
        "n_PD_F": n_pd_f_vals[i],
        "delta_F_minus_M": d_signed_vals[i],
        "abs_delta": d_abs_vals[i],
    } for i, r in enumerate(results)
})

summary = {
    "dataset": dataset_id,
    "dx_out_root": DX_OUT_ROOT,
    "manifest_all": MANIFEST_ALL,
    "trainval_experiment_used": str(chosen_exp),
    "seeds": SEEDS,

    "n_val": int(len(val_df)),
    "label_counts_val": val_df["label_num"].value_counts(dropna=False).to_dict(),
    "sex_counts_val_norm": val_df["sex_norm"].value_counts(dropna=False).to_dict(),

    "n_test": int(len(test_df)),
    "label_counts_test": test_df["label_num"].value_counts(dropna=False).to_dict(),
    "sex_counts_test_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),

    "val_opt_thresholds_by_seed": {str(r["seed"]): float(r["val_opt_threshold"]) for r in results},
    "val_opt_threshold_mean_sd": {"mean": float(thr_mean), "sd": float(thr_sd)},

    "test_aurocs_by_seed": {str(r["seed"]): float(r["test_auroc"]) for r in results},
    "mean_test_auroc": float(mean_auc),
    "t_crit_df2_95": float(t_crit),
    "half_width_95_ci": float(half_width),
    "ci95_test_auroc": ci95,

    "threshold_metrics_test_at_val_opt_mean_sd": agg,
    "confusion_matrix_by_seed_at_val_opt": cm_by_seed,
    "run_dirs": run_dirs,

    "fairness_test_at_val_opt": {
        "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at val_opt_threshold.",
        "fnr_by_sex_norm_by_seed": fnr_by_seed,
        "delta_fnr_F_minus_M_by_seed": delta_signed_by_seed,
        "delta_fnr_abs_by_seed": delta_abs_by_seed,
        "fnr_M_mean_sd": {"mean": float(fnr_m_mean), "sd": float(fnr_m_sd), "values_by_seed": {str(r['seed']): float(fnr_m_vals[i]) for i, r in enumerate(results)}},
        "fnr_F_mean_sd": {"mean": float(fnr_f_mean), "sd": float(fnr_f_sd), "values_by_seed": {str(r['seed']): float(fnr_f_vals[i]) for i, r in enumerate(results)}},
        "delta_fnr_F_minus_M_mean_sd": {"mean": float(d_signed_mean), "sd": float(d_signed_sd), "values_by_seed": {str(r['seed']): float(d_signed_vals[i]) for i, r in enumerate(results)}},
        "delta_fnr_abs_mean_sd": {"mean": float(d_abs_mean), "sd": float(d_abs_sd), "values_by_seed": {str(r['seed']): float(d_abs_vals[i]) for i, r in enumerate(results)}},
        "denominators_PD_by_seed": {str(r["seed"]): {"n_PD_M": float(n_pd_m_vals[i]), "n_PD_F": float(n_pd_f_vals[i])} for i, r in enumerate(results)},
        "sex_normalization_note": "sex_norm in {M,F,UNK}. Values not mapped to M/F counted as UNK. ΔFNR computed only when both M and F have defined FNR.",
    },

    "timestamp": time.strftime("%Y%m%d_%H%M%S"),
}

summary_path = TEST_ROOT / "summary_test.json"
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(summary, f, indent=2)

history_path = TEST_ROOT / "history_index.jsonl"
with open(history_path, "a", encoding="utf-8") as f:
    f.write(json.dumps(summary) + "\n")

print("\nWROTE summary:", str(summary_path))
print("APPENDED history index:", str(history_path))
print("Open this folder to access artifacts:", str(TEST_ROOT))

# -------------------------
# Stop runtime (Colab)
# Inputs: none
# Output: attempts to release the GPU instance
# -------------------------
print("\nAll done. Unassigning the runtime to stop the L4 instance...")
try:
    from google.colab import runtime  # type: ignore
    print("Calling runtime.unassign() now.")
    runtime.unassign()
except Exception as e:
    print("Could not unassign runtime automatically. You can stop the runtime manually in Colab.")
    print("Reason:", repr(e))

#Cross-Language Zero-Shot Transfer Tests

The following cell runs a cross language zero shot test from Spanish (D1) to Slovak (D2). It checks how well the classifier heads trained on the Spanish dataset (D1, NeuroVoz) perform on the Slovak dataset (D2, EWA DB) without any training, fine tuning, or calibration on D2. The evaluation uses only the D2 test split from `D2/manifests/manifest_all.csv`. The cell automatically finds the most recent completed D1 train and validation experiment that contains saved head checkpoints for three random seeds (1337, 2024, 7777), and applies those D1 trained heads directly to D2 audio.

A single fixed decision threshold is used for all three seeds. This threshold is taken from D1’s saved `monolingual_test_runs/summary_test.json` and represents the mean validation selected threshold from D1. Because the summary file structure can differ across versions, the cell includes a robust reader that searches multiple known locations for the threshold and, if needed, falls back to averaging per seed threshold values. If no valid threshold can be found, the cell stops with a clear error message.

Before running inference, the cell performs several safety checks to avoid silent errors. It mounts Google Drive if needed, confirms that local files are not overriding PyTorch or Transformers, prints key paths and the active compute device, verifies that the D2 manifest contains all required columns, confirms the dataset label is truly D2, filters strictly to the test split, and checks that all referenced audio files exist on disk. Each clip is assigned a task group, where `task == "vowl"` is treated as vowel and everything else is treated as other. Sex values in D2, such as male or female and common variants, are normalized to `sex_norm` values of M, F, or UNK so reporting is consistent.

For each seed, the cell loads the matching D1 head checkpoint (`best_heads.pt`) into a frozen Wav2Vec2 two head classifier, with one head for vowel clips and one for other speech. It then runs inference on all D2 test clips to produce Parkinson’s probabilities. Using these scores, it computes overall AUROC, which does not depend on a threshold, and fixed threshold metrics at the D1 derived threshold, including the confusion matrix, accuracy, precision, sensitivity or recall, specificity, F1 score, MCC, and Fisher exact test p value. Fairness is computed as ΔFNR, defined as FNR(F) minus FNR(M), measured on Parkinson’s positive clips only. Confusion counts by sex are also recorded. For each seed, the cell saves a ROC curve, an overall confusion matrix, additional sex specific confusion matrices when both M and F are present, and a detailed `metrics.json` file under
`<D2_OUT_ROOT>/Cross_Language_Zero_Shot_Runs/run_ES_to_SK_seed<seed>/`.

After all three seeds finish, the cell combines results across seeds. It reports the mean AUROC with a 95 percent confidence interval using a t distribution with n equal to 3, along with the mean and standard deviation of the fixed threshold metrics and fairness values. A combined summary is written to `summary_zero_shot.json` and the same record is appended to `history_index.jsonl`, both under `<D2_OUT_ROOT>/Cross_Language_Zero_Shot_Runs/`. Finally, the Colab runtime is unassigned to shut down the GPU instance.

In [None]:
# =========================
# ES → SK Zero-shot (D1 → D2): Test-only transfer with fixed threshold
# - Loads D1-trained heads (most recent train+val experiment)
# - Runs inference on D2 TEST only (from manifest_all.csv)
# - Uses one fixed threshold read from D1 monolingual summary_test.json (mean VAL-opt)
# - Saves per-seed metrics + plots, plus an aggregated summary under D2
# =========================

import os, json, math, random, time, warnings
from pathlib import Path

import numpy as np
import pandas as pd
import soundfile as sf

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import Wav2Vec2Model

from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, matthews_corrcoef
from scipy.stats import fisher_exact
import matplotlib.pyplot as plt

from tqdm.auto import tqdm

# -------------------------
# Import safety: prevent local files from overriding torch/transformers
# Inputs: /content directory
# Output: raises early if a conflicting file/folder exists
# -------------------------
if os.path.exists("/content/torch.py") or os.path.exists("/content/torch/__init__.py"):
    raise RuntimeError("Found local /content/torch.py or /content/torch/ that shadows PyTorch. Rename/remove it and restart runtime.")
if os.path.exists("/content/transformers.py") or os.path.exists("/content/transformers/__init__.py"):
    raise RuntimeError("Found local /content/transformers.py or /content/transformers/ that shadows Hugging Face Transformers. Rename/remove it and restart runtime.")

# -------------------------
# Drive mount (Colab)
# Inputs: Google Drive
# Output: /content/drive mounted for reading inputs and writing results
# -------------------------
if not os.path.isdir("/content/drive/MyDrive"):
    from google.colab import drive  # type: ignore
    drive.mount("/content/drive")

# -------------------------
# Paths: target (D2) + source (D1)
# Inputs: optional globals D1_OUT_ROOT, D2_OUT_ROOT (if set earlier), else fallbacks
# Output: D2_MANIFEST_ALL and D1_MONO_SUMMARY paths
# -------------------------
D1_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D1-NeuroVoz-Castillan Spanish/preprocessed_v1"
D2_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D2-Slovak (EWA-DB)/EWA-DB/preprocessed_v1"

D1_OUT_ROOT = globals().get("D1_OUT_ROOT", D1_OUT_ROOT_FALLBACK)
D2_OUT_ROOT = globals().get("D2_OUT_ROOT", D2_OUT_ROOT_FALLBACK)

D2_MANIFEST_ALL = f"{D2_OUT_ROOT}/manifests/manifest_all.csv"

# D1 monolingual summary (used to read the fixed transfer threshold)
D1_MONO_SUMMARY = Path(D1_OUT_ROOT) / "monolingual_test_runs" / "summary_test.json"

# -------------------------
# Run settings (kept consistent with other test cells)
# Inputs: constants below
# Output: printed configuration + runtime behavior (batching, AMP, device)
# -------------------------
TRANSFER_TAG    = "ES_to_SK"          # D1 -> D2
SEEDS           = [1337, 2024, 7777]
BACKBONE_CKPT   = "facebook/wav2vec2-base"
SR_EXPECTED     = 16000
TINY_THRESH     = 1e-4

EFFECTIVE_BS    = 64
PER_DEVICE_BS   = 16
GRAD_ACCUM      = max(1, EFFECTIVE_BS // PER_DEVICE_BS)  # printed only

DROPOUT_P       = 0.2

NUM_WORKERS     = 0
PIN_MEMORY      = False

USE_AMP         = True

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
warnings.filterwarnings("ignore", message="Passing `gradient_checkpointing` to a config initialization is deprecated")

print("D2_OUT_ROOT (target):", D2_OUT_ROOT)
print("D2_MANIFEST_ALL:", D2_MANIFEST_ALL)
print("D1_OUT_ROOT (source):", D1_OUT_ROOT)
print("D1_MONO_SUMMARY:", str(D1_MONO_SUMMARY))
print("DEVICE:", DEVICE)
if DEVICE.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
print("PER_DEVICE_BS:", PER_DEVICE_BS, "| EFFECTIVE_BS:", PER_DEVICE_BS * GRAD_ACCUM)
print("NUM_WORKERS:", NUM_WORKERS, "| PIN_MEMORY:", PIN_MEMORY)
print("USE_AMP:", bool(USE_AMP and DEVICE.type == "cuda"))

# -------------------------
# Fixed threshold: read from D1 monolingual summary_test.json (schema-robust)
# Inputs: D1_MONO_SUMMARY (JSON)
# Output: FIXED_THR (float) used for all three seeds on D2
# -------------------------
def _dig(d, path):
    cur = d
    for k in path:
        if not isinstance(cur, dict) or k not in cur:
            return None
        cur = cur[k]
    return cur

def load_fixed_threshold_from_summary(summary_dict: dict) -> float:
    """
    Returns a finite float threshold.
    Tries multiple key paths to stay compatible with older/newer summary schemas.

    Preferred:
      threshold_selection.val_selected_threshold_mean_sd.mean

    Alternatives:
      val_opt_threshold_mean_sd.mean
      val_selected_threshold_mean_sd.mean
      threshold_selection.val_opt_threshold_mean_sd.mean
      mean of val_opt_threshold_by_seed / val_selected_threshold_by_seed
    """
    candidate_paths = [
        ["threshold_selection", "val_selected_threshold_mean_sd", "mean"],
        ["val_selected_threshold_mean_sd", "mean"],
        ["threshold_selection", "val_opt_threshold_mean_sd", "mean"],
        ["val_opt_threshold_mean_sd", "mean"],
        ["val_optimal_threshold", "mean_sd", "mean"],
        ["val_optimal_threshold_mean_sd", "mean"],
        ["val_optimal_threshold", "mean"],
    ]

    for p in candidate_paths:
        v = _dig(summary_dict, p)
        if v is None:
            continue
        try:
            fv = float(v)
            if np.isfinite(fv):
                return fv
        except Exception:
            pass

    candidate_by_seed_paths = [
        ["threshold_selection", "val_selected_threshold_by_seed"],
        ["val_selected_threshold_by_seed"],
        ["threshold_selection", "val_opt_threshold_by_seed"],
        ["val_opt_threshold_by_seed"],
        ["val_optimal_threshold_by_seed"],
        ["val_opt_threshold", "by_seed"],
    ]
    for p in candidate_by_seed_paths:
        d = _dig(summary_dict, p)
        if isinstance(d, dict) and len(d) > 0:
            vals = []
            for _, vv in d.items():
                try:
                    vals.append(float(vv))
                except Exception:
                    continue
            if len(vals) > 0:
                fv = float(np.nanmean(np.asarray(vals, dtype=np.float64)))
                if np.isfinite(fv):
                    return fv

    return float("nan")

if not D1_MONO_SUMMARY.exists():
    raise FileNotFoundError(
        "Missing D1 monolingual summary_test.json.\n"
        f"Expected: {str(D1_MONO_SUMMARY)}\n"
        "Run the D1 monolingual test cell first (the one that writes summary_test.json)."
    )

with open(D1_MONO_SUMMARY, "r", encoding="utf-8") as f:
    d1_sum = json.load(f)

FIXED_THR = load_fixed_threshold_from_summary(d1_sum)

if not np.isfinite(FIXED_THR):
    raise RuntimeError(
        "Could not read a finite fixed threshold from D1 monolingual summary_test.json.\n"
        "Tried known key paths including:\n"
        "  threshold_selection.val_selected_threshold_mean_sd.mean (preferred)\n"
        "  val_opt_threshold_mean_sd.mean\n"
        "  and by-seed fallbacks.\n"
        f"File: {str(D1_MONO_SUMMARY)}\n"
        "Open that JSON and confirm where the mean VAL-opt threshold is stored."
    )

print(f"\nFixed transfer threshold (from D1 monolingual MEAN VAL-opt): {FIXED_THR:.6f}")

# -------------------------
# Target manifest: load D2 and keep TEST only
# Inputs: D2 manifest_all.csv
# Output: test_df with required columns + basic counts printed
# -------------------------
if not os.path.exists(D2_MANIFEST_ALL):
    raise FileNotFoundError(f"Missing D2 manifest_all.csv: {D2_MANIFEST_ALL}")

m_all = pd.read_csv(D2_MANIFEST_ALL)

# Require these fields for inference + fairness plots
req_cols = {"split", "clip_path", "label_num", "task", "sex"}
missing = [c for c in sorted(req_cols) if c not in m_all.columns]
if missing:
    raise ValueError(f"D2 manifest missing required columns: {missing}. Found: {list(m_all.columns)}")

# Ensure the manifest content is actually D2
if "dataset" in m_all.columns and m_all["dataset"].notna().any():
    dataset_id = str(m_all["dataset"].astype(str).value_counts(dropna=True).idxmax())
    m_all = m_all[m_all["dataset"].astype(str) == dataset_id].copy()
else:
    dataset_id = "DX"

if dataset_id != "D2":
    raise RuntimeError(
        f"Expected dataset_id=='D2' but got {dataset_id!r}. "
        "This usually means D2_OUT_ROOT is wrong or inherited from a previous cell.\n"
        f"D2_OUT_ROOT={D2_OUT_ROOT}"
    )

# Keep a stable column set (missing ones are filled with NaN)
keep_cols = ["clip_path", "label_num", "task", "speaker_id", "sex", "age", "duration_sec", "split"]
for c in keep_cols:
    if c not in m_all.columns:
        m_all[c] = np.nan
m_all = m_all[keep_cols].copy()

test_df = m_all[m_all["split"].isin(["test"])].copy().reset_index(drop=True)

print(f"\nTarget dataset inferred: {dataset_id}")
print(f"TEST rows: {len(test_df)}")
if len(test_df) == 0:
    raise RuntimeError("After filtering to split=='test', D2 manifest has 0 rows.")

print("TEST label counts:", test_df["label_num"].value_counts(dropna=False).to_dict())
print("TEST sex counts (raw):", test_df["sex"].value_counts(dropna=False).to_dict())

# -------------------------
# Path check: ensure target audio files exist
# Inputs: test_df.clip_path
# Output: raises early if any clip files are missing (shows a few examples)
# -------------------------
def _fail_fast_missing_paths(df: pd.DataFrame, name: str):
    missing_paths = []
    for p in tqdm(df["clip_path"].astype(str).tolist(), desc=f"Check {name} clip_path exists", dynamic_ncols=True):
        if not os.path.exists(p):
            missing_paths.append(p)
            if len(missing_paths) >= 10:
                break
    if missing_paths:
        raise FileNotFoundError(f"{name}: missing clip_path(s). Examples: {missing_paths}")

_fail_fast_missing_paths(test_df, "TEST")

# -------------------------
# Task grouping: map each row to vowel vs other
# Inputs: test_df.task
# Output: test_df.task_group used to pick the correct head during inference
# -------------------------
def _task_group(task_val) -> str:
    return "vowel" if str(task_val) == "vowl" else "other"

test_df["task_group"] = test_df["task"].apply(_task_group)

# -------------------------
# Sex normalization for fairness/charts (D2 uses male/female)
# Inputs: test_df.sex
# Output: test_df.sex_norm in {M,F,UNK}
# -------------------------
def normalize_sex_d2(val) -> str:
    """
    Returns 'M', 'F', or 'UNK'
    D2 expected encoding: 'male'/'female' (case-insensitive).
    Also handles M/F and 0/1 defensively.
    """
    if pd.isna(val):
        return "UNK"

    # Defensive numeric mapping (only used if sex is stored as 0/1)
    try:
        fv = float(val)
        if np.isfinite(fv) and abs(fv - round(fv)) < 1e-9:
            iv = int(round(fv))
            if iv == 0:
                return "F"
            if iv == 1:
                return "M"
    except Exception:
        pass

    s = str(val).strip().lower()
    if s in {"m", "male", "man", "masc", "masculine"}:
        return "M"
    if s in {"f", "female", "woman", "fem", "feminine"}:
        return "F"
    return "UNK"

test_df["sex_norm"] = test_df["sex"].apply(normalize_sex_d2)
print("TEST sex counts (normalized):", test_df["sex_norm"].value_counts(dropna=False).to_dict())
if (test_df["sex_norm"] == "UNK").any():
    print("NOTE: Some D2 'sex' values could not be normalized to M/F and were counted as 'UNK' for fairness and sex charts.")

# -------------------------
# Dataset + collator: read audio and build attention masks
# Inputs: test_df rows + audio files
# Output: DataLoader batches with padded input_values, attention_mask, labels, task_group, sex_norm
# -------------------------
class AudioManifestDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        clip_path = str(row["clip_path"])
        label = int(row["label_num"])
        task_group = str(row["task_group"])
        sex_norm = str(row["sex_norm"])

        if not os.path.exists(clip_path):
            raise FileNotFoundError(f"Missing clip_path: {clip_path}")

        y, sr = sf.read(clip_path, always_2d=False)
        if isinstance(y, np.ndarray) and y.ndim > 1:
            y = np.mean(y, axis=1)
        y = np.asarray(y, dtype=np.float32)

        if int(sr) != SR_EXPECTED:
            raise ValueError(f"Sample rate mismatch (expected {SR_EXPECTED}): {clip_path} has sr={sr}")

        # attention_mask mainly matters for vowel clips (ignore padded tail)
        attn = np.ones((len(y),), dtype=np.int64)
        if task_group == "vowel":
            k = -1
            for j in range(len(y) - 1, -1, -1):
                if abs(float(y[j])) > TINY_THRESH:
                    k = j
                    break
            if k >= 0:
                attn[:k+1] = 1
                attn[k+1:] = 0
            else:
                attn[:] = 1
        else:
            attn[:] = 1

        return {
            "input_values": torch.from_numpy(y),
            "attention_mask": torch.from_numpy(attn),
            "labels": torch.tensor(label, dtype=torch.long),
            "task_group": task_group,
            "sex_norm": sex_norm,
        }

def collate_fn(batch):
    max_len = int(max(b["input_values"].numel() for b in batch))
    input_vals, attn_masks, labels, task_groups, sex_norms = [], [], [], [], []
    for b in batch:
        x = b["input_values"]
        a = b["attention_mask"]
        pad = max_len - x.numel()
        if pad > 0:
            x = torch.cat([x, torch.zeros(pad, dtype=x.dtype)], dim=0)
            a = torch.cat([a, torch.zeros(pad, dtype=a.dtype)], dim=0)
        input_vals.append(x)
        attn_masks.append(a)
        labels.append(b["labels"])
        task_groups.append(b["task_group"])
        sex_norms.append(b["sex_norm"])
    return {
        "input_values": torch.stack(input_vals, dim=0),
        "attention_mask": torch.stack(attn_masks, dim=0),
        "labels": torch.stack(labels, dim=0),
        "task_group": task_groups,
        "sex_norm": sex_norms,
    }

# -------------------------
# Model: frozen wav2vec2 backbone + two heads (vowel vs other)
# Inputs: backbone checkpoint name + dropout_p
# Output: logits for PD vs healthy for each clip, using the correct head per task_group
# -------------------------
class Wav2Vec2TwoHeadClassifier(nn.Module):
    def __init__(self, ckpt: str, dropout_p: float):
        super().__init__()
        self.backbone = Wav2Vec2Model.from_pretrained(ckpt, use_safetensors=True)
        for p in self.backbone.parameters():
            p.requires_grad = False

        H = int(self.backbone.config.hidden_size)
        self.pre_vowel = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.pre_other = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.head_vowel = nn.Linear(H, 2)
        self.head_other = nn.Linear(H, 2)

    def masked_mean_pool(self, last_hidden, attn_mask_samples):
        feat_mask = self.backbone._get_feature_vector_attention_mask(last_hidden.shape[1], attn_mask_samples)
        feat_mask = feat_mask.to(last_hidden.device).unsqueeze(-1).type_as(last_hidden)
        masked = last_hidden * feat_mask
        denom = feat_mask.sum(dim=1).clamp(min=1.0)
        return masked.sum(dim=1) / denom

    def _heads_fp32(self, x_fp_any: torch.Tensor, head: nn.Module) -> torch.Tensor:
        # Keep head math in fp32 even when AMP is enabled
        x = x_fp_any.float()
        if x.is_cuda:
            with torch.amp.autocast(device_type="cuda", enabled=False):
                return head(x)
        return head(x)

    def forward_logits(self, input_values, attention_mask, task_group):
        with torch.no_grad():
            out = self.backbone(input_values=input_values, attention_mask=attention_mask)
            last_hidden = out.last_hidden_state

        pooled = self.masked_mean_pool(last_hidden, attention_mask)  # [B,H]

        z_v = self.pre_vowel(pooled.float())
        z_o = self.pre_other(pooled.float())

        logits_v = self._heads_fp32(z_v, self.head_vowel)
        logits_o = self._heads_fp32(z_o, self.head_other)

        # Select the correct head per row based on task_group
        is_vowel = torch.tensor([tg == "vowel" for tg in task_group], device=pooled.device, dtype=torch.bool)
        logits = logits_o.clone()
        if is_vowel.any():
            logits[is_vowel] = logits_v[is_vowel]
        return logits

# -------------------------
# Metrics + plot helpers
# Inputs: y_true and y_prob (PD probability), plus a threshold for thresholded metrics
# Output: numeric metrics dict + PNG charts saved to disk
# -------------------------
def compute_auc(y_true, y_prob):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    if len(np.unique(y_true)) < 2:
        return float("nan")
    return float(roc_auc_score(y_true, y_prob))

def compute_threshold_metrics(y_true, y_prob, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])

    eps = 1e-12
    acc = (TP + TN) / max(1, (TP + TN + FP + FN))
    prec = TP / (TP + FP + eps)
    rec = TP / (TP + FN + eps)     # sensitivity
    f1 = 2 * prec * rec / (prec + rec + eps)
    spec = TN / (TN + FP + eps)

    try:
        mcc = float(matthews_corrcoef(y_true, y_pred)) if len(np.unique(y_pred)) > 1 else float("nan")
    except Exception:
        mcc = float("nan")

    try:
        _, pval = fisher_exact([[TN, FP], [FN, TP]], alternative="two-sided")
        pval = float(pval)
    except Exception:
        pval = float("nan")

    return {
        "threshold": float(thr),
        "confusion_matrix": {"TN": TN, "FP": FP, "FN": FN, "TP": TP},
        "accuracy": float(acc),
        "precision": float(prec),
        "recall": float(rec),
        "f1_score": float(f1),
        "sensitivity": float(rec),
        "specificity": float(spec),
        "mcc": float(mcc),
        "p_value_fisher_two_sided": float(pval),
    }

def save_roc_curve_png(y_true, y_prob, out_png, title_suffix="Test"):
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    plt.figure()
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curve ({title_suffix})")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def save_confusion_png(y_true, y_prob, out_png, thr=0.5, title_suffix="Test"):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(int)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])

    plt.figure()
    plt.imshow(cm)
    plt.xticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.yticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix ({title_suffix}, thr={thr:.2f})")
    for (i, j), v in np.ndenumerate(cm):
        plt.text(j, i, str(v), ha="center", va="center")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

# -------------------------
# Fairness: FNR per sex and ΔFNR = FNR(F) - FNR(M)
# Inputs: y_true, y_prob, sex_norm array, threshold
# Output: per-group FNR details + signed and absolute gap
# -------------------------
def compute_fnr_by_group_signed(y_true, y_prob, groups, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    groups = np.asarray(list(groups), dtype=object)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    out = {}
    for g in sorted(set(groups.tolist())):
        mask_g = (groups == g)
        if int(mask_g.sum()) == 0:
            continue

        pos_mask = mask_g & (y_true == 1)
        n_pos = int(pos_mask.sum())
        if n_pos == 0:
            out[g] = {"n_total": int(mask_g.sum()), "n_pos": 0, "tp": 0, "fn": 0, "fnr": float("nan")}
            continue

        tp = int(((y_pred == 1) & pos_mask).sum())
        fn = int(((y_pred == 0) & pos_mask).sum())
        fnr = float(fn / max(1, (fn + tp)))

        out[g] = {"n_total": int(mask_g.sum()), "n_pos": int(n_pos), "tp": int(tp), "fn": int(fn), "fnr": float(fnr)}

    fnr_m = out.get("M", {}).get("fnr", float("nan"))
    fnr_f = out.get("F", {}).get("fnr", float("nan"))

    if (not np.isnan(fnr_m)) and (not np.isnan(fnr_f)):
        delta_signed = float(fnr_f - fnr_m)
        delta_abs = float(abs(delta_signed))
    else:
        delta_signed = float("nan")
        delta_abs = float("nan")

    return out, delta_signed, delta_abs

def compute_confusion_counts(y_true, y_prob, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])
    return {"TN": TN, "FP": FP, "FN": FN, "TP": TP}

def compute_confusion_by_group(y_true, y_prob, groups, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    groups = np.asarray(list(groups), dtype=object)
    out = {}
    for g in sorted(set(groups.tolist())):
        mask = (groups == g)
        if int(mask.sum()) == 0:
            continue
        out[g] = {"n": int(mask.sum()), "confusion": compute_confusion_counts(y_true[mask], y_prob[mask], thr=thr)}
    return out

# -------------------------
# Reproducibility: set all RNG seeds
# Inputs: seed int
# Output: deterministic settings for random, numpy, torch
# -------------------------
def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# -------------------------
# Source checkpoint selection: pick most recent D1 trainval exp with all seeds
# Inputs: D1 trainval_runs folder
# Output: chosen_exp folder used to load best_heads.pt per seed
# -------------------------
D1_TRAINVAL_ROOT = Path(D1_OUT_ROOT) / "trainval_runs"
if not D1_TRAINVAL_ROOT.exists():
    raise FileNotFoundError(f"Missing D1 trainval_runs folder: {str(D1_TRAINVAL_ROOT)}")

exp_dirs = sorted([p for p in D1_TRAINVAL_ROOT.glob("exp_*") if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
if not exp_dirs:
    raise FileNotFoundError(f"No exp_* folders found under: {str(D1_TRAINVAL_ROOT)}")

def _has_all_seeds(exp_path: Path, seeds: list):
    for s in seeds:
        p = exp_path / f"run_D1_seed{s}" / "best_heads.pt"
        if not p.exists():
            return False
    return True

chosen_exp = None
for ed in exp_dirs:
    if _has_all_seeds(ed, SEEDS):
        chosen_exp = ed
        break

if chosen_exp is None:
    sample = exp_dirs[0]
    raise FileNotFoundError(
        "Could not find a D1 trainval experiment with all 3 best_heads.pt files.\n"
        f"Expected: {str(D1_TRAINVAL_ROOT)}/exp_*/run_D1_seedXXXX/best_heads.pt\n"
        f"Most recent exp checked: {str(sample)}"
    )

print("\nUsing SOURCE (D1) Train+Val experiment folder:")
print(" ", str(chosen_exp))

# -------------------------
# Output folder: write only under D2
# Inputs: D2_OUT_ROOT
# Output: ZS_ROOT folder for per-seed runs + summary files
# -------------------------
ZS_ROOT = Path(D2_OUT_ROOT) / "Cross_Language_Zero_Shot_Runs"
ZS_ROOT.mkdir(parents=True, exist_ok=True)

# -------------------------
# DataLoader: D2 TEST only (plus a small warm-up read)
# Inputs: test_df and audio files
# Output: test_loader ready for inference
# -------------------------
test_ds = AudioManifestDataset(test_df)
test_loader = DataLoader(
    test_ds,
    batch_size=PER_DEVICE_BS,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    collate_fn=collate_fn,
)

print("\nWarm-up: loading up to 2 TEST batches...")
t0 = time.time()
nb = len(test_loader)
wb = min(2, nb)
if wb == 0:
    raise RuntimeError("TEST DataLoader has 0 batches. Check test_df length and PER_DEVICE_BS.")
it = iter(test_loader)
for i in range(wb):
    _ = next(it)
    print(f"  loaded TEST warmup batch {i+1}/{wb}")
print(f"Warm-up done in {time.time()-t0:.2f}s")

# -------------------------
# Checkpoint loader: load only the trained heads from best_heads.pt
# Inputs: model instance + best_heads.pt path
# Output: model with head weights restored for this seed
# -------------------------
def load_heads_into_model(model, best_heads_path: Path):
    if not best_heads_path.exists():
        raise FileNotFoundError(f"Missing best_heads.pt: {str(best_heads_path)}")
    state = torch.load(str(best_heads_path), map_location="cpu")
    needed = ["pre_vowel", "pre_other", "head_vowel", "head_other"]
    missing = [k for k in needed if k not in state]
    if missing:
        raise KeyError(
            f"best_heads.pt missing keys {missing}. Found keys: {list(state.keys())}. "
            "This zero-shot code expects the trainval save format."
        )
    model.pre_vowel.load_state_dict(state["pre_vowel"], strict=True)
    model.pre_other.load_state_dict(state["pre_other"], strict=True)
    model.head_vowel.load_state_dict(state["head_vowel"], strict=True)
    model.head_other.load_state_dict(state["head_other"], strict=True)
    return model

# -------------------------
# Inference: run model over D2 test_loader and collect probabilities
# Inputs: loader + model
# Output: arrays (y_true, y_prob, sex_norm) for metrics and plots
# -------------------------
def run_inference(loader, model, desc):
    use_amp = bool(USE_AMP and DEVICE.type == "cuda")
    all_probs, all_true, all_sex = [], [], []

    pbar = tqdm(loader, desc=desc, dynamic_ncols=True)
    with torch.inference_mode():
        for batch in pbar:
            input_values = batch["input_values"].to(DEVICE, non_blocking=False)
            attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
            labels = batch["labels"].to(DEVICE, non_blocking=False)
            task_group = batch["task_group"]
            sex_norm = batch["sex_norm"]

            with torch.amp.autocast(device_type="cuda", enabled=use_amp):
                logits = model.forward_logits(input_values, attention_mask, task_group)

            probs = torch.softmax(logits.float(), dim=-1)[:, 1].detach().cpu().numpy().astype(np.float64)
            all_probs.extend(probs.tolist())
            all_true.extend(labels.detach().cpu().numpy().astype(np.int64).tolist())
            all_sex.extend(list(sex_norm))

    return np.asarray(all_true, dtype=np.int64), np.asarray(all_probs, dtype=np.float64), np.asarray(all_sex, dtype=object)

# -------------------------
# Per-seed run: load D1 heads, evaluate D2 at FIXED_THR, save metrics + PNGs
# Inputs: seed, chosen_exp, FIXED_THR, D2 test_loader
# Output: run_<TRANSFER_TAG>_seedXXXX folder with metrics.json + plots
# -------------------------
def run_seed(seed: int):
    set_all_seeds(seed)

    run_dir = ZS_ROOT / f"run_{TRANSFER_TAG}_seed{seed}"
    run_dir.mkdir(parents=True, exist_ok=True)

    best_heads_path = chosen_exp / f"run_D1_seed{seed}" / "best_heads.pt"

    print(f"\n[seed={seed}] Loading SOURCE heads from:")
    print(" ", str(best_heads_path))
    print(f"[seed={seed}] Evaluating TARGET D2 TEST @ fixed_thr={FIXED_THR:.6f}")

    model = Wav2Vec2TwoHeadClassifier(BACKBONE_CKPT, dropout_p=DROPOUT_P).to(DEVICE)
    model = load_heads_into_model(model, best_heads_path)
    model.eval()

    yt, pt, st = run_inference(test_loader, model, desc=f"[seed={seed}] D2 TEST (zero-shot)")

    test_auc = compute_auc(yt, pt)  # threshold-free
    thr_metrics = compute_threshold_metrics(yt, pt, thr=FIXED_THR)

    fnr_by_sex, delta_f_minus_m, delta_abs = compute_fnr_by_group_signed(yt, pt, st, thr=FIXED_THR)
    confusion_by_sex = compute_confusion_by_group(yt, pt, st, thr=FIXED_THR)

    # Plots: always save overall ROC + confusion matrix
    roc_png = run_dir / "roc_curve.png"
    cm_png  = run_dir / "confusion_matrix.png"
    save_roc_curve_png(yt, pt, str(roc_png), title_suffix=f"D2 Test (seed={seed})")
    save_confusion_png(yt, pt, str(cm_png), thr=FIXED_THR, title_suffix=f"D2 Test (seed={seed})")

    # Sex-specific confusion matrices (only if that sex exists in D2 TEST)
    cm_m_png = None
    cm_f_png = None
    mask_m = (st == "M")
    mask_f = (st == "F")

    if int(mask_m.sum()) > 0:
        cm_m_png = run_dir / "confusion_matrix_M.png"
        save_confusion_png(yt[mask_m], pt[mask_m], str(cm_m_png), thr=FIXED_THR, title_suffix=f"D2 Test SEX=M (seed={seed})")

    if int(mask_f.sum()) > 0:
        cm_f_png = run_dir / "confusion_matrix_F.png"
        save_confusion_png(yt[mask_f], pt[mask_f], str(cm_f_png), thr=FIXED_THR, title_suffix=f"D2 Test SEX=F (seed={seed})")

    metrics = {
        "transfer": {
            "tag": TRANSFER_TAG,
            "source_dataset": "D1",
            "target_dataset": "D2",
            "threshold_policy": "Fixed threshold = mean VAL-opt threshold from D1 monolingual test summary (applied to ALL seeds)",
            "fixed_threshold_value": float(FIXED_THR),
            "d1_mono_summary_path": str(D1_MONO_SUMMARY),
        },

        "target": {
            "dataset": "D2",
            "seed": int(seed),
            "n_test": int(len(test_df)),
            "label_counts_test": test_df["label_num"].value_counts(dropna=False).to_dict(),
            "sex_counts_test_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),
        },

        "test_auroc": float(test_auc),

        "threshold_metrics_test": thr_metrics,
        "test_threshold_used": float(FIXED_THR),

        "fairness_test": {
            "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at fixed transfer threshold.",
            "fnr_by_sex_norm": fnr_by_sex,
            "delta_fnr_F_minus_M": float(delta_f_minus_m),
            "delta_fnr_abs": float(delta_abs),
            "note": "If n_PD for a sex is 0, its FNR is NaN and ΔFNR is NaN.",
            "sex_normalization_note": "D2 mapping: male->M, female->F; otherwise UNK (case-insensitive).",
        },

        "confusion_by_sex_norm": confusion_by_sex,

        "artifacts": {
            "roc_curve_png": str(roc_png),
            "confusion_matrix_png": str(cm_png),
            "confusion_matrix_M_png": str(cm_m_png) if cm_m_png is not None else None,
            "confusion_matrix_F_png": str(cm_f_png) if cm_f_png is not None else None,
        },

        "paths": {
            "d2_out_root": str(D2_OUT_ROOT),
            "zero_shot_root": str(ZS_ROOT),
            "run_dir": str(run_dir),
            "source_trainval_experiment_used": str(chosen_exp),
            "source_best_heads_path": str(best_heads_path),
        },

        "backbone_ckpt": BACKBONE_CKPT,
        "dropout_p": float(DROPOUT_P),
        "timestamp": time.strftime("%Y%m%d_%H%M%S"),
    }

    metrics_path = run_dir / "metrics.json"
    with open(metrics_path, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2)

    print(f"[seed={seed}] DONE | test_AUROC={test_auc:.6f}")

    def _fmt_fnr(g):
        d = fnr_by_sex.get(g, None)
        if d is None:
            return "n/a"
        return f"fnr={d['fnr']:.6f} (n_PD={d['n_pos']}, fn={d['fn']}, tp={d['tp']})"

    print(f"[seed={seed}] TEST metrics @ fixed_thr={FIXED_THR:.6f}")
    print(f"[seed={seed}] FAIRNESS (H3) @ fixed_thr={FIXED_THR:.6f}:")
    print("  M:", _fmt_fnr("M"))
    print("  F:", _fmt_fnr("F"))
    if "UNK" in fnr_by_sex:
        print("  UNK:", _fmt_fnr("UNK"))
    print("  ΔFNR (F-M):", f"{delta_f_minus_m:.6f}" if not np.isnan(delta_f_minus_m) else "nan")
    print("  |ΔFNR|:", f"{delta_abs:.6f}" if not np.isnan(delta_abs) else "nan")

    print(f"[seed={seed}] WROTE:")
    print(" ", str(metrics_path))
    print(" ", str(roc_png))
    print(" ", str(cm_png))
    if cm_m_png is not None:
        print(" ", str(cm_m_png))
    if cm_f_png is not None:
        print(" ", str(cm_f_png))

    return {
        "seed": int(seed),
        "test_auroc": float(test_auc),
        "thr_metrics_test": thr_metrics,
        "fnr_by_sex_test": fnr_by_sex,
        "delta_f_minus_m_test": float(delta_f_minus_m),
        "delta_abs_test": float(delta_abs),
        "run_dir": str(run_dir),
    }

# -------------------------
# Aggregate across 3 seeds: AUROC mean ± 95% CI, others mean ± SD
# Inputs: per-seed outputs from run_seed()
# Output: summary_zero_shot.json + history_index.jsonl under ZS_ROOT
# -------------------------
seed_results = []
for seed in SEEDS:
    seed_results.append(run_seed(seed))

aurocs = [r["test_auroc"] for r in seed_results]
t_crit = 4.302652729911275  # df=2, 95% CI
n = len(aurocs)
mean_auc = float(np.mean(aurocs))
std_auc = float(np.std(aurocs, ddof=1)) if n > 1 else 0.0
half_width = float(t_crit * (std_auc / math.sqrt(n))) if n > 1 else 0.0
ci95 = [float(mean_auc - half_width), float(mean_auc + half_width)]

def _mean_sd(vals):
    vals = np.asarray(vals, dtype=np.float64)
    return float(np.nanmean(vals)), float(np.nanstd(vals, ddof=1)) if np.sum(~np.isnan(vals)) > 1 else 0.0

thr_list = [r["thr_metrics_test"] for r in seed_results]
keys = ["accuracy","precision","recall","f1_score","sensitivity","specificity","mcc","p_value_fisher_two_sided"]

agg = {}
for k in keys:
    v = [float(tm.get(k, float("nan"))) for tm in thr_list]
    mu, sd = _mean_sd(v)
    agg[k] = {
        "mean": mu,
        "sd": sd,
        "values_by_seed": {str(r["seed"]): float(tm.get(k, float("nan"))) for r, tm in zip(seed_results, thr_list)}
    }

cm_by_seed = {str(r["seed"]): r["thr_metrics_test"]["confusion_matrix"] for r in seed_results}

# Fairness aggregation (H3): mean ± SD across seeds
fnr_by_seed = {str(r["seed"]): r["fnr_by_sex_test"] for r in seed_results}
delta_signed_by_seed = {str(r["seed"]): float(r["delta_f_minus_m_test"]) for r in seed_results}
delta_abs_by_seed = {str(r["seed"]): float(r["delta_abs_test"]) for r in seed_results}

fnr_m_vals, fnr_f_vals, n_pd_m_vals, n_pd_f_vals = [], [], [], []
d_signed_vals, d_abs_vals = [], []

for r in seed_results:
    s = str(r["seed"])
    d = fnr_by_seed.get(s, {})
    fnr_m_vals.append(float(d.get("M", {}).get("fnr", float("nan"))))
    fnr_f_vals.append(float(d.get("F", {}).get("fnr", float("nan"))))
    n_pd_m_vals.append(float(d.get("M", {}).get("n_pos", float("nan"))))
    n_pd_f_vals.append(float(d.get("F", {}).get("n_pos", float("nan"))))
    d_signed_vals.append(float(delta_signed_by_seed.get(s, float("nan"))))
    d_abs_vals.append(float(delta_abs_by_seed.get(s, float("nan"))))

fnr_m_mean, fnr_m_sd = _mean_sd(fnr_m_vals)
fnr_f_mean, fnr_f_sd = _mean_sd(fnr_f_vals)
d_signed_mean, d_signed_sd = _mean_sd(d_signed_vals)
d_abs_mean, d_abs_sd = _mean_sd(d_abs_vals)

print("\nZero-shot TEST AUROC by seed:")
for r in seed_results:
    print(f"  seed {r['seed']}: {r['test_auroc']:.6f}")
print(f"\nMean Zero-shot TEST AUROC: {mean_auc:.6f}")
print(f"95% CI (t, n=3): [{ci95[0]:.6f}, {ci95[1]:.6f}]")

print(f"\nFixed threshold used for ALL seeds (MEAN VAL-opt from D1 mono): {FIXED_THR:.6f}")

print("\nThreshold metrics on D2 TEST @ fixed threshold (mean ± SD across seeds):")
for k in ["accuracy","precision","sensitivity","specificity","f1_score","mcc"]:
    mu = agg[k]["mean"]
    sd = agg[k]["sd"]
    print(f"  {k}: {mu:.6f} ± {sd:.6f}")
print("  fisher_p_value_two_sided:", f"{agg['p_value_fisher_two_sided']['mean']:.6g} ± {agg['p_value_fisher_two_sided']['sd']:.6g}")

print("\nFAIRNESS (H3) on D2 TEST @ fixed threshold (mean ± SD across seeds):")
print(f"  FNR_M: {fnr_m_mean:.6f} ± {fnr_m_sd:.6f}")
print(f"  FNR_F: {fnr_f_mean:.6f} ± {fnr_f_sd:.6f}")
print(f"  ΔFNR (F-M): {d_signed_mean:.6f} ± {d_signed_sd:.6f}")
print(f"  |ΔFNR|: {d_abs_mean:.6f} ± {d_abs_sd:.6f}")

summary = {
    "transfer": {
        "tag": TRANSFER_TAG,
        "source_dataset": "D1",
        "target_dataset": "D2",
        "fixed_threshold_value": float(FIXED_THR),
        "fixed_threshold_source": {
            "path": str(D1_MONO_SUMMARY),
            "note": "Mean across seeds of D1 VAL-opt thresholds as recorded by the D1 monolingual test run.",
        },
        "source_trainval_experiment_used": str(chosen_exp),
    },

    "target": {
        "dx_out_root": str(D2_OUT_ROOT),
        "manifest_all": str(D2_MANIFEST_ALL),
        "n_test": int(len(test_df)),
        "label_counts_test": test_df["label_num"].value_counts(dropna=False).to_dict(),
        "sex_counts_test_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),
    },

    "seeds": SEEDS,

    "test_aurocs_by_seed": {str(r["seed"]): float(r["test_auroc"]) for r in seed_results},
    "mean_test_auroc": float(mean_auc),
    "t_crit_df2_95": float(t_crit),
    "half_width_95_ci": float(half_width),
    "ci95_test_auroc": ci95,

    "threshold_metrics_test_mean_sd": agg,
    "confusion_matrix_by_seed": cm_by_seed,

    "fairness_test": {
        "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at the fixed transfer threshold.",
        "fnr_by_sex_norm_by_seed": fnr_by_seed,
        "delta_fnr_F_minus_M_by_seed": delta_signed_by_seed,
        "delta_fnr_abs_by_seed": delta_abs_by_seed,
        "fnr_M_mean_sd": {"mean": float(fnr_m_mean), "sd": float(fnr_m_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, fnr_m_vals)}},
        "fnr_F_mean_sd": {"mean": float(fnr_f_mean), "sd": float(fnr_f_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, fnr_f_vals)}},
        "delta_fnr_F_minus_M_mean_sd": {"mean": float(d_signed_mean), "sd": float(d_signed_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, d_signed_vals)}},
        "delta_fnr_abs_mean_sd": {"mean": float(d_abs_mean), "sd": float(d_abs_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, d_abs_vals)}},
        "denominators_PD_by_seed": {str(s): {"n_PD_M": float(n_pd_m_vals[i]), "n_PD_F": float(n_pd_f_vals[i])} for i, s in enumerate(SEEDS)},
        "sex_normalization_note": "D2 mapping: male->M, female->F; otherwise UNK. ΔFNR computed only when both M and F have defined FNR.",
    },

    "run_dirs": [r["run_dir"] for r in seed_results],
    "timestamp": time.strftime("%Y%m%d_%H%M%S"),
}

summary_path = ZS_ROOT / "summary_zero_shot.json"
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(summary, f, indent=2)

history_path = ZS_ROOT / "history_index.jsonl"
with open(history_path, "a", encoding="utf-8") as f:
    f.write(json.dumps(summary) + "\n")

print("\nWROTE summary:", str(summary_path))
print("APPENDED history index:", str(history_path))
print("Open this folder to access artifacts:", str(ZS_ROOT))

# -------------------------
# Stop runtime (Colab)
# Inputs: none
# Output: attempts to release the GPU instance
# -------------------------
print("\nAll done. Unassigning the runtime to stop the L4 instance...")
try:
    from google.colab import runtime  # type: ignore
    print("Calling runtime.unassign() now.")
    runtime.unassign()
except Exception as e:
    print("Could not unassign runtime automatically. You can stop the runtime manually in Colab.")
    print("Reason:", repr(e))

The following cell runs a cross language zero shot test from Italian (D4) to Slovak (D2). It evaluates whether classifier heads trained on the Italian dataset (D4, IPVS) can detect Parkinson’s disease in the Slovak dataset (D2, EWA-DB) without any training or fine tuning on D2. The evaluation uses only the D2 test split from `D2/manifests/manifest_all.csv`. The cell automatically finds the most recent completed D4 train and validation experiment that contains saved head checkpoints for the three seeds 1337, 2024, and 7777, and applies those D4 trained heads directly to D2 audio.

A single fixed decision threshold is used for all three seeds. This threshold is read from D4’s saved `monolingual_test_runs/summary_test.json` and represents the mean validation selected threshold from D4. The code first looks for the preferred key `val_opt_threshold_mean_sd.mean`, and if that is not found, it checks known fallback locations used in other summary formats. This approach keeps the D2 evaluation independent by avoiding any threshold tuning on D2.

The cell includes several setup and validation steps to prevent silent errors. It ensures Google Drive is mounted, checks that no local files are shadowing PyTorch or Transformers, prints key paths and the active compute device, confirms that the D2 manifest contains all required columns, verifies that the manifest belongs to dataset D2, filters strictly to the test split, and stops early if any listed audio files are missing. Each clip is assigned a simple task group, where `task == "vowl"` is treated as vowel and everything else is treated as other. Sex values in D2, such as male or female and common variants, are normalized to `sex_norm` values of M, F, or UNK so sex based reporting stays consistent even if unexpected values appear.

For each seed, the cell loads the matching D4 head checkpoint (`best_heads.pt`) into a frozen Wav2Vec2 two head classifier, with one head for vowel clips and one for other speech. It then runs inference on all D2 test clips to produce a Parkinson’s probability for each clip. Using these scores, it computes overall AUROC, which does not depend on a threshold, and fixed threshold metrics at the D4 derived threshold. These metrics include the confusion matrix, accuracy, precision, sensitivity or recall, specificity, F1 score, MCC, and the Fisher exact test p value. Fairness is computed as ΔFNR, defined as FNR(F) minus FNR(M), measured on Parkinson’s positive clips only, along with confusion counts split by sex. For each seed, the cell saves a ROC curve, an overall confusion matrix, additional sex specific confusion matrices when both M and F are present, and a detailed `metrics.json` file under
`<D2_OUT_ROOT>/Cross_Language_Zero_Shot_Runs/run_IT_to_SK_seed<seed>/`.

After all three seeds finish, results are combined across seeds. The cell reports the mean AUROC with a 95 percent confidence interval using a t distribution with n equal to 3, along with the mean and standard deviation of the fixed threshold metrics and fairness values. A combined summary is saved as `summary_zero_shot2.json` and the same record is appended to `history_index.jsonl`, both under `<D2_OUT_ROOT>/Cross_Language_Zero_Shot_Runs/`. The cell ends by unassigning the Colab runtime to shut down the GPU.

In [None]:
# =========================
# IT → SK Zero-shot (D4 → D2): Test-only transfer with fixed threshold
# - Loads D4-trained heads (most recent train+val experiment)
# - Runs inference on D2 TEST only (from manifest_all.csv)
# - Uses one fixed threshold read from D4 monolingual summary_test.json
# - Saves per-seed metrics + plots, plus an aggregated summary under D2
# =========================

import os, json, math, random, time, warnings
from pathlib import Path

import numpy as np
import pandas as pd
import soundfile as sf

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import Wav2Vec2Model

from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, matthews_corrcoef
from scipy.stats import fisher_exact
import matplotlib.pyplot as plt

from tqdm.auto import tqdm

# -------------------------
# Import safety: prevent local files from overriding torch/transformers
# Inputs: /content directory
# Output: raises early if a conflicting file/folder exists
# -------------------------
if os.path.exists("/content/torch.py") or os.path.exists("/content/torch/__init__.py"):
    raise RuntimeError("Found local /content/torch.py or /content/torch/ that shadows PyTorch. Rename/remove it and restart runtime.")
if os.path.exists("/content/transformers.py") or os.path.exists("/content/transformers/__init__.py"):
    raise RuntimeError("Found local /content/transformers.py or /content/transformers/ that shadows Hugging Face Transformers. Rename/remove it and restart runtime.")

# -------------------------
# Drive mount (Colab)
# Inputs: Google Drive
# Output: /content/drive mounted for reading inputs and writing results
# -------------------------
if not os.path.isdir("/content/drive/MyDrive"):
    from google.colab import drive  # type: ignore
    drive.mount("/content/drive")

# -------------------------
# Paths: target (D2) + source (D4)
# Inputs: optional globals D2_OUT_ROOT, D4_OUT_ROOT (if set earlier), else fallbacks
# Output: D2_MANIFEST_ALL and D4_MONO_SUMMARY paths
# -------------------------
D2_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D2-Slovak (EWA-DB)/EWA-DB/preprocessed_v1"
D4_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D4-Italian (IPVS)/preprocessed_v1"

D2_OUT_ROOT = globals().get("D2_OUT_ROOT", D2_OUT_ROOT_FALLBACK)
D4_OUT_ROOT = globals().get("D4_OUT_ROOT", D4_OUT_ROOT_FALLBACK)

D2_MANIFEST_ALL = f"{D2_OUT_ROOT}/manifests/manifest_all.csv"

# D4 monolingual summary (provides the fixed threshold source)
D4_MONO_SUMMARY = Path(D4_OUT_ROOT) / "monolingual_test_runs" / "summary_test.json"

# -------------------------
# Run settings (kept consistent with other test cells)
# Inputs: constants below
# Output: printed configuration + runtime behavior (batching, AMP, device)
# -------------------------
TRANSFER_TAG    = "IT_to_SK"          # confirmed
SEEDS           = [1337, 2024, 7777]
BACKBONE_CKPT   = "facebook/wav2vec2-base"
SR_EXPECTED     = 16000
TINY_THRESH     = 1e-4

EFFECTIVE_BS    = 64
PER_DEVICE_BS   = 16
GRAD_ACCUM      = max(1, EFFECTIVE_BS // PER_DEVICE_BS)  # printed only

DROPOUT_P       = 0.2

NUM_WORKERS     = 0
PIN_MEMORY      = False

USE_AMP         = True

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
warnings.filterwarnings("ignore", message="Passing `gradient_checkpointing` to a config initialization is deprecated")

print("D2_OUT_ROOT (target):", D2_OUT_ROOT)
print("D2_MANIFEST_ALL:", D2_MANIFEST_ALL)
print("D4_OUT_ROOT (source):", D4_OUT_ROOT)
print("D4_MONO_SUMMARY:", str(D4_MONO_SUMMARY))
print("DEVICE:", DEVICE)
if DEVICE.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
print("PER_DEVICE_BS:", PER_DEVICE_BS, "| EFFECTIVE_BS:", PER_DEVICE_BS * GRAD_ACCUM)
print("NUM_WORKERS:", NUM_WORKERS, "| PIN_MEMORY:", PIN_MEMORY)
print("USE_AMP:", bool(USE_AMP and DEVICE.type == "cuda"))

# -------------------------
# Fixed threshold: read from D4 monolingual summary_test.json
# Inputs: D4_MONO_SUMMARY (JSON)
# Output: FIXED_THR (float) used for all three seeds on D2
# -------------------------
def _safe_float(x):
    try:
        v = float(x)
        return v if np.isfinite(v) else float("nan")
    except Exception:
        return float("nan")

if not D4_MONO_SUMMARY.exists():
    raise FileNotFoundError(
        "Missing D4 monolingual summary_test.json.\n"
        f"Expected: {str(D4_MONO_SUMMARY)}\n"
        "Run the D4 monolingual test cell first (the one that writes summary_test.json)."
    )

with open(D4_MONO_SUMMARY, "r", encoding="utf-8") as f:
    d4_sum = json.load(f)

FIXED_THR = float("nan")

# Preferred: mean threshold already aggregated across seeds
if isinstance(d4_sum.get("val_opt_threshold_mean_sd"), dict):
    FIXED_THR = _safe_float(d4_sum["val_opt_threshold_mean_sd"].get("mean", float("nan")))

# Fallback: compute mean from per-seed thresholds
if not np.isfinite(FIXED_THR):
    by_seed = d4_sum.get("val_opt_threshold_by_seed", None)
    if isinstance(by_seed, dict) and len(by_seed) > 0:
        FIXED_THR = _safe_float(np.nanmean([_safe_float(v) for v in by_seed.values()]))

# Optional fallback for alternate schema (if present)
if not np.isfinite(FIXED_THR):
    ts = d4_sum.get("threshold_selection", None)
    if isinstance(ts, dict) and isinstance(ts.get("val_selected_threshold_mean_sd"), dict):
        FIXED_THR = _safe_float(ts["val_selected_threshold_mean_sd"].get("mean", float("nan")))

if not np.isfinite(FIXED_THR):
    raise RuntimeError(
        "Could not read a finite fixed threshold from D4 monolingual summary_test.json.\n"
        "Expected keys (preferred): val_opt_threshold_mean_sd.mean, or fallback val_opt_threshold_by_seed."
    )

print(f"\nFixed transfer threshold (mean D4 VAL-opt, applied to ALL seeds): {FIXED_THR:.6f}")

# -------------------------
# Target manifest: load D2 and keep TEST only
# Inputs: D2 manifest_all.csv
# Output: test_df with required columns + basic counts printed
# -------------------------
if not os.path.exists(D2_MANIFEST_ALL):
    raise FileNotFoundError(f"Missing D2 manifest_all.csv: {D2_MANIFEST_ALL}")

m_all = pd.read_csv(D2_MANIFEST_ALL)

# Require these fields for inference + fairness plots
req_cols = {"split", "clip_path", "label_num", "task", "sex"}
missing = [c for c in sorted(req_cols) if c not in m_all.columns]
if missing:
    raise ValueError(f"D2 manifest missing required columns: {missing}. Found: {list(m_all.columns)}")

# Ensure the manifest content is actually D2
if "dataset" in m_all.columns and m_all["dataset"].notna().any():
    dataset_id = str(m_all["dataset"].astype(str).value_counts(dropna=True).idxmax())
    m_all = m_all[m_all["dataset"].astype(str) == dataset_id].copy()
else:
    dataset_id = "DX"

if dataset_id != "D2":
    raise RuntimeError(
        f"Expected dataset_id=='D2' but got {dataset_id!r}. "
        "This usually means D2_OUT_ROOT is wrong or inherited from a previous cell.\n"
        f"D2_OUT_ROOT={D2_OUT_ROOT}"
    )

# Keep a stable column set (missing ones are filled with NaN)
keep_cols = ["clip_path", "label_num", "task", "speaker_id", "sex", "age", "duration_sec", "split"]
for c in keep_cols:
    if c not in m_all.columns:
        m_all[c] = np.nan
m_all = m_all[keep_cols].copy()

test_df = m_all[m_all["split"].isin(["test"])].copy().reset_index(drop=True)

print(f"\nTarget dataset inferred: {dataset_id}")
print(f"TEST rows: {len(test_df)}")
if len(test_df) == 0:
    raise RuntimeError("After filtering to split=='test', D2 manifest has 0 rows.")

print("TEST label counts:", test_df["label_num"].value_counts(dropna=False).to_dict())
print("TEST sex counts (raw):", test_df["sex"].value_counts(dropna=False).to_dict())

# -------------------------
# Path check: ensure target audio files exist
# Inputs: test_df.clip_path
# Output: raises early if any clip files are missing (shows a few examples)
# -------------------------
def _fail_fast_missing_paths(df: pd.DataFrame, name: str):
    missing_paths = []
    for p in tqdm(df["clip_path"].astype(str).tolist(), desc=f"Check {name} clip_path exists", dynamic_ncols=True):
        if not os.path.exists(p):
            missing_paths.append(p)
            if len(missing_paths) >= 10:
                break
    if missing_paths:
        raise FileNotFoundError(f"{name}: missing clip_path(s). Examples: {missing_paths}")

_fail_fast_missing_paths(test_df, "TEST")

# -------------------------
# Task grouping: map each row to vowel vs other
# Inputs: test_df.task
# Output: test_df.task_group used to pick the correct head during inference
# -------------------------
def _task_group(task_val) -> str:
    return "vowel" if str(task_val) == "vowl" else "other"

test_df["task_group"] = test_df["task"].apply(_task_group)

# -------------------------
# Sex normalization for fairness/charts (D2 uses male/female)
# Inputs: test_df.sex
# Output: test_df.sex_norm in {M,F,UNK}
# -------------------------
def normalize_sex_d2(val) -> str:
    """
    Returns 'M', 'F', or 'UNK'
    D2 encoding: male/female (strings)
    Also handles common variants.
    """
    if pd.isna(val):
        return "UNK"
    s = str(val).strip().lower()
    if s in {"m", "male", "man", "masc", "masculine"}:
        return "M"
    if s in {"f", "female", "woman", "fem", "feminine"}:
        return "F"
    return "UNK"

test_df["sex_norm"] = test_df["sex"].apply(normalize_sex_d2)
print("TEST sex counts (normalized):", test_df["sex_norm"].value_counts(dropna=False).to_dict())
if (test_df["sex_norm"] == "UNK").any():
    print("NOTE: Some D2 'sex' values could not be normalized to M/F and were counted as 'UNK' for fairness and sex charts.")

# -------------------------
# Dataset + collator: read audio and build attention masks
# Inputs: test_df rows + audio files
# Output: DataLoader batches with padded input_values, attention_mask, labels, task_group, sex_norm
# -------------------------
class AudioManifestDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        clip_path = str(row["clip_path"])
        label = int(row["label_num"])
        task_group = str(row["task_group"])
        sex_norm = str(row["sex_norm"])

        if not os.path.exists(clip_path):
            raise FileNotFoundError(f"Missing clip_path: {clip_path}")

        y, sr = sf.read(clip_path, always_2d=False)
        if isinstance(y, np.ndarray) and y.ndim > 1:
            y = np.mean(y, axis=1)
        y = np.asarray(y, dtype=np.float32)

        if int(sr) != SR_EXPECTED:
            raise ValueError(f"Sample rate mismatch (expected {SR_EXPECTED}): {clip_path} has sr={sr}")

        # attention_mask mainly matters for vowel clips (ignore padded tail)
        attn = np.ones((len(y),), dtype=np.int64)
        if task_group == "vowel":
            k = -1
            for j in range(len(y) - 1, -1, -1):
                if abs(float(y[j])) > TINY_THRESH:
                    k = j
                    break
            if k >= 0:
                attn[:k+1] = 1
                attn[k+1:] = 0
            else:
                attn[:] = 1
        else:
            attn[:] = 1

        return {
            "input_values": torch.from_numpy(y),
            "attention_mask": torch.from_numpy(attn),
            "labels": torch.tensor(label, dtype=torch.long),
            "task_group": task_group,
            "sex_norm": sex_norm,
        }

def collate_fn(batch):
    max_len = int(max(b["input_values"].numel() for b in batch))
    input_vals, attn_masks, labels, task_groups, sex_norms = [], [], [], [], []
    for b in batch:
        x = b["input_values"]
        a = b["attention_mask"]
        pad = max_len - x.numel()
        if pad > 0:
            x = torch.cat([x, torch.zeros(pad, dtype=x.dtype)], dim=0)
            a = torch.cat([a, torch.zeros(pad, dtype=a.dtype)], dim=0)
        input_vals.append(x)
        attn_masks.append(a)
        labels.append(b["labels"])
        task_groups.append(b["task_group"])
        sex_norms.append(b["sex_norm"])
    return {
        "input_values": torch.stack(input_vals, dim=0),
        "attention_mask": torch.stack(attn_masks, dim=0),
        "labels": torch.stack(labels, dim=0),
        "task_group": task_groups,
        "sex_norm": sex_norms,
    }

# -------------------------
# Model: frozen wav2vec2 backbone + two heads (vowel vs other)
# Inputs: backbone checkpoint name + dropout_p
# Output: logits for PD vs healthy for each clip, using the correct head per task_group
# -------------------------
class Wav2Vec2TwoHeadClassifier(nn.Module):
    def __init__(self, ckpt: str, dropout_p: float):
        super().__init__()
        self.backbone = Wav2Vec2Model.from_pretrained(ckpt, use_safetensors=True)
        for p in self.backbone.parameters():
            p.requires_grad = False

        H = int(self.backbone.config.hidden_size)
        self.pre_vowel = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.pre_other = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.head_vowel = nn.Linear(H, 2)
        self.head_other = nn.Linear(H, 2)

    def masked_mean_pool(self, last_hidden, attn_mask_samples):
        feat_mask = self.backbone._get_feature_vector_attention_mask(last_hidden.shape[1], attn_mask_samples)
        feat_mask = feat_mask.to(last_hidden.device).unsqueeze(-1).type_as(last_hidden)
        masked = last_hidden * feat_mask
        denom = feat_mask.sum(dim=1).clamp(min=1.0)
        return masked.sum(dim=1) / denom

    def _heads_fp32(self, x_fp_any: torch.Tensor, head: nn.Module) -> torch.Tensor:
        # Keep head math in fp32 even when AMP is enabled
        x = x_fp_any.float()
        if x.is_cuda:
            with torch.amp.autocast(device_type="cuda", enabled=False):
                return head(x)
        return head(x)

    def forward_logits(self, input_values, attention_mask, task_group):
        with torch.no_grad():
            out = self.backbone(input_values=input_values, attention_mask=attention_mask)
            last_hidden = out.last_hidden_state

        pooled = self.masked_mean_pool(last_hidden, attention_mask)  # [B,H]

        z_v = self.pre_vowel(pooled.float())
        z_o = self.pre_other(pooled.float())

        logits_v = self._heads_fp32(z_v, self.head_vowel)
        logits_o = self._heads_fp32(z_o, self.head_other)

        # Select the correct head per row based on task_group
        is_vowel = torch.tensor([tg == "vowel" for tg in task_group], device=pooled.device, dtype=torch.bool)
        logits = logits_o.clone()
        if is_vowel.any():
            logits[is_vowel] = logits_v[is_vowel]
        return logits

# -------------------------
# Metrics + plot helpers
# Inputs: y_true and y_prob (PD probability), plus a threshold for thresholded metrics
# Output: numeric metrics dict + PNG charts saved to disk
# -------------------------
def compute_auc(y_true, y_prob):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    if len(np.unique(y_true)) < 2:
        return float("nan")
    return float(roc_auc_score(y_true, y_prob))

def compute_threshold_metrics(y_true, y_prob, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])

    eps = 1e-12
    acc = (TP + TN) / max(1, (TP + TN + FP + FN))
    prec = TP / (TP + FP + eps)
    rec = TP / (TP + FN + eps)     # sensitivity
    f1 = 2 * prec * rec / (prec + rec + eps)
    spec = TN / (TN + FP + eps)

    try:
        mcc = float(matthews_corrcoef(y_true, y_pred)) if len(np.unique(y_pred)) > 1 else float("nan")
    except Exception:
        mcc = float("nan")

    try:
        _, pval = fisher_exact([[TN, FP], [FN, TP]], alternative="two-sided")
        pval = float(pval)
    except Exception:
        pval = float("nan")

    return {
        "threshold": float(thr),
        "confusion_matrix": {"TN": TN, "FP": FP, "FN": FN, "TP": TP},
        "accuracy": float(acc),
        "precision": float(prec),
        "recall": float(rec),
        "f1_score": float(f1),
        "sensitivity": float(rec),
        "specificity": float(spec),
        "mcc": float(mcc),
        "p_value_fisher_two_sided": float(pval),
    }

def save_roc_curve_png(y_true, y_prob, out_png, title_suffix="Test"):
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    plt.figure()
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curve ({title_suffix})")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def save_confusion_png(y_true, y_prob, out_png, thr=0.5, title_suffix="Test"):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(int)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])

    plt.figure()
    plt.imshow(cm)
    plt.xticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.yticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix ({title_suffix}, thr={thr:.2f})")
    for (i, j), v in np.ndenumerate(cm):
        plt.text(j, i, str(v), ha="center", va="center")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

# -------------------------
# Fairness: FNR per sex and ΔFNR = FNR(F) - FNR(M)
# Inputs: y_true, y_prob, sex_norm array, threshold
# Output: per-group FNR details + signed and absolute gap
# -------------------------
def compute_fnr_by_group_signed(y_true, y_prob, groups, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    groups = np.asarray(list(groups), dtype=object)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    out = {}
    for g in sorted(set(groups.tolist())):
        mask_g = (groups == g)
        if int(mask_g.sum()) == 0:
            continue

        pos_mask = mask_g & (y_true == 1)
        n_pos = int(pos_mask.sum())
        if n_pos == 0:
            out[g] = {"n_total": int(mask_g.sum()), "n_pos": 0, "tp": 0, "fn": 0, "fnr": float("nan")}
            continue

        tp = int(((y_pred == 1) & pos_mask).sum())
        fn = int(((y_pred == 0) & pos_mask).sum())
        fnr = float(fn / max(1, (fn + tp)))

        out[g] = {"n_total": int(mask_g.sum()), "n_pos": int(n_pos), "tp": int(tp), "fn": int(fn), "fnr": float(fnr)}

    fnr_m = out.get("M", {}).get("fnr", float("nan"))
    fnr_f = out.get("F", {}).get("fnr", float("nan"))

    if (not np.isnan(fnr_m)) and (not np.isnan(fnr_f)):
        delta_signed = float(fnr_f - fnr_m)
        delta_abs = float(abs(delta_signed))
    else:
        delta_signed = float("nan")
        delta_abs = float("nan")

    return out, delta_signed, delta_abs

def compute_confusion_counts(y_true, y_prob, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])
    return {"TN": TN, "FP": FP, "FN": FN, "TP": TP}

def compute_confusion_by_group(y_true, y_prob, groups, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    groups = np.asarray(list(groups), dtype=object)
    out = {}
    for g in sorted(set(groups.tolist())):
        mask = (groups == g)
        if int(mask.sum()) == 0:
            continue
        out[g] = {"n": int(mask.sum()), "confusion": compute_confusion_counts(y_true[mask], y_prob[mask], thr=thr)}
    return out

# -------------------------
# Reproducibility: set all RNG seeds
# Inputs: seed int
# Output: deterministic settings for random, numpy, torch
# -------------------------
def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# -------------------------
# Source checkpoint selection: pick most recent D4 trainval exp with all seeds
# Inputs: D4 trainval_runs folder
# Output: chosen_exp folder used to load best_heads.pt per seed
# -------------------------
D4_TRAINVAL_ROOT = Path(D4_OUT_ROOT) / "trainval_runs"
if not D4_TRAINVAL_ROOT.exists():
    raise FileNotFoundError(f"Missing D4 trainval_runs folder: {str(D4_TRAINVAL_ROOT)}")

exp_dirs = sorted([p for p in D4_TRAINVAL_ROOT.glob("exp_*") if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
if not exp_dirs:
    raise FileNotFoundError(f"No exp_* folders found under: {str(D4_TRAINVAL_ROOT)}")

def _has_all_seeds(exp_path: Path, seeds: list):
    for s in seeds:
        p = exp_path / f"run_D4_seed{s}" / "best_heads.pt"
        if not p.exists():
            return False
    return True

chosen_exp = None
for ed in exp_dirs:
    if _has_all_seeds(ed, SEEDS):
        chosen_exp = ed
        break

if chosen_exp is None:
    sample = exp_dirs[0]
    raise FileNotFoundError(
        "Could not find a D4 trainval experiment with all 3 best_heads.pt files.\n"
        f"Expected: {str(D4_TRAINVAL_ROOT)}/exp_*/run_D4_seedXXXX/best_heads.pt\n"
        f"Most recent exp checked: {str(sample)}"
    )

print("\nUsing SOURCE (D4) Train+Val experiment folder:")
print(" ", str(chosen_exp))

# -------------------------
# Output folder: write only under D2
# Inputs: D2_OUT_ROOT
# Output: ZS_ROOT folder for per-seed runs + summary files
# -------------------------
ZS_ROOT = Path(D2_OUT_ROOT) / "Cross_Language_Zero_Shot_Runs"
ZS_ROOT.mkdir(parents=True, exist_ok=True)

# -------------------------
# DataLoader: D2 TEST only (plus a small warm-up read)
# Inputs: test_df and audio files
# Output: test_loader ready for inference
# -------------------------
test_ds = AudioManifestDataset(test_df)
test_loader = DataLoader(
    test_ds,
    batch_size=PER_DEVICE_BS,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    collate_fn=collate_fn,
)

print("\nWarm-up: loading up to 2 TEST batches...")
t0 = time.time()
nb = len(test_loader)
wb = min(2, nb)
if wb == 0:
    raise RuntimeError("TEST DataLoader has 0 batches. Check test_df length and PER_DEVICE_BS.")
it = iter(test_loader)
for i in range(wb):
    _ = next(it)
    print(f"  loaded TEST warmup batch {i+1}/{wb}")
print(f"Warm-up done in {time.time()-t0:.2f}s")

# -------------------------
# Checkpoint loader: load only the trained heads from best_heads.pt
# Inputs: model instance + best_heads.pt path
# Output: model with head weights restored for this seed
# -------------------------
def load_heads_into_model(model, best_heads_path: Path):
    if not best_heads_path.exists():
        raise FileNotFoundError(f"Missing best_heads.pt: {str(best_heads_path)}")
    state = torch.load(str(best_heads_path), map_location="cpu")
    needed = ["pre_vowel", "pre_other", "head_vowel", "head_other"]
    missing = [k for k in needed if k not in state]
    if missing:
        raise KeyError(
            f"best_heads.pt missing keys {missing}. Found keys: {list(state.keys())}. "
            "This zero-shot code expects the trainval save format."
        )
    model.pre_vowel.load_state_dict(state["pre_vowel"], strict=True)
    model.pre_other.load_state_dict(state["pre_other"], strict=True)
    model.head_vowel.load_state_dict(state["head_vowel"], strict=True)
    model.head_other.load_state_dict(state["head_other"], strict=True)
    return model

# -------------------------
# Inference: run model over D2 test_loader and collect probabilities
# Inputs: loader + model
# Output: arrays (y_true, y_prob, sex_norm) for metrics and plots
# -------------------------
def run_inference(loader, model, desc):
    use_amp = bool(USE_AMP and DEVICE.type == "cuda")
    all_probs, all_true, all_sex = [], [], []

    pbar = tqdm(loader, desc=desc, dynamic_ncols=True)
    with torch.inference_mode():
        for batch in pbar:
            input_values = batch["input_values"].to(DEVICE, non_blocking=False)
            attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
            labels = batch["labels"].to(DEVICE, non_blocking=False)
            task_group = batch["task_group"]
            sex_norm = batch["sex_norm"]

            with torch.amp.autocast(device_type="cuda", enabled=use_amp):
                logits = model.forward_logits(input_values, attention_mask, task_group)

            probs = torch.softmax(logits.float(), dim=-1)[:, 1].detach().cpu().numpy().astype(np.float64)
            all_probs.extend(probs.tolist())
            all_true.extend(labels.detach().cpu().numpy().astype(np.int64).tolist())
            all_sex.extend(list(sex_norm))

    return np.asarray(all_true, dtype=np.int64), np.asarray(all_probs, dtype=np.float64), np.asarray(all_sex, dtype=object)

# -------------------------
# Per-seed run: load D4 heads, evaluate D2 at FIXED_THR, save metrics + PNGs
# Inputs: seed, chosen_exp, FIXED_THR, D2 test_loader
# Output: run_<TRANSFER_TAG>_seedXXXX folder with metrics.json + plots
# -------------------------
def run_seed(seed: int):
    set_all_seeds(seed)

    run_dir = ZS_ROOT / f"run_{TRANSFER_TAG}_seed{seed}"
    run_dir.mkdir(parents=True, exist_ok=True)

    best_heads_path = chosen_exp / f"run_D4_seed{seed}" / "best_heads.pt"

    print(f"\n[seed={seed}] Loading SOURCE heads from:")
    print(" ", str(best_heads_path))
    print(f"[seed={seed}] Evaluating TARGET D2 TEST @ fixed_thr={FIXED_THR:.6f}")

    model = Wav2Vec2TwoHeadClassifier(BACKBONE_CKPT, dropout_p=DROPOUT_P).to(DEVICE)
    model = load_heads_into_model(model, best_heads_path)
    model.eval()

    yt, pt, st = run_inference(test_loader, model, desc=f"[seed={seed}] D2 TEST (zero-shot)")

    test_auc = compute_auc(yt, pt)  # threshold-free
    thr_metrics = compute_threshold_metrics(yt, pt, thr=FIXED_THR)

    fnr_by_sex, delta_f_minus_m, delta_abs = compute_fnr_by_group_signed(yt, pt, st, thr=FIXED_THR)
    confusion_by_sex = compute_confusion_by_group(yt, pt, st, thr=FIXED_THR)

    # Plots: always save overall ROC + confusion matrix
    roc_png = run_dir / "roc_curve.png"
    cm_png  = run_dir / "confusion_matrix.png"
    save_roc_curve_png(yt, pt, str(roc_png), title_suffix=f"D2 Test (seed={seed})")
    save_confusion_png(yt, pt, str(cm_png), thr=FIXED_THR, title_suffix=f"D2 Test (seed={seed})")

    # Sex-specific confusion matrices (only if that sex exists in D2 TEST)
    cm_m_png = None
    cm_f_png = None
    mask_m = (st == "M")
    mask_f = (st == "F")

    if int(mask_m.sum()) > 0:
        cm_m_png = run_dir / "confusion_matrix_M.png"
        save_confusion_png(yt[mask_m], pt[mask_m], str(cm_m_png), thr=FIXED_THR, title_suffix=f"D2 Test SEX=M (seed={seed})")

    if int(mask_f.sum()) > 0:
        cm_f_png = run_dir / "confusion_matrix_F.png"
        save_confusion_png(yt[mask_f], pt[mask_f], str(cm_f_png), thr=FIXED_THR, title_suffix=f"D2 Test SEX=F (seed={seed})")

    metrics = {
        "transfer": {
            "tag": TRANSFER_TAG,
            "source_dataset": "D4",
            "target_dataset": "D2",
            "threshold_policy": "Fixed threshold = mean D4 VAL-opt threshold from D4 monolingual test summary (applied to ALL seeds)",
            "fixed_threshold_value": float(FIXED_THR),
            "d4_mono_summary_path": str(D4_MONO_SUMMARY),
        },

        "target": {
            "dataset": "D2",
            "seed": int(seed),
            "n_test": int(len(test_df)),
            "label_counts_test": test_df["label_num"].value_counts(dropna=False).to_dict(),
            "sex_counts_test_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),
        },

        "test_auroc": float(test_auc),

        "threshold_metrics_test": thr_metrics,
        "test_threshold_used": float(FIXED_THR),

        "fairness_test": {
            "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at fixed transfer threshold.",
            "fnr_by_sex_norm": fnr_by_sex,
            "delta_fnr_F_minus_M": float(delta_f_minus_m),
            "delta_fnr_abs": float(delta_abs),
            "note": "If n_PD for a sex is 0, its FNR is NaN and ΔFNR is NaN.",
            "sex_normalization_note": "D2 mapping: male->M, female->F; otherwise UNK.",
        },

        "confusion_by_sex_norm": confusion_by_sex,

        "artifacts": {
            "roc_curve_png": str(roc_png),
            "confusion_matrix_png": str(cm_png),
            "confusion_matrix_M_png": str(cm_m_png) if cm_m_png is not None else None,
            "confusion_matrix_F_png": str(cm_f_png) if cm_f_png is not None else None,
        },

        "paths": {
            "d2_out_root": str(D2_OUT_ROOT),
            "zero_shot_root": str(ZS_ROOT),
            "run_dir": str(run_dir),
            "source_trainval_experiment_used": str(chosen_exp),
            "source_best_heads_path": str(best_heads_path),
        },

        "backbone_ckpt": BACKBONE_CKPT,
        "dropout_p": float(DROPOUT_P),
        "timestamp": time.strftime("%Y%m%d_%H%M%S"),
    }

    metrics_path = run_dir / "metrics.json"
    with open(metrics_path, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2)

    print(f"[seed={seed}] DONE | test_AUROC={test_auc:.6f}")

    def _fmt_fnr(g):
        d = fnr_by_sex.get(g, None)
        if d is None:
            return "n/a"
        return f"fnr={d['fnr']:.6f} (n_PD={d['n_pos']}, fn={d['fn']}, tp={d['tp']})"

    print(f"[seed={seed}] TEST metrics @ fixed_thr={FIXED_THR:.6f}")
    print(f"[seed={seed}] FAIRNESS (H3) @ fixed_thr={FIXED_THR:.6f}:")
    print("  M:", _fmt_fnr("M"))
    print("  F:", _fmt_fnr("F"))
    if "UNK" in fnr_by_sex:
        print("  UNK:", _fmt_fnr("UNK"))
    print("  ΔFNR (F-M):", f"{delta_f_minus_m:.6f}" if not np.isnan(delta_f_minus_m) else "nan")
    print("  |ΔFNR|:", f"{delta_abs:.6f}" if not np.isnan(delta_abs) else "nan")

    print(f"[seed={seed}] WROTE:")
    print(" ", str(metrics_path))
    print(" ", str(roc_png))
    print(" ", str(cm_png))
    if cm_m_png is not None:
        print(" ", str(cm_m_png))
    if cm_f_png is not None:
        print(" ", str(cm_f_png))

    return {
        "seed": int(seed),
        "test_auroc": float(test_auc),
        "thr_metrics_test": thr_metrics,
        "fnr_by_sex_test": fnr_by_sex,
        "delta_f_minus_m_test": float(delta_f_minus_m),
        "delta_abs_test": float(delta_abs),
        "run_dir": str(run_dir),
    }

# -------------------------
# Aggregate across 3 seeds: AUROC mean ± 95% CI, others mean ± SD
# Inputs: per-seed outputs from run_seed()
# Output: summary_zero_shot2.json + history_index.jsonl under ZS_ROOT
# -------------------------
seed_results = []
for seed in SEEDS:
    seed_results.append(run_seed(seed))

aurocs = [r["test_auroc"] for r in seed_results]
t_crit = 4.302652729911275  # df=2, 95% CI
n = len(aurocs)
mean_auc = float(np.mean(aurocs))
std_auc = float(np.std(aurocs, ddof=1)) if n > 1 else 0.0
half_width = float(t_crit * (std_auc / math.sqrt(n))) if n > 1 else 0.0
ci95 = [float(mean_auc - half_width), float(mean_auc + half_width)]

def _mean_sd(vals):
    vals = np.asarray(vals, dtype=np.float64)
    return float(np.nanmean(vals)), float(np.nanstd(vals, ddof=1)) if np.sum(~np.isnan(vals)) > 1 else 0.0

thr_list = [r["thr_metrics_test"] for r in seed_results]
keys = ["accuracy","precision","recall","f1_score","sensitivity","specificity","mcc","p_value_fisher_two_sided"]

agg = {}
for k in keys:
    v = [float(tm.get(k, float("nan"))) for tm in thr_list]
    mu, sd = _mean_sd(v)
    agg[k] = {
        "mean": mu,
        "sd": sd,
        "values_by_seed": {str(r["seed"]): float(tm.get(k, float("nan"))) for r, tm in zip(seed_results, thr_list)}
    }

cm_by_seed = {str(r["seed"]): r["thr_metrics_test"]["confusion_matrix"] for r in seed_results}

# Fairness aggregation (H3): mean ± SD across seeds
fnr_by_seed = {str(r["seed"]): r["fnr_by_sex_test"] for r in seed_results}
delta_signed_by_seed = {str(r["seed"]): float(r["delta_f_minus_m_test"]) for r in seed_results}
delta_abs_by_seed = {str(r["seed"]): float(r["delta_abs_test"]) for r in seed_results}

fnr_m_vals, fnr_f_vals, n_pd_m_vals, n_pd_f_vals = [], [], [], []
d_signed_vals, d_abs_vals = [], []

for r in seed_results:
    s = str(r["seed"])
    d = fnr_by_seed.get(s, {})
    fnr_m_vals.append(float(d.get("M", {}).get("fnr", float("nan"))))
    fnr_f_vals.append(float(d.get("F", {}).get("fnr", float("nan"))))
    n_pd_m_vals.append(float(d.get("M", {}).get("n_pos", float("nan"))))
    n_pd_f_vals.append(float(d.get("F", {}).get("n_pos", float("nan"))))
    d_signed_vals.append(float(delta_signed_by_seed.get(s, float("nan"))))
    d_abs_vals.append(float(delta_abs_by_seed.get(s, float("nan"))))

fnr_m_mean, fnr_m_sd = _mean_sd(fnr_m_vals)
fnr_f_mean, fnr_f_sd = _mean_sd(fnr_f_vals)
d_signed_mean, d_signed_sd = _mean_sd(d_signed_vals)
d_abs_mean, d_abs_sd = _mean_sd(d_abs_vals)

print("\nZero-shot TEST AUROC by seed:")
for r in seed_results:
    print(f"  seed {r['seed']}: {r['test_auroc']:.6f}")
print(f"\nMean Zero-shot TEST AUROC: {mean_auc:.6f}")
print(f"95% CI (t, n=3): [{ci95[0]:.6f}, {ci95[1]:.6f}]")

print(f"\nFixed threshold used for ALL seeds (mean D4 VAL-opt): {FIXED_THR:.6f}")

print("\nThreshold metrics on D2 TEST @ fixed threshold (mean ± SD across seeds):")
for k in ["accuracy","precision","sensitivity","specificity","f1_score","mcc"]:
    mu = agg[k]["mean"]
    sd = agg[k]["sd"]
    print(f"  {k}: {mu:.6f} ± {sd:.6f}")
print("  fisher_p_value_two_sided:", f"{agg['p_value_fisher_two_sided']['mean']:.6g} ± {agg['p_value_fisher_two_sided']['sd']:.6g}")

print("\nFAIRNESS (H3) on D2 TEST @ fixed threshold (mean ± SD across seeds):")
print(f"  FNR_M: {fnr_m_mean:.6f} ± {fnr_m_sd:.6f}")
print(f"  FNR_F: {fnr_f_mean:.6f} ± {fnr_f_sd:.6f}")
print(f"  ΔFNR (F-M): {d_signed_mean:.6f} ± {d_signed_sd:.6f}")
print(f"  |ΔFNR|: {d_abs_mean:.6f} ± {d_abs_sd:.6f}")

summary = {
    "transfer": {
        "tag": TRANSFER_TAG,
        "source_dataset": "D4",
        "target_dataset": "D2",
        "fixed_threshold_value": float(FIXED_THR),
        "fixed_threshold_source": {
            "path": str(D4_MONO_SUMMARY),
            "key": "val_opt_threshold_mean_sd.mean",
            "note": "Mean D4 VAL-opt threshold recorded by D4 monolingual test run, applied to ALL seeds.",
        },
        "source_trainval_experiment_used": str(chosen_exp),
    },

    "target": {
        "dx_out_root": str(D2_OUT_ROOT),
        "manifest_all": str(D2_MANIFEST_ALL),
        "n_test": int(len(test_df)),
        "label_counts_test": test_df["label_num"].value_counts(dropna=False).to_dict(),
        "sex_counts_test_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),
    },

    "seeds": SEEDS,

    "test_aurocs_by_seed": {str(r["seed"]): float(r["test_auroc"]) for r in seed_results},
    "mean_test_auroc": float(mean_auc),
    "t_crit_df2_95": float(t_crit),
    "half_width_95_ci": float(half_width),
    "ci95_test_auroc": ci95,

    "threshold_metrics_test_mean_sd": agg,
    "confusion_matrix_by_seed": cm_by_seed,

    "fairness_test": {
        "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at the fixed transfer threshold.",
        "fnr_by_sex_norm_by_seed": fnr_by_seed,
        "delta_fnr_F_minus_M_by_seed": delta_signed_by_seed,
        "delta_fnr_abs_by_seed": delta_abs_by_seed,
        "fnr_M_mean_sd": {"mean": float(fnr_m_mean), "sd": float(fnr_m_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, fnr_m_vals)}},
        "fnr_F_mean_sd": {"mean": float(fnr_f_mean), "sd": float(fnr_f_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, fnr_f_vals)}},
        "delta_fnr_F_minus_M_mean_sd": {"mean": float(d_signed_mean), "sd": float(d_signed_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, d_signed_vals)}},
        "delta_fnr_abs_mean_sd": {"mean": float(d_abs_mean), "sd": float(d_abs_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, d_abs_vals)}},
        "denominators_PD_by_seed": {str(s): {"n_PD_M": float(n_pd_m_vals[i]), "n_PD_F": float(n_pd_f_vals[i])} for i, s in enumerate(SEEDS)},
        "sex_normalization_note": "D2 mapping: male->M, female->F; otherwise UNK. ΔFNR computed only when both M and F have defined FNR.",
    },

    "run_dirs": [r["run_dir"] for r in seed_results],
    "timestamp": time.strftime("%Y%m%d_%H%M%S"),
}

# Save summary as requested
summary_path = ZS_ROOT / "summary_zero_shot2.json"
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(summary, f, indent=2)

history_path = ZS_ROOT / "history_index.jsonl"
with open(history_path, "a", encoding="utf-8") as f:
    f.write(json.dumps(summary) + "\n")

print("\nWROTE summary:", str(summary_path))
print("APPENDED history index:", str(history_path))
print("Open this folder to access artifacts:", str(ZS_ROOT))

# -------------------------
# Stop runtime (Colab)
# Inputs: none
# Output: attempts to release the GPU instance
# -------------------------
print("\nAll done. Unassigning the runtime to stop the L4 instance...")
try:
    from google.colab import runtime  # type: ignore
    print("Calling runtime.unassign() now.")
    runtime.unassign()
except Exception as e:
    print("Could not unassign runtime automatically. You can stop the runtime manually in Colab.")
    print("Reason:", repr(e))

The following cell runs a cross language zero shot test from English (UK, D5) to Slovak (D2). It checks whether classifier heads trained on the source dataset (D5, MDVR-KCL English) can detect Parkinson’s disease in the target dataset (D2, Slovak) without any retraining on D2. The evaluation uses only the D2 test split from `D2/manifests/manifest_all.csv`. The cell finds the most recent completed D5 train and validation experiment that contains saved head checkpoints for three seeds (1337, 2024, 7777) and applies those trained heads directly to D2 audio.

A single fixed decision threshold is used for all seeds and all D2 test clips. This threshold is taken from D5’s previously saved `monolingual_test_runs/summary_test.json`. The preferred value is `val_opt_threshold_mean_sd.mean`, with a fallback to the mean of `val_opt_threshold_by_seed` if needed. This avoids tuning any threshold on D2 and keeps the transfer evaluation consistent and fair.

Before running inference, the cell performs setup and safety checks. It makes sure Google Drive is mounted, confirms that no local files are shadowing PyTorch or Transformers, prints key paths and the active compute device, loads the D2 manifest, checks that all required columns exist, filters strictly to the test split, and stops early if any referenced audio files are missing. Each clip is assigned to a simple task group, where `task == "vowl"` is treated as vowel and everything else is treated as other. Sex values in D2, such as male or female and common variants, are normalized to `sex_norm` values of M, F, or UNK so sex based reporting remains consistent even if unexpected values appear.

For each seed, the cell loads the matching D5 head checkpoint (`best_heads.pt`) into a frozen Wav2Vec2 two head classifier, with one head for vowel clips and one for other clips. It then runs inference on the full D2 test set to produce a Parkinson’s probability for each clip. Using these scores, it computes overall AUROC, which does not depend on a threshold, and fixed threshold metrics at the D5 derived threshold. These include the confusion matrix, accuracy, precision, sensitivity or recall, specificity, F1 score, MCC, and the Fisher exact test p value. Fairness is computed as ΔFNR, defined as FNR(F) minus FNR(M), using only true Parkinson’s clips, and confusion counts are also reported by sex. For each seed, the cell saves a ROC curve, an overall confusion matrix, additional sex specific confusion matrices when both M and F are present, and a detailed `metrics.json` file under
`<D2_OUT_ROOT>/Cross_Language_Zero_Shot_Runs/run_ENUK_to_SK_seed<seed>/`.

After all three seeds finish, results are combined across seeds. The cell reports the mean AUROC with a 95 percent confidence interval using a t distribution with n equal to 3, along with the mean and standard deviation of the fixed threshold metrics and fairness values. A combined summary is saved as `summary_zero_shot3.json` and the same record is appended to `history_index.jsonl`, both under `<D2_OUT_ROOT>/Cross_Language_Zero_Shot_Runs/`. The cell ends by unassigning the Colab runtime to shut down the GPU.

In [None]:
# =========================
# ENUK → SK Zero-shot (D5 → D2): Test-only transfer with fixed threshold
# - Loads D5-trained heads (most recent train+val experiment)
# - Runs inference on D2 TEST only (from manifest_all.csv)
# - Uses one fixed threshold read from D5 monolingual summary_test.json
# - Saves per-seed metrics + plots, plus an aggregated summary under D2
# =========================

import os, json, math, random, time, warnings
from pathlib import Path

import numpy as np
import pandas as pd
import soundfile as sf

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import Wav2Vec2Model

from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, matthews_corrcoef
from scipy.stats import fisher_exact
import matplotlib.pyplot as plt

from tqdm.auto import tqdm

# -------------------------
# Import safety: ensure local files do not override core libraries
# Inputs: /content directory
# Output: raises early if torch/transformers are shadowed by local names
# -------------------------
if os.path.exists("/content/torch.py") or os.path.exists("/content/torch/__init__.py"):
    raise RuntimeError("Found local /content/torch.py or /content/torch/ that shadows PyTorch. Rename/remove it and restart runtime.")
if os.path.exists("/content/transformers.py") or os.path.exists("/content/transformers/__init__.py"):
    raise RuntimeError("Found local /content/transformers.py or /content/transformers/ that shadows Hugging Face Transformers. Rename/remove it and restart runtime.")

# -------------------------
# Drive mount (Colab)
# Inputs: Google Drive
# Output: /content/drive mounted for reading manifests/checkpoints and writing results
# -------------------------
if not os.path.isdir("/content/drive/MyDrive"):
    from google.colab import drive  # type: ignore
    drive.mount("/content/drive")

# -------------------------
# Paths: target (D2) + source (D5)
# Inputs: optional globals D2_OUT_ROOT, D5_OUT_ROOT (if set earlier), else fallbacks
# Output: D2_MANIFEST_ALL and D5_MONO_SUMMARY paths
# -------------------------
D2_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D2-Slovak (EWA-DB)/EWA-DB/preprocessed_v1"
D5_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D5-English (MDVR-KCL)/preprocessed_v2"

D2_OUT_ROOT = globals().get("D2_OUT_ROOT", D2_OUT_ROOT_FALLBACK)
D5_OUT_ROOT = globals().get("D5_OUT_ROOT", D5_OUT_ROOT_FALLBACK)

D2_MANIFEST_ALL = f"{D2_OUT_ROOT}/manifests/manifest_all.csv"

# D5 monolingual summary (provides the fixed threshold source)
D5_MONO_SUMMARY = Path(D5_OUT_ROOT) / "monolingual_test_runs" / "summary_test.json"

# -------------------------
# Run settings (kept consistent with other test cells)
# Inputs: constants below
# Output: printed configuration + runtime behavior (batching, AMP, device)
# -------------------------
TRANSFER_TAG    = "ENUK_to_SK"        # D5 -> D2
SEEDS           = [1337, 2024, 7777]
BACKBONE_CKPT   = "facebook/wav2vec2-base"
SR_EXPECTED     = 16000
TINY_THRESH     = 1e-4

EFFECTIVE_BS    = 64
PER_DEVICE_BS   = 16
GRAD_ACCUM      = max(1, EFFECTIVE_BS // PER_DEVICE_BS)  # printed only

DROPOUT_P       = 0.2

NUM_WORKERS     = 0
PIN_MEMORY      = False

USE_AMP         = True

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
warnings.filterwarnings("ignore", message="Passing `gradient_checkpointing` to a config initialization is deprecated")

print("D2_OUT_ROOT (target):", D2_OUT_ROOT)
print("D2_MANIFEST_ALL:", D2_MANIFEST_ALL)
print("D5_OUT_ROOT (source):", D5_OUT_ROOT)
print("D5_MONO_SUMMARY:", str(D5_MONO_SUMMARY))
print("DEVICE:", DEVICE)
if DEVICE.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
print("PER_DEVICE_BS:", PER_DEVICE_BS, "| EFFECTIVE_BS:", PER_DEVICE_BS * GRAD_ACCUM)
print("NUM_WORKERS:", NUM_WORKERS, "| PIN_MEMORY:", PIN_MEMORY)
print("USE_AMP:", bool(USE_AMP and DEVICE.type == "cuda"))

# -------------------------
# Fixed threshold: read from D5 monolingual summary_test.json
# Inputs: D5_MONO_SUMMARY (JSON)
# Output: FIXED_THR (float) used for all three seeds on D2
# -------------------------
if not D5_MONO_SUMMARY.exists():
    raise FileNotFoundError(
        "Missing D5 monolingual summary_test.json.\n"
        f"Expected: {str(D5_MONO_SUMMARY)}\n"
        "Run the D5 monolingual test cell first (the one that writes summary_test.json)."
    )

with open(D5_MONO_SUMMARY, "r", encoding="utf-8") as f:
    d5_sum = json.load(f)

# Preferred: summary["val_opt_threshold_mean_sd"]["mean"]
if "val_opt_threshold_mean_sd" in d5_sum and isinstance(d5_sum["val_opt_threshold_mean_sd"], dict):
    FIXED_THR = float(d5_sum["val_opt_threshold_mean_sd"].get("mean", float("nan")))
else:
    # Fallback: mean of per-seed values if present
    by_seed = d5_sum.get("val_opt_threshold_by_seed", None)
    if isinstance(by_seed, dict) and len(by_seed) > 0:
        FIXED_THR = float(np.nanmean([float(v) for v in by_seed.values()]))
    else:
        FIXED_THR = float("nan")

if not np.isfinite(FIXED_THR):
    raise RuntimeError(
        "Could not read a finite fixed threshold from D5 monolingual summary_test.json.\n"
        "Expected keys: val_opt_threshold_mean_sd.mean (preferred) or val_opt_threshold_by_seed."
    )

print(f"\nFixed transfer threshold (from D5 monolingual mean VAL-opt): {FIXED_THR:.6f}")

# -------------------------
# Target manifest: load D2 and keep TEST only
# Inputs: D2 manifest_all.csv
# Output: test_df with required columns + basic counts printed
# -------------------------
if not os.path.exists(D2_MANIFEST_ALL):
    raise FileNotFoundError(f"Missing D2 manifest_all.csv: {D2_MANIFEST_ALL}")

m_all = pd.read_csv(D2_MANIFEST_ALL)

# Require these fields for inference + fairness plots
req_cols = {"split", "clip_path", "label_num", "task", "sex"}
missing = [c for c in sorted(req_cols) if c not in m_all.columns]
if missing:
    raise ValueError(f"D2 manifest missing required columns: {missing}. Found: {list(m_all.columns)}")

# Ensure the manifest content is actually D2
if "dataset" in m_all.columns and m_all["dataset"].notna().any():
    dataset_id = str(m_all["dataset"].astype(str).value_counts(dropna=True).idxmax())
    m_all = m_all[m_all["dataset"].astype(str) == dataset_id].copy()
else:
    dataset_id = "DX"

if dataset_id != "D2":
    raise RuntimeError(
        f"Expected dataset_id=='D2' but got {dataset_id!r}. "
        "This usually means D2_OUT_ROOT is wrong or inherited from a previous cell.\n"
        f"D2_OUT_ROOT={D2_OUT_ROOT}"
    )

# Keep a stable column set (missing ones are filled with NaN)
keep_cols = ["clip_path", "label_num", "task", "speaker_id", "sex", "age", "duration_sec", "split"]
for c in keep_cols:
    if c not in m_all.columns:
        m_all[c] = np.nan
m_all = m_all[keep_cols].copy()

test_df = m_all[m_all["split"].isin(["test"])].copy().reset_index(drop=True)

print(f"\nTarget dataset inferred: {dataset_id}")
print(f"TEST rows: {len(test_df)}")
if len(test_df) == 0:
    raise RuntimeError("After filtering to split=='test', D2 manifest has 0 rows.")

print("TEST label counts:", test_df["label_num"].value_counts(dropna=False).to_dict())
print("TEST sex counts (raw):", test_df["sex"].value_counts(dropna=False).to_dict())

# -------------------------
# Path check: ensure target audio files exist
# Inputs: test_df.clip_path
# Output: raises early if any clip files are missing (shows a few examples)
# -------------------------
def _fail_fast_missing_paths(df: pd.DataFrame, name: str):
    missing_paths = []
    for p in tqdm(df["clip_path"].astype(str).tolist(), desc=f"Check {name} clip_path exists", dynamic_ncols=True):
        if not os.path.exists(p):
            missing_paths.append(p)
            if len(missing_paths) >= 10:
                break
    if missing_paths:
        raise FileNotFoundError(f"{name}: missing clip_path(s). Examples: {missing_paths}")

_fail_fast_missing_paths(test_df, "TEST")

# -------------------------
# Task grouping: map each row to vowel vs other
# Inputs: test_df.task
# Output: test_df.task_group used to pick the correct head during inference
# -------------------------
def _task_group(task_val) -> str:
    return "vowel" if str(task_val) == "vowl" else "other"

test_df["task_group"] = test_df["task"].apply(_task_group)

# -------------------------
# Sex normalization for fairness/charts (D2 uses male/female)
# Inputs: test_df.sex
# Output: test_df.sex_norm in {M,F,UNK}
# -------------------------
def normalize_sex_d2(val) -> str:
    """
    Returns 'M', 'F', or 'UNK'
    D2 encoding: male/female (strings)
    Also handles common variants.
    """
    if pd.isna(val):
        return "UNK"
    s = str(val).strip().lower()
    if s in {"m", "male", "man", "masc", "masculine"}:
        return "M"
    if s in {"f", "female", "woman", "fem", "feminine"}:
        return "F"
    return "UNK"

test_df["sex_norm"] = test_df["sex"].apply(normalize_sex_d2)
print("TEST sex counts (normalized):", test_df["sex_norm"].value_counts(dropna=False).to_dict())
if (test_df["sex_norm"] == "UNK").any():
    print("NOTE: Some D2 'sex' values could not be normalized to M/F and were counted as 'UNK' for fairness and sex charts.")

# -------------------------
# Dataset + collator: read audio and build attention masks
# Inputs: test_df rows + audio files
# Output: DataLoader batches with padded input_values, attention_mask, labels, task_group, sex_norm
# -------------------------
class AudioManifestDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        clip_path = str(row["clip_path"])
        label = int(row["label_num"])
        task_group = str(row["task_group"])
        sex_norm = str(row["sex_norm"])

        if not os.path.exists(clip_path):
            raise FileNotFoundError(f"Missing clip_path: {clip_path}")

        y, sr = sf.read(clip_path, always_2d=False)
        if isinstance(y, np.ndarray) and y.ndim > 1:
            y = np.mean(y, axis=1)
        y = np.asarray(y, dtype=np.float32)

        if int(sr) != SR_EXPECTED:
            raise ValueError(f"Sample rate mismatch (expected {SR_EXPECTED}): {clip_path} has sr={sr}")

        # attention_mask mainly matters for vowel clips (ignore padded tail)
        attn = np.ones((len(y),), dtype=np.int64)
        if task_group == "vowel":
            k = -1
            for j in range(len(y) - 1, -1, -1):
                if abs(float(y[j])) > TINY_THRESH:
                    k = j
                    break
            if k >= 0:
                attn[:k+1] = 1
                attn[k+1:] = 0
            else:
                attn[:] = 1
        else:
            attn[:] = 1

        return {
            "input_values": torch.from_numpy(y),
            "attention_mask": torch.from_numpy(attn),
            "labels": torch.tensor(label, dtype=torch.long),
            "task_group": task_group,
            "sex_norm": sex_norm,
        }

def collate_fn(batch):
    max_len = int(max(b["input_values"].numel() for b in batch))
    input_vals, attn_masks, labels, task_groups, sex_norms = [], [], [], [], []
    for b in batch:
        x = b["input_values"]
        a = b["attention_mask"]
        pad = max_len - x.numel()
        if pad > 0:
            x = torch.cat([x, torch.zeros(pad, dtype=x.dtype)], dim=0)
            a = torch.cat([a, torch.zeros(pad, dtype=a.dtype)], dim=0)
        input_vals.append(x)
        attn_masks.append(a)
        labels.append(b["labels"])
        task_groups.append(b["task_group"])
        sex_norms.append(b["sex_norm"])
    return {
        "input_values": torch.stack(input_vals, dim=0),
        "attention_mask": torch.stack(attn_masks, dim=0),
        "labels": torch.stack(labels, dim=0),
        "task_group": task_groups,
        "sex_norm": sex_norms,
    }

# -------------------------
# Model: frozen wav2vec2 backbone + two heads (vowel vs other)
# Inputs: backbone checkpoint name + dropout_p
# Output: logits for PD vs healthy for each clip, using the correct head per task_group
# -------------------------
class Wav2Vec2TwoHeadClassifier(nn.Module):
    def __init__(self, ckpt: str, dropout_p: float):
        super().__init__()
        self.backbone = Wav2Vec2Model.from_pretrained(ckpt, use_safetensors=True)
        for p in self.backbone.parameters():
            p.requires_grad = False

        H = int(self.backbone.config.hidden_size)
        self.pre_vowel = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.pre_other = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.head_vowel = nn.Linear(H, 2)
        self.head_other = nn.Linear(H, 2)

    def masked_mean_pool(self, last_hidden, attn_mask_samples):
        feat_mask = self.backbone._get_feature_vector_attention_mask(last_hidden.shape[1], attn_mask_samples)
        feat_mask = feat_mask.to(last_hidden.device).unsqueeze(-1).type_as(last_hidden)
        masked = last_hidden * feat_mask
        denom = feat_mask.sum(dim=1).clamp(min=1.0)
        return masked.sum(dim=1) / denom

    def _heads_fp32(self, x_fp_any: torch.Tensor, head: nn.Module) -> torch.Tensor:
        # Keep head math in fp32 even when AMP is enabled
        x = x_fp_any.float()
        if x.is_cuda:
            with torch.amp.autocast(device_type="cuda", enabled=False):
                return head(x)
        return head(x)

    def forward_logits(self, input_values, attention_mask, task_group):
        with torch.no_grad():
            out = self.backbone(input_values=input_values, attention_mask=attention_mask)
            last_hidden = out.last_hidden_state

        pooled = self.masked_mean_pool(last_hidden, attention_mask)  # [B,H]

        z_v = self.pre_vowel(pooled.float())
        z_o = self.pre_other(pooled.float())

        logits_v = self._heads_fp32(z_v, self.head_vowel)
        logits_o = self._heads_fp32(z_o, self.head_other)

        # Select the correct head per row based on task_group
        is_vowel = torch.tensor([tg == "vowel" for tg in task_group], device=pooled.device, dtype=torch.bool)
        logits = logits_o.clone()
        if is_vowel.any():
            logits[is_vowel] = logits_v[is_vowel]
        return logits

# -------------------------
# Metrics + plot helpers
# Inputs: y_true and y_prob (PD probability), plus a threshold for thresholded metrics
# Output: numeric metrics dict + PNG charts saved to disk
# -------------------------
def compute_auc(y_true, y_prob):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    if len(np.unique(y_true)) < 2:
        return float("nan")
    return float(roc_auc_score(y_true, y_prob))

def compute_threshold_metrics(y_true, y_prob, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])

    eps = 1e-12
    acc = (TP + TN) / max(1, (TP + TN + FP + FN))
    prec = TP / (TP + FP + eps)
    rec = TP / (TP + FN + eps)     # sensitivity
    f1 = 2 * prec * rec / (prec + rec + eps)
    spec = TN / (TN + FP + eps)

    try:
        mcc = float(matthews_corrcoef(y_true, y_pred)) if len(np.unique(y_pred)) > 1 else float("nan")
    except Exception:
        mcc = float("nan")

    try:
        _, pval = fisher_exact([[TN, FP], [FN, TP]], alternative="two-sided")
        pval = float(pval)
    except Exception:
        pval = float("nan")

    return {
        "threshold": float(thr),
        "confusion_matrix": {"TN": TN, "FP": FP, "FN": FN, "TP": TP},
        "accuracy": float(acc),
        "precision": float(prec),
        "recall": float(rec),
        "f1_score": float(f1),
        "sensitivity": float(rec),
        "specificity": float(spec),
        "mcc": float(mcc),
        "p_value_fisher_two_sided": float(pval),
    }

def save_roc_curve_png(y_true, y_prob, out_png, title_suffix="Test"):
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    plt.figure()
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curve ({title_suffix})")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def save_confusion_png(y_true, y_prob, out_png, thr=0.5, title_suffix="Test"):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(int)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])

    plt.figure()
    plt.imshow(cm)
    plt.xticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.yticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix ({title_suffix}, thr={thr:.2f})")
    for (i, j), v in np.ndenumerate(cm):
        plt.text(j, i, str(v), ha="center", va="center")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

# -------------------------
# Fairness: FNR per sex and ΔFNR = FNR(F) - FNR(M)
# Inputs: y_true, y_prob, sex_norm array, threshold
# Output: per-group FNR details + signed and absolute gap
# -------------------------
def compute_fnr_by_group_signed(y_true, y_prob, groups, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    groups = np.asarray(list(groups), dtype=object)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    out = {}
    for g in sorted(set(groups.tolist())):
        mask_g = (groups == g)
        if int(mask_g.sum()) == 0:
            continue

        pos_mask = mask_g & (y_true == 1)
        n_pos = int(pos_mask.sum())
        if n_pos == 0:
            out[g] = {"n_total": int(mask_g.sum()), "n_pos": 0, "tp": 0, "fn": 0, "fnr": float("nan")}
            continue

        tp = int(((y_pred == 1) & pos_mask).sum())
        fn = int(((y_pred == 0) & pos_mask).sum())
        fnr = float(fn / max(1, (fn + tp)))

        out[g] = {"n_total": int(mask_g.sum()), "n_pos": int(n_pos), "tp": int(tp), "fn": int(fn), "fnr": float(fnr)}

    fnr_m = out.get("M", {}).get("fnr", float("nan"))
    fnr_f = out.get("F", {}).get("fnr", float("nan"))

    if (not np.isnan(fnr_m)) and (not np.isnan(fnr_f)):
        delta_signed = float(fnr_f - fnr_m)
        delta_abs = float(abs(delta_signed))
    else:
        delta_signed = float("nan")
        delta_abs = float("nan")

    return out, delta_signed, delta_abs

def compute_confusion_counts(y_true, y_prob, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])
    return {"TN": TN, "FP": FP, "FN": FN, "TP": TP}

def compute_confusion_by_group(y_true, y_prob, groups, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    groups = np.asarray(list(groups), dtype=object)
    out = {}
    for g in sorted(set(groups.tolist())):
        mask = (groups == g)
        if int(mask.sum()) == 0:
            continue
        out[g] = {"n": int(mask.sum()), "confusion": compute_confusion_counts(y_true[mask], y_prob[mask], thr=thr)}
    return out

# -------------------------
# Reproducibility: set all RNG seeds
# Inputs: seed int
# Output: deterministic settings for random, numpy, torch
# -------------------------
def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# -------------------------
# Source checkpoint selection: pick most recent D5 trainval exp with all seeds
# Inputs: D5 trainval_runs folder
# Output: chosen_exp folder used to load best_heads.pt per seed
# -------------------------
D5_TRAINVAL_ROOT = Path(D5_OUT_ROOT) / "trainval_runs"
if not D5_TRAINVAL_ROOT.exists():
    raise FileNotFoundError(f"Missing D5 trainval_runs folder: {str(D5_TRAINVAL_ROOT)}")

exp_dirs = sorted([p for p in D5_TRAINVAL_ROOT.glob("exp_*") if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
if not exp_dirs:
    raise FileNotFoundError(f"No exp_* folders found under: {str(D5_TRAINVAL_ROOT)}")

def _has_all_seeds(exp_path: Path, seeds: list):
    for s in seeds:
        p = exp_path / f"run_D5_seed{s}" / "best_heads.pt"
        if not p.exists():
            return False
    return True

chosen_exp = None
for ed in exp_dirs:
    if _has_all_seeds(ed, SEEDS):
        chosen_exp = ed
        break

if chosen_exp is None:
    sample = exp_dirs[0]
    raise FileNotFoundError(
        "Could not find a D5 trainval experiment with all 3 best_heads.pt files.\n"
        f"Expected: {str(D5_TRAINVAL_ROOT)}/exp_*/run_D5_seedXXXX/best_heads.pt\n"
        f"Most recent exp checked: {str(sample)}"
    )

print("\nUsing SOURCE (D5) Train+Val experiment folder:")
print(" ", str(chosen_exp))

# -------------------------
# Output folder: write only under D2
# Inputs: D2_OUT_ROOT
# Output: ZS_ROOT folder for per-seed runs + summary files
# -------------------------
ZS_ROOT = Path(D2_OUT_ROOT) / "Cross_Language_Zero_Shot_Runs"
ZS_ROOT.mkdir(parents=True, exist_ok=True)

# -------------------------
# DataLoader: D2 TEST only (plus a small warm-up read)
# Inputs: test_df and audio files
# Output: test_loader ready for inference
# -------------------------
test_ds = AudioManifestDataset(test_df)
test_loader = DataLoader(
    test_ds,
    batch_size=PER_DEVICE_BS,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    collate_fn=collate_fn,
)

print("\nWarm-up: loading up to 2 TEST batches...")
t0 = time.time()
nb = len(test_loader)
wb = min(2, nb)
if wb == 0:
    raise RuntimeError("TEST DataLoader has 0 batches. Check test_df length and PER_DEVICE_BS.")
it = iter(test_loader)
for i in range(wb):
    _ = next(it)
    print(f"  loaded TEST warmup batch {i+1}/{wb}")
print(f"Warm-up done in {time.time()-t0:.2f}s")

# -------------------------
# Checkpoint loader: load only the trained heads from best_heads.pt
# Inputs: model instance + best_heads.pt path
# Output: model with head weights restored for this seed
# -------------------------
def load_heads_into_model(model: Wav2Vec2TwoHeadClassifier, best_heads_path: Path):
    if not best_heads_path.exists():
        raise FileNotFoundError(f"Missing best_heads.pt: {str(best_heads_path)}")
    state = torch.load(str(best_heads_path), map_location="cpu")
    needed = ["pre_vowel", "pre_other", "head_vowel", "head_other"]
    missing = [k for k in needed if k not in state]
    if missing:
        raise KeyError(
            f"best_heads.pt missing keys {missing}. Found keys: {list(state.keys())}. "
            "This zero-shot code expects the trainval save format."
        )
    model.pre_vowel.load_state_dict(state["pre_vowel"], strict=True)
    model.pre_other.load_state_dict(state["pre_other"], strict=True)
    model.head_vowel.load_state_dict(state["head_vowel"], strict=True)
    model.head_other.load_state_dict(state["head_other"], strict=True)
    return model

# -------------------------
# Inference: run model over D2 test_loader and collect probabilities
# Inputs: loader + model
# Output: arrays (y_true, y_prob, sex_norm) for metrics and plots
# -------------------------
def run_inference(loader, model, desc):
    use_amp = bool(USE_AMP and DEVICE.type == "cuda")
    all_probs, all_true, all_sex = [], [], []

    pbar = tqdm(loader, desc=desc, dynamic_ncols=True)
    with torch.inference_mode():
        for batch in pbar:
            input_values = batch["input_values"].to(DEVICE, non_blocking=False)
            attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
            labels = batch["labels"].to(DEVICE, non_blocking=False)
            task_group = batch["task_group"]
            sex_norm = batch["sex_norm"]

            with torch.amp.autocast(device_type="cuda", enabled=use_amp):
                logits = model.forward_logits(input_values, attention_mask, task_group)

            probs = torch.softmax(logits.float(), dim=-1)[:, 1].detach().cpu().numpy().astype(np.float64)
            all_probs.extend(probs.tolist())
            all_true.extend(labels.detach().cpu().numpy().astype(np.int64).tolist())
            all_sex.extend(list(sex_norm))

    return np.asarray(all_true, dtype=np.int64), np.asarray(all_probs, dtype=np.float64), np.asarray(all_sex, dtype=object)

# -------------------------
# Per-seed run: load D5 heads, evaluate D2 at FIXED_THR, save metrics + PNGs
# Inputs: seed, chosen_exp, FIXED_THR, D2 test_loader
# Output: run_<TRANSFER_TAG>_seedXXXX folder with metrics.json + plots
# -------------------------
def run_seed(seed: int):
    set_all_seeds(seed)

    run_dir = ZS_ROOT / f"run_{TRANSFER_TAG}_seed{seed}"
    run_dir.mkdir(parents=True, exist_ok=True)

    best_heads_path = chosen_exp / f"run_D5_seed{seed}" / "best_heads.pt"

    print(f"\n[seed={seed}] Loading SOURCE heads from:")
    print(" ", str(best_heads_path))
    print(f"[seed={seed}] Evaluating TARGET D2 TEST @ fixed_thr={FIXED_THR:.6f}")

    model = Wav2Vec2TwoHeadClassifier(BACKBONE_CKPT, dropout_p=DROPOUT_P).to(DEVICE)
    model = load_heads_into_model(model, best_heads_path)
    model.eval()

    yt, pt, st = run_inference(test_loader, model, desc=f"[seed={seed}] D2 TEST (zero-shot)")

    test_auc = compute_auc(yt, pt)  # threshold-free
    thr_metrics = compute_threshold_metrics(yt, pt, thr=FIXED_THR)

    fnr_by_sex, delta_f_minus_m, delta_abs = compute_fnr_by_group_signed(yt, pt, st, thr=FIXED_THR)
    confusion_by_sex = compute_confusion_by_group(yt, pt, st, thr=FIXED_THR)

    # Plots: always save overall ROC + confusion matrix
    roc_png = run_dir / "roc_curve.png"
    cm_png  = run_dir / "confusion_matrix.png"
    save_roc_curve_png(yt, pt, str(roc_png), title_suffix=f"D2 Test (seed={seed})")
    save_confusion_png(yt, pt, str(cm_png), thr=FIXED_THR, title_suffix=f"D2 Test (seed={seed})")

    # Sex-specific confusion matrices (only if that sex exists in D2 TEST)
    cm_m_png = None
    cm_f_png = None
    mask_m = (st == "M")
    mask_f = (st == "F")

    if int(mask_m.sum()) > 0:
        cm_m_png = run_dir / "confusion_matrix_M.png"
        save_confusion_png(yt[mask_m], pt[mask_m], str(cm_m_png), thr=FIXED_THR, title_suffix=f"D2 Test SEX=M (seed={seed})")

    if int(mask_f.sum()) > 0:
        cm_f_png = run_dir / "confusion_matrix_F.png"
        save_confusion_png(yt[mask_f], pt[mask_f], str(cm_f_png), thr=FIXED_THR, title_suffix=f"D2 Test SEX=F (seed={seed})")

    # Structured metrics saved for later analysis/paper tables
    metrics = {
        "transfer": {
            "tag": TRANSFER_TAG,
            "source_dataset": "D5",
            "target_dataset": "D2",
            "threshold_policy": "Fixed threshold = mean VAL-opt threshold from D5 monolingual test summary (Youden J on D5 VAL)",
            "fixed_threshold_value": float(FIXED_THR),
            "d5_mono_summary_path": str(D5_MONO_SUMMARY),
        },

        "target": {
            "dataset": "D2",
            "seed": int(seed),
            "n_test": int(len(test_df)),
            "label_counts_test": test_df["label_num"].value_counts(dropna=False).to_dict(),
            "sex_counts_test_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),
        },

        "test_auroc": float(test_auc),

        "threshold_metrics_test": thr_metrics,
        "test_threshold_used": float(FIXED_THR),

        "fairness_test": {
            "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at fixed transfer threshold.",
            "fnr_by_sex_norm": fnr_by_sex,
            "delta_fnr_F_minus_M": float(delta_f_minus_m),
            "delta_fnr_abs": float(delta_abs),
            "note": "If n_PD for a sex is 0, its FNR is NaN and ΔFNR is NaN.",
            "sex_normalization_note": "D2 mapping: male->M, female->F; otherwise UNK.",
        },

        "confusion_by_sex_norm": confusion_by_sex,

        "artifacts": {
            "roc_curve_png": str(roc_png),
            "confusion_matrix_png": str(cm_png),
            "confusion_matrix_M_png": str(cm_m_png) if cm_m_png is not None else None,
            "confusion_matrix_F_png": str(cm_f_png) if cm_f_png is not None else None,
        },

        "paths": {
            "d2_out_root": str(D2_OUT_ROOT),
            "zero_shot_root": str(ZS_ROOT),
            "run_dir": str(run_dir),
            "source_trainval_experiment_used": str(chosen_exp),
            "source_best_heads_path": str(best_heads_path),
        },

        "backbone_ckpt": BACKBONE_CKPT,
        "dropout_p": float(DROPOUT_P),
        "timestamp": time.strftime("%Y%m%d_%H%M%S"),
    }

    metrics_path = run_dir / "metrics.json"
    with open(metrics_path, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2)

    print(f"[seed={seed}] DONE | test_AUROC={test_auc:.6f}")

    def _fmt_fnr(g):
        d = fnr_by_sex.get(g, None)
        if d is None:
            return "n/a"
        return f"fnr={d['fnr']:.6f} (n_PD={d['n_pos']}, fn={d['fn']}, tp={d['tp']})"

    print(f"[seed={seed}] TEST metrics @ fixed_thr={FIXED_THR:.6f}")
    print(f"[seed={seed}] FAIRNESS (H3) @ fixed_thr={FIXED_THR:.6f}:")
    print("  M:", _fmt_fnr("M"))
    print("  F:", _fmt_fnr("F"))
    if "UNK" in fnr_by_sex:
        print("  UNK:", _fmt_fnr("UNK"))
    print("  ΔFNR (F-M):", f"{delta_f_minus_m:.6f}" if not np.isnan(delta_f_minus_m) else "nan")
    print("  |ΔFNR|:", f"{delta_abs:.6f}" if not np.isnan(delta_abs) else "nan")

    print(f"[seed={seed}] WROTE:")
    print(" ", str(metrics_path))
    print(" ", str(roc_png))
    print(" ", str(cm_png))
    if cm_m_png is not None:
        print(" ", str(cm_m_png))
    if cm_f_png is not None:
        print(" ", str(cm_f_png))

    return {
        "seed": int(seed),
        "test_auroc": float(test_auc),
        "thr_metrics_test": thr_metrics,
        "fnr_by_sex_test": fnr_by_sex,
        "delta_f_minus_m_test": float(delta_f_minus_m),
        "delta_abs_test": float(delta_abs),
        "run_dir": str(run_dir),
    }

# -------------------------
# Aggregate across 3 seeds: AUROC mean ± 95% CI, others mean ± SD
# Inputs: per-seed outputs from run_seed()
# Output: summary_zero_shot3.json + history_index.jsonl under ZS_ROOT
# -------------------------
seed_results = []
for seed in SEEDS:
    seed_results.append(run_seed(seed))

aurocs = [r["test_auroc"] for r in seed_results]
t_crit = 4.302652729911275  # df=2, 95% CI
n = len(aurocs)
mean_auc = float(np.mean(aurocs))
std_auc = float(np.std(aurocs, ddof=1)) if n > 1 else 0.0
half_width = float(t_crit * (std_auc / math.sqrt(n))) if n > 1 else 0.0
ci95 = [float(mean_auc - half_width), float(mean_auc + half_width)]

def _mean_sd(vals):
    vals = np.asarray(vals, dtype=np.float64)
    return float(np.nanmean(vals)), float(np.nanstd(vals, ddof=1)) if np.sum(~np.isnan(vals)) > 1 else 0.0

thr_list = [r["thr_metrics_test"] for r in seed_results]
keys = ["accuracy","precision","recall","f1_score","sensitivity","specificity","mcc","p_value_fisher_two_sided"]

agg = {}
for k in keys:
    v = [float(tm.get(k, float("nan"))) for tm in thr_list]
    mu, sd = _mean_sd(v)
    agg[k] = {
        "mean": mu,
        "sd": sd,
        "values_by_seed": {str(r["seed"]): float(tm.get(k, float("nan"))) for r, tm in zip(seed_results, thr_list)}
    }

cm_by_seed = {str(r["seed"]): r["thr_metrics_test"]["confusion_matrix"] for r in seed_results}

# Fairness aggregation (H3): mean ± SD across seeds
fnr_by_seed = {str(r["seed"]): r["fnr_by_sex_test"] for r in seed_results}
delta_signed_by_seed = {str(r["seed"]): float(r["delta_f_minus_m_test"]) for r in seed_results}
delta_abs_by_seed = {str(r["seed"]): float(r["delta_abs_test"]) for r in seed_results}

fnr_m_vals, fnr_f_vals, n_pd_m_vals, n_pd_f_vals = [], [], [], []
d_signed_vals, d_abs_vals = [], []

for r in seed_results:
    s = str(r["seed"])
    d = fnr_by_seed.get(s, {})
    fnr_m_vals.append(float(d.get("M", {}).get("fnr", float("nan"))))
    fnr_f_vals.append(float(d.get("F", {}).get("fnr", float("nan"))))
    n_pd_m_vals.append(float(d.get("M", {}).get("n_pos", float("nan"))))
    n_pd_f_vals.append(float(d.get("F", {}).get("n_pos", float("nan"))))
    d_signed_vals.append(float(delta_signed_by_seed.get(s, float("nan"))))
    d_abs_vals.append(float(delta_abs_by_seed.get(s, float("nan"))))

fnr_m_mean, fnr_m_sd = _mean_sd(fnr_m_vals)
fnr_f_mean, fnr_f_sd = _mean_sd(fnr_f_vals)
d_signed_mean, d_signed_sd = _mean_sd(d_signed_vals)
d_abs_mean, d_abs_sd = _mean_sd(d_abs_vals)

print("\nZero-shot TEST AUROC by seed:")
for r in seed_results:
    print(f"  seed {r['seed']}: {r['test_auroc']:.6f}")
print(f"\nMean Zero-shot TEST AUROC: {mean_auc:.6f}")
print(f"95% CI (t, n=3): [{ci95[0]:.6f}, {ci95[1]:.6f}]")

print(f"\nFixed threshold used for ALL seeds (from D5 mono mean VAL-opt): {FIXED_THR:.6f}")

print("\nThreshold metrics on D2 TEST @ fixed threshold (mean ± SD across seeds):")
for k in ["accuracy","precision","sensitivity","specificity","f1_score","mcc"]:
    mu = agg[k]["mean"]
    sd = agg[k]["sd"]
    print(f"  {k}: {mu:.6f} ± {sd:.6f}")
print("  fisher_p_value_two_sided:", f"{agg['p_value_fisher_two_sided']['mean']:.6g} ± {agg['p_value_fisher_two_sided']['sd']:.6g}")

print("\nFAIRNESS (H3) on D2 TEST @ fixed threshold (mean ± SD across seeds):")
print(f"  FNR_M: {fnr_m_mean:.6f} ± {fnr_m_sd:.6f}")
print(f"  FNR_F: {fnr_f_mean:.6f} ± {fnr_f_sd:.6f}")
print(f"  ΔFNR (F-M): {d_signed_mean:.6f} ± {d_signed_sd:.6f}")
print(f"  |ΔFNR|: {d_abs_mean:.6f} ± {d_abs_sd:.6f}")

summary = {
    "transfer": {
        "tag": TRANSFER_TAG,
        "source_dataset": "D5",
        "target_dataset": "D2",
        "fixed_threshold_value": float(FIXED_THR),
        "fixed_threshold_source": {
            "path": str(D5_MONO_SUMMARY),
            "key": "val_opt_threshold_mean_sd.mean",
            "note": "Mean across seeds of D5 VAL-opt thresholds (Youden J) as recorded by D5 monolingual test run.",
        },
        "source_trainval_experiment_used": str(chosen_exp),
    },

    "target": {
        "dx_out_root": str(D2_OUT_ROOT),
        "manifest_all": str(D2_MANIFEST_ALL),
        "n_test": int(len(test_df)),
        "label_counts_test": test_df["label_num"].value_counts(dropna=False).to_dict(),
        "sex_counts_test_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),
    },

    "seeds": SEEDS,

    "test_aurocs_by_seed": {str(r["seed"]): float(r["test_auroc"]) for r in seed_results},
    "mean_test_auroc": float(mean_auc),
    "t_crit_df2_95": float(t_crit),
    "half_width_95_ci": float(half_width),
    "ci95_test_auroc": ci95,

    "threshold_metrics_test_mean_sd": agg,
    "confusion_matrix_by_seed": cm_by_seed,

    "fairness_test": {
        "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at the fixed transfer threshold.",
        "fnr_by_sex_norm_by_seed": fnr_by_seed,
        "delta_fnr_F_minus_M_by_seed": delta_signed_by_seed,
        "delta_fnr_abs_by_seed": delta_abs_by_seed,
        "fnr_M_mean_sd": {"mean": float(fnr_m_mean), "sd": float(fnr_m_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, fnr_m_vals)}},
        "fnr_F_mean_sd": {"mean": float(fnr_f_mean), "sd": float(fnr_f_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, fnr_f_vals)}},
        "delta_fnr_F_minus_M_mean_sd": {"mean": float(d_signed_mean), "sd": float(d_signed_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, d_signed_vals)}},
        "delta_fnr_abs_mean_sd": {"mean": float(d_abs_mean), "sd": float(d_abs_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, d_abs_vals)}},
        "denominators_PD_by_seed": {str(s): {"n_PD_M": float(n_pd_m_vals[i]), "n_PD_F": float(n_pd_f_vals[i])} for i, s in enumerate(SEEDS)},
        "sex_normalization_note": "D2 mapping: male->M, female->F; otherwise UNK. ΔFNR computed only when both M and F have defined FNR.",
    },

    "run_dirs": [r["run_dir"] for r in seed_results],
    "timestamp": time.strftime("%Y%m%d_%H%M%S"),
}

# Save summary as requested
summary_path = ZS_ROOT / "summary_zero_shot3.json"
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(summary, f, indent=2)

history_path = ZS_ROOT / "history_index.jsonl"
with open(history_path, "a", encoding="utf-8") as f:
    f.write(json.dumps(summary) + "\n")

print("\nWROTE summary:", str(summary_path))
print("APPENDED history index:", str(history_path))
print("Open this folder to access artifacts:", str(ZS_ROOT))

# -------------------------
# Stop runtime (Colab)
# Inputs: none
# Output: attempts to release the GPU instance
# -------------------------
print("\nAll done. Unassigning the runtime to stop the L4 instance...")
try:
    from google.colab import runtime  # type: ignore
    print("Calling runtime.unassign() now.")
    runtime.unassign()
except Exception as e:
    print("Could not unassign runtime automatically. You can stop the runtime manually in Colab.")
    print("Reason:", repr(e))

The following cell runs a cross language zero shot test from English Ah Sound (D6) to Slovak (D2). It evaluates whether classifier heads trained on the source dataset D6 can detect Parkinson’s disease in the target dataset D2 without any retraining on D2. The cell reads only the D2 test split from `D2/manifests/manifest_all.csv` and evaluates it using the most recent completed D6 train and validation experiment that contains saved head checkpoints for three seeds, 1337, 2024, and 7777.

A single fixed decision threshold is used for all seeds and all D2 test clips. This threshold is taken from D6’s saved `monolingual_test_runs/summary_test.json`. The preferred value is `val_opt_threshold_mean_sd.mean`, with a fallback to the mean of `val_opt_threshold_by_seed` if needed. This avoids tuning any threshold on D2 and keeps the transfer evaluation consistent.

Before inference begins, the cell runs several safety and data preparation checks. It confirms that Google Drive is mounted, checks that no local files are shadowing PyTorch or Transformers, prints the active paths and compute device, loads the D2 manifest, verifies that all required columns are present, filters strictly to the test split, and stops early if any referenced audio files are missing. Each clip is assigned to a simple task group, where `task == "vowl"` is treated as vowel and everything else is treated as other. Sex values in D2, such as male or female and common variants, are normalized to `sex_norm` values of M, F, or UNK so sex based reporting works even if unexpected values appear.

For each seed, the cell finds the correct D6 head checkpoint (`best_heads.pt`) from the selected D6 experiment folder. It rebuilds the same frozen Wav2Vec2 two head classifier used elsewhere, with one head for vowel clips and one for other clips, and loads the saved head weights. Inference is then run on the full D2 test set to produce a Parkinson’s probability for each clip. Using these scores, the cell computes overall AUROC, which is threshold free, and fixed threshold metrics at the D6 derived threshold. These include the confusion matrix, accuracy, precision, sensitivity or recall, specificity, F1 score, MCC, and the Fisher exact test p value. Fairness is computed as ΔFNR, defined as FNR(F) minus FNR(M), using only true Parkinson’s clips. Confusion counts by sex are also recorded. For each seed, the cell saves a ROC curve, an overall confusion matrix, additional sex specific confusion matrices when both M and F are present, and a detailed `metrics.json` file under
`<D2_OUT_ROOT>/Cross_Language_Zero_Shot_Runs/run_ENUS_AH_to_SK_seed<seed>/`.

After all three seeds complete, results are combined across seeds. The cell reports the mean AUROC with a 95 percent confidence interval using a t distribution with n equal to 3, along with the mean and standard deviation of the fixed threshold metrics and fairness values. A combined summary is written to `summary_zero_shot4.json` and the same record is appended to `history_index.jsonl`, both under `<D2_OUT_ROOT>/Cross_Language_Zero_Shot_Runs/`. Finally, the Colab runtime is unassigned to shut down the GPU.

In [None]:
# =========================
# ENUS_AH → SK Zero-shot (D6 → D2): Test-only transfer with fixed threshold
# - Loads D6-trained heads (most recent train+val experiment)
# - Runs inference on D2 TEST only (from manifest_all.csv)
# - Uses one fixed threshold read from D6 monolingual summary_test.json
# - Saves per-seed metrics + plots, plus an aggregated summary under D2
# =========================

import os, json, math, random, time, warnings
from pathlib import Path

import numpy as np
import pandas as pd
import soundfile as sf

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import Wav2Vec2Model

from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, matthews_corrcoef
from scipy.stats import fisher_exact
import matplotlib.pyplot as plt

from tqdm.auto import tqdm

# -------------------------
# Safety checks: avoid importing local files named like core libraries
# Inputs: /content directory
# Output: raises an error early if a naming conflict exists
# -------------------------
if os.path.exists("/content/torch.py") or os.path.exists("/content/torch/__init__.py"):
    raise RuntimeError("Found local /content/torch.py or /content/torch/ that shadows PyTorch. Rename/remove it and restart runtime.")
if os.path.exists("/content/transformers.py") or os.path.exists("/content/transformers/__init__.py"):
    raise RuntimeError("Found local /content/transformers.py or /content/transformers/ that shadows Hugging Face Transformers. Rename/remove it and restart runtime.")

# -------------------------
# Drive mount (Colab)
# Inputs: Google Drive
# Output: /content/drive mounted so manifests, checkpoints, and outputs can be read/written
# -------------------------
if not os.path.isdir("/content/drive/MyDrive"):
    from google.colab import drive  # type: ignore
    drive.mount("/content/drive")

# -------------------------
# Paths: target (D2) + source (D6)
# Inputs: optional globals D2_OUT_ROOT, D6_OUT_ROOT (if set earlier), else fallbacks
# Output: D2_MANIFEST_ALL and D6_MONO_SUMMARY paths
# -------------------------
D2_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D2-Slovak (EWA-DB)/EWA-DB/preprocessed_v1"
D6_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D6-Ah Sound (Figshare)/preprocessed_v1"

D2_OUT_ROOT = globals().get("D2_OUT_ROOT", D2_OUT_ROOT_FALLBACK)
D6_OUT_ROOT = globals().get("D6_OUT_ROOT", D6_OUT_ROOT_FALLBACK)

D2_MANIFEST_ALL = f"{D2_OUT_ROOT}/manifests/manifest_all.csv"

# D6 monolingual summary path (contains the fixed threshold source)
D6_MONO_SUMMARY = Path(D6_OUT_ROOT) / "monolingual_test_runs" / "summary_test.json"

# -------------------------
# Run settings (kept consistent with other test cells)
# Inputs: constants below
# Output: printed configuration + runtime behavior (batching, AMP, device)
# -------------------------
TRANSFER_TAG    = "ENUS_AH_to_SK"     # D6 -> D2
SEEDS           = [1337, 2024, 7777]
BACKBONE_CKPT   = "facebook/wav2vec2-base"
SR_EXPECTED     = 16000
TINY_THRESH     = 1e-4

EFFECTIVE_BS    = 64
PER_DEVICE_BS   = 16
GRAD_ACCUM      = max(1, EFFECTIVE_BS // PER_DEVICE_BS)  # printed only

DROPOUT_P       = 0.2

NUM_WORKERS     = 0
PIN_MEMORY      = False

USE_AMP         = True

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
warnings.filterwarnings("ignore", message="Passing `gradient_checkpointing` to a config initialization is deprecated")

print("D2_OUT_ROOT (target):", D2_OUT_ROOT)
print("D2_MANIFEST_ALL:", D2_MANIFEST_ALL)
print("D6_OUT_ROOT (source):", D6_OUT_ROOT)
print("D6_MONO_SUMMARY:", str(D6_MONO_SUMMARY))
print("DEVICE:", DEVICE)
if DEVICE.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
print("PER_DEVICE_BS:", PER_DEVICE_BS, "| EFFECTIVE_BS:", PER_DEVICE_BS * GRAD_ACCUM)
print("NUM_WORKERS:", NUM_WORKERS, "| PIN_MEMORY:", PIN_MEMORY)
print("USE_AMP:", bool(USE_AMP and DEVICE.type == "cuda"))

# -------------------------
# Fixed threshold: read from D6 monolingual summary_test.json
# Inputs: D6_MONO_SUMMARY (JSON)
# Output: FIXED_THR (float) used for all three seeds on D2
# -------------------------
if not D6_MONO_SUMMARY.exists():
    raise FileNotFoundError(
        "Missing D6 monolingual summary_test.json.\n"
        f"Expected: {str(D6_MONO_SUMMARY)}\n"
        "Run the D6 monolingual test cell first (the one that writes summary_test.json)."
    )

with open(D6_MONO_SUMMARY, "r", encoding="utf-8") as f:
    d6_sum = json.load(f)

# D6 schema: summary['val_opt_threshold_mean_sd']['mean'] (preferred)
if "val_opt_threshold_mean_sd" in d6_sum and isinstance(d6_sum["val_opt_threshold_mean_sd"], dict):
    FIXED_THR = float(d6_sum["val_opt_threshold_mean_sd"].get("mean", float("nan")))
else:
    # Fallback: mean of per-seed values if present
    by_seed = d6_sum.get("val_opt_threshold_by_seed", None)
    if isinstance(by_seed, dict) and len(by_seed) > 0:
        FIXED_THR = float(np.nanmean([float(v) for v in by_seed.values()]))
    else:
        FIXED_THR = float("nan")

if not np.isfinite(FIXED_THR):
    raise RuntimeError(
        "Could not read a finite fixed threshold from D6 monolingual summary_test.json.\n"
        "Expected keys: val_opt_threshold_mean_sd.mean (preferred) or val_opt_threshold_by_seed."
    )

print(f"\nFixed transfer threshold (from D6 monolingual mean VAL-opt): {FIXED_THR:.6f}")

# -------------------------
# Target manifest: load D2 and keep TEST only
# Inputs: D2 manifest_all.csv
# Output: test_df with required columns + basic counts printed
# -------------------------
if not os.path.exists(D2_MANIFEST_ALL):
    raise FileNotFoundError(f"Missing D2 manifest_all.csv: {D2_MANIFEST_ALL}")

m_all = pd.read_csv(D2_MANIFEST_ALL)

# Keep sex for fairness and sex-specific confusion matrices
req_cols = {"split", "clip_path", "label_num", "task", "sex"}
missing = [c for c in sorted(req_cols) if c not in m_all.columns]
if missing:
    raise ValueError(f"D2 manifest missing required columns: {missing}. Found: {list(m_all.columns)}")

# Infer dataset_id and ensure this is really D2
if "dataset" in m_all.columns and m_all["dataset"].notna().any():
    dataset_id = str(m_all["dataset"].astype(str).value_counts(dropna=True).idxmax())
    m_all = m_all[m_all["dataset"].astype(str) == dataset_id].copy()
else:
    dataset_id = "DX"

if dataset_id != "D2":
    raise RuntimeError(
        f"Expected dataset_id=='D2' but got {dataset_id!r}. "
        "This usually means D2_OUT_ROOT is wrong or inherited from a previous cell.\n"
        f"D2_OUT_ROOT={D2_OUT_ROOT}"
    )

# Keep a stable column set (missing ones are filled with NaN)
keep_cols = ["clip_path", "label_num", "task", "speaker_id", "sex", "age", "duration_sec", "split"]
for c in keep_cols:
    if c not in m_all.columns:
        m_all[c] = np.nan
m_all = m_all[keep_cols].copy()

test_df = m_all[m_all["split"].isin(["test"])].copy().reset_index(drop=True)

print(f"\nTarget dataset inferred: {dataset_id}")
print(f"TEST rows: {len(test_df)}")
if len(test_df) == 0:
    raise RuntimeError("After filtering to split=='test', D2 manifest has 0 rows.")

print("TEST label counts:", test_df["label_num"].value_counts(dropna=False).to_dict())
print("TEST sex counts (raw):", test_df["sex"].value_counts(dropna=False).to_dict())

# -------------------------
# Path check: ensure target audio files exist
# Inputs: test_df.clip_path
# Output: raises early if any clip files are missing (shows a few examples)
# -------------------------
def _fail_fast_missing_paths(df: pd.DataFrame, name: str):
    missing_paths = []
    for p in tqdm(df["clip_path"].astype(str).tolist(), desc=f"Check {name} clip_path exists", dynamic_ncols=True):
        if not os.path.exists(p):
            missing_paths.append(p)
            if len(missing_paths) >= 10:
                break
    if missing_paths:
        raise FileNotFoundError(f"{name}: missing clip_path(s). Examples: {missing_paths}")

_fail_fast_missing_paths(test_df, "TEST")

# -------------------------
# Task grouping: map each row to vowel vs other
# Inputs: test_df.task
# Output: test_df.task_group used to pick the correct head during inference
# -------------------------
def _task_group(task_val) -> str:
    return "vowel" if str(task_val) == "vowl" else "other"

test_df["task_group"] = test_df["task"].apply(_task_group)

# -------------------------
# Sex normalization for fairness/charts (D2 uses male/female)
# Inputs: test_df.sex
# Output: test_df.sex_norm in {M,F,UNK}
# -------------------------
def normalize_sex_d2(val) -> str:
    """
    Returns 'M', 'F', or 'UNK'
    D2 encoding (per manifest):
      male / female
    Also handles common variants.
    """
    if pd.isna(val):
        return "UNK"

    s = str(val).strip().lower()
    if s in {"m", "male", "man", "masc", "masculine"}:
        return "M"
    if s in {"f", "female", "woman", "fem", "feminine"}:
        return "F"
    return "UNK"

test_df["sex_norm"] = test_df["sex"].apply(normalize_sex_d2)
print("TEST sex counts (normalized):", test_df["sex_norm"].value_counts(dropna=False).to_dict())
if (test_df["sex_norm"] == "UNK").any():
    print("NOTE: Some D2 'sex' values could not be normalized to M/F and were counted as 'UNK' for fairness and sex charts.")

# -------------------------
# Dataset + collator: read audio and build attention masks
# Inputs: test_df rows + audio files
# Output: DataLoader batches with padded input_values, attention_mask, labels, task_group, sex_norm
# -------------------------
class AudioManifestDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        clip_path = str(row["clip_path"])
        label = int(row["label_num"])
        task_group = str(row["task_group"])
        sex_norm = str(row["sex_norm"])

        if not os.path.exists(clip_path):
            raise FileNotFoundError(f"Missing clip_path: {clip_path}")

        y, sr = sf.read(clip_path, always_2d=False)
        if isinstance(y, np.ndarray) and y.ndim > 1:
            y = np.mean(y, axis=1)
        y = np.asarray(y, dtype=np.float32)

        if int(sr) != SR_EXPECTED:
            raise ValueError(f"Sample rate mismatch (expected {SR_EXPECTED}): {clip_path} has sr={sr}")

        # attention_mask mainly matters for vowel clips (ignore padded tail)
        attn = np.ones((len(y),), dtype=np.int64)
        if task_group == "vowel":
            k = -1
            for j in range(len(y) - 1, -1, -1):
                if abs(float(y[j])) > TINY_THRESH:
                    k = j
                    break
            if k >= 0:
                attn[:k+1] = 1
                attn[k+1:] = 0
            else:
                attn[:] = 1
        else:
            attn[:] = 1

        return {
            "input_values": torch.from_numpy(y),
            "attention_mask": torch.from_numpy(attn),
            "labels": torch.tensor(label, dtype=torch.long),
            "task_group": task_group,
            "sex_norm": sex_norm,
        }

def collate_fn(batch):
    max_len = int(max(b["input_values"].numel() for b in batch))
    input_vals, attn_masks, labels, task_groups, sex_norms = [], [], [], [], []
    for b in batch:
        x = b["input_values"]
        a = b["attention_mask"]
        pad = max_len - x.numel()
        if pad > 0:
            x = torch.cat([x, torch.zeros(pad, dtype=x.dtype)], dim=0)
            a = torch.cat([a, torch.zeros(pad, dtype=a.dtype)], dim=0)
        input_vals.append(x)
        attn_masks.append(a)
        labels.append(b["labels"])
        task_groups.append(b["task_group"])
        sex_norms.append(b["sex_norm"])
    return {
        "input_values": torch.stack(input_vals, dim=0),
        "attention_mask": torch.stack(attn_masks, dim=0),
        "labels": torch.stack(labels, dim=0),
        "task_group": task_groups,
        "sex_norm": sex_norms,
    }

# -------------------------
# Model: frozen wav2vec2 backbone + two heads (vowel vs other)
# Inputs: backbone checkpoint name + dropout_p
# Output: logits for PD vs healthy for each clip, using the correct head per task_group
# -------------------------
class Wav2Vec2TwoHeadClassifier(nn.Module):
    def __init__(self, ckpt: str, dropout_p: float):
        super().__init__()
        self.backbone = Wav2Vec2Model.from_pretrained(ckpt, use_safetensors=True)
        for p in self.backbone.parameters():
            p.requires_grad = False

        H = int(self.backbone.config.hidden_size)
        self.pre_vowel = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.pre_other = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.head_vowel = nn.Linear(H, 2)
        self.head_other = nn.Linear(H, 2)

    def masked_mean_pool(self, last_hidden, attn_mask_samples):
        feat_mask = self.backbone._get_feature_vector_attention_mask(last_hidden.shape[1], attn_mask_samples)
        feat_mask = feat_mask.to(last_hidden.device).unsqueeze(-1).type_as(last_hidden)
        masked = last_hidden * feat_mask
        denom = feat_mask.sum(dim=1).clamp(min=1.0)
        return masked.sum(dim=1) / denom

    def _heads_fp32(self, x_fp_any: torch.Tensor, head: nn.Module) -> torch.Tensor:
        x = x_fp_any.float()
        if x.is_cuda:
            with torch.amp.autocast(device_type="cuda", enabled=False):
                return head(x)
        return head(x)

    def forward_logits(self, input_values, attention_mask, task_group):
        with torch.no_grad():
            out = self.backbone(input_values=input_values, attention_mask=attention_mask)
            last_hidden = out.last_hidden_state

        pooled = self.masked_mean_pool(last_hidden, attention_mask)  # [B,H]

        z_v = self.pre_vowel(pooled.float())
        z_o = self.pre_other(pooled.float())

        logits_v = self._heads_fp32(z_v, self.head_vowel)
        logits_o = self._heads_fp32(z_o, self.head_other)

        # Select the correct head per row based on task_group
        is_vowel = torch.tensor([tg == "vowel" for tg in task_group], device=pooled.device, dtype=torch.bool)
        logits = logits_o.clone()
        if is_vowel.any():
            logits[is_vowel] = logits_v[is_vowel]
        return logits

# -------------------------
# Metrics + plot helpers
# Inputs: y_true and y_prob (PD probability), plus a threshold for thresholded metrics
# Output: numeric metrics dict + PNG charts saved to disk
# -------------------------
def compute_auc(y_true, y_prob):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    if len(np.unique(y_true)) < 2:
        return float("nan")
    return float(roc_auc_score(y_true, y_prob))

def compute_threshold_metrics(y_true, y_prob, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])

    eps = 1e-12
    acc = (TP + TN) / max(1, (TP + TN + FP + FN))
    prec = TP / (TP + FP + eps)
    rec = TP / (TP + FN + eps)     # sensitivity
    f1 = 2 * prec * rec / (prec + rec + eps)
    spec = TN / (TN + FP + eps)

    try:
        mcc = float(matthews_corrcoef(y_true, y_pred)) if len(np.unique(y_pred)) > 1 else float("nan")
    except Exception:
        mcc = float("nan")

    try:
        _, pval = fisher_exact([[TN, FP], [FN, TP]], alternative="two-sided")
        pval = float(pval)
    except Exception:
        pval = float("nan")

    return {
        "threshold": float(thr),
        "confusion_matrix": {"TN": TN, "FP": FP, "FN": FN, "TP": TP},
        "accuracy": float(acc),
        "precision": float(prec),
        "recall": float(rec),
        "f1_score": float(f1),
        "sensitivity": float(rec),
        "specificity": float(spec),
        "mcc": float(mcc),
        "p_value_fisher_two_sided": float(pval),
    }

def save_roc_curve_png(y_true, y_prob, out_png, title_suffix="Test"):
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    plt.figure()
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curve ({title_suffix})")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def save_confusion_png(y_true, y_prob, out_png, thr=0.5, title_suffix="Test"):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(int)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])

    plt.figure()
    plt.imshow(cm)
    plt.xticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.yticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix ({title_suffix}, thr={thr:.2f})")
    for (i, j), v in np.ndenumerate(cm):
        plt.text(j, i, str(v), ha="center", va="center")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

# -------------------------
# Fairness: FNR per sex and ΔFNR = FNR(F) - FNR(M)
# Inputs: y_true, y_prob, sex_norm array, threshold
# Output: per-group FNR details + signed and absolute gap
# -------------------------
def compute_fnr_by_group_signed(y_true, y_prob, groups, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    groups = np.asarray(list(groups), dtype=object)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    out = {}
    for g in sorted(set(groups.tolist())):
        mask_g = (groups == g)
        if int(mask_g.sum()) == 0:
            continue

        pos_mask = mask_g & (y_true == 1)
        n_pos = int(pos_mask.sum())
        if n_pos == 0:
            out[g] = {"n_total": int(mask_g.sum()), "n_pos": 0, "tp": 0, "fn": 0, "fnr": float("nan")}
            continue

        tp = int(((y_pred == 1) & pos_mask).sum())
        fn = int(((y_pred == 0) & pos_mask).sum())
        fnr = float(fn / max(1, (fn + tp)))

        out[g] = {"n_total": int(mask_g.sum()), "n_pos": int(n_pos), "tp": int(tp), "fn": int(fn), "fnr": float(fnr)}

    fnr_m = out.get("M", {}).get("fnr", float("nan"))
    fnr_f = out.get("F", {}).get("fnr", float("nan"))

    if (not np.isnan(fnr_m)) and (not np.isnan(fnr_f)):
        delta_signed = float(fnr_f - fnr_m)
        delta_abs = float(abs(delta_signed))
    else:
        delta_signed = float("nan")
        delta_abs = float("nan")

    return out, delta_signed, delta_abs

def compute_confusion_counts(y_true, y_prob, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])
    return {"TN": TN, "FP": FP, "FN": FN, "TP": TP}

def compute_confusion_by_group(y_true, y_prob, groups, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    groups = np.asarray(list(groups), dtype=object)
    out = {}
    for g in sorted(set(groups.tolist())):
        mask = (groups == g)
        if int(mask.sum()) == 0:
            continue
        out[g] = {"n": int(mask.sum()), "confusion": compute_confusion_counts(y_true[mask], y_prob[mask], thr=thr)}
    return out

# -------------------------
# Reproducibility: set all RNG seeds
# Inputs: seed int
# Output: deterministic settings for random, numpy, torch
# -------------------------
def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# -------------------------
# Source checkpoint selection: pick most recent D6 trainval exp with all seeds
# Inputs: D6 trainval_runs folder
# Output: chosen_exp folder used to load best_heads.pt per seed
# -------------------------
D6_TRAINVAL_ROOT = Path(D6_OUT_ROOT) / "trainval_runs"
if not D6_TRAINVAL_ROOT.exists():
    raise FileNotFoundError(f"Missing D6 trainval_runs folder: {str(D6_TRAINVAL_ROOT)}")

exp_dirs = sorted([p for p in D6_TRAINVAL_ROOT.glob("exp_*") if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
if not exp_dirs:
    raise FileNotFoundError(f"No exp_* folders found under: {str(D6_TRAINVAL_ROOT)}")

def _has_all_seeds(exp_path: Path, seeds: list):
    for s in seeds:
        p = exp_path / f"run_D6_seed{s}" / "best_heads.pt"
        if not p.exists():
            return False
    return True

chosen_exp = None
for ed in exp_dirs:
    if _has_all_seeds(ed, SEEDS):
        chosen_exp = ed
        break

if chosen_exp is None:
    sample = exp_dirs[0]
    raise FileNotFoundError(
        "Could not find a D6 trainval experiment with all 3 best_heads.pt files.\n"
        f"Expected: {str(D6_TRAINVAL_ROOT)}/exp_*/run_D6_seedXXXX/best_heads.pt\n"
        f"Most recent exp checked: {str(sample)}"
    )

print("\nUsing SOURCE (D6) Train+Val experiment folder:")
print(" ", str(chosen_exp))

# -------------------------
# Output folder: write only under D2
# Inputs: D2_OUT_ROOT
# Output: ZS_ROOT folder for per-seed runs + summary files
# -------------------------
ZS_ROOT = Path(D2_OUT_ROOT) / "Cross_Language_Zero_Shot_Runs"
ZS_ROOT.mkdir(parents=True, exist_ok=True)

# -------------------------
# DataLoader: D2 TEST only (plus a small warm-up read)
# Inputs: test_df and audio files
# Output: test_loader ready for inference
# -------------------------
test_ds = AudioManifestDataset(test_df)
test_loader = DataLoader(
    test_ds,
    batch_size=PER_DEVICE_BS,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    collate_fn=collate_fn,
)

print("\nWarm-up: loading up to 2 TEST batches...")
t0 = time.time()
nb = len(test_loader)
wb = min(2, nb)
if wb == 0:
    raise RuntimeError("TEST DataLoader has 0 batches. Check test_df length and PER_DEVICE_BS.")
it = iter(test_loader)
for i in range(wb):
    _ = next(it)
    print(f"  loaded TEST warmup batch {i+1}/{wb}")
print(f"Warm-up done in {time.time()-t0:.2f}s")

# -------------------------
# Checkpoint loader: load only the trained heads from best_heads.pt
# Inputs: model instance + best_heads.pt path
# Output: model with head weights restored for this seed
# -------------------------
def load_heads_into_model(model: Wav2Vec2TwoHeadClassifier, best_heads_path: Path):
    if not best_heads_path.exists():
        raise FileNotFoundError(f"Missing best_heads.pt: {str(best_heads_path)}")
    state = torch.load(str(best_heads_path), map_location="cpu")
    needed = ["pre_vowel", "pre_other", "head_vowel", "head_other"]
    missing = [k for k in needed if k not in state]
    if missing:
        raise KeyError(
            f"best_heads.pt missing keys {missing}. Found keys: {list(state.keys())}. "
            "This zero-shot code expects the trainval save format."
        )
    model.pre_vowel.load_state_dict(state["pre_vowel"], strict=True)
    model.pre_other.load_state_dict(state["pre_other"], strict=True)
    model.head_vowel.load_state_dict(state["head_vowel"], strict=True)
    model.head_other.load_state_dict(state["head_other"], strict=True)
    return model

# -------------------------
# Inference: run model over D2 test_loader and collect probabilities
# Inputs: loader + model
# Output: arrays (y_true, y_prob, sex_norm) for metrics and plots
# -------------------------
def run_inference(loader, model, desc):
    use_amp = bool(USE_AMP and DEVICE.type == "cuda")
    all_probs, all_true, all_sex = [], [], []

    pbar = tqdm(loader, desc=desc, dynamic_ncols=True)
    with torch.inference_mode():
        for batch in pbar:
            input_values = batch["input_values"].to(DEVICE, non_blocking=False)
            attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
            labels = batch["labels"].to(DEVICE, non_blocking=False)
            task_group = batch["task_group"]
            sex_norm = batch["sex_norm"]

            with torch.amp.autocast(device_type="cuda", enabled=use_amp):
                logits = model.forward_logits(input_values, attention_mask, task_group)

            probs = torch.softmax(logits.float(), dim=-1)[:, 1].detach().cpu().numpy().astype(np.float64)
            all_probs.extend(probs.tolist())
            all_true.extend(labels.detach().cpu().numpy().astype(np.int64).tolist())
            all_sex.extend(list(sex_norm))

    return np.asarray(all_true, dtype=np.int64), np.asarray(all_probs, dtype=np.float64), np.asarray(all_sex, dtype=object)

# -------------------------
# Per-seed run: load D6 heads, evaluate D2 at FIXED_THR, save metrics + PNGs
# Inputs: seed, chosen_exp, FIXED_THR, D2 test_loader
# Output: run_<TRANSFER_TAG>_seedXXXX folder with metrics.json + plots
# -------------------------
def run_seed(seed: int):
    set_all_seeds(seed)

    run_dir = ZS_ROOT / f"run_{TRANSFER_TAG}_seed{seed}"
    run_dir.mkdir(parents=True, exist_ok=True)

    best_heads_path = chosen_exp / f"run_D6_seed{seed}" / "best_heads.pt"

    print(f"\n[seed={seed}] Loading SOURCE heads from:")
    print(" ", str(best_heads_path))
    print(f"[seed={seed}] Evaluating TARGET D2 TEST @ fixed_thr={FIXED_THR:.6f}")

    model = Wav2Vec2TwoHeadClassifier(BACKBONE_CKPT, dropout_p=DROPOUT_P).to(DEVICE)
    model = load_heads_into_model(model, best_heads_path)
    model.eval()

    yt, pt, st = run_inference(test_loader, model, desc=f"[seed={seed}] D2 TEST (zero-shot)")

    test_auc = compute_auc(yt, pt)  # threshold-free
    thr_metrics = compute_threshold_metrics(yt, pt, thr=FIXED_THR)

    fnr_by_sex, delta_f_minus_m, delta_abs = compute_fnr_by_group_signed(yt, pt, st, thr=FIXED_THR)
    confusion_by_sex = compute_confusion_by_group(yt, pt, st, thr=FIXED_THR)

    # Plots: always save overall ROC + confusion matrix
    roc_png = run_dir / "roc_curve.png"
    cm_png  = run_dir / "confusion_matrix.png"
    save_roc_curve_png(yt, pt, str(roc_png), title_suffix=f"D2 Test (seed={seed})")
    save_confusion_png(yt, pt, str(cm_png), thr=FIXED_THR, title_suffix=f"D2 Test (seed={seed})")

    # Sex-specific confusion matrices (only if that sex exists in D2 TEST)
    cm_m_png = None
    cm_f_png = None
    mask_m = (st == "M")
    mask_f = (st == "F")

    if int(mask_m.sum()) > 0:
        cm_m_png = run_dir / "confusion_matrix_M.png"
        save_confusion_png(yt[mask_m], pt[mask_m], str(cm_m_png), thr=FIXED_THR, title_suffix=f"D2 Test SEX=M (seed={seed})")

    if int(mask_f.sum()) > 0:
        cm_f_png = run_dir / "confusion_matrix_F.png"
        save_confusion_png(yt[mask_f], pt[mask_f], str(cm_f_png), thr=FIXED_THR, title_suffix=f"D2 Test SEX=F (seed={seed})")

    # Structured metrics saved for later analysis/paper tables
    metrics = {
        "transfer": {
            "tag": TRANSFER_TAG,
            "source_dataset": "D6",
            "target_dataset": "D2",
            "threshold_policy": "Fixed threshold = mean VAL-opt threshold from D6 monolingual test summary (Youden J on D6 VAL)",
            "fixed_threshold_value": float(FIXED_THR),
            "d6_mono_summary_path": str(D6_MONO_SUMMARY),
        },

        "target": {
            "dataset": "D2",
            "seed": int(seed),
            "n_test": int(len(test_df)),
            "label_counts_test": test_df["label_num"].value_counts(dropna=False).to_dict(),
            "sex_counts_test_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),
        },

        "test_auroc": float(test_auc),

        "threshold_metrics_test": thr_metrics,
        "test_threshold_used": float(FIXED_THR),

        "fairness_test": {
            "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at fixed transfer threshold.",
            "fnr_by_sex_norm": fnr_by_sex,
            "delta_fnr_F_minus_M": float(delta_f_minus_m),
            "delta_fnr_abs": float(delta_abs),
            "note": "If n_PD for a sex is 0, its FNR is NaN and ΔFNR is NaN.",
            "sex_normalization_note": "D2 mapping: male->M, female->F; otherwise UNK.",
        },

        "confusion_by_sex_norm": confusion_by_sex,

        "artifacts": {
            "roc_curve_png": str(roc_png),
            "confusion_matrix_png": str(cm_png),
            "confusion_matrix_M_png": str(cm_m_png) if cm_m_png is not None else None,
            "confusion_matrix_F_png": str(cm_f_png) if cm_f_png is not None else None,
        },

        "paths": {
            "d2_out_root": str(D2_OUT_ROOT),
            "zero_shot_root": str(ZS_ROOT),
            "run_dir": str(run_dir),
            "source_trainval_experiment_used": str(chosen_exp),
            "source_best_heads_path": str(best_heads_path),
        },

        "backbone_ckpt": BACKBONE_CKPT,
        "dropout_p": float(DROPOUT_P),
        "timestamp": time.strftime("%Y%m%d_%H%M%S"),
    }

    metrics_path = run_dir / "metrics.json"
    with open(metrics_path, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2)

    print(f"[seed={seed}] DONE | test_AUROC={test_auc:.6f}")

    def _fmt_fnr(g):
        d = fnr_by_sex.get(g, None)
        if d is None:
            return "n/a"
        return f"fnr={d['fnr']:.6f} (n_PD={d['n_pos']}, fn={d['fn']}, tp={d['tp']})"

    print(f"[seed={seed}] TEST metrics @ fixed_thr={FIXED_THR:.6f}")
    print(f"[seed={seed}] FAIRNESS (H3) @ fixed_thr={FIXED_THR:.6f}:")
    print("  M:", _fmt_fnr("M"))
    print("  F:", _fmt_fnr("F"))
    if "UNK" in fnr_by_sex:
        print("  UNK:", _fmt_fnr("UNK"))
    print("  ΔFNR (F-M):", f"{delta_f_minus_m:.6f}" if not np.isnan(delta_f_minus_m) else "nan")
    print("  |ΔFNR|:", f"{delta_abs:.6f}" if not np.isnan(delta_abs) else "nan")

    print(f"[seed={seed}] WROTE:")
    print(" ", str(metrics_path))
    print(" ", str(roc_png))
    print(" ", str(cm_png))
    if cm_m_png is not None:
        print(" ", str(cm_m_png))
    if cm_f_png is not None:
        print(" ", str(cm_f_png))

    return {
        "seed": int(seed),
        "test_auroc": float(test_auc),
        "thr_metrics_test": thr_metrics,
        "fnr_by_sex_test": fnr_by_sex,
        "delta_f_minus_m_test": float(delta_f_minus_m),
        "delta_abs_test": float(delta_abs),
        "run_dir": str(run_dir),
    }

# -------------------------
# Aggregate across 3 seeds: AUROC mean ± 95% CI, others mean ± SD
# Inputs: per-seed outputs from run_seed()
# Output: summary_zero_shot4.json + history_index.jsonl under ZS_ROOT
# -------------------------
seed_results = []
for seed in SEEDS:
    seed_results.append(run_seed(seed))

aurocs = [r["test_auroc"] for r in seed_results]
t_crit = 4.302652729911275  # df=2, 95% CI
n = len(aurocs)
mean_auc = float(np.mean(aurocs))
std_auc = float(np.std(aurocs, ddof=1)) if n > 1 else 0.0
half_width = float(t_crit * (std_auc / math.sqrt(n))) if n > 1 else 0.0
ci95 = [float(mean_auc - half_width), float(mean_auc + half_width)]

def _mean_sd(vals):
    vals = np.asarray(vals, dtype=np.float64)
    return float(np.nanmean(vals)), float(np.nanstd(vals, ddof=1)) if np.sum(~np.isnan(vals)) > 1 else 0.0

thr_list = [r["thr_metrics_test"] for r in seed_results]
keys = ["accuracy","precision","recall","f1_score","sensitivity","specificity","mcc","p_value_fisher_two_sided"]

agg = {}
for k in keys:
    v = [float(tm.get(k, float("nan"))) for tm in thr_list]
    mu, sd = _mean_sd(v)
    agg[k] = {
        "mean": mu,
        "sd": sd,
        "values_by_seed": {str(r["seed"]): float(tm.get(k, float("nan"))) for r, tm in zip(seed_results, thr_list)}
    }

cm_by_seed = {str(r["seed"]): r["thr_metrics_test"]["confusion_matrix"] for r in seed_results}

# FAIRNESS aggregation (H3)
fnr_by_seed = {str(r["seed"]): r["fnr_by_sex_test"] for r in seed_results}
delta_signed_by_seed = {str(r["seed"]): float(r["delta_f_minus_m_test"]) for r in seed_results}
delta_abs_by_seed = {str(r["seed"]): float(r["delta_abs_test"]) for r in seed_results}

fnr_m_vals, fnr_f_vals, n_pd_m_vals, n_pd_f_vals = [], [], [], []
d_signed_vals, d_abs_vals = [], []

for r in seed_results:
    s = str(r["seed"])
    d = fnr_by_seed.get(s, {})
    fnr_m_vals.append(float(d.get("M", {}).get("fnr", float("nan"))))
    fnr_f_vals.append(float(d.get("F", {}).get("fnr", float("nan"))))
    n_pd_m_vals.append(float(d.get("M", {}).get("n_pos", float("nan"))))
    n_pd_f_vals.append(float(d.get("F", {}).get("n_pos", float("nan"))))
    d_signed_vals.append(float(delta_signed_by_seed.get(s, float("nan"))))
    d_abs_vals.append(float(delta_abs_by_seed.get(s, float("nan"))))

fnr_m_mean, fnr_m_sd = _mean_sd(fnr_m_vals)
fnr_f_mean, fnr_f_sd = _mean_sd(fnr_f_vals)
d_signed_mean, d_signed_sd = _mean_sd(d_signed_vals)
d_abs_mean, d_abs_sd = _mean_sd(d_abs_vals)

print("\nZero-shot TEST AUROC by seed:")
for r in seed_results:
    print(f"  seed {r['seed']}: {r['test_auroc']:.6f}")
print(f"\nMean Zero-shot TEST AUROC: {mean_auc:.6f}")
print(f"95% CI (t, n=3): [{ci95[0]:.6f}, {ci95[1]:.6f}]")

print(f"\nFixed threshold used for ALL seeds (from D6 mono mean VAL-opt): {FIXED_THR:.6f}")

print("\nThreshold metrics on D2 TEST @ fixed threshold (mean ± SD across seeds):")
for k in ["accuracy","precision","sensitivity","specificity","f1_score","mcc"]:
    mu = agg[k]["mean"]
    sd = agg[k]["sd"]
    print(f"  {k}: {mu:.6f} ± {sd:.6f}")
print("  fisher_p_value_two_sided:", f"{agg['p_value_fisher_two_sided']['mean']:.6g} ± {agg['p_value_fisher_two_sided']['sd']:.6g}")

print("\nFAIRNESS (H3) on D2 TEST @ fixed threshold (mean ± SD across seeds):")
print(f"  FNR_M: {fnr_m_mean:.6f} ± {fnr_m_sd:.6f}")
print(f"  FNR_F: {fnr_f_mean:.6f} ± {fnr_f_sd:.6f}")
print(f"  ΔFNR (F-M): {d_signed_mean:.6f} ± {d_signed_sd:.6f}")
print(f"  |ΔFNR|: {d_abs_mean:.6f} ± {d_abs_sd:.6f}")

summary = {
    "transfer": {
        "tag": TRANSFER_TAG,
        "source_dataset": "D6",
        "target_dataset": "D2",
        "fixed_threshold_value": float(FIXED_THR),
        "fixed_threshold_source": {
            "path": str(D6_MONO_SUMMARY),
            "key": "val_opt_threshold_mean_sd.mean",
            "note": "Mean across seeds of D6 VAL-opt thresholds (Youden J) as recorded by D6 monolingual test run.",
        },
        "source_trainval_experiment_used": str(chosen_exp),
    },

    "target": {
        "dx_out_root": str(D2_OUT_ROOT),
        "manifest_all": str(D2_MANIFEST_ALL),
        "n_test": int(len(test_df)),
        "label_counts_test": test_df["label_num"].value_counts(dropna=False).to_dict(),
        "sex_counts_test_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),
    },

    "seeds": SEEDS,

    "test_aurocs_by_seed": {str(r["seed"]): float(r["test_auroc"]) for r in seed_results},
    "mean_test_auroc": float(mean_auc),
    "t_crit_df2_95": float(t_crit),
    "half_width_95_ci": float(half_width),
    "ci95_test_auroc": ci95,

    "threshold_metrics_test_mean_sd": agg,
    "confusion_matrix_by_seed": cm_by_seed,

    "fairness_test": {
        "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at the fixed transfer threshold.",
        "fnr_by_sex_norm_by_seed": fnr_by_seed,
        "delta_fnr_F_minus_M_by_seed": delta_signed_by_seed,
        "delta_fnr_abs_by_seed": delta_abs_by_seed,
        "fnr_M_mean_sd": {"mean": float(fnr_m_mean), "sd": float(fnr_m_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, fnr_m_vals)}},
        "fnr_F_mean_sd": {"mean": float(fnr_f_mean), "sd": float(fnr_f_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, fnr_f_vals)}},
        "delta_fnr_F_minus_M_mean_sd": {"mean": float(d_signed_mean), "sd": float(d_signed_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, d_signed_vals)}},
        "delta_fnr_abs_mean_sd": {"mean": float(d_abs_mean), "sd": float(d_abs_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, d_abs_vals)}},
        "denominators_PD_by_seed": {str(s): {"n_PD_M": float(n_pd_m_vals[i]), "n_PD_F": float(n_pd_f_vals[i])} for i, s in enumerate(SEEDS)},
        "sex_normalization_note": "D2 mapping: male->M, female->F; otherwise UNK. ΔFNR computed only when both M and F have defined FNR.",
    },

    "run_dirs": [r["run_dir"] for r in seed_results],
    "timestamp": time.strftime("%Y%m%d_%H%M%S"),
}

summary_path = ZS_ROOT / "summary_zero_shot4.json"
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(summary, f, indent=2)

history_path = ZS_ROOT / "history_index.jsonl"
with open(history_path, "a", encoding="utf-8") as f:
    f.write(json.dumps(summary) + "\n")

print("\nWROTE summary:", str(summary_path))
print("APPENDED history index:", str(history_path))
print("Open this folder to access artifacts:", str(ZS_ROOT))

# -------------------------
# Stop runtime (Colab)
# Inputs: none
# Output: attempts to release the GPU instance
# -------------------------
print("\nAll done. Unassigning the runtime to stop the L4 instance...")
try:
    from google.colab import runtime  # type: ignore
    print("Calling runtime.unassign() now.")
    runtime.unassign()
except Exception as e:
    print("Could not unassign runtime automatically. You can stop the runtime manually in Colab.")
    print("Reason:", repr(e))

The following cell runs a cross-language zero-shot test from Italian to Spanish (D4 → D1). It checks whether a model trained on the Italian dataset (D4) can detect Parkinson’s disease on the Spanish dataset (D1) without any retraining on D1. The cell loads the most recent D4 train and validation experiment that contains completed head checkpoints for three seeds (1337, 2024, 7777), and applies those heads directly to the D1 test split only, using audio paths from `D1/manifests/manifest_all.csv`.

A single fixed decision threshold is used for all seeds. This threshold is taken from D4’s saved `monolingual_test_runs/summary_test.json`, using `val_opt_threshold_mean_sd.mean` when available, or falling back to the mean of `val_opt_threshold_by_seed`. This avoids any threshold tuning on D1 and keeps the transfer test fair.

Before evaluation, the cell prepares the D1 test data. It checks that all required manifest columns exist, filters the data to `split == "test"`, and stops early if any audio clip paths are missing. Each clip is assigned to a task group, with “vowel” used when `task == "vowl"` and “other” for all remaining cases. D1 sex values are normalized into `sex_norm` (M, F, UNK) using D1’s numeric encoding, where 0 maps to F and 1 maps to M. This allows consistent sex-based reporting and plots.

For each seed, the cell loads `best_heads.pt` from the selected D4 experiment folder, rebuilds the same frozen Wav2Vec2 two-head classifier used during training, and runs inference on all D1 test clips to produce Parkinson’s disease probabilities. It then computes AUROC as a threshold-free measure of performance, followed by fixed-threshold metrics at the D4-derived threshold. These include the confusion matrix, accuracy, precision, sensitivity or recall, specificity, F1 score, MCC, and the Fisher exact test p-value. Fairness is measured as ΔFNR, defined as FNR(F) minus FNR(M), using only true Parkinson’s cases. The cell saves ROC curves and confusion matrix images for the full test set, as well as sex-specific confusion matrices when both M and F samples are present. A per-seed `metrics.json` file and all plots are written under
`<D1_OUT_ROOT>/Cross_Language_Zero_Shot_Runs/run_IT_to_ES_seed<seed>/`.

After all three seeds finish, the cell combines the results across seeds. It reports the mean AUROC with a 95 percent confidence interval using a t distribution with n equal to 3, along with the mean and standard deviation of the fixed-threshold metrics and fairness values. A combined `summary_zero_shot.json` file is written, the same summary is appended to `history_index.jsonl` under `<D1_OUT_ROOT>/Cross_Language_Zero_Shot_Runs/`, and the Colab runtime is unassigned to shut down the GPU.

In [None]:
# =========================
# IT → ES Zero-shot (D4 → D1): Test-only transfer with fixed threshold
# - Loads D4-trained heads (most recent train+val experiment)
# - Runs inference on D1 TEST only (from manifest_all.csv)
# - Uses one fixed threshold read from D4 monolingual summary_test.json
# - Saves per-seed metrics + plots, plus an aggregated summary under D1
# =========================

import os, json, math, random, time, warnings
from pathlib import Path

import numpy as np
import pandas as pd
import soundfile as sf

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import Wav2Vec2Model

from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, matthews_corrcoef
from scipy.stats import fisher_exact
import matplotlib.pyplot as plt

from tqdm.auto import tqdm

# -------------------------
# Safety checks: avoid importing local files named like core libraries
# Inputs: /content directory
# Output: raises an error early if a naming conflict exists
# -------------------------
if os.path.exists("/content/torch.py") or os.path.exists("/content/torch/__init__.py"):
    raise RuntimeError("Found local /content/torch.py or /content/torch/ that shadows PyTorch. Rename/remove it and restart runtime.")
if os.path.exists("/content/transformers.py") or os.path.exists("/content/transformers/__init__.py"):
    raise RuntimeError("Found local /content/transformers.py or /content/transformers/ that shadows Hugging Face Transformers. Rename/remove it and restart runtime.")

# -------------------------
# Drive mount (Colab)
# Inputs: Google Drive
# Output: /content/drive mounted so manifests, checkpoints, and outputs can be read/written
# -------------------------
if not os.path.isdir("/content/drive/MyDrive"):
    from google.colab import drive  # type: ignore
    drive.mount("/content/drive")

# -------------------------
# Paths: target (D1) + source (D4)
# Inputs: optional globals D1_OUT_ROOT, D4_OUT_ROOT (if set earlier), else fallbacks
# Output: D1_MANIFEST_ALL and D4_MONO_SUMMARY paths
# -------------------------
D1_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D1-NeuroVoz-Castillan Spanish/preprocessed_v1"
D4_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D4-Italian (IPVS)/preprocessed_v1"

D1_OUT_ROOT = globals().get("D1_OUT_ROOT", D1_OUT_ROOT_FALLBACK)
D4_OUT_ROOT = globals().get("D4_OUT_ROOT", D4_OUT_ROOT_FALLBACK)

D1_MANIFEST_ALL = f"{D1_OUT_ROOT}/manifests/manifest_all.csv"

# D4 monolingual summary path (contains the fixed threshold source)
D4_MONO_SUMMARY = Path(D4_OUT_ROOT) / "monolingual_test_runs" / "summary_test.json"

# -------------------------
# Run settings (kept consistent with other test cells)
# Inputs: constants below
# Output: printed configuration + runtime behavior (batching, AMP, device)
# -------------------------
TRANSFER_TAG    = "IT_to_ES"          # D4 -> D1
SEEDS           = [1337, 2024, 7777]
BACKBONE_CKPT   = "facebook/wav2vec2-base"
SR_EXPECTED     = 16000
TINY_THRESH     = 1e-4

EFFECTIVE_BS    = 64
PER_DEVICE_BS   = 16
GRAD_ACCUM      = max(1, EFFECTIVE_BS // PER_DEVICE_BS)  # printed only

DROPOUT_P       = 0.2

NUM_WORKERS     = 0
PIN_MEMORY      = False

USE_AMP         = True

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
warnings.filterwarnings("ignore", message="Passing `gradient_checkpointing` to a config initialization is deprecated")

print("D1_OUT_ROOT (target):", D1_OUT_ROOT)
print("D1_MANIFEST_ALL:", D1_MANIFEST_ALL)
print("D4_OUT_ROOT (source):", D4_OUT_ROOT)
print("D4_MONO_SUMMARY:", str(D4_MONO_SUMMARY))
print("DEVICE:", DEVICE)
if DEVICE.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
print("PER_DEVICE_BS:", PER_DEVICE_BS, "| EFFECTIVE_BS:", PER_DEVICE_BS * GRAD_ACCUM)
print("NUM_WORKERS:", NUM_WORKERS, "| PIN_MEMORY:", PIN_MEMORY)
print("USE_AMP:", bool(USE_AMP and DEVICE.type == "cuda"))

# -------------------------
# Fixed threshold: read from D4 monolingual summary_test.json
# Inputs: D4_MONO_SUMMARY (JSON)
# Output: FIXED_THR (float) used for all three seeds on D1
# -------------------------
if not D4_MONO_SUMMARY.exists():
    raise FileNotFoundError(
        "Missing D4 monolingual summary_test.json.\n"
        f"Expected: {str(D4_MONO_SUMMARY)}\n"
        "Run the D4 monolingual test cell first (the one that writes summary_test.json)."
    )

with open(D4_MONO_SUMMARY, "r", encoding="utf-8") as f:
    d4_sum = json.load(f)

# D4 schema: summary["val_opt_threshold_mean_sd"]["mean"] (preferred)
if "val_opt_threshold_mean_sd" in d4_sum and isinstance(d4_sum["val_opt_threshold_mean_sd"], dict):
    FIXED_THR = float(d4_sum["val_opt_threshold_mean_sd"].get("mean", float("nan")))
else:
    # Fallback: mean of per-seed values if present
    by_seed = d4_sum.get("val_opt_threshold_by_seed", None)
    if isinstance(by_seed, dict) and len(by_seed) > 0:
        FIXED_THR = float(np.nanmean([float(v) for v in by_seed.values()]))
    else:
        FIXED_THR = float("nan")

if not np.isfinite(FIXED_THR):
    raise RuntimeError(
        "Could not read a finite fixed threshold from D4 monolingual summary_test.json.\n"
        "Expected keys: val_opt_threshold_mean_sd.mean (preferred) or val_opt_threshold_by_seed."
    )

print(f"\nFixed transfer threshold (from D4 monolingual mean VAL-opt): {FIXED_THR:.6f}")

# -------------------------
# Target manifest: load D1 and keep TEST only
# Inputs: D1 manifest_all.csv
# Output: test_df with required columns + basic counts printed
# -------------------------
if not os.path.exists(D1_MANIFEST_ALL):
    raise FileNotFoundError(f"Missing D1 manifest_all.csv: {D1_MANIFEST_ALL}")

m_all = pd.read_csv(D1_MANIFEST_ALL)

# Keep sex for fairness and sex-specific confusion matrices
req_cols = {"split", "clip_path", "label_num", "task", "sex"}
missing = [c for c in sorted(req_cols) if c not in m_all.columns]
if missing:
    raise ValueError(f"D1 manifest missing required columns: {missing}. Found: {list(m_all.columns)}")

# Infer dataset_id and ensure this is really D1
if "dataset" in m_all.columns and m_all["dataset"].notna().any():
    dataset_id = str(m_all["dataset"].astype(str).value_counts(dropna=True).idxmax())
    m_all = m_all[m_all["dataset"].astype(str) == dataset_id].copy()
else:
    dataset_id = "DX"

if dataset_id != "D1":
    raise RuntimeError(
        f"Expected dataset_id=='D1' but got {dataset_id!r}. "
        "This usually means D1_OUT_ROOT is wrong or inherited from a previous cell.\n"
        f"D1_OUT_ROOT={D1_OUT_ROOT}"
    )

# Keep a stable column set (missing ones are filled with NaN)
keep_cols = ["clip_path", "label_num", "task", "speaker_id", "sex", "age", "duration_sec", "split"]
for c in keep_cols:
    if c not in m_all.columns:
        m_all[c] = np.nan
m_all = m_all[keep_cols].copy()

test_df = m_all[m_all["split"].isin(["test"])].copy().reset_index(drop=True)

print(f"\nTarget dataset inferred: {dataset_id}")
print(f"TEST rows: {len(test_df)}")
if len(test_df) == 0:
    raise RuntimeError("After filtering to split=='test', D1 manifest has 0 rows.")

print("TEST label counts:", test_df["label_num"].value_counts(dropna=False).to_dict())
print("TEST sex counts (raw):", test_df["sex"].value_counts(dropna=False).to_dict())

# -------------------------
# Path check: ensure target audio files exist
# Inputs: test_df.clip_path
# Output: raises early if any clip files are missing (shows a few examples)
# -------------------------
def _fail_fast_missing_paths(df: pd.DataFrame, name: str):
    missing_paths = []
    for p in tqdm(df["clip_path"].astype(str).tolist(), desc=f"Check {name} clip_path exists", dynamic_ncols=True):
        if not os.path.exists(p):
            missing_paths.append(p)
            if len(missing_paths) >= 10:
                break
    if missing_paths:
        raise FileNotFoundError(f"{name}: missing clip_path(s). Examples: {missing_paths}")

_fail_fast_missing_paths(test_df, "TEST")

# -------------------------
# Task grouping: map each row to vowel vs other
# Inputs: test_df.task
# Output: test_df.task_group used to pick the correct head during inference
# -------------------------
def _task_group(task_val) -> str:
    return "vowel" if str(task_val) == "vowl" else "other"

test_df["task_group"] = test_df["task"].apply(_task_group)

# -------------------------
# Sex normalization for fairness/charts (D1 uses 0/1 encoding)
# Inputs: test_df.sex
# Output: test_df.sex_norm in {M,F,UNK}
# -------------------------
def normalize_sex_d1(val) -> str:
    """
    Returns 'M', 'F', or 'UNK'
    D1 encoding:
      0 -> F
      1 -> M
    Also handles common string encodings.
    """
    if pd.isna(val):
        return "UNK"

    # numeric
    try:
        fv = float(val)
        if np.isfinite(fv) and abs(fv - round(fv)) < 1e-9:
            iv = int(round(fv))
            if iv == 0:
                return "F"
            if iv == 1:
                return "M"
    except Exception:
        pass

    # string
    s = str(val).strip().lower()
    if s in {"m", "male", "man", "masc", "masculine"}:
        return "M"
    if s in {"f", "female", "woman", "fem", "feminine"}:
        return "F"
    return "UNK"

test_df["sex_norm"] = test_df["sex"].apply(normalize_sex_d1)
print("TEST sex counts (normalized):", test_df["sex_norm"].value_counts(dropna=False).to_dict())
if (test_df["sex_norm"] == "UNK").any():
    print("NOTE: Some D1 'sex' values could not be normalized to M/F and were counted as 'UNK' for fairness and sex charts.")

# -------------------------
# Dataset + collator: read audio and build attention masks
# Inputs: test_df rows + audio files
# Output: DataLoader batches with padded input_values, attention_mask, labels, task_group, sex_norm
# -------------------------
class AudioManifestDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        clip_path = str(row["clip_path"])
        label = int(row["label_num"])
        task_group = str(row["task_group"])
        sex_norm = str(row["sex_norm"])

        if not os.path.exists(clip_path):
            raise FileNotFoundError(f"Missing clip_path: {clip_path}")

        y, sr = sf.read(clip_path, always_2d=False)
        if isinstance(y, np.ndarray) and y.ndim > 1:
            y = np.mean(y, axis=1)
        y = np.asarray(y, dtype=np.float32)

        if int(sr) != SR_EXPECTED:
            raise ValueError(f"Sample rate mismatch (expected {SR_EXPECTED}): {clip_path} has sr={sr}")

        # attention_mask mainly matters for vowel clips (ignore padded tail)
        attn = np.ones((len(y),), dtype=np.int64)
        if task_group == "vowel":
            k = -1
            for j in range(len(y) - 1, -1, -1):
                if abs(float(y[j])) > TINY_THRESH:
                    k = j
                    break
            if k >= 0:
                attn[:k+1] = 1
                attn[k+1:] = 0
            else:
                attn[:] = 1
        else:
            attn[:] = 1

        return {
            "input_values": torch.from_numpy(y),
            "attention_mask": torch.from_numpy(attn),
            "labels": torch.tensor(label, dtype=torch.long),
            "task_group": task_group,
            "sex_norm": sex_norm,
        }

def collate_fn(batch):
    max_len = int(max(b["input_values"].numel() for b in batch))
    input_vals, attn_masks, labels, task_groups, sex_norms = [], [], [], [], []
    for b in batch:
        x = b["input_values"]
        a = b["attention_mask"]
        pad = max_len - x.numel()
        if pad > 0:
            x = torch.cat([x, torch.zeros(pad, dtype=x.dtype)], dim=0)
            a = torch.cat([a, torch.zeros(pad, dtype=a.dtype)], dim=0)
        input_vals.append(x)
        attn_masks.append(a)
        labels.append(b["labels"])
        task_groups.append(b["task_group"])
        sex_norms.append(b["sex_norm"])
    return {
        "input_values": torch.stack(input_vals, dim=0),
        "attention_mask": torch.stack(attn_masks, dim=0),
        "labels": torch.stack(labels, dim=0),
        "task_group": task_groups,
        "sex_norm": sex_norms,
    }

# -------------------------
# Model: frozen wav2vec2 backbone + two heads (vowel vs other)
# Inputs: backbone checkpoint name + dropout_p
# Output: logits for PD vs healthy for each clip, using the correct head per task_group
# -------------------------
class Wav2Vec2TwoHeadClassifier(nn.Module):
    def __init__(self, ckpt: str, dropout_p: float):
        super().__init__()
        self.backbone = Wav2Vec2Model.from_pretrained(ckpt, use_safetensors=True)
        for p in self.backbone.parameters():
            p.requires_grad = False

        H = int(self.backbone.config.hidden_size)
        self.pre_vowel = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.pre_other = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.head_vowel = nn.Linear(H, 2)
        self.head_other = nn.Linear(H, 2)

    def masked_mean_pool(self, last_hidden, attn_mask_samples):
        feat_mask = self.backbone._get_feature_vector_attention_mask(last_hidden.shape[1], attn_mask_samples)
        feat_mask = feat_mask.to(last_hidden.device).unsqueeze(-1).type_as(last_hidden)
        masked = last_hidden * feat_mask
        denom = feat_mask.sum(dim=1).clamp(min=1.0)
        return masked.sum(dim=1) / denom

    def _heads_fp32(self, x_fp_any: torch.Tensor, head: nn.Module) -> torch.Tensor:
        x = x_fp_any.float()
        if x.is_cuda:
            with torch.amp.autocast(device_type="cuda", enabled=False):
                return head(x)
        return head(x)

    def forward_logits(self, input_values, attention_mask, task_group):
        with torch.no_grad():
            out = self.backbone(input_values=input_values, attention_mask=attention_mask)
            last_hidden = out.last_hidden_state

        pooled = self.masked_mean_pool(last_hidden, attention_mask)  # [B,H]

        z_v = self.pre_vowel(pooled.float())
        z_o = self.pre_other(pooled.float())

        logits_v = self._heads_fp32(z_v, self.head_vowel)
        logits_o = self._heads_fp32(z_o, self.head_other)

        # Select the correct head per row based on task_group
        is_vowel = torch.tensor([tg == "vowel" for tg in task_group], device=pooled.device, dtype=torch.bool)
        logits = logits_o.clone()
        if is_vowel.any():
            logits[is_vowel] = logits_v[is_vowel]
        return logits

# -------------------------
# Metrics + plot helpers
# Inputs: y_true and y_prob (PD probability), plus a threshold for thresholded metrics
# Output: numeric metrics dict + PNG charts saved to disk
# -------------------------
def compute_auc(y_true, y_prob):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    if len(np.unique(y_true)) < 2:
        return float("nan")
    return float(roc_auc_score(y_true, y_prob))

def compute_threshold_metrics(y_true, y_prob, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])

    eps = 1e-12
    acc = (TP + TN) / max(1, (TP + TN + FP + FN))
    prec = TP / (TP + FP + eps)
    rec = TP / (TP + FN + eps)     # sensitivity
    f1 = 2 * prec * rec / (prec + rec + eps)
    spec = TN / (TN + FP + eps)

    try:
        mcc = float(matthews_corrcoef(y_true, y_pred)) if len(np.unique(y_pred)) > 1 else float("nan")
    except Exception:
        mcc = float("nan")

    try:
        _, pval = fisher_exact([[TN, FP], [FN, TP]], alternative="two-sided")
        pval = float(pval)
    except Exception:
        pval = float("nan")

    return {
        "threshold": float(thr),
        "confusion_matrix": {"TN": TN, "FP": FP, "FN": FN, "TP": TP},
        "accuracy": float(acc),
        "precision": float(prec),
        "recall": float(rec),
        "f1_score": float(f1),
        "sensitivity": float(rec),
        "specificity": float(spec),
        "mcc": float(mcc),
        "p_value_fisher_two_sided": float(pval),
    }

def save_roc_curve_png(y_true, y_prob, out_png, title_suffix="Test"):
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    plt.figure()
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curve ({title_suffix})")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def save_confusion_png(y_true, y_prob, out_png, thr=0.5, title_suffix="Test"):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(int)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])

    plt.figure()
    plt.imshow(cm)
    plt.xticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.yticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix ({title_suffix}, thr={thr:.2f})")
    for (i, j), v in np.ndenumerate(cm):
        plt.text(j, i, str(v), ha="center", va="center")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

# -------------------------
# Fairness: FNR per sex and ΔFNR = FNR(F) - FNR(M)
# Inputs: y_true, y_prob, sex_norm array, threshold
# Output: per-group FNR details + signed and absolute gap
# -------------------------
def compute_fnr_by_group_signed(y_true, y_prob, groups, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    groups = np.asarray(list(groups), dtype=object)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    out = {}
    for g in sorted(set(groups.tolist())):
        mask_g = (groups == g)
        if int(mask_g.sum()) == 0:
            continue

        pos_mask = mask_g & (y_true == 1)
        n_pos = int(pos_mask.sum())
        if n_pos == 0:
            out[g] = {"n_total": int(mask_g.sum()), "n_pos": 0, "tp": 0, "fn": 0, "fnr": float("nan")}
            continue

        tp = int(((y_pred == 1) & pos_mask).sum())
        fn = int(((y_pred == 0) & pos_mask).sum())
        fnr = float(fn / max(1, (fn + tp)))

        out[g] = {"n_total": int(mask_g.sum()), "n_pos": int(n_pos), "tp": int(tp), "fn": int(fn), "fnr": float(fnr)}

    fnr_m = out.get("M", {}).get("fnr", float("nan"))
    fnr_f = out.get("F", {}).get("fnr", float("nan"))

    if (not np.isnan(fnr_m)) and (not np.isnan(fnr_f)):
        delta_signed = float(fnr_f - fnr_m)
        delta_abs = float(abs(delta_signed))
    else:
        delta_signed = float("nan")
        delta_abs = float("nan")

    return out, delta_signed, delta_abs

def compute_confusion_counts(y_true, y_prob, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])
    return {"TN": TN, "FP": FP, "FN": FN, "TP": TP}

def compute_confusion_by_group(y_true, y_prob, groups, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    groups = np.asarray(list(groups), dtype=object)
    out = {}
    for g in sorted(set(groups.tolist())):
        mask = (groups == g)
        if int(mask.sum()) == 0:
            continue
        out[g] = {"n": int(mask.sum()), "confusion": compute_confusion_counts(y_true[mask], y_prob[mask], thr=thr)}
    return out

# -------------------------
# Reproducibility: set all RNG seeds
# Inputs: seed int
# Output: deterministic settings for random, numpy, torch
# -------------------------
def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# -------------------------
# Source checkpoint selection: pick most recent D4 trainval exp with all seeds
# Inputs: D4 trainval_runs folder
# Output: chosen_exp folder used to load best_heads.pt per seed
# -------------------------
D4_TRAINVAL_ROOT = Path(D4_OUT_ROOT) / "trainval_runs"
if not D4_TRAINVAL_ROOT.exists():
    raise FileNotFoundError(f"Missing D4 trainval_runs folder: {str(D4_TRAINVAL_ROOT)}")

exp_dirs = sorted([p for p in D4_TRAINVAL_ROOT.glob("exp_*") if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
if not exp_dirs:
    raise FileNotFoundError(f"No exp_* folders found under: {str(D4_TRAINVAL_ROOT)}")

def _has_all_seeds(exp_path: Path, seeds: list):
    for s in seeds:
        p = exp_path / f"run_D4_seed{s}" / "best_heads.pt"
        if not p.exists():
            return False
    return True

chosen_exp = None
for ed in exp_dirs:
    if _has_all_seeds(ed, SEEDS):
        chosen_exp = ed
        break

if chosen_exp is None:
    sample = exp_dirs[0]
    raise FileNotFoundError(
        "Could not find a D4 trainval experiment with all 3 best_heads.pt files.\n"
        f"Expected: {str(D4_TRAINVAL_ROOT)}/exp_*/run_D4_seedXXXX/best_heads.pt\n"
        f"Most recent exp checked: {str(sample)}"
    )

print("\nUsing SOURCE (D4) Train+Val experiment folder:")
print(" ", str(chosen_exp))

# -------------------------
# Output folder: write only under D1
# Inputs: D1_OUT_ROOT
# Output: ZS_ROOT folder for per-seed runs + summary files
# -------------------------
ZS_ROOT = Path(D1_OUT_ROOT) / "Cross_Language_Zero_Shot_Runs"
ZS_ROOT.mkdir(parents=True, exist_ok=True)

# -------------------------
# DataLoader: D1 TEST only (plus a small warm-up read)
# Inputs: test_df and audio files
# Output: test_loader ready for inference
# -------------------------
test_ds = AudioManifestDataset(test_df)
test_loader = DataLoader(
    test_ds,
    batch_size=PER_DEVICE_BS,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    collate_fn=collate_fn,
)

print("\nWarm-up: loading up to 2 TEST batches...")
t0 = time.time()
nb = len(test_loader)
wb = min(2, nb)
if wb == 0:
    raise RuntimeError("TEST DataLoader has 0 batches. Check test_df length and PER_DEVICE_BS.")
it = iter(test_loader)
for i in range(wb):
    _ = next(it)
    print(f"  loaded TEST warmup batch {i+1}/{wb}")
print(f"Warm-up done in {time.time()-t0:.2f}s")

# -------------------------
# Checkpoint loader: load only the trained heads from best_heads.pt
# Inputs: model instance + best_heads.pt path
# Output: model with head weights restored for this seed
# -------------------------
def load_heads_into_model(model: Wav2Vec2TwoHeadClassifier, best_heads_path: Path):
    if not best_heads_path.exists():
        raise FileNotFoundError(f"Missing best_heads.pt: {str(best_heads_path)}")
    state = torch.load(str(best_heads_path), map_location="cpu")
    needed = ["pre_vowel", "pre_other", "head_vowel", "head_other"]
    missing = [k for k in needed if k not in state]
    if missing:
        raise KeyError(
            f"best_heads.pt missing keys {missing}. Found keys: {list(state.keys())}. "
            "This zero-shot code expects the trainval save format."
        )
    model.pre_vowel.load_state_dict(state["pre_vowel"], strict=True)
    model.pre_other.load_state_dict(state["pre_other"], strict=True)
    model.head_vowel.load_state_dict(state["head_vowel"], strict=True)
    model.head_other.load_state_dict(state["head_other"], strict=True)
    return model

# -------------------------
# Inference: run model over D1 test_loader and collect probabilities
# Inputs: loader + model
# Output: arrays (y_true, y_prob, sex_norm) for metrics and plots
# -------------------------
def run_inference(loader, model, desc):
    use_amp = bool(USE_AMP and DEVICE.type == "cuda")
    all_probs, all_true, all_sex = [], [], []

    pbar = tqdm(loader, desc=desc, dynamic_ncols=True)
    with torch.inference_mode():
        for batch in pbar:
            input_values = batch["input_values"].to(DEVICE, non_blocking=False)
            attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
            labels = batch["labels"].to(DEVICE, non_blocking=False)
            task_group = batch["task_group"]
            sex_norm = batch["sex_norm"]

            with torch.amp.autocast(device_type="cuda", enabled=use_amp):
                logits = model.forward_logits(input_values, attention_mask, task_group)

            probs = torch.softmax(logits.float(), dim=-1)[:, 1].detach().cpu().numpy().astype(np.float64)
            all_probs.extend(probs.tolist())
            all_true.extend(labels.detach().cpu().numpy().astype(np.int64).tolist())
            all_sex.extend(list(sex_norm))

    return np.asarray(all_true, dtype=np.int64), np.asarray(all_probs, dtype=np.float64), np.asarray(all_sex, dtype=object)

# -------------------------
# Per-seed run: load D4 heads, evaluate D1 at FIXED_THR, save metrics + PNGs
# Inputs: seed, chosen_exp, FIXED_THR, D1 test_loader
# Output: run_<TRANSFER_TAG>_seedXXXX folder with metrics.json + plots
# -------------------------
def run_seed(seed: int):
    set_all_seeds(seed)

    run_dir = ZS_ROOT / f"run_{TRANSFER_TAG}_seed{seed}"
    run_dir.mkdir(parents=True, exist_ok=True)

    best_heads_path = chosen_exp / f"run_D4_seed{seed}" / "best_heads.pt"

    print(f"\n[seed={seed}] Loading SOURCE heads from:")
    print(" ", str(best_heads_path))
    print(f"[seed={seed}] Evaluating TARGET D1 TEST @ fixed_thr={FIXED_THR:.6f}")

    model = Wav2Vec2TwoHeadClassifier(BACKBONE_CKPT, dropout_p=DROPOUT_P).to(DEVICE)
    model = load_heads_into_model(model, best_heads_path)
    model.eval()

    yt, pt, st = run_inference(test_loader, model, desc=f"[seed={seed}] D1 TEST (zero-shot)")

    test_auc = compute_auc(yt, pt)  # threshold-free
    thr_metrics = compute_threshold_metrics(yt, pt, thr=FIXED_THR)

    fnr_by_sex, delta_f_minus_m, delta_abs = compute_fnr_by_group_signed(yt, pt, st, thr=FIXED_THR)
    confusion_by_sex = compute_confusion_by_group(yt, pt, st, thr=FIXED_THR)

    # Plots: always save overall ROC + confusion matrix
    roc_png = run_dir / "roc_curve.png"
    cm_png  = run_dir / "confusion_matrix.png"
    save_roc_curve_png(yt, pt, str(roc_png), title_suffix=f"D1 Test (seed={seed})")
    save_confusion_png(yt, pt, str(cm_png), thr=FIXED_THR, title_suffix=f"D1 Test (seed={seed})")

    # Sex-specific confusion matrices (only if that sex exists in D1 TEST)
    cm_m_png = None
    cm_f_png = None
    mask_m = (st == "M")
    mask_f = (st == "F")

    if int(mask_m.sum()) > 0:
        cm_m_png = run_dir / "confusion_matrix_M.png"
        save_confusion_png(yt[mask_m], pt[mask_m], str(cm_m_png), thr=FIXED_THR, title_suffix=f"D1 Test SEX=M (seed={seed})")

    if int(mask_f.sum()) > 0:
        cm_f_png = run_dir / "confusion_matrix_F.png"
        save_confusion_png(yt[mask_f], pt[mask_f], str(cm_f_png), thr=FIXED_THR, title_suffix=f"D1 Test SEX=F (seed={seed})")

    # Structured metrics saved for later analysis/paper tables
    metrics = {
        "transfer": {
            "tag": TRANSFER_TAG,
            "source_dataset": "D4",
            "target_dataset": "D1",
            "threshold_policy": "Fixed threshold = mean VAL-opt threshold from D4 monolingual test summary (Youden J on D4 VAL)",
            "fixed_threshold_value": float(FIXED_THR),
            "d4_mono_summary_path": str(D4_MONO_SUMMARY),
        },

        "target": {
            "dataset": "D1",
            "seed": int(seed),
            "n_test": int(len(test_df)),
            "label_counts_test": test_df["label_num"].value_counts(dropna=False).to_dict(),
            "sex_counts_test_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),
        },

        "test_auroc": float(test_auc),

        "threshold_metrics_test": thr_metrics,
        "test_threshold_used": float(FIXED_THR),

        "fairness_test": {
            "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at fixed transfer threshold.",
            "fnr_by_sex_norm": fnr_by_sex,
            "delta_fnr_F_minus_M": float(delta_f_minus_m),
            "delta_fnr_abs": float(delta_abs),
            "note": "If n_PD for a sex is 0, its FNR is NaN and ΔFNR is NaN.",
            "sex_normalization_note": "D1 mapping: 0->F, 1->M; otherwise UNK.",
        },

        "confusion_by_sex_norm": confusion_by_sex,

        "artifacts": {
            "roc_curve_png": str(roc_png),
            "confusion_matrix_png": str(cm_png),
            "confusion_matrix_M_png": str(cm_m_png) if cm_m_png is not None else None,
            "confusion_matrix_F_png": str(cm_f_png) if cm_f_png is not None else None,
        },

        "paths": {
            "d1_out_root": str(D1_OUT_ROOT),
            "zero_shot_root": str(ZS_ROOT),
            "run_dir": str(run_dir),
            "source_trainval_experiment_used": str(chosen_exp),
            "source_best_heads_path": str(best_heads_path),
        },

        "backbone_ckpt": BACKBONE_CKPT,
        "dropout_p": float(DROPOUT_P),
        "timestamp": time.strftime("%Y%m%d_%H%M%S"),
    }

    metrics_path = run_dir / "metrics.json"
    with open(metrics_path, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2)

    print(f"[seed={seed}] DONE | test_AUROC={test_auc:.6f}")

    def _fmt_fnr(g):
        d = fnr_by_sex.get(g, None)
        if d is None:
            return "n/a"
        return f"fnr={d['fnr']:.6f} (n_PD={d['n_pos']}, fn={d['fn']}, tp={d['tp']})"

    print(f"[seed={seed}] TEST metrics @ fixed_thr={FIXED_THR:.6f}")
    print(f"[seed={seed}] FAIRNESS (H3) @ fixed_thr={FIXED_THR:.6f}:")
    print("  M:", _fmt_fnr("M"))
    print("  F:", _fmt_fnr("F"))
    if "UNK" in fnr_by_sex:
        print("  UNK:", _fmt_fnr("UNK"))
    print("  ΔFNR (F-M):", f"{delta_f_minus_m:.6f}" if not np.isnan(delta_f_minus_m) else "nan")
    print("  |ΔFNR|:", f"{delta_abs:.6f}" if not np.isnan(delta_abs) else "nan")

    print(f"[seed={seed}] WROTE:")
    print(" ", str(metrics_path))
    print(" ", str(roc_png))
    print(" ", str(cm_png))
    if cm_m_png is not None:
        print(" ", str(cm_m_png))
    if cm_f_png is not None:
        print(" ", str(cm_f_png))

    return {
        "seed": int(seed),
        "test_auroc": float(test_auc),
        "thr_metrics_test": thr_metrics,
        "fnr_by_sex_test": fnr_by_sex,
        "delta_f_minus_m_test": float(delta_f_minus_m),
        "delta_abs_test": float(delta_abs),
        "run_dir": str(run_dir),
    }

# -------------------------
# Aggregate across 3 seeds: AUROC mean ± 95% CI, others mean ± SD
# Inputs: per-seed outputs from run_seed()
# Output: summary_zero_shot.json + history_index.jsonl under ZS_ROOT
# -------------------------
seed_results = []
for seed in SEEDS:
    seed_results.append(run_seed(seed))

aurocs = [r["test_auroc"] for r in seed_results]
t_crit = 4.302652729911275  # df=2, 95% CI
n = len(aurocs)
mean_auc = float(np.mean(aurocs))
std_auc = float(np.std(aurocs, ddof=1)) if n > 1 else 0.0
half_width = float(t_crit * (std_auc / math.sqrt(n))) if n > 1 else 0.0
ci95 = [float(mean_auc - half_width), float(mean_auc + half_width)]

def _mean_sd(vals):
    vals = np.asarray(vals, dtype=np.float64)
    return float(np.nanmean(vals)), float(np.nanstd(vals, ddof=1)) if np.sum(~np.isnan(vals)) > 1 else 0.0

thr_list = [r["thr_metrics_test"] for r in seed_results]
keys = ["accuracy","precision","recall","f1_score","sensitivity","specificity","mcc","p_value_fisher_two_sided"]

agg = {}
for k in keys:
    v = [float(tm.get(k, float("nan"))) for tm in thr_list]
    mu, sd = _mean_sd(v)
    agg[k] = {
        "mean": mu,
        "sd": sd,
        "values_by_seed": {str(r["seed"]): float(tm.get(k, float("nan"))) for r, tm in zip(seed_results, thr_list)}
    }

cm_by_seed = {str(r["seed"]): r["thr_metrics_test"]["confusion_matrix"] for r in seed_results}

# FAIRNESS aggregation (H3)
fnr_by_seed = {str(r["seed"]): r["fnr_by_sex_test"] for r in seed_results}
delta_signed_by_seed = {str(r["seed"]): float(r["delta_f_minus_m_test"]) for r in seed_results}
delta_abs_by_seed = {str(r["seed"]): float(r["delta_abs_test"]) for r in seed_results}

fnr_m_vals, fnr_f_vals, n_pd_m_vals, n_pd_f_vals = [], [], [], []
d_signed_vals, d_abs_vals = [], []

for r in seed_results:
    s = str(r["seed"])
    d = fnr_by_seed.get(s, {})
    fnr_m_vals.append(float(d.get("M", {}).get("fnr", float("nan"))))
    fnr_f_vals.append(float(d.get("F", {}).get("fnr", float("nan"))))
    n_pd_m_vals.append(float(d.get("M", {}).get("n_pos", float("nan"))))
    n_pd_f_vals.append(float(d.get("F", {}).get("n_pos", float("nan"))))
    d_signed_vals.append(float(delta_signed_by_seed.get(s, float("nan"))))
    d_abs_vals.append(float(delta_abs_by_seed.get(s, float("nan"))))

fnr_m_mean, fnr_m_sd = _mean_sd(fnr_m_vals)
fnr_f_mean, fnr_f_sd = _mean_sd(fnr_f_vals)
d_signed_mean, d_signed_sd = _mean_sd(d_signed_vals)
d_abs_mean, d_abs_sd = _mean_sd(d_abs_vals)

print("\nZero-shot TEST AUROC by seed:")
for r in seed_results:
    print(f"  seed {r['seed']}: {r['test_auroc']:.6f}")
print(f"\nMean Zero-shot TEST AUROC: {mean_auc:.6f}")
print(f"95% CI (t, n=3): [{ci95[0]:.6f}, {ci95[1]:.6f}]")

print(f"\nFixed threshold used for ALL seeds (from D4 mono mean VAL-opt): {FIXED_THR:.6f}")

print("\nThreshold metrics on D1 TEST @ fixed threshold (mean ± SD across seeds):")
for k in ["accuracy","precision","sensitivity","specificity","f1_score","mcc"]:
    mu = agg[k]["mean"]
    sd = agg[k]["sd"]
    print(f"  {k}: {mu:.6f} ± {sd:.6f}")
print("  fisher_p_value_two_sided:", f"{agg['p_value_fisher_two_sided']['mean']:.6g} ± {agg['p_value_fisher_two_sided']['sd']:.6g}")

print("\nFAIRNESS (H3) on D1 TEST @ fixed threshold (mean ± SD across seeds):")
print(f"  FNR_M: {fnr_m_mean:.6f} ± {fnr_m_sd:.6f}")
print(f"  FNR_F: {fnr_f_mean:.6f} ± {fnr_f_sd:.6f}")
print(f"  ΔFNR (F-M): {d_signed_mean:.6f} ± {d_signed_sd:.6f}")
print(f"  |ΔFNR|: {d_abs_mean:.6f} ± {d_abs_sd:.6f}")

summary = {
    "transfer": {
        "tag": TRANSFER_TAG,
        "source_dataset": "D4",
        "target_dataset": "D1",
        "fixed_threshold_value": float(FIXED_THR),
        "fixed_threshold_source": {
            "path": str(D4_MONO_SUMMARY),
            "key": "val_opt_threshold_mean_sd.mean",
            "note": "Mean across seeds of D4 VAL-opt thresholds (Youden J) as recorded by D4 monolingual test run.",
        },
        "source_trainval_experiment_used": str(chosen_exp),
    },

    "target": {
        "dx_out_root": str(D1_OUT_ROOT),
        "manifest_all": str(D1_MANIFEST_ALL),
        "n_test": int(len(test_df)),
        "label_counts_test": test_df["label_num"].value_counts(dropna=False).to_dict(),
        "sex_counts_test_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),
    },

    "seeds": SEEDS,

    "test_aurocs_by_seed": {str(r["seed"]): float(r["test_auroc"]) for r in seed_results},
    "mean_test_auroc": float(mean_auc),
    "t_crit_df2_95": float(t_crit),
    "half_width_95_ci": float(half_width),
    "ci95_test_auroc": ci95,

    "threshold_metrics_test_mean_sd": agg,
    "confusion_matrix_by_seed": cm_by_seed,

    "fairness_test": {
        "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at the fixed transfer threshold.",
        "fnr_by_sex_norm_by_seed": fnr_by_seed,
        "delta_fnr_F_minus_M_by_seed": delta_signed_by_seed,
        "delta_fnr_abs_by_seed": delta_abs_by_seed,
        "fnr_M_mean_sd": {"mean": float(fnr_m_mean), "sd": float(fnr_m_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, fnr_m_vals)}},
        "fnr_F_mean_sd": {"mean": float(fnr_f_mean), "sd": float(fnr_f_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, fnr_f_vals)}},
        "delta_fnr_F_minus_M_mean_sd": {"mean": float(d_signed_mean), "sd": float(d_signed_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, d_signed_vals)}},
        "delta_fnr_abs_mean_sd": {"mean": float(d_abs_mean), "sd": float(d_abs_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, d_abs_vals)}},
        "denominators_PD_by_seed": {str(s): {"n_PD_M": float(n_pd_m_vals[i]), "n_PD_F": float(n_pd_f_vals[i])} for i, s in enumerate(SEEDS)},
        "sex_normalization_note": "D1 mapping: 0->F, 1->M; otherwise UNK. ΔFNR computed only when both M and F have defined FNR.",
    },

    "run_dirs": [r["run_dir"] for r in seed_results],
    "timestamp": time.strftime("%Y%m%d_%H%M%S"),
}

summary_path = ZS_ROOT / "summary_zero_shot.json"
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(summary, f, indent=2)

history_path = ZS_ROOT / "history_index.jsonl"
with open(history_path, "a", encoding="utf-8") as f:
    f.write(json.dumps(summary) + "\n")

print("\nWROTE summary:", str(summary_path))
print("APPENDED history index:", str(history_path))
print("Open this folder to access artifacts:", str(ZS_ROOT))

# -------------------------
# Stop runtime (Colab)
# Inputs: none
# Output: attempts to release the GPU instance
# -------------------------
print("\nAll done. Unassigning the runtime to stop the L4 instance...")
try:
    from google.colab import runtime  # type: ignore
    print("Calling runtime.unassign() now.")
    runtime.unassign()
except Exception as e:
    print("Could not unassign runtime automatically. You can stop the runtime manually in Colab.")
    print("Reason:", repr(e))

The following cell runs a zero-shot transfer test from D2 (Slovak) to D5 (English). It checks how well model heads trained on D2 perform on D5 test clips, without any retraining on D5. The cell reads `manifest_all.csv` from D5, keeps only rows marked as `test`, verifies that all referenced audio files exist, and standardizes two fields used later in reporting: `task_group`, which separates vowel and other clips based on the `task` value, and `sex_norm`, which is set to M, F, or UNK.

The classification threshold is not adjusted using D5 data. Instead, the cell loads a single fixed threshold from D2’s saved `monolingual_test_runs/summary_test.json`. It first looks for the newer key `threshold_selection.val_selected_threshold_mean_sd.mean`, and supports older summary formats as well. If no usable value is found, it falls back to extracting a stored threshold from the most recent D2 train and validation checkpoint files.

The model setup matches the training configuration. It uses a frozen Wav2Vec2 backbone with two small task specific heads, one for vowel clips and one for other speech. The cell automatically finds the most recent D2 train and validation experiment folder that contains `best_heads.pt` for all three seeds, 1337, 2024, and 7777. For each seed, it loads the saved D2 heads, runs inference on the full D5 test set with progress bars and a short warm up, and produces Parkinson’s disease probabilities for each clip. It then computes AUROC as a threshold free measure, calculates threshold based metrics at the fixed threshold including accuracy, precision, sensitivity or recall, specificity, F1 score, MCC, and the Fisher exact test p value, and computes a fairness metric defined as ΔFNR = FNR(F) minus FNR(M). The false negative rate is calculated using only true Parkinson’s cases.

For each seed, results are saved under a timestamped folder inside D5 at
`<D5_OUT_ROOT>/zeroshot_transfer_runs/D2_to_D5_<timestamp>/run_D5_seedXXXX/`.
Saved outputs include `metrics.json`, a ROC curve plot, an overall confusion matrix plot, and confusion matrices split by sex when enough data is available.

After all three seeds complete, the cell combines results across seeds. It reports the mean AUROC with a 95 percent confidence interval using a t distribution with n equal to 3, computes the mean and standard deviation of the threshold based metrics, and summarizes fairness with mean and standard deviation for FNR by sex and for ΔFNR. A combined `summary_transfer.json` file and a `history_index.jsonl` file are written in the transfer run root. Finally, the Colab runtime is unassigned to stop the GPU session.

In [None]:
# Zero-Shot Transfer: D2 Heads on D5 Test (SK → EN)
# Loads the most recent D2 trainval heads (3 seeds), applies a single fixed threshold from D2 monolingual testing,
# and evaluates only the D5 test split with overall metrics and sex-based fairness (ΔFNR).

# ============================================================
# ZERO-SHOT TRANSFER TEST (CRASH-PROOF, WITH PROGRESS + STORED METRICS, SK to EN)
# - TARGET: D5 TEST ONLY using <D5_OUT_ROOT>/manifests/manifest_all.csv (test split only)
# - SOURCE: D2 trainval heads from MOST RECENT trainval experiment under:
#     <D2_OUT_ROOT>/trainval_runs/exp_*/run_D2_seed{seed}/best_heads.pt
# - FIXED THRESHOLD: read from D2 monolingual summary_test.json
#     preferred key: threshold_selection.val_selected_threshold_mean_sd.mean
#     (also supports older key variants + checkpoint fallback val_threshold)
# - Evaluates 3 seeds separately (1337, 2024, 7777)
# - Reports: mean Target Test AUROC ± 95% CI (t, n=3)
# - Reports: threshold metrics at FIXED threshold (mean ± SD across seeds)
# - Fairness (H3): ΔFNR = FNR(F) - FNR(M) at FIXED threshold (per seed + aggregated)
# - Saves confusion charts overall + by sex at FIXED threshold
# - Writes ONLY under:
#     <D5_OUT_ROOT>/zeroshot_transfer_runs/D2_to_D5_<timestamp>/run_D5_seedXXXX/
#   plus summary_transfer.json and history_index.jsonl in that transfer root
# - Unassigns runtime at end (L4)
# ============================================================

import os, json, math, random, time, warnings
from pathlib import Path

import numpy as np
import pandas as pd
import soundfile as sf

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import Wav2Vec2Model

from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, matthews_corrcoef
from scipy.stats import fisher_exact
import matplotlib.pyplot as plt

from tqdm.auto import tqdm


# -------------------------
# Mount Drive (Colab)
# Inputs: Colab runtime state
# Output: /content/drive/MyDrive available for reading manifests and writing results
# -------------------------
if not os.path.isdir("/content/drive/MyDrive"):
    from google.colab import drive  # type: ignore
    drive.mount("/content/drive")


# -------------------------
# Target and source roots
# Inputs: D5 preprocessed root (target) and D2 preprocessed root (source)
# Output: key file paths (D5 manifest_all.csv, D2 monolingual summary_test.json)
# -------------------------
D5_OUT_ROOT = "/content/drive/MyDrive/AI_PD_Project/Datasets/D5-English (MDVR-KCL)/preprocessed_v2"
D5_MANIFEST_ALL = f"{D5_OUT_ROOT}/manifests/manifest_all.csv"

D2_SOURCE_ROOT = "/content/drive/MyDrive/AI_PD_Project/Datasets/D2-Slovak (EWA-DB)/EWA-DB/"
D2_OUT_ROOT = "/content/drive/MyDrive/AI_PD_Project/Datasets/D2-Slovak (EWA-DB)/EWA-DB/preprocessed_v1"
D2_MONO_SUMMARY = Path(D2_OUT_ROOT) / "monolingual_test_runs" / "summary_test.json"


# -------------------------
# Fixed runtime settings
# Inputs: constants for evaluation
# Output: consistent inference setup across seeds (sample rate, batch size, AMP, device)
# -------------------------
SEEDS          = [1337, 2024, 7777]
BACKBONE_CKPT  = "facebook/wav2vec2-base"   # fallback if ckpt does not override
SR_EXPECTED    = 16000
TINY_THRESH    = 1e-4

EFFECTIVE_BS   = 64
PER_DEVICE_BS  = 16
GRAD_ACCUM     = max(1, EFFECTIVE_BS // PER_DEVICE_BS)  # printed only

NUM_WORKERS    = 0
PIN_MEMORY     = False

USE_AMP        = True

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")


print("D5_OUT_ROOT (target):", D5_OUT_ROOT)
print("D5_MANIFEST_ALL:", D5_MANIFEST_ALL)
print("D2_SOURCE_ROOT (source root):", D2_SOURCE_ROOT)
print("D2_OUT_ROOT (source preprocessed):", D2_OUT_ROOT)
print("D2_MONO_SUMMARY:", str(D2_MONO_SUMMARY))
print("DEVICE:", DEVICE)
if DEVICE.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
print("PER_DEVICE_BS:", PER_DEVICE_BS, "| EFFECTIVE_BS:", PER_DEVICE_BS * GRAD_ACCUM)
print("NUM_WORKERS:", NUM_WORKERS, "| PIN_MEMORY:", PIN_MEMORY)
print("USE_AMP:", bool(USE_AMP and DEVICE.type == "cuda"))


# -------------------------
# Fixed threshold loader (from D2 monolingual summary)
# Inputs: D2 summary_test.json (new or older formats)
# Outputs: single scalar FIXED_THR plus a string describing which key was used
# -------------------------
def _dig(d, path):
    cur = d
    for k in path:
        if not isinstance(cur, dict) or k not in cur:
            return None
        cur = cur[k]
    return cur

def _as_finite_float(x):
    try:
        v = float(x)
        return v if np.isfinite(v) else None
    except Exception:
        return None

def _load_fixed_thr_from_summary(summary: dict):
    # Candidate paths are checked in priority order to support older runs.
    candidates = [
        (["threshold_selection", "val_selected_threshold_mean_sd", "mean"], "threshold_selection.val_selected_threshold_mean_sd.mean"),
        (["threshold_selection", "val_selected_threshold_by_seed"], "threshold_selection.val_selected_threshold_by_seed (mean)"),

        (["val_selected_threshold_mean_sd", "mean"], "val_selected_threshold_mean_sd.mean"),
        (["val_selected_threshold_by_seed"], "val_selected_threshold_by_seed (mean)"),

        (["val_opt_threshold_mean_sd", "mean"], "val_opt_threshold_mean_sd.mean"),
        (["val_opt_threshold_by_seed"], "val_opt_threshold_by_seed (mean)"),

        (["val_optimal_threshold", "mean_sd", "mean"], "val_optimal_threshold.mean_sd.mean"),
    ]

    for path, key_name in candidates:
        v = _dig(summary, path)
        if v is None:
            continue

        fv = _as_finite_float(v)
        if fv is not None:
            return float(fv), key_name

        # If the field is a dict (seed → threshold), average finite values.
        if isinstance(v, dict) and len(v) > 0:
            vals = []
            for _, vv in v.items():
                fvv = _as_finite_float(vv)
                if fvv is not None:
                    vals.append(fvv)
            if len(vals) > 0:
                return float(np.mean(vals)), key_name

    return None, "NOT_FOUND"

def _find_latest_exp_with_all_seeds(trainval_root: Path, seeds):
    # Picks the most recently modified exp_* that contains best_heads.pt for all seeds.
    exp_dirs = sorted([p for p in trainval_root.glob("exp_*") if p.is_dir()],
                      key=lambda p: p.stat().st_mtime, reverse=True)
    if not exp_dirs:
        raise FileNotFoundError(f"No exp_* folders found under: {trainval_root}")

    def has_all(exp_path: Path):
        for s in seeds:
            p = exp_path / f"run_D2_seed{s}" / "best_heads.pt"
            if not p.exists():
                return False
        return True

    for ed in exp_dirs:
        if has_all(ed):
            return ed

    raise FileNotFoundError(
        f"Could not find an exp_* folder that has all best_heads.pt files for seeds={seeds} under {trainval_root}"
    )

def _fallback_thr_from_best_heads(d2_out_root: str, seeds):
    # Last-resort fallback: read a threshold-like field directly from the checkpoints and average across seeds.
    trainval_root = Path(d2_out_root) / "trainval_runs"
    chosen_exp = _find_latest_exp_with_all_seeds(trainval_root, seeds)

    thr_by_seed = {}
    for s in seeds:
        ckpt_path = chosen_exp / f"run_D2_seed{s}" / "best_heads.pt"
        state = torch.load(str(ckpt_path), map_location="cpu")
        if not isinstance(state, dict):
            continue

        for k in ["val_threshold", "val_thr", "threshold", "thr"]:
            if k in state:
                fv = _as_finite_float(state[k])
                if fv is not None:
                    thr_by_seed[str(s)] = float(fv)
                    break

    if len(thr_by_seed) == 0:
        return None, "NO_THRESHOLD_IN_BEST_HEADS", str(chosen_exp), thr_by_seed

    return float(np.mean(list(thr_by_seed.values()))), "best_heads.pt:val_threshold(mean)", str(chosen_exp), thr_by_seed

if not D2_MONO_SUMMARY.exists():
    raise FileNotFoundError(
        "Missing D2 monolingual summary_test.json.\n"
        f"Expected: {str(D2_MONO_SUMMARY)}\n"
        "Run the D2 monolingual test cell first (the one that writes summary_test.json)."
    )

with open(D2_MONO_SUMMARY, "r", encoding="utf-8") as f:
    d2_summary = json.load(f)

FIXED_THR, THR_SRC_KEY = _load_fixed_thr_from_summary(d2_summary)

if FIXED_THR is None:
    FIXED_THR, THR_SRC_KEY, chosen_exp_fb, thr_by_seed_fb = _fallback_thr_from_best_heads(D2_OUT_ROOT, SEEDS)
    if FIXED_THR is None:
        raise RuntimeError(
            "Could not read a finite fixed threshold from D2 monolingual summary_test.json and fallback failed.\n"
            f"Summary path: {str(D2_MONO_SUMMARY)}\n"
            f"Fallback exp: {chosen_exp_fb}\n"
            f"thr_by_seed: {thr_by_seed_fb}"
        )

print(f"\nFixed transfer threshold (from D2 monolingual mean VAL-selected): {FIXED_THR:.6f}")
print("Threshold source key:", THR_SRC_KEY)


# -------------------------
# Load target manifest (D5) and keep TEST only
# Inputs: D5 manifest_all.csv
# Outputs: test_df table with clip paths, labels, task, sex, and other light metadata
# -------------------------
if not os.path.exists(D5_MANIFEST_ALL):
    raise FileNotFoundError(f"Missing D5 manifest_all.csv: {D5_MANIFEST_ALL}")

m = pd.read_csv(D5_MANIFEST_ALL)

req_cols = {"split", "clip_path", "label_num", "task", "sex", "age"}
missing = [c for c in sorted(req_cols) if c not in m.columns]
if missing:
    raise ValueError(f"D5 manifest missing required columns: {missing}. Found: {list(m.columns)}")

# TEST only (target reporting is test only)
m = m[m["split"].isin(["test"])].copy()
if len(m) == 0:
    raise RuntimeError("After filtering to split == 'test', D5 manifest has 0 rows.")

# dataset_id is used for naming output folders only.
if "dataset" in m.columns and m["dataset"].notna().any():
    dataset_id = str(m["dataset"].value_counts(dropna=True).idxmax())
    m = m[m["dataset"].astype(str) == dataset_id].copy()
else:
    dataset_id = "D5"

keep_cols = ["clip_path", "label_num", "task", "speaker_id", "sex", "age", "duration_sec", "split"]
for c in keep_cols:
    if c not in m.columns:
        m[c] = np.nan
m = m[keep_cols].copy()

test_df = m.reset_index(drop=True)

print(f"\nTarget dataset inferred: {dataset_id}")
print(f"TEST rows: {len(test_df)}")
print("TEST label counts:", test_df["label_num"].value_counts(dropna=False).to_dict())
print("TEST sex counts (raw):", test_df["sex"].value_counts(dropna=False).to_dict())

if len(test_df) == 0:
    raise RuntimeError("Target TEST split has 0 rows.")


# -------------------------
# Fail-fast: target clip files must exist
# Inputs: test_df clip_path values
# Output: early error with a few missing examples (prevents long runs that fail later)
# -------------------------
def _fail_fast_missing_paths(df: pd.DataFrame, name: str):
    missing_paths = []
    for p in tqdm(df["clip_path"].astype(str).tolist(), desc=f"Check {name} clip_path exists", dynamic_ncols=True):
        if not os.path.exists(p):
            missing_paths.append(p)
            if len(missing_paths) >= 10:
                break
    if missing_paths:
        raise FileNotFoundError(f"{name}: missing clip_path(s). Examples: {missing_paths}")

_fail_fast_missing_paths(test_df, "TEST")


# -------------------------
# Task routing (vowel vs other)
# Inputs: manifest task values
# Output: task_group column used to select the correct head at inference
# -------------------------
def _task_group(task_val) -> str:
    return "vowel" if str(task_val) == "vowl" else "other"

test_df["task_group"] = test_df["task"].apply(_task_group)


# -------------------------
# Sex normalization for fairness and plots
# Inputs: raw sex values from the manifest
# Output: sex_norm in {M, F, UNK} for group metrics and per-sex confusion charts
# -------------------------
def normalize_sex(val) -> str:
    if pd.isna(val):
        return "UNK"
    s = str(val).strip().lower()
    if s in {"m", "male", "man", "masc", "masculine"}:
        return "M"
    if s in {"f", "female", "woman", "fem", "feminine"}:
        return "F"
    return "UNK"

test_df["sex_norm"] = test_df["sex"].apply(normalize_sex)
print("TEST sex counts (normalized):", test_df["sex_norm"].value_counts(dropna=False).to_dict())
if (test_df["sex_norm"] == "UNK").any():
    print("NOTE: D5 v2 'sex' is blank or unmapped for some rows; they were counted as 'UNK' for fairness and sex charts.")


# -------------------------
# Audio dataset and collator (pads to batch max length)
# Inputs: test_df with clip_path, label_num, task_group, sex_norm
# Outputs: batched tensors (input_values, attention_mask, labels) plus task_group and sex_norm lists
# -------------------------
class AudioManifestDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        clip_path = str(row["clip_path"])
        label = int(row["label_num"])
        task_group = str(row["task_group"])
        sex_norm = str(row["sex_norm"])

        if not os.path.exists(clip_path):
            raise FileNotFoundError(f"Missing clip_path: {clip_path}")

        y, sr = sf.read(clip_path, always_2d=False)
        if isinstance(y, np.ndarray) and y.ndim > 1:
            y = np.mean(y, axis=1)
        y = np.asarray(y, dtype=np.float32)

        if int(sr) != SR_EXPECTED:
            raise ValueError(f"Sample rate mismatch (expected {SR_EXPECTED}): {clip_path} has sr={sr}")

        # Attention mask rule:
        # - vowel: mask trailing near-zero padding (prevents scoring silence)
        # - other: keep full attention
        attn = np.ones((len(y),), dtype=np.int64)
        if task_group == "vowel":
            k = -1
            for j in range(len(y) - 1, -1, -1):
                if abs(float(y[j])) > TINY_THRESH:
                    k = j
                    break
            if k >= 0:
                attn[:k+1] = 1
                attn[k+1:] = 0
            else:
                attn[:] = 1
        else:
            attn[:] = 1

        return {
            "input_values": torch.from_numpy(y),
            "attention_mask": torch.from_numpy(attn),
            "labels": torch.tensor(label, dtype=torch.long),
            "task_group": task_group,
            "sex_norm": sex_norm,
        }

def collate_fn(batch):
    """Pads waveforms and attention masks to the longest clip in the batch."""
    max_len = int(max(b["input_values"].numel() for b in batch))
    input_vals, attn_masks, labels, task_groups, sex_norms = [], [], [], [], []
    for b in batch:
        x = b["input_values"]
        a = b["attention_mask"]
        pad = max_len - x.numel()
        if pad > 0:
            x = torch.cat([x, torch.zeros(pad, dtype=x.dtype)], dim=0)
            a = torch.cat([a, torch.zeros(pad, dtype=a.dtype)], dim=0)
        input_vals.append(x)
        attn_masks.append(a)
        labels.append(b["labels"])
        task_groups.append(b["task_group"])
        sex_norms.append(b["sex_norm"])
    return {
        "input_values": torch.stack(input_vals, dim=0),
        "attention_mask": torch.stack(attn_masks, dim=0),
        "labels": torch.stack(labels, dim=0),
        "task_group": task_groups,
        "sex_norm": sex_norms,
    }


# -------------------------
# Transfer model (frozen backbone + two heads)
# Inputs: backbone checkpoint name + dropout
# Output: logits for class 0/1, routing by task_group (vowel vs other)
# -------------------------
class FrozenW2V2TwoHeadTest(nn.Module):
    def __init__(self, ckpt: str, dropout_p: float):
        super().__init__()
        self.backbone = Wav2Vec2Model.from_pretrained(ckpt)
        for p in self.backbone.parameters():
            p.requires_grad = False

        H = int(self.backbone.config.hidden_size)
        self.head_vowel = nn.Sequential(
            nn.LayerNorm(H),
            nn.Dropout(float(dropout_p)),
            nn.Linear(H, 2),
        )
        self.head_other = nn.Sequential(
            nn.LayerNorm(H),
            nn.Dropout(float(dropout_p)),
            nn.Linear(H, 2),
        )

    def masked_mean_pool(self, last_hidden, attn_mask_samples):
        feat_mask = self.backbone._get_feature_vector_attention_mask(last_hidden.shape[1], attn_mask_samples)
        feat_mask = feat_mask.to(last_hidden.device).unsqueeze(-1).type_as(last_hidden)
        masked = last_hidden * feat_mask
        denom = feat_mask.sum(dim=1).clamp(min=1.0)
        return masked.sum(dim=1) / denom

    def _heads_fp32(self, x_any: torch.Tensor, head: nn.Module) -> torch.Tensor:
        # Heads run in float32 to avoid dtype edge cases during AMP inference.
        x = x_any.float()
        if x.is_cuda:
            with torch.amp.autocast(device_type="cuda", enabled=False):
                return head(x)
        return head(x)

    def forward_logits(self, input_values, attention_mask, task_group):
        # Backbone is frozen and runs under no_grad.
        with torch.no_grad():
            out = self.backbone(input_values=input_values, attention_mask=attention_mask)
            last_hidden = out.last_hidden_state

        pooled = self.masked_mean_pool(last_hidden, attention_mask)

        logits_v = self._heads_fp32(pooled, self.head_vowel)
        logits_o = self._heads_fp32(pooled, self.head_other)

        # Route each sample to the correct head.
        is_vowel = torch.tensor([tg == "vowel" for tg in task_group], device=pooled.device, dtype=torch.bool)
        logits = logits_o.clone()
        if is_vowel.any():
            logits[is_vowel] = logits_v[is_vowel]
        return logits


def load_heads_into_model(model: FrozenW2V2TwoHeadTest, best_heads_path: Path, state: dict):
    """
    Loads head weights from best_heads.pt.
    Inputs: checkpoint dict and expected head keys
    Output: model with head_vowel and head_other weights populated
    """
    if not best_heads_path.exists():
        raise FileNotFoundError(f"Missing best_heads.pt: {str(best_heads_path)}")

    if not isinstance(state, dict):
        raise ValueError(f"Unexpected best_heads.pt type: {type(state)}")

    if "head_vowel" in state and "head_other" in state and isinstance(state["head_vowel"], dict) and isinstance(state["head_other"], dict):
        model.head_vowel.load_state_dict(state["head_vowel"], strict=True)
        model.head_other.load_state_dict(state["head_other"], strict=True)
        return model

    for wrap_key in ["state_dict", "model_state_dict", "model"]:
        if wrap_key in state and isinstance(state[wrap_key], dict):
            sd = state[wrap_key]
            _missing, _unexpected = model.load_state_dict(sd, strict=False)
            has_head = any(k.startswith("head_vowel.") or k.startswith("head_other.") for k in sd.keys())
            if not has_head:
                raise KeyError(f"Checkpoint wrapper '{wrap_key}' did not contain head keys. First keys: {list(sd.keys())[:25]}")
            return model

    raise KeyError(
        "best_heads.pt did not match expected formats.\n"
        f"Top-level keys: {list(state.keys())[:50]}"
    )


# -------------------------
# Metrics and plots
# Inputs: y_true and y_score (probability of class 1)
# Outputs: AUROC, threshold metrics, ROC and confusion PNGs
# -------------------------
def compute_auc(y_true, y_prob):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    if len(np.unique(y_true)) < 2:
        return float("nan")
    return float(roc_auc_score(y_true, y_prob))

def compute_threshold_metrics(y_true, y_prob, thr):
    # Computes confusion matrix and common classification metrics at a fixed threshold.
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])

    eps = 1e-12
    acc = (TP + TN) / max(1, (TP + TN + FP + FN))
    prec = TP / (TP + FP + eps)
    rec = TP / (TP + FN + eps)     # sensitivity
    f1 = 2 * prec * rec / (prec + rec + eps)
    spec = TN / (TN + FP + eps)

    try:
        mcc = float(matthews_corrcoef(y_true, y_pred)) if len(np.unique(y_pred)) > 1 else float("nan")
    except Exception:
        mcc = float("nan")

    try:
        _, pval = fisher_exact([[TN, FP], [FN, TP]], alternative="two-sided")
        pval = float(pval)
    except Exception:
        pval = float("nan")

    return {
        "threshold": float(thr),
        "confusion_matrix": {"TN": TN, "FP": FP, "FN": FN, "TP": TP},
        "accuracy": float(acc),
        "precision": float(prec),
        "recall": float(rec),
        "f1_score": float(f1),
        "sensitivity": float(rec),
        "specificity": float(spec),
        "mcc": float(mcc),
        "p_value_fisher_two_sided": float(pval),
    }

def save_roc_curve_png(y_true, y_prob, out_png, title_suffix="Set"):
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    plt.figure()
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curve ({title_suffix})")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def save_confusion_png(y_true, y_prob, out_png, thr, title_suffix="Set"):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(int)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    plt.figure()
    plt.imshow(cm)
    plt.xticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.yticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix ({title_suffix}, thr={thr:.4f})")
    for (i, j), v in np.ndenumerate(cm):
        plt.text(j, i, str(v), ha="center", va="center")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()


# -------------------------
# Fairness: FNR by sex and ΔFNR (F − M)
# Inputs: labels, scores, sex_norm, fixed threshold
# Outputs: per-group FNR plus ΔFNR and |ΔFNR|
# -------------------------
def compute_fnr_by_group_signed(y_true, y_prob, groups, thr):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    groups = np.asarray(list(groups), dtype=object)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    out = {}
    for g in sorted(set(groups.tolist())):
        mask_g = (groups == g)
        if int(mask_g.sum()) == 0:
            continue

        pos_mask = mask_g & (y_true == 1)
        n_pos = int(pos_mask.sum())
        if n_pos == 0:
            out[g] = {"n_total": int(mask_g.sum()), "n_pos": 0, "tp": 0, "fn": 0, "fnr": float("nan")}
            continue

        tp = int(((y_pred == 1) & pos_mask).sum())
        fn = int(((y_pred == 0) & pos_mask).sum())
        fnr = float(fn / max(1, (fn + tp)))

        out[g] = {"n_total": int(mask_g.sum()), "n_pos": int(n_pos), "tp": int(tp), "fn": int(fn), "fnr": float(fnr)}

    fnr_m = out.get("M", {}).get("fnr", float("nan"))
    fnr_f = out.get("F", {}).get("fnr", float("nan"))

    if (not np.isnan(fnr_m)) and (not np.isnan(fnr_f)):
        delta_signed = float(fnr_f - fnr_m)   # H3
        delta_abs = float(abs(delta_signed))
    else:
        delta_signed = float("nan")
        delta_abs = float("nan")

    return out, delta_signed, delta_abs


# -------------------------
# Confusion counts by sex group (for reporting)
# Inputs: labels, scores, sex_norm, fixed threshold
# Outputs: per-group confusion counts (TN/FP/FN/TP)
# -------------------------
def compute_confusion_counts(y_true, y_prob, thr):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])
    return {"TN": TN, "FP": FP, "FN": FN, "TP": TP}

def compute_confusion_by_group(y_true, y_prob, groups, thr):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    groups = np.asarray(list(groups), dtype=object)
    out = {}
    for g in sorted(set(groups.tolist())):
        mask = (groups == g)
        if int(mask.sum()) == 0:
            continue
        out[g] = {"n": int(mask.sum()), "confusion": compute_confusion_counts(y_true[mask], y_prob[mask], thr=thr)}
    return out


# -------------------------
# Reproducible seeding (per seed run)
# Inputs: seed integer
# Output: deterministic torch/numpy/python RNG behavior
# -------------------------
def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# -------------------------
# Choose the latest D2 trainval experiment that has all seeds
# Inputs: D2 trainval_runs/exp_* folders
# Output: chosen_exp used to load best_heads.pt for each seed
# -------------------------
D2_TRAINVAL_ROOT = Path(D2_OUT_ROOT) / "trainval_runs"
if not D2_TRAINVAL_ROOT.exists():
    raise FileNotFoundError(f"Missing D2 trainval_runs folder: {str(D2_TRAINVAL_ROOT)}")

exp_dirs = sorted([p for p in D2_TRAINVAL_ROOT.glob("exp_*") if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
if not exp_dirs:
    raise FileNotFoundError(f"No exp_* folders found under: {str(D2_TRAINVAL_ROOT)}")

def _has_all_seeds_d2(exp_path: Path, seeds: list):
    for s in seeds:
        p = exp_path / f"run_D2_seed{s}" / "best_heads.pt"
        if not p.exists():
            return False
    return True

chosen_exp = None
for ed in exp_dirs:
    if _has_all_seeds_d2(ed, SEEDS):
        chosen_exp = ed
        break

if chosen_exp is None:
    sample = exp_dirs[0]
    raise FileNotFoundError(
        "Could not find a D2 trainval experiment with all 3 best_heads.pt files.\n"
        f"Expected under: {str(D2_TRAINVAL_ROOT)}/exp_*/run_D2_seedXXXX/best_heads.pt\n"
        f"Most recent exp checked: {str(sample)}"
    )

print("\nUsing SOURCE (D2) Train+Val experiment folder:")
print(" ", str(chosen_exp))


# -------------------------
# Transfer output folder (one run root for all seeds)
# Inputs: D5_OUT_ROOT + dataset_id + timestamp
# Outputs: TRANSFER_ROOT with per-seed subfolders and summary files
# -------------------------
TRANSFER_ROOT = Path(D5_OUT_ROOT) / "zeroshot_transfer_runs" / f"D2_to_{dataset_id}_{time.strftime('%Y%m%d_%H%M%S')}"
TRANSFER_ROOT.mkdir(parents=True, exist_ok=True)


# -------------------------
# Target test DataLoader (D5 test only)
# Inputs: test_df
# Output: test_loader for batched inference
# -------------------------
test_ds = AudioManifestDataset(test_df)
test_loader = DataLoader(
    test_ds,
    batch_size=PER_DEVICE_BS,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    collate_fn=collate_fn,
)

def warmup_loader(loader, name: str, max_batches: int = 3):
    """Quick I/O warm-up to catch obvious read/shape issues early."""
    print(f"\nWarm-up: loading up to {max_batches} {name} batches...")
    t0 = time.time()
    nb = len(loader)
    wb = min(max_batches, nb)
    if wb == 0:
        raise RuntimeError(f"{name} DataLoader has 0 batches. Check df length and PER_DEVICE_BS.")
    it = iter(loader)
    for i in range(wb):
        _ = next(it)
        print(f"  loaded {name} warmup batch {i+1}/{wb}")
    print(f"Warm-up done in {time.time()-t0:.2f}s")

warmup_loader(test_loader, "TEST", max_batches=2)


# -------------------------
# Inference loop (probability of PD)
# Inputs: model, dataloader
# Outputs: arrays of y_true, y_score, sex_norm for metrics and plots
# -------------------------
def run_inference(model: FrozenW2V2TwoHeadTest, loader: DataLoader, desc: str, use_amp: bool):
    all_probs, all_true, all_sex = [], [], []
    pbar = tqdm(loader, desc=desc, dynamic_ncols=True)
    with torch.inference_mode():
        for batch in pbar:
            input_values = batch["input_values"].to(DEVICE, non_blocking=False)
            attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
            labels = batch["labels"].to(DEVICE, non_blocking=False)
            task_group = batch["task_group"]
            sex_norm = batch["sex_norm"]

            with torch.amp.autocast(device_type="cuda", enabled=use_amp):
                logits = model.forward_logits(input_values, attention_mask, task_group)

            probs = torch.softmax(logits.float(), dim=-1)[:, 1].detach().cpu().numpy().astype(np.float64)
            all_probs.extend(probs.tolist())
            all_true.extend(labels.detach().cpu().numpy().astype(np.int64).tolist())
            all_sex.extend(list(sex_norm))

    return np.asarray(all_true, dtype=np.int64), np.asarray(all_probs, dtype=np.float64), np.asarray(all_sex, dtype=object)


# -------------------------
# Single seed run: load D2 heads and evaluate D5 test at FIXED_THR
# Inputs: seed, chosen_exp checkpoints, fixed threshold, D5 test_loader
# Outputs: metrics.json + ROC/confusion PNGs in run_{dataset}_seedXXXX/
# -------------------------
def run_seed(seed: int):
    set_all_seeds(seed)

    run_dir = TRANSFER_ROOT / f"run_{dataset_id}_seed{seed}"
    run_dir.mkdir(parents=True, exist_ok=True)

    best_heads_path = chosen_exp / f"run_D2_seed{seed}" / "best_heads.pt"

    print(f"\n[seed={seed}] Loading SOURCE heads from:")
    print(" ", str(best_heads_path))
    print(f"[seed={seed}] Evaluating TARGET {dataset_id} TEST @ fixed_thr={FIXED_THR:.6f}")

    state = torch.load(str(best_heads_path), map_location="cpu")

    # Optionally inherit backbone/dropout values from the checkpoint for consistency.
    ckpt_backbone = str(state.get("backbone_ckpt", BACKBONE_CKPT)) if isinstance(state, dict) else BACKBONE_CKPT
    ckpt_dropout  = float(state.get("dropout_p", 0.10)) if isinstance(state, dict) else 0.10

    model = FrozenW2V2TwoHeadTest(ckpt_backbone, dropout_p=ckpt_dropout).to(DEVICE)
    model = load_heads_into_model(model, best_heads_path, state)
    model.eval()

    use_amp = bool(USE_AMP and DEVICE.type == "cuda")

    # Target inference
    yt, pt, st = run_inference(model, test_loader, desc=f"[seed={seed}] Target TEST", use_amp=use_amp)

    test_auc = compute_auc(yt, pt)
    thr_metrics = compute_threshold_metrics(yt, pt, thr=FIXED_THR)

    # Fairness at the same fixed threshold used for all seeds.
    fnr_by_sex, delta_f_minus_m, delta_abs = compute_fnr_by_group_signed(yt, pt, st, thr=FIXED_THR)

    # Confusion by sex (uses all labels, not PD-only).
    confusion_by_sex = compute_confusion_by_group(yt, pt, st, thr=FIXED_THR)

    # Plots saved under the per-seed folder.
    roc_png = run_dir / "roc_curve.png"
    cm_png  = run_dir / "confusion_matrix.png"
    save_roc_curve_png(yt, pt, str(roc_png), title_suffix=f"Target Test (seed={seed})")
    save_confusion_png(yt, pt, str(cm_png), thr=FIXED_THR, title_suffix=f"Target Test (seed={seed})")

    cm_m_png = None
    cm_f_png = None

    mask_m = (st == "M")
    mask_f = (st == "F")

    if int(mask_m.sum()) > 0:
        cm_m_png = run_dir / "confusion_matrix_M.png"
        save_confusion_png(yt[mask_m], pt[mask_m], str(cm_m_png), thr=FIXED_THR, title_suffix=f"Target SEX=M (seed={seed})")

    if int(mask_f.sum()) > 0:
        cm_f_png = run_dir / "confusion_matrix_F.png"
        save_confusion_png(yt[mask_f], pt[mask_f], str(cm_f_png), thr=FIXED_THR, title_suffix=f"Target SEX=F (seed={seed})")

    # metrics.json is the per-seed record used by the final summary.
    metrics = {
        "source_dataset": "D2",
        "target_dataset": dataset_id,
        "seed": int(seed),

        "fixed_threshold": float(FIXED_THR),
        "fixed_threshold_source_key": THR_SRC_KEY,
        "fixed_threshold_source_summary": str(D2_MONO_SUMMARY),

        "n_test_target": int(len(test_df)),
        "label_counts_test_target": test_df["label_num"].value_counts(dropna=False).to_dict(),
        "sex_counts_test_target_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),

        "target_test_auroc": float(test_auc),
        "threshold_metrics_target_test": thr_metrics,

        "fairness_target_test": {
            "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at fixed_threshold.",
            "fnr_by_sex_norm": fnr_by_sex,
            "delta_fnr_F_minus_M": float(delta_f_minus_m),
            "delta_fnr_abs": float(delta_abs),
            "note": "If n_PD for a sex is 0, its FNR is NaN and ΔFNR is NaN.",
            "sex_normalization_note": "sex_norm in {M,F,UNK}. Values not mapped to M/F counted as UNK.",
        },

        "confusion_by_sex_norm": confusion_by_sex,

        "artifacts": {
            "roc_curve_png": str(roc_png),
            "confusion_matrix_png": str(cm_png),
            "confusion_matrix_M_png": str(cm_m_png) if cm_m_png is not None else None,
            "confusion_matrix_F_png": str(cm_f_png) if cm_f_png is not None else None,
        },

        "d5_out_root": D5_OUT_ROOT,
        "d5_manifest_all": D5_MANIFEST_ALL,

        "d2_out_root": D2_OUT_ROOT,
        "d2_source_root": D2_SOURCE_ROOT,
        "d2_trainval_experiment_used": str(chosen_exp),
        "best_heads_path_used": str(best_heads_path),

        "backbone_ckpt_used": ckpt_backbone,
        "dropout_p_used": float(ckpt_dropout),

        "timestamp": time.strftime("%Y%m%d_%H%M%S"),
    }

    metrics_path = run_dir / "metrics.json"
    with open(metrics_path, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2)

    def _fmt_fnr(g):
        d = fnr_by_sex.get(g, None)
        if d is None:
            return "n/a"
        return f"fnr={d['fnr']:.6f} (n_PD={d['n_pos']}, fn={d['fn']}, tp={d['tp']})"

    print(f"[seed={seed}] DONE | target_test_AUROC={test_auc:.6f} | fixed_thr={FIXED_THR:.6f}")
    print(f"[seed={seed}] FAIRNESS (H3) @ fixed_thr={FIXED_THR:.6f}:")
    print("  M:", _fmt_fnr("M"))
    print("  F:", _fmt_fnr("F"))
    if "UNK" in fnr_by_sex:
        print("  UNK:", _fmt_fnr("UNK"))
    print("  ΔFNR (F-M):", f"{delta_f_minus_m:.6f}" if not np.isnan(delta_f_minus_m) else "nan")
    print("  |ΔFNR|:", f"{delta_abs:.6f}" if not np.isnan(delta_abs) else "nan")

    print(f"[seed={seed}] WROTE:")
    print(" ", str(metrics_path))
    print(" ", str(roc_png))
    print(" ", str(cm_png))
    if cm_m_png is not None:
        print(" ", str(cm_m_png))
    if cm_f_png is not None:
        print(" ", str(cm_f_png))

    return {
        "seed": int(seed),
        "target_test_auroc": float(test_auc),
        "thr_metrics": thr_metrics,
        "fnr_by_sex": fnr_by_sex,
        "delta_signed": float(delta_f_minus_m),
        "delta_abs": float(delta_abs),
        "confusion_by_sex": confusion_by_sex,
        "run_dir": str(run_dir),
    }


# -------------------------
# Run all seeds and write transfer summary
# Inputs: 3 per-seed results
# Outputs: summary_transfer.json + history_index.jsonl under TRANSFER_ROOT
# -------------------------
seed_results = []
for seed in SEEDS:
    seed_results.append(run_seed(seed))

aurocs = [r["target_test_auroc"] for r in seed_results]
t_crit = 4.302652729911275  # df=2, 95% CI
n = len(aurocs)
mean_auc = float(np.mean(aurocs))
std_auc = float(np.std(aurocs, ddof=1)) if n > 1 else 0.0
half_width = float(t_crit * (std_auc / math.sqrt(n))) if n > 1 else 0.0
ci95 = [float(mean_auc - half_width), float(mean_auc + half_width)]

def _mean_sd(vals):
    vals = np.asarray(vals, dtype=np.float64)
    return float(np.nanmean(vals)), float(np.nanstd(vals, ddof=1)) if np.sum(~np.isnan(vals)) > 1 else 0.0

keys = ["accuracy","precision","recall","f1_score","sensitivity","specificity","mcc","p_value_fisher_two_sided"]
agg = {}
for k in keys:
    v = [float(r["thr_metrics"].get(k, float("nan"))) for r in seed_results]
    mu, sd = _mean_sd(v)
    agg[k] = {
        "mean": mu,
        "sd": sd,
        "values_by_seed": {str(r["seed"]): float(r["thr_metrics"].get(k, float("nan"))) for r in seed_results},
    }

cm_by_seed = {str(r["seed"]): r["thr_metrics"]["confusion_matrix"] for r in seed_results}

fnr_m_vals, fnr_f_vals, n_pd_m_vals, n_pd_f_vals = [], [], [], []
d_signed_vals, d_abs_vals = [], []

fnr_by_seed = {}
delta_signed_by_seed = {}
delta_abs_by_seed = {}
confusion_by_sex_by_seed = {}
run_dirs = []

for r in seed_results:
    s = str(r["seed"])
    fnr_by_seed[s] = r["fnr_by_sex"]
    delta_signed_by_seed[s] = float(r["delta_signed"])
    delta_abs_by_seed[s] = float(r["delta_abs"])
    confusion_by_sex_by_seed[s] = r["confusion_by_sex"]
    run_dirs.append(r["run_dir"])

    d = r["fnr_by_sex"]
    fnr_m_vals.append(float(d.get("M", {}).get("fnr", float("nan"))))
    fnr_f_vals.append(float(d.get("F", {}).get("fnr", float("nan"))))
    n_pd_m_vals.append(float(d.get("M", {}).get("n_pos", float("nan"))))
    n_pd_f_vals.append(float(d.get("F", {}).get("n_pos", float("nan"))))
    d_signed_vals.append(float(r["delta_signed"]))
    d_abs_vals.append(float(r["delta_abs"]))

fnr_m_mean, fnr_m_sd = _mean_sd(fnr_m_vals)
fnr_f_mean, fnr_f_sd = _mean_sd(fnr_f_vals)
d_signed_mean, d_signed_sd = _mean_sd(d_signed_vals)
d_abs_mean, d_abs_sd = _mean_sd(d_abs_vals)

print("\nTARGET Test AUROC by seed:")
for r in seed_results:
    print(f"  seed {r['seed']}: {r['target_test_auroc']:.6f}")
print(f"\nMean TARGET Test AUROC: {mean_auc:.6f}")
print(f"95% CI (t, n=3): [{ci95[0]:.6f}, {ci95[1]:.6f}]")

print(f"\nThreshold metrics on TARGET TEST @ fixed_thr={FIXED_THR:.6f} (mean ± SD across seeds):")
for k in ["accuracy","precision","sensitivity","specificity","f1_score","mcc"]:
    mu = agg[k]["mean"]
    sd = agg[k]["sd"]
    print(f"  {k}: {mu:.6f} ± {sd:.6f}")
print("  fisher_p_value_two_sided:", f"{agg['p_value_fisher_two_sided']['mean']:.6g} ± {agg['p_value_fisher_two_sided']['sd']:.6g}")

print(f"\nFAIRNESS (H3) on TARGET TEST @ fixed_thr={FIXED_THR:.6f} (mean ± SD):")
print(f"  FNR_M: {fnr_m_mean:.6f} ± {fnr_m_sd:.6f}")
print(f"  FNR_F: {fnr_f_mean:.6f} ± {fnr_f_sd:.6f}")
print(f"  ΔFNR (F-M): {d_signed_mean:.6f} ± {d_signed_sd:.6f}")
print(f"  |ΔFNR|: {d_abs_mean:.6f} ± {d_abs_sd:.6f}")

summary = {
    "source_dataset": "D2",
    "target_dataset": dataset_id,

    "d5_out_root": D5_OUT_ROOT,
    "d5_manifest_all": D5_MANIFEST_ALL,

    "d2_source_root": D2_SOURCE_ROOT,
    "d2_out_root": D2_OUT_ROOT,
    "d2_monolingual_summary": str(D2_MONO_SUMMARY),
    "d2_trainval_experiment_used": str(chosen_exp),

    "seeds": SEEDS,

    "fixed_threshold": float(FIXED_THR),
    "fixed_threshold_source_key": THR_SRC_KEY,

    "n_test_target": int(len(test_df)),
    "label_counts_test_target": test_df["label_num"].value_counts(dropna=False).to_dict(),
    "sex_counts_test_target_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),

    "test_aurocs_by_seed": {str(r["seed"]): float(r["target_test_auroc"]) for r in seed_results},
    "mean_test_auroc": float(mean_auc),
    "t_crit_df2_95": float(t_crit),
    "half_width_95_ci": float(half_width),
    "ci95_test_auroc": ci95,

    "threshold_metrics_mean_sd": agg,
    "confusion_matrix_by_seed": cm_by_seed,
    "run_dirs": run_dirs,

    "fairness_test": {
        "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at fixed_threshold.",
        "fnr_by_sex_norm_by_seed": fnr_by_seed,
        "delta_fnr_F_minus_M_by_seed": delta_signed_by_seed,
        "delta_fnr_abs_by_seed": delta_abs_by_seed,
        "fnr_M_mean_sd": {"mean": float(fnr_m_mean), "sd": float(fnr_m_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, fnr_m_vals)}},
        "fnr_F_mean_sd": {"mean": float(fnr_f_mean), "sd": float(fnr_f_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, fnr_f_vals)}},
        "delta_fnr_F_minus_M_mean_sd": {"mean": float(d_signed_mean), "sd": float(d_signed_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, d_signed_vals)}},
        "delta_fnr_abs_mean_sd": {"mean": float(d_abs_mean), "sd": float(d_abs_sd), "values_by_seed": {str(s): float(v) for s, v in zip(SEEDS, d_abs_vals)}},
        "denominators_PD_by_seed": {str(s): {"n_PD_M": float(n_pd_m_vals[i]), "n_PD_F": float(n_pd_f_vals[i])} for i, s in enumerate(SEEDS)},
        "sex_normalization_note": "sex_norm in {M,F,UNK}. Values not mapped to M/F counted as UNK. ΔFNR computed only when both M and F have defined FNR.",
    },

    "confusion_by_sex_norm_by_seed": confusion_by_sex_by_seed,

    "timestamp": time.strftime("%Y%m%d_%H%M%S"),
}

summary_path = TRANSFER_ROOT / "summary_transfer.json"
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(summary, f, indent=2)

history_path = TRANSFER_ROOT / "history_index.jsonl"
with open(history_path, "a", encoding="utf-8") as f:
    f.write(json.dumps(summary) + "\n")

print("\nWROTE summary:", str(summary_path))
print("APPENDED history index:", str(history_path))
print("Open this folder to access artifacts:", str(TRANSFER_ROOT))


# -------------------------
# Stop runtime (Colab L4)
# Inputs: Colab runtime availability
# Output: runtime unassigned (or a message explaining manual stop)
# -------------------------
print("\nAll done. Unassigning the runtime to stop the L4 instance...")
try:
    from google.colab import runtime  # type: ignore
    print("Calling runtime.unassign() now.")
    runtime.unassign()
except Exception as e:
    print("Could not unassign runtime automatically. You can stop the runtime manually in Colab.")
    print("Reason:", repr(e))

#Multilingual Model D7 using D1, D4, D5 and D6 Datasets

The following cell builds the merged multilingual dataset D7 by combining four already preprocessed datasets, D1, D4, D5v2, and D6, into one standard and self contained folder. It reads each source dataset’s `manifests/manifest_all.csv` as the source of truth, keeps the original train, validation, and test split assignments, and does not perform any re splitting. Every referenced audio clip is then copied into the D7 folder so the merged dataset can be used on its own without depending on the original dataset locations.

Before any files are copied, the cell applies strict rules to avoid unintended changes. It locks the manifest schema and column order to a fixed header, checks that all required columns are present, converts any literal `"NaN"` strings into true missing values, and normalizes key fields so they are consistent across all source datasets. The numeric label is forced to 0 or 1, the string label is forced to Healthy or Parkinson and kept consistent with the numeric label, and sex is restricted to M, F, or blank, with unknown values left blank instead of creating new categories. To avoid speaker ID collisions between datasets, the cell rewrites `speaker_key_rel` by prefixing it with the source dataset code, such as `D1__`. The dataset column is set to `"D7"` so downstream code treats the merged data as a single dataset.

The cell then performs a full preflight inventory check before copying any audio. For each row, it computes a deterministic destination filename based on the source dataset, label, speaker, task, and split counter. It verifies that every source audio file exists, checks for destination filename collisions, and confirms that if a destination file already exists it matches the expected file size. This allows safe re runs, since correct existing files are not overwritten or recopied. Even if the preflight step finds problems, the cell still writes standard log and configuration files first so issues are easy to debug and no run ends without metadata.

If the preflight step finds any fatal issues, the cell stops immediately without copying files or rewriting the manifest. In that case, it still writes `logs/preprocess_warnings.csv` with structured errors and warnings, `logs/dataset_summary.json` with counts and status information, and `config/run_config.json` describing the inputs, schema rules, and preflight results.

If the preflight step passes, the cell copies only the missing audio clips into `D7/clips/<split>/` using a no overwrite policy, and logs any copy errors. It then updates the merged table so `clip_path` points to the new D7 locations and `sample_id` matches the new filename stem. Final checks are run to confirm the schema order is correct, there are no literal `"NaN"` strings, labels are valid, and sex values are valid. The cell then writes all required outputs atomically to avoid partial files if the run is interrupted. These outputs include the `clips/` folder with audio organized by split, `manifests/manifest_all.csv` with the merged data and D7 local paths, `config/run_config.json` with final run details and copy statistics, `logs/preprocess_warnings.csv` with the full warning and error log, `logs/dataset_summary.json` with final counts and success status, and an empty `trainval_runs/` folder created for future training outputs.

In [None]:
# Build D7 Multilingual Dataset (Merge Builder)
# Merges multiple already-preprocessed datasets into one self-contained D7 folder by copying clips into D7,
# keeping the original split assignments, and writing a single locked-schema manifest plus logs and config.

# =========================
# D7 MERGE-BUILDER (CRASH-PROOF, STANDARDS-SAFE) — D1 + D4 + D5v2 + D6 → D7
#
# Purpose
# - Create a multilingual training root (D7) that mirrors the standard dataset layout
#   and is fully self-sufficient (clips stored under D7).
#
# Guarantees / Invariants
# - Manifest schema is locked to the attached sample header; 'source_dataset' is appended at the end.
# - Missing values remain true NaN in-memory and are written as blanks in CSV (never the literal string "NaN").
# - label_num ∈ {0, 1} and label_str ∈ {Healthy, Parkinson}, and both are mutually consistent.
# - sex ∈ {M, F} or blank/NaN.
# - speaker_key_rel is prefixed with the source dataset code to prevent collisions across merged datasets.
# - clip files are copied into D7 with collision-proof filenames; manifest clip_path is rewritten to D7 paths.
#
# Outputs written (and only these)
# - clips/
# - manifests/manifest_all.csv
# - config/run_config.json
# - logs/preprocess_warnings.csv
# - logs/dataset_summary.json
# - trainval_runs/ (folder only; training code writes inside later)
# =========================

import os, json, re, shutil
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

# -------------------------
# Mount Drive (Colab)
# Inputs: Colab runtime state
# Output: /content/drive available when running in Colab
# -------------------------
# Safe no-op outside Colab or if Drive is already mounted.
try:
    from google.colab import drive  # type: ignore
    if not os.path.exists("/content/drive"):
        drive.mount("/content/drive")
except Exception:
    pass

# -------------------------
# Merge inputs and output root
# Inputs: source manifest_all.csv files (already preprocessed)
# Output: D7_OUT_ROOT where the merged clips and manifest will be written
# -------------------------
D7_OUT_ROOT = Path("/content/drive/MyDrive/AI_PD_Project/Datasets/D7-Multilingual (D1_D4_D5v2_D6)/preprocessed_v1")

# Each source points to its already-preprocessed manifest_all.csv.
# Clips referenced by these manifests will be copied into D7.
SOURCES = [
    {"src": "D1", "manifest_all": "/content/drive/MyDrive/AI_PD_Project/Datasets/D1-NeuroVoz-Castillan Spanish/preprocessed_v1/manifests/manifest_all.csv"},
    {"src": "D4", "manifest_all": "/content/drive/MyDrive/AI_PD_Project/Datasets/D4-Italian (IPVS)/preprocessed_v1/manifests/manifest_all.csv"},
    {"src": "D5", "manifest_all": "/content/drive/MyDrive/AI_PD_Project/Datasets/D5-English (MDVR-KCL)/preprocessed_v2/manifests/manifest_all.csv"},  # D5 v2
    {"src": "D6", "manifest_all": "/content/drive/MyDrive/AI_PD_Project/Datasets/D6-Ah Sound (Figshare)/preprocessed_v1/manifests/manifest_all.csv"},
]

# Keep only these splits from each source; no split changes are made.
D7_INCLUDE_SPLITS = ["train", "val", "test"]

# -------------------------
# Locked manifest schema
# Inputs: expected column list and order
# Output: fail-fast protection against schema drift across sources
# -------------------------
# This list defines the canonical column set and the exact order expected in every source manifest.
# A source manifest missing any of these columns triggers a fail-fast error to prevent schema drift.
CANON_COLS = [
    "split",
    "dataset",
    "task",
    "speaker_id",
    "sample_id",
    "label_str",
    "label_num",
    "age",
    "sex",
    "speaker_key_rel",
    "clip_path",
    "duration_sec",
    "source_path",
    "clip_start_sec",
    "clip_end_sec",
    "sr_hz",
    "channels",
    "clip_is_contiguous",
]
# D7 adds provenance at the end only (no reordering of canonical fields).
FINAL_COLS = CANON_COLS + ["source_dataset"]

# Redundant minimum set used by core operations.
REQ = ["split", "clip_path", "label_num", "label_str", "sex", "dataset", "speaker_id", "task", "sample_id", "speaker_key_rel"]

# -------------------------
# Create D7 folder structure
# Inputs: D7_OUT_ROOT and split list
# Output: standard subfolders (clips/manifests/config/logs/trainval_runs)
# -------------------------
# Only these folders are created or written. No other directories are touched.
clips_dir = D7_OUT_ROOT / "clips"
manif_dir = D7_OUT_ROOT / "manifests"
cfg_dir   = D7_OUT_ROOT / "config"
logs_dir  = D7_OUT_ROOT / "logs"
trainval_runs_dir = D7_OUT_ROOT / "trainval_runs"

for sp in D7_INCLUDE_SPLITS:
    (clips_dir / sp).mkdir(parents=True, exist_ok=True)
manif_dir.mkdir(parents=True, exist_ok=True)
cfg_dir.mkdir(parents=True, exist_ok=True)
logs_dir.mkdir(parents=True, exist_ok=True)
trainval_runs_dir.mkdir(parents=True, exist_ok=True)

# -------------------------
# Utilities: guardrails, atomic writes, normalization
# Inputs: simple values and dataframes
# Outputs: consistent CSV writing and strict field standardization
# -------------------------
def require(cond, msg):
    """Stops immediately when an invariant is violated."""
    if not cond:
        raise RuntimeError(msg)

def atomic_write_text(dst: Path, text: str):
    """Writes through a temp file to avoid partial output on interruption."""
    tmp = dst.with_suffix(dst.suffix + ".tmp")
    with open(tmp, "w", encoding="utf-8") as f:
        f.write(text)
    os.replace(tmp, dst)

def atomic_write_csv(dst: Path, df: pd.DataFrame):
    """Writes CSV with blanks for NaN values (never the literal string 'NaN')."""
    tmp = dst.with_suffix(dst.suffix + ".tmp")
    df.to_csv(tmp, index=False, na_rep="")
    os.replace(tmp, dst)

def normalize_sex_to_MF(x):
    """
    Normalizes sex to 'M' or 'F' when possible.
    Unknown values are set to NaN to keep categories stable.

    Update: also handle numeric 1/0 encodings that may appear in D1 as floats (1.0/0.0).
    """
    if pd.isna(x):
        return np.nan

    # Handle numeric 1/0 directly (covers int/float like 1, 0, 1.0, 0.0)
    if isinstance(x, (int, np.integer)):
        if int(x) == 1:
            return "M"
        if int(x) == 0:
            return "F"
        return np.nan
    if isinstance(x, (float, np.floating)):
        if np.isfinite(x):
            if float(x) == 1.0:
                return "M"
            if float(x) == 0.0:
                return "F"
        return np.nan

    s = str(x).strip().lower()
    if s in ["m", "male", "1", "1.0"]:
        return "M"
    if s in ["f", "female", "0", "0.0"]:
        return "F"
    if s in ["", "nan", "none", "unknown", "u"]:
        return np.nan
    return np.nan  # do not invent categories

def normalize_label_num(x):
    """
    Normalizes label_num to integers 0/1.
    Allows common encodings; unknown values become NaN (then fail-fast later).
    """
    if pd.isna(x):
        return np.nan
    if isinstance(x, str):
        t = x.strip().lower()
        if t in ["0", "healthy", "hc"]:
            return 0
        if t in ["1", "parkinson", "pd"]:
            return 1
        return np.nan
    try:
        return int(x)
    except Exception:
        return np.nan

def label_str_from_num(v):
    """Maps label_num to canonical strings: Healthy or Parkinson."""
    if pd.isna(v):
        return np.nan
    v = int(v)
    if v == 0:
        return "Healthy"
    if v == 1:
        return "Parkinson"
    return np.nan

def hc_pd_from_label_num(v):
    """Short label for filenames only (manifest keeps the canonical label_str)."""
    return "PD" if int(v) == 1 else "HC"

def safe_token(s, max_len=32, default="NA"):
    """
    Creates a filename-safe token from a value (speaker_id, task).
    Used only for output clip filenames.
    """
    if pd.isna(s):
        return default
    s = str(s).strip()
    s = re.sub(r"\s+", "_", s)
    s = re.sub(r"[^A-Za-z0-9_]+", "", s)
    return (s[:max_len] if s else default)

# -------------------------
# Warning log accumulator
# Inputs: event details during preflight/copy
# Output: logs/preprocess_warnings.csv at multiple points
# -------------------------
warnings_rows = []
def add_warn(src, level, code, message, **extra):
    """Appends a structured log row for later CSV export."""
    row = {
        "ts": datetime.utcnow().isoformat(),
        "src": src,
        "level": level,
        "code": code,
        "message": message,
    }
    row.update(extra)
    warnings_rows.append(row)

def count_by_level(rows):
    """Counts INFO/WARN/ERROR entries for quick summaries."""
    out = {"ERROR": 0, "WARN": 0, "INFO": 0}
    for r in rows:
        lvl = str(r.get("level", "INFO")).upper()
        if lvl not in out:
            out[lvl] = 0
        out[lvl] += 1
    return out

# -------------------------
# Load, validate, normalize, and merge manifests
# Inputs: each source manifest_all.csv
# Output: merged dataframe with locked columns plus source_dataset
# -------------------------
print("D7_OUT_ROOT:", str(D7_OUT_ROOT))
print("Sources:")
for s in SOURCES:
    print(f"  - {s['src']}: {s['manifest_all']}")

parts = []
source_counts = {}

for s in SOURCES:
    src = s["src"]
    mpath = Path(s["manifest_all"])
    require(mpath.exists(), f"{src} manifest not found: {mpath}")

    df = pd.read_csv(mpath)

    # Schema lock: all canonical columns must exist and keep the expected meaning.
    missing = [c for c in CANON_COLS if c not in df.columns]
    require(len(missing) == 0, f"{src} manifest missing required columns (locked schema): {missing}")

    # Keep only requested splits; no split mutation is performed.
    df = df[df["split"].isin(D7_INCLUDE_SPLITS)].copy()

    # Convert literal "NaN" strings to true NaN for common fields.
    for col in ["sex", "age", "duration_sec", "clip_start_sec", "clip_end_sec", "speaker_key_rel", "speaker_id", "task", "sample_id"]:
        if col in df.columns:
            df[col] = df[col].replace("NaN", np.nan)

    # Normalize labels and enforce 0/1 + Healthy/Parkinson consistency.
    df["label_num"] = df["label_num"].map(normalize_label_num)
    bad_num = sorted(set(df["label_num"].dropna().unique()) - {0, 1})
    require(len(bad_num) == 0, f"{src} label_num has values outside {{0,1}} after normalization: {bad_num}")

    df["label_str"] = df["label_num"].map(label_str_from_num)
    bad_str = sorted(set(df["label_str"].dropna().unique()) - {"Healthy", "Parkinson"})
    require(len(bad_str) == 0, f"{src} label_str has values outside {{Healthy, Parkinson}} after normalization: {bad_str}")

    # Normalize sex to M/F/blank only.
    df["sex"] = df["sex"].map(normalize_sex_to_MF)
    bad_sex = sorted(set(df["sex"].dropna().unique()) - {"M", "F"})
    require(len(bad_sex) == 0, f"{src} sex has unexpected values after normalization: {bad_sex}")

    # Force dataset = D7 so training code writes runs using the merged dataset id.
    df["dataset"] = "D7"

    # Prefix speaker keys to avoid collisions across datasets.
    df["speaker_key_rel"] = df["speaker_key_rel"].astype(str).map(lambda x: f"{src}__{x}")

    # Add provenance at the end only.
    df["source_dataset"] = src

    out = df[FINAL_COLS].copy()
    parts.append(out)
    source_counts[src] = int(len(out))

require(len(parts) > 0, "No rows loaded after split filtering.")
merged = pd.concat(parts, axis=0, ignore_index=True)

print("\nMerged rows (after split filter) by source:", source_counts)
print("Total merged rows:", int(len(merged)))
print("Merged split counts:", merged["split"].value_counts(dropna=False).to_dict())
print("Merged label counts:", merged["label_num"].value_counts(dropna=False).to_dict())

# -------------------------
# Preflight: plan destination filenames and check files before copying
# Inputs: merged rows (source clip_path values)
# Outputs: dest_paths/dest_names aligned to merged rows, plus preflight stats and warnings
# -------------------------
# This stage checks everything before touching outputs:
# - every source clip exists
# - destination name collisions are detected
# - if a destination file already exists, its size must match the source
# - copy plan is recorded for reproducibility
print("\nPreflight: computing destination filenames and verifying source clips exist...")

split_counter_pre = {sp: 0 for sp in D7_INCLUDE_SPLITS}
dest_paths = []
dest_names = []
src_paths = []

seen_dest = set()

n_src_missing = 0
n_dest_collision = 0
n_dest_exists_ok = 0
n_dest_exists_mismatch = 0
n_will_copy = 0

pbar = tqdm(total=len(merged), desc="D7: Preflight", dynamic_ncols=True)
for i, row in merged.iterrows():
    sp = str(row["split"])
    src = str(row["source_dataset"])

    require(sp in D7_INCLUDE_SPLITS, f"Unexpected split value in merged table: {sp}")

    split_counter_pre[sp] += 1
    gi = split_counter_pre[sp]

    # Filename label uses PD/HC/UNK; manifest keeps the canonical label fields.
    y = row["label_num"]
    if pd.isna(y):
        add_warn("D7", "ERROR", "LABEL_NAN", "label_num is NaN after normalization", row_index=int(i))
        hc_pd = "UNK"
    else:
        hc_pd = hc_pd_from_label_num(y)

    spkid_tok = safe_token(row["speaker_id"], max_len=32, default="NA")
    task_tok  = safe_token(row["task"], max_len=12, default="0")

    # Deterministic output naming: stable within a run given the merged row order.
    out_name = f"D7_{src}_{hc_pd}_{spkid_tok}_{task_tok}_{gi:06d}.wav"
    out_path = clips_dir / sp / out_name

    # Detect deterministic collisions (should not happen unless inputs collide unexpectedly).
    if str(out_path) in seen_dest:
        n_dest_collision += 1
        add_warn("D7", "ERROR", "NAME_COLLISION", "Destination filename collision during preflight", out_path=str(out_path), row_index=int(i))
    else:
        seen_dest.add(str(out_path))

    src_path = Path(str(row["clip_path"]))
    src_paths.append(str(src_path))
    dest_paths.append(str(out_path))
    dest_names.append(out_name)

    # Missing source clips are fatal (recorded first, copying is blocked later).
    if not src_path.exists():
        n_src_missing += 1
        if n_src_missing <= 50:
            add_warn("D7", "ERROR", "MISSING_SOURCE_CLIP", "Source clip file missing", clip_path=str(src_path), row_index=int(i))
        pbar.update(1)
        continue

    # No overwrite policy: if a destination exists, it must match the source size.
    if out_path.exists():
        try:
            if out_path.stat().st_size == src_path.stat().st_size:
                n_dest_exists_ok += 1
            else:
                n_dest_exists_mismatch += 1
                add_warn(
                    "D7", "ERROR", "DEST_EXISTS_SIZE_MISMATCH",
                    "Destination exists but file size differs from source",
                    src_path=str(src_path), dest_path=str(out_path),
                    src_size=int(src_path.stat().st_size), dest_size=int(out_path.stat().st_size),
                    row_index=int(i)
                )
        except Exception as e:
            n_dest_exists_mismatch += 1
            add_warn(
                "D7", "ERROR", "DEST_EXISTS_STAT_ERROR",
                "Failed to stat source/destination file during preflight",
                src_path=str(src_path), dest_path=str(out_path),
                error=repr(e), row_index=int(i)
            )
    else:
        n_will_copy += 1

    pbar.update(1)

pbar.close()

# Preflight stats are written to logs before copying starts.
preflight_stats = {
    "total_rows": int(len(merged)),
    "split_counters": {k: int(v) for k, v in split_counter_pre.items()},
    "n_src_missing": int(n_src_missing),
    "n_dest_collision": int(n_dest_collision),
    "n_dest_exists_ok": int(n_dest_exists_ok),
    "n_dest_exists_mismatch": int(n_dest_exists_mismatch),
    "n_will_copy": int(n_will_copy),
    "warnings_by_level": count_by_level(warnings_rows),
}

print("\nPreflight summary:")
print("  Source clips missing:", n_src_missing)
print("  Destination collisions:", n_dest_collision)
print("  Destination exists (size OK):", n_dest_exists_ok)
print("  Destination exists (size mismatch/stat error):", n_dest_exists_mismatch)
print("  Will copy (missing in D7):", n_will_copy)
print("  Warnings by level:", preflight_stats["warnings_by_level"])

# -------------------------
# Write early logs and config (always)
# Inputs: merged table and preflight stats
# Outputs: run_config.json, dataset_summary.json, preprocess_warnings.csv
# -------------------------
# These files are written before copying so a failed run still leaves a usable trace.
def summarize(df):
    """Small count summary used in dataset_summary.json."""
    out = {}
    out["total_rows"] = int(len(df))
    out["by_split"] = {sp: int((df["split"] == sp).sum()) for sp in D7_INCLUDE_SPLITS}
    out["by_source_dataset"] = {sd: int((df["source_dataset"] == sd).sum()) for sd in sorted(df["source_dataset"].unique())}
    out["by_source_dataset_by_split"] = {}
    for sd in sorted(df["source_dataset"].unique()):
        out["by_source_dataset_by_split"][sd] = {sp: int(((df["source_dataset"]==sd) & (df["split"]==sp)).sum()) for sp in D7_INCLUDE_SPLITS}
    out["label_counts_total"] = {str(k): int(v) for k, v in df["label_num"].value_counts(dropna=False).to_dict().items()}
    return out

run_config = {
    "dataset": "D7",
    "mode": "merge_builder",
    "created_utc": datetime.utcnow().isoformat(),
    "d7_out_root": str(D7_OUT_ROOT),
    "include_splits": D7_INCLUDE_SPLITS,
    "source_manifests": {s["src"]: s["manifest_all"] for s in SOURCES},
    "manifest_schema_locked_to_attachment": {
        "canon_cols": CANON_COLS,
        "final_cols": FINAL_COLS,
        "label_str_values": ["Healthy", "Parkinson"],
        "label_num_values": [0, 1],
        "sex_values": ["M", "F"],
        "notes": [
            "speaker_key_rel prefixed with {SRC}__ to avoid collisions",
            "clips copied into D7; clip_path rewritten to D7 paths",
            "sample_id set to the output filename stem for determinism",
            "existing destination clips are not overwritten (copy skipped when size matches)",
            "missing values written as blanks (true NaN), never literal 'NaN'",
        ],
    },
    "preflight": preflight_stats,
}

early_summary = {
    "dataset": "D7",
    "created_utc": datetime.utcnow().isoformat(),
    "d7_out_root": str(D7_OUT_ROOT),
    "include_splits": D7_INCLUDE_SPLITS,
    "sources": [{"src": s["src"], "manifest_all": s["manifest_all"]} for s in SOURCES],
    "source_row_counts_after_split_filter": source_counts,
    "counts": summarize(merged),
    "status": "PRECHECK_COMPLETE",
    "preflight": preflight_stats,
}

atomic_write_csv(logs_dir / "preprocess_warnings.csv", pd.DataFrame(warnings_rows))
atomic_write_text(cfg_dir / "run_config.json", json.dumps(run_config, indent=2))
atomic_write_text(logs_dir / "dataset_summary.json", json.dumps(early_summary, indent=2))

# If any ERROR was logged, stop before copying to avoid partial outputs.
fatal_pre = [w for w in warnings_rows if str(w.get("level", "")).upper() == "ERROR"]
if fatal_pre:
    print("\n❌ Preflight found ERROR conditions. No copying performed.")
    print("   Logs written:")
    print("   -", str(logs_dir / "preprocess_warnings.csv"))
    print("   -", str(logs_dir / "dataset_summary.json"))
    print("   -", str(cfg_dir / "run_config.json"))
    raise RuntimeError(f"Preflight failed with {len(fatal_pre)} ERROR(s). See logs/preprocess_warnings.csv.")

# -------------------------
# Copy stage: copy missing clips only (no overwrite)
# Inputs: src_paths + dest_paths planned in preflight
# Outputs: D7 clips/ populated; warnings updated if a copy fails
# -------------------------
print("\nCopy stage: copying only missing clips (no overwrite when destination exists)...")

copied = 0
skipped_exists = 0

pbar = tqdm(total=len(merged), desc="D7: Copying clips", dynamic_ncols=True)
for i in range(len(merged)):
    src_path = Path(src_paths[i])
    out_path = Path(dest_paths[i])

    # Defensive guard in case a source file disappears after preflight.
    if not src_path.exists():
        add_warn("D7", "ERROR", "SOURCE_CLIP_DISAPPEARED", "Source clip missing during copy stage", clip_path=str(src_path), row_index=int(i))
        pbar.update(1)
        continue

    if out_path.exists():
        # Skip and keep the existing file (preflight already checked size match).
        skipped_exists += 1
        pbar.update(1)
        continue

    try:
        shutil.copy2(src_path, out_path)
        copied += 1
    except Exception as e:
        add_warn(
            "D7", "ERROR", "COPY_FAILED", "Failed to copy source clip into D7",
            src_path=str(src_path), dest_path=str(out_path), error=repr(e), row_index=int(i)
        )

    pbar.update(1)

pbar.close()

fatal_copy = [w for w in warnings_rows if str(w.get("level", "")).upper() == "ERROR"]
copy_stats = {
    "copied": int(copied),
    "skipped_exists": int(skipped_exists),
    "total_rows": int(len(merged)),
    "warnings_by_level": count_by_level(warnings_rows),
}
print("\nCopy summary:")
print("  Copied:", copied)
print("  Skipped (already exists):", skipped_exists)
print("  Warnings by level:", copy_stats["warnings_by_level"])

# Always refresh warnings CSV after the copy attempt.
atomic_write_csv(logs_dir / "preprocess_warnings.csv", pd.DataFrame(warnings_rows))

# Stop if any copy-stage errors occurred (manifest is not rewritten in that case).
if fatal_copy:
    fail_summary = dict(early_summary)
    fail_summary["status"] = "FAILED_DURING_COPY"
    fail_summary["copy_stats"] = copy_stats
    atomic_write_text(logs_dir / "dataset_summary.json", json.dumps(fail_summary, indent=2))
    print("\n❌ Copy stage encountered ERROR conditions.")
    print("   Logs written:")
    print("   -", str(logs_dir / "preprocess_warnings.csv"))
    print("   -", str(logs_dir / "dataset_summary.json"))
    print("   -", str(cfg_dir / "run_config.json"))
    raise RuntimeError(f"Copy failed with {len(fatal_copy)} ERROR(s). See logs/preprocess_warnings.csv.")

# Now that copies are confirmed, rewrite manifest fields to point to D7 copies.
merged["clip_path"] = dest_paths
merged["sample_id"] = [n.replace(".wav", "") for n in dest_names]

# -------------------------
# Final checks before writing manifest
# Inputs: merged table after rewrite
# Output: hard stop if any standard rule is violated
# -------------------------
require(list(merged.columns) == FINAL_COLS, "Internal error: merged columns are not in the locked FINAL_COLS order.")

for c in FINAL_COLS:
    if merged[c].dtype == object and (merged[c] == "NaN").any():
        raise RuntimeError(f"Found literal string 'NaN' in column '{c}'. This violates standardization rules.")

require(set(merged["label_num"].dropna().unique()).issubset({0, 1}), "label_num contains values outside {0,1}.")
require(set(merged["label_str"].dropna().unique()).issubset({"Healthy", "Parkinson"}), "label_str contains values outside {Healthy, Parkinson}.")
require(((merged["label_num"] == 0) == (merged["label_str"] == "Healthy")).all(), "Mismatch between label_num==0 and label_str!=Healthy.")
require(((merged["label_num"] == 1) == (merged["label_str"] == "Parkinson")).all(), "Mismatch between label_num==1 and label_str!=Parkinson.")

bad_sex = sorted(set(merged["sex"].dropna().unique()) - {"M", "F"})
require(len(bad_sex) == 0, f"sex contains unexpected values: {bad_sex}")

# -------------------------
# Write final artifacts
# Inputs: merged table + updated logs/config
# Outputs: manifests/manifest_all.csv and updated dataset_summary.json, preprocess_warnings.csv, run_config.json
# -------------------------
manifest_path = manif_dir / "manifest_all.csv"
atomic_write_csv(manifest_path, merged)

final_summary = dict(early_summary)
final_summary["status"] = "SUCCESS"
final_summary["copy_stats"] = copy_stats
atomic_write_text(logs_dir / "dataset_summary.json", json.dumps(final_summary, indent=2))

atomic_write_csv(logs_dir / "preprocess_warnings.csv", pd.DataFrame(warnings_rows))

run_config["copy_stats"] = copy_stats
atomic_write_text(cfg_dir / "run_config.json", json.dumps(run_config, indent=2))

print("\n✅ D7 merge complete.")
print(f"- Manifest: {manifest_path}")
print(f"- Clips root: {clips_dir}")
print(f"- Summary: {logs_dir / 'dataset_summary.json'}")
print(f"- Warnings: {logs_dir / 'preprocess_warnings.csv'}")
print(f"- Config: {cfg_dir / 'run_config.json'}")

The following cell trains and validates the D7 multilingual model heads only. It does not run any test evaluation. The cell uses the merged D7 manifest file (`<DX_OUT_ROOT>/manifests/manifest_all.csv`), keeps only rows marked as `train` or `val`, checks that all required columns are present, and stops early with clear errors if the manifest is missing or if any audio clip paths do not exist. A new experiment folder is created under `<DX_OUT_ROOT>/trainval_runs/exp_<tag>_<timestamp>/` so each run is saved separately and can be traced later.

The model uses a frozen Wav2Vec2 backbone, meaning the feature extractor is not updated during training, and two small trainable heads. One head is used for vowel clips and the other is used for all remaining speech clips. A simple rule assigns clips to the vowel head only when `task == "vowl"`, otherwise the clip is treated as other. For vowel clips, the data loader builds an attention mask that ignores trailing near silence, which helps prevent learning from padded silence. Audio is loaded from disk, checked to be 16 kHz, and padded within each batch so batches can be processed efficiently.

Training is run separately for three random seeds, 1337, 2024, and 7777. For each seed, the model trains for up to a fixed maximum number of epochs and uses early stopping when validation AUROC stops improving for a set number of epochs. After every epoch, the cell runs validation and computes validation AUROC. When a new best AUROC is reached, the cell saves a snapshot of the trainable components only, which includes the two heads and their small LayerNorm and dropout blocks. At this best validation epoch, the cell also computes a validation optimal probability threshold using Youden’s J statistic, which selects the point on the ROC curve that best balances true positives and false positives. For clarity, the cell records threshold based metrics on validation at both a fixed threshold of 0.5 and at the Youden optimal threshold. These metrics include accuracy, precision, sensitivity or recall, specificity, F1 score, MCC, and the Fisher exact test p value. A ROC curve and confusion matrix plots are also saved.

Outputs are written at two levels. For each seed, results are saved under `<EXP_ROOT>/run_<DATASET>_seed<SEED>/`, which contains `best_heads.pt`, `metrics.json`, and the generated plots. At the experiment level, a single `<EXP_ROOT>/summary_trainval.json` file is written. This summary includes the run settings, validation AUROC results across seeds with the mean and a 95 percent t based confidence interval, and the canonical threshold information used later by test cells, including `val_optimal_threshold.by_seed`, `val_optimal_threshold.mean_sd.mean`, and `val_optimal_threshold.mean_sd.sd`.

The experiment summary is also appended to the running history file at `<DX_OUT_ROOT>/trainval_runs/history_index.jsonl` so previous experiments remain easy to find. At the end of the cell, the Colab runtime is unassigned to cleanly shut down the GPU instance.

In [None]:
# D7 Train + Val (Frozen Backbone, Two Heads)
# Trains only the small task heads (vowel vs other) on the D7 train split, validates on the D7 val split,
# saves the best heads per seed, computes a VAL-optimal threshold from the best AUROC epoch, writes per-seed
# artifacts and a single experiment summary, then unassigns the runtime.

# =========================
# Train + Val ONLY (CRASH-PROOF, WITH PROGRESS + HISTORY) - D7 (Multilingual)
# - Frozen Wav2Vec2 backbone
# - Two task heads + tiny LayerNorm + Dropout pre-head (trainable heads only)
# - Uses ONLY: <DX_OUT_ROOT>/manifests/manifest_all.csv
# - Writes ONLY under: <DX_OUT_ROOT>/trainval_runs/exp_<tag>_<timestamp>/
# - Saves best-epoch plots + metrics per seed, plus per-experiment summary + history_index.jsonl
# - Adds additional metrics: accuracy, precision, recall/sensitivity, specificity, F1, MCC, Fisher p-value
# - Determines VAL-opt threshold via Youden J at the BEST-AUROC epoch (per seed)
# - Stores thresholds ONLY as the canonical aggregate in summary_trainval.json:
#     val_optimal_threshold.by_seed
#     val_optimal_threshold.mean_sd.mean
#     val_optimal_threshold.mean_sd.sd
# - Ends by unassigning Colab runtime (L4) with messages
# =========================

import os, json, math, random, time, warnings
from pathlib import Path

import numpy as np
import pandas as pd
import soundfile as sf

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import Wav2Vec2Model

from sklearn.metrics import (
    roc_auc_score, roc_curve,
    confusion_matrix, accuracy_score,
    precision_recall_fscore_support,
    matthews_corrcoef
)
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# -------------------------
# Guard against local name conflicts
# Inputs: local filesystem under /content/
# Output: stops early with a clear error if a local file would break imports
# -------------------------
if os.path.exists("/content/torch.py") or os.path.exists("/content/torch/__init__.py"):
    raise RuntimeError("Found local /content/torch.py or /content/torch/ that shadows PyTorch. Rename/remove it and restart runtime.")
if os.path.exists("/content/transformers.py") or os.path.exists("/content/transformers/__init__.py"):
    raise RuntimeError("Found local /content/transformers.py or /content/transformers/ that shadows Hugging Face Transformers. Rename/remove it and restart runtime.")

# -------------------------
# Google Drive mount (safe if already mounted)
# Inputs: Colab runtime state
# Output: ensures /content/drive/MyDrive is available when running in Colab
# -------------------------
try:
    from google.colab import drive  # type: ignore
    if not os.path.isdir("/content/drive/MyDrive"):
        drive.mount("/content/drive")
except Exception:
    pass

# -------------------------
# Resolve run root and manifest path
# Inputs: DX_OUT_ROOT (if already set) or fallback D7 path
# Output: DX_OUT_ROOT and MANIFEST_ALL used by the rest of the cell
# -------------------------
# Rule:
# - Prefer runtime variable DX_OUT_ROOT if already defined (allows reuse across cells).
# - Otherwise use the D7 fallback path below.
D7_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D7-Multilingual (D1_D4_D5v2_D6)/preprocessed_v1"
DX_OUT_ROOT = str(globals().get("DX_OUT_ROOT", D7_OUT_ROOT_FALLBACK))
MANIFEST_ALL = f"{DX_OUT_ROOT}/manifests/manifest_all.csv"

# Keep for later cells that only need DX_OUT_ROOT
globals()["DX_OUT_ROOT"] = DX_OUT_ROOT

# -------------------------
# Experiment folder setup
# Inputs: EXPERIMENT_TAG + timestamp
# Output: a new exp_* folder that keeps older runs intact
# -------------------------
EXPERIMENT_TAG = "frozen_LNDO"
RUN_STAMP = time.strftime("%Y%m%d_%H%M%S")

TRAINVAL_ROOT = Path(DX_OUT_ROOT) / "trainval_runs"
EXP_ROOT = TRAINVAL_ROOT / f"exp_{EXPERIMENT_TAG}_{RUN_STAMP}"
EXP_ROOT.mkdir(parents=True, exist_ok=True)

# -------------------------
# Training settings
# Inputs: fixed hyperparameters and routing rules
# Output: consistent train/val behavior across seeds
# -------------------------
MAX_EPOCHS     = 10
EFFECTIVE_BS   = 64
PER_DEVICE_BS  = 16
GRAD_ACCUM     = max(1, EFFECTIVE_BS // PER_DEVICE_BS)

LR             = 1e-3
PATIENCE       = 2
SEEDS          = [1337, 2024, 7777]

BACKBONE_CKPT  = "facebook/wav2vec2-base"
SR_EXPECTED    = 16000
TINY_THRESH    = 1e-4

DROPOUT_P      = 0.2

NUM_WORKERS    = 0
PIN_MEMORY     = False

# Task routing rule: tasks with this exact value use the vowel head
VOWEL_TASK_VALUE = "vowl"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
warnings.filterwarnings("ignore", message="Passing `gradient_checkpointing` to a config initialization is deprecated")
warnings.filterwarnings("ignore", category=UserWarning, module="huggingface_hub")

print("DX_OUT_ROOT:", DX_OUT_ROOT)
print("MANIFEST_ALL:", MANIFEST_ALL)
print("DEVICE:", DEVICE)
if DEVICE.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
print("PER_DEVICE_BS:", PER_DEVICE_BS, "| GRAD_ACCUM:", GRAD_ACCUM, "| EFFECTIVE_BS:", PER_DEVICE_BS * GRAD_ACCUM)
print("NUM_WORKERS:", NUM_WORKERS, "| PIN_MEMORY:", PIN_MEMORY)
print("EXPERIMENT_TAG:", EXPERIMENT_TAG, "| RUN_STAMP:", RUN_STAMP)

# -------------------------
# Load manifest and split into train/val tables
# Inputs: manifest_all.csv
# Output: train_df and val_df with required columns only
# -------------------------
if not os.path.exists(MANIFEST_ALL):
    raise FileNotFoundError(
        "Missing manifest_all.csv at:\n"
        f"  {MANIFEST_ALL}\n"
        "Confirm that the D7 merge-builder wrote manifests/manifest_all.csv under DX_OUT_ROOT."
    )

m = pd.read_csv(MANIFEST_ALL)

req_cols = {"split", "clip_path", "label_num", "task"}
missing = [c for c in sorted(req_cols) if c not in m.columns]
if missing:
    raise ValueError(f"Manifest missing required columns: {missing}. Found: {list(m.columns)}")

m = m[m["split"].isin(["train", "val"])].copy()
if len(m) == 0:
    raise RuntimeError("After filtering to split in {train,val}, manifest has 0 rows.")

# Infer dataset_id for naming only
if "dataset" in m.columns and m["dataset"].notna().any():
    dataset_id = str(m["dataset"].astype(str).value_counts(dropna=True).idxmax())
    m = m[m["dataset"].astype(str) == dataset_id].copy()
else:
    dataset_id = "DX"

keep_cols = ["clip_path", "label_num", "task", "speaker_id", "duration_sec", "split"]
for c in keep_cols:
    if c not in m.columns:
        m[c] = np.nan
m = m[keep_cols].copy()

train_df = m[m["split"] == "train"].copy().reset_index(drop=True)
val_df   = m[m["split"] == "val"].copy().reset_index(drop=True)

print(f"\nDataset inferred: {dataset_id}")
print(f"Train rows: {len(train_df)} | Val rows: {len(val_df)}")
print("Train label counts:", train_df["label_num"].value_counts(dropna=False).to_dict())
print("Val label counts:",   val_df["label_num"].value_counts(dropna=False).to_dict())

if len(train_df) == 0 or len(val_df) == 0:
    raise RuntimeError("Train or Val split has 0 rows.")

# -------------------------
# Fail fast if audio files are missing
# Inputs: clip_path values in train_df/val_df
# Output: stops early with a few missing examples instead of failing mid-epoch
# -------------------------
def _fail_fast_missing_paths(df: pd.DataFrame, name: str):
    missing_paths = []
    for p in tqdm(df["clip_path"].astype(str).tolist(), desc=f"Check {name} clip_path exists", dynamic_ncols=True):
        if not os.path.exists(p):
            missing_paths.append(p)
            if len(missing_paths) >= 10:
                break
    if missing_paths:
        raise FileNotFoundError(f"{name}: missing clip_path(s). Examples: {missing_paths}")

_fail_fast_missing_paths(train_df, "TRAIN")
_fail_fast_missing_paths(val_df, "VAL")

# -------------------------
# Add task group used for head routing
# Inputs: task column in the manifest
# Output: task_group column with values "vowel" or "other"
# -------------------------
def _task_group(task_val) -> str:
    return "vowel" if str(task_val) == VOWEL_TASK_VALUE else "other"

train_df["task_group"] = train_df["task"].apply(_task_group)
val_df["task_group"]   = val_df["task"].apply(_task_group)

# -------------------------
# Dataset and collator
# Inputs: tables + audio files
# Output: padded batches with attention masks and per-item task_group
# -------------------------
class AudioManifestDataset(Dataset):
    """
    Loads a clip and creates an attention mask.

    Mask rule:
    - vowel clips: mask trailing near-zeros to reduce learning from padded silence.
    - other clips: keep full attention.
    """
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        clip_path = str(row["clip_path"])
        label = int(row["label_num"])
        task_group = str(row["task_group"])

        if not os.path.exists(clip_path):
            raise FileNotFoundError(f"Missing clip_path: {clip_path}")

        y, sr = sf.read(clip_path, always_2d=False)
        if isinstance(y, np.ndarray) and y.ndim > 1:
            y = np.mean(y, axis=1)
        y = np.asarray(y, dtype=np.float32)

        if int(sr) != SR_EXPECTED:
            raise ValueError(f"Sample rate mismatch (expected {SR_EXPECTED}): {clip_path} has sr={sr}")

        attn = np.ones((len(y),), dtype=np.int64)

        if task_group == "vowel":
            # Find last non-trivial sample and mask the trailing near-zero region
            k = -1
            for j in range(len(y) - 1, -1, -1):
                if abs(float(y[j])) > TINY_THRESH:
                    k = j
                    break
            if k >= 0:
                attn[:k+1] = 1
                attn[k+1:] = 0
            else:
                attn[:] = 1
        else:
            attn[:] = 1

        return {
            "input_values": torch.from_numpy(y),                 # float32 [T]
            "attention_mask": torch.from_numpy(attn),            # int64   [T]
            "labels": torch.tensor(label, dtype=torch.long),     # int64   []
            "task_group": task_group,                            # str
        }

def collate_fn(batch):
    """Pads waveforms and masks to the longest clip in the batch."""
    max_len = int(max(b["input_values"].numel() for b in batch))
    input_vals, attn_masks, labels, task_groups = [], [], [], []
    for b in batch:
        x = b["input_values"]
        a = b["attention_mask"]
        pad = max_len - x.numel()
        if pad > 0:
            x = torch.cat([x, torch.zeros(pad, dtype=x.dtype)], dim=0)
            a = torch.cat([a, torch.zeros(pad, dtype=a.dtype)], dim=0)
        input_vals.append(x)
        attn_masks.append(a)
        labels.append(b["labels"])
        task_groups.append(b["task_group"])
    return {
        "input_values": torch.stack(input_vals, dim=0),      # [B,T]
        "attention_mask": torch.stack(attn_masks, dim=0),    # [B,T]
        "labels": torch.stack(labels, dim=0),                # [B]
        "task_group": task_groups,                           # list[str]
    }

# -------------------------
# Model definition
# Inputs: backbone checkpoint + dropout rate
# Output: frozen backbone model with two small, trainable heads
# -------------------------
class Wav2Vec2TwoHeadClassifier(nn.Module):
    """
    Frozen backbone with two task heads.

    Trainable parts:
    - LayerNorm+Dropout blocks (pre_vowel, pre_other)
    - Linear heads (head_vowel, head_other)
    """
    def __init__(self, ckpt: str, dropout_p: float):
        super().__init__()
        self.backbone = Wav2Vec2Model.from_pretrained(
            ckpt,
            use_safetensors=True,
            local_files_only=False
        )
        for p in self.backbone.parameters():
            p.requires_grad = False

        H = int(self.backbone.config.hidden_size)
        self.pre_vowel = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.pre_other = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))

        self.head_vowel = nn.Linear(H, 2)
        self.head_other = nn.Linear(H, 2)
        self.loss_fn = nn.CrossEntropyLoss()

    def masked_mean_pool(self, last_hidden, attn_mask_samples):
        # Convert sample-level mask to feature-level mask, then average only valid frames
        feat_mask = self.backbone._get_feature_vector_attention_mask(last_hidden.shape[1], attn_mask_samples)
        feat_mask = feat_mask.to(last_hidden.device).unsqueeze(-1).type_as(last_hidden)
        masked = last_hidden * feat_mask
        denom = feat_mask.sum(dim=1).clamp(min=1.0)
        return masked.sum(dim=1) / denom

    def forward(self, input_values, attention_mask, labels, task_group):
        # Backbone stays frozen and runs without gradients
        with torch.no_grad():
            out = self.backbone(input_values=input_values, attention_mask=attention_mask)
            last_hidden = out.last_hidden_state  # [B,T',H]

        pooled = self.masked_mean_pool(last_hidden, attention_mask).float()  # [B,H]

        z_v = self.pre_vowel(pooled)
        z_o = self.pre_other(pooled)

        logits_v = self.head_vowel(z_v)  # 2-class logits
        logits_o = self.head_other(z_o)

        # Route each item to the matching head
        is_vowel = torch.tensor([tg == "vowel" for tg in task_group], device=pooled.device, dtype=torch.bool)
        logits = logits_o.clone()
        if is_vowel.any():
            logits[is_vowel] = logits_v[is_vowel]

        loss = self.loss_fn(logits, labels)
        return loss, logits

# -------------------------
# Metric utilities
# Inputs: true labels, probabilities, threshold
# Output: AUROC, threshold metrics, plots, and simple mean±SD summaries
# -------------------------
def compute_auc(y_true, y_prob):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    if len(np.unique(y_true)) < 2:
        return float("nan")
    return float(roc_auc_score(y_true, y_prob))

def compute_threshold_metrics(y_true, y_prob, thr=0.5):
    """
    Converts probabilities into hard predictions using thr, then computes common metrics.
    """
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    tn, fp, fn, tp = (cm.ravel().tolist() if cm.size == 4 else [0, 0, 0, 0])

    acc = float(accuracy_score(y_true, y_pred))
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", zero_division=0)
    mcc = float(matthews_corrcoef(y_true, y_pred)) if len(np.unique(y_true)) > 1 else float("nan")

    sensitivity = float(rec)
    specificity = float(tn / (tn + fp)) if (tn + fp) > 0 else float("nan")

    p_value = float("nan")
    try:
        from scipy.stats import fisher_exact  # type: ignore
        _, p_value = fisher_exact([[tn, fp], [fn, tp]], alternative="two-sided")
        p_value = float(p_value)
    except Exception:
        p_value = float("nan")

    return {
        "threshold": float(thr),
        "tn": int(tn), "fp": int(fp), "fn": int(fn), "tp": int(tp),
        "accuracy": float(acc),
        "precision": float(prec),
        "recall": float(rec),
        "f1": float(f1),
        "sensitivity": float(sensitivity),
        "specificity": float(specificity),
        "mcc": float(mcc),
        "p_value_fisher": float(p_value),
    }

def compute_youden_j_threshold(y_true, y_prob):
    """
    Finds the ROC threshold that maximizes Youden J (TPR - FPR).
    Returns an optimal threshold plus a small details dict.
    """
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    if len(np.unique(y_true)) < 2:
        return float("nan"), {"youden_j": float("nan"), "tpr": float("nan"), "fpr": float("nan")}
    fpr, tpr, thr = roc_curve(y_true, y_prob)
    j = tpr - fpr
    idx = int(np.argmax(j))
    return float(thr[idx]), {"youden_j": float(j[idx]), "tpr": float(tpr[idx]), "fpr": float(fpr[idx])}

def save_roc_curve_png(y_true, y_prob, out_png):
    # Writes a simple ROC plot for the best epoch (VAL only)
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    plt.figure()
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curve (Val)")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def save_confusion_png(y_true, y_prob, out_png, thr=0.5):
    # Writes a confusion matrix image for a chosen threshold (VAL only)
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(int)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])

    plt.figure()
    plt.imshow(cm)
    plt.xticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.yticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix (Val, thr={thr:.4f})")
    for (i, j), v in np.ndenumerate(cm):
        plt.text(j, i, str(v), ha="center", va="center")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def mean_sd(vals):
    vals = np.asarray(vals, dtype=np.float64)
    mu = float(np.nanmean(vals)) if np.any(~np.isnan(vals)) else float("nan")
    sd = float(np.nanstd(vals, ddof=1)) if np.sum(~np.isnan(vals)) > 1 else 0.0
    return mu, sd

# -------------------------
# Seed control
# Inputs: seed integer
# Output: repeatable training and validation behavior for that seed
# -------------------------
def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# -------------------------
# One seed run: train, validate, keep best epoch
# Inputs: train_df, val_df, and fixed settings
# Outputs: best_heads.pt + plots + metrics.json under run_<dataset>_seedXXXX
# -------------------------
def run_trainval_once(seed: int):
    set_all_seeds(seed)

    run_dir = EXP_ROOT / f"run_{dataset_id}_seed{seed}"
    run_dir.mkdir(parents=True, exist_ok=True)

    train_ds = AudioManifestDataset(train_df)
    val_ds   = AudioManifestDataset(val_df)

    train_loader = DataLoader(
        train_ds,
        batch_size=PER_DEVICE_BS,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        collate_fn=collate_fn
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=PER_DEVICE_BS,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        collate_fn=collate_fn
    )

    # Quick warm-up: catches basic data issues before the first full epoch
    print(f"\n[seed={seed}] Warm-up: loading 3 train batches...")
    t0 = time.time()
    it = iter(train_loader)
    for i in range(3):
        _ = next(it)
        print(f"  loaded warmup batch {i+1}/3")
    print(f"[seed={seed}] Warm-up done in {time.time()-t0:.2f}s")

    model = Wav2Vec2TwoHeadClassifier(BACKBONE_CKPT, dropout_p=DROPOUT_P).to(DEVICE)

    # Only train the small head blocks (backbone remains frozen)
    trainable_params = (
        list(model.pre_vowel.parameters()) + list(model.pre_other.parameters()) +
        list(model.head_vowel.parameters()) + list(model.head_other.parameters())
    )
    opt = torch.optim.Adam(trainable_params, lr=LR)

    best_auc = -1.0
    best_epoch = -1
    no_improve = 0

    # Snapshot of the best epoch (used for saving heads and computing thresholds)
    best_state = None
    best_val_probs = None
    best_val_true = None

    best_thr_youden = float("nan")
    best_thr_youden_details = None
    best_val_metrics_thr05 = None
    best_val_metrics_thr_opt = None

    for epoch in range(1, MAX_EPOCHS + 1):
        model.train()
        train_losses = []
        opt.zero_grad(set_to_none=True)

        pbar = tqdm(train_loader, desc=f"[seed={seed}] Train epoch {epoch}", dynamic_ncols=True)
        step = 0
        for batch in pbar:
            step += 1
            input_values = batch["input_values"].to(DEVICE, non_blocking=False)
            attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
            labels = batch["labels"].to(DEVICE, non_blocking=False)
            task_group = batch["task_group"]

            loss, _ = model(input_values, attention_mask, labels, task_group)
            loss = loss / GRAD_ACCUM
            loss.backward()

            train_losses.append(float(loss.detach().cpu().item()) * GRAD_ACCUM)

            if (step % GRAD_ACCUM) == 0:
                opt.step()
                opt.zero_grad(set_to_none=True)

        # Final optimizer step if the epoch ends mid-accumulation
        if (step % GRAD_ACCUM) != 0:
            opt.step()
            opt.zero_grad(set_to_none=True)

        avg_train_loss = float(np.mean(train_losses)) if train_losses else float("nan")

        # Validation: collect probabilities for the full VAL set
        model.eval()
        all_probs, all_true = [], []
        vpbar = tqdm(val_loader, desc=f"[seed={seed}] Val epoch {epoch}", dynamic_ncols=True)
        with torch.inference_mode():
            for batch in vpbar:
                input_values = batch["input_values"].to(DEVICE, non_blocking=False)
                attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
                labels = batch["labels"].to(DEVICE, non_blocking=False)
                task_group = batch["task_group"]

                _, logits = model(input_values, attention_mask, labels, task_group)
                probs = torch.softmax(logits.float(), dim=-1)[:, 1].detach().cpu().numpy().astype(np.float64)
                all_probs.extend(probs.tolist())
                all_true.extend(labels.detach().cpu().numpy().astype(np.int64).tolist())

        val_auc = compute_auc(all_true, all_probs)
        print(f"seed={seed} | epoch {epoch:02d}/{MAX_EPOCHS} | train_loss={avg_train_loss:.5f} | val_AUROC={val_auc:.5f}")

        # Keep the epoch with the best VAL AUROC
        improved = (not math.isnan(val_auc)) and (val_auc > best_auc + 1e-12)
        if improved:
            best_auc = float(val_auc)
            best_epoch = int(epoch)
            no_improve = 0

            # Save only the trainable parts (for later test-only reload)
            best_state = {
                "pre_vowel": {k: v.detach().cpu().clone() for k, v in model.pre_vowel.state_dict().items()},
                "pre_other": {k: v.detach().cpu().clone() for k, v in model.pre_other.state_dict().items()},
                "head_vowel": {k: v.detach().cpu().clone() for k, v in model.head_vowel.state_dict().items()},
                "head_other": {k: v.detach().cpu().clone() for k, v in model.head_other.state_dict().items()},
            }

            best_val_probs = list(all_probs)
            best_val_true  = list(all_true)

            # Reference metrics at thr=0.5
            best_val_metrics_thr05 = compute_threshold_metrics(best_val_true, best_val_probs, thr=0.5)

            # VAL-opt threshold is computed only at the best AUROC epoch
            thr_opt, details = compute_youden_j_threshold(best_val_true, best_val_probs)
            best_thr_youden = float(thr_opt)
            best_thr_youden_details = details
            best_val_metrics_thr_opt = compute_threshold_metrics(best_val_true, best_val_probs, thr=best_thr_youden)
        else:
            no_improve += 1

        # Early stopping when AUROC has not improved for PATIENCE epochs
        if no_improve >= PATIENCE:
            break

    if best_state is None or best_val_probs is None or best_val_true is None:
        raise RuntimeError(
            "No best epoch captured. Validation AUROC may be NaN due to single-class validation split "
            "or earlier failures."
        )

    # Save best epoch outputs (heads + plots + metrics)
    best_heads_path = run_dir / "best_heads.pt"
    torch.save(best_state, str(best_heads_path))

    roc_png = run_dir / "roc_curve.png"
    cm_png_05 = run_dir / "confusion_matrix_thr0p5.png"
    cm_png_opt = run_dir / "confusion_matrix_thr_opt.png"

    save_roc_curve_png(np.asarray(best_val_true, dtype=np.int64), np.asarray(best_val_probs, dtype=np.float64), str(roc_png))
    save_confusion_png(np.asarray(best_val_true, dtype=np.int64), np.asarray(best_val_probs, dtype=np.float64), str(cm_png_05), thr=0.5)
    if not np.isnan(best_thr_youden):
        save_confusion_png(np.asarray(best_val_true, dtype=np.int64), np.asarray(best_val_probs, dtype=np.float64), str(cm_png_opt), thr=float(best_thr_youden))

    # Per-seed metrics file is detailed; the experiment summary stores the canonical threshold aggregate
    metrics = {
        "dataset": dataset_id,
        "seed": int(seed),
        "best_val_auroc": float(best_auc),
        "best_epoch": int(best_epoch),

        "n_train": int(len(train_df)),
        "n_val": int(len(val_df)),
        "label_counts_train": train_df["label_num"].value_counts(dropna=False).to_dict(),
        "label_counts_val": val_df["label_num"].value_counts(dropna=False).to_dict(),

        "experiment_tag": EXPERIMENT_TAG,
        "run_stamp": RUN_STAMP,

        "dropout_p": float(DROPOUT_P),
        "lr": float(LR),
        "effective_batch_size": int(PER_DEVICE_BS * GRAD_ACCUM),
        "per_device_batch_size": int(PER_DEVICE_BS),
        "grad_accum_steps": int(GRAD_ACCUM),

        "backbone_ckpt": BACKBONE_CKPT,

        "val_opt_threshold_method": "Youden J (maximize TPR - FPR on VAL ROC curve)",
        "val_opt_threshold": float(best_thr_youden),
        "val_opt_details": best_thr_youden_details,

        "thr_metrics_val_thr0p5": best_val_metrics_thr05,
        "thr_metrics_val_thr_opt": best_val_metrics_thr_opt,

        "artifacts": {
            "roc_curve_png": str(roc_png),
            "confusion_thr0p5_png": str(cm_png_05),
            "confusion_thr_opt_png": str(cm_png_opt),
            "best_heads_pt": str(best_heads_path),
        },
    }

    metrics_path = run_dir / "metrics.json"
    with open(metrics_path, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2)

    print(f"[seed={seed}] VAL-opt threshold (Youden J): {float(best_thr_youden):.6f}")
    print(f"[seed={seed}] WROTE:")
    print(" ", str(metrics_path))
    print(" ", str(roc_png))
    print(" ", str(cm_png_05))
    print(" ", str(cm_png_opt))
    print(" ", str(best_heads_path))

    return {
        "seed": int(seed),
        "best_val_auroc": float(best_auc),
        "best_epoch": int(best_epoch),
        "val_opt_thr": float(best_thr_youden),
        "run_dir": str(run_dir),
        "seed_metrics": metrics,
    }

# -------------------------
# Run all seeds and write experiment summary
# Inputs: SEEDS list
# Outputs: summary_trainval.json in EXP_ROOT and history_index.jsonl in trainval_runs/
# -------------------------
results = []
for seed in SEEDS:
    results.append(run_trainval_once(seed))

aucs = [r["best_val_auroc"] for r in results]
thr_vals = [r["val_opt_thr"] for r in results]

# Aggregate AUROC across seeds (CI is for n=3 only)
t_crit = 4.302652729911275  # df=2, 95% CI
n = len(aucs)
mean_auc = float(np.mean(aucs))
std_auc = float(np.std(aucs, ddof=1)) if n > 1 else 0.0
half_width = float(t_crit * (std_auc / math.sqrt(n))) if n > 1 else 0.0
ci95 = [float(mean_auc - half_width), float(mean_auc + half_width)]

thr_mean, thr_sd = mean_sd(thr_vals)

print("\nAUROC by seed:")
for r in results:
    print(f"  seed {r['seed']}: {r['best_val_auroc']:.6f}")
print(f"\nMean AUROC: {mean_auc:.6f}")
print(f"95% CI (t, n=3): [{ci95[0]:.6f}, {ci95[1]:.6f}]")

print("\nVAL-opt thresholds (Youden J) by seed:")
for r in results:
    print(f"  seed {r['seed']}: {r['val_opt_thr']:.6f}")
print(f"  mean ± SD: {thr_mean:.6f} ± {thr_sd:.6f}")

# Canonical threshold storage for test-only cells
val_optimal_threshold_obj = {
    "method": "Youden J (maximize TPR - FPR on VAL ROC curve)",
    "by_seed": {str(r["seed"]): float(r["val_opt_thr"]) for r in results},
    "mean_sd": {"mean": float(thr_mean), "sd": float(thr_sd)},
}

exp_summary = {
    "dataset": dataset_id,
    "dx_out_root": DX_OUT_ROOT,
    "manifest_all": MANIFEST_ALL,

    "experiment_tag": EXPERIMENT_TAG,
    "run_stamp": RUN_STAMP,
    "exp_root": str(EXP_ROOT),
    "run_dirs": [r["run_dir"] for r in results],
    "seeds": SEEDS,

    "aurocs": [float(x) for x in aucs],
    "mean_auroc": float(mean_auc),
    "t_crit_df2_95": float(t_crit),
    "half_width_95_ci": float(half_width),
    "ci95": ci95,

    "n_train": int(len(train_df)),
    "n_val": int(len(val_df)),
    "label_counts_train": train_df["label_num"].value_counts(dropna=False).to_dict(),
    "label_counts_val": val_df["label_num"].value_counts(dropna=False).to_dict(),

    "effective_batch_size": int(PER_DEVICE_BS * GRAD_ACCUM),
    "per_device_batch_size": int(PER_DEVICE_BS),
    "grad_accum_steps": int(GRAD_ACCUM),

    "backbone_ckpt": BACKBONE_CKPT,
    "dropout_p": float(DROPOUT_P),
    "lr": float(LR),

    "val_optimal_threshold": val_optimal_threshold_obj,

    # Full per-seed payload kept for traceability
    "per_seed_metrics": [r["seed_metrics"] for r in results],
}

summary_path = EXP_ROOT / "summary_trainval.json"
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(exp_summary, f, indent=2)

history_path = TRAINVAL_ROOT / "history_index.jsonl"
TRAINVAL_ROOT.mkdir(parents=True, exist_ok=True)
with open(history_path, "a", encoding="utf-8") as f:
    f.write(json.dumps(exp_summary) + "\n")

print("\nWROTE per-experiment summary:", str(summary_path))
print("APPENDED global history index:", str(history_path))
print("\nOpen this folder to access artifacts:", str(EXP_ROOT))

# -------------------------
# Stop the runtime to release the GPU
# Inputs: Colab runtime module (if available)
# Output: ends the session cleanly when running on paid GPU instances
# -------------------------
print("\nAll done. Unassigning the runtime to stop the L4 instance...")
try:
    from google.colab import runtime  # type: ignore
    print("Calling runtime.unassign() now.")
    runtime.unassign()
except Exception as e:
    print("Could not unassign runtime automatically. Error:", repr(e))
    print("Manual stop: Runtime -> Disconnect and delete runtime.")

The following cell evaluates the D7 multilingual model on its own D7 test split. It does not perform any training or update model parameters. The cell reads the D7 manifest file (`<DX_OUT_ROOT>/manifests/manifest_all.csv`), filters the data to `split == "test"`, checks that all required columns are present, and verifies that every audio file listed in the manifest exists before continuing. It also adds two helper fields used for reporting: `task_group`, which is set to “vowel” only when `task == "vowl"` and “other” otherwise, and `sex_norm`, which is normalized to M, F, or UNK using a robust rule even if the source data is already clean.

The model is rebuilt in the same way as in other evaluation cells. It uses a frozen Wav2Vec2 backbone with two small classification heads, one for vowel clips and one for other clips. For vowel clips, an attention mask is created to trim trailing near silence so padded silence has less influence on the output. For other clips, full attention is used. A DataLoader is created and briefly warmed up by loading a few batches to catch data issues early, such as missing files or incorrect sample rates.

The cell then automatically finds the most recent successful D7 train and validation experiment under `<DX_OUT_ROOT>/trainval_runs/exp_*`. This experiment must contain `best_heads.pt` files for all three seeds, 1337, 2024, and 7777, and a readable `summary_trainval.json` that includes the validation optimal thresholds, both by seed and as an overall mean with standard deviation. From this summary, the cell selects a single global mean validation optimal threshold, with a defined fallback if the stored mean is missing.

Using this shared threshold for all seeds, the cell runs inference on the D7 test set separately for each seed. For each seed, it loads the saved head weights, produces Parkinson’s disease probability scores, and computes AUROC for the full test set. It also computes threshold based metrics at the shared threshold, including confusion matrix counts, accuracy, precision, sensitivity or recall, specificity, F1 score, MCC, and the Fisher exact test p value. Fairness is evaluated using the H3 definition at the same threshold, where ΔFNR is calculated as FNR(F) minus FNR(M), and FNR is computed only among true Parkinson’s cases as FN divided by FN plus TP.

For each seed, outputs are written to a dedicated folder at
`<DX_OUT_ROOT>/multilingual_test_runs/run_<DATASET>_seed<SEED>/`.
These outputs include `metrics.json`, a ROC curve image, an overall confusion matrix image, and confusion matrices split by sex for M and F when available. After all three seeds complete, the cell aggregates results across seeds. It reports the mean AUROC with a 95 percent confidence interval using a t distribution, along with the mean and standard deviation of the threshold based metrics and the fairness results, including ΔFNR and the underlying FNR values. A combined `summary_test.json` file is written and the same summary is appended to `history_index.jsonl` in the `multilingual_test_runs/` folder. Finally, the Colab runtime is unassigned to stop the GPU instance.

In [None]:
# Multilingual Test (D7 → D7)
# Runs evaluation only (no training). Reads the D7 test split from manifest_all.csv, loads the most recent
# train+val experiment that has finished heads for all 3 seeds, then scores the D7 test set using a single
# shared threshold (the mean VAL-opt threshold from the train+val summary). Writes per-seed metrics and plots,
# plus an aggregated summary across seeds, then unassigns the runtime.

# =========================
# TEST ONLY (CRASH-PROOF, WITH PROGRESS + STORED METRICS) — D7 (Multilingual) to D7
# - Uses ONLY: <DX_OUT_ROOT>/manifests/manifest_all.csv (test split only)
# - Loads finished heads from MOST RECENT trainval experiment under:
#     <DX_OUT_ROOT>/trainval_runs/exp_*/run_<DATASET>_seed{seed}/best_heads.pt
# - Reads VAL-opt thresholds ONLY from that trainval run's summary_trainval.json:
#     summary_trainval.json -> val_optimal_threshold.by_seed
#     summary_trainval.json -> val_optimal_threshold.mean_sd.mean   (canonical aggregate)
#     summary_trainval.json -> val_optimal_threshold.mean_sd.sd
# - Evaluates 3 seeds separately (1337, 2024, 7777)
# - Reports:
#     * mean Test AUROC ± 95% CI (t, n=3)
#     * Threshold metrics on TEST @ MEAN VAL-opt threshold (shared across seeds) as mean ± SD
#     * FAIRNESS (H3) on TEST @ shared threshold as mean ± SD:
#         ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) on PD-only true labels
#     * Confusion charts split by sex (M/F) on TEST @ shared threshold
# - Writes ONLY under: <DX_OUT_ROOT>/multilingual_test_runs/run_<DATASET>_seedXXXX/
#   plus summary_test.json + history_index.jsonl in multilingual_test_runs/
# - Unassigns runtime at end (L4)
# - This cell does NOT re-fit any model parameters.
# - Threshold(s) are taken verbatim from summary_trainval.json (produced earlier).
# =========================

import os, json, math, random, time, warnings
from pathlib import Path

import numpy as np
import pandas as pd
import soundfile as sf

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import Wav2Vec2Model

from sklearn.metrics import roc_auc_score, confusion_matrix, roc_curve, matthews_corrcoef
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# -------------------------
# Guard against local name conflicts
# Inputs: local filesystem under /content/
# Output: stops early with a clear error if a local file would break imports
# -------------------------
if os.path.exists("/content/torch.py") or os.path.exists("/content/torch/__init__.py"):
    raise RuntimeError("Found local /content/torch.py or /content/torch/ that shadows PyTorch. Rename/remove it and restart runtime.")
if os.path.exists("/content/transformers.py") or os.path.exists("/content/transformers/__init__.py"):
    raise RuntimeError("Found local /content/transformers.py or /content/transformers/ that shadows Hugging Face Transformers. Rename/remove it and restart runtime.")

# -------------------------
# Google Drive mount (safe if already mounted)
# Inputs: Colab runtime state
# Output: ensures /content/drive/MyDrive is available when running in Colab
# -------------------------
try:
    from google.colab import drive  # type: ignore
    if not os.path.isdir("/content/drive/MyDrive"):
        drive.mount("/content/drive")
except Exception:
    pass

# -------------------------
# Resolve run root and manifest path
# Inputs: DX_OUT_ROOT (if already set) or fallback D7 path
# Output: DX_OUT_ROOT and MANIFEST_ALL used by the rest of the cell
# -------------------------
# Rule:
# - Prefer runtime variable DX_OUT_ROOT if already defined (allows reuse across cells).
# - Otherwise use the D7 fallback path below.
D7_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D7-Multilingual (D1_D4_D5v2_D6)/preprocessed_v1"
DX_OUT_ROOT = str(globals().get("DX_OUT_ROOT", D7_OUT_ROOT_FALLBACK))
MANIFEST_ALL = f"{DX_OUT_ROOT}/manifests/manifest_all.csv"
globals()["DX_OUT_ROOT"] = DX_OUT_ROOT  # keep consistent across cells

# -------------------------
# Evaluation settings
# Inputs: fixed constants (seeds, model checkpoint, batch sizes)
# Output: consistent evaluation behavior across seeds and runs
# -------------------------
SEEDS          = [1337, 2024, 7777]
BACKBONE_CKPT  = "facebook/wav2vec2-base"
SR_EXPECTED    = 16000
TINY_THRESH    = 1e-4

EFFECTIVE_BS   = 64
PER_DEVICE_BS  = 16
GRAD_ACCUM     = max(1, EFFECTIVE_BS // PER_DEVICE_BS)

DROPOUT_P      = 0.2

NUM_WORKERS    = 0
PIN_MEMORY     = False

USE_AMP        = True  # evaluation autocast only
VOWEL_TASK_VALUE = "vowl"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
warnings.filterwarnings("ignore", message="Passing `gradient_checkpointing` to a config initialization is deprecated")

print("DX_OUT_ROOT:", DX_OUT_ROOT)
print("MANIFEST_ALL:", MANIFEST_ALL)
print("DEVICE:", DEVICE)
if DEVICE.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
print("PER_DEVICE_BS:", PER_DEVICE_BS, "| EFFECTIVE_BS:", PER_DEVICE_BS * GRAD_ACCUM)
print("NUM_WORKERS:", NUM_WORKERS, "| PIN_MEMORY:", PIN_MEMORY)
print("USE_AMP:", bool(USE_AMP and DEVICE.type == "cuda"))

# -------------------------
# Load manifest and build the TEST table
# Inputs: manifest_all.csv
# Output: test_df (rows where split == "test"), plus dataset_id label for folder naming
# -------------------------
if not os.path.exists(MANIFEST_ALL):
    raise FileNotFoundError(f"Missing manifest_all.csv: {MANIFEST_ALL}")

m_all = pd.read_csv(MANIFEST_ALL)

req_cols_core = {"split", "clip_path", "label_num", "task"}
missing_core = [c for c in sorted(req_cols_core) if c not in m_all.columns]
if missing_core:
    raise ValueError(f"Manifest missing required columns: {missing_core}. Found: {list(m_all.columns)}")

# Infer dataset_id (naming only)
if "dataset" in m_all.columns and m_all["dataset"].notna().any():
    dataset_id = str(m_all["dataset"].astype(str).value_counts(dropna=True).idxmax())
    m_all = m_all[m_all["dataset"].astype(str) == dataset_id].copy()
else:
    dataset_id = "DX"

print("\nDataset inferred:", dataset_id)

keep_cols = ["clip_path", "label_num", "task", "speaker_id", "sex", "age", "duration_sec", "split"]
for c in keep_cols:
    if c not in m_all.columns:
        m_all[c] = np.nan
m_all = m_all[keep_cols].copy()

test_df = m_all[m_all["split"].isin(["test"])].reset_index(drop=True)

print("TEST rows:", len(test_df))
print("TEST label counts:", test_df["label_num"].value_counts(dropna=False).to_dict())
print("TEST sex counts (raw):", test_df["sex"].value_counts(dropna=False).to_dict() if "sex" in test_df.columns else {})

if len(test_df) == 0:
    raise RuntimeError("After filtering to split=='test', manifest has 0 rows.")

# -------------------------
# Fail fast if audio files are missing
# Inputs: test_df.clip_path
# Output: stops early with a few missing examples instead of failing mid-run
# -------------------------
def _fail_fast_missing_paths(df: pd.DataFrame, name: str):
    missing_paths = []
    for p in tqdm(df["clip_path"].astype(str).tolist(), desc=f"Check {name} clip_path exists", dynamic_ncols=True):
        if not os.path.exists(p):
            missing_paths.append(p)
            if len(missing_paths) >= 10:
                break
    if missing_paths:
        raise FileNotFoundError(f"{name}: missing clip_path(s). Examples: {missing_paths}")

_fail_fast_missing_paths(test_df, "TEST")

# -------------------------
# Add task group and normalize sex labels
# Inputs: task, sex columns in the manifest
# Output: test_df gains task_group and sex_norm for routing heads and fairness metrics
# -------------------------
def _task_group(task_val) -> str:
    return "vowel" if str(task_val) == VOWEL_TASK_VALUE else "other"

test_df["task_group"] = test_df["task"].apply(_task_group)

def normalize_sex(val) -> str:
    """
    Returns 'M', 'F', or 'UNK'.

    Handles:
    - common strings (male/female, m/f, etc.)
    - numeric encodings:
        0 -> F
        1 -> M
    """
    if pd.isna(val):
        return "UNK"

    try:
        fv = float(val)
        if np.isfinite(fv) and abs(fv - round(fv)) < 1e-9:
            iv = int(round(fv))
            if iv == 0:
                return "F"
            if iv == 1:
                return "M"
    except Exception:
        pass

    s = str(val).strip().lower()
    if s in {"m", "male", "man", "masc", "masculine"}:
        return "M"
    if s in {"f", "female", "woman", "fem", "feminine"}:
        return "F"
    return "UNK"

test_df["sex_norm"] = test_df["sex"].apply(normalize_sex) if "sex" in test_df.columns else "UNK"
print("TEST sex counts (normalized):", test_df["sex_norm"].value_counts(dropna=False).to_dict())
if (test_df["sex_norm"] == "UNK").any():
    print("NOTE: Some 'sex' values could not be normalized to M/F and were counted as 'UNK' for fairness and sex charts.")

# -------------------------
# Dataset and batch collator
# Inputs: test_df and audio files on disk
# Output: DataLoader batches with padding + attention masks, plus labels and metadata
# -------------------------
class AudioManifestDataset(Dataset):
    """
    Loads a clip and builds an attention mask in sample space.

    Attention mask rule:
    - vowel clips: mask trailing near-zeros to reduce impact of padded silence.
    - other clips: full attention (all ones).
    """
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        clip_path = str(row["clip_path"])
        label = int(row["label_num"])
        task_group = str(row["task_group"])
        sex_norm = str(row["sex_norm"])

        if not os.path.exists(clip_path):
            raise FileNotFoundError(f"Missing clip_path: {clip_path}")

        y, sr = sf.read(clip_path, always_2d=False)
        if isinstance(y, np.ndarray) and y.ndim > 1:
            y = np.mean(y, axis=1)
        y = np.asarray(y, dtype=np.float32)

        if int(sr) != SR_EXPECTED:
            raise ValueError(f"Sample rate mismatch (expected {SR_EXPECTED}): {clip_path} has sr={sr}")

        attn = np.ones((len(y),), dtype=np.int64)

        if task_group == "vowel":
            k = -1
            for j in range(len(y) - 1, -1, -1):
                if abs(float(y[j])) > TINY_THRESH:
                    k = j
                    break
            if k >= 0:
                attn[:k+1] = 1
                attn[k+1:] = 0
            else:
                attn[:] = 1
        else:
            attn[:] = 1

        return {
            "input_values": torch.from_numpy(y),                 # float32 [T]
            "attention_mask": torch.from_numpy(attn),            # int64   [T]
            "labels": torch.tensor(label, dtype=torch.long),     # int64   []
            "task_group": task_group,                            # str
            "sex_norm": sex_norm,                                # str
        }

def collate_fn(batch):
    # Pads to the longest clip in the batch and pads attention_mask with zeros
    max_len = int(max(b["input_values"].numel() for b in batch))
    input_vals, attn_masks, labels, task_groups, sex_norms = [], [], [], [], []
    for b in batch:
        x = b["input_values"]
        a = b["attention_mask"]
        pad = max_len - x.numel()
        if pad > 0:
            x = torch.cat([x, torch.zeros(pad, dtype=x.dtype)], dim=0)
            a = torch.cat([a, torch.zeros(pad, dtype=a.dtype)], dim=0)
        input_vals.append(x)
        attn_masks.append(a)
        labels.append(b["labels"])
        task_groups.append(b["task_group"])
        sex_norms.append(b["sex_norm"])
    return {
        "input_values": torch.stack(input_vals, dim=0),
        "attention_mask": torch.stack(attn_masks, dim=0),
        "labels": torch.stack(labels, dim=0),
        "task_group": task_groups,
        "sex_norm": sex_norms,
    }

# -------------------------
# Two-head classifier (inference only)
# Inputs: Wav2Vec2 backbone checkpoint + loaded head weights
# Output: logits per clip routed through vowel vs other head
# -------------------------
class Wav2Vec2TwoHeadClassifier(nn.Module):
    def __init__(self, ckpt: str, dropout_p: float):
        super().__init__()
        self.backbone = Wav2Vec2Model.from_pretrained(ckpt, use_safetensors=True)
        for p in self.backbone.parameters():
            p.requires_grad = False

        H = int(self.backbone.config.hidden_size)
        self.pre_vowel = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.pre_other = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.head_vowel = nn.Linear(H, 2)
        self.head_other = nn.Linear(H, 2)

    def masked_mean_pool(self, last_hidden, attn_mask_samples):
        # Converts sample-level mask to feature-level mask, then averages only valid frames
        feat_mask = self.backbone._get_feature_vector_attention_mask(last_hidden.shape[1], attn_mask_samples)
        feat_mask = feat_mask.to(last_hidden.device).unsqueeze(-1).type_as(last_hidden)
        masked = last_hidden * feat_mask
        denom = feat_mask.sum(dim=1).clamp(min=1.0)
        return masked.sum(dim=1) / denom

    def _heads_fp32(self, x_any: torch.Tensor, head: nn.Module) -> torch.Tensor:
        # Runs heads in float32 for stable probabilities even when AMP is enabled
        x = x_any.float()
        if x.is_cuda:
            with torch.amp.autocast(device_type="cuda", enabled=False):
                return head(x)
        return head(x)

    def forward_logits(self, input_values, attention_mask, task_group):
        with torch.no_grad():
            out = self.backbone(input_values=input_values, attention_mask=attention_mask)
            last_hidden = out.last_hidden_state

        pooled = self.masked_mean_pool(last_hidden, attention_mask)  # [B,H]

        z_v = self.pre_vowel(pooled.float())
        z_o = self.pre_other(pooled.float())

        logits_v = self._heads_fp32(z_v, self.head_vowel)
        logits_o = self._heads_fp32(z_o, self.head_other)

        # Route each item to its matching head
        is_vowel = torch.tensor([tg == "vowel" for tg in task_group], device=pooled.device, dtype=torch.bool)
        logits = logits_o.clone()
        if is_vowel.any():
            logits[is_vowel] = logits_v[is_vowel]
        return logits

# -------------------------
# Metrics, plots, and summary helpers
# Inputs: y_true, y_prob, threshold
# Output: AUROC, threshold metrics, ROC/CM images, mean±SD utilities
# -------------------------
def compute_auc(y_true, y_prob):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    if len(np.unique(y_true)) < 2:
        return float("nan")
    return float(roc_auc_score(y_true, y_prob))

def compute_threshold_metrics(y_true, y_prob, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    tn, fp, fn, tp = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])

    eps = 1e-12
    acc = (tp + tn) / max(1, (tp + tn + fp + fn))
    prec = tp / (tp + fp + eps)
    rec = tp / (tp + fn + eps)     # sensitivity
    f1 = 2 * prec * rec / (prec + rec + eps)
    spec = tn / (tn + fp + eps)

    try:
        mcc = float(matthews_corrcoef(y_true, y_pred)) if len(np.unique(y_pred)) > 1 else float("nan")
    except Exception:
        mcc = float("nan")

    pval = float("nan")
    try:
        from scipy.stats import fisher_exact  # type: ignore
        _, pval = fisher_exact([[tn, fp], [fn, tp]], alternative="two-sided")
        pval = float(pval)
    except Exception:
        pval = float("nan")

    return {
        "threshold": float(thr),
        "tn": tn, "fp": fp, "fn": fn, "tp": tp,
        "accuracy": float(acc),
        "precision": float(prec),
        "recall": float(rec),
        "f1": float(f1),
        "sensitivity": float(rec),
        "specificity": float(spec),
        "mcc": float(mcc),
        "p_value_fisher": float(pval),
    }

def save_roc_curve_png(y_true, y_prob, out_png, title_suffix="Test"):
    fpr, tpr, _ = roc_curve(np.asarray(y_true, dtype=np.int64), np.asarray(y_prob, dtype=np.float64))
    plt.figure()
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curve ({title_suffix})")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def save_confusion_png(y_true, y_prob, out_png, thr=0.5, title_suffix="Test"):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(int)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])

    plt.figure()
    plt.imshow(cm)
    plt.xticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.yticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix ({title_suffix}, thr={thr:.4f})")
    for (i, j), v in np.ndenumerate(cm):
        plt.text(j, i, str(v), ha="center", va="center")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def mean_sd(vals):
    vals = np.asarray(vals, dtype=np.float64)
    mu = float(np.nanmean(vals)) if np.any(~np.isnan(vals)) else float("nan")
    sd = float(np.nanstd(vals, ddof=1)) if np.sum(~np.isnan(vals)) > 1 else 0.0
    return mu, sd

# -------------------------
# Fairness metric: FNR by sex and ΔFNR
# Inputs: y_true, y_prob, sex_norm, threshold
# Output: per-sex PD-only FNR and the difference F minus M
# -------------------------
def compute_fnr_by_sex_and_delta(y_true, y_prob, sex_norm, thr):
    """
    FNR is computed on PD-only true labels (y_true==1):
      FNR(sex) = FN/(FN+TP)
    Returns:
      fnr_by_sex: per-sex counts and FNR
      delta_f_minus_m: FNR(F) - FNR(M)
    """
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    sex_norm = np.asarray(list(sex_norm), dtype=object)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    out = {}
    for g in ["M", "F", "UNK"]:
        mask_g = (sex_norm == g)
        if int(mask_g.sum()) == 0:
            continue
        pos_mask = mask_g & (y_true == 1)
        n_pos = int(pos_mask.sum())
        if n_pos == 0:
            out[g] = {"n_total": int(mask_g.sum()), "n_pos": 0, "tp": 0, "fn": 0, "fnr": float("nan")}
            continue
        tp = int(((y_pred == 1) & pos_mask).sum())
        fn = int(((y_pred == 0) & pos_mask).sum())
        fnr = float(fn / max(1, (fn + tp)))
        out[g] = {"n_total": int(mask_g.sum()), "n_pos": int(n_pos), "tp": int(tp), "fn": int(fn), "fnr": float(fnr)}

    fnr_m = out.get("M", {}).get("fnr", float("nan"))
    fnr_f = out.get("F", {}).get("fnr", float("nan"))
    if (not np.isnan(fnr_m)) and (not np.isnan(fnr_f)):
        delta = float(fnr_f - fnr_m)
    else:
        delta = float("nan")

    return out, delta

# -------------------------
# Seed control
# Inputs: seed integer
# Output: repeatable dataloader order and model behavior for that seed
# -------------------------
def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# -------------------------
# Build the TEST DataLoader + quick warm-up
# Inputs: test_df and batching settings
# Output: test_loader ready for inference; warm-up catches shape/path issues early
# -------------------------
test_ds = AudioManifestDataset(test_df)
test_loader = DataLoader(
    test_ds,
    batch_size=PER_DEVICE_BS,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    collate_fn=collate_fn,
)

print("\nWarm-up: loading up to 3 TEST batches...")
t0 = time.time()
nb = len(test_loader)
wb = min(3, nb)
if wb == 0:
    raise RuntimeError("TEST DataLoader has 0 batches. Check test_df length and PER_DEVICE_BS.")
it = iter(test_loader)
for i in range(wb):
    _ = next(it)
    print(f"  loaded warmup TEST batch {i+1}/{wb}")
print(f"Warm-up done in {time.time()-t0:.2f}s")

# -------------------------
# Choose the most recent finished train+val experiment
# Inputs: trainval_runs/exp_* folders under DX_OUT_ROOT
# Output: chosen_exp and chosen_summary with thresholds and head checkpoints for all seeds
# -------------------------
TRAINVAL_ROOT = Path(DX_OUT_ROOT) / "trainval_runs"
if not TRAINVAL_ROOT.exists():
    raise FileNotFoundError(f"Missing trainval_runs folder: {str(TRAINVAL_ROOT)}")

exp_dirs = sorted([p for p in TRAINVAL_ROOT.glob("exp_*") if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
if not exp_dirs:
    raise FileNotFoundError(f"No exp_* folders found under: {str(TRAINVAL_ROOT)}")

def _has_all_seeds(exp_path: Path, dataset_id: str, seeds: list):
    # Checks that each seed has a saved best_heads.pt
    for s in seeds:
        p = exp_path / f"run_{dataset_id}_seed{s}" / "best_heads.pt"
        if not p.exists():
            return False
    return True

def _load_trainval_summary(exp_path: Path) -> dict:
    # Loads summary_trainval.json if present and readable
    sp = exp_path / "summary_trainval.json"
    if not sp.exists():
        return {}
    try:
        with open(sp, "r", encoding="utf-8") as f:
            return json.load(f)
    except Exception:
        return {}

def _summary_has_threshold_keys(summary: dict) -> bool:
    # Ensures the threshold fields needed for evaluation are present
    try:
        v = summary["val_optimal_threshold"]
        _ = v["by_seed"]
        _ = v["mean_sd"]["mean"]
        _ = v["mean_sd"]["sd"]
        return True
    except Exception:
        return False

chosen_exp = None
chosen_summary = None

for ed in exp_dirs:
    if not _has_all_seeds(ed, dataset_id, SEEDS):
        continue
    summ = _load_trainval_summary(ed)
    if not _summary_has_threshold_keys(summ):
        continue
    chosen_exp = ed
    chosen_summary = summ
    break

if chosen_exp is None or chosen_summary is None:
    sample = exp_dirs[0]
    raise FileNotFoundError(
        "Could not find a recent trainval experiment with:\n"
        "  - all 3 best_heads.pt files\n"
        "  - a readable summary_trainval.json\n"
        "  - val_optimal_threshold.by_seed + mean_sd.mean + mean_sd.sd\n"
        f"Expected under: {str(TRAINVAL_ROOT)}/exp_*/\n"
        f"Most recent exp checked: {str(sample)}"
    )

print("\nUsing Train+Val experiment folder:")
print(" ", str(chosen_exp))

# -------------------------
# Load VAL-opt thresholds and choose the shared TEST threshold
# Inputs: chosen_summary["val_optimal_threshold"]
# Output: TEST_THR_MEAN_VAL_OPT used for all seeds on test
# -------------------------
val_opt = chosen_summary["val_optimal_threshold"]
thr_by_seed = {str(k): float(v) for k, v in (val_opt.get("by_seed") or {}).items()}
thr_mean = float(val_opt.get("mean_sd", {}).get("mean", float("nan")))
thr_sd   = float(val_opt.get("mean_sd", {}).get("sd", float("nan")))

print("\nVAL-opt thresholds loaded from trainval summary_trainval.json:")
print("  method:", str(val_opt.get("method", "unknown")))
print("  by_seed:", {k: (f"{v:.6f}" if not np.isnan(v) else "nan") for k, v in thr_by_seed.items()})
print("  mean ± SD:", (f"{thr_mean:.6f} ± {thr_sd:.6f}" if not np.isnan(thr_mean) else f"nan ± {thr_sd:.6f}"))

TEST_THR_SOURCE = "trainval summary_trainval.json -> val_optimal_threshold.mean_sd.mean (canonical aggregate)"
TEST_THR_MEAN_VAL_OPT = float(thr_mean)
TEST_THR_NOTE = None

if np.isnan(TEST_THR_MEAN_VAL_OPT):
    # Fallback only if the canonical mean is missing
    vals = []
    for s in SEEDS:
        v = thr_by_seed.get(str(s), float("nan"))
        if not np.isnan(v):
            vals.append(float(v))
    if len(vals) > 0:
        TEST_THR_MEAN_VAL_OPT = float(np.mean(vals))
        TEST_THR_NOTE = "val_optimal_threshold.mean_sd.mean was NaN. Fallback used: mean(thr_by_seed over available seeds)."
    else:
        TEST_THR_MEAN_VAL_OPT = 0.5
        TEST_THR_NOTE = "val_optimal_threshold.mean_sd.mean was NaN and thr_by_seed had no usable values. Fallback used: thr=0.5."

print("\nTEST threshold policy:")
print("  Using MEAN VAL-opt threshold for ALL seeds/tests:", f"{TEST_THR_MEAN_VAL_OPT:.6f}")
print("  Source:", TEST_THR_SOURCE)
if TEST_THR_NOTE is not None:
    print("  NOTE:", TEST_THR_NOTE)

# -------------------------
# Output folder for test runs
# Inputs: DX_OUT_ROOT
# Output: multilingual_test_runs/ with one folder per seed plus summary files
# -------------------------
TEST_ROOT = Path(DX_OUT_ROOT) / "multilingual_test_runs"
TEST_ROOT.mkdir(parents=True, exist_ok=True)

# -------------------------
# Load saved heads into the model
# Inputs: best_heads.pt from the chosen train+val run
# Output: model with the trained head weights ready for inference
# -------------------------
def load_heads_into_model(model: Wav2Vec2TwoHeadClassifier, best_heads_path: Path):
    if not best_heads_path.exists():
        raise FileNotFoundError(f"Missing best_heads.pt: {str(best_heads_path)}")
    state = torch.load(str(best_heads_path), map_location="cpu")
    model.pre_vowel.load_state_dict(state["pre_vowel"], strict=True)
    model.pre_other.load_state_dict(state["pre_other"], strict=True)
    model.head_vowel.load_state_dict(state["head_vowel"], strict=True)
    model.head_other.load_state_dict(state["head_other"], strict=True)
    return model

# -------------------------
# Inference loop
# Inputs: DataLoader + model
# Output: arrays of true labels, PD probabilities, and sex labels (for fairness splits)
# -------------------------
def infer_probs(loader, model, desc):
    use_amp = bool(USE_AMP and DEVICE.type == "cuda")
    all_probs, all_true, all_sex = [], [], []

    pbar = tqdm(loader, desc=desc, dynamic_ncols=True)
    with torch.inference_mode():
        for batch in pbar:
            input_values = batch["input_values"].to(DEVICE, non_blocking=False)
            attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
            labels = batch["labels"].to(DEVICE, non_blocking=False)
            task_group = batch["task_group"]
            sex_norm = batch["sex_norm"]

            with torch.amp.autocast(device_type="cuda", enabled=use_amp):
                logits = model.forward_logits(input_values, attention_mask, task_group)

            probs = torch.softmax(logits.float(), dim=-1)[:, 1].detach().cpu().numpy().astype(np.float64)
            all_probs.extend(probs.tolist())
            all_true.extend(labels.detach().cpu().numpy().astype(np.int64).tolist())
            all_sex.extend(list(sex_norm))

    return (
        np.asarray(all_true, dtype=np.int64),
        np.asarray(all_probs, dtype=np.float64),
        np.asarray(all_sex, dtype=object),
    )

# -------------------------
# Per-seed evaluation and artifact writing
# Inputs: a seed, chosen train+val experiment, shared threshold
# Outputs: run_<dataset>_seedXXXX/metrics.json + ROC/CM plots (overall + by sex)
# -------------------------
def run_test_once(seed: int):
    set_all_seeds(seed)

    run_dir = TEST_ROOT / f"run_{dataset_id}_seed{seed}"
    run_dir.mkdir(parents=True, exist_ok=True)

    best_heads_path = chosen_exp / f"run_{dataset_id}_seed{seed}" / "best_heads.pt"
    thr_used = float(TEST_THR_MEAN_VAL_OPT)

    print(f"\n[seed={seed}] Loading model + heads from:")
    print(" ", str(best_heads_path))
    print(f"[seed={seed}] Using TEST threshold (shared mean VAL-opt): {thr_used:.6f}")
    if TEST_THR_NOTE is not None:
        print(f"[seed={seed}] NOTE: {TEST_THR_NOTE}")

    model = Wav2Vec2TwoHeadClassifier(BACKBONE_CKPT, dropout_p=DROPOUT_P).to(DEVICE)
    model = load_heads_into_model(model, best_heads_path)
    model.eval()

    yt_true, yt_prob, yt_sex = infer_probs(test_loader, model, desc=f"[seed={seed}] Test")
    test_auc = compute_auc(yt_true, yt_prob)

    thr_metrics_test = compute_threshold_metrics(yt_true, yt_prob, thr=thr_used)
    fnr_by_sex, delta_f_minus_m = compute_fnr_by_sex_and_delta(yt_true, yt_prob, yt_sex, thr=thr_used)

    # Plots (overall)
    roc_png = run_dir / "roc_curve.png"
    cm_png  = run_dir / "confusion_matrix.png"
    save_roc_curve_png(yt_true, yt_prob, str(roc_png), title_suffix=f"Test (seed={seed})")
    save_confusion_png(yt_true, yt_prob, str(cm_png), thr=thr_used, title_suffix=f"Test (seed={seed})")

    # Plots (by sex)
    cm_m_png = None
    cm_f_png = None
    mask_m = (yt_sex == "M")
    mask_f = (yt_sex == "F")

    if int(mask_m.sum()) > 0:
        cm_m_png = run_dir / "confusion_matrix_M.png"
        save_confusion_png(yt_true[mask_m], yt_prob[mask_m], str(cm_m_png), thr=thr_used, title_suffix=f"Test SEX=M (seed={seed})")

    if int(mask_f.sum()) > 0:
        cm_f_png = run_dir / "confusion_matrix_F.png"
        save_confusion_png(yt_true[mask_f], yt_prob[mask_f], str(cm_f_png), thr=thr_used, title_suffix=f"Test SEX=F (seed={seed})")

    metrics = {
        "dataset": dataset_id,
        "seed": int(seed),

        "dx_out_root": DX_OUT_ROOT,
        "manifest_all": MANIFEST_ALL,

        "trainval_experiment_used": str(chosen_exp),
        "best_heads_path": str(best_heads_path),

        "n_test": int(len(test_df)),
        "label_counts_test": test_df["label_num"].value_counts(dropna=False).to_dict(),
        "sex_counts_test_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),

        "test_auroc": float(test_auc),

        "threshold_source": TEST_THR_SOURCE,
        "val_optimal_threshold_canonical": {
            "method": str(val_opt.get("method", "unknown")),
            "by_seed": {k: float(v) for k, v in thr_by_seed.items()},
            "mean_sd": {"mean": float(thr_mean), "sd": float(thr_sd)},
        },
        "test_threshold_used": float(thr_used),
        "test_threshold_note": (TEST_THR_NOTE if TEST_THR_NOTE is not None else ""),

        "threshold_metrics_test_at_val_opt": thr_metrics_test,

        "fairness_test_at_val_opt": {
            "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at test_threshold_used.",
            "threshold_used": float(thr_used),
            "fnr_by_sex_norm": fnr_by_sex,
            "delta_fnr_F_minus_M": float(delta_f_minus_m),
            "note": "If n_PD for a sex is 0, its FNR is NaN and ΔFNR is NaN.",
            "sex_normalization_note": "Normalizer accepts numeric 0->F, 1->M and common strings; otherwise UNK.",
        },

        "artifacts": {
            "roc_curve_png": str(roc_png),
            "confusion_matrix_png": str(cm_png),
            "confusion_matrix_M_png": str(cm_m_png) if cm_m_png is not None else None,
            "confusion_matrix_F_png": str(cm_f_png) if cm_f_png is not None else None,
        },

        "backbone_ckpt": BACKBONE_CKPT,
        "dropout_p": float(DROPOUT_P),
        "use_amp": bool(USE_AMP and DEVICE.type == "cuda"),
        "timestamp": time.strftime("%Y%m%d_%H%M%S"),
    }

    metrics_path = run_dir / "metrics.json"
    with open(metrics_path, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2)

    print(f"[seed={seed}] DONE | test_AUROC={test_auc:.6f} | thr_used={thr_used:.6f}")
    print(f"[seed={seed}] WROTE:")
    print(" ", str(metrics_path))
    print(" ", str(roc_png))
    print(" ", str(cm_png))
    if cm_m_png is not None:
        print(" ", str(cm_m_png))
    if cm_f_png is not None:
        print(" ", str(cm_f_png))

    return {
        "seed": int(seed),
        "thr_used": float(thr_used),
        "test_auc": float(test_auc),
        "thr_metrics": thr_metrics_test,
        "fnr_by_sex": fnr_by_sex,
        "delta_f_minus_m": float(delta_f_minus_m),
        "run_dir": str(run_dir),
    }

# -------------------------
# Run all seeds and aggregate results
# Inputs: SEEDS list
# Outputs: per-seed artifacts + summary_test.json + history_index.jsonl
# -------------------------
results = []
for seed in SEEDS:
    results.append(run_test_once(seed))

# ---- AUROC aggregation (TEST)
aurocs = [r["test_auc"] for r in results]
t_crit = 4.302652729911275  # df=2, 95% CI
n = len(aurocs)
mean_auc = float(np.mean(aurocs))
std_auc = float(np.std(aurocs, ddof=1)) if n > 1 else 0.0
half_width = float(t_crit * (std_auc / math.sqrt(n))) if n > 1 else 0.0
ci95 = [float(mean_auc - half_width), float(mean_auc + half_width)]

# ---- threshold metrics aggregation @ shared threshold
metric_keys = ["accuracy","precision","recall","f1","sensitivity","specificity","mcc","p_value_fisher"]
metric_agg = {}
for k in metric_keys:
    vals = [float(r["thr_metrics"].get(k, float("nan"))) for r in results]
    mu, sd = mean_sd(vals)
    metric_agg[k] = {
        "mean": float(mu),
        "sd": float(sd),
        "by_seed": {str(r["seed"]): float(r["thr_metrics"].get(k, float("nan"))) for r in results},
    }

# ---- fairness aggregation @ shared threshold
delta_vals = [float(r["delta_f_minus_m"]) for r in results]
delta_mean, delta_sd = mean_sd(delta_vals)

fnr_m_vals = []
fnr_f_vals = []
n_pd_m_vals = []
n_pd_f_vals = []
for r in results:
    d = r["fnr_by_sex"] or {}
    fnr_m_vals.append(float(d.get("M", {}).get("fnr", float("nan"))))
    fnr_f_vals.append(float(d.get("F", {}).get("fnr", float("nan"))))
    n_pd_m_vals.append(float(d.get("M", {}).get("n_pos", float("nan"))))
    n_pd_f_vals.append(float(d.get("F", {}).get("n_pos", float("nan"))))

fnr_m_mean, fnr_m_sd = mean_sd(fnr_m_vals)
fnr_f_mean, fnr_f_sd = mean_sd(fnr_f_vals)

print("\nTest AUROC by seed:")
for r in results:
    print(f"  seed {r['seed']}: {r['test_auc']:.6f}")
print("\nMean Test AUROC:", f"{mean_auc:.6f}")
print(f"95% CI (t, n=3): [{ci95[0]:.6f}, {ci95[1]:.6f}]")

print("\nTEST threshold used (shared mean VAL-opt):", f"{TEST_THR_MEAN_VAL_OPT:.6f}")
if TEST_THR_NOTE is not None:
    print("NOTE:", TEST_THR_NOTE)

print("\nThreshold metrics on TEST @ shared mean VAL-opt threshold (mean ± SD across seeds):")
for k in ["accuracy","precision","sensitivity","specificity","f1","mcc"]:
    print(f"  {k}: {metric_agg[k]['mean']:.6f} ± {metric_agg[k]['sd']:.6f}")

print("\nFAIRNESS (H3) on TEST @ shared threshold (mean ± SD across seeds):")
print(f"  FNR_M: {fnr_m_mean:.6f} ± {fnr_m_sd:.6f}")
print(f"  FNR_F: {fnr_f_mean:.6f} ± {fnr_f_sd:.6f}")
print(f"  ΔFNR (F-M): {delta_mean:.6f} ± {delta_sd:.6f}")

summary = {
    "dataset": dataset_id,
    "dx_out_root": DX_OUT_ROOT,
    "manifest_all": MANIFEST_ALL,
    "trainval_experiment_used": str(chosen_exp),
    "seeds": SEEDS,

    "n_test": int(len(test_df)),
    "label_counts_test": test_df["label_num"].value_counts(dropna=False).to_dict(),
    "sex_counts_test_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),

    "val_optimal_threshold": {
        "source": "trainval summary_trainval.json",
        "method": str(val_opt.get("method", "unknown")),
        "by_seed": {str(k): float(v) for k, v in thr_by_seed.items()},
        "mean_sd": {"mean": float(thr_mean), "sd": float(thr_sd)},
    },

    "test_threshold_used": {
        "policy": "mean_val_optimal_threshold_shared_across_seeds",
        "threshold_used": float(TEST_THR_MEAN_VAL_OPT),
        "note": (TEST_THR_NOTE if TEST_THR_NOTE is not None else ""),
        "source": TEST_THR_SOURCE,
    },

    "test_aurocs_by_seed": {str(r["seed"]): float(r["test_auc"]) for r in results},
    "mean_test_auroc": float(mean_auc),
    "t_crit_df2_95": float(t_crit),
    "half_width_95_ci": float(half_width),
    "ci95_test_auroc": ci95,

    "threshold_metrics_test_at_val_opt_mean_sd": metric_agg,

    "fairness_test_at_val_opt": {
        "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at the shared mean VAL-opt threshold.",
        "delta_fnr_F_minus_M_by_seed": {str(r["seed"]): float(r["delta_f_minus_m"]) for r in results},
        "delta_fnr_F_minus_M_mean_sd": {"mean": float(delta_mean), "sd": float(delta_sd)},
        "fnr_M_mean_sd": {"mean": float(fnr_m_mean), "sd": float(fnr_m_sd), "by_seed": {str(SEEDS[i]): float(fnr_m_vals[i]) for i in range(len(SEEDS))}},
        "fnr_F_mean_sd": {"mean": float(fnr_f_mean), "sd": float(fnr_f_sd), "by_seed": {str(SEEDS[i]): float(fnr_f_vals[i]) for i in range(len(SEEDS))}},
        "denominators_PD_by_seed": {str(SEEDS[i]): {"n_PD_M": float(n_pd_m_vals[i]), "n_PD_F": float(n_pd_f_vals[i])} for i in range(len(SEEDS))},
        "sex_normalization_note": "Normalizer accepts numeric 0->F, 1->M and common strings; otherwise UNK.",
    },

    "run_dirs": [r["run_dir"] for r in results],
    "timestamp": time.strftime("%Y%m%d_%H%M%S"),
}

summary_path = TEST_ROOT / "summary_test.json"
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(summary, f, indent=2)

history_path = TEST_ROOT / "history_index.jsonl"
with open(history_path, "a", encoding="utf-8") as f:
    f.write(json.dumps(summary) + "\n")

print("\nWROTE summary:", str(summary_path))
print("APPENDED history index:", str(history_path))
print("Open this folder to access artifacts:", str(TEST_ROOT))

# -------------------------
# Stop the runtime to release the GPU
# Inputs: Colab runtime module (if available)
# Output: ends the session cleanly when running on paid GPU instances
# -------------------------
print("\nAll done. Unassigning the runtime to stop the L4 instance...")
try:
    from google.colab import runtime  # type: ignore
    print("Calling runtime.unassign() now.")
    runtime.unassign()
except Exception as e:
    print("Could not unassign runtime automatically. You can stop the runtime manually in Colab.")
    print("Reason:", repr(e))

The following cell runs a test only evaluation of the fixed D7 Base model on the D2 test split and saves fully traceable outputs, including predictions, plots, and summary metrics. It reads `manifest_all.csv` from D2, keeps only rows where `split == "test"`, checks that required columns are present, including sex and age, and verifies that every audio file path exists. It also creates two simple grouping fields used later in reporting: `task_group`, which is set to “vowel” when the task is `vowl` and “other” otherwise, and `sex_norm`, which is normalized to M, F, or UNK using a robust rule.

The model used here matches the training setup. The cell rebuilds a frozen Wav2Vec2 backbone with two small classification heads, one for vowel clips and one for other clips. It then loads the trained head weights from one fixed, preselected D7 train and validation experiment folder, using one `best_heads.pt` file per seed. The cell also loads a single global decision threshold from that experiment’s `summary_trainval.json`, using `val_optimal_threshold.mean_sd.mean`. This same threshold is applied to all three test runs, one per seed. If the stored mean threshold is missing, the cell falls back to averaging the available per seed thresholds, and only uses 0.5 as a last resort.

Inference is then run on the D2 test set for each of the three seeds, 1337, 2024, and 7777, producing a Parkinson’s disease probability score for every clip. For each seed, the cell computes AUROC on the full test set, threshold based metrics at the shared threshold including the confusion matrix, accuracy, precision, sensitivity or recall, specificity, F1 score, MCC, and the Fisher exact test p value, and a fairness metric using the H3 definition. Fairness is reported as ΔFNR, defined as FNR(F) minus FNR(M), where the false negative rate is calculated only on true Parkinson’s cases as FN divided by FN plus TP, using the normalized sex labels.

To support later analysis and writing, the cell saves both per clip outputs and visual artifacts. For each seed, it writes a `predictions.csv` file that includes the clip path, true label, predicted score, normalized sex, speaker ID, task group, seed, the source train and validation experiment tag, the run timestamp, and the global threshold used. It also saves ROC curve images and confusion matrix images, including confusion matrices split by sex when both M and F samples are available. All outputs are written to two locations: a stable tag based folder that always reflects the latest run for that tag, and a run stamped folder that preserves a snapshot of the specific execution.

After all three seeds finish, the cell aggregates results across seeds and writes a single `summary_test.json` file. This summary reports the mean AUROC with a 95 percent confidence interval using a t distribution with n equal to 3, the mean and standard deviation of the threshold based metrics, and the mean and standard deviation of ΔFNR along with the underlying FNR values for M and F. The summary is appended to a `history_index.jsonl` log for tracking runs over time. The cell also writes small pointer JSON files to help locate the latest run, writes builder aligned config and log stubs with backup if existing files are found, prints the key output locations, and finally unassigns the Colab runtime to stop GPU usage.

In [None]:
# =========================
# D7 BASE → D2 TEST (Test-only evaluation)
# Purpose:
# - Run D2 test inference using frozen D7 BASE heads (3 seeds)
# - Use ONE shared threshold: mean VAL-opt threshold saved by the base trainval run
# Outputs:
# - Per-seed predictions.csv + metrics.json in a stable tag folder and a run-stamped folder
# - summary_test.json + history_index.jsonl under multilingual_test_runs/
# - Builder-aligned run records under config/ and logs/
# =========================

import os, json, math, random, time, warnings, shutil
from pathlib import Path

import numpy as np
import pandas as pd
import soundfile as sf

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import Wav2Vec2Model

from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, matthews_corrcoef
from scipy.stats import fisher_exact
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# -------------------------
# Shadowing guardrails
# Inputs: local filesystem
# Outputs: fail-fast error if a local file would override torch/transformers imports
# -------------------------
if os.path.exists("/content/torch.py") or os.path.exists("/content/torch/__init__.py"):
    raise RuntimeError("Found local /content/torch.py or /content/torch/ that shadows PyTorch. Rename/remove it and restart runtime.")
if os.path.exists("/content/transformers.py") or os.path.exists("/content/transformers/__init__.py"):
    raise RuntimeError("Found local /content/transformers.py or /content/transformers/ that shadows Hugging Face Transformers. Rename/remove it and restart runtime.")

# -------------------------
# Drive mount (Colab only)
# Inputs: none
# Outputs: Drive mounted when needed; no-op otherwise
# -------------------------
try:
    from google.colab import drive  # type: ignore
    if not os.path.isdir("/content/drive/MyDrive"):
        drive.mount("/content/drive")
except Exception:
    pass

# -------------------------
# Paths and run identifiers
# Inputs: D7_OUT_ROOT, D2_OUT_ROOT
# Outputs: a run stamp and a safe tag used for folder names and metadata
# -------------------------
D7_OUT_ROOT = "/content/drive/MyDrive/AI_PD_Project/Datasets/D7-Multilingual (D1_D4_D5v2_D6)/preprocessed_v1"
D2_OUT_ROOT = "/content/drive/MyDrive/AI_PD_Project/Datasets/D2-Slovak (EWA-DB)/EWA-DB/preprocessed_v1"

D2_MANIFEST_ALL = f"{D2_OUT_ROOT}/manifests/manifest_all.csv"

# Keep DX_OUT_ROOT aligned with run root (D7)
DX_OUT_ROOT = D7_OUT_ROOT
globals()["DX_OUT_ROOT"] = DX_OUT_ROOT
globals()["D7_OUT_ROOT"] = D7_OUT_ROOT
globals()["D2_OUT_ROOT"] = D2_OUT_ROOT

RUN_STAMP = time.strftime("%Y%m%d_%H%M%S")

def _sanitize_tag(s: str) -> str:
    """Folder-safe tag: keep letters, numbers, dash, underscore; convert others to '_'."""
    s = str(s).strip()
    out = []
    for ch in s:
        if ch.isalnum() or ch in ["-", "_"]:
            out.append(ch)
        else:
            out.append("_")
    out = "".join(out).strip("_")
    return out if out else "tag"

# -------------------------
# Fixed BASE trainval experiment
# Inputs: a single base trainval experiment folder
# Outputs: FULL_TRAINVAL_EXP_TAG used in predictions and summaries
# -------------------------
_trainval_root = Path(D7_OUT_ROOT) / "trainval_runs"
if not _trainval_root.exists():
    raise FileNotFoundError(f"Missing trainval_runs folder under D7_OUT_ROOT: {str(_trainval_root)}")

def _is_complete_base_exp_dir(exp_dir: Path) -> bool:
    if not exp_dir.is_dir():
        return False
    if not (exp_dir / "summary_trainval.json").exists():
        return False
    for seed in [1337, 2024, 7777]:
        if not (exp_dir / f"run_D7_seed{seed}" / "best_heads.pt").exists():
            return False
    return True

_exp_dirs = sorted([p for p in _trainval_root.glob("exp_*") if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
BASE_TRAINVAL_EXP_DIR = None
for p in _exp_dirs:
    if _is_complete_base_exp_dir(p):
        BASE_TRAINVAL_EXP_DIR = p
        break

if BASE_TRAINVAL_EXP_DIR is None:
    raise FileNotFoundError(
        "No suitable base trainval experiment folder found under: "
        f"{str(_trainval_root)}. Expected an exp_* folder containing summary_trainval.json "
        "and run_D7_seed{1337,2024,7777}/best_heads.pt."
    )

FULL_TRAINVAL_EXP_TAG = BASE_TRAINVAL_EXP_DIR.name  # must match the actual exp folder name

# Tag SAFE: build a readable tag that is unique per execution
TAG_RAW = f"exp_frozen_LNDO_base_initBaseline_{RUN_STAMP}"
TAG_SAFE = _sanitize_tag(TAG_RAW)

# -------------------------
# Runtime and evaluation settings
# Inputs: constants (seeds, checkpoint, sample rate, batching)
# Outputs: consistent inference behavior across runs
# -------------------------
SEEDS          = [1337, 2024, 7777]
BACKBONE_CKPT  = "facebook/wav2vec2-base"
SR_EXPECTED    = 16000
TINY_THRESH    = 1e-4

EFFECTIVE_BS   = 64
PER_DEVICE_BS  = 16
GRAD_ACCUM     = max(1, EFFECTIVE_BS // PER_DEVICE_BS)

DROPOUT_P      = 0.2

NUM_WORKERS    = 0
PIN_MEMORY     = False

USE_AMP        = True

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

# Silence known non-critical warnings
warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
warnings.filterwarnings("ignore", message="Passing `gradient_checkpointing` to a config initialization is deprecated")

# Run header (prints only; no logic changes)
print("D7_OUT_ROOT:", D7_OUT_ROOT)
print("D2_OUT_ROOT:", D2_OUT_ROOT)
print("D2_MANIFEST_ALL:", D2_MANIFEST_ALL)
print("BASE_TRAINVAL_EXP_DIR:", str(BASE_TRAINVAL_EXP_DIR))
print("FULL_TRAINVAL_EXP_TAG:", FULL_TRAINVAL_EXP_TAG)
print("TAG_SAFE:", TAG_SAFE)
print("DEVICE:", DEVICE)
if DEVICE.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
print("PER_DEVICE_BS:", PER_DEVICE_BS, "| EFFECTIVE_BS:", PER_DEVICE_BS * GRAD_ACCUM)
print("NUM_WORKERS:", NUM_WORKERS, "| PIN_MEMORY:", PIN_MEMORY)
print("USE_AMP:", bool(USE_AMP and DEVICE.type == "cuda"))

# -------------------------
# Backup helper for builder-aligned logs/config
# Inputs: destination path
# Outputs: a timestamped backup copy if the destination already exists
# -------------------------
def _backup_if_exists(path: Path):
    if not path.exists():
        return None
    ts = time.strftime("%Y%m%d_%H%M%S")
    bak = path.with_name(path.name + f".bak_{ts}")
    if path.is_dir():
        shutil.copytree(path, bak)
    else:
        shutil.copy2(path, bak)
    return str(bak)

# -------------------------
# Load D2 manifest and build the TEST table
# Inputs: D2 manifest_all.csv
# Outputs: test_df with required columns + basic counts printed
# -------------------------
if not os.path.exists(D2_MANIFEST_ALL):
    raise FileNotFoundError(f"Missing D2 manifest_all.csv: {D2_MANIFEST_ALL}")

m_all = pd.read_csv(D2_MANIFEST_ALL)

# Minimum fields needed for inference, grouping, and fairness reporting
req_cols = {"split", "clip_path", "label_num", "task", "sex", "age"}
missing = [c for c in sorted(req_cols) if c not in m_all.columns]
if missing:
    raise ValueError(f"D2 manifest missing required columns: {missing}. Found: {list(m_all.columns)}")

# Infer dataset id from the manifest, then keep only that dataset
if "dataset" in m_all.columns and m_all["dataset"].notna().any():
    d2_dataset_id = str(m_all["dataset"].astype(str).value_counts(dropna=True).idxmax())
    m_all = m_all[m_all["dataset"].astype(str) == d2_dataset_id].copy()
else:
    d2_dataset_id = "DX"

# Guard: this cell is strictly D2 test evaluation
if d2_dataset_id != "D2":
    raise RuntimeError(
        f"Expected D2 dataset_id=='D2' but got {d2_dataset_id!r}. "
        "This usually means D2_OUT_ROOT is wrong or the manifest is not D2. "
        f"D2_OUT_ROOT={D2_OUT_ROOT}"
    )

# Keep a small set of columns; create missing optional fields as NaN
keep_cols = ["clip_path", "label_num", "task", "speaker_id", "sex", "age", "duration_sec", "split"]
for c in keep_cols:
    if c not in m_all.columns:
        m_all[c] = np.nan
m_all = m_all[keep_cols].copy()

test_df = m_all[m_all["split"].astype(str) == "test"].reset_index(drop=True)

print(f"\nD2 dataset inferred: {d2_dataset_id}")
print(f"D2 TEST rows: {len(test_df)}")
print("D2 TEST label counts:", test_df["label_num"].value_counts(dropna=False).to_dict())
print("D2 TEST sex counts (raw):", test_df["sex"].value_counts(dropna=False).to_dict())

if len(test_df) == 0:
    raise RuntimeError("After filtering to split=='test', D2 manifest has 0 rows.")

# -------------------------
# Missing file check (fail-fast)
# Inputs: test_df clip_path
# Outputs: stops early if any referenced audio file is missing
# -------------------------
def _fail_fast_missing_paths(df: pd.DataFrame, name: str):
    missing_paths = []
    for p in tqdm(df["clip_path"].astype(str).tolist(), desc=f"Check {name} clip_path exists", dynamic_ncols=True):
        if not os.path.exists(p):
            missing_paths.append(p)
            if len(missing_paths) >= 10:
                break
    if missing_paths:
        raise FileNotFoundError(f"{name}: missing clip_path(s). Examples: {missing_paths}")

_fail_fast_missing_paths(test_df, "D2 TEST")

# -------------------------
# Task grouping and sex normalization
# Inputs: D2 task and sex columns
# Outputs: task_group and sex_norm columns used in metrics, plots, and predictions.csv
# -------------------------
def _task_group(task_val) -> str:
    """Map raw task to a 2-level group used to pick the correct head."""
    return "vowel" if str(task_val) == "vowl" else "other"

test_df["task_group"] = test_df["task"].apply(_task_group)

def normalize_sex(val) -> str:
    """
    Returns 'M', 'F', or 'UNK'

    Expected values: 'male'/'female', but supports common variants and numeric codes.
    """
    if pd.isna(val):
        return "UNK"

    # numeric mapping (if present)
    try:
        fv = float(val)
        if np.isfinite(fv) and abs(fv - round(fv)) < 1e-9:
            iv = int(round(fv))
            if iv == 0:
                return "F"
            if iv == 1:
                return "M"
    except Exception:
        pass

    s = str(val).strip().lower()
    if s in {"m", "male", "man", "masc", "masculine"}:
        return "M"
    if s in {"f", "female", "woman", "fem", "feminine"}:
        return "F"
    return "UNK"

test_df["sex_norm"] = test_df["sex"].apply(normalize_sex)
print("D2 TEST sex counts (normalized):", test_df["sex_norm"].value_counts(dropna=False).to_dict())
if (test_df["sex_norm"] == "UNK").any():
    print("NOTE: Some 'sex' values could not be normalized to M/F and were counted as 'UNK' for fairness and sex charts.")

# -------------------------
# Dataset and collator
# Inputs: test_df + audio files (16 kHz expected)
# Outputs: batches with audio tensors + per-clip metadata for predictions.csv
# -------------------------
class AudioManifestDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        clip_path = str(row["clip_path"])
        label = int(row["label_num"])
        task_group = str(row["task_group"])
        sex_norm = str(row["sex_norm"])
        speaker_id = row["speaker_id"] if "speaker_id" in row.index else np.nan

        if not os.path.exists(clip_path):
            raise FileNotFoundError(f"Missing clip_path: {clip_path}")

        y, sr = sf.read(clip_path, always_2d=False)
        if isinstance(y, np.ndarray) and y.ndim > 1:
            y = np.mean(y, axis=1)
        y = np.asarray(y, dtype=np.float32)

        if int(sr) != SR_EXPECTED:
            raise ValueError(f"Sample rate mismatch (expected {SR_EXPECTED}): {clip_path} has sr={sr}")

        # attention_mask marks valid samples; vowel clips may include padded silence
        attn = np.ones((len(y),), dtype=np.int64)
        if task_group == "vowel":
            k = -1
            for j in range(len(y) - 1, -1, -1):
                if abs(float(y[j])) > TINY_THRESH:
                    k = j
                    break
            if k >= 0:
                attn[:k+1] = 1
                attn[k+1:] = 0
            else:
                attn[:] = 1
        else:
            attn[:] = 1

        return {
            "input_values": torch.from_numpy(y),
            "attention_mask": torch.from_numpy(attn),
            "labels": torch.tensor(label, dtype=torch.long),
            "task_group": task_group,
            "sex_norm": sex_norm,
            "clip_path": clip_path,
            "speaker_id": speaker_id,
        }

def collate_fn(batch):
    """Pad audio and attention masks to the max length in the batch; keep metadata lists."""
    max_len = int(max(b["input_values"].numel() for b in batch))
    input_vals, attn_masks, labels = [], [], []
    task_groups, sex_norms, clip_paths, speaker_ids = [], [], [], []

    for b in batch:
        x = b["input_values"]
        a = b["attention_mask"]
        pad = max_len - x.numel()
        if pad > 0:
            x = torch.cat([x, torch.zeros(pad, dtype=x.dtype)], dim=0)
            a = torch.cat([a, torch.zeros(pad, dtype=a.dtype)], dim=0)

        input_vals.append(x)
        attn_masks.append(a)
        labels.append(b["labels"])
        task_groups.append(b["task_group"])
        sex_norms.append(b["sex_norm"])
        clip_paths.append(b["clip_path"])
        speaker_ids.append(b["speaker_id"])

    return {
        "input_values": torch.stack(input_vals, dim=0),
        "attention_mask": torch.stack(attn_masks, dim=0),
        "labels": torch.stack(labels, dim=0),
        "task_group": task_groups,
        "sex_norm": sex_norms,
        "clip_path": clip_paths,
        "speaker_id": speaker_ids,
    }

# -------------------------
# Model definition (frozen backbone, two heads)
# Inputs: backbone checkpoint name + dropout
# Outputs: logits for PD/Healthy using the head chosen by task_group
# -------------------------
class Wav2Vec2TwoHeadClassifier(nn.Module):
    def __init__(self, ckpt: str, dropout_p: float):
        super().__init__()
        self.backbone = Wav2Vec2Model.from_pretrained(ckpt, use_safetensors=True)
        for p in self.backbone.parameters():
            p.requires_grad = False

        H = int(self.backbone.config.hidden_size)
        self.pre_vowel = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.pre_other = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.head_vowel = nn.Linear(H, 2)
        self.head_other = nn.Linear(H, 2)

    def masked_mean_pool(self, last_hidden, attn_mask_samples):
        """Mean pooling over non-masked frames (mask derived from the sample-level attention_mask)."""
        feat_mask = self.backbone._get_feature_vector_attention_mask(last_hidden.shape[1], attn_mask_samples)
        feat_mask = feat_mask.to(last_hidden.device).unsqueeze(-1).type_as(last_hidden)
        masked = last_hidden * feat_mask
        denom = feat_mask.sum(dim=1).clamp(min=1.0)
        return masked.sum(dim=1) / denom

    def _heads_fp32(self, x_fp_any: torch.Tensor, head: nn.Module) -> torch.Tensor:
        """Run heads in float32 for numerical stability."""
        x = x_fp_any.float()
        if x.is_cuda:
            with torch.amp.autocast(device_type="cuda", enabled=False):
                return head(x)
        return head(x)

    def forward_logits(self, input_values, attention_mask, task_group):
        """Backbone forward (no grad) + pooled features + head selection by task_group."""
        with torch.no_grad():
            out = self.backbone(input_values=input_values, attention_mask=attention_mask)
            last_hidden = out.last_hidden_state

        pooled = self.masked_mean_pool(last_hidden, attention_mask)  # [B,H]

        z_v = self.pre_vowel(pooled.float())
        z_o = self.pre_other(pooled.float())

        logits_v = self._heads_fp32(z_v, self.head_vowel)
        logits_o = self._heads_fp32(z_o, self.head_other)

        is_vowel = torch.tensor([tg == "vowel" for tg in task_group], device=pooled.device, dtype=torch.bool)
        logits = logits_o.clone()
        if is_vowel.any():
            logits[is_vowel] = logits_v[is_vowel]
        return logits

# -------------------------
# Metrics and plotting helpers
# Inputs: y_true, y_prob (PD probability), threshold
# Outputs: AUROC, threshold metrics dict, ROC/CM png files
# -------------------------
def compute_auc(y_true, y_prob):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    if len(np.unique(y_true)) < 2:
        return float("nan")
    return float(roc_auc_score(y_true, y_prob))

def compute_threshold_metrics(y_true, y_prob, thr=0.5):
    """Compute confusion-derived metrics at a fixed threshold."""
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])

    eps = 1e-12
    acc = (TP + TN) / max(1, (TP + TN + FP + FN))
    prec = TP / (TP + FP + eps)
    rec = TP / (TP + FN + eps)     # sensitivity
    f1 = 2 * prec * rec / (prec + rec + eps)
    spec = TN / (TN + FP + eps)

    try:
        mcc = float(matthews_corrcoef(y_true, y_pred)) if len(np.unique(y_pred)) > 1 else float("nan")
    except Exception:
        mcc = float("nan")

    try:
        _, pval = fisher_exact([[TN, FP], [FN, TP]], alternative="two-sided")
        pval = float(pval)
    except Exception:
        pval = float("nan")

    return {
        "threshold": float(thr),
        "confusion_matrix": {"TN": TN, "FP": FP, "FN": FN, "TP": TP},
        "accuracy": float(acc),
        "precision": float(prec),
        "recall": float(rec),
        "f1_score": float(f1),
        "sensitivity": float(rec),
        "specificity": float(spec),
        "mcc": float(mcc),
        "p_value_fisher_two_sided": float(pval),
    }

def save_roc_curve_png(y_true, y_prob, out_png, title_suffix="Test"):
    """Write a basic ROC curve plot."""
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    plt.figure()
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curve ({title_suffix})")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def save_confusion_png(y_true, y_prob, out_png, thr=0.5, title_suffix="Test"):
    """Write a confusion matrix plot at the chosen threshold."""
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(int)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])

    plt.figure()
    plt.imshow(cm)
    plt.xticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.yticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix ({title_suffix}, thr={thr:.4f})")
    for (i, j), v in np.ndenumerate(cm):
        plt.text(j, i, str(v), ha="center", va="center")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def mean_sd(vals):
    """Mean and sample SD over seeds (NaN-safe)."""
    vals = np.asarray(vals, dtype=np.float64)
    mu = float(np.nanmean(vals)) if np.any(~np.isnan(vals)) else float("nan")
    sd = float(np.nanstd(vals, ddof=1)) if np.sum(~np.isnan(vals)) > 1 else 0.0
    return mu, sd

# -------------------------
# Fairness metric (H3): ΔFNR = FNR(F) - FNR(M)
# Inputs: y_true, y_prob, sex_norm, threshold
# Outputs: per-sex FNR details and ΔFNR (F minus M)
# -------------------------
def compute_fnr_by_sex_and_delta(y_true, y_prob, sex_norm, thr):
    """
    PD-only FNR at the chosen threshold:
      FNR(sex) = FN / (FN + TP) using only rows with y_true == 1
    """
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    sex_norm = np.asarray(list(sex_norm), dtype=object)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    out = {}
    for g in ["M", "F", "UNK"]:
        mask_g = (sex_norm == g)
        if int(mask_g.sum()) == 0:
            continue
        pos_mask = mask_g & (y_true == 1)
        n_pos = int(pos_mask.sum())
        if n_pos == 0:
            out[g] = {"n_total": int(mask_g.sum()), "n_pos": 0, "tp": 0, "fn": 0, "fnr": float("nan")}
            continue
        tp = int(((y_pred == 1) & pos_mask).sum())
        fn = int(((y_pred == 0) & pos_mask).sum())
        fnr = float(fn / max(1, (fn + tp)))
        out[g] = {"n_total": int(mask_g.sum()), "n_pos": int(n_pos), "tp": int(tp), "fn": int(fn), "fnr": float(fnr)}

    fnr_m = out.get("M", {}).get("fnr", float("nan"))
    fnr_f = out.get("F", {}).get("fnr", float("nan"))
    if (not np.isnan(fnr_m)) and (not np.isnan(fnr_f)):
        delta = float(fnr_f - fnr_m)
    else:
        delta = float("nan")

    return out, delta

# -------------------------
# Seed control
# Inputs: integer seed
# Outputs: deterministic CPU/GPU behavior for this run
# -------------------------
def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# -------------------------
# Build D2 TEST DataLoader + quick warm-up
# Inputs: test_df
# Outputs: test_loader ready for inference; warm-up checks read/decode path
# -------------------------
test_ds = AudioManifestDataset(test_df)
test_loader = DataLoader(
    test_ds,
    batch_size=PER_DEVICE_BS,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    collate_fn=collate_fn,
)

print("\nWarm-up: loading up to 3 D2 TEST batches...")
t0 = time.time()
nb = len(test_loader)
wb = min(3, nb)
if wb == 0:
    raise RuntimeError("D2 TEST DataLoader has 0 batches. Check test_df length and PER_DEVICE_BS.")
it = iter(test_loader)
for i in range(wb):
    _ = next(it)
    print(f"  loaded warmup D2 TEST batch {i+1}/{wb}")
print(f"Warm-up done in {time.time()-t0:.2f}s")

# -------------------------
# Load the shared threshold from base trainval summary
# Inputs: BASE_TRAINVAL_EXP_DIR/summary_trainval.json
# Outputs: TEST_THR_MEAN_VAL_OPT used for ALL seeds on D2 test
# -------------------------
summary_path = BASE_TRAINVAL_EXP_DIR / "summary_trainval.json"
if not summary_path.exists():
    raise FileNotFoundError(f"Missing summary_trainval.json in base exp: {str(summary_path)}")

with open(summary_path, "r", encoding="utf-8") as f:
    chosen_summary = json.load(f)

try:
    val_opt = chosen_summary["val_optimal_threshold"]
    thr_mean = float(val_opt.get("mean_sd", {}).get("mean", float("nan")))
    thr_sd   = float(val_opt.get("mean_sd", {}).get("sd", float("nan")))
    thr_by_seed = {str(k): float(v) for k, v in (val_opt.get("by_seed") or {}).items()}
except Exception as e:
    raise RuntimeError(f"Could not parse val_optimal_threshold from {str(summary_path)}. Reason: {repr(e)}")

TEST_THR_SOURCE = "trainval summary_trainval.json -> val_optimal_threshold.mean_sd.mean (canonical aggregate)"
TEST_THR_MEAN_VAL_OPT = float(thr_mean)
TEST_THR_NOTE = None

# Fallbacks only if the stored aggregate threshold is unusable
if np.isnan(TEST_THR_MEAN_VAL_OPT):
    vals = []
    for s in SEEDS:
        v = thr_by_seed.get(str(s), float("nan"))
        if not np.isnan(v):
            vals.append(float(v))
    if len(vals) > 0:
        TEST_THR_MEAN_VAL_OPT = float(np.mean(vals))
        TEST_THR_NOTE = "val_optimal_threshold.mean_sd.mean was NaN. Fallback used: mean(thr_by_seed over available seeds)."
    else:
        TEST_THR_MEAN_VAL_OPT = 0.5
        TEST_THR_NOTE = "val_optimal_threshold.mean_sd.mean was NaN and thr_by_seed had no usable values. Fallback used: thr=0.5."

print("\nVAL-opt thresholds loaded from BASE trainval summary_trainval.json:")
print("  method:", str(val_opt.get("method", "unknown")))
print("  by_seed:", {k: (f"{v:.6f}" if not np.isnan(v) else "nan") for k, v in thr_by_seed.items()})
print("  mean ± SD:", (f"{thr_mean:.6f} ± {thr_sd:.6f}" if not np.isnan(thr_mean) else f"nan ± {thr_sd:.6f}"))

print("\nTEST threshold policy:")
print("  Using MEAN VAL-opt threshold for ALL seeds/tests:", f"{TEST_THR_MEAN_VAL_OPT:.6f}")
print("  Source:", TEST_THR_SOURCE)
if TEST_THR_NOTE is not None:
    print("  NOTE:", TEST_THR_NOTE)

# -------------------------
# Output folder layout
# Inputs: TAG_SAFE and RUN_STAMP
# Outputs:
# - TAG_ROOT: stable folder for this tag (accumulates per-seed outputs)
# - STAMP_ROOT: unique folder for this execution
# - CFG_DIR/LOG_DIR: builder-aligned run records
# -------------------------
TEST_ROOT = Path(D7_OUT_ROOT) / "multilingual_test_runs"
TEST_ROOT.mkdir(parents=True, exist_ok=True)

TAG_ROOT = TEST_ROOT / f"run_{TAG_SAFE}"
TAG_ROOT.mkdir(parents=True, exist_ok=True)

STAMP_ROOT = TEST_ROOT / f"run_{TAG_SAFE}__{RUN_STAMP}"
STAMP_ROOT.mkdir(parents=True, exist_ok=True)

CFG_DIR = Path(D7_OUT_ROOT) / "config" / f"D7_{TAG_SAFE}_on_D2_Test"
LOG_DIR = Path(D7_OUT_ROOT) / "logs"   / f"D7_{TAG_SAFE}_on_D2_Test"
CFG_DIR.mkdir(parents=True, exist_ok=True)
LOG_DIR.mkdir(parents=True, exist_ok=True)

# -------------------------
# Write builder-aligned run records (config + dataset summary + empty warnings)
# Inputs: test_df counts + threshold policy + run metadata
# Outputs: run_config.json, dataset_summary.json, preprocess_warnings.csv
# -------------------------
run_config_path = CFG_DIR / "run_config.json"
dataset_summary_path = LOG_DIR / "dataset_summary.json"
preprocess_warn_path = LOG_DIR / "preprocess_warnings.csv"

bak1 = _backup_if_exists(run_config_path)
bak2 = _backup_if_exists(dataset_summary_path)
bak3 = _backup_if_exists(preprocess_warn_path)

if bak1: print("Backed up existing run_config.json to:", bak1)
if bak2: print("Backed up existing dataset_summary.json to:", bak2)
if bak3: print("Backed up existing preprocess_warnings.csv to:", bak3)

run_config = {
    "run_type": "test_only",
    "model_family": "D7_multilingual_base",
    "d7_out_root": D7_OUT_ROOT,
    "d2_out_root": D2_OUT_ROOT,
    "d2_manifest_all": D2_MANIFEST_ALL,
    "base_trainval_exp_dir": str(BASE_TRAINVAL_EXP_DIR),
    "trainval_exp_tag": FULL_TRAINVAL_EXP_TAG,
    "tag_safe": TAG_SAFE,
    "run_stamp": RUN_STAMP,
    "seeds": SEEDS,
    "backbone_ckpt": BACKBONE_CKPT,
    "dropout_p": float(DROPOUT_P),
    "batching": {"per_device_bs": int(PER_DEVICE_BS), "effective_bs": int(PER_DEVICE_BS * GRAD_ACCUM)},
    "threshold_policy": {
        "source": TEST_THR_SOURCE,
        "threshold_used_global": float(TEST_THR_MEAN_VAL_OPT),
        "note": (TEST_THR_NOTE if TEST_THR_NOTE is not None else "")
    }
}
with open(run_config_path, "w", encoding="utf-8") as f:
    json.dump(run_config, f, indent=2)

ds_summary = {
    "dataset_id": "D2",
    "split": "test",
    "n_rows": int(len(test_df)),
    "label_counts": test_df["label_num"].value_counts(dropna=False).to_dict(),
    "sex_counts_raw": test_df["sex"].value_counts(dropna=False).to_dict(),
    "sex_counts_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),
    "task_group_counts": test_df["task_group"].value_counts(dropna=False).to_dict(),
    "manifest_path": D2_MANIFEST_ALL,
    "generated_at": RUN_STAMP,
}
with open(dataset_summary_path, "w", encoding="utf-8") as f:
    json.dump(ds_summary, f, indent=2)

# Test-only run: warnings file is a structured placeholder
pd.DataFrame(columns=["warning_type", "clip_path", "detail"]).to_csv(preprocess_warn_path, index=False)

print("\nWROTE builder-aligned artifacts:")
print(" ", str(run_config_path))
print(" ", str(dataset_summary_path))
print(" ", str(preprocess_warn_path))

# -------------------------
# Load heads into the model
# Inputs: best_heads.pt from the base trainval run for a given seed
# Outputs: model with pre-blocks and heads loaded (backbone stays frozen)
# -------------------------
def load_heads_into_model(model: Wav2Vec2TwoHeadClassifier, best_heads_path: Path):
    if not best_heads_path.exists():
        raise FileNotFoundError(f"Missing best_heads.pt: {str(best_heads_path)}")
    state = torch.load(str(best_heads_path), map_location="cpu")
    model.pre_vowel.load_state_dict(state["pre_vowel"], strict=True)
    model.pre_other.load_state_dict(state["pre_other"], strict=True)
    model.head_vowel.load_state_dict(state["head_vowel"], strict=True)
    model.head_other.load_state_dict(state["head_other"], strict=True)
    return model

# -------------------------
# Inference loop (collect per-clip metadata)
# Inputs: test_loader + model
# Outputs: arrays for y_true, y_score, sex_norm plus clip_path/speaker_id/task_group lists
# -------------------------
def infer_probs_and_meta(loader, model, desc):
    use_amp = bool(USE_AMP and DEVICE.type == "cuda")
    all_probs, all_true, all_sex = [], [], []
    all_clip, all_spk, all_tg = [], [], []

    pbar = tqdm(loader, desc=desc, dynamic_ncols=True)
    with torch.inference_mode():
        for batch in pbar:
            input_values = batch["input_values"].to(DEVICE, non_blocking=False)
            attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
            labels = batch["labels"].to(DEVICE, non_blocking=False)

            task_group = batch["task_group"]
            sex_norm = batch["sex_norm"]
            clip_paths = batch["clip_path"]
            speaker_ids = batch["speaker_id"]

            with torch.amp.autocast(device_type="cuda", enabled=use_amp):
                logits = model.forward_logits(input_values, attention_mask, task_group)

            probs = torch.softmax(logits.float(), dim=-1)[:, 1].detach().cpu().numpy().astype(np.float64)

            all_probs.extend(probs.tolist())
            all_true.extend(labels.detach().cpu().numpy().astype(np.int64).tolist())
            all_sex.extend(list(sex_norm))

            all_clip.extend(list(clip_paths))
            all_spk.extend([("" if pd.isna(x) else str(x)) for x in speaker_ids])
            all_tg.extend(list(task_group))

    return (
        np.asarray(all_true, dtype=np.int64),
        np.asarray(all_probs, dtype=np.float64),
        np.asarray(all_sex, dtype=object),
        list(all_clip),
        list(all_spk),
        list(all_tg),
    )

# -------------------------
# Single-seed execution
# Inputs: seed, shared threshold, base heads file for that seed
# Outputs:
# - predictions.csv and metrics.json in both TAG and STAMP folders
# - ROC and confusion plots (overall + by sex when available)
# -------------------------
def run_test_once(seed: int):
    set_all_seeds(seed)

    per_tag_seed_dir = TAG_ROOT / f"run_D7_on_D2test_seed{seed}"
    per_tag_seed_dir.mkdir(parents=True, exist_ok=True)

    per_stamp_seed_dir = STAMP_ROOT / f"run_D7_on_D2test_seed{seed}"
    per_stamp_seed_dir.mkdir(parents=True, exist_ok=True)

    best_heads_path = BASE_TRAINVAL_EXP_DIR / f"run_D7_seed{seed}" / "best_heads.pt"
    if not best_heads_path.exists():
        raise FileNotFoundError(f"Missing best_heads.pt for seed={seed}: {str(best_heads_path)}")

    thr_used = float(TEST_THR_MEAN_VAL_OPT)

    print(f"\n[seed={seed}] Loading model + BASE heads from:")
    print(" ", str(best_heads_path))
    print(f"[seed={seed}] Using TEST threshold (shared mean VAL-opt): {thr_used:.6f}")
    if TEST_THR_NOTE is not None:
        print(f"[seed={seed}] NOTE: {TEST_THR_NOTE}")

    model = Wav2Vec2TwoHeadClassifier(BACKBONE_CKPT, dropout_p=DROPOUT_P).to(DEVICE)
    model = load_heads_into_model(model, best_heads_path)
    model.eval()

    yt_true, yt_prob, yt_sex, clip_paths, speaker_ids, task_groups = infer_probs_and_meta(
        test_loader, model, desc=f"[seed={seed}] D2 TEST"
    )

    test_auc = compute_auc(yt_true, yt_prob)
    thr_metrics_test = compute_threshold_metrics(yt_true, yt_prob, thr=thr_used)
    fnr_by_sex, delta_f_minus_m = compute_fnr_by_sex_and_delta(yt_true, yt_prob, yt_sex, thr=thr_used)

    # Write plots to both locations so the stable tag folder and the run-stamped folder are complete
    for out_dir in [per_tag_seed_dir, per_stamp_seed_dir]:
        roc_png = out_dir / "roc_curve.png"
        cm_png  = out_dir / "confusion_matrix.png"
        save_roc_curve_png(yt_true, yt_prob, str(roc_png), title_suffix=f"D2 TEST (seed={seed})")
        save_confusion_png(yt_true, yt_prob, str(cm_png), thr=thr_used, title_suffix=f"D2 TEST (seed={seed})")

        mask_m = (yt_sex == "M")
        mask_f = (yt_sex == "F")

        if int(mask_m.sum()) > 0:
            save_confusion_png(yt_true[mask_m], yt_prob[mask_m], str(out_dir / "confusion_matrix_M.png"),
                               thr=thr_used, title_suffix=f"D2 TEST SEX=M (seed={seed})")
        if int(mask_f.sum()) > 0:
            save_confusion_png(yt_true[mask_f], yt_prob[mask_f], str(out_dir / "confusion_matrix_F.png"),
                               thr=thr_used, title_suffix=f"D2 TEST SEX=F (seed={seed})")

    # predictions.csv: one row per clip with score and metadata
    pred_df = pd.DataFrame({
        "clip_path": clip_paths,
        "y_true": yt_true.astype(int),
        "y_score": yt_prob.astype(float),
        "sex_norm": [str(x) for x in yt_sex.tolist()],
        "speaker_id": speaker_ids,
        "task_group": task_groups,
        "seed": int(seed),
        "trainval_exp_tag": FULL_TRAINVAL_EXP_TAG,
        "run_stamp": RUN_STAMP,
        "threshold_used_global": float(thr_used),
    })

    pred_path_tag = per_tag_seed_dir / "predictions.csv"
    pred_path_stamp = per_stamp_seed_dir / "predictions.csv"
    pred_df.to_csv(pred_path_tag, index=False)
    pred_df.to_csv(pred_path_stamp, index=False)

    # metrics.json: per-seed snapshot used by the final summary
    seed_metrics = {
        "seed": int(seed),
        "d7_out_root": D7_OUT_ROOT,
        "d2_out_root": D2_OUT_ROOT,
        "d2_manifest_all": D2_MANIFEST_ALL,
        "base_trainval_exp_dir": str(BASE_TRAINVAL_EXP_DIR),
        "trainval_exp_tag": FULL_TRAINVAL_EXP_TAG,
        "best_heads_path": str(best_heads_path),

        "tag_safe": TAG_SAFE,
        "run_stamp": RUN_STAMP,

        "n_test": int(len(test_df)),
        "label_counts_test": test_df["label_num"].value_counts(dropna=False).to_dict(),
        "sex_counts_test_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),

        "test_auroc": float(test_auc),

        "threshold_source": TEST_THR_SOURCE,
        "val_optimal_threshold_canonical": {
            "method": str(val_opt.get("method", "unknown")),
            "by_seed": {k: float(v) for k, v in thr_by_seed.items()},
            "mean_sd": {"mean": float(thr_mean), "sd": float(thr_sd)},
        },
        "test_threshold_used": float(thr_used),
        "test_threshold_note": (TEST_THR_NOTE if TEST_THR_NOTE is not None else ""),

        "threshold_metrics_test_at_val_opt": thr_metrics_test,

        "fairness_test_at_val_opt": {
            "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at test_threshold_used.",
            "threshold_used": float(thr_used),
            "fnr_by_sex_norm": fnr_by_sex,
            "delta_fnr_F_minus_M": float(delta_f_minus_m),
            "note": "If n_PD for a sex is 0, its FNR is NaN and ΔFNR is NaN.",
        },

        "artifacts": {
            "tag_seed_dir": str(per_tag_seed_dir),
            "stamp_seed_dir": str(per_stamp_seed_dir),
            "predictions_csv_tag": str(pred_path_tag),
            "predictions_csv_stamp": str(pred_path_stamp),
        },
    }

    for out_dir in [per_tag_seed_dir, per_stamp_seed_dir]:
        mp = out_dir / "metrics.json"
        with open(mp, "w", encoding="utf-8") as f:
            json.dump(seed_metrics, f, indent=2)

    print(f"[seed={seed}] DONE | test_AUROC={test_auc:.6f} | thr_used={thr_used:.6f}")
    print(f"[seed={seed}] WROTE predictions:")
    print(" ", str(pred_path_tag))
    print(" ", str(pred_path_stamp))

    return {
        "seed": int(seed),
        "thr_used": float(thr_used),
        "test_auc": float(test_auc),
        "thr_metrics": thr_metrics_test,
        "fnr_by_sex": fnr_by_sex,
        "delta_f_minus_m": float(delta_f_minus_m),
        "tag_seed_dir": str(per_tag_seed_dir),
        "stamp_seed_dir": str(per_stamp_seed_dir),
    }

# -------------------------
# Run all seeds and aggregate results
# Inputs: SEEDS
# Outputs: printed per-seed results + mean/CI + aggregated metrics
# -------------------------
results = []
for seed in SEEDS:
    results.append(run_test_once(seed))

aurocs = [r["test_auc"] for r in results]
t_crit = 4.302652729911275  # df=2, 95% CI
n = len(aurocs)
mean_auc = float(np.mean(aurocs))
std_auc = float(np.std(aurocs, ddof=1)) if n > 1 else 0.0
half_width = float(t_crit * (std_auc / math.sqrt(n))) if n > 1 else 0.0
ci95 = [float(mean_auc - half_width), float(mean_auc + half_width)]

metric_keys = ["accuracy","precision","sensitivity","specificity","f1_score","mcc","p_value_fisher_two_sided"]
metric_agg = {}
for k in metric_keys:
    vals = [float(r["thr_metrics"].get(k, float("nan"))) for r in results]
    mu, sd = mean_sd(vals)
    metric_agg[k] = {
        "mean": float(mu),
        "sd": float(sd),
        "by_seed": {str(r["seed"]): float(r["thr_metrics"].get(k, float("nan"))) for r in results},
    }

delta_vals = [float(r["delta_f_minus_m"]) for r in results]
delta_mean, delta_sd = mean_sd(delta_vals)

fnr_m_vals = []
fnr_f_vals = []
n_pd_m_vals = []
n_pd_f_vals = []
for r in results:
    d = r["fnr_by_sex"] or {}
    fnr_m_vals.append(float(d.get("M", {}).get("fnr", float("nan"))))
    fnr_f_vals.append(float(d.get("F", {}).get("fnr", float("nan"))))
    n_pd_m_vals.append(float(d.get("M", {}).get("n_pos", float("nan"))))
    n_pd_f_vals.append(float(d.get("F", {}).get("n_pos", float("nan"))))

fnr_m_mean, fnr_m_sd = mean_sd(fnr_m_vals)
fnr_f_mean, fnr_f_sd = mean_sd(fnr_f_vals)

# Console summary (prints only)
print("\nD2 TEST AUROC by seed:")
for r in results:
    print(f"  seed {r['seed']}: {r['test_auc']:.6f}")
print("\nMean D2 TEST AUROC:", f"{mean_auc:.6f}")
print(f"95% CI (t, n=3): [{ci95[0]:.6f}, {ci95[1]:.6f}]")

print("\nD2 TEST threshold used (shared mean VAL-opt):", f"{TEST_THR_MEAN_VAL_OPT:.6f}")
if TEST_THR_NOTE is not None:
    print("NOTE:", TEST_THR_NOTE)

print("\nThreshold metrics on D2 TEST @ shared mean VAL-opt threshold (mean ± SD across seeds):")
for k in ["accuracy","precision","sensitivity","specificity","f1_score","mcc"]:
    print(f"  {k}: {metric_agg[k]['mean']:.6f} ± {metric_agg[k]['sd']:.6f}")

print("\nFAIRNESS (H3) on D2 TEST @ shared threshold (mean ± SD across seeds):")
print(f"  FNR_M: {fnr_m_mean:.6f} ± {fnr_m_sd:.6f}")
print(f"  FNR_F: {fnr_f_mean:.6f} ± {fnr_f_sd:.6f}")
print(f"  ΔFNR (F-M): {delta_mean:.6f} ± {delta_sd:.6f}")

# -------------------------
# Write summary_test.json and append history_index.jsonl
# Inputs: results + test_df counts + threshold policy
# Outputs: summary_test.json, history_index.jsonl, and pointer json files
# -------------------------
summary = {
    "run_type": "D7_base_on_D2_test",
    "d7_out_root": D7_OUT_ROOT,
    "d2_out_root": D2_OUT_ROOT,
    "d2_manifest_all": D2_MANIFEST_ALL,

    "base_trainval_exp_dir": str(BASE_TRAINVAL_EXP_DIR),
    "trainval_exp_tag": FULL_TRAINVAL_EXP_TAG,

    "tag_safe": TAG_SAFE,
    "run_stamp": RUN_STAMP,
    "seeds": SEEDS,

    "n_test": int(len(test_df)),
    "label_counts_test": test_df["label_num"].value_counts(dropna=False).to_dict(),
    "sex_counts_test_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),
    "task_group_counts_test": test_df["task_group"].value_counts(dropna=False).to_dict(),

    "val_optimal_threshold": {
        "source": "trainval summary_trainval.json",
        "method": str(val_opt.get("method", "unknown")),
        "by_seed": {str(k): float(v) for k, v in thr_by_seed.items()},
        "mean_sd": {"mean": float(thr_mean), "sd": float(thr_sd)},
    },

    "test_threshold_used": {
        "policy": "mean_val_optimal_threshold_shared_across_seeds",
        "threshold_used_global": float(TEST_THR_MEAN_VAL_OPT),
        "note": (TEST_THR_NOTE if TEST_THR_NOTE is not None else ""),
        "source": TEST_THR_SOURCE,
    },

    "test_aurocs_by_seed": {str(r["seed"]): float(r["test_auc"]) for r in results},
    "mean_test_auroc": float(mean_auc),
    "t_crit_df2_95": float(t_crit),
    "half_width_95_ci": float(half_width),
    "ci95_test_auroc": ci95,

    "threshold_metrics_test_at_val_opt_mean_sd": metric_agg,

    "fairness_test_at_val_opt": {
        "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at the shared mean VAL-opt threshold.",
        "delta_fnr_F_minus_M_by_seed": {str(r["seed"]): float(r["delta_f_minus_m"]) for r in results},
        "delta_fnr_F_minus_M_mean_sd": {"mean": float(delta_mean), "sd": float(delta_sd)},
        "fnr_M_mean_sd": {"mean": float(fnr_m_mean), "sd": float(fnr_m_sd), "by_seed": {str(SEEDS[i]): float(fnr_m_vals[i]) for i in range(len(SEEDS))}},
        "fnr_F_mean_sd": {"mean": float(fnr_f_mean), "sd": float(fnr_f_sd), "by_seed": {str(SEEDS[i]): float(fnr_f_vals[i]) for i in range(len(SEEDS))}},
        "denominators_PD_by_seed": {str(SEEDS[i]): {"n_PD_M": float(n_pd_m_vals[i]), "n_PD_F": float(n_pd_f_vals[i])} for i in range(len(SEEDS))},
    },

    "tag_run_dir": str(TAG_ROOT),
    "stamp_run_dir": str(STAMP_ROOT),
    "tag_seed_dirs": {str(r["seed"]): r["tag_seed_dir"] for r in results},
    "stamp_seed_dirs": {str(r["seed"]): r["stamp_seed_dir"] for r in results},

    "builder_aligned": {
        "run_config_json": str(run_config_path),
        "dataset_summary_json": str(dataset_summary_path),
        "preprocess_warnings_csv": str(preprocess_warn_path),
    },
}

summary_path = TEST_ROOT / "summary_test.json"
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(summary, f, indent=2)

history_path = TEST_ROOT / "history_index.jsonl"
with open(history_path, "a", encoding="utf-8") as f:
    f.write(json.dumps(summary) + "\n")

# Small pointer files for quick navigation (safe to overwrite)
with open(TEST_ROOT / "last_run_pointer.json", "w", encoding="utf-8") as f:
    json.dump({"tag_safe": TAG_SAFE, "run_stamp": RUN_STAMP, "stamp_run_dir": str(STAMP_ROOT)}, f, indent=2)

with open(TAG_ROOT / "tag_run_pointer.json", "w", encoding="utf-8") as f:
    json.dump({"tag_safe": TAG_SAFE, "run_stamp": RUN_STAMP, "stamp_run_dir": str(STAMP_ROOT)}, f, indent=2)

print("\nWROTE summary:", str(summary_path))
print("APPENDED history index:", str(history_path))
print("TAG folder (stable):", str(TAG_ROOT))
print("STAMP folder (this execution):", str(STAMP_ROOT))
print("Open this folder to access artifacts:", str(TEST_ROOT))

# -------------------------
# Stop the Colab runtime (GPU release)
# Inputs: none
# Outputs: runtime unassigned when supported
# -------------------------
print("\nAll done. Unassigning the runtime to stop the L4 instance...")
try:
    from google.colab import runtime  # type: ignore
    print("Calling runtime.unassign() now.")
    runtime.unassign()
except Exception as e:
    print("Could not unassign runtime automatically. You can stop the runtime manually in Colab.")
    print("Reason:", repr(e))

The following cell builds an enhanced D7 training split called **train_enh1** by combining the full original D7 training data with a small, carefully selected portion of D2. It starts by reading `manifest_all.csv` from both datasets and enforcing a fixed column structure so the output stays compatible with the rest of the pipeline. Early checks are applied to catch problems right away. Dataset identifiers must be exactly D7 or D2, labels must be Healthy (0) or Parkinson’s (1), and sex values are standardized to M or F, with D2 values explicitly mapped from “male” and “female”. Any literal `"NaN"` strings are converted to real missing values so they do not cause issues later.

From the D2 training split, the cell selects 10 percent of speakers using speaker level sampling. The number of speakers is rounded up and adjusted to an even count. All training clips for each selected speaker are included, and a strict check confirms that each speaker has a single, consistent diagnosis label across all clips. The selected speakers are balanced by class so there are equal numbers of Healthy and Parkinson’s speakers. A fixed random seed is used so the same speakers are chosen every time, making the process repeatable.

Before copying any audio files, the cell checks that every source file referenced by the selected D7 and D2 rows actually exists. It then prepares a copy plan and checks the destination folder under `clips/train_enh1/` to avoid overwriting files by mistake. Files are copied, not moved, so the original datasets are left unchanged. If a destination file already exists and has the same file size, it is skipped. If a file exists but the size does not match, the run stops immediately to prevent silent errors.

Two filename rules are used to keep everything traceable and avoid name conflicts. Original D7 training clips keep their existing filenames. D2 clips are renamed in a deterministic way that encodes the class, speaker identity, task type, and a stable index. During the process, structured outputs are written early and safely, including a run configuration file, a preprocessing warnings log, and a dataset summary file. Backups are created automatically so earlier results are not overwritten.

After all files are copied, the cell creates `manifest_train_enh1.csv` using the locked schema and adds a final `source_dataset` column. All rows are marked with `split = "train_enh1"` and keep `dataset = "D7"` so the enhanced data can be used directly by the standard D7 training code. File paths and sample IDs are updated to match the copied files, and the source of each clip, either D7 or D2, is clearly recorded. Final checks confirm that labels and sex values are valid, there are no literal `"NaN"` strings, and every audio file listed in the manifest exists. The cell finishes by writing a success summary with counts by clip, speaker, label, sex, and source, and prints the main output locations for reference.

In [None]:
# =========================
# D7 TRAIN_ENH1 BUILDER — Add ceil(10% of D2 TRAIN speakers)
# Purpose:
# - Create an augmented D7 training clip folder by copying:
#   1) all existing D7 train clips (keep original filenames)
#   2) plus a speaker-balanced 10% sample of D2 train speakers (copy all their train clips)
# Outputs:
# - New clips folder: clips/train_enh1/
# - New manifest: manifests/manifest_train_enh1.csv (locked schema)
# - Run records: config/.../run_config.json and logs/.../dataset_summary.json + preprocess_warnings.csv
# Notes:
# - Copy only (never move)
# - No overwrite: if destination exists and size matches, skip; if size differs, fail-fast
# =========================

import os, json, re, shutil, math
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

# -------------------------
# Mount Drive (Colab only)
# Inputs: none
# Outputs: Drive mounted at /content/drive when needed
# -------------------------
try:
    from google.colab import drive  # type: ignore
    if not os.path.exists("/content/drive"):
        drive.mount("/content/drive")
except Exception:
    pass

# -------------------------
# Paths and output locations
# Inputs: D7_OUT_ROOT + D2_OUT_ROOT (preprocessed dataset roots)
# Outputs: standard folders created if missing (clips, manifests, config, logs)
# -------------------------
D7_OUT_ROOT = Path("/content/drive/MyDrive/AI_PD_Project/Datasets/D7-Multilingual (D1_D4_D5v2_D6)/preprocessed_v1")
D2_OUT_ROOT = Path("/content/drive/MyDrive/AI_PD_Project/Datasets/D2-Slovak (EWA-DB)/EWA-DB/preprocessed_v1")

D7_MANIFEST_ALL = D7_OUT_ROOT / "manifests" / "manifest_all.csv"
D2_MANIFEST_ALL = D2_OUT_ROOT / "manifests" / "manifest_all.csv"

# train_enh → train_enh1 (folder + manifest)
TRAIN_ENH_DIR = D7_OUT_ROOT / "clips" / "train_enh1"
MANIFEST_TRAIN_ENH = D7_OUT_ROOT / "manifests" / "manifest_train_enh1.csv"

# Standard layout dirs (match merge-builder style)
clips_dir = D7_OUT_ROOT / "clips"
manif_dir = D7_OUT_ROOT / "manifests"
cfg_dir   = D7_OUT_ROOT / "config" / "D7_Enh1_on_D2_Test"
logs_dir  = D7_OUT_ROOT / "logs" / "D7_Enh1_on_D2_Test"

for d in [clips_dir, manif_dir, cfg_dir, logs_dir]:
    d.mkdir(parents=True, exist_ok=True)
TRAIN_ENH_DIR.mkdir(parents=True, exist_ok=True)

# All rows written into manifest_train_enh1.csv use split="train_enh1"
TRAIN_ENH_SPLIT_NAME = "train_enh1"

# -------------------------
# Locked manifest schema
# Inputs: D7/D2 manifest_all.csv
# Outputs: train_enh manifest with the same column order + a source_dataset column
# -------------------------
CANON_COLS = [
    "split",
    "dataset",
    "task",
    "speaker_id",
    "sample_id",
    "label_str",
    "label_num",
    "age",
    "sex",
    "speaker_key_rel",
    "clip_path",
    "duration_sec",
    "source_path",
    "clip_start_sec",
    "clip_end_sec",
    "sr_hz",
    "channels",
    "clip_is_contiguous",
]
FINAL_COLS = CANON_COLS + ["source_dataset"]

# -------------------------
# Sampling policy (10% of D2 train speakers)
# Inputs: D2 train speakers + labels
# Outputs: a fixed, repeatable speaker sample (seeded)
# -------------------------
TEN_PCT         = 0.10
ROUNDING_POLICY = "ceil_then_make_even_by_rounding_down_if_needed"
BALANCE_POLICY  = "speaker-balanced"  # equal speakers HC/PD
SPEAKER_ID_COL  = "speaker_id"
RNG_SEED        = 1337

# Label mapping used everywhere
LABEL_MAP_NOTE  = "label_num mapping: 0=Healthy, 1=Parkinson"

# -------------------------
# Guardrails and file writing helpers
# Inputs: runtime checks + planned outputs
# Outputs: warnings table + atomic writes for csv/json
# -------------------------
warnings_rows = []

def require(cond: bool, msg: str):
    """Fail-fast check for required conditions."""
    if not cond:
        raise RuntimeError(msg)

def add_warn(src: str, level: str, code: str, message: str, **extra):
    """Structured warning log row saved to preprocess_warnings.csv."""
    row = {
        "ts": datetime.utcnow().isoformat(),
        "src": src,
        "level": str(level).upper(),
        "code": code,
        "message": message,
    }
    row.update(extra)
    warnings_rows.append(row)

def count_by_level(rows):
    """Count INFO/WARN/ERROR in the warnings table."""
    out = {"ERROR": 0, "WARN": 0, "INFO": 0}
    for r in rows:
        lvl = str(r.get("level", "INFO")).upper()
        out[lvl] = out.get(lvl, 0) + 1
    return out

def _maybe_backup_existing(dst: Path):
    """
    Backup an existing output file before rewriting it.
    Used for builder outputs (manifest/config/logs).
    """
    if dst.exists():
        ts = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
        bak = dst.with_suffix(dst.suffix + f".bak_{ts}")
        try:
            shutil.copy2(dst, bak)
        except Exception:
            raise RuntimeError(f"Failed to create backup for existing file: {str(dst)}")

def atomic_write_text(dst: Path, text: str):
    """Write a text file via a temp file to avoid partial results."""
    _maybe_backup_existing(dst)
    tmp = dst.with_suffix(dst.suffix + ".tmp")
    with open(tmp, "w", encoding="utf-8") as f:
        f.write(text)
    os.replace(tmp, dst)

def atomic_write_csv(dst: Path, df: pd.DataFrame):
    """Write a CSV via a temp file; missing values are blank, not the string 'NaN'."""
    _maybe_backup_existing(dst)
    tmp = dst.with_suffix(dst.suffix + ".tmp")
    df.to_csv(tmp, index=False, na_rep="")
    os.replace(tmp, dst)

def safe_token(s, max_len=32, default="NA"):
    """
    Make a short filename-safe token.
    Used only for new D2-derived filenames.
    """
    if pd.isna(s):
        return default
    s = str(s).strip()
    s = re.sub(r"\s+", "_", s)
    s = re.sub(r"[^A-Za-z0-9_]+", "", s)
    return (s[:max_len] if s else default)

def label_str_from_num(v):
    """Convert numeric labels into canonical strings."""
    if pd.isna(v):
        return np.nan
    iv = int(v)
    if iv == 0:
        return "Healthy"
    if iv == 1:
        return "Parkinson"
    return np.nan

def hc_pd_from_label_num(v):
    """Short class tag used only in new filenames."""
    return "PD" if int(v) == 1 else "HC"

def normalize_sex_to_MF_D7(x):
    """D7 sex normalization to M/F with NaN for unknown."""
    if pd.isna(x):
        return np.nan
    s = str(x).strip().lower()
    if s in ["m", "male", "1"]:
        return "M"
    if s in ["f", "female", "0"]:
        return "F"
    if s in ["", "nan", "none", "unknown", "u"]:
        return np.nan
    return np.nan

def normalize_sex_to_MF_D2(x):
    """D2 sex normalization from male/female to M/F."""
    if pd.isna(x):
        return np.nan
    s = str(x).strip().lower()
    if s in ["male", "m"]:
        return "M"
    if s in ["female", "f"]:
        return "F"
    if s in ["", "nan", "none", "unknown", "u"]:
        return np.nan
    return np.nan

def summarize_counts(df: pd.DataFrame, split_name: str):
    """Small summary block written to dataset_summary.json."""
    out = {}
    out["total_rows"] = int(len(df))
    out["split"] = split_name
    out["label_counts_total"] = {str(k): int(v) for k, v in df["label_num"].value_counts(dropna=False).to_dict().items()}
    out["by_source_dataset"] = {sd: int((df["source_dataset"] == sd).sum()) for sd in sorted(df["source_dataset"].dropna().unique())}
    out["sex_counts"] = {str(k): int(v) for k, v in df["sex"].value_counts(dropna=False).to_dict().items()}
    out["n_unique_speakers"] = int(df["speaker_id"].astype(str).nunique()) if "speaker_id" in df.columns else int(0)
    return out

# -------------------------
# Load manifests and validate schema
# Inputs: D7 manifest_all.csv and D2 manifest_all.csv
# Outputs: in-memory tables with cleaned NaNs, validated labels, normalized sex
# -------------------------
print("D7_OUT_ROOT:", str(D7_OUT_ROOT))
print("D2_OUT_ROOT:", str(D2_OUT_ROOT))
print("D7_MANIFEST_ALL:", str(D7_MANIFEST_ALL))
print("D2_MANIFEST_ALL:", str(D2_MANIFEST_ALL))
print("TRAIN_ENH_DIR:", str(TRAIN_ENH_DIR))
print("MANIFEST_TRAIN_ENH:", str(MANIFEST_TRAIN_ENH))
print("cfg_dir:", str(cfg_dir))
print("logs_dir:", str(logs_dir))
print("TRAIN_ENH_SPLIT_NAME:", TRAIN_ENH_SPLIT_NAME)

require(D7_MANIFEST_ALL.exists(), f"Missing D7 manifest_all.csv: {str(D7_MANIFEST_ALL)}")
require(D2_MANIFEST_ALL.exists(), f"Missing D2 manifest_all.csv: {str(D2_MANIFEST_ALL)}")

d7 = pd.read_csv(D7_MANIFEST_ALL)
d2 = pd.read_csv(D2_MANIFEST_ALL)

missing_d7 = [c for c in CANON_COLS if c not in d7.columns]
missing_d2 = [c for c in CANON_COLS if c not in d2.columns]
require(len(missing_d7) == 0, f"D7 manifest missing required columns (locked schema): {missing_d7}")
require(len(missing_d2) == 0, f"D2 manifest missing required columns (locked schema): {missing_d2}")

# Add source_dataset if missing so the final merged manifest can always track provenance.
if "source_dataset" not in d7.columns:
    d7["source_dataset"] = "D7"
if "source_dataset" not in d2.columns:
    d2["source_dataset"] = "D2"

# Replace the literal text "NaN" with true missing values for consistent typing.
for df in [d7, d2]:
    for col in ["sex", "age", "duration_sec", "clip_start_sec", "clip_end_sec", "speaker_key_rel", "speaker_id", "task", "sample_id"]:
        if col in df.columns:
            df[col] = df[col].replace("NaN", np.nan)

# Confirm each manifest really corresponds to the intended dataset ID.
def infer_dataset_id(df: pd.DataFrame, fallback: str) -> str:
    if "dataset" in df.columns and df["dataset"].notna().any():
        return str(df["dataset"].astype(str).value_counts(dropna=True).idxmax())
    return fallback

d7_dataset_id = infer_dataset_id(d7, "DX")
d2_dataset_id = infer_dataset_id(d2, "DX")

require(d7_dataset_id == "D7", f"Expected D7 manifest dataset=='D7' but inferred {d7_dataset_id!r}. Check D7_OUT_ROOT/manifest.")
require(d2_dataset_id == "D2", f"Expected D2 manifest dataset=='D2' but inferred {d2_dataset_id!r}. Check D2_OUT_ROOT/manifest.")

# Keep only the intended dataset rows (extra safety if manifests were ever merged).
d7 = d7[d7["dataset"].astype(str) == "D7"].copy()
d2 = d2[d2["dataset"].astype(str) == "D2"].copy()

# Enforce labels are strictly binary and set canonical label_str.
for name, df in [("D7", d7), ("D2", d2)]:
    bad = sorted(set(df["label_num"].dropna().unique()) - {0, 1})
    require(len(bad) == 0, f"{name} label_num contains values outside {{0,1}}: {bad}")
    df["label_str"] = df["label_num"].map(label_str_from_num)
    bad_str = sorted(set(df["label_str"].dropna().unique()) - {"Healthy", "Parkinson"})
    require(len(bad_str) == 0, f"{name} label_str contains unexpected values: {bad_str}")

# Normalize sex to M/F for consistent downstream metrics.
d7["sex"] = d7["sex"].map(normalize_sex_to_MF_D7)
d2["sex"] = d2["sex"].map(normalize_sex_to_MF_D2)

bad_sex_d7 = sorted(set(d7["sex"].dropna().unique()) - {"M", "F"})
bad_sex_d2 = sorted(set(d2["sex"].dropna().unique()) - {"M", "F"})
require(len(bad_sex_d7) == 0, f"D7 sex contains unexpected values after normalization: {bad_sex_d7}")
require(len(bad_sex_d2) == 0, f"D2 sex contains unexpected values after normalization: {bad_sex_d2}")

# -------------------------
# D7 train rows to include (baseline portion)
# Inputs: D7 manifest rows where split == "train"
# Outputs: d7_train table used for copying + manifest building
# -------------------------
d7_train = d7[d7["split"].astype(str) == "train"].copy()
require(len(d7_train) > 0, "D7 train split has 0 rows. Expected split=='train' to exist in D7 manifest.")

print("\nD7 train rows:", int(len(d7_train)))
print("D7 train label counts:", d7_train["label_num"].value_counts(dropna=False).to_dict())

# -------------------------
# D2 speaker-balanced sampling (10% of train speakers)
# Inputs: D2 manifest rows where split == "train"
# Outputs: selected_speakers list + d2_sel (all train clips for those speakers)
# -------------------------
d2_train = d2[d2["split"].astype(str) == "train"].copy()
require(len(d2_train) > 0, "D2 train split has 0 rows. Expected split=='train' to exist in D2 manifest.")

print("\nD2 train rows:", int(len(d2_train)))
print("D2 train label counts:", d2_train["label_num"].value_counts(dropna=False).to_dict())

require(SPEAKER_ID_COL in d2_train.columns, f"D2 manifest missing speaker id column: {SPEAKER_ID_COL}")

# Check that each speaker belongs to exactly one class (required for speaker-level sampling).
speaker_labels = (
    d2_train
    .groupby(SPEAKER_ID_COL)["label_num"]
    .apply(lambda s: sorted(set(s.dropna().astype(int).tolist())))
)
mixed = speaker_labels[speaker_labels.apply(lambda x: len(x) != 1)]
if len(mixed) > 0:
    sample_mixed = mixed.head(10).to_dict()
    raise RuntimeError(
        "D2 train has speakers with mixed label_num values across clips. "
        "Speaker-level sampling requires each speaker to belong to exactly one class.\n"
        f"Examples (speaker_id -> labels): {sample_mixed}"
    )

speaker_to_label = speaker_labels.apply(lambda x: int(x[0])).to_dict()
all_speakers = sorted(list(speaker_to_label.keys()))
total_speakers = len(all_speakers)

hc_speakers = sorted([spk for spk, y in speaker_to_label.items() if y == 0])
pd_speakers = sorted([spk for spk, y in speaker_to_label.items() if y == 1])

require(len(hc_speakers) > 0 and len(pd_speakers) > 0, "D2 train does not contain both HC and PD speakers; cannot do balanced sampling.")

# Compute target speaker count: ceil(10%), then force even so HC/PD can be equal.
target_total = int(math.ceil(TEN_PCT * total_speakers))
if target_total % 2 != 0:
    target_total -= 1  # make even

# Ensure at least one speaker per class.
if target_total < 2:
    target_total = 2

target_per_class = target_total // 2

require(target_per_class <= len(hc_speakers), f"Not enough HC speakers in D2 train for target_per_class={target_per_class}. HC speakers={len(hc_speakers)}")
require(target_per_class <= len(pd_speakers), f"Not enough PD speakers in D2 train for target_per_class={target_per_class}. PD speakers={len(pd_speakers)}")

# Sample speakers deterministically from the fixed seed.
rng = np.random.default_rng(RNG_SEED)
sel_hc = sorted(rng.choice(hc_speakers, size=target_per_class, replace=False).tolist())
sel_pd = sorted(rng.choice(pd_speakers, size=target_per_class, replace=False).tolist())
selected_speakers = sorted(sel_hc + sel_pd)

# Keep all D2 train clips for the selected speakers.
d2_sel = d2_train[d2_train[SPEAKER_ID_COL].astype(str).isin([str(x) for x in selected_speakers])].copy()

print("\nD2 speaker sampling:")
print("  speaker_id_col used:", SPEAKER_ID_COL)
print("  total D2 train speakers:", total_speakers)
print(f"  target speakers total (10% ceil, balanced, odd->down): {target_total} => {target_per_class} HC + {target_per_class} PD")
print("  selected D2 rows (all clips for selected speakers):", int(len(d2_sel)))
print("  selected label counts:", d2_sel["label_num"].value_counts(dropna=False).to_dict())

# -------------------------
# Verify every referenced clip exists (before any copying)
# Inputs: d7_train clip_path and d2_sel clip_path
# Outputs: fail-fast error if any source audio is missing
# -------------------------
def fail_fast_missing_paths(df: pd.DataFrame, name: str):
    missing_paths = []
    for p in tqdm(df["clip_path"].astype(str).tolist(), desc=f"Check {name} clip_path exists", dynamic_ncols=True):
        if not os.path.exists(p):
            missing_paths.append(p)
            if len(missing_paths) >= 25:
                break
    if missing_paths:
        raise FileNotFoundError(f"{name}: missing clip_path(s). Examples (up to 25): {missing_paths}")

fail_fast_missing_paths(d7_train, "D7 TRAIN")
fail_fast_missing_paths(d2_sel, "D2 TRAIN (selected speakers)")

# -------------------------
# Plan destination filenames (D7 keeps names, D2 gets collision-proof names)
# Inputs: d7_train + d2_sel
# Outputs: copy_plan list (source path → destination path)
# -------------------------
# Rules:
# - D7 train clips: copy into train_enh1 using the same basename
# - D2 selected clips: rename deterministically to avoid collisions in the shared folder
copy_plan = []  # dicts: src_path, dst_path, origin, source_dataset, row_index

# D7 → train_enh1 (same filenames)
for i, row in d7_train.reset_index(drop=True).iterrows():
    src_path = Path(str(row["clip_path"]))
    dst_path = TRAIN_ENH_DIR / src_path.name
    copy_plan.append({
        "src_path": str(src_path),
        "dst_path": str(dst_path),
        "origin": "D7_train_existing",
        "source_dataset": str(row.get("source_dataset", "D7")),
        "row_index": int(i),
    })

# D2 selected → train_enh1 (deterministic order and names)
d2_sel_reset = d2_sel.reset_index(drop=True).copy()

# Sort so the numbering stays stable across reruns.
d2_sel_reset["_speaker_tok"] = d2_sel_reset["speaker_id"].map(lambda x: safe_token(x, 32, "NA"))
d2_sel_reset["_task_tok"] = d2_sel_reset["task"].map(lambda x: safe_token(x, 12, "0"))
d2_sel_reset["_clip_path_str"] = d2_sel_reset["clip_path"].astype(str)
d2_sel_reset = d2_sel_reset.sort_values(by=["_speaker_tok", "_task_tok", "_clip_path_str"]).reset_index(drop=True)

for j, row in d2_sel_reset.iterrows():
    src_path = Path(str(row["clip_path"]))
    hc_pd = hc_pd_from_label_num(row["label_num"])
    spk_tok = safe_token(row["speaker_id"], 32, "NA")
    task_tok = safe_token(row["task"], 12, "0")

    out_name = f"D7_D2add_{hc_pd}_{spk_tok}_{task_tok}_{j+1:06d}.wav"
    dst_path = TRAIN_ENH_DIR / out_name

    copy_plan.append({
        "src_path": str(src_path),
        "dst_path": str(dst_path),
        "origin": "D2_train_selected",
        "source_dataset": "D2",
        "row_index": int(j),
    })

# -------------------------
# Preflight checks (no overwrite policy)
# Inputs: copy_plan + filesystem
# Outputs: preflight stats + warnings; stops early on size mismatches
# -------------------------
n_dest_exists_ok = 0
n_dest_exists_mismatch = 0
n_will_copy = 0

print("\nPreflight: destination existence and size checks (no overwrite)...")
for item in tqdm(copy_plan, desc="Preflight (dest checks)", dynamic_ncols=True):
    sp = Path(item["src_path"])
    dp = Path(item["dst_path"])
    require(sp.exists(), f"Source clip missing at preflight (should have been caught earlier): {str(sp)}")

    if dp.exists():
        try:
            if dp.stat().st_size == sp.stat().st_size:
                n_dest_exists_ok += 1
            else:
                n_dest_exists_mismatch += 1
                add_warn(
                    "D7_TRAIN_ENH", "ERROR", "DEST_EXISTS_SIZE_MISMATCH",
                    "Destination exists but file size differs from source",
                    src_path=str(sp), dest_path=str(dp),
                    src_size=int(sp.stat().st_size), dest_size=int(dp.stat().st_size),
                )
        except Exception as e:
            n_dest_exists_mismatch += 1
            add_warn(
                "D7_TRAIN_ENH", "ERROR", "DEST_EXISTS_STAT_ERROR",
                "Failed to stat source/destination during preflight",
                src_path=str(sp), dest_path=str(dp), error=repr(e),
            )
    else:
        n_will_copy += 1

preflight_stats = {
    "total_planned_files": int(len(copy_plan)),
    "n_dest_exists_ok": int(n_dest_exists_ok),
    "n_dest_exists_mismatch": int(n_dest_exists_mismatch),
    "n_will_copy": int(n_will_copy),
    "warnings_by_level": count_by_level(warnings_rows),
}

print("\nPreflight summary:")
print("  Planned files:", int(len(copy_plan)))
print("  Destination exists (size OK):", n_dest_exists_ok)
print("  Destination exists (size mismatch/stat error):", n_dest_exists_mismatch)
print("  Will copy:", n_will_copy)
print("  Warnings by level:", preflight_stats["warnings_by_level"])

# -------------------------
# Write run_config + early summary (before copying)
# Inputs: selection details + preflight stats
# Outputs: config/run_config.json, logs/dataset_summary.json, logs/preprocess_warnings.csv
# -------------------------
run_config = {
    "dataset": "D7",
    "mode": "train_enh_builder",
    "created_utc": datetime.utcnow().isoformat(),

    "d7_out_root": str(D7_OUT_ROOT),
    "d2_out_root": str(D2_OUT_ROOT),
    "d7_manifest_all": str(D7_MANIFEST_ALL),
    "d2_manifest_all": str(D2_MANIFEST_ALL),

    "train_enh_dir": str(TRAIN_ENH_DIR),
    "manifest_train_enh": str(MANIFEST_TRAIN_ENH),
    "train_enh_split_name": TRAIN_ENH_SPLIT_NAME,

    "policy": {
        "definition": "Add ceil(10%) of D2 TRAIN by speaker into D7 training clips folder",
        "pct_speakers": float(TEN_PCT),
        "rounding_policy": ROUNDING_POLICY,
        "balance_policy": BALANCE_POLICY,
        "speaker_id_col_used": SPEAKER_ID_COL,
        "rng_seed": int(RNG_SEED),
        "label_note": LABEL_MAP_NOTE,
        "file_operation": "copy",
        "no_overwrite_rule": "skip if dest exists with matching size; error if size differs",
    },

    "inputs": {
        "d7_train_rows": int(len(d7_train)),
        "d7_train_label_counts": d7_train["label_num"].value_counts(dropna=False).to_dict(),
        "d2_train_rows": int(len(d2_train)),
        "d2_train_label_counts": d2_train["label_num"].value_counts(dropna=False).to_dict(),
        "d2_train_total_speakers": int(total_speakers),
        "d2_train_speakers_hc": int(len(hc_speakers)),
        "d2_train_speakers_pd": int(len(pd_speakers)),
    },

    "selection": {
        "target_total_speakers": int(target_total),
        "target_per_class": int(target_per_class),
        "selected_speakers_hc": sel_hc,
        "selected_speakers_pd": sel_pd,
        "selected_speakers_all": selected_speakers,
        "selected_d2_rows": int(len(d2_sel)),
        "selected_d2_label_counts": d2_sel["label_num"].value_counts(dropna=False).to_dict(),
    },

    "preflight": preflight_stats,
}

early_summary = {
    "dataset": "D7",
    "created_utc": datetime.utcnow().isoformat(),
    "status": "PRECHECK_COMPLETE",
    "d7_out_root": str(D7_OUT_ROOT),
    "train_enh_dir": str(TRAIN_ENH_DIR),
    "train_enh_split_name": TRAIN_ENH_SPLIT_NAME,
    "preflight": preflight_stats,
    "selection": run_config["selection"],
}

atomic_write_csv(logs_dir / "preprocess_warnings.csv", pd.DataFrame(warnings_rows))
atomic_write_text(cfg_dir / "run_config.json", json.dumps(run_config, indent=2))
atomic_write_text(logs_dir / "dataset_summary.json", json.dumps(early_summary, indent=2))

# Stop here if preflight found any ERROR-level warnings.
fatal_pre = [w for w in warnings_rows if str(w.get("level", "")).upper() == "ERROR"]
if fatal_pre:
    print("\n❌ Preflight found ERROR conditions. No copying performed.")
    print("   Logs written:")
    print("   -", str(logs_dir / "preprocess_warnings.csv"))
    print("   -", str(logs_dir / "dataset_summary.json"))
    print("   -", str(cfg_dir / "run_config.json"))
    raise RuntimeError(f"Preflight failed with {len(fatal_pre)} ERROR(s). See logs/preprocess_warnings.csv.")

# -------------------------
# Copy clips into clips/train_enh1 (skip existing)
# Inputs: copy_plan
# Outputs: copied audio files under clips/train_enh1/ + updated warnings log
# -------------------------
print("\nCopy stage: copying into train_enh1 (no overwrite)...")

copied = 0
skipped_exists = 0
copy_errors = 0

for item in tqdm(copy_plan, desc="Copying clips", dynamic_ncols=True):
    sp = Path(item["src_path"])
    dp = Path(item["dst_path"])

    if not sp.exists():
        copy_errors += 1
        add_warn("D7_TRAIN_ENH", "ERROR", "SOURCE_CLIP_DISAPPEARED", "Source clip missing during copy stage", clip_path=str(sp))
        continue

    if dp.exists():
        skipped_exists += 1
        continue

    try:
        shutil.copy2(sp, dp)
        copied += 1
    except Exception as e:
        copy_errors += 1
        add_warn("D7_TRAIN_ENH", "ERROR", "COPY_FAILED", "Failed to copy file", src_path=str(sp), dest_path=str(dp), error=repr(e))

copy_stats = {
    "copied": int(copied),
    "skipped_exists": int(skipped_exists),
    "copy_errors": int(copy_errors),
    "total_planned_files": int(len(copy_plan)),
    "warnings_by_level": count_by_level(warnings_rows),
}

print("\nCopy summary:")
print("  Copied:", copied)
print("  Skipped (already exists):", skipped_exists)
print("  Copy errors:", copy_errors)
print("  Warnings by level:", copy_stats["warnings_by_level"])

# Persist warnings after copying so failures still leave a readable trail.
atomic_write_csv(logs_dir / "preprocess_warnings.csv", pd.DataFrame(warnings_rows))

# Stop if any ERROR occurred during copying.
fatal_copy = [w for w in warnings_rows if str(w.get("level", "")).upper() == "ERROR"]
if fatal_copy:
    fail_summary = dict(early_summary)
    fail_summary["status"] = "FAILED_DURING_COPY"
    fail_summary["copy_stats"] = copy_stats
    atomic_write_text(logs_dir / "dataset_summary.json", json.dumps(fail_summary, indent=2))
    run_config["copy_stats"] = copy_stats
    atomic_write_text(cfg_dir / "run_config.json", json.dumps(run_config, indent=2))

    print("\n❌ Copy stage encountered ERROR conditions.")
    print("   Logs written:")
    print("   -", str(logs_dir / "preprocess_warnings.csv"))
    print("   -", str(logs_dir / "dataset_summary.json"))
    print("   -", str(cfg_dir / "run_config.json"))
    raise RuntimeError(f"Copy failed with {len(fatal_copy)} ERROR(s). See logs/preprocess_warnings.csv.")

# -------------------------
# Build manifest_train_enh1.csv (locked schema)
# Inputs: D7 train rows + selected D2 train rows
# Outputs: merged train_enh manifest pointing to clips/train_enh1/ paths
# -------------------------
# All rows are marked split="train_enh1".
# dataset is kept as "D7" so D7 training code treats this as a D7 training split.

# D7 existing train rows → train_enh manifest (paths rewritten to train_enh1 folder)
d7_enh = d7_train[CANON_COLS].copy()
d7_enh["split"] = TRAIN_ENH_SPLIT_NAME
d7_enh["dataset"] = "D7"
d7_enh["clip_path"] = d7_train["clip_path"].astype(str).map(lambda p: str(TRAIN_ENH_DIR / Path(p).name))
d7_enh["sample_id"] = d7_enh["clip_path"].map(lambda p: Path(p).stem)
d7_enh["source_dataset"] = d7_train.get("source_dataset", pd.Series(["D7"] * len(d7_train))).astype(str).tolist()

# D2 selected rows → train_enh manifest (paths rewritten to the renamed files in train_enh1)
d2_enh = d2_sel_reset[CANON_COLS].copy()
d2_enh["split"] = TRAIN_ENH_SPLIT_NAME
d2_enh["dataset"] = "D7"

new_paths = []
new_ids = []
for j, row in d2_sel_reset.iterrows():
    hc_pd = hc_pd_from_label_num(row["label_num"])
    spk_tok = safe_token(row["speaker_id"], 32, "NA")
    task_tok = safe_token(row["task"], 12, "0")
    out_name = f"D7_D2add_{hc_pd}_{spk_tok}_{task_tok}_{j+1:06d}.wav"
    out_path = TRAIN_ENH_DIR / out_name
    new_paths.append(str(out_path))
    new_ids.append(out_path.stem)

d2_enh["clip_path"] = new_paths
d2_enh["sample_id"] = new_ids
d2_enh["source_dataset"] = "D2"

# Combine into one training split and enforce exact column order.
train_enh = pd.concat([d7_enh, d2_enh], axis=0, ignore_index=True)
train_enh = train_enh[FINAL_COLS].copy()

# Validate no literal "NaN" strings slipped into the final manifest.
for c in FINAL_COLS:
    if train_enh[c].dtype == object and (train_enh[c] == "NaN").any():
        raise RuntimeError(f"Found literal string 'NaN' in column '{c}'. This violates NaN policy.")

# Sanity checks for labels and sex.
require(set(train_enh["label_num"].dropna().unique()).issubset({0, 1}), "train_enh manifest label_num contains values outside {0,1}.")
require(set(train_enh["label_str"].dropna().unique()).issubset({"Healthy", "Parkinson"}), "train_enh manifest label_str contains values outside {Healthy, Parkinson}.")
require(((train_enh["label_num"] == 0) == (train_enh["label_str"] == "Healthy")).all(), "train_enh manifest mismatch label_num==0 vs label_str.")
require(((train_enh["label_num"] == 1) == (train_enh["label_str"] == "Parkinson")).all(), "train_enh manifest mismatch label_num==1 vs label_str.")
bad_sex = sorted(set(train_enh["sex"].dropna().unique()) - {"M", "F"})
require(len(bad_sex) == 0, f"train_enh manifest sex contains unexpected values: {bad_sex}")

# Verify the manifest points only to files that exist under clips/train_enh1.
missing_enh = []
for p in tqdm(train_enh["clip_path"].astype(str).tolist(), desc="Check TRAIN_ENH clip_path exists", dynamic_ncols=True):
    if not os.path.exists(p):
        missing_enh.append(p)
        if len(missing_enh) >= 25:
            break
require(len(missing_enh) == 0, f"train_enh manifest points to missing files. Examples (up to 25): {missing_enh}")

# -------------------------
# Write final artifacts and final summary
# Inputs: train_enh dataframe + copy_stats + selection info
# Outputs: manifests/manifest_train_enh1.csv and logs/config JSON summaries
# -------------------------
atomic_write_csv(MANIFEST_TRAIN_ENH, train_enh)

final_summary = {
    "dataset": "D7",
    "created_utc": datetime.utcnow().isoformat(),
    "status": "SUCCESS",
    "d7_out_root": str(D7_OUT_ROOT),
    "train_enh_dir": str(TRAIN_ENH_DIR),
    "train_enh_split_name": TRAIN_ENH_SPLIT_NAME,
    "manifest_train_enh": str(MANIFEST_TRAIN_ENH),
    "selection": run_config["selection"],
    "copy_stats": copy_stats,
    "counts_train_enh": summarize_counts(train_enh, split_name=TRAIN_ENH_SPLIT_NAME),
}

atomic_write_text(logs_dir / "dataset_summary.json", json.dumps(final_summary, indent=2))
atomic_write_csv(logs_dir / "preprocess_warnings.csv", pd.DataFrame(warnings_rows))

run_config["copy_stats"] = copy_stats
run_config["outputs"] = {
    "train_enh_dir": str(TRAIN_ENH_DIR),
    "manifest_train_enh": str(MANIFEST_TRAIN_ENH),
    "dataset_summary_json": str(logs_dir / "dataset_summary.json"),
    "preprocess_warnings_csv": str(logs_dir / "preprocess_warnings.csv"),
    "run_config_json": str(cfg_dir / "run_config.json"),
}
atomic_write_text(cfg_dir / "run_config.json", json.dumps(run_config, indent=2))

print("\n✅ D7 train_enh build complete.")
print("- Train_enh folder:", str(TRAIN_ENH_DIR))
print("- Manifest:", str(MANIFEST_TRAIN_ENH))
print("- Summary:", str(logs_dir / "dataset_summary.json"))
print("- Warnings:", str(logs_dir / "preprocess_warnings.csv"))
print("- Config:", str(cfg_dir / "run_config.json"))

The following cell trains and validates the **D7 enhanced model (train_enh1)** using three fixed random seeds, while keeping the Wav2Vec2 backbone frozen and training only the small head layers. Training data comes from `manifests/manifest_train_enh1.csv` (rows where `split == "train_enh1"`), and validation data comes from the standard D7 `manifests/manifest_all.csv` (rows where `split == "val"`). The code does not reshuffle or recreate any splits and relies entirely on the existing manifest files. Before training starts, it checks that both manifests exist, confirms required columns are present, verifies that the validation data belongs to dataset D7, and stops immediately if any audio file listed in the train or validation splits is missing.

To handle clips of different lengths, the cell creates a `task_group` field using a simple rule. Clips with `task == "vowl"` are treated as vowel clips, and all others are treated as other. A custom dataset loader reads each audio file, converts it to mono if needed, checks that the sample rate is 16 kHz, and builds an attention mask. For vowel clips, the attention mask hides trailing near-silent samples so the model does not learn from padded silence. For other clips, the attention mask remains fully active. A collate function pads each batch to the longest clip in that batch and pads both the waveform and the attention mask in a consistent way.

The model is a two-head classifier built on top of a frozen `facebook/wav2vec2-base` backbone. It includes two small trainable pre-head blocks (LayerNorm and Dropout), one for vowel clips and one for other clips, followed by two linear classification heads. During the forward pass, the backbone produces features, a masked mean pooling step is applied using the attention mask, and each sample is routed to the vowel head or the other head based on its `task_group`. Only the pre-head blocks and the linear heads are trained. The backbone weights are never updated.

A key feature of this cell is initialization from a prior baseline model. Before training the enhanced model, the code searches `trainval_runs/exp_*/` and automatically selects the most recent experiment folder that contains a valid `summary_trainval.json` and all three `best_heads.pt` files, one for each seed. For each seed run, it loads that seed’s baseline `best_heads.pt` into the enhanced model’s heads and then continues training on the `train_enh1` data from that starting point. This helps keep results comparable and reduces sensitivity to random initialization.

For each seed (1337, 2024, 7777), the cell runs a full training and validation loop with progress bars and early stopping, using a patience of two epochs based on validation AUROC. Gradient accumulation is used so the effective batch size matches the intended value of 64, using a per-device batch size of 16. After each epoch, the model is evaluated on the validation set and AUROC is computed. If AUROC improves, the current heads are saved as the best snapshot and the validation probabilities are stored. When training for a seed finishes, the best heads are written to `best_heads.pt`, a validation ROC plot is saved, and confusion matrix plots are generated at both threshold 0.5 and the chosen optimal threshold.

The optimal validation threshold is computed only at the best-AUROC epoch using Youden’s J statistic, which maximizes TPR minus FPR on the validation ROC curve. For that best epoch, the cell computes and stores threshold-based validation metrics at both threshold 0.5 and the Youden-optimal threshold. These metrics include accuracy, precision, recall or sensitivity, specificity, F1 score, MCC, and Fisher exact test p-value. Each seed writes a `metrics.json` file that records the data sources used, details of the baseline initialization, the best epoch, best AUROC, optimal threshold information, and paths to all saved outputs.

After all three seeds finish, the cell combines results across seeds and writes an experiment-level `summary_trainval.json` under a new folder named `trainval_runs/exp_<EXPERIMENT_TAG>_<timestamp>/`. This summary includes AUROC values for each seed, the mean AUROC with a 95 percent t-based confidence interval using three seeds, and the validation-optimal thresholds stored in a standard format with per-seed values and overall mean and standard deviation. The same summary is appended as one line to the global `trainval_runs/history_index.jsonl` file for long-term tracking. The cell also mirrors small run configuration and dataset summary files into the builder-style `config/` and `logs/` folders for consistent record keeping, without changing the main training outputs. Finally, it prints the main output locations and unassigns the Colab runtime to stop the GPU instance.

In [None]:
# =========================
# D7 Enhanced Train + Val (train_enh1), Baseline-Initialized Heads
# Inputs: D7 enhanced-train manifest (train_enh1 split) + D7 full manifest (val split) + baseline best_heads.pt (per seed)
# Outputs: Per-seed best heads, metrics, ROC/confusion plots, plus an experiment summary + history index entries
# =========================
# Train + Val ONLY (CRASH-PROOF, WITH PROGRESS + HISTORY) — D7 ENHANCED (train_enh1)
# - Frozen Wav2Vec2 backbone
# - Two task heads + small LayerNorm + Dropout blocks (only these parts train)
# - Uses ONLY two manifests:
#     * Enhanced TRAIN manifest: train_enh1 split
#     * Standard VAL manifest: val split
# - Initializes heads from the most recent BASELINE D7 train+val experiment (same 3 seeds),
#   then continues training on train_enh1
# - Saves per-seed best artifacts and a per-experiment summary + append-only history index
# - Adds extra threshold-based metrics and computes a VAL-optimal threshold (Youden J) at the best-AUROC epoch
# - Ends by unassigning the Colab runtime (GPU stop)
#
# NOTES
# - Train comes from the enhanced train manifest (already points to enhanced clips)
# - Val comes from the standard D7 manifest
# - dataset_id is inferred from val (expected "D7"); no re-splitting happens here
# =========================

import os, json, math, random, time, warnings
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import soundfile as sf

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import Wav2Vec2Model

from sklearn.metrics import (
    roc_auc_score, roc_curve,
    confusion_matrix, accuracy_score,
    precision_recall_fscore_support,
    matthews_corrcoef
)
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# -------------------------
# 0) Safety checks: avoid importing a local file that masks real libraries
# -------------------------
if os.path.exists("/content/torch.py") or os.path.exists("/content/torch/__init__.py"):
    raise RuntimeError("Found local /content/torch.py or /content/torch/ that shadows PyTorch. Rename/remove it and restart runtime.")
if os.path.exists("/content/transformers.py") or os.path.exists("/content/transformers/__init__.py"):
    raise RuntimeError("Found local /content/transformers.py or /content/transformers/ that shadows Hugging Face Transformers. Rename/remove it and restart runtime.")

# -------------------------
# 1) Drive access (Colab): mount only if not already available
# -------------------------
try:
    from google.colab import drive  # type: ignore
    if not os.path.isdir("/content/drive/MyDrive"):
        drive.mount("/content/drive")
except Exception:
    pass

# -------------------------
# 2) Root paths and manifest names
# Inputs: a single D7 preprocessed root (DX_OUT_ROOT)
# Outputs: resolved manifest file paths and stable config/log folders
# -------------------------
D7_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D7-Multilingual (D1_D4_D5v2_D6)/preprocessed_v1"
DX_OUT_ROOT = str(globals().get("DX_OUT_ROOT", D7_OUT_ROOT_FALLBACK))
globals()["DX_OUT_ROOT"] = DX_OUT_ROOT

MANIFEST_ALL = f"{DX_OUT_ROOT}/manifests/manifest_all.csv"
MANIFEST_TRAIN_ENH = f"{DX_OUT_ROOT}/manifests/manifest_train_enh1.csv"
TRAIN_ENH_SPLIT_NAME = "train_enh1"

# Builder-style mirror folders (lightweight copies of key run info)
cfg_dir  = Path(DX_OUT_ROOT) / "config" / "D7_Enh1_on_D2_Test"
logs_dir = Path(DX_OUT_ROOT) / "logs" / "D7_Enh1_on_D2_Test"
cfg_dir.mkdir(parents=True, exist_ok=True)
logs_dir.mkdir(parents=True, exist_ok=True)

# -------------------------
# 3) Experiment identity and output folder for this run
# Output: a new exp_... folder under trainval_runs/ (never overwrites older experiments)
# -------------------------
EXPERIMENT_TAG = "frozen_LNDO_trainEnh1_initBaseline"
RUN_STAMP = time.strftime("%Y%m%d_%H%M%S")

TRAINVAL_ROOT = Path(DX_OUT_ROOT) / "trainval_runs"
EXP_ROOT = TRAINVAL_ROOT / f"exp_{EXPERIMENT_TAG}_{RUN_STAMP}"
EXP_ROOT.mkdir(parents=True, exist_ok=True)

# -------------------------
# 4) Training settings and runtime options
# Inputs: fixed hyperparameters, 3 seeds, and expected audio format
# -------------------------
MAX_EPOCHS     = 10
EFFECTIVE_BS   = 64
PER_DEVICE_BS  = 16
GRAD_ACCUM     = max(1, EFFECTIVE_BS // PER_DEVICE_BS)

LR             = 1e-3
PATIENCE       = 2
SEEDS          = [1337, 2024, 7777]

BACKBONE_CKPT  = "facebook/wav2vec2-base"
SR_EXPECTED    = 16000
TINY_THRESH    = 1e-4

DROPOUT_P      = 0.2

NUM_WORKERS    = 0
PIN_MEMORY     = False

VOWEL_TASK_VALUE = "vowl"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

# Quiet known warnings to keep logs readable
warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
warnings.filterwarnings("ignore", message="Passing `gradient_checkpointing` to a config initialization is deprecated")
warnings.filterwarnings("ignore", category=UserWarning, module="huggingface_hub")

# Quick run summary (prints only; does not change any artifacts)
print("DX_OUT_ROOT:", DX_OUT_ROOT)
print("MANIFEST_TRAIN_ENH:", MANIFEST_TRAIN_ENH)
print("TRAIN_ENH_SPLIT_NAME:", TRAIN_ENH_SPLIT_NAME)
print("MANIFEST_ALL (val source):", MANIFEST_ALL)
print("DEVICE:", DEVICE)
if DEVICE.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
print("PER_DEVICE_BS:", PER_DEVICE_BS, "| GRAD_ACCUM:", GRAD_ACCUM, "| EFFECTIVE_BS:", PER_DEVICE_BS * GRAD_ACCUM)
print("NUM_WORKERS:", NUM_WORKERS, "| PIN_MEMORY:", PIN_MEMORY)
print("EXPERIMENT_TAG:", EXPERIMENT_TAG, "| RUN_STAMP:", RUN_STAMP)
print("EXP_ROOT:", str(EXP_ROOT))
print("cfg_dir:", str(cfg_dir))
print("logs_dir:", str(logs_dir))

# -------------------------
# 5) Read manifests and build the train/val tables
# Inputs: enhanced TRAIN manifest + standard ALL manifest
# Outputs: train_df (train_enh1) and val_df (val) with the same core columns
# -------------------------
if not os.path.exists(MANIFEST_TRAIN_ENH):
    raise FileNotFoundError(
        "Missing manifest_train_enh1.csv at:\n"
        f"  {MANIFEST_TRAIN_ENH}\n"
        "Run the train_enh builder first."
    )
if not os.path.exists(MANIFEST_ALL):
    raise FileNotFoundError(
        "Missing manifest_all.csv at:\n"
        f"  {MANIFEST_ALL}\n"
        "Confirm D7 merge-builder wrote manifests/manifest_all.csv under DX_OUT_ROOT."
    )

m_train = pd.read_csv(MANIFEST_TRAIN_ENH)
m_all   = pd.read_csv(MANIFEST_ALL)

# Minimum required columns for training and validation
req_cols = {"split", "clip_path", "label_num", "task"}
for name, df in [("manifest_train_enh1", m_train), ("manifest_all", m_all)]:
    missing = [c for c in sorted(req_cols) if c not in df.columns]
    if missing:
        raise ValueError(f"{name} missing required columns: {missing}. Found: {list(df.columns)}")

# Split selection: enhanced train split + standard val split
m_train = m_train[m_train["split"].astype(str) == TRAIN_ENH_SPLIT_NAME].copy()
m_val   = m_all[m_all["split"].astype(str) == "val"].copy()

if len(m_train) == 0:
    raise RuntimeError(f"After filtering manifest_train_enh1.csv to split=={TRAIN_ENH_SPLIT_NAME!r}, 0 rows remain.")
if len(m_val) == 0:
    raise RuntimeError("After filtering manifest_all.csv to split=='val', 0 rows remain.")

# Infer dataset_id from validation data (used for naming outputs and as a guard)
if "dataset" in m_val.columns and m_val["dataset"].notna().any():
    dataset_id = str(m_val["dataset"].astype(str).value_counts(dropna=True).idxmax())
    m_val = m_val[m_val["dataset"].astype(str) == dataset_id].copy()
else:
    dataset_id = "DX"

# Guard: this training script expects D7 validation data
if dataset_id != "D7":
    raise RuntimeError(f"Dataset inferred from VAL manifest is {dataset_id!r}. Expected 'D7'. Check DX_OUT_ROOT/manifests/manifest_all.csv.")

# Keep a consistent, compact set of columns across train and val
keep_cols = ["clip_path", "label_num", "task", "speaker_id", "duration_sec", "split"]
for df in [m_train, m_val]:
    for c in keep_cols:
        if c not in df.columns:
            df[c] = np.nan

m_train = m_train[keep_cols].copy()
m_val   = m_val[keep_cols].copy()

train_df = m_train.copy().reset_index(drop=True)
val_df   = m_val.copy().reset_index(drop=True)

print(f"\nDataset inferred (from VAL): {dataset_id}")
print(f"Train rows ({TRAIN_ENH_SPLIT_NAME}): {len(train_df)} | Val rows: {len(val_df)}")
print("Train label counts:", train_df["label_num"].value_counts(dropna=False).to_dict())
print("Val label counts:",   val_df["label_num"].value_counts(dropna=False).to_dict())

# -------------------------
# 6) Fail-fast: confirm audio files exist before training starts
# Input: train_df and val_df clip_path values
# Output: raises early with example missing paths, instead of failing mid-epoch
# -------------------------
def _fail_fast_missing_paths(df: pd.DataFrame, name: str):
    missing_paths = []
    for p in tqdm(df["clip_path"].astype(str).tolist(), desc=f"Check {name} clip_path exists", dynamic_ncols=True):
        if not os.path.exists(p):
            missing_paths.append(p)
            if len(missing_paths) >= 10:
                break
    if missing_paths:
        raise FileNotFoundError(f"{name}: missing clip_path(s). Examples: {missing_paths}")

_fail_fast_missing_paths(train_df, "TRAIN_ENH1")
_fail_fast_missing_paths(val_df, "VAL")

# -------------------------
# 7) Task grouping for the two-head model
# Rule: task == "vowl" -> vowel head, otherwise -> other head
# -------------------------
def _task_group(task_val) -> str:
    return "vowel" if str(task_val) == VOWEL_TASK_VALUE else "other"

train_df["task_group"] = train_df["task"].apply(_task_group)
val_df["task_group"]   = val_df["task"].apply(_task_group)

# -------------------------
# 8) Dataset + batch collation (waveforms + attention masks)
# Inputs: manifest rows with clip_path, label_num, and task_group
# Outputs: padded tensors for model input, plus per-item task_group routing
# -------------------------
class AudioManifestDataset(Dataset):
    """
    Loads a clip and creates an attention mask at the sample level.

    Masking rule:
    - vowel clips: mask trailing near-silence (often padding)
    - other clips: keep all samples active
    """
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        clip_path = str(row["clip_path"])
        label = int(row["label_num"])
        task_group = str(row["task_group"])

        if not os.path.exists(clip_path):
            raise FileNotFoundError(f"Missing clip_path: {clip_path}")

        # Load audio and force mono float32
        y, sr = sf.read(clip_path, always_2d=False)
        if isinstance(y, np.ndarray) and y.ndim > 1:
            y = np.mean(y, axis=1)
        y = np.asarray(y, dtype=np.float32)

        # Hard check: training assumes a single sample rate everywhere
        if int(sr) != SR_EXPECTED:
            raise ValueError(f"Sample rate mismatch (expected {SR_EXPECTED}): {clip_path} has sr={sr}")

        # Default: attend to everything
        attn = np.ones((len(y),), dtype=np.int64)

        # Vowel: ignore trailing near-zeros so padding does not affect training
        if task_group == "vowel":
            k = -1
            for j in range(len(y) - 1, -1, -1):
                if abs(float(y[j])) > TINY_THRESH:
                    k = j
                    break
            if k >= 0:
                attn[:k+1] = 1
                attn[k+1:] = 0
            else:
                attn[:] = 1
        else:
            attn[:] = 1

        return {
            "input_values": torch.from_numpy(y),                 # float32 [T]
            "attention_mask": torch.from_numpy(attn),            # int64   [T]
            "labels": torch.tensor(label, dtype=torch.long),     # int64   []
            "task_group": task_group,                            # str
        }

def collate_fn(batch):
    """Pads waveforms and masks to the longest clip in the batch."""
    max_len = int(max(b["input_values"].numel() for b in batch))
    input_vals, attn_masks, labels, task_groups = [], [], [], []
    for b in batch:
        x = b["input_values"]
        a = b["attention_mask"]
        pad = max_len - x.numel()
        if pad > 0:
            x = torch.cat([x, torch.zeros(pad, dtype=x.dtype)], dim=0)
            a = torch.cat([a, torch.zeros(pad, dtype=a.dtype)], dim=0)
        input_vals.append(x)
        attn_masks.append(a)
        labels.append(b["labels"])
        task_groups.append(b["task_group"])
    return {
        "input_values": torch.stack(input_vals, dim=0),      # [B,T]
        "attention_mask": torch.stack(attn_masks, dim=0),    # [B,T]
        "labels": torch.stack(labels, dim=0),                # [B]
        "task_group": task_groups,                           # list[str]
    }

# -------------------------
# 9) Two-head classifier with a frozen Wav2Vec2 backbone
# Inputs: waveform + attention mask + task_group routing
# Outputs: loss and logits (PD probability comes from softmax(logits)[:,1])
# -------------------------
class Wav2Vec2TwoHeadClassifier(nn.Module):
    """
    Trainable parts:
    - LayerNorm+Dropout blocks (one per task group)
    - Linear heads (one per task group)
    Backbone stays frozen.
    """
    def __init__(self, ckpt: str, dropout_p: float):
        super().__init__()
        self.backbone = Wav2Vec2Model.from_pretrained(
            ckpt,
            use_safetensors=True,
            local_files_only=False
        )
        for p in self.backbone.parameters():
            p.requires_grad = False

        H = int(self.backbone.config.hidden_size)
        self.pre_vowel = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.pre_other = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))

        self.head_vowel = nn.Linear(H, 2)
        self.head_other = nn.Linear(H, 2)
        self.loss_fn = nn.CrossEntropyLoss()

    def masked_mean_pool(self, last_hidden, attn_mask_samples):
        # Converts sample-level mask into feature-level mask, then average-pools valid frames
        feat_mask = self.backbone._get_feature_vector_attention_mask(last_hidden.shape[1], attn_mask_samples)
        feat_mask = feat_mask.to(last_hidden.device).unsqueeze(-1).type_as(last_hidden)
        masked = last_hidden * feat_mask
        denom = feat_mask.sum(dim=1).clamp(min=1.0)
        return masked.sum(dim=1) / denom

    def forward(self, input_values, attention_mask, labels, task_group):
        # Backbone forward is wrapped in no_grad since it is frozen
        with torch.no_grad():
            out = self.backbone(input_values=input_values, attention_mask=attention_mask)
            last_hidden = out.last_hidden_state  # [B,T',H]

        pooled = self.masked_mean_pool(last_hidden, attention_mask).float()  # [B,H]

        # Separate feature transforms for the two task groups
        z_v = self.pre_vowel(pooled)
        z_o = self.pre_other(pooled)

        # Compute both heads, then select per item using task_group
        logits_v = self.head_vowel(z_v)
        logits_o = self.head_other(z_o)

        is_vowel = torch.tensor([tg == "vowel" for tg in task_group], device=pooled.device, dtype=torch.bool)
        logits = logits_o.clone()
        if is_vowel.any():
            logits[is_vowel] = logits_v[is_vowel]

        loss = self.loss_fn(logits, labels)
        return loss, logits

# -------------------------
# 9.5) Baseline head loading
# Input: best_heads.pt saved from an earlier baseline train+val run (per seed)
# Output: model heads set to baseline weights before enhanced training starts
# -------------------------
def load_heads_into_model(model: Wav2Vec2TwoHeadClassifier, best_heads_path: Path):
    if not best_heads_path.exists():
        raise FileNotFoundError(f"Missing best_heads.pt: {str(best_heads_path)}")
    state = torch.load(str(best_heads_path), map_location="cpu")
    model.pre_vowel.load_state_dict(state["pre_vowel"], strict=True)
    model.pre_other.load_state_dict(state["pre_other"], strict=True)
    model.head_vowel.load_state_dict(state["head_vowel"], strict=True)
    model.head_other.load_state_dict(state["head_other"], strict=True)
    return model

# -------------------------
# 10) Metric helpers
# Inputs: y_true labels and y_prob PD probabilities
# Outputs: AUROC, threshold-based metrics, and a VAL-optimal threshold from ROC (Youden J)
# -------------------------
def compute_auc(y_true, y_prob):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    if len(np.unique(y_true)) < 2:
        return float("nan")
    return float(roc_auc_score(y_true, y_prob))

def compute_threshold_metrics(y_true, y_prob, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    tn, fp, fn, tp = (cm.ravel().tolist() if cm.size == 4 else [0, 0, 0, 0])

    # Common classification metrics at a fixed threshold
    acc = float(accuracy_score(y_true, y_pred))
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", zero_division=0)
    mcc = float(matthews_corrcoef(y_true, y_pred)) if len(np.unique(y_true)) > 1 else float("nan")

    sensitivity = float(rec)
    specificity = float(tn / (tn + fp)) if (tn + fp) > 0 else float("nan")

    # Fisher exact p-value for the 2x2 confusion table (when SciPy is available)
    p_value = float("nan")
    try:
        from scipy.stats import fisher_exact  # type: ignore
        _, p_value = fisher_exact([[tn, fp], [fn, tp]], alternative="two-sided")
        p_value = float(p_value)
    except Exception:
        p_value = float("nan")

    return {
        "threshold": float(thr),
        "tn": int(tn), "fp": int(fp), "fn": int(fn), "tp": int(tp),
        "accuracy": float(acc),
        "precision": float(prec),
        "recall": float(rec),
        "f1": float(f1),
        "sensitivity": float(sensitivity),
        "specificity": float(specificity),
        "mcc": float(mcc),
        "p_value_fisher": float(p_value),
    }

def compute_youden_j_threshold(y_true, y_prob):
    # Picks the threshold that maximizes (TPR - FPR) on the validation ROC curve
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    if len(np.unique(y_true)) < 2:
        return float("nan"), {"youden_j": float("nan"), "tpr": float("nan"), "fpr": float("nan")}
    fpr, tpr, thr = roc_curve(y_true, y_prob)
    j = tpr - fpr
    idx = int(np.argmax(j))
    return float(thr[idx]), {"youden_j": float(j[idx]), "tpr": float(tpr[idx]), "fpr": float(fpr[idx])}

def save_roc_curve_png(y_true, y_prob, out_png):
    # Saves a simple ROC plot for the best validation epoch
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    plt.figure()
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curve (Val)")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def save_confusion_png(y_true, y_prob, out_png, thr=0.5):
    # Saves a confusion matrix image for a chosen threshold
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(int)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])

    plt.figure()
    plt.imshow(cm)
    plt.xticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.yticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix (Val, thr={thr:.4f})")
    for (i, j), v in np.ndenumerate(cm):
        plt.text(j, i, str(v), ha="center", va="center")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def mean_sd(vals):
    # Small helper for reporting mean ± SD across the three seeds
    vals = np.asarray(vals, dtype=np.float64)
    mu = float(np.nanmean(vals)) if np.any(~np.isnan(vals)) else float("nan")
    sd = float(np.nanstd(vals, ddof=1)) if np.sum(~np.isnan(vals)) > 1 else 0.0
    return mu, sd

# -------------------------
# 11) Reproducibility: set all random seeds (Python, NumPy, PyTorch)
# -------------------------
def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# -------------------------
# 11.5) Find the baseline experiment used to initialize heads
# Inputs: trainval_runs/exp_* folders
# Output: the most recent exp_* that contains summary_trainval.json and 3 best_heads.pt files
# -------------------------
BASELINE_TRAINVAL_ROOT = Path(DX_OUT_ROOT) / "trainval_runs"
if not BASELINE_TRAINVAL_ROOT.exists():
    raise FileNotFoundError(f"Missing trainval_runs folder under DX_OUT_ROOT: {str(BASELINE_TRAINVAL_ROOT)}")

exp_dirs = sorted([p for p in BASELINE_TRAINVAL_ROOT.glob("exp_*") if p.is_dir()],
                  key=lambda p: p.stat().st_mtime, reverse=True)
if not exp_dirs:
    raise FileNotFoundError(f"No exp_* folders found under: {str(BASELINE_TRAINVAL_ROOT)}")

train_dataset_id = "D7"

def _has_all_seeds_and_summary(exp_path: Path, dataset_id: str, seeds: list):
    summary_path = exp_path / "summary_trainval.json"
    if not summary_path.exists():
        return False
    for s in seeds:
        p = exp_path / f"run_{dataset_id}_seed{s}" / "best_heads.pt"
        if not p.exists():
            return False
    return True

baseline_exp = None
for ed in exp_dirs:
    # Skip the current experiment folder to avoid self-selection
    if ed.resolve() == EXP_ROOT.resolve():
        continue
    if _has_all_seeds_and_summary(ed, train_dataset_id, SEEDS):
        baseline_exp = ed
        break

if baseline_exp is None:
    sample = exp_dirs[0]
    raise FileNotFoundError(
        "Could not find a baseline D7 trainval experiment with all 3 best_heads.pt files and summary_trainval.json.\n"
        f"Expected under: {str(BASELINE_TRAINVAL_ROOT)}/exp_*/run_D7_seedXXXX/best_heads.pt and summary_trainval.json\n"
        f"Most recent exp checked: {str(sample)}"
    )

baseline_summary_path = baseline_exp / "summary_trainval.json"
with open(baseline_summary_path, "r", encoding="utf-8") as f:
    baseline_summary = json.load(f)

print("\nBaseline initialization experiment selected:")
print(" ", str(baseline_exp))
print(" ", "summary:", str(baseline_summary_path))
for s in SEEDS:
    p = baseline_exp / f"run_{train_dataset_id}_seed{s}" / "best_heads.pt"
    if not p.exists():
        raise RuntimeError(f"Baseline artifact missing after selection. Missing: {str(p)}")

# -------------------------
# 12) Per-seed training loop with early stopping on validation AUROC
# Inputs: train_df + val_df and baseline-initialized heads
# Outputs: best_heads.pt + best-epoch plots + metrics.json for that seed
# -------------------------
def run_trainval_once(seed: int):
    set_all_seeds(seed)

    run_dir = EXP_ROOT / f"run_{dataset_id}_seed{seed}"
    run_dir.mkdir(parents=True, exist_ok=True)

    # Dataset objects keep I/O simple and consistent across train and val
    train_ds = AudioManifestDataset(train_df)
    val_ds   = AudioManifestDataset(val_df)

    train_loader = DataLoader(
        train_ds,
        batch_size=PER_DEVICE_BS,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        collate_fn=collate_fn
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=PER_DEVICE_BS,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        collate_fn=collate_fn
    )

    # Warm-up: confirm batches can be formed and read successfully
    print(f"\n[seed={seed}] Warm-up: loading 3 train batches...")
    t0 = time.time()
    it = iter(train_loader)
    for i in range(3):
        _ = next(it)
        print(f"  loaded warmup batch {i+1}/3")
    print(f"[seed={seed}] Warm-up done in {time.time()-t0:.2f}s")

    # Build model and start from baseline heads for this seed
    model = Wav2Vec2TwoHeadClassifier(BACKBONE_CKPT, dropout_p=DROPOUT_P).to(DEVICE)

    baseline_heads_path = baseline_exp / f"run_{train_dataset_id}_seed{seed}" / "best_heads.pt"
    print(f"[seed={seed}] Initializing heads from baseline:")
    print(" ", str(baseline_heads_path))
    model = load_heads_into_model(model, baseline_heads_path)
    model.train()

    # Optimizer only sees trainable head parameters (backbone stays frozen)
    trainable_params = (
        list(model.pre_vowel.parameters()) + list(model.pre_other.parameters()) +
        list(model.head_vowel.parameters()) + list(model.head_other.parameters())
    )
    opt = torch.optim.Adam(trainable_params, lr=LR)

    # Track the best validation AUROC and stop after PATIENCE non-improving epochs
    best_auc = -1.0
    best_epoch = -1
    no_improve = 0

    # Save the best heads and the validation predictions from the best epoch
    best_state = None
    best_val_probs = None
    best_val_true = None

    # Threshold tracking for reporting at the best epoch
    best_thr_youden = float("nan")
    best_thr_youden_details = None
    best_val_metrics_thr05 = None
    best_val_metrics_thr_opt = None

    for epoch in range(1, MAX_EPOCHS + 1):
        # ---- Train phase ----
        model.train()
        train_losses = []
        opt.zero_grad(set_to_none=True)

        pbar = tqdm(train_loader, desc=f"[seed={seed}] Train epoch {epoch}", dynamic_ncols=True)
        step = 0
        for batch in pbar:
            step += 1
            input_values = batch["input_values"].to(DEVICE, non_blocking=False)
            attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
            labels = batch["labels"].to(DEVICE, non_blocking=False)
            task_group = batch["task_group"]

            loss, _ = model(input_values, attention_mask, labels, task_group)
            loss = loss / GRAD_ACCUM
            loss.backward()

            train_losses.append(float(loss.detach().cpu().item()) * GRAD_ACCUM)

            # Gradient accumulation to reach the effective batch size
            if (step % GRAD_ACCUM) == 0:
                opt.step()
                opt.zero_grad(set_to_none=True)

        # Final step if the epoch ends mid-accumulation
        if (step % GRAD_ACCUM) != 0:
            opt.step()
            opt.zero_grad(set_to_none=True)

        avg_train_loss = float(np.mean(train_losses)) if train_losses else float("nan")

        # ---- Validation phase ----
        model.eval()
        all_probs, all_true = [], []
        vpbar = tqdm(val_loader, desc=f"[seed={seed}] Val epoch {epoch}", dynamic_ncols=True)
        with torch.inference_mode():
            for batch in vpbar:
                input_values = batch["input_values"].to(DEVICE, non_blocking=False)
                attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
                labels = batch["labels"].to(DEVICE, non_blocking=False)
                task_group = batch["task_group"]

                _, logits = model(input_values, attention_mask, labels, task_group)
                probs = torch.softmax(logits.float(), dim=-1)[:, 1].detach().cpu().numpy().astype(np.float64)
                all_probs.extend(probs.tolist())
                all_true.extend(labels.detach().cpu().numpy().astype(np.int64).tolist())

        # Main selection metric for the best epoch
        val_auc = compute_auc(all_true, all_probs)
        print(f"seed={seed} | epoch {epoch:02d}/{MAX_EPOCHS} | train_loss={avg_train_loss:.5f} | val_AUROC={val_auc:.5f}")

        # Save best epoch artifacts by AUROC
        improved = (not math.isnan(val_auc)) and (val_auc > best_auc + 1e-12)
        if improved:
            best_auc = float(val_auc)
            best_epoch = int(epoch)
            no_improve = 0

            # Save only the trainable parts (heads + small pre-head blocks)
            best_state = {
                "pre_vowel": {k: v.detach().cpu().clone() for k, v in model.pre_vowel.state_dict().items()},
                "pre_other": {k: v.detach().cpu().clone() for k, v in model.pre_other.state_dict().items()},
                "head_vowel": {k: v.detach().cpu().clone() for k, v in model.head_vowel.state_dict().items()},
                "head_other": {k: v.detach().cpu().clone() for k, v in model.head_other.state_dict().items()},
            }

            # Save validation outputs for plots and threshold calculation
            best_val_probs = list(all_probs)
            best_val_true  = list(all_true)

            # Baseline threshold report at 0.5
            best_val_metrics_thr05 = compute_threshold_metrics(best_val_true, best_val_probs, thr=0.5)

            # VAL-optimal threshold from ROC curve at the best epoch
            thr_opt, details = compute_youden_j_threshold(best_val_true, best_val_probs)
            best_thr_youden = float(thr_opt)
            best_thr_youden_details = details
            best_val_metrics_thr_opt = compute_threshold_metrics(best_val_true, best_val_probs, thr=best_thr_youden)
        else:
            no_improve += 1

        # Early stopping: stop when validation AUROC no longer improves
        if no_improve >= PATIENCE:
            break

    # Guard: ensure at least one best epoch was captured
    if best_state is None or best_val_probs is None or best_val_true is None:
        raise RuntimeError(
            "No best epoch captured. Validation AUROC may be NaN due to single-class validation split "
            "or earlier failures."
        )

    # Save the best heads for reuse in later test-only code
    best_heads_path = run_dir / "best_heads.pt"
    torch.save(best_state, str(best_heads_path))

    # Best-epoch plots (ROC + confusion matrices at 0.5 and at VAL-opt threshold)
    roc_png = run_dir / "roc_curve.png"
    cm_png_05 = run_dir / "confusion_matrix_thr0p5.png"
    cm_png_opt = run_dir / "confusion_matrix_thr_opt.png"

    save_roc_curve_png(np.asarray(best_val_true, dtype=np.int64), np.asarray(best_val_probs, dtype=np.float64), str(roc_png))
    save_confusion_png(np.asarray(best_val_true, dtype=np.int64), np.asarray(best_val_probs, dtype=np.float64), str(cm_png_05), thr=0.5)
    if not np.isnan(best_thr_youden):
        save_confusion_png(np.asarray(best_val_true, dtype=np.int64), np.asarray(best_val_probs, dtype=np.float64), str(cm_png_opt), thr=float(best_thr_youden))

    # Per-seed metrics file (used later for threshold lookup and provenance)
    metrics = {
        "dataset": dataset_id,
        "seed": int(seed),
        "best_val_auroc": float(best_auc),
        "best_epoch": int(best_epoch),

        # Data sources (train_enh1 + standard val)
        "train_manifest_used": MANIFEST_TRAIN_ENH,
        "val_manifest_used": MANIFEST_ALL,
        "train_split_name": TRAIN_ENH_SPLIT_NAME,

        # Baseline initialization details (where the starting heads came from)
        "init_heads": {
            "mode": "baseline_best_heads",
            "baseline_exp_used": str(baseline_exp),
            "baseline_summary_path": str(baseline_summary_path),
            "baseline_best_heads_path": str(baseline_heads_path),
        },

        "n_train": int(len(train_df)),
        "n_val": int(len(val_df)),
        "label_counts_train": train_df["label_num"].value_counts(dropna=False).to_dict(),
        "label_counts_val": val_df["label_num"].value_counts(dropna=False).to_dict(),

        "experiment_tag": EXPERIMENT_TAG,
        "run_stamp": RUN_STAMP,

        "dropout_p": float(DROPOUT_P),
        "lr": float(LR),
        "effective_batch_size": int(PER_DEVICE_BS * GRAD_ACCUM),
        "per_device_batch_size": int(PER_DEVICE_BS),
        "grad_accum_steps": int(GRAD_ACCUM),

        "backbone_ckpt": BACKBONE_CKPT,

        # Threshold selection based on validation ROC at the best epoch
        "val_opt_threshold_method": "Youden J (maximize TPR - FPR on VAL ROC curve)",
        "val_opt_threshold": float(best_thr_youden),
        "val_opt_details": best_thr_youden_details,

        "thr_metrics_val_thr0p5": best_val_metrics_thr05,
        "thr_metrics_val_thr_opt": best_val_metrics_thr_opt,

        "artifacts": {
            "roc_curve_png": str(roc_png),
            "confusion_thr0p5_png": str(cm_png_05),
            "confusion_thr_opt_png": str(cm_png_opt),
            "best_heads_pt": str(best_heads_path),
        },
    }

    metrics_path = run_dir / "metrics.json"
    with open(metrics_path, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2)

    print(f"[seed={seed}] VAL-opt threshold (Youden J): {float(best_thr_youden):.6f}")
    print(f"[seed={seed}] WROTE:")
    print(" ", str(metrics_path))
    print(" ", str(roc_png))
    print(" ", str(cm_png_05))
    print(" ", str(cm_png_opt))
    print(" ", str(best_heads_path))

    return {
        "seed": int(seed),
        "best_val_auroc": float(best_auc),
        "best_epoch": int(best_epoch),
        "val_opt_thr": float(best_thr_youden),
        "run_dir": str(run_dir),
        "seed_metrics": metrics,
    }

# -------------------------
# 13) Run all seeds, then write a single experiment summary
# Outputs: summary_trainval.json (experiment-level) + history_index.jsonl (append-only)
# -------------------------
results = []
for seed in SEEDS:
    results.append(run_trainval_once(seed))

aucs = [r["best_val_auroc"] for r in results]
thr_vals = [r["val_opt_thr"] for r in results]

# 95% CI for mean AUROC across 3 seeds (t distribution, df=2)
t_crit = 4.302652729911275  # df=2, 95% CI
n = len(aucs)
mean_auc = float(np.mean(aucs))
std_auc = float(np.std(aucs, ddof=1)) if n > 1 else 0.0
half_width = float(t_crit * (std_auc / math.sqrt(n))) if n > 1 else 0.0
ci95 = [float(mean_auc - half_width), float(mean_auc + half_width)]

thr_mean, thr_sd = mean_sd(thr_vals)

print("\nAUROC by seed:")
for r in results:
    print(f"  seed {r['seed']}: {r['best_val_auroc']:.6f}")
print(f"\nMean AUROC: {mean_auc:.6f}")
print(f"95% CI (t, n=3): [{ci95[0]:.6f}, {ci95[1]:.6f}]")

print("\nVAL-opt thresholds (Youden J) by seed:")
for r in results:
    print(f"  seed {r['seed']}: {r['val_opt_thr']:.6f}")
print(f"  mean ± SD: {thr_mean:.6f} ± {thr_sd:.6f}")

# Store thresholds in one canonical place (by-seed + mean/sd)
val_optimal_threshold_obj = {
    "method": "Youden J (maximize TPR - FPR on VAL ROC curve)",
    "by_seed": {str(r["seed"]): float(r["val_opt_thr"]) for r in results},
    "mean_sd": {"mean": float(thr_mean), "sd": float(thr_sd)},
}

# Experiment-level summary that later test scripts can read
exp_summary = {
    "dataset": dataset_id,
    "dx_out_root": DX_OUT_ROOT,

    # Data sources
    "train_manifest_used": MANIFEST_TRAIN_ENH,
    "val_manifest_used": MANIFEST_ALL,
    "train_split_name": TRAIN_ENH_SPLIT_NAME,

    # Baseline initialization (shared baseline exp across all seeds)
    "init_heads": {
        "mode": "baseline_best_heads",
        "baseline_exp_used": str(baseline_exp),
        "baseline_summary_path": str(baseline_summary_path),
        "baseline_best_heads_by_seed": {
            str(s): str(baseline_exp / f"run_{train_dataset_id}_seed{s}" / "best_heads.pt") for s in SEEDS
        },
    },

    "experiment_tag": EXPERIMENT_TAG,
    "run_stamp": RUN_STAMP,
    "exp_root": str(EXP_ROOT),
    "run_dirs": [r["run_dir"] for r in results],
    "seeds": SEEDS,

    # AUROC aggregation across seeds
    "aurocs": [float(x) for x in aucs],
    "mean_auroc": float(mean_auc),
    "t_crit_df2_95": float(t_crit),
    "half_width_95_ci": float(half_width),
    "ci95": ci95,

    # Basic dataset counts (train/val)
    "n_train": int(len(train_df)),
    "n_val": int(len(val_df)),
    "label_counts_train": train_df["label_num"].value_counts(dropna=False).to_dict(),
    "label_counts_val": val_df["label_num"].value_counts(dropna=False).to_dict(),

    # Training settings that matter for reproducibility
    "effective_batch_size": int(PER_DEVICE_BS * GRAD_ACCUM),
    "per_device_batch_size": int(PER_DEVICE_BS),
    "grad_accum_steps": int(GRAD_ACCUM),

    "backbone_ckpt": BACKBONE_CKPT,
    "dropout_p": float(DROPOUT_P),
    "lr": float(LR),

    # Canonical threshold block
    "val_optimal_threshold": val_optimal_threshold_obj,
    "per_seed_metrics": [r["seed_metrics"] for r in results],
}

summary_path = EXP_ROOT / "summary_trainval.json"
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(exp_summary, f, indent=2)

# Append-only global history file (keeps a record of all experiments)
history_path = TRAINVAL_ROOT / "history_index.jsonl"
TRAINVAL_ROOT.mkdir(parents=True, exist_ok=True)
with open(history_path, "a", encoding="utf-8") as f:
    f.write(json.dumps(exp_summary) + "\n")

# -------------------------
# 13.5) Builder-style mirror artifacts (small, stable files)
# Outputs: run_config_trainval.json + dataset_summary_trainval.json + placeholder warnings CSV
# -------------------------
trainval_run_config = {
    "dataset": dataset_id,
    "mode": "trainval_enh1",
    "created_utc": datetime.utcnow().isoformat(),
    "dx_out_root": DX_OUT_ROOT,
    "train_manifest_used": MANIFEST_TRAIN_ENH,
    "val_manifest_used": MANIFEST_ALL,
    "train_split_name": TRAIN_ENH_SPLIT_NAME,
    "experiment_tag": EXPERIMENT_TAG,
    "run_stamp": RUN_STAMP,
    "exp_root": str(EXP_ROOT),
    "baseline_exp_used": str(baseline_exp),
    "baseline_summary_path": str(baseline_summary_path),
    "seeds": SEEDS,
}
with open(cfg_dir / "run_config_trainval.json", "w", encoding="utf-8") as f:
    json.dump(trainval_run_config, f, indent=2)

trainval_dataset_summary = {
    "dataset": dataset_id,
    "created_utc": datetime.utcnow().isoformat(),
    "status": "SUCCESS",
    "train_split_name": TRAIN_ENH_SPLIT_NAME,
    "n_train": int(len(train_df)),
    "n_val": int(len(val_df)),
    "label_counts_train": train_df["label_num"].value_counts(dropna=False).to_dict(),
    "label_counts_val": val_df["label_num"].value_counts(dropna=False).to_dict(),
    "exp_root": str(EXP_ROOT),
    "summary_trainval_json": str(summary_path),
}
with open(logs_dir / "dataset_summary_trainval.json", "w", encoding="utf-8") as f:
    json.dump(trainval_dataset_summary, f, indent=2)

with open(logs_dir / "trainval_warnings.csv", "w", encoding="utf-8") as f:
    f.write("ts,level,message\n")  # placeholder: errors surface as exceptions in this cell

print("\nWROTE per-experiment summary:", str(summary_path))
print("APPENDED global history index:", str(history_path))
print("\nWROTE (builder-aligned):")
print(" ", str(cfg_dir / "run_config_trainval.json"))
print(" ", str(logs_dir / "dataset_summary_trainval.json"))
print(" ", str(logs_dir / "trainval_warnings.csv"))
print("\nOpen this folder to access artifacts:", str(EXP_ROOT))

# -------------------------
# 14) Release runtime resources (stop L4 GPU)
# -------------------------
print("\nAll done. Unassigning the runtime to stop the L4 instance...")
try:
    from google.colab import runtime  # type: ignore
    print("Calling runtime.unassign() now.")
    runtime.unassign()
except Exception as e:
    print("Could not unassign runtime automatically. Error:", repr(e))
    print("Manual stop: Runtime -> Disconnect and delete runtime.")

The following cell runs a **test-only evaluation** of the **D7 enhanced model (trainEnh1)** on the **D2 test split**, with no training involved. It reads only the D2 `manifest_all.csv`, keeps rows where `split == "test"`, confirms that the manifest truly belongs to D2 as a hard safety check, and verifies that every listed audio file exists before continuing. It also adds two simple helper columns used during evaluation: **task_group** (set to “vowel” when `task == "vowl"`, otherwise “other”) and a **normalized sex** label where D2 values `"male"` and `"female"` are mapped to `M` and `F`, and anything else is set to `UNK`.

Next, the cell automatically finds the **most recent D7 enhanced train and validation experiment** under `<D7_OUT_ROOT>/trainval_runs/exp_*/` whose folder name includes “trainEnh1” (case-insensitive). It requires that this experiment contains a `summary_trainval.json` file and `best_heads.pt` files for all three seeds (1337, 2024, 7777). After selecting the experiment, it checks again that all three head files are present. From the experiment summary, it reads a **single global decision threshold** from `val_optimal_threshold.mean_sd.mean`. This same threshold is used for all seeds. If the value is missing or invalid, the code falls back to 0.5 and records that this fallback was used.

For each seed, the cell rebuilds the model using a **frozen Wav2Vec2 backbone with two task-specific heads** (one for vowel clips and one for other clips), loads the saved head weights, and runs inference on the full D2 test set. For each seed it writes a dedicated output folder that includes a `predictions.csv` file with clip-level results and metadata (clip path, true label, PD probability, sex, speaker ID, task group, seed, run tag, and threshold used), a `metrics.json` file with AUROC, threshold-based metrics, fairness values, and artifact paths, and several plots. The plots include a ROC curve, an overall confusion matrix, and confusion matrices split by sex for M and F when enough data is available.

After all three seeds finish, the cell prints and saves an overall summary. This includes the **mean test AUROC with a 95 percent t-based confidence interval (n=3)**, along with threshold-based metrics reported as **mean ± standard deviation** across seeds. It also computes the paper-ready fairness metric **H3** on the D2 test set at the same global threshold, defined as **ΔFNR = FNR(F) − FNR(M)** and its absolute value, where FNR is calculated only on true Parkinson’s cases. Finally, the cell writes results to a structured directory under `monolingual_test_runs/`, updates a pointer to the latest run, appends the summary to a history log, writes a stable tag-based pointer, mirrors key metadata into builder-style `config/` and `logs/` folders with backups if needed, and then unassigns the Colab runtime to stop the GPU.

In [None]:
# =========================
# D7 trainEnh1 Heads → D2 Test Evaluation
# Inputs: D2 manifest (test split) and latest matching D7 trainval experiment (best_heads + summary_trainval)
# Outputs: per-seed predictions and metrics, plots, run summary files, plus latest and pointer files
# =========================
# =========================
# TEST ONLY (CRASH-PROOF, WITH PROGRESS + STORED METRICS) — D7 ENHANCED → D2 TEST
# - Evaluates the D7 ENHANCED trained heads (trained on train_enh1) on the D2 TEST split only
# - Uses ONLY D2: <D2_OUT_ROOT>/manifests/manifest_all.csv  (TEST split)
# - Loads finished heads from MOST RECENT D7 *ENHANCED* trainval experiment under:
#     <D7_OUT_ROOT>/trainval_runs/exp_*/run_D7_seed{seed}/best_heads.pt
#   Selection rule: exp folder name must contain substring "trainEnh1" (case-insensitive)
#   and must contain all three seeds + summary_trainval.json.
# - Uses the SINGLE MEAN VAL-optimal threshold stored by that D7 trainval in:
#     summary_trainval.json -> val_optimal_threshold.mean_sd.mean
#   (No VAL threshold recomputation in this cell)
# - Evaluates 3 seeds separately (1337, 2024, 7777)
# - Reports:
#     * mean Test AUROC ± 95% CI (t, n=3)
#     * A single threshold used for ALL seeds (mean val-opt threshold) + note if fallback to 0.5
#     * Threshold metrics on D2 TEST @ that single threshold as mean ± SD
#     * FAIRNESS (H3) on D2 TEST @ that single threshold as mean ± SD
#     * Confusion charts split by sex (M/F) on D2 TEST @ that single threshold
# - Writes all artifacts under:
#     <D7_OUT_ROOT>/monolingual_test_runs/run_<FULL_TRAINVAL_EXP_TAG>__<RUN_STAMP>/...
#   plus:
#     * monolingual_test_runs/last_run_pointer.json (intentional overwrite)
#     * monolingual_test_runs/summary_latest.json (intentional overwrite)
#     * monolingual_test_runs/history_index.jsonl (append-only)
#     * monolingual_test_runs/run_<TAG>/tag_run_pointer.json (never overwritten across tags)
#
# D2-SPECIFIC NOTE (the manifest):
# - sex is encoded as the exact strings "male" / "female" (case-sensitive)
#   This code maps: "male" -> M, "female" -> F, and anything else -> UNK
#
# GUARDS:
# A) Hard-assert D2 dataset_id == "D2" after inference from D2 manifest
# B) Re-assert all best_heads.pt exist immediately after chosen_exp is selected
# =========================

import os, json, math, random, time, warnings
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import soundfile as sf

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import Wav2Vec2Model

from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, matthews_corrcoef
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# -------------------------
# 0) Environment safety checks
# -------------------------
# Prevents accidental import of local files named like core libraries.
if os.path.exists("/content/torch.py") or os.path.exists("/content/torch/__init__.py"):
    raise RuntimeError("Found local /content/torch.py or /content/torch/ that shadows PyTorch. Rename/remove it and restart runtime.")
if os.path.exists("/content/transformers.py") or os.path.exists("/content/transformers/__init__.py"):
    raise RuntimeError("Found local /content/transformers.py or /content/transformers/ that shadows Hugging Face Transformers. Rename/remove it and restart runtime.")

# -------------------------
# 1) Drive access (Colab-friendly)
# -------------------------
# Mounts Google Drive when needed; safe to skip outside Colab.
try:
    from google.colab import drive  # type: ignore
    if not os.path.isdir("/content/drive/MyDrive"):
        drive.mount("/content/drive")
except Exception:
    pass

# -------------------------
# 2) Resolve input roots and basic run helpers
# -------------------------
# Inputs:
# - D2 manifest is read from the D2 root
# - D7 trainval experiments are searched under the D7 root
# Outputs:
# - all test artifacts are written under the D7 root
D7_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D7-Multilingual (D1_D4_D5v2_D6)/preprocessed_v1"
D2_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D2-Slovak (EWA-DB)/EWA-DB/preprocessed_v1"

D7_OUT_ROOT = str(globals().get("D7_OUT_ROOT", D7_OUT_ROOT_FALLBACK))
D2_OUT_ROOT = str(globals().get("D2_OUT_ROOT", D2_OUT_ROOT_FALLBACK))

D2_MANIFEST_ALL = f"{D2_OUT_ROOT}/manifests/manifest_all.csv"

# Keep DX_OUT_ROOT aligned with the run root (D7), since style uses DX_OUT_ROOT.
DX_OUT_ROOT = D7_OUT_ROOT
globals()["DX_OUT_ROOT"] = DX_OUT_ROOT
globals()["D7_OUT_ROOT"] = D7_OUT_ROOT
globals()["D2_OUT_ROOT"] = D2_OUT_ROOT

# Timestamp used in output folder naming and backups.
RUN_STAMP = time.strftime("%Y%m%d_%H%M%S")

def _backup_if_exists(p: Path):
    # Moves an existing file aside before overwriting it.
    if p.exists():
        bak = p.with_suffix(p.suffix + f".bak_{RUN_STAMP}")
        try:
            p.rename(bak)
            print(f"BACKUP: {str(p)} -> {str(bak)}")
        except Exception as e:
            raise RuntimeError(f"Could not backup existing file before overwrite: {str(p)}. Error: {repr(e)}")

def _sanitize_tag(s: str) -> str:
    # Makes a filesystem-safe tag for folder names.
    s = str(s).strip()
    out = []
    for ch in s:
        if ch.isalnum() or ch in ["-", "_"]:
            out.append(ch)
        else:
            out.append("_")
    out = "".join(out).strip("_")
    return out if out else "tag"

# -------------------------
# 3) Fixed evaluation settings
# -------------------------
# Keeps evaluation consistent across runs and seeds.
SEEDS          = [1337, 2024, 7777]
BACKBONE_CKPT  = "facebook/wav2vec2-base"
SR_EXPECTED    = 16000
TINY_THRESH    = 1e-4

# Batch settings for inference (no training happens in this cell).
EFFECTIVE_BS   = 64
PER_DEVICE_BS  = 16
GRAD_ACCUM     = max(1, EFFECTIVE_BS // PER_DEVICE_BS)

# Head architecture settings (must match trainval).
DROPOUT_P      = 0.2

# Dataloader stability settings (safe defaults for Colab).
NUM_WORKERS    = 0
PIN_MEMORY     = False

# Mixed precision speeds up GPU inference when available.
USE_AMP        = True

# Only experiments whose name includes this substring are considered.
REQUIRED_EXP_SUBSTRING = "trainEnh1"  # case-insensitive

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

# Quiet known, non-actionable warnings.
warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
warnings.filterwarnings("ignore", message="Passing `gradient_checkpointing` to a config initialization is deprecated")
warnings.filterwarnings("ignore", category=UserWarning, module="huggingface_hub")

print("D7_OUT_ROOT:", D7_OUT_ROOT)
print("D2_OUT_ROOT:", D2_OUT_ROOT)
print("D2_MANIFEST_ALL:", D2_MANIFEST_ALL)
print("DEVICE:", DEVICE)
if DEVICE.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
print("PER_DEVICE_BS:", PER_DEVICE_BS, "| EFFECTIVE_BS:", PER_DEVICE_BS * GRAD_ACCUM)
print("NUM_WORKERS:", NUM_WORKERS, "| PIN_MEMORY:", PIN_MEMORY)
print("USE_AMP:", bool(USE_AMP and DEVICE.type == "cuda"))
print("Enhanced exp required substring (case-insensitive):", REQUIRED_EXP_SUBSTRING)

# -------------------------
# 4) Load D2 manifest and build the TEST table
# -------------------------
# Reads D2 manifest_all.csv, checks required columns, and keeps only split=="test".
# Also confirms the manifest truly belongs to dataset "D2" (Guard A).
if not os.path.exists(D2_MANIFEST_ALL):
    raise FileNotFoundError(f"Missing D2 manifest_all.csv: {D2_MANIFEST_ALL}")

m_all = pd.read_csv(D2_MANIFEST_ALL)

# Required fields for evaluation and subgroup reporting.
req_cols = {"split", "clip_path", "label_num", "task", "sex", "age"}
missing = [c for c in sorted(req_cols) if c not in m_all.columns]
if missing:
    raise ValueError(f"D2 manifest missing required columns: {missing}. Found: {list(m_all.columns)}")

# Infer dataset id from the manifest to avoid mixing roots by mistake.
if "dataset" in m_all.columns and m_all["dataset"].notna().any():
    d2_dataset_id = str(m_all["dataset"].astype(str).value_counts(dropna=True).idxmax())
    m_all = m_all[m_all["dataset"].astype(str) == d2_dataset_id].copy()
else:
    d2_dataset_id = "DX"

# --------- GUARD A ----------
# Stops immediately if the loaded manifest is not D2.
if d2_dataset_id != "D2":
    raise RuntimeError(
        f"Expected D2 dataset_id=='D2' but got {d2_dataset_id!r}. "
        "This usually means D2_OUT_ROOT is wrong or the manifest is not D2. "
        f"D2_OUT_ROOT={D2_OUT_ROOT}"
    )

# Keep a minimal, consistent column set for inference + reporting.
keep_cols = ["clip_path", "label_num", "task", "speaker_id", "sex", "age", "duration_sec", "split"]
for c in keep_cols:
    if c not in m_all.columns:
        m_all[c] = np.nan
m_all = m_all[keep_cols].copy()

# IMPORTANT: TEST split only
test_df = m_all[m_all["split"].astype(str) == "test"].reset_index(drop=True)

print(f"\nD2 dataset inferred: {d2_dataset_id}")
print(f"D2 TEST rows: {len(test_df)}")
print("D2 TEST label counts:", test_df["label_num"].value_counts(dropna=False).to_dict())
print("D2 TEST sex counts (raw):", test_df["sex"].value_counts(dropna=False).to_dict())

if len(test_df) == 0:
    raise RuntimeError("After filtering to split=='test', D2 manifest has 0 rows.")

# -------------------------
# 5) Confirm all test audio files exist
# -------------------------
# Prevents running a long evaluation that will fail midway due to missing clips.
def _fail_fast_missing_paths(df: pd.DataFrame, name: str):
    missing_paths = []
    for p in tqdm(df["clip_path"].astype(str).tolist(), desc=f"Check {name} clip_path exists", dynamic_ncols=True):
        if not os.path.exists(p):
            missing_paths.append(p)
            if len(missing_paths) >= 10:
                break
    if missing_paths:
        raise FileNotFoundError(f"{name}: missing clip_path(s). Examples: {missing_paths}")

_fail_fast_missing_paths(test_df, "D2 TEST")

# -------------------------
# 6) Add task grouping used by the two-head model
# -------------------------
# "vowl" clips go through the vowel head; everything else goes through the other head.
def _task_group(task_val) -> str:
    return "vowel" if str(task_val) == "vowl" else "other"

test_df["task_group"] = test_df["task"].apply(_task_group)

# -------------------------
# 6.5) Normalize sex values for subgroup reporting
# -------------------------
# D2 uses exact strings "male"/"female"; everything else is treated as unknown.
def normalize_sex_d2_case_sensitive(val) -> str:
    if pd.isna(val):
        return "UNK"
    if val == "male":
        return "M"
    if val == "female":
        return "F"
    return "UNK"

test_df["sex_norm"] = test_df["sex"].apply(normalize_sex_d2_case_sensitive)
print("D2 TEST sex counts (normalized):", test_df["sex_norm"].value_counts(dropna=False).to_dict())
if (test_df["sex_norm"] == "UNK").any():
    print("NOTE: Some D2 'sex' values were not exactly 'male'/'female' and were mapped to 'UNK'.")

# -------------------------
# 7) Dataset and batching logic
# -------------------------
# Loads audio, checks sample rate, and builds attention masks:
# - vowel: mask tries to ignore trailing silence
# - other: mask keeps all samples
class AudioManifestDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        clip_path = str(row["clip_path"])
        label = int(row["label_num"])
        task_group = str(row["task_group"])
        sex_norm = str(row["sex_norm"])
        speaker_id = row["speaker_id"] if "speaker_id" in row.index else np.nan

        if not os.path.exists(clip_path):
            raise FileNotFoundError(f"Missing clip_path: {clip_path}")

        y, sr = sf.read(clip_path, always_2d=False)
        if isinstance(y, np.ndarray) and y.ndim > 1:
            y = np.mean(y, axis=1)
        y = np.asarray(y, dtype=np.float32)

        # Hard check to match the model’s expected input rate.
        if int(sr) != SR_EXPECTED:
            raise ValueError(f"Sample rate mismatch (expected {SR_EXPECTED}): {clip_path} has sr={sr}")

        # Attention mask is used after feature extraction to ignore padded or silent regions.
        attn = np.ones((len(y),), dtype=np.int64)
        if task_group == "vowel":
            k = -1
            for j in range(len(y) - 1, -1, -1):
                if abs(float(y[j])) > TINY_THRESH:
                    k = j
                    break
            if k >= 0:
                attn[:k+1] = 1
                attn[k+1:] = 0
            else:
                attn[:] = 1
        else:
            attn[:] = 1

        return {
            "input_values": torch.from_numpy(y),
            "attention_mask": torch.from_numpy(attn),
            "labels": torch.tensor(label, dtype=torch.long),
            "task_group": task_group,
            "sex_norm": sex_norm,
            "clip_path": clip_path,
            "speaker_id": speaker_id,
        }

def collate_fn(batch):
    # Pads variable-length waveforms to the longest clip in the batch.
    max_len = int(max(b["input_values"].numel() for b in batch))
    input_vals, attn_masks, labels = [], [], []
    task_groups, sex_norms, clip_paths, speaker_ids = [], [], [], []
    for b in batch:
        x = b["input_values"]
        a = b["attention_mask"]
        pad = max_len - x.numel()
        if pad > 0:
            x = torch.cat([x, torch.zeros(pad, dtype=x.dtype)], dim=0)
            a = torch.cat([a, torch.zeros(pad, dtype=a.dtype)], dim=0)
        input_vals.append(x)
        attn_masks.append(a)
        labels.append(b["labels"])
        task_groups.append(b["task_group"])
        sex_norms.append(b["sex_norm"])
        clip_paths.append(b["clip_path"])
        speaker_ids.append(b["speaker_id"])
    return {
        "input_values": torch.stack(input_vals, dim=0),
        "attention_mask": torch.stack(attn_masks, dim=0),
        "labels": torch.stack(labels, dim=0),
        "task_group": task_groups,
        "sex_norm": sex_norms,
        "clip_path": clip_paths,
        "speaker_id": speaker_ids,
    }

# -------------------------
# 8) Model definition (same structure as training)
# -------------------------
# Backbone is frozen; only the two heads are loaded from best_heads.pt.
class Wav2Vec2TwoHeadClassifier(nn.Module):
    def __init__(self, ckpt: str, dropout_p: float):
        super().__init__()
        self.backbone = Wav2Vec2Model.from_pretrained(ckpt, use_safetensors=True)
        for p in self.backbone.parameters():
            p.requires_grad = False

        H = int(self.backbone.config.hidden_size)
        self.pre_vowel = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.pre_other = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.head_vowel = nn.Linear(H, 2)
        self.head_other = nn.Linear(H, 2)

    def masked_mean_pool(self, last_hidden, attn_mask_samples):
        # Pools time frames using the attention mask mapped into feature space.
        feat_mask = self.backbone._get_feature_vector_attention_mask(last_hidden.shape[1], attn_mask_samples)
        feat_mask = feat_mask.to(last_hidden.device).unsqueeze(-1).type_as(last_hidden)
        masked = last_hidden * feat_mask
        denom = feat_mask.sum(dim=1).clamp(min=1.0)
        return masked.sum(dim=1) / denom

    def _heads_fp32(self, x_any: torch.Tensor, head: nn.Module) -> torch.Tensor:
        # Runs heads in float32 for numeric stability.
        x = x_any.float()
        if x.is_cuda:
            with torch.amp.autocast(device_type="cuda", enabled=False):
                return head(x)
        return head(x)

    def forward_logits(self, input_values, attention_mask, task_group):
        # Backbone forward is frozen and runs without gradients.
        with torch.no_grad():
            out = self.backbone(input_values=input_values, attention_mask=attention_mask)
            last_hidden = out.last_hidden_state

        pooled = self.masked_mean_pool(last_hidden, attention_mask)

        z_v = self.pre_vowel(pooled.float())
        z_o = self.pre_other(pooled.float())

        logits_v = self._heads_fp32(z_v, self.head_vowel)
        logits_o = self._heads_fp32(z_o, self.head_other)

        # Routes each item to the correct head based on task_group.
        is_vowel = torch.tensor([tg == "vowel" for tg in task_group], device=pooled.device, dtype=torch.bool)
        logits = logits_o.clone()
        if is_vowel.any():
            logits[is_vowel] = logits_v[is_vowel]
        return logits

# -------------------------
# 9) Metrics and plot writers
# -------------------------
# Computes AUROC, threshold metrics, ROC curves, and confusion matrix images.
def compute_auc(y_true, y_prob):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    if len(np.unique(y_true)) < 2:
        return float("nan")
    return float(roc_auc_score(y_true, y_prob))

def compute_threshold_metrics(y_true, y_prob, thr=0.5):
    # Standard classification metrics at a fixed probability threshold.
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])

    eps = 1e-12
    acc = (TP + TN) / max(1, (TP + TN + FP + FN))
    prec = TP / (TP + FP + eps)
    rec = TP / (TP + FN + eps)     # sensitivity
    f1 = 2 * prec * rec / (prec + rec + eps)
    spec = TN / (TN + FP + eps)

    try:
        mcc = float(matthews_corrcoef(y_true, y_pred)) if len(np.unique(y_pred)) > 1 else float("nan")
    except Exception:
        mcc = float("nan")

    pval = float("nan")
    try:
        from scipy.stats import fisher_exact  # type: ignore
        _, pval = fisher_exact([[TN, FP], [FN, TP]], alternative="two-sided")
        pval = float(pval)
    except Exception:
        pval = float("nan")

    return {
        "threshold": float(thr),
        "tn": TN, "fp": FP, "fn": FN, "tp": TP,
        "accuracy": float(acc),
        "precision": float(prec),
        "recall": float(rec),
        "f1": float(f1),
        "sensitivity": float(rec),
        "specificity": float(spec),
        "mcc": float(mcc),
        "p_value_fisher_two_sided": float(pval),
    }

def save_roc_curve_png(y_true, y_prob, out_png, title_suffix="Test"):
    # Saves an ROC curve image for quick visual checks.
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    plt.figure()
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curve ({title_suffix})")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def save_confusion_png(y_true, y_prob, out_png, thr=0.5, title_suffix="Test"):
    # Saves a confusion matrix image at the chosen threshold.
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(int)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])

    plt.figure()
    plt.imshow(cm)
    plt.xticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.yticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix ({title_suffix}, thr={thr:.4f})")
    for (i, j), v in np.ndenumerate(cm):
        plt.text(j, i, str(v), ha="center", va="center")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

# -------------------------
# 9.5) Fairness metric (sex-based ΔFNR)
# -------------------------
# Computes FNR per sex and ΔFNR = FNR(F) - FNR(M) using PD-only true labels.
def compute_fnr_by_group_signed(y_true, y_prob, groups, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    groups = np.asarray(list(groups), dtype=object)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    out = {}
    for g in sorted(set(groups.tolist())):
        mask_g = (groups == g)
        if int(mask_g.sum()) == 0:
            continue

        pos_mask = mask_g & (y_true == 1)
        n_pos = int(pos_mask.sum())
        if n_pos == 0:
            out[g] = {"n_total": int(mask_g.sum()), "n_pos": 0, "tp": 0, "fn": 0, "fnr": float("nan")}
            continue

        tp = int(((y_pred == 1) & pos_mask).sum())
        fn = int(((y_pred == 0) & pos_mask).sum())
        fnr = float(fn / max(1, (fn + tp)))
        out[g] = {"n_total": int(mask_g.sum()), "n_pos": int(n_pos), "tp": int(tp), "fn": int(fn), "fnr": float(fnr)}

    fnr_m = out.get("M", {}).get("fnr", float("nan"))
    fnr_f = out.get("F", {}).get("fnr", float("nan"))
    if (not np.isnan(fnr_m)) and (not np.isnan(fnr_f)):
        delta_signed = float(fnr_f - fnr_m)
        delta_abs = float(abs(delta_signed))
    else:
        delta_signed = float("nan")
        delta_abs = float("nan")

    return out, delta_signed, delta_abs

def compute_confusion_counts(y_true, y_prob, thr=0.5):
    # Convenience wrapper for confusion counts at a threshold.
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    return {"TN": int(cm[0, 0]), "FP": int(cm[0, 1]), "FN": int(cm[1, 0]), "TP": int(cm[1, 1])}

def compute_confusion_by_group(y_true, y_prob, groups, thr=0.5):
    # Builds confusion counts per group label (M/F/UNK).
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    groups = np.asarray(list(groups), dtype=object)
    out = {}
    for g in sorted(set(groups.tolist())):
        mask = (groups == g)
        if int(mask.sum()) == 0:
            continue
        out[g] = {"n": int(mask.sum()), "confusion": compute_confusion_counts(y_true[mask], y_prob[mask], thr=thr)}
    return out

# -------------------------
# 10) Reproducibility seed setter
# -------------------------
# Keeps per-seed evaluation deterministic.
def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# -------------------------
# 11) Find the latest matching D7 trainval experiment
# -------------------------
# Searches exp_* folders (newest first) and picks the first that:
# - includes REQUIRED_EXP_SUBSTRING in the folder name
# - contains all three best_heads.pt files and summary_trainval.json
TRAINVAL_ROOT = Path(D7_OUT_ROOT) / "trainval_runs"
if not TRAINVAL_ROOT.exists():
    raise FileNotFoundError(f"Missing trainval_runs folder under D7_OUT_ROOT: {str(TRAINVAL_ROOT)}")

exp_dirs = sorted([p for p in TRAINVAL_ROOT.glob("exp_*") if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
if not exp_dirs:
    raise FileNotFoundError(f"No exp_* folders found under: {str(TRAINVAL_ROOT)}")

train_dataset_id = "D7"  # expected naming from trainval code (run_D7_seedXXXX)

def _is_enhanced_exp_dir(exp_path: Path, required_substring: str) -> bool:
    return (required_substring.lower() in exp_path.name.lower())

def _has_all_seeds_and_summary(exp_path: Path, dataset_id: str, seeds: list) -> bool:
    summary_path = exp_path / "summary_trainval.json"
    if not summary_path.exists():
        return False
    for s in seeds:
        p = exp_path / f"run_{dataset_id}_seed{s}" / "best_heads.pt"
        if not p.exists():
            return False
    return True

chosen_exp = None
for ed in exp_dirs:
    if not _is_enhanced_exp_dir(ed, REQUIRED_EXP_SUBSTRING):
        continue
    if _has_all_seeds_and_summary(ed, train_dataset_id, SEEDS):
        chosen_exp = ed
        break

if chosen_exp is None:
    sample = exp_dirs[0]
    raise FileNotFoundError(
        "Could not find a recent D7 *ENHANCED* trainval experiment folder that:\n"
        f"  (1) contains substring '{REQUIRED_EXP_SUBSTRING}' (case-insensitive) in the exp folder name, and\n"
        "  (2) contains all 3 best_heads.pt files + summary_trainval.json.\n\n"
        f"Most recent exp checked (for reference): {str(sample)}"
    )

# Full experiment folder name is used as the run tag.
FULL_TRAINVAL_EXP_TAG = chosen_exp.name
TAG_SAFE = _sanitize_tag(FULL_TRAINVAL_EXP_TAG)
RUN_PARENT_DIRNAME = f"run_{TAG_SAFE}__{RUN_STAMP}"

# Output layout:
# - RUN_ROOT is the unique folder for this run (tag + timestamp)
# - TAG_ROOT is the stable folder used to store a tag-specific pointer file
TEST_ROOT = Path(D7_OUT_ROOT) / "monolingual_test_runs"
RUN_ROOT  = TEST_ROOT / RUN_PARENT_DIRNAME
RUN_ROOT.mkdir(parents=True, exist_ok=True)

TAG_ROOT = TEST_ROOT / f"run_{TAG_SAFE}"
TAG_ROOT.mkdir(parents=True, exist_ok=True)

# Additional config/log locations that align with the builder naming style.
cfg_dir  = Path(D7_OUT_ROOT) / "config" / f"D7_{TAG_SAFE}_on_D2_Test"
logs_dir = Path(D7_OUT_ROOT) / "logs"   / f"D7_{TAG_SAFE}_on_D2_Test"
cfg_dir.mkdir(parents=True, exist_ok=True)
logs_dir.mkdir(parents=True, exist_ok=True)

RUN_CONFIG_PATH       = cfg_dir / "run_config.json"
WARNINGS_CSV_PATH     = logs_dir / "preprocess_warnings.csv"
DATASET_SUMMARY_PATH  = logs_dir / "dataset_summary.json"

print("\nUsing D7 ENHANCED Train+Val experiment folder:")
print(" ", str(chosen_exp))
print("FULL_TRAINVAL_EXP_TAG:", FULL_TRAINVAL_EXP_TAG)
print("RUN_ROOT:", str(RUN_ROOT))
print("cfg_dir:", str(cfg_dir))
print("logs_dir:", str(logs_dir))

# --------- GUARD B ----------
# Confirms expected head files exist after experiment selection.
for s in SEEDS:
    p = chosen_exp / f"run_{train_dataset_id}_seed{s}" / "best_heads.pt"
    if not p.exists():
        raise RuntimeError(f"Trainval artifact missing after choosing exp. Missing: {str(p)}")

# Read trainval summary to fetch the global (mean) threshold.
summary_trainval_path = chosen_exp / "summary_trainval.json"
with open(summary_trainval_path, "r", encoding="utf-8") as f:
    d7_trainval_summary = json.load(f)

# -------------------------
# 11.5) Select one global threshold for all test seeds
# -------------------------
# Uses mean val-opt threshold from trainval summary; falls back to 0.5 if missing.
val_opt_obj = (d7_trainval_summary or {}).get("val_optimal_threshold", {}) or {}
thr_mean_sd = (val_opt_obj.get("mean_sd", {}) or {})

def _get_mean_val_opt_threshold() -> float:
    try:
        return float(thr_mean_sd.get("mean", float("nan")))
    except Exception:
        return float("nan")

THR_MEAN_FROM_TRAINVAL = _get_mean_val_opt_threshold()

if np.isnan(THR_MEAN_FROM_TRAINVAL):
    THR_USED_GLOBAL = 0.5
    THR_GLOBAL_NOTE = (
        "Mean val-opt threshold was missing/NaN in D7 enhanced summary_trainval.json. "
        "Fallback: THR_USED_GLOBAL=0.5 for ALL seeds."
    )
else:
    THR_USED_GLOBAL = float(THR_MEAN_FROM_TRAINVAL)
    THR_GLOBAL_NOTE = None

print("\nVAL-opt threshold selection for TEST (GLOBAL):")
print("  Source: summary_trainval.json -> val_optimal_threshold.mean_sd.mean")
print(f"  THR_USED_GLOBAL: {THR_USED_GLOBAL:.6f}")
if THR_GLOBAL_NOTE is not None:
    print("  NOTE:", THR_GLOBAL_NOTE)

# -------------------------
# 13) Build the D2 test DataLoader
# -------------------------
# Feeds audio batches to the model for inference.
test_ds = AudioManifestDataset(test_df)
test_loader = DataLoader(
    test_ds,
    batch_size=PER_DEVICE_BS,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    collate_fn=collate_fn,
)

# -------------------------
# 14) Warm-up loader to catch issues early
# -------------------------
# Loads a few batches to surface shape or file problems before full inference.
print("\nWarm-up: loading up to 3 D2 TEST batches...")
t0 = time.time()

def _warmup(loader, name):
    nb = len(loader)
    wb = min(3, nb)
    if wb == 0:
        raise RuntimeError(f"{name} DataLoader has 0 batches. Check df length and PER_DEVICE_BS.")
    it = iter(loader)
    for i in range(wb):
        _ = next(it)
        print(f"  loaded warmup {name} batch {i+1}/{wb}")

_warmup(test_loader, "D2 TEST")
print(f"Warm-up done in {time.time()-t0:.2f}s")

# -------------------------
# 15) Load trained heads into a fresh model instance
# -------------------------
# Loads only the head parameters saved by trainval (backbone stays frozen).
def load_heads_into_model(model: Wav2Vec2TwoHeadClassifier, best_heads_path: Path):
    if not best_heads_path.exists():
        raise FileNotFoundError(f"Missing best_heads.pt: {str(best_heads_path)}")
    state = torch.load(str(best_heads_path), map_location="cpu")
    model.pre_vowel.load_state_dict(state["pre_vowel"], strict=True)
    model.pre_other.load_state_dict(state["pre_other"], strict=True)
    model.head_vowel.load_state_dict(state["head_vowel"], strict=True)
    model.head_other.load_state_dict(state["head_other"], strict=True)
    return model

# -------------------------
# 16) Inference helper that also returns metadata
# -------------------------
# Returns:
# - y_true and y_score for metrics
# - sex_norm, task_group, clip_path, speaker_id for predictions.csv and subgroup reports
def _infer_probs_with_meta(loader, model, desc):
    use_amp = bool(USE_AMP and DEVICE.type == "cuda")
    all_probs, all_true, all_sex = [], [], []
    all_clip, all_spk, all_task = [], [], []

    pbar = tqdm(loader, desc=desc, dynamic_ncols=True)
    with torch.inference_mode():
        for batch in pbar:
            input_values = batch["input_values"].to(DEVICE, non_blocking=False)
            attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
            labels = batch["labels"].to(DEVICE, non_blocking=False)
            task_group = batch["task_group"]
            sex_norm = batch["sex_norm"]
            clip_paths = batch["clip_path"]
            speaker_ids = batch["speaker_id"]

            with torch.amp.autocast(device_type="cuda", enabled=use_amp):
                logits = model.forward_logits(input_values, attention_mask, task_group)

            probs = torch.softmax(logits.float(), dim=-1)[:, 1].detach().cpu().numpy().astype(np.float64)
            all_probs.extend(probs.tolist())
            all_true.extend(labels.detach().cpu().numpy().astype(np.int64).tolist())
            all_sex.extend(list(sex_norm))
            all_task.extend(list(task_group))
            all_clip.extend(list(clip_paths))
            all_spk.extend([("" if (x is None or (isinstance(x, float) and np.isnan(x))) else str(x)) for x in speaker_ids])

    return (
        np.asarray(all_true, dtype=np.int64),
        np.asarray(all_probs, dtype=np.float64),
        np.asarray(all_sex, dtype=object),
        np.asarray(all_clip, dtype=object),
        np.asarray(all_spk, dtype=object),
        np.asarray(all_task, dtype=object),
    )

# -------------------------
# 17) Run evaluation for one seed
# -------------------------
# Produces:
# - predictions.csv for that seed
# - metrics.json for that seed
# - ROC and confusion plots (overall + by sex when available)
def run_test_once(seed: int):
    set_all_seeds(seed)

    run_dir = RUN_ROOT / f"run_{train_dataset_id}_on_{d2_dataset_id}test_seed{seed}"
    run_dir.mkdir(parents=True, exist_ok=True)

    best_heads_path = chosen_exp / f"run_{train_dataset_id}_seed{seed}" / "best_heads.pt"

    print(f"\n[seed={seed}] Loading model + heads from:")
    print(" ", str(best_heads_path))

    model = Wav2Vec2TwoHeadClassifier(BACKBONE_CKPT, dropout_p=DROPOUT_P).to(DEVICE)
    model = load_heads_into_model(model, best_heads_path)
    model.eval()

    # Single global threshold applied to all seeds.
    thr_used = float(THR_USED_GLOBAL)
    thr_note = THR_GLOBAL_NOTE

    # Inference (includes metadata for predictions.csv).
    yt_true, yt_prob, yt_sex, yt_clip, yt_spk, yt_task = _infer_probs_with_meta(
        test_loader, model, desc=f"[seed={seed}] Test (D2 TEST)"
    )
    test_auc = compute_auc(yt_true, yt_prob)

    # Metrics and subgroup summaries at the global threshold.
    thr_metrics_test = compute_threshold_metrics(yt_true, yt_prob, thr=thr_used)
    fnr_by_sex, delta_f_minus_m, delta_abs = compute_fnr_by_group_signed(yt_true, yt_prob, yt_sex, thr=thr_used)
    confusion_by_sex = compute_confusion_by_group(yt_true, yt_prob, yt_sex, thr=thr_used)

    # Plots (overall).
    roc_png = run_dir / "roc_curve.png"
    cm_png  = run_dir / "confusion_matrix.png"
    save_roc_curve_png(yt_true, yt_prob, str(roc_png), title_suffix=f"D2 TEST (seed={seed})")
    save_confusion_png(yt_true, yt_prob, str(cm_png), thr=thr_used, title_suffix=f"D2 TEST (seed={seed})")

    # Plots (by sex: M and F only).
    cm_m_png = None
    cm_f_png = None
    mask_m = (yt_sex == "M")
    mask_f = (yt_sex == "F")

    if int(mask_m.sum()) > 0:
        cm_m_png = run_dir / "confusion_matrix_M.png"
        save_confusion_png(yt_true[mask_m], yt_prob[mask_m], str(cm_m_png), thr=thr_used, title_suffix=f"D2 TEST SEX=M (seed={seed})")

    if int(mask_f.sum()) > 0:
        cm_f_png = run_dir / "confusion_matrix_F.png"
        save_confusion_png(yt_true[mask_f], yt_prob[mask_f], str(cm_f_png), thr=thr_used, title_suffix=f"D2 TEST SEX=F (seed={seed})")

    # predictions.csv: one row per clip, includes model score and key metadata.
    pred_df = pd.DataFrame({
        "clip_path": yt_clip.astype(str),
        "y_true": yt_true.astype(int),
        "y_score": yt_prob.astype(float),
        "sex_norm": yt_sex.astype(str),
        "speaker_id": yt_spk.astype(str),
        "task_group": yt_task.astype(str),
        "seed": int(seed),
        "trainval_exp_tag": str(FULL_TRAINVAL_EXP_TAG),
        "run_stamp": str(RUN_STAMP),
        "threshold_used_global": float(thr_used),
    })
    pred_csv_path = run_dir / "predictions.csv"
    pred_df.to_csv(pred_csv_path, index=False)

    # metrics.json: full record of settings, metrics, and artifact paths for this seed.
    metrics = {
        "train_dataset": train_dataset_id,
        "test_dataset": d2_dataset_id,
        "seed": int(seed),

        "n_test": int(len(test_df)),
        "label_counts_test": test_df["label_num"].value_counts(dropna=False).to_dict(),
        "sex_counts_test_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),

        "test_auroc": float(test_auc),

        "threshold_source": "D7 enhanced trainval summary_trainval.json -> val_optimal_threshold.mean_sd.mean",
        "trainval_experiment_used": str(chosen_exp),
        "trainval_exp_tag": str(FULL_TRAINVAL_EXP_TAG),
        "trainval_summary_path": str(summary_trainval_path),

        "test_threshold_used_global": float(thr_used),
        "test_threshold_note_global": thr_note,

        "threshold_metrics_test_at_thr_used": thr_metrics_test,

        "fairness_test_at_thr_used": {
            "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at test_threshold_used_global.",
            "threshold_used": float(thr_used),
            "fnr_by_sex_norm": fnr_by_sex,
            "delta_fnr_F_minus_M": float(delta_f_minus_m),
            "delta_fnr_abs": float(delta_abs),
            "note": "If n_PD for a sex is 0, its FNR is NaN and ΔFNR is NaN.",
            "sex_normalization_note": "D2 mapping: exact 'male'->M and 'female'->F (case-sensitive); otherwise UNK.",
        },

        "confusion_by_sex_norm_at_thr_used": confusion_by_sex,

        "artifacts": {
            "predictions_csv": str(pred_csv_path),
            "roc_curve_png": str(roc_png),
            "confusion_matrix_png": str(cm_png),
            "confusion_matrix_M_png": str(cm_m_png) if cm_m_png is not None else None,
            "confusion_matrix_F_png": str(cm_f_png) if cm_f_png is not None else None,
        },

        "d7_out_root": D7_OUT_ROOT,
        "d2_out_root": D2_OUT_ROOT,
        "d2_manifest_all": D2_MANIFEST_ALL,

        "best_heads_path": str(best_heads_path),
        "backbone_ckpt": BACKBONE_CKPT,
        "dropout_p": float(DROPOUT_P),
        "timestamp": time.strftime("%Y%m%d_%H%M%S"),
    }

    metrics_path = run_dir / "metrics.json"
    with open(metrics_path, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2)

    print(f"[seed={seed}] DONE | test_AUROC={test_auc:.6f}")
    print(f"[seed={seed}] Threshold used (GLOBAL mean from D7 enhanced trainval): {thr_used:.6f}")
    print(f"[seed={seed}] WROTE:")
    print(" ", str(metrics_path))
    print(" ", str(pred_csv_path))
    print(" ", str(roc_png))
    print(" ", str(cm_png))

    return {
        "seed": int(seed),
        "thr_used": float(thr_used),
        "thr_note": thr_note,
        "test_auc": float(test_auc),
        "thr_metrics_test": thr_metrics_test,
        "fnr_by_sex": fnr_by_sex,
        "delta_signed": float(delta_f_minus_m),
        "delta_abs": float(delta_abs),
        "run_dir": str(run_dir),
        "predictions_csv": str(pred_csv_path),
    }

# -------------------------
# 18) Run all seeds and aggregate results
# -------------------------
# Computes:
# - AUROC mean and 95% CI (t, n=3)
# - threshold metrics mean ± SD across seeds
# - fairness mean ± SD across seeds
results = []
for seed in SEEDS:
    results.append(run_test_once(seed))

aurocs = [r["test_auc"] for r in results]
t_crit = 4.302652729911275  # df=2, 95% CI
n = len(aurocs)
mean_auc = float(np.mean(aurocs))
std_auc = float(np.std(aurocs, ddof=1)) if n > 1 else 0.0
half_width = float(t_crit * (std_auc / math.sqrt(n))) if n > 1 else 0.0
ci95 = [float(mean_auc - half_width), float(mean_auc + half_width)]

def _mean_sd(vals):
    # NaN-aware mean and standard deviation.
    vals = np.asarray(vals, dtype=np.float64)
    mu = float(np.nanmean(vals)) if np.any(~np.isnan(vals)) else float("nan")
    sd = float(np.nanstd(vals, ddof=1)) if np.sum(~np.isnan(vals)) > 1 else 0.0
    return mu, sd

# Aggregate standard threshold metrics across seeds.
thr_list = [r["thr_metrics_test"] for r in results]
keys = ["accuracy","precision","recall","f1","sensitivity","specificity","mcc","p_value_fisher_two_sided"]
agg = {}
for k in keys:
    v = [float(tm.get(k, float("nan"))) for tm in thr_list]
    mu, sd = _mean_sd(v)
    agg[k] = {
        "mean": float(mu),
        "sd": float(sd),
        "values_by_seed": {str(s): float(tm.get(k, float("nan"))) for s, tm in zip(SEEDS, thr_list)},
    }

# Confusion counts per seed (at the global threshold).
cm_by_seed = {
    str(s): {"tn": int(thr_list[i]["tn"]), "fp": int(thr_list[i]["fp"]), "fn": int(thr_list[i]["fn"]), "tp": int(thr_list[i]["tp"])}
    for i, s in enumerate(SEEDS)
}

# Fairness summaries per seed (ΔFNR and FNR per sex).
fnr_by_seed = {str(r["seed"]): r["fnr_by_sex"] for r in results}
delta_signed_by_seed = {str(r["seed"]): float(r["delta_signed"]) for r in results}
delta_abs_by_seed = {str(r["seed"]): float(r["delta_abs"]) for r in results}

fnr_m_vals, fnr_f_vals, n_pd_m_vals, n_pd_f_vals = [], [], [], []
d_signed_vals, d_abs_vals = [], []
for r in results:
    d = r["fnr_by_sex"] or {}
    fnr_m_vals.append(float(d.get("M", {}).get("fnr", float("nan"))))
    fnr_f_vals.append(float(d.get("F", {}).get("fnr", float("nan"))))
    n_pd_m_vals.append(float(d.get("M", {}).get("n_pos", float("nan"))))
    n_pd_f_vals.append(float(d.get("F", {}).get("n_pos", float("nan"))))
    d_signed_vals.append(float(r["delta_signed"]))
    d_abs_vals.append(float(r["delta_abs"]))

fnr_m_mean, fnr_m_sd = _mean_sd(fnr_m_vals)
fnr_f_mean, fnr_f_sd = _mean_sd(fnr_f_vals)
d_signed_mean, d_signed_sd = _mean_sd(d_signed_vals)
d_abs_mean, d_abs_sd = _mean_sd(d_abs_vals)

print("\nTest AUROC by seed:")
for r in results:
    print(f"  seed {r['seed']}: {r['test_auc']:.6f}")
print(f"\nMean Test AUROC: {mean_auc:.6f}")
print(f"95% CI (t, n=3): [{ci95[0]:.6f}, {ci95[1]:.6f}]")

print("\nTEST threshold used (GLOBAL mean val-opt from D7 enhanced trainval):")
print(f"  THR_USED_GLOBAL: {THR_USED_GLOBAL:.6f}")
if THR_GLOBAL_NOTE is not None:
    print("  NOTE:", THR_GLOBAL_NOTE)

print("\nThreshold metrics on D2 TEST @ THR_USED_GLOBAL (mean ± SD across seeds):")
for k in ["accuracy","precision","sensitivity","specificity","f1","mcc"]:
    print(f"  {k}: {agg[k]['mean']:.6f} ± {agg[k]['sd']:.6f}")
print("  fisher_p_value_two_sided:", f"{agg['p_value_fisher_two_sided']['mean']:.6g} ± {agg['p_value_fisher_two_sided']['sd']:.6g}")

print("\nFAIRNESS (H3) on D2 TEST @ THR_USED_GLOBAL across seeds (mean ± SD):")
print(f"  FNR_M: {fnr_m_mean:.6f} ± {fnr_m_sd:.6f}")
print(f"  FNR_F: {fnr_f_mean:.6f} ± {fnr_f_sd:.6f}")
print(f"  ΔFNR (F-M): {d_signed_mean:.6f} ± {d_signed_sd:.6f}")
print(f"  |ΔFNR|: {d_abs_mean:.6f} ± {d_abs_sd:.6f}")

# -------------------------
# 18.1) Build and write run-level summary + pointer files
# -------------------------
# Writes:
# - summary_test.json inside the run folder
# - history_index.jsonl appended (run log)
# - summary_latest.json overwritten (quick "latest run" reference)
# - last_run_pointer.json overwritten (quick pointer)
# - tag_run_pointer.json written under a stable tag folder
summary = {
    "run_tag_full_trainval_exp": str(FULL_TRAINVAL_EXP_TAG),
    "run_tag_safe": str(TAG_SAFE),
    "run_stamp": str(RUN_STAMP),

    "train_dataset": train_dataset_id,
    "test_dataset": d2_dataset_id,

    "d7_out_root": D7_OUT_ROOT,
    "d2_out_root": D2_OUT_ROOT,
    "d2_manifest_all": D2_MANIFEST_ALL,

    "enhanced_exp_required_substring_case_insensitive": REQUIRED_EXP_SUBSTRING,
    "trainval_experiment_used": str(chosen_exp),
    "trainval_summary_path": str(summary_trainval_path),

    "seeds": SEEDS,

    "n_test": int(len(test_df)),
    "label_counts_test": test_df["label_num"].value_counts(dropna=False).to_dict(),
    "sex_counts_test_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),

    "test_threshold_used": {
        "threshold_used_global": float(THR_USED_GLOBAL),
        "threshold_source": "D7 enhanced trainval summary_trainval.json -> val_optimal_threshold.mean_sd.mean",
        "note_global": (THR_GLOBAL_NOTE if THR_GLOBAL_NOTE is not None else ""),
        "per_seed_repetition_for_audit": {str(r["seed"]): float(r["thr_used"]) for r in results},
    },

    "test_aurocs_by_seed": {str(r["seed"]): float(r["test_auc"]) for r in results},
    "mean_test_auroc": float(mean_auc),
    "t_crit_df2_95": float(t_crit),
    "half_width_95_ci": float(half_width),
    "ci95_test_auroc": ci95,

    "threshold_metrics_mean_sd_test_at_thr_used": agg,
    "confusion_matrix_by_seed_test_at_thr_used": cm_by_seed,

    "fairness_test_at_thr_used": {
        "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at test_threshold_used_global.",
        "fnr_by_sex_norm_by_seed": fnr_by_seed,
        "delta_fnr_F_minus_M_by_seed": delta_signed_by_seed,
        "delta_fnr_abs_by_seed": delta_abs_by_seed,
        "fnr_M_mean_sd": {"mean": float(fnr_m_mean), "sd": float(fnr_m_sd)},
        "fnr_F_mean_sd": {"mean": float(fnr_f_mean), "sd": float(fnr_f_sd)},
        "delta_fnr_F_minus_M_mean_sd": {"mean": float(d_signed_mean), "sd": float(d_signed_sd)},
        "delta_fnr_abs_mean_sd": {"mean": float(d_abs_mean), "sd": float(d_abs_sd)},
    },

    "run_dirs": [r["run_dir"] for r in results],
    "predictions_csv_by_seed": {str(r["seed"]): str(r["predictions_csv"]) for r in results},
}

summary_path = RUN_ROOT / "summary_test.json"
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(summary, f, indent=2)

history_path = TEST_ROOT / "history_index.jsonl"
with open(history_path, "a", encoding="utf-8") as f:
    f.write(json.dumps(summary) + "\n")

summary_latest_path = TEST_ROOT / "summary_latest.json"
latest_summary_obj = {
    "run_tag_full_trainval_exp": str(FULL_TRAINVAL_EXP_TAG),
    "run_tag_safe": str(TAG_SAFE),
    "run_stamp": str(RUN_STAMP),
    "run_root": str(RUN_ROOT),
    "summary_test_json": str(summary_path),
    "seed_run_dirs": [str(RUN_ROOT / f"run_{train_dataset_id}_on_{d2_dataset_id}test_seed{s}") for s in SEEDS],
}
with open(summary_latest_path, "w", encoding="utf-8") as f:
    json.dump(latest_summary_obj, f, indent=2)

global_pointer_path = TEST_ROOT / "last_run_pointer.json"
with open(global_pointer_path, "w", encoding="utf-8") as f:
    json.dump(latest_summary_obj, f, indent=2)

tag_pointer_path = TAG_ROOT / "tag_run_pointer.json"
tag_pointer_obj = dict(latest_summary_obj)
tag_pointer_obj["tag_root"] = str(TAG_ROOT)
with open(tag_pointer_path, "w", encoding="utf-8") as f:
    json.dump(tag_pointer_obj, f, indent=2)

print("\nWROTE run summary:", str(summary_path))
print("APPENDED history index:", str(history_path))
print("WROTE latest summary:", str(summary_latest_path))
print("WROTE global pointer:", str(global_pointer_path))
print("WROTE tag pointer:", str(tag_pointer_path))
print("Open this folder to access artifacts:", str(RUN_ROOT))

# -------------------------
# 18.5) Write builder-style config/log placeholders
# -------------------------
# Writes:
# - run_config.json (run settings and pointers)
# - dataset_summary.json (key dataset stats for this test run)
# - preprocess_warnings.csv placeholder (kept for consistent folder structure)
_backup_if_exists(RUN_CONFIG_PATH)
_backup_if_exists(WARNINGS_CSV_PATH)
_backup_if_exists(DATASET_SUMMARY_PATH)

run_config = {
    "mode": f"D7_{TAG_SAFE}_on_D2_Test",
    "created_utc": datetime.utcnow().isoformat(),
    "run_stamp": RUN_STAMP,

    "d7_out_root": D7_OUT_ROOT,
    "d2_out_root": D2_OUT_ROOT,
    "d2_manifest_all": D2_MANIFEST_ALL,

    "enhanced_exp_required_substring_case_insensitive": REQUIRED_EXP_SUBSTRING,
    "trainval_experiment_used": str(chosen_exp),
    "trainval_exp_tag": str(FULL_TRAINVAL_EXP_TAG),
    "trainval_summary_path": str(summary_trainval_path),

    "threshold_source": "summary_trainval.json -> val_optimal_threshold.mean_sd.mean",
    "threshold_used_global": float(THR_USED_GLOBAL),
    "threshold_note_global": (THR_GLOBAL_NOTE if THR_GLOBAL_NOTE is not None else ""),

    "seeds": SEEDS,
    "use_amp": bool(USE_AMP and DEVICE.type == "cuda"),
    "per_device_bs": int(PER_DEVICE_BS),
    "effective_bs": int(PER_DEVICE_BS * GRAD_ACCUM),
    "num_workers": int(NUM_WORKERS),
    "pin_memory": bool(PIN_MEMORY),
    "backbone_ckpt": BACKBONE_CKPT,
    "dropout_p": float(DROPOUT_P),

    "monolingual_test_runs_root": str(TEST_ROOT),
    "run_root": str(RUN_ROOT),
    "summary_test_json": str(summary_path),
    "last_run_pointer_json": str(global_pointer_path),
    "tag_run_pointer_json": str(tag_pointer_path),
}
with open(RUN_CONFIG_PATH, "w", encoding="utf-8") as f:
    json.dump(run_config, f, indent=2)

dataset_summary = {
    "mode": f"D7_{TAG_SAFE}_on_D2_Test",
    "created_utc": datetime.utcnow().isoformat(),
    "status": "SUCCESS",
    "run_stamp": RUN_STAMP,

    "d2_dataset_id": d2_dataset_id,
    "n_test": int(len(test_df)),
    "label_counts_test": test_df["label_num"].value_counts(dropna=False).to_dict(),
    "sex_counts_test_raw": test_df["sex"].value_counts(dropna=False).to_dict(),
    "sex_counts_test_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),

    "trainval_experiment_used": str(chosen_exp),
    "trainval_exp_tag": str(FULL_TRAINVAL_EXP_TAG),
    "trainval_summary_path": str(summary_trainval_path),

    "threshold_used_global": float(THR_USED_GLOBAL),

    "monolingual_test_runs_root": str(TEST_ROOT),
    "run_root": str(RUN_ROOT),
    "summary_test_json": str(summary_path),
    "last_run_pointer_json": str(global_pointer_path),
}
with open(DATASET_SUMMARY_PATH, "w", encoding="utf-8") as f:
    json.dump(dataset_summary, f, indent=2)

with open(WARNINGS_CSV_PATH, "w", encoding="utf-8") as f:
    f.write("ts,level,message\n")

print("\nWROTE (builder-aligned):")
print(" ", str(RUN_CONFIG_PATH))
print(" ", str(WARNINGS_CSV_PATH))
print(" ", str(DATASET_SUMMARY_PATH))

# -------------------------
# 19) Stop runtime (release GPU)
# -------------------------
# Frees Colab GPU resources after the run finishes.
print("\nAll done. Unassigning the runtime...")
try:
    from google.colab import runtime  # type: ignore
    runtime.unassign()
except Exception as e:
    print("Could not unassign runtime automatically. Stop runtime manually if needed.")
    print("Reason:", repr(e))

The following cell builds the enhanced D7 training split **train_enh2** by starting with the full original D7 training set and then adding a new, carefully chosen subset from the D2 training data. The added D2 data comes from about **10 percent of all D2 training speakers**, based on the total number of D2 training speakers before any filtering. Speakers that were already included in the earlier **train_enh1** build are explicitly excluded, so there is no overlap between the two enhanced training sets. The output is a new clips folder and a matching manifest that can be used directly in later D7 training and validation steps.

The cell begins by mounting Google Drive if needed and defining the preprocessed data roots for D7 and D2. It creates the required output folders under D7 for clips, manifests, logs, and configuration files. It then loads `manifest_all.csv` for both datasets, along with `manifest_train_enh1.csv`, which is used to identify D2 speakers that must be excluded. At this stage, several basic checks are applied: the column layout must match the expected schema, dataset identifiers must be correct, and class labels must be limited to Healthy (0) or Parkinson’s (1). Any literal `"NaN"` strings are converted to real missing values, and sex values are standardized to a simple **M/F** format.

Next, the cell gathers the source rows for the build. This includes all D7 rows with `split = "train"` and all D2 rows with `split = "train"`. For D2, it checks that each speaker has a single, consistent diagnosis across all clips, which is required for speaker-level sampling. Using the full D2 training speaker count, it computes the target number of speakers as **10 percent**, rounds this up, and then adjusts it to be an even number so Healthy and Parkinson’s speakers can be selected in equal amounts. The exclusion list from **train_enh1** is applied, and the code confirms that enough eligible speakers remain. From the remaining pool, the required number of Healthy and Parkinson’s speakers is selected at random using a fixed seed, and all training clips from those speakers are included.

Before copying any audio, the cell checks that every referenced source file exists for both the D7 training data and the selected D2 subset. It then creates a copy plan with strict rules to avoid overwriting files. D7 training clips are copied into `clips/train_enh2/` using their original filenames. D2 clips are copied into the same folder using new, deterministic filenames that encode the split, class, speaker ID, task type, and a stable index. A preflight check looks for existing destination files: files with the correct size are skipped, while any size mismatch or file error causes the run to stop. At this point, the cell writes initial run metadata, warnings, and a preliminary dataset summary so the process is traceable even if it fails.

During the copy step, only missing files are copied. Any copy error is recorded and treated as a failure. After copying completes successfully, the cell creates `manifest_train_enh2.csv` by combining the rewritten D7 training rows with the newly added D2 rows. All rows are marked with `split = "train_enh2"`. D7 rows keep their identity but point to the new clip locations. D2 rows are rewritten so they appear as D7 entries for training, with updated file paths and sample IDs, and each one includes a `source_dataset` field set to “D2” to preserve where it came from. The final manifest is checked to ensure there are no literal `"NaN"` strings and that every listed audio file exists, and then it is written safely to disk.

The cell finishes by writing a final dataset summary and updating the run configuration with details about speaker selection, exclusions, file checks, and copy results. It prints the paths to the new **train_enh2** clips folder, the corresponding manifest, and the related log and configuration files for reference.

In [None]:
# =========================
# D7 train_enh2 Builder (Add fixed 10% of D2 train speakers)
# Inputs: D7 full manifest, D2 full manifest, prior train_enh1 manifest
# Outputs: train_enh2 clips folder, train_enh2 manifest, run config, warnings log, summary JSON
# =========================
# =========================
# D7 TRAIN_ENH2 BUILDER (CRASH-PROOF, THIRD-PARTY AUDITABLE) — ADD ceil(10% of D2 TRAIN speakers)
# FIXED DENOMINATOR POLICY:
# - The 10% target is computed from the FULL D2 TRAIN speaker count (before exclusions)
# - Exclusions only affect ELIGIBLE pool, not the draw size
#
# - Creates: <D7_OUT_ROOT>/clips/train_enh2/
# - Writes:
#     * manifests/manifest_train_enh2.csv
#     * config/D7_Enh2_on_D2_Test/run_config.json
#     * logs/D7_Enh2_on_D2_Test/preprocess_warnings.csv
#     * logs/D7_Enh2_on_D2_Test/dataset_summary.json
# - COPY only, no overwrite
# =========================

import os, json, re, shutil, math
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

# -------------------------
# 0) Drive access (Colab-friendly)
# -------------------------
# Makes sure files can be read and written when running in Google Colab.
try:
    from google.colab import drive  # type: ignore
    if not os.path.exists("/content/drive"):
        drive.mount("/content/drive")
except Exception:
    pass

# -------------------------
# 1) Key inputs and output locations
# -------------------------
# Inputs: D7 manifest_all, D2 manifest_all, and the prior train_enh1 manifest (for exclusions).
# Outputs: a new train_enh2 clips folder and manifest, plus logs and a run config record.
D7_OUT_ROOT = Path("/content/drive/MyDrive/AI_PD_Project/Datasets/D7-Multilingual (D1_D4_D5v2_D6)/preprocessed_v1")
D2_OUT_ROOT = Path("/content/drive/MyDrive/AI_PD_Project/Datasets/D2-Slovak (EWA-DB)/EWA-DB/preprocessed_v1")

D7_MANIFEST_ALL = D7_OUT_ROOT / "manifests" / "manifest_all.csv"
D2_MANIFEST_ALL = D2_OUT_ROOT / "manifests" / "manifest_all.csv"

# Prior selection source (used to enforce zero overlap with draw 1)
MANIFEST_TRAIN_ENH1 = Path(
    "/content/drive/MyDrive/AI_PD_Project/Datasets/D7-Multilingual (D1_D4_D5v2_D6)/preprocessed_v1/manifests/manifest_train_enh1.csv"
)

# New outputs
TRAIN_ENH2_DIR = D7_OUT_ROOT / "clips" / "train_enh2"
MANIFEST_TRAIN_ENH2 = D7_OUT_ROOT / "manifests" / "manifest_train_enh2.csv"

# Logs/config for this builder run
RUN_FOLDER = "D7_Enh2_on_D2_Test"
LOGS_SUBDIR = D7_OUT_ROOT / "logs" / RUN_FOLDER
CFG_SUBDIR  = D7_OUT_ROOT / "config" / RUN_FOLDER

WARNINGS_CSV = LOGS_SUBDIR / "preprocess_warnings.csv"
SUMMARY_JSON = LOGS_SUBDIR / "dataset_summary.json"
RUN_CONFIG_JSON = CFG_SUBDIR / "run_config.json"

# Standard folders (created if missing)
clips_dir = D7_OUT_ROOT / "clips"
manif_dir = D7_OUT_ROOT / "manifests"
logs_dir  = D7_OUT_ROOT / "logs"
cfg_dir   = D7_OUT_ROOT / "config"

for d in [clips_dir, manif_dir, logs_dir, cfg_dir, LOGS_SUBDIR, CFG_SUBDIR]:
    d.mkdir(parents=True, exist_ok=True)
TRAIN_ENH2_DIR.mkdir(parents=True, exist_ok=True)

# -------------------------
# 2) Manifest schema (fixed column list)
# -------------------------
# Locks column names and ordering so downstream training code stays stable.
CANON_COLS = [
    "split",
    "dataset",
    "task",
    "speaker_id",
    "sample_id",
    "label_str",
    "label_num",
    "age",
    "sex",
    "speaker_key_rel",
    "clip_path",
    "duration_sec",
    "source_path",
    "clip_start_sec",
    "clip_end_sec",
    "sr_hz",
    "channels",
    "clip_is_contiguous",
]
FINAL_COLS = CANON_COLS + ["source_dataset"]

# -------------------------
# 3) Sampling policy settings (kept explicit for repeatability)
# -------------------------
# - target size uses FULL D2 train speaker count (before exclusions)
# - speaker-balanced draw (same number of HC and PD speakers)
TEN_PCT         = 0.10
ROUNDING_POLICY = "ceil_then_make_even_by_rounding_down_if_needed"
BALANCE_POLICY  = "speaker-balanced"
SPEAKER_ID_COL  = "speaker_id"

# Random seed used for this draw (draw 2)
RNG_SEED        = 2024

LABEL_MAP_NOTE  = "label_num mapping: 0=Healthy, 1=Parkinson"

# -------------------------
# 4) Small utilities and logging
# -------------------------
# warnings_rows is written to CSV and used to stop on fatal issues.
warnings_rows = []

def require(cond: bool, msg: str):
    # Simple guard helper used throughout the script.
    if not cond:
        raise RuntimeError(msg)

def add_warn(src: str, level: str, code: str, message: str, **extra):
    # Records issues in a structured way (later written to preprocess_warnings.csv).
    row = {
        "ts": datetime.utcnow().isoformat(),
        "src": src,
        "level": str(level).upper(),
        "code": code,
        "message": message,
    }
    row.update(extra)
    warnings_rows.append(row)

def count_by_level(rows):
    # Quick summary for console prints and run summaries.
    out = {"ERROR": 0, "WARN": 0, "INFO": 0}
    for r in rows:
        lvl = str(r.get("level", "INFO")).upper()
        out[lvl] = out.get(lvl, 0) + 1
    return out

def atomic_write_text(dst: Path, text: str):
    # Writes via a temp file then renames, to avoid partial files on crashes.
    tmp = dst.with_suffix(dst.suffix + ".tmp")
    with open(tmp, "w", encoding="utf-8") as f:
        f.write(text)
    os.replace(tmp, dst)

def atomic_write_csv(dst: Path, df: pd.DataFrame):
    # Same atomic pattern for CSV (also enforces blank for NaN).
    tmp = dst.with_suffix(dst.suffix + ".tmp")
    df.to_csv(tmp, index=False, na_rep="")
    os.replace(tmp, dst)

def safe_token(s, max_len=32, default="NA"):
    # Makes a filesystem-safe token for deterministic file naming.
    if pd.isna(s):
        return default
    s = str(s).strip()
    s = re.sub(r"\s+", "_", s)
    s = re.sub(r"[^A-Za-z0-9_]+", "", s)
    return (s[:max_len] if s else default)

def label_str_from_num(v):
    # Human-readable label string derived from label_num.
    if pd.isna(v):
        return np.nan
    iv = int(v)
    if iv == 0:
        return "Healthy"
    if iv == 1:
        return "Parkinson"
    return np.nan

def hc_pd_from_label_num(v):
    # Compact label used in generated filenames.
    return "PD" if int(v) == 1 else "HC"

def normalize_sex_to_MF_D7(x):
    # Standardize D7 sex values to M/F/NaN.
    if pd.isna(x):
        return np.nan
    s = str(x).strip().lower()
    if s in ["m", "male"]:
        return "M"
    if s in ["f", "female"]:
        return "F"
    if s in ["", "nan", "none", "unknown", "u"]:
        return np.nan
    return np.nan

def normalize_sex_to_MF_D2(x):
    # Standardize D2 sex values to M/F/NaN (D2 uses male/female).
    if pd.isna(x):
        return np.nan
    s = str(x).strip().lower()
    if s in ["male", "m"]:
        return "M"
    if s in ["female", "f"]:
        return "F"
    if s in ["", "nan", "none", "unknown", "u"]:
        return np.nan
    return np.nan

def summarize_counts(df: pd.DataFrame, split_name: str):
    # Final dataset summary used by the logs JSON.
    out = {}
    out["total_rows"] = int(len(df))
    out["split"] = split_name
    out["label_counts_total"] = {str(k): int(v) for k, v in df["label_num"].value_counts(dropna=False).to_dict().items()}
    out["by_source_dataset"] = {sd: int((df["source_dataset"] == sd).sum()) for sd in sorted(df["source_dataset"].dropna().unique())}
    out["sex_counts"] = {str(k): int(v) for k, v in df["sex"].value_counts(dropna=False).to_dict().items()}
    out["n_unique_speakers"] = int(df["speaker_id"].astype(str).nunique()) if "speaker_id" in df.columns else int(0)
    return out

# -------------------------
# 5) Load manifests and validate basics
# -------------------------
# Loads D7 and D2 manifests, checks required columns, normalizes key fields.
# Also loads train_enh1 to build the exclusion list (zero overlap with draw 1).
print("D7_OUT_ROOT:", str(D7_OUT_ROOT))
print("D2_OUT_ROOT:", str(D2_OUT_ROOT))
print("D7_MANIFEST_ALL:", str(D7_MANIFEST_ALL))
print("D2_MANIFEST_ALL:", str(D2_MANIFEST_ALL))
print("MANIFEST_TRAIN_ENH1:", str(MANIFEST_TRAIN_ENH1))
print("TRAIN_ENH2_DIR:", str(TRAIN_ENH2_DIR))
print("MANIFEST_TRAIN_ENH2:", str(MANIFEST_TRAIN_ENH2))
print("LOGS_SUBDIR:", str(LOGS_SUBDIR))
print("CFG_SUBDIR:", str(CFG_SUBDIR))
print("RNG_SEED:", int(RNG_SEED))

require(D7_MANIFEST_ALL.exists(), f"Missing D7 manifest_all.csv: {str(D7_MANIFEST_ALL)}")
require(D2_MANIFEST_ALL.exists(), f"Missing D2 manifest_all.csv: {str(D2_MANIFEST_ALL)}")
require(MANIFEST_TRAIN_ENH1.exists(), f"Missing prior manifest_train_enh1.csv: {str(MANIFEST_TRAIN_ENH1)}")

d7 = pd.read_csv(D7_MANIFEST_ALL)
d2 = pd.read_csv(D2_MANIFEST_ALL)
enh1 = pd.read_csv(MANIFEST_TRAIN_ENH1)

# Schema checks for D7/D2; enh1 must include speaker_id and source_dataset.
missing_d7 = [c for c in CANON_COLS if c not in d7.columns]
missing_d2 = [c for c in CANON_COLS if c not in d2.columns]
require(len(missing_d7) == 0, f"D7 manifest missing required columns: {missing_d7}")
require(len(missing_d2) == 0, f"D2 manifest missing required columns: {missing_d2}")

require("speaker_id" in enh1.columns, f"manifest_train_enh1.csv missing 'speaker_id'. Found: {list(enh1.columns)}")
require("source_dataset" in enh1.columns, f"manifest_train_enh1.csv missing 'source_dataset'. Found: {list(enh1.columns)}")

# Ensure source_dataset exists on both base manifests for later summaries.
if "source_dataset" not in d7.columns:
    d7["source_dataset"] = "D7"
if "source_dataset" not in d2.columns:
    d2["source_dataset"] = "D2"

# Convert literal "NaN" strings into real missing values.
for df in [d7, d2, enh1]:
    for col in ["sex", "age", "duration_sec", "clip_start_sec", "clip_end_sec", "speaker_key_rel", "speaker_id", "task", "sample_id"]:
        if col in df.columns:
            df[col] = df[col].replace("NaN", np.nan)

def infer_dataset_id(df: pd.DataFrame, fallback: str) -> str:
    # Infers the dataset id from the most common non-null value in the dataset column.
    if "dataset" in df.columns and df["dataset"].notna().any():
        return str(df["dataset"].astype(str).value_counts(dropna=True).idxmax())
    return fallback

# Hard guards: these manifests should identify themselves as D7 and D2.
require(infer_dataset_id(d7, "DX") == "D7", "Expected D7 manifest dataset=='D7'.")
require(infer_dataset_id(d2, "DX") == "D2", "Expected D2 manifest dataset=='D2'.")

# Keep only the intended dataset rows.
d7 = d7[d7["dataset"].astype(str) == "D7"].copy()
d2 = d2[d2["dataset"].astype(str) == "D2"].copy()

# Validate labels and create label_str.
for name, df in [("D7", d7), ("D2", d2)]:
    bad = sorted(set(df["label_num"].dropna().unique()) - {0, 1})
    require(len(bad) == 0, f"{name} label_num contains values outside {{0,1}}: {bad}")
    df["label_str"] = df["label_num"].map(label_str_from_num)

# Standardize sex encoding to M/F/NaN.
d7["sex"] = d7["sex"].map(normalize_sex_to_MF_D7)
d2["sex"] = d2["sex"].map(normalize_sex_to_MF_D2)

# -------------------------
# 6) Select D7 train split (base content for train_enh2)
# -------------------------
# All D7 train clips are included in train_enh2 (copied into a new folder).
d7_train = d7[d7["split"].astype(str) == "train"].copy()
require(len(d7_train) > 0, "D7 train split has 0 rows.")
print("\nD7 train rows:", int(len(d7_train)))
print("D7 train label counts:", d7_train["label_num"].value_counts(dropna=False).to_dict())

# -------------------------
# 7) Build D2 exclusion list from train_enh1
# -------------------------
# Excludes any D2 speakers already used in train_enh1 to enforce zero overlap.
enh1_d2 = enh1[enh1["source_dataset"].astype(str) == "D2"].copy()
prior_d2_speakers = sorted(enh1_d2["speaker_id"].dropna().astype(str).unique().tolist())

print("\nZero-overlap exclusion from manifest_train_enh1:")
print("  Prior D2 rows in enh1:", int(len(enh1_d2)))
print("  Prior D2 unique speakers to exclude:", int(len(prior_d2_speakers)))
print("  Example excluded speakers (up to 20):", prior_d2_speakers[:20])

# -------------------------
# 8) Compute fixed draw size from FULL D2 train speaker universe
# -------------------------
# Key rule: the 10% target uses the full D2 train speaker count (before exclusions).
# Additional rule: the target is forced to be even, then split evenly across HC and PD.
d2_train = d2[d2["split"].astype(str) == "train"].copy()
require(len(d2_train) > 0, "D2 train split has 0 rows.")
require(SPEAKER_ID_COL in d2_train.columns, f"D2 missing speaker id column: {SPEAKER_ID_COL}")

print("\nD2 train rows:", int(len(d2_train)))
print("D2 train label counts:", d2_train["label_num"].value_counts(dropna=False).to_dict())

# Guard: each speaker must have a single consistent label across all clips.
speaker_labels = (
    d2_train.groupby(SPEAKER_ID_COL)["label_num"]
    .apply(lambda s: sorted(set(s.dropna().astype(int).tolist())))
)
mixed = speaker_labels[speaker_labels.apply(lambda x: len(x) != 1)]
if len(mixed) > 0:
    raise RuntimeError(
        "D2 train has speakers with mixed label_num values across clips; Definition B sampling cannot proceed.\n"
        f"Examples: {mixed.head(10).to_dict()}"
    )

# Build the full speaker list and full class lists (HC vs PD).
speaker_to_label_full = speaker_labels.apply(lambda x: int(x[0])).to_dict()
all_speakers_full = sorted([str(x) for x in speaker_to_label_full.keys()])
total_speakers_full = len(all_speakers_full)

hc_full = sorted([str(spk) for spk, y in speaker_to_label_full.items() if int(y) == 0])
pd_full = sorted([str(spk) for spk, y in speaker_to_label_full.items() if int(y) == 1])
require(len(hc_full) > 0 and len(pd_full) > 0, "D2 train does not contain both HC and PD speakers.")

# Fixed-denominator target computation (draw size does not change after exclusions).
target_total = int(math.ceil(TEN_PCT * total_speakers_full))
if target_total % 2 != 0:
    target_total -= 1
if target_total < 2:
    target_total = 2
target_per_class = target_total // 2

print("\nFixed-denominator target computation:")
print("  total D2 train speakers (FULL):", total_speakers_full)
print(f"  target_total = ceil(0.10 * {total_speakers_full}) then even:", target_total, f"=> {target_per_class} HC + {target_per_class} PD")

# -------------------------
# 9) Apply exclusions, then sample from the eligible pool
# -------------------------
# Exclusions shrink the eligible pool only; the target draw size stays the same.
exclude_set = set(prior_d2_speakers)
speaker_to_label_eligible = {str(spk): int(lbl) for spk, lbl in speaker_to_label_full.items() if str(spk) not in exclude_set}

hc_eligible = sorted([spk for spk, y in speaker_to_label_eligible.items() if y == 0])
pd_eligible = sorted([spk for spk, y in speaker_to_label_eligible.items() if y == 1])

print("\nEligible pool after exclusion (does not affect target size):")
print("  eligible speakers total:", len(speaker_to_label_eligible))
print("  eligible HC speakers:", len(hc_eligible))
print("  eligible PD speakers:", len(pd_eligible))

require(len(hc_eligible) >= target_per_class, f"Not enough eligible HC speakers to draw {target_per_class}. Eligible HC={len(hc_eligible)}")
require(len(pd_eligible) >= target_per_class, f"Not enough eligible PD speakers to draw {target_per_class}. Eligible PD={len(pd_eligible)}")

# Random but repeatable sampling (seeded RNG).
rng = np.random.default_rng(int(RNG_SEED))
sel_hc = sorted(rng.choice(hc_eligible, size=target_per_class, replace=False).tolist())
sel_pd = sorted(rng.choice(pd_eligible, size=target_per_class, replace=False).tolist())
selected_speakers = sorted(sel_hc + sel_pd)

# Keep all D2 train clips for the selected speakers.
d2_sel = d2_train[d2_train[SPEAKER_ID_COL].astype(str).isin([str(x) for x in selected_speakers])].copy()

print("\nD2 speaker sampling (DRAW 2, fixed denominator + zero overlap):")
print("  target_total (fixed):", target_total, "=>", target_per_class, "HC +", target_per_class, "PD")
print("  selected D2 rows (all clips for selected speakers):", int(len(d2_sel)))
print("  selected label counts:", d2_sel["label_num"].value_counts(dropna=False).to_dict())

# -------------------------
# 10) Confirm all source clip paths exist (fail early)
# -------------------------
# Prevents partial builds caused by missing audio files.
def fail_fast_missing_paths(df: pd.DataFrame, name: str):
    missing_paths = []
    for p in tqdm(df["clip_path"].astype(str).tolist(), desc=f"Check {name} clip_path exists", dynamic_ncols=True):
        if not os.path.exists(p):
            missing_paths.append(p)
            if len(missing_paths) >= 25:
                break
    if missing_paths:
        raise FileNotFoundError(f"{name}: missing clip_path(s). Examples (up to 25): {missing_paths}")

fail_fast_missing_paths(d7_train, "D7 TRAIN")
fail_fast_missing_paths(d2_sel, "D2 TRAIN (selected speakers)")

# -------------------------
# 11) Build the copy plan and enforce no-overwrite rules
# -------------------------
# Two sources are copied into one folder:
# - D7 train clips keep their original filenames
# - D2 clips get deterministic new filenames to avoid collisions
copy_plan = []

# A) D7 → train_enh2 (same filenames)
for i, row in d7_train.reset_index(drop=True).iterrows():
    src_path = Path(str(row["clip_path"]))
    dst_path = TRAIN_ENH2_DIR / src_path.name
    copy_plan.append({
        "src_path": str(src_path),
        "dst_path": str(dst_path),
        "origin": "D7_train_existing",
        "source_dataset": str(row.get("source_dataset", "D7")),
        "row_index": int(i),
    })

# B) D2 selected → train_enh2 (deterministic names)
d2_sel_reset = d2_sel.reset_index(drop=True).copy()
d2_sel_reset["_speaker_tok"] = d2_sel_reset["speaker_id"].map(lambda x: safe_token(x, 32, "NA"))
d2_sel_reset["_task_tok"] = d2_sel_reset["task"].map(lambda x: safe_token(x, 12, "0"))
d2_sel_reset["_clip_path_str"] = d2_sel_reset["clip_path"].astype(str)
d2_sel_reset = d2_sel_reset.sort_values(by=["_speaker_tok", "_task_tok", "_clip_path_str"]).reset_index(drop=True)

for j, row in d2_sel_reset.iterrows():
    src_path = Path(str(row["clip_path"]))
    hc_pd = hc_pd_from_label_num(row["label_num"])
    spk_tok = safe_token(row["speaker_id"], 32, "NA")
    task_tok = safe_token(row["task"], 12, "0")
    out_name = f"D7_D2add_train_enh2_{hc_pd}_{spk_tok}_{task_tok}_{j+1:06d}.wav"
    dst_path = TRAIN_ENH2_DIR / out_name
    copy_plan.append({
        "src_path": str(src_path),
        "dst_path": str(dst_path),
        "origin": "D2_train_selected",
        "source_dataset": "D2",
        "row_index": int(j),
    })

# Preflight: decide copy vs skip, and detect dangerous conflicts.
n_dest_exists_ok = 0
n_dest_exists_mismatch = 0
n_will_copy = 0

print("\nPreflight: destination existence and size checks (no overwrite)...")
for item in tqdm(copy_plan, desc="Preflight (dest checks)", dynamic_ncols=True):
    sp = Path(item["src_path"])
    dp = Path(item["dst_path"])
    require(sp.exists(), f"Source clip missing at preflight: {str(sp)}")
    if dp.exists():
        try:
            # If sizes match, treat as already-copied and skip later.
            if dp.stat().st_size == sp.stat().st_size:
                n_dest_exists_ok += 1
            else:
                # Same name but different content is treated as a fatal error.
                n_dest_exists_mismatch += 1
                add_warn(
                    "D7_TRAIN_ENH2", "ERROR", "DEST_EXISTS_SIZE_MISMATCH",
                    "Destination exists but file size differs from source",
                    src_path=str(sp), dest_path=str(dp),
                    src_size=int(sp.stat().st_size), dest_size=int(dp.stat().st_size),
                )
        except Exception as e:
            n_dest_exists_mismatch += 1
            add_warn(
                "D7_TRAIN_ENH2", "ERROR", "DEST_EXISTS_STAT_ERROR",
                "Failed to stat source/destination during preflight",
                src_path=str(sp), dest_path=str(dp), error=repr(e),
            )
    else:
        n_will_copy += 1

preflight_stats = {
    "total_planned_files": int(len(copy_plan)),
    "n_dest_exists_ok": int(n_dest_exists_ok),
    "n_dest_exists_mismatch": int(n_dest_exists_mismatch),
    "n_will_copy": int(n_will_copy),
    "warnings_by_level": count_by_level(warnings_rows),
}

print("\nPreflight summary:")
print("  Planned files:", int(len(copy_plan)))
print("  Destination exists (size OK):", n_dest_exists_ok)
print("  Destination exists (size mismatch/stat error):", n_dest_exists_mismatch)
print("  Will copy:", n_will_copy)
print("  Warnings by level:", preflight_stats["warnings_by_level"])

# -------------------------
# 12) Write run config and early summary (before copying)
# -------------------------
# These files capture the selection and policies even if copying fails later.
run_config = {
    "dataset": "D7",
    "mode": "train_enh2_builder",
    "created_utc": datetime.utcnow().isoformat(),
    "run_folder": RUN_FOLDER,
    "d7_out_root": str(D7_OUT_ROOT),
    "d2_out_root": str(D2_OUT_ROOT),
    "d7_manifest_all": str(D7_MANIFEST_ALL),
    "d2_manifest_all": str(D2_MANIFEST_ALL),
    "manifest_train_enh1": str(MANIFEST_TRAIN_ENH1),
    "train_enh2_dir": str(TRAIN_ENH2_DIR),
    "manifest_train_enh2": str(MANIFEST_TRAIN_ENH2),
    "policy": {
        "definition": "Add ceil(10%) of D2 TRAIN by speaker (Definition B) into D7 training clips folder",
        "pct_speakers": float(TEN_PCT),
        "denominator_policy": "FIXED_FULL_D2_TRAIN_SPEAKER_COUNT",
        "rounding_policy": ROUNDING_POLICY,
        "balance_policy": BALANCE_POLICY,
        "speaker_id_col_used": SPEAKER_ID_COL,
        "rng_seed": int(RNG_SEED),
        "label_note": LABEL_MAP_NOTE,
        "file_operation": "copy",
        "no_overwrite_rule": "skip if dest exists with matching size; error if size differs",
        "zero_overlap": {
            "enabled": True,
            "exclude_speakers_source": "manifest_train_enh1 where source_dataset == 'D2'",
            "n_excluded_d2_speakers": int(len(prior_d2_speakers)),
            "excluded_d2_speakers": prior_d2_speakers,
        },
    },
    "inputs": {
        "d7_train_rows": int(len(d7_train)),
        "d2_train_rows": int(len(d2_train)),
        "d2_train_total_speakers_full": int(total_speakers_full),
        "d2_train_total_speakers_eligible": int(len(speaker_to_label_eligible)),
        "d2_train_speakers_hc_full": int(len(hc_full)),
        "d2_train_speakers_pd_full": int(len(pd_full)),
        "d2_train_speakers_hc_eligible": int(len(hc_eligible)),
        "d2_train_speakers_pd_eligible": int(len(pd_eligible)),
    },
    "selection": {
        "target_total_speakers_fixed": int(target_total),
        "target_per_class_fixed": int(target_per_class),
        "selected_speakers_hc": sel_hc,
        "selected_speakers_pd": sel_pd,
        "selected_speakers_all": selected_speakers,
        "selected_d2_rows": int(len(d2_sel)),
        "selected_d2_label_counts": d2_sel["label_num"].value_counts(dropna=False).to_dict(),
    },
    "preflight": preflight_stats,
}

early_summary = {
    "dataset": "D7",
    "created_utc": datetime.utcnow().isoformat(),
    "status": "PRECHECK_COMPLETE",
    "d7_out_root": str(D7_OUT_ROOT),
    "train_enh2_dir": str(TRAIN_ENH2_DIR),
    "preflight": preflight_stats,
    "selection": run_config["selection"],
}

atomic_write_csv(WARNINGS_CSV, pd.DataFrame(warnings_rows))
atomic_write_text(RUN_CONFIG_JSON, json.dumps(run_config, indent=2))
atomic_write_text(SUMMARY_JSON, json.dumps(early_summary, indent=2))

# Stop immediately if preflight found fatal conflicts.
fatal_pre = [w for w in warnings_rows if str(w.get("level", "")).upper() == "ERROR"]
if fatal_pre:
    raise RuntimeError(f"Preflight failed with {len(fatal_pre)} ERROR(s). See {str(WARNINGS_CSV)}")

# -------------------------
# 13) Copy clips into the train_enh2 folder (copy-only, no overwrite)
# -------------------------
# Copies only missing destination files; existing files are left untouched.
print("\nCopy stage: copying into train_enh2 (no overwrite)...")

copied = 0
skipped_exists = 0
copy_errors = 0

for item in tqdm(copy_plan, desc="Copying clips", dynamic_ncols=True):
    sp = Path(item["src_path"])
    dp = Path(item["dst_path"])

    if not sp.exists():
        # Protects against sources disappearing mid-run.
        copy_errors += 1
        add_warn("D7_TRAIN_ENH2", "ERROR", "SOURCE_CLIP_DISAPPEARED", "Source clip missing during copy stage", clip_path=str(sp))
        continue

    if dp.exists():
        skipped_exists += 1
        continue

    try:
        shutil.copy2(sp, dp)
        copied += 1
    except Exception as e:
        copy_errors += 1
        add_warn("D7_TRAIN_ENH2", "ERROR", "COPY_FAILED", "Failed to copy file", src_path=str(sp), dest_path=str(dp), error=repr(e))

copy_stats = {
    "copied": int(copied),
    "skipped_exists": int(skipped_exists),
    "copy_errors": int(copy_errors),
    "total_planned_files": int(len(copy_plan)),
    "warnings_by_level": count_by_level(warnings_rows),
}

# Persist warnings after the copy stage.
atomic_write_csv(WARNINGS_CSV, pd.DataFrame(warnings_rows))

# Stop if any fatal copy errors occurred.
fatal_copy = [w for w in warnings_rows if str(w.get("level", "")).upper() == "ERROR"]
if fatal_copy:
    run_config["copy_stats"] = copy_stats
    atomic_write_text(RUN_CONFIG_JSON, json.dumps(run_config, indent=2))
    fail_summary = dict(early_summary)
    fail_summary["status"] = "FAILED_DURING_COPY"
    fail_summary["copy_stats"] = copy_stats
    atomic_write_text(SUMMARY_JSON, json.dumps(fail_summary, indent=2))
    raise RuntimeError(f"Copy failed with {len(fatal_copy)} ERROR(s). See {str(WARNINGS_CSV)}")

# -------------------------
# 14) Build the train_enh2 manifest
# -------------------------
# Creates a single manifest that points to the new train_enh2 clip paths.
# - D7 rows: same clips renamed only by folder move
# - D2 rows: new deterministic filenames, dataset set to D7, source_dataset set to D2
d7_enh2 = d7_train[CANON_COLS].copy()
d7_enh2["split"] = "train_enh2"
d7_enh2["dataset"] = "D7"
d7_enh2["clip_path"] = d7_train["clip_path"].astype(str).map(lambda p: str(TRAIN_ENH2_DIR / Path(p).name))
d7_enh2["sample_id"] = d7_enh2["clip_path"].map(lambda p: Path(p).stem)
d7_enh2["source_dataset"] = d7_train.get("source_dataset", pd.Series(["D7"] * len(d7_train))).astype(str).tolist()

d2_enh2 = d2_sel_reset[CANON_COLS].copy()
d2_enh2["split"] = "train_enh2"
d2_enh2["dataset"] = "D7"

# Recreate the deterministic output names so the manifest matches the copied files.
new_paths, new_ids = [], []
for j, row in d2_sel_reset.iterrows():
    hc_pd = hc_pd_from_label_num(row["label_num"])
    spk_tok = safe_token(row["speaker_id"], 32, "NA")
    task_tok = safe_token(row["task"], 12, "0")
    out_name = f"D7_D2add_train_enh2_{hc_pd}_{spk_tok}_{task_tok}_{j+1:06d}.wav"
    out_path = TRAIN_ENH2_DIR / out_name
    new_paths.append(str(out_path))
    new_ids.append(out_path.stem)

d2_enh2["clip_path"] = new_paths
d2_enh2["sample_id"] = new_ids
d2_enh2["source_dataset"] = "D2"

# Combine D7 base train + selected D2 additions into one training manifest.
train_enh2 = pd.concat([d7_enh2, d2_enh2], axis=0, ignore_index=True)
train_enh2 = train_enh2[FINAL_COLS].copy()

# Guard: do not allow the literal string "NaN" in output columns.
for c in FINAL_COLS:
    if train_enh2[c].dtype == object and (train_enh2[c] == "NaN").any():
        raise RuntimeError(f"Found literal string 'NaN' in column '{c}'.")

# Final check: every manifest clip_path must exist on disk.
missing_enh2 = []
for p in tqdm(train_enh2["clip_path"].astype(str).tolist(), desc="Check TRAIN_ENH2 clip_path exists", dynamic_ncols=True):
    if not os.path.exists(p):
        missing_enh2.append(p)
        if len(missing_enh2) >= 25:
            break
require(len(missing_enh2) == 0, f"train_enh2 manifest points to missing files. Examples: {missing_enh2[:25]}")

# -------------------------
# 15) Write the manifest and final summaries
# -------------------------
# Outputs:
# - manifest_train_enh2.csv (training input for the train_enh2 train+val cell)
# - dataset_summary.json and preprocess_warnings.csv (run record)
# - run_config.json updated with outputs and copy stats
atomic_write_csv(MANIFEST_TRAIN_ENH2, train_enh2)

final_summary = {
    "dataset": "D7",
    "created_utc": datetime.utcnow().isoformat(),
    "status": "SUCCESS",
    "d7_out_root": str(D7_OUT_ROOT),
    "train_enh2_dir": str(TRAIN_ENH2_DIR),
    "manifest_train_enh2": str(MANIFEST_TRAIN_ENH2),
    "selection": run_config["selection"],
    "copy_stats": copy_stats,
    "counts_train_enh2": summarize_counts(train_enh2, split_name="train_enh2"),
}

atomic_write_text(SUMMARY_JSON, json.dumps(final_summary, indent=2))

run_config["copy_stats"] = copy_stats
run_config["outputs"] = {
    "train_enh2_dir": str(TRAIN_ENH2_DIR),
    "manifest_train_enh2": str(MANIFEST_TRAIN_ENH2),
    "dataset_summary_json": str(SUMMARY_JSON),
    "preprocess_warnings_csv": str(WARNINGS_CSV),
    "run_config_json": str(RUN_CONFIG_JSON),
}
atomic_write_text(RUN_CONFIG_JSON, json.dumps(run_config, indent=2))

print("\n✅ D7 train_enh2 build complete (FIXED DENOMINATOR).")
print("- Train_enh2 folder:", str(TRAIN_ENH2_DIR))
print("- Manifest:", str(MANIFEST_TRAIN_ENH2))
print("- Summary:", str(SUMMARY_JSON))
print("- Warnings:", str(WARNINGS_CSV))
print("- Config:", str(RUN_CONFIG_JSON))

The following cell trains and validates the **D7 enhanced model (train_enh2)** using only two existing inputs: (1) `manifest_train_enh2.csv` filtered to rows where `split == "train_enh2"` for training, and (2) the standard D7 `manifest_all.csv` filtered to rows where `split == "val"` for validation. No new data splits are created. The cell simply reads the prepared manifests, checks that required columns are present, confirms that the validation data truly belongs to **D7**, and stops immediately if any listed audio file paths are missing.

The model uses a **frozen wav2vec2 backbone**, so the main speech encoder is not updated during training. Only small trainable components are updated: **two separate classification heads** (one for vowel clips and one for all other speech) and a small **LayerNorm plus Dropout block** in front of each head. Each clip is assigned to a task group using a strict rule: **`task == "vowl"` is treated as “vowel”, everything else as “other”**. During data loading, an attention mask is created so that for vowel clips the model **ignores trailing padded silence** (based on a small amplitude threshold), while other clips use full attention. This reduces the chance of learning from padding instead of real speech.

Before training starts, the cell automatically finds the **most recent baseline D7 train and validation experiment** and uses it to **initialize the trainable heads**. A guard ensures that an enhanced run is not used as the baseline by excluding any prior experiment whose `summary_trainval.json` shows that it was trained on a `train_enh` split. Training is run for **three fixed random seeds** (1337, 2024, 7777). For each seed, the model trains for up to 10 epochs with early stopping (patience of 2 epochs), using gradient accumulation to reach an effective batch size of 64. After every epoch, the model is evaluated on the D7 validation set and the **best epoch is selected based on validation AUROC**.

At the best validation AUROC epoch for each seed, the cell computes a **validation-optimal probability threshold** using **Youden’s J statistic** (the ROC point that maximizes TPR minus FPR). It also computes threshold-based validation metrics at both **0.5** and the **optimal threshold**, including accuracy, precision, recall or sensitivity, specificity, F1, MCC, and Fisher exact test p-value. For each seed, the cell saves the best head weights (`best_heads.pt`), a validation ROC curve, confusion matrix plots, and a detailed `metrics.json` describing the settings, the baseline initialization used, and the results.

After all three seeds complete, the cell writes a single experiment-level `summary_trainval.json` inside a new timestamped folder under `trainval_runs/exp_<tag>_<timestamp>/`. This summary includes per-seed AUROC values, the mean AUROC with a 95% confidence interval across seeds, and the **canonical validation threshold summary**, stored as:

* `val_optimal_threshold.by_seed`
* `val_optimal_threshold.mean_sd.mean`
* `val_optimal_threshold.mean_sd.sd`

The full experiment summary is also appended as a new line to `trainval_runs/history_index.jsonl`, which keeps a running record of all experiments. The cell finishes by **unassigning the Colab runtime** to shut down the GPU instance.

In [None]:
# =========================
# D7 train_enh2 Train + Val (Frozen Backbone, Two Heads)
# Inputs: D7 train_enh2 train manifest + D7 val split from the full manifest
# Outputs: Per-seed best heads, per-seed metrics and plots, experiment summary, history log
# =========================
# Train + Val ONLY (CRASH-PROOF, WITH PROGRESS + HISTORY) — D7 ENHANCED (train_enh2)
# - Frozen Wav2Vec2 backbone
# - Two task heads + tiny LayerNorm + Dropout pre-head (trainable heads only)
# - Uses ONLY: <DX_OUT_ROOT>/manifests/manifest_train_enh2.csv  (split == "train_enh2")
#            + <DX_OUT_ROOT>/manifests/manifest_all.csv         (split == "val")
# - Initializes heads from MOST RECENT BASELINE D7 trainval experiment under:
#     <DX_OUT_ROOT>/trainval_runs/exp_*/run_D7_seed{seed}/best_heads.pt
#   Baseline guard: excludes any exp whose summary shows train_manifest_used contains "train_enh"
# - Writes ONLY under: <DX_OUT_ROOT>/trainval_runs/exp_<tag>_<timestamp>/
# - Saves best-epoch plots + metrics per seed, plus per-experiment summary + history_index.jsonl
# - Adds additional metrics: accuracy, precision, recall/sensitivity, specificity, F1, MCC, Fisher p-value
# - Determines VAL-opt threshold via Youden J at the BEST-AUROC epoch (per seed)
# - Stores thresholds ONLY as the canonical aggregate in summary_trainval.json:
#     val_optimal_threshold.by_seed
#     val_optimal_threshold.mean_sd.mean
#     val_optimal_threshold.mean_sd.sd
# - Ends by unassigning Colab runtime (L4) with messages
#
# NOTES
# - Train rows come from the prebuilt train_enh2 manifest (no resplitting here)
# - Val rows come from the standard D7 val split in the full manifest
# - dataset_id is inferred from the val manifest (expected "D7") and used for folder naming
# =========================

import os, json, math, random, time, warnings
from pathlib import Path

import numpy as np
import pandas as pd
import soundfile as sf

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import Wav2Vec2Model

from sklearn.metrics import (
    roc_auc_score, roc_curve,
    confusion_matrix, accuracy_score,
    precision_recall_fscore_support,
    matthews_corrcoef
)
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# -------------------------
# 0) Safety check: prevent importing the wrong torch/transformers
# -------------------------
# Colab can accidentally import a local file named torch.py or transformers.py.
# Fail early so results are not silently corrupted.
if os.path.exists("/content/torch.py") or os.path.exists("/content/torch/__init__.py"):
    raise RuntimeError("Found local /content/torch.py or /content/torch/ that shadows PyTorch. Rename/remove it and restart runtime.")
if os.path.exists("/content/transformers.py") or os.path.exists("/content/transformers/__init__.py"):
    raise RuntimeError("Found local /content/transformers.py or /content/transformers/ that shadows Hugging Face Transformers. Rename/remove it and restart runtime.")

# -------------------------
# 1) Mount Drive (safe if already mounted)
# -------------------------
# Needed to read manifests and write training outputs.
try:
    from google.colab import drive  # type: ignore
    if not os.path.isdir("/content/drive/MyDrive"):
        drive.mount("/content/drive")
except Exception:
    pass

# -------------------------
# 2) Set run root and input manifest files
# -------------------------
# Inputs: train_enh2 manifest (train split) + full manifest (val split).
D7_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D7-Multilingual (D1_D4_D5v2_D6)/preprocessed_v1"
DX_OUT_ROOT = str(globals().get("DX_OUT_ROOT", D7_OUT_ROOT_FALLBACK))
globals()["DX_OUT_ROOT"] = DX_OUT_ROOT

MANIFEST_ALL = f"{DX_OUT_ROOT}/manifests/manifest_all.csv"

# train_enh2 manifest + split label used inside the file
MANIFEST_TRAIN_ENH2 = f"{DX_OUT_ROOT}/manifests/manifest_train_enh2.csv"
TRAIN_SPLIT_NAME = "train_enh2"

# -------------------------
# 3) Experiment naming and output folder
# -------------------------
# Outputs: a new exp_* folder containing 3 seed runs + an experiment summary.
EXPERIMENT_TAG = "frozen_LNDO_trainEnh2_initBaseline"
RUN_STAMP = time.strftime("%Y%m%d_%H%M%S")

TRAINVAL_ROOT = Path(DX_OUT_ROOT) / "trainval_runs"
EXP_ROOT = TRAINVAL_ROOT / f"exp_{EXPERIMENT_TAG}_{RUN_STAMP}"
EXP_ROOT.mkdir(parents=True, exist_ok=True)

# -------------------------
# 4) Fixed settings (kept stable for repeatable comparisons)
# -------------------------
# These settings define training length, batch behavior, and model configuration.
MAX_EPOCHS     = 10
EFFECTIVE_BS   = 64
PER_DEVICE_BS  = 16
GRAD_ACCUM     = max(1, EFFECTIVE_BS // PER_DEVICE_BS)

LR             = 1e-3
PATIENCE       = 2
SEEDS          = [1337, 2024, 7777]

BACKBONE_CKPT  = "facebook/wav2vec2-base"
SR_EXPECTED    = 16000
TINY_THRESH    = 1e-4

DROPOUT_P      = 0.2

NUM_WORKERS    = 0
PIN_MEMORY     = False

VOWEL_TASK_VALUE = "vowl"

# Pick GPU if available, otherwise run on CPU.
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

# Reduce noisy warnings that do not affect results.
warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
warnings.filterwarnings("ignore", message="Passing `gradient_checkpointing` to a config initialization is deprecated")
warnings.filterwarnings("ignore", category=UserWarning, module="huggingface_hub")

# Print key run info (kept as-is).
print("DX_OUT_ROOT:", DX_OUT_ROOT)
print("MANIFEST_TRAIN_ENH2:", MANIFEST_TRAIN_ENH2)
print("TRAIN_SPLIT_NAME:", TRAIN_SPLIT_NAME)
print("MANIFEST_ALL (val source):", MANIFEST_ALL)
print("DEVICE:", DEVICE)
if DEVICE.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
print("PER_DEVICE_BS:", PER_DEVICE_BS, "| GRAD_ACCUM:", GRAD_ACCUM, "| EFFECTIVE_BS:", PER_DEVICE_BS * GRAD_ACCUM)
print("NUM_WORKERS:", NUM_WORKERS, "| PIN_MEMORY:", PIN_MEMORY)
print("EXPERIMENT_TAG:", EXPERIMENT_TAG, "| RUN_STAMP:", RUN_STAMP)
print("EXP_ROOT:", str(EXP_ROOT))

# -------------------------
# 5) Read manifests and build train/val tables
# -------------------------
# Inputs:
# - Train: manifest_train_enh2.csv filtered to split == train_enh2
# - Val:   manifest_all.csv filtered to split == val
# Output: train_df and val_df with a consistent set of columns.
if not os.path.exists(MANIFEST_TRAIN_ENH2):
    raise FileNotFoundError(
        "Missing manifest_train_enh2.csv at:\n"
        f"  {MANIFEST_TRAIN_ENH2}\n"
        "Run the train_enh2 builder first."
    )
if not os.path.exists(MANIFEST_ALL):
    raise FileNotFoundError(
        "Missing manifest_all.csv at:\n"
        f"  {MANIFEST_ALL}\n"
        "Confirm D7 merge-builder wrote manifests/manifest_all.csv under DX_OUT_ROOT."
    )

m_train = pd.read_csv(MANIFEST_TRAIN_ENH2)
m_all   = pd.read_csv(MANIFEST_ALL)

# Basic schema check (only the columns needed later).
req_cols = {"split", "clip_path", "label_num", "task"}
for name, df in [("manifest_train_enh2", m_train), ("manifest_all", m_all)]:
    missing = [c for c in sorted(req_cols) if c not in df.columns]
    if missing:
        raise ValueError(f"{name} missing required columns: {missing}. Found: {list(df.columns)}")

# Split selection (train_enh2 vs val).
m_train = m_train[m_train["split"].astype(str) == TRAIN_SPLIT_NAME].copy()
m_val   = m_all[m_all["split"].astype(str) == "val"].copy()

if len(m_train) == 0:
    raise RuntimeError(f"After filtering manifest_train_enh2.csv to split=={TRAIN_SPLIT_NAME!r}, 0 rows remain.")
if len(m_val) == 0:
    raise RuntimeError("After filtering manifest_all.csv to split=='val', 0 rows remain.")

# Infer dataset_id from val (expected D7); used for naming run folders.
if "dataset" in m_val.columns and m_val["dataset"].notna().any():
    dataset_id = str(m_val["dataset"].astype(str).value_counts(dropna=True).idxmax())
    m_val = m_val[m_val["dataset"].astype(str) == dataset_id].copy()
else:
    dataset_id = "DX"

# Hard guard: this cell is only for D7.
if dataset_id != "D7":
    raise RuntimeError(f"Dataset inferred from VAL manifest is {dataset_id!r}. Expected 'D7'. Check DX_OUT_ROOT/manifests/manifest_all.csv.")

# Keep a stable minimal column set across train and val.
keep_cols = ["clip_path", "label_num", "task", "speaker_id", "duration_sec", "split"]
for df in [m_train, m_val]:
    for c in keep_cols:
        if c not in df.columns:
            df[c] = np.nan

m_train = m_train[keep_cols].copy()
m_val   = m_val[keep_cols].copy()

train_df = m_train.copy().reset_index(drop=True)
val_df   = m_val.copy().reset_index(drop=True)

# Quick split stats (kept as-is).
print(f"\nDataset inferred (from VAL): {dataset_id}")
print(f"Train rows ({TRAIN_SPLIT_NAME}): {len(train_df)} | Val rows: {len(val_df)}")
print("Train label counts:", train_df["label_num"].value_counts(dropna=False).to_dict())
print("Val label counts:",   val_df["label_num"].value_counts(dropna=False).to_dict())

# -------------------------
# 6) Fail fast if audio files are missing
# -------------------------
# This prevents long training runs from failing halfway through.
def _fail_fast_missing_paths(df: pd.DataFrame, name: str):
    missing_paths = []
    for p in tqdm(df["clip_path"].astype(str).tolist(), desc=f"Check {name} clip_path exists", dynamic_ncols=True):
        if not os.path.exists(p):
            missing_paths.append(p)
            if len(missing_paths) >= 10:
                break
    if missing_paths:
        raise FileNotFoundError(f"{name}: missing clip_path(s). Examples: {missing_paths}")

_fail_fast_missing_paths(train_df, "TRAIN_ENH2")
_fail_fast_missing_paths(val_df, "VAL")

# -------------------------
# 7) Map each clip to a task group
# -------------------------
# Rule: task == "vowl" is treated as vowel, everything else as other.
def _task_group(task_val) -> str:
    return "vowel" if str(task_val) == VOWEL_TASK_VALUE else "other"

train_df["task_group"] = train_df["task"].apply(_task_group)
val_df["task_group"]   = val_df["task"].apply(_task_group)

# -------------------------
# 8) Dataset and batch collation
# -------------------------
# Inputs: the manifest tables (train_df, val_df).
# Outputs: padded batches of waveforms, attention masks, labels, and task_group.
class AudioManifestDataset(Dataset):
    """
    Reads one audio clip and builds a sample-level attention mask.

    Mask rule:
    - vowel: mask trailing near-silence so padding does not drive learning
    - other: keep full attention
    """
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        clip_path = str(row["clip_path"])
        label = int(row["label_num"])
        task_group = str(row["task_group"])

        if not os.path.exists(clip_path):
            raise FileNotFoundError(f"Missing clip_path: {clip_path}")

        # Load audio, convert to mono if needed, enforce dtype.
        y, sr = sf.read(clip_path, always_2d=False)
        if isinstance(y, np.ndarray) and y.ndim > 1:
            y = np.mean(y, axis=1)
        y = np.asarray(y, dtype=np.float32)

        # Guard: model assumes a single sample rate.
        if int(sr) != SR_EXPECTED:
            raise ValueError(f"Sample rate mismatch (expected {SR_EXPECTED}): {clip_path} has sr={sr}")

        # Build attention mask in sample space.
        attn = np.ones((len(y),), dtype=np.int64)

        if task_group == "vowel":
            # Find the last sample that is not near-zero, then mask the rest.
            k = -1
            for j in range(len(y) - 1, -1, -1):
                if abs(float(y[j])) > TINY_THRESH:
                    k = j
                    break
            if k >= 0:
                attn[:k+1] = 1
                attn[k+1:] = 0
            else:
                attn[:] = 1
        else:
            attn[:] = 1

        return {
            "input_values": torch.from_numpy(y),                 # float32 [T]
            "attention_mask": torch.from_numpy(attn),            # int64   [T]
            "labels": torch.tensor(label, dtype=torch.long),     # int64   []
            "task_group": task_group,                            # str
        }

def collate_fn(batch):
    """Pads waveforms and masks to the longest clip in the batch."""
    max_len = int(max(b["input_values"].numel() for b in batch))
    input_vals, attn_masks, labels, task_groups = [], [], [], []
    for b in batch:
        x = b["input_values"]
        a = b["attention_mask"]
        pad = max_len - x.numel()
        if pad > 0:
            x = torch.cat([x, torch.zeros(pad, dtype=x.dtype)], dim=0)
            a = torch.cat([a, torch.zeros(pad, dtype=a.dtype)], dim=0)
        input_vals.append(x)
        attn_masks.append(a)
        labels.append(b["labels"])
        task_groups.append(b["task_group"])
    return {
        "input_values": torch.stack(input_vals, dim=0),      # [B,T]
        "attention_mask": torch.stack(attn_masks, dim=0),    # [B,T]
        "labels": torch.stack(labels, dim=0),                # [B]
        "task_group": task_groups,                           # list[str]
    }

# -------------------------
# 9) Model: frozen backbone with two task-specific heads
# -------------------------
# Backbone: wav2vec2-base (frozen)
# Trainable: LayerNorm+Dropout blocks + two linear heads (vowel/other)
class Wav2Vec2TwoHeadClassifier(nn.Module):
    """
    Frozen backbone with two heads.
    Only head-related layers are updated during training.
    """
    def __init__(self, ckpt: str, dropout_p: float):
        super().__init__()
        self.backbone = Wav2Vec2Model.from_pretrained(
            ckpt,
            use_safetensors=True,
            local_files_only=False
        )
        for p in self.backbone.parameters():
            p.requires_grad = False

        H = int(self.backbone.config.hidden_size)
        self.pre_vowel = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.pre_other = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))

        self.head_vowel = nn.Linear(H, 2)
        self.head_other = nn.Linear(H, 2)
        self.loss_fn = nn.CrossEntropyLoss()

    def masked_mean_pool(self, last_hidden, attn_mask_samples):
        # Convert sample-level mask to feature-frame mask, then do masked mean pooling.
        feat_mask = self.backbone._get_feature_vector_attention_mask(last_hidden.shape[1], attn_mask_samples)
        feat_mask = feat_mask.to(last_hidden.device).unsqueeze(-1).type_as(last_hidden)
        masked = last_hidden * feat_mask
        denom = feat_mask.sum(dim=1).clamp(min=1.0)
        return masked.sum(dim=1) / denom

    def forward(self, input_values, attention_mask, labels, task_group):
        # Feature extraction is frozen; gradients flow only through the heads.
        with torch.no_grad():
            out = self.backbone(input_values=input_values, attention_mask=attention_mask)
            last_hidden = out.last_hidden_state  # [B,T',H]

        pooled = self.masked_mean_pool(last_hidden, attention_mask).float()  # [B,H]

        z_v = self.pre_vowel(pooled)
        z_o = self.pre_other(pooled)

        logits_v = self.head_vowel(z_v)
        logits_o = self.head_other(z_o)

        # Select the head based on task_group for each sample in the batch.
        is_vowel = torch.tensor([tg == "vowel" for tg in task_group], device=pooled.device, dtype=torch.bool)
        logits = logits_o.clone()
        if is_vowel.any():
            logits[is_vowel] = logits_v[is_vowel]

        loss = self.loss_fn(logits, labels)
        return loss, logits

# -------------------------
# 9.5) Head initialization from baseline run
# -------------------------
# Starts train_enh2 heads from the most recent baseline heads for the same seed.
def load_heads_into_model(model: Wav2Vec2TwoHeadClassifier, best_heads_path: Path):
    if not best_heads_path.exists():
        raise FileNotFoundError(f"Missing best_heads.pt: {str(best_heads_path)}")
    state = torch.load(str(best_heads_path), map_location="cpu")
    model.pre_vowel.load_state_dict(state["pre_vowel"], strict=True)
    model.pre_other.load_state_dict(state["pre_other"], strict=True)
    model.head_vowel.load_state_dict(state["head_vowel"], strict=True)
    model.head_other.load_state_dict(state["head_other"], strict=True)
    return model

# -------------------------
# 10) Metrics and plotting helpers
# -------------------------
# Outputs: AUROC, threshold-based metrics, optional Youden-J threshold, and PNG plots.
def compute_auc(y_true, y_prob):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    if len(np.unique(y_true)) < 2:
        return float("nan")
    return float(roc_auc_score(y_true, y_prob))

def compute_threshold_metrics(y_true, y_prob, thr=0.5):
    # Converts probabilities into predicted labels, then reports common classification metrics.
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    tn, fp, fn, tp = (cm.ravel().tolist() if cm.size == 4 else [0, 0, 0, 0])

    acc = float(accuracy_score(y_true, y_pred))
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", zero_division=0)
    mcc = float(matthews_corrcoef(y_true, y_pred)) if len(np.unique(y_true)) > 1 else float("nan")

    sensitivity = float(rec)
    specificity = float(tn / (tn + fp)) if (tn + fp) > 0 else float("nan")

    # Fisher exact test on the 2x2 confusion matrix (if available).
    p_value = float("nan")
    try:
        from scipy.stats import fisher_exact  # type: ignore
        _, p_value = fisher_exact([[tn, fp], [fn, tp]], alternative="two-sided")
        p_value = float(p_value)
    except Exception:
        p_value = float("nan")

    return {
        "threshold": float(thr),
        "tn": int(tn), "fp": int(fp), "fn": int(fn), "tp": int(tp),
        "accuracy": float(acc),
        "precision": float(prec),
        "recall": float(rec),
        "f1": float(f1),
        "sensitivity": float(sensitivity),
        "specificity": float(specificity),
        "mcc": float(mcc),
        "p_value_fisher": float(p_value),
    }

def compute_youden_j_threshold(y_true, y_prob):
    # Picks the threshold that maximizes (TPR - FPR) on the ROC curve.
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    if len(np.unique(y_true)) < 2:
        return float("nan"), {"youden_j": float("nan"), "tpr": float("nan"), "fpr": float("nan")}
    fpr, tpr, thr = roc_curve(y_true, y_prob)
    j = tpr - fpr
    idx = int(np.argmax(j))
    return float(thr[idx]), {"youden_j": float(j[idx]), "tpr": float(tpr[idx]), "fpr": float(fpr[idx])}

def save_roc_curve_png(y_true, y_prob, out_png):
    # Saves a simple ROC curve image for the best epoch (val).
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    plt.figure()
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curve (Val)")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def save_confusion_png(y_true, y_prob, out_png, thr=0.5):
    # Saves a confusion matrix image for a chosen threshold.
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(int)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])

    plt.figure()
    plt.imshow(cm)
    plt.xticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.yticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix (Val, thr={thr:.4f})")
    for (i, j), v in np.ndenumerate(cm):
        plt.text(j, i, str(v), ha="center", va="center")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def mean_sd(vals):
    # Small helper for mean and sample SD (used for 3-seed summaries).
    vals = np.asarray(vals, dtype=np.float64)
    mu = float(np.nanmean(vals)) if np.any(~np.isnan(vals)) else float("nan")
    sd = float(np.nanstd(vals, ddof=1)) if np.sum(~np.isnan(vals)) > 1 else 0.0
    return mu, sd

# -------------------------
# 11) Seed control for repeatable runs
# -------------------------
def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# -------------------------
# 11.5) Find the most recent BASELINE experiment for head init
# -------------------------
# Baseline is detected by reading summary_trainval.json and rejecting any run that used train_enh manifests.
BASELINE_TRAINVAL_ROOT = Path(DX_OUT_ROOT) / "trainval_runs"
if not BASELINE_TRAINVAL_ROOT.exists():
    raise FileNotFoundError(f"Missing trainval_runs folder under DX_OUT_ROOT: {str(BASELINE_TRAINVAL_ROOT)}")

exp_dirs = sorted([p for p in BASELINE_TRAINVAL_ROOT.glob("exp_*") if p.is_dir()],
                  key=lambda p: p.stat().st_mtime, reverse=True)
if not exp_dirs:
    raise FileNotFoundError(f"No exp_* folders found under: {str(BASELINE_TRAINVAL_ROOT)}")

train_dataset_id = "D7"

def _is_baseline_exp(exp_path: Path) -> bool:
    # Uses summary_trainval.json as the record of which train manifest was used.
    summary_path = exp_path / "summary_trainval.json"
    if not summary_path.exists():
        return False
    try:
        with open(summary_path, "r", encoding="utf-8") as f:
            s = json.load(f)
        train_manifest_used = str(s.get("train_manifest_used", "")).lower()
        if "train_enh" in train_manifest_used or "manifest_train_enh" in train_manifest_used:
            return False
        return True
    except Exception:
        return False

def _has_all_seeds(exp_path: Path, dataset_id: str, seeds: list) -> bool:
    # Ensures each seed has a best_heads.pt file.
    for s in seeds:
        p = exp_path / f"run_{dataset_id}_seed{s}" / "best_heads.pt"
        if not p.exists():
            return False
    return True

baseline_exp = None
for ed in exp_dirs:
    if ed.resolve() == EXP_ROOT.resolve():
        continue
    if _is_baseline_exp(ed) and _has_all_seeds(ed, train_dataset_id, SEEDS):
        baseline_exp = ed
        break

if baseline_exp is None:
    raise FileNotFoundError(
        "Could not find a BASELINE D7 trainval experiment with all 3 best_heads.pt files.\n"
        "Baseline guard excludes experiments whose summary_trainval.json shows train_manifest_used contains 'train_enh'.\n"
        f"Searched under: {str(BASELINE_TRAINVAL_ROOT)}/exp_*/run_D7_seedXXXX/best_heads.pt"
    )

baseline_summary_path = baseline_exp / "summary_trainval.json"
with open(baseline_summary_path, "r", encoding="utf-8") as f:
    baseline_summary = json.load(f)

# Print baseline selection (kept as-is).
print("\nBaseline initialization experiment selected:")
print(" ", str(baseline_exp))
print(" ", "summary:", str(baseline_summary_path))
print(" ", "train_manifest_used (baseline):", baseline_summary.get("train_manifest_used", "NA"))

# -------------------------
# 12) One seed run: train until no improvement, keep best epoch
# -------------------------
# Outputs per seed: best_heads.pt, metrics.json, ROC and confusion plots.
def run_trainval_once(seed: int):
    set_all_seeds(seed)

    run_dir = EXP_ROOT / f"run_{dataset_id}_seed{seed}"
    run_dir.mkdir(parents=True, exist_ok=True)

    # Build datasets and loaders for this seed.
    train_ds = AudioManifestDataset(train_df)
    val_ds   = AudioManifestDataset(val_df)

    train_loader = DataLoader(
        train_ds,
        batch_size=PER_DEVICE_BS,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        collate_fn=collate_fn
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=PER_DEVICE_BS,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        collate_fn=collate_fn
    )

    # Quick loader check so failures happen immediately.
    print(f"\n[seed={seed}] Warm-up: loading 3 train batches...")
    t0 = time.time()
    it = iter(train_loader)
    for i in range(3):
        _ = next(it)
        print(f"  loaded warmup batch {i+1}/3")
    print(f"[seed={seed}] Warm-up done in {time.time()-t0:.2f}s")

    # Create model and initialize heads from baseline (same seed).
    model = Wav2Vec2TwoHeadClassifier(BACKBONE_CKPT, dropout_p=DROPOUT_P).to(DEVICE)

    baseline_heads_path = baseline_exp / f"run_{train_dataset_id}_seed{seed}" / "best_heads.pt"
    print(f"[seed={seed}] Initializing heads from baseline:")
    print(" ", str(baseline_heads_path))
    model = load_heads_into_model(model, baseline_heads_path)
    model.train()

    # Only head-related layers are optimized.
    trainable_params = (
        list(model.pre_vowel.parameters()) + list(model.pre_other.parameters()) +
        list(model.head_vowel.parameters()) + list(model.head_other.parameters())
    )
    opt = torch.optim.Adam(trainable_params, lr=LR)

    # Track best epoch by val AUROC.
    best_auc = -1.0
    best_epoch = -1
    no_improve = 0

    best_state = None
    best_val_probs = None
    best_val_true = None

    # Store threshold details for the best epoch only.
    best_thr_youden = float("nan")
    best_thr_youden_details = None
    best_val_metrics_thr05 = None
    best_val_metrics_thr_opt = None

    for epoch in range(1, MAX_EPOCHS + 1):
        model.train()
        train_losses = []
        opt.zero_grad(set_to_none=True)

        # Train loop (with grad accumulation to reach the effective batch size).
        pbar = tqdm(train_loader, desc=f"[seed={seed}] Train epoch {epoch}", dynamic_ncols=True)
        step = 0
        for batch in pbar:
            step += 1
            input_values = batch["input_values"].to(DEVICE, non_blocking=False)
            attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
            labels = batch["labels"].to(DEVICE, non_blocking=False)
            task_group = batch["task_group"]

            loss, _ = model(input_values, attention_mask, labels, task_group)
            loss = loss / GRAD_ACCUM
            loss.backward()

            train_losses.append(float(loss.detach().cpu().item()) * GRAD_ACCUM)

            if (step % GRAD_ACCUM) == 0:
                opt.step()
                opt.zero_grad(set_to_none=True)

        # Flush any partial accumulation at the end of the epoch.
        if (step % GRAD_ACCUM) != 0:
            opt.step()
            opt.zero_grad(set_to_none=True)

        avg_train_loss = float(np.mean(train_losses)) if train_losses else float("nan")

        # Val loop: collect probabilities for AUROC and threshold selection.
        model.eval()
        all_probs, all_true = [], []
        vpbar = tqdm(val_loader, desc=f"[seed={seed}] Val epoch {epoch}", dynamic_ncols=True)
        with torch.inference_mode():
            for batch in vpbar:
                input_values = batch["input_values"].to(DEVICE, non_blocking=False)
                attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
                labels = batch["labels"].to(DEVICE, non_blocking=False)
                task_group = batch["task_group"]

                _, logits = model(input_values, attention_mask, labels, task_group)
                probs = torch.softmax(logits.float(), dim=-1)[:, 1].detach().cpu().numpy().astype(np.float64)
                all_probs.extend(probs.tolist())
                all_true.extend(labels.detach().cpu().numpy().astype(np.int64).tolist())

        val_auc = compute_auc(all_true, all_probs)
        print(f"seed={seed} | epoch {epoch:02d}/{MAX_EPOCHS} | train_loss={avg_train_loss:.5f} | val_AUROC={val_auc:.5f}")

        # Update "best" only when AUROC improves.
        improved = (not math.isnan(val_auc)) and (val_auc > best_auc + 1e-12)
        if improved:
            best_auc = float(val_auc)
            best_epoch = int(epoch)
            no_improve = 0

            # Save only head-related weights (backbone is frozen and unchanged).
            best_state = {
                "pre_vowel": {k: v.detach().cpu().clone() for k, v in model.pre_vowel.state_dict().items()},
                "pre_other": {k: v.detach().cpu().clone() for k, v in model.pre_other.state_dict().items()},
                "head_vowel": {k: v.detach().cpu().clone() for k, v in model.head_vowel.state_dict().items()},
                "head_other": {k: v.detach().cpu().clone() for k, v in model.head_other.state_dict().items()},
            }

            best_val_probs = list(all_probs)
            best_val_true  = list(all_true)

            # Save metrics at a fixed threshold (0.5) and at the val-opt threshold.
            best_val_metrics_thr05 = compute_threshold_metrics(best_val_true, best_val_probs, thr=0.5)

            thr_opt, details = compute_youden_j_threshold(best_val_true, best_val_probs)
            best_thr_youden = float(thr_opt)
            best_thr_youden_details = details
            best_val_metrics_thr_opt = compute_threshold_metrics(best_val_true, best_val_probs, thr=best_thr_youden)
        else:
            no_improve += 1

        # Early stop after PATIENCE epochs without improvement.
        if no_improve >= PATIENCE:
            break

    if best_state is None or best_val_probs is None or best_val_true is None:
        raise RuntimeError(
            "No best epoch captured. Validation AUROC may be NaN due to single-class validation split "
            "or earlier failures."
        )

    # Save best heads for this seed.
    best_heads_path = run_dir / "best_heads.pt"
    torch.save(best_state, str(best_heads_path))

    # Save plots for the best epoch only.
    roc_png = run_dir / "roc_curve.png"
    cm_png_05 = run_dir / "confusion_matrix_thr0p5.png"
    cm_png_opt = run_dir / "confusion_matrix_thr_opt.png"

    save_roc_curve_png(np.asarray(best_val_true, dtype=np.int64), np.asarray(best_val_probs, dtype=np.float64), str(roc_png))
    save_confusion_png(np.asarray(best_val_true, dtype=np.int64), np.asarray(best_val_probs, dtype=np.float64), str(cm_png_05), thr=0.5)
    if not np.isnan(best_thr_youden):
        save_confusion_png(np.asarray(best_val_true, dtype=np.int64), np.asarray(best_val_probs, dtype=np.float64), str(cm_png_opt), thr=float(best_thr_youden))

    # Write per-seed metrics (includes val-opt threshold for later test-time reuse).
    metrics = {
        "dataset": dataset_id,
        "seed": int(seed),
        "best_val_auroc": float(best_auc),
        "best_epoch": int(best_epoch),

        "train_manifest_used": MANIFEST_TRAIN_ENH2,
        "val_manifest_used": MANIFEST_ALL,

        "init_heads": {
            "mode": "baseline_best_heads",
            "baseline_exp_used": str(baseline_exp),
            "baseline_summary_path": str(baseline_summary_path),
            "baseline_best_heads_path": str(baseline_heads_path),
        },

        "n_train": int(len(train_df)),
        "n_val": int(len(val_df)),
        "label_counts_train": train_df["label_num"].value_counts(dropna=False).to_dict(),
        "label_counts_val": val_df["label_num"].value_counts(dropna=False).to_dict(),

        "experiment_tag": EXPERIMENT_TAG,
        "run_stamp": RUN_STAMP,

        "dropout_p": float(DROPOUT_P),
        "lr": float(LR),
        "effective_batch_size": int(PER_DEVICE_BS * GRAD_ACCUM),
        "per_device_batch_size": int(PER_DEVICE_BS),
        "grad_accum_steps": int(GRAD_ACCUM),

        "backbone_ckpt": BACKBONE_CKPT,

        "val_opt_threshold_method": "Youden J (maximize TPR - FPR on VAL ROC curve)",
        "val_opt_threshold": float(best_thr_youden),
        "val_opt_details": best_thr_youden_details,

        "thr_metrics_val_thr0p5": best_val_metrics_thr05,
        "thr_metrics_val_thr_opt": best_val_metrics_thr_opt,

        "artifacts": {
            "roc_curve_png": str(roc_png),
            "confusion_thr0p5_png": str(cm_png_05),
            "confusion_thr_opt_png": str(cm_png_opt),
            "best_heads_pt": str(best_heads_path),
        },
    }

    metrics_path = run_dir / "metrics.json"
    with open(metrics_path, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2)

    # Progress prints (kept as-is).
    print(f"[seed={seed}] VAL-opt threshold (Youden J): {float(best_thr_youden):.6f}")
    print(f"[seed={seed}] WROTE:")
    print(" ", str(metrics_path))
    print(" ", str(roc_png))
    print(" ", str(cm_png_05))
    print(" ", str(cm_png_opt))
    print(" ", str(best_heads_path))

    return {
        "seed": int(seed),
        "best_val_auroc": float(best_auc),
        "best_epoch": int(best_epoch),
        "val_opt_thr": float(best_thr_youden),
        "run_dir": str(run_dir),
        "seed_metrics": metrics,
    }

# -------------------------
# 13) Run all seeds and write the experiment summary
# -------------------------
# Outputs:
# - Per-seed runs under the experiment folder
# - summary_trainval.json for the experiment
# - history_index.jsonl appended at the trainval_runs root
results = []
for seed in SEEDS:
    results.append(run_trainval_once(seed))

aucs = [r["best_val_auroc"] for r in results]
thr_vals = [r["val_opt_thr"] for r in results]

# 95% CI across 3 seeds (t distribution, df=2).
t_crit = 4.302652729911275  # df=2, 95% CI
n = len(aucs)
mean_auc = float(np.mean(aucs))
std_auc = float(np.std(aucs, ddof=1)) if n > 1 else 0.0
half_width = float(t_crit * (std_auc / math.sqrt(n))) if n > 1 else 0.0
ci95 = [float(mean_auc - half_width), float(mean_auc + half_width)]

# Aggregate threshold across seeds (mean ± SD).
thr_mean, thr_sd = mean_sd(thr_vals)

# Print summary stats (kept as-is).
print("\nAUROC by seed:")
for r in results:
    print(f"  seed {r['seed']}: {r['best_val_auroc']:.6f}")
print(f"\nMean AUROC: {mean_auc:.6f}")
print(f"95% CI (t, n=3): [{ci95[0]:.6f}, {ci95[1]:.6f}]")

print("\nVAL-opt thresholds (Youden J) by seed:")
for r in results:
    print(f"  seed {r['seed']}: {r['val_opt_thr']:.6f}")
print(f"  mean ± SD: {thr_mean:.6f} ± {thr_sd:.6f}")

# Canonical threshold record saved only in the experiment summary.
val_optimal_threshold_obj = {
    "method": "Youden J (maximize TPR - FPR on VAL ROC curve)",
    "by_seed": {str(r["seed"]): float(r["val_opt_thr"]) for r in results},
    "mean_sd": {"mean": float(thr_mean), "sd": float(thr_sd)},
}

# Experiment-level summary used later by test-only cells.
exp_summary = {
    "dataset": dataset_id,
    "dx_out_root": DX_OUT_ROOT,

    "train_manifest_used": MANIFEST_TRAIN_ENH2,
    "val_manifest_used": MANIFEST_ALL,

    "init_heads": {
        "mode": "baseline_best_heads",
        "baseline_exp_used": str(baseline_exp),
        "baseline_summary_path": str(baseline_summary_path),
        "baseline_best_heads_by_seed": {
            str(s): str(baseline_exp / f"run_{train_dataset_id}_seed{s}" / "best_heads.pt") for s in SEEDS
        },
    },

    "experiment_tag": EXPERIMENT_TAG,
    "run_stamp": RUN_STAMP,
    "exp_root": str(EXP_ROOT),
    "run_dirs": [r["run_dir"] for r in results],
    "seeds": SEEDS,

    "aurocs": [float(x) for x in aucs],
    "mean_auroc": float(mean_auc),
    "t_crit_df2_95": float(t_crit),
    "half_width_95_ci": float(half_width),
    "ci95": ci95,

    "n_train": int(len(train_df)),
    "n_val": int(len(val_df)),
    "label_counts_train": train_df["label_num"].value_counts(dropna=False).to_dict(),
    "label_counts_val": val_df["label_num"].value_counts(dropna=False).to_dict(),

    "effective_batch_size": int(PER_DEVICE_BS * GRAD_ACCUM),
    "per_device_batch_size": int(PER_DEVICE_BS),
    "grad_accum_steps": int(GRAD_ACCUM),

    "backbone_ckpt": BACKBONE_CKPT,
    "dropout_p": float(DROPOUT_P),
    "lr": float(LR),

    "val_optimal_threshold": val_optimal_threshold_obj,
    "per_seed_metrics": [r["seed_metrics"] for r in results],
}

# Write experiment summary and append global history log.
summary_path = EXP_ROOT / "summary_trainval.json"
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(exp_summary, f, indent=2)

history_path = TRAINVAL_ROOT / "history_index.jsonl"
TRAINVAL_ROOT.mkdir(parents=True, exist_ok=True)
with open(history_path, "a", encoding="utf-8") as f:
    f.write(json.dumps(exp_summary) + "\n")

print("\nWROTE per-experiment summary:", str(summary_path))
print("APPENDED global history index:", str(history_path))
print("\nOpen this folder to access artifacts:", str(EXP_ROOT))

# -------------------------
# 14) Unassign runtime (stop L4)
# -------------------------
# Frees the GPU at the end so the session does not keep running.
print("\nAll done. Unassigning the runtime to stop the L4 instance...")
try:
    from google.colab import runtime  # type: ignore
    print("Calling runtime.unassign() now.")
    runtime.unassign()
except Exception as e:
    print("Could not unassign runtime automatically. Error:", repr(e))
    print("Manual stop: Runtime -> Disconnect and delete runtime.")

The following cell evaluates the **enhanced D7 model trained with the trainEnh2 dataset** on the **D2 test split only**, using the original D2 `manifest_all.csv` as the only source of test clips. It automatically finds the most recent D7 train and validation experiment whose folder name includes **“trainEnh2”**, and checks that all required files are present. This includes the three saved model heads for seeds 1337, 2024, and 7777, as well as the matching training summary file. As a safety check, the cell confirms that the test manifest truly belongs to **dataset D2** and rechecks that all model files exist after the experiment folder is selected.

A single global decision threshold is loaded from the train and validation summary and is applied to all three seeds during testing. This threshold is not tuned using D2 data. If the threshold is missing or cannot be read, the cell falls back to **0.5** and records that choice in the outputs. Each D2 test audio file is checked to ensure it exists and that it is sampled at **16 kHz**. Clips are grouped into two task types: **vowel** for sustained vowel tasks, and **other** for all remaining speech tasks. For vowel clips, an attention mask is created to ignore padded silence at the end of the signal so the model does not score non speech regions. Sex values are standardized for later analysis by mapping “male” and “female” to **M** and **F**, with all other values treated as unknown.

For each random seed, the cell rebuilds the model, which consists of a frozen speech feature extractor and two task specific classification heads, and loads the trained weights for that seed. It then runs inference on the full D2 test set. Per clip results are written to a `predictions.csv` file that includes the clip path, true label, predicted probability, sex, speaker ID, task group, seed, run tag, and the decision threshold used. Performance is reported using **AUROC**, standard threshold based metrics such as accuracy, precision, recall, specificity, F1 score, Matthews correlation coefficient, and Fisher exact test p value, along with a fairness measure defined as the difference in false negative rates between female and male speakers at the chosen threshold. For each seed, the cell also saves an ROC curve, an overall confusion matrix, and separate confusion matrices for male and female groups when enough data are available.

After all three seeds are evaluated, the cell produces an overall summary. This includes the **mean AUROC with a 95 percent confidence interval** across seeds, as well as the **mean and standard deviation** of all threshold based metrics and fairness values. Outputs are stored in a timestamped run folder under `monolingual_test_runs`, together with structured summary files that make it easy to track and compare results with earlier runs. Pointer files are updated to mark the latest run and to maintain a running history tied to the training tag.

At the end, the cell writes a small set of configuration and log files that match the dataset and evaluation naming scheme, backing up any existing versions first. It then releases the Colab runtime resources, including the GPU, to signal that the evaluation is complete.

In [None]:
# D7 trainEnh2 heads: test on D2 test split (3 seeds)
# Inputs: D2 manifest (test split only), most recent D7 trainEnh2 trainval run (best_heads.pt + summary_trainval.json)
# Outputs: per-seed predictions and metrics + plots, plus run-level summaries and pointers

# =========================
# TEST ONLY (CRASH-PROOF, WITH PROGRESS + STORED METRICS) — D7 ENHANCED → D2 TEST
# - Evaluates the D7 ENHANCED trained heads (trained on train_enh2) on the D2 TEST split only
# - Uses ONLY D2: <D2_OUT_ROOT>/manifests/manifest_all.csv  (TEST split)
# - Loads finished heads from MOST RECENT D7 *ENHANCED* trainval experiment under:
#     <D7_OUT_ROOT>/trainval_runs/exp_*/run_D7_seed{seed}/best_heads.pt
#   Selection rule: exp folder name must contain substring "trainEnh2" (case-insensitive)
#   and must contain all three seeds + summary_trainval.json.
# - Uses the SINGLE MEAN VAL-optimal threshold stored by that D7 trainval in:
#     summary_trainval.json -> val_optimal_threshold.mean_sd.mean
#   (No VAL threshold recomputation in this cell)
# - Evaluates 3 seeds separately (1337, 2024, 7777)
# - Reports:
#     * mean Test AUROC ± 95% CI (t, n=3)
#     * A single threshold used for ALL seeds (mean val-opt threshold) + note if fallback to 0.5
#     * Threshold metrics on D2 TEST @ that single threshold as mean ± SD
#     * FAIRNESS (H3) on D2 TEST @ that single threshold as mean ± SD
#     * Confusion charts split by sex (M/F) on D2 TEST @ that single threshold
# - Writes all artifacts under:
#     <D7_OUT_ROOT>/monolingual_test_runs/run_<FULL_TRAINVAL_EXP_TAG>__<RUN_STAMP>/...
#   plus:
#     * monolingual_test_runs/last_run_pointer.json (intentional overwrite)
#     * monolingual_test_runs/summary_latest.json (intentional overwrite)
#     * monolingual_test_runs/history_index.jsonl (append-only)
#     * monolingual_test_runs/run_<TAG>/tag_run_pointer.json (never overwritten across tags)
#
# D2-SPECIFIC NOTE (the manifest):
# - sex is encoded as the exact strings "male" / "female" (case-sensitive)
#   This code maps: "male" -> M, "female" -> F, and anything else -> UNK
#
# GUARDS:
# A) Hard-assert D2 dataset_id == "D2" after inference from D2 manifest
# B) Re-assert all best_heads.pt exist immediately after chosen_exp is selected
# =========================

import os, json, math, random, time, warnings
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import soundfile as sf

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import Wav2Vec2Model

from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, matthews_corrcoef
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# -------------------------
# 0) Environment guards
#    Prevents importing the wrong library due to a local file/folder name clash.
# -------------------------
if os.path.exists("/content/torch.py") or os.path.exists("/content/torch/__init__.py"):
    raise RuntimeError("Found local /content/torch.py or /content/torch/ that shadows PyTorch. Rename/remove it and restart runtime.")
if os.path.exists("/content/transformers.py") or os.path.exists("/content/transformers/__init__.py"):
    raise RuntimeError("Found local /content/transformers.py or /content/transformers/ that shadows Hugging Face Transformers. Rename/remove it and restart runtime.")

# -------------------------
# 1) Drive mount (Colab-safe)
#    Makes files available when running in Colab; harmless elsewhere.
# -------------------------
try:
    from google.colab import drive  # type: ignore
    if not os.path.isdir("/content/drive/MyDrive"):
        drive.mount("/content/drive")
except Exception:
    pass

# -------------------------
# 2) Input roots and shared run settings
#    - Reads D2 manifest from D2_OUT_ROOT
#    - Reads D7 trainval artifacts and writes test outputs under D7_OUT_ROOT
# -------------------------
D7_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D7-Multilingual (D1_D4_D5v2_D6)/preprocessed_v1"
D2_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D2-Slovak (EWA-DB)/EWA-DB/preprocessed_v1"

D7_OUT_ROOT = str(globals().get("D7_OUT_ROOT", D7_OUT_ROOT_FALLBACK))
D2_OUT_ROOT = str(globals().get("D2_OUT_ROOT", D2_OUT_ROOT_FALLBACK))

D2_MANIFEST_ALL = f"{D2_OUT_ROOT}/manifests/manifest_all.csv"

# Keep DX_OUT_ROOT aligned to the run root (D7) for consistent naming across cells.
DX_OUT_ROOT = D7_OUT_ROOT
globals()["DX_OUT_ROOT"] = DX_OUT_ROOT
globals()["D7_OUT_ROOT"] = D7_OUT_ROOT
globals()["D2_OUT_ROOT"] = D2_OUT_ROOT

# Timestamp used to create a unique output folder name for this run.
RUN_STAMP = time.strftime("%Y%m%d_%H%M%S")

# Backup helper for files that are intentionally overwritten in the builder-aligned outputs.
def _backup_if_exists(p: Path):
    if p.exists():
        bak = p.with_suffix(p.suffix + f".bak_{RUN_STAMP}")
        try:
            p.rename(bak)
            print(f"BACKUP: {str(p)} -> {str(bak)}")
        except Exception as e:
            raise RuntimeError(f"Could not backup existing file before overwrite: {str(p)}. Error: {repr(e)}")

# Stable tag name for folder names and pointer files.
def _sanitize_tag(s: str) -> str:
    s = str(s).strip()
    out = []
    for ch in s:
        if ch.isalnum() or ch in ["-", "_"]:
            out.append(ch)
        else:
            out.append("_")
    out = "".join(out).strip("_")
    return out if out else "tag"

# -------------------------
# 3) Fixed evaluation settings
#    Mirrors the training setup where needed (seeds, backbone, audio sampling rate).
# -------------------------
SEEDS          = [1337, 2024, 7777]
BACKBONE_CKPT  = "facebook/wav2vec2-base"
SR_EXPECTED    = 16000
TINY_THRESH    = 1e-4

# Batch sizing: PER_DEVICE_BS controls runtime memory; GRAD_ACCUM kept for consistent prints.
EFFECTIVE_BS   = 64
PER_DEVICE_BS  = 16
GRAD_ACCUM     = max(1, EFFECTIVE_BS // PER_DEVICE_BS)

DROPOUT_P      = 0.2

NUM_WORKERS    = 0
PIN_MEMORY     = False

USE_AMP        = True

# Trainval folder selection filter for trainEnh2.
REQUIRED_EXP_SUBSTRING = "trainEnh2"  # case-insensitive

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

# Reduce noisy warnings not relevant to test-only inference.
warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
warnings.filterwarnings("ignore", message="Passing `gradient_checkpointing` to a config initialization is deprecated")
warnings.filterwarnings("ignore", category=UserWarning, module="huggingface_hub")

print("D7_OUT_ROOT:", D7_OUT_ROOT)
print("D2_OUT_ROOT:", D2_OUT_ROOT)
print("D2_MANIFEST_ALL:", D2_MANIFEST_ALL)
print("DEVICE:", DEVICE)
if DEVICE.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
print("PER_DEVICE_BS:", PER_DEVICE_BS, "| EFFECTIVE_BS:", PER_DEVICE_BS * GRAD_ACCUM)
print("NUM_WORKERS:", NUM_WORKERS, "| PIN_MEMORY:", PIN_MEMORY)
print("USE_AMP:", bool(USE_AMP and DEVICE.type == "cuda"))
print("Enhanced exp required substring (case-insensitive):", REQUIRED_EXP_SUBSTRING)

# -------------------------
# 4) Load D2 manifest and build the D2 TEST table
#    - Keeps only split == "test"
#    - Confirms the manifest belongs to dataset "D2" (Guard A)
# -------------------------
if not os.path.exists(D2_MANIFEST_ALL):
    raise FileNotFoundError(f"Missing D2 manifest_all.csv: {D2_MANIFEST_ALL}")

m_all = pd.read_csv(D2_MANIFEST_ALL)

# Basic required columns for testing and group analysis.
req_cols = {"split", "clip_path", "label_num", "task", "sex", "age"}
missing = [c for c in sorted(req_cols) if c not in m_all.columns]
if missing:
    raise ValueError(f"D2 manifest missing required columns: {missing}. Found: {list(m_all.columns)}")

# Infer dataset id from the most common value in the manifest (then filter to it).
if "dataset" in m_all.columns and m_all["dataset"].notna().any():
    d2_dataset_id = str(m_all["dataset"].astype(str).value_counts(dropna=True).idxmax())
    m_all = m_all[m_all["dataset"].astype(str) == d2_dataset_id].copy()
else:
    d2_dataset_id = "DX"

# --------- GUARD A ----------
# Ensures the evaluation is truly on D2 and not an accidental path mix-up.
if d2_dataset_id != "D2":
    raise RuntimeError(
        f"Expected D2 dataset_id=='D2' but got {d2_dataset_id!r}. "
        "This usually means D2_OUT_ROOT is wrong or the manifest is not D2. "
        f"D2_OUT_ROOT={D2_OUT_ROOT}"
    )

# Keep a small, consistent set of columns (fill with NaN if missing).
keep_cols = ["clip_path", "label_num", "task", "speaker_id", "sex", "age", "duration_sec", "split"]
for c in keep_cols:
    if c not in m_all.columns:
        m_all[c] = np.nan
m_all = m_all[keep_cols].copy()

# IMPORTANT: TEST split only
test_df = m_all[m_all["split"].astype(str) == "test"].reset_index(drop=True)

print(f"\nD2 dataset inferred: {d2_dataset_id}")
print(f"D2 TEST rows: {len(test_df)}")
print("D2 TEST label counts:", test_df["label_num"].value_counts(dropna=False).to_dict())
print("D2 TEST sex counts (raw):", test_df["sex"].value_counts(dropna=False).to_dict())

if len(test_df) == 0:
    raise RuntimeError("After filtering to split=='test', D2 manifest has 0 rows.")

# -------------------------
# 5) Fail-fast: confirm D2 TEST clips exist on disk
#    Stops early with a short list of missing paths.
# -------------------------
def _fail_fast_missing_paths(df: pd.DataFrame, name: str):
    missing_paths = []
    for p in tqdm(df["clip_path"].astype(str).tolist(), desc=f"Check {name} clip_path exists", dynamic_ncols=True):
        if not os.path.exists(p):
            missing_paths.append(p)
            if len(missing_paths) >= 10:
                break
    if missing_paths:
        raise FileNotFoundError(f"{name}: missing clip_path(s). Examples: {missing_paths}")

_fail_fast_missing_paths(test_df, "D2 TEST")

# -------------------------
# 6) Add task grouping used by the two-head model
#    Exact rule: task == "vowl" -> vowel, everything else -> other
# -------------------------
def _task_group(task_val) -> str:
    return "vowel" if str(task_val) == "vowl" else "other"

test_df["task_group"] = test_df["task"].apply(_task_group)

# -------------------------
# 6.5) Normalize sex for fairness and subgroup plots
#    D2 encoding is case-sensitive: "male"/"female", otherwise UNK
# -------------------------
def normalize_sex_d2_case_sensitive(val) -> str:
    if pd.isna(val):
        return "UNK"
    if val == "male":
        return "M"
    if val == "female":
        return "F"
    return "UNK"

test_df["sex_norm"] = test_df["sex"].apply(normalize_sex_d2_case_sensitive)
print("D2 TEST sex counts (normalized):", test_df["sex_norm"].value_counts(dropna=False).to_dict())
if (test_df["sex_norm"] == "UNK").any():
    print("NOTE: Some D2 'sex' values were not exactly 'male'/'female' and were mapped to 'UNK'.")

# -------------------------
# 7) Dataset and batching
#    - Loads audio from clip_path
#    - Builds attention_mask to ignore trailing padding for vowel clips
#    - Returns extra fields for predictions.csv
# -------------------------
class AudioManifestDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        clip_path = str(row["clip_path"])
        label = int(row["label_num"])
        task_group = str(row["task_group"])
        sex_norm = str(row["sex_norm"])
        speaker_id = row["speaker_id"] if "speaker_id" in row.index else np.nan

        if not os.path.exists(clip_path):
            raise FileNotFoundError(f"Missing clip_path: {clip_path}")

        y, sr = sf.read(clip_path, always_2d=False)
        if isinstance(y, np.ndarray) and y.ndim > 1:
            y = np.mean(y, axis=1)
        y = np.asarray(y, dtype=np.float32)

        # Strict sample rate check to match training assumptions.
        if int(sr) != SR_EXPECTED:
            raise ValueError(f"Sample rate mismatch (expected {SR_EXPECTED}): {clip_path} has sr={sr}")

        # Attention mask logic:
        # - vowel: mask out trailing near-zero tail (reduces learning from padded silence)
        # - other: keep all samples
        attn = np.ones((len(y),), dtype=np.int64)
        if task_group == "vowel":
            k = -1
            for j in range(len(y) - 1, -1, -1):
                if abs(float(y[j])) > TINY_THRESH:
                    k = j
                    break
            if k >= 0:
                attn[:k+1] = 1
                attn[k+1:] = 0
            else:
                attn[:] = 1
        else:
            attn[:] = 1

        return {
            "input_values": torch.from_numpy(y),
            "attention_mask": torch.from_numpy(attn),
            "labels": torch.tensor(label, dtype=torch.long),
            "task_group": task_group,
            "sex_norm": sex_norm,
            "clip_path": clip_path,
            "speaker_id": speaker_id,
        }

# Pads variable-length audio in a batch and carries over metadata lists.
def collate_fn(batch):
    max_len = int(max(b["input_values"].numel() for b in batch))
    input_vals, attn_masks, labels = [], [], []
    task_groups, sex_norms, clip_paths, speaker_ids = [], [], [], []
    for b in batch:
        x = b["input_values"]
        a = b["attention_mask"]
        pad = max_len - x.numel()
        if pad > 0:
            x = torch.cat([x, torch.zeros(pad, dtype=x.dtype)], dim=0)
            a = torch.cat([a, torch.zeros(pad, dtype=a.dtype)], dim=0)
        input_vals.append(x)
        attn_masks.append(a)
        labels.append(b["labels"])
        task_groups.append(b["task_group"])
        sex_norms.append(b["sex_norm"])
        clip_paths.append(b["clip_path"])
        speaker_ids.append(b["speaker_id"])
    return {
        "input_values": torch.stack(input_vals, dim=0),
        "attention_mask": torch.stack(attn_masks, dim=0),
        "labels": torch.stack(labels, dim=0),
        "task_group": task_groups,
        "sex_norm": sex_norms,
        "clip_path": clip_paths,
        "speaker_id": speaker_ids,
    }

# -------------------------
# 8) Model definition (backbone frozen, heads loaded from trainval)
#    - Two heads: one for vowel, one for other
#    - Chooses head per sample using task_group
# -------------------------
class Wav2Vec2TwoHeadClassifier(nn.Module):
    def __init__(self, ckpt: str, dropout_p: float):
        super().__init__()
        self.backbone = Wav2Vec2Model.from_pretrained(ckpt, use_safetensors=True)
        for p in self.backbone.parameters():
            p.requires_grad = False

        H = int(self.backbone.config.hidden_size)
        self.pre_vowel = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.pre_other = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.head_vowel = nn.Linear(H, 2)
        self.head_other = nn.Linear(H, 2)

    # Mean pooling with masking aligned to wav2vec2 feature frames.
    def masked_mean_pool(self, last_hidden, attn_mask_samples):
        feat_mask = self.backbone._get_feature_vector_attention_mask(last_hidden.shape[1], attn_mask_samples)
        feat_mask = feat_mask.to(last_hidden.device).unsqueeze(-1).type_as(last_hidden)
        masked = last_hidden * feat_mask
        denom = feat_mask.sum(dim=1).clamp(min=1.0)
        return masked.sum(dim=1) / denom

    # Forces head computation in fp32 (stable with AMP).
    def _heads_fp32(self, x_any: torch.Tensor, head: nn.Module) -> torch.Tensor:
        x = x_any.float()
        if x.is_cuda:
            with torch.amp.autocast(device_type="cuda", enabled=False):
                return head(x)
        return head(x)

    # Produces logits for class 0/1 using the head matching task_group.
    def forward_logits(self, input_values, attention_mask, task_group):
        with torch.no_grad():
            out = self.backbone(input_values=input_values, attention_mask=attention_mask)
            last_hidden = out.last_hidden_state

        pooled = self.masked_mean_pool(last_hidden, attention_mask)

        z_v = self.pre_vowel(pooled.float())
        z_o = self.pre_other(pooled.float())

        logits_v = self._heads_fp32(z_v, self.head_vowel)
        logits_o = self._heads_fp32(z_o, self.head_other)

        is_vowel = torch.tensor([tg == "vowel" for tg in task_group], device=pooled.device, dtype=torch.bool)
        logits = logits_o.clone()
        if is_vowel.any():
            logits[is_vowel] = logits_v[is_vowel]
        return logits

# -------------------------
# 9) Metrics and plotting helpers
#    - AUROC uses probabilities for class 1
#    - Threshold metrics use a fixed threshold (global mean from trainval)
# -------------------------
def compute_auc(y_true, y_prob):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    if len(np.unique(y_true)) < 2:
        return float("nan")
    return float(roc_auc_score(y_true, y_prob))

def compute_threshold_metrics(y_true, y_prob, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])

    eps = 1e-12
    acc = (TP + TN) / max(1, (TP + TN + FP + FN))
    prec = TP / (TP + FP + eps)
    rec = TP / (TP + FN + eps)     # sensitivity
    f1 = 2 * prec * rec / (prec + rec + eps)
    spec = TN / (TN + FP + eps)

    try:
        mcc = float(matthews_corrcoef(y_true, y_pred)) if len(np.unique(y_pred)) > 1 else float("nan")
    except Exception:
        mcc = float("nan")

    pval = float("nan")
    try:
        from scipy.stats import fisher_exact  # type: ignore
        _, pval = fisher_exact([[TN, FP], [FN, TP]], alternative="two-sided")
        pval = float(pval)
    except Exception:
        pval = float("nan")

    return {
        "threshold": float(thr),
        "tn": TN, "fp": FP, "fn": FN, "tp": TP,
        "accuracy": float(acc),
        "precision": float(prec),
        "recall": float(rec),
        "f1": float(f1),
        "sensitivity": float(rec),
        "specificity": float(spec),
        "mcc": float(mcc),
        "p_value_fisher_two_sided": float(pval),
    }

# Saves a basic ROC curve image for quick inspection.
def save_roc_curve_png(y_true, y_prob, out_png, title_suffix="Test"):
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    plt.figure()
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curve ({title_suffix})")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

# Saves a basic confusion matrix image at a chosen threshold.
def save_confusion_png(y_true, y_prob, out_png, thr=0.5, title_suffix="Test"):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(int)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])

    plt.figure()
    plt.imshow(cm)
    plt.xticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.yticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix ({title_suffix}, thr={thr:.4f})")
    for (i, j), v in np.ndenumerate(cm):
        plt.text(j, i, str(v), ha="center", va="center")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

# -------------------------
# 9.5) Fairness helper (H3)
#    Computes FNR for M and F, then ΔFNR = FNR(F) - FNR(M)
# -------------------------
def compute_fnr_by_group_signed(y_true, y_prob, groups, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    groups = np.asarray(list(groups), dtype=object)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    out = {}
    for g in sorted(set(groups.tolist())):
        mask_g = (groups == g)
        if int(mask_g.sum()) == 0:
            continue

        pos_mask = mask_g & (y_true == 1)
        n_pos = int(pos_mask.sum())
        if n_pos == 0:
            out[g] = {"n_total": int(mask_g.sum()), "n_pos": 0, "tp": 0, "fn": 0, "fnr": float("nan")}
            continue

        tp = int(((y_pred == 1) & pos_mask).sum())
        fn = int(((y_pred == 0) & pos_mask).sum())
        fnr = float(fn / max(1, (fn + tp)))
        out[g] = {"n_total": int(mask_g.sum()), "n_pos": int(n_pos), "tp": int(tp), "fn": int(fn), "fnr": float(fnr)}

    fnr_m = out.get("M", {}).get("fnr", float("nan"))
    fnr_f = out.get("F", {}).get("fnr", float("nan"))
    if (not np.isnan(fnr_m)) and (not np.isnan(fnr_f)):
        delta_signed = float(fnr_f - fnr_m)
        delta_abs = float(abs(delta_signed))
    else:
        delta_signed = float("nan")
        delta_abs = float("nan")

    return out, delta_signed, delta_abs

# Small helpers to report confusion matrices by subgroup.
def compute_confusion_counts(y_true, y_prob, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    return {"TN": int(cm[0, 0]), "FP": int(cm[0, 1]), "FN": int(cm[1, 0]), "TP": int(cm[1, 1])}

def compute_confusion_by_group(y_true, y_prob, groups, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    groups = np.asarray(list(groups), dtype=object)
    out = {}
    for g in sorted(set(groups.tolist())):
        mask = (groups == g)
        if int(mask.sum()) == 0:
            continue
        out[g] = {"n": int(mask.sum()), "confusion": compute_confusion_counts(y_true[mask], y_prob[mask], thr=thr)}
    return out

# -------------------------
# 10) Seed control
#    Keeps results repeatable for a given seed.
# -------------------------
def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# -------------------------
# 11) Find the most recent matching D7 trainval experiment (trainEnh2)
#    Requires: summary_trainval.json + best_heads.pt for all three seeds.
# -------------------------
TRAINVAL_ROOT = Path(D7_OUT_ROOT) / "trainval_runs"
if not TRAINVAL_ROOT.exists():
    raise FileNotFoundError(f"Missing trainval_runs folder under D7_OUT_ROOT: {str(TRAINVAL_ROOT)}")

exp_dirs = sorted([p for p in TRAINVAL_ROOT.glob("exp_*") if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
if not exp_dirs:
    raise FileNotFoundError(f"No exp_* folders found under: {str(TRAINVAL_ROOT)}")

train_dataset_id = "D7"  # run folder naming from trainval code (run_D7_seedXXXX)

def _is_enhanced_exp_dir(exp_path: Path, required_substring: str) -> bool:
    return (required_substring.lower() in exp_path.name.lower())

def _has_all_seeds_and_summary(exp_path: Path, dataset_id: str, seeds: list) -> bool:
    summary_path = exp_path / "summary_trainval.json"
    if not summary_path.exists():
        return False
    for s in seeds:
        p = exp_path / f"run_{dataset_id}_seed{s}" / "best_heads.pt"
        if not p.exists():
            return False
    return True

chosen_exp = None
for ed in exp_dirs:
    if not _is_enhanced_exp_dir(ed, REQUIRED_EXP_SUBSTRING):
        continue
    if _has_all_seeds_and_summary(ed, train_dataset_id, SEEDS):
        chosen_exp = ed
        break

if chosen_exp is None:
    sample = exp_dirs[0]
    raise FileNotFoundError(
        "Could not find a recent D7 *ENHANCED* trainval experiment folder that:\n"
        f"  (1) contains substring '{REQUIRED_EXP_SUBSTRING}' (case-insensitive) in the exp folder name, and\n"
        "  (2) contains all 3 best_heads.pt files + summary_trainval.json.\n\n"
        f"Most recent exp checked (for reference): {str(sample)}"
    )

# Run tag uses the full experiment folder name (keeps traceability across runs).
FULL_TRAINVAL_EXP_TAG = chosen_exp.name
TAG_SAFE = _sanitize_tag(FULL_TRAINVAL_EXP_TAG)
RUN_PARENT_DIRNAME = f"run_{TAG_SAFE}__{RUN_STAMP}"

# Output structure: one stamped folder per run + one stable folder per tag.
TEST_ROOT = Path(D7_OUT_ROOT) / "monolingual_test_runs"
RUN_ROOT  = TEST_ROOT / RUN_PARENT_DIRNAME
RUN_ROOT.mkdir(parents=True, exist_ok=True)

TAG_ROOT = TEST_ROOT / f"run_{TAG_SAFE}"
TAG_ROOT.mkdir(parents=True, exist_ok=True)

# Builder-aligned config and logs folders (stable for this enhancement tag).
ENH_TAG_SAFE = _sanitize_tag(REQUIRED_EXP_SUBSTRING)
cfg_dir  = Path(D7_OUT_ROOT) / "config" / f"D7_{ENH_TAG_SAFE}_on_D2_Test"
logs_dir = Path(D7_OUT_ROOT) / "logs"   / f"D7_{ENH_TAG_SAFE}_on_D2_Test"
cfg_dir.mkdir(parents=True, exist_ok=True)
logs_dir.mkdir(parents=True, exist_ok=True)

RUN_CONFIG_PATH       = cfg_dir / "run_config.json"
WARNINGS_CSV_PATH     = logs_dir / "preprocess_warnings.csv"
DATASET_SUMMARY_PATH  = logs_dir / "dataset_summary.json"

print("\nUsing D7 ENHANCED Train+Val experiment folder:")
print(" ", str(chosen_exp))
print("FULL_TRAINVAL_EXP_TAG:", FULL_TRAINVAL_EXP_TAG)
print("RUN_ROOT:", str(RUN_ROOT))
print("cfg_dir:", str(cfg_dir))
print("logs_dir:", str(logs_dir))

# --------- GUARD B ----------
# Double-check that the chosen folder still contains the required head files.
for s in SEEDS:
    p = chosen_exp / f"run_{train_dataset_id}_seed{s}" / "best_heads.pt"
    if not p.exists():
        raise RuntimeError(f"Trainval artifact missing after choosing exp. Missing: {str(p)}")

# Load trainval summary for the global threshold.
summary_trainval_path = chosen_exp / "summary_trainval.json"
with open(summary_trainval_path, "r", encoding="utf-8") as f:
    d7_trainval_summary = json.load(f)

# -------------------------
# 11.5) Threshold: use one global mean val-opt threshold for all seeds
#    Falls back to 0.5 if the summary value is missing.
# -------------------------
val_opt_obj = (d7_trainval_summary or {}).get("val_optimal_threshold", {}) or {}
thr_mean_sd = (val_opt_obj.get("mean_sd", {}) or {})

def _get_mean_val_opt_threshold() -> float:
    try:
        return float(thr_mean_sd.get("mean", float("nan")))
    except Exception:
        return float("nan")

THR_MEAN_FROM_TRAINVAL = _get_mean_val_opt_threshold()

if np.isnan(THR_MEAN_FROM_TRAINVAL):
    THR_USED_GLOBAL = 0.5
    THR_GLOBAL_NOTE = (
        "Mean val-opt threshold was missing/NaN in D7 enhanced summary_trainval.json. "
        "Fallback: THR_USED_GLOBAL=0.5 for ALL seeds."
    )
else:
    THR_USED_GLOBAL = float(THR_MEAN_FROM_TRAINVAL)
    THR_GLOBAL_NOTE = None

print("\nVAL-opt threshold selection for TEST (GLOBAL):")
print("  Source: summary_trainval.json -> val_optimal_threshold.mean_sd.mean")
print(f"  THR_USED_GLOBAL: {THR_USED_GLOBAL:.6f}")
if THR_GLOBAL_NOTE is not None:
    print("  NOTE:", THR_GLOBAL_NOTE)

# -------------------------
# 13) Build DataLoader for D2 TEST
# -------------------------
test_ds = AudioManifestDataset(test_df)
test_loader = DataLoader(
    test_ds,
    batch_size=PER_DEVICE_BS,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    collate_fn=collate_fn,
)

# -------------------------
# 14) Warm-up: read a few batches to catch data issues early
# -------------------------
print("\nWarm-up: loading up to 3 D2 TEST batches...")
t0 = time.time()

def _warmup(loader, name):
    nb = len(loader)
    wb = min(3, nb)
    if wb == 0:
        raise RuntimeError(f"{name} DataLoader has 0 batches. Check df length and PER_DEVICE_BS.")
    it = iter(loader)
    for i in range(wb):
        _ = next(it)
        print(f"  loaded warmup {name} batch {i+1}/{wb}")

_warmup(test_loader, "D2 TEST")
print(f"Warm-up done in {time.time()-t0:.2f}s")

# -------------------------
# 15) Load heads from best_heads.pt into the model skeleton
#    Backbone stays frozen; only head weights are loaded.
# -------------------------
def load_heads_into_model(model: Wav2Vec2TwoHeadClassifier, best_heads_path: Path):
    if not best_heads_path.exists():
        raise FileNotFoundError(f"Missing best_heads.pt: {str(best_heads_path)}")
    state = torch.load(str(best_heads_path), map_location="cpu")
    model.pre_vowel.load_state_dict(state["pre_vowel"], strict=True)
    model.pre_other.load_state_dict(state["pre_other"], strict=True)
    model.head_vowel.load_state_dict(state["head_vowel"], strict=True)
    model.head_other.load_state_dict(state["head_other"], strict=True)
    return model

# -------------------------
# 16) Inference helper
#    Returns probabilities plus metadata needed for predictions.csv.
# -------------------------
def _infer_probs_with_meta(loader, model, desc):
    use_amp = bool(USE_AMP and DEVICE.type == "cuda")
    all_probs, all_true, all_sex = [], [], []
    all_clip, all_spk, all_task = [], [], []

    pbar = tqdm(loader, desc=desc, dynamic_ncols=True)
    with torch.inference_mode():
        for batch in pbar:
            input_values = batch["input_values"].to(DEVICE, non_blocking=False)
            attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
            labels = batch["labels"].to(DEVICE, non_blocking=False)
            task_group = batch["task_group"]
            sex_norm = batch["sex_norm"]
            clip_paths = batch["clip_path"]
            speaker_ids = batch["speaker_id"]

            with torch.amp.autocast(device_type="cuda", enabled=use_amp):
                logits = model.forward_logits(input_values, attention_mask, task_group)

            probs = torch.softmax(logits.float(), dim=-1)[:, 1].detach().cpu().numpy().astype(np.float64)
            all_probs.extend(probs.tolist())
            all_true.extend(labels.detach().cpu().numpy().astype(np.int64).tolist())
            all_sex.extend(list(sex_norm))
            all_task.extend(list(task_group))
            all_clip.extend(list(clip_paths))
            all_spk.extend([("" if (x is None or (isinstance(x, float) and np.isnan(x))) else str(x)) for x in speaker_ids])

    return (
        np.asarray(all_true, dtype=np.int64),
        np.asarray(all_probs, dtype=np.float64),
        np.asarray(all_sex, dtype=object),
        np.asarray(all_clip, dtype=object),
        np.asarray(all_spk, dtype=object),
        np.asarray(all_task, dtype=object),
    )

# -------------------------
# 17) Single-seed evaluation
#    Writes: metrics.json, predictions.csv, roc_curve.png, confusion matrix images.
# -------------------------
def run_test_once(seed: int):
    set_all_seeds(seed)

    run_dir = RUN_ROOT / f"run_{train_dataset_id}_on_{d2_dataset_id}test_seed{seed}"
    run_dir.mkdir(parents=True, exist_ok=True)

    best_heads_path = chosen_exp / f"run_{train_dataset_id}_seed{seed}" / "best_heads.pt"

    print(f"\n[seed={seed}] Loading model + heads from:")
    print(" ", str(best_heads_path))

    model = Wav2Vec2TwoHeadClassifier(BACKBONE_CKPT, dropout_p=DROPOUT_P).to(DEVICE)
    model = load_heads_into_model(model, best_heads_path)
    model.eval()

    # One global threshold across all seeds for apples-to-apples comparison.
    thr_used = float(THR_USED_GLOBAL)
    thr_note = THR_GLOBAL_NOTE

    # Inference (also collects fields for predictions.csv)
    yt_true, yt_prob, yt_sex, yt_clip, yt_spk, yt_task = _infer_probs_with_meta(
        test_loader, model, desc=f"[seed={seed}] Test (D2 TEST)"
    )
    test_auc = compute_auc(yt_true, yt_prob)

    # Metrics and subgroup diagnostics at thr_used
    thr_metrics_test = compute_threshold_metrics(yt_true, yt_prob, thr=thr_used)
    fnr_by_sex, delta_f_minus_m, delta_abs = compute_fnr_by_group_signed(yt_true, yt_prob, yt_sex, thr=thr_used)
    confusion_by_sex = compute_confusion_by_group(yt_true, yt_prob, yt_sex, thr=thr_used)

    # Plots (overall)
    roc_png = run_dir / "roc_curve.png"
    cm_png  = run_dir / "confusion_matrix.png"
    save_roc_curve_png(yt_true, yt_prob, str(roc_png), title_suffix=f"D2 TEST (seed={seed})")
    save_confusion_png(yt_true, yt_prob, str(cm_png), thr=thr_used, title_suffix=f"D2 TEST (seed={seed})")

    # Plots (by sex: M/F only)
    cm_m_png = None
    cm_f_png = None
    mask_m = (yt_sex == "M")
    mask_f = (yt_sex == "F")

    if int(mask_m.sum()) > 0:
        cm_m_png = run_dir / "confusion_matrix_M.png"
        save_confusion_png(yt_true[mask_m], yt_prob[mask_m], str(cm_m_png), thr=thr_used, title_suffix=f"D2 TEST SEX=M (seed={seed})")

    if int(mask_f.sum()) > 0:
        cm_f_png = run_dir / "confusion_matrix_F.png"
        save_confusion_png(yt_true[mask_f], yt_prob[mask_f], str(cm_f_png), thr=thr_used, title_suffix=f"D2 TEST SEX=F (seed={seed})")

    # predictions.csv: one row per clip, includes metadata for later analysis.
    pred_df = pd.DataFrame({
        "clip_path": yt_clip.astype(str),
        "y_true": yt_true.astype(int),
        "y_score": yt_prob.astype(float),
        "sex_norm": yt_sex.astype(str),
        "speaker_id": yt_spk.astype(str),
        "task_group": yt_task.astype(str),
        "seed": int(seed),
        "trainval_exp_tag": str(FULL_TRAINVAL_EXP_TAG),
        "run_stamp": str(RUN_STAMP),
        "threshold_used_global": float(thr_used),
    })
    pred_csv_path = run_dir / "predictions.csv"
    pred_df.to_csv(pred_csv_path, index=False)

    # metrics.json: per-seed results and pointers to artifacts.
    metrics = {
        "train_dataset": train_dataset_id,
        "test_dataset": d2_dataset_id,
        "seed": int(seed),

        "n_test": int(len(test_df)),
        "label_counts_test": test_df["label_num"].value_counts(dropna=False).to_dict(),
        "sex_counts_test_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),

        "test_auroc": float(test_auc),

        "threshold_source": "D7 enhanced trainval summary_trainval.json -> val_optimal_threshold.mean_sd.mean",
        "trainval_experiment_used": str(chosen_exp),
        "trainval_exp_tag": str(FULL_TRAINVAL_EXP_TAG),
        "trainval_summary_path": str(summary_trainval_path),

        "test_threshold_used_global": float(thr_used),
        "test_threshold_note_global": thr_note,

        "threshold_metrics_test_at_thr_used": thr_metrics_test,

        "fairness_test_at_thr_used": {
            "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at test_threshold_used_global.",
            "threshold_used": float(thr_used),
            "fnr_by_sex_norm": fnr_by_sex,
            "delta_fnr_F_minus_M": float(delta_f_minus_m),
            "delta_fnr_abs": float(delta_abs),
            "note": "If n_PD for a sex is 0, its FNR is NaN and ΔFNR is NaN.",
            "sex_normalization_note": "D2 mapping: exact 'male'->M and 'female'->F (case-sensitive); otherwise UNK.",
        },

        "confusion_by_sex_norm_at_thr_used": confusion_by_sex,

        "artifacts": {
            "predictions_csv": str(pred_csv_path),
            "roc_curve_png": str(roc_png),
            "confusion_matrix_png": str(cm_png),
            "confusion_matrix_M_png": str(cm_m_png) if cm_m_png is not None else None,
            "confusion_matrix_F_png": str(cm_f_png) if cm_f_png is not None else None,
        },

        "d7_out_root": D7_OUT_ROOT,
        "d2_out_root": D2_OUT_ROOT,
        "d2_manifest_all": D2_MANIFEST_ALL,

        "best_heads_path": str(best_heads_path),
        "backbone_ckpt": BACKBONE_CKPT,
        "dropout_p": float(DROPOUT_P),
        "timestamp": time.strftime("%Y%m%d_%H%M%S"),
    }

    metrics_path = run_dir / "metrics.json"
    with open(metrics_path, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2)

    print(f"[seed={seed}] DONE | test_AUROC={test_auc:.6f}")
    print(f"[seed={seed}] Threshold used (GLOBAL mean from D7 enhanced trainval): {thr_used:.6f}")
    print(f"[seed={seed}] WROTE:")
    print(" ", str(metrics_path))
    print(" ", str(pred_csv_path))
    print(" ", str(roc_png))
    print(" ", str(cm_png))

    return {
        "seed": int(seed),
        "thr_used": float(thr_used),
        "thr_note": thr_note,
        "test_auc": float(test_auc),
        "thr_metrics_test": thr_metrics_test,
        "fnr_by_sex": fnr_by_sex,
        "delta_signed": float(delta_f_minus_m),
        "delta_abs": float(delta_abs),
        "run_dir": str(run_dir),
        "predictions_csv": str(pred_csv_path),
    }

# -------------------------
# 18) Run all seeds and aggregate results
#    - AUROC: mean ± 95% CI (t, n=3)
#    - Other metrics and fairness: mean ± SD
# -------------------------
results = []
for seed in SEEDS:
    results.append(run_test_once(seed))

aurocs = [r["test_auc"] for r in results]
t_crit = 4.302652729911275  # df=2, 95% CI
n = len(aurocs)
mean_auc = float(np.mean(aurocs))
std_auc = float(np.std(aurocs, ddof=1)) if n > 1 else 0.0
half_width = float(t_crit * (std_auc / math.sqrt(n))) if n > 1 else 0.0
ci95 = [float(mean_auc - half_width), float(mean_auc + half_width)]

def _mean_sd(vals):
    vals = np.asarray(vals, dtype=np.float64)
    mu = float(np.nanmean(vals)) if np.any(~np.isnan(vals)) else float("nan")
    sd = float(np.nanstd(vals, ddof=1)) if np.sum(~np.isnan(vals)) > 1 else 0.0
    return mu, sd

# Aggregate threshold-based metrics.
thr_list = [r["thr_metrics_test"] for r in results]
keys = ["accuracy","precision","recall","f1","sensitivity","specificity","mcc","p_value_fisher_two_sided"]
agg = {}
for k in keys:
    v = [float(tm.get(k, float("nan"))) for tm in thr_list]
    mu, sd = _mean_sd(v)
    agg[k] = {
        "mean": float(mu),
        "sd": float(sd),
        "values_by_seed": {str(s): float(tm.get(k, float("nan"))) for s, tm in zip(SEEDS, thr_list)},
    }

# Confusion counts by seed (useful for spot checks).
cm_by_seed = {
    str(s): {"tn": int(thr_list[i]["tn"]), "fp": int(thr_list[i]["fp"]), "fn": int(thr_list[i]["fn"]), "tp": int(thr_list[i]["tp"])}
    for i, s in enumerate(SEEDS)
}

# Fairness aggregation (signed and absolute ΔFNR).
fnr_by_seed = {str(r["seed"]): r["fnr_by_sex"] for r in results}
delta_signed_by_seed = {str(r["seed"]): float(r["delta_signed"]) for r in results}
delta_abs_by_seed = {str(r["seed"]): float(r["delta_abs"]) for r in results}

fnr_m_vals, fnr_f_vals = [], []
d_signed_vals, d_abs_vals = [], []
for r in results:
    d = r["fnr_by_sex"] or {}
    fnr_m_vals.append(float(d.get("M", {}).get("fnr", float("nan"))))
    fnr_f_vals.append(float(d.get("F", {}).get("fnr", float("nan"))))
    d_signed_vals.append(float(r["delta_signed"]))
    d_abs_vals.append(float(r["delta_abs"]))

fnr_m_mean, fnr_m_sd = _mean_sd(fnr_m_vals)
fnr_f_mean, fnr_f_sd = _mean_sd(fnr_f_vals)
d_signed_mean, d_signed_sd = _mean_sd(d_signed_vals)
d_abs_mean, d_abs_sd = _mean_sd(d_abs_vals)

# Console summary (quick read without opening files).
print("\nTest AUROC by seed:")
for r in results:
    print(f"  seed {r['seed']}: {r['test_auc']:.6f}")
print(f"\nMean Test AUROC: {mean_auc:.6f}")
print(f"95% CI (t, n=3): [{ci95[0]:.6f}, {ci95[1]:.6f}]")

print("\nTEST threshold used (GLOBAL mean val-opt from D7 enhanced trainval):")
print(f"  THR_USED_GLOBAL: {THR_USED_GLOBAL:.6f}")
if THR_GLOBAL_NOTE is not None:
    print("  NOTE:", THR_GLOBAL_NOTE)

print("\nThreshold metrics on D2 TEST @ THR_USED_GLOBAL (mean ± SD across seeds):")
for k in ["accuracy","precision","sensitivity","specificity","f1","mcc"]:
    print(f"  {k}: {agg[k]['mean']:.6f} ± {agg[k]['sd']:.6f}")
print("  fisher_p_value_two_sided:", f"{agg['p_value_fisher_two_sided']['mean']:.6g} ± {agg['p_value_fisher_two_sided']['sd']:.6g}")

print("\nFAIRNESS (H3) on D2 TEST @ THR_USED_GLOBAL across seeds (mean ± SD):")
print(f"  FNR_M: {fnr_m_mean:.6f} ± {fnr_m_sd:.6f}")
print(f"  FNR_F: {fnr_f_mean:.6f} ± {fnr_f_sd:.6f}")
print(f"  ΔFNR (F-M): {d_signed_mean:.6f} ± {d_signed_sd:.6f}")
print(f"  |ΔFNR|: {d_abs_mean:.6f} ± {d_abs_sd:.6f}")

# -------------------------
# 18.1) Build run-level summary object for files and pointers
#    Includes per-seed results plus aggregated statistics.
# -------------------------
summary = {
    "run_tag_full_trainval_exp": str(FULL_TRAINVAL_EXP_TAG),
    "run_tag_safe": str(TAG_SAFE),
    "enh_tag": str(REQUIRED_EXP_SUBSTRING),
    "run_stamp": str(RUN_STAMP),

    "train_dataset": train_dataset_id,
    "test_dataset": d2_dataset_id,

    "d7_out_root": D7_OUT_ROOT,
    "d2_out_root": D2_OUT_ROOT,
    "d2_manifest_all": D2_MANIFEST_ALL,

    "enhanced_exp_required_substring_case_insensitive": REQUIRED_EXP_SUBSTRING,
    "trainval_experiment_used": str(chosen_exp),
    "trainval_summary_path": str(summary_trainval_path),

    "seeds": SEEDS,

    "n_test": int(len(test_df)),
    "label_counts_test": test_df["label_num"].value_counts(dropna=False).to_dict(),
    "sex_counts_test_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),

    "test_threshold_used": {
        "threshold_used_global": float(THR_USED_GLOBAL),
        "threshold_source": "D7 enhanced trainval summary_trainval.json -> val_optimal_threshold.mean_sd.mean",
        "note_global": (THR_GLOBAL_NOTE if THR_GLOBAL_NOTE is not None else ""),
        "per_seed_repetition_for_audit": {str(r["seed"]): float(r["thr_used"]) for r in results},
    },

    "test_aurocs_by_seed": {str(r["seed"]): float(r["test_auc"]) for r in results},
    "mean_test_auroc": float(mean_auc),
    "t_crit_df2_95": float(t_crit),
    "half_width_95_ci": float(half_width),
    "ci95_test_auroc": ci95,

    "threshold_metrics_mean_sd_test_at_thr_used": agg,
    "confusion_matrix_by_seed_test_at_thr_used": cm_by_seed,

    "fairness_test_at_thr_used": {
        "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at test_threshold_used_global.",
        "fnr_by_sex_norm_by_seed": fnr_by_seed,
        "delta_fnr_F_minus_M_by_seed": delta_signed_by_seed,
        "delta_fnr_abs_by_seed": delta_abs_by_seed,
        "fnr_M_mean_sd": {"mean": float(fnr_m_mean), "sd": float(fnr_m_sd)},
        "fnr_F_mean_sd": {"mean": float(fnr_f_mean), "sd": float(fnr_f_sd)},
        "delta_fnr_F_minus_M_mean_sd": {"mean": float(d_signed_mean), "sd": float(d_signed_sd)},
        "delta_fnr_abs_mean_sd": {"mean": float(d_abs_mean), "sd": float(d_abs_sd)},
    },

    "run_dirs": [r["run_dir"] for r in results],
    "predictions_csv_by_seed": {str(r["seed"]): str(r["predictions_csv"]) for r in results},
}

# -------------------------
# 18.2) Write summary files and pointer files
#    - summary_test.json inside this run folder
#    - history_index.jsonl append-only
#    - summary_latest.json and last_run_pointer.json overwritten for convenience
#    - tag_run_pointer.json stable per experiment tag
# -------------------------
summary_path = RUN_ROOT / "summary_test.json"
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(summary, f, indent=2)

history_path = TEST_ROOT / "history_index.jsonl"
with open(history_path, "a", encoding="utf-8") as f:
    f.write(json.dumps(summary) + "\n")

summary_latest_path = TEST_ROOT / "summary_latest.json"
latest_summary_obj = {
    "run_tag_full_trainval_exp": str(FULL_TRAINVAL_EXP_TAG),
    "enh_tag": str(REQUIRED_EXP_SUBSTRING),
    "run_stamp": str(RUN_STAMP),
    "run_root": str(RUN_ROOT),
    "summary_test_json": str(summary_path),
    "seed_run_dirs": [str(RUN_ROOT / f"run_{train_dataset_id}_on_{d2_dataset_id}test_seed{s}") for s in SEEDS],
}
with open(summary_latest_path, "w", encoding="utf-8") as f:
    json.dump(latest_summary_obj, f, indent=2)

global_pointer_path = TEST_ROOT / "last_run_pointer.json"
with open(global_pointer_path, "w", encoding="utf-8") as f:
    json.dump(latest_summary_obj, f, indent=2)

tag_pointer_path = TAG_ROOT / "tag_run_pointer.json"
tag_pointer_obj = dict(latest_summary_obj)
tag_pointer_obj["tag_root"] = str(TAG_ROOT)
with open(tag_pointer_path, "w", encoding="utf-8") as f:
    json.dump(tag_pointer_obj, f, indent=2)

print("\nWROTE run summary:", str(summary_path))
print("APPENDED history index:", str(history_path))
print("WROTE latest summary:", str(summary_latest_path))
print("WROTE global pointer:", str(global_pointer_path))
print("WROTE tag pointer:", str(tag_pointer_path))
print("Open this folder to access artifacts:", str(RUN_ROOT))

# -------------------------
# 18.5) Builder-aligned artifacts (config + simple logs placeholder)
#    Creates stable config/log summaries under the enhancement tag folder.
# -------------------------
_backup_if_exists(RUN_CONFIG_PATH)
_backup_if_exists(WARNINGS_CSV_PATH)
_backup_if_exists(DATASET_SUMMARY_PATH)

run_config = {
    "mode": f"D7_{ENH_TAG_SAFE}_on_D2_Test",
    "created_utc": datetime.utcnow().isoformat(),
    "run_stamp": RUN_STAMP,

    "enh_tag": str(REQUIRED_EXP_SUBSTRING),

    "d7_out_root": D7_OUT_ROOT,
    "d2_out_root": D2_OUT_ROOT,
    "d2_manifest_all": D2_MANIFEST_ALL,

    "enhanced_exp_required_substring_case_insensitive": REQUIRED_EXP_SUBSTRING,
    "trainval_experiment_used": str(chosen_exp),
    "trainval_exp_tag": str(FULL_TRAINVAL_EXP_TAG),
    "trainval_summary_path": str(summary_trainval_path),

    "threshold_source": "summary_trainval.json -> val_optimal_threshold.mean_sd.mean",
    "threshold_used_global": float(THR_USED_GLOBAL),
    "threshold_note_global": (THR_GLOBAL_NOTE if THR_GLOBAL_NOTE is not None else ""),

    "seeds": SEEDS,
    "use_amp": bool(USE_AMP and DEVICE.type == "cuda"),
    "per_device_bs": int(PER_DEVICE_BS),
    "effective_bs": int(PER_DEVICE_BS * GRAD_ACCUM),
    "num_workers": int(NUM_WORKERS),
    "pin_memory": bool(PIN_MEMORY),
    "backbone_ckpt": BACKBONE_CKPT,
    "dropout_p": float(DROPOUT_P),

    "monolingual_test_runs_root": str(TEST_ROOT),
    "run_root": str(RUN_ROOT),
    "summary_test_json": str(summary_path),
    "last_run_pointer_json": str(global_pointer_path),
    "tag_run_pointer_json": str(tag_pointer_path),
}
with open(RUN_CONFIG_PATH, "w", encoding="utf-8") as f:
    json.dump(run_config, f, indent=2)

dataset_summary = {
    "mode": f"D7_{ENH_TAG_SAFE}_on_D2_Test",
    "created_utc": datetime.utcnow().isoformat(),
    "status": "SUCCESS",
    "run_stamp": RUN_STAMP,

    "enh_tag": str(REQUIRED_EXP_SUBSTRING),

    "d2_dataset_id": d2_dataset_id,
    "n_test": int(len(test_df)),
    "label_counts_test": test_df["label_num"].value_counts(dropna=False).to_dict(),
    "sex_counts_test_raw": test_df["sex"].value_counts(dropna=False).to_dict(),
    "sex_counts_test_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),

    "trainval_experiment_used": str(chosen_exp),
    "trainval_exp_tag": str(FULL_TRAINVAL_EXP_TAG),
    "trainval_summary_path": str(summary_trainval_path),

    "threshold_used_global": float(THR_USED_GLOBAL),

    "monolingual_test_runs_root": str(TEST_ROOT),
    "run_root": str(RUN_ROOT),
    "summary_test_json": str(summary_path),
    "last_run_pointer_json": str(global_pointer_path),
}
with open(DATASET_SUMMARY_PATH, "w", encoding="utf-8") as f:
    json.dump(dataset_summary, f, indent=2)

# Placeholder for symmetry with preprocessing logs.
with open(WARNINGS_CSV_PATH, "w", encoding="utf-8") as f:
    f.write("ts,level,message\n")

print("\nWROTE (builder-aligned):")
print(" ", str(RUN_CONFIG_PATH))
print(" ", str(WARNINGS_CSV_PATH))
print(" ", str(DATASET_SUMMARY_PATH))

# -------------------------
# 19) Unassign runtime (stop GPU)
# -------------------------
print("\nAll done. Unassigning the runtime...")
try:
    from google.colab import runtime  # type: ignore
    runtime.unassign()
except Exception as e:
    print("Could not unassign runtime automatically. Stop runtime manually if needed.")
    print("Reason:", repr(e))

The following cell builds the enhanced D7 training split **train_enh3** by starting with the full original D7 training set and adding a new, speaker-balanced subset from the D2 training data. The D2 portion is about **10 percent of all D2 training speakers** and is selected at the speaker level, so when a speaker is chosen, all of that speaker’s training clips are included. Any D2 speakers that were already used in **train_enh1** or **train_enh2** are excluded, so **train_enh3 only adds D2 speakers that have not appeared in earlier enhanced sets**.

The cell first loads `manifest_all.csv` for D7 and D2, along with the manifests from the earlier enhanced builds. It checks that all required columns exist, confirms dataset names are correct, and enforces basic data rules. Any literal `"NaN"` strings are converted to real missing values, sex values are standardized to **M** or **F**, and class labels are limited to Healthy (0) and Parkinson’s (1). Using the earlier enhanced manifests, it collects the full list of D2 speakers that have already been used and removes them from the pool of eligible speakers.

Next, the cell calculates how many D2 speakers to add. This is based on **10 percent of the total D2 training speaker count**, rounded up and adjusted to an even number so Healthy and Parkinson’s speakers can be selected in equal numbers. This target is fixed before exclusions are applied. If, after exclusions, there are not enough eligible speakers in one or both classes, the target is reduced evenly for Healthy and Parkinson’s speakers, and this adjustment is recorded. A fixed random seed is used so the same speakers are chosen each time the cell is run.

With the speaker list finalized, the cell prepares a copy plan for the new folder `clips/train_enh3/`. All original D7 training clips are copied using their original filenames. Clips from the selected D2 speakers are copied using new, predictable filenames that include class, speaker ID, task type, and a stable index. This avoids name conflicts and keeps files easy to trace. Before copying, the cell checks that every source file exists and applies a strict no-overwrite rule: files that already exist with the same size are skipped, while any mismatch or file error stops the run. Initial configuration files and summaries are written at this point so the setup is recorded even if copying does not finish.

During the copy step, files are copied without changing the original datasets. The cell tracks how many files are copied and how many are skipped. Any copy failure causes the run to stop with a clear error. After copying completes, the cell creates `manifest_train_enh3.csv` by combining the original D7 training rows with the newly added D2 rows. All rows are labeled with `split = "train_enh3"`. File paths are updated to point to the new clip locations, and D2-derived rows include `source_dataset = "D2"` while still being treated as part of the D7 training set. The final manifest is checked to confirm that all files exist and that no literal `"NaN"` values remain.

The cell finishes by writing the final outputs in a safe, non-destructive way. These include the new `train_enh3` clip folder, the training manifest, a detailed run configuration file, a warnings log, and a dataset summary that reports counts and key details of the build.

In [None]:
# D7 trainEnh3 builder: add a new D2 speaker draw into D7 training
# Inputs: D7 manifest (train split), D2 manifest (train split), prior train_enh1 and train_enh2 manifests
# Outputs: train_enh3 clip folder (copied files), manifest_train_enh3.csv, run_config.json, dataset_summary.json, preprocess_warnings.csv
#
# =========================
# D7 TRAIN_ENH3 BUILDER (CRASH-PROOF, THIRD-PARTY AUDITABLE)
# — Add ~9–10% of D2 TRAIN speakers (Definition B) with:
#   * FIXED DENOMINATOR POLICY (same as train_enh2):
#       target_total = ceil(10% of FULL D2 TRAIN speaker count), then make even
#       exclusions only shrink eligible pool, not the denominator
#   * FORCE ZERO OVERLAP with prior draws:
#       exclude ANY D2 speakers used in manifest_train_enh1 AND manifest_train_enh2
#   * If eligible pool is insufficient to meet target_per_class:
#       downsize BOTH classes to keep speaker-balanced sampling
#
# - Creates: <D7_OUT_ROOT>/clips/train_enh3/
# - Writes:
#     * manifests/manifest_train_enh3.csv
#     * config/D7_Enh3_on_D2_Test/run_config.json
#     * logs/D7_Enh3_on_D2_Test/preprocess_warnings.csv
#     * logs/D7_Enh3_on_D2_Test/dataset_summary.json
# - COPY only, no overwrite
# =========================

import os, json, re, shutil, math
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

# -------------------------
# 0) Drive mount (only needed in Colab)
# -------------------------
try:
    from google.colab import drive  # type: ignore
    if not os.path.exists("/content/drive"):
        drive.mount("/content/drive")
except Exception:
    pass

# -------------------------
# 1) Key inputs and output pointers
#    - Inputs: D7/D2 manifests and prior enh1/enh2 manifests (to enforce zero overlap)
#    - Outputs: train_enh3 clips folder + manifest + logs + config
# -------------------------
D7_OUT_ROOT = Path("/content/drive/MyDrive/AI_PD_Project/Datasets/D7-Multilingual (D1_D4_D5v2_D6)/preprocessed_v1")
D2_OUT_ROOT = Path("/content/drive/MyDrive/AI_PD_Project/Datasets/D2-Slovak (EWA-DB)/EWA-DB/preprocessed_v1")

D7_MANIFEST_ALL = D7_OUT_ROOT / "manifests" / "manifest_all.csv"
D2_MANIFEST_ALL = D2_OUT_ROOT / "manifests" / "manifest_all.csv"

# Prior selection sources (used only to extract excluded D2 speakers)
MANIFEST_TRAIN_ENH1 = Path("/content/drive/MyDrive/AI_PD_Project/Datasets/D7-Multilingual (D1_D4_D5v2_D6)/preprocessed_v1/manifests/manifest_train_enh1.csv")
MANIFEST_TRAIN_ENH2 = Path("/content/drive/MyDrive/AI_PD_Project/Datasets/D7-Multilingual (D1_D4_D5v2_D6)/preprocessed_v1/manifests/manifest_train_enh2.csv")

# New outputs
TRAIN_ENH3_DIR = D7_OUT_ROOT / "clips" / "train_enh3"
MANIFEST_TRAIN_ENH3 = D7_OUT_ROOT / "manifests" / "manifest_train_enh3.csv"

# -------------------------
# 1.5) Run subfolders for logs and config
# -------------------------
RUN_FOLDER_LOGS = "D7_Enh3_on_D2_Test"
RUN_FOLDER_CFG  = "D7_Enh3_on_D2_Test"

LOGS_SUBDIR = D7_OUT_ROOT / "logs" / RUN_FOLDER_LOGS
CFG_SUBDIR  = D7_OUT_ROOT / "config" / RUN_FOLDER_CFG

WARNINGS_CSV    = LOGS_SUBDIR / "preprocess_warnings.csv"
SUMMARY_JSON    = LOGS_SUBDIR / "dataset_summary.json"
RUN_CONFIG_JSON = CFG_SUBDIR  / "run_config.json"

# -------------------------
# 1.6) Sampling knobs and file copy rules
# -------------------------
TEN_PCT         = 0.10
ROUNDING_POLICY = "ceil_then_make_even_by_rounding_down_if_needed"
BALANCE_POLICY  = "speaker-balanced"
SPEAKER_ID_COL  = "speaker_id"

RNG_SEED        = 7777
COPY_ALL_D7_TRAIN_CLIPS = True

LABEL_MAP_NOTE  = "label_num mapping: 0=Healthy, 1=Parkinson"

# -------------------------
# 2) Create required folders (safe if already present)
# -------------------------
clips_dir = D7_OUT_ROOT / "clips"
manif_dir = D7_OUT_ROOT / "manifests"
logs_dir  = D7_OUT_ROOT / "logs"
cfg_dir   = D7_OUT_ROOT / "config"

for d in [clips_dir, manif_dir, logs_dir, cfg_dir, LOGS_SUBDIR, CFG_SUBDIR]:
    d.mkdir(parents=True, exist_ok=True)
TRAIN_ENH3_DIR.mkdir(parents=True, exist_ok=True)

# -------------------------
# 3) Locked manifest schema (kept consistent across datasets)
# -------------------------
CANON_COLS = [
    "split",
    "dataset",
    "task",
    "speaker_id",
    "sample_id",
    "label_str",
    "label_num",
    "age",
    "sex",
    "speaker_key_rel",
    "clip_path",
    "duration_sec",
    "source_path",
    "clip_start_sec",
    "clip_end_sec",
    "sr_hz",
    "channels",
    "clip_is_contiguous",
]
FINAL_COLS = CANON_COLS + ["source_dataset"]

# -------------------------
# 4) Small helpers
#    - require(): stop early with a clear error
#    - add_warn(): collect warnings and errors for the log file
#    - atomic writes: avoid partial files if the runtime stops mid-write
# -------------------------
warnings_rows = []

def require(cond: bool, msg: str):
    if not cond:
        raise RuntimeError(msg)

def add_warn(src: str, level: str, code: str, message: str, **extra):
    row = {
        "ts": datetime.utcnow().isoformat(),
        "src": src,
        "level": str(level).upper(),
        "code": code,
        "message": message,
    }
    row.update(extra)
    warnings_rows.append(row)

def count_by_level(rows):
    out = {"ERROR": 0, "WARN": 0, "INFO": 0}
    for r in rows:
        lvl = str(r.get("level", "INFO")).upper()
        out[lvl] = out.get(lvl, 0) + 1
    return out

def atomic_write_text(dst: Path, text: str):
    tmp = dst.with_suffix(dst.suffix + ".tmp")
    with open(tmp, "w", encoding="utf-8") as f:
        f.write(text)
    os.replace(tmp, dst)

def atomic_write_csv(dst: Path, df: pd.DataFrame):
    tmp = dst.with_suffix(dst.suffix + ".tmp")
    df.to_csv(tmp, index=False, na_rep="")
    os.replace(tmp, dst)

# Safe tokens for deterministic file names (avoid spaces and symbols)
def safe_token(s, max_len=32, default="NA"):
    if pd.isna(s):
        return default
    s = str(s).strip()
    s = re.sub(r"\s+", "_", s)
    s = re.sub(r"[^A-Za-z0-9_]+", "", s)
    return (s[:max_len] if s else default)

# Convert numeric label to text label (used for readability in manifests)
def label_str_from_num(v):
    if pd.isna(v):
        return np.nan
    iv = int(v)
    if iv == 0:
        return "Healthy"
    if iv == 1:
        return "Parkinson"
    return np.nan

def hc_pd_from_label_num(v):
    return "PD" if int(v) == 1 else "HC"

# Normalize sex values to M/F/NaN (separate rules for each dataset)
def normalize_sex_to_MF_D7(x):
    if pd.isna(x):
        return np.nan
    s = str(x).strip().lower()
    if s in ["m", "male"]:
        return "M"
    if s in ["f", "female"]:
        return "F"
    if s in ["", "nan", "none", "unknown", "u"]:
        return np.nan
    return np.nan

def normalize_sex_to_MF_D2(x):
    if pd.isna(x):
        return np.nan
    s = str(x).strip().lower()
    if s in ["male", "m"]:
        return "M"
    if s in ["female", "f"]:
        return "F"
    if s in ["", "nan", "none", "unknown", "u"]:
        return np.nan
    return np.nan

# Quick summary used in the final dataset_summary.json
def summarize_counts(df: pd.DataFrame, split_name: str):
    out = {}
    out["total_rows"] = int(len(df))
    out["split"] = split_name
    out["label_counts_total"] = {str(k): int(v) for k, v in df["label_num"].value_counts(dropna=False).to_dict().items()}
    out["by_source_dataset"] = {sd: int((df["source_dataset"] == sd).sum()) for sd in sorted(df["source_dataset"].dropna().unique())}
    out["sex_counts"] = {str(k): int(v) for k, v in df["sex"].value_counts(dropna=False).to_dict().items()}
    out["n_unique_speakers"] = int(df["speaker_id"].astype(str).nunique()) if "speaker_id" in df.columns else int(0)
    return out

# -------------------------
# 5) Load manifests and validate schema
#    - Stops early if a required input is missing or malformed
# -------------------------
print("D7_OUT_ROOT:", str(D7_OUT_ROOT))
print("D2_OUT_ROOT:", str(D2_OUT_ROOT))
print("D7_MANIFEST_ALL:", str(D7_MANIFEST_ALL))
print("D2_MANIFEST_ALL:", str(D2_MANIFEST_ALL))
print("MANIFEST_TRAIN_ENH1:", str(MANIFEST_TRAIN_ENH1))
print("MANIFEST_TRAIN_ENH2:", str(MANIFEST_TRAIN_ENH2))
print("TRAIN_ENH3_DIR:", str(TRAIN_ENH3_DIR))
print("MANIFEST_TRAIN_ENH3:", str(MANIFEST_TRAIN_ENH3))
print("LOGS_SUBDIR:", str(LOGS_SUBDIR))
print("CFG_SUBDIR:", str(CFG_SUBDIR))
print("RNG_SEED:", int(RNG_SEED))
print("COPY_ALL_D7_TRAIN_CLIPS:", bool(COPY_ALL_D7_TRAIN_CLIPS))

require(D7_MANIFEST_ALL.exists(), f"Missing D7 manifest_all.csv: {str(D7_MANIFEST_ALL)}")
require(D2_MANIFEST_ALL.exists(), f"Missing D2 manifest_all.csv: {str(D2_MANIFEST_ALL)}")
require(MANIFEST_TRAIN_ENH1.exists(), f"Missing prior manifest_train_enh1.csv: {str(MANIFEST_TRAIN_ENH1)}")
require(MANIFEST_TRAIN_ENH2.exists(), f"Missing prior manifest_train_enh2.csv: {str(MANIFEST_TRAIN_ENH2)}")

d7 = pd.read_csv(D7_MANIFEST_ALL)
d2 = pd.read_csv(D2_MANIFEST_ALL)
enh1 = pd.read_csv(MANIFEST_TRAIN_ENH1)
enh2 = pd.read_csv(MANIFEST_TRAIN_ENH2)

# Verify the core columns exist in both base manifests
missing_d7 = [c for c in CANON_COLS if c not in d7.columns]
missing_d2 = [c for c in CANON_COLS if c not in d2.columns]
require(len(missing_d7) == 0, f"D7 manifest missing required columns: {missing_d7}")
require(len(missing_d2) == 0, f"D2 manifest missing required columns: {missing_d2}")

# Prior enh manifests must contain speaker_id and source_dataset to support exclusions
for name, df in [("enh1", enh1), ("enh2", enh2)]:
    require("speaker_id" in df.columns, f"{name} missing 'speaker_id'. Found: {list(df.columns)}")
    require("source_dataset" in df.columns, f"{name} missing 'source_dataset'. Found: {list(df.columns)}")

# Ensure source_dataset exists everywhere (used for reporting and exclusions)
if "source_dataset" not in d7.columns:
    d7["source_dataset"] = "D7"
if "source_dataset" not in d2.columns:
    d2["source_dataset"] = "D2"

# Convert literal "NaN" strings into real missing values
for df in [d7, d2, enh1, enh2]:
    for col in ["sex", "age", "duration_sec", "clip_start_sec", "clip_end_sec", "speaker_key_rel", "speaker_id", "task", "sample_id"]:
        if col in df.columns:
            df[col] = df[col].replace("NaN", np.nan)

# Basic dataset identity checks (avoid mixing wrong manifests)
def infer_dataset_id(df: pd.DataFrame, fallback: str) -> str:
    if "dataset" in df.columns and df["dataset"].notna().any():
        return str(df["dataset"].astype(str).value_counts(dropna=True).idxmax())
    return fallback

require(infer_dataset_id(d7, "DX") == "D7", "Expected D7 manifest dataset=='D7'.")
require(infer_dataset_id(d2, "DX") == "D2", "Expected D2 manifest dataset=='D2'.")

# Keep only the intended dataset rows (defensive filter)
d7 = d7[d7["dataset"].astype(str) == "D7"].copy()
d2 = d2[d2["dataset"].astype(str) == "D2"].copy()

# Validate labels and rebuild label_str for consistency
for name, df in [("D7", d7), ("D2", d2)]:
    bad = sorted(set(df["label_num"].dropna().unique()) - {0, 1})
    require(len(bad) == 0, f"{name} label_num contains values outside {{0,1}}: {bad}")
    df["label_str"] = df["label_num"].map(label_str_from_num)

# Standardize sex encoding
d7["sex"] = d7["sex"].map(normalize_sex_to_MF_D7)
d2["sex"] = d2["sex"].map(normalize_sex_to_MF_D2)

# -------------------------
# 6) Extract D7 training rows (base training content)
# -------------------------
d7_train = d7[d7["split"].astype(str) == "train"].copy()
require(len(d7_train) > 0, "D7 train split has 0 rows.")
print("\nD7 train rows:", int(len(d7_train)))
print("D7 train label counts:", d7_train["label_num"].value_counts(dropna=False).to_dict())

# -------------------------
# 7) Build the D2 speaker exclusion list from enh1 and enh2 (union)
#    - Ensures the new draw shares zero D2 speakers with previous draws
# -------------------------
enh1_d2 = enh1[enh1["source_dataset"].astype(str) == "D2"].copy()
enh2_d2 = enh2[enh2["source_dataset"].astype(str) == "D2"].copy()

prior_d2_speakers_enh1 = set(enh1_d2["speaker_id"].dropna().astype(str).unique().tolist())
prior_d2_speakers_enh2 = set(enh2_d2["speaker_id"].dropna().astype(str).unique().tolist())
prior_d2_speakers_union = sorted(list(prior_d2_speakers_enh1.union(prior_d2_speakers_enh2)))

print("\nZero-overlap exclusion from manifest_train_enh1 + manifest_train_enh2:")
print("  Prior D2 unique speakers in enh1:", int(len(prior_d2_speakers_enh1)))
print("  Prior D2 unique speakers in enh2:", int(len(prior_d2_speakers_enh2)))
print("  UNION excluded D2 unique speakers:", int(len(prior_d2_speakers_union)))
print("  Example excluded speakers (up to 20):", prior_d2_speakers_union[:20])

# -------------------------
# 8) Build the full D2 train speaker set and compute the fixed target size
#    - Target is based on ALL D2 train speakers (before exclusions)
#    - Definition B requires each speaker to have a single consistent label
# -------------------------
d2_train = d2[d2["split"].astype(str) == "train"].copy()
require(len(d2_train) > 0, "D2 train split has 0 rows.")
require(SPEAKER_ID_COL in d2_train.columns, f"D2 missing speaker id column: {SPEAKER_ID_COL}")

print("\nD2 train rows:", int(len(d2_train)))
print("D2 train label counts:", d2_train["label_num"].value_counts(dropna=False).to_dict())

# Speaker label consistency check (one label per speaker)
speaker_labels = (
    d2_train.groupby(SPEAKER_ID_COL)["label_num"]
    .apply(lambda s: sorted(set(s.dropna().astype(int).tolist())))
)
mixed = speaker_labels[speaker_labels.apply(lambda x: len(x) != 1)]
if len(mixed) > 0:
    raise RuntimeError(
        "D2 train has speakers with mixed label_num values across clips; Definition B sampling cannot proceed.\n"
        f"Examples: {mixed.head(10).to_dict()}"
    )

speaker_to_label_full = speaker_labels.apply(lambda x: int(x[0])).to_dict()
all_speakers_full = sorted([str(x) for x in speaker_to_label_full.keys()])
total_speakers_full = len(all_speakers_full)

hc_full = sorted([str(spk) for spk, y in speaker_to_label_full.items() if int(y) == 0])
pd_full = sorted([str(spk) for spk, y in speaker_to_label_full.items() if int(y) == 1])
require(len(hc_full) > 0 and len(pd_full) > 0, "D2 train does not contain both HC and PD speakers.")

# Fixed-denominator target (computed before applying exclusions)
target_total = int(math.ceil(TEN_PCT * total_speakers_full))
if target_total % 2 != 0:
    target_total -= 1
if target_total < 2:
    target_total = 2
target_per_class = target_total // 2

print("\nFixed-denominator target computation (same policy as train_enh2):")
print("  total D2 train speakers (FULL):", total_speakers_full)
print(f"  target_total = ceil(0.10 * {total_speakers_full}) then even:", target_total, f"=> {target_per_class} HC + {target_per_class} PD")

# -------------------------
# 9) Apply exclusions to form the eligible pool, then sample speakers
#    - If needed, downsize both classes to keep HC/PD speaker counts matched
# -------------------------
exclude_set = set(prior_d2_speakers_union)
speaker_to_label_eligible = {str(spk): int(lbl) for spk, lbl in speaker_to_label_full.items() if str(spk) not in exclude_set}

hc_eligible = sorted([spk for spk, y in speaker_to_label_eligible.items() if y == 0])
pd_eligible = sorted([spk for spk, y in speaker_to_label_eligible.items() if y == 1])

print("\nEligible pool after exclusion (does not affect denominator target):")
print("  eligible speakers total:", len(speaker_to_label_eligible))
print("  eligible HC speakers:", len(hc_eligible))
print("  eligible PD speakers:", len(pd_eligible))

# Downsize if eligible pool is too small (keeps the draw speaker-balanced)
max_per_class = min(len(hc_eligible), len(pd_eligible), int(target_per_class))
require(max_per_class >= 1, "After exclusions, eligible pool has <1 HC or <1 PD speaker. Cannot build a balanced train_enh3 draw.")

downsized = (max_per_class != int(target_per_class))
effective_per_class = int(max_per_class)
effective_total = int(2 * effective_per_class)

if downsized:
    add_warn(
        "D7_TRAIN_ENH3", "WARN", "DOWNSIZED_DUE_TO_ELIGIBLE_POOL",
        "Eligible pool insufficient to meet fixed-denominator target_per_class; downsized to keep balanced sampling.",
        target_per_class=int(target_per_class),
        effective_per_class=int(effective_per_class),
        eligible_hc=int(len(hc_eligible)),
        eligible_pd=int(len(pd_eligible)),
    )

print("\nEffective draw size (speaker-balanced):")
print("  requested target_per_class:", int(target_per_class))
print("  effective_per_class:", int(effective_per_class))
print("  effective_total speakers:", int(effective_total), f"(~{100.0*effective_total/max(1,total_speakers_full):.2f}% of FULL D2 train speaker denom)")

# Random draw (deterministic via RNG_SEED)
rng = np.random.default_rng(int(RNG_SEED))
sel_hc = sorted(rng.choice(hc_eligible, size=effective_per_class, replace=False).tolist())
sel_pd = sorted(rng.choice(pd_eligible, size=effective_per_class, replace=False).tolist())
selected_speakers = sorted(sel_hc + sel_pd)

# Pull all D2 train clips for the selected speakers
d2_sel = d2_train[d2_train[SPEAKER_ID_COL].astype(str).isin([str(x) for x in selected_speakers])].copy()

print("\nD2 speaker sampling (DRAW 3, zero overlap with enh1+enh2):")
print("  selected HC speakers:", int(len(sel_hc)))
print("  selected PD speakers:", int(len(sel_pd)))
print("  selected D2 rows (all clips for selected speakers):", int(len(d2_sel)))
print("  selected label counts:", d2_sel["label_num"].value_counts(dropna=False).to_dict())

# -------------------------
# 10) Fail-fast: confirm all source clips exist before planning copies
# -------------------------
def fail_fast_missing_paths(df: pd.DataFrame, name: str):
    missing_paths = []
    for p in tqdm(df["clip_path"].astype(str).tolist(), desc=f"Check {name} clip_path exists", dynamic_ncols=True):
        if not os.path.exists(p):
            missing_paths.append(p)
            if len(missing_paths) >= 25:
                break
    if missing_paths:
        raise FileNotFoundError(f"{name}: missing clip_path(s). Examples (up to 25): {missing_paths}")

if COPY_ALL_D7_TRAIN_CLIPS:
    fail_fast_missing_paths(d7_train, "D7 TRAIN")
fail_fast_missing_paths(d2_sel, "D2 TRAIN (selected speakers)")

# -------------------------
# 11) Build a copy plan and enforce no-overwrite rules
#    - If destination exists with same file size: skip
#    - If destination exists with different size: log ERROR and stop
# -------------------------
copy_plan = []

# A) D7 train clips copied into train_enh3 with the same file names
if COPY_ALL_D7_TRAIN_CLIPS:
    for i, row in d7_train.reset_index(drop=True).iterrows():
        src_path = Path(str(row["clip_path"]))
        dst_path = TRAIN_ENH3_DIR / src_path.name
        copy_plan.append({
            "src_path": str(src_path),
            "dst_path": str(dst_path),
            "origin": "D7_train_existing",
            "source_dataset": str(row.get("source_dataset", "D7")),
            "row_index": int(i),
        })

# B) D2 selected clips copied into train_enh3 with deterministic new names
d2_sel_reset = d2_sel.reset_index(drop=True).copy()
d2_sel_reset["_speaker_tok"] = d2_sel_reset["speaker_id"].map(lambda x: safe_token(x, 32, "NA"))
d2_sel_reset["_task_tok"] = d2_sel_reset["task"].map(lambda x: safe_token(x, 12, "0"))
d2_sel_reset["_clip_path_str"] = d2_sel_reset["clip_path"].astype(str)
d2_sel_reset = d2_sel_reset.sort_values(by=["_speaker_tok", "_task_tok", "_clip_path_str"]).reset_index(drop=True)

for j, row in d2_sel_reset.iterrows():
    src_path = Path(str(row["clip_path"]))
    hc_pd = hc_pd_from_label_num(row["label_num"])
    spk_tok = safe_token(row["speaker_id"], 32, "NA")
    task_tok = safe_token(row["task"], 12, "0")
    out_name = f"D7_D2add_train_enh3_{hc_pd}_{spk_tok}_{task_tok}_{j+1:06d}.wav"
    dst_path = TRAIN_ENH3_DIR / out_name
    copy_plan.append({
        "src_path": str(src_path),
        "dst_path": str(dst_path),
        "origin": "D2_train_selected",
        "source_dataset": "D2",
        "row_index": int(j),
    })

n_dest_exists_ok = 0
n_dest_exists_mismatch = 0
n_will_copy = 0

print("\nPreflight: destination existence and size checks (no overwrite)...")
for item in tqdm(copy_plan, desc="Preflight (dest checks)", dynamic_ncols=True):
    sp = Path(item["src_path"])
    dp = Path(item["dst_path"])
    require(sp.exists(), f"Source clip missing at preflight: {str(sp)}")
    if dp.exists():
        try:
            if dp.stat().st_size == sp.stat().st_size:
                n_dest_exists_ok += 1
            else:
                n_dest_exists_mismatch += 1
                add_warn(
                    "D7_TRAIN_ENH3", "ERROR", "DEST_EXISTS_SIZE_MISMATCH",
                    "Destination exists but file size differs from source",
                    src_path=str(sp), dest_path=str(dp),
                    src_size=int(sp.stat().st_size), dest_size=int(dp.stat().st_size),
                )
        except Exception as e:
            n_dest_exists_mismatch += 1
            add_warn(
                "D7_TRAIN_ENH3", "ERROR", "DEST_EXISTS_STAT_ERROR",
                "Failed to stat source/destination during preflight",
                src_path=str(sp), dest_path=str(dp), error=repr(e),
            )
    else:
        n_will_copy += 1

preflight_stats = {
    "total_planned_files": int(len(copy_plan)),
    "n_dest_exists_ok": int(n_dest_exists_ok),
    "n_dest_exists_mismatch": int(n_dest_exists_mismatch),
    "n_will_copy": int(n_will_copy),
    "warnings_by_level": count_by_level(warnings_rows),
}

print("\nPreflight summary:")
print("  Planned files:", int(len(copy_plan)))
print("  Destination exists (size OK):", n_dest_exists_ok)
print("  Destination exists (size mismatch/stat error):", n_dest_exists_mismatch)
print("  Will copy:", n_will_copy)
print("  Warnings by level:", preflight_stats["warnings_by_level"])

# -------------------------
# 12) Write config and early summary before copying
#    - Captures the draw, exclusions, and preflight checks even if the session stops later
# -------------------------
run_config = {
    "dataset": "D7",
    "mode": "train_enh3_builder",
    "created_utc": datetime.utcnow().isoformat(),
    "run_folder_logs": RUN_FOLDER_LOGS,
    "run_folder_cfg": RUN_FOLDER_CFG,
    "d7_out_root": str(D7_OUT_ROOT),
    "d2_out_root": str(D2_OUT_ROOT),
    "d7_manifest_all": str(D7_MANIFEST_ALL),
    "d2_manifest_all": str(D2_MANIFEST_ALL),
    "manifest_train_enh1": str(MANIFEST_TRAIN_ENH1),
    "manifest_train_enh2": str(MANIFEST_TRAIN_ENH2),
    "train_enh3_dir": str(TRAIN_ENH3_DIR),
    "manifest_train_enh3": str(MANIFEST_TRAIN_ENH3),
    "policy": {
        "definition": "Add ~10% of D2 TRAIN by speaker (Definition B) into D7 training clips folder",
        "pct_speakers": float(TEN_PCT),
        "denominator_policy": "FIXED_FULL_D2_TRAIN_SPEAKER_COUNT",
        "rounding_policy": ROUNDING_POLICY,
        "balance_policy": BALANCE_POLICY,
        "speaker_id_col_used": SPEAKER_ID_COL,
        "rng_seed": int(RNG_SEED),
        "copy_all_d7_train_clips": bool(COPY_ALL_D7_TRAIN_CLIPS),
        "label_note": LABEL_MAP_NOTE,
        "file_operation": "copy",
        "no_overwrite_rule": "skip if dest exists with matching size; error if size differs",
        "zero_overlap": {
            "enabled": True,
            "exclude_speakers_source": "manifest_train_enh1 + manifest_train_enh2 where source_dataset == 'D2'",
            "n_excluded_d2_speakers_union": int(len(prior_d2_speakers_union)),
            "excluded_d2_speakers_union": prior_d2_speakers_union,
        },
        "downsizing_rule_if_insufficient_eligible_pool": "downsize BOTH classes to min(eligible_hc, eligible_pd, target_per_class) to keep speaker-balanced",
    },
    "inputs": {
        "d7_train_rows": int(len(d7_train)),
        "d2_train_rows": int(len(d2_train)),
        "d2_train_total_speakers_full": int(total_speakers_full),
        "d2_train_speakers_hc_full": int(len(hc_full)),
        "d2_train_speakers_pd_full": int(len(pd_full)),
        "d2_train_speakers_hc_eligible": int(len(hc_eligible)),
        "d2_train_speakers_pd_eligible": int(len(pd_eligible)),
    },
    "selection": {
        "target_total_speakers_fixed": int(target_total),
        "target_per_class_fixed": int(target_per_class),
        "effective_total_speakers": int(effective_total),
        "effective_per_class": int(effective_per_class),
        "downsized": bool(downsized),
        "selected_speakers_hc": sel_hc,
        "selected_speakers_pd": sel_pd,
        "selected_speakers_all": selected_speakers,
        "selected_d2_rows": int(len(d2_sel)),
        "selected_d2_label_counts": d2_sel["label_num"].value_counts(dropna=False).to_dict(),
    },
    "preflight": preflight_stats,
}

early_summary = {
    "dataset": "D7",
    "created_utc": datetime.utcnow().isoformat(),
    "status": "PRECHECK_COMPLETE",
    "d7_out_root": str(D7_OUT_ROOT),
    "train_enh3_dir": str(TRAIN_ENH3_DIR),
    "preflight": preflight_stats,
    "selection": run_config["selection"],
}

atomic_write_csv(WARNINGS_CSV, pd.DataFrame(warnings_rows))
atomic_write_text(RUN_CONFIG_JSON, json.dumps(run_config, indent=2))
atomic_write_text(SUMMARY_JSON, json.dumps(early_summary, indent=2))

# Stop before copying if any ERROR was logged during preflight
fatal_pre = [w for w in warnings_rows if str(w.get("level", "")).upper() == "ERROR"]
if fatal_pre:
    raise RuntimeError(f"Preflight failed with {len(fatal_pre)} ERROR(s). See {str(WARNINGS_CSV)}")

# -------------------------
# 13) Copy files into train_enh3 (copy only, never overwrite)
# -------------------------
print("\nCopy stage: copying into train_enh3 (no overwrite)...")

copied = 0
skipped_exists = 0
copy_errors = 0

for item in tqdm(copy_plan, desc="Copying clips", dynamic_ncols=True):
    sp = Path(item["src_path"])
    dp = Path(item["dst_path"])

    if not sp.exists():
        copy_errors += 1
        add_warn("D7_TRAIN_ENH3", "ERROR", "SOURCE_CLIP_DISAPPEARED", "Source clip missing during copy stage", clip_path=str(sp))
        continue

    if dp.exists():
        skipped_exists += 1
        continue

    try:
        shutil.copy2(sp, dp)
        copied += 1
    except Exception as e:
        copy_errors += 1
        add_warn("D7_TRAIN_ENH3", "ERROR", "COPY_FAILED", "Failed to copy file", src_path=str(sp), dest_path=str(dp), error=repr(e))

copy_stats = {
    "copied": int(copied),
    "skipped_exists": int(skipped_exists),
    "copy_errors": int(copy_errors),
    "total_planned_files": int(len(copy_plan)),
    "warnings_by_level": count_by_level(warnings_rows),
}

atomic_write_csv(WARNINGS_CSV, pd.DataFrame(warnings_rows))

# Stop if any copy ERROR happened (keeps outputs consistent)
fatal_copy = [w for w in warnings_rows if str(w.get("level", "")).upper() == "ERROR"]
if fatal_copy:
    run_config["copy_stats"] = copy_stats
    atomic_write_text(RUN_CONFIG_JSON, json.dumps(run_config, indent=2))
    fail_summary = dict(early_summary)
    fail_summary["status"] = "FAILED_DURING_COPY"
    fail_summary["copy_stats"] = copy_stats
    atomic_write_text(SUMMARY_JSON, json.dumps(fail_summary, indent=2))
    raise RuntimeError(f"Copy failed with {len(fatal_copy)} ERROR(s). See {str(WARNINGS_CSV)}")

# -------------------------
# 14) Create manifest_train_enh3.csv
#    - Combines: D7 train clips + selected D2 clips (renamed)
#    - Updates clip_path to point to files inside train_enh3
# -------------------------
parts = []

# D7 portion (same clips, new split name)
d7_enh3 = d7_train[CANON_COLS].copy()
d7_enh3["split"] = "train_enh3"
d7_enh3["dataset"] = "D7"
d7_enh3["clip_path"] = d7_train["clip_path"].astype(str).map(lambda p: str(TRAIN_ENH3_DIR / Path(p).name))
d7_enh3["sample_id"] = d7_enh3["clip_path"].map(lambda p: Path(p).stem)
d7_enh3["source_dataset"] = d7_train.get("source_dataset", pd.Series(["D7"] * len(d7_train))).astype(str).tolist()
parts.append(d7_enh3)

# D2 add-on portion (new file names already planned above)
d2_enh3 = d2_sel_reset[CANON_COLS].copy()
d2_enh3["split"] = "train_enh3"
d2_enh3["dataset"] = "D7"

new_paths, new_ids = [], []
for j, row in d2_sel_reset.iterrows():
    hc_pd = hc_pd_from_label_num(row["label_num"])
    spk_tok = safe_token(row["speaker_id"], 32, "NA")
    task_tok = safe_token(row["task"], 12, "0")
    out_name = f"D7_D2add_train_enh3_{hc_pd}_{spk_tok}_{task_tok}_{j+1:06d}.wav"
    out_path = TRAIN_ENH3_DIR / out_name
    new_paths.append(str(out_path))
    new_ids.append(out_path.stem)

d2_enh3["clip_path"] = new_paths
d2_enh3["sample_id"] = new_ids
d2_enh3["source_dataset"] = "D2"
parts.append(d2_enh3)

train_enh3 = pd.concat(parts, axis=0, ignore_index=True)
train_enh3 = train_enh3[FINAL_COLS].copy()

# Ensure missing values are real NaN, not the literal string "NaN"
for c in FINAL_COLS:
    if train_enh3[c].dtype == object and (train_enh3[c] == "NaN").any():
        raise RuntimeError(f"Found literal string 'NaN' in column '{c}'.")

# Final file existence check for the new manifest
missing_enh3 = []
for p in tqdm(train_enh3["clip_path"].astype(str).tolist(), desc="Check TRAIN_ENH3 clip_path exists", dynamic_ncols=True):
    if not os.path.exists(p):
        missing_enh3.append(p)
        if len(missing_enh3) >= 25:
            break
require(len(missing_enh3) == 0, f"train_enh3 manifest points to missing files. Examples: {missing_enh3[:25]}")

# -------------------------
# 15) Write outputs and a final summary
# -------------------------
atomic_write_csv(MANIFEST_TRAIN_ENH3, train_enh3)

final_summary = {
    "dataset": "D7",
    "created_utc": datetime.utcnow().isoformat(),
    "status": "SUCCESS",
    "d7_out_root": str(D7_OUT_ROOT),
    "train_enh3_dir": str(TRAIN_ENH3_DIR),
    "manifest_train_enh3": str(MANIFEST_TRAIN_ENH3),
    "selection": run_config["selection"],
    "copy_stats": copy_stats,
    "counts_train_enh3": summarize_counts(train_enh3, split_name="train_enh3"),
}

atomic_write_text(SUMMARY_JSON, json.dumps(final_summary, indent=2))

# Update run_config with final outputs and copy stats
run_config["copy_stats"] = copy_stats
run_config["outputs"] = {
    "train_enh3_dir": str(TRAIN_ENH3_DIR),
    "manifest_train_enh3": str(MANIFEST_TRAIN_ENH3),
    "dataset_summary_json": str(SUMMARY_JSON),
    "preprocess_warnings_csv": str(WARNINGS_CSV),
    "run_config_json": str(RUN_CONFIG_JSON),
}
atomic_write_text(RUN_CONFIG_JSON, json.dumps(run_config, indent=2))

print("\n✅ D7 train_enh3 build complete (FIXED DENOMINATOR + ZERO OVERLAP enh1+enh2).")
print("- Train_enh3 folder:", str(TRAIN_ENH3_DIR))
print("- Manifest:", str(MANIFEST_TRAIN_ENH3))
print("- Summary:", str(SUMMARY_JSON))
print("- Warnings:", str(WARNINGS_CSV))
print("- Config:", str(RUN_CONFIG_JSON))

The following cell runs **training and validation for the D7 enhanced model using the train_enh3 split**, without changing any data splits or reprocessing audio. Training data is read from `manifests/manifest_train_enh3.csv` (rows where `split == "train_enh3"`), and validation data is read from `manifests/manifest_all.csv` (rows where `split == "val"`). The cell checks that all required columns are present, confirms the dataset identifier is **D7**, and stops early if any referenced audio files are missing.

The model uses a **frozen Wav2Vec2 backbone** with **two small classification heads**, one for vowel clips and one for other speech clips. Each head includes a small LayerNorm and dropout block. Audio is loaded from `clip_path`, verified to be **16 kHz**, and padded for batching. For vowel clips, an attention mask removes trailing near-silence so padded quiet regions do not affect learning. Other clips use the full attention mask.

Before training starts, the cell locates the **most recent baseline D7 train+validation experiment** that was not trained on any `train_enh*` data and that contains `best_heads.pt` for all three seeds. For each seed (1337, 2024, 7777), the enhanced model heads are initialized from this baseline run. Training then updates only the head parameters using Adam and gradient accumulation. After each epoch, the model is evaluated on the validation set, **validation AUROC** is recorded, and early stopping is applied if AUROC does not improve for several epochs.

For each seed, the cell saves the **best head weights** (`best_heads.pt`), a `metrics.json` file, and validation plots, including a ROC curve and confusion matrices at threshold 0.5 and at an optimal threshold. The validation-optimal threshold is selected using **Youden’s J statistic**, based on the best-AUROC epoch.

After all three seeds complete, the cell writes an experiment-level `summary_trainval.json` that includes AUROC per seed, mean AUROC with a 95% t-based confidence interval (n=3), and the stored validation-optimal thresholds under `val_optimal_threshold.by_seed` and `val_optimal_threshold.mean_sd`. The experiment summary is also appended to the global `trainval_runs/history_index.jsonl`. The cell finishes by attempting to **unassign the Colab runtime** to release the GPU.

In [None]:
# D7 trainEnh3: Train and validate task heads (baseline-initialized)
# Inputs: D7 train_enh3 manifest (train rows) and D7 full manifest (val rows)
# Outputs: Per-seed best heads, per-seed metrics and plots, experiment summary, history entry
#
# =========================
# Train + Val ONLY (CRASH-PROOF, WITH PROGRESS + HISTORY) — D7 ENHANCED (train_enh3)
# - Frozen Wav2Vec2 backbone
# - Two task heads with small LayerNorm + Dropout blocks (trainable heads only)
# - Uses: one manifest for train_enh3 split + one manifest for val split
# - Initializes heads from the most recent BASELINE D7 trainval run (not train_enh*)
# - Writes: a new exp_*/ folder with per-seed runs, plus summary_trainval.json and history_index.jsonl entry
# - Adds threshold metrics (thr=0.5 and val-opt thr) and stores val-opt thresholds in summary_trainval.json
# - Ends by unassigning the Colab runtime (L4) with messages
# =========================

import os, json, math, random, time, warnings
from pathlib import Path

import numpy as np
import pandas as pd
import soundfile as sf

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import Wav2Vec2Model

from sklearn.metrics import (
    roc_auc_score, roc_curve,
    confusion_matrix, accuracy_score,
    precision_recall_fscore_support,
    matthews_corrcoef
)
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# -------------------------
# 0) Safety check: avoid importing a local file named torch.py or transformers.py
# -------------------------
if os.path.exists("/content/torch.py") or os.path.exists("/content/torch/__init__.py"):
    raise RuntimeError("Found local /content/torch.py or /content/torch/ that shadows PyTorch. Rename/remove it and restart runtime.")
if os.path.exists("/content/transformers.py") or os.path.exists("/content/transformers/__init__.py"):
    raise RuntimeError("Found local /content/transformers.py or /content/transformers/ that shadows Hugging Face Transformers. Rename/remove it and restart runtime.")

# -------------------------
# 1) Drive mount (safe if already mounted)
# -------------------------
try:
    from google.colab import drive  # type: ignore
    if not os.path.isdir("/content/drive/MyDrive"):
        drive.mount("/content/drive")
except Exception:
    pass

# -------------------------
# 2) Project root and manifest pointers
#    - DX_OUT_ROOT is the D7 preprocessed root (read manifests, write trainval runs)
# -------------------------
D7_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D7-Multilingual (D1_D4_D5v2_D6)/preprocessed_v1"
DX_OUT_ROOT = str(globals().get("DX_OUT_ROOT", D7_OUT_ROOT_FALLBACK))
globals()["DX_OUT_ROOT"] = DX_OUT_ROOT

MANIFEST_ALL = f"{DX_OUT_ROOT}/manifests/manifest_all.csv"

# Train_enh3 manifest and split label used for filtering rows
MANIFEST_TRAIN_ENH3 = f"{DX_OUT_ROOT}/manifests/manifest_train_enh3.csv"
TRAIN_SPLIT_NAME = "train_enh3"

# -------------------------
# 3) Run identity and output folder for this experiment
# -------------------------
EXPERIMENT_TAG = "frozen_LNDO_trainEnh3_initBaseline"
RUN_STAMP = time.strftime("%Y%m%d_%H%M%S")

TRAINVAL_ROOT = Path(DX_OUT_ROOT) / "trainval_runs"
EXP_ROOT = TRAINVAL_ROOT / f"exp_{EXPERIMENT_TAG}_{RUN_STAMP}"
EXP_ROOT.mkdir(parents=True, exist_ok=True)

# -------------------------
# 4) Training settings (kept aligned with the reference style)
# -------------------------
MAX_EPOCHS     = 10
EFFECTIVE_BS   = 64
PER_DEVICE_BS  = 16
GRAD_ACCUM     = max(1, EFFECTIVE_BS // PER_DEVICE_BS)

LR             = 1e-3
PATIENCE       = 2
SEEDS          = [1337, 2024, 7777]

BACKBONE_CKPT  = "facebook/wav2vec2-base"
SR_EXPECTED    = 16000
TINY_THRESH    = 1e-4

DROPOUT_P      = 0.2

NUM_WORKERS    = 0
PIN_MEMORY     = False

VOWEL_TASK_VALUE = "vowl"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

# Reduce common noisy warnings (keeps the notebook output readable)
warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
warnings.filterwarnings("ignore", message="Passing `gradient_checkpointing` to a config initialization is deprecated")
warnings.filterwarnings("ignore", category=UserWarning, module="huggingface_hub")

print("DX_OUT_ROOT:", DX_OUT_ROOT)
print("MANIFEST_TRAIN_ENH3:", MANIFEST_TRAIN_ENH3)
print("TRAIN_SPLIT_NAME:", TRAIN_SPLIT_NAME)
print("MANIFEST_ALL (val source):", MANIFEST_ALL)
print("DEVICE:", DEVICE)
if DEVICE.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
print("PER_DEVICE_BS:", PER_DEVICE_BS, "| GRAD_ACCUM:", GRAD_ACCUM, "| EFFECTIVE_BS:", PER_DEVICE_BS * GRAD_ACCUM)
print("NUM_WORKERS:", NUM_WORKERS, "| PIN_MEMORY:", PIN_MEMORY)
print("EXPERIMENT_TAG:", EXPERIMENT_TAG, "| RUN_STAMP:", RUN_STAMP)
print("EXP_ROOT:", str(EXP_ROOT))

# -------------------------
# 5) Load train and val tables from manifests
#    - Train: manifest_train_enh3.csv filtered to split == train_enh3
#    - Val:   manifest_all.csv filtered to split == val
# -------------------------
if not os.path.exists(MANIFEST_TRAIN_ENH3):
    raise FileNotFoundError(
        "Missing manifest_train_enh3.csv at:\n"
        f"  {MANIFEST_TRAIN_ENH3}\n"
        "Run the train_enh3 builder first."
    )
if not os.path.exists(MANIFEST_ALL):
    raise FileNotFoundError(
        "Missing manifest_all.csv at:\n"
        f"  {MANIFEST_ALL}\n"
        "Confirm the D7 merge-builder wrote manifests/manifest_all.csv under DX_OUT_ROOT."
    )

m_train = pd.read_csv(MANIFEST_TRAIN_ENH3)
m_all   = pd.read_csv(MANIFEST_ALL)

# Minimum columns needed for training and evaluation
req_cols = {"split", "clip_path", "label_num", "task"}
for name, df in [("manifest_train_enh3", m_train), ("manifest_all", m_all)]:
    missing = [c for c in sorted(req_cols) if c not in df.columns]
    if missing:
        raise ValueError(f"{name} missing required columns: {missing}. Found: {list(df.columns)}")

# Split filtering (no resplitting here)
m_train = m_train[m_train["split"].astype(str) == TRAIN_SPLIT_NAME].copy()
m_val   = m_all[m_all["split"].astype(str) == "val"].copy()

if len(m_train) == 0:
    raise RuntimeError(f"After filtering manifest_train_enh3.csv to split=={TRAIN_SPLIT_NAME!r}, 0 rows remain.")
if len(m_val) == 0:
    raise RuntimeError("After filtering manifest_all.csv to split=='val', 0 rows remain.")

# Infer dataset_id from the val manifest (used for run folder naming; expected "D7")
if "dataset" in m_val.columns and m_val["dataset"].notna().any():
    dataset_id = str(m_val["dataset"].astype(str).value_counts(dropna=True).idxmax())
    m_val = m_val[m_val["dataset"].astype(str) == dataset_id].copy()
else:
    dataset_id = "DX"

# Guard: this cell is for D7 training
if dataset_id != "D7":
    raise RuntimeError(f"Dataset inferred from VAL manifest is {dataset_id!r}. Expected 'D7'. Check DX_OUT_ROOT/manifests/manifest_all.csv.")

# Keep a small, consistent set of columns (missing columns become NaN)
keep_cols = ["clip_path", "label_num", "task", "speaker_id", "duration_sec", "split"]
for df in [m_train, m_val]:
    for c in keep_cols:
        if c not in df.columns:
            df[c] = np.nan

m_train = m_train[keep_cols].copy()
m_val   = m_val[keep_cols].copy()

train_df = m_train.copy().reset_index(drop=True)
val_df   = m_val.copy().reset_index(drop=True)

print(f"\nDataset inferred (from VAL): {dataset_id}")
print(f"Train rows ({TRAIN_SPLIT_NAME}): {len(train_df)} | Val rows: {len(val_df)}")
print("Train label counts:", train_df["label_num"].value_counts(dropna=False).to_dict())
print("Val label counts:",   val_df["label_num"].value_counts(dropna=False).to_dict())

# -------------------------
# 6) Fail-fast file check: stop early if clip files are missing
# -------------------------
def _fail_fast_missing_paths(df: pd.DataFrame, name: str):
    missing_paths = []
    for p in tqdm(df["clip_path"].astype(str).tolist(), desc=f"Check {name} clip_path exists", dynamic_ncols=True):
        if not os.path.exists(p):
            missing_paths.append(p)
            if len(missing_paths) >= 10:
                break
    if missing_paths:
        raise FileNotFoundError(f"{name}: missing clip_path(s). Examples: {missing_paths}")

_fail_fast_missing_paths(train_df, "TRAIN_ENH3")
_fail_fast_missing_paths(val_df, "VAL")

# -------------------------
# 7) Task grouping used by the two-head model
# -------------------------
def _task_group(task_val) -> str:
    return "vowel" if str(task_val) == VOWEL_TASK_VALUE else "other"

train_df["task_group"] = train_df["task"].apply(_task_group)
val_df["task_group"]   = val_df["task"].apply(_task_group)

# -------------------------
# 8) Dataset and batch padding
#    - Loads waveforms from clip_path
#    - Builds attention masks so vowel padding is ignored
# -------------------------
class AudioManifestDataset(Dataset):
    """
    Reads one audio clip and creates an attention mask.

    Mask rule:
    - vowel clips: mask trailing near-silence so padded zeros do not affect training
    - other clips: keep all samples unmasked
    """
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        clip_path = str(row["clip_path"])
        label = int(row["label_num"])
        task_group = str(row["task_group"])

        if not os.path.exists(clip_path):
            raise FileNotFoundError(f"Missing clip_path: {clip_path}")

        y, sr = sf.read(clip_path, always_2d=False)
        if isinstance(y, np.ndarray) and y.ndim > 1:
            y = np.mean(y, axis=1)
        y = np.asarray(y, dtype=np.float32)

        if int(sr) != SR_EXPECTED:
            raise ValueError(f"Sample rate mismatch (expected {SR_EXPECTED}): {clip_path} has sr={sr}")

        attn = np.ones((len(y),), dtype=np.int64)

        # Vowel clips often include zero padding; mask the padded tail
        if task_group == "vowel":
            k = -1
            for j in range(len(y) - 1, -1, -1):
                if abs(float(y[j])) > TINY_THRESH:
                    k = j
                    break
            if k >= 0:
                attn[:k+1] = 1
                attn[k+1:] = 0
            else:
                attn[:] = 1
        else:
            attn[:] = 1

        return {
            "input_values": torch.from_numpy(y),                 # float32 [T]
            "attention_mask": torch.from_numpy(attn),            # int64   [T]
            "labels": torch.tensor(label, dtype=torch.long),     # int64   []
            "task_group": task_group,                            # str
        }

def collate_fn(batch):
    """Pads all waveforms and masks to the longest clip in the batch."""
    max_len = int(max(b["input_values"].numel() for b in batch))
    input_vals, attn_masks, labels, task_groups = [], [], [], []
    for b in batch:
        x = b["input_values"]
        a = b["attention_mask"]
        pad = max_len - x.numel()
        if pad > 0:
            x = torch.cat([x, torch.zeros(pad, dtype=x.dtype)], dim=0)
            a = torch.cat([a, torch.zeros(pad, dtype=a.dtype)], dim=0)
        input_vals.append(x)
        attn_masks.append(a)
        labels.append(b["labels"])
        task_groups.append(b["task_group"])
    return {
        "input_values": torch.stack(input_vals, dim=0),      # [B,T]
        "attention_mask": torch.stack(attn_masks, dim=0),    # [B,T]
        "labels": torch.stack(labels, dim=0),                # [B]
        "task_group": task_groups,                           # list[str]
    }

# -------------------------
# 9) Model: frozen backbone + two task-specific heads
# -------------------------
class Wav2Vec2TwoHeadClassifier(nn.Module):
    """
    Frozen Wav2Vec2 backbone with two small classification heads.
    Trainable parts:
    - pre_vowel, pre_other (LayerNorm + Dropout)
    - head_vowel, head_other (Linear -> 2 logits)
    """
    def __init__(self, ckpt: str, dropout_p: float):
        super().__init__()
        self.backbone = Wav2Vec2Model.from_pretrained(
            ckpt,
            use_safetensors=True,
            local_files_only=False
        )
        for p in self.backbone.parameters():
            p.requires_grad = False

        H = int(self.backbone.config.hidden_size)
        self.pre_vowel = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.pre_other = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))

        self.head_vowel = nn.Linear(H, 2)
        self.head_other = nn.Linear(H, 2)
        self.loss_fn = nn.CrossEntropyLoss()

    def masked_mean_pool(self, last_hidden, attn_mask_samples):
        # Convert sample-level mask to feature-level mask and pool only valid frames
        feat_mask = self.backbone._get_feature_vector_attention_mask(last_hidden.shape[1], attn_mask_samples)
        feat_mask = feat_mask.to(last_hidden.device).unsqueeze(-1).type_as(last_hidden)
        masked = last_hidden * feat_mask
        denom = feat_mask.sum(dim=1).clamp(min=1.0)
        return masked.sum(dim=1) / denom

    def forward(self, input_values, attention_mask, labels, task_group):
        # Backbone forward is kept in no_grad to train heads only
        with torch.no_grad():
            out = self.backbone(input_values=input_values, attention_mask=attention_mask)
            last_hidden = out.last_hidden_state  # [B,T',H]

        pooled = self.masked_mean_pool(last_hidden, attention_mask).float()  # [B,H]

        z_v = self.pre_vowel(pooled)
        z_o = self.pre_other(pooled)

        logits_v = self.head_vowel(z_v)
        logits_o = self.head_other(z_o)

        # Choose which head to use per sample based on task_group
        is_vowel = torch.tensor([tg == "vowel" for tg in task_group], device=pooled.device, dtype=torch.bool)
        logits = logits_o.clone()
        if is_vowel.any():
            logits[is_vowel] = logits_v[is_vowel]

        loss = self.loss_fn(logits, labels)
        return loss, logits

# -------------------------
# 9.5) Baseline head initialization
# -------------------------
def load_heads_into_model(model: Wav2Vec2TwoHeadClassifier, best_heads_path: Path):
    if not best_heads_path.exists():
        raise FileNotFoundError(f"Missing best_heads.pt: {str(best_heads_path)}")
    state = torch.load(str(best_heads_path), map_location="cpu")
    model.pre_vowel.load_state_dict(state["pre_vowel"], strict=True)
    model.pre_other.load_state_dict(state["pre_other"], strict=True)
    model.head_vowel.load_state_dict(state["head_vowel"], strict=True)
    model.head_other.load_state_dict(state["head_other"], strict=True)
    return model

# -------------------------
# 10) Metrics and plotting helpers
# -------------------------
def compute_auc(y_true, y_prob):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    if len(np.unique(y_true)) < 2:
        return float("nan")
    return float(roc_auc_score(y_true, y_prob))

def compute_threshold_metrics(y_true, y_prob, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    tn, fp, fn, tp = (cm.ravel().tolist() if cm.size == 4 else [0, 0, 0, 0])

    # Common threshold metrics (kept together for easy reporting)
    acc = float(accuracy_score(y_true, y_pred))
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary", zero_division=0)
    mcc = float(matthews_corrcoef(y_true, y_pred)) if len(np.unique(y_true)) > 1 else float("nan")

    sensitivity = float(rec)
    specificity = float(tn / (tn + fp)) if (tn + fp) > 0 else float("nan")

    # Fisher exact test on the 2x2 confusion table (if available)
    p_value = float("nan")
    try:
        from scipy.stats import fisher_exact  # type: ignore
        _, p_value = fisher_exact([[tn, fp], [fn, tp]], alternative="two-sided")
        p_value = float(p_value)
    except Exception:
        p_value = float("nan")

    return {
        "threshold": float(thr),
        "tn": int(tn), "fp": int(fp), "fn": int(fn), "tp": int(tp),
        "accuracy": float(acc),
        "precision": float(prec),
        "recall": float(rec),
        "f1": float(f1),
        "sensitivity": float(sensitivity),
        "specificity": float(specificity),
        "mcc": float(mcc),
        "p_value_fisher": float(p_value),
    }

def compute_youden_j_threshold(y_true, y_prob):
    # Picks the ROC threshold that maximizes (TPR - FPR) on the val set
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    if len(np.unique(y_true)) < 2:
        return float("nan"), {"youden_j": float("nan"), "tpr": float("nan"), "fpr": float("nan")}
    fpr, tpr, thr = roc_curve(y_true, y_prob)
    j = tpr - fpr
    idx = int(np.argmax(j))
    return float(thr[idx]), {"youden_j": float(j[idx]), "tpr": float(tpr[idx]), "fpr": float(fpr[idx])}

def save_roc_curve_png(y_true, y_prob, out_png):
    # Simple ROC plot for the best epoch (val)
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    plt.figure()
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curve (Val)")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def save_confusion_png(y_true, y_prob, out_png, thr=0.5):
    # Confusion matrix at a chosen threshold
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(int)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])

    plt.figure()
    plt.imshow(cm)
    plt.xticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.yticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix (Val, thr={thr:.4f})")
    for (i, j), v in np.ndenumerate(cm):
        plt.text(j, i, str(v), ha="center", va="center")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def mean_sd(vals):
    # Small helper for reporting across seeds
    vals = np.asarray(vals, dtype=np.float64)
    mu = float(np.nanmean(vals)) if np.any(~np.isnan(vals)) else float("nan")
    sd = float(np.nanstd(vals, ddof=1)) if np.sum(~np.isnan(vals)) > 1 else 0.0
    return mu, sd

# -------------------------
# 11) Reproducibility: set all RNG seeds
# -------------------------
def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# -------------------------
# 11.5) Find the most recent BASELINE D7 trainval run for head initialization
#      - Uses summary_trainval.json to exclude any train_enh runs
#      - Requires all three best_heads.pt files
# -------------------------
BASELINE_TRAINVAL_ROOT = Path(DX_OUT_ROOT) / "trainval_runs"
if not BASELINE_TRAINVAL_ROOT.exists():
    raise FileNotFoundError(f"Missing trainval_runs folder under DX_OUT_ROOT: {str(BASELINE_TRAINVAL_ROOT)}")

exp_dirs = sorted([p for p in BASELINE_TRAINVAL_ROOT.glob("exp_*") if p.is_dir()],
                  key=lambda p: p.stat().st_mtime, reverse=True)
if not exp_dirs:
    raise FileNotFoundError(f"No exp_* folders found under: {str(BASELINE_TRAINVAL_ROOT)}")

train_dataset_id = "D7"

def _is_baseline_exp(exp_path: Path) -> bool:
    # Baseline = train_manifest_used does not contain train_enh markers
    summary_path = exp_path / "summary_trainval.json"
    if not summary_path.exists():
        return False
    try:
        with open(summary_path, "r", encoding="utf-8") as f:
            s = json.load(f)
        train_manifest_used = str(s.get("train_manifest_used", "")).lower()
        if "train_enh" in train_manifest_used or "manifest_train_enh" in train_manifest_used:
            return False
        return True
    except Exception:
        return False

def _has_all_seeds(exp_path: Path, dataset_id: str, seeds: list) -> bool:
    # Need all per-seed head checkpoints to initialize consistently
    for s in seeds:
        p = exp_path / f"run_{dataset_id}_seed{s}" / "best_heads.pt"
        if not p.exists():
            return False
    return True

baseline_exp = None
for ed in exp_dirs:
    if ed.resolve() == EXP_ROOT.resolve():
        continue
    if _is_baseline_exp(ed) and _has_all_seeds(ed, train_dataset_id, SEEDS):
        baseline_exp = ed
        break

if baseline_exp is None:
    raise FileNotFoundError(
        "Could not find a BASELINE D7 trainval experiment with all 3 best_heads.pt files.\n"
        "Baseline guard excludes experiments whose summary_trainval.json shows train_manifest_used contains 'train_enh'.\n"
        f"Searched under: {str(BASELINE_TRAINVAL_ROOT)}/exp_*/run_D7_seedXXXX/best_heads.pt"
    )

baseline_summary_path = baseline_exp / "summary_trainval.json"
with open(baseline_summary_path, "r", encoding="utf-8") as f:
    baseline_summary = json.load(f)

print("\nBaseline initialization experiment selected:")
print(" ", str(baseline_exp))
print(" ", "summary:", str(baseline_summary_path))
print(" ", "train_manifest_used (baseline):", baseline_summary.get("train_manifest_used", "NA"))

# -------------------------
# 12) Single-seed train+val loop with early stopping
#      - Tracks best val AUROC
#      - Saves best heads and best-epoch val plots
#      - Computes val-opt threshold at the best AUROC epoch
# -------------------------
def run_trainval_once(seed: int):
    set_all_seeds(seed)

    run_dir = EXP_ROOT / f"run_{dataset_id}_seed{seed}"
    run_dir.mkdir(parents=True, exist_ok=True)

    train_ds = AudioManifestDataset(train_df)
    val_ds   = AudioManifestDataset(val_df)

    train_loader = DataLoader(
        train_ds,
        batch_size=PER_DEVICE_BS,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        collate_fn=collate_fn
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=PER_DEVICE_BS,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        collate_fn=collate_fn
    )

    # Quick read of a few batches to catch loader issues early
    print(f"\n[seed={seed}] Warm-up: loading 3 train batches...")
    t0 = time.time()
    it = iter(train_loader)
    for i in range(3):
        _ = next(it)
        print(f"  loaded warmup batch {i+1}/3")
    print(f"[seed={seed}] Warm-up done in {time.time()-t0:.2f}s")

    model = Wav2Vec2TwoHeadClassifier(BACKBONE_CKPT, dropout_p=DROPOUT_P).to(DEVICE)

    # Head initialization from the selected baseline experiment (same seed)
    baseline_heads_path = baseline_exp / f"run_{train_dataset_id}_seed{seed}" / "best_heads.pt"
    print(f"[seed={seed}] Initializing heads from baseline:")
    print(" ", str(baseline_heads_path))
    model = load_heads_into_model(model, baseline_heads_path)
    model.train()

    # Only train the small head blocks (backbone stays frozen)
    trainable_params = (
        list(model.pre_vowel.parameters()) + list(model.pre_other.parameters()) +
        list(model.head_vowel.parameters()) + list(model.head_other.parameters())
    )
    opt = torch.optim.Adam(trainable_params, lr=LR)

    best_auc = -1.0
    best_epoch = -1
    no_improve = 0

    best_state = None
    best_val_probs = None
    best_val_true = None

    best_thr_youden = float("nan")
    best_thr_youden_details = None
    best_val_metrics_thr05 = None
    best_val_metrics_thr_opt = None

    for epoch in range(1, MAX_EPOCHS + 1):
        model.train()
        train_losses = []
        opt.zero_grad(set_to_none=True)

        pbar = tqdm(train_loader, desc=f"[seed={seed}] Train epoch {epoch}", dynamic_ncols=True)
        step = 0
        for batch in pbar:
            step += 1
            input_values = batch["input_values"].to(DEVICE, non_blocking=False)
            attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
            labels = batch["labels"].to(DEVICE, non_blocking=False)
            task_group = batch["task_group"]

            loss, _ = model(input_values, attention_mask, labels, task_group)
            loss = loss / GRAD_ACCUM
            loss.backward()

            train_losses.append(float(loss.detach().cpu().item()) * GRAD_ACCUM)

            if (step % GRAD_ACCUM) == 0:
                opt.step()
                opt.zero_grad(set_to_none=True)

        # Final step if the epoch ends mid accumulation
        if (step % GRAD_ACCUM) != 0:
            opt.step()
            opt.zero_grad(set_to_none=True)

        avg_train_loss = float(np.mean(train_losses)) if train_losses else float("nan")

        # Validation pass (collect probabilities for AUROC and thresholds)
        model.eval()
        all_probs, all_true = [], []
        vpbar = tqdm(val_loader, desc=f"[seed={seed}] Val epoch {epoch}", dynamic_ncols=True)
        with torch.inference_mode():
            for batch in vpbar:
                input_values = batch["input_values"].to(DEVICE, non_blocking=False)
                attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
                labels = batch["labels"].to(DEVICE, non_blocking=False)
                task_group = batch["task_group"]

                _, logits = model(input_values, attention_mask, labels, task_group)
                probs = torch.softmax(logits.float(), dim=-1)[:, 1].detach().cpu().numpy().astype(np.float64)
                all_probs.extend(probs.tolist())
                all_true.extend(labels.detach().cpu().numpy().astype(np.int64).tolist())

        val_auc = compute_auc(all_true, all_probs)
        print(f"seed={seed} | epoch {epoch:02d}/{MAX_EPOCHS} | train_loss={avg_train_loss:.5f} | val_AUROC={val_auc:.5f}")

        # Track the best val AUROC and snapshot head weights at that epoch
        improved = (not math.isnan(val_auc)) and (val_auc > best_auc + 1e-12)
        if improved:
            best_auc = float(val_auc)
            best_epoch = int(epoch)
            no_improve = 0

            best_state = {
                "pre_vowel": {k: v.detach().cpu().clone() for k, v in model.pre_vowel.state_dict().items()},
                "pre_other": {k: v.detach().cpu().clone() for k, v in model.pre_other.state_dict().items()},
                "head_vowel": {k: v.detach().cpu().clone() for k, v in model.head_vowel.state_dict().items()},
                "head_other": {k: v.detach().cpu().clone() for k, v in model.head_other.state_dict().items()},
            }

            best_val_probs = list(all_probs)
            best_val_true  = list(all_true)

            # Metrics at a fixed threshold and at the val-opt threshold
            best_val_metrics_thr05 = compute_threshold_metrics(best_val_true, best_val_probs, thr=0.5)

            thr_opt, details = compute_youden_j_threshold(best_val_true, best_val_probs)
            best_thr_youden = float(thr_opt)
            best_thr_youden_details = details
            best_val_metrics_thr_opt = compute_threshold_metrics(best_val_true, best_val_probs, thr=best_thr_youden)
        else:
            no_improve += 1

        # Early stop if AUROC does not improve for PATIENCE epochs
        if no_improve >= PATIENCE:
            break

    if best_state is None or best_val_probs is None or best_val_true is None:
        raise RuntimeError(
            "No best epoch captured. Validation AUROC may be NaN due to single-class validation split "
            "or earlier failures."
        )

    # Save best heads for this seed (used later by test-only cells)
    best_heads_path = run_dir / "best_heads.pt"
    torch.save(best_state, str(best_heads_path))

    # Save best-epoch plots for quick inspection
    roc_png = run_dir / "roc_curve.png"
    cm_png_05 = run_dir / "confusion_matrix_thr0p5.png"
    cm_png_opt = run_dir / "confusion_matrix_thr_opt.png"

    save_roc_curve_png(np.asarray(best_val_true, dtype=np.int64), np.asarray(best_val_probs, dtype=np.float64), str(roc_png))
    save_confusion_png(np.asarray(best_val_true, dtype=np.int64), np.asarray(best_val_probs, dtype=np.float64), str(cm_png_05), thr=0.5)
    if not np.isnan(best_thr_youden):
        save_confusion_png(np.asarray(best_val_true, dtype=np.int64), np.asarray(best_val_probs, dtype=np.float64), str(cm_png_opt), thr=float(best_thr_youden))

    # Per-seed metrics payload saved next to the model heads
    metrics = {
        "dataset": dataset_id,
        "seed": int(seed),
        "best_val_auroc": float(best_auc),
        "best_epoch": int(best_epoch),

        "train_manifest_used": MANIFEST_TRAIN_ENH3,
        "val_manifest_used": MANIFEST_ALL,

        "init_heads": {
            "mode": "baseline_best_heads",
            "baseline_exp_used": str(baseline_exp),
            "baseline_summary_path": str(baseline_summary_path),
            "baseline_best_heads_path": str(baseline_heads_path),
        },

        "n_train": int(len(train_df)),
        "n_val": int(len(val_df)),
        "label_counts_train": train_df["label_num"].value_counts(dropna=False).to_dict(),
        "label_counts_val": val_df["label_num"].value_counts(dropna=False).to_dict(),

        "experiment_tag": EXPERIMENT_TAG,
        "run_stamp": RUN_STAMP,

        "dropout_p": float(DROPOUT_P),
        "lr": float(LR),
        "effective_batch_size": int(PER_DEVICE_BS * GRAD_ACCUM),
        "per_device_batch_size": int(PER_DEVICE_BS),
        "grad_accum_steps": int(GRAD_ACCUM),

        "backbone_ckpt": BACKBONE_CKPT,

        "val_opt_threshold_method": "Youden J (maximize TPR - FPR on VAL ROC curve)",
        "val_opt_threshold": float(best_thr_youden),
        "val_opt_details": best_thr_youden_details,

        "thr_metrics_val_thr0p5": best_val_metrics_thr05,
        "thr_metrics_val_thr_opt": best_val_metrics_thr_opt,

        "artifacts": {
            "roc_curve_png": str(roc_png),
            "confusion_thr0p5_png": str(cm_png_05),
            "confusion_thr_opt_png": str(cm_png_opt),
            "best_heads_pt": str(best_heads_path),
        },
    }

    metrics_path = run_dir / "metrics.json"
    with open(metrics_path, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2)

    print(f"[seed={seed}] VAL-opt threshold (Youden J): {float(best_thr_youden):.6f}")
    print(f"[seed={seed}] WROTE:")
    print(" ", str(metrics_path))
    print(" ", str(roc_png))
    print(" ", str(cm_png_05))
    print(" ", str(cm_png_opt))
    print(" ", str(best_heads_path))

    return {
        "seed": int(seed),
        "best_val_auroc": float(best_auc),
        "best_epoch": int(best_epoch),
        "val_opt_thr": float(best_thr_youden),
        "run_dir": str(run_dir),
        "seed_metrics": metrics,
    }

# -------------------------
# 13) Run all seeds and write the experiment summary
#      - AUROC summary: mean and 95% CI across seeds
#      - Threshold summary: mean and SD across seeds
# -------------------------
results = []
for seed in SEEDS:
    results.append(run_trainval_once(seed))

aucs = [r["best_val_auroc"] for r in results]
thr_vals = [r["val_opt_thr"] for r in results]

t_crit = 4.302652729911275  # df=2, 95% CI
n = len(aucs)
mean_auc = float(np.mean(aucs))
std_auc = float(np.std(aucs, ddof=1)) if n > 1 else 0.0
half_width = float(t_crit * (std_auc / math.sqrt(n))) if n > 1 else 0.0
ci95 = [float(mean_auc - half_width), float(mean_auc + half_width)]

thr_mean, thr_sd = mean_sd(thr_vals)

print("\nAUROC by seed:")
for r in results:
    print(f"  seed {r['seed']}: {r['best_val_auroc']:.6f}")
print(f"\nMean AUROC: {mean_auc:.6f}")
print(f"95% CI (t, n=3): [{ci95[0]:.6f}, {ci95[1]:.6f}]")

print("\nVAL-opt thresholds (Youden J) by seed:")
for r in results:
    print(f"  seed {r['seed']}: {r['val_opt_thr']:.6f}")
print(f"  mean ± SD: {thr_mean:.6f} ± {thr_sd:.6f}")

val_optimal_threshold_obj = {
    "method": "Youden J (maximize TPR - FPR on VAL ROC curve)",
    "by_seed": {str(r["seed"]): float(r["val_opt_thr"]) for r in results},
    "mean_sd": {"mean": float(thr_mean), "sd": float(thr_sd)},
}

# Compact experiment summary used by downstream test-only code
exp_summary = {
    "dataset": dataset_id,
    "dx_out_root": DX_OUT_ROOT,

    "train_manifest_used": MANIFEST_TRAIN_ENH3,
    "val_manifest_used": MANIFEST_ALL,

    "init_heads": {
        "mode": "baseline_best_heads",
        "baseline_exp_used": str(baseline_exp),
        "baseline_summary_path": str(baseline_summary_path),
        "baseline_best_heads_by_seed": {
            str(s): str(baseline_exp / f"run_{train_dataset_id}_seed{s}" / "best_heads.pt") for s in SEEDS
        },
    },

    "experiment_tag": EXPERIMENT_TAG,
    "run_stamp": RUN_STAMP,
    "exp_root": str(EXP_ROOT),
    "run_dirs": [r["run_dir"] for r in results],
    "seeds": SEEDS,

    "aurocs": [float(x) for x in aucs],
    "mean_auroc": float(mean_auc),
    "t_crit_df2_95": float(t_crit),
    "half_width_95_ci": float(half_width),
    "ci95": ci95,

    "n_train": int(len(train_df)),
    "n_val": int(len(val_df)),
    "label_counts_train": train_df["label_num"].value_counts(dropna=False).to_dict(),
    "label_counts_val": val_df["label_num"].value_counts(dropna=False).to_dict(),

    "effective_batch_size": int(PER_DEVICE_BS * GRAD_ACCUM),
    "per_device_batch_size": int(PER_DEVICE_BS),
    "grad_accum_steps": int(GRAD_ACCUM),

    "backbone_ckpt": BACKBONE_CKPT,
    "dropout_p": float(DROPOUT_P),
    "lr": float(LR),

    "val_optimal_threshold": val_optimal_threshold_obj,
    "per_seed_metrics": [r["seed_metrics"] for r in results],
}

summary_path = EXP_ROOT / "summary_trainval.json"
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(exp_summary, f, indent=2)

# Append-only history log for quick run tracking
history_path = TRAINVAL_ROOT / "history_index.jsonl"
TRAINVAL_ROOT.mkdir(parents=True, exist_ok=True)
with open(history_path, "a", encoding="utf-8") as f:
    f.write(json.dumps(exp_summary) + "\n")

print("\nWROTE per-experiment summary:", str(summary_path))
print("APPENDED global history index:", str(history_path))
print("\nOpen this folder to access artifacts:", str(EXP_ROOT))

# -------------------------
# 14) Runtime shutdown (stop L4)
# -------------------------
print("\nAll done. Unassigning the runtime to stop the L4 instance...")
try:
    from google.colab import runtime  # type: ignore
    print("Calling runtime.unassign() now.")
    runtime.unassign()
except Exception as e:
    print("Could not unassign runtime automatically. Error:", repr(e))
    print("Manual stop: Runtime -> Disconnect and delete runtime.")

The following cell runs a **test-only evaluation** of the **D7 enhanced model trained with train_enh3** on the **D2 test split**. No training is done in this cell. It reads only one input file, `D2/manifests/manifest_all.csv`, keeps rows where `split == "test"`, and confirms that the manifest truly belongs to **dataset D2** (the cell stops immediately if it does not). It also checks that every file listed in `clip_path` exists before continuing, so problems are caught early.

The cell then prepares the test data in a consistent way. Each clip is assigned to a simple **task group**: `"vowel"` when `task == "vowl"`, otherwise `"other"`. Sex values are standardized for fairness analysis using D2’s exact labels: `"male"` is mapped to `"M"`, `"female"` to `"F"`, and anything else to `"UNK"`. The resulting sex counts are printed. During audio loading, each clip is verified to be sampled at **16 kHz**. For vowel clips, the attention mask is trimmed so trailing near-silent padding does not affect the model. Other clips use the full attention mask.

Next, the cell locates the **most recent D7 train+validation experiment** under `D7/trainval_runs/` whose folder name contains `"trainEnh3"` (case-insensitive) and that includes `summary_trainval.json` and `best_heads.pt` for **all three seeds** (1337, 2024, 7777). After selecting this experiment, it re-checks that all expected `best_heads.pt` files exist as a second safety step. From the selected experiment’s `summary_trainval.json`, it reads a **single global decision threshold** from `val_optimal_threshold.mean_sd.mean`. This same threshold is used for **all three seeds**. If the value is missing or invalid, the cell falls back to **0.5** and records this fallback.

For each seed, the cell rebuilds the model using a frozen Wav2Vec2 backbone and the same two-head structure used during training. It loads the trained head weights for that seed and runs inference on the full D2 test set. For each seed, it computes **test AUROC** and saves a ROC curve, an overall confusion matrix at the global threshold, and separate confusion matrices for sex `"M"` and `"F"` when those groups are present. It also writes a per-seed `predictions.csv` containing the clip path, true label, predicted PD probability, normalized sex, speaker ID, task group, seed, and run metadata including the threshold used.

After all three seeds are evaluated, the cell computes and saves overall results. It reports **mean test AUROC with a 95% t-based confidence interval (n=3)** and **threshold-based metrics** (accuracy, precision, sensitivity/recall, specificity, F1, MCC, Fisher p-value) as mean ± SD across seeds. It also computes the main fairness metric **H3** at the same global threshold: **ΔFNR = FNR(F) − FNR(M)**, where FNR is calculated only on true Parkinson’s cases (FN divided by FN + TP). It reports FNR for M and F, the signed ΔFNR, and |ΔFNR| as mean ± SD across seeds, with NaN used when a sex group has no PD cases.

All outputs are saved under a new run folder in `D7/monolingual_test_runs/`, named using the selected train+val experiment tag plus a timestamp. For traceability, the cell also updates lightweight pointers and history files: it writes `summary_latest.json` and `last_run_pointer.json` (backing up any existing versions), appends a one-line entry to `history_index.jsonl`, and writes a stable per-tag pointer in `monolingual_test_runs/run_<TAG_SAFE>/tag_run_pointer.json`. Finally, for consistency with the preprocessing layout, it writes a small `run_config.json` and summary log files under `D7/config/...` and `D7/logs/...`, and clears the GPU cache at the end.

In [None]:
# =========================
# Enhanced D7 Heads on D2 Test (trainEnh3, fixed threshold)
# =========================
# Purpose: evaluate saved D7 trainEnh3 heads on the D2 test split using one shared threshold for all seeds.
# Inputs: D2 manifest (test split) and the most recent matching D7 trainval experiment (heads + summary).
# Outputs: per-seed metrics and predictions, plots, and a run summary plus pointer files for quick lookup.
# =========================

import os, json, math, random, time, warnings
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import soundfile as sf

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import Wav2Vec2Model

from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, matthews_corrcoef
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# -------------------------
# 0) Import safety check
# -------------------------
# Stops early if local files would override PyTorch or Transformers imports.
if os.path.exists("/content/torch.py") or os.path.exists("/content/torch/__init__.py"):
    raise RuntimeError("Found local /content/torch.py or /content/torch/ that shadows PyTorch. Rename/remove it and restart runtime.")
if os.path.exists("/content/transformers.py") or os.path.exists("/content/transformers/__init__.py"):
    raise RuntimeError("Found local /content/transformers.py or /content/transformers/ that shadows Hugging Face Transformers. Rename/remove it and restart runtime.")

# -------------------------
# 1) Drive mount (best-effort)
# -------------------------
# Mounts Google Drive if running in Colab and not already mounted.
try:
    from google.colab import drive  # type: ignore
    if not os.path.isdir("/content/drive/MyDrive"):
        drive.mount("/content/drive")
except Exception:
    pass

# -------------------------
# 2) Resolve run roots and manifest path
# -------------------------
# Uses notebook globals if present; otherwise falls back to the defaults below.
# DX_OUT_ROOT follows the D7 root because outputs are written next to D7 trainval runs.
D7_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D7-Multilingual (D1_D4_D5v2_D6)/preprocessed_v1"
D2_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D2-Slovak (EWA-DB)/EWA-DB/preprocessed_v1"

D7_OUT_ROOT = str(globals().get("D7_OUT_ROOT", D7_OUT_ROOT_FALLBACK))
D2_OUT_ROOT = str(globals().get("D2_OUT_ROOT", D2_OUT_ROOT_FALLBACK))

D2_MANIFEST_ALL = f"{D2_OUT_ROOT}/manifests/manifest_all.csv"

DX_OUT_ROOT = D7_OUT_ROOT
globals()["DX_OUT_ROOT"] = DX_OUT_ROOT
globals()["D7_OUT_ROOT"] = D7_OUT_ROOT
globals()["D2_OUT_ROOT"] = D2_OUT_ROOT

# Run stamp used only for naming backups and this test run folder.
RUN_STAMP = time.strftime("%Y%m%d_%H%M%S")

# Creates a timestamped backup before overwriting files intended to be replaced.
def _backup_if_exists(p: Path):
    if p.exists():
        bak = p.with_suffix(p.suffix + f".bak_{RUN_STAMP}")
        try:
            p.rename(bak)
            print(f"BACKUP: {str(p)} -> {str(bak)}")
        except Exception as e:
            raise RuntimeError(f"Could not backup existing file before overwrite: {str(p)}. Error: {repr(e)}")

# Converts an arbitrary string into a filesystem-safe tag used in folder names.
def _sanitize_tag(s: str) -> str:
    s = str(s).strip()
    out = []
    for ch in s:
        if ch.isalnum() or ch in ["-", "_"]:
            out.append(ch)
        else:
            out.append("_")
    out = "".join(out).strip("_")
    return out if out else "tag"

# -------------------------
# 3) Fixed evaluation settings
# -------------------------
# Matches the training backbone and common inference settings.
SEEDS          = [1337, 2024, 7777]
BACKBONE_CKPT  = "facebook/wav2vec2-base"
SR_EXPECTED    = 16000
TINY_THRESH    = 1e-4

# Batch sizing values are printed for consistency tracking (inference uses PER_DEVICE_BS).
EFFECTIVE_BS   = 64
PER_DEVICE_BS  = 16
GRAD_ACCUM     = max(1, EFFECTIVE_BS // PER_DEVICE_BS)

DROPOUT_P      = 0.2

NUM_WORKERS    = 0
PIN_MEMORY     = False

USE_AMP        = True

# Trainval experiment tag used to pick the most recent enhanced run.
ENH_TAG = "trainEnh3"
REQUIRED_EXP_SUBSTRING = ENH_TAG  # case-insensitive
ENH_TAG_SAFE = _sanitize_tag(ENH_TAG)  # kept for traceability; not used for folder naming in this cell

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

# Reduces common non-critical warnings in notebook output.
warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
warnings.filterwarnings("ignore", message="Passing `gradient_checkpointing` to a config initialization is deprecated")
warnings.filterwarnings("ignore", category=UserWarning, module="huggingface_hub")

# Quick run context printout.
print("D7_OUT_ROOT:", D7_OUT_ROOT)
print("D2_OUT_ROOT:", D2_OUT_ROOT)
print("D2_MANIFEST_ALL:", D2_MANIFEST_ALL)
print("DEVICE:", DEVICE)
if DEVICE.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
print("PER_DEVICE_BS:", PER_DEVICE_BS, "| EFFECTIVE_BS:", PER_DEVICE_BS * GRAD_ACCUM)
print("NUM_WORKERS:", NUM_WORKERS, "| PIN_MEMORY:", PIN_MEMORY)
print("USE_AMP:", bool(USE_AMP and DEVICE.type == "cuda"))
print("Enhanced exp required substring (case-insensitive):", REQUIRED_EXP_SUBSTRING)

# -------------------------
# 4) Load D2 manifest and build the test table
# -------------------------
# Reads the manifest, checks required columns, confirms it is D2, then filters to split == "test".
if not os.path.exists(D2_MANIFEST_ALL):
    raise FileNotFoundError(f"Missing D2 manifest_all.csv: {D2_MANIFEST_ALL}")

m_all = pd.read_csv(D2_MANIFEST_ALL)

# Required columns for evaluation and fairness reporting.
req_cols = {"split", "clip_path", "label_num", "task", "sex", "age"}
missing = [c for c in sorted(req_cols) if c not in m_all.columns]
if missing:
    raise ValueError(f"D2 manifest missing required columns: {missing}. Found: {list(m_all.columns)}")

# Uses the most common dataset id (when present) and filters to it.
if "dataset" in m_all.columns and m_all["dataset"].notna().any():
    d2_dataset_id = str(m_all["dataset"].astype(str).value_counts(dropna=True).idxmax())
    m_all = m_all[m_all["dataset"].astype(str) == d2_dataset_id].copy()
else:
    d2_dataset_id = "DX"

# --------- GUARD A ----------
# Hard stop if the manifest does not identify as D2.
if d2_dataset_id != "D2":
    raise RuntimeError(
        f"Expected D2 dataset_id=='D2' but got {d2_dataset_id!r}. "
        "This usually means D2_OUT_ROOT is wrong or the manifest is not D2. "
        f"D2_OUT_ROOT={D2_OUT_ROOT}"
    )

# Keeps a stable set of columns used later (fills missing ones with NaN).
keep_cols = ["clip_path", "label_num", "task", "speaker_id", "sex", "age", "duration_sec", "split"]
for c in keep_cols:
    if c not in m_all.columns:
        m_all[c] = np.nan
m_all = m_all[keep_cols].copy()

# IMPORTANT: test split only.
test_df = m_all[m_all["split"].astype(str) == "test"].reset_index(drop=True)

print(f"\nD2 dataset inferred: {d2_dataset_id}")
print(f"D2 TEST rows: {len(test_df)}")
print("D2 TEST label counts:", test_df["label_num"].value_counts(dropna=False).to_dict())
print("D2 TEST sex counts (raw):", test_df["sex"].value_counts(dropna=False).to_dict())

if len(test_df) == 0:
    raise RuntimeError("After filtering to split=='test', D2 manifest has 0 rows.")

# -------------------------
# 5) Fail fast: confirm audio files exist
# -------------------------
# Checks file existence early to avoid long runs that fail late.
def _fail_fast_missing_paths(df: pd.DataFrame, name: str):
    missing_paths = []
    for p in tqdm(df["clip_path"].astype(str).tolist(), desc=f"Check {name} clip_path exists", dynamic_ncols=True):
        if not os.path.exists(p):
            missing_paths.append(p)
            if len(missing_paths) >= 10:
                break
    if missing_paths:
        raise FileNotFoundError(f"{name}: missing clip_path(s). Examples: {missing_paths}")

_fail_fast_missing_paths(test_df, "D2 TEST")

# -------------------------
# 6) Task grouping for the two-head model
# -------------------------
# Converts raw task into the head selector used during inference.
def _task_group(task_val) -> str:
    return "vowel" if str(task_val) == "vowl" else "other"

test_df["task_group"] = test_df["task"].apply(_task_group)

# -------------------------
# 6.5) Sex normalization for fairness reporting
# -------------------------
# D2 uses exact strings "male"/"female"; anything else becomes "UNK".
def normalize_sex_d2_case_sensitive(val) -> str:
    if pd.isna(val):
        return "UNK"
    if val == "male":
        return "M"
    if val == "female":
        return "F"
    return "UNK"

test_df["sex_norm"] = test_df["sex"].apply(normalize_sex_d2_case_sensitive)
print("D2 TEST sex counts (normalized):", test_df["sex_norm"].value_counts(dropna=False).to_dict())
if (test_df["sex_norm"] == "UNK").any():
    print("NOTE: Some D2 'sex' values were not exactly 'male'/'female' and were mapped to 'UNK'.")

# -------------------------
# 7) Dataset and batch padding
# -------------------------
# Loads audio from clip_path and builds an attention mask for pooling.
# For vowel clips, the mask ignores trailing near-silence so padding does not affect pooling.
class AudioManifestDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        clip_path = str(row["clip_path"])
        label = int(row["label_num"])
        task_group = str(row["task_group"])
        sex_norm = str(row["sex_norm"])
        speaker_id = row["speaker_id"] if "speaker_id" in row.index else np.nan

        if not os.path.exists(clip_path):
            raise FileNotFoundError(f"Missing clip_path: {clip_path}")

        # Reads audio and converts to mono float32.
        y, sr = sf.read(clip_path, always_2d=False)
        if isinstance(y, np.ndarray) and y.ndim > 1:
            y = np.mean(y, axis=1)
        y = np.asarray(y, dtype=np.float32)

        # Enforces a single sample rate for consistent model input.
        if int(sr) != SR_EXPECTED:
            raise ValueError(f"Sample rate mismatch (expected {SR_EXPECTED}): {clip_path} has sr={sr}")

        # Attention mask: 1 = keep, 0 = ignore during pooling.
        attn = np.ones((len(y),), dtype=np.int64)
        if task_group == "vowel":
            # Finds last non-tiny sample; masks the tail after it.
            k = -1
            for j in range(len(y) - 1, -1, -1):
                if abs(float(y[j])) > TINY_THRESH:
                    k = j
                    break
            if k >= 0:
                attn[:k+1] = 1
                attn[k+1:] = 0
            else:
                attn[:] = 1
        else:
            attn[:] = 1

        return {
            "input_values": torch.from_numpy(y),
            "attention_mask": torch.from_numpy(attn),
            "labels": torch.tensor(label, dtype=torch.long),
            "task_group": task_group,
            "sex_norm": sex_norm,
            "clip_path": clip_path,
            "speaker_id": speaker_id,
        }

# Pads variable-length audio in a batch to the longest clip.
def collate_fn(batch):
    max_len = int(max(b["input_values"].numel() for b in batch))
    input_vals, attn_masks, labels = [], [], []
    task_groups, sex_norms, clip_paths, speaker_ids = [], [], [], []
    for b in batch:
        x = b["input_values"]
        a = b["attention_mask"]
        pad = max_len - x.numel()
        if pad > 0:
            x = torch.cat([x, torch.zeros(pad, dtype=x.dtype)], dim=0)
            a = torch.cat([a, torch.zeros(pad, dtype=a.dtype)], dim=0)
        input_vals.append(x)
        attn_masks.append(a)
        labels.append(b["labels"])
        task_groups.append(b["task_group"])
        sex_norms.append(b["sex_norm"])
        clip_paths.append(b["clip_path"])
        speaker_ids.append(b["speaker_id"])
    return {
        "input_values": torch.stack(input_vals, dim=0),
        "attention_mask": torch.stack(attn_masks, dim=0),
        "labels": torch.stack(labels, dim=0),
        "task_group": task_groups,
        "sex_norm": sex_norms,
        "clip_path": clip_paths,
        "speaker_id": speaker_ids,
    }

# -------------------------
# 8) Two-head classifier (frozen backbone)
# -------------------------
# Uses a frozen Wav2Vec2 backbone and switches heads by task_group.
class Wav2Vec2TwoHeadClassifier(nn.Module):
    def __init__(self, ckpt: str, dropout_p: float):
        super().__init__()
        self.backbone = Wav2Vec2Model.from_pretrained(ckpt, use_safetensors=True)
        for p in self.backbone.parameters():
            p.requires_grad = False

        H = int(self.backbone.config.hidden_size)
        self.pre_vowel = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.pre_other = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.head_vowel = nn.Linear(H, 2)
        self.head_other = nn.Linear(H, 2)

    # Mean-pools frame features using the attention mask.
    def masked_mean_pool(self, last_hidden, attn_mask_samples):
        feat_mask = self.backbone._get_feature_vector_attention_mask(last_hidden.shape[1], attn_mask_samples)
        feat_mask = feat_mask.to(last_hidden.device).unsqueeze(-1).type_as(last_hidden)
        masked = last_hidden * feat_mask
        denom = feat_mask.sum(dim=1).clamp(min=1.0)
        return masked.sum(dim=1) / denom

    # Runs heads in fp32 for stable probabilities even when AMP is enabled.
    def _heads_fp32(self, x_any: torch.Tensor, head: nn.Module) -> torch.Tensor:
        x = x_any.float()
        if x.is_cuda:
            with torch.amp.autocast(device_type="cuda", enabled=False):
                return head(x)
        return head(x)

    # Returns logits for the PD vs healthy classes, using the correct head per item.
    def forward_logits(self, input_values, attention_mask, task_group):
        with torch.no_grad():
            out = self.backbone(input_values=input_values, attention_mask=attention_mask)
            last_hidden = out.last_hidden_state

        pooled = self.masked_mean_pool(last_hidden, attention_mask)

        z_v = self.pre_vowel(pooled.float())
        z_o = self.pre_other(pooled.float())

        logits_v = self._heads_fp32(z_v, self.head_vowel)
        logits_o = self._heads_fp32(z_o, self.head_other)

        is_vowel = torch.tensor([tg == "vowel" for tg in task_group], device=pooled.device, dtype=torch.bool)
        logits = logits_o.clone()
        if is_vowel.any():
            logits[is_vowel] = logits_v[is_vowel]
        return logits

# -------------------------
# 9) Metrics and plotting helpers
# -------------------------
# AUROC is threshold-independent; other metrics use a chosen threshold.
def compute_auc(y_true, y_prob):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    if len(np.unique(y_true)) < 2:
        return float("nan")
    return float(roc_auc_score(y_true, y_prob))

def compute_threshold_metrics(y_true, y_prob, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])

    eps = 1e-12
    acc = (TP + TN) / max(1, (TP + TN + FP + FN))
    prec = TP / (TP + FP + eps)
    rec = TP / (TP + FN + eps)     # sensitivity
    f1 = 2 * prec * rec / (prec + rec + eps)
    spec = TN / (TN + FP + eps)

    # MCC can be undefined if predictions collapse to a single class.
    try:
        mcc = float(matthews_corrcoef(y_true, y_pred)) if len(np.unique(y_pred)) > 1 else float("nan")
    except Exception:
        mcc = float("nan")

    # Fisher exact test on the 2x2 table (best-effort if scipy is available).
    pval = float("nan")
    try:
        from scipy.stats import fisher_exact  # type: ignore
        _, pval = fisher_exact([[TN, FP], [FN, TP]], alternative="two-sided")
        pval = float(pval)
    except Exception:
        pval = float("nan")

    return {
        "threshold": float(thr),
        "tn": TN, "fp": FP, "fn": FN, "tp": TP,
        "accuracy": float(acc),
        "precision": float(prec),
        "recall": float(rec),
        "f1": float(f1),
        "sensitivity": float(rec),
        "specificity": float(spec),
        "mcc": float(mcc),
        "p_value_fisher_two_sided": float(pval),
    }

# Saves a simple ROC curve plot for a single seed run.
def save_roc_curve_png(y_true, y_prob, out_png, title_suffix="Test"):
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    plt.figure()
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curve ({title_suffix})")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

# Saves a labeled 2x2 confusion matrix image.
def save_confusion_png(y_true, y_prob, out_png, thr=0.5, title_suffix="Test"):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(int)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])

    plt.figure()
    plt.imshow(cm)
    plt.xticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.yticks([0, 1], ["Healthy(0)", "PD(1)"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"Confusion Matrix ({title_suffix}, thr={thr:.4f})")
    for (i, j), v in np.ndenumerate(cm):
        plt.text(j, i, str(v), ha="center", va="center")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

# -------------------------
# 9.5) Fairness (H3): signed ΔFNR = FNR(F) - FNR(M)
# -------------------------
# Computes FNR per sex group using PD-only cases, then returns signed and absolute gaps.
def compute_fnr_by_group_signed(y_true, y_prob, groups, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    groups = np.asarray(list(groups), dtype=object)
    y_pred = (y_prob >= float(thr)).astype(np.int64)

    out = {}
    for g in sorted(set(groups.tolist())):
        mask_g = (groups == g)
        if int(mask_g.sum()) == 0:
            continue

        pos_mask = mask_g & (y_true == 1)
        n_pos = int(pos_mask.sum())
        if n_pos == 0:
            out[g] = {"n_total": int(mask_g.sum()), "n_pos": 0, "tp": 0, "fn": 0, "fnr": float("nan")}
            continue

        tp = int(((y_pred == 1) & pos_mask).sum())
        fn = int(((y_pred == 0) & pos_mask).sum())
        fnr = float(fn / max(1, (fn + tp)))
        out[g] = {"n_total": int(mask_g.sum()), "n_pos": int(n_pos), "tp": int(tp), "fn": int(fn), "fnr": float(fnr)}

    fnr_m = out.get("M", {}).get("fnr", float("nan"))
    fnr_f = out.get("F", {}).get("fnr", float("nan"))
    if (not np.isnan(fnr_m)) and (not np.isnan(fnr_f)):
        delta_signed = float(fnr_f - fnr_m)
        delta_abs = float(abs(delta_signed))
    else:
        delta_signed = float("nan")
        delta_abs = float("nan")

    return out, delta_signed, delta_abs

# Confusion counts are saved both overall and by sex for quick diagnostics.
def compute_confusion_counts(y_true, y_prob, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    y_pred = (y_prob >= float(thr)).astype(np.int64)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    return {"TN": int(cm[0, 0]), "FP": int(cm[0, 1]), "FN": int(cm[1, 0]), "TP": int(cm[1, 1])}

def compute_confusion_by_group(y_true, y_prob, groups, thr=0.5):
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    groups = np.asarray(list(groups), dtype=object)
    out = {}
    for g in sorted(set(groups.tolist())):
        mask = (groups == g)
        if int(mask.sum()) == 0:
            continue
        out[g] = {"n": int(mask.sum()), "confusion": compute_confusion_counts(y_true[mask], y_prob[mask], thr=thr)}
    return out

# -------------------------
# 10) Seed control
# -------------------------
# Fixes random states to make runs repeatable per seed.
def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# -------------------------
# 11) Select the most recent matching D7 enhanced trainval experiment
# -------------------------
# Looks under trainval_runs/exp_* for the newest folder whose name contains trainEnh3 and has all three heads + summary.
TRAINVAL_ROOT = Path(D7_OUT_ROOT) / "trainval_runs"
if not TRAINVAL_ROOT.exists():
    raise FileNotFoundError(f"Missing trainval_runs folder under D7_OUT_ROOT: {str(TRAINVAL_ROOT)}")

exp_dirs = sorted([p for p in TRAINVAL_ROOT.glob("exp_*") if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
if not exp_dirs:
    raise FileNotFoundError(f"No exp_* folders found under: {str(TRAINVAL_ROOT)}")

train_dataset_id = "D7"  # expected naming from trainval code (run_D7_seedXXXX)

def _is_enhanced_exp_dir(exp_path: Path, required_substring: str) -> bool:
    return (required_substring.lower() in exp_path.name.lower())

def _has_all_seeds_and_summary(exp_path: Path, dataset_id: str, seeds: list) -> bool:
    summary_path = exp_path / "summary_trainval.json"
    if not summary_path.exists():
        return False
    for s in seeds:
        p = exp_path / f"run_{dataset_id}_seed{s}" / "best_heads.pt"
        if not p.exists():
            return False
    return True

chosen_exp = None
for ed in exp_dirs:
    if not _is_enhanced_exp_dir(ed, REQUIRED_EXP_SUBSTRING):
        continue
    if _has_all_seeds_and_summary(ed, train_dataset_id, SEEDS):
        chosen_exp = ed
        break

if chosen_exp is None:
    sample = exp_dirs[0]
    raise FileNotFoundError(
        "Could not find a recent D7 *ENHANCED* trainval experiment folder that:\n"
        f"  (1) contains substring '{REQUIRED_EXP_SUBSTRING}' (case-insensitive) in the exp folder name, and\n"
        "  (2) contains all 3 best_heads.pt files + summary_trainval.json.\n\n"
        f"Most recent exp checked (for reference): {str(sample)}"
    )

# Uses the full exp folder name as the tag so outputs are easy to trace back.
FULL_TRAINVAL_EXP_TAG = chosen_exp.name
TAG_SAFE = _sanitize_tag(FULL_TRAINVAL_EXP_TAG)
RUN_PARENT_DIRNAME = f"run_{TAG_SAFE}__{RUN_STAMP}"

# Output locations: one folder per run stamp, plus a stable per-tag folder for pointers.
TEST_ROOT = Path(D7_OUT_ROOT) / "monolingual_test_runs"
RUN_ROOT  = TEST_ROOT / RUN_PARENT_DIRNAME
RUN_ROOT.mkdir(parents=True, exist_ok=True)

TAG_ROOT = TEST_ROOT / f"run_{TAG_SAFE}"
TAG_ROOT.mkdir(parents=True, exist_ok=True)

# Builder-aligned config and logs folders (kept stable per tag).
cfg_dir  = Path(D7_OUT_ROOT) / "config" / f"D7_{TAG_SAFE}_on_D2_Test"
logs_dir = Path(D7_OUT_ROOT) / "logs"   / f"D7_{TAG_SAFE}_on_D2_Test"
cfg_dir.mkdir(parents=True, exist_ok=True)
logs_dir.mkdir(parents=True, exist_ok=True)

RUN_CONFIG_PATH       = cfg_dir / "run_config.json"
WARNINGS_CSV_PATH     = logs_dir / "preprocess_warnings.csv"
DATASET_SUMMARY_PATH  = logs_dir / "dataset_summary.json"

print("\nUsing D7 ENHANCED Train+Val experiment folder:")
print(" ", str(chosen_exp))
print("FULL_TRAINVAL_EXP_TAG:", FULL_TRAINVAL_EXP_TAG)
print("ENH_TAG:", ENH_TAG)
print("RUN_ROOT:", str(RUN_ROOT))
print("cfg_dir:", str(cfg_dir))
print("logs_dir:", str(logs_dir))

# --------- GUARD B ----------
# Re-checks that all seed head files exist under the chosen experiment.
for s in SEEDS:
    p = chosen_exp / f"run_{train_dataset_id}_seed{s}" / "best_heads.pt"
    if not p.exists():
        raise RuntimeError(f"Trainval artifact missing after choosing exp. Missing: {str(p)}")

# Loads the trainval summary used to fetch the shared mean threshold.
summary_trainval_path = chosen_exp / "summary_trainval.json"
with open(summary_trainval_path, "r", encoding="utf-8") as f:
    d7_trainval_summary = json.load(f)

# -------------------------
# 11.5) Shared threshold from trainval summary
# -------------------------
# Uses the mean validation-optimal threshold for all test seeds; falls back to 0.5 if missing.
val_opt_obj = (d7_trainval_summary or {}).get("val_optimal_threshold", {}) or {}
thr_mean_sd = (val_opt_obj.get("mean_sd", {}) or {})

def _get_mean_val_opt_threshold() -> float:
    try:
        return float(thr_mean_sd.get("mean", float("nan")))
    except Exception:
        return float("nan")

THR_MEAN_FROM_TRAINVAL = _get_mean_val_opt_threshold()

if np.isnan(THR_MEAN_FROM_TRAINVAL):
    THR_USED_GLOBAL = 0.5
    THR_GLOBAL_NOTE = (
        "Mean val-opt threshold was missing/NaN in D7 enhanced summary_trainval.json. "
        "Fallback: THR_USED_GLOBAL=0.5 for ALL seeds."
    )
else:
    THR_USED_GLOBAL = float(THR_MEAN_FROM_TRAINVAL)
    THR_GLOBAL_NOTE = None

print("\nVAL-opt threshold selection for TEST (GLOBAL):")
print("  Source: summary_trainval.json -> val_optimal_threshold.mean_sd.mean")
print(f"  THR_USED_GLOBAL: {THR_USED_GLOBAL:.6f}")
if THR_GLOBAL_NOTE is not None:
    print("  NOTE:", THR_GLOBAL_NOTE)

# -------------------------
# 13) D2 test DataLoader
# -------------------------
# Builds the DataLoader used for all seeds.
test_ds = AudioManifestDataset(test_df)
test_loader = DataLoader(
    test_ds,
    batch_size=PER_DEVICE_BS,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    collate_fn=collate_fn,
)

# -------------------------
# 14) Warm-up read (sanity check)
# -------------------------
# Loads a few batches to catch DataLoader issues early.
print("\nWarm-up: loading up to 3 D2 TEST batches...")
t0 = time.time()

def _warmup(loader, name):
    nb = len(loader)
    wb = min(3, nb)
    if wb == 0:
        raise RuntimeError(f"{name} DataLoader has 0 batches. Check df length and PER_DEVICE_BS.")
    it = iter(loader)
    for i in range(wb):
        _ = next(it)
        print(f"  loaded warmup {name} batch {i+1}/{wb}")

_warmup(test_loader, "D2 TEST")
print(f"Warm-up done in {time.time()-t0:.2f}s")

# -------------------------
# 15) Load saved heads into the model
# -------------------------
# Loads only the trained head and pre-head blocks; backbone stays pretrained and frozen.
def load_heads_into_model(model: Wav2Vec2TwoHeadClassifier, best_heads_path: Path):
    if not best_heads_path.exists():
        raise FileNotFoundError(f"Missing best_heads.pt: {str(best_heads_path)}")
    state = torch.load(str(best_heads_path), map_location="cpu")
    model.pre_vowel.load_state_dict(state["pre_vowel"], strict=True)
    model.pre_other.load_state_dict(state["pre_other"], strict=True)
    model.head_vowel.load_state_dict(state["head_vowel"], strict=True)
    model.head_other.load_state_dict(state["head_other"], strict=True)
    return model

# -------------------------
# 16) Inference helper with metadata
# -------------------------
# Runs inference and returns probabilities plus fields needed for predictions.csv.
def _infer_probs_with_meta(loader, model, desc):
    use_amp = bool(USE_AMP and DEVICE.type == "cuda")
    all_probs, all_true, all_sex = [], [], []
    all_clip, all_spk, all_task = [], [], []

    pbar = tqdm(loader, desc=desc, dynamic_ncols=True)
    with torch.inference_mode():
        for batch in pbar:
            input_values = batch["input_values"].to(DEVICE, non_blocking=False)
            attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
            labels = batch["labels"].to(DEVICE, non_blocking=False)
            task_group = batch["task_group"]
            sex_norm = batch["sex_norm"]
            clip_paths = batch["clip_path"]
            speaker_ids = batch["speaker_id"]

            with torch.amp.autocast(device_type="cuda", enabled=use_amp):
                logits = model.forward_logits(input_values, attention_mask, task_group)

            # Class-1 probability is treated as PD probability.
            probs = torch.softmax(logits.float(), dim=-1)[:, 1].detach().cpu().numpy().astype(np.float64)
            all_probs.extend(probs.tolist())
            all_true.extend(labels.detach().cpu().numpy().astype(np.int64).tolist())
            all_sex.extend(list(sex_norm))
            all_task.extend(list(task_group))
            all_clip.extend(list(clip_paths))
            all_spk.extend([("" if (x is None or (isinstance(x, float) and np.isnan(x))) else str(x)) for x in speaker_ids])

    return (
        np.asarray(all_true, dtype=np.int64),
        np.asarray(all_probs, dtype=np.float64),
        np.asarray(all_sex, dtype=object),
        np.asarray(all_clip, dtype=object),
        np.asarray(all_spk, dtype=object),
        np.asarray(all_task, dtype=object),
    )

# -------------------------
# 17) Single-seed evaluation on D2 test
# -------------------------
# For each seed: load heads, run inference, compute metrics at THR_USED_GLOBAL, save plots and predictions.
def run_test_once(seed: int):
    set_all_seeds(seed)

    run_dir = RUN_ROOT / f"run_{train_dataset_id}_on_{d2_dataset_id}test_seed{seed}"
    run_dir.mkdir(parents=True, exist_ok=True)

    best_heads_path = chosen_exp / f"run_{train_dataset_id}_seed{seed}" / "best_heads.pt"

    print(f"\n[seed={seed}] Loading model + heads from:")
    print(" ", str(best_heads_path))

    model = Wav2Vec2TwoHeadClassifier(BACKBONE_CKPT, dropout_p=DROPOUT_P).to(DEVICE)
    model = load_heads_into_model(model, best_heads_path)
    model.eval()

    thr_used = float(THR_USED_GLOBAL)
    thr_note = THR_GLOBAL_NOTE

    # Inference (with meta for predictions.csv).
    yt_true, yt_prob, yt_sex, yt_clip, yt_spk, yt_task = _infer_probs_with_meta(
        test_loader, model, desc=f"[seed={seed}] Test (D2 TEST)"
    )
    test_auc = compute_auc(yt_true, yt_prob)

    # Metrics at the shared threshold.
    thr_metrics_test = compute_threshold_metrics(yt_true, yt_prob, thr=thr_used)
    fnr_by_sex, delta_f_minus_m, delta_abs = compute_fnr_by_group_signed(yt_true, yt_prob, yt_sex, thr=thr_used)
    confusion_by_sex = compute_confusion_by_group(yt_true, yt_prob, yt_sex, thr=thr_used)

    # Plots (overall).
    roc_png = run_dir / "roc_curve.png"
    cm_png  = run_dir / "confusion_matrix.png"
    save_roc_curve_png(yt_true, yt_prob, str(roc_png), title_suffix=f"D2 TEST (seed={seed})")
    save_confusion_png(yt_true, yt_prob, str(cm_png), thr=thr_used, title_suffix=f"D2 TEST (seed={seed})")

    # Plots split by sex (M and F only).
    cm_m_png = None
    cm_f_png = None
    mask_m = (yt_sex == "M")
    mask_f = (yt_sex == "F")

    if int(mask_m.sum()) > 0:
        cm_m_png = run_dir / "confusion_matrix_M.png"
        save_confusion_png(yt_true[mask_m], yt_prob[mask_m], str(cm_m_png), thr=thr_used, title_suffix=f"D2 TEST SEX=M (seed={seed})")

    if int(mask_f.sum()) > 0:
        cm_f_png = run_dir / "confusion_matrix_F.png"
        save_confusion_png(yt_true[mask_f], yt_prob[mask_f], str(cm_f_png), thr=thr_used, title_suffix=f"D2 TEST SEX=F (seed={seed})")

    # predictions.csv includes clip id fields plus model score and group fields.
    pred_df = pd.DataFrame({
        "clip_path": yt_clip.astype(str),
        "y_true": yt_true.astype(int),
        "y_score": yt_prob.astype(float),
        "sex_norm": yt_sex.astype(str),
        "speaker_id": yt_spk.astype(str),
        "task_group": yt_task.astype(str),
        "seed": int(seed),
        "trainval_exp_tag": str(FULL_TRAINVAL_EXP_TAG),
        "run_stamp": str(RUN_STAMP),
        "threshold_used_global": float(thr_used),
    })
    pred_csv_path = run_dir / "predictions.csv"
    pred_df.to_csv(pred_csv_path, index=False)

    # metrics.json captures run settings, aggregate stats, fairness, and artifact paths for this seed.
    metrics = {
        "enh_tag": str(ENH_TAG),

        "train_dataset": train_dataset_id,
        "test_dataset": d2_dataset_id,
        "seed": int(seed),

        "n_test": int(len(test_df)),
        "label_counts_test": test_df["label_num"].value_counts(dropna=False).to_dict(),
        "sex_counts_test_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),

        "test_auroc": float(test_auc),

        "threshold_source": "D7 enhanced trainval summary_trainval.json -> val_optimal_threshold.mean_sd.mean",
        "trainval_experiment_used": str(chosen_exp),
        "trainval_exp_tag": str(FULL_TRAINVAL_EXP_TAG),
        "trainval_summary_path": str(summary_trainval_path),

        "test_threshold_used_global": float(thr_used),
        "test_threshold_note_global": thr_note,

        "threshold_metrics_test_at_thr_used": thr_metrics_test,

        "fairness_test_at_thr_used": {
            "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at test_threshold_used_global.",
            "threshold_used": float(thr_used),
            "fnr_by_sex_norm": fnr_by_sex,
            "delta_fnr_F_minus_M": float(delta_f_minus_m),
            "delta_fnr_abs": float(delta_abs),
            "note": "If n_PD for a sex is 0, its FNR is NaN and ΔFNR is NaN.",
            "sex_normalization_note": "D2 mapping: exact 'male'->M and 'female'->F (case-sensitive); otherwise UNK.",
        },

        "confusion_by_sex_norm_at_thr_used": confusion_by_sex,

        "artifacts": {
            "predictions_csv": str(pred_csv_path),
            "roc_curve_png": str(roc_png),
            "confusion_matrix_png": str(cm_png),
            "confusion_matrix_M_png": str(cm_m_png) if cm_m_png is not None else None,
            "confusion_matrix_F_png": str(cm_f_png) if cm_f_png is not None else None,
        },

        "d7_out_root": D7_OUT_ROOT,
        "d2_out_root": D2_OUT_ROOT,
        "d2_manifest_all": D2_MANIFEST_ALL,

        "best_heads_path": str(best_heads_path),
        "backbone_ckpt": BACKBONE_CKPT,
        "dropout_p": float(DROPOUT_P),
        "timestamp": time.strftime("%Y%m%d_%H%M%S"),
    }

    metrics_path = run_dir / "metrics.json"
    with open(metrics_path, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2)

    print(f"[seed={seed}] DONE | test_AUROC={test_auc:.6f}")
    print(f"[seed={seed}] Threshold used (GLOBAL mean from D7 enhanced trainval): {thr_used:.6f}")
    print(f"[seed={seed}] WROTE:")
    print(" ", str(metrics_path))
    print(" ", str(pred_csv_path))
    print(" ", str(roc_png))
    print(" ", str(cm_png))

    return {
        "seed": int(seed),
        "thr_used": float(thr_used),
        "thr_note": thr_note,
        "test_auc": float(test_auc),
        "thr_metrics_test": thr_metrics_test,
        "fnr_by_sex": fnr_by_sex,
        "delta_signed": float(delta_f_minus_m),
        "delta_abs": float(delta_abs),
        "run_dir": str(run_dir),
        "predictions_csv": str(pred_csv_path),
    }

# -------------------------
# 18) Run all seeds and aggregate results
# -------------------------
# Reports AUROC mean with 95% CI (n=3) and other metrics as mean ± SD across seeds.
results = []
for seed in SEEDS:
    results.append(run_test_once(seed))

aurocs = [r["test_auc"] for r in results]
t_crit = 4.302652729911275  # df=2, 95% CI
n = len(aurocs)
mean_auc = float(np.mean(aurocs))
std_auc = float(np.std(aurocs, ddof=1)) if n > 1 else 0.0
half_width = float(t_crit * (std_auc / math.sqrt(n))) if n > 1 else 0.0
ci95 = [float(mean_auc - half_width), float(mean_auc + half_width)]

# Aggregates mean and SD for a list of values (NaN-safe).
def _mean_sd(vals):
    vals = np.asarray(vals, dtype=np.float64)
    mu = float(np.nanmean(vals)) if np.any(~np.isnan(vals)) else float("nan")
    sd = float(np.nanstd(vals, ddof=1)) if np.sum(~np.isnan(vals)) > 1 else 0.0
    return mu, sd

thr_list = [r["thr_metrics_test"] for r in results]
keys = ["accuracy","precision","recall","f1","sensitivity","specificity","mcc","p_value_fisher_two_sided"]
agg = {}
for k in keys:
    v = [float(tm.get(k, float("nan"))) for tm in thr_list]
    mu, sd = _mean_sd(v)
    agg[k] = {
        "mean": float(mu),
        "sd": float(sd),
        "values_by_seed": {str(s): float(tm.get(k, float("nan"))) for s, tm in zip(SEEDS, thr_list)},
    }

# Stores confusion counts per seed for quick cross-checks.
cm_by_seed = {
    str(s): {"tn": int(thr_list[i]["tn"]), "fp": int(thr_list[i]["fp"]), "fn": int(thr_list[i]["fn"]), "tp": int(thr_list[i]["tp"])}
    for i, s in enumerate(SEEDS)
}

# Fairness aggregation across seeds.
fnr_by_seed = {str(r["seed"]): r["fnr_by_sex"] for r in results}
delta_signed_by_seed = {str(r["seed"]): float(r["delta_signed"]) for r in results}
delta_abs_by_seed = {str(r["seed"]): float(r["delta_abs"]) for r in results}

fnr_m_vals, fnr_f_vals = [], []
d_signed_vals, d_abs_vals = [], []
for r in results:
    d = r["fnr_by_sex"] or {}
    fnr_m_vals.append(float(d.get("M", {}).get("fnr", float("nan"))))
    fnr_f_vals.append(float(d.get("F", {}).get("fnr", float("nan"))))
    d_signed_vals.append(float(r["delta_signed"]))
    d_abs_vals.append(float(r["delta_abs"]))

fnr_m_mean, fnr_m_sd = _mean_sd(fnr_m_vals)
fnr_f_mean, fnr_f_sd = _mean_sd(fnr_f_vals)
d_signed_mean, d_signed_sd = _mean_sd(d_signed_vals)
d_abs_mean, d_abs_sd = _mean_sd(d_abs_vals)

# Console summary for quick review.
print("\nTest AUROC by seed:")
for r in results:
    print(f"  seed {r['seed']}: {r['test_auc']:.6f}")
print(f"\nMean Test AUROC: {mean_auc:.6f}")
print(f"95% CI (t, n=3): [{ci95[0]:.6f}, {ci95[1]:.6f}]")

print("\nTEST threshold used (GLOBAL mean val-opt from D7 enhanced trainval):")
print(f"  THR_USED_GLOBAL: {THR_USED_GLOBAL:.6f}")
if THR_GLOBAL_NOTE is not None:
    print("  NOTE:", THR_GLOBAL_NOTE)

print("\nThreshold metrics on D2 TEST @ THR_USED_GLOBAL (mean ± SD across seeds):")
for k in ["accuracy","precision","sensitivity","specificity","f1","mcc"]:
    print(f"  {k}: {agg[k]['mean']:.6f} ± {agg[k]['sd']:.6f}")
print("  fisher_p_value_two_sided:", f"{agg['p_value_fisher_two_sided']['mean']:.6g} ± {agg['p_value_fisher_two_sided']['sd']:.6g}")

print("\nFAIRNESS (H3) on D2 TEST @ THR_USED_GLOBAL across seeds (mean ± SD):")
print(f"  FNR_M: {fnr_m_mean:.6f} ± {fnr_m_sd:.6f}")
print(f"  FNR_F: {fnr_f_mean:.6f} ± {fnr_f_sd:.6f}")
print(f"  ΔFNR (F - M): {d_signed_mean:.6f} ± {d_signed_sd:.6f}")
print(f"  |ΔFNR|: {d_abs_mean:.6f} ± {d_abs_sd:.6f}")

# -------------------------
# 19) Write summary and pointer files
# -------------------------
# Saves a latest summary (overwritten with backup), a last-run pointer, an append-only history line, and a tag pointer.
summary_latest = {
    "enh_tag": str(ENH_TAG),
    "train_dataset": train_dataset_id,
    "test_dataset": d2_dataset_id,

    "chosen_trainval_exp": str(chosen_exp),
    "trainval_exp_tag": str(FULL_TRAINVAL_EXP_TAG),
    "run_stamp": str(RUN_STAMP),

    "threshold_used_global": float(THR_USED_GLOBAL),
    "threshold_note_global": THR_GLOBAL_NOTE,

    "test_auroc": {
        "by_seed": {str(r["seed"]): float(r["test_auc"]) for r in results},
        "mean": float(mean_auc),
        "std": float(std_auc),
        "ci95_t_n3": [float(ci95[0]), float(ci95[1])],
    },
    "threshold_metrics_test_at_thr_used": agg,
    "confusion_counts_by_seed_at_thr_used": cm_by_seed,

    "fairness_test_at_thr_used": {
        "definition": "H3: ΔFNR = FNR(F) - FNR(M), where FNR(sex)=FN/(FN+TP) computed on PD-only true labels at threshold_used_global.",
        "fnr_by_seed": fnr_by_seed,
        "delta_fnr_F_minus_M_by_seed": delta_signed_by_seed,
        "delta_fnr_abs_by_seed": delta_abs_by_seed,
        "fnr_M_mean_sd": {"mean": float(fnr_m_mean), "sd": float(fnr_m_sd)},
        "fnr_F_mean_sd": {"mean": float(fnr_f_mean), "sd": float(fnr_f_sd)},
        "delta_signed_mean_sd": {"mean": float(d_signed_mean), "sd": float(d_signed_sd)},
        "delta_abs_mean_sd": {"mean": float(d_abs_mean), "sd": float(d_abs_sd)},
    },

    "run_root": str(RUN_ROOT),
    "runs": {str(r["seed"]): {"run_dir": r["run_dir"], "predictions_csv": r["predictions_csv"]} for r in results},
}

summary_latest_path = TEST_ROOT / "summary_latest.json"
_backup_if_exists(summary_latest_path)
with open(summary_latest_path, "w", encoding="utf-8") as f:
    json.dump(summary_latest, f, indent=2)

# Kept for compatibility with prior notebooks (path is the same as summary_latest_path).
summary_symlink_path = TEST_ROOT / "summary_latest.json"  # intentional overwrite already handled via backup
# keep as-is

# Stores the most recent run folder pointer (overwritten with backup).
last_run_pointer_path = TEST_ROOT / "last_run_pointer.json"
_backup_if_exists(last_run_pointer_path)
last_run_pointer_obj = {
    "run_root": str(RUN_ROOT),
    "run_parent_dirname": str(RUN_PARENT_DIRNAME),
    "trainval_exp_tag": str(FULL_TRAINVAL_EXP_TAG),
    "trainval_exp_path": str(chosen_exp),
    "enh_tag": str(ENH_TAG),
    "run_stamp": str(RUN_STAMP),
    "timestamp": time.strftime("%Y%m%d_%H%M%S"),
}
with open(last_run_pointer_path, "w", encoding="utf-8") as f:
    json.dump(last_run_pointer_obj, f, indent=2)

# Appends a one-line history record so older runs stay discoverable.
history_index_path = TEST_ROOT / "history_index.jsonl"
history_record = {
    "timestamp": time.strftime("%Y%m%d_%H%M%S"),
    "run_stamp": str(RUN_STAMP),
    "run_root": str(RUN_ROOT),
    "trainval_exp_tag": str(FULL_TRAINVAL_EXP_TAG),
    "trainval_exp_path": str(chosen_exp),
    "enh_tag": str(ENH_TAG),
    "summary_latest_path": str(summary_latest_path),
}
with open(history_index_path, "a", encoding="utf-8") as f:
    f.write(json.dumps(history_record) + "\n")

# Stores a tag-scoped pointer (stable location per experiment tag).
tag_pointer_path = TAG_ROOT / "tag_run_pointer.json"
tag_pointer_obj = dict(last_run_pointer_obj)
tag_pointer_obj["tag_root"] = str(TAG_ROOT)
_backup_if_exists(tag_pointer_path)
with open(tag_pointer_path, "w", encoding="utf-8") as f:
    json.dump(tag_pointer_obj, f, indent=2)

print("\nWROTE (index/pointers):")
print(" ", str(summary_latest_path))
print(" ", str(last_run_pointer_path))
print(" ", str(history_index_path))
print(" ", str(tag_pointer_path))

# -------------------------
# 20) Builder-aligned config and logs
# -------------------------
# Writes a compact run_config plus minimal log placeholders for consistent folder structure.
run_config_obj = {
    "mode": f"D7_{TAG_SAFE}_on_D2_Test",
    "enh_tag": str(ENH_TAG),
    "trainval_exp_tag": str(FULL_TRAINVAL_EXP_TAG),
    "trainval_exp_path": str(chosen_exp),
    "d7_out_root": str(D7_OUT_ROOT),
    "d2_out_root": str(D2_OUT_ROOT),
    "d2_manifest_all": str(D2_MANIFEST_ALL),
    "seeds": [int(s) for s in SEEDS],
    "per_device_bs": int(PER_DEVICE_BS),
    "effective_bs": int(PER_DEVICE_BS * GRAD_ACCUM),
    "use_amp": bool(USE_AMP and DEVICE.type == "cuda"),
    "threshold_used_global": float(THR_USED_GLOBAL),
    "run_stamp": str(RUN_STAMP),
    "timestamp": time.strftime("%Y%m%d_%H%M%S"),
}
_backup_if_exists(RUN_CONFIG_PATH)
with open(RUN_CONFIG_PATH, "w", encoding="utf-8") as f:
    json.dump(run_config_obj, f, indent=2)

# Creates an empty warnings CSV if it does not exist (structure consistency).
if not WARNINGS_CSV_PATH.exists():
    WARNINGS_CSV_PATH.parent.mkdir(parents=True, exist_ok=True)
    pd.DataFrame(columns=["warning_type","detail"]).to_csv(WARNINGS_CSV_PATH, index=False)

# Saves a small dataset summary for this evaluation run.
dataset_summary_obj = {
    "mode": f"D7_{TAG_SAFE}_on_D2_Test",
    "dataset_id": str(d2_dataset_id),
    "split": "test",
    "n_rows": int(len(test_df)),
    "label_counts": test_df["label_num"].value_counts(dropna=False).to_dict(),
    "sex_counts_raw": test_df["sex"].value_counts(dropna=False).to_dict(),
    "sex_counts_norm": test_df["sex_norm"].value_counts(dropna=False).to_dict(),
    "timestamp": time.strftime("%Y%m%d_%H%M%S"),
}
_backup_if_exists(DATASET_SUMMARY_PATH)
with open(DATASET_SUMMARY_PATH, "w", encoding="utf-8") as f:
    json.dump(dataset_summary_obj, f, indent=2)

print("\nWROTE (builder-aligned config/logs):")
print(" ", str(RUN_CONFIG_PATH))
print(" ", str(WARNINGS_CSV_PATH))
print(" ", str(DATASET_SUMMARY_PATH))

# -------------------------
# 21) Free GPU memory (optional)
# -------------------------
# Clears cached CUDA allocations to reduce peak memory after the run.
try:
    torch.cuda.empty_cache()
except Exception:
    pass

# Ablation 1: Threshold Sweeps

The following cell runs a **post-training decision threshold sweep** for the **D7 trainEnh1 model evaluated on the D2 test split**, without retraining or changing the model. It automatically finds the most recent D7 training and validation experiment whose folder name contains **“trainEnh1”** and checks that all required files are present for each of the three seeds, including the training summary and saved head weights. Using these fixed model heads, it runs inference once per seed on the D2 test data to get stable predicted probabilities, then evaluates how performance and fairness change as the decision threshold is varied from 0.01 to 0.99.

The cell starts with basic setup and safety checks. It avoids issues from locally named files that could override core libraries, mounts Google Drive if needed, resolves dataset output paths, and prints key settings such as device type, batch size, and sweep resolution. It then loads the D2 `manifest_all.csv`, confirms it belongs to **dataset D2**, filters to the test split only, and prints basic label and sex counts. All referenced audio files are checked, and the run stops immediately if any are missing.

For data preparation, each test clip is assigned to a simple task group: **vowel** for sustained vowel recordings and **other** for all remaining speech tasks. Sex metadata is standardized using D2’s original values, mapping `male` and `female` to **M** and **F**, with anything else treated as unknown. Audio is loaded at **16 kHz**, and attention masks are created so padded regions are ignored. For vowel clips, the mask also removes trailing near-silence so silence is not treated as useful signal.

Inference uses the same model structure as training: a frozen speech feature extractor with two task-specific classification heads, one for vowel clips and one for non-vowel clips. For each seed, the saved head weights are loaded and the full D2 test set is processed once. The predicted Parkinson’s probabilities, along with true labels and sex information, are stored. **AUROC is computed for each seed and averaged across seeds** as a threshold-free reference before sweeping.

With predictions fixed, the cell performs the threshold sweep. At each threshold, standard classification metrics are computed and then averaged across the three seeds, including sensitivity, specificity, accuracy, precision, F1 score, Matthews correlation coefficient, and Fisher’s exact test p-value. Fairness is evaluated by computing the false negative rate for Parkinson’s cases separately for males and females, then summarizing the sex difference as both a signed value (female minus male) and an absolute gap. The sweep table keeps both averaged results and per-seed values so differences across seeds remain visible.

A single recommended threshold is chosen using a clear rule called **Policy B+**. This threshold minimizes the mean absolute sex gap in false negative rate while also meeting two requirements at the same time: mean sensitivity of at least **0.60** and mean specificity of at least **0.50**. If no threshold meets both conditions, the cell reports which requirement failed and selects the closest available alternative based on how far it misses the targets.

All outputs are saved to a timestamped folder under the D7 threshold sweep directory. This includes the full sweep table, a summary explaining the selected experiment and threshold choice, and plots showing how sensitivity, specificity, and fairness change with the threshold, along with a trade-off curve that clearly marks the chosen operating point.

In [None]:
# =========================
# Threshold Sweep With Fairness Guardrail (D7 trainEnh1 → D2 test)
# =========================
# Purpose: run inference once per seed using saved D7 trainEnh1 heads, then sweep decision thresholds on the same scores.
# Inputs: saved head weights for three seeds, plus the D2 manifest (test split) pointing to clip files and labels.
# Outputs: a sweep table (CSV), a summary (JSON), and a few simple plots saved in a new timestamped sweep folder.
# =========================

import os, json, math, random, time, warnings
from pathlib import Path

import numpy as np
import pandas as pd
import soundfile as sf

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import Wav2Vec2Model
from sklearn.metrics import roc_auc_score, confusion_matrix, matthews_corrcoef
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# -------------------------
# 0) Import safety check
# -------------------------
# Stops early if local files would override PyTorch or Transformers.
if os.path.exists("/content/torch.py") or os.path.exists("/content/torch/__init__.py"):
    raise RuntimeError("Found local /content/torch.py or /content/torch/ that shadows PyTorch. Rename/remove it and restart runtime.")
if os.path.exists("/content/transformers.py") or os.path.exists("/content/transformers/__init__.py"):
    raise RuntimeError("Found local /content/transformers.py or /content/transformers/ that shadows Hugging Face Transformers. Rename/remove it and restart runtime.")

# -------------------------
# 1) Drive mount (best-effort)
# -------------------------
# Mounts Google Drive if running in Colab and not already mounted.
try:
    from google.colab import drive  # type: ignore
    if not os.path.isdir("/content/drive/MyDrive"):
        drive.mount("/content/drive")
except Exception:
    pass

# -------------------------
# 2) Root paths
# -------------------------
# Uses existing notebook globals if present; otherwise uses the fallbacks below.
D7_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D7-Multilingual (D1_D4_D5v2_D6)/preprocessed_v1"
D2_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D2-Slovak (EWA-DB)/EWA-DB/preprocessed_v1"

D7_OUT_ROOT = str(globals().get("D7_OUT_ROOT", D7_OUT_ROOT_FALLBACK))
D2_OUT_ROOT = str(globals().get("D2_OUT_ROOT", D2_OUT_ROOT_FALLBACK))
D2_MANIFEST_ALL = f"{D2_OUT_ROOT}/manifests/manifest_all.csv"

# Exposes roots for other cells that might reuse them.
globals()["D7_OUT_ROOT"] = D7_OUT_ROOT
globals()["D2_OUT_ROOT"] = D2_OUT_ROOT

# -------------------------
# 3) Run configuration
# -------------------------
# Core settings: seeds, backbone name, expected audio sample rate, and inference batching.
SEEDS          = [1337, 2024, 7777]
BACKBONE_CKPT  = "facebook/wav2vec2-base"
SR_EXPECTED    = 16000
TINY_THRESH    = 1e-4

PER_DEVICE_BS  = 16
NUM_WORKERS    = 0
PIN_MEMORY     = False
USE_AMP        = True
DROPOUT_P      = 0.2

# Experiment selector: matches the exp_* folder name tag for trainEnh1.
REQUIRED_EXP_SUBSTRING = "trainEnh1"  # case-insensitive match against exp_* folder name

# Policy B+: reduce sex FNR gap while keeping sensitivity and specificity above minimums.
TARGET_SENS = 0.60
MIN_SPEC    = 0.50   # avoids trivial thresholds that label nearly everything as PD
POLICY_TEXT = (
    "Policy B+: minimize mean(|ΔFNR|) subject to "
    f"mean(sensitivity) >= {TARGET_SENS:.2f} AND mean(specificity) >= {MIN_SPEC:.2f}"
)

# Threshold grid for sweeping.
THR_MIN, THR_MAX, THR_STEPS = 0.01, 0.99, 199
THR_GRID = np.linspace(THR_MIN, THR_MAX, THR_STEPS).astype(np.float64)

# Device selection and minor numeric tuning.
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

# Keeps output readable by filtering common non-critical warnings.
warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
warnings.filterwarnings("ignore", message="Passing `gradient_checkpointing` to a config initialization is deprecated")
warnings.filterwarnings("ignore", category=UserWarning, module="huggingface_hub")

# Prints key settings for quick verification.
print("D7_OUT_ROOT:", D7_OUT_ROOT)
print("D2_OUT_ROOT:", D2_OUT_ROOT)
print("D2_MANIFEST_ALL:", D2_MANIFEST_ALL)
print("DEVICE:", DEVICE)
if DEVICE.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
print("PER_DEVICE_BS:", PER_DEVICE_BS)
print("NUM_WORKERS:", NUM_WORKERS, "| PIN_MEMORY:", PIN_MEMORY)
print("USE_AMP:", bool(USE_AMP and DEVICE.type == "cuda"))
print("Required exp substring (case-insensitive):", REQUIRED_EXP_SUBSTRING)
print("Policy B+ constraints:")
print("  mean sensitivity >=", TARGET_SENS)
print("  mean specificity >=", MIN_SPEC)
print(f"Threshold sweep grid: {THR_MIN:.2f}..{THR_MAX:.2f} with {THR_STEPS} steps")

# -------------------------
# 4) Load D2 manifest and keep only test split
# -------------------------
# Reads the D2 manifest, checks required columns, confirms it is actually D2, then filters to split == "test".
if not os.path.exists(D2_MANIFEST_ALL):
    raise FileNotFoundError(f"Missing D2 manifest_all.csv: {D2_MANIFEST_ALL}")

m_all = pd.read_csv(D2_MANIFEST_ALL)

req_cols = {"split", "clip_path", "label_num", "task", "sex", "age"}
missing = [c for c in sorted(req_cols) if c not in m_all.columns]
if missing:
    raise ValueError(f"D2 manifest missing required columns: {missing}. Found: {list(m_all.columns)}")

# Uses the most common dataset label (when present) and filters to that dataset.
if "dataset" in m_all.columns and m_all["dataset"].notna().any():
    d2_dataset_id = str(m_all["dataset"].astype(str).value_counts(dropna=True).idxmax())
    m_all = m_all[m_all["dataset"].astype(str) == d2_dataset_id].copy()
else:
    d2_dataset_id = "DX"

# Hard stop if the manifest is not D2.
if d2_dataset_id != "D2":
    raise RuntimeError(
        f"Expected D2 dataset_id=='D2' but got {d2_dataset_id!r}. "
        "This usually means D2_OUT_ROOT is wrong or the manifest is not D2. "
        f"D2_OUT_ROOT={D2_OUT_ROOT}"
    )

# Keeps a stable set of columns used downstream (fills missing ones with NaN).
keep_cols = ["clip_path", "label_num", "task", "speaker_id", "sex", "age", "duration_sec", "split"]
for c in keep_cols:
    if c not in m_all.columns:
        m_all[c] = np.nan
m_all = m_all[keep_cols].copy()

# Test split only.
test_df = m_all[m_all["split"].astype(str) == "test"].reset_index(drop=True)

print(f"\nD2 dataset inferred: {d2_dataset_id}")
print(f"D2 TEST rows: {len(test_df)}")
print("D2 TEST label counts:", test_df["label_num"].value_counts(dropna=False).to_dict())
print("D2 TEST sex counts (raw):", test_df["sex"].value_counts(dropna=False).to_dict())

if len(test_df) == 0:
    raise RuntimeError("After filtering to split=='test', D2 manifest has 0 rows.")

# -------------------------
# 5) Fail fast: confirm audio files exist
# -------------------------
# Checks for missing clip paths early to avoid long runs that fail late.
def _fail_fast_missing_paths(df: pd.DataFrame, name: str):
    missing_paths = []
    for p in tqdm(df["clip_path"].astype(str).tolist(), desc=f"Check {name} clip_path exists", dynamic_ncols=True):
        if not os.path.exists(p):
            missing_paths.append(p)
            if len(missing_paths) >= 10:
                break
    if missing_paths:
        raise FileNotFoundError(f"{name}: missing clip_path(s). Examples: {missing_paths}")

_fail_fast_missing_paths(test_df, "D2 TEST")

# -------------------------
# 6) Task grouping (vowel vs other)
# -------------------------
# Maps each clip to the head it should use during inference.
def _task_group(task_val) -> str:
    return "vowel" if str(task_val) == "vowl" else "other"

test_df["task_group"] = test_df["task"].apply(_task_group)

# -------------------------
# 6.5) Sex normalization
# -------------------------
# Normalizes D2 "male"/"female" into "M"/"F" and maps anything else to "UNK".
def normalize_sex_d2_case_sensitive(val) -> str:
    if pd.isna(val):
        return "UNK"
    if val == "male":
        return "M"
    if val == "female":
        return "F"
    return "UNK"

test_df["sex_norm"] = test_df["sex"].apply(normalize_sex_d2_case_sensitive)
print("D2 TEST sex counts (normalized):", test_df["sex_norm"].value_counts(dropna=False).to_dict())
if (test_df["sex_norm"] == "UNK").any():
    print("NOTE: Some D2 'sex' values were not exactly 'male'/'female' and were mapped to 'UNK'.")

# -------------------------
# 7) Dataset and collator
# -------------------------
# Builds model inputs from audio files and creates an attention mask for pooling.
# For vowel clips, the mask removes trailing padding-like silence based on a tiny amplitude threshold.
class AudioManifestDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        clip_path = str(row["clip_path"])
        label = int(row["label_num"])
        task_group = str(row["task_group"])
        sex_norm = str(row["sex_norm"])

        if not os.path.exists(clip_path):
            raise FileNotFoundError(f"Missing clip_path: {clip_path}")

        # Reads audio and converts to mono float32.
        y, sr = sf.read(clip_path, always_2d=False)
        if isinstance(y, np.ndarray) and y.ndim > 1:
            y = np.mean(y, axis=1)
        y = np.asarray(y, dtype=np.float32)

        # Enforces a single sample rate so the backbone sees consistent inputs.
        if int(sr) != SR_EXPECTED:
            raise ValueError(f"Sample rate mismatch (expected {SR_EXPECTED}): {clip_path} has sr={sr}")

        # Attention mask: 1 = keep, 0 = ignore during pooling.
        attn = np.ones((len(y),), dtype=np.int64)
        if task_group == "vowel":
            # Finds the last non-tiny sample and ignores anything after it.
            k = -1
            for j in range(len(y) - 1, -1, -1):
                if abs(float(y[j])) > TINY_THRESH:
                    k = j
                    break
            if k >= 0:
                attn[:k+1] = 1
                attn[k+1:] = 0
            else:
                attn[:] = 1
        else:
            attn[:] = 1

        return {
            "input_values": torch.from_numpy(y),
            "attention_mask": torch.from_numpy(attn),
            "labels": torch.tensor(label, dtype=torch.long),
            "task_group": task_group,
            "sex_norm": sex_norm,
        }

# Pads variable-length audio in a batch to the longest clip length.
def collate_fn(batch):
    max_len = int(max(b["input_values"].numel() for b in batch))
    input_vals, attn_masks, labels, task_groups, sex_norms = [], [], [], [], []
    for b in batch:
        x = b["input_values"]
        a = b["attention_mask"]
        pad = max_len - x.numel()
        if pad > 0:
            x = torch.cat([x, torch.zeros(pad, dtype=x.dtype)], dim=0)
            a = torch.cat([a, torch.zeros(pad, dtype=a.dtype)], dim=0)
        input_vals.append(x)
        attn_masks.append(a)
        labels.append(b["labels"])
        task_groups.append(b["task_group"])
        sex_norms.append(b["sex_norm"])
    return {
        "input_values": torch.stack(input_vals, dim=0),
        "attention_mask": torch.stack(attn_masks, dim=0),
        "labels": torch.stack(labels, dim=0),
        "task_group": task_groups,
        "sex_norm": sex_norms,
    }

# DataLoader over the D2 test split.
test_ds = AudioManifestDataset(test_df)
test_loader = DataLoader(
    test_ds,
    batch_size=PER_DEVICE_BS,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    collate_fn=collate_fn,
)

# -------------------------
# 8) Model (frozen backbone + two heads)
# -------------------------
# Loads a frozen Wav2Vec2 backbone and switches between vowel and other heads per clip.
class Wav2Vec2TwoHeadClassifier(nn.Module):
    def __init__(self, ckpt: str, dropout_p: float):
        super().__init__()
        self.backbone = Wav2Vec2Model.from_pretrained(ckpt, use_safetensors=True)
        for p in self.backbone.parameters():
            p.requires_grad = False

        H = int(self.backbone.config.hidden_size)
        self.pre_vowel = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.pre_other = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.head_vowel = nn.Linear(H, 2)
        self.head_other = nn.Linear(H, 2)

    # Mean-pools frame features using the attention mask (ignores masked regions).
    def masked_mean_pool(self, last_hidden, attn_mask_samples):
        feat_mask = self.backbone._get_feature_vector_attention_mask(last_hidden.shape[1], attn_mask_samples)
        feat_mask = feat_mask.to(last_hidden.device).unsqueeze(-1).type_as(last_hidden)
        masked = last_hidden * feat_mask
        denom = feat_mask.sum(dim=1).clamp(min=1.0)
        return masked.sum(dim=1) / denom

    # Runs the head in fp32 for stability even when AMP is enabled.
    def _heads_fp32(self, x_any: torch.Tensor, head: nn.Module) -> torch.Tensor:
        x = x_any.float()
        if x.is_cuda:
            with torch.amp.autocast(device_type="cuda", enabled=False):
                return head(x)
        return head(x)

    # Produces logits using the head selected by task_group.
    def forward_logits(self, input_values, attention_mask, task_group):
        with torch.no_grad():
            out = self.backbone(input_values=input_values, attention_mask=attention_mask)
            last_hidden = out.last_hidden_state

        pooled = self.masked_mean_pool(last_hidden, attention_mask)
        z_v = self.pre_vowel(pooled.float())
        z_o = self.pre_other(pooled.float())

        logits_v = self._heads_fp32(z_v, self.head_vowel)
        logits_o = self._heads_fp32(z_o, self.head_other)

        is_vowel = torch.tensor([tg == "vowel" for tg in task_group], device=pooled.device, dtype=torch.bool)
        logits = logits_o.clone()
        if is_vowel.any():
            logits[is_vowel] = logits_v[is_vowel]
        return logits

# Loads only the head weights saved during training (backbone stays the pretrained frozen model).
def load_heads_into_model(model: Wav2Vec2TwoHeadClassifier, best_heads_path: Path):
    if not best_heads_path.exists():
        raise FileNotFoundError(f"Missing best_heads.pt: {str(best_heads_path)}")
    state = torch.load(str(best_heads_path), map_location="cpu")
    model.pre_vowel.load_state_dict(state["pre_vowel"], strict=True)
    model.pre_other.load_state_dict(state["pre_other"], strict=True)
    model.head_vowel.load_state_dict(state["head_vowel"], strict=True)
    model.head_other.load_state_dict(state["head_other"], strict=True)
    return model

# -------------------------
# 9) Seed control and single-pass inference
# -------------------------
# Runs inference once per seed and returns fixed probabilities, labels, and sex tags for threshold sweeping.
def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def infer_once(seed: int, chosen_exp: Path, train_dataset_id: str):
    set_all_seeds(seed)
    best_heads_path = chosen_exp / f"run_{train_dataset_id}_seed{seed}" / "best_heads.pt"

    model = Wav2Vec2TwoHeadClassifier(BACKBONE_CKPT, dropout_p=DROPOUT_P).to(DEVICE)
    model = load_heads_into_model(model, best_heads_path)
    model.eval()

    use_amp = bool(USE_AMP and DEVICE.type == "cuda")
    all_probs, all_true, all_sex = [], [], []

    # Forward pass over D2 test (no thresholding here).
    pbar = tqdm(test_loader, desc=f"[seed={seed}] Inference D2 TEST", dynamic_ncols=True)
    with torch.inference_mode():
        for batch in pbar:
            input_values = batch["input_values"].to(DEVICE, non_blocking=False)
            attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
            labels = batch["labels"].to(DEVICE, non_blocking=False)
            task_group = batch["task_group"]
            sex_norm = batch["sex_norm"]

            with torch.amp.autocast(device_type="cuda", enabled=use_amp):
                logits = model.forward_logits(input_values, attention_mask, task_group)

            # Probability for PD is the softmax class-1 probability.
            probs = torch.softmax(logits.float(), dim=-1)[:, 1].detach().cpu().numpy().astype(np.float64)
            all_probs.extend(probs.tolist())
            all_true.extend(labels.detach().cpu().numpy().astype(np.int64).tolist())
            all_sex.extend(list(sex_norm))

    y_true = np.asarray(all_true, dtype=np.int64)
    y_prob = np.asarray(all_probs, dtype=np.float64)
    y_sex  = np.asarray(all_sex, dtype=object)

    # AUROC is reported per seed as a threshold-independent reference.
    auc = float("nan")
    if len(np.unique(y_true)) >= 2:
        auc = float(roc_auc_score(y_true, y_prob))

    return {
        "seed": int(seed),
        "best_heads_path": str(best_heads_path),
        "y_true": y_true,
        "y_prob": y_prob,
        "y_sex": y_sex,
        "auroc": float(auc),
    }

# -------------------------
# 10) Metric helpers for threshold sweep
# -------------------------
# Computes standard metrics from a thresholded prediction and sex-based FNR gap on PD-only cases.
def confusion_counts(y_true, y_prob, thr):
    y_pred = (y_prob >= float(thr)).astype(np.int64)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])
    return TN, FP, FN, TP

def threshold_metrics(y_true, y_prob, thr):
    TN, FP, FN, TP = confusion_counts(y_true, y_prob, thr)
    eps = 1e-12
    acc  = (TP + TN) / max(1, (TP + TN + FP + FN))
    prec = TP / (TP + FP + eps)
    sens = TP / (TP + FN + eps)  # recall/sensitivity
    spec = TN / (TN + FP + eps)
    f1   = 2 * prec * sens / (prec + sens + eps)

    # MCC can be undefined if predictions collapse to a single class.
    try:
        y_pred = (y_prob >= float(thr)).astype(np.int64)
        mcc = float(matthews_corrcoef(y_true, y_pred)) if len(np.unique(y_pred)) > 1 else float("nan")
    except Exception:
        mcc = float("nan")

    # Fisher exact test on the 2x2 table (best-effort if scipy is available).
    pval = float("nan")
    try:
        from scipy.stats import fisher_exact  # type: ignore
        _, pval = fisher_exact([[TN, FP], [FN, TP]], alternative="two-sided")
        pval = float(pval)
    except Exception:
        pval = float("nan")

    return {
        "tn": TN, "fp": FP, "fn": FN, "tp": TP,
        "accuracy": float(acc),
        "precision": float(prec),
        "sensitivity": float(sens),
        "specificity": float(spec),
        "f1": float(f1),
        "mcc": float(mcc),
        "fisher_p_two_sided": float(pval),
    }

def fnr_by_sex_signed_delta(y_true, y_prob, y_sex, thr):
    # FNR is computed within each sex group on true PD cases only: FN / (FN + TP).
    y_pred = (y_prob >= float(thr)).astype(np.int64)
    out = {}

    for g in ["M", "F"]:
        mask_g = (y_sex == g)
        pos_mask = mask_g & (y_true == 1)
        n_pos = int(pos_mask.sum())
        if n_pos == 0:
            out[g] = {"n_pos": 0, "fn": 0, "tp": 0, "fnr": float("nan")}
            continue
        tp = int(((y_pred == 1) & pos_mask).sum())
        fn = int(((y_pred == 0) & pos_mask).sum())
        fnr = float(fn / max(1, (fn + tp)))
        out[g] = {"n_pos": int(n_pos), "fn": int(fn), "tp": int(tp), "fnr": float(fnr)}

    fnr_m = out["M"]["fnr"]
    fnr_f = out["F"]["fnr"]
    if (not np.isnan(fnr_m)) and (not np.isnan(fnr_f)):
        delta = float(fnr_f - fnr_m)     # ΔFNR = F - M (signed)
        absd  = float(abs(delta))
    else:
        delta = float("nan")
        absd  = float("nan")

    return out, delta, absd

# Small helpers to aggregate values across seeds while tolerating NaNs.
def mean_sd(vals):
    vals = np.asarray(vals, dtype=np.float64)
    mu = float(np.nanmean(vals)) if np.any(~np.isnan(vals)) else float("nan")
    sd = float(np.nanstd(vals, ddof=1)) if np.sum(~np.isnan(vals)) > 1 else 0.0
    return mu, sd

def safe_nanmean(x):
    x = np.asarray(x, dtype=np.float64)
    if np.any(~np.isnan(x)):
        return float(np.nanmean(x))
    return float("nan")

def safe_nansd(x):
    x = np.asarray(x, dtype=np.float64)
    if np.sum(~np.isnan(x)) > 1:
        return float(np.nanstd(x, ddof=1))
    return 0.0

# -------------------------
# 11) Find the most recent matching trainval experiment
# -------------------------
# Selects the newest exp_* folder that matches REQUIRED_EXP_SUBSTRING and has all seed heads plus summary_trainval.json.
TRAINVAL_ROOT = Path(D7_OUT_ROOT) / "trainval_runs"
if not TRAINVAL_ROOT.exists():
    raise FileNotFoundError(f"Missing trainval_runs folder under D7_OUT_ROOT: {str(TRAINVAL_ROOT)}")

exp_dirs = sorted([p for p in TRAINVAL_ROOT.glob("exp_*") if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
if not exp_dirs:
    raise FileNotFoundError(f"No exp_* folders found under: {str(TRAINVAL_ROOT)}")

train_dataset_id = "D7"

def is_match(exp_path: Path, required_substring: str) -> bool:
    return required_substring.lower() in exp_path.name.lower()

def has_all_seeds_and_summary(exp_path: Path, dataset_id: str, seeds: list) -> bool:
    if not (exp_path / "summary_trainval.json").exists():
        return False
    for s in seeds:
        p = exp_path / f"run_{dataset_id}_seed{s}" / "best_heads.pt"
        if not p.exists():
            return False
    return True

chosen_exp = None
for ed in exp_dirs:
    if not is_match(ed, REQUIRED_EXP_SUBSTRING):
        continue
    if has_all_seeds_and_summary(ed, train_dataset_id, SEEDS):
        chosen_exp = ed
        break

if chosen_exp is None:
    examples = [p.name for p in exp_dirs[:10]]
    raise FileNotFoundError(
        "Could not find a recent D7 trainval experiment folder that:\n"
        f"  (1) contains substring '{REQUIRED_EXP_SUBSTRING}' (case-insensitive) in the exp folder name, and\n"
        "  (2) contains all 3 best_heads.pt files + summary_trainval.json.\n"
        f"Checked under: {str(TRAINVAL_ROOT)}\n\n"
        f"Example exp_* folder names (most recent first):\n  - " + "\n  - ".join(examples) + "\n\n"
        "Fix: set REQUIRED_EXP_SUBSTRING to match the exact tag used in the exp folder name (underscores vs camelCase matter)."
    )

print("\nUsing trainval experiment folder:")
print(" ", str(chosen_exp))

# Loads the trainval summary only to print the mean Youden-J threshold as a reference.
summary_trainval_path = chosen_exp / "summary_trainval.json"
with open(summary_trainval_path, "r", encoding="utf-8") as f:
    trainval_summary = json.load(f)

val_opt = ((trainval_summary or {}).get("val_optimal_threshold", {}) or {})
thr_mean = float((((val_opt.get("mean_sd", {}) or {}).get("mean", float("nan")))))
print("\nTrainval mean val-opt threshold (Youden J) for reference:")
print("  summary_trainval.json -> val_optimal_threshold.mean_sd.mean =", f"{thr_mean:.6f}" if not np.isnan(thr_mean) else "nan")

# -------------------------
# 12) Create output folder for this sweep run
# -------------------------
# Uses a timestamp so each sweep run writes to a fresh folder.
TS_ROOT = Path(D7_OUT_ROOT) / "threshold_sweeps"
TS_ROOT.mkdir(parents=True, exist_ok=True)
timestamp = time.strftime("%Y%m%d_%H%M%S")
OUT_DIR = TS_ROOT / f"run_D7_{REQUIRED_EXP_SUBSTRING}_on_D2test_{timestamp}"
OUT_DIR.mkdir(parents=True, exist_ok=True)

# -------------------------
# 13) Inference once per seed
# -------------------------
# Produces fixed prediction probabilities used by all threshold evaluations.
print("\nRunning inference ONCE per seed (then sweep thresholds on fixed predictions)...")
seed_payloads = []
for s in SEEDS:
    seed_payloads.append(infer_once(s, chosen_exp, train_dataset_id))

print("\nAUROC by seed (ranking metric, threshold-independent):")
for sp in seed_payloads:
    print(f"  seed {sp['seed']}: AUROC={sp['auroc']:.6f}")
auc_mean, auc_sd = mean_sd([sp["auroc"] for sp in seed_payloads])
print(f"Mean AUROC: {auc_mean:.6f} ± {auc_sd:.6f}")

# -------------------------
# 14) Threshold sweep across seeds
# -------------------------
# For each threshold: compute metrics per seed, then take the mean across seeds.
rows = []
for thr in tqdm(THR_GRID, desc="Threshold sweep", dynamic_ncols=True):
    sens_list, absd_list, signd_list = [], [], []
    acc_list, prec_list, spec_list, f1_list, mcc_list, p_list = [], [], [], [], [], []

    fnr_m_list, fnr_f_list = [], []
    n_pd_m_list, n_pd_f_list = [], []

    for sp in seed_payloads:
        y_true = sp["y_true"]
        y_prob = sp["y_prob"]
        y_sex  = sp["y_sex"]

        tm = threshold_metrics(y_true, y_prob, thr)
        fnr_by_sex, delta_signed, delta_abs = fnr_by_sex_signed_delta(y_true, y_prob, y_sex, thr)

        sens_list.append(tm["sensitivity"])
        acc_list.append(tm["accuracy"])
        prec_list.append(tm["precision"])
        spec_list.append(tm["specificity"])
        f1_list.append(tm["f1"])
        mcc_list.append(tm["mcc"])
        p_list.append(tm["fisher_p_two_sided"])

        signd_list.append(delta_signed)
        absd_list.append(delta_abs)

        fnr_m_list.append(float(fnr_by_sex["M"]["fnr"]))
        fnr_f_list.append(float(fnr_by_sex["F"]["fnr"]))
        n_pd_m_list.append(float(fnr_by_sex["M"]["n_pos"]))
        n_pd_f_list.append(float(fnr_by_sex["F"]["n_pos"]))

    row = {
        "threshold": float(thr),

        "mean_sensitivity": safe_nanmean(sens_list),
        "sd_sensitivity": safe_nansd(sens_list),

        "mean_specificity": safe_nanmean(spec_list),
        "sd_specificity": safe_nansd(spec_list),

        "mean_abs_deltaFNR": safe_nanmean(absd_list),
        "sd_abs_deltaFNR": safe_nansd(absd_list),

        "mean_signed_deltaFNR_F_minus_M": safe_nanmean(signd_list),
        "sd_signed_deltaFNR": safe_nansd(signd_list),

        "mean_accuracy": safe_nanmean(acc_list),
        "mean_precision": safe_nanmean(prec_list),
        "mean_f1": safe_nanmean(f1_list),
        "mean_mcc": safe_nanmean(mcc_list),
        "mean_fisher_p_two_sided": safe_nanmean(p_list),

        "mean_FNR_M": safe_nanmean(fnr_m_list),
        "mean_FNR_F": safe_nanmean(fnr_f_list),
        "n_PD_M_each_seed": json.dumps({str(SEEDS[i]): int(n_pd_m_list[i]) for i in range(len(SEEDS))}),
        "n_PD_F_each_seed": json.dumps({str(SEEDS[i]): int(n_pd_f_list[i]) for i in range(len(SEEDS))}),
    }

    # Stores per-seed values to explain the means later.
    row["sensitivity_by_seed"] = json.dumps({str(sp["seed"]): float(sens_list[i]) for i, sp in enumerate(seed_payloads)})
    row["specificity_by_seed"] = json.dumps({str(sp["seed"]): float(spec_list[i]) for i, sp in enumerate(seed_payloads)})
    row["abs_deltaFNR_by_seed"] = json.dumps({str(sp["seed"]): float(absd_list[i]) for i, sp in enumerate(seed_payloads)})
    row["signed_deltaFNR_by_seed"] = json.dumps({str(sp["seed"]): float(signd_list[i]) for i, sp in enumerate(seed_payloads)})

    rows.append(row)

sweep_df = pd.DataFrame(rows)

# -------------------------
# 15) Choose threshold using Policy B+
# -------------------------
# 1) If both constraints are reachable: choose the eligible threshold with the smallest mean |ΔFNR|.
# 2) Otherwise: choose the threshold with the smallest total shortfall from the constraints, and report the gaps.
max_mean_sens = float(sweep_df["mean_sensitivity"].max())
thr_at_max_sens = float(sweep_df.loc[sweep_df["mean_sensitivity"].idxmax(), "threshold"])

max_mean_spec = float(sweep_df["mean_specificity"].max())
thr_at_max_spec = float(sweep_df.loc[sweep_df["mean_specificity"].idxmax(), "threshold"])

eligible = sweep_df[
    (sweep_df["mean_sensitivity"] >= TARGET_SENS) &
    (sweep_df["mean_specificity"] >= MIN_SPEC)
].copy()

constraint_reached = (len(eligible) > 0)

# Summarizes whether a threshold is short on sensitivity and or specificity.
def _constraint_status_at_row(row_dict):
    sens_gap = float(TARGET_SENS - row_dict["mean_sensitivity"])
    spec_gap = float(MIN_SPEC - row_dict["mean_specificity"])
    return {
        "sens_ok": bool(row_dict["mean_sensitivity"] >= TARGET_SENS),
        "spec_ok": bool(row_dict["mean_specificity"] >= MIN_SPEC),
        "sens_gap_needed": float(max(0.0, sens_gap)),
        "spec_gap_needed": float(max(0.0, spec_gap)),
    }

if constraint_reached:
    # Picks the smallest mean |ΔFNR| among thresholds that meet both constraints.
    eligible = eligible.sort_values(
        by=["mean_abs_deltaFNR", "mean_sensitivity", "mean_specificity", "threshold"],
        ascending=[True, False, False, True]
    )
    chosen = eligible.iloc[0].to_dict()

    policy_note = (
        "Constraints reached:\n"
        f"  - mean(sensitivity) >= {TARGET_SENS:.2f}\n"
        f"  - mean(specificity) >= {MIN_SPEC:.2f}\n"
        "Chosen threshold minimizes mean(|ΔFNR|) among eligible thresholds "
        "(tie-breakers: higher sensitivity, higher specificity, then lower threshold)."
    )
else:
    # If no threshold meets both constraints, selects the closest option and explains what failed.
    can_reach_sens = bool((sweep_df["mean_sensitivity"] >= TARGET_SENS).any())
    can_reach_spec = bool((sweep_df["mean_specificity"] >= MIN_SPEC).any())

    working_df = sweep_df.copy()

    # If a constraint is unreachable, it is not required in the subset filter.
    subset = working_df.copy()
    if can_reach_sens:
        subset = subset[subset["mean_sensitivity"] >= TARGET_SENS].copy()
    if can_reach_spec:
        subset = subset[subset["mean_specificity"] >= MIN_SPEC].copy()

    if len(subset) == 0:
        subset = working_df.copy()

    # Defines how far below each constraint the mean values are (0 if satisfied).
    subset["_sens_gap"] = np.maximum(0.0, TARGET_SENS - subset["mean_sensitivity"])
    subset["_spec_gap"] = np.maximum(0.0, MIN_SPEC - subset["mean_specificity"])
    subset["_total_gap"] = subset["_sens_gap"] + subset["_spec_gap"]

    # Chooses the smallest total gap, then smallest mean |ΔFNR|, then higher sensitivity and specificity.
    subset = subset.sort_values(
        by=["_total_gap", "mean_abs_deltaFNR", "mean_sensitivity", "mean_specificity", "threshold"],
        ascending=[True, True, False, False, True]
    )

    chosen = subset.iloc[0].drop(labels=["_sens_gap", "_spec_gap", "_total_gap"]).to_dict()
    status = _constraint_status_at_row(chosen)

    failed_parts = []
    if not can_reach_sens:
        failed_parts.append(
            f"mean(sensitivity) >= {TARGET_SENS:.2f} (UNREACHABLE on this grid; max was {max_mean_sens:.6f} at thr={thr_at_max_sens:.4f})"
        )
    if not can_reach_spec:
        failed_parts.append(
            f"mean(specificity) >= {MIN_SPEC:.2f} (UNREACHABLE on this grid; max was {max_mean_spec:.6f} at thr={thr_at_max_spec:.4f})"
        )
    if not failed_parts:
        failed_parts.append(
            "Both constraints are individually reachable, but no single threshold meets BOTH at the same time on this grid."
        )

    policy_note = (
        "Constraints NOT jointly reachable.\n"
        "What failed:\n"
        "  - " + "\n  - ".join(failed_parts) + "\n"
        "Returned the threshold that minimizes total constraint violation (sum of gaps), then minimizes mean(|ΔFNR|), "
        "then prefers higher sensitivity and higher specificity.\n"
        f"At the chosen threshold, remaining gaps are:\n"
        f"  - sensitivity gap needed: {status['sens_gap_needed']:.6f}\n"
        f"  - specificity gap needed: {status['spec_gap_needed']:.6f}"
    )

print("\n================ POLICY RESULT ================")
print(POLICY_TEXT)
print("Constraint reached (both)?:", bool(constraint_reached))
print("Max achievable mean(sensitivity) on grid:", f"{max_mean_sens:.6f}", "| at threshold:", f"{thr_at_max_sens:.4f}")
print("Max achievable mean(specificity) on grid:", f"{max_mean_spec:.6f}", "| at threshold:", f"{thr_at_max_spec:.4f}")

print("\nChosen threshold:", f"{float(chosen['threshold']):.4f}")
print("Chosen mean(sensitivity):", f"{float(chosen['mean_sensitivity']):.6f}")
print("Chosen mean(specificity):", f"{float(chosen['mean_specificity']):.6f}")
print("Chosen mean(|ΔFNR|):", f"{float(chosen['mean_abs_deltaFNR']):.6f}")
print("Chosen mean signed ΔFNR (F-M):", f"{float(chosen['mean_signed_deltaFNR_F_minus_M']):.6f}")

print("\nDetails (no black box):")
print(policy_note)

# -------------------------
# 16) Save sweep results
# -------------------------
# Writes the full sweep table (all thresholds) and a compact JSON summary (chosen threshold + key stats).
sweep_csv = OUT_DIR / "sweep_table.csv"
sweep_df.to_csv(sweep_csv, index=False)

summary = {
    "train_dataset": "D7",
    "test_dataset": "D2",
    "split_swept": "D2 TEST",
    "required_exp_substring_case_insensitive": REQUIRED_EXP_SUBSTRING,
    "trainval_experiment_used": str(chosen_exp),
    "trainval_summary_path": str(summary_trainval_path),
    "seeds": SEEDS,
    "policy": POLICY_TEXT,
    "constraints": {
        "target_mean_sensitivity": float(TARGET_SENS),
        "min_mean_specificity": float(MIN_SPEC),
    },
    "constraint_reached_both": bool(constraint_reached),
    "max_mean_sensitivity_on_grid": float(max_mean_sens),
    "threshold_at_max_mean_sensitivity": float(thr_at_max_sens),
    "max_mean_specificity_on_grid": float(max_mean_spec),
    "threshold_at_max_mean_specificity": float(thr_at_max_spec),
    "youdenJ_mean_threshold_reference": (None if np.isnan(thr_mean) else float(thr_mean)),
    "chosen_threshold": float(chosen["threshold"]),
    "chosen_metrics": {
        "mean_sensitivity": float(chosen["mean_sensitivity"]),
        "mean_specificity": float(chosen["mean_specificity"]),
        "mean_abs_deltaFNR": float(chosen["mean_abs_deltaFNR"]),
        "mean_signed_deltaFNR_F_minus_M": float(chosen["mean_signed_deltaFNR_F_minus_M"]),
        "mean_accuracy": float(chosen["mean_accuracy"]),
        "mean_precision": float(chosen["mean_precision"]),
        "mean_f1": float(chosen["mean_f1"]),
        "mean_mcc": float(chosen["mean_mcc"]),
        "mean_fisher_p_two_sided": float(chosen["mean_fisher_p_two_sided"]),
        "mean_FNR_M": float(chosen["mean_FNR_M"]),
        "mean_FNR_F": float(chosen["mean_FNR_F"]),
        "sensitivity_by_seed": chosen.get("sensitivity_by_seed", ""),
        "specificity_by_seed": chosen.get("specificity_by_seed", ""),
        "abs_deltaFNR_by_seed": chosen.get("abs_deltaFNR_by_seed", ""),
        "signed_deltaFNR_by_seed": chosen.get("signed_deltaFNR_by_seed", ""),
        "n_PD_M_each_seed": chosen.get("n_PD_M_each_seed", ""),
        "n_PD_F_each_seed": chosen.get("n_PD_F_each_seed", ""),
    },
    "policy_note_transparent": policy_note,
    "paths": {
        "out_dir": str(OUT_DIR),
        "sweep_table_csv": str(sweep_csv),
    },
    "timestamp": timestamp,
}

summary_json = OUT_DIR / "sweep_summary.json"
with open(summary_json, "w", encoding="utf-8") as f:
    json.dump(summary, f, indent=2)

# -------------------------
# 17) Plots
# -------------------------
# Saves quick visuals for: sensitivity, specificity, |ΔFNR| vs threshold, and the sensitivity–fairness tradeoff.
plt.figure()
plt.plot(sweep_df["threshold"].values, sweep_df["mean_sensitivity"].values)
plt.axhline(TARGET_SENS, linestyle="--")
plt.axvline(float(chosen["threshold"]), linestyle="--")
plt.xlabel("Threshold")
plt.ylabel("Mean Sensitivity (across seeds)")
plt.title("Threshold Sweep: Mean Sensitivity vs Threshold (D7 trainEnh1 → D2 TEST)")
plt.tight_layout()
p1 = OUT_DIR / "sweep_sensitivity_vs_threshold.png"
plt.savefig(p1, dpi=150)
plt.close()

plt.figure()
plt.plot(sweep_df["threshold"].values, sweep_df["mean_specificity"].values)
plt.axhline(MIN_SPEC, linestyle="--")
plt.axvline(float(chosen["threshold"]), linestyle="--")
plt.xlabel("Threshold")
plt.ylabel("Mean Specificity (across seeds)")
plt.title("Threshold Sweep: Mean Specificity vs Threshold (D7 trainEnh1 → D2 TEST)")
plt.tight_layout()
p1b = OUT_DIR / "sweep_specificity_vs_threshold.png"
plt.savefig(p1b, dpi=150)
plt.close()

plt.figure()
plt.plot(sweep_df["threshold"].values, sweep_df["mean_abs_deltaFNR"].values)
plt.axvline(float(chosen["threshold"]), linestyle="--")
plt.xlabel("Threshold")
plt.ylabel("Mean |ΔFNR| (across seeds)")
plt.title("Threshold Sweep: Mean |ΔFNR| vs Threshold (D7 trainEnh1 → D2 TEST)")
plt.tight_layout()
p2 = OUT_DIR / "sweep_abs_deltaFNR_vs_threshold.png"
plt.savefig(p2, dpi=150)
plt.close()

plt.figure()
plt.plot(sweep_df["mean_sensitivity"].values, sweep_df["mean_abs_deltaFNR"].values)
plt.scatter([float(chosen["mean_sensitivity"])], [float(chosen["mean_abs_deltaFNR"])])
plt.axvline(TARGET_SENS, linestyle="--")
plt.xlabel("Mean Sensitivity (across seeds)")
plt.ylabel("Mean |ΔFNR| (across seeds)")
plt.title("Threshold Tradeoff: Sensitivity vs |ΔFNR| (D7 trainEnh1 → D2 TEST)")
plt.tight_layout()
p3 = OUT_DIR / "sweep_tradeoff.png"
plt.savefig(p3, dpi=150)
plt.close()

print("\n================ SAVED OUTPUTS ================")
print("OUT_DIR:", str(OUT_DIR))
print("Saved:", str(sweep_csv))
print("Saved:", str(summary_json))
print("Saved plots:")
print(" ", str(p1))
print(" ", str(p1b))
print(" ", str(p2))
print(" ", str(p3))
print("\nDone.")

The following cell runs a **post-training decision threshold sweep** for the **D7 trainEnh2 model evaluated on the D2 test split**, using models that have already been trained. No retraining is done in this step. The cell automatically finds the most recent D7 training and validation experiment whose folder name contains **“trainEnh2”** and checks that all required files exist for each of the three seeds, including the training summary and saved head weights. Using these fixed models, it runs one full inference pass per seed on the D2 test data to generate Parkinson’s probability scores, then studies how performance and sex-based fairness change as the decision threshold moves from 0.01 to 0.99.

The cell starts with basic setup and safety checks to ensure the run is clean and repeatable. It avoids import conflicts from locally named files, mounts Google Drive if needed, resolves the dataset paths, and prints key settings such as the compute device, batch size, threshold range, and policy targets. It then loads the D2 `manifest_all.csv`, confirms that required columns are present, verifies that the data belong to **dataset D2**, and filters the data to the test split only. Before inference begins, it checks that all referenced audio files exist and reports progress during this check.

Each test clip is assigned to a simple task group: **vowel** for sustained vowel recordings and **other** for all remaining speech tasks. Sex labels are standardized from the original D2 values by mapping `male` to **M** and `female` to **F**, with any other value treated as unknown. Audio is loaded at **16 kHz**, and attention masks are created so padded regions are ignored. For vowel clips, the mask also removes trailing near-silence so non-speech segments have less effect on the results.

Inference uses the same model structure as during training: a frozen speech feature extractor with two task-specific heads, one for vowel clips and one for non-vowel clips. For each seed, the saved head weights are loaded and the full D2 test set is processed once. The predicted probabilities are stored together with the true labels and sex information. **AUROC is computed for each seed and averaged across seeds** to provide a threshold-free reference before applying any decision threshold.

With predictions fixed, the cell performs the threshold sweep. At each threshold, standard classification metrics are calculated and averaged across the three seeds, including accuracy, precision, sensitivity, specificity, F1 score, Matthews correlation coefficient, and Fisher’s exact test p-value. Fairness is evaluated by comparing false negative rates for Parkinson’s cases between females and males, reported as both a signed difference and an absolute gap. The sweep table includes both the averaged values and the per-seed results so differences across seeds remain visible.

A single operating threshold is then chosen using a clear rule called **Policy B+**. This threshold minimizes the mean absolute sex gap in false negative rate while also meeting two requirements at the same time: mean sensitivity of at least **0.60** and mean specificity of at least **0.50**. If no threshold meets both conditions, the cell reports which requirement failed and selects the closest available alternative based on how far the results miss the targets.

All outputs are saved to a timestamped folder under the D7 threshold sweep directory. This includes the full sweep table, a summary describing the selected experiment and threshold choice, and plots that show how sensitivity, specificity, and fairness change across thresholds, along with a tradeoff curve that clearly marks the chosen operating point.

In [None]:
# =========================
# Threshold Sweep With Fairness Guardrail (D7 trainEnh2 → D2 test)
# =========================
# Purpose: run inference once per seed using saved D7 trainEnh2 heads, then sweep decision thresholds on the same scores.
# Inputs: (1) saved model heads for the three seeds, (2) D2 manifest with test split clip paths and labels.
# Outputs: sweep_table.csv, sweep_summary.json, and a few sweep plots saved under a new timestamped sweep folder.
# =========================

import os, json, math, random, time, warnings
from pathlib import Path

import numpy as np
import pandas as pd
import soundfile as sf

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import Wav2Vec2Model
from sklearn.metrics import roc_auc_score, confusion_matrix, matthews_corrcoef
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# -------------------------
# 0) Import safety check
# -------------------------
# Stops early if local files would override real PyTorch or Transformers imports.
if os.path.exists("/content/torch.py") or os.path.exists("/content/torch/__init__.py"):
    raise RuntimeError("Found local /content/torch.py or /content/torch/ that shadows PyTorch. Rename/remove it and restart runtime.")
if os.path.exists("/content/transformers.py") or os.path.exists("/content/transformers/__init__.py"):
    raise RuntimeError("Found local /content/transformers.py or /content/transformers/ that shadows Hugging Face Transformers. Rename/remove it and restart runtime.")

# -------------------------
# 1) Drive mount (best-effort)
# -------------------------
# Mounts Google Drive if running in Colab and not already mounted.
try:
    from google.colab import drive  # type: ignore
    if not os.path.isdir("/content/drive/MyDrive"):
        drive.mount("/content/drive")
except Exception:
    pass

# -------------------------
# 2) Root paths
# -------------------------
# Uses existing notebook globals if present; otherwise uses the fallbacks below.
D7_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D7-Multilingual (D1_D4_D5v2_D6)/preprocessed_v1"
D2_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D2-Slovak (EWA-DB)/EWA-DB/preprocessed_v1"

D7_OUT_ROOT = str(globals().get("D7_OUT_ROOT", D7_OUT_ROOT_FALLBACK))
D2_OUT_ROOT = str(globals().get("D2_OUT_ROOT", D2_OUT_ROOT_FALLBACK))
D2_MANIFEST_ALL = f"{D2_OUT_ROOT}/manifests/manifest_all.csv"

# Exposes these roots for downstream cells.
globals()["D7_OUT_ROOT"] = D7_OUT_ROOT
globals()["D2_OUT_ROOT"] = D2_OUT_ROOT

# -------------------------
# 3) Run configuration
# -------------------------
# Core settings: seeds, backbone, audio requirements, and inference batch sizing.
SEEDS          = [1337, 2024, 7777]
BACKBONE_CKPT  = "facebook/wav2vec2-base"
SR_EXPECTED    = 16000
TINY_THRESH    = 1e-4

PER_DEVICE_BS  = 16
NUM_WORKERS    = 0
PIN_MEMORY     = False
USE_AMP        = True
DROPOUT_P      = 0.2

# Experiment selector: matches the exp_* folder name for trainEnh2.
REQUIRED_EXP_SUBSTRING = "trainEnh2"  # case-insensitive match against exp_* folder name

# Policy B+: choose a threshold that reduces sex FNR gap while keeping sensitivity and specificity above minimums.
TARGET_SENS = 0.60
MIN_SPEC    = 0.50   # avoids trivial thresholds that label nearly everything as PD
POLICY_TEXT = (
    "Policy B+: minimize mean(|ΔFNR|) subject to "
    f"mean(sensitivity) >= {TARGET_SENS:.2f} AND mean(specificity) >= {MIN_SPEC:.2f}"
)

# Threshold grid for sweeping.
THR_MIN, THR_MAX, THR_STEPS = 0.01, 0.99, 199
THR_GRID = np.linspace(THR_MIN, THR_MAX, THR_STEPS).astype(np.float64)

# Device selection and minor numeric tuning.
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

# Keeps output readable by filtering common non-critical warnings.
warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
warnings.filterwarnings("ignore", message="Passing `gradient_checkpointing` to a config initialization is deprecated")
warnings.filterwarnings("ignore", category=UserWarning, module="huggingface_hub")

# Prints key settings for quick verification.
print("D7_OUT_ROOT:", D7_OUT_ROOT)
print("D2_OUT_ROOT:", D2_OUT_ROOT)
print("D2_MANIFEST_ALL:", D2_MANIFEST_ALL)
print("DEVICE:", DEVICE)
if DEVICE.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
print("PER_DEVICE_BS:", PER_DEVICE_BS)
print("NUM_WORKERS:", NUM_WORKERS, "| PIN_MEMORY:", PIN_MEMORY)
print("USE_AMP:", bool(USE_AMP and DEVICE.type == "cuda"))
print("Required exp substring (case-insensitive):", REQUIRED_EXP_SUBSTRING)
print("Policy B+ constraints:")
print("  mean sensitivity >=", TARGET_SENS)
print("  mean specificity >=", MIN_SPEC)
print(f"Threshold sweep grid: {THR_MIN:.2f}..{THR_MAX:.2f} with {THR_STEPS} steps")

# -------------------------
# 4) Load D2 manifest and keep only test split
# -------------------------
# Reads the D2 manifest, checks required columns, confirms it is actually D2, then filters to split == "test".
if not os.path.exists(D2_MANIFEST_ALL):
    raise FileNotFoundError(f"Missing D2 manifest_all.csv: {D2_MANIFEST_ALL}")

m_all = pd.read_csv(D2_MANIFEST_ALL)

req_cols = {"split", "clip_path", "label_num", "task", "sex", "age"}
missing = [c for c in sorted(req_cols) if c not in m_all.columns]
if missing:
    raise ValueError(f"D2 manifest missing required columns: {missing}. Found: {list(m_all.columns)}")

# Uses the most common dataset label (when present) and filters to that dataset.
if "dataset" in m_all.columns and m_all["dataset"].notna().any():
    d2_dataset_id = str(m_all["dataset"].astype(str).value_counts(dropna=True).idxmax())
    m_all = m_all[m_all["dataset"].astype(str) == d2_dataset_id].copy()
else:
    d2_dataset_id = "DX"

# Hard stop if the manifest is not D2.
if d2_dataset_id != "D2":
    raise RuntimeError(
        f"Expected D2 dataset_id=='D2' but got {d2_dataset_id!r}. "
        "This usually means D2_OUT_ROOT is wrong or the manifest is not D2. "
        f"D2_OUT_ROOT={D2_OUT_ROOT}"
    )

# Keeps a stable set of columns used downstream (fills missing ones with NaN).
keep_cols = ["clip_path", "label_num", "task", "speaker_id", "sex", "age", "duration_sec", "split"]
for c in keep_cols:
    if c not in m_all.columns:
        m_all[c] = np.nan
m_all = m_all[keep_cols].copy()

# Test split only.
test_df = m_all[m_all["split"].astype(str) == "test"].reset_index(drop=True)

print(f"\nD2 dataset inferred: {d2_dataset_id}")
print(f"D2 TEST rows: {len(test_df)}")
print("D2 TEST label counts:", test_df["label_num"].value_counts(dropna=False).to_dict())
print("D2 TEST sex counts (raw):", test_df["sex"].value_counts(dropna=False).to_dict())

if len(test_df) == 0:
    raise RuntimeError("After filtering to split=='test', D2 manifest has 0 rows.")

# -------------------------
# 5) Fail fast: confirm audio files exist
# -------------------------
# Checks for missing clip paths early to avoid long runs that fail late.
def _fail_fast_missing_paths(df: pd.DataFrame, name: str):
    missing_paths = []
    for p in tqdm(df["clip_path"].astype(str).tolist(), desc=f"Check {name} clip_path exists", dynamic_ncols=True):
        if not os.path.exists(p):
            missing_paths.append(p)
            if len(missing_paths) >= 10:
                break
    if missing_paths:
        raise FileNotFoundError(f"{name}: missing clip_path(s). Examples: {missing_paths}")

_fail_fast_missing_paths(test_df, "D2 TEST")

# -------------------------
# 6) Task grouping (vowel vs other)
# -------------------------
# Maps each clip to the head it should use during inference.
def _task_group(task_val) -> str:
    return "vowel" if str(task_val) == "vowl" else "other"

test_df["task_group"] = test_df["task"].apply(_task_group)

# -------------------------
# 6.5) Sex normalization
# -------------------------
# Normalizes D2 "male"/"female" into "M"/"F" and maps anything else to "UNK".
def normalize_sex_d2_case_sensitive(val) -> str:
    if pd.isna(val):
        return "UNK"
    if val == "male":
        return "M"
    if val == "female":
        return "F"
    return "UNK"

test_df["sex_norm"] = test_df["sex"].apply(normalize_sex_d2_case_sensitive)
print("D2 TEST sex counts (normalized):", test_df["sex_norm"].value_counts(dropna=False).to_dict())
if (test_df["sex_norm"] == "UNK").any():
    print("NOTE: Some D2 'sex' values were not exactly 'male'/'female' and were mapped to 'UNK'.")

# -------------------------
# 7) Dataset and collator
# -------------------------
# Builds model inputs from audio files and creates an attention mask for pooling.
# For vowel clips, the mask removes trailing padding-like silence based on a tiny amplitude threshold.
class AudioManifestDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        clip_path = str(row["clip_path"])
        label = int(row["label_num"])
        task_group = str(row["task_group"])
        sex_norm = str(row["sex_norm"])

        if not os.path.exists(clip_path):
            raise FileNotFoundError(f"Missing clip_path: {clip_path}")

        y, sr = sf.read(clip_path, always_2d=False)
        if isinstance(y, np.ndarray) and y.ndim > 1:
            y = np.mean(y, axis=1)
        y = np.asarray(y, dtype=np.float32)

        if int(sr) != SR_EXPECTED:
            raise ValueError(f"Sample rate mismatch (expected {SR_EXPECTED}): {clip_path} has sr={sr}")

        # Attention mask: 1 = keep, 0 = ignore during pooling.
        attn = np.ones((len(y),), dtype=np.int64)
        if task_group == "vowel":
            # Find last non-tiny sample and ignore anything after it.
            k = -1
            for j in range(len(y) - 1, -1, -1):
                if abs(float(y[j])) > TINY_THRESH:
                    k = j
                    break
            if k >= 0:
                attn[:k+1] = 1
                attn[k+1:] = 0
            else:
                attn[:] = 1
        else:
            attn[:] = 1

        return {
            "input_values": torch.from_numpy(y),
            "attention_mask": torch.from_numpy(attn),
            "labels": torch.tensor(label, dtype=torch.long),
            "task_group": task_group,
            "sex_norm": sex_norm,
        }

# Pads variable-length audio in a batch to the longest clip length.
def collate_fn(batch):
    max_len = int(max(b["input_values"].numel() for b in batch))
    input_vals, attn_masks, labels, task_groups, sex_norms = [], [], [], [], []
    for b in batch:
        x = b["input_values"]
        a = b["attention_mask"]
        pad = max_len - x.numel()
        if pad > 0:
            x = torch.cat([x, torch.zeros(pad, dtype=x.dtype)], dim=0)
            a = torch.cat([a, torch.zeros(pad, dtype=a.dtype)], dim=0)
        input_vals.append(x)
        attn_masks.append(a)
        labels.append(b["labels"])
        task_groups.append(b["task_group"])
        sex_norms.append(b["sex_norm"])
    return {
        "input_values": torch.stack(input_vals, dim=0),
        "attention_mask": torch.stack(attn_masks, dim=0),
        "labels": torch.stack(labels, dim=0),
        "task_group": task_groups,
        "sex_norm": sex_norms,
    }

# DataLoader over the D2 test split.
test_ds = AudioManifestDataset(test_df)
test_loader = DataLoader(
    test_ds,
    batch_size=PER_DEVICE_BS,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    collate_fn=collate_fn,
)

# -------------------------
# 8) Model definition (frozen backbone + two heads)
# -------------------------
# Loads a frozen Wav2Vec2 backbone and switches between vowel and other heads per clip.
class Wav2Vec2TwoHeadClassifier(nn.Module):
    def __init__(self, ckpt: str, dropout_p: float):
        super().__init__()
        self.backbone = Wav2Vec2Model.from_pretrained(ckpt, use_safetensors=True)
        for p in self.backbone.parameters():
            p.requires_grad = False

        H = int(self.backbone.config.hidden_size)
        self.pre_vowel = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.pre_other = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.head_vowel = nn.Linear(H, 2)
        self.head_other = nn.Linear(H, 2)

    # Mean-pools frame features using the attention mask (ignores masked regions).
    def masked_mean_pool(self, last_hidden, attn_mask_samples):
        feat_mask = self.backbone._get_feature_vector_attention_mask(last_hidden.shape[1], attn_mask_samples)
        feat_mask = feat_mask.to(last_hidden.device).unsqueeze(-1).type_as(last_hidden)
        masked = last_hidden * feat_mask
        denom = feat_mask.sum(dim=1).clamp(min=1.0)
        return masked.sum(dim=1) / denom

    # Runs the head in fp32 for stability even when AMP is enabled.
    def _heads_fp32(self, x_any: torch.Tensor, head: nn.Module) -> torch.Tensor:
        x = x_any.float()
        if x.is_cuda:
            with torch.amp.autocast(device_type="cuda", enabled=False):
                return head(x)
        return head(x)

    # Produces logits for each sample using the appropriate head (vowel or other).
    def forward_logits(self, input_values, attention_mask, task_group):
        with torch.no_grad():
            out = self.backbone(input_values=input_values, attention_mask=attention_mask)
            last_hidden = out.last_hidden_state

        pooled = self.masked_mean_pool(last_hidden, attention_mask)
        z_v = self.pre_vowel(pooled.float())
        z_o = self.pre_other(pooled.float())

        logits_v = self._heads_fp32(z_v, self.head_vowel)
        logits_o = self._heads_fp32(z_o, self.head_other)

        is_vowel = torch.tensor([tg == "vowel" for tg in task_group], device=pooled.device, dtype=torch.bool)
        logits = logits_o.clone()
        if is_vowel.any():
            logits[is_vowel] = logits_v[is_vowel]
        return logits

# Loads only the head weights saved during training (backbone stays the pretrained frozen model).
def load_heads_into_model(model: Wav2Vec2TwoHeadClassifier, best_heads_path: Path):
    if not best_heads_path.exists():
        raise FileNotFoundError(f"Missing best_heads.pt: {str(best_heads_path)}")
    state = torch.load(str(best_heads_path), map_location="cpu")
    model.pre_vowel.load_state_dict(state["pre_vowel"], strict=True)
    model.pre_other.load_state_dict(state["pre_other"], strict=True)
    model.head_vowel.load_state_dict(state["head_vowel"], strict=True)
    model.head_other.load_state_dict(state["head_other"], strict=True)
    return model

# -------------------------
# 9) Seed control and single-pass inference
# -------------------------
# Runs inference once per seed and returns fixed probabilities, labels, and sex tags for threshold sweeping.
def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def infer_once(seed: int, chosen_exp: Path, train_dataset_id: str):
    set_all_seeds(seed)
    best_heads_path = chosen_exp / f"run_{train_dataset_id}_seed{seed}" / "best_heads.pt"

    model = Wav2Vec2TwoHeadClassifier(BACKBONE_CKPT, dropout_p=DROPOUT_P).to(DEVICE)
    model = load_heads_into_model(model, best_heads_path)
    model.eval()

    use_amp = bool(USE_AMP and DEVICE.type == "cuda")
    all_probs, all_true, all_sex = [], [], []

    # Forward pass over D2 test once (no thresholding here).
    pbar = tqdm(test_loader, desc=f"[seed={seed}] Inference D2 TEST", dynamic_ncols=True)
    with torch.inference_mode():
        for batch in pbar:
            input_values = batch["input_values"].to(DEVICE, non_blocking=False)
            attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
            labels = batch["labels"].to(DEVICE, non_blocking=False)
            task_group = batch["task_group"]
            sex_norm = batch["sex_norm"]

            with torch.amp.autocast(device_type="cuda", enabled=use_amp):
                logits = model.forward_logits(input_values, attention_mask, task_group)

            # Probability for PD is the softmax class-1 probability.
            probs = torch.softmax(logits.float(), dim=-1)[:, 1].detach().cpu().numpy().astype(np.float64)
            all_probs.extend(probs.tolist())
            all_true.extend(labels.detach().cpu().numpy().astype(np.int64).tolist())
            all_sex.extend(list(sex_norm))

    y_true = np.asarray(all_true, dtype=np.int64)
    y_prob = np.asarray(all_probs, dtype=np.float64)
    y_sex  = np.asarray(all_sex, dtype=object)

    # AUROC is reported per seed as a threshold-independent reference.
    auc = float("nan")
    if len(np.unique(y_true)) >= 2:
        auc = float(roc_auc_score(y_true, y_prob))

    return {
        "seed": int(seed),
        "best_heads_path": str(best_heads_path),
        "y_true": y_true,
        "y_prob": y_prob,
        "y_sex": y_sex,
        "auroc": float(auc),
    }

# -------------------------
# 10) Metric helpers for sweeping thresholds
# -------------------------
# Computes standard metrics from a thresholded prediction and sex-based FNR gap on PD-only cases.
def confusion_counts(y_true, y_prob, thr):
    y_pred = (y_prob >= float(thr)).astype(np.int64)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])
    return TN, FP, FN, TP

def threshold_metrics(y_true, y_prob, thr):
    TN, FP, FN, TP = confusion_counts(y_true, y_prob, thr)
    eps = 1e-12
    acc  = (TP + TN) / max(1, (TP + TN + FP + FN))
    prec = TP / (TP + FP + eps)
    sens = TP / (TP + FN + eps)  # recall/sensitivity
    spec = TN / (TN + FP + eps)
    f1   = 2 * prec * sens / (prec + sens + eps)

    # MCC may be undefined if predictions collapse into a single class.
    try:
        y_pred = (y_prob >= float(thr)).astype(np.int64)
        mcc = float(matthews_corrcoef(y_true, y_pred)) if len(np.unique(y_pred)) > 1 else float("nan")
    except Exception:
        mcc = float("nan")

    # Fisher exact test on the 2x2 table (best-effort if scipy is available).
    pval = float("nan")
    try:
        from scipy.stats import fisher_exact  # type: ignore
        _, pval = fisher_exact([[TN, FP], [FN, TP]], alternative="two-sided")
        pval = float(pval)
    except Exception:
        pval = float("nan")

    return {
        "tn": TN, "fp": FP, "fn": FN, "tp": TP,
        "accuracy": float(acc),
        "precision": float(prec),
        "sensitivity": float(sens),
        "specificity": float(spec),
        "f1": float(f1),
        "mcc": float(mcc),
        "fisher_p_two_sided": float(pval),
    }

def fnr_by_sex_signed_delta(y_true, y_prob, y_sex, thr):
    # FNR is computed within each sex group on true PD cases only: FN / (FN + TP).
    y_pred = (y_prob >= float(thr)).astype(np.int64)
    out = {}

    for g in ["M", "F"]:
        mask_g = (y_sex == g)
        pos_mask = mask_g & (y_true == 1)
        n_pos = int(pos_mask.sum())
        if n_pos == 0:
            out[g] = {"n_pos": 0, "fn": 0, "tp": 0, "fnr": float("nan")}
            continue
        tp = int(((y_pred == 1) & pos_mask).sum())
        fn = int(((y_pred == 0) & pos_mask).sum())
        fnr = float(fn / max(1, (fn + tp)))
        out[g] = {"n_pos": int(n_pos), "fn": int(fn), "tp": int(tp), "fnr": float(fnr)}

    fnr_m = out["M"]["fnr"]
    fnr_f = out["F"]["fnr"]
    if (not np.isnan(fnr_m)) and (not np.isnan(fnr_f)):
        delta = float(fnr_f - fnr_m)     # ΔFNR = F - M (signed)
        absd  = float(abs(delta))
    else:
        delta = float("nan")
        absd  = float("nan")

    return out, delta, absd

# Small helpers to aggregate values across seeds while tolerating NaNs.
def mean_sd(vals):
    vals = np.asarray(vals, dtype=np.float64)
    mu = float(np.nanmean(vals)) if np.any(~np.isnan(vals)) else float("nan")
    sd = float(np.nanstd(vals, ddof=1)) if np.sum(~np.isnan(vals)) > 1 else 0.0
    return mu, sd

def safe_nanmean(x):
    x = np.asarray(x, dtype=np.float64)
    if np.any(~np.isnan(x)):
        return float(np.nanmean(x))
    return float("nan")

def safe_nansd(x):
    x = np.asarray(x, dtype=np.float64)
    if np.sum(~np.isnan(x)) > 1:
        return float(np.nanstd(x, ddof=1))
    return 0.0

# -------------------------
# 11) Find the most recent matching trainval experiment
# -------------------------
# Selects the newest exp_* folder that matches REQUIRED_EXP_SUBSTRING and has all seed heads plus summary_trainval.json.
TRAINVAL_ROOT = Path(D7_OUT_ROOT) / "trainval_runs"
if not TRAINVAL_ROOT.exists():
    raise FileNotFoundError(f"Missing trainval_runs folder under D7_OUT_ROOT: {str(TRAINVAL_ROOT)}")

exp_dirs = sorted([p for p in TRAINVAL_ROOT.glob("exp_*") if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
if not exp_dirs:
    raise FileNotFoundError(f"No exp_* folders found under: {str(TRAINVAL_ROOT)}")

train_dataset_id = "D7"

def is_match(exp_path: Path, required_substring: str) -> bool:
    return required_substring.lower() in exp_path.name.lower()

def has_all_seeds_and_summary(exp_path: Path, dataset_id: str, seeds: list) -> bool:
    if not (exp_path / "summary_trainval.json").exists():
        return False
    for s in seeds:
        p = exp_path / f"run_{dataset_id}_seed{s}" / "best_heads.pt"
        if not p.exists():
            return False
    return True

chosen_exp = None
for ed in exp_dirs:
    if not is_match(ed, REQUIRED_EXP_SUBSTRING):
        continue
    if has_all_seeds_and_summary(ed, train_dataset_id, SEEDS):
        chosen_exp = ed
        break

if chosen_exp is None:
    examples = [p.name for p in exp_dirs[:10]]
    raise FileNotFoundError(
        "Could not find a recent D7 trainval experiment folder that:\n"
        f"  (1) contains substring '{REQUIRED_EXP_SUBSTRING}' (case-insensitive) in the exp folder name, and\n"
        "  (2) contains all 3 best_heads.pt files + summary_trainval.json.\n"
        f"Checked under: {str(TRAINVAL_ROOT)}\n\n"
        f"Example exp_* folder names (most recent first):\n  - " + "\n  - ".join(examples) + "\n\n"
        "Fix: set REQUIRED_EXP_SUBSTRING to match the exact tag used in the exp folder name."
    )

print("\nUsing trainval experiment folder:")
print(" ", str(chosen_exp))

# Loads the trainval summary only to print the mean Youden-J threshold as a reference.
summary_trainval_path = chosen_exp / "summary_trainval.json"
with open(summary_trainval_path, "r", encoding="utf-8") as f:
    trainval_summary = json.load(f)

val_opt = ((trainval_summary or {}).get("val_optimal_threshold", {}) or {})
thr_mean = float((((val_opt.get("mean_sd", {}) or {}).get("mean", float("nan")))))
print("\nTrainval mean val-opt threshold (Youden J) for reference:")
print("  summary_trainval.json -> val_optimal_threshold.mean_sd.mean =", f"{thr_mean:.6f}" if not np.isnan(thr_mean) else "nan")

# -------------------------
# 12) Create output folder for this sweep run
# -------------------------
# Uses a timestamp so each sweep run writes to a fresh folder.
TS_ROOT = Path(D7_OUT_ROOT) / "threshold_sweeps"
TS_ROOT.mkdir(parents=True, exist_ok=True)
timestamp = time.strftime("%Y%m%d_%H%M%S")
OUT_DIR = TS_ROOT / f"run_D7_{REQUIRED_EXP_SUBSTRING}_on_D2test_{timestamp}"
OUT_DIR.mkdir(parents=True, exist_ok=True)

# -------------------------
# 13) Inference once per seed
# -------------------------
# Produces fixed prediction probabilities used by all threshold evaluations.
print("\nRunning inference ONCE per seed (then sweep thresholds on fixed predictions)...")
seed_payloads = []
for s in SEEDS:
    seed_payloads.append(infer_once(s, chosen_exp, train_dataset_id))

print("\nAUROC by seed (ranking metric, threshold-independent):")
for sp in seed_payloads:
    print(f"  seed {sp['seed']}: AUROC={sp['auroc']:.6f}")
auc_mean, auc_sd = mean_sd([sp["auroc"] for sp in seed_payloads])
print(f"Mean AUROC: {auc_mean:.6f} ± {auc_sd:.6f}")

# -------------------------
# 14) Threshold sweep across seeds
# -------------------------
# For each threshold: compute metrics per seed, then take the mean across seeds.
rows = []
for thr in tqdm(THR_GRID, desc="Threshold sweep", dynamic_ncols=True):
    sens_list, absd_list, signd_list = [], [], []
    acc_list, prec_list, spec_list, f1_list, mcc_list, p_list = [], [], [], [], [], []

    fnr_m_list, fnr_f_list = [], []
    n_pd_m_list, n_pd_f_list = [], []

    for sp in seed_payloads:
        y_true = sp["y_true"]
        y_prob = sp["y_prob"]
        y_sex  = sp["y_sex"]

        tm = threshold_metrics(y_true, y_prob, thr)
        fnr_by_sex, delta_signed, delta_abs = fnr_by_sex_signed_delta(y_true, y_prob, y_sex, thr)

        sens_list.append(tm["sensitivity"])
        acc_list.append(tm["accuracy"])
        prec_list.append(tm["precision"])
        spec_list.append(tm["specificity"])
        f1_list.append(tm["f1"])
        mcc_list.append(tm["mcc"])
        p_list.append(tm["fisher_p_two_sided"])

        signd_list.append(delta_signed)
        absd_list.append(delta_abs)

        fnr_m_list.append(float(fnr_by_sex["M"]["fnr"]))
        fnr_f_list.append(float(fnr_by_sex["F"]["fnr"]))
        n_pd_m_list.append(float(fnr_by_sex["M"]["n_pos"]))
        n_pd_f_list.append(float(fnr_by_sex["F"]["n_pos"]))

    row = {
        "threshold": float(thr),

        "mean_sensitivity": safe_nanmean(sens_list),
        "sd_sensitivity": safe_nansd(sens_list),

        "mean_specificity": safe_nanmean(spec_list),
        "sd_specificity": safe_nansd(spec_list),

        "mean_abs_deltaFNR": safe_nanmean(absd_list),
        "sd_abs_deltaFNR": safe_nansd(absd_list),

        "mean_signed_deltaFNR_F_minus_M": safe_nanmean(signd_list),
        "sd_signed_deltaFNR": safe_nansd(signd_list),

        "mean_accuracy": safe_nanmean(acc_list),
        "mean_precision": safe_nanmean(prec_list),
        "mean_f1": safe_nanmean(f1_list),
        "mean_mcc": safe_nanmean(mcc_list),
        "mean_fisher_p_two_sided": safe_nanmean(p_list),

        "mean_FNR_M": safe_nanmean(fnr_m_list),
        "mean_FNR_F": safe_nanmean(fnr_f_list),
        "n_PD_M_each_seed": json.dumps({str(SEEDS[i]): int(n_pd_m_list[i]) for i in range(len(SEEDS))}),
        "n_PD_F_each_seed": json.dumps({str(SEEDS[i]): int(n_pd_f_list[i]) for i in range(len(SEEDS))}),
    }

    # Stores per-seed values to explain the means later.
    row["sensitivity_by_seed"] = json.dumps({str(sp["seed"]): float(sens_list[i]) for i, sp in enumerate(seed_payloads)})
    row["specificity_by_seed"] = json.dumps({str(sp["seed"]): float(spec_list[i]) for i, sp in enumerate(seed_payloads)})
    row["abs_deltaFNR_by_seed"] = json.dumps({str(sp["seed"]): float(absd_list[i]) for i, sp in enumerate(seed_payloads)})
    row["signed_deltaFNR_by_seed"] = json.dumps({str(sp["seed"]): float(signd_list[i]) for i, sp in enumerate(seed_payloads)})

    rows.append(row)

sweep_df = pd.DataFrame(rows)

# -------------------------
# 15) Pick threshold using Policy B+
# -------------------------
# 1) If both constraints are reachable: choose the eligible threshold with the smallest mean |ΔFNR|.
# 2) Otherwise: choose the threshold with the smallest total shortfall from the constraints, and report the gaps.
max_mean_sens = float(sweep_df["mean_sensitivity"].max())
thr_at_max_sens = float(sweep_df.loc[sweep_df["mean_sensitivity"].idxmax(), "threshold"])

max_mean_spec = float(sweep_df["mean_specificity"].max())
thr_at_max_spec = float(sweep_df.loc[sweep_df["mean_specificity"].idxmax(), "threshold"])

eligible = sweep_df[
    (sweep_df["mean_sensitivity"] >= TARGET_SENS) &
    (sweep_df["mean_specificity"] >= MIN_SPEC)
].copy()

constraint_reached = (len(eligible) > 0)

def _constraint_status_at_row(row_dict):
    sens_gap = float(TARGET_SENS - row_dict["mean_sensitivity"])
    spec_gap = float(MIN_SPEC - row_dict["mean_specificity"])
    return {
        "sens_ok": bool(row_dict["mean_sensitivity"] >= TARGET_SENS),
        "spec_ok": bool(row_dict["mean_specificity"] >= MIN_SPEC),
        "sens_gap_needed": float(max(0.0, sens_gap)),
        "spec_gap_needed": float(max(0.0, spec_gap)),
    }

if constraint_reached:
    eligible = eligible.sort_values(
        by=["mean_abs_deltaFNR", "mean_sensitivity", "mean_specificity", "threshold"],
        ascending=[True, False, False, True]
    )
    chosen = eligible.iloc[0].to_dict()

    policy_note = (
        "Constraints reached:\n"
        f"  - mean(sensitivity) >= {TARGET_SENS:.2f}\n"
        f"  - mean(specificity) >= {MIN_SPEC:.2f}\n"
        "Chosen threshold minimizes mean(|ΔFNR|) among eligible thresholds "
        "(tie-breakers: higher sensitivity, higher specificity, then lower threshold)."
    )
else:
    can_reach_sens = bool((sweep_df["mean_sensitivity"] >= TARGET_SENS).any())
    can_reach_spec = bool((sweep_df["mean_specificity"] >= MIN_SPEC).any())

    working_df = sweep_df.copy()
    subset = working_df.copy()
    if can_reach_sens:
        subset = subset[subset["mean_sensitivity"] >= TARGET_SENS].copy()
    if can_reach_spec:
        subset = subset[subset["mean_specificity"] >= MIN_SPEC].copy()

    if len(subset) == 0:
        subset = working_df.copy()

    subset["_sens_gap"] = np.maximum(0.0, TARGET_SENS - subset["mean_sensitivity"])
    subset["_spec_gap"] = np.maximum(0.0, MIN_SPEC - subset["mean_specificity"])
    subset["_total_gap"] = subset["_sens_gap"] + subset["_spec_gap"]

    subset = subset.sort_values(
        by=["_total_gap", "mean_abs_deltaFNR", "mean_sensitivity", "mean_specificity", "threshold"],
        ascending=[True, True, False, False, True]
    )

    chosen = subset.iloc[0].drop(labels=["_sens_gap", "_spec_gap", "_total_gap"]).to_dict()
    status = _constraint_status_at_row(chosen)

    failed_parts = []
    if not can_reach_sens:
        failed_parts.append(
            f"mean(sensitivity) >= {TARGET_SENS:.2f} (UNREACHABLE on this grid; max was {max_mean_sens:.6f} at thr={thr_at_max_sens:.4f})"
        )
    if not can_reach_spec:
        failed_parts.append(
            f"mean(specificity) >= {MIN_SPEC:.2f} (UNREACHABLE on this grid; max was {max_mean_spec:.6f} at thr={thr_at_max_spec:.4f})"
        )
    if not failed_parts:
        failed_parts.append(
            "Both constraints are individually reachable, but no single threshold meets BOTH at the same time on this grid."
        )

    policy_note = (
        "Constraints NOT jointly reachable.\n"
        "What failed:\n"
        "  - " + "\n  - ".join(failed_parts) + "\n"
        "Returned the threshold that minimizes total constraint violation (sum of gaps), then minimizes mean(|ΔFNR|), "
        "then prefers higher sensitivity and higher specificity.\n"
        f"At the chosen threshold, remaining gaps are:\n"
        f"  - sensitivity gap needed: {status['sens_gap_needed']:.6f}\n"
        f"  - specificity gap needed: {status['spec_gap_needed']:.6f}"
    )

print("\n================ POLICY RESULT ================")
print(POLICY_TEXT)
print("Constraint reached (both)?:", bool(constraint_reached))
print("Max achievable mean(sensitivity) on grid:", f"{max_mean_sens:.6f}", "| at threshold:", f"{thr_at_max_sens:.4f}")
print("Max achievable mean(specificity) on grid:", f"{max_mean_spec:.6f}", "| at threshold:", f"{thr_at_max_spec:.4f}")

print("\nChosen threshold:", f"{float(chosen['threshold']):.4f}")
print("Chosen mean(sensitivity):", f"{float(chosen['mean_sensitivity']):.6f}")
print("Chosen mean(specificity):", f"{float(chosen['mean_specificity']):.6f}")
print("Chosen mean(|ΔFNR|):", f"{float(chosen['mean_abs_deltaFNR']):.6f}")
print("Chosen mean signed ΔFNR (F-M):", f"{float(chosen['mean_signed_deltaFNR_F_minus_M']):.6f}")

print("\nDetails (no black box):")
print(policy_note)

# -------------------------
# 16) Save sweep results
# -------------------------
# Writes the full sweep table (all thresholds) and a compact JSON summary (chosen threshold + key stats).
sweep_csv = OUT_DIR / "sweep_table.csv"
sweep_df.to_csv(sweep_csv, index=False)

summary = {
    "train_dataset": "D7",
    "test_dataset": "D2",
    "split_swept": "D2 TEST",
    "required_exp_substring_case_insensitive": REQUIRED_EXP_SUBSTRING,
    "trainval_experiment_used": str(chosen_exp),
    "trainval_summary_path": str(summary_trainval_path),
    "seeds": SEEDS,
    "policy": POLICY_TEXT,
    "constraints": {
        "target_mean_sensitivity": float(TARGET_SENS),
        "min_mean_specificity": float(MIN_SPEC),
    },
    "constraint_reached_both": bool(constraint_reached),
    "max_mean_sensitivity_on_grid": float(max_mean_sens),
    "threshold_at_max_mean_sensitivity": float(thr_at_max_sens),
    "max_mean_specificity_on_grid": float(max_mean_spec),
    "threshold_at_max_mean_specificity": float(thr_at_max_spec),
    "youdenJ_mean_threshold_reference": (None if np.isnan(thr_mean) else float(thr_mean)),
    "chosen_threshold": float(chosen["threshold"]),
    "chosen_metrics": {
        "mean_sensitivity": float(chosen["mean_sensitivity"]),
        "mean_specificity": float(chosen["mean_specificity"]),
        "mean_abs_deltaFNR": float(chosen["mean_abs_deltaFNR"]),
        "mean_signed_deltaFNR_F_minus_M": float(chosen["mean_signed_deltaFNR_F_minus_M"]),
        "mean_accuracy": float(chosen["mean_accuracy"]),
        "mean_precision": float(chosen["mean_precision"]),
        "mean_f1": float(chosen["mean_f1"]),
        "mean_mcc": float(chosen["mean_mcc"]),
        "mean_fisher_p_two_sided": float(chosen["mean_fisher_p_two_sided"]),
        "mean_FNR_M": float(chosen["mean_FNR_M"]),
        "mean_FNR_F": float(chosen["mean_FNR_F"]),
        "sensitivity_by_seed": chosen.get("sensitivity_by_seed", ""),
        "specificity_by_seed": chosen.get("specificity_by_seed", ""),
        "abs_deltaFNR_by_seed": chosen.get("abs_deltaFNR_by_seed", ""),
        "signed_deltaFNR_by_seed": chosen.get("signed_deltaFNR_by_seed", ""),
        "n_PD_M_each_seed": chosen.get("n_PD_M_each_seed", ""),
        "n_PD_F_each_seed": chosen.get("n_PD_F_each_seed", ""),
    },
    "policy_note_transparent": policy_note,
    "paths": {
        "out_dir": str(OUT_DIR),
        "sweep_table_csv": str(sweep_csv),
    },
    "timestamp": timestamp,
}

summary_json = OUT_DIR / "sweep_summary.json"
with open(summary_json, "w", encoding="utf-8") as f:
    json.dump(summary, f, indent=2)

# -------------------------
# 17) Plots
# -------------------------
# Saves quick visuals for: sensitivity, specificity, |ΔFNR| vs threshold, and the sensitivity–fairness tradeoff.
plt.figure()
plt.plot(sweep_df["threshold"].values, sweep_df["mean_sensitivity"].values)
plt.axhline(TARGET_SENS, linestyle="--")
plt.axvline(float(chosen["threshold"]), linestyle="--")
plt.xlabel("Threshold")
plt.ylabel("Mean Sensitivity (across seeds)")
plt.title(f"Threshold Sweep: Mean Sensitivity vs Threshold (D7 {REQUIRED_EXP_SUBSTRING} → D2 TEST)")
plt.tight_layout()
p1 = OUT_DIR / "sweep_sensitivity_vs_threshold.png"
plt.savefig(p1, dpi=150)
plt.close()

plt.figure()
plt.plot(sweep_df["threshold"].values, sweep_df["mean_specificity"].values)
plt.axhline(MIN_SPEC, linestyle="--")
plt.axvline(float(chosen["threshold"]), linestyle="--")
plt.xlabel("Threshold")
plt.ylabel("Mean Specificity (across seeds)")
plt.title(f"Threshold Sweep: Mean Specificity vs Threshold (D7 {REQUIRED_EXP_SUBSTRING} → D2 TEST)")
plt.tight_layout()
p1b = OUT_DIR / "sweep_specificity_vs_threshold.png"
plt.savefig(p1b, dpi=150)
plt.close()

plt.figure()
plt.plot(sweep_df["threshold"].values, sweep_df["mean_abs_deltaFNR"].values)
plt.axvline(float(chosen["threshold"]), linestyle="--")
plt.xlabel("Threshold")
plt.ylabel("Mean |ΔFNR| (across seeds)")
plt.title(f"Threshold Sweep: Mean |ΔFNR| vs Threshold (D7 {REQUIRED_EXP_SUBSTRING} → D2 TEST)")
plt.tight_layout()
p2 = OUT_DIR / "sweep_abs_deltaFNR_vs_threshold.png"
plt.savefig(p2, dpi=150)
plt.close()

plt.figure()
plt.plot(sweep_df["mean_sensitivity"].values, sweep_df["mean_abs_deltaFNR"].values)
plt.scatter([float(chosen["mean_sensitivity"])], [float(chosen["mean_abs_deltaFNR"])])
plt.axvline(TARGET_SENS, linestyle="--")
plt.xlabel("Mean Sensitivity (across seeds)")
plt.ylabel("Mean |ΔFNR| (across seeds)")
plt.title(f"Threshold Tradeoff: Sensitivity vs |ΔFNR| (D7 {REQUIRED_EXP_SUBSTRING} → D2 TEST)")
plt.tight_layout()
p3 = OUT_DIR / "sweep_tradeoff.png"
plt.savefig(p3, dpi=150)
plt.close()

print("\n================ SAVED OUTPUTS ================")
print("OUT_DIR:", str(OUT_DIR))
print("Saved:", str(sweep_csv))
print("Saved:", str(summary_json))
print("Saved plots:")
print(" ", str(p1))
print(" ", str(p1b))
print(" ", str(p2))
print(" ", str(p3))
print("\nDone.")

The following cell runs a **post-training threshold selection study** for the **D7 trainEnh3 model evaluated on the D2 test set**, using models that are already trained. No retraining is done. The cell finds the most recent D7 training experiment linked to **trainEnh3**, runs inference once per random seed to get a fixed set of prediction scores, and then studies how performance and sex-based fairness change as the decision threshold varies. The goal is to choose one operating threshold using a clear rule called **Policy B+**, which focuses on fairness while still meeting minimum sensitivity and specificity targets.

The cell starts with setup checks to ensure a clean run and to avoid import issues caused by locally named files. It mounts Google Drive if needed, resolves dataset paths from existing variables or safe defaults, and defines all fixed settings used in the study. These include the three random seeds, the speech backbone checkpoint, the required 16 kHz sample rate, batch size, evaluation precision settings, the text used to identify the correct training experiment, and the threshold grid from 0.01 to 0.99. The Policy B+ rules are also defined here: among thresholds that meet minimum mean sensitivity and specificity across seeds, preference is given to the one with the smallest average sex gap in false negative rate.

Next, the D2 `manifest_all.csv` file is loaded and checked for required columns and correct dataset identity. The data are filtered to the **test split only**, and basic counts by label and sex are printed. A fail-fast check confirms that all listed audio files exist before any model work begins. Two standardized fields are added for later analysis: a simple task group that separates sustained vowel clips from all other speech tasks, and a normalized sex field with values **M**, **F**, or **UNK**, based on the original D2 labels. A dataset and data loader are then created to load audio files, enforce the expected sample rate, and build attention masks that ignore padded regions, with extra handling to reduce the effect of trailing silence in vowel clips.

The cell then finds the correct trained model heads by searching the training and validation runs for the most recent folder whose name includes the trainEnh3 identifier and that contains all required files. The selected experiment folder is printed, and the validation-derived threshold saved during training is read only as a reference. A new timestamped output folder is created so results from this sweep do not overwrite earlier runs.

Inference is run once per seed. For each seed, random states are fixed for repeatability, the frozen speech backbone and two task-specific heads are rebuilt, and the saved head weights are loaded. The model is then run on the full D2 test set to produce one probability score per clip. **AUROC is reported for each seed and summarized across seeds** as a threshold-independent performance check. The true labels, predicted probabilities, and sex labels are stored and reused for all threshold analysis.

The threshold sweep then evaluates every value in the predefined grid. For each threshold and each seed, standard classification metrics are computed, including accuracy, precision, sensitivity, specificity, F1 score, Matthews correlation coefficient, and Fisher’s exact test p-value. Fairness is measured using false negative rates among Parkinson’s cases, calculated separately for males and females, along with the signed and absolute sex gaps. Results are combined across seeds, keeping both average values and variability, while also preserving per-seed results. All threshold results are collected into a single table.

Finally, **Policy B+** is applied in a transparent way. The cell first checks whether any thresholds meet both the sensitivity and specificity requirements. If so, it selects the threshold with the smallest average absolute sex gap, using higher sensitivity and specificity as tie-breakers. If no threshold meets both requirements, it reports which conditions failed and selects the closest available option based on how far the results fall short, again favoring smaller fairness gaps and stronger performance. The chosen threshold and its main metrics are printed with a short explanation.

All outputs are written to the timestamped sweep folder. These include the full sweep table, a summary file explaining the threshold choice, and plots showing how sensitivity, specificity, and fairness change with threshold, along with a trade-off curve that highlights the selected operating point.

In [None]:
# =========================
# Post Hoc Threshold Sweep With Fairness Guardrail
# =========================
# Runs inference one time per seed using an existing trained D7 trainEnh3 head,
# then sweeps decision thresholds on the fixed probabilities to study tradeoffs.
# Inputs: saved trained heads (per seed) and the D2 test manifest with clip paths.
# Outputs: a sweep table, a short JSON summary, and a few PNG plots saved in a new sweep folder.
# =========================

import os, json, math, random, time, warnings
from pathlib import Path

import numpy as np
import pandas as pd
import soundfile as sf

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import Wav2Vec2Model
from sklearn.metrics import roc_auc_score, confusion_matrix, matthews_corrcoef
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# -------------------------
# Safety checks for common Colab import conflicts
# -------------------------
# Stop early if local files would override real PyTorch or Transformers imports.
if os.path.exists("/content/torch.py") or os.path.exists("/content/torch/__init__.py"):
    raise RuntimeError("Found local /content/torch.py or /content/torch/ that shadows PyTorch. Rename/remove it and restart runtime.")
if os.path.exists("/content/transformers.py") or os.path.exists("/content/transformers/__init__.py"):
    raise RuntimeError("Found local /content/transformers.py or /content/transformers/ that shadows Hugging Face Transformers. Rename/remove it and restart runtime.")

# -------------------------
# Drive mount (best-effort)
# -------------------------
# Mount Drive if running in Colab and not already mounted.
try:
    from google.colab import drive  # type: ignore
    if not os.path.isdir("/content/drive/MyDrive"):
        drive.mount("/content/drive")
except Exception:
    pass

# -------------------------
# Root paths used by this run
# -------------------------
# Use notebook globals if already set, otherwise fall back to the defaults below.
D7_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D7-Multilingual (D1_D4_D5v2_D6)/preprocessed_v1"
D2_OUT_ROOT_FALLBACK = "/content/drive/MyDrive/AI_PD_Project/Datasets/D2-Slovak (EWA-DB)/EWA-DB/preprocessed_v1"

D7_OUT_ROOT = str(globals().get("D7_OUT_ROOT", D7_OUT_ROOT_FALLBACK))
D2_OUT_ROOT = str(globals().get("D2_OUT_ROOT", D2_OUT_ROOT_FALLBACK))
D2_MANIFEST_ALL = f"{D2_OUT_ROOT}/manifests/manifest_all.csv"

# Re-export for downstream cells that may rely on these globals.
globals()["D7_OUT_ROOT"] = D7_OUT_ROOT
globals()["D2_OUT_ROOT"] = D2_OUT_ROOT

# -------------------------
# Fixed run settings
# -------------------------
# Seeds: run inference once per seed, then average metrics across seeds during the sweep.
SEEDS          = [1337, 2024, 7777]

# Backbone and audio expectations must match training.
BACKBONE_CKPT  = "facebook/wav2vec2-base"
SR_EXPECTED    = 16000
TINY_THRESH    = 1e-4

# DataLoader and inference settings.
PER_DEVICE_BS  = 16
NUM_WORKERS    = 0
PIN_MEMORY     = False
USE_AMP        = True
DROPOUT_P      = 0.2

# Which trainval experiment to load (matches exp_* folder name).
REQUIRED_EXP_SUBSTRING = "trainEnh3"  # case-insensitive match against exp_* folder name

# Threshold selection rule: prioritize small sex fairness gap while keeping performance acceptable.
TARGET_SENS = 0.60
MIN_SPEC    = 0.50
POLICY_TEXT = (
    "Policy B+: minimize mean(|ΔFNR|) subject to "
    f"mean(sensitivity) >= {TARGET_SENS:.2f} AND mean(specificity) >= {MIN_SPEC:.2f}"
)

# Threshold grid for the sweep.
THR_MIN, THR_MAX, THR_STEPS = 0.01, 0.99, 199
THR_GRID = np.linspace(THR_MIN, THR_MAX, THR_STEPS).astype(np.float64)

# Compute device selection.
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

# Reduce noisy warnings for cleaner notebook output.
warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None")
warnings.filterwarnings("ignore", message="Passing `gradient_checkpointing` to a config initialization is deprecated")
warnings.filterwarnings("ignore", category=UserWarning, module="huggingface_hub")

# Print key settings for quick verification.
print("D7_OUT_ROOT:", D7_OUT_ROOT)
print("D2_OUT_ROOT:", D2_OUT_ROOT)
print("D2_MANIFEST_ALL:", D2_MANIFEST_ALL)
print("DEVICE:", DEVICE)
if DEVICE.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
print("PER_DEVICE_BS:", PER_DEVICE_BS)
print("NUM_WORKERS:", NUM_WORKERS, "| PIN_MEMORY:", PIN_MEMORY)
print("USE_AMP:", bool(USE_AMP and DEVICE.type == "cuda"))
print("Required exp substring (case-insensitive):", REQUIRED_EXP_SUBSTRING)
print("Policy B+ constraints:")
print("  mean sensitivity >=", TARGET_SENS)
print("  mean specificity >=", MIN_SPEC)
print(f"Threshold sweep grid: {THR_MIN:.2f}..{THR_MAX:.2f} with {THR_STEPS} steps")

# -------------------------
# Load D2 manifest and keep only the test split
# -------------------------
# Read the manifest, validate required columns, confirm the dataset is D2, then filter to split == "test".
if not os.path.exists(D2_MANIFEST_ALL):
    raise FileNotFoundError(f"Missing D2 manifest_all.csv: {D2_MANIFEST_ALL}")

m_all = pd.read_csv(D2_MANIFEST_ALL)

req_cols = {"split", "clip_path", "label_num", "task", "sex", "age"}
missing = [c for c in sorted(req_cols) if c not in m_all.columns]
if missing:
    raise ValueError(f"D2 manifest missing required columns: {missing}. Found: {list(m_all.columns)}")

# Identify dataset id when available, then restrict to that dataset block.
if "dataset" in m_all.columns and m_all["dataset"].notna().any():
    d2_dataset_id = str(m_all["dataset"].astype(str).value_counts(dropna=True).idxmax())
    m_all = m_all[m_all["dataset"].astype(str) == d2_dataset_id].copy()
else:
    d2_dataset_id = "DX"

# Hard guard: prevent accidentally running on the wrong manifest.
if d2_dataset_id != "D2":
    raise RuntimeError(
        f"Expected D2 dataset_id=='D2' but got {d2_dataset_id!r}. "
        "This usually means D2_OUT_ROOT is wrong or the manifest is not D2. "
        f"D2_OUT_ROOT={D2_OUT_ROOT}"
    )

# Keep a consistent set of columns used by the dataset class.
keep_cols = ["clip_path", "label_num", "task", "speaker_id", "sex", "age", "duration_sec", "split"]
for c in keep_cols:
    if c not in m_all.columns:
        m_all[c] = np.nan
m_all = m_all[keep_cols].copy()

test_df = m_all[m_all["split"].astype(str) == "test"].reset_index(drop=True)

print(f"\nD2 dataset inferred: {d2_dataset_id}")
print(f"D2 TEST rows: {len(test_df)}")
print("D2 TEST label counts:", test_df["label_num"].value_counts(dropna=False).to_dict())
print("D2 TEST sex counts (raw):", test_df["sex"].value_counts(dropna=False).to_dict())

if len(test_df) == 0:
    raise RuntimeError("After filtering to split=='test', D2 manifest has 0 rows.")

# -------------------------
# Fail fast if audio files are missing
# -------------------------
# Checks a few missing paths quickly (up to 10 examples) before spending time on inference.
def _fail_fast_missing_paths(df: pd.DataFrame, name: str):
    missing_paths = []
    for p in tqdm(df["clip_path"].astype(str).tolist(), desc=f"Check {name} clip_path exists", dynamic_ncols=True):
        if not os.path.exists(p):
            missing_paths.append(p)
            if len(missing_paths) >= 10:
                break
    if missing_paths:
        raise FileNotFoundError(f"{name}: missing clip_path(s). Examples: {missing_paths}")

_fail_fast_missing_paths(test_df, "D2 TEST")

# -------------------------
# Add analysis-friendly columns (task group, normalized sex)
# -------------------------
# task_group is used to choose the correct head (vowel vs other).
def _task_group(task_val) -> str:
    return "vowel" if str(task_val) == "vowl" else "other"

test_df["task_group"] = test_df["task"].apply(_task_group)

# Normalize sex using D2's exact encoding ("male"/"female"); anything else becomes UNK.
def normalize_sex_d2_case_sensitive(val) -> str:
    if pd.isna(val):
        return "UNK"
    if val == "male":
        return "M"
    if val == "female":
        return "F"
    return "UNK"

test_df["sex_norm"] = test_df["sex"].apply(normalize_sex_d2_case_sensitive)
print("D2 TEST sex counts (normalized):", test_df["sex_norm"].value_counts(dropna=False).to_dict())
if (test_df["sex_norm"] == "UNK").any():
    print("NOTE: Some D2 'sex' values were not exactly 'male'/'female' and were mapped to 'UNK'.")

# -------------------------
# Dataset and collator for variable-length audio
# -------------------------
# Builds per-clip tensors and an attention mask. For vowel clips, the attention mask ignores trailing padding/silence.
class AudioManifestDataset(Dataset):
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        clip_path = str(row["clip_path"])
        label = int(row["label_num"])
        task_group = str(row["task_group"])
        sex_norm = str(row["sex_norm"])

        if not os.path.exists(clip_path):
            raise FileNotFoundError(f"Missing clip_path: {clip_path}")

        y, sr = sf.read(clip_path, always_2d=False)
        if isinstance(y, np.ndarray) and y.ndim > 1:
            y = np.mean(y, axis=1)
        y = np.asarray(y, dtype=np.float32)

        if int(sr) != SR_EXPECTED:
            raise ValueError(f"Sample rate mismatch (expected {SR_EXPECTED}): {clip_path} has sr={sr}")

        # Attention mask marks which samples should count during pooling.
        attn = np.ones((len(y),), dtype=np.int64)
        if task_group == "vowel":
            # Find last non-tiny sample, then mask out the remainder as padding/silence.
            k = -1
            for j in range(len(y) - 1, -1, -1):
                if abs(float(y[j])) > TINY_THRESH:
                    k = j
                    break
            if k >= 0:
                attn[:k+1] = 1
                attn[k+1:] = 0
            else:
                attn[:] = 1
        else:
            # Non-vowel clips are treated as fully valid audio.
            attn[:] = 1

        return {
            "input_values": torch.from_numpy(y),
            "attention_mask": torch.from_numpy(attn),
            "labels": torch.tensor(label, dtype=torch.long),
            "task_group": task_group,
            "sex_norm": sex_norm,
        }

# Pads a batch to the longest clip length and stacks tensors for model input.
def collate_fn(batch):
    max_len = int(max(b["input_values"].numel() for b in batch))
    input_vals, attn_masks, labels, task_groups, sex_norms = [], [], [], [], []
    for b in batch:
        x = b["input_values"]
        a = b["attention_mask"]
        pad = max_len - x.numel()
        if pad > 0:
            x = torch.cat([x, torch.zeros(pad, dtype=x.dtype)], dim=0)
            a = torch.cat([a, torch.zeros(pad, dtype=a.dtype)], dim=0)
        input_vals.append(x)
        attn_masks.append(a)
        labels.append(b["labels"])
        task_groups.append(b["task_group"])
        sex_norms.append(b["sex_norm"])
    return {
        "input_values": torch.stack(input_vals, dim=0),
        "attention_mask": torch.stack(attn_masks, dim=0),
        "labels": torch.stack(labels, dim=0),
        "task_group": task_groups,
        "sex_norm": sex_norms,
    }

# DataLoader over the D2 test split.
test_ds = AudioManifestDataset(test_df)
test_loader = DataLoader(
    test_ds,
    batch_size=PER_DEVICE_BS,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    collate_fn=collate_fn,
)

# -------------------------
# Model definition (backbone + two task heads)
# -------------------------
# Uses a frozen Wav2Vec2 backbone and switches between vowel and other heads per clip.
class Wav2Vec2TwoHeadClassifier(nn.Module):
    def __init__(self, ckpt: str, dropout_p: float):
        super().__init__()
        self.backbone = Wav2Vec2Model.from_pretrained(ckpt, use_safetensors=True)
        for p in self.backbone.parameters():
            p.requires_grad = False

        H = int(self.backbone.config.hidden_size)
        self.pre_vowel = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.pre_other = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.head_vowel = nn.Linear(H, 2)
        self.head_other = nn.Linear(H, 2)

    # Pools frame-level features into one vector per clip, ignoring masked-out regions.
    def masked_mean_pool(self, last_hidden, attn_mask_samples):
        feat_mask = self.backbone._get_feature_vector_attention_mask(last_hidden.shape[1], attn_mask_samples)
        feat_mask = feat_mask.to(last_hidden.device).unsqueeze(-1).type_as(last_hidden)
        masked = last_hidden * feat_mask
        denom = feat_mask.sum(dim=1).clamp(min=1.0)
        return masked.sum(dim=1) / denom

    # Runs heads in fp32 to avoid mixed-precision edge cases.
    def _heads_fp32(self, x_any: torch.Tensor, head: nn.Module) -> torch.Tensor:
        x = x_any.float()
        if x.is_cuda:
            with torch.amp.autocast(device_type="cuda", enabled=False):
                return head(x)
        return head(x)

    # Produces logits for PD vs HC, choosing the correct head based on task_group.
    def forward_logits(self, input_values, attention_mask, task_group):
        with torch.no_grad():
            out = self.backbone(input_values=input_values, attention_mask=attention_mask)
            last_hidden = out.last_hidden_state

        pooled = self.masked_mean_pool(last_hidden, attention_mask)
        z_v = self.pre_vowel(pooled.float())
        z_o = self.pre_other(pooled.float())

        logits_v = self._heads_fp32(z_v, self.head_vowel)
        logits_o = self._heads_fp32(z_o, self.head_other)

        is_vowel = torch.tensor([tg == "vowel" for tg in task_group], device=pooled.device, dtype=torch.bool)
        logits = logits_o.clone()
        if is_vowel.any():
            logits[is_vowel] = logits_v[is_vowel]
        return logits

# Load only the trained head weights from best_heads.pt into the model.
def load_heads_into_model(model: Wav2Vec2TwoHeadClassifier, best_heads_path: Path):
    if not best_heads_path.exists():
        raise FileNotFoundError(f"Missing best_heads.pt: {str(best_heads_path)}")
    state = torch.load(str(best_heads_path), map_location="cpu")
    model.pre_vowel.load_state_dict(state["pre_vowel"], strict=True)
    model.pre_other.load_state_dict(state["pre_other"], strict=True)
    model.head_vowel.load_state_dict(state["head_vowel"], strict=True)
    model.head_other.load_state_dict(state["head_other"], strict=True)
    return model

# -------------------------
# Deterministic setup and one-pass inference per seed
# -------------------------
# Inference returns fixed probabilities, labels, and sex labels, used later for threshold sweeping.
def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def infer_once(seed: int, chosen_exp: Path, train_dataset_id: str):
    set_all_seeds(seed)
    best_heads_path = chosen_exp / f"run_{train_dataset_id}_seed{seed}" / "best_heads.pt"

    model = Wav2Vec2TwoHeadClassifier(BACKBONE_CKPT, dropout_p=DROPOUT_P).to(DEVICE)
    model = load_heads_into_model(model, best_heads_path)
    model.eval()

    use_amp = bool(USE_AMP and DEVICE.type == "cuda")
    all_probs, all_true, all_sex = [], [], []

    # Single forward pass over the test set for this seed.
    pbar = tqdm(test_loader, desc=f"[seed={seed}] Inference D2 TEST", dynamic_ncols=True)
    with torch.inference_mode():
        for batch in pbar:
            input_values = batch["input_values"].to(DEVICE, non_blocking=False)
            attention_mask = batch["attention_mask"].to(DEVICE, non_blocking=False)
            labels = batch["labels"].to(DEVICE, non_blocking=False)
            task_group = batch["task_group"]
            sex_norm = batch["sex_norm"]

            with torch.amp.autocast(device_type="cuda", enabled=use_amp):
                logits = model.forward_logits(input_values, attention_mask, task_group)

            probs = torch.softmax(logits.float(), dim=-1)[:, 1].detach().cpu().numpy().astype(np.float64)
            all_probs.extend(probs.tolist())
            all_true.extend(labels.detach().cpu().numpy().astype(np.int64).tolist())
            all_sex.extend(list(sex_norm))

    y_true = np.asarray(all_true, dtype=np.int64)
    y_prob = np.asarray(all_probs, dtype=np.float64)
    y_sex  = np.asarray(all_sex, dtype=object)

    # AUROC is reported for reference since it does not depend on threshold.
    auc = float("nan")
    if len(np.unique(y_true)) >= 2:
        auc = float(roc_auc_score(y_true, y_prob))

    return {
        "seed": int(seed),
        "best_heads_path": str(best_heads_path),
        "y_true": y_true,
        "y_prob": y_prob,
        "y_sex": y_sex,
        "auroc": float(auc),
    }

# -------------------------
# Threshold-dependent metrics and fairness helper
# -------------------------
# Provides standard classification metrics and a sex-based ΔFNR (F minus M) on PD-only cases.
def confusion_counts(y_true, y_prob, thr):
    y_pred = (y_prob >= float(thr)).astype(np.int64)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    TN, FP, FN, TP = int(cm[0, 0]), int(cm[0, 1]), int(cm[1, 0]), int(cm[1, 1])
    return TN, FP, FN, TP

def threshold_metrics(y_true, y_prob, thr):
    TN, FP, FN, TP = confusion_counts(y_true, y_prob, thr)
    eps = 1e-12
    acc  = (TP + TN) / max(1, (TP + TN + FP + FN))
    prec = TP / (TP + FP + eps)
    sens = TP / (TP + FN + eps)
    spec = TN / (TN + FP + eps)
    f1   = 2 * prec * sens / (prec + sens + eps)

    # MCC may be undefined if predictions are all one class.
    try:
        y_pred = (y_prob >= float(thr)).astype(np.int64)
        mcc = float(matthews_corrcoef(y_true, y_pred)) if len(np.unique(y_pred)) > 1 else float("nan")
    except Exception:
        mcc = float("nan")

    # Fisher test on the 2x2 confusion table (may fail if scipy not available).
    pval = float("nan")
    try:
        from scipy.stats import fisher_exact  # type: ignore
        _, pval = fisher_exact([[TN, FP], [FN, TP]], alternative="two-sided")
        pval = float(pval)
    except Exception:
        pval = float("nan")

    return {
        "tn": TN, "fp": FP, "fn": FN, "tp": TP,
        "accuracy": float(acc),
        "precision": float(prec),
        "sensitivity": float(sens),
        "specificity": float(spec),
        "f1": float(f1),
        "mcc": float(mcc),
        "fisher_p_two_sided": float(pval),
    }

def fnr_by_sex_signed_delta(y_true, y_prob, y_sex, thr):
    # FNR is computed on true PD cases within each sex: FN / (FN + TP).
    y_pred = (y_prob >= float(thr)).astype(np.int64)
    out = {}

    for g in ["M", "F"]:
        mask_g = (y_sex == g)
        pos_mask = mask_g & (y_true == 1)
        n_pos = int(pos_mask.sum())
        if n_pos == 0:
            out[g] = {"n_pos": 0, "fn": 0, "tp": 0, "fnr": float("nan")}
            continue
        tp = int(((y_pred == 1) & pos_mask).sum())
        fn = int(((y_pred == 0) & pos_mask).sum())
        fnr = float(fn / max(1, (fn + tp)))
        out[g] = {"n_pos": int(n_pos), "fn": int(fn), "tp": int(tp), "fnr": float(fnr)}

    fnr_m = out["M"]["fnr"]
    fnr_f = out["F"]["fnr"]
    if (not np.isnan(fnr_m)) and (not np.isnan(fnr_f)):
        delta = float(fnr_f - fnr_m)     # ΔFNR = F - M
        absd  = float(abs(delta))
    else:
        delta = float("nan")
        absd  = float("nan")

    return out, delta, absd

# Small utilities to summarize across seeds while tolerating NaNs.
def mean_sd(vals):
    vals = np.asarray(vals, dtype=np.float64)
    mu = float(np.nanmean(vals)) if np.any(~np.isnan(vals)) else float("nan")
    sd = float(np.nanstd(vals, ddof=1)) if np.sum(~np.isnan(vals)) > 1 else 0.0
    return mu, sd

def safe_nanmean(x):
    x = np.asarray(x, dtype=np.float64)
    if np.any(~np.isnan(x)):
        return float(np.nanmean(x))
    return float("nan")

def safe_nansd(x):
    x = np.asarray(x, dtype=np.float64)
    if np.sum(~np.isnan(x)) > 1:
        return float(np.nanstd(x, ddof=1))
    return 0.0

# -------------------------
# Locate the most recent matching trainval experiment
# -------------------------
# Finds the newest exp_* folder that matches REQUIRED_EXP_SUBSTRING and contains all needed artifacts.
TRAINVAL_ROOT = Path(D7_OUT_ROOT) / "trainval_runs"
if not TRAINVAL_ROOT.exists():
    raise FileNotFoundError(f"Missing trainval_runs folder under D7_OUT_ROOT: {str(TRAINVAL_ROOT)}")

exp_dirs = sorted([p for p in TRAINVAL_ROOT.glob("exp_*") if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
if not exp_dirs:
    raise FileNotFoundError(f"No exp_* folders found under: {str(TRAINVAL_ROOT)}")

train_dataset_id = "D7"

def is_match(exp_path: Path, required_substring: str) -> bool:
    return required_substring.lower() in exp_path.name.lower()

def has_all_seeds_and_summary(exp_path: Path, dataset_id: str, seeds: list) -> bool:
    if not (exp_path / "summary_trainval.json").exists():
        return False
    for s in seeds:
        p = exp_path / f"run_{dataset_id}_seed{s}" / "best_heads.pt"
        if not p.exists():
            return False
    return True

chosen_exp = None
for ed in exp_dirs:
    if not is_match(ed, REQUIRED_EXP_SUBSTRING):
        continue
    if has_all_seeds_and_summary(ed, train_dataset_id, SEEDS):
        chosen_exp = ed
        break

if chosen_exp is None:
    examples = [p.name for p in exp_dirs[:10]]
    raise FileNotFoundError(
        "Could not find a recent D7 trainval experiment folder that:\n"
        f"  (1) contains substring '{REQUIRED_EXP_SUBSTRING}' (case-insensitive) in the exp folder name, and\n"
        "  (2) contains all 3 best_heads.pt files + summary_trainval.json.\n"
        f"Checked under: {str(TRAINVAL_ROOT)}\n\n"
        f"Example exp_* folder names (most recent first):\n  - " + "\n  - ".join(examples) + "\n\n"
        "Fix: set REQUIRED_EXP_SUBSTRING to match the exact tag used in the exp folder name."
    )

print("\nUsing trainval experiment folder:")
print(" ", str(chosen_exp))

# Load the trainval summary only to show the mean Youden-J threshold as a reference point.
summary_trainval_path = chosen_exp / "summary_trainval.json"
with open(summary_trainval_path, "r", encoding="utf-8") as f:
    trainval_summary = json.load(f)

val_opt = ((trainval_summary or {}).get("val_optimal_threshold", {}) or {})
thr_mean = float((((val_opt.get("mean_sd", {}) or {}).get("mean", float("nan")))))
print("\nTrainval mean val-opt threshold (Youden J) for reference:")
print("  summary_trainval.json -> val_optimal_threshold.mean_sd.mean =", f"{thr_mean:.6f}" if not np.isnan(thr_mean) else "nan")

# -------------------------
# Create a new output folder for this sweep run
# -------------------------
# Writes sweep table, summary JSON, and plots into a timestamped folder to avoid overwrites.
TS_ROOT = Path(D7_OUT_ROOT) / "threshold_sweeps"
TS_ROOT.mkdir(parents=True, exist_ok=True)
timestamp = time.strftime("%Y%m%d_%H%M%S")
OUT_DIR = TS_ROOT / f"run_D7_{REQUIRED_EXP_SUBSTRING}_on_D2test_{timestamp}"
OUT_DIR.mkdir(parents=True, exist_ok=True)

# -------------------------
# Inference once per seed (fixed predictions)
# -------------------------
# Produces y_prob arrays per seed; all later threshold results reuse these probabilities.
print("\nRunning inference ONCE per seed (then sweep thresholds on fixed predictions)...")
seed_payloads = []
for s in SEEDS:
    seed_payloads.append(infer_once(s, chosen_exp, train_dataset_id))

print("\nAUROC by seed (ranking metric, threshold-independent):")
for sp in seed_payloads:
    print(f"  seed {sp['seed']}: AUROC={sp['auroc']:.6f}")
auc_mean, auc_sd = mean_sd([sp["auroc"] for sp in seed_payloads])
print(f"Mean AUROC: {auc_mean:.6f} ± {auc_sd:.6f}")

# -------------------------
# Sweep thresholds and aggregate across seeds
# -------------------------
# For each threshold, compute mean sensitivity, mean specificity, and mean |ΔFNR| across seeds.
rows = []
for thr in tqdm(THR_GRID, desc="Threshold sweep", dynamic_ncols=True):
    sens_list, absd_list, signd_list = [], [], []
    acc_list, prec_list, spec_list, f1_list, mcc_list, p_list = [], [], [], [], [], []

    fnr_m_list, fnr_f_list = [], []
    n_pd_m_list, n_pd_f_list = [], []

    for sp in seed_payloads:
        y_true = sp["y_true"]
        y_prob = sp["y_prob"]
        y_sex  = sp["y_sex"]

        tm = threshold_metrics(y_true, y_prob, thr)
        fnr_by_sex, delta_signed, delta_abs = fnr_by_sex_signed_delta(y_true, y_prob, y_sex, thr)

        sens_list.append(tm["sensitivity"])
        acc_list.append(tm["accuracy"])
        prec_list.append(tm["precision"])
        spec_list.append(tm["specificity"])
        f1_list.append(tm["f1"])
        mcc_list.append(tm["mcc"])
        p_list.append(tm["fisher_p_two_sided"])

        signd_list.append(delta_signed)
        absd_list.append(delta_abs)

        fnr_m_list.append(float(fnr_by_sex["M"]["fnr"]))
        fnr_f_list.append(float(fnr_by_sex["F"]["fnr"]))
        n_pd_m_list.append(float(fnr_by_sex["M"]["n_pos"]))
        n_pd_f_list.append(float(fnr_by_sex["F"]["n_pos"]))

    # One row per threshold, with means across seeds and seed-level details stored as JSON strings.
    row = {
        "threshold": float(thr),

        "mean_sensitivity": safe_nanmean(sens_list),
        "sd_sensitivity": safe_nansd(sens_list),

        "mean_specificity": safe_nanmean(spec_list),
        "sd_specificity": safe_nansd(spec_list),

        "mean_abs_deltaFNR": safe_nanmean(absd_list),
        "sd_abs_deltaFNR": safe_nansd(absd_list),

        "mean_signed_deltaFNR_F_minus_M": safe_nanmean(signd_list),
        "sd_signed_deltaFNR": safe_nansd(signd_list),

        "mean_accuracy": safe_nanmean(acc_list),
        "mean_precision": safe_nanmean(prec_list),
        "mean_f1": safe_nanmean(f1_list),
        "mean_mcc": safe_nanmean(mcc_list),
        "mean_fisher_p_two_sided": safe_nanmean(p_list),

        "mean_FNR_M": safe_nanmean(fnr_m_list),
        "mean_FNR_F": safe_nanmean(fnr_f_list),
        "n_PD_M_each_seed": json.dumps({str(SEEDS[i]): int(n_pd_m_list[i]) for i in range(len(SEEDS))}),
        "n_PD_F_each_seed": json.dumps({str(SEEDS[i]): int(n_pd_f_list[i]) for i in range(len(SEEDS))}),
    }

    # Store per-seed values to explain the averages when needed.
    row["sensitivity_by_seed"] = json.dumps({str(sp["seed"]): float(sens_list[i]) for i, sp in enumerate(seed_payloads)})
    row["specificity_by_seed"] = json.dumps({str(sp["seed"]): float(spec_list[i]) for i, sp in enumerate(seed_payloads)})
    row["abs_deltaFNR_by_seed"] = json.dumps({str(sp["seed"]): float(absd_list[i]) for i, sp in enumerate(seed_payloads)})
    row["signed_deltaFNR_by_seed"] = json.dumps({str(sp["seed"]): float(signd_list[i]) for i, sp in enumerate(seed_payloads)})

    rows.append(row)

sweep_df = pd.DataFrame(rows)

# -------------------------
# Choose a threshold using Policy B+ and explain the decision
# -------------------------
# Policy B+ prefers thresholds with low mean |ΔFNR| but rejects thresholds that break sensitivity/specificity constraints.
max_mean_sens = float(sweep_df["mean_sensitivity"].max())
thr_at_max_sens = float(sweep_df.loc[sweep_df["mean_sensitivity"].idxmax(), "threshold"])

max_mean_spec = float(sweep_df["mean_specificity"].max())
thr_at_max_spec = float(sweep_df.loc[sweep_df["mean_specificity"].idxmax(), "threshold"])

eligible = sweep_df[
    (sweep_df["mean_sensitivity"] >= TARGET_SENS) &
    (sweep_df["mean_specificity"] >= MIN_SPEC)
].copy()

constraint_reached = (len(eligible) > 0)

# Summarizes which constraint is met and how far below the target the other is (if any).
def _constraint_status_at_row(row_dict):
    sens_gap = float(TARGET_SENS - row_dict["mean_sensitivity"])
    spec_gap = float(MIN_SPEC - row_dict["mean_specificity"])
    return {
        "sens_ok": bool(row_dict["mean_sensitivity"] >= TARGET_SENS),
        "spec_ok": bool(row_dict["mean_specificity"] >= MIN_SPEC),
        "sens_gap_needed": float(max(0.0, sens_gap)),
        "spec_gap_needed": float(max(0.0, spec_gap)),
    }

if constraint_reached:
    # Select best eligible threshold: lowest mean |ΔFNR|, then stronger sensitivity/specificity, then lower threshold.
    eligible = eligible.sort_values(
        by=["mean_abs_deltaFNR", "mean_sensitivity", "mean_specificity", "threshold"],
        ascending=[True, False, False, True]
    )
    chosen = eligible.iloc[0].to_dict()

    policy_note = (
        "Constraints reached:\n"
        f"  - mean(sensitivity) >= {TARGET_SENS:.2f}\n"
        f"  - mean(specificity) >= {MIN_SPEC:.2f}\n"
        "Chosen threshold minimizes mean(|ΔFNR|) among eligible thresholds "
        "(tie-breakers: higher sensitivity, higher specificity, then lower threshold)."
    )
else:
    # If constraints cannot be met together, pick the closest option and state the gaps clearly.
    can_reach_sens = bool((sweep_df["mean_sensitivity"] >= TARGET_SENS).any())
    can_reach_spec = bool((sweep_df["mean_specificity"] >= MIN_SPEC).any())

    working_df = sweep_df.copy()
    subset = working_df.copy()
    if can_reach_sens:
        subset = subset[subset["mean_sensitivity"] >= TARGET_SENS].copy()
    if can_reach_spec:
        subset = subset[subset["mean_specificity"] >= MIN_SPEC].copy()

    if len(subset) == 0:
        subset = working_df.copy()

    subset["_sens_gap"] = np.maximum(0.0, TARGET_SENS - subset["mean_sensitivity"])
    subset["_spec_gap"] = np.maximum(0.0, MIN_SPEC - subset["mean_specificity"])
    subset["_total_gap"] = subset["_sens_gap"] + subset["_spec_gap"]

    subset = subset.sort_values(
        by=["_total_gap", "mean_abs_deltaFNR", "mean_sensitivity", "mean_specificity", "threshold"],
        ascending=[True, True, False, False, True]
    )

    chosen = subset.iloc[0].drop(labels=["_sens_gap", "_spec_gap", "_total_gap"]).to_dict()
    status = _constraint_status_at_row(chosen)

    failed_parts = []
    if not can_reach_sens:
        failed_parts.append(
            f"mean(sensitivity) >= {TARGET_SENS:.2f} (UNREACHABLE on this grid; max was {max_mean_sens:.6f} at thr={thr_at_max_sens:.4f})"
        )
    if not can_reach_spec:
        failed_parts.append(
            f"mean(specificity) >= {MIN_SPEC:.2f} (UNREACHABLE on this grid; max was {max_mean_spec:.6f} at thr={thr_at_max_spec:.4f})"
        )
    if not failed_parts:
        failed_parts.append(
            "Both constraints are individually reachable, but no single threshold meets BOTH at the same time on this grid."
        )

    policy_note = (
        "Constraints NOT jointly reachable.\n"
        "What failed:\n"
        "  - " + "\n  - ".join(failed_parts) + "\n"
        "Returned the threshold that minimizes total constraint violation (sum of gaps), then minimizes mean(|ΔFNR|), "
        "then prefers higher sensitivity and higher specificity.\n"
        f"At the chosen threshold, remaining gaps are:\n"
        f"  - sensitivity gap needed: {status['sens_gap_needed']:.6f}\n"
        f"  - specificity gap needed: {status['spec_gap_needed']:.6f}"
    )

# Console output stays as-is; comments only clarify the intent.
print("\n================ POLICY RESULT ================")
print(POLICY_TEXT)
print("Constraint reached (both)?:", bool(constraint_reached))
print("Max achievable mean(sensitivity) on grid:", f"{max_mean_sens:.6f}", "| at threshold:", f"{thr_at_max_sens:.4f}")
print("Max achievable mean(specificity) on grid:", f"{max_mean_spec:.6f}", "| at threshold:", f"{thr_at_max_spec:.4f}")

print("\nChosen threshold:", f"{float(chosen['threshold']):.4f}")
print("Chosen mean(sensitivity):", f"{float(chosen['mean_sensitivity']):.6f}")
print("Chosen mean(specificity):", f"{float(chosen['mean_specificity']):.6f}")
print("Chosen mean(|ΔFNR|):", f"{float(chosen['mean_abs_deltaFNR']):.6f}")
print("Chosen mean signed ΔFNR (F-M):", f"{float(chosen['mean_signed_deltaFNR_F_minus_M']):.6f}")

print("\nDetails:")
print(policy_note)

# -------------------------
# Save sweep table and a compact summary JSON
# -------------------------
# Writes the full sweep grid as CSV, plus a summary JSON with the chosen threshold and key metrics.
sweep_csv = OUT_DIR / "sweep_table.csv"
sweep_df.to_csv(sweep_csv, index=False)

summary = {
    "train_dataset": "D7",
    "test_dataset": "D2",
    "split_swept": "D2 TEST",
    "required_exp_substring_case_insensitive": REQUIRED_EXP_SUBSTRING,
    "trainval_experiment_used": str(chosen_exp),
    "trainval_summary_path": str(summary_trainval_path),
    "seeds": SEEDS,
    "policy": POLICY_TEXT,
    "constraints": {
        "target_mean_sensitivity": float(TARGET_SENS),
        "min_mean_specificity": float(MIN_SPEC),
    },
    "constraint_reached_both": bool(constraint_reached),
    "max_mean_sensitivity_on_grid": float(max_mean_sens),
    "threshold_at_max_mean_sensitivity": float(thr_at_max_sens),
    "max_mean_specificity_on_grid": float(max_mean_spec),
    "threshold_at_max_mean_specificity": float(thr_at_max_spec),
    "youdenJ_mean_threshold_reference": (None if np.isnan(thr_mean) else float(thr_mean)),
    "chosen_threshold": float(chosen["threshold"]),
    "chosen_metrics": {
        "mean_sensitivity": float(chosen["mean_sensitivity"]),
        "mean_specificity": float(chosen["mean_specificity"]),
        "mean_abs_deltaFNR": float(chosen["mean_abs_deltaFNR"]),
        "mean_signed_deltaFNR_F_minus_M": float(chosen["mean_signed_deltaFNR_F_minus_M"]),
        "mean_accuracy": float(chosen["mean_accuracy"]),
        "mean_precision": float(chosen["mean_precision"]),
        "mean_f1": float(chosen["mean_f1"]),
        "mean_mcc": float(chosen["mean_mcc"]),
        "mean_fisher_p_two_sided": float(chosen["mean_fisher_p_two_sided"]),
        "mean_FNR_M": float(chosen["mean_FNR_M"]),
        "mean_FNR_F": float(chosen["mean_FNR_F"]),
        "sensitivity_by_seed": chosen.get("sensitivity_by_seed", ""),
        "specificity_by_seed": chosen.get("specificity_by_seed", ""),
        "abs_deltaFNR_by_seed": chosen.get("abs_deltaFNR_by_seed", ""),
        "signed_deltaFNR_by_seed": chosen.get("signed_deltaFNR_by_seed", ""),
        "n_PD_M_each_seed": chosen.get("n_PD_M_each_seed", ""),
        "n_PD_F_each_seed": chosen.get("n_PD_F_each_seed", ""),
    },
    "policy_note_transparent": policy_note,
    "paths": {
        "out_dir": str(OUT_DIR),
        "sweep_table_csv": str(sweep_csv),
    },
    "timestamp": timestamp,
}

summary_json = OUT_DIR / "sweep_summary.json"
with open(summary_json, "w", encoding="utf-8") as f:
    json.dump(summary, f, indent=2)

# -------------------------
# Plots for quick visual checks
# -------------------------
# Saves four simple plots: sensitivity vs threshold, specificity vs threshold, |ΔFNR| vs threshold, and the tradeoff curve.
plt.figure()
plt.plot(sweep_df["threshold"].values, sweep_df["mean_sensitivity"].values)
plt.axhline(TARGET_SENS, linestyle="--")
plt.axvline(float(chosen["threshold"]), linestyle="--")
plt.xlabel("Threshold")
plt.ylabel("Mean Sensitivity (across seeds)")
plt.title("Threshold Sweep: Mean Sensitivity vs Threshold (D7 trainEnh3 → D2 TEST)")
plt.tight_layout()
p1 = OUT_DIR / "sweep_sensitivity_vs_threshold.png"
plt.savefig(p1, dpi=150)
plt.close()

plt.figure()
plt.plot(sweep_df["threshold"].values, sweep_df["mean_specificity"].values)
plt.axhline(MIN_SPEC, linestyle="--")
plt.axvline(float(chosen["threshold"]), linestyle="--")
plt.xlabel("Threshold")
plt.ylabel("Mean Specificity (across seeds)")
plt.title("Threshold Sweep: Mean Specificity vs Threshold (D7 trainEnh3 → D2 TEST)")
plt.tight_layout()
p1b = OUT_DIR / "sweep_specificity_vs_threshold.png"
plt.savefig(p1b, dpi=150)
plt.close()

plt.figure()
plt.plot(sweep_df["threshold"].values, sweep_df["mean_abs_deltaFNR"].values)
plt.axvline(float(chosen["threshold"]), linestyle="--")
plt.xlabel("Threshold")
plt.ylabel("Mean |ΔFNR| (across seeds)")
plt.title("Threshold Sweep: Mean |ΔFNR| vs Threshold (D7 trainEnh3 → D2 TEST)")
plt.tight_layout()
p2 = OUT_DIR / "sweep_abs_deltaFNR_vs_threshold.png"
plt.savefig(p2, dpi=150)
plt.close()

plt.figure()
plt.plot(sweep_df["mean_sensitivity"].values, sweep_df["mean_abs_deltaFNR"].values)
plt.scatter([float(chosen["mean_sensitivity"])], [float(chosen["mean_abs_deltaFNR"])])
plt.axvline(TARGET_SENS, linestyle="--")
plt.xlabel("Mean Sensitivity (across seeds)")
plt.ylabel("Mean |ΔFNR| (across seeds)")
plt.title("Threshold Tradeoff: Sensitivity vs |ΔFNR| (D7 trainEnh3 → D2 TEST)")
plt.tight_layout()
p3 = OUT_DIR / "sweep_tradeoff.png"
plt.savefig(p3, dpi=150)
plt.close()

# Final file list for quick confirmation.
print("\n================ SAVED OUTPUTS ================")
print("OUT_DIR:", str(OUT_DIR))
print("Saved:", str(sweep_csv))
print("Saved:", str(summary_json))
print("Saved plots:")
print(" ", str(p1))
print(" ", str(p1b))
print(" ", str(p2))
print(" ", str(p3))
print("\nDone.")

#Ablation 2: Task-Head Ablation: Vowel vs Non-Vowel Speech Contributions

The following cell prepares the **mechanistic evaluation** workspace and builds a dependable run registry so subsequent analysis cells can find the correct prediction files without guessing folder names. It mounts Google Drive if needed, defines the main dataset roots for D7 and D2, and lists the three fixed random seeds used in all test runs. It also sets a single shared output directory, `.../mechanistic_eval/`, where each model condition (base, trainEnh1, trainEnh2, trainEnh3) will have its own subfolder.

Next, the cell creates a small registry called `RUNS` that points to one `tag_run_pointer.json` file for each condition. These pointer files act as stable identifiers for the runs. A helper function reads each pointer file and reliably determines the expected `predictions.csv` path for each seed, without scanning folders. It follows a clear order: it first looks for an explicit `predictions_csv_by_seed` entry, then a `runs` dictionary with per-seed information, then a `seed_run_dirs` list, and finally falls back to a standard path under the tag folder (`TAG_ROOT/run_D7_on_D2test_seed{seed}/predictions.csv`). This fallback allows the base run to work even if the pointer file contains only minimal information.

For each condition, the cell creates the matching output directory under `mechanistic_eval/<tag>/` and prints a clear summary that shows the pointer file path, the resolved tag and run roots, and the resolved `predictions.csv` path for each seed. It then checks that every resolved `predictions.csv` file exists. If any file is missing, it prints the exact seed and path that failed and stops with an error. When all checks pass, the cell builds a final dictionary called `RUNS_RESOLVED`. For each condition, this dictionary stores the pointer path, resolved roots, per-seed prediction file paths, and the output folder for mechanistic evaluation. The cell finishes by confirming that `RUNS_RESOLVED` is ready for use by the later scoring and ablation cells.


In [None]:
# =========================
# Mechanistic evaluation run registry (paths + pointer-based resolution)
# -------------------------
# Purpose
# - Centralize condition names and pointer files for base and three enhanced draws
# - Resolve the exact per-seed predictions.csv locations without searching folders
#
# Inputs
# - Root folders for D7 outputs and D2 dataset
# - One tag_run_pointer.json per condition
#
# Outputs
# - MECH_EVAL_ROOT folder created (one subfolder per condition)
# - RUNS_RESOLVED dictionary populated and printed (per-seed predictions.csv paths)
# =========================

import os, json
from pathlib import Path

# -------------------------
# Drive mount (Colab)
# -------------------------
# Input: Google Drive availability in the runtime
# Output: /content/drive/MyDrive becomes accessible for reading and writing files
if not os.path.isdir("/content/drive/MyDrive"):
    from google.colab import drive  # type: ignore
    drive.mount("/content/drive")

# -------------------------
# Dataset roots
# -------------------------
# Inputs: D7 output root (contains test runs and trainval runs) and D2 root (contains manifest and clips)
# Outputs: used by later cells for reading predictions and writing mechanistic eval artifacts
D7_OUT_ROOT = "/content/drive/MyDrive/AI_PD_Project/Datasets/D7-Multilingual (D1_D4_D5v2_D6)/preprocessed_v1"
D2_OUT_ROOT = "/content/drive/MyDrive/AI_PD_Project/Datasets/D2-Slovak (EWA-DB)/EWA-DB/preprocessed_v1"

# -------------------------
# Seeds used throughout evaluation
# -------------------------
# Input: fixed seed list
# Output: defines which per-seed prediction files must exist for each condition
SEEDS = [1337, 2024, 7777]

# -------------------------
# Condition registry (pointer JSON locations)
# -------------------------
# Input: one pointer JSON per condition
# Output: RUNS dict used to build RUNS_RESOLVED
def _latest_pointer_json(search_root: str, required_substrings: list) -> str:
    search_root = str(search_root).strip()
    root = Path(search_root)
    if not root.exists():
        raise FileNotFoundError(f"Search root not found: {search_root}")

    candidates = []
    for p in root.rglob("tag_run_pointer.json"):
        sp = str(p)
        ok = True
        for sub in required_substrings:
            if sub and (sub.lower() not in sp.lower()):
                ok = False
                break
        if ok:
            candidates.append(p)

    if not candidates:
        req = ", ".join([s for s in required_substrings if s])
        raise FileNotFoundError(
            f"No tag_run_pointer.json found under {search_root} matching: [{req}]"
        )

    candidates.sort(key=lambda x: x.stat().st_mtime, reverse=True)
    return str(candidates[0])

RUNS = {
    "base": {
        "pointer_json": _latest_pointer_json(
            f"{D7_OUT_ROOT}/multilingual_test_runs",
            ["base"],
        ),
    },
    "trainEnh1": {
        "pointer_json": _latest_pointer_json(
            f"{D7_OUT_ROOT}/monolingual_test_runs",
            ["trainEnh1"],
        ),
    },
    "trainEnh2": {
        "pointer_json": _latest_pointer_json(
            f"{D7_OUT_ROOT}/monolingual_test_runs",
            ["trainEnh2"],
        ),
    },
    "trainEnh3": {
        "pointer_json": _latest_pointer_json(
            f"{D7_OUT_ROOT}/monolingual_test_runs",
            ["trainEnh3"],
        ),
    },
}

# -------------------------
# Shared output root for mechanistic evaluation
# -------------------------
# Input: D7_OUT_ROOT
# Output: MECH_EVAL_ROOT created and used for all downstream artifacts
MECH_EVAL_ROOT = f"{D7_OUT_ROOT}/mechanistic_eval"
os.makedirs(MECH_EVAL_ROOT, exist_ok=True)

# -------------------------
# Pointer parsing and deterministic path resolution
# -------------------------
# Inputs: pointer JSON contents + seed list
# Output: resolved per-seed predictions.csv paths with no folder searching
def _read_json(path: str) -> dict:
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def resolve_predictions_paths_from_pointer(pointer_json_path: str, seeds: list) -> dict:
    """
    Resolve per-seed predictions.csv paths using known pointer schemas.

    Inputs
    - pointer_json_path: path to tag_run_pointer.json
    - seeds: list of expected seeds

    Outputs
    - run_root: best-effort run root (falls back to tag_root)
    - tag_root: folder that contains the pointer file
    - predictions_by_seed: dict(seed -> predictions.csv path)

    Resolution order:
      1) predictions_csv_by_seed
      2) runs[seed]["predictions_csv"]
      3) seed_run_dirs entries (append /predictions.csv)
      4) fallback under tag_root/run_D7_on_D2test_seed{seed}/predictions.csv
    """
    ptr = _read_json(pointer_json_path)

    pointer_path = Path(pointer_json_path)
    TAG_ROOT = pointer_path.parent  # tag folder that owns this pointer file

    # Prefer explicit run_root fields when present, else use TAG_ROOT
    run_root = (
        ptr.get("run_root")
        or ptr.get("run_dir")
        or ptr.get("root")
        or ptr.get("test_run_root")
        or None
    )
    if run_root is None:
        run_root = str(TAG_ROOT)

    run_root = str(run_root).strip()

    predictions_by_seed = {}

    # (1) Direct map: predictions_csv_by_seed = {"1337": ".../predictions.csv", ...}
    pbs = ptr.get("predictions_csv_by_seed")
    if isinstance(pbs, dict):
        for s in seeds:
            v = pbs.get(str(s))
            if isinstance(v, str) and len(v) > 0:
                predictions_by_seed[s] = v

    # (2) Nested map: runs = {"1337": {"predictions_csv": "..."}, ...}
    if not predictions_by_seed:
        runs = ptr.get("runs")
        if isinstance(runs, dict):
            for s in seeds:
                d = runs.get(str(s))
                if isinstance(d, dict):
                    v = d.get("predictions_csv")
                    if isinstance(v, str) and len(v) > 0:
                        predictions_by_seed[s] = v

    # (3) List of per-seed run dirs: seed_run_dirs = ["...seed1337", ...]
    if len(predictions_by_seed) < len(seeds):
        srd = ptr.get("seed_run_dirs")
        if isinstance(srd, list):
            for item in srd:
                if not isinstance(item, str):
                    continue
                for s in seeds:
                    if f"seed{s}" in item:
                        predictions_by_seed.setdefault(s, str(Path(item) / "predictions.csv"))

    # (4) Fallback to the standard per-seed folder name under TAG_ROOT
    for s in seeds:
        if s not in predictions_by_seed:
            predictions_by_seed[s] = str(TAG_ROOT / f"run_D7_on_D2test_seed{s}" / "predictions.csv")

    return {
        "run_root": run_root,
        "tag_root": str(TAG_ROOT),
        "predictions_by_seed": predictions_by_seed,
    }

def _best_existing_predictions_path(pointer_json_path: str, seeds: list) -> dict:
    """
    Resolve predictions.csv paths, but if the pointer's TAG_ROOT does not contain the files,
    fall back to searching for the newest folder under the same test-run parent that:
      - contains the same tag string (e.g., trainEnh3)
      - contains predictions.csv for all requested seeds
    """
    resolved = resolve_predictions_paths_from_pointer(pointer_json_path, seeds)

    # If all exist, done
    all_ok = True
    for s in seeds:
        if not os.path.isfile(resolved["predictions_by_seed"][s]):
            all_ok = False
            break
    if all_ok:
        return resolved

    pointer_path = Path(pointer_json_path)
    tag_root = pointer_path.parent
    parent = tag_root.parent  # monolingual_test_runs/ or multilingual_test_runs/
    tag_name = tag_root.name.lower()

    # Identify a stable tag token to match (trainEnh3 / trainEnh2 / trainEnh1 / base)
    want = None
    for tok in ["trainenh3", "trainenh2", "trainenh1", "base"]:
        if tok in tag_name:
            want = tok
            break
    if want is None:
        return resolved

    candidates = []
    for d in parent.iterdir():
        if not d.is_dir():
            continue
        dn = d.name.lower()
        if want not in dn:
            continue

        # Prefer directories that are actually test run roots, not just tag folders
        ok = True
        for s in seeds:
            p = d / f"run_D7_on_D2test_seed{s}" / "predictions.csv"
            if not p.is_file():
                ok = False
                break
        if ok:
            candidates.append(d)

    if not candidates:
        # Last attempt: search one level deeper (handles tag_root__timestamp patterns)
        for d in parent.iterdir():
            if not d.is_dir():
                continue
            dn = d.name.lower()
            if want not in dn:
                continue
            ok = True
            for s in seeds:
                found = False
                for p in d.rglob(f"run_D7_on_D2test_seed{s}/predictions.csv"):
                    if p.is_file():
                        found = True
                        break
                if not found:
                    ok = False
                    break
            if ok:
                candidates.append(d)

    if not candidates:
        return resolved

    candidates.sort(key=lambda x: x.stat().st_mtime, reverse=True)
    best = candidates[0]

    resolved["run_root"] = str(best)
    resolved["tag_root"] = str(best)
    resolved["predictions_by_seed"] = {
        s: str(best / f"run_D7_on_D2test_seed{s}" / "predictions.csv") for s in seeds
    }
    return resolved

# -------------------------
# Validate pointers and predictions, then build RUNS_RESOLVED
# -------------------------
# Inputs: RUNS dict (conditions) + resolved predictions paths
# Outputs: RUNS_RESOLVED dict used directly by later evaluation cells
print("D7_OUT_ROOT:", D7_OUT_ROOT)
print("D2_OUT_ROOT:", D2_OUT_ROOT)
print("SEEDS:", SEEDS)
print("MECH_EVAL_ROOT:", MECH_EVAL_ROOT)
print()

RUNS_RESOLVED = {}

for tag, info in RUNS.items():
    pointer_path = str(info["pointer_json"]).strip()

    # Guard: pointer JSON must exist for this condition
    if not os.path.isfile(pointer_path):
        raise FileNotFoundError(f"[{tag}] pointer_json not found: {pointer_path}")

    resolved = _best_existing_predictions_path(pointer_path, SEEDS)

    # Per-condition output folder for mechanistic eval artifacts
    out_dir = os.path.join(MECH_EVAL_ROOT, tag)
    os.makedirs(out_dir, exist_ok=True)

    # Guard: predictions.csv must already exist for all seeds (this cell runs after tests)
    missing = []
    for s, p in resolved["predictions_by_seed"].items():
        if not os.path.isfile(p):
            missing.append((s, p))

    print(f"[{tag}]")
    print("  pointer_json:", pointer_path)
    print("  tag_root   :", resolved["tag_root"])
    print("  run_root   :", resolved["run_root"])
    print("  out_dir    :", out_dir)
    print("  predictions_by_seed:")
    for s in SEEDS:
        print(f"    seed{s}: {resolved['predictions_by_seed'][s]}")
    if missing:
        print("  !!! MISSING predictions.csv for:")
        for s, p in missing:
            print(f"    seed{s}: {p}")
        raise FileNotFoundError(f"[{tag}] Missing predictions.csv files (see above).")
    print()

    RUNS_RESOLVED[tag] = {
        "tag": tag,
        "pointer_json": pointer_path,
        "tag_root": resolved["tag_root"],
        "run_root": resolved["run_root"],
        "predictions_by_seed": resolved["predictions_by_seed"],
        "out_dir": out_dir,
    }

print("Cell 1 complete: RUNS_RESOLVED is ready for downstream scoring/ablation cells.")

The following cell runs a **data ablation analysis** using prediction files that already exist. It does not run any new model inference. It assumes that the run mapping (`RUNS_RESOLVED`), the mechanistic evaluation root folder (`MECH_EVAL_ROOT`), and the list of seeds (`SEEDS`) are already defined from the earlier setup cell, and it stops immediately with a clear error if any of these are missing. Using the resolved paths, it loads each condition’s saved `predictions.csv` file for all three seeds (Base, trainEnh1, trainEnh2, trainEnh3) and evaluates results separately for the two clip types: `task_group == "vowel"` and `task_group == "other"`.

For each condition and seed, the cell first checks that the predictions file contains all required columns, including `task_group`, `sex_norm`, and `threshold_used_global`. It confirms that the `seed` value stored in the file matches the expected seed, to avoid mixing results from different runs, and verifies that `threshold_used_global` is the same for all rows in that file. This stored threshold is then used to turn probabilities into binary predictions, so all models are evaluated in the same way they were originally tested. Within each task group, the cell computes AUROC (set to NaN if it cannot be computed because only one class is present) and standard threshold-based metrics at the stored threshold: accuracy, precision, sensitivity (recall), specificity, F1 score, MCC, and the full confusion matrix (TN, FP, FN, TP). It also computes a sex-based fairness metric within each task group using only Parkinson’s cases: false negative rate (FNR) for males and females separately, the signed gap ΔFNR = FNR(F) − FNR(M), and the absolute gap. These fairness calculations use only rows with sex values `M` and `F`, and automatically exclude `UNK`.

All results are saved in three forms. First, a single long-format table called `combined_taskgroup_metrics.csv` is written under the mechanistic evaluation root, with one row per condition, seed, and task group containing all key metrics and counts. Second, for each condition, a JSON summary file named `taskgroup_eval_normal.json` is written under `.../mechanistic_eval/<tag>/`. This file includes detailed per-seed results as well as mean and standard deviation summaries across the three seeds for each task group. Third, for each condition, a confusion-matrix table called `taskgroup_confusion_tables.csv` is written in the same folder so threshold outcomes can be reviewed later without rerunning the analysis.

The cell also creates visual summaries and saves them as PNG files. It generates four comparison plots under `.../mechanistic_eval/plots_all_conditions/`: AUROC by condition for vowel clips and for other clips, and ΔFNR by condition for vowel clips and for other clips. Each plot shows individual seed results as slightly jittered points, along with a mean line and error bars showing the standard deviation across seeds. In addition, the cell creates a single PD-denominator plot that is global rather than condition-specific. This plot shows the number of PD males and PD females for vowel and other clips, reflecting that these counts come from the fixed D2 test set and do not change across conditions. Finally, to keep each condition folder complete on its own, the cell copies all generated plots from the shared plots folder into `.../mechanistic_eval/<tag>/plots/` for every condition.

In [None]:
# =========================
# Data ablation from saved predictions (no new inference) + summary plots
# -------------------------
# Purpose
# - Compare performance separately on vowel clips vs other clips using the already-saved predictions.
# - Report both accuracy style metrics and a sex-based fairness gap (ΔFNR on PD-only clips).
#
# Inputs
# - RUNS_RESOLVED (from Cell 1): per condition, per seed path to predictions.csv
# - Each predictions.csv must include y_true, y_score, task_group, sex_norm, threshold_used_global
#
# Outputs
# - Per condition:
#   - taskgroup_eval_normal.json (per seed metrics + mean±SD summary)
#   - taskgroup_confusion_tables.csv (confusion counts per seed and task_group)
# - Combined:
#   - combined_taskgroup_metrics.csv (one row per condition, seed, task_group)
# - Plots:
#   - AUROC by condition (vowel and other, separate figures)
#   - ΔFNR by condition (vowel and other, separate figures)
#   - One global PD denominator plot (vowel vs other) since denominators do not change by condition
# =========================

import os, json
from pathlib import Path

import numpy as np
import pandas as pd

from sklearn.metrics import (
    roc_auc_score,
    confusion_matrix,
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    matthews_corrcoef,
)

import matplotlib.pyplot as plt

# -------------------------
# 0) Guards (this cell depends on Cell 1 variables)
# -------------------------
if "RUNS_RESOLVED" not in globals():
    raise RuntimeError("RUNS_RESOLVED not found. Run Cell 1 first in the same runtime.")
if "MECH_EVAL_ROOT" not in globals():
    raise RuntimeError("MECH_EVAL_ROOT not found. Run Cell 1 first in the same runtime.")
if "SEEDS" not in globals():
    raise RuntimeError("SEEDS not found. Run Cell 1 first in the same runtime.")

# -------------------------
# 0.5) Plot readability defaults (150% text size)
# -------------------------
# Inputs: none (affects matplotlib defaults only)
# Outputs: larger text across ALL plots generated in this cell (AUROC, ΔFNR, PD denominators)
plt.rcParams.update({
    "font.size": plt.rcParams.get("font.size", 10) * 1.5,
    "axes.titlesize": plt.rcParams.get("axes.titlesize", plt.rcParams.get("font.size", 10)) * 1.5,
    "axes.labelsize": plt.rcParams.get("axes.labelsize", plt.rcParams.get("font.size", 10)) * 1.5,
    "xtick.labelsize": plt.rcParams.get("xtick.labelsize", plt.rcParams.get("font.size", 10)) * 1.5,
    "ytick.labelsize": plt.rcParams.get("ytick.labelsize", plt.rcParams.get("font.size", 10)) * 1.5,
    "legend.fontsize": plt.rcParams.get("legend.fontsize", plt.rcParams.get("font.size", 10)) * 1.5,
})

# -------------------------
# 1) Expected schema and evaluation slices
# -------------------------
# Inputs: required columns present in predictions.csv
# Outputs: constants used for validation, slicing, and plot locations
REQUIRED_COLS = [
    "clip_path",
    "y_true",
    "y_score",
    "sex_norm",
    "speaker_id",
    "task_group",
    "seed",
    "threshold_used_global",
]
TASK_GROUPS = ["vowel", "other"]
SEX_LEVELS = ["M", "F"]  # UNK excluded from sex-specific fairness metrics

PLOTS_ALL_DIR = os.path.join(MECH_EVAL_ROOT, "plots_all_conditions")
os.makedirs(PLOTS_ALL_DIR, exist_ok=True)

# -------------------------
# 2) Helper functions (metrics + safety checks)
# -------------------------
# Inputs: arrays or small DataFrames
# Outputs: scalar metrics, summaries, and validation errors when needed
def _auc_or_nan(y_true, y_score):
    # AUROC needs both classes; return NaN if the slice is single-class
    y_true = np.asarray(y_true).astype(int)
    if len(np.unique(y_true)) < 2:
        return np.nan
    return float(roc_auc_score(y_true, y_score))

def _specificity_from_cm(cm):
    # Specificity = TN / (TN + FP)
    tn, fp, fn, tp = cm.ravel()
    denom = tn + fp
    return float(tn / denom) if denom > 0 else np.nan

def _fnr_on_pd_only(y_true, y_pred):
    # FNR computed only on PD clips: FN / (FN + TP)
    y_true = np.asarray(y_true).astype(int)
    y_pred = np.asarray(y_pred).astype(int)
    mask = (y_true == 1)
    if int(mask.sum()) == 0:
        return np.nan, 0
    fn = int((y_pred[mask] == 0).sum())
    tp = int((y_pred[mask] == 1).sum())
    denom = fn + tp
    fnr = float(fn / denom) if denom > 0 else np.nan
    return fnr, int(mask.sum())

def _mean_sd(values):
    # Mean and sample SD (ddof=1) over non-NaN values
    vals = np.asarray(values, dtype=float)
    vals = vals[~np.isnan(vals)]
    if len(vals) == 0:
        return {"mean": np.nan, "sd": np.nan, "n": 0}
    if len(vals) == 1:
        return {"mean": float(vals[0]), "sd": 0.0, "n": 1}
    return {"mean": float(np.mean(vals)), "sd": float(np.std(vals, ddof=1)), "n": int(len(vals))}

def _verify_required_cols(df, path):
    # Fail fast if predictions.csv schema is not as expected
    missing = [c for c in REQUIRED_COLS if c not in df.columns]
    if missing:
        raise ValueError(f"Missing required columns in {path}: {missing}")

def _read_predictions_csv(path):
    # Read and validate one predictions.csv
    df = pd.read_csv(path)
    _verify_required_cols(df, path)
    return df

def compute_metrics_for_subset(df_subset, threshold):
    """
    Compute metrics for one slice (one seed, one task_group).
    Returns:
    - AUROC
    - threshold-based metrics (confusion + common scores)
    - fairness: FNR by sex on PD-only clips and ΔFNR = FNR(F) − FNR(M)
    """
    y_true = df_subset["y_true"].astype(int).to_numpy()
    y_score = df_subset["y_score"].astype(float).to_numpy()
    y_pred = (y_score >= threshold).astype(int)

    auc = _auc_or_nan(y_true, y_score)

    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    tn, fp, fn, tp = cm.ravel()

    acc = float(accuracy_score(y_true, y_pred))
    prec = float(precision_score(y_true, y_pred, zero_division=0))
    rec = float(recall_score(y_true, y_pred, zero_division=0))  # sensitivity/recall
    f1v = float(f1_score(y_true, y_pred, zero_division=0))
    spec = _specificity_from_cm(cm)
    mcc = float(matthews_corrcoef(y_true, y_pred)) if (len(np.unique(y_true)) > 1 or len(np.unique(y_pred)) > 1) else np.nan

    # Fairness: compute FNR separately for M and F on PD-only clips
    fnr_vals = {}
    n_pd_vals = {}
    for sex in SEX_LEVELS:
        dsex = df_subset[df_subset["sex_norm"] == sex]
        y_t = dsex["y_true"].astype(int).to_numpy()
        y_p = (dsex["y_score"].astype(float).to_numpy() >= threshold).astype(int)
        fnr, n_pd = _fnr_on_pd_only(y_t, y_p)
        fnr_vals[sex] = fnr
        n_pd_vals[sex] = n_pd

    delta_fnr = np.nan
    abs_delta_fnr = np.nan
    if not np.isnan(fnr_vals["F"]) and not np.isnan(fnr_vals["M"]):
        delta_fnr = float(fnr_vals["F"] - fnr_vals["M"])
        abs_delta_fnr = float(abs(delta_fnr))

    return {
        "auroc": auc,
        "threshold": float(threshold),
        "metrics_at_threshold": {
            "accuracy": acc,
            "precision": prec,
            "sensitivity_recall": rec,
            "specificity": spec,
            "f1": f1v,
            "mcc": mcc,
            "confusion": {"TN": int(tn), "FP": int(fp), "FN": int(fn), "TP": int(tp)},
            "n": int(len(df_subset)),
        },
        "fairness": {
            "FNR_M": fnr_vals["M"],
            "FNR_F": fnr_vals["F"],
            "delta_FNR_F_minus_M": delta_fnr,
            "abs_delta_FNR": abs_delta_fnr,
            "n_PD_M": int(n_pd_vals["M"]),
            "n_PD_F": int(n_pd_vals["F"]),
        },
    }

# -------------------------
# 3) Main scoring loop (condition → seed → task_group)
# -------------------------
# Inputs: RUNS_RESOLVED and SEEDS
# Outputs: per-condition JSON summaries and two combined in-memory tables for CSV + plots
all_rows = []       # one row per condition/seed/task_group
all_conf_rows = []  # confusion counts per condition/seed/task_group

for tag, info in RUNS_RESOLVED.items():
    out_dir = info["out_dir"]
    plots_dir = os.path.join(out_dir, "plots")
    os.makedirs(plots_dir, exist_ok=True)

    per_seed = {}

    for seed in SEEDS:
        pred_path = info["predictions_by_seed"][seed]
        df = _read_predictions_csv(pred_path)

        # Sanity: ensure the file contains exactly one seed and it matches the expected seed
        unique_seeds = sorted(set(df["seed"].astype(int).tolist()))
        if len(unique_seeds) != 1 or unique_seeds[0] != int(seed):
            raise ValueError(f"[{tag}] Seed mismatch in {pred_path}. Found seeds {unique_seeds}, expected {seed}.")

        # Threshold is read from the file and must be constant within the seed file
        thr_unique = df["threshold_used_global"].astype(float).to_numpy()
        thr_unique_rounded = sorted(set([round(float(x), 12) for x in thr_unique]))
        if len(thr_unique_rounded) != 1:
            raise ValueError(
                f"[{tag}] threshold_used_global is not constant within seed {seed} in {pred_path}. "
                f"Unique (rounded) thresholds: {thr_unique_rounded[:10]}"
            )
        threshold = float(thr_unique_rounded[0])

        per_seed.setdefault(seed, {})

        # Score vowel-only and other-only slices separately
        for tg in TASK_GROUPS:
            dfg = df[df["task_group"] == tg].copy()
            m = compute_metrics_for_subset(dfg, threshold)
            per_seed[seed][tg] = m

            # Flat row for combined CSV and plotting
            all_rows.append({
                "tag": tag,
                "seed": int(seed),
                "task_group": tg,
                "threshold_used_global": float(threshold),
                "auroc": m["auroc"],
                "accuracy": m["metrics_at_threshold"]["accuracy"],
                "precision": m["metrics_at_threshold"]["precision"],
                "sensitivity_recall": m["metrics_at_threshold"]["sensitivity_recall"],
                "specificity": m["metrics_at_threshold"]["specificity"],
                "f1": m["metrics_at_threshold"]["f1"],
                "mcc": m["metrics_at_threshold"]["mcc"],
                "FNR_M": m["fairness"]["FNR_M"],
                "FNR_F": m["fairness"]["FNR_F"],
                "delta_FNR_F_minus_M": m["fairness"]["delta_FNR_F_minus_M"],
                "abs_delta_FNR": m["fairness"]["abs_delta_FNR"],
                "n_PD_M": int(m["fairness"]["n_PD_M"]),
                "n_PD_F": int(m["fairness"]["n_PD_F"]),
                "n_rows": int(m["metrics_at_threshold"]["n"]),
                "predictions_csv": pred_path,
            })

            # Confusion table row for easier inspection
            conf = m["metrics_at_threshold"]["confusion"]
            all_conf_rows.append({
                "tag": tag,
                "seed": int(seed),
                "task_group": tg,
                "threshold_used_global": float(threshold),
                **conf,
                "n_rows": int(m["metrics_at_threshold"]["n"]),
                "predictions_csv": pred_path,
            })

    # Mean ± SD summary across seeds for each task_group
    summary = {}
    for tg in TASK_GROUPS:
        summary[tg] = {
            "auroc": _mean_sd([per_seed[s][tg]["auroc"] for s in SEEDS]),
            "accuracy": _mean_sd([per_seed[s][tg]["metrics_at_threshold"]["accuracy"] for s in SEEDS]),
            "precision": _mean_sd([per_seed[s][tg]["metrics_at_threshold"]["precision"] for s in SEEDS]),
            "sensitivity_recall": _mean_sd([per_seed[s][tg]["metrics_at_threshold"]["sensitivity_recall"] for s in SEEDS]),
            "specificity": _mean_sd([per_seed[s][tg]["metrics_at_threshold"]["specificity"] for s in SEEDS]),
            "f1": _mean_sd([per_seed[s][tg]["metrics_at_threshold"]["f1"] for s in SEEDS]),
            "mcc": _mean_sd([per_seed[s][tg]["metrics_at_threshold"]["mcc"] for s in SEEDS]),
            "FNR_M": _mean_sd([per_seed[s][tg]["fairness"]["FNR_M"] for s in SEEDS]),
            "FNR_F": _mean_sd([per_seed[s][tg]["fairness"]["FNR_F"] for s in SEEDS]),
            "delta_FNR_F_minus_M": _mean_sd([per_seed[s][tg]["fairness"]["delta_FNR_F_minus_M"] for s in SEEDS]),
            "abs_delta_FNR": _mean_sd([per_seed[s][tg]["fairness"]["abs_delta_FNR"] for s in SEEDS]),
            "n_PD_M": {"values_by_seed": {str(s): int(per_seed[s][tg]["fairness"]["n_PD_M"]) for s in SEEDS}},
            "n_PD_F": {"values_by_seed": {str(s): int(per_seed[s][tg]["fairness"]["n_PD_F"]) for s in SEEDS}},
        }

    # Per-condition JSON output written under its mechanistic evaluation folder
    out_json = {
        "tag": tag,
        "mode": "normal_routing_existing_predictions",
        "seeds": [int(s) for s in SEEDS],
        "per_seed": per_seed,
        "summary_mean_sd_across_seeds": summary,
        "out_dir": out_dir,
    }
    with open(os.path.join(out_dir, "taskgroup_eval_normal.json"), "w", encoding="utf-8") as f:
        json.dump(out_json, f, indent=2)

# -------------------------
# 4) Write combined CSV outputs
# -------------------------
# Inputs: all_rows and all_conf_rows collected above
# Outputs: combined_taskgroup_metrics.csv under MECH_EVAL_ROOT, plus per-tag confusion CSVs
combined_df = pd.DataFrame(all_rows)
combined_csv_path = os.path.join(MECH_EVAL_ROOT, "combined_taskgroup_metrics.csv")
combined_df.to_csv(combined_csv_path, index=False)

conf_df = pd.DataFrame(all_conf_rows)
for tag in RUNS_RESOLVED.keys():
    out_dir = RUNS_RESOLVED[tag]["out_dir"]
    conf_df_tag = conf_df[conf_df["tag"] == tag].copy()
    conf_df_tag.to_csv(os.path.join(out_dir, "taskgroup_confusion_tables.csv"), index=False)

print("Wrote:")
print("  ", combined_csv_path)
for tag in RUNS_RESOLVED.keys():
    print("  ", os.path.join(RUNS_RESOLVED[tag]["out_dir"], "taskgroup_eval_normal.json"))
    print("  ", os.path.join(RUNS_RESOLVED[tag]["out_dir"], "taskgroup_confusion_tables.csv"))

# -------------------------
# 5) Plotting helpers (seed points + mean±SD)
# -------------------------
# Inputs: combined_df filtered to one task_group
# Outputs: PNG figures saved to the shared plots folder
def _condition_order(tags):
    # Stable ordering for readability across plots
    preferred = ["base", "trainEnh1", "trainEnh2", "trainEnh3"]
    ordered = [t for t in preferred if t in tags] + [t for t in tags if t not in preferred]
    return ordered

def _jitter(xs, seed_index, jitter_scale=0.06):
    # Small horizontal offset so seed points do not overlap
    offsets = [-1, 0, 1]
    o = offsets[seed_index % len(offsets)]
    return xs + o * jitter_scale

def plot_metric_across_conditions(df_in, metric_col, title, ylabel, out_path):
    # Plot per-seed points and mean±SD across seeds for one metric
    tags = _condition_order(sorted(df_in["tag"].unique().tolist()))
    x = np.arange(len(tags), dtype=float)

    plt.figure(figsize=(10, 4.5))

    # Seed points
    for si, seed in enumerate(SEEDS):
        vals = []
        for t in tags:
            sub = df_in[(df_in["tag"] == t) & (df_in["seed"] == seed)]
            v = sub[metric_col].iloc[0] if len(sub) == 1 else np.nan
            vals.append(v)
        plt.scatter(_jitter(x, si), vals, label=f"seed{seed}", s=28)

    # Mean ± SD across seeds
    means = []
    sds = []
    for t in tags:
        sub = df_in[df_in["tag"] == t][metric_col].astype(float).to_numpy()
        means.append(np.nanmean(sub) if np.any(~np.isnan(sub)) else np.nan)
        sds.append(np.nanstd(sub, ddof=1) if np.sum(~np.isnan(sub)) >= 2 else 0.0)

    plt.errorbar(x, means, yerr=sds, fmt="o-", capsize=4)

    plt.xticks(x, tags, rotation=0)
    plt.title(title)
    plt.ylabel(ylabel)
    plt.grid(True, axis="y", alpha=0.3)
    plt.tight_layout()
    plt.savefig(out_path, dpi=200)
    plt.close()

def plot_pd_denoms_by_taskgroup_global(combined_df_in, title, out_path):
    """
    Global PD denominator plot.
    - X axis: task_group (vowel vs other)
    - Bars: PD counts for M and F
    Denominators should match across conditions because task_group and sex come from the dataset, not the model.
    """
    plt.figure(figsize=(8, 4.5))

    xlabels = TASK_GROUPS
    x = np.arange(len(xlabels), dtype=float)
    width = 0.35

    n_m = []
    n_f = []

    for tg in TASK_GROUPS:
        sub = combined_df_in[combined_df_in["task_group"] == tg].copy()

        # Robust aggregation (should be identical across tag and seed rows)
        n_m.append(float(np.nanmedian(sub["n_PD_M"].astype(float).to_numpy())))
        n_f.append(float(np.nanmedian(sub["n_PD_F"].astype(float).to_numpy())))

    plt.bar(x - width/2, n_m, width=width, label="n_PD_M")
    plt.bar(x + width/2, n_f, width=width, label="n_PD_F")

    plt.xticks(x, xlabels, rotation=0)
    plt.title(title)
    plt.ylabel("Count (PD only)")
    plt.grid(True, axis="y", alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_path, dpi=200)
    plt.close()

# -------------------------
# 6) Generate plots requested for this cell
# -------------------------
# Inputs: combined_df written above
# Outputs: 5 PNG files saved under plots_all_conditions
# (two AUROC, two ΔFNR, one global denominator plot)

# AUROC by condition (separate plots for vowel and other)
for tg in TASK_GROUPS:
    df_tg = combined_df[combined_df["task_group"] == tg].copy()
    auroc_path = os.path.join(PLOTS_ALL_DIR, f"auroc_by_condition__taskgroup_{tg}.png")
    plot_metric_across_conditions(
        df_tg,
        metric_col="auroc",
        title=f"AUROC by condition (task_group = {tg})",
        ylabel="AUROC",
        out_path=auroc_path,
    )

# ΔFNR by condition (separate plots for vowel and other)
for tg in TASK_GROUPS:
    df_tg = combined_df[combined_df["task_group"] == tg].copy()
    gap_path = os.path.join(PLOTS_ALL_DIR, f"deltaFNR_F_minus_M_by_condition__taskgroup_{tg}.png")
    plot_metric_across_conditions(
        df_tg,
        metric_col="delta_FNR_F_minus_M",
        title=f"ΔFNR (FNR_F − FNR_M) by condition (task_group = {tg})",
        ylabel="ΔFNR (signed)",
        out_path=gap_path,
    )

# PD denominators (single global plot: vowel vs other)
denom_global_path = os.path.join(PLOTS_ALL_DIR, "pd_denominators_by_taskgroup_global.png")
plot_pd_denoms_by_taskgroup_global(
    combined_df,
    title="PD denominators by task group (global, D2 test)",
    out_path=denom_global_path,
)

print("Wrote plots under:")
print("  ", PLOTS_ALL_DIR)
print("  ", denom_global_path)

# -------------------------
# 7) Copy plots into each condition folder for convenience
# -------------------------
# Inputs: PNGs in plots_all_conditions
# Outputs: same PNGs duplicated into each <condition>/plots/ folder
PLOTS_TO_COPY = []

for tg in TASK_GROUPS:
    PLOTS_TO_COPY.append(f"auroc_by_condition__taskgroup_{tg}.png")
    PLOTS_TO_COPY.append(f"deltaFNR_F_minus_M_by_condition__taskgroup_{tg}.png")

PLOTS_TO_COPY.append("pd_denominators_by_taskgroup_global.png")

for tag in RUNS_RESOLVED.keys():
    plots_dir = os.path.join(RUNS_RESOLVED[tag]["out_dir"], "plots")
    os.makedirs(plots_dir, exist_ok=True)
    for fname in PLOTS_TO_COPY:
        src = os.path.join(PLOTS_ALL_DIR, fname)
        dst = os.path.join(plots_dir, fname)
        with open(src, "rb") as fsrc, open(dst, "wb") as fdst:
            fdst.write(fsrc.read())

print("Copied plots into each tag's plots/ folder as well.")

# Ablation 3: Forced Task-Head Routing Ablation

The following cell runs a focused **mechanistic evaluation** to see how much the model depends on the choice of task head (vowel head versus other head) when scoring D2 test audio. It produces the five required summary plots. Unlike earlier plotting-only cells, this cell **does run new inference**. It loads already trained model heads for the Base model and for the three enhanced models (trainEnh1, trainEnh2, trainEnh3), then re-scores the same D2 test clips multiple times while deliberately forcing the model to use one head or the other.

First, the cell makes sure Google Drive is mounted, then rebuilds a clean and deterministic run map (`RUNS_RESOLVED`) using the provided `tag_run_pointer.json` files and confirmed train–validation experiment folders. This avoids relying on outputs from earlier cells and avoids broad folder searches. For each condition (base, trainEnh1, trainEnh2, trainEnh3), it resolves the locations of the existing “normal” `predictions.csv` files for each seed, checks that those files exist, and sets up a dedicated output folder under `.../mechanistic_eval/<tag>/`.

Next, the cell loads the D2 master manifest (`manifests/manifest_all.csv`), filters it to **only the D2 test split**, and standardizes two fields needed for analysis. Sex is normalized to `M`, `F`, or `UNK` using D2’s original `male` and `female` values, and each clip is assigned a **task group**: `vowel` only if `task == "vowl"`, otherwise `other`. A PyTorch dataset and data loader are created to read audio from `clip_path` and build an attention mask. For vowel clips, the mask turns off after the last meaningful audio sample so padded silence does not affect the model. For other clips, the mask stays fully on. Audio and masks are padded consistently within each batch.

For each condition and each of the three seeds, the cell loads the saved heads (`best_heads.pt`) from the selected train–validation experiment and attaches them to a frozen wav2vec2 backbone. It reads the **single global threshold actually used in the normal test run** from the existing `predictions.csv` (`threshold_used_global`) and confirms that it is constant within that file. It then runs two new inference passes over the full D2 test set. One pass forces the **vowel head for every clip** and writes `predictions_force_vowel.csv`. The other pass forces the **other head for every clip** and writes `predictions_force_other.csv`. Both files include the same core fields as the normal predictions (clip path, true label, predicted probability, sex, speaker id, task group, seed), plus extra fields for traceability, such as which head was forced, which global threshold was used, which train–validation run the heads came from, and a run timestamp. All forced outputs are written under `.../mechanistic_eval/<tag>/forced_head_runs/run_D7_on_D2test_seed{seed}/` so they never overwrite the original test results.

After inference, the cell computes performance and fairness metrics for three routing modes for each condition and seed: **normal routing**, **force_vowel**, and **force_other**. Metrics are computed separately for vowel-only clips and other-only clips. Performance is measured using AUROC, which is recorded as NaN if a subset contains only one class. Fairness is measured using the **sex-based difference in false negative rate among Parkinson’s cases only**, defined as ΔFNR = FNR(F) − FNR(M), evaluated at the same global threshold used in the normal test run. All results are collected into a long-format table (`forced_head_metrics_long.csv`) under `mechanistic_eval/` and into per-condition JSON summaries (`forced_head_eval_summary.json`) under each condition folder. A registry file (`forced_head_registry.json`) is also written to record exactly which prediction files and thresholds were used for each seed.

Finally, the cell generates the five required plots and saves them as PNG files under `.../mechanistic_eval/plots_all_conditions/`. Four plots show grouped bar charts (mean with standard deviation across the three seeds) for AUROC and ΔFNR across the four conditions, with the three routing modes shown side by side, separately for vowel-only clips and other-only clips. The fifth plot shows “head sensitivity” by computing **AUROC(force vowel) − AUROC(force other)** for each condition and seed, and plotting results for vowel clips and other clips with mean and standard deviation across seeds. At the end, the cell attempts to unassign the Colab runtime, matching the behavior of the test-only cells, so the session is released cleanly after completion.

In [None]:
# =========================
# Task-head ablation on D2 test (new inference) + required mechanistic plots
# -------------------------
# Purpose
# - Quantify how much performance and fairness depend on which head is used.
# - Re-run inference on the D2 test split while forcing:
#     (1) vowel head for every clip
#     (2) other head for every clip
# - Compare those to the normal routed mode using the same global threshold per seed.
#
# Inputs
# - Pointer JSON files that identify where existing per-seed predictions.csv live
# - D2 manifest (test split) and audio clips
# - Trained heads (best_heads.pt) from each condition and seed
#
# Outputs
# - Per seed, per condition: predictions_force_vowel.csv and predictions_force_other.csv
# - Long-form metrics table (CSV) and per-condition JSON summaries
# - Five PNG plots saved under a shared plots folder
# - Runtime unassigned at the end
# =========================

import os
import json
import math
import time
import warnings
import re
from pathlib import Path
from typing import Dict, Any, List, Tuple

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import soundfile as sf
from tqdm.auto import tqdm

from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt

# -------------------------
# Drive mount (Colab)
# -------------------------
# Ensures files are reachable even if this runtime started fresh.
if not os.path.isdir("/content/drive/MyDrive"):
    from google.colab import drive  # type: ignore
    drive.mount("/content/drive")

# -------------------------
# Fixed inputs (roots, seeds, and run identifiers)
# -------------------------
# Defines where to read D2 data from and where to write mechanistic evaluation outputs.
D7_OUT_ROOT = "/content/drive/MyDrive/AI_PD_Project/Datasets/D7-Multilingual (D1_D4_D5v2_D6)/preprocessed_v1"
D2_OUT_ROOT = "/content/drive/MyDrive/AI_PD_Project/Datasets/D2-Slovak (EWA-DB)/EWA-DB/preprocessed_v1"
SEEDS = [1337, 2024, 7777]

# Output root for all forced-head artifacts and plots.
MECH_EVAL_ROOT = f"{D7_OUT_ROOT}/mechanistic_eval"
Path(MECH_EVAL_ROOT).mkdir(parents=True, exist_ok=True)

def _parse_run_stamp_from_name(name: str) -> Tuple[int, int]:
    """
    Tries to parse ..._YYYYMMDD_HHMMSS from a folder name.
    Returns (yyyymmdd, hhmmss) as ints, or (0, 0) if not found.
    """
    m = re.search(r"_(\d{8})_(\d{6})(?:\D|$)", str(name))
    if not m:
        return (0, 0)
    return (int(m.group(1)), int(m.group(2)))

def _sort_key_latest(p: Path) -> Tuple[int, int, float]:
    a, b = _parse_run_stamp_from_name(p.name)
    try:
        mt = p.stat().st_mtime
    except Exception:
        mt = 0.0
    return (a, b, mt)

def _find_latest_dir(parent: Path, required_substrings: List[str], must_have_relpaths: List[str]) -> Path:
    """
    Picks the latest directory under parent whose name contains all required_substrings
    (case-insensitive), and that contains all must_have_relpaths.
    """
    if not parent.is_dir():
        raise FileNotFoundError(f"Missing directory: {str(parent)}")

    req = [s.lower() for s in required_substrings if str(s).strip()]
    candidates = []
    for d in parent.iterdir():
        if not d.is_dir():
            continue
        nm = d.name.lower()
        ok = True
        for s in req:
            if s not in nm:
                ok = False
                break
        if not ok:
            continue
        for rel in must_have_relpaths:
            if not (d / rel).exists():
                ok = False
                break
        if ok:
            candidates.append(d)

    if not candidates:
        raise FileNotFoundError(
            f"No matching run directory found under: {str(parent)} "
            f"required_substrings={required_substrings} must_have_relpaths={must_have_relpaths}"
        )

    candidates.sort(key=_sort_key_latest, reverse=True)
    return candidates[0]

def _find_latest_trainval_exp(parent: Path, tag: str, seeds: List[int]) -> Path:
    """
    Finds the latest trainval exp directory that contains required best_heads.pt files.
    Selection rules:
    - For trainEnh*: folder name must contain that tag (case-insensitive).
    - For base: folder name must NOT contain 'trainenh' (case-insensitive).
    """
    if not parent.is_dir():
        raise FileNotFoundError(f"Missing directory: {str(parent)}")

    tag_l = tag.lower().strip()
    candidates = []
    for d in parent.iterdir():
        if not d.is_dir():
            continue
        nm = d.name.lower()

        if tag_l == "base":
            if "trainenh" in nm:
                continue
        else:
            if tag_l not in nm:
                continue

        ok = True
        for s in seeds:
            rel = Path(f"run_D7_seed{s}") / "best_heads.pt"
            if not (d / rel).is_file():
                ok = False
                break
        if ok:
            candidates.append(d)

    if not candidates:
        raise FileNotFoundError(
            f"No matching trainval exp dir found under: {str(parent)} for tag={tag} "
            f"that contains best_heads.pt for seeds={seeds}"
        )

    candidates.sort(key=_sort_key_latest, reverse=True)
    return candidates[0]

# Tag pointers (one per condition). Used to rebuild run roots deterministically.
RUN_POINTERS = {}
_multitest_parent = Path(D7_OUT_ROOT) / "multilingual_test_runs"
_monotest_parent = Path(D7_OUT_ROOT) / "monolingual_test_runs"

RUN_POINTERS["base"] = str(_find_latest_dir(
    parent=_multitest_parent,
    required_substrings=["base"],
    must_have_relpaths=["tag_run_pointer.json"],
) / "tag_run_pointer.json")

for k in ["trainEnh1", "trainEnh2", "trainEnh3"]:
    RUN_POINTERS[k] = str(_find_latest_dir(
        parent=_monotest_parent,
        required_substrings=[k],
        must_have_relpaths=["tag_run_pointer.json"],
    ) / "tag_run_pointer.json")

# Trainval experiment folders used to load best_heads.pt for each seed.
TRVAL_EXP_DIRS = {}
_trval_parent = Path(D7_OUT_ROOT) / "trainval_runs"
for tag in ["base", "trainEnh1", "trainEnh2", "trainEnh3"]:
    TRVAL_EXP_DIRS[tag] = str(_find_latest_trainval_exp(_trval_parent, tag=tag, seeds=SEEDS))

# -------------------------
# Resolve run roots and normal predictions.csv paths from pointer JSONs
# -------------------------
# Reconstructs where the existing predictions.csv files live, without relying on earlier cells.
def _load_json(path: str) -> Dict[str, Any]:
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def resolve_run_root_from_pointer(pointer_path: str) -> Tuple[str, str]:
    """
    Returns (tag_root, run_root).
    - tag_root: folder that contains tag_run_pointer.json
    - run_root: folder that contains per-seed run subfolders with predictions.csv
    """
    tag_root = str(Path(pointer_path).parent)
    j = _load_json(pointer_path)

    # Accept several common key names seen in run pointers.
    for k in ["run_root", "run_dir", "test_run_root", "root", "resolved_run_root", "resolved_test_run_root"]:
        v = j.get(k, None)
        if isinstance(v, str) and v.strip():
            v2 = v.strip()
            if v2.startswith("/"):
                return tag_root, v2
            return tag_root, str(Path(tag_root) / v2)

    # Fallback: tag_root is used directly (typical for base layout).
    return tag_root, tag_root

def build_predictions_paths(run_root: str, seeds: List[int]) -> Dict[int, str]:
    """
    Builds the expected per-seed predictions.csv paths under a resolved run folder.
    Fails fast if any seed file is missing.
    """
    out = {}
    for s in seeds:
        p = str(Path(run_root) / f"run_D7_on_D2test_seed{s}" / "predictions.csv")
        if not os.path.isfile(p):
            raise FileNotFoundError(f"predictions.csv not found for seed{s} at: {p}")
        out[int(s)] = p
    return out

RUNS_RESOLVED: Dict[str, Any] = {}

# Basic existence checks (no silent fallback).
for tag, p in RUN_POINTERS.items():
    if not os.path.isfile(p):
        raise FileNotFoundError(f"[{tag}] pointer_json not found: {p}")
for tag, p in TRVAL_EXP_DIRS.items():
    if not os.path.isdir(p):
        raise FileNotFoundError(f"[{tag}] trainval exp dir not found: {p}")

# Build a resolved record per condition: where to read normal preds and where to write new forced-head outputs.
for tag, pointer_path in RUN_POINTERS.items():
    tag_root, run_root = resolve_run_root_from_pointer(pointer_path)
    RUNS_RESOLVED[tag] = {
        "pointer_json": pointer_path,
        "tag_root": tag_root,
        "run_root": run_root,
        "out_dir": str(Path(MECH_EVAL_ROOT) / tag),
        "predictions_by_seed": build_predictions_paths(run_root, SEEDS),
        "trainval_exp_dir": TRVAL_EXP_DIRS[tag],
    }

print("\nRUNS_RESOLVED rebuilt (this runtime):")
for tag in ["base", "trainEnh1", "trainEnh2", "trainEnh3"]:
    print(f"  [{tag}] run_root: {RUNS_RESOLVED[tag]['run_root']}")

# -------------------------
# Inference settings (kept consistent with prior runs)
# -------------------------
# Controls model backbone, sampling rate, batching, and AMP behavior.
BACKBONE_CKPT = "facebook/wav2vec2-base"  # If another backbone was used elsewhere, update here once.
DROPOUT_P     = 0.2
SR_EXPECTED   = 16000
TINY_THRESH   = 1e-4

PER_DEVICE_BS = 16
NUM_WORKERS   = 0
PIN_MEMORY    = False
USE_AMP       = True

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE:", DEVICE, "| USE_AMP:", bool(USE_AMP and DEVICE.type == "cuda"))

# -------------------------
# D2 test set loader (labels, sex, and task grouping)
# -------------------------
# Loads the manifest, filters to split=test, and prepares a DataLoader that yields:
# - waveform samples
# - attention mask
# - labels and metadata needed for slicing metrics
D2_MANIFEST_ALL = f"{D2_OUT_ROOT}/manifests/manifest_all.csv"
if not os.path.isfile(D2_MANIFEST_ALL):
    raise FileNotFoundError(f"Missing D2 manifest_all.csv: {D2_MANIFEST_ALL}")

m_all = pd.read_csv(D2_MANIFEST_ALL)
req_cols = {"split", "clip_path", "label_num", "task", "sex"}
missing = [c for c in sorted(req_cols) if c not in m_all.columns]
if missing:
    raise ValueError(f"D2 manifest missing required columns: {missing}")

m_test = m_all[m_all["split"].astype(str) == "test"].copy()
if len(m_test) == 0:
    raise ValueError("No rows found where split == 'test' in D2 manifest_all.csv")

# Normalizes D2 sex values into M/F/UNK.
def _sex_norm_from_d2(sex_val: Any) -> str:
    s = "" if sex_val is None else str(sex_val).strip().lower()
    if s == "male":
        return "M"
    if s == "female":
        return "F"
    return "UNK"

# Converts D2 task into the two groups used by the two-head model.
def _task_group_from_task(task_val: Any) -> str:
    t = "" if task_val is None else str(task_val).strip().lower()
    return "vowel" if t == "vowl" else "other"

m_test["sex_norm"] = m_test["sex"].apply(_sex_norm_from_d2)
m_test["task_group"] = m_test["task"].apply(_task_group_from_task)
m_test["y_true"] = m_test["label_num"].astype(int)
if "speaker_id" not in m_test.columns:
    m_test["speaker_id"] = ""

print("D2 TEST rows:", len(m_test))
print("Task group counts:", m_test["task_group"].value_counts().to_dict())

class D2TestDataset(Dataset):
    # Loads audio and constructs an attention mask aligned with the padding rule.
    def __init__(self, df: pd.DataFrame):
        self.df = df.reset_index(drop=True)

    def __len__(self) -> int:
        return int(len(self.df))

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        r = self.df.iloc[int(idx)]
        clip_path = str(r["clip_path"])
        if not os.path.exists(clip_path):
            raise FileNotFoundError(f"Missing clip_path: {clip_path}")

        y, sr = sf.read(clip_path, always_2d=False)
        if isinstance(y, np.ndarray) and y.ndim > 1:
            y = np.mean(y, axis=1)
        y = np.asarray(y, dtype=np.float32)

        if int(sr) != int(SR_EXPECTED):
            raise ValueError(f"Sample rate mismatch (expected {SR_EXPECTED}): {clip_path} has sr={sr}")

        task_group = str(r["task_group"])

        # Attention mask behavior:
        # - vowel: mask trailing near-zero padding so it does not affect pooling
        # - other: keep all samples active
        attn = np.ones((len(y),), dtype=np.int64)
        if task_group == "vowel":
            k = -1
            for j in range(len(y) - 1, -1, -1):
                if abs(float(y[j])) > float(TINY_THRESH):
                    k = j
                    break
            if k >= 0:
                attn[:k+1] = 1
                attn[k+1:] = 0
            else:
                attn[:] = 1
        else:
            attn[:] = 1

        sid = r.get("speaker_id", "")
        sid = "" if (sid is None or (isinstance(sid, float) and np.isnan(sid))) else str(sid)

        return {
            "input_values": torch.from_numpy(y),
            "attention_mask": torch.from_numpy(attn),
            "labels": torch.tensor(int(r["y_true"]), dtype=torch.long),
            "task_group": task_group,
            "sex_norm": str(r["sex_norm"]),
            "clip_path": clip_path,
            "speaker_id": sid,
        }

def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
    # Pads variable-length audio to a batch max length while padding the attention mask to match.
    max_len = int(max(b["input_values"].numel() for b in batch))
    xs, ams, ys = [], [], []
    tgs, sexs, cps, sids = [], [], [], []

    for b in batch:
        x = b["input_values"]
        a = b["attention_mask"]
        pad = max_len - x.numel()
        if pad > 0:
            x = torch.cat([x, torch.zeros(pad, dtype=x.dtype)], dim=0)
            a = torch.cat([a, torch.zeros(pad, dtype=a.dtype)], dim=0)
        xs.append(x)
        ams.append(a)
        ys.append(b["labels"])
        tgs.append(b["task_group"])
        sexs.append(b["sex_norm"])
        cps.append(b["clip_path"])
        sids.append(b["speaker_id"])

    return {
        "input_values": torch.stack(xs, dim=0),
        "attention_mask": torch.stack(ams, dim=0),
        "labels": torch.stack(ys, dim=0),
        "task_group": tgs,
        "sex_norm": sexs,
        "clip_path": cps,
        "speaker_id": sids,
    }

test_loader = DataLoader(
    D2TestDataset(m_test),
    batch_size=int(PER_DEVICE_BS),
    shuffle=False,
    num_workers=int(NUM_WORKERS),
    pin_memory=bool(PIN_MEMORY),
    collate_fn=collate_fn,
)

# -------------------------
# Two-head model definition + head loading
# -------------------------
# Backbone is frozen; only the two heads are loaded from best_heads.pt.
from transformers import Wav2Vec2Model

class Wav2Vec2TwoHeadClassifier(nn.Module):
    def __init__(self, ckpt: str, dropout_p: float):
        super().__init__()
        self.backbone = Wav2Vec2Model.from_pretrained(str(ckpt), use_safetensors=True)
        for p in self.backbone.parameters():
            p.requires_grad = False

        H = int(self.backbone.config.hidden_size)
        self.pre_vowel = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.pre_other = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.head_vowel = nn.Linear(H, 2)
        self.head_other = nn.Linear(H, 2)

    def masked_mean_pool(self, last_hidden: torch.Tensor, attn_mask_samples: torch.Tensor) -> torch.Tensor:
        feat_mask = self.backbone._get_feature_vector_attention_mask(last_hidden.shape[1], attn_mask_samples)
        feat_mask = feat_mask.to(last_hidden.device).unsqueeze(-1).type_as(last_hidden)
        masked = last_hidden * feat_mask
        denom = feat_mask.sum(dim=1).clamp(min=1.0)
        return masked.sum(dim=1) / denom

    def _heads_fp32(self, x_any: torch.Tensor, head: nn.Module) -> torch.Tensor:
        # Ensures head matmul runs in FP32 even when AMP is enabled.
        x = x_any.float()
        if x.is_cuda:
            with torch.amp.autocast(device_type="cuda", enabled=False):
                return head(x)
        return head(x)

    def forward_both_logits(self, input_values: torch.Tensor, attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # Computes both heads for every clip, then returns both logits for forced selection.
        with torch.no_grad():
            out = self.backbone(input_values=input_values, attention_mask=attention_mask)
            last_hidden = out.last_hidden_state
        pooled = self.masked_mean_pool(last_hidden, attention_mask)
        z_v = self.pre_vowel(pooled.float())
        z_o = self.pre_other(pooled.float())
        logits_v = self._heads_fp32(z_v, self.head_vowel)
        logits_o = self._heads_fp32(z_o, self.head_other)
        return logits_v, logits_o

def load_heads_into_model(model: nn.Module, best_heads_path: str) -> nn.Module:
    # Loads saved head weights; supports both wrapped and raw state_dict formats.
    obj = torch.load(best_heads_path, map_location="cpu")
    if isinstance(obj, dict) and "state_dict" in obj and isinstance(obj["state_dict"], dict):
        sd = obj["state_dict"]
    elif isinstance(obj, dict):
        sd = obj
    else:
        raise RuntimeError(f"Unsupported best_heads.pt format: {best_heads_path}")
    model.load_state_dict(sd, strict=False)
    return model

# -------------------------
# Threshold used for scoring (read from normal predictions.csv)
# -------------------------
# Uses the single global threshold saved during the normal test run for that seed.
def read_thr_used_global(pred_csv_path: str) -> float:
    df = pd.read_csv(pred_csv_path)
    if "threshold_used_global" not in df.columns:
        raise KeyError(f"threshold_used_global column missing in: {pred_csv_path}")
    vals = pd.to_numeric(df["threshold_used_global"], errors="coerce").dropna().unique()
    if len(vals) == 0:
        raise ValueError(f"No usable threshold_used_global values in: {pred_csv_path}")
    if len(vals) > 1:
        raise RuntimeError(f"threshold_used_global not constant in: {pred_csv_path}. Unique: {vals[:10]}")
    return float(vals[0])

# -------------------------
# Forced-head inference and CSV writer
# -------------------------
# Runs inference once per forced mode and writes a predictions CSV with metadata for analysis.
def infer_forced(loader: DataLoader, model: Wav2Vec2TwoHeadClassifier, forced_head: str, desc: str):
    if forced_head not in ["vowel", "other"]:
        raise ValueError("forced_head must be 'vowel' or 'other'")

    amp_ok = bool(USE_AMP and DEVICE.type == "cuda")
    y_true_all, y_score_all = [], []
    sex_all, tg_all, clip_all, spk_all = [], [], [], []

    model.eval()
    with torch.no_grad():
        for batch in tqdm(loader, desc=desc):
            x = batch["input_values"].to(DEVICE, non_blocking=True)
            a = batch["attention_mask"].to(DEVICE, non_blocking=True)
            y = batch["labels"].detach().cpu().numpy().astype(np.int64)

            if amp_ok:
                with torch.amp.autocast(device_type="cuda", enabled=True):
                    lv, lo = model.forward_both_logits(x, a)
            else:
                lv, lo = model.forward_both_logits(x, a)

            logits = lv if forced_head == "vowel" else lo
            p = torch.softmax(logits.float(), dim=-1)[:, 1].detach().cpu().numpy().astype(np.float64)

            y_true_all.extend(y.tolist())
            y_score_all.extend(p.tolist())
            sex_all.extend(list(batch["sex_norm"]))
            tg_all.extend(list(batch["task_group"]))
            clip_all.extend(list(batch["clip_path"]))
            spk_all.extend(list(batch["speaker_id"]))

    return (
        np.asarray(clip_all, dtype=object),
        np.asarray(y_true_all, dtype=np.int64),
        np.asarray(y_score_all, dtype=np.float64),
        np.asarray(sex_all, dtype=object),
        np.asarray(spk_all, dtype=object),
        np.asarray(tg_all, dtype=object),
    )

def write_force_csv(out_csv: str, clip_path, y_true, y_score, sex_norm, speaker_id, task_group, seed, forced_head, thr, trainval_exp_tag, run_stamp):
    # Saves forced-head predictions with enough fields to reproduce slicing and scoring later.
    df = pd.DataFrame({
        "clip_path": clip_path,
        "y_true": y_true,
        "y_score": y_score,
        "sex_norm": sex_norm,
        "speaker_id": speaker_id,
        "task_group": task_group,
        "seed": int(seed),
        "forced_head": str(forced_head),
        "threshold_used_global": float(thr),
        "trainval_exp_tag": str(trainval_exp_tag),
        "run_stamp": str(run_stamp),
    })
    Path(out_csv).parent.mkdir(parents=True, exist_ok=True)
    df.to_csv(out_csv, index=False)

# -------------------------
# Metrics helpers (AUROC and ΔFNR)
# -------------------------
# Computes AUROC per task group, plus ΔFNR across sex within the PD subset.
def compute_auc(y_true: np.ndarray, y_prob: np.ndarray) -> float:
    # AUROC is undefined if only one class is present; return NaN to keep the pipeline running.
    if len(np.unique(y_true)) < 2:
        return float("nan")
    return float(roc_auc_score(y_true, y_prob))

def fairness_delta_fnr_pd_only(y_true: np.ndarray, y_prob: np.ndarray, sex_norm: np.ndarray, thr: float) -> Dict[str, Any]:
    """
    Computes false negative rate (FNR) within PD-only clips for M and F, then:
      ΔFNR = FNR(F) − FNR(M)
    """
    y_true = np.asarray(y_true, dtype=np.int64)
    y_prob = np.asarray(y_prob, dtype=np.float64)
    sex_norm = np.asarray(sex_norm, dtype=object)

    pd_mask = (y_true == 1)
    out = {"fnr_M": float("nan"), "fnr_F": float("nan"), "delta_f_minus_m": float("nan"),
           "n_pd_M": 0, "n_pd_F": 0}

    for g in ["M", "F"]:
        gm = pd_mask & (sex_norm == g)
        n_pd = int(gm.sum())
        out[f"n_pd_{g}"] = n_pd
        if n_pd == 0:
            continue
        pred = (y_prob[gm] >= float(thr)).astype(np.int64)
        fn = int((pred == 0).sum())
        out[f"fnr_{g}"] = float(fn / n_pd)

    if not np.isnan(out["fnr_M"]) and not np.isnan(out["fnr_F"]):
        out["delta_f_minus_m"] = float(out["fnr_F"] - out["fnr_M"])
    return out

def score_df_by_taskgroup(df: pd.DataFrame, thr: float) -> Dict[str, Any]:
    # Returns AUROC and ΔFNR metrics separately for vowel clips and other clips.
    out = {}
    for tg in ["vowel", "other"]:
        sub = df[df["task_group"].astype(str) == tg].copy()
        yt = sub["y_true"].to_numpy(dtype=np.int64)
        yp = sub["y_score"].to_numpy(dtype=np.float64)
        sx = sub["sex_norm"].to_numpy(dtype=object)
        out[tg] = {
            "n": int(len(sub)),
            "auc": compute_auc(yt, yp),
            "delta_fnr": fairness_delta_fnr_pd_only(yt, yp, sx, thr),
        }
    return out

def mean_sd(arr: np.ndarray) -> Tuple[float, float]:
    # Convenience for plotting mean ± SD across the three seeds.
    arr = np.asarray(arr, dtype=np.float64)
    return float(np.nanmean(arr)), float(np.nanstd(arr, ddof=0))

# -------------------------
# Forced-head inference for all conditions and seeds
# -------------------------
# For each condition and seed:
# - read the seed's saved threshold from normal predictions.csv
# - load best_heads.pt
# - write predictions_force_vowel.csv and predictions_force_other.csv
RUN_STAMP = time.strftime("%Y%m%d_%H%M%S")
PLOTS_ALL_DIR = str(Path(MECH_EVAL_ROOT) / "plots_all_conditions")
Path(PLOTS_ALL_DIR).mkdir(parents=True, exist_ok=True)

forced_index = {}
t0 = time.time()

for tag in ["base", "trainEnh1", "trainEnh2", "trainEnh3"]:
    forced_index[tag] = {"seeds": SEEDS, "files_by_seed": {}}

    exp_dir = Path(TRVAL_EXP_DIRS[tag])
    trainval_exp_tag = exp_dir.name

    for seed in SEEDS:
        normal_pred = RUNS_RESOLVED[tag]["predictions_by_seed"][seed]
        thr = read_thr_used_global(normal_pred)

        best_heads = exp_dir / f"run_D7_seed{seed}" / "best_heads.pt"
        if not best_heads.exists():
            raise FileNotFoundError(f"[{tag}] best_heads.pt missing for seed{seed}: {str(best_heads)}")

        out_run_dir = Path(MECH_EVAL_ROOT) / tag / "forced_head_runs" / f"run_D7_on_D2test_seed{seed}"
        out_force_v = out_run_dir / "predictions_force_vowel.csv"
        out_force_o = out_run_dir / "predictions_force_other.csv"

        # Build model and load the trained heads for this seed.
        model = Wav2Vec2TwoHeadClassifier(BACKBONE_CKPT, dropout_p=DROPOUT_P).to(DEVICE)
        model = load_heads_into_model(model, str(best_heads))

        # Forced vowel head for all clips.
        cp, yt, yp, sx, sid, tg = infer_forced(test_loader, model, "vowel", f"[{tag} seed{seed}] FORCE_VOWEL_HEAD")
        write_force_csv(str(out_force_v), cp, yt, yp, sx, sid, tg, seed, "vowel", thr, trainval_exp_tag, RUN_STAMP)

        # Forced other head for all clips.
        cp2, yt2, yp2, sx2, sid2, tg2 = infer_forced(test_loader, model, "other", f"[{tag} seed{seed}] FORCE_OTHER_HEAD")
        write_force_csv(str(out_force_o), cp2, yt2, yp2, sx2, sid2, tg2, seed, "other", thr, trainval_exp_tag, RUN_STAMP)

        forced_index[tag]["files_by_seed"][str(seed)] = {
            "threshold_used_global": float(thr),
            "predictions_normal": str(normal_pred),
            "predictions_force_vowel": str(out_force_v),
            "predictions_force_other": str(out_force_o),
        }

# Registry file listing all produced outputs (paths and thresholds).
registry_path = Path(MECH_EVAL_ROOT) / "forced_head_registry.json"
with open(registry_path, "w", encoding="utf-8") as f:
    json.dump(forced_index, f, indent=2)
print("\nWROTE:", str(registry_path))
print("Forced inference total seconds:", time.time() - t0)

# -------------------------
# Score all modes and write a long metrics table
# -------------------------
# Produces one row per: condition × seed × routing_mode × task_group.
rows = []
for tag in ["base", "trainEnh1", "trainEnh2", "trainEnh3"]:
    for seed in SEEDS:
        info = forced_index[tag]["files_by_seed"][str(seed)]
        thr = float(info["threshold_used_global"])

        df_n = pd.read_csv(info["predictions_normal"])
        df_v = pd.read_csv(info["predictions_force_vowel"])
        df_o = pd.read_csv(info["predictions_force_other"])

        # Verify the needed columns exist for scoring and slicing.
        for name, df in [("normal", df_n), ("force_vowel", df_v), ("force_other", df_o)]:
            need = ["y_true", "y_score", "sex_norm", "task_group"]
            miss = [c for c in need if c not in df.columns]
            if miss:
                raise KeyError(f"[{tag} seed{seed} {name}] missing required cols {miss}")

        for mode, df in [("normal", df_n), ("force_vowel", df_v), ("force_other", df_o)]:
            scored = score_df_by_taskgroup(df, thr)
            for tg in ["vowel", "other"]:
                rows.append({
                    "condition": tag,
                    "seed": int(seed),
                    "routing_mode": mode,
                    "task_group": tg,
                    "auc": float(scored[tg]["auc"]),
                    "delta_fnr_f_minus_m": float(scored[tg]["delta_fnr"]["delta_f_minus_m"]),
                    "n_pd_m": int(scored[tg]["delta_fnr"]["n_pd_M"]),
                    "n_pd_f": int(scored[tg]["delta_fnr"]["n_pd_F"]),
                    "threshold_used_global": float(thr),
                })

metrics_long = pd.DataFrame(rows)
metrics_csv = Path(MECH_EVAL_ROOT) / "forced_head_metrics_long.csv"
metrics_long.to_csv(metrics_csv, index=False)
print("WROTE:", str(metrics_csv))

# -------------------------
# Per-condition JSON summaries
# -------------------------
# Writes a compact per-seed summary for each condition so results stay localized.
for tag in ["base", "trainEnh1", "trainEnh2", "trainEnh3"]:
    out_json = Path(MECH_EVAL_ROOT) / tag / "forced_head_eval_summary.json"
    pack = {"tag": tag, "run_stamp": RUN_STAMP, "seeds": SEEDS, "by_seed": {}}

    for seed in SEEDS:
        info = forced_index[tag]["files_by_seed"][str(seed)]
        thr = float(info["threshold_used_global"])

        df_n = pd.read_csv(info["predictions_normal"])
        df_v = pd.read_csv(info["predictions_force_vowel"])
        df_o = pd.read_csv(info["predictions_force_other"])

        pack["by_seed"][str(seed)] = {
            "threshold_used_global": thr,
            "normal": score_df_by_taskgroup(df_n, thr),
            "force_vowel": score_df_by_taskgroup(df_v, thr),
            "force_other": score_df_by_taskgroup(df_o, thr),
            "files": info,
        }

    with open(out_json, "w", encoding="utf-8") as f:
        json.dump(pack, f, indent=2)
    print("WROTE:", str(out_json))

# -------------------------
# Plot helpers (mean ± SD across seeds)
# -------------------------
# Aggregates across seeds so each bar or point represents mean performance with variability.
CONDS = ["base", "trainEnh1", "trainEnh2", "trainEnh3"]
MODES = ["normal", "force_vowel", "force_other"]

def summarize_metric(task_group: str, metric_col: str) -> Dict[Tuple[str, str], Tuple[float, float]]:
    """
    Returns {(condition, routing_mode): (mean, sd)} over seeds.
    metric_col: "auc" or "delta_fnr_f_minus_m"
    """
    out = {}
    for cond in CONDS:
        for mode in MODES:
            sub = metrics_long[
                (metrics_long["condition"] == cond) &
                (metrics_long["routing_mode"] == mode) &
                (metrics_long["task_group"] == task_group)
            ]
            vals = sub[metric_col].to_numpy(dtype=np.float64)
            out[(cond, mode)] = mean_sd(vals)
    return out

def plot_grouped_bars(task_group: str, metric_col: str, ylabel: str, title: str, out_png: str):
    # Creates one grouped bar chart per task group with three routing modes.
    summ = summarize_metric(task_group, metric_col)
    x = np.arange(len(CONDS))
    bar_w = 0.24
    offsets = {"normal": -bar_w, "force_vowel": 0.0, "force_other": bar_w}

    plt.figure(figsize=(12, 4.8))
    for mode in MODES:
        means = [summ[(c, mode)][0] for c in CONDS]
        sds   = [summ[(c, mode)][1] for c in CONDS]
        plt.bar(x + offsets[mode], means, width=bar_w, yerr=sds, capsize=3, label=mode)

    plt.xticks(x, CONDS)
    plt.xlabel("condition")
    plt.ylabel(ylabel)
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_png, dpi=200)
    plt.close()
    print("WROTE:", out_png)

# -------------------------
# Required plots (5 total)
# -------------------------
# (1) AUROC by condition and routing mode — vowel-only
plot_grouped_bars(
    task_group="vowel",
    metric_col="auc",
    ylabel="AUROC (mean ± SD across seeds)",
    title="AUROC by condition and routing mode — vowel-only",
    out_png=str(Path(PLOTS_ALL_DIR) / "auroc_by_condition_vowel_only.png"),
)

# (2) AUROC by condition and routing mode — other-only
plot_grouped_bars(
    task_group="other",
    metric_col="auc",
    ylabel="AUROC (mean ± SD across seeds)",
    title="AUROC by condition and routing mode — other-only",
    out_png=str(Path(PLOTS_ALL_DIR) / "auroc_by_condition_other_only.png"),
)

# (3) ΔFNR by condition and routing mode — vowel-only
plot_grouped_bars(
    task_group="vowel",
    metric_col="delta_fnr_f_minus_m",
    ylabel="ΔFNR = FNR(F) − FNR(M) (mean ± SD across seeds)",
    title="ΔFNR by condition and routing mode — vowel-only",
    out_png=str(Path(PLOTS_ALL_DIR) / "deltaFNR_by_condition_vowel_only.png"),
)

# (4) ΔFNR by condition and routing mode — other-only
plot_grouped_bars(
    task_group="other",
    metric_col="delta_fnr_f_minus_m",
    ylabel="ΔFNR = FNR(F) − FNR(M) (mean ± SD across seeds)",
    title="ΔFNR by condition and routing mode — other-only",
    out_png=str(Path(PLOTS_ALL_DIR) / "deltaFNR_by_condition_other_only.png"),
)

# (5) Head sensitivity: AUROC(force vowel) − AUROC(force other), shown for vowel clips and other clips
sens_rows = []
for cond in CONDS:
    for seed in SEEDS:
        for tg in ["vowel", "other"]:
            a_v = metrics_long[(metrics_long["condition"] == cond) & (metrics_long["seed"] == seed) &
                               (metrics_long["routing_mode"] == "force_vowel") & (metrics_long["task_group"] == tg)]["auc"].iloc[0]
            a_o = metrics_long[(metrics_long["condition"] == cond) & (metrics_long["seed"] == seed) &
                               (metrics_long["routing_mode"] == "force_other") & (metrics_long["task_group"] == tg)]["auc"].iloc[0]
            sens_rows.append({"condition": cond, "seed": seed, "task_group": tg, "sens": float(a_v - a_o)})

sens_df = pd.DataFrame(sens_rows)

sens_png = str(Path(PLOTS_ALL_DIR) / "head_sensitivity_auc_forceV_minus_forceO.png")
plt.figure(figsize=(12, 4.8))
x = np.arange(len(CONDS))
for tg in ["vowel", "other"]:
    sub = sens_df[sens_df["task_group"] == tg]
    means = []
    sds = []
    for cond in CONDS:
        vals = sub[sub["condition"] == cond]["sens"].to_numpy(dtype=np.float64)
        m, sd = mean_sd(vals)
        means.append(m)
        sds.append(sd)
    plt.errorbar(x, means, yerr=sds, fmt="o-", capsize=3, label=tg)
plt.axhline(0.0, linewidth=1)
plt.xticks(x, CONDS)
plt.xlabel("condition")
plt.ylabel("AUROC(force vowel) − AUROC(force other) (mean ± SD)")
plt.title("Head sensitivity by task group")
plt.legend()
plt.tight_layout()
plt.savefig(sens_png, dpi=200)
plt.close()
print("WROTE:", sens_png)

# -------------------------
# Runtime unload (same pattern as other test-only cells)
# -------------------------
try:
    from google.colab import runtime  # type: ignore
    print("\nUnassigning runtime (Colab)...")
    runtime.unassign()
except Exception as e:
    print("\nNOTE: Could not unassign runtime. Reason:", repr(e))

#Interpretability Study (XAI): Prediction Score Distribution Analysis

The following cell creates simple **explainability-style score histograms** using prediction results that were already generated earlier. It does not run any new model inference. The cell reads the saved `predictions.csv` files for the Base D7 model and for three enhanced versions of D7 that were trained with 10 percent of D2 data using different speaker selections (**trainEnh1**, **trainEnh2**, and **trainEnh3**). For each model, predictions from the three random seeds are combined at the clip level by averaging the predicted Parkinson’s probability for the same audio file. This produces one pooled score per clip, which makes it easier to see how clearly the model separates Parkinson’s speech from healthy speech.

Using these pooled scores, the cell generates histogram plots showing how prediction scores are distributed. Two types of plots are created for each model. The first compares score distributions for Parkinson’s versus healthy clips overall. The second breaks the same comparison down by sex, with separate plots for male and female speakers. Clips with unknown or non-standard sex labels are excluded from the sex-specific plots so the comparisons remain clean and easy to interpret. Along with the plots, the pooled per-clip prediction values are saved as CSV files so the underlying data can be reviewed or reused later.

To make loading files reliable across different run folder layouts, the cell first checks for `predictions.csv` files directly inside the expected run directory. If they are not found there, it follows pointer files such as `tag_run_pointer.json` or `last_run_pointer.json` to locate the actual run folder that contains the seed-specific prediction outputs. Each predictions file is validated to ensure required columns are present, that labels match across seeds for the same clip, and that any missing seed results are reported with warnings instead of being silently ignored.

All histogram plots are written to an `xai_histograms/` subfolder inside each model’s run directory. Plots for the Base D7 model are clearly labeled as baseline results without any target data adaptation. Plots for the enhanced models are labeled to show that the model was trained with 10 percent of D2 data, with each enhanced draw shown separately. This layout keeps comparisons clear while maintaining a clean and traceable record of how score distributions change from the base model to each enhanced version.

In [None]:
# =========================
# XAI HISTOGRAMS (NO INFERENCE) — Base vs D7 + 10% D2 (3 draws)
# - Reads existing predictions.csv (no rerun)
# - Pools 3 seeds by ENSEMBLE MEAN per clip_path
# - Plots simple score histograms:
#     (A) Overall: PD vs HC
#     (B) By sex: PD M/F and HC M/F
# - Saves SEPARATE PNGs per histogram under:
#     <MODEL_TAG_ROOT>/xai_histograms/
#
# IMPORTANT FIX vs your error:
# Enhanced runs store predictions.csv under RUN_ROOT (tag+stamp), not TAG_ROOT.
# So we:
#   1) Try to find predictions.csv under TAG_ROOT
#   2) If none found, read TAG_ROOT/tag_run_pointer.json -> run_root, then search there
#
# IMPORTANT CLARIFICATION (per your note):
# - "Base D7" does NOT include any 10% D2 training.
# - Therefore, we DO NOT label anything in the Base folder as "D7 + 10% D2".
#   In the Base folder we write BASE-only histograms.
#   In each Enhanced folder, we write Base vs that Enhanced draw.
# =========================

import json
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# -------------------------
# 0) Inputs (from your message)
# -------------------------
D7_OUT_ROOT = "/content/drive/MyDrive/AI_PD_Project/Datasets/D7-Multilingual (D1_D4_D5v2_D6)/preprocessed_v1"

BASE_TAG_ROOT = "/content/drive/MyDrive/AI_PD_Project/Datasets/D7-Multilingual (D1_D4_D5v2_D6)/preprocessed_v1/multilingual_test_runs/run_exp_frozen_LNDO_base_initBaseline_20251229_060853"

ENH_TAG_ROOTS = {
    "trainEnh1": "/content/drive/MyDrive/AI_PD_Project/Datasets/D7-Multilingual (D1_D4_D5v2_D6)/preprocessed_v1/monolingual_test_runs/run_exp_frozen_LNDO_trainEnh1_initBaseline_20251227_224654",
    "trainEnh2": "/content/drive/MyDrive/AI_PD_Project/Datasets/D7-Multilingual (D1_D4_D5v2_D6)/preprocessed_v1/monolingual_test_runs/run_exp_frozen_LNDO_trainEnh2_initBaseline_20251227_184205",
    "trainEnh3": "/content/drive/MyDrive/AI_PD_Project/Datasets/D7-Multilingual (D1_D4_D5v2_D6)/preprocessed_v1/monolingual_test_runs/run_exp_frozen_LNDO_trainEnh3_initBaseline_20251227_203719",
}

SEEDS = [1337, 2024, 7777]

# Recommended, defensible value: 50 bins for probabilities in [0,1] (bin width = 0.02)
BINS = 50

XAI_DIRNAME = "xai_histograms"

# -------------------------
# 1) Helpers
# -------------------------
def _read_json(p: Path) -> dict:
    with open(p, "r", encoding="utf-8") as f:
        return json.load(f)

def _resolve_run_root_from_tag(tag_root: Path) -> Path:
    """
    Returns a path where predictions.csv files actually live.
    - First try tag_root itself (Base typically stores preds in subfolders under tag_root)
    - If not found, try tag_root/tag_run_pointer.json -> run_root (Enhanced)
    - If still not found, try TEST_ROOT/last_run_pointer.json as fallback
    """
    tag_root = Path(tag_root)

    # Quick success: predictions exist under tag_root
    if list(tag_root.glob("**/predictions.csv")):
        return tag_root

    # Enhanced normal case: tag pointer exists under TAG_ROOT
    tag_ptr = tag_root / "tag_run_pointer.json"
    if tag_ptr.exists():
        obj = _read_json(tag_ptr)
        # Enhanced code may store either run_root or stamp_run_dir
        run_root = obj.get("run_root", None) or obj.get("stamp_run_dir", None)
        if run_root:
            return Path(run_root)

    # Fallback: last_run_pointer.json in the parent TEST_ROOT
    test_root = tag_root.parent
    last_ptr = test_root / "last_run_pointer.json"
    if last_ptr.exists():
        obj = _read_json(last_ptr)
        run_root = obj.get("run_root", None) or obj.get("stamp_run_dir", None)
        if run_root:
            return Path(run_root)

    raise FileNotFoundError(
        "Could not locate any predictions.csv under TAG_ROOT, and could not resolve a run_root via pointer files.\n"
        f"TAG_ROOT checked: {str(tag_root)}\n"
        f"Expected pointer file: {str(tag_root / 'tag_run_pointer.json')}\n"
        f"Fallback pointer file: {str(tag_root.parent / 'last_run_pointer.json')}"
    )

def _find_predictions_csvs(run_root: Path, seeds) -> list[tuple[int, Path]]:
    """
    Find predictions.csv for each seed under run_root, using either:
    - folder/file path containing 'seed{seed}', OR
    - reading the 'seed' column inside predictions.csv
    """
    run_root = Path(run_root)
    all_preds = sorted(run_root.glob("**/predictions.csv"))
    if not all_preds:
        raise FileNotFoundError(f"No predictions.csv found under resolved RUN_ROOT: {str(run_root)}")

    chosen = []

    # First pass: pick files by seed hint in path
    for s in seeds:
        matches = [p for p in all_preds if f"seed{s}" in str(p)]
        if len(matches) == 1:
            chosen.append((s, matches[0]))
        elif len(matches) > 1:
            # pick the deepest (typically per-seed run folder) to avoid accidental duplicates
            matches = sorted(matches, key=lambda p: len(str(p).split("/")), reverse=True)
            chosen.append((s, matches[0]))

    # If we got all seeds, done
    if len(chosen) == len(seeds):
        return chosen

    # Second pass: load each predictions.csv and filter by seed column if present
    chosen = []
    for s in seeds:
        found = None
        for p in all_preds:
            try:
                df = pd.read_csv(p)
            except Exception:
                continue
            if "seed" in df.columns and (df["seed"] == s).any():
                found = p
                break
        if found is None:
            raise FileNotFoundError(
                f"Could not find predictions.csv for seed={s} under {str(run_root)}.\n"
                "Searched by path hint 'seed####' and by reading 'seed' column."
            )
        chosen.append((s, found))

    return chosen

def load_and_pool_from_tag_root(tag_root: Path, model_name: str, seeds) -> pd.DataFrame:
    """
    Reads predictions.csv for seeds and returns one pooled row per clip_path:
      y_score_mean = mean(y_score across seeds)
    Keeps y_true and sex_norm (assumed consistent per clip_path).
    """
    tag_root = Path(tag_root)
    run_root = _resolve_run_root_from_tag(tag_root)
    print(f"\n[{model_name}] TAG_ROOT: {str(tag_root)}")
    print(f"[{model_name}] RESOLVED RUN_ROOT for predictions: {str(run_root)}")

    pred_files = _find_predictions_csvs(run_root, seeds)
    print(f"[{model_name}] predictions.csv files:")
    for s, p in pred_files:
        print(f"  seed={s}: {str(p)}")

    dfs = []
    for s, p in pred_files:
        df = pd.read_csv(p)

        required = ["clip_path", "y_true", "y_score", "sex_norm"]
        missing = [c for c in required if c not in df.columns]
        if missing:
            raise ValueError(
                f"[{model_name}] Missing columns in {str(p)}: {missing}\n"
                f"Columns found: {list(df.columns)}"
            )

        df = df[["clip_path", "y_true", "y_score", "sex_norm"]].copy()
        df["clip_path"] = df["clip_path"].astype(str)
        df["y_true"] = df["y_true"].astype(int)
        df["y_score"] = df["y_score"].astype(float)
        df["sex_norm"] = df["sex_norm"].astype(str).str.upper().str.strip()

        df["seed"] = int(s)
        dfs.append(df)

    all_df = pd.concat(dfs, axis=0, ignore_index=True)

    # Sanity: y_true should not disagree across seeds for the same clip_path
    chk = all_df.groupby("clip_path")["y_true"].nunique()
    bad = chk[chk > 1]
    if len(bad) > 0:
        ex = bad.index[:10].tolist()
        raise RuntimeError(
            f"[{model_name}] y_true mismatch across seeds for some clip_path. Examples: {ex}"
        )

    # Pool: mean score per clip_path across seeds
    pooled = (
        all_df.groupby("clip_path", as_index=False)
        .agg(
            y_true=("y_true", "first"),
            sex_norm=("sex_norm", "first"),
            y_score_mean=("y_score", "mean"),
            n_seeds=("seed", "nunique"),
        )
    )

    # Safety check: ensure each clip got all seeds
    missing_seeds = pooled[pooled["n_seeds"] != len(seeds)]
    if len(missing_seeds) > 0:
        print(f"[{model_name}] WARNING: {len(missing_seeds)} clips did not have all {len(seeds)} seeds.")
        print("Showing first 10 missing-seed clips:")
        display(missing_seeds.head(10))

    return pooled

def _save_hist(values: np.ndarray, out_png: Path, title: str, bins: int):
    """
    Save a simple histogram as a separate PNG.
    Guard: if values is empty, skip and print a note (prevents matplotlib errors and misleading blank plots).
    """
    out_png.parent.mkdir(parents=True, exist_ok=True)

    values = np.asarray(values, dtype=np.float64)
    if values.size == 0:
        print("NOTE: No data for plot, skipping:", out_png.name, "|", title)
        return

    plt.figure()
    plt.hist(values, bins=bins)
    plt.title(title)
    plt.xlabel("Predicted PD probability (y_score)")
    plt.ylabel("Count")
    plt.tight_layout()
    plt.savefig(out_png, dpi=200)
    plt.close()

def _prep_slices(df: pd.DataFrame) -> pd.DataFrame:
    """
    Standardize label and sex for plotting.
    """
    df = df.copy()
    df["label"] = np.where(df["y_true"] == 1, "PD", "HC")
    df["sex_norm"] = df["sex_norm"].astype(str).str.upper().str.strip()
    df.loc[~df["sex_norm"].isin(["M", "F"]), "sex_norm"] = "UNK"
    return df

def write_base_only_histograms(base_tag_root: Path, base_pooled: pd.DataFrame, bins: int):
    """
    Writes BASE-only plots under:
      <BASE_TAG_ROOT>/xai_histograms/
    This avoids any misleading "10% D2" wording in Base folder titles.
    """
    base_tag_root = Path(base_tag_root)
    out_dir = base_tag_root / XAI_DIRNAME
    out_dir.mkdir(parents=True, exist_ok=True)

    b = _prep_slices(base_pooled)

    # Overall PD/HC
    for label in ["PD", "HC"]:
        _save_hist(
            b.loc[b["label"] == label, "y_score_mean"].to_numpy(),
            out_dir / f"Base__{label}__hist.png",
            title=f"Base D7 → D2 TEST | {label} | pooled across 3 seeds",
            bins=bins,
        )

    # By sex (PD and HC)
    for label in ["PD", "HC"]:
        for sex in ["M", "F"]:
            _save_hist(
                b.loc[(b["label"] == label) & (b["sex_norm"] == sex), "y_score_mean"].to_numpy(),
                out_dir / f"Base__{label}__SEX_{sex}__hist.png",
                title=f"Base D7 → D2 TEST | {label} | sex={sex} | pooled across 3 seeds",
                bins=bins,
            )

    # Save pooled CSV (audit trail)
    pooled_out = out_dir / "pooled_predictions__Base_D7.csv"
    base_pooled.to_csv(pooled_out, index=False)

    print(f"\nWROTE BASE-only XAI histograms to: {str(out_dir)}")

def write_base_vs_enh_histograms(enh_tag_root: Path, base_pooled: pd.DataFrame, enh_pooled: pd.DataFrame,
                                draw_name: str, bins: int):
    """
    Writes Base vs Enhanced plots under:
      <ENH_TAG_ROOT>/xai_histograms/
    Output PNGs:
      - Base PD, Base HC
      - Enhanced PD, Enhanced HC
      - Base PD M/F + HC M/F
      - Enhanced PD M/F + HC M/F

    Title wording is explicit:
      - Base D7 (no target adaptation)
      - D7 + 10% D2 (this draw)
    """
    enh_tag_root = Path(enh_tag_root)
    out_dir = enh_tag_root / XAI_DIRNAME
    out_dir.mkdir(parents=True, exist_ok=True)

    b = _prep_slices(base_pooled)
    e = _prep_slices(enh_pooled)

    # Overall PD/HC
    for label in ["PD", "HC"]:
        _save_hist(
            b.loc[b["label"] == label, "y_score_mean"].to_numpy(),
            out_dir / f"Base__{label}__hist.png",
            title=f"Base D7 (no target adaptation) → D2 TEST | {label} | pooled across 3 seeds",
            bins=bins,
        )
        _save_hist(
            e.loc[e["label"] == label, "y_score_mean"].to_numpy(),
            out_dir / f"{draw_name}__{label}__hist.png",
            title=f"D7 + 10% D2 ({draw_name}) → D2 TEST | {label} | pooled across 3 seeds",
            bins=bins,
        )

    # By sex (PD and HC)
    for label in ["PD", "HC"]:
        for sex in ["M", "F"]:
            _save_hist(
                b.loc[(b["label"] == label) & (b["sex_norm"] == sex), "y_score_mean"].to_numpy(),
                out_dir / f"Base__{label}__SEX_{sex}__hist.png",
                title=f"Base D7 (no target adaptation) → D2 TEST | {label} | sex={sex} | pooled across 3 seeds",
                bins=bins,
            )
            _save_hist(
                e.loc[(e["label"] == label) & (e["sex_norm"] == sex), "y_score_mean"].to_numpy(),
                out_dir / f"{draw_name}__{label}__SEX_{sex}__hist.png",
                title=f"D7 + 10% D2 ({draw_name}) → D2 TEST | {label} | sex={sex} | pooled across 3 seeds",
                bins=bins,
            )

    # Save pooled CSVs (audit trail)
    base_out = out_dir / "pooled_predictions__Base_D7.csv"
    enh_out  = out_dir / f"pooled_predictions__{draw_name}_D7plus10pctD2.csv"
    base_pooled.to_csv(base_out, index=False)
    enh_pooled.to_csv(enh_out, index=False)

    print(f"\nWROTE Base vs {draw_name} XAI histograms to: {str(out_dir)}")

# -------------------------
# 2) Load + pool BASE once
# -------------------------
base_tag_root = Path(BASE_TAG_ROOT)
base_pooled = load_and_pool_from_tag_root(base_tag_root, model_name="BASE", seeds=SEEDS)

# Write BASE-only histograms under the BASE folder itself (titles are now correct and unambiguous)
write_base_only_histograms(base_tag_root=base_tag_root, base_pooled=base_pooled, bins=BINS)

# -------------------------
# 3) For each draw: load + pool enhanced, then write Base vs Enhanced into that draw folder
# -------------------------
for draw_name, tag_root in ENH_TAG_ROOTS.items():
    enh_tag_root = Path(tag_root)
    enh_pooled = load_and_pool_from_tag_root(enh_tag_root, model_name=draw_name, seeds=SEEDS)

    write_base_vs_enh_histograms(
        enh_tag_root=enh_tag_root,
        base_pooled=base_pooled,
        enh_pooled=enh_pooled,
        draw_name=draw_name,
        bins=BINS,
    )

print("\nDONE.")

The following cell creates a single ROC curve that summarizes TrainEnh2 performance on the D2 test set across multiple random seeds and highlights two practical decision thresholds on that curve. It automatically locates the most recent TrainEnh2 test run by reading the test history index, finds each seed’s `predictions.csv`, and computes an ROC curve and AUROC for every seed. It also reads the validation-selected Youden J threshold that was saved in the prediction files, then separately finds the latest threshold sweep run for TrainEnh2 and loads the recommended fairness-constrained threshold from `sweep_summary.json`. If needed, it looks up the matching sensitivity and specificity from `sweep_table.csv`.

After loading the ROC data for all seeds, the cell builds a mean ROC curve by interpolating each seed’s ROC onto a shared false positive rate grid, then averaging the true positive rates. It also computes a simple spread band using the standard deviation to show how results vary from seed to seed. The Youden J operating point is computed using the pooled D2 test predictions across all seeds so that the plotted sensitivity and specificity reflect actual test-set performance. For the fairness-constrained operating point, the cell uses the sensitivity and specificity reported by the sweep and converts specificity to false positive rate for plotting.

Finally, the cell generates and saves `mean_roc_curve_with_operating_points.png` in the same TrainEnh2 test run folder. The plot shows faint ROC curves for individual seeds, the mean ROC curve with a shaded variation band, a dashed chance line, and two clearly labeled markers. One marker represents the Youden J (validation-optimal) threshold, and the other represents the recommended fairness-constrained threshold. Each marker is annotated with the threshold value and the corresponding sensitivity and specificity.

#Mean ROC Curve for TrainEnh2 to D2 Test

In [None]:
# Mean ROC + Two Operating Points (TrainEnh2 → D2 Test)
# Inputs:
#   - TrainEnh2 test predictions (per seed)
#   - TrainEnh2 threshold sweep summary/table (fairness-constrained threshold)
# Outputs:
#   - mean_roc_curve_with_operating_points.png saved in the TrainEnh2 test run folder

import json
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, roc_auc_score

# -------------------------
# 0) Roots and tags
# -------------------------
D7_OUT_ROOT = "/content/drive/MyDrive/AI_PD_Project/Datasets/D7-Multilingual (D1_D4_D5v2_D6)/preprocessed_v1"
TARGET_ENH_TAG = "trainEnh2"

# -------------------------
# 1) Locate latest TrainEnh2 TEST run (predictions)
# -------------------------
test_root = Path(D7_OUT_ROOT) / "monolingual_test_runs"
history_path = test_root / "history_index.jsonl"
if not history_path.exists():
    raise FileNotFoundError("Missing history_index.jsonl under monolingual_test_runs.")

latest_test = None
with open(history_path, "r", encoding="utf-8") as f:
    for line in f:
        obj = json.loads(line)
        if str(obj.get("enh_tag", "")).lower() == TARGET_ENH_TAG.lower():
            latest_test = obj  # last match wins

if latest_test is None:
    raise FileNotFoundError("No TrainEnh2 entries found in test history.")

run_dirs = latest_test.get("run_dirs", [])
if not run_dirs:
    raise FileNotFoundError("TrainEnh2 test entry missing run_dirs.")

run_root = Path(run_dirs[0]).parent
if not run_root.exists():
    raise FileNotFoundError("TrainEnh2 test run folder does not exist.")

print("Test run folder:", run_root)

# -------------------------
# 2) Load per-seed predictions and read Youden J (VAL-optimal) threshold
# -------------------------
pred_paths = sorted(run_root.glob("run_*_seed*/predictions.csv"))
if not pred_paths:
    raise FileNotFoundError("No predictions.csv files found in test run.")

seed_rocs = []
youden_thresholds = []

for p in pred_paths:
    df = pd.read_csv(p)

    if not {"y_true", "y_score"}.issubset(df.columns):
        raise ValueError(f"Missing y_true/y_score in {p}")

    y_true_seed = df["y_true"].to_numpy(int)
    y_score_seed = df["y_score"].to_numpy(float)

    if len(np.unique(y_true_seed)) < 2:
        raise ValueError(f"Only one class present in {p}")

    fpr, tpr, _ = roc_curve(y_true_seed, y_score_seed)
    auc = float(roc_auc_score(y_true_seed, y_score_seed))
    seed = int(p.parent.name.split("seed")[-1])

    seed_rocs.append({"seed": seed, "fpr": fpr, "tpr": tpr, "auc": auc})

    if "threshold_used_global" in df.columns:
        youden_thresholds.append(float(df["threshold_used_global"].iloc[0]))

if not youden_thresholds:
    raise ValueError("VAL-optimal threshold (threshold_used_global) not found in predictions.csv.")

youden_thr = float(youden_thresholds[0])
print(f"Youden J threshold (VAL-optimal): {youden_thr:.6f}")

# -------------------------
# 3) Locate latest TrainEnh2 threshold sweep (fairness threshold)
# -------------------------
sweep_root = Path(D7_OUT_ROOT) / "threshold_sweeps"
sweep_dirs = sorted(
    [d for d in sweep_root.glob("run_D7_trainEnh2_on_D2test_*") if d.is_dir()],
    key=lambda d: d.stat().st_mtime,
    reverse=True,
)
if not sweep_dirs:
    raise FileNotFoundError("No TrainEnh2 threshold_sweeps found under threshold_sweeps/.")

sweep_dir = sweep_dirs[0]
summary_path = sweep_dir / "sweep_summary.json"
table_path = sweep_dir / "sweep_table.csv"

if not summary_path.exists():
    raise FileNotFoundError("Missing sweep_summary.json in the latest threshold sweep folder.")

with open(summary_path, "r", encoding="utf-8") as f:
    sweep_summary = json.load(f)

# Fairness-constrained recommendation (as written by the  sweep code)
if "chosen_threshold" not in sweep_summary:
    raise ValueError(
        "sweep_summary.json must contain 'chosen_threshold'. "
        f"Found keys: {list(sweep_summary.keys())}"
    )

fair_thr = float(sweep_summary["chosen_threshold"])
print(f"Fairness-constrained threshold (chosen_threshold): {fair_thr:.6f}")
print("Threshold sweep folder:", sweep_dir)

# -------------------------
# 3.5) Extract fairness operating-point sens/spec (best effort)
# -------------------------
def _find_metric(d: dict, candidates: list[str]):
    for k in candidates:
        if k in d and d[k] is not None:
            try:
                return float(d[k])
            except Exception:
                pass
    return None

fair_sens = None
fair_spec = None

chosen_metrics = sweep_summary.get("chosen_metrics", {}) or {}

# Try common names first
fair_sens = _find_metric(chosen_metrics, ["sensitivity", "mean_sensitivity", "tpr", "TPR"])
fair_spec = _find_metric(chosen_metrics, ["specificity", "mean_specificity", "tnr", "TNR"])

# If not found, fall back to sweep_table.csv by nearest threshold
if (fair_sens is None) or (fair_spec is None):
    if not table_path.exists():
        raise ValueError(
            "chosen_metrics did not contain sensitivity/specificity, and sweep_table.csv is missing. "
            "Cannot label fairness point with sens/spec."
        )

    tbl = pd.read_csv(table_path)

    # Find the threshold column
    thr_col = None
    for c in ["threshold", "thr", "thresh"]:
        if c in tbl.columns:
            thr_col = c
            break
    if thr_col is None:
        raise ValueError(f"sweep_table.csv missing a threshold column. Found: {list(tbl.columns)}")

    # Find sensitivity/specificity columns
    sens_col = None
    spec_col = None
    sens_candidates = ["mean_sensitivity", "sensitivity", "tpr", "TPR"]
    spec_candidates = ["mean_specificity", "specificity", "tnr", "TNR"]

    for c in sens_candidates:
        if c in tbl.columns:
            sens_col = c
            break
    for c in spec_candidates:
        if c in tbl.columns:
            spec_col = c
            break

    if sens_col is None or spec_col is None:
        raise ValueError(
            "sweep_table.csv does not have recognizable sensitivity/specificity columns. "
            f"Found: {list(tbl.columns)}"
        )

    # Pick nearest threshold row
    thr_vals = pd.to_numeric(tbl[thr_col], errors="coerce").to_numpy(dtype=float)
    if np.all(np.isnan(thr_vals)):
        raise ValueError("Threshold column in sweep_table.csv could not be parsed as numbers.")

    idx = int(np.nanargmin(np.abs(thr_vals - fair_thr)))
    fair_thr = float(thr_vals[idx])  # snap to actual grid value used in sweep
    fair_sens = float(tbl.iloc[idx][sens_col])
    fair_spec = float(tbl.iloc[idx][spec_col])

# -------------------------
# 4) Build mean ROC across seeds
# -------------------------
fpr_grid = np.linspace(0.0, 1.0, 501)
tprs = []

for r in seed_rocs:
    fpr = np.asarray(r["fpr"], dtype=float)
    tpr = np.asarray(r["tpr"], dtype=float)

    if fpr[0] > 0.0:
        fpr = np.concatenate([[0.0], fpr])
        tpr = np.concatenate([[0.0], tpr])
    if fpr[-1] < 1.0:
        fpr = np.concatenate([fpr, [1.0]])
        tpr = np.concatenate([tpr, [1.0]])

    tprs.append(np.interp(fpr_grid, fpr, tpr))

tprs = np.vstack(tprs)
mean_tpr = tprs.mean(axis=0)
std_tpr = tprs.std(axis=0, ddof=1) if tprs.shape[0] > 1 else np.zeros_like(mean_tpr)
mean_auc = float(np.mean([r["auc"] for r in seed_rocs]))

# -------------------------
# 5) Compute Youden J operating point on D2 TEST (pooled predictions)
# -------------------------
all_true, all_score = [], []
for p in pred_paths:
    df = pd.read_csv(p)
    all_true.append(df["y_true"].to_numpy(int))
    all_score.append(df["y_score"].to_numpy(float))

y_true = np.concatenate(all_true)
y_score = np.concatenate(all_score)

def operating_point(thr: float):
    y_pred = (y_score >= float(thr)).astype(int)

    tp = int(((y_pred == 1) & (y_true == 1)).sum())
    fn = int(((y_pred == 0) & (y_true == 1)).sum())
    tn = int(((y_pred == 0) & (y_true == 0)).sum())
    fp = int(((y_pred == 1) & (y_true == 0)).sum())

    eps = 1e-12
    sens = tp / (tp + fn + eps)
    spec = tn / (tn + fp + eps)
    fpr = 1.0 - spec
    return {"thr": float(thr), "sens": float(sens), "spec": float(spec), "fpr": float(fpr)}

pt_you = operating_point(youden_thr)

# Fairness point: use sweep-provided sens/spec, compute FPR from specificity
pt_fair = {
    "thr": float(fair_thr),
    "sens": float(fair_sens),
    "spec": float(fair_spec),
    "fpr": float(1.0 - float(fair_spec)),
}

print("\nOperating points:")
print(f"  Youden J (VAL-optimal): thr={pt_you['thr']:.6f}, sens={pt_you['sens']:.4f}, spec={pt_you['spec']:.4f}")
print(f"  Fairness-constrained:   thr={pt_fair['thr']:.6f}, sens={pt_fair['sens']:.4f}, spec={pt_fair['spec']:.4f}")

# -------------------------
# 6) Plot and save
# -------------------------
plt.figure()

# Seed curves (light context)
for r in seed_rocs:
    plt.plot(r["fpr"], r["tpr"], alpha=0.35)

# Mean ROC + spread band
plt.plot(fpr_grid, mean_tpr)
plt.fill_between(
    fpr_grid,
    np.clip(mean_tpr - std_tpr, 0.0, 1.0),
    np.clip(mean_tpr + std_tpr, 0.0, 1.0),
    alpha=0.15,
)

# Fairness-constrained point with sens/spec
plt.scatter(
    [pt_fair["fpr"]],
    [pt_fair["sens"]],
    s=80,
    zorder=5,
    label=(
        "Recomm. Fairness-Constrained Threshold\n"
        f"thr={pt_fair['thr']:.3f}, "
        f"Sens={pt_fair['sens']:.2f}, Spec={pt_fair['spec']:.2f}"
    ),
)

# Youden J point (VAL-optimal) with sens/spec
plt.scatter(
    [pt_you["fpr"]],
    [pt_you["sens"]],
    s=80,
    zorder=5,
    label=(
        "Youden J Threshold (VAL-optimal)\n"
        f"thr={pt_you['thr']:.3f}, "
        f"Sens={pt_you['sens']:.2f}, Spec={pt_you['spec']:.2f}"
    ),
)

# Chance line
plt.plot([0, 1], [0, 1], linestyle="--")

plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title(f"Mean ROC (TrainEnh2 → D2 Test) | mean AUROC={mean_auc:.3f} | seeds={len(seed_rocs)}")
plt.legend(loc="lower right", frameon=True)
plt.tight_layout()

out_png = run_root / "mean_roc_curve_with_operating_points.png"
plt.savefig(out_png, dpi=150)
plt.show()
plt.close()

print("\nSaved plot:", str(out_png))

The following cell runs a complete “target data ablation” experiment to study how adding small, speaker-balanced amounts of D2 training data into the multilingual D7 training set affects performance on the D2 test split. It loads the base D7 manifest along with three prebuilt D7 training enhancement manifests (TrainEnh1, TrainEnh2, and TrainEnh3). These enhancement manifests act as separate, non-overlapping pools of D2 speakers that can be added in controlled steps. Multiple training conditions are created, starting with 0 percent added D2 speakers (base only) and increasing through finer increments (5, 10, 15, 20, 25, and 29 percent) by combining half or full portions of the enhancement pools. For each condition, the cell builds fresh training and validation manifests, optionally stages audio clips to local storage for faster access, trains a frozen Wav2Vec2 model with two small task-specific heads across three random seeds, evaluates each trained model on the D2 test split, and records AUROC. At the end, it writes a summary table and a curve plot showing mean AUROC versus the percentage of D2 speakers added, with error bars across seeds.

The setup section defines the D7 and D2 dataset locations on Drive, creates an output folder for the ablation run, and fixes all key training settings. These include random seeds, number of epochs, early stopping patience, learning rate, effective batch size through gradient accumulation, dropout, device selection, and mixed precision when running on GPU. An option is also provided to copy audio clips to local SSD for faster reading. Several helper functions handle repeated tasks such as checking file existence, reading and writing CSV files with proper missing values, setting random seeds, normalizing clip paths into a consistent format, printing split counts, and running quick checks for missing or corrupted audio files before training starts.

The model used in this cell keeps the Wav2Vec2 backbone frozen and trains only two small classifier heads, one intended for vowel clips and one for all other clip types. Frame-level features are pooled into a single vector per clip using the attention mask, and simple routing logic sends each clip to the correct head based on its task label. The dataset and collate code load waveforms from disk, build attention masks, pad batches to a common length, and keep clip paths in the batch so that slow or failed reads can be traced to specific files. For vowel clips, the attention mask ignores trailing near-silence, while other clips use the full waveform.

A key part of the cell constructs the gradual D2 exposure steps in a controlled way. It reads each TrainEnh manifest, keeps only rows that truly come from D2 using a strict `source_dataset == "D2"` filter, and samples speakers rather than individual clips so all clips from a chosen speaker stay together. For the “half” steps, it selects half of the available Parkinson’s speakers and half of the available Healthy speakers from that pool, using a fixed random seed so the same speakers are chosen each time. It also checks that the enhancement pools do not share speakers, since the logic assumes the three draws are non-overlapping.

For each condition, the cell creates a condition-specific folder on Drive and writes a new `manifest_all.csv` that contains only training and validation splits. Training data consist of the base D7 training set plus the selected D2 rows for that condition, while validation data always come from the base D7 validation split. Duplicate clip paths are removed so the same audio file cannot be counted twice. A small JSON file is also written for each condition, recording how many speakers and clips were added and which enhancement parts were used, so each condition can be traced later.

Training is then run for each condition and each seed. Before training starts, the cell checks that all referenced audio files exist and performs a quick audio header scan to catch corrupted clips early. Training uses early stopping based on validation AUROC and saves the best model weights, which include only the two heads and related normalization and dropout blocks. A per-epoch history file is also written. Watchdog timers print warnings if data loading or training becomes unusually slow and include example clip paths to help identify problematic files. After training finishes for a condition, the saved best weights are reloaded and the model is evaluated on the D2 test split, writing per-seed predictions and metrics to disk.

At the end of the run, results from all conditions are combined into a single table containing per-seed AUROCs, the mean AUROC, and the standard deviation. This table is saved as a CSV in the ablation output folder. The cell then creates a plot of mean AUROC versus the percentage of D2 training speakers included, adds error bars to show variation across seeds, fits a quadratic trend line to show the overall pattern, and saves the figure as a PNG. Finally, it prints the total runtime and attempts to stop the Colab runtime automatically to avoid leaving a GPU session running.

#Ablation 4: Target Data Ablation

In [None]:
# ============================================================
# Target-Data Ablation: Finer Cumulative D2 Exposure (Two-Head Wav2Vec2)
# ============================================================
# Inputs:
# - D7 base manifest (train, val) and D7 TrainEnh1/2/3 manifests
# - D7 clips folder
# - D2 manifest (test split) and D2 clips folder
# Outputs:
# - Per-condition manifest on Drive (train + val)
# - Per-condition TrainVal runs saved on Drive
# - Per-condition D2 test runs saved on Drive
# - Summary CSV and curve plot saved under target_data_ablation_finer/

import os, json, time, shutil
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import Wav2Vec2Model
from sklearn.metrics import roc_auc_score

START_TIME = time.time()

# -------------------------
# Drive mount (Colab)
# -------------------------
try:
    from google.colab import drive  # type: ignore
    if not os.path.exists("/content/drive"):
        drive.mount("/content/drive")
except Exception:
    pass


# ============================================================
# Run setup
# ============================================================

D7_ROOT = Path("/content/drive/MyDrive/AI_PD_Project/Datasets/D7-Multilingual (D1_D4_D5v2_D6)/preprocessed_v1")
D2_ROOT = Path("/content/drive/MyDrive/AI_PD_Project/Datasets/D2-Slovak (EWA-DB)/EWA-DB/preprocessed_v1")

D7_MAN_DIR   = D7_ROOT / "manifests"
D7_CLIP_ROOT = D7_ROOT  # clip paths are stored like "clips/train/..wav"

ABL_ROOT = D7_ROOT / "target_data_ablation_finer"
ABL_ROOT.mkdir(parents=True, exist_ok=True)

SEEDS = [1337, 2024, 7777]

MAX_EPOCHS = 10
PATIENCE   = 2
LR         = 1e-3

EFFECTIVE_BS  = 64
PER_DEVICE_BS = 16
GRAD_ACCUM    = max(1, EFFECTIVE_BS // PER_DEVICE_BS)

BACKBONE_CKPT = "facebook/wav2vec2-base"
DROPOUT_P     = 0.2

# Task routing
VOWEL_TASK_VALUE = "vowl"
TINY_THRESH = 1e-4

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
USE_AMP = (DEVICE.type == "cuda")

# Progress + hang debugging
PRINT_EVERY_N_BATCHES = 25
FETCH_WATCHDOG_SEC    = 90
STEP_WATCHDOG_SEC     = 180
INFO_SCAN_MAX_PRINT   = 20

# Optional: stage clips to local SSD (recommended)
STAGE_TO_LOCAL   = True
LOCAL_CACHE_ROOT = Path("/content/abl_cache")
LOCAL_CACHE_ROOT.mkdir(parents=True, exist_ok=True)

# D2 filter rule: must use "source_dataset"
D2_SOURCE_KEY = "D2"  # strict on purpose; avoids silently pulling other datasets

# Finer points use fractions of each non-overlapping draw
DRAW_FRAC = {
    "A_half": 0.5,  # half of TrainEnh1
    "B_half": 0.5,  # half of TrainEnh2
    "C_half": 0.5,  # half of TrainEnh3
}

print("Target-Data Ablation: Finer Cumulative D2 Exposure")
print("--------------------------------------------------")
print(f"[Init] Output folder: {ABL_ROOT}")
print(f"[Init] Device: {DEVICE} | AMP: {USE_AMP}")
print(f"[Init] Stage to local SSD: {STAGE_TO_LOCAL}")


# ============================================================
# Helpers
# ============================================================

def _ensure_exists(p: Path, what: str) -> None:
    if not p.exists():
        raise FileNotFoundError(f"Missing {what} at: {str(p)}")

def _read_csv(p: Path) -> pd.DataFrame:
    return pd.read_csv(p)

def _write_csv(df: pd.DataFrame, p: Path) -> None:
    p.parent.mkdir(parents=True, exist_ok=True)
    df.to_csv(p, index=False, na_rep="")

def _set_seed(seed: int) -> None:
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.amp.autocast(enabled=USE_AMP):
        torch.cuda.manual_seed_all(seed)

def _new_run_stamp() -> str:
    return time.strftime("%Y%m%d_%H%M%S")

def _safe_unique_count(df: pd.DataFrame, col: str) -> int | None:
    if col not in df.columns:
        return None
    return int(df[col].nunique(dropna=True))

def _choose_train_split_name(df: pd.DataFrame) -> str:
    splits = df["split"].astype(str)
    if (splits == "train").any():
        return "train"
    candidates = splits[~splits.isin(["val", "test"])].value_counts()
    if len(candidates) == 0:
        raise ValueError("No usable training split found (excluding val/test).")
    return str(candidates.index[0])

def _norm_clip_path(raw: str) -> str:
    """
    Normalize to relative form:
      clips/<split_folder>/...wav
    """
    s = str(raw).replace("\\", "/").strip()
    if "://" in s:
        s = s.split("://", 1)[-1]
    if "/clips/" in s:
        s = "clips/" + s.split("/clips/", 1)[-1]
    elif s.startswith("clips/"):
        pass
    else:
        known = ["train", "val", "test", "train_enh1", "train_enh2", "train_enh3"]
        for k in known:
            if s.startswith(k + "/"):
                s = "clips/" + s
                break
    while "//" in s:
        s = s.replace("//", "/")
    return s.lstrip("/")

def _print_split_counts(manifest_path: Path, name: str) -> None:
    df = _read_csv(manifest_path)
    for c in ["split", "clip_path"]:
        if c not in df.columns:
            raise ValueError(f"{name}: missing required column '{c}'")

    spk_col = "speaker_id" if "speaker_id" in df.columns else None

    print(f"\n[Counts] {name}")
    for split in ["train", "val", "test"]:
        sub = df[df["split"].astype(str) == split]
        if len(sub) == 0:
            continue
        spk = int(sub[spk_col].nunique(dropna=True)) if spk_col else None
        msg = f"  - {split}: clips={len(sub)}"
        if spk is not None:
            msg += f" | speakers={spk}"
        print(msg)

def _scan_exists(clip_root: Path, df: pd.DataFrame, label: str) -> None:
    missing = 0
    for cp in df["clip_path"].astype(str).tolist():
        if not (clip_root / cp).exists():
            missing += 1
    print(f"[Check] {label}: total={len(df)} missing={missing}")

def _scan_soundfile_info(clip_root: Path, df: pd.DataFrame, label: str) -> None:
    """
    Quick file header scan (no full decode).
    """
    import soundfile as sf
    bad = []
    t0 = time.time()
    for cp in df["clip_path"].astype(str).tolist():
        p = clip_root / cp
        try:
            _ = sf.info(str(p))
        except Exception as e:
            bad.append((str(p), repr(e)))
            if len(bad) >= INFO_SCAN_MAX_PRINT:
                break
    dt = time.time() - t0
    if len(bad) == 0:
        print(f"[Info] {label}: header scan ok ({len(df)} files) | {dt:.1f}s")
    else:
        print(f"[Info] {label}: header scan found issues (showing up to {len(bad)}) | {dt:.1f}s")
        for p, e in bad:
            print("  -", p)
            print("    ", e)
        raise RuntimeError("Header scan failed. Fix or remove the listed files before training.")


# ============================================================
# Local staging (fast reads during training)
# ============================================================

def _copy_file_if_needed(src: Path, dst: Path) -> None:
    dst.parent.mkdir(parents=True, exist_ok=True)
    if dst.exists():
        if src.stat().st_size == dst.stat().st_size:
            return
        raise RuntimeError(f"Destination exists with different size:\n  src={src}\n  dst={dst}")
    shutil.copy2(str(src), str(dst))

def _stage_manifest_clips_to_local(cond_name: str, src_clip_root: Path, manifest_all: pd.DataFrame) -> Path:
    """
    Local staging for one condition.

    Input:
    - src_clip_root: root folder that contains "clips/..."
    - manifest_all: rows with clip_path values like "clips/train/...wav"

    Output:
    - local_root: /content/abl_cache/<cond_name> with:
        - clips/...
        - manifests/manifest_all.csv
    """
    # Option A: reset this condition folder to avoid stale clips
    local_root = LOCAL_CACHE_ROOT / cond_name
    if local_root.exists():
        shutil.rmtree(local_root)

    local_mans = local_root / "manifests"
    local_root.mkdir(parents=True, exist_ok=True)
    local_mans.mkdir(parents=True, exist_ok=True)

    _write_csv(manifest_all, local_mans / "manifest_all.csv")

    paths = manifest_all["clip_path"].astype(str).tolist()
    unique_paths = list(dict.fromkeys(paths))

    t0 = time.time()
    copied = 0
    for cp in tqdm(unique_paths, desc=f"Stage to local ({cond_name})", leave=False):
        src = src_clip_root / cp
        dst = local_root / cp
        if not src.exists():
            raise FileNotFoundError(f"Missing source file for staging: {src}")
        _copy_file_if_needed(src, dst)
        copied += 1
    dt = time.time() - t0
    print(f"[Stage] {cond_name}: copied_files={copied} | {dt/60.0:.1f} min")
    return local_root


# ============================================================
# Model (frozen backbone + two task heads)
# ============================================================

class FrozenW2VTwoHead(nn.Module):
    """
    - Wav2Vec2 backbone is frozen
    - Two small heads are trained:
      * vowel head for task == VOWEL_TASK_VALUE
      * other head for all other tasks
    """
    def __init__(self, backbone_ckpt: str, dropout_p: float):
        super().__init__()
        self.backbone = Wav2Vec2Model.from_pretrained(backbone_ckpt)
        for p in self.backbone.parameters():
            p.requires_grad = False

        H = int(self.backbone.config.hidden_size)
        self.pre_vowel = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.pre_other = nn.Sequential(nn.LayerNorm(H), nn.Dropout(float(dropout_p)))
        self.head_vowel = nn.Linear(H, 2)
        self.head_other = nn.Linear(H, 2)
        self.loss_fn = nn.CrossEntropyLoss()

    def masked_mean_pool(self, last_hidden, attn_mask_samples):
        feat_mask = self.backbone._get_feature_vector_attention_mask(last_hidden.shape[1], attn_mask_samples)
        feat_mask = feat_mask.to(last_hidden.device).unsqueeze(-1).type_as(last_hidden)
        summed = (last_hidden * feat_mask).sum(dim=1)
        denom = feat_mask.sum(dim=1).clamp(min=1.0)
        return summed / denom

    def forward(self, input_values, attention_mask, task_group):
        out = self.backbone(input_values=input_values, attention_mask=attention_mask)
        pooled = self.masked_mean_pool(out.last_hidden_state, attention_mask)

        logits = []
        for i, tg in enumerate(task_group):
            if tg == VOWEL_TASK_VALUE:
                z = self.pre_vowel(pooled[i])
                logits.append(self.head_vowel(z))
            else:
                z = self.pre_other(pooled[i])
                logits.append(self.head_other(z))
        return torch.stack(logits, dim=0)


# ============================================================
# Dataset (adds clip_path in each sample for debug prints)
# ============================================================

class AudioManifestDataset(Dataset):
    """
    Loads audio from clip_path under clip_root.
    Builds an attention mask:
    - vowel clips: trim trailing near-silence
    - other clips: use full clip
    """
    def __init__(self, df: pd.DataFrame, clip_root: Path):
        self.df = df.reset_index(drop=True)
        self.clip_root = clip_root
        for c in ["clip_path", "label_num", "task"]:
            if c not in self.df.columns:
                raise ValueError(f"Manifest missing column '{c}'")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        rel = str(row["clip_path"])
        label = int(row["label_num"])
        task_group = str(row["task"])

        wav_path = self.clip_root / rel

        import soundfile as sf
        y, _sr = sf.read(str(wav_path))

        if isinstance(y, np.ndarray) and y.ndim > 1:
            y = y.mean(axis=1)

        y = np.asarray(y, dtype=np.float32)
        attn = np.zeros_like(y, dtype=np.int64)

        if task_group == VOWEL_TASK_VALUE:
            k = -1
            for j in range(len(y) - 1, -1, -1):
                if abs(float(y[j])) > TINY_THRESH:
                    k = j
                    break
            attn[:] = 1 if k < 0 else 0
            if k >= 0:
                attn[:k+1] = 1
        else:
            attn[:] = 1

        return {
            "clip_path": rel,
            "input_values": torch.from_numpy(y),
            "attention_mask": torch.from_numpy(attn),
            "labels": torch.tensor(label, dtype=torch.long),
            "task_group": task_group,
        }

def collate_fn(batch):
    """
    Pads waveforms and masks to max length in the batch.
    Keeps clip_path for debug.
    """
    max_len = int(max(b["input_values"].numel() for b in batch))
    xs, ms, ys, tgs, cps = [], [], [], [], []
    for b in batch:
        x = b["input_values"]
        m = b["attention_mask"]
        pad = max_len - x.numel()
        if pad > 0:
            x = torch.cat([x, torch.zeros(pad, dtype=x.dtype)], dim=0)
            m = torch.cat([m, torch.zeros(pad, dtype=m.dtype)], dim=0)
        xs.append(x)
        ms.append(m)
        ys.append(b["labels"])
        tgs.append(b["task_group"])
        cps.append(b["clip_path"])
    return {
        "clip_path": cps,
        "input_values": torch.stack(xs, dim=0),
        "attention_mask": torch.stack(ms, dim=0),
        "labels": torch.stack(ys, dim=0),
        "task_group": tgs,
    }


# ============================================================
# Metrics
# ============================================================

def _eval_auc(model: nn.Module, loader: DataLoader) -> float:
    model.eval()
    ys, ps = [], []
    with torch.no_grad():
        for batch in loader:
            x = batch["input_values"].to(DEVICE)
            m = batch["attention_mask"].to(DEVICE)
            y = batch["labels"].to(DEVICE)
            tg = batch["task_group"]

            with torch.cuda.amp.autocast(enabled=USE_AMP):
                logits = model(x, m, tg)
                prob_pd = torch.softmax(logits, dim=1)[:, 1]

            ys.append(y.detach().cpu().numpy())
            ps.append(prob_pd.detach().cpu().numpy())

    y_all = np.concatenate(ys)
    p_all = np.concatenate(ps)
    if len(np.unique(y_all)) < 2:
        return float("nan")
    return float(roc_auc_score(y_all, p_all))


# ============================================================
# Build speaker-balanced subsets from TrainEnh draws
# ============================================================

def _select_d2_rows(enh_train: pd.DataFrame, manifest_name: str) -> pd.DataFrame:
    """
    Select only D2 rows using source_dataset.
    """
    if "source_dataset" not in enh_train.columns:
        raise ValueError(
            f"{manifest_name} missing 'source_dataset'. "
            "D2 filtering requires this column."
        )

    src = enh_train["source_dataset"].astype(str).str.strip()
    has_d2 = (src == D2_SOURCE_KEY).any()
    if not has_d2:
        top = src.value_counts().head(10).to_dict()
        raise ValueError(
            f"{manifest_name}: no rows where source_dataset == '{D2_SOURCE_KEY}'. "
            f"Top source_dataset values: {top}"
        )
    return enh_train[src == D2_SOURCE_KEY].copy()

def _pick_balanced_speakers(
    df_d2: pd.DataFrame,
    frac: float,
    seed: int,
    manifest_name: str
) -> tuple[set[str], dict]:
    """
    Speaker-level sampling with PD/HC balance.

    Input:
    - df_d2: D2-only rows from one TrainEnh draw (one speaker may have multiple clips)
    - frac: fraction of speakers to include within PD and HC separately

    Output:
    - chosen_speakers: speaker_id set
    - meta: counts for logging
    """
    for c in ["speaker_id", "label_num"]:
        if c not in df_d2.columns:
            raise ValueError(f"{manifest_name}: missing '{c}' needed for speaker sampling")

    spk = df_d2[["speaker_id", "label_num"]].drop_duplicates().copy()
    spk["speaker_id"] = spk["speaker_id"].astype(str)

    spk_pd = spk[spk["label_num"].astype(int) == 1]["speaker_id"].tolist()
    spk_hc = spk[spk["label_num"].astype(int) == 0]["speaker_id"].tolist()

    rng = np.random.RandomState(int(seed))
    rng.shuffle(spk_pd)
    rng.shuffle(spk_hc)

    n_pd = len(spk_pd)
    n_hc = len(spk_hc)

    k_pd = int(np.floor(frac * n_pd))
    k_hc = int(np.floor(frac * n_hc))

    chosen = set(spk_pd[:k_pd]) | set(spk_hc[:k_hc])

    meta = {
        "manifest": manifest_name,
        "frac": float(frac),
        "pd_speakers_total": int(n_pd),
        "hc_speakers_total": int(n_hc),
        "pd_speakers_kept": int(k_pd),
        "hc_speakers_kept": int(k_hc),
        "speakers_kept_total": int(len(chosen)),
    }
    return chosen, meta

def _subset_rows_by_speakers(df_d2: pd.DataFrame, speakers: set[str]) -> pd.DataFrame:
    """
    Keep all clips for included speakers.
    """
    s = df_d2["speaker_id"].astype(str)
    return df_d2[s.isin(speakers)].copy()

def _load_draw_d2_rows(mpath: Path) -> tuple[pd.DataFrame, str]:
    """
    Read one TrainEnh manifest and return its D2-only training rows.
    """
    enh = _read_csv(mpath)
    for c in ["clip_path", "split", "label_num", "task", "source_dataset", "speaker_id"]:
        if c not in enh.columns:
            raise ValueError(f"{mpath.name} missing '{c}'")

    enh["clip_path"] = enh["clip_path"].map(_norm_clip_path)
    split_name = _choose_train_split_name(enh)
    enh_train = enh[enh["split"].astype(str) == split_name].copy()
    if len(enh_train) == 0:
        raise ValueError(f"{mpath.name}: split '{split_name}' is empty")

    d2_rows = _select_d2_rows(enh_train, mpath.name)
    return d2_rows, split_name

def _build_draw_buckets(
    m_enh1: Path,
    m_enh2: Path,
    m_enh3: Path,
) -> dict:
    """
    Pre-compute speaker-balanced subsets for finer points.

    Outputs:
    - dict with full and half subsets:
      A_full, A_half, B_full, B_half, C_full, C_half
    """
    d2_a, split_a = _load_draw_d2_rows(m_enh1)
    d2_b, split_b = _load_draw_d2_rows(m_enh2)
    d2_c, split_c = _load_draw_d2_rows(m_enh3)

    # Use a fixed seed for deterministic speaker picks (not tied to training seed)
    pick_seed = 424242

    spk_a_half, meta_a = _pick_balanced_speakers(d2_a, DRAW_FRAC["A_half"], pick_seed, m_enh1.name)
    spk_b_half, meta_b = _pick_balanced_speakers(d2_b, DRAW_FRAC["B_half"], pick_seed, m_enh2.name)
    spk_c_half, meta_c = _pick_balanced_speakers(d2_c, DRAW_FRAC["C_half"], pick_seed, m_enh3.name)

    buckets = {
        "A_full": d2_a.copy(),
        "A_half": _subset_rows_by_speakers(d2_a, spk_a_half),
        "B_full": d2_b.copy(),
        "B_half": _subset_rows_by_speakers(d2_b, spk_b_half),
        "C_full": d2_c.copy(),
        "C_half": _subset_rows_by_speakers(d2_c, spk_c_half),
        "meta": {
            "split_used": {"A": split_a, "B": split_b, "C": split_c},
            "half_pick_seed": int(pick_seed),
            "half_pick_details": [meta_a, meta_b, meta_c],
        },
    }

    # Quick sanity: draws should be non-overlapping speakers (given the current process)
    spkA = set(d2_a["speaker_id"].astype(str).unique().tolist())
    spkB = set(d2_b["speaker_id"].astype(str).unique().tolist())
    spkC = set(d2_c["speaker_id"].astype(str).unique().tolist())
    overlaps = {
        "A∩B": int(len(spkA & spkB)),
        "A∩C": int(len(spkA & spkC)),
        "B∩C": int(len(spkB & spkC)),
    }
    print("\n[Build] Draw speaker overlaps (expected ~0):", overlaps)

    return buckets


# ============================================================
# Build condition manifest (base + selected draw buckets)
# ============================================================

def _prepare_condition_manifest_from_rows(
    cond_name: str,
    pct: int,
    extra_rows: pd.DataFrame,
    sources_meta: list[dict],
) -> Path:
    """
    Writes:
      <cond>/manifests/manifest_all.csv   (train + val)
      <cond>/logs/condition_meta.json
    """
    dst = ABL_ROOT / cond_name
    dst_mans = dst / "manifests"
    dst_logs = dst / "logs"
    dst.mkdir(parents=True, exist_ok=True)
    dst_mans.mkdir(parents=True, exist_ok=True)
    dst_logs.mkdir(parents=True, exist_ok=True)

    print(f"\n[Build] {cond_name} ({pct}%)")
    print(f"[Build] Writing to: {dst}")

    base_all = _read_csv(D7_MAN_DIR / "manifest_all.csv")
    for c in ["clip_path", "split", "label_num", "task"]:
        if c not in base_all.columns:
            raise ValueError(f"Base manifest missing '{c}'")

    base_all["clip_path"] = base_all["clip_path"].map(_norm_clip_path)

    base_train = base_all[base_all["split"].astype(str) == "train"].copy()
    base_val   = base_all[base_all["split"].astype(str) == "val"].copy()
    if len(base_train) == 0:
        raise ValueError(f"[Build] {cond_name}: base train is empty")
    if len(base_val) == 0:
        raise ValueError(f"[Build] {cond_name}: base val is empty")

    base_train_paths = set(base_train["clip_path"].astype(str).tolist())

    extra = extra_rows.copy() if len(extra_rows) > 0 else base_train.iloc[0:0].copy()
    extra["clip_path"] = extra["clip_path"].map(_norm_clip_path)

    # Dedup against base, and within extra
    cp = extra["clip_path"].astype(str)
    extra = extra[~cp.isin(base_train_paths)].copy()
    extra = extra.drop_duplicates(subset=["clip_path"]).copy()

    train_df = pd.concat([base_train, extra], axis=0, ignore_index=True)
    train_df["split"] = "train"
    val_df = base_val.copy()
    val_df["split"] = "val"
    out_all = pd.concat([train_df, val_df], axis=0, ignore_index=True)

    _write_csv(out_all, dst_mans / "manifest_all.csv")

    added_raw = int(len(extra_rows))
    added_dedup = int(len(extra))

    spk_added = _safe_unique_count(extra, "speaker_id")
    print(f"[Counts] Base train={len(base_train)} | Added D2 (raw)={added_raw} | Added D2 (dedup)={added_dedup}")
    print(f"[Counts] Final train={len(train_df)} | Final val={len(val_df)}")
    if spk_added is not None:
        print(f"[Counts] Added D2 speakers (if available)={spk_added}")

    meta = {
        "cond_name": cond_name,
        "pct_target_speakers": pct,
        "base_train_rows": int(len(base_train)),
        "base_val_rows": int(len(base_val)),
        "added_d2_raw_total": int(added_raw),
        "added_d2_dedup_total": int(added_dedup),
        "train_rows_written": int(len(train_df)),
        "val_rows_written": int(len(val_df)),
        "added_d2_unique_speakers_if_available": spk_added,
        "sources": sources_meta,
        "created_at": time.strftime("%Y-%m-%d %H:%M:%S"),
    }
    with open(dst_logs / "condition_meta.json", "w") as f:
        json.dump(meta, f, indent=2)

    return dst


# ============================================================
# Train + validation (runs saved on Drive, clips read from clip_root)
# ============================================================

def _run_trainval(manifest_path: Path, clip_root: Path, out_root: Path, tag: str) -> Path:
    run_stamp = _new_run_stamp()
    exp_root = out_root / "trainval_runs" / f"exp_{tag}_{run_stamp}"
    exp_root.mkdir(parents=True, exist_ok=True)

    print(f"\n[TrainVal] {tag}")
    print(f"[TrainVal] Writing to: {exp_root}")

    manifest = _read_csv(manifest_path)
    manifest["clip_path"] = manifest["clip_path"].map(_norm_clip_path)

    df_train = manifest[manifest["split"].astype(str) == "train"].copy()
    df_val   = manifest[manifest["split"].astype(str) == "val"].copy()

    if len(df_train) == 0:
        raise ValueError(f"[TrainVal] {tag}: train split is empty")
    if len(df_val) == 0:
        raise ValueError(f"[TrainVal] {tag}: val split is empty")

    print(f"[TrainVal] train clips={len(df_train)} | val clips={len(df_val)}")
    _scan_exists(clip_root, df_train, f"{tag} train")
    _scan_exists(clip_root, df_val,   f"{tag} val")
    _scan_soundfile_info(clip_root, df_train.head(500), f"{tag} train (first 500)")
    _scan_soundfile_info(clip_root, df_val, f"{tag} val")

    train_ds = AudioManifestDataset(df_train, clip_root)
    val_ds   = AudioManifestDataset(df_val, clip_root)

    train_loader = DataLoader(train_ds, batch_size=PER_DEVICE_BS, shuffle=True,  num_workers=0, collate_fn=collate_fn)
    val_loader   = DataLoader(val_ds,   batch_size=PER_DEVICE_BS, shuffle=False, num_workers=0, collate_fn=collate_fn)

    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)

    for seed in SEEDS:
        _set_seed(seed)
        seed_dir = exp_root / f"run_D7_seed{seed}"
        seed_dir.mkdir(parents=True, exist_ok=True)

        model = FrozenW2VTwoHead(BACKBONE_CKPT, DROPOUT_P).to(DEVICE)
        opt = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=LR)

        print(f"[TrainVal] Seed {seed}: warmup fetch...")
        t_fetch = time.time()
        it = iter(train_loader)
        batch0 = next(it)
        print(f"[TrainVal] Seed {seed}: warmup fetch ok | {time.time()-t_fetch:.2f}s | sample: {batch0['clip_path'][0]}")

        best_auc = -1.0
        best_epoch = -1
        best_state = None
        bad = 0
        history = []

        pbar = tqdm(range(1, MAX_EPOCHS + 1), desc=f"Seed {seed} epochs", leave=False)
        for epoch in pbar:
            model.train()
            opt.zero_grad(set_to_none=True)

            total_loss = 0.0
            steps = 0
            last_i = 0

            loader_iter = iter(train_loader)

            for i in range(1, len(train_loader) + 1):
                last_i = i

                t0_fetch = time.time()
                batch = next(loader_iter)
                dt_fetch = time.time() - t0_fetch
                if dt_fetch > FETCH_WATCHDOG_SEC:
                    print(f"\n[Hang] Slow batch fetch: epoch={epoch} i={i} fetch_sec={dt_fetch:.1f}")
                    print("[Hang] Sample clip paths:")
                    for cp in batch["clip_path"][:8]:
                        print("  -", cp)

                t0_step = time.time()

                x = batch["input_values"].to(DEVICE)
                m = batch["attention_mask"].to(DEVICE)
                y = batch["labels"].to(DEVICE)
                tg = batch["task_group"]

                with torch.cuda.amp.autocast(enabled=USE_AMP):
                    logits = model(x, m, tg)
                    loss = model.loss_fn(logits, y)

                scaler.scale(loss / GRAD_ACCUM).backward()

                if i % GRAD_ACCUM == 0:
                    scaler.step(opt)
                    scaler.update()
                    opt.zero_grad(set_to_none=True)

                total_loss += float(loss.item())
                steps += 1

                dt_step = time.time() - t0_step
                if dt_step > STEP_WATCHDOG_SEC:
                    print(f"\n[Hang] Slow step: epoch={epoch} i={i} step_sec={dt_step:.1f}")
                    print("[Hang] Sample clip paths:")
                    for cp in batch["clip_path"][:8]:
                        print("  -", cp)

                if i % PRINT_EVERY_N_BATCHES == 0:
                    print(f"[TrainVal] epoch={epoch} i={i}/{len(train_loader)} loss={float(loss.item()):.4f} fetch={dt_fetch:.2f}s step={dt_step:.2f}s")

            if last_i % GRAD_ACCUM != 0:
                scaler.step(opt)
                scaler.update()
                opt.zero_grad(set_to_none=True)

            val_auc = _eval_auc(model, val_loader)
            history.append({"epoch": epoch, "train_loss": total_loss / max(1, steps), "val_auc": val_auc})
            pbar.set_postfix({"val_auc": f"{val_auc:.4f}"})

            if val_auc > best_auc + 1e-6:
                best_auc = val_auc
                best_epoch = epoch
                bad = 0
                best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            else:
                bad += 1
                if bad >= PATIENCE:
                    break

        if best_state is None:
            raise RuntimeError("No best_state captured; training failed.")

        torch.save(best_state, seed_dir / "best_heads.pt")
        with open(seed_dir / "history.json", "w") as f:
            json.dump(history, f, indent=2)
        with open(seed_dir / "summary_trainval_seed.json", "w") as f:
            json.dump({"seed": seed, "best_val_auc": best_auc, "best_epoch": best_epoch}, f, indent=2)

    with open(exp_root / "summary_trainval.json", "w") as f:
        json.dump({"tag": tag, "seeds": SEEDS}, f, indent=2)

    return exp_root


# ============================================================
# D2 test (reads from D2, writes into condition folder)
# ============================================================

def _run_d2_test(exp_root: Path, tag: str, out_root: Path) -> Path:
    run_stamp = _new_run_stamp()
    run_root = out_root / "multilingual_test_runs" / f"{tag}_{run_stamp}"
    run_root.mkdir(parents=True, exist_ok=True)

    d2_manifest_path = D2_ROOT / "manifests" / "manifest_all.csv"
    _ensure_exists(d2_manifest_path, "D2 manifests/manifest_all.csv")

    d2_all = _read_csv(d2_manifest_path)
    d2_all["clip_path"] = d2_all["clip_path"].map(_norm_clip_path)

    df_test = d2_all[d2_all["split"].astype(str) == "test"].copy()
    for c in ["clip_path", "label_num", "task"]:
        if c not in df_test.columns:
            raise ValueError(f"D2 manifest missing required column '{c}'")

    d2_ds = AudioManifestDataset(df_test, D2_ROOT)
    d2_loader = DataLoader(d2_ds, batch_size=PER_DEVICE_BS, shuffle=False, num_workers=0, collate_fn=collate_fn)

    print(f"[D2 Test] {tag} | D2 test clips={len(d2_ds)}")
    print(f"[D2 Test] Writing to: {run_root}")

    aurocs = []

    for seed in SEEDS:
        seed_dir = run_root / f"run_D7_on_D2test_seed{seed}"
        seed_dir.mkdir(parents=True, exist_ok=True)

        best_heads = exp_root / f"run_D7_seed{seed}" / "best_heads.pt"
        _ensure_exists(best_heads, f"best_heads.pt for seed {seed}")

        model = FrozenW2VTwoHead(BACKBONE_CKPT, DROPOUT_P).to(DEVICE)
        state = torch.load(best_heads, map_location="cpu")
        model.load_state_dict(state, strict=True)
        model.eval()

        ys, ps = [], []
        with torch.no_grad():
            for batch in tqdm(d2_loader, desc=f"Infer seed {seed}", leave=False):
                x = batch["input_values"].to(DEVICE)
                m = batch["attention_mask"].to(DEVICE)
                y = batch["labels"].to(DEVICE)
                tg = batch["task_group"]

                with torch.cuda.amp.autocast(enabled=USE_AMP):
                    logits = model(x, m, tg)
                    prob_pd = torch.softmax(logits, dim=1)[:, 1]

                ys.append(y.detach().cpu().numpy())
                ps.append(prob_pd.detach().cpu().numpy())

        y_all = np.concatenate(ys)
        p_all = np.concatenate(ps)

        auc = float("nan") if len(np.unique(y_all)) < 2 else float(roc_auc_score(y_all, p_all))
        aurocs.append(auc)

        with open(seed_dir / "metrics.json", "w") as f:
            json.dump({"seed": seed, "d2_test_auroc": auc}, f, indent=2)

        pd.DataFrame({"y_true": y_all.astype(int), "y_score": p_all.astype(float)}).to_csv(
            seed_dir / "predictions.csv", index=False
        )

    with open(run_root / "summary_test.json", "w") as f:
        json.dump({"tag": tag, "seeds": SEEDS, "aurocs": aurocs}, f, indent=2)

    return run_root


# ============================================================
# Main run (finer points, then summary plot)
# ============================================================

_ensure_exists(D7_MAN_DIR / "manifest_all.csv", "D7 base manifest_all.csv")
_ensure_exists(D2_ROOT / "manifests" / "manifest_all.csv", "D2 manifest_all.csv")

M_ENH1 = D7_MAN_DIR / "manifest_train_enh1.csv"
M_ENH2 = D7_MAN_DIR / "manifest_train_enh2.csv"
M_ENH3 = D7_MAN_DIR / "manifest_train_enh3.csv"
_ensure_exists(M_ENH1, "D7 manifest_train_enh1.csv")
_ensure_exists(M_ENH2, "D7 manifest_train_enh2.csv")
_ensure_exists(M_ENH3, "D7 manifest_train_enh3.csv")

_print_split_counts(D2_ROOT / "manifests" / "manifest_all.csv", "D2 splits")

print("\n[Build] Pre-computing speaker-balanced half draws...")
draws = _build_draw_buckets(M_ENH1, M_ENH2, M_ENH3)

# Conditions:
# 0%: base
# 5%: base + A_half
# 10%: base + A_full
# 15%: base + A_full + B_half
# 20%: base + A_full + B_full
# 25%: base + A_full + B_full + C_half
# 29%: base + A_full + B_full + C_full
conditions = [
    ("base_0pct",    0,  []),
    ("enhA_5pct",    5,  ["A_half"]),
    ("enhA_10pct",   10, ["A_full"]),
    ("enhAB_15pct",  15, ["A_full", "B_half"]),
    ("enhAB_20pct",  20, ["A_full", "B_full"]),
    ("enhABC_25pct", 25, ["A_full", "B_full", "C_half"]),
    ("enhABC_29pct", 29, ["A_full", "B_full", "C_full"]),
]

results = []

for cond_name, pct, parts in conditions:
    # 1) Extra rows for this condition
    extra_rows_list = []
    sources_meta = []

    for key in parts:
        df_part = draws[key].copy()
        extra_rows_list.append(df_part)

        spk_ct = _safe_unique_count(df_part, "speaker_id")
        sources_meta.append({
            "source_part": key,
            "d2_rows": int(len(df_part)),
            "d2_speakers_if_available": spk_ct,
        })

    # Add draw metadata once for traceability
    if len(parts) > 0:
        sources_meta.append({"draw_meta": draws["meta"]})

    extra_rows = (
        pd.concat(extra_rows_list, axis=0, ignore_index=True)
        if len(extra_rows_list) > 0
        else pd.DataFrame()
    )

    # 2) Build condition manifest on Drive
    dx_drive = _prepare_condition_manifest_from_rows(
        cond_name=cond_name,
        pct=pct,
        extra_rows=extra_rows,
        sources_meta=sources_meta,
    )

    manifest_drive_path = dx_drive / "manifests" / "manifest_all.csv"
    manifest_drive = _read_csv(manifest_drive_path)
    manifest_drive["clip_path"] = manifest_drive["clip_path"].map(_norm_clip_path)

    df_train = manifest_drive[manifest_drive["split"].astype(str) == "train"].copy()
    df_val   = manifest_drive[manifest_drive["split"].astype(str) == "val"].copy()
    print(f"[Counts] {cond_name}: train={len(df_train)} | val={len(df_val)}")

    # 3) Choose where clips are read from during training
    clip_root = D7_CLIP_ROOT
    manifest_for_train = manifest_drive_path

    if STAGE_TO_LOCAL:
        local_root = _stage_manifest_clips_to_local(cond_name, D7_CLIP_ROOT, manifest_drive)
        clip_root = local_root
        manifest_for_train = local_root / "manifests" / "manifest_all.csv"

        _scan_exists(clip_root, df_train, f"{cond_name} train (local)")
        _scan_exists(clip_root, df_val,   f"{cond_name} val (local)")

    tag = f"target_ablation_finer_{cond_name}"

    # 4) Train (writes outputs to Drive condition folder)
    exp_root = _run_trainval(
        manifest_path=manifest_for_train,
        clip_root=clip_root,
        out_root=dx_drive,
        tag=tag
    )

    # 5) Test on D2 (writes outputs to Drive condition folder)
    test_run = _run_d2_test(exp_root=exp_root, tag=tag, out_root=dx_drive)

    # 6) Collect AUROCs
    aurocs = []
    for seed in SEEDS:
        mpath = test_run / f"run_D7_on_D2test_seed{seed}" / "metrics.json"
        _ensure_exists(mpath, f"metrics.json for {cond_name} seed {seed}")
        with open(mpath, "r") as f:
            aurocs.append(json.load(f)["d2_test_auroc"])

    mean_auc = float(np.nanmean(aurocs))
    sd_auc   = float(np.nanstd(aurocs, ddof=1)) if np.sum(~np.isnan(aurocs)) >= 2 else float("nan")

    results.append({
        "condition": cond_name,
        "pct_target_speakers": pct,
        "aurocs_by_seed": aurocs,
        "auroc_mean": mean_auc,
        "auroc_sd": sd_auc,
        "drive_condition_root": str(dx_drive),
        "clip_root_used_for_train": str(clip_root),
    })


# ============================================================
# Summary + curve plot
# ============================================================

summary_df = pd.DataFrame(results).sort_values("pct_target_speakers").reset_index(drop=True)
summary_csv = ABL_ROOT / "target_data_ablation_finer_summary.csv"
summary_df.to_csv(summary_csv, index=False)

print("\n[Done] Summary CSV:", summary_csv)
display(summary_df)

x = summary_df["pct_target_speakers"].to_numpy(dtype=float)
y = summary_df["auroc_mean"].to_numpy(dtype=float)
yerr = summary_df["auroc_sd"].to_numpy(dtype=float)

coeffs = np.polyfit(x, y, deg=2)
poly_fn = np.poly1d(coeffs)

x_fit = np.linspace(float(np.min(x)), float(np.max(x)), 200)
y_fit = poly_fn(x_fit)

plt.figure()
plt.errorbar(x, y, yerr=yerr, fmt="o", capsize=4, label="Mean ± SD (3 seeds)")
plt.plot(x_fit, y_fit, "--", linewidth=2, color="orange", label="Trend Line")

plt.xticks(x, [f"{int(v)}%" for v in x])
plt.xlabel("Percent of D2 train speakers included in training")
plt.ylabel("D2 test AUROC")
plt.title("Target-Data Ablation: Finer Cumulative D2 Exposure vs D2-Test AUROC")
plt.grid(True, alpha=0.3)
plt.legend()

plot_path = ABL_ROOT / "target_data_ablation_finer_curve_quadratic.png"
plt.savefig(plot_path, dpi=200, bbox_inches="tight")
plt.show()
print("[Done] Plot saved:", plot_path)


# ============================================================
# End
# ============================================================

elapsed_hr = (time.time() - START_TIME) / 3600.0
print(f"\nTotal wall time: {elapsed_hr:.2f} hours")

print("\nAll done. Unassigning the runtime to stop the L4 instance...")
try:
    from google.colab import runtime  # type: ignore
    print("Calling runtime.unassign() now.")
    runtime.unassign()
except Exception as e:
    print("Could not unassign runtime automatically. The runtime can be stopped manually in Colab.")
    print("Reason:", repr(e))