Object Detection Models on TPU


To get started, make sure to use Tensorflow 1.13+ on Google Cloud. Also here are a few package you need to install to get started:

sudo apt-get install -y python-tk && \
pip install Cython matplotlib opencv-python-headless pyyaml Pillow && \
pip install 'git+'

Next, download the code from tpu github repository or use the pre-installed Google Cloud VM.

git clone

Train RetinaNet on TPU

Train a vanilla ResNet-50 based RetinaNet.

TPU_NAME="<your GCP TPU name>"
MODEL_DIR="<path to the directory to store model files>"
RESNET_CHECKPOINT="<path to the pre-trained Resnet-50 checkpoint>"
TRAIN_FILE_PATTERN="<path to the TFRecord training data>"
EVAL_FILE_PATTERN="<path to the TFRecord validation data>"
VAL_JSON_FILE="<path to the validation annotation JSON file>"
python ~/tpu/models/official/detection/ \
  --use_tpu=True \
  --tpu="${TPU_NAME?}" \
  --num_cores=8 \
  --model_dir="${MODEL_DIR?}" \
  --mode=train \
  --eval_after_training=True \
  --params_override="{ type: retinanet, train: { checkpoint: { path: ${RESNET_CHECKPOINT?}, prefix: resnet50/ }, train_file_pattern: ${TRAIN_FILE_PATTERN?} }, eval: { val_json_file: ${VAL_JSON_FILE?}, eval_file_pattern: ${EVAL_FILE_PATTERN?} } }"

Train a custom RetinaNet using the config file.

First, create a YAML config file, e.g. my_retinanet.yaml. This file specifies the parameters to be overridden, which should at least include the following fields.

# my_retinanet.yaml
type: 'retinanet'
  train_file_pattern: <path to the TFRecord training data>
  eval_file_pattern: <path to the TFRecord validation data>
  val_json_file: <path to the validation annotation JSON file>

Once the YAML config file is created, you can launch the training using the following command.

TPU_NAME="<your GCP TPU name>"
MODEL_DIR="<path to the directory to store model files>"
python ~/tpu/models/official/detection/ \
  --use_tpu=True \
  --tpu="${TPU_NAME?}" \
  --num_cores=8 \
  --model_dir="${MODEL_DIR?}" \
  --mode=train \
  --eval_after_training=True \

Available RetinaNet templates.

Export to SavedModel for serving

Once the training is finished, you can export the model in the SavedModel format for serving using the following command.

EXPORT_DIR="<path to the directory to store the exported model>"
CHECKPOINT_PATH="<path to the pre-trained checkpoint>"
PARAMS_OVERRIDE=""  # if any.
python ~/tpu/models/official/detection/ \
  --export_dir="${EXPORT_DIR?}" \
  --checkpoint_path="${CHECKPOINT_PATH?}" \
  --use_tpu=${USE_TPU?} \
  --params_override="${PARAMS_OVERRIDE?}" \
  --batch_size=${BATCH_SIZE?} \
  --input_type="${INPUT_TYPE?}" \
  --input_name="${INPUT_NAME?}" \
  --input_image_size="${INPUT_IMAGE_SIZE?}" \
  --output_image_info=${OUTPUT_IMAGE_INFO?} \
  --output_normalized_coordinates=${OUTPUT_NORMALIZED_COORDINATES?} \
