# Bahdanau Attention

:label:`sec_seq2seq_attention`

When we encountered machine translation in :numref:`sec_seq2seq`,
we designed an encoder-decoder architecture 
based on two RNNs for sequence-to-sequence learning.
Specifically, the RNN encoder transforms a variable-length sequence
into a fixed-shape context variable.
Then, the RNN decoder generates the output (target) sequence token by token
based on the generated tokens and the context variable.
However, while not all input (source) tokens
are relevant when decoding a particular target token,
the *same* context variable
that encodes the entire input sequence
is still used at each decoding step.

In a separate but related challenge 
of handwriting generation for a given text sequence,
:citet:`Graves.2013` designed a differentiable attention model
to align text characters with the much longer pen trace,
where the alignment moves only in one direction.
Inspired by the idea of learning to align,
:citet:`Bahdanau.Cho.Bengio.2014` proposed a differentiable attention model
without the severe unidirectional alignment limitation.
When predicting a token,
if not all the input tokens are relevant,
the model aligns (or attends)
only to parts of the input sequence 
that are deemed relevant to the current prediction.
This is achieved by constructing the context variable
via an attention mechanism. 

## Model

To describe the  Bahdanau-style attention mechanism 
for the RNN encoder-decoder below,
we follow the same notation as in :numref:`sec_seq2seq`.
The new attention-based model follows 
the sequence-to-sequence architecture of :numref:`sec_seq2seq`,
but with the context variable $\mathbf{c}$ in :eqref:`eq_seq2seq_s_t`
replaced by $\mathbf{c}_{t'}$
at any decoding time step $t'$.
Suppose that there are $T$ tokens in the input sequence,
the context variable at the decoding time step $t'$
is the output of attention pooling:

$$\mathbf{c}_{t'} = \sum_{t=1}^T \alpha(\mathbf{s}_{t' - 1}, \mathbf{h}_t) \mathbf{h}_t,$$

