# Transformers for Vision
:label:`sec_vision-transformer`

The Transformer architecture was initially proposed 
for sequence to sequence learning, 
with a focus on machine translation. 
Subsequently, Transformers emerged as the model of choice 
in various natural language processing tasks :cite:`Radford.Narasimhan.Salimans.ea.2018,Radford.Wu.Child.ea.2019,brown2020language,Devlin.Chang.Lee.ea.2018,raffel2020exploring`. 
However, in the field of computer vision
the dominant architecture has remained
the CNN (:numref:`chap_modern_cnn`).
Naturally, researchers started to wonder
if it might be possible to do better
by adapting Transformer models to image data.
This question sparked immense interest
in the computer vision community.
Recently, :citet:`ramachandran2019stand` proposed 
a scheme for replacing convolution with self-attention. 
However, its use of specialized patterns in attention 
makes it hard to scale up models on hardware accelerators.
Then, :citet:`cordonnier2020relationship` theoretically proved 
that self-attention can learn to behave similarly to convolution. 
Empirically, $2 \times 2$ patches were taken from images as inputs, 
but the small patch size makes the model 
only applicable to image data with low resolutions.

Without specific constraints on patch size,
*vision Transformers* (ViTs)
extract patches from images
and feed them into a Transformer encoder
to obtain a global representation,
which will finally be transformed for classification :cite:`Dosovitskiy.Beyer.Kolesnikov.ea.2021`.
Notably, Transformers show better scalability than CNNs:
when training larger models on larger datasets,
vision Transformers outperform ResNets by a significant margin. 
Similar to the landscape of network architecture design in natural language processing,
Transformers also became a game-changer in computer vision.


## Model

:numref:`fig_vit` depicts
the model architecture of vision Transformers.
This architecture consists of a stem
that patchifies images, 
a body based on the multi-layer Transformer encoder,
and a head that transforms the global representation
into the output label.

![The vision Transformer architecture. In this example, an image is split into 9 patches. A special “&lt;cls&gt;” token and the 9 flattened image patches are transformed via patch embedding and $n$ Transformer encoder blocks into 10 representations, respectively. The “&lt;cls&gt;” representation is further transformed into the output label.](../img/vit.svg)
:label:`fig_vit`

Consider an input image with height $h$, width $w$,
and $c$ channels.
Specifying the patch height and width both as $p$,
the image is split into a sequence of $m = hw/p^2$ patches,
where each patch is flattened to a vector of length $cp^2$.
In this way, image patches can be treated similarly to tokens in text sequences by Transformer encoders.
A special “&lt;cls&gt;” (class) token and
the $m$ flattened image patches are linearly projected
into a sequence of $m+1$ vectors,
summed with learnable positional embeddings.
The multi-layer Transformer encoder
transforms $m+1$ input vectors
into the same amount of output vector representations of the same length.
It works exactly the same way as the original Transformer encoder in :numref:`fig_transformer`,
only differing in the position of normalization.
Since the “&lt;cls&gt;” token attends to all the image patches 
via self-attention (see :numref:`fig_cnn-rnn-self-attention`),
its representation from the Transformer encoder output
will be further transformed into the output label.

In [1]:
%use @file[../djl.json]
%use lets-plot
@file:DependsOn("../D2J-1.0-SNAPSHOT.jar")
import ai.djl.modality.nlp.DefaultVocabulary
import ai.djl.modality.nlp.Vocabulary
import ai.djl.modality.nlp.embedding.TrainableWordEmbedding
import jp.live.ugai.d2j.timemachine.RNNModelScratch
import jp.live.ugai.d2j.timemachine.TimeMachine.trainCh8
import jp.live.ugai.d2j.timemachine.TimeMachineDataset
import jp.live.ugai.d2j.timemachine.Vocab
import jp.live.ugai.d2j.RNNModel
import jp.live.ugai.d2j.util.StopWatch
import jp.live.ugai.d2j.util.Accumulator
import jp.live.ugai.d2j.util.Training
import jp.live.ugai.d2j.util.TrainingChapter9
import jp.live.ugai.d2j.lstm.Decoder
import jp.live.ugai.d2j.lstm.Encoder
import jp.live.ugai.d2j.lstm.EncoderDecoder
import jp.live.ugai.d2j.util.NMT
import jp.live.ugai.d2j.attention.AttentionDecoder
import jp.live.ugai.d2j.attention.MultiHeadAttention
import jp.live.ugai.d2j.PositionalEncoding
import jp.live.ugai.d2j.MaskedSoftmaxCELoss
import java.util.Locale
import kotlin.random.Random
import kotlin.collections.List
import kotlin.collections.Map
import kotlin.Pair

