# Attention Mechanisms in Deep Learning

## Introduction

In this notebook, we will:

- **Introduce** the concept of Attention Mechanisms and understand why they are pivotal in modern deep learning models.
- **Explore** the various types of attention, including Soft Attention, Hard Attention, and Self-Attention.
- **Implement** key attention mechanisms such as Scaled Dot-Product Attention and Additive Attention using PyTorch.
- **Provide** resources for further reading to deepen your understanding.

**Resources for Further Reading:**

- [Attention Mechanism Explained](https://towardsdatascience.com/attention-mechanism-explained-8f96b26ebae)
- [The Annotated Transformer](http://nlp.seas.harvard.edu/2018/04/03/attention.html)

**Prerequisites:**

- Familiarity with Python and PyTorch.
- Understanding of neural network fundamentals, especially sequence models like RNNs.

**Note:** Attention mechanisms have revolutionized the field of Natural Language Processing (NLP) and are integral to architectures like Transformers. They help models focus on relevant parts of the input sequence, addressing limitations inherent in traditional RNNs.

## Why Attention?: Overcoming Limitations of RNNs

Recurrent Neural Networks (RNNs) are powerful for modeling sequential data, but they have notable limitations:

- **Long-Term Dependencies:** RNNs struggle to capture dependencies between distant elements in a sequence due to issues like vanishing and exploding gradients.
- **Sequential Processing:** RNNs process data sequentially, making it difficult to parallelize computations.
- **Fixed-Size Context:** The hidden state in RNNs acts as a bottleneck, limiting the amount of information that can be retained from the input.

**Attention Mechanisms** address these challenges by allowing models to dynamically focus on different parts of the input sequence when producing each element of the output. This leads to better performance, especially in tasks requiring the integration of information from various parts of the input.

## Types of Attention

### 1. Soft vs. Hard Attention

- **Soft Attention:**
  - **Deterministic and Differentiable:** Allows the model to consider all parts of the input with varying degrees of importance.
  - **Weighted Sum:** Computes a weighted average of the input features.
  - **Backpropagation-Friendly:** Can be trained end-to-end using gradient-based optimization.

- **Hard Attention:**
  - **Stochastic and Non-Differentiable:** Selects specific parts of the input, often requiring reinforcement learning techniques for training.
  - **Discrete Selection:** Chooses exact elements or regions to focus on.
  - **Less Common:** Due to training difficulties, hard attention is less frequently used in practice.

### 2. Self-Attention

- **Definition:** A mechanism where different positions of a single sequence are related to each other to compute a representation of the sequence.
- **Usage:** Fundamental to Transformer architectures, enabling the model to capture dependencies regardless of their distance in the sequence.
- **Benefits:**
  - **Parallelization:** Unlike RNNs, self-attention allows for parallel processing of sequence elements.
  - **Long-Range Dependencies:** Effectively captures relationships between distant elements in the sequence.

## Implementation

We will implement two primary attention mechanisms:

1. **Scaled Dot-Product Attention**
2. **Additive Attention**

Both implementations will be in PyTorch.

### 1. Scaled Dot-Product Attention

**Overview:**

Scaled Dot-Product Attention computes the attention weights using the dot product of queries and keys, scales them, applies a softmax to obtain probabilities, and then uses these to weight the values.

**Formula:**

$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$

where:
- $ Q $ = Query matrix
- $ K $ = Key matrix
- $ V $ = Value matrix
- $ d_k $ = Dimension of the keys

**Implementation:**


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# Define Attention Mechanisms
class ScaledDotProductAttention(nn.Module):
    def __init__(self, head_dim):
        super(ScaledDotProductAttention, self).__init__()
        self.scale = head_dim ** 0.5  # Correct scaling based on head_dim
        
    def forward(self, Q, K, V, mask=None):
        """
        Q, K, V: [batch_size, num_heads, seq_len, head_dim]
        mask: [batch_size, num_heads, seq_len, seq_len] or similar
        """
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale  # [batch_size, num_heads, seq_len, seq_len]
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attn = F.softmax(scores, dim=-1)  # [batch_size, num_heads, seq_len, seq_len]
        output = torch.matmul(attn, V)  # [batch_size, num_heads, seq_len, head_dim]
        return output, attn


**Example Usage:**

In [2]:
# Example parameters
batch_size = 2
seq_len_q = 3
seq_len_k = 4
d_k = 5
d_v = 6

# Random tensors for Q, K, V
Q = torch.randn(batch_size, seq_len_q, d_k)
K = torch.randn(batch_size, seq_len_k, d_k)
V = torch.randn(batch_size, seq_len_k, d_v)

# Initialize attention module
attention = ScaledDotProductAttention(d_k)

# Forward pass
output, attn_weights = attention(Q, K, V)

print("Output shape:", output.shape)          # Expected: (2, 3, 6)
print("Attention weights shape:", attn_weights.shape)  # Expected: (2, 3, 4)

Output shape: torch.Size([2, 3, 6])
Attention weights shape: torch.Size([2, 3, 4])


### 2. Additive Attention

**Overview:**

Additive Attention, introduced by Bahdanau et al., computes attention scores by applying a feed-forward network to the concatenation of queries and keys, followed by a non-linear activation (usually $ \tanh $).

**Formula:**

$$
\text{Attention}(Q, K, V) = \text{softmax}(\text{score}(Q, K))V
$$

where:

$$
\text{score}(Q, K) = \mathbf{v}^T \tanh(\mathbf{W}_q Q + \mathbf{W}_k K)
$$

**Implementation:**

In [3]:
class AdditiveAttention(nn.Module):
    def __init__(self, d_q, d_k, d_attn):
        super(AdditiveAttention, self).__init__()
        self.W_q = nn.Linear(d_q, d_attn)
        self.W_k = nn.Linear(d_k, d_attn)
        self.v = nn.Linear(d_attn, 1, bias=False)

    def forward(self, Q, K, V, mask=None):
        """
        Q: Queries shape (batch_size, seq_len_q, d_q)
        K: Keys shape (batch_size, seq_len_k, d_k)
        V: Values shape (batch_size, seq_len_k, d_v)
        mask: Optional mask tensor
        """
        # Expand Q and K for addition
        # Q: (batch_size, seq_len_q, 1, d_q)
        # K: (batch_size, 1, seq_len_k, d_k)
        Q_expanded = Q.unsqueeze(2)
        K_expanded = K.unsqueeze(1)
        
        # Apply linear layers and activation
        energy = torch.tanh(self.W_q(Q_expanded) + self.W_k(K_expanded))  # (batch_size, seq_len_q, seq_len_k, d_attn)
        scores = self.v(energy).squeeze(-1)  # (batch_size, seq_len_q, seq_len_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attn = F.softmax(scores, dim=-1)  # (batch_size, seq_len_q, seq_len_k)
        output = torch.matmul(attn, V)     # (batch_size, seq_len_q, d_v)
        return output, attn


**Example Usage:**

In [4]:
# Example parameters
batch_size = 2
seq_len_q = 3
seq_len_k = 4
d_q = 5
d_k = 5
d_v = 6
d_attn = 10

# Random tensors for Q, K, V
Q = torch.randn(batch_size, seq_len_q, d_q)
K = torch.randn(batch_size, seq_len_k, d_k)
V = torch.randn(batch_size, seq_len_k, d_v)

# Initialize attention module
additive_attention = AdditiveAttention(d_q, d_k, d_attn)

# Forward pass
output, attn_weights = additive_attention(Q, K, V)

print("Output shape:", output.shape)          # Expected: (2, 3, 6)
print("Attention weights shape:", attn_weights.shape)  # Expected: (2, 3, 4)

Output shape: torch.Size([2, 3, 6])
Attention weights shape: torch.Size([2, 3, 4])


## Self-Attention

**Overview:**

Self-Attention allows a sequence to interact with itself (i.e., different positions within the same sequence) to compute a representation of the sequence. This mechanism is pivotal in Transformer models.

**Key Components:**

- **Queries, Keys, Values:** Derived from the same input sequence.
- **Multi-Head Attention:** Extends self-attention by running multiple attention mechanisms in parallel, allowing the model to focus on different representation subspaces.

**Implementation Example:**

In [5]:
class SelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(SelfAttention, self).__init__()
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        self.q_linear = nn.Linear(embed_dim, embed_dim)
        self.k_linear = nn.Linear(embed_dim, embed_dim)
        self.v_linear = nn.Linear(embed_dim, embed_dim)
        
        self.attention = ScaledDotProductAttention(self.head_dim)
        self.fc_out = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, x, mask=None):
        B, T, E = x.size()
        
        Q = self.q_linear(x).view(B, T, self.num_heads, self.head_dim).transpose(1,2)  # (B, num_heads, T, head_dim)
        K = self.k_linear(x).view(B, T, self.num_heads, self.head_dim).transpose(1,2)
        V = self.v_linear(x).view(B, T, self.num_heads, self.head_dim).transpose(1,2)
        
        if mask is not None:
            mask = mask.unsqueeze(1)  # (B, 1, 1, T)
        
        attended, attn_weights = self.attention(Q, K, V, mask)
        attended = attended.transpose(1,2).contiguous().view(B, T, E)
        output = self.fc_out(attended)
        return output, attn_weights


In [6]:
# Example parameters
batch_size = 2
seq_len = 5
embed_dim = 16
num_heads = 4

# Random input tensor
x = torch.randn(batch_size, seq_len, embed_dim)

# Initialize self-attention module
self_attention = SelfAttention(embed_dim, num_heads)

# Forward pass
output, attn_weights = self_attention(x)

print("Output shape:", output.shape)  # Expected: (2, 5, 16)
print("Attention weights shape:", attn_weights.shape)  # Expected: (2, 4, 5, 5)

Output shape: torch.Size([2, 5, 16])
Attention weights shape: torch.Size([2, 4, 5, 5])


## MultiHeadAttention 

Facilitates the model's ability to focus on different parts of the input sequence simultaneously by splitting the embedding space into multiple heads, each performing its own attention computation. This multi-faceted attention mechanism enriches the model's capacity to capture diverse contextual relationships within the data.



In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"
        
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        # Linear layers for Q, K, V
        self.q_linear = nn.Linear(embed_dim, embed_dim)
        self.k_linear = nn.Linear(embed_dim, embed_dim)
        self.v_linear = nn.Linear(embed_dim, embed_dim)
        
        self.attention = ScaledDotProductAttention(self.head_dim)
        self.fc_out = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, query, key, value, mask=None):
        """
        query: [batch_size, query_len, embed_dim]
        key: [batch_size, key_len, embed_dim]
        value: [batch_size, key_len, embed_dim]
        mask: [batch_size, num_heads, query_len, key_len] or similar
        """
        B, Tq, E = query.size()
        Bk, Tk, Ek = key.size()
        
        assert B == Bk and E == Ek, "Batch size and embedding dimensions must match between query and key/value."
        
        # Linear projections
        Q = self.q_linear(query).view(B, Tq, self.num_heads, self.head_dim).transpose(1,2)  # [B, num_heads, Tq, head_dim]
        K = self.k_linear(key).view(B, Tk, self.num_heads, self.head_dim).transpose(1,2)    # [B, num_heads, Tk, head_dim]
        V = self.v_linear(value).view(B, Tk, self.num_heads, self.head_dim).transpose(1,2)  # [B, num_heads, Tk, head_dim]
        
        # Apply attention
        attended, attn_weights = self.attention(Q, K, V, mask)  # [B, num_heads, Tq, head_dim]
        
        # Concatenate heads
        attended = attended.transpose(1,2).contiguous().view(B, Tq, E)  # [B, Tq, E]
        
        # Final linear layer
        output = self.fc_out(attended)  # [B, Tq, E]
        
        return output, attn_weights


