In [23]:
import torch 
import torch.nn as nn
import math

In [24]:
class InputEmbeddings(nn.Module): 
    def __init__(self, d_model: int, vocabulary_size: int):
        super().__init__()
        self.d_model = d_model
        self.vocabulary_size = vocabulary_size
        self.embedding = nn.Embedding(vocabulary_size, d_model)
    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.d_model) # formula from Section 3.4

In [25]:
"""
Reference Section 3.5. 
Positional encodings embed the positional information of the tokens. 
This allows for the preservation of sequences (as such, the context) eg: 
"It is cheaper to book now"
"Can you return the book?"
We need sentence_len vectors that encode d_model dimensions. 
"""

class PositionalEmbeddings(nn.Module):
    def __init__(self, d_model: int, sentence_len: int, dropout: float):
        super().__init__()
        self.d_model = d_model
        self.sentence_len = sentence_len
        self.dropout = nn.Dropout(dropout)
        posititional_matrix = torch.zeros(sentence_len, d_model)
        position = torch.arange(0, sentence_len, dtype=torch.float32).unsqueeze(1) # Creates 2 dimensional indexed matrix from 0 to sentence_len - 1 (inclusive) as rows
        denominator = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0)/d_model)) # This is a computational simplication with logarithms. The denominator in Section 3.5 is 10000^(2i/d_model)
        posititional_matrix[:,0::2] = torch.sin(position * denominator) 
        posititional_matrix[:,1::2] = torch.cos(position * denominator)
        posititional_matrix = posititional_matrix.unsqueeze(0)
        self.register_buffer("positional_matrix", posititional_matrix) # This is a fixed tensor. Hence, registering as buffer saves this alongside model parameters. 

    def forward(self, x):
        x = x + self.positional_matrix[:, :x.shape[1], :].requires_grad_(False) #:x.shape[1] is the sentence length, we add the positional encodings to the inputs. We do not need to track gradients since positional matrix is not a training param
        return self.dropout(x)

In [26]:
"""
To allow for parallel normalization, we use layer normalization.
This normalizes the data by the colkumn 

"""

class LayerNormalization(nn.Module):
    def __init__(self, features: int, epsilon:float = 10**-6): #Epsilon to avoid div by 0 errors
        super().__init__()
        self.epsilon = epsilon
        self.alpha = nn.Parameter(torch.ones(features))
        self.beta = nn.Parameter(torch.zeros(features)) #bias param
    def forward(self, x):
        mean = x.mean(dim = -1, keepdim=True) # IMP: keepdim parameter needed as default Pytorch behavior removes dimension along which mean or std calculated
        standard_dev = x.std(dim = -1, keepdim=True) 
        return self.alpha * ((x - mean)/ (standard_dev + self.epsilon)) + self.beta

In [None]:
"""
Multi head attention implementation: 
Reference Section 3.2.1 and 3.2.2
"""
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, heads: int, dropout : float):
        super().__init__()
        assert d_model % heads == 0, "d_model needs to be divisible by heads parameter"
        self.d_model = d_model
        self.heads = heads
        self.d_k = d_model # Refer to 3.2.2 for dimensional specifications
        self.w_q = nn.Linear(d_model, d_model, bias=False) # bias term is excluded in the original paper 
        self.w_k = nn.Linear(d_model, d_model, bias=False) 
        self.w_v = nn.Linear(d_model, d_model, bias=False) 
        self.w_o = nn.Linear(d_model, d_model, bias=False)   
        self.dropout = nn.Dropout(dropout)
    
    @staticmethod
    def attention(query, key, value, mask, dropout: nn.Dropout): 
        d_k = query.shape[-1] #d_k value
        attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k) #swaps the sentence length and d_k dimensions, formula from the scalar dot attention section
        if mask is not None: #Masking ensures that the tokens we are using as paddings do not incorrectly influence our calculations
            attention_scores.masked_fill_(mask == 0, -1e9) 
        attention_scores = F.softmax(attention_scores, dim=-1)
        if dropout is not None: 
            attention_scores = dropout(attention_scores)
        return torch.matmul(attention_scores, value), attention_scores

    def forward(self, q, v, k, x, mask):
        query = self.w_q(q) 
        key = self.w_k(k)
        value = self.w_v(v)
        """
        We need to reshape the query, key, value tensors. 
        Respectively, we need to divide acros head number of tensors then change the dimensions such that we can calulcate attention (dimension details in Section 3.2.1)
        """
        query = query.view(query.shape[0], query.shape[1], self.heads, self.d_k).transpose(1,2)
        key = key.view(key.shape[0], key.shape[1], self.heads, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], self.heads, self.d_k).transpose(1, 2)
        x, self.attention_scores = MultiHeadAttention.attention(query, key, value, mask, self.dropout) 
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k) # Reassembles data
        return self.w_o(x)


In [28]:
"""
Residual connections allow for imp information to be carried across layers. 
Without these, training would be harder as gradients in deeper layers would stuggle being propagated.
^^ Vanishing gradient problem

Explanation of vanishing gradient problem + Residual Layers: 
https://medium.com/analytics-vidhya/understanding-and-implementation-of-residual-networks-resnets-b80f9a507b9c

Also in section 3.1
"""

class ResidualConnection(nn.Module):
        def __init__(self, features: int, dropout: float):
            super().__init__()
            self.dropout = nn.Dropout(dropout)
            self.norm = LayerNormalization(features)
        def forward(self, x, layer):
            return x + self.dropout(layer(self.norm(x)))

