## Imports

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub

from tensorflow import keras

tfds.disable_progress_bar()

## Constants

In [None]:
MODULE_URL = "https://tfhub.dev/google/bit/m-r50x3/1"

BATCH_SIZE = 128
SZ = 224
NUM_EPOCHS = 10

AUTO = tf.data.AUTOTUNE
NB_CLASSES = 5

## Data preprocessing and loading

In [None]:
def preprocess_dataset(is_training=True):
    def _pp(image, label):
        if is_training:
            # Resize to a bigger spatial resolution and take the random
            # crops.
            image = tf.image.resize(image, (SZ + 20, SZ + 20))
            image = tf.image.random_crop(image, (SZ, SZ, 3))
            image = tf.image.random_flip_left_right(image)
        else:
            image = tf.image.resize(image, (SZ, SZ))
        label = tf.one_hot(label, depth=NB_CLASSES)
        return image, label

    return _pp


def prepare_dataset(dataset, is_training=True):
    if is_training:
        dataset = dataset.shuffle(BATCH_SIZE * 10)
    dataset = dataset.map(preprocess_dataset(is_training), num_parallel_calls=AUTO)
    return dataset.batch(BATCH_SIZE).prefetch(AUTO)

In [None]:
train_dataset, val_dataset = tfds.load(
    "tf_flowers", split=["train[:90%]", "train[90%:]"], as_supervised=True
)
num_train = train_dataset.cardinality()
num_val = val_dataset.cardinality()
print(f"Number of training examples: {num_train}")
print(f"Number of validation examples: {num_val}")

train_dataset = prepare_dataset(train_dataset, is_training=True)
val_dataset = prepare_dataset(val_dataset, is_training=False)

## Model initialization

In [None]:
hub_module = hub.KerasLayer(MODULE_URL)

model = keras.Sequential(
    [
        keras.Input((SZ, SZ, 3)),
        keras.layers.Rescaling(scale=1.0 / 255),
        hub_module,
        keras.layers.Dense(NB_CLASSES, kernel_initializer="zeros"),
    ],
    name="bit_teacher_flowers",
)
print(f"Number of parameters (millions): {model.count_params() / 1e6}.")

## Optimizer and loss function

In [None]:
SCHEDULE_LENGTH = 500
SCHEDULE_LENGTH = SCHEDULE_LENGTH * 512 / BATCH_SIZE

SCHEDULE_BOUNDARIES = [200, 300, 400]
lr = 0.003 * BATCH_SIZE / 512

# Decay learning rate by a factor of 10 at SCHEDULE_BOUNDARIES.
lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
    boundaries=SCHEDULE_BOUNDARIES, values=[lr, lr * 0.1, lr * 0.001, lr * 0.0001]
)
optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule, momentum=0.9)

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

## Train the model and save it

In [None]:
model.compile(optimizer=optimizer, loss=loss_fn, metrics=["accuracy"])

history = model.fit(
    train_dataset.repeat(),
    batch_size=BATCH_SIZE,
    steps_per_epoch=10,
    epochs=NUM_EPOCHS,
    validation_data=val_dataset,
)

In [None]:
model.save("bit_teacher_flowers")

## References

* [Official Colab Notebook from BiT authors](https://colab.research.google.com/github/google-research/big_transfer/blob/master/colabs/big_transfer_tf2.ipynb)