In [5]:
import torch
import torch.nn as nn
import math
import time

device = "cuda"
torch.set_grad_enabled(False)

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

print("CUDA:", torch.cuda.is_available())
print("GPU:", torch.cuda.get_device_name(0))

CUDA: True
GPU: NVIDIA RTX 4000 Ada Generation


In [33]:
## lets do the model config

class Config:
    vocab_size = 32000
    hidden_size = 768
    num_heads = 12
    num_layers = 6
    max_seq_len = 32768

config = Config()

In [7]:
### Naive Attention with KV cache

class NaiveAttention(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        
        self.q_proj   =   nn.Linear(hidden_size, hidden_size)
        self.k_proj   =   nn.Linear(hidden_size, hidden_size)
        self.v_proj   =   nn.Linear(hidden_size, hidden_size)
        self.out_proj =   nn.Linear(hidden_size, hidden_size)

    def forward(self, x, kv_cache=None):
        B, T, C = x.shape

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)

        ## this is where we update our KV cache
        if kv_cache is not None:
            k, v = kv_cache.update(k, v)

        attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn = torch.softmax(attn, dim=-1)

        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).contiguous().view(B, T, C)

        return self.out_proj(out)

In [8]:
## KV Cache

class KVCache:
    def __init__(self, batch_size, num_heads, max_seq_len, head_dim):
        self.k = torch.zeros(batch_size, num_heads, max_seq_len, head_dim, device=device)
        self.v = torch.zeros(batch_size, num_heads, max_seq_len, head_dim, device=device)
        self.cur_pos = 0

    def update(self, new_k, new_v):
        B, H, T, D = new_k.shape

        self.k[:, :, self.cur_pos:self.cur_pos+T, :] = new_k
        self.v[:, :, self.cur_pos:self.cur_pos+T, :] = new_v
        self.cur_pos += T

        return (
            self.k[:, :, :self.cur_pos, :],
            self.v[:, :, :self.cur_pos, :]
        )

In [9]:
class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.hidden_size)
        self.attn = NaiveAttention(config.hidden_size, config.num_heads)
        self.ln2 = nn.LayerNorm(config.hidden_size)
        self.mlp = nn.Sequential(
            nn.Linear(config.hidden_size, 4 * config.hidden_size),
            nn.GELU(),
            nn.Linear(4 * config.hidden_size, config.hidden_size)
        )

    def forward(self, x, kv_cache=None):
        x = x + self.attn(self.ln1(x), kv_cache)
        x = x + self.mlp(self.ln2(x))
        return x


class MiniTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
        self.layers = nn.ModuleList([
            TransformerBlock(config)
            for _ in range(config.num_layers)
        ])
        self.ln_f = nn.LayerNorm(config.hidden_size)
        self.head = nn.Linear(config.hidden_size, config.vocab_size)

    def forward(self, input_ids, kv_caches=None):
        x = self.embed(input_ids)

        for i, layer in enumerate(self.layers):
            cache = kv_caches[i] if kv_caches else None
            x = layer(x, cache)

        x = self.ln_f(x)
        return self.head(x)

In [10]:
## Instatiate model

model = MiniTransformer(config).to(device)
model.eval()

print("Model parameters:",
      sum(p.numel() for p in model.parameters()) / 1e6, "M")

Model parameters: 91.712768 M


In [11]:
def benchmark_prefill(seq_len=1024, batch_size=1):
    input_ids = torch.randint(0, config.vocab_size,
                              (batch_size, seq_len),
                              device=device)

    torch.cuda.synchronize()
    start = time.time()
    _ = model(input_ids)
    torch.cuda.synchronize()
    end = time.time()

    print(f"Prefill latency: {end-start:.4f}s")

benchmark_prefill(1024, 1)

Prefill latency: 0.1606s


In [16]:
## decode benchmark

def benchmark_decode(steps=128, batch_size=1):
    input_ids = torch.randint(0, config.vocab_size,
                              (batch_size, 1),
                              device=device)

    kv_caches = [
        KVCache(
            batch_size,
            config.num_heads,
            config.max_seq_len,
            config.hidden_size // config.num_heads
        )
        for _ in range(config.num_layers)
    ]

    torch.cuda.synchronize()
    start = time.time()

    for _ in range(steps):
        logits = model(input_ids, kv_caches)
        input_ids = torch.argmax(logits[:, -1, :],
                                 dim=-1,
                                 keepdim=True)

    torch.cuda.synchronize()
    end = time.time()

    print(f"Decode tokens/sec: {steps / (end - start):.2f}")

benchmark_decode(256, 1)

Decode tokens/sec: 323.65


In [19]:
## prefix scaling

for seq in [256, 512, 1024, 1536, 4048, 8026]:
    benchmark_prefill(seq, 1)

