# understand attention mechanism

In [1]:
#pragma cling add_include_path("../../../libtorch/include")
#pragma cling add_include_path("../../../libtorch/include/torch/csrc/api/include")
#pragma cling add_library_path("../../../libtorch/lib")
#pragma cling load("libtorch")

In [2]:
#include <iostream>
#include <tuple>
#include <string>
#include <vector>
#include <cmath>
#include <torch/torch.h>
#include <torch/script.h>
namespace nn = torch::nn;
namespace F = torch::nn::functional;

In [3]:
using torch::indexing::Slice;
using torch::indexing::None;

The interactions between queries (volitional cues) and keys (nonvolitional cues) result in attention pooling. The attention pooling selectively aggregates values (sensory inputs) to produce the output. In this section, we will describe attention pooling in greater detail to give you a high-level view of how attention mechanisms work in practice. Specifically, the Nadaraya-Watson kernel regression model proposed in 1964 is a simple yet complete example for demonstrating machine learning with attention mechanisms

# 1 intuition of attention mechanisms¶
## generate mock dataset

In [4]:
torch::Tensor f(torch::Tensor x){
    return 2 * torch::sin(x) + x.pow(0.8);
}

In [5]:
int n_train = 50;
torch::Tensor x_train = torch::rand(n_train) * 5;
std::tuple<torch::Tensor, torch::Tensor> sorted_tensor = torch::sort(x_train);
x_train = std::get<0>(sorted_tensor);
    
torch::Tensor y_train = f(x_train) + torch::normal(0.0, 0.5, {n_train,});
torch::Tensor x_test = torch::arange(0, 5, 0.5);
torch::Tensor y_truth = f(x_test);

int n_test = x_test.size(0);

In [6]:
std::cout << n_test;

10

In [8]:
std::cout << x_train.index({torch::indexing::None, Slice()});

Columns 1 to 10 0.0024  0.0126  0.1656  0.3623  0.4831  0.5598  0.5641  0.6235  0.6449  0.9649

Columns 11 to 20 1.0581  1.0602  1.1115  1.2249  1.2952  1.4877  1.5125  1.6455  1.7756  1.7802

Columns 21 to 30 2.1077  2.1463  2.1602  2.1610  2.4413  2.5440  2.6271  2.6898  2.7903  2.7922

Columns 31 to 40 2.8925  3.1440  3.1508  3.2526  3.3495  3.3830  3.6988  4.0370  4.1022  4.1065

Columns 41 to 50 4.1401  4.1482  4.2658  4.2835  4.3629  4.5923  4.6302  4.7272  4.8369  4.8993
[ CPUFloatType{1,50} ]

## 1.1 Average Pooling

In [9]:
torch::Tensor y_hat = y_train.mean(0).repeat(n_test);

In [10]:
std::cout << y_hat << std::endl;

 2.2468
 2.2468
 2.2468
 2.2468
 2.2468
 2.2468
 2.2468
 2.2468
 2.2468
 2.2468
[ CPUFloatType{10} ]


## 1.2 Nonparametric Attention Pooling

In [11]:
torch::Tensor diff(torch::Tensor queries, torch::Tensor keys){
    return queries.reshape({-1, 1}) - keys.reshape({1, -1});
}

In [12]:
torch::Tensor score_function(torch::Tensor queries,torch::Tensor keys){
    torch::Tensor query_key_diffs = diff(queries, keys);
    torch::Tensor scores = - query_key_diffs.pow(2) / 2;
    return scores;
}

In [13]:
std::tuple<torch::Tensor, torch::Tensor> attention_pool(torch::Tensor scores, torch::Tensor values){
    torch::Tensor attention_weights = F::softmax(scores, 1);
    torch::Tensor output = torch::matmul(attention_weights, values);
    return std::make_tuple(output, attention_weights);
}

In [14]:
std::tuple<torch::Tensor, torch::Tensor> output_with_weight = attention_pool(score_function(x_test, x_train), y_train)