In [10]:
import ai.djl.basicdataset.cv.classification.FashionMnist
import ai.djl.metric.Metrics

In [11]:
    val manager = NDManager.newBaseManager()
    val ps = ParameterStore(manager, false)

## Patch Embedding

To implement a vision Transformer, let's start 
with patch embedding in :numref:`fig_vit`. 
Splitting an image into patches 
and linearly projecting these flattened patches
can be simplified as a single convolution operation, 
where both the kernel size and the stride size are set to the patch size.

In [14]:
class PatchEmbedding(imgSize: Int = 96, val patchSize: Int = 16, val numHiddens: Int = 512) : AbstractBlock() {
    val numPatches = (imgSize / patchSize) * (imgSize / patchSize)
    val conv = Conv2d.builder()
        .setKernelShape(Shape(patchSize.toLong(), patchSize.toLong()))
        .optStride(Shape(patchSize.toLong(), patchSize.toLong()))
        .setFilters(numHiddens)
        .build()

    override fun forwardInternal(
        parameterStore: ParameterStore,
        inputs: NDList,
        training: Boolean,
        params: PairList<String, Any>?
    ): NDList {
        // Output shape: (batch size, no. of patches, no. of channels)
        val f = conv.forward(parameterStore, inputs, training, params).head()
        return NDList(f.reshape(Shape(f.shape[0], f.shape[1], -1)).transpose(0, 2, 1))
    }

    override fun initializeChildBlocks(manager: NDManager, dataType: DataType, vararg inputShapes: Shape) {
        conv.initialize(manager, dataType, *inputShapes)
    }

    /* We won't implement this since we won't be using it but it's required as part of an AbstractBlock  */
    override fun getOutputShapes(inputShapes: Array<Shape>): Array<Shape> {
        return arrayOf(Shape(inputShapes[0][0], numPatches.toLong(), numHiddens.toLong()))
    }
}

In the following example, taking images with height and width of `img_size` as inputs,
the patch embedding outputs `(img_size//patch_size)**2` patches 
that are linearly projected to vectors of length `num_hiddens`.

In [15]:
    val manager = NDManager.newBaseManager()
    val ps = ParameterStore(manager, false)
    val imgSize = 96
    val patchSize = 16
    val numHiddens = 512
    val batchSize = 4
    val patchEmb = PatchEmbedding(imgSize, patchSize, numHiddens)
    val X = manager.randomNormal(Shape(batchSize.toLong(), 3, imgSize.toLong(), imgSize.toLong()))
    patchEmb.initialize(manager, DataType.FLOAT32, X.shape)
    println(patchEmb.forward(ps, NDList(X), false))


NDList size: 1
0 : (4, 36, 512) float32



## Vision Transformer Encoder
:label:`subsec_vit-encoder`

The MLP of the vision Transformer encoder is slightly different 
from the position-wise FFN of the original Transformer encoder 
(see :numref:`subsec_positionwise-ffn`).
First, here the activation function uses the Gaussian error linear unit (GELU),
which can be considered as a smoother version of the ReLU :cite:`hendrycks2016gaussian`.
Second, dropout is applied to the output of each fully connected layer in the MLP for regularization.

