ResNet-50 on TPU


If you want to train the model on Cloud TPU through the managed service Cloud Machine Learning Engine, skip to the Train on Cloud Machine Learning Engine section.

Setup a Google Cloud project

Follow the instructions at the Quickstart Guide to get a GCE VM with access to a Cloud TPU. It is also recommended that you try the Cloud TPU ResNet tutorial, which covers both the quickstart and training of the ResNet algorithm.

To run this model, you will need:

  • A GCE VM instance with an associated Cloud TPU resource
  • A GCS bucket to store your training checkpoints (the "model directory")
  • (Optional): The ImageNet training and validation data preprocessed into TFRecord format, and stored in GCS.

Formatting the data

The data is expected to be formatted in TFRecord format, as generated by this script.

If you do not have ImageNet dataset prepared, you can use a randomly generated fake dataset to test the model. It is located at gs://cloud-tpu-test-datasets/fake_imagenet.

Training the model

  1. Add the top-level /models folder to the Python path with the command
export PYTHONPATH="$PYTHONPATH:/path/to/models"
  1. Train the model by executing the following command (substituting the appropriate values):
python \
  --tpu=$TPU_NAME \
  --data_dir=$DATA_DIR \

$TPU_NAME is the name of the TPU node, the same name that appears when you run gcloud compute tpus list, or ctpu ls. (When using the shell created by ctpu up, this argument may not be necessary.)

