## Encoder in Transformer

In [6]:
# import some libraries we'll probably use
import numpy as np
import pandas as pd
import torch
# just used for plotting
from lets_plot import *
LetsPlot.setup_html()

class EncoderLayer(torch.nn.Module):
    def __init__(self, embed_dim: int, n_heads: int, dim_feedforward: int = 128, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.n_heads = n_heads
        self.mha = torch.nn.MultiheadAttention(embed_dim=embed_dim, num_heads=n_heads, batch_first=True, bias=True)
        self.layer_norm1 = torch.nn.LayerNorm(normalized_shape=embed_dim)
        self.layer_norm2 = torch.nn.LayerNorm(normalized_shape=embed_dim)

        # section 5.4
        # apply dropout to output of each sublayer before it is added to sublayer's input
        self.dropout1 = torch.nn.Dropout(p=dropout)
        self.dropout2 = torch.nn.Dropout(p=dropout)
        
        # section 3.3 in paper
        self.position_wise_ff = torch.nn.Sequential(
            torch.nn.Linear(in_features=embed_dim, out_features=dim_feedforward, bias=True),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=dim_feedforward, out_features=embed_dim, bias=True)
        )
    def forward(self, x, src_key_padding_mask=None, src_mask=None):
        # x.shape = (batch_size, seq_len, embed_dim)
        # src_key_padding_mask = (bs, seq_len), True value indicates it should not attend
        # src_mask.shape = (bs, seq_len, seq_len) of dtype torch.bool, True value indicates it shouldn't attend
        attn_output, attn_weights = self.mha(x, x, x, key_padding_mask=src_key_padding_mask, attn_mask=src_mask)
        # dropout and residual connection
        x  = x + self.dropout1(attn_output)
        x = self.layer_norm1(x)
        
        projection = self.position_wise_ff(x)
        # dropout and residual connection
        x = x + self.dropout2(projection)
        # layer norm
        x = self.layer_norm2(x)
        return x