<a href="https://colab.research.google.com/github/torrhen/deep-learning-papers/blob/main/Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn

In [None]:
class ScaledDotProductAttention(nn.Module):
  '''
  Scaled Dot-Product Attention function as described in section 3.2.1. Used as part of the Multi-Head Attention layer.
  '''
  def __init__(self):
    super(ScaledDotProductAttention, self).__init__()
    # calculate attention weights
    self.softmax = nn.Softmax(dim=-1)

  def forward(self, Q, K, V):
    # transpose the final 2 dimensions of K to allow multiplication with Q
    K = K.permute(0, 1, 3, 2) # [b, h, sz_k, d_k] -> [b, h, d_k, sz_k]

    # calulate attention matrix between Q and K
    attn = Q.matmul(K) # [b, h, sz_q, d_q] @ [b, h, d_k, sz_k] -> [b, h, sz_q, sz_k]

    # scale attention matrix by factor sqrt(d_k)
    attn = attn / torch.tensor(K.shape[-2])
    # convert attention values to weights
    attn = self.softmax(attn)
    # multiply weighted attention with V
    out = attn.matmul(V)

    return out, attn # attention weighted values, attention weights


In [None]:
class MultiHeadAttention(nn.Module):
  '''
  Multi-Head Attention sub-layer as described in section 3.2.2. Used as part of the Encoder layer.

  TODO: mask, batch dimension

  '''
  def __init__(self, d_model, h):
    super(MultiHeadAttention, self).__init__()
    # embedding size
    self.d_model = d_model
    # number of heads
    self.h = h
    # embedding projection size for query, keys and values vectors
    self.d_q = self.d_k = self.d_v = self.d_model // self.h
    # linear projection layers for embeddings
    self.fc_Q = nn.Linear(in_features=self.d_model, out_features=self.d_model)
    self.fc_K = nn.Linear(in_features=self.d_model, out_features=self.d_model)
    self.fc_V = nn.Linear(in_features=self.d_model, out_features=self.d_model)
    # attention function
    self.attention = ScaledDotProductAttention()
    # linear projection layer for attention
    self.fc_mh_out = nn.Linear(in_features=self.d_model, out_features=self.d_model)

  def forward(self, Q, K, V):
    # linear projection of Q, K and V
    p_Q = self.fc_Q(Q) # [b, sz_q, d_model] -> [b, sz_q, d_model]
    p_K = self.fc_K(K) # [b, sz_k, d_model] -> [b, sz_k, d_model]
    p_V = self.fc_V(V) # [b, sz_v, d_model] -> [b, sz_v, d_model]

    # divide embedding dimension into seperate heads for Q, K, V
    p_Q = p_Q.reshape((1, -1, self.h, self.d_q)) # [b, sz_q, d_model] -> [b, sz_q, h, d_q]
    p_K = p_K.reshape((1, -1, self.h, self.d_k)) # [b, sz_k, d_model] -> [b, sz_k, h, d_k]
    p_V = p_V.reshape((1, -1, self.h, self.d_v)) # [b, sz_v, d_model] -> [b, sz_v, h, d_v]

    # move the head dimension of Q, K and V
    p_Q = p_Q.permute((0, 2, 1, 3)) # [b, sz_q, h, d_q] -> [b, h, sz_q, d_q]
    p_K = p_K.permute((0, 2, 1, 3)) # [b, sz_k, h, d_k] -> [b, h, sz_k, d_k]
    p_V = p_V.permute((0, 2, 1, 3)) # [b, sz_v, h, d_v] -> [b, h, sz_v, d_v]

    # calculate the scaled dot product attention for each head in parallel
    mh_out, mh_attn = self.attention(p_Q, p_K, p_V)

    # move the head dimension of the attention weighted values
    mh_out = mh_out.permute((0, 2, 1, 3)) # [b, sz_v, h, d_v] -> [b, sz_v, h, d_v]

    # concatenate heads of attention weighted values
    mh_out = mh_out.reshape((1, -1, self.d_model)) # [b, sz_v, h, d_v] -> [b, sz_v, h * d_v (d_model)]

    # linear projection of attention weighted values
    mh_out = self.fc_mh_out(mh_out) # [b, sz_v, d_model] -> [b, sz_v, d_model]

    return mh_out, mh_attn # multi-head output, multi-head attention weights

In [None]:
torch.manual_seed(10)

batch_size = 1
# sizes of Q, K and V
sz_q = sz_k = sz_v = 10
# embedding dim
d_model = 512

