## Let's build a ResNet!

Now that we have an understanding of JAX, and how to build neural network layers and optimizations in Flax, let's put our skills to implementing one of the most cited academic [publications](https://arxiv.org/pdf/1512.03385.pdf) in machine learning: Deep Residual Learning for Image Recognition.

### The idea

The paper authors found that after batchnorm, networks with more layers often performed worse that those with fewer.

![Neural net errors](https://raw.githubusercontent.com/fastai/fastbook/823b69e00aa1e1c1a45fe88bd346f11e8f89c1ff//images/att_00042.png)
(Image from paper linked above by Kaimimg He and others).

The researchers experimented with the idea of adding extra layers as an 'identity mapping', which means they have parameters and are trainable, but which return inputs without changing them.

This idea resulted in 'skip connections', which leap over convolutions as in this diagram:

![Skip connections](https://raw.githubusercontent.com/fastai/fastbook/823b69e00aa1e1c1a45fe88bd346f11e8f89c1ff//images/att_00043.png)

The skip connections make the network make the larger architecture easier to train, and prevent overfitting.

#### Residuals

A 'residual' is basically a prediction minus the target. ResNet blocks beat most earlier benchmarks because rather than asking them to predict the target, they predict the *difference* between the target and the prediction. This architecture proved very strong in detecting slight differences in images (is it a wolf or an off-leash Husky running through a dark forest?).

In [23]:
from functools import partial

from flax import linen as nn
import jax.numpy as jnp
from typing import Any, Callable, Sequence, Tuple

In [43]:
class ResNetBlock(nn.Module):
  filters: int
  strides: Tuple[int, int] = (1, 1)

  @nn.compact
  def __call__(self, x):
    residual = x
    y = nn.Conv(self.filters, (3, 3), self.strides)(x)
    y = nn.BatchNorm()(y)
    y = nn.relu(y)
    y = nn.Conv(self.filters, (3, 3))(y)
    return x + y

    if residual.shape != y.shape:
          residual = nn.conv(self.filters, (1, 1),
                              self.strides, name='conv_proj')(residual)
          residual = self.norm(name='norm_proj')(residual)

    return self.act(residual + y)



### Bottleneck layers

To enable training deeper models without spiking memory and computation use, we can use bottleneck layers. These were also introduced in the original paper as suitable for ResNets with a depth of 50 or more layers.

In our original ResNet layer, we have two convolutions with kernel size 3. Bottleneck layers use a 1 x 1 convolution at the start and end, and a 3 x 3 layer in between.

![Bottleneck layers](https://raw.githubusercontent.com/fastai/fastbook/823b69e00aa1e1c1a45fe88bd346f11e8f89c1ff//images/att_00045.png)

These improve training when used with deeper models since they allow us to add more filters. Filters mean we can reduce the color channels of images, then restore them -- hence their name.

In [44]:
class BottleneckResNetBlock(nn.Module):
  filters: int
  strides: Tuple[int, int] = (1, 1)

  @nn.compact
  def __call__(self, x):
    residual = x
    y = nn.Conv(self.filters, (1, 1))(x)
    y = nn.BatchNorm()(y)
    y = nn.relu(y)
    y = nn.Conv(self.filters, (3, 3), self.strides)(y)
    y = nn.BatchNorm()(y)
    y = nn.relu(y)
    y = nn.Conv(self.filters * 4, (1, 1), self.strides)(y)
    y = nn.BatchNorm(scale_init=nn.initializers.zeros_init())(y)

    if residual.shape != y.shape:
          residual = nn.conv(self.filters * 4, kernel_size=(1, 1),
                              strides=(1, 1), name='conv_proj')(residual)
          residual = self.norm(name='norm_proj')(residual)

    return nn.relu(residual + y)




### Creating the ResNet

Now we have our blocks, we can stack them together to create a fully-fledged ResNet.

Blocks are generally grouped by shape, so if our model has [2,2,2,2] blocks, it means we have four groups of 2 blocks.

We will use the original `ResNetBlock` for those < 50 layers, and the `BottleneckResNetBlock` for larger architectures.

The various ResNet sizes (ResNet18, ResNet50 etc) simply denote the number of layers.

A ResNet 18's blocks are [2,2,2,2], so how do we get to 18 layers?

We have 1 initial conv layer, 8 conv layers in residual blocks, a final conv layer, which gives us 10 layers. The remaining 8 are fully connected layers that follow the last conv layer, which typically serve as the classifier head and output predictions.

In [42]:
from functools import partial

class ResNet(nn.Module):
  num_classes: int
  block_class: nn.Module
  num_blocks: Sequence[int]
  filters: int = 64
  dtype: Any = jnp.float32


  @nn.compact
  def __call__(self, x, train: bool = True):
    x = nn.Conv(self.filters, (7, 7), (2, 2),
             padding=[(3, 3), (3, 3)],
             use_bias=False,
             name='conv_init')(x)
    x = nn.BatchNorm(name='bn_init')(x, use_running_average=not train)
    x = nn.relu(x)
    x = nn.max_pool(x, (3, 3), (2, 2), padding='SAME')

    for i, block_size in enumerate(self.num_blocks):
      for j in range(block_size):
        strides = (2, 2) if i > 0 and j == 0 else (1, 1)
        x = self.block_class(self.filters * (2**i),
                             strides=strides,
                             name=f'block_{i}_{j}')(x)

    x = jnp.mean(x, axis=(1, 2))
    x = nn.Dense(self.num_classes, dtype=self.dtype)(x)
    x = jnp.asarray(x, self.dtype)

    return x

ResNet18 = partial(ResNet, num_blocks=[2, 2, 2, 2],
                   block_class=ResNetBlock)
ResNet34 = partial(ResNet, num_blocks=[3, 4, 6, 3],
                   block_class=ResNetBlock)
ResNet50 = partial(ResNet, num_blocks=[3, 4, 6, 3],
                   block_class=BottleneckResNetBlock)
ResNet101 = partial(ResNet, num_blocks=[3, 4, 23, 3],
                    block_class=BottleneckResNetBlock)
ResNet152 = partial(ResNet, num_blocks=[3, 8, 36, 3],
                    block_class=BottleneckResNetBlock)
ResNet200 = partial(ResNet, num_blocks=[3, 24, 36, 3],
                    block_class=BottleneckResNetBlock)

In [47]:
"""ImageNet input pipeline.
"""

import jax
import tensorflow as tf
import tensorflow_datasets as tfds


IMAGE_SIZE = 224
CROP_PADDING = 32
MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255]
STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255]


def distorted_bounding_box_crop(image_bytes,
                                bbox,
                                min_object_covered=0.1,
                                aspect_ratio_range=(0.75, 1.33),
                                area_range=(0.05, 1.0),
                                max_attempts=100):
  """Generates cropped_image using one of the bboxes randomly distorted.

  See `tf.image.sample_distorted_bounding_box` for more documentation.

  Args:
    image_bytes: `Tensor` of binary image data.
    bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]`
        where each coordinate is [0, 1) and the coordinates are arranged
        as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole
        image.
    min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
        area of the image must contain at least this fraction of any bounding
        box supplied.
    aspect_ratio_range: An optional list of `float`s. The cropped area of the
        image must have an aspect ratio = width / height within this range.
    area_range: An optional list of `float`s. The cropped area of the image
        must contain a fraction of the supplied image within this range.
    max_attempts: An optional `int`. Number of attempts at generating a cropped
        region of the image of the specified constraints. After `max_attempts`
        failures, return the entire image.
  Returns:
    cropped image `Tensor`
  """
  shape = tf.io.extract_jpeg_shape(image_bytes)
  sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
      shape,
      bounding_boxes=bbox,
      min_object_covered=min_object_covered,
      aspect_ratio_range=aspect_ratio_range,
      area_range=area_range,
      max_attempts=max_attempts,
      use_image_if_no_bounding_boxes=True)
  bbox_begin, bbox_size, _ = sample_distorted_bounding_box

  # Crop the image to the specified bounding box.
  offset_y, offset_x, _ = tf.unstack(bbox_begin)
  target_height, target_width, _ = tf.unstack(bbox_size)
  crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
  image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)

  return image


