In [2]:
import jax
from jax import numpy as jnp, random, lax, jit
from flax import linen as nn

X = jnp.ones((1, 10))
Y = jnp.ones((5,))

model = nn.Dense(features=5)

@jit
def predict(params):
  return model.apply({'params': params}, X)

@jit
def loss_fn(params):
  return jnp.mean(jnp.abs(Y - predict(params)))

@jit
def init_params(rng):
  mlp_variables = model.init({'params': rng}, X)
  return mlp_variables['params']

# Get initial parameters
params = init_params(jax.random.PRNGKey(42))
print("initial params", params)

# Run SGD.
for i in range(50):
  loss, grad = jax.value_and_grad(loss_fn)(params)
  print(i, "loss = ", loss, "Yhat = ", predict(params))
  lr = 0.03
  params = jax.tree_util.tree_map(lambda x, d: x - lr * d, params, grad)



initial params FrozenDict({
    bias: Array([0., 0., 0., 0., 0.], dtype=float32),
    kernel: Array([[ 0.55015737,  0.41833436,  0.33977556,  0.43434998, -0.12176939],
           [ 0.07848787,  0.27258518, -0.22658114,  0.60018104,  0.27871728],
           [-0.3267914 , -0.45051524,  0.02986265, -0.5590184 ,  0.30982274],
           [ 0.05870081,  0.20131111, -0.15255067,  0.26707688, -0.6626963 ],
           [ 0.62493724, -0.20424645,  0.04849606,  0.25078458, -0.457508  ],
           [ 0.11445389,  0.08716885, -0.08621331,  0.42504248,  0.48199227],
           [ 0.01878028,  0.07163057, -0.21260886,  0.54276705, -0.1087756 ],
           [-0.08556435,  0.16303335, -0.54164785,  0.03657064, -0.0329951 ],
           [-0.10136051, -0.04724246,  0.20461115,  0.13906626, -0.04074767],
           [ 0.22522342,  0.32235065,  0.13873868,  0.13579997,  0.6603415 ]],      dtype=float32),
})
0 loss =  0.74939424 Yhat =  [[ 1.1570246   0.83440995 -0.45811766  2.2726204   0.30638173]]
1 loss =  0.