[View in Colaboratory](https://colab.research.google.com/github/windpuppy/tensorflow/blob/master/test_tpu.ipynb)

In [8]:
import os
import pprint
import tensorflow as tf
import numpy as np
import time

if 'COLAB_TPU_ADDR' not in os.environ:
  print('ERROR: Not connected to a TPU runtime; please see the first cell in this notebook for instructions!')
  exit()
  
tpu_address = 'grpc://' + os.environ['COLAB_TPU_ADDR']
print ('TPU address is', tpu_address)

with tf.Session(tpu_address) as session:
  devices = session.list_devices()

#print('TPU devices:')
#pprint.pprint(devices)
print('TPU devices:', len(devices), "in total")
print('')

#
# small test
#

def add_op(x, y):
  return x + y
  
x = tf.placeholder(tf.float32, [10,])
y = tf.placeholder(tf.float32, [10,])
tpu_ops = tf.contrib.tpu.rewrite(add_op, [x, y])
session = tf.Session(tpu_address)

try:
  print('Initializing...')
  session.run(tf.contrib.tpu.initialize_system())
  print('Running ops')
  print(session.run(tpu_ops, {x: np.arange(10), y: np.arange(10)}))
finally:
  # For now, TPU sessions must be shutdown separately from
  # closing the session.
  print('Shutting down...')
  session.run(tf.contrib.tpu.shutdown_system())
  session.close()
print('')
  
#
# profiler test
#

N = 4096
COUNT = 100

def flops():
  x = tf.random_uniform([N, N])
  y = tf.random_uniform([N, N])
  def _matmul(x, y):
    return tf.tensordot(x, y, axes=[[1], [0]]), y

  return tf.reduce_sum(
    tf.contrib.tpu.repeat(COUNT, _matmul, [x, y])
  )
  
tpu_ops = tf.contrib.tpu.batch_parallel(flops, [], num_shards=8)
session = tf.Session(tpu_address)

try:
  print('Initializing...')
  session.run(tf.contrib.tpu.initialize_system())
  print('Profiling')
  start = time.time()
  session.run(tpu_ops)
  elapsed = time.time() - start
  print(elapsed, 'seconds, TFlops: {:.2f}'.format(1e-12 * 8 * COUNT * 2*N*N*N / elapsed))
finally:
  print('Shutting down...')
  session.run(tf.contrib.tpu.shutdown_system())
  session.close()

TPU address is grpc://10.63.130.250:8470
TPU devices: 12 in total

Initializing...
Running ops
[array([ 0.,  2.,  4.,  6.,  8., 10., 12., 14., 16., 18.], dtype=float32)]
Shutting down...

Initializing...
Profiling
1.035811424255371 seconds, TFlops: 106.15
Shutting down...
