<a href="https://colab.research.google.com/github/torrhen/paper-transformer/blob/main/Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn
import math

In [None]:
class ScaledDotProductAttention(nn.Module):
  '''
  Scaled Dot-Product Attention function as described in section 3.2.1. Used as part of the Multi-Head Attention layer.
  '''
  def __init__(self):
    super(ScaledDotProductAttention, self).__init__()
    # calculate attention weights
    self.softmax = nn.Softmax(dim=-1)

  def forward(self, Q, K, V, mask=None):
    # transpose the final 2 dimensions of K to allow multiplication with Q
    K = K.permute(0, 1, 3, 2) # [b, h, sz_k, d_k] -> [b, h, d_k, sz_k]

    # calulate attention matrix between Q and K
    attn = Q.matmul(K) # [b, h, sz_q, d_q] @ [b, h, d_k, sz_k] -> [b, h, sz_q, sz_k]

    # scale attention matrix by factor sqrt(d_k)
    attn = attn / torch.tensor(K.shape[-2])

    # mask out illegal attention value connections
    if mask is not None:
      attn = attn.masked_fill_(mask, -math.inf)

    # convert attention values to weights
    attn = self.softmax(attn)
    # multiply weighted attention with V
    out = attn.matmul(V)

    return out, attn # attention weighted values, attention weights


In [None]:
class MultiHeadAttention(nn.Module):
  '''
  Multi-Head Attention sub-layer as described in section 3.2.2. Used as part of the Encoder layer.
  '''
  def __init__(self, d_model, h):
    super(MultiHeadAttention, self).__init__()
    # embedding size
    self.d_model = d_model
    # number of heads
    self.h = h
    # embedding projection size for query, keys and values vectors
    self.d_q = self.d_k = self.d_v = self.d_model // self.h
    # linear projection layers for embeddings
    self.fc_Q = nn.Linear(in_features=self.d_model, out_features=self.d_model)
    self.fc_K = nn.Linear(in_features=self.d_model, out_features=self.d_model)
    self.fc_V = nn.Linear(in_features=self.d_model, out_features=self.d_model)
    # attention function
    self.attention = ScaledDotProductAttention()
    # linear projection layer for attention
    self.fc_mh_out = nn.Linear(in_features=self.d_model, out_features=self.d_model)

  def forward(self, Q, K, V, mask=None):
    batch_size = Q.shape[0]
    # linear projection of Q, K and V
    p_Q = self.fc_Q(Q) # [b, sz_q, d_model] -> [b, sz_q, d_model]
    p_K = self.fc_K(K) # [b, sz_k, d_model] -> [b, sz_k, d_model]
    p_V = self.fc_V(V) # [b, sz_v, d_model] -> [b, sz_v, d_model]

    # divide embedding dimension into seperate heads for Q, K, V
    p_Q = p_Q.reshape((batch_size, -1, self.h, self.d_q)) # [b, sz_q, d_model] -> [b, sz_q, h, d_q]
    p_K = p_K.reshape((batch_size, -1, self.h, self.d_k)) # [b, sz_k, d_model] -> [b, sz_k, h, d_k]
    p_V = p_V.reshape((batch_size, -1, self.h, self.d_v)) # [b, sz_v, d_model] -> [b, sz_v, h, d_v]

    # move the head dimension of Q, K and V
    p_Q = p_Q.permute((0, 2, 1, 3)) # [b, sz_q, h, d_q] -> [b, h, sz_q, d_q]
    p_K = p_K.permute((0, 2, 1, 3)) # [b, sz_k, h, d_k] -> [b, h, sz_k, d_k]
    p_V = p_V.permute((0, 2, 1, 3)) # [b, sz_v, h, d_v] -> [b, h, sz_v, d_v]

    # calculate the scaled dot product attention for each head in parallel
    mh_out, mh_attn = self.attention(p_Q, p_K, p_V, mask)

    # move the head dimension of the attention weighted values
    mh_out = mh_out.permute((0, 2, 1, 3)) # [b, sz_v, h, d_v] -> [b, sz_v, h, d_v]

    # concatenate heads of attention weighted values
    mh_out = mh_out.reshape((batch_size, -1, self.d_model)) # [b, sz_v, h, d_v] -> [b, sz_v, h * d_v (d_model)]

    # linear projection of attention weighted values
    mh_out = self.fc_mh_out(mh_out) # [b, sz_v, d_model] -> [b, sz_v, d_model]

    return mh_out, mh_attn # multi-head output, multi-head attention weights

