Find the Kaggle Kernel of this notebook [here](https://www.kaggle.com/spsayakpaul/train-bit-keras-tuner).

This notebook runs hyperparameter-tuning on a teacher model (based on [BiT ResNet101x3](https://arxiv.org/abs/1912.11370)) to further train a student using function matching (proposed in [Knowledge distillation: A good teacher is patient and consistent](https://arxiv.org/abs/2106.05237)). You can find the distillation notebook [here](https://www.kaggle.com/spsayakpaul/funmatch-distillation). To run this notebook you would need to have a billing enabled GCP account to use a GCS Bucket.

## Setup

In [1]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
user_credential = user_secrets.get_gcloud_credential()
user_secrets.set_tensorflow_credential(user_credential)

In [2]:
!pip install keras-tuner -q --user



In [3]:
import tensorflow_hub as hub
import keras_tuner as kt
import tensorflow as tf

from pprint import pprint
import matplotlib.pyplot as plt
import numpy as np
import re

# This needs to be done in order for the Hub module to communicate.
import os
os.environ["TFHUB_CACHE_DIR"] = "gs://funmatch-tf/model-cache-dir"

`gs://funmatch-tf` is the GCS Bucket I created beforehand. To proceed, you'd need to create a GCS Bucket with a universally unique name and replace `funmatch-tf` with it. 

In [4]:
try: # Cetect TPUs
    tpu = None
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
except ValueError: # Detect GPUs
    strategy = tf.distribute.MirroredStrategy() 

print("Number of accelerators: ", strategy.num_replicas_in_sync)

Number of accelerators:  8


## Constants and hyperparameters

In [5]:
BATCH_SIZE = 64 * strategy.num_replicas_in_sync
BIGGER = 160
RESIZE = 128
CENTRAL_FRAC = 0.875
AUTO = tf.data.AUTOTUNE

SCHEDULE_LENGTH = 500
SCHEDULE_BOUNDARIES = [200, 300, 400]
SCHEDULE_LENGTH = (SCHEDULE_LENGTH * 512 / BATCH_SIZE) 

## Data loading and input preprocessing

To know how these TFRecords were created refer to [this notebook](https://colab.research.google.com/github/sayakpaul/FunMatch-Distillation/blob/main/tfrecords_pets37.ipynb). **Be sure to update the GCS paths.** 

In [6]:
# This comes from this repository https://github.com/GoogleCloudPlatform/training-data-analyst.
def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

train_pattern = "gs://funmatch-tf/train/*.tfrec"
train_filenames = tf.io.gfile.glob(train_pattern)
val_pattern = "gs://funmatch-tf/validation/*.tfrec"
val_filenames = tf.io.gfile.glob(val_pattern)
test_pattern = "gs://funmatch-tf/test/*.tfrec"
test_filenames = tf.io.gfile.glob(test_pattern)

DATASET_NUM_TRAIN_EXAMPLES = count_data_items(train_filenames)
STEPS_PER_EPOCH = 10

pprint(train_filenames[:5])
pprint(val_filenames[:5])
pprint(test_filenames[:5])

['gs://funmatch-tf/train/train_pets37-0-128.tfrec',
 'gs://funmatch-tf/train/train_pets37-1-128.tfrec',
 'gs://funmatch-tf/train/train_pets37-10-128.tfrec',
 'gs://funmatch-tf/train/train_pets37-11-128.tfrec',
 'gs://funmatch-tf/train/train_pets37-12-128.tfrec']
['gs://funmatch-tf/validation/validation_pets37-0-128.tfrec',
 'gs://funmatch-tf/validation/validation_pets37-1-128.tfrec',
 'gs://funmatch-tf/validation/validation_pets37-2-112.tfrec']
['gs://funmatch-tf/test/test_pets37-0-128.tfrec',
 'gs://funmatch-tf/test/test_pets37-1-128.tfrec',
 'gs://funmatch-tf/test/test_pets37-10-128.tfrec',
 'gs://funmatch-tf/test/test_pets37-11-128.tfrec',
 'gs://funmatch-tf/test/test_pets37-12-128.tfrec']


In [8]:
# Function to read the TFRecords, segregate the images and labels.
def read_tfrecord(example, train):
    features = {
        "image": tf.io.FixedLenFeature([], tf.string), 
        "class": tf.io.FixedLenFeature([], tf.int64)
    }
    
    example = tf.io.parse_single_example(example, features)
    image = tf.image.decode_jpeg(example["image"], channels=3)
    
    if train:
        image = augment(image)
    else:
        image = tf.image.central_crop(image, central_fraction=CENTRAL_FRAC)
        image = tf.image.resize(image, (RESIZE, RESIZE))
        
    image = tf.reshape(image, (RESIZE, RESIZE, 3))
    image = tf.cast(image, tf.float32) / 255.0  
    class_label = tf.cast(example["class"], tf.int32)
    return (image, class_label)

# Load the TFRecords and create tf.data.Dataset
def load_dataset(filenames, train):
    opt = tf.data.Options()
    opt.experimental_deterministic = False
    
    if not train:
        opt.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
    
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) 
    dataset = dataset.map(lambda x: (read_tfrecord(x, train)), num_parallel_calls=AUTO)
    dataset = dataset.with_options(opt)
    return dataset

# Augmentation motivated from here:
# https://github.com/google-research/big_transfer/blob/master/colabs/big_transfer_tf2.ipynb.
def augment(image):
    image = tf.image.resize(image, (BIGGER, BIGGER))
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_crop(image, [RESIZE, RESIZE, 3])
    return image

# Batch, shuffle, and repeat the dataset and prefetch it
# well before the current epoch ends
def batch_dataset(filenames, train, batch_size=BATCH_SIZE):
    dataset = load_dataset(filenames, train)
    if train:
        dataset = dataset.repeat(int(SCHEDULE_LENGTH * BATCH_SIZE / DATASET_NUM_TRAIN_EXAMPLES * STEPS_PER_EPOCH) + 1 + 50)
        dataset = dataset.shuffle(BATCH_SIZE*10)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(AUTO) 
    return dataset

In [9]:
training_dataset = batch_dataset(train_filenames, True)
validation_dataset = batch_dataset(val_filenames, False)
test_dataset = batch_dataset(test_filenames, False)

In [10]:
# sample_images, _ = next(iter(training_dataset))
# plt.figure(figsize=(10, 10))
# for n in range(25):
#     ax = plt.subplot(5, 5, n + 1)
#     plt.imshow(sample_images[n].numpy())
#     plt.axis("off")
# plt.show()

## Model related utilities

### Set up a custom model class

In [11]:
# Referenced from: https://github.com/google-research/big_transfer/blob/master/colabs/big_transfer_tf2.ipynb. 
class MyBiTModel(tf.keras.Model):
    def __init__(self, num_classes, module):
        super().__init__()

        self.num_classes = num_classes
        self.head = tf.keras.layers.Dense(num_classes, kernel_initializer="zeros")
        self.bit_model = module
  
    def call(self, images):
        bit_embedding = self.bit_model(images)
        return self.head(bit_embedding)

### Utility function to sample a learning rate

In [13]:
def get_lr(hp):
    initial_lr = hp.Choice("learning_rate", values=[0.003, 1e-4, 1e-5, 5e-5])
    lr = (initial_lr * BATCH_SIZE / 512) * strategy.num_replicas_in_sync 
    return lr

### Model building and compiling utility

In [14]:
def build_model(hp):
    model_url = "https://tfhub.dev/google/bit/m-r101x3/1"
    module = hub.KerasLayer(model_url, trainable=True)
    model = MyBiTModel(num_classes=37, module=module)
    
    lr = get_lr(hp)
    lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries=SCHEDULE_BOUNDARIES, 
                                                                       values=[lr, lr*0.1, lr*0.001, lr*0.0001])
    optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule, momentum=0.9)
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    
    model.compile(optimizer=optimizer,
              loss=loss_fn,
              metrics=["accuracy"])
    return model

## Running the search

Take note of the GCS paths!

In [18]:
tuner = kt.RandomSearch(
    build_model,
    objective="val_accuracy",
    executions_per_trial=2,
    max_trials=3,
    overwrite=True,
    directory="gs://funmatch-tf/keras_tuner",
    project_name="funmatch",
    distribution_strategy=strategy,
)

tuner.search(training_dataset, 
    steps_per_epoch=STEPS_PER_EPOCH, 
    epochs=45, 
    validation_data=validation_dataset,
    callbacks=[tf.keras.callbacks.EarlyStopping("val_accuracy", patience=3)])

Trial 3 Complete [00h 11m 49s]
val_accuracy: 0.9361413419246674

Best val_accuracy So Far: 0.94701087474823
Total elapsed time: 00h 45m 44s


In [19]:
tuner.results_summary()

Results summary
Results in gs://funmatch-tf/keras_tuner/funmatch
Showing 10 best trials
Objective(name='val_accuracy', direction='max')
Trial summary
Hyperparameters:
learning_rate: 1e-05
Score: 0.94701087474823
Trial summary
Hyperparameters:
learning_rate: 0.0001
Score: 0.9361413419246674
Trial summary
Hyperparameters:
learning_rate: 0.003
Score: 0.8192934989929199


In [24]:
best_model = tuner.get_best_models(num_models=1)
_, accuracy = best_model[0].evaluate(test_dataset)
print(f"Test top-1 accuracy: {round(accuracy * 100, 2)}%")

Test top-1 accuracy: 90.32%
