# Neural Network with JAX

In this lab, we will:
1. Downloads and loads MNIST into NumPy arrays (if it doesn't already
   exist locally).
2. Builds a simple Multi-Layer Perceptron in JAX.
3. Trains the network on MNIST.
4. Evaluates the performance on test data.
5. Provides a custom inference function for your own handwriting
   images.

This lab is based on a
[JAX example](https://github.com/jax-ml/jax/blob/main/examples/mnist_classifier_fromscratch.py).

Please notice that MNIST is the "hello world" for machine learning,
and there are many many examples available online, including some
simplier ones that use libraries:
[JAX with pre-built optimizers](https://github.com/jax-ml/jax/blob/main/examples/mnist_classifier.py),
[FLAX](https://flax.readthedocs.io/en/latest/mnist_tutorial.html),
[pytorch](https://github.com/pytorch/examples/tree/main/mnist), and
[Keras](https://www.tensorflow.org/datasets/keras_example).

## MNIST Data Loader

We start by downloading the MNIST data set and store it locally.
Our data loader will parse, reshape, normalize them, and return them
in NumPy arrays.

In [None]:
from os.path import isfile
from urllib.request import urlretrieve

base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"

# File names
files = {
    "train_images": "train-images-idx3-ubyte.gz",
    "train_labels": "train-labels-idx1-ubyte.gz",
    "test_images":  "t10k-images-idx3-ubyte.gz",
    "test_labels":  "t10k-labels-idx1-ubyte.gz",
}

for key, file in files.items():
    if not isfile(file):
        url = base_url + file
        print(f"Downloading {url} to {file}...")
        urlretrieve(url, file)
    else:
        print(f"{file} exists; skip download")

In [None]:
import gzip
import struct
import array
from jax import numpy as np

# Parsing functions

def parse_labels(file):
    with gzip.open(file, "rb") as fh:
        _magic, num_data = struct.unpack(">II", fh.read(8))
        # Read the label data as 1-byte unsigned integers
        return np.array(array.array("B", fh.read()), dtype=np.uint8)

def parse_images(file):
    with gzip.open(file, "rb") as fh:
        _magic, num_data, rows, cols = struct.unpack(">IIII", fh.read(16))
        # Read the image data as 1-byte unsigned integers
        images = np.array(array.array("B", fh.read()), dtype=np.uint8)
        # Reshape to (num_data, 28, 28)
        images = images.reshape(num_data, rows, cols)
        return images

In [None]:
# Parse raw data

train_images_raw = parse_images(files["train_images"])
train_labels_raw = parse_labels(files["train_labels"])

test_images_raw  = parse_images(files["test_images"])
test_labels_raw  = parse_labels(files["test_labels"])

In [None]:
# Standardize the images, i.e., flatten and normalize images to [0, 1]
def standardize(images):
    return images.reshape(-1, 28*28).astype(np.float32) / 255

train_images = standardize(train_images_raw)
test_images  = standardize(test_images_raw)

In [None]:
# One-hot encode labels
def one_hot(labels, num_classes=10):
    return np.eye(num_classes)[labels]

train_labels = one_hot(train_labels_raw, 10).astype(np.float32)
test_labels  = one_hot(test_labels_raw,  10).astype(np.float32)

## Visualize Some Training and Testing Data

Let's take a look at our data set!

In [None]:
from matplotlib import pyplot as plt

In [None]:
plt.imshow(train_images_raw[0,:,:], cmap='gray')

In [None]:
plt.imshow(test_images_raw[0,:,:], cmap='gray')

## Define a Simple Neural Network in JAX

In this subsection, we introduce the core function needed to
initialize the parameters of a multi-layer network.
Our network will have multiple layers, each characterized by a weight
matrix `W` and a bias vector `b`.
We will use random initialization scaled by a small factor to ensure
stable starting values for training.

In [None]:
from numpy.random import RandomState

def init_params(scale, layer_sizes, rng=RandomState(0)):
    """
    Initialize the parameters (weights and biases) for each layer in the network.

    Parameters
    ----------
    scale : float
        A scaling factor to control the initial range of the weights.
    layer_sizes : list of int
        The sizes of each layer in the network.
        e.g., [784, 1024, 1024, 10] means:
            - Input layer: 784 units
            - Hidden layer 1: 1024 units
            - Hidden layer 2: 1024 units
            - Output layer: 10 units
    rng : numpy.random.RandomState
        Random state for reproducibility.

    Returns
    -------
    params : list of tuples (W, b)
        Each tuple contains (W, b) for a layer.
        - W is a (input_dim, output_dim) array of weights
        - b is a (output_dim,) array of biases
    """
    return [
        (scale * rng.randn(m, n), scale * rng.randn(n))
        for m, n in zip(layer_sizes[:-1], layer_sizes[1:])
    ]

In the above function,
* We specify a list of layer sizes: for example,
  `[784, 1024, 1024, 10]`.
* For each pair of consecutive sizes `(m, n)`, we create a weight
  matrix W of shape `(m, n)` and a bias vector `b` of shape `(n,)`.
* Multiplying by scale ensures that initial values are not too large,
  which helps prevent numerical issues early in training.
* We store all `(W, b)` pairs in a list, one pair per layer, to be
  used throughout training and inference.

By calling `init_params(scale, layer_sizes)`, you obtain an
easy-to-manipulate structure that keeps all the parameters needed for
your neural network.


In [None]:
# Define network architecture and hyperparameters

layer_sizes = [784, 1024, 1024, 10]  # 2 hidden layers
param_scale = 0.1

In [None]:
# Initialize parameters

params = init_params(param_scale, layer_sizes)

## Forward Pass: The `predict` Function

Once the network parameters are initialized, we need a function to
perform the forward pass, producing an output for each batch of
inputs.
Below, we define `predict` to process data through multiple layers,
using a `tanh` activation on the hidden layers, and compute a
log-softmax on the final output layer for stability.

In [None]:
from jax import numpy as np
from jax.scipy.special import logsumexp

In [None]:
def predict(params, inputs):
    """
    Compute the network's output logits for a batch of inputs, then subtract
    log-sum-exp for numerical stability (log-softmax).

    Network architecture:
      - Hidden layers use tanh activation
      - Output layer is linear (we'll do log-softmax here)

    Parameters
    ----------
    params : list of (W, b) tuples
        Network's parameters for each layer.
    inputs : np.ndarray
        A batch of input data of shape (batch_size, input_dim).

    Returns
    -------
    np.ndarray
        Log probabilities of shape (batch_size, 10).
    """
    activations = inputs

    # Hidden layers
    for w, b in params[:-1]:
        outputs = np.dot(activations, w) + b
        activations = np.tanh(outputs)

    # Final layer (logits)
    final_w, final_b = params[-1]
    logits = np.dot(activations, final_w) + final_b

    # Log-Softmax: subtract logsumexp for numerical stability
    return logits - logsumexp(logits, axis=1, keepdims=True)

In the above function,
* Hidden Layers (`tanh`):
  Each hidden layer applies a linear transformation
  (`np.dot(activations, w) + b`) followed by the hyperbolic tangent
  activation function (`np.tanh`).
* Final Layer (`logits`):
  The last layer's output is not activated by tanh; instead, we use it
  directly as logits.
* Log-Softmax:
  We transform logits to log probabilities by subtracting the
  logsumexp(logits) along the class dimension.
  This step ensures numerical stability and can be directly used to
  compute losses like cross-entropy.