# Attention Pooling

:label:`sec_attention-pooling`

You now know the major components of attention mechanisms.
The interactions between queries and keys 
induce an *attention pooling* over values.
In this section, we will describe attention pooling in greater detail
to give you a high-level view of how modern attention mechanisms work in practice.


To begin, we point out that the idea of computing weighted sums
according to some compatibility score is actually 
quite common in machine learning and statistics. 
In particular, kernel regression has this flavor. 
To compute the prediction for a given data point $x$,
a kernel regression model determines the *similarity*,
between $x$ and each $x' \neq x$ in the dataset. 
To compute the predicted label, the kernel regression
computes a weighted sum of the labels of each training instance.
Here, instances with features deemed similar
(according to some similarity function)
are weighted higher and instances
with lower similarity are weighted lower. 
This is precisely the behavior of the Nadaraya-Watson kernel regression model,
proposed in 1964.
Note that the crucial difference here is that 
in kernel regression, the weighting is computed
over the training *examples*
whereas in attention mechanisms, the weighting
is computed over the inputs (e.g., input tokens of one training example). 

To build your intuition, we briefly implement
the classic Nadaraya-Watson kernel regression model,
computing similarities according to a Gaussian kernel.

In [1]:
%use @file[../djl.json]
%use lets-plot
@file:DependsOn("../D2J-1.0-SNAPSHOT.jar")
@file:DependsOn("../mxnet-native-cu112mkl-1.9.1-linux-x86_64.jar")
import kotlin.random.Random
import kotlin.collections.List
import kotlin.collections.Map
import kotlin.Pair

In [2]:
val manager = NDManager.newBaseManager()

## [**Generating the Dataset**]

To keep things simple, let's consider
the following regression problem:
given a dataset of input-output pairs 
$\{(x_1, y_1), \ldots, (x_n, y_n)\}$,
we wish to learn a function $f$ 
that can accurately predict 
the target $y$ for any new input $x$.

In the following snippets, we generate an artificial dataset 
according to the following nonlinear function with the noise term $\epsilon$:

$$y_i = 2\sin(x_i) + x_i^{0.8} + \epsilon,$$

where $\epsilon$ obeys a normal distribution 
with zero mean and standard deviation 0.5.
Both 50 training examples and 50 validation examples are generated.
To better visualize the pattern of attention later, the training inputs are sorted.

In [13]:
import ai.djl.training.dataset.ArrayDataset
import ai.djl.training.dataset.Batch
import ai.djl.metric.Metrics


val batchSize = 10
val n = 50

val f = { x: NDArray -> x.sin().mul(2).add(x.pow(0.8)) }
val xTrain = manager.randomUniform(0f, 1f, Shape(n.toLong())).mul(5).sort()
val r = manager.randomNormal(Shape(n.toLong())).div(2)
val yTrain = f(xTrain).add(r)
val xVal = manager.arange(0f, 5f, 5.0f/n)
val yVal = f(xVal)
    val nonLinearDataSet = ArrayDataset.Builder()
        .setData(xTrain)
        .optLabels(yTrain)
        .setSampling(batchSize, false)
        .build()
    val nonLinearDataSetVar = ArrayDataset.Builder()
        .setData(xVal)
        .optLabels(yVal)
        .setSampling(batchSize, false)
        .build()


The following function plots all the training examples (represented by circles),
the ground-truth data generation function `f` 
without the noise term (labeled by "Truth"),
and the learned prediction function (labeled by "Pred").


## Average Pooling

We begin with perhaps the world's "dumbest" estimator for this regression problem:
using average pooling to average over all the training outputs:

$$f(x) = \frac{1}{n}\sum_{i=1}^n y_i,$$
:eqlabel:`eq_avg-pooling`

which is plotted below. As we can see, this estimator is indeed not so smart.


In [14]:
println(xTrain)
val yHat = yTrain.mean().reshape(1).repeat(n.toLong())
println(yHat)
val data = mapOf(
    "x" to xVal.toFloatArray() + xVal.toFloatArray(),
    "y" to yVal.toFloatArray() + yHat.toFloatArray(),
    "label" to Array(n){"True"} + Array(n){"Pred"}
)
var plot = letsPlot(data)
plot += geomLine(size=2) { x = "x" ; y = "y" ; color = "label" }
plot += geomPoint(size=3) { x = xTrain.toFloatArray() ; y = yTrain.toFloatArray() }
plot + ggsize(700,500)

ND: (50) gpu(0) float32
[0.1638, 0.1951, 0.4685, 0.5497, 0.7706, 1.04  , 1.0665, 1.1713, 1.2024, 1.2634, 1.3915, 1.4247, 1.4601, 1.4738, 1.5364, 1.597 , 1.7728, 1.7785, 1.7925, 1.8752, ... 30 more]

