[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/xiptos/is_notes/blob/main/attention_bahdanau.ipynb)

# Introduction to attention mechanism

In [None]:
# Attention demo: classic (Bahdanau and Luong) on toy vectors

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.set_printoptions(precision=4, sci_mode=False)

# Toy encoder hidden states: 3 time steps, dim = 4
# Shape: (T, d_h) = (3, 4)
encoder_states = torch.tensor([
    [1.0, 0.0, 0.0, 0.0],  # h1
    [0.0, 1.0, 0.0, 0.0],  # h2
    [0.0, 0.0, 1.0, 0.0]   # h3
])

# Toy decoder state: dim = 4
decoder_state = torch.tensor([0.5, 1.0, 0.5, 0.0])  # s_t

## Bahdanau (additive) attention (small dimensions)

In [None]:
class BahdanauAttention(nn.Module):
    def __init__(self, hidden_dim, attn_dim):
        super().__init__()
        self.W_s = nn.Linear(hidden_dim, attn_dim, bias=False)
        self.W_h = nn.Linear(hidden_dim, attn_dim, bias=False)
        self.v = nn.Linear(attn_dim, 1, bias=False)

    def forward(self, decoder_state, encoder_states):
        """
        decoder_state: (hidden_dim,)
        encoder_states: (T, hidden_dim)
        """
        # Add batch dimension: (1, hidden_dim)
        s = decoder_state.unsqueeze(0)       # (1, d)
        h = encoder_states                   # (T, d)

        # Expand s over time steps
        # s_expanded: (T, d)
        s_expanded = s.expand(h.size(0), -1)

        # Compute energies e_{t,i}
        # W_s s + W_h h -> (T, attn_dim)
        energy = torch.tanh(self.W_s(s_expanded) + self.W_h(h))
        # v^T * energy -> (T, 1) -> (T,)
        scores = self.v(energy).squeeze(-1)

        # Attention weights (softmax over time dimension)
        attn_weights = F.softmax(scores, dim=0)  # (T,)

        # Context vector: weighted sum of encoder states
        # (T, d) * (T,) -> (d,)
        context = torch.sum(encoder_states * attn_weights.unsqueeze(-1), dim=0)

        return context, attn_weights, scores

bahdanau = BahdanauAttention(hidden_dim=4, attn_dim=3)

context_b, alpha_b, scores_b = bahdanau(decoder_state, encoder_states)

print("Bahdanau scores:", scores_b)
print("Bahdanau attention weights:", alpha_b)
print("Bahdanau context vector:", context_b)

## Luong (dot-product) attention

In [None]:
def luong_dot_attention(decoder_state, encoder_states):
    """
    Simple dot-product attention.
    decoder_state: (d,)
    encoder_states: (T, d)
    """
    # scores: (T,)
    scores = torch.mv(encoder_states, decoder_state)  # matrix-vector product
    attn_weights = F.softmax(scores, dim=0)           # (T,)
    context = torch.sum(encoder_states * attn_weights.unsqueeze(-1), dim=0)
    return context, attn_weights, scores

context_l, alpha_l, scores_l = luong_dot_attention(decoder_state, encoder_states)

print("Luong scores:", scores_l)
print("Luong attention weights:", alpha_l)
print("Luong context vector:", context_l)

## Visual comparison and interpretation

In [None]:
print("=== Bahdanau vs Luong ===")
print("Bahdanau weights:", alpha_b)
print("Luong weights    :", alpha_l)

print("\nDo both mechanisms focus on similar time steps?")
print("Sum of absolute differences:", torch.sum(torch.abs(alpha_b - alpha_l)).item())

## Plot attention weights

In [None]:
time_steps = range(1, encoder_states.size(0) + 1)

plt.figure()
plt.stem(time_steps, alpha_b.detach().numpy(), markerfmt='o')
plt.xlabel("Encoder time step")
plt.ylabel("Bahdanau weight")
plt.title("Bahdanau Attention Weights")
plt.show()

