In [1]:
import os
import math
import pathlib
import sys
import torch
sys.path.append("..")

from jobs.configure import GPT2Config
from gpt.model import GPT2Attention
import torch.nn.functional as F

For simplicity, configure the size of attention layer small as below.

In [2]:
batch_size = 16  # denote as B
token_count = 5  # denote as T
embed_size = 12  # denote as S
num_heads = 3    # denote as H

cfg = GPT2Config(
    block_size=token_count,
    n_embd=embed_size,
    n_head=num_heads,
)
layer = GPT2Attention(cfg)
layer.eval()

GPT2Attention(
  (c_attn): Linear(in_features=12, out_features=36, bias=False)
  (c_proj): Linear(in_features=12, out_features=12, bias=False)
  (attn_dropout): Dropout(p=0.2, inplace=False)
  (resid_dropout): Dropout(p=0.2, inplace=False)
)

## Causal Attention

Batch of token_ids will have shape of (B, T). If passed to embedding layer, resulting tensor will have shape of (B, T, S) illustrated as above. This causal attention layer will map this tensor into a new tensor of shape (B, T, S) after forward computation. This notebook will demonstrate the steps of this computation in detail.

In [3]:
input_tensor = torch.randn(batch_size, cfg.block_size, cfg.n_embd)  # (B, T, S)
B, T, S = input_tensor.size()
original_output = layer(input_tensor)
input_tensor.size() == original_output.size()

True

![input tensor](./images/causal_attention/input_tensor.png)

Dot product with weight matrix `c_attn` is calculated to map `input_tensor` into query, key, value tensor respsectively. Again, for simplicity, let's zoom into first tensor of a input_tensor only(i.e. `input_tensor[0]`) from now on. 

In [4]:
q, k, v = layer.c_attn(input_tensor).split(S, dim=2)
q.size() == k.size() == v.size() == input_tensor.size()

True

![qkv split](./images/causal_attention/qkv_split.png)

Each of query, key, value tensor is reshaped into tensor of shape (B, H, T, S // H), to represent input for each head in multi head attention.

In [5]:
q = q.view(B, T, layer.num_heads, layer.head_dim).transpose(1, 2)
k = k.view(B, T, layer.num_heads, layer.head_dim).transpose(1, 2)
v = v.view(B, T, layer.num_heads, layer.head_dim).transpose(1, 2)
q.size()

torch.Size([16, 3, 5, 4])

![qkv split](./images/causal_attention/multihead_view.png)

Then, to ensure only the attention over $1,\ldots,i-1$-th tokens is applied on $i$-th token:
1. dot product between query and key is calculated(i.e. softmax logits)
2. upper triangular elements of this dot product are replaced to smallest value(-inf)
3. softmax function is applied, to make this replaced value will result in 0 softmax value

In [6]:
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))     # scaled dot product
att = att.masked_fill(layer.bias[:, :, :T, :T] == 0, float("-inf")) # masked_fill
att = F.softmax(att, dim=-1)
y = att @ v
y.size()

torch.Size([16, 3, 5, 4])

![qkv split](./images/causal_attention/masked_scaled_dot_product.gif)

Since resulting tensor `y` is still separated into H tensors, it has to be combined into single tensor. Note that `y` seemingly multiple tensors in this illustration, but as mentioned, it is actually (B, H, T, S // H) shaped single tensor. That is, it cannot be simply reshaped using `torch.stack` or `torch.cat`, therefore has to be manipulated by complex view arrangement as below.

In [7]:
y = y.transpose(1, 2).contiguous().view(B, T, S)
y.size()

torch.Size([16, 5, 12])

![qkv split](./images/causal_attention/reshape.gif)

After additional matrix multiplication with `c_proj` linear layer, this tensor becomes output of causal attention layer. This is then passed to lm_head and used as a logits for a softmax activation to compute cross entropy loss over next tokens.

In [8]:
torch.allclose(layer.c_proj(y), original_output)

True

## Key-Value Caching

Now suppose that new token has appended and the system is asked to generate next token given this updated token sequence. For illustration, query, key, value elements already calculated from previous iteration are underlined and newly calculated elements are boldfaced.

![qkv split](./images/kv_cache/qkv_split.png)

Below illustration demonstrates the flow of new token generation:

1. As explained above, this query(q), key(k) tensors are devidened into H sub-tensors and their dot product is calculated to compute logits corresponding to the value(v).
2. Also, each of dot product is then masked to calculate output tensor `y`.
3. However, only the last row of `y` is used for next token generation. Although we still need every elements of value tensor, we only need last row of dot products in this context. Whitened cells in tensors represent such redundant elements in this softmax applied dot product.
4. When rewinding calculation of this non-redundant logits, it turns out that only the last row in query tensor is required.

![qkv split](./images/kv_cache/cache_flow.gif)

In conclusion, to generate next token more efficiently by dodging duplicated query, key, value computation:

* Only the query, key, value that corresponds to new token has to be calculated
* Other elements in key, value tensor should be cached to be reused to form full key, value tensor

![qkv split](./images/kv_cache/conclusion.png)