# Flax基础

这一节我们学习:

*   使用Flax内置的层来创建模型
*   对模型参数进行初始化并且手写训练流程
*   使用Flax提供的优化器来简化代码
*   对模型参数进行序列化
*   创建自己的模型，并且管理状态

## 准备开发环境


In [1]:
# Install the latest JAXlib version.
!pip3 install --upgrade -q pip jax jaxlib 
# Install Flax at head:
!pip3 install --upgrade -q git+https://github.com/google/flax.git

In [2]:
import jax
from typing import Any, Callable, Sequence
from jax import lax, random, numpy as jnp
from flax.core import freeze, unfreeze  # what's this? for FrozenDict
from flax import linen as nn

## 线性回归🌰

线性回归模型就是一个线性层，Flax中已经包含了， `flax.linen` 

In [3]:
# 输出维度是5
model = nn.Dense(features=5)

Layers (或者说model)，都市 `linen.Module` 的子类，模仿PyTorch。

### 模型参数和初始化

和PyTorch不同，模型（model）本身不存储参数，我们可以通过模型的 `init` 方法来创建参数。传参是PRNGKey和dummpy input（模拟输入）。

In [4]:
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (10,)) # Dummy input，输入维度是10
params = model.init(key2, x) # Initialization call

jax.tree_map(lambda x: x.shape, params) # 查看模型参数的shape

FrozenDict({
    params: {
        bias: (5,),
        kernel: (10, 5),
    },
})

In [5]:
type(params)

flax.core.frozen_dict.FrozenDict

*注意：JAX和Flax，同NumPy一样，向量都是行向量而不是列向量。*


*   模拟输入 `x` 用于推测参数的shape，毕竟我们只定义了`Dense`的输出维度
*   PRNG key用于初始化方法，初始化方法的作用是创建模型参数，传参是 `(PRNG Key, shape, dtype)` ，返回具有对应 `shape`的数组

参数的类型是 `FrozenDict` ，`frozen`的意思是不可修改，这对应了JAX函数式编程的本质，JAX数组不可修改。


In [6]:
try:
    params['new_key'] = jnp.ones((2,2))
except ValueError as e:
    print("Error: ", e)

Error:  FrozenDict is immutable.


如何让模型进行前向计算呢，调用 `apply` 方法即可，只不过需要将模型参数也传入:

In [7]:
model.apply(params, x)

DeviceArray([-1.3721182,  0.6113121,  0.6442853,  2.2192988, -1.1271175],            dtype=float32)

### 梯度下降算法

If you jumped here directly without going through the JAX part, here is the linear regression formulation we're going to use: from a set of data points $\{(x_i,y_i), i\in \{1,\ldots, k\}, x_i\in\mathbb{R}^n,y_i\in\mathbb{R}^m\}$, we try to find a set of parameters $W\in \mathcal{M}_{m,n}(\mathbb{R}), b\in\mathbb{R}^m$ such that the function $f_{W,b}(x)=Wx+b$ minimizes the mean squared error:
$$\mathcal{L}(W,b)\rightarrow\frac{1}{k}\sum_{i=1}^{k} \frac{1}{2}\|y_i-f_{W,b}(x_i)\|^2_2$$

Here, we see that the tuple $(W,b)$ matches the parameters of the Dense layer. We'll perform gradient descent using those. Let's first generate the fake data we'll use. The data is exactly the same as in the JAX part's linear regression pytree example.

In [8]:
# Set problem dimensions.
n_samples = 20
x_dim = 10
y_dim = 5

# Generate random ground truth W and b.
key = random.PRNGKey(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))
# Store the parameters in a pytree.
true_params = freeze({'params': {'bias': b, 'kernel': W}})

# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise,(n_samples, y_dim))
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)

x shape: (20, 10) ; y shape: (20, 5)


In [9]:
# Same as JAX version but using model.apply().
@jax.jit
def mse(params, x_batched, y_batched):
  # Define the squared loss for a single pair (x,y)
  def squared_error(x, y):
    pred = model.apply(params, x)
    return jnp.inner(y-pred, y-pred) / 2.0
  # Vectorize the previous to compute the average of the loss on all samples.
  return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)

And finally perform the gradient descent.

In [10]:
learning_rate = 0.3  # Gradient step size.
print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples))
loss_grad_fn = jax.value_and_grad(mse)

@jax.jit
def update_params(params, learning_rate, grads):
  params = jax.tree_map(
      lambda p, g: p - learning_rate * g, params, grads)
  return params

