# Implementing Transformer Architecture: A Step-by-Step Guide

## Paper Reference
- ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762) (Vaswani et al., 2017)
- Key sections: 
  - 3.1: Encoder and Decoder Stacks
  - 3.2: Attention Mechanism
  - 3.3: Position-wise Feed-Forward Networks
  - 3.4: Embeddings and Softmax
  - 3.5: Positional Encoding
  - 5.4: Regularization (dropout strategy)

## Implementation Strategy
Breaking down the architecture into manageable pieces and gradually adding complexity:

1. Start with foundational components:
   - Embedding + Positional Encoding
   - Single-head self-attention
   
2. Build up attention mechanism:
   - Extend to multi-head attention
   - Add cross-attention capability
   - Implement attention masking

3. Construct larger components:
   - Encoder (self-attention + FFN)
   - Decoder (masked self-attention + cross-attention + FFN)
   
4. Combine into final architecture:
   - Encoder-Decoder stack
   - Full Transformer with input/output layers

## Development Tips
1. Visualization and Planning:
   - Draw out tensor dimensions on paper
   - Sketch attention patterns and masks
   - Map each component back to paper equations
   - This helps catch dimension mismatches early!

2. Dimension Cheat Sheet:
   - Input tokens: [batch_size, seq_len]
   - Embeddings: [batch_size, seq_len, d_model]
   - Attention matrices: [batch_size, num_heads, seq_len, seq_len]
   - FFN hidden layer: [batch_size, seq_len, d_ff]
   - Output logits: [batch_size, seq_len, vocab_size]

3. Common Pitfalls:
   - Forgetting to scale dot products by √d_k
   - Applying mask too early or too late
   - Incorrect mask dimensions or application
   - Missing residual connections
   - Wrong order of layer norm and dropout
   - Tensor dimension mismatches in attention
   - Not handling padding properly

4. Performance Considerations:
   - Memory usage scales with sequence length squared
   - Attention computation is O(n²) with sequence length
   - Balance between d_model and num_heads
   - Trade-off between model size and batch size

## Testing Strategy
- Test each component independently
- Verify shape preservation
- Check attention patterns
- Confirm mask effectiveness
- Validate gradient flow
- Monitor numerical stability

Remember: The key to successfully implementing the Transformer is understanding how each piece fits together and maintaining clear dimension tracking throughout the implementation.

reference: https://huggingface.co/datasets/bird-of-paradise/transformer-from-scratch-tutorial/blob/main/Transformer_Implementation_Tutorial.ipynb

## Code Section
### Embedding and Positional Encoding
This implements the input embedding from Section 3.4 and positional encoding from Section 3.5 of the paper. Key points:
- Embedding dimension can differ from model dimension (using projection)
- Positional encoding uses sine and cosine functions
- Scale embeddings by $\sqrt{d_{model}}$
- Apply dropout to the sum of embeddings and positional encodings

Implementation tips:
- Use `nn.Embedding` for token embeddings
- Store scaling factor as float during initialization
- Remember to expand positional encoding for batch dimension
- Add assertion for input dtype (should be torch.long)

In [11]:
import math
import torch
import torch.nn as nn