In [8]:
# Example parameters
batch_size = 2
query_len = 5
key_len = 5
embed_dim = 16
num_heads = 4

# Random input tensors
query = torch.randn(batch_size, query_len, embed_dim)  # Shape: [2, 5, 16]
key = torch.randn(batch_size, key_len, embed_dim)      # Shape: [2, 5, 16]
value = torch.randn(batch_size, key_len, embed_dim)    # Shape: [2, 5, 16]

# Optional mask (set to None for simplicity)
mask = None  # Alternatively, define a mask tensor with shape [batch_size, num_heads, query_len, key_len]

# Initialize multi-head attention module
multi_head_attn = MultiHeadAttention(embed_dim, num_heads)

# Forward pass
output, attn_weights = multi_head_attn(query, key, value, mask)

print("Output shape:", output.shape)            # Expected: [2, 5, 16]
print("Attention weights shape:", attn_weights.shape)  # Expected: [2, 4, 5, 5]

Output shape: torch.Size([2, 5, 16])
Attention weights shape: torch.Size([2, 4, 5, 5])


## Practical Example: Implementing Attention in a Sequence-to-Sequence Model

To solidify our understanding, let's implement a simple sequence-to-sequence (Seq2Seq) model with attention for a translation task. We'll use the Scaled Dot-Product Attention mechanism.

