<a href="https://colab.research.google.com/github/rickqiu/flaxnn/blob/main/FlaxCNN_vs_KerasCNN_on_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Flax CNN 2x faster than Keras CNN on MNIST

The notebook aims to compare the model training time between a Flax CNN and a Keras CNN on MNIST dataset.

Choose GPU via "Runtime -> Change runtime type".

In [None]:
# Install the flax package
!pip install -q flax

[K     |████████████████████████████████| 202 kB 19.2 MB/s 
[K     |████████████████████████████████| 145 kB 34.5 MB/s 
[K     |████████████████████████████████| 596 kB 53.0 MB/s 
[K     |████████████████████████████████| 9.1 MB 60.3 MB/s 
[K     |████████████████████████████████| 217 kB 10.7 MB/s 
[K     |████████████████████████████████| 51 kB 5.7 MB/s 
[K     |████████████████████████████████| 72 kB 617 kB/s 
[?25h

In [None]:
# Import libraries
import jax
import jax.numpy as jnp                # JAX NumPy

from flax import linen as nn           # The Linen API
from flax.training import train_state  # Useful dataclass to keep train state
import optax                           # Optimizers

import numpy as np                     
import tensorflow as tf 
import tensorflow_datasets as tfds                
from tensorflow import keras           

import time                           

In [None]:
# Set CONSTANTS for both models
NUM_EPOCH = 10
BATCH_SIZE = 32

## Keras CNN Model

We are to train a Keras CNN model as a baseline model.

In [None]:
# Setup a keras cnn with 6 layers
model = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        keras.layers.AveragePooling2D(pool_size=(2, 2)),
        keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        keras.layers.AveragePooling2D(pool_size=(2, 2)),
        keras.layers.Flatten(),
        keras.layers.Dense(256, activation="relu"),
        keras.layers.Dense(10, activation="softmax"),
    ]
)

In [None]:
# Load the MNIST dataset with tfds
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

[1mDownloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...[0m


local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.



Dl Completed...:   0%|          | 0/4 [00:00<?, ? file/s]


[1mDataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.[0m


In [None]:
# Normalize the traning data
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(BATCH_SIZE)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

In [None]:
# Normalize the test data
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(BATCH_SIZE)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

In [None]:
# Compile the model
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.01),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

In [None]:
# Train the model with 10 epochs

# START TIME 
start = int(round(time.time()))

model.fit(ds_train, epochs=NUM_EPOCH, validation_data=ds_test)

end = int(round(time.time()))
# END TIME 

print(f' Duration: {(end - start)}s')

Epoch 1/10


  return dispatch_target(*args, **kwargs)


Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
 Duration: 124s


## Flax CNN Model

We are to train a Flax cnn model to compare its training time with Keras cnn model training time.

In [None]:
# Setup a Flax CNN with 6 layers
class CNN(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    return x

In [None]:
def cross_entropy_loss(*, logits, labels):
  labels_onehot = jax.nn.one_hot(labels, num_classes=10)
  return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()

In [None]:
def compute_metrics(*, logits, labels):
  loss = cross_entropy_loss(logits=logits, labels=labels)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy,
  }
  return metrics

In [None]:
def get_datasets():
  """Load MNIST train and test datasets into memory."""
  ds_builder = tfds.builder('mnist')
  ds_builder.download_and_prepare()
  train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
  test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
  train_ds['image'] = jnp.float32(train_ds['image']) / 255.
  test_ds['image'] = jnp.float32(test_ds['image']) / 255.
  return train_ds, test_ds

In [None]:
def create_train_state(rng, learning_rate):
  """Creates initial `TrainState`."""
  cnn = CNN()
  params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
  tx = optax.adam(learning_rate)
  return train_state.TrainState.create(
      apply_fn=cnn.apply, params=params, tx=tx)

In [None]:
@jax.jit
def train_step(state, batch):
  """Train for a single step."""
  def loss_fn(params):
    logits = CNN().apply({'params': params}, batch['image'])
    loss = cross_entropy_loss(logits=logits, labels=batch['label'])
    return loss, logits
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, logits), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  metrics = compute_metrics(logits=logits, labels=batch['label'])
  return state, metrics

In [None]:
@jax.jit
def eval_step(params, batch):
  logits = CNN().apply({'params': params}, batch['image'])
  return compute_metrics(logits=logits, labels=batch['label'])

In [None]:
def train_epoch(state, train_ds, batch_size, epoch, rng):
  """Train for a single epoch."""
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // BATCH_SIZE

  perms = jax.random.permutation(rng, train_ds_size)
  perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
  perms = perms.reshape((steps_per_epoch, batch_size))
  batch_metrics = []
  for perm in perms:
    batch = {k: v[perm, ...] for k, v in train_ds.items()}
    state, metrics = train_step(state, batch)
    batch_metrics.append(metrics)

  # compute mean of metrics across each batch in epoch.
  batch_metrics_np = jax.device_get(batch_metrics)
  epoch_metrics_np = {
      k: np.mean([metrics[k] for metrics in batch_metrics_np])
      for k in batch_metrics_np[0]}

  print('train epoch: %d, loss: %.4f, accuracy: %.2f' % (
      epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100))

  return state

In [None]:
def eval_model(params, test_ds):
  metrics = eval_step(params, test_ds)
  metrics = jax.device_get(metrics)
  summary = jax.tree_util.tree_map(lambda x: x.item(), metrics)
  return summary['loss'], summary['accuracy']

In [None]:
train_ds, test_ds = get_datasets()

Instructions for updating:
Use `tf.data.Dataset.get_single_element()`.


Instructions for updating:
Use `tf.data.Dataset.get_single_element()`.


In [None]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

In [None]:
state = create_train_state(init_rng, 0.01)
del init_rng  # Must not be used anymore.

In [None]:
# Train the Flax CNN with 10 epochs

# START TIME 
start = int(round(time.time()))

for epoch in range(1, NUM_EPOCH + 1):
  # Use a separate PRNG key to permute image data during shuffling
  rng, input_rng = jax.random.split(rng)
  # Run an optimization step over a training batch
  state = train_epoch(state, train_ds, BATCH_SIZE, epoch, input_rng)
  # Evaluate on the test set after each training epoch 
  test_loss, test_accuracy = eval_model(state.params, test_ds)
  print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % (epoch, test_loss, test_accuracy * 100))

end = int(round(time.time()))
# END TIME
print(f' Duration: {(end - start)}s')

train epoch: 1, loss: 0.1367, accuracy: 95.84
 test epoch: 1, loss: 0.06, accuracy: 98.04
train epoch: 2, loss: 0.0658, accuracy: 98.02
 test epoch: 2, loss: 0.06, accuracy: 98.45
train epoch: 3, loss: 0.0542, accuracy: 98.40
 test epoch: 3, loss: 0.08, accuracy: 97.70
train epoch: 4, loss: 0.0455, accuracy: 98.66
 test epoch: 4, loss: 0.05, accuracy: 98.51
train epoch: 5, loss: 0.0480, accuracy: 98.69
 test epoch: 5, loss: 0.06, accuracy: 98.40
train epoch: 6, loss: 0.0421, accuracy: 98.81
 test epoch: 6, loss: 0.07, accuracy: 98.46
train epoch: 7, loss: 0.0370, accuracy: 99.01
 test epoch: 7, loss: 0.08, accuracy: 98.39
train epoch: 8, loss: 0.0389, accuracy: 98.98
 test epoch: 8, loss: 0.09, accuracy: 98.21
train epoch: 9, loss: 0.0343, accuracy: 99.18
 test epoch: 9, loss: 0.09, accuracy: 98.23
train epoch: 10, loss: 0.0402, accuracy: 99.02
 test epoch: 10, loss: 0.08, accuracy: 98.42
 Duration: 64s