# 输入和输出词表的 embedding 都可以使用这个类
class EmbeddingWithProjection(nn.Module):
    def __init__(self, vocab_size, d_embed, d_model, dropout=0.1):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_embed = d_embed
        self.d_model = d_model
        self.embedding = nn.Embedding(self.vocab_size, self.d_embed)
        self.embd_proj = nn.Linear(self.d_embed, self.d_model)
        self.scaling = float(math.sqrt(self.d_model))
        self.layernorm = nn.LayerNorm(self.d_model)
        self.dropout = nn.Dropout(p=dropout)

    @staticmethod
    def create_positional_encoding(seq_length, d_model, batch_size=1):
        # Create position indices: [seq_length, 1]
        position = torch.arange(seq_length).unsqueeze(1).float()

        # Create dimension indices: [1, d_model//2]
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )

        # Create empty tensor: [seq_length, d_model]
        pe = torch.zeros(seq_length, d_model)

        # Compute sin and cos
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # Add batch dimension and expand: [batch_size, seq_length, d_model]
        pe = pe.unsqueeze(0).expand(batch_size, -1, -1)

        return pe

    def forward(self, x):
        # 在进行embedding查找之前需要确定输入满足下标的要求，所以输入类型必须是torch.long
        assert (
            x.dtype == torch.long
        ), f"Input tensor must have dtype torch.long, got {x.dtype}"
        batch_size, seq_length = x.size()  # [batch, seq_length]

        # token embedding
        token_embedding = self.embedding(x)  # [batch_size, seq_length, d_embed]
        # project the scaled token embedding to the d_model space
        token_embedding = (
            self.embd_proj(token_embedding) * self.scaling
        )  # [batch_size, seq_length, d_model]

        # add positional encodings to projected,
        # scaled embeddings before applying layer norm and dropout.
        positional_encoding = self.create_positional_encoding(
            seq_length, self.d_model, batch_size
        )  # [batch_size, seq_length, d_model]

        # In addition, we apply dropout to the sums of the embeddings
        # in both the encoder and decoder stacks. For the base model, we use a rate of Pdrop = 0.1.
        normalized_sum = self.layernorm(token_embedding + positional_encoding)
        final_output = self.dropout(normalized_sum)

        return final_output

In [20]:
# 方法1: 总参数量
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())


# 方法2: 区分可训练和不可训练参数
def count_parameters_detailed(model):
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    total = trainable + non_trainable
    return {"trainable": trainable, "non_trainable": non_trainable, "total": total}


# 方法3: 格式化显示（以百万为单位）
def format_parameter_count(count):
    if count >= 1e9:
        return f"{count/1e9:.2f}B"
    elif count >= 1e6:
        return f"{count/1e6:.2f}M"
    elif count >= 1e3:
        return f"{count/1e3:.2f}K"
    else:
        return str(count)


# 方法4: 显示每层参数量
def print_model_structure(model):
    total_params = 0
    print(f"{'Layer Name':<30} {'Parameters':<15} {'Shape'}")
    print("-" * 65)

    for name, param in model.named_parameters():
        param_count = param.numel()
        total_params += param_count
        print(
            f"{name:<30} {format_parameter_count(param_count):<15} {str(list(param.shape))}"
        )

    print("-" * 65)
    print(f"{'Total':<30} {format_parameter_count(total_params):<15}")
    return total_params


# 测试 EmbeddingWithProjection 模型
vocab_size = 50000
d_embed = 1024
d_model = 512

embd = EmbeddingWithProjection(vocab_size, d_embed, d_model)


def custom_show_model_detail(model):
    print("=== 模型结构 ===")
    print(model)
    print("\n=== 参数统计 ===")

    # 简单统计
    total_params = count_parameters(model)
    print(f"总参数量: {total_params:,} ({format_parameter_count(total_params)})")

    # 详细统计
    detailed = count_parameters_detailed(model)
    print(
        f"可训练参数: {detailed['trainable']:,} ({format_parameter_count(detailed['trainable'])})"
    )
    print(
        f"不可训练参数: {detailed['non_trainable']:,} ({format_parameter_count(detailed['non_trainable'])})"
    )

    print("\n=== 每层参数详情 ===")
    print_model_structure(model)


custom_show_model_detail(embd)

x = torch.randint(0, vocab_size, (1, 20))
print(x.shape)
print(x)
output = embd(x)
print(output.shape)
print(output)

