In [None]:
pip install -qq -U tensorflow-addons

In [None]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt

In [None]:
NUM_CLASSES = 100
INPUT_SHAPE = (32, 32, 3)

IMAGE_SIZE = 72
PATCH_SIZE = 6
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2

PROJECTION_DIM = 64

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data()

## Data augmentation

According to DeiT, various techniqus are required to effectively train ViTs.
Thus we applied data augmentations such as CutMix, Mixup, Auto Augment, Repeated Augment to all models.

In [None]:
data_augmentation = tf.keras.Sequential(
  [
    tf.keras.layers.Normalization(),
    tf.keras.layers.Resizing(IMAGE_SIZE, IMAGE_SIZE),
    tf.keras.layers.RandomFlip('horizontal'),
    tf.keras.layers.RandomRotation(factor=0.2),
    tf.keras.layers.RandomZoom(height_factor=0.2, width_factor=0.2),
  ],
  name='data_augmentation'
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)

## PatchEncoder

In [None]:
class PatchEncoder(tf.keras.layers.Layer):
  def __init__(self, num_patches=NUM_PATCHES, projection_dim=PROJECTION_DIM, **kwargs):
    super().__init__(**kwargs)
    self.num_patches = num_patches
    self.position_embedding = tf.keras.layers.Embedding(
        input_dim=num_patches, output_dim=projection_dim
    )
    self.positions = tf.range(start=0, limit=self.num_patches, delta=1)

  def call(self, encoded_patches):
    encoded_positions = self.position_embedding(self.positions)
    encoded_patches = encoded_patches + encoded_positions
    return encoded_patches

## Shifted patch tokenization

* start with an image
* shift the image in diagonal direction
* concat the diagonally shifted images with the orignal image
* extract pathces of the concatenated images
* flatten the spatial dimension of all patches
* layer normalize the flattened patches and then project it

In [None]:
class ShiftedPatchTokenization(layers.Layer):
  def __init__(
      self,
      image_size=IMAGE_SIZE,
      patch_size=PATCH_SIZE,
      num_patches=NUM_PATCHES,
      projection_dim=PROJECTION_DIM,
      vanilla=False,
      **kwargs
  ):
    super().__init__(**kwargs)
    self.vanilla = vanilla
    self.patch_size = patch_size
    self.num_patches = num_patches
    self.flatten_patches = layers.Reshape((num_patches, -1))
    self.layer_norm = layers.LayerNormalization(epsilon=1e-6)
    self.projection = layers.Dense(units=projection_dim)
    self.half_patch = self.patch_size // 2

  def crop_shift_pad(self, images, mode):
    if mode == "left-up":
      crop_height = self.half_patch
      crop_width = self.half_patch
      shift_height = 0
      shift_width = 0
    elif mode == "left-down":
      crop_height = 0
      crop_width = self.half_patch
      shift_height = self.half_patch
      shift_width = 0
    elif mode == "right-up":
      crop_height = self.half_patch
      crop_width = 0
      shift_height = 0
      shift_width = self.half_patch
    elif mode == "right-down":
      crop_height = 0
      crop_width = 0
      shift_height = self.half_patch
      shift_width = self.half_patch

    # Crop the shifted images and pad them
    crop = tf.image.crop_to_bounding_box(
        images,
        offset_height=crop_height,
        offset_width=crop_width,
        target_height=self.image_size - self.half_patch,
        target_width=self.image_size - self.half_patch,
    )
    shift_pad = tf.image.pad_to_bounding_box(
        crop,
        offset_height=shift_height,
        offset_width=shift_width,
        target_height=self.image_size,
        target_width=self.image_size,
    )
    return shift_pad

  def call(self, images):
    if not self.vanilla:
      # concat the shifted patches with the original image along last axis
      images = tf.concat(
          [
              images, 
              self.crop_shift_pad(images, mode='left-up'),
              self.crop_shift_pad(images, mode='left-down'),
              self.crop_shift_pad(images, mode='rigth-up'),
              self.crop_shift_pad(images, mode='right-down')
          ],
          axis=-1
      )
    # Patch generation
    patches = tf.image.extract_patches(
        images=images,
        sizes=[1, self.patch_size, self.patch_size, 1],
        strides=[1, self.patch_size, self.patch_size, 1],
        rates=[1, 1, 1, 1],
        padding='VALID'
    )
    flat_patches = self.flatten_patches(patches)
    if not self.vanilla:
      tokens = self.layer_norm(flat_patches)
      tokens = self.projection(tokens)
    else:
      tokens = self.projection(flat_patches)
    return (tokens, patches)

## Visualize the patches

In [None]:
x_image = x_train[np.random.choice(range(x_train.shape[0]))]

In [None]:
resized_image = tf.image.resize(
    tf.convert_to_tensor([x_image]), size=(IMAGE_SIZE, IMAGE_SIZE)
)

In [None]:
token, patch = ShiftedPatchTokenization(vanilla=True)(resized_image / 255.0)

In [None]:
print(token.shape)
print(patch.shape)

In [None]:
token, patch = token[0], patch[0]

In [None]:
plt.imshow(x_image / 255.0)

In [None]:
plt.figure(figsize=(4, 4))
index = 1
n = patch.shape[0]
for row in range(n):
  for col in range(n):
    plt.subplot(n, n, index)
    index += 1
    image = tf.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 3))
    plt.imshow(image)
    plt.axis('off')
plt.show()