<a href="https://colab.research.google.com/github/shunzh/llm.ipynb/blob/main/llm_with_kv_cache.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 📘 Implementing a Decoder-Only Transformer from Scratch (with KV Cache)

We implement a decoder-only Transformer with kv cache to improve its inference efficiency.

We recommend going through the [original implementation without KV cache](https://github.com/shunzh/llm.ipynb/blob/main/llm.ipynb) first.

We first install and import the necessary libraries and set the random seed.

In [1]:
import math
import torch
import torch.nn as nn


# Set random seed
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

## Architecture

We first define a model config object with default parameter values.


In [2]:
from dataclasses import dataclass

@dataclass
class Config:
    # The size of hidden state in transformer, also called d_model
    hidden_size: int = 512
    # The size of hidden state in MLP in the decoder block
    ff_hidden_size: int = 4 * 512
    # The number of decoder blocks
    num_hidden_layers: int = 2
    # Dropout rates for all modules that need dropout
    dropout_rate: float = 0.1
    vocab_size: int = 10000
    max_seq_len: int = 128

config = Config()

### Decoder Block

The decoder block is the core of the transformer. It has four modules inside (where the LayerNorm appears twice):

```
Input: x (batch_size, seq_len, hidden_size)
        │
        ▼
+-------------------+
|   LayerNorm       |
+-------------------+
        │
        ▼
+-------------------+
|  Self-Attention   |  (single head, with causal mask)
+-------------------+
        │
        ▼
+-------------------+
|   LayerNorm       |
+-------------------+
        │
        ▼
+-------------------+
|      MLP          |  (Linear → GELU → Linear)
+-------------------+
        │
        ▼
     Output x (batch_size, seq_len, hidden_size)
```

Let's define these modules.

In [3]:
class LayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-5):
        super().__init__()

        self.gamma = nn.Parameter(torch.ones(hidden_size)) # (hidden_size,)
        self.beta = nn.Parameter(torch.zeros(hidden_size)) # (hidden_size,)
        self.eps = eps

    def forward(self, x):
        # x: (batch_size, seq_len, hidden_size)
        mean = x.mean(dim=-1, keepdim=True) # (batch_size, seq_len, 1)
        std = x.std(dim=-1, keepdim=True) # (batch_size, seq_len, 1)
        return (x - mean) / (std + self.eps) * self.gamma + self.beta # (batch_size, seq_len, hidden_size)


# Test layer norm
layer_norm = LayerNorm(config.hidden_size)
x = torch.randn(1, 3, config.hidden_size)
layer_norm_output = layer_norm(x)
print("The size of layer norm output is (batch_size, seq_len, hidden_size):", layer_norm_output.shape)

The size of layer norm output is (batch_size, seq_len, hidden_size): torch.Size([1, 3, 512])


