# Multilayer Perceptron (MNIST)

for a multilayer perceptron (mlp), the objective is to minimize a loss function over the dataset. assuming supervised learning with input–label pairs $(x_i, y_i)$, the generic minimization problem is:

$$
\min_\theta \frac{1}{N} \sum_{i=1}^N \mathcal{L}(f_\theta(x_i), y_i)
$$

where:

* $f_\theta(x)$ is the output of the MLP with parameters $\theta$ (weights + biases across all layers),
* $\mathcal{L}(\cdot, \cdot)$ is the loss function, e.g. mean squared error (MSE), cross-entropy, etc.,
* $N$ is the number of training samples.

---

**for regression (MSE loss):**

$$
\mathcal{L}(f_\theta(x), y) = \|f_\theta(x) - y\|^2
$$

**for binary classification (sigmoid output + binary cross entropy):**
$$
\mathcal{L}(f_\theta(x), y) = -y \log f_\theta(x) - (1 - y) \log(1 - f_\theta(x))
$$

**for multiclass classification (softmax + cross entropy):**

$$
\mathcal{L}(f_\theta(x), y) = -\sum_{k=1}^K y_k \log \left( \text{softmax}(f_\theta(x))_k \right)
$$

---

the MLP itself is defined recursively:

$$
\begin{aligned}
h^{(0)} &= x \\
h^{(l)} &= \sigma(W^{(l)} h^{(l-1)} + b^{(l)}), \quad l = 1, \dots, L-1 \\
f_\theta(x) &= W^{(L)} h^{(L-1)} + b^{(L)} \quad \text{(or apply output activation)}
\end{aligned}
$$

where:

* $\sigma$ is an activation function (ReLU, tanh, etc.),
* $\theta = \{W^{(l)}, b^{(l)}\}_{l=1}^L$.

---

you can add regularization if desired:

$$
\min_\theta \frac{1}{N} \sum_{i=1}^N \mathcal{L}(f_\theta(x_i), y_i) + \lambda \|\theta\|^2
$$

In [12]:
# Imports
# import os
# os.environ["JAX_PLATFORM_NAME"] = "METAL"          # before importing jax
# os.environ["JAX_PLATFORMS"] = "metal,cpu"        # allow cpu fallback for missing op
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import pandas as pd

# For image processing
from io import BytesIO
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

In [13]:
jax.devices()

[CudaDevice(id=0)]

In [14]:
df = pd.read_parquet('data/mnist_train.parquet')
df

Unnamed: 0,image,label
0,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,5
1,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,0
2,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,4
3,"{'bytes': b""\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...",1
4,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,9
...,...,...
59995,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,8
59996,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,3
59997,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,5
59998,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,6


In [15]:
def show_image(byte_str):    
    img = Image.open(BytesIO(byte_str))
    plt.imshow(img)
    plt.axis('off')
    plt.show()

# show_image(df['image'][2]['bytes']), df['label'][2]

df

Unnamed: 0,image,label
0,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,5
1,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,0
2,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,4
3,"{'bytes': b""\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...",1
4,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,9
...,...,...
59995,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,8
59996,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,3
59997,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,5
59998,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,6


In [16]:
def bytes_dict_to_jax_array(d):
    img = Image.open(BytesIO(d['bytes']))
    return jnp.array(img)

def preprocess_df(df):
    df = df.copy()
    df['image'] = df['image'].map(bytes_dict_to_jax_array)
    X, y = jnp.stack(df['image'].tolist()), jax.nn.one_hot(df['label'], 10)
    X = X.reshape(X.shape[0], -1) / 255
    return X, y

X_train, y_train = preprocess_df(df)
X_train.shape, y_train.shape

((60000, 784), (60000, 10))

In [17]:
def create_params(layers):
    key = jax.random.key(42)
    def create_matrix_and_bias(n, m, key):
        k1, k2 = jax.random.split(key)
        return jax.random.normal(k1, (m, n)), jax.random.normal(k2, (m, 1))
    
    params = list(map(create_matrix_and_bias, layers, layers[1:], jax.random.split(jax.random.key(42), len(layers)-1)))
    return params

layers = [X_train.shape[1], 128, 64, 10]
jax.tree.map(lambda p: p.shape, create_params(layers))

[((128, 784), (128, 1)), ((64, 128), (64, 1)), ((10, 64), (10, 1))]

In [18]:
@jax.jit
def fwd(params, X):
    h = X.T
    for W, b in params:
        l = W @ h + b
        h = jax.nn.relu(l)
    return l.T # do not relu last layer

# fwd(params, X_train).shape, y_train.shape
@jax.jit
def loss(params, X, y):
    logits = fwd(params, X)
    return -jnp.mean(jnp.sum(y*jax.nn.log_softmax(logits), axis=1))

In [19]:
grad_loss = jax.jit(jax.grad(loss))

@jax.jit
def train(params, X_train, y_train):
    lr = 0.001

    def body(i, params):
        params = jax.tree.map(lambda param, grad: param - lr * grad, params, grad_loss(params, X_train, y_train))
        # params = [param - lr * grad for param, grad in zip(params, grad_loss(params, X_train, y_train))]
        # params -= lr * grad_loss(params, X_train, y_train)

        def do_print(_):
            jax.debug.print("step {i}, loss: {l}", i=i, l=loss(params, X_train, y_train))
            return None

        _ = jax.lax.cond(i % 100 == 0, do_print, lambda _: None, operand=None)
        return params

    params = jax.lax.fori_loop(0, 30000, body, params)
    return params

