
# Understanding Multi-Head Self-Attention in Transformers

This notebook provides a detailed step-by-step explanation of the multi-head self-attention mechanism in Transformer models. We will use a simple example to illustrate the computations involved, including linear transformations, scaled dot-product attention, and the final concatenation and linear transformation.

### Example Input

Let's assume we have a simple input sequence of three tokens, represented as embeddings:
- Embedding size: 6
- Number of heads: 2

The steps involved in the multi-head self-attention mechanism are:

1. Linear transformations for Queries, Keys, and Values
2. Splitting into multiple heads
3. Scaled dot-product attention
4. Concatenation of heads and final linear transformation

We will illustrate each of these steps in detail.




## Step 1: Linear Transformations for Queries, Keys, and Values

Each input embedding is transformed into queries, keys, and values using learned weight matrices.

### Example

Let's assume we have the following embedding matrix for a sequence of three tokens (each of size 6):


In [3]:
import numpy as np
# Example input embeddings (3 tokens, each of size 6)
X = np.array([[1, 0, 1, 0, 1, 0],
              [0, 2, 0, 2, 0, 2],
              [1, 1, 1, 1, 1, 1]])

# Learned weight matrices for queries, keys, and values (for simplicity, we use smaller matrices)
W_Q = np.array([[0.1, 0.2, 0.3],
                [0.4, 0.5, 0.6],
                [0.7, 0.8, 0.9],
                [0.1, 0.2, 0.3],
                [0.4, 0.5, 0.6],
                [0.7, 0.8, 0.9]])

W_K = W_Q  # For simplicity, using the same matrix for keys
W_V = W_Q  # For simplicity, using the same matrix for values

# Linear transformations
Q = np.dot(X, W_Q)
K = np.dot(X, W_K)
V = np.dot(X, W_V)

print("Queries:\n", Q)
print("Keys:\n", K)
print("Values:\n", V)


Queries:
 [[1.2 1.5 1.8]
 [2.4 3.  3.6]
 [2.4 3.  3.6]]
Keys:
 [[1.2 1.5 1.8]
 [2.4 3.  3.6]
 [2.4 3.  3.6]]
Values:
 [[1.2 1.5 1.8]
 [2.4 3.  3.6]
 [2.4 3.  3.6]]



## Step 2: Splitting into Multiple Heads

We split the resulting queries, keys, and values into multiple heads. Each head will have a dimension of 3 (embedding size / number of heads).

### Example

Splitting the queries, keys, and values into 2 heads:


In [None]:

```python
# Splitting into 2 heads (each of size 3)
def split_heads(X, num_heads):
    batch_size, seq_length, embed_size = X.shape
    head_dim = embed_size // num_heads
    X = X.reshape(batch_size, seq_length, num_heads, head_dim)
    X = X.transpose(0, 2, 1, 3)  # (batch_size, num_heads, seq_length, head_dim)
    return X

num_heads = 2
Q_heads = split_heads(Q, num_heads)
K_heads = split_heads(K, num_heads)
V_heads = split_heads(V, num_heads)

print("Q_heads shape:", Q_heads.shape)
print("K_heads shape:", K_heads.shape)
print("V_heads shape:", V_heads.shape)
```



## Step 3: Scaled Dot-Product Attention

For each head, we compute the attention scores and weighted sums.

### Example

Compute the attention scores for one of the heads:


In [None]:

```python
# Scaled dot-product attention
def scaled_dot_product_attention(Q, K, V):
    d_k = Q.shape[-1]
    scores = np.matmul(Q, K.transpose(0, 1, 3, 2)) / math.sqrt(d_k)
    weights = np.exp(scores) / np.sum(np.exp(scores), axis=-1, keepdims=True)
    output = np.matmul(weights, V)
    return output

# Compute attention for each head
attention_outputs = [scaled_dot_product_attention(Q_heads[:, i], K_heads[:, i], V_heads[:, i]) for i in range(num_heads)]

print("Attention outputs for each head:")
for i, attn_output in enumerate(attention_outputs):
    print(f"Head {i+1}:\n", attn_output)
```



## Step 4: Concatenation of Heads and Linear Transformation

After computing the attention output for each head, we concatenate the outputs and apply a final linear transformation.

### Example

Concatenate the outputs and apply a linear transformation:


In [None]:

```python
# Concatenate outputs from each head
def concatenate_heads(outputs):
    batch_size, num_heads, seq_length, head_dim = outputs[0].shape
    concatenated = np.concatenate(outputs, axis=-1)  # (batch_size, seq_length, num_heads * head_dim)
    return concatenated

# Linear transformation after concatenation (for simplicity, using an identity matrix)
W_O = np.eye(num_heads * Q_heads.shape[-1])
concatenated = concatenate_heads(attention_outputs)
final_output = np.dot(concatenated, W_O)

print("Concatenated output shape:", concatenated.shape)
print("Final output:\n", final_output)
```
