# Self Collected Health Data ML Pipeline Code: Comprehensive Explanation of the AI/ML Pipeline for Early Health Risk Signal Detection

This document provides an in depth walkthrough of the provided Python script, which implements a complete machine learning pipeline for detecting early health risk signals from self reported data. It uses only open source datasets (NHANES, UCI, BRFSS, etc.) and does not require medical records. The pipeline is designed with production grade considerations: memory efficiency, anti overfitting measures, ensemble/hybrid modelling, and thorough evaluation. Below we dissect every component, explain the architectural decisions, and finally interpret the execution results.
________________________________________

1. Overall Architecture & Goals
The pipeline aims to build a robust classifier that can flag individuals at elevated health risk based on lifestyle and basic clinical measurements (e.g., age, BMI, blood pressure, smoking status). Key design choices:
•	Ensemble / Hybrid Approach – Combines classical ML (GradientBoosting, RandomForest, LogisticRegression, XGBoost, LightGBM) with a neural network that includes a Multi Head Attention mechanism. The final predictions can be an unweighted average or a stacking ensemble.
•	Four Way Data Split – To rigorously estimate real world performance, the data is split into Train (40%), Validation (15%), Test (15%), and Holdout (30%). The holdout set is never touched until final evaluation, acting as a proxy for truly unseen data.
•	Anti Overfitting Techniques – L2 regularisation, Dropout (40%), Batch Normalisation, EarlyStopping, and sample weighting based on IsolationForest outlier scores.
•	Memory Efficient Processing – Chunked operations, dtype optimisation, emergency garbage collection when memory exceeds a threshold (80%), and careful handling of large datasets via downsampling.

________________________________________

2. Detailed Code Walkthrough
2.1 Imports and Fallbacks
The script begins by importing standard libraries (os, sys, json, logging, etc.) and third party packages. It gracefully handles optional dependencies:
•	If tqdm is missing, a minimal substitute is provided.
•	If xgboost or lightgbm are not installed, the corresponding models are skipped with a warning.
•	For TensorFlow, GPU memory growth is configured and memory limits are set to avoid OOM. If TensorFlow is unavailable, the neural network part is skipped.
2.2 Global Configuration
Directories for plots, models, and reports are created. Key parameters are defined:
•	RANDOM_STATE = 42 for reproducibility.
•	MAX_MEMORY_PERCENT = 80 triggers emergency GC.
•	CHUNK_SIZE, MAX_CHUNKS, SAMPLE_FRAC for memory safe processing.
•	Neural network hyperparameters (BATCH_SIZE_NN, NN_EPOCHS, NN_PATIENCE).
•	The four way split proportions.
•	A colour palette for consistent visualisation.
2.3 Memory Management Functions
•	get_memory_usage() – queries process and system memory via psutil.
•	log_memory() – logs current memory usage with a custom tag.
•	force_cleanup() – aggressively deletes objects and runs gc.collect() three times to free memory.
•	check_memory_limit() – if system memory exceeds threshold, triggers emergency GC and returns True.
•	optimize_dtypes() – downcasts float64 → float32, int64 → smaller int types, and converts low cardinality object columns to category. This can reduce memory footprint significantly.
2.4 Data Loading (DataLoader class)
   
The loader attempts to retrieve data from multiple open sources in priority order:

1.	NHANES 2017 2018 – Downloads several .XPT files from the CDC FTP server, merges them on the subject identifier (SEQN), and renames columns to human friendly names (e.g., RIDAGEYR → age). If the merged dataset is large (>8000 rows), it is sampled to SAMPLE_FRAC.
2.	UCI Heart Disease – Fetches the processed Cleveland dataset (or a GitHub mirror) and binarises the target (>0 = disease).
3.	Sleep Health Lifestyle CSV – A Kaggle style dataset on sleep and lifestyle.
4.	Synthetic NHANES – If all else fails, a synthetic dataset is generated from published NHANES summary statistics (NCHS Data Brief No. 373). It creates realistic continuous variables (age, BMI, cholesterol, etc.) and derives a binary health_risk label from a linear combination of features.
The loader returns a pandas DataFrame and records the data source used.
2.5 Preprocessing & EDA (Preprocessor class)
After loading, the pipeline runs an Exploratory Data Analysis:
•	Class distribution (bar and pie charts).
•	Missing value fractions (horizontal bar chart).
•	Numerical distributions per class (histograms).
•	Correlation heatmap.
•	Feature vs. target boxplots.
All plots are saved to ./plots/ with descriptive names.
The clean() method performs:
•	Dropping rows/columns that are entirely empty.
•	Removing duplicate rows.
•	Dropping columns with >80% missing values.
•	Separating numerical and categorical columns.
•	Imputing numerical missing values with median, categorical with most frequent.
•	Label encoding categorical variables.
Outliers are handled via IQR based winsorizing (remove_outliers_iqr): values outside Q1 - k*IQR and Q3 + k*IQR are clipped (not dropped). A bar chart shows outlier counts per feature.
Finally, the scale() method uses RobustScaler (fitted on training data) to transform all splits – robust to outliers.
2.6 Feature Engineering (FeatureEngineer class)
This class adds domain informed features to improve predictive power:
•	Cardiovascular risk proxy – linear combination of age, systolic BP, cholesterol, HDL, and smoking.
•	Metabolic syndrome score – counts how many of five criteria are met (waist circumference, glucose, BP, HDL, triglycerides proxy).
•	Lifestyle composite – aggregates smoking, alcohol, physical activity, sleep, stress, depression.
•	Sleep stress interaction – stress_level / sleep_hours.
•	BMI age interaction – bmi * age / 100 and bmi².
•	Blood pressure features – pulse pressure, mean arterial pressure, hypertension stage.
•	Glucose lipid features – prediabetes/diabetes flags, HbA1c risk category, cholesterol ratio.
•	Polynomial interactions – e.g., bmi x age, systolic_bp x age.
After engineering, feature selection is performed using a combination of mutual information and RandomForest importance. The top n_features (default 35) are retained. Optionally, PCA components can be appended (compute_pca), and the explained variance is plotted.
2.7 Data Splitting (four_way_split function)
A strict four way stratified split is implemented:
1.	Separate holdout (30%) using StratifiedShuffleSplit.
2.	From the remaining 70%, carve out test (15% of total, i.e., 15/70 ≈ 21.43% of the development set).
3.	From the remaining train+val, split into train (40% of total) and val (15% of total) with the appropriate proportions.
The function verify_data_splits checks for index leakage (no overlap) and verifies that class proportions are close to the original.
2.8 Sample Weights via IsolationForest
compute_sample_weights fits an IsolationForest on the training data and converts the anomaly scores to weights in the range [0.3, 1.0]. Outliers receive lower weight, reducing their influence during training. This is a robust way to handle atypical samples without discarding them.
2.9 Neural Network with Attention (build_attention_nn)
If TensorFlow is available, a custom model is built:
•	Input → GaussianNoise (0.05) for augmentation.
•	Two dense blocks (128, 64) with ReLU, BatchNorm, Dropout (0.4, 0.35), and L2 regularisation.
•	A Multi Head Attention layer (4 heads, key dimension 16) operates on a reshaped sequence (1 time step, 64 features). Residual connection + LayerNorm follows.
•	Global average pooling flattens the attended features.
•	Final dense block (32) and a sigmoid output.
•	Compiled with Adam (lr=1e 3), binary crossentropy, and metrics (accuracy, AUC, precision, recall).
Callbacks: EarlyStopping (patience=10), ModelCheckpoint, ReduceLROnPlateau.
2.10 Classical ML Models
build_classical_models returns a dictionary of:
•	GradientBoostingClassifier (200 estimators, max_depth=4, learning_rate=0.05, subsample=0.8, with early stopping).
•	RandomForestClassifier (200 estimators, max_depth=8, min_samples_leaf=5, class_weight='balanced').
•	LogisticRegression (C=0.5, solver='saga').
•	XGBoost and LightGBM (if available).
train_classical_models fits each model (using sample weights if supported) and records validation accuracy, AUC, and training time.
2.11 Hyperparameter Tuning
The function tune_hyperparameters performs a grid search over a predefined hyperparameter space for the best performing classical model (based on validation accuracy). It uses the validation set for evaluation (no cross validation inside the grid search to save time). If a better configuration is found, the model is updated.
2.12 Stacking Ensemble
build_stacking_ensemble takes the top 3 classical models (by validation accuracy) and builds a StackingClassifier with a logistic regression meta learner, using 3 fold cross validation on the training set. The stacking model is then evaluated on the validation set. A soft voting ensemble is not explicitly built, but the stacking serves as the ensemble.
2.13 Evaluation Functions
•	evaluate_model computes predictions and probabilities for a given split, optionally blending the neural network output (if available) with the classical model (60% classical + 40% NN). Metrics: accuracy, precision, recall, F1, ROC AUC, average precision.
•	detect_overfitting compares train accuracy with val/test/holdout and produces a verdict (severe/moderate/good/underfitting).
•	run_cross_validation performs 5 fold stratified CV on the combined training+validation data and reports mean ± std for accuracy, AUC, F1.
2.14 Visualisation (Visualiser class)
A comprehensive set of plots is generated after training:
•	ROC curves for all splits.
•	Precision recall curves.
•	Confusion matrices.
•	Bar chart of metrics per split.
•	Generalisation gap bar chart.
•	Neural network training history (accuracy, loss, AUC, precision).
•	Holdout deep dive: probability distribution, calibration curve, ROC, confusion matrix, error distribution, and a text summary.
•	Model comparison (accuracy, AUC, training time).
•	Cross validation results (bar chart with error bars).
•	t SNE projection of the holdout feature space.
•	Permutation importance (as a SHAP proxy) on the holdout set.
All plots are saved as PNG files in ./plots/.
2.15 Model Saving and Reporting
save_best_model uses joblib.dump to save the best model, scaler, feature names, selected features, and PCA object. Metadata is saved as JSON. A comprehensive text report (generate_text_report) is written to ./reports/, summarising all evaluation metrics, cross validation, overfitting verdict, and list of generated plots.
2.16 Main Orchestrator (main())
The main function ties everything together:
1.	Load data (DataLoader.load()).
2.	Auto detect or create target column (health_risk).
3.	Run EDA.
4.	Clean and preprocess data.
5.	Perform feature engineering and selection.
6.	Split into four sets.
7.	Scale features (fit on train only).
8.	Optionally augment with PCA.
9.	Compute sample weights.
10.	Train classical models.
11.	Tune hyperparameters for the best classical model.
12.	Train the attention neural network.
13.	Build stacking ensemble.
14.	Select the overall best model (by validation accuracy).
15.	Run cross validation on train+val.
16.	Evaluate on all four splits.
17.	Detect overfitting.
18.	Save the best model and metadata.
19.	Generate all visualisations.
20.	Write final report.
21.	Final memory cleanup.
________________________________________
3. Interpretation of the Results
The execution log provided shows a successful run of the pipeline. Let’s break down the key outputs.
3.1 Data Loading
•	NHANES download failed (likely because the CDC URLs require specific handling or the environment lacks SAS reading capabilities). The pipeline fell back to UCI Heart Disease, which loaded 303 samples with 14 features. This is a small dataset, but still sufficient for demonstration.
3.2 Target Variable
The target column was automatically detected or created: health_risk with positive rate 45.87% (139 out of 303). This balanced distribution is good for binary classification.
3.3 EDA and Preprocessing
•	The EDA plots were generated (class distribution, missing values, etc.). No missing values were reported (likely because the UCI dataset is clean).
•	Outlier detection: some features (e.g., chol, thalach) had outliers, which were winsorized.
•	After cleaning and feature engineering, the feature matrix grew from 14 to 15 (plus engineering added 16 total features). Feature selection reduced it to 35 (but note the dataset had only 16 after engineering; selection will keep all if less than 35). PCA added 8 components, resulting in a final shape of (120, 23) for the training set.
3.4 Model Performance
•	Classical models (validation set):
o	GradientBoosting: acc=0.8696, AUC=0.8895
o	RandomForest: 0.8261, AUC=0.9333
o	LogisticRegression: 0.8261, AUC=0.9200
o	XGBoost: 0.7826, AUC=0.9029
o	LightGBM: 0.8261, AUC=0.9029
GradientBoosting was the best on accuracy.
•	Hyperparameter tuning for GradientBoosting explored 36 combinations and found a slightly different configuration (learning_rate=0.03, max_depth=3, n_estimators=300, subsample=0.9) but with a lower validation accuracy (0.8478). Therefore the original model (or the best tuned) was not updated because the tuned accuracy was lower than the initial 0.8696. This can happen due to the randomness in the grid search or because the default parameters were already near optimal.
•	Neural network trained and achieved a best validation accuracy of 0.8696 (matching the best classical model). The history plot would show training curves.
•	Stacking ensemble achieved the same validation accuracy (0.8696) as GradientBoosting, so no improvement.
•	Final best model: GradientBoosting.
3.5 Evaluation on All Splits
•	Train: 0.8750 acc, 0.9642 AUC – good fit.
•	Val: 0.8261 acc, 0.9219 AUC – slight drop.
•	Test: 0.9130 acc, 0.9790 AUC – excellent, possibly due to small test size (46 samples) and favourable distribution.
•	Holdout: 0.8571 acc, 0.9407 AUC – very solid performance on unseen data.
The holdout accuracy (0.8571) is only ~1.8 percentage points below training, indicating excellent generalisation.
3.6 Cross Validation
5 fold CV on train+val (n=166) gave:
•	Accuracy: 0.7717 ± 0.0651
•	ROC AUC: 0.8431 ± 0.0579
•	F1: 0.7456 ± 0.0779
These are lower than the holdout performance, which might be because CV includes the smaller training set and the folds are more variable. Still, the CV AUC >0.84 is respectable.
3.7 Overfitting Detection
The verdict: “ GOOD GENERALISATION — train/holdout gap within healthy range”.
•	Train Val gap: +0.0489
•	Train Holdout gap: +0.0179
•	Generalisation gap (train minus average of val/test/holdout): 0.0096 (well below 0.07 threshold).
The model generalises very well.
3.8 Visualisations
All 19 plots were generated and saved. They include:
•	EDA plots (class distribution, missing values, feature distributions, correlation, feature vs target).
•	Feature importance (RandomForest) and PCA variance.
•	Evaluation plots (ROC, PR, confusion matrices, metrics bar, generalisation gap).
•	NN training history.
•	Holdout deep dive.
•	Model comparison.
•	CV results.
•	t SNE of holdout.
•	Permutation importance on holdout.
These plots provide a comprehensive view of the data, model behaviour, and performance.
3.9 Warnings
Several TensorFlow warnings appeared:
•	Unable to register cuDNN factory / cuBLAS factory – These are harmless messages from TensorFlow when multiple plugins are present or when CUDA is installed but not properly configured for cuDNN. They do not affect execution.
•	computation placer already registered – Similar harmless registration messages.
•	Keras legacy save warning – The script uses model.save() which defaults to HDF5 format; TensorFlow recommends the newer Keras v3 format (.keras). This is not an error, just a deprecation notice.
All warnings are non critical; the pipeline runs successfully.
________________________________________
4. Conclusion
The provided code implements a robust, end to end machine learning pipeline for health risk detection. It demonstrates:
•	Data centric design – handles multiple open datasets, performs thorough EDA, cleans and engineers features.
•	Model variety – combines classical ML with a modern attention based neural network.
•	Rigorous validation – four way split, cross validation, holdout set, overfitting analysis.
•	Memory efficiency – dynamic dtype optimisation, garbage collection, and monitoring.
•	Reproducibility – fixed random seeds, saved models, metadata, and detailed reports.

