In [0]:
#Dependencies
import tensorflow as tf
import tensorflow.keras.backend as ker
from tensorflow.contrib.tpu.python.tpu import keras_support
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, AveragePooling2D, Dense, Dropout, Flatten
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.utils import to_categorical
import numpy as np
import os

In [21]:
def basic_mlp(input,units):
  x=Dense(units)(input)
  x=BatchNormalization()(x)
  x=Activation("relu")(x)
  x=Dropout(0.5)(x)
  return x

def create_mlp():
  input = Input((32*32*3,))
  x=basic_mlp(input, 512)
  x=basic_mlp(input, 256)
  x=basic_mlp(input, 128)
  x=basic_mlp(input, 64)
  x=Dense(10, activation="softmax")(x)
  return Model(input,x)

def main():
  ker.clear_session()
  
  #CIFAR
  (x_train, y_train),(_,_) = cifar10.load_data()
  x_train = (x_train/255.0).reshape(50000, -1)
  y_train = to_categorical(y_train)
  
  #model building
  model=create_mlp()
  
  #cross shard optimizer
  optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
  tpu_optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 
  
  model.compile(tpu_optimizer, loss="categorical_crossentropy", metrics=["acc"])
  
  #connecting colab to the tpu in this enviroment 
  tpu_grpc_url="grpc://" + os.environ["COLAB_TPU_ADDR"]
  
  #using cluster resolver to access tpu pod
  tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url)
  
  #strategy
  strategy=tf.contrib.tpu.TPUDistributionStrategy(tpu_cluster_resolver)
  model=tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy)
  
  model.fit(x_train, y_train, batch_size=1024, epochs=15)
  
if __name__=="__main__":
  main()

INFO:tensorflow:Querying Tensorflow master (grpc://10.69.143.154:8470) for TPU system metadata.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 1058564301551203744)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, 15312325690895376758)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 6677000421848605297)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 16733970090192387138)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, 11338638165492359624)
INFO:tensorflow:*** Available Device: _DeviceAttribute