# Scaled Dot-Product Attention (SPDA)

## Summary

The **Scaled Dot-Product Attention** mechanism can be conceptualized as an associative retrieval system where the **Query** represents a specific semantic "question" directed toward **Keys**, which serve as indexing metadata for the underlying **Values**.


## Step-by-Step Explanation

### Attention mechanism as Library Information Retrieval Metaphor

Let us define our variables within the library context:

<img src = "../images/SelfAttention.png">

* $Q$ **(Query)**: Your search intent (e.g., "How does backpropagation work?").
* $K$ **(Key)**: The spine labels or metadata of all books on the shelf (e.g., "Calculus," "Optimization," "Neural Networks").
* $V$ **(Value)**: The actual instructional content inside those books.

**The Three-Stage Retrieval Process**

1. Similarity Computation: <br>
    You compare your question ($Q$) against every label on the bookshelf ($K$) by `MatMul` operation. This is a "matching" phase. If your question is about ML, the dot product will yield a high score for the "Neural Networks" book.

2. Scaling: <br>
    When representations have very high dimensionality, the raw matching score between a query and a key (paticular book) could become so large that it overwhelms all other candidates. To prevent this, Scaling serves as a stabilizing adjstment.

3. Normalization and Aggregation: <br>
    In this process, transformation, produces what are commonly referred as **attention weights**, which behave like probabilities. For example, if on source receives a weight of $0.7$ and another $0.2$, the system does not exclusively select the highest-scoring source. Instead, it intergrates information proportionally. In effect the final representation becomes a weighted synthesis: approximately $70%$ of the informational contribution derives from the primary source, while $20%$ is drawn from a secondary but still relevant source.

### Mathematical Foundation

Let $X \in \mathbb{R}^{n \times d_{model}}$ represent your input sequence, where $n$ is the sequence length and $d_{model}$ is the embedding dimension. To generate the Query ($Q$), Key ($K$), and Value ($V$) matrices, we perform three independent linear transformations:

$$
Q = X W_q, \quad K = X W_k, \quad V = X W_v
$$

Where the weight matrices are defined as:

* $W_q, W_k \in \mathbb{R}^{d_{model} \times d_k}$
* $W_v \in \mathbb{R}^{d_{model} \times d_v}$

**Random Initialization**

for initialization, we use **Variance Scaling Initialization**, such as **Xavier (Glorot) Initialization**:

$$
W \sim \mathcal{N}\left(0, \frac{2}{d_{in} + d_{out}}\right)
$$

This ensures the variance of the activations remains constant across layers, preventing the signal from vanishing or exploding as it traverses the network.

---

**Positional Information Injection**

We then apply the position-dependent rotation:

* **Rotated Query**: $Q' = \{ R_i q^{(i)} \}_{i=1}^n$
* **Rotated Key**: $K' = \{ R_j k^{(j)} \}_{j=1}^n$

---

**The Attention Operation**

The attention output is computed as:

$$
\text{Attention}(Q', K', V) = \text{softmax}\left(\frac{Q' (K')^\top}{\sqrt{d_k}}\right)V
$$

**The Three Stages of Computation**

1. **Similarity Computation (Rotary Dot Product)**: <br>

    The product $Q' (K')^\top$ calculates the similarity between the rotated query at position $i$ and the rotated key at position $j$. Because $R$ is an orthogonal transformation, the score $\langle R_i q^{(i)}, R_j k^{(j)} \rangle$ is equivalent to a rotation of the relative distance $(j-i)$, making the attention position-sensitive.

2. **Scaling**:<br>

    The scores are divided by $\sqrt{d_k}$. This normalization is critical; without it, the dot products of high-dimensional vectors (especially after rotation) could reach extreme magnitudes, pushing the softmax into regions with **vanishing gradients**.

3. **Normalization and Aggregation**:<br>

    The row-wise softmax converts these relative-distance-aware scores into a probabilistic distribution. Finally, we multiply by $V$ (which is not rotated, as value vectors represent content rather than position) to aggregate information.

---
**Causal Masking**

Since the dot-product ($QK^\top$) employes `MatMul` operations, which enables it to look at all word tokens within the sequences, we introduce **Causal Masking** that ensures them forbidden from glacncing at any token to the right of their current hand position.

