# Multilayer Perceptron (MNIST)

for a multilayer perceptron (mlp), the objective is to minimize a loss function over the dataset. assuming supervised learning with input–label pairs $(x_i, y_i)$, the generic minimization problem is:

$$
\min_\theta \frac{1}{N} \sum_{i=1}^N \mathcal{L}(f_\theta(x_i), y_i)
$$

where:

* $f_\theta(x)$ is the output of the MLP with parameters $\theta$ (weights + biases across all layers),
* $\mathcal{L}(\cdot, \cdot)$ is the loss function, e.g. mean squared error (MSE), cross-entropy, etc.,
* $N$ is the number of training samples.

---

**for regression (MSE loss):**

$$
\mathcal{L}(f_\theta(x), y) = \|f_\theta(x) - y\|^2
$$

**for binary classification (sigmoid output + binary cross entropy):**
$$
\mathcal{L}(f_\theta(x), y) = -y \log f_\theta(x) - (1 - y) \log(1 - f_\theta(x))
$$

**for multiclass classification (softmax + cross entropy):**

$$
\mathcal{L}(f_\theta(x), y) = -\sum_{k=1}^K y_k \log \left( \text{softmax}(f_\theta(x))_k \right)
$$

---

the MLP itself is defined recursively:

$$
\begin{aligned}
h^{(0)} &= x \\
h^{(l)} &= \sigma(W^{(l)} h^{(l-1)} + b^{(l)}), \quad l = 1, \dots, L-1 \\
f_\theta(x) &= W^{(L)} h^{(L-1)} + b^{(L)} \quad \text{(or apply output activation)}
\end{aligned}
$$

where:

* $\sigma$ is an activation function (ReLU, tanh, etc.),
* $\theta = \{W^{(l)}, b^{(l)}\}_{l=1}^L$.

---

you can add regularization if desired:

$$
\min_\theta \frac{1}{N} \sum_{i=1}^N \mathcal{L}(f_\theta(x_i), y_i) + \lambda \|\theta\|^2
$$

In [1]:
# Imports
# import os
# os.environ["JAX_PLATFORM_NAME"] = "METAL"          # before importing jax
# os.environ["JAX_PLATFORMS"] = "metal,cpu"        # allow cpu fallback for missing op
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import pandas as pd

# For image processing
from io import BytesIO
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

In [2]:
jax.devices()

[CpuDevice(id=0)]

In [3]:
df = pd.read_parquet('data/mnist_train.parquet')
df

Unnamed: 0,image,label
0,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,5
1,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,0
2,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,4
3,"{'bytes': b""\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...",1
4,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,9
...,...,...
59995,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,8
59996,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,3
59997,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,5
59998,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,6


In [4]:
def show_image(byte_str):    
    img = Image.open(BytesIO(byte_str))
    plt.imshow(img)
    plt.axis('off')
    plt.show()

# show_image(df['image'][2]['bytes']), df['label'][2]

df

Unnamed: 0,image,label
0,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,5
1,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,0
2,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,4
3,"{'bytes': b""\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...",1
4,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,9
...,...,...
59995,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,8
59996,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,3
59997,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,5
59998,{'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...,6


In [5]:
def bytes_dict_to_jax_array(d):
    img = Image.open(BytesIO(d['bytes']))
    return jnp.array(img)

def preprocess_df(df):
    df = df.copy()
    df['image'] = df['image'].map(bytes_dict_to_jax_array)
    X, y = jnp.stack(df['image'].tolist()), jax.nn.one_hot(df['label'], 10)
    X = X.reshape(X.shape[0], -1) / 255
    return X, y

X_train, y_train = preprocess_df(df)
X_train.shape, y_train.shape

((60000, 784), (60000, 10))

In [6]:
def create_params(layers):
    key = jax.random.key(42)
    def create_matrix_and_bias(n, m, key):
        k1, k2 = jax.random.split(key)
        return jax.random.normal(k1, (m, n)), jax.random.normal(k2, (m, 1))
    
    params = list(map(create_matrix_and_bias, layers, layers[1:], jax.random.split(jax.random.key(42), len(layers)-1)))
    return params

layers = [X_train.shape[1], 128, 64, 10]
jax.tree.map(lambda p: p.shape, create_params(layers))

[((128, 784), (128, 1)), ((64, 128), (64, 1)), ((10, 64), (10, 1))]

In [7]:
@jax.jit
def fwd(params, X):
    h = X.T
    for W, b in params:
        l = W @ h + b
        h = jax.nn.relu(l)
    return l.T # do not relu last layer

# fwd(params, X_train).shape, y_train.shape
@jax.jit
def loss(params, X, y):
    logits = fwd(params, X)
    return -jnp.mean(jnp.sum(y*jax.nn.log_softmax(logits), axis=1))

In [8]:
grad_loss = jax.jit(jax.grad(loss))

@jax.jit
def train(params, X_train, y_train):
    lr = 0.003

    def body(i, params):
        params = jax.tree.map(lambda param, grad: param - lr * grad, params, grad_loss(params, X_train, y_train))
        # params = [param - lr * grad for param, grad in zip(params, grad_loss(params, X_train, y_train))]
        # params -= lr * grad_loss(params, X_train, y_train)

        def do_print(_):
            jax.debug.print("step {i}, loss: {l}", i=i, l=loss(params, X_train, y_train))
            return None

        _ = jax.lax.cond(i % 100 == 0, do_print, lambda _: None, operand=None)
        return params

    params = jax.lax.fori_loop(0, 30000, body, params)
    return params

# params = train(create_params(layers))

In [11]:
import modal

app = modal.App("learning-jax")

@app.function()
def modal_train(X_train, y_train):
    params = train(create_params(layers), X_train, y_train)
    return params


with modal.enable_output():
    with app.run():
        # my_function.remote(42)
        params = modal_train.remote(X_train, y_train)


Output()

Output()

InvalidError: Function <function modal_train at 0x12e879c60> has size 190576541 bytes when packaged. This is larger than the maximum limit of 16 MiB. Try reducing the size of the closure by using parameters or mounts, not large global variables.

In [None]:
params = train(create_params(layers))

In [9]:
loss(params, X_train, y_train)

Array(3.8788896, dtype=float32)