=== 模型结构 ===
EmbeddingWithProjection(
  (embedding): Embedding(50000, 1024)
  (embd_proj): Linear(in_features=1024, out_features=512, bias=True)
  (layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

=== 参数统计 ===
总参数量: 51,725,824 (51.73M)
可训练参数: 51,725,824 (51.73M)
不可训练参数: 0 (0)

=== 每层参数详情 ===
Layer Name                     Parameters      Shape
-----------------------------------------------------------------
embedding.weight               51.20M          [50000, 1024]
embd_proj.weight               524.29K         [512, 1024]
embd_proj.bias                 512             [512]
layernorm.weight               512             [512]
layernorm.bias                 512             [512]
-----------------------------------------------------------------
Total                          51.73M         
torch.Size([1, 20])
tensor([[ 6370,  5742, 33006, 38192, 12627, 41345, 25454, 45323,  8666, 43669,
         11891, 48564,  6534, 1797

### Transformer Attention
Implements the core attention mechanism from Section 3.2.1. Formula: Attention(Q,K,V) = softmax(QK^T/ $\sqrt{d_k}$ )V

Key points:
- Supports both self-attention and cross-attention
- Multi-head attention implementation per section 3.2.2
- Handles different sequence lengths for encoder/decoder
- Scales dot products by 1/√d_k
- Applies attention masking before softmax
- Applies dropout after softmax

Implementation tips:
- Use separate Q,K,V projections
- Handle masking through addition (not masked_fill)
- Remember to use braodcasting and reshape for multi-head attention
- Keep track of tensor dimensions at each step

In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import math


class TransformerAttention(nn.Module):
    """
    Transformer Scaled Dot Product Attention Module
    Args:
        d_model: Total dimension of the model.
        num_head: Number of attention heads.
        dropout: Dropout rate for attention scores.
        bias: Whether to include bias in linear projections.

    Inputs:
        sequence: input sequence for self-attention and the query for cross-attention
        key_value_state: input for the key, values for cross-attention
    """

    def __init__(
        self, d_model, num_head, dropout=0.1, bias=True
    ):  # infer d_k, d_v, d_q from d_model
        super().__init__()  # Missing in the original implementation
        assert d_model % num_head == 0, "d_model must be divisible by num_head"
        self.d_model = d_model
        self.num_head = num_head
        self.d_head = d_model // num_head
        self.dropout_rate = dropout  # Store dropout rate separately

        # linear transformations
        self.q_proj = nn.Linear(d_model, d_model, bias=bias)
        self.k_proj = nn.Linear(d_model, d_model, bias=bias)
        self.v_proj = nn.Linear(d_model, d_model, bias=bias)
        self.output_proj = nn.Linear(d_model, d_model, bias=bias)

        # Dropout layer
        self.dropout = nn.Dropout(p=dropout)

        # Initiialize scaler
        self.scaler = float(
            1.0 / math.sqrt(self.d_head)
        )  # Store as float in initialization

    def forward(self, x, encoder_output=None, att_mask=None):
        """Input shape: [batch_size, seq_len, d_model=num_head * d_head]"""
        batch_size, seq_len, d_model = x.size()

        # Check only critical input dimensions
        assert (
            d_model == self.d_model
        ), f"Input dimension {d_model} doesn't match model dimension {self.d_model}"
        if encoder_output is not None:
            assert (
                encoder_output.size(-1) == self.d_model
            ), f"Cross attention key/value dimension {encoder_output.size(-1)} doesn't match model dimension {self.d_model}"

        # if encoder_output are provided this layer is used as a cross-attention layer
        # for the decoder
        is_cross_attention = encoder_output is not None

        # Linear projections and reshape for multi-head
        Q_state = self.q_proj(x)
        if is_cross_attention:
            kv_seq_len = encoder_output.size(1)
            K_state = self.k_proj(encoder_output)
            V_state = self.v_proj(encoder_output)
        else:
            kv_seq_len = seq_len
            K_state = self.k_proj(x)
            V_state = self.v_proj(x)

        # [batch_size, self.num_head, seq_len, self.d_head]
        Q_state = Q_state.view(
            batch_size, seq_len, self.num_head, self.d_head
        ).transpose(1, 2)

        # in cross-attention, key/value sequence length might be different from query sequence length
        K_state = K_state.view(
            batch_size, kv_seq_len, self.num_head, self.d_head
        ).transpose(1, 2)
        V_state = V_state.view(
            batch_size, kv_seq_len, self.num_head, self.d_head
        ).transpose(1, 2)

        # Compute attention matrix: QK^T, result shape: [batch_size, num_head, seq_len, kv_seq_len]
        self.att_matrix = torch.matmul(Q_state, K_state.transpose(-1, -2)) * self.scaler

        # apply attention mask to attention matrix
        if att_mask is not None and not isinstance(att_mask, torch.Tensor):
            raise TypeError("att_mask must be a torch.Tensor")

        if att_mask is not None:
            self.att_matrix = self.att_matrix + att_mask

        # apply softmax to the last dimension to get the attention score: softmax(QK^T)
        # result shape: [batch_size, num_head, seq_len, kv_seq_len]
        att_score = F.softmax(self.att_matrix, dim=-1)

        # apply drop out to attention score, result shape: [batch_size, num_head, seq_len, kv_seq_len]
        att_score = self.dropout(att_score)

        # get final output: softmax(QK^T)V, result shape: [batch_size, num_head, seq_len, d_head]
        att_output = torch.matmul(att_score, V_state)

        # concatinate all attention heads
        att_output = att_output.transpose(
            1, 2
        )  # [batch_size, seq_len, num_head, d_head]
        att_output = att_output.contiguous().view(
            batch_size, seq_len, self.num_head * self.d_head
        )  # [batch_size, seq_len, d_model]

        # final linear transformation to the concatenated output
        att_output = self.output_proj(att_output)  # [batch_size, seq_len, d_model]

        assert att_output.size() == (
            batch_size,
            seq_len,
            self.d_model,
        ), f"Final output shape {att_output.size()} incorrect"

        return att_output

In [23]:
d_model = 512
d_head = 8
attn = TransformerAttention(d_model, d_head, dropout=0.1, bias=True)
custom_show_model_detail(attn)

=== 模型结构 ===
TransformerAttention(
  (q_proj): Linear(in_features=512, out_features=512, bias=True)
  (k_proj): Linear(in_features=512, out_features=512, bias=True)
  (v_proj): Linear(in_features=512, out_features=512, bias=True)
  (output_proj): Linear(in_features=512, out_features=512, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

=== 参数统计 ===
总参数量: 1,050,624 (1.05M)
可训练参数: 1,050,624 (1.05M)
不可训练参数: 0 (0)

=== 每层参数详情 ===
Layer Name                     Parameters      Shape
-----------------------------------------------------------------
q_proj.weight                  262.14K         [512, 512]
q_proj.bias                    512             [512]
k_proj.weight                  262.14K         [512, 512]
k_proj.bias                    512             [512]
v_proj.weight                  262.14K         [512, 512]
v_proj.bias                    512             [512]
output_proj.weight             262.14K         [512, 512]
output_proj.bias               512             [512]

### Feed-Forward Network (FFN)
Implements the position-wise feed-forward network from Section 3.3: FFN(x) = max(0, xW₁ + b₁)W₂ + b₂

Key points:
- Two linear transformations with ReLU in between
- Inner layer dimension (d_ff) is typically 2048
- Applied identically to each position

Implementation tips:
- Use nn.Linear for transformations
- Remember to include bias terms
- Position-wise means same transformation for each position
- Dimension flow: d_model → d_ff → d_model

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class FFN(nn.Module):
    """
    Position-wise Feed-Forward Networks
    This consists of two linear transformations with a ReLU activation in between.

    FFN(x) = max(0, xW1 + b1 )W2 + b2
    d_model: embedding dimension (e.g., 512)
    d_ff: feed-forward dimension (e.g., 2048)

    """

    def __init__(self, d_model, d_ff):
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff

        # Linear transformation y = xW+b
        self.fc1 = nn.Linear(self.d_model, self.d_ff, bias=True)
        self.fc2 = nn.Linear(self.d_ff, self.d_model, bias=True)

        # for potential speed up
        # Pre-normalize the weights (can help with training stability)
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)

    def forward(self, input):
        # check input and first FF layer dimension matching
        batch_size, seq_length, d_input = input.size()
        assert (
            self.d_model == d_input
        ), "d_model must be the same dimension as the input"

        # First linear transformation followed by ReLU
        # There's no need for explicit torch.max() as F.relu() already implements max(0,x)
        f1 = F.relu(self.fc1(input))

        # max(0, xW_1 + b_1)W_2 + b_2
        f2 = self.fc2(f1)

        return f2

In [4]:
net = FFN(d_model=512, d_ff=2048)
print(net)

FFN(
  (fc1): Linear(in_features=512, out_features=2048, bias=True)
  (fc2): Linear(in_features=2048, out_features=512, bias=True)
)


### Transformer Encoder
Implements single encoder layer from Section 3.1, consisting of:
- Multi-head self-attention
- Position-wise feed-forward network
- Residual connections and layer normalization


Implementation tips:
- Apply dropout before adding residual
- Keep model dimension consistent through the layer

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class TransformerEncoder(nn.Module):
    """
    Encoder layer of the Transformer
    Sublayers: TransformerAttention
               Residual LayerNorm
               FNN
               Residual LayerNorm
    Args:
            d_model: 512 model hidden dimension
            d_embed: 512 embedding dimension, same as d_model in transformer framework
            d_ff: 2048 hidden dimension of the feed forward network
            num_head: 8 Number of attention heads.
            dropout:  0.1 dropout rate

            bias: Whether to include bias in linear projections.

    """

    def __init__(self, d_model, d_ff, num_head, dropout=0.1, bias=True):
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff

        # attention sublayer
        self.att = TransformerAttention(
            d_model=d_model, num_head=num_head, dropout=dropout, bias=bias
        )

        # FFN sublayer
        self.ffn = FFN(d_model=d_model, d_ff=d_ff)

        # Dropout layer
        self.dropout = nn.Dropout(p=dropout)

        # layer-normalization layer
        self.LayerNorm_att = nn.LayerNorm(self.d_model)
        self.LayerNorm_ffn = nn.LayerNorm(self.d_model)

    def forward(self, embed_input, padding_mask=None):

        batch_size, seq_len, _ = embed_input.size()

        ## First sublayer: self attion
        att_sublayer = self.att(
            sequence=embed_input, key_value_states=None, att_mask=padding_mask
        )  # [batch_size, sequence_length, d_model]

        # apply dropout before layer normalization for each sublayer
        att_sublayer = self.dropout(att_sublayer)
        # Residual layer normalization
        att_normalized = self.LayerNorm_att(
            embed_input + att_sublayer
        )  # [batch_size, sequence_length, d_model]

        ## Second sublayer: FFN
        ffn_sublayer = self.ffn(
            att_normalized
        )  # [batch_size, sequence_length, d_model]
        ffn_sublayer = self.dropout(ffn_sublayer)
        ffn_normalized = self.LayerNorm_ffn(
            att_normalized + ffn_sublayer
        )  # [batch_size, sequence_length, d_model]

        return ffn_normalized

In [6]:
net = TransformerEncoder(d_model=512, d_ff=2048, num_head=8, dropout=0.1, bias=True)
print(net)

TransformerEncoder(
  (att): TransformerAttention(
    (q_proj): Linear(in_features=512, out_features=512, bias=True)
    (k_proj): Linear(in_features=512, out_features=512, bias=True)
    (v_proj): Linear(in_features=512, out_features=512, bias=True)
    (output_proj): Linear(in_features=512, out_features=512, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (ffn): FFN(
    (fc1): Linear(in_features=512, out_features=2048, bias=True)
    (fc2): Linear(in_features=2048, out_features=512, bias=True)
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (LayerNorm_att): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (LayerNorm_ffn): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)


### Transformer Decoder
Implements decoder layer from Section 3.1, with three sub-layers:
- Masked multi-head self-attention
- Multi-head cross-attention with encoder output
- Position-wise feed-forward network

Key points:
- Self-attention uses causal masking
- Cross-attention allows attending to all encoder outputs
- Each sub-layer followed by residual connection and layer normalization
- Apply dropout to the output of previous sub-layer before residual connection and layer normalization

Implementation tips:
- Order of operations matters (masking before softmax)
- Each attention layer has its own projections
- Remember to pass encoder outputs for cross-attention
- Careful with mask dimensions in self and cross attention
- Key implementation detail for causal masking:
- Create causal mask using upper triangular matrix:
  ```python
  mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
  mask = mask.masked_fill(mask == 1, float('-inf'))

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class TransformerDecoder(nn.Module):
    """
    Decoder layer of the Transformer
    Sublayers: TransformerAttention with self-attention
               Residual LayerNorm
               TransformerAttention with cross-attention
               Residual LayerNorm
               FNN
               Residual LayerNorm
    Args:
            d_model: 512 model hidden dimension
            d_embed: 512 embedding dimension, same as d_model in transformer framework
            d_ff: 2048 hidden dimension of the feed forward network
            num_head: 8 Number of attention heads.
            dropout:  0.1 dropout rate

            bias: Whether to include bias in linear projections.

    """

    def __init__(self, d_model, d_ff, num_head, dropout=0.1, bias=True):
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff

        # attention sublayer
        self.att = TransformerAttention(
            d_model=d_model, num_head=num_head, dropout=dropout, bias=bias
        )

        # FFN sublayer
        self.ffn = FFN(d_model=d_model, d_ff=d_ff)

        # Dropout layer
        self.dropout = nn.Dropout(p=dropout)

        # layer-normalization layer
        self.LayerNorm_att1 = nn.LayerNorm(self.d_model)
        self.LayerNorm_att2 = nn.LayerNorm(self.d_model)
        self.LayerNorm_ffn = nn.LayerNorm(self.d_model)

    @staticmethod
    def create_causal_mask(seq_len):
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
        mask = mask.masked_fill(mask == 1, float("-inf"))
        return mask

    def forward(self, embed_input, cross_input, padding_mask=None):
        """
        Args:
        embed_input: Decoder input sequence [batch_size, seq_len, d_model]
        cross_input: Encoder output sequence [batch_size, encoder_seq_len, d_model]
        casual_attention_mask: Causal mask for self-attention [batch_size, seq_len, seq_len]
        padding_mask: Padding mask for cross-attention [batch_size, seq_len, encoder_seq_len]
        Returns:
        Tensor: Decoded output [batch_size, seq_len, d_model]
        """
        batch_size, seq_len, _ = embed_input.size()

        assert (
            embed_input.size(-1) == self.d_model
        ), f"Input dimension {embed_input.size(-1)} doesn't match model dimension {self.d_model}"
        assert (
            cross_input.size(-1) == self.d_model
        ), "Encoder output dimension doesn't match model dimension"

        # Generate and expand causal mask for self-attention
        causal_mask = self.create_causal_mask(seq_len).to(
            embed_input.device
        )  # [seq_len, seq_len]
        causal_mask = causal_mask.unsqueeze(0).unsqueeze(1)  # [1, 1, seq_len, seq_len]

        ## First sublayer: self attion
        # After embedding and positional encoding, input sequence feed into current attention sublayer
        # Or, the output of the previous encoder/decoder feed into current attention sublayer
        att_sublayer1 = self.att(
            sequence=embed_input, key_value_states=None, att_mask=causal_mask
        )  # [batch_size, num_head, sequence_length, d_model]
        # apply dropout before layer normalization for each sublayer
        att_sublayer1 = self.dropout(att_sublayer1)
        # Residual layer normalization
        att_normalized1 = self.LayerNorm_att1(
            embed_input + att_sublayer1
        )  # [batch_size, sequence_length, d_model]

        ## Second sublayer: cross attention
        # Query from the output of previous attention output, or training data
        # Key, Value from output of Encoder of the same layer
        att_sublayer2 = self.att(
            sequence=att_normalized1,
            key_value_states=cross_input,
            att_mask=padding_mask,
        )  # [batch_size, sequence_length, d_model]
        # apply dropout before layer normalization for each sublayer
        att_sublayer2 = self.dropout(att_sublayer2)
        # Residual layer normalization
        att_normalized2 = self.LayerNorm_att2(
            att_normalized1 + att_sublayer2
        )  # [batch_size, sequence_length, d_model]

        # Third sublayer: FFN
        ffn_sublayer = self.ffn(
            att_normalized2
        )  # [batch_size, sequence_length, d_model]
        ffn_sublayer = self.dropout(ffn_sublayer)
        ffn_normalized = self.LayerNorm_ffn(
            att_normalized2 + ffn_sublayer
        )  # [batch_size, sequence_length, d_model]

        return ffn_normalized

In [8]:
net = TransformerDecoder(d_model=512, d_ff=2048, num_head=8, dropout=0.1, bias=True)
print(net)

TransformerDecoder(
  (att): TransformerAttention(
    (q_proj): Linear(in_features=512, out_features=512, bias=True)
    (k_proj): Linear(in_features=512, out_features=512, bias=True)
    (v_proj): Linear(in_features=512, out_features=512, bias=True)
    (output_proj): Linear(in_features=512, out_features=512, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (ffn): FFN(
    (fc1): Linear(in_features=512, out_features=2048, bias=True)
    (fc2): Linear(in_features=2048, out_features=512, bias=True)
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (LayerNorm_att1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (LayerNorm_att2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (LayerNorm_ffn): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)


### Encoder-Decoder Stack
Implements the full stack of encoder and decoder layers from Section 3.1.

Key points:
- Multiple encoder and decoder layers (typically 6)
- Each encoder output feeds into all decoder layers
- Maintains residual connections throughout the stack

Implementation tips:
- Use nn.ModuleList for layer stacks
- Share encoder outputs across decoder layers
- Maintain consistent masking throughout
- Handle padding masks separately from causal masks


In [9]:
class TransformerEncoderDecoder(nn.Module):
    """
    Encoder-Decoder stack of the Transformer
    Sublayers:  Encoder x 6
                Decoder x 6
    Args:
            d_model: 512 model hidden dimension
            d_embed: 512 embedding dimension, same as d_model in transformer framework
            d_ff: 2048 hidden dimension of the feed forward network
            num_head: 8 Number of attention heads.
            dropout:  0.1 dropout rate

            bias: Whether to include bias in linear projections.

    """

    def __init__(self, num_layer, d_model, d_ff, num_head, dropout=0.1, bias=True):
        super().__init__()
        self.num_layer = num_layer
        self.d_model = d_model
        self.d_ff = d_ff
        self.num_head = num_head
        self.dropout = dropout
        self.bias = bias

        # Encoder stack
        self.encoder_stack = nn.ModuleList(
            [
                TransformerEncoder(
                    d_model=self.d_model,
                    d_ff=self.d_ff,
                    num_head=self.num_head,
                    dropout=self.dropout,
                    bias=self.bias,
                )
                for _ in range(self.num_layer)
            ]
        )

        # Decoder stack
        self.decoder_stack = nn.ModuleList(
            [
                TransformerDecoder(
                    d_model=self.d_model,
                    d_ff=self.d_ff,
                    num_head=self.num_head,
                    dropout=self.dropout,
                    bias=self.bias,
                )
                for _ in range(self.num_layer)
            ]
        )

    def forward(self, embed_encoder_input, embed_decoder_input, padding_mask=None):
        # Process through all encoder layers first
        encoder_output = embed_encoder_input
        for encoder in self.encoder_stack:
            encoder_output = encoder(encoder_output, padding_mask)

        # Use final encoder output for all decoder layers
        decoder_output = embed_decoder_input
        for decoder in self.decoder_stack:
            decoder_output = decoder(decoder_output, encoder_output, padding_mask)

        return decoder_output

### Full Transformer
Combines all components into complete architecture:
- Input embeddings for source and target
- Positional encoding
- Encoder-decoder stack
- Final linear and softmax layer

Key points:
- Handles different vocabulary sizes for source/target
- Shifts decoder inputs for teacher forcing
- Projects outputs to target vocabulary size
- Applies log softmax for training stability

Implementation tips:
- Handle start tokens for decoder input
- Maintain separate embeddings for source/target
- Remember to scale embeddings
- Consider sharing embedding weights with output layer

In [10]:
class Transformer(nn.Module):
    def __init__(
        self,
        num_layer,
        d_model,
        d_embed,
        d_ff,
        num_head,
        src_vocab_size,
        tgt_vocab_size,
        max_position_embeddings=512,
        dropout=0.1,
        bias=True,
    ):
        super().__init__()

        self.tgt_vocab_size = tgt_vocab_size

        # Source and target embeddings
        self.src_embedding = EmbeddingWithProjection(
            vocab_size=src_vocab_size,
            d_embed=d_embed,
            d_model=d_model,
            max_position_embeddings=max_position_embeddings,
            dropout=dropout,
        )

        self.tgt_embedding = EmbeddingWithProjection(
            vocab_size=tgt_vocab_size,
            d_embed=d_embed,
            d_model=d_model,
            max_position_embeddings=max_position_embeddings,
            dropout=dropout,
        )

        # Encoder-Decoder stack
        self.encoder_decoder = TransformerEncoderDecoder(
            num_layer=num_layer,
            d_model=d_model,
            d_ff=d_ff,
            num_head=num_head,
            dropout=dropout,
            bias=bias,
        )

        # Output projection and softmax
        self.output_projection = nn.Linear(d_model, tgt_vocab_size)
        self.softmax = nn.LogSoftmax(dim=-1)

    def shift_target_right(self, tgt_tokens):
        # Shift target tokens right by padding with zeros at the beginning
        batch_size, seq_len = tgt_tokens.size()

        # Create start token (zeros)
        start_tokens = torch.zeros(
            batch_size, 1, dtype=tgt_tokens.dtype, device=tgt_tokens.device
        )

        # Concatenate start token and remove last token
        shifted_tokens = torch.cat([start_tokens, tgt_tokens[:, :-1]], dim=1)

        return shifted_tokens

    def forward(self, src_tokens, tgt_tokens, padding_mask=None):
        """
        Args:
            src_tokens: source sequence [batch_size, src_len]
            tgt_tokens: target sequence [batch_size, tgt_len]
            padding_mask: padding mask [batch_size, 1, 1, seq_len]
        Returns:
            output: [batch_size, tgt_len, tgt_vocab_size] log probabilities
        """
        # Shift target tokens right for teacher forcing
        shifted_tgt_tokens = self.shift_target_right(tgt_tokens)

        # Embed source and target sequences
        src_embedding = self.src_embedding(src_tokens)
        tgt_embedding = self.tgt_embedding(shifted_tgt_tokens)

        # Pass through encoder-decoder stack
        decoder_output = self.encoder_decoder(
            embed_encoder_input=src_embedding,
            embed_decoder_input=tgt_embedding,
            padding_mask=padding_mask,
        )

        # Project to vocabulary size and apply log softmax
        logits = self.output_projection(decoder_output)
        log_probs = self.softmax(logits)

        return log_probs