# Generation without kv_cache

![NO_KVCACHE](./image/without_kv_cache.png)

In [1]:
import torch
import torch.nn.functional as F
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM
import time

config = LlamaConfig(
    vocab_size=100,
    hidden_size=256,
    intermediate_size=512,
    num_hidden_layers=2,
    num_attention_heads=4,
    num_key_value_heads=4,
)
model = LlamaForCausalLM(config)  # 加载模型

X = torch.randint(0, 100, (1, 10))  # 创建数据、不使用tokenizer
print(X.shape)

idx = {}
idx['input_ids'] = X
for i in range(4):
    print(f"\nGeneration第{i}个时的输入{idx['input_ids'].shape}：")
    print(f"Generation第{i}个时的输入{idx['input_ids']}：")
    output = model(**idx)
    logits = output['logits'][:, -1, :]
    idx_next = torch.argmax(logits, dim=1)[0]
    time.sleep(1)

    idx['input_ids'] = torch.cat((idx['input_ids'], idx_next.unsqueeze(0).unsqueeze(1)), dim=-1)

In [8]:
model.model.causal_mask
model.model._update_causal_mask(model.model.causal_mask, torch.randn((,2048,2048)))

# Generation With KV-Cache

![KVCACHE](./image/with_kv_cache.png)

In [8]:
# author: xiaodongguaAIGC
# KV-Cache + Generation + decoder

import torch
import torch.nn.functional as F

D = 128  # single-head-dim
V = 64  # vocab_size


# Decoder-Only
class xiaodonggua_kv_cache(torch.nn.Module):
    def __init__(self, D, V):
        super().__init__()
        self.D = D
        self.V = V
        self.Embedding = torch.nn.Embedding(V, D)
        self.Wq = torch.nn.Linear(D, D)
        self.Wk = torch.nn.Linear(D, D)
        self.Wv = torch.nn.Linear(D, D)
        self.lm_head = torch.nn.Linear(D, V)  # LM_head
        self.cache_K = self.cache_V = None  # initial

    def forward(self, X):
        X = self.Embedding(X)
        Q, K, V = self.Wq(X), self.Wk(X), self.Wv(X)
        print(f"input_Q:{Q.shape}")
        print(f"input_K:{Q.shape}")
        print(f"input_V:{Q.shape}")

        # Easy KV_Cache
        if self.cache_K == None:
            self.cache_K = K
            self.cache_V = V
        else:
            self.cache_K = torch.cat((self.cache_K, K), dim=1)
            self.cache_V = torch.cat((self.cache_V, V), dim=1)
            K = self.cache_K
            V = self.cache_V

        print(f"cache_K:{self.cache_K.shape}")
        print(f"cache_V:{self.cache_K.shape}")

        # ignore proj/MLP/scaled/mask/multi-head
        attn = Q @ K.transpose(1, 2) @ V

        # output
        output = self.lm_head(attn)
        return output

In [17]:
model = xiaodonggua_kv_cache(D, V)  # 创建decode模型

# 创建数据、不使用tokenizer
X = torch.randint(0, 64, (1, 10))
print(X.shape)

for i in range(3):
    print(f"\nGeneration {i} step input_shape: {X.shape}：")
    output = model.forward(X)
    next_token = torch.argmax(F.softmax(output, dim=-1), -1)[:, -1]
    print(f'next_token预测:{next_token}')
    X = next_token.unsqueeze(0)