In [None]:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax

In [None]:
class Linear(eqx.Module):
    weight: jax.Array
    bias: jax.Array

    def __init__(self, size_in, size_out, key: jax.Array):
        wkey, bkey = jax.random.split(key)
        self.weight = jax.random.normal(wkey, (size_out, size_in))
        self.bias = jax.random.normal(bkey, (size_out,))

    def __call__(self, x):
        return self.weight @ x + self.bias

In [None]:
x = jnp.linspace(0, 1, 20).reshape(-1, 1)
y = 0.5 * x + 0.1

In [None]:
@jax.jit
@jax.grad
def loss_fn(model, x, y):
    pred = jax.vmap(model)(x)
    return jnp.mean((pred - y) ** 2)

In [None]:
key = jax.random.PRNGKey(42)
linear = Linear(1, 1, key)
optim = optax.sgd(0.05)
state = optim.init(eqx.filter(linear, eqx.is_array))

for iter in range(1000):
    grad_loss = loss_fn(linear, x, y)
    updates, state = optim.update(grad_loss, state, linear)
    linear = eqx.apply_updates(linear, updates)

In [None]:
linear.weight, linear.bias