In [None]:
class FeedForwardNetwork(nn.Module):
  '''
  Position-wise Feed Forward Network sub-layer as described in section 3.3. Used as part of the Encoder layer.
  '''
  def __init__(self, d_model, d_ff):
    super(FeedForwardNetwork, self).__init__()
    # input size
    self.d_model = d_model
    # hidden units
    self.d_ff = d_ff
    # feed forward network layers
    self.fc_1 = nn.Linear(in_features=self.d_model, out_features=self.d_ff)
    self.fc_2 = nn.Linear(in_features=self.d_ff, out_features=self.d_model)
    self.relu = nn.ReLU()

  def forward(self, x):
    return self.fc_2(self.relu(self.fc_1(x)))

In [None]:
import torch
from torch import nn
import numpy as np

class PositionalEncoding(nn.Module):
  '''
  Positional Encoding as described in section 3.5.
  '''
  def __init__(self, d_model):
    super(PositionalEncoding, self).__init__()
    # embedding size
    self.d_model = d_model
    # 2i / d_model
    self.exp = torch.arange(start=0, end=self.d_model, step=2, dtype=torch.float32) / self.d_model
    # 10000
    self.base = torch.full(size=(self.exp.shape[-1],), fill_value=10000.0, dtype=torch.float32)
    # 10000 ^ (2i / d_model)
    self.denominator = torch.pow(self.base, self.exp)

  def forward(self, x):
    # input sequence size
    sz_x = x.shape[-2]
    # initialise positional encoding for each sequence position
    pe = torch.zeros(size=(sz_x, self.d_model))
    
    # calculate positional encoding for each position in the input sequence
    for pos in range(sz_x):
      # PE(pos, 2i) = sin(pos / 10000^(2i / d_model))
      pe[pos, 0::2] = torch.sin(self.denominator)
      # PE(pos, 2i+1) = cos(pos / 10000^(2i / d_model))
      pe[pos, 1::2] = torch.cos(self.denominator)

    # combine input embedding and positional encoding
    x = x + pe
    return x

In [None]:
class EncoderLayer(nn.Module):
  '''
  Encoder layer as described in section 3.1. Contains the multi-head attention and feed forward network sub-layers.
  '''
  def __init__(self, d_model, d_ff):
    super(EncoderLayer, self).__init__()
    # embedding size
    self.d_model = d_model
    # number of attention heads
    self.h = 8
    # feed foward network hidden units
    self.d_ff = d_ff
    # multi-head attention sub-layer
    self.mha = MultiHeadAttention(self.d_model, self.h)
    # multi-head attention layer norm
    self.layer_norm_mha = nn.LayerNorm(normalized_shape=self.d_model)
    # feed forward network sub-layer
    self.ffn = FeedForwardNetwork(self.d_model, self.d_ff)
    # feed foward network layer norm
    self.layer_norm_ffn = nn.LayerNorm(normalized_shape=self.d_model)

  def forward(self, x):
    # multihead attention
    query = keys = values = x
    mha_out, mha_attn = self.mha(query, keys, values)
    # residual connection and layer norm
    x = self.layer_norm_mha(x + mha_out)

    # feed forward network
    ffn_out = self.ffn(x)
    # residual connection and layer norm
    x = self.layer_norm_ffn(x + ffn_out)
    return x

In [None]:
class Encoder(nn.Module):
  '''
  Encoder as described in section 3.1. Contains multiple encoder layers.
  '''
  def __init__(self, N, d_model, h, d_ff):
    super(Encoder, self).__init__()
    # number of encoder layers
    self.N = N
    # embedding size
    self.d_model = d_model
    # number of attention heads
    self.h = h
    # feed forward network hidden units
    self.d_ff = d_ff
    # encoder of N encoder layers
    self.encoder = nn.ModuleList([EncoderLayer(self.d_model, self.d_ff) for i in range(self.N)])

  def forward(self, x):
    # pass input through each layer of the encoder
    for encoder_layer in self.encoder:
      x = encoder_layer(x)
    return x

