In [1]:
import torch
from torch import nn
import math


In [16]:
def attention(Q, K, V, mask=None):
  # x_q shape: [batch, query_len, query_dim // head] 
  # x_k shape: [batch, key_len, query_dim // head]
  # x_v shape: [batch, key_len, value_dim // head]
  # weight shape: [batch, query_len, key_len]
  # mask shape: [batch]
  weight = torch.bmm(Q, K.permute(0, 2, 1)) / math.sqrt(Q.shape[-1])
  if mask is not None:
    mask_tensor = torch.arange((weight.shape[-1]))[None, :] < mask[:, None]
    weight = weight.masked_fill(~mask_tensor[:, None], -1e6)

  return torch.bmm(nn.functional.softmax(weight, dim=-1), V)


def split_head(tensor, head):
  batch, time, dim = tensor.shape
  return tensor.reshape((batch, time, head, dim // head)).permute(2, 0, 1, 3)

def scaled_dot_multihead_attention(x_q, x_k, x_v, head_num, mask):
  # x_q shape: [batch, query_len, query_dim] 
  # x_k shape: [batch, key_len, query_dim]
  # x_v shape: [batch, key_len, value_dim]
  # mask shape: [batch]
  x_k = split_head(x_k, head_num)
  x_q = split_head(x_q, head_num)
  x_v = split_head(x_v, head_num)
  # x_q shape: [head, batch, query_len, query_dim // head] 
  # x_k shape: [head, batch, key_len, query_dim // head]
  # x_v shape: [head, batch, key_len, value_dim // head]
  output_list = []
  for q, k, v in zip(x_q, x_k, x_v):
    # output_list element shape: [batch, query_len, attention_dim // head]
    output_list.append(attention(q, k, v, mask=mask))

  return torch.concat(output_list, dim=2)


In [17]:
num_hiddens, num_heads = 5, 5
batch_size, num_queries, num_kvpairs, valid_lens = 2, 4, 6, torch.tensor([3, 2])

q = torch.arange(batch_size * num_queries * num_hiddens, dtype=torch.float32).reshape((batch_size, num_queries, num_hiddens))
k = torch.arange(batch_size * num_kvpairs * num_hiddens, dtype=torch.float32).reshape((batch_size, num_kvpairs,num_hiddens))
v = torch.arange(batch_size * num_kvpairs * num_hiddens, dtype=torch.float32).reshape((batch_size, num_kvpairs,num_hiddens))

# q = torch.ones((batch_size, num_queries, num_hiddens))
# k = torch.ones((batch_size, num_kvpairs, num_hiddens))
# v = torch.ones((batch_size, num_kvpairs, num_hiddens))

scaled_dot_multihead_attention(q, k, v, num_heads, valid_lens)


tensor([[[ 5.0000, 10.9661, 11.9998, 13.0000, 14.0000],
         [10.0000, 11.0000, 12.0000, 13.0000, 14.0000],
         [10.0000, 11.0000, 12.0000, 13.0000, 14.0000],
         [10.0000, 11.0000, 12.0000, 13.0000, 14.0000]],

        [[35.0000, 36.0000, 37.0000, 38.0000, 39.0000],
         [35.0000, 36.0000, 37.0000, 38.0000, 39.0000],
         [35.0000, 36.0000, 37.0000, 38.0000, 39.0000],
         [35.0000, 36.0000, 37.0000, 38.0000, 39.0000]]])

In [25]:
class MultiheadAttention(nn.Module):
  def __init__(self, q_dim, k_dim, v_dim, latent_dim, head_num) -> None:
    super().__init__()
    self.K_pre = nn.Linear(k_dim, latent_dim, bias=False)
    self.Q_pre = nn.Linear(q_dim, latent_dim, bias=False)
    self.V_pre = nn.Linear(v_dim, latent_dim, bias=False)
    self.final = nn.Linear(latent_dim, latent_dim, bias=False)
    self.head_num = head_num

  def forward(self, x_q, x_k, x_v, mask=None):
    # x_k, x_q, x_v shape: [batch, time, latent_dim]
    k, q, v = self.K_pre(x_k), self.Q_pre(x_q), self.V_pre(x_v)
    attentions = scaled_dot_multihead_attention(q, k, v, self.head_num, mask)
    return self.final(attentions)

class Encoder(nn.Module):
  def __init__(self) -> None:
      super().__init__()
      

In [30]:
num_hiddens, num_heads = 5, 5
layer = MultiheadAttention(num_hiddens, num_hiddens, num_hiddens,
                               num_hiddens, num_heads)
layer.eval()

batch_size, num_queries, num_kvpairs, valid_lens = 2, 4, 6, torch.tensor([3, 2])
q = torch.arange(batch_size * num_queries * num_hiddens, dtype=torch.float32).reshape((batch_size, num_queries, num_hiddens))
k = torch.arange(batch_size * num_kvpairs * num_hiddens, dtype=torch.float32).reshape((batch_size, num_kvpairs,num_hiddens))
v = torch.arange(batch_size * num_kvpairs * num_hiddens, dtype=torch.float32).reshape((batch_size, num_kvpairs,num_hiddens))
layer(q, k, v, valid_lens)

tensor([[[ 0.7799,  1.1065,  0.7765, -2.0466,  2.3597],
         [ 0.5640,  1.0851,  0.3234, -2.2220,  2.5895],
         [ 0.3934,  1.0652, -0.0271, -2.3523,  2.7689],
         [ 0.3361,  1.0807, -0.1803, -2.4557,  2.9237]],

        [[ 1.7985,  2.4169,  1.1444, -5.5539, 10.9390],
         [ 1.7950,  2.4262,  1.1207, -5.5830, 10.9852],
         [ 1.7945,  2.4329,  1.1077, -5.6022, 11.0161],
         [ 1.7948,  2.4373,  1.1003, -5.6142, 11.0358]]],
       grad_fn=<UnsafeViewBackward0>)