# ScreenMind — Milestone 5: Experiment Tracking with Weights & Biases

**Goal:** Train 3 model variants, log every run to W&B, and compare them on a live dashboard.

**Why W&B?**  
In Milestone 4 we trained one model and printed results to the screen.  That works for one run,
but as soon as you want to compare architectures or tune hyperparameters you need a systematic record.
W&B stores every metric, every hyperparameter, and every artifact from every run in a searchable
cloud dashboard — so you can see at a glance which config won and why.

**The 3 variants we will compare:**

| Variant | Architecture | Reg loss | Key idea |
|---------|-------------|----------|----------|
| A — Baseline | 128 → 64 | MSE | Exactly what we trained in M4 |
| B — Wider | 256 → 128 → 64 | MSE | More parameters → can it learn richer patterns? |
| C — Weighted loss | 128 → 64 | WeightedMSE | Penalise errors on high-day cases 2× harder |

Each variant trains **both** a classifier and a regressor.  All 6 runs appear in the same W&B project.

## 1. Imports & Setup

In [16]:
print("check")

check


In [17]:
import sys
sys.path.insert(0, '..')

import os
import numpy as np
import torch
import wandb

from src.data.preprocessing import load_processed
from src.models.mlp import MLP
from src.training.trainer import (
    make_loaders, make_criterion, train,
    evaluate_clf, evaluate_reg,
)

PROCESSED_DIR = '../data/processed'
MODELS_DIR    = '../data/models'
os.makedirs(MODELS_DIR, exist_ok=True)

data = load_processed(PROCESSED_DIR)
print('Data loaded.')
print(f'Device: {"cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"}')

Data loaded.
Device: cpu


## 2. How W&B Works

The W&B workflow has three steps:

```
wandb.init(project=..., config=...)   # 1. Open a new run, store hyperparams
    wandb.run.log({...})              # 2. Stream metrics each epoch
    wandb.run.summary.update({...})   # 3. Record final test metrics
wandb.finish()                        # 4. Close the run
```

**`config`** is a plain dict of all hyperparameters (architecture, lr, loss, etc.).
W&B stores it alongside the metrics so you can filter runs by config in the dashboard.

**`log`** is called once per epoch and streams numbers to the cloud in real time.
This is what draws the live loss curves on the dashboard.

**`summary`** is for final / aggregate metrics (test AUC, MAE, etc.) — values that
only exist after training is complete.

Our `train()` function already accepts `wandb_run=` — we just pass `wandb.run` to it.

## 3. Define the 3 Variant Configs

We express each variant as a plain Python dict.  This dict becomes the W&B config
and is also used to construct the model and criterion — one source of truth.

In [18]:
VARIANTS = [
    {
        "name":        "variant_a_baseline",
        "hidden_dims": [128, 64],
        "dropout":     0.3,
        "lr":          1e-3,
        "reg_loss":    "reg",          # plain MSELoss
        "batch_size":  512,
        "max_epochs":  100,
        "patience":    10,
        "description": "Baseline from Milestone 4 — replicated for fair comparison",
    },
    {
        "name":        "variant_b_wider",
        "hidden_dims": [256, 128, 64],  # extra layer, more parameters
        "dropout":     0.3,
        "lr":          1e-3,
        "reg_loss":    "reg",
        "batch_size":  512,
        "max_epochs":  100,
        "patience":    10,
        "description": "Wider MLP: 3 hidden layers instead of 2",
    },
    {
        "name":        "variant_c_weighted_loss",
        "hidden_dims": [128, 64],
        "dropout":     0.3,
        "lr":          1e-3,
        "reg_loss":    "reg_weighted",  # WeightedMSELoss
        "batch_size":  512,
        "max_epochs":  100,
        "patience":    10,
        "description": "Same architecture as A but WeightedMSE penalises high-day errors 2x",
    },
]

