### **Decoder: Autoregressive Next-Token Prediction**

* This is a minimal decoder-only Transformer built to understand how information flows through a decoder during next-token prediction.
* The model takes a sequence of token representations, applies causal self-attention and feed-forward transformations, and produces vocabulary logits for
each position in the sequence.

#### **Notation**

- **B** — batch size  
- **T** — sequence length (time steps)  
- **D** — model dimension (`d_model`)  
- **V** — vocabulary size  

All tensors follow the shape convention:

* (B, T, D) — for token representations  
* (B, T, V) — for output logits


In [None]:
# ------------------------
# Imports
# ------------------------
import torch
import torch.nn as nn
import math

# ------------------------
# Initial configuration
# ------------------------
vocab_size = 20   # V 
d_model = 4       # D (token vector dimention)
seq_len = 6       # T (number of tokens in input sentence)
batch_size = 1    # B


#### **1 Single-Head Causal Self-Attention**

This section implements the **core operation of a decoder-only Transformer**: causal self-attention.The goal here is not efficiency or multi-head scaling, but to make the attention mechanism **fully explicit and inspectable**

- query, key, value projections
- causal masking
- attention weight computation
- context aggregation

This attention module operates on a full sequence and enforces **autoregressive (left-to-right) information flow**.


In [2]:
# IN  : token representations (B, T, D)
# OUT : context-aware representations (B, T, D) and attention weights for inspection (B, T, T)

class Attention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        
        self.d_model = d_model

        # Linear projections for Q, K, V
        self.Wq_projection = nn.Linear(d_model, d_model, bias=False)
        self.Wk_projection = nn.Linear(d_model, d_model, bias=False)
        self.Wv_projection = nn.Linear(d_model, d_model, bias=False)

        # Output projection
        self.Wo_projection = nn.Linear(d_model, d_model, bias=False)

    def forward(self, X):
        B, S, D = X.shape

        # Project input to queries, keys, values
        Q = self.Wq_projection(X)
        K = self.Wk_projection(X)
        V = self.Wv_projection(X)

        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_model)

        # Causal mask to prevent access to future tokens
        causal_mask = torch.tril(torch.ones(S, S))
        scores = scores.masked_fill(causal_mask == 0, -1e9)

        # Attention weights
        weights = torch.softmax(scores, dim=-1)

        # Weighted sum of values
        output = torch.matmul(weights, V)

        # Final projection
        output = self.Wo_projection(output)

        return output, weights


### **2 Layer Normalization**

- Layer Normalization stabilizes training by **normalizing each token’s feature vector independently**.
- This keeps activations well-behaved, improves gradient flow, and makes deep Transformer training stable.
- This implementation uses PyTorch’s built-in `LayerNorm` 

In [3]:
# IN  : token representations (B, T, D)
# OUT : normalized representations (B, T, D)

