# The Transformer Architecture
:label:`sec_transformer`


We have compared CNNs, RNNs, and self-attention in
:numref:`subsec_cnn-rnn-self-attention`.
Notably, self-attention
enjoys both parallel computation and
the shortest maximum path length.
Therefore naturally,
it is appealing to design deep architectures
by using self-attention.
Unlike earlier self-attention models
that still rely on RNNs for input representations :cite:`Cheng.Dong.Lapata.2016,Lin.Feng.Santos.ea.2017,Paulus.Xiong.Socher.2017`,
the Transformer model
is solely based on attention mechanisms
without any convolutional or recurrent layer :cite:`Vaswani.Shazeer.Parmar.ea.2017`.
Though originally proposed
for sequence to sequence learning on text data,
Transformers have been
pervasive in a wide range of
modern deep learning applications,
such as in areas of language, vision, speech, and reinforcement learning.

## Model

As an instance of the encoder-decoder
architecture,
the overall architecture of
the Transformer
is presented in :numref:`fig_transformer`.
As we can see,
the Transformer is composed of an encoder and a decoder.
Different from
Bahdanau attention
for sequence to sequence learning
in :numref:`fig_s2s_attention_details`,
the input (source) and output (target)
sequence embeddings
are added with positional encoding
before being fed into
the encoder and the decoder
that stack modules based on self-attention.

