In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# Re-implementing the Transformer architecture
From the original paper *[Attention Is All You Need](https://arxiv.org/abs/1706.03762) (Vaswani et al., 2017)*

<p>
  <img src="./images/transformer.jpg" alt="1" width="700" />
</p>

- Encoder & Decoder stacks
- Positional encoding
- Multi-head attention layers
- Position-wise feed-forward networks

* `d_model`: dimension of the embeddings and all internal representations of the Transformer; must be divisible by `num_heads`
* `num_heads`: number of heads in the multi-head attention
* `num_encoders`: number of stacked Encoder blocks
* `num_decoders`: number of stacked Decoder blocks
* `src_vocab_size`: size of the source language vocabulary; determines the size of the input embedding layer
* `tgt_vocab_size`: size of the target language vocabulary; determines the final linear layer that predicts the next token
* `max_len`: maximum sequence length supported by the Positional Encoding

In [None]:
class Transformer(nn.Module):
    def __init__(
        self, d_model=512, num_heads=8, num_encoders=6, num_decoders=6,
        src_vocab_size=10000, tgt_vocab_size=10000, max_len=5000
    ):
        super().__init__()
        self.d_model = d_model
        # Embeddings
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        # Positional Encoding
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        # Encoder and Decoder
        self.encoder = Encoder(d_model, num_heads, num_encoders)
        self.decoder = Decoder(d_model, num_heads, num_decoders)
        # Output projection
        self.output = nn.Linear(d_model, tgt_vocab_size)
    
    def forward(self):
        pass

### The Two Main Components:

1. **``Encoder``**: Processes the input sequence (e.g., English sentence)

2. **``Decoder``**: Generates the output sequence (e.g., French translation)

The architecture has 6 (*num_encoders* and *num_decoders*) identical blocks of each.

* `src`: initial input (embedding + positional encoding)
* `src_mask`: padding token mask; boolean or 0/1 matrix for the source sentence
* `tgt`: decoder input (embedding + positional encoding)
* `enc`: final output of the Encoder; representation of the initial input
* `tgt_mask`: lower-triangular mask applied to the target sequence (to prevent seeing future tokens)
* `enc_mask`: padding mask for the source sentence

In [4]:
class Encoder(nn.Module):
    def __init__(self, d_model, num_heads, num_encoders):
        super().__init__()
        self.enc_layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads) for _ in range(num_encoders)
        ])
    def forward(self, src, src_mask):
        output = src
        for layer in self.enc_layers:
            output = layer(output, src_mask)
        return output

class Decoder(nn.Module):
    def __init__(self, d_model, num_heads, num_decoders):
        super().__init__()
        self.dec_layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads) for _ in range(num_decoders)
        ])
    def forward(self, tgt, enc, tgt_mask, enc_mask):
        output = tgt
        for layer in self.dec_layers:
            output = layer(output, enc, tgt_mask, enc_mask)
        return output

### Encoder Layer: Self-Attention + FFN

Each encoder layer has two sub-layers:

1. **Multi-headed self-attention** (looks at input sequence)
2. **Position-wise feed-forward network**

With residual connections and layer normalization around each.

<p>
  <img src="./images/encoder.png" alt="Encoder Layer" width="700" />
</p>



* `src`: input to the block, output from the previous block (or embedding + positional encoding at the beginning)
* `self.self_attn(x, x, x)` → Query = Key = Value = x → **self-attention**
* `x + ...` → skip connection: the input is added to the output; prevents vanishing gradients and preserves the original information
* In *Multi-Head Attention*, there are 3 main inputs:

  * 1st `x` → ``Query (Q)`` → what I’m looking for
  * 2nd `x` → ``Key (K)`` → the indices I compare against
  * 3rd `x` → ``Value (V)`` → the values I retrieve
  * Encoder: Q = K = V = x → each word attends to every other word

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff=2048, dropout=0.1):
        super().__init__()
        # 1. multi-head self attention
        self.self_attn = MultiHeadedAttention(d_model, num_heads, dropout)
        # 2. dropout and layer norm after attention for stability
        self.attn_dropout = nn.Dropout(dropout)
        self.attn_norm = nn.LayerNorm(d_model)
        # 3. feed forward network:
        #    d_model: input/output dim
        #    d_ff hidden layers dim
        self.ffn = FeedForward(d_model, d_ff, dropout)
        # 4. dropout layer norm after ffn
        self.ffn_dropout = nn.Dropout(dropout)
        self.ffn_norm = nn.LayerNorm(d_model)
    
    def forward(self, src, src_mask=None):
        x = src
        # self-attention + residual + normalization
        attn_out = self.self_attn(x, x, x, mask=src_mask)
        x = self.attn_norm(x + self.attn_dropout(attn_out))
        # feed-forward + residual + final normalization
        ffn_out = self.ffn(x)
        x = self.ffn_norm(x + self.ffn_dropout(ffn_out))
        return x


