In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable

def get_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    elif torch.backends.mps.is_available():
        return torch.device('mps')
    else:
        return torch.device('cpu')
    
import unittest

def make_src_mask(src,src_pad_idx,device=None):
    """
    Creates a mask for the source input sequence.
    
    Args:
        src (torch.Tensor): The source input sequence.
    
    Returns:
        src_mask (torch.Tensor): The mask for the source input sequence.
    """
    if device is None:
        device = get_device()
    src_mask = (src != src_pad_idx).unsqueeze(1).unsqueeze(2).to(device)
    return src_mask

def make_trg_mask(trg,device=None):
    """
    Creates a mask for the target input sequence.
    
    Args:
        trg (torch.Tensor): The target input sequence.
    
    Returns:
        trg_mask (torch.Tensor): The mask for the target input sequence.
    """
    if device is None:
        device = get_device()
    N, trg_len = trg.shape
    trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
        N, 1, trg_len, trg_len
    ).to(device)
    
    return trg_mask

# Transformer from scratch

Writing a Transformer from scratch is an educational journey into the inner workings of one of the most influential models in the field of Natural Language Processing. Here's why focusing on this process can be so beneficial:

**Deep Understanding**: Utilizing pre-trained models or libraries can often lead to a surface-level understanding of the underlying mechanisms. On the other hand, creating a Transformer model from the ground up facilitates a deep, granular comprehension of each part of this complex architecture. You'll intimately understand the self-attention mechanism, positional encoding, layer normalization, and the importance of masks, among other concepts, as these are all pieces you'll have to create and interconnect.

**Educational Value**: Implementing a Transformer from scratch is akin to a hands-on masterclass in advanced deep learning. It's an effective method of learning that goes beyond theory, encouraging you to apply knowledge and solve problems as they arise. It enables you to grapple with the practical aspects of model development, including debugging, ensuring computational efficiency, and managing resources—a real-life skill set often omitted in theoretical coursework.

**Deeper Insights into NLP**: Transformers form the backbone of modern NLP, powering applications from translation to text generation. By building a Transformer yourself, you'll gain insights into why they're so effective for handling language data and how different components contribute to their exceptional performance. This can be a stepping stone to further explorations and innovations in NLP.

**Insight into Generative AI**: Transformers have revolutionized not just NLP, but the broader domain of generative AI as well. They lie at the core of cutting-edge models like GPT-3 that can generate human-like text, and have also found applications in areas like music and image generation. By constructing a Transformer from scratch, you'll gain a foundational understanding of the principles that drive these powerful generative models. This equips you to participate in and contribute to the ongoing advancements in this rapidly evolving field, whether by optimizing existing generative models or pioneering new approaches.

In conclusion, while using pre-built models has its advantages, nothing quite matches the depth of understanding and learning achieved by building a Transformer model from scratch. It's an endeavor that requires dedication and effort, but the resultant knowledge and skills gained make it a worthwhile investment for any serious learner or practitioner in the field of NLP.

![](imgs/attention-enc-decoder.png)

I have decided to embark on this journey of building a Transformer from scratch, inspired by the significant advancements and prolific literature available in the field. There's a rich body of resources and scholarly articles, tutorials, and code repositories that break down the complexity of the Transformer architecture, serving as invaluable guides and references. I'm particularly grateful to those brave pioneers who ventured into uncharted territory and shared their experiences and insights. Their work has illuminated the path for others and has been instrumental in making these complex concepts more accessible to the broader AI community. Through this endeavor, I hope to deepen my own understanding of Transformer models, while also contributing to the growing body of knowledge on this subject.


## Getting Started with Transformers