def _resize(image, image_size):
  return tf.image.resize([image], [image_size, image_size],
                         method=tf.image.ResizeMethod.BICUBIC)[0]


def _at_least_x_are_equal(a, b, x):
  """At least `x` of `a` and `b` `Tensors` are equal."""
  match = tf.equal(a, b)
  match = tf.cast(match, tf.int32)
  return tf.greater_equal(tf.reduce_sum(match), x)


def _decode_and_random_crop(image_bytes, image_size):
  """Make a random crop of image_size."""
  bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
  image = distorted_bounding_box_crop(
      image_bytes,
      bbox,
      min_object_covered=0.1,
      aspect_ratio_range=(3. / 4, 4. / 3.),
      area_range=(0.08, 1.0),
      max_attempts=10)
  original_shape = tf.io.extract_jpeg_shape(image_bytes)
  bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3)

  image = tf.cond(
      bad,
      lambda: _decode_and_center_crop(image_bytes, image_size),
      lambda: _resize(image, image_size))

  return image


def _decode_and_center_crop(image_bytes, image_size):
  """Crops to center of image with padding then scales image_size."""
  shape = tf.io.extract_jpeg_shape(image_bytes)
  image_height = shape[0]
  image_width = shape[1]

  padded_center_crop_size = tf.cast(
      ((image_size / (image_size + CROP_PADDING)) *
       tf.cast(tf.minimum(image_height, image_width), tf.float32)),
      tf.int32)

  offset_height = ((image_height - padded_center_crop_size) + 1) // 2
  offset_width = ((image_width - padded_center_crop_size) + 1) // 2
  crop_window = tf.stack([offset_height, offset_width,
                          padded_center_crop_size, padded_center_crop_size])
  image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
  image = _resize(image, image_size)

  return image


