In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.depth = d_model // num_heads

        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)

        self.dense = nn.Linear(d_model, d_model)

    def split_heads(self, x, batch_size):
        x = x.view(batch_size, -1, self.num_heads, self.depth)
        return x.permute(0, 2, 1, 3)

    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)

        q = self.split_heads(self.wq(q), batch_size)
        k = self.split_heads(self.wk(k), batch_size)
        v = self.split_heads(self.wv(v), batch_size)

        scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.depth, dtype=torch.float32))
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attention_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, v)

        output = output.permute(0, 2, 1, 3).contiguous()
        output = output.view(batch_size, -1, self.d_model)
        return self.dense(output)

class FeedForwardNetwork(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(FeedForwardNetwork, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ffn = FeedForwardNetwork(d_model, d_ff, dropout)

        self.layernorm1 = nn.LayerNorm(d_model)
        self.layernorm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attn_output = self.mha(x, x, x, mask)
        out1 = self.layernorm1(x + self.dropout1(attn_output))
        ffn_output = self.ffn(out1)
        return self.layernorm2(out1 + self.dropout2(ffn_output))
    

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        
        # Create a matrix of shape (max_len, d_model) representing the positional encodings
        pe = torch.zeros(max_len, d_model)
        
        # Create a vector representing the position indices
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        # Compute the positional encodings using sine and cosine functions
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Add a batch dimension by unsqueezing
        pe = pe.unsqueeze(0)
        
        # Register the positional encodings as a buffer, which makes it a part of the model but not trainable
        self.register_buffer('pe', pe)

    def forward(self, x):
        # Add positional encodings to the input tensor x
        x = x + self.pe[:, :x.size(1), :].to(x.device)
        return x

class TransformerEncoder(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, d_ff, input_vocab_size, max_seq_length, dropout=0.1):
        super(TransformerEncoder, self).__init__()

        self.embedding = nn.Embedding(input_vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_seq_length)

        self.enc_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        seq_length = x.size(1)
        x = self.embedding(x)
        x = self.pos_encoding(x)
        x = self.dropout(x)

        for layer in self.enc_layers:
            x = layer(x, mask)

        return x


In [2]:
# Define parameters
num_layers = 6
d_model = 512
num_heads = 8
d_ff = 2048
input_vocab_size = 10000
max_seq_length = 100

# Initialize encoder
encoder = TransformerEncoder(num_layers, d_model, num_heads, d_ff, input_vocab_size, max_seq_length)

# Dummy input
input_seq = torch.randint(0, input_vocab_size, (64, max_seq_length))

# Forward pass
output = encoder(input_seq)

In [4]:
output.shape

torch.Size([64, 100, 512])

In [5]:
input_seq.shape

torch.Size([64, 100])

In [9]:
input_seq[0]

tensor([4665, 8070, 9905, 2164, 2671, 7713, 1857, 3571, 7090, 4751, 1240, 5461,
        3593, 5304, 6606, 4758, 2426,  375, 2734,  763,  738, 3224, 8400, 6668,
         202, 3738, 4653, 4162, 5298,  782, 9596, 7926, 4435, 8175, 9134, 4460,
        3636, 7937, 5878, 5154, 3400, 9042, 5594, 3278, 4849, 9870, 6618, 1064,
        1182, 2710, 1038, 8674, 9894, 8474, 7426, 5926,  210, 2739, 8315, 9734,
        1980, 9773, 5329, 2463, 5059, 9857, 8080, 3484, 3668, 4355,  687, 1623,
        2846, 8195, 9184, 3687,  884, 7340, 4699, 5926, 7479, 3179, 1759, 1152,
        4724, 7301, 2162, 8986, 9426, 9459,  183, 1523, 9217, 6525, 5674, 1061,
        2011, 9003, 4859, 8230])