Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#include <cassert>

// TODO: Remove all ::kernels. No need for extra namespace.
namespace torchao::kernels::cpu::fallback::quantized_matmul {
namespace fp32_a_input_channelwise_8bit_b_fp32 {
template <bool b_has_zeros, bool a_transposed, bool b_transposed>
void kernel(
int m,
int n,
int k,
const float* lhs,
int lhs_stride_m,
const int8_t* rhs,
int rhs_stride_n,
float* output,
int out_stride_m,
const int8_t* rhs_zero_points,
const float* rhs_scales,
const float beta,
const int rhs_qparams_stride) {
assert(a_transposed == false);
for (int m_idx = 0; m_idx < m; m_idx++) {
for (int n_idx = 0; n_idx < n; n_idx++) {
float res = 0.0;
for (int k_idx = 0; k_idx < k; k_idx++) {
int lhs_idx = m_idx * lhs_stride_m + k_idx;
int rhs_idx = k_idx * rhs_stride_n + n_idx;
if (b_transposed) {
rhs_idx = n_idx * rhs_stride_n + k_idx;
}
float rhs_dequant = rhs_scales[k_idx * rhs_qparams_stride] *
(static_cast<int16_t>(rhs[rhs_idx]) -
static_cast<int16_t>(rhs_zero_points[k_idx * rhs_qparams_stride]));

res += lhs[lhs_idx] * rhs_dequant;
}
output[m_idx * n + n_idx] = output[m_idx * n + n_idx] * beta + res;
}
}
}
} // namespace fp32_a_input_channelwise_8bit_b_fp32
} // namespace torchao::kernels::cpu::fallback::quantized_matmul
70 changes: 70 additions & 0 deletions torchao/experimental/kernels/cpu/interface/quantized_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include <cassert>

#include <torchao/experimental/kernels/cpu/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h>
#include <torchao/experimental/kernels/cpu/fallback/matmul/fp32_a_channelwise_8bit_b_fp32_c.h>

#if defined(__aarch64__) || defined(__ARM_NEON)
#include <torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h>
#include <torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h>
Expand Down Expand Up @@ -85,4 +87,72 @@ get_int8_a_int8_b_channelwise_qmatmul(
channelwise_8bit_a_channelwise_8bit_b::kernel<true, true, false, false>;
}
}

/*
a_stride_m: stride of a in memory to indiciate how far apart each row is.
b_stride_n: stride of b in memory to indiciate how far apart each row is.
If b is transposed (n x k), then this is how many bytes to skip to get to the
next row. If b is not transposed (k x n), then this is how many bytes to skip to
get to the next row.

It also returns the stride of a and b, that should be used in the kernel.

Will need to think of a better way to find the right
ukernel. Perhaps via ukernelconfig + registry?.
*/
using fp32_a_input_channelwise_8bit_b_f32_c_matmul_type = void (*)(
int,
int,
int,
const float*,
int,
const int8_t*,
int,
float*,
int,
const int8_t*,
const float*,
const float,
const int);

fp32_a_input_channelwise_8bit_b_f32_c_matmul_type
get_fp32_a_input_channelwise_8bit_b_f32_c_matmul(
int m,
int n,
int k,
bool a_transposed,
bool b_transposed,
int& a_stride_m,
int& b_stride_n);

fp32_a_input_channelwise_8bit_b_f32_c_matmul_type
get_fp32_a_input_channelwise_8bit_b_f32_c_matmul(
int m,
int n,
int k,
bool a_transposed,
bool b_transposed,
int& a_stride_m,
int& b_stride_n) {
#if defined(__aarch64__) || defined(__ARM_NEON)
if (!a_transposed && !b_transposed && n >= 16) {
a_stride_m = k;
b_stride_n = n;
return aarch64::quantized_matmul::
fp32_a_input_channelwise_8bit_b_1x16x4_f32::kernel<true, false, false>;
}
#endif // defined(__aarch64__) || defined(__ARM_NEON)
assert(!a_transposed);
if (b_transposed) {
a_stride_m = k;
b_stride_n = k;
return torchao::kernels::cpu::fallback::quantized_matmul::
fp32_a_input_channelwise_8bit_b_fp32::kernel<true, false, true>;
} else {
a_stride_m = k;
b_stride_n = n;
return torchao::kernels::cpu::fallback::quantized_matmul::
fp32_a_input_channelwise_8bit_b_fp32::kernel<true, false, false>;
}
}
} // namespace torchao::kernels::cpu::quantized_matmul
182 changes: 182 additions & 0 deletions torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,3 +446,185 @@ TEST(
Run(
/*m=*/4, /*k=*/2, /*n=*/1, 32);
}