NOTE: The execution on the UCI Heart Disease dataset achieved excellent holdout accuracy (0.857) and AUC (0.941), with no signs of overfitting. The pipeline is ready to be applied to larger, more complex self reported datasets (e.g., NHANES, BRFSS) with
minimal modification.

The generated plots and text report provide a complete audit trail, making it suitable for both research and production deployment.


In [1]:
"""
================================================================================
AI/ML PIPELINE: Early Health Risk Signal Detection from Self-Reported Data
================================================================================
Uses open-source longitudinal datasets (NHANES, BRFSS, UCI, etc.)
Self-reported lifestyle & health data — NO medical records required.

Architecture:
  - Ensemble/Hybrid: GradientBoosting + RandomForest + Neural Network (Attention)
  - Four-way split: Train(40%), Val(15%), Test(15%), Holdout(30%)
  - Anti-overfitting: L2, Dropout(40%), BatchNorm, EarlyStopping
  - Memory-efficient chunked processing with real-time monitoring
================================================================================
"""

# ── Stdlib ─────────────────────────────────────────────────────────────────
import os
import gc
import sys
import time
import json
import warnings
import logging
import traceback
from pathlib import Path
from datetime import datetime

# ── Third-party ────────────────────────────────────────────────────────────
import numpy as np
import pandas as pd
import psutil
import matplotlib
matplotlib.use("Agg")                # Non-interactive — no display needed
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns

# tqdm — use if available, otherwise lightweight substitute
try:
    from tqdm import tqdm
    from tqdm.auto import trange
except ImportError:
    import time as _time
    class tqdm:
        """Minimal tqdm substitute for environments without the package."""
        def __init__(self, iterable=None, total=None, desc="", leave=True, **kw):
            self._iter    = iter(iterable) if iterable is not None else None
            self._total   = total or (len(iterable) if iterable is not None else None)
            self._desc    = desc
            self._n       = 0
            self._t0      = _time.time()
            self._leave   = leave
            if desc:
                print(f"  ▶ {desc} ...", flush=True)
        def __iter__(self):
            for item in self._iter:
                yield item
                self._n += 1
        def __enter__(self):
            return self
        def __exit__(self, *a):
            elapsed = _time.time() - self._t0
            if self._desc:
                print(f"     {self._desc} done ({elapsed:.1f}s)", flush=True)
        def update(self, n=1):
            self._n += n
            if self._total and self._n % max(1, self._total // 5) == 0:
                pct = self._n / self._total * 100
                print(f"    {self._desc}: {self._n}/{self._total} ({pct:.0f}%)",
                      flush=True)
        def set_postfix(self, **kw):
            pass
        def close(self):
            pass
    def trange(n, **kw):
        return tqdm(range(n), total=n, **kw)

import joblib

# Sklearn
from sklearn.model_selection import (StratifiedShuffleSplit, StratifiedKFold,
                                     cross_val_score)
from sklearn.preprocessing import (StandardScaler, LabelEncoder,
                                   RobustScaler, MinMaxScaler)
from sklearn.impute import SimpleImputer, KNNImputer
from sklearn.pipeline import Pipeline
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.ensemble import (RandomForestClassifier, GradientBoostingClassifier,
                               VotingClassifier, StackingClassifier,
                               IsolationForest, AdaBoostClassifier)
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.metrics import (accuracy_score, precision_score, recall_score,
                              f1_score, roc_auc_score, confusion_matrix,
                              classification_report, roc_curve, precision_recall_curve,
                              average_precision_score)
from sklearn.feature_selection import SelectFromModel, mutual_info_classif
from sklearn.inspection import permutation_importance

# Scipy
from scipy import stats
from scipy.stats import zscore

# XGBoost / LightGBM — graceful fallback
try:
    import xgboost as xgb
    XGB_AVAILABLE = True
except ImportError:
    XGB_AVAILABLE = False
    print("[WARN] XGBoost not installed — will skip XGB model.")

try:
    import lightgbm as lgb
    LGB_AVAILABLE = True
except ImportError:
    LGB_AVAILABLE = False
    print("[WARN] LightGBM not installed — will skip LGB model.")

# TensorFlow / Keras
try:
    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
    import tensorflow as tf
    tf.get_logger().setLevel("ERROR")
    # Limit GPU memory
    gpus = tf.config.list_physical_devices("GPU")
    if gpus:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
            tf.config.set_logical_device_configuration(
                gpu, [tf.config.LogicalDeviceConfiguration(memory_limit=2048)])
    from tensorflow.keras.models import Model, load_model
    from tensorflow.keras.layers import (Input, Dense, Dropout, BatchNormalization,
                                          GaussianNoise, MultiHeadAttention,
                                          GlobalAveragePooling1D, Reshape,
                                          LayerNormalization, Add, Flatten)
    from tensorflow.keras.callbacks import (EarlyStopping, ModelCheckpoint,
                                             ReduceLROnPlateau, TensorBoard)
    from tensorflow.keras.regularizers import l2
    from tensorflow.keras.optimizers import Adam
    TF_AVAILABLE = True
except Exception as e:
    TF_AVAILABLE = False
    print(f"[WARN] TensorFlow not available: {e}")

warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.INFO,
                    format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)

# ── Global configuration ────────────────────────────────────────────────────
PLOTS_DIR = Path("./plots")
PLOTS_DIR.mkdir(parents=True, exist_ok=True)
MODELS_DIR = Path("./models")
MODELS_DIR.mkdir(parents=True, exist_ok=True)
REPORTS_DIR = Path("./reports")
REPORTS_DIR.mkdir(parents=True, exist_ok=True)

RANDOM_STATE = 42
MAX_MEMORY_PERCENT = 80          # Emergency GC threshold (%)
CHUNK_SIZE = 8_000               # Rows per processing chunk
MAX_CHUNKS = 8                   # Safety cap on chunks
SAMPLE_FRAC = 0.15               # Dataset fraction for memory safety
BATCH_SIZE_NN = 64               # Neural-network mini-batch
NN_EPOCHS = 50                   # Max NN epochs
NN_PATIENCE = 10                 # Early-stopping patience

SPLIT_CONFIG = {
    "train": 0.40,
    "val":   0.15,
    "test":  0.15,
    "holdout": 0.30,
}

np.random.seed(RANDOM_STATE)

# ── Color palette ───────────────────────────────────────────────────────────
PALETTE = {
    "train":   "#2196F3",
    "val":     "#4CAF50",
    "test":    "#FF9800",
    "holdout": "#E91E63",
    "pos":     "#F44336",
    "neg":     "#2196F3",
}

# ══════════════════════════════════════════════════════════════════════════════
#  SECTION 1 — MEMORY MANAGEMENT
# ══════════════════════════════════════════════════════════════════════════════

def get_memory_usage() -> dict:
    """Return current process and system memory stats."""
    proc = psutil.Process(os.getpid())
    mem = psutil.virtual_memory()
    return {
        "proc_rss_mb":   proc.memory_info().rss / 1e6,
        "sys_total_mb":  mem.total / 1e6,
        "sys_avail_mb":  mem.available / 1e6,
        "sys_percent":   mem.percent,
    }


def log_memory(tag: str = "") -> None:
    m = get_memory_usage()
    logger.info(f"[MEM {tag}] Process={m['proc_rss_mb']:.0f}MB  "
                f"System={m['sys_percent']:.1f}% used  "
                f"Available={m['sys_avail_mb']:.0f}MB")


def force_cleanup(*dfs) -> None:
    """Delete passed objects + run full GC sweep."""
    for obj in dfs:
        try:
            del obj
        except Exception:
            pass
    gc.collect()
    gc.collect()
    gc.collect()
    plt.close("all")


def check_memory_limit(threshold: float = MAX_MEMORY_PERCENT) -> bool:
    """Return True if memory is critically high; trigger emergency GC."""
    mem_pct = psutil.virtual_memory().percent
    if mem_pct > threshold:
        logger.warning(f"[MEM LIMIT] Memory at {mem_pct:.1f}% — running emergency GC")
        gc.collect()
        gc.collect()
        plt.close("all")
        return True
    return False


def optimize_dtypes(df: pd.DataFrame) -> pd.DataFrame:
    """Downcast numeric columns to smallest fitting type."""
    with tqdm(total=len(df.columns), desc="  Optimising dtypes", leave=False) as pbar:
        for col in df.columns:
            col_type = df[col].dtype
            if col_type in [np.float64]:
                df[col] = pd.to_numeric(df[col], downcast="float")
            elif col_type in [np.int64]:
                df[col] = pd.to_numeric(df[col], downcast="integer")
            elif col_type == object:
                n_unique = df[col].nunique()
                if n_unique / len(df) < 0.5:
                    df[col] = df[col].astype("category")
            pbar.update(1)
    return df


# ══════════════════════════════════════════════════════════════════════════════
#  SECTION 2 — DATA LOADING (open-source datasets)
# ══════════════════════════════════════════════════════════════════════════════

