# Neural Network with JAX

In this lab, we will:
1. Downloads and loads MNIST into NumPy arrays (if it doesn't already
   exist locally).
2. Builds a simple Multi-Layer Perceptron in JAX.
3. Trains the network on MNIST.
4. Evaluates the performance on test data.
5. Provides a custom inference function for your own handwriting
   images.

This lab is based on a
[JAX example](https://github.com/jax-ml/jax/blob/main/examples/mnist_classifier_fromscratch.py).

Please notice that MNIST is the "hello world" for machine learning,
and there are many many examples available online, including some
simplier ones that use libraries:
[JAX with pre-built optimizers](https://github.com/jax-ml/jax/blob/main/examples/mnist_classifier.py),
[FLAX](https://flax.readthedocs.io/en/latest/mnist_tutorial.html),
[pytorch](https://github.com/pytorch/examples/tree/main/mnist), and
[Keras](https://www.tensorflow.org/datasets/keras_example).

## MNIST Data Loader

We start by downloading the MNIST data set and store it locally.
Our data loader will parse, reshape, normalize them, and return them
in NumPy arrays.

In [None]:
from os.path import isfile
from urllib.request import urlretrieve

base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"

# File names
files = {
    "train_images": "train-images-idx3-ubyte.gz",
    "train_labels": "train-labels-idx1-ubyte.gz",
    "test_images":  "t10k-images-idx3-ubyte.gz",
    "test_labels":  "t10k-labels-idx1-ubyte.gz",
}

for key, file in files.items():
    if not isfile(file):
        url = base_url + file
        print(f"Downloading {url} to {file}...")
        urlretrieve(url, file)
    else:
        print(f"{file} exists; skip download")

In [None]:
import gzip
import struct
import array
from jax import numpy as np

# Parsing functions

def parse_labels(file):
    with gzip.open(file, "rb") as fh:
        _magic, num_data = struct.unpack(">II", fh.read(8))
        # Read the label data as 1-byte unsigned integers
        return np.array(array.array("B", fh.read()), dtype=np.uint8)

def parse_images(file):
    with gzip.open(file, "rb") as fh:
        _magic, num_data, rows, cols = struct.unpack(">IIII", fh.read(16))
        # Read the image data as 1-byte unsigned integers
        images = np.array(array.array("B", fh.read()), dtype=np.uint8)
        # Reshape to (num_data, 28, 28)
        images = images.reshape(num_data, rows, cols)
        return images

In [None]:
# Parse raw data

train_images_raw = parse_images(files["train_images"])
train_labels_raw = parse_labels(files["train_labels"])

test_images_raw  = parse_images(files["test_images"])
test_labels_raw  = parse_labels(files["test_labels"])

In [None]:
# Standardize the images, i.e., flatten and normalize images to [0, 1]
def standardize(images):
    return images.reshape(-1, 28*28).astype(np.float32) / 255

train_images = standardize(train_images_raw)
test_images  = standardize(test_images_raw)

In [None]:
# One-hot encode labels
def one_hot(labels, num_classes=10):
    return np.eye(num_classes)[labels]

train_labels = one_hot(train_labels_raw, 10).astype(np.float32)
test_labels  = one_hot(test_labels_raw,  10).astype(np.float32)