In [None]:
!pip install -U tensorflow-addons

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
from tensorflow.keras import layers

## Data preparation

In [None]:
NUM_CLASSES = 100
INPUT_SHAPE = (32, 32, 3)

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
y_train = keras.utils.to_categorical(y_train, NUM_CLASSES)
y_test = keras.utils.to_categorical(y_test, NUM_CLASSES)
print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")

plt.figure(figsize=(10, 10))
for i in range(25):
    plt.subplot(5, 5, i + 1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(x_train[i])
plt.show()

## Hyper-parameters

In [None]:
PATCH_SIZE = (2, 2)  # 2-by-2 sized patches
dropout_rate = 0.03  # Dropout rate
num_heads = 8  # Attention heads
EMBED_DIM = 64  # Embedding dimension
num_mlp = 256  # MLP layer size
qkv_bias = True  # Convert embedded patches to query, key, and values with a learnable additive value
WINDOW_SIZE = 2 # Size of attention window
SHIFT_SIZE = 1  # Size of shifting window
IMAGE_DIMENSION = 32  # Initial image size

NUM_PATCH_X = INPUT_SHAPE[0] // PATCH_SIZE[0]
NUM_PATCH_Y = INPUT_SHAPE[1] // PATCH_SIZE[1]

learning_rate = 1e-3
batch_size = 128
num_epochs = 40
validation_split = 0.1
weight_decay = 0.0001
label_smoothing = 0.1

## Window partition

In [None]:
def window_partition(x, window_size):
  _, height, width, channels = x.shape
  print(f'height: {height}, width: {width}, channels: {channels} @ window_partition')
  patch_num_y = height // window_size
  patch_num_x = width // window_size 
  print(f'patch_num_y: {patch_num_y}, patch_num_x: {patch_num_x} @ window_partition')
  # `batch_size` remains same
  # `height` and `width` are factorized into `patch_num` * `window_size`
  x = tf.reshape(
      x, shape=[-1, patch_num_y, window_size, patch_num_x, window_size, channels]
  )
  print(f'x_reshaped: {x.shape} @ window_partition')
  # patch_index first
  x = tf.transpose(x, [0, 1, 3, 2, 4, 5])
  print(f'x_transposed[0, 1, 3, 2, 4, 5]: {x.shape} @ window_partition')
  # window index based array
  windows = tf.reshape(x, shape=[-1, window_size, window_size, channels])
  return windows

In [None]:
def window_reverse(windows, window_size, height, width, channels):
    patch_num_y = height // window_size
    patch_num_x = width // window_size
    x = tf.reshape(
        windows,
        shape=(-1, patch_num_y, patch_num_x, window_size, window_size, channels),
    )
    x = tf.transpose(x, perm=(0, 1, 3, 2, 4, 5))
    x = tf.reshape(x, shape=(-1, height, width, channels))
    return x

In [None]:
class DropPath(layers.Layer):
    def __init__(self, drop_prob=None, **kwargs):
        super(DropPath, self).__init__(**kwargs)
        self.drop_prob = drop_prob

    def call(self, x):
        input_shape = tf.shape(x)
        batch_size = input_shape[0]
        rank = x.shape.rank
        shape = (batch_size,) + (1,) * (rank - 1)
        random_tensor = (1 - self.drop_prob) + tf.random.uniform(shape, dtype=x.dtype)
        path_mask = tf.floor(random_tensor)
        output = tf.math.divide(x, 1 - self.drop_prob) * path_mask
        return output

## Swin Transformer

In [None]:
class PatchExtract(layers.Layer):
  def __init__(self, patch_size, **kwargs):
    super(PatchExtract, self).__init__(**kwargs)
    self.patch_size_x = patch_size[0]
    self.patch_size_y = patch_size[1]

  def call(self, images):
    batch_size = tf.shape(images)[0]
    patches = tf.image.extract_patches(
        images=images,
        sizes=(1, self.patch_size_y, self.patch_size_x, 1),
        strides=(1, self.patch_size_y, self.patch_size_x, 1), # non-overlapping
        rates=(1, 1, 1, 1), # no subsample
        padding='VALID'
    )
    patch_dim = patches.shape[-1] # patch_size_y * patch_size_x
    patch_num = patches.shape[1]
    return tf.reshape(patches, (batch_size, patch_num * patch_num, patch_dim))


In [None]:
class PatchEmbedding(layers.Layer):
  def __init__(self, num_patch, embed_dim, **kwargs):
    super(PatchEmbedding, self).__init__(**kwargs)
    self.num_patch = num_patch
    self.proj = layers.Dense(embed_dim)
    self.pos_embed = layers.Embedding(input_dim=num_patch, output_dim=embed_dim)

  def call(self, patch):
    pos = tf.range(start=0, limit=self.num_patch, delta=1)
    return self.proj(patch) + self.pos_embed(pos)

In [None]:
def create_swin_transformer(
    input_shape=INPUT_SHAPE,
    image_dimension=IMAGE_DIMENSION,
    patch_size=PATCH_SIZE,
    num_patch_x=NUM_PATCH_X,
    num_patch_y=NUM_PATCH_Y,
    embed_dim=EMBED_DIM,
    num_classes=NUM_CLASSES,
    shift_size=SHIFT_SIZE,
    window_size=WINDOW_SIZE,
):
  # Print args
  print(f'input_shape: {input_shape}')
  print(f'num_patch_x: {num_patch_x}')
  print(f'num_patch_y: {num_patch_y}')
  print(f'image_dimension: {image_dimension}')

  inputs = layers.Input(shape=input_shape) # (32, 32, 3)
  x = layers.RandomCrop(image_dimension, image_dimension)(inputs) # crop image from the input randomly
  x = layers.RandomFlip('horizontal')(x)
  x = PatchExtract(patch_size)(x)
  print(f'patch extract: {x.shape}')
  x = PatchEmbedding(num_patch_x * num_patch_y, embed_dim)(x)
  print(f'patch embedding: {x.shape}')

  #
  # Swin Transformer Block
  #

  # skip connection and LN
  height, width = num_patch_y, num_patch_x
  _, num_patches_before, channels = x.shape
  x_skip = x
  x = layers.LayerNormalization(epsilon=1e-5)(x)
  print(f'LN: {x.shape}')

  # patch extract from shifted window
  x = tf.reshape(x, shape=(-1, height, width, channels))
  print(f'x_reshaped: {x.shape}')
  if shift_size > 0:
    shifted_x = tf.roll(x, shift=[-shift_size, -shift_size], axis=[1, 2]) # shift backward for each axis
  else:
    shifted_x = x
  print(f'shifted_x: {shifted_x.shape}')
  x_windows = window_partition(shifted_x, window_size)
  print(f'x_windows: {x_windows.shape}') # 4 windows

  x = tf.reshape(x, shape=(-1, height * width, embed_dim)) # temporal

  # Classification Head
  x = layers.GlobalAveragePooling1D()(x) # (batch_size, embed_dim)
  print(f'global avg pool1d: {x.shape}')
  outputs = layers.Dense(num_classes, activation='softmax')(x)
  print(f'outputs: {outputs.shape}')

  # Create model
  model = keras.Model(inputs=inputs, outputs=outputs)
  return model

In [None]:
create_swin_transformer()