In [2]:
import torch
import torch.nn as nn
import math
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace
from pathlib import Path
from torch.utils.data import Dataset, DataLoader, random_split

In [3]:
class InputEmbeddings(nn.Module):
  def __init__(self, d_model:int, vocab_size:int) -> None:
    super().__init__()
    self.d_model = d_model
    self.vocab_size = vocab_size
    self.embedding = nn.Embedding(vocab_size, d_model)

  def forward(self, x):
    # We scale (increase) the input embeddings to make positional embeddings relatively smaller. 
    # This way the meaning of the input embeddings is preserved.
    # Embedding shape will be [batch size x sequence length x embedding dimension]
    return self.embedding(x) * math.sqrt(self.d_model)

class PositionalEmbedding(nn.Module):
  # Positional Embedding shape will be [batch size x sequence length x embedding dimension]
  # Positional Embedding will be added to the input embeddings.
  # Positonal embedding are computed only once and then reused for every sentence in training and inference.
  def __init__(self, maxlen:int, d_model:int, dropout:float) -> None:
    super().__init__()
    self.maxlen = maxlen
    self.dropout = nn.Dropout(p=dropout)
    self.d_model = d_model

    # Position embedding matrix of shape (maxlen, d_model)
    self.pos_emb = torch.zeros(maxlen, d_model)
    # Position for each token in the sequence of shape (maxlen, 1)
    position = torch.arange(0, maxlen, dtype=torch.float32).unsqueeze(1)
    div = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float32) * -(math.log(10000.0)/d_model))
    # Positional embedding for each token in the sequence of shape (maxlen, d_model)
    self.pos_emb[:, 0::2] = torch.sin(position*div)
    self.pos_emb[:, 1::2] = torch.cos(position*div)
    self.pos_emb = self.pos_emb.unsqueeze(0) # (1, maxlen, d_model)
    # We don't want to update the positional embedding matrix during training, so we register it as a buffer.
    self.register_buffer('pos_emb', self.pos_emb)

  def forward(self, x):
    # Add positional embedding to the input of shape (bsz, maxlen, d_model)
    # Since position embedding is not trainable, we don't need to compute gradients for this tensor.
    # The position embedding is calculated in init and we only read in the forward function here.
    x = x + self.pos_emb[:, :x.shape(1), :].requires_grad_(False) # (bsz, maxlen, d_model)
    return self.dropout(x)

In [4]:
class LayerNorm(nn.Module):
    def __init__(self, eps:float =1e-6) -> None:
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(1))
        self.bias = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        # Shape of x is (bsz, maxlen, d_model)
        # Calculate mean and std per feature
        mean = x.mean(dim = -1, keepdim=True)
        std = torch.sqrt(x.var(dim=-1, keepdim=True)+self.eps)
        return self.alpha * (x - mean)/std + self.bias

In [5]:
class FeedForward(nn.Module):
    def __init__(self, d_model:int , d_ff:int=2048, dropout:float =0.1) -> None:
        super().__init__()

        # Define linear layers
        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.activation = nn.ReLU()

    def forward(self, x):
        # Apply feed-forward network and activation function
        x = self.dropout(self.activation(self.linear1(x)))
        x = self.linear2(x)
        return x