In [16]:
class ViTMLP(mlpNumHiddens: Int, mlpNumOutputs: Int, dropout: Float = 0.5f) : SequentialBlock() {
    val dense1 = Linear.builder().setUnits(mlpNumHiddens.toLong()).build()
    val gelu: (NDList) -> NDList = Activation::relu
    val dropout1 = Dropout.builder().optRate(dropout).build()
    val dense2 = Linear.builder().setUnits(mlpNumOutputs.toLong()).build()
    val dropout2 = Dropout.builder().optRate(dropout).build()

    init {
        add(dense1)
        add(gelu)
        add(dropout1)
        add(dense2)
        add(dropout2)
    }
}

The vision Transformer encoder block implementation
just follows the pre-normalization design in :numref:`fig_vit`,
where normalization is applied right *before* multi-head attention or the MLP.
In contrast to post-normalization ("add & norm" in :numref:`fig_transformer`),
where normalization is placed right *after* residual connections,
pre-normalization leads to more effective or efficient training for Transformers :cite:`baevski2018adaptive,wang2019learning,xiong2020layer`.


In [17]:
class ViTBlock(numHiddens: Int, val normShape: Int, mlpNumHiddens: Int, numHeads: Int, dropout: Float, useBias: Boolean = false) : AbstractBlock() {
    val ln1 = LayerNorm.builder().axis(normShape).build()
    val attention = MultiHeadAttention(numHiddens, numHeads, dropout, useBias)
    val ln2 = LayerNorm.builder().axis(normShape).build()
    val mlp = ViTMLP(mlpNumHiddens, numHiddens, dropout)

    override fun forwardInternal(
        parameterStore: ParameterStore,
        inputs: NDList,
        training: Boolean,
        params: PairList<String, Any>?
    ): NDList {
        var X = inputs[0]
        val validLens = if (inputs.size <2) null else inputs[1]
        X = ln1.forward(parameterStore, NDList(X), training, params).head()
        val att = attention.forward(parameterStore, NDList(X, X, X, validLens), training, params).head()
        X = X.add(att)
        val ln = ln2.forward(parameterStore, NDList(X), training, params).head()
        val ret = mlp.forward(parameterStore, NDList(ln), training, params)
        return ret
    }

    override fun initializeChildBlocks(manager: NDManager, dataType: DataType, vararg inputShapes: Shape) {
        val arr = List<Long>(normShape + 1) { normShape.toLong() }
        ln1.initialize(manager, dataType, Shape(arr))
        val shapes = arrayOf(inputShapes[0], inputShapes[0], inputShapes[0], inputShapes[1])

        attention.initialize(manager, dataType, *shapes)
        ln2.initialize(manager, dataType, Shape(arr))
        mlp.initialize(manager, dataType, inputShapes[0])
    }

    /* We won't implement this since we won't be using it but it's required as part of an AbstractBlock  */
    override fun getOutputShapes(inputShapes: Array<Shape>): Array<Shape> {
        return inputShapes
    }
}

Same as in :numref:`subsec_transformer-encoder`,
any vision Transformer encoder block does not change its input shape.


In [18]:
    val X1 = manager.ones(Shape(2, 100, 24))
    val encoderBlk = ViTBlock(24, 24, 48, 8, 0.5f)
    encoderBlk.initialize(manager, DataType.FLOAT32, X1.shape, Shape(2))
    println(encoderBlk.forward(ps, NDList(X1, null), false))
    println("Shapes : ${encoderBlk.getOutputShapes(arrayOf(X1.shape)).toList()}")


NDList size: 1
0 : (2, 100, 24) float32

Shapes : [(2, 100, 24)]


## Putting It All Together

The forward pass of vision Transformers below is straightforward.
First, input images are fed into an `PatchEmbedding` instance,
whose output is concatenated with the “&lt;cls&gt;”  token embedding.
They are summed with learnable positional embeddings before dropout.
Then the output is fed into the Transformer encoder that stacks `num_blks` instances of the `ViTBlock` class.
Finally, the representation of the “&lt;cls&gt;”  token is projected by the network head.

