In [1]:
"""
SNIPPET S1: Data Audit & Subject/Visit Tables (REVISION 3 - Nearest-Day Matching)
Builds canonical visits_table.csv and subjects_table.csv from OASIS-2 + OASIS-3
"""

import os
import glob
import pandas as pd
import numpy as np
from pathlib import Path
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

# ============================================================================
# HELPER FUNCTIONS
# ============================================================================

def standardize_oasis2_subject(raw_id):
    """Ensure OASIS-2 subject ID is in format OAS2_XXXX"""
    if pd.isna(raw_id):
        return None
    raw_id = str(raw_id).strip()
    if raw_id.startswith("OAS2_"):
        return raw_id
    elif raw_id.isdigit():
        return f"OAS2_{raw_id.zfill(4)}"
    else:
        parts = raw_id.replace("OAS", "").replace("_", "").strip()
        if parts.isdigit():
            return f"OAS2_{parts.zfill(4)}"
    return raw_id

def choose_canonical_t1(candidates):
    """Select preferred T1w file from multiple candidates (prefer run-01)"""
    if len(candidates) == 1:
        return candidates[0]
    
    # Prefer run-01
    for c in candidates:
        if 'run-01' in c or 'run-1' in c:
            return c
    
    # Otherwise take first alphabetically
    return sorted(candidates)[0]

def assign_oasis3_visit_index(df):
    """
    Assign chronological visit_index (1, 2, 3...) based on days_to_visit.
    For subjects with multiple visits, sort by days and number sequentially.
    """
    df = df.sort_values(["subject_id", "days_to_visit"])
    df["visit_index"] = df.groupby("subject_id").cumcount() + 1
    return df

# ============================================================================
# STEP 1: SCAN OASIS-2 MRI DIRECTORIES
# ============================================================================

def scan_oasis2_mri(root_part1, root_part2):
    """Scan OASIS-2 Part 1 and Part 2 for MRI visit directories"""
    print("\n" + "="*70)
    print("STEP 1: Scanning OASIS-2 MRI directories")
    print("="*70)
    
    records = []
    
    for root_label, root_path in [("PART1", root_part1), ("PART2", root_part2)]:
        print(f"\nScanning {root_label}: {root_path}")
        
        if not os.path.exists(root_path):
            print(f"  ‚ö†Ô∏è  WARNING: Path does not exist, skipping")
            continue
        
        subdirs = [d for d in os.listdir(root_path) 
                   if os.path.isdir(os.path.join(root_path, d))]
        
        count_valid = 0
        count_missing = 0
        
        for dir_name in subdirs:
            if not dir_name.startswith("OAS2_"):
                continue
            
            # Parse: OAS2_0001_MR1 -> subject="OAS2_0001", visit="MR1", index=1
            parts = dir_name.split("_")
            if len(parts) != 3:
                print(f"  ‚ö†Ô∏è  Unexpected format: {dir_name}")
                continue
            
            subject_id = f"{parts[0]}_{parts[1]}"  # "OAS2_0001"
            visit_id = parts[2]                     # "MR1"
            
            # Extract visit index
            try:
                visit_index = int(visit_id.replace("MR", ""))
            except ValueError:
                print(f"  ‚ö†Ô∏è  Cannot parse visit index: {dir_name}")
                continue
            
            # Check for MRI file
            visit_dir = os.path.join(root_path, dir_name, "RAW")
            mri_path = os.path.join(visit_dir, "mpr-1.nifti.hdr")
            
            if not os.path.exists(mri_path):
                alt_path = os.path.join(visit_dir, "mpr-1.hdr")
                if os.path.exists(alt_path):
                    mri_path = alt_path
                else:
                    count_missing += 1
                    continue
            
            records.append({
                "dataset": "OASIS2",
                "domain_id": 0,
                "subject_id": subject_id,
                "visit_id": visit_id,
                "visit_index": visit_index,
                "mri_path": mri_path,
            })
            count_valid += 1
        
        print(f"  ‚úì Found {count_valid} valid MRI sessions")
        if count_missing > 0:
            print(f"  ‚ö†Ô∏è  Skipped {count_missing} sessions (missing MRI file)")
    
    df = pd.DataFrame(records)
    print(f"\nüìä OASIS-2 Total: {len(df)} MRI sessions from {df['subject_id'].nunique()} subjects")
    return df


# ============================================================================
# STEP 2: SCAN OASIS-3 MRI DIRECTORIES (FIXED - Extract mri_days)
# ============================================================================

def scan_oasis3_mri(root_o3):
    """
    Scan OASIS-3 for T1w NIfTI files
    FIX: Extract numeric mri_days from directory name for temporal matching
    """
    print("\n" + "="*70)
    print("STEP 2: Scanning OASIS-3 MRI directories")
    print("="*70)
    print(f"Root: {root_o3}")
    
    if not os.path.exists(root_o3):
        print(f"  ‚ö†Ô∏è  WARNING: Path does not exist")
        return pd.DataFrame()
    
    records = []
    subdirs = [d for d in os.listdir(root_o3) 
               if os.path.isdir(os.path.join(root_o3, d))]
    
    count_valid = 0
    count_no_t1 = 0
    count_bad_days = 0
    
    for dir_name in subdirs:
        if not dir_name.startswith("OAS3"):
            continue
        
        if "_MR_" not in dir_name:
            continue
        
        # Parse: "OAS30006_MR_d2341" -> subject="OAS30006", visit="d2341", days=2341
        mri_session_label = dir_name  # Keep full label for reference
        
        subject_part, _, visit_part = dir_name.partition("_MR_")
        subject_id = subject_part  # "OAS30006"
        visit_id = visit_part      # "d2341"
        
        # FIX: Extract numeric days from visit_id
        try:
            mri_days = int(visit_id.replace("d", ""))  # 2341
        except ValueError:
            count_bad_days += 1
            mri_days = np.nan
        
        # Search for T1w NIfTI files
        session_dir = os.path.join(root_o3, dir_name)
        t1_candidates = []
        
        for anat_dir in glob.glob(os.path.join(session_dir, "anat*", "NIFTI")):
            t1_files = glob.glob(os.path.join(anat_dir, "**", "*T1w.nii"), recursive=True)
            t1_candidates.extend(t1_files)
        
        if len(t1_candidates) == 0:
            count_no_t1 += 1
            continue
        
        mri_path = choose_canonical_t1(t1_candidates)
        
        records.append({
            "dataset": "OASIS3",
            "domain_id": 1,
            "subject_id": subject_id,
            "mri_session_label": mri_session_label,  # Full MRI directory name
            "visit_id": visit_id,                    # "d2341"
            "mri_days": mri_days,                    # ‚Üê NEW: numeric days for matching
            "visit_index": None,                     # Will assign after merge
            "mri_path": mri_path,
        })
        count_valid += 1
    
    df = pd.DataFrame(records)
    print(f"  ‚úì Found {count_valid} valid T1w sessions")
    if count_no_t1 > 0:
        print(f"  ‚ö†Ô∏è  Skipped {count_no_t1} sessions (no T1w file)")
    if count_bad_days > 0:
        print(f"  ‚ö†Ô∏è  Skipped {count_bad_days} sessions (invalid day format)")
    
    print(f"\nüìä OASIS-3 Total: {len(df)} MRI sessions from {df['subject_id'].nunique()} subjects")
    
    # DEBUG: Show sample data
    if len(df) > 0:
        print(f"\nüîç Sample MRI sessions (first 5):")
        for _, row in df.head(5).iterrows():
            print(f"   {row['mri_session_label']} ‚Üí subject={row['subject_id']}, mri_days={row['mri_days']}")
    
    return df


# ============================================================================
# STEP 3: LOAD CLINICAL CSVs
# ============================================================================

def load_oasis2_clinical(csv_path):
    """Load and standardize OASIS-2 clinical CSV"""
    print("\n" + "="*70)
    print("STEP 3a: Loading OASIS-2 Clinical CSV")
    print("="*70)
    print(f"Path: {csv_path}")
    
    if not os.path.exists(csv_path):
        raise FileNotFoundError(f"OASIS-2 clinical CSV not found: {csv_path}")
    
    df = pd.read_csv(csv_path)
    print(f"‚úì Loaded {len(df)} rows, {len(df.columns)} columns")
    
    # Auto-detect columns
    col_map = {}
    for col in df.columns:
        col_lower = col.lower().strip()
        if 'subject' in col_lower or col_lower == 'id':
            col_map['subject_col'] = col
        elif 'visit' in col_lower and 'mri' not in col_lower:
            col_map['visit_col'] = col
        elif col_lower == 'cdr':
            col_map['cdr_col'] = col
        elif 'mmse' in col_lower:
            col_map['mmse_col'] = col
        elif col_lower == 'age':
            col_map['age_col'] = col
        elif 'm/f' in col_lower or col_lower == 'sex':
            col_map['sex_col'] = col
    
    # Validate
    required = ['subject_col', 'visit_col', 'cdr_col', 'mmse_col']
    missing = [k for k in required if k not in col_map]
    if missing:
        raise ValueError(f"Missing required columns in OASIS-2 CSV: {missing}")
    
    print(f"\n‚úì Column mapping:")
    for k, v in col_map.items():
        print(f"  {k:15} -> '{v}'")
    
    # Standardize
    df['subject_id'] = df[col_map['subject_col']].apply(standardize_oasis2_subject)
    df['visit_index'] = pd.to_numeric(df[col_map['visit_col']], errors='coerce')
    df['CDR'] = pd.to_numeric(df[col_map['cdr_col']], errors='coerce')
    df['MMSE'] = pd.to_numeric(df[col_map['mmse_col']], errors='coerce')
    
    if 'age_col' in col_map:
        df['Age'] = pd.to_numeric(df[col_map['age_col']], errors='coerce')
    if 'sex_col' in col_map:
        df['Sex'] = df[col_map['sex_col']].astype(str).str.strip()
    
    # Select columns
    keep_cols = ['subject_id', 'visit_index', 'CDR', 'MMSE']
    if 'Age' in df.columns:
        keep_cols.append('Age')
    if 'Sex' in df.columns:
        keep_cols.append('Sex')
    
    df_clean = df[keep_cols].copy()
    print(f"\n‚úì Standardized {len(df_clean)} clinical records")
    
    return df_clean


def load_oasis3_clinical(csv_path):
    """
    Load and standardize OASIS-3 clinical CSV
    FIXED: Keep clinical session label separate, use days_to_visit for matching
    """
    print("\n" + "="*70)
    print("STEP 3b: Loading OASIS-3 Clinical CSV")
    print("="*70)
    print(f"Path: {csv_path}")
    
    if not os.path.exists(csv_path):
        raise FileNotFoundError(f"OASIS-3 clinical CSV not found: {csv_path}")
    
    df = pd.read_csv(csv_path)
    print(f"‚úì Loaded {len(df)} rows, {len(df.columns)} columns")
    
    # Hard-coded column mapping for OASIS-3
    subject_col = "OASISID"
    session_col = "OASIS_session_label"
    cdr_col = "CDRTOT"
    mmse_col = "MMSE"
    age_col = "age at visit"
    days_col = "days_to_visit"
    
    # Validate
    required_cols = [subject_col, session_col, cdr_col, mmse_col, age_col, days_col]
    missing = [c for c in required_cols if c not in df.columns]
    if missing:
        raise ValueError(f"Missing required columns in OASIS-3 CSV: {missing}")
    
    print(f"\n‚úì Column mapping (hard-coded for OASIS-3):")
    print(f"  subject_col     -> '{subject_col}'")
    print(f"  session_col     -> '{session_col}'")
    print(f"  age_col         -> '{age_col}'")
    print(f"  mmse_col        -> '{mmse_col}'")
    print(f"  cdr_col         -> '{cdr_col}'")
    print(f"  days_col        -> '{days_col}'")
    
    # Standardize IDs
    df['subject_id'] = df[subject_col].astype(str).str.strip()
    
    # Keep clinical session label (different from MRI label!)
    df['clin_session_label'] = df[session_col].astype(str).str.strip()
    
    # Rename columns
    df.rename(columns={
        cdr_col: "CDR",
        mmse_col: "MMSE",
        age_col: "Age",
        days_col: "days_to_visit",
    }, inplace=True)
    
    # Convert to numeric
    df['CDR'] = pd.to_numeric(df['CDR'], errors='coerce')
    df['MMSE'] = pd.to_numeric(df['MMSE'], errors='coerce')
    df['Age'] = pd.to_numeric(df['Age'], errors='coerce')
    df['days_to_visit'] = pd.to_numeric(df['days_to_visit'], errors='coerce')
    
    # Select columns
    keep_cols = ['subject_id', 'clin_session_label', 'CDR', 'MMSE', 'Age', 'days_to_visit']
    df_clean = df[keep_cols].copy()
    
    print(f"\n‚úì Standardized {len(df_clean)} OASIS-3 clinical records")
    
    # Sanity check
    print("\nüìä OASIS-3 CDR/MMSE Sanity Check:")
    print(df_clean[['CDR', 'MMSE']].describe())
    print(f"\nCDR unique values: {sorted(df_clean['CDR'].dropna().unique())}")
    
    # DEBUG: Show sample data
    print(f"\nüîç Sample clinical records (first 5):")
    for _, row in df_clean.head(5).iterrows():
        print(f"   subject={row['subject_id']}, days={row['days_to_visit']}, CDR={row['CDR']}, label={row['clin_session_label']}")
    
    return df_clean


# ============================================================================
# STEP 4: MERGE MRI + CLINICAL
# ============================================================================

def merge_oasis2(mri_df, clinical_df):
    """Merge OASIS-2 MRI visits with clinical data"""
    print("\n" + "="*70)
    print("STEP 4a: Merging OASIS-2 MRI + Clinical")
    print("="*70)
    
    print(f"MRI visits before merge: {len(mri_df)}")
    print(f"Clinical records: {len(clinical_df)}")
    
    merged = mri_df.merge(
        clinical_df,
        on=['subject_id', 'visit_index'],
        how='inner'
    )
    
    print(f"‚úì After merge: {len(merged)} visits")
    
    # Filter: require valid CDR and MMSE
    before_filter = len(merged)
    merged = merged.dropna(subset=['CDR', 'MMSE'])
    print(f"‚úì After CDR/MMSE filter: {len(merged)} visits (dropped {before_filter - len(merged)})")
    
    # Assign labels
    def assign_label(cdr):
        if cdr == 0:
            return 0   # CN
        elif cdr >= 1:
            return 1   # AD
        else:
            return -1  # Exclude
    
    merged['label'] = merged['CDR'].apply(assign_label)
    
    # Keep only CN/AD
    before_label = len(merged)
    merged = merged[merged['label'].isin([0, 1])]
    print(f"‚úì After label filter (CN/AD only): {len(merged)} visits (dropped {before_label - len(merged)})")
    
    return merged


def merge_oasis3(mri_df, clinical_df, max_day_diff=365):
    """
    Merge OASIS-3 MRI visits with clinical data using nearest-neighbor temporal matching
    
    FIXED: Match by (subject_id + nearest days_to_visit) instead of exact session label
    
    Args:
        mri_df: MRI sessions with mri_days
        clinical_df: Clinical assessments with days_to_visit
        max_day_diff: Maximum allowed day difference (default 365 days = 1 year)
    """
    print("\n" + "="*70)
    print("STEP 4b: Merging OASIS-3 MRI + Clinical (nearest-day matching)")
    print("="*70)
    print(f"MRI visits before merge: {len(mri_df)}")
    print(f"Clinical records: {len(clinical_df)}")
    
    # Drop MRI sessions without valid mri_days
    mri_valid = mri_df.dropna(subset=["mri_days"]).copy()
    mri_valid["mri_days"] = mri_valid["mri_days"].astype(float)
    print(f"‚úì MRI sessions with valid days: {len(mri_valid)}")
    
    # Drop clinical records without valid days_to_visit
    clinical_valid = clinical_df.dropna(subset=["days_to_visit"]).copy()
    clinical_valid["days_to_visit"] = clinical_valid["days_to_visit"].astype(float)
    print(f"‚úì Clinical records with valid days: {len(clinical_valid)}")
    
    # Cartesian join on subject_id (all combinations within each subject)
    merged = mri_valid.merge(
        clinical_valid,
        on="subject_id",
        how="inner",
        suffixes=("_mri", "_clin"),
    )
    
    print(f"‚úì Total (subject, MRI, clinical) combinations: {len(merged)}")
    
    if len(merged) == 0:
        print("‚ùå No subject overlap between MRI and clinical for OASIS-3")
        return pd.DataFrame()
    
    # Compute absolute day difference
    merged["day_diff"] = (merged["mri_days"] - merged["days_to_visit"]).abs()
    
    # Filter to reasonable temporal proximity (‚â§ max_day_diff)
    if max_day_diff is not None:
        before = len(merged)
        merged = merged[merged["day_diff"] <= max_day_diff]
        print(f"‚úì After day_diff ‚â§ {max_day_diff} filter: {len(merged)} pairs (dropped {before - len(merged)})")
    
    if len(merged) == 0:
        print("‚ùå No MRI‚Äìclinical pairs within the day_diff threshold")
        return pd.DataFrame()
    
    # For each MRI session, keep the clinical row with MINIMUM day_diff
    idx = merged.groupby("mri_session_label")["day_diff"].idxmin()
    best = merged.loc[idx].copy()
    
    print(f"‚úì After selecting nearest clinical visit per MRI: {len(best)} visits")
    
    # Show example matches
    print(f"\nüîç Sample matches (first 5):")
    for _, row in best.head(5).iterrows():
        print(f"   MRI: {row['mri_session_label']} (day {row['mri_days']:.0f})")
        print(f"     ‚Üí Clinical: {row['clin_session_label']} (day {row['days_to_visit']:.0f})")
        print(f"     ‚Üí Diff: {row['day_diff']:.0f} days, CDR={row['CDR']}, MMSE={row['MMSE']}")
    
    # Filter: require valid CDR & MMSE
    before_filter = len(best)
    best = best.dropna(subset=["CDR", "MMSE"])
    print(f"\n‚úì After CDR/MMSE filter: {len(best)} visits (dropped {before_filter - len(best)})")
    
    # Assign labels
    def assign_label(cdr):
        if cdr == 0:
            return 0   # CN
        elif cdr >= 1:
            return 1   # AD
        else:
            return -1
    
    best["label"] = best["CDR"].apply(assign_label)
    before_label = len(best)
    best = best[best["label"].isin([0, 1])]
    print(f"‚úì After label filter (CN/AD only): {len(best)} visits (dropped {before_label - len(best)})")
    
    # Assign visit_index chronologically using days_to_visit
    best = assign_oasis3_visit_index(best)
    print("‚úì Assigned visit_index using days_to_visit")
    
    # Standardize output columns to match OASIS-2
    best["dataset"] = "OASIS3"
    best["domain_id"] = 1
    best["session_label"] = best["mri_session_label"]  # Use MRI label as primary identifier
    
    # Select final columns
    out_cols = [
        "dataset",
        "domain_id",
        "subject_id",
        "session_label",
        "visit_id",
        "visit_index",
        "mri_path",
        "CDR",
        "MMSE",
        "label",
        "Age",
        "days_to_visit",
        "day_diff",  # Keep for QC
    ]
    
    # Only keep columns that exist
    out_cols = [c for c in out_cols if c in best.columns]
    best_out = best[out_cols].copy()
    
    return best_out