In [15]:
torch::Tensor output = std::get<0>(output_with_weight);
torch::Tensor attention_weights = std::get<1>(output_with_weight);

std::cout << output << std::endl;
std::cout << attention_weights << std::endl;

 1.7523
 2.0821
 2.3985
 2.6507
 2.7799
 2.7485
 2.5707
 2.3168
 2.0766
 1.9056
[ CPUFloatType{10} ]
Columns 1 to 6 7.5054e-02  7.5048e-02  7.4032e-02  7.0286e-02  6.6786e-02  6.4171e-02
 4.8873e-02  4.9121e-02  5.2307e-02  5.4793e-02  5.5307e-02  5.5216e-02
 2.8172e-02  2.8461e-02  3.2715e-02  3.7813e-02  4.0545e-02  4.2059e-02
 1.3874e-02  1.4088e-02  1.7481e-02  2.2294e-02  2.5393e-02  2.7370e-02
 5.6367e-03  5.7535e-03  7.7063e-03  1.0844e-02  1.3120e-02  1.4694e-02
 1.8407e-03  1.8886e-03  2.7306e-03  4.2394e-03  5.4489e-03  6.3408e-03
 4.7699e-04  4.9192e-04  7.6775e-04  1.3152e-03  1.7957e-03  2.1713e-03
 9.8390e-05  1.0199e-04  1.7183e-04  3.2479e-04  4.7106e-04  5.9182e-04
 1.6535e-05  1.7229e-05  3.1333e-05  6.5347e-05  1.0068e-04  1.3143e-04
 2.3583e-06  2.4700e-06  4.8489e-06  1.1158e-05  1.8262e-05  2.4770e-05

Columns 7 to 12 6.4014e-02  6.1796e-02  6.0962e-02  4.7120e-02  4.2879e-02  4.2784e-02
 5.5202e-02  5.4895e-02  5.4737e-02  4.9649e-02  4.7337e-02  4.7281e-02
 4.21

## 1.3 Parametric Attention Pooling

In [16]:
class NWKernelRegressionImpl : public torch::nn::Module {
public:
    NWKernelRegressionImpl() {
        //w = torch::nn::Parameter(torch::rand({1}, torch::requires_grad(true)));
        w = torch::rand({1}, torch::requires_grad(true));
        register_parameter("w", w);
    }

    torch::Tensor forward(torch::Tensor queries, torch::Tensor keys, torch::Tensor values) {
        // Shape of the output `queries` and `attention_weights`:
        // (no. of queries, no. of key-value pairs)
        queries = queries.repeat_interleave(keys.size(1)).reshape({-1, keys.size(1)});
        auto attention_weights = F::softmax(-1*torch::pow(((queries - keys)*w), 2) / 2, /*dim=*/1);
        // attention_weights = rorch::nn::functional::softmax(-1 * torch::pow((queries - keys)* w.item<float>(), 2) / 2, /*dim=*/1);
        // Shape of `values`: (no. of queries, no. of key-value pairs)
        return torch::bmm(attention_weights.unsqueeze(1),
                         values.unsqueeze(-1)).reshape(-1);
    }
public:
    torch::Tensor w; // w parameter
};
TORCH_MODULE(NWKernelRegression);

In [17]:
torch::Tensor X_tile = x_train.repeat({n_train, 1});

// Shape of `Y_tile`: (`n_train`, `n_train`), where each column contains the same training outputs
torch::Tensor Y_tile = y_train.repeat({n_train, 1});

// Shape of `keys`: ('n_train', 'n_train' - 1)
torch::Tensor keys = torch::masked_select(X_tile, (1 - torch::eye(n_train)).to(torch::kBool)).reshape({n_train, -1});
std::cout << keys.sizes() << "\n";

// Shape of `values`: ('n_train', 'n_train' - 1)
torch::Tensor values = torch::masked_select(Y_tile, (1 - torch::eye(n_train)).to(torch::kBool)).reshape({n_train, -1});
std::cout << values.sizes() << "\n";

[50, 49]
[50, 49]


In [18]:
// Using the squared loss and stochastic gradient descent, we [train the parametric attention model].
NWKernelRegression net = NWKernelRegression();
torch::nn::MSELoss loss = torch::nn::MSELoss(torch::nn::MSELossOptions(torch::kNone));
torch::optim::SGD optimizer = torch::optim::SGD(net->parameters(), 0.5);

std::vector<float> v_epoch, v_loss;

for( int epoch = 0; epoch < 5; epoch++ ) {
    optimizer.zero_grad();
    auto l = loss(net->forward(x_train, keys, values), y_train);
    l.sum().backward();
    optimizer.step();
    std::cout << "epoch: " << (epoch + 1) << ", loss: " << l.sum().item<float>() << std::endl;
    v_epoch.push_back((epoch + 1)*1.0);
    v_loss.push_back(l.sum().item<float>());
}

epoch: 1, loss: 54.1917
epoch: 2, loss: 54.0975
epoch: 3, loss: 54.1435
epoch: 4, loss: 53.8769
epoch: 5, loss: 14.2772


In [19]:
std::cout << net->w << std::endl;

 4.3752
[ CPUFloatType{1} ]


# 2 general framework of attention mechanisms - Attention Scoring Functions
In Section 1, we used a Gaussian kernel to model interactions between queries and keys. Treating the exponent of the Gaussian kernel in (11.2.6) as an attention scoring function (or scoring function for short), the results of this function were essentially fed into a softmax operation. As a result, we obtained a probability distribution (attention weights) over values that are paired with keys. In the end, the output of the attention pooling is simply a weighted sum of the values based on these attention weights.

At a high level, we can use the above algorithm to instantiate the framework of attention mechanisms in Fig. 11.1.3. Denoting an attention scoring function by , Fig. 11.3.1 illustrates how the output of attention pooling can be computed as a weighted sum of values. Since attention weights are a probability distribution, the weighted sum is essentially a weighted average.

![](https://d2l.ai/_images/attention-output.svg)

### elementary knowledge - Masked Softmax Operation
Ａs we just mentioned, a softmax operation is used to output a probability distribution as attention weights. In some cases, not all the values should be fed into attention pooling. For instance, for efficient minibatch processing in Section 10.5, some text sequences are padded with special tokens that do not carry meaning. To get an attention pooling over only meaningful tokens as values, we can specify a valid sequence length (in number of tokens) to filter out those beyond this specified range when computing softmax. In this way, we can implement such a masked softmax operation in the following masked_softmax function, where any value beyond the valid length is masked as zero

In [20]:
torch::Tensor sequence_mask(torch::Tensor X, torch::Tensor  valid_len, float value) {
    //Mask irrelevant entries in sequences.
    int64_t maxlen = X.size(1);
    auto mask = torch::arange((maxlen), torch::TensorOptions().dtype(torch::kFloat32).device(X.device())).index({torch::indexing::None, Slice()}) < valid_len.index({Slice(), torch::indexing::None});

    // (if B - boolean tensor) at::Tensor not_B = torch::ones_like(B) ^ B;
    // std::cout << (torch::ones_like(mask) ^ mask).sizes() <<std::endl;
    X.index_put_({torch::ones_like(mask) ^ mask}, value);

    return X;
}

In [21]:
torch::Tensor masked_softmax(torch::Tensor X, torch::Tensor valid_lens) {
    // Perform softmax operation by masking elements on the last axis.
    // `X`: 3D tensor, `valid_lens`: 1D or 2D tensor
    if( ! valid_lens.defined() || (valid_lens.numel() == 0) ) { 								// None
        return F::softmax(X, /*dim=*/-1);
    } else {
        auto shape = X.sizes();

        if( valid_lens.dim() == 1) {
            valid_lens = torch::repeat_interleave(valid_lens, shape[shape.size() - 2]);
        } else {
            valid_lens = valid_lens.reshape(-1);
        }

        // On the last axis, replace masked elements with a very large negative value, whose exponentiation outputs 0
        //std::cout << X.reshape({-1, shape[shape.size() - 1]}).sizes()  << "\n";
        X = sequence_mask(X.reshape({-1, shape[shape.size() - 1]}), valid_lens, /*value=*/ -1e6);

        return F::softmax(X.reshape(shape), /*dim=*/-1);

    }
}

In [22]:
torch::Tensor masked_attention_weights = masked_softmax(torch::rand({2, 2, 4}), torch::tensor({2, 3}));
std::cout << masked_attention_weights << std::endl;

(1,.,.) = 
  0.6018  0.3982  0.0000  0.0000
  0.4276  0.5724  0.0000  0.0000

(2,.,.) = 
  0.3252  0.2519  0.4229  0.0000
  0.3284  0.2495  0.4220  0.0000
[ CPUFloatType{2,2,4} ]


In [23]:
masked_attention_weights = masked_softmax(torch::rand({2, 2, 4}), torch::tensor({{1, 3}, {2, 4}}));
std::cout << masked_attention_weights << std::endl;

(1,.,.) = 
  1.0000  0.0000  0.0000  0.0000
  0.3829  0.2397  0.3774  0.0000

(2,.,.) = 
  0.4128  0.5872  0.0000  0.0000
  0.2063  0.2214  0.2841  0.2882
[ CPUFloatType{2,2,4} ]


## 2.1 Scaled Dot-Product Attention
A more computationally efficient design for the scoring function can be simply dot product. However, the dot product operation requires that both the query and the key have the same vector length, say d. Assume that all the elements of the query and the key are independent random variables with zero mean and unit variance. The dot product of both vectors has zero mean and a variance of . To ensure that the variance of the dot product still remains one regardless of vector length, the scaled dot-product attention scoring function

In [24]:
class DotProductAttentionImpl : public torch::nn::Module {
    public:
    // Scaled dot product attention.
	DotProductAttentionImpl(float dropout) {
        dpout = torch::nn::Dropout(dropout);
        register_module("dpout", dpout);
	}

    // Shape of `queries`: (`batch_size`, no. of queries, `d`)
    // Shape of `keys`: (`batch_size`, no. of key-value pairs, `d`)
    // Shape of `values`: (`batch_size`, no. of key-value pairs, value dimension)
    // Shape of `valid_lens`: (`batch_size`,) or (`batch_size`, no. of queries)
    torch::Tensor forward(torch::Tensor queries, torch::Tensor keys, torch::Tensor values, torch::Tensor valid_lens) {
        int n_shape = (queries.sizes()).size();
    	auto d = queries.sizes()[n_shape - 1];
        // Set `transpose_b=True` to swap the last two dimensions of `keys`
        auto scores = torch::bmm(queries, keys.transpose(1, 2)) / std::sqrt(d);
        attention_weights = masked_softmax(scores, valid_lens);
        return torch::bmm(dpout(attention_weights), values);
    }
    torch::nn::Dropout dpout{nullptr};
    torch::Tensor attention_weights;
};
TORCH_MODULE(DotProductAttention);

In [25]:
torch::Tensor queries = torch::normal(0, 1, {2, 1, 2});
torch::Tensor keys = torch::ones({2, 10, 2});

// The two value matrices in the values minibatch are identical
torch::Tensor values = torch::arange(40, torch::TensorOptions().dtype(torch::kFloat32)).reshape({1, 10, 4}).repeat({2, 1, 1});

torch::Tensor valid_lens = torch::tensor({2, 6});

In [26]:
std::cout << queries << std::endl;

(1,.,.) = 
  0.8010 -0.6466

(2,.,.) = 
  1.2837  0.0655
[ CPUFloatType{2,1,2} ]


In [27]:
std::cout << keys << std::endl;

(1,.,.) = 
  1  1
  1  1
  1  1
  1  1
  1  1
  1  1
  1  1
  1  1
  1  1
  1  1

(2,.,.) = 
  1  1
  1  1
  1  1
  1  1
  1  1
  1  1
  1  1
  1  1
  1  1
  1  1
[ CPUFloatType{2,10,2} ]


In [28]:
std::cout << values << std::endl;

(1,.,.) = 
   0   1   2   3
   4   5   6   7
   8   9  10  11
  12  13  14  15
  16  17  18  19
  20  21  22  23
  24  25  26  27
  28  29  30  31
  32  33  34  35
  36  37  38  39

(2,.,.) = 
   0   1   2   3
   4   5   6   7
   8   9  10  11
  12  13  14  15
  16  17  18  19
  20  21  22  23
  24  25  26  27
  28  29  30  31
  32  33  34  35
  36  37  38  39
[ CPUFloatType{2,10,4} ]


In [29]:
auto dattention = DotProductAttention(0.5);
dattention->eval();
torch::Tensor output = dattention->forward(queries, keys, values, valid_lens);

In [30]:
std::cout << output << std::endl;
std::cout << output.sizes() << std::endl;

(1,.,.) = 
  2  3  4  5

(2,.,.) = 
  10.0000  11.0000  12.0000  13.0000
[ CPUFloatType{2,1,4} ]
[2, 1, 4]


In [31]:
std::cout << dattention->attention_weights << std::endl;
std::cout << dattention->attention_weights.sizes() << std::endl;

(1,.,.) = 
 Columns 1 to 9  0.5000  0.5000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000

Columns 10 to 10  0.0000

(2,.,.) = 
 Columns 1 to 9  0.1667  0.1667  0.1667  0.1667  0.1667  0.1667  0.0000  0.0000  0.0000

Columns 10 to 10  0.0000
[ CPUFloatType{2,1,10} ]
[2, 1, 10]


## 2.2 Additive Attention
In general, when queries and keys are vectors of different lengths, we can use additive attention as the scoring function. Given a query and a key , the additive attention scoring function

a(\mathbf q, \mathbf k) = \mathbf w_v^\top \text{tanh}(\mathbf W_q\mathbf q + \mathbf W_k \mathbf k) \in \mathbb{R},

In [32]:
struct AdditiveAttentionImpl : public torch::nn::Module {
    public:
    //Additive attention
    AdditiveAttentionImpl(int64_t key_size, int64_t query_size, int64_t num_hiddens, float dropout) {
        W_k = torch::nn::Linear(torch::nn::LinearOptions(key_size, num_hiddens).bias(false));
        W_q = torch::nn::Linear(torch::nn::LinearOptions(query_size, num_hiddens).bias(false));
        W_v = torch::nn::Linear(torch::nn::LinearOptions(num_hiddens, 1).bias(false));
        dpout = torch::nn::Dropout(dropout);
        register_module("W_k", W_k);
        register_module("W_q", W_q);
        register_module("W_v", W_v);
        register_module("dpout", dpout);
    }

    torch::Tensor forward(torch::Tensor queries, torch::Tensor keys, torch::Tensor values, torch::Tensor valid_lens) {
        queries = W_q->forward(queries);
        keys = W_k->forward(keys);
        // After dimension expansion, shape of `queries`: (`batch_size`, no. of
        // queries, 1, `num_hiddens`) and shape of `keys`: (`batch_size`, 1,
        // no. of key-value pairs, `num_hiddens`). Sum them up with broadcasting

        //std::cout << "queries: " << queries.sizes() << "\n";
        //std::cout << "keys: " << keys.sizes() << "\n";
        auto features = queries.unsqueeze(2) + keys.unsqueeze(1);
        features = torch::tanh(features);

        //std::cout << "features: " << features.sizes() << "\n";
        // There is only one output of `self.w_v`, so we remove the last
        // one-dimensional entry from the shape. Shape of `scores`:
        // (`batch_size`, no. of queries, no. of key-value pairs)
        auto scores = W_v->forward(features).squeeze(-1); //squeeze() 不加参数的，把所有为1的维度都压缩
        //std::cout << "scores: " << scores.sizes() << "\n";
        //std::cout << "valid_lens: " << valid_lens.numel() << "\n";

        attention_weights = masked_softmax(scores, valid_lens);
        //std::cout << "attention_weights: " << attention_weights.sizes() << "\n";

        // Shape of `values`: (`batch_size`, no. of key-value pairs, value dimension)
        return torch::bmm(dpout->forward(attention_weights), values);
    }
    torch::nn::Linear W_k{nullptr}, W_q{nullptr}, W_v{nullptr};
    torch::nn::Dropout dpout{nullptr};
    torch::Tensor attention_weights;
};

TORCH_MODULE(AdditiveAttention);

In [33]:
queries = torch::normal(0, 1, {2, 1, 20});
keys = torch::ones({2, 10, 2});

// The two value matrices in the values minibatch are identical
values = torch::arange(40, torch::TensorOptions().dtype(torch::kFloat32)).reshape({1, 10, 4}).repeat({2, 1, 1});
valid_lens = torch::tensor({2, 6});

In [34]:
auto attention = AdditiveAttention(2, 20, 8, 0.1);

In [35]:
attention->eval();
auto output = attention->forward(queries, keys, values, valid_lens);

In [36]:
std::cout << output << std::endl;

(1,.,.) = 
  2  3  4  5

(2,.,.) = 
  10.0000  11.0000  12.0000  13.0000
[ CPUFloatType{2,1,4} ]


In [37]:
std::cout << attention->attention_weights << std::endl;
std::cout << attention->attention_weights.sizes() << std::endl;

(1,.,.) = 
 Columns 1 to 9  0.5000  0.5000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000

Columns 10 to 10  0.0000

(2,.,.) = 
 Columns 1 to 9  0.1667  0.1667  0.1667  0.1667  0.1667  0.1667  0.0000  0.0000  0.0000

Columns 10 to 10  0.0000
[ CPUFloatType{2,1,10} ]
[2, 1, 10]


# further reading
## multi-head attention
In practice, given the same set of queries, keys, and values we may want our model to combine knowledge from different behaviors of the same attention mechanism, such as capturing dependencies of various ranges (e.g., shorter-range vs. longer-range) within a sequence. Thus, it may be beneficial to allow our attention mechanism to jointly use different representation subspaces of queries, keys, and values.

To this end, instead of performing a single attention pooling, queries, keys, and values can be transformed with independently learned linear projections. Then these projected queries, keys, and values are fed into attention pooling in parallel. In the end, attention pooling outputs are concatenated and transformed with another learned linear projection to produce the final output. This design is called multi-head attention, where each of the attention pooling outputs is a head [Vaswani et al., 2017]. Using fully connected layers to perform learnable linear transformations, Fig. 11.5.1 describes multi-head attention

![](https://d2l.ai/_images/multi-head-attention.svg)

In [38]:
torch::Tensor transpose_qkv(torch::Tensor X, int64_t num_heads) {
    // Transposition for parallel computation of multiple attention heads.
    // Shape of input `X`:
    // (`batch_size`, no. of queries or key-value pairs, `num_hiddens`).
    // Shape of output `X`:
    // (`batch_size`, no. of queries or key-value pairs, `num_heads`,
    // `num_hiddens` / `num_heads`)
    X = X.reshape({X.size(0), X.size(1), num_heads, -1});

    // Shape of output `X`:
    // (`batch_size`, `num_heads`, no. of queries or key-value pairs,
    // `num_hiddens` / `num_heads`)
    X = X.permute({0, 2, 1, 3});

    // Shape of `output`:
    // (`batch_size` * `num_heads`, no. of queries or key-value pairs,
    // `num_hiddens` / `num_heads`)
    return X.reshape({-1, X.size(2), X.size(3)});
}

In [39]:
torch::Tensor transpose_output(torch::Tensor X, int64_t num_heads) {
    // Reverse the operation of `transpose_qkv`.
    X = X.reshape({-1, num_heads, X.size(1), X.size(2)});
    X = X.permute({0, 2, 1, 3});
    return X.reshape({X.size(0), X.size(1), -1});
}

In [40]:
class MultiHeadAttentionImpl : public torch::nn::Module {
    public:
	int64_t num_heads;
	DotProductAttention attention{nullptr};
	torch::nn::Linear W_k{nullptr}, W_q{nullptr}, W_v{nullptr}, W_o{nullptr};

    //Multi-head attention
	MultiHeadAttentionImpl(int64_t key_size, int64_t query_size, int64_t value_size, int64_t num_hiddens,
                 int64_t n_heads, float dropout, bool bias=false) {

        num_heads = n_heads;
        attention = DotProductAttention(dropout);
        W_q = torch::nn::Linear(torch::nn::LinearOptions(query_size, num_hiddens).bias(bias));
        W_k = torch::nn::Linear(torch::nn::LinearOptions(key_size, num_hiddens).bias(bias));
        W_v = torch::nn::Linear(torch::nn::LinearOptions(value_size, num_hiddens).bias(bias));
        W_o = torch::nn::Linear(torch::nn::LinearOptions(num_hiddens, num_hiddens).bias(bias));
        register_module("attention", attention);
        register_module("W_q",W_q);
        register_module("W_k",W_k);
        register_module("W_v",W_v);
        register_module("W_o",W_o);
	}

    torch::Tensor forward(torch::Tensor queries, torch::Tensor keys, torch::Tensor values, torch::Tensor valid_lens) {
        // Shape of `queries`, `keys`, or `values`:
        // (`batch_size`, no. of queries or key-value pairs, `num_hiddens`)
        // Shape of `valid_lens`:
        // (`batch_size`,) or (`batch_size`, no. of queries)
        // After transposing, shape of output `queries`, `keys`, or `values`:
        // (`batch_size` * `num_heads`, no. of queries or key-value pairs,
        // `num_hiddens` / `num_heads`)
        queries = transpose_qkv(W_q->forward(queries), num_heads);
        keys    = transpose_qkv(W_k->forward(keys), num_heads);
        values  = transpose_qkv(W_v->forward(values), num_heads);

        if( valid_lens.defined() ) {
            // On axis 0, copy the first item (scalar or vector) for
            // `num_heads` times, then copy the next item, and so on
            valid_lens = torch::repeat_interleave(valid_lens, /*repeats=*/num_heads, /*dim=*/0);
        }

        // Shape of `output`: (`batch_size` * `num_heads`, no. of queries,
        // `num_hiddens` / `num_heads`)
        auto output = attention->forward(queries, keys, values, valid_lens);

        // Shape of `output_concat`:
        // (`batch_size`, no. of queries, `num_hiddens`)
        auto output_concat = transpose_output(output, num_heads);
        return W_o->forward(output_concat);
    }
};
TORCH_MODULE(MultiHeadAttention);

In [41]:
int64_t num_hiddens = 10, num_heads = 5;
int64_t batch_size = 2, num_queries = 4, num_kvpairs = 6;
auto valid_lens = torch::tensor({3, 2});
auto X = torch::ones({batch_size, num_queries, num_hiddens});
auto Y = torch::ones({batch_size, num_kvpairs, num_hiddens});

In [42]:
auto attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.5);
attention->eval();
std::cout << attention->forward(X, Y, Y, valid_lens).sizes() << std::endl;
std::cout << attention << std::endl;

[2, 4, 10]
__cling_N541::MultiHeadAttentionImpl(
  (attention): __cling_N525::DotProductAttentionImpl(
    (dpout): torch::nn::Dropout(p=0.5, inplace=false)
  )
  (W_q): torch::nn::Linear(in_features=10, out_features=10, bias=false)
  (W_k): torch::nn::Linear(in_features=10, out_features=10, bias=false)
  (W_v): torch::nn::Linear(in_features=10, out_features=10, bias=false)
  (W_o): torch::nn::Linear(in_features=10, out_features=10, bias=false)
)


## self-attention

In [44]:
torch::Tensor output_after_attention = attention ->forward(X, X, X, valid_lens);

In [46]:
std::cout << output_after_attention << std::endl;
std::cout << output_after_attention.sizes() << std::endl;

(1,.,.) = 
 Columns 1 to 9  0.0065  0.4535 -0.2640 -0.3188  0.0791 -0.4609 -0.1906  0.5901 -0.1081
  0.0065  0.4535 -0.2640 -0.3188  0.0791 -0.4609 -0.1906  0.5901 -0.1081
  0.0065  0.4535 -0.2640 -0.3188  0.0791 -0.4609 -0.1906  0.5901 -0.1081
  0.0065  0.4535 -0.2640 -0.3188  0.0791 -0.4609 -0.1906  0.5901 -0.1081

Columns 10 to 10 -0.4416
 -0.4416
 -0.4416
 -0.4416

(2,.,.) = 
 Columns 1 to 9  0.0065  0.4535 -0.2640 -0.3188  0.0791 -0.4609 -0.1906  0.5901 -0.1081
  0.0065  0.4535 -0.2640 -0.3188  0.0791 -0.4609 -0.1906  0.5901 -0.1081
  0.0065  0.4535 -0.2640 -0.3188  0.0791 -0.4609 -0.1906  0.5901 -0.1081
  0.0065  0.4535 -0.2640 -0.3188  0.0791 -0.4609 -0.1906  0.5901 -0.1081

Columns 10 to 10 -0.4416
 -0.4416
 -0.4416
 -0.4416
[ CPUFloatType{2,4,10} ]
[2, 4, 10]


In [50]:
std::cout << attention->attention->attention_weights << std::endl;
std::cout << attention->attention->attention_weights.sizes() << std::endl;

(1,.,.) = 
  0.3333  0.3333  0.3333  0.0000
  0.3333  0.3333  0.3333  0.0000
  0.3333  0.3333  0.3333  0.0000
  0.3333  0.3333  0.3333  0.0000

(2,.,.) = 
  0.3333  0.3333  0.3333  0.0000
  0.3333  0.3333  0.3333  0.0000
  0.3333  0.3333  0.3333  0.0000
  0.3333  0.3333  0.3333  0.0000

(3,.,.) = 
  0.3333  0.3333  0.3333  0.0000
  0.3333  0.3333  0.3333  0.0000
  0.3333  0.3333  0.3333  0.0000
  0.3333  0.3333  0.3333  0.0000

(4,.,.) = 
  0.3333  0.3333  0.3333  0.0000
  0.3333  0.3333  0.3333  0.0000
  0.3333  0.3333  0.3333  0.0000
  0.3333  0.3333  0.3333  0.0000

(5,.,.) = 
  0.3333  0.3333  0.3333  0.0000
  0.3333  0.3333  0.3333  0.0000
  0.3333  0.3333  0.3333  0.0000
  0.3333  0.3333  0.3333  0.0000

(6,.,.) = 
  0.5000  0.5000  0.0000  0.0000
  0.5000  0.5000  0.0000  0.0000
  0.5000  0.5000  0.0000  0.0000
  0.5000  0.5000  0.0000  0.0000

(7,.,.) = 
  0.5000  0.5000  0.0000  0.0000
  0.5000  0.5000  0.0000  0.0000
  0.5000  0.5000  0.0000  0.0000
  0.5000  0.5000  0.0000  

# Reference
* https://d2l.ai/chapter_attention-mechanisms-and-transformers/index.html
* https://github.com/jiamny/Dive_into_deep_learning_with_libtorch/blob/master/src/utils/ch_10_util.h