class LN(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.norm = nn.LayerNorm(d_model)

    def forward(self, X):
        return self.norm(X)

### **3 Feed-Forward Network**

* The Feed-Forward Network applies a **position-wise nonlinearity** to each token independently.
* It expands the model dimension, applies a non-linear transformation, and projects back to the original dimension.
* This allows the decoder to learn **nonlinear feature interactions** after attention has mixed information across tokens.

In [4]:
# IN  : (B, T, D)
# OUT : (B, T, D)

class FFN(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.linear1 = nn.Linear(d_model, 4 * d_model)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(4 * d_model, d_model)

    def forward(self, X):
        h = self.relu(self.linear1(X))
        out = self.linear2(h)
        return out


### **4 Output Projection (Logits)**

* The output projection maps each token’s final representation to vocabulary-sized logits.
* These logits are later used for next-token prediction via a softmax and cross-entropy loss during training.


In [5]:
# IN  : (B, T, D)
# OUT : (B, T, V)

class Logit(nn.Module):
    def __init__(self, d_model, vocab_size):
        super().__init__()
        self.linear = nn.Linear(d_model, vocab_size)

    def forward(self, X):
        return self.linear(X)


### **5 Decoder Block**

* This section wires together all components into a single decoder layer using residual connections and layer normalization.
* The decoder processes a sequence autoregressively and produces vocabulary logits for each token.

In [6]:
# IN  : (B, T, D)
# OUT : (B, T, V)

class Decoder(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()

        self.attn = Attention(d_model)
        self.ln1 = LN(d_model)

        self.ffn = FFN(d_model)
        self.ln2 = LN(d_model)

        self.logit = Logit(d_model, vocab_size)

    def forward(self, X):
        attn_out, weights = self.attn(X)
        X = self.ln1(X + attn_out)

        ffn_out = self.ffn(X)
        X = self.ln2(X + ffn_out)

        logits = self.logit(X)
        return logits, weights


### **6 Single-Step Training Sanity Check**

* This section performs a minimal forward and backward pass to verify that the decoder is fully differentiable and can learn from data.
* The purpose here is **not convergence**, but to validate end-to-end data flow, loss computation, and gradient updates.

In [None]:
# ------------------------
# Model & Optimizer
# ------------------------
model = Decoder(vocab_size, d_model)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

# ------------------------
# Dummy Batch (Sanity Check)
# ------------------------
X = torch.randn(batch_size, seq_len, d_model)                         # input representations for each token
Y = torch.randint(0, vocab_size, (batch_size, seq_len))               # the target token IDs the model should predict

# ------------------------
# Forward
# ------------------------
logits, weights = model(X)   # (B, T, V)

B, T, V = logits.shape
logits = logits.view(B * T, V)
Y = Y.view(B * T)

# ------------------------
# Loss
# ------------------------
loss = loss_fn(logits, Y)

# ------------------------
# Backward
# ------------------------
optimizer.zero_grad()
loss.backward()
optimizer.step()

### **Important Observations**

In [11]:
# --------------------------------------------------
# 1. Input / Output Shape Validation
# --------------------------------------------------
# Ensures that tensor dimensions align across the model.
# If shapes are correct, the forward pass is structurally sound.

print("Input X shape:", X.shape)           # (B, T, D)
print("\nOutput logits shape:", logits.shape) # (B, T, V)


# --------------------------------------------------
# 2. Attention Weight Inspection
# --------------------------------------------------
# The attention matrix reveals autoregressive behavior.
# Each token should attend only to itself and previous tokens,
# confirming that causal masking is applied correctly.

print("\nAttention weights shape:", weights.shape)  # (B, T, T)
print("\nAttention weights (batch 0):")
print(weights[0])


# --------------------------------------------------
# 3. Logit Inspection
# --------------------------------------------------
# Logits represent unnormalized scores over the vocabulary.
# At initialization, these scores are random, leading to an
# approximately uniform probability distribution.

print("\nLogits at first position:")
print(logits[0])


# --------------------------------------------------
# 4. Loss Inspection
# --------------------------------------------------
# With random initialization, the loss should be approximately log(V),
# indicating uniform guessing across the vocabulary.

print("\nLoss:", loss.item())


# --------------------------------------------------
# 5. Gradient Flow Check
# --------------------------------------------------
# A non-zero gradient at the query projection (Wq) confirms that
# learning signals propagate from the loss back through the entire model.

print(
    "\nGrad norm (Wq):",
    model.attn.Wq_projection.weight.grad.norm()
)


Input X shape: torch.Size([1, 6, 4])

Output logits shape: torch.Size([6, 20])

Attention weights shape: torch.Size([1, 6, 6])

Attention weights (batch 0):
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5292, 0.4708, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4307, 0.4072, 0.1621, 0.0000, 0.0000, 0.0000],
        [0.1702, 0.2278, 0.3822, 0.2199, 0.0000, 0.0000],
        [0.1436, 0.1866, 0.2123, 0.2014, 0.2561, 0.0000],
        [0.2025, 0.1726, 0.1301, 0.1739, 0.1413, 0.1797]],
       grad_fn=<SelectBackward0>)

Logits at first position:
tensor([-0.6833,  0.2061,  0.3898,  0.3143, -0.5294, -0.1016,  0.1618,  0.4335,
         0.7510,  1.6260, -0.8837,  0.2265,  1.1763, -0.6008,  0.9962,  0.8220,
         0.1334, -0.5884, -0.3020,  0.5791], grad_fn=<SelectBackward0>)

Loss: 3.23246693611145

Grad norm (Wq): tensor(0.0190)
