In [1]:
import tensorflow as tf 
import tensorflow_hub as hub
import tensorflow_addons as tfa
import tensorflow_datasets as tfds

from tensorflow import keras

tfds.disable_progress_bar()
tf.random.set_seed(42)

In [2]:
import sys

sys.path.append("..")

from vit.deit_models import ViTDistilled
from vit.model_configs import base_config

In [3]:
MODEL_TYPE = "deit_distilled_tiny_patch16_224"

BATCH_SIZE = 128
NUM_EPOCHS = 10

AUTO = tf.data.AUTOTUNE
NB_CLASSES = 5

In [4]:
deit_tiny_config = base_config.get_config(
    drop_path_rate=0.1, model_name=MODEL_TYPE
)
with deit_tiny_config.unlocked():
    deit_tiny_config.num_classes = NB_CLASSES

deit_tiny_config.to_dict()

{'classifier': 'token',
 'drop_path_rate': 0.1,
 'dropout_rate': 0.0,
 'image_size': 224,
 'init_values': None,
 'initializer_range': 0.02,
 'input_shape': (224, 224, 3),
 'layer_norm_eps': 1e-06,
 'mlp_units': [768, 192],
 'name': 'deit_distilled_tiny_patch16_224',
 'num_classes': 5,
 'num_heads': 3,
 'num_layers': 12,
 'num_patches': 196,
 'patch_size': 16,
 'pre_logits': False,
 'projection_dim': 192}

In [5]:
SZ = deit_tiny_config.image_size

def preprocess_dataset(is_training=True):
    def _pp(image, label):
        if is_training:
            # Resize to a bigger spatial resolution and take the random
            # crops.
            image = tf.image.resize(image, (SZ + 20, SZ + 20))
            image = tf.image.random_crop(image, (SZ, SZ, 3))
            image = tf.image.random_flip_left_right(image)
        else:
            image = tf.image.resize(image, (SZ, SZ))
        label = tf.one_hot(label, depth=NB_CLASSES)
        return image, label

    return _pp


def prepare_dataset(dataset, is_training=True):
    if is_training:
        dataset = dataset.shuffle(BATCH_SIZE * 10)
    dataset = dataset.map(preprocess_dataset(is_training), num_parallel_calls=AUTO)
    return dataset.batch(BATCH_SIZE).prefetch(AUTO)

In [6]:
train_dataset, val_dataset = tfds.load(
    "tf_flowers", split=["train[:90%]", "train[90%:]"], as_supervised=True
)
num_train = train_dataset.cardinality()
num_val = val_dataset.cardinality()
print(f"Number of training examples: {num_train}")
print(f"Number of validation examples: {num_val}")

train_dataset = prepare_dataset(train_dataset, is_training=True)
val_dataset = prepare_dataset(val_dataset, is_training=False)

Number of training examples: 3303
Number of validation examples: 367


2022-03-31 21:23:18.231035: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [7]:
deit_tiny = ViTDistilled(deit_tiny_config)

deit_tiny = keras.Sequential([keras.Input((SZ, SZ, 3)), keras.layers.Rescaling(scale=1.0 / 255), deit_tiny])

resolution = deit_tiny_config.image_size
dummy_inputs = tf.ones((2, resolution, resolution, 3))
_ = deit_tiny(dummy_inputs)
print(f"Number of parameters (millions): {deit_tiny.count_params() / 1e6}.")

ValueError: All layers in a Sequential model should have a single output tensor. For multi-output layers, use the functional API.

In [None]:
bit_teacher_flowers = keras.models.load_model("gs://deit-tf/bit_teacher_flowers")
print(f"Number of parameters (millions): {bit_teacher_flowers.count_params() / 1e6}.")

In [8]:
class DeiT(keras.Model):
    def __init__(self, student, teacher, **kwargs):
        super().__init__(**kwargs)
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
    ):
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn

    def train_step(self, data):
        # Unpack data.
        x, y = data

        # Forward pass of teacher
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
            # Forward pass of student.
            cls_predictions, dist_predictions, _ = self.student(x, training=True)

            # Compute losses.
            student_loss = self.student_loss_fn(y, cls_predictions)
            distillation_loss = self.distillation_loss_fn(
                teacher_predictions, dist_predictions
            )
            loss = (student_loss + distillation_loss) / 2

        # Compute gradients.
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights.
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update the metrics configured in `compile()`.
        student_predictions = (cls_predictions + dist_predictions) / 2
        self.compiled_metrics.update_state(y, student_predictions)

        # Return a dict of performance.
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss}
        )
        return results

    def test_step(self, data):
        # Unpack the data.
        x, y = data

        # Compute predictions.
        y_prediction, _ = self.student(x, training=False)

        # Calculate the loss.
        student_loss = self.student_loss_fn(y, y_prediction)

        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)

        # Return a dict of performance.
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results
    
    def call(self, inputs):
        return self.student(inputs, training=False)

In [9]:
deit_distiller = DeiT(student=deit_tiny, teacher=bit_teacher_flowers)
deit_distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=["accuracy"],
    student_loss_fn=keras.losses.CategoricalCrossentropy(
        from_logits=True, label_smoothing=0.1
    ),
    distillation_loss_fn=keras.losses.CategoricalCrossentropy(from_logits=True),
)
_ = deit_distiller.fit(
    train_dataset.take(1), validation_data=val_dataset.take(1), epochs=1
)

NameError: name 'bit_teacher_flowers' is not defined