In [1]:
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
import math

In [2]:
heads = 8
d = 512 # embedding size
dff = 2049 # expansiondim
N = 6 # layers
p = 0.1 # dropout rate

src = torch.randint(0, 100, (1, 4)) # 100 words in vocab
trg = torch.randint(0, 50, (1, 2)) # 2-word target sencence


In [3]:
class Embedding(nn.Module):
    # embedding layer with scaling and dropout
    def __init__(self, d : int, vocab_size : int):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d)
            
    def forward(self, x: Tensor) -> Tensor:
        return self.embedding(x)

In [4]:
e = Embedding(d, 100)
e(src).shape

torch.Size([1, 4, 512])

In [19]:
class PE(nn.Module):
    # Positional encoding fn with dropout
    # d = embedding size
    # p = prob of dropout
    def __init__(self, d : int, p : int, max_len = 100):
        super().__init__()
        self.pe = torch.zeros(max_len, d)
        pos = torch.arange(0, max_len, 1).unsqueeze(1)
        div = torch.pow(10_000, 2 * torch.arange(0, d, 2) / d)
        
        self.pe[:, 0::2] = torch.sin(pos / div)
        self.pe[:, 1::2] = torch.cos(pos / div)
        
        self.dropout = nn.Dropout(p)
    def forward(self, x: Tensor) -> Tensor:
        return self.dropout(x + self.pe[:x.shape[1]])
        

In [20]:
pe = PE(d, p)
pe(e(src)).shape

torch.Size([1, 4, 512])

In [21]:
class SelfAttention(nn.Module):
    # Multi head self-attention sub-layer
    
    def __init__(self, heads : int, d : int):
        super().__init__()
        
        self.heads = heads
        self.head_dim = d // heads
        self.d = d
        self.Q = nn.Linear(self.head_dim, self.head_dim)
        self.K = nn.Linear(self.head_dim, self.head_dim)
        self.V = nn.Linear(self.head_dim, self.head_dim)

        self.linear = nn.Linear(self.d, self.d)
        self.norm = nn.LayerNorm(d)
        
    def forward(self, q: Tensor, k: Tensor, v: Tensor, mask=None) -> Tensor:
        batch = q.shape[0]
        q_len = q.shape[1] # query length
        k_len = k.shape[1]
        v_len = v.shape[1]
        
        Q = self.Q(q.reshape(batch, q_len, self.heads, self.head_dim))
        K = self.K(q.reshape(batch, q_len, self.heads, self.head_dim))
        V = self.V(q.reshape(batch, q_len, self.heads, self.head_dim))
        
        QK = torch.einsum("bqhd, bkhd -> bhqk", [Q, K])
        scale = QK / math.sqrt(self.d)
        
        if mask is not None:
            scale = scale.masked_fill(mask == 0, float("-inf"))
            
        softmax = F.softmax(scale, dim=1)
        output = torch.einsum("bhqk, bvhd -> bqhd", [softmax, V])
        concat = output.reshape(batch, q_len, self.d)
        addnorm = self.norm(linear + q)
        
        return addnorm
            

In [22]:
s = SelfAttention(heads, d)
x = pe(e(src))
s(x, x, x).shape

NameError: name 'linear' is not defined