<a href="https://colab.research.google.com/github/present42/PyTorchPractice/blob/main/Following_Flax_Tutorial_Managing_Parameters_and_State.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Managing Parameters and State

* manage the variables from initialization to updates
* split and re-assemble parameters and state
* use `vmap` with batch-dependent state

In [2]:
from flax import linen as nn

In [5]:
import jax.numpy as jnp
import jax

In [3]:
class BiasAdderWithRunningMean(nn.Module):
  momentum: float = 0.9

  @nn.compact
  def __call__(self, x):
    is_initialized = self.has_variable('batch_stats', 'mean')
    mean = self.variable('batch_stats', 'mean', jnp.zeros, x.shape[1:])
    bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
    if is_initialized:
      mean.value = (self.momentum * mean.value +
                    (1.0 - self.momentum) * jnp.mean(x, axis=0, keepdims=True))
    return mean.value + bias

tricky part with init here is that we need to split the state variables and the parameters we're going to optimize for.

First we define `update_step` as follows (with a dummy loss that should be replaced with yours):

In [9]:
import jax.random as random
import flax
import optax

In [10]:
key = random.key(0)
key, i_key, m_key = random.split(key, 3)
model = BiasAdderWithRunningMean()

dummy_input = random.normal(i_key, (10,)) # Dummy input data
variables = model.init(m_key, dummy_input)
print(variables)

# flax.core.pop creates a new FrozenDict where one entry is removed
# returns a pair with the new FrozenDict and the removed value
state, params = flax.core.pop(variables, 'params')
print(variables)
del variables # delete variables to avoid wasting resources
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)

{'batch_stats': {'mean': Array(0., dtype=float32)}, 'params': {'bias': Array(0., dtype=float32)}}
{'batch_stats': {'mean': Array(0., dtype=float32)}, 'params': {'bias': Array(0., dtype=float32)}}


In [11]:
def update_step(apply_fn, x, opt_state, params, state):
  def loss(params):
    y, updated_state = apply_fn({'params': params, **state},
                                x, mutable=list(state.keys()))
    l = ((x - y) ** 2).sum() # replace with your loss here
    return l, updated_state

  (l, updated_state), grads = jax.value_and_grad(
      loss, has_aux=True)(params)
  updates, opt_state = tx.update(grads, opt_state) # update gradient
  params = optax.apply_updates(params, updates) # recompute params
  return opt_state, params, updated_state

In [12]:
num_epochs = 10

for _ in range(num_epochs):
  opt_state, params, state = update_step(
      model.apply, dummy_input, opt_state, params, state
  )

## `vmap` across the batch dimension

In [13]:
from functools import partial

class MLP(nn.Module):
  hidden_size: int
  out_size: int

  @nn.compact
  def __call__(self, x, train=False):
    norm = partial(
        nn.BatchNorm,
        use_running_average=not train,
        momentum=0.9,
        epsilon=1e-5,
        axis_name="batch" # Name batch dim
    )

    x = nn.Dense(self.hidden_size)(x)
    x = norm()(x)
    x = nn.relu(x)
    x = nn.Dense(self.hidden_size)(x)
    x = norm()(x)
    x = nn.relu(x)
    y = nn.Dense(self.out_size)(x)

    return y

Secondly, we need to specify the same name when calling `vmap` in our training code:

In [14]:
def update_step(apply_fn, x_batch, y_batch, opt_state, params, state):
  def batch_loss(params):
    def loss_fn(x, y):
      pred, updated_state = apply_fn(
          {'params': params, **state},
          x, mutable=list(state.keys())
      )
      return (pred - y) ** 2, updated_state

    loss, updated_state = jax.vmap(
        loss_fn, out_axes=(0, None), # Do not vmap `updated_state`
        axis_name='batch' # Name batch dim
    )(x_batch, y_batch)
    return jnp.mean(loss), updated_state

  (loss, updated_state), grads = jax.value_and_grad(
      batch_loss, has_aux=True
  )(params)

  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return opt_state, params, updated_state, loss



### Note that we also need to specify that the model state does not have a batch dimension

In [None]:
model = MLP(hidden_size=10, out_size=1)
variables = model.init(m_key, dummy_input)
# split state and params
state, params = flax.core.pop(variables, 'params')
del variables # delete variables to avoid wasting resources
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)

for _ in range(num_epochs):
  opt_state, params, state, loss = update_step(
      model.apply, X, Y, opt_state, params, state
  )