### Decoder Layer: Self-Attention + Cross-Attention + FFN

Decoder layers have 3 sub-layers:

1. **Masked Multi-headed self-attention** (can't see future tokens)
2. **Multi-headed cross-attention**: Encoder-decoder attention
3. **Position-wise feed-forward network**

This is where the magic of generation happens!

<p>
  <img src="./images/decoder.png" alt="Decoder Layer" width="700" />
</p>



In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff=2048, dropout=0.1):
        super().__init__()
        self.masked_self_attn = MultiHeadedAttention(d_model, num_heads, dropout)
        self.masked_self_attn_norm = nn.LayerNorm(d_model)
        self.cross_attn = MultiHeadedAttention(d_model, num_heads, dropout)
        self.cross_attn_norm = nn.LayerNorm(d_model)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.ffn_norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)


    def forward(self, tgt, enc, tgt_mask=None, enc_mask=None):
        x = tgt
        # 1. masked self-attention
        _x = x
        x = self.masked_self_attn(x, x, x, mask=tgt_mask)
        x = self.masked_self_attn_norm(_x + self.dropout(x))
        # 2. cross-attention with encoder output
        _x = x
        x = self.cross_attn(x, enc, enc, mask=enc_mask)
        x = self.cross_attn_norm(_x + self.dropout(x))
        # 3. feed-forward network
        _x = x
        x = self.ffn(x)
        x = self.ffn_norm(_x + self.dropout(x))
        return x
        

Encoder and Decoder set --> Sublayers, starting with **Multi-Head Attention blocks**.

These have parallel attention heads, with the number of heads (``num_heads``) as a hyperparameter.

On the left is zoomed in view of the multi-heads attention block.

<p>
  <img src="./images/mha.jpg" alt="Multi-Head Attention" width="700" />
</p>


### Multi-Head Attention: Parallel Processing

More heads mean more parameters and greater flexibility to learn patterns, for example

- Head 1: Subject-verb relationships   
- Head 2: Adjective-noun relationships   
- Head 3: Long-range dependencies  
- ... 

Then concatenate all outputs.


This ``MultiHeadedAttention`` uses a **loop over individual attention heads** (`nn.ModuleList` of `SelfAttention`). It is **intentionally written for maximum clarity**  

This version is **5-10× slower** and uses **more memory** than the vectorized version because:  
- It runs a Python loop over heads (no GPU parallelism across heads)  
- Each head has its own full linear layers (no weight sharing)  
- `torch.cat([...])` in a loop kills performance   

The **Fast and Memory-Efficient version** is implemented further below in this notebook.  

- The same Q, K, V are passed to all heads  
- Encoder: $Q = K = V$ --> Self-Attention
- Decoder: Cross-Attention --> $Q \neq K=V$: Q from decoder; K/V from encoder output   
For d\_model = 512;
- Each head produces an output of 64 dimensions → then concatenated → [batch, len\_seq, 512]


In [None]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.attn_output_size = self.d_model // self.num_heads
        # create num_heads attention heads, each with its own Wq, Wk, Wv
        self.attentions = nn.ModuleList([
            SelfAttention(d_model, self.attn_output_size)
            for _ in range(num_heads)
        ])
        self.output = nn.Linear(self.d_model, self.d_model)

    def forward(self, q, k, v, mask=None):
        x = torch.cat([
            layer(q, k, v, mask) for layer in self.attentions
        ], dim=-1)
        x = self.output(x)
        return x
    

### Self-Attention

This core innovation that started it all.

- 3 *linear projections*:  
    - ``Q (Query)``: what am I looking for?  
    - ``K (Key)``: what information do I have as references / is available?  
    - ``V (Value)``: what do I retrieve / what information to return?  
