# Parameter Management

Once we have chosen an architecture
and set our hyperparameters,
we proceed to the training loop,
where our goal is to find parameter values
that minimize our loss function.
After training, we will need these parameters
in order to make future predictions.
Additionally, we will sometimes wish
to extract the parameters
perhaps to reuse them in some other context,
to save our model to disk so that
it may be executed in other software,
or for examination in the hope of
gaining scientific understanding.

Most of the time, we will be able
to ignore the nitty-gritty details
of how parameters are declared
and manipulated, relying on deep learning frameworks
to do the heavy lifting.
However, when we move away from
stacked architectures with standard layers,
we will sometimes need to get into the weeds
of declaring and manipulating parameters.
In this section, we cover the following:

* Accessing parameters for debugging, diagnostics, and visualizations.
* Sharing parameters across different model components.


In [1]:
import jax
from jax import numpy as jnp
from flax import nnx

(**We start by focusing on an MLP with one hidden layer.**)


In [2]:
class Net(nnx.Module):
	def __init__(self, n_inputs: int, n_hiddens: int, n_outputs: int, *, rngs: nnx.Rngs):
		self.net1 = nnx.Linear(n_inputs, n_hiddens, rngs=rngs)
		self.net2 = nnx.Linear(n_hiddens, n_outputs, rngs=rngs)
	
	def __call__(self, x: jax.Array):
		x = self.net1(x)
		return self.net2(nnx.relu(x))

In [4]:
net = Net(4, 8, 1, rngs=nnx.Rngs(params=0))

In [6]:
x = jax.random.uniform(jax.random.key(1300), (2, 4))
net(x)

Array([[-0.03695716],
       [-0.0799413 ]], dtype=float32)

## [**Parameter Access**]
:label:`subsec_param-access`

Let's start with how to access parameters
from the models that you already know.


We can inspect the parameters of the second fully connected layer as follows.


In [18]:
graphdef, params = nnx.split(net, nnx.Param)

In [21]:
params['net2']

State({
  'bias': VariableState(
    type=Param,
    value=Array([0.], dtype=float32)
  ),
  'kernel': VariableState(
    type=Param,
    value=Array([[ 0.00311489],
           [-0.13639002],
           [ 0.75477993],
           [ 0.49593627],
           [-0.39758068],
           [-0.5204415 ],
           [-0.04960723],
           [-0.48217064]], dtype=float32)
  )
})

We can see that this fully connected layer
contains two parameters,
corresponding to that layer's
weights and biases, respectively.


### [**Targeted Parameters**]

Note that each parameter is represented
as an instance of the parameter class.
To do anything useful with the parameters,
we first need to access the underlying numerical values.
There are several ways to do this.
Some are simpler while others are more general.
The following code extracts the bias
from the second neural network layer, which returns a parameter class instance, and
further accesses that parameter's value.


In [22]:
params['net2']['bias']

VariableState(
  type=Param,
  value=Array([0.], dtype=float32)
)

In [23]:
bias = params['net2']['bias']
type(bias), bias

(flax.nnx.nnx.variables.VariableState,
 VariableState(
   type=Param,
   value=Array([0.], dtype=float32)
 ))

### [**All Parameters at Once**]

When we need to perform operations on all parameters,
accessing them one-by-one can grow tedious.
The situation can grow especially unwieldy
when we work with more complex, e.g., nested, modules,
since we would need to recurse
through the entire tree to extract
each sub-module's parameters. Below we demonstrate accessing the parameters of all layers.


In [24]:
jax.tree_util.tree_map(lambda x: x.shape, params)

State({
  'net1': {
    'bias': VariableState(
      type=Param,
      value=(8,)
    ),
    'kernel': VariableState(
      type=Param,
      value=(4, 8)
    )
  },
  'net2': {
    'bias': VariableState(
      type=Param,
      value=(1,)
    ),
    'kernel': VariableState(
      type=Param,
      value=(8, 1)
    )
  }
})

## [**Tied Parameters**]

Often, we want to share parameters across multiple layers.
Let's see how to do this elegantly.
In the following we allocate a fully connected layer
and then use its parameters specifically
to set those of another layer.

In [25]:
class Net(nnx.Module):
	def __init__(self, n_inputs: int, *, rngs: nnx.Rngs):
		self.net1 = nnx.Linear(n_inputs, 8, rngs=rngs)
		self.shared = nnx.Linear(8, 8, rngs=rngs)
		self.net2 = nnx.Linear(8, 1, rngs=rngs)
	
	def __call__(self, x: jax.Array):
		x = nnx.relu(self.net1(x))
		x = nnx.relu(self.shared(x))
		x = nnx.relu(self.shared(x))
		return self.net2(x)

In [26]:
net = Net(4, rngs=nnx.Rngs(params=0))
net(x)

Array([[-0.08396542],
       [-0.06350967]], dtype=float32)

In [None]:
# Check whether the parameters are different
print(len(params['params']) == 3)
params['params'].keys()

In [27]:
graphdef, params = nnx.split(net, nnx.Param)

In [29]:
len(params)

3

This example shows that the parameters
of the second and third layer are tied.
They are not just equal, they are
represented by the same exact tensor.
Thus, if we change one of the parameters,
the other one changes, too.


## Summary

We have several ways of accessing and tying model parameters.


## Exercises

1. Use the `NestMLP` model defined in :numref:`sec_model_construction` and access the parameters of the various layers.
1. Construct an MLP containing a shared parameter layer and train it. During the training process, observe the model parameters and gradients of each layer.
1. Why is sharing parameters a good idea?


[Discussions](https://discuss.d2l.ai/t/17990)
