Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2719,8 +2719,8 @@ __global__ void dequantize_fp8_cache_kernel_paged(
auto max_t = kv_seqlen[b];

// one warp per T/H
for (auto t_h = threadIdx.y + blockIdx.y * blockDim.y; t_h < max_t * N_KVH;
t_h += blockDim.y * gridDim.y) {
auto t_h = threadIdx.y + blockIdx.y * blockDim.y;
for (; t_h < max_t * N_KVH; t_h += blockDim.y * gridDim.y) {
auto h = t_h % N_KVH;
auto t = t_h / N_KVH;

Expand Down Expand Up @@ -2774,6 +2774,29 @@ __global__ void dequantize_fp8_cache_kernel_paged(
*reinterpret_cast<uint2*>(&row_v_dq[4 * threadIdx.x]) =
*reinterpret_cast<uint2*>(&kv_dq.vals[2]);
}

// zero out the rest of the page, because FA3 can be affected by
// NaN values beyond the sequence length.
max_t = (max_t + page_size - 1) / page_size * page_size;
for (; t_h < max_t * N_KVH; t_h += blockDim.y * gridDim.y) {
if (4 * threadIdx.x >= D_H) {
continue;
}
auto h = t_h % N_KVH;
auto t = t_h / N_KVH;

int page_logical_idx = t / page_size;
int page_offset = t % page_size;
int page_physical_idx =
block_tables[b * block_tables_b_stride + page_logical_idx];
int physical_t = page_physical_idx * page_size + page_offset;

auto* row_k_dq = &cache_K_dq[0][physical_t][h][0];
auto* row_v_dq = &cache_V_dq[0][physical_t][h][0];

memset(&row_k_dq[4 * threadIdx.x], 0, sizeof(uint2));
memset(&row_v_dq[4 * threadIdx.x], 0, sizeof(uint2));
}
}
#endif

Expand Down
Loading