In [None]:
##create a transformer from scratch using pytorch library

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# Multi-Head Attention module
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert embed_size % num_heads == 0, "Embedding size must be divisible by the number of heads"

        self.num_heads = num_heads
        self.head_dim = embed_size // num_heads

        self.query = nn.Linear(embed_size, embed_size)
        self.key = nn.Linear(embed_size, embed_size)
        self.value = nn.Linear(embed_size, embed_size)
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, queries, keys, values, mask=None):
        N = queries.shape[0]
        Q = self.query(queries)
        K = self.key(keys)
        V = self.value(values)

        # Split into heads
        Q = Q.view(N, -1, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(N, -1, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(N, -1, self.num_heads, self.head_dim).transpose(1, 2)

        # Scaled dot-product attention
        energy = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(self.head_dim)

        if mask is not None:
          energy = energy.masked_fill(mask == 0, float('-1e20'))


        attention = torch.softmax(energy, dim=-1)
        out = torch.matmul(attention, V)

        # Concatenate heads
        out = out.transpose(1, 2).contiguous().view(N, -1, self.num_heads * self.head_dim)

        # Final linear layer
        out = self.fc_out(out)
        return out

# Position-wise Feed-Forward Network
class PositionWiseFeedForward(nn.Module):
    def __init__(self, embed_size, ff_dim):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(embed_size, ff_dim)
        self.fc2 = nn.Linear(ff_dim, embed_size)

    def forward(self, x):
        return self.fc2(F.relu(self.fc1(x)))

# Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, embed_size)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_size, 2).float() * -(math.log(10000.0) / embed_size))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        seq_len = x.size(1)
        return x + self.encoding[:, :seq_len, :].to(x.device)

# Transformer Encoder Layer
class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_size, num_heads, ff_dim, dropout):
        super(TransformerEncoderLayer, self).__init__()
        self.attention = MultiHeadAttention(embed_size, num_heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.feed_forward = PositionWiseFeedForward(embed_size, ff_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        # Self-attention
        attn_out = self.attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_out))

        # Feed-forward
        ff_out = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_out))
        return x

# Full Transformer Encoder
class TransformerEncoder(nn.Module):
    def __init__(self, embed_size, num_heads, ff_dim, num_layers, vocab_size, max_len, dropout):
        super(TransformerEncoder, self).__init__()
        self.embed_size = embed_size
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.positional_encoding = PositionalEncoding(embed_size, max_len)
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(embed_size, num_heads, ff_dim, dropout)
            for _ in range(num_layers)
        ])
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        x = self.embedding(x) * math.sqrt(self.embed_size)
        x = self.positional_encoding(x)
        x = self.dropout(x)

        for layer in self.layers:
            x = layer(x, mask)

        return x

# Example usage
if __name__ == "__main__":
    embed_size = 512
    num_heads = 8
    ff_dim = 2048
    num_layers = 6
    vocab_size = 10000
    max_len = 100
    dropout = 0.1

    model = TransformerEncoder(embed_size, num_heads, ff_dim, num_layers, vocab_size, max_len, dropout)

    src = torch.randint(0, vocab_size, (32, max_len))  # (batch_size, seq_len)
    # src_mask = torch.ones((32, max_len, max_len))  # Example mask
    # Example mask with appropriate shape
    src_mask = torch.ones((32, 1, max_len, max_len))  # Add an extra dimension for num_heads


    out = model(src, src_mask)
    print(out.shape)  # Should be (32, max_len, embed_size)

torch.Size([32, 100, 512])
