In [1]:
import torch
import torch.nn as nn
import math

In [2]:
class LayerNorm(nn.Module):
  def __init__(self, d_model, eps=1e-5):
    super().__init__()

    self.gamma = nn.Parameter(torch.ones(d_model))

    self.beta = nn.Parameter(torch.zeros(d_model))

    self.eps = eps

  def forward(self, x):

    mean = x.mean(dim=-1, keepdim=True)

    var = x.var(dim=-1, keepdim=True, unbiased=False)

    x_hat = (x - mean) / torch.sqrt(var + self.eps)

    return self.gamma * x_hat + self.beta


In [3]:
class FeedForward(nn.Module):
  def __init__(self, d_model, d_ff, dropout=0.1):
    super().__init__()

    self.fc1 = nn.Linear(d_model, d_ff)
    self.fc2 = nn.Linear(d_ff, d_model)

    self.dropout = nn.Dropout(dropout)
    self.relu = nn.ReLU()

  def forward(self, x):

    x = self.fc1(x)
    x = self.relu(x)
    x = self.dropout(x)
    x = self.fc2(x)

    return x

In [4]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model, num_heads, dropout=0.1):
    super().__init__()

    assert d_model % num_heads == 0

    self.d_model = d_model
    self.num_heads = num_heads
    self.head_dim = d_model // num_heads

    self.W_q = nn.Linear(d_model, d_model)
    self.W_k = nn.Linear(d_model, d_model)
    self.W_v = nn.Linear(d_model, d_model)

    self.W_o = nn.Linear(d_model, d_model)

    self.dropout = nn.Dropout(dropout)

  def split_heads(self, x):

    batch_size, seq_len, _ = x.shape

    x = x.view(batch_size, seq_len, self.num_heads, self.head_dim)

    return x.transpose(1,2)

  def forward(self, x, mask=None):

    Q = self.W_q(x)
    K = self.W_k(x)
    V = self.W_v(x)

    Q = self.split_heads(Q)
    K = self.split_heads(K)
    V = self.split_heads(V)

    scores = torch.matmul(Q, K.transpose(-2, -1))
    scores = scores / math.sqrt(self.head_dim)

    if mask is not None:
      scores = scores.masked_fill(mask==0, -1e9)

    attn_weights = torch.softmax(scores, dim=-1)
    attn_weights = self.dropout(attn_weights)

    attn_output = torch.matmul(attn_weights, V)

    batch_size = x.shape[0]
    attn_output = attn_output.transpose(1, 2)
    attn_output = attn_output.contiguous().view(
        batch_size, -1, self.d_model
    )

    return self.W_o(attn_output)

In [6]:
class TransformerBlock(nn.Module):
  def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
    super().__init__()

    self.attention = MultiHeadAttention(d_model, num_heads, dropout)
    self.ffn = FeedForward(d_model, d_ff, dropout)

    self.norm1 = LayerNorm(d_model)
    self.norm2 =LayerNorm(d_model)

    self.dropout = nn.Dropout(dropout)

  def forward(self, x, mask=None):
    attn_out = self.attention(x, mask)

    x = self.norm1(x + self.dropout(attn_out))

    ffn_out = self.ffn(x)

    x = self.norm2(x + self.dropout(ffn_out))

    return x

In [11]:
batch_size = 2
seq_len = 5
d_model = 32

x = torch.randn(batch_size, seq_len, d_model)

transformer = TransformerBlock(
    d_model = 32,
    num_heads = 4,
    d_ff = 128
)

out = transformer(x)

print(f"shape: {out.shape}")
print("*"*100)
print(out)

shape: torch.Size([2, 5, 32])
****************************************************************************************************
tensor([[[ 1.0534, -0.6624, -1.4092, -0.4917,  0.5500,  0.0970,  0.1812,
           0.4793,  0.6594, -0.1894, -1.3995,  0.7596,  1.0723, -0.6384,
           0.9827, -0.5358,  2.2224, -0.3691,  0.0763, -1.7024, -1.1696,
           1.1941, -0.3627,  0.1955, -0.3510,  1.8046, -0.8134,  1.0535,
           0.1997,  0.4967, -2.1996, -0.7836],
         [-0.3104,  0.2464, -0.4946, -1.2386, -1.6280,  1.0460, -0.6358,
           1.7946, -1.0773,  0.4617, -0.1341, -2.0943,  1.0217,  0.1462,
           0.9793, -0.2250,  0.2189, -1.6106,  0.5425, -1.3726, -0.1398,
           1.1157,  0.3787,  0.1484,  0.5612,  0.7415, -0.2470, -0.1951,
          -1.3836,  0.0124,  1.5640,  1.8076],
         [ 2.3854, -1.8561, -0.8405, -0.2704,  0.9073,  0.0400,  1.4762,
          -1.4033,  0.0074,  0.0974,  0.1561,  0.4720,  0.8220,  1.0027,
          -0.8401, -1.3973,  0.3289, -0.3035,