In [15]:
import numpy as np
import pytest
import torch
import math
from torch.nn import functional as F
from typing import List, Tuple

import flashinfer

In [16]:
num_head = 2
head_dim = 128
page_len = 16
device = torch.device("cuda:0")
dtype = torch.bfloat16

In [17]:
def assert_close(a, b):
    rtol, atol = {
        torch.float16: (1e-3, 5e-4),
        torch.float32: (1e-5, 5e-6),
        torch.bfloat16: (8e-3, 8e-3),
    }[a.dtype]
    torch.testing.assert_close(a, b, rtol=rtol, atol=atol)

In [18]:
## q: (seqLen, num_head, head_dim)
## k: (seqLen, num_head, head_dim)
## v: (seqLen, num_head, head_dim)
def batch_prefill_baseline(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqLens: List[int]):
    attns = []
    startingIndex = 0
    for seqLen in seqLens:
        seqSlice = slice(startingIndex, startingIndex + seqLen)
        qs = q[seqSlice]
        ks = k[seqSlice]
        vs = v[seqSlice]
        attns.append(prefill_single_seq_attn(qs, ks, vs, seqLen))
        startingIndex += seqLen
    return torch.cat(attns, dim=0) 

## q: (seqLen, num_head, head_dim)
## k: (seqLen, num_head, head_dim)
## v: (seqLen, num_head, head_dim)
def prefill_single_seq_attn(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqLen: int):
    qt = q.transpose(1, 0) # (num_head, seqLen, head_dim)
    kt = k.transpose(1, 0) # (num_head, seqLen, head_dim)
    vt = v.transpose(1, 0) # (num_head, seqLen, head_dim)
    scale = math.sqrt(kt.size(-1))

    qkProduct = (qt @ kt.transpose(-2, -1)) * (1.0 / scale) # (num_head, seqLen, seqLen)
    causalMask = torch.triu(torch.full((seqLen, seqLen), float('-inf'), dtype=dtype, device=device), diagonal=1) # lower triangular matrix
    softmax = F.softmax(qkProduct + causalMask, dim=-1)
    attn = softmax @ vt # (num_head, seqLen, seqLen) x (num_head, seqLen, head_dim) -> (num_head, seqLen, head_dim)
    return attn.transpose(0,1) # (seqLen, num_head, head_dim)

In [19]:
## q: (numSeqs, num_head, head_dim)
## k: (seqLen, num_head, head_dim)
## v: (seqLen, num_head, head_dim)
def batch_decode_baseline(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqLens: List[int]):
    attns = []
    startingIndex = 0
    for i, seqLen in enumerate(seqLens):
        seqSlice = slice(startingIndex, startingIndex + seqLen)
        qi = q[i] # rotary_embed(q[i], seqLen - 1) # (num_head, head_dim)
        kSlice = k[seqSlice] # rotary_embed(k[seqSlice].transpose(0,1), 0).transpose(0,1) # (seqLen, num_head, head_dim)
        attns.append(decode_single_seq_attn(qi, kSlice, v[seqSlice], seqLen))
        startingIndex += seqLen
    return torch.cat(attns, dim=0)

def decode_single_seq_attn(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqLen):
    qt = q.view(num_head, 1, head_dim)  # (num_head, 1, head_dim)
    kt = k.transpose(1, 0) # (num_head, seqLen, head_dim)
    vt = v.transpose(1, 0) # (num_head, seqLen, head_dim)
    scale = math.sqrt(head_dim)

    qkProduct = (qt @ kt.transpose(-2, -1)) * (1.0 / scale) # (num_head, 1, seqLen)
    # the causal mask is not needed for decoding
    softmax = F.softmax(qkProduct, dim=-1)
    attn = softmax @ vt # (num_head, 1, seqLen) x (num_head, seqLen, head_dim) -> (num_head, 1, head_dim)
    return attn.transpose(0,1) # (1, num_head, head_dim)

In [24]:
class KvCache:
    def __init__(self, start_page_idx: int, seq_len: int, page_size: int):
        self.page_size = page_size
        num_pages = math.ceil(seq_len / page_size)
        self.kv_last_page_len = seq_len - (num_pages - 1) * page_size
        self.kv_page_indices = [i for i in range(start_page_idx, start_page_idx + num_pages)]
        self.kv_len = seq_len

    def increment(self):
        self.kv_len += 1
        self.kv_last_page_len += 1
        if self.kv_last_page_len > self.page_size:
            self.kv_last_page_len -= self.page_size
            self.kv_page_indices.append(self.kv_page_indices[-1] + 1)
            