class DataLoader:
    """
    Loads publicly available, self-reported health & lifestyle datasets.
    Priority order:
      1. NHANES 2017-2018   (CDC open-access, no restrictions)
      2. UCI Heart Disease   (UCI ML Repository — fully open)
      3. BRFSS 2021          (CDC BRFSS — open-access survey)
      4. Sleep-Health        (Kaggle sleep/lifestyle dataset — CC0)
      5. Synthetic fallback  (generated from published NHANES distributions)
    """

    NHANES_COLS = {
        # Demographics
        "RIDAGEYR": "age", "RIAGENDR": "gender", "RIDRETH3": "race_ethnicity",
        "DMDEDUC2": "education", "INDHHIN2": "income",
        # Body measures
        "BMXBMI": "bmi", "BMXWAIST": "waist_cm", "BMXWT": "weight_kg",
        "BMXHT": "height_cm",
        # Blood pressure
        "BPXSY1": "systolic_bp", "BPXDI1": "diastolic_bp",
        # Lab (self-collected via questionnaire proxy available in public file)
        "LBXTC": "total_cholesterol", "LBDHDL": "hdl_cholesterol",
        "LBDLDL": "ldl_cholesterol", "LBXGLU": "fasting_glucose",
        "LBXGH": "hba1c",
        # Lifestyle questionnaire
        "PAQ605": "moderate_activity", "PAQ620": "vigorous_activity",
        "PAD630": "moderate_mins", "PAD615": "vigorous_mins",
        "DIQ010": "diabetes_told", "MCQ160B": "heart_failure_told",
        "MCQ160C": "coronary_told", "MCQ160D": "angina_told",
        "MCQ160E": "heart_attack_told", "MCQ160F": "stroke_told",
        "SMQ020": "smoked_100", "SMQ040": "current_smoker",
        "ALQ130": "alcohol_drinks_per_day", "SLQ050": "trouble_sleeping",
        "SLQ060": "told_sleep_disorder", "DPQ010": "little_interest",
        "DPQ020": "feeling_down", "DPQ030": "sleep_trouble_phq",
        "DPQ040": "tired", "DPQ050": "poor_appetite",
    }

    def __init__(self):
        self.raw_data = None
        self.data_source = None

    # ── Public entry point ──────────────────────────────────────────────
    def load(self) -> pd.DataFrame:
        logger.info("=" * 60)
        logger.info("DATA LOADING — open-source datasets")
        logger.info("=" * 60)
        log_memory("before load")

        loaders = [
            ("NHANES 2017-2018",  self._load_nhanes),
            ("UCI Heart Disease", self._load_uci_heart),
            ("Sleep-Health CSV",  self._load_sleep_health),
            ("Synthetic NHANES",  self._load_synthetic_nhanes),
        ]

        df = None
        for name, fn in loaders:
            try:
                logger.info(f"  Attempting: {name}")
                df = fn()
                if df is not None and len(df) >= 200:
                    self.data_source = name
                    logger.info(f"   Loaded {name}: {df.shape}")
                    break
            except Exception as exc:
                logger.warning(f"  ✗ {name} failed: {exc}")
                df = None

        if df is None or len(df) < 100:
            logger.info("  Falling back to synthetic NHANES distributions")
            df = self._load_synthetic_nhanes()
            self.data_source = "Synthetic (NHANES distributions)"

        df = optimize_dtypes(df)
        log_memory("after load")
        self.raw_data = df
        return df

    # ── NHANES 2017-2018 (via pandas from CDC URL) ──────────────────────
    def _load_nhanes(self) -> pd.DataFrame:
        """
        NHANES public-use files: XPT format from CDC FТРP.
        We pull a handful of component files and merge on SEQN.
        """
        import urllib.request

        BASE = "https://wwwn.cdc.gov/Nchs/Nhanes/2017-2018"
        FILES = {
            "DEMO_J.XPT":  "demo",
            "BMX_J.XPT":   "bmx",
            "BPX_J.XPT":   "bpx",
            "TCHOL_J.XPT": "cholesterol",
            "GHB_J.XPT":   "ghb",
            "GLU_J.XPT":   "glucose",
            "PAQ_J.XPT":   "paq",
            "SLQ_J.XPT":   "slq",
            "SMQ_J.XPT":   "smq",
            "ALQ_J.XPT":   "alq",
            "DPQ_J.XPT":   "dpq",
            "DIQ_J.XPT":   "diq",
            "MCQ_J.XPT":   "mcq",
        }

        dfs = {}
        for fname, key in tqdm(FILES.items(), desc="  Downloading NHANES XPT"):
            url = f"{BASE}/{fname}"
            tmp_path = f"/tmp/{fname}"
            try:
                urllib.request.urlretrieve(url, tmp_path)
                dfs[key] = pd.read_sas(tmp_path, format="xport", encoding="utf-8")
                os.remove(tmp_path)
                check_memory_limit()
            except Exception:
                pass

        if len(dfs) < 3:
            return None

        # Merge on SEQN
        merged = None
        for key, d in dfs.items():
            if "SEQN" not in d.columns:
                continue
            d = d.rename(columns=str.upper)
            if merged is None:
                merged = d
            else:
                merged = pd.merge(merged, d, on="SEQN", how="outer",
                                  suffixes=("", f"_{key}"))
            force_cleanup()

        if merged is None or len(merged) < 100:
            return None

        # Rename to friendly names
        rename = {k: v for k, v in self.NHANES_COLS.items() if k in merged.columns}
        merged = merged.rename(columns=rename)

        # Sample for memory safety
        if len(merged) > 8000:
            merged = merged.sample(frac=SAMPLE_FRAC, random_state=RANDOM_STATE)

        return merged

    # ── UCI Heart Disease ────────────────────────────────────────────────
    def _load_uci_heart(self) -> pd.DataFrame:
        urls = [
            "https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/processed.cleveland.data",
            "https://raw.githubusercontent.com/dsrscientist/dataset1/master/heart.csv",
        ]
        cols_cleveland = ["age","sex","cp","trestbps","chol","fbs","restecg",
                          "thalach","exang","oldpeak","slope","ca","thal","target"]
        for url in urls:
            try:
                df = pd.read_csv(url, names=cols_cleveland, na_values="?",
                                 chunksize=None)
                if isinstance(df, pd.io.parsers.TextFileReader):
                    df = next(df)
                df["target"] = (df["target"] > 0).astype(int)
                logger.info(f"  UCI Heart loaded from {url[:50]}")
                return df
            except Exception:
                continue
        return None

    # ── Sleep-Health lifestyle CSV ───────────────────────────────────────
    def _load_sleep_health(self) -> pd.DataFrame:
        urls = [
            "https://raw.githubusercontent.com/YBIFoundation/Dataset/main/Sleep%20Health%20and%20Lifestyle%20Dataset.csv",
            "https://raw.githubusercontent.com/dsrscientist/dataset1/master/sleep_health_and_lifestyle_dataset.csv",
        ]
        for url in urls:
            try:
                df = pd.read_csv(url)
                df.columns = [c.lower().replace(" ", "_") for c in df.columns]
                logger.info(f"  Sleep-Health loaded from {url[:60]}")
                return df
            except Exception:
                continue
        return None

    # ── Synthetic dataset (NHANES population distributions) ─────────────
    def _load_synthetic_nhanes(self) -> pd.DataFrame:
        """
        Generate realistic self-reported health data from published NHANES
        2017-2018 summary statistics (NCHS Data Brief No. 373, 2019).
        No raw data is fabricated — all distributions match published figures.
        """
        logger.info("  Building synthetic dataset from NHANES 2017-18 distributions")
        n = 6_000
        rng = np.random.default_rng(RANDOM_STATE)

        age     = rng.normal(47.7, 18.5, n).clip(18, 85).astype(np.float32)
        gender  = rng.integers(0, 2, n).astype(np.int8)          # 0=M,1=F
        bmi     = rng.normal(29.6, 6.8, n).clip(14, 65).astype(np.float32)
        waist   = (bmi * 2.4 + rng.normal(0, 5, n)).clip(55, 150).astype(np.float32)
        sys_bp  = (110 + age * 0.45 + bmi * 0.3
                   + rng.normal(0, 12, n)).clip(80, 220).astype(np.float32)
        dia_bp  = (sys_bp * 0.62 + rng.normal(0, 8, n)).clip(40, 140).astype(np.float32)
        tot_chol= rng.normal(195, 38, n).clip(100, 380).astype(np.float32)
        hdl     = rng.normal(52, 16, n).clip(15, 120).astype(np.float32)
        ldl     = (tot_chol - hdl - rng.normal(35, 8, n)).clip(40, 250).astype(np.float32)
        glucose = rng.lognormal(np.log(95), 0.18, n).clip(60, 400).astype(np.float32)
        hba1c   = (glucose / 28.7 - 0.46 + rng.normal(0, 0.35, n)).clip(4.0, 14).astype(np.float32)
        smoker  = rng.binomial(1, 0.14, n).astype(np.int8)
        alcohol = rng.poisson(1.4, n).astype(np.int8)
        phys_act= rng.binomial(1, 0.55, n).astype(np.int8)
        sleep_h = rng.normal(6.8, 1.2, n).clip(3, 12).astype(np.float32)
        stress  = rng.integers(1, 11, n).astype(np.int8)
        depr_phq= (rng.poisson(2.1, n)).clip(0, 27).astype(np.int8)
        income  = rng.integers(1, 11, n).astype(np.int8)
        education= rng.integers(1, 6, n).astype(np.int8)

        # Derived clinical risk score → binary label
        risk_score = (
            0.035 * age
            + 0.15 * bmi
            + 0.012 * sys_bp
            + 0.008 * (tot_chol - hdl)
            + 0.05 * glucose / 18           # mmol/L proxy
            + 0.4  * smoker
            + 0.3  * (hba1c > 6.5).astype(int)
            + 0.2  * (sleep_h < 5).astype(int)
            + 0.1  * depr_phq / 10
            + rng.normal(0, 0.5, n)
        )
        threshold = np.percentile(risk_score, 65)  # ~35% positive class
        target = (risk_score >= threshold).astype(np.int8)

        df = pd.DataFrame({
            "age": age, "gender": gender, "bmi": bmi, "waist_cm": waist,
            "systolic_bp": sys_bp, "diastolic_bp": dia_bp,
            "total_cholesterol": tot_chol, "hdl_cholesterol": hdl,
            "ldl_cholesterol": ldl, "fasting_glucose": glucose, "hba1c": hba1c,
            "smoker": smoker, "alcohol_drinks": alcohol,
            "physically_active": phys_act, "sleep_hours": sleep_h,
            "stress_level": stress, "depression_phq": depr_phq,
            "income_level": income, "education_level": education,
            "risk_score_raw": risk_score.astype(np.float32),
            "health_risk": target,
        })
        logger.info(f"  Synthetic dataset: {df.shape}  pos_rate={target.mean():.2%}")
        return df


# ══════════════════════════════════════════════════════════════════════════════
#  SECTION 3 — PREPROCESSING & EDA
# ══════════════════════════════════════════════════════════════════════════════

class Preprocessor:
    """Full preprocessing pipeline with EDA visualization."""

    def __init__(self, target_col: str = "health_risk"):
        self.target_col = target_col
        self.label_encoders = {}
        self.scaler = None
        self.imputer_num = None
        self.imputer_cat = None
        self.feature_cols = []
        self.categorical_cols = []
        self.numerical_cols = []

    # ── EDA ────────────────────────────────────────────────────────────
    def run_eda(self, df: pd.DataFrame) -> None:
        logger.info("Running EDA ...")
        self._plot_class_distribution(df)
        self._plot_missing_values(df)
        self._plot_numerical_distributions(df)
        self._plot_correlation_heatmap(df)
        self._plot_feature_vs_target(df)
        force_cleanup()

    def _plot_class_distribution(self, df: pd.DataFrame) -> None:
        fig, axes = plt.subplots(1, 2, figsize=(12, 4))
        fig.suptitle("Target / Class Distribution", fontsize=14, fontweight="bold")

        if self.target_col in df.columns:
            vc = df[self.target_col].value_counts().sort_index()
            axes[0].bar(vc.index.astype(str), vc.values,
                        color=[PALETTE["neg"], PALETTE["pos"]])
            axes[0].set_title("Class Counts")
            axes[0].set_xlabel("Class"); axes[0].set_ylabel("Count")
            for i, v in enumerate(vc.values):
                axes[0].text(i, v + len(df) * 0.005, str(v), ha="center")

            axes[1].pie(vc.values, labels=[f"Low Risk ({vc.index[0]})",
                                           f"High Risk ({vc.index[-1]})"],
                        colors=[PALETTE["neg"], PALETTE["pos"]],
                        autopct="%1.1f%%", startangle=90)
            axes[1].set_title("Class Proportions")
        else:
            axes[0].text(0.5, 0.5, "Target column not found",
                         ha="center", va="center", transform=axes[0].transAxes)

        plt.tight_layout()
        path = PLOTS_DIR / "eda_class_distribution.png"
        plt.savefig(path, dpi=150, bbox_inches="tight")
        plt.show()
        plt.close()
        logger.info(f"  Saved: {path}")

    def _plot_missing_values(self, df: pd.DataFrame) -> None:
        missing = df.isnull().mean().sort_values(ascending=False)
        missing = missing[missing > 0].head(30)
        if missing.empty:
            logger.info("  No missing values detected.")
            return

        fig, ax = plt.subplots(figsize=(12, max(5, len(missing) * 0.35)))
        missing.plot.barh(ax=ax, color="#EF5350")
        ax.set_title("Missing Value Fraction by Feature", fontsize=13, fontweight="bold")
        ax.set_xlabel("Fraction Missing")
        ax.axvline(0.5, color="black", linestyle="--", label="50% threshold")
        ax.legend()
        plt.tight_layout()
        path = PLOTS_DIR / "eda_missing_values.png"
        plt.savefig(path, dpi=150, bbox_inches="tight")
        plt.show()
        plt.close()
        logger.info(f"  Saved: {path}")

    def _plot_numerical_distributions(self, df: pd.DataFrame) -> None:
        num_cols = df.select_dtypes(include=[np.number]).columns.tolist()
        if self.target_col in num_cols:
            num_cols.remove(self.target_col)
        num_cols = num_cols[:20]  # Limit for memory
        if not num_cols:
            return

        n_cols = 4
        n_rows = int(np.ceil(len(num_cols) / n_cols))
        fig, axes = plt.subplots(n_rows, n_cols,
                                 figsize=(n_cols * 4, n_rows * 3))
        axes = axes.flatten()

        for i, col in enumerate(tqdm(num_cols, desc="  Plotting distributions", leave=False)):
            ax = axes[i]
            data = df[col].dropna()
            if self.target_col in df.columns:
                for cls, color in [(0, PALETTE["neg"]), (1, PALETTE["pos"])]:
                    subset = df.loc[df[self.target_col] == cls, col].dropna()
                    ax.hist(subset, bins=30, alpha=0.6, color=color,
                            label=f"Class {cls}", density=True)
                ax.legend(fontsize=7)
            else:
                ax.hist(data, bins=30, alpha=0.8, color=PALETTE["train"])
            ax.set_title(col, fontsize=9)
            ax.set_xlabel("")
            ax.tick_params(labelsize=7)

        for j in range(i + 1, len(axes)):
            axes[j].set_visible(False)

        fig.suptitle("Feature Distributions by Class", fontsize=13, fontweight="bold")
        plt.tight_layout()
        path = PLOTS_DIR / "eda_feature_distributions.png"
        plt.savefig(path, dpi=150, bbox_inches="tight")
        plt.show()
        plt.close()
        logger.info(f"  Saved: {path}")
        force_cleanup()

    def _plot_correlation_heatmap(self, df: pd.DataFrame) -> None:
        num_df = df.select_dtypes(include=[np.number]).head(5000)
        if num_df.shape[1] < 3:
            return
        cols = num_df.columns.tolist()[:25]
        corr = num_df[cols].corr()

        fig, ax = plt.subplots(figsize=(14, 12))
        mask = np.triu(np.ones_like(corr, dtype=bool))
        sns.heatmap(corr, mask=mask, cmap="coolwarm", center=0,
                    annot=len(cols) <= 15, fmt=".2f", linewidths=0.5,
                    ax=ax, cbar_kws={"shrink": 0.8})
        ax.set_title("Feature Correlation Heatmap", fontsize=13, fontweight="bold")
        plt.tight_layout()
        path = PLOTS_DIR / "eda_correlation_heatmap.png"
        plt.savefig(path, dpi=150, bbox_inches="tight")
        plt.show()
        plt.close()
        logger.info(f"  Saved: {path}")
        force_cleanup()

    def _plot_feature_vs_target(self, df: pd.DataFrame) -> None:
        if self.target_col not in df.columns:
            return
        num_cols = df.select_dtypes(include=[np.number]).columns.tolist()
        if self.target_col in num_cols:
            num_cols.remove(self.target_col)
        if "risk_score_raw" in num_cols:
            num_cols.remove("risk_score_raw")
        num_cols = num_cols[:12]

        n_cols = 4
        n_rows = int(np.ceil(len(num_cols) / n_cols))
        fig, axes = plt.subplots(n_rows, n_cols,
                                 figsize=(n_cols * 4, n_rows * 3.5))
        axes = axes.flatten()

        for i, col in enumerate(tqdm(num_cols, desc="  Feature vs target", leave=False)):
            ax = axes[i]
            groups = [df.loc[df[self.target_col] == c, col].dropna()
                      for c in sorted(df[self.target_col].unique())]
            ax.boxplot(groups,
                       labels=[f"C{c}" for c in sorted(df[self.target_col].unique())],
                       patch_artist=True,
                       boxprops=dict(facecolor="#BBDEFB"),
                       medianprops=dict(color="#F44336", linewidth=2))
            ax.set_title(col, fontsize=9)
            ax.tick_params(labelsize=7)

        for j in range(i + 1, len(axes)):
            axes[j].set_visible(False)

        fig.suptitle("Feature Distribution by Target Class", fontsize=13, fontweight="bold")
        plt.tight_layout()
        path = PLOTS_DIR / "eda_feature_vs_target.png"
        plt.savefig(path, dpi=150, bbox_inches="tight")
        plt.show()
        plt.close()
        logger.info(f"  Saved: {path}")
        force_cleanup()

    # ── Preprocessing ──────────────────────────────────────────────────
    def clean(self, df: pd.DataFrame) -> pd.DataFrame:
        logger.info("Cleaning data ...")
        init_shape = df.shape
        with tqdm(total=7, desc="  Cleaning steps") as pbar:
            # 1. Drop fully empty rows/cols
            df = df.dropna(how="all")
            df = df.dropna(axis=1, how="all")
            pbar.update(1)

            # 2. Remove duplicate rows
            df = df.drop_duplicates()
            pbar.update(1)

            # 3. Drop columns >80% missing
            thresh = int(0.2 * len(df))
            df = df.dropna(axis=1, thresh=thresh)
            pbar.update(1)

            # 4. Identify column types
            self.numerical_cols = df.select_dtypes(
                include=[np.number]).columns.tolist()
            self.categorical_cols = df.select_dtypes(
                include=["object", "category"]).columns.tolist()
            if self.target_col in self.numerical_cols:
                self.numerical_cols.remove(self.target_col)
            pbar.update(1)

            # 5. Impute
            if self.numerical_cols:
                self.imputer_num = SimpleImputer(strategy="median")
                df[self.numerical_cols] = self.imputer_num.fit_transform(
                    df[self.numerical_cols])
            pbar.update(1)

            if self.categorical_cols:
                self.imputer_cat = SimpleImputer(strategy="most_frequent")
                df[self.categorical_cols] = self.imputer_cat.fit_transform(
                    df[self.categorical_cols])
            pbar.update(1)

            # 6. Encode categoricals
            for col in self.categorical_cols:
                le = LabelEncoder()
                df[col] = le.fit_transform(df[col].astype(str))
                self.label_encoders[col] = le
            pbar.update(1)

        logger.info(f"  Shape: {init_shape} → {df.shape}")
        return df

    def remove_outliers_iqr(self, df: pd.DataFrame,
                             cols: list, k: float = 3.0) -> pd.DataFrame:
        """Flag and optionally clip outliers using IQR * k."""
        logger.info("Outlier detection (IQR) ...")
        outlier_counts = {}
        for col in tqdm(cols, desc="  IQR outlier check", leave=False):
            if col not in df.columns:
                continue
            Q1, Q3 = df[col].quantile(0.25), df[col].quantile(0.75)
            IQR = Q3 - Q1
            lo, hi = Q1 - k * IQR, Q3 + k * IQR
            n_out = ((df[col] < lo) | (df[col] > hi)).sum()
            outlier_counts[col] = n_out
            df[col] = df[col].clip(lo, hi)       # Winsorize rather than drop
        self._plot_outlier_summary(outlier_counts)
        return df

    def _plot_outlier_summary(self, counts: dict) -> None:
        counts = {k: v for k, v in counts.items() if v > 0}
        if not counts:
            return
        fig, ax = plt.subplots(figsize=(10, max(4, len(counts) * 0.4)))
        keys = list(counts.keys())[:25]
        vals = [counts[k] for k in keys]
        ax.barh(keys, vals, color="#FF7043")
        ax.set_title("Outlier Counts per Feature (IQR×3 method)", fontsize=12)
        ax.set_xlabel("Count")
        plt.tight_layout()
        path = PLOTS_DIR / "eda_outlier_summary.png"
        plt.savefig(path, dpi=150, bbox_inches="tight")
        plt.show()
        plt.close()
        logger.info(f"  Saved: {path}")

    def scale(self, X_train: np.ndarray, X_val: np.ndarray,
              X_test: np.ndarray, X_holdout: np.ndarray):
        """Fit RobustScaler on training data; transform all splits."""
        self.scaler = RobustScaler()
        X_train  = self.scaler.fit_transform(X_train)
        X_val    = self.scaler.transform(X_val)
        X_test   = self.scaler.transform(X_test)
        X_holdout= self.scaler.transform(X_holdout)
        return X_train, X_val, X_test, X_holdout


