## Attention and Self-Attention

Suppose we have a set of $n$ information vectors (the "memory"):

$$
V_1, V_2, \dots, V_n \in \mathbb{R}^d.
$$



We want to produce a single output vector that is a weighted combination of these:

$$
\text{output} = \sum_{i=1}^n \alpha_i V_i,
$$

where the weights $\alpha_i$ satisfy:

- $\alpha_i \ge 0$ for all $i$,
- $\sum_{i=1}^n \alpha_i = 1$.

i.e. **probability distribution** over the vectors $V_i$ telling us how much to "attend" to each one.



## Why? 

Attention answers:

1. How similar is each memory vector $V_i$ to what we are looking for?
2. How do we convert these similarities into normalized weights $\alpha_i$?


## Idea of attention

- **query vector** $Q$: "what we are looking for"
- **key vector** $K_i$ (for matching) and a **value vector** $V_i$ (for content).
- attention weights from **queryâ€“key similarity**.

## In math

Each item in memory has:

- A **key** $K_i \in \mathbb{R}^d$, used to decide *whether* we should attend to it.
- A **value** $V_i \in \mathbb{R}^d$, used as the actual content we mix together.

We also have a **query** vector $Q \in \mathbb{R}^d$.

1. Compute a score for each item $i$:

$$
s_i = \langle Q, K_i \rangle
$$

(often a dot product).

2. Normalize scores with a softmax:

$$
\alpha_i = \frac{\exp(s_i)}{\sum_{j=1}^n \exp(s_j)}.
$$

3. Combine values:

$$
\text{Attention}(Q, K, V) = \sum_{i=1}^n \alpha_i V_i.
$$

In vectorized form, if we stack keys into a matrix $K \in \mathbb{R}^{n \times d}$ and values into $V \in \mathbb{R}^{n \times d}$:

- Scores: $s = K Q^\top \in \mathbb{R}^n$,
- Weights: $\alpha = \text{softmax}(s) \in \mathbb{R}^n$,
- Output: $\text{Attention}(Q,K,V) = \alpha^\top V \in \mathbb{R}^d$.

## Scaled dot-product attention

$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q K^\top}{\sqrt{d_k}}\right) V.
$$

## Self-attention

- In standard attention, $Q$, $K$, and $V$ may come from **different** sequences.  
- In **self-attention**, they all come from the **same** sequence.


## In linear algebra

$$
  Q = X W_Q, \quad
  K = X W_K, \quad
  V = X W_V, \quad
  \text{Output} = \text{softmax}\left(\frac{Q K^\top}{\sqrt{d_k}}\right) V.
$$

The $W_{\{Q,K,V\}}$ are learnable projection matrices.


## Multi-head self attention

= stack a bunch of these together

In [10]:
## In code
import torch
import torch.nn.functional as F

X = torch.randn(4, 5)   # shape: (seq_len, d_model)
X

tensor([[ 1.1550e+00,  1.3382e+00,  1.6987e-03, -1.2204e+00,  3.5535e-01],
        [-1.1931e+00,  9.6666e-01,  3.7223e-01,  2.2102e-01,  1.0763e+00],
        [ 9.9946e-02, -1.7015e-01, -1.2487e+00,  7.5870e-01, -4.2486e-01],
        [ 1.1354e+00,  1.1884e+00, -1.7155e+00,  5.7872e-01,  9.4685e-01]])

In [11]:
def attention(Q, K, V):
    """
    Q: (d_k,)
    K: (n, d_k)
    V: (n, d_v)
    """
    d_k = Q.shape[-1]
    scores = (K @ Q) / d_k**0.5          # (n,)
    weights = F.softmax(scores, dim=0)   # (n,)
    output = weights @ V                 # (d_v,)
    return output, weights


In [12]:

# Example: use X as both keys and values; choose one token as query
Q = X[0]
K = X
V = X

out, w = attention(Q, K, V)
print("output:", out)
print("weights:", w)