for i in range(101):
  # Perform one gradient update.
  loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
  params = update_params(params, learning_rate, grads)
  if i % 10 == 0:
    print(f'Loss step {i}: ', loss_val)

Loss for "true" W,b:  0.02363973
Loss step 0:  35.330673
Loss step 10:  0.5138981
Loss step 20:  0.11392992
Loss step 30:  0.03919686
Loss step 40:  0.019862717
Loss step 50:  0.014238695
Loss step 60:  0.012534802
Loss step 70:  0.011963526
Loss step 80:  0.0117725
Loss step 90:  0.011667584
Loss step 100:  0.01167096


### 使用Optax的优化器

虽然 `flax.optim` 也包含了一些优化算法，但是[FLIP #1009](https://github.com/google/flax/blob/main/docs/flip/1009-optimizer-api.md)建议使用DeepMind提供的[Optax](https://github.com/deepmind/optax)。

Optax的用法很简单:

1.   选择一个优化算法 (e.g. `optax.sgd`)
2.   从模型参数中创建优化器状态
3.   计算loss的梯度 `jax.value_and_grad()`.
4.   每个iteration，调用 Optax `update` 来更新优化器状态以及创建模型参数的更新值，然后使用 `apply_updates`来更新模型参数

Optax支持学习率调度算法以及各种复杂操作
[official documentation](https://optax.readthedocs.io/en/latest/).

In [11]:
import optax
tx = optax.sgd(learning_rate=learning_rate)
opt_state = tx.init(params)  # 创建optimizer state
loss_grad_fn = jax.value_and_grad(mse)

In [12]:
for i in range(101):
  loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  if i % 10 == 0:
    print('Loss step {}: '.format(i), loss_val)

Loss step 0:  0.01167054
Loss step 10:  0.0117824925
Loss step 20:  0.01173518
Loss step 30:  0.011738507
Loss step 40:  0.011714945
Loss step 50:  0.011688645
Loss step 60:  0.011755346
Loss step 70:  0.011680829
Loss step 80:  0.011735042
Loss step 90:  0.011765196
Loss step 100:  0.0117326705


### 序列化

模型已经训练好了，如何保存模型参数呢，Flax提供了序列化操作。

In [13]:
from flax import serialization
bytes_output = serialization.to_bytes(params)
dict_output = serialization.to_state_dict(params)
print('Dict output')
print(dict_output)
print('Bytes output')
print(bytes_output)

Dict output
{'params': {'bias': DeviceArray([-1.4535893, -2.025787 ,  2.0809538,  1.2200156, -0.9970508],            dtype=float32), 'kernel': DeviceArray([[ 1.0109067 ,  0.19055863,  0.04583546, -0.9280397 ,
               0.34755877],
             [ 1.730625  ,  0.99051666,  1.167061  ,  1.0984352 ,
              -0.1059608 ],
             [-1.1982636 ,  0.28850538,  1.4185867 ,  0.12010808,
              -1.3168883 ],
             [-1.1967102 , -0.19019204,  0.03367356,  1.3164332 ,
               0.0799135 ],
             [ 0.14068502,  1.3709251 , -1.3170009 ,  0.5340607 ,
              -2.2270784 ],
             [ 0.56380916,  0.81390435,  0.31864667,  0.53749406,
               0.90501535],
             [-0.378986  ,  1.7398658 ,  1.0832783 , -0.5030793 ,
               0.92784196],
             [ 0.97118044, -1.3149489 ,  0.33652323,  0.8073513 ,
              -1.2069335 ],
             [ 1.0206019 , -0.6180575 ,  1.0824292 , -1.8404183 ,
              -0.45773292],
           

如何读取保存的模型参数

In [14]:
# 先要有params结构
serialization.from_bytes(params, bytes_output)

FrozenDict({
    params: {
        bias: array([-1.4535893, -2.025787 ,  2.0809538,  1.2200156, -0.9970508],
              dtype=float32),
        kernel: array([[ 1.0109067 ,  0.19055863,  0.04583546, -0.9280397 ,  0.34755877],
               [ 1.730625  ,  0.99051666,  1.167061  ,  1.0984352 , -0.1059608 ],
               [-1.1982636 ,  0.28850538,  1.4185867 ,  0.12010808, -1.3168883 ],
               [-1.1967102 , -0.19019204,  0.03367356,  1.3164332 ,  0.0799135 ],
               [ 0.14068502,  1.3709251 , -1.3170009 ,  0.5340607 , -2.2270784 ],
               [ 0.56380916,  0.81390435,  0.31864667,  0.53749406,  0.90501535],
               [-0.378986  ,  1.7398658 ,  1.0832783 , -0.5030793 ,  0.92784196],
               [ 0.97118044, -1.3149489 ,  0.33652323,  0.8073513 , -1.2069335 ],
               [ 1.0206019 , -0.6180575 ,  1.0824292 , -1.8404183 , -0.45773292],
               [-0.64246374,  0.45634705, -1.1300361 , -0.68562186,  0.16957925]],
              dtype=float32),
  

## 设计模型

模拟PyTorch `nn.Module`。

*Keep in mind that we imported* `linen as nn` *and this only works with the new linen API*

### Module basics



In [15]:
class ExplicitMLP(nn.Module):
  features: Sequence[int]  # def __init__()中的参数，模型的属性

  def setup(self):
    # we automatically know what to do with lists, dicts of submodules
    self.layers = [nn.Dense(feat) for feat in self.features]
    # for single submodules, we would just write:
    # self.layer1 = nn.Dense(feat1)

  def __call__(self, inputs):
    x = inputs
    for i, lyr in enumerate(self.layers):
      x = lyr(x)
      if i != len(self.layers) - 1:
        x = nn.relu(x)
    return x


In [16]:
key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4, 4))

model = ExplicitMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameter shapes:\n', jax.tree_map(jnp.shape, unfreeze(params)))
print('output:\n', y)

initialized parameter shapes:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
 [[ 0.          0.          0.          0.          0.        ]
 [ 0.00720215 -0.00805998 -0.02526569  0.02135181 -0.01251006]
 [ 0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.        ]]


 `nn.Module` 包含:

*  a collection of data fields, 由于 `nn.Module` 属于Python dataclass
*  `setup()` 方法，在`__postinit__`之后被调用，用于注册submodule、variables和paramaters
*  `__call__` 方法，类似PyTorch的forward()
*  The model structure defines a pytree of parameters following the same tree structure as the model: the params tree contains one `layers_n` sub dict per layer, and each of those contain the parameters of the associated Dense layer. The layout is very explicit.
* `apply()`封装了`__call__`

*Note: lists are mostly managed as you would expect (WIP), there are corner cases you should be aware of as pointed out* [here](https://github.com/google/flax/issues/524)


In [17]:
try:
    y = model(x) # Returns an error
except AttributeError as e:
    print(e)

"ExplicitMLP" object has no attribute "layers"


Since here we have a very simple model, we could have used an alternative (but equivalent) way of declaring the submodules inline in the `__call__` using the `@nn.compact` annotation like so:

In [18]:
class SimpleMLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, inputs):
    x = inputs
    for i, feat in enumerate(self.features):
      x = nn.Dense(feat, name=f'layers_{i}')(x)
      if i != len(self.features) - 1:
        x = nn.relu(x)
      # providing a name is optional though!
      # the default autonames would be "Dense_0", "Dense_1", ...
    return x

key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))

model = SimpleMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameter shapes:\n', jax.tree_map(jnp.shape, unfreeze(params)))
print('output:\n', y)

initialized parameter shapes:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
 [[ 0.          0.          0.          0.          0.        ]
 [ 0.00720215 -0.00805998 -0.02526569  0.02135181 -0.01251006]
 [ 0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.        ]]


There are, however, a few differences you should be aware of between the two declaration modes:

*   In `setup`, you are able to name some sublayers and keep them around for further use (e.g. encoder/decoder methods in autoencoders).
*   If you want to have multiple methods, then you **need** to declare the module using `setup`, as the `@nn.compact` annotation only allows one method to be annotated.
*   The last initialization will be handled differently. See these notes for more details (TODO: add notes link).


### 如何开发新的Module

使用 `@nn.compact` 定义新的module:

In [19]:
class SimpleDense(nn.Module):
  features: int
  kernel_init: Callable = nn.initializers.lecun_normal()
  bias_init: Callable = nn.initializers.zeros

  @nn.compact
  def __call__(self, inputs):
    kernel = self.param('kernel',
                        self.kernel_init, # Initialization function
                        (inputs.shape[-1], self.features))  # shape info.
    y = lax.dot_general(inputs, kernel,
                        (((inputs.ndim - 1,), (0,)), ((), ())),) # TODO Why not jnp.dot?
    bias = self.param('bias', self.bias_init, (self.features,))
    y = y + bias
    return y

key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))

model = SimpleDense(features=3)
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameters:\n', params)
print('output:\n', y)

initialized parameters:
 FrozenDict({
    params: {
        kernel: DeviceArray([[ 0.61505955, -0.22728656,  0.60547036],
                     [-0.29617673,  1.1231915 , -0.87975675],
                     [-0.35162923,  0.38064492,  0.6893282 ],
                     [-0.11513556,  0.04567894, -1.0912049 ]], dtype=float32),
        bias: DeviceArray([0., 0., 0.], dtype=float32),
    },
})
output:
 [[-0.03109741  1.1018791  -0.6654663 ]
 [-0.31054688  0.6324854  -0.53855133]
 [ 0.0125351   0.94367504 -0.6356659 ]
 [ 0.36548042  0.35944438 -0.00537109]]


Here, we see how to both declare and assign a parameter to the model using the `self.param` method. It takes as input `(name, init_fn, *init_args)` : 

*   `name` is simply the name of the parameter that will end up in the parameter structure.
*   `init_fn` is a function with input `(PRNGKey, *init_args)` returning an Array, with `init_args` being the arguments needed to call the initialisation function.
*   `init_args` are the arguments to provide to the initialization function.

Such params can also be declared in the `setup` method; it won't be able to use shape inference because Flax is using lazy initialization at the first call site.

### Variables and collections of variables

As we've seen so far, working with models means working with:

*   A subclass of `nn.Module`;
*   A pytree of parameters for the model (typically from `model.init()`);

However this is not enough to cover everything that we would need for machine learning, especially neural networks. In some cases, you might want your neural network to keep track of some internal state while it runs (e.g. batch normalization layers). There is a way to declare variables beyond the parameters of the model with the `variable` method.

For demonstration purposes, we'll implement a simplified but similar mechanism to batch normalization: we'll store running averages and subtract those to the input at training time. For proper batchnorm, you should use (and look at) the implementation [here](https://github.com/google/flax/blob/main/flax/linen/normalization.py).

In [20]:
class BiasAdderWithRunningMean(nn.Module):
  decay: float = 0.99

  @nn.compact
  def __call__(self, x):
    # easy pattern to detect if we're initializing via empty variable tree
    is_initialized = self.has_variable('batch_stats', 'mean')
    ra_mean = self.variable('batch_stats', 'mean',
                            lambda s: jnp.zeros(s),
                            x.shape[1:])
    mean = ra_mean.value # This will either get the value or trigger init
    bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
    if is_initialized:
      ra_mean.value = self.decay * ra_mean.value + (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True)

    return x - ra_mean.value + bias


key1, key2 = random.split(random.PRNGKey(0), 2)
x = jnp.ones((10,5))
model = BiasAdderWithRunningMean()
variables = model.init(key1, x)
print('initialized variables:\n', variables)
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
print('updated state:\n', updated_state)

initialized variables:
 FrozenDict({
    batch_stats: {
        mean: DeviceArray([0., 0., 0., 0., 0.], dtype=float32),
    },
    params: {
        bias: DeviceArray([0., 0., 0., 0., 0.], dtype=float32),
    },
})
updated state:
 FrozenDict({
    batch_stats: {
        mean: DeviceArray([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32),
    },
})


Here, `updated_state` returns only the state variables that are being mutated by the model while applying it on data. To update the variables and get the new parameters of the model, we can use the following pattern:

In [21]:
for val in [1.0, 2.0, 3.0]:
  x = val * jnp.ones((10,5))
  y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
  old_state, params = variables.pop('params')
  variables = freeze({'params': params, **updated_state})
  print('updated state:\n', updated_state) # Shows only the mutable part

updated state:
 FrozenDict({
    batch_stats: {
        mean: DeviceArray([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32),
    },
})
updated state:
 FrozenDict({
    batch_stats: {
        mean: DeviceArray([[0.0299, 0.0299, 0.0299, 0.0299, 0.0299]], dtype=float32),
    },
})
updated state:
 FrozenDict({
    batch_stats: {
        mean: DeviceArray([[0.059601, 0.059601, 0.059601, 0.059601, 0.059601]], dtype=float32),
    },
})


From this simplified example, you should be able to derive a full BatchNorm implementation, or any layer involving a state. To finish, let's add an optimizer to see how to play with both parameters updated by an optimizer and state variables.

*This example isn't doing anything and is only for demonstration purposes.*

In [22]:
def update_step(tx, apply_fn, x, opt_state, params, state):

  def loss(params):
    y, updated_state = apply_fn({'params': params, **state},
                                x, mutable=list(state.keys()))
    l = ((x - y) ** 2).sum()
    return l, updated_state

  (l, state), grads = jax.value_and_grad(loss, has_aux=True)(params)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return opt_state, params, state

x = jnp.ones((10,5))
variables = model.init(random.PRNGKey(0), x)
state, params = variables.pop('params')
del variables
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)

for _ in range(3):
  opt_state, params, state = update_step(tx, model.apply, x, opt_state, params, state)
  print('Updated state: ', state)

Updated state:  FrozenDict({
    batch_stats: {
        mean: DeviceArray([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32),
    },
})
Updated state:  FrozenDict({
    batch_stats: {
        mean: DeviceArray([[0.0199, 0.0199, 0.0199, 0.0199, 0.0199]], dtype=float32),
    },
})
Updated state:  FrozenDict({
    batch_stats: {
        mean: DeviceArray([[0.029701, 0.029701, 0.029701, 0.029701, 0.029701]], dtype=float32),
    },
})


Note that the above function has a quite verbose signature and it would not actually
work with `jax.jit()` because the function arguments are not "valid JAX types".

We provide a handy wrapper that simplifies the above code, see:

https://flax.readthedocs.io/en/latest/flax.training.html#train-state

### Exporting to Tensorflow's SavedModel with jax2tf

JAX released an experimental converter called [jax2tf](https://github.com/google/jax/tree/main/jax/experimental/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax.

In [23]:
from transformers import FlaxAutoModel

In [24]:
codebert = FlaxAutoModel.from_pretrained("microsoft/codebert-base")

Downloading:   0%|          | 0.00/475M [00:00<?, ?B/s]

In [25]:
codebert.params

{'embeddings': {'LayerNorm': {'bias': DeviceArray([-1.33444816e-01,  1.73319560e-02, -8.60523141e-04,
                 1.26775308e-03,  1.13012940e-01, -1.95150599e-02,
                 9.65326205e-02,  4.82217632e-02,  1.05802231e-01,
                -4.21788804e-02,  8.72153230e-03,  3.98519747e-02,
                 5.37393279e-02, -3.14790495e-02, -3.56927328e-02,
                 5.33316750e-03, -7.08508445e-03,  8.25870112e-02,
                -5.61116524e-02,  1.84450543e-03,  3.05853486e-02,
                 7.22988183e-03,  2.34760251e-02,  4.56156489e-03,
                -1.73515100e-02,  1.37515888e-02, -9.71781909e-02,
                -4.63355333e-02, -3.19514684e-02,  6.35724962e-02,
                 7.57761970e-02,  2.70680953e-02, -5.68341650e-02,
                -5.13008200e-02,  3.54870707e-02, -1.42550012e-02,
                -1.06799910e-02, -2.42758412e-02, -1.70434520e-01,
                 1.02982009e-02, -1.14280116e-02,  5.66726318e-03,
                -3.34911831

In [26]:
codebert.module_class

transformers.models.roberta.modeling_flax_roberta.FlaxRobertaModule

In [27]:
codebert.module

FlaxRobertaModule(
    # attributes
    config = RobertaConfig {
      "_name_or_path": "microsoft/codebert-base",
      "architectures": [
        "RobertaModel"
      ],
      "attention_probs_dropout_prob": 0.1,
      "bos_token_id": 0,
      "classifier_dropout": null,
      "eos_token_id": 2,
      "hidden_act": "gelu",
      "hidden_dropout_prob": 0.1,
      "hidden_size": 768,
      "initializer_range": 0.02,
      "intermediate_size": 3072,
      "layer_norm_eps": 1e-05,
      "max_position_embeddings": 514,
      "model_type": "roberta",
      "num_attention_heads": 12,
      "num_hidden_layers": 12,
      "output_past": true,
      "pad_token_id": 1,
      "position_embedding_type": "absolute",
      "transformers_version": "4.20.1",
      "type_vocab_size": 1,
      "use_cache": true,
      "vocab_size": 50265
    }
    
    dtype = float32
    add_pooling_layer = True
)

In [28]:
type(codebert)

transformers.models.roberta.modeling_flax_roberta.FlaxRobertaModel