In [None]:
# # Trainer — AI-Powered Digital Epidemiologist (Modular & Professional)

# **What this notebook does**
# 1. Generate a synthetic vitals dataset (~ configurable size; default 210k rows)
# 2. Compute and save dataset statistics (JSON) for poster use
# 3. Create and save visualizations (histogram, boxplot, pie, line, scatter)
# 4. Preprocess (scaling) and save `scaler.pkl`
# 5. Train a classification model (PyTorch if available; else sklearn RandomForest)
#    - Training includes logging, checkpointing, optional early stopping and best-model saving
# 6. Save final artifacts to `model/artifacts/` and plots to `plots/`

# **How to run**
# - Copy cells in order and run them.
# - Tune hyperparameters in *Cell 2 (Configuration)* before running later cells.
# - If you want a full training on the entire dataset, set `TRAIN_ON_FULL_DATA = True` in cell 2 (but expect longer runtime).



In [5]:
# Cell 1 — Debug CWD (run first)
from pathlib import Path
cwd = Path.cwd()
print("Current working directory:", cwd)
print("Contents:", [p.name for p in cwd.iterdir()])


Current working directory: c:\Users\tanma\OneDrive\Desktop\Machine-Learning\model
Contents: ['artifacts', 'train.py', 'trainer.ipynb']


In [6]:
# Cell 2 — Configuration & robust artifact/plots detection
from pathlib import Path

cwd = Path.cwd()

# Robust detection for artifacts folder:
# If running inside model/ and artifacts exists: use cwd/artifacts
# Else if repo-root has model/artifacts: use that
if (cwd / "artifacts").exists():
    ARTIFACT_DIR = cwd / "artifacts"
elif (cwd / "model" / "artifacts").exists():
    ARTIFACT_DIR = cwd / "model" / "artifacts"
else:
    # default: create artifacts next to notebook
    ARTIFACT_DIR = cwd / "artifacts"

PLOTS_DIR = cwd / "plots"

ARTIFACT_DIR.mkdir(parents=True, exist_ok=True)
PLOTS_DIR.mkdir(parents=True, exist_ok=True)

# Dataset & training configuration (tweak these as needed)
DATASET_SIZE = 210_000       # >= 200,000 as requested (default 210k)
SEED = 42

TRAIN_ON_FULL_DATA = False   # If True, trains on whole DATASET_SIZE (can be slow)
SAMPLE_SIZE = 60_000         # sample size used for training if not training on full data
VAL_SPLIT = 0.15             # validation fraction

USE_PYTORCH = True           # try PyTorch first, fallback to sklearn if not available
BATCH_SIZE = 1024
EPOCHS = 10
LEARNING_RATE = 0.001
HIDDEN_UNITS = 64

EARLY_STOPPING = True
EARLY_STOPPING_PATIENCE = 3
CHECKPOINT_ON_IMPROVEMENT = True

SAVE_SUMMARY_JSON = True

print("ARTIFACT_DIR:", ARTIFACT_DIR.resolve())
print("PLOTS_DIR:   ", PLOTS_DIR.resolve())
print("DATASET_SIZE:", DATASET_SIZE)
print("TRAIN_ON_FULL_DATA:", TRAIN_ON_FULL_DATA, "| SAMPLE_SIZE:", SAMPLE_SIZE)


ARTIFACT_DIR: C:\Users\tanma\OneDrive\Desktop\Machine-Learning\model\artifacts
PLOTS_DIR:    C:\Users\tanma\OneDrive\Desktop\Machine-Learning\model\plots
DATASET_SIZE: 210000
TRAIN_ON_FULL_DATA: False | SAMPLE_SIZE: 60000


In [7]:
# Cell 3 — Imports & utilities
import os, json, time, math, random
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import joblib
from sklearn.metrics import accuracy_score

# optional: use tqdm if available
try:
    from tqdm.auto import tqdm
except Exception:
    tqdm = lambda x, **kw: x

# Logger
import logging
logger = logging.getLogger("trainer")
if not logger.handlers:
    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)
    fmt = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s", "%H:%M:%S")
    ch.setFormatter(fmt)
    logger.addHandler(ch)
logger.setLevel(logging.INFO)

# reproducibility
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    try:
        import torch
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
    except Exception:
        pass

# device detection
def get_device():
    try:
        import torch
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")
    except Exception:
        return "cpu"

set_seed(SEED)
DEVICE = get_device()
logger.info(f"Device set to: {DEVICE}")


