# Attention Is All You Need
## Scaled Dot Product Attention

At its core, the self-attention mechanism revolves around the interplay of three components: **key**, **query**, and **value**. These are vital for understanding how information is weighted and propagated in attention models, such as the Transformer.

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^{T}}{\sqrt{d_{k}}}\right) \cdot V $$

When $Q = K$, the term $QK^{T}$ captures the self-attention, indicating how similar elements within the matrix $Q$ are to one another.

### Why Use $\sqrt{d_k}$?
Under the assumption that the components of $q$ and $k$ are independent random variables with mean 0 and variance 1 (it is quite theoretical assumption that is not realistic for most cases), their dot product, $q \cdot k = \sum_{i=1}^{d_k} q_{i}k_{i}$ has mean 0 and variance $d_{k}$.

The mean can be determined using the **linearity of expectation**:

$$ E[q \cdot k] = E\left[\sum_{i=1}^{d_k} q_i k_i\right] $$

$$ = \sum_{i=1}^{d_k} E[q_ik_i] $$

Given the assumption that random variables are i.i.d (independently identically distributed):

$$ = \sum_{i=1}^{d_k} E[q_i]E[k_i] = 0 $$

Thus, the mean of $q \cdot k$ equals 0.

For variance, although variance is not strictly linear in the way that expectation is, in this context, since the random variables are independent, the variance of their sum is the sum of their variances. Hence, using a principle similar to the **linearity of expectation**:

$$ \text{var}[q \cdot k] = \text{var}\left[\sum_{i=1}^{d_k}q_ik_i\right] $$

$$ = \sum_{i=1}^{d_k}\text{var}[q_ik_i] = d_k $$

To make the dot product have a mean of 0 and standard deviation of 1, it's divided by $\sqrt{d_k}$. However, nowadays, this normalization is often omitted since a normal distribution is not always assumed, especially when layer normalization is not used. **Scaled Dot Product Attention** refers to the process of this calculation. Given that **Query**, **Key**, and **Value** are all $3 \times 1$ matrices:

$$ 
Q = K = V = \begin{bmatrix} 
v_1 \\ 
v_2 \\ 
v_3 
\end{bmatrix} 
$$

Since $QK^{T}$ results in a $3 \times 3$ matrix:

$$ 
QK^T = \begin{bmatrix} 
v_1 \cdot v_1 & v_1 \cdot v_2 & v_1 \cdot v_3 \\ 
v_2 \cdot v_1 & v_2 \cdot v_2 & v_2 \cdot v_3 \\ 
v_3 \cdot v_1 & v_3 \cdot v_2 & v_3 \cdot v_3 
\end{bmatrix} 
$$

We then divide $QK^{T}$ by $\sqrt{d_k}$, obtaining the **attention weight**:

$$ 
\text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) = \begin{bmatrix} 
w_{11} & w_{12} & w_{13} \\ 
w_{21} & w_{22} & w_{23} \\ 
w_{31} & w_{32} & w_{33} 
\end{bmatrix} 
$$

Given the value matrix, we compute:

$$ 
\text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) \times V = \begin{bmatrix} 
y_1 \\ 
y_2 \\ 
y_3 
\end{bmatrix} 
$$

The attention mechanism gauges the similarity between a *query* (the word we're focusing on) and a *key* (the word we're comparing against). The resulting similarity scores are then used to weigh the importance of words in the **Value** matrix. See below example code to understand how it goes:

### Transformer Block

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Attention Module
class ScaledDotProductAttention(nn.Module):
    def __init__(self, temperature, attn_dropout=0.1):
        super(ScaledDotProductAttention, self).__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)

    def forward(self, q, k, v, mask=None):
        attn = torch.matmul(q, k.transpose(-2, -1)) / self.temperature
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -torch.inf)
        attn = self.dropout(F.softmax(attn, dim=-1))
        output = torch.matmul(attn, v)
        return output, attn

# Sample sentences
sentences = ["This is a sample.", "Attention mechanisms are powerful.", "Scaled dot product is interesting."]

# Simple tokenization (split by space) and creating a vocabulary
tokens = [sentence.split() for sentence in sentences]
vocab = set(word for sentence in tokens for word in sentence)
vocab_size = len(vocab)
word_to_idx = {word: idx for idx, word in enumerate(vocab)}
idx_to_word = {idx: word for word, idx in word_to_idx.items()}

# Convert tokens to integers
token_ids = [[word_to_idx[word] for word in token] for token in tokens]

# Pad sequences (assuming a max sequence length of 10 for simplicity)
max_len, d_model = 10, 64
padded_token_ids = [token_id + [0] * (max_len - len(token_id)) for token_id in token_ids]

# Convert to tensor
input_tensor = torch.tensor(padded_token_ids)

# Embedding (assuming d_model = 64)
embedding = nn.Embedding(vocab_size, d_model)
embedded_input = embedding(input_tensor)

# Attention mechanism
temperature = torch.sqrt(torch.tensor(d_model).float())
# temperature = torch.sqrt(torch.tensor(d_model, dtype=torch.float32))
sdp_attention = ScaledDotProductAttention(temperature)

# Using the same embedded input for q, k, and v for simplicity
output, attn_weights = sdp_attention(embedded_input, embedded_input, embedded_input)

# Display the results
print("Output:", output)
print("Attention Weights:", attn_weights)
print(output.shape, attn_weights.shape)

Output: tensor([[[ 8.7608e-01,  3.3927e-01, -1.3244e+00,  ..., -4.0746e-02,
          -1.6846e-01, -1.5164e-01],
         [ 1.4763e+00,  1.6987e+00,  9.0214e-01,  ..., -1.2537e+00,
           8.5446e-01,  1.3122e+00],
         [ 1.3544e+00,  9.5870e-01, -4.4449e-02,  ...,  2.3614e+00,
           7.4944e-01,  1.7983e+00],
         ...,
         [ 3.3140e-01,  2.2513e+00, -1.4790e+00,  ..., -3.9521e-01,
          -1.2430e+00,  3.5047e-01],
         [ 3.9767e-01,  2.7016e+00, -1.7748e+00,  ..., -4.7426e-01,
          -1.4916e+00,  4.2055e-01],
         [ 3.3140e-01,  2.2513e+00, -1.4790e+00,  ..., -3.9521e-01,
          -1.2430e+00,  3.5047e-01]],

        [[ 1.2716e+00,  2.2604e+00, -3.6325e-01,  ..., -2.4550e-03,
          -1.1870e-01, -1.8570e+00],
         [ 4.0824e-01,  9.8044e-01,  1.0579e+00,  ..., -6.8351e-02,
           2.1602e+00, -3.0626e-01],
         [ 5.2817e-03,  2.1607e-02, -1.1288e-02,  ..., -3.0273e-03,
          -7.8106e-03, -1.4656e-03],
         ...,
         [ 3.9771