In [None]:
class DecoderLayer(nn.Module):
  '''
  Decoder layer as described in section 3.1. Contains the multi-head attention and feed forward network sub-layers.
  '''
  def __init__(self, d_model, d_ff):
    super(DecoderLayer, self).__init__()
    # embedding size
    self.d_model = d_model
    # number of attention heads
    self.h = 8
    # feed foward network hidden units
    self.d_ff = d_ff

    # masked multi-head attention sub-layer
    self.masked_mha = MultiHeadAttention(self.d_model, self.h)
    # masked multi-head attention layer norm
    self.layer_norm_masked_mha = nn.LayerNorm(normalized_shape=self.d_model)

    # multi-head attention sub-layer
    self.mha = MultiHeadAttention(self.d_model, self.h)
    # multi-head attention layer norm
    self.layer_norm_mha = nn.LayerNorm(normalized_shape=self.d_model)

    # feed forward network sub-layer
    self.ffn = FeedForwardNetwork(self.d_model, self.d_ff)
    # feed foward network layer norm
    self.layer_norm_ffn = nn.LayerNorm(normalized_shape=self.d_model)

  def forward(self, x, encoder_output, mask=None):
    # masked multi-head attention
    query = keys = values = x
    masked_mha_out, masked_mha_attn = self.masked_mha(query, keys, values, mask)
    # residual connection and layer norm
    x = self.layer_norm_masked_mha(x + masked_mha_out)

    # multi-head attention
    query = x
    keys = values = encoder_output
    mha_out, mha_attn = self.mha(query, keys, values)
    # residual connection and layer norm
    x = self.layer_norm_mha(x + mha_out)

    # feed forward network
    ffn_out = self.ffn(x)
    # residual connection and layer norm
    x = self.layer_norm_ffn(x + ffn_out)

    return x

In [None]:
class Decoder(nn.Module):
  '''
  Decoder as described in section 3.1. Contains multiple decoder layers.
  '''
  def __init__(self, N,  d_model, h, d_ff):
    super(Decoder, self).__init__()
    # number of decoder layers
    self.N = N
    # embedding size
    self.d_model = d_model
    # number of attention heads
    self.h = h
    # feed forward network hidden units
    self.d_ff = d_ff
    # decoder of N decoder layers
    self.decoder = nn.ModuleList([DecoderLayer(self.d_model, self.d_ff) for i in range(self.N)])

  def forward(self, x, encoder_output, mask=None):
    # pass inputs through each layer of the decoder
    for decoder_layer in self.decoder:
      x = decoder_layer(x, encoder_output, mask)
    return x

In [None]:
class Transformer(nn.Module):
  '''
  Transformer architecture as described in section 3.
  '''
  def __init__(self, N_enc, N_dec, d_model, h, d_ff):
    super(Transformer, self).__init__()
    # number of encoder layers
    self.N_enc = N_enc
    # number of decoder layers
    self.N_dec = N_dec
    # embedding_size
    self.d_model = d_model
    # number of attention heads
    self.h = h
    # feed forward hidden units
    self.d_ff = d_ff

    # TODO: input embedding
    # TODO: output embedding

    self.PE = PositionalEncoding(self.d_model)
    self.encoder = Encoder(self.N_enc, self.d_model, self.h, self.d_ff)
    self.decoder = Decoder(self.N_dec, self.d_model, self.h, self.d_ff)
    self.fc_o = nn.Linear(in_features=self.d_model, out_features=10000) # update with actual vocab size layer
    self.softmax = nn.Softmax(dim=-1)

  def forward(self, x_enc, x_dec):
    # TODO: input embedding
    # TODO: output embedding

    # positional encoding of input embedding
    x_enc = self.PE(x_enc)
    # positoinal encoding of output embedding
    x_dec = self.PE(x_dec)

    # encoder
    y_enc = self.encoder(x_enc)
    # decoder
    y_dec = self.decoder(x_dec, y_enc)

    # linear projection to vocabulary size
    output = self.fc_o(y_dec)
    # output probabilities
    output = self.softmax(output)

    return output

In [None]:
# required for access to torchtext datasets
!pip install -U torchdata