**Note:** This is a simplified example for educational purposes.

### 1. Preparing the Data

For demonstration, we'll use dummy data. In practice, you'd use a dataset like the English-French sentence pairs.

In [9]:
source_sentences = [
    "hello",
    "how are you",
    "good morning",
    "good night",
    "thank you"
]

target_sentences = [
    "bonjour",
    "comment ça va",
    "bonjour",
    "bonne nuit",
    "merci"
]

# Create vocabulary
source_vocab = sorted(list(set(" ".join(source_sentences))))
target_vocab = sorted(list(set(" ".join(target_sentences))))

# Add special tokens
source_vocab = ['<pad>', '<sos>', '<eos>'] + source_vocab
target_vocab = ['<pad>', '<sos>', '<eos>'] + target_vocab

source_char2idx = { ch:i for i,ch in enumerate(source_vocab) }
source_idx2char = { i:ch for i,ch in enumerate(source_vocab) }

target_char2idx = { ch:i for i,ch in enumerate(target_vocab) }
target_idx2char = { i:ch for i,ch in enumerate(target_vocab) }


In [10]:
from torch.utils.data import DataLoader, Dataset

# Convert sentences to indices
def encode_sentence(sentence, char2idx):
    return [char2idx['<sos>']] + [char2idx[ch] for ch in sentence] + [char2idx['<eos>']]

