In [None]:
from tensorflow.keras import layers
import tensorflow as tf

import tensorflow_datasets as tfds
tfds.disable_progress_bar()

from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy("mixed_float16")

import matplotlib.pyplot as plt
import numpy as np

In [None]:
try: 
    tpu = None
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver() 
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
except ValueError: 
    strategy = tf.distribute.MirroredStrategy() 

print("Number of accelerators: ", strategy.num_replicas_in_sync)

In [None]:
TARGET_DIM = (224, 224)

AUTO = tf.data.AUTOTUNE
BATCH_SIZE = 128 * strategy.num_replicas_in_sync
EPOCHS = 10

In [None]:
train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    # Reserve 10% for validation and 10% for test
    split=["train[:90%]", "train[90%:95%]", "train[95%:]"],
    as_supervised=True,  
)

In [None]:
def preprocess_dataset(image, label):
    image = tf.image.resize(image, (TARGET_DIM[0], TARGET_DIM[1]))
    label = tf.one_hot(label, depth=2)
    return (image, label)

In [None]:
train_ds = (
    train_ds
    .shuffle(BATCH_SIZE * 100)
    .map(preprocess_dataset, num_parallel_calls=AUTO)
    .batch(BATCH_SIZE)
    .prefetch(AUTO)
)
validation_ds = (
    validation_ds
    .map(preprocess_dataset, num_parallel_calls=AUTO)
    .batch(BATCH_SIZE)
    .prefetch(AUTO)
)
test_ds = (
    test_ds
    .map(preprocess_dataset, num_parallel_calls=AUTO)
    .batch(BATCH_SIZE)
    .prefetch(AUTO)
)

In [None]:
def get_model():
    backbone = tf.keras.applications.DenseNet121(weights=None, include_top=True, classes=2)
    backbone.trainable = True
    
    inputs = layers.Input((INP_DIM[0], INP_DIM[1], 3))
    x = layers.experimental.preprocessing.Rescaling(scale=1./255)(inputs)
    outputs = backbone(x)
    outputs = layers.Activation("linear", dtype="float32")(outputs)

    return tf.keras.Model(inputs, outputs)

In [None]:
with strategy.scope():
    model = get_model()
    model.compile(loss=keras.losses.CategoricalCrossentropy(label_smoothing=0.1),
                optimizer="sgd",
                metrics=["accuracy"])
model.fit(train_ds,
          validation_data=validation_ds,
          epochs=EPOCHS)

In [None]:
with strategy.scope():
    _, test_acc = model.evaluate(test_ds)
print("Test accuracy: {:.2f}%".format(test_acc * 100))

model.save("standard_densenet_model")