def normalize_image(image):
  image -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=image.dtype)
  image /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=image.dtype)
  return image


def preprocess_for_train(image_bytes, dtype=tf.float32, image_size=IMAGE_SIZE):
  """Preprocesses the given image for training.

  Args:
    image_bytes: `Tensor` representing an image binary of arbitrary size.
    dtype: data type of the image.
    image_size: image size.

  Returns:
    A preprocessed image `Tensor`.
  """
  image = _decode_and_random_crop(image_bytes, image_size)
  image = tf.reshape(image, [image_size, image_size, 3])
  image = tf.image.random_flip_left_right(image)
  image = normalize_image(image)
  image = tf.image.convert_image_dtype(image, dtype=dtype)
  return image


def preprocess_for_eval(image_bytes, dtype=tf.float32, image_size=IMAGE_SIZE):
  """Preprocesses the given image for evaluation.

  Args:
    image_bytes: `Tensor` representing an image binary of arbitrary size.
    dtype: data type of the image.
    image_size: image size.

  Returns:
    A preprocessed image `Tensor`.
  """
  image = _decode_and_center_crop(image_bytes, image_size)
  image = tf.reshape(image, [image_size, image_size, 3])
  image = normalize_image(image)
  image = tf.image.convert_image_dtype(image, dtype=dtype)
  return image


