# Week-03 — Baseline CNNs: Data Loading & Environment Setup

**Goal of this notebook.**
Establish the reproducible foundation for **Chapter 2** by:
- Setting up the Python / PyTorch environment (Apple Silicon **MPS** on M-series).
- Recording exact package versions (for reproducibility).
- Creating project folders for data and outputs.
- (Next cells) Loading three **MedMNIST** subsets — **ChestMNIST**, **PneumoniaMNIST**, **BreastMNIST** — to run baseline CNNs (ResNet-18).
- Ensuring this notebook is **explainable** with narrative + figures.

**Why start here?**
Chapter 2 asks: *What is baseline CNN performance and what are its limitations under small, “thick” data conditions?*
A clean, deterministic setup ensures that any performance we report (Accuracy, Sensitivity/Specificity, F1, ROC-AUC, PR-AUC, Calibration) is repeatable. The limitations we find here will **motivate Chapter 3** (few-shot with clinical heuristics) and **Chapter 4** (augmentation / ROI).

**Deliverables from this notebook (over Week-03):**
- Verified environment (MPS enabled), package versions snapshot.
- Dataset loaders & sanity checks (sizes, class balance, sample grids).
- Baseline **ResNet-18** training/evaluation per dataset.
- Saved figures: confusion matrices, ROC/PR curves, 2–3 Grad-CAMs per dataset.
- A concise results table to copy into **Chapter 2** (proposal).

# Environment Setup (Apple Silicon MPS, Reproducibility, Paths)

In [5]:
import os
import sys
import platform
import random
from pathlib import Path
from datetime import datetime
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import torch
import importlib
import pandas as pd

In [6]:
# -------------------------
# 0) Notebook display tweaks
# -------------------------
# High-DPI figures if running in IPython
try:
    from IPython import get_ipython
    ip = get_ipython()
    if ip is not None:
        ip.run_line_magic("config", "InlineBackend.figure_format = 'retina'")
except Exception:
    pass

matplotlib.rcParams["figure.dpi"] = 120
plt.rcParams["axes.grid"] = True

# -------------------------
# 1) Reproducibility
# -------------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# Make CuDNN deterministic if we're on CUDA (harmless on MPS/CPU)
if hasattr(torch.backends, "cudnn"):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# -------------------------
# 2) Device selection (prefers CUDA > MPS > CPU)
# -------------------------
def pick_device():
    if torch.cuda.is_available():
        name = torch.cuda.get_device_name(0)
        return torch.device("cuda"), f"CUDA: {name}"
    # Apple Silicon Metal backend (M-series)
    if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
        return torch.device("mps"), "Apple Silicon MPS"
    return torch.device("cpu"), "CPU"

DEVICE, DEVICE_NAME = pick_device()

# Quick allocation test (catches subtle device issues early)
try:
    _ = torch.ones(1, device=DEVICE) * 1.0
    DEVICE_OK = True
except Exception as e:
    DEVICE_OK = False

# -------------------------
# 3) Ensure/record dependencies
# -------------------------
def ensure_package(mod_name, pip_name=None):
    """
    Import a module if present; if missing, attempt pip install, then import again.
    Returns the imported module or raises the ImportError if install fails.
    """
    pip_name = pip_name or mod_name
    try:
        return importlib.import_module(mod_name)
    except ImportError:
        print(f"[setup] '{mod_name}' not found. Attempting to install '{pip_name}' ...")
        try:
            import subprocess
            subprocess.check_call([sys.executable, "-m", "pip", "install", "--quiet", pip_name])
            return importlib.import_module(mod_name)
        except Exception as ie:
            print(f"[setup] Failed to install '{pip_name}'. Please install manually.")
            raise ie

# MedMNIST loader (we'll use it in the next cell)
medmnist = ensure_package("medmnist")

# torchvision for ResNet-18 later
torchvision = ensure_package("torchvision")

# sklearn for metrics later
sklearn = ensure_package("sklearn", "scikit-learn")