# ============================================================================
# STEP 5: BUILD VISITS TABLE
# ============================================================================

def build_visits_table(o2_merged, o3_merged, output_path):
    """Combine OASIS-2 and OASIS-3 into unified visits table"""
    print("\n" + "="*70)
    print("STEP 5: Building Unified Visits Table")
    print("="*70)
    
    # Add session_label to OASIS-2 for schema consistency (set to None)
    if 'session_label' not in o2_merged.columns:
        o2_merged['session_label'] = None
    
    # Add day_diff to OASIS-2 for schema consistency (set to 0)
    if 'day_diff' not in o2_merged.columns:
        o2_merged['day_diff'] = 0.0
    
    df_all = pd.concat([o2_merged, o3_merged], axis=0, ignore_index=True)
    
    # Column ordering
    base_cols = [
        'dataset', 'domain_id', 'subject_id', 'session_label', 'visit_id', 
        'visit_index', 'mri_path', 'CDR', 'MMSE', 'label'
    ]
    
    # Add optional columns
    optional_cols = ['Age', 'Sex', 'days_to_visit', 'day_diff']
    cols = base_cols + [c for c in optional_cols if c in df_all.columns]
    
    df_all = df_all[cols]
    
    # Sanity checks
    assert df_all['mri_path'].notna().all(), "‚ùå Found missing mri_path values"
    
    # Save
    df_all.to_csv(output_path, index=False)
    print(f"\n‚úÖ Saved visits_table.csv: {len(df_all)} visits")
    
    # OASIS-3 temporal matching quality report
    if 'day_diff' in df_all.columns and 'OASIS3' in df_all['dataset'].values:
        o3_diffs = df_all[df_all['dataset'] == 'OASIS3']['day_diff']
        print(f"\nüìä OASIS-3 Temporal Matching Quality:")
        print(f"   Mean day difference: {o3_diffs.mean():.1f} days")
        print(f"   Median day difference: {o3_diffs.median():.1f} days")
        print(f"   Max day difference: {o3_diffs.max():.1f} days")
        print(f"   Within 30 days: {(o3_diffs <= 30).sum()} visits ({100*(o3_diffs <= 30).sum()/len(o3_diffs):.1f}%)")
        print(f"   Within 90 days: {(o3_diffs <= 90).sum()} visits ({100*(o3_diffs <= 90).sum()/len(o3_diffs):.1f}%)")
    
    return df_all


# ============================================================================
# STEP 6: BUILD SUBJECTS TABLE
# ============================================================================

def build_subjects_table(visits_df, output_path):
    """Aggregate visits into subject-level table"""
    print("\n" + "="*70)
    print("STEP 6: Building Subjects Table")
    print("="*70)
    
    grouped = visits_df.groupby(['dataset', 'subject_id', 'domain_id'])
    
    records = []
    for (dataset, subject_id, domain_id), g in grouped:
        # Baseline age
        ages = g['Age'].dropna().tolist() if 'Age' in g.columns else []
        baseline_age = min(ages) if len(ages) > 0 else None
        
        # Sex
        if 'Sex' in g.columns:
            sex_vals = g['Sex'].dropna().tolist()
            sex = Counter(sex_vals).most_common(1)[0][0] if len(sex_vals) > 0 else None
        else:
            sex = None
        
        # Visit counts
        n_total = len(g)
        n_cn = int((g['label'] == 0).sum())
        n_ad = int((g['label'] == 1).sum())
        
        records.append({
            'dataset': dataset,
            'domain_id': domain_id,
            'subject_id': subject_id,
            'Sex': sex,
            'baseline_age': baseline_age,
            'n_visits_total': n_total,
            'n_CN_visits': n_cn,
            'n_AD_visits': n_ad,
            'has_longitudinal': int(n_total >= 2),
        })
    
    subjects_df = pd.DataFrame(records)
    subjects_df.to_csv(output_path, index=False)
    print(f"\n‚úÖ Saved subjects_table.csv: {len(subjects_df)} subjects")
    
    return subjects_df


# ============================================================================
# STEP 7: QC LOGGING
# ============================================================================

def log_qc(subjects_df, visits_df):
    """Comprehensive quality control logging"""
    print("\n" + "="*70)
    print("STEP 7: QUALITY CONTROL SUMMARY")
    print("="*70)
    
    # Visit-level statistics
    print("\nüìä VISIT-LEVEL STATISTICS")
    print("-" * 70)
    
    visit_counts = visits_df.groupby(['dataset', 'label']).size().unstack(fill_value=0)
    visit_counts.columns = ['CN', 'AD']
    print("\nVisits by dataset and label:")
    print(visit_counts)
    
    cn_count = int((visits_df['label'] == 0).sum())
    ad_count = int((visits_df['label'] == 1).sum())
    ratio = cn_count / max(ad_count, 1)
    
    print(f"\n{'Total visits:':<25} {len(visits_df)}")
    print(f"{'CN visits:':<25} {cn_count} ({100*cn_count/len(visits_df):.1f}%)")
    print(f"{'AD visits:':<25} {ad_count} ({100*ad_count/len(visits_df):.1f}%)")
    print(f"{'CN:AD ratio:':<25} {ratio:.2f}:1")
    
    # Subject-level statistics
    print("\n\nüìä SUBJECT-LEVEL STATISTICS")
    print("-" * 70)
    
    subj_by_dataset = subjects_df.groupby('dataset')['subject_id'].count()
    print("\nSubjects by dataset:")
    for dataset, count in subj_by_dataset.items():
        print(f"  {dataset:<10} {count:>4} subjects")
    print(f"  {'TOTAL':<10} {len(subjects_df):>4} subjects")
    
    # Longitudinal structure
    print("\n\nüìä LONGITUDINAL STRUCTURE")
    print("-" * 70)
    
    long_by_dataset = subjects_df[subjects_df['has_longitudinal'] == 1].groupby('dataset')['subject_id'].count()
    print("\nSubjects with ‚â•2 visits:")
    for dataset, count in long_by_dataset.items():
        print(f"  {dataset:<10} {count:>4} subjects")
    
    visit_dist = subjects_df['n_visits_total'].value_counts().sort_index()
    print("\nDistribution of visits per subject:")
    for n_visits, count in visit_dist.items():
        print(f"  {n_visits:>2} visits: {count:>4} subjects")
    
    # Converter detection
    print("\n\nüìä CONVERTER ANALYSIS")
    print("-" * 70)
    
    converters = subjects_df[(subjects_df['n_CN_visits'] > 0) & (subjects_df['n_AD_visits'] > 0)]
    print(f"\nSubjects with BOTH CN and AD labels (potential converters): {len(converters)}")
    
    if len(converters) > 0:
        print("\nTop 10 converters:")
        print(converters[['dataset', 'subject_id', 'n_CN_visits', 'n_AD_visits', 'n_visits_total']].head(10))
    
    # Data quality checks
    print("\n\nüìä DATA QUALITY CHECKS")
    print("-" * 70)
    
    missing_mri = visits_df['mri_path'].isna().sum()
    print(f"‚úì Missing mri_path values: {missing_mri} (should be 0)")
    
    if 'Age' in visits_df.columns:
        age_range = visits_df['Age'].dropna()
        if len(age_range) > 0:
            print(f"‚úì Age range: {age_range.min():.1f} - {age_range.max():.1f} years (mean: {age_range.mean():.1f})")
    
    if 'MMSE' in visits_df.columns:
        mmse_range = visits_df['MMSE'].dropna()
        if len(mmse_range) > 0:
            print(f"‚úì MMSE range: {mmse_range.min():.0f} - {mmse_range.max():.0f} (mean: {mmse_range.mean():.1f})")
    
    # Success criteria assessment
    print("\n\nüìä SUCCESS CRITERIA ASSESSMENT")
    print("-" * 70)
    
    checks = []
    checks.append(("Total visits ‚â• 400", len(visits_df) >= 400))
    checks.append(("Both datasets present", len(visits_df['dataset'].unique()) == 2))
    checks.append(("CN:AD ratio 3:1 to 10:1", 3 <= ratio <= 10))
    
    # Check both datasets have CN and AD
    o2_has_both = False
    o3_has_both = False
    if 'OASIS2' in visit_counts.index:
        o2_has_both = (visit_counts.loc['OASIS2']['CN'] > 0) and (visit_counts.loc['OASIS2']['AD'] > 0)
    if 'OASIS3' in visit_counts.index:
        o3_has_both = (visit_counts.loc['OASIS3']['CN'] > 0) and (visit_counts.loc['OASIS3']['AD'] > 0)
    
    checks.append(("OASIS-2 has CN and AD", o2_has_both))
    checks.append(("OASIS-3 has CN and AD", o3_has_both))
    checks.append(("Longitudinal subjects ‚â• 50", (subjects_df['has_longitudinal'] == 1).sum() >= 50))
    checks.append(("No missing MRI paths", missing_mri == 0))
    
    all_pass = all([c[1] for c in checks])
    
    for check_name, passed in checks:
        status = "‚úÖ PASS" if passed else "‚ùå FAIL"
        print(f"{status}  {check_name}")
    
    if all_pass:
        print("\n" + "="*70)
        print("üéâ ALL SUCCESS CRITERIA MET - S1 ACCEPTED")
        print("="*70)
    else:
        print("\n" + "="*70)
        print("‚ö†Ô∏è  SOME CRITERIA NOT MET - REVIEW REQUIRED")
        print("="*70)
    
    return all_pass


# ============================================================================
# MAIN PIPELINE
# ============================================================================

def run_S1_pipeline():
    """Execute complete S1 data audit pipeline"""
    print("\n" + "="*70)
    print("üî¨ SNIPPET S1: DATA AUDIT & SUBJECT/VISIT TABLES (REVISION 3)")
    print("   KEY FIX: OASIS-3 nearest-neighbor temporal matching")
    print("="*70)
    
    # Step 1-2: Scan MRI directories
    o2_mri = scan_oasis2_mri(
        root_part1="/kaggle/input/oaisis-dataset-3-p1/OAS2_RAW_PART1",
        root_part2="/kaggle/input/oaisis-3-p2/OAS2_RAW_PART2",
    )
    
    o3_mri = scan_oasis3_mri(
        root_o3="/kaggle/input/oaisis-3/oaisis3",
    )
    
    # Step 3: Load clinical CSVs
    o2_clin = load_oasis2_clinical("/kaggle/input/mri-and-alzheimers/oasis_longitudinal.csv")
    o3_clin = load_oasis3_clinical("/kaggle/input/oaisis-3-longitiudinal/oaisis3longitiudinal.csv")
    
    # Step 4: Merge
    o2_merged = merge_oasis2(o2_mri, o2_clin)
    o3_merged = merge_oasis3(o3_mri, o3_clin, max_day_diff=365)
    
    # Step 5-6: Build tables
    visits_df = build_visits_table(
        o2_merged, o3_merged,
        output_path="visits_table.csv"
    )
    
    subjects_df = build_subjects_table(
        visits_df,
        output_path="subjects_table.csv"
    )
    
    # Step 7: QC logging
    all_pass = log_qc(subjects_df, visits_df)
    
    return visits_df, subjects_df, all_pass


# ============================================================================
# EXECUTE
# ============================================================================

if __name__ == "__main__":
    visits_df, subjects_df, success = run_S1_pipeline()
    
    # Display previews
    print("\n" + "="*70)
    print("üìã PREVIEW: visits_table.csv")
    print("="*70)
    print("\nFirst 5 OASIS-2 visits:")
    o2_sample = visits_df[visits_df['dataset'] == 'OASIS2'].head(5)[
        ['subject_id', 'visit_id', 'visit_index', 'CDR', 'MMSE', 'Age', 'label']
    ]
    print(o2_sample.to_string(index=False))
    
    if 'OASIS3' in visits_df['dataset'].values:
        print("\nFirst 5 OASIS-3 visits:")
        o3_sample = visits_df[visits_df['dataset'] == 'OASIS3'].head(5)[
            ['subject_id', 'visit_id', 'visit_index', 'CDR', 'MMSE', 'Age', 'label', 'day_diff']
        ]
        print(o3_sample.to_string(index=False))
    
    print("\n" + "="*70)
    print("üìã PREVIEW: subjects_table.csv (first 10 rows)")
    print("="*70)
    print(subjects_df.head(10).to_string(index=False))



üî¨ SNIPPET S1: DATA AUDIT & SUBJECT/VISIT TABLES (REVISION 3)
   KEY FIX: OASIS-3 nearest-neighbor temporal matching

STEP 1: Scanning OASIS-2 MRI directories

Scanning PART1: /kaggle/input/oaisis-dataset-3-p1/OAS2_RAW_PART1
  ‚úì Found 209 valid MRI sessions

Scanning PART2: /kaggle/input/oaisis-3-p2/OAS2_RAW_PART2
  ‚úì Found 164 valid MRI sessions

üìä OASIS-2 Total: 373 MRI sessions from 150 subjects

STEP 2: Scanning OASIS-3 MRI directories
Root: /kaggle/input/oaisis-3/oaisis3
  ‚úì Found 423 valid T1w sessions

üìä OASIS-3 Total: 423 MRI sessions from 300 subjects

üîç Sample MRI sessions (first 5):
   OAS30354_MR_d0056 ‚Üí subject=OAS30354, mri_days=56
   OAS30083_MR_d3827 ‚Üí subject=OAS30083, mri_days=3827
   OAS31019_MR_d1370 ‚Üí subject=OAS31019, mri_days=1370
   OAS30208_MR_d1703 ‚Üí subject=OAS30208, mri_days=1703
   OAS30830_MR_d0030 ‚Üí subject=OAS30830, mri_days=30

STEP 3a: Loading OASIS-2 Clinical CSV
Path: /kaggle/input/mri-and-alzheimers/oasis_longitudinal.csv

In [2]:
"""
SNIPPET S2: 3D Preprocessing + Hippocampal ROI + Full-Brain Volume (EXTENDED)

NEW: Also generates downsampled full-brain volumes for whole-brain CNN analysis
- Hippocampal ROI extraction (for XAI/regional analysis)
- Full T1w downsampled volumes (for full-brain CNN)
"""

import os
import glob
import pandas as pd
import numpy as np
import nibabel as nib
from pathlib import Path
from scipy import ndimage
from scipy.ndimage import zoom
from tqdm import tqdm
import warnings
import hashlib
warnings.filterwarnings('ignore')

# Nilearn for template loading
from nilearn import datasets as nilearn_datasets

# SimpleITK for registration
import SimpleITK as sitk

# Visualization
import matplotlib.pyplot as plt
from matplotlib import gridspec

# ============================================================================
# CONFIGURATION
# ============================================================================

class S2Config:
    """Central configuration for S2 preprocessing"""
    
    MNI_RESOLUTION = 2
    
    # Hippocampal atlas labels (auto-detected at runtime)
    HIPPO_LEFT_LABEL = None
    HIPPO_RIGHT_LABEL = None
    
    # Output directories
    OUTPUT_ROOT = "/kaggle/working/processed_mri"
    WHOLE_BRAIN_DIR = "whole_brain"      # Full MNI resolution
    FULL_BRAIN_DIR = "full_brain"        # NEW: Downsampled for CNN
    HIPPO_ROI_DIR = "hippo_roi"
    HIPPO_EXT_DIR = "hippo_roi_ext"
    QC_DIR = "qc"
    TEMP_DIR = "temp_cleaned"
    
    # ROI parameters
    HIPPO_ROI_MARGIN = 10
    HIPPO_EXT_MARGIN = 24
    TARGET_HIPPO_SHAPE = (80, 80, 80)
    TARGET_EXT_SHAPE = (112, 112, 80)
    
    # NEW: Full-brain CNN input shape (downsampled from MNI ~193x229x193)
    TARGET_FULL_SHAPE = (128, 160, 128)  # Manageable for 3D CNN
    
    # Registration parameters
    RIGID_ITERATIONS = 200
    
    # Intensity normalization
    PERCENTILE_LOW = 1
    PERCENTILE_HIGH = 99
    
    # QC sampling
    QC_SAMPLE_SIZE = 20


