# Multi-Head Attention
:label:`sec_multihead-attention`


In practice,
given the same set of queries, keys, and values
we may want our model to
combine knowledge from
different behaviors of the same attention mechanism,
such as capturing dependencies of various ranges (e.g., shorter-range vs. longer-range)
within a sequence.
Thus, 
it may be beneficial 
to allow our attention mechanism
to jointly use different representation subspaces
of queries, keys, and values.



To this end,
instead of performing a single attention pooling,
queries, keys, and values
can be transformed
with $h$ independently learned linear projections.
Then these $h$ projected queries, keys, and values
are fed into attention pooling in parallel.
In the end,
$h$ attention pooling outputs
are concatenated and 
transformed with another learned linear projection
to produce the final output.
This design
is called *multi-head attention*,
where each of the $h$ attention pooling outputs
is a *head* :cite:`Vaswani.Shazeer.Parmar.ea.2017`.
Using fully-connected layers
to perform learnable linear transformations,
:numref:`fig_multi-head-attention`
describes multi-head attention.

![Multi-head attention, where multiple heads are concatenated then linearly transformed.](https://raw.githubusercontent.com/d2l-ai/d2l-en/master/img/multi-head-attention.svg)
:label:`fig_multi-head-attention`




## Model

Before providing the implementation of multi-head attention,
let us formalize this model mathematically.
Given a query $\mathbf{q} \in \mathbb{R}^{d_q}$,
a key $\mathbf{k} \in \mathbb{R}^{d_k}$,
and a value $\mathbf{v} \in \mathbb{R}^{d_v}$,
each attention head $\mathbf{h}_i$  ($i = 1, \ldots, h$)
is computed as

$$\mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v},$$

where learnable parameters
$\mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q}$,
$\mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k}$
and $\mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v}$,
and
$f$ is attention pooling,
such as
additive attention and scaled dot-product attention
in :numref:`sec_attention-scoring-functions`.
The multi-head attention output
is another linear transformation via 
learnable parameters
$\mathbf W_o\in\mathbb R^{p_o\times h p_v}$
of the concatenation of $h$ heads:

$$\mathbf W_o \begin{bmatrix}\mathbf h_1\\\vdots\\\mathbf h_h\end{bmatrix} \in \mathbb{R}^{p_o}.$$

Based on this design,
each head may attend to different parts of the input.
More sophisticated functions than the simple weighted average
can be expressed.


In [1]:
%use @file[../djl.json]
//%use lets-plot
@file:DependsOn("../D2J-1.0-SNAPSHOT.jar")
//import jp.live.ugai.d2j.attention.Chap10Utils
import jp.live.ugai.d2j.attention.DotProductAttention
// %load ../utils/attention/MultiHeadAttention.java
// %load ../utils/attention/PositionalEncoding.java

In [2]:
val manager = NDManager.newBaseManager()

To allow for parallel computation of multiple heads,
the below `MultiHeadAttention` class uses two transposition functions as defined below.
Specifically,
the `transposeOutput` function reverses the operation
of the `transposeQkv` function.

In [3]:
    fun transposeQkv(_X: NDArray, numHeads: Int): NDArray? {
        // Shape of input `X`:
        // (`batchSize`, no. of queries or key-value pairs, `numHiddens`).
        // Shape of output `X`:
        // (`batchSize`, no. of queries or key-value pairs, `numHeads`,
        // `numHiddens` / `numHeads`)
        var X = _X
        X = X.reshape(X.shape[0], X.shape[1], numHeads.toLong(), -1)

        // Shape of output `X`:
        // (`batchSize`, `numHeads`, no. of queries or key-value pairs,
        // `numHiddens` / `numHeads`)
        X = X.transpose(0, 2, 1, 3)

        // Shape of `output`:
        // (`batchSize` * `numHeads`, no. of queries or key-value pairs,
        // `numHiddens` / `numHeads`)
        return X.reshape(-1, X.shape[2], X.shape[3])
    }

        fun transposeOutput(_X: NDArray, numHeads: Int): NDArray? {
        var X = _X
        X = X.reshape(-1, numHeads.toLong(), X.shape[1], X.shape[2])
        X = X.transpose(0, 2, 1, 3)
        return X.reshape(X.shape[0], X.shape[1], -1)
    }


## Implementation

In our implementation,
we choose the scaled dot-product attention
for each head of the multi-head attention.
To avoid significant growth
of computational cost and parameterization cost,
we set
$p_q = p_k = p_v = p_o / h$.
Note that $h$ heads
can be computed in parallel
if we set
the number of outputs of linear transformations
for the query, key, and value
to $p_q h = p_k h = p_v h = p_o$.
In the following implementation,
$p_o$ is specified via the argument `numHiddens`.


