In [None]:
pip install tensorflow-addons

In [None]:
from typing import List

import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_datasets as tfds
import tensorflow_hub as hub
from tensorflow import keras
from tensorflow.keras import layers

tfds.disable_progress_bar()
tf.keras.utils.set_random_seed(42)

In [None]:
MODEL_TYPE = 'deit_distilled_tiny_patch16_224'
RESOLUTION = 224
PATCH_SIZE = 16
NUM_PATCHES = (RESOLUTION // PATCH_SIZE) ** 2
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 192
NUM_HEADS = 3
NUM_LAYERS = 12
MLP_UNITS = [
    PROJECTION_DIM * 4,
    PROJECTION_DIM
]
DROPOUT_RATE = 0.0
DROP_PATH_RATE = 0.1

NUM_EPOCHS = 20
BASE_LR = 0.0005
WEIGHT_DECAY = 0.0001

BATCH_SIZE = 256
AUTO = tf.data.AUTOTUNE
NUM_CLASSES = 5

## Dataset preparation

In [None]:
def preprocess_dataset(is_training=True):
  def fn(image, label):
    if is_training:
      # Resize to a bigger spatial resolution and take the random crops.
      image = tf.image.resize(image, (RESOLUTION + 20, RESOLUTION + 20))
      image = tf.image.random_crop(image, (RESOLUTION, RESOLUTION, 3))
      image = tf.image.random_flip_left_right(image)
    else:
      image = tf.image.resize(image, (RESOLUTION, RESOLUTION))
    label = tf.one_hot(label, depth=NUM_CLASSES)
    return image, label

  return fn

In [None]:
def prepare_dataset(dataset, is_training=True):
  if is_training:
    dataset = dataset.shuffle(BATCH_SIZE * 10)
  dataset = dataset.map(preprocess_dataset(is_training), num_parallel_calls=AUTO)
  return dataset.batch(BATCH_SIZE).prefetch(AUTO)

In [None]:
train_dataset, val_dataset = tfds.load(
    "tf_flowers", split=['train[:90%]', 'train[90%:]'], as_supervised=True
)
num_train = train_dataset.cardinality()
num_val = val_dataset.cardinality()
print(f'num of train: {num_train}')
print(f'num of val: {num_val}')

train_dataset = prepare_dataset(train_dataset, is_training=True)
val_dataset = prepare_dataset(val_dataset, is_training=False)

## Implementation of DeiT

In [None]:
class StochasticDepth(layers.Layer):
  def __init__(self, drop_prop, **kwargs):
    super().__init__(**kwargs)
    self.drop_prop = drop_prop

  def call(self, x, training=True):
    if training:
      keep_prob = 1- self.drop_prop
      shape = (tf.shape(x)[0], ) + (1, ) * (len(tf.shape(x)) - 1)
      random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
      random_tensor = tf.floor(random_tensor)
      return (x / keep_prob) * random_tensor
    return x

In [None]:
def mlp(x, dropout_rate, hidden_units):
  for (idx, units) in enumerate(hidden_units):
    x = layers.Dense(
        units, activation=tf.nn.gelu if idx == 0 else None,
    )(x)
    x = layers.Dropout(dropout_rate)(x)
  return x

In [None]:
def transformer(drop_prob, name):
  num_patches = NUM_PATCHES + 2 if "distilled" in MODEL_TYPE else NUM_PATCHES + 1
  encoded_patches = layers.Input((num_patches, PROJECTION_DIM))

  x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches)
  attention_output = layers.MultiHeadAttention(
      num_heads=NUM_HEADS,
      key_dim=PROJECTION_DIM,
      dropout=DROPOUT_RATE
  )(x1, x1)
  attention_output = (
      StochasticDepth(drop_prob)(attention_output) if drop_prob else attention_output
  )

  x2 = layers.Add()([attention_output, encoded_patches])
  x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)
  x4 = mlp(x3, dropout_rate=DROPOUT_RATE, hidden_units=MLP_UNITS)
  x4 = StochasticDepth(drop_prop=drop_prob)(x4) if drop_prob else x4

  outputs = layers.Add()([x2, x4])
  return keras.Model(encoded_patches, outputs, name=name)


In [None]:
class ViTClassifier(keras.Model):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)

    self.projection = keras.Sequential([
        layers.Conv2D(filters=PROJECTION_DIM, kernel_size=(PATCH_SIZE, PATCH_SIZE), strides=(PATCH_SIZE, PATCH_SIZE),
                      padding='VALID', name='conv_projection'),
        layers.Reshape(target_shape=(NUM_PATCHES, PROJECTION_DIM), name='flatten_projection'),
    ], name='projection')
    
    # Positional embedding
    init_shape = (1, NUM_PATCHES + 1, PROJECTION_DIM)
    self.positional_embedding = tf.Variable(tf.zeros(init_shape), name='position_embedding')

    # Transformer blocks
    dpr = [x for x in tf.linspace(0.0, DROP_PATH_RATE, NUM_LAYERS)]
    self.transformer_blocks = [
        transformer(drop_prob=dpr[i], name=f'transformer_block_{i}')
        for i in range(NUM_LAYERS)
    ]

    # CLS token
    initial_value = tf.zeros((1, 1, PROJECTION_DIM))
    self.cls_token = tf.Variable(
        initial_value=initial_value, trainable=True, name='cls'
    )

    # Other layers
    self.dropout = layers.Dropout(DROPOUT_RATE)
    self.layer_norm = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)
    self.head = layers.Dense(
        NUM_CLASSES, name='classification_head'
    )

  def call(self, inputs, training=True):
    n = tf.shape(inputs)[0]

    projected_patches = self.projection(inputs)
    cls_token = tf.tile(self.cls_token, (n, 1, 1))
    cls_token = tf.cast(cls_token, projected_patches.dtype)
    projected_patches = tf.concat([cls_token, projected_patches], axis=1)

    encoded_patches = (
        self.positional_embedding + projected_patches
    ) # (B, number_patches, projection_dim)
    encoded_patches = self.dropout(encoded_patches)

    for transformer_module in self.transformer_blocks:
      encoded_patches = transformer_module(encoded_patches)

    representation = self.layer_norm(encoded_patches)
    encoded_patches = representation[:, 0]
    output = self.head(encoded_patches)
    return output