<a href="https://colab.research.google.com/github/thai94/d2l/blob/main/10.attention_mechanisms/10_3_attention_scoring_functions_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [24]:
import math
import torch
from torch import nn

In [25]:
def sequence_mask(X, valid_len, value=0):

  maxlen = X.size(1)
  mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None]
  X[~mask] = value
  return X

In [26]:
def masked_softmax(X, valid_lens):

  if valid_lens is None:
    return nn.functional.softmax(X, dim=-1)
  
  shape = X.shape
  if valid_lens.dim() == 1:
    valid_lens = torch.repeat_interleave(valid_lens, shape[1])
  else:
    valid_lens = valid_lens.reshape(-1)
  X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
                              value=-1e6)
  return nn.functional.softmax(X.reshape(shape), dim=-1)

In [27]:
masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))

tensor([[[0.4661, 0.5339, 0.0000, 0.0000],
         [0.5988, 0.4012, 0.0000, 0.0000]],

        [[0.3728, 0.2808, 0.3463, 0.0000],
         [0.4424, 0.2788, 0.2788, 0.0000]]])

In [28]:
masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))

tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.2331, 0.2616, 0.5053, 0.0000]],

        [[0.3763, 0.6237, 0.0000, 0.0000],
         [0.1642, 0.4026, 0.2194, 0.2139]]])

In [29]:
class AdditiveAttention(nn.Module):

  def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
    super(AdditiveAttention, self).__init__(**kwargs)
    self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
    self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
    self.w_v = nn.Linear(num_hiddens, 1, bias=False)
    self.dropout = nn.Dropout(dropout)

  def forward(self, queries, keys, values, valid_lens):

    queries, keys = self.W_q(queries), self.W_k(keys)
    # queries shape: batch_size, no. of queries, num_hiddens
    # key shape: batch_size, no. of key-value pairs, num_hiddens

    # shape of `queries`: (`batch_size`, no. of queries, 1, num_hiddens)
    # key shape:          (batch_size, 1, no. of key-value pairs, num_hiddens)
    features = queries.unsqueeze(2) + keys.unsqueeze(1)
    features = torch.tanh(features)
    # scores: (`batch_size`, no. of queries, no. of key-value pairs)
    scores = self.w_v(features).squeeze(-1)
    self.attention_weights = masked_softmax(scores, valid_lens)

    ret = torch.bmm(self.dropout(self.attention_weights), values)

    print(queries.shape)
    print(keys.shape)
    print(features.shape)
    print(scores.shape)
    print(self.attention_weights.shape)
    print(ret.shape)

    return ret


In [30]:
queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))
# The two value matrices in the `values` minibatch are identical
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(
    2, 1, 1)
valid_lens = torch.tensor([2, 6])

attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8,
                              dropout=0.1)
attention.eval()
attention(queries, keys, values, valid_lens)

torch.Size([2, 1, 8])
torch.Size([2, 10, 8])
torch.Size([2, 1, 10, 8])
torch.Size([2, 1, 10])
torch.Size([2, 1, 10])
torch.Size([2, 1, 4])


tensor([[[ 2.0000,  3.0000,  4.0000,  5.0000]],

        [[10.0000, 11.0000, 12.0000, 13.0000]]], grad_fn=<BmmBackward0>)

In [31]:
class DotProductAttention(nn.Module):

  def __init__(self, dropout, **kwargs):
    super(DotProductAttention, self).__init__(**kwargs)
    self.dropout = nn.Dropout(dropout)

  def forward(self, queries, keys, values, valid_lens=None):

    d = queries.shape[-1]
    scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
    self.attention_weights = masked_softmax(scores, valid_lens)
    return torch.bmm(self.dropout(self.attention_weights), values)

In [32]:
queries = torch.normal(0, 1, (2, 1, 2))
attention = DotProductAttention(dropout=0.5)
attention.eval()
attention(queries, keys, values, valid_lens)

tensor([[[ 2.0000,  3.0000,  4.0000,  5.0000]],

        [[10.0000, 11.0000, 12.0000, 13.0000]]])