diff --git a/torchao/experimental/kernels/cpu/fallback/matmul/fp32_a_channelwise_8bit_b_fp32_c.h b/torchao/experimental/kernels/cpu/fallback/matmul/fp32_a_channelwise_8bit_b_fp32_c.h new file mode 100644 index 0000000000..58e2853617 --- /dev/null +++ b/torchao/experimental/kernels/cpu/fallback/matmul/fp32_a_channelwise_8bit_b_fp32_c.h @@ -0,0 +1,50 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +// TODO: Remove all ::kernels. No need for extra namespace. +namespace torchao::kernels::cpu::fallback::quantized_matmul { +namespace fp32_a_input_channelwise_8bit_b_fp32 { +template +void kernel( + int m, + int n, + int k, + const float* lhs, + int lhs_stride_m, + const int8_t* rhs, + int rhs_stride_n, + float* output, + int out_stride_m, + const int8_t* rhs_zero_points, + const float* rhs_scales, + const float beta, + const int rhs_qparams_stride) { + assert(a_transposed == false); + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + float res = 0.0; + for (int k_idx = 0; k_idx < k; k_idx++) { + int lhs_idx = m_idx * lhs_stride_m + k_idx; + int rhs_idx = k_idx * rhs_stride_n + n_idx; + if (b_transposed) { + rhs_idx = n_idx * rhs_stride_n + k_idx; + } + float rhs_dequant = rhs_scales[k_idx * rhs_qparams_stride] * + (static_cast(rhs[rhs_idx]) - + static_cast(rhs_zero_points[k_idx * rhs_qparams_stride])); + + res += lhs[lhs_idx] * rhs_dequant; + } + output[m_idx * n + n_idx] = output[m_idx * n + n_idx] * beta + res; + } + } +} +} // namespace fp32_a_input_channelwise_8bit_b_fp32 +} // namespace torchao::kernels::cpu::fallback::quantized_matmul diff --git a/torchao/experimental/kernels/cpu/interface/quantized_matmul.h b/torchao/experimental/kernels/cpu/interface/quantized_matmul.h index 01a4c704c5..718f7eaad9 100644 --- a/torchao/experimental/kernels/cpu/interface/quantized_matmul.h +++ b/torchao/experimental/kernels/cpu/interface/quantized_matmul.h @@ -9,6 +9,8 @@ #include #include +#include + #if defined(__aarch64__) || defined(__ARM_NEON) #include #include @@ -85,4 +87,72 @@ get_int8_a_int8_b_channelwise_qmatmul( channelwise_8bit_a_channelwise_8bit_b::kernel; } } + +/* +a_stride_m: stride of a in memory to indiciate how far apart each row is. +b_stride_n: stride of b in memory to indiciate how far apart each row is. +If b is transposed (n x k), then this is how many bytes to skip to get to the +next row. If b is not transposed (k x n), then this is how many bytes to skip to +get to the next row. + +It also returns the stride of a and b, that should be used in the kernel. + +Will need to think of a better way to find the right +ukernel. Perhaps via ukernelconfig + registry?. +*/ +using fp32_a_input_channelwise_8bit_b_f32_c_matmul_type = void (*)( + int, + int, + int, + const float*, + int, + const int8_t*, + int, + float*, + int, + const int8_t*, + const float*, + const float, + const int); + +fp32_a_input_channelwise_8bit_b_f32_c_matmul_type +get_fp32_a_input_channelwise_8bit_b_f32_c_matmul( + int m, + int n, + int k, + bool a_transposed, + bool b_transposed, + int& a_stride_m, + int& b_stride_n); + +fp32_a_input_channelwise_8bit_b_f32_c_matmul_type +get_fp32_a_input_channelwise_8bit_b_f32_c_matmul( + int m, + int n, + int k, + bool a_transposed, + bool b_transposed, + int& a_stride_m, + int& b_stride_n) { +#if defined(__aarch64__) || defined(__ARM_NEON) + if (!a_transposed && !b_transposed && n >= 16) { + a_stride_m = k; + b_stride_n = n; + return aarch64::quantized_matmul:: + fp32_a_input_channelwise_8bit_b_1x16x4_f32::kernel; + } +#endif // defined(__aarch64__) || defined(__ARM_NEON) + assert(!a_transposed); + if (b_transposed) { + a_stride_m = k; + b_stride_n = k; + return torchao::kernels::cpu::fallback::quantized_matmul:: + fp32_a_input_channelwise_8bit_b_fp32::kernel; + } else { + a_stride_m = k; + b_stride_n = n; + return torchao::kernels::cpu::fallback::quantized_matmul:: + fp32_a_input_channelwise_8bit_b_fp32::kernel; + } +} } // namespace torchao::kernels::cpu::quantized_matmul diff --git a/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp b/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp index 3629f0960b..4024f3f1de 100644 --- a/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp +++ b/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp @@ -446,3 +446,185 @@ TEST( Run( /*m=*/4, /*k=*/2, /*n=*/1, 32); } + +class FP32A_QuantizedB_FP32C_Interface_Test + : public ::testing::TestWithParam { + public: + int m; + int k; + int n; + int stride; + + bool rhs_has_zeros; + bool lhs_is_transposed; + bool rhs_is_transposed; + + std::vector init_output; + std::vector expected_output; + + std::vector lhs; + + std::vector rhs; + std::vector rhs_qvals; + std::vector rhs_scales; + std::vector rhs_zeros; + + void generate( + int m_, + int k_, + int n_, + bool rhs_has_zeros_, + bool lhs_is_transposed_, + bool rhs_is_transposed_, + int stride_ = 1) { + assert(!lhs_is_transposed_); + assert(rhs_has_zeros_); + m = m_; + k = k_; + n = n_; + stride = stride_; + rhs_has_zeros = rhs_has_zeros_; + lhs_is_transposed = lhs_is_transposed_; + rhs_is_transposed = rhs_is_transposed_; + + assert(!rhs_is_transposed || stride == 1); + + // Generate activations + lhs = get_random_vector(m * k, -1.0, 1.0); + + // The strange thing this is doing is that instead of quantizing + // each output channel separately, we are quantizing each input channel + // Reason why we do !rhs_is_transposed is because + // we actually want k x n matrix not n x k matrix + // because each input channel is quantized separately + std::tie(rhs, rhs_qvals, rhs_scales, rhs_zeros) = + generate_per_token_quantized_tensor(k * stride, n, rhs_is_transposed); + + // Compute expected output + init_output = get_random_vector(m * n, -1.0, 1.0); + + assert(init_output.size() == m * n); + assert(lhs.size() == m * k); + assert(rhs.size() == n * stride * k); + assert(rhs_qvals.size() == n * stride * k); + assert(rhs_scales.size() == k * stride); + assert(rhs_zeros.size() == k * stride); + } + + void execute(float beta) { + // Compute expected output + expected_output = init_output; + + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + float res = 0.0; + for (int k_idx = 0; k_idx < k; k_idx++) { + int lhs_idx = m_idx * k + k_idx; + int rhs_idx = k_idx * stride * n + n_idx; + if (rhs_is_transposed) { + rhs_idx = n_idx * k * stride + k_idx * stride; + } + float rhs_dequant = rhs_scales[k_idx * stride] * + (static_cast(rhs_qvals[rhs_idx]) - + static_cast(rhs_zeros[k_idx * stride])); + + res += lhs[lhs_idx] * rhs_dequant; + } + expected_output[m_idx * n + n_idx] = + expected_output[m_idx * n + n_idx] * beta + res; + } + } + } + + float beta() const { + return GetParam(); + } +}; + +static void test_fp32_a_input_channelwise_8bit_b( + int m, + int k, + int n, + float beta, + FP32A_QuantizedB_FP32C_Interface_Test& test_case, + int stride = 1) { + test_case.execute(beta); + + int a_stride_m, b_stride_n; + auto kernel = torchao::kernels::cpu::quantized_matmul:: + get_fp32_a_input_channelwise_8bit_b_f32_c_matmul( + m, n, k, false, false, a_stride_m, b_stride_n); + b_stride_n = b_stride_n * stride; + + std::vector output(test_case.init_output); + kernel( + m, + n, + k, + test_case.lhs.data(), + a_stride_m /*lhs_stride_m*/, + test_case.rhs_qvals.data(), + b_stride_n /*rhs_stride_n*/, + output.data(), + n /*out_stride_n*/, + test_case.rhs_zeros.data(), + test_case.rhs_scales.data(), + beta, + stride /*rhs qparams stride*/); + + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } +} + +TEST_P(FP32A_QuantizedB_FP32C_Interface_Test, BTranposedWithZeroPoints) { + generate(3, 128, 16, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/3, /*k=*/128, /*n=*/16, beta(), *this); +} + +TEST_P( + FP32A_QuantizedB_FP32C_Interface_Test, + BTranposedWithZeroPointsOddSizes) { + generate(4, 37, 19, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/37, /*n=*/19, beta(), *this); +} + +// Test shapes for which we have to use fallback kernel +TEST_P( + FP32A_QuantizedB_FP32C_Interface_Test, + BTranposedWithZeroPointsOddSizesFallback) { + generate(4, 37, 3, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/37, /*n=*/3, beta(), *this); +} + +TEST_P( + FP32A_QuantizedB_FP32C_Interface_Test, + BTranposedWithZeroPointsOddSizes2Fallback) { + generate(4, 1, 3, true, false, false); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/1, /*n=*/3, beta(), *this); +} + +TEST_P( + FP32A_QuantizedB_FP32C_Interface_Test, + BTranposedWithZeroPointsOddSizesStrided) { + generate(4, 37, 19, true, false, false, 32); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/37, /*n=*/19, beta(), *this, 32); +} + +TEST_P( + FP32A_QuantizedB_FP32C_Interface_Test, + BTranposedWithZeroPointsOddSizes2FallbackStrided) { + generate(4, 5, 3, true, false, false, 32); + test_fp32_a_input_channelwise_8bit_b( + /*m=*/4, /*k=*/5, /*n=*/3, beta(), *this, 32); +} + +INSTANTIATE_TEST_SUITE_P( + F32AInt8BFP32CTest, + FP32A_QuantizedB_FP32C_Interface_Test, + ::testing::Values(0.0, 1.0, 3.1));