I am familiar with [Pytorch](https://pytorch.org/) and I will be using it to build the Transformer, but you can use any framework you prefer. As a developer, I find it helpful to have a clear plan of action before starting a project. This helps me stay focused and motivated, and also ensures that I don't get lost in the details. The famous paper ([Attention is all you need](https://arxiv.org/abs/1706.03762)) is a must read. I will be using the same notation as the paper as much as possible, so it's important to understand the terminology and notation used in the paper. I will also be using the same variable names as the paper, so it's helpful to have the paper open as a reference while going thru is Jupyter Notebook. Here's the high-level plan I've laid out for this project:

* I will start by building the core components of the Transformer architecture, which I will combine to create the complete the Transformer model. From a high-level I have decided to build the following components:

    * Multi-head attention
    * Transformer Block (which is used in Encoder and Decoder)
    * Encoder 
    * Decoder block which uses the Transformer Block and Multi-head attention
    * Decoder 
    * And finally Transformer which uses the Encoder and Decoder


* Every component mentioned above will be:
  * Defined as a [`nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) subclass. This will allow me to leverage the built-in capabilities of PyTorch, including the ability to track and update the model parameters, compute gradients, and more.
  * Have a corresponding unit test to ensure that it's working as expected. I will also use the unit tests to debug the code and fix any bugs that arise. 
  * Have a [`mermaid`](https://mermaid.js.org/intro/) diagram to visualize the flow of data. I will use the diagrams to ensure that the dimensions of tensors match at each step, and as a visual aid to understand the model architecture.


![](imgs/TrransformerFromScratch.png)

Without further ado, let's begin!


## But before we start, how does the data flow in a Transformer?

Assume you have a trained Transformer model that can translate English sentences to French. Here's how the data flows through the model during prediction:

1. **Source Token Sequence**: This is your input sentence in English, tokenized into a sequence of word vectors. The model reads this sequence to understand the context of the input sentence. Assuming N=1 (N is the batch number or the number of sentences in the input), the shape of the tensor is (1, SeqLength), where SeqLength is the length of the sequence or the number of tokens in the English sentence.

1. **Source Mask**: This is applied to the source sequence. The purpose of the source mask is to prevent the attention mechanism from focusing on irrelevant parts of the input, such as padding tokens. In a batch of sequences, not all sequences will be of the same length. To make them uniform in length, padding tokens are added to the shorter sequences. These padded positions don't carry any meaningful information, and it's important to prevent the self-attention mechanism from considering these positions. A padding mask is used for this purpose. It is typically a tensor with values that are either 0 (for padding positions) or 1 (for actual data positions).

1. **Encoder**: The encoder processes the source token sequence and the source mask. Since we are in inference mode, the model isn't being trained, so the weights remain the same. The encoder's role is to create a rich representation of the source sentence that captures the semantic information of the input.

1. **Target Token Sequence**: At the start of the decoding process, this sequence contains only a start-of-sequence token (in this case for the French language translation). The goal of the model is to generate the rest of the sequence, token by token, which will form our translated sentence in French.

1. **Target Mask**: This mask ensures that the prediction for each position can depend only on known outputs at earlier positions, not on future positions, maintaining the auto-regressive property.

1. **Decoder**: The decoder processes the output from the encoder, the initial target token sequence, and the target and source masks. It then generates a probability distribution for the next token (the next word in the French translation).The token (word) with the highest probability is selected and added to the target sequence. This updated target sequence is then fed back into the decoder, and the process is repeated. This generation continues until an end-of-sequence token is produced or the sequence reaches a predefined maximum length.

This way, during inference/prediction mode, our English-to-French Transformer model generates a French translation of the input English sentence, one word at a time, based on both the original English sentence and the French words it has generated so far.


```mermaid
graph TB
SRC[Source Token Sequence] --shape=N,SeqLength-->  ENC[Encoder]
SRCM[Source Mask] --> ENC
ENC --shape=N,SeqLength,E--> DEC[Decoder]
TRG[Target Token Sequence] --shape=N,SeqLength-->  DEC
SRCM --> DEC
TRGM[Target Mask] --> DEC
DEC --shape=N,SeqLength,E--> O[Output]
```

## Multti-ead Self Attention Block

In machine learning, particularly in models such as Transformers, an attention function is a crucial component that enhances the model's ability to focus on specific parts of the input data that are relevant for a given task.

Here's an expanded explanation of the key concepts:

**Query, Keys, and Values**: These are terms used in the attention mechanism. The 'query' is a vector that represents the current item or context we're focusing on. 'Keys' are vectors that represent the different items or contexts available in the input data. 'Values' are vectors that represent the actual content or information associated with each key. They query and keys always have the same dimension, but the values can have the same or different dimension. 

**Compatibility Function**: The weights assigned to each value in the weighted sum are determined by a 'compatibility function'. This is a function that takes the query and a key as inputs, and outputs a score that represents how well the query and key match. This score is then used as the weight for the corresponding value in the weighted sum. In other words, if the query matches well with a certain key (i.e., the compatibility function gives a high score), then the corresponding value gets a high weight, and therefore has a larger influence on the output of the attention function. This is how the attention mechanism allows the model to focus on the most relevant parts of the input data.

**Attention Function**: This is essentially a mapping procedure. It maps a query (a representation of the current context or focus) and a set of key-value pairs (representations of the possible inputs and their corresponding outputs or 'values') to an output. 

**Weighted Sum of Values**: The output of the attention function is calculated as a weighted sum of the values. This means each value vector is multiplied by a certain weight, and then all of these weighted values are added together to produce the output. The idea is to give more 'attention' (i.e., higher weight) to the values that are more relevant or useful for the current task.

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$


![](imgs/mha.png)


**Multi-Head Attention**: The multi-head attention mechanism is a variant of the attention mechanism that allows the model to jointly attend to information from different embedding/representation subspaces at different positions. This is beneficial as it allows the model to learn more robust relationships between the query and the key-value pairs. It also helps the model to focus on different positions and subspaces of the input data at different layers of the model.

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W_O$$
$$\text{where} \quad \text{head}_i = \text{Attention}(QW_{Qi}, KW_{Ki}, VW_{Vi})$$

The projections are defined by parameter matrices: 

- $W_{Qi} \in \mathbb{R}^{E_{\text{model}} \times E_q}$ 
- $W_{Ki} \in \mathbb{R}^{E_{\text{model}} \times E_k}$ 
- $W_{Vi} \in \mathbb{R}^{E_{\text{model}} \times E_v}$ 
- $W_O \in \mathbb{R}^{h E_v \times E_{\text{model}}}$

In our implementation, we use $h=8$ parallel attention layers, also known as heads. For each of these, we set 

$E_k = E_q = E_v = E_{\text{model}} / h = 64$. 


#### Input
1. **Values(N,Sv,E)**, **Keys(N,Sk,E)**, **Queries(N,Sq,E)**: 
    they are three tensors "Values", "Keys", "Queries" which are linearly transformed to the same embedding size "E". Keys and Queries have the same shape, while Values can have a different shape (the number of sequences "Sv" can be different from the number of sequences in Keys and Queries "Sk")
    * **N**: Batch Size
    * **Sv**, **Sk**, **Sq**: Sequence Lengths of Values, Keys, Queries
    * **E**: Embedding Size
2. **Mask** (Optional)one of two types:
    * **Target Mask**: The mask is used to prevent the model from attending to the future tokens. It is a tensor of shape (N, 1, Sq, Sq) where the future tokens are set to 0 and the rest to 1.
    * **Source Mask**: The mask is used to prevent the model from attending to the padding tokens. It is a tensor of shape (N, 1,1,Sq/Sk) where the padding tokens are set to 0 and the rest to 1. 

```mermaid
graph TB

V[Values] -- shape=N,Sv,E--> B[Linear]
K[Keys] --shape=N,Sk,E--> C[Linear]
Q[Queries] --shape=N,Sq,E--> D[Linear]
B --Values=N,Sv,E--> E[Reshape Values to N,Sv,H,Hd]
C --Keys=N,Sk,E--> F[Reshape Keys to N,Sk,H,Hd]
D --Queries=N,Sq,E--> G[Reshape Queries to N,Sq,H,Hd]
F --Keys=N,Sk,H,Hd,--> H[Scaled Dot Product]
G --Queries=N,Sq,H,Hd--> H
H --Queries=N,Sq,H,Hd--> U[Mask]
U --> I[Softmax]
I --Softmax=N,H,Sq,Sk--> J[Dot Product]
E --Values=N,Sv,H,Hd--> J
J --Dot Product=N,H,Sq,Sk--> L[Reshape to N,Sq,E or concatenate the heads]
L --MatchingValues=N,Sq,E--> M[Linear]
M --MatchingValues=N,Sq,E--> N[Output]
```

In [2]:
class MultiHeadSelfAttentionBlock(nn.Module):
    """
    Initializes the MultiHeadSelfAttentionBlock module.

    Args:
        embedding_size (int): The size of the input embeddings. The value should be 
                            divisible by num_heads. This is because the embeddings 
                            are split into `num_heads` different pieces during the 
                            self-attention process.
        num_heads (int): The number of attention heads. In the multi-head attention 
                        mechanism, the model generates `num_heads` different attention 
                        scores for each token in the sequence. This allows the model 
                        to focus on different parts of the sequence for each token.

    Raises:
        AssertionError: If `embedding_size` is not divisible by `num_heads`. The embedding 
                        size needs to be divisible by the number of heads to ensure even 
                        division of embeddings for multi-head attention.
    """    
    def __init__(self, embedding_size, num_heads,device=None):
        super(MultiHeadSelfAttentionBlock, self).__init__()
        if device is None:
            self.device = get_device()
        else:
            self.device = device
        self.embedding_size = embedding_size
        self.num_heads = num_heads
        self.head_dim = embedding_size // num_heads

        assert (
            self.head_dim * num_heads == embedding_size
        ), "embedding_size  needs to be divisible by num_heads"

        # Define the linear transformations for the input data 
        self.values_transform = nn.Linear(embedding_size, embedding_size)
        self.keys_transform = nn.Linear(embedding_size, embedding_size)
        self.queries_transform = nn.Linear(embedding_size, embedding_size)
        # Define the final output linear transformation
        self.linear_out = nn.Linear(embedding_size, embedding_size)

    def forward(self, values, keys, queries, mask):
        """
        Forward pass for the MultiHeadSelfAttentionBlock module.

        Args:
            values (torch.Tensor): The values tensor of shape (N, value_len, embedding_size),
                where N is the batch size, value_len is the sequence length for the values.
            keys (torch.Tensor): The keys tensor of shape (N, key_len, embedding_size), 
                where key_len is the sequence length for the keys.
            queries (torch.Tensor): The query tensor of shape (N, query_len, embedding_size), 
                where query_len is the sequence length for the queries.
            mask (torch.Tensor, optional): A mask tensor of shape (N, 1, 1, query_len/key_len),
                where the values are either 1 (for positions to be attended to) or 0 
                (for positions to be masked). The mask is used to prevent attention to 
                certain positions. Default is None.

        Returns:
            torch.Tensor: The output tensor of shape (N, query_len, embedding_size), 
                where N is the batch size, query_len is the sequence length for the queries,
                and embedding_size is the dimension of the output embeddings. 
                This tensor represents the result of applying self-attention on the input.

        Note:
            The method first transforms the input tensors (values, keys, query) using 
            separate linear transformations. Then, it splits the transformed embeddings 
            into multiple heads and computes the attention scores (queries * keys). If a mask 
            is provided, it is applied to the scores. The scores are then normalized to create 
            the attention weights. The method then computes the weighted sum of the values 
            (attention * values), applies a final linear transformation, and returns the result.
        """        
        # Get the number of training examples
        num_examples = queries.shape[0]

        # Calculate sequence lengths
        value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]
        
        values = values.to(self.device)
        keys = keys.to(self.device)
        queries = queries.to(self.device)
        mask = mask.to(self.device) if mask is not None else None

        # Transform the input data
        values = self.values_transform(values)
        keys = self.keys_transform(keys)
        queries = self.queries_transform(queries)

        # Split the embeddings into multiple heads
        values = values.reshape(num_examples, value_len, self.num_heads, self.head_dim)
        keys = keys.reshape(num_examples, key_len, self.num_heads, self.head_dim)
        queries = queries.reshape(num_examples, query_len, self.num_heads, self.head_dim)

        # Compute attention scores (queries * keys)
        attention_scores = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # queries shape: (N, query_len, num_heads, heads_dim),
        # keys shape: (N, key_len, num_heads, heads_dim)
        # attention_scores: (N, num_heads, query_len, key_len)        

        # Apply the mask to the scores
        # mask shape: (N, 1, 1, query_len/key_len)
        if mask is not None:
            attention_scores = attention_scores.masked_fill(mask == 0, float("-1e20"))

        # Normalize the scores
        # attention shape: (N, num_heads, query_len, key_len)
        attention = torch.softmax(attention_scores / (self.embedding_size ** 0.5), dim=3)

        # Compute weighted values (attention * values)
        # attention shape: (N, num_heads, query_len, key_len)
        # values shape: (N, value_len, num_heads, heads_dim)
        # weighted_values shape: (N, query_len, num_heads, heads_dim)
        weighted_values = torch.einsum("nhql,nlhd->nqhd", [attention, values])
        weighted_values = weighted_values.reshape(num_examples, query_len, self.num_heads * self.head_dim)

        # Apply the final linear transformation
        out = self.linear_out(weighted_values)

        return out

In [3]:
# Test it
class TestSelfAttention(unittest.TestCase):

    def test_init(self):
        # Test that the object initializes correctly
        try:
            self_attention = MultiHeadSelfAttentionBlock(embedding_size=512, num_heads=8)
        except Exception as e:
            self.fail(f"Initialization of SelfAttention failed with {e}")

    def test_forward_pass_small(self):
        # Create a SelfAttention object
        self_attention = MultiHeadSelfAttentionBlock(embedding_size=6, num_heads=2)
        self_attention.to(get_device())

        # Create dummy inputs for the forward pass
        # 64 examples, 20 tokens, 512 embedding size
        N, sequence_length = 1, 5  # Number of examples and sequence length
        values = torch.rand((N, sequence_length, 6))
        keys = torch.rand((N, sequence_length, 6))
        query = torch.rand((N, sequence_length, 6))
        mask = torch.ones((N, 1,1,sequence_length))

        # Test the forward pass
        try:
            output = self_attention(values, keys, query, mask)
        except Exception as e:
            self.fail(f"Forward pass of SelfAttention failed with {e}")


    def test_forward_pass(self):
        # Create a SelfAttention object
        self_attention = MultiHeadSelfAttentionBlock(embedding_size=512, num_heads=8)
        self_attention.to(get_device())

        # Create dummy inputs for the forward pass
        # 64 examples, 20 tokens, 512 embedding size
        N, sequence_length = 64, 20  # Number of examples and sequence length
        values = torch.rand((N, sequence_length, 512))
        keys = torch.rand((N, sequence_length, 512))
        query = torch.rand((N, sequence_length, 512))
        mask = torch.ones((N, 1,1,sequence_length))

        # Test the forward pass
        try:
            output = self_attention(values, keys, query, mask)
        except Exception as e:
            self.fail(f"Forward pass of SelfAttention failed with {e}")

    def test_output_shape(self):
        # Create a SelfAttention object
        self_attention = MultiHeadSelfAttentionBlock(embedding_size=512, num_heads=8)
        self_attention.to(get_device())

        # Create dummy inputs for the forward pass
        N, sequence_length = 64, 20  # Number of examples and sequence length
        values = torch.rand((N, sequence_length, 512))
        keys = torch.rand((N, sequence_length, 512))
        query = torch.rand((N, sequence_length, 512))
        mask = torch.ones((N, 1,1,sequence_length))

        # Perform the forward pass and check the output shape
        output = self_attention(values, keys, query, mask)
        self.assertEqual(output.shape, (N, sequence_length, 512),
                         "Output shape of SelfAttention forward pass is incorrect")
        
if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

....
----------------------------------------------------------------------
Ran 4 tests in 0.152s

OK


## Transformer Block

The forward pass in the Transformer Block is as follows:

1. **MultiHeadSelfAttentionBlock (MHA)**: `Queries (Q)`, `Keys (K)`, `Values (V)`, and a `Mask (M)` are inputs to the multi-head self-attention block. The shapes of Q, K, and V are (N, Sq, E), (N, Sk, E), and (N, Sv, E), respectively, where N is the batch size, Sq and Sk are sequence lengths of queries and keys and are typically equal, and E is the embedding size. The mask has shape (N, 1, Sq) or (N, 1, 1, Sq), depending on the mask type. This block computes a weighted sum of the 'Values' based on the similarity of 'Keys' and 'Queries', scaled by the mask.

1. **LayerNorm + Dropout (LN1)**: The output from the MHA and the original Queries (residual connection) are added together and then passed through the first Layer Normalization and Dropout operation to stabilize the outputs and prevent overfitting. The output of this stage retains the shape (N, Sq, E).

1. **FeedForwardBlock (FFN)**: The output of LN1 is passed through a Feed-Forward Neural Network, which consists of two linear transformations with a ReLU activation in between. It's here where the model can learn more complex representations.

1. **LayerNorm + Dropout (LN2)**: The output from the FFN and the output of LN1 (again, a residual connection) are added together, then passed through the second Layer Normalization and Dropout operation. This further stabilizes the outputs and helps prevent overfitting.

1. **Output (O)**: The output of LN2 becomes the final output of this Transformer block. It has the same shape as the input queries (N, Sq, E) and can be used as input to the next Transformer block in the stack (if any), or can be used for further processing (like a linear transformation to obtain prediction scores for downstream tasks).

This high-level description covers one Transformer block, and in practice, several of these blocks are stacked to form the complete Transformer model. This stacking allows the model to learn more complex patterns and relationships in the data.

```mermaid
graph TB
Q[Queries] --shape=N,Sq,E--> MHA
V[Values] -- shape=N,Sv,E--> MHA[MultiHeadSelfAttentionBlock]
K[Keys] --shape=N,Sk,E--> MHA
M[Mask] --shape=N,1,Sq-or-N,1,1,Sq--> MHA
Q --Queries=N,Sq,E--> LN1[LayerNorm + Dropout]
MHA --Queries=N,Sq,E--> LN1
LN1 --Queries=N,Sq,E--> FFN[FeedForwardBlock with forward expansion]
FFN --Queries=N,Sq,E--> LN2[LayerNorm + Dropout]
LN1 --Queries=N,Sq,E--> LN2
LN2 --Queries=N,Sq,E--> O[Output]
```

In [4]:
class TransformerBlock(nn.Module):
    """
    A Transformer Block class that defines a single block in a transformer model.

    Args:
        embed_size (int): The dimensionality of the input embeddings.
        num_heads (int): The number of attention heads for the self-attention mechanism.
        dropout_rate (float): The dropout rate used in the dropout layers to prevent overfitting.
        forward_expansion (int): The factor by which the dimensionality of the input is 
            expanded in the feed-forward network. The feed-forward network expands the 
            dimensionality of the input from `embed_size` to `forward_expansion * embed_size`
            and then reduces it back to `embed_size`.

    Attributes:
        self_attention (SelfAttention): The self-attention layer used in the transformer block.
        norm1 (nn.LayerNorm): The first layer normalization used to stabilize the outputs of 
            the self-attention layer.
        norm2 (nn.LayerNorm): The second layer normalization used to stabilize the outputs 
            of the feed-forward network.
        feed_forward (nn.Sequential): The feed-forward network used to transform the output 
            of the self-attention layer.
        dropout (nn.Dropout): The dropout layer used to prevent overfitting.

    The Transformer Block consists of a self-attention layer followed by normalization, 
    a feed-forward network followed by normalization, and a dropout layer. 
    """    
    def __init__(self, embed_size, num_heads, dropout_rate, forward_expansion, device=None):
        super(TransformerBlock, self).__init__()
        if device is None:
            self.device = get_device()
        else:
            self.device = device
        # Transformer block consists of a self-attention layer and a feed-forward neural network
        self.self_attention = MultiHeadSelfAttentionBlock(embed_size, num_heads,device=device)
        
        # Normalization layers are used to stabilize the outputs of self-attention and feed-forward network
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        # Feed-forward neural network - it's used to transform the output of the self-attention layer
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size),
        )

        # Dropout is used to reduce overfitting
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, values, keys, queries, mask):
        """
        Forward pass of the Transformer Block.

        Args:
            values (torch.Tensor): The values used by the self-attention layer. 
                They have shape (N, value_len, embed_size) where N is the batch size, 
                value_len is the length of the value sequence, and embed_size is the size of the embeddings.
            keys (torch.Tensor): The keys used by the self-attention layer. 
                They have shape (N, key_len, embed_size) where N is the batch size, 
                key_len is the length of the key sequence, and embed_size is the size of the embeddings.
            queries (torch.Tensor): The queries used by the self-attention layer. 
                They have shape (N, query_len, embed_size) where N is the batch size, 
                query_len is the length of the query sequence, and embed_size is the size of the embeddings.
            mask (torch.Tensor): The mask to be applied on the attention outputs to prevent the model 
                from attending to certain positions. It has shape (N, 1, 1, src_len), 
                where N is the batch size and src_len is the source sequence length.

        Returns:
            out (torch.Tensor): The output tensor from the transformer block, it has shape 
                (N, query_len, embed_size), where N is the batch size, query_len is the length 
                of the query sequence, and embed_size is the size of the embeddings.

        The forward method first applies the self-attention mechanism on the input tensor using the provided
        keys, queries, and values. The output from the self-attention layer is then passed through a
        normalization layer and a dropout layer. The output from these layers is then passed through the
        feed-forward network. The output from the feed-forward network is also passed through a 
        normalization layer and a dropout layer. The final output is then returned.
        """
        values = values.to(self.device)
        keys = keys.to(self.device)
        queries = queries.to(self.device)
        mask = mask.to(self.device) if mask is not None else None
        
        # Self-attention layer takes in values, keys and queries and returns an output tensor
        attention_output = self.self_attention(values, keys, queries, mask)

        # Add residual connection (skip connection), normalize and apply dropout
        # The normalization is applied on the sum of the original input `queries` and the output of the self-attention layer
        x = self.dropout(self.norm1(attention_output + queries))

        # Pass the output from the attention layer through the feed-forward network
        ff_output = self.feed_forward(x)
        
        # Add another residual connection, normalize and apply dropout
        # The normalization is applied on the sum of the previous output `x` and the output of the feed-forward network
        out = self.dropout(self.norm2(ff_output + x))

        return out

In [5]:
class TestTransformerBlock(unittest.TestCase):
    def setUp(self):
        self.embed_size = 512
        self.num_heads = 8
        self.dropout_rate = 0.1
        self.forward_expansion = 4
        self.transformer_block = TransformerBlock(
            self.embed_size, self.num_heads, self.dropout_rate, self.forward_expansion
        )
        self.transformer_block.to(get_device())

    def test_forward_pass(self):
        N = 64  # batch size
        seq_len = 10  # sequence length

        values = Variable(torch.rand(N, seq_len, self.embed_size))
        keys = Variable(torch.rand(N, seq_len, self.embed_size))
        queries = Variable(torch.rand(N, seq_len, self.embed_size))
        mask = Variable(torch.ones(N, 1, 1, seq_len))

        out = self.transformer_block(values, keys, queries, mask)
        
        # Assert that the output has the correct type and shape
        self.assertIsInstance(out, torch.Tensor)
        self.assertTupleEqual(out.shape, (N, seq_len, self.embed_size))

if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

.....
----------------------------------------------------------------------
Ran 5 tests in 0.149s

OK


## Encoder

The forward pass in the Encoder is as follows:

1. **Tokens Sequence**: The process begins with a sequence of tokens. The shape of this sequence is typically (N, SequenceLength), where N is the batch size and SequenceLength is the length of the sequence for each batch.

1. **Word Embeddings Mapping**: The tokens are then passed through an embedding layer that maps each token to a high-dimensional vector. This results in a tensor with shape (N, SequenceLength, E), where E is the dimensionality of the embeddings.

1. **Positional Encoding**: Given the lack of inherent sequential information in a Transformer's architecture, a positional encoding step is applied. This adds information about the position of each token in the sequence to the embedding. The output of this step still maintains the shape (N, SequenceLength, E).

1. **Dropout**: To prevent overfitting and improve generalization, a dropout layer is applied. Dropout randomly sets a fraction of input units to 0 at each update during training. The shape remains (N, SequenceLength, E) after this layer as well.

1. **Transformer Block**: The processed embeddings are then passed through a Transformer block. The block consists of a self-attention mechanism and a feed-forward neural network, both followed by normalization and dropout. This process could be repeated for a specific number of layers (as Transformers usually consist of multiple such blocks stacked upon each other).

1. **Output**: The final output from the Transformer block still maintains the shape (N, SequenceLength, E). This can be used as the input to the next Transformer block in the sequence or to a final linear layer to obtain prediction scores for tasks like sequence classification or translation.



```mermaid
graph TB
X[Tokens Sequence] --shape=N,SeqLength--> W[Word Embeddings Mapping]
W --shape=N,SeqLength,E--> PE[Positional Encoding]
PE --shape=N,SeqLength,E--> DO[Dropout]
DO --shape=N,SeqLength,E--> T[Transformer Block]
T --Number of Layers Times--> T
T --shape=N,SeqLength,E--> O[Output]
```

In [6]:
class Encoder(nn.Module):
    """
    The Encoder class for a Transformer model.
    
    Args:
        src_vocab_size (int): The size of the source vocabulary.
        embed_size (int): The dimensionality of the input embeddings.
        num_layers (int): The number of layers in the transformer.
        num_heads (int): The number of attention heads in the transformer block.
        device (torch.device): The device to run the model on (CPU or GPU).
        forward_expansion (int): The expansion factor for the feed forward network in transformer block.
        dropout (float): The dropout rate used in the dropout layers to prevent overfitting.
        max_length (int): The maximum sequence length the model can handle.
    """
    def __init__(
        self,
        src_vocab_size,
        embed_size,
        num_layers,
        num_heads,
        forward_expansion,
        dropout_rate,
        max_length,
        device=None,
    ):

        super(Encoder, self).__init__()
        self.embed_size = embed_size
        if device is None:
            self.device = get_device()
        else:
            self.device = device

        # Embeddings for the input words and positional embeddings
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        # Transformer blocks for the encoder layers
        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    num_heads,
                    dropout_rate=dropout_rate,
                    forward_expansion=forward_expansion,
                    device=self.device,
                )
                for _ in range(num_layers)
            ]
        )

        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, mask):
        """
        Forward method for the Encoder class.
        
        Args:
            x (torch.Tensor): The input tensor of shape (batch_size, seq_length).
            mask (torch.Tensor): The mask to be applied on the attention outputs to prevent the model 
                from attending to certain positions.
        
        Returns:
            out (torch.Tensor): The output tensor from the encoder.
        """

        # Obtain the batch size and sequence length
        N, seq_length = x.shape
        x = x.to(self.device)
        mask = mask.to(self.device)

        # Create positional indices
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)

        # Combine word embeddings and positional embeddings
        embeddings = self.word_embedding(x) + self.position_embedding(positions)

        # Apply dropout to the combined embeddings
        out = self.dropout(embeddings)

        # Pass the output through the layers
        for layer in self.layers:
            out = layer(out, out, out, mask)

        return out

In [7]:
class TestEncoder(unittest.TestCase):
    def setUp(self):
        self.src_vocab_size = 1000
        self.embed_size = 512
        self.num_layers = 6
        self.num_heads = 8
        self.forward_expansion = 4
        self.dropout_rate = 0.1
        self.max_length = 5000

        self.encoder = Encoder(
            src_vocab_size=self.src_vocab_size,
            embed_size=self.embed_size,
            num_layers=self.num_layers,
            num_heads=self.num_heads,
            forward_expansion=self.forward_expansion,
            dropout_rate=self.dropout_rate,
            max_length=self.max_length
        )

        self.encoder.to(get_device())

    def test_forward(self):
        batch_size = 32
        seq_length = 10
        src_vocab = torch.randint(0, self.src_vocab_size, (batch_size, seq_length)).to(get_device())
        src_mask = torch.zeros((batch_size, 1, 1, seq_length)).to(get_device())
        
        out = self.encoder(src_vocab, src_mask)

        self.assertTupleEqual(out.shape, (batch_size, seq_length, self.embed_size))

if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

......
----------------------------------------------------------------------
Ran 6 tests in 0.255s

OK


## Decoder Block

The decoder block uses the same components as the encoder block, such as Multi-head attention, Layer Normalization, and Feed-Forward Neural Network. However, it also has a few additional components, such as the encoder-decoder attention block and the target mask. The forward pass in the Decoder block is as follows:


1. **Target Mask & X embeddings**: The decoder block starts by receiving the target mask and the embeddings of the target sequence. The target mask is used to prevent the model from peaking at future positions in the sequence, enforcing the auto-regressive property. The shape of the embeddings (X) is (N, S, E), where N is the batch size, S is the sequence length, and E is the embedding size.

1. **MultiHeadSelfAttentionBlock** (MHA): The embeddings and the target mask are passed to a MultiHeadSelfAttentionBlock. This module applies self-attention to the embeddings, taking into account the target mask to avoid peaking at future positions.

1. **LayerNorm + Dropout (LN1**): The output of the self-attention block (which still has shape (N, S, E)) is then normalized and subjected to dropout to prevent overfitting. The input embeddings X are also added back to the output of the self-attention block, in a "residual connection" that helps to mitigate the vanishing gradients problem in deep networks.

1. **Transformer Block (TR)**: The output from the previous layer serves as the queries to the Transformer Block. It also receives the keys (K) and values (V), which typically are the outputs of the encoder. Additionally, the source mask is applied, which prevents the decoder from attending to specific positions in the source sequence (for example, padding positions).

1. **Output (O)**: Finally, the output from the Transformer Block is forwarded. This is typically passed through a linear layer followed by a softmax to produce probability scores for each token in the target vocabulary. The shape of this output is still (N, S, E), maintaining the batch size, sequence length, and embedding size dimensions.


```mermaid
graph TB
T[Target Mask] --> MHA[MultiHeadSelfAttentionBlock]
X[X embeddings] --shape=N,S,E--> MHA

MHA --shape=N,S,E--> LN1[LayerNorm + Dropout]
X --shape=N,S,E--> LN1
LN1 --queries=N,Sq,E--> TR[Transformer Block]
V[Values] --shape=N,Sv,E--> TR
K[Keys] --shape=N,Sk,E--> TR
S[Source Mask] --> TR
TR --shape=N,S,E--> O[Output]
```

In [8]:
class DecoderBlock(nn.Module):
    """
    The DecoderBlock class that forms a part of the Decoder in a Transformer model.

    Args:
        embed_size (int): The dimensionality of the input embeddings.
        num_heads (int): The number of attention heads for the self-attention mechanism.
        forward_expansion (int): The factor by which the dimensionality of the input is expanded in the feed-forward network.
        dropout_rate (float): The dropout rate used in the dropout layers to prevent overfitting.
        device (torch.device): The device to run the model on (CPU or GPU).
        
    Attributes:
        norm (nn.LayerNorm): Layer normalization.
        attention (SelfAttention): The self-attention mechanism.
        transformer_block (TransformerBlock): A transformer block.
        dropout (nn.Dropout): Dropout layer for regularization.
    """
    def __init__(self, embed_size, num_heads, forward_expansion, dropout_rate, device=None):
        super(DecoderBlock, self).__init__()
        if device is None:
            self.device = get_device()
        else:
            self.device = device        
        self.norm = nn.LayerNorm(embed_size)
        self.self_attention = MultiHeadSelfAttentionBlock(embed_size, num_heads)
        self.transformer_block = TransformerBlock(
            embed_size, num_heads, dropout_rate, forward_expansion
        )
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, value, key, src_mask, trg_mask):
        """
        Forward method for the DecoderBlock class.
        
        Args:
            x (torch.Tensor): The input tensor.
            value (torch.Tensor): The values to be used in the self-attention mechanism.
            key (torch.Tensor): The keys to be used in the self-attention mechanism.
            src_mask (torch.Tensor): The source mask to prevent attention to certain positions.
            trg_mask (torch.Tensor): The target mask to prevent attention to certain positions.
        
        Returns:
            out (torch.Tensor): The output tensor from the transformer block.
        """

        x = x.to(self.device)
        value = value.to(self.device)
        key = key.to(self.device)
        src_mask = src_mask.to(self.device)
        trg_mask = trg_mask.to(self.device)
        
        # Compute self attention
        attention = self.self_attention(x, x, x, trg_mask)
        
        # Apply normalization and dropout to the sum of the input and the attention output
        query = self.dropout(self.norm(attention + x))
        
        # Compute the output of the transformer block
        out = self.transformer_block(value, key, query, src_mask)
        
        return out

In [9]:
class TestDecoderBlock(unittest.TestCase):

    def setUp(self):
        self.batch_size = 64
        self.seq_length = 50
        self.embed_size = 512
        self.num_heads = 8
        self.forward_expansion = 4
        self.dropout_rate = 0.1

        # Initialize an instance of the DecoderBlock
        self.decoder_block = DecoderBlock(self.embed_size, self.num_heads, self.forward_expansion, self.dropout_rate)
        self.decoder_block.to(get_device())

        # Initialize some random test data
        self.x = torch.randn(self.batch_size, self.seq_length, self.embed_size)
        self.value = torch.randn(self.batch_size, self.seq_length, self.embed_size)
        self.key = torch.randn(self.batch_size, self.seq_length, self.embed_size)
        self.src_mask = torch.randn(self.batch_size, 1, 1, self.seq_length)
        self.trg_mask = torch.randn(self.batch_size, 1, 1, self.seq_length)

    def test_decoder_block(self):
        # Forward pass through the DecoderBlock
        output = self.decoder_block(self.x, self.value, self.key, self.src_mask, self.trg_mask)

        # Assert the output shape is as expected
        self.assertEqual(output.shape, (self.batch_size, self.seq_length, self.embed_size))
        
if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

.......
----------------------------------------------------------------------
Ran 7 tests in 0.318s

OK


## Decoder

In decoder the forward pass is as follows:

1. **Tokens Sequence**: The target sequence, denoted as X, is the input to the decoder. Its shape is (N, SeqLength), where N is the batch size and SeqLength is the sequence length.

1. **Word Embeddings Mapping**: The input sequence X is passed through an embedding layer that transforms each token into a dense vector of size E. The resulting shape is (N, SeqLength, E), where E is the embedding size.

1. **Positional Encoding**: Positional encoding is added to these embeddings to encode the order of the tokens in the sequence, since this information is not inherently captured in the self-attention mechanism.

1. **Dropout**: A dropout layer is applied to prevent overfitting and to add some noise to the data.

1. **As X embeddings sequence** (identity): The output from the dropout layer serves as the initial sequence of embeddings X for the decoder block.

1. **Target Encoded Sequence**: The encoded sequence (Y), typically the output of the encoder, is processed to be used as the keys and values for the attention mechanism in the decoder. They are of shape (N, SeqLength, E).

1. **As Key/Value embeddings sequence** (identity): The encoded sequence is used directly as keys (ASK) and values (ASV), with the same shape of (N, SeqLength, E).

1. **Decoder Block**: The decoder block receives the input embeddings sequence X, the key embeddings ASK, and the value embeddings ASV. The decoder block applies masked self-attention on X and then applies another attention mechanism using X as queries, ASK as keys and ASV as values.The Decoder Block's output is then passed through the same series of layers multiple times (the exact number is a parameter of the model), with each pass allowing the decoder to refine its understanding of the input sequence in the context of the output it is generating.

1. **Output**: The final output of the Decoder Block, which retains the shape of (N, SeqLength, E), can be used to predict the next token in the target sequence by passing it through a linear layer and then a softmax to generate probability scores for each token in the target vocabulary.

```mermaid
graph TB
X[Tokens Sequence: X] --shape=N,SeqLength-->  WE[Word Embeddings Mapping]
WE --shape=N,SeqLength,E--> PE[Positional Encoding]
PE --shape=N,SeqLength,E--> DO[Dropout]
DO --shape=N,SeqLength,E--> ASX[As X embeddings sequence]
ASX --shape=N,SeqLength,E-->T[Decoder Block]
T --Number of Layers Times--> T
Y[Target Encoded Sequence] --shape=N,SeqLength,E --> ASV[As Value embeddings sequence]
Y --shape=N,SeqLength,E--> ASK[As Key embeddings sequence]
ASK --shape=N,SeqLength,E-->T[Decoder Block]
ASV --shape=N,SeqLength,E-->T[Decoder Block]
T --> O[Output]
```

In [10]:
class Decoder(nn.Module):
    """
    The Decoder class for a Transformer model.
    
    Args:
        trg_vocab_size (int): The size of the target vocabulary.
        embed_size (int): The dimensionality of the input embeddings.
        num_layers (int): The number of layers in the transformer.
        num_heads (int): The number of attention heads in the transformer block.
        forward_expansion (int): The expansion factor for the feed forward network in transformer block.
        dropout_rate (float): The dropout rate used in the dropout layers to prevent overfitting.
        device (torch.device): The device to run the model on (CPU or GPU).
        max_length (int): The maximum sequence length the model can handle.
    """
    def __init__(
        self,
        trg_vocab_size,
        embed_size,
        num_layers,
        num_heads,
        forward_expansion,
        dropout_rate,
        max_length,
        device = None
    ):

        super(Decoder, self).__init__()
        if device is None:
            self.device = get_device()
        else:
            self.device = device

        # Embeddings for the input words and positional embeddings
        self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        # Decoder blocks for the decoder layers
        self.layers = nn.ModuleList(
            [
                DecoderBlock(
                    embed_size,
                    num_heads,
                    forward_expansion,
                    dropout_rate,
                    device=self.device
                )
                for _ in range(num_layers)
            ]
        )

        # Fully connected layer to map the decoder's output to the target vocabulary size
        self.fc_out = nn.Linear(embed_size, trg_vocab_size)
        
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, enc_out, src_mask, trg_mask):
        """
        Forward method for the Decoder class.
        
        Args:
            x (torch.Tensor): The input tensor of shape (batch_size, seq_length).
            enc_out (torch.Tensor): The output from the encoder.
            src_mask (torch.Tensor): The source mask to prevent the model from attending to certain positions.
            trg_mask (torch.Tensor): The target mask to prevent the model from attending to certain positions.
        
        Returns:
            out (torch.Tensor): The output tensor from the decoder.
        """

        # Obtain the batch size and sequence length
        N, seq_length = x.shape
        x = x.to(self.device)
        enc_out = enc_out.to(self.device)
        src_mask = src_mask.to(self.device)
        trg_mask = trg_mask.to(self.device)

        # Create positional indices
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)

        # Combine word embeddings and positional embeddings
        embeddings = self.word_embedding(x) + self.position_embedding(positions)

        # Apply dropout to the combined embeddings
        x = self.dropout(embeddings)

        # Pass the output through the layers
        for layer in self.layers:
            x = layer(x, enc_out, enc_out, src_mask, trg_mask)

        # Apply the fully connected layer to the outputs of the layers
        out = self.fc_out(x)

        return out

In [11]:
class TestDecoder(unittest.TestCase):
    def setUp(self):
        self.batch_size = 64
        self.seq_length = 50
        self.embed_size = 512
        self.num_heads = 8
        self.num_layers = 8
        self.forward_expansion = 4
        self.dropout_rate = 0.1
        self.trg_vocab_size = 10000
        self.max_length = 1000

        self.decoder = Decoder(
            self.trg_vocab_size,
            self.embed_size,
            self.num_layers,
            self.num_heads,
            self.forward_expansion,
            self.dropout_rate,
            self.max_length
        )
        self.decoder.to(get_device())

        self.x = torch.randint(0, self.trg_vocab_size, (self.batch_size, self.seq_length))
        self.enc_out = torch.randn(self.batch_size, self.seq_length, self.embed_size)
        self.src_mask = Transformer.make_src_mask(self.x, 0)
        torch.ones((self.batch_size, 1, 1, self.seq_length))
        self.trg_mask = Transformer.make_trg_mask(self.x)

    def test_forward(self):
        # Perform a forward pass through the decoder
        output = self.decoder(self.x, self.enc_out, self.src_mask, self.trg_mask)

        # Check the output size
        self.assertTupleEqual(output.size(), (self.batch_size, self.seq_length, self.trg_vocab_size))
        
if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

E.......
ERROR: test_forward (__main__.TestDecoder)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/var/folders/kh/ts9l8kk51cg14948m227s7f80000gp/T/ipykernel_27732/1802322832.py", line 26, in setUp
    self.src_mask = Transformer.make_src_mask(self.x, 0)
NameError: name 'Transformer' is not defined

----------------------------------------------------------------------
Ran 8 tests in 0.538s

FAILED (errors=1)


# All together

In [12]:
class Transformer(nn.Module):
    """
    The Transformer model class that combines an Encoder and a Decoder.

    Args:
        src_vocab_size (int): The size of the source vocabulary.
        trg_vocab_size (int): The size of the target vocabulary.
        src_pad_idx (int): The index of the source padding token in the source vocabulary.
        trg_pad_idx (int): The index of the target padding token in the target vocabulary.
        embed_size (int): The dimensionality of the input embeddings.
        num_layers (int): The number of layers in the transformer.
        forward_expansion (int): The expansion factor for the feed forward network in transformer block.
        heads (int): The number of attention heads in the transformer block.
        dropout (float): The dropout rate used in the dropout layers to prevent overfitting.
        device (torch.device): The device to run the model on (CPU or GPU).
        max_length (int): The maximum sequence length the model can handle.
    """

    def __init__(
        self,
        src_vocab_size,
        trg_vocab_size,
        src_pad_idx,
        trg_pad_idx,
        embed_size=512,
        num_layers=6,
        forward_expansion=4,
        heads=8,
        dropout=0,
        max_length=100,
        device=None
    ):

        super(Transformer, self).__init__()
        if device is None:
            self.device = get_device()
        else:
            self.device = device  
        self.encoder = Encoder(
            src_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            max_length,
            device=self.device
        )

        self.decoder = Decoder(
            trg_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            max_length,
            device=self.device
        )

        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

    @staticmethod
    def make_src_mask(src,src_pad_idx,device=None):
        """
        Creates a mask for the source input sequence.
        
        Args:
            src (torch.Tensor): The source input sequence.
        
        Returns:
            src_mask (torch.Tensor): The mask for the source input sequence.
        """
        if device is None:
            device = get_device()
        src_mask = (src != src_pad_idx).unsqueeze(1).unsqueeze(2).to(device)
        return src_mask

    @staticmethod
    def make_trg_mask(trg,device=None):
        """
        Creates a mask for the target input sequence.
        
        Args:
            trg (torch.Tensor): The target input sequence.
        
        Returns:
            trg_mask (torch.Tensor): The mask for the target input sequence.
        """
        if device is None:
            device = get_device()
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, 1, trg_len, trg_len
        ).to(device)
        
        return trg_mask

    def forward(self, src, trg):
        """
        Forward method for the Transformer class.
        
        Args:
            src (torch.Tensor): The source input sequence.
            trg (torch.Tensor): The target input sequence.
        
        Returns:
            out (torch.Tensor): The output tensor from the transformer.
        """
        src = src.to(self.device)
        trg = trg.to(self.device)
        src_mask = self.make_src_mask(src,self.src_pad_idx,self.device)
        trg_mask = self.make_trg_mask(trg,self.device)
        enc_src = self.encoder(src, src_mask)
        out = self.decoder(trg, enc_src, src_mask, trg_mask)
        return out

In [13]:
class TestTransformer(unittest.TestCase):
    def setUp(self):
        self.src_pad_idx = 0
        self.trg_pad_idx = 0
        self.src_vocab_size = 10
        self.trg_vocab_size = 10
        self.model = Transformer(self.src_vocab_size, self.trg_vocab_size, self.src_pad_idx, self.trg_pad_idx)
        self.model.to(get_device())

        self.x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]]).to(get_device())
        self.trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]]).to(get_device())

    def test_forward_pass(self):
        out = self.model(self.x, self.trg[:, :-1])
        self.assertTupleEqual(out.shape, (self.trg.shape[0], self.trg.shape[1] - 1, self.trg_vocab_size))

if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

.........
----------------------------------------------------------------------
Ran 9 tests in 1.048s

OK
