In [1]:
import math
import numpy as np
import torch
import torch.nn as nn

## Transformer

- Blocks
    - Input embedding
    - Position encoding
    - Attention
    - Layer normalization
    

### Input Embedding

## Self Attention

Self-attention is a key component of transformer architectures that allows a model to weigh the importance of different parts of the input sequence when processing each element.

### How it Works

1. **Input Transformation**
   - For each input element, create three vectors:
     - **Query (Q)**: What the current element is looking for
     - **Key (K)**: What this element offers to others
     - **Value (V)**: The actual content of the element
   - These are created by multiplying the input with learned weight matrices (WQ, WK, WV)

2. **Attention Score Calculation**
   ```python
   scores = (Q × K^T) / sqrt(d_k)
   ```
   - Multiply Query with Key transpose to get compatibility scores
   - Scale by sqrt(d_k) to prevent softmax from having extremely small gradients
   - d_k is the dimension of the key vectors

3. **Attention Weights**
   - Apply softmax to scores to get probabilities
   - These weights determine how much each position will focus on other positions

4. **Final Output**
   ```python
   attention = softmax(scores) × V
   ```
   - Multiply attention weights with Values
   - This produces the final attention-weighted representation

### Benefits

- **Global Context**: Each position can attend to all other positions
- **Parallel Processing**: All attention computations can be done simultaneously
- **No Sequential Bottleneck**: Unlike RNNs, information can flow directly between any positions
- **Interpretable**: Attention weights show which parts of input are important for each output

In [47]:
class SelfAttention(nn.Module):
    def __init__(self, d_model: int):
        super(SelfAttention, self).__init__()
        self.d_model = d_model
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

    def attention(
        self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor
    ) -> torch.Tensor:
        d_k = Q.shape[-1]
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
        attn_probs = torch.softmax(attn_scores, dim=1)
        attention_output = torch.matmul(attn_probs, V)

        return attention_output

    def forward(
        self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor
    ) -> torch.Tensor:
        Q_prime = self.W_q(Q)
        K_prime = self.W_k(K)
        V_prime = self.W_v(V)

        return self.attention(Q_prime, K_prime, V_prime)


# Create test inputs
batch_size = 1
seq_length = 4
d_model = 8

# Create random input tensors
x = torch.randn(batch_size, seq_length, d_model)

# Initialize the self-attention module
self_attention = SelfAttention(d_model)

# Pass the same tensor as Q, K, and V (self-attention)
output = self_attention(x, x, x)

# Print shapes to verify
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")

# Print a sample of the output
print("\nSample of output:")
print(output[0][0])  # First 5 values of the first sequence in the first batch

Input shape: torch.Size([1, 4, 8])
Output shape: torch.Size([1, 4, 8])

Sample of output:
tensor([-0.2497, -0.2451,  0.0241, -0.4867,  0.1290,  0.2022,  0.1524, -0.3079],
       grad_fn=<SelectBackward0>)


## Multi-head Attention

Multi-head attention is an enhancement to the basic attention mechanism that allows the model to capture different types of relationships between elements simultaneously.

### How Multi-Head Attention Works

1. **Split Into Heads**
   - Instead of one attention operation, perform multiple in parallel
   - Split the input embedding dimension (d_model) into h heads
   - Each head works with dimension d_k = d_model/h

2. **Per-Head Processing**
   ```python
   # For each head i:
   head_i = Attention(Q_i, K_i, V_i)
   where:
   Q_i = W_Q_i × X
   K_i = W_K_i × X
   V_i = W_V_i × X
   ```
   - Each head has its own learnable weight matrices
   - Each head can specialize in different relationship patterns

3. **Combine Heads**
   ```python
   MultiHead(Q, K, V) = Concat(head_1, ..., head_h) × W_O
   ```
   - Concatenate outputs from all heads
   - Project back to original dimension using W_O

### Benefits of Multi-Head Attention

1. **Diverse Feature Learning**
   - Different heads can learn different aspects:
     - Head 1 might focus on syntactic relationships
     - Head 2 might capture semantic similarities
     - Head 3 might learn positional patterns

2. **Parallel Processing**
   - All heads operate independently
   - Computation can be done in parallel using matrix operations

3. **Enhanced Representation**
   - Combines multiple views of the relationships
   - More robust than single-head attention
   - Can capture both fine and coarse-grained patterns

### Typical Configuration

- Common settings in transformer models:
  - 8 attention heads (h=8)
  - If d_model = 512, each head operates on d_k = 64 dimensions
  - Total computation remains similar to single-head attention
  - Final output dimension matches input (d_model)

In [54]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def attention(
        self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor
    ) -> torch.Tensor:
        d_k = Q.shape[-1]
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
        attn_probs = torch.softmax(attn_scores, dim=1)
        attention_output = torch.matmul(attn_probs, V)

        return attention_output

    def split_heads(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, d_model = x.size()
        x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
        x = x.transpose(1, 2)

        return x

    def combine_heads(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, num_heads, seq_len, d_k = x.size()
        x = x.transpose(1, 2)  # (batch_size, seq_len, num_heads, d_k)
        x = x.contiguous().view(batch_size, seq_len, self.d_model)

        return x

    def forward(
        self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor
    ) -> torch.Tensor:
        Q_prime = self.W_q(Q)  # (batch_size, seq_len, d_model)
        K_prime = self.W_k(K)  # (batch_size, seq_len, d_model)
        V_prime = self.W_v(V)  # (batch_size, seq_len, d_model)

        Q_prime = self.split_heads(Q_prime)  # (batch_size, num_heads, seq_len, d_k)
        K_prime = self.split_heads(K_prime)  # (batch_size, num_heads, seq_len, d_k)
        V_prime = self.split_heads(V_prime)  # (batch_size, num_heads, seq_len, d_k)

        attention_output = self.attention(Q_prime, K_prime, V_prime)

        concat_output = self.combine_heads(attention_output)
        output = self.W_o(concat_output)  # (batch_size, seq_len, d_model)

        return output
    
# Create test inputs
batch_size = 1
seq_length = 4
d_model = 8

# Create random input tensors
x = torch.randn(batch_size, seq_length, d_model)

# Initialize the self-attention module
multi_head_attention = MultiHeadAttention(d_model, num_heads=4)

# Pass the same tensor as Q, K, and V (self-attention)
output = multi_head_attention(x, x, x)

# Print shapes to verify
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")

# Print a sample of the output
print("\nSample of output:")
print(output[0][0])  # First 5 values of the first sequence in the first batch

Input shape: torch.Size([1, 4, 8])
Output shape: torch.Size([1, 4, 8])

Sample of output:
tensor([ 0.4102, -0.0562, -0.1898,  0.2277, -0.0764, -0.0449, -0.3150,  0.6217],
       grad_fn=<SelectBackward0>)