params = train(create_params(layers), X_train, y_train)

In [20]:
loss(params, X_train, y_train)

NameError: name 'params' is not defined

In [None]:
X_test, y_test = preprocess_df(pd.read_parquet('data/mnist_test.parquet'))

In [None]:
# compute accuracy
(jax.nn.one_hot(jnp.argmax(fwd(params, X_test),axis=1), 10) * y_test).sum() / y_test.shape[0]

In [None]:
# Accuracy is bad. Let's try rescaling inputs per pixel to have mean 0 and stddev 1.
def rescale_inputs(X, mu=None, sigma=None, eps=0.0001):
    if mu is None:
        mu = jnp.mean(X)
    if sigma is None:
        sigma = jnp.std(X)
    return (X - mu) / (sigma+eps), mu, sigma

X_train, mu, sigma = rescale_inputs(X_train)
params2 = train(create_params(layers), X_train, y_train)

In [None]:
# compute accuracy again
(jax.nn.one_hot(jnp.argmax(fwd(params2, rescale_inputs(X_test, mu, sigma)[0]),axis=1), 10) * y_test).sum() / y_test.shape[0]

In [None]:
# didn't work! let's try training with mini-batches instead of the entire training set.
def get_batches(X, y, key, batch_size):
    num_examples = X.shape[0]
    perm = jax.random.permutation(key, num_examples)
    X = X[perm]
    y = y[perm]

    for batch_num in range(num_examples // batch_size):
        yield X[batch_num*batch_size:(batch_num+1)*batch_size], y[batch_num*batch_size:(batch_num+1)*batch_size]

@jax.jit
def train_batched(params, X, y, key, batch_size=64, num_epochs=10):
    lr = 0.001
    keys = jax.random.split(key, num_epochs)

    # def body(i, params):
    for i in range(num_epochs):
        for bX, by in get_batches(X, y, keys[i], batch_size):
            params = jax.tree.map(lambda param, grad: param - lr * grad, params, grad_loss(params, X_train, y_train))

        jax.debug.print("loss: {l}", l=loss(params, X, y))
        # print(f"loss: {loss(params, X, y)}")
        # return params

    # params = jax.lax.fori_loop(0, num_epochs, body, params)
    return params

In [None]:
params = train_batched(create_params(layers), X_train, y_train, jax.random.key(69))

In [None]:
def get_epoch_batches(X, y, key, batch_size):
    num = (X.shape[0] // batch_size) * batch_size
    perm = jax.random.permutation(key, X.shape[0])[:num]
    X, y = X[perm], y[perm]
    X = X.reshape(-1, batch_size, *X.shape[1:])
    y = y.reshape(-1, batch_size, *y.shape[1:])
    return X, y
    
def init_adam(params):
    m = jax.tree_util.tree_map(jnp.zeros_like, params)
    v = jax.tree_util.tree_map(jnp.zeros_like, params)
    t = jax.tree_util.tree_map(jnp.zeros_like, params)
    return params, m, v, t

# *gasp* Adam!
def adam_update_single(params, m, v, t, g, lr=1e-3, b1=0.9, b2=0.999, eps=1e-8):
    t_new = t + 1
    
    # 1. Update the biased first and second moment estimates (the state).
    m_new = b1 * m + (1 - b1) * g
    v_new = b2 * v + (1 - b2) * jnp.square(g)

    # 2. Compute the bias-corrected estimates (temporary values for this step).
    m_hat = m_new / (1 - b1**t_new)
    v_hat = v_new / (1 - b2**t_new)
    
    # 3. Update the parameters using the bias-corrected estimates.
    params_new = params - lr * m_hat / (jnp.sqrt(v_hat) + eps)
    
    # 4. Return the new parameters and the updated RAW moments for the next state.
    return params_new, m_new, v_new, t_new
    
@jax.jit
def adam_update(params, m, v, t, g, lr=1e-3, b1=0.9, b2=0.999, eps=1e-8):
    tree = jax.tree.map(lambda params, m, v, t, g: adam_update_single(params, m, v, t, g, lr, b1, b2, eps), params, m, v, t, g)
    return jax.tree.transpose(jax.tree.structure(params), jax.tree.structure(('*', '*', '*', '*')), tree)

def train_adam(X, y, key, init_params, batch_size=64, num_epochs=10):
    keys = jax.random.split(key, num_epochs)
    params, m, v, t = init_adam(init_params)

    def epoch_step(i, opt_state):
        params, m, v, t = opt_state
        bX, by = get_epoch_batches(X, y, keys[i], batch_size)

        def batch_step(opt_state, batch):
            p, m, v, t = opt_state
            Xb, yb = batch
            return adam_update(p, m, v, t, grad_loss(p, Xb, yb)), None

        (params_, m_, v_, t_), _ = jax.lax.scan(batch_step, (params, m, v, t), (bX, by))
        jax.debug.print("epoch {i}, loss: {l}", i=i, l=loss(params_, X, y))
        return params_, m_, v_, t_

    params, m, v, t = jax.lax.fori_loop(0, num_epochs, epoch_step, (params, m, v, t))
    return params


In [None]:
params = train_adam(X_train, y_train, jax.random.key(66), create_params(layers), num_epochs=1000)

In [None]:
(jax.nn.one_hot(jnp.argmax(fwd(params, X_test) ,axis=1), 10) * y_test).sum() / y_test.shape[0]