# Attention Pooling: Nadaraya-Watson Kernel Regression
:label:`sec_nadaraya-watson`

Now you know the major components of attention mechanisms under the framework in :numref:`fig_qkv`.
To recapitulate,
the interactions between
queries (volitional cues) and keys (nonvolitional cues)
result in *attention pooling*.
The attention pooling selectively aggregates values (sensory inputs) to produce the output.
In this section,
we will describe attention pooling in greater detail
to give you a high-level view of
how attention mechanisms work in practice.
Specifically,
the Nadaraya-Watson kernel regression model
proposed in 1964
is a simple yet complete example
for demonstrating machine learning with attention mechanisms.


In [106]:
%use @file[../djl-pytorch.json]
%use lets-plot
@file:DependsOn("../D2J-1.0-SNAPSHOT.jar")
import ai.djl.Model
import ai.djl.ndarray.NDArray
import ai.djl.ndarray.NDList
import ai.djl.ndarray.NDManager
import ai.djl.ndarray.index.NDIndex
import ai.djl.ndarray.types.DataType
import ai.djl.ndarray.types.Shape
import ai.djl.nn.AbstractBlock
import ai.djl.nn.Parameter
import ai.djl.training.DefaultTrainingConfig
import ai.djl.training.GradientCollector
import ai.djl.training.ParameterStore
import ai.djl.training.Trainer
import ai.djl.training.initializer.UniformInitializer
import ai.djl.training.loss.Loss
import ai.djl.training.optimizer.Optimizer
import ai.djl.training.tracker.Tracker
import ai.djl.util.PairList


In [107]:
val manager = NDManager.newBaseManager()


## Generating the Dataset

To keep things simple,
let us consider the following regression problem:
given a dataset of input-output pairs $\{(x_1, y_1), \ldots, (x_n, y_n)\}$,
how to learn $f$ to predict the output $\hat{y} = f(x)$ for any new input $x$?

Here we generate an artificial dataset according to the following nonlinear function with the noise term $\epsilon$:

$$y_i = 2\sin(x_i) + x_i^{0.8} + \epsilon,$$

where $\epsilon$ obeys a normal distribution with zero mean and standard deviation 0.5.
Both 50 training examples and 50 testing examples
are generated.
To better visualize the pattern of attention later, the training inputs are sorted.


In [108]:
val nTrain = 50 // No. of training examples
val xTrain = manager.randomUniform(0f, 1f, Shape(nTrain.toLong())).mul(5).sort() // Training inputs


In [109]:
val f: (NDArray) -> NDArray = { x -> x.sin().mul(2).add(x.pow(0.8)) }
val yTrain =
    f(xTrain).add(
        manager.randomNormal(
            0f,
            0.5f,
            Shape(nTrain.toLong()),
            DataType.FLOAT32
        )
    ) // Training outputs
val xTest = manager.arange(0f, 5f, 0.1f) // Testing examples
val yTruth = f(xTest) // Ground-truth outputs for the testing examples
val nTest = xTest.shape[0].toInt() // No. of testing examples
println(nTest)


50


The following function plots all the training examples (represented by circles),
the ground-truth data generation function `f` without the noise term (labeled by "Truth"), and the learned prediction function (labeled by "Pred").


In [110]:
fun plot(
    yHat: NDArray,
    trace1Name: String,
    trace2Name: String,
    xLabel: String,
    yLabel: String,
    width: Int,
    height: Int
): Any {
    val data = mapOf(
        "x" to (xTest.toFloatArray() + xTest.toFloatArray()),
        "y" to (yTruth.toFloatArray() + yHat.toFloatArray()),
        "label" to (Array(nTest) { trace1Name } + Array(nTest) { trace2Name })
    )
    var plot = letsPlot(data)
    plot += geomLine(size = 2) { x = "x"; y = "y"; color = "label" }
    plot += geomPoint(size = 3) { x = xTrain.toFloatArray(); y = yTrain.toFloatArray() }
    plot += xlab(xLabel) + ylab(yLabel)
    return plot + ggsize(width, height)
}


## Average Pooling

We begin with perhaps the world's "dumbest" estimator for this regression problem:
using average pooling to average over all the training outputs:

$$f(x) = \frac{1}{n}\sum_{i=1}^n y_i,$$
:eqlabel:`eq_avg-pooling`

which is plotted below. As we can see, this estimator is indeed not so smart.