Prefill latency: 0.0064s
Prefill latency: 0.0060s
Prefill latency: 0.0104s
Prefill latency: 0.0210s
Prefill latency: 0.1097s
Prefill latency: 0.3973s


In [20]:
## decode scaling
for steps in [128, 256, 512, 1024]:
    benchmark_decode(steps, 1)


Decode tokens/sec: 349.10
Decode tokens/sec: 351.18
Decode tokens/sec: 350.37
Decode tokens/sec: 348.15


In [22]:
def print_gpu_memory():
    allocated = torch.cuda.memory_allocated() / 1e9
    reserved = torch.cuda.memory_reserved() / 1e9
    print(f"Allocated: {allocated:.2f} GB | Reserved: {reserved:.2f} GB")

torch.cuda.empty_cache()
print_gpu_memory()
benchmark_prefill(4096, 1)
print_gpu_memory()

Allocated: 0.38 GB | Reserved: 0.41 GB
Prefill latency: 0.1359s
Allocated: 0.38 GB | Reserved: 2.08 GB


In [23]:
benchmark_prefill(8192, 1)
print_gpu_memory()

Prefill latency: 0.4162s
Allocated: 0.38 GB | Reserved: 8.52 GB


In [24]:
torch.cuda.empty_cache()
print_gpu_memory()
benchmark_prefill(4096, 4)
print_gpu_memory()

Allocated: 0.38 GB | Reserved: 0.41 GB
Prefill latency: 0.4851s
Allocated: 0.38 GB | Reserved: 7.20 GB


In [25]:
## FlashAttention

class BlockedAttention(nn.Module):
    def __init__(self, hidden_size, num_heads, block_size=256):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.block_size = block_size
        
        self.q_proj = nn.Linear(hidden_size, hidden_size)
        self.k_proj = nn.Linear(hidden_size, hidden_size)
        self.v_proj = nn.Linear(hidden_size, hidden_size)
        self.out_proj = nn.Linear(hidden_size, hidden_size)

    def forward(self, x):
        B, T, C = x.shape

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        q = q.view(B, T, self.num_heads, self.head_dim).transpose(1,2)
        k = k.view(B, T, self.num_heads, self.head_dim).transpose(1,2)
        v = v.view(B, T, self.num_heads, self.head_dim).transpose(1,2)

        scale = 1.0 / math.sqrt(self.head_dim)

        output = torch.zeros_like(q)

        for start in range(0, T, self.block_size):
            end = min(start + self.block_size, T)

            q_block = q[:, :, start:end, :]  # (B,H,block,D)

            # Online softmax components
            m_i = torch.full((B, self.num_heads, end-start), 
                             -float("inf"), device=q.device)
            l_i = torch.zeros((B, self.num_heads, end-start), 
                              device=q.device)
            acc = torch.zeros_like(q_block)

            for k_start in range(0, T, self.block_size):
                k_end = min(k_start + self.block_size, T)

                k_block = k[:, :, k_start:k_end, :]
                v_block = v[:, :, k_start:k_end, :]

                scores = torch.matmul(q_block, k_block.transpose(-2,-1)) * scale

                block_max = scores.max(dim=-1).values
                new_m = torch.maximum(m_i, block_max)

                exp_scores = torch.exp(scores - new_m.unsqueeze(-1))
                exp_m_diff = torch.exp(m_i - new_m)

                l_i = exp_m_diff * l_i + exp_scores.sum(dim=-1)
                acc = exp_m_diff.unsqueeze(-1) * acc + torch.matmul(exp_scores, v_block)

                m_i = new_m

            output[:, :, start:end, :] = acc / l_i.unsqueeze(-1)

        output = output.transpose(1,2).contiguous().view(B, T, C)

        return self.out_proj(output)

