Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ def _ragged_pagedattention_generate_qkv(
"constant", 0)
q = torch.randn((max_num_batched_tokens, num_q_heads, head_dim),
dtype=dtype)
k_pages = torch.randn((num_pages, page_size, num_kv_heads, head_dim),
k_pages = torch.randn((num_pages, page_size, num_kv_heads * head_dim),
dtype=dtype)
v_pages = torch.randn((num_pages, page_size, num_kv_heads, head_dim),
v_pages = torch.randn((num_pages, page_size, num_kv_heads * head_dim),
dtype=dtype)
page_indices = torch.randint(
0, num_pages, (max_num_seqs, pages_per_seq), dtype=torch.int32)
Expand Down
80 changes: 10 additions & 70 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,70 +893,10 @@ def flash_attention(
sm_scale, ab, partition_spec, mesh)


def ceil_div(a, b):
assert b != 0
return (a + b - 1) // b


def validate_ragged_paged_attention_inputs(
q, # [max_num_batched_tokens, num_q_heads, head_dim]
k_pages, # [total_num_pages, page_size, num_kv_heads, head_dim]
v_pages, # [total_num_pages, page_size, num_kv_heads, head_dim]
kv_lens, # i32[max_num_seqs]
page_indices, # i32[max_num_seqs, pages_per_seq]
cu_q_lens, # i32[max_num_seqs + 1]
num_seqs, # i32[1]
):
_, num_q_heads, head_dim = q.shape
_, _, num_kv_heads, head_dim_k = k_pages.shape
max_num_seqs, _ = page_indices.shape
if k_pages.shape != v_pages.shape:
raise ValueError(
f"{k_pages.shape=} and {v_pages.shape=} must have the same shape.")
if head_dim_k != head_dim:
raise ValueError(
f"Q head_dim {head_dim} must be the same as that of K/V {head_dim_k}.")
if kv_lens.shape != (max_num_seqs,):
raise ValueError(f"Expected {kv_lens.shape=} to be ({max_num_seqs},) where"
" `max_num_seqs` is `page_indices.shape[0]`.")
if cu_q_lens.shape != (max_num_seqs + 1,):
raise ValueError(
f"Expected {cu_q_lens.shape=} to be ({max_num_seqs + 1},) where"
" `max_num_seqs` is `page_indices.shape[0]`.")
if (kv_lens.dtype != torch.int32 or page_indices.dtype != torch.int32 or
cu_q_lens.dtype != torch.int32):
raise ValueError(
"The dtype of `kv_lens`, `page_indices`, and `cu_q_lens` must be"
f" int32. Got {kv_lens.dtype=}, {page_indices.dtype=},"
f" {cu_q_lens.dtype=}.")
if num_q_heads % num_kv_heads != 0:
raise ValueError(f"{num_q_heads=} must be divisible by {num_kv_heads=}")

# Must check below on runtime!
# if num_seqs > max_num_seqs:
# raise ValueError(f"{num_seqs=} must be less or equal to {max_num_seqs=}")
# max_kv_len = torch.max(kv_lens)
# min_pages_per_seq = ceil_div(max_kv_len, page_size)
# if pages_per_seq < min_pages_per_seq:
# raise ValueError(
# f"{pages_per_seq=} must be greater or equal to"
# f" {min_pages_per_seq=} given {max_kv_len=} and {page_size=}.")
# if cu_q_lens[num_seqs] > max_num_batched_tokens:
# raise ValueError(
# f"Total q tokens {cu_q_lens[num_seqs]} must be less or equal to"
# f" {max_num_batched_tokens=}.")
# for i in range(num_seqs):
# q_len = cu_q_lens[i + 1] - cu_q_lens[i]
# kv_len = kv_lens[i]
# if q_len > kv_len:
# raise ValueError(
# f"{q_len=} must be less or equal to {kv_len=} at sequence {i}.")


def _ragged_paged_attention_nonkernel(
queries, # [max_num_batched_tokens, num_q_heads, head_dim]
k_pages, # [total_num_pages, page_size, num_kv_heads, head_dim]
v_pages, # [total_num_pages, page_size, num_kv_heads, head_dim]
k_pages, # [total_num_pages, page_size, num_kv_heads * head_dim]
v_pages, # [total_num_pages, page_size, num_kv_heads * head_dim]
kv_lens, # i32[max_num_seqs]
page_indices, # i32[max_num_seqs, pages_per_seq]
cu_q_lens, # i32[max_num_seqs + 1]
Expand All @@ -965,8 +905,9 @@ def _ragged_paged_attention_nonkernel(
sm_scale=1.0,
mask_value=DEFAULT_MASK_VALUE,
):
_, _, num_kv_heads, head_dim = k_pages.shape
num_q_heads = queries.shape[1]
_, num_q_heads, head_dim = queries.shape
_, _, kv_model_dim = k_pages.shape
num_kv_heads = kv_model_dim // head_dim
assert num_q_heads % num_kv_heads == 0
num_query_per_kv = num_q_heads // num_kv_heads
outputs = []
Expand All @@ -977,8 +918,8 @@ def _ragged_paged_attention_nonkernel(
kv_len = kv_lens[i]
indices = page_indices[i]
q = queries[q_start:q_end]
k = k_pages[indices, :, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len]
v = v_pages[indices, :, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len]
k = k_pages[indices, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len]
v = v_pages[indices, :, :].reshape(-1, num_kv_heads, head_dim)[:kv_len]
k = torch.repeat_interleave(k, num_query_per_kv, dim=1)
v = torch.repeat_interleave(v, num_query_per_kv, dim=1)
attn = torch.einsum("qhd,khd->hqk", q, k)
Expand All @@ -998,8 +939,8 @@ def _ragged_paged_attention_nonkernel(
@requires_jax
def ragged_paged_attention(
q, # [max_num_batched_tokens, num_q_heads, head_dim]
k_pages, # [total_num_pages, page_size, num_kv_heads, head_dim]
v_pages, # [total_num_pages, page_size, num_kv_heads, head_dim]
k_pages, # [total_num_pages, page_size, num_kv_heads * head_dim]
v_pages, # [total_num_pages, page_size, num_kv_heads * head_dim]
kv_lens, # i32[max_num_seqs]
page_indices, # i32[max_num_seqs, pages_per_seq]
cu_q_lens, # i32[max_num_seqs + 1]
Expand All @@ -1014,8 +955,7 @@ def ragged_paged_attention(
):
if mask_value is None:
mask_value = DEFAULT_MASK_VALUE
validate_ragged_paged_attention_inputs(q, k_pages, v_pages, kv_lens,
page_indices, cu_q_lens, num_seqs)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why stopped checking validate_ragged_paged_attention_inputs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we have these static shape check in JAX already

if not use_kernel:
return _ragged_paged_attention_nonkernel(
q,
Expand Down
Loading
Loading