# Autism Screening: End-to-End EDA and PyTorch Baseline

This notebook explores a tabular dataset end-to-end using pandas, numpy, plotly, and PyTorch. It includes:

- Data overview and compact data dictionary
- Type inference and cleaning
- EDA with interactive Plotly visuals
- Preprocessing (encoding, scaling, splitting)
- PyTorch MLP baseline with early stopping
- Evaluation metrics and diagnostics
- Permutation importance for explainability
- Summary and next steps

All steps are reproducible with fixed seeds and configurable via the `CONFIG` dict below.


In [None]:
# Config & Imports
# Comment: Centralize configuration and set seeds for reproducibility.
import os
import sys
import json
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd

import plotly.express as px
import plotly.graph_objects as go

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# Reproducibility
def set_seeds(seed: int = 42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

CONFIG: Dict[str, object] = {
    "csv_path": "autism_screening.csv",  # Update if needed
    "target": None,  # e.g., "Class/ASD" or None
    "problem_type": "unspecified",  # "classification" | "regression" | "unspecified"
    "test_size": 0.2,
    "val_size": 0.2,
    "seed": 42,
    "max_cat_cardinality": 50,  # threshold for treating as categorical
    "n_bins_hist": 30,
}

set_seeds(int(CONFIG["seed"]))

# Print package versions for reproducibility
print("Versions:")
print("python:", sys.version.split()[0])
print("pandas:", pd.__version__)
print("numpy:", np.__version__)
print("torch:", torch.__version__)

# Utility: Safe display helper
def display_heading(text: str):
    print("\n" + "=" * 80)
    print(text)
    print("=" * 80)

# Utility: Numpy vectorized helpers
def np_zscore(x: np.ndarray) -> np.ndarray:
    """Vectorized z-score with numerical stability for 1D or 2D arrays."""
    mean = np.nanmean(x, axis=0)
    std = np.nanstd(x, axis=0)
    std = np.where(std == 0, 1.0, std)
    return (x - mean) / std

def np_winsorize(x: np.ndarray, lower_q: float = 0.01, upper_q: float = 0.99) -> np.ndarray:
    """Vectorized winsorization using numpy percentiles (per-column)."""
    if x.ndim == 1:
        x = x.reshape(-1, 1)
    lows = np.nanpercentile(x, lower_q * 100, axis=0, method="linear")
    highs = np.nanpercentile(x, upper_q * 100, axis=0, method="linear")
    return np.clip(x, lows, highs)



In [None]:
# Load & Inspect
# Comment: Load CSV and produce a compact overview and data dictionary.
from pathlib import Path

csv_path = Path(CONFIG["csv_path"]).expanduser()
assert csv_path.exists(), f"CSV not found at {csv_path}"

df = pd.read_csv(csv_path)

display_heading("Basic Overview")
print("shape:", df.shape)
print("dtypes:\n", df.dtypes)
print("duplicates:", int(df.duplicated().sum()))
print("missing per column:\n", df.isna().sum())

# Build data dictionary
# For categoricals: list # unique and example values; for numericals: min/mean/std/max

def infer_types(df: pd.DataFrame, max_cat_cardinality: int) -> Tuple[List[str], List[str], List[str]]:
    numeric_cols: List[str] = []
    categorical_cols: List[str] = []
    datetime_cols: List[str] = []
    for col in df.columns:
        s = df[col]
        if pd.api.types.is_datetime64_any_dtype(s):
            datetime_cols.append(col)
        elif pd.api.types.is_numeric_dtype(s):
            numeric_cols.append(col)
        else:
            # try parse datetime if many parse successes
            parsed = pd.to_datetime(s, errors="coerce", infer_datetime_format=True)
            parsed_ratio = parsed.notna().mean()
            if parsed_ratio > 0.9:
                datetime_cols.append(col)
            else:
                # treat as categorical if cardinality is not too high
                nunique = s.nunique(dropna=True)
                if nunique <= max_cat_cardinality:
                    categorical_cols.append(col)
                else:
                    # high-cardinality non-numeric => treat as categorical nevertheless
                    categorical_cols.append(col)
    return numeric_cols, categorical_cols, datetime_cols

numeric_cols, categorical_cols, datetime_cols = infer_types(df, int(CONFIG["max_cat_cardinality"]))

# Create data dictionary rows
rows = []
for col in df.columns:
    s = df[col]
    dtype = str(s.dtype)
    missing_pct = float(s.isna().mean() * 100)
    entry: Dict[str, object] = {
        "column": col,
        "dtype": dtype,
        "%missing": round(missing_pct, 2),
    }
    if col in numeric_cols:
        desc = s.describe(percentiles=[])
        entry.update({
            "min": float(desc.get("min", np.nan)),
            "mean": float(desc.get("mean", np.nan)),
            "std": float(desc.get("std", np.nan)),
            "max": float(desc.get("max", np.nan)),
        })
    elif col in categorical_cols:
        nunique = int(s.nunique(dropna=True))
        examples = s.dropna().astype(str).unique()[:5]
        entry.update({
            "#unique": nunique,
            "examples": ", ".join(map(str, examples)),
        })
    elif col in datetime_cols:
        # parse to datetime for min/max
        parsed = pd.to_datetime(s, errors="coerce")
        entry.update({
            "min": str(parsed.min()) if parsed.notna().any() else None,
            "max": str(parsed.max()) if parsed.notna().any() else None,
        })
    rows.append(entry)

data_dict_df = pd.DataFrame(rows)

display_heading("Data Dictionary (compact)")
print(data_dict_df.to_string(index=False))


In [None]:
# Type Inference & Cleaning
# Comment: Normalize string categories and strip whitespace; prepare parsed datetimes.

def normalize_categorical_series(s: pd.Series) -> pd.Series:
    if s.dtype == object or pd.api.types.is_string_dtype(s):
        return s.astype(str).str.strip().str.lower().replace({'nan': np.nan})
    return s

# Apply normalization to non-numeric columns
for col in df.columns:
    if col not in df.select_dtypes(include=[np.number]).columns:
        df[col] = normalize_categorical_series(df[col])

# Parse datetime columns detected earlier
for col in datetime_cols:
    df[col] = pd.to_datetime(df[col], errors='coerce', infer_datetime_format=True)

# Recompute missingness after normalization
missing_after = df.isna().sum()

display_heading("Missingness After Cleaning (counts)")
print(missing_after[missing_after > 0].sort_values(ascending=False))

# Imputation plan: numeric -> median; categorical -> most frequent; datetime -> leave, or fill with median date
from collections import Counter

def impute_dataframe(df_in: pd.DataFrame,
                     numeric_cols: List[str],
                     categorical_cols: List[str],
                     datetime_cols: List[str]) -> Tuple[pd.DataFrame, Dict[str, Dict[str, object]]]:
    df_out = df_in.copy()
    plan: Dict[str, Dict[str, object]] = {}
    for col in numeric_cols:
        median_val = float(df_out[col].median()) if df_out[col].notna().any() else 0.0
        df_out[col] = df_out[col].fillna(median_val)
        plan[col] = {"type": "numeric", "strategy": "median", "value": median_val}
    for col in categorical_cols:
        mode_val = df_out[col].mode(dropna=True)
        fill_val = mode_val.iloc[0] if not mode_val.empty else "missing"
        df_out[col] = df_out[col].fillna(fill_val)
        plan[col] = {"type": "categorical", "strategy": "most_frequent", "value": fill_val}
    for col in datetime_cols:
        # fill with median timestamp if available
        ts = pd.to_datetime(df_out[col], errors='coerce')
        if ts.notna().any():
            med = ts.dropna().median()
            df_out[col] = ts.fillna(med)
            plan[col] = {"type": "datetime", "strategy": "median", "value": str(med)}
        else:
            plan[col] = {"type": "datetime", "strategy": "none", "value": None}
    return df_out, plan

clean_df, impute_plan = impute_dataframe(df, numeric_cols, categorical_cols, datetime_cols)

display_heading("Imputation Plan Summary (per column)")
print(json.dumps(impute_plan, indent=2)[:4000])  # truncate if long


In [None]:
# Missingness Visuals (Plotly)
# Comment: Plot missing counts and a simple heatmap-like visualization.

missing_counts = clean_df.isna().sum().sort_values(ascending=False)
missing_df = missing_counts.reset_index()
missing_df.columns = ["column", "missing_count"]

fig_mc = px.bar(missing_df, x="column", y="missing_count", title="Missing Values per Column",
                labels={"missing_count": "Missing Count", "column": "Column"})
fig_mc.update_layout(xaxis_tickangle=-45, height=450)
fig_mc.show()

# Heatmap-like missingness: show a sample to avoid huge rendering
sample_n = min(500, len(clean_df))
ms_sample = clean_df.sample(sample_n, random_state=int(CONFIG["seed"]))
miss_bool = ms_sample.isna().astype(int)
fig_mh = px.imshow(miss_bool.T, aspect="auto", color_continuous_scale=[(0.0, "#1a1a1a"), (1.0, "#e74c3c")],
                   title="Missingness Heatmap (1=Missing) - Transposed for readability")
fig_mh.update_yaxes(title="Columns")
fig_mh.update_xaxes(title="Sample Rows")
fig_mh.show()


In [None]:
# Numeric EDA (hist+KDE, box, correlations)
# Comment: Iterate over numeric columns to plot hist/KDE and boxplots; then correlations.

num_cols = [c for c in numeric_cols if c in clean_df.columns]

for col in num_cols:
    s = clean_df[col]
    fig_h = px.histogram(clean_df, x=col, nbins=int(CONFIG["n_bins_hist"]), marginal="violin",
                         title=f"Histogram + Violin for {col}", opacity=0.85)
    fig_h.update_traces(marker_color="#2ecc71")
    fig_h.update_layout(bargap=0.05)
    fig_h.show()

    fig_b = px.box(clean_df, y=col, points="outliers", title=f"Boxplot (Outliers) for {col}")
    fig_b.update_traces(marker_color="#e67e22")
    fig_b.show()

# Correlations (Pearson & Spearman)
if len(num_cols) >= 2:
    pearson_corr = clean_df[num_cols].corr(method='pearson')
    spearman_corr = clean_df[num_cols].corr(method='spearman')

    fig_cp = px.imshow(pearson_corr, text_auto=False, color_continuous_scale="RdBu_r",
                       title="Pearson Correlation (Numeric Features)")
    fig_cp.update_xaxes(side="bottom")
    fig_cp.show()

    fig_cs = px.imshow(spearman_corr, text_auto=False, color_continuous_scale="RdBu_r",
                       title="Spearman Correlation (Numeric Features)")
    fig_cs.update_xaxes(side="bottom")
    fig_cs.show()
else:
    print("Not enough numeric columns for correlations.")


In [None]:
# Numpy Utilities Demo: Z-score Outlier Rates and Winsorization Preview
# Comment: Use vectorized numpy ops to flag outliers and preview winsorization effects.

if len(num_cols) > 0:
    Xn = clean_df[num_cols].to_numpy(dtype=float)
    Z = np_zscore(Xn)
    outlier_mask = np.abs(Z) > 3.0
    outlier_rates = outlier_mask.mean(axis=0)
    outlier_series = pd.Series(outlier_rates, index=num_cols).sort_values(ascending=False)

    fig_or = px.bar(outlier_series.head(25).sort_values(), orientation="h",
                    title="Outlier Rate (>3σ) by Numeric Feature (Top 25)",
                    labels={"value": "Outlier Rate", "index": "Feature"})
    fig_or.show()

    # Winsorization preview (not used downstream, just diagnostic)
    Xw = np_winsorize(Xn, 0.01, 0.99)
    pre_std = np.nanstd(Xn, axis=0)
    post_std = np.nanstd(Xw, axis=0)
    std_drop = pd.Series(pre_std - post_std, index=num_cols)
    fig_w = px.bar(std_drop.head(25).sort_values(), orientation="h",
                   title="Std Reduction after Winsorization (Top 25)",
                   labels={"value": "Std Reduction", "index": "Feature"})
    fig_w.show()
else:
    print("No numeric columns available for numpy utilities demo.")


In [None]:
# Categorical & Relationships EDA
# Comment: Top categories, stacked bars vs target (if exists), scatter and violin.

target_col = CONFIG["target"]
cat_cols = [c for c in categorical_cols if c in clean_df.columns]

# Top categories bar charts
for col in cat_cols:
    vc = clean_df[col].value_counts(dropna=False).reset_index()
    vc.columns = [col, "count"]
    fig_bar = px.bar(vc.head(20), x=col, y="count", title=f"Top Categories for {col}")
    fig_bar.update_layout(xaxis_tickangle=-45)
    fig_bar.show()

# Stacked bars vs target (classification-like view)
if isinstance(target_col, str) and target_col in clean_df.columns:
    if CONFIG["problem_type"] == "classification" or clean_df[target_col].nunique() <= 20:
        for col in cat_cols:
            if col == target_col:
                continue
            cross = clean_df.groupby([col, target_col]).size().reset_index(name="count")
            fig_stack = px.bar(cross, x=col, y="count", color=target_col, barmode="stack",
                               title=f"Stacked Bar: {col} vs {target_col}")
            fig_stack.update_layout(xaxis_tickangle=-45)
            fig_stack.show()

# Relationships: scatter with trendline (for numeric pairs)
if len(num_cols) >= 2:
    x_col = num_cols[0]
    for y_col in num_cols[1:3]:
        fig_sc = px.scatter(clean_df, x=x_col, y=y_col, trendline="ols",
                            title=f"Scatter with Trendline: {x_col} vs {y_col}")
        fig_sc.show()

# Violin/strip for target vs key features
if isinstance(target_col, str) and target_col in clean_df.columns:
    if target_col in num_cols:
        # numeric target vs categorical features: violin by category
        for col in cat_cols[:3]:
            fig_vi = px.violin(clean_df, x=col, y=target_col, box=True, points="all",
                               title=f"Violin: {target_col} by {col}")
            fig_vi.show()
    else:
        # categorical target vs numeric features
        for col in num_cols[:5]:
            fig_st = px.strip(clean_df, x=target_col, y=col, title=f"Strip: {col} by {target_col}")
            fig_st.show()


In [None]:
# Preprocessing for Modeling
# Comment: Encode categoricals, scale numerics, build train/val/test splits.
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# Resolve target and problem type
TARGET = CONFIG["target"]
PROBLEM = CONFIG["problem_type"]

if isinstance(TARGET, str) and TARGET in clean_df.columns:
    y = clean_df[TARGET]
    X = clean_df.drop(columns=[TARGET])

    # Re-infer types on X only
    X_num, X_cat, X_dt = infer_types(X, int(CONFIG["max_cat_cardinality"]))

    # One-hot encoding choice: Simpler for MLP and avoids ordinal assumptions
    X_enc = pd.get_dummies(X, columns=X_cat, dummy_na=False, drop_first=False)

    # Scale numeric columns to stabilize training
    scaler = StandardScaler()
    # Identify numeric columns post-encoding
    enc_num_cols = [c for c in X_enc.columns if c in X_num]
    X_enc[enc_num_cols] = scaler.fit_transform(X_enc[enc_num_cols])

    # Convert y for classification/regression
    if PROBLEM == "classification":
        # Try to map string labels to integers deterministically
        if not pd.api.types.is_numeric_dtype(y):
            classes = sorted(y.dropna().astype(str).unique())
            class_to_idx = {c: i for i, c in enumerate(classes)}
            y_enc = y.astype(str).map(class_to_idx)
        else:
            y_enc = y.astype(int)
    elif PROBLEM == "regression":
        y_enc = pd.to_numeric(y, errors='coerce')
    else:
        y_enc = y  # unspecified; will skip modeling later if not supported

    # Train/Val/Test split (stratified if classification)
    test_size = float(CONFIG["test_size"])
    val_size = float(CONFIG["val_size"])

    stratify_vec = y_enc if PROBLEM == "classification" else None

    X_temp, X_test, y_temp, y_test = train_test_split(
        X_enc, y_enc, test_size=test_size, random_state=int(CONFIG["seed"]), stratify=stratify_vec)

    # Adjust val split relative to remaining
    val_relative = val_size / (1 - test_size)
    stratify_vec_temp = y_temp if PROBLEM == "classification" else None
    X_train, X_val, y_train, y_val = train_test_split(
        X_temp, y_temp, test_size=val_relative, random_state=int(CONFIG["seed"]), stratify=stratify_vec_temp)

    display_heading("Split Shapes")
    print("X_train:", X_train.shape, "X_val:", X_val.shape, "X_test:", X_test.shape)
    if PROBLEM == "classification":
        def show_balance(name, y_series):
            vc = y_series.value_counts(normalize=True).sort_index()
            print(name, ":", {int(k): round(float(v), 3) for k, v in vc.items()})
        show_balance("Train class balance", y_train)
        show_balance("Val class balance", y_val)
        show_balance("Test class balance", y_test)
else:
    print("No valid target provided; skipping modeling steps. Set CONFIG['target'] and 'problem_type'.")


In [None]:
# PyTorch Dataset & Model
# Comment: Tabular MLP baseline suitable for classification or regression.

if isinstance(TARGET, str) and TARGET in clean_df.columns and PROBLEM in ("classification", "regression"):

    class TabularDataset(Dataset):
        def __init__(self, X: pd.DataFrame, y: pd.Series):
            self.X = torch.tensor(X.values, dtype=torch.float32)
            # y dtype depends on problem
            if PROBLEM == "classification":
                self.y = torch.tensor(y.values, dtype=torch.long)
            else:
                self.y = torch.tensor(y.values, dtype=torch.float32).unsqueeze(1)

        def __len__(self):
            return self.X.shape[0]

        def __getitem__(self, idx):
            return self.X[idx], self.y[idx]

    input_dim = X_train.shape[1]

    class MLP(nn.Module):
        def __init__(self, input_dim: int, output_dim: int):
            super().__init__()
            hidden = 128
            self.net = nn.Sequential(
                nn.Linear(input_dim, hidden),
                nn.BatchNorm1d(hidden),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.Linear(hidden, 64),
                nn.BatchNorm1d(64),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.Linear(64, output_dim),
            )

        def forward(self, x):
            return self.net(x)

    if PROBLEM == "classification":
        num_classes = int(pd.Series(y_train).nunique())
        model = MLP(input_dim, num_classes)
        criterion = nn.CrossEntropyLoss()
    else:
        model = MLP(input_dim, 1)
        criterion = nn.MSELoss()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    train_ds = TabularDataset(X_train, y_train)
    val_ds = TabularDataset(X_val, y_val)
    test_ds = TabularDataset(X_test, y_test)

    train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=256, shuffle=False)
    test_loader = DataLoader(test_ds, batch_size=256, shuffle=False)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

    # Training loop with early stopping
    best_val = float('inf')
    best_state = None
    patience = 10
    wait = 0
    num_epochs = 100

    history = {"epoch": [], "train_loss": [], "val_loss": []}

    for epoch in range(1, num_epochs + 1):
        model.train()
        batch_losses = []
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad()
            out = model(xb)
            if PROBLEM == "classification":
                loss = criterion(out, yb)
            else:
                loss = criterion(out, yb)
            loss.backward()
            optimizer.step()
            batch_losses.append(loss.item())
        train_loss = float(np.mean(batch_losses))

        # Validation
        model.eval()
        val_losses = []
        with torch.no_grad():
            for xb, yb in val_loader:
                xb, yb = xb.to(device), yb.to(device)
                out = model(xb)
                loss = criterion(out, yb)
                val_losses.append(loss.item())
        val_loss = float(np.mean(val_losses))

        history["epoch"].append(epoch)
        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)

        print(f"Epoch {epoch:03d} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f}")

        if val_loss < best_val - 1e-6:
            best_val = val_loss
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            wait = 0
        else:
            wait += 1
            if wait >= patience:
                print("Early stopping")
                break

    # Restore best
    if best_state is not None:
        model.load_state_dict(best_state)

    # Save artifacts
    ARTIFACTS = {
        "scaler_mean": scaler.mean_.tolist() if 'scaler' in locals() else None,
        "scaler_scale": scaler.scale_.tolist() if 'scaler' in locals() else None,
        "columns": X_enc.columns.tolist(),
        "problem": PROBLEM,
        "target": TARGET,
    }
    torch.save(model.state_dict(), "model_state_dict.pt")
    with open("artifacts.json", "w") as f:
        json.dump(ARTIFACTS, f, indent=2)