# ══════════════════════════════════════════════════════════════════════════════
#  SECTION 4 — FEATURE ENGINEERING
# ══════════════════════════════════════════════════════════════════════════════

class FeatureEngineer:
    """
    Domain-aware feature engineering for health risk signals.
    Creates clinical composite scores, interaction terms, and
    PCA/mutual-information–based selections.
    """

    def __init__(self):
        self.pca = None
        self.selected_features = None
        self.feature_importances_ = None

    def engineer(self, df: pd.DataFrame, target_col: str = "health_risk") -> pd.DataFrame:
        logger.info("Feature Engineering ...")
        with tqdm(total=8, desc="  Engineering features") as pbar:

            df = self._cardiovascular_risk(df);         pbar.update(1)
            df = self._metabolic_syndrome_score(df);    pbar.update(1)
            df = self._lifestyle_composite(df);         pbar.update(1)
            df = self._sleep_stress_interaction(df);    pbar.update(1)
            df = self._bmi_age_interaction(df);         pbar.update(1)
            df = self._blood_pressure_features(df);     pbar.update(1)
            df = self._glucose_lipid_features(df);      pbar.update(1)
            df = self._polynomial_features(df);         pbar.update(1)

        logger.info(f"  Feature count after engineering: {df.shape[1]}")
        return df

    def _cardiovascular_risk(self, df: pd.DataFrame) -> pd.DataFrame:
        """Framingham-inspired CVD risk proxy."""
        if all(c in df.columns for c in
               ["age", "systolic_bp", "total_cholesterol", "hdl_cholesterol",
                "smoker"]):
            tc  = df["total_cholesterol"].clip(100, 400)
            hdl = df["hdl_cholesterol"].clip(10, 120)
            df["cvd_risk_score"] = (
                0.04 * df["age"]
                + 0.01 * df["systolic_bp"]
                + 0.01 * tc
                - 0.01 * hdl
                + 0.3  * df.get("smoker", 0)
            ).astype(np.float32)
        return df

    def _metabolic_syndrome_score(self, df: pd.DataFrame) -> pd.DataFrame:
        """Metabolic syndrome component count (0-5)."""
        score = pd.Series(0, index=df.index, dtype=np.float32)
        if "waist_cm" in df.columns:
            score += (((df["gender"] == 0) & (df["waist_cm"] > 102)) |
                      ((df["gender"] == 1) & (df["waist_cm"] > 88))).astype(int)
        if "fasting_glucose" in df.columns:
            score += (df["fasting_glucose"] >= 100).astype(int)
        if "systolic_bp" in df.columns:
            score += (df["systolic_bp"] >= 130).astype(int)
        if "hdl_cholesterol" in df.columns:
            score += (((df["gender"] == 0) & (df["hdl_cholesterol"] < 40)) |
                      ((df["gender"] == 1) & (df["hdl_cholesterol"] < 50))).astype(int)
        if "total_cholesterol" in df.columns and "hdl_cholesterol" in df.columns:
            trig_proxy = df["total_cholesterol"] - df["hdl_cholesterol"]
            score += (trig_proxy > 150).astype(int)
        df["metabolic_syndrome_score"] = score
        return df

    def _lifestyle_composite(self, df: pd.DataFrame) -> pd.DataFrame:
        """Composite lifestyle risk: higher = worse."""
        score = pd.Series(0.0, index=df.index, dtype=np.float32)
        if "smoker" in df.columns:
            score += df["smoker"].fillna(0) * 2.0
        if "alcohol_drinks" in df.columns:
            score += (df["alcohol_drinks"].fillna(0) > 2).astype(float)
        if "physically_active" in df.columns:
            score -= df["physically_active"].fillna(0) * 1.5
        if "sleep_hours" in df.columns:
            sh = df["sleep_hours"].fillna(7)
            score += ((sh < 5) | (sh > 9)).astype(float)
        if "stress_level" in df.columns:
            score += df["stress_level"].fillna(5) / 10.0
        if "depression_phq" in df.columns:
            score += df["depression_phq"].fillna(0) / 27.0
        df["lifestyle_risk_composite"] = score.astype(np.float32)
        return df

    def _sleep_stress_interaction(self, df: pd.DataFrame) -> pd.DataFrame:
        if "sleep_hours" in df.columns and "stress_level" in df.columns:
            df["sleep_stress_ratio"] = (
                df["stress_level"].fillna(5) /
                df["sleep_hours"].fillna(7).replace(0, 1)
            ).astype(np.float32)
        return df

    def _bmi_age_interaction(self, df: pd.DataFrame) -> pd.DataFrame:
        if "bmi" in df.columns and "age" in df.columns:
            df["bmi_age_product"] = (df["bmi"] * df["age"] / 100).astype(np.float32)
            df["bmi_sq"] = (df["bmi"] ** 2 / 1000).astype(np.float32)
        return df

    def _blood_pressure_features(self, df: pd.DataFrame) -> pd.DataFrame:
        if "systolic_bp" in df.columns and "diastolic_bp" in df.columns:
            df["pulse_pressure"] = (
                df["systolic_bp"] - df["diastolic_bp"]).astype(np.float32)
            df["mean_arterial_pressure"] = (
                df["diastolic_bp"] + df["pulse_pressure"] / 3).astype(np.float32)
            df["hypertension_stage"] = pd.cut(
                df["systolic_bp"],
                bins=[0, 120, 130, 140, 180, 999],
                labels=[0, 1, 2, 3, 4]
            ).astype(float).fillna(0).astype(np.float32)
        return df

    def _glucose_lipid_features(self, df: pd.DataFrame) -> pd.DataFrame:
        if "fasting_glucose" in df.columns:
            df["prediabetes_flag"] = (
                (df["fasting_glucose"] >= 100) & (df["fasting_glucose"] < 126)
            ).astype(np.int8)
            df["diabetes_flag"] = (df["fasting_glucose"] >= 126).astype(np.int8)
        if "hba1c" in df.columns:
            df["hba1c_risk"] = pd.cut(
                df["hba1c"], bins=[0, 5.7, 6.5, 100],
                labels=[0, 1, 2]
            ).astype(float).fillna(0).astype(np.float32)
        if all(c in df.columns for c in ["total_cholesterol", "hdl_cholesterol"]):
            df["cholesterol_ratio"] = (
                df["total_cholesterol"] / df["hdl_cholesterol"].replace(0, 1)
            ).astype(np.float32)
        return df

    def _polynomial_features(self, df: pd.DataFrame) -> pd.DataFrame:
        """Interaction terms between top clinical features."""
        pairs = [
            ("bmi", "age"), ("systolic_bp", "age"),
            ("fasting_glucose", "bmi"), ("total_cholesterol", "age"),
        ]
        for a, b in pairs:
            if a in df.columns and b in df.columns:
                df[f"{a}x{b}"] = (df[a] * df[b] / 1000).astype(np.float32)
        return df

    def select_features(self, X: pd.DataFrame, y: pd.Series,
                        n_features: int = 40) -> pd.DataFrame:
        """Mutual information + RandomForest importance–based selection."""
        logger.info(f"  Selecting top {n_features} features ...")
        n_features = min(n_features, X.shape[1])

        # Mutual information
        mi = mutual_info_classif(X.values, y.values, random_state=RANDOM_STATE)
        mi_series = pd.Series(mi, index=X.columns).sort_values(ascending=False)

        # RandomForest quick importance
        rf = RandomForestClassifier(n_estimators=50, max_depth=5,
                                    random_state=RANDOM_STATE, n_jobs=-1)
        rf.fit(X.values, y.values)
        fi_series = pd.Series(rf.feature_importances_,
                              index=X.columns).sort_values(ascending=False)
        force_cleanup(rf)

        # Combine rankings
        rank_mi = mi_series.rank(ascending=False)
        rank_fi = fi_series.rank(ascending=False)
        combined = (rank_mi + rank_fi).sort_values()
        self.selected_features = combined.index[:n_features].tolist()
        self.feature_importances_ = fi_series

        self._plot_feature_importance(fi_series.head(25))
        return X[self.selected_features]

    def _plot_feature_importance(self, fi: pd.Series) -> None:
        fig, ax = plt.subplots(figsize=(10, 8))
        fi.sort_values().plot.barh(ax=ax, color="#42A5F5")
        ax.set_title("RandomForest Feature Importance (Top 25)", fontsize=12,
                     fontweight="bold")
        ax.set_xlabel("Importance")
        plt.tight_layout()
        path = PLOTS_DIR / "feature_importance.png"
        plt.savefig(path, dpi=150, bbox_inches="tight")
        plt.show()
        plt.close()
        logger.info(f"  Saved: {path}")

    def compute_pca(self, X: np.ndarray,
                    n_components: int = 10) -> np.ndarray:
        """Append PCA components to feature matrix."""
        n_components = min(n_components, X.shape[1], X.shape[0] - 1)
        self.pca = PCA(n_components=n_components, random_state=RANDOM_STATE)
        pca_feats = self.pca.fit_transform(X)
        self._plot_pca_variance()
        return np.hstack([X, pca_feats])

    def _plot_pca_variance(self) -> None:
        if self.pca is None:
            return
        evr = self.pca.explained_variance_ratio_
        cumulative = np.cumsum(evr)
        fig, axes = plt.subplots(1, 2, figsize=(12, 4))
        axes[0].bar(range(1, len(evr) + 1), evr, color="#66BB6A")
        axes[0].set_title("Explained Variance per PC")
        axes[0].set_xlabel("Principal Component")
        axes[0].set_ylabel("Variance Ratio")

        axes[1].plot(range(1, len(cumulative) + 1), cumulative,
                     "b-o", linewidth=2)
        axes[1].axhline(0.9, color="red", linestyle="--", label="90% threshold")
        axes[1].set_title("Cumulative Explained Variance")
        axes[1].set_xlabel("Principal Component")
        axes[1].set_ylabel("Cumulative Variance Ratio")
        axes[1].legend()

        fig.suptitle("PCA Analysis", fontsize=13, fontweight="bold")
        plt.tight_layout()
        path = PLOTS_DIR / "feature_pca_variance.png"
        plt.savefig(path, dpi=150, bbox_inches="tight")
        plt.show()
        plt.close()
        logger.info(f"  Saved: {path}")


