# Digit Classifier — End-to-End Pipeline

This notebook walks through the full training pipeline step by step.

## 0. Setup

Make sure you have installed the package:

```bash
pip install -e ".[dev]"
pip install ipywidgets
```

In [1]:
import sys, os, pathlib

# Run everything from the project root so relative paths (datasets/, data/, etc.) resolve correctly.
# Find the project root by looking for pyproject.toml, starting from the notebook's location.
_here = pathlib.Path(os.getcwd())
_root = _here if (_here / "pyproject.toml").exists() else _here.parent
os.chdir(_root)
sys.path.insert(0, str(_root / "src"))

from digit_classifier import __version__
print(f"digit_classifier v{__version__}")
print(f"Working directory: {os.getcwd()}")

digit_classifier v1.0.0
Working directory: /Users/ryanrudes/GitHub/digit_classification


## 1. Download the raw dataset

In [2]:
from digit_classifier.preprocessing import download_dataset

data_dir = download_dataset(output_zip="data/dataset.zip")

## 2. Preprocess and cache as `.npz`

In [3]:
from digit_classifier.preprocessing import preprocess_and_cache

npz_path = preprocess_and_cache(
    dataset_name="mnist_rgb_224",
    color=True,
    size=224,
    data_dir="data/dataset",
    output_dir="datasets",
)
print("Cached dataset at:", npz_path)

Cached dataset at: datasets/mnist_rgb_224.npz


## 3. Load and split the dataset

In [4]:
from pathlib import Path
from digit_classifier.config import Config
from digit_classifier.training import load_cached_dataset
from digit_classifier.splitting import split_dataset, DEDUP_CACHE_DIR
from digit_classifier.external import DEFAULT_EXTERNAL_FRACTIONS, compute_external_manifest_hash

cfg = Config()

# Show dedup cache status before doing anything expensive
dataset_names = [ds.value for ds in DEFAULT_EXTERNAL_FRACTIONS]
manifest_hash = compute_external_manifest_hash(
    dataset_names, cfg.data.color, cfg.data.image_size,
    cfg.data.split_seed, cfg.data.train_fraction,
)
dedup_path = Path(DEDUP_CACHE_DIR) / f"dedup_indices_{manifest_hash}.json"
if dedup_path.exists():
    import json
    with open(dedup_path) as f:
        cached = json.load(f)
    total_cached = sum(len(v) for v in cached.values())
    print(f"✓ Dedup cache found: {dedup_path.name}")
    print(f"  {len(cached)} sources, {total_cached:,} surviving indices")
    print(f"  Deduplication will be skipped (cached)")
else:
    print(f"✗ No dedup cache found (hash: {manifest_hash})")
    print(f"  Deduplication will run on first load (may take a few minutes)")

# Load and split
images, labels, cached_mean, cached_std = load_cached_dataset(cfg)
print(f"\nLoaded {images.shape[0]} images, shape {images.shape}")

train_ds, val_ds, mean, std = split_dataset(
    images, labels, cached_mean, cached_std,
    mix_external=cfg.data.mix_external,
    external_fractions=DEFAULT_EXTERNAL_FRACTIONS,
    color=cfg.data.color,
    size=cfg.data.image_size,
)
print(f"\nTrain: {len(train_ds):,}, Val: {len(val_ds):,}")
print(f"Mean: {mean}")
print(f"Std:  {std}")

✓ Dedup cache found: dedup_indices_0fd3a92e5f62.json
  10 sources, 1,101,982 surviving indices
  Deduplication will be skipped (cached)

Loaded 10294 images, shape torch.Size([10294, 3, 224, 224])


Output()


Train: 1,111,246, Val: 1,030
Mean: (0.5714552998542786, 0.524692952632904, 0.4752778112888336)
Std:  (0.23079083859920502, 0.22527235746383667, 0.23291534185409546)


## 4. Inspect the model

In [5]:
from digit_classifier.model import ResNeXt

model = ResNeXt(
    layers=list(cfg.model.layers),
    num_classes=cfg.model.num_classes,
    groups=cfg.model.groups,
    width_per_group=cfg.model.width_per_group,
    drop_path_rate=cfg.model.drop_path_rate,
)
total_params = sum(p.numel() for p in model.parameters())
print(f"ResNeXt parameters: {total_params:,}")

ResNeXt parameters: 81,426,762


## 5. Train

Run the full training loop.  For a quick test, reduce `--epochs`:

```bash
python -m digit_classifier train --epochs 5 --no-wandb --no-compile
```

Or call it programmatically:

In [None]:
from digit_classifier.config import Config, TrainingConfig
from digit_classifier.training import train

quick_cfg = Config(
    training=TrainingConfig(
        epochs=2,
        warmup_epochs=1,
        wandb_enabled=False,
        compile_model=False,
    ),
)

# Uncomment to run:
train(quick_cfg)

Output()

## 6. Inference

After training, run webcam inference:

```bash
python -m digit_classifier infer --checkpoint checkpoints/<run_id>/best.pt
```