Skip to content


Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?

Latest commit


Git stats


Failed to load latest commit information.
Latest commit message
Commit time


Accompanying blog post: Distributed Training in TensorFlow with AI Platform & Docker

This repository provides code to train an image classification model in a distributed manner with the tf.distribute.MirroredStrategy strategy (single host multiple GPUs) in TensorFlow 2.4.1. We make use of the MLOps stack to do this:

  • Docker to create a custom image so that the code is reproducible.
  • AI Platform training jobs (by GCP) to manage running the custom Docker image using multiple GPUs. It also handles automatic provisioning and de-provisioning of resources.

Advantages of training in this manner (as opposed to doing that in a Jupyter Notebook environment) are the following:

  • Resources (GPUs, CPUs, memory, etc.) are fully managed by the custom service we are using to orchestrate our training workflow. In this case, it is AI Platform.
  • Resources are automatically provisioned and de-provisioned by the service. It helps to prevent any unnecessary costs.

Other recipes included:

  • Mixed-precision training (this will only work if you are using Tensor core GPUs like V100).
  • Serialization of resized and augmented TFRecords. This eliminates the augmentation and resizing ops from our data loading providing efficiency.

Steps to run the code 💻

Note: One needs to have a billing-enabled GCP project to fully follow these steps.

We will use a cheap AI Platform Notebook instance as our staging machine which we will use to build our custom Docker image, push it to Google Container Registry (GCR), and submit a training job to AI Platform. Additionally, we will use this instance to create TensorFlow Records (TFRecords) from the original dataset (Cats vs. Dogs in this case) and upload them to a GCS Bucket. AI Platform notebooks come pre-configured with many useful Python libraries, Linux packages like docker, and also the command-line GCP tools like gcloud.

(I used an n1-standard-4 instance (with TensorFlow 2.4 as the base image) which costs $0.141 hourly.)

  1. Set the following environmental variables and set the shell scripts to be executables:

    $ export PROJECT_ID=your-gcp-project-id
    $ export BUCKET_NAME=unique-gcs-bucket-name
    $ chmod +x scripts/*.sh
  2. Create a GCS Bucket:

    $ gsutil mb ${BUCKET_NAME}

    You can additionally pass in the zone where you want to create the bucket like the following: $ gsutil mb -l asia-east1 ${BUCKET_NAME}. If all of your resources will be provisioned from that same zone, then you will likely get a slight performance boost.

  3. Create TFRecords and upload them to the GCS Bucket.

    $ cd scripts
    $ source
  4. Build the custom Docker image and run it locally:

    $ cd ~/Distributed-Training-in-TensorFlow-2-with-AI-Platform
    $ source scripts/
  5. If everything is looking good, you can interrupt the training run with Ctrl-C and proceed to run on Cloud:

    $ source scripts/

... and done!

Find my TensorBoard logs online here. The training artifacts (SavedModels, TensorBoard logs, and TFRecords) can be found here.

About the files 🍖

    ├── config.yaml: Specifies the type of machine to use to run training on Cloud.
    ├── scripts
    │   ├── Trains on Cloud with the given specifications. 
    │   ├── Trains locally. 
    │   └── Creates and uploaded TFRecords to a GCS Bucket. 
    └── trainer
        ├── Specifies hyperparameters and other constants. 
        ├── Driver code for creating TFRecords. It is called by ``. 
        ├── Contains utilities for the data loader. 
        ├── Contains the actual data loading and model training code.
        ├── Contains model building utilities. 
        ├── Parses the command-line arguments given and starts an experiment.
        └── Utilities for creating TFRecords. 

References 👨‍💻

Acknowledgements 🙌

I am thankful to the ML-GDE program for providing generous GCP support.