From a5939a3ff43b2b17f12f9a1566be14437b506428 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Sat, 6 Sep 2025 10:22:32 -0700 Subject: [PATCH] Delete copy of quantized SDPA in torchao/experimental (#2952) Summary: Pull Request resolved: https://github.com/pytorch/ao/pull/2952 This deletes the SDPA/matmul kernels in torchao/experimental. These kernels were already migrated to torchao/csrc/cpu/torch_free_kernels in the first diff, but we needed to keep them around in torchao/experimental to do a pin bump in ExecuTorch OSS after the first diff lands. X-link: https://github.com/pytorch/executorch/pull/13924 Test Plan: Imported from GitHub, without a `Test Plan:` line. Rollback Plan: Reviewed By: kimishpatel Differential Revision: D81640227 Pulled By: metascroy --- .../experimental/kernels/cpu/aarch64/macro.h | 9 - ...hannelwise_8bit_b_1x16x16_f32_smlal-impl.h | 384 ---------- ...annelwise_8bit_b_1x8x16_f32_neondot-impl.h | 340 --------- ...hannelwise_8bit_b_4x8x8_f32_neondot-impl.h | 411 ----------- ...input_channelwise_8bit_b_1x16x4_f32_impl.h | 281 -------- ...input_channelwise_8bit_b_4x16x4_f32_impl.h | 328 --------- .../kernels/cpu/aarch64/matmul/matmul.h | 318 --------- .../kernels/cpu/aarch64/matmul/matmul_utils.h | 153 ---- .../channelwise_8bit_a_channelwise_8bit_b.h | 133 ---- .../matmul/fp32_a_channelwise_8bit_b_fp32_c.h | 50 -- .../kernels/cpu/interface/quantized_matmul.h | 156 ----- .../cpu/interface/test_qmatmul_interface.cpp | 658 ------------------ 12 files changed, 3221 deletions(-) delete mode 100644 torchao/experimental/kernels/cpu/aarch64/macro.h delete mode 100644 torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h delete mode 100644 torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h delete mode 100644 torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot-impl.h delete mode 100644 torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h delete mode 100644 torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_4x16x4_f32_impl.h delete mode 100644 torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h delete mode 100644 torchao/experimental/kernels/cpu/aarch64/matmul/matmul_utils.h delete mode 100644 torchao/experimental/kernels/cpu/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h delete mode 100644 torchao/experimental/kernels/cpu/fallback/matmul/fp32_a_channelwise_8bit_b_fp32_c.h delete mode 100644 torchao/experimental/kernels/cpu/interface/quantized_matmul.h delete mode 100644 torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp diff --git a/torchao/experimental/kernels/cpu/aarch64/macro.h b/torchao/experimental/kernels/cpu/aarch64/macro.h deleted file mode 100644 index 4861edbee7..0000000000 --- a/torchao/experimental/kernels/cpu/aarch64/macro.h +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once - -#define TORCHAO_ALWAYS_INLINE __attribute__((always_inline)) diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h b/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h deleted file mode 100644 index 5ed3b686fd..0000000000 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h +++ /dev/null @@ -1,384 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once - -#if defined(__aarch64__) && defined(__ARM_NEON) - -#include -#include -#include - -#include -#include -#include - -namespace torchao::kernels::cpu::aarch64::quantized_matmul { -namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal::internal { - -namespace { -/* -This function loads int8x16_t value from a, and 8 int8x16_t values from b. -For each int8x16_t of b: -- subl to subtarct a_zero_point from a, to get a_low, a_high -- 4 int32x4 accumulated values -- for i in [0, 8]: - - load b[i] - - subl to subtarct b_zero_point from b, to get b_low, b_high - - smlal_lane to multiply a_low[i] and b_low_low. - - smlal_lane to multiply a_low[i] and b_low_high. - - smlal_lane to multiply a_low[i] and b_high_low. - - smlal_lane to multiply a_low[i] and b_high_high. - - This produces 2 int32x4_t values -- for i in [0, 8]: - - load b[i] - - subl to subtarct b_zero_point from b, to get b_low, b_high - - smlal_lane to multiply a_low[i] and b_low_low. - - smlal_lane to multiply a_low[i] and b_low_high. - - smlal_lane to multiply a_low[i] and b_high_low. - - smlal_lane to multiply a_low[i] and b_high_high. - - This produces 2 int32x4_t values -Possibly better to transpose 16x16 of b and use dotprod. Left for future. -*/ - -template -TORCHAO_ALWAYS_INLINE inline void block_mul_1x16x1( - const int16x4_t& a_vec, - const int8x16_t& b_vec, - const int8x16_t& b_zero_point_vec, - int32x4_t (&partial_sums)[4]) { - int16x8_t b_vec_low = - vsubl_s8(vget_low_s8(b_vec), vget_low_s8(b_zero_point_vec)); - int16x8_t b_vec_high = - vsubl_s8(vget_high_s8(b_vec), vget_high_s8(b_zero_point_vec)); - partial_sums[0] = - vmlal_lane_s16(partial_sums[0], vget_low_s16(b_vec_low), a_vec, lane); - partial_sums[1] = - vmlal_lane_s16(partial_sums[1], vget_high_s16(b_vec_low), a_vec, lane); - partial_sums[2] = - vmlal_lane_s16(partial_sums[2], vget_low_s16(b_vec_high), a_vec, lane); - partial_sums[3] = - vmlal_lane_s16(partial_sums[3], vget_high_s16(b_vec_high), a_vec, lane); -} - -void block_mul_1x16x16( - const int8_t* a, - const int8_t* b, - const size_t ldb, - const int8_t a_zero_point, - const int8_t* b_zero_point, - int32x4_t (&partial_sums)[4]) { - int8x16_t a_vec = vld1q_s8(a); - int8x8_t a_zero_point_vec = vdup_n_s8(a_zero_point); - int8x16_t b_zero_point_vec = vld1q_s8(b_zero_point); - int16x8_t a_vec_low = vsubl_s8(vget_low_s8(a_vec), a_zero_point_vec); - int16x8_t a_vec_high = vsubl_s8(vget_high_s8(a_vec), a_zero_point_vec); - - int8x16_t b_vec = vld1q_s8(b + 0 * ldb); - block_mul_1x16x1<0>( - vget_low_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); - b_vec = vld1q_s8(b + 1 * ldb); - block_mul_1x16x1<1>( - vget_low_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); - b_vec = vld1q_s8(b + 2 * ldb); - block_mul_1x16x1<2>( - vget_low_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); - b_vec = vld1q_s8(b + 3 * ldb); - block_mul_1x16x1<3>( - vget_low_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); - b_vec = vld1q_s8(b + 4 * ldb); - block_mul_1x16x1<0>( - vget_high_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); - b_vec = vld1q_s8(b + 5 * ldb); - block_mul_1x16x1<1>( - vget_high_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); - b_vec = vld1q_s8(b + 6 * ldb); - block_mul_1x16x1<2>( - vget_high_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); - b_vec = vld1q_s8(b + 7 * ldb); - block_mul_1x16x1<3>( - vget_high_s16(a_vec_low), b_vec, b_zero_point_vec, partial_sums); - - // Second set of 8 channels - b_vec = vld1q_s8(b + 8 * ldb); - block_mul_1x16x1<0>( - vget_low_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); - b_vec = vld1q_s8(b + 9 * ldb); - block_mul_1x16x1<1>( - vget_low_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); - b_vec = vld1q_s8(b + 10 * ldb); - block_mul_1x16x1<2>( - vget_low_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); - b_vec = vld1q_s8(b + 11 * ldb); - block_mul_1x16x1<3>( - vget_low_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); - b_vec = vld1q_s8(b + 12 * ldb); - block_mul_1x16x1<0>( - vget_high_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); - b_vec = vld1q_s8(b + 13 * ldb); - block_mul_1x16x1<1>( - vget_high_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); - b_vec = vld1q_s8(b + 14 * ldb); - block_mul_1x16x1<2>( - vget_high_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); - b_vec = vld1q_s8(b + 15 * ldb); - block_mul_1x16x1<3>( - vget_high_s16(a_vec_high), b_vec, b_zero_point_vec, partial_sums); -} - -TORCHAO_ALWAYS_INLINE inline void dequantize_1x16_int32_t( - const int32x4_t (&sums)[4], - const float* lhs_scales, - const float* rhs_scales, - float32x4_t (&outputs)[4]) { - float32x4_t scales_0123 = vmulq_n_f32(vld1q_f32(rhs_scales), lhs_scales[0]); - float32x4_t scales_4567 = - vmulq_n_f32(vld1q_f32(rhs_scales + 4), lhs_scales[0]); - float32x4_t scales_89ab = - vmulq_n_f32(vld1q_f32(rhs_scales + 8), lhs_scales[0]); - float32x4_t scales_cdef = - vmulq_n_f32(vld1q_f32(rhs_scales + 12), lhs_scales[0]); - - outputs[0] = vmulq_f32(vcvtq_f32_s32(sums[0]), scales_0123); - outputs[1] = vmulq_f32(vcvtq_f32_s32(sums[1]), scales_4567); - outputs[2] = vmulq_f32(vcvtq_f32_s32(sums[2]), scales_89ab); - outputs[3] = vmulq_f32(vcvtq_f32_s32(sums[3]), scales_cdef); -} - -template < - bool a_has_zeros, - bool b_has_zeros, - bool a_transposed, - bool b_transposed> -struct KernelImpl { - static void run( - int m, - int n, - int k, - const void* lhs, - int lhs_stride_m, - const void* rhs, - int rhs_stride_n, - float32_t* output, - int out_stride_m, - const int8_t* lhs_zero_points, - const int8_t* rhs_zero_points, - const float* lhs_scales, - const float* rhs_scales, - const int lhs_qparams_stride, - const int rhs_qparams_stride); -}; - -template <> -struct KernelImpl { - /** - * @brief Implements quantized matrix multiplication for 8-bit channelwise - * quantized matrices - * - * This specialized implementation handles the case where: - * - Both LHS and RHS have zero points (true, true) - * - Neither LHS nor RHS are transposed (false, false) - * - * The function performs a quantized matrix multiplication C = A * B where: - * - A is an m×k matrix (LHS) - * - B is a k×n matrix (RHS) - * - C is an m×n matrix (output) - * - * The implementation uses NEON intrinsics for vectorized computation and - * processes data in blocks of 16×16 for optimal performance on ARM - * architecture. - * - * @param m Number of rows in LHS and output - * @param n Number of columns in RHS and output - * @param k Number of columns in LHS and rows in RHS - * @param lhs Pointer to LHS matrix data (int8_t) - * @param lhs_stride_m Stride between rows of LHS - * @param rhs Pointer to RHS matrix data (int8_t) - * @param rhs_stride_n Stride between rows of RHS - * @param output Pointer to output matrix (float32_t) - * @param out_stride_m Stride between rows of output - * @param lhs_zero_points Zero points for LHS quantization (per-channel) - * @param rhs_zero_points Zero points for RHS quantization (per-channel) - * @param lhs_scales Scales for LHS quantization (per-channel) - * @param rhs_scales Scales for RHS quantization (per-channel) - * @param lhs_qparams_stride Stride for LHS quantization parameters - * @param rhs_qparams_stride Stride for RHS quantization parameters - */ - static void run( - int m, - int n, - int k, - const void* lhs, - int lhs_stride_m, - const void* rhs, - int rhs_stride_n, - float32_t* output, - int out_stride_m, - const int8_t* lhs_zero_points, - const int8_t* rhs_zero_points, - const float* lhs_scales, - const float* rhs_scales, - const int lhs_qparams_stride, - const int rhs_qparams_stride) { - // If lhs_zero_points and rhs_zero_points are not contiguous, transpose - std::unique_ptr lhs_zero_points_transposed = - std::make_unique(m); - std::unique_ptr lhs_scales_transposed = - std::make_unique(m); - if (lhs_qparams_stride > 1) { - utils::transpose_scales_and_zero_points( - lhs_zero_points, - lhs_scales, - lhs_zero_points_transposed.get(), - lhs_scales_transposed.get(), - m, - lhs_qparams_stride); - lhs_zero_points = lhs_zero_points_transposed.get(); - lhs_scales = lhs_scales_transposed.get(); - } - std::unique_ptr rhs_zero_points_transposed = - std::make_unique(n); - std::unique_ptr rhs_scales_transposed = - std::make_unique(n); - if (rhs_qparams_stride > 1) { - utils::transpose_scales_and_zero_points( - rhs_zero_points, - rhs_scales, - rhs_zero_points_transposed.get(), - rhs_scales_transposed.get(), - n, - rhs_qparams_stride); - rhs_zero_points = rhs_zero_points_transposed.get(); - rhs_scales = rhs_scales_transposed.get(); - } - - for (int m_idx = 0; m_idx < m; m_idx++) { - // Loop over 16 cols at a time - // Access to partial tiles must be protected:w - constexpr int nr = 16; - constexpr int kr = 16; - assert(n >= nr); - for (int n_idx = 0; n_idx < n; n_idx += nr) { - // If remaining is < nr, that must mean that (nr - remaining) items - // dont need to be computed. - // In order to avoid out-of-bounds access, we need to rewind n_indx a - // bit - // |-------------------|-------------------| - // 0-------------------8-------------------16 - // 0-------------------8-----10 - // If n = 10 and nr = 8 then at n_idx = 8, we need to rewind n_idx to - // 8 - (8 - 10) = 2 - int remaining = std::min(n - n_idx, nr); - n_idx = n_idx - (nr - remaining); - // Set activation_ptr to start of activation qvals for row m_idx - const int8_t* lhs_ptr = (const int8_t*)lhs + m_idx * lhs_stride_m; - const int8_t* rhs_ptr = (const int8_t*)rhs + n_idx; - int32x4_t int32_sums[nr / 4] = {vdupq_n_s32(0)}; - - // Loop k_idx by group - int k_idx = 0; - for (; (k_idx + kr) <= k; k_idx += kr) { - block_mul_1x16x16( - lhs_ptr, - rhs_ptr, - rhs_stride_n, - lhs_zero_points[m_idx], - rhs_zero_points + n_idx, - int32_sums); - lhs_ptr += kr; - rhs_ptr += kr * rhs_stride_n; - } - - int8x16_t b_zero_point_vec = vld1q_s8(rhs_zero_points + n_idx); - for (int ki = 0; ki < (k - k_idx); ++ki) { - // For each of the remaining k values - // Load 1 int8_t from lhs - // Load 16 int8_t from rhs - // And multiply + add into the 16 accumulators - // arranged as int32x4_t[4] - int16_t a_val = static_cast(lhs_ptr[ki]) - - static_cast(lhs_zero_points[m_idx]); - int8x16_t b_vec = vld1q_s8(rhs_ptr + ki * rhs_stride_n); - int16x8_t b_vec_low = - vsubl_s8(vget_low_s8(b_vec), vget_low_s8(b_zero_point_vec)); - int16x8_t b_vec_high = - vsubl_s8(vget_high_s8(b_vec), vget_high_s8(b_zero_point_vec)); - int32_sums[0] = - vmlal_n_s16(int32_sums[0], vget_low_s16(b_vec_low), a_val); - int32_sums[1] = - vmlal_n_s16(int32_sums[1], vget_high_s16(b_vec_low), a_val); - int32_sums[2] = - vmlal_n_s16(int32_sums[2], vget_low_s16(b_vec_high), a_val); - int32_sums[3] = - vmlal_n_s16(int32_sums[3], vget_high_s16(b_vec_high), a_val); - } - - float32x4_t res[4]; - dequantize_1x16_int32_t( - int32_sums, lhs_scales + m_idx, rhs_scales + n_idx, res); - - // Store result - // Because we adjust n_idx, we may end up writing the same location - // twice - float* store_loc = output + m_idx * out_stride_m + n_idx; - vst1q_f32(store_loc, res[0]); - vst1q_f32(store_loc + 4, res[1]); - vst1q_f32(store_loc + 8, res[2]); - vst1q_f32(store_loc + 12, res[3]); - } // n_idx - } // m_idx - } -}; - -} // namespace - -} // namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal::internal - -namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal { -template < - bool a_has_zeros, - bool b_has_zeros, - bool a_transposed, - bool b_transposed> -void kernel( - int m, - int n, - int k, - const void* lhs, - int lhs_stride_m, - const void* rhs, - int rhs_stride_n, - float32_t* output, - int out_stride_m, - const int8_t* lhs_zero_points, - const int8_t* rhs_zero_points, - const float* lhs_scales, - const float* rhs_scales, - const int lhs_qparams_stride, - const int rhs_qparams_stride) { - torchao::kernels::cpu::aarch64::quantized_matmul:: - channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal::internal:: - KernelImpl::run( - m, - n, - k, - lhs, - lhs_stride_m, - rhs, - rhs_stride_n, - output, - out_stride_m, - lhs_zero_points, - rhs_zero_points, - lhs_scales, - rhs_scales, - lhs_qparams_stride, - rhs_qparams_stride); -} -} // namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal -} // namespace torchao::kernels::cpu::aarch64::quantized_matmul - -#endif // defined(__aarch64__) && defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h b/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h deleted file mode 100644 index c976be39f5..0000000000 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h +++ /dev/null @@ -1,340 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once - -#if defined(__aarch64__) && defined(__ARM_NEON) - -#include -#include -#include - -#include -#include -#include - -namespace torchao::kernels::cpu::aarch64::quantized_matmul { -namespace channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot::internal { - -/* -This function loads int8x16_t value from a, and 8 int8x16_t values from b, and -computes 8 dot products, resulting in 8 int32x4_t values. -Furthermore the int8x16_t values from a are reduced via summing, resulting in -int32_t row_sum_a. Similar int8x16_t values from b are reduced via summing, -resulting in int32_t row_sum_b. -*/ -TORCHAO_ALWAYS_INLINE static void block_mul_1x8x16( - const int8_t* a, - const int8_t* b, - const size_t ldb, - int32x4_t (&partial_sums)[8], - int32_t& row_sum_a, - int32x4_t (&row_sum_b)[8]) { - int8x16_t a_vec = vld1q_s8(a); - int8x16_t ones = vdupq_n_s8(1); - row_sum_a = row_sum_a + vaddlvq_s8(a_vec); - -// godbolt (https://godbolt.org/z/9vbq1d1qY) shows this loops doesnt quantize -// get optimized by moving all the loads up in the unrolled loop. Just hoping -// OOO machine will take care of things Late replace this with macros so as to -// deconstruct the loop and do manual optimization. Or just write assembly. -#pragma unroll(8) - for (int i = 0; i < 8; ++i) { - int8x16_t b_vec = vld1q_s8(b); - b += ldb; - row_sum_b[i] = vdotq_s32(row_sum_b[i], b_vec, ones); - partial_sums[i] = vdotq_s32(partial_sums[i], a_vec, b_vec); - } -} - -TORCHAO_ALWAYS_INLINE static void reduce_1x8_int32x4_t_sums( - const int32x4_t (&partial_sums)[8], - int32_t (&sums)[8]) { -#pragma unroll(8) - for (int i = 0; i < 8; ++i) { - sums[i] = vaddvq_s32(partial_sums[i]); - } -} - -TORCHAO_ALWAYS_INLINE static void dequantize_1x8_int32_t( - const int32_t (&sums)[8], - int32_t& row_sum_lhs, - int32_t (&row_sum_rhs)[8], - const int8_t* lhs_zero_points, - const int8_t* rhs_zero_points, - const float* lhs_scales, - const float* rhs_scales, - const int32_t k, - float32x4x2_t& outputs) { - int32x4_t vec_sum_0123 = vld1q_s32(sums); - int32x4_t vec_sum_4567 = vld1q_s32(sums + 4); - - int32x4_t row_sum_rhs_x_lhs_zp_0123 = - vmulq_n_s32(vld1q_s32(row_sum_rhs), (int32_t)lhs_zero_points[0]); - int32x4_t row_sum_rhs_x_lhs_zp_4567 = - vmulq_n_s32(vld1q_s32(row_sum_rhs + 4), (int32_t)lhs_zero_points[0]); - - // Extract rhs zero point in int8x8_t and convert to int32x4_t - int16x8_t rhs_zero_points_vec_01234567 = vmovl_s8(vld1_s8(rhs_zero_points)); - int32x4_t rhs_zero_points_vec_0123 = - vmovl_s16(vget_low_s16(rhs_zero_points_vec_01234567)); - int32x4_t rhs_zero_points_vec_4567 = - vmovl_s16(vget_high_s16(rhs_zero_points_vec_01234567)); - int32x4_t row_sum_lhs_x_rhs_zp_0123 = - vmulq_n_s32(rhs_zero_points_vec_0123, row_sum_lhs); - int32x4_t row_sum_lhs_x_rhs_zp_4567 = - vmulq_n_s32(rhs_zero_points_vec_4567, row_sum_lhs); - - int32x4_t zp_rhs_x_zp_lhs_0123 = - vmulq_n_s32(rhs_zero_points_vec_0123, k * (int32_t)lhs_zero_points[0]); - int32x4_t zp_rhs_x_zp_lhs_4567 = - vmulq_n_s32(rhs_zero_points_vec_4567, k * (int32_t)lhs_zero_points[0]); - - vec_sum_0123 = vsubq_s32(vec_sum_0123, row_sum_rhs_x_lhs_zp_0123); - vec_sum_0123 = vsubq_s32(vec_sum_0123, row_sum_lhs_x_rhs_zp_0123); - vec_sum_0123 = vaddq_s32(vec_sum_0123, zp_rhs_x_zp_lhs_0123); - - vec_sum_4567 = vsubq_s32(vec_sum_4567, row_sum_rhs_x_lhs_zp_4567); - vec_sum_4567 = vsubq_s32(vec_sum_4567, row_sum_lhs_x_rhs_zp_4567); - vec_sum_4567 = vaddq_s32(vec_sum_4567, zp_rhs_x_zp_lhs_4567); - - float32x4_t scales_0123 = vmulq_n_f32(vld1q_f32(rhs_scales), lhs_scales[0]); - float32x4_t scales_4567 = - vmulq_n_f32(vld1q_f32(rhs_scales + 4), lhs_scales[0]); - - outputs.val[0] = vmulq_f32(vcvtq_f32_s32(vec_sum_0123), scales_0123); - outputs.val[1] = vmulq_f32(vcvtq_f32_s32(vec_sum_4567), scales_4567); -} - -template < - bool a_has_zeros, - bool b_has_zeros, - bool a_transposed, - bool b_transposed> -struct KernelImpl { - static void run( - int m, - int n, - int k, - const void* lhs, - int lhs_stride_m, - const void* rhs, - int rhs_stride_n, - float32_t* output, - int out_stride_m, - const int8_t* lhs_zero_points, - const int8_t* rhs_zero_points, - const float* lhs_scales, - const float* rhs_scales, - const int lhs_qparams_stride, - const int rhs_qparams_stride); -}; - -template <> -struct KernelImpl { - /** - * @brief Executes a quantized matrix multiplication with channelwise - * quantization parameters - * - * This function performs matrix multiplication between two 8-bit quantized - * matrices with per-channel quantization parameters. It handles the following - * operations: - * 1. Transposes quantization parameters if they're not contiguous - * 2. Processes the matrices in blocks of 8 columns at a time - * 3. Uses NEON dot product instructions for efficient computation - * 4. Handles edge cases for remaining elements - * 5. Dequantizes the results to floating point - * - * @param m Number of rows in the output matrix - * @param n Number of columns in the output matrix - * @param k Number of columns in lhs / rows in rhs - * @param lhs Pointer to the left-hand side matrix (quantized int8) - * @param lhs_stride_m Stride between rows of the lhs matrix - * @param rhs Pointer to the right-hand side matrix (quantized int8) - * @param rhs_stride_n Stride between rows of the rhs matrix. Expects matrix - * to be transposed. Thus of size [n x k] - * @param output Pointer to the output matrix (float32) - * @param out_stride_m Stride between rows of the output matrix - * @param lhs_zero_points Zero points for lhs quantization (per-channel) - * @param rhs_zero_points Zero points for rhs quantization (per-channel) - * @param lhs_scales Scales for lhs quantization (per-channel) - * @param rhs_scales Scales for rhs quantization (per-channel) - * @param lhs_qparams_stride Stride for lhs quantization parameters - * @param rhs_qparams_stride Stride for rhs quantization parameters - */ - static void run( - int m, - int n, - int k, - const void* lhs, - int lhs_stride_m, - const void* rhs, - int rhs_stride_n, - float32_t* output, - int out_stride_m, - const int8_t* lhs_zero_points, - const int8_t* rhs_zero_points, - const float* lhs_scales, - const float* rhs_scales, - const int lhs_qparams_stride, - const int rhs_qparams_stride) { - // If lhs_zero_points and rhs_zero_points are not contiguous, transpose - std::unique_ptr lhs_zero_points_transposed = - std::make_unique(m); - std::unique_ptr lhs_scales_transposed = - std::make_unique(m); - if (lhs_qparams_stride > 1) { - utils::transpose_scales_and_zero_points( - lhs_zero_points, - lhs_scales, - lhs_zero_points_transposed.get(), - lhs_scales_transposed.get(), - m, - lhs_qparams_stride); - lhs_zero_points = lhs_zero_points_transposed.get(); - lhs_scales = lhs_scales_transposed.get(); - } - std::unique_ptr rhs_zero_points_transposed = - std::make_unique(n); - std::unique_ptr rhs_scales_transposed = - std::make_unique(n); - if (rhs_qparams_stride > 1) { - utils::transpose_scales_and_zero_points( - rhs_zero_points, - rhs_scales, - rhs_zero_points_transposed.get(), - rhs_scales_transposed.get(), - n, - rhs_qparams_stride); - rhs_zero_points = rhs_zero_points_transposed.get(); - rhs_scales = rhs_scales_transposed.get(); - } - - for (int m_idx = 0; m_idx < m; m_idx++) { - // Loop over 8 cols at a time - // Access to partial tiles must be protected:w - constexpr int nr = 8; - constexpr int kr = 16; - assert(n >= nr); - for (int n_idx = 0; n_idx < n; n_idx += nr) { - // If remaining is < nr, that must mean that (nr - remaining) items - // dont need to be computed. - // In order to avoid out-of-bounds access, we need to rewind n_indx a - // bit - // |-------------------|-------------------| - // 0-------------------8-------------------16 - // 0-------------------8-----10 - // If n = 10 and nr = 8 then at n_idx = 8, we need to rewind n_idx to - // 8 - (8 - 10) = 2 - int remaining = std::min(n - n_idx, nr); - n_idx = n_idx - (nr - remaining); - // Set activation_ptr to start of activation qvals for row m_idx - const int8_t* lhs_ptr = (const int8_t*)lhs + m_idx * lhs_stride_m; - const int8_t* rhs_ptr = (const int8_t*)rhs + n_idx * rhs_stride_n; - int32x4_t int32_sums[nr] = {vdupq_n_s32(0)}; - int32_t row_sum_lhs = 0; - int32x4_t row_sum_rhs_vec[nr] = {vdupq_n_s32(0)}; - int32_t sums[nr]; - int32_t row_sum_rhs[nr]; - - // Loop k_idx by group - int k_idx = 0; - for (; (k_idx + kr) <= k; k_idx += kr) { - block_mul_1x8x16( - lhs_ptr, - rhs_ptr, - rhs_stride_n, - int32_sums, - row_sum_lhs, - row_sum_rhs_vec); - lhs_ptr += kr; - rhs_ptr += kr; - } - - reduce_1x8_int32x4_t_sums(int32_sums, sums); - reduce_1x8_int32x4_t_sums(row_sum_rhs_vec, row_sum_rhs); - for (int ki = 0; ki < (k - k_idx); ++ki) { - row_sum_lhs += (int32_t)lhs_ptr[ki]; - } - for (int ni = 0; ni < nr; ++ni) { - for (int ki = 0; ki < (k - k_idx); ++ki) { - sums[ni] += (int32_t)lhs_ptr[ki] * - (int32_t)(rhs_ptr + ni * rhs_stride_n)[ki]; - row_sum_rhs[ni] += (int32_t)(rhs_ptr + ni * rhs_stride_n)[ki]; - } - } - - float32x4x2_t res; - dequantize_1x8_int32_t( - sums, - row_sum_lhs, - row_sum_rhs, - lhs_zero_points + m_idx, - rhs_zero_points + n_idx, - lhs_scales + m_idx, - rhs_scales + n_idx, - k, - res); - - // Store result - // Because we adjust n_idx, we may end up writing the same location - // twice - float* store_loc = output + m_idx * out_stride_m + n_idx; - vst1q_f32(store_loc, res.val[0]); - vst1q_f32(store_loc + 4, res.val[1]); - } // n_idx - } // m_idx - } -}; - -} // namespace - // channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot::internal - -namespace channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot { -template < - bool a_has_zeros, - bool b_has_zeros, - bool a_transposed, - bool b_transposed> -void kernel( - int m, - int n, - int k, - const void* lhs, - int lhs_stride_m, - const void* rhs, - int rhs_stride_n, - float32_t* output, - int out_stride_m, - const int8_t* lhs_zero_points, - const int8_t* rhs_zero_points, - const float* lhs_scales, - const float* rhs_scales, - const int lhs_qparams_stride, - const int rhs_qparams_stride) { - torchao::kernels::cpu::aarch64::quantized_matmul:: - channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot::internal:: - KernelImpl::run( - m, - n, - k, - lhs, - lhs_stride_m, - rhs, - rhs_stride_n, - output, - out_stride_m, - lhs_zero_points, - rhs_zero_points, - lhs_scales, - rhs_scales, - lhs_qparams_stride, - rhs_qparams_stride); -} -} // namespace channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot -} // namespace torchao::kernels::cpu::aarch64::quantized_matmul - -#endif // defined(__aarch64__) && defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot-impl.h b/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot-impl.h deleted file mode 100644 index 19bde9dad9..0000000000 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot-impl.h +++ /dev/null @@ -1,411 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once - -#if defined(__aarch64__) && defined(__ARM_NEON) - -#include -#include -#include - -#include -#include -#include - -namespace torchao::kernels::cpu::aarch64::quantized_matmul { -namespace channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot::internal { - -TORCHAO_ALWAYS_INLINE static void block_mul_4x8x8( - const int8_t* a, - const size_t lda, - const int8_t* b, - int32x4_t (&partial_sums)[4][8 / 4], - int32_t (&row_sum_a)[4], - int32x4_t (&row_sum_b)[2]) { - int8x8_t a_vec[4]; - a_vec[0] = vld1_s8(a + 0 * lda); - a_vec[1] = vld1_s8(a + 1 * lda); - a_vec[2] = vld1_s8(a + 2 * lda); - a_vec[3] = vld1_s8(a + 3 * lda); - int8x16_t ones = vdupq_n_s8(1); - row_sum_a[0] = row_sum_a[0] + vaddlv_s8(a_vec[0]); - row_sum_a[1] = row_sum_a[1] + vaddlv_s8(a_vec[1]); - row_sum_a[2] = row_sum_a[2] + vaddlv_s8(a_vec[2]); - row_sum_a[3] = row_sum_a[3] + vaddlv_s8(a_vec[3]); - - int8x16_t b_vec[2]; - b_vec[0] = vld1q_s8(b); - b_vec[1] = vld1q_s8(b + 16); - row_sum_b[0] = vdotq_s32(row_sum_b[0], b_vec[0], ones); - row_sum_b[1] = vdotq_s32(row_sum_b[1], b_vec[1], ones); - // First 4x4 of the 4x8 tile - // Multiply with k = 0 thus (a_vec[0], 0) (a_vec[1], 0)... - partial_sums[0][0] = - vdotq_lane_s32(partial_sums[0][0], b_vec[0], a_vec[0], 0); - partial_sums[1][0] = - vdotq_lane_s32(partial_sums[1][0], b_vec[0], a_vec[1], 0); - partial_sums[2][0] = - vdotq_lane_s32(partial_sums[2][0], b_vec[0], a_vec[2], 0); - partial_sums[3][0] = - vdotq_lane_s32(partial_sums[3][0], b_vec[0], a_vec[3], 0); - // Second 4x4 of the 4x8 til - partial_sums[0][1] = - vdotq_lane_s32(partial_sums[0][1], b_vec[1], a_vec[0], 0); - partial_sums[1][1] = - vdotq_lane_s32(partial_sums[1][1], b_vec[1], a_vec[1], 0); - partial_sums[2][1] = - vdotq_lane_s32(partial_sums[2][1], b_vec[1], a_vec[2], 0); - partial_sums[3][1] = - vdotq_lane_s32(partial_sums[3][1], b_vec[1], a_vec[3], 0); - - // Second set of 4 channels - b = b + 32; - b_vec[0] = vld1q_s8(b); - b_vec[1] = vld1q_s8(b + 16); - row_sum_b[0] = vdotq_s32(row_sum_b[0], b_vec[0], ones); - row_sum_b[1] = vdotq_s32(row_sum_b[1], b_vec[1], ones); - // First 4x4 of the 4x8 tile - // Multiply with k = 0 thus (a_vec[0], 0) (a_vec[1], 0)... - partial_sums[0][0] = - vdotq_lane_s32(partial_sums[0][0], b_vec[0], a_vec[0], 1); - partial_sums[1][0] = - vdotq_lane_s32(partial_sums[1][0], b_vec[0], a_vec[1], 1); - partial_sums[2][0] = - vdotq_lane_s32(partial_sums[2][0], b_vec[0], a_vec[2], 1); - partial_sums[3][0] = - vdotq_lane_s32(partial_sums[3][0], b_vec[0], a_vec[3], 1); - // Second 4x4 of the 4x8 til - partial_sums[0][1] = - vdotq_lane_s32(partial_sums[0][1], b_vec[1], a_vec[0], 1); - partial_sums[1][1] = - vdotq_lane_s32(partial_sums[1][1], b_vec[1], a_vec[1], 1); - partial_sums[2][1] = - vdotq_lane_s32(partial_sums[2][1], b_vec[1], a_vec[2], 1); - partial_sums[3][1] = - vdotq_lane_s32(partial_sums[3][1], b_vec[1], a_vec[3], 1); -} - -TORCHAO_ALWAYS_INLINE static void dequantize_4x8_int32_t( - int32x4_t (&sums)[4][8 / 4], - int32_t (&row_sum_lhs)[4], - int32x4_t (&row_sum_rhs)[2], - const int8_t* lhs_zero_points, - const int8_t* rhs_zero_points, - const float* lhs_scales, - const float* rhs_scales, - const int32_t k, - float32x4_t (&outputs)[4][8 / 4]) { - int16x8_t rhs_zero_points_01234567 = vmovl_s8(vld1_s8(rhs_zero_points)); - int32x4_t rhs_zero_points_0123 = - vmovl_s16(vget_low_s16(rhs_zero_points_01234567)); - int32x4_t rhs_zero_points_4567 = - vmovl_s16(vget_high_s16(rhs_zero_points_01234567)); - int32x4_t row_sum_lhs_x_rhs_zp_0123 = - vmulq_n_s32(rhs_zero_points_0123, row_sum_lhs[0]); - int32x4_t row_sum_lhs_x_rhs_zp_4567 = - vmulq_n_s32(rhs_zero_points_4567, row_sum_lhs[0]); - // First 8 output channels adjustment - sums[0][0] = vsubq_s32(sums[0][0], row_sum_lhs_x_rhs_zp_0123); - sums[0][1] = vsubq_s32(sums[0][1], row_sum_lhs_x_rhs_zp_4567); - - // Add zp_rhs * zp_lhs * k - int32x4_t zp_rhs_x_zp_lhs_0123 = - vmulq_n_s32(rhs_zero_points_0123, k * (int32_t)lhs_zero_points[0]); - int32x4_t zp_rhs_x_zp_lhs_4567 = - vmulq_n_s32(rhs_zero_points_4567, k * (int32_t)lhs_zero_points[0]); - sums[0][0] = vaddq_s32(sums[0][0], zp_rhs_x_zp_lhs_0123); - sums[0][1] = vaddq_s32(sums[0][1], zp_rhs_x_zp_lhs_4567); - - row_sum_lhs_x_rhs_zp_0123 = vmulq_n_s32(rhs_zero_points_0123, row_sum_lhs[1]); - row_sum_lhs_x_rhs_zp_4567 = vmulq_n_s32(rhs_zero_points_4567, row_sum_lhs[1]); - // Second 8 output channels adjustment - sums[1][0] = vsubq_s32(sums[1][0], row_sum_lhs_x_rhs_zp_0123); - sums[1][1] = vsubq_s32(sums[1][1], row_sum_lhs_x_rhs_zp_4567); - - // Add zp_rhs * zp_lhs * k - zp_rhs_x_zp_lhs_0123 = - vmulq_n_s32(rhs_zero_points_0123, k * (int32_t)lhs_zero_points[1]); - zp_rhs_x_zp_lhs_4567 = - vmulq_n_s32(rhs_zero_points_4567, k * (int32_t)lhs_zero_points[1]); - sums[1][0] = vaddq_s32(sums[1][0], zp_rhs_x_zp_lhs_0123); - sums[1][1] = vaddq_s32(sums[1][1], zp_rhs_x_zp_lhs_4567); - - row_sum_lhs_x_rhs_zp_0123 = vmulq_n_s32(rhs_zero_points_0123, row_sum_lhs[2]); - row_sum_lhs_x_rhs_zp_4567 = vmulq_n_s32(rhs_zero_points_4567, row_sum_lhs[2]); - // Third 8 output channels adjustment - sums[2][0] = vsubq_s32(sums[2][0], row_sum_lhs_x_rhs_zp_0123); - sums[2][1] = vsubq_s32(sums[2][1], row_sum_lhs_x_rhs_zp_4567); - - // Add zp_rhs * zp_lhs * k - zp_rhs_x_zp_lhs_0123 = - vmulq_n_s32(rhs_zero_points_0123, k * (int32_t)lhs_zero_points[2]); - zp_rhs_x_zp_lhs_4567 = - vmulq_n_s32(rhs_zero_points_4567, k * (int32_t)lhs_zero_points[2]); - sums[2][0] = vaddq_s32(sums[2][0], zp_rhs_x_zp_lhs_0123); - sums[2][1] = vaddq_s32(sums[2][1], zp_rhs_x_zp_lhs_4567); - - row_sum_lhs_x_rhs_zp_0123 = vmulq_n_s32(rhs_zero_points_0123, row_sum_lhs[3]); - row_sum_lhs_x_rhs_zp_4567 = vmulq_n_s32(rhs_zero_points_4567, row_sum_lhs[3]); - // Fourth 8 output channels adjustment - sums[3][0] = vsubq_s32(sums[3][0], row_sum_lhs_x_rhs_zp_0123); - sums[3][1] = vsubq_s32(sums[3][1], row_sum_lhs_x_rhs_zp_4567); - - // Add zp_rhs * zp_lhs * k - zp_rhs_x_zp_lhs_0123 = - vmulq_n_s32(rhs_zero_points_0123, k * (int32_t)lhs_zero_points[3]); - zp_rhs_x_zp_lhs_4567 = - vmulq_n_s32(rhs_zero_points_4567, k * (int32_t)lhs_zero_points[3]); - sums[3][0] = vaddq_s32(sums[3][0], zp_rhs_x_zp_lhs_0123); - sums[3][1] = vaddq_s32(sums[3][1], zp_rhs_x_zp_lhs_4567); - - // Now adjust for rhs_zero_points * lhs_row_sum - int32x4_t row_sum_rhs_0123_x_lhs_zp = - vmulq_n_s32(row_sum_rhs[0], lhs_zero_points[0]); - int32x4_t row_sum_rhs_4567_x_lhs_zp = - vmulq_n_s32(row_sum_rhs[1], lhs_zero_points[0]); - sums[0][0] = vsubq_s32(sums[0][0], row_sum_rhs_0123_x_lhs_zp); - sums[0][1] = vsubq_s32(sums[0][1], row_sum_rhs_4567_x_lhs_zp); - - row_sum_rhs_0123_x_lhs_zp = vmulq_n_s32(row_sum_rhs[0], lhs_zero_points[1]); - row_sum_rhs_4567_x_lhs_zp = vmulq_n_s32(row_sum_rhs[1], lhs_zero_points[1]); - sums[1][0] = vsubq_s32(sums[1][0], row_sum_rhs_0123_x_lhs_zp); - sums[1][1] = vsubq_s32(sums[1][1], row_sum_rhs_4567_x_lhs_zp); - - row_sum_rhs_0123_x_lhs_zp = vmulq_n_s32(row_sum_rhs[0], lhs_zero_points[2]); - row_sum_rhs_4567_x_lhs_zp = vmulq_n_s32(row_sum_rhs[1], lhs_zero_points[2]); - sums[2][0] = vsubq_s32(sums[2][0], row_sum_rhs_0123_x_lhs_zp); - sums[2][1] = vsubq_s32(sums[2][1], row_sum_rhs_4567_x_lhs_zp); - - row_sum_rhs_0123_x_lhs_zp = vmulq_n_s32(row_sum_rhs[0], lhs_zero_points[3]); - row_sum_rhs_4567_x_lhs_zp = vmulq_n_s32(row_sum_rhs[1], lhs_zero_points[3]); - sums[3][0] = vsubq_s32(sums[3][0], row_sum_rhs_0123_x_lhs_zp); - sums[3][1] = vsubq_s32(sums[3][1], row_sum_rhs_4567_x_lhs_zp); - - float32x4_t rhs_scales_0123 = vld1q_f32(rhs_scales); - float32x4_t rhs_scales_4567 = vld1q_f32(rhs_scales + 4); - - float32x4_t scales_0123 = vmulq_n_f32(rhs_scales_0123, lhs_scales[0]); - float32x4_t scales_4567 = vmulq_n_f32(rhs_scales_4567, lhs_scales[0]); - - outputs[0][0] = vmulq_f32(vcvtq_f32_s32(sums[0][0]), scales_0123); - outputs[0][1] = vmulq_f32(vcvtq_f32_s32(sums[0][1]), scales_4567); - - scales_0123 = vmulq_n_f32(rhs_scales_0123, lhs_scales[1]); - scales_4567 = vmulq_n_f32(rhs_scales_4567, lhs_scales[1]); - outputs[1][0] = vmulq_f32(vcvtq_f32_s32(sums[1][0]), scales_0123); - outputs[1][1] = vmulq_f32(vcvtq_f32_s32(sums[1][1]), scales_4567); - - scales_0123 = vmulq_n_f32(rhs_scales_0123, lhs_scales[2]); - scales_4567 = vmulq_n_f32(rhs_scales_4567, lhs_scales[2]); - outputs[2][0] = vmulq_f32(vcvtq_f32_s32(sums[2][0]), scales_0123); - outputs[2][1] = vmulq_f32(vcvtq_f32_s32(sums[2][1]), scales_4567); - - scales_0123 = vmulq_n_f32(rhs_scales_0123, lhs_scales[3]); - scales_4567 = vmulq_n_f32(rhs_scales_4567, lhs_scales[3]); - outputs[3][0] = vmulq_f32(vcvtq_f32_s32(sums[3][0]), scales_0123); - outputs[3][1] = vmulq_f32(vcvtq_f32_s32(sums[3][1]), scales_4567); -} - -template < - bool a_has_zeros, - bool b_has_zeros, - bool a_transposed, - bool b_transposed> -struct KernelImpl { - static void run( - int m, - int n, - int k, - const void* lhs, - int lhs_stride_m, - const void* rhs, - int rhs_stride_n, - float32_t* output, - int out_stride_m, - const int8_t* lhs_zero_points, - const int8_t* rhs_zero_points, - const float* lhs_scales, - const float* rhs_scales, - const int lhs_qparams_stride, - const int rhs_qparams_stride); -}; - -template <> -struct KernelImpl { - static void run( - int m, - int n, - int k, - const void* lhs, - int lhs_stride_m, - const void* rhs, - int rhs_stride_n, - float32_t* output, - int out_stride_m, - const int8_t* lhs_zero_points, - const int8_t* rhs_zero_points, - const float* lhs_scales, - const float* rhs_scales, - const int lhs_qparams_stride, - const int rhs_qparams_stride) { - // If lhs_zero_points and rhs_zero_points are not contiguous, transpose - std::vector lhs_zero_points_transposed; - std::vector lhs_scales_transposed; - if (lhs_qparams_stride > 1) { - lhs_zero_points_transposed.resize(m); - lhs_scales_transposed.resize(m); - utils::transpose_scales_and_zero_points( - lhs_zero_points, - lhs_scales, - lhs_zero_points_transposed.data(), - lhs_scales_transposed.data(), - m, - lhs_qparams_stride); - lhs_zero_points = lhs_zero_points_transposed.data(); - lhs_scales = lhs_scales_transposed.data(); - } - std::vector rhs_zero_points_transposed; - std::vector rhs_scales_transposed; - if (rhs_qparams_stride > 1) { - rhs_zero_points_transposed.resize(n); - rhs_scales_transposed.resize(n); - utils::transpose_scales_and_zero_points( - rhs_zero_points, - rhs_scales, - rhs_zero_points_transposed.data(), - rhs_scales_transposed.data(), - n, - rhs_qparams_stride); - rhs_zero_points = rhs_zero_points_transposed.data(); - rhs_scales = rhs_scales_transposed.data(); - } - - constexpr int mr = 4; - constexpr int nr = 8; - constexpr int kr = 8; - assert(m % mr == 0); - assert(k % 16 == 0); - assert(n % nr == 0); - std::vector rhs_packed(n * k); - // Since we are casting int8_t to float32_t in order to tranpose matrix in a - // way to keep 4 of the k values to gether, we must adjust stride as well as - // k size - const size_t k_adjusted = k / 4; - const size_t rhs_stride_n_adjusted = rhs_stride_n / 4; - utils::pack_kxn_b_matrix_for_mx8_dotprod_ukernel( - static_cast(rhs), - rhs_stride_n_adjusted, - reinterpret_cast(rhs_packed.data()), - n, - k_adjusted); - size_t packed_block_stride = nr * k; - constexpr size_t packed_k_stride = nr * kr; - - for (int m_idx = 0; m_idx < m; m_idx += mr) { - for (int n_idx = 0; n_idx < n; n_idx += nr) { - // Set activation_ptr to start of activation qvals for row m_idx - const int8_t* lhs_ptr = (const int8_t*)lhs + m_idx * lhs_stride_m; - const int8_t* rhs_ptr = (const int8_t*)rhs_packed.data() + - (n_idx / nr) * packed_block_stride; - int32x4_t int32_sums[mr][nr / 4] = {{vdupq_n_s32(0)}}; - int32x4_t row_sum_rhs_vec[nr / 4] = {vdupq_n_s32(0)}; - int32_t row_sum_lhs[mr] = {0}; - - // Loop k_idx by group - int k_idx = 0; - for (; k_idx < k; k_idx += kr) { - block_mul_4x8x8( - lhs_ptr, - lhs_stride_m, - rhs_ptr, - int32_sums, - row_sum_lhs, - row_sum_rhs_vec); - lhs_ptr += kr; - rhs_ptr += packed_k_stride; - } - - float32x4_t res[mr][nr / 4]; - dequantize_4x8_int32_t( - int32_sums, - row_sum_lhs, - row_sum_rhs_vec, - lhs_zero_points + m_idx, - rhs_zero_points + n_idx, - lhs_scales + m_idx, - rhs_scales + n_idx, - k, - res); - - // Store result - // Because we adjust n_idx, we may end up writing the same location - // twice - float* store_loc = output + m_idx * out_stride_m + n_idx; - vst1q_f32(store_loc, res[0][0]); - vst1q_f32(store_loc + 4, res[0][1]); - store_loc += out_stride_m; - vst1q_f32(store_loc, res[1][0]); - vst1q_f32(store_loc + 4, res[1][1]); - store_loc += out_stride_m; - vst1q_f32(store_loc, res[2][0]); - vst1q_f32(store_loc + 4, res[2][1]); - store_loc += out_stride_m; - vst1q_f32(store_loc, res[3][0]); - vst1q_f32(store_loc + 4, res[3][1]); - } // n_idx - } // m_idx - } -}; - -} // namespace - // channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot::internal - -namespace channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot { -template < - bool a_has_zeros, - bool b_has_zeros, - bool a_transposed, - bool b_transposed> -void kernel( - int m, - int n, - int k, - const void* lhs, - int lhs_stride_m, - const void* rhs, - int rhs_stride_n, - float32_t* output, - int out_stride_m, - const int8_t* lhs_zero_points, - const int8_t* rhs_zero_points, - const float* lhs_scales, - const float* rhs_scales, - const int lhs_qparams_stride, - const int rhs_qparams_stride) { - torchao::kernels::cpu::aarch64::quantized_matmul:: - channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot::internal:: - KernelImpl::run( - m, - n, - k, - lhs, - lhs_stride_m, - rhs, - rhs_stride_n, - output, - out_stride_m, - lhs_zero_points, - rhs_zero_points, - lhs_scales, - rhs_scales, - lhs_qparams_stride, - rhs_qparams_stride); -} -} // namespace channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot -} // namespace torchao::kernels::cpu::aarch64::quantized_matmul - -#endif // defined(__aarch64__) && defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h b/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h deleted file mode 100644 index 4fc393fcaf..0000000000 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h +++ /dev/null @@ -1,281 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once - -#if defined(__aarch64__) && defined(__ARM_NEON) - -#include -#include -#include - -#include -#include -#include - -namespace torchao::kernels::cpu::aarch64::quantized_matmul { -namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32::internal { - -namespace { - -/* -This function loads float32x4_t value from a, and 16 int8x16_t values from b. -For each int8x16_t of b: -- 4 float32x4 accumulated values -- load 4 a in float32x4_t -- [The following repeats for each of the 4 lanes of a] -- for i in [0, 4]: - - load b[i] in int8x16_t - - subl to subtract b_zero_point from b, to get b_low, b_high - - vmovl to get b_low_low, b_low_high, b_high_low, b_high_high - - vcvtq to convert to float32x4_t, we will have 4 of these. -- for i in [0, 4]: for each of the 4 float32x4_t of b: - - vfmaq_lane_fp32 to multiply a[lane] and b[i] - - vfmaq_lane_fp32 to multiply a[lane] and b[i] - - vfmaq_lane_fp32 to multiply a[lane] and b[i] - - vfmaq_lane_fp32 to multiply a[lane] and b[i] -- By doing the above 4 times (lane=[0-3]), we used all values along k dim of a - and accumulated 4 float32x4_t values -*/ -TORCHAO_ALWAYS_INLINE inline void block_mul_1x16x1( - const float32_t a, - const int8x16_t& b_vec, - const int8_t b_zero_point, - const float b_scale, - float32x4_t (&partial_sums)[4]) { - int8x8_t b_zero_point_vec = vdup_n_s8(b_zero_point); - int16x8_t b_vec_low = vsubl_s8(vget_low_s8(b_vec), b_zero_point_vec); - int16x8_t b_vec_high = vsubl_s8(vget_high_s8(b_vec), b_zero_point_vec); - float32x4_t b_vec_low_low = vcvtq_f32_s32(vmovl_s16(vget_low_s16(b_vec_low))); - float32x4_t b_vec_low_high = - vcvtq_f32_s32(vmovl_s16(vget_high_s16(b_vec_low))); - float32x4_t b_vec_high_low = - vcvtq_f32_s32(vmovl_s16(vget_low_s16(b_vec_high))); - float32x4_t b_vec_high_high = - vcvtq_f32_s32(vmovl_s16(vget_high_s16(b_vec_high))); - b_vec_low_low = vmulq_n_f32(b_vec_low_low, b_scale); - b_vec_low_high = vmulq_n_f32(b_vec_low_high, b_scale); - b_vec_high_low = vmulq_n_f32(b_vec_high_low, b_scale); - b_vec_high_high = vmulq_n_f32(b_vec_high_high, b_scale); - - partial_sums[0] = vfmaq_n_f32(partial_sums[0], b_vec_low_low, a); - partial_sums[1] = vfmaq_n_f32(partial_sums[1], b_vec_low_high, a); - partial_sums[2] = vfmaq_n_f32(partial_sums[2], b_vec_high_low, a); - partial_sums[3] = vfmaq_n_f32(partial_sums[3], b_vec_high_high, a); -} - -void block_mul_1x16x4( - const float32_t* a, - const int8_t* b, - const size_t ldb, - const int8_t* b_zero_point, - const float* b_scale, - float32x4_t (&partial_sums)[4]) { - #pragma unroll(8) - for (int i = 0; i < 4; i++) { - int8x16_t b_vec = vld1q_s8(b + i * ldb); - block_mul_1x16x1(a[i], b_vec, b_zero_point[i], b_scale[i], partial_sums); - } -} - -} // namespace - -template -struct KernelImpl { - static void run( - int m, - int n, - int k, - const void* lhs, - int lhs_stride_m, - const void* rhs, - int rhs_stride_n, - float32_t* output, - int out_stride_m, - const int8_t* rhs_zero_points, - const float* rhs_scales, - const float beta, - const int rhs_qparams_stride); -}; - -/* -Document param meaning -rhs_stride_n: Since rhs transposed == false, the expected shape of rhs is k x n. -Thus rhs_stride_n is the stride of k dim, that how many bytes aparts elements -in k dim are. -*/ -template <> -struct KernelImpl { - static void run( - int m, - int n, - int k, - const float* lhs, - int lhs_stride_m, - const int8_t* rhs, - int rhs_stride_n, - float32_t* output, - int out_stride_m, - const int8_t* rhs_zero_points, - const float* rhs_scales, - const float beta, - const int rhs_qparams_stride) { - std::unique_ptr rhs_zero_points_transposed = std::make_unique(k); - std::unique_ptr rhs_scales_transposed = std::make_unique(k); - if (rhs_qparams_stride > 1) { - utils::transpose_scales_and_zero_points( - rhs_zero_points, - rhs_scales, - rhs_zero_points_transposed.get(), - rhs_scales_transposed.get(), - k, - rhs_qparams_stride); - rhs_zero_points = rhs_zero_points_transposed.get(); - rhs_scales = rhs_scales_transposed.get(); - } - - constexpr int nr = 16; - constexpr int kr = 4; - for (int m_idx = 0; m_idx < m; m_idx++) { - // Loop over 16 cols at a time - // Access to partial tiles must be protected:w - assert(n >= nr); - for (int n_idx = 0; n_idx < n; n_idx += nr) { - // If remaining is < nr, that must mean that (nr - remaining) items - // dont need to be computed. - // In order to avoid out-of-bounds access, we need to rewind n_indx a - // bit - // |-------------------|-------------------| - // 0-------------------8-------------------16 - // 0-------------------8-----10 - // If n = 10 and nr = 8 then at n_idx = 8, we need to rewind n_idx to - // 8 - (8 - 10) = 2 - int remaining = std::min(n - n_idx, nr); - n_idx = n_idx - (nr - remaining); - // Set activation_ptr to start of activation qvals for row m_idx - const float* lhs_ptr = lhs + m_idx * lhs_stride_m; - const int8_t* rhs_ptr = rhs + n_idx; - float32x4_t sums[nr / 4] = {vdupq_n_f32(0)}; - - // Loop k_idx by group - int k_idx = 0; - for (; (k_idx + kr) <= k; k_idx += kr) { - block_mul_1x16x4( - lhs_ptr, - rhs_ptr, - rhs_stride_n, - rhs_zero_points + k_idx, - rhs_scales + k_idx, - sums); - lhs_ptr += kr; - rhs_ptr += kr * rhs_stride_n; - } - - for (int ki = 0; ki < (k - k_idx); ++ki) { - // For each of the remaining k values - // Load 1 int8_t from lhs - // Load 16 int8_t from rhs - // And multiply + add into the 16 accumulators - // arranged as int32x4_t[4] - int8x16_t rhs_vec = vld1q_s8(rhs_ptr + ki * rhs_stride_n); - block_mul_1x16x1( - lhs_ptr[ki], - rhs_vec, - rhs_zero_points[k_idx + ki], - rhs_scales[k_idx + ki], - sums); - } - - // Store result - // Because we adjust n_idx, we may end up writing the same location - // twice - // Note that the reason this case is being handled only for this kernel - // and not others in this directory is because only for this kernel - // we support accumulation. - float* store_loc = output + m_idx * out_stride_m + n_idx; - if (remaining < 16) { - // If remaining is < 16, then not all of the 16 accumulators are - // valid. That is not all of float32x4_t[4] are valid. We need to - // find the first valid one, and then store the rest of the - // accumulators in the same order. - // First valid one is at 3 - ((remaining - 1) / 4) because: - // If remaining is say 10 then first 6 are not valid. - // Thus first group of 4 at sums[0] is not valid. - // In the second group of 4, the first 2 are not valid. - // Rest are valid. - int start_sum_idx = 3 - ((remaining - 1) / 4); - // If remaining is 11, then the sums[1] has 3 valid values - // so 3 - (11 -1) % 4 = 3 - 10 % 4 = 3 - 2 = 1 - // Thus there is 1 invalid value in the first group of 4 - int invalid_values_in_32x4_reg = 3 - (remaining - 1) % 4; - store_loc += start_sum_idx * 4; - store_loc += invalid_values_in_32x4_reg; - if (invalid_values_in_32x4_reg > 0) { - for (int val_idx = invalid_values_in_32x4_reg; val_idx < 4; - ++val_idx) { - *store_loc = sums[start_sum_idx][val_idx] + (*store_loc) * beta; - store_loc += 1; - } - start_sum_idx++; - } - for (int out_idx = 0, sum_idx = start_sum_idx; sum_idx < nr / 4; - out_idx += 4, ++sum_idx) { - float32x4_t sum_val = vld1q_f32(store_loc + out_idx); - sums[sum_idx] = vfmaq_n_f32(sums[sum_idx], sum_val, beta); - vst1q_f32(store_loc + out_idx, sums[sum_idx]); - } - } else { - for (int out_idx = 0, sum_idx = 0; out_idx < nr; - out_idx += 4, ++sum_idx) { - float32x4_t sum_val = vld1q_f32(store_loc + out_idx); - sums[sum_idx] = vfmaq_n_f32(sums[sum_idx], sum_val, beta); - vst1q_f32(store_loc + out_idx, sums[sum_idx]); - } - } - } // n_idx - } // m_idx - } -}; - -} // namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32::internal - -namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32 { -template -void kernel( - int m, - int n, - int k, - const float* lhs, - int lhs_stride_m, - const int8_t* rhs, - int rhs_stride_n, - float32_t* output, - int out_stride_m, - const int8_t* rhs_zero_points, - const float* rhs_scales, - const float beta, - const int rhs_qparams_stride) { - torchao::kernels::cpu::aarch64::quantized_matmul:: - fp32_a_input_channelwise_8bit_b_1x16x4_f32::internal:: - KernelImpl::run( - m, - n, - k, - lhs, - lhs_stride_m, - rhs, - rhs_stride_n, - output, - out_stride_m, - rhs_zero_points, - rhs_scales, - beta, - rhs_qparams_stride); -} -} // namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32 -} // namespace torchao::kernels::cpu::aarch64::quantized_matmul - -#endif // defined(__aarch64__) && defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_4x16x4_f32_impl.h b/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_4x16x4_f32_impl.h deleted file mode 100644 index a3dd44a10b..0000000000 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_4x16x4_f32_impl.h +++ /dev/null @@ -1,328 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once - -#if defined(__aarch64__) && defined(__ARM_NEON) - -#include -#include -#include - -#include -#include -#include - -namespace torchao::kernels::cpu::aarch64::quantized_matmul { -namespace fp32_a_input_channelwise_8bit_b_4x16x4_f32::internal { - -namespace { - -/* -This function loads float32x4_t value from a, and 16 int8x16_t values from b. -For each int8x16_t of b: -- 4 float32x4 accumulated values -- load 4 a in float32x4_t -- [The following repeats for each of the 4 lanes of a] -- for i in [0, 4]: - - load b[i] in int8x16_t - - subl to subtarct b_zero_point from b, to get b_low, b_high - - vmovl to get b_low_low, b_low_high, b_high_low, b_high_high - - vcvtq to convert to float32x4_t, we will have 4 of these. -- for i in [0, 4]: for each of the 4 float32x4_t of b: - - vfmaq_lane_fp32 to multiply a[lane] and b[i] - - vfmaq_lane_fp32 to multiply a[lane] and b[i] - - vfmaq_lane_fp32 to multiply a[lane] and b[i] - - vfmaq_lane_fp32 to multiply a[lane] and b[i] -- By doing the above 4 times (lane=[0-3]), we used all values along k dim of a - and accumulated 4 float32x4_t values -*/ -TORCHAO_ALWAYS_INLINE inline void block_mul_4x16x1( - const float32x4_t& a, - const int8x16_t& b_vec, - const int8_t b_zero_point, - const float b_scale, - float32x4_t (&partial_sums)[4][4]) { - int8x8_t b_zero_point_vec = vdup_n_s8(b_zero_point); - int16x8_t b_vec_low = vsubl_s8(vget_low_s8(b_vec), b_zero_point_vec); - int16x8_t b_vec_high = vsubl_s8(vget_high_s8(b_vec), b_zero_point_vec); - float32x4_t b_vec_low_low = vcvtq_f32_s32(vmovl_s16(vget_low_s16(b_vec_low))); - float32x4_t b_vec_low_high = - vcvtq_f32_s32(vmovl_s16(vget_high_s16(b_vec_low))); - float32x4_t b_vec_high_low = - vcvtq_f32_s32(vmovl_s16(vget_low_s16(b_vec_high))); - float32x4_t b_vec_high_high = - vcvtq_f32_s32(vmovl_s16(vget_high_s16(b_vec_high))); - b_vec_low_low = vmulq_n_f32(b_vec_low_low, b_scale); - b_vec_low_high = vmulq_n_f32(b_vec_low_high, b_scale); - b_vec_high_low = vmulq_n_f32(b_vec_high_low, b_scale); - b_vec_high_high = vmulq_n_f32(b_vec_high_high, b_scale); - - partial_sums[0][0] = vfmaq_n_f32(partial_sums[0][0], b_vec_low_low, a[0]); - partial_sums[0][1] = vfmaq_n_f32(partial_sums[0][1], b_vec_low_high, a[0]); - partial_sums[0][2] = vfmaq_n_f32(partial_sums[0][2], b_vec_high_low, a[0]); - partial_sums[0][3] = vfmaq_n_f32(partial_sums[0][3], b_vec_high_high, a[0]); - - partial_sums[1][0] = vfmaq_n_f32(partial_sums[1][0], b_vec_low_low, a[1]); - partial_sums[1][1] = vfmaq_n_f32(partial_sums[1][1], b_vec_low_high, a[1]); - partial_sums[1][2] = vfmaq_n_f32(partial_sums[1][2], b_vec_high_low, a[1]); - partial_sums[1][3] = vfmaq_n_f32(partial_sums[1][3], b_vec_high_high, a[1]); - - partial_sums[2][0] = vfmaq_n_f32(partial_sums[2][0], b_vec_low_low, a[2]); - partial_sums[2][1] = vfmaq_n_f32(partial_sums[2][1], b_vec_low_high, a[2]); - partial_sums[2][2] = vfmaq_n_f32(partial_sums[2][2], b_vec_high_low, a[2]); - partial_sums[2][3] = vfmaq_n_f32(partial_sums[2][3], b_vec_high_high, a[2]); - - partial_sums[3][0] = vfmaq_n_f32(partial_sums[3][0], b_vec_low_low, a[3]); - partial_sums[3][1] = vfmaq_n_f32(partial_sums[3][1], b_vec_low_high, a[3]); - partial_sums[3][2] = vfmaq_n_f32(partial_sums[3][2], b_vec_high_low, a[3]); - partial_sums[3][3] = vfmaq_n_f32(partial_sums[3][3], b_vec_high_high, a[3]); -} - -TORCHAO_ALWAYS_INLINE inline void block_mul_4x16x4( - const float32_t* a, - const size_t lda, - const int8_t* b, - const size_t ldb, - const int8_t* b_zero_point, - const float* b_scale, - float32x4_t (&partial_sums)[4][4]) { - float32x4_t a_vec[4]; - utils::transpose_4x4(a, lda, a_vec); - - int8x16_t b_vec = vld1q_s8(b + 0 * ldb); - block_mul_4x16x1(a_vec[0], b_vec, b_zero_point[0], b_scale[0], partial_sums); - b_vec = vld1q_s8(b + 1 * ldb); - block_mul_4x16x1(a_vec[1], b_vec, b_zero_point[1], b_scale[1], partial_sums); - b_vec = vld1q_s8(b + 2 * ldb); - block_mul_4x16x1(a_vec[2], b_vec, b_zero_point[2], b_scale[2], partial_sums); - b_vec = vld1q_s8(b + 3 * ldb); - block_mul_4x16x1(a_vec[3], b_vec, b_zero_point[3], b_scale[3], partial_sums); -} - -} // namespace - -template -struct KernelImpl { - static void run( - int m, - int n, - int k, - const void* lhs, - int lhs_stride_m, - const void* rhs, - int rhs_stride_n, - float32_t* output, - int out_stride_m, - const int8_t* rhs_zero_points, - const float* rhs_scales, - const float beta, - const int rhs_qparams_stride); -}; - -/* -Document param meaning -rhs_stride_n: Since rhs transposed == false, the expected shape of rhs is k x n. -Thus rhs_stride_n is the stride of k dim, that how many bytes aparts elements -in k dim are. -*/ -template <> -struct KernelImpl { - static void run( - int m, - int n, - int k, - const float* lhs, - int lhs_stride_m, - const int8_t* rhs, - int rhs_stride_n, - float32_t* output, - int out_stride_m, - const int8_t* rhs_zero_points, - const float* rhs_scales, - const float beta, - const int rhs_qparams_stride) { - std::vector rhs_zero_points_transposed; - std::vector rhs_scales_transposed; - if (rhs_qparams_stride > 1) { - rhs_zero_points_transposed.resize(k); - rhs_scales_transposed.resize(k); - utils::transpose_scales_and_zero_points( - rhs_zero_points, - rhs_scales, - rhs_zero_points_transposed.data(), - rhs_scales_transposed.data(), - k, - rhs_qparams_stride); - rhs_zero_points = rhs_zero_points_transposed.data(); - rhs_scales = rhs_scales_transposed.data(); - } - - constexpr int mr = 4; - constexpr int nr = 16; - constexpr int kr = 4; - assert(m % mr == 0); - assert(kr == 4); - assert(n >= nr); - for (int m_idx = 0; m_idx < m; m_idx += mr) { - const float* lhs_ptr = lhs + m_idx * lhs_stride_m; - // Loop over 16 cols at a time - // Access to partial tiles must be protected - for (int n_idx = 0; n_idx < n; n_idx += nr) { - // If remaining is < nr, that must mean that (nr - remaining) items - // dont need to be computed. - // In order to avoid out-of-bounds access, we need to rewind n_indx a - // bit - // |-------------------|-------------------| - // 0-------------------8-------------------16 - // 0-------------------8-----10 - // If n = 10 and nr = 8 then at n_idx = 8, we need to rewind n_idx to - // 8 - (8 - 10) = 2 - int remaining = std::min(n - n_idx, nr); - n_idx = n_idx - (nr - remaining); - // Set activation_ptr to start of activation qvals for row m_idx - const int8_t* rhs_ptr = rhs + n_idx; - float32x4_t sums[mr][(nr / 4)] = {{vdupq_n_f32(0)}}; - - // Loop k_idx by group - int k_idx = 0; - const float* current_lhs_ptr = lhs_ptr; - for (; (k_idx + kr) <= k; k_idx += kr) { - block_mul_4x16x4( - current_lhs_ptr, - lhs_stride_m, - rhs_ptr, - rhs_stride_n, - rhs_zero_points + k_idx, - rhs_scales + k_idx, - sums); - current_lhs_ptr += kr; - rhs_ptr += kr * rhs_stride_n; - } - - for (int ki = 0; ki < (k - k_idx); ++ki) { - // For each of the remaining k values - // Load 1 int8_t from lhs - // Load 16 int8_t from rhs - // And multiply + add into the 16 accumulators - // arranged as int32x4_t[4] - int8x16_t rhs_vec = vld1q_s8(rhs_ptr + ki * rhs_stride_n); - float32x4_t lhs_vec = { - current_lhs_ptr[ki + 0 * lhs_stride_m], - current_lhs_ptr[ki + 1 * lhs_stride_m], - current_lhs_ptr[ki + 2 * lhs_stride_m], - current_lhs_ptr[ki + 3 * lhs_stride_m]}; - block_mul_4x16x1( - lhs_vec, - rhs_vec, - rhs_zero_points[k_idx + ki], - rhs_scales[k_idx + ki], - sums); - } - - // Store result - // Because we adjust n_idx, we may end up writing the same location - // twice - // Note that the reason this case is being handld only for this kernel - // and not others in this directory is because only for this kernel - // we support accumulation. - float* store_loc = output + m_idx * out_stride_m + n_idx; - if (remaining < 16) { - // If remaining is < 16, then not all of the 16 accumulators are - // valid. That is not all of float32x4_t[4] are valid. We need to - // find the first valid one, and then store the rest of the - // accumulators in the same order. - // First valid one is at 3 - ((remaining - 1) / 4) because: - // If remaining is say 10 then first 6 are not valid. - // Thus first group of 4 at sums[0] is not valid. - // In the second group of 4, the first 2 are not valid. - // Rest are valid. - int start_sum_idx = 3 - ((remaining - 1) / 4); - // If remaining is 11, then the sums[1] has 3 valid values - // so 3 - (11 -1) % 4 = 3 - 10 % 4 = 3 - 2 = 1 - // Thus there is 1 invalid value in the first group of 4 - int invalid_values_in_32x4_reg = 3 - (remaining - 1) % 4; - store_loc += start_sum_idx * 4; - store_loc += invalid_values_in_32x4_reg; - if (invalid_values_in_32x4_reg > 0) { - for (int m_out_idx = 0; m_out_idx < mr; m_out_idx++) { - float* store_loc_local = store_loc + m_out_idx * out_stride_m; - for (int val_idx = invalid_values_in_32x4_reg; val_idx < 4; - ++val_idx) { - *store_loc_local = sums[m_out_idx][start_sum_idx][val_idx] + - (*store_loc_local) * beta; - store_loc_local += 1; - } - } - start_sum_idx++; - store_loc += (4 - invalid_values_in_32x4_reg); - } - for (int m_out_idx = 0; m_out_idx < mr; m_out_idx++) { - float* store_loc_local = store_loc + m_out_idx * out_stride_m; - for (int out_idx = 0, sum_idx = start_sum_idx; sum_idx < nr / 4; - out_idx += 4, ++sum_idx) { - float32x4_t sum_val = vld1q_f32(store_loc_local + out_idx); - sums[m_out_idx][sum_idx] = - vfmaq_n_f32(sums[m_out_idx][sum_idx], sum_val, beta); - vst1q_f32(store_loc_local + out_idx, sums[m_out_idx][sum_idx]); - } - } - } else { - for (int m_out_idx = 0; m_out_idx < mr; m_out_idx++) { - float* store_loc_local = store_loc + m_out_idx * out_stride_m; - for (int out_idx = 0, sum_idx = 0; out_idx < nr; - out_idx += 4, ++sum_idx) { - float32x4_t sum_val = vld1q_f32(store_loc_local + out_idx); - sums[m_out_idx][sum_idx] = - vfmaq_n_f32(sums[m_out_idx][sum_idx], sum_val, beta); - vst1q_f32(store_loc_local + out_idx, sums[m_out_idx][sum_idx]); - } - } - } - } // n_idx - } // m_idx - } -}; - -} // namespace fp32_a_input_channelwise_8bit_b_4x16x4_f32::internal - -namespace fp32_a_input_channelwise_8bit_b_4x16x4_f32 { -template -void kernel( - int m, - int n, - int k, - const float* lhs, - int lhs_stride_m, - const int8_t* rhs, - int rhs_stride_n, - float32_t* output, - int out_stride_m, - const int8_t* rhs_zero_points, - const float* rhs_scales, - const float beta, - const int rhs_qparams_stride) { - torchao::kernels::cpu::aarch64::quantized_matmul:: - fp32_a_input_channelwise_8bit_b_4x16x4_f32::internal:: - KernelImpl::run( - m, - n, - k, - lhs, - lhs_stride_m, - rhs, - rhs_stride_n, - output, - out_stride_m, - rhs_zero_points, - rhs_scales, - beta, - rhs_qparams_stride); -} -} // namespace fp32_a_input_channelwise_8bit_b_4x16x4_f32 -} // namespace torchao::kernels::cpu::aarch64::quantized_matmul - -#endif // defined(__aarch64__) && defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h b/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h deleted file mode 100644 index 86b14a52aa..0000000000 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h +++ /dev/null @@ -1,318 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -// TODO: this file will be deleted and replaced by -// torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/include.h -// It exists now to prevent breaking existing code in the interim. - -#pragma once - -#include -#if defined(__aarch64__) && defined(__ARM_NEON) - -#include - -namespace torchao::kernels::cpu::aarch64::quantized_matmul { -namespace channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot { - -template < - bool a_has_zeros, - bool b_has_zeros, - bool a_transposed, - bool b_tranposed> -void kernel( - int m, - int n, - int k, - const void* lhs, - int lhs_stride_m, - const void* rhs, - int rhs_stride_n, - float32_t* output, - int out_stride_m, - const int8_t* lhs_zero_points, - const int8_t* rhs_zero_points, - const float* lhs_scales, - const float* rhs_scales, - const int lhs_qparams_stride, - const int rhs_qparams_stride); - -} // namespace channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot - -namespace channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot { - -template < - bool a_has_zeros, - bool b_has_zeros, - bool a_transposed, - bool b_tranposed> -void kernel( - int m, - int n, - int k, - const void* lhs, - int lhs_stride_m, - const void* rhs, - int rhs_stride_n, - float32_t* output, - int out_stride_m, - const int8_t* lhs_zero_points, - const int8_t* rhs_zero_points, - const float* lhs_scales, - const float* rhs_scales, - const int lhs_qparams_stride, - const int rhs_qparams_stride); - -} // namespace channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot - -namespace channelwise_8bit_a_channelwise_8bit_b_f32 { - -template < - bool a_has_zeros, - bool b_has_zeros, - bool a_transposed, - bool b_tranposed> -void kernel( - int m, - int n, - int k, - const void* lhs, - int lhs_stride_m, - const void* rhs, - int rhs_stride_n, - float32_t* output, - int out_stride_m, - const int8_t* lhs_zero_points, - const int8_t* rhs_zero_points, - const float* lhs_scales, - const float* rhs_scales, - const int lhs_qparams_stride, - const int rhs_qparams_stride); - -template < - bool a_has_zeros, - bool b_has_zeros, - bool a_transposed, - bool b_tranposed> -void kernel( - int m, - int n, - int k, - const void* lhs, - int lhs_stride_m, - const void* rhs, - int rhs_stride_n, - float32_t* output, - int out_stride_m, - const int8_t* lhs_zero_points, - const int8_t* rhs_zero_points, - const float* lhs_scales, - const float* rhs_scales, - const int lhs_qparams_stride, - const int rhs_qparams_stride) { - // TODO: Replace this with KerneConfig based dispatch - constexpr size_t gemm_nr = 8; - constexpr size_t gemm_kr = 16; - if ((n % gemm_nr == 0) && (k % gemm_kr == 0) && m > 4) { - auto remaining_m = m % 4; - auto m_for_gemm_kernel = m - remaining_m; - channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot:: - kernel( - m_for_gemm_kernel, - n, - k, - lhs, - lhs_stride_m, - rhs, - rhs_stride_n, - output, - out_stride_m, - lhs_zero_points, - rhs_zero_points, - lhs_scales, - rhs_scales, - lhs_qparams_stride, - rhs_qparams_stride); - output += m_for_gemm_kernel * out_stride_m; - lhs = (static_cast(lhs) + m_for_gemm_kernel * lhs_stride_m); - lhs_zero_points = lhs_zero_points + m_for_gemm_kernel * lhs_qparams_stride; - lhs_scales = lhs_scales + m_for_gemm_kernel * lhs_qparams_stride; - m = remaining_m; - } - if (m > 0) { - channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot:: - kernel( - m, - n, - k, - lhs, - lhs_stride_m, - rhs, - rhs_stride_n, - output, - out_stride_m, - lhs_zero_points, - rhs_zero_points, - lhs_scales, - rhs_scales, - lhs_qparams_stride, - rhs_qparams_stride); - } -} - -} // namespace channelwise_8bit_a_channelwise_8bit_b_f32 - -namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal { - -template < - bool a_has_zeros, - bool b_has_zeros, - bool a_transposed, - bool b_tranposed> -void kernel( - int m, - int n, - int k, - const void* lhs, - int lhs_stride_m, - const void* rhs, - int rhs_stride_n, - float32_t* output, - int out_stride_m, - const int8_t* lhs_zero_points, - const int8_t* rhs_zero_points, - const float* lhs_scales, - const float* rhs_scales, - const int lhs_qparams_stride, - const int rhs_qparams_stride); - -} // namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal - -namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32 { - -template -void kernel( - int m, - int n, - int k, - const float* lhs, - int lhs_stride_m, - const int8_t* rhs, - int rhs_stride_n, - float32_t* output, - int out_stride_m, - const int8_t* rhs_zero_points, - const float* rhs_scales, - const float beta, - const int rhs_qparams_stride); - -} // namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32 - -namespace fp32_a_input_channelwise_8bit_b_4x16x4_f32 { - -template -void kernel( - int m, - int n, - int k, - const float* lhs, - int lhs_stride_m, - const int8_t* rhs, - int rhs_stride_n, - float32_t* output, - int out_stride_m, - const int8_t* rhs_zero_points, - const float* rhs_scales, - const float beta, - const int rhs_qparams_stride); - -} // namespace fp32_a_input_channelwise_8bit_b_4x16x4_f32 - -namespace fp32_a_input_channelwise_8bit_b_f32 { - -template -void kernel( - int m, - int n, - int k, - const float* lhs, - int lhs_stride_m, - const int8_t* rhs, - int rhs_stride_n, - float32_t* output, - int out_stride_m, - const int8_t* rhs_zero_points, - const float* rhs_scales, - const float beta, - const int rhs_qparams_stride); - -template -void kernel( - int m, - int n, - int k, - const float* lhs, - int lhs_stride_m, - const int8_t* rhs, - int rhs_stride_n, - float32_t* output, - int out_stride_m, - const int8_t* rhs_zero_points, - const float* rhs_scales, - const float beta, - const int rhs_qparams_stride) { - assert(n >= 16); - if (m > 16) { - auto remaining_m = m % 16; - auto m_for_gemm_kernel = m - remaining_m; - fp32_a_input_channelwise_8bit_b_4x16x4_f32:: - kernel( - m_for_gemm_kernel, - n, - k, - lhs, - lhs_stride_m, - rhs, - rhs_stride_n, - output, - out_stride_m, - rhs_zero_points, - rhs_scales, - beta, - rhs_qparams_stride); - output += m_for_gemm_kernel * out_stride_m; - lhs += m_for_gemm_kernel * lhs_stride_m; - m = remaining_m; - } - if (m > 0) { - fp32_a_input_channelwise_8bit_b_1x16x4_f32:: - kernel( - m, - n, - k, - lhs, - lhs_stride_m, - rhs, - rhs_stride_n, - output, - out_stride_m, - rhs_zero_points, - rhs_scales, - beta, - rhs_qparams_stride); - } -} - -} // namespace fp32_a_input_channelwise_8bit_b_f32 -} // namespace torchao::kernels::cpu::aarch64::quantized_matmul - -#include -#include -#include -#include -#include - -#endif // defined(__aarch64__) && defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/aarch64/matmul/matmul_utils.h b/torchao/experimental/kernels/cpu/aarch64/matmul/matmul_utils.h deleted file mode 100644 index 0a3c8463a8..0000000000 --- a/torchao/experimental/kernels/cpu/aarch64/matmul/matmul_utils.h +++ /dev/null @@ -1,153 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once - -#if defined(__aarch64__) || defined(__ARM_NEON) - -#include -#include -#include -#include - -namespace torchao::kernels::cpu::aarch64::quantized_matmul { -namespace utils { - -TORCHAO_ALWAYS_INLINE static void transpose_scales_and_zero_points( - const int8_t* zero_points, - const float* scales, - int8_t* zero_points_transposed, - float* scales_transposed, - const int m, - const int stride_m) { - // Process 8 elements at a time using NEON - int i = 0; - for (; i + 8 <= m; i += 8) { - // Load 8 zero points with stride_m - int8x8_t zp = { - zero_points[0 * stride_m], - zero_points[1 * stride_m], - zero_points[2 * stride_m], - zero_points[3 * stride_m], - zero_points[4 * stride_m], - zero_points[5 * stride_m], - zero_points[6 * stride_m], - zero_points[7 * stride_m]}; - zero_points += 8 * stride_m; - // Store contiguously - vst1_s8(zero_points_transposed + i, zp); - - // Load 8 scales with stride_m - float32x4_t scales_lo = { - scales[0 * stride_m], - scales[1 * stride_m], - scales[2 * stride_m], - scales[3 * stride_m]}; - float32x4_t scales_hi = { - scales[4 * stride_m], - scales[5 * stride_m], - scales[6 * stride_m], - scales[7 * stride_m]}; - scales += 8 * stride_m; - // Store contiguously - vst1q_f32(scales_transposed + i, scales_lo); - vst1q_f32(scales_transposed + i + 4, scales_hi); - } - - // Handle remaining elements - for (; i < m; i++) { - zero_points_transposed[i] = zero_points[0]; - scales_transposed[i] = scales[0]; - zero_points += stride_m; - scales += stride_m; - } -} - -void transpose_4x4( - const float32_t* a, - const size_t lda, - float32x4_t (&tranposed)[4]); - -TORCHAO_ALWAYS_INLINE inline void transpose_4x4( - const float32_t* a, - const size_t lda, - float32x4_t (&tranposed)[4]) { - float32x4_t a_vec_0 = vld1q_f32(a + 0 * lda); - float32x4_t a_vec_1 = vld1q_f32(a + 1 * lda); - float32x4_t a_vec_2 = vld1q_f32(a + 2 * lda); - float32x4_t a_vec_3 = vld1q_f32(a + 3 * lda); - // Transpose the 4x4 matrix formed by a_vec_0, a_vec_1, a_vec_2, a_vec_3 - float32x4x2_t a01 = vtrnq_f32(a_vec_0, a_vec_1); - float32x4x2_t a23 = vtrnq_f32(a_vec_2, a_vec_3); - - float32x4_t a_vec_0_t = - vcombine_f32(vget_low_f32(a01.val[0]), vget_low_f32(a23.val[0])); - float32x4_t a_vec_1_t = - vcombine_f32(vget_low_f32(a01.val[1]), vget_low_f32(a23.val[1])); - float32x4_t a_vec_2_t = - vcombine_f32(vget_high_f32(a01.val[0]), vget_high_f32(a23.val[0])); - float32x4_t a_vec_3_t = - vcombine_f32(vget_high_f32(a01.val[1]), vget_high_f32(a23.val[1])); - - tranposed[0] = a_vec_0_t; - tranposed[1] = a_vec_1_t; - tranposed[2] = a_vec_2_t; - tranposed[3] = a_vec_3_t; -} - -void pack_kxn_b_matrix_for_mx8_dotprod_ukernel( - const float32_t* a, - const size_t lda, - float32_t* b, - const size_t n, - const size_t k); - -// Really dong what xnnpack is doing -void pack_kxn_b_matrix_for_mx8_dotprod_ukernel( - const float32_t* a, - const size_t lda, - float32_t* b, - const size_t n, - const size_t k) { - assert(n % 8 == 0); - assert(k % 4 == 0); - // Transpose the matrix in 4x4 blocks - size_t packed_block_stride = 8 * k; - constexpr size_t block_stride_8x4 = 8 * 4; - for (size_t i = 0; i < n; i += 8) { - float32_t* b_ptr = b + (i / 8) * packed_block_stride; - for (size_t j = 0; j < k; j += 4) { - // Get the transposed 4x4 block - float32x4_t transposed_block0[4]; - float32x4_t transposed_block1[4]; - // This transposes the a[i: i + 4, j: j + 4] - // Thus tranposed_block0[0] = a[j: i: i + 4] - // Thus tranposed_block0[1] = a[j + 1: i: i + 4] - transpose_4x4(a + (i + 0) * lda + j, lda, transposed_block0); - // This transposes the a[i + 4: i + 8, j: j + 4] - // Thus tranposed_block1[0] = a[j: i + 4 : i + 8] - // Thus tranposed_block1[1] = a[j + 1: i + 4 : i + 8] - transpose_4x4(a + (i + 4) * lda + j, lda, transposed_block1); - - // Once you have 8x4 matrix of 32bit values transposed - // Store them by writing two adjucent 1x4 blocks so that - // all of the 8 values from n dim are together. - // Then pack the next set of k values. - float32_t* b_ptr_local = b_ptr + (j / 4) * block_stride_8x4; -#pragma unroll(4) - for (size_t ki = 0; ki < 4; ki++) { - float32_t* b_ptr_local_k = b_ptr_local + ki * 8; - vst1q_f32(b_ptr_local_k, transposed_block0[ki]); - vst1q_f32( - b_ptr_local_k + sizeof(float32x4_t) / 4, transposed_block1[ki]); - } - } - } -} -} // namespace utils -} // namespace torchao::kernels::cpu::aarch64::quantized_matmul - -#endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/kernels/cpu/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h b/torchao/experimental/kernels/cpu/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h deleted file mode 100644 index 3b070eb2b3..0000000000 --- a/torchao/experimental/kernels/cpu/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once - -#include - -namespace torchao::kernels::cpu::fallback::quantized_matmul { -namespace channelwise_8bit_a_channelwise_8bit_b::internal { - -template < - bool a_has_zeros, - bool b_has_zeros, - bool a_transposed, - bool b_tranposed> -struct KernelImpl { - static void run( - int m, - int n, - int k, - const void* lhs, - int lhs_stride_m, - const void* rhs, - int rhs_stride_n, - float* output, - int out_stride_m, - const int8_t* lhs_zero_points, - const int8_t* rhs_zero_points, - const float* lhs_scales, - const float* rhs_scales, - const int lhs_qparams_stride, - const int rhs_qparams_stride); -}; - -template -struct KernelImpl { - static void run( - int m, - int n, - int k, - const void* lhs, - int lhs_stride_m, - const void* rhs, - int rhs_stride_n, - float* output, - int out_stride_m, - const int8_t* lhs_zero_points, - const int8_t* rhs_zero_points, - const float* lhs_scales, - const float* rhs_scales, - const int lhs_qparams_stride, - const int rhs_qparams_stride) { - const int8_t* lhs_qvals = static_cast(lhs); - const int8_t* rhs_qvals = static_cast(rhs); - for (int m_idx = 0; m_idx < m; m_idx++) { - for (int n_idx = 0; n_idx < n; n_idx++) { - float res = 0.0; - for (int k_idx = 0; k_idx < k; k_idx++) { - int lhs_idx = m_idx * lhs_stride_m + k_idx; - int rhs_idx = k_idx * rhs_stride_n + n_idx; - if (b_transposed) { - rhs_idx = n_idx * rhs_stride_n + k_idx; - } - - float lhs_dequant = lhs_scales[m_idx * lhs_qparams_stride] * - (static_cast(lhs_qvals[lhs_idx]) - - static_cast( - lhs_zero_points[m_idx * lhs_qparams_stride])); - - float rhs_dequant = rhs_scales[n_idx * rhs_qparams_stride] * - (static_cast(rhs_qvals[rhs_idx]) - - static_cast( - rhs_zero_points[n_idx * rhs_qparams_stride])); - - res += lhs_dequant * rhs_dequant; - } - output[m_idx * n + n_idx] = res; - } - } - } -}; - -} // namespace - // channelwise_8bit_a_channelwise_8bit_b::internal -} // namespace torchao::kernels::cpu::fallback::quantized_matmul - -// TODO: Remove all ::kernels. No need for extra namespace. -namespace torchao::kernels::cpu::fallback::quantized_matmul { -namespace channelwise_8bit_a_channelwise_8bit_b { -template < - bool a_has_zeros, - bool b_has_zeros, - bool a_transposed, - bool b_transposed> -void kernel( - int m, - int n, - int k, - const void* lhs, - int lhs_stride_m, - const void* rhs, - int rhs_stride_n, - float* output, - int out_stride_m, - const int8_t* lhs_zero_points, - const int8_t* rhs_zero_points, - const float* lhs_scales, - const float* rhs_scales, - const int lhs_qparams_stride, - const int rhs_qparams_stride) { - channelwise_8bit_a_channelwise_8bit_b::internal:: - KernelImpl::run( - m, - n, - k, - lhs, - lhs_stride_m, - rhs, - rhs_stride_n, - output, - out_stride_m, - lhs_zero_points, - rhs_zero_points, - lhs_scales, - rhs_scales, - lhs_qparams_stride, - rhs_qparams_stride); -} -} // namespace channelwise_8bit_a_channelwise_8bit_b -} // namespace torchao::kernels::cpu::fallback::quantized_matmul diff --git a/torchao/experimental/kernels/cpu/fallback/matmul/fp32_a_channelwise_8bit_b_fp32_c.h b/torchao/experimental/kernels/cpu/fallback/matmul/fp32_a_channelwise_8bit_b_fp32_c.h deleted file mode 100644 index 58e2853617..0000000000 --- a/torchao/experimental/kernels/cpu/fallback/matmul/fp32_a_channelwise_8bit_b_fp32_c.h +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once - -#include - -// TODO: Remove all ::kernels. No need for extra namespace. -namespace torchao::kernels::cpu::fallback::quantized_matmul { -namespace fp32_a_input_channelwise_8bit_b_fp32 { -template -void kernel( - int m, - int n, - int k, - const float* lhs, - int lhs_stride_m, - const int8_t* rhs, - int rhs_stride_n, - float* output, - int out_stride_m, - const int8_t* rhs_zero_points, - const float* rhs_scales, - const float beta, - const int rhs_qparams_stride) { - assert(a_transposed == false); - for (int m_idx = 0; m_idx < m; m_idx++) { - for (int n_idx = 0; n_idx < n; n_idx++) { - float res = 0.0; - for (int k_idx = 0; k_idx < k; k_idx++) { - int lhs_idx = m_idx * lhs_stride_m + k_idx; - int rhs_idx = k_idx * rhs_stride_n + n_idx; - if (b_transposed) { - rhs_idx = n_idx * rhs_stride_n + k_idx; - } - float rhs_dequant = rhs_scales[k_idx * rhs_qparams_stride] * - (static_cast(rhs[rhs_idx]) - - static_cast(rhs_zero_points[k_idx * rhs_qparams_stride])); - - res += lhs[lhs_idx] * rhs_dequant; - } - output[m_idx * n + n_idx] = output[m_idx * n + n_idx] * beta + res; - } - } -} -} // namespace fp32_a_input_channelwise_8bit_b_fp32 -} // namespace torchao::kernels::cpu::fallback::quantized_matmul diff --git a/torchao/experimental/kernels/cpu/interface/quantized_matmul.h b/torchao/experimental/kernels/cpu/interface/quantized_matmul.h deleted file mode 100644 index 826fe9e85b..0000000000 --- a/torchao/experimental/kernels/cpu/interface/quantized_matmul.h +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once - -#include - -#include -#include - -#if defined(__aarch64__) && defined(__ARM_NEON) -#include -#endif // defined(__aarch64__) && defined(__ARM_NEON) - -namespace torchao::kernels::cpu::quantized_matmul { - -/* -a_stride_m: stride of a in memory to indiciate how far apart each row is. -b_stride_n: stride of b in memory to indiciate how far apart each row is. -If b is transposed (n x k), then this is how many bytes to skip to get to the -next row. If b is not transposed (k x n), then this is how many bytes to skip to -get to the next row. - -It also returns the stride of a and b, that should be used in the kernel. - -Will need to think of a better way to find the right -ukernel. Perhaps via ukernelconfig + registry?. -*/ -using int8_a_int8_b_channelwise_fp32_c_qmatmul_type = void (*)( - int, - int, - int, - const void*, - int, - const void*, - int, - float*, - int, - const int8_t*, - const int8_t*, - const float*, - const float*, - const int, - const int); - -int8_a_int8_b_channelwise_fp32_c_qmatmul_type -get_int8_a_int8_b_channelwise_qmatmul( - int m, - int n, - int k, - bool a_transposed, - bool b_transposed, - int& a_stride_m, - int& b_stride_n); - -int8_a_int8_b_channelwise_fp32_c_qmatmul_type -get_int8_a_int8_b_channelwise_qmatmul( - int m, - int n, - int k, - bool a_transposed, - bool b_transposed, - int& a_stride_m, - int& b_stride_n) { -#if defined(__aarch64__) && defined(__ARM_NEON) - if (!a_transposed && b_transposed && n >= 8) { - a_stride_m = k; - b_stride_n = k; - return aarch64::quantized_matmul:: - channelwise_8bit_a_channelwise_8bit_b_f32:: - kernel; - } -#endif // defined(__aarch64__) && defined(__ARM_NEON) - assert(!a_transposed); - if (b_transposed) { - a_stride_m = k; - b_stride_n = k; - return torchao::kernels::cpu::fallback::quantized_matmul:: - channelwise_8bit_a_channelwise_8bit_b::kernel; - } else { - return torchao::kernels::cpu::fallback::quantized_matmul:: - channelwise_8bit_a_channelwise_8bit_b::kernel; - } -} - -/* -a_stride_m: stride of a in memory to indiciate how far apart each row is. -b_stride_n: stride of b in memory to indiciate how far apart each row is. -If b is transposed (n x k), then this is how many bytes to skip to get to the -next row. If b is not transposed (k x n), then this is how many bytes to skip to -get to the next row. - -It also returns the stride of a and b, that should be used in the kernel. - -Will need to think of a better way to find the right -ukernel. Perhaps via ukernelconfig + registry?. -*/ -using fp32_a_input_channelwise_8bit_b_f32_c_matmul_type = void (*)( - int, - int, - int, - const float*, - int, - const int8_t*, - int, - float*, - int, - const int8_t*, - const float*, - const float, - const int); - -fp32_a_input_channelwise_8bit_b_f32_c_matmul_type -get_fp32_a_input_channelwise_8bit_b_f32_c_matmul( - int m, - int n, - int k, - bool a_transposed, - bool b_transposed, - int& a_stride_m, - int& b_stride_n); - -fp32_a_input_channelwise_8bit_b_f32_c_matmul_type -get_fp32_a_input_channelwise_8bit_b_f32_c_matmul( - int m, - int n, - int k, - bool a_transposed, - bool b_transposed, - int& a_stride_m, - int& b_stride_n) { -#if defined(__aarch64__) && defined(__ARM_NEON) - if (!a_transposed && !b_transposed && n >= 16) { - a_stride_m = k; - b_stride_n = n; - return aarch64::quantized_matmul::fp32_a_input_channelwise_8bit_b_f32:: - kernel; - } -#endif // defined(__aarch64__) && defined(__ARM_NEON) - assert(!a_transposed); - if (b_transposed) { - a_stride_m = k; - b_stride_n = k; - return torchao::kernels::cpu::fallback::quantized_matmul:: - fp32_a_input_channelwise_8bit_b_fp32::kernel; - } else { - a_stride_m = k; - b_stride_n = n; - return torchao::kernels::cpu::fallback::quantized_matmul:: - fp32_a_input_channelwise_8bit_b_fp32::kernel; - } -} -} // namespace torchao::kernels::cpu::quantized_matmul diff --git a/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp b/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp deleted file mode 100644 index 0fbe33ccdc..0000000000 --- a/torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp +++ /dev/null @@ -1,658 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -#include -#include -#include -#include -#include - -#include -#include - -float kTol = 0.0001; - -// This is unfortunately had to be copied over because code in test_utils.h -// depends on quantization kernels which are only buildable for ARM. -// I would like the testing code in this folder to be independent of the arch. -namespace { -void get_qvals_range(int& qmin, int& qmax, int nbit, bool is_symmetric) { - if (is_symmetric) { - qmin = -(1 << (nbit - 1)) + 1; - qmax = -qmin; - } else { - qmin = -(1 << (nbit - 1)); - qmax = (1 << (nbit - 1)) - 1; - } -} - -void get_scale_and_zero( - float& scale, - int& zero, - float vmin, - float vmax, - int qmin, - int qmax) { - assert(qmin < qmax); - assert(vmin < vmax); - scale = (vmax - vmin) / (qmax - qmin); - zero = qmin - std::round(vmin / scale); -} - -inline std::vector -get_random_vector(int size, float min = -1.0, float max = 1.0) { - assert(min < max); - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto dist = std::bind(std::uniform_real_distribution(min, max), rng); - std::vector res(size); - std::generate(res.begin(), res.end(), std::ref(dist)); - return res; -} - -void quantize( - // Output - int8_t* qvals, - // Inputs - const float* vals, - int size, - float scale, - int8_t zero, - int8_t qmin, - int8_t qmax) { - float invScale = 1.0 / (scale + 1e-16); - int i = 0; - auto curr_rounding_mode = fegetround(); - fesetround(FE_TONEAREST); - for (; i < size; ++i) { - // Quantize remaining elements using scalar code - float val = vals[i]; - float qval_f32 = zero + val * invScale; - int32_t qval_s32 = static_cast(std::nearbyint(qval_f32)); - - // Clip to qmin and qmax - qval_s32 = std::max( - static_cast(qmin), - std::min(qval_s32, static_cast(qmax))); - - // Store the quantized value - qvals[i] = static_cast(qval_s32); - } - fesetround(int(curr_rounding_mode)); -} - -auto generate_per_token_quantized_tensor( - int m, - int n, - bool transposed = false) { - auto activations = get_random_vector(m * n, -1.0, 1.0); - auto activation_qvals = std::vector(m * n, 0); - auto activation_scales = std::vector(m, 0); - auto activation_zeros = std::vector(m, 0); - - // Quantize activations with 8-bit asymmetric - // TODO: replace with generic function that does not use aarch64 - // quantize method after we combine with torchao - int qmin, qmax, zero; - float vmin, vmax, scale; - get_qvals_range(qmin, qmax, /*nbit=*/8, /*is_symmetric=*/false); - for (int m_idx = 0; m_idx < m; m_idx++) { - auto minmax = std::minmax_element( - activations.data() + m_idx * n, activations.data() + (m_idx + 1) * n); - vmin = *minmax.first; - vmax = *minmax.second; - get_scale_and_zero(scale, zero, vmin, vmax, qmin, qmax); - activation_scales[m_idx] = scale; - activation_zeros[m_idx] = zero; - quantize( - /*qvals=*/activation_qvals.data() + m_idx * n, - /*vals=*/activations.data() + m_idx * n, - /*size=*/n, - scale, - zero, - qmin, - qmax); - } - - if (transposed) { - auto activations_t = std::vector(m * n, 0); - auto activation_qvals_t = std::vector(m * n, 0); - for (int m_idx = 0; m_idx < m; m_idx++) { - for (int n_idx = 0; n_idx < n; n_idx++) { - int activation_idx = m_idx * n + n_idx; - int tranposed_idx = n_idx * m + m_idx; - activations_t[tranposed_idx] = activations[activation_idx]; - activation_qvals_t[tranposed_idx] = activation_qvals[activation_idx]; - } - } - activations = activations_t; - activation_qvals = activation_qvals_t; - } - - return std::make_tuple( - activations, activation_qvals, activation_scales, activation_zeros); -} - -struct channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case { - int m; - int k; - int n; - int stride; - - bool lhs_has_zeros; - bool rhs_has_zeros; - bool lhs_is_transposed; - bool rhs_is_transposed; - - std::vector expected_output; - - std::vector lhs; - std::vector lhs_qvals; - std::vector lhs_scales; - std::vector lhs_zeros; - - std::vector rhs; - std::vector rhs_qvals; - std::vector rhs_scales; - std::vector rhs_zeros; - - channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case( - int m_, - int k_, - int n_, - int stride_, - bool lhs_has_zeros_, - bool rhs_has_zeros_, - bool lhs_is_transposed_, - bool rhs_is_transposed_, - std::vector expected_output_, - std::vector lhs_, - std::vector lhs_qvals_, - std::vector lhs_scales_, - std::vector lhs_zeros_, - std::vector rhs_, - std::vector rhs_qvals_, - std::vector rhs_scales_, - std::vector rhs_zeros_) - : m(m_), - k(k_), - n(n_), - stride(stride_), - lhs_has_zeros(lhs_has_zeros_), - rhs_has_zeros(rhs_has_zeros_), - lhs_is_transposed(lhs_is_transposed_), - rhs_is_transposed(rhs_is_transposed_), - expected_output(expected_output_), - lhs(lhs_), - lhs_qvals(lhs_qvals_), - lhs_scales(lhs_scales_), - lhs_zeros(lhs_zeros_), - rhs(rhs_), - rhs_qvals(rhs_qvals_), - rhs_scales(rhs_scales_), - rhs_zeros(rhs_zeros_) { - assert(expected_output.size() == m * n); - assert(lhs.size() == m * stride * k); - assert(lhs_qvals.size() == m * stride * k); - assert(lhs_scales.size() == m * stride); - assert(lhs_zeros.size() == m * stride); - assert(rhs.size() == n * stride * k); - assert(rhs_qvals.size() == n * stride * k); - assert(rhs_scales.size() == n * stride); - assert(rhs_zeros.size() == n * stride); - } - - static channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case generate( - int m, - int k, - int n, - bool lhs_has_zeros, - bool rhs_has_zeros, - bool lhs_is_transposed, - // rhs_is_transposed means generated b matrix is mxk instead of kxm - bool rhs_is_transposed, - int stride = 1) { - assert(!lhs_is_transposed); - assert(lhs_has_zeros); - assert(rhs_has_zeros); - assert(rhs_is_transposed || stride == 1); - // Generate activations - auto [lhs, lhs_qvals, lhs_scales, lhs_zeros] = - generate_per_token_quantized_tensor(m * stride, k); - - auto [rhs, rhs_qvals, rhs_scales, rhs_zeros] = - generate_per_token_quantized_tensor(n * stride, k, !rhs_is_transposed); - // Above function produces nxk matrix and to produce kxn you need transposed - // = true. we do !rhs_is_transposed becaues when rhs_is_transposed = true - // the shape should be nxk instead of kxn. - - // Compute expected output - std::vector expected_output(m * n); - - for (int m_idx = 0; m_idx < m; m_idx++) { - for (int n_idx = 0; n_idx < n; n_idx++) { - float res = 0.0; - for (int k_idx = 0; k_idx < k; k_idx++) { - int lhs_idx = m_idx * stride * k + k_idx; - int rhs_idx = k_idx * stride * n + n_idx * stride; - if (rhs_is_transposed) { - rhs_idx = n_idx * stride * k + k_idx; - } - float lhs_dequant = lhs_scales[m_idx * stride] * - (lhs_qvals[lhs_idx] - lhs_zeros[m_idx * stride]); - - float rhs_dequant = rhs_scales[n_idx * stride] * - (rhs_qvals[rhs_idx] - rhs_zeros[n_idx * stride]); - - res += lhs_dequant * rhs_dequant; - } - expected_output[m_idx * n + n_idx] = res; - } - } - - // Return test case - return channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case( - m, - k, - n, - stride, - lhs_has_zeros, - rhs_has_zeros, - lhs_is_transposed, - rhs_is_transposed, - expected_output, - lhs, - lhs_qvals, - lhs_scales, - lhs_zeros, - rhs, - rhs_qvals, - rhs_scales, - rhs_zeros); - } -}; -} // namespace - -template < - bool a_has_zeros, - bool b_has_zeros, - bool a_transposed, - bool b_transposed> -struct test_channelwise_8bit_channelwise_8bit_b { - static void Run(int m, int k, int n); -}; - -template -struct test_channelwise_8bit_channelwise_8bit_b< - a_has_zeros, - b_has_zeros, - false, - true> { - static void Run(int m, int k, int n, int stride = 1) { - auto test_case = - channelwise_8bit_a_channelwise_8bit_b_qmatmul_test_case::generate( - m, k, n, a_has_zeros, a_has_zeros, false, true, stride); - - int a_stride_m, b_stride_n; - auto kernel = torchao::kernels::cpu::quantized_matmul:: - get_int8_a_int8_b_channelwise_qmatmul( - m, n, k, false, true, a_stride_m, b_stride_n); - a_stride_m = a_stride_m * stride; - b_stride_n = b_stride_n * stride; - - std::vector output(m * n); - kernel( - m, - n, - k, - test_case.lhs_qvals.data(), - a_stride_m /*lsh_stride_m*/, - test_case.rhs_qvals.data(), - b_stride_n /*rsh_stride_n*/, - output.data(), - n /*out_stride_n*/, - test_case.lhs_zeros.data(), - test_case.rhs_zeros.data(), - test_case.lhs_scales.data(), - test_case.rhs_scales.data(), - stride, /*lhs qparams stride*/ - stride /*rhs qparams stride*/); - - for (int i = 0; i < m * n; i++) { - EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); - } - } -}; - -TEST(test_channelwise_8bit_channelwise_8bit_b, TranposedBWithZeroPoints) { - test_channelwise_8bit_channelwise_8bit_b< - true /*a_has_zeros*/, - true /*b_has_zeros*/, - false /*a_transposed*/, - true /*b_transposed*/>:: - Run( - /*m=*/1, /*k=*/128, /*n=*/16); -} - -TEST(test_channelwise_8bit_channelwise_8bit_b, TranposeBWithZeroPointsLargeM) { - test_channelwise_8bit_channelwise_8bit_b< - true /*a_has_zeros*/, - true /*b_has_zeros*/, - false /*a_transposed*/, - true /*b_transposed*/>:: - Run( - /*m=*/4, /*k=*/128, /*n=*/16); -} - -TEST( - test_channelwise_8bit_channelwise_8bit_b, - TranposeBWithZeroPointsLargeMWithGemmGemvMix) { - test_channelwise_8bit_channelwise_8bit_b< - true /*a_has_zeros*/, - true /*b_has_zeros*/, - false /*a_transposed*/, - true /*b_transposed*/>:: - Run( - /*m=*/11, /*k=*/128, /*n=*/16); -} - -TEST( - test_channelwise_8bit_channelwise_8bit_b, - TranposedBWithZeroPointsOddSizes) { - test_channelwise_8bit_channelwise_8bit_b< - true /*a_has_zeros*/, - true /*b_has_zeros*/, - false /*a_transposed*/, - true /*b_transposed*/>:: - Run( - /*m=*/4, /*k=*/37, /*n=*/24); -} - -TEST( - test_channelwise_8bit_channelwise_8bit_b, - TranposedBWithZeroPointsOddSizes2) { - test_channelwise_8bit_channelwise_8bit_b< - true /*a_has_zeros*/, - true /*b_has_zeros*/, - false /*a_transposed*/, - true /*b_transposed*/>:: - Run( - /*m=*/4, /*k=*/37, /*n=*/19); -} - -// Test shapes for which we have to use fallback kernel -TEST( - test_channelwise_8bit_channelwise_8bit_b, - TranposedBWithZeroPointsOddSizesFallback) { - test_channelwise_8bit_channelwise_8bit_b< - true /*a_has_zeros*/, - true /*b_has_zeros*/, - false /*a_transposed*/, - true /*b_transposed*/>:: - Run( - /*m=*/4, /*k=*/37, /*n=*/5); -} - -// Test shapes for which we have to use fallback kernel -TEST( - test_channelwise_8bit_channelwise_8bit_b, - TranposedBWithZeroPointsOddSizesFallback2) { - test_channelwise_8bit_channelwise_8bit_b< - true /*a_has_zeros*/, - true /*b_has_zeros*/, - false /*a_transposed*/, - true /*b_transposed*/>:: - Run( - /*m=*/4, /*k=*/2, /*n=*/1); -} - -TEST( - test_channelwise_8bit_channelwise_8bit_b, - TranposeBWithZeroPointsLargeMStrided) { - test_channelwise_8bit_channelwise_8bit_b< - true /*a_has_zeros*/, - true /*b_has_zeros*/, - false /*a_transposed*/, - true /*b_transposed*/>:: - Run( - /*m=*/4, /*k=*/128, /*n=*/16, 5); -} - -TEST( - test_channelwise_8bit_channelwise_8bit_b, - TranposedBWithZeroPointsOddSizes2Strided) { - test_channelwise_8bit_channelwise_8bit_b< - true /*a_has_zeros*/, - true /*b_has_zeros*/, - false /*a_transposed*/, - true /*b_transposed*/>:: - Run( - /*m=*/4, /*k=*/37, /*n=*/19, 16); -} - -// Test shapes for which we have to use fallback kernel -TEST( - test_channelwise_8bit_channelwise_8bit_b, - TranposedBWithZeroPointsOddSizesFallbackStrided) { - test_channelwise_8bit_channelwise_8bit_b< - true /*a_has_zeros*/, - true /*b_has_zeros*/, - false /*a_transposed*/, - true /*b_transposed*/>:: - Run( - /*m=*/4, /*k=*/37, /*n=*/5, 7); -} - -// Test shapes for which we have to use fallback kernel -TEST( - test_channelwise_8bit_channelwise_8bit_b, - TranposedBWithZeroPointsOddSizesFallback2Strided) { - test_channelwise_8bit_channelwise_8bit_b< - true /*a_has_zeros*/, - true /*b_has_zeros*/, - false /*a_transposed*/, - true /*b_transposed*/>:: - Run( - /*m=*/4, /*k=*/2, /*n=*/1, 32); -} - -class FP32A_QuantizedB_FP32C_Interface_Test - : public ::testing::TestWithParam { - public: - int m; - int k; - int n; - int stride; - - bool rhs_has_zeros; - bool lhs_is_transposed; - bool rhs_is_transposed; - - std::vector init_output; - std::vector expected_output; - - std::vector lhs; - - std::vector rhs; - std::vector rhs_qvals; - std::vector rhs_scales; - std::vector rhs_zeros; - - void generate( - int m_, - int k_, - int n_, - bool rhs_has_zeros_, - bool lhs_is_transposed_, - bool rhs_is_transposed_, - int stride_ = 1) { - assert(!lhs_is_transposed_); - assert(rhs_has_zeros_); - m = m_; - k = k_; - n = n_; - stride = stride_; - rhs_has_zeros = rhs_has_zeros_; - lhs_is_transposed = lhs_is_transposed_; - rhs_is_transposed = rhs_is_transposed_; - - assert(!rhs_is_transposed || stride == 1); - - // Generate activations - lhs = get_random_vector(m * k, -1.0, 1.0); - - // The strange thing this is doing is that instead of quantizing - // each output channel separately, we are quantizing each input channel - // Reason why we do !rhs_is_transposed is because - // we actually want k x n matrix not n x k matrix - // because each input channel is quantized separately - std::tie(rhs, rhs_qvals, rhs_scales, rhs_zeros) = - generate_per_token_quantized_tensor(k * stride, n, rhs_is_transposed); - - // Compute expected output - init_output = get_random_vector(m * n, -1.0, 1.0); - - assert(init_output.size() == m * n); - assert(lhs.size() == m * k); - assert(rhs.size() == n * stride * k); - assert(rhs_qvals.size() == n * stride * k); - assert(rhs_scales.size() == k * stride); - assert(rhs_zeros.size() == k * stride); - } - - void execute(float beta) { - // Compute expected output - expected_output = init_output; - - for (int m_idx = 0; m_idx < m; m_idx++) { - for (int n_idx = 0; n_idx < n; n_idx++) { - float res = 0.0; - for (int k_idx = 0; k_idx < k; k_idx++) { - int lhs_idx = m_idx * k + k_idx; - int rhs_idx = k_idx * stride * n + n_idx; - if (rhs_is_transposed) { - rhs_idx = n_idx * k * stride + k_idx * stride; - } - float rhs_dequant = rhs_scales[k_idx * stride] * - (static_cast(rhs_qvals[rhs_idx]) - - static_cast(rhs_zeros[k_idx * stride])); - - res += lhs[lhs_idx] * rhs_dequant; - } - expected_output[m_idx * n + n_idx] = - expected_output[m_idx * n + n_idx] * beta + res; - } - } - } - - float beta() const { - return GetParam(); - } -}; - -static void test_fp32_a_input_channelwise_8bit_b( - int m, - int k, - int n, - float beta, - FP32A_QuantizedB_FP32C_Interface_Test& test_case, - int stride = 1) { - test_case.execute(beta); - - int a_stride_m, b_stride_n; - auto kernel = torchao::kernels::cpu::quantized_matmul:: - get_fp32_a_input_channelwise_8bit_b_f32_c_matmul( - m, n, k, false, false, a_stride_m, b_stride_n); - b_stride_n = b_stride_n * stride; - - std::vector output(test_case.init_output); - kernel( - m, - n, - k, - test_case.lhs.data(), - a_stride_m /*lhs_stride_m*/, - test_case.rhs_qvals.data(), - b_stride_n /*rhs_stride_n*/, - output.data(), - n /*out_stride_n*/, - test_case.rhs_zeros.data(), - test_case.rhs_scales.data(), - beta, - stride /*rhs qparams stride*/); - - for (int i = 0; i < m * n; i++) { - EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); - } -} - -TEST_P(FP32A_QuantizedB_FP32C_Interface_Test, BTranposedWithZeroPoints) { - generate(3, 128, 16, true, false, false); - test_fp32_a_input_channelwise_8bit_b( - /*m=*/3, /*k=*/128, /*n=*/16, beta(), *this); -} - -TEST_P( - FP32A_QuantizedB_FP32C_Interface_Test, - BTranposedWithZeroPointsOddSizes) { - generate(4, 37, 19, true, false, false); - test_fp32_a_input_channelwise_8bit_b( - /*m=*/4, /*k=*/37, /*n=*/19, beta(), *this); -} - -// Test shapes for which we have to use fallback kernel -TEST_P( - FP32A_QuantizedB_FP32C_Interface_Test, - BTranposedWithZeroPointsOddSizesFallback) { - generate(4, 37, 3, true, false, false); - test_fp32_a_input_channelwise_8bit_b( - /*m=*/4, /*k=*/37, /*n=*/3, beta(), *this); -} - -TEST_P( - FP32A_QuantizedB_FP32C_Interface_Test, - BTranposedWithZeroPointsOddSizes2Fallback) { - generate(4, 1, 3, true, false, false); - test_fp32_a_input_channelwise_8bit_b( - /*m=*/4, /*k=*/1, /*n=*/3, beta(), *this); -} - -TEST_P( - FP32A_QuantizedB_FP32C_Interface_Test, - BTranposedWithZeroPointsOddSizesStrided) { - generate(4, 37, 19, true, false, false, 32); - test_fp32_a_input_channelwise_8bit_b( - /*m=*/4, /*k=*/37, /*n=*/19, beta(), *this, 32); -} - -TEST_P( - FP32A_QuantizedB_FP32C_Interface_Test, - BTranposedWithZeroPointsOddSizes2FallbackStrided) { - generate(4, 5, 3, true, false, false, 32); - test_fp32_a_input_channelwise_8bit_b( - /*m=*/4, /*k=*/5, /*n=*/3, beta(), *this, 32); -} - -TEST_P( - FP32A_QuantizedB_FP32C_Interface_Test, - BTranposedWithZeroPointsOddSizes2) { - generate(19, 37, 35, true, false, false); - test_fp32_a_input_channelwise_8bit_b( - /*m=*/19, /*k=*/37, /*n=*/35, beta(), *this); -} - -TEST_P( - FP32A_QuantizedB_FP32C_Interface_Test, - BTranposedWithZeroPointsOddSizesStrided2) { - generate(23, 37, 50, true, false, false, 32); - test_fp32_a_input_channelwise_8bit_b( - /*m=*/23, /*k=*/37, /*n=*/50, beta(), *this, 32); -} - -INSTANTIATE_TEST_SUITE_P( - F32AInt8BFP32CTest, - FP32A_QuantizedB_FP32C_Interface_Test, - ::testing::Values(0.0, 1.0, 3.1));