else:
    print("Modeling skipped due to missing/invalid target or problem type.")


In [None]:
# Evaluation & Diagnostics
# Comment: Compute metrics and plot learning curve; confusion matrix or residuals.
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix, classification_report
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score

if isinstance(TARGET, str) and TARGET in clean_df.columns and PROBLEM in ("classification", "regression"):
    # Learning curve
    hist_df = pd.DataFrame(history)
    fig_lc = px.line(hist_df, x="epoch", y=["train_loss", "val_loss"],
                     title="Learning Curve (Loss vs Epoch)")
    fig_lc.update_layout(yaxis_title="Loss")
    fig_lc.show()

    # Collect predictions
    model.eval()
    all_preds = []
    all_probs = []
    all_true = []
    with torch.no_grad():
        for xb, yb in test_loader:
            xb = xb.to(device)
            out = model(xb)
            if PROBLEM == "classification":
                probs = torch.softmax(out, dim=1).cpu().numpy()
                preds = probs.argmax(axis=1)
                all_probs.append(probs)
                all_preds.append(preds)
                all_true.append(yb.numpy())
            else:
                preds = out.cpu().numpy().squeeze()
                all_preds.append(preds)
                all_true.append(yb.numpy().squeeze())

    y_true = np.concatenate(all_true)
    y_pred = np.concatenate(all_preds)

    if PROBLEM == "classification":
        acc = accuracy_score(y_true, y_pred)
        f1 = f1_score(y_true, y_pred, average="weighted")
        try:
            # binary or multiclass OneVsRest AUROC
            num_classes = int(pd.Series(y_train).nunique())
            if num_classes == 2 and len(all_probs) > 0:
                y_prob = np.concatenate(all_probs)[:, 1]
                auroc = roc_auc_score(y_true, y_prob)
            else:
                y_prob = np.concatenate(all_probs)
                auroc = roc_auc_score(pd.get_dummies(y_true), y_prob, average="weighted", multi_class="ovr")
        except Exception:
            auroc = np.nan

        print({"accuracy": round(acc, 4), "f1_weighted": round(f1, 4), "auroc": None if np.isnan(auroc) else round(float(auroc), 4)})

        # Confusion matrix
        cm = confusion_matrix(y_true, y_pred)
        fig_cm = px.imshow(cm, text_auto=True, color_continuous_scale="Blues",
                           title="Confusion Matrix", labels=dict(x="Predicted", y="True", color="Count"))
        fig_cm.update_xaxes(title="Predicted")
        fig_cm.update_yaxes(title="True")
        fig_cm.show()
    else:
        mae = mean_absolute_error(y_true, y_pred)
        rmse = mean_squared_error(y_true, y_pred, squared=False)
        r2 = r2_score(y_true, y_pred)
        print({"MAE": round(mae, 4), "RMSE": round(rmse, 4), "R2": round(r2, 4)})

        # Residual diagnostics
        residuals = y_true - y_pred
        fig_res_sc = px.scatter(x=y_pred, y=residuals, labels={"x": "Predicted", "y": "Residual"},
                                title="Residuals vs Predicted")
        fig_res_sc.add_hline(y=0, line_dash="dash")
        fig_res_sc.show()

        fig_res_hist = px.histogram(residuals, nbins=40, title="Residuals Histogram")
        fig_res_hist.show()


