In [13]:
import torch
import torch.nn as nn

class SimpleTransformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout=0.1):
        super(SimpleTransformer, self).__init__()

        # Embeddings
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)

        # Encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers)

        # Decoder
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_decoder_layers)

        # Output layer
        self.out = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None, src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        # Embed the source and target sequences
        src_embed = self.src_embedding(src)
        tgt_embed = self.tgt_embedding(tgt)

        # Encode the source sequence
        memory = self.encoder(src_embed, mask=src_mask, src_key_padding_mask=src_key_padding_mask)

        # Decode the target sequence
        output = self.decoder(tgt_embed, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, 
                             tgt_key_padding_mask=tgt_key_padding_mask,
                             memory_key_padding_mask=memory_key_padding_mask)

        # Apply linear transformation to get final output
        output = self.out(output)
        
        return output
    
# example source and target sequences
src_seq = torch.tensor([[1, 2, 3, 4, 5, 0, 0]])  
tgt_seq = torch.tensor([[1, 2, 3, 4, 0, 0, 0]])

src_mask = (src_seq != 0).unsqueeze(1)  
tgt_mask = torch.tril(torch.ones((tgt_seq.size(1), tgt_seq.size(1)), dtype=torch.bool)).unsqueeze(0)  

# instantiate the model
model1 = SimpleTransformer(src_vocab_size=100, tgt_vocab_size=100, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048)
output1 = model1(src_seq, tgt_seq, src_mask=src_mask, tgt_mask=tgt_mask)
print("Output from Model 1:", output1)

RuntimeError: The shape of the 3D attn_mask is torch.Size([1, 1, 7]), but should be (56, 1, 1).

In [7]:
import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.fc_out = nn.Linear(d_model, d_model)
        
    def forward(self, Q, K, V, mask=None):
        N = Q.shape[0]
        Q_len, K_len, V_len = Q.shape[1], K.shape[1], V.shape[1]
        
        Q = self.W_q(Q).view(N, Q_len, self.num_heads, self.head_dim).transpose(1, 2)  # (N, num_heads, Q_len, head_dim)
        K = self.W_k(K).view(N, K_len, self.num_heads, self.head_dim).transpose(1, 2)  # (N, num_heads, K_len, head_dim)
        V = self.W_v(V).view(N, V_len, self.num_heads, self.head_dim).transpose(1, 2)  # (N, num_heads, V_len, head_dim)
        
        energy = torch.einsum("nqhd,nkhd->nhqk", [Q, K])
        
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float('-1e20'))
        
        attention = torch.softmax(energy / (self.head_dim ** (1/2)), dim=-1)
        out = torch.einsum("nhql,nlhd->nqhd", [attention, V]).reshape(N, Q_len, self.d_model)
        
        return self.fc_out(out)
    
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        return self.fc2(self.dropout(torch.relu(self.fc1(x))))
    
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff, dropout)
        self.layernorm1 = nn.LayerNorm(d_model)
        self.layernorm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask):
        attn_output = self.mha(x, x, x, mask)
        x = self.layernorm1(x + self.dropout(attn_output))
        
        ff_output = self.ff(x)
        x = self.layernorm2(x + self.dropout(ff_output))
        
        return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.masked_mha = MultiHeadAttention(d_model, num_heads)
        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff, dropout)
        self.layernorm1 = nn.LayerNorm(d_model)
        self.layernorm2 = nn.LayerNorm(d_model)
        self.layernorm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_out, src_mask, tgt_mask):
        attn_output = self.masked_mha(x, x, x, tgt_mask)
        x = self.layernorm1(x + self.dropout(attn_output))
        
        attn_output = self.mha(x, enc_out, enc_out, src_mask)
        x = self.layernorm2(x + self.dropout(attn_output))
        
        ff_output = self.ff(x)
        x = self.layernorm3(x + self.dropout(ff_output))
        
        return x

