In [None]:
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
import math

In [None]:
heads = 8
d = 512 # embedding size
dff = 2049 # expansiondim
N = 6 # layers
p = 0.1 # dropout rate

src = torch.randint(0, 100, (1, 4)) # 100 words in vocab
trg = torch.randint(0, 50, (1, 2)) # 2-word target sencence


In [None]:
class Embedding(nn.Module):
    # embedding layer with scaling and dropout
    def __init__(self, d : int, vocab_size : int):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d)
            
    def forward(self, x: Tensor) -> Tensor:
        return self.embedding(x)

In [None]:
e = Embedding(d, 100)
e(src).shape

In [None]:
class PE(nn.Module):
    # Positional encoding fn with dropout
    # d = embedding size
    # p = prob of dropout
    def __init__(self, d : int, p : int, max_len = 100):
        super().__init__()
        self.pe = torch.zeros(max_len, d)
        pos = torch.arange(0, max_len, 1).unsqueeze(1)
        div = torch.pow(10_000, 2 * torch.arange(0, d, 2) / d)
        
        self.pe[:, 0::2] = torch.sin(pos / div)
        self.pe[:, 1::2] = torch.cos(pos / div)
        
        self.dropout = nn.Dropout(p)
    def forward(self, x: Tensor) -> Tensor:
        return self.dropout(x + self.pe[:x.shape[1]])
        

In [None]:
pe = PE(d, p)
pe(e(src)).shape

In [None]:
class SelfAttention(nn.Module):
    # Multi head self-attention sub-layer
    
    def __init__(self, heads : int, d : int):
        super().__init__()
        
        self.heads = heads
        self.head_dim = d // heads
        self.d = d
        self.Q = nn.Linear(self.head_dim, self.head_dim)
        self.K = nn.Linear(self.head_dim, self.head_dim)
        self.V = nn.Linear(self.head_dim, self.head_dim)

        self.linear = nn.Linear(self.d, self.d)
        self.norm = nn.LayerNorm(d)
        
    def forward(self, q: Tensor, k: Tensor, v: Tensor, mask=None) -> Tensor:
        batch = q.shape[0]
        q_len = q.shape[1] # query length
        k_len = k.shape[1]
        v_len = v.shape[1]
        
        Q = self.Q(q.reshape(batch, q_len, self.heads, self.head_dim))
        K = self.K(q.reshape(batch, q_len, self.heads, self.head_dim))
        V = self.V(q.reshape(batch, q_len, self.heads, self.head_dim))
        
        QK = torch.einsum("bqhd, bkhd -> bhqk", [Q, K])
        scale = QK / math.sqrt(self.d)
        
        if mask is not None:
            scale = scale.masked_fill(mask == 0, float("-inf"))
            
        softmax = F.softmax(scale, dim=1)
        output = torch.einsum("bhqk, bvhd -> bqhd", [softmax, V])
        concat = output.reshape(batch, q_len, self.d)
        linear = self.linear(concat)
        addnorm = self.norm(linear + q)
        
        return addnorm
            

In [None]:
s = SelfAttention(heads, d)
x = pe(e(src))
s(x, x, x).shape

In [None]:
class FeedForward(nn.Module):
    # feed forwrd network with 3 linear transformations
    # ReLU with Add&Norm operation
    def __init__(self, d: int, df: int):
        super().__init__()
        
        self.ff = nn.Sequential(
            nn.Linear(d, dff),
            nn.ReLU(),
            nn.Linear(dff, d))
        
        self.norm = nn.LayerNorm(d)
        
    def forward(self, x: Tensor) -> Tensor:
        return self.norm(x + self.ff(x))
        

In [None]:
f = FeedForward(d, dff)
x = s(x, x, x)
f(x).shape

In [None]:
class EncoderLayer(nn.Module):
    # Encoder with 2 sub-layers molti-head attention and position-wise feed forward network
    
    def __init__(self, heads: int, d: int, dff: int):
        super().__init__()
        
        self.attention = SelfAttention(heads, d)
        self.ff = FeedForward(d, dff)
        
    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
        return self.ff(self.attention(q, k, v))
    

In [None]:
enc = EncoderLayer(heads, d, dff)
x = pe(e(src))
enc(x, x, x).shape

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, heads: int, d: int, dff: int):
        super().__init__()
        
        self.masked_attention = SelfAttention(heads, d)
        self.enc_layer = EncoderLayer(heads, d, dff)
    
    def forward(self, x: Tensor, k: Tensor, v: Tensor, trg_mask: Tensor) -> Tensor:
        q = self.masked_attention(x, x, x, trg_mask)
        return self.enc_layer(q, k, v)        

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, heads: int, d: int, dff: int, N: int):
        super().__init__()
        
        self.enc_layer = nn.ModuleList([EncoderLayer(heads, d, dff) for _ in range(N)])
        self.dec_layer = nn.ModuleList([DecoderLayer(heads, d, dff) for _ in range(N)])
        
    def forward(self, src: Tensor, trg: Tensor) -> Tensor:
        for enc in self.enc_layer:
            src = enc(src, src, src)
        
        for dec in self.dec_layer:
            trg = dec(trg, src, src, self._make_mask(trg))
    
        return trg
    
    def _make_mask(self, trg):
        # trg shape: [1, 4, 512]
        batch, trg_len, _ = trg.shape
        mask = torch.tril(torch.ones(trg_len, trg_len))
        return mask.reshape(batch, 1, trg_len, trg_len)
            

In [None]:
encdec = EncoderDecoder(heads, d, dff, N)
s = pe(e(src))
t = pe(e(src))
encdec(s, t).shape