## Transformer in image

<img src="./assert/transformer.png" width="50%" height="50%" alt="transformer">

In [20]:
# Add musk function to multi-head attention

import torch
import torch.nn as nn
import math

class MaskedMultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads

        self.w_query = nn.Linear(d_model, d_model)
        self.w_key = nn.Linear(d_model, d_model)
        self.w_value = nn.Linear(d_model, d_model)
        self.attention_scores = nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask=None):
        seq_len = query.shape[0]
        # from input, the query, key, value will be simply input matrix, input, input, input.
        query = self.w_query(query)
        key = self.w_key(key)
        value = self.w_value(value)

        query = query.view(seq_len, self.n_heads, self.head_dim).transpose(0, 1)
        key = key.view(seq_len, self.n_heads, self.head_dim).transpose(0, 1)
        value = value.view(seq_len, self.n_heads, self.head_dim).transpose(0, 1)

        attention_scores = torch.bmm(query, key.transpose(1, 2)) / math.sqrt(self.head_dim)
        
        if mask is not None:
            # The computation is pretty similar to the previous one:
            # we have seq_len of query ( sub matrix ), after processing query * key.T, we get
            # seq_len, (n_heads, n_heads), where n_head is the size of original attention matrix
            
            # Note:
            # Post mask, we don't want to have 0 as masked value,
            # because, softmax(0) = 1, which will make the attention score too high.
            # So, we use -inf to mask the value. The normalization will still remains.
            attention_scores = attention_scores.masked_fill(mask == 0, float("-inf"))

        attention = torch.matmul(torch.softmax(attention_scores, dim=-1), value)

        attention = attention.transpose(0, 1).contiguous().view(seq_len, self.d_model)
        return self.attention_scores(attention)
            
        
        

In [21]:
# We have a component for add & Norm, it pretty much means x + layer_norm(x)

class TransformerAddNorm(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.layer_norm = nn.LayerNorm(d_model)
    
    # Note the sublayer here is MultiHeadAttention
    def forward(self, x, sublayer):
        return self.layer_norm(x + sublayer(x))

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )
        
    def forward(self, x):
        return self.feed_forward(x)


In [22]:
import torch
import torch.nn as nn

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        self.d_model = d_model
        self.max_len = max_len
        
        # initialize position encoding, shape: (max_len, d_model), all zeros
        self.position_encoding = torch.zeros(max_len, d_model)
        
        # position: (max_len, 1), this will be position for each token
        self.position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        # i will be step_count we used to compute div_term
        i = torch.arange(0, d_model, step=2, dtype=torch.float)
        
        # the calculation of div_term is based on the formula in the paper, check image above
        # we apply a boardcast to the position matrix to get a matrix with shape (max_len, d_model / 2)
        
        # here is the example:
        # position: [0, 1, 2, 3]
        # i: [0, 2]
        # div_term: [1, 100, 10000]
        
        # in order to calculate the div, we need transform position for boardcast.
        # we will use the boardcast to fill the position_encoding matrix
        # [0, 1, 2, 3] -> [0, 0, 0, 0]
        # [0, 1, 2, 3] -> [1, 100, 10000, 1000000]
        # [0, 1, 2, 3] -> [2, 200, 20000, 2000000]
        # [0, 1, 2, 3] -> [3, 300, 30000, 3000000]
        self.div_term = self.position / torch.pow(10000.0, 2 * i / d_model)
        
        # We will leverage boardcast to fill the position_encoding matrix
        # the self.div_term is a matrix with shape (max_len, d_model/2)
        # fill odd index with sin, even index with cos
        self.position_encoding[:, 0::2] = torch.sin(self.div_term)
        self.position_encoding[:, 1::2] = torch.cos(self.div_term)
   
    def forward(self, x):
        seq_len = x.size(0)  # x.shape = (seq_len, d_model)
        seq_len = min(self.max_len, seq_len)
        x = x + self.position_encoding[:seq_len, :]
        return x
        

In [23]:
# Encoder

class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.mha = MaskedMultiHeadAttention(d_model, n_heads)
        self.ffn = FeedForward(d_model, d_ff)
        self.add_norm_mha = TransformerAddNorm(d_model)
        self.add_norm_ffn = TransformerAddNorm(d_model)        
    
    def forward(self, x, src_mask):
        # x.shape = (seq_len, d_model)
        x = self.add_norm_mha(x, lambda y: self.mha(y, y, y, src_mask))
        x = self.add_norm_ffn(x, self.ffn)
        return x

class Encoder(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, n_layers):
        super().__init__()
        self.positional_encoding = PositionalEncoding(d_model)
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, n_heads, d_ff) for _ in range(n_layers)
        ])
        
    def forward(self, x, src_mask):
        x = self.positional_encoding(x)
        for layer in self.encoder_layers:
            x = layer(x, src_mask)
        return x
        

In [29]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.positional_encoding = PositionalEncoding(d_model)
        self.mha_1 = MaskedMultiHeadAttention(d_model, n_heads)
        self.add_norm_mha_1 = TransformerAddNorm(d_model)
        self.corss_mha = MaskedMultiHeadAttention(d_model, n_heads)
        self.add_norm_mha_2 = TransformerAddNorm(d_model)
        self.ffn = FeedForward(d_model, d_ff)
        self.add_norm_ffn = TransformerAddNorm(d_model)
    
    def forward(self, x, src, src_mask, tgt_mask):
        x = self.add_norm_mha_1(x, lambda y: self.mha_1(y, y, y, tgt_mask))
        x = self.add_norm_mha_2(x, lambda y: self.corss_mha(y, src, src, src_mask))
        x = self.add_norm_ffn(x, self.ffn)
        return x
    