15:15:47 - INFO - Device set to: cpu


In [8]:
# Cell 4 — Generate synthetic dataset (or load if exists)
csv_path = ARTIFACT_DIR / "synthetic_epidemic_vitals.csv"

if csv_path.exists():
    logger.info("Found existing CSV - loading: %s", csv_path)
    df = pd.read_csv(csv_path)
else:
    logger.info("Generating synthetic dataset with %d rows (seed=%d)", DATASET_SIZE, SEED)
    rng = np.random.default_rng(SEED)

    # core distributions
    hr = rng.normal(loc=75, scale=12, size=DATASET_SIZE)
    rr = rng.normal(loc=16, scale=3, size=DATASET_SIZE)
    spo2 = rng.normal(loc=96.5, scale=2.5, size=DATASET_SIZE)

    # clamp to plausible physiological bounds
    hr = np.clip(hr, 35, 200)
    rr = np.clip(rr, 6, 40)
    spo2 = np.clip(spo2, 70, 100)

    # deterministic rule-based labels (repeatable)
    labels = []
    for h, r_, s in zip(hr, rr, spo2):
        if (s < 92) or (h > 140) or (r_ > 25):
            labels.append(2)   # High Risk
        elif (92 <= s < 94) or (120 < h <= 140) or (20 < r_ <= 25):
            labels.append(1)   # At Risk
        else:
            labels.append(0)   # Normal

    df = pd.DataFrame({
        "HR": np.round(hr, 2),
        "RR": np.round(rr, 2),
        "SpO2": np.round(spo2, 2),
        "Label": labels
    })

    df.to_csv(csv_path, index=False)
    logger.info("Saved synthetic CSV to: %s", csv_path)

logger.info("Dataset shape: %s", df.shape)
df.head()


15:15:52 - INFO - Generating synthetic dataset with 210000 rows (seed=42)
15:15:53 - INFO - Saved synthetic CSV to: c:\Users\tanma\OneDrive\Desktop\Machine-Learning\model\artifacts\synthetic_epidemic_vitals.csv
15:15:53 - INFO - Dataset shape: (210000, 4)


Unnamed: 0,HR,RR,SpO2,Label
0,78.66,18.97,93.5,1
1,62.52,20.12,92.95,1
2,84.01,14.61,95.31,0
3,86.29,15.72,96.16,0
4,51.59,16.31,95.91,0


In [9]:
# Cell 5 — Compute poster-ready statistics and save JSON
def series_stats(s: pd.Series):
    s_rounded = s.round(1)
    mode_vals = s_rounded.mode().tolist()
    mode_val = float(mode_vals[0]) if len(mode_vals) > 0 else float("nan")
    return {
        "count": int(s.count()),
        "mean": float(s.mean()),
        "median": float(s.median()),
        "mode": mode_val,
        "min": float(s.min()),
        "max": float(s.max()),
        "range": float(s.max() - s.min()),
        "std": float(s.std())
    }

poster_stats = {
    "dataset_rows": int(df.shape[0]),
    "dataset_columns": int(df.shape[1]),
    "features": ["HR", "RR", "SpO2", "Label"],
    "stats": {
        "HR": series_stats(df["HR"]),
        "RR": series_stats(df["RR"]),
        "SpO2": series_stats(df["SpO2"])
    },
    "label_counts": df["Label"].value_counts().sort_index().to_dict()
}

poster_json_path = ARTIFACT_DIR / "poster_dataset_stats.json"
with open(poster_json_path, "w") as f:
    json.dump(poster_stats, f, indent=2)

logger.info("Saved poster-ready stats to %s", poster_json_path)
poster_stats


15:16:01 - INFO - Saved poster-ready stats to c:\Users\tanma\OneDrive\Desktop\Machine-Learning\model\artifacts\poster_dataset_stats.json


{'dataset_rows': 210000,
 'dataset_columns': 4,
 'features': ['HR', 'RR', 'SpO2', 'Label'],
 'stats': {'HR': {'count': 210000,
   'mean': 74.99011947619047,
   'median': 74.96,
   'mode': 72.6,
   'min': 35.0,
   'max': 135.09,
   'range': 100.09,
   'std': 12.021493378945301},
  'RR': {'count': 210000,
   'mean': 16.000955904761906,
   'median': 16.0,
   'mode': 16.2,
   'min': 6.0,
   'max': 30.02,
   'range': 24.02,
   'std': 2.994740888214876},
  'SpO2': {'count': 210000,
   'mean': 96.40052785714286,
   'median': 96.49,
   'mode': 100.0,
   'min': 85.22,
   'max': 100.0,
   'range': 14.780000000000001,
   'std': 2.3313834606617228}},
 'label_counts': {0: 160339, 1: 41788, 2: 7873}}