In [4]:
class MultiHeadAttention(numHiddens: Int, private val numHeads: Int, dropout: Float, useBias: Boolean) :
    AbstractBlock() {
    var attention: DotProductAttention
    private val W_k: Linear
    private val W_q: Linear
    private val W_v: Linear
    private val W_o: Linear

    init {
        attention = DotProductAttention(dropout)
        W_q = Linear.builder().setUnits(numHiddens.toLong()).optBias(useBias).build()
        addChildBlock("W_q", W_q)
        W_k = Linear.builder().setUnits(numHiddens.toLong()).optBias(useBias).build()
        addChildBlock("W_k", W_k)
        W_v = Linear.builder().setUnits(numHiddens.toLong()).optBias(useBias).build()
        addChildBlock("W_v", W_v)
        W_o = Linear.builder().setUnits(numHiddens.toLong()).optBias(useBias).build()
        addChildBlock("W_o", W_o)
        val dropout1 = Dropout.builder().optRate(dropout).build()
        addChildBlock("dropout", dropout1)
    }

    override fun forwardInternal(
        ps: ParameterStore,
        inputs: NDList,
        training: Boolean,
        params: PairList<String, Any>?
    ): NDList {
        // Shape of `queries`, `keys`, or `values`:
        // (`batchSize`, no. of queries or key-value pairs, `numHiddens`)
        // Shape of `validLens`:
        // (`batchSize`,) or (`batchSize`, no. of queries)
        // After transposing, shape of output `queries`, `keys`, or `values`:
        // (`batchSize` * `numHeads`, no. of queries or key-value pairs,
        // `numHiddens` / `numHeads`)
        var queries = inputs[0]
        var keys = inputs[1]
        var values = inputs[2]
        var validLens = inputs[3]
        // On axis 0, copy the first item (scalar or vector) for
        // `numHeads` times, then copy the next item, and so on
        validLens = validLens.repeat(0, numHeads.toLong())
        queries = transposeQkv(W_q.forward(ps, NDList(queries), training, params)[0], numHeads)
        keys = transposeQkv(W_k.forward(ps, NDList(keys), training, params)[0], numHeads)
        values = transposeQkv(W_v.forward(ps, NDList(values), training, params)[0], numHeads)

        // Shape of `output`: (`batchSize` * `numHeads`, no. of queries,
        // `numHiddens` / `numHeads`)
        val output: NDArray = attention
            .forward(ps, NDList(queries, keys, values, validLens), training, params)
            .get(0)

        // Shape of `outputConcat`:
        // (`batchSize`, no. of queries, `numHiddens`)
        val outputConcat = transposeOutput(output, numHeads)
        return NDList(W_o.forward(ps, NDList(outputConcat), training, params)[0])
    }

    override fun getOutputShapes(inputShapes: Array<Shape>): Array<Shape> {
        throw UnsupportedOperationException("Not implemented")
    }

    override fun initializeChildBlocks(manager: NDManager, dataType: DataType, vararg inputShapes: Shape) {
        val sub = manager.newSubManager()
            var queries = sub.zeros(inputShapes[0], dataType)
            var keys = sub.zeros(inputShapes[1], dataType)
            var values = sub.zeros(inputShapes[2], dataType)
            var validLens = sub.zeros(inputShapes[3], dataType)
            validLens = validLens.repeat(0, numHeads.toLong())
            val ps = ParameterStore(sub, false)
            W_q.initialize(manager, dataType, queries.shape)
            W_k.initialize(manager, dataType, keys.shape)
            W_v.initialize(manager, dataType, values.shape)
            queries = transposeQkv(W_q.forward(ps, NDList(queries), false)[0], numHeads)
            keys = transposeQkv(W_k.forward(ps, NDList(keys), false)[0], numHeads)
            values = transposeQkv(W_v.forward(ps, NDList(values), false)[0], numHeads)
            val list = NDList(queries, keys, values, validLens)
            attention.initialize(sub, dataType, *list.shapes)
            val output = attention.forward(ps, list, false).head()
            val outputConcat = transposeOutput(output, numHeads)
            W_o.initialize(manager, dataType, outputConcat!!.shape)
        sub.close()
    }
}


Let us test our implemented `MultiHeadAttention` class
using a toy example where keys and values are the same.
As a result,
the shape of the multi-head attention output
is (`batchSize`, `numQueries`, `numHiddens`).


In [5]:
    val numHiddens = 100
    val numHeads = 5
    val attention = MultiHeadAttention(numHiddens, numHeads, 0.5f, false)

In [6]:
    val batchSize = 2
    val numQueries = 4
    val numKvpairs = 6
    val validLens = manager.create(floatArrayOf(3.0f, 2.0f))
    val X = manager.ones(Shape(batchSize.toLong(), numQueries.toLong(), numHiddens.toLong()))
    val Y = manager.ones(Shape(batchSize.toLong(), numKvpairs.toLong(), numHiddens.toLong()))

    val ps = ParameterStore(manager, false)
    val input = NDList(X, Y, Y, validLens)
    attention.initialize(manager, DataType.FLOAT32, *input.shapes)
    val result = attention.forward(ps, input, false)
    println(result[0].shape)


(2, 4, 100)


## Summary

* Multi-head attention combines knowledge of the same attention pooling via different representation subspaces of queries, keys, and values.
* To compute multiple heads of multi-head attention in parallel, proper tensor manipulation is needed.



## Exercises

1. Visualize attention weights of multiple heads in this experiment.
1. Suppose that we have a trained model based on multi-head attention and we want to prune least important attention heads to increase the prediction speed. How can we design experiments to measure the importance of an attention head?