def create_split(dataset_builder, batch_size, train, dtype=tf.float32,
                 image_size=IMAGE_SIZE, cache=False, shuffle_buffer_size=2_000,
                 prefetch=10):
  """Creates a split from the ImageNet dataset using TensorFlow Datasets.

  Args:
    dataset_builder: TFDS dataset builder for ImageNet.
    batch_size: the batch size returned by the data pipeline.
    train: Whether to load the train or evaluation split.
    dtype: data type of the image.
    image_size: The target size of the images.
    cache: Whether to cache the dataset.
    shuffle_buffer_size: Size of the shuffle buffer.
    prefetch: Number of items to prefetch in the dataset.
  Returns:
    A `tf.data.Dataset`.
  """
  if train:
    train_examples = dataset_builder.info.splits['train'].num_examples
    split_size = train_examples // jax.process_count()
    start = jax.process_index() * split_size
    split = f'train[{start}:{start + split_size}]'
  else:
    validate_examples = dataset_builder.info.splits['validation'].num_examples
    split_size = validate_examples // jax.process_count()
    start = jax.process_index() * split_size
    split = f'validation[{start}:{start + split_size}]'

  def decode_example(example):
    if train:
      image = preprocess_for_train(example['image'], dtype, image_size)
    else:
      image = preprocess_for_eval(example['image'], dtype, image_size)
    return {'image': image, 'label': example['label']}

  ds = dataset_builder.as_dataset(split=split, decoders={
      'image': tfds.decode.SkipDecoding(),
  })
  options = tf.data.Options()
  options.experimental_threading.private_threadpool_size = 48
  ds = ds.with_options(options)

  if cache:
    ds = ds.cache()

  if train:
    ds = ds.repeat()
    ds = ds.shuffle(shuffle_buffer_size, seed=0)

  ds = ds.map(decode_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)
  ds = ds.batch(batch_size, drop_remainder=True)

  if not train:
    ds = ds.repeat()

  ds = ds.prefetch(prefetch)

  return ds