# ============================================================================
# UTILITY FUNCTIONS
# ============================================================================

def collapse_4d_to_3d(data):
    """Collapse 4D MRI volume to 3D"""
    if data.ndim == 3:
        return data
    
    if data.ndim == 4:
        if data.shape[-1] == 1:
            return data[..., 0]
        else:
            print(f"      4D volume with {data.shape[-1]} timepoints, averaging...")
            return np.mean(data, axis=-1)
    
    data_squeezed = np.squeeze(data)
    if data_squeezed.ndim == 3:
        return data_squeezed
    
    raise ValueError(f"Cannot collapse {data.ndim}D volume with shape {data.shape} to 3D")


def find_actual_oasis3_file(raw_path):
    """OASIS-3 files are in nested BIDS directories"""
    path_obj = Path(raw_path)
    
    if path_obj.is_dir():
        nii_files = list(path_obj.glob("*.nii")) + list(path_obj.glob("*.nii.gz"))
        if nii_files:
            for f in nii_files:
                if "T1w" in f.name or "t1" in f.name.lower():
                    return str(f)
            return str(nii_files[0])
    
    if path_obj.exists() and path_obj.is_file():
        return str(path_obj)
    
    if not path_obj.exists():
        parent = path_obj.parent
        if parent.exists() and parent.is_dir():
            nii_files = list(parent.glob("*.nii")) + list(parent.glob("*.nii.gz"))
            if nii_files:
                for f in nii_files:
                    if "T1w" in f.name:
                        return str(f)
                return str(nii_files[0])
    
    gz_path = Path(str(path_obj) + ".gz")
    if gz_path.exists():
        return str(gz_path)
    
    raise FileNotFoundError(f"Cannot find actual NIfTI file for: {raw_path}")


def prepare_volume_for_sitk(raw_path, dataset, temp_dir):
    """
    Load raw volume, handle 4D‚Üí3D, clean headers, save for SimpleITK
    
    FIXED: SimpleITK fallback for corrupted OASIS-3 files
    """
    # Find actual file
    if dataset == "OASIS3":
        actual_file = find_actual_oasis3_file(raw_path)
    else:
        actual_file = raw_path
    
    # Try loading with nibabel (standard path)
    data = None
    affine = None
    
    try:
        img = nib.load(actual_file)
        data = img.get_fdata()
        affine = img.affine
    except Exception as e:
        # Check if it's a byte-size mismatch error (corrupted file)
        error_msg = str(e).lower()
        is_corruption = ("expected" in error_msg and "bytes" in error_msg and "got" in error_msg)
        
        if not is_corruption:
            raise
        
        # Try SimpleITK fallback
        print(f"      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...")
        try:
            sitk_img = sitk.ReadImage(str(actual_file))
            arr = sitk.GetArrayFromImage(sitk_img)
            data = np.transpose(arr, (2, 1, 0))
            
            spacing = sitk_img.GetSpacing()
            origin = sitk_img.GetOrigin()
            
            affine = np.eye(4)
            affine[0, 0] = spacing[0]
            affine[1, 1] = spacing[1]
            affine[2, 2] = spacing[2]
            affine[0, 3] = origin[0]
            affine[1, 3] = origin[1]
            affine[2, 3] = origin[2]
            
            print(f"      ‚úì SimpleITK fallback succeeded")
            
        except Exception as sitk_error:
            raise ValueError(f"Both nibabel and SimpleITK failed. Original: {e}, SimpleITK: {sitk_error}")
    
    if data is None:
        raise ValueError("Failed to load data from file")
    
    # Collapse 4D ‚Üí 3D if needed
    data_3d = collapse_4d_to_3d(data)
    
    # Validate 3D data
    if data_3d.size == 0:
        raise ValueError("Empty data array")
    
    if np.all(data_3d == 0):
        raise ValueError("All-zero volume")
    
    if data_3d.ndim != 3:
        raise ValueError(f"Expected 3D, got {data_3d.ndim}D with shape {data_3d.shape}")
    
    # Create cleaned image
    img_clean = nib.Nifti1Image(data_3d.astype(np.float32), affine)
    
    # Generate unique temp filename
    hash_suffix = hashlib.md5(actual_file.encode()).hexdigest()[:8]
    temp_filename = f"{dataset}_{hash_suffix}_cleaned.nii.gz"
    temp_path = os.path.join(temp_dir, temp_filename)
    
    os.makedirs(temp_dir, exist_ok=True)
    
    # Save cleaned version
    nib.save(img_clean, temp_path)
    
    return temp_path


# ============================================================================
# TEMPLATE LOADING
# ============================================================================

def load_mni_template_and_mask(config):
    """Load MNI152 template with auto-detected hippocampal labels"""
    print("\n" + "="*70)
    print("Loading MNI Template, Mask, and Atlas (Auto-Detection)")
    print("="*70)
    
    try:
        # Fetch MNI152 template
        print("‚úì Fetching MNI152 ICBM152 2009c template...")
        mni_data = nilearn_datasets.fetch_icbm152_2009()
        
        # Robust load
        if isinstance(mni_data['t1'], str):
            mni_template = nib.load(mni_data['t1'])
        else:
            mni_template = mni_data['t1']
        
        if isinstance(mni_data['mask'], str):
            mni_brain_mask = nib.load(mni_data['mask'])
        else:
            mni_brain_mask = mni_data['mask']
        
        print(f"‚úì MNI template: shape={mni_template.shape}")
        print(f"‚úì Brain mask: {int(mni_brain_mask.get_fdata().sum())} voxels")
        
        # Save template for SimpleITK
        template_dir = os.path.join(config.OUTPUT_ROOT, "mni_template")
        os.makedirs(template_dir, exist_ok=True)
        mni_template_path = os.path.join(template_dir, "mni_template.nii.gz")
        
        nib.save(mni_template, mni_template_path)
        print(f"‚úì Saved MNI template: {mni_template_path}")
        
        # Fetch Harvard-Oxford atlas
        print("\n‚úì Loading Harvard-Oxford subcortical atlas...")
        ho_data = nilearn_datasets.fetch_atlas_harvard_oxford('sub-maxprob-thr25-2mm')
        
        # Robust load
        raw_maps = ho_data.maps
        if isinstance(raw_maps, str):
            hippo_atlas = nib.load(raw_maps)
        elif isinstance(raw_maps, nib.nifti1.Nifti1Image):
            hippo_atlas = raw_maps
        else:
            hippo_atlas = raw_maps
        
        print(f"‚úì Atlas loaded: shape={hippo_atlas.shape}")
        
        # AUTO-DETECT HIPPOCAMPAL LABELS
        print("\n‚úì Auto-detecting hippocampal labels...")
        
        labels = ho_data.labels
        labels = [lab.decode("utf-8") if isinstance(lab, bytes) else lab for lab in labels]
        
        print(f"  Found {len(labels)} atlas regions")
        
        left_idx = None
        right_idx = None
        
        for idx, name in enumerate(labels):
            name_lower = name.lower()
            if "left" in name_lower and "hippocampus" in name_lower:
                left_idx = idx
                print(f"  Found: '{name}' at index {idx}")
            if "right" in name_lower and "hippocampus" in name_lower:
                right_idx = idx
                print(f"  Found: '{name}' at index {idx}")
        
        if left_idx is None or right_idx is None:
            print("\n  ‚ö†Ô∏è  Could not find hippocampal labels. Available:")
            for idx, name in enumerate(labels[:20]):
                print(f"    {idx}: {name}")
            raise ValueError(f"Hippocampal labels not found. Left={left_idx}, Right={right_idx}")
        
        # Update config
        config.HIPPO_LEFT_LABEL = left_idx
        config.HIPPO_RIGHT_LABEL = right_idx
        
        print(f"\n‚úì Detected hippocampal labels:")
        print(f"   Left:  {config.HIPPO_LEFT_LABEL}")
        print(f"   Right: {config.HIPPO_RIGHT_LABEL}")
        
        # Validate
        atlas_data = hippo_atlas.get_fdata()
        left_count = int(np.sum(atlas_data == config.HIPPO_LEFT_LABEL))
        right_count = int(np.sum(atlas_data == config.HIPPO_RIGHT_LABEL))
        
        print(f"\nüìä Hippocampus validation:")
        print(f"   Left (label {config.HIPPO_LEFT_LABEL}): {left_count} voxels")
        print(f"   Right (label {config.HIPPO_RIGHT_LABEL}): {right_count} voxels")
        
        if left_count == 0 or right_count == 0:
            raise ValueError(f"Hippocampal labels have no voxels! Left={left_count}, Right={right_count}")
        
        return mni_template, mni_brain_mask, hippo_atlas, mni_template_path
        
    except Exception as e:
        print(f"\n‚ùå CRITICAL ERROR: {e}")
        import traceback
        traceback.print_exc()
        raise


# ============================================================================
# REGISTRATION
# ============================================================================

def rigid_register_to_mni_sitk(moving_path, fixed_sitk):
    """Rigidly register moving image to MNI using SimpleITK"""
    moving_sitk = sitk.ReadImage(str(moving_path))
    
    if moving_sitk.GetSize()[0] == 0:
        raise ValueError("Moving image has zero size")
    
    registration = sitk.ImageRegistrationMethod()
    registration.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
    registration.SetMetricSamplingStrategy(registration.RANDOM)
    registration.SetMetricSamplingPercentage(0.01)
    registration.SetInterpolator(sitk.sitkLinear)
    registration.SetOptimizerAsGradientDescent(
        learningRate=1.0, numberOfIterations=200,
        convergenceMinimumValue=1e-6, convergenceWindowSize=10
    )
    registration.SetOptimizerScalesFromPhysicalShift()
    
    initial_transform = sitk.CenteredTransformInitializer(
        fixed_sitk, moving_sitk, sitk.Euler3DTransform(),
        sitk.CenteredTransformInitializerFilter.GEOMETRY
    )
    registration.SetInitialTransform(initial_transform, inPlace=False)
    
    final_transform = registration.Execute(fixed_sitk, moving_sitk)
    
    resampler = sitk.ResampleImageFilter()
    resampler.SetReferenceImage(fixed_sitk)
    resampler.SetInterpolator(sitk.sitkLinear)
    resampler.SetDefaultPixelValue(0)
    resampler.SetTransform(final_transform)
    
    registered_sitk = resampler.Execute(moving_sitk)
    
    data_mni = sitk.GetArrayFromImage(registered_sitk)
    data_mni = np.transpose(data_mni, (2, 1, 0))
    
    return data_mni.astype(np.float32), True


# ============================================================================
# ROI & RESIZING UTILITIES (NEW: resize_to_shape)
# ============================================================================

def resize_to_shape(data, target_shape):
    """
    Resize 3D volume to target shape using trilinear interpolation
    
    Args:
        data: 3D numpy array
        target_shape: tuple (D, H, W) target dimensions
    
    Returns:
        resized: 3D numpy array with target_shape
    """
    if data.shape == target_shape:
        return data
    
    zoom_factors = (
        target_shape[0] / data.shape[0],
        target_shape[1] / data.shape[1],
        target_shape[2] / data.shape[2],
    )
    
    # Use scipy.ndimage.zoom with trilinear interpolation (order=1)
    resized = zoom(data, zoom_factors, order=1, mode='constant', cval=0)
    
    return resized.astype(np.float32)


def get_hippocampal_mask_and_bbox(hippo_atlas, config):
    """Extract hippocampal mask and bounding box"""
    atlas_data = hippo_atlas.get_fdata()
    
    hippo_mask_left = (atlas_data == config.HIPPO_LEFT_LABEL)
    hippo_mask_right = (atlas_data == config.HIPPO_RIGHT_LABEL)
    hippo_mask_total = hippo_mask_left | hippo_mask_right
    
    coords = np.where(hippo_mask_total)
    if len(coords[0]) == 0:
        raise ValueError("Empty hippocampal mask!")
    
    x_min, x_max = coords[0].min(), coords[0].max()
    y_min, y_max = coords[1].min(), coords[1].max()
    z_min, z_max = coords[2].min(), coords[2].max()
    
    return hippo_mask_total, (x_min, x_max, y_min, y_max, z_min, z_max)


def crop_roi_with_margin(data, bbox, margin, target_shape=None):
    """Crop ROI with margin"""
    x_min, x_max, y_min, y_max, z_min, z_max = bbox
    
    x0 = max(x_min - margin, 0)
    x1 = min(x_max + margin + 1, data.shape[0])
    y0 = max(y_min - margin, 0)
    y1 = min(y_max + margin + 1, data.shape[1])
    z0 = max(z_min - margin, 0)
    z1 = min(z_max + margin + 1, data.shape[2])
    
    cropped = data[x0:x1, y0:y1, z0:z1]
    
    if target_shape is not None:
        cropped = pad_or_crop_to_shape(cropped, target_shape)
    
    return cropped, (x0, x1, y0, y1, z0, z1)


def pad_or_crop_to_shape(data, target_shape):
    """Pad or crop to target shape"""
    current_shape = data.shape
    output = np.zeros(target_shape, dtype=data.dtype)
    
    slices_out = []
    slices_in = []
    
    for i in range(3):
        if current_shape[i] <= target_shape[i]:
            start = (target_shape[i] - current_shape[i]) // 2
            end = start + current_shape[i]
            slices_out.append(slice(start, end))
            slices_in.append(slice(None))
        else:
            start = (current_shape[i] - target_shape[i]) // 2
            end = start + target_shape[i]
            slices_out.append(slice(None))
            slices_in.append(slice(start, end))
    
    output[tuple(slices_out)] = data[tuple(slices_in)]
    return output


# ============================================================================
# PREPROCESSING PIPELINE (EXTENDED with full-brain volume)
# ============================================================================

def preprocess_visit(row, mni_template_sitk, mni_brain_mask, mni_affine, 
                     hippo_mask, bbox, config, temp_dir_absolute):
    """
    Complete preprocessing for one visit
    
    NEW: Also generates downsampled full-brain volume for CNN
    """
    dataset = row['dataset']
    subject_id = row['subject_id']
    visit_index = row['visit_index']
    mri_path = row['mri_path']
    
    cleaned_path = prepare_volume_for_sitk(mri_path, dataset, temp_dir_absolute)
    data_mni, reg_success = rigid_register_to_mni_sitk(cleaned_path, mni_template_sitk)
    
    if not reg_success:
        raise ValueError("Registration failed")
    
    mask_data = mni_brain_mask.get_fdata() > 0
    masked_vals = data_mni[mask_data]
    
    if masked_vals.size == 0:
        raise ValueError("No brain voxels after masking")
    
    p1 = np.percentile(masked_vals, config.PERCENTILE_LOW)
    p99 = np.percentile(masked_vals, config.PERCENTILE_HIGH)
    
    if (p99 - p1) <= 1e-6:
        raise ValueError(f"Invalid percentile range: p1={p1:.2f}, p99={p99:.2f}")
    
    data_norm = np.clip((data_mni - p1) / (p99 - p1), 0, 1).astype(np.float32)
    data_norm[~mask_data] = 0.0
    
    # Step 4a: Save whole-brain volume (full MNI resolution)
    wb_dir = os.path.join(config.OUTPUT_ROOT, config.WHOLE_BRAIN_DIR)
    os.makedirs(wb_dir, exist_ok=True)
    wb_path = os.path.join(wb_dir, f"MNI_{dataset}_{subject_id}_v{visit_index}.nii.gz")
    nib.save(nib.Nifti1Image(data_norm, mni_affine), wb_path)
    
    # Step 4b: Create downsampled full-brain volume for CNN (NEW)
    full_dir = os.path.join(config.OUTPUT_ROOT, config.FULL_BRAIN_DIR)
    os.makedirs(full_dir, exist_ok=True)
    
    data_full = resize_to_shape(data_norm, config.TARGET_FULL_SHAPE)
    
    full_path = os.path.join(
        full_dir,
        f"FULL_{dataset}_{subject_id}_v{visit_index}.nii.gz"
    )
    nib.save(nib.Nifti1Image(data_full, mni_affine), full_path)
    
    # Step 5: Extract hippocampal ROI
    hippo_roi, _ = crop_roi_with_margin(
        data_norm, bbox, margin=config.HIPPO_ROI_MARGIN,
        target_shape=config.TARGET_HIPPO_SHAPE
    )
    roi_dir = os.path.join(config.OUTPUT_ROOT, config.HIPPO_ROI_DIR)
    os.makedirs(roi_dir, exist_ok=True)
    roi_path = os.path.join(roi_dir, f"ROI_{dataset}_{subject_id}_v{visit_index}.nii.gz")
    nib.save(nib.Nifti1Image(hippo_roi, mni_affine), roi_path)
    
    # Step 6: Extract extended ROI
    hippo_ext, _ = crop_roi_with_margin(
        data_norm, bbox, margin=config.HIPPO_EXT_MARGIN,
        target_shape=config.TARGET_EXT_SHAPE
    )
    ext_dir = os.path.join(config.OUTPUT_ROOT, config.HIPPO_EXT_DIR)
    os.makedirs(ext_dir, exist_ok=True)
    ext_path = os.path.join(ext_dir, f"ROIEXT_{dataset}_{subject_id}_v{visit_index}.nii.gz")
    nib.save(nib.Nifti1Image(hippo_ext, mni_affine), ext_path)
    
    return {
        'mni_path': wb_path,
        'full_t1_path': full_path,      # NEW: Downsampled full-brain for CNN
        'hippo_roi_path': roi_path,
        'hippo_ext_path': ext_path,
        'p1': float(p1),
        'p99': float(p99),
        'mean_intensity': float(np.mean(masked_vals)),
        'std_intensity': float(np.std(masked_vals)),
        'nonzero_voxels': int(np.sum(mask_data)),
        'preproc_ok': True,
        'error_msg': ''
    }


