Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,69 @@ void pack_weights(
bias);
}

template <int weight_nbit, int nr, int kr, int sr>
void pack_weights_with_lut(
// Output
void* packed_weights,
// Inputs
int n,
int k,
int group_size,
const int8_t* weight_qval_idxs,
int n_luts,
const int8_t* luts,
const float* weight_scales,
// weight_zeros not packed if nullptr
const int8_t* weight_zeros,
// bias not packed if nullptr
const float* bias) {
torchao::kernels::cpu::aarch64::linear::
channelwise_8bit_activation_groupwise_lowbit_weight::weight_packing::
pack_weights_with_lut<weight_nbit, nr, kr, sr>(
packed_weights,
n,
k,
group_size,
weight_qval_idxs,
n_luts,
luts,
weight_scales,
weight_zeros,
bias);
}

inline size_t packed_weights_with_lut_size(
int n,
int k,
int group_size,
int weight_nbit,
bool has_weight_zeros,
bool has_bias,
int nr,
int kr,
int sr) {
(void)kr; // unused
(void)sr; // unused
return weight_packing::packed_weights_with_lut_size(
n, k, group_size, weight_nbit, has_weight_zeros, has_bias, nr);
}

inline size_t packed_weights_with_lut_offset(
int n_idx,
int k,
int group_size,
int weight_nbit,
bool has_weight_zeros,
bool has_bias,
int nr,
int kr,
int sr) {
assert(n_idx % nr == 0);
auto packed_weights_size_nr_cols = packed_weights_with_lut_size(
nr, k, group_size, weight_nbit, has_weight_zeros, has_bias, nr, kr, sr);
return (n_idx / nr) * packed_weights_size_nr_cols;
}

template <int weight_nbit>
void kernel_1x1x32_f32_neondot(
// Outputs
Expand Down Expand Up @@ -182,7 +245,7 @@ void kernel_1x4x16_f32_neondot(
has_clamp);
}

template <int weight_nbit>
template <int weight_nbit, bool has_lut>
void kernel_1x8x16_f32_neondot(
// Outputs
float32_t* output,
Expand All @@ -200,7 +263,7 @@ void kernel_1x8x16_f32_neondot(
bool has_weight_zeros,
bool has_bias,
bool has_clamp) {
kernel::kernel_1x8x16_f32_neondot<weight_nbit>(
kernel::kernel_1x8x16_f32_neondot<weight_nbit, has_lut>(
output,
output_m_stride,
m,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ vec_clamp(float32x4_t x, float32x4_t vec_min, float32x4_t vec_max) {
// Roughly inspired by
// https://gitlab.arm.com/kleidi/kleidiai/-/blob/main/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c?ref_type=heads

template <int weight_nbit>
template <int weight_nbit, bool has_lut>
void kernel_1x8x16_f32_neondot(
// Outputs
float32_t* output,
Expand All @@ -79,6 +79,11 @@ void kernel_1x8x16_f32_neondot(
assert(k % group_size == 0);
assert(group_size % 16 == 0);

int8x16_t lut;
if constexpr (!has_lut) {
(void)lut; // unused
}

constexpr int bytes_per_128_weight_values = 16 * weight_nbit;

auto activation_data_byte_ptr = (char*)(activation_data);
Expand All @@ -99,6 +104,11 @@ void kernel_1x8x16_f32_neondot(
// Weights and activations are padded when prepared, so the
// reads are legal, even if on a partial tile
for (int n_idx = 0; n_idx < n; n_idx += 8) {
if constexpr (has_lut) {
lut = vld1q_s8((int8_t*)weight_data_byte_ptr);
weight_data_byte_ptr += 16;
}

// Set activation_ptr to start of activation qvals for row m_idx
activation_ptr = activation_data_byte_ptr;
float32x4_t res_0123 = vdupq_n_f32(0.0);
Expand Down Expand Up @@ -167,16 +177,33 @@ void kernel_1x8x16_f32_neondot(
// Each chunk is 64 values of unpacked data (4 cols x 16 vals/col).
// This comes out to (64 * weight_nbit / 8) bits = 8 * weight_nbit
// bytes of bitpacked data
torchao::bitpacking::vec_unpack_128_lowbit_values<weight_nbit>(
weight_q_cols01_0,
weight_q_cols23_0,
weight_q_cols45_0,
weight_q_cols67_0,
weight_q_cols01_1,
weight_q_cols23_1,
weight_q_cols45_1,
weight_q_cols67_1,
(uint8_t*)weight_data_byte_ptr);

if constexpr (has_lut) {
torchao::bitpacking::vec_unpack_128_lowbit_values_with_lut<
weight_nbit>(
weight_q_cols01_0,
weight_q_cols23_0,
weight_q_cols45_0,
weight_q_cols67_0,
weight_q_cols01_1,
weight_q_cols23_1,
weight_q_cols45_1,
weight_q_cols67_1,
(uint8_t*)weight_data_byte_ptr,
lut);
} else {
torchao::bitpacking::vec_unpack_128_lowbit_values<weight_nbit>(
weight_q_cols01_0,
weight_q_cols23_0,
weight_q_cols45_0,
weight_q_cols67_0,
weight_q_cols01_1,
weight_q_cols23_1,
weight_q_cols45_1,
weight_q_cols67_1,
(uint8_t*)weight_data_byte_ptr);
}

weight_data_byte_ptr += bytes_per_128_weight_values;

// Load 16 activation values
Expand Down
Loading
Loading