In [49]:
!pip install clu

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting clu
  Downloading clu-0.0.9-py3-none-any.whl (98 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/98.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m98.2/98.2 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
Collecting ml-collections (from clu)
  Downloading ml_collections-0.1.1.tar.gz (77 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.9/77.9 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: ml-collections
  Building wheel for ml-collections (setup.py) ... [?25l[?25hdone
  Created wheel for ml-collections: filename=ml_collections-0.1.1-py3-none-any.whl size=94506 sha256=d76373859a0f34ede78c42eee25f5edc5687955cc5bd75f397d317c644810ae5
  Stored in directory: /root/.cache/pip/wheels/7

In [82]:
import functools
import time
from typing import Any

from absl import logging
from clu import metric_writers
from clu import periodic_actions
import flax
from flax import jax_utils
from flax.training import checkpoints
from flax.training import common_utils
from flax.training import dynamic_scale as dynamic_scale_lib
from flax.training import train_state
import jax
from jax import lax
import jax.numpy as jnp
from jax import random
import ml_collections
import optax
import tensorflow as tf
import tensorflow_datasets as tfds



NUM_CLASSES = 1000


def create_model(*, model_cls, half_precision, **kwargs):
  platform = jax.local_devices()[0].platform
  if half_precision:
    if platform == 'tpu':
      model_dtype = jnp.bfloat16
    else:
      model_dtype = jnp.float16
  else:
    model_dtype = jnp.float32
  return model_cls(num_classes=NUM_CLASSES, dtype=model_dtype, **kwargs)


def initialized(key, image_size, model):
  input_shape = (1, image_size, image_size, 3)
  @jax.jit
  def init(*args):
    return model.init(*args)
  variables = init({'params': key}, jnp.ones(input_shape, model.dtype))
  return variables['params'], variables['batch_stats']


def cross_entropy_loss(logits, labels):
  one_hot_labels = common_utils.onehot(labels, num_classes=NUM_CLASSES)
  xentropy = optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels)
  return jnp.mean(xentropy)


def compute_metrics(logits, labels):
  loss = cross_entropy_loss(logits, labels)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy,
  }
  metrics = lax.pmean(metrics, axis_name='batch')
  return metrics


def create_learning_rate_fn(
    config: ml_collections.ConfigDict,
    base_learning_rate: float,
    steps_per_epoch: int):
  """Create learning rate schedule."""
  warmup_fn = optax.linear_schedule(
      init_value=0., end_value=base_learning_rate,
      transition_steps=flags['warmup_epochs'] * steps_per_epoch)
  cosine_epochs = max(flags['num_epochs'] - flags['warmup_epochs'], 1)
  cosine_fn = optax.cosine_decay_schedule(
      init_value=base_learning_rate,
      decay_steps=cosine_epochs * steps_per_epoch)
  schedule_fn = optax.join_schedules(
      schedules=[warmup_fn, cosine_fn],
      boundaries=[flags['warmup_epochs'] * steps_per_epoch])
  return schedule_fn


def train_step(state, batch, learning_rate_fn):
  """Perform a single training step."""
  def loss_fn(params):
    """loss function used for training."""
    logits, new_model_state = state.apply_fn(
        {'params': params, 'batch_stats': state.batch_stats},
        batch['image'],
        mutable=['batch_stats'])
    loss = cross_entropy_loss(logits, batch['label'])
    weight_penalty_params = jax.tree_util.tree_leaves(params)
    weight_decay = 0.0001
    weight_l2 = sum(jnp.sum(x ** 2)
                     for x in weight_penalty_params
                     if x.ndim > 1)
    weight_penalty = weight_decay * 0.5 * weight_l2
    loss = loss + weight_penalty
    return loss, (new_model_state, logits)

  step = state.step
  dynamic_scale = state.dynamic_scale
  lr = learning_rate_fn(step)

  if dynamic_scale:
    grad_fn = dynamic_scale.value_and_grad(
        loss_fn, has_aux=True, axis_name='batch')
    dynamic_scale, is_fin, aux, grads = grad_fn(state.params)
    # dynamic loss takes care of averaging gradients across replicas
  else:
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    aux, grads = grad_fn(state.params)
    # Re-use same axis_name as in the call to `pmap(...train_step...)` below.
    grads = lax.pmean(grads, axis_name='batch')
  new_model_state, logits = aux[1]
  metrics = compute_metrics(logits, batch['label'])
  metrics['learning_rate'] = lr

  new_state = state.apply_gradients(
      grads=grads, batch_stats=new_model_state['batch_stats'])
  if dynamic_scale:
    # if is_fin == False the gradients contain Inf/NaNs and optimizer state and
    # params should be restored (= skip this step).
    new_state = new_state.replace(
        opt_state=jax.tree_util.tree_map(
            functools.partial(jnp.where, is_fin),
            new_state.opt_state,
            state.opt_state),
        params=jax.tree_util.tree_map(
            functools.partial(jnp.where, is_fin),
            new_state.params,
            state.params),
        dynamic_scale=dynamic_scale)
    metrics['scale'] = dynamic_scale.scale

  return new_state, metrics


def eval_step(state, batch):
  variables = {'params': state.params, 'batch_stats': state.batch_stats}
  logits = state.apply_fn(
      variables, batch['image'], train=False, mutable=False)
  return compute_metrics(logits, batch['label'])


def prepare_tf_data(xs):
  """Convert a input batch from tf Tensors to numpy arrays."""
  local_device_count = jax.local_device_count()
  def _prepare(x):
    # Use _numpy() for zero-copy conversion between TF and NumPy.
    x = x._numpy()  # pylint: disable=protected-access

    # reshape (host_batch_size, height, width, 3) to
    # (local_devices, device_batch_size, height, width, 3)
    return x.reshape((local_device_count, -1) + x.shape[1:])

  return jax.tree_util.tree_map(_prepare, xs)


def create_input_iter(dataset_builder, batch_size, image_size, dtype, train,
                      cache, shuffle_buffer_size, prefetch):
  ds = create_split(dataset_builder, batch_size, image_size=image_size, dtype=dtype,
      train=train, cache=cache, shuffle_buffer_size=shuffle_buffer_size,
      prefetch=prefetch)
  it = map(prepare_tf_data, ds)
  it = jax_utils.prefetch_to_device(it, 2)
  return it


class TrainState(train_state.TrainState):
  batch_stats: Any
  dynamic_scale: dynamic_scale_lib.DynamicScale


def restore_checkpoint(state, workdir):
  return checkpoints.restore_checkpoint(workdir, state)


def save_checkpoint(state, workdir):
  state = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state))
  step = int(state.step)
  logging.info('Saving checkpoint step %d.', step)
  checkpoints.save_checkpoint_multiprocess(workdir, state, step, keep=3)