- applies the learned weights to the actual data:  
$Q = x.W_q$  
$K = x.W_k$   
$V = x.V_v$     
then, the famous formula:
$$
\boxed{
Attention(Q, K, V) = softmax(\frac{Q.K^T}{\sqrt{d_k}}) . V
}
$$


#### Why divide by $\sqrt{d_k}$ ? --> Stabilization
From the original paper, *[Attention Is All You Need](https://arxiv.org/abs/1706.03762) (Vaswani et al., 2017)*

> *"We suspect that for large values of $d_k$, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients. To counteract this effect, we scale the dot products by $1/\sqrt{d_k}$."*

1. **Q and K have dimension $d_k = d_{model} // num_{heads}$**
   - Each component of $Q_i$ and $K_j$ is initialized with variance ~1 (``Xavier/Glorot`` or ``He`` init).
   - The dot product $q_i \cdot k_j$ is the sum of $d_k$ independent terms:
     $$
     q_i \cdot k_j = \sum_{m=1}^{d_k} q_{i,m} k_{j,m}
     $$

2. **By the central limit theorem**, this sum has:
   - **Expectation**: 0
   - **Variance**: $d_k \times \text{Var}(q) \times \text{Var}(k) \approx d_k$
   - **Standard deviation**: $\sqrt{d_k}$

   Thus: $Q K^T$ has values that grow like $\sqrt{d_k}$

   For example if $d_k = 64$, the raw scores are ~8 times larger on average ($\sqrt{64}=8$).

3. **Problem with softmax**
   - Softmax is very sensitive to large values:
     $$
     \text{softmax}([10, 9, 0]) \approx [0.88, 0.12, 0]
     $$
     $$
     \text{softmax}([80, 72, 0]) \approx [1.0, 0, 0] \quad \text{(complete one-hot)}
     $$
   - Without scaling → scores are too large → **softmax degenerates into a near one-hot distribution** → gradients almost zero everywhere except at the max → learning becomes impossible.

### Role of $\sqrt{d_k}$
By dividing by $\sqrt{d_k}$, we bring the **variance of the dot product back to 1**:
$$
\frac{Q K^T}{\sqrt{d_k}} \quad \rightarrow \quad \text{variance} \approx 1
$$
→ Attention scores stay on a reasonable scale  
→ Softmax produces **smooth and differentiable distributions**  
→ Gradients propagate properly during backprop  
→ Attention can truly learn to weight the tokens


<p>
  <img src="./images/mha.jpg" alt="Multi-Head Attention" width="700" />
</p>


In [None]:
class SelfAttention(nn.Module):
    def __init__(self, d_model, output_size, dropout=0.1):
        super().__init__()
        self.query = nn.Linear(d_model, output_size)
        self.key = nn.Linear(d_model, output_size)
        self.value = nn.Linear(d_model, output_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, q, k, v, mask=None):
        # project the inputs into the attention head’s subspace
        query = self.query(q)       # [batch, seq_len, 64]
        key = self.key(k)           # [batch, seq_len, 64]
        value = self.value(v)       # [batch, seq_len, 64]

        dim_k = key.size(-1)        # -> 64
        # batch matrix multiplication: Q @ K.T / sqrt(d_k)
        scores = torch.bmm(query, key.transpose(-2, -1))/math.sqrt(dim_k)
        # apply mask (padding or future): where 0 --> set -inf --> 0 after softmax
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # dim =- 1: normalization over columns, each row sums to 1
        weights = F.softmax(scores, dim=-1)
        weights = self.dropout(weights)
        # weights @ V
        outputs = torch.bmm(weights, value)
        return outputs


### Feed-Forward Networks

After attention, each position gets processed independently.

$$FFN(x) = max(0, x.W_1 + b_1)W₂ + b_2$$

Linear transformation → ReLU → Linear transformation


In [9]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff=2048, dropout=0.1):
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.ffn(x)

### Positional Encoding: Teaching Position

Since attention has no inherent sense of order, we need to inject position information.

It's done using sinusoidal functions, which creates unique patterns for each position.



For a position `pos` (the index of the word: 0, 1, 2, …)
and a dimension `i` (from 0 to $d_{model}−1$), we compute:

$$
\boxed{
\begin{cases}
PE_{(pos, 2i)}   \;\; = \sin\left(\dfrac{pos}{10000^{2i/d_{model}}}\right) \\[8pt]
PE_{(pos, 2i+1)} \;\; = \cos\left(\dfrac{pos}{10000^{2i/d_{model}}}\right)
\end{cases}
}
$$