# ============================================================================
# QC
# ============================================================================

def generate_qc_overlay(data_mni, hippo_mask, output_path):
    """Generate QC overlay"""
    fig = plt.figure(figsize=(15, 5))
    gs = gridspec.GridSpec(1, 3, figure=fig)
    
    coords = np.where(hippo_mask)
    if len(coords[0]) > 0:
        center_x, center_y, center_z = [int(np.mean(c)) for c in coords]
    else:
        center_x, center_y, center_z = [s//2 for s in data_mni.shape]
    
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.imshow(data_mni[:, :, center_z].T, cmap='gray', origin='lower')
    ax1.contour(hippo_mask[:, :, center_z].T, colors='red', linewidths=1.5)
    ax1.set_title(f'Axial (z={center_z})')
    ax1.axis('off')
    
    ax2 = fig.add_subplot(gs[0, 1])
    ax2.imshow(data_mni[:, center_y, :].T, cmap='gray', origin='lower')
    ax2.contour(hippo_mask[:, center_y, :].T, colors='red', linewidths=1.5)
    ax2.set_title(f'Coronal (y={center_y})')
    ax2.axis('off')
    
    ax3 = fig.add_subplot(gs[0, 2])
    ax3.imshow(data_mni[center_x, :, :].T, cmap='gray', origin='lower')
    ax3.contour(hippo_mask[center_x, :, :].T, colors='red', linewidths=1.5)
    ax3.set_title(f'Sagittal (x={center_x})')
    ax3.axis('off')
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=100, bbox_inches='tight')
    plt.close()


# ============================================================================
# MAIN PIPELINE
# ============================================================================

def setup_s2_environment(config):
    """Create output directories"""
    for subdir in [config.WHOLE_BRAIN_DIR, config.FULL_BRAIN_DIR,  # NEW: FULL_BRAIN_DIR
                   config.HIPPO_ROI_DIR, config.HIPPO_EXT_DIR, config.QC_DIR]:
        path = os.path.join(config.OUTPUT_ROOT, subdir)
        os.makedirs(path, exist_ok=True)


def run_s2_pipeline(visits_csv_path="/kaggle/working/visits_table.csv", config=None):
    """Execute S2 preprocessing pipeline"""
    if config is None:
        config = S2Config()
    
    print("\n" + "="*70)
    print("üî¨ SNIPPET S2: 3D PREPROCESSING (EXTENDED - Full-Brain + ROI)")
    print("="*70)
    print(f"\nOutputs:")
    print(f"  - Full MNI resolution: {config.WHOLE_BRAIN_DIR}/")
    print(f"  - Downsampled full-brain {config.TARGET_FULL_SHAPE}: {config.FULL_BRAIN_DIR}/")
    print(f"  - Hippocampal ROI {config.TARGET_HIPPO_SHAPE}: {config.HIPPO_ROI_DIR}/")
    print(f"  - Hippocampal extended {config.TARGET_EXT_SHAPE}: {config.HIPPO_EXT_DIR}/")
    
    setup_s2_environment(config)
    
    temp_dir_absolute = os.path.join(config.OUTPUT_ROOT, config.TEMP_DIR)
    os.makedirs(temp_dir_absolute, exist_ok=True)
    
    mni_template, mni_brain_mask, hippo_atlas, mni_template_path = load_mni_template_and_mask(config)
    mni_template_sitk = sitk.ReadImage(mni_template_path)
    
    hippo_mask, bbox = get_hippocampal_mask_and_bbox(hippo_atlas, config)
    print(f"\n‚úì Hippocampal bbox: {bbox}, voxels: {int(np.sum(hippo_mask))}")
    
    visits_df = pd.read_csv(visits_csv_path)
    print(f"\n‚úì Loaded {len(visits_df)} visits")
    print(f"   {visits_df['dataset'].value_counts().to_dict()}")
    
    print("\n" + "="*70)
    print(f"Processing {len(visits_df)} visits...")
    print("="*70)
    
    results = []
    n_success = 0
    n_fail = 0
    qc_count = 0
    
    for idx, row in tqdm(visits_df.iterrows(), total=len(visits_df), desc="Preprocessing"):
        try:
            res = preprocess_visit(
                row, mni_template_sitk, mni_brain_mask,
                mni_template.affine, hippo_mask, bbox, config,
                temp_dir_absolute=temp_dir_absolute
            )
            n_success += 1
            
            if qc_count < config.QC_SAMPLE_SIZE:
                qc_path = os.path.join(
                    config.OUTPUT_ROOT, config.QC_DIR,
                    f"qc_{row['dataset']}_{row['subject_id']}_v{row['visit_index']}.png"
                )
                data_mni = nib.load(res['mni_path']).get_fdata()
                generate_qc_overlay(data_mni, hippo_mask, qc_path)
                qc_count += 1
                
        except Exception as e:
            n_fail += 1
            if n_fail <= 5:
                print(f"\n‚ö†Ô∏è  Failed: {row['subject_id']}_v{row['visit_index']}: {str(e)[:100]}")
            
            res = {
                'mni_path': None, 'full_t1_path': None,  # NEW
                'hippo_roi_path': None, 'hippo_ext_path': None,
                'p1': np.nan, 'p99': np.nan, 'mean_intensity': np.nan,
                'std_intensity': np.nan, 'nonzero_voxels': np.nan,
                'preproc_ok': False, 'error_msg': str(e)[:200]
            }
        
        result = {**row.to_dict(), **res}
        results.append(result)
    
    processed_df = pd.DataFrame(results)
    output_csv = "/kaggle/working/processed_volumes.csv"
    processed_df.to_csv(output_csv, index=False)
    
    # Summary
    success_rate = 100*n_success/len(visits_df) if len(visits_df) > 0 else 0
    print(f"\n‚úÖ S2 COMPLETE")
    print(f"   Successful: {n_success}/{len(visits_df)} ({success_rate:.1f}%)")
    print(f"   Failed: {n_fail}")
    
    if n_success > 0:
        ok_df = processed_df[processed_df['preproc_ok']]
        print(f"\nüìä Summary:")
        print(f"   Datasets: {ok_df['dataset'].value_counts().to_dict()}")
        print(f"   Labels: CN={int((ok_df['label']==0).sum())}, AD={int((ok_df['label']==1).sum())}")
        print(f"   Mean intensity: {ok_df['mean_intensity'].mean():.3f}")
        
        # Verify full-brain files exist
        full_exist = ok_df['full_t1_path'].apply(lambda x: os.path.exists(x) if pd.notna(x) else False)
        print(f"   Full-brain volumes: {int(full_exist.sum())}/{len(ok_df)} exist")
    
    # Validation
    print("\n" + "="*70)
    print("VALIDATION (Data-Aware)")
    print("="*70)
    
    ok_df = processed_df[processed_df['preproc_ok']]
    o3_total = int((visits_df['dataset'] == 'OASIS3').sum())
    o3_success = int((ok_df['dataset'] == 'OASIS3').sum())
    o3_rate = 100 * o3_success / o3_total if o3_total > 0 else 0
    ad_count = int((ok_df['label'] == 1).sum())
    
    print(f"\nCriteria:")
    print(f"  Overall success rate: {success_rate:.1f}% (target: ‚â•80%)")
    print(f"  OASIS3 success rate:  {o3_rate:.1f}% (target: ‚â•70%)")
    print(f"  AD visits available:  {ad_count} (target: ‚â•50)")
    
    passed = (success_rate >= 80 and o3_rate >= 70 and ad_count >= 50)
    
    if passed:
        print(f"\n‚úÖ PASS: S2 preprocessing meets quality thresholds")
        print(f"   ({n_fail} failures due to corrupted source files)")
    else:
        print(f"\n‚ùå FAIL: S2 preprocessing below quality thresholds")
    
    # Cleanup
    import shutil
    if os.path.exists(temp_dir_absolute):
        n_temp = len(list(Path(temp_dir_absolute).glob("*")))
        shutil.rmtree(temp_dir_absolute)
        print(f"\n‚úì Cleaned up {n_temp} temp files")
    
    return processed_df


# ============================================================================
# EXECUTE
# ============================================================================

if __name__ == "__main__":
    processed_df = run_s2_pipeline()
    
    print("\n" + "="*70)
    print("üìã Sample Output:")
    print("="*70)
    ok_df = processed_df[processed_df['preproc_ok'] == True]
    if len(ok_df) > 0:
        cols = ['subject_id', 'visit_index', 'dataset', 'label', 'full_t1_path']
        print(ok_df[cols].head(5).to_string(index=False))



üî¨ SNIPPET S2: 3D PREPROCESSING (EXTENDED - Full-Brain + ROI)

Outputs:
  - Full MNI resolution: whole_brain/
  - Downsampled full-brain (128, 160, 128): full_brain/
  - Hippocampal ROI (80, 80, 80): hippo_roi/
  - Hippocampal extended (112, 112, 80): hippo_roi_ext/

Loading MNI Template, Mask, and Atlas (Auto-Detection)
‚úì Fetching MNI152 ICBM152 2009c template...

Added README.md to /root/nilearn_data


Dataset created in /root/nilearn_data/icbm152_2009

Downloading data from https://osf.io/7pj92/download ...


 ...done. (2 seconds, 0 min)
Extracting data from /root/nilearn_data/icbm152_2009/e05b733c275cab0eec856067143c9dc9/download..... done.


‚úì MNI template: shape=(197, 233, 189)
‚úì Brain mask: 1886539 voxels
‚úì Saved MNI template: /kaggle/working/processed_mri/mni_template/mni_template.nii.gz

‚úì Loading Harvard-Oxford subcortical atlas...

Dataset created in /root/nilearn_data/fsl

Downloading data from https://www.nitrc.org/frs/download.php/9902/HarvardOxford.tgz ...


 ...done. (1 seconds, 0 min)
Extracting data from /root/nilearn_data/fsl/8a6a179c4b7672ec60913c596b129eff/HarvardOxford.tgz..... done.


‚úì Atlas loaded: shape=(91, 109, 91)

‚úì Auto-detecting hippocampal labels...
  Found 22 atlas regions
  Found: 'Left Hippocampus' at index 9
  Found: 'Right Hippocampus' at index 19

‚úì Detected hippocampal labels:
   Left:  9
   Right: 19

üìä Hippocampus validation:
   Left (label 9): 691 voxels
   Right (label 19): 700 voxels

‚úì Hippocampal bbox: (27, 64, 42, 61, 22, 39), voxels: 1391

‚úì Loaded 575 visits
   {'OASIS3': 327, 'OASIS2': 248}

Processing 575 visits...


