<a href="https://colab.research.google.com/github/torrhen/papers/blob/main/Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn

In [91]:
class ScaledDotProductAttention(nn.Module):
  '''
  Scaled Dot-Product Attention function as described in section 3.2.1. Used as part of the Multi-Head Attention layer.

  TODO: attention dropout

  '''
  def __init__(self):
    super(ScaledDotProductAttention, self).__init__()
    # calculate attention weights
    self.softmax = nn.Softmax(dim=-1)

  def forward(self, Q, K, V):
    # transpose the final 2 dimensions of K to allow multiplication with Q
    K = K.permute(0, 1, 3, 2) # [b, h, sz_k, d_k] -> [b, h, d_k, sz_k]

    # calulate attention matrix between Q and K
    attn = Q.matmul(K) # [b, h, sz_q, d_q] @ [b, h, d_k, sz_k] -> [b, h, sz_q, sz_k]

    # scale attention matrix by factor sqrt(d_k)
    attn = attn / torch.tensor(K.shape[-2])
    # convert attention values to weights
    attn = self.softmax(attn)
    # multiply weighted attention with V
    out = attn.matmul(V)

    return out, attn # attention weighted values, attention weights


In [94]:
class MultiHeadAttention(nn.Module):
  '''
  Multi-Head Attention function as described in section 3.2.2

  TODO: attention dropout, mask, batch dimension

  '''
  def __init__(self, d_model, h):
    super(MultiHeadAttention, self).__init__()
    # embedding size
    self.d_model = d_model
    # number of heads
    self.h = h
    # embedding projection size for query, keys and values vectors
    self.d_q = self.d_k = self.d_v = self.d_model // self.h
    # linear projection layers for embeddings
    self.fc_Q = nn.Linear(in_features=self.d_model, out_features=self.d_model)
    self.fc_K = nn.Linear(in_features=self.d_model, out_features=self.d_model)
    self.fc_V = nn.Linear(in_features=self.d_model, out_features=self.d_model)
    # attention function
    self.attention = ScaledDotProductAttention()
    # linear projection layer for attention
    self.fc_mh_out = nn.Linear(in_features=self.d_model, out_features=self.d_model)

  def forward(self, Q, K, V):
    # linear projection of Q, K and V
    p_Q = self.fc_Q(Q) # [b, sz_q, d_model] -> [b, sz_q, d_model]
    p_K = self.fc_K(K) # [b, sz_k, d_model] -> [b, sz_k, d_model]
    p_V = self.fc_V(V) # [b, sz_v, d_model] -> [b, sz_v, d_model]

    # divide embedding dimension into seperate heads for Q, K, V
    p_Q = p_Q.reshape((1, -1, self.h, self.d_q)) # [b, sz_q, d_model] -> [b, sz_q, h, d_q]
    p_K = p_K.reshape((1, -1, self.h, self.d_k)) # [b, sz_k, d_model] -> [b, sz_k, h, d_k]
    p_V = p_V.reshape((1, -1, self.h, self.d_v)) # [b, sz_v, d_model] -> [b, sz_v, h, d_v]

    # move the head dimension of Q, K and V
    p_Q = p_Q.permute((0, 2, 1, 3)) # [b, sz_q, h, d_q] -> [b, h, sz_q, d_q]
    p_K = p_K.permute((0, 2, 1, 3)) # [b, sz_k, h, d_k] -> [b, h, sz_k, d_k]
    p_V = p_V.permute((0, 2, 1, 3)) # [b, sz_v, h, d_v] -> [b, h, sz_v, d_v]

    # calculate the scaled dot product attention for each head in parallel
    mh_out, mh_attn = self.attention(p_Q, p_K, p_V)

    # move the head dimension of the attention weighted values
    mh_out = mh_out.permute((0, 2, 1, 3)) # [b, sz_v, h, d_v] -> [b, sz_v, h, d_v]

    # concatenate heads of attention weighted values
    mh_out = mh_out.reshape((1, -1, self.d_model)) # [b, sz_v, h, d_v] -> [b, sz_v, h * d_v (d_model)]

    # linear projection of attention weighted values
    mh_out = self.fc_mh_out(mh_out) # [b, sz_v, d_model] -> [b, sz_v, d_model]

    return mh_out, mh_attn # multi-head output, multi-head attention weights

In [98]:
torch.manual_seed(10)

batch_size = 1
# sizes of Q, K and V
sz_q = sz_k = sz_v = 10
# embedding dim
d_model = 512

# query
Q = torch.randn(size=(batch_size, sz_q, d_model), dtype=torch.float32) # [b, sz_q, d_model]
# keys
K = torch.randn(size=(batch_size, sz_k, d_model), dtype=torch.float32) # [b, sz_k, d_model]
# values
V = torch.randn(size=(batch_size, sz_v, d_model), dtype=torch.float32) # [b, sz_v, d_model]

In [100]:
# test multi-head attention layer
multihead_attention = MultiHeadAttention(d_model, 8)
mh_out, mh_attn = multihead_attention(Q, K, V)

print(mh_out.shape)
print(mh_attn.shape)

torch.Size([1, 10, 512])
torch.Size([1, 8, 10, 10])
