Skip to content

Commit b49bf38

Browse files
authored
[ET] correcting cpu ref quantize_per_channel logic to align with ATen (#12431)
# Context The quantize_per_channel was not perfectly aligned with the ATen implementation, and demonstrated errors when specifying different axis. This bug wasn't distinctly acknowledged given that the test cases only has one test for the whole operator. In order to align more closely with ATen this change simply does a single loop imlpementation with direct channel index calculation over the old `apply_over_dim_list` approach. # Changes We change the core logic for quantize_per_channel to more properly align with ATen's implementation, and we also change it from `apply_over_dim_list` approach to a single loop implementation with direct channel index calculation. This also adds more comprehensive testing for quantize_per_channel so that a bug isn't missed again. Differential Revision: [D77746130](https://our.internmc.facebook.com/intern/diff/D77746130/)
1 parent bd48732 commit b49bf38

File tree

3 files changed

+263
-51
lines changed

3 files changed

+263
-51
lines changed

kernels/quantized/cpu/op_quantize.cpp

Lines changed: 23 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/kernels/portable/cpu/util/reduce_util.h>
109
#include <executorch/runtime/kernel/kernel_includes.h>
1110
#include <algorithm>
1211
#include <cinttypes>
@@ -282,55 +281,34 @@ Tensor& quantize_per_channel_out(
282281

283282
check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out);
284283

285-
// a list contains all dimensions except axis
286-
int64_t dims[kTensorDimensionLimit];
287-
for (int64_t i = 0; i < input.dim() - 1; i++) {
288-
if (i < axis) {
289-
dims[i] = i;
290-
} else {
291-
dims[i] = i - 1;
292-
}
293-
}
294284
const double* scale_data = scale.const_data_ptr<double>();
295285
const int64_t* zero_point_data = zero_point.const_data_ptr<int64_t>();
296286

297-
std::optional<executorch::aten::ArrayRef<int64_t>> optional_dim_list{
298-
executorch::aten::ArrayRef<int64_t>{dims, size_t(input.dim() - 1)}};
299-
300-
// Actual quantization logic
301-
// input, out are the input and output tensors
302-
// channel_ix is the index along the axis dimension. 0 <= channel_ix <
303-
// input.size(axis).
304-
// i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix
305-
// will be 0, 1, 2, ... C-1
306-
// in_ix is the flat index of the element you are quantizing.
307-
// in other words you are quantizing in_data[in_ix]
287+
// High-performance single loop with direct channel calculation
308288
#define QUANTIZE_IMPL(CTYPE_IN, CTYPE_OUT, out_dtype) \
309-
case ScalarType::out_dtype: \
310-
for (size_t channel_ix = 0; channel_ix < input.size(axis); ++channel_ix) { \
311-
double _scale = scale_data[channel_ix]; \
312-
int64_t _zero_point = zero_point_data[channel_ix]; \
313-
auto* out_data_ptr = out.mutable_data_ptr<CTYPE_OUT>(); \
314-
const auto* input_data_ptr = input.const_data_ptr<CTYPE_IN>(); \
315-
apply_over_dim_list( \
316-
[input_data_ptr, \
317-
out_data_ptr, \
318-
_scale, \
319-
_zero_point, \
320-
quant_min, \
321-
quant_max](size_t in_ix) { \
322-
out_data_ptr[in_ix] = quantize_val<CTYPE_OUT, CTYPE_IN>( \
323-
_scale, \
324-
_zero_point, \
325-
input_data_ptr[in_ix], \
326-
quant_min, \
327-
quant_max); \
328-
}, \
329-
input, \
330-
optional_dim_list, \
331-
channel_ix); \
289+
case ScalarType::out_dtype: { \
290+
auto* out_data_ptr = out.mutable_data_ptr<CTYPE_OUT>(); \
291+
const auto* input_data_ptr = input.const_data_ptr<CTYPE_IN>(); \
292+
const int64_t input_numel = input.numel(); \
293+
const int64_t axis_size = input.size(axis); \
294+
/* Calculate the stride pattern for efficient channel index calculation */ \
295+
int64_t axis_block_size = 1; \
296+
for (int64_t i = axis + 1; i < input.dim(); i++) { \
297+
axis_block_size *= input.size(i); \
332298
} \
333-
break;
299+
/* Single loop over all elements */ \
300+
for (int64_t i = 0; i < input_numel; i++) { \
301+
/* Calculate which channel this element belongs to */ \
302+
int64_t channel_idx = (i / axis_block_size) % axis_size; \
303+
/* Get quantization parameters for this channel */ \
304+
double _scale = scale_data[channel_idx]; \
305+
int64_t _zero_point = zero_point_data[channel_idx]; \
306+
/* Apply quantization */ \
307+
out_data_ptr[i] = quantize_val<CTYPE_OUT, CTYPE_IN>( \
308+
_scale, _zero_point, input_data_ptr[i], quant_min, quant_max); \
309+
} \
310+
} break;
311+
334312
#define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \
335313
case ScalarType::in_dtype: \
336314
switch (out.scalar_type()) { \

kernels/quantized/cpu/targets.bzl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,6 @@ _QUANT_OPS = (
5151
),
5252
op_target(
5353
name = "op_quantize",
54-
deps = [
55-
"//executorch/kernels/portable/cpu/util:reduce_util",
56-
],
57-
_aten_mode_deps = [
58-
"//executorch/kernels/portable/cpu/util:reduce_util_aten",
59-
],
6054
),
6155
)
6256

kernels/quantized/test/op_quantize_test.cpp

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,243 @@ TEST(OpQuantizeOutTest, QuantizePerChannel) {
206206

207207
EXPECT_TENSOR_EQ(out, expected);
208208
}
209+
210+
TEST(OpQuantizeOutTest, QuantizePerChannelAxis0) {
211+
TensorFactory<ScalarType::Float> tf_float;
212+
TensorFactory<ScalarType::Double> tf_double;
213+
TensorFactory<ScalarType::Long> tf_long;
214+
215+
Tensor input = tf_float.full({3, 2}, 4);
216+
Tensor scale = tf_double.make({3}, {0.5, 1.0, 2.0});
217+
Tensor zero_point = tf_long.make({3}, {100, 50, 25});
218+
int64_t quant_min = 0;
219+
int64_t quant_max = 255;
220+
221+
TensorFactory<ScalarType::Byte> tfo;
222+
Tensor out = tfo.zeros({3, 2});
223+
// Channel 0: 4 / 0.5 + 100 = 108
224+
// Channel 1: 4 / 1.0 + 50 = 54
225+
// Channel 2: 4 / 2.0 + 25 = 27
226+
Tensor expected = tfo.make({3, 2}, {108, 108, 54, 54, 27, 27});
227+
quantize_per_channel_out(
228+
input, scale, zero_point, 0, quant_min, quant_max, ScalarType::Byte, out);
229+
230+
EXPECT_TENSOR_EQ(out, expected);
231+
}
232+
233+
TEST(OpQuantizeOutTest, QuantizePerChannel3D) {
234+
TensorFactory<ScalarType::Float> tf_float;
235+
TensorFactory<ScalarType::Double> tf_double;
236+
TensorFactory<ScalarType::Long> tf_long;
237+
238+
// Test 3D tensor with axis=1 (middle dimension)
239+
Tensor input = tf_float.full({2, 3, 4}, 6);
240+
Tensor scale = tf_double.make({3}, {0.5, 1.0, 1.5});
241+
Tensor zero_point = tf_long.make({3}, {10, 20, 30});
242+
int64_t quant_min = -128;
243+
int64_t quant_max = 127;
244+
245+
TensorFactory<ScalarType::Char> tfo;
246+
Tensor out = tfo.zeros({2, 3, 4});
247+
// Channel 0: 6 / 0.5 + 10 = 22
248+
// Channel 1: 6 / 1.0 + 20 = 26
249+
// Channel 2: 6 / 1.5 + 30 = 34
250+
Tensor expected = tfo.make(
251+
{2, 3, 4},
252+
{
253+
22, 22, 22, 22, // First batch, channel 0
254+
26, 26, 26, 26, // First batch, channel 1
255+
34, 34, 34, 34, // First batch, channel 2
256+
22, 22, 22, 22, // Second batch, channel 0
257+
26, 26, 26, 26, // Second batch, channel 1
258+
34, 34, 34, 34 // Second batch, channel 2
259+
});
260+
quantize_per_channel_out(
261+
input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Char, out);
262+
263+
EXPECT_TENSOR_EQ(out, expected);
264+
}
265+
266+
TEST(OpQuantizeOutTest, QuantizePerChannel4D) {
267+
TensorFactory<ScalarType::Float> tf_float;
268+
TensorFactory<ScalarType::Double> tf_double;
269+
TensorFactory<ScalarType::Long> tf_long;
270+
271+
// Test 4D tensor with axis=2 (typical conv weight layout: N,C,H,W)
272+
Tensor input = tf_float.full({2, 2, 3, 2}, 8);
273+
Tensor scale = tf_double.make({3}, {0.25, 0.5, 1.0});
274+
Tensor zero_point = tf_long.make({3}, {0, 10, 20});
275+
int64_t quant_min = -128;
276+
int64_t quant_max = 127;
277+
278+
TensorFactory<ScalarType::Char> tfo;
279+
Tensor out = tfo.zeros({2, 2, 3, 2});
280+
// Channel 0: 8 / 0.25 + 0 = 32
281+
// Channel 1: 8 / 0.5 + 10 = 26
282+
// Channel 2: 8 / 1.0 + 20 = 28
283+
std::vector<int8_t> expected_data;
284+
for (int n = 0; n < 2; n++) {
285+
for (int c = 0; c < 2; c++) {
286+
for (int h = 0; h < 3; h++) {
287+
for (int w = 0; w < 2; w++) {
288+
int8_t val = (h == 0) ? 32 : (h == 1) ? 26 : 28;
289+
expected_data.push_back(val);
290+
}
291+
}
292+
}
293+
}
294+
Tensor expected = tfo.make({2, 2, 3, 2}, expected_data);
295+
quantize_per_channel_out(
296+
input, scale, zero_point, 2, quant_min, quant_max, ScalarType::Char, out);
297+
298+
EXPECT_TENSOR_EQ(out, expected);
299+
}
300+
301+
TEST(OpQuantizeOutTest, QuantizePerChannelNegativeAxis) {
302+
TensorFactory<ScalarType::Float> tf_float;
303+
TensorFactory<ScalarType::Double> tf_double;
304+
TensorFactory<ScalarType::Long> tf_long;
305+
306+
Tensor input = tf_float.full({2, 3}, 5);
307+
Tensor scale = tf_double.make({3}, {0.5, 1.0, 2.0});
308+
Tensor zero_point = tf_long.make({3}, {0, 10, 20});
309+
int64_t quant_min = 0;
310+
int64_t quant_max = 255;
311+
312+
TensorFactory<ScalarType::Byte> tfo;
313+
Tensor out = tfo.zeros({2, 3});
314+
// Using axis=-1 should be equivalent to axis=1 for 2D tensor
315+
// Channel 0: 5 / 0.5 + 0 = 10
316+
// Channel 1: 5 / 1.0 + 10 = 15
317+
// Channel 2: 5 / 2.0 + 20 = 22 (rounded from 22.5)
318+
Tensor expected = tfo.make({2, 3}, {10, 15, 22, 10, 15, 22});
319+
quantize_per_channel_out(
320+
input,
321+
scale,
322+
zero_point,
323+
-1,
324+
quant_min,
325+
quant_max,
326+
ScalarType::Byte,
327+
out);
328+
329+
EXPECT_TENSOR_EQ(out, expected);
330+
}
331+
332+
TEST(OpQuantizeOutTest, QuantizePerChannelSingleChannel) {
333+
TensorFactory<ScalarType::Float> tf_float;
334+
TensorFactory<ScalarType::Double> tf_double;
335+
TensorFactory<ScalarType::Long> tf_long;
336+
337+
Tensor input = tf_float.full({3, 1, 4}, 7);
338+
Tensor scale = tf_double.make({1}, {0.5});
339+
Tensor zero_point = tf_long.make({1}, {128});
340+
int64_t quant_min = 0;
341+
int64_t quant_max = 255;
342+
343+
TensorFactory<ScalarType::Byte> tfo;
344+
Tensor out = tfo.zeros({3, 1, 4});
345+
// Single channel: 7 / 0.5 + 128 = 142
346+
Tensor expected = tfo.full({3, 1, 4}, 142);
347+
quantize_per_channel_out(
348+
input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Byte, out);
349+
350+
EXPECT_TENSOR_EQ(out, expected);
351+
}
352+
353+
TEST(OpQuantizeOutTest, QuantizePerChannelDifferentInputTypes) {
354+
TensorFactory<ScalarType::Double> tf_double_input;
355+
TensorFactory<ScalarType::Double> tf_double;
356+
TensorFactory<ScalarType::Long> tf_long;
357+
358+
Tensor input = tf_double_input.full({2, 2}, 3.14159);
359+
Tensor scale = tf_double.make({2}, {0.01, 0.02});
360+
Tensor zero_point = tf_long.make({2}, {0, 100});
361+
int64_t quant_min = -128;
362+
int64_t quant_max = 127;
363+
364+
TensorFactory<ScalarType::Char> tfo;
365+
Tensor out = tfo.zeros({2, 2});
366+
// Channel 0: 3.14159 / 0.01 + 0 = 314 -> clamped to 127
367+
// Channel 1: 3.14159 / 0.02 + 100 = 257 -> clamped to 127
368+
Tensor expected = tfo.make({2, 2}, {127, 127, 127, 127});
369+
quantize_per_channel_out(
370+
input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Char, out);
371+
372+
EXPECT_TENSOR_EQ(out, expected);
373+
}
374+
375+
TEST(OpQuantizeOutTest, QuantizePerChannelDifferentOutputTypes) {
376+
TensorFactory<ScalarType::Float> tf_float;
377+
TensorFactory<ScalarType::Double> tf_double;
378+
TensorFactory<ScalarType::Long> tf_long;
379+
380+
Tensor input = tf_float.full({2, 2}, 10);
381+
Tensor scale = tf_double.make({2}, {1.0, 2.0});
382+
Tensor zero_point = tf_long.make({2}, {1000, 2000});
383+
int64_t quant_min = -32768;
384+
int64_t quant_max = 32767;
385+
386+
// Test with 16-bit output
387+
TensorFactory<ScalarType::Short> tfo;
388+
Tensor out = tfo.zeros({2, 2});
389+
// Channel 0: 10 / 1.0 + 1000 = 1010
390+
// Channel 1: 10 / 2.0 + 2000 = 2005
391+
Tensor expected = tfo.make({2, 2}, {1010, 2005, 1010, 2005});
392+
quantize_per_channel_out(
393+
input,
394+
scale,
395+
zero_point,
396+
1,
397+
quant_min,
398+
quant_max,
399+
ScalarType::Short,
400+
out);
401+
402+
EXPECT_TENSOR_EQ(out, expected);
403+
}
404+
405+
TEST(OpQuantizeOutTest, QuantizePerChannelMixedValues) {
406+
TensorFactory<ScalarType::Float> tf_float;
407+
TensorFactory<ScalarType::Double> tf_double;
408+
TensorFactory<ScalarType::Long> tf_long;
409+
410+
// Test with different input values per position
411+
Tensor input = tf_float.make({2, 3}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0});
412+
Tensor scale = tf_double.make({3}, {0.5, 1.0, 1.5});
413+
Tensor zero_point = tf_long.make({3}, {10, 20, 30});
414+
int64_t quant_min = 0;
415+
int64_t quant_max = 255;
416+
417+
TensorFactory<ScalarType::Byte> tfo;
418+
Tensor out = tfo.zeros({2, 3});
419+
// Row 0: [1.0/0.5+10, 2.0/1.0+20, 3.0/1.5+30] = [12, 22, 32]
420+
// Row 1: [4.0/0.5+10, 5.0/1.0+20, 6.0/1.5+30] = [18, 25, 34]
421+
Tensor expected = tfo.make({2, 3}, {12, 22, 32, 18, 25, 34});
422+
quantize_per_channel_out(
423+
input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Byte, out);
424+
425+
EXPECT_TENSOR_EQ(out, expected);
426+
}
427+
428+
TEST(OpQuantizeOutTest, QuantizePerChannelClampingBehavior) {
429+
TensorFactory<ScalarType::Float> tf_float;
430+
TensorFactory<ScalarType::Double> tf_double;
431+
TensorFactory<ScalarType::Long> tf_long;
432+
433+
// Test values that will exceed quant_min/quant_max bounds
434+
Tensor input = tf_float.make({1, 3}, {-100.0, 0.0, 100.0});
435+
Tensor scale = tf_double.make({3}, {1.0, 1.0, 1.0});
436+
Tensor zero_point = tf_long.make({3}, {0, 0, 0});
437+
int64_t quant_min = -10;
438+
int64_t quant_max = 10;
439+
440+
TensorFactory<ScalarType::Char> tfo;
441+
Tensor out = tfo.zeros({1, 3});
442+
// Values: [-100, 0, 100] should be clamped to [-10, 0, 10]
443+
Tensor expected = tfo.make({1, 3}, {-10, 0, 10});
444+
quantize_per_channel_out(
445+
input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Char, out);
446+
447+
EXPECT_TENSOR_EQ(out, expected);
448+
}

0 commit comments

Comments
 (0)