class Decoder(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, n_layers):
        super().__init__()
        self.positional_encoding = PositionalEncoding(d_model)
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, n_heads, d_ff) for _ in range(n_layers)
        ])
        
    def forward(self, x, src, src_mask, tgt_mask):
        x = self.positional_encoding(x)
        for layer in self.decoder_layers:
            x = layer(x, src, src_mask, tgt_mask)
        return x
        

In [33]:
class Transformer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, n_layers):
        super().__init__()
        self.encoder = Encoder(d_model, n_heads, d_ff, n_layers)
        self.decoder = Decoder(d_model, n_heads, d_ff, n_layers)
        self.linear = nn.Linear(d_model, d_model)
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, src, tgt, src_mask, tgt_mask):
        src = self.encoder(src, src_mask)
        tgt = self.decoder(tgt, src, src_mask, tgt_mask)
        out = self.linear(tgt)
        return self.softmax(out)
        

In [34]:
# Training Setup and Process

torch.manual_seed(42)

def create_mask(seq_len):
    """Create a causal mask for decoder (prevent looking at future tokens)"""
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask

def generate_sample_data(seq_len, d_model, num_samples=32):
    """Generate synthetic training data - no batching, individual samples"""
    # Generate individual samples as list of 2D tensors
    src_data = [torch.randn(seq_len, d_model) for _ in range(num_samples)]
    tgt_data = [torch.randn(seq_len, d_model) for _ in range(num_samples)]
    
    return src_data, tgt_data

# Model hyperparameters
d_model = 1024      # Model dimension
n_heads = 8       # Number of attention heads  
d_ff = 2048      # Feed-forward dimension
n_layers = 8      # Number of encoder/decoder layers
seq_len = 16      # Actual sequence length for training

# Generate training data
num_samples = 100
src_data, tgt_data = generate_sample_data(seq_len, d_model, num_samples)

print(f"Training data: {num_samples} samples of length {seq_len}")
print(f"Source shape per sample: {src_data[0].shape}")
print(f"Target shape per sample: {tgt_data[0].shape}")


Training data: 100 samples of length 16
Source shape per sample: torch.Size([16, 1024])
Target shape per sample: torch.Size([16, 1024])


In [35]:

# Training Loop - No Batching (matches MHA implementation)

# Training setup
model = Transformer(d_model, n_heads, d_ff, n_layers)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()

# Create masks
src_mask = None  # No masking for encoder
tgt_mask = create_mask(seq_len)  # Causal mask for decoder

model.train()
epochs = 50

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

print("Starting training...")
for epoch in range(epochs):
    total_loss = 0
    
    # Process each sample individually (no batching)
    for i in range(num_samples):
        optimizer.zero_grad()
        
        # Get single sample: (seq_len, d_model)
        src_sample = src_data[i]
        tgt_sample = tgt_data[i]
        
        # Forward pass
        output = model(src_sample, tgt_sample, src_mask, tgt_mask)
        
        # Compute loss (predicting the target sequence)
        loss = criterion(output, tgt_sample)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / num_samples
    
    if epoch % 10 == 0:
        print(f"Epoch {epoch:2d}, Average Loss: {avg_loss:.6f}")

print("Training completed!")


TypeError: Transformer.__init__() takes 5 positional arguments but 6 were given

In [28]:
# Model Evaluation and Testing

model.eval()
with torch.no_grad():
    # Test on a single sample
    test_src = torch.randn(seq_len, d_model)
    test_tgt = torch.randn(seq_len, d_model)
    
    print("\\n=== Model Evaluation ===")
    print(f"Test input shape: {test_src.shape}")
    print(f"Test target shape: {test_tgt.shape}")
    
    # Forward pass
    output = model(test_src, test_tgt, src_mask, tgt_mask)
    test_loss = criterion(output, test_tgt)
    
    print(f"Test output shape: {output.shape}")
    print(f"Test loss: {test_loss.item():.6f}")
    
    # Check if output is properly normalized (due to softmax)
    print(f"Output sum per position (should be ~1.0): {output.sum(dim=-1)[:5]}")  # First 5 positions

print("\\n=== Training Summary ===")
print(f"✅ Transformer model with {sum(p.numel() for p in model.parameters()):,} parameters")
print(f"✅ Trained on {num_samples} synthetic samples for {epochs} epochs")
print(f"✅ Final loss: {avg_loss:.6f}")
print(f"✅ Model successfully processes sequences of length {seq_len}")


\n=== Model Evaluation ===
Test input shape: torch.Size([16, 1024])
Test target shape: torch.Size([16, 1024])
Test output shape: torch.Size([16, 1024])
Test loss: 0.995932
Output sum per position (should be ~1.0): tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000])
\n=== Training Summary ===
✅ Transformer model with 169,051,136 parameters
✅ Trained on 100 synthetic samples for 50 epochs
✅ Final loss: 1.000419
✅ Model successfully processes sequences of length 16