We introduce a mask matrix $M$ into the similarity scores before the softmax operation.

The Masked Formula represents :

$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top + M}{\sqrt{d_k}}\right)V
$$

Where $M$ is a matrix of the same shape as the attention scores ($n \times n$). 

For a causal (auto-regressive) model, the elements of $M$ are defined as:

$$
M_{ij} = \begin{cases} 0 & \text{if } i \geq j \\ -\infty & \text{if } i < j \end{cases}
$$

When we add $-\infty$ to the scores where $i < j$ (the "future" tokens), and then apply the softmax function:

$$
e^{-\infty} = 0
$$

The attention weights for all future tokens become exactly zero. Consequently, the Query at position $i$ can only "see" and aggregate information from Keys at positions $0$ to $i$.

### Practical Application

#### Step0. Preparation

* String Sequence: `"the cat ate the rat"`
* Token Sequence: `[9, 0, 2, 7, 0, 7, 3, 9, 0, 6, 7]`
* Scenario: Let's observe how the Query for the token `c` (Position 2, ID `2`) attends to the Key for the token `r` (Position 9, ID `6`).

#### Step1. Projecting BPE Sequences

**The Starting Point ($X$)**

The ID `2` is converted into a vector $x^{(2)} \in \mathbb{R}^{d_{model}}$. Letâ€™s assume $d_{model} = 512$.

$$
x^{(2)} = [0.15, -0.02, \dots, 0.88]
$$

**The Projection ($W_q, W_k, W_v$)**

1. **Generate Query**: $q^{(2)} = x^{(2)} \cdot W_q$. 

    This vector represents what the `c` token is **looking for** in other tokens.

2. **Generate Key**: $k^{(2)} = x^{(2)} \cdot W_k$. 

    This represents what **information** the `c` token offers to others.

3. **Generate Value**: $v^{(2)} = x^{(2)} \cdot W_v$. 

    This contains the **actual content** that will be passed to the next layer if this token is selected.

#### Step 2: Similarity (The "Search")

The model computes the dot product between the rotated Query vector of `c` ($q'^{(2)}$) and the rotated Key vector of `r` ($k'^{(9)}$).

* **With RoPE**: 

    Because $q^{(2)}$ was rotated by $\theta_2$ and $k^{(9)}$ was rotated by $\theta_9$, the dot product inherently "measures" the gap of 7 positions. If the model has learned that a `c` and an `r` separated by 7 tokens are syntactically linked (e.g., in the words `cat` and `rat`), this score will be high.

#### Step 3: Scaling (The "Stability")

Suppose the dot product result is $80.0$. If $d_k = 64$, we divide by $\sqrt{64} = 8$.

$$
\text{Score} = 80.0 / 8 = 10.0
$$

This prevents the softmax from "peaking" too early, ensuring the model can still consider other tokens in the sequence (like the `a` or `t`).

#### Step 4: Aggregation (The "Context")

The softmax assigns a high weight (e.g., $0.85$) to the `r` at position `9`. The final representation for `c` will now contain a heavy "dosage" of the Value vector for `r`.

**Result**: 

The representation of `cat` is now "aware" of the upcoming `rat` because of their relative spatial configuration.

## Code

In [None]:
import torch
import math
import torch.nn as nn
from torch.nn import functional as F
from torch import Tensor

class ScaledDotProductAttention(nn.Module):
    
    def __init__(self) -> None:
        
        super().__init__()
 
        self.softmax = F.softmax(dim=-1)
 
    
    def forward(
            self,
            Q : torch.Tensor,
            K : torch.Tensor,
            V : torch.Tensor,
            mask : torch.Tensor
        ) -> Tensor:
 
        *batch_size, sew_len, d_k = Q.shape
        # or d_k = Q.size(-1)
 
        atten_score = Q @ K.transpose(-2, -1) / math.sqrt(d_k)
 
        # mask is given hiper parameter two purposes, mathematically, limit x -> 0
        if mask is not None:
            atten_score = atten_score.masked_fill(mask == 0, float("-inf"))
 
        atten_score = self.softmax(atten_score)
 
        Output = atten_score @ V
 
        return Output    