$MODEL_DIR is a GCS location (a URL starting with gs:// where both the GCE VM and the associated Cloud TPU have write access, something like gs://userid- dev-imagenet-output/model. (TensorFlow can't create the bucket; you have to create it with gsutil mb <bucket>.) This bucket is used to save checkpoints and the training result, so that the training steps are cumulative when you reuse the model directory. If you do 1000 steps, for example, and you reuse the model directory, on a subsequent run, it will skip the first 1000 steps, because it picks up where it left off.

$DATA_DIR is a GCS location to which both the GCE VM and associated Cloud TPU have read access, something like gs://cloud-tpu-test-datasets/fake_imagenet. This location is expected to contain files with the prefixes train-* and validation-*. The former pattern is used to match files used for the training phase, the latter for the evaluation phase.

Each file is a series of TFExample records. In the case of ResNet-50, the TFExample records have a specific format, as follows:

keys_to_features = {
    'image/encoded': tf.FixedLenFeature((), tf.string, ''),
    'image/format': tf.FixedLenFeature((), tf.string, 'jpeg'),
    'image/class/label': tf.FixedLenFeature([], tf.int64, -1),
    'image/class/text': tf.FixedLenFeature([], tf.string, ''),
    'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
    'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
    'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
    'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
    'image/object/class/label': tf.VarLenFeature(dtype=tf.int64),

The training and validation data can also be sourced from Cloud Bigtable:

python \
  --tpu=$TPU_NAME \
  --model_dir=$MODEL_DIR \
  --bigtable_project=$PROJECT \
  --bigtable_instance=$INSTANCE \

In this case, the TFExample records are stored one per row in a Cloud Bigtable table. Categories of data are arranged by row prefix, and the rows within that prefix arranged by zero-filled indexes, e.g. train_0000003892.)

You can also specify the following arguments when sourcing data from Cloud Bigtable, though they already have the right defaults for ResNet-50:

  --bigtable_train_prefix=train_ \        # row prefix for training rows
  --bigtable_eval_prefix=validation_ \    # row prefix for evaluation rows
  --bigtable_column_family=tfexample \

Note that even when sourcing input data from Cloud Bigtable, $MODEL_DIR must still be a GCS location.

Project and Zone

If you are not running this script on a GCE VM in the same project and zone as your Cloud TPU, you will need to add the --project and --zone flags specifying the corresponding values for the Cloud TPU you'd like to use.

This will train a ResNet-50 model on ImageNet with 1024 batch size on a single Cloud TPU. With the default flags on everything, the model should train to above 76% accuracy in around 17 hours (including evaluation time every --steps_per_eval steps).

You can launch TensorBoard (e.g. tensorboard -logdir=$MODEL_DIR) to view loss curves and other metadata regarding your training run.

Note: if you launch TensorBoard on your GCE VM, be sure to configure either SSH port forwarding or SOCKS proxy over SSH to connect to your GCE VM securely (recommended).

Alternatively, you can modify your GCE firewall rules to open a port, but this is not recommended as it enables insecure world-wide access for everyone.

Train on Cloud Machine Learning Engine

To train this model on Machine Learning Engine, you will need:

  • A GCP project with Cloud Machine Learning Engine enabled
  • A GCS bucket to store your training checkpoints (the "model directory") and for staging the training package
  • (Optional): The ImageNet training and validation data preprocessed into TFRecord format, and stored in GCS.

Run the following command from the top level models folder:



gcloud ml-engine jobs submit training $JOB_NAME \
    --staging-bucket $STAGING_BUCKET \
    --runtime-version 1.9 \
    --scale-tier BASIC_TPU \
    --module-name official.resnet.resnet_main \
    --package-path official \
    --region $REGION \
    -- \
    --data_dir=$DATA_DIR \
    --model_dir=$OUTPUT_PATH \
    --resnet_depth=50 \

Understanding the code

For more detailed information, read the documentation within each file.

  • Constructs the input pipeline which handles parsing, preprocessing, shuffling, and batching the data samples.
  • Main code which constructs the TPUEstimator and handles training and evaluating the model.
  • ResNet model code which constructs the network via modular residual blocks or bottleneck blocks.
  • Useful utilities for preprocessing and augmenting ImageNet data for ResNet training. Significantly improves final accuracy.

Additional notes

About the model and training regime

The model is based on network architecture presented in Deep Residual Learning for Image Recognition by Kaiming He, et. al.

Specifically, the model uses post-activation residual units for ResNet-18, and 34 and post-activation bottleneck units for ResNet-50, 101, 152, and 200. There are a few differences to the model and training compared to the original paper:

  • The preprocessing and data augmentation is slightly different. In particular, we have an additional step during normalization which rescales the inputs based on the stddev of the RGB values of the dataset.
  • We use a larger batch size of 1024 (by default) instead of 256 and linearly scale the learning rate. In addition, we adopt the learning rate schedule suggested by Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour and train for 90 epochs.
  • We use a slightly different weight initialization for batch normalization in the last batch norm per block, as inspired by the above paper.
  • Evaluation is performed on a single center crop of the validation set rather than a 10-crop from the original paper.

Training/evaluating/predicting on CPU/GPU

To run the same code on CPU/GPU, set the flag --use_tpu=False. This will use the default devices available to TensorFlow on your machine. The checkpoints created by CPU/GPU and TPU are all identical so it is possible to train on one type of device and then evaluate/predict using the trained model on a different device.

Serve the exported model on CPU/GPU

To serve the exported model on CPU, set the flag --data_format='channels_last' as inference on CPU only supports channels_last. Inference on GPU supports both channels_first and channels_last.

Using different ResNet configurations

The default ResNet-50 has been carefully tested with the default flags but includes a few other commonly used configurations including ResNet-18, 34, 101, 152, 200. The 18 and 34 layer configurations use residual blocks without bottlenecks and the remaining configurations use bottleneck layers. The configuration can be controlled via --resnet_size. Bigger models require more training time and more memory, thus may require lowering the --train_batch_size to avoid running out of memory.

Using your own data

To use your own data with this model, you first need to write an input pipeline similar to It is recommended to use TFRecord format for storing your data on disk (see the ImageNet dataset download script for details) and for the actual pipeline. Then, simply replace the current imagenet_input in and adjust the dataset constants.

Benchmarking the training speed

Benchmarking code for DAWNBench can be found under the benchmark/ subdirectory. The benchmarking code imports the same models, inputs, and training regimes but includes some extra checkpointing and evaluation.