In [4]:
class SingleHeadSelfAttention(nn.Module):
    def __init__(self, hidden_size, dropout_rate=0.1):
        super().__init__()

        self.hidden_size = hidden_size
        self.dropout_rate = dropout_rate

        # Attention: hidden state to query, key, value
        self.c_attn = nn.Linear(hidden_size, 3 * hidden_size)
        # Output projection
        self.c_proj = nn.Linear(hidden_size, hidden_size)

        # Dropouts
        self.attn_dropout = nn.Dropout(dropout_rate)
        self.proj_dropout = nn.Dropout(dropout_rate)

    def forward(self, x, past_kv=None, use_cache=False):
        # x: (batch_size, seq_len, hidden_size)
        # If use_cache == True, x only contain new tokens because previous tokens are not needed,
        #   past_kv is a tuple of (past_k, past_v), each is (batch_size, past_len, hidden_size),
        #   where past_len is the length of cached tokens
        # let total_len = past_len + seq_len if use_cache else seq_len
        batch_size, seq_len, hidden_size_in_data = x.shape
        assert self.hidden_size == hidden_size_in_data, f"Mismatch between hidden_size in config {self.hidden_size} and hidden_size in data {hidden_size_in_data}"

        c_attn_output = self.c_attn(x) # (batch_size, seq_len, 3 * hidden_size)

        # Split into query, key, and value
        q, k, v = c_attn_output.split(self.hidden_size, dim=-1)

        if past_kv is not None:
            past_k, past_v = past_kv
            k = torch.cat([past_k, k], dim=1) # (batch_size, total_len, hidden_size)
            v = torch.cat([past_v, v], dim=1) # (batch_size, total_len, hidden_size)

        # Compute attention scores
        # q: (.., seq_len, hidden_size)
        # k.transpose(-2, -1): (.., hidden_size, total_len)
        attn = (q @ k.transpose(-2, -1)) / (math.sqrt(self.hidden_size)) # (batch_size, seq_len, total_len)

        # Apply causal mask
        total_len = attn.shape[-1]
        # If use_cache = True, seq_len = 2, the mask looks like
        # [[1, 1, .., 1, 0]]
        # [[1, 1, .., 1, 1]]
        # (although seq_len = 1 during inference)
        mask = torch.tril(torch.ones(1, total_len, total_len))[:, -seq_len:, :].to(attn.device) # (1, seq_len, total_len)
        attn = attn.masked_fill(mask == 0, float('-inf'))

        # Softmax
        attn = torch.softmax(attn, dim=-1)

        # Dropout attention
        attn = self.attn_dropout(attn)

        # Attention output
        # attn: (.., seq_len, seq_len)
        # v: (.., seq_len, d_head)
        attn_output = attn @ v # (batch_size, seq_len, d_head)

        # Final projection
        proj_output = self.c_proj(attn_output)
        proj_output = self.proj_dropout(proj_output) # (batch_size, seq_len, d_head)

        if use_cache:
            new_kv = (k, v) # Each (batch_size, total_len, hidden_size)
            return proj_output, new_kv
        else:
            return proj_output


# Test with and without kv cache
# Assume config.hidden_size is already set
x = torch.randn(1, 3, config.hidden_size)
self_attention = SingleHeadSelfAttention(config.hidden_size)
self_attention.eval()  # disable dropout

# Forward without cache
out_no_cache = self_attention(x, use_cache=False)

# Forward with cache (step by step)
past_kv = None
outputs = []
for i in range(x.size(1)):
    token = x[:, i:i+1]
    out, past_kv = self_attention(token, past_kv=past_kv, use_cache=True)
    outputs.append(out)
out_with_cache = torch.cat(outputs, dim=1)

# Print results
print("Output without cache:")
print(out_no_cache)

print("\nOutput with cache:")
print(out_with_cache)

assert torch.allclose(out_no_cache, out_with_cache, atol=1e-5), "Outputs do not match with and without KV cache"
print("✅ Outputs match with and without KV cache")

Output without cache:
tensor([[[-0.6716, -0.3282, -0.6029,  ...,  0.0846, -0.9988, -0.0512],
         [-0.6365, -0.0690, -0.4434,  ...,  0.1903, -0.4347,  0.0625],
         [-0.2435,  0.2562, -0.2738,  ...,  0.1553, -0.2531, -0.0460]]],
       grad_fn=<ViewBackward0>)

Output with cache:
tensor([[[-0.6716, -0.3282, -0.6029,  ...,  0.0846, -0.9988, -0.0512],
         [-0.6365, -0.0690, -0.4434,  ...,  0.1903, -0.4347,  0.0625],
         [-0.2435,  0.2562, -0.2738,  ...,  0.1553, -0.2531, -0.0460]]],
       grad_fn=<CatBackward0>)
✅ Outputs match with and without KV cache


In [5]:
class MLP(nn.Module):
    def __init__(self, hidden_size, ff_hidden_size, dropout_rate=0.1):
        super().__init__()

        self.hidden_size = hidden_size
        self.ff_hidden_size = ff_hidden_size

        self.c_fc = nn.Linear(self.hidden_size, self.ff_hidden_size)
        self.act = nn.GELU() # Or other activation functions
        self.c_proj = nn.Linear(self.ff_hidden_size, self.hidden_size)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        # x: (batch_size, seq_len, hidden_size)
        x = self.c_fc(x) # (batch_size, seq_len, ff_hidden_size)
        x = self.act(x) # (batch_size, seq_len, ff_hidden_size)
        x = self.c_proj(x) # (batch_size, seq_len, hidden_size)
        x = self.dropout(x)
        return x