We want to compute:
$$ div\_term = \frac{1}{10000^{2i / d_{model}}} $$

But in numerical math, it’s more stable to write it as:
$$ \exp\left( \ln(1) - \frac{2i}{d_{model}} \ln(10000) \right) = \exp\left( -\frac{2i}{d_{model}} \ln(10000) \right) $$

Thus, `-math.log(10000.0)` = **-9.21034...**
And when we multiply by `i`, we get frequencies that **decrease exponentially**.


In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        
        pe = torch.zeros(max_len, d_model)

        # positions column 0, 1, 2, ..., max_len-1
        # shape (max_len, 1) → [[0], [1], [2], ..., [max_len-1]]
        position = torch.arange(
            0, max_len, dtype=torch.float
        ).unsqueeze(1)

        # div_term = 1 / (10000^(2i / d_model))
        # shape (d_model//2)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * 
            (-math.log(10000.0) / d_model)
        )

        # fill in even comumns with sin and odd columns with cos
        pe[:, 0::2] = torch.sin(position * div_term)    # (max_len, d_model//2)
        pe[:, 1::2] = torch.cos(position * div_term)    # (max_len, d_model//2)

        # Store pe as a non-trainable buffer
        self.register_buffer('pe', pe.unsqueeze(0))     # shape: (1, max_len, d_model)

    def forward(self, x):
        # x shape (batch_size, seq_len, d_model)
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len, :]

Finally, we bring all of this together in the forward pass of our transformer.

This is the forward pass of the main Transformer class, well defined at the beginning.



***The embeddings are scaled by $\sqrt{d_{model}}$***:

While the original paper doesn't explicitly justify this scaling, there are several widely accepted reasons:

-  **Balancing Embeddings and Positional Encodings**

--> Positional encodings (PEs) have a fixed and relatively small magnitude, usually between -1 and 1.

--> *Problem Without Scaling*: word embeddings are typically initialized with small values. Their variance is roughly proportional to $1/d_{model}$. Without adjustment, their magnitude would be much smaller than that of the PEs. As a result, the positional information could dominate the semantic meaning of the words.

--> Multiplying embeddings by $\sqrt{d_{model}}$ increases their magnitude so that the semantic information remains significant compared to positional signals. This ensures a balanced contribution between position and meaning in the input representation.

- **Training Stability**

--> *Variance Normalization:* since embeddings are initialized with a variance around $1/d_{model}$, multiplying them by $\sqrt{d_{model}}$ brings their overall variance close to 1.

--> Maintaining a unit-scale variance helps avoid vanishing or exploding gradients during backpropagation, making the optimization process more stable.


In [None]:
class Transformer(nn.Module):
    def __init__(self, d_model=512, num_heads=8, num_encoders=6, num_decoders=6, src_vocab_size=10000, tgt_vocab_size=10000, max_len=5000):
        super().__init__()
        self.d_model = d_model
        # Encoder and Decoder
        self.encoder = Encoder(d_model, num_heads, num_encoders)
        self.decoder = Decoder(d_model, num_heads, num_decoders)
        # Positional Encoding
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        # Embeddings
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        # Output projection
        self.output = nn.Linear(d_model, tgt_vocab_size)
    
    def create_pad_mask(self, seq, pad_token):
        mask = (seq != pad_token)   # (batch_size, seq_len)
        return mask.unsqueeze(1).unsqueeze(2)   # (batch_size, 1, 1, seq_len) broadcastable

    def create_subsequent_mask(self, seq_len):
        subsequent_mask = torch.triu(
            torch.ones(seq_len, seq_len, dtype=torch.bool),
            diagonal=1
        )      # (seq_len, seq_len)
        return subsequent_mask.unsqueeze(0).unsqueeze(0)  # (1, 1, seq_len, seq_len)
    
    def forward(self, src, tgt, src_pad_token=0, tgt_pad_token=0):
        # Create masks
        src_mask = self.create_pad_mask(src, src_pad_token)
        tgt_mask = self.create_pad_mask(tgt, tgt_pad_token)
        subsequent_mask = self.create_subsequent_mask(tgt.size(1)).to(tgt.device)
        tgt_mask = tgt_mask & subsequent_mask

        # Embedding with scaling by sqrt(d_model)
        src_emb = self.src_embedding(src) * math.sqrt(self.d_model)
        tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.d_model)

        # Add positional encoding
        src_emb = self.pos_encoding(src_emb)
        tgt_emb = self.pos_encoding(tgt_emb)

        # Encoder: shape (batch_size, src_seq_len, d_model)
        enc_out = self.encoder(src_emb, src_mask)
        # Decoder: shape (batch_size, tgt_seq_len, d_model)
        dec_out = self.decoder(tgt_emb, enc_out, tgt_mask, src_mask)

        # Output projection: shape (batch_size, tgt_seq_len, tgt_vocab_size)
        output = self.output(dec_out)
        return output
    
    def training_step(self, vocab_size):
        src = ["Hello", "my", "name", "is", "zedems"]
        tgt_input = torch.tensor(["<bos>", "Bonjour", "mon", "nom", "est", "zedems"])
        expected = torch.tensor(["Bonjour", "mon", "nom", "est", "zedems", "<eos>"])

        output = self.forward(src, tgt_input)

        loss = F.cross_entropy(output.reshape(-1, vocab_size),
                            expected.reshape(-1))
        return loss