# ══════════════════════════════════════════════════════════════════════════════
#  SECTION 5 — DATA SPLITTING  (Train/Val/Test/Holdout)
# ══════════════════════════════════════════════════════════════════════════════

def verify_data_splits(splits: dict) -> bool:
    """
    Validate four-way split:
    - Correct proportions
    - No index overlap (no data leakage)
    - Class balance reasonable
    """
    logger.info("Verifying data splits ...")
    total = sum(len(v[0]) for v in splits.values())
    all_ok = True

    # Index overlap check
    index_sets = {k: set(range(offset, offset + len(v[0])))
                  for k, v in splits.items()
                  for offset in [0]}
    # Re-check via original indices stored alongside
    names = list(splits.keys())

    for i in range(len(names)):
        for j in range(i + 1, len(names)):
            n1, n2 = names[i], names[j]
            Xi, _ = splits[n1]
            Xj, _ = splits[n2]
            # Check shapes don't share rows (robust for numpy arrays)
            if Xi.shape[0] + Xj.shape[0] > total:
                logger.error(f"  LEAK detected between {n1} and {n2}!")
                all_ok = False

    # Proportion check
    expected = SPLIT_CONFIG
    for name, (X, y) in splits.items():
        frac = len(X) / total
        exp_frac = expected.get(name, 0)
        diff = abs(frac - exp_frac)
        status = "" if diff < 0.03 else "✗"
        logger.info(f"  {status} {name:8s}: n={len(X):5d}  "
                    f"frac={frac:.3f}  expected≈{exp_frac:.2f}  "
                    f"pos_rate={y.mean():.3f}")
        if diff >= 0.05:
            all_ok = False

    return all_ok


def four_way_split(X: np.ndarray, y: np.ndarray):
    """
    Strict four-way stratified split:
      Holdout (30%) → set aside immediately, never touched again
      Remaining 70% → Train(40/70≈57%), Val(15/70≈21%), Test(15/70≈21%)
    """
    logger.info("Creating four-way stratified split ...")
    total = len(X)

    # Step 1: Carve out holdout (30%)
    sss1 = StratifiedShuffleSplit(n_splits=1, test_size=0.30,
                                  random_state=RANDOM_STATE)
    dev_idx, holdout_idx = next(sss1.split(X, y))

    X_dev,     y_dev     = X[dev_idx],     y[dev_idx]
    X_holdout, y_holdout = X[holdout_idx], y[holdout_idx]

    # Step 2: From remaining 70%, carve test (15/70 ≈ 0.2143)
    test_frac = SPLIT_CONFIG["test"] / (1 - SPLIT_CONFIG["holdout"])
    sss2 = StratifiedShuffleSplit(n_splits=1,
                                  test_size=test_frac,
                                  random_state=RANDOM_STATE + 1)
    trainval_idx, test_idx = next(sss2.split(X_dev, y_dev))

    X_trainval, y_trainval = X_dev[trainval_idx], y_dev[trainval_idx]
    X_test,     y_test     = X_dev[test_idx],     y_dev[test_idx]

    # Step 3: From trainval, carve val (15/55 ≈ 0.2727)
    val_frac = SPLIT_CONFIG["val"] / (SPLIT_CONFIG["train"] + SPLIT_CONFIG["val"])
    sss3 = StratifiedShuffleSplit(n_splits=1,
                                  test_size=val_frac,
                                  random_state=RANDOM_STATE + 2)
    train_idx, val_idx = next(sss3.split(X_trainval, y_trainval))

    X_train, y_train = X_trainval[train_idx], y_trainval[train_idx]
    X_val,   y_val   = X_trainval[val_idx],   y_trainval[val_idx]

    splits = {
        "train":   (X_train,   y_train),
        "val":     (X_val,     y_val),
        "test":    (X_test,    y_test),
        "holdout": (X_holdout, y_holdout),
    }

    verify_data_splits(splits)
    return splits


# ══════════════════════════════════════════════════════════════════════════════
#  SECTION 6 — OUTLIER DETECTION & SAMPLE WEIGHTS
# ══════════════════════════════════════════════════════════════════════════════

def compute_sample_weights(X_train: np.ndarray,
                           y_train: np.ndarray) -> np.ndarray:
    """
    IsolationForest outlier scores → adjust training sample weights.
    Outlier samples get reduced weight; typical samples get unit weight.
    """
    logger.info("Computing outlier-adjusted sample weights ...")
    iso = IsolationForest(n_estimators=100, contamination=0.05,
                          random_state=RANDOM_STATE, n_jobs=-1)
    iso.fit(X_train)
    scores = iso.score_samples(X_train)       # More negative = more anomalous
    # Normalise to [0.3, 1.0]  — outliers get min weight 0.3
    normalised = (scores - scores.min()) / (scores.max() - scores.min() + 1e-9)
    weights = 0.3 + 0.7 * normalised
    force_cleanup(iso)
    logger.info(f"  Weight range: [{weights.min():.3f}, {weights.max():.3f}]")
    return weights.astype(np.float32)


# ══════════════════════════════════════════════════════════════════════════════
#  SECTION 7 — NEURAL NETWORK WITH ATTENTION
# ══════════════════════════════════════════════════════════════════════════════

def build_attention_nn(input_dim: int) -> "tf.keras.Model":
    """
    Hybrid Dense + Multi-Head Self-Attention network.
    Anti-overfitting: L2, Dropout(40%), BatchNorm, GaussianNoise.
    """
    inp = Input(shape=(input_dim,), name="input")

    # Input noise augmentation
    x = GaussianNoise(0.05)(inp)

    # Block 1
    x = Dense(128, activation="relu", kernel_regularizer=l2(1e-4))(x)
    x = BatchNormalization()(x)
    x = Dropout(0.4)(x)

    # Block 2
    x = Dense(64, activation="relu", kernel_regularizer=l2(1e-4))(x)
    x = BatchNormalization()(x)
    x = Dropout(0.35)(x)

    # Attention block: reshape to seq, apply attention, flatten
    x_seq = Reshape((1, 64))(x)
    attn_out, _ = MultiHeadAttention(num_heads=4, key_dim=16)(
        x_seq, x_seq, return_attention_scores=True)
    attn_out = LayerNormalization()(attn_out + x_seq)  # Residual
    x = GlobalAveragePooling1D()(attn_out)

    # Block 3
    x = Dense(32, activation="relu", kernel_regularizer=l2(1e-4))(x)
    x = BatchNormalization()(x)
    x = Dropout(0.3)(x)

    out = Dense(1, activation="sigmoid", name="output")(x)

    model = Model(inputs=inp, outputs=out, name="AttentionNet")
    model.compile(optimizer=Adam(learning_rate=1e-3),
                  loss="binary_crossentropy",
                  metrics=["accuracy",
                            tf.keras.metrics.AUC(name="auc"),
                            tf.keras.metrics.Precision(name="precision"),
                            tf.keras.metrics.Recall(name="recall")])
    return model


def train_neural_network(X_train, y_train, X_val, y_val,
                          sample_weights=None):
    """Train AttentionNet with callbacks; return model + history."""
    if not TF_AVAILABLE:
        return None, None

    logger.info("Training AttentionNet ...")
    model = build_attention_nn(X_train.shape[1])
    model.summary(print_fn=lambda s: logger.info("  " + s))

    checkpoint_path = str(MODELS_DIR / "attention_nn_best.keras")   # changed .h5 to .keras
    callbacks = [
        EarlyStopping(monitor="val_accuracy", patience=NN_PATIENCE,
                      restore_best_weights=True, min_delta=1e-4, mode="max"),
        ModelCheckpoint(checkpoint_path, monitor="val_accuracy",
                        save_best_only=True, mode="max", verbose=0),
        ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=5,
                          min_lr=1e-6, verbose=0),
    ]

    history = model.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        epochs=NN_EPOCHS,
        batch_size=BATCH_SIZE_NN,
        callbacks=callbacks,
        sample_weight=sample_weights,
        verbose=0,
        shuffle=True,
    )
    logger.info(f"  Best val_accuracy: {max(history.history['val_accuracy']):.4f}")
    return model, history


# ══════════════════════════════════════════════════════════════════════════════
#  SECTION 8 — CLASSICAL ML MODELS
# ══════════════════════════════════════════════════════════════════════════════

def build_classical_models(n_features: int) -> dict:
    """Return dictionary of instantiated classical models."""
    models = {
        "GradientBoosting": GradientBoostingClassifier(
            n_estimators=200, max_depth=4, learning_rate=0.05,
            subsample=0.8, max_features="sqrt",
            validation_fraction=0.1, n_iter_no_change=15,
            random_state=RANDOM_STATE),
        "RandomForest": RandomForestClassifier(
            n_estimators=200, max_depth=8, min_samples_leaf=5,
            max_features="sqrt", class_weight="balanced",
            random_state=RANDOM_STATE, n_jobs=-1),
        "LogisticRegression": LogisticRegression(
            C=0.5, max_iter=1000, solver="saga",
            class_weight="balanced", random_state=RANDOM_STATE),
    }
    if XGB_AVAILABLE:
        models["XGBoost"] = xgb.XGBClassifier(
            n_estimators=200, max_depth=4, learning_rate=0.05,
            subsample=0.8, colsample_bytree=0.8, reg_alpha=0.1,
            reg_lambda=1.0, use_label_encoder=False,
            eval_metric="logloss", random_state=RANDOM_STATE,
            n_jobs=-1)
    if LGB_AVAILABLE:
        models["LightGBM"] = lgb.LGBMClassifier(
            n_estimators=200, max_depth=4, learning_rate=0.05,
            subsample=0.8, colsample_bytree=0.8, reg_alpha=0.1,
            reg_lambda=1.0, class_weight="balanced",
            random_state=RANDOM_STATE, n_jobs=-1, verbose=-1)
    return models


def train_classical_models(models: dict, X_train, y_train,
                             X_val, y_val, sample_weights=None):
    """Train each classical model; return results dict."""
    results = {}
    for name, model in tqdm(models.items(), desc="Training classical models"):
        try:
            t0 = time.time()
            if sample_weights is not None and hasattr(model, "fit"):
                try:
                    model.fit(X_train, y_train, sample_weight=sample_weights)
                except TypeError:
                    model.fit(X_train, y_train)
            else:
                model.fit(X_train, y_train)
            elapsed = time.time() - t0
            val_pred  = model.predict(X_val)
            val_proba = model.predict_proba(X_val)[:, 1]
            val_acc   = accuracy_score(y_val, val_pred)
            val_auc   = roc_auc_score(y_val, val_proba)
            results[name] = {
                "model":    model,
                "val_acc":  val_acc,
                "val_auc":  val_auc,
                "time_s":   elapsed,
            }
            logger.info(f"  {name:20s}  val_acc={val_acc:.4f}  "
                        f"val_auc={val_auc:.4f}  t={elapsed:.1f}s")
        except Exception as exc:
            logger.error(f"  {name} failed: {exc}")
        check_memory_limit()
        gc.collect()
    return results


# ══════════════════════════════════════════════════════════════════════════════
#  SECTION 9 — HYPERPARAMETER TUNING
# ══════════════════════════════════════════════════════════════════════════════

def tune_hyperparameters(X_train, y_train, X_val, y_val,
                          best_model_name: str) -> object:
    """
    Grid search over a focused hyperparameter grid for the best classical
    model. Uses validation-set evaluation (no re-fitting on val).
    """
    from sklearn.model_selection import ParameterGrid

    logger.info(f"Hyperparameter tuning for {best_model_name} ...")

    param_grids = {
        "GradientBoosting": {
            "n_estimators": [200, 300],
            "max_depth":    [3, 4, 5],
            "learning_rate":[0.03, 0.05, 0.1],
            "subsample":    [0.7, 0.9],
        },
        "RandomForest": {
            "n_estimators": [200, 300],
            "max_depth":    [6, 8, 10],
            "min_samples_leaf": [3, 5],
        },
        "XGBoost": {
            "n_estimators": [200, 300],
            "max_depth":    [3, 4],
            "learning_rate":[0.03, 0.07],
            "subsample":    [0.7, 0.9],
        },
        "LightGBM": {
            "n_estimators": [200, 300],
            "max_depth":    [4, 6],
            "learning_rate":[0.03, 0.07],
        },
    }

    grid = param_grids.get(best_model_name, {})
    if not grid:
        logger.info("  No tuning grid defined — returning as-is.")
        return None

    all_params = list(ParameterGrid(grid))
    best_acc, best_params, best_model = 0, {}, None

    with tqdm(total=len(all_params), desc="  Hyperparameter grid") as pbar:
        for params in all_params:
            try:
                if best_model_name == "GradientBoosting":
                    m = GradientBoostingClassifier(**params,
                            random_state=RANDOM_STATE)
                elif best_model_name == "RandomForest":
                    m = RandomForestClassifier(**params,
                            class_weight="balanced",
                            random_state=RANDOM_STATE, n_jobs=-1)
                elif best_model_name == "XGBoost" and XGB_AVAILABLE:
                    m = xgb.XGBClassifier(**params,
                            use_label_encoder=False,
                            eval_metric="logloss",
                            random_state=RANDOM_STATE, n_jobs=-1)
                elif best_model_name == "LightGBM" and LGB_AVAILABLE:
                    m = lgb.LGBMClassifier(**params,
                            class_weight="balanced",
                            random_state=RANDOM_STATE, n_jobs=-1,
                            verbose=-1)
                else:
                    pbar.update(1)
                    continue

                m.fit(X_train, y_train)
                acc = accuracy_score(y_val, m.predict(X_val))
                if acc > best_acc:
                    best_acc = acc
                    best_params = params
                    best_model = m
                force_cleanup(m)
            except Exception as exc:
                logger.warning(f"    Tuning step failed: {exc}")
            pbar.update(1)
            check_memory_limit()

    logger.info(f"  Best params: {best_params}  val_acc={best_acc:.4f}")
    return best_model


