From 3cb21bd1fd3ab7ca598594548ed153136820ab6b Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein <669761+bottler@users.noreply.github.com> Date: Wed, 27 Aug 2025 05:46:52 -0700 Subject: [PATCH] pad dequantized paged fp8 kv with zeros (#4780) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/4780 X-link: https://github.com/facebookresearch/FBGEMM/pull/1803 Pad zeros after the end of used sequences to avoid nans in flash attention 3, in the dequantization of fp8 paged kv-cache. This is analogous to the non-paged case which was tackled in D69522001. Differential Revision: D80977902 --- .../gen_ai/src/kv_cache/kv_cache.cu | 27 +++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu index 0f94ccc163..094ff3cd79 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu @@ -2719,8 +2719,8 @@ __global__ void dequantize_fp8_cache_kernel_paged( auto max_t = kv_seqlen[b]; // one warp per T/H - for (auto t_h = threadIdx.y + blockIdx.y * blockDim.y; t_h < max_t * N_KVH; - t_h += blockDim.y * gridDim.y) { + auto t_h = threadIdx.y + blockIdx.y * blockDim.y; + for (; t_h < max_t * N_KVH; t_h += blockDim.y * gridDim.y) { auto h = t_h % N_KVH; auto t = t_h / N_KVH; @@ -2774,6 +2774,29 @@ __global__ void dequantize_fp8_cache_kernel_paged( *reinterpret_cast(&row_v_dq[4 * threadIdx.x]) = *reinterpret_cast(&kv_dq.vals[2]); } + + // zero out the rest of the page, because FA3 can be affected by + // NaN values beyond the sequence length. + max_t = (max_t + page_size - 1) / page_size * page_size; + for (; t_h < max_t * N_KVH; t_h += blockDim.y * gridDim.y) { + if (4 * threadIdx.x >= D_H) { + continue; + } + auto h = t_h % N_KVH; + auto t = t_h / N_KVH; + + int page_logical_idx = t / page_size; + int page_offset = t % page_size; + int page_physical_idx = + block_tables[b * block_tables_b_stride + page_logical_idx]; + int physical_t = page_physical_idx * page_size + page_offset; + + auto* row_k_dq = &cache_K_dq[0][physical_t][h][0]; + auto* row_v_dq = &cache_V_dq[0][physical_t][h][0]; + + memset(&row_k_dq[4 * threadIdx.x], 0, sizeof(uint2)); + memset(&row_v_dq[4 * threadIdx.x], 0, sizeof(uint2)); + } } #endif