class Encoder(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, d_ff, input_vocab_size, max_len, dropout):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(input_vocab_size, d_model)
        self.positional_encoding = self._generate_positional_encoding(max_len, d_model)
        self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.dropout = nn.Dropout(dropout)
        
    def _generate_positional_encoding(self, max_len, d_model):
        pos = torch.arange(0, max_len).unsqueeze(1)
        i = torch.arange(0, d_model // 2).unsqueeze(0)
        angle_rates = 1 / torch.pow(10000, (2 * i) / d_model)
        pos_encoding = pos * angle_rates
        pos_encoding[..., 0::2] = torch.sin(pos_encoding[..., 0::2])
        pos_encoding[..., 1::2] = torch.cos(pos_encoding[..., 1::2])
        return pos_encoding.unsqueeze(0)
        
    def forward(self, x, mask):
        seq_len = x.shape[1]
        x = self.embedding(x)
        x += self.positional_encoding[:, :seq_len, :].to(x.device)
        x = self.dropout(x)
        for layer in self.layers:
            x = layer(x, mask)
        return x

class Decoder(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, d_ff, target_vocab_size, max_len, dropout):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(target_vocab_size, d_model)
        self.positional_encoding = self._generate_positional_encoding(max_len, d_model)
        self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.dropout = nn.Dropout(dropout)
        
    def _generate_positional_encoding(self, max_len, d_model):
        pos = torch.arange(0, max_len).unsqueeze(1)
        i = torch.arange(0, d_model // 2).unsqueeze(0)
        angle_rates = 1 / torch.pow(10000, (2 * i) / d_model)
        pos_encoding = pos * angle_rates
        pos_encoding[..., 0::2] = torch.sin(pos_encoding[..., 0::2])
        pos_encoding[..., 1::2] = torch.cos(pos_encoding[..., 1::2])
        return pos_encoding.unsqueeze(0)
        
    def forward(self, x, enc_out, src_mask, tgt_mask):
        seq_len = x.shape[1]
        x = self.embedding(x)
        x += self.positional_encoding[:, :seq_len, :].to(x.device)
        x = self.dropout(x)
        for layer in self.layers:
            x = layer(x, enc_out, src_mask, tgt_mask)
        return x

class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, src_pad_idx, tgt_pad_idx, 
                 num_layers=6, d_model=512, num_heads=8, d_ff=2048, dropout=0.1, max_len=100):
        super(Transformer, self).__init__()
        self.encoder = Encoder(num_layers, d_model, num_heads, d_ff, src_vocab_size, max_len, dropout)
        self.decoder = Decoder(num_layers, d_model, num_heads, d_ff, tgt_vocab_size, max_len, dropout)
        self.fc_out = nn.Linear(d_model, tgt_vocab_size)
        self.src_pad_idx = src_pad_idx
        self.tgt_pad_idx = tgt_pad_idx
        
    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        return src_mask
    
    def make_tgt_mask(self, tgt):
        tgt_pad_mask = (tgt != self.tgt_pad_idx).unsqueeze(1).unsqueeze(2)
        tgt_len = tgt.shape[1]
        tgt_sub_mask = torch.tril(torch.ones((tgt_len, tgt_len))).bool().to(tgt.device)
        tgt_mask = tgt_pad_mask & tgt_sub_mask
        return tgt_mask
    
    def forward(self, src, tgt):
        src_mask = self.make_src_mask(src)
        tgt_mask = self.make_tgt_mask(tgt)
        
        enc_out = self.encoder(src, src_mask)
        dec_out = self.decoder(tgt, enc_out, src_mask, tgt_mask)
        
        out = self.fc_out(dec_out)
        return out

# Example usage:
# Instantiate the model
# src_vocab_size = 10000  # example vocab size for source language
# tgt_vocab_size = 10000  # example vocab size for target language
# src_pad_idx = 0  # Padding

src_seq = torch.tensor([[1, 2, 3, 4, 5, 0, 0]])  # 0 may represent padding
tgt_seq = torch.tensor([[1, 2, 3, 4, 0, 0, 0]])  # 0 for padding

model2 = Transformer(src_vocab_size=100, tgt_vocab_size=100, src_pad_idx=0, tgt_pad_idx=0, num_layers=6, d_model=512, num_heads=8, d_ff=2048)

output2 = model2(src_seq, tgt_seq)

print("Output from Model 2:", output2)

RuntimeError: The size of tensor a (512) must match the size of tensor b (256) at non-singleton dimension 2