# MNIST Example with Haliax
This notebook trains a simple neural network on MNIST using Haliax.

In [1]:
import jax, jax.numpy as jnp
import haliax as hax
import haliax.nn as hnn
import equinox as eqx
from jax import random
import optax


In [2]:
# Generate synthetic data (num_batches of random images)
num_batches = 10


In [3]:
Batch = hax.Axis('batch', 128)
Image = hax.Axis('image', 28*28)
Hidden = hax.Axis('hidden', 256)
Classes = hax.Axis('classes', 10)

class Net(eqx.Module):
    mlp: hnn.MLP
    @staticmethod
    def init(key):
        mlp = hnn.MLP.init(Image, Classes, width=Hidden, depth=2, key=key)
        return Net(mlp)
    def __call__(self, x):
        return self.mlp(x)



In [4]:
def loss_fn(model, images, labels):
    imgs = hax.NamedArray(images.reshape(-1, 28*28), (Batch, Image))
    logits = model(imgs)
    loss = hnn.cross_entropy_loss(logits, Classes, labels)
    return loss.mean().scalar()


In [5]:
key = random.PRNGKey(0)
model = Net.init(key)
opt = optax.adam(1e-3)
opt_state = opt.init(model)

for epoch in range(1):
    for _ in range(num_batches):
        key, subkey1, subkey2 = random.split(key, 3)
        images = random.normal(subkey1, (Batch.size, 28*28))
        label_ids = random.randint(subkey2, (Batch.size,), 0, Classes.size)
        labels = hnn.one_hot(hax.NamedArray(label_ids, (Batch,)), Classes)
        grads = jax.grad(loss_fn)(model, images, labels)
        updates, opt_state = opt.update(grads, opt_state, params=model)
        model = eqx.apply_updates(model, updates)
    print("epoch done")


epoch done