class FP32A_QuantizedB_FP32C_Interface_Test
: public ::testing::TestWithParam<float> {
public:
int m;
int k;
int n;
int stride;

bool rhs_has_zeros;
bool lhs_is_transposed;
bool rhs_is_transposed;

std::vector<float> init_output;
std::vector<float> expected_output;

std::vector<float> lhs;

std::vector<float> rhs;
std::vector<int8_t> rhs_qvals;
std::vector<float> rhs_scales;
std::vector<int8_t> rhs_zeros;

void generate(
int m_,
int k_,
int n_,
bool rhs_has_zeros_,
bool lhs_is_transposed_,
bool rhs_is_transposed_,
int stride_ = 1) {
assert(!lhs_is_transposed_);
assert(rhs_has_zeros_);
m = m_;
k = k_;
n = n_;
stride = stride_;
rhs_has_zeros = rhs_has_zeros_;
lhs_is_transposed = lhs_is_transposed_;
rhs_is_transposed = rhs_is_transposed_;

assert(!rhs_is_transposed || stride == 1);

// Generate activations
lhs = get_random_vector(m * k, -1.0, 1.0);

// The strange thing this is doing is that instead of quantizing
// each output channel separately, we are quantizing each input channel
// Reason why we do !rhs_is_transposed is because
// we actually want k x n matrix not n x k matrix
// because each input channel is quantized separately
std::tie(rhs, rhs_qvals, rhs_scales, rhs_zeros) =
generate_per_token_quantized_tensor(k * stride, n, rhs_is_transposed);

// Compute expected output
init_output = get_random_vector(m * n, -1.0, 1.0);

assert(init_output.size() == m * n);
assert(lhs.size() == m * k);
assert(rhs.size() == n * stride * k);
assert(rhs_qvals.size() == n * stride * k);
assert(rhs_scales.size() == k * stride);
assert(rhs_zeros.size() == k * stride);
}

void execute(float beta) {
// Compute expected output
expected_output = init_output;

for (int m_idx = 0; m_idx < m; m_idx++) {
for (int n_idx = 0; n_idx < n; n_idx++) {
float res = 0.0;
for (int k_idx = 0; k_idx < k; k_idx++) {
int lhs_idx = m_idx * k + k_idx;
int rhs_idx = k_idx * stride * n + n_idx;
if (rhs_is_transposed) {
rhs_idx = n_idx * k * stride + k_idx * stride;
}
float rhs_dequant = rhs_scales[k_idx * stride] *
(static_cast<int16_t>(rhs_qvals[rhs_idx]) -
static_cast<int16_t>(rhs_zeros[k_idx * stride]));

res += lhs[lhs_idx] * rhs_dequant;
}
expected_output[m_idx * n + n_idx] =
expected_output[m_idx * n + n_idx] * beta + res;
}
}
}

float beta() const {
return GetParam();
}
};

static void test_fp32_a_input_channelwise_8bit_b(
int m,
int k,
int n,
float beta,
FP32A_QuantizedB_FP32C_Interface_Test& test_case,
int stride = 1) {
test_case.execute(beta);

int a_stride_m, b_stride_n;
auto kernel = torchao::kernels::cpu::quantized_matmul::
get_fp32_a_input_channelwise_8bit_b_f32_c_matmul(
m, n, k, false, false, a_stride_m, b_stride_n);
b_stride_n = b_stride_n * stride;

std::vector<float> output(test_case.init_output);
kernel(
m,
n,
k,
test_case.lhs.data(),
a_stride_m /*lhs_stride_m*/,
test_case.rhs_qvals.data(),
b_stride_n /*rhs_stride_n*/,
output.data(),
n /*out_stride_n*/,
test_case.rhs_zeros.data(),
test_case.rhs_scales.data(),
beta,
stride /*rhs qparams stride*/);

for (int i = 0; i < m * n; i++) {
EXPECT_NEAR(output[i], test_case.expected_output[i], kTol);
}
}

TEST_P(FP32A_QuantizedB_FP32C_Interface_Test, BTranposedWithZeroPoints) {
generate(3, 128, 16, true, false, false);
test_fp32_a_input_channelwise_8bit_b(
/*m=*/3, /*k=*/128, /*n=*/16, beta(), *this);
}

TEST_P(
FP32A_QuantizedB_FP32C_Interface_Test,
BTranposedWithZeroPointsOddSizes) {
generate(4, 37, 19, true, false, false);
test_fp32_a_input_channelwise_8bit_b(
/*m=*/4, /*k=*/37, /*n=*/19, beta(), *this);
}

// Test shapes for which we have to use fallback kernel
TEST_P(
FP32A_QuantizedB_FP32C_Interface_Test,
BTranposedWithZeroPointsOddSizesFallback) {
generate(4, 37, 3, true, false, false);
test_fp32_a_input_channelwise_8bit_b(
/*m=*/4, /*k=*/37, /*n=*/3, beta(), *this);
}

TEST_P(
FP32A_QuantizedB_FP32C_Interface_Test,
BTranposedWithZeroPointsOddSizes2Fallback) {
generate(4, 1, 3, true, false, false);
test_fp32_a_input_channelwise_8bit_b(
/*m=*/4, /*k=*/1, /*n=*/3, beta(), *this);
}

TEST_P(
FP32A_QuantizedB_FP32C_Interface_Test,
BTranposedWithZeroPointsOddSizesStrided) {
generate(4, 37, 19, true, false, false, 32);
test_fp32_a_input_channelwise_8bit_b(
/*m=*/4, /*k=*/37, /*n=*/19, beta(), *this, 32);
}

TEST_P(
FP32A_QuantizedB_FP32C_Interface_Test,
BTranposedWithZeroPointsOddSizes2FallbackStrided) {
generate(4, 5, 3, true, false, false, 32);
test_fp32_a_input_channelwise_8bit_b(
/*m=*/4, /*k=*/5, /*n=*/3, beta(), *this, 32);
}

INSTANTIATE_TEST_SUITE_P(
F32AInt8BFP32CTest,
FP32A_QuantizedB_FP32C_Interface_Test,
::testing::Values(0.0, 1.0, 3.1));
Loading