In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.depth = d_model // num_heads

        self.query_linear = nn.Linear(d_model, d_model)
        self.key_linear = nn.Linear(d_model, d_model)
        self.value_linear = nn.Linear(d_model, d_model)

        self.out_linear = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, query, key, value, mask=None):
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.depth)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attention_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, value)
        return output, attention_weights

    def split_heads(self, x):
        """Split the last dimension into (num_heads, depth)."""
        batch_size = x.size(0)
        return x.view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)

    def forward(self, query, key, value, mask=None):
        query = self.split_heads(self.query_linear(query))
        key = self.split_heads(self.key_linear(key))
        value = self.split_heads(self.value_linear(value))

        attention_output, _ = self.scaled_dot_product_attention(query, key, value, mask)

        # Concatenate heads
        attention_output = attention_output.transpose(1, 2).contiguous().view(query.size(0), -1, self.d_model)

        return self.out_linear(attention_output)

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)].to(x.device)

In [None]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, dim_feedforward=2048, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feedforward = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Linear(dim_feedforward, d_model)
        )
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, mask=None):
        # Self-attention
        attn_output = self.self_attn(src, src, src, mask)
        src = self.layer_norm1(src + self.dropout(attn_output))

        # Feedforward
        ff_output = self.feedforward(src)
        src = self.layer_norm2(src + self.dropout(ff_output))

        return src


In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self, d_model, num_heads, num_layers, dim_feedforward=2048, dropout=0.1):
        super(TransformerEncoder, self).__init__()
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(d_model, num_heads, dim_feedforward, dropout)
            for _ in range(num_layers)
        ])
        self.positional_encoding = PositionalEncoding(d_model)

    def forward(self, src, mask=None):
        src = self.positional_encoding(src)
        for layer in self.layers:
            src = layer(src, mask)
        return src

In [None]:
class Transformer(nn.Module):
    def __init__(self, d_model, num_heads, num_encoder_layers, vocab_size, dim_feedforward=2048, dropout=0.1):
        super(Transformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.encoder = TransformerEncoder(d_model, num_heads, num_encoder_layers, dim_feedforward, dropout)
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, src, src_mask=None):
        embedded = self.embedding(src)
        encoded_output = self.encoder(embedded, src_mask)
        logits = self.fc_out(encoded_output)
        return logits


In [None]:
# Example usage
d_model = 512
num_heads = 8
num_encoder_layers = 6
vocab_size = 10000
src_sequence = torch.randint(0, vocab_size, (32, 10))  # Example batch of sequences

transformer = Transformer(d_model, num_heads, num_encoder_layers, vocab_size)
output = transformer(src_sequence)
print(output.shape)  # Output shape: (batch_size, sequence_length, vocab_size)


torch.Size([32, 10, 10000])
