# KV Cache

1. Attention 类模型如 transformer / GPT 的训练和推理有什么特性
2. Decoder 预测时有什么计算特性
3. Decoder 有什么冗余计算
4. KV Cache 实现

## 1. Decoder Toy model

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
torch.manual_seed(42)

<torch._C.Generator at 0x10f3bcd90>

In [2]:
mask = -(1 - torch.tril(torch.ones(4, 4))) * torch.tensor(float('inf'))
mask = torch.nan_to_num(mask, nan=0.0)
print(mask)
p = F.softmax(mask, dim = -1)
print(p)

tensor([[ 0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38],
        [ 0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]])
tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500]])


In [3]:
class Attention(nn.Module):
    def __init__(self, dim = 512):
        super().__init__()
        self.dim = dim
        self.wq = nn.Linear(dim, dim)
        self.wk = nn.Linear(dim, dim)
        self.wv = nn.Linear(dim, dim)
        self.wo = nn.Linear(dim, dim)
        
    def forward(self, x, mask, verbose = False):
        q, k, v = self.wq(x), self.wk(x), self.wv(x)
        s = q@k.transpose(2,1) / math.sqrt(self.dim)
        if verbose:
            print('(q,k,v,s).shape:', q.shape, k.shape, v.shape, s.shape)
        s = s+mask.unsqueeze(0)
        p = F.softmax(s, dim = -1)
        z = p @ v
        return self.wo(z)
        
    
class SimplesDecoder(nn.Module):
    def __init__(self, dim = 512, vocab_size = 100, max_len = 1024):
        super().__init__()
        self.embd = nn.Embedding(vocab_size, dim)
        self.attn = Attention(dim)
        self.lm_head = nn.Linear(dim, vocab_size)
        self.mask = -(1 - torch.tril(torch.ones(max_len, max_len))) * torch.tensor(float('inf'))
        self.mask = torch.nan_to_num(self.mask, nan=0.0)

    def forward(self, x, verbose = False):
        bs, seq_len = x.shape
        X = self.embd(x)
        X = self.attn(X, self.mask[:seq_len, :seq_len], verbose=verbose)
        logits = self.lm_head(X)
        return logits

dim = 512,
seq_len = 16
vocab_size = 100
batch_size = 2
input_ids = torch.randint(vocab_size, [batch_size, seq_len])
# model = SimplesDecoder(dim = dim, vocab_size = vocab_size)
model = SimplesDecoder()
logits = model(input_ids)
print(logits.shape)
# print(logits)

torch.Size([2, 16, 100])


In [4]:
input_ids = torch.randint(vocab_size, [1, 2])
print(input_ids)
def generation(
    model = None,
    input_ids: torch.tensor = None,
    max_new_token: int = 100,
):
    for i in range(max_new_token):
        print(input_ids)
        with torch.no_grad():
            logits = model.forward(input_ids, verbose = True)
        logits = logits[:, -1, :] # zhe
        probs = F.softmax(logits, dim = -1)
        next_token_idx = torch.argmax(probs, dim=-1, keepdim=True)
        input_ids = torch.cat( [input_ids, next_token_idx], dim = -1 )
    return input_ids

# logits = model(input_ids)
result = generation(model, input_ids, max_new_token = 3)
print(result)

tensor([[69, 47]])
tensor([[69, 47]])
(q,k,v,s).shape: torch.Size([1, 2, 512]) torch.Size([1, 2, 512]) torch.Size([1, 2, 512]) torch.Size([1, 2, 2])
tensor([[69, 47, 15]])
(q,k,v,s).shape: torch.Size([1, 3, 512]) torch.Size([1, 3, 512]) torch.Size([1, 3, 512]) torch.Size([1, 3, 3])
tensor([[69, 47, 15, 15]])
(q,k,v,s).shape: torch.Size([1, 4, 512]) torch.Size([1, 4, 512]) torch.Size([1, 4, 512]) torch.Size([1, 4, 4])
tensor([[69, 47, 15, 15, 15]])


1. 在推理过程中， q,k,v 第 seq_len 维度累增，存在重复计算
2. 在推理过程中， score 可以看出 attention 重复计算

出现上述原因在于，next token prediction 在推理过程中，`input_ids` 在累增

1. 在第 t=1 `forward` 时不存在重复计算
2. 在 t>1 `forward` 时以累增的 `input_ids` 预测

next token prediction 预测的本质是：第 $t=n$ 时的 $q_{\color{red}{n}} $  与 $k_{1:n}, v_{1:n}$ 做注意力计算

上述代码中实际上是 $q_{\color{red}{1:} n}$  与 $k_{1:n}, v_{1:n}$

## KV-Cache