In [111]:
var yHat = yTrain.mean().tile(nTest.toLong())
plot(yHat, "Truth", "Pred", "x", "y", 700, 500)


## Nonparametric Attention Pooling

Obviously,
average pooling omits the inputs $x_i$.
A better idea was proposed
by Nadaraya :cite:`Nadaraya.1964`
and Watson :cite:`Watson.1964`
to weigh the outputs $y_i$ according to their input locations:

$$f(x) = \sum_{i=1}^n \frac{K(x - x_i)}{\sum_{j=1}^n K(x - x_j)} y_i,$$
:eqlabel:`eq_nadaraya-watson`

where $K$ is a *kernel*.
The estimator in :eqref:`eq_nadaraya-watson`
is called *Nadaraya-Watson kernel regression*.
Here we will not dive into details of kernels.
Recall the framework of attention mechanisms in :numref:`fig_qkv`.
From the perspective of attention,
we can rewrite :eqref:`eq_nadaraya-watson`
in a more generalized form of *attention pooling*:

$$f(x) = \sum_{i=1}^n \alpha(x, x_i) y_i,$$
:eqlabel:`eq_attn-pooling`


where $x$ is the query and $(x_i, y_i)$ is the key-value pair.
Comparing :eqref:`eq_attn-pooling` and :eqref:`eq_avg-pooling`,
the attention pooling here
is a weighted average of values $y_i$.
The *attention weight* $\alpha(x, x_i)$
in :eqref:`eq_attn-pooling`
is assigned to the corresponding value $y_i$
based on the interaction
between the query $x$ and the key $x_i$
modeled by $\alpha$.
For any query, its attention weights over all the key-value pairs are a valid probability distribution: they are non-negative and sum up to one.

To gain intuitions of attention pooling,
just consider a *Gaussian kernel* defined as

$$
K(u) = \frac{1}{\sqrt{2\pi}} \exp(-\frac{u^2}{2}).
$$


Plugging the Gaussian kernel into
:eqref:`eq_attn-pooling` and
:eqref:`eq_nadaraya-watson` gives

$$\begin{aligned} f(x) &=\sum_{i=1}^n \alpha(x, x_i) y_i\\ &= \sum_{i=1}^n \frac{\exp\left(-\frac{1}{2}(x - x_i)^2\right)}{\sum_{j=1}^n \exp\left(-\frac{1}{2}(x - x_j)^2\right)} y_i \\&= \sum_{i=1}^n \mathrm{softmax}\left(-\frac{1}{2}(x - x_i)^2\right) y_i. \end{aligned}$$
:eqlabel:`eq_nadaraya-watson-gaussian`

In :eqref:`eq_nadaraya-watson-gaussian`,
a key $x_i$ that is closer to the given query $x$ will get
*more attention* via a *larger attention weight* assigned to the key's corresponding value $y_i$.

Notably, Nadaraya-Watson kernel regression is a nonparametric model;
thus :eqref:`eq_nadaraya-watson-gaussian`
is an example of *nonparametric attention pooling*.
In the following, we plot the prediction based on this
nonparametric attention model.
The predicted line is smooth and closer to the ground-truth than that produced by average pooling.


In [112]:
// Shape of `xRepeat`: (`nTest`, `nTrain`), where each row contains the
// same testing inputs (i.e., same queries)
val xRepeat = xTest.repeat(nTrain.toLong()).reshape(Shape(-1, nTrain.toLong()))
// Note that `xTrain` contains the keys. Shape of `attentionWeights`:
// (`nTest`, `nTrain`), where each row contains attention weights to be
// assigned among the values (`yTrain`) given each query
val attentionWeights = xRepeat.sub(xTrain).pow(2).div(2).mul(-1).softmax(-1)
// Each element of `yHat` is weighted average of values, where weights are
// attention weights
yHat = attentionWeights.matMul(yTrain.reshape(Shape(nTrain.toLong(), 1))).reshape(Shape(-1))
plot(yHat, "Truth", "Pred", "x", "y", 700, 500)


Now let us take a look at the attention weights.
Here testing inputs are queries while training inputs are keys.
Since both inputs are sorted,
we can see that the closer the query-key pair is,
the higher attention weight is in the attention pooling.