ND: (50) gpu(0) float32
[2.5126, 2.5126, 2.5126, 2.5126, 2.5126, 2.5126, 2.5126, 2.5126, 2.5126, 2.5126, 2.5126, 2.5126, 2.5126, 2.5126, 2.5126, 2.5126, 2.5126, 2.5126, 2.5126, 2.5126, ... 30 more]



## [**Nonparametric Attention Pooling**]

Average pooling isn't very useful because 
it fails to output different predictions
depending on the inputs $x_i$.
Thus, the method due to :citet:`Nadaraya.1964`
and :citet:`Watson.1964`
weighs the outputs $y_i$ according to their input locations:

$$f(x) = \sum_{i=1}^n \frac{K(x - x_i)}{\sum_{j=1}^n K(x - x_j)} y_i,$$
:eqlabel:`eq_nadaraya-watson`

where $K$ is a *kernel*.
The estimator in :eqref:`eq_nadaraya-watson`
is called *Nadaraya-Watson kernel regression*.
Recall the framework of attention mechanisms in :numref:`fig_qkv`.
From the perspective of attention,
we can rewrite :eqref:`eq_nadaraya-watson`
in a more generalized form of *attention pooling*:

$$f(x) = \sum_{i=1}^n \alpha(x, x_i) y_i,$$
:eqlabel:`eq_attn-pooling`


where $x$ is the query and $(x_i, y_i)$ is the key-value pair.
Comparing :eqref:`eq_attn-pooling` and :eqref:`eq_avg-pooling`,
the attention pooling here
is a weighted average of values $y_i$.
The *attention weight* $\alpha(x, x_i)$
in :eqref:`eq_attn-pooling`
is assigned to the corresponding value $y_i$
based on the interaction
between the query $x$ 
and the key $x_i$,
modeled by $\alpha$.
For any query, its attention weights over all the key-value pairs
are a valid probability distribution:
they are non-negative and sum up to one.

To gain intuitions of attention pooling,
just consider a *Gaussian kernel* defined as

$$
K(u) = \frac{1}{\sqrt{2\pi}} \exp(-\frac{u^2}{2}).
$$


Plugging the Gaussian kernel into
:eqref:`eq_attn-pooling` and
:eqref:`eq_nadaraya-watson` gives

$$\begin{aligned} f(x) &=\sum_{i=1}^n \alpha(x, x_i) y_i\\ &= \sum_{i=1}^n \frac{\exp\left(-\frac{1}{2}(x - x_i)^2\right)}{\sum_{j=1}^n \exp\left(-\frac{1}{2}(x - x_j)^2\right)} y_i \\&= \sum_{i=1}^n \mathrm{softmax}\left(-\frac{1}{2}(x - x_i)^2\right) y_i. \end{aligned}$$
:eqlabel:`eq_nadaraya-watson-gaussian`

In :eqref:`eq_nadaraya-watson-gaussian`,
a key $x_i$ that is closer to the given query $x$ will get
*more attention* via a *larger attention weight* 
assigned to the key's corresponding value $y_i$.

The Nadaraya-Watson kernel regression is a nonparametric model;
thus :eqref:`eq_nadaraya-watson-gaussian`
is an example of *nonparametric attention pooling*.
In the following, we plot the prediction based on this
nonparametric attention model.


In [15]:
fun diff(queries: NDArray, keys: NDArray) : NDArray {
    return queries.reshape(-1,1).sub(keys.reshape(1,-1))
}

fun attentionPool(queryKeyDiffs :NDArray, values: NDArray) : NDList {
    val attentionWeights = queryKeyDiffs.pow(2).div(2).mul(-1).softmax(1)
    return NDList(attentionWeights.dot(values), attentionWeights)
}

val aPool = attentionPool(diff(xVal, xTrain), yTrain)
val data = mapOf(
    "x" to xVal.toFloatArray() + xVal.toFloatArray(),
    "y" to yVal.toFloatArray() + aPool[0].toFloatArray(),
    "label" to Array(n){"True"} + Array(n){"Pred"}
)
var plot = letsPlot(data)
plot += geomLine(size=2) { x = "x" ; y = "y" ; color = "label" }
plot += geomPoint(size=3) { x = xTrain.toFloatArray() ; y = yTrain.toFloatArray() }
plot + ggsize(700,500)

Now let's take a look at the [**attention weights**].
Here validation inputs are queries while training inputs are keys.
Since both inputs are sorted,
we can see that the closer the query-key pair is,
the higher attention weight is in the attention pooling.


