In [1]:
import os; os.environ['JAX_PLATFORM_NAME'] = 'cpu'
from pathlib import Path

import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float, Int, PyTree 
import h5py
import optax


In [2]:
FILENAME = Path("/home/raymo/Downloads/batch_1.hdf5")
INPUT_SHAPE = (200,)
SEED = 42
LEARNING_RATE = 3e-4

In [3]:
class MLP(eqx.Module):
    layers: list

    def __init__(self, key):
        key1, key2, key3, key4, key5, key6 = jax.random.split(key, 6)

        self.layers = [
            eqx.nn.Linear(200, 128, key=key4),
            jax.nn.relu,
            eqx.nn.Linear(128, 64, key=key5),
            jax.nn.relu,
            eqx.nn.Linear(64, 2, key=key6),
        ]

    def __call__(self, x: Float[Array, "200"]) -> Float[Array, "2"]:
        for layer in self.layers:
            x = layer(x)
        return x
    
key = jax.random.PRNGKey(SEED)
key, subkey = jax.random.split(key, 2)
model = MLP(subkey)

In [4]:
def loss(model: MLP, x: Float[Array, "batch 200"], y: Float[Array, "batch 2"]) -> Float[Array, ""]:
    preds = jax.vmap(model)(x)  # (batch, 2)
    return jnp.mean(jnp.sum((preds - y) ** 2, axis=-1))  # MSE over both outputs

In [5]:
test_x = jax.random.normal(key, (28, 200))     # (batch=28, 200)
test_y = jax.random.normal(key, (28, 2))       # (batch=28, 2)
pred_y = jax.vmap(model)(test_x)               # (28, 2)
print("Predicited y: ", pred_y[0])                               # show first prediction
value, grads = eqx.filter_value_and_grad(loss)(model, test_x, test_y)
print("Loss value: ", value)

Predicited y:  [-0.04515041  0.04674196]
Loss value:  1.7107732


In [None]:
loss = eqx.filter_jit(loss)  # JIT our loss function from earlier!

@eqx.filter_jit
def compute_accuracy(
    model: MLP, x: Float[Array, "batch 200"], y: Int[Array, "batch 2"]
) -> Float[Array, ""]:
    pred_y = jax.vmap(model)(x)
    accuracy = jnp.mean((pred_y == y).astype(jnp.float32))
    return accuracy

def evaluate(model: MLP, test_dataset):
    """This function evaluates the model on the test dataset,
    computing both the average loss and the average accuracy.
    """
    avg_loss = 0
    avg_acc = 0
    for x, y in test_dataset:
        x = x.numpy()
        y = y.numpy()
        # Note that all the JAX operations happen inside `loss` and `compute_accuracy`,
        # and both have JIT wrappers, so this is fast.
        avg_loss += loss(model, x, y)
        avg_acc += compute_accuracy(model, x, y)
    return avg_loss / len(test_dataset), avg_acc / len(test_dataset)