In [None]:
# Explainability: Permutation Importance
# Comment: Compute permutation importance on validation set using numpy.

def permutation_importance(model: nn.Module, X: pd.DataFrame, y: pd.Series, metric_fn, n_repeats: int = 5) -> pd.Series:
    model.eval()
    X_np = X.values.astype(np.float32)
    y_np = y.values
    baseline = metric_fn(model, X_np, y_np)
    rng = np.random.default_rng(int(CONFIG["seed"]))
    importances = np.zeros(X_np.shape[1], dtype=float)

    for j in range(X_np.shape[1]):
        scores = []
        for _ in range(n_repeats):
            X_perm = X_np.copy()
            rng.shuffle(X_perm[:, j])
            score = metric_fn(model, X_perm, y_np)
            scores.append(baseline - score)  # drop in metric
        importances[j] = float(np.mean(scores))
    return pd.Series(importances, index=X.columns).sort_values(ascending=False)

# Define metric compatible with classification/regression
if isinstance(TARGET, str) and TARGET in clean_df.columns and PROBLEM in ("classification", "regression"):
    def metric_fn(model, Xb, yb):
        with torch.no_grad():
            xb = torch.tensor(Xb, dtype=torch.float32, device=device)
            out = model(xb)
            if PROBLEM == "classification":
                preds = out.softmax(dim=1).argmax(dim=1).cpu().numpy()
                return accuracy_score(yb, preds)
            else:
                preds = out.cpu().numpy().squeeze()
                return -mean_squared_error(yb, preds, squared=False)  # negative RMSE (higher is better)

    imp_series = permutation_importance(model, X_val, y_val, metric_fn, n_repeats=5)
    fig_imp = px.bar(imp_series.head(25).sort_values(), orientation="h",
                     title="Permutation Importance (Top 25)", labels={"value": "Importance", "index": "Feature"})
    fig_imp.show()