In [16]:
val matrix = aPool[1]
    val seriesX = mutableListOf<Long>()
    val seriesY = mutableListOf<Long>()
    val seriesW = mutableListOf<Float>()
    for(i in 0 until matrix.shape[0]) {
        val row = matrix.get(i)
        for(j in 0 until row.shape[0]) {
            seriesX.add(j)
            seriesY.add(i)
            seriesW.add(row.get(j).getFloat())
        }
    }
    val data = mapOf( "x" to seriesX, "y" to seriesY)
    var plot = letsPlot(data)
    plot += geomBin2D(drop=false, binWidth = Pair(1,1), position = positionIdentity){x="x"; y = "y"; weight = seriesW }
    plot += scaleFillGradient(low="blue", high="red")
//plot += scaleFillContinuous("red", "green")
    plot + ggsize(700, 200)

# [**Parametric Attention Pooling**]

Nadaraya-Watson kernel regression enjoys *consistency*.
Given enough data (and sufficiently small kernel bandwidth)
this model converges to the optimal solution.
Nonetheless, we can easily integrate 
learnable parameters into attention pooling.

As an example, slightly different from :eqref:`eq_nadaraya-watson-gaussian`,
in the following the distance between the query $x$ and the key $x_i$
is multiplied by a learnable parameter $w$:


$$\begin{aligned}f(x) &= \sum_{i=1}^n \alpha(x, x_i) y_i \\&= \sum_{i=1}^n \frac{\exp\left(-\frac{1}{2}((x - x_i)w)^2\right)}{\sum_{j=1}^n \exp\left(-\frac{1}{2}((x - x_j)w)^2\right)} y_i \\&= \sum_{i=1}^n \mathrm{softmax}\left(-\frac{1}{2}((x - x_i)w)^2\right) y_i.\end{aligned}$$
:eqlabel:`eq_nadaraya-watson-gaussian-para`

In the rest of the section, we will train this model 
by learning the parameter of the attention pooling
in :eqref:`eq_nadaraya-watson-gaussian-para`.


### Batch Matrix Multiplication
:label:`subsec_batch_dot`

To more efficiently compute attention for minibatches,
we can leverage batch matrix multiplication utilities
provided by deep learning frameworks.


Suppose that the first minibatch contains $n$ matrices
$\mathbf{X}_1, \ldots, \mathbf{X}_n$ of shape $a\times b$,
and the second minibatch contains $n$ matrices 
$\mathbf{Y}_1, \ldots, \mathbf{Y}_n$ of shape $b\times c$. 
Their batch matrix multiplication results in $n$ matrices
$\mathbf{X}_1\mathbf{Y}_1, \ldots, \mathbf{X}_n\mathbf{Y}_n$ 
of shape $a\times c$. 
Therefore, [**given two tensors of shape ($n$, $a$, $b$) and ($n$, $b$, $c$),
the shape of their batch matrix multiplication output is ($n$, $a$, $c$).**]

In [17]:
val X = manager.ones(Shape(2,1,4))
val Y = manager.ones(Shape(2,4,6))
X.matMul(Y)

ND: (2, 1, 6) gpu(0) float32
[[[4., 4., 4., 4., 4., 4.],
 ],
 [[4., 4., 4., 4., 4., 4.],
 ],
]


In the context of attention mechanisms,
we can [**use minibatch matrix multiplication 
to compute weighted averages of values in a minibatch.**]

In [18]:
val weights = manager.ones(Shape(2, 10)).mul(0.1)
val values = manager.arange(20.0f).reshape(Shape(2,10))
weights.expandDims(1).matMul(values.expandDims(-1))

ND: (2, 1, 1) gpu(0) float32
[[[ 4.5],
 ],
 [[14.5],
 ],
]


### Defining the Model

Using minibatch matrix multiplication,
below we define the parametric version
of Nadaraya-Watson kernel regression
based on the [**parametric attention pooling**] 
in :eqref:`eq_nadaraya-watson-gaussian-para`.


In [19]:
    class NWKernelRegression(val keys: NDArray, val values0: NDArray) : AbstractBlock() {
        val wParam : Parameter
        var attention: NDArray? = null
        init {
            wParam = addParameter(
                Parameter.builder()
                    .setName("weight")
                    .setType(Parameter.Type.BIAS)
                    .optShape(Shape(1))
                    .optArray(manager.ones(Shape(1)))
                    .build())
        }

        override fun forwardInternal(
            parameterStore: ParameterStore,
            X: NDList,
            training: Boolean,
            params: PairList<String, Any>?
        ): NDList {
            val input = X.head()
            val device: Device = input.getDevice()
            val tmpW = parameterStore.getValue(wParam, device, training)
            val ret = attentionPool(diff(input, keys).mul(tmpW), values0)
            attention = ret[1]
            return NDList(ret[0])
        }

        override fun getOutputShapes(inputs: Array<Shape>): Array<Shape> {
            return inputs
        }
    }


### Training