plt.figure()
plt.stem(time_steps, alpha_l.detach().numpy(), markerfmt='o')
plt.xlabel("Encoder time step")
plt.ylabel("Luong weight")
plt.title("Luong Attention Weights")
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_attention_heatmap(attention_weights, title="Attention Heatmap"):
    """
    attention_weights: 1D tensor or list, shape (T,)
    """
    weights = attention_weights.detach().numpy()
    weights = weights.reshape(1, -1)  # shape (1, T)

    plt.figure(figsize=(6, 2))
    plt.imshow(weights, cmap="viridis", aspect="auto")

    plt.colorbar(label="Attention weight")
    plt.yticks([])  # hide y-axis since it's a single decoder step
    plt.xticks(
        ticks=np.arange(len(weights[0])),
        labels=[f"t={i+1}" for i in range(len(weights[0]))]
    )
    plt.title(title)
    plt.xlabel("Encoder time step")
    plt.show()

# Example usage:
plot_attention_heatmap(alpha_b, "Bahdanau Attention Heatmap")
plot_attention_heatmap(alpha_l, "Luong Attention Heatmap")

# Attention by hand

## Exercise A – Softmax and Attention Weights

**Given scores for one decoder step:**

$$
e_{t,1} = 0,\quad e_{t,2} = 1,\quad e_{t,3} = 2
$$

1. Compute the attention weights:
   $$
   \alpha_{t,i} = \frac{\exp(e_{t,i})}{\sum_{j=1}^3 \exp(e_{t,j})}, \quad i = 1,2,3.
   $$

2. Verify that:
   $$
   \alpha_{t,1} + \alpha_{t,2} + \alpha_{t,3} = 1.
   $$

3. Which position has the highest attention?
   Does that correspond to the largest score?

> Hint: You may approximate \( e \approx 2.72 \) if you compute by hand.

## Exercise B – Context Vector with 1D Encoder States

Assume the encoder hidden states are **1-dimensional**:

- \( h_1 = 1 \)
- \( h_2 = 2 \)
- \( h_3 = 4 \)

Use the attention weights \(\alpha_{t,1}, \alpha_{t,2}, \alpha_{t,3}\) obtained in **Exercise A**.

1. Compute the context vector:
   $$
   c_t = \alpha_{t,1} h_1 + \alpha_{t,2} h_2 + \alpha_{t,3} h_3.
   $$

2. Is \(c_t\) closer to 1, 2, or 4?
   How does this relate to which \(\alpha_{t,i}\) was largest?

3. Intuitively, what does this say about which encoder position the model is “focusing” on?

## Exercise C – Dot-Product Attention with Tiny Vectors

Consider encoder states in 2D and one decoder state:

- \( h_1 = (1, 0) \)
- \( h_2 = (0, 1) \)
- \( h_3 = (1, 1) \)
- \( s_t = (1, 1) \)

1. Compute the dot-product scores:
   $$
   e_{t,i} = s_t^\top h_i, \quad i = 1,2,3.
   $$

2. Apply the softmax to obtain attention weights:
   $$
   \alpha_{t,i} = \frac{\exp(e_{t,i})}{\sum_{j=1}^3 \exp(e_{t,j})}.
   $$

3. Compute the **2D context vector**:
   $$
   c_t = \sum_{i=1}^3 \alpha_{t,i} h_i.
   $$

4. Interpret the result:
   - Which encoder state does the decoder “care” most about?
   - How can you see that both from the scores and from the context vector?

## Exercise D – Interpreting an Attention Matrix

Suppose you have **2 output tokens** and **3 input tokens**, with attention weights:

$$
\alpha =
\begin{bmatrix}
0.7 & 0.2 & 0.1 \\
0.1 & 0.3 & 0.6
\end{bmatrix}
$$

- **Rows** correspond to decoder/output positions (row 1 = output 1, row 2 = output 2).
- **Columns** correspond to encoder/input positions (column 1 = input 1, etc.).

Answer the following:

1. When predicting the **first output word** (row 1), which input word is most important? Why?

2. When predicting the **second output word** (row 2), which input word is most important? Why?

3. If this were a **machine translation** model:
   - What kind of alignment between input and output words does this matrix suggest?
   - Give a possible example of a short input and output sequence that could produce a pattern like this.