# pmean only works inside pmap because it needs an axis name.
# This function will average the inputs across all devices.
cross_replica_mean = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x')


def sync_batch_stats(state):
  """Sync the batch statistics across replicas."""
  # Each device has its own version of the running average batch statistics and
  # we sync them before evaluation.
  return state.replace(batch_stats=cross_replica_mean(state.batch_stats))


def create_train_state(rng, config: ml_collections.ConfigDict,
                       model, image_size, learning_rate_fn):
  """Create initial training state."""
  dynamic_scale = None
  platform = jax.local_devices()[0].platform
  if flags['half_precision'] and platform == 'gpu':
    dynamic_scale = dynamic_scale_lib.DynamicScale()
  else:
    dynamic_scale = None

  params, batch_stats = initialized(rng, image_size, model)
  tx = optax.sgd(
      learning_rate=learning_rate_fn,
      momentum=flags['momentum'],
      nesterov=True,
  )
  state = TrainState.create(
      apply_fn=model.apply,
      params=params,
      tx=tx,
      batch_stats=batch_stats,
      dynamic_scale=dynamic_scale)
  return state


def train_and_evaluate(FLAGS,
                       workdir: str) -> TrainState:
  """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    workdir: Directory where the tensorboard summaries are written to.

  Returns:
    Final TrainState.
  """

  writer = metric_writers.create_default_writer(
      logdir=workdir, just_logging=jax.process_index() != 0)

  rng = random.PRNGKey(0)

  image_size = 224

  if flags['batch_size'] % jax.device_count() > 0:
    raise ValueError('Batch size must be divisible by the number of devices')
  local_batch_size = flags['batch_size'] // jax.process_count()

  platform = jax.local_devices()[0].platform

  if flags['half_precision']:
    if platform == 'tpu':
      input_dtype = tf.bfloat16
    else:
      input_dtype = tf.float16
  else:
    input_dtype = tf.float32

  dataset_builder = tfds.builder(flags['dataset'])
  train_iter = create_input_iter(
      dataset_builder, local_batch_size, image_size, input_dtype, train=True,
      cache=flags['cache'], shuffle_buffer_size=flags['shuffle_buffer_size'],
      prefetch=flags['prefetch'])
  eval_iter = create_input_iter(
      dataset_builder, local_batch_size, image_size, input_dtype, train=False,
      cache=flags['cache'], shuffle_buffer_size=None, prefetch=flags['prefetch'])

  steps_per_epoch = (
      dataset_builder.info.splits['train'].num_examples // config.batch_size
  )

  if flags['num_train_steps'] <= 0:
    num_steps = int(steps_per_epoch * config.num_epochs)
  else:
    num_steps = flags['num_train_steps']

  if flags['steps_per_eval'] == -1:
    num_validation_examples = dataset_builder.info.splits[
        'validation'].num_examples
    steps_per_eval = num_validation_examples // flags['batch_size']
  else:
    steps_per_eval = flags['steps_per_eval']

  steps_per_checkpoint = steps_per_epoch * 10

  base_learning_rate = flags['learning_rate'] * flags['batch_size'] / 256.

  model_cls = getattr(models, flags['model'])
  model = create_model(
      model_cls=model_cls, half_precision=flags['half_precision'])

  learning_rate_fn = create_learning_rate_fn(
      config, base_learning_rate, steps_per_epoch)

  state = create_train_state(rng, config, model, image_size, learning_rate_fn)
  state = restore_checkpoint(state, workdir)
  # step_offset > 0 if restarting from checkpoint
  step_offset = int(state.step)
  state = jax_utils.replicate(state)

  p_train_step = jax.pmap(
      functools.partial(train_step, learning_rate_fn=learning_rate_fn),
      axis_name='batch')
  p_eval_step = jax.pmap(eval_step, axis_name='batch')

  train_metrics = []
  hooks = []
  if jax.process_index() == 0:
    hooks += [periodic_actions.Profile(num_profile_steps=5, logdir=workdir)]
  train_metrics_last_t = time.time()
  logging.info('Initial compilation, this might take some minutes...')
  for step, batch in zip(range(step_offset, num_steps), train_iter):
    state, metrics = p_train_step(state, batch)
    for h in hooks:
      h(step)
    if step == step_offset:
      logging.info('Initial compilation completed.')

    if flags['log_every_steps']:
      train_metrics.append(metrics)
      if (step + 1) % flags['log_every_steps == 0']:
        train_metrics = common_utils.get_metrics(train_metrics)
        summary = {
            f'train_{k}': v
            for k, v in jax.tree_util.tree_map(lambda x: x.mean(), train_metrics).items()
        }
        summary['steps_per_second'] = flags['log_every_steps / '](
            time.time() - train_metrics_last_t)
        writer.write_scalars(step + 1, summary)
        train_metrics = []
        train_metrics_last_t = time.time()

    if (step + 1) % steps_per_epoch == 0:
      epoch = step // steps_per_epoch
      eval_metrics = []

      # sync batch statistics across replicas
      state = sync_batch_stats(state)
      for _ in range(steps_per_eval):
        eval_batch = next(eval_iter)
        metrics = p_eval_step(state, eval_batch)
        eval_metrics.append(metrics)
      eval_metrics = common_utils.get_metrics(eval_metrics)
      summary = jax.tree_util.tree_map(lambda x: x.mean(), eval_metrics)
      logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f',
                   epoch, summary['loss'], summary['accuracy'] * 100)
      writer.write_scalars(
          step + 1, {f'eval_{key}': val for key, val in summary.items()})
      writer.flush()
    if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:
      state = sync_batch_stats(state)
      save_checkpoint(state, workdir)

  # Wait until computations are done before exiting
  jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()

  return state