class BatchedKvCache:
    """Key-value cache for a batch of sequences."""

    def __init__(self, page_size: int, context_length: int, seq_lens: List[int]):
        self.page_size = page_size
        batch_size = len(seq_lens)
        num_pages_per_seq = math.ceil(context_length / page_size)
        total_num_pages = num_pages_per_seq * batch_size
        self.kv_cache = torch.zeros(total_num_pages, 2, page_size, num_head, head_dim, dtype=dtype, device=device)
        kvCacheList = []
        for i, seq_len in enumerate(seq_lens):
            start_page_idx = num_pages_per_seq * i
            kvCacheList.append(KvCache(start_page_idx, seq_len, page_size))

        self.kv_cache_pages_info = kvCacheList

    def increment(self, kv_active: List[bool]):
        for i, kvCache in enumerate(self.kv_cache_pages_info):
            if kv_active[i]:
                kvCache.increment()

    def computeActiveKvData(self, kv_active: List[bool]):
        kv_page_indices_list = []
        kv_page_indptr_list = []
        kv_last_page_len_list = []
        cum_pages = 0
        for i, kvCache in enumerate(self.kv_cache_pages_info):
            if kv_active[i]:
                kv_page_indices_list.extend(kvCache.kv_page_indices)
                kv_page_indptr_list.append(cum_pages)
                kv_last_page_len_list.append(kvCache.kv_last_page_len)
                cum_pages += len(kvCache.kv_page_indices)

        kv_page_indptr_list.append(cum_pages)
        kv_page_indices = torch.tensor(kv_page_indices_list, dtype=torch.int32, device=device)
        kv_page_indptr = torch.tensor(kv_page_indptr_list, dtype=torch.int32, device=device)
        kv_last_page_len = torch.tensor(kv_last_page_len_list, dtype=torch.int32, device=device)
        return kv_page_indices, kv_page_indptr, kv_last_page_len

In [25]:
class NaiveKvCache:
    """Key-value cache to assist baseline attention computation """
    
    def __init__(self, k: torch.tensor, v: torch.tensor, seqLens: List[int]):
        batch_size = len(seqLens)
        kList = [ None ] * batch_size
        vList = [ None ] * batch_size
        newKvIdx = 0
        for i, seqLen in enumerate(seqLens):
            kList[i] = k[newKvIdx: newKvIdx + seqLen]
            vList[i] = v[newKvIdx: newKvIdx + seqLen]
            newKvIdx += seqLen

        self.kList = kList # per sequence key
        self.vList = vList # per sequence value

    def append(self, newK: torch.tensor, newV: torch.tensor, kv_active: List[bool]):
        batch_size = len(kv_active)
        newKvIdx = 0
        for i, active in enumerate(kv_active):
            if active:
                self.kList[i] = torch.cat([self.kList[i], newK[newKvIdx: newKvIdx+1]], dim=0)
                self.vList[i] = torch.cat([self.vList[i], newV[newKvIdx: newKvIdx+1]], dim=0)
                newKvIdx += 1

    def getActiveKvData(self, kv_active: List[bool]):
        activeK = [ self.kList[i] for i, active in enumerate(kv_active) if active ]
        activeV = [ self.vList[i] for i, active in enumerate(kv_active) if active ]
        seqLens = [ k.data.shape[0] for i, k in enumerate(self.kList) if kv_active[i] ]
        return torch.cat(activeK, dim=0), torch.cat(activeV, dim=0), seqLens

In [26]:
def batch_prefill_baseline_wrapper(q, naiveKvCache, kv_active):
    k, v, seqLens = naiveKvCache.getActiveKvData(kv_active)
    return batch_prefill_baseline(q, k, v, seqLens)   

# hyperparameters
page_size = 16 # number of tokens a page contains
kv_layout = "NHD"
causal = True
pos_encoding_mode = "NONE"
context_length = 32

# sequence info
seqLens = [3, 3, 3] # [18, 5, 22]
batch_size = len(seqLens)
totalSeqLen = sum(seqLens)