# query
Q = torch.randn(size=(batch_size, sz_q, d_model), dtype=torch.float32) # [b, sz_q, d_model]
# keys
K = torch.randn(size=(batch_size, sz_k, d_model), dtype=torch.float32) # [b, sz_k, d_model]
# values
V = torch.randn(size=(batch_size, sz_v, d_model), dtype=torch.float32) # [b, sz_v, d_model]

In [None]:
# test multi-head attention layer
multihead_attention = MultiHeadAttention(d_model, 8)
mh_out, mh_attn = multihead_attention(Q, K, V)

print(mh_out.shape)
print(mh_attn.shape)

torch.Size([1, 10, 512])
torch.Size([1, 8, 10, 10])


In [None]:
class FeedForwardNetwork(nn.Module):
  '''
  Position-wise Feed Forward Network sub-layer as described in section 3.3. Used as part of the Encoder layer.
  '''
  def __init__(self, d_model, d_ff):
    super(FeedForwardNetwork, self).__init__()
    # input size
    self.d_model = d_model
    # hidden units
    self.d_ff = d_ff
    # feed forward network layers
    self.fc_1 = nn.Linear(in_features=d_model, out_features=d_ff)
    self.fc_2 = nn.linear(in_features=d_ff, out_feature=d_model)
    self.relu = nn.ReLU()

  def forward(self, x):
    return self.fc_2(self.relu(self.fc_1(x)))

In [41]:
import torch
from torch import nn
import numpy as np

class PositionalEncoding(nn.Module):
  '''
  Positional Encoding as described in section 3.5
  '''
  def __init__(self, d_model):
    super(PositionalEncoding, self).__init__()
    # embedding size
    self.d_model = d_model
    # 2i / d_model
    self.exp = torch.arange(start=0, end=self.d_model, step=2, dtype=torch.float32) / self.d_model
    # 10000
    self.base = torch.full(size=(self.exp.shape[-1],), fill_value=10000.0, dtype=torch.float32)
    # 10000 ^ (2i / d_model)
    self.denominator = torch.pow(self.base, self.exp)

  def forward(self, x):
    # input sequence size
    sz_x = x.shape[-2]
    # initialise positional encoding for each sequence position
    pe = torch.zeros(size=(sz_x, self.d_model))
    
    # calculate positional encoding for each position in the input sequence
    for pos in range(sz_x):
      # PE(pos, 2i) = sin(pos / 10000^(2i / d_model))
      pe[pos, 0::2] = torch.sin(self.denominator)
      # PE(pos, 2i+1) = cos(pos / 10000^(2i / d_model))
      pe[pos, 1::2] = torch.cos(self.denominator)

    # combine input embedding and positional encoding
    x = x + pe

    return x



p = PositionalEncoding(512)

output = p(torch.randn(3, 8, d_model))
print(output)
print(output.shape)

tensor([[[ 0.1409,  0.1681,  0.7124,  ..., -0.1476,  0.8687,  1.7650],
         [ 2.2070,  0.6970,  2.7329,  ...,  1.7209,  2.6870, -1.0465],
         [ 2.1869, -0.5471,  1.7622,  ...,  2.7226,  0.9175, -0.3200],
         ...,
         [ 1.4383,  0.6669,  0.8037,  ...,  0.5861,  2.2831, -0.4902],
         [ 1.0945,  0.6509,  1.4815,  ...,  1.7653, -0.6370, -0.0856],
         [ 0.9465, -0.3051,  2.0486,  ...,  0.2450,  1.3773, -0.5689]],

        [[-0.3004,  1.5436,  0.4095,  ...,  1.0171, -0.1957,  0.7254],
         [ 0.5467,  0.9954,  0.3952,  ...,  0.1653,  0.0851, -1.7218],
         [ 1.6440,  0.6102,  1.3465,  ...,  1.1019,  0.4261, -0.0400],
         ...,
         [ 0.0062,  1.7500,  1.0579,  ...,  1.1991,  0.7634, -0.2441],
         [ 2.1763,  1.8964, -0.9343,  ...,  0.7227,  2.0685,  0.2015],
         [ 1.3125,  0.6929, -0.0996,  ...,  0.3745,  1.0795, -0.3656]],

        [[ 1.0870,  0.6678,  0.8443,  ..., -0.2066, -0.3458, -0.0820],
         [ 1.1817,  0.3142,  0.3208,  ..., -0