# Encoder Layer

> Implement Transformer's Encoder Layer from scratch

In [None]:
#| default_exp transformer.encoder

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
from typing import Callable, Optional

import torch
from torch import nn

# from foundation.transformer.attention import MultiHeadAttention
from foundation.transformer.efficient_attention import MultiHeadAttention
from foundation.transformer.embedding import TextEmbedding
from foundation.transformer.positional_encoding import PositionalEncoding

In [None]:
#| export
class ResidualLayerNorm(nn.Module):
    def __init__(self, d_model: int, dropout: Optional[float] = 0.3):
        super().__init__()
        self.layer_norm = nn.LayerNorm(normalized_shape=d_model)
        self.dropout = nn.Dropout(p=dropout)
    
    def forward(self, x: torch.Tensor, residual: torch.Tensor):
        return self.layer_norm(self.dropout(x + residual))

In [None]:
#| export
class PostionWiseFeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: Optional[float] = 0.3):
        """_summary_

        Args:
            d_model (int): the dimension of text embedding
            d_ff (int): the hidden dimension of the feed forward linear layer
            dropout (float, optional): dropout. Defaults to 0.3.
        """
        super().__init__()

        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )

    def forward(self, x: torch.Tensor):
        # shape(x) = [batch_size x seq_len x d_model]

        # shape(output) = [batch_size x seq_len x d_model]
        output = self.feed_forward(x)
        return output

In [None]:
#| export
class EncoderLayer(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: Optional[float] = 0.3):
        super().__init__()
        
        self.mha = MultiHeadAttention(d_model, n_heads, dropout)
        self.norm_1 = ResidualLayerNorm(d_model, dropout)
        self.feed_forward = PostionWiseFeedForward(d_model, d_ff, dropout)
        self.norm_2 = ResidualLayerNorm(d_model, dropout)
    
    def forward(self, x: torch.Tensor, mask = None):
        
        # shape(mha) = [batch_size x seq_len x d_model]
        # shape(encoder_attention_weights) = [batch_size x n_heads x seq_len x seq_len]
        mha, encoder_attention_weights = self.mha(x, x, x, mask=mask)
        
        # shape(norm1) = [batch_size x seq_len x d_model]
        norm_1 = self.norm_1(mha, x)
        
        # shape(feed_forward) = [batch_size x seq_len x d_model]
        feed_forward = self.feed_forward(norm_1)
        
        # shape(output) = [batch_size x seq_len x d_model]
        output = self.norm_2(feed_forward, norm_1)
        
        return output, encoder_attention_weights

`num_heads`: t

In [None]:
#| export
class Encoder(nn.Module):
    def __init__(
        self, d_model: int, n_heads: int, n_layers: int,
        d_ff: int, dropout: Optional[float] = 0.3
    ):
        super().__init__()
        self.embedding = TextEmbedding(
            vocab_size = 10000,
            d_model = d_model,
            padding_idx = 0
        )
        self.positional_encoding = PositionalEncoding(d_model)
        self.encoders = nn.ModuleList([
            EncoderLayer(
                d_model, n_heads, d_ff, dropout
            ) for _ in range(n_layers)
        ])
    
    def forward(self, x: torch.Tensor, mask = None):
        # shape(x) = [batch_size x src_seq_len]
        
        # shape(embeddings) = [batch_size x src_seq_len x d_model]
        embeddings = self.embedding(x)
        # shape(encoding) = [batch_size x src_seq_len x d_model]
        encoding = self.positional_encoding(embeddings)
        
        for encoder in self.encoders:
            # shape(encoding) = [batch_size x src_seq_len x d_model]
            # shape(encoder_attention_weights) = [batch_size x n_heads x src_seq_len x src_seq_len]
            encoding, encoder_attention_weights = encoder(encoding, mask)
        
        return encoding, encoder_attention_weights