output: tensor([ 0.9203,  1.2058, -0.4342, -0.5913,  0.5169])
weights: tensor([0.6394, 0.0777, 0.0450, 0.2378])


In [13]:
def self_attention(X):
    """
    X: (L, d_model)
    returns: (L, d_model)
    """
    Q = X        # (L, d_model)
    K = X        # (L, d_model)
    V = X        # (L, d_model)
    
    d_k = X.shape[-1]
    # scores: (L, L)
    scores = Q @ K.T / d_k**0.5  
    
    # weights: (L, L), row-softmax
    weights = F.softmax(scores, dim=1)
    
    # output: (L, d_model)
    output = weights @ V
    return output, weights

out, w = self_attention(X)
print("output shape:", out.shape)
print("weights shape:", w.shape)

output shape: torch.Size([4, 5])
weights shape: torch.Size([4, 4])


In [14]:
class SelfAttention(torch.nn.Module):
    def __init__(self, d_model, d_k):
        super().__init__()
        self.W_Q = torch.nn.Linear(d_model, d_k, bias=False)
        self.W_K = torch.nn.Linear(d_model, d_k, bias=False)
        self.W_V = torch.nn.Linear(d_model, d_k, bias=False)

    def forward(self, X):
        Q = self.W_Q(X)     # (L, d_k)
        K = self.W_K(X)     # (L, d_k)
        V = self.W_V(X)     # (L, d_k)

        d_k = Q.shape[-1]
        scores = Q @ K.T / d_k**0.5
        weights = F.softmax(scores, dim=1)
        return weights @ V

sa = SelfAttention(d_model=5, d_k=4)
out = sa(X)
out

tensor([[-0.1983,  0.2171,  0.2833,  0.2332],
        [-0.1736,  0.1477,  0.2335,  0.4675],
        [-0.1423,  0.1043,  0.2284,  0.6015],
        [-0.1806,  0.1563,  0.2324,  0.4391]], grad_fn=<MmBackward0>)

In [16]:
class MultiHeadSelfAttention(torch.nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_head = d_model // num_heads
        self.num_heads = num_heads

        self.W_Q = torch.nn.Linear(d_model, d_model, bias=False)
        self.W_K = torch.nn.Linear(d_model, d_model, bias=False)
        self.W_V = torch.nn.Linear(d_model, d_model, bias=False)
        self.W_O = torch.nn.Linear(d_model, d_model, bias=False)

    def forward(self, X):
        L, d_model = X.shape

        Q = self.W_Q(X)     # (L, d_model)
        K = self.W_K(X)
        V = self.W_V(X)

        # reshape into heads
        Q = Q.view(L, self.num_heads, self.d_head).transpose(0,1)  # (h, L, d_head)
        K = K.view(L, self.num_heads, self.d_head).transpose(0,1)
        V = V.view(L, self.num_heads, self.d_head).transpose(0,1)

        scores = Q @ K.transpose(1,2) / (self.d_head ** 0.5)       # (h, L, L)
        weights = F.softmax(scores, dim=-1)
        out = weights @ V                                          # (h, L, d_head)

        # merge heads
        out = out.transpose(0,1).contiguous().view(L, d_model)
        return self.W_O(out)

mhsa = MultiHeadSelfAttention(d_model=8, num_heads=2)
X2 = torch.randn(4, 8)
mhsa(X2)

tensor([[-0.0437, -0.0520,  0.3056,  0.2199,  0.0880, -0.1595,  0.0579,  0.0900],
        [-0.0722,  0.1792,  0.4519,  0.1704,  0.0363, -0.0569, -0.0437,  0.1474],
        [-0.0450, -0.0899,  0.2666,  0.2517,  0.1145, -0.1877,  0.0941,  0.1019],
        [-0.0671, -0.0384,  0.2782,  0.2572,  0.1151, -0.1889,  0.0960,  0.1342]],
       grad_fn=<MmBackward0>)