In [19]:
class ViT(
    val imgSize: Int,
    patchSize: Int,
    val numHiddens: Int,
    mlpNumHiddens: Int,
    numHeads: Int,
    numBlks: Int,
    embDropout: Float,
    blkDropout: Float,
    lr: Float = 0.1f,
    useBias: Boolean = false,
    val numClasses: Int = 10
) : AbstractBlock() {
    val patchEmbedding = PatchEmbedding(imgSize, patchSize, numHiddens)
    val clsToken = Parameter.builder()
        .optRequiresGrad(true)
        .setType(Parameter.Type.BIAS)
        .optShape(Shape(1, 1, numHiddens.toLong()))
        .build()
    val numSteps: Int = patchEmbedding.numPatches + 1
    val posEmbedding = Parameter.builder()
        .optRequiresGrad(true)
        .optShape(Shape(1, numSteps.toLong(), numHiddens.toLong()))
//    torch.randn(1, num_steps, num_hiddens))
        .setType(Parameter.Type.BIAS)
        .build()
    val dropOut = Dropout.builder().optRate(embDropout).build()
    val blks = mutableListOf<Pair<String, AbstractBlock>>()
    val blks0 = SequentialBlock()
    val head = SequentialBlock()
        .add(LayerNorm.builder().build())
        .add(Linear.builder().setUnits(numClasses.toLong()).build())
    init {
        addParameter(clsToken)
        addParameter(posEmbedding)
        clsToken.setInitializer(ConstantInitializer(0f))
        for (i in 0 until numBlks) {
            blks0.add(ViTBlock(numHiddens, numHiddens, mlpNumHiddens, numHeads, blkDropout, useBias))
        }
    }

    override fun forwardInternal(
        parameterStore: ParameterStore,
        inputs: NDList,
        training: Boolean,
        params: PairList<String, Any>?
    ): NDList {
        var X = patchEmbedding.forward(parameterStore, inputs, training, params).head()
        // X = torch.cat((self.cls_token.expand(X.shape[0], -1, -1), X), 1)

        X = clsToken.array.repeat(0, X.shape[0]).concat(X, 1)
        X = dropOut.forward(parameterStore, NDList(X.add(posEmbedding.array)), training, params).head()
        X = blks0.forward(parameterStore, NDList(X), training, params).head()
        return head.forward(parameterStore, NDList(X.get(NDIndex(":, 0"))), training, params)
    }

    override fun initializeChildBlocks(manager: NDManager, dataType: DataType, vararg inputShapes: Shape) {
        clsToken.initialize(manager, dataType)
        posEmbedding.initialize(manager, dataType)
        patchEmbedding.initialize(manager, dataType, Shape(inputShapes[0][0], inputShapes[0][1], imgSize.toLong(), imgSize.toLong()))
        blks0.initialize(manager, dataType, Shape(inputShapes[0][0], numSteps.toLong(), numHiddens.toLong()), Shape(inputShapes[0][0]))
        head.initialize(manager, dataType, Shape(inputShapes[0][0], numHiddens.toLong()))
    }

    override fun getOutputShapes(inputShapes: Array<Shape>): Array<Shape> {
        return arrayOf(Shape(inputShapes[0][0], numClasses.toLong()))
    }
}

## Training

Training a vision Transformer on the Fashion-MNIST dataset is just like how CNNs were trained in :numref:`chap_modern_cnn`.


In [None]:
    val imgSize0 = 28
    val patchSize0 = 16
    val numHiddens0 = 512
    val mlpNumHiddens0 = 2048
    val numHeads0 = 8
    val numBlks0 = 2
    val embDropout0 = 0.1f
    val blkDropout0 = 0.1f
    val batchSize0 = 256
    val lr = 0.001f
    val X0 = manager.ones(Shape(batchSize0.toLong(), 1, imgSize0.toLong(), imgSize0.toLong()))
    val encoder = ViT(imgSize0, patchSize0, numHiddens0, mlpNumHiddens0, numHeads0, numBlks0, embDropout0, blkDropout0, lr)