# Test MLP
x = torch.randn(1, 3, config.hidden_size)
mlp = MLP(config.hidden_size, config.ff_hidden_size)
output = mlp(x)
print("Output shape:", output.shape)

Output shape: torch.Size([1, 3, 512])


With all the pieces defined above, we're ready to define the decoder block.

In [6]:
class DecoderBlock(nn.Module):
    def __init__(self, hidden_size, ff_hidden_size, dropout_rate=0.1):
        super().__init__()
        self.ln_1 = LayerNorm(hidden_size)
        self.attn = SingleHeadSelfAttention(hidden_size, dropout_rate)
        self.ln_2 = LayerNorm(hidden_size)
        self.mlp = MLP(hidden_size, ff_hidden_size, dropout_rate)

    def forward(self, x, past_kv=None, use_cache=False):
        # x: (batch_size, seq_len, hidden_size)
        # Layer norm 1
        x = self.ln_1(x)
        # Self attention + residual
        if use_cache:
            attn_output, new_kv = self.attn(x, past_kv=past_kv, use_cache=True)
            x = x + attn_output
        else:
            attn_output = self.attn(x)
            x = x + attn_output
        # Layer norm 2
        x = self.ln_2(x)
        # MLP + residual
        x = x + self.mlp(x)

        if use_cache:
            return x, new_kv
        else:
            return x


# Test Decoder
x = torch.randn(1, 3, config.hidden_size)
decoder = DecoderBlock(config.hidden_size, config.ff_hidden_size)
decoder.eval()  # disable dropout

# Forward without cache
out_full = decoder(x, use_cache=False)

# Forward with cache (incremental)
past_kv = None
outputs = []
for i in range(x.size(1)):
    token = x[:, i:i+1]
    out_step, past_kv = decoder(token, past_kv=past_kv, use_cache=True)
    outputs.append(out_step)
out_cached = torch.cat(outputs, dim=1)

# Compare
print("Output without cache:\n", out_full)
print("\nOutput with cache:\n", out_cached)

assert torch.allclose(out_full, out_cached, atol=1e-5), "Outputs do not match with and without KV cache"
print("✅ Outputs match with and without KV cache")

Output without cache:
 tensor([[[ 0.1712,  0.2079, -0.2557,  ...,  0.4331, -0.7992,  0.5643],
         [-1.7891, -1.1413,  1.2852,  ..., -0.1187,  1.4198,  0.8264],
         [-0.7370, -0.2091,  1.4891,  ...,  0.2402,  1.4331, -0.9390]]],
       grad_fn=<AddBackward0>)

Output with cache:
 tensor([[[ 0.1712,  0.2079, -0.2557,  ...,  0.4331, -0.7992,  0.5643],
         [-1.7891, -1.1413,  1.2852,  ..., -0.1187,  1.4198,  0.8264],
         [-0.7370, -0.2091,  1.4891,  ...,  0.2402,  1.4331, -0.9390]]],
       grad_fn=<CatBackward0>)
✅ Outputs match with and without KV cache


### The Complete Transformer Model

With the Decoder block defined above, we're ready to define the complete Transformer architecture.

```
Input: input_ids (batch_size, seq_len)
        │
        ▼
+------------------------+
|  Token Embedding       |
|  Position Embedding    |
+------------------------+
        │
        ▼
+------------------------+
|   DecoderBlock × N     |  (defined in the previous section)
+------------------------+
        │
        ▼
+------------------------+
|   Final LayerNorm      |
+------------------------+
        │
        ▼
+------------------------+
|  Linear (Language Head)|
+------------------------+
        │
        ▼
Output: logits (batch_size, seq_len, vocab_size)
```