# ══════════════════════════════════════════════════════════════════════════════
#  SECTION 10 — ENSEMBLE / STACKING
# ══════════════════════════════════════════════════════════════════════════════

def build_stacking_ensemble(base_results: dict,
                             X_train, y_train,
                             X_val, y_val) -> object:
    """
    Stacking ensemble: top 3 base classifiers → LogisticRegression meta.
    Also builds a soft-voting ensemble for comparison.
    """
    logger.info("Building stacking ensemble ...")

    # Select top-3 by val_acc
    sorted_models = sorted(base_results.items(),
                           key=lambda x: x[1]["val_acc"], reverse=True)[:3]
    estimators = [(n, r["model"]) for n, r in sorted_models]

    if len(estimators) < 2:
        logger.warning("  Not enough base models for stacking.")
        return None

    # Stacking
    stack = StackingClassifier(
        estimators=estimators,
        final_estimator=LogisticRegression(C=0.5, max_iter=500,
                                            random_state=RANDOM_STATE),
        cv=StratifiedKFold(n_splits=3, shuffle=True,
                           random_state=RANDOM_STATE),
        n_jobs=-1,
        passthrough=False,
    )
    stack.fit(X_train, y_train)
    val_acc = accuracy_score(y_val, stack.predict(X_val))
    logger.info(f"  Stacking val_acc={val_acc:.4f}")
    gc.collect()
    return stack, val_acc


# ══════════════════════════════════════════════════════════════════════════════
#  SECTION 11 — EVALUATION METRICS
# ══════════════════════════════════════════════════════════════════════════════

def evaluate_model(model, X: np.ndarray, y: np.ndarray,
                   split_name: str, nn_model=None) -> dict:
    """Full evaluation on any split."""
    # Predictions
    if nn_model is not None and TF_AVAILABLE:
        proba_nn = nn_model.predict(X, verbose=0).flatten()
    else:
        proba_nn = None

    proba = model.predict_proba(X)[:, 1]
    pred  = model.predict(X)

    if proba_nn is not None:
        proba_ensemble = 0.6 * proba + 0.4 * proba_nn
        pred_ensemble  = (proba_ensemble >= 0.5).astype(int)
    else:
        proba_ensemble = proba
        pred_ensemble  = pred

    metrics = {
        "split":      split_name,
        "accuracy":   accuracy_score(y, pred_ensemble),
        "precision":  precision_score(y, pred_ensemble, zero_division=0),
        "recall":     recall_score(y, pred_ensemble, zero_division=0),
        "f1":         f1_score(y, pred_ensemble, zero_division=0),
        "roc_auc":    roc_auc_score(y, proba_ensemble),
        "avg_precision": average_precision_score(y, proba_ensemble),
        "n_samples":  len(y),
        "proba":      proba_ensemble,
        "pred":       pred_ensemble,
    }
    logger.info(f"  {split_name:8s}  acc={metrics['accuracy']:.4f}  "
                f"auc={metrics['roc_auc']:.4f}  f1={metrics['f1']:.4f}  "
                f"prec={metrics['precision']:.4f}  rec={metrics['recall']:.4f}")
    return metrics


def detect_overfitting(train_metrics: dict,
                       val_metrics:   dict,
                       test_metrics:  dict,
                       holdout_metrics: dict) -> str:
    """
    Rule-based overfitting detection.
    Returns human-readable verdict.
    """
    train_acc   = train_metrics["accuracy"]
    val_acc     = val_metrics["accuracy"]
    test_acc    = test_metrics["accuracy"]
    holdout_acc = holdout_metrics["accuracy"]

    gap_tv  = train_acc - val_acc
    gap_th  = train_acc - holdout_acc
    gen_gap = train_acc - np.mean([val_acc, test_acc, holdout_acc])

    if gen_gap > 0.15:
        verdict = "⚠ SEVERE OVERFITTING — reduce complexity or add regularisation"
    elif gen_gap > 0.07:
        verdict = "⚠ MODERATE OVERFITTING — consider more dropout / early stopping"
    elif gen_gap < 0.01 and holdout_acc < 0.6:
        verdict = "⚠ UNDERFITTING — model may need more capacity or features"
    else:
        verdict = " GOOD GENERALISATION — train/holdout gap within healthy range"

    logger.info(f"  Overfitting check: {verdict}")
    logger.info(f"    Train-Val gap:     {gap_tv:.4f}")
    logger.info(f"    Train-Holdout gap: {gap_th:.4f}")
    logger.info(f"    Generalisation gap:{gen_gap:.4f}")
    return verdict


# ══════════════════════════════════════════════════════════════════════════════
#  SECTION 12 — CROSS VALIDATION
# ══════════════════════════════════════════════════════════════════════════════

def run_cross_validation(model, X: np.ndarray, y: np.ndarray,
                          n_splits: int = 5) -> dict:
    """5-fold stratified CV on training+val data."""
    logger.info(f"5-fold cross-validation ...")
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True,
                          random_state=RANDOM_STATE)

    metrics = {"accuracy": [], "roc_auc": [], "f1": []}
    with tqdm(total=n_splits, desc="  CV folds") as pbar:
        for fold, (tr_idx, va_idx) in enumerate(skf.split(X, y)):
            X_tr, X_va = X[tr_idx], X[va_idx]
            y_tr, y_va = y[tr_idx], y[va_idx]
            try:
                import copy
                m = copy.deepcopy(model)
                m.fit(X_tr, y_tr)
                pred  = m.predict(X_va)
                proba = m.predict_proba(X_va)[:, 1]
                metrics["accuracy"].append(accuracy_score(y_va, pred))
                metrics["roc_auc"].append(roc_auc_score(y_va, proba))
                metrics["f1"].append(f1_score(y_va, pred, zero_division=0))
                force_cleanup(m)
            except Exception as exc:
                logger.warning(f"  CV fold {fold+1} failed: {exc}")
            pbar.update(1)
            gc.collect()

    summary = {k: {"mean": np.mean(v), "std": np.std(v)}
               for k, v in metrics.items() if v}
    for k, s in summary.items():
        logger.info(f"  CV {k}: {s['mean']:.4f} ± {s['std']:.4f}")
    return summary


# ══════════════════════════════════════════════════════════════════════════════
#  SECTION 13 — VISUALISATION
# ══════════════════════════════════════════════════════════════════════════════

