In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# Helper function to create positional encoding
def positional_encoding(max_len, d_model):
    position = torch.arange(0, max_len).unsqueeze(1).float()
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
    pos_enc = torch.zeros((max_len, d_model))
    pos_enc[:, 0::2] = torch.sin(position * div_term)
    pos_enc[:, 1::2] = torch.cos(position * div_term)
    return pos_enc.unsqueeze(0)

# Multi-Head Attention module
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        self.fc_out = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask):
        batch_size = query.shape[0]

        # Linear transformation for query, key, and value
        Q = self.query(query)
        K = self.key(key)
        V = self.value(value)

        # Reshape Q, K, V for multi-head attention
        Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        # Scaled Dot-Product Attention
        energy = torch.einsum('bhid,bhjd->bhij', [Q, K]) / math.sqrt(self.head_dim)
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float('-1e20'))

        attention = F.softmax(energy, dim=-1)
        x = torch.einsum('bhij,bhjd->bhid', [attention, V]).contiguous()

        # Concatenate heads and apply the final linear layer
        x = x.view(batch_size, -1, self.d_model)
        x = self.fc_out(x)
        x = self.dropout(x)
        return x

# Position-wise Feedforward module
class PositionwiseFeedforward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedforward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(p=dropout)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = self.dropout(x)
        x = self.linear2(x)
        return x

# Layer normalization with learnable parameters
class Norm(nn.Module):
    def __init__(self, d_model, eps=1e-6):
        super(Norm, self).__init__()
        self.size = d_model
        self.alpha = nn.Parameter(torch.ones(self.size))
        self.bias = nn.Parameter(torch.zeros(self.size))
        self.eps = eps

    def forward(self, x):
        norm = self.alpha * (x - x.mean(dim=-1, keepdim=True)) / (x.std(dim=-1, keepdim=True) + self.eps) + self.bias
        return norm

# Encoder layer
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.norm1 = Norm(d_model)
        self.norm2 = Norm(d_model)
        self.attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.ff = PositionwiseFeedforward(d_model, d_ff, dropout)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, mask):
        x = x + self.dropout(self.attn(x, x, x, mask))
        x = self.norm1(x)
        x = x + self.dropout(self.ff(x))
        x = self.norm2(x)
        return x

# Decoder layer
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.norm1 = Norm(d_model)
        self.norm2 = Norm(d_model)
        self.norm3 = Norm(d_model)
        self.attn1 = MultiHeadAttention(d_model, num_heads, dropout)
        self.attn2 = MultiHeadAttention(d_model, num_heads, dropout)
        self.ff = PositionwiseFeedforward(d_model, d_ff, dropout)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, enc_output, look_ahead_mask, padding_mask):
        x = x + self.dropout(self.attn1(x, x, x, look_ahead_mask))
        x = self.norm1(x)
        x = x + self.dropout(self.attn2(x, enc_output, enc_output, padding_mask))
        x = self.norm2(x)
        x = x + self.dropout(self.ff(x))
        x = self.norm3(x)
        return x

# Transformer model
class Transformer(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, d_ff, input_vocab_size, target_vocab_size, max_len, dropout=0.1):
        super(Transformer, self).__init__()
        self.embedding = nn.Embedding(input_vocab_size, d_model)
        self.pos_encoding = nn.Parameter(positional_encoding(max_len, d_model), requires_grad=False)

        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

        self.fc_out = nn.Linear(d_model, target_vocab_size)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, src, trg, src_mask, trg_mask):
        src = self.embedding(src) + self.pos_encoding[:, :src.size(1), :]
        trg = self.embedding(trg) + self.pos_encoding[:, :trg.size(1), :]

        for layer in self.encoder_layers:
            src = layer(src, src_mask)

        for layer in self.decoder_layers:
            trg = layer(trg, src, trg_mask, src_mask)

        output = self.fc_out(trg)
        return output

# Create an instance of the Transformer model
num_layers = 6
d_model = 512
num_heads = 8
d_ff = 2048
input_vocab_size = 10000  # Adjust as needed
target_vocab_size = 10000  # Adjust as needed
max_len = 100
dropout = 0.1

transformer_model = Transformer(num_layers, d_model, num_heads, d_ff, input_vocab_size, target_vocab_size, max_len, dropout)

# Print the model architecture
print(transformer_model)


Transformer(
  (embedding): Embedding(10000, 512)
  (encoder_layers): ModuleList(
    (0-5): 6 x EncoderLayer(
      (norm1): Norm()
      (norm2): Norm()
      (attn): MultiHeadAttention(
        (query): Linear(in_features=512, out_features=512, bias=True)
        (key): Linear(in_features=512, out_features=512, bias=True)
        (value): Linear(in_features=512, out_features=512, bias=True)
        (fc_out): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): PositionwiseFeedforward(
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (decoder_layers): ModuleList(
    (0-5): 6 x DecoderLayer(
      (norm1): Norm()
      (norm2): Norm()
      (norm3): Norm()
      (attn1): MultiHeadAttention(
        (query):