This notebook follows the instruction on: https://flax.readthedocs.io/en/latest/guides/flax_basics.html#linear-regression-with-flax

In [1]:
import jax
from jax import jit
from typing import Any, Callable, Sequence
from jax import lax, random, numpy as jnp
from flax.core import freeze, unfreeze
from flax import linen as nn

# 1. Linear Regression with Flax

In [2]:
#create a dense layer instance
model = nn.Dense(features=5)

### Model parameters & initialization

In [3]:
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (10, )) # this x is used to trigger shape inference
params = model.init(key2, x)
jax.tree_map(lambda x: x.shape, params)

FrozenDict({
    params: {
        bias: (5,),
        kernel: (10, 5),
    },
})

Parameters are stored in a `FrozenDict` instance which deal with the function nature of JAX by preventing any mutation of the underlying dict and making the user aware of it.

In [4]:
print(params)

FrozenDict({
    params: {
        kernel: Array([[ 2.35571593e-01, -1.71652630e-01, -4.45728898e-02,
                -4.68043625e-01,  4.54595298e-01],
               [-6.87736511e-01,  3.67835432e-01, -1.79262117e-01,
                 1.29276216e-01, -2.42580175e-01],
               [ 2.02303097e-01, -2.49465629e-01,  2.74955630e-01,
                 4.73488301e-01, -1.98002532e-01],
               [ 2.74478376e-01, -1.21369645e-01, -2.25361690e-01,
                -4.78193611e-01, -9.63979959e-02],
               [-6.19886220e-02, -1.72743499e-01,  2.96947401e-04,
                -7.17593431e-01,  2.00894251e-01],
               [-5.60321212e-01,  3.27208459e-01,  1.06281511e-01,
                 1.28758654e-01,  1.16973273e-01],
               [ 1.82219014e-01,  1.11444041e-01, -1.62924170e-01,
                 3.24953273e-02, -1.67053357e-01],
               [ 4.31294084e-01,  2.08004534e-01,  1.47714198e-01,
                -8.51502791e-02, -1.26487076e-01],
               [ 3.29

In [5]:
print(model)

Dense(
    # attributes
    features = 5
    use_bias = True
    dtype = None
    param_dtype = float32
    precision = None
    kernel_init = init
    bias_init = zeros
    dot_general = dot_general
)


### Gradient Descent

In [6]:
# Set problem dimensions
n_samples = 20
x_dim = 10
y_dim = 5

# Generate random ground truth W and b
key = random.PRNGKey(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))

# Store the parameters in a pytree
true_params = freeze({'params':{'bias': b, 'kernel': W}})

# Generate samples with additional noise
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise, (n_samples, y_dim))
print('x shape: ', x_samples.shape, '; y shape', y_samples.shape)

x shape:  (20, 10) ; y shape (20, 5)


In [7]:
@jax.jit
def mse(params, x_batched, y_batched):
    # Define square loss for a single pair
    def squared_error(x, y):
        pred = model.apply(params, x)
        return jnp.inner(y-pred, y-pred) / 2.0
    # Vectorize the previous to compute the average of the loss on all samples
    return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)

Perform gradient descent

In [8]:
learning_rate = 0.3
print('Loss for "true" W, b: ', mse(true_params, x_samples, y_samples))
loss_grad_fn = jax.value_and_grad(mse)

@jax.jit
def update_params(params, learning_rate, grads):
    params = jax.tree_util.tree_map(
        lambda p, g: p - learning_rate * g, params, grads
    )
    return params

for i in range(101):
    loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
    params = update_params(params, learning_rate, grads)
    if i % 10 == 0:
        print(f'Loss step {i}: ', loss_val)

Loss for "true" W, b:  0.023639798
Loss step 0:  35.343876
Loss step 10:  0.5143469
Loss step 20:  0.11384161
Loss step 30:  0.03932675
Loss step 40:  0.019916205
Loss step 50:  0.014209128
Loss step 60:  0.012425651
Loss step 70:  0.0118503915
Loss step 80:  0.011661774
Loss step 90:  0.011599411
Loss step 100:  0.011578695


### Optimizing with Optax

Basic usage of Optax:
1. Choose an optimization method (e.g. `optax.adam`);
2. Create optimizer state from parameters (for the Adam optimizer, this state will contain the *momentum* values);
3. Compute the gradients of the loss with `jax.value_and_grad()`;
4. At every iteration, call the Optax `update` function to update the internal optimizer state and create an update to the parameters. Then add the update to the parameters with Optax's `apply_updates` method.

In [9]:
import optax
tx = optax.adam(learning_rate=learning_rate)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)

for i in range(101):
    loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 10 == 0:
        print('Loss step {}: '.format(i), loss_val)

Loss step 0:  0.011577623
Loss step 10:  0.2614426
Loss step 20:  0.076756336
Loss step 30:  0.03644162
Loss step 40:  0.022014325
Loss step 50:  0.016178917
Loss step 60:  0.013002919
Loss step 70:  0.012026143
Loss step 80:  0.0117644435
Loss step 90:  0.011646055
Loss step 100:  0.011585526


### Serializing the result
Save and load model parameters

In [10]:
from flax import serialization
bytes_output = serialization.to_bytes(params)
dict_output = serialization.to_state_dict(params)

print('Dict output')
print(dict_output)

print('Bytes output')
print(bytes_output)

Dict output
{'params': {'bias': Array([-1.4555762, -2.0277987,  2.079098 ,  1.2186146, -0.9980971],      dtype=float32), 'kernel': Array([[ 1.0098814 ,  0.18934432,  0.04455043, -0.9280222 ,  0.34784007],
       [ 1.7298449 ,  0.98793644,  1.164046  ,  1.1006075 , -0.10653944],
       [-1.2029458 ,  0.2863525 ,  1.4155985 ,  0.11871001, -1.3141481 ],
       [-1.1941485 , -0.18958484,  0.0341387 ,  1.3169428 ,  0.08060349],
       [ 0.13852426,  1.3713043 , -1.3187183 ,  0.53152704, -2.2404993 ],
       [ 0.56293976,  0.812231  ,  0.31751972,  0.53455067,  0.90500313],
       [-0.37926075,  1.7410388 ,  1.0790282 , -0.50398386,  0.92830575],
       [ 0.97064954, -1.3153397 ,  0.33681548,  0.80993474, -1.2018455 ],
       [ 1.0194305 , -0.62024856,  1.081882  , -1.8389751 , -0.45805126],
       [-0.6436538 ,  0.4566669 , -1.1329137 , -0.6853869 ,  0.16829033]],      dtype=float32)}}
Bytes output
b'\x81\xa6params\x82\xa4bias\xc7!\x01\x93\x91\x05\xa7float32\xc4\x14RP\xba\xbft\xc7\x01\xc0\x

In [11]:
serialization.from_bytes(params, bytes_output)

FrozenDict({
    params: {
        bias: array([-1.4555762, -2.0277987,  2.079098 ,  1.2186146, -0.9980971],
              dtype=float32),
        kernel: array([[ 1.0098814 ,  0.18934432,  0.04455043, -0.9280222 ,  0.34784007],
               [ 1.7298449 ,  0.98793644,  1.164046  ,  1.1006075 , -0.10653944],
               [-1.2029458 ,  0.2863525 ,  1.4155985 ,  0.11871001, -1.3141481 ],
               [-1.1941485 , -0.18958484,  0.0341387 ,  1.3169428 ,  0.08060349],
               [ 0.13852426,  1.3713043 , -1.3187183 ,  0.53152704, -2.2404993 ],
               [ 0.56293976,  0.812231  ,  0.31751972,  0.53455067,  0.90500313],
               [-0.37926075,  1.7410388 ,  1.0790282 , -0.50398386,  0.92830575],
               [ 0.97064954, -1.3153397 ,  0.33681548,  0.80993474, -1.2018455 ],
               [ 1.0194305 , -0.62024856,  1.081882  , -1.8389751 , -0.45805126],
               [-0.6436538 ,  0.4566669 , -1.1329137 , -0.6853869 ,  0.16829033]],
              dtype=float32),
  

# Defining your own models

### Module basics

In [12]:
class ExplicitMLP(nn.Module):
    features: Sequence[int]
        
    def setup(self):
        self.layers = [nn.Dense(feat) for feat in self.features]
    
    def __call__(self, inputs):
        x = inputs
        for i, lyr in enumerate(self.layers):
            x = lyr(x)
            if i != len(self.layers) - 1:
                x = nn.relu(x)
        return x

key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4, 4))

model = ExplicitMLP(features=[3, 4, 5])
params = model.init(key2, x)
y = model.apply(params, x)

print('initialize parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))
print('output:\n', y)

initialize parameter shapes:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
 [[ 0.          0.          0.          0.          0.        ]
 [ 0.00723787 -0.00810345 -0.0255093   0.02151708 -0.01261237]
 [ 0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.        ]]


Since **the modules structure and its parameters are not tied to each other**, we cannot directly call `model(x)` on a given input as it will return an error. The `__call__` function is being wrapped up in the `apply` one:

In [13]:
try:
    y = model(x)
except AttributeError as e:
    print(e)

"ExplicitMLP" object has no attribute "layers". If "layers" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply'.


In [14]:
class SimpleMLP(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, inputs):
        x = inputs
        for i, feat in enumerate(self.features):
            x = nn.Dense(feat, name=f'layers_{i}')(x)
            if i != len(self.features) - 1:
                x = nn.relu(x)
        return x

key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4, 4))

model = SimpleMLP(features=[3, 4, 5])
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameter:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))
print('output:\n', y)