In [113]:
fun plotHeatmap(att: NDArray, width: Int, height: Int): Any {
    val seriesX = mutableListOf<Long>()
    val seriesY = mutableListOf<Long>()
    val seriesW = mutableListOf<Float>()
    for (i in 0 until att.shape[0]) {
        val row = att.get(i)
        for (j in 0 until row.shape[0]) {
            seriesX.add(j)
            seriesY.add(i)
            seriesW.add(row.get(j).getFloat())
        }
    }
    val data = mapOf("x" to seriesX, "y" to seriesY)
    var plot = letsPlot(data)
    plot += geomBin2D(drop = false, binWidth = kotlin.Pair(1, 1), position = positionIdentity) {
        x = "x"; y = "y"; weight = seriesW
    }
    plot += scaleFillGradient(low = "blue", high = "red")
    return plot + ggsize(width, height)
}

plotHeatmap(attentionWeights, 500, 700)


## Parametric Attention Pooling

Nonparametric Nadaraya-Watson kernel regression
enjoys the *consistency* benefit:
given enough data this model converges to the optimal solution.
Nonetheless,
we can easily integrate learnable parameters into attention pooling.

As an example, slightly different from :eqref:`eq_nadaraya-watson-gaussian`,
in the following
the distance between the query $x$ and the key $x_i$
is multiplied a learnable parameter $w$:


$$\begin{aligned}f(x) &= \sum_{i=1}^n \alpha(x, x_i) y_i \\&= \sum_{i=1}^n \frac{\exp\left(-\frac{1}{2}((x - x_i)w)^2\right)}{\sum_{j=1}^n \exp\left(-\frac{1}{2}((x - x_i)w)^2\right)} y_i \\&= \sum_{i=1}^n \mathrm{softmax}\left(-\frac{1}{2}((x - x_i)w)^2\right) y_i.\end{aligned}$$
:eqlabel:`eq_nadaraya-watson-gaussian-para`

In the rest of the section,
we will train this model by learning the parameter of
the attention pooling in :eqref:`eq_nadaraya-watson-gaussian-para`.


### Batch Matrix Multiplication
:label:`subsec_batch_dot`

To more efficiently compute attention
for minibatches,
we can leverage batch matrix multiplication utilities
provided by deep learning frameworks.


Suppose that the first minibatch contains $n$ matrices $\mathbf{X}_1, \ldots, \mathbf{X}_n$ of shape $a\times b$, and the second minibatch contains $n$ matrices $\mathbf{Y}_1, \ldots, \mathbf{Y}_n$ of shape $b\times c$. Their batch matrix multiplication
results in
$n$ matrices $\mathbf{X}_1\mathbf{Y}_1, \ldots, \mathbf{X}_n\mathbf{Y}_n$ of shape $a\times c$. Therefore, given two tensors of shape ($n$, $a$, $b$) and ($n$, $b$, $c$), the shape of their batch matrix multiplication output is ($n$, $a$, $c$).


In [114]:
val X = manager.ones(Shape(2, 1, 4))
val Y = manager.ones(Shape(2, 4, 6))

X.matMul(Y).shape


(2, 1, 6)

In the context of attention mechanisms, we can use minibatch matrix multiplication to compute weighted averages of values in a minibatch.


In [115]:
val weights = manager.ones(Shape(2, 10)).mul(0.1)
val values = manager.arange(20f).reshape(Shape(2, 10))
weights.expandDims(1).matMul(values.expandDims(-1))


ND: (2, 1, 1) gpu(0) float32
Check the "Development Guideline"->Debug to enable array display.


### Defining the Model

Using minibatch matrix multiplication,
below we define the parametric version
of Nadaraya-Watson kernel regression
based on the parametric attention pooling in
:eqref:`eq_nadaraya-watson-gaussian-para`.


In [116]:
class NWKernelRegression : AbstractBlock() {
    private val w: Parameter
    var attentionWeights: NDArray? = null

    init {
        w = Parameter.builder()
            .optShape(Shape(1))
            .setName("w")
            .optInitializer(UniformInitializer())
            .build()
        addParameter(w)
    }

    override fun forwardInternal(
        parameterStore: ParameterStore,
        inputs: NDList,
        training: Boolean,
        params: PairList<String, Any>?
    ): NDList {
        // Shape of the output `queries` and `attentionWeights`:
        // (no. of queries, no. of key-value pairs)
        var queries = inputs[0]
        val keys = inputs[1]
        val values = inputs[2]
        queries = queries.repeat(keys.shape[1]).reshape(Shape(-1, keys.shape[1]))

        attentionWeights = queries.sub(keys).mul(w.array).pow(2).div(2).mul(-1).softmax(-1)
        // Shape of `values`: (no. of queries, no. of key-value pairs)
        return NDList(attentionWeights!!.mul(values).sum(intArrayOf(1)))
    }