In the following, we [**transform the training dataset
to keys and values**] to train the attention model.
In the parametric attention pooling,
for simplicity,
any training input just takes key-value pairs 
from all training examples to predict its output.


In [29]:
    val lr = 3f
    val lrt = Tracker.fixed(lr)

    val l2loss = Loss.l2Loss()
    val sgd = Optimizer.sgd().setLearningRateTracker(lrt).build()

    val config = DefaultTrainingConfig(l2loss)
        .optOptimizer(sgd) // Optimizer (loss function)
        .optDevices(Engine.getInstance().getDevices(1)) // single CPU/GPU
        .addEvaluator(Accuracy()) // Model Accuracy
        .addEvaluator(l2loss)
        .addTrainingListeners(*TrainingListener.Defaults.logging()) // Logging

    val model = Model.newInstance("NWKernelRegression")
    val net = NWKernelRegression(xTrain, yTrain)
    model.setBlock(net)
    val trainer = model.newTrainer(config)
    trainer.initialize(Shape(batchSize.toLong(), 2))
    trainer.setMetrics(Metrics());
    val numEpochs = 10
    EasyTrain.fit(trainer, numEpochs, nonLinearDataSet, nonLinearDataSetVar)
    println(net.parameters.get(0).value.array)


Training:    100% |████████████████████████████████████████| Accuracy: 0.66, L2Loss: 0.19, ...
Validating:  100% |████████████████████████████████████████|
Training:    100% |████████████████████████████████████████| Accuracy: 0.68, L2Loss: 0.11, ...
Validating:  100% |████████████████████████████████████████|
Training:    100% |████████████████████████████████████████| Accuracy: 0.70, L2Loss: 0.10, ...
Validating:  100% |████████████████████████████████████████|
Training:    100% |████████████████████████████████████████| Accuracy: 0.70, L2Loss: 0.10, ...
Validating:  100% |████████████████████████████████████████|
Training:    100% |████████████████████████████████████████| Accuracy: 0.72, L2Loss: 0.09, ...
Validating:  100% |████████████████████████████████████████|
Training:    100% |████████████████████████████████████████| Accuracy: 0.72, L2Loss: 0.09, ...
Validating:  100% |████████████████████████████████████████|
Training:    100% |████████████████████████████████████████| Acc

Trying to fit the training dataset with noise, 
the predicted line is less smooth 
than its nonparametric counterpart, plotted earlier.

In [30]:
val ps = ParameterStore(manager, false)
val smooth = net.forward(ps, NDList(xTrain), false)[0]
val data = mapOf(
    "x" to xVal.toFloatArray() + xVal.toFloatArray(),
    "y" to yVal.toFloatArray() + smooth.toFloatArray(),
    "label" to Array(n){"True"} + Array(n){"Pred"}
)
var plot = letsPlot(data)
plot += geomLine(size=2) { x = "x" ; y = "y" ; color = "label" }
plot += geomPoint(size=3) { x = xTrain.toFloatArray() ; y = yTrain.toFloatArray() }
plot + ggsize(700,500)

Comparing with nonparametric attention pooling,
[**the region with large attention weights becomes sharper**]
in the parametric setting.


In [31]:
val matrix = net.attention!!
    val seriesX = mutableListOf<Long>()
    val seriesY = mutableListOf<Long>()
    val seriesW = mutableListOf<Float>()
    for(i in 0 until matrix.shape[0]) {
        val row = matrix.get(i)
        for(j in 0 until row.shape[0]) {
            seriesX.add(j)
            seriesY.add(i)
            seriesW.add(row.get(j).getFloat())
        }
    }
    val data = mapOf( "x" to seriesX, "y" to seriesY)
    var plot = letsPlot(data)
    plot += geomBin2D(drop=false, binWidth = Pair(1,1), position = positionIdentity){x="x"; y = "y"; weight = seriesW }
    plot += scaleFillGradient(low="blue", high="red")
//plot += scaleFillContinuous("red", "green")
    plot + ggsize(700, 200)

## Summary

Nadaraya-Watson kernel regression is an example 
of machine learning with attention mechanisms.
The attention pooling of Nadaraya-Watson kernel regression
is a weighted average of the training outputs.
From the attention perspective, the attention weight is assigned to a value 
based on a function of a query and the key that is paired with the value.
Attention pooling can be either nonparametric or parametric.


## Exercises

1. Increase the number of training examples. Can you learn  nonparametric Nadaraya-Watson kernel regression better?
1. What is the value of our learned $w$ in the parametric attention pooling experiment? Why does it make the weighted region sharper when visualizing the attention weights?
1. How can we add hyperparameters to nonparametric Nadaraya-Watson kernel regression to predict better?
1. Design another parametric attention pooling for the kernel regression of this section. Train this new model and visualize its attention weights.