# Parameter Initialization

Now that we know how to access the parameters,
let's look at how to initialize them properly.
We discussed the need for proper initialization in :numref:`sec_numerical_stability`.
The deep learning framework provides default random initializations to its layers.
However, we often want to initialize our weights
according to various other protocols. The framework provides most commonly
used protocols, and also allows to create a custom initializer.


By default, Flax initializes weights using `jax.nn.initializers.lecun_normal`,
i.e., by drawing samples from a truncated normal distribution centered on 0 with
the standard deviation set as the squared root of $1 / \text{fan}_{\text{in}}$
where `fan_in` is the number of input units in the weight tensor. The bias
parameters are all set to zero.
Jax's `nn.initializers` module provides a variety
of preset initialization methods.


In [1]:
import jax
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l

net = nn.Sequential([nn.Dense(8), nn.relu, nn.Dense(1)])
X = jax.random.uniform(d2l.get_key(), (2, 4))
params = net.init(d2l.get_key(), X)
net.apply(params, X).shape

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


(2, 1)

## [**Built-in Initialization**]

Let's begin by calling on built-in initializers.
The code below initializes all weight parameters
as Gaussian random variables
with standard deviation 0.01, while bias parameters cleared to zero.


In [2]:
weight_init = nn.initializers.normal(0.01)
bias_init = nn.initializers.zeros

net = nn.Sequential([nn.Dense(8, kernel_init=weight_init, bias_init=bias_init),
                     nn.relu,
                     nn.Dense(1, kernel_init=weight_init, bias_init=bias_init)])

params = net.init(jax.random.PRNGKey(d2l.get_seed()), X)
layer_0 = params['params']['layers_0']
layer_0['kernel'][:, 0], layer_0['bias'][0]

(Array([-0.00763717,  0.00054529,  0.01311344,  0.01136602], dtype=float32),
 Array(0., dtype=float32))

We can also initialize all the parameters
to a given constant value (say, 1).


In [3]:
weight_init = nn.initializers.constant(1)

net = nn.Sequential([nn.Dense(8, kernel_init=weight_init, bias_init=bias_init),
                     nn.relu,
                     nn.Dense(1, kernel_init=weight_init, bias_init=bias_init)])

params = net.init(jax.random.PRNGKey(d2l.get_seed()), X)
layer_0 = params['params']['layers_0']
layer_0['kernel'][:, 0], layer_0['bias'][0]

(Array([1., 1., 1., 1.], dtype=float32), Array(0., dtype=float32))

[**We can also apply different initializers for certain blocks.**]
For example, below we initialize the first layer
with the Xavier initializer
and initialize the second layer
to a constant value of 42.


In [4]:
net = nn.Sequential([nn.Dense(8, kernel_init=nn.initializers.xavier_uniform(),
                              bias_init=bias_init),
                     nn.relu,
                     nn.Dense(1, kernel_init=nn.initializers.constant(42),
                              bias_init=bias_init)])

params = net.init(jax.random.PRNGKey(d2l.get_seed()), X)
params['params']['layers_0']['kernel'][:, 0], params['params']['layers_2']['kernel']

(Array([ 0.09000139,  0.05019319,  0.22662596, -0.3504947 ], dtype=float32),
 Array([[42.],
        [42.],
        [42.],
        [42.],
        [42.],
        [42.],
        [42.],
        [42.]], dtype=float32))

### [**Custom Initialization**]

Sometimes, the initialization methods we need
are not provided by the deep learning framework.
In the example below, we define an initializer
for any weight parameter $w$ using the following strange distribution:

$$
\begin{aligned}
    w \sim \begin{cases}
        U(5, 10) & \text{ with probability } \frac{1}{4} \\
            0    & \text{ with probability } \frac{1}{2} \\
        U(-10, -5) & \text{ with probability } \frac{1}{4}
    \end{cases}
\end{aligned}
$$


Jax initialization functions take as arguments the `PRNGKey`, `shape` and
`dtype`. Here we implement the function `my_init` that returns a desired
tensor given the shape and data type.


In [5]:
def my_init(key, shape, dtype=jnp.float_):
    data = jax.random.uniform(key, shape, minval=-10, maxval=10)
    return data * (jnp.abs(data) >= 5)

net = nn.Sequential([nn.Dense(8, kernel_init=my_init), nn.relu, nn.Dense(1)])
params = net.init(d2l.get_key(), X)
print(params['params']['layers_0']['kernel'][:, :2])

[[ 0.         0.       ]
 [ 0.         0.       ]
 [-5.992775   0.       ]
 [-6.5313935 -7.7249002]]


When initializing parameters in JAX and Flax, the the dictionary of parameters
returned has a `flax.core.frozen_dict.FrozenDict` type. It is not advisable in
the Jax ecosystem to directly alter the values of an array, hence the datatypes
are generally immutable. One might use `params.unfreeze()` to make changes.


## Summary

We can initialize parameters using built-in and custom initializers.

## Exercises

Look up the online documentation for more built-in initializers.
