## Imports

In [None]:
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_addons as tfa
import tensorflow_datasets as tfds

from tensorflow import keras

tfds.disable_progress_bar()
tf.keras.utils.set_random_seed(42)

In [None]:
import sys

sys.path.append("..")

from vit.deit_models import ViTDistilled
from vit.model_configs import base_config

## Constants

In [None]:
MODEL_TYPE = "deit_distilled_tiny_patch16_224"

BATCH_SIZE = 256
NUM_EPOCHS = 20
BASE_LR = 0.0005
WEIGHT_DECAY = 0.0001

AUTO = tf.data.AUTOTUNE
NB_CLASSES = 5

## Initialize model config

In [None]:
deit_tiny_config = base_config.get_config(drop_path_rate=0.1, model_name=MODEL_TYPE)
with deit_tiny_config.unlocked():
    deit_tiny_config.num_classes = NB_CLASSES

deit_tiny_config.to_dict()

## Data preprocessing and loading

In [None]:
SZ = deit_tiny_config.image_size


def preprocess_dataset(is_training=True):
    def _pp(image, label):
        if is_training:
            # Resize to a bigger spatial resolution and take the random
            # crops.
            image = tf.image.resize(image, (SZ + 20, SZ + 20))
            image = tf.image.random_crop(image, (SZ, SZ, 3))
            image = tf.image.random_flip_left_right(image)
        else:
            image = tf.image.resize(image, (SZ, SZ))
        label = tf.one_hot(label, depth=NB_CLASSES)
        return image, label

    return _pp


def prepare_dataset(dataset, is_training=True):
    if is_training:
        dataset = dataset.shuffle(BATCH_SIZE * 10)
    dataset = dataset.map(preprocess_dataset(is_training), num_parallel_calls=AUTO)
    return dataset.batch(BATCH_SIZE).prefetch(AUTO)

In [None]:
train_dataset, val_dataset = tfds.load(
    "tf_flowers", split=["train[:90%]", "train[90%:]"], as_supervised=True
)
num_train = train_dataset.cardinality()
num_val = val_dataset.cardinality()
print(f"Number of training examples: {num_train}")
print(f"Number of validation examples: {num_val}")

train_dataset = prepare_dataset(train_dataset, is_training=True)
val_dataset = prepare_dataset(val_dataset, is_training=False)

## Initialize student and teacher models

In [None]:
deit_tiny = ViTDistilled(deit_tiny_config)

resolution = deit_tiny_config.image_size
dummy_inputs = tf.ones((2, resolution, resolution, 3))
_ = deit_tiny(dummy_inputs)
print(f"Number of parameters (millions): {deit_tiny.count_params() / 1e6}.")

In [None]:
# 98.37% on the validation set.
# To know how this was trained refer to `./bit-teacher.ipynb`.
bit_teacher_flowers = keras.models.load_model("bit_teacher_flowers")
print(f"Number of parameters (millions): {bit_teacher_flowers.count_params() / 1e6}.")

Here we can see that the teacher model has got orders of magnitude more parameters than the student model.

## Wrap the training logic of DeiT

**Note** that here we are just following the core principles of the distillation process laid out in the [original paper](https://arxiv.org/abs/2012.12877). The authors use more data augmentation and regularization which have been purposefully discarded to keep the workflow simple to follow.

In [None]:
class DeiT(keras.Model):
    # Reference:
    # https://keras.io/examples/vision/knowledge_distillation/
    def __init__(self, student, teacher, **kwargs):
        super().__init__(**kwargs)
        self.student = student
        self.teacher = teacher

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
    ):
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn

    def train_step(self, data):
        # Unpack data.
        x, y = data

        # Forward pass of teacher
        teacher_predictions = tf.nn.softmax(self.teacher(x, training=False), -1)
        teacher_predictions = tf.argmax(teacher_predictions, -1)

        with tf.GradientTape() as tape:
            # Forward pass of student.
            cls_predictions, dist_predictions, _ = self.student(
                x / 255.0, training=True
            )

            # Compute losses.
            student_loss = self.student_loss_fn(y, cls_predictions)
            distillation_loss = self.distillation_loss_fn(
                teacher_predictions, dist_predictions
            )
            loss = (student_loss + distillation_loss) / 2

        # Compute gradients.
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights.
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update the metrics configured in `compile()`.
        student_predictions = (cls_predictions + dist_predictions) / 2
        self.compiled_metrics.update_state(y, student_predictions)

        # Return a dict of performance.
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss}
        )
        return results

    def test_step(self, data):
        # Unpack the data.
        x, y = data

        # Compute predictions.
        y_prediction, _ = self.student(x / 255.0, training=False)

        # Calculate the loss.
        student_loss = self.student_loss_fn(y, y_prediction)

        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)

        # Return a dict of performance.
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results

    def call(self, inputs):
        return self.student(inputs / 255.0, training=False)

## Distill the teacher model into the student model

In [None]:
deit_distiller = DeiT(student=deit_tiny, teacher=bit_teacher_flowers)

lr_scaled = (BASE_LR / 512) * BATCH_SIZE
deit_distiller.compile(
    optimizer=tfa.optimizers.AdamW(weight_decay=WEIGHT_DECAY, learning_rate=lr_scaled),
    metrics=["accuracy"],
    student_loss_fn=keras.losses.CategoricalCrossentropy(
        from_logits=True, label_smoothing=0.1
    ),
    distillation_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
_ = deit_distiller.fit(train_dataset, validation_data=val_dataset, epochs=NUM_EPOCHS)

The model should give about 68.5% - 69.5% accuracy on the validation set. The results may slightly vary depending on the hardware used. 

If the same student model was trained from scratch (i.e., without distillation) it would give 65% - 66% accuracy on the validation set. To train such a model adapt the following code:

```py
inputs = keras.Input((SZ, SZ, 3))
x = keras.layers.Rescaling(scale=1./255)(inputs)
outputs, _ = deit_tiny(x) # Second output in the tuple is a dictionary containing attention scores.
model = keras.Model(inputs, outputs)

model.compile(...)
model.fit(...)
```


