In [None]:
import optax
import jax
import jax.numpy as np

In [None]:

optimizer = optax.adam(1e-3)
params = {'w': np.ones((10,)),
          'b': np.ones((1,))}

xs = np.ones((3,10))
ys = np.arange(3)

compute_loss = lambda params, x, y: np.sum(optax.l2_loss(x@params['w']+params['b'], y))
grad_fn = jax.grad(compute_loss)


# initialize optimizer state (rolling avg etc.)
opt_state = optimizer.init(params)

# compute gradient 
grads = grad_fn(params, xs, ys)

# get updates given gradients and states
updates, opt_state = optimizer.update(grads, opt_state)
print(updates)

# apply the updates to parameters
params = optax.apply_updates(params, updates)


# more flexible gradient tarnsforms 

xs = np.ones((3,10))
ys = np.arange(3)

gradient_transform = optax.chain(
    optax.clip_by_global_norm(1e-8),
    optax.scale_by_adam(),
    optax.scale(-1e-3),)


opt_state = gradient_transform.init(params)
grads = grad_fn(params, xs, ys)
updates, opt_state = gradient_transform.update(grads, opt_state)
updates