# required for tokenizer
!pip install -U spacy
!python -m spacy download de_core_news_sm
!python -m spacy download en_core_web_sm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchdata
  Downloading torchdata-0.5.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.6/4.6 MB[0m [31m33.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting urllib3>=1.25
  Downloading urllib3-1.26.14-py2.py3-none-any.whl (140 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m140.6/140.6 KB[0m [31m13.7 MB/s[0m eta [36m0:00:00[0m
Collecting portalocker>=2.0.0
  Downloading portalocker-2.7.0-py2.py3-none-any.whl (15 kB)
Installing collected packages: urllib3, portalocker, torchdata
  Attempting uninstall: urllib3
    Found existing installation: urllib3 1.24.3
    Uninstalling urllib3-1.24.3:
      Successfully uninstalled urllib3-1.24.3
Successfully installed portalocker-2.7.0 torchdata-0.5.1 urllib3-1.26.14
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.

In [None]:
from torchtext.datasets import Multi30k
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

# translate from german to english
SRC_LANGUAGE = 'de'
TAR_LANGUAGE = 'en'

# store individual tokenizer for each language inside dictionary
tokenizers = {}
# create tokenizer using german written vocabulary 
tokenizers[SRC_LANGUAGE] = get_tokenizer('spacy', language='de_core_news_sm')
# create tokenizer using english written vocabulary
tokenizers[TAR_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')

# yield list of tokens using language tokenizer
def yield_tokens(data_iter, language):
  language_idx = {SRC_LANGUAGE : 0, TAR_LANGUAGE : 1}
  # loop through each 'de'-'en' pair in the data iterable
  for data in data_iter:
    # select the tokenizer and create list of tokens from the corresponding language text sample
    yield tokenizers[language](data[language_idx[language]])

# Define special symbols
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
special_symbols = [
                    '<unk>', # symbol for infrequent words outside of the vocabulary and unknown
                    '<pad>', # ensure sequences in batches are the same length
                    '<bos>', # beginning of sentence token
                    '<eos>'  # end of sentence token
                  ]

vocabularies = {}
# build each language vocabulary from iterator
for language in [SRC_LANGUAGE, TAR_LANGUAGE]:
  # training iterable
  train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TAR_LANGUAGE))
  # store language vocabulary
  vocabularies[language] = build_vocab_from_iterator(yield_tokens(train_iter, language),
                                                        min_freq=1,
                                                        specials=special_symbols,
                                                        special_first=True # insert symbols at beginning of vocabulary
                                                        )

# set UNK_IDX as the default index returned when a token cannot be found
for language in [SRC_LANGUAGE, TAR_LANGUAGE]:
  vocabularies[language].set_default_index(UNK_IDX)

In [None]:
# print dictionary of tokens and integer values for german vocabulary
print(vocabularies[SRC_LANGUAGE].get_stoi())
print(f"German vocabulary size: {len(vocabularies[SRC_LANGUAGE])}")
# print dictionary of tokens and integer values for english vocabulary
print(vocabularies[TAR_LANGUAGE].get_stoi())
print(f"English vocabulary size: {len(vocabularies[TAR_LANGUAGE])}")


{'’': 19212, 'üppiges': 19210, 'überzeugen': 19206, 'überwältigt': 19205, 'überwiegende': 19202, 'überstehendes': 19199, 'überschlagenen': 19196, 'überschatteten': 19194, 'überraschten': 19193, 'überlisten': 19188, 'überlegen': 19187, 'überholt': 19185, 'übergroßer': 19183, 'übergießt': 19181, 'überfülltes': 19179, 'überfüllter': 19178, 'überfüllte': 19177, 'übereinstimmenden': 19175, 'überdimensioniertem': 19173, 'überdimensionierte': 19172, 'örtliche': 19167, 'öffentlich': 19165, 'äteres': 19160, 'ärztliche': 19158, 'ändern': 19154, 'ältliche': 19153, 'älterem': 19151, 'ähnelt': 19148, 'ägyptischem': 19147, 'Übungssaal': 19145, 'Übungsgrün': 19143, 'Überwurf': 19142, 'Überschwemmung': 19141, 'Überschlags': 19140, 'Übergabe': 19137, 'Überdachungen': 19136, 'Überblick': 19135, 'Überall': 19134, 'Über': 19133, 'Ölzeug': 19131, 'Ölverseuchung': 19130, 'Ölkanne': 19128, 'ÖPNV': 19125, 'Äxte': 19124, 'Ärzten': 19123, 'Ägyptisches': 19121, 'zögerndes': 19120, 'zwölf': 19118, 'zuzuwerfen': 1