class Visualiser:
    """All evaluation plots — each saved + shown."""

    @staticmethod
    def save_show(path: Path) -> None:
        plt.savefig(path, dpi=150, bbox_inches="tight")
        plt.show()
        plt.close()
        logger.info(f"  Saved: {path}")

    def plot_roc_curves(self, all_metrics: dict) -> None:
        """ROC curves for all splits on the same axes."""
        fig, ax = plt.subplots(figsize=(8, 6))
        for split, m in all_metrics.items():
            if "proba" not in m:
                continue
            fpr, tpr, _ = roc_curve(m["y_true"], m["proba"])
            ax.plot(fpr, tpr, linewidth=2, color=PALETTE.get(split, "gray"),
                    label=f"{split} (AUC={m['roc_auc']:.3f})")
        ax.plot([0, 1], [0, 1], "k--", linewidth=1, label="Random")
        ax.set_xlabel("False Positive Rate"); ax.set_ylabel("True Positive Rate")
        ax.set_title("ROC Curves — All Splits", fontsize=13, fontweight="bold")
        ax.legend(loc="lower right")
        ax.grid(alpha=0.3)
        self.save_show(PLOTS_DIR / "eval_roc_curves.png")

    def plot_precision_recall(self, all_metrics: dict) -> None:
        fig, ax = plt.subplots(figsize=(8, 6))
        for split, m in all_metrics.items():
            if "proba" not in m:
                continue
            prec, rec, _ = precision_recall_curve(m["y_true"], m["proba"])
            ax.plot(rec, prec, linewidth=2, color=PALETTE.get(split, "gray"),
                    label=f"{split} (AP={m['avg_precision']:.3f})")
        ax.set_xlabel("Recall"); ax.set_ylabel("Precision")
        ax.set_title("Precision-Recall Curves", fontsize=13, fontweight="bold")
        ax.legend(loc="upper right")
        ax.grid(alpha=0.3)
        self.save_show(PLOTS_DIR / "eval_pr_curves.png")

    def plot_confusion_matrices(self, all_metrics: dict) -> None:
        n = len(all_metrics)
        fig, axes = plt.subplots(1, n, figsize=(5 * n, 4))
        if n == 1:
            axes = [axes]
        for ax, (split, m) in zip(axes, all_metrics.items()):
            cm = confusion_matrix(m["y_true"], m["pred"])
            sns.heatmap(cm, annot=True, fmt="d", ax=ax,
                        cmap="Blues", linewidths=0.5,
                        xticklabels=["Low Risk", "High Risk"],
                        yticklabels=["Low Risk", "High Risk"])
            ax.set_title(f"{split}\nacc={m['accuracy']:.3f}",
                         fontsize=11, fontweight="bold")
            ax.set_xlabel("Predicted"); ax.set_ylabel("Actual")
        fig.suptitle("Confusion Matrices — All Splits", fontsize=13,
                     fontweight="bold")
        plt.tight_layout()
        self.save_show(PLOTS_DIR / "eval_confusion_matrices.png")

    def plot_metrics_bar(self, all_metrics: dict) -> None:
        metric_names = ["accuracy", "precision", "recall", "f1", "roc_auc"]
        splits = list(all_metrics.keys())
        x = np.arange(len(metric_names))
        width = 0.8 / len(splits)

        fig, ax = plt.subplots(figsize=(14, 6))
        for i, split in enumerate(splits):
            values = [all_metrics[split].get(m, 0) for m in metric_names]
            offset = (i - len(splits) / 2 + 0.5) * width
            bars = ax.bar(x + offset, values, width,
                          label=split, color=PALETTE.get(split, "#999"))
            for bar, v in zip(bars, values):
                if v > 0:
                    ax.text(bar.get_x() + bar.get_width() / 2,
                            bar.get_height() + 0.005,
                            f"{v:.3f}", ha="center", va="bottom",
                            fontsize=7, rotation=45)

        ax.set_xticks(x)
        ax.set_xticklabels(metric_names, fontsize=11)
        ax.set_ylim(0, 1.15)
        ax.set_ylabel("Score")
        ax.set_title("Evaluation Metrics — All Splits", fontsize=13,
                     fontweight="bold")
        ax.legend(loc="upper right")
        ax.axhline(0.9, color="green", linestyle="--", linewidth=1,
                   label="0.90 target")
        ax.grid(axis="y", alpha=0.3)
        plt.tight_layout()
        self.save_show(PLOTS_DIR / "eval_metrics_bar.png")

    def plot_generalization_gap(self, all_metrics: dict) -> None:
        """Show train vs val/test/holdout gap for overfitting detection."""
        train_acc = all_metrics.get("train", {}).get("accuracy", 0)
        splits = ["val", "test", "holdout"]
        gaps = [train_acc - all_metrics.get(s, {}).get("accuracy", train_acc)
                for s in splits]
        accs = [all_metrics.get(s, {}).get("accuracy", 0) for s in splits]

        fig, axes = plt.subplots(1, 2, figsize=(14, 5))
        # Left: gap bars
        colors = ["#4CAF50" if g < 0.05 else "#FF9800" if g < 0.10 else "#F44336"
                  for g in gaps]
        axes[0].bar(splits, gaps, color=colors)
        axes[0].axhline(0.05, color="orange", linestyle="--", label="5% warning")
        axes[0].axhline(0.10, color="red",    linestyle="--", label="10% critical")
        axes[0].set_title("Generalisation Gap (Train − Split Accuracy)",
                          fontsize=12, fontweight="bold")
        axes[0].set_ylabel("Accuracy Gap")
        axes[0].legend()
        for i, (s, g) in enumerate(zip(splits, gaps)):
            axes[0].text(i, g + 0.002, f"{g:.4f}", ha="center", fontsize=10)

        # Right: accuracy across all splits
        all_splits = ["train"] + splits
        all_accs   = [all_metrics.get(s, {}).get("accuracy", 0) for s in all_splits]
        bar_colors = [PALETTE.get(s, "#999") for s in all_splits]
        bars = axes[1].bar(all_splits, all_accs, color=bar_colors)
        axes[1].set_title("Accuracy — All Splits", fontsize=12, fontweight="bold")
        axes[1].set_ylabel("Accuracy")
        axes[1].set_ylim(0, 1.1)
        axes[1].axhline(0.9, color="green", linestyle="--", label="0.90 line")
        axes[1].legend()
        for bar, acc in zip(bars, all_accs):
            axes[1].text(bar.get_x() + bar.get_width() / 2,
                         bar.get_height() + 0.01,
                         f"{acc:.3f}", ha="center", fontsize=10)

        fig.suptitle("Overfitting / Generalisation Analysis", fontsize=14,
                     fontweight="bold")
        plt.tight_layout()
        self.save_show(PLOTS_DIR / "eval_generalization_gap.png")

    def plot_nn_history(self, history) -> None:
        if history is None:
            return
        h = history.history
        epochs = range(1, len(h["accuracy"]) + 1)
        best_epoch = np.argmax(h["val_accuracy"]) + 1

        fig, axes = plt.subplots(2, 2, figsize=(14, 10))
        pairs = [("accuracy", "val_accuracy", "Accuracy"),
                 ("loss",     "val_loss",     "Loss"),
                 ("auc",      "val_auc",      "AUC"),
                 ("precision","val_precision","Precision")]

        for ax, (tr_key, va_key, title) in zip(axes.flatten(), pairs):
            if tr_key not in h:
                ax.set_visible(False)
                continue
            ax.plot(epochs, h[tr_key], color=PALETTE["train"],
                    linewidth=2, label="Train")
            ax.plot(epochs, h[va_key], color=PALETTE["val"],
                    linewidth=2, label="Validation")
            ax.axvline(best_epoch, color="red", linestyle="--",
                       label=f"Best epoch ({best_epoch})")
            best_val = max(h[va_key]) if "loss" not in va_key else min(h[va_key])
            ax.scatter([best_epoch], [best_val], color="red", s=60, zorder=5)
            ax.set_title(title, fontsize=11, fontweight="bold")
            ax.set_xlabel("Epoch"); ax.set_ylabel(title)
            ax.legend(fontsize=9); ax.grid(alpha=0.3)

        fig.suptitle("Neural Network Training History", fontsize=14,
                     fontweight="bold")
        plt.tight_layout()
        self.save_show(PLOTS_DIR / "nn_training_history.png")

    def plot_feature_shap_proxy(self, model, X: np.ndarray,
                                 feature_names: list) -> None:
        """Permutation importance as SHAP proxy on holdout set."""
        logger.info("  Computing permutation importance (SHAP proxy) ...")
        try:
            result = permutation_importance(
                model, X[:500], None,
                # We need y for permutation importance
                n_repeats=10, random_state=RANDOM_STATE, n_jobs=-1)
        except Exception:
            return

    def plot_permutation_importance(self, model, X: np.ndarray,
                                     y: np.ndarray, feature_names: list,
                                     split_name: str = "holdout") -> None:
        logger.info(f"  Permutation importance on {split_name} ...")
        try:
            n = min(500, len(X))
            result = permutation_importance(
                model, X[:n], y[:n],
                n_repeats=8, random_state=RANDOM_STATE, n_jobs=-1)
            fi = pd.Series(result.importances_mean,
                           index=feature_names[:X.shape[1]]).sort_values(ascending=False).head(20)

            fig, ax = plt.subplots(figsize=(10, 7))
            fi.sort_values().plot.barh(ax=ax, color="#AB47BC")
            ax.set_title(f"Permutation Importance ({split_name})",
                         fontsize=12, fontweight="bold")
            ax.set_xlabel("Mean Decrease in Accuracy")
            plt.tight_layout()
            self.save_show(PLOTS_DIR / f"eval_perm_importance_{split_name}.png")
        except Exception as exc:
            logger.warning(f"  Permutation importance failed: {exc}")

    def plot_holdout_deep_dive(self, holdout_metrics: dict) -> None:
        """Detailed analysis of holdout (truly unseen) set."""
        m = holdout_metrics
        proba = m["proba"]
        y_true = m["y_true"]
        pred   = m["pred"]

        fig = plt.figure(figsize=(16, 12))
        gs  = gridspec.GridSpec(2, 3, figure=fig)

        # 1. Probability distribution
        ax1 = fig.add_subplot(gs[0, 0])
        for cls, color, label in [(0, PALETTE["neg"], "Low Risk"),
                                   (1, PALETTE["pos"], "High Risk")]:
            ax1.hist(proba[y_true == cls], bins=40, alpha=0.7,
                     color=color, label=label, density=True)
        ax1.axvline(0.5, color="black", linestyle="--", linewidth=1.5)
        ax1.set_title("Predicted Probability Distribution")
        ax1.set_xlabel("P(High Risk)"); ax1.legend()

        # 2. Calibration curve
        ax2 = fig.add_subplot(gs[0, 1])
        from sklearn.calibration import calibration_curve
        try:
            frac_pos, mean_pred = calibration_curve(y_true, proba, n_bins=10)
            ax2.plot(mean_pred, frac_pos, "s-", color=PALETTE["holdout"], label="Model")
            ax2.plot([0, 1], [0, 1], "k--", label="Perfect")
        except Exception:
            ax2.text(0.5, 0.5, "Calibration N/A", ha="center", transform=ax2.transAxes)
        ax2.set_title("Calibration Curve (Holdout)")
        ax2.set_xlabel("Mean Predicted Probability")
        ax2.set_ylabel("Fraction of Positives")
        ax2.legend()

        # 3. ROC
        ax3 = fig.add_subplot(gs[0, 2])
        fpr, tpr, _ = roc_curve(y_true, proba)
        ax3.plot(fpr, tpr, color=PALETTE["holdout"], linewidth=2,
                 label=f"AUC={m['roc_auc']:.3f}")
        ax3.plot([0, 1], [0, 1], "k--")
        ax3.set_title("ROC Curve (Holdout)")
        ax3.set_xlabel("FPR"); ax3.set_ylabel("TPR")
        ax3.legend()

        # 4. Confusion matrix
        ax4 = fig.add_subplot(gs[1, 0])
        cm = confusion_matrix(y_true, pred)
        sns.heatmap(cm, annot=True, fmt="d", ax=ax4, cmap="RdYlGn",
                    xticklabels=["Low", "High"],
                    yticklabels=["Low", "High"])
        ax4.set_title(f"Confusion Matrix (Holdout)\nacc={m['accuracy']:.3f}")
        ax4.set_xlabel("Predicted"); ax4.set_ylabel("Actual")

        # 5. Error analysis — hard samples
        ax5 = fig.add_subplot(gs[1, 1])
        errors = np.abs(y_true - proba)
        ax5.hist(errors, bins=40, color="#FF7043", edgecolor="white")
        ax5.set_title("Error Magnitude Distribution (Holdout)")
        ax5.set_xlabel("|True − Predicted Probability|")
        ax5.set_ylabel("Count")

        # 6. Metric summary text
        ax6 = fig.add_subplot(gs[1, 2])
        ax6.axis("off")
        report_txt = (
            f"HOLDOUT EVALUATION REPORT\n"
            f"{'─' * 30}\n"
            f"Samples:      {m['n_samples']}\n"
            f"Accuracy:     {m['accuracy']:.4f}\n"
            f"Precision:    {m['precision']:.4f}\n"
            f"Recall:       {m['recall']:.4f}\n"
            f"F1 Score:     {m['f1']:.4f}\n"
            f"ROC-AUC:      {m['roc_auc']:.4f}\n"
            f"Avg Precision:{m['avg_precision']:.4f}\n"
        )
        ax6.text(0.05, 0.95, report_txt, transform=ax6.transAxes,
                 fontsize=11, verticalalignment="top",
                 fontfamily="monospace",
                 bbox=dict(boxstyle="round", facecolor="lightyellow", alpha=0.8))

        fig.suptitle("Holdout (Unseen Data) — Deep Dive Analysis",
                     fontsize=14, fontweight="bold")
        plt.tight_layout()
        self.save_show(PLOTS_DIR / "eval_holdout_deep_dive.png")

    def plot_model_comparison(self, base_results: dict) -> None:
        """Compare all classical models on val accuracy + AUC."""
        names  = list(base_results.keys())
        accs   = [base_results[n]["val_acc"] for n in names]
        aucs   = [base_results[n]["val_auc"] for n in names]
        times  = [base_results[n]["time_s"] for n in names]

        x = np.arange(len(names))
        width = 0.35

        fig, axes = plt.subplots(1, 2, figsize=(14, 5))
        axes[0].bar(x - width/2, accs, width, label="Accuracy", color="#42A5F5")
        axes[0].bar(x + width/2, aucs, width, label="ROC-AUC",  color="#66BB6A")
        axes[0].set_xticks(x); axes[0].set_xticklabels(names, rotation=15, ha="right")
        axes[0].set_ylim(0, 1.1)
        axes[0].set_title("Model Comparison — Val Accuracy & AUC",
                           fontsize=12, fontweight="bold")
        axes[0].legend()
        axes[0].axhline(0.9, color="red", linestyle="--", linewidth=1)
        axes[0].grid(axis="y", alpha=0.3)
        for i, (a, au) in enumerate(zip(accs, aucs)):
            axes[0].text(i - width/2, a + 0.01, f"{a:.3f}",
                         ha="center", fontsize=8)
            axes[0].text(i + width/2, au + 0.01, f"{au:.3f}",
                         ha="center", fontsize=8)

        axes[1].bar(names, times, color="#EF5350")
        axes[1].set_title("Training Time (seconds)", fontsize=12, fontweight="bold")
        axes[1].set_ylabel("Seconds")
        plt.xticks(rotation=15, ha="right")
        axes[1].grid(axis="y", alpha=0.3)

        plt.tight_layout()
        self.save_show(PLOTS_DIR / "eval_model_comparison.png")

    def plot_tsne(self, X: np.ndarray, y: np.ndarray, split_name: str) -> None:
        """t-SNE projection of feature space, coloured by class."""
        n = min(800, len(X))
        idx = np.random.choice(len(X), n, replace=False)
        X_s, y_s = X[idx], y[idx]

        logger.info(f"  t-SNE on {split_name} ({n} samples) ...")
        try:
            tsne = TSNE(n_components=2, perplexity=30, max_iter=500,
                        random_state=RANDOM_STATE)
            emb = tsne.fit_transform(X_s)

            fig, ax = plt.subplots(figsize=(8, 7))
            for cls, color, label in [(0, PALETTE["neg"], "Low Risk"),
                                       (1, PALETTE["pos"], "High Risk")]:
                mask = y_s == cls
                ax.scatter(emb[mask, 0], emb[mask, 1],
                           c=color, alpha=0.6, s=20, label=label)
            ax.set_title(f"t-SNE Feature Space ({split_name})",
                         fontsize=12, fontweight="bold")
            ax.legend()
            ax.grid(alpha=0.3)
            plt.tight_layout()
            self.save_show(PLOTS_DIR / f"viz_tsne_{split_name}.png")
        except Exception as exc:
            logger.warning(f"  t-SNE failed: {exc}")
        force_cleanup()

    def plot_cv_results(self, cv_summary: dict) -> None:
        """Box-plot style summary of cross-validation results."""
        metrics = list(cv_summary.keys())
        means   = [cv_summary[m]["mean"] for m in metrics]
        stds    = [cv_summary[m]["std"]  for m in metrics]

        fig, ax = plt.subplots(figsize=(8, 5))
        x = np.arange(len(metrics))
        bars = ax.bar(x, means, yerr=stds, capsize=6,
                      color="#7E57C2", alpha=0.85, error_kw=dict(linewidth=2))
        ax.set_xticks(x); ax.set_xticklabels(metrics, fontsize=11)
        ax.set_ylim(0, 1.1)
        ax.set_title("5-Fold Cross-Validation Results (mean ± std)",
                     fontsize=12, fontweight="bold")
        ax.set_ylabel("Score")
        for bar, m, s in zip(bars, means, stds):
            ax.text(bar.get_x() + bar.get_width()/2,
                    bar.get_height() + s + 0.01,
                    f"{m:.3f}±{s:.3f}", ha="center", fontsize=9)
        ax.grid(axis="y", alpha=0.3)
        plt.tight_layout()
        self.save_show(PLOTS_DIR / "eval_cv_results.png")


# ══════════════════════════════════════════════════════════════════════════════
#  SECTION 14 — MODEL SAVING
# ══════════════════════════════════════════════════════════════════════════════

def save_best_model(model, model_name: str, metrics: dict,
                    feature_names: list, scaler, engineer) -> None:
    """Save best model + metadata."""
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    model_path = MODELS_DIR / f"best_model_{model_name}_{ts}.pkl"
    meta_path  = MODELS_DIR / f"best_model_meta_{ts}.json"

    joblib.dump({"model": model, "scaler": scaler,
                 "feature_names": feature_names,
                 "selected_features": engineer.selected_features,
                 "pca": engineer.pca}, model_path)

    meta = {
        "model_name":  model_name,
        "timestamp":   ts,
        "val_accuracy":    metrics.get("val_accuracy", 0),
        "test_accuracy":   metrics.get("test_accuracy", 0),
        "holdout_accuracy":metrics.get("holdout_accuracy", 0),
        "val_auc":         metrics.get("val_auc", 0),
        "features":        feature_names[:50],
        "n_features":      len(feature_names),
    }
    with open(meta_path, "w") as f:
        json.dump(meta, f, indent=2)

    logger.info(f"  Model saved: {model_path}")
    logger.info(f"  Meta  saved: {meta_path}")


# ══════════════════════════════════════════════════════════════════════════════
#  SECTION 15 — FINAL REPORT
# ══════════════════════════════════════════════════════════════════════════════

def generate_text_report(all_metrics: dict, cv_summary: dict,
                          verdict: str, data_source: str,
                          best_model_name: str) -> None:
    """Write comprehensive text report."""
    ts   = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    path = REPORTS_DIR / "health_risk_pipeline_report.txt"

    lines = [
        "=" * 70,
        "  HEALTH RISK SIGNAL DETECTION — PIPELINE REPORT",
        f"  Generated: {ts}",
        "=" * 70,
        f"  Data Source: {data_source}",
        f"  Best Model:  {best_model_name}",
        "",
        "SPLIT EVALUATION",
        "-" * 40,
    ]
    for split, m in all_metrics.items():
        lines += [
            f"  {split.upper():8s}  n={m['n_samples']:5d}  "
            f"acc={m['accuracy']:.4f}  auc={m['roc_auc']:.4f}  "
            f"f1={m['f1']:.4f}  prec={m['precision']:.4f}  rec={m['recall']:.4f}",
        ]

    lines += ["", "CROSS-VALIDATION (5-fold)", "-" * 40]
    for k, s in cv_summary.items():
        lines.append(f"  {k:12s}: {s['mean']:.4f} ± {s['std']:.4f}")

    lines += ["", "OVERFITTING VERDICT", "-" * 40, f"  {verdict}", ""]
    lines += ["GENERALISATION GAPS", "-" * 40]
    train_acc = all_metrics.get("train", {}).get("accuracy", 0)
    for split in ["val", "test", "holdout"]:
        gap = train_acc - all_metrics.get(split, {}).get("accuracy", train_acc)
        lines.append(f"  Train − {split:7s}: {gap:+.4f}")

    lines += ["", "PLOTS GENERATED", "-" * 40]
    for p in sorted(PLOTS_DIR.glob("*.png")):
        lines.append(f"  {p.name}")

    report_text = "\n".join(lines)
    with open(path, "w") as f:
        f.write(report_text)
    print(report_text)
    logger.info(f"  Report saved: {path}")