else:
    print("Explainability skipped due to missing/invalid target or problem type.")


## Summary & Next Steps

- Data quality: Reviewed missingness and applied column-wise imputation strategies (median, mode, median timestamp).
- Feature types: Inferred numeric, categorical, and datetime automatically; normalized string categories.
- EDA: Produced 8+ interactive visuals (histograms, box plots, correlations, bars, stacked bars, scatter with trendline, violin/strip, missingness heatmap).
- Encoding & scaling: Used one-hot for categoricals and standardization for numerics to support MLP stability.
- Baseline model: Trained a small MLP with early stopping; logged learning curves and metrics.
- Metrics: Printed appropriate metrics based on problem type (classification vs regression) and diagnostics.
- Explainability: Computed permutation importance on validation; displayed top features.
- Reproducibility: Seeds set; versions printed; artifacts saved (`model_state_dict.pt`, `artifacts.json`).

Recommendations:
- Validate target definition and problem type; consider class imbalance handling (class weights, resampling) if skewed.
- Engineer domain features (e.g., aggregations, ratios); consider interaction terms.
- Try stronger tabular models (Gradient Boosting, XGBoost, catboost) and compare.
- Hyperparameter tune MLP (layers, hidden size, dropout, LR schedule) and use k-fold CV.
- Calibrate probabilities for classification (Platt/Isotonic) if needed.
- Address high-cardinality categoricals with target or leave-one-out encoding if present.