In [29]:
"""
Section 3.3 Feed-Forward Network Implementation
"""

class FeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float):
        super().__init__()
        self.first_linear = nn.Linear(d_model, d_ff) 
        self.dropout = nn.Dropout(dropout) # Specifies dropout in between
        self.second_linear = nn.Linear(d_ff, d_model) #Ref formula, inner dimensions need to be the same for first and second linear layer
    def forward(self, x):
        return self.second_linear(self.dropout(self.first_linear(x)))

In [30]:
"""
Architecture here is derived from Figure 1. 
"""

class Encoder(nn.Module):
    def __init__(self, features: int, layers: nn.ModuleList):
        super().__init__()
        self.features = features
        self.layers = layers
        self.norm = LayerNormalization(features)
    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x) 

In [31]:
class EncoderBlock(nn.Module):
    def __init__(self, features: int, self_attention: MultiHeadAttention, feed_forward: FeedForward, dropout: float):
        super().__init__()
        self.self_attention = self_attention
        self.feed_forward = feed_forward
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])
    def forward(self, x, encoder_masks):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, encoder_masks))  # Reference Figure 1 (bottom residual connection)
        # lambda needed in the above function to delay the function being called (we are passing the function call and not calling the function right now)
        x = self.residual_connections[1](x, self.feed_forward_block) # Reference Figure 1 (top residual connection)
        return x

In [32]:
class Decoder(nn.Module):
    def __init__(self, features: int, layers: nn.ModuleList):
        super().__init__()
        self.features = features
        self.layers = layers
        self.norm = LayerNormalization(features)
    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x) 
    
class DecoderBlock(nn.Module):
    def __init__(self, features: int, masked_self_attention: MultiHeadAttention, cross_attention: MultiHeadAttention, feed_forward: FeedForward, dropout: float):
        super().__init__()
        self.masked_self_attention = masked_self_attention
        self.cross_attention = cross_attention
        self.feed_forward = feed_forward
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(3)]) # Ref Figure 1, three residual connections needed
    def forward(self, x, encoder_output, encoder_masks, decoder_masks):
        x = self.residual_connections[0](x, lambda x: self.masked_self_attention(x, x, x, decoder_masks))
        x = self.residual_connections[1](x,  lambda x: self.cross_attention(x, encoder_output, encoder_output, encoder_masks))
        x = self.residual_connections[2](x, self.feed_forward)
        return x


In [33]:
class OutputLayer(nn.Module):
    def __init__(self,  d_model, vocab_size):
        super().__init__()
        self.output = F.softmax(nn.Linear(d_model, vocab_size))
    def forward(self, x):
        return self.output

In [34]:
class Transformer(nn.Module):
    def __init__(self, encoder: Encoder, decoder: Decoder, source_embedding: InputEmbeddings, target_embeddings: InputEmbeddings, source_positional: PositionalEmbeddings, target_positional: PositionalEmbeddings, output_layer: OutputLayer):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.source_embedding = source_embedding
        self.target_embedding = target_embeddings
        self.source_positional = source_positional
        self.target_positional = target_positional
        self.output_layer = output_layer

    def encode(self, source, source_mask):
        source = self.source_positional(self.source_embedding(source))
        return self.encoder(source, source_mask)
    
    def decode(self, encoder_output: torch.Tensor, source_mask: torch.Tensor, target: torch.Tensor, target_mask: torch.Tensor):
            # (batch, seq_len, d_model)
        target = self.target_positional(self.target_embedding(target))
        return self.decoder(target, encoder_output, source_mask, target_mask)
        
    def project(self, x):
        # (batch, seq_len, vocab_size)
        return self.output_layer(x)
        

In [None]:
def assemble(source_vocab_size: int, target_vocab_size: int, source_sentence_len: int, target_sentence_len: int, d_model: int = 512, N: int = 6, h:int=8, dropout: float= 0.1, d_ff: int = 2048):
    source_embedding = InputEmbeddings(d_model, source_vocab_size)
    target_embedding = InputEmbeddings(d_model, target_vocab_size)
    source_positional = PositionalEmbeddings(d_model, source_sentence_len, dropout)
    target_positional = PositionalEmbeddings(d_model, target_sentence_len, dropout)
    encoders= []
    for i in range(N):
        self_attention_block = MultiHeadAttention(d_model, h, dropout)
        feed_forward = FeedForward(d_model, d_ff, dropout)
        encoder_block = EncoderBlock(d_model, self_attention_block, feed_forward, dropout)
        encoders.append(encoder_block)
    decoders = []
    for i in range(N):
        self_attention_block = MultiHeadAttention(d_model, h, dropout)
        cross_attention_block = MultiHeadAttention(d_model, h, dropout)
        feed_forward = FeedForward(d_model, d_ff, dropout)
        decoder_block = EncoderBlock(d_model, self_attention_block, cross_attention_block, feed_forward, dropout)
        decoders.append(decoder_block)
    encoder = Encoder(d_model, nn.ModuleList(encoders))
    decoder = Decoder(d_model, nn.ModuleList(decoder))

    output_layer = OutputLayer(d_model, target_vocab_size)
    transformer = Transformer(encoder, decoder, source_embedding, target_embedding, source_positional, target_positional, output_layer)
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    return transformer

### Resources

https://www.youtube.com/watch?v=ISNdQcPhsts
https://arxiv.org/pdf/1706.03762