In [1]:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax

In [2]:
class Linear(eqx.Module):
    weight: jax.Array
    bias: jax.Array

    def __init__(self, size_in, size_out, key: jax.Array):
        wkey, bkey = jax.random.split(key)
        self.weight = jax.random.normal(wkey, (size_out, size_in))
        self.bias = jax.random.normal(bkey, (size_out,))

    def __call__(self, x):
        return self.weight @ x + self.bias

In [3]:
x = jnp.linspace(0, 1, 20).reshape(-1, 1)
y = 0.5 * x + 0.1

In [4]:
@jax.jit
@jax.grad
def loss_fn(model, x, y):
    pred = jax.vmap(model)(x)
    return jnp.mean((pred - y) ** 2)

In [5]:
key = jax.random.PRNGKey(42)
linear = Linear(1, 1, key)
optim = optax.sgd(0.05)
state = optim.init(linear)

for iter in range(1000):
    grad_loss = loss_fn(linear, x, y)
    updates, state = optim.update(grad_loss, state, linear)
    linear = eqx.apply_updates(linear, updates)

In [6]:
linear.weight, linear.bias

(Array([[0.49944192]], dtype=float32), Array([0.10030087], dtype=float32))