initialized parameter:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
 [[ 0.          0.          0.          0.          0.        ]
 [ 0.00723787 -0.00810345 -0.0255093   0.02151708 -0.01261237]
 [ 0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.        ]]


In [15]:
class SimpleDense(nn.Module):
    features: int
    kernel_init: Callable = nn.initializers.lecun_normal()
    bias_init: Callable = nn.initializers.zeros_init()

    @nn.compact
    def __call__(self, inputs):
        kernel = self.param('kernel',
                            self.kernel_init,
                            (inputs.shape[-1], self.features))

        y = lax.dot_general(inputs, kernel,
                            (((inputs.ndim-1,), (0,)),((), ())),)
        
        bias = self.param('bias',
                          self.bias_init,
                          (self.features))

        y = y + bias
        return y

key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4, 4))

model = SimpleDense(features=3)
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameters:\n', params)
print('output:\n', y)

initialized parameters:
 FrozenDict({
    params: {
        kernel: Array([[ 0.61506   , -0.22728713,  0.60547006],
               [-0.2961799 ,  1.1232013 , -0.879759  ],
               [-0.35162622,  0.38064912,  0.68932474],
               [-0.1151355 ,  0.04567899, -1.091212  ]], dtype=float32),
        bias: Array([0., 0., 0.], dtype=float32),
    },
})
output:
 [[-0.029962    1.102088   -0.6660265 ]
 [-0.31092793  0.6323942  -0.5367881 ]
 [ 0.0142401   0.9424717  -0.6356147 ]
 [ 0.36818963  0.35865188 -0.00459227]]


### Variables and collections of variables

In [16]:
class BiasAdderWithRunningMean(nn.Module):
    decay: float = 0.99
    
    @nn.compact
    def __call__(self, x):
        # easy pattern to detect if we're initializing via empty variable tree
        is_initialized = self.has_variable('batch_stats', 'mean')
        ra_mean = self.variable('batch_stats', 'mean',
                                lambda s:jnp.zeros(s),
                                x.shape[1:])
        mean = ra_mean.value
        bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
        
        if is_initialized:
            ra_mean.value = self.decay * ra_mean.value + (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True)

        return x - ra_mean.value + bias

key1, key2 = random.split(random.PRNGKey(0), 2)
x = jnp.ones((10, 5))
model = BiasAdderWithRunningMean()
variables = model.init(key1, x)

print('initialized variables:\n', variables)
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
print('updated_state:\n', updated_state)

initialized variables:
 FrozenDict({
    batch_stats: {
        mean: Array([0., 0., 0., 0., 0.], dtype=float32),
    },
    params: {
        bias: Array([0., 0., 0., 0., 0.], dtype=float32),
    },
})
updated_state:
 FrozenDict({
    batch_stats: {
        mean: Array([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32),
    },
})