where the decoder hidden state
$\mathbf{s}_{t' - 1}$ at time step $t' - 1$ is the query,
and the encoder hidden states $\mathbf{h}_t$
are both the keys and values,
and the attention weight $\alpha$
is computed as in
:eqref:`eq_attn-scoring-alpha`
using the additive attention scoring function
defined by :eqref:`eq_additive-attn`.

This RNN encoder-decoder architecture
augmented  Bahdanau attention 
is depicted in :numref:`fig_s2s_attention_details`.

![Layers in an RNN encoder-decoder model with Bahdanau attention.](https://github.com/d2l-ai/d2l-en/raw/master/img/seq2seq-attention-details.svg)
:label:`fig_s2s_attention_details`


In [1]:
%use @file[../djl.json]
%use lets-plot
@file:DependsOn("../D2J-1.0-SNAPSHOT.jar")
import jp.live.ugai.d2j.timemachine.Vocab
import jp.live.ugai.d2j.RNNModel
import jp.live.ugai.d2j.util.StopWatch
import jp.live.ugai.d2j.util.Accumulator
import jp.live.ugai.d2j.util.Training
import jp.live.ugai.d2j.util.TrainingChapter9
import jp.live.ugai.d2j.lstm.Decoder
import jp.live.ugai.d2j.lstm.Encoder
import jp.live.ugai.d2j.lstm.EncoderDecoder
import jp.live.ugai.d2j.util.NMT
import jp.live.ugai.d2j.attention.AdditiveAttention
import jp.live.ugai.d2j.Seq2SeqEncoder
import jp.live.ugai.d2j.MaskedSoftmaxCELoss
import java.util.Locale
import kotlin.random.Random
import kotlin.collections.List
import kotlin.collections.Map
import kotlin.Pair

In [2]:
import ai.djl.modality.nlp.DefaultVocabulary
import ai.djl.modality.nlp.Vocabulary
import ai.djl.modality.nlp.embedding.TrainableWordEmbedding

In [3]:
System.setProperty("org.slf4j.simpleLogger.showThreadName", "false")
System.setProperty("org.slf4j.simpleLogger.showLogName", "true")
System.setProperty("org.slf4j.simpleLogger.log.ai.djl.pytorch", "WARN")
System.setProperty("org.slf4j.simpleLogger.log.ai.djl.mxnet", "ERROR")
System.setProperty("org.slf4j.simpleLogger.log.ai.djl.ndarray.index", "ERROR")
System.setProperty("org.slf4j.simpleLogger.log.ai.djl.tensorflow", "WARN")

val manager = NDManager.newBaseManager()
val ps = ParameterStore(manager, false)

## Defining the Decoder with Attention

To implement the RNN encoder-decoder 
with Bahdanau attention,
we only need to redefine the decoder.
To visualize the learned attention weights more conveniently,
the following `AttentionDecoder` class
defines [**the base interface for
decoders with attention mechanisms**].


In [4]:
abstract class AttentionDecoder : Decoder() {
    var attentionWeightArr: MutableList<Pair<FloatArray,Shape>> = mutableListOf()
    abstract override fun initState(encOutputs: NDList): NDList
    override fun getOutputShapes(inputShapes: Array<Shape>): Array<Shape> {
        throw UnsupportedOperationException("Not implemented")
    }
}

Now let's [**implement
the RNN decoder with Bahdanau attention**]
in the following `Seq2SeqAttentionDecoder` class.
The state of the decoder is initialized with
(i) the encoder final-layer hidden states at all the time steps 
(as keys and values of the attention);
(ii) the encoder all-layer hidden state at the final time step 
(to initialize the hidden state of the decoder);
and (iii) the encoder valid length 
(to exclude the padding tokens in attention pooling).
At each decoding time step,
the decoder final-layer hidden state at the previous time step 
is used as the query of the attention.
As a result, both the attention output
and the input embedding are concatenated
as input of the RNN decoder.

In [5]:
class Seq2SeqAttentionDecoder(
    vocabSize: Long,
    private val embedSize: Int,
    private val numHiddens: Int,
    private val numLayers: Int,
    dropout: Float = 0f
) : AttentionDecoder() {
    val attention = AdditiveAttention(numHiddens, dropout)
    val embedding: TrainableWordEmbedding
    val rnn = GRU.builder()
        .setNumLayers(numLayers)
        .setStateSize(numHiddens)
        .optReturnState(true)
        .optBatchFirst(false)
        .optDropRate(dropout)
        .build()
    val linear = Linear.builder().setUnits(vocabSize).build()

    init {
        val list: List<String> = (0 until vocabSize).map { it.toString() }
        val vocab: Vocabulary = DefaultVocabulary(list)
        // Embedding layer
        embedding = TrainableWordEmbedding.builder()
            .optNumEmbeddings(vocabSize.toInt())
            .setEmbeddingSize(embedSize)
            .setVocabulary(vocab)
            .build()
        addChildBlock("embedding", embedding)
        addChildBlock("rnn", rnn)
        addChildBlock("attention", attention)
        addChildBlock("linear", linear)
    }

    override fun initState(encOutputs: NDList): NDList {
        val outputs = encOutputs[0]
        val hiddenState = encOutputs[1]
        val encValidLens = if (encOutputs.size >= 3) encOutputs[2] else manager.create(0)
        return NDList(outputs.swapAxes(0, 1), hiddenState, encValidLens)
    }

    override fun initializeChildBlocks(manager: NDManager, dataType: DataType, vararg inputShapes: Shape) {
        embedding.initialize(manager, dataType, inputShapes[0])
        attention.initialize(manager, DataType.FLOAT32, inputShapes[1], inputShapes[1])
        rnn.initialize(manager, DataType.FLOAT32, Shape(1, 4, (numHiddens + embedSize).toLong()))
        linear.initialize(manager, DataType.FLOAT32, Shape(4, numHiddens.toLong()))
    }

    override fun forwardInternal(
        ps: ParameterStore,
        inputs: NDList,
        training: Boolean,
        params: PairList<String, Any>?
    ): NDList {
        var outputs: NDArray? = null
        val encOutputs = inputs[1]
        var hiddenState: NDArray = inputs[2]
        val encValidLens = inputs[3]
        var input = inputs[0]
//        # Shape of enc_outputs: (batch_size, num_steps, num_hiddens).
//        # Shape of hidden_state: (num_layers, batch_size, num_hiddens)
//        enc_outputs, hidden_state, enc_valid_lens = state
//        # Shape of the output X: (num_steps, batch_size, embed_size)
//        X = self.embedding(X).permute(1, 0, 2)
        // The output `X` shape: (`batchSize`(4), `numSteps`(7), `embedSize`(8))
        val X = embedding.forward(ps, NDList(input), training, params)[0].swapAxes(0, 1)
        attentionWeightArr = mutableListOf()
        for (x in 0 until X.size(0)) {
            val query = hiddenState[-1].expandDims(1)
            val context = attention.forward(ps, NDList(query, encOutputs, encOutputs, encValidLens), training, params)
            val xArray = context[0].concat(X[x].expandDims(1), -1)
            val out = rnn.forward(ps, NDList(xArray.swapAxes(0, 1), hiddenState), training, params)
            hiddenState = out[1]
            outputs = if (outputs == null) out[0] else outputs.concat(out[0])
//            println(attention.attentionWeights?.shape)
//            println(attentionWeights)
            if (attention.attentionWeights != null) {
                val att = attention.attentionWeights!!
                attentionWeightArr.add(Pair(att.toFloatArray(), att.shape))
            }
        }
        val ret = linear.forward(ps, NDList(outputs), training)
        return NDList(ret[0].swapAxes(0, 1), encOutputs, hiddenState, encValidLens)
    }
}

In the following, we [**test the implemented
decoder**] with Bahdanau attention
using a minibatch of 4 sequence inputs
of 7 time steps.

In [6]:
    val vocabSize = 10
    val embedSize = 8
    val numHiddens = 16
    val numLayers = 2
    val batchSize = 4
    val numSteps = 7
    val encoder = Seq2SeqEncoder(vocabSize, embedSize, numHiddens, numLayers, 0f)
    encoder.initialize(manager, DataType.FLOAT32, Shape(batchSize.toLong(), batchSize.toLong()))
    val decoder = Seq2SeqAttentionDecoder(vocabSize.toLong(), embedSize, numHiddens, numLayers)
    decoder.initialize(
        manager,
        DataType.FLOAT32,
        Shape(batchSize.toLong(), numHiddens.toLong()),
        Shape(batchSize.toLong(), batchSize.toLong(), numHiddens.toLong()),
        Shape(1, batchSize.toLong(), (numHiddens + embedSize).toLong()),
        Shape(4, numHiddens.toLong())
    )
    val X = manager.zeros(Shape(batchSize.toLong(), numSteps.toLong()))
    val output = encoder.forward(ps, NDList(X), false)
    output.add(manager.create(0))
    val state = decoder.initState(output)
    println("State: $state")
    val ff = decoder.forward(ps, NDList(X).addAll(state), false)
    println(ff)
    println(ff[0].shape) // (batch_size, num_steps, vocab_size) (4, 7, 10)
    println(ff[1].shape) // (batch_size, num_steps, num_hiddens) (4, 7, 16)
    println(ff[2][0].shape) // (batch_size, num_hiddens) (4, 16)


State: NDList size: 3
0 : (4, 7, 16) float32
1 : (2, 4, 16) float32
2 : () int32

NDList size: 4
0 : (4, 7, 10) float32
1 : (4, 7, 16) float32
2 : (2, 4, 16) float32
3 : () int32

(4, 7, 10)
(4, 7, 16)
(4, 16)


## [**Training**]

Similar to :numref:`sec_seq2seq_training`,
here we specify hyperparameters,
instantiate
an encoder and a decoder with Bahdanau attention,
and train this model for machine translation.


In [7]:
fun trainSeq2Seq(
        net: EncoderDecoder,
        dataset: ArrayDataset,
        lr: Float,
        numEpochs: Int,
        tgtVocab: Vocab,
        device: Device
    ) {
        val loss: Loss = MaskedSoftmaxCELoss()
        val lrt: Tracker = Tracker.fixed(lr)
        val adam: Optimizer = Optimizer.adam().optLearningRateTracker(lrt).build()
        val config: DefaultTrainingConfig = DefaultTrainingConfig(loss)
            .optOptimizer(adam) // Optimizer (loss function)
            .optInitializer(XavierInitializer(), "")
        val model: Model = Model.newInstance("")
        model.block = net
        val trainer: Trainer = model.newTrainer(config)
//    val animator = Animator()
        var watch: StopWatch
        var metric: Accumulator
        var lossValue = 0.0
        var speed = 0.0
        for (epoch in 1..numEpochs) {
            watch = StopWatch()
            metric = Accumulator(2) // Sum of training loss, no. of tokens
            // Iterate over dataset
            for (batch in dataset.getData(manager)) {
                val X: NDArray = batch.data.get(0)
                val lenX: NDArray = batch.data.get(1)
                val Y: NDArray = batch.labels.get(0)
                val lenY: NDArray = batch.labels.get(1)
                val bos: NDArray = manager
                    .full(Shape(Y.shape[0]), tgtVocab.getIdx("<bos>"))
                    .reshape(-1, 1)
                val decInput: NDArray = NDArrays.concat(
                    NDList(bos, Y.get(NDIndex(":, :-1"))),
                    1
                ) // Teacher forcing
                Engine.getInstance().newGradientCollector().use { gc ->
                    val yHat: NDArray = net.forward(
                        ParameterStore(manager, false),
                        NDList(X, decInput, lenX),
                        true
                    )
                        .get(0)
                    val l = loss.evaluate(NDList(Y, lenY), NDList(yHat))
                    gc.backward(l)
                    metric.add(floatArrayOf(l.sum().getFloat(), lenY.sum().getLong().toFloat()))
                }
                TrainingChapter9.gradClipping(net, 1, manager)
                // Update parameters
                trainer.step()
            }
            lossValue = metric.get(0).toDouble() / metric.get(1)
            speed = metric.get(1) / watch.stop()
            if ((epoch + 1) % 10 == 0) {
//            animator.add(epoch + 1, lossValue.toFloat(), "loss")
//            animator.show()
                println("${epoch + 1} : $lossValue")
            }
        }
        println("loss: %.3f, %.1f tokens/sec on %s%n".format(lossValue, speed, device.toString()))
    }

In [12]:
    val embedSize = 32
    val numHiddens = 32
    val numLayers = 2
    val batchSize = 64
    val numSteps = 10
    val numEpochs = Integer.getInteger("MAX_EPOCH", 300)

    val dropout = 0.2f
    val lr = 0.001f
    val device = manager.device

    val dataNMT = NMT.loadDataNMT(batchSize, numSteps, 600)
    val dataset: ArrayDataset = dataNMT.first
    val srcVocab: Vocab = dataNMT.second.first
    val tgtVocab: Vocab = dataNMT.second.second

    val encoder = Seq2SeqEncoder(srcVocab.length(), embedSize, numHiddens, numLayers, dropout)
    val decoder = Seq2SeqAttentionDecoder(tgtVocab.length().toLong(), embedSize, numHiddens, numLayers)

    val net = EncoderDecoder(encoder, decoder)
    trainSeq2Seq(net, dataset, lr, numEpochs, tgtVocab, device)

10 : 0.13338356663607462
20 : 0.08894217234810052
30 : 0.07230630909913359
40 : 0.061881931969989055
50 : 0.053327786364174415
60 : 0.04650633600648582
70 : 0.04107880095054389
80 : 0.03648166288685864
90 : 0.03246828309666777
100 : 0.02908570801165572
110 : 0.026272295967231157
120 : 0.023878186859046788
130 : 0.021688097248580313
140 : 0.019819132787230822
150 : 0.01831368600604385
160 : 0.017004677368308182
170 : 0.015935419676389246
180 : 0.01492106299942659
190 : 0.01420922037280014
200 : 0.013304892415748906
210 : 0.012707202465503246
220 : 0.012035523001390477
230 : 0.011537431779221164
240 : 0.011089678497494605
250 : 0.010782665176121375
260 : 0.010251179214351433
270 : 0.00999254469137781
280 : 0.009777232728948372
290 : 0.009408628785765016
300 : 0.00927302462995545
loss: 0.009, 2065.8 tokens/sec on cpu()



After the model is trained,
we use it to [**translate a few English sentences**]
into French and compute their BLEU scores.


In [13]:
    fun predictSeq2Seq(
        net: EncoderDecoder,
        srcSentence: String,
        srcVocab: Vocab,
        tgtVocab: Vocab,
        numSteps: Int,
        device: Device,
        saveAttentionWeights: Boolean
    ): Pair<String, List<List<Pair<FloatArray, Shape>>>> {
        val srcTokens = srcVocab.getIdxs(srcSentence.lowercase(Locale.getDefault()).split(" ")) + listOf(srcVocab.getIdx("<eos>"))
        val encValidLen = manager.create(srcTokens.size)
        val truncateSrcTokens = NMT.truncatePad(srcTokens, numSteps, srcVocab.getIdx("<pad>"))
        // Add the batch axis
        val encX = manager.create(truncateSrcTokens.toIntArray()).expandDims(0)
        val encOutputs = net.encoder.forward(ParameterStore(manager, false), NDList(encX, encValidLen), false)
        var decState = net.decoder.initState(encOutputs.addAll(NDList(encValidLen)))
        // Add the batch axis
        var decX = manager.create(floatArrayOf(tgtVocab.getIdx("<bos>").toFloat())).expandDims(0)
        val outputSeq: MutableList<Int> = mutableListOf()
        val attentionWeightSeq: MutableList<List<Pair<FloatArray, Shape>>> = mutableListOf()
        for (i in 0 until numSteps) {
            val output = net.decoder.forward(
                ParameterStore(manager, false),
                NDList(decX).addAll(decState),
                false
            )
            val Y = output[0]
            decState = output.subNDList(1)
            // We use the token with the highest prediction likelihood as the input
            // of the decoder at the next time step
            decX = Y.argMax(2)
            val pred = decX.squeeze(0).getLong().toInt()
            // Save attention weights (to be covered later)
            if (saveAttentionWeights) {
                attentionWeightSeq.add((net.decoder as AttentionDecoder).attentionWeightArr)
            }
            // Once the end-of-sequence token is predicted, the generation of the
            // output sequence is complete
            if (pred == tgtVocab.getIdx("<eos>")) {
                break
            }
            outputSeq.add(pred)
        }
        val outputString: String = tgtVocab.toTokens(outputSeq).joinToString(separator = " ")
        return Pair(outputString, attentionWeightSeq)
    }

    /* Compute the BLEU. */
    fun bleu(predSeq: String, labelSeq: String, k: Int): Double {
        val predTokens = predSeq.split(" ")
        val labelTokens = labelSeq.split(" ")
        val lenPred = predTokens.size
        val lenLabel = labelTokens.size
        var score = Math.exp(Math.min(0.toDouble(), 1.0 - lenLabel / lenPred))
        for (n in 1 until k + 1) {
            var numMatches = 0
            val labelSubs = mutableMapOf<String, Int>()
            for (i in 0 until lenLabel - n + 1) {
                val key = labelTokens.subList(i, i + n).joinToString(separator = " ")
                labelSubs.put(key, labelSubs.getOrDefault(key, 0) + 1)
            }
            for (i in 0 until lenPred - n + 1) {
                // val key =predTokens.subList(i, i + n).joinToString(" ")
                val key = predTokens.subList(i, i + n).joinToString(separator = " ")
                if (labelSubs.getOrDefault(key, 0) > 0) {
                    numMatches += 1
                    labelSubs.put(key, labelSubs.getOrDefault(key, 0) - 1)
                }
            }
            score *= Math.pow(numMatches.toDouble() / (lenPred - n + 1).toDouble(), Math.pow(0.5, n.toDouble()))
        }
        return score
    }


In [14]:
    val engs = arrayOf("go .", "i lost .", "he's calm .", "i'm home .")
    val fras = arrayOf("va !", "j'ai perdu .", "il est calme .", "je suis chez moi .")
    for (i in engs.indices) {
        val pair = predictSeq2Seq(net, engs[i], srcVocab, tgtVocab, numSteps, device, false)
        val translation: String = pair.first
        val attentionWeightSeq = pair.second
        println("%s => %s, bleu %.3f".format(engs[i], translation, bleu(translation, fras[i], 2)))
    }
    println(numSteps)

go . => va !, bleu 1.000
i lost . => j'ai perdu ., bleu 1.000
he's calm . => il est bon ., bleu 0.658
i'm home . => je suis chez moi ., bleu 1.000
10


By [**visualizing the attention weights**]
when translating the last English sentence,
we can see that each query assigns non-uniform weights
over key-value pairs.
It shows that at each decoding step,
different parts of the input sequences
are selectively aggregated in the attention pooling.

In [15]:
    val pair = predictSeq2Seq(net, engs.last(), srcVocab, tgtVocab, numSteps, device, true)
    val attentions = pair.second
println(attentions.size)
    val matrix = manager.create(attentions[0].last().first).reshape(attentions[0].last().second)
        .concat(manager.create(attentions[1].last().first).reshape(attentions[1].last().second))
        .concat(manager.create(attentions[2].last().first).reshape(attentions[2].last().second))
        .concat(manager.create(attentions[3].last().first).reshape(attentions[3].last().second))
        .concat(manager.create(attentions[4].last().first).reshape(attentions[4].last().second)).reshape(5,10)
    println(matrix)
    val seriesX = mutableListOf<Long>()
    val seriesY = mutableListOf<Long>()
    val seriesW = mutableListOf<Float>()
    for(i in 0 until matrix.shape[0]) {
        val row = matrix.get(i)
        for(j in 0 until row.shape[0]) {
            seriesX.add(j)
            seriesY.add(i)
            seriesW.add(row.get(j).getFloat())
        }
    }
    val data = mapOf( "x" to seriesX, "y" to seriesY)
    var plot = letsPlot(data)
    plot += geomBin2D(drop=false, binWidth = Pair(1,1), position = positionIdentity){x="x"; y = "y"; weight = seriesW }
    plot += scaleFillGradient(low="blue", high="red")
//plot += scaleFillContinuous("red", "green")
    plot + ggsize(700, 200)

6
ND: (5, 10) cpu() float32
[[ 2.51013692e-03,  6.97515905e-04,  1.21010663e-02,  1.32435374e-02,  4.99029718e-02,  1.83202356e-01,  3.18012834e-01,  1.79773659e-01,  1.28771856e-01,  1.11784115e-01],
 [ 6.01784652e-03,  4.91461810e-03,  6.99551627e-02,  1.70189843e-01,  3.29204738e-01,  2.39405081e-01,  1.14986718e-01,  3.14249061e-02,  1.79261118e-02,  1.59750320e-02],
 [ 5.98380400e-04,  2.75306171e-04,  1.01401545e-02,  1.35161597e-02,  9.31207538e-02,  3.19551945e-01,  3.15946758e-01,  1.13761492e-01,  7.20553100e-02,  6.10337295e-02],
 [ 9.37038916e-04,  2.40883994e-04,  3.43956985e-03,  5.08535001e-03,  4.13465388e-02,  1.76077679e-01,  3.98106396e-01,  1.69010714e-01,  1.08022988e-01,  9.77328867e-02],
 [ 1.04698956e-04,  7.04538324e-05,  1.00894133e-03,  1.19440223e-03,  1.82277411e-02,  1.99527755e-01,  4.10029739e-01,  1.86897486e-01,  1.01734743e-01,  8.12040120e-02],
]



## Summary

When predicting a token, if not all the input tokens are relevant, the RNN encoder-decoder with Bahdanau attention selectively aggregates different parts of the input sequence. This is achieved by treating the context variable as an output of additive attention pooling.
In the RNN encoder-decoder, Bahdanau attention treats the decoder hidden state at the previous time step as the query, and the encoder hidden states at all the time steps as both the keys and values.

## Exercises

1. Replace GRU with LSTM in the experiment.
1. Modify the experiment to replace the additive attention scoring function with the scaled dot-product. How does it influence the training efficiency?