Preprocessing:  45%|‚ñà‚ñà‚ñà‚ñà‚ñå     | 259/575 [21:40<36:24,  6.91s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  45%|‚ñà‚ñà‚ñà‚ñà‚ñå     | 260/575 [21:49<40:06,  7.64s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  46%|‚ñà‚ñà‚ñà‚ñà‚ñå     | 264/575 [22:14<33:17,  6.42s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  49%|‚ñà‚ñà‚ñà‚ñà‚ñâ     | 281/575 [23:44<25:25,  5.19s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  49%|‚ñà‚ñà‚ñà‚ñà‚ñâ     | 283/575 [23:54<25:23,  5.22s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  50%|‚ñà‚ñà‚ñà‚ñà‚ñà     | 290/575 [24:29<23:02,  4.85s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  51%|‚ñà‚ñà‚ñà‚ñà‚ñà     | 292/575 [24:38<22:43,  4.82s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  51%|‚ñà‚ñà‚ñà‚ñà‚ñà     | 293/575 [24:43<22:50,  4.86s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  51%|‚ñà‚ñà‚ñà‚ñà‚ñà     | 294/575 [24:48<22:52,  4.89s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  51%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè    | 295/575 [24:53<22:50,  4.89s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  53%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé    | 307/575 [25:57<25:59,  5.82s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  54%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç    | 313/575 [26:26<21:24,  4.90s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  58%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä    | 332/575 [28:07<21:25,  5.29s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  59%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ    | 341/575 [28:55<20:02,  5.14s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  60%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà    | 346/575 [29:20<19:24,  5.08s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  60%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà    | 347/575 [29:25<19:05,  5.02s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  61%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè   | 353/575 [30:00<21:31,  5.82s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  63%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé   | 362/575 [30:44<17:09,  4.83s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  63%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé   | 363/575 [30:49<17:43,  5.02s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  65%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç   | 371/575 [31:38<19:38,  5.77s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  66%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå   | 379/575 [32:18<16:08,  4.94s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  66%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå   | 380/575 [32:23<16:10,  4.98s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  66%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã   | 381/575 [32:29<16:15,  5.03s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  66%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã   | 382/575 [32:33<16:02,  4.98s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  67%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã   | 383/575 [32:38<15:40,  4.90s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  67%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã   | 388/575 [33:04<16:06,  5.17s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  69%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ   | 397/575 [33:52<15:35,  5.25s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  69%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ   | 398/575 [33:57<15:26,  5.23s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  70%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ   | 400/575 [34:06<14:39,  5.03s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  70%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ   | 401/575 [34:11<14:25,  4.98s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  70%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà   | 404/575 [34:26<14:03,  4.93s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  71%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè  | 410/575 [35:00<16:00,  5.82s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  72%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè  | 413/575 [35:17<15:28,  5.73s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  72%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè  | 416/575 [35:37<17:08,  6.47s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  73%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé  | 419/575 [35:52<14:31,  5.59s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  73%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé  | 420/575 [35:57<13:59,  5.41s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  77%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã  | 445/575 [38:14<11:09,  5.15s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  79%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ  | 454/575 [39:03<11:19,  5.62s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  79%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ  | 457/575 [39:22<11:29,  5.84s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  81%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 465/575 [40:09<11:16,  6.15s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  81%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 467/575 [40:21<10:52,  6.04s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  83%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé | 479/575 [41:25<08:37,  5.40s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  83%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé | 480/575 [41:31<08:36,  5.44s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  84%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç | 482/575 [41:46<10:26,  6.74s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  84%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç | 483/575 [41:51<09:34,  6.24s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  86%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå | 494/575 [42:57<08:36,  6.38s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  87%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã | 498/575 [43:18<06:52,  5.35s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  89%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä | 509/575 [44:18<05:53,  5.35s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  89%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä | 510/575 [44:22<05:37,  5.19s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  89%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ | 512/575 [44:32<05:21,  5.10s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  91%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà | 521/575 [45:19<04:56,  5.49s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  91%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà | 524/575 [45:39<05:15,  6.20s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  94%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç| 540/575 [47:00<02:52,  4.93s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  95%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç| 545/575 [47:28<02:54,  5.81s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  95%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå| 549/575 [47:52<02:35,  5.98s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  96%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå| 550/575 [47:57<02:22,  5.71s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  97%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã| 555/575 [48:22<01:42,  5.13s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing:  99%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ| 571/575 [49:52<00:23,  5.84s/it]

      ‚ö†Ô∏è nibabel failed (corrupted file), trying SimpleITK fallback...
      ‚úì SimpleITK fallback succeeded


Preprocessing: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 575/575 [50:11<00:00,  5.24s/it]



‚úÖ S2 COMPLETE
   Successful: 575/575 (100.0%)
   Failed: 0

üìä Summary:
   Datasets: {'OASIS3': 327, 'OASIS2': 248}
   Labels: CN=499, AD=76
   Mean intensity: 516.185
   Full-brain volumes: 575/575 exist

VALIDATION (Data-Aware)

Criteria:
  Overall success rate: 100.0% (target: ‚â•80%)
  OASIS3 success rate:  100.0% (target: ‚â•70%)
  AD visits available:  76 (target: ‚â•50)

‚úÖ PASS: S2 preprocessing meets quality thresholds
   (0 failures due to corrupted source files)

‚úì Cleaned up 575 temp files

üìã Sample Output:
subject_id  visit_index dataset  label                                                             full_t1_path
 OAS2_0079            2  OASIS2      1 /kaggle/working/processed_mri/full_brain/FULL_OASIS2_OAS2_0079_v2.nii.gz
 OAS2_0044            1  OASIS2      1 /kaggle/working/processed_mri/full_brain/FULL_OASIS2_OAS2_0044_v1.nii.gz
 OAS2_0056            2  OASIS2      0 /kaggle/working/processed_mri/full_brain/FULL_OASIS2_OAS2_0056_v2.nii.gz
 OAS2_0062      

In [3]:
"""
SNIPPET S3: ROI-Focused Tri-Planar Slice Generator (PRODUCTION)

Generates 16 slices per visit (8 axial + 6 coronal + 2 sagittal) from hippocampal ROI

Input:
    /kaggle/working/processed_volumes.csv (from S2)
    Uses: hippo_ext_path (112√ó112√ó80 hippocampal-extended ROI)

Output:
    /kaggle/working/slices_roi/*.png (224√ó224 RGB images)
    /kaggle/working/slices_metadata_ROI.csv (metadata for all slices)

This feeds directly into:
    - S5: DSBN backbone + attention
    - S6: Focal loss training
    - S7: XAI (Grad-CAM, hippocampal Dice)
"""

import os
import numpy as np
import pandas as pd
import nibabel as nib
from pathlib import Path
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

from PIL import Image
from skimage.transform import resize

# ============================================================================
# CONFIGURATION
# ============================================================================

class S3Config:
    """Configuration for tri-planar slice generation"""
    
    # Input
    PROCESSED_CSV = "/kaggle/working/processed_volumes.csv"
    VOLUME_COL = "hippo_ext_path"  # ROI-focused (use "full_t1_path" for full-brain variant)
    
    # Output
    OUTPUT_ROOT = "/kaggle/working/slices_roi"
    METADATA_CSV = "/kaggle/working/slices_metadata_ROI.csv"
    
    # Image parameters
    IMG_SIZE = (224, 224)  # H, W for ImageNet backbones
    
    # Slice counts per plane (K = 16 total)
    N_AXIAL = 8      # Superior-inferior
    N_CORONAL = 6    # Anterior-posterior
    N_SAGITTAL = 2   # Left-right (medial)
    
    # Quality control
    SAVE_QC_EXAMPLES = True
    QC_VISITS = 3  # Save example montages for first N visits


# ============================================================================
# HELPER FUNCTIONS
# ============================================================================

def compute_slice_indices(dim_size, n_slices):
    """
    Compute evenly spaced interior slice indices
    
    Avoids extreme border slices by using linspace on interior range
    
    Args:
        dim_size: dimension size (e.g., 80 for z-axis)
        n_slices: number of slices to extract (e.g., 8)
    
    Returns:
        sorted array of unique integer indices
    
    Example:
        dim_size=80, n_slices=8 ‚Üí [8, 18, 28, 38, 48, 58, 68, 78]
    """
    if n_slices >= dim_size:
        return np.arange(dim_size)
    
    # Use linspace to get n_slices+2 positions, then drop endpoints
    positions = np.linspace(0, dim_size - 1, num=n_slices + 2)
    inner = positions[1:-1]  # Drop first and last
    indices = np.round(inner).astype(int)
    
    # Ensure uniqueness and sort
    indices = sorted(set(indices))
    
    # If we lost indices due to rounding, fall back to linspace
    if len(indices) < n_slices:
        positions = np.linspace(0, dim_size - 1, num=n_slices)
        indices = np.round(positions).astype(int)
        indices = sorted(set(indices))
    
    return np.array(indices)


def normalize_slice_01(slice_2d):
    """
    Normalize 2D slice to [0, 1] range
    
    S2 already normalized volumes, but we add safeguard
    for numerical stability
    
    Args:
        slice_2d: 2D numpy array
    
    Returns:
        normalized float32 array in [0, 1]
    """
    s = slice_2d.astype(np.float32)
    min_val = s.min()
    max_val = s.max()
    
    if max_val <= min_val + 1e-6:
        # Constant slice (shouldn't happen with brain data)
        return np.zeros_like(s, dtype=np.float32)
    
    s = (s - min_val) / (max_val - min_val)
    return s


def resize_to_img(slice_2d, img_size):
    """
    Resize 2D slice to target size with bilinear interpolation
    
    Args:
        slice_2d: 2D array in [0, 1]
        img_size: tuple (H, W), e.g., (224, 224)
    
    Returns:
        resized float32 array in [0, 1]
    """
    H, W = img_size
    
    s_resized = resize(
        slice_2d,
        (H, W),
        order=1,           # Bilinear interpolation
        mode='constant',
        cval=0.0,
        anti_aliasing=True,
        preserve_range=True
    ).astype(np.float32)
    
    return s_resized


def gray_to_rgb(slice_2d):
    """
    Convert grayscale H√óW to RGB H√óW√ó3 by channel replication
    
    Required for ImageNet pretrained backbones expecting 3 channels
    """
    return np.stack([slice_2d, slice_2d, slice_2d], axis=-1)


def save_slice_png(slice_2d, output_path):
    """
    Save 2D slice as PNG image
    
    Args:
        slice_2d: float32 array in [0, 1], shape (H, W) or (H, W, 3)
        output_path: path to save PNG
    """
    # Ensure RGB
    if slice_2d.ndim == 2:
        slice_rgb = gray_to_rgb(slice_2d)
    else:
        slice_rgb = slice_2d
    
    # Convert to uint8 [0, 255]
    img_uint8 = (slice_rgb * 255.0).clip(0, 255).astype(np.uint8)
    
    # Save using PIL
    img_pil = Image.fromarray(img_uint8)
    img_pil.save(output_path)


def create_qc_montage(visit_id, slices_dict, output_path):
    """
    Create QC montage showing all 16 slices for a visit
    
    Args:
        visit_id: visit identifier
        slices_dict: dict with keys 'axial', 'coronal', 'sagittal', 
                     values are lists of (H,W) arrays
        output_path: path to save montage
    """
    import matplotlib.pyplot as plt
    from matplotlib import gridspec
    
    fig = plt.figure(figsize=(16, 10))
    gs = gridspec.GridSpec(3, 8, figure=fig, hspace=0.3, wspace=0.1)
    
    # Axial (top 2 rows, 4 per row)
    axial_slices = slices_dict['axial']
    for i, slice_2d in enumerate(axial_slices):
        row = i // 4
        col = i % 4
        ax = fig.add_subplot(gs[row, col])
        ax.imshow(slice_2d, cmap='gray', origin='lower')
        ax.set_title(f'Axial {i}', fontsize=10)
        ax.axis('off')
    
    # Coronal (bottom row, first 6 positions)
    coronal_slices = slices_dict['coronal']
    for i, slice_2d in enumerate(coronal_slices):
        ax = fig.add_subplot(gs[2, i])
        ax.imshow(slice_2d, cmap='gray', origin='lower')
        ax.set_title(f'Coronal {i}', fontsize=10)
        ax.axis('off')
    
    # Sagittal (bottom row, last 2 positions)
    sagittal_slices = slices_dict['sagittal']
    for i, slice_2d in enumerate(sagittal_slices):
        ax = fig.add_subplot(gs[2, 6 + i])
        ax.imshow(slice_2d, cmap='gray', origin='lower')
        ax.set_title(f'Sagittal {i}', fontsize=10)
        ax.axis('off')
    
    fig.suptitle(f'Visit: {visit_id} (16 slices)', fontsize=14, fontweight='bold')
    plt.savefig(output_path, dpi=100, bbox_inches='tight')
    plt.close()


# ============================================================================
# MAIN PIPELINE
# ============================================================================

def run_s3_slice_generation(config=None):
    """
    Execute tri-planar slice generation pipeline
    
    Returns:
        meta_df: DataFrame with metadata for all generated slices
    """
    if config is None:
        config = S3Config()
    
    print("\n" + "="*70)
    print("üì∏ SNIPPET S3: ROI-Focused Tri-Planar Slice Generator")
    print("="*70)
    print(f"\nConfiguration:")
    print(f"  Volume source: {config.VOLUME_COL}")
    print(f"  Output directory: {config.OUTPUT_ROOT}")
    print(f"  Target image size: {config.IMG_SIZE}")
    print(f"  Slices per visit: {config.N_AXIAL} axial + {config.N_CORONAL} coronal + "
          f"{config.N_SAGITTAL} sagittal = {config.N_AXIAL + config.N_CORONAL + config.N_SAGITTAL}")
    
    # Create output directory
    os.makedirs(config.OUTPUT_ROOT, exist_ok=True)
    
    if config.SAVE_QC_EXAMPLES:
        qc_dir = os.path.join(os.path.dirname(config.OUTPUT_ROOT), "slices_qc")
        os.makedirs(qc_dir, exist_ok=True)
    
    # Load processed volumes
    df = pd.read_csv(config.PROCESSED_CSV)
    print(f"\n‚úì Loaded {len(df)} rows from {config.PROCESSED_CSV}")
    
    # Filter valid visits
    df = df[(df['preproc_ok'] == True) & (df['label'].isin([0, 1]))].copy()
    
    # Ensure domain_id exists
    if 'domain_id' not in df.columns:
        df['domain_id'] = (df['dataset'] == 'OASIS3').astype(int)
    
    # Verify volume files exist
    missing_mask = ~df[config.VOLUME_COL].apply(
        lambda x: os.path.exists(x) if isinstance(x, str) else False
    )
    
    if missing_mask.sum() > 0:
        print(f"\n‚ö†Ô∏è  Removing {missing_mask.sum()} visits with missing {config.VOLUME_COL}")
        df = df[~missing_mask].copy()
    
    print(f"\nAfter filtering:")
    print(f"  Valid visits: {len(df)}")
    print(f"  CN (label=0): {int((df['label'] == 0).sum())}")
    print(f"  AD (label=1): {int((df['label'] == 1).sum())}")
    print(f"  Expected total slices: {len(df) * (config.N_AXIAL + config.N_CORONAL + config.N_SAGITTAL)}")
    
    # Container for metadata
    meta_rows = []
    qc_counter = 0
    
    # ========================================================================
    # Iterate over visits and generate 16 slices each
    # ========================================================================
    
    print("\n" + "="*70)
    print("Generating slices...")
    print("="*70)
    
    for visit_idx, row in tqdm(df.iterrows(), total=len(df), desc="Processing visits"):
        subject_id = row['subject_id']
        visit_index = row['visit_index']
        dataset = row['dataset']
        domain_id = int(row['domain_id'])
        label = int(row['label'])
        
        visit_id = f"{subject_id}_v{visit_index}"
        vol_path = row[config.VOLUME_COL]
        
        # Load 3D ROI volume
        img = nib.load(vol_path)
        vol = img.get_fdata().astype(np.float32)
        
        if vol.ndim != 3:
            print(f"\n‚ö†Ô∏è  Skipping {visit_id}: expected 3D volume, got shape {vol.shape}")
            continue
        
        Dx, Dy, Dz = vol.shape
        
        # Compute slice indices for each plane
        axial_indices = compute_slice_indices(Dz, config.N_AXIAL)
        coronal_indices = compute_slice_indices(Dy, config.N_CORONAL)
        sagittal_indices = compute_slice_indices(Dx, config.N_SAGITTAL)
        
        global_idx = 0
        qc_slices = {'axial': [], 'coronal': [], 'sagittal': []} if config.SAVE_QC_EXAMPLES else None
        
        # -------------------- AXIAL SLICES --------------------
        for rank, z_idx in enumerate(axial_indices):
            # Extract axial slice (x-y plane at z=z_idx)
            slice_2d = vol[:, :, z_idx]
            
            # Normalize and resize
            slice_2d = normalize_slice_01(slice_2d)
            slice_2d = resize_to_img(slice_2d, config.IMG_SIZE)
            
            # Save for QC
            if qc_slices is not None and qc_counter < config.QC_VISITS:
                qc_slices['axial'].append(slice_2d)
            
            # Save as PNG
            fname = f"{visit_id}_axial_{rank:02d}.png"
            out_path = os.path.join(config.OUTPUT_ROOT, fname)
            save_slice_png(slice_2d, out_path)
            
            # Record metadata
            meta_rows.append({
                'subject_id': subject_id,
                'visit_index': visit_index,
                'visit_id': visit_id,
                'dataset': dataset,
                'domain_id': domain_id,
                'label': label,
                'plane': 'axial',
                'plane_slice_rank': rank,
                'global_slice_idx': global_idx,
                'img_path': out_path
            })
            global_idx += 1
        
        # -------------------- CORONAL SLICES --------------------
        for rank, y_idx in enumerate(coronal_indices):
            # Extract coronal slice (x-z plane at y=y_idx)
            slice_2d = vol[:, y_idx, :]
            
            slice_2d = normalize_slice_01(slice_2d)
            slice_2d = resize_to_img(slice_2d, config.IMG_SIZE)
            
            if qc_slices is not None and qc_counter < config.QC_VISITS:
                qc_slices['coronal'].append(slice_2d)
            
            fname = f"{visit_id}_coronal_{rank:02d}.png"
            out_path = os.path.join(config.OUTPUT_ROOT, fname)
            save_slice_png(slice_2d, out_path)
            
            meta_rows.append({
                'subject_id': subject_id,
                'visit_index': visit_index,
                'visit_id': visit_id,
                'dataset': dataset,
                'domain_id': domain_id,
                'label': label,
                'plane': 'coronal',
                'plane_slice_rank': rank,
                'global_slice_idx': global_idx,
                'img_path': out_path
            })
            global_idx += 1
        
        # -------------------- SAGITTAL SLICES --------------------
        for rank, x_idx in enumerate(sagittal_indices):
            # Extract sagittal slice (y-z plane at x=x_idx)
            slice_2d = vol[x_idx, :, :]
            
            slice_2d = normalize_slice_01(slice_2d)
            slice_2d = resize_to_img(slice_2d, config.IMG_SIZE)
            
            if qc_slices is not None and qc_counter < config.QC_VISITS:
                qc_slices['sagittal'].append(slice_2d)
            
            fname = f"{visit_id}_sagittal_{rank:02d}.png"
            out_path = os.path.join(config.OUTPUT_ROOT, fname)
            save_slice_png(slice_2d, out_path)
            
            meta_rows.append({
                'subject_id': subject_id,
                'visit_index': visit_index,
                'visit_id': visit_id,
                'dataset': dataset,
                'domain_id': domain_id,
                'label': label,
                'plane': 'sagittal',
                'plane_slice_rank': rank,
                'global_slice_idx': global_idx,
                'img_path': out_path
            })
            global_idx += 1
        
        # Generate QC montage for first few visits
        if config.SAVE_QC_EXAMPLES and qc_counter < config.QC_VISITS:
            qc_path = os.path.join(qc_dir, f"qc_montage_{visit_id}.png")
            create_qc_montage(visit_id, qc_slices, qc_path)
            qc_counter += 1
        
        # Sanity check
        expected_slices = config.N_AXIAL + config.N_CORONAL + config.N_SAGITTAL
        if global_idx != expected_slices:
            print(f"\n‚ö†Ô∏è  Visit {visit_id}: generated {global_idx} slices, expected {expected_slices}")
    
    # ========================================================================
    # Save metadata CSV
    # ========================================================================
    
    meta_df = pd.DataFrame(meta_rows)
    meta_df.to_csv(config.METADATA_CSV, index=False)
    
    print("\n" + "="*70)
    print("‚úÖ S3 COMPLETE: Tri-Planar Slices Generated")
    print("="*70)
    
    print(f"\nüìä Summary:")
    print(f"  Total slices generated: {len(meta_df)}")
    print(f"  Visits processed: {len(df)}")
    print(f"  Slices per visit: {len(meta_df) / len(df):.1f} (expected: 16)")
    print(f"\n  Slice breakdown:")
    print(f"    Axial: {int((meta_df['plane'] == 'axial').sum())}")
    print(f"    Coronal: {int((meta_df['plane'] == 'coronal').sum())}")
    print(f"    Sagittal: {int((meta_df['plane'] == 'sagittal').sum())}")
    print(f"\n  Label distribution:")
    print(f"    CN visits: {int((meta_df.groupby('visit_id')['label'].first() == 0).sum())}")
    print(f"    AD visits: {int((meta_df.groupby('visit_id')['label'].first() == 1).sum())}")
    print(f"\n  Output files:")
    print(f"    Images: {config.OUTPUT_ROOT}/")
    print(f"    Metadata: {config.METADATA_CSV}")
    
    if config.SAVE_QC_EXAMPLES:
        print(f"    QC montages: {qc_dir}/ ({qc_counter} examples)")
    
    # Show sample metadata
    print("\nüìã Sample metadata (first 8 slices of first visit):")
    sample_cols = ['visit_id', 'plane', 'plane_slice_rank', 'global_slice_idx', 'label', 'domain_id']
    print(meta_df[sample_cols].head(8).to_string(index=False))
    
    # Validate
    print("\n" + "="*70)
    print("VALIDATION")
    print("="*70)
    
    # Check all files exist
    files_exist = meta_df['img_path'].apply(os.path.exists)
    n_missing = (~files_exist).sum()
    
    if n_missing > 0:
        print(f"‚ùå FAIL: {n_missing} image files missing!")
    else:
        print(f"‚úÖ PASS: All {len(meta_df)} image files exist")
    
    # Check slice counts per visit
    slices_per_visit = meta_df.groupby('visit_id').size()
    expected = config.N_AXIAL + config.N_CORONAL + config.N_SAGITTAL
    incorrect_counts = slices_per_visit[slices_per_visit != expected]
    
    if len(incorrect_counts) > 0:
        print(f"‚ö†Ô∏è  WARNING: {len(incorrect_counts)} visits have incorrect slice counts")
        print(incorrect_counts.head())
    else:
        print(f"‚úÖ PASS: All visits have exactly {expected} slices")
    
    return meta_df


# ============================================================================
# EXECUTE
# ============================================================================

if __name__ == "__main__":
    meta_df = run_s3_slice_generation()
    
    print("\n" + "="*70)
    print("üé¨ Ready for downstream tasks:")
    print("="*70)
    print("  Next: S5 - DSBN backbone + visit-level attention")
    print("  Then: S6 - Focal loss training")
    print("  Then: S7 - XAI (Grad-CAM, hippocampal Dice)")



üì∏ SNIPPET S3: ROI-Focused Tri-Planar Slice Generator

Configuration:
  Volume source: hippo_ext_path
  Output directory: /kaggle/working/slices_roi
  Target image size: (224, 224)
  Slices per visit: 8 axial + 6 coronal + 2 sagittal = 16

‚úì Loaded 575 rows from /kaggle/working/processed_volumes.csv

After filtering:
  Valid visits: 575
  CN (label=0): 499
  AD (label=1): 76
  Expected total slices: 9200

Generating slices...


Processing visits: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 575/575 [00:55<00:00, 10.44it/s]


‚úÖ S3 COMPLETE: Tri-Planar Slices Generated

üìä Summary:
  Total slices generated: 9200
  Visits processed: 575
  Slices per visit: 16.0 (expected: 16)

  Slice breakdown:
    Axial: 4600
    Coronal: 3450
    Sagittal: 1150

  Label distribution:
    CN visits: 499
    AD visits: 76

  Output files:
    Images: /kaggle/working/slices_roi/
    Metadata: /kaggle/working/slices_metadata_ROI.csv
    QC montages: /kaggle/working/slices_qc/ (3 examples)

üìã Sample metadata (first 8 slices of first visit):
    visit_id plane  plane_slice_rank  global_slice_idx  label  domain_id
OAS2_0079_v2 axial                 0                 0      1          0
OAS2_0079_v2 axial                 1                 1      1          0
OAS2_0079_v2 axial                 2                 2      1          0
OAS2_0079_v2 axial                 3                 3      1          0
OAS2_0079_v2 axial                 4                 4      1          0
OAS2_0079_v2 axial                 5              




In [4]:
"""
SNIPPET S4: Subject-Level Multi-Site Splits (LOCKED)

Creates subject-level stratified splits with locked test set:
- 80/20 train+val/test split (stratified by site √ó diagnosis)
- 5-fold CV on train+val subjects
- Propagates splits to visit and slice levels

Input:
    /kaggle/working/processed_volumes.csv (from S2)
    /kaggle/working/slices_metadata_ROI.csv (from S3)

Output:
    /kaggle/working/visits_with_splits.csv (visit-level splits)
    /kaggle/working/slices_metadata_ROI_splits.csv (slice-level splits)

Design:
    - Subject-level splits (no leakage)
    - Stratified by site (OASIS2/3) √ó diagnosis (CN/AD)
    - Fixed random seed for reproducibility
    - Test set locked across all experiments
"""

import os
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold, train_test_split

# ============================================================================
# CONFIGURATION
# ============================================================================

RANDOM_SEED = 20250126  # Fixed seed for reproducibility


# ============================================================================
# MAIN PIPELINE
# ============================================================================

def build_s4_splits(
    processed_csv="/kaggle/working/processed_volumes.csv",
    slices_csv="/kaggle/working/slices_metadata_ROI.csv",
):
    """
    Build subject-level stratified splits with locked test set
    
    Returns:
        visits_df: DataFrame with visit-level split information
        slices_merged: DataFrame with slice-level split information
    """
    
    print("\n" + "="*70)
    print("üß© SNIPPET S4: Subject-Level Multi-Site Splits (LOCKED)")
    print("="*70)
    print(f"\nRandom seed: {RANDOM_SEED}")

    # ========================================================================
    # 1. Load visit-level table from S2
    # ========================================================================
    
    print("\nStep 1: Loading visit-level data...")
    visits_df = pd.read_csv(processed_csv)
    print(f"  Loaded {len(visits_df)} visits from {processed_csv}")

    # Keep only valid CN/AD visits with successful preprocessing
    visits_df = visits_df[
        (visits_df["preproc_ok"] == True) & 
        (visits_df["label"].isin([0, 1]))
    ].copy()

    # Define visit_id consistently with S3
    visits_df["visit_id"] = visits_df.apply(
        lambda r: f"{r['subject_id']}_v{r['visit_index']}", axis=1
    )

    # Domain id: 0 = OASIS2, 1 = OASIS3
    if "domain_id" not in visits_df.columns:
        visits_df["domain_id"] = (visits_df["dataset"] == "OASIS3").astype(int)

    print(f"  After filtering: {len(visits_df)} valid CN/AD visits")
    print(f"    CN (label=0): {int((visits_df['label'] == 0).sum())}")
    print(f"    AD (label=1): {int((visits_df['label'] == 1).sum())}")

    # ========================================================================
    # 2. Build subject-level table for stratified splitting
    # ========================================================================
    
    print("\nStep 2: Building subject-level table...")
    
    subjects_df = (
        visits_df.groupby("subject_id")
        .agg(
            label=("label", "first"),  # Use first visit's label
            dataset=("dataset", lambda x: x.mode()[0]),  # Most common dataset
            n_visits=("visit_id", "count")
        )
        .reset_index()
    )

    # Stratification group key: label_site (e.g., "0_OASIS2", "1_OASIS3")
    subjects_df["group"] = (
        subjects_df["label"].astype(str) + "_" + subjects_df["dataset"].astype(str)
    )

    print(f"  Total subjects: {len(subjects_df)}")
    print(f"  Stratification groups:")
    for group, count in subjects_df["group"].value_counts().items():
        print(f"    {group}: {count} subjects")

    # ========================================================================
    # 3. 80/20 subject-level split ‚Üí locked test set
    # ========================================================================
    
    print("\nStep 3: Creating 80/20 train+val/test split (subject-level)...")
    
    trainval_subj, test_subj = train_test_split(
        subjects_df,
        test_size=0.2,
        random_state=RANDOM_SEED,
        stratify=subjects_df["group"]
    )

    print(f"  Train+Val subjects: {len(trainval_subj)}")
    print(f"  Test subjects:      {len(test_subj)}")
    
    # Show test set composition
    print(f"\n  Test set composition:")
    for group, count in test_subj["group"].value_counts().items():
        print(f"    {group}: {count} subjects")

    trainval_subj_ids = set(trainval_subj["subject_id"])
    test_subj_ids = set(test_subj["subject_id"])

    # Safety check: no overlap
    assert trainval_subj_ids.isdisjoint(test_subj_ids), \
        "‚ùå CRITICAL: Overlap between trainval and test subjects!"
    print("  ‚úì No overlap between train+val and test subjects")

    # ========================================================================
    # 4. 5-fold stratified CV on train+val subjects (subject-level)
    # ========================================================================
    
    print("\nStep 4: Creating 5-fold CV on train+val subjects...")
    
    skf = StratifiedKFold(
        n_splits=5, 
        shuffle=True, 
        random_state=RANDOM_SEED
    )

    trainval_subj = trainval_subj.reset_index(drop=True)
    fold_ids = np.full(len(trainval_subj), -1, dtype=int)

    for fold_idx, (_, val_index) in enumerate(
        skf.split(trainval_subj["subject_id"], trainval_subj["group"])
    ):
        fold_ids[val_index] = fold_idx

    trainval_subj["fold_id"] = fold_ids

    # Sanity check
    assert not (trainval_subj["fold_id"] == -1).any(), \
        "‚ùå CRITICAL: Unassigned fold_id found!"

    print(f"  ‚úì All train+val subjects assigned to folds")
    print(f"\n  Fold distribution (subject-level):")
    fold_counts = trainval_subj["fold_id"].value_counts().sort_index()
    for fold_id, count in fold_counts.items():
        print(f"    Fold {fold_id}: {count} subjects")

    # ========================================================================
    # 5. Map subject-level splits back to visits
    # ========================================================================
    
    print("\nStep 5: Mapping splits to visit level...")
    
    # Initialize split columns
    visits_df["split"] = "trainval"
    visits_df["fold_id"] = -1

    # Mark test visits
    visits_df.loc[visits_df["subject_id"].isin(test_subj_ids), "split"] = "test"

    # Assign fold_id for trainval visits
    fold_map = dict(zip(trainval_subj["subject_id"], trainval_subj["fold_id"]))
    mask_trainval = visits_df["subject_id"].isin(trainval_subj_ids)
    visits_df.loc[mask_trainval, "fold_id"] = (
        visits_df.loc[mask_trainval, "subject_id"].map(fold_map)
    )

    # Summary
    print(f"\n  Visit-level split summary:")
    for split, count in visits_df["split"].value_counts().items():
        print(f"    {split}: {count} visits")

    print(f"\n  Fold distribution (trainval visits only):")
    fold_visit_counts = (
        visits_df[visits_df["split"] == "trainval"]["fold_id"]
        .value_counts()
        .sort_index()
    )
    for fold_id, count in fold_visit_counts.items():
        print(f"    Fold {fold_id}: {count} visits")

    # Save visit-level splits
    visits_out = "/kaggle/working/visits_with_splits.csv"
    visits_df.to_csv(visits_out, index=False)
    print(f"\n  ‚úì Saved visit-level splits to: {visits_out}")

    # ========================================================================
    # 6. Propagate splits to slice-level metadata
    # ========================================================================
    
    print("\nStep 6: Propagating splits to slice level...")
    
    slices_df = pd.read_csv(slices_csv)
    print(f"  Loaded {len(slices_df)} slices from {slices_csv}")

    # Join on visit_id to get split/fold info
    slices_merged = slices_df.merge(
        visits_df[["visit_id", "split", "fold_id"]],
        on="visit_id",
        how="inner"
    )

    print(f"  After merge: {len(slices_merged)} slices with split/fold info")

    # Sanity: each visit still has 16 slices
    slices_per_visit = slices_merged.groupby("visit_id").size()
    expected_slices = 16
    
    incorrect_visits = slices_per_visit[slices_per_visit != expected_slices]
    if len(incorrect_visits) > 0:
        print(f"  ‚ö†Ô∏è  WARNING: {len(incorrect_visits)} visits do not have {expected_slices} slices")
        print(f"     Expected: {expected_slices}, found range: [{slices_per_visit.min()}, {slices_per_visit.max()}]")
    else:
        print(f"  ‚úì All visits have exactly {expected_slices} slices")

    # Save slice-level splits
    slices_out = "/kaggle/working/slices_metadata_ROI_splits.csv"
    slices_merged.to_csv(slices_out, index=False)
    print(f"  ‚úì Saved slice-level splits to: {slices_out}")

    # ========================================================================
    # 7. Final summary and validation
    # ========================================================================
    
    print("\n" + "="*70)
    print("‚úÖ S4 COMPLETE: Subject-Level Splits Ready")
    print("="*70)

    print(f"\nüìä Summary:")
    
    # Subject-level
    print(f"\n  Subject-level:")
    print(f"    Total subjects: {len(subjects_df)}")
    print(f"    Train+Val: {len(trainval_subj)} ({len(trainval_subj)/len(subjects_df)*100:.1f}%)")
    print(f"    Test: {len(test_subj)} ({len(test_subj)/len(subjects_df)*100:.1f}%)")

    # Visit-level
    print(f"\n  Visit-level:")
    total_visits = len(visits_df)
    trainval_visits = int((visits_df["split"] == "trainval").sum())
    test_visits = int((visits_df["split"] == "test").sum())
    print(f"    Total visits: {total_visits}")
    print(f"    Train+Val: {trainval_visits} ({trainval_visits/total_visits*100:.1f}%)")
    print(f"    Test: {test_visits} ({test_visits/total_visits*100:.1f}%)")

    # Slice-level
    print(f"\n  Slice-level:")
    total_slices = len(slices_merged)
    trainval_slices = int((slices_merged["split"] == "trainval").sum())
    test_slices = int((slices_merged["split"] == "test").sum())
    print(f"    Total slices: {total_slices}")
    print(f"    Train+Val: {trainval_slices} ({trainval_slices/total_slices*100:.1f}%)")
    print(f"    Test: {test_slices} ({test_slices/total_slices*100:.1f}%)")

    # Label distribution in test set
    print(f"\n  Test set label distribution:")
    test_visits_df = visits_df[visits_df["split"] == "test"]
    cn_test = int((test_visits_df["label"] == 0).sum())
    ad_test = int((test_visits_df["label"] == 1).sum())
    print(f"    CN: {cn_test} ({cn_test/len(test_visits_df)*100:.1f}%)")
    print(f"    AD: {ad_test} ({ad_test/len(test_visits_df)*100:.1f}%)")

    # Fold balance
    print(f"\n  Fold balance (train+val slices):")
    trainval_fold_counts = (
        slices_merged[slices_merged["split"] == "trainval"]["fold_id"]
        .value_counts()
        .sort_index()
    )
    for fold_id, count in trainval_fold_counts.items():
        pct = count / trainval_slices * 100
        print(f"    Fold {fold_id}: {count} slices ({pct:.1f}%)")

    # ========================================================================
    # Validation checks
    # ========================================================================
    
    print("\n" + "="*70)
    print("VALIDATION")
    print("="*70)

    checks_passed = 0
    checks_total = 0

    # Check 1: No subject overlap
    checks_total += 1
    if trainval_subj_ids.isdisjoint(test_subj_ids):
        print("‚úÖ No subject leakage between train+val and test")
        checks_passed += 1
    else:
        print("‚ùå CRITICAL: Subject leakage detected!")

    # Check 2: All visits assigned to split
    checks_total += 1
    unassigned_visits = visits_df[visits_df["split"].isna()]
    if len(unassigned_visits) == 0:
        print("‚úÖ All visits assigned to split")
        checks_passed += 1
    else:
        print(f"‚ùå {len(unassigned_visits)} visits not assigned to split")

    # Check 3: All trainval visits have fold_id
    checks_total += 1
    trainval_no_fold = visits_df[
        (visits_df["split"] == "trainval") & 
        (visits_df["fold_id"] == -1)
    ]
    if len(trainval_no_fold) == 0:
        print("‚úÖ All train+val visits assigned to fold")
        checks_passed += 1
    else:
        print(f"‚ùå {len(trainval_no_fold)} train+val visits missing fold_id")

    # Check 4: Test visits have fold_id = -1
    checks_total += 1
    test_with_fold = visits_df[
        (visits_df["split"] == "test") & 
        (visits_df["fold_id"] != -1)
    ]
    if len(test_with_fold) == 0:
        print("‚úÖ Test visits correctly marked (fold_id=-1)")
        checks_passed += 1
    else:
        print(f"‚ùå {len(test_with_fold)} test visits have invalid fold_id")

    # Check 5: Slice counts per visit
    checks_total += 1
    if len(incorrect_visits) == 0:
        print("‚úÖ All visits have correct slice count (16)")
        checks_passed += 1
    else:
        print(f"‚ö†Ô∏è  {len(incorrect_visits)} visits have incorrect slice count")

    print(f"\n{'='*70}")
    if checks_passed == checks_total:
        print(f"‚úÖ ALL VALIDATION CHECKS PASSED ({checks_passed}/{checks_total})")
    else:
        print(f"‚ö†Ô∏è  VALIDATION: {checks_passed}/{checks_total} checks passed")

    return visits_df, slices_merged


# ============================================================================
# EXECUTE
# ============================================================================

if __name__ == "__main__":
    visits_df, slices_df = build_s4_splits()
    
    print("\n" + "="*70)
    print("üé¨ Ready for downstream tasks:")
    print("="*70)
    print("  Next: S5 - DSBN backbone + visit-level attention")
    print("  Then: S6 - Focal loss training")
    print("  Then: S7 - XAI (Grad-CAM, hippocampal Dice)")
    print("\n  Files ready:")
    print("    - visits_with_splits.csv")
    print("    - slices_metadata_ROI_splits.csv")



üß© SNIPPET S4: Subject-Level Multi-Site Splits (LOCKED)

Random seed: 20250126

Step 1: Loading visit-level data...
  Loaded 575 visits from /kaggle/working/processed_volumes.csv
  After filtering: 575 valid CN/AD visits
    CN (label=0): 499
    AD (label=1): 76

Step 2: Building subject-level table...
  Total subjects: 346
  Stratification groups:
    0_OASIS3: 202 subjects
    0_OASIS2: 86 subjects
    1_OASIS3: 33 subjects
    1_OASIS2: 25 subjects

Step 3: Creating 80/20 train+val/test split (subject-level)...
  Train+Val subjects: 276
  Test subjects:      70

  Test set composition:
    0_OASIS3: 41 subjects
    0_OASIS2: 17 subjects
    1_OASIS3: 7 subjects
    1_OASIS2: 5 subjects
  ‚úì No overlap between train+val and test subjects

Step 4: Creating 5-fold CV on train+val subjects...
  ‚úì All train+val subjects assigned to folds

  Fold distribution (subject-level):
    Fold 0: 56 subjects
    Fold 1: 55 subjects
    Fold 2: 55 subjects
    Fold 3: 55 subjects
    Fold 4:

In [5]:
"""
SNIPPET S5: Dataset + Model Architecture (DSBN + Attention) - FIXED FOR S6

KEY CHANGE: Dataset now returns images in [0, 1] WITHOUT normalization
(S6 training loop will handle normalization + augmentation)
"""

import os
import numpy as np
import pandas as pd
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from PIL import Image
import torchvision.transforms as transforms
import torchvision.models as models

# ============================================================================
# CONFIGURATION
# ============================================================================

class S5Config:
    """Configuration for dataset and model"""
    
    # Data
    SLICES_CSV = "/kaggle/working/slices_metadata_ROI_splits.csv"
    
    # Model architecture
    BACKBONE = "vgg16_bn"  # Options: 'vgg16_bn', 'resnet50', 'densenet121'
    FEATURE_DIM = 512      # VGG16: 512, ResNet50: 2048, DenseNet121: 1024
    ATTENTION_DIM = 128
    NUM_DOMAINS = 2        # OASIS2=0, OASIS3=1
    USE_DSBN = True
    
    # Training
    BATCH_SIZE = 8
    NUM_WORKERS = 4
    
    # Image preprocessing
    IMG_SIZE = 224
    IMAGENET_MEAN = [0.485, 0.456, 0.406]
    IMAGENET_STD = [0.229, 0.224, 0.225]


# ============================================================================
# DATASET: TriPlanarVisitDataset (FIXED FOR S6)
# ============================================================================

class TriPlanarVisitDataset(Dataset):
    """
    Dataset for tri-planar visit-level classification
    
    FIXED FOR S6: Returns images in [0, 1] WITHOUT normalization
    (S6 training loop handles normalization + label-aware augmentation)
    
    Loads 16 slices per visit (8 axial + 6 coronal + 2 sagittal)
    Supports train/val/test splits with fold-based CV
    
    Args:
        csv_path: path to slices_metadata_ROI_splits.csv
        mode: 'train', 'val', or 'test'
        fold_id: fold number (0-4) for train/val split, ignored for test
        transform: NOT USED (kept for compatibility)
        config: S5Config instance
    """
    
    def __init__(self, csv_path, mode, fold_id=None, transform=None, config=None):
        self.csv_path = csv_path
        self.mode = mode
        self.fold_id = fold_id
        self.config = config or S5Config()
        
        # Load metadata
        self.df = pd.read_csv(csv_path)
        
        # Filter by split and fold
        if mode == "train":
            assert fold_id is not None, "fold_id required for train mode"
            self.df = self.df[
                (self.df["split"] == "trainval") & 
                (self.df["fold_id"] != fold_id)
            ].copy()
        elif mode == "val":
            assert fold_id is not None, "fold_id required for val mode"
            self.df = self.df[
                (self.df["split"] == "trainval") & 
                (self.df["fold_id"] == fold_id)
            ].copy()
        elif mode == "test":
            self.df = self.df[self.df["split"] == "test"].copy()
        else:
            raise ValueError(f"mode must be 'train', 'val', or 'test', got {mode}")
        
        # Group by visit_id
        self.visits = []
        grouped = self.df.groupby("visit_id")
        
        for visit_id, group in grouped:
            group_sorted = group.sort_values("global_slice_idx")
            
            if len(group_sorted) != 16:
                print(f"Warning: Visit {visit_id} has {len(group_sorted)} slices, skipping")
                continue
            
            self.visits.append({
                "visit_id": visit_id,
                "rows": group_sorted,
                "label": int(group_sorted["label"].iloc[0]),
                "domain_id": int(group_sorted["domain_id"].iloc[0]),
            })
        
        print(f"  {mode.capitalize()} dataset: {len(self.visits)} visits")
        if len(self.visits) > 0:
            cn_count = sum(1 for v in self.visits if v["label"] == 0)
            ad_count = sum(1 for v in self.visits if v["label"] == 1)
            print(f"    CN: {cn_count}, AD: {ad_count}")
    
    def __len__(self):
        return len(self.visits)
    
    def __getitem__(self, idx):
        """
        CRITICAL CHANGE: Returns images in [0, 1] WITHOUT normalization
        
        S6 training loop will handle:
        - Label-aware augmentation
        - ImageNet normalization
        """
        visit = self.visits[idx]
        rows = visit["rows"]
        
        # Load all 16 slices
        imgs = []
        for _, row in rows.iterrows():
            img_path = row["img_path"]
            
            # Load image and convert to tensor [0, 1]
            img = Image.open(img_path).convert("RGB")
            img = transforms.ToTensor()(img)  # Converts to [0, 1]
            
            imgs.append(img)
        
        # Stack to (K, C, H, W)
        imgs = torch.stack(imgs, dim=0)  # (16, 3, 224, 224) in [0, 1]
        
        label = torch.tensor(visit["label"], dtype=torch.float32)
        domain_id = torch.tensor(visit["domain_id"], dtype=torch.long)
        
        return imgs, label, domain_id


# ============================================================================
# DOMAIN-SPECIFIC BATCH NORMALIZATION
# ============================================================================

class DomainSpecificBN2d(nn.Module):
    """Domain-Specific Batch Normalization for multi-site adaptation"""
    
    def __init__(self, num_features, num_domains=2):
        super().__init__()
        self.num_domains = num_domains
        self.num_features = num_features
        
        self.bns = nn.ModuleList([
            nn.BatchNorm2d(num_features) for _ in range(num_domains)
        ])
    
    def forward(self, x, domain_id):
        out = torch.zeros_like(x)
        
        for d in range(self.num_domains):
            mask = (domain_id == d)
            if mask.any():
                out[mask] = self.bns[d](x[mask])
        
        return out


# ============================================================================
# VISIT-LEVEL ATTENTION
# ============================================================================

class VisitAttention(nn.Module):
    """Attention mechanism for aggregating K slice features into visit embedding"""
    
    def __init__(self, in_dim, attn_dim=128):
        super().__init__()
        
        self.W_a = nn.Linear(in_dim, attn_dim)
        self.b_a = nn.Parameter(torch.zeros(attn_dim))
        self.w_a = nn.Linear(attn_dim, 1, bias=False)
    
    def forward(self, F):
        H = torch.tanh(self.W_a(F) + self.b_a)
        s = self.w_a(H).squeeze(-1)
        alpha = torch.softmax(s, dim=1)
        v = torch.sum(alpha.unsqueeze(-1) * F, dim=1)
        
        return v, alpha


# ============================================================================
# VISIT CLASSIFIER
# ============================================================================

class VisitClassifier(nn.Module):
    """Simple linear classifier for visit embedding"""
    
    def __init__(self, in_dim):
        super().__init__()
        self.fc = nn.Linear(in_dim, 1)
    
    def forward(self, v):
        logits = self.fc(v).squeeze(-1)
        return logits


# ============================================================================
# COMPLETE MODEL: TriPlanarADNet
# ============================================================================

class TriPlanarADNet(nn.Module):
    """Complete tri-planar AD classification network"""
    
    def __init__(self, backbone_name='vgg16_bn', num_domains=2, 
                 use_dsbn=True, pretrained=True, config=None):
        super().__init__()
        
        self.config = config or S5Config()
        self.backbone_name = backbone_name
        self.use_dsbn = use_dsbn
        
        # Load backbone
        self.backbone, self.feature_dim = self._load_backbone(backbone_name, pretrained)
        
        # Domain-specific BN
        if use_dsbn:
            self.dsbn = DomainSpecificBN2d(
                num_features=self.feature_dim,
                num_domains=num_domains
            )
        else:
            self.dsbn = None
        
        # Visit-level attention
        self.attention = VisitAttention(
            in_dim=self.feature_dim,
            attn_dim=self.config.ATTENTION_DIM
        )
        
        # Classifier
        self.classifier = VisitClassifier(in_dim=self.feature_dim)
    
    def _load_backbone(self, backbone_name, pretrained):
        if backbone_name == 'vgg16_bn':
            vgg = models.vgg16_bn(pretrained=pretrained)
            backbone = vgg.features
            feature_dim = 512
        elif backbone_name == 'resnet50':
            resnet = models.resnet50(pretrained=pretrained)
            backbone = nn.Sequential(*list(resnet.children())[:-2])
            feature_dim = 2048
        elif backbone_name == 'densenet121':
            densenet = models.densenet121(pretrained=pretrained)
            backbone = densenet.features
            feature_dim = 1024
        else:
            raise ValueError(f"Unknown backbone: {backbone_name}")
        
        return backbone, feature_dim
    
    def forward(self, imgs, domain_id):
        B, K, C, H, W = imgs.shape
        
        # Reshape to process all slices
        x = imgs.view(B * K, C, H, W)
        
        # Repeat domain_id for each slice
        domain_id_slices = domain_id.unsqueeze(1).repeat(1, K).view(-1)
        
        # Extract conv features
        feat = self.backbone(x)
        
        # Apply DSBN
        if self.use_dsbn:
            feat = self.dsbn(feat, domain_id_slices)
        
        # Global average pooling
        feat_pooled = feat.mean(dim=[2, 3])
        
        # Reshape to (B, K, D)
        F = feat_pooled.view(B, K, -1)
        
        # Visit-level attention
        v, alpha = self.attention(F)
        
        # Classification
        logits = self.classifier(v)
        
        return logits, alpha


# ============================================================================
# TESTING: Verify Components
# ============================================================================

def test_s5_components():
    """Test dataset and model components"""
    
    print("\n" + "="*70)
    print("üß™ Testing S5 Components (Fixed for S6)")
    print("="*70)
    
    config = S5Config()
    
    # Test 1: Dataset
    print("\nTest 1: Dataset loading...")
    try:
        dataset = TriPlanarVisitDataset(
            csv_path=config.SLICES_CSV,
            mode="train",
            fold_id=0,
            transform=None,  # Not used anymore
            config=config
        )
        
        imgs, label, domain_id = dataset[0]
        print(f"  ‚úì Dataset shape: imgs={imgs.shape}, label={label}, domain_id={domain_id}")
        print(f"  ‚úì Image value range: [{imgs.min():.3f}, {imgs.max():.3f}] (should be ~[0, 1])")
        assert imgs.shape == (16, 3, 224, 224), f"Expected (16, 3, 224, 224), got {imgs.shape}"
        assert imgs.min() >= 0 and imgs.max() <= 1, "Images should be in [0, 1] range"
        print("  ‚úì Dataset test passed")
    except Exception as e:
        print(f"  ‚ùå Dataset test failed: {e}")
        import traceback
        traceback.print_exc()
        return False
    
    # Test 2: DSBN
    print("\nTest 2: Domain-Specific BN...")
    try:
        dsbn = DomainSpecificBN2d(num_features=512, num_domains=2)
        x = torch.randn(32, 512, 7, 7)
        domain_id = torch.randint(0, 2, (32,))
        out = dsbn(x, domain_id)
        print(f"  ‚úì DSBN output shape: {out.shape}")
        assert out.shape == x.shape
        print("  ‚úì DSBN test passed")
    except Exception as e:
        print(f"  ‚ùå DSBN test failed: {e}")
        return False
    
    # Test 3: Attention
    print("\nTest 3: Visit Attention...")
    try:
        attention = VisitAttention(in_dim=512, attn_dim=128)
        F = torch.randn(4, 16, 512)
        v, alpha = attention(F)
        print(f"  ‚úì Attention outputs: v={v.shape}, alpha={alpha.shape}")
        assert v.shape == (4, 512) and alpha.shape == (4, 16)
        assert torch.allclose(alpha.sum(dim=1), torch.ones(4), atol=1e-6)
        print("  ‚úì Attention test passed")
    except Exception as e:
        print(f"  ‚ùå Attention test failed: {e}")
        return False
    
    # Test 4: Full Model
    print("\nTest 4: Complete Model...")
    try:
        model = TriPlanarADNet(
            backbone_name='vgg16_bn',
            num_domains=2,
            use_dsbn=True,
            pretrained=False,
            config=config
        )
        
        imgs = torch.rand(2, 16, 3, 224, 224)  # [0, 1] range
        domain_id = torch.randint(0, 2, (2,))
        
        logits, alpha = model(imgs, domain_id)
        print(f"  ‚úì Model outputs: logits={logits.shape}, alpha={alpha.shape}")
        assert logits.shape == (2,) and alpha.shape == (2, 16)
        print("  ‚úì Model test passed")
    except Exception as e:
        print(f"  ‚ùå Model test failed: {e}")
        import traceback
        traceback.print_exc()
        return False
    
    print("\n" + "="*70)
    print("‚úÖ All S5 tests passed!")
    print("="*70)
    
    return True


# ============================================================================
# EXECUTE
# ============================================================================

if __name__ == "__main__":
    success = test_s5_components()
    
    if success:
        print("\n" + "="*70)
        print("üé¨ S5 Ready for S6 Training")
        print("="*70)
        print("\nKey change:")
        print("  - Dataset now returns images in [0, 1] WITHOUT normalization")
        print("  - S6 training loop will handle normalization + augmentation")
        print("\nNext: Run S6 training script")


üß™ Testing S5 Components (Fixed for S6)

Test 1: Dataset loading...
  Train dataset: 370 visits
    CN: 319, AD: 51
  ‚úì Dataset shape: imgs=torch.Size([16, 3, 224, 224]), label=0.0, domain_id=0
  ‚úì Image value range: [0.000, 0.988] (should be ~[0, 1])
  ‚úì Dataset test passed

Test 2: Domain-Specific BN...
  ‚úì DSBN output shape: torch.Size([32, 512, 7, 7])
  ‚úì DSBN test passed

Test 3: Visit Attention...
  ‚úì Attention outputs: v=torch.Size([4, 512]), alpha=torch.Size([4, 16])
  ‚úì Attention test passed

Test 4: Complete Model...
  ‚úì Model outputs: logits=torch.Size([2]), alpha=torch.Size([2, 16])
  ‚úì Model test passed

‚úÖ All S5 tests passed!

üé¨ S5 Ready for S6 Training

Key change:
  - Dataset now returns images in [0, 1] WITHOUT normalization
  - S6 training loop will handle normalization + augmentation

Next: Run S6 training script


In [6]:
"""
SNIPPET S6: STABLE BASELINE (Stage 1 Only + AD Duplication)

STRATEGY (Step 1 from roadmap):
1. Stage 1 ONLY (skip Stage 2 - unstable on small dataset)
2. BCE with pos_weight (stable)
3. AD duplication (not sampler)
4. Label-aware augmentation ON
5. Threshold search for evaluation

Goal: Establish stable baseline (AUC ‚â• 0.70, BalAcc ‚â• 0.70) before:
- Trying Focal Loss
- Adding ResNet50/DenseNet121
- Building ensemble
"""

import os
import numpy as np
import pandas as pd
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from PIL import Image
import torchvision.transforms as transforms
from sklearn.metrics import roc_auc_score
from tqdm import tqdm

# ============================================================================
# CONFIGURATION
# ============================================================================

class S6Config:
    """Training configuration - Stable Baseline"""
    
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    SLICES_CSV = "/kaggle/working/slices_metadata_ROI_splits.csv"
    
    # Training (Stage 1 only for now)
    STAGE1_EPOCHS = 10
    USE_STAGE2 = False  # Disabled for stable baseline
    
    BATCH_SIZE = 4
    NUM_WORKERS = 2
    GRAD_CLIP = 5.0
    
    STAGE1_LR = 1e-3
    STAGE1_WEIGHT_DECAY = 1e-4
    STAGE1_PATIENCE = 8
    
    # AD duplication (replaces sampler)
    AD_DUP_FACTOR = 2  # Duplicate each AD visit 2x in training
    
    # Augmentation
    USE_AUG = True
    AUG_ROTATION = 15
    AUG_TRANSLATE = 0.02
    AUG_SCALE = (0.9, 1.1)
    AUG_BRIGHTNESS = 0.2
    AUG_CONTRAST = 0.2
    AUG_NOISE_SIGMA = 0.02
    
    BACKBONE = "vgg16_bn"
    NUM_FOLDS = 5
    
    CHECKPOINT_DIR = "/kaggle/working/checkpoints"
    RESULTS_CSV = "/kaggle/working/s6_results_baseline.csv"


# ============================================================================
# AUGMENTATION
# ============================================================================

IMAGENET_NORM = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)

def create_ad_augmentation(config):
    return transforms.Compose([
        transforms.RandomRotation(degrees=config.AUG_ROTATION),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomAffine(
            degrees=0,
            translate=(config.AUG_TRANSLATE, config.AUG_TRANSLATE),
            scale=config.AUG_SCALE
        ),
        transforms.ColorJitter(
            brightness=config.AUG_BRIGHTNESS,
            contrast=config.AUG_CONTRAST
        ),
    ])


def create_cn_augmentation(config):
    return transforms.Compose([
        transforms.RandomRotation(degrees=config.AUG_ROTATION // 2),
        transforms.RandomHorizontalFlip(p=0.3),
    ])


def add_gaussian_noise(x, sigma=0.02):
    if sigma <= 0:
        return x
    noise = torch.randn_like(x) * sigma
    return torch.clamp(x + noise, 0.0, 1.0)


def augment_batch_label_aware(imgs, labels, config):
    """Label-aware augmentation: heavier for AD, lighter for CN"""
    B, K, C, H, W = imgs.shape
    imgs_aug = torch.zeros_like(imgs)
    
    ad_transform = create_ad_augmentation(config)
    cn_transform = create_cn_augmentation(config)
    
    for b in range(B):
        label = labels[b].item()
        
        if label == 1:  # AD
            transform = ad_transform
            use_noise = True
        else:  # CN
            transform = cn_transform
            use_noise = False
        
        for k in range(K):
            slice_img = imgs[b, k]
            pil_img = transforms.functional.to_pil_image(slice_img)
            pil_img = transform(pil_img)
            slice_tensor = transforms.functional.to_tensor(pil_img)
            
            if use_noise:
                slice_tensor = add_gaussian_noise(slice_tensor, sigma=config.AUG_NOISE_SIGMA)
            
            slice_tensor = IMAGENET_NORM(slice_tensor)
            imgs_aug[b, k] = slice_tensor
    
    return imgs_aug


def normalize_batch(imgs):
    B, K, C, H, W = imgs.shape
    imgs_norm = torch.zeros_like(imgs)
    
    for b in range(B):
        for k in range(K):
            imgs_norm[b, k] = IMAGENET_NORM(imgs[b, k])
    
    return imgs_norm


# ============================================================================
# METRICS WITH THRESHOLD SEARCH
# ============================================================================

def compute_metrics(logits, labels, search_best_thr=True):
    """Compute metrics with optimal threshold search"""
    probs = 1 / (1 + np.exp(-logits))
    y = labels.astype(int)
    
    if search_best_thr:
        thresholds = np.linspace(0.05, 0.95, 19)
        best_bal_acc = -1.0
        best_thr = 0.5
        best_stats = None
        
        for thr in thresholds:
            preds = (probs >= thr).astype(int)
            
            TP = ((y == 1) & (preds == 1)).sum()
            TN = ((y == 0) & (preds == 0)).sum()
            FP = ((y == 0) & (preds == 1)).sum()
            FN = ((y == 1) & (preds == 0)).sum()
            
            sens = TP / (TP + FN + 1e-8)
            spec = TN / (TN + FP + 1e-8)
            bal_acc = 0.5 * (sens + spec)
            
            if bal_acc > best_bal_acc:
                best_bal_acc = bal_acc
                best_thr = thr
                best_stats = (TP, TN, FP, FN, sens, spec)
        
        try:
            auc = roc_auc_score(y, probs)
        except ValueError:
            auc = 0.5
        
        TP, TN, FP, FN, sens, spec = best_stats
        
        return {
            'bal_acc': float(best_bal_acc),
            'auc': float(auc),
            'sens': float(sens),
            'spec': float(spec),
            'tp': int(TP),
            'tn': int(TN),
            'fp': int(FP),
            'fn': int(FN),
            'thr': float(best_thr)
        }
    
    # Fixed 0.5 threshold (for comparison)
    preds = (probs >= 0.5).astype(int)
    TP = ((y == 1) & (preds == 1)).sum()
    TN = ((y == 0) & (preds == 0)).sum()
    FP = ((y == 0) & (preds == 1)).sum()
    FN = ((y == 1) & (preds == 0)).sum()
    
    sens = TP / (TP + FN + 1e-8)
    spec = TN / (TN + FP + 1e-8)
    bal_acc = 0.5 * (sens + spec)
    
    try:
        auc = roc_auc_score(y, probs)
    except ValueError:
        auc = 0.5
    
    return {
        'bal_acc': float(bal_acc),
        'auc': float(auc),
        'sens': float(sens),
        'spec': float(spec),
        'tp': int(TP),
        'tn': int(TN),
        'fp': int(FP),
        'fn': int(FN),
        'thr': 0.5
    }


# ============================================================================
# TRAINING WITH AD DUPLICATION
# ============================================================================

def train_one_epoch_with_ad_dup(model, train_loader, criterion, optimizer, device, config):
    """
    Training with AD duplication instead of sampler
    
    Each batch: if it contains AD samples, duplicate them AD_DUP_FACTOR times
    """
    model.train()
    total_loss = 0.0
    n_batches = 0
    
    for imgs, labels, domain_ids in train_loader:
        imgs = imgs.to(device)
        labels = labels.to(device)
        domain_ids = domain_ids.to(device)
        
        # üî• AD DUPLICATION: duplicate AD samples in the batch
        if config.AD_DUP_FACTOR > 1:
            ad_mask = (labels == 1)
            if ad_mask.any():
                ad_imgs = imgs[ad_mask]
                ad_labels = labels[ad_mask]
                ad_domain_ids = domain_ids[ad_mask]
                
                # Duplicate AD samples
                for _ in range(config.AD_DUP_FACTOR - 1):
                    imgs = torch.cat([imgs, ad_imgs], dim=0)
                    labels = torch.cat([labels, ad_labels], dim=0)
                    domain_ids = torch.cat([domain_ids, ad_domain_ids], dim=0)
        
        # Apply augmentation
        if config.USE_AUG:
            imgs = augment_batch_label_aware(imgs, labels, config)
        else:
            imgs = normalize_batch(imgs)
        
        optimizer.zero_grad()
        logits, alpha = model(imgs, domain_ids)
        
        loss = criterion(logits, labels)
        loss.backward()
        
        if config.GRAD_CLIP is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.GRAD_CLIP)
        
        optimizer.step()
        
        total_loss += loss.item()
        n_batches += 1
    
    return total_loss / max(n_batches, 1)


@torch.no_grad()
def evaluate(model, val_loader, device, config):
    model.eval()
    
    all_logits = []
    all_labels = []
    
    for imgs, labels, domain_ids in val_loader:
        imgs = imgs.to(device)
        labels = labels.to(device)
        domain_ids = domain_ids.to(device)
        
        imgs = normalize_batch(imgs)
        
        logits, alpha = model(imgs, domain_ids)
        
        all_logits.append(logits.detach().cpu().numpy())
        all_labels.append(labels.detach().cpu().numpy())
    
    if len(all_logits) == 0:
        return {'bal_acc': 0.0, 'auc': 0.0, 'sens': 0.0, 'spec': 0.0, 'thr': 0.5}
    
    logits_np = np.concatenate(all_logits, axis=0)
    labels_np = np.concatenate(all_labels, axis=0)
    
    metrics = compute_metrics(logits_np, labels_np, search_best_thr=True)
    
    return metrics


# ============================================================================
# STAGE 1 TRAINING (PRIMARY)
# ============================================================================

def freeze_module(module):
    for param in module.parameters():
        param.requires_grad = False


def unfreeze_module(module):
    for param in module.parameters():
        param.requires_grad = True


def stage1_training(model, train_loader, val_loader, device, config, pos_weight):
    """
    Stage 1: Linear probing (freeze backbone, train DSBN + attention + classifier)
    
    This is our PRIMARY training strategy (Stage 2 disabled for small dataset)
    """
    print("\n" + "="*70)
    print("STAGE 1: LINEAR PROBING (PRIMARY)")
    print("="*70)
    print(f"  Frozen: Backbone")
    print(f"  Training: DSBN, Attention, Classifier")
    print(f"  Loss: BCEWithLogitsLoss(pos_weight={pos_weight.item():.2f})")
    print(f"  AD duplication: {config.AD_DUP_FACTOR}x")
    print(f"  Augmentation: {config.USE_AUG}")
    
    freeze_module(model.backbone)
    
    if hasattr(model, 'dsbn') and model.dsbn is not None:
        unfreeze_module(model.dsbn)
    unfreeze_module(model.attention)
    unfreeze_module(model.classifier)
    
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight).to(device)
    
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=config.STAGE1_LR,
        weight_decay=config.STAGE1_WEIGHT_DECAY
    )
    
    best_bal_acc = 0.0
    best_auc = 0.0
    best_state = None
    no_improve = 0
    
    print(f"\n{'Epoch':>5} {'Train Loss':>12} {'Val BalAcc':>12} {'Val AUC':>10} {'Sensitivity':>12} {'Specificity':>12} {'Threshold':>10}")
    print("-" * 80)
    
    for epoch in range(config.STAGE1_EPOCHS):
        train_loss = train_one_epoch_with_ad_dup(
            model, train_loader, criterion, optimizer, device, config
        )
        val_metrics = evaluate(model, val_loader, device, config)
        
        bal_acc = val_metrics['bal_acc']
        auc = val_metrics['auc']
        sens = val_metrics['sens']
        spec = val_metrics['spec']
        thr = val_metrics.get('thr', 0.5)
        
        print(f"{epoch+1:>5} {train_loss:>12.4f} {bal_acc:>12.4f} {auc:>10.4f} {sens:>12.4f} {spec:>12.4f} {thr:>10.2f}")
        
        # Track best by balanced accuracy (primary metric)
        if bal_acc > best_bal_acc + 1e-4:
            best_bal_acc = bal_acc
            best_auc = auc
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= config.STAGE1_PATIENCE:
                print(f"\n  Early stopping at epoch {epoch+1}")
                break
    
    if best_state is not None:
        model.load_state_dict(best_state)
        print(f"\n  ‚úì Best BalAcc: {best_bal_acc:.4f}, AUC: {best_auc:.4f}")
    
    return best_bal_acc, best_auc


# ============================================================================
# MAIN TRAINING PIPELINE
# ============================================================================

def run_s6_training_baseline(config=None):
    """
    Run stable baseline training (Stage 1 only)
    
    This establishes baseline before:
    - Trying Focal Loss
    - Adding other backbones
    - Building ensemble
    """
    if config is None:
        config = S6Config()
    
    print("\n" + "="*70)
    print("üî• SNIPPET S6: STABLE BASELINE (Stage 1 Only)")
    print("="*70)
    print(f"\nConfiguration:")
    print(f"  Device: {config.DEVICE}")
    print(f"  Backbone: {config.BACKBONE}")
    print(f"  Batch size: {config.BATCH_SIZE}")
    print(f"  Stage 1 epochs: {config.STAGE1_EPOCHS}")
    print(f"  Stage 2: {'ENABLED' if config.USE_STAGE2 else 'DISABLED (stable baseline)'}")
    print(f"  AD duplication: {config.AD_DUP_FACTOR}x")
    print(f"  Augmentation: {config.USE_AUG}")
    
    os.makedirs(config.CHECKPOINT_DIR, exist_ok=True)
    
    fold_results = []
    
    for fold_id in range(config.NUM_FOLDS):
        print("\n" + "="*70)
        print(f"üìÇ FOLD {fold_id + 1}/{config.NUM_FOLDS}")
        print("="*70)
        
        train_dataset = TriPlanarVisitDataset(
            csv_path=config.SLICES_CSV,
            mode='train',
            fold_id=fold_id,
            transform=None,
            config=None
        )
        
        val_dataset = TriPlanarVisitDataset(
            csv_path=config.SLICES_CSV,
            mode='val',
            fold_id=fold_id,
            transform=None,
            config=None
        )
        
        # Compute pos_weight
        labels_train = np.array([v['label'] for v in train_dataset.visits])
        class_counts = np.bincount(labels_train.astype(int))
        cn_count = class_counts[0]
        ad_count = class_counts[1]
        
        pos_weight_value = cn_count / (ad_count + 1e-8)
        pos_weight = torch.tensor([pos_weight_value], dtype=torch.float32, device=config.DEVICE)
        
        print(f"\n  Class counts: CN={cn_count}, AD={ad_count}")
        print(f"  pos_weight: {pos_weight_value:.2f}")
        print(f"  Effective AD samples (with {config.AD_DUP_FACTOR}x dup): {ad_count * config.AD_DUP_FACTOR}")
        
        # Simple DataLoaders (no sampler, AD duplication happens in training loop)
        train_loader = DataLoader(
            train_dataset,
            batch_size=config.BATCH_SIZE,
            shuffle=True,
            num_workers=config.NUM_WORKERS,
            pin_memory=True
        )
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=config.BATCH_SIZE,
            shuffle=False,
            num_workers=config.NUM_WORKERS,
            pin_memory=True
        )
        
        print(f"  Train: {len(train_dataset)} visits")
        print(f"  Val:   {len(val_dataset)} visits")
        
        model = TriPlanarADNet(
            backbone_name=config.BACKBONE,
            num_domains=2,
            use_dsbn=True,
            pretrained=True,
            config=None
        ).to(config.DEVICE)
        
        # Stage 1 (PRIMARY)
        best_bal_acc, best_auc = stage1_training(
            model, train_loader, val_loader, config.DEVICE, config, pos_weight
        )
        
        # Final evaluation
        final_metrics = evaluate(model, val_loader, config.DEVICE, config)
        
        print(f"\n‚úì Fold {fold_id + 1} Complete:")
        print(f"  Final BalAcc: {final_metrics['bal_acc']:.4f} at thr={final_metrics.get('thr', 0.5):.2f}")
        print(f"  Final AUC:    {final_metrics['auc']:.4f}")
        print(f"  Final Sens:   {final_metrics['sens']:.4f}")
        print(f"  Final Spec:   {final_metrics['spec']:.4f}")
        
        checkpoint_path = os.path.join(config.CHECKPOINT_DIR, f"{config.BACKBONE}_fold{fold_id}.pth")
        torch.save(model.state_dict(), checkpoint_path)
        
        fold_results.append({
            'fold': fold_id,
            'backbone': config.BACKBONE,
            'final_bal_acc': final_metrics['bal_acc'],
            'final_auc': final_metrics['auc'],
            'final_sens': final_metrics['sens'],
            'final_spec': final_metrics['spec'],
            'final_thr': final_metrics.get('thr', 0.5),
            'tp': final_metrics['tp'],
            'tn': final_metrics['tn'],
            'fp': final_metrics['fp'],
            'fn': final_metrics['fn']
        })
    
    results_df = pd.DataFrame(fold_results)
    
    print("\n" + "="*70)
    print(f"üìä FINAL RESULTS ({config.BACKBONE.upper()})")
    print("="*70)
    print("\n" + results_df.to_string(index=False))
    
    print(f"\nüìà Mean ¬± Std:")
    print(f"  Balanced Accuracy: {results_df['final_bal_acc'].mean():.4f} ¬± {results_df['final_bal_acc'].std():.4f}")
    print(f"  AUC:              {results_df['final_auc'].mean():.4f} ¬± {results_df['final_auc'].std():.4f}")
    print(f"  Sensitivity (AD): {results_df['final_sens'].mean():.4f} ¬± {results_df['final_sens'].std():.4f}")
    print(f"  Specificity (CN): {results_df['final_spec'].mean():.4f} ¬± {results_df['final_spec'].std():.4f}")
    print(f"  Optimal Threshold: {results_df['final_thr'].mean():.2f} ¬± {results_df['final_thr'].std():.2f}")
    
    results_df.to_csv(config.RESULTS_CSV, index=False)
    print(f"\n‚úì Results saved: {config.RESULTS_CSV}")
    
    # Assessment
    mean_bal_acc = results_df['final_bal_acc'].mean()
    mean_auc = results_df['final_auc'].mean()
    
    print("\n" + "="*70)
    print("BASELINE ASSESSMENT")
    print("="*70)
    
    if mean_auc >= 0.75 and mean_bal_acc >= 0.70:
        print(f"‚úÖ STRONG BASELINE: AUC={mean_auc:.4f}, BalAcc={mean_bal_acc:.4f}")
        print("   Next steps:")
        print("   1. Try Focal Loss (may improve by 2-3%)")
        print("   2. Add ResNet50 + DenseNet121")
        print("   3. Build 15-model ensemble (target: 0.80-0.85 AUC)")
    elif mean_auc >= 0.70 and mean_bal_acc >= 0.65:
        print(f"‚úÖ GOOD BASELINE: AUC={mean_auc:.4f}, BalAcc={mean_bal_acc:.4f}")
        print("   Next steps:")
        print("   1. Try increasing AD_DUP_FACTOR to 3")
        print("   2. Then proceed to multi-backbone ensemble")
    elif mean_auc >= 0.65:
        print(f"‚ö†Ô∏è  MODERATE BASELINE: AUC={mean_auc:.4f}, BalAcc={mean_bal_acc:.4f}")
        print("   Suggestions:")
        print("   1. Increase STAGE1_EPOCHS to 15")
        print("   2. Try AD_DUP_FACTOR = 3")
        print("   3. Verify ROI quality in S3")
    else:
        print(f"‚ö†Ô∏è  WEAK BASELINE: AUC={mean_auc:.4f}, BalAcc={mean_bal_acc:.4f}")
        print("   This suggests data/ROI issues, not just model tuning")
        print("   Consider:")
        print("   - Different ROI extraction")
        print("   - More slices per visit")
        print("   - Verify S2 preprocessing quality")
    
    return results_df


# ============================================================================
# EXECUTE
# ============================================================================

if __name__ == "__main__":
    # STEP 1: 1-FOLD VALIDATION
    print("\n" + "="*70)
    print("üéØ STEP 1: 1-FOLD VALIDATION (Stable Baseline)")
    print("="*70)
    print("\nGoal: Establish stable baseline")
    print("  Target: AUC ‚â• 0.70, BalAcc ‚â• 0.70")
    print("  Strategy: Stage 1 only, BCE + pos_weight, AD duplication")
    
    cfg = S6Config()
    cfg.NUM_FOLDS = 1  # Start with 1 fold
    cfg.STAGE1_EPOCHS = 10
    cfg.AD_DUP_FACTOR = 2
    cfg.USE_AUG = True
    cfg.USE_STAGE2 = False
    
    results = run_s6_training_baseline(cfg)
    
    print("\n" + "="*70)
    print("‚úÖ 1-FOLD BASELINE COMPLETE")
    print("="*70)
    print("\nIf successful (AUC ‚â• 0.70):")
    print("  ‚Üí Set NUM_FOLDS=5 for full CV")
    print("  ‚Üí Try Focal Loss as next experiment")
    print("  ‚Üí Then add ResNet50 + DenseNet121")
    
    print("\nIf AUC < 0.70:")
    print("  ‚Üí Increase AD_DUP_FACTOR to 3")
    print("  ‚Üí Increase STAGE1_EPOCHS to 15")
    print("  ‚Üí Check data quality")



üéØ STEP 1: 1-FOLD VALIDATION (Stable Baseline)

Goal: Establish stable baseline
  Target: AUC ‚â• 0.70, BalAcc ‚â• 0.70
  Strategy: Stage 1 only, BCE + pos_weight, AD duplication

üî• SNIPPET S6: STABLE BASELINE (Stage 1 Only)

Configuration:
  Device: cuda
  Backbone: vgg16_bn
  Batch size: 4
  Stage 1 epochs: 10
  Stage 2: DISABLED (stable baseline)
  AD duplication: 2x
  Augmentation: True

üìÇ FOLD 1/1
  Train dataset: 370 visits
    CN: 319, AD: 51
  Val dataset: 93 visits
    CN: 80, AD: 13

  Class counts: CN=319, AD=51
  pos_weight: 6.25
  Effective AD samples (with 2x dup): 102
  Train: 370 visits
  Val:   93 visits


Downloading: "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth" to /root/.cache/torch/hub/checkpoints/vgg16_bn-6c64b313.pth
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 528M/528M [00:02<00:00, 192MB/s]



STAGE 1: LINEAR PROBING (PRIMARY)
  Frozen: Backbone
  Training: DSBN, Attention, Classifier
  Loss: BCEWithLogitsLoss(pos_weight=6.25)
  AD duplication: 2x
  Augmentation: True

Epoch   Train Loss   Val BalAcc    Val AUC  Sensitivity  Specificity  Threshold
--------------------------------------------------------------------------------
    1       0.5640       0.5000     0.5365       0.0000       1.0000       0.05
    2       0.0673       0.5000     0.4644       0.0000       1.0000       0.05
    3       0.0231       0.5000     0.5077       0.0000       1.0000       0.05
    4       0.0148       0.5000     0.4337       0.0000       1.0000       0.05
    5       0.0091       0.5000     0.4433       0.0000       1.0000       0.05
    6       0.0119       0.5000     0.5125       0.0000       1.0000       0.05
    7       0.0077       0.5000     0.4385       0.0000       1.0000       0.05
    8       0.0080       0.5000     0.5692       0.0000       1.0000       0.05
    9       0.0043 