In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)  # Multi-head Attention
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.activation = F.relu

    def forward(self, src):
        # Self-attention layer
        src2 = self.self_attn(src, src, src)[0]
        src = self.norm1(src + self.dropout(src2))
        
        # Feedforward layer
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = self.norm2(src + self.dropout(src2))

        return src


class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
        super(TransformerDecoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.activation = F.relu

    def forward(self, tgt, memory):
        # Self-attention for the target sequence
        tgt2 = self.self_attn(tgt, tgt, tgt)[0]
        tgt = self.norm1(tgt + self.dropout(tgt2))
        
        # Multi-head attention for the source-target interaction
        tgt2 = self.multihead_attn(tgt, memory, memory)[0]
        tgt = self.norm2(tgt + self.dropout(tgt2))
        
        # Feedforward layer
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = self.norm3(tgt + self.dropout(tgt2))

        return tgt


class Transformer(nn.Module):
    def __init__(self, num_encoder_layers, num_decoder_layers, d_model, nhead, vocab_size, dim_feedforward=2048, dropout=0.1):
        super(Transformer, self).__init__()
        self.encoder = nn.Embedding(vocab_size, d_model)
        self.decoder = nn.Embedding(vocab_size, d_model)
        
        self.encoder_layers = nn.ModuleList([
            TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout)
            for _ in range(num_encoder_layers)
        ])
        
        self.decoder_layers = nn.ModuleList([
            TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout)
            for _ in range(num_decoder_layers)
        ])
        
        self.fc_out = nn.Linear(d_model, vocab_size)
        
    def forward(self, src, tgt):
        # Encoding the source sequence
        memory = self.encoder(src)
        for layer in self.encoder_layers:
            memory = layer(memory)
        
        # Decoding the target sequence
        tgt_emb = self.decoder(tgt)
        for layer in self.decoder_layers:
            tgt_emb = layer(tgt_emb, memory)
        
        # Final output layer (projection to vocab size)
        output = self.fc_out(tgt_emb)
        
        return output


# Hyperparameters for the transformer
d_model = 512        # Dimensionality of the embedding and the model
nhead = 8            # Number of attention heads
num_encoder_layers = 6  # Number of encoder layers
num_decoder_layers = 6  # Number of decoder layers
vocab_size = 10000    # Size of the vocabulary
dim_feedforward = 2048
dropout = 0.1

# Instantiate the Transformer model
model = Transformer(
    num_encoder_layers=num_encoder_layers,
    num_decoder_layers=num_decoder_layers,
    d_model=d_model,
    nhead=nhead,
    vocab_size=vocab_size,
    dim_feedforward=dim_feedforward,
    dropout=dropout
)

# Example inputs
src = torch.randint(0, vocab_size, (10, 32))  # (sequence_length, batch_size)
tgt = torch.randint(0, vocab_size, (20, 32))  # (sequence_length, batch_size)

# Forward pass
output = model(src, tgt)
print(output.shape)  # Output shape should be (target_seq_length, batch_size, vocab_size)


torch.Size([20, 32, 10000])