The original paper tackled sequence-to-sequence translation, like English <-> French.

During training and inference phases:

- **Training**

For English <-> French translation:

- Encoder encodes the English sentence.
- Decoder gets the French target *shifted right* (teacher forcing).
- Predictions are compared to the true French output with cross-entropy loss.

**Shifting right** feeds the decoder the previous correct token (e.g., \<BOS> + target[:-1]) to focus on history for next-token prediction.

This allows autoregressive modeling by predicting the next token using all prior ones while training all positions in parallel using masking.

- **Inference**: Autoregressive Generation

During inference, the output is generated step-by-step (autoregressive). Encoder runs once, decoder runs multiple times. Each step uses previous predictions.

### Multi-Head Attention: Fast and Memory-Efficient versions

In [None]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)

        self.out = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None):
        """
        q, k, v: (B, T, d_model)
        mask: (B, ..., T_q, T_k), broadcasted automatically
        """
        B, T_q, _ = q.size()
        _, T_k, _ = k.size()

        # Linear projections
        q = self.query(q).view(B, T_q, self.num_heads, self.head_dim)
        k = self.key(k).view(B, T_k, self.num_heads, self.head_dim)
        v = self.value(v).view(B, T_k, self.num_heads, self.head_dim)

        # Reshape for multi-head: (B, T, num_heads, head_dim) --> (B, num_heads, T, head_dim)
        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)

        # Scaled dot-product
        # k --> (B, num_heads, head_dim, T_k)
        # scores: (B, num_heads, T_q, T_k)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)  # (B, num_heads, T_q, head_dim)

        # Recombine heads
        out = out.permute(0, 2, 1, 3).contiguous()
        out = out.view(B, T_q, self.d_model)
        out = self.out(out)
        return out

In [None]:
# highly optimized version with "scaled_dot_product_attention" pytorch function
class MultiHeadedAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)

        self.out = nn.Linear(d_model, d_model)
        self.dropout_p = dropout

    def forward(self, q, k, v, mask=None):
        """
        q, k, v: (B, T, d_model)
        mask: (B, ..., T_q, T_k), broadcasted automatically
        """
        B, T_q, _ = q.size()
        _, T_k, _ = k.size()

        # Linear projections
        q = self.query(q).view(B, T_q, self.num_heads, self.head_dim)
        k = self.key(k).view(B, T_k, self.num_heads, self.head_dim)
        v = self.value(v).view(B, T_k, self.num_heads, self.head_dim)

        # Reshape for multi-head: (B, T, num_heads, head_dim) --> (B, num_heads, T, head_dim)
        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)

        # Compute scaled dot-product attention using the optimized fused kernel
        # This single function handles the Q.K^T multiplication, scaling, masking, softmax, and A.V multiplication.
        # It can leverage backends like FlashAttention for significant speed and memory savings.
        attn_output = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=mask,
            dropout_p=self.dropout_p if self.training else 0.0,
        ) # Shape: (B, num_heads, T_q, head_dim)

        # Recombine the heads
        # (B, num_heads, T_q, head_dim) -> (B, T_q, num_heads, head_dim) -> (B, T_q, d_model)
        attn_output = attn_output.permute(0, 2, 1, 3).contiguous().view(B, T_q, self.d_model)

        # Final linear projection
        return self.out(attn_output)