[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/notebooks/annotated_mnist.ipynb)
[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/notebooks/annotated_mnist.ipynb)

# MNIST 🌰


## 1. Imports

Import JAX, [JAX NumPy](https://jax.readthedocs.io/en/latest/jax.numpy.html),
Flax, ordinary NumPy, and TensorFlow Datasets (TFDS). 

Flax支持任意的数据读取模块，不限于TensorFlow、PyTorch等等。

In [1]:
!pip3 install -q flax tensorflow_datasets tensorflow

In [2]:
import jax
import jax.numpy as jnp                # JAX NumPy

from flax import linen as nn           # The Linen API
from flax.training import train_state  # Useful dataclass to keep train state

import numpy as np                     # Ordinary NumPy
import optax                           # Optimizers
import tensorflow_datasets as tfds     # TFDS for MNIST

2022-07-12 09:35:16.552063: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib


## 2. 定义网络

同PyTorch，需要继承
[Module](https://flax.readthedocs.io/en/latest/flax.linen.html#core-module-abstraction)，然后定义网络即可。

本示例用的网络非常简单，所以可以直接在`__call__` 中定义子modules，再使用
[@compact](https://flax.readthedocs.io/en/latest/flax.linen.html#compact-methods)。注意：Module类中最多有一个方法可以被`@compact`修饰。


如果网络很复杂，需要借助`setup()`和一些data fields.

In [3]:
class CNN(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    return x

## 3. 定义 loss

使用交叉熵损失函数 `optax.softmax_cross_entropy()`， 这个函数要求 `logits` 和 `labels` 的shape是 `[batch, num_classes]`，我们需要将整数label转换为one-hot。


In [4]:
# def cross_entropy_loss(*, logits, labels):
def cross_entropy_loss(logits, labels):
  labels_onehot = jax.nn.one_hot(labels, num_classes=10)
  return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()

## 4. Metric computation

For loss and accuracy metrics, create a separate function:

In [5]:
# def compute_metrics(*, logits, labels):
def compute_metrics(logits, labels):
  loss = cross_entropy_loss(logits=logits, labels=labels)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy,
  }
  return metrics

## 5. Loading data

Define a function that loads and prepares the MNIST dataset and converts the
samples to floating-point numbers.

In [6]:
def get_datasets():
  """Load MNIST train and test datasets into memory."""
  ds_builder = tfds.builder('mnist')
  ds_builder.download_and_prepare()
  train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
  test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
  train_ds['image'] = jnp.float32(train_ds['image']) / 255.
  test_ds['image'] = jnp.float32(test_ds['image']) / 255.
  return train_ds, test_ds

## 6. Create train state

* 很关键


由于JAX属于函数式编程，建议创建一个单独的dataclass来包含所有的训练状态，比如 step number, parameters, and optimizer state。Flax中提供了一个类[flax.training.train_state.TrainState](https://flax.readthedocs.io/en/latest/flax.training.html#train-state)
，我们可以继承它然后创建自己的TrainState，添加更多需要被track的数据即可，但是本文的例子很简单，我们不需要修改TrainState。

In [7]:
def create_train_state(rng, learning_rate, momentum):
  """Creates initial `TrainState`."""
  cnn = CNN()
  params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']  # ['params']
  tx = optax.sgd(learning_rate, momentum)
  return train_state.TrainState.create(
      apply_fn=cnn.apply,
      params=params, 
      tx=tx)

## 7. Training step

很重要的一个函数，就是一个step的完整的训练流程:

- Evaluates the neural network given the parameters and a batch of input images
  with the
  [Module.apply](https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.apply)
  method.
- Computes the `cross_entropy_loss` loss function.
- Evaluates the loss function and its gradient using
  [jax.value_and_grad](https://jax.readthedocs.io/en/latest/jax.html#jax.value_and_grad).
- Applies a
  [pytree](https://jax.readthedocs.io/en/latest/pytrees.html#pytrees-and-jax-functions)
  of gradients to the optimizer to update the model's parameters.
- Computes the metrics using `compute_metrics` (defined earlier).

使用JAX的[@jit](https://jax.readthedocs.io/en/latest/jax.html#jax.jit)来进行JIT编译，让XLA进行优化，提供执行效率。

In [8]:
@jax.jit
def train_step(state, batch):
  """Train for a single step."""
  def loss_fn(params):
    logits = CNN().apply({'params': params}, batch['image'])  # CNN().apply，都不需要创建CNN的示例，纯函数式
    loss = cross_entropy_loss(logits=logits, labels=batch['label'])
    return loss, logits
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)  # has_aux
  (_, logits), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  metrics = compute_metrics(logits=logits, labels=batch['label'])
  return state, metrics

## 8. Evaluation step

Create a function that evaluates your model on the test set with
[Module.apply](https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.apply)

In [9]:
@jax.jit
def eval_step(params, batch):
  logits = CNN().apply({'params': params}, batch['image'])
  return compute_metrics(logits=logits, labels=batch['label'])

## 9. Train function

每个epoch的训练函数:

- 每个epoch前对训练集shuffle
  [jax.random.permutation](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.permutation.html)
- Runs an optimization step for each batch.
- 使用`jax.device_get`从device中取回每个minibatch的metric，计算整个epoch的metric
- Returns the optimizer with updated parameters and the training loss and
  accuracy metrics.

In [None]:
x = jnp.array([[1,2,3], [4,5,6]])

In [None]:
rng = jax.random.PRNGKey(0)
y = jax.random.permutation(rng, 6)

In [None]:
type(y), type(x)

In [None]:
y

In [None]:
def train_epoch(state, train_ds, batch_size, epoch, rng):
  """Train for a single epoch."""
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, train_ds_size)  # perms就是shuffle后的range(n)
  perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch，这里保证每个batch数据量相同，否则涉及到函数重编译
  perms = perms.reshape((steps_per_epoch, batch_size))
  batch_metrics = []
  for perm in perms:
    batch = {k: v[perm, ...] for k, v in train_ds.items()}  # v[perm, ...] 是???
#     print(batch)  # {'image': }
    break
    state, metrics = train_step(state, batch)
    batch_metrics.append(metrics)

  # compute mean of metrics across each batch in epoch.
  batch_metrics_np = jax.device_get(batch_metrics)
  epoch_metrics_np = {
      k: np.mean([metrics[k] for metrics in batch_metrics_np])
      for k in batch_metrics_np[0]}

  print('train epoch: %d, loss: %.4f, accuracy: %.2f' % (
      epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100))

  return state

## 10. Eval function

Create a model evaluation function that:

- Retrieves the evaluation metrics from the device with `jax.device_get`.
- Copies the metrics
  [data stored](https://flax.readthedocs.io/en/latest/design_notes/linen_design_principles.html#how-are-parameters-represented-and-how-do-we-handle-general-differentiable-algorithms-that-update-stateful-variables)
  in a JAX
  [pytree](https://jax.readthedocs.io/en/latest/pytrees.html#pytrees-and-jax-functions).

In [None]:
def eval_model(params, test_ds):
  metrics = eval_step(params, test_ds)  # 一个batch...
  metrics = jax.device_get(metrics)
  summary = jax.tree_map(lambda x: x.item(), metrics)
  return summary['loss'], summary['accuracy']

## 11. Download data

In [None]:
train_ds, test_ds = get_datasets()

## 12. Seed randomness

- Get one
  [PRNGKey](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.PRNGKey.html#jax.random.PRNGKey)
  and
  [split](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.split.html#jax.random.split)
  it to get a second key that you'll use for parameter initialization. (Learn
  more about
  [PRNG chains](https://flax.readthedocs.io/en/latest/design_notes/linen_design_principles.html#how-are-parameters-represented-and-how-do-we-handle-general-differentiable-algorithms-that-update-stateful-variables)
  and
  [JAX PRNG design](https://github.com/google/jax/blob/main/design_notes/prng.md).)

In [None]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

## 13. Initialize train state

Remember that function initializes both the model parameters and the optimizer
and puts both into the training state dataclass that is returned.

In [None]:
learning_rate = 0.1
momentum = 0.9

In [None]:
state = create_train_state(init_rng, learning_rate, momentum)
del init_rng  # Must not be used anymore.

## 14. Train and evaluate

Once the training and testing is done after 10 epochs, the output should show that your model was able to achieve approximately 99% accuracy.

In [None]:
num_epochs = 10
batch_size = 32

In [None]:
for epoch in range(1, num_epochs + 1):
  # Use a separate PRNG key to permute image data during shuffling
  rng, input_rng = jax.random.split(rng)
  # Run an optimization step over a training batch
  state = train_epoch(state, train_ds, batch_size, epoch, input_rng)
  # Evaluate on the test set after each training epoch 
  test_loss, test_accuracy = eval_model(state.params, test_ds)
  print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % (
      epoch, test_loss, test_accuracy * 100))

Congrats! You made it to the end of the annotated MNIST example. You can revisit
the same example, but structured differently as a couple of Python modules, test
modules, config files, another Colab, and documentation in Flax's Git repo:

https://github.com/google/flax/tree/main/examples/mnist