From 304dbb78df1fd15d75e38330bec65a2510ba4e7e Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Wed, 2 Apr 2025 12:18:03 -0700 Subject: [PATCH] Add quantized q @ k test for intented used in quantized attention Summary: Adds tests for q @ k.T as intended in SDPA op. Reviewed By: metascroy Differential Revision: D71370604 --- .../cpu/aarch64/tests/test_qmatmul.cpp | 98 ++++++++ .../kernels/cpu/aarch64/tests/test_utils.h | 1 + .../tests/test_utils_quantized_attention.h | 235 ++++++++++++++++++ 3 files changed, 334 insertions(+) create mode 100644 torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp index e7e2d09c64..344b2c4915 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp @@ -12,6 +12,7 @@ #include #include #include +#include float kTol = 0.0001; @@ -411,4 +412,101 @@ INSTANTIATE_TEST_SUITE_P( FP32A_QuantizedB_FP32C_Test, ::testing::Values(0.0, 1.0, 2.69)); +static void test_8bit_per_token_q_at_k_matmul_attention( + int b, + int s_q, + int s_k, + int h, + int d, + bool transpose = true) { + auto test_case = torchao:: + channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case:: + generate(b, s_q, s_k, h, d, transpose); + + using namespace torchao::kernels::cpu::aarch64::quantized_matmul:: + channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot; + + size_t q_b_stride = test_case.b_q_stride; + size_t q_h_stride = test_case.h_q_stride; + size_t q_s_q_stride = test_case.s_q_stride; + size_t q_scale_zp_b_stride = test_case.b_q_qparams_stride; + size_t q_scale_zp_h_stride = test_case.h_q_qparams_stride; + size_t q_scale_zp_s_stride = test_case.s_q_qparams_stride; + + size_t k_b_stride = test_case.b_k_stride; + size_t k_h_stride = test_case.h_k_stride; + size_t k_s_k_stride = test_case.s_k_stride; + size_t k_scale_zp_b_stride = test_case.b_k_qparams_stride; + size_t k_scale_zp_h_stride = test_case.h_k_qparams_stride; + size_t k_scale_zp_s_stride = test_case.s_k_qparams_stride; + + std::vector output(b * h * s_q * s_k); + size_t output_b_stride = h * s_q * s_k; + size_t output_h_stride = s_q * s_k; + size_t output_s_q_stride = s_k; + + for (int b_idx = 0; b_idx < b; b_idx++) { + for (int h_idx = 0; h_idx < h; h_idx++) { + kernel( + s_q, + s_k, + d, + test_case.q_qvals.data() + b_idx * q_b_stride + h_idx * q_h_stride, + q_s_q_stride /*lhs_stride_m*/, + test_case.k_qvals.data() + b_idx * k_b_stride + h_idx * k_h_stride, + k_s_k_stride /*rhs_stride_n*/, + output.data() + b_idx * output_b_stride + h_idx * output_h_stride, + output_s_q_stride /*out_stride_n*/, + test_case.q_zeros.data() + b_idx * q_scale_zp_b_stride + + h_idx * q_scale_zp_h_stride, + test_case.k_zeros.data() + b_idx * k_scale_zp_b_stride + + h_idx * k_scale_zp_h_stride, + test_case.q_scales.data() + b_idx * q_scale_zp_b_stride + + h_idx * q_scale_zp_h_stride, + test_case.k_scales.data() + b_idx * k_scale_zp_b_stride + + h_idx * k_scale_zp_h_stride, + q_scale_zp_s_stride /*lhs qparams stride*/, + k_scale_zp_s_stride /*rhs qparams stride*/); + } + } + + for (int i = 0; i < b * h * s_q * s_k; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } +} + +TEST(test_8bit_per_token_q_at_k_matmul_attention, Basic) { + test_8bit_per_token_q_at_k_matmul_attention(1, 16, 16, 8, 16); +} + +TEST(test_8bit_per_token_q_at_k_matmul_attention, PrimeHeadsAndHeadDim) { + test_8bit_per_token_q_at_k_matmul_attention(1, 8, 8, 7, 33); +} + +TEST( + test_8bit_per_token_q_at_k_matmul_attention, + PrimeHeadsAndHeadDimDiffSqSk) { + test_8bit_per_token_q_at_k_matmul_attention(1, 7, 16, 7, 33); +} + +TEST(test_8bit_per_token_q_at_k_matmul_attention, PrimeHeadsAndSmallHeadDim) { + test_8bit_per_token_q_at_k_matmul_attention(1, 8, 8, 7, 3); +} + +TEST(test_8bit_per_token_q_at_k_matmul_attention, BasicNoTransposed) { + test_8bit_per_token_q_at_k_matmul_attention(1, 16, 16, 8, 16, false); +} + +TEST( + test_8bit_per_token_q_at_k_matmul_attention, + PrimeHeadsAndHeadDimDiffSqSkNoTranspose) { + test_8bit_per_token_q_at_k_matmul_attention(1, 7, 16, 7, 33, false); +} + +TEST( + test_8bit_per_token_q_at_k_matmul_attention, + PrimeHeadsAndSmallHeadDimNoTranspose) { + test_8bit_per_token_q_at_k_matmul_attention(1, 8, 8, 7, 3, false); +} + #endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h index e411211eb4..4f96f8bf96 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h @@ -86,6 +86,7 @@ inline float get_float_from_bf16(uint16_t bf16) { namespace test_utils { auto generate_per_token_quantized_tensor(int m, int n, bool transposed = false); + auto generate_per_token_quantized_tensor(int m, int n, bool transposed) { auto activations = get_random_vector(m * n, -1.0, 1.0); auto activation_qvals = std::vector(m * n, 0); diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h new file mode 100644 index 0000000000..9ca86ece76 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils_quantized_attention.h @@ -0,0 +1,235 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#if defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include +#include +#include +#include +#include +#include + +namespace torchao { +struct channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case { + int b; + int s_q; + int s_k; + int h; + int d; + bool tranposed; + + size_t b_q_stride; + size_t h_q_stride; + size_t s_q_stride; + + size_t b_k_stride; + size_t h_k_stride; + size_t s_k_stride; + + size_t b_q_qparams_stride; + size_t h_q_qparams_stride; + size_t s_q_qparams_stride; + + size_t b_k_qparams_stride; + size_t h_k_qparams_stride; + size_t s_k_qparams_stride; + + std::vector expected_output; + + std::vector q; + std::vector q_qvals; + std::vector q_scales; + std::vector q_zeros; + + std::vector k; + std::vector k_qvals; + std::vector k_scales; + std::vector k_zeros; + + channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case( + int b_, + int s_q_, + int s_k_, + int h_, + int d_, + int transposed_, + size_t b_q_stride_, + size_t h_q_stride_, + size_t s_q_stride_, + size_t b_k_stride_, + size_t h_k_stride_, + size_t s_k_stride_, + size_t b_q_qparams_stride_, + size_t h_q_qparams_stride_, + size_t s_q_qparams_stride_, + size_t b_k_qparams_stride_, + size_t h_k_qparams_stride_, + size_t s_k_qparams_stride_, + std::vector expected_output_, + std::vector q_, + std::vector q_qvals_, + std::vector q_scales_, + std::vector q_zeros_, + std::vector k_, + std::vector k_qvals_, + std::vector k_scales_, + std::vector k_zeros_) + : b(b_), + s_q(s_q_), + s_k(s_k_), + h(h_), + d(d_), + tranposed(transposed_), + b_q_stride(b_q_stride_), + h_q_stride(h_q_stride_), + s_q_stride(s_q_stride_), + b_k_stride(b_k_stride_), + h_k_stride(h_k_stride_), + s_k_stride(s_k_stride_), + b_q_qparams_stride(b_q_qparams_stride_), + h_q_qparams_stride(h_q_qparams_stride_), + s_q_qparams_stride(s_q_qparams_stride_), + b_k_qparams_stride(b_k_qparams_stride_), + h_k_qparams_stride(h_k_qparams_stride_), + s_k_qparams_stride(s_k_qparams_stride_), + expected_output(expected_output_), + q(q_), + q_qvals(q_qvals_), + q_scales(q_scales_), + q_zeros(q_zeros_), + k(k_), + k_qvals(k_qvals_), + k_scales(k_scales_), + k_zeros(k_zeros_) { + assert(expected_output.size() == b * s_q * h * s_k); + assert(q.size() == b * s_q * h * d); + assert(q_qvals.size() == b * s_q * h * d); + assert(q_scales.size() == b * s_q * h); + assert(q_zeros.size() == b * s_q * h); + assert(k.size() == b * s_k * h * d); + assert(k_qvals.size() == b * s_k * h * d); + assert(k_scales.size() == b * s_k * h); + assert(k_zeros.size() == b * s_k * h); + } + + static channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case + generate(int b, int s_q, int s_k, int h, int d, bool transposed = true) { + // Generate activations + auto [lhs, lhs_qvals, lhs_scales, lhs_zeros] = + torchao::test_utils::generate_per_token_quantized_tensor( + b * s_q * h, d); + + auto [rhs, rhs_qvals, rhs_scales, rhs_zeros] = + torchao::test_utils::generate_per_token_quantized_tensor( + b * s_k * h, d); + // Above function produces nxk matrix and to produce kxn you need transposed + // = true. we do !rhs_is_transposed becaues when rhs_is_transposed = true + // the shape should be nxk instead of kxn. + + size_t b_q_stride = h * s_q * d; + size_t h_q_stride = s_q * d; + size_t s_q_stride = d; + + size_t b_k_stride = h * s_k * d; + size_t h_k_stride = s_k * d; + size_t s_k_stride = d; + + size_t b_q_qparams_stride = h * s_q; + size_t h_q_qparams_stride = s_q; + size_t s_q_qparams_stride = 1; + + size_t b_k_qparams_stride = h * s_k; + size_t h_k_qparams_stride = s_k; + size_t s_k_qparams_stride = 1; + + if (!transposed) { + h_q_stride = d; + s_q_stride = h * d; + h_k_stride = d; + s_k_stride = h * d; + + s_q_qparams_stride = h; + h_q_qparams_stride = 1; + + s_k_qparams_stride = h; + h_k_qparams_stride = 1; + } + + // Compute expected output + std::vector expected_output(b * h * s_q * s_k); + size_t b_out_stride = h * s_q * s_k; + size_t h_out_stride = s_q * s_k; + size_t s_q_out_stride = s_k; + + for (int b_idx = 0; b_idx < b; b_idx++) { + for (int s_q_idx = 0; s_q_idx < s_q; s_q_idx++) { + for (int h_idx = 0; h_idx < h; h_idx++) { + for (int s_k_idx = 0; s_k_idx < s_k; s_k_idx++) { + float res = 0.0; + for (int d_idx = 0; d_idx < d; d_idx++) { + int lhs_idx = b_idx * b_q_stride + s_q_idx * s_q_stride + + h_idx * h_q_stride + d_idx; + int rhs_idx = b_idx * b_k_stride + s_k_idx * s_k_stride + + h_idx * h_k_stride + d_idx; + int lhs_scales_zp_idx = b_idx * b_q_qparams_stride + + h_idx * h_q_qparams_stride + s_q_idx * s_q_qparams_stride; + int rhs_scales_zp_idx = b_idx * b_k_qparams_stride * h + + h_idx * h_k_qparams_stride + s_k_idx * s_k_qparams_stride; + float lhs_dequant = lhs_scales[lhs_scales_zp_idx] * + (lhs_qvals[lhs_idx] - lhs_zeros[lhs_scales_zp_idx]); + + float rhs_dequant = rhs_scales[rhs_scales_zp_idx] * + (rhs_qvals[rhs_idx] - rhs_zeros[rhs_scales_zp_idx]); + + res += lhs_dequant * rhs_dequant; + } + expected_output + [b_idx * b_out_stride + s_q_idx * s_q_out_stride + + h_idx * h_out_stride + s_k_idx] = res; + } + } + } + } + + // Return test case + return channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case( + b, + s_q, + s_k, + h, + d, + transposed, + b_q_stride, + h_q_stride, + s_q_stride, + b_k_stride, + h_k_stride, + s_k_stride, + b_q_qparams_stride, + h_q_qparams_stride, + s_q_qparams_stride, + b_k_qparams_stride, + h_k_qparams_stride, + s_k_qparams_stride, + expected_output, + lhs, + lhs_qvals, + lhs_scales, + lhs_zeros, + rhs, + rhs_qvals, + rhs_scales, + rhs_zeros); + } +}; + +} // namespace torchao + +#endif // defined(__aarch64__) || defined(__ARM_NEON)