In [7]:
class Transformer(nn.Module):
    def __init__(
        self,
        hidden_size,
        ff_hidden_size,
        vocab_size,
        max_seq_len,
        num_hidden_layers,
        dropout_rate=0.1,
    ):
        super().__init__()

        self.token_embed = nn.Embedding(vocab_size, hidden_size)
        self.position_embed = nn.Embedding(max_seq_len, hidden_size)
        self.embed_dropout = nn.Dropout(dropout_rate)

        self.hidden_layers = nn.ModuleList([DecoderBlock(hidden_size, ff_hidden_size, dropout_rate) for _ in range(num_hidden_layers)])
        # The final layer norm
        self.ln_f = nn.LayerNorm(hidden_size)

        # The final language head, which maps the last hidden state to logits
        self.language_head = nn.Linear(hidden_size, vocab_size)

    def forward(
        self,
        input_ids,
        past_kvs=None,
        use_cache=False,
    ):
        # input_ids: (batch_size, seq_len)
        batch_size, seq_len = input_ids.shape

        if past_kvs:
            past_len = past_kvs[0][0].shape[1]
        else:
            past_len = 0

        # Create position ids (past_len, past_len + 1, ..)
        position_ids = torch.arange(past_len, past_len + seq_len, device=input_ids.device).unsqueeze(0)  # (1, seq_len)

        # Embed tokens and positions, apply dropout
        x = self.token_embed(input_ids) + self.position_embed(position_ids)
        x = self.embed_dropout(x)

        # Transformer blocks
        new_kvs = [] if use_cache else None # new_kvs[i] will be the new_kv for the i-th layer
        for i, layer in enumerate(self.hidden_layers):
            past_kv = past_kvs[i] if past_kvs is not None else None
            if use_cache:
                x, new_kv = layer(x, past_kv=past_kv, use_cache=True)
                new_kvs.append(new_kv)
            else:
                x = layer(x)

        # Final layer norm
        x = self.ln_f(x)

        # Project to vocabulary
        logits = self.language_head(x) # (batch_size, seq_len, vocab_size)

        if use_cache:
            return logits, new_kvs
        else:
            return logits


# Test Transformer
input_ids = torch.randint(0, config.vocab_size, (1, 3))

model = Transformer(
    hidden_size=config.hidden_size,
    ff_hidden_size=config.ff_hidden_size,
    vocab_size=config.vocab_size,
    max_seq_len=config.max_seq_len,
    num_hidden_layers=config.num_hidden_layers,
    dropout_rate=config.dropout_rate,
)
model.eval()
print("Our Transformer model:\n", model)

# Forward without cache
logits_full = model(input_ids, use_cache=False)

# Forward with cache (token-by-token)
past_kvs = None
logits_cached = []
for i in range(input_ids.size(1)):
    logits_step, past_kvs = model(input_ids[:, i:i+1], past_kvs=past_kvs, use_cache=True)
    logits_cached.append(logits_step)
logits_cached = torch.cat(logits_cached, dim=1)

# Compare
print("Full logits:\n", logits_full)
print("Cached logits:\n", logits_cached)

assert torch.allclose(logits_full, logits_cached, atol=1e-5), "Outputs do not match with and without KV cache"
print("✅ Outputs match with and without KV cache")