    override fun getOutputShapes(inputShapes: Array<Shape>): Array<Shape> {
        throw UnsupportedOperationException("Not implemented")
    }
}


### Training

In the following, we transform the training dataset
to keys and values to train the attention model.
In the parametric attention pooling,
any training input takes key-value pairs from all the training examples except for itself to predict its output.


In [117]:
// Shape of `xTile`: (`nTrain`, `nTrain`), where each column contains the
// same training inputs
val xTile = xTrain.tile(longArrayOf(nTrain.toLong(), 1L))
// Shape of `yTile`: (`nTrain`, `nTrain`), where each column contains the
// same training outputs
val yTile = yTrain.tile(longArrayOf(nTrain.toLong(), 1L))
// Shape of `keys`: (`nTrain`, `nTrain` - 1)
var keys =
    xTile.get(NDIndex().addBooleanIndex(manager.eye(nTrain).mul(-1).add(1).toType(DataType.BOOLEAN, false))).reshape(Shape(nTrain.toLong(), -1))
// Shape of `values`: (`nTrain`, `nTrain` - 1)
var valuesKV =
    yTile.get(NDIndex().addBooleanIndex(manager.eye(nTrain).mul(-1).add(1).toType(DataType.BOOLEAN, false))).reshape(Shape(nTrain.toLong(), -1))


Using the squared loss and stochastic gradient descent,
we train the parametric attention model.


In [118]:
val net = NWKernelRegression()
val loss = Loss.l2Loss()
val lrt = Tracker.fixed(0.5f * nTrain) // scale for SGD
val sgd = Optimizer.sgd().setLearningRateTracker(lrt).build()
val config =
    DefaultTrainingConfig(loss)
        .optOptimizer(sgd) // Optimizer (loss function)
        val model = Model.newInstance("")
model.block = net

val trainer: Trainer = model.newTrainer(config)
//val animator = Animator()
val ps = ParameterStore(manager, false)

for (epoch in 0 until 5) {
    trainer.newGradientCollector().use { gc ->
        val result = net.forward(ps, NDList(xTrain, keys, valuesKV), true).singletonOrThrow()
        val l = trainer.getLoss().evaluate(NDList(yTrain), NDList(result))
        gc.backward(l)
//        animator.add(epoch + 1, l.getFloat(), "Loss")
//        animator.show()
    }
    trainer.step()
}


After training the parametric attention model,
we can plot its prediction.
Trying to fit the training dataset with noise,
the predicted line is less smooth
than its nonparametric counterpart that was plotted earlier.


In [119]:
// Shape of `keys`: (`nTest`, `nTrain`), where each column contains the same
// training inputs (i.e., same keys)
keys = xTrain.tile(longArrayOf(nTest.toLong(), 1L))

// Shape of `values`: (`nTest`, `nTrain`)
valuesKV = yTrain.tile(longArrayOf(nTest.toLong(), 1L))
yHat = net.forward(ps, NDList(xTest, keys, valuesKV), true).singletonOrThrow()
plot(yHat, "Truth", "Pred", "x", "y", 700, 500)


Comparing with nonparametric attention pooling,
the region with large attention weights becomes sharper
in the learnable and parametric setting.


In [120]:
plotHeatmap(net.attentionWeights!!, 500, 700)


## Summary

* Nadaraya-Watson kernel regression is an example of machine learning with attention mechanisms.
* The attention pooling of Nadaraya-Watson kernel regression is a weighted average of the training outputs. From the attention perspective, the attention weight is assigned to a value based on a function of a query and the key that is paired with the value.
* Attention pooling can be either nonparametric or parametric.


## Exercises

1. Increase the number of training examples. Can you learn  nonparametric Nadaraya-Watson kernel regression better?
1. What is the value of our learned $w$ in the parametric attention pooling experiment? Why does it make the weighted region sharper when visualizing the attention weights?
1. How can we add hyperparameters to nonparametric Nadaraya-Watson kernel regression to predict better?
1. Design another parametric attention pooling for the kernel regression of this section. Train this new model and visualize its attention weights.
