# Swin Transformer example with the MNIST dataset

In [1]:
import numpy as np
from tensorflow import keras
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling1D

In [2]:
from models import swin_layers

# Problem statement and data

**MNIST**

The MNIST dataset contains handwritten digits as gray-scale images with pixel sizes of 28-by-28. The pixel values are converted to float numbers and normalized with minimum-maximum scaling. The dataset is labeled with ten categories, represents digits of 0-9.

**Problem statement**

A supervised image classification problem is proposed to demonstrate the application of the Swin Transformer. By taking preprocessed grayscale images as inputs, the Swin Transformer is trained to classify the ten image labels. 

In [3]:
# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255

# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)

y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)

# The Swin Transformer

Liu, Z., Lin, Y., Cao, Y., Hu, H., Wei, Y., Zhang, Z., Lin, S. and Guo, B., 2021. Swin transformer: Hierarchical vision transformer using shifted windows. arXiv preprint arXiv:2103.14030. https://arxiv.org/abs/2103.14030

## Hyperparameters

A simplified Swin Transformer configuration is applied with a patch embedding layer, two transformer blocks, and a patch merging layer. Global average pooling and softmax output activations are applied after Swin Transformer blocks.

See Liu et al. (2021) for more complicated architecture variants.

The hyperparameters of the Swin Transfor are listed as follows:

In [4]:
input_size = (28, 28, 1) # The image size of the MNIST
patch_size = (2, 2) # Segment 28-by-28 frames into 2-by-2 sized patches, patch contents and positions are embedded
n_labels = 10 # MNIST labels

# Dropout parameters
mlp_drop_rate = 0.01 # Droupout after each MLP layer
attn_drop_rate = 0.01 # Dropout after Swin-Attention
proj_drop_rate = 0.01 # Dropout at the end of each Swin-Attention block, i.e., after linear projections
drop_path_rate = 0.01 # Drop-path within skip-connections

# Self-attention parameters 
# (Fixed for all the blocks in this configuration, but can vary per block in larger architectures)
num_heads = 8 # Number of attention heads
embed_dim = 64 # Number of embedded dimensions
mlp_num = 256 # Number of MLP nodes
qkv_bias = True # Convert embedded patches to query, key, and values with a learnable additive value
qk_scale = None # None: Re-scale query based on embed dimensions per attention head # Float for user specified scaling factor

# Shift-window parameters
window_size = 2 # Size of attention window (height = width)
shift_size = window_size // 2 # Size of shifting (shift_size < window_size)

num_patch_x = input_size[0]//patch_size[0]
num_patch_y = input_size[1]//patch_size[1]

## Model configuration

In [5]:
# The input section
IN = Input(input_size)
X = IN

# Patch embedding
X = swin_layers.PatchEmbed(patch_size=patch_size, embed_dim=embed_dim)(X)

# -------------------- Swin transformers -------------------- #
# Stage 1: window-attention + Swin-attention + patch-merging

for i in range(2):
    
    if i % 2 == 0:
        shift_size_temp = 0
    else:
        shift_size_temp = shift_size

    X = swin_layers.SwinTransformerBlock(dim=embed_dim, num_patch=(num_patch_x, num_patch_y), num_heads=num_heads, 
                             window_size=window_size, shift_size=shift_size_temp, mlp_num=mlp_num, qkv_bias=qkv_bias, qk_scale=qk_scale,
                             mlp_drop=mlp_drop_rate, attn_drop=attn_drop_rate, proj_drop=proj_drop_rate, drop_path_prob=drop_path_rate, 
                             prefix='swin_block{}'.format(i))(X)
# Patch-merging
#    Pooling patch sequences. Half the number of patches (skip every two patches) and double the embedded dimensions
X = swin_layers.PatchMerging((num_patch_x, num_patch_y), dim=64, prefix='down{}'.format(i))(X)

# ----------------------------------------------------------- #

# Convert patch sequences to vectors
X = GlobalAveragePooling1D()(X)

# The output section
OUT = Dense(n_labels, activation='softmax')(X)

In [6]:
# Model configuration
model = keras.models.Model(inputs=[IN,], outputs=[OUT,])

In [7]:
model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
PatchEmbed (PatchEmbed)      (None, 196, 64)           320       
_________________________________________________________________
swin_transformer_block (Swin (None, 196, 64)           50072     
_________________________________________________________________
swin_transformer_block_1 (Sw (None, 196, 64)           50856     
_________________________________________________________________
patch_merging (PatchMerging) (None, 49, 128)           33280     
_________________________________________________________________
global_average_pooling1d (Gl (None, 128)               0         
_________________________________________________________________
dense (Dense)                (None, 10)                1290  

## Training

Gradient clipping is applied to prevent gradient explosion.

Note: the traning of this example is not systematic, and is provided for illustration purposes only.

In [8]:
# Compile the model
opt = keras.optimizers.Adam(learning_rate=1e-4, clipvalue=0.5)
model.compile(loss=keras.losses.categorical_crossentropy, optimizer=opt, metrics=['accuracy',])

# Training
model.fit(x_train, y_train, batch_size=32, epochs=10, validation_split=0.1)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x2b2fbb90aa90>

## Evaluation

In [9]:
y_pred = model.predict(x_train[:10, ...])

In [10]:
y_pred[0]

array([1.2170502e-04, 1.0908453e-06, 1.1731385e-03, 6.4424365e-03,
       7.7731880e-07, 9.9109441e-01, 2.5945242e-08, 1.1356717e-03,
       2.2678056e-05, 7.9827478e-06], dtype=float32)

## Save and reuse

The model can be saved as `model.save()`, but it contains python objects that are not part of the `tensorflow.keras`. Thus when loading the model, it is preferred to load the weights only, and freeze them within a new configuration.

e.g.

```python
weights = dummy_loader(model_old_path)
model_new = swin_transformer_model(...)
model_new.set_weights(weights)
```