In [0]:
%tensorflow_version 2.x

In [8]:
import os
import time
import numpy as np
import tensorflow as tf
import tensorflow.keras.layers as l

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)





INFO:tensorflow:Initializing the TPU system: 10.64.225.42:8470


INFO:tensorflow:Initializing the TPU system: 10.64.225.42:8470


INFO:tensorflow:Clearing out eager caches


INFO:tensorflow:Clearing out eager caches


INFO:tensorflow:Finished initializing TPU system.


INFO:tensorflow:Finished initializing TPU system.


<tensorflow.python.tpu.topology.Topology at 0x7fdfc2c87668>

In [0]:
num_img = 16
num_filter = 128
num_it = 5000
num_it = 10 + num_it//num_img
dtype = tf.bfloat16
kernel = (3,3)

In [10]:
with tf.device('/job:worker/replica:0/task:0/device:TPU:0'):
    synthetic_data = tf.random.normal((num_img,224,224,128),
                                      mean=0.0,
                                      stddev=0.5,
                                      dtype=dtype)
    
    depsep_conv_2d = l.SeparableConv2D(filters=num_filter, kernel_size=kernel, dtype=dtype)
    conv_2d = l.Conv2D(filters=num_filter, kernel_size=kernel, dtype=dtype)

    @tf.function
    def conv_2d_forward(data):
        return conv_2d(data)
    
    @tf.function
    def depsep_conv_2d_forward(data):
        return depsep_conv_2d(data)
        
    # warm-up
    print("Warm-up...")
    for _ in range(3):
        conv_2d_forward(synthetic_data)
        conv_2d_forward(synthetic_data).numpy()
        depsep_conv_2d_forward(synthetic_data)
        depsep_conv_2d_forward(synthetic_data).numpy()
    print("Warm-up completed")
    
    print("Remove constant overhead:")
    st = time.time()
    for _ in range(30):
        conv_2d_forward(synthetic_data).numpy()
    et = time.time()
    conv_overhead = (et-st)/30
    print("* Conv2D:", conv_overhead)
    st = time.time()
    for _ in range(30):
        depsep_conv_2d_forward(synthetic_data).numpy()
    et = time.time()
    depsep_conv_overhead = (et-st)/30
    print("* SeparableConv2D:", depsep_conv_overhead)

Warm-up...
Warm-up completed
Remove constant overhead:
* Conv2D: 0.3245253403981527
* SeparableConv2D: 0.3583238442738851


In [11]:
print("Benchmark for Conv2D", kernel, dtype)

with tf.device('/job:worker/replica:0/task:0/device:TPU:0'):
    st = time.time()
    for _ in range(num_it-1):
        conv_2d_forward(synthetic_data)
    conv_2d_forward(synthetic_data).numpy()
    et = time.time()

    tt = et - st - conv_overhead
    conv_2d_score = int(num_img*num_it/tt)

    print("Conv2D per sec:", conv_2d_score)

Benchmark for Conv2D (3, 3) <dtype: 'bfloat16'>
Conv2D per sec: 1446


In [12]:
print("Benchmark for SeparableConv2D", kernel, dtype)

with tf.device('/job:worker/replica:0/task:0/device:TPU:0'):
    st = time.time()
    for _ in range(num_it-1):
        depsep_conv_2d_forward(synthetic_data)
    depsep_conv_2d_forward(synthetic_data).numpy()
    et = time.time()

    tt = et - st - depsep_conv_overhead
    depsep_conv_2d_score = int(num_img*num_it/tt)

    print("SeparableConv2D per sec:", depsep_conv_2d_score)

Benchmark for SeparableConv2D (3, 3) <dtype: 'bfloat16'>
SeparableConv2D per sec: 526