# -------------------------
# 4) Project paths
# -------------------------
PROJECT_ROOT = Path.cwd().resolve()
DATA_DIR     = PROJECT_ROOT / "data" / "external" / "medmnist"
OUTPUT_DIR   = PROJECT_ROOT / "prototyping" / "week-03" / "outputs"

DATA_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

RUN_ID = datetime.now().strftime("%Y%m%d-%H%M%S")
RUN_DIR = OUTPUT_DIR / f"run-{RUN_ID}"
RUN_DIR.mkdir(parents=True, exist_ok=True)

# -------------------------
# 5) Version snapshot (helps reproducibility & README/results-week-03.md)
# -------------------------
def get_ver(mod, attr="__version__"):
    try:
        return getattr(mod, attr)
    except Exception:
        return "N/A"

ENV_SNAPSHOT = {
    "python": sys.version.split()[0],
    "platform": platform.platform(),
    "processor": platform.processor() or platform.machine(),
    "torch": get_ver(torch),
    "torchvision": get_ver(torchvision),
    "medmnist": get_ver(medmnist),
    "sklearn": get_ver(sklearn),
    "matplotlib": get_ver(matplotlib),
    "numpy": np.__version__,
    "device": DEVICE_NAME,
    "mps_available": bool(getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()),
    "cuda_available": torch.cuda.is_available(),
}

# Pretty print the environment summary
print("=== Week-03 Environment Summary ===")
for k, v in ENV_SNAPSHOT.items():
    print(f"{k:>12}: {v}")
print(f"\nSelected device  : {DEVICE} (ok={DEVICE_OK})")
print(f"Project root     : {PROJECT_ROOT}")
print(f"Data dir         : {DATA_DIR}")
print(f"Outputs (run)    : {RUN_DIR}")

# write a compact requirements snapshot for this notebook
req_lines = [
    f"# Week-03 requirements snapshot ({RUN_ID})",
    f"python=={ENV_SNAPSHOT['python']}",
    f"torch=={ENV_SNAPSHOT['torch']}",
    f"torchvision=={ENV_SNAPSHOT['torchvision']}",
    f"medmnist=={ENV_SNAPSHOT['medmnist']}",
    f"scikit-learn=={ENV_SNAPSHOT['sklearn']}",
    f"matplotlib=={ENV_SNAPSHOT['matplotlib']}",
    f"numpy=={ENV_SNAPSHOT['numpy']}",
]
REQ_PATH = OUTPUT_DIR / "requirements-week-03.txt"
with open(REQ_PATH, "w") as f:
    f.write("\n".join(req_lines))
print(f"\nSaved requirements snapshot → {REQ_PATH}")

# Sanity message about Apple MPS (useful for M-series laptops/desktops)
if DEVICE.type == "mps":
    print("\n[MPS] Apple Silicon backend is active. We'll use autocast(fp16) later for training to speed things up.")
elif DEVICE.type == "cuda":
    print("\n[CUDA] CUDA backend is active.")
else:
    print("\n[CPU] MPS/CUDA not available. This will be slower but still functional.")


=== Week-03 Environment Summary ===
      python: 3.11.9
    platform: macOS-26.0-arm64-arm-64bit
   processor: arm
       torch: 2.8.0
 torchvision: 0.23.0
    medmnist: 3.0.2
     sklearn: 1.7.2
  matplotlib: 3.10.6
       numpy: 2.3.3
      device: Apple Silicon MPS
mps_available: True
cuda_available: False

Selected device  : mps (ok=True)
Project root     : /Users/ali/Documents/MSC-Project/Project/prototyping/week-03
Data dir         : /Users/ali/Documents/MSC-Project/Project/prototyping/week-03/data/external/medmnist
Outputs (run)    : /Users/ali/Documents/MSC-Project/Project/prototyping/week-03/prototyping/week-03/outputs/run-20250926-184050

Saved requirements snapshot → /Users/ali/Documents/MSC-Project/Project/prototyping/week-03/prototyping/week-03/outputs/requirements-week-03.txt

[MPS] Apple Silicon backend is active. We'll use autocast(fp16) later for training to speed things up.