# kvq initialization
torch.manual_seed(0xABCDABCD987)
k_prefill = torch.randn(totalSeqLen, num_head, head_dim, dtype=dtype, device=device)
v_prefill = torch.randn(totalSeqLen, num_head, head_dim, dtype=dtype, device=device)
q = torch.randn(totalSeqLen, num_head, head_dim, dtype=dtype, device=device)

# batch length info
qo_indptr = torch.cat(
    [torch.zeros(1, dtype=torch.int32, device=device), torch.cumsum(torch.tensor(seqLens, dtype=torch.int32, device=device), dim=0)]
).int()

# kv cache allocation
kv_active = [ seqLen > 0 for seqLen in seqLens ]
batchKvCache = BatchedKvCache(page_size, context_length, seqLens)

# move the kv data to cache
naiveKvCache = NaiveKvCache(k_prefill, v_prefill, seqLens)
kv_page_indices, kv_page_indptr, kv_last_page_len = batchKvCache.computeActiveKvData(kv_active)
flashinfer.append_paged_kv_cache(
    k_prefill,
    v_prefill,
    qo_indptr,
    batchKvCache.kv_cache,
    kv_page_indices,
    kv_page_indptr,
    kv_last_page_len)

# compute prefill attention
workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0)
prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
    workspace_buffer, kv_layout
)

prefill_wrapper.begin_forward(
    qo_indptr,
    kv_page_indptr,
    kv_page_indices,
    kv_last_page_len,
    num_head,
    num_head,
    head_dim,
)

flashinfer_attn = prefill_wrapper.forward(q, batchKvCache.kv_cache, causal=causal, pos_encoding_mode=pos_encoding_mode)
baseline_attn = batch_prefill_baseline_wrapper(q, naiveKvCache, kv_active)
assert_close(flashinfer_attn, baseline_attn)

In [27]:
# continously decoding
def batch_decode_baseline_wrapper(q_decode, naiveKvCache, kv_active):
    k, v, seqLens = naiveKvCache.getActiveKvData(kv_active)
    return batch_decode_baseline(q_decode, k, v, seqLens)

def batch_decode_flashinfer_wrapper(q_decode, batchKvCache, kv_active, workspace_buffer, kv_layout):
    kv_page_indices, kv_page_indptr, kv_last_page_len = batchKvCache.computeActiveKvData(kv_active)
    decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
        workspace_buffer, kv_layout
    )
    decode_wrapper.begin_forward(
        kv_page_indptr,
        kv_page_indices,
        kv_last_page_len,
        num_head,
        num_head,
        head_dim,
        page_size,
        "NONE",
        dtype,
    )
    return decode_wrapper.forward(q_decode, batchKvCache.kv_cache, pos_encoding_mode=pos_encoding_mode)    

def append_kv_cache_wrapper(batchKvCache, kv_active):
    batch_size_decode = sum(kv_active)
    qo_decode_indptr = torch.arange(0, batch_size_decode + 1, dtype=torch.int32, device=device)
    batchKvCache.increment(kv_active)
    kv_page_indices, kv_page_indptr, kv_last_page_len = batchKvCache.computeActiveKvData(kv_active)
    flashinfer.append_paged_kv_cache(
        k_decode,
        v_decode,
        qo_decode_indptr,
        batchKvCache.kv_cache,
        kv_page_indices,
        kv_page_indptr,
        kv_last_page_len
    )    

for _ in range(5):
    kv_active = [False, True, True] # the first sequence stopped, so only 2nd and 3rd ones are active
    batch_size_decode = sum(kv_active)
    k_decode = torch.randn(batch_size_decode, num_head, head_dim, dtype=dtype, device=device)
    v_decode = torch.randn(batch_size_decode, num_head, head_dim, dtype=dtype, device=device)
    q_decode = torch.randn(batch_size_decode, num_head, head_dim, dtype=dtype, device=device)

    # decode
    flashinfer_decode_attn = batch_decode_flashinfer_wrapper(q_decode, batchKvCache, kv_active, workspace_buffer, kv_layout)
    baseline_decode_attn = batch_decode_baseline_wrapper(q_decode, naiveKvCache, kv_active)
    assert_close(flashinfer_decode_attn, baseline_decode_attn)
    
    # append kv data to cache
    naiveKvCache.append(k_decode, v_decode, kv_active)
    append_kv_cache_wrapper(batchKvCache, kv_active)