//    encoder.initialize(manager, DataType.FLOAT32, X0.shape)

    val randomShuffle = true

// Get Training and Validation Datasets

// Get Training and Validation Datasets
    val trainingSet = FashionMnist.builder()
        .optUsage(Dataset.Usage.TRAIN)
        .setSampling(batchSize0, randomShuffle)
        .optLimit(Long.MAX_VALUE)
        .build()

    val validationSet = FashionMnist.builder()
        .optUsage(Dataset.Usage.TEST)
        .setSampling(batchSize0, false)
        .optLimit(Long.MAX_VALUE)
        .build()
    val model: Model = Model.newInstance("softmax-regression")
    model.setBlock(encoder)
    val loss: Loss = Loss.softmaxCrossEntropyLoss()
    val lrt: Tracker = Tracker.fixed(lr)
    val adam: Optimizer = Optimizer.adam().optLearningRateTracker(lrt).build()
    val config: DefaultTrainingConfig = DefaultTrainingConfig(loss)
        .optOptimizer(adam) // Optimizer (loss function)
        .optInitializer(XavierInitializer(), "")
        .addEvaluator(Accuracy()) // Model Accuracy
        .addTrainingListeners(*TrainingListener.Defaults.logging()); // Logging

    val trainer: Trainer = model.newTrainer(config)
    trainer.initialize(X0.shape)
    trainer.metrics = Metrics()
    EasyTrain.fit(trainer, 3, trainingSet, validationSet)

    val batch = validationSet.getData(manager).iterator().next()
    val X3 = batch.getData().head()
    val yHat: IntArray = encoder.forward(ps, NDList(X3), false).head().argMax(1).toType(DataType.INT32, false).toIntArray()
    println(yHat.toList().subList(0, 20))
    println(batch.getLabels().head().toFloatArray().toList().subList(0, 20))


Training:    100% |████████████████████████████████████████| Accuracy: 0.66, SoftmaxCrossEntropyLoss: 1.04
Validating:  100% |████████████████████████████████████████|
Training:     68% |████████████████████████████            | Accuracy: 0.77, SoftmaxCrossEntropyLoss: 0.63

## Summary and Discussion

You may notice that for small datasets like Fashion-MNIST, 
our implemented vision Transformer 
does not outperform the ResNet in :numref:`sec_resnet`.
Similar observations can be made even on the ImageNet dataset (1.2 million images).
This is because Transformers *lack* those useful principles in convolution, 
such as translation invariance and locality (:numref:`sec_why-conv`).
However, the picture changes when training larger models on larger datasets (e.g., 300 million images),
where vision Transformers outperform ResNets by a large margin in image classification, demonstrating
intrinsic superiority of Transformers in scalability :cite:`Dosovitskiy.Beyer.Kolesnikov.ea.2021`.
The introduction of vision Transformers 
has changed the landscape of network design for modeling image data.
They were soon shown effective on the ImageNet dataset
with data-efficient training strategies of DeiT :cite:`touvron2021training`.
However, quadratic complexity of self-attention 
(:numref:`sec_self-attention-and-positional-encoding`)
makes the Transformer architecture
less suitable for higher-resolution images.
Towards a general-purpose backbone network in computer vision,
Swin Transformers addressed the quadratic computational complexity 
with respect to image size (:numref:`subsec_cnn-rnn-self-attention`)
and added back convolution-like priors,
extending the applicability of Transformers to a range of computer vision tasks 
beyond image classification with state-of-the-art results :cite:`liu2021swin`.



## Exercises

1. How does the value of `img_size` affect training time?
1. Instead of projecting the “&lt;cls&gt;” token representation to the output, how to project the averaged patch representations? Implement this change and see how it affects the accuracy.
1. Can you modify hyperparameters to improve the accuracy of the vision Transformer?