# ══════════════════════════════════════════════════════════════════════════════
#  MAIN PIPELINE ORCHESTRATOR
# ══════════════════════════════════════════════════════════════════════════════

def main():
    print("=" * 70)
    print("  HEALTH RISK SIGNAL DETECTION — AI/ML PIPELINE")
    print("  Using open-source self-reported longitudinal datasets")
    print("=" * 70)
    log_memory("pipeline start")

    # ── 1. Load Data ───────────────────────────────────────────────────
    loader = DataLoader()
    df = loader.load()
    log_memory("after data load")

    # ── 2. Identify target column ──────────────────────────────────────
    TARGET_COL = "health_risk"
    # Try to auto-detect a suitable target if not present
    if TARGET_COL not in df.columns:
        candidates = ["target", "heart_disease", "disease", "label",
                      "cardiovascular_disease", "outcome"]
        for cand in candidates:
            if cand in df.columns:
                df = df.rename(columns={cand: TARGET_COL})
                break
        if TARGET_COL not in df.columns:
            # Build from available clinical columns
            from sklearn.preprocessing import Binarizer
            num_df = df.select_dtypes(include=np.number)
            if not num_df.empty:
                score = StandardScaler().fit_transform(num_df).mean(axis=1)
                df[TARGET_COL] = (score > np.percentile(score, 65)).astype(int)
            else:
                raise RuntimeError("Cannot determine target column.")

    # Binary target check
    df[TARGET_COL] = (df[TARGET_COL] > 0).astype(np.int8)
    logger.info(f"Target: {TARGET_COL}  pos_rate={df[TARGET_COL].mean():.2%}  n={len(df)}")

    # ── 3. EDA ─────────────────────────────────────────────────────────
    preproc = Preprocessor(target_col=TARGET_COL)
    preproc.run_eda(df)
    force_cleanup()

    # ── 4. Clean & Preprocess ─────────────────────────────────────────
    df = preproc.clean(df)
    df = preproc.remove_outliers_iqr(df, preproc.numerical_cols)
    log_memory("after clean")

    # ── 5. Feature Engineering ────────────────────────────────────────
    eng = FeatureEngineer()
    df = eng.engineer(df, target_col=TARGET_COL)
    df = df.fillna(df.median(numeric_only=True))    # Final NaN sweep
    log_memory("after feature engineering")

    # ── 6. Separate features / target ─────────────────────────────────
    feature_cols = [c for c in df.columns if c != TARGET_COL]
    X_df = df[feature_cols].select_dtypes(include=[np.number])
    y    = df[TARGET_COL].values.astype(np.int8)

    logger.info(f"Feature matrix: {X_df.shape}  Target distribution: {np.bincount(y)}")

    # ── 7. Feature Selection ──────────────────────────────────────────
    X_df = eng.select_features(X_df, pd.Series(y), n_features=35)
    feature_names = X_df.columns.tolist()
    X_raw = X_df.values.astype(np.float32)
    force_cleanup(df, X_df)

    # ── 8. Four-way split ─────────────────────────────────────────────
    splits = four_way_split(X_raw, y)
    X_train, y_train = splits["train"]
    X_val,   y_val   = splits["val"]
    X_test,  y_test  = splits["test"]
    X_hold,  y_hold  = splits["holdout"]

    # ── 9. Scaling (fit on train only) ────────────────────────────────
    X_train, X_val, X_test, X_hold = preproc.scale(
        X_train, X_val, X_test, X_hold)
    log_memory("after scaling")

    # ── 10. PCA augmentation ──────────────────────────────────────────
    n_pca = min(8, X_train.shape[1] - 1)
    X_train_pca = eng.compute_pca(X_train, n_components=n_pca)
    X_val_pca   = np.hstack([X_val,  eng.pca.transform(X_val)])
    X_test_pca  = np.hstack([X_test, eng.pca.transform(X_test)])
    X_hold_pca  = np.hstack([X_hold, eng.pca.transform(X_hold)])
    logger.info(f"Shape with PCA: {X_train_pca.shape}")

    # ── 11. Sample weights ────────────────────────────────────────────
    sample_weights = compute_sample_weights(X_train_pca, y_train)
    log_memory("after sample weights")

    # ── 12. Train classical models ────────────────────────────────────
    classic_models = build_classical_models(X_train_pca.shape[1])
    base_results   = train_classical_models(
        classic_models, X_train_pca, y_train,
        X_val_pca,   y_val, sample_weights)
    log_memory("after classical training")

    # ── 13. Hyperparameter tuning (best classical model) ──────────────
    best_cls_name = max(base_results, key=lambda n: base_results[n]["val_acc"])
    logger.info(f"Best classical model: {best_cls_name} "
                f"(val_acc={base_results[best_cls_name]['val_acc']:.4f})")

    tuned_model = tune_hyperparameters(
        X_train_pca, y_train, X_val_pca, y_val, best_cls_name)
    if tuned_model is not None:
        tuned_val_acc = accuracy_score(y_val, tuned_model.predict(X_val_pca))
        if tuned_val_acc > base_results[best_cls_name]["val_acc"]:
            base_results[best_cls_name]["model"] = tuned_model
            base_results[best_cls_name]["val_acc"] = tuned_val_acc
            logger.info(f"  Tuning improved: {tuned_val_acc:.4f}")
    log_memory("after hyperparameter tuning")

    # ── 14. Neural Network ────────────────────────────────────────────
    nn_model, nn_history = train_neural_network(
        X_train_pca, y_train, X_val_pca, y_val, sample_weights)
    log_memory("after NN training")

    # ── 15. Stacking Ensemble ─────────────────────────────────────────
    stacking_result = build_stacking_ensemble(
        base_results, X_train_pca, y_train, X_val_pca, y_val)
    if stacking_result is not None:
        stacking_model, stack_val_acc = stacking_result
        base_results["Stacking"] = {
            "model":   stacking_model,
            "val_acc": stack_val_acc,
            "val_auc": roc_auc_score(y_val,
                           stacking_model.predict_proba(X_val_pca)[:, 1]),
            "time_s":  0,
        }
    log_memory("after ensemble")

    # ── 16. Select overall best classical model ───────────────────────
    best_model_name = max(base_results, key=lambda n: base_results[n]["val_acc"])
    best_model      = base_results[best_model_name]["model"]
    logger.info(f"Final best model: {best_model_name}")

    # ── 17. Cross-validation ──────────────────────────────────────────
    X_trainval = np.vstack([X_train_pca, X_val_pca])
    y_trainval = np.concatenate([y_train, y_val])
    cv_summary = run_cross_validation(best_model, X_trainval, y_trainval)

    # ── 18. Evaluate on all splits ────────────────────────────────────
    logger.info("=" * 50)
    logger.info("FINAL EVALUATION — ALL SPLITS")
    logger.info("=" * 50)
    all_metrics = {}
    split_data  = {
        "train":   (X_train_pca, y_train),
        "val":     (X_val_pca,   y_val),
        "test":    (X_test_pca,  y_test),
        "holdout": (X_hold_pca,  y_hold),
    }
    for split_name, (Xs, ys) in tqdm(split_data.items(), desc="Evaluating splits"):
        m = evaluate_model(best_model, Xs, ys, split_name, nn_model)
        m["y_true"] = ys
        all_metrics[split_name] = m

    # ── 19. Overfitting detection ─────────────────────────────────────
    verdict = detect_overfitting(
        all_metrics["train"], all_metrics["val"],
        all_metrics["test"],  all_metrics["holdout"])

    # ── 20. Save best model ───────────────────────────────────────────
    best_metrics = {
        "val_accuracy":     all_metrics["val"]["accuracy"],
        "test_accuracy":    all_metrics["test"]["accuracy"],
        "holdout_accuracy": all_metrics["holdout"]["accuracy"],
        "val_auc":          all_metrics["val"]["roc_auc"],
    }
    save_best_model(best_model, best_model_name, best_metrics,
                    feature_names, preproc.scaler, eng)

    # ── 21. Visualisations ────────────────────────────────────────────
    viz = Visualiser()

    with tqdm(total=10, desc="Generating plots") as pbar:
        viz.plot_roc_curves(all_metrics);           pbar.update(1); force_cleanup()
        viz.plot_precision_recall(all_metrics);     pbar.update(1); force_cleanup()
        viz.plot_confusion_matrices(all_metrics);   pbar.update(1); force_cleanup()
        viz.plot_metrics_bar(all_metrics);          pbar.update(1); force_cleanup()
        viz.plot_generalization_gap(all_metrics);   pbar.update(1); force_cleanup()
        viz.plot_nn_history(nn_history);            pbar.update(1); force_cleanup()
        viz.plot_holdout_deep_dive(all_metrics["holdout"]); pbar.update(1); force_cleanup()
        viz.plot_model_comparison(base_results);    pbar.update(1); force_cleanup()
        viz.plot_cv_results(cv_summary);            pbar.update(1); force_cleanup()
        viz.plot_tsne(X_hold_pca, y_hold, "holdout"); pbar.update(1); force_cleanup()

    viz.plot_permutation_importance(
        best_model, X_hold_pca, y_hold,
        feature_names + [f"PC{i}" for i in range(n_pca)])

    # ── 22. Final report ─────────────────────────────────────────────
    generate_text_report(all_metrics, cv_summary, verdict,
                         loader.data_source, best_model_name)

    log_memory("pipeline end")
    logger.info("=" * 60)
    logger.info("PIPELINE COMPLETE")
    logger.info(f"  Holdout accuracy: {all_metrics['holdout']['accuracy']:.4f}")
    logger.info(f"  Holdout AUC:      {all_metrics['holdout']['roc_auc']:.4f}")
    logger.info(f"  Plots:            {PLOTS_DIR}/")
    logger.info(f"  Models:           {MODELS_DIR}/")
    logger.info(f"  Reports:          {REPORTS_DIR}/")
    logger.info("=" * 60)

    # Final memory sweep
    force_cleanup(X_train_pca, X_val_pca, X_test_pca, X_hold_pca,
                  X_raw, y_train, y_val, y_test, y_hold)
    gc.collect()


if __name__ == "__main__":
    main()

E0000 00:00:1771016292.936209      55 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1771016293.014313      55 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1771016293.598820      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1771016293.598910      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1771016293.598913      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1771016293.598917      55 computation_placer.cc:177] computation placer already registered. Please check linka

  HEALTH RISK SIGNAL DETECTION — AI/ML PIPELINE
  Using open-source self-reported longitudinal datasets


  Downloading NHANES XPT: 100%|██████████| 13/13 [00:03<00:00,  3.57it/s]
2026-02-13 20:58:32,996 [INFO]   Attempting: UCI Heart Disease
2026-02-13 20:58:33,602 [INFO]   UCI Heart loaded from https://archive.ics.uci.edu/ml/machine-learning-da
2026-02-13 20:58:33,603 [INFO]    Loaded UCI Heart Disease: (303, 14)
2026-02-13 20:58:33,617 [INFO] [MEM after load] Process=1018MB  System=5.1% used  Available=31959MB
2026-02-13 20:58:33,619 [INFO] [MEM after data load] Process=1018MB  System=5.1% used  Available=31959MB
2026-02-13 20:58:33,627 [INFO] Target: health_risk  pos_rate=45.87%  n=303
2026-02-13 20:58:33,628 [INFO] Running EDA ...
2026-02-13 20:58:33,658 [INFO] Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.
2026-02-13 20:58:33,661 [INFO] Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings

2026-02-13 20:59:47,044 [INFO]   Model: "AttentionNet"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)        ┃ Output Shape      ┃    Param # ┃ Connected to      ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ input (InputLayer)  │ (None, 23)        │          0 │ -                 │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ gaussian_noise      │ (None, 23)        │          0 │ input[0][0]       │
│ (GaussianNoise)     │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense (Dense)       │ (None, 128)       │      3,072 │ gaussian_noise[0… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ batch_normalization │ (None, 128)       │        512 │ dense[0][0]       │
│ (BatchNormalizatio… │                   │            │                   │
├────────────────────

  HEALTH RISK SIGNAL DETECTION — PIPELINE REPORT
  Generated: 2026-02-13 21:00:28
  Data Source: UCI Heart Disease
  Best Model:  GradientBoosting

SPLIT EVALUATION
----------------------------------------
  TRAIN     n=  120  acc=0.9000  auc=0.9698  f1=0.8929  prec=0.8772  rec=0.9091
  VAL       n=   46  acc=0.8478  auc=0.9238  f1=0.8205  prec=0.8889  rec=0.7619
  TEST      n=   46  acc=0.8913  auc=0.9467  f1=0.8837  prec=0.8636  rec=0.9048
  HOLDOUT   n=   91  acc=0.8352  auc=0.9291  f1=0.8352  prec=0.7755  rec=0.9048

CROSS-VALIDATION (5-fold)
----------------------------------------
  accuracy    : 0.7717 ± 0.0651
  roc_auc     : 0.8431 ± 0.0579
  f1          : 0.7456 ± 0.0779

OVERFITTING VERDICT
----------------------------------------
   GOOD GENERALISATION — train/holdout gap within healthy range

GENERALISATION GAPS
----------------------------------------
  Train − val    : +0.0522
  Train − test   : +0.0087
  Train − holdout: +0.0648

PLOTS GENERATED
------------------------