In [10]:
# Cell 6 — Generate & save visuals (matplotlib)
plot_paths = {}

# HR Histogram
plt.figure(figsize=(7,4))
plt.hist(df["HR"], bins=60)
plt.title("Heart Rate Distribution")
plt.xlabel("HR (bpm)")
plt.ylabel("Count")
p = PLOTS_DIR / "hr_histogram.png"
plt.tight_layout()
plt.savefig(p)
plt.close()
plot_paths["hr_histogram"] = str(p)

# RR Boxplot
plt.figure(figsize=(4,4))
plt.boxplot(df["RR"], vert=True)
plt.title("Respiratory Rate (Boxplot)")
plt.ylabel("RR (breaths/min)")
p = PLOTS_DIR / "rr_boxplot.png"
plt.tight_layout()
plt.savefig(p)
plt.close()
plot_paths["rr_boxplot"] = str(p)

# Label Pie Chart
plt.figure(figsize=(5,5))
counts = df["Label"].value_counts().sort_index()
plt.pie(counts, labels=["Normal (0)","At Risk (1)","High Risk (2)"], autopct="%1.1f%%", startangle=90)
plt.title("Label Proportions")
p = PLOTS_DIR / "label_proportions_pie.png"
plt.tight_layout()
plt.savefig(p)
plt.close()
plot_paths["label_pie"] = str(p)

# SpO2 line (first 2000 samples to show trend)
plt.figure(figsize=(8,3))
plt.plot(df["SpO2"].values[:2000])
plt.title("SpO2 Trend (First 2000 samples)")
plt.xlabel("Index")
plt.ylabel("SpO2 (%)")
p = PLOTS_DIR / "spo2_line.png"
plt.tight_layout()
plt.savefig(p)
plt.close()
plot_paths["spo2_line"] = str(p)

# HR vs RR Scatter (sampled)
sample_df = df.sample(n=min(5000, len(df)), random_state=SEED)
plt.figure(figsize=(6,4))
plt.scatter(sample_df["HR"], sample_df["RR"], s=6, alpha=0.6)
plt.title("HR vs RR (sampled)")
plt.xlabel("HR (bpm)")
plt.ylabel("RR (breaths/min)")
p = PLOTS_DIR / "hr_vs_rr_scatter.png"
plt.tight_layout()
plt.savefig(p)
plt.close()
plot_paths["hr_rr_scatter"] = str(p)

logger.info("Saved plots to %s", PLOTS_DIR.resolve())
plot_paths


15:16:11 - INFO - Saved plots to C:\Users\tanma\OneDrive\Desktop\Machine-Learning\model\plots


{'hr_histogram': 'c:\\Users\\tanma\\OneDrive\\Desktop\\Machine-Learning\\model\\plots\\hr_histogram.png',
 'rr_boxplot': 'c:\\Users\\tanma\\OneDrive\\Desktop\\Machine-Learning\\model\\plots\\rr_boxplot.png',
 'label_pie': 'c:\\Users\\tanma\\OneDrive\\Desktop\\Machine-Learning\\model\\plots\\label_proportions_pie.png',
 'spo2_line': 'c:\\Users\\tanma\\OneDrive\\Desktop\\Machine-Learning\\model\\plots\\spo2_line.png',
 'hr_rr_scatter': 'c:\\Users\\tanma\\OneDrive\\Desktop\\Machine-Learning\\model\\plots\\hr_vs_rr_scatter.png'}

In [11]:
# Cell 7 — Scale features, optionally sample for training, and split
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

X_full = df[["HR","RR","SpO2"]].values
y_full = df["Label"].values

scaler = StandardScaler()
scaler.fit(X_full)
X_full_scaled = scaler.transform(X_full)

scaler_path = ARTIFACT_DIR / "scaler.pkl"
joblib.dump(scaler, scaler_path)
logger.info("Saved scaler to %s", scaler_path)

# Choose subset for training to keep runs fast by default
if TRAIN_ON_FULL_DATA:
    X_for_train = X_full_scaled
    y_for_train = y_full
    logger.info("Training will use full dataset: %d rows", X_for_train.shape[0])