for v in VARIANTS:
    n_params = MLP(input_dim=15, hidden_dims=v['hidden_dims'], dropout=v['dropout']).count_parameters()
    print(f"{v['name']:<30}  hidden={v['hidden_dims']}  params={n_params:,}  reg_loss={v['reg_loss']}")

variant_a_baseline              hidden=[128, 64]  params=10,753  reg_loss=reg
variant_b_wider                 hidden=[256, 128, 64]  params=46,209  reg_loss=reg
variant_c_weighted_loss         hidden=[128, 64]  params=10,753  reg_loss=reg_weighted


## 4. W&B Login

Run this cell once to authenticate.  
It will open a browser tab asking you to paste your API key (free account at wandb.ai).

After the first login, W&B caches your key — you won't need to do this again.

In [10]:
wandb.login()  # opens browser / prompts for API key the first time

True

## 5. Run the Experiments

The loop below:
1. Opens a W&B run for each (variant, task) combination
2. Trains the model, streaming per-epoch losses to W&B
3. Evaluates on the test set and writes final metrics to the W&B summary
4. Closes the run

**6 runs total** (3 variants × 2 tasks: clf + reg).  
While it trains, open your [W&B project dashboard](https://wandb.ai) to watch the curves live.

⏱️ On CPU this takes ~15-25 minutes total — each variant trains at most 100 epochs, with early stopping.

In [11]:
all_results = []   # collect test metrics across all runs for the comparison table

for cfg in VARIANTS:
    for task in ('clf', 'reg'):

        run_name = f"{cfg['name']}_{task}"
        print(f"\n{'='*60}")
        print(f"Starting run: {run_name}")
        print(f"{'='*60}")

        # ── Open a W&B run ────────────────────────────────────────────────────
        # project: groups all runs together in one W&B project page
        # name:    human-readable label for this run in the dashboard
        # config:  stores all hyperparams so you can filter/sort by them later
        run = wandb.init(
            project="screenmind",
            name=run_name,
            config={**cfg, "task": task, "input_dim": 15},
            reinit=True,   # allow multiple inits in one notebook session
        )

        # ── Build model & loaders ─────────────────────────────────────────────
        model = MLP(
            input_dim=15,
            hidden_dims=cfg['hidden_dims'],
            dropout=cfg['dropout'],
            task=task,
        )

        train_loader, val_loader, test_loader = make_loaders(
            data, task=task, batch_size=cfg['batch_size']
        )

        # For the classifier always use 'clf'; for reg use whatever loss the variant specifies
        criterion_task = 'clf' if task == 'clf' else cfg['reg_loss']
        y_train = data['y_clf_train'] if task == 'clf' else data['y_reg_train']
        criterion = make_criterion(criterion_task, y_train)

        checkpoint_path = f'{MODELS_DIR}/{run_name}_best.pt'

        # ── Train (losses stream to W&B each epoch via wandb_run=run) ─────────
        history = train(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            criterion=criterion,
            checkpoint_path=checkpoint_path,
            lr=cfg['lr'],
            max_epochs=cfg['max_epochs'],
            patience=cfg['patience'],
            wandb_run=run,
        )

        # ── Evaluate on test set ──────────────────────────────────────────────
        model.load_state_dict(torch.load(checkpoint_path, weights_only=True))

        if task == 'clf':
            metrics = evaluate_clf(model, test_loader)
            print(f"  AUC={metrics['roc_auc']:.4f}  F1={metrics['f1']:.4f}  "
                  f"Recall={metrics['recall']:.4f}  Precision={metrics['precision']:.4f}")
        else:
            metrics = evaluate_reg(model, test_loader)
            print(f"  MAE={metrics['mae']:.3f}  RMSE={metrics['rmse']:.3f}  R²={metrics['r2']:.4f}")

        # ── Write final metrics to W&B summary (appears in the runs table) ────
        # summary values show up as columns in the W&B project table,
        # making it easy to sort runs by AUC or MAE at a glance.
        run.summary.update(metrics)
        run.summary["best_val_loss"] = min(history['val_loss'])
        run.summary["epochs_trained"] = len(history['val_loss'])

        all_results.append({"run": run_name, "task": task, **metrics})

        wandb.finish()

print("\n✓ All runs complete.")


Starting run: variant_a_baseline_clf


Training on cpu  |  max_epochs=100  |  patience=10
 Epoch    Train Loss      Val Loss    Best
--------------------------------------------
     1       0.86401       0.83318   ✓
     2       0.84638       0.83062   ✓
     3       0.84331       0.82943   ✓
     4       0.84115       0.82890   ✓
     5       0.83989       0.82988  
     6       0.83952       0.82975  
     7       0.83844       0.82843   ✓
     8       0.83818       0.82820   ✓
     9       0.83768       0.82789   ✓
    10       0.83697       0.82834  
    11       0.83658       0.82775   ✓
    12       0.83666       0.82723   ✓
    13       0.83663       0.82771  
    14       0.83622       0.82792  
    15       0.83512       0.82793  
    16       0.83560       0.82777  
    17       0.83500       0.82744  
    18       0.83566       0.82763  
    19       0.83502       0.82745  
    20       0.83434       0.82711   ✓
    21       0.83505       0.82705   ✓
    22       0.83483       0.82708  
    23       0.83509     

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


  AUC=0.8561  F1=0.4875  Recall=0.7746  Precision=0.3557


0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
train_loss,█▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁
val_loss,█▆▄▄▅▃▃▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▁▂▁▁▂▂▃▂▂

0,1
accuracy,0.78456
best_val_loss,0.82595
epoch,52
epochs_trained,52
f1,0.48753
n_samples,67428
precision,0.35571
recall,0.77458
roc_auc,0.85613
train_loss,0.83058



Starting run: variant_a_baseline_reg


Training on cpu  |  max_epochs=100  |  patience=10
 Epoch    Train Loss      Val Loss    Best
--------------------------------------------
     1      50.49163      46.05582   ✓
     2      47.18585      45.90093   ✓
     3      46.97400      45.79371   ✓
     4      46.88458      45.69613   ✓
     5      46.74306      45.71967  
     6      46.69943      45.66703   ✓
     7      46.66086      45.61792   ✓
     8      46.58563      45.71859  
     9      46.60715      45.64407  
    10      46.56724      45.62900  
    11      46.52922      45.60190   ✓
    12      46.50346      45.59019   ✓
    13      46.52180      45.63809  
    14      46.48891      45.59873  
    15      46.46805      45.59681  
    16      46.42696      45.57326   ✓
    17      46.44442      45.70117  
    18      46.45649      45.60152  
    19      46.41776      45.56942   ✓
    20      46.38597      45.54929   ✓
    21      46.36916      45.58523  
    22      46.36635      45.57831  
    23      46.34396     

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


  MAE=4.155  RMSE=6.720  R²=0.3521


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇█
train_loss,██▇▆▆▆▅▅▄▄▄▄▄▄▄▃▂▄▂▂▃▃▃▂▃▂▁▂▂▂▂▁▂▁▂▁▁▁▁▂
val_loss,█▆▅▄▃▃▃▂▂▂▂▂▂▁▁▃▂▂▁▂▁▁▂▂▁▁▂▁▁▁▁▂▁▂▁▁▁▁▁▁

0,1
best_val_loss,45.4903
epoch,78.0
epochs_trained,78.0
mae,4.15545
n_samples,67428.0
r2,0.35212
rmse,6.72001
train_loss,46.17944
val_loss,45.65292



Starting run: variant_b_wider_clf


Training on cpu  |  max_epochs=100  |  patience=10
 Epoch    Train Loss      Val Loss    Best
--------------------------------------------
     1       0.85767       0.83103   ✓
     2       0.84537       0.82989   ✓
     3       0.84214       0.82959   ✓
     4       0.83985       0.82853   ✓
     5       0.83958       0.82785   ✓
     6       0.83845       0.82787  
     7       0.83804       0.82904  
     8       0.83739       0.82805  
     9       0.83629       0.82694   ✓
    10       0.83650       0.82683   ✓
    11       0.83500       0.82682   ✓
    12       0.83588       0.82716  
    13       0.83519       0.82617   ✓
    14       0.83438       0.82652  
    15       0.83419       0.82776  
    16       0.83378       0.82612   ✓
    17       0.83370       0.82674  
    18       0.83356       0.82620  
    19       0.83285       0.82698  
    20       0.83223       0.82538   ✓
    21       0.83268       0.82612  
    22       0.83267       0.82566  
    23       0.83208     

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


  AUC=0.8564  F1=0.4847  Recall=0.7778  Precision=0.3520


0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███
train_loss,█▅▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁
val_loss,█▇▆▅▄▄▆▄▃▃▃▃▂▂▄▂▃▂▃▁▂▁▂▂▃▂▂▂▂▂

0,1
accuracy,0.78119
best_val_loss,0.82538
epoch,30
epochs_trained,30
f1,0.4847
n_samples,67428
precision,0.35204
recall,0.77783
roc_auc,0.85637
train_loss,0.83003



Starting run: variant_b_wider_reg


Training on cpu  |  max_epochs=100  |  patience=10
 Epoch    Train Loss      Val Loss    Best
--------------------------------------------
     1      49.89957      45.75487   ✓
     2      47.17452      45.72248   ✓
     3      47.11304      45.67145   ✓
     4      46.91575      45.64142   ✓
     5      46.82398      45.68896  
     6      46.72905      45.58446   ✓
     7      46.66932      45.54983   ✓
     8      46.64963      45.58336  
     9      46.64902      45.57214  
    10      46.56154      45.56769  
    11      46.58170      45.52601   ✓
    12      46.46164      45.55773  
    13      46.44419      45.52750  
    14      46.44137      45.53166  
    15      46.36834      45.53254  
    16      46.30252      45.50596   ✓
    17      46.35646      45.53694  
    18      46.38856      45.53119  
    19      46.36033      45.52502  
    20      46.25116      45.56150  
    21      46.31244      45.53346  
    22      46.24528      45.53128  
    23      46.16934      45.51

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


  MAE=4.187  RMSE=6.725  R²=0.3512


0,1
epoch,▁▁▂▂▂▂▃▃▃▄▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
train_loss,█▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,█▇▆▅▆▃▂▃▃▃▂▂▂▂▂▁▂▂▂▃▂▂▁▂▂▂

0,1
best_val_loss,45.50596
epoch,26.0
epochs_trained,26.0
mae,4.1865
n_samples,67428.0
r2,0.35123
rmse,6.72463
train_loss,46.18983
val_loss,45.55672



Starting run: variant_c_weighted_loss_clf


Training on cpu  |  max_epochs=100  |  patience=10
 Epoch    Train Loss      Val Loss    Best
--------------------------------------------
     1       0.86238       0.83131   ✓
     2       0.84539       0.82999   ✓
     3       0.84235       0.82879   ✓
     4       0.84071       0.82839   ✓
     5       0.83974       0.82901  
     6       0.83850       0.82818   ✓
     7       0.83880       0.82889  
     8       0.83728       0.82787   ✓
     9       0.83724       0.82801  
    10       0.83710       0.82812  
    11       0.83709       0.82745   ✓
    12       0.83605       0.82764  
    13       0.83562       0.82737   ✓
    14       0.83566       0.82759  
    15       0.83538       0.82670   ✓
    16       0.83485       0.82690  
    17       0.83542       0.82700  
    18       0.83545       0.82715  
    19       0.83490       0.82700  
    20       0.83492       0.82668   ✓
    21       0.83480       0.82727  
    22       0.83477       0.82717  
    23       0.83452       

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


  AUC=0.8563  F1=0.4864  Recall=0.7751  Precision=0.3544


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇██
train_loss,█▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,█▆▅▄▅▄▅▄▄▄▃▃▃▂▃▃▃▃▂▃▃▃▃▂▃▂▁▂▃▂▁▁▂▂▁▁▁▂▁▁

0,1
accuracy,0.78341
best_val_loss,0.82568
epoch,43
epochs_trained,43
f1,0.48639
n_samples,67428
precision,0.35438
recall,0.77514
roc_auc,0.8563
train_loss,0.832



Starting run: variant_c_weighted_loss_reg


Training on cpu  |  max_epochs=100  |  patience=10
 Epoch    Train Loss      Val Loss    Best
--------------------------------------------
     1      83.80939      72.99685   ✓
     2      75.12024      72.79371   ✓
     3      74.86382      72.58828   ✓
     4      74.59381      72.54819   ✓
     5      74.45407      72.78791  
     6      74.30140      72.56605  
     7      74.15848      72.46526   ✓
     8      74.02201      72.47620  
     9      74.07983      72.27989   ✓
    10      74.10324      72.38026  
    11      73.94366      72.32950  
    12      73.82160      72.33438  
    13      73.91891      72.36322  
    14      73.90058      72.25301   ✓
    15      73.83213      72.30286  
    16      73.77368      72.25423  
    17      73.76725      72.22230   ✓
    18      73.82047      72.46912  
    19      73.75924      72.37707  
    20      73.74145      72.23087  
    21      73.71404      72.33050  
    22      73.68846      72.34884  
    23      73.65688      72.21

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


  MAE=4.675  RMSE=6.879  R²=0.3212


0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,█▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,█▆▄▄▆▃▄▂▃▂▃▂▂▂▁▃▁▂▂▁▁▂▂▂▁▂▂▁▂▁▁▁▃▂▂▂▂▂▁▂

0,1
best_val_loss,72.18144
epoch,48.0
epochs_trained,48.0
mae,4.67487
n_samples,67428.0
r2,0.32119
rmse,6.87853
train_loss,73.40874
val_loss,72.26317



✓ All runs complete.


## 6. Local Comparison Table

W&B gives you an interactive dashboard, but it's also useful to see the summary locally.

**How to read the classifier table:**  
- Focus on **AUC** (overall discrimination) and **Recall** (did we catch high-risk patients?)  
- F1 will always look modest with 13% positive rate — don't be misled by it

**How to read the regressor table:**  
- Focus on **MAE** (interpretable in days) and **R²** (how much variance explained)  
- Variant C (WeightedMSE) may trade overall MAE for better accuracy on high-day cases

In [19]:
import pandas as pd

clf_rows = [r for r in all_results if r['task'] == 'clf']
reg_rows = [r for r in all_results if r['task'] == 'reg']

print('=== CLASSIFIER RESULTS ===')
clf_df = pd.DataFrame(clf_rows)[['run', 'roc_auc', 'f1', 'recall', 'precision', 'accuracy']]
clf_df.columns = ['Run', 'AUC', 'F1', 'Recall', 'Precision', 'Accuracy']
clf_df = clf_df.sort_values('AUC', ascending=False).reset_index(drop=True)
print(clf_df.to_string(index=False, float_format='{:.4f}'.format))

print()
print('=== REGRESSOR RESULTS ===')
reg_df = pd.DataFrame(reg_rows)[['run', 'mae', 'rmse', 'r2']]
reg_df.columns = ['Run', 'MAE', 'RMSE', 'R²']
reg_df = reg_df.sort_values('MAE').reset_index(drop=True)
print(reg_df.to_string(index=False, float_format='{:.4f}'.format))

print()
best_clf = clf_df.iloc[0]
best_reg = reg_df.iloc[0]
print(f"Best classifier: {best_clf['Run']}  (AUC={best_clf['AUC']:.4f})")
print(f"Best regressor:  {best_reg['Run']}  (MAE={best_reg['MAE']:.3f})")

=== CLASSIFIER RESULTS ===
                        Run    AUC     F1  Recall  Precision  Accuracy
        variant_b_wider_clf 0.8564 0.4847  0.7778     0.3520    0.7812
variant_c_weighted_loss_clf 0.8563 0.4864  0.7751     0.3544    0.7834
     variant_a_baseline_clf 0.8561 0.4875  0.7746     0.3557    0.7846

=== REGRESSOR RESULTS ===
                        Run    MAE   RMSE     R²
     variant_a_baseline_reg 4.1554 6.7200 0.3521
        variant_b_wider_reg 4.1865 6.7246 0.3512
variant_c_weighted_loss_reg 4.6749 6.8785 0.3212

Best classifier: variant_b_wider_clf  (AUC=0.8564)
Best regressor:  variant_a_baseline_reg  (MAE=4.155)


## 7. Milestone 5 Summary

**What we did:**
- Added `WeightedMSELoss` to `src/training/trainer.py` — a custom loss that gives 2× penalty to high-day errors
- Wired W&B into the `train()` function via the optional `wandb_run=` parameter
- Trained 6 runs (3 variants × 2 tasks) and logged every epoch's loss to W&B
- Compared final test metrics in the W&B dashboard and locally

**What the W&B dashboard gives you:**
- Live loss curves for every run (find it at wandb.ai → project 'screenmind')
- A sortable table of all runs with their hyperparameters and final metrics
- Easy filtering: e.g. "show me all runs with AUC > 0.85"

**Next: Milestone 6 — RAG Knowledge Base**  
We will embed scientific papers into ChromaDB and build a retriever that fetches relevant context
given a patient's risk factors. This becomes the knowledge source for the LangGraph agent.

## Bonus: Regression on the "Affected" Subpopulation (≥1 bad day)

### The zero-inflation problem

60% of BRFSS respondents report **0** bad mental health days.  When the
regression model trains on the full dataset, the loss is dominated by this
zero-inflated majority and the model learns to predict low values for almost everyone.

A different, arguably more actionable question is:

> **Among people who already have *some* mental health burden, how severe is it?**

We answer this by filtering to rows where `y_reg > 0` and retraining.

**What changes:**
- The task shifts from "predict burden for any US adult" to "predict severity among affected adults"
- Training set shrinks from ~315k → ~125k rows
- Mean target rises from ~4.4 days to ~11 days
- R² and MAE should be interpreted relative to this subgroup, not the full population
- This is a legitimate modelling choice — many clinical tools focus on *severity among the affected*

In [20]:
import os
os.makedirs(MODELS_DIR, exist_ok=True)

# ── 1. Filter each split to rows where y_reg > 0 ──────────────────────────────
mask_train = data['y_reg_train'] > 0
mask_val   = data['y_reg_val']   > 0
mask_test  = data['y_reg_test']  > 0

X_train_aff = data['X_train'][mask_train]
y_train_aff = data['y_reg_train'][mask_train]
X_val_aff   = data['X_val'][mask_val]
y_val_aff   = data['y_reg_val'][mask_val]
X_test_aff  = data['X_test'][mask_test]
y_test_aff  = data['y_reg_test'][mask_test]

print('Subset sizes (y_reg > 0):')
print(f'  Full train:     {len(data["X_train"]):>8,}  →  Affected: {len(X_train_aff):>7,}  ({len(X_train_aff)/len(data["X_train"])*100:.1f}%)')
print(f'  Full val:       {len(data["X_val"]):>8,}  →  Affected: {len(X_val_aff):>7,}  ({len(X_val_aff)/len(data["X_val"])*100:.1f}%)')
print(f'  Full test:      {len(data["X_test"]):>8,}  →  Affected: {len(X_test_aff):>7,}  ({len(X_test_aff)/len(data["X_test"])*100:.1f}%)')
print()
print(f'  Mean MENTHLTH (full population):    {data["y_reg_train"].mean():.2f} days')
print(f'  Mean MENTHLTH (affected subset):    {y_train_aff.mean():.2f} days')
print(f'  Baseline MAE on subset (predict mean): {abs(y_test_aff - y_test_aff.mean()).mean():.3f}')

# ── 2. Build DataLoaders manually ─────────────────────────────────────────────
from torch.utils.data import DataLoader
from src.training.trainer import BRFSSDataset

train_loader_aff = DataLoader(BRFSSDataset(X_train_aff, y_train_aff), batch_size=512, shuffle=True,  num_workers=0)
val_loader_aff   = DataLoader(BRFSSDataset(X_val_aff,   y_val_aff),   batch_size=512, shuffle=False, num_workers=0)
test_loader_aff  = DataLoader(BRFSSDataset(X_test_aff,  y_test_aff),  batch_size=512, shuffle=False, num_workers=0)

# ── 3. Train ──────────────────────────────────────────────────────────────────
model_aff = MLP(input_dim=15, hidden_dims=[128, 64], dropout=0.3, task='reg')
criterion_aff = make_criterion('reg', y_train_aff)

history_aff = train(
    model=model_aff,
    train_loader=train_loader_aff,
    val_loader=val_loader_aff,
    criterion=criterion_aff,
    checkpoint_path=f'{MODELS_DIR}/reg_affected_best.pt',
    lr=1e-3,
    max_epochs=100,
    patience=10,
)

# ── 4. Evaluate ───────────────────────────────────────────────────────────────
model_aff.load_state_dict(torch.load(f'{MODELS_DIR}/reg_affected_best.pt', weights_only=True))
metrics_aff = evaluate_reg(model_aff, test_loader_aff)

# Pull baseline reg metrics from the all_results list defined in cell 5
# (metrics_reg lives in 03_model.ipynb, not here)
baseline_reg = next(r for r in all_results if r['run'] == 'variant_a_baseline_reg')

print()
print('=== Affected Subset Regressor (y_reg > 0) ===')
print(f'  MAE:  {metrics_aff["mae"]:.3f} days')
print(f'  RMSE: {metrics_aff["rmse"]:.3f} days')
print(f'  R²:   {metrics_aff["r2"]:.4f}')
print()
print('=== Comparison ===')
print(f'  Full-population model  — MAE: {baseline_reg["mae"]:.3f}  R²: {baseline_reg["r2"]:.4f}')
print(f'  Affected-only model    — MAE: {metrics_aff["mae"]:.3f}  R²: {metrics_aff["r2"]:.4f}')
print()
print('Note: the two MAEs are NOT directly comparable — they measure error')
print('on different test sets (full vs affected subset) with different mean targets.')
print('R² is more informative: it measures how much variance within each group is explained.')

Subset sizes (y_reg > 0):
  Full train:      314,659  →  Affected: 126,006  (40.0%)
  Full val:         67,427  →  Affected:  26,830  (39.8%)
  Full test:        67,428  →  Affected:  26,769  (39.7%)

  Mean MENTHLTH (full population):    4.42 days
  Mean MENTHLTH (affected subset):    11.04 days
  Baseline MAE on subset (predict mean): 8.569
Training on cpu  |  max_epochs=100  |  patience=10
 Epoch    Train Loss      Val Loss    Best
--------------------------------------------
     1     132.09844      81.41433   ✓
     2      79.70052      75.26604   ✓
     3      78.28926      74.93330   ✓
     4      77.92303      74.82152   ✓
     5      77.72394      74.73298   ✓
     6      77.63907      74.76095  
     7      77.48473      74.60441   ✓
     8      77.38590      74.78923  
     9      77.32826      74.69868  
    10      77.25120      74.59886   ✓
    11      77.09087      74.69891  
    12      77.03165      74.56622   ✓
    13      77.06529      74.70201  
    14      76.8487

In [21]:
metrics_aff

{'mae': 6.804905891418457,
 'rmse': 8.617044068955904,
 'r2': 0.26925361156463623,
 'n_samples': 26769}

In [22]:
baseline_reg

{'run': 'variant_a_baseline_reg',
 'task': 'reg',
 'mae': 4.155448913574219,
 'rmse': 6.720008625751591,
 'r2': 0.35211753845214844,
 'n_samples': 67428}