### A very compact example of training ResNet50 on MNIST using TensorFlow2

Note that TensorFlow2 takes quite a long time to initialize. It may seem like the notebook stalled, but be prepared to wait for up to a minute or so for it to start printing output.

In [3]:
# Test of Tensorflow2 training on MNIST

import argparse
import os
import numpy as np
import timeit

import tensorflow as tf
from tensorflow.keras import applications

batch_size=8
num_batches_per_iter=3
num_iters=3
warmup_iters=2

gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

# Set up standard model.
model = getattr(applications, 'ResNet50')(weights=None)
opt = tf.optimizers.SGD(0.01)

data = tf.random.uniform([batch_size, 224, 224, 3])
target = tf.random.uniform([batch_size, 1], minval=0, maxval=999, dtype=tf.int64)


@tf.function
def benchmark_step(first_batch):

    # Horovod: use DistributedGradientTape
    with tf.GradientTape() as tape:
        probs = model(data, training=True)
        loss = tf.losses.sparse_categorical_crossentropy(target, probs)

    gradients = tape.gradient(loss, model.trainable_variables)
    opt.apply_gradients(zip(gradients, model.trainable_variables))

def log(s, nl=True):
    print(s, end='\n' if nl else '')


log('Model: %s' % 'ResNet50')
log('Batch size: %d' % batch_size)
device = 'GPU' if len(gpus) > 0 else 'CPU'

with tf.device(device):
    # Warm-up
    log('Running warmup...')
    benchmark_step(first_batch=True)
    timeit.timeit(lambda: benchmark_step(first_batch=False),
                  number=warmup_iters)

    # Benchmark
    log('Running benchmark...')
    img_secs = []
    for x in range(num_iters):
        time = timeit.timeit(lambda: benchmark_step(first_batch=False),
                             number=num_batches_per_iter)
        img_sec = batch_size * num_batches_per_iter / time
        log('Iter #%d: %.1f img/sec per %s' % (x, img_sec, device))
        img_secs.append(img_sec)

    # Results
    img_sec_mean = np.mean(img_secs)
    img_sec_conf = 1.96 * np.std(img_secs)
    log('Img/sec per %s: %.1f +-%.1f' % (device, img_sec_mean, img_sec_conf))

Model: ResNet50
Batch size: 8
Running warmup...
Running benchmark...
Iter #0: 4.5 img/sec per CPU
Iter #1: 4.6 img/sec per CPU
Iter #2: 4.6 img/sec per CPU
Img/sec per CPU: 4.6 +-0.1
