<a href="https://colab.research.google.com/github/wenxuan0923/My-notes/blob/master/MNITS_TPU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Train Keras model with TPU on Google Colab
### - An Example using MNIST dataset with TF 2.0

Most tutorials available online are using the function `tf.contrib.tpu.keras_to_tpu_model` to set up the distribution strategy, which has been deprecated in TF 2.0.  In this note I will use train a simple handwritten digits classifier using Keras and Colab TPU. Be sure to change **runtime type** to be **TPU** when running this notebook in Colab.

Reference: 

[tf.distribute.Strategy Documentation](https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy)<br>
[A helpful Github issue](https://github.com/huan/tensorflow-handbook-tpu/issues/1)


In [1]:
import pandas as pd
import numpy as np
import tensorflow as tf
import keras
from keras.datasets import mnist
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

Using TensorFlow backend.


In [0]:
# Authenticates the Colab machine and also the TPU using your
# credentials so that they can access your private GCS buckets.
from google.colab import auth
auth.authenticate_user()

In [3]:
# Detect hardware
try:
  # TPU detection
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver() 
  print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
except ValueError:
  tpu = None
  gpus = tf.config.experimental.list_logical_devices("GPU")

Running on TPU  ['10.21.193.202:8470']


In [4]:
# Select appropriate distribution strategy
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
strategy = tf.distribute.experimental.TPUStrategy(tpu)
print("REPLICAS: ", strategy.num_replicas_in_sync)

INFO:tensorflow:Initializing the TPU system: grpc://10.21.193.202:8470


INFO:tensorflow:Initializing the TPU system: grpc://10.21.193.202:8470


INFO:tensorflow:Clearing out eager caches


INFO:tensorflow:Clearing out eager caches


INFO:tensorflow:Finished initializing TPU system.


INFO:tensorflow:Finished initializing TPU system.


INFO:tensorflow:Found TPU system:


INFO:tensorflow:Found TPU system:


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


REPLICAS:  8


### Data preprocessing

In [15]:
# the original data is 3D tensor
(x_train, y_train), (x_test, y_test) = mnist.load_data()
print(x_train.shape)
print(x_test.shape)

(60000, 28, 28)
(10000, 28, 28)


In [0]:
def preprocess_input(x):
    desired_shape = (-1, 28, 28, 1)
    return x.reshape(desired_shape) / 255.0

def preprocess_output(y):
    return keras.utils.to_categorical(y)

In [8]:
x_train, x_test = map(preprocess_input, [x_train, x_test])
y_train, y_test = map(preprocess_output, [y_train, y_test])

print(x_train.shape)
print(y_train.shape)

(60000, 28, 28, 1)
(60000, 10)


In [9]:
X_train, X_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.2, random_state=42)
print(X_train.shape)
print(X_val.shape)
print(y_train.shape)
print(y_val.shape)

(48000, 28, 28, 1)
(12000, 28, 28, 1)
(48000, 10)
(12000, 10)


### Keras model: 3 convolutional layers, 2 dense layers

In [0]:
from keras import models, layers
def create_model():
    model = models.Sequential()
    model.add(layers.Conv2D(32, (3, 3), activation='relu'))
    model.add(layers.MaxPool2D(2, 2))
    model.add(layers.Dropout(0.25))

    model.add(layers.Conv2D(32, (3, 3), activation='relu'))
    model.add(layers.MaxPool2D(2, 2))
    model.add(layers.Dropout(0.25))

    model.add(layers.Conv2D(32, (3, 3), activation='relu'))
    model.add(layers.MaxPool2D(2, 2))
    model.add(layers.Dropout(0.25))

    model.add(layers.Flatten())
    model.add(layers.Dense(251, activation='relu'))
    model.add(layers.Dropout(0.5))
    model.add(layers.Dense(10, activation='softmax'))
    return model

### Model Parameters

In [0]:
total_train = len(X_train)
total_val = len(X_val)
BATCH_SIZE = 64 * strategy.num_replicas_in_sync # Gobal batch size.
# The global batch size will be automatically sharded across all
# replicas by the tf.data.Dataset API. A single TPU has 8 cores.
# The best practice is to scale the batch size by the number of
# replicas (cores). The learning rate should be increased as well.
LEARNING_RATE = 0.01
LEARNING_RATE_EXP_DECAY = 0.6 if strategy.num_replicas_in_sync == 1 else 0.7
# Learning rate computed later in call back function
# 0.7 decay instead of 0.6 means a slower decay, i.e. a faster learnign rate.

### Callbacks

In [0]:
# set up learning rate decay
lr_decay = keras.callbacks.LearningRateScheduler(
    lambda epoch: LEARNING_RATE * LEARNING_RATE_EXP_DECAY**epoch,
    verbose=True)

### Train and validate the model

In [13]:
with strategy.scope():
  model = create_model()
  model.compile(
      optimizer='rmsprop',
      loss='categorical_crossentropy',
      metrics=['accuracy']
  )

model.fit(
    X_train, 
    y_train,
    epochs = 14,
    steps_per_epoch = total_train//BATCH_SIZE,
    callbacks = [lr_decay],
    validation_data = (X_val, y_val),
    validation_steps = total_val//BATCH_SIZE,
)

Train on 48000 samples, validate on 12000 samples
Epoch 1/14

Epoch 00001: LearningRateScheduler setting learning rate to 0.01.
Epoch 2/14

Epoch 00002: LearningRateScheduler setting learning rate to 0.006999999999999999.
Epoch 3/14

Epoch 00003: LearningRateScheduler setting learning rate to 0.0049.
Epoch 4/14

Epoch 00004: LearningRateScheduler setting learning rate to 0.003429999999999999.
Epoch 5/14

Epoch 00005: LearningRateScheduler setting learning rate to 0.0024009999999999995.
Epoch 6/14

Epoch 00006: LearningRateScheduler setting learning rate to 0.0016806999999999994.
Epoch 7/14

Epoch 00007: LearningRateScheduler setting learning rate to 0.0011764899999999997.
Epoch 8/14

Epoch 00008: LearningRateScheduler setting learning rate to 0.0008235429999999996.
Epoch 9/14

Epoch 00009: LearningRateScheduler setting learning rate to 0.0005764800999999997.
Epoch 10/14

Epoch 00010: LearningRateScheduler setting learning rate to 0.0004035360699999998.
Epoch 11/14

Epoch 00011: Learnin

<keras.callbacks.callbacks.History at 0x7f7ab1280240>

In [0]:
model.save_weights('./mnist_TPU.h5', overwrite=True)