## Setup

In [1]:
# Comes from 
# https://github.com/GoogleCloudPlatform/keras-idiomatic-programmer/tree/master/zoo
!wget -q https://git.io/Jshxv -O resnet20.py

In [2]:
from tensorflow.keras import layers
from tensorflow import keras
import tensorflow as tf
import numpy as np
import resnet20

# Only enable this for tensor-core GPUs
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy("mixed_float16")

INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPUs will likely run quickly with dtype policy mixed_float16 as they all have compute capability of at least 7.0


In [3]:
try: 
    tpu = None
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver() 
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
except ValueError: 
    strategy = tf.distribute.MirroredStrategy() 

print("Number of accelerators: ", strategy.num_replicas_in_sync)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
Number of accelerators:  4


## Hyperparameters

In [4]:
RESIZE_TO = 72
EPOCHS = 100
BATCH_SIZE = 512 * strategy.num_replicas_in_sync

## Dataset Utilities

In [5]:
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

In [6]:
def get_augmentation_layers():
    data_augmentation = keras.Sequential(
        [
            layers.experimental.preprocessing.Normalization(),
            layers.experimental.preprocessing.Resizing(RESIZE_TO, RESIZE_TO),
            layers.experimental.preprocessing.RandomFlip("horizontal"),
            layers.experimental.preprocessing.RandomRotation(factor=0.02),
            layers.experimental.preprocessing.RandomZoom(
                height_factor=0.2, width_factor=0.2
            ),
        ],
        name="data_augmentation",
    )
    # Compute the mean and the variance of the training data for normalization.
    data_augmentation.layers[0].adapt(x_train)
    
    return data_augmentation

## ResNet20 Utilities

In [7]:
def get_model():
    n = 2
    depth = n * 9 + 2
    n_blocks = ((depth - 2) // 9) - 1

    # The input tensor
    inputs = layers.Input(shape=(32, 32, 3))
    data_augmentation = get_augmentation_layers()
    augmented = data_augmentation(inputs)
    
    # The Stem Convolution Group
    x = resnet20.stem(augmented)

    # The learner
    x = resnet20.learner(x, n_blocks)

    # The Classifier for 10 classes
    outputs = resnet20.classifier(x, 10)

    # Instantiate the Model
    model = tf.keras.Model(inputs, outputs)
    return model

In [8]:
def create_mlp_mixer():
    data_augmentation = get_augmentation_layers()
    
    inputs = layers.Input(shape=(32, 32, 3))
    augmented = data_augmentation(inputs)
    outputs = mlp_mixer(augmented, NUM_MIXER_LAYERS,
                        PATCH_SIZE, HIDDEN_SIZE, 
                        MLP_SEQ_DIM, MLP_CHANNEL_DIM)
    return tf.keras.Model(inputs, outputs, name="mlp_mixer")

In [9]:
def run_experiment(model):
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001 * strategy.num_replicas_in_sync)

    model.compile(
        optimizer=optimizer,
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"]
    )

    checkpoint_filepath = "/tmp/checkpoint"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )

    history = model.fit(
        x=x_train,
        y=y_train,
        batch_size=BATCH_SIZE,
        epochs=EPOCHS,
        validation_split=0.1,
        callbacks=[checkpoint_callback],
    )

    model.load_weights(checkpoint_filepath)
    _, top_1_accuracy = model.evaluate(x_test, y_test)
    print(f"Test accuracy: {round(top_1_accuracy * 100, 2)}%")
    
    return history, model

## Model Training and Evaluation

In [10]:
with strategy.scope():
    resnet20 = get_model()
    history, model = run_experiment(resnet20)
    model.save(f"resnet20_cifar10")

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Redu

In [11]:
resnet20.count_params()/1e6

0.571969