In [1]:
import math
import torch
from torch import nn
from d2l import torch as d2l

In [2]:
def masked_softmax(X,valid_lens):
    if valid_lens is None:
        return nn.functional.softmax(X,dim=-1)
    else:
        shape=X.shape
        if valid_lens.dim()==1:
            valid_lens=torch.repeat_interleave(valid_lens,shape[1])
        else:
            valid_lens=valid_lens.reshape(-1)
        X=d2l.sequence_mask(X.reshape(-1,shape[-1]),valid_lens,value=-1e6)
        return nn.functional.softmax(X.reshape(shape),dim=-1)

In [3]:
masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))

tensor([[[0.3321, 0.6679, 0.0000, 0.0000],
         [0.5055, 0.4945, 0.0000, 0.0000]],

        [[0.3033, 0.2700, 0.4267, 0.0000],
         [0.4513, 0.2955, 0.2532, 0.0000]]])

In [4]:
masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))

tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.2970, 0.3425, 0.3605, 0.0000]],

        [[0.4271, 0.5729, 0.0000, 0.0000],
         [0.1769, 0.2673, 0.2921, 0.2637]]])

In [5]:
class AdditiveAttention(nn.Module):
    def __init__(self,key_size,query_size,num_hiddens,dropout,**kwargs):
        super(AdditiveAttention, self).__init__(**kwargs)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
        self.w_v = nn.Linear(num_hiddens, 1, bias=False)
        self.dropout = nn.Dropout(dropout)
    def forward(self,queries,keys,values,valid_lens):
        queries,keys=self.W_q(queries),self.W_k(keys)
        features=torch.unsqueeze(queries,dim=2)+torch.unsqueeze(keys,dim=1)
        features=torch.tanh(features)
        scores=self.w_v(features).squeeze(dim=-1)
        self.attention_weights=masked_softmax(scores,valid_lens)
        return torch.bmm(self.dropout(self.attention_weights),values)

In [7]:
queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))
queries

tensor([[[-0.4614,  0.5076,  0.0334, -0.6187, -1.5714,  0.2323,  1.7643,
          -0.1371,  0.4528, -0.2451, -0.6314,  1.2588,  0.0040, -0.1359,
           0.4308, -0.2722, -0.7264,  1.1732,  0.4825, -0.7781]],

        [[ 0.4376, -0.8590, -0.1773,  0.2728,  0.9183, -1.1395, -0.0990,
          -1.1712, -0.3262, -1.7565, -1.2060,  1.7139,  1.1807, -1.8242,
           0.2295,  0.1091,  0.7169,  0.3899, -1.0987, -0.0684]]])