In [2]:
import torch
from config import Config
from model import MiniTransformer
from attention import NaiveAttention, PageStreamingAttention
from kv_cache import PagedKVCache

device = "cuda"
config = Config()

naive_model = MiniTransformer(config, NaiveAttention).to(device)
stream_model = MiniTransformer(config, PageStreamingAttention).to(device)

# Copy weights
stream_model.load_state_dict(naive_model.state_dict())

naive_model.eval()
stream_model.eval()

input_ids = torch.randint(0, config.vocab_size, (1, 128), device=device)

# Naive
naive_out = naive_model(input_ids)

# Streaming
kv_caches = [
    PagedKVCache(
        1,
        config.num_heads,
        config.hidden_size // config.num_heads,
        device=device
    )
    for _ in range(config.num_layers)
]

stream_out = stream_model(input_ids, kv_caches)

diff = (naive_out - stream_out).abs().max()
print("Max difference:", diff.item())

Max difference: 1.6689300537109375e-06
