From 7a766f4323876250ff3b1021f272b1194da5fa28 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 31 Jul 2024 10:28:23 -0700 Subject: [PATCH] Move lowbit universal kernels from torchaccel to torchao (#582) Summary: Pull Request resolved: https://github.com/pytorch/ao/pull/582 Moves torchaccel universal kernels to torchao/experimental. Differential Revision: D60292095 --- .../cpu/aarch64/benchmarks/CMakeLists.txt | 53 ++ .../benchmarks/benchmark_bitpacking.cpp | 301 ++++++++++ .../aarch64/benchmarks/benchmark_linear.cpp | 241 ++++++++ .../benchmarks/benchmark_quantization.cpp | 37 ++ .../kernels/cpu/aarch64/bitpacking/bitpack.h | 323 +++++++++++ .../kernels/cpu/aarch64/bitpacking/macro.h | 3 + .../kernels/cpu/aarch64/bitpacking/uint3.h | 327 +++++++++++ .../kernels/cpu/aarch64/bitpacking/uint4.h | 66 +++ ...se_lowbit_weight_1x1x32_f32_neondot-impl.h | 361 ++++++++++++ ...se_lowbit_weight_1x4x16_f32_neondot-impl.h | 472 +++++++++++++++ ...se_lowbit_weight_1x8x16_f32_neondot-impl.h | 546 ++++++++++++++++++ ...ion_prepare_activation_data_1xk_f32-impl.h | 117 ++++ .../kernels/cpu/aarch64/linear/linear.h | 162 ++++++ .../cpu/aarch64/quantization/quantize.cpp | 109 ++++ .../cpu/aarch64/quantization/quantize.h | 51 ++ .../cpu/aarch64/reduction/compute_sum.cpp | 20 + .../aarch64/reduction/find_min_and_max.cpp | 32 + .../kernels/cpu/aarch64/reduction/reduction.h | 24 + .../kernels/cpu/aarch64/tests/CMakeLists.txt | 66 +++ .../cpu/aarch64/tests/test_bitpacking.cpp | 353 +++++++++++ .../kernels/cpu/aarch64/tests/test_linear.cpp | 370 ++++++++++++ .../cpu/aarch64/tests/test_quantization.cpp | 66 +++ .../kernels/cpu/aarch64/tests/test_utils.h | 269 +++++++++ .../cpu/aarch64/tests/test_valpacking.cpp | 96 +++ .../cpu/aarch64/valpacking/interleave.cpp | 75 +++ .../kernels/cpu/aarch64/valpacking/valpack.h | 24 + .../kernels/cpu/build_and_run_benchmarks.sh | 19 + .../kernels/cpu/build_and_run_tests.sh | 13 + 28 files changed, 4596 insertions(+) create mode 100644 torchao/experimental/kernels/cpu/aarch64/benchmarks/CMakeLists.txt create mode 100644 torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp create mode 100644 torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp create mode 100644 torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_quantization.cpp create mode 100644 torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h create mode 100644 torchao/experimental/kernels/cpu/aarch64/bitpacking/macro.h create mode 100644 torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h create mode 100644 torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h create mode 100644 torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot-impl.h create mode 100644 torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot-impl.h create mode 100644 torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot-impl.h create mode 100644 torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_prepare_activation_data_1xk_f32-impl.h create mode 100644 torchao/experimental/kernels/cpu/aarch64/linear/linear.h create mode 100644 torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp create mode 100644 torchao/experimental/kernels/cpu/aarch64/quantization/quantize.h create mode 100644 torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp create mode 100644 torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp create mode 100644 torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h create mode 100644 torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt create mode 100644 torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp create mode 100644 torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp create mode 100644 torchao/experimental/kernels/cpu/aarch64/tests/test_quantization.cpp create mode 100644 torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h create mode 100644 torchao/experimental/kernels/cpu/aarch64/tests/test_valpacking.cpp create mode 100644 torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp create mode 100644 torchao/experimental/kernels/cpu/aarch64/valpacking/valpack.h create mode 100644 torchao/experimental/kernels/cpu/build_and_run_benchmarks.sh create mode 100644 torchao/experimental/kernels/cpu/build_and_run_tests.sh diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/benchmarks/CMakeLists.txt new file mode 100644 index 0000000000..1c1a779dbe --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/benchmarks/CMakeLists.txt @@ -0,0 +1,53 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +cmake_minimum_required(VERSION 3.19) +project(benchmarks) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_BUILD_TYPE Release) + +include(FetchContent) +FetchContent_Declare(googlebenchmark + GIT_REPOSITORY https://github.com/google/benchmark.git + GIT_TAG main) # need main for benchmark::benchmark + +set(BENCHMARK_ENABLE_TESTING OFF) +FetchContent_MakeAvailable( + googlebenchmark) + +add_compile_options("-Wall" "-Werror") + +include(CMakePrintHelpers) +message("TORCHAO_LIBRARIES: ${TORCHAO_LIBRARIES}") +include_directories(${TORCHAO_LIBRARIES}) + +add_library( + dep + ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp + ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp + ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp + ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp +) + +add_executable(benchmark_quantization benchmark_quantization.cpp) +target_link_libraries( + benchmark_quantization + PRIVATE + benchmark::benchmark + dep +) + +add_executable(benchmark_bitpacking benchmark_bitpacking.cpp) +target_link_libraries( + benchmark_bitpacking + PRIVATE + benchmark::benchmark + dep +) + +add_executable(benchmark_linear benchmark_linear.cpp) +target_link_libraries( + benchmark_linear + PRIVATE + benchmark::benchmark + dep +) diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp new file mode 100644 index 0000000000..d03a3bfca8 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp @@ -0,0 +1,301 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace { + +// Benchmark utility to compare variants of uint3 packing +void pack_uint3_values( + uint8_t* packed, + uint8_t* unpacked, + int packed_size, + int unpacked_size, + int variant) { + constexpr int nbit = 3; + constexpr int bitsPerByte = 8; + assert(unpacked_size * nbit / bitsPerByte == packed_size); + assert(packed_size % variant == 0); + + uint8x16_t unpacked0; + uint8x16_t unpacked1; + uint8x16_t unpacked2; + uint8x16_t unpacked3; + uint8x16_t unpacked4; + uint8x16_t unpacked5; + uint8x16_t unpacked6; + uint8x16_t unpacked7; + + switch (variant) { + case 8: + for (int i = 0; i < unpacked_size; i += 8) { + torchao::bitpacking::internal::pack_8_uint3_values( + packed + ((i * nbit) / bitsPerByte), unpacked + i); + } + break; + case 64: + for (int i = 0; i < unpacked_size; i += 64) { + torchao::bitpacking::internal::vec_load_64_uint8_values( + unpacked0, unpacked1, unpacked2, unpacked3, unpacked + i); + torchao::bitpacking::internal::vec_pack_64_uint3_values( + packed + ((i * nbit) / bitsPerByte), + unpacked0, + unpacked1, + unpacked2, + unpacked3); + } + break; + case 128: + for (int i = 0; i < unpacked_size; i += 128) { + torchao::bitpacking::internal::vec_load_64_uint8_values( + unpacked0, unpacked1, unpacked2, unpacked3, unpacked + i); + torchao::bitpacking::internal::vec_load_64_uint8_values( + unpacked4, unpacked5, unpacked6, unpacked7, unpacked + i + 64); + torchao::bitpacking::internal::vec_pack_128_uint3_values( + packed + ((i * nbit) / bitsPerByte), + unpacked0, + unpacked1, + unpacked2, + unpacked3, + unpacked4, + unpacked5, + unpacked6, + unpacked7); + } + break; + } +} + +// Benchmark utility to compare variants of uint3 unpacking +void unpack_uint3_values( + uint8_t* unpacked, + uint8_t* packed, + int unpacked_size, + int packed_size, + int variant) { + constexpr int nbit = 3; + constexpr int bitsPerByte = 8; + assert(unpacked_size * nbit / bitsPerByte == packed_size); + assert(packed_size % variant == 0); + + uint8x16_t unpacked0; + uint8x16_t unpacked1; + uint8x16_t unpacked2; + uint8x16_t unpacked3; + uint8x16_t unpacked4; + uint8x16_t unpacked5; + uint8x16_t unpacked6; + uint8x16_t unpacked7; + + switch (variant) { + case 8: + for (int i = 0; i < unpacked_size; i += 8) { + torchao::bitpacking::internal::unpack_8_uint3_values( + unpacked + i, packed + ((i * nbit) / bitsPerByte)); + } + break; + case 64: + for (int i = 0; i < unpacked_size; i += 64) { + torchao::bitpacking::internal::vec_unpack_64_uint3_values( + unpacked0, + unpacked1, + unpacked2, + unpacked3, + packed + ((i * nbit) / bitsPerByte)); + torchao::bitpacking::internal::vec_store_64_uint8_values( + unpacked + i, unpacked0, unpacked1, unpacked2, unpacked3); + } + break; + case 128: + for (int i = 0; i < unpacked_size; i += 128) { + torchao::bitpacking::internal::vec_unpack_128_uint3_values( + unpacked0, + unpacked1, + unpacked2, + unpacked3, + unpacked4, + unpacked5, + unpacked6, + unpacked7, + packed + ((i * nbit) / bitsPerByte)); + torchao::bitpacking::internal::vec_store_64_uint8_values( + unpacked + i, unpacked0, unpacked1, unpacked2, unpacked3); + torchao::bitpacking::internal::vec_store_64_uint8_values( + unpacked + i + 64, unpacked4, unpacked5, unpacked6, unpacked7); + } + break; + } +} + +// Benchmark utility to compare variants of uint4 packing +void pack_uint4_values( + uint8_t* packed, + uint8_t* unpacked, + int packed_size, + int unpacked_size, + int variant) { + constexpr int nbit = 4; + constexpr int bitsPerByte = 8; + assert(unpacked_size * nbit / bitsPerByte == packed_size); + assert(packed_size % variant == 0); + + uint8x16_t unpacked0; + uint8x16_t unpacked1; + + switch (variant) { + case 2: + for (int i = 0; i < unpacked_size; i += 2) { + torchao::bitpacking::internal::pack_2_uint4_values( + packed + ((i * nbit) / bitsPerByte), unpacked + i); + } + break; + case 16: + for (int i = 0; i < unpacked_size; i += 16) { + unpacked0 = vld1q_u8(unpacked + i); + torchao::bitpacking::internal::vec_pack_16_uint4_values( + packed + ((i * nbit) / bitsPerByte), unpacked0); + } + break; + case 32: + for (int i = 0; i < unpacked_size; i += 32) { + unpacked0 = vld1q_u8(unpacked + i); + unpacked1 = vld1q_u8(unpacked + 16 + i); + torchao::bitpacking::internal::vec_pack_32_uint4_values( + packed + ((i * nbit) / bitsPerByte), unpacked0, unpacked1); + } + break; + } +} + +// Benchmark utility to compare variants of uint4 unpacking +void unpack_uint4_values( + uint8_t* unpacked, + uint8_t* packed, + int unpacked_size, + int packed_size, + int variant) { + constexpr int nbit = 4; + constexpr int bitsPerByte = 8; + assert(unpacked_size * nbit / bitsPerByte == packed_size); + assert(packed_size % variant == 0); + + uint8x16_t unpacked0; + uint8x16_t unpacked1; + + switch (variant) { + case 2: + for (int i = 0; i < unpacked_size; i += 2) { + torchao::bitpacking::internal::unpack_2_uint4_values( + unpacked + i, packed + ((i * nbit) / bitsPerByte)); + } + break; + case 16: + for (int i = 0; i < unpacked_size; i += 16) { + torchao::bitpacking::internal::vec_unpack_16_uint4_values( + unpacked0, packed + ((i * nbit) / bitsPerByte)); + vst1q_u8(unpacked + i, unpacked0); + } + break; + case 32: + for (int i = 0; i < unpacked_size; i += 32) { + torchao::bitpacking::internal::vec_unpack_32_uint4_values( + unpacked0, unpacked1, packed + ((i * nbit) / bitsPerByte)); + vst1q_u8(unpacked + i, unpacked0); + vst1q_u8(unpacked + 16 + i, unpacked1); + } + break; + } +} + +} // namespace + +static void benchmark_pack_uint3_values(benchmark::State& state) { + int unpacked_size = state.range(0); + int variant = state.range(1); + int nbit = 3; + + assert(unpacked_size % 8 == 0); + int packed_size = (unpacked_size / 8) * nbit; + + auto packed = std::vector(unpacked_size, 0); + auto unpacked = torchao::get_random_lowbit_vector(packed_size, 8); + + for (auto _ : state) { + pack_uint3_values( + packed.data(), unpacked.data(), packed_size, unpacked_size, variant); + } +} + +static void benchmark_unpack_uint3_values(benchmark::State& state) { + int unpacked_size = state.range(0); + int variant = state.range(1); + int nbit = 3; + + assert(unpacked_size % 8 == 0); + int packed_size = (unpacked_size / 8) * nbit; + + auto packed = torchao::get_random_lowbit_vector(packed_size, 8); + auto unpacked = std::vector(unpacked_size, 0); + + for (auto _ : state) { + unpack_uint3_values( + unpacked.data(), + packed.data(), + unpacked.size(), + packed.size(), + variant); + } +} + +static void benchmark_pack_uint4_values(benchmark::State& state) { + int unpacked_size = state.range(0); + int variant = state.range(1); + int nbit = 4; + + assert(unpacked_size % 8 == 0); + int packed_size = (unpacked_size / 8) * nbit; + + auto packed = std::vector(unpacked_size, 0); + auto unpacked = torchao::get_random_lowbit_vector(packed_size, 8); + + for (auto _ : state) { + pack_uint4_values( + packed.data(), unpacked.data(), packed_size, unpacked_size, variant); + } +} + +static void benchmark_unpack_uint4_values(benchmark::State& state) { + int unpacked_size = state.range(0); + int variant = state.range(1); + int nbit = 4; + + assert(unpacked_size % 8 == 0); + int packed_size = (unpacked_size / 8) * nbit; + + auto packed = torchao::get_random_lowbit_vector(packed_size, 8); + auto unpacked = std::vector(unpacked_size, 0); + + for (auto _ : state) { + unpack_uint4_values( + unpacked.data(), + packed.data(), + unpacked.size(), + packed.size(), + variant); + } +} + +BENCHMARK(benchmark_pack_uint3_values)->ArgsProduct({{128}, {8, 64, 128}}); +BENCHMARK(benchmark_unpack_uint3_values)->ArgsProduct({{128}, {8, 64, 128}}); +BENCHMARK(benchmark_pack_uint4_values)->ArgsProduct({{128}, {2, 16, 32}}); +BENCHMARK(benchmark_unpack_uint4_values)->ArgsProduct({{128}, {2, 16, 32}}); + +// Run the benchmark +BENCHMARK_MAIN(); diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp new file mode 100644 index 0000000000..631bab42d4 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp @@ -0,0 +1,241 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include +#include +#include +#include +#include + +template +static void +channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot( + benchmark::State& state) { + int m = state.range(0); + int n = state.range(1); + int k = state.range(2); + int group_size = state.range(3); + + using namespace torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot; + + auto test_case = torchao:: + channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( + m, + k, + n, + group_size, + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp); + + std::vector activation_data( + activation_data_size(m, k, group_size)); + prepare_activation_data( + (void*)activation_data.data(), + m, + k, + group_size, + test_case.activations.data()); + + std::vector weight_data( + weight_data_size(n, k, group_size)); + prepare_weight_data( + (void*)weight_data.data(), + n, + k, + group_size, + test_case.weight_qvals.data(), + test_case.weight_scales.data(), + test_case.weight_zeros.data()); + + std::vector output(m * k); + for (auto _ : state) { + kernel( + output.data(), + /*output_m_stride=*/n, + m, + n, + k, + group_size, + weight_data.data(), + activation_data.data(), + test_case.bias.data(), + test_case.clamp_min, + test_case.clamp_max); + } +} + +template +static void +channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot( + benchmark::State& state) { + int m = state.range(0); + int n = state.range(1); + int k = state.range(2); + int group_size = state.range(3); + + using namespace torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot; + + auto test_case = torchao:: + channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( + m, + k, + n, + group_size, + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp); + + std::vector activation_data( + activation_data_size(m, k, group_size)); + prepare_activation_data( + (void*)activation_data.data(), + m, + k, + group_size, + test_case.activations.data()); + + std::vector weight_data( + weight_data_size(n, k, group_size)); + prepare_weight_data( + (void*)weight_data.data(), + n, + k, + group_size, + test_case.weight_qvals.data(), + test_case.weight_scales.data(), + test_case.weight_zeros.data()); + + std::vector output(m * k); + for (auto _ : state) { + kernel( + output.data(), + /*output_m_stride=*/n, + m, + n, + k, + group_size, + weight_data.data(), + activation_data.data(), + test_case.bias.data(), + test_case.clamp_min, + test_case.clamp_max); + } +} + +template +static void +channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot( + benchmark::State& state) { + int m = state.range(0); + int n = state.range(1); + int k = state.range(2); + int group_size = state.range(3); + + using namespace torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; + + auto test_case = torchao:: + channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( + m, + k, + n, + group_size, + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp); + + std::vector activation_data( + activation_data_size(m, k, group_size)); + prepare_activation_data( + (void*)activation_data.data(), + m, + k, + group_size, + test_case.activations.data()); + + std::vector weight_data( + weight_data_size(n, k, group_size)); + prepare_weight_data( + (void*)weight_data.data(), + n, + k, + group_size, + test_case.weight_qvals.data(), + test_case.weight_scales.data(), + test_case.weight_zeros.data()); + + std::vector output(m * k); + for (auto _ : state) { + kernel( + output.data(), + /*output_m_stride=*/n, + m, + n, + k, + group_size, + weight_data.data(), + activation_data.data(), + test_case.bias.data(), + test_case.clamp_min, + test_case.clamp_max); + } +} + +#define BENCHMARK_PARAMS \ + { \ + /*m*/ {1}, /*n*/ {8}, /*k*/ {4096, 8192, 16384, 32768, 131072}, \ + /*group_size*/ { \ + 32, 256 \ + } \ + } + +#define BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT( \ + weight_nbit) \ + BENCHMARK( \ + channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< \ + weight_nbit, \ + false, \ + false, \ + false>) \ + ->ArgsProduct(BENCHMARK_PARAMS) + +#define BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT( \ + weight_nbit) \ + BENCHMARK( \ + channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< \ + weight_nbit, \ + false, \ + false, \ + false>) \ + ->ArgsProduct(BENCHMARK_PARAMS) + +#define BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT( \ + weight_nbit) \ + BENCHMARK( \ + channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< \ + weight_nbit, \ + false, \ + false, \ + false>) \ + ->ArgsProduct(BENCHMARK_PARAMS) + +BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT( + 3); +BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT( + 4); +BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT( + 3); +BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT( + 4); +BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT( + 3); +BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT( + 4); + +// Run the benchmark +BENCHMARK_MAIN(); diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_quantization.cpp b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_quantization.cpp new file mode 100644 index 0000000000..942855c017 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_quantization.cpp @@ -0,0 +1,37 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include +#include +#include +#include + +static void benchmark_quantize(benchmark::State& state) { + int nbit = state.range(0); + int size = state.range(1); + auto vals = torchao::get_random_vector(size, -10, 10); + auto qvals = std::vector(size, 0); + + int qmin, qmax, zero; + float vmin, vmax, scale; + + for (auto _ : state) { + torchao::kernels::cpu::aarch64::reduction::find_min_and_max( + vmin, vmax, vals.data(), vals.size()); + + torchao::quantization::get_qvals_range( + qmin, qmax, nbit, /*is_symmetric=*/false); + + torchao::quantization::get_scale_and_zero( + scale, zero, vmin, vmax, qmin, qmax); + + torchao::kernels::cpu::aarch64::quantization::quantize( + qvals.data(), vals.data(), vals.size(), scale, zero, qmin, qmax); + } +} + +BENCHMARK(benchmark_quantize) + ->ArgsProduct( + {{3, 4, 8}, benchmark::CreateRange(1024, 131072, /*multi=*/4)}); + +// Run the benchmark +BENCHMARK_MAIN(); diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h new file mode 100644 index 0000000000..4f1d9a5d5e --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h @@ -0,0 +1,323 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once +#include +#include +#include +#include + +namespace torchao { +namespace bitpacking { + +namespace internal { +TORCHAO_ALWAYS_INLINE inline void vec_store_64_uint8_values( + uint8_t* dest, + const uint8x16_t& vec0, + const uint8x16_t& vec1, + const uint8x16_t& vec2, + const uint8x16_t& vec3) { + vst1q_u8(dest, vec0); + vst1q_u8(dest + 16, vec1); + vst1q_u8(dest + 32, vec2); + vst1q_u8(dest + 48, vec3); +} + +TORCHAO_ALWAYS_INLINE inline void vec_load_64_uint8_values( + uint8x16_t& vec0, + uint8x16_t& vec1, + uint8x16_t& vec2, + uint8x16_t& vec3, + const uint8_t* src) { + vec0 = vld1q_u8(src); + vec1 = vld1q_u8(src + 16); + vec2 = vld1q_u8(src + 32); + vec3 = vld1q_u8(src + 48); +} +} // namespace internal + +template +TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values( + uint8_t* packed, + const int8x16_t& unpacked0, + const int8x16_t& unpacked1) { + static_assert(nbit < 8); + static_assert(nbit >= 2); + + // Currently supported values + static_assert(nbit >= 3); + static_assert(nbit <= 4); + + // Shift unpacked values to nonnegative range + int8x16_t shift = vdupq_n_s8(1 << (nbit - 1)); + uint8x16_t shifted0 = vreinterpretq_u8_s8(vaddq_s8(unpacked0, shift)); + uint8x16_t shifted1 = vreinterpretq_u8_s8(vaddq_s8(unpacked1, shift)); + + switch (nbit) { + case 3: + uint8_t buffer[32]; + vst1q_u8(buffer, shifted0); + vst1q_u8(buffer + 16, shifted1); + + torchao::bitpacking::internal::pack_8_uint3_values(packed, buffer); + torchao::bitpacking::internal::pack_8_uint3_values( + packed + 3, buffer + 8); + torchao::bitpacking::internal::pack_8_uint3_values( + packed + 6, buffer + 16); + torchao::bitpacking::internal::pack_8_uint3_values( + packed + 9, buffer + 24); + break; + case 4: + torchao::bitpacking::internal::vec_pack_32_uint4_values( + packed, shifted0, shifted1); + break; + default: + assert(false); + } +} + +template +TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values( + int8x16_t& unpacked0, + int8x16_t& unpacked1, + uint8_t* packed) { + static_assert(nbit < 8); + static_assert(nbit >= 2); + + // Currently supported values + static_assert(nbit >= 3); + static_assert(nbit <= 4); + + uint8x16_t shifted0; + uint8x16_t shifted1; + + switch (nbit) { + case 3: + uint8_t buffer[32]; + torchao::bitpacking::internal::unpack_8_uint3_values(buffer, packed); + torchao::bitpacking::internal::unpack_8_uint3_values( + buffer + 8, packed + 3); + torchao::bitpacking::internal::unpack_8_uint3_values( + buffer + 16, packed + 6); + torchao::bitpacking::internal::unpack_8_uint3_values( + buffer + 24, packed + 9); + shifted0 = vld1q_u8(buffer); + shifted1 = vld1q_u8(buffer + 16); + break; + case 4: + torchao::bitpacking::internal::vec_unpack_32_uint4_values( + shifted0, shifted1, packed); + break; + default: + assert(false); + } + + // unshift to move unpacked values to full range + int8x16_t unshift = vdupq_n_s8(-(1 << (nbit - 1))); + unpacked0 = vaddq_s8(vreinterpretq_s8_u8(shifted0), unshift); + unpacked1 = vaddq_s8(vreinterpretq_s8_u8(shifted1), unshift); +} + +template +TORCHAO_ALWAYS_INLINE inline void vec_pack_64_lowbit_values( + uint8_t* packed, + const int8x16_t& unpacked0, + const int8x16_t& unpacked1, + const int8x16_t& unpacked2, + const int8x16_t& unpacked3) { + static_assert(nbit < 8); + static_assert(nbit >= 2); + + // Currently supported values + static_assert(nbit >= 3); + static_assert(nbit <= 4); + + // Shift unpacked values to nonnegative range + int8x16_t shift = vdupq_n_s8(1 << (nbit - 1)); + uint8x16_t shifted0 = vreinterpretq_u8_s8(vaddq_s8(unpacked0, shift)); + uint8x16_t shifted1 = vreinterpretq_u8_s8(vaddq_s8(unpacked1, shift)); + uint8x16_t shifted2 = vreinterpretq_u8_s8(vaddq_s8(unpacked2, shift)); + uint8x16_t shifted3 = vreinterpretq_u8_s8(vaddq_s8(unpacked3, shift)); + + switch (nbit) { + case 3: + torchao::bitpacking::internal::vec_pack_64_uint3_values( + packed, shifted0, shifted1, shifted2, shifted3); + break; + case 4: + torchao::bitpacking::internal::vec_pack_32_uint4_values( + packed, shifted0, shifted1); + torchao::bitpacking::internal::vec_pack_32_uint4_values( + packed + 16, shifted2, shifted3); + break; + default: + assert(false); + } +} + +template +TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_lowbit_values( + int8x16_t& unpacked0, + int8x16_t& unpacked1, + int8x16_t& unpacked2, + int8x16_t& unpacked3, + uint8_t* packed) { + static_assert(nbit < 8); + static_assert(nbit >= 2); + + // Currently supported values + static_assert(nbit >= 3); + static_assert(nbit <= 4); + + uint8x16_t shifted0; + uint8x16_t shifted1; + uint8x16_t shifted2; + uint8x16_t shifted3; + + switch (nbit) { + case 3: + torchao::bitpacking::internal::vec_unpack_64_uint3_values( + shifted0, shifted1, shifted2, shifted3, packed); + break; + case 4: + torchao::bitpacking::internal::vec_unpack_32_uint4_values( + shifted0, shifted1, packed); + torchao::bitpacking::internal::vec_unpack_32_uint4_values( + shifted2, shifted3, packed + 16); + break; + default: + assert(false); + } + + // unshift to move unpacked values to full range + int8x16_t unshift = vdupq_n_s8(-(1 << (nbit - 1))); + unpacked0 = vaddq_s8(vreinterpretq_s8_u8(shifted0), unshift); + unpacked1 = vaddq_s8(vreinterpretq_s8_u8(shifted1), unshift); + unpacked2 = vaddq_s8(vreinterpretq_s8_u8(shifted2), unshift); + unpacked3 = vaddq_s8(vreinterpretq_s8_u8(shifted3), unshift); +} + +template +TORCHAO_ALWAYS_INLINE inline void vec_pack_128_lowbit_values( + uint8_t* packed, + const int8x16_t& unpacked0, + const int8x16_t& unpacked1, + const int8x16_t& unpacked2, + const int8x16_t& unpacked3, + const int8x16_t& unpacked4, + const int8x16_t& unpacked5, + const int8x16_t& unpacked6, + const int8x16_t& unpacked7) { + static_assert(nbit < 8); + static_assert(nbit >= 2); + + // Currently supported values + static_assert(nbit >= 3); + static_assert(nbit <= 4); + + // Shift unpacked values to nonnegative range + int8x16_t shift = vdupq_n_s8(1 << (nbit - 1)); + uint8x16_t shifted0 = vreinterpretq_u8_s8(vaddq_s8(unpacked0, shift)); + uint8x16_t shifted1 = vreinterpretq_u8_s8(vaddq_s8(unpacked1, shift)); + uint8x16_t shifted2 = vreinterpretq_u8_s8(vaddq_s8(unpacked2, shift)); + uint8x16_t shifted3 = vreinterpretq_u8_s8(vaddq_s8(unpacked3, shift)); + uint8x16_t shifted4 = vreinterpretq_u8_s8(vaddq_s8(unpacked4, shift)); + uint8x16_t shifted5 = vreinterpretq_u8_s8(vaddq_s8(unpacked5, shift)); + uint8x16_t shifted6 = vreinterpretq_u8_s8(vaddq_s8(unpacked6, shift)); + uint8x16_t shifted7 = vreinterpretq_u8_s8(vaddq_s8(unpacked7, shift)); + + switch (nbit) { + case 3: + torchao::bitpacking::internal::vec_pack_128_uint3_values( + packed, + shifted0, + shifted1, + shifted2, + shifted3, + shifted4, + shifted5, + shifted6, + shifted7); + break; + case 4: + torchao::bitpacking::internal::vec_pack_32_uint4_values( + packed, shifted0, shifted1); + torchao::bitpacking::internal::vec_pack_32_uint4_values( + packed + 16, shifted2, shifted3); + torchao::bitpacking::internal::vec_pack_32_uint4_values( + packed + 32, shifted4, shifted5); + torchao::bitpacking::internal::vec_pack_32_uint4_values( + packed + 48, shifted6, shifted7); + break; + default: + assert(false); + } +} + +template +TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_lowbit_values( + int8x16_t& unpacked0, + int8x16_t& unpacked1, + int8x16_t& unpacked2, + int8x16_t& unpacked3, + int8x16_t& unpacked4, + int8x16_t& unpacked5, + int8x16_t& unpacked6, + int8x16_t& unpacked7, + uint8_t* packed) { + static_assert(nbit < 8); + static_assert(nbit >= 2); + + // Currently supported values + static_assert(nbit >= 3); + static_assert(nbit <= 4); + + uint8x16_t shifted0; + uint8x16_t shifted1; + uint8x16_t shifted2; + uint8x16_t shifted3; + uint8x16_t shifted4; + uint8x16_t shifted5; + uint8x16_t shifted6; + uint8x16_t shifted7; + + switch (nbit) { + case 3: + torchao::bitpacking::internal::vec_unpack_128_uint3_values( + shifted0, + shifted1, + shifted2, + shifted3, + shifted4, + shifted5, + shifted6, + shifted7, + packed); + break; + case 4: + torchao::bitpacking::internal::vec_unpack_32_uint4_values( + shifted0, shifted1, packed); + torchao::bitpacking::internal::vec_unpack_32_uint4_values( + shifted2, shifted3, packed + 16); + torchao::bitpacking::internal::vec_unpack_32_uint4_values( + shifted4, shifted5, packed + 32); + torchao::bitpacking::internal::vec_unpack_32_uint4_values( + shifted6, shifted7, packed + 48); + break; + default: + assert(false); + } + + // unshift to move unpacked values to full range + int8x16_t unshift = vdupq_n_s8(-(1 << (nbit - 1))); + unpacked0 = vaddq_s8(vreinterpretq_s8_u8(shifted0), unshift); + unpacked1 = vaddq_s8(vreinterpretq_s8_u8(shifted1), unshift); + unpacked2 = vaddq_s8(vreinterpretq_s8_u8(shifted2), unshift); + unpacked3 = vaddq_s8(vreinterpretq_s8_u8(shifted3), unshift); + unpacked4 = vaddq_s8(vreinterpretq_s8_u8(shifted4), unshift); + unpacked5 = vaddq_s8(vreinterpretq_s8_u8(shifted5), unshift); + unpacked6 = vaddq_s8(vreinterpretq_s8_u8(shifted6), unshift); + unpacked7 = vaddq_s8(vreinterpretq_s8_u8(shifted7), unshift); +} + +} // namespace bitpacking +} // namespace torchao diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/macro.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/macro.h new file mode 100644 index 0000000000..6bd06e0dfe --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/macro.h @@ -0,0 +1,3 @@ +#pragma once + +#define TORCHAO_ALWAYS_INLINE __attribute__((always_inline)) diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h new file mode 100644 index 0000000000..b747148092 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h @@ -0,0 +1,327 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once +#include +#include + +// This file contains bitpacking and unpacking methods for uint3. +// These are not inteded to be used outside of bitpacking directory. +// See bitpack.h for the interface. + +namespace torchao { +namespace bitpacking { +namespace internal { + +TORCHAO_ALWAYS_INLINE inline void pack_8_uint3_values( + uint8_t* packed, + const uint8_t* unpacked) { + // Given 8 unpacked uint3 values: 0ab, 1cd, 2ef, 3gh, 4ij, 5kl, 6mn, 7op, + // this function packs them as: + // b2: 7|6|5|4|3|2|1|0 (upper bits for all values) + // b10_0: gh|ef|cd|ab (lower 2 bits for first 4 values) + // b10_1: op|mn|kl|ij (lower 2 bits for last 4 values) + // These are stored in packed as: b2, b10_0, b10_1 + // + // Input is 8 bytes + // Output is 24 bits = 3 bytes + + // b2 + packed[0] = ((unpacked[0] & 4) >> 2) | ((unpacked[1] & 4) >> 1) | + ((unpacked[2] & 4)) | ((unpacked[3] & 4) << 1) | + ((unpacked[4] & 4) << 2) | ((unpacked[5] & 4) << 3) | + ((unpacked[6] & 4) << 4) | ((unpacked[7] & 4) << 5); + + // b10_0 + packed[1] = (unpacked[0] & 3) | ((unpacked[1] & 3) << 2) | + ((unpacked[2] & 3) << 4) | ((unpacked[3] & 3) << 6); + + // b10_1 + packed[2] = (unpacked[4] & 3) | ((unpacked[5] & 3) << 2) | + ((unpacked[6] & 3) << 4) | ((unpacked[7] & 3) << 6); +} + +TORCHAO_ALWAYS_INLINE inline void unpack_8_uint3_values( + uint8_t* unpacked, + const uint8_t* packed) { + // Unpacks data packed by pack_8_uint3_values + // + // Input is 24 bits = 3 bytes + // Output is 8 bytes + + uint8_t b2 = packed[0]; + uint8_t b10_0 = packed[1]; + uint8_t b10_1 = packed[2]; + + unpacked[0] = ((b2 & 1) << 2) | (b10_0 & 3); + unpacked[1] = ((b2 & 2) << 1) | ((b10_0 & 12) >> 2); + unpacked[2] = (b2 & 4) | ((b10_0 & 48) >> 4); + unpacked[3] = ((b2 & 8) >> 1) | ((b10_0 & 192) >> 6); + + unpacked[4] = ((b2 & 16) >> 2) | (b10_1 & 3); + unpacked[5] = ((b2 & 32) >> 3) | ((b10_1 & 12) >> 2); + unpacked[6] = ((b2 & 64) >> 4) | ((b10_1 & 48) >> 4); + unpacked[7] = ((b2 & 128) >> 5) | ((b10_1 & 192) >> 6); +} + +TORCHAO_ALWAYS_INLINE inline void vec_pack_64_uint3_values( + uint8_t* packed, + const uint8x16_t& unpacked0, + const uint8x16_t& unpacked1, + const uint8x16_t& unpacked2, + const uint8x16_t& unpacked3) { + // This function is a vectorized version of pack_8_uint3_values + // To understand it, please see pack_8_uint3_values first. + // Before each code section, there is a comment indicating the + // code in pack_8_uint3_values that is being vectorized + // + // Input is 64 bytes + // Output is 3*64= 192 bits = 24 bytes + + uint8x8_t b2; + uint8x8_t mask; + + // b2 + // packed[0] = ((unpacked[0] & 4) >> 2) | ((unpacked[1] & 4) >> 1) | + // ((unpacked[2] & 4)) | ((unpacked[3] & 4) << 1) | + // ((unpacked[4] & 4) << 2) | ((unpacked[5] & 4) << 3) | + // ((unpacked[6] & 4) << 4) | ((unpacked[7] & 4) << 5); + mask = vdup_n_u8(4); + b2 = vshr_n_u8(vand_u8(vget_low_u8(unpacked0), mask), 2); + b2 = vorr_u8(b2, vshr_n_u8(vand_u8(vget_high_u8(unpacked0), mask), 1)); + + b2 = vorr_u8(b2, vand_u8(vget_low_u8(unpacked1), mask)); + b2 = vorr_u8(b2, vshl_n_u8(vand_u8(vget_high_u8(unpacked1), mask), 1)); + + b2 = vorr_u8(b2, vshl_n_u8(vand_u8(vget_low_u8(unpacked2), mask), 2)); + b2 = vorr_u8(b2, vshl_n_u8(vand_u8(vget_high_u8(unpacked2), mask), 3)); + + b2 = vorr_u8(b2, vshl_n_u8(vand_u8(vget_low_u8(unpacked3), mask), 4)); + b2 = vorr_u8(b2, vshl_n_u8(vand_u8(vget_high_u8(unpacked3), mask), 5)); + + vst1_u8(packed, b2); + + // b10_0 + // packed[1] = (unpacked[0] & 3) | ((unpacked[1] & 3) << 2) | + // ((unpacked[2] & 3) << 4) | ((unpacked[3] & 3) << 6); + mask = vdup_n_u8(3); + uint8x8_t b10_0; + + b10_0 = vand_u8(vget_low_u8(unpacked0), mask); + b10_0 = vorr_u8(b10_0, vshl_n_u8(vand_u8(vget_high_u8(unpacked0), mask), 2)); + + b10_0 = vorr_u8(b10_0, vshl_n_u8(vand_u8(vget_low_u8(unpacked1), mask), 4)); + b10_0 = vorr_u8(b10_0, vshl_n_u8(vand_u8(vget_high_u8(unpacked1), mask), 6)); + + vst1_u8(packed + 8, b10_0); + + // b10_1 + // packed[2] = (unpacked[4] & 3) | ((unpacked[5] & 3) << 2) | + // ((unpacked[6] & 3) << 4) | ((unpacked[7] & 3) << 6); + uint8x8_t b10_1; + + b10_1 = vand_u8(vget_low_u8(unpacked2), mask); + b10_1 = vorr_u8(b10_1, vshl_n_u8(vand_u8(vget_high_u8(unpacked2), mask), 2)); + + b10_1 = vorr_u8(b10_1, vshl_n_u8(vand_u8(vget_low_u8(unpacked3), mask), 4)); + b10_1 = vorr_u8(b10_1, vshl_n_u8(vand_u8(vget_high_u8(unpacked3), mask), 6)); + + vst1_u8(packed + 16, b10_1); +} + +TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint3_values( + uint8x16_t& unpacked0, + uint8x16_t& unpacked1, + uint8x16_t& unpacked2, + uint8x16_t& unpacked3, + const uint8_t* packed) { + // Unpacks data packed by pack_64_uint3_values + // + // This function vectorizes vec_unpack_8_uint3_values + // To understand it, please see vec_unpack_8_uint3_values first. + // Before each code section, there is a comment indicating the + // code in vec_unpack_8_uint3_values that is being vectorized + + // Input is 3*64= 192 bits = 24 bytes + // Output is 64 bytes + + uint8x8_t b2 = vld1_u8(packed); + uint8x8_t b10_0 = vld1_u8(packed + 8); + uint8x8_t unpacked_tmp0; + uint8x8_t unpacked_tmp1; + + // unpacked[0] = ((b2 & 1) << 2) | (b10_0 & 3); + unpacked_tmp0 = vshl_n_u8(vand_u8(b2, vdup_n_u8(1)), 2); + unpacked_tmp0 = vorr_u8(unpacked_tmp0, vand_u8(b10_0, vdup_n_u8(3))); + + // unpacked[1] = ((b2 & 2) << 1) | ((b10_0 & 12) >> 2); + unpacked_tmp1 = vshl_n_u8(vand_u8(b2, vdup_n_u8(2)), 1); + unpacked_tmp1 = + vorr_u8(unpacked_tmp1, vshr_n_u8(vand_u8(b10_0, vdup_n_u8(12)), 2)); + + unpacked0 = vcombine_u8(unpacked_tmp0, unpacked_tmp1); + + // unpacked[2] = (b2 & 4) | ((b10_0 & 48) >> 4); + unpacked_tmp0 = vand_u8(b2, vdup_n_u8(4)); + unpacked_tmp0 = + vorr_u8(unpacked_tmp0, vshr_n_u8(vand_u8(b10_0, vdup_n_u8(48)), 4)); + + // unpacked[3] = ((b2 & 8) >> 1) | ((b10_0 & 192) >> 6); + unpacked_tmp1 = vshr_n_u8(vand_u8(b2, vdup_n_u8(8)), 1); + unpacked_tmp1 = + vorr_u8(unpacked_tmp1, vshr_n_u8(vand_u8(b10_0, vdup_n_u8(192)), 6)); + + unpacked1 = vcombine_u8(unpacked_tmp0, unpacked_tmp1); + + // unpacked[4] = ((b2 & 16) >> 2) | (b10_1 & 3); + uint8x8_t b10_1 = vld1_u8(packed + 16); + unpacked_tmp0 = vshr_n_u8(vand_u8(b2, vdup_n_u8(16)), 2); + unpacked_tmp0 = vorr_u8(unpacked_tmp0, vand_u8(b10_1, vdup_n_u8(3))); + + // unpacked[5] = ((b2 & 32) >> 3) | ((b10_1 & 12) >> 2); + unpacked_tmp1 = vshr_n_u8(vand_u8(b2, vdup_n_u8(32)), 3); + unpacked_tmp1 = + vorr_u8(unpacked_tmp1, vshr_n_u8(vand_u8(b10_1, vdup_n_u8(12)), 2)); + + unpacked2 = vcombine_u8(unpacked_tmp0, unpacked_tmp1); + + // unpacked[6] = ((b2 & 64) >> 4) | ((b10_1 & 48) >> 4); + unpacked_tmp0 = vshr_n_u8(vand_u8(b2, vdup_n_u8(64)), 4); + unpacked_tmp0 = + vorr_u8(unpacked_tmp0, vshr_n_u8(vand_u8(b10_1, vdup_n_u8(48)), 4)); + + // unpacked[7] = ((b2 & 128) >> 5) | ((b10_1 & 192) >> 6); + unpacked_tmp1 = vshr_n_u8(vand_u8(b2, vdup_n_u8(128)), 5); + unpacked_tmp1 = + vorr_u8(unpacked_tmp1, vshr_n_u8(vand_u8(b10_1, vdup_n_u8(192)), 6)); + unpacked3 = vcombine_u8(unpacked_tmp0, unpacked_tmp1); +} + +TORCHAO_ALWAYS_INLINE inline void vec_pack_128_uint3_values( + uint8_t* packed, + const uint8x16_t& unpacked0, + const uint8x16_t& unpacked1, + const uint8x16_t& unpacked2, + const uint8x16_t& unpacked3, + const uint8x16_t& unpacked4, + const uint8x16_t& unpacked5, + const uint8x16_t& unpacked6, + const uint8x16_t& unpacked7) { + // This function is a vectorized version of pack_8_uint3_values + // To understand it, please see pack_8_uint3_values first. + // Before each code section, there is a comment indicating the + // code in pack_8_uint3_values that is being vectorized + // + // Input is 128 bytes + // Output is 3*128= 384 bits = 48 bytes + + uint8x16_t b2; + uint8x16_t mask; + + // b2 + // packed[0] = ((unpacked[0] & 4) >> 2) | ((unpacked[1] & 4) >> 1) | + // ((unpacked[2] & 4)) | ((unpacked[3] & 4) << 1) | + // ((unpacked[4] & 4) << 2) | ((unpacked[5] & 4) << 3) | + // ((unpacked[6] & 4) << 4) | ((unpacked[7] & 4) << 5); + mask = vdupq_n_u8(4); + b2 = vshrq_n_u8(vandq_u8(unpacked0, mask), 2); + b2 = vorrq_u8(b2, vshrq_n_u8(vandq_u8(unpacked1, mask), 1)); + b2 = vorrq_u8(b2, vandq_u8(unpacked2, mask)); + b2 = vorrq_u8(b2, vshlq_n_u8(vandq_u8(unpacked3, mask), 1)); + b2 = vorrq_u8(b2, vshlq_n_u8(vandq_u8(unpacked4, mask), 2)); + b2 = vorrq_u8(b2, vshlq_n_u8(vandq_u8(unpacked5, mask), 3)); + b2 = vorrq_u8(b2, vshlq_n_u8(vandq_u8(unpacked6, mask), 4)); + b2 = vorrq_u8(b2, vshlq_n_u8(vandq_u8(unpacked7, mask), 5)); + + vst1q_u8(packed, b2); + + // b10_0 + // packed[1] = (unpacked[0] & 3) | ((unpacked[1] & 3) << 2) | + // ((unpacked[2] & 3) << 4) | ((unpacked[3] & 3) << 6); + mask = vdupq_n_u8(3); + uint8x16_t b10_0; + + b10_0 = vandq_u8(unpacked0, mask); + b10_0 = vorrq_u8(b10_0, vshlq_n_u8(vandq_u8(unpacked1, mask), 2)); + b10_0 = vorrq_u8(b10_0, vshlq_n_u8(vandq_u8(unpacked2, mask), 4)); + b10_0 = vorrq_u8(b10_0, vshlq_n_u8(vandq_u8(unpacked3, mask), 6)); + + vst1q_u8(packed + 16, b10_0); + + // b10_1 + // packed[2] = (unpacked[4] & 3) | ((unpacked[5] & 3) << 2) | + // ((unpacked[6] & 3) << 4) | ((unpacked[7] & 3) << 6); + uint8x16_t b10_1; + b10_1 = vandq_u8(unpacked4, mask); + b10_1 = vorrq_u8(b10_1, vshlq_n_u8(vandq_u8(unpacked5, mask), 2)); + b10_1 = vorrq_u8(b10_1, vshlq_n_u8(vandq_u8(unpacked6, mask), 4)); + b10_1 = vorrq_u8(b10_1, vshlq_n_u8(vandq_u8(unpacked7, mask), 6)); + + vst1q_u8(packed + 32, b10_1); +} + +TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_uint3_values( + uint8x16_t& unpacked0, + uint8x16_t& unpacked1, + uint8x16_t& unpacked2, + uint8x16_t& unpacked3, + uint8x16_t& unpacked4, + uint8x16_t& unpacked5, + uint8x16_t& unpacked6, + uint8x16_t& unpacked7, + const uint8_t* packed) { + // Unpacks data packed by pack_128_uint3_values + // + // This function vectorizes vec_unpack_8_uint3_values + // To understand it, please see vec_unpack_8_uint3_values first. + // Before each code section, there is a comment indicating the + // code in vec_unpack_8_uint3_values that is being vectorized + + // Input is 3*128 = 384 bits = 48 bytes + // Output is 128 bytes + + uint8x16_t b2 = vld1q_u8(packed); + uint8x16_t b10_0 = vld1q_u8(packed + 16); + + // unpacked[0] = ((b2 & 1) << 2) | (b10_0 & 3); + unpacked0 = vshlq_n_u8(vandq_u8(b2, vdupq_n_u8(1)), 2); + unpacked0 = vorrq_u8(unpacked0, vandq_u8(b10_0, vdupq_n_u8(3))); + + // unpacked[1] = ((b2 & 2) << 1) | ((b10_0 & 12) >> 2); + unpacked1 = vshlq_n_u8(vandq_u8(b2, vdupq_n_u8(2)), 1); + unpacked1 = + vorrq_u8(unpacked1, vshrq_n_u8(vandq_u8(b10_0, vdupq_n_u8(12)), 2)); + + // unpacked[2] = (b2 & 4) | ((b10_0 & 48) >> 4); + unpacked2 = vandq_u8(b2, vdupq_n_u8(4)); + unpacked2 = + vorrq_u8(unpacked2, vshrq_n_u8(vandq_u8(b10_0, vdupq_n_u8(48)), 4)); + + // unpacked[3] = ((b2 & 8) >> 1) | ((b10_0 & 192) >> 6); + unpacked3 = vshrq_n_u8(vandq_u8(b2, vdupq_n_u8(8)), 1); + unpacked3 = + vorrq_u8(unpacked3, vshrq_n_u8(vandq_u8(b10_0, vdupq_n_u8(192)), 6)); + + // unpacked[4] = ((b2 & 16) >> 2) | (b10_1 & 3); + uint8x16_t b10_1 = vld1q_u8(packed + 32); + unpacked4 = vshrq_n_u8(vandq_u8(b2, vdupq_n_u8(16)), 2); + unpacked4 = vorrq_u8(unpacked4, vandq_u8(b10_1, vdupq_n_u8(3))); + + // unpacked[5] = ((b2 & 32) >> 3) | ((b10_1 & 12) >> 2); + unpacked5 = vshrq_n_u8(vandq_u8(b2, vdupq_n_u8(32)), 3); + unpacked5 = + vorrq_u8(unpacked5, vshrq_n_u8(vandq_u8(b10_1, vdupq_n_u8(12)), 2)); + + // unpacked[6] = ((b2 & 64) >> 4) | ((b10_1 & 48) >> 4); + unpacked6 = vshrq_n_u8(vandq_u8(b2, vdupq_n_u8(64)), 4); + unpacked6 = + vorrq_u8(unpacked6, vshrq_n_u8(vandq_u8(b10_1, vdupq_n_u8(48)), 4)); + + // unpacked[7] = ((b2 & 128) >> 5) | ((b10_1 & 192) >> 6); + unpacked7 = vshrq_n_u8(vandq_u8(b2, vdupq_n_u8(128)), 5); + unpacked7 = + vorrq_u8(unpacked7, vshrq_n_u8(vandq_u8(b10_1, vdupq_n_u8(192)), 6)); +} + +} // namespace internal +} // namespace bitpacking +} // namespace torchao diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h new file mode 100644 index 0000000000..c30949d72b --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h @@ -0,0 +1,66 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once +#include +#include + +// This file contains bitpacking and unpacking methods for uint4. +// These are not inteded to be used outside of bitpacking directory. +// See bitpack.h for the interface. + +namespace torchao { +namespace bitpacking { +namespace internal { + +TORCHAO_ALWAYS_INLINE inline void pack_2_uint4_values( + uint8_t* packed, + const uint8_t* unpacked) { + packed[0] = (unpacked[0] << 4) | (unpacked[1] & 0xF); +} + +TORCHAO_ALWAYS_INLINE inline void unpack_2_uint4_values( + uint8_t* unpacked, + const uint8_t* packed) { + unpacked[0] = packed[0] >> 4; + unpacked[1] = packed[0] & 0xF; +} + +TORCHAO_ALWAYS_INLINE inline void vec_pack_16_uint4_values( + uint8_t* packed, + const uint8x16_t& unpacked) { + uint8x8_t unpacked_low = vget_low_u8(unpacked); + uint8x8_t unpacked_high = vshl_n_u8(vget_high_u8(unpacked), 4); + uint8x8_t packed_to_st = vorr_u8(unpacked_low, unpacked_high); + vst1_u8(packed, packed_to_st); +} + +TORCHAO_ALWAYS_INLINE inline void vec_unpack_16_uint4_values( + uint8x16_t& unpacked, + const uint8_t* packed) { + uint8x8_t packed_ld = vld1_u8(packed); + uint8x8_t high = vshr_n_u8(packed_ld, 4); + uint8x8_t low = vand_u8(packed_ld, vdup_n_u8(0xF)); + unpacked = vcombine_u8(low, high); +} + +TORCHAO_ALWAYS_INLINE inline void vec_pack_32_uint4_values( + uint8_t* packed, + const uint8x16_t& unpacked0, + const uint8x16_t& unpacked1) { + uint8x16_t high = vshlq_n_u8(unpacked1, 4); + uint8x16_t packed_to_st = vorrq_u8(unpacked0, high); + vst1q_u8(packed, packed_to_st); +} + +TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_uint4_values( + uint8x16_t& unpacked0, + uint8x16_t& unpacked1, + const uint8_t* packed) { + uint8x16_t packed_ld = vld1q_u8(packed); + unpacked1 = vshrq_n_u8(packed_ld, 4); + unpacked0 = vandq_u8(packed_ld, vdupq_n_u8(0xF)); +} + +} // namespace internal +} // namespace bitpacking +} // namespace torchao diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot-impl.h b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot-impl.h new file mode 100644 index 0000000000..626bff3487 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot-impl.h @@ -0,0 +1,361 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once +#include +#include +#include +#include +#include + +namespace torchao::kernels::cpu::aarch64::linear { +namespace channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot:: + internal { + +inline float clamp(float x, float min, float max) { + if (x < min) + return min; + if (x > max) + return max; + return x; +} + +// Implements variants of +// channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot +// to compute +// output = F(activations * weights + bias) +// where +// +// * activations are mxk and transposed, stored in row-major order. +// * weights are kxn and transposed, stored in column-major order. +// (can also be viewed as nxk non-transposed weights stored in row-major +// order). +// * bias are mx1. Ignored if has_bias = false. +// * F is an element-wise activation function, either clamp (has_clamp = true) +// or linear (has_clamp = false). +// * output is mxn. +// +// The suffix 1x1x32_f32_neondot indicates the tile size (1x1), the number of +// values unpacked in each inner loop (32), floating point type for output +// (f32), and main ISA instruction (neon_dot). +// +// Activations are channelwise 8-bit quantized, with a scale and zero per row +// Weights are groupwise lowbit (weight_nbit) quantized with a scale (and zero +// if has_weight_zeros = true) per group. +// +// Both activations and weights are dequantized with +// scale * (qval - zero) +// +// The output is computed by dequantizing the activations and weights and +// computing F(activations * weights + bias). +// +// Activations and weights are stored in a prepared format specific to +// this kernel: +// +// activation_data +// Per m_idx (row), activations are stored as follows: +// scale (float), zero (int8_t), +// group0_qvals (int8_t[group_size]), [group0_qvals_sum (int32_t)]? +// group1_qvals (int8_t[group_size]), [group1_qvals_sum (int32_t)]? +// ... +// +// The groupi_qvals_sum is only present if has_weight_zeros = true. +// +// weight_data +// Per n_idx (column), weights are stored as follows: +// group0_qvals (int8_t[group_size]), group0_scale (float), group0_qvals_sum +// (int32_t), [group0_zero (int8_t)]? +// ... +// The groupi_zero is only present if has_weight_zeros = true. +template +void kernel_impl( + // Outputs + float32_t* output, + // Inputs + int output_m_stride, + int m, + int n, + int k, + int group_size, + const void* weight_data, + const void* activation_data, + // Ignored if has_bias is false + const float* bias, + // Ignored if has_clamp is false + float clamp_min, + float clamp_max) { + assert(k % group_size == 0); + assert(group_size % 32 == 0); + constexpr int bytes_per_32_weight_values = 4 * weight_nbit; + + auto activation_data_byte_ptr = (char*)(activation_data); + char* activation_ptr; + + for (int m_idx = 0; m_idx < m; m_idx++) { + // Read activation scale and zero + float activation_scale = *((float*)activation_data_byte_ptr); + activation_data_byte_ptr += sizeof(float); + + int8_t activation_zero = *((int8_t*)activation_data_byte_ptr); + activation_data_byte_ptr += sizeof(int8_t); + + // Set weight_data_byte_ptr to start of weight_data + auto weight_data_byte_ptr = (char*)(weight_data); + for (int n_idx = 0; n_idx < n; n_idx++) { + // Set activation_ptr to start of activation qvals for row m_idx + activation_ptr = activation_data_byte_ptr; + float res = 0.0; + + // Loop k_idx by group + for (int k_idx = 0; k_idx < k; k_idx += group_size) { + // Process group in chunks of 32, accumulating dot products in acc + int32x4_t acc = vdupq_n_s32(0); + int8x16_t wq0, wq1, aq; + + for (int i = 0; i < group_size; i += 32) { + torchao::bitpacking::vec_unpack_32_lowbit_values( + /*unpacked0=*/wq0, + /*unpacked1=*/wq1, + /*packed=*/(uint8_t*)weight_data_byte_ptr); + + weight_data_byte_ptr += bytes_per_32_weight_values; + + // Dot product of first 16 values in chunk + aq = vld1q_s8((int8_t*)activation_ptr); + activation_ptr += 16; + acc = vdotq_s32(acc, wq0, aq); + + // Dot product of second 16 values in chunk + aq = vld1q_s8((int8_t*)activation_ptr); + activation_ptr += 16; + acc = vdotq_s32(acc, wq1, aq); + } + int32_t qval_dot = vaddvq_s32(acc); + + // Dequantize and accumulate in result + float weight_scale = *((float*)weight_data_byte_ptr); + weight_data_byte_ptr += sizeof(float); + + int32_t weight_qvals_sum = *((int32_t*)weight_data_byte_ptr); + weight_data_byte_ptr += sizeof(int32_t); + + if constexpr (has_weight_zeros) { + int32_t activation_qvals_sum = *((int32_t*)activation_ptr); + activation_ptr += sizeof(int32_t); + + int8_t weight_zero = *((int8_t*)weight_data_byte_ptr); + weight_data_byte_ptr += sizeof(int8_t); + + res += (weight_scale * activation_scale) * + (qval_dot - (activation_zero * weight_qvals_sum) - + (weight_zero * activation_qvals_sum) + + (group_size * weight_zero * activation_zero)); + } else { + res += (weight_scale * activation_scale) * + (qval_dot - activation_zero * weight_qvals_sum); + } + } // k_idx + if constexpr (has_bias) { + res += bias[m_idx]; + } + if constexpr (has_clamp) { + res = clamp(res, clamp_min, clamp_max); + } + output[m_idx * output_m_stride + n_idx] = res; + } // n_idx + activation_data_byte_ptr += (activation_ptr - activation_data_byte_ptr); + } // m_idx +} + +// Prepares weight data for kernel_impl. +// Per n_idx (column), weights are stored as follows: +// group0_qvals (int8_t[group_size]), group0_scale (float), group0_qvals_sum +// (int32_t), [group0_zero (int8_t)]? +// ... +// The groupi_zero is only present if has_weight_zeros = true. + +// Returns number of bytes required for weight_data +int inline weight_data_size_impl( + int n, + int k, + int group_size, + int weight_nbit, + bool has_weight_zeros) { + assert(k % group_size == 0); + assert(k % 32 == 0); + int groups_per_col = k / group_size; + int col_size = 0; + + // qvals + // (k * weight_bit) bits -> ((k / 8) * weight_bit) bytes + col_size += (k / 8) * weight_nbit; + + // scales + col_size += sizeof(float) * groups_per_col; + + // qvals_sum + col_size += sizeof(int32_t) * groups_per_col; + + // zeros + if (has_weight_zeros) { + col_size += sizeof(int8_t) * groups_per_col; + } + + return col_size * n; +} + +template +void prepare_weight_data_impl( + // Output + void* weight_data, + // Inputs + int n, + int k, + int group_size, + const int8_t* weight_qvals, + const float* weight_scales, + // Ignored if has_weight_zeros = false + const int8_t* weight_zeros) { + assert(k % group_size == 0); + assert(group_size % 32 == 0); + + auto weight_data_byte_ptr = (char*)weight_data; + constexpr int bytes_per_32_weight_values = 4 * weight_nbit; + + int8x16_t wq0, wq1; + + const int8_t* qvals_ptr = weight_qvals; + const float* scales_ptr = weight_scales; + const int8_t* zeros_ptr = weight_zeros; + + for (int n_idx = 0; n_idx < n; n_idx++) { + for (int k_idx = 0; k_idx < k; k_idx += group_size) { + int32_t group_qvals_sum = 0; + for (int i = 0; i < group_size; i += 32) { + wq0 = vld1q_s8(qvals_ptr); + wq1 = vld1q_s8(qvals_ptr + 16); + qvals_ptr += 32; + + group_qvals_sum += vaddlvq_s8(wq0) + vaddlvq_s8(wq1); + + torchao::bitpacking::vec_pack_32_lowbit_values( + /*packed=*/(uint8_t*)weight_data_byte_ptr, + /*unpacked0=*/wq0, + /*unpacked1=*/wq1); + weight_data_byte_ptr += bytes_per_32_weight_values; + } + *((float*)weight_data_byte_ptr) = *scales_ptr++; + weight_data_byte_ptr += sizeof(float); + + *((int32_t*)weight_data_byte_ptr) = group_qvals_sum; + weight_data_byte_ptr += sizeof(int32_t); + + if constexpr (has_weight_zeros) { + *((int8_t*)weight_data_byte_ptr) = *zeros_ptr++; + weight_data_byte_ptr += sizeof(int8_t); + } + } + } +} + +} // namespace + // channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot::internal +} // namespace torchao::kernels::cpu::aarch64::linear + +// Activation functions +template +int torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot:: + activation_data_size(int m, int k, int group_size) { + return torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_prepare_activation_data_1xk_f32::internal:: + activation_data_size_impl(m, k, group_size, has_weight_zeros); +} + +template +void torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot:: + prepare_activation_data( + void* activation_data, + // Inputs + int m, + int k, + // Ignored if has_weight_zeros = false + int group_size, + const float* activations) { + torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_prepare_activation_data_1xk_f32::internal:: + prepare_activation_data_impl( + activation_data, m, k, group_size, activations); +} + +// Weight functions +template +int torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot:: + weight_data_size(int n, int k, int group_size) { + return torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot:: + internal::weight_data_size_impl( + n, k, group_size, weight_nbit, has_weight_zeros); +} + +template +void torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot:: + prepare_weight_data( + void* weight_data, + // Inputs + int n, + int k, + int group_size, + const int8_t* weight_qvals, + const float* weight_scales, + const int8_t* weight_zeros) { + torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot:: + internal::prepare_weight_data_impl( + weight_data, + n, + k, + group_size, + weight_qvals, + weight_scales, + weight_zeros); +} + +// Kernel function +template +void torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot:: + kernel( + // Outputs + float32_t* output, + // Inputs + int output_m_stride, + int m, + int n, + int k, + int group_size, + const void* weight_data, + const void* activation_data, + // Not applied if nullptr + const float* bias, + // Ignored if has_clamp = false + float clamp_min, + float clamp_max) { + torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot:: + internal:: + kernel_impl( + output, + output_m_stride, + m, + n, + k, + group_size, + weight_data, + activation_data, + bias, + clamp_min, + clamp_max); +} diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot-impl.h b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot-impl.h new file mode 100644 index 0000000000..4dbade848f --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot-impl.h @@ -0,0 +1,472 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once +#include +#include +#include +#include +#include + +namespace torchao::kernels::cpu::aarch64::linear { +namespace channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot:: + internal { + +inline float32x4_t clamp(float32x4_t x, float min, float max) { + float32x4_t vec_min = vdupq_n_f32(min); + float32x4_t vec_max = vdupq_n_f32(max); + float32x4_t tmp = vmaxq_f32(x, vec_min); + return vminq_f32(tmp, vec_max); +} + +// Implements variants of +// channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot +// to compute +// output = F(activations * weights + bias) +// where +// +// * activations are mxk and transposed, stored in row-major order. +// * weights are kxn and transposed, stored in column-major order. +// (can also be viewed as nxk non-transposed weights stored in row-major +// order). +// * bias are mx1. Ignored if has_bias = false. +// * F is an element-wise activation function, either clamp (has_clamp = true) +// or linear (has_clamp = false). +// * output is mxn. +// +// The suffix 1x4x16_f32_neondot indicates the tile sizes (1x4 = 1x16 @ 16x4), +// floating point type for output (f32), and main ISA instruction (neon_dot). +// There are 64 = 4*16 weight values unpacked in each inner loop iteration. +// +// Activations are channelwise 8-bit quantized, with a scale and zero per row +// Weights are groupwise lowbit (weight_nbit) quantized with a scale (and zero +// if has_weight_zeros = true) per group. +// +// Both activations and weights are dequantized with +// scale * (qval - zero) +// +// The output is computed by dequantizing the activations and weights and +// computing F(activations * weights + bias). +// +// Activations and weights are stored in a prepared format specific to +// this kernel. See prepare_weight_data_impl and prepare_activation_data_impl +// functions for details. +// +// Kernel is roughly modeled on +// https://gitlab.arm.com/kleidi/kleidiai/-/blob/main/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c + +template +void kernel_impl( + // Outputs + float32_t* output, + // Inputs + int output_m_stride, + int m, + int n, + int k, + int group_size, + const void* weight_data, + const void* activation_data, + // Ignored if has_bias is false + const float* bias, + // Ignored if has_clamp is false + float clamp_min, + float clamp_max) { + assert(k % group_size == 0); + assert(group_size % 16 == 0); + + constexpr int bytes_per_64_weight_values = 8 * weight_nbit; + + auto activation_data_byte_ptr = (char*)(activation_data); + char* activation_ptr; + + for (int m_idx = 0; m_idx < m; m_idx++) { + // Read activation scale and zero + float activation_scale = *((float*)activation_data_byte_ptr); + activation_data_byte_ptr += sizeof(float); + + int activation_zero = (int)(*((int8_t*)activation_data_byte_ptr)); + activation_data_byte_ptr += sizeof(int8_t); + + // Set weight_data_byte_ptr to start of weight_data + auto weight_data_byte_ptr = (char*)(weight_data); + + // Loop over 4 cols at a time + // Weights and activations are padded when prepared, so the + // reads are legal, even if on a partial tile + for (int n_idx = 0; n_idx < n; n_idx += 4) { + // Set activation_ptr to start of activation qvals for row m_idx + activation_ptr = activation_data_byte_ptr; + float32x4_t res = vdupq_n_f32(0.0); + + // Loop k_idx by group + for (int k_idx = 0; k_idx < k; k_idx += group_size) { + // Iterating over k in chunks of 16, we compute the dot product + // between 16 values of activation data with 16 values in each of 4 cols + // of weight data. These dot products are stored in accumulators + // acc_cols0011 and acc_cols2233 as indicated in the below table: + // + // weight data activation data accumulator + // ------------------------------------------------------- + // 1st 8 vals of col0 1st 8 vals acc_cols0011[0] + // 2nd 8 vals of col0 2nd 8 vals acc_cols0011[1] + // 1st 8 vals of col1 1st 8 vals acc_cols0011[2] + // 2nd 8 vals of col1 2nd 8 vals acc_cols0011[3] + // 1st 8 vals of col2 1st 8 vals acc_cols2233[0] + // 2nd 8 vals of col2 2nd 8 vals acc_cols2233[1] + // 1st 8 vals of col3 1st 8 vals acc_cols2233[2] + // 2nd 8 vals of col3 2nd 8 vals acc_cols2233[3] + // + // The above computation scheme is what informs the weight valpacking + int32x4_t acc_cols0011 = vdupq_n_s32(0); + int32x4_t acc_cols2233 = vdupq_n_s32(0); + + // holds chunk of 16 activation_q values + int8x16_t act_q; + + // holds chunk of 8 activation vals, duplicated twice + int8x16_t act_q_dup; + + // holds chunk of 8 vals from weight_q col0, followed by 8 vals from + // weight_q col1 + int8x16_t weight_q_cols01_0; + int8x16_t weight_q_cols01_1; + + // holds chunk of 8 vals from weight_q col2, followed by 8 vals from + // weight_q col3 + int8x16_t weight_q_cols23_0; + int8x16_t weight_q_cols23_1; + + for (int i = 0; i < group_size; i += 16) { + // Each chunk is 64 values of unpacked data (4 cols x 16 vals/col). + // This comes out to (64 * weight_nbit / 8) bits = 8 * weight_nbit + // bytes of bitpacked data + torchao::bitpacking::vec_unpack_64_lowbit_values( + weight_q_cols01_0, + weight_q_cols23_0, + weight_q_cols01_1, + weight_q_cols23_1, + (uint8_t*)weight_data_byte_ptr); + weight_data_byte_ptr += bytes_per_64_weight_values; + + // Load 16 activation values + act_q = vld1q_s8((int8_t*)activation_ptr); + activation_ptr += 16; + + // Dot product of first 8 vals of activation data with first 8 vals of + // weight data. Note the sequence of operations here imply the + // following order on weight_data stored in unpacked_buffer: (1st 8 + // vals col0), (1st 8 vals col1), (1st 8 vals col2), (1st 8 vals + // col2). This order is accomplished by valpacking + act_q_dup = vcombine_s8(vget_low_s8(act_q), vget_low_s8(act_q)); + acc_cols0011 = vdotq_s32(acc_cols0011, weight_q_cols01_0, act_q_dup); + acc_cols2233 = vdotq_s32(acc_cols2233, weight_q_cols23_0, act_q_dup); + + // Dot product of second 8 vals of activation data with second 8 vals + // of weight data. + act_q_dup = vcombine_s8(vget_high_s8(act_q), vget_high_s8(act_q)); + acc_cols0011 = vdotq_s32(acc_cols0011, weight_q_cols01_1, act_q_dup); + acc_cols2233 = vdotq_s32(acc_cols2233, weight_q_cols23_1, act_q_dup); + } + // Reduce accumulators, so we have one dot product value per col + int32x4_t qval_dot = vpaddq_s32(acc_cols0011, acc_cols2233); + + // Dequantize and accumulate in result + float32x4_t weight_scales = vld1q_f32((float*)weight_data_byte_ptr); + weight_data_byte_ptr += 16; + + int32x4_t weight_qvals_sum = vld1q_s32((int32_t*)weight_data_byte_ptr); + weight_data_byte_ptr += 16; + + float32x4_t scale_factor = + vmulq_f32(weight_scales, vdupq_n_f32(activation_scale)); + int32x4_t term1 = vmulq_n_s32(weight_qvals_sum, activation_zero); + + if constexpr (has_weight_zeros) { + // Compute + // res += (weight_scale * activation_scale) * + // (qval_dot - (activation_zero * weight_qvals_sum) - + // (weight_zero * activation_qvals_sum) + + // (group_size * weight_zero * activation_zero)); + + int32_t activation_qvals_sum = *((int32_t*)activation_ptr); + activation_ptr += sizeof(int32_t); + + int32x4_t weight_zeros = vld1q_s32((int32_t*)weight_data_byte_ptr); + weight_data_byte_ptr += 16; + + int32x4_t term2 = vmulq_n_s32(weight_zeros, activation_qvals_sum); + int32x4_t term3 = + vmulq_n_s32(weight_zeros, group_size * activation_zero); + + int32x4_t tmp = vsubq_s32(qval_dot, term1); + tmp = vsubq_s32(tmp, term2); + tmp = vaddq_s32(tmp, term3); + res = vmlaq_f32(res, scale_factor, vcvtq_f32_s32(tmp)); + } else { + // Compute + // res += (weight_scale * activation_scale) * + // (qval_dot - activation_zero * weight_qvals_sum); + auto tmp = vsubq_s32(qval_dot, term1); + res = vmlaq_f32(res, scale_factor, vcvtq_f32_s32(tmp)); + } + + } // k_idx + if constexpr (has_bias) { + res = vaddq_f32(res, vdupq_n_f32(bias[m_idx])); + } + if constexpr (has_clamp) { + res = clamp(res, clamp_min, clamp_max); + } + vst1q_f32(output + m_idx * output_m_stride + n_idx, res); + } // n_idx + activation_data_byte_ptr += (activation_ptr - activation_data_byte_ptr); + } // m_idx +} + +// Prepares weight data for kernel_impl. + +// Returns number of bytes required for weight_data +int inline weight_data_size_impl( + int n, + int k, + int group_size, + int weight_nbit, + bool has_weight_zeros) { + assert(k % group_size == 0); + int groups_per_col = k / group_size; + int col_size = 0; + + // qvals + col_size += (k / 8) * weight_nbit; + + // scales + col_size += sizeof(float) * groups_per_col; + + // qvals_sum + col_size += sizeof(int32_t) * groups_per_col; + + // zeros + if (has_weight_zeros) { + col_size += sizeof(int32_t) * groups_per_col; + } + + // Replace n with next multiple of 4 >= n + n = ((n + 3) >> 2) << 2; + + return col_size * n; +} + +template +void prepare_weight_data_impl( + // Output + void* weight_data, + // Inputs + int n, + int k, + int group_size, + const int8_t* weight_qvals, + const float* weight_scales, + // Ignored if has_weight_zeros = false + const int8_t* weight_zeros) { + assert(k % group_size == 0); + assert(group_size % 16 == 0); + int groups_per_k = k / group_size; + constexpr int bytes_per_64_weight_values = 8 * weight_nbit; + + auto weight_data_byte_ptr = (char*)weight_data; + const int8_t* qvals_ptr = weight_qvals; + const float* scales_ptr = weight_scales; + const int8_t* zeros_ptr = weight_zeros; + + int8_t interleaved_buffer[64]; + int8_t buffer[64]; + + for (int n_idx = 0; n_idx < n; n_idx += 4) { + for (int k_idx = 0; k_idx < k; k_idx += group_size) { + // Loop over group in chunks of 16, processing 4 columns at at time + int qvals_sum[4] = {0, 0, 0, 0}; + for (int i = 0; i < group_size; i += 16) { + std::memset(buffer, 0, 64); + // Loop over 4 cols +#pragma unroll(4) + for (int j = 0; j < 4; j++) { + if (n_idx + j < n) { + // If qvals_ptr are pre-packed in a naive way, this is where + // unpacking can occur + std::memcpy(buffer + 16 * j, qvals_ptr + k * j, 16); + qvals_sum[j] += + torchao::kernels::cpu::aarch64::reduction::compute_sum( + buffer + 16 * j, 16); + } + } + torchao::kernels::cpu::valpacking::interleave_data( + /*data_interleaved=*/interleaved_buffer, + /*data=*/buffer, + /*bytes_per_val=*/1, + /*vals_per_channel=*/16, + /*vals_per_group=*/16, + /*vals_per_chunk=*/8, + /*channels=*/4, + /*channel_stride_in_vals=*/16); + torchao::bitpacking::vec_pack_64_lowbit_values( + (uint8_t*)weight_data_byte_ptr, + vld1q_s8(interleaved_buffer), + vld1q_s8(interleaved_buffer + 16), + vld1q_s8(interleaved_buffer + 32), + vld1q_s8(interleaved_buffer + 48)); + qvals_ptr += 16; + weight_data_byte_ptr += bytes_per_64_weight_values; + } // loop over group + + // Store weight scales +#pragma unroll(4) + for (int j = 0; j < 4; j++) { + float32_t scale = 0.0; + if (n_idx + j < n) { + scale = *(scales_ptr + j * groups_per_k); + } + *((float*)weight_data_byte_ptr) = scale; + weight_data_byte_ptr += sizeof(float); + } + scales_ptr += 1; + + // Store weight qvals_sum +#pragma unroll(4) + for (int j = 0; j < 4; j++) { + *((int*)weight_data_byte_ptr) = qvals_sum[j]; + weight_data_byte_ptr += sizeof(int); + } + + // Store weight zeros + // I went back and forth on how to store weight_zero. + // Kernel computation is done in int32, so I'm converting these to + // int32 before storing (load 4 int32s in kernel). + // In the 1x8 kernel, we may want to store as int16_t, which reduces + // a load in the kernel (load 8 int16_ts in kernel, instead of 2 + // load 4 int32_ts), but adds 2 moves (int16 to int32). + if constexpr (has_weight_zeros) { +#pragma unroll(4) + for (int j = 0; j < 4; j++) { + int32_t zero = 0; + if (n_idx + j < n) { + zero = (int)(*(zeros_ptr + j * groups_per_k)); + } + *((int32_t*)weight_data_byte_ptr) = zero; + weight_data_byte_ptr += sizeof(int32_t); + } + zeros_ptr += 1; + } + } // k_idx + + // In the previous loop over k, we processed 4 columns at a time, + // but only advanced our pointers over the first column. + // So we advance over the other 3 columns here. + qvals_ptr += 3 * k; + scales_ptr += 3 * groups_per_k; + if constexpr (has_weight_zeros) { + zeros_ptr += 3 * groups_per_k; + } + } // n_idx +} + +} // namespace + // channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot::internal +} // namespace torchao::kernels::cpu::aarch64::linear + +// Activation functions +template +int torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot:: + activation_data_size(int m, int k, int group_size) { + return torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_prepare_activation_data_1xk_f32::internal:: + activation_data_size_impl(m, k, group_size, has_weight_zeros); +} + +template +void torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot:: + prepare_activation_data( + void* activation_data, + // Inputs + int m, + int k, + // Ignored if has_weight_zeros = false + int group_size, + const float* activations) { + torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_prepare_activation_data_1xk_f32::internal:: + prepare_activation_data_impl( + activation_data, m, k, group_size, activations); +} + +// Weight functions +template +int torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot:: + weight_data_size(int n, int k, int group_size) { + return torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot:: + internal::weight_data_size_impl( + n, k, group_size, weight_nbit, has_weight_zeros); +} + +template +void torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot:: + prepare_weight_data( + void* weight_data, + // Inputs + int n, + int k, + int group_size, + const int8_t* weight_qvals, + const float* weight_scales, + const int8_t* weight_zeros) { + torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot:: + internal::prepare_weight_data_impl( + weight_data, + n, + k, + group_size, + weight_qvals, + weight_scales, + weight_zeros); +} + +template +void torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot:: + kernel( + // Outputs + float32_t* output, + // Inputs + int output_m_stride, + int m, + int n, + int k, + int group_size, + const void* weight_data, + const void* activation_data, + // Not applied if nullptr + const float* bias, + // Ignored if has_clamp = false + float clamp_min, + float clamp_max) { + torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot:: + internal:: + kernel_impl( + output, + output_m_stride, + m, + n, + k, + group_size, + weight_data, + activation_data, + bias, + clamp_min, + clamp_max); +} diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot-impl.h b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot-impl.h new file mode 100644 index 0000000000..3e8df5ae60 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot-impl.h @@ -0,0 +1,546 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once +#include +#include +#include +#include +#include + +namespace torchao::kernels::cpu::aarch64::linear { +namespace channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot:: + internal { + +inline float32x4_t +vec_clamp(float32x4_t x, float32x4_t vec_min, float32x4_t vec_max) { + float32x4_t tmp = vmaxq_f32(x, vec_min); + return vminq_f32(tmp, vec_max); +} + +// Implements variants of +// channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot +// to compute +// output = F(activations * weights + bias) +// where +// +// * activations are mxk and transposed, stored in row-major order. +// * weights are kxn and transposed, stored in column-major order. +// (can also be viewed as nxk non-transposed weights stored in row-major +// order). +// * bias are mx1. Ignored if has_bias = false. +// * F is an element-wise activation function, either clamp (has_clamp = true) +// or linear (has_clamp = false). +// * output is mxn. +// +// The suffix 1x8x16_f32_neondot indicates the tile sizes (1x8 = 1x16 @ 16x4), +// floating point type for output (f32), and main ISA instruction (neon_dot). +// There are 64 = 4*16 weight values unpacked in each inner loop iteration. +// +// Activations are channelwise 8-bit quantized, with a scale and zero per row +// Weights are groupwise lowbit (weight_nbit) quantized with a scale (and zero +// if has_weight_zeros = true) per group. +// +// Both activations and weights are dequantized with +// scale * (qval - zero) +// +// The output is computed by dequantizing the activations and weights and +// computing F(activations * weights + bias). +// +// Activations and weights are stored in a prepared format specific to +// this kernel. See prepare_weight_data_impl and prepare_activation_data_impl +// functions for details. +// +// Roughly inspired by +// https://gitlab.arm.com/kleidi/kleidiai/-/blob/main/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c?ref_type=heads + +template +void kernel_impl( + // Outputs + float32_t* output, + // Inputs + int output_m_stride, + int m, + int n, + int k, + int group_size, + const void* weight_data, + const void* activation_data, + // Ignored if has_bias is false + const float* bias, + // Ignored if has_clamp is false + float clamp_min, + float clamp_max) { + assert(k % group_size == 0); + assert(group_size % 16 == 0); + + constexpr int bytes_per_128_weight_values = 16 * weight_nbit; + + auto activation_data_byte_ptr = (char*)(activation_data); + char* activation_ptr; + + for (int m_idx = 0; m_idx < m; m_idx++) { + // Read activation scale and zero + float activation_scale = *((float*)activation_data_byte_ptr); + activation_data_byte_ptr += sizeof(float); + + int activation_zero = (int)(*((int8_t*)activation_data_byte_ptr)); + activation_data_byte_ptr += sizeof(int8_t); + + // Set weight_data_byte_ptr to start of weight_data + auto weight_data_byte_ptr = (char*)(weight_data); + + // Loop over 8 cols at a time + // Weights and activations are padded when prepared, so the + // reads are legal, even if on a partial tile + for (int n_idx = 0; n_idx < n; n_idx += 8) { + // Set activation_ptr to start of activation qvals for row m_idx + activation_ptr = activation_data_byte_ptr; + float32x4_t res_0123 = vdupq_n_f32(0.0); + float32x4_t res_4567 = vdupq_n_f32(0.0); + + // Loop k_idx by group + for (int k_idx = 0; k_idx < k; k_idx += group_size) { + // Iterating over k in chunks of 16, we compute the dot product + // between 16 values of activation data with 16 values in each of 8 cols + // of weight data. These dot products are stored in accumulators + // acc_cols0011, acc_cols2233, acc_cols4455, acc_cols6677 + // as indicated in the below table: + // + // weight data activation data accumulator + // ------------------------------------------------------- + // 1st 8 vals of col0 1st 8 vals acc_cols0011[0] + // 2nd 8 vals of col0 2nd 8 vals acc_cols0011[1] + // 1st 8 vals of col1 1st 8 vals acc_cols0011[2] + // 2nd 8 vals of col1 2nd 8 vals acc_cols0011[3] + // 1st 8 vals of col2 1st 8 vals acc_cols2233[0] + // 2nd 8 vals of col2 2nd 8 vals acc_cols2233[1] + // 1st 8 vals of col3 1st 8 vals acc_cols2233[2] + // 2nd 8 vals of col3 2nd 8 vals acc_cols2233[3] + // 1st 8 vals of col4 1st 8 vals acc_cols4455[0] + // 2nd 8 vals of col4 2nd 8 vals acc_cols4455[1] + // 1st 8 vals of col5 1st 8 vals acc_cols4455[2] + // 2nd 8 vals of col5 2nd 8 vals acc_cols4455[3] + // 1st 8 vals of col6 1st 8 vals acc_cols6677[0] + // 2nd 8 vals of col6 2nd 8 vals acc_cols6677[1] + // 1st 8 vals of col7 1st 8 vals acc_cols6677[2] + // 2nd 8 vals of col7 2nd 8 vals acc_cols6677[3] + // + // The above computation scheme is what informs the weight valpacking + int32x4_t acc_cols0011 = vdupq_n_s32(0); + int32x4_t acc_cols2233 = vdupq_n_s32(0); + int32x4_t acc_cols4455 = vdupq_n_s32(0); + int32x4_t acc_cols6677 = vdupq_n_s32(0); + + // holds chunk of 16 activation_q values + int8x16_t act_q; + + // holds chunk of 8 activation vals, duplicated twice + int8x16_t act_q_dup; + + // holds chunk of 8 vals from weight_q col0, followed by 8 vals from + // weight_q col1 + int8x16_t weight_q_cols01_0; + int8x16_t weight_q_cols01_1; + + // holds chunk of 8 vals from weight_q col2, followed by 8 vals from + // weight_q col3 + int8x16_t weight_q_cols23_0; + int8x16_t weight_q_cols23_1; + + // holds chunk of 8 vals from weight_q col4, followed by 8 vals from + // weight_q col5 + int8x16_t weight_q_cols45_0; + int8x16_t weight_q_cols45_1; + + // holds chunk of 8 vals from weight_q col6, followed by 8 vals from + // weight_q col7 + int8x16_t weight_q_cols67_0; + int8x16_t weight_q_cols67_1; + + for (int i = 0; i < group_size; i += 16) { + // Each chunk is 64 values of unpacked data (4 cols x 16 vals/col). + // This comes out to (64 * weight_nbit / 8) bits = 8 * weight_nbit + // bytes of bitpacked data + torchao::bitpacking::vec_unpack_128_lowbit_values( + weight_q_cols01_0, + weight_q_cols23_0, + weight_q_cols45_0, + weight_q_cols67_0, + weight_q_cols01_1, + weight_q_cols23_1, + weight_q_cols45_1, + weight_q_cols67_1, + (uint8_t*)weight_data_byte_ptr); + weight_data_byte_ptr += bytes_per_128_weight_values; + + // Load 16 activation values + act_q = vld1q_s8((int8_t*)activation_ptr); + activation_ptr += 16; + + // Dot product of first 8 vals of activation data with first 8 vals of + // weight data. Note the sequence of operations here imply the + // following order on weight_data stored in unpacked_buffer: (1st 8 + // vals col0), (1st 8 vals col1), (1st 8 vals col2), (1st 8 vals + // col2). This order is accomplished by valpacking + act_q_dup = vcombine_s8(vget_low_s8(act_q), vget_low_s8(act_q)); + acc_cols0011 = vdotq_s32(acc_cols0011, weight_q_cols01_0, act_q_dup); + acc_cols2233 = vdotq_s32(acc_cols2233, weight_q_cols23_0, act_q_dup); + acc_cols4455 = vdotq_s32(acc_cols4455, weight_q_cols45_0, act_q_dup); + acc_cols6677 = vdotq_s32(acc_cols6677, weight_q_cols67_0, act_q_dup); + + // Dot product of second 8 vals of activation data with second 8 vals + // of weight data. + act_q_dup = vcombine_s8(vget_high_s8(act_q), vget_high_s8(act_q)); + acc_cols0011 = vdotq_s32(acc_cols0011, weight_q_cols01_1, act_q_dup); + acc_cols2233 = vdotq_s32(acc_cols2233, weight_q_cols23_1, act_q_dup); + acc_cols4455 = vdotq_s32(acc_cols4455, weight_q_cols45_1, act_q_dup); + acc_cols6677 = vdotq_s32(acc_cols6677, weight_q_cols67_1, act_q_dup); + } + // Reduce accumulators, so we have one dot product value per col + int32x4_t qval_dot_0123 = vpaddq_s32(acc_cols0011, acc_cols2233); + int32x4_t qval_dot_4567 = vpaddq_s32(acc_cols4455, acc_cols6677); + + // Result is updated with: + // res += scale_factor * (qval_dot - term1 - term2 + term3), where + // * scale_factor = (weight_scale * activation_scale) + // * term1 = (activation_zero * weight_qvals_sum) + // * term2 = (weight_zero * activation_qvals_sum) + // * term3 = (group_size * weight_zero * activation_zero) + // If has_weight_zeros is false, terms 2 and 3 disappaer. + + // Compute scale_factor + float32x4_t activation_scales = vdupq_n_f32(activation_scale); + + float32x4_t weight_scales = vld1q_f32((float*)weight_data_byte_ptr); + weight_data_byte_ptr += 16; + + float32x4_t scale_factor_0123 = + vmulq_f32(weight_scales, activation_scales); + + weight_scales = vld1q_f32((float*)weight_data_byte_ptr); + weight_data_byte_ptr += 16; + + float32x4_t scale_factor_4567 = + vmulq_f32(weight_scales, activation_scales); + + // Compute term1 + int32x4_t weight_qvals_sum = vld1q_s32((int32_t*)weight_data_byte_ptr); + weight_data_byte_ptr += 16; + + int32x4_t term1_0123 = vmulq_n_s32(weight_qvals_sum, activation_zero); + + weight_qvals_sum = vld1q_s32((int32_t*)weight_data_byte_ptr); + weight_data_byte_ptr += 16; + + int32x4_t term1_4567 = vmulq_n_s32(weight_qvals_sum, activation_zero); + + if constexpr (has_weight_zeros) { + // Compute term2 and term3 + + int32_t activation_qvals_sum = *((int32_t*)activation_ptr); + activation_ptr += sizeof(int32_t); + + int32x4_t weight_zeros = vld1q_s32((int32_t*)weight_data_byte_ptr); + weight_data_byte_ptr += 16; + + int32x4_t term2_0123 = + vmulq_n_s32(weight_zeros, activation_qvals_sum); + + int32x4_t term3_0123 = + vmulq_n_s32(weight_zeros, group_size * activation_zero); + + weight_zeros = vld1q_s32((int32_t*)weight_data_byte_ptr); + weight_data_byte_ptr += 16; + + int32x4_t term2_4567 = + vmulq_n_s32(weight_zeros, activation_qvals_sum); + + int32x4_t term3_4567 = + vmulq_n_s32(weight_zeros, group_size * activation_zero); + + // Do updates + int32x4_t tmp = vsubq_s32(qval_dot_0123, term1_0123); + tmp = vsubq_s32(tmp, term2_0123); + tmp = vaddq_s32(tmp, term3_0123); + res_0123 = vmlaq_f32(res_0123, scale_factor_0123, vcvtq_f32_s32(tmp)); + + tmp = vsubq_s32(qval_dot_4567, term1_4567); + tmp = vsubq_s32(tmp, term2_4567); + tmp = vaddq_s32(tmp, term3_4567); + res_4567 = vmlaq_f32(res_4567, scale_factor_4567, vcvtq_f32_s32(tmp)); + } else { + // Do updates + int32x4_t tmp = vsubq_s32(qval_dot_0123, term1_0123); + res_0123 = vmlaq_f32(res_0123, scale_factor_0123, vcvtq_f32_s32(tmp)); + + tmp = vsubq_s32(qval_dot_4567, term1_4567); + res_4567 = vmlaq_f32(res_4567, scale_factor_4567, vcvtq_f32_s32(tmp)); + } + + } // k_idx + if constexpr (has_bias) { + float32x4_t vec_bias = vdupq_n_f32(bias[m_idx]); + res_0123 = vaddq_f32(res_0123, vec_bias); + res_4567 = vaddq_f32(res_4567, vec_bias); + } + if constexpr (has_clamp) { + float32x4_t vec_min = vdupq_n_f32(clamp_min); + float32x4_t vec_max = vdupq_n_f32(clamp_max); + res_0123 = vec_clamp(res_0123, vec_min, vec_max); + res_4567 = vec_clamp(res_4567, vec_min, vec_max); + } + vst1q_f32(output + m_idx * output_m_stride + n_idx, res_0123); + vst1q_f32(output + m_idx * output_m_stride + n_idx + 4, res_4567); + } // n_idx + activation_data_byte_ptr += (activation_ptr - activation_data_byte_ptr); + } // m_idx +} + +// Prepares weight data for kernel_impl. + +// Returns number of bytes required for weight_data +int inline weight_data_size_impl( + int n, + int k, + int group_size, + int weight_nbit, + bool has_weight_zeros) { + assert(k % group_size == 0); + int groups_per_col = k / group_size; + int col_size = 0; + + // qvals + col_size += (k / 8) * weight_nbit; + + // scales + col_size += sizeof(float) * groups_per_col; + + // qvals_sum + col_size += sizeof(int32_t) * groups_per_col; + + // zeros + if (has_weight_zeros) { + col_size += sizeof(int32_t) * groups_per_col; + } + + // Replace n with next multiple of 8 >= n + n = ((n + 3) >> 3) << 3; + + return col_size * n; +} + +template +void prepare_weight_data_impl( + // Output + void* weight_data, + // Inputs + int n, + int k, + int group_size, + const int8_t* weight_qvals, + const float* weight_scales, + // Ignored if has_weight_zeros = false + const int8_t* weight_zeros) { + assert(k % group_size == 0); + assert(group_size % 16 == 0); + int groups_per_k = k / group_size; + constexpr int bytes_per_128_weight_values = 16 * weight_nbit; + + auto weight_data_byte_ptr = (char*)weight_data; + const int8_t* qvals_ptr = weight_qvals; + const float* scales_ptr = weight_scales; + const int8_t* zeros_ptr = weight_zeros; + + int8_t interleaved_buffer[128]; + int8_t buffer[128]; + + for (int n_idx = 0; n_idx < n; n_idx += 8) { + for (int k_idx = 0; k_idx < k; k_idx += group_size) { + // Loop over group in chunks of 16, processing 8 columns at at time + int qvals_sum[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + for (int i = 0; i < group_size; i += 16) { + std::memset(buffer, 0, 128); + // Loop over 8 cols +#pragma unroll(8) + for (int j = 0; j < 8; j++) { + if (n_idx + j < n) { + // If qvals_ptr are pre-packed in a naive way, this is where + // unpacking can occur + std::memcpy(buffer + 16 * j, qvals_ptr + k * j, 16); + qvals_sum[j] += + torchao::kernels::cpu::aarch64::reduction::compute_sum( + buffer + 16 * j, 16); + } + } + torchao::kernels::cpu::valpacking::interleave_data( + /*data_interleaved=*/interleaved_buffer, + /*data=*/buffer, + /*bytes_per_val=*/1, + /*vals_per_channel=*/16, + /*vals_per_group=*/16, + /*vals_per_chunk=*/8, + /*channels=*/8, + /*channel_stride_in_vals=*/16); + torchao::bitpacking::vec_pack_128_lowbit_values( + (uint8_t*)weight_data_byte_ptr, + vld1q_s8(interleaved_buffer), + vld1q_s8(interleaved_buffer + 16), + vld1q_s8(interleaved_buffer + 32), + vld1q_s8(interleaved_buffer + 48), + vld1q_s8(interleaved_buffer + 64), + vld1q_s8(interleaved_buffer + 80), + vld1q_s8(interleaved_buffer + 96), + vld1q_s8(interleaved_buffer + 112)); + qvals_ptr += 16; + weight_data_byte_ptr += bytes_per_128_weight_values; + } // loop over group + + // Store weight scales +#pragma unroll(8) + for (int j = 0; j < 8; j++) { + float32_t scale = 0.0; + if (n_idx + j < n) { + scale = *(scales_ptr + j * groups_per_k); + } + *((float*)weight_data_byte_ptr) = scale; + weight_data_byte_ptr += sizeof(float); + } + scales_ptr += 1; + + // Store weight qvals_sum +#pragma unroll(8) + for (int j = 0; j < 8; j++) { + *((int*)weight_data_byte_ptr) = qvals_sum[j]; + weight_data_byte_ptr += sizeof(int); + } + + // Store weight zeros + // TODO: test storing these as int16_t, which reduces + // a load in the kernel (load 8 int16_ts in kernel, instead of 2 + // load 4 int32_ts), but adds 2 moves (int16 to int32). + if constexpr (has_weight_zeros) { +#pragma unroll(8) + for (int j = 0; j < 8; j++) { + int32_t zero = 0; + if (n_idx + j < n) { + zero = (int)(*(zeros_ptr + j * groups_per_k)); + } + *((int32_t*)weight_data_byte_ptr) = zero; + weight_data_byte_ptr += sizeof(int32_t); + } + zeros_ptr += 1; + } + } // k_idx + + // In the previous loop over k, we processed 8 columns at a time, + // but only advanced our pointers over the first column. + // So we advance over the other 7 columns here. + qvals_ptr += 7 * k; + scales_ptr += 7 * groups_per_k; + if constexpr (has_weight_zeros) { + zeros_ptr += 7 * groups_per_k; + } + } // n_idx +} + +} // namespace + // channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot::internal +} // namespace torchao::kernels::cpu::aarch64::linear + +// Activation functions +template +int torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot:: + activation_data_size(int m, int k, int group_size) { + return torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_prepare_activation_data_1xk_f32::internal:: + activation_data_size_impl(m, k, group_size, has_weight_zeros); +} + +template +void torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot:: + prepare_activation_data( + void* activation_data, + // Inputs + int m, + int k, + // Ignored if has_weight_zeros = false + int group_size, + const float* activations) { + torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_prepare_activation_data_1xk_f32::internal:: + prepare_activation_data_impl( + activation_data, m, k, group_size, activations); +} + +// Weight functions +template +int torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot:: + weight_data_size(int n, int k, int group_size) { + return torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot:: + internal::weight_data_size_impl( + n, k, group_size, weight_nbit, has_weight_zeros); +} + +template +void torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot:: + prepare_weight_data( + void* weight_data, + // Inputs + int n, + int k, + int group_size, + const int8_t* weight_qvals, + const float* weight_scales, + const int8_t* weight_zeros) { + torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot:: + internal::prepare_weight_data_impl( + weight_data, + n, + k, + group_size, + weight_qvals, + weight_scales, + weight_zeros); +} + +template +void torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot:: + kernel( + // Outputs + float32_t* output, + // Inputs + int output_m_stride, + int m, + int n, + int k, + int group_size, + const void* weight_data, + const void* activation_data, + // Not applied if nullptr + const float* bias, + // Ignored if has_clamp = false + float clamp_min, + float clamp_max) { + torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot:: + internal:: + kernel_impl( + output, + output_m_stride, + m, + n, + k, + group_size, + weight_data, + activation_data, + bias, + clamp_min, + clamp_max); +} diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_prepare_activation_data_1xk_f32-impl.h b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_prepare_activation_data_1xk_f32-impl.h new file mode 100644 index 0000000000..7c2ba7d070 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_prepare_activation_data_1xk_f32-impl.h @@ -0,0 +1,117 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once +#include +#include +#include + +namespace torchao::kernels::cpu::aarch64::linear { +namespace channelwise_8bit_activation_prepare_activation_data_1xk_f32:: + internal { + +// Prepares activation data for kernel_impl. +// Per m_idx (row), activations are stored as follows: +// scale (float), zero (int8_t), +// group0_qvals (int8_t[group_size]), [group0_qvals_sum (int32_t)]? +// group1_qvals (int8_t[group_size]), [group1_qvals_sum (int32_t)]? +// ... +// The groupi_qvals_sum is only present if has_weight_zeros = true. + +// Returns number of bytes required for activation_data +int inline activation_data_size_impl( + int m, + int k, + // Ignored if has_weight_zeros = false + int group_size, + bool has_weight_zeros) { + int row_size = 0; + + // scale + row_size += sizeof(float); + + // zero + row_size += sizeof(int8_t); + + // qvals + row_size += sizeof(int8_t) * k; + + // qvals_sum + if (has_weight_zeros) { + assert(k % group_size == 0); + int groups_per_row = k / group_size; + row_size += sizeof(int32_t) * groups_per_row; + } + + return row_size * m; +} + +template +void prepare_activation_data_impl( + // Output + void* activation_data, + // Inputs + int m, + int k, + // Ignored if has_weight_zeros = false + int group_size, + const float* activations) { + auto activation_data_byte_ptr = (char*)activation_data; + + float vmin, vmax, scale; + int qmin, qmax, zero, qvals_sum; + torchao::quantization::get_qvals_range( + qmin, qmax, /*nbit=*/8, /*is_symmetric=*/false); + + for (int m_idx = 0; m_idx < m; m_idx++) { + torchao::kernels::cpu::aarch64::reduction::find_min_and_max( + vmin, vmax, activations, k); + torchao::quantization::get_scale_and_zero( + scale, zero, vmin, vmax, qmin, qmax); + + // Save scale and zero + *(float32_t*)activation_data_byte_ptr = scale; + activation_data_byte_ptr += sizeof(float32_t); + + *(int8_t*)activation_data_byte_ptr = (int8_t)zero; + activation_data_byte_ptr += sizeof(int8_t); + + if constexpr (has_weight_zeros) { + for (int k_idx = 0; k_idx < k; k_idx += group_size) { + torchao::kernels::cpu::aarch64::quantization::quantize( + /*qvals=*/(int8_t*)activation_data_byte_ptr, + /*vals=*/activations, + /*size=*/group_size, + /*scale=*/scale, + /*zero=*/zero, + /*qmin=*/qmin, + /*qmax=*/qmax); + + qvals_sum = torchao::kernels::cpu::aarch64::reduction::compute_sum( + /*vals=*/(int8_t*)activation_data_byte_ptr, + /*size=*/group_size); + + activation_data_byte_ptr += group_size; + + *(int32_t*)activation_data_byte_ptr = qvals_sum; + activation_data_byte_ptr += sizeof(int32_t); + + activations += group_size; + } + } else { + torchao::kernels::cpu::aarch64::quantization::quantize( + /*qvals=*/(int8_t*)activation_data_byte_ptr, + /*vals=*/activations, + /*size=*/k, + /*scale=*/scale, + /*zero=*/zero, + /*qmin=*/qmin, + /*qmax=*/qmax); + activation_data_byte_ptr += k; + activations += k; + } + } +} + +} // namespace + // channelwise_8bit_activation_prepare_activation_data_1xk_f32::internal +} // namespace torchao::kernels::cpu::aarch64::linear diff --git a/torchao/experimental/kernels/cpu/aarch64/linear/linear.h b/torchao/experimental/kernels/cpu/aarch64/linear/linear.h new file mode 100644 index 0000000000..2607a2371a --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/linear/linear.h @@ -0,0 +1,162 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once +#include + +namespace torchao::kernels::cpu::aarch64::linear { + +namespace channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot { + +template +int activation_data_size(int m, int k, int group_size); + +template +void prepare_activation_data( + void* activation_data, + // Inputs + int m, + int k, + // Ignored if has_weight_zeros = false + int group_size, + const float* activations); + +template +int weight_data_size(int n, int k, int group_size); + +template +void prepare_weight_data( + void* weight_data, + // Inputs + int n, + int k, + int group_size, + const int8_t* weight_qvals, + const float* weight_scales, + const int8_t* weight_zeros); + +template +void kernel( + // Outputs + float32_t* output, + // Inputs + int output_m_stride, + int m, + int n, + int k, + int group_size, + const void* weight_data, + const void* activation_data, + // Not applied if nullptr + const float* bias, + // Ignored if has_clamp = false + float clamp_min, + float clamp_max); + +} // namespace + // channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot + +namespace channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot { + +template +int activation_data_size(int m, int k, int group_size); + +template +void prepare_activation_data( + void* activation_data, + // Inputs + int m, + int k, + // Ignored if has_weight_zeros = false + int group_size, + const float* activations); + +template +int weight_data_size(int n, int k, int group_size); + +template +void prepare_weight_data( + void* weight_data, + // Inputs + int n, + int k, + int group_size, + const int8_t* weight_qvals, + const float* weight_scales, + const int8_t* weight_zeros); + +template +void kernel( + // Outputs + float32_t* output, + // Inputs + int output_m_stride, + int m, + int n, + int k, + int group_size, + const void* weight_data, + const void* activation_data, + // Not applied if nullptr + const float* bias, + // Ignored if has_clamp = false + float clamp_min, + float clamp_max); + +} // namespace + // channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot + +namespace channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot { + +template +int activation_data_size(int m, int k, int group_size); + +template +void prepare_activation_data( + void* activation_data, + // Inputs + int m, + int k, + // Ignored if has_weight_zeros = false + int group_size, + const float* activations); + +template +int weight_data_size(int n, int k, int group_size); + +template +void prepare_weight_data( + void* weight_data, + // Inputs + int n, + int k, + int group_size, + const int8_t* weight_qvals, + const float* weight_scales, + const int8_t* weight_zeros); + +template +void kernel( + // Outputs + float32_t* output, + // Inputs + int output_m_stride, + int m, + int n, + int k, + int group_size, + const void* weight_data, + const void* activation_data, + // Not applied if nullptr + const float* bias, + // Ignored if has_clamp = false + float clamp_min, + float clamp_max); + +} // namespace + // channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot + +} // namespace torchao::kernels::cpu::aarch64::linear + +#include +#include +#include diff --git a/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp b/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp new file mode 100644 index 0000000000..3aed6d0192 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp @@ -0,0 +1,109 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include +#include +#include +#include + +void torchao::quantization::get_qvals_range( + int& qmin, + int& qmax, + int nbit, + bool is_symmetric) { + if (is_symmetric) { + qmin = -(1 << (nbit - 1)) + 1; + qmax = -qmin; + } else { + qmin = -(1 << (nbit - 1)); + qmax = (1 << (nbit - 1)) - 1; + } +} + +float torchao::quantization::get_scale( + float vmin, + float vmax, + int qmin, + int qmax) { + assert(qmin < qmax); + assert(vmin < vmax); + return (vmax - vmin) / (qmax - qmin); +} + +void torchao::quantization::get_scale_and_zero( + float& scale, + int& zero, + float vmin, + float vmax, + int qmin, + int qmax) { + scale = torchao::quantization::get_scale(vmin, vmax, qmin, qmax); + zero = qmin - std::round(vmin / scale); +} + +namespace { +inline void +_vec_clip_inplace(int32x4_t& vec, int32x4_t vec_min, int32x4_t vec_max) { + vec = vmaxq_s32(vec, vec_min); + vec = vminq_s32(vec, vec_max); +} +} // namespace + +void torchao::kernels::cpu::aarch64::quantization::quantize( + // Output + int8_t* qvals, + // Inputs + const float32_t* vals, + int size, + float32_t scale, + int8_t zero, + int8_t qmin, + int8_t qmax) { + assert(size % 8 == 0); + + float32_t invScale = 1.0 / (scale + 1e-16); + float32x4_t vec_zero = vdupq_n_f32(zero); + float32x4_t vec_invScale = vdupq_n_f32(invScale); + int32x4_t vec_qmin = vdupq_n_s32(qmin); + int32x4_t vec_qmax = vdupq_n_s32(qmax); + + float32x4_t vec_val; + float32x4_t vec_qval_f32; + int32x4_t vec_qval_s32; + int16x4_t vec_qval_s16_0; + int16x4_t vec_qval_s16_1; + + for (int i = 0; i < size; i += 8) { + ////////////////////////////////////// + // Quantize first 4 element chunk to int16 + ////////////////////////////////////// + vec_val = vld1q_f32(vals + i); + + // Quantize and round + vec_qval_f32 = vfmaq_f32(vec_zero, vec_val, vec_invScale); + vec_qval_s32 = vcvtnq_s32_f32(vec_qval_f32); + + _vec_clip_inplace(vec_qval_s32, vec_qmin, vec_qmax); + + vec_qval_s16_0 = vqmovn_s32(vec_qval_s32); + + ////////////////////////////////////// + // Quantize second 4 element chunk to int16 + ////////////////////////////////////// + vec_val = vld1q_f32(vals + i + 4); + + // Quantize and round + vec_qval_f32 = vfmaq_f32(vec_zero, vec_val, vec_invScale); + vec_qval_s32 = vcvtnq_s32_f32(vec_qval_f32); + + _vec_clip_inplace(vec_qval_s32, vec_qmin, vec_qmax); + + vec_qval_s16_1 = vqmovn_s32(vec_qval_s32); + + ////////////////////////////////////// + // Store 8 quantized elements + ////////////////////////////////////// + int16x8_t vec_qval_s16_01 = vcombine_s16(vec_qval_s16_0, vec_qval_s16_1); + int8x8_t vec_qval_s8_01 = vqmovn_s16(vec_qval_s16_01); + vst1_s8(qvals + i, vec_qval_s8_01); + } +} diff --git a/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.h b/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.h new file mode 100644 index 0000000000..af49836596 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.h @@ -0,0 +1,51 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once +#include + +// These methods are here temporarily +// Eventually they will be moved to a non-arch specific location +// or replaced by existing PyTorch functions +// The quantize method in aarch64 namespace will remain here; +// it is used for dynamic activation quantization +namespace torchao { +namespace quantization { + +void get_qvals_range(int& qmin, int& qmax, int nbit, bool is_symmetric); + +// val = scale * qval +float get_scale(float vmin, float vmax, int qmin, int qmax); + +// val = scale * (qval - zero) +void get_scale_and_zero( + float& scale, + int& zero, + float vmin, + float vmax, + int qmin, + int qmax); + +} // namespace quantization +} // namespace torchao + +namespace torchao { +namespace kernels { +namespace cpu { +namespace aarch64 { +namespace quantization { +void quantize( + // Output + int8_t* qvals, + // Inputs + const float32_t* vals, + int size, + float32_t scale, + int8_t zero, + int8_t qmin, + int8_t qmax); + +} // namespace quantization +} // namespace aarch64 +} // namespace cpu +} // namespace kernels +} // namespace torchao diff --git a/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp b/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp new file mode 100644 index 0000000000..ab1f26180d --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp @@ -0,0 +1,20 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include + +int32_t torchao::kernels::cpu::aarch64::reduction::compute_sum( + const int8_t* vals, + int size) { + int32_t res = 0; + int i = 0; + +#pragma unroll(4) + for (; i < size; i += 16) { + int8x16_t vec_vals = vld1q_s8(vals + i); + res += (int)(vaddlvq_s8(vec_vals)); + } + for (; i < size; i += 1) { + res += vals[i]; + } + return res; +} diff --git a/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp b/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp new file mode 100644 index 0000000000..ed7ca01bb4 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp @@ -0,0 +1,32 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include + +void torchao::kernels::cpu::aarch64::reduction::find_min_and_max( + float32_t& min, + float32_t& max, + const float32_t* vals, + int size) { + float32x4_t mins = vdupq_n_f32(0.0); + float32x4_t maxes = vdupq_n_f32(0.0); + int i = 0; + for (; i < size; i += 8) { + float32x4_t v1 = vld1q_f32(vals + i); + float32x4_t v2 = vld1q_f32(vals + i + 4); + mins = vminq_f32(v1, v2); + maxes = vmaxq_f32(v1, v2); + } + min = vminvq_f32(mins); + max = vmaxvq_f32(maxes); + + // Remainder + while (i < size) { + if (vals[i] < min) { + min = vals[i]; + } + if (vals[i] > max) { + max = vals[i]; + } + i += 1; + } +} diff --git a/torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h b/torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h new file mode 100644 index 0000000000..25110f4f34 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h @@ -0,0 +1,24 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once +#include +#include + +namespace torchao { +namespace kernels { +namespace cpu { +namespace aarch64 { +namespace reduction { +void find_min_and_max( + float32_t& min, + float32_t& max, + const float32_t* vals, + int size); + +int32_t compute_sum(const int8_t* vals, int size); + +} // namespace reduction +} // namespace aarch64 +} // namespace cpu +} // namespace kernels +} // namespace torchao diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt new file mode 100644 index 0000000000..1b78f25b9c --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt @@ -0,0 +1,66 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +cmake_minimum_required(VERSION 3.19) +project(tests) +set(CMAKE_CXX_STANDARD 17) + +include(FetchContent) +FetchContent_Declare( + googletest + URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip +) +FetchContent_MakeAvailable(googletest) + +add_compile_options("-Wall" "-Werror") + +include(CMakePrintHelpers) +message("TORCHAO_LIBRARIES: ${TORCHAO_LIBRARIES}") +include_directories(${TORCHAO_LIBRARIES}) + +add_library( + dep + ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp + ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp + ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp + ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp +) + +enable_testing() + +add_executable(test_quantization test_quantization.cpp) +target_link_libraries( + test_quantization + PRIVATE + GTest::gtest_main + dep +) + +add_executable(test_bitpacking test_bitpacking.cpp) +target_link_libraries( + test_bitpacking + PRIVATE + GTest::gtest_main + dep +) + +add_executable(test_linear test_linear.cpp) +target_link_libraries( + test_linear + PRIVATE + GTest::gtest_main + dep +) + +add_executable(test_valpacking test_valpacking.cpp) +target_link_libraries( + test_valpacking + PRIVATE + GTest::gtest_main + dep +) + +include(GoogleTest) +gtest_discover_tests(test_quantization) +gtest_discover_tests(test_bitpacking) +gtest_discover_tests(test_linear) +gtest_discover_tests(test_valpacking) diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp new file mode 100644 index 0000000000..9e530da8e5 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp @@ -0,0 +1,353 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include +#include +#include +#include +#include +#include +#include + +TEST(test_bitpacking_8_uint3_values, PackUnpackAreSame) { + int unpacked_bytes = 8; + int packed_bytes = 3; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 3); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + torchao::bitpacking::internal::pack_8_uint3_values( + packed.data(), input.data()); + torchao::bitpacking::internal::unpack_8_uint3_values( + unpacked.data(), packed.data()); + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +TEST(test_bitpacking_64_uint3_values, PackUnpackAreSame) { + int unpacked_bytes = 64; + int packed_bytes = 24; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 3); + std::vector packed(packed_bytes, 0); + + uint8x16_t input0; + uint8x16_t input1; + uint8x16_t input2; + uint8x16_t input3; + + uint8x16_t unpacked0; + uint8x16_t unpacked1; + uint8x16_t unpacked2; + uint8x16_t unpacked3; + + torchao::bitpacking::internal::vec_load_64_uint8_values( + input0, input1, input2, input3, input.data()); + torchao::bitpacking::internal::vec_pack_64_uint3_values( + packed.data(), input0, input1, input2, input3); + torchao::bitpacking::internal::vec_unpack_64_uint3_values( + unpacked0, unpacked1, unpacked2, unpacked3, packed.data()); + + for (int i = 0; i < 16; ++i) { + EXPECT_EQ(input0[i], unpacked0[i]); + EXPECT_EQ(input1[i], unpacked1[i]); + EXPECT_EQ(input2[i], unpacked2[i]); + EXPECT_EQ(input3[i], unpacked3[i]); + } +} + +TEST(test_bitpacking_128_uint3_values, PackUnpackAreSame) { + int unpacked_bytes = 128; + int packed_bytes = 48; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, 3); + std::vector packed(packed_bytes, 0); + + uint8x16_t input0; + uint8x16_t input1; + uint8x16_t input2; + uint8x16_t input3; + uint8x16_t input4; + uint8x16_t input5; + uint8x16_t input6; + uint8x16_t input7; + + uint8x16_t unpacked0; + uint8x16_t unpacked1; + uint8x16_t unpacked2; + uint8x16_t unpacked3; + uint8x16_t unpacked4; + uint8x16_t unpacked5; + uint8x16_t unpacked6; + uint8x16_t unpacked7; + + torchao::bitpacking::internal::vec_load_64_uint8_values( + input0, input1, input2, input3, input.data()); + torchao::bitpacking::internal::vec_load_64_uint8_values( + input4, input5, input6, input7, input.data() + 64); + torchao::bitpacking::internal::vec_pack_128_uint3_values( + packed.data(), + input0, + input1, + input2, + input3, + input4, + input5, + input6, + input7); + torchao::bitpacking::internal::vec_unpack_128_uint3_values( + unpacked0, + unpacked1, + unpacked2, + unpacked3, + unpacked4, + unpacked5, + unpacked6, + unpacked7, + packed.data()); + + for (int i = 0; i < 16; ++i) { + EXPECT_EQ(input0[i], unpacked0[i]); + EXPECT_EQ(input1[i], unpacked1[i]); + EXPECT_EQ(input2[i], unpacked2[i]); + EXPECT_EQ(input3[i], unpacked3[i]); + EXPECT_EQ(input4[i], unpacked4[i]); + EXPECT_EQ(input5[i], unpacked5[i]); + EXPECT_EQ(input6[i], unpacked6[i]); + EXPECT_EQ(input7[i], unpacked7[i]); + } +} + +TEST(test_bitpacking_2_uint4_values, PackUnpackAreSame) { + int unpacked_bytes = 2; + int nbit = 4; + int packed_bytes = unpacked_bytes * nbit / 8; + + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes, 0); + std::vector unpacked(unpacked_bytes, 0); + + torchao::bitpacking::internal::pack_2_uint4_values( + packed.data(), input.data()); + torchao::bitpacking::internal::unpack_2_uint4_values( + unpacked.data(), packed.data()); + for (int i = 0; i < unpacked_bytes; ++i) { + EXPECT_EQ(input[i], unpacked[i]); + } +} + +TEST(test_bitpacking_16_uint4_values, PackUnpackAreSame) { + int unpacked_bytes = 16; + int nbit = 4; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes, 0); + + uint8x16_t input0; + uint8x16_t unpacked0; + + input0 = vld1q_u8(input.data()); + torchao::bitpacking::internal::vec_pack_16_uint4_values( + packed.data(), input0); + torchao::bitpacking::internal::vec_unpack_16_uint4_values( + unpacked0, packed.data()); + + for (int i = 0; i < 16; ++i) { + EXPECT_EQ(input0[i], unpacked0[i]); + } +} + +TEST(test_bitpacking_32_uint4_values, PackUnpackAreSame) { + int unpacked_bytes = 32; + int nbit = 2; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input = torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector packed(packed_bytes, 0); + + uint8x16_t input0; + uint8x16_t input1; + uint8x16_t unpacked0; + uint8x16_t unpacked1; + + input0 = vld1q_u8(input.data()); + input1 = vld1q_u8(input.data() + 16); + torchao::bitpacking::internal::vec_pack_32_uint4_values( + packed.data(), input0, input1); + torchao::bitpacking::internal::vec_unpack_32_uint4_values( + unpacked0, unpacked1, packed.data()); + + for (int i = 0; i < 16; ++i) { + EXPECT_EQ(input0[i], unpacked0[i]); + EXPECT_EQ(input1[i], unpacked1[i]); + } +} + +// Universal bitpacking tests +template +void test_bitpacking_32_lowbit_values() { + int unpacked_bytes = 32; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input_shifted = + torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector input(unpacked_bytes, 0); + int8_t low = -(1 << (nbit - 1)); + int8_t high = (1 << (nbit - 1)); + for (int i = 0; i < unpacked_bytes; ++i) { + input[i] = (int8_t)(input_shifted[i]) + low; + assert(input[i] >= low); + assert(input[i] <= high); + } + std::vector packed(packed_bytes, 0); + + int8x16_t input0; + int8x16_t input1; + int8x16_t unpacked0; + int8x16_t unpacked1; + input0 = vld1q_s8(input.data()); + input1 = vld1q_s8(input.data() + 16); + torchao::bitpacking::vec_pack_32_lowbit_values( + packed.data(), input0, input1); + torchao::bitpacking::vec_unpack_32_lowbit_values( + unpacked0, unpacked1, packed.data()); + + for (int i = 0; i < 16; ++i) { + EXPECT_EQ(input0[i], unpacked0[i]); + EXPECT_EQ(input1[i], unpacked1[i]); + } +} + +template +void test_bitpacking_64_lowbit_values() { + int unpacked_bytes = 64; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input_shifted = + torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector input(unpacked_bytes, 0); + int8_t low = -(1 << (nbit - 1)); + int8_t high = (1 << (nbit - 1)); + for (int i = 0; i < unpacked_bytes; ++i) { + input[i] = (int8_t)(input_shifted[i]) + low; + assert(input[i] >= low); + assert(input[i] <= high); + } + std::vector packed(packed_bytes, 0); + + int8x16_t input0; + int8x16_t input1; + int8x16_t input2; + int8x16_t input3; + int8x16_t unpacked0; + int8x16_t unpacked1; + int8x16_t unpacked2; + int8x16_t unpacked3; + input0 = vld1q_s8(input.data()); + input1 = vld1q_s8(input.data() + 16); + input2 = vld1q_s8(input.data() + 32); + input3 = vld1q_s8(input.data() + 48); + torchao::bitpacking::vec_pack_64_lowbit_values( + packed.data(), input0, input1, input2, input3); + torchao::bitpacking::vec_unpack_64_lowbit_values( + unpacked0, unpacked1, unpacked2, unpacked3, packed.data()); + + for (int i = 0; i < 16; ++i) { + EXPECT_EQ(input0[i], unpacked0[i]); + EXPECT_EQ(input1[i], unpacked1[i]); + EXPECT_EQ(input2[i], unpacked2[i]); + EXPECT_EQ(input3[i], unpacked3[i]); + } +} + +template +void test_bitpacking_128_lowbit_values() { + int unpacked_bytes = 128; + int packed_bytes = unpacked_bytes * nbit / 8; + auto input_shifted = + torchao::get_random_lowbit_vector(unpacked_bytes, nbit); + std::vector input(unpacked_bytes, 0); + int8_t low = -(1 << (nbit - 1)); + int8_t high = (1 << (nbit - 1)); + for (int i = 0; i < unpacked_bytes; ++i) { + input[i] = (int8_t)(input_shifted[i]) + low; + assert(input[i] >= low); + assert(input[i] <= high); + } + std::vector packed(packed_bytes, 0); + + int8x16_t input0; + int8x16_t input1; + int8x16_t input2; + int8x16_t input3; + int8x16_t input4; + int8x16_t input5; + int8x16_t input6; + int8x16_t input7; + int8x16_t unpacked0; + int8x16_t unpacked1; + int8x16_t unpacked2; + int8x16_t unpacked3; + int8x16_t unpacked4; + int8x16_t unpacked5; + int8x16_t unpacked6; + int8x16_t unpacked7; + + input0 = vld1q_s8(input.data()); + input1 = vld1q_s8(input.data() + 16); + input2 = vld1q_s8(input.data() + 32); + input3 = vld1q_s8(input.data() + 48); + input4 = vld1q_s8(input.data() + 64); + input5 = vld1q_s8(input.data() + 80); + input6 = vld1q_s8(input.data() + 96); + input7 = vld1q_s8(input.data() + 112); + torchao::bitpacking::vec_pack_128_lowbit_values( + packed.data(), + input0, + input1, + input2, + input3, + input4, + input5, + input6, + input7); + torchao::bitpacking::vec_unpack_128_lowbit_values( + unpacked0, + unpacked1, + unpacked2, + unpacked3, + unpacked4, + unpacked5, + unpacked6, + unpacked7, + packed.data()); + for (int i = 0; i < 16; ++i) { + EXPECT_EQ(input0[i], unpacked0[i]); + EXPECT_EQ(input1[i], unpacked1[i]); + EXPECT_EQ(input2[i], unpacked2[i]); + EXPECT_EQ(input3[i], unpacked3[i]); + EXPECT_EQ(input4[i], unpacked4[i]); + EXPECT_EQ(input5[i], unpacked5[i]); + EXPECT_EQ(input6[i], unpacked6[i]); + EXPECT_EQ(input7[i], unpacked7[i]); + } +} + +#define TEST_BITPACKING_32_LOWBIT_VALUES(nbit) \ + TEST(test_bitpacking_32_lowbit_values_##nbit, PackUnpackAreSame) { \ + test_bitpacking_32_lowbit_values(); \ + } + +#define TEST_BITPACKING_64_LOWBIT_VALUES(nbit) \ + TEST(test_bitpacking_64_lowbit_values_##nbit, PackUnpackAreSame) { \ + test_bitpacking_64_lowbit_values(); \ + } + +#define TEST_BITPACKING_128_LOWBIT_VALUES(nbit) \ + TEST(test_bitpacking_128_lowbit_values_##nbit, PackUnpackAreSame) { \ + test_bitpacking_128_lowbit_values(); \ + } + +TEST_BITPACKING_32_LOWBIT_VALUES(3); +TEST_BITPACKING_32_LOWBIT_VALUES(4); + +TEST_BITPACKING_64_LOWBIT_VALUES(3); +TEST_BITPACKING_64_LOWBIT_VALUES(4); + +TEST_BITPACKING_128_LOWBIT_VALUES(3); +TEST_BITPACKING_128_LOWBIT_VALUES(4); diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp new file mode 100644 index 0000000000..4b61c162e0 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp @@ -0,0 +1,370 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include +#include +#include +#include +#include +#include + +float kTol = 0.0001; + +template +void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot() { + int m = 7; + int k = 128; + int n = 13; + int group_size = 32; + + auto test_case = torchao:: + channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( + m, + k, + n, + group_size, + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp); + + using namespace torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot; + + std::vector activation_data( + activation_data_size(m, k, group_size)); + prepare_activation_data( + (void*)activation_data.data(), + m, + k, + group_size, + test_case.activations.data()); + + std::vector weight_data( + weight_data_size(n, k, group_size)); + prepare_weight_data( + (void*)weight_data.data(), + n, + k, + group_size, + test_case.weight_qvals.data(), + test_case.weight_scales.data(), + /*weight_zeros=*/test_case.weight_zeros.data()); + + std::vector output(m * k); + kernel( + output.data(), + /*output_m_stride=*/n, + m, + n, + k, + group_size, + weight_data.data(), + activation_data.data(), + /*bias=*/test_case.bias.data(), + /*clamp_min=*/test_case.clamp_min, + /*clamp_max=*/test_case.clamp_max); + + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } +} + +TEST( + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, + Standard) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = false; + constexpr bool has_clamp = false; + + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); +} + +TEST( + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, + HasWeightZeros) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = true; + constexpr bool has_bias = false; + constexpr bool has_clamp = false; + + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); +} + +TEST( + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, + HasBias) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = true; + constexpr bool has_clamp = false; + + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); +} + +TEST( + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, + HasClamp) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = false; + constexpr bool has_clamp = true; + + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); +} + +template +void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot() { + int m = 7; + int k = 64; + int n = 13; + int group_size = 16; + + auto test_case = torchao:: + channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( + m, + k, + n, + group_size, + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp); + + using namespace torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot; + + std::vector activation_data( + activation_data_size(m, k, group_size)); + prepare_activation_data( + (void*)activation_data.data(), + m, + k, + group_size, + test_case.activations.data()); + + std::vector weight_data( + weight_data_size(n, k, group_size)); + prepare_weight_data( + (void*)weight_data.data(), + n, + k, + group_size, + test_case.weight_qvals.data(), + test_case.weight_scales.data(), + /*weight_zeros=*/test_case.weight_zeros.data()); + + std::vector output(m * k); + kernel( + output.data(), + /*output_m_stride=*/n, + m, + n, + k, + group_size, + weight_data.data(), + activation_data.data(), + /*bias=*/test_case.bias.data(), + /*clamp_min=*/test_case.clamp_min, + /*clamp_max=*/test_case.clamp_max); + + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } +} + +TEST( + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, + Standard) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = false; + constexpr bool has_clamp = false; + + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); +} + +TEST( + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, + HasWeightZeros) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = true; + constexpr bool has_bias = false; + constexpr bool has_clamp = false; + + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); +} + +TEST( + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, + HasBias) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = true; + constexpr bool has_clamp = false; + + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); +} + +TEST( + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, + HasClamp) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = false; + constexpr bool has_clamp = true; + + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); +} + +template +void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot() { + int m = 7; + int k = 64; + int n = 13; + int group_size = 16; + + auto test_case = torchao:: + channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( + m, + k, + n, + group_size, + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp); + + using namespace torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; + + std::vector activation_data( + activation_data_size(m, k, group_size)); + prepare_activation_data( + (void*)activation_data.data(), + m, + k, + group_size, + test_case.activations.data()); + + std::vector weight_data( + weight_data_size(n, k, group_size)); + prepare_weight_data( + (void*)weight_data.data(), + n, + k, + group_size, + test_case.weight_qvals.data(), + test_case.weight_scales.data(), + /*weight_zeros=*/test_case.weight_zeros.data()); + + std::vector output(m * k); + kernel( + output.data(), + /*output_m_stride=*/n, + m, + n, + k, + group_size, + weight_data.data(), + activation_data.data(), + /*bias=*/test_case.bias.data(), + /*clamp_min=*/test_case.clamp_min, + /*clamp_max=*/test_case.clamp_max); + + for (int i = 0; i < m * n; i++) { + EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); + } +} + +TEST( + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, + Standard) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = false; + constexpr bool has_clamp = false; + + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); +} + +TEST( + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, + HasWeightZeros) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = true; + constexpr bool has_bias = false; + constexpr bool has_clamp = false; + + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); +} + +TEST( + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, + HasBias) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = true; + constexpr bool has_clamp = false; + + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); +} + +TEST( + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, + HasClamp) { + constexpr int weight_nbit = 4; + constexpr bool has_weight_zeros = false; + constexpr bool has_bias = false; + constexpr bool has_clamp = true; + + test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp>(); +} diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_quantization.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_quantization.cpp new file mode 100644 index 0000000000..6fac44244a --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_quantization.cpp @@ -0,0 +1,66 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include +#include +#include +#include +#include + +// Demonstrate some basic assertions. +TEST(test_quantize, ExpectedOutput) { + std::array vals = {1.0, 2.5, -5.2, 10.2, 11.1, -3.15, -8.1, 7.3}; + std::array>, 5> nBitToExpectedResult{ + {{2, {0.0, 0.0, -6.4, 12.8, 12.8, 0, -6.4, 6.4}}, + {3, + {0.0, + 2.74286, + -5.48571, + 10.9714, + 10.9714, + -2.74286, + -8.22857, + 8.22857}}, + {4, {1.28, 2.56, -5.12, 10.24, 11.52, -2.56, -7.68, 7.68}}, + {5, + {1.23871, + 2.47742, + -4.95484, + 9.90968, + 11.1484, + -3.09677, + -8.05161, + 7.43226}}, + {8, + {0.978824, + 2.48471, + -5.19529, + 10.1647, + 11.0682, + -3.16235, + -8.13177, + 7.30353}}}}; + + int qmin, qmax, zero; + float vmin, vmax, scale; + + torchao::kernels::cpu::aarch64::reduction::find_min_and_max( + vmin, vmax, vals.data(), vals.size()); + + std::vector qvals(vals.size()); + + for (auto [nbit, expectedResult] : nBitToExpectedResult) { + torchao::quantization::get_qvals_range( + qmin, qmax, nbit, /*is_symmetric=*/false); + + torchao::quantization::get_scale_and_zero( + scale, zero, vmin, vmax, qmin, qmax); + + torchao::kernels::cpu::aarch64::quantization::quantize( + qvals.data(), vals.data(), vals.size(), scale, zero, qmin, qmax); + + for (int i = 0; i < vals.size(); ++i) { + float dq = scale * (qvals[i] - zero); + EXPECT_NEAR(dq, expectedResult[i], 0.0001); + } + } +} diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h new file mode 100644 index 0000000000..7f0fbe8e49 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h @@ -0,0 +1,269 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once +#include +#include +#include +#include +#include + +namespace torchao { +inline std::vector +get_random_vector(int size, float min = -1.0, float max = 1.0) { + assert(min < max); + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto dist = std::bind(std::uniform_real_distribution(min, max), rng); + std::vector res(size); + std::generate(res.begin(), res.end(), std::ref(dist)); + return res; +} + +inline std::vector get_random_lowbit_vector(int size, int nbit) { + assert(nbit >= 2); + assert(nbit <= 8); + + int min = 0; + int max = (1 << nbit) - 1; + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto dist = std::bind(std::uniform_int_distribution<>(min, max), rng); + + std::vector res(size); + std::generate(res.begin(), res.end(), std::ref(dist)); + return res; +} + +struct channelwise_8bit_activation_groupwise_lowbit_weight_test_case { + int m; + int k; + int n; + int weight_group_size; + int weight_nbit; + bool has_weight_zeros; + bool has_bias; + bool has_clamp; + float clamp_min; + float clamp_max; + + std::vector expected_output; + + std::vector activations; + std::vector activation_qvals; + std::vector activation_scales; + std::vector activation_zeros; + + std::vector weights; + std::vector weight_qvals; + std::vector weight_scales; + std::vector weight_zeros; + + std::vector bias; + + channelwise_8bit_activation_groupwise_lowbit_weight_test_case( + int m_, + int k_, + int n_, + int weight_group_size_, + int weight_nbit_, + bool has_weight_zeros_, + bool has_bias_, + bool has_clamp_, + float clamp_min_, + float clamp_max_, + std::vector expected_output_, + std::vector activations_, + std::vector activation_qvals_, + std::vector activation_scales_, + std::vector activation_zeros_, + std::vector weights_, + std::vector weight_qvals_, + std::vector weight_scales_, + std::vector weight_zeros_, + std::vector bias_) + : m(m_), + k(k_), + n(n_), + weight_group_size(weight_group_size_), + weight_nbit(weight_nbit_), + has_weight_zeros(has_weight_zeros_), + has_bias(has_bias_), + has_clamp(has_clamp_), + clamp_min(clamp_min_), + clamp_max(clamp_max_), + expected_output(expected_output_), + activations(activations_), + activation_qvals(activation_qvals_), + activation_scales(activation_scales_), + activation_zeros(activation_zeros_), + weights(weights_), + weight_qvals(weight_qvals_), + weight_scales(weight_scales_), + weight_zeros(weight_zeros_), + bias(bias_) { + assert(k % weight_group_size == 0); + assert(expected_output.size() == m * n); + assert(activations.size() == m * k); + assert(activation_qvals.size() == m * k); + assert(activation_scales.size() == m); + assert(activation_zeros.size() == m); + assert(weights.size() == n * k); + assert(weight_qvals.size() == n * k); + assert((weight_group_size * weight_scales.size()) == (n * k)); + assert((weight_group_size * weight_zeros.size()) == (n * k)); + assert(bias.size() == m); + + if (has_clamp) { + assert(clamp_min < clamp_max); + } + } + + static channelwise_8bit_activation_groupwise_lowbit_weight_test_case generate( + int m, + int k, + int n, + int weight_group_size, + int weight_nbit, + bool has_weight_zeros, + bool has_bias, + bool has_clamp) { + // activations is m x k (stored in row-major) + // weights is k x n (stored in column-major) + + // Generate activations + auto activations = get_random_vector(m * k, -1.0, 1.0); + auto activation_qvals = std::vector(m * k, 0); + auto activation_scales = std::vector(m, 0); + auto activation_zeros = std::vector(m, 0); + + // Quantize activations with 8-bit asymmetric + // TODO: replace with generic function that does not use aarch64 + // quantize method after we combine with torchao + int qmin, qmax, zero; + float vmin, vmax, scale; + torchao::quantization::get_qvals_range( + qmin, qmax, /*nbit=*/8, /*is_symmetric=*/false); + for (int m_idx = 0; m_idx < m; m_idx++) { + torchao::kernels::cpu::aarch64::reduction::find_min_and_max( + vmin, vmax, /*vals=*/activations.data() + m_idx * k, /*size=*/k); + torchao::quantization::get_scale_and_zero( + scale, zero, vmin, vmax, qmin, qmax); + activation_scales[m_idx] = scale; + activation_zeros[m_idx] = zero; + torchao::kernels::cpu::aarch64::quantization::quantize( + /*qvals=*/activation_qvals.data() + m_idx * k, + /*vals=*/activations.data() + m_idx * k, + /*size=*/k, + scale, + zero, + qmin, + qmax); + } + + // Generate weights + assert(k % weight_group_size == 0); + int n_weight_groups = (n * k) / weight_group_size; + auto weights = get_random_vector(n * k, -1.0, 1.0); + auto weight_qvals = std::vector(n * k, 0); + auto weight_scales = std::vector(n_weight_groups, 0.0); + auto weight_zeros = std::vector(n_weight_groups, 0); + + // Quantize weights with weight_nbit + // TODO: replace with generic function that does not use aarch64 + // quantize method after we combine with torchao + torchao::quantization::get_qvals_range( + qmin, qmax, /*nbit=*/weight_nbit, /*is_symmetric=*/false); + + int n_groups = (n * k) / weight_group_size; + for (int group_idx = 0; group_idx < n_groups; group_idx += 1) { + torchao::kernels::cpu::aarch64::reduction::find_min_and_max( + vmin, + vmax, + /*vals=*/weights.data() + group_idx * weight_group_size, + /*size=*/weight_group_size); + + if (has_weight_zeros) { + torchao::quantization::get_scale_and_zero( + scale, zero, vmin, vmax, qmin, qmax); + } else { + scale = torchao::quantization::get_scale(vmin, vmax, qmin, qmax); + zero = 0; + } + weight_scales[group_idx] = scale; + weight_zeros[group_idx] = zero; + + torchao::kernels::cpu::aarch64::quantization::quantize( + /*qvals=*/weight_qvals.data() + group_idx * weight_group_size, + /*vals=*/weights.data() + group_idx * weight_group_size, + /*size=*/weight_group_size, + scale, + zero, + qmin, + qmax); + } + + std::vector bias(m, 0.0); + if (has_bias) { + bias = get_random_vector(m, -1.0, 1.0); + } + + float clamp_min = 0.0; + float clamp_max = 0.0; + if (has_clamp) { + clamp_min = get_random_vector(1, -1.0, 0.2)[0]; + clamp_max = get_random_vector(1, 0.3, 1.0)[0]; + } + + // Compute expected output + std::vector expected_output(m * n); + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int n_idx = 0; n_idx < n; n_idx++) { + float res = 0.0; + for (int k_idx = 0; k_idx < k; k_idx++) { + int activation_idx = m_idx * k + k_idx; + int weight_idx = n_idx * k + k_idx; + int weight_group_idx = weight_idx / weight_group_size; + + float activation_dequant = activation_scales[m_idx] * + (activation_qvals[activation_idx] - activation_zeros[m_idx]); + + float weight_dequant = weight_scales[weight_group_idx] * + (weight_qvals[weight_idx] - weight_zeros[weight_group_idx]); + + res += activation_dequant * weight_dequant; + } + res += bias[m_idx]; + if (has_clamp) { + res = std::min(std::max(res, clamp_min), clamp_max); + } + expected_output[m_idx * n + n_idx] = res; + } + } + + // Return test case + return channelwise_8bit_activation_groupwise_lowbit_weight_test_case( + m, + k, + n, + weight_group_size, + weight_nbit, + has_weight_zeros, + has_bias, + has_clamp, + clamp_min, + clamp_max, + expected_output, + activations, + activation_qvals, + activation_scales, + activation_zeros, + weights, + weight_qvals, + weight_scales, + weight_zeros, + bias); + } +}; + +} // namespace torchao diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_valpacking.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_valpacking.cpp new file mode 100644 index 0000000000..5497a62f72 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_valpacking.cpp @@ -0,0 +1,96 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include +#include +#include + +TEST(InterleaveDataTest, InterleaveChannels) { + // interleave 4 rows of 6 elements + int bytes_per_val = 4; // int32_t + int vals_per_channel = 6; + int vals_per_group = 6; + int vals_per_chunk = 3; + int channels = 4; + int channel_stride_in_vals = vals_per_channel; + + int data_size = channels * vals_per_channel; + assert(data_size == 24); + int32_t data[data_size]; + int32_t data_interleaved[data_size]; + for (int i = 0; i < data_size; i++) { + data[i] = i; + data_interleaved[i] = 0; + } + int32_t expected_data_interleaved[] = {0, 1, 2, 6, 7, 8, 12, 13, + 14, 18, 19, 20, 3, 4, 5, 9, + 10, 11, 15, 16, 17, 21, 22, 23}; + + torchao::kernels::cpu::valpacking::interleave_data( + data_interleaved, + data, + bytes_per_val, + vals_per_channel, + vals_per_group, + vals_per_chunk, + channels, + channel_stride_in_vals); + + for (int i = 0; i < data_size; ++i) { + EXPECT_EQ(data_interleaved[i], expected_data_interleaved[i]); + } +} + +TEST(InterleaveDataTest, InterleaveChannelsAndGroups) { + // Test this example: + // + // group0 group1 group2 + // chunk0 chunk1 chunk0 chunk1 chunk0 chunk1 + // [(v00, v01 | v02, v03) | (v04, v05 | v06, v07) | (v08, v09 | v0a, v0b)] ch0 + // [(v10, v11 | v12, v13) | (v14, v15 | v16, v17) | (v18, v19 | v1a, v1b)] ch1 + // [(v20, v21 | v22, v23) | (v24, v25 | v26, v27) | (v28, v29 | v2a, v2b)] ch2 + // [(v30, v31 | v32, v33) | (v34, v35 | v36, v37) | (v38, v39 | v3a, v3b)] ch3 + // + // The output of this method is: + // + // v00, v01 | v10, v11 | v20, v21 | v30, v31 // chunk0, group0 channels + // v04, v05 | v14, v15 | v24, v25 | v34, v35 // chunk0, group1 channels + // v08, v09 | v18, v19 | v28, v29 | v38, v39 // chunk0, group2 channels + // v02, v03 | v12, v13 | v22, v23 | v32, v33 // chunk1, group0 channels + // v06, v07 | v16, v17 | v26, v27 | v36, v37 // chunk1, group1 channels + // v0a, v0b | v1a, v1b | v2a, v2b | v3a, v3b // chunk1, group2 channels + + // interleave 4 rows of 6 elements + int bytes_per_val = 4; // int32_t + int vals_per_channel = 12; + int vals_per_group = 4; + int vals_per_chunk = 2; + int channels = 4; + int channel_stride_in_vals = vals_per_channel; + + int data_size = channels * vals_per_channel; + assert(data_size == 48); + int32_t data[data_size]; + int32_t data_interleaved[data_size]; + for (int i = 0; i < data_size; i++) { + data[i] = i; + data_interleaved[i] = 0; + } + int32_t expected_data_interleaved[] = { + 0, 1, 12, 13, 24, 25, 36, 37, 4, 5, 16, 17, 28, 29, 40, 41, + 8, 9, 20, 21, 32, 33, 44, 45, 2, 3, 14, 15, 26, 27, 38, 39, + 6, 7, 18, 19, 30, 31, 42, 43, 10, 11, 22, 23, 34, 35, 46, 47}; + + torchao::kernels::cpu::valpacking::interleave_data( + data_interleaved, + data, + bytes_per_val, + vals_per_channel, + vals_per_group, + vals_per_chunk, + channels, + channel_stride_in_vals); + + for (int i = 0; i < data_size; ++i) { + EXPECT_EQ(data_interleaved[i], expected_data_interleaved[i]); + } +} diff --git a/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp b/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp new file mode 100644 index 0000000000..ace1d1697a --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp @@ -0,0 +1,75 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include +#include +#include + +// Interleaves data across channels (row/column) and groups. +// Each channel is the same size (vals_per_channel) and is +// divided into groups (vals_per_channel % vals_per_group == 0). +// Each group is divided into chunks (vals_per_group % vals_per_chunk == 0). +// Chunks are interleaved. +// +// Data is interleaved by iterating over chunks, then groups, and then channels. +// +// For example, given original data (depicted below with channels as +// rows, vals_per_channel=12, vals_per_group = 4, vals_per_chunk=2): +// +// group0 group1 group2 +// chunk0 chunk1 chunk0 chunk1 chunk0 chunk1 +// [(v00, v01 | v02, v03) | (v04, v05 | v06, v07) | (v08, v09 | v0a, v0b)] ch0 +// [(v10, v11 | v12, v13) | (v14, v15 | v16, v17) | (v18, v19 | v1a, v1b)] ch1 +// [(v20, v21 | v22, v23) | (v24, v25 | v26, v27) | (v28, v29 | v2a, v2b)] ch2 +// [(v30, v31 | v32, v33) | (v34, v35 | v36, v37) | (v38, v39 | v3a, v3b)] ch3 +// +// The output of this method is: +// +// v00, v01 | v10, v11 | v20, v21 | v30, v31 // chunk0, group0 channels +// v04, v05 | v14, v15 | v24, v25 | v34, v35 // chunk0, group1 channels +// v08, v09 | v18, v19 | v28, v29 | v38, v39 // chunk0, group2 channels +// v02, v03 | v12, v13 | v22, v23 | v32, v33 // chunk1, group0 channels +// v06, v07 | v16, v17 | v26, v27 | v36, v37 // chunk1, group1 channels +// v0a, v0b | v1a, v1b | v2a, v2b | v3a, v3b // chunk1, group2 channels +// +// For a given value, the value in the next channel is offset by +// channel_stride_in_vals. +// It may be that channel_stride_in_vals = vals_per_channel, +// but it can be something else if we are applying this method +// to a matrix tile. + +void torchao::kernels::cpu::valpacking::interleave_data( + void* data_interleaved, + const void* data, + int bytes_per_val, + int vals_per_channel, + int vals_per_group, + int vals_per_chunk, + int channels, + int channel_stride_in_vals) { + assert(vals_per_channel % vals_per_group == 0); + assert(vals_per_group % vals_per_chunk == 0); + + int chunks_per_group = vals_per_group / vals_per_chunk; + int groups_per_channel = vals_per_channel / vals_per_group; + int bytes_per_chunk = vals_per_chunk * bytes_per_val; + + int8_t* output_byte_ptr = (int8_t*)(data_interleaved); + const int8_t* input_byte_ptr = (int8_t*)(data); + + for (int chunk_idx = 0; chunk_idx < chunks_per_group; chunk_idx++) { + for (int group_idx = 0; group_idx < groups_per_channel; group_idx++) { + for (int channel_idx = 0; channel_idx < channels; channel_idx++) { + // Index of first value in chunk we're moving + int val_idx = (channel_idx * channel_stride_in_vals) + + (group_idx * vals_per_group) + (chunk_idx * vals_per_chunk); + + // Copy chunk to correct location + std::memcpy( + output_byte_ptr, + input_byte_ptr + val_idx * bytes_per_val, + bytes_per_chunk); + output_byte_ptr += bytes_per_chunk; + } + } + } +} diff --git a/torchao/experimental/kernels/cpu/aarch64/valpacking/valpack.h b/torchao/experimental/kernels/cpu/aarch64/valpacking/valpack.h new file mode 100644 index 0000000000..ecfb16ac83 --- /dev/null +++ b/torchao/experimental/kernels/cpu/aarch64/valpacking/valpack.h @@ -0,0 +1,24 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +namespace torchao { +namespace kernels { +namespace cpu { +namespace valpacking { + +// TODO: should this be relocated out of aarch64? +void interleave_data( + void* data_interleaved, + const void* data, + int bytes_per_val, + int vals_per_channel, + int vals_per_group, + int vals_per_chunk, + int channels, + int channel_stride_in_vals); + +} // namespace valpacking +} // namespace cpu +} // namespace kernels +} // namespace torchao diff --git a/torchao/experimental/kernels/cpu/build_and_run_benchmarks.sh b/torchao/experimental/kernels/cpu/build_and_run_benchmarks.sh new file mode 100644 index 0000000000..83a7e3094f --- /dev/null +++ b/torchao/experimental/kernels/cpu/build_and_run_benchmarks.sh @@ -0,0 +1,19 @@ +#!/bin/bash +# Call script with sh build_and_run_benchmarks.sh {BENCHAMRK} + +SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) +export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../.. +export CMAKE_OUT=/tmp/cmake-out/torch_ao/benchmarks +cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \ + -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/benchmarks \ + -B ${CMAKE_OUT} + +cmake --build ${CMAKE_OUT} + +# Run +case "$1" in + quantization) ${CMAKE_OUT}/benchmark_quantization; ;; + bitpacking) ${CMAKE_OUT}/benchmark_bitpacking; ;; + linear) ${CMAKE_OUT}/benchmark_linear; ;; + *) echo "Unknown benchmark: $1. Please specify quantization, bitpacking, or linear."; exit 1; ;; +esac diff --git a/torchao/experimental/kernels/cpu/build_and_run_tests.sh b/torchao/experimental/kernels/cpu/build_and_run_tests.sh new file mode 100644 index 0000000000..5ebc30f454 --- /dev/null +++ b/torchao/experimental/kernels/cpu/build_and_run_tests.sh @@ -0,0 +1,13 @@ +#!/bin/bash +SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) +export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../.. +export CMAKE_OUT=/tmp/cmake-out/torch_ao/tests +cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/tests -B ${CMAKE_OUT} + +cmake --build ${CMAKE_OUT} + +# Run + ${CMAKE_OUT}/test_quantization + ${CMAKE_OUT}/test_bitpacking + ${CMAKE_OUT}/test_linear + ${CMAKE_OUT}/test_valpacking