## Setup

In [1]:
from tensorflow.keras import layers
from tensorflow import keras
import tensorflow as tf

import tensorflow_datasets as tfds
tfds.disable_progress_bar()

from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy("mixed_float16")

import matplotlib.pyplot as plt
import numpy as np

INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPUs will likely run quickly with dtype policy mixed_float16 as they all have compute capability of at least 7.0


In [2]:
try: 
    tpu = None
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver() 
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
except ValueError: 
    strategy = tf.distribute.MirroredStrategy() 

print("Number of accelerators: ", strategy.num_replicas_in_sync)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
Number of accelerators:  4


## Define hyperparameters

In [3]:
INP_DIM = (300, 300)
TARGET_DIM = (224, 224)
INTERPOLATION = "bilinear"

AUTO = tf.data.AUTOTUNE
BATCH_SIZE = 128 * strategy.num_replicas_in_sync
EPOCHS = 10

## Load dataset and prepare data loaders

In [4]:
train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    # Reserve 10% for validation and 10% for test
    split=["train[:90%]", "train[90%:95%]", "train[95%:]"],
    as_supervised=True,  
)

In [5]:
def preprocess_dataset(image, label):
    image = tf.image.resize(image, (INP_DIM[0], INP_DIM[1]))
    label = tf.one_hot(label, depth=2)
    return (image, label)

In [6]:
train_ds = (
    train_ds
    .shuffle(BATCH_SIZE * 100)
    .map(preprocess_dataset, num_parallel_calls=AUTO)
    .batch(BATCH_SIZE)
    .prefetch(AUTO)
)
validation_ds = (
    validation_ds
    .map(preprocess_dataset, num_parallel_calls=AUTO)
    .batch(BATCH_SIZE)
    .prefetch(AUTO)
)
test_ds = (
    test_ds
    .map(preprocess_dataset, num_parallel_calls=AUTO)
    .batch(BATCH_SIZE)
    .prefetch(AUTO)
)

## Utility functions

In [7]:
def conv_block(x, filters, kernel_size, strides, activation=layers.LeakyReLU(0.2)):
    x = layers.Conv2D(filters, kernel_size, strides, padding="same",
                      use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    if activation:
        x = activation(x)
    return x

In [8]:
def res_block(x):
    inputs = x
    x = conv_block(x, 16, 3, 1)
    x = conv_block(x, 16, 3, 1, activation=None)
    return layers.Add()([inputs, x])

In [9]:
def learnable_resizer(inputs, 
    filters=16,
    num_res_blocks=1, 
    interpolation=INTERPOLATION):

    # We first need to resize to a fixed resolution to allow mini-batch learning
    naive_resize = layers.experimental.preprocessing.Resizing(*TARGET_DIM,
            interpolation=interpolation)(inputs)

    # First conv block without Batch Norm
    x = layers.Conv2D(filters=filters, kernel_size=7, strides=1, padding="same")(inputs)
    x = layers.LeakyReLU(0.2)(x)

    # Second conv block with Batch Norm
    x = layers.Conv2D(filters=filters, kernel_size=1, strides=1, padding="same")(x)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.BatchNormalization()(x)

    # Intermediate resizing as bottleneck
    bottleneck = layers.experimental.preprocessing.Resizing(*TARGET_DIM,
            interpolation=interpolation)(x)
    
    # Residual passes
    for _ in range(num_res_blocks):
        x = res_block(bottleneck)

    # Projection
    x = layers.Conv2D(filters=filters, kernel_size=3, strides=1, padding="same",
                      use_bias=False)(x)
    x = layers.BatchNormalization()(x)

    # Skip connection
    x = layers.Add()([bottleneck, x])

    # Final resized image
    x = layers.Conv2D(filters=3, kernel_size=7, strides=1, padding="same")(x)
    final_resize = layers.Add()([naive_resize, x])

    return final_resize

In [10]:
def get_model():
    backbone = tf.keras.applications.DenseNet121(weights=None, include_top=True, classes=2)
    backbone.trainable = True
    
    inputs = layers.Input((INP_DIM[0], INP_DIM[1], 3))
    x = layers.experimental.preprocessing.Rescaling(scale=1./255)(inputs)
    x = learnable_resizer(x)
    outputs = backbone(x)
    outputs = layers.Activation("linear", dtype="float32")(outputs)

    return tf.keras.Model(inputs, outputs)

## Train and evaluate model

In [16]:
# Here we load the initial moddel weights and first extract the weights of the backbone i.e.
# DenseNet121. We then take these weights and set them as the initial starting point
# for the backbone of this model containing a learnable image resizer. 
initial_model = tf.keras.models.load_model("initial_model")
model = get_model()
model.layers[-2].set_weights(initial_model.layers[-2].get_weights())

# Test a single layer if the weights were properly loaded
np.allclose(initial_model.layers[-2].layers[-1].get_weights()[0], 
            model.layers[-2].layers[-1].get_weights()[0])

True

In [17]:
with strategy.scope():
    initial_model = tf.keras.models.load_model("initial_model")
    model = get_model()
    model.layers[-2].set_weights(initial_model.layers[-2].get_weights())
    
    model.compile(loss=keras.losses.CategoricalCrossentropy(label_smoothing=0.1),
                optimizer="sgd",
                metrics=["accuracy"])
model.fit(train_ds,
          validation_data=validation_ds,
          epochs=EPOCHS)

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).






Epoch 1/10
INFO:tensorflow:batch_all_reduce: 381 all-reduces with algorithm = nccl, num_packs = 1


INFO:tensorflow:batch_all_reduce: 381 all-reduces with algorithm = nccl, num_packs = 1


INFO:tensorflow:batch_all_reduce: 381 all-reduces with algorithm = nccl, num_packs = 1


INFO:tensorflow:batch_all_reduce: 381 all-reduces with algorithm = nccl, num_packs = 1


Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7efbb6ce7750>

In [18]:
with strategy.scope():
    _, test_acc = model.evaluate(test_ds)
print("Test accuracy: {:.2f}%".format(test_acc * 100))

model.save("learnable_resizer_model")

Test accuracy: 52.02%
INFO:tensorflow:Assets written to: learnable_resizer_model/assets


INFO:tensorflow:Assets written to: learnable_resizer_model/assets