encoded_sources = [encode_sentence(s, source_char2idx) for s in source_sentences]
encoded_targets = [encode_sentence(s, target_char2idx) for s in target_sentences]

print("Encoded Sources:", encoded_sources)
print("Encoded Targets:", encoded_targets)

# Define maximum sequence lengths
MAX_SOURCE_LEN = max(len(s) for s in encoded_sources)
MAX_TARGET_LEN = max(len(t) for t in encoded_targets)

# Pad sequences
def pad_sequence(seq, max_len, pad_idx):
    return seq + [pad_idx] * (max_len - len(seq))

padded_sources = [pad_sequence(s, MAX_SOURCE_LEN, source_char2idx['<pad>']) for s in encoded_sources]
padded_targets = [pad_sequence(s, MAX_TARGET_LEN, target_char2idx['<pad>']) for s in encoded_targets]

# Convert to tensors
source_tensor = torch.tensor(padded_sources, dtype=torch.long)
target_tensor = torch.tensor(padded_targets, dtype=torch.long)


# Create a simple Dataset
class TranslationDataset(Dataset):
    def __init__(self, src, tgt):
        self.src = src
        self.tgt = tgt
        
    def __len__(self):
        return len(self.src)
    
    def __getitem__(self, idx):
        return self.src[idx], self.tgt[idx]

Encoded Sources: [[1, 8, 6, 11, 11, 14, 2], [1, 8, 14, 18, 3, 4, 15, 6, 3, 19, 14, 17, 2], [1, 7, 14, 14, 5, 3, 12, 14, 15, 13, 9, 13, 7, 2], [1, 7, 14, 14, 5, 3, 13, 9, 7, 8, 16, 2], [1, 16, 8, 4, 13, 10, 3, 19, 14, 17, 2]]
Encoded Targets: [[1, 5, 12, 11, 9, 12, 15, 13, 2], [1, 6, 12, 10, 10, 7, 11, 14, 3, 17, 4, 3, 16, 4, 2], [1, 5, 12, 11, 9, 12, 15, 13, 2], [1, 5, 12, 11, 11, 7, 3, 11, 15, 8, 14, 2], [1, 10, 7, 13, 6, 8, 2]]


In [11]:
# Create a simple Dataset
dataset = TranslationDataset(source_tensor, target_tensor)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)


### 2. Defining the Encoder