In [5]:
class AttentionKVCache(nn.Module):
    def __init__(self, dim = 512):
        super().__init__()
        self.dim = dim
        self.wq = nn.Linear(dim, dim)
        self.wk = nn.Linear(dim, dim)
        self.wv = nn.Linear(dim, dim)
        self.wo = nn.Linear(dim, dim)
        
        self.KV_cache = None
        
    def forward(self, x, mask, verbose = False):
        q, k, v = self.wq(x), self.wk(x), self.wv(x)

        
        if verbose:
            print('(q,k,v).shape:', q.shape, k.shape, v.shape,)
            if self.KV_cache is None:
                print('KV Cache is empty')
            else:
                print('KV Cache.shape:', self.KV_cache[0].shape, self.KV_cache[1].shape)


        if self.KV_cache is None:
            self.KV_cache = [k, v]
        else:
            self.KV_cache[0] = torch.cat((self.KV_cache[0], k), dim = 1) # k
            self.KV_cache[1] = torch.cat((self.KV_cache[1], v), dim = 1) # v

        
        s = q @ self.KV_cache[0].transpose(2,1) / math.sqrt(self.dim)

        if self.KV_cache is None:
            mask = mask.unsqueeze(0)
        else:
            mask = mask[-1, :].unsqueeze(0).unsqueeze(1) # 注意 mask 取最后一行 
        s = s + mask 

        if verbose:
            print('(s,mask).shape:',  s.shape, mask.shape)

        p = F.softmax(s, dim = -1)
        z = p @ self.KV_cache[1]
        return self.wo(z)

In [6]:
class SimplesDecoderKVCache(nn.Module):
    def __init__(self, dim = 512, vocab_size = 100, max_len = 1024):
        super().__init__()
        self.embd = nn.Embedding(vocab_size, dim)
        self.attn = AttentionKVCache(dim)
        self.lm_head = nn.Linear(dim, vocab_size)
        self.mask = -(1 - torch.tril(torch.ones(max_len, max_len))) * torch.tensor(float('inf'))
        self.mask = torch.nan_to_num(self.mask, nan=0.0)

    def forward(self, x, cur_len, verbose = False):
        bs, seq_len = x.shape
        X = self.embd(x)
        X = self.attn(X, self.mask[:cur_len, :cur_len], verbose=verbose)
        logits = self.lm_head(X)
        return logits
        
model = SimplesDecoderKVCache()

In [7]:
# input_ids = torch.randint(vocab_size, [1, 2])
print(input_ids)
def generation(
    model = None,
    input_ids: torch.tensor = None,
    max_new_token: int = 100,
):
    input_len = input_ids.shape[1]
    output_ids = input_ids.clone()
    for i in range(max_new_token):
        
        print('-' * 10, 'loop:', i, '-' * 10)
        print('input_ids:', input_ids)
        with torch.no_grad():
            logits = model.forward(input_ids, cur_len = input_len+i,  verbose = True)
        logits = logits[:, -1, :] # zhe
        probs = F.softmax(logits, dim = -1)
        next_token_idx = torch.argmax(probs, dim=-1, keepdim=True)
  
        # input_ids = torch.cat( [input_ids, next_token_idx], dim = -1 )
        input_ids = next_token_idx # 输入 1 个 token
        
        output_ids = torch.cat([output_ids, next_token_idx], dim = -1 )
        
    return output_ids

# logits = model(input_ids)
result = generation(model, input_ids, max_new_token = 3)
print(result)

tensor([[69, 47]])
---------- loop: 0 ----------
input_ids: tensor([[69, 47]])
(q,k,v).shape: torch.Size([1, 2, 512]) torch.Size([1, 2, 512]) torch.Size([1, 2, 512])
KV Cache is empty
(s,mask).shape: torch.Size([1, 2, 2]) torch.Size([1, 1, 2])
---------- loop: 1 ----------
input_ids: tensor([[65]])
(q,k,v).shape: torch.Size([1, 1, 512]) torch.Size([1, 1, 512]) torch.Size([1, 1, 512])
KV Cache.shape: torch.Size([1, 2, 512]) torch.Size([1, 2, 512])
(s,mask).shape: torch.Size([1, 1, 3]) torch.Size([1, 1, 3])
---------- loop: 2 ----------
input_ids: tensor([[65]])
(q,k,v).shape: torch.Size([1, 1, 512]) torch.Size([1, 1, 512]) torch.Size([1, 1, 512])
KV Cache.shape: torch.Size([1, 3, 512]) torch.Size([1, 3, 512])
(s,mask).shape: torch.Size([1, 1, 4]) torch.Size([1, 1, 4])
tensor([[69, 47, 65, 65, 98]])


## 推理分析

1. 在 for loop 上，输入由 `input_ids += next_token_idx`  变为 `input_ids = next_token_idx`
2. 在 attention 内部，遵循 1 q 和 kv-cache 进行算注意力
3. 在 t=1 时，填充 kv-cache，此时称为 prefill 阶段， 在 t>1 时, 一个 q 与多个 kv 算注意力，此时称为 decoding 阶段
4. prefill 与 training forward 的模式时相同的， 多 q 和 多kv， decoding 较为特殊是 1q 和 多KV
5. prefill, forward 计算是 block-wise 的，并行度高， decoding 是 line-wise 的，需要采用 batch-decoding 提高并行度
6. prefill, forward 有 compute-bound， decoding 有 memory-bound，例如 多层 attention 中，单个数据，需要多次访问 wq,wk,wv,wo 做投影，或者频繁加载kv cache计算
7. 不同的计算特性需要设计对应的机制

## KV Cache 分析

1. KV Cache 存储量: `bsz x seq_len x dim x attention_num_layers x bits`， 根据此模型可以有多种 kv-cache 存储量减少方法
2. generation 时随 seq_len 线形递增存储量
3. 一种 memory-efficient 方法则可以提高 cache 的管理效率， 管理的是存储过程
4. KV Cache 是 masked-self-attention 类模型的必要技术, Transformer 解码模型实际上也可以用 KV-Cache 技术
5. 分析 foward、prefill、decoding 的计算复杂度，从计算角度分析计算复杂度是否等同于计算效率？