![The Transformer architecture.](https://d2l.ai/_images/transformer.svg)
:width:`400px`
:label:`fig_transformer`


Now we provide an overview of the
Transformer architecture in :numref:`fig_transformer`.
On a high level,
the Transformer encoder is a stack of multiple identical layers,
where each layer
has two sublayers (either is denoted as $\mathrm{sublayer}$).
The first
is a multi-head self-attention pooling
and the second is a positionwise feed-forward network.
Specifically,
in the encoder self-attention,
queries, keys, and values are all from the
outputs of the previous encoder layer.
Inspired by the ResNet design in :numref:`sec_resnet`,
a residual connection is employed
around both sublayers.
In the Transformer,
for any input $\mathbf{x} \in \mathbb{R}^d$ at any position of the sequence,
we require that $\mathrm{sublayer}(\mathbf{x}) \in \mathbb{R}^d$ so that
the residual connection $\mathbf{x} + \mathrm{sublayer}(\mathbf{x}) \in \mathbb{R}^d$ is feasible.
This addition from the residual connection is immediately
followed by layer normalization :cite:`Ba.Kiros.Hinton.2016`.
As a result, the Transformer encoder outputs a $d$-dimensional vector representation 
for each position of the input sequence.

The Transformer decoder is also a stack of multiple identical layers 
with residual connections and layer normalizations.
Besides the two sublayers described in
the encoder, the decoder inserts
a third sublayer, known as
the encoder-decoder attention,
between these two.
In the encoder-decoder attention,
queries are from the
outputs of the previous decoder layer,
and the keys and values are
from the Transformer encoder outputs.
In the decoder self-attention,
queries, keys, and values are all from the
outputs of the previous decoder layer.
However, each position in the decoder is
allowed to only attend to all positions in the decoder
up to that position.
This *masked* attention
preserves the auto-regressive property,
ensuring that the prediction only depends 
on those output tokens that have been generated.


We have already described and implemented
multi-head attention based on scaled dot-products
in :numref:`sec_multihead-attention`
and positional encoding in :numref:`subsec_positional-encoding`.
In the following, we will implement
the rest of the Transformer model.

In [1]:
%use @file[../djl-pytorch.json]
%use lets-plot
@file:DependsOn("../D2J-1.0-SNAPSHOT.jar")
import ai.djl.modality.nlp.DefaultVocabulary
import ai.djl.modality.nlp.Vocabulary
import ai.djl.modality.nlp.embedding.TrainableWordEmbedding
import jp.live.ugai.d2j.timemachine.RNNModelScratch
import jp.live.ugai.d2j.timemachine.TimeMachine.trainCh8
import jp.live.ugai.d2j.timemachine.TimeMachineDataset
import jp.live.ugai.d2j.timemachine.Vocab
import jp.live.ugai.d2j.RNNModel
import jp.live.ugai.d2j.util.StopWatch
import jp.live.ugai.d2j.util.Accumulator
import jp.live.ugai.d2j.util.Training
import jp.live.ugai.d2j.util.TrainingChapter9
import jp.live.ugai.d2j.lstm.Decoder
import jp.live.ugai.d2j.lstm.Encoder
import jp.live.ugai.d2j.lstm.EncoderDecoder
import jp.live.ugai.d2j.util.NMT
import jp.live.ugai.d2j.attention.AttentionDecoder
import jp.live.ugai.d2j.PositionalEncoding
import jp.live.ugai.d2j.MaskedSoftmaxCELoss
import java.util.Locale
import kotlin.random.Random
import kotlin.collections.List
import kotlin.collections.Map
import kotlin.Pair

In [2]:
val manager = NDManager.newBaseManager()
val ps = ParameterStore(manager, false)

## [**Positionwise Feed-Forward Networks**]
:label:`subsec_positionwise-ffn`

The positionwise feed-forward network transforms
the representation at all the sequence positions
using the same MLP.
This is why we call it *positionwise*.
In the implementation below,
the input `X` with shape
(batch size, number of time steps or sequence length in tokens,
number of hidden units or feature dimension)
will be transformed by a two-layer MLP into
an output tensor of shape
(batch size, number of time steps, `ffn_num_outputs`).

In [3]:

fun positionWiseFFN(ffn_num_hiddens: Long, ffn_num_outputs: Long) : AbstractBlock {
    val net = SequentialBlock()
    net.add(Linear.builder().setUnits(ffn_num_hiddens).build())
    net.add(Activation::relu)
    net.add(Linear.builder().setUnits(ffn_num_outputs).build());
//    net.setInitializer(NormalInitializer(), Parameter.Type.WEIGHT)
    return net
}

The following example
shows that [**the innermost dimension
of a tensor changes**] to
the number of outputs in
the positionwise feed-forward network.
Since the same MLP transforms
at all the positions,
when the inputs at all these positions are the same,
their outputs are also identical.

In [4]:
val ffn = positionWiseFFN(4, 8)
ffn.initialize(manager, DataType.FLOAT32, Shape(2,3,4))
ffn.forward(ps, NDList(manager.ones(Shape(2,3,4))), false)[0][0]

ND: (3, 8) gpu(0) float32
Check the "Development Guideline"->Debug to enable array display.


## Residual Connection and Layer Normalization

Now let's focus on the "add & norm" component in :numref:`fig_transformer`.
As we described at the beginning of this section,
this is a residual connection immediately
followed by layer normalization.
Both are key to effective deep architectures.

In :numref:`sec_batch_norm`,
we explained how batch normalization
recenters and rescales across the examples within
a minibatch.
As discussed in :numref:`subsec_layer-normalization-in-bn`,
layer normalization is the same as batch normalization
except that the former
normalizes across the feature dimension,
thus enjoying benefits of scale independence and batch size independence.
Despite its pervasive applications
in computer vision,
batch normalization
is usually empirically
less effective than layer normalization
in natural language processing
tasks, whose inputs are often
variable-length sequences.

The following code snippet
[**compares the normalization across different dimensions
by layer normalization and batch normalization**].


In [5]:
val ln = LayerNorm.builder().build()
ln.initialize(manager, DataType.FLOAT32, Shape(2,2))
val bn = BatchNorm.builder().build()
bn.initialize(manager, DataType.FLOAT32, Shape(2,2))
val X = manager.create(floatArrayOf(1f,2f,2f,3f)).reshape(Shape(2,2))
print("LayerNorm: ")
println(ln.forward(ps, NDList(X), false)[0])
print("BatchNorm: ")
println(bn.forward(ps, NDList(X), false)[0])

LayerNorm: ND: (2, 2) gpu(0) float32
Check the "Development Guideline"->Debug to enable array display.

BatchNorm: ND: (2, 2) gpu(0) float32
Check the "Development Guideline"->Debug to enable array display.



Now we can implement the `AddNorm` class
[**using a residual connection followed by layer normalization**].
Dropout is also applied for regularization.


In [None]:
    class AddNorm(rate: Float) : AbstractBlock() {
        val dropout = Dropout.builder().optRate(rate).build()
        val ln = LayerNorm.builder().build()

        init {
            addChildBlock("dropout", dropout)
            addChildBlock("layerNorm", ln)
        }

        override fun forwardInternal(
            ps: ParameterStore,
            inputs: NDList,
            training: Boolean,
            params: PairList<String, Any>?
        ): NDList {
            val x = inputs[0]
            val y = inputs[1]
            val dropoutResult = dropout.forward(ps, NDList(y), training, params).singletonOrThrow()
            val result = ln.forward(ps, NDList(dropoutResult.add(x)), training, params)
            return result
        }

        override fun getOutputShapes(inputShapes: Array<Shape>): Array<Shape> {
            return inputShapes
        }

        override fun initializeChildBlocks(manager: NDManager, dataType: DataType, vararg inputShapes: Shape) {
            dropout.initialize(manager, dataType, *inputShapes)
            ln.initialize(manager, dataType, *inputShapes)
        }
    }

    fun maskedSoftmax(_X: NDArray, _validLens: NDArray?): NDArray {
        var validLens = _validLens ?: return _X.softmax(-1)
        // Align validLens to the same device as X (e.g., GPU)
        if (validLens.device != _X.device) {
            validLens = validLens.toDevice(_X.device, false)
        }
        val shape: Shape = _X.shape
        val lastDim = shape.get(shape.dimension() - 1)
        // validLens can be 1D (batch,) or 2D (batch, steps); flatten to match X reshaping
        var lens = if (validLens.dataType == DataType.FLOAT32) validLens else validLens.toType(DataType.FLOAT32, false)
        lens = if (lens.shape.dimension() == 1) {
            lens.repeat(shape.get(1))
        } else {
            lens.reshape(-1)
        }
        val X2 = _X.reshape(Shape(-1, lastDim))
        // If lens shape does not match, fall back to unmasked softmax
        if (lens.shape.size() != X2.shape.get(0)) {
            if (lens !== validLens) {
                lens.close()
            }
            return _X.softmax(-1)
        }
        // Build mask on the last dimension and apply a large negative bias
        val arange = _X.manager.arange(lastDim.toFloat()).reshape(1, -1)
        val mask = arange.lt(lens.reshape(-1, 1))
        val maskFloat = mask.toType(_X.dataType, false)
        val invMaskFloat = mask.logicalNot().toType(_X.dataType, false)
        val maskedInput = X2.mul(maskFloat).add(invMaskFloat.mul(-1.0E6F))
        val out = maskedInput.softmax(-1).reshape(shape)
        arange.close()
        mask.close()
        maskFloat.close()
        invMaskFloat.close()
        maskedInput.close()
        if (lens !== validLens) {
            lens.close()
        }
        return out
    }

    fun transposeQkv(_X: NDArray, numHeads: Int): NDArray {
        var X = _X
        X = X.reshape(X.shape[0], X.shape[1], numHeads.toLong(), -1)
        X = X.transpose(0, 2, 1, 3)
        return X.reshape(-1, X.shape[2], X.shape[3])
    }

    fun transposeOutput(_X: NDArray, numHeads: Int): NDArray {
        var X = _X
        X = X.reshape(-1, numHeads.toLong(), X.shape[1], X.shape[2])
        X = X.transpose(0, 2, 1, 3)
        return X.reshape(X.shape[0], X.shape[1], -1)
    }

    class DotProductAttention(dropout: Float) : AbstractBlock() {
        private val dropout: Dropout
        var attentionWeights: NDArray? = null
        private var outputShapes: Array<Shape> = arrayOf<Shape>()

        init {
            this.dropout = Dropout.builder().optRate(dropout).build()
            addChildBlock("dropout", this.dropout)
        }

        override fun forwardInternal(
            ps: ParameterStore,
            inputs: NDList,
            training: Boolean,
            params: PairList<String, Any>?
        ): NDList {
            val queries = inputs[0]
            val keys = inputs[1]
            val values = inputs[2]
            val validLens = if (inputs.size > 3) inputs[3] else null
            val d = keys.shape.get(keys.shape.dimension() - 1).toDouble()
            val scores = queries.matMul(keys.swapAxes(1, 2)).div(Math.sqrt(d))
            attentionWeights = maskedSoftmax(scores, validLens)
            val result = dropout.forward(ps, NDList(attentionWeights), training, params)
            return NDList(result[0].matMul(values))
        }

        override fun getOutputShapes(inputShapes: Array<Shape>): Array<Shape> {
            return outputShapes
        }

        override fun initializeChildBlocks(manager: NDManager, dataType: DataType, vararg inputShapes: Shape) {
            val batchSize = inputShapes[0].get(0)
            val numQueries = inputShapes[0].get(1)
            val numKvs = inputShapes[1].get(1)
            val scoresShape = Shape(batchSize, numQueries, numKvs)
            dropout.initialize(manager, dataType, scoresShape)
            outputShapes = dropout.getOutputShapes(arrayOf(scoresShape))
        }
    }

    class MultiHeadAttention(numHiddens: Int, private val numHeads: Int, dropout: Float, useBias: Boolean) :
        AbstractBlock() {
        var attention: DotProductAttention
        private val W_k: Linear
        private val W_q: Linear
        private val W_v: Linear
        private val W_o: Linear

        init {
            attention = DotProductAttention(dropout)
            W_q = Linear.builder().setUnits(numHiddens.toLong()).optBias(useBias).build()
            addChildBlock("W_q", W_q)
            W_k = Linear.builder().setUnits(numHiddens.toLong()).optBias(useBias).build()
            addChildBlock("W_k", W_k)
            W_v = Linear.builder().setUnits(numHiddens.toLong()).optBias(useBias).build()
            addChildBlock("W_v", W_v)
            W_o = Linear.builder().setUnits(numHiddens.toLong()).optBias(useBias).build()
            addChildBlock("W_o", W_o)
            val dropout1 = Dropout.builder().optRate(dropout).build()
            addChildBlock("dropout", dropout1)
        }

        override fun forwardInternal(
            ps: ParameterStore,
            inputs: NDList,
            training: Boolean,
            params: PairList<String, Any>?
        ): NDList {
            var queries = inputs[0]
            var keys = inputs[1]
            var values = inputs[2]
            val validLens = if (inputs.size > 3) inputs[3] else null
            val expandedValidLens = validLens?.repeat(0, numHeads.toLong())
            queries = transposeQkv(W_q.forward(ps, NDList(queries), training, params)[0], numHeads)
            keys = transposeQkv(W_k.forward(ps, NDList(keys), training, params)[0], numHeads)
            values = transposeQkv(W_v.forward(ps, NDList(values), training, params)[0], numHeads)
            val attnInputs = if (expandedValidLens == null) {
                NDList(queries, keys, values)
            } else {
                NDList(queries, keys, values, expandedValidLens)
            }
            val output = attention.forward(ps, attnInputs, training, params)[0]
            val outputConcat = transposeOutput(output, numHeads)
            return NDList(W_o.forward(ps, NDList(outputConcat), training, params)[0])
        }

        override fun getOutputShapes(inputShapes: Array<Shape>): Array<Shape> {
            throw UnsupportedOperationException("Not implemented")
        }

        override fun initializeChildBlocks(manager: NDManager, dataType: DataType, vararg inputShapes: Shape) {
            val sub = manager.newSubManager()
            var queries = sub.zeros(inputShapes[0], dataType)
            var keys = sub.zeros(inputShapes[1], dataType)
            var values = sub.zeros(inputShapes[2], dataType)
            var validLens = sub.zeros(inputShapes[3], dataType)
            validLens = validLens.repeat(0, numHeads.toLong())

            val ps = ParameterStore(sub, false)

            W_q.initialize(manager, dataType, queries.shape)
            W_k.initialize(manager, dataType, keys.shape)
            W_v.initialize(manager, dataType, values.shape)

            queries = transposeQkv(W_q.forward(ps, NDList(queries), false)[0], numHeads)
            keys = transposeQkv(W_k.forward(ps, NDList(keys), false)[0], numHeads)
            values = transposeQkv(W_v.forward(ps, NDList(values), false)[0], numHeads)

            val list = NDList(queries, keys, values, validLens)
            attention.initialize(sub, dataType, *list.shapes)
            val output = attention.forward(ps, list, false)[0]
            val outputConcat = transposeOutput(output, numHeads)
            W_o.initialize(manager, dataType, outputConcat.shape)
            sub.close()
        }
    }

The residual connection requires that
the two inputs are of the same shape
so that [**the output tensor also has the same shape after the addition operation**].


In [None]:
val addNorm = AddNorm(0.5f)
addNorm.initialize(manager, DataType.FLOAT32, Shape(1,4))
addNorm.forward(ps, NDList(manager.ones(Shape(2,3,4)), manager.ones(Shape(2,3,4))), false)[0].shapeEquals(manager.ones(Shape(2,3,4)))

## Encoder
:label:`subsec_transformer-encoder`

With all the essential components to assemble
the Transformer encoder,
let's start by
implementing [**a single layer within the encoder**].
The following `TransformerEncoderBlock` class
contains two sublayers: multi-head self-attention and positionwise feed-forward networks,
where a residual connection followed by layer normalization is employed
around both sublayers.

In [None]:
    class TransformerEncoderBlock(
        numHiddens: Int,
        ffnNumHiddens: Long,
        numHeads: Int,
        dropout: Float,
        useBias: Boolean = false
    ) : AbstractBlock() {
        val attention = MultiHeadAttention(numHiddens, numHeads, dropout, useBias)
        val addnorm1 = AddNorm(dropout)
        val ffn = positionWiseFFN(ffnNumHiddens, numHiddens.toLong())
        val addnorm2 = AddNorm(dropout)
        init {
            addChildBlock("attention", attention)
            addChildBlock("addnorm1", addnorm1)
            addChildBlock("ffn", ffn)
            addChildBlock("addnorm2", addnorm2)
        }

        override fun forwardInternal(
            ps: ParameterStore,
            inputs: NDList,
            training: Boolean,
            params: PairList<String, Any>?
        ): NDList {
            val x = inputs[0]
            val validLens = inputs[1]
            val y = addnorm1.forward(ps, NDList(x, attention.forward(ps, NDList(x, x, x, validLens), training, params).singletonOrThrow()), training, params)
            val ret = addnorm2.forward(ps, NDList(y.singletonOrThrow(), ffn.forward(ps, y, training, params).singletonOrThrow()), training, params)
            return ret
        }

        override fun getOutputShapes(inputShapes: Array<Shape>): Array<Shape> {
            return arrayOf<Shape>()
        }

        override fun initializeChildBlocks(manager: NDManager, dataType: DataType, vararg inputShapes: Shape) {
            val shapes = arrayOf(inputShapes[0], inputShapes[0], inputShapes[0], inputShapes[1])
            attention.initialize(manager, dataType, *shapes)
            addnorm1.initialize(manager, dataType, inputShapes[0])
            ffn.initialize(manager, dataType, inputShapes[0])
            addnorm2.initialize(manager, dataType, inputShapes[0])
        }
    }

As we can see,
[**any layer in the Transformer encoder
does not change the shape of its input.**]


In [None]:
val X = manager.ones(Shape(2, 100, 24))
val validLens = manager.create(floatArrayOf(3f,2f))
val encoderBlock = TransformerEncoderBlock(24,48,8, 0.5f)
encoderBlock.initialize(manager, DataType.FLOAT32, X.shape, validLens.shape)
encoderBlock.forward(ps, NDList(X, validLens), false)

In the following [**Transformer encoder**] implementation,
we stack `num_blks` instances of the above `TransformerEncoderBlock` classes.
Since we use the fixed positional encoding
whose values are always between -1 and 1,
we multiply values of the learnable input embeddings
by the square root of the embedding dimension
to rescale before summing up the input embedding and the positional encoding.

In [None]:
    class TransformerEncoder(
        vocabSize: Int,
        val numHiddens: Int,
        ffnNumHiddens: Long,
        numHeads: Long,
        numBlks: Int,
        dropout: Float,
        useBias: Boolean = false
    ) : Encoder() {

        private val embedding: TrainableWordEmbedding
        val posEncoding = PositionalEncoding(numHiddens, dropout, 1000, manager)
        val blks = mutableListOf<TransformerEncoderBlock>()
        val attentionWeights = Array<NDArray?>(numBlks) { null }

        /* The RNN encoder for sequence to sequence learning. */
        init {
            val list: List<String> = (0 until vocabSize).map { it.toString() }
            val vocab: Vocabulary = DefaultVocabulary(list)
            // Embedding layer
            embedding = TrainableWordEmbedding.builder()
                .optNumEmbeddings(vocabSize)
                .setEmbeddingSize(numHiddens)
                .setVocabulary(vocab)
                .build()
            addChildBlock("embedding", embedding)
            repeat(numBlks) {
                blks.add(TransformerEncoderBlock(numHiddens, ffnNumHiddens, numHeads.toInt(), dropout, useBias))
            }
        }

        override fun forwardInternal(
            ps: ParameterStore,
            inputs: NDList,
            training: Boolean,
            params: PairList<String, Any>?
        ): NDList {
            var X = inputs[0]
            X = X.toType(DataType.INT64, false)
            val validLens = inputs[1]
            val emb = embedding.forward(ps, NDList(X), training, params).singletonOrThrow().mul(Math.sqrt(numHiddens.toDouble()))
            X = posEncoding.forward(ps, NDList(emb), training, params).singletonOrThrow()
            for (i in 0 until blks.size) {
                X = blks[i].forward(ps, NDList(X, validLens), training, params).singletonOrThrow()
                attentionWeights[i] = blks[i].attention.attention.attentionWeights
            }
            return NDList(X, validLens)
        }

        override fun initializeChildBlocks(manager: NDManager, dataType: DataType, vararg inputShapes: Shape) {
            embedding.initialize(manager, dataType, *inputShapes)
            for (blk in blks) {
                blk.initialize(manager, dataType, inputShapes[0].add(numHiddens.toLong()), inputShapes[1])
            }
        }

        override fun getOutputShapes(inputShapes: Array<Shape>): Array<Shape> {
            return arrayOf(inputShapes[0].add(numHiddens.toLong()), inputShapes[1])
        }
    }


Below we specify hyperparameters to [**create a two-layer Transformer encoder**].
The shape of the Transformer encoder output
is (batch size, number of time steps, `num_hiddens`).


In [None]:
val encoder = TransformerEncoder(200, 24, 48, 8, 2, 0.5f)
encoder.initialize(manager, DataType.FLOAT32, Shape(2,100), validLens.shape)
encoder.forward(ps, NDList(manager.ones(Shape(2, 100)), validLens), false)

## Decoder

As shown in :numref:`fig_transformer`,
[**the Transformer decoder
is composed of multiple identical layers**].
Each layer is implemented in the following
`TransformerDecoderBlock` class,
which contains three sublayers:
decoder self-attention,
encoder-decoder attention,
and positionwise feed-forward networks.
These sublayers employ
a residual connection around them
followed by layer normalization.


As we described earlier in this section,
in the masked multi-head decoder self-attention
(the first sublayer),
queries, keys, and values
all come from the outputs of the previous decoder layer.
When training sequence-to-sequence models,
tokens at all the positions (time steps)
of the output sequence
are known.
However,
during prediction
the output sequence is generated token by token;
thus,
at any decoder time step
only the generated tokens
can be used in the decoder self-attention.
To preserve auto-regression in the decoder,
its masked self-attention
specifies  `dec_valid_lens` so that
any query
only attends to
all positions in the decoder
up to the query position.

In [None]:
    class TransformerDecoderBlock(
        val numHiddens: Int,
        ffnNumHiddens: Long,
        numHeads: Long,
        dropout: Float,
        _i: Int
    ) : AbstractBlock() {
        val i = _i
        val attention1 = MultiHeadAttention(numHiddens, numHeads.toInt(), dropout, false)
        val addnorm1 = AddNorm(dropout)
        val attention2 = MultiHeadAttention(numHiddens, numHeads.toInt(), dropout, false)
        val addnorm2 = AddNorm(dropout)
        val ffn = positionWiseFFN(ffnNumHiddens, numHiddens.toLong())
        val addnorm3 = AddNorm(dropout)

        init {
            addChildBlock("attention1", attention1)
            addChildBlock("addnorm1", addnorm1)
            addChildBlock("attention2", attention2)
            addChildBlock("addnorm2", addnorm2)
            addChildBlock("ffn", ffn)
            addChildBlock("addnorm3", addnorm3)
        }

        override fun initializeChildBlocks(manager: NDManager, dataType: DataType, vararg inputShapes: Shape) {
            val decShape = inputShapes[0]
            val decValidLensShape = Shape(decShape.get(0), decShape.get(1))
            val encOutputsShape = if (inputShapes.size > 1 && inputShapes[1].dimension() == 3) {
                inputShapes[1]
            } else {
                decShape
            }
            val encValidLensShape = if (inputShapes.size > 2) {
                inputShapes[2]
            } else {
                Shape(decShape.get(0))
            }
            attention1.initialize(manager, dataType, decShape, decShape, decShape, decValidLensShape)
            addnorm1.initialize(manager, dataType, decShape)
            attention2.initialize(manager, dataType, decShape, encOutputsShape, encOutputsShape, encValidLensShape)
            addnorm2.initialize(manager, dataType, decShape)
            ffn.initialize(manager, dataType, decShape)
            addnorm3.initialize(manager, dataType, decShape)
        }
        override fun forwardInternal(
            ps: ParameterStore,
            inputs: NDList,
            training: Boolean,
            params: PairList<String, Any>?
        ): NDList {
            val input0 = inputs[0]
            val encOutputs = inputs[1]
            val envValidLens = inputs[2]
//        # During training, all the tokens of any output sequence are processed
//        # at the same time, so state[2][self.i] is None as initialized. When
//        # decoding any output sequence token by token during prediction,
//        # state[2][self.i] contains representations of the decoded output at
//        # the i-th block up to the current time step

            // TODO FIX IT
//            if state[2][self.i] is None:
//            key_values = X
//            else:
//            key_values = torch.cat((state[2][self.i], X), dim=1)
//            state[2][self.i] = key_values

            val cache = if (inputs.size > 3) inputs[3] else null
            val keyValues = if (cache == null) {
                input0
            } else {
                cache.concat(input0, 1)
            }
            /*
            } else if (inputs[3]!!.size(0) < i.toLong()) {
                keyValues = inputs[3].concat(inputs[0])
//                keyValues = inputs[3]
            } else {
//                println(inputs[3].get(i.toLong()).concat(input0))
//                val keyValue = inputs[3].get(i.toLong()).concat(input0, 1)
                keyValues = inputs[3].get(i.toLong()).concat(input0)
//                keyValues!!.set(NDIndex(i.toLong()), keyValue)
//                if (training) {
//                    keyValues = keyValue.expandDims(0)
//                } else {
//                    keyValues = inputs[3].concat(input0.expandDims(0))
//                }
            }
            println("KEYVALUES:: $keyValues")

             */

            var decValidLens: NDArray?
            if (training) {
                val batchSize = input0.shape[0]
                val numSteps = input0.shape[1]
                //  Shape of dec_valid_lens: (batch_size, num_steps), where every
                //  row is [1, 2, ..., num_steps]
                decValidLens = manager.arange(1f, (numSteps + 1).toFloat()).reshape(1, numSteps).repeat(0, batchSize)
            } else {
                decValidLens = null
            }
//        # Self-attention
            val X2 = attention1.forward(ps, NDList(input0, keyValues, keyValues, decValidLens), training)
            val Y = addnorm1.forward(ps, NDList(input0, X2.head()), training)
//        # Encoder-decoder attention. Shape of enc_outputs:
//        # (batch_size, num_steps, num_hiddens)
            val Y2 = attention2.forward(ps, NDList(Y.head(), encOutputs, encOutputs, envValidLens), training)
            val Z = addnorm2.forward(ps, NDList(Y.head(), Y2.head()), training)
            return NDList(addnorm3.forward(ps, NDList(Z.head(), ffn.forward(ps, NDList(Z), training).head()), training).head(), encOutputs, envValidLens, keyValues)
        }

        override fun getOutputShapes(inputShapes: Array<Shape>): Array<Shape> {
            return arrayOf<Shape>()
        }
    }


To facilitate scaled dot-product operations
in the encoder-decoder attention
and addition operations in the residual connections,
[**the feature dimension (`num_hiddens`) of the decoder is
the same as that of the encoder.**]

```{.python .input}
%%tab mxnet
decoder_blk = TransformerDecoderBlock(24, 48, 8, 0.5, 0)
decoder_blk.initialize()
X = np.ones((2, 100, 24))
state = [encoder_blk(X, valid_lens), valid_lens, [None]]
d2l.check_shape(decoder_blk(X, state)[0], X.shape)

In [None]:
    val decoderBlk = TransformerDecoderBlock(24, 48, 8, 0.5f, 0)
    val X = manager.ones(Shape(2, 100, 24))
    // Self-contained setup for this demo cell
    val validLens = manager.create(floatArrayOf(3f, 2f))
    val encoderBlock = TransformerEncoderBlock(24, 48, 8, 0.5f)
    encoderBlock.initialize(manager, DataType.FLOAT32, X.shape, validLens.shape)
    val input = NDList(X, validLens)

    decoderBlk.initialize(manager, DataType.FLOAT32, *input.shapes)
    val state = encoderBlock.forward(ps, NDList(X, validLens, validLens), false)
    println(decoderBlk.forward(ps, NDList(X, state.head(), validLens, null), false))


Now we [**construct the entire Transformer decoder**]
composed of `num_blks` instances of `TransformerDecoderBlock`.
In the end,
a fully connected layer computes the prediction
for all the `vocab_size` possible output tokens.
Both of the decoder self-attention weights
and the encoder-decoder attention weights
are stored for later visualization.


In [None]:
    class TransformerDecoder(
        vocabSize: Int,
        val numHiddens: Int,
        ffnNumHiddens: Int,
        numHeads: Int,
        val numBlks: Int,
        dropout: Float
    ) : AttentionDecoder() {
        val list: List<String> = (0 until vocabSize).map { it.toString() }
        val vocab: Vocabulary = DefaultVocabulary(list)
        val embedding = TrainableWordEmbedding.builder()
            .optNumEmbeddings(vocabSize)
            .setEmbeddingSize(numHiddens)
            .setVocabulary(vocab)
            .build()
        val posEncoding = PositionalEncoding(numHiddens, dropout, 1000, manager)
        val blks = mutableListOf<TransformerDecoderBlock>()

        //            val attentionWeights = Array<NDArray?>(numBlks) { null }
        val linear = Linear.builder().setUnits(vocabSize.toLong()).build()
        var attentionWeightsArr2: MutableList<NDArray?>? = null
        var attentionWeightsArr1: MutableList<NDArray?>? = null

        init {
            addChildBlock("embedding", embedding)
            repeat(numBlks) {
                blks.add(
                    TransformerDecoderBlock(
                        numHiddens,
                        ffnNumHiddens.toLong(),
                        numHeads.toLong(),
                        dropout,
                        it
                    )
                )
            }
        }

        override fun initState(input: NDList): NDList {
            val encOutputs = input[0]
            val encValidLens = input[1]
            val state = NDList(encOutputs, encValidLens)
            repeat(numBlks) {
                state.add(null)
            }
            return state
        }

        override fun forwardInternal(
            ps: ParameterStore,
            inputs: NDList,
            training: Boolean,
            params: PairList<String, Any>?
        ): NDList {
            var X = inputs[0]
            X = X.toType(DataType.INT64, false)
            val state = inputs.subNDList(1)
            val encOutputs = state[0]
            val encValidLens = state[1]
            val pos = posEncoding.forward(
                ps,
                NDList(embedding.forward(ps, NDList(X), training, params).head().mul(Math.sqrt(numHiddens.toDouble()))),
                training,
                params
            )
            X = pos.head()
            attentionWeightsArr1 = if (!training) mutableListOf() else null
            attentionWeightsArr2 = if (!training) mutableListOf() else null
            val cacheStart = 2
            val newState = NDList(encOutputs, encValidLens)
            for (i in 0 until blks.size) {
                val cache = if (state.size > cacheStart + i) state[cacheStart + i] else null
                val blkOut = blks[i].forward(ps, NDList(X, encOutputs, encValidLens, cache), training, params)
                X = blkOut.head()
                val newCache = if (blkOut.size > 3) blkOut[3] else null
                newState.add(newCache)
                if (!training) {
                    attentionWeightsArr1!!.add(blks[i].attention1.attention.attentionWeights)
                    attentionWeightsArr2!!.add(blks[i].attention2.attention.attentionWeights)
                }
            }
            val ret = linear.forward(ps, NDList(X), training, params)
            return NDList(ret.head()).addAll(newState)
        }

        override fun initializeChildBlocks(manager: NDManager, dataType: DataType, vararg inputShapes: Shape) {
            embedding.initialize(manager, dataType, inputShapes[0])
            val decShape = inputShapes[0].add(numHiddens.toLong())
            posEncoding.initialize(manager, dataType, decShape)
            val encOutputsShape = inputShapes[1]
            val encValidLensShape = if (inputShapes.size > 2) {
                inputShapes[2]
            } else {
                Shape(encOutputsShape.get(0))
            }
            for (blk in blks) {
                blk.initialize(manager, dataType, decShape, encOutputsShape, encValidLensShape)
            }
        }
    }


## [**Training**]

Let's instantiate an encoder-decoder model
by following the Transformer architecture.
Here we specify that
both the Transformer encoder and the Transformer decoder
have 2 layers using 4-head attention.
Similar to :numref:`sec_seq2seq_training`,
we train the Transformer model
for sequence to sequence learning on the English-French machine translation dataset.

In [None]:
    var trainedNet: EncoderDecoder? = null
    var trainedSrcVocab: Vocab? = null
    var trainedTgtVocab: Vocab? = null
    var trainedNumSteps: Int = 0

    fun train() {
//        num_hiddens, num_blks, dropout = 256, 2, 0.2
//        ffn_num_hiddens, num_heads = 64, 4

        val numHiddens = 256
        val numBlks = 2
        val ffnNumHiddens = 64
        val numHeads = 4
        val batchSize = 2
        // Note: 5 epochs is a minimal demo; increase for better quality.
        val numEpochs = 5
        val numSteps = 35

        val dropout = 0.2f
        val lr = 0.001f
        val device = manager.device

        val dataNMT = NMT.loadDataNMT(batchSize, numSteps, 600)
        val dataset: ArrayDataset = dataNMT.first
        val srcVocab: Vocab = dataNMT.second.first
        val tgtVocab: Vocab = dataNMT.second.second

        val encoder = TransformerEncoder(srcVocab.length(), numHiddens, ffnNumHiddens.toLong(), numHeads.toLong(), numBlks, dropout)
        encoder.initialize(manager, DataType.FLOAT32, Shape(batchSize.toLong(), numSteps.toLong()), Shape(batchSize.toLong()))

        val decoder = TransformerDecoder(tgtVocab.length(), numHiddens, ffnNumHiddens, numHeads, numBlks, dropout)
        decoder.initialize(manager, DataType.FLOAT32, Shape(batchSize.toLong(), numSteps.toLong()), Shape(batchSize.toLong()))

        val net = EncoderDecoder(encoder, decoder)
        trainedNet = net
        trainedSrcVocab = srcVocab
        trainedTgtVocab = tgtVocab
        trainedNumSteps = numSteps
        fun trainSeq2Seq(
            net: EncoderDecoder,
            dataset: ArrayDataset,
            lr: Float,
            numEpochs: Int,
            tgtVocab: Vocab,
            device: Device
        ) {
            val loss: Loss = MaskedSoftmaxCELoss()
            val lrt: Tracker = Tracker.fixed(lr)
            val adam: Optimizer = Optimizer.adam().optLearningRateTracker(lrt).build()
            val config: DefaultTrainingConfig = DefaultTrainingConfig(loss)
                .optOptimizer(adam) // Optimizer (loss function)
                .optInitializer(XavierInitializer(), "")
            val model: ai.djl.Model = ai.djl.Model.newInstance("")
            model.block = net
            val trainer: Trainer = model.newTrainer(config)
//    val animator = Animator()
            var watch: StopWatch
            var metric: Accumulator
            var lossValue = 0.0
            var speed = 0.0
            for (epoch in 1..numEpochs) {
                watch = StopWatch()
                metric = Accumulator(2) // Sum of training loss, no. of tokens
                // Iterate over dataset
                for (batch in dataset.getData(manager)) {
                    val X: NDArray = batch.data.get(0)
                    val lenX: NDArray = batch.data.get(1)
                    val Y: NDArray = batch.labels.get(0)
                    val lenY: NDArray = batch.labels.get(1)
                    val bos: NDArray = manager
                        .full(Shape(Y.shape[0]), tgtVocab.getIdx("<bos>"))
                        .reshape(-1, 1)
                    val decInput: NDArray = NDArrays.concat(
                        NDList(bos, Y.get(NDIndex(":, :-1"))),
                        1
                    ) // Teacher forcing
                    Engine.getInstance().newGradientCollector().use { gc ->
                        val yHat: NDArray = net.forward(
                            ParameterStore(manager, false),
                            NDList(X, decInput, lenX),
                            true
                        )
                            .get(0)
                        val l = loss.evaluate(NDList(Y, lenY), NDList(yHat))
                        gc.backward(l)
                        metric.add(floatArrayOf(l.sum().getFloat(), lenY.sum().getLong().toFloat()))
                    }
                    TrainingChapter9.gradClipping(net, 1, manager)
                    // Update parameters
                    trainer.step()
                }
                lossValue = metric.get(0).toDouble() / metric.get(1)
                speed = metric.get(1) / watch.stop()
//                if ((epoch + 1) % 10 == 0) {
//            animator.add(epoch + 1, lossValue.toFloat(), "loss")
//            animator.show()
                println("${epoch + 1} : $lossValue")
//                }
            }
            println("loss: %.3f, %.1f tokens/sec on %s%n".format(lossValue, speed, device.toString()))
        }
        trainSeq2Seq(net, dataset, lr, numEpochs, tgtVocab, device)
    }
    this.train()


After training,
we use the Transformer model
to [**translate a few English sentences**] into French and compute their BLEU scores.


In [None]:
        fun predictSeq2Seq(
            net: EncoderDecoder,
            srcSentence: String,
            srcVocab: Vocab,
            tgtVocab: Vocab,
            numSteps: Int,
            saveAttentionWeights: Boolean
        ): Pair<String, List<NDArray?>> {
            val srcTokens = srcVocab.getIdxs(srcSentence.lowercase(Locale.getDefault()).split(" ")) + listOf(srcVocab.getIdx("<eos>"))
            val encValidLen = manager.create(srcTokens.size).reshape(1)
            val truncateSrcTokens = NMT.truncatePad(srcTokens, numSteps, srcVocab.getIdx("<pad>"))
            // Add the batch axis
            val encX = manager.create(truncateSrcTokens.toIntArray()).expandDims(0)
            val encOutputs = net.encoder.forward(ParameterStore(manager, false), NDList(encX, encValidLen), false)
            var decState = net.decoder.initState(encOutputs)
            // Add the batch axis
            var decX = manager.create(floatArrayOf(tgtVocab.getIdx("<bos>").toFloat())).repeat(35).expandDims(0)
            val outputSeq: MutableList<Int> = mutableListOf()
            val attentionWeightSeq: MutableList<NDArray?> = mutableListOf()
            for (i in 0 until numSteps) {
//                println(i)
                val output = net.decoder.forward(
                    ParameterStore(manager, false),
                    NDList(decX).addAll(decState),
                    false
                )

//                val encOutputs = encoder.forward(parameterStore, encX, training, params)
//                val decState = decoder.initState(encOutputs)
//                val inp = NDList(decX).addAll(decState)
//                return decoder.forward(parameterStore, inp, training, params)

                val Y = output[0]
                decState = output.subNDList(1)
                // We use the token with the highest prediction likelihood as the input
                // of the decoder at the next time step
//                println("Y:::$Y")
//                println("Y(1)::: ${Y.get(NDIndex("0,2")).argMax(0).getLong().toInt()}")
//                decX = Y.argMax(2)
//                println("DECX: ${decX.squeeze(0)}")
//                val pred = decX.squeeze(0).getLong().toInt()
                val pred = Y.get(NDIndex("0,2")).argMax(0).getLong().toInt()
                // Save attention weights (to be covered later)
                if (saveAttentionWeights) {
                    attentionWeightSeq.add(net.decoder.attentionWeights)
                }
                // Once the end-of-sequence token is predicted, the generation of the
                // output sequence is complete
                if (pred == tgtVocab.getIdx("<eos>")) {
                    break
                }
                outputSeq.add(pred)
            }
            val outputString: String = tgtVocab.toTokens(outputSeq).joinToString(separator = " ")
            return Pair(outputString, attentionWeightSeq.toList())
        }

        /* Compute the BLEU. */
        fun bleu(predSeq: String, labelSeq: String, k: Int): Double {
            val predTokens = predSeq.split(" ")
            val labelTokens = labelSeq.split(" ")
            val lenPred = predTokens.size
            val lenLabel = labelTokens.size
            var score = Math.exp(Math.min(0.toDouble(), 1.0 - lenLabel / lenPred))
            for (n in 1 until k + 1) {
                var numMatches = 0
                val labelSubs = mutableMapOf<String, Int>()
                for (i in 0 until lenLabel - n + 1) {
                    val key = labelTokens.subList(i, i + n).joinToString(separator = " ")
                    labelSubs.put(key, labelSubs.getOrDefault(key, 0) + 1)
                }
                for (i in 0 until lenPred - n + 1) {
                    // val key =predTokens.subList(i, i + n).joinToString(" ")
                    val key = predTokens.subList(i, i + n).joinToString(separator = " ")
                    if (labelSubs.getOrDefault(key, 0) > 0) {
                        numMatches += 1
                        labelSubs.put(key, labelSubs.getOrDefault(key, 0) - 1)
                    }
                }
                score *= Math.pow(numMatches.toDouble() / (lenPred - n + 1).toDouble(), Math.pow(0.5, n.toDouble()))
            }
            return score
        }

        val net = trainedNet ?: error("Run the training cell first")
        val srcVocab = trainedSrcVocab ?: error("Run the training cell first")
        val tgtVocab = trainedTgtVocab ?: error("Run the training cell first")
        val numSteps = trainedNumSteps

        val engs = arrayOf("go .", "i lost .", "he's calm .", "i'm home .")
        val fras = arrayOf("va !", "j'ai perdu .", "il est calme .", "je suis chez moi .")
        for (i in engs.indices) {
            val pair = predictSeq2Seq(net, engs[i], srcVocab, tgtVocab, numSteps, false)
            val translation: String = pair.first
            println("%s => %s, bleu %.3f".format(engs[i], translation, bleu(translation, fras[i], 2)))
        }

Let's [**visualize the Transformer attention weights**] when translating the last English sentence into French.
The shape of the encoder self-attention weights
is (number of encoder layers, number of attention heads, `num_steps` or number of queries, `num_steps` or number of key-value pairs).


In the encoder self-attention,
both queries and keys come from the same input sequence.
Since padding tokens do not carry meaning,
with specified valid length of the input sequence,
no query attends to positions of padding tokens.
In the following,
two layers of multi-head attention weights
are presented row by row.
Each head independently attends
based on a separate representation subspaces of queries, keys, and values.

[**To visualize both the decoder self-attention weights and the encoder-decoder attention weights,
we need more data manipulations.**]
For example,
we fill the masked attention weights with zero.
Note that
the decoder self-attention weights
and the encoder-decoder attention weights
both have the same queries:
the beginning-of-sequence token followed by
the output tokens and possibly
end-of-sequence tokens.

Due to the auto-regressive property of the decoder self-attention,
no query attends to key-value pairs after the query position.

Similar to the case in the encoder self-attention,
via the specified valid length of the input sequence,
[**no query from the output sequence
attends to those padding tokens from the input sequence.**]

Although the Transformer architecture
was originally proposed for sequence-to-sequence learning,
as we will discover later in the book,
either the Transformer encoder
or the Transformer decoder
is often individually used
for different deep learning tasks.


## Summary

The Transformer is an instance of the encoder-decoder architecture, 
though either the encoder or the decoder can be used individually in practice.
In the Transformer architecture, multi-head self-attention is used 
for representing the input sequence and the output sequence, 
though the decoder has to preserve the auto-regressive property via a masked version.
Both the residual connections and the layer normalization in the Transformer
are important for training a very deep model.
The positionwise feed-forward network in the Transformer model 
transforms the representation at all the sequence positions using the same MLP.

## Exercises

1. Train a deeper Transformer in the experiments. How does it affect the training speed and the translation performance?
1. Is it a good idea to replace scaled dot-product attention with additive attention in the Transformer? Why?
1. For language modeling, should we use the Transformer encoder, decoder, or both? How to design this method?
1. What can be challenges to Transformers if input sequences are very long? Why?
1. How to improve computational and memory efficiency of Transformers? Hint: you may refer to the survey paper by :citet:`Tay.Dehghani.Bahri.ea.2020`.