In [26]:
class BlockedTransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.hidden_size)
        self.attn = BlockedAttention(config.hidden_size, config.num_heads, block_size=256)
        self.ln2 = nn.LayerNorm(config.hidden_size)
        self.mlp = nn.Sequential(
            nn.Linear(config.hidden_size, 4 * config.hidden_size),
            nn.GELU(),
            nn.Linear(4 * config.hidden_size, config.hidden_size)
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


class BlockedTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
        self.layers = nn.ModuleList([
            BlockedTransformerBlock(config)
            for _ in range(config.num_layers)
        ])
        self.ln_f = nn.LayerNorm(config.hidden_size)
        self.head = nn.Linear(config.hidden_size, config.vocab_size)

    def forward(self, input_ids):
        x = self.embed(input_ids)
        for layer in self.layers:
            x = layer(x)
        x = self.ln_f(x)
        return self.head(x)

In [27]:
blocked_model = BlockedTransformer(config).to(device)
blocked_model.eval()

BlockedTransformer(
  (embed): Embedding(32000, 768)
  (layers): ModuleList(
    (0-5): 6 x BlockedTransformerBlock(
      (ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): BlockedAttention(
        (q_proj): Linear(in_features=768, out_features=768, bias=True)
        (k_proj): Linear(in_features=768, out_features=768, bias=True)
        (v_proj): Linear(in_features=768, out_features=768, bias=True)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
      )
      (ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): Sequential(
        (0): Linear(in_features=768, out_features=3072, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=3072, out_features=768, bias=True)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (head): Linear(in_features=768, out_features=32000, bias=True)
)

In [28]:
def benchmark_blocked(seq_len=4096, batch_size=1):
    input_ids = torch.randint(0, config.vocab_size,
                              (batch_size, seq_len),
                              device=device)

    torch.cuda.empty_cache()
    print_gpu_memory()

    torch.cuda.synchronize()
    start = time.time()
    _ = blocked_model(input_ids)
    torch.cuda.synchronize()
    end = time.time()

    print(f"Blocked prefill latency: {end-start:.4f}s")
    print_gpu_memory()

In [30]:
benchmark_blocked(4096, 1)
benchmark_blocked(32000, 1)

Allocated: 0.75 GB | Reserved: 3.83 GB
Blocked prefill latency: 0.3588s
Allocated: 1.27 GB | Reserved: 3.84 GB
Allocated: 0.75 GB | Reserved: 3.83 GB
Blocked prefill latency: 20.9284s
Allocated: 4.84 GB | Reserved: 7.93 GB


In [31]:
def decode_after_prefill(prefill_len, steps=256):
    input_ids = torch.randint(0, config.vocab_size,
                              (1, prefill_len),
                              device=device)

    kv_caches = [
        KVCache(
            1,
            config.num_heads,
            config.max_seq_len,
            config.hidden_size // config.num_heads
        )
        for _ in range(config.num_layers)
    ]

    # Prefill
    _ = model(input_ids, kv_caches)

    input_ids = torch.randint(0, config.vocab_size,
                              (1, 1),
                              device=device)

    torch.cuda.synchronize()
    start = time.time()

    for _ in range(steps):
        logits = model(input_ids, kv_caches)
        input_ids = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)

    torch.cuda.synchronize()
    end = time.time()

    print(f"Prefill {prefill_len} → decode tok/s:",
          steps / (end - start))

In [34]:
for ctx in [512, 1024, 2048, 4096]:
    decode_after_prefill(ctx, 256)

Prefill 512 → decode tok/s: 337.76002576904494
Prefill 1024 → decode tok/s: 350.66236060350633
Prefill 2048 → decode tok/s: 345.42498582902306
Prefill 4096 → decode tok/s: 351.57389214498545


In [35]:
def full_generation_time(total_tokens):
    input_ids = torch.randint(0, config.vocab_size,
                              (1, 1),
                              device=device)

    kv_caches = [
        KVCache(
            1,
            config.num_heads,
            config.max_seq_len,
            config.hidden_size // config.num_heads
        )
        for _ in range(config.num_layers)
    ]

    torch.cuda.synchronize()
    start = time.time()

    for _ in range(total_tokens):
        logits = model(input_ids, kv_caches)
        input_ids = torch.argmax(logits[:, -1, :],
                                 dim=-1,
                                 keepdim=True)

    torch.cuda.synchronize()
    end = time.time()

    print(f"Generate {total_tokens} tokens → {end-start:.4f}s")

In [36]:
for t in [256, 512, 1024]:
    full_generation_time(t)

Generate 256 tokens → 0.7628s
Generate 512 tokens → 1.4797s
Generate 1024 tokens → 3.0131s


In [37]:
## KV cache memory scaling

def kv_cache_memory(batch, seq_len):
    bytes_per_element = 2  # assuming FP16
    total = (
        batch *
        config.num_layers *
        config.num_heads *
        seq_len *
        (config.hidden_size // config.num_heads) *
        2 *  # K and V
        bytes_per_element
    )
    print(f"Approx KV cache size: {total / 1e9:.2f} GB")

In [38]:
kv_cache_memory(1, 4096)
kv_cache_memory(1, 8192)
kv_cache_memory(4, 8192)

Approx KV cache size: 0.08 GB
Approx KV cache size: 0.15 GB
Approx KV cache size: 0.60 GB


Allocated: 0.88 GB | Reserved: 3.83 GB
Prefill 4096 → decode tok/s: 0.0
Allocated: 0.88 GB | Reserved: 3.83 GB


Allocated: 0.88 GB | Reserved: 3.83 GB
Prefill 8192 → decode tok/s: 0.0
Allocated: 0.88 GB | Reserved: 10.27 GB
