# How to train a ResNet50 on RxRx1 using TPUs 

Colaboratory makes it easy to train models using [Cloud TPUs](https://cloud.google.com/tpu/), and this notebook demonstrates how to use the code in [rxrx1-utils](https://github.com/recursionpharma/rxrx1-utils) to train ResNet50 on the RxRx1 image set using Colab TPU.

Be sure to select the TPU runtime before beginning!

In [0]:
import json
import os
import sys
import tensorflow as tf

In [2]:
if 'google.colab' in sys.modules:
    !git clone https://github.com/recursionpharma/rxrx1-utils
    sys.path.append('/content/rxrx1-utils')

    from google.colab import auth
    auth.authenticate_user()
    
from rxrx.main import main

Cloning into 'rxrx1-utils'...
remote: Enumerating objects: 99, done.[K
remote: Counting objects: 100% (99/99), done.[K
remote: Compressing objects: 100% (53/53), done.[K
remote: Total 99 (delta 48), reused 92 (delta 42), pack-reused 0[K
Unpacking objects: 100% (99/99), done.


## Train

Set `MODEL_DIR` to be a Google Cloud Storage bucket that you can write to.   The code will write your checkpoins to this directory.

In [3]:
MODEL_DIR = 'gs://path/to/your/bucket'
URL_BASE_PATH = 'gs://rxrx1-us-central1/tfrecords/random-42'

# make sure we're in a TPU runtime
assert 'COLAB_TPU_ADDR' in os.environ

# set TPU-relevant args
tpu_grpc = 'grpc://{}'.format(os.environ['COLAB_TPU_ADDR'])
num_shards = 8  # colab uses Cloud TPU v2-8

# upload credentials to the TPU
with tf.Session(tpu_grpc) as sess:
    data = json.load(open('/content/adc.json'))
    tf.contrib.cloud.configure_gcs(sess, credentials=data)

tf.logging.set_verbosity(tf.logging.INFO)

main(use_tpu=True,
     tpu=tpu_grpc,
     gcp_project=None,
     tpu_zone=None,
     url_base_path=URL_BASE_PATH,
     use_cache=False,
     model_dir=MODEL_DIR,
     train_epochs=1,
     train_batch_size=512,
     num_train_images=73030,
     epochs_per_loop=1,
     log_step_count_epochs=1,
     num_cores=num_shards,
     data_format='channels_last',
     transpose_input=True,
     tf_precision='bfloat16',
     n_classes=1108,
     momentum=0.9,
     weight_decay=1e-4,
     base_learning_rate=0.2,
     warmup_epochs=5)

W0627 19:53:08.003671 139758653511552 deprecation_wrapper.py:119] From /content/rxrx1-utils/rxrx/main.py:280: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.

I0627 19:53:08.005592 139758653511552 main.py:280] tpu: grpc://10.106.194.154:8470
I0627 19:53:08.010348 139758653511552 main.py:283] gcp_project: None
W0627 19:53:10.223041 139758653511552 estimator.py:1984] Estimator's model_fn (functools.partial(<function resnet_model_fn at 0x7f1be104d8c8>, n_classes=1108, num_train_images=73030, data_format='channels_last', transpose_input=True, train_batch_size=512, iterations_per_loop=142, tf_precision='bfloat16', momentum=0.9, weight_decay=0.0001, base_learning_rate=0.2, warmup_epochs=5, model_dir='gs://recursion-tpu-training/berton/rxrx1_test/my_test', use_tpu=True, resnet_depth=50)) includes params argument, but params are not passed to Estimator.
I0627 19:53:10.225781 139758653511552 estimator.py:209] Using config: {'_model_dir': 'gs://recursion-tpu