# How to train a ResNet50 on RxRx1 using TPUs 

Colaboratory makes it easy to train models using [Cloud TPUs](https://cloud.google.com/tpu/), and this notebook demonstrates how to use the code in [rxrx1-utils](https://github.com/recursionpharma/rxrx1-utils) to train ResNet50 on the RxRx1 image set using Colab TPU.

Be sure to select the TPU runtime before beginning!

In [0]:
import json
import os
import sys
import tensorflow as tf

In [0]:
if 'google.colab' in sys.modules:
    !git clone https://github.com/recursionpharma/rxrx1-utils
    sys.path.append('/content/rxrx1-utils')

    from google.colab import auth
    auth.authenticate_user()
    
from rxrx.main import main

Cloning into 'rxrx1-utils'...
remote: Enumerating objects: 54, done.[K
remote: Counting objects:   1% (1/54)   [Kremote: Counting objects:   3% (2/54)   [Kremote: Counting objects:   5% (3/54)   [Kremote: Counting objects:   7% (4/54)   [Kremote: Counting objects:   9% (5/54)   [Kremote: Counting objects:  11% (6/54)   [Kremote: Counting objects:  12% (7/54)   [Kremote: Counting objects:  14% (8/54)   [Kremote: Counting objects:  16% (9/54)   [Kremote: Counting objects:  18% (10/54)   [Kremote: Counting objects:  20% (11/54)   [Kremote: Counting objects:  22% (12/54)   [Kremote: Counting objects:  24% (13/54)   [Kremote: Counting objects:  25% (14/54)   [Kremote: Counting objects:  27% (15/54)   [Kremote: Counting objects:  29% (16/54)   [Kremote: Counting objects:  31% (17/54)   [Kremote: Counting objects:  33% (18/54)   [Kremote: Counting objects:  35% (19/54)   [Kremote: Counting objects:  37% (20/54)   [Kremote: Counting objects:  38% (21/54

## Train

Set `MODEL_DIR` to be a Google Cloud Storage bucket that you can write to.   The code will write your checkpoins to this directory.

In [0]:
MODEL_DIR = 'gs://path/to/your/bucket'
URL_BASE_PATH = 'gs://rxrx1-us-central1/tfrecords/random-42'

# make sure we're in a TPU runtime
assert 'COLAB_TPU_ADDR' in os.environ

# set TPU-relevant args
tpu_grpc = 'grpc://{}'.format(os.environ['COLAB_TPU_ADDR'])
num_shards = 8  # colab uses Cloud TPU v2-8

# upload credentials to the TPU
with tf.Session(tpu_grpc) as sess:
    data = json.load(open('/content/adc.json'))
    tf.contrib.cloud.configure_gcs(sess, credentials=data)

main(use_tpu=True,
     tpu=tpu_grpc,
     gcp_project=None,
     tpu_zone=None,
     url_base_path=URL_BASE_PATH,
     use_cache=False,
     model_dir=MODEL_DIR,
     train_epochs=3,
     train_batch_size=512,
     num_train_images=73030,
     epochs_per_loop=1,
     log_step_count_epochs=1,
     num_cores=num_shards,
     data_format='channels_last',
     transpose_input=True,
     tf_precision='bfloat16',
     n_classes=1108,
     momentum=0.9,
     weight_decay=1e-4,
     base_learning_rate=0.2,
     warmup_epochs=5)

W0627 00:50:41.234449 140143222630272 estimator.py:1984] Estimator's model_fn (functools.partial(<function resnet_model_fn at 0x7f756b216950>, n_classes=1108, num_train_images=73030, data_format='channels_last', transpose_input=True, train_batch_size=512, iterations_per_loop=142, tf_precision='bfloat16', momentum=0.9, weight_decay=0.0001, base_learning_rate=0.2, warmup_epochs=5, model_dir='gs://recursion-tpu-training/berton/rxrx1_test/my_test_3', use_tpu=True, resnet_depth=50)) includes params argument, but params are not passed to Estimator.
W0627 00:50:41.382101 140143222630272 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
W0627 00:50:41.424432 140143