The encoder processes the input sequence and produces key and value vectors for attention.

#### PositionwiseFeedForward 

Enhances feature representations at each sequence position independently through two linear transformations with an activation function, allowing the model to learn complex, non-linear transformations of the input data.

In [12]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, embed_dim, ff_dim):
        super(PositionwiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(embed_dim, ff_dim)
        self.fc2 = nn.Linear(ff_dim, embed_dim)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

In [13]:

class Encoder(nn.Module):
    def __init__(self, input_dim, embed_dim, num_heads, ff_dim, num_layers):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(input_dim, embed_dim)
        self.layers = nn.ModuleList([EncoderLayer(embed_dim, num_heads, ff_dim) for _ in range(num_layers)])
        self.fc_out = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, x, mask=None):
        embedded = self.embedding(x)  # [B, T, E]
        for layer in self.layers:
            embedded = layer(embedded, mask)  # [B, T, E]
        output = self.fc_out(embedded)          # [B, T, E]
        return output


class EncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(embed_dim, num_heads)
        self.ff = PositionwiseFeedForward(embed_dim, ff_dim)
        self.layernorm1 = nn.LayerNorm(embed_dim)
        self.layernorm2 = nn.LayerNorm(embed_dim)
        
    def forward(self, x, mask=None):
        # Self-attention
        attn_output, _ = self.self_attn(x, x, x, mask)  # [B, T, E]
        x = self.layernorm1(x + attn_output)          # [B, T, E]
        
        # Feed-forward
        ff_output = self.ff(x)                        # [B, T, E]
        x = self.layernorm2(x + ff_output)            # [B, T, E]
        
        return x


### 3. Defining the Decoder

The decoder generates the output sequence, attending to the encoder's outputs.

In [14]:

class Decoder(nn.Module):
    def __init__(self, output_dim, embed_dim, num_heads, ff_dim, num_layers):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(output_dim, embed_dim)
        self.layers = nn.ModuleList([DecoderLayer(embed_dim, num_heads, ff_dim) for _ in range(num_layers)])
        self.fc_out = nn.Linear(embed_dim, output_dim)
        
    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        embedded = self.embedding(x)  # [B, T, E]
        for layer in self.layers:
            embedded = layer(embedded, encoder_output, src_mask, tgt_mask)  # [B, T, E]
        output = self.fc_out(embedded)  # [B, T, output_dim]
        return output

class DecoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(embed_dim, num_heads)
        self.cross_attn = MultiHeadAttention(embed_dim, num_heads)
        self.ff = PositionwiseFeedForward(embed_dim, ff_dim)
        self.layernorm1 = nn.LayerNorm(embed_dim)
        self.layernorm2 = nn.LayerNorm(embed_dim)
        self.layernorm3 = nn.LayerNorm(embed_dim)
        
    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        # Self-attention
        attn_output, _ = self.self_attn(x, x, x, tgt_mask) # [B, T, E]
        x = self.layernorm1(x + attn_output) # [B, T, E]
        
        # Cross-attention
        attn_output, _ = self.cross_attn(x, encoder_output, encoder_output, src_mask)  # [B, T, E]
        x = self.layernorm2(x + attn_output)  # [B, T, E]
        
        # Feed-forward
        ff_output = self.ff(x)  # [B, T, E]
        x = self.layernorm3(x + ff_output)  # [B, T, E]
        
        return x


### 4. Define the Seq2Seq Model

Combines the encoder and decoder into a single model. The forward pass involves encoding the source sequence and then decoding the target sequence.

In [15]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        
    def forward(self, src, trg, src_mask=None, trg_mask=None):
        encoder_output = self.encoder(src, src_mask)      # [B, src_T, E]
        decoder_output = self.decoder(trg, encoder_output, src_mask, trg_mask)  # [B, trg_T, output_dim]
        return decoder_output

### 5. Training the Model

In practice, you'd define a loss function (e.g., CrossEntropyLoss), an optimizer, and iterate over epochs to train the model.

In [16]:
import torch.optim as optim

# Define model hyperparameters
INPUT_DIM = len(source_vocab)
OUTPUT_DIM = len(target_vocab)
EMBED_DIM = 32      # Set embed_dim equal to hidden_dim
NUM_HEADS = 4       # Adjusted number of heads for better representation
FF_DIM = 64         # Feed-forward network dimension
NUM_ENCODER_LAYERS = 2
NUM_DECODER_LAYERS = 2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

