In [None]:
pip install -qq -U tensorflow-addons

In [None]:
import tensorflow as tf

In [None]:
NUM_CLASSES = 100
INPUT_SHAPE = (32, 32, 3)

IMAGE_SIZE = 72
PATCH_SIZE = 6
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2

PROJECTION_DIM = 64

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data()

## Data augmentation

According to DeiT, various techniqus are required to effectively train ViTs.
Thus we applied data augmentations such as CutMix, Mixup, Auto Augment, Repeated Augment to all models.

In [None]:
data_augmentation = tf.keras.Sequential(
  [
    tf.keras.layers.Normalization(),
    tf.keras.layers.Resizing(IMAGE_SIZE, IMAGE_SIZE),
    tf.keras.layers.RandomFlip('horizontal'),
    tf.keras.layers.RandomRotation(factor=0.2),
    tf.keras.layers.RandomZoom(height_factor=0.2, width_factor=0.2),
  ],
  name='data_augmentation'
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)

## PatchEncoder

In [None]:
class PatchEncoder(tf.keras.layers.Layer):
  def __init__(self, num_patches=NUM_PATCHES, projection_dim=PROJECTION_DIM, **kwargs):
    super().__init__(**kwargs)
    self.num_patches = num_patches
    self.position_embedding = tf.keras.layers.Embedding(
        input_dim=num_patches, output_dim=projection_dim
    )
    self.positions = tf.range(start=0, limit=self.num_patches, delta=1)

  def call(self, encoded_patches):
    encoded_positions = self.position_embedding(self.positions)
    encoded_patches = encoded_patches + encoded_positions
    return encoded_patches