In [6]:
class MultiHeadAttentionBlock(nn.Module):
    def __init__(self, d_model:int , n_heads:int=8, dropout:float=0.1) -> None:
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        # d_model must be divisible by number of heads
        assert (d_model % n_heads == 0), "Number of heads must be a factor of model dimension"
        # Calculate the dimension of each head
        self.d_head = d_model // n_heads

        # Define linear layers for query, key and values
        self.w_query = nn.Linear(d_model, d_model)
        self.w_key  = nn.Linear(d_model, d_model)
        self.w_value = nn.Linear(d_model, d_model)
        # Define linear layer for output
        self.w_out = nn.Linear(d_model, d_model)
        # Define dropout layer
        self.dropout = nn.Dropout(dropout)
        # Define softmax function
        self.softmax = nn.Softmax(dim=-1)
        # Define scaling factor for query and key dot product in attention.
        self.scale = torch.sqrt(torch.FloatTensor([d_model // n_heads]))
    
    def attention(self, query:torch.Tensor , key:torch.Tensor , value:torch.Tensor , mask:torch.Tensor = None, dropout: nn.Dropout = None ) -> torch.Tensor :
        # Calculate the scaled dot product attention
        # query shape = (batch size, number of heads, sequence length, dimension of each head)
        # key shape = (batch size, number of heads, sequence length, dimension of each head)
        # value shape = (batch size, number of heads, sequence length, dimension of each head)
        # scores shape = (batch size, number of heads, sequence length, sequence length)
        scores = torch.matmul(query, key.transpose(-2,-1)) / self.scale
        # Apply mask if given
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        # Apply softmax to get weights. Softmax is applied over the last dimension.
        scores = self.softmax(scores)

        if dropout is not None:
            scores = dropout(scores)
        
        # Calculate weighted sum of values. Need to return scores for visualisation.
        # scores shape = (batch size, number of heads, sequence length, sequence length)
        # value shape = (batch size, number of heads, sequence length, dimension of each head)
        # return shape = (batch size, number of heads, sequence length, dimension of each head)
        return torch.matmul(scores, value), scores
    
    def forward(self, query:torch.Tensor, key:torch.Tensor, value:torch.Tensor, mask:torch.Tensor = None) -> torch.Tensor :
        # Shape of query, key , value = (bsz, maxlen, d_model)
        b, l_q = query.shape[:2]
        _, l_k = key.shape[:2]

        # Pass through linear layers. Shape of each is (batch size, sequence length, model dimension)
        q = self.w_query(query)
        k = self.w_key(key)
        v = self.w_value(value)

        # Reshape q, k and v for multi-head attention, Shape of each is (batch size, sequence length, number of heads, dimension per head). 
        # Split the embedding into n_heads heads of dimension d_model / n_heads each
        query  = q.view(b, l_q, self.n_heads, self.d_head)
        key    = k.view(b, l_k, self.n_heads, self.d_head)
        value  = v.view(b, l_k, self.n_heads, self.d_head)

        # Transpose the dimensions of each so that the shape is (batch size, number of heads, sequence length, dimension per head). 
        # This will allow us to perform attention on each head separately. Each head can see the whole sequence but a small part of the embedding space.
        query  = query.transpose(1,2)
        key    = key.transpose(1,2)
        value  = value.transpose(1,2)

        x, self.attention_scores = self.attention(query, key, value, mask, self.dropout)
        
        # Transpose the dimensions of each so that the shape is (batch size, sequence length, number of heads, dimension per head). 
        # We need it to be in contigous memory to do view operations later on..
        x = x.transpose(1,2).contiguous()

        # Reshape to (batch size, sequence length, model dimension)
        x = x.view(b, l_q, self.d_model)
        
        # Apply the output linear layer and dropout
        x = self.w_o(x)
        # x = self.dropout(x)
        # Shape of the output is (batch size, sequence length, model dimension).
        return x

In [7]:
class ResidualConnection(nn.Module):
    def __init__(self, dropout: float = 0.1) -> None:
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.norm = LayerNorm()

    def forward(self, x: torch.Tensor, sublayer: nn.Module):
        return x + self.dropout(sublayer(self.norm(x)))

In [8]:
class EncoderBlock(nn.Module):
    def __init__(self, self_attention:MultiHeadAttentionBlock, feed_forward: FeedForward, dropout: float = 0.1) -> None:
        super().__init__()
        self.self_attention = self_attention
        self.feed_forward   = feed_forward
        self.res_connection1 = ResidualConnection()
        self.res_connection2 = ResidualConnection()


    # we need a mask here to hide the padding tokens from the attention computation
    def forward(self, x: torch.Tensor, src_mask: torch.Tensor) -> torch.Tensor:
        x = self.res_connection1(x, self.self_attention(x, x, x, src_mask))
        x = self.res_connection2(x, self.feed_forward(x))
        return x



In [9]:
class Encoder(nn.Module):
    def __init__(self, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNorm()

        def forward(self, x: torch.Tensor, src_mask: torch.Tensor) -> torch.Tensor:
            for layer in self.layers:
                x = layer(x, src_mask)
            return self.norm(x)
            

In [10]:
class DecoderBlock(nn.Module):
    def __init__(self, self_attention: MultiHeadAttentionBlock, cross_attention: MultiHeadAttentionBlock, feed_forward: FeedForward, dropout: float = 0.1) -> None:
        super().__init__()
        self.self_attention = self_attention
        self.cross_attention = cross_attention
        self.feed_forward    = feed_forward
        self.dropout = dropout
        self.res_connection1  = ResidualConnection()
        self.res_connection2   = ResidualConnection()
        self.res_connection3 = ResidualConnection()

    def forward(self, x: torch.Tensor, encoder_output: torch.Tensor, src_mask: torch.Tensor, trg_mask: torch.Tensor) -> torch.Tensor:
        x = self.res_connection1(x, self.self_attention(x, x, x, trg_mask))
        x = self.res_connection2(x, self.cross_attention(x, encoder_output, encoder_output, src_mask))
        return self.res_connection3(x, self.feed_forward(x))
        

In [11]:
class Decoder(nn.Module):
    def __init__(self, layers: nn.ModuleList) -> None:
        super().__init__()
        self
        self.norm = LayerNorm()

    def forward(self, x: torch.Tensor, encoder_output: torch.Tensor, src_mask: torch.Tensor, trg_mask: torch.Tensor) -> torch.Tensor:
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, trg_mask)
        return self.norm(x)


In [12]:
class ProjectionLayer(nn.Module):
    def __init__(self, d_model: int, vocab_size: int) -> None:
        super().__init__()
        # Project to target vocabulary size.
        self.linear = nn.Linear(d_model, vocab_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Shape of x is (batch_size, seq_len, d_model)
        # Shape of output is (batch_size, seq_len, vocab_size)
        return torch.log_softmax(self.linear(x), dim=-1)

In [13]:
class Transformer(nn.Module):
    def __init__(self, encoder: Encoder, decoder: Decoder, src_embedding: InputEmbeddings, trg_embedding: InputEmbeddings, src_pos_embed: PositionalEmbedding, 
                trg_pos_embed: PositionalEmbedding, projection_layer: ProjectionLayer) -> None:

        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embedding = src_embedding
        self.trg_embedding = trg_embedding
        self.src_pos_embed = src_pos_embed
        self.trg_pos_embed = trg_pos_embed
        self.projection_layer = projection_layer

    def encode(self, x: torch.Tensor, src_mask: torch.Tensor) -> torch.Tensor:
        # Shape of x is (batch_size, seq_len)
        src = self.src_embedding(x)
        # Shape of src is (batch_size, seq_len, d_model)
        src = self.src_pos_embed(src)
        # Shape of src is (batch_size, seq_len, d_model)
        return self.encoder(src, src_mask)
    
    def decode(self, x: torch.Tensor, encoder_output: torch.Tensor, src_mask: torch.Tensor, trg_mask: torch.Tensor) -> torch.Tensor:
        # Shape of x is (batch_size, seq_len)
        trg = self.trg_embedding(x)
        # Shape of trg is (batch_size, seq_len, d_model)
        trg = self.trg_pos_embed(trg)
        # Shape of trg is (batch_size, seq_len, d_model)
        return self.decoder(trg, encoder_output, src_mask, trg_mask)

    def project(self, x: torch.Tensor) -> torch.Tensor:
        return self.projection_layer(x)

    # def forward(self, src: torch.Tensor, trg: torch.Tensor) -> torch.Tensor:


def build_transformer(src_vocab_size: int,
                       trg_vocab_size: int,
                       src_sequence_length: int,
                       trg_sequence_length: int,
                       d_model: int = 512,
                       nhead: int = 8,
                       num_encoder_layers: int = 6,
                       num_decoder_layers: int = 6,
                       dim_feedforward: int = 2048,
                       dropout: float = 0.1) -> Transformer:
    """Builds a transformer model.""" 

    src_embedding = InputEmbeddings(d_model, src_vocab_size)
    trg_embedding = InputEmbeddings(d_model, trg_vocab_size)
    src_pos_embed = PositionalEmbedding(d_model)
    trg_pos_embed = PositionalEmbedding(d_model)
    
    encoder_blocks = []
    for i in range (num_encoder_layers):
        encoder_self_attention = MultiHeadAttentionBlock(d_model, nhead, dropout)
        encoder_feedforward = FeedForward(d_model, dim_feedforward, dropout)
        encoder_block = EncoderBlock(encoder_self_attention, encoder_feedforward, dropout)
        encoder_blocks.append(encoder_block)
    
    decoder_blocks = []
    for i in range (num_decoder_layers):
        decoder_self_attention  = MultiHeadAttentionBlock(d_model, nhead, dropout)
        encoder_decoder_attention  = MultiHeadAttentionBlock(d_model, nhead, dropout)
        decoder_feedforward  = FeedForward(d_model, dim_feedforward, dropout)
        decoder_block  = DecoderBlock(decoder_self_attention, encoder_decoder_attention, decoder_feedforward, dropout)
        decoder_blocks.append(decoder_block)

    encoder = Encoder(encoder_blocks)
    decoder = Decoder(decoder_blocks)

    projection_layer = nn.Linear(d_model, trg_vocab_size)

    transformer = Transformer(encoder, decoder, src_embedding, trg_embedding, src_pos_embed, trg_pos_embed, projection_layer)

    # Initialize parameters with Glorot / fan_avg.
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    

    return transformer

In [14]:
def get_all_sentences(dataset, lang):
    for item in dataset:
        yield item['translation'][lang]

def build_tokenizer(config, dataset, lang):
    # Build a new word level tokenizer. 
    # If already exists, load it from disk.
    # Otherwise train a new one and save to disk.
    tokenizer_path = Path(config['tokenizer_file'].format(lang))
    if not Path.exists(tokenizer_path):
        # Replace any unknown tokens with [UNK].Unknowwn tokens don't exist in the vocabulary.
        tokenizer = Tokenizer(WordLevel(unk_token='[UNK]'))
        # Split the sentence based on whitespace.
        tokenizer.pre_tokenizer = Whitespace()
        # Train a tokenizer from the data. Splits words on white space, and will have 4 special tokens. Minimum frequency of 2 implies that a word to appear in our dictionary it has to appera at least twice in the data.
        trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
        tokenizer.train_from_iterator(get_all_sentences(dataset, lang), trainer=trainer)
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer

def get_dataset(config):
    dataset_raw = load_dataset('opus_books', f'{config["lang_src"]}-{config["lang_tgt"]}', split='train')

    tokenizer_src = build_tokenizer(config, dataset_raw, config["lang_src"])
    tokenizer_tgt = build_tokenizer(config, dataset_raw, config["lang_tgt"])

    # Split data 90% for training and 10% for validation
    train_dataset_size = int(len(dataset_raw) * 0.9)
    val_dataset_size = len(dataset_raw) - train_dataset_size
    train_dataset_raw, val_dataset_raw = random_split(dataset_raw, [train_dataset_size, val_dataset_size])


In [16]:
class BilingualDataset(Dataset):
    
    def __init__(self, dataset, src_tokenizer, tgt_tokenizer, src_lang, tgt_lang, seq_len):
        super().__init__()
        self.seq_len = seq_len
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        self.src_tokenizer = src_tokenizer
        self.tgt_tokenizer = tgt_tokenizer
        self.dataset = dataset

        # Get ID of [PAD], [SOS] and [EOS] tokens in the src/target language
        self.sos_token = torch.tensor([tgt_tokenizer.token_to_id("[SOS]")], dtype=torch.int64)
        self.eos_token = torch.tensor([tgt_tokenizer.token_to_id("[EOS]")], dtype=torch.int64)
        self.pad_token = torch.tensor([tgt_tokenizer.token_to_id("[PAD]")], dtype=torch.int64)

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        src_target = self.dataset[index]
        src_txt = src_target['translation'][self.src_lang]
        tgt_txt = src_target['translation'][self.tgt_lang]

        # Get IDs for src and tgt tokens from the tokenizer for the text in the src/target language
        src_tokens = self.src_tokenizer.encode(src_txt).ids
        tgt_tokens = self.tgt_tokenizer.encode(tgt_txt).ids
        
        src_len = len(src_tokens)
        tgt_len = len(tgt_tokens)
        
        # Pad the src and tgt sequences to the maximum sequence length.
        # src seq will have 2 extra tokens [SOS] and [EOS], while tgt seq will have 1 extra token [SOS].
        enc_padding_token_count = self.seq_len - src_len - 2
        dec_padding_token_count = self.seq_len - tgt_len - 1

        # seq_len + special tokens cannot exceed the maximum sequence length
        if enc_padding_token_count < 0 or dec_padding_token_count < 0:
            raise ValueError(f'Sequence lenght is too long')
        
        # Add SOS and EOS tokens to the src sequence, followed by padding tokens
        # Shape of encoder input:([SOS] + src tokens + [EOS] + [PAD] * padding length)
        encoder_input = torch.cat( 
            [
                self.sos_token, 
                torch.tensor(src_tokens,dtype=torch.int64),
                self.eos_token, 
                torch.tensor([self.pad_token] * enc_padding_token_count, dtype=torch.int64),
            ], dim=0
        )

        # Add SOS token to the tgt sequence followed by padding tokens
        # Shape of decoder input:([SOS]+tgt+[PAD]*padding length)
        decoder_input = torch.cat(
            [
                self.sos_token, 
                torch.tensor(tgt_tokens,dtype=torch.int64),
                torch.tensor([self.pad_token] * dec_padding_token_count, dtype=torch.int64),
            ], dim=0
        )

        # Add EOS token to the label sequence followed by padding tokens
        # Shape of decoder label:(tgt+[EOS]+[PAD]*padding length)
        label = torch.cat(
            [
                torch.tensor(tgt_tokens,dtype=torch.int64), 
                self.eos_token, 
                torch.tensor([self.pad_token] * dec_padding_token_count, dtype=torch.int64)
            ], dim=0
        )

        assert encoder_input.shape[0] == self.seq_len
        assert decoder_input.shape[0] == self.seq_len
        assert label.shape[0] == self.seq_len

        return {
            "encoder_input": encoder_input,
            "decoder_input": decoder_input, 
            "label": label,
            # Shape of encoder_mask:([1,1,seq_len])
            "encoder_mask": (encoder_input!= self.pad_token).unsqueeze(0).unsqueeze(0).int(),
            # Broadcat will happen here in the bitwise and.
            "decoder_mask": (decoder_input!= self.pad_token).unsqueeze(0).int() & self.causal_mask(decoder_input.shape[0]),
            "src_txt": src_txt,
            "tgt_txt": tgt_txt,
        }

    def causal_mask(self, length):
        # Here triu returns the upper triangular matrix excluding the diagonal elements. Diagonal argument controls which diagonal to consider.
        # if diagonal = 0, all elements on and above the main diagonal are retained.
        # A positive value excludes just as many diagonals above the main diagonal, and similarly a negative value includes just as many diagonals below the main diagonal.
        mask = torch.triu(torch.ones((1, length, length), diagonal=1)).type(torch.int)
        return (1 - mask).bool()

