# 4. Decoder-Only Transformer

Decoder-only transformers (like GPT) use masked self-attention and feed-forward networks.
This architecture powers modern language models!


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


## Key Components

1. **Masked Multi-Head Attention** - Can only attend to previous positions
2. **Feed-Forward Network** - Two linear layers with activation
3. **Layer Normalization** - Normalizes activations
4. **Residual Connections** - Adds input to output (x + f(x))


In [None]:
# Masked attention: prevent attending to future positions
seq_len = 5

# Create causal mask (lower triangular)
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
mask = mask.masked_fill(mask == 1, float('-inf'))
mask = mask.masked_fill(mask == 0, 0.0)

print("Causal mask (0 = allowed, -inf = masked):")
print(mask)
print("\nEach position can only attend to itself and previous positions!")


## Masked Multi-Head Attention

Same as multi-head attention, but with causal masking applied to attention scores.


In [None]:
class MaskedMultiHeadAttention(nn.Module):
    """Masked multi-head attention for decoder"""
    
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model)
        
    def forward(self, x, mask=None):
        """
        Args:
            x: [batch_size, seq_len, d_model]
            mask: Optional causal mask [seq_len, seq_len]
        """
        batch_size, seq_len, d_model = x.shape
        
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # Reshape for multi-head
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_k)
        
        # Apply causal mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == float('-inf'), float('-inf'))
        
        attention_weights = F.softmax(scores, dim=-1)
        attention_output = torch.matmul(attention_weights, V)
        
        # Concatenate heads
        attention_output = attention_output.transpose(1, 2).contiguous()
        attention_output = attention_output.view(batch_size, seq_len, d_model)
        
        output = self.W_o(attention_output)
        return output

# Test masked attention
d_model = 64
num_heads = 8
seq_len = 10
batch_size = 2

# Create causal mask
causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
causal_mask = causal_mask.masked_fill(causal_mask == 1, float('-inf'))

mha = MaskedMultiHeadAttention(d_model, num_heads)
x = torch.randn(batch_size, seq_len, d_model)

output = mha(x, mask=causal_mask)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print("Masked multi-head attention with causal masking!")


In [None]:
class FeedForward(nn.Module):
    """Feed-forward network"""
    
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        # Expand: d_model -> d_ff, then contract: d_ff -> d_model
        return self.linear2(self.relu(self.linear1(x)))

# Test feed-forward
d_model = 64
d_ff = 256  # Typically 4x d_model
seq_len = 10
batch_size = 2

ff = FeedForward(d_model, d_ff)
x = torch.randn(batch_size, seq_len, d_model)

output = ff(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Expanded to {d_ff} dimensions, then back to {d_model}")


## Transformer Decoder Block

Combines masked attention, feed-forward, layer norm, and residual connections.


In [None]:
class TransformerDecoderBlock(nn.Module):
    """Single decoder block"""
    
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attention = MaskedMultiHeadAttention(d_model, num_heads)
        self.feed_forward = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        # Self-attention with residual and norm
        attn_output = self.attention(x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Feed-forward with residual and norm
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        
        return x

# Test decoder block
d_model = 64
num_heads = 8
d_ff = 256
seq_len = 10
batch_size = 2

causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
causal_mask = causal_mask.masked_fill(causal_mask == 1, float('-inf'))

block = TransformerDecoderBlock(d_model, num_heads, d_ff)
x = torch.randn(batch_size, seq_len, d_model)

output = block(x, mask=causal_mask)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print("Decoder block with masked attention, feed-forward, and residuals!")


## Complete Decoder-Only Transformer

Stack multiple decoder blocks to create a full transformer model.


In [None]:
class DecoderOnlyTransformer(nn.Module):
    """Decoder-only transformer (GPT-style)"""
    
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_len, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        
        # Token and positional embeddings
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(max_seq_len, d_model)
        
        # Stack of decoder blocks
        self.blocks = nn.ModuleList([
            TransformerDecoderBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        # Output projection
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        """
        Args:
            x: Token indices [batch_size, seq_len]
        Returns:
            logits: [batch_size, seq_len, vocab_size]
        """
        batch_size, seq_len = x.shape
        
        # Create causal mask
        mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        
        # Token embeddings
        token_emb = self.token_embedding(x)  # [batch_size, seq_len, d_model]
        
        # Positional embeddings
        positions = torch.arange(seq_len, device=x.device).unsqueeze(0)
        pos_emb = self.position_embedding(positions)  # [1, seq_len, d_model]
        
        # Combine embeddings
        x = self.dropout(token_emb + pos_emb)
        
        # Pass through decoder blocks
        for block in self.blocks:
            x = block(x, mask=mask)
        
        # Final layer norm and projection
        x = self.ln_f(x)
        logits = self.head(x)
        
        return logits

# Test complete transformer
vocab_size = 1000
d_model = 128
num_heads = 8
num_layers = 4
d_ff = 512
max_seq_len = 256

model = DecoderOnlyTransformer(
    vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_len
)

# Example input: token indices
batch_size = 2
seq_len = 20
x = torch.randint(0, vocab_size, (batch_size, seq_len))

logits = model(x)

print(f"Input shape: {x.shape}")
print(f"Output logits shape: {logits.shape}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print("\nDecoder-only transformer complete!")
