Find the Kaggle Kernel of this notebook [here](https://www.kaggle.com/spsayakpaul/train-bit).

This notebook fine-tunes a teacher model (based on [BiT ResNet101x3](https://arxiv.org/abs/1912.11370)) to further train a student using function matching (proposed in [Knowledge distillation: A good teacher is patient and consistent](https://arxiv.org/abs/2106.05237)). You can find the distillation notebook [here](https://www.kaggle.com/spsayakpaul/funmatch-distillation). To run this notebook you would need to have a billing enabled GCP account to use a GCS Bucket.

## Setup

In [1]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
user_credential = user_secrets.get_gcloud_credential()
user_secrets.set_tensorflow_credential(user_credential)

In [2]:
import tensorflow_hub as hub
import tensorflow as tf

from pprint import pprint
import matplotlib.pyplot as plt
import numpy as np
import re

# This needs to be done in order for the Hub module to communicate.
import os
os.environ["TFHUB_CACHE_DIR"] = "gs://funmatch-tf/model-cache-dir"

`gs://funmatch-tf` is the GCS Bucket I created beforehand. To proceed, you'd need to create a GCS Bucket with a universally unique name and replace `funmatch-tf` with it. 

In [3]:
try: # Cetect TPUs
    tpu = None
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
except ValueError: # Detect GPUs
    strategy = tf.distribute.MirroredStrategy() 

print("Number of accelerators: ", strategy.num_replicas_in_sync)

Number of accelerators:  8


## Hyperparameters and constants

In [4]:
BATCH_SIZE = 64 * strategy.num_replicas_in_sync
BIGGER = 160
RESIZE = 128
CENTRAL_FRAC = 0.875
AUTO = tf.data.AUTOTUNE

SCHEDULE_LENGTH = 500
SCHEDULE_BOUNDARIES = [200, 300, 400]
SCHEDULE_LENGTH = (SCHEDULE_LENGTH * 512 / BATCH_SIZE) 

## Data loading and input preprocessing

To know how these TFRecords were created refer to [this notebook](https://colab.research.google.com/github/sayakpaul/FunMatch-Distillation/blob/main/tfrecords_pets37.ipynb). **Be sure to update the GCS paths.** 

In [5]:
# This comes from this repository https://github.com/GoogleCloudPlatform/training-data-analyst.
def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

train_pattern = "gs://funmatch-tf/train/*.tfrec"
train_filenames = tf.io.gfile.glob(train_pattern)
val_pattern = "gs://funmatch-tf/validation/*.tfrec"
val_filenames = tf.io.gfile.glob(val_pattern)
test_pattern = "gs://funmatch-tf/test/*.tfrec"
test_filenames = tf.io.gfile.glob(test_pattern)

DATASET_NUM_TRAIN_EXAMPLES = count_data_items(train_filenames)
STEPS_PER_EPOCH = 10

pprint(train_filenames[:5])
pprint(val_filenames[:5])
pprint(test_filenames[:5])

['gs://funmatch-tf/train/train_pets37-0-128.tfrec',
 'gs://funmatch-tf/train/train_pets37-1-128.tfrec',
 'gs://funmatch-tf/train/train_pets37-10-128.tfrec',
 'gs://funmatch-tf/train/train_pets37-11-128.tfrec',
 'gs://funmatch-tf/train/train_pets37-12-128.tfrec']
['gs://funmatch-tf/validation/validation_pets37-0-128.tfrec',
 'gs://funmatch-tf/validation/validation_pets37-1-128.tfrec',
 'gs://funmatch-tf/validation/validation_pets37-2-112.tfrec']
['gs://funmatch-tf/test/test_pets37-0-128.tfrec',
 'gs://funmatch-tf/test/test_pets37-1-128.tfrec',
 'gs://funmatch-tf/test/test_pets37-10-128.tfrec',
 'gs://funmatch-tf/test/test_pets37-11-128.tfrec',
 'gs://funmatch-tf/test/test_pets37-12-128.tfrec']


In [7]:
# Function to read the TFRecords, segregate the images and labels.
def read_tfrecord(example, train):
    features = {
        "image": tf.io.FixedLenFeature([], tf.string), 
        "class": tf.io.FixedLenFeature([], tf.int64)
    }
    
    example = tf.io.parse_single_example(example, features)
    image = tf.image.decode_jpeg(example["image"], channels=3)
    
    if train:
        image = augment(image)
    else:
        image = tf.image.central_crop(image, central_fraction=CENTRAL_FRAC)
        image = tf.image.resize(image, (RESIZE, RESIZE))
        
    image = tf.reshape(image, (RESIZE, RESIZE, 3))
    image = tf.cast(image, tf.float32) / 255.0  
    class_label = tf.cast(example["class"], tf.int32)
    return (image, class_label)

# Load the TFRecords and create tf.data.Dataset
def load_dataset(filenames, train):
    opt = tf.data.Options()
    opt.experimental_deterministic = False
    
    if not train:
        opt.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
    
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) 
    dataset = dataset.map(lambda x: (read_tfrecord(x, train)), num_parallel_calls=AUTO)
    dataset = dataset.with_options(opt)
    return dataset

# Augmentation motivated from here:
# https://github.com/google-research/big_transfer/blob/master/colabs/big_transfer_tf2.ipynb.
def augment(image):
    # Resize to a bigger shape, randomly horizontally flip it,
    # and then take the crops. 
    image = tf.image.resize(image, (BIGGER, BIGGER))
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_crop(image, [RESIZE, RESIZE, 3])
    return image

# Batch, shuffle, and repeat the dataset and prefetch it
# well before the current epoch ends
def batch_dataset(filenames, train, batch_size=BATCH_SIZE):
    dataset = load_dataset(filenames, train)
    if train:
        dataset = dataset.repeat(int(SCHEDULE_LENGTH * BATCH_SIZE / DATASET_NUM_TRAIN_EXAMPLES * STEPS_PER_EPOCH) + 1 + 50)
        dataset = dataset.shuffle(BATCH_SIZE*10)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(AUTO) 
    return dataset

In [8]:
training_dataset = batch_dataset(train_filenames, True)
validation_dataset = batch_dataset(val_filenames, False)
test_dataset = batch_dataset(test_filenames, False)

In [9]:
# sample_images, _ = next(iter(training_dataset))
# plt.figure(figsize=(10, 10))
# for n in range(25):
#     ax = plt.subplot(5, 5, n + 1)
#     plt.imshow(sample_images[n].numpy())
#     plt.axis("off")
# plt.show()

## Model related utilities

In [10]:
# Referenced from: https://github.com/google-research/big_transfer/blob/master/colabs/big_transfer_tf2.ipynb. 
class MyBiTModel(tf.keras.Model):
    def __init__(self, num_classes, module):
        super().__init__()

        self.num_classes = num_classes
        self.head = tf.keras.layers.Dense(num_classes, kernel_initializer="zeros")
        self.bit_model = module
  
    def call(self, images):
        bit_embedding = self.bit_model(images)
        return self.head(bit_embedding)

In [11]:
# Define optimizer and loss

lr = (1e-5 * BATCH_SIZE / 512) * strategy.num_replicas_in_sync 

# Decay learning rate by a factor of 10 at SCHEDULE_BOUNDARIES.
lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries=SCHEDULE_BOUNDARIES, 
                                                                   values=[lr, lr*0.1, lr*0.001, lr*0.0001])
optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule, momentum=0.9)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

## Training and evaluation

In [12]:
# Target is 91.03%.
with strategy.scope():
    model_url = "https://tfhub.dev/google/bit/m-r101x3/1"
    module = hub.KerasLayer(model_url, trainable=True)
    model = MyBiTModel(num_classes=37, module=module)
    model.compile(optimizer=optimizer,
              loss=loss_fn,
              metrics=["accuracy"])
    
history = model.fit(
    training_dataset,
    validation_data=validation_dataset,
    steps_per_epoch=STEPS_PER_EPOCH,
    epochs=45
)

Epoch 1/45
Epoch 2/45
Epoch 3/45
Epoch 4/45
Epoch 5/45
Epoch 6/45
Epoch 7/45
Epoch 8/45
Epoch 9/45
Epoch 10/45
Epoch 11/45
Epoch 12/45
Epoch 13/45
Epoch 14/45
Epoch 15/45
Epoch 16/45
Epoch 17/45
Epoch 18/45
Epoch 19/45
Epoch 20/45
Epoch 21/45
Epoch 22/45
Epoch 23/45
Epoch 24/45
Epoch 25/45
Epoch 26/45
Epoch 27/45
Epoch 28/45
Epoch 29/45
Epoch 30/45
Epoch 31/45
Epoch 32/45
Epoch 33/45
Epoch 34/45
Epoch 35/45
Epoch 36/45
Epoch 37/45
Epoch 38/45
Epoch 39/45
Epoch 40/45
Epoch 41/45
Epoch 42/45
Epoch 43/45
Epoch 44/45
Epoch 45/45


Should have trained for five more epochs. 

In [13]:
_, accuracy = model.evaluate(test_dataset)
print(f"Test top-1 accuracy: {round(accuracy * 100, 2)}%")

Test top-1 accuracy: 90.92%


In [14]:
model.save("gs://funmatch-tf/models/T-r101x3-128")