else:
    rng = np.random.default_rng(SEED)
    idx = rng.choice(len(X_full_scaled), size=min(SAMPLE_SIZE, len(X_full_scaled)), replace=False)
    X_for_train = X_full_scaled[idx]
    y_for_train = y_full[idx]
    logger.info("Training will use subsample: %d rows", X_for_train.shape[0])

# Train/validation split
X_train, X_val, y_train, y_val = train_test_split(X_for_train, y_for_train, test_size=VAL_SPLIT, random_state=SEED, stratify=y_for_train if len(np.unique(y_for_train))>1 else None)
logger.info("Train/Val sizes: %d / %d", X_train.shape[0], X_val.shape[0])


15:16:40 - INFO - Saved scaler to c:\Users\tanma\OneDrive\Desktop\Machine-Learning\model\artifacts\scaler.pkl
15:16:40 - INFO - Training will use subsample: 60000 rows
15:16:40 - INFO - Train/Val sizes: 51000 / 9000


In [12]:
# Cell 8 — Model builders & sklearn fallback

def build_torch_model(input_dim=3, hidden=HIDDEN_UNITS, num_classes=3):
    import torch.nn as nn
    class SimpleNet(nn.Module):
        def __init__(self, input_dim=input_dim, hidden=hidden, num_classes=num_classes):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(input_dim, hidden),
                nn.ReLU(),
                nn.Linear(hidden, max(8, hidden//2)),
                nn.ReLU(),
                nn.Linear(max(8, hidden//2), num_classes)
            )
        def forward(self, x):
            return self.net(x)
    return SimpleNet()

def train_sklearn_model(X_train, y_train, X_val, y_val, out_path=ARTIFACT_DIR / "model_sklearn.joblib"):
    from sklearn.ensemble import RandomForestClassifier
    clf = RandomForestClassifier(n_estimators=100, max_depth=12, n_jobs=-1, random_state=SEED)
    clf.fit(X_train, y_train)
    joblib.dump(clf, out_path)
    logger.info("Saved sklearn model to %s", out_path)
    if X_val is not None:
        y_pred = clf.predict(X_val)
        acc = accuracy_score(y_val, y_pred)
        logger.info("Validation accuracy (sklearn): %.4f", acc)
    return str(out_path)


In [13]:
# Cell 9 — PyTorch training loop (returns saved checkpoint path)
def train_torch(X_train, y_train, X_val, y_val, device=DEVICE, 
                epochs=EPOCHS, batch_size=BATCH_SIZE, lr=LEARNING_RATE, 
                checkpoint_path=ARTIFACT_DIR / "model.pt", patience=EARLY_STOPPING_PATIENCE):
    import torch
    from torch.utils.data import DataLoader, TensorDataset

    num_classes = len(np.unique(np.concatenate([y_train, y_val])))
    model = build_torch_model(input_dim=X_train.shape[1], hidden=HIDDEN_UNITS, num_classes=num_classes)
    model.to(device)
    torch.set_num_threads(1)

    X_tr = torch.tensor(X_train, dtype=torch.float32)
    y_tr = torch.tensor(y_train, dtype=torch.long)
    X_v = torch.tensor(X_val, dtype=torch.float32)
    y_v = torch.tensor(y_val, dtype=torch.long)

    train_ds = TensorDataset(X_tr, y_tr)
    val_ds = TensorDataset(X_v, y_v)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    best_val_loss = float("inf")
    epochs_no_improve = 0
    best_state = None
    history = {"train_loss": [], "val_loss": [], "val_acc": []}

    for epoch in range(1, epochs+1):
        model.train()
        running_loss = 0.0
        n_samples = 0
        for xb, yb in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", leave=False):
            xb = xb.to(device)
            yb = yb.to(device)
            optimizer.zero_grad()
            out = model(xb)
            loss = criterion(out, yb)
            loss.backward()
            optimizer.step()
            batch_n = xb.size(0)
            running_loss += loss.item() * batch_n
            n_samples += batch_n
        train_loss = running_loss / max(1, n_samples)

        # validation
        model.eval()
        val_loss_total = 0.0
        val_samples = 0
        all_preds = []
        with torch.no_grad():
            for xb, yb in val_loader:
                xb = xb.to(device)
                yb = yb.to(device)
                out = model(xb)
                loss = criterion(out, yb)
                val_loss_total += loss.item() * xb.size(0)
                val_samples += xb.size(0)
                preds = out.argmax(dim=1).cpu().numpy()
                all_preds.append(preds)
        val_loss = val_loss_total / max(1, val_samples)
        if len(all_preds) > 0:
            all_preds = np.concatenate(all_preds)
            val_acc = float((all_preds == y_val[:len(all_preds)]).mean())
        else:
            val_acc = 0.0

        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc)

        logger.info("Epoch %d/%d — train_loss: %.4f — val_loss: %.4f — val_acc: %.4f", epoch, epochs, train_loss, val_loss, val_acc)

        # checkpoint & early stopping
        if val_loss < best_val_loss - 1e-6:
            best_val_loss = val_loss
            epochs_no_improve = 0
            best_state = model.state_dict()
            if CHECKPOINT_ON_IMPROVEMENT:
                torch.save(best_state, checkpoint_path)
                logger.info("Checkpoint saved to %s", checkpoint_path)
        else:
            epochs_no_improve += 1
            if EARLY_STOPPING and epochs_no_improve >= patience:
                logger.info("Early stopping triggered (no improvement in %d epochs).", patience)
                break

    # Save best model
    if best_state is not None:
        torch.save(best_state, checkpoint_path)
        logger.info("Best model saved to %s", checkpoint_path)
    else:
        torch.save(model.state_dict(), checkpoint_path)
        logger.info("Final model saved to %s (no improvement observed)", checkpoint_path)

    # Save history
    hist_path = ARTIFACT_DIR / "training_history.json"
    with open(hist_path, "w") as f:
        json.dump(history, f, indent=2)
    logger.info("Training history saved to %s", hist_path)

    return str(checkpoint_path)


In [14]:
# Cell 10 — Orchestration & training run
model_artifact_path = None
training_start = time.time()

if USE_PYTORCH:
    try:
        import torch
        logger.info("Attempting to train using PyTorch (device: %s)", DEVICE)
        model_artifact_path = train_torch(X_train, y_train, X_val, y_val, device=DEVICE, epochs=EPOCHS, batch_size=BATCH_SIZE, lr=LEARNING_RATE)
    except Exception as e:
        logger.warning("PyTorch training failed: %s", str(e))
        USE_PYTORCH = False

if not USE_PYTORCH:
    logger.info("Training with sklearn fallback (RandomForest). This may be slower for large data but is robust.")
    model_artifact_path = train_sklearn_model(X_train, y_train, X_val, y_val, out_path=ARTIFACT_DIR / "model_sklearn.joblib")

training_end = time.time()
duration_sec = training_end - training_start
logger.info("Training finished in %.1f seconds", duration_sec)

# Save summary JSON
summary = {
    "dataset_csv": str((csv_path).resolve()),
    "poster_stats_json": str(poster_json_path.resolve()),
    "plots": plot_paths,
    "scaler": str(scaler_path.resolve()),
    "model_artifact": model_artifact_path,
    "training_seconds": duration_sec,
    "config": {
        "DATASET_SIZE": DATASET_SIZE,
        "TRAIN_ON_FULL_DATA": TRAIN_ON_FULL_DATA,
        "SAMPLE_SIZE": SAMPLE_SIZE,
        "EPOCHS": EPOCHS,
        "BATCH_SIZE": BATCH_SIZE,
        "LEARNING_RATE": LEARNING_RATE,
        "HIDDEN_UNITS": HIDDEN_UNITS
    }
}
if SAVE_SUMMARY_JSON:
    with open(ARTIFACT_DIR / "dataset_training_summary.json", "w") as f:
        json.dump(summary, f, indent=2)
    logger.info("Saved dataset and training summary to %s", ARTIFACT_DIR / "dataset_training_summary.json")

model_artifact_path


15:17:02 - INFO - Attempting to train using PyTorch (device: cpu)
15:17:04 - INFO - Epoch 1/10 — train_loss: 0.7697 — val_loss: 0.4929 — val_acc: 0.8059
15:17:04 - INFO - Checkpoint saved to c:\Users\tanma\OneDrive\Desktop\Machine-Learning\model\artifacts\model.pt
15:17:04 - INFO - Epoch 2/10 — train_loss: 0.3728 — val_loss: 0.2738 — val_acc: 0.9048
15:17:04 - INFO - Checkpoint saved to c:\Users\tanma\OneDrive\Desktop\Machine-Learning\model\artifacts\model.pt
15:17:05 - INFO - Epoch 3/10 — train_loss: 0.2163 — val_loss: 0.1693 — val_acc: 0.9637
15:17:05 - INFO - Checkpoint saved to c:\Users\tanma\OneDrive\Desktop\Machine-Learning\model\artifacts\model.pt
15:17:05 - INFO - Epoch 4/10 — train_loss: 0.1455 — val_loss: 0.1217 — val_acc: 0.9743
15:17:05 - INFO - Checkpoint saved to c:\Users\tanma\OneDrive\Desktop\Machine-Learning\model\artifacts\model.pt
15:17:06 - INFO - Epoch 5/10 — train_loss: 0.1097 — val_loss: 0.0946 — val_acc: 0.9830
15:17:06 - INFO - Checkpoint saved to c:\Users\tanm

'c:\\Users\\tanma\\OneDrive\\Desktop\\Machine-Learning\\model\\artifacts\\model.pt'

In [15]:
# Cell 11 — Verify artifacts & sample
print("Artifacts written to:", ARTIFACT_DIR.resolve())
for p in sorted(ARTIFACT_DIR.glob("*")):
    print("-", p.name)

print("\nPlots written to:", PLOTS_DIR.resolve())
for p in sorted(PLOTS_DIR.glob("*.png")):
    print("-", p.name)

print("\nDataset sample (5 rows):")
display(df.sample(5, random_state=SEED))


Artifacts written to: C:\Users\tanma\OneDrive\Desktop\Machine-Learning\model\artifacts
- dataset_training_summary.json
- model.pt
- poster_dataset_stats.json
- scaler.pkl
- synthetic_epidemic_vitals.csv
- training_history.json

Plots written to: C:\Users\tanma\OneDrive\Desktop\Machine-Learning\model\plots
- hr_histogram.png
- hr_vs_rr_scatter.png
- label_proportions_pie.png
- rr_boxplot.png
- spo2_line.png

Dataset sample (5 rows):


Unnamed: 0,HR,RR,SpO2,Label
194949,66.65,14.64,97.75,0
161875,69.09,11.93,94.34,0
61912,65.03,18.72,96.43,0
35966,65.38,15.16,97.16,0
143596,68.17,14.12,97.16,0


In [16]:
# Cell 12 — Optional retrain helper (call retrain() after modifying Cell 2 config)
def retrain():
    global df, X_full_scaled, X_train, X_val, y_train, y_val, scaler_path, model_artifact_path, plot_paths, poster_stats, poster_json_path
    # reload dataset
    if not (ARTIFACT_DIR / "synthetic_epidemic_vitals.csv").exists():
        raise FileNotFoundError("synthetic_epidemic_vitals.csv missing; run Cell 4 to generate dataset first.")
    df = pd.read_csv(ARTIFACT_DIR / "synthetic_epidememic_vitals.csv")
    # recompute stats/plots
    poster_stats = {
        "dataset_rows": int(df.shape[0]),
        "stats": {
            "HR": series_stats(df["HR"]),
            "RR": series_stats(df["RR"]),
            "SpO2": series_stats(df["SpO2"])
        },
        "label_counts": df["Label"].value_counts().sort_index().to_dict()
    }
    with open(poster_json_path, "w") as f:
        json.dump(poster_stats, f, indent=2)
    plot_paths = save_plots(df) if 'save_plots' in globals() else None
    # scale and split
    X_full = df[["HR","RR","SpO2"]].values
    scaler = StandardScaler()
    scaler.fit(X_full)
    X_full_scaled = scaler.transform(X_full)
    joblib.dump(scaler, scaler_path)
    if not TRAIN_ON_FULL_DATA:
        rng = np.random.default_rng(SEED)
        idx = rng.choice(len(X_full_scaled), size=min(SAMPLE_SIZE, len(X_full_scaled)), replace=False)
        X_for_train = X_full_scaled[idx]
        y_for_train = df["Label"].values[idx]
    else:
        X_for_train = X_full_scaled
        y_for_train = df["Label"].values
    from sklearn.model_selection import train_test_split
    X_train, X_val, y_train, y_val = train_test_split(X_for_train, y_for_train, test_size=VAL_SPLIT, random_state=SEED)
    # train
    if USE_PYTORCH:
        model_artifact_path = train_torch(X_train, y_train, X_val, y_val)
    else:
        model_artifact_path = train_sklearn_model(X_train, y_train, X_val, y_val)
    logger.info("Retrain saved model to %s", model_artifact_path)

# Note: this helper is optional; if you want to retrain with different config, modify Cell 2 and call retrain()