In [83]:
!pip install configs

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [102]:
flags = {}
flags['model'] = 'ResNet50'
flags['dataset'] = 'imagenet2012:5.*.*'
flags['seed'] = 1

flags['learning_rate'] = 0.1
flags['warmup_epochs'] = 5.0
flags['momentum'] = 0.9
flags['batch_size'] = 128
flags['shuffle_buffer_size'] = 16 * 128
flags['prefetch'] = 10

flags['num_epochs'] = 100.0
flags['log_every_steps'] = 100

flags['cache'] = False
flags['half_precision'] = False

def metrics():
  return [
      'train_loss',
      'eval_loss',
      'train_accuracy',
      'eval_accuracy',
      'steps_per_second',
      'train_learning_rate',
  ]

In [85]:
flags

{'model': 'ResNet50',
 'dataset': 'imagenet2012:5.*.*',
 'seed': 1,
 'learning_rate': 0.1,
 'warmup_epochs': 5.0,
 'momentum': 0.9,
 'batch_size': 128,
 'shuffle_buffer_size': 2048,
 'prefetch': 10,
 'num_epochs': 100.0,
 'log_every_steps': 100,
 'cache': False,
 'half_precision': False}

In [87]:
dataset_builder = tfds.builder('imagenette')
dataset_builder.download_and_prepare()
ds = dataset_builder.as_dataset('train')
dataset_builder.info

Downloading and preparing dataset 1.45 GiB (download: 1.45 GiB, generated: 1.46 GiB, total: 2.91 GiB) to /root/tensorflow_datasets/imagenette/full-size-v2/1.0.0...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Shuffling /root/tensorflow_datasets/imagenette/full-size-v2/1.0.0.incomplete82RKKC/imagenette-train.tfrecord*.…

Generating validation examples...:   0%|          | 0/3925 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/imagenette/full-size-v2/1.0.0.incomplete82RKKC/imagenette-validation.tfrec…

Dataset imagenette downloaded and prepared to /root/tensorflow_datasets/imagenette/full-size-v2/1.0.0. Subsequent calls will reuse this data.