encoder = Encoder(INPUT_DIM, EMBED_DIM, NUM_HEADS, FF_DIM, NUM_ENCODER_LAYERS)
decoder = Decoder(OUTPUT_DIM, EMBED_DIM, NUM_HEADS, FF_DIM, NUM_DECODER_LAYERS)
model = Seq2Seq(encoder, decoder, device).to(device)


# Define Loss and Optimizer
criterion = nn.CrossEntropyLoss(ignore_index=target_char2idx['<pad>'])
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [17]:
# Function to create masks
def create_masks(src, decoder_input, num_heads=4):
    """
    Create source and target masks.
    
    Args:
        src (Tensor): Source sequences [B, src_T]
        decoder_input (Tensor): Decoder input sequences [B, trg_T -1]
        num_heads (int): Number of attention heads
    
    Returns:
        src_mask (Tensor): [B, num_heads, 1, src_T]
        trg_mask (Tensor): [B, num_heads, trg_T -1, trg_T -1]
    """
    # Source mask: [B,1,1,src_T] -> [B,num_heads,1,src_T]
    src_mask = (src != source_char2idx['<pad>']).unsqueeze(1).unsqueeze(2)  # [B,1,1,src_T]
    src_mask = src_mask.repeat(1, num_heads, 1, 1)  # [B,num_heads,1,src_T]
    
    # Target mask: [B,1,T,1] & [T,T] -> [B,1,T,T] -> [B,num_heads,T,T]
    trg_pad_mask = (decoder_input != target_char2idx['<pad>']).unsqueeze(1).unsqueeze(3)  # [B,1,T,1]
    trg_len = decoder_input.size(1)
    trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=device)).bool()  # [T,T]
    trg_mask = trg_pad_mask & trg_sub_mask  # [B,1,T,T]
    trg_mask = trg_mask.repeat(1, num_heads, 1, 1)  # [B,num_heads,T,T]
    
    return src_mask, trg_mask

NUM_EPOCHS = 200

for epoch in range(NUM_EPOCHS):
    model.train()
    epoch_loss = 0
    for src_batch, trg_batch in dataloader:
        src_batch = src_batch.to(device)  # [B, src_T=14]
        trg_batch = trg_batch.to(device)  # [B, trg_T=14]
        
        # Prepare decoder input and target
        decoder_input = trg_batch[:, :-1]  # [B, trg_T -1 =13]
        target = trg_batch[:, 1:].contiguous().view(-1)  # [B *13]
        
        # Create masks based on decoder_input
        src_mask, trg_mask = create_masks(src_batch, decoder_input, num_heads=NUM_HEADS)  # trg_mask: [B,4,13,13]
        
        optimizer.zero_grad()
        
        # Forward pass
        output = model(src_batch, decoder_input, src_mask, trg_mask)  # [B, trg_T -1, OUTPUT_DIM]
        
        # Reshape output to [B * (trg_T -1), OUTPUT_DIM]
        output = output.view(-1, OUTPUT_DIM)
        
        # Compute loss
        loss = criterion(output, target)
        loss.backward()
        
        optimizer.step()
        
        epoch_loss += loss.item()
    
    # Print loss every 10 epochs and the first epoch
    if (epoch+1) % 10 == 0 or epoch == 0:
        print(f'Epoch: {epoch+1}, Loss: {epoch_loss / len(dataloader):.4f}')


Epoch: 1, Loss: 2.9783
Epoch: 10, Loss: 1.7617
Epoch: 20, Loss: 1.1020
Epoch: 30, Loss: 0.5524
Epoch: 40, Loss: 0.2994
Epoch: 50, Loss: 0.1741
Epoch: 60, Loss: 0.1579
Epoch: 70, Loss: 0.0988
Epoch: 80, Loss: 0.0630
Epoch: 90, Loss: 0.0480
Epoch: 100, Loss: 0.0399
Epoch: 110, Loss: 0.0323
Epoch: 120, Loss: 0.0264
Epoch: 130, Loss: 0.0235
Epoch: 140, Loss: 0.0198
Epoch: 150, Loss: 0.0184
Epoch: 160, Loss: 0.0160
Epoch: 170, Loss: 0.0150
Epoch: 180, Loss: 0.0127
Epoch: 190, Loss: 0.0114
Epoch: 200, Loss: 0.0108