Our Transformer model:
 Transformer(
  (token_embed): Embedding(10000, 512)
  (position_embed): Embedding(128, 512)
  (embed_dropout): Dropout(p=0.1, inplace=False)
  (hidden_layers): ModuleList(
    (0-1): 2 x DecoderBlock(
      (ln_1): LayerNorm()
      (attn): SingleHeadSelfAttention(
        (c_attn): Linear(in_features=512, out_features=1536, bias=True)
        (c_proj): Linear(in_features=512, out_features=512, bias=True)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (proj_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm()
      (mlp): MLP(
        (c_fc): Linear(in_features=512, out_features=2048, bias=True)
        (act): GELU(approximate='none')
        (c_proj): Linear(in_features=2048, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (language_head): Linear(in_features=512, out_features=10000, bias=True)
)
Full logits:
 tens

## Model Inference

In [None]:
from transformers import AutoTokenizer

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# Update config
config.vocab_size = tokenizer.vocab_size

# Tokenize entire corpus
def tokenize(example):
    return tokenizer(example["text"])

Let's set a path to save the trained checkpoint so it can be retrieved later for inference.

In [9]:
from google.colab import drive
drive.mount('/content/drive')

# Set model path here
model_checkpoint_path = "/content/drive/MyDrive/Colab Notebooks/model.pth"

# Or save the checkpoint to this runtime without connecting to Google Drive.
# The checkpoint will be deleted after the runtime terminates.
# model_checkpoint_path = "/content/model.pth"

Mounted at /content/drive


### Model Inference



Let's first load the trained model from the previous cell.

In [10]:
#  The training cell above may have been killed. So redefine these variables here.
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Using device", device)

# Initialize model
model = Transformer(
    hidden_size=config.hidden_size,
    ff_hidden_size=config.ff_hidden_size,
    vocab_size=config.vocab_size,
    max_seq_len=config.max_seq_len,
    num_hidden_layers=config.num_hidden_layers,
    dropout_rate=config.dropout_rate,
).to(device)

# Load trained model
try:
    model.load_state_dict(torch.load(model_checkpoint_path, map_location=device))
except:
    print(f"Checkpoint {model_checkpoint_path} not found. Skip loading.")

Using device cuda




#### Greedy Decoding

Since we're using a very small model and a very limited training set, the generated outputs may not be semantically meaningful.
Still, let's run greedy decoding with the trained model and see what it produces!

In [11]:
import time


def greedy_decode(model, input_ids, max_len, device=torch.device("cpu")):
    model.eval()

    for _ in range(max_len - input_ids.size(1)):
        logits = model(input_ids) # (batch_size, seq_len, vocab_size)
        next_tokens = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True) # (batch_size, 1)
        input_ids = torch.cat((input_ids, next_tokens), dim=-1)

    return input_ids


def greedy_decode_with_kv_cache(model, input_ids, max_len, device=torch.device("cpu")):
    model.eval()

    past_kvs = None
    generated = input_ids

    for _ in range(max_len - input_ids.size(1)):
        logits, past_kvs = model(input_ids, past_kvs=past_kvs, use_cache=True)
        next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
        generated = torch.cat([generated, next_token], dim=-1)
        input_ids = next_token  # only feed the new token next step

    return generated


# Use an empty prompt
input_ids = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).to(device)

# Run greedy decoding
start_time = time.time()
output_ids = greedy_decode(model, input_ids, max_len=128, device=device)
end_time = time.time()
generated_text = tokenizer.decode(output_ids[0].tolist(), skip_special_tokens=True)

print("==== Output with NO KV Cache ====")
print(generated_text)
print("==== End of Outpute ====")
print(f"Time taken: {end_time - start_time} seconds")
print()

# Run greedy decoding with kv cache
start_time = time.time()
output_ids = greedy_decode_with_kv_cache(model, input_ids, max_len=128, device=device)
end_time = time.time()
generated_text = tokenizer.decode(output_ids[0].tolist(), skip_special_tokens=True)

print("==== Output with KV Cache ====")
print(generated_text)
print("==== End of Outpute ====")
print(f"Time taken: {end_time - start_time} seconds")

==== Output with NO KV Cache ====
,
And, as I am a king,
And, as I do not, and I am sure,
And, as I am a king,
And yet I am not to be a king.

DUKE VINCENTIO:
I am a gentleman, and I am a king,
And so I am a king, and a
motion generative to the crown.

DUKE VINCENTIO:
I know you, sir, sir, sir, sir, sir, sir, sir, sir, sir, sir, sir, sir, sir, sir, sir
==== End of Outpute ====
Time taken: 1.1075119972229004 seconds

==== Output with KV Cache ====
,
And, as I am a king,
And, as I do not, and I am sure,
And, as I am a king,
And yet I am not to be a king.

DUKE VINCENTIO:
I am a gentleman, and I am a king,
And so I am a king, and a
motion generative to the crown.

DUKE VINCENTIO:
I know you, sir, sir, sir, sir, sir, sir, sir, sir, sir, sir, sir, sir, sir, sir, sir
==== End of Outpute ====
Time taken: 0.3300652503967285 seconds
