From 056a4b7d872fecc1c83349ca61268a73a2081232 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Fri, 11 Apr 2025 07:05:23 -0700 Subject: [PATCH] {Executorch][llm] quantized sdpa. update attn_scores @ v gemm Summary: Dequantize gemm when doing prefill like op, else use custom kernel Reviewed By: metascroy Differential Revision: D71833065 --- extension/llm/custom_ops/op_sdpa_impl.h | 108 +++++++++++++++++------- 1 file changed, 78 insertions(+), 30 deletions(-) diff --git a/extension/llm/custom_ops/op_sdpa_impl.h b/extension/llm/custom_ops/op_sdpa_impl.h index 76dbf776700..689537923d5 100644 --- a/extension/llm/custom_ops/op_sdpa_impl.h +++ b/extension/llm/custom_ops/op_sdpa_impl.h @@ -202,6 +202,49 @@ void dequantize_per_channel_optimized( } } +void dequant_and_gemm( + const int64_t m, + const int64_t n, + const int64_t k, + float* qk_data, + const int64_t qk_stride_m, + const MaybeQuantizedMatrixData& v_data, + const int64_t v_stride_n, + float* o_data, + const int64_t o_stride_m, + const float beta) { + std::vector dequantized_v_data(v_data.m * v_data.n); + dequantize_per_channel_optimized( + static_cast(v_data.data), + static_cast(v_data.scales), + static_cast(v_data.zero_points), + dequantized_v_data.data(), + -128, + 127, + 1, + 0, + 0, + v_data.m, + v_stride_n, + v_data.n, + v_data.n, + v_data.zero_points_stride); + ::executorch::cpublas::gemm( + ::executorch::cpublas::TransposeType::NoTranspose, + ::executorch::cpublas::TransposeType::NoTranspose, + n, + m, + k, + static_cast(1), + dequantized_v_data.data(), + v_data.n, + qk_data, + qk_stride_m, + beta, + o_data, + o_stride_m); +} + template void _qk_at_v_gemm( const int64_t m, @@ -216,36 +259,41 @@ void _qk_at_v_gemm( const accum_t beta) { if (v_data.dtype == ScalarType::Char) { if constexpr (std::is_same::value) { - std::vector dequantized_v_data(v_data.m * v_data.n); - dequantize_per_channel_optimized( - static_cast(v_data.data), - static_cast(v_data.scales), - static_cast(v_data.zero_points), - dequantized_v_data.data(), - -128, - 127, - 1, - 0, - 0, - v_data.m, - v_stride_n, - v_data.n, - v_data.n, - v_data.zero_points_stride); - ::executorch::cpublas::gemm( - ::executorch::cpublas::TransposeType::NoTranspose, - ::executorch::cpublas::TransposeType::NoTranspose, - n, - m, - k, - static_cast(1), - dequantized_v_data.data(), - v_data.n, - qk_data, - qk_stride_m, - beta, - o_data, - o_stride_m); + if (m > 4) { + // For larger batch sizes, dequantize and use BLAS for better + // performance + dequant_and_gemm( + m, + n, + k, + const_cast(qk_data), + qk_stride_m, + v_data, + v_stride_n, + o_data, + o_stride_m, + beta); + } else { + // For smaller batch sizes, use quantized gemm + int a_stride_m_tmp, b_stride_n_tmp; + auto kernel = torchao::kernels::cpu::quantized_matmul:: + get_fp32_a_input_channelwise_8bit_b_f32_c_matmul( + m, n, k, false, false, a_stride_m_tmp, b_stride_n_tmp); + kernel( + m, + n, + k, + qk_data, + qk_stride_m /*lhs_stride_m*/, + static_cast(v_data.data), + v_stride_n /*rhs_stride_n*/, + o_data, + o_stride_m /*out_stride_n*/, + static_cast(v_data.zero_points), + static_cast(v_data.scales), + beta, + v_data.zero_points_stride); + } } else { ET_CHECK_MSG( false, "Accumulation in dtype other than float not supported yet");