In [18]:
def translate(sentence, encoder, decoder, source_char2idx, target_idx2char, max_len=MAX_TARGET_LEN):
    model.eval()
    with torch.no_grad():
        # Encode the source sentence
        encoded_src = encode_sentence(sentence, source_char2idx)
        encoded_src = pad_sequence(encoded_src, MAX_SOURCE_LEN, source_char2idx['<pad>'])
        src_tensor = torch.tensor(encoded_src, dtype=torch.long).unsqueeze(0).to(device)  # [1, src_seq_len]
        
        # Create source mask
        src_mask = (src_tensor != source_char2idx['<pad>']).unsqueeze(1).unsqueeze(2)  # [1,1,1,src_T]
        src_mask = src_mask.repeat(1, NUM_HEADS, 1, 1)  # [1,num_heads,1,src_T]
        
        # Encoder output
        encoder_out = encoder(src_tensor, src_mask)  # [1, src_T, E]
        
        # Initialize decoder input with <sos>
        decoder_input = torch.tensor([target_char2idx['<sos>']], dtype=torch.long).unsqueeze(0).to(device)  # [1,1]
        
        translated_sentence = ""
        
        for _ in range(max_len):
            # Create target mask
            trg_pad_mask = (decoder_input != target_char2idx['<pad>']).unsqueeze(1).unsqueeze(3)  # [1,1,T,1]
            trg_len = decoder_input.size(1)
            trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=device)).bool()  # [T,T]
            trg_mask = trg_pad_mask & trg_sub_mask  # [1,1,T,T]
            trg_mask = trg_mask.repeat(1, NUM_HEADS, 1, 1)  # [1,num_heads,T,T]
            
            # Forward pass
            output = decoder(decoder_input, encoder_out, src_mask, trg_mask)  # [1, T, OUTPUT_DIM]
            
            # Get the last time step
            output = output[:, -1, :]  # [1, OUTPUT_DIM]
            
            # Get the predicted token
            pred_token = output.argmax(1).item()
            pred_char = target_idx2char[pred_token]
            
            if pred_char == '<eos>':
                break
            translated_sentence += pred_char
            decoder_input = torch.cat([decoder_input, torch.tensor([[pred_token]], dtype=torch.long).to(device)], dim=1)  # [1, T+1]
        
        return translated_sentence

In [19]:
test_sentences = [
    "hello",
    "good night",
    "thank you",
    "how are you",
]

for sentence in test_sentences:
    translation = translate(sentence, encoder, decoder, source_char2idx, target_idx2char)
    print(f"{sentence} -> {translation}")


hello -> bonjour
good night -> bonne nuit
thank you -> merci
how are you -> comment ça va


## Analysis of Attention Mechanisms

Attention mechanisms allow models to dynamically focus on relevant parts of the input, enhancing performance in tasks like machine translation, text summarization, and more. They alleviate the limitations of RNNs by:

- **Capturing Long-Range Dependencies:** By directly connecting any two positions in the input, regardless of their distance.
- **Improving Parallelization:** Especially in Transformer models, attention enables parallel processing of sequence elements.
- **Enhancing Interpretability:** Attention weights can provide insights into which parts of the input the model focuses on during prediction.

## Further Steps

- **Explore Multi-Head Attention:** Understand how multiple attention heads can capture diverse aspects of the input.
- **Implement Transformer Models:** Dive deeper into Transformer architectures, which rely heavily on attention mechanisms.
- **Experiment with Different Attention Types:** Implement and compare soft attention, hard attention, and self-attention in various tasks.
- **Visualize Attention Weights:** Gain insights by visualizing where the model is focusing its attention during predictions.

**Remember:** Attention mechanisms are foundational to many state-of-the-art models in NLP and beyond. Mastering them will significantly enhance your ability to design and understand complex neural network architectures.

## References

- [Attention Mechanism Explained](https://towardsdatascience.com/attention-mechanism-explained-8f96b26ebae)
- [The Annotated Transformer](http://nlp.seas.harvard.edu/2018/04/03/attention.html)