tfds.core.DatasetInfo(
    name='imagenette',
    full_name='imagenette/full-size-v2/1.0.0',
    description="""
    Imagenette is a subset of 10 easily classified classes from the Imagenet
    dataset. It was originally prepared by Jeremy Howard of FastAI. The objective
    behind putting together a small version of the Imagenet dataset was mainly
    because running new ideas/algorithms/experiments on the whole Imagenet take a
    lot of time.
    
    This version of the dataset allows researchers/practitioners to quickly try out
    ideas and share with others. The dataset comes in three variants:
    
    *   Full size
    *   320 px
    *   160 px
    
    Note: The v2 config correspond to the new 70/30 train/valid split (released in
    Dec 6 2019).
    """,
    config_description="""
    full-size variant.
    """,
    homepage='https://github.com/fastai/imagenette',
    data_path=PosixGPath('/tmp/tmp5nparc18tfds'),
    file_format=tfrecord,
    download_size=1.45 GiB,
    data

In [90]:
import json

![ ! -f mapping_imagenet.json ] && wget --no-check-certificate https://raw.githubusercontent.com/ozendelait/wordnet-to-json/master/mapping_imagenet.json

with open('mapping_imagenet.json') as f:
  mapping_imagenet = json.load(f)
# Mapping imagenette label name to imagenet label index.
imagenette_labels = {
    d['v3p0']: d['label']
    for d in mapping_imagenet
}
# Mapping imagenette label name to human-readable label.
imagenette_idx = {
    d['v3p0']: idx
    for idx, d in enumerate(mapping_imagenet)
}

def imagenette_label(idx):
  """Returns a short human-readable string for provided imagenette index."""
  net = dataset_builder.info.features['label'].int2str(idx)
  return imagenette_labels[net].split(',')[0]

def imagenette_imagenet2012(idx):
  """Returns the imagenet2012 index for provided imagenette index."""
  net = dataset_builder.info.features['label'].int2str(idx)
  return imagenette_idx[net]

def imagenet2012_label(idx):
  """Returns a short human-readable string for provided imagenet2012 index."""
  return mapping_imagenet[idx]['label'].split(',')[0]


In [93]:
train_ds = create_split(
    dataset_builder, 128, train=True,
)
eval_ds = create_split(
    dataset_builder, 128, train=False,
)




In [94]:
train_batch = next(iter(train_ds))
{k: (v.shape, v.dtype) for k, v in train_batch.items()}

{'image': (TensorShape([128, 224, 224, 3]), tf.float32),
 'label': (TensorShape([128]), tf.int64)}

In [96]:
# Regenerate datasets with updated batch_size.
train_ds = create_split(
    dataset_builder, flags['batch_size'], train=True,
)
eval_ds = create_split(
    dataset_builder, flags['batch_size'], train=False,
)



In [103]:
flags

{'model': 'ResNet50',
 'dataset': 'imagenet2012:5.*.*',
 'seed': 1,
 'learning_rate': 0.1,
 'warmup_epochs': 5.0,
 'momentum': 0.9,
 'batch_size': 128,
 'shuffle_buffer_size': 2048,
 'prefetch': 10,
 'num_epochs': 100.0,
 'log_every_steps': 100,
 'cache': False,
 'half_precision': False}

In [108]:
ds = tfds.load('imagenette')

In [109]:
# Takes ~1.5 min / epoch.
for lr in  (0.03, 0.1, 0.3):
  lr = flags['learning_rate']
  model_name = flags['model']
  name = f'{model_name}_lr={lr}'
  print(f'\n\n{name}')
  state = train_and_evaluate(flags, workdir=f'./models/{name}')



ResNet50_lr=0.1


AssertionError: ignored

In [105]:
logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count())
logging.info('JAX local devices: %r', jax.local_devices())

  # Add a note so that we can tell which task is which JAX host.
  # (Depending on the platform task 0 is not guaranteed to be host 0)

train_and_evaluate(FLAGS= flags, workdir="./tmp")



AssertionError: ignored

In [106]:
!ls

mapping_imagenet.json  models  sample_data  tmp
