Skip to content

xkjjx/basic-digit-recognition

Repository files navigation

About

MNIST digit recognition with MLP and CNN models.

Structure

train/              # Training scripts (train_mlp.py, train_cnn.py)
visualize/          # Model visualization generators
utils/              # Format conversion (pth → json/onnx), HTML demo export
data/               # MNIST dataset
weights/            # Saved model weights
visualizations/     # Generated analysis outputs
common.py           # Shared utilities
test.py             # Model evaluation
train_and_test.py   # Unified train + test script

Usage

Train & Test:

uv run train_and_test.py --model cnn
uv run train_and_test.py --model mlp --epochs 50 --lr 0.0005
uv run train_and_test.py --model cnn --email  # requires SMTP_USER, SMTP_PASSWORD env vars

Train only:

uv run -m train.train_cnn
uv run -m train.train_mlp --scheduler step --batch-size 128

Test: uv run test.py weights/cnn_weights.pth --model-type cnn Visualize: uv run -m visualize.visualize_cnn Convert: uv run -m utils.change_format Export HTML demo: uv run -m utils.export_html weights/model.onnx -o demo.html

Training Options

Both training scripts support:

  • --lr - Learning rate (default: 0.001)
  • --epochs - Training epochs (default: 100)
  • --batch-size - Batch size (default: 64)
  • --scheduler - LR scheduler: none, step, cosine, exponential, onecycle (default: cosine)
  • --format - Output format: pth, onnx, json (default: pth)
  • --num-rotations - Number of rotation angles for data augmentation (default: 0, disabled)

GPU acceleration via CUDA or MPS (Apple Silicon) is automatic.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages