## 注意力分数
- 回顾：注意力权重+注意力分数
- ![softmax-description](./imgs/65-2.png)
- ![softmax-description](./imgs/65-1.png)

## 拓展到高维度
- key和value的维度不一定相同
- ![softmax-description](./imgs/65-3.png)

## Additive Attention
- ![softmax-description](./imgs/65-4.png)

## Scaled Dot-Product Attention

## 总结
- 注意力

In [1]:
import torch
from torch import nn
import d2l

In [2]:
# 遮蔽softmax操作
def masked_softmax(X, valid_lens):
    """通过在最后一个轴上遮蔽元素来执行 softmax 操作"""
    # `X`: 3D张量, `valid_lens`: 1D或2D 张量
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            # repeat_interleave: repeated tensor which has the same shape as input, except along the given axis.
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # 在最后的轴上，被遮蔽的元素使用一个非常大的负值替换，从而其 softmax (指数)输出为 0
        X = d2l.sequence_mask(
            X.reshape(-1, shape[-1]), valid_lens,
                              value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)

In [3]:
masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))

tensor([[[0.5619, 0.4381, 0.0000, 0.0000],
         [0.4803, 0.5197, 0.0000, 0.0000]],

        [[0.3957, 0.2597, 0.3446, 0.0000],
         [0.3613, 0.3236, 0.3150, 0.0000]]])

In [4]:
masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))

tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.2565, 0.2286, 0.5149, 0.0000]],

        [[0.5116, 0.4884, 0.0000, 0.0000],
         [0.3399, 0.2703, 0.1819, 0.2079]]])

In [5]:
weights = torch.ones((2, 10)) * 0.1
values = torch.arange(20.0).reshape((2, 10))
torch.bmm(weights.unsqueeze(1), values.unsqueeze(-1))

tensor([[[ 4.5000]],

        [[14.5000]]])

In [None]:
class NWKernelRegression(nn.Module):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # 在数学上，w用于控制高斯核窗口的大小(是否平滑)
        self.w = nn.Parameter(torch.rand((1,), requires_grad=True))

    def forward(self, queries, keys, values):
        # `queries` 和 `attention_weights` 的形状为 (查询个数, “键－值”对个数)
        queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))
        self.attention_weights = nn.functional.softmax(
            -((queries - keys) * self.w)**2 / 2, dim=1)
        # `values` 的形状为 (查询个数, “键－值”对个数)
        return torch.bmm(self.attention_weights.unsqueeze(1),
                         values.unsqueeze(-1)).reshape(-1)