From baa33cee9608a7e3de9362c4ec5f4964c509eca9 Mon Sep 17 00:00:00 2001 From: Supriya Rao Date: Wed, 30 Sep 2020 14:57:56 -0700 Subject: [PATCH] [quant] PerChannelFloatQParams support for quint4x2 dtype Summary: Adds support for Per-channel quantization using float qparams for 4-bit dtype We use the new dispatch mechanism and use existing quantize/dequantize kernels to pack the 4-bit data depending on the bit_width. Size of 4-bit quantized tensor is half that of 8-bit quantized tensor. Test Plan: python test/test_quantization.py TestQuantizedTensor.test_quantize_per_channel_sub_byte Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 97553da79ec9e1eb2a63516a5ce37a9747dd9264 Pull Request resolved: https://github.com/pytorch/pytorch/pull/45594 --- .../native/quantized/affine_quantizer.cpp | 20 ++------ .../ATen/native/quantized/affine_quantizer.h | 3 +- .../cpu/kernels/QuantizedOpKernels.cpp | 48 ++++++++++++++----- test/quantization/test_quantized_tensor.py | 48 ++++++++++++++++++- 4 files changed, 88 insertions(+), 31 deletions(-) diff --git a/aten/src/ATen/native/quantized/affine_quantizer.cpp b/aten/src/ATen/native/quantized/affine_quantizer.cpp index 1e0285bc1426..1d0aed1174aa 100644 --- a/aten/src/ATen/native/quantized/affine_quantizer.cpp +++ b/aten/src/ATen/native/quantized/affine_quantizer.cpp @@ -174,7 +174,7 @@ Tensor quantize_tensor_per_channel_float_qparams( checkSameDevice(fn_name, rtensor, qtensor); checkSameSize(fn_name, qtensor, rtensor); - AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), fn_name, [&]() { + AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(qtensor.scalar_type(), fn_name, [&]() { checkQuantizedTensor(fn_name, qtensor); }); @@ -269,7 +269,7 @@ Tensor dequantize_tensor_per_channel_float_qparams( checkSameDevice(fn_name, rtensor, qtensor); checkSameSize(fn_name, qtensor, rtensor); - AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), fn_name, [&]() { + AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(qtensor.scalar_type(), fn_name, [&]() { checkQuantizedTensor(fn_name, qtensor); }); @@ -410,17 +410,13 @@ CAFFE2_API float dequantize_val(double scale, int64_t zero_point, T value) { * Note: For the case of embedding quantization we will set zero_point * to (-Xmin/scale), where Xmin is the min value in input tensor row. */ -template -T quantize_val_float_qparams(float scale, float zero_point, float value) { - int64_t qvalue; +int quantize_val_float_qparams(float scale, float zero_point, float value, int qmin, int qmax) { + int qvalue; - // TODO make sure qmax and qmin for dtypes other than int8, uint8 is correctly defined. - constexpr int64_t qmin = std::numeric_limits::min(); - constexpr int64_t qmax = std::numeric_limits::max(); float inv_scale = scale == 0 ? 1.0f : 1.0f / scale; qvalue = lrintf(value * inv_scale + zero_point); qvalue = std::max(qmin, std::min(qvalue, qmax)); - return static_cast(qvalue); + return qvalue; } template @@ -507,11 +503,5 @@ requantize_from_int(double, int64_t, int64_t); template CAFFE2_API qint32 requantize_from_int(double, int64_t, int64_t); -template CAFFE2_API qint8 -quantize_val_float_qparams(float scale, float zero_point, float value); -template CAFFE2_API quint8 -quantize_val_float_qparams(float scale, float zero_point, float value); -template CAFFE2_API qint32 -quantize_val_float_qparams(float scale, float zero_point, float value); } // namespace native } // namespace at diff --git a/aten/src/ATen/native/quantized/affine_quantizer.h b/aten/src/ATen/native/quantized/affine_quantizer.h index 2991922c7f86..670b119652cd 100644 --- a/aten/src/ATen/native/quantized/affine_quantizer.h +++ b/aten/src/ATen/native/quantized/affine_quantizer.h @@ -158,8 +158,7 @@ template CAFFE2_API DST_T requantize_from_int(double multiplier, int64_t zero_point, int64_t src); -template -CAFFE2_API T quantize_val_float_qparams(float scale, float zero_point, float value); +int quantize_val_float_qparams(float scale, float zero_point, float value, int qmin, int qmax); } // namespace native } // namespace at diff --git a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp index bae9fe650046..a65e9f00f1d8 100644 --- a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp +++ b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp @@ -2592,7 +2592,8 @@ void dequantize_per_channel_affine_kernel( Tensor rtensor, Tensor scales, Tensor zero_points, - int64_t axis) { + int64_t axis, + int bit_width=8) { // For contiguous tensors, e.g. NCHW, arbitrary axis can be used. // For channels_last/3d however axis == 0 or 1. @@ -2611,6 +2612,7 @@ void dequantize_per_channel_affine_kernel( check_tensor_memory_format(qtensor, rtensor); const auto* qd = qtensor.data_ptr(); float* rd = rtensor.data_ptr(); + const auto elem_per_byte = 8 / bit_width; if (axis == 1 && (rtensor.is_contiguous(MemoryFormat::ChannelsLast) || rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) { for (auto b = 0; b < batches; ++b) { @@ -2619,8 +2621,12 @@ void dequantize_per_channel_affine_kernel( auto i = b * channel * elements_per_channel + e * channel + c; // We need to convert the qint8 value to float to ensure the // subtraction subexpression returns a float - rd[i] = (static_cast(qd[i].val_) - zero_points_data[c]) * - scales_data[c]; + auto qvalue = qd[i / elem_per_byte].val_; + if (bit_width < 8) { + qvalue >>= (i % elem_per_byte) * bit_width; + qvalue &= (1 << bit_width) - 1; + } + rd[i] = (static_cast(qvalue) - zero_points_data[c]) * scales_data[c]; } } } @@ -2632,8 +2638,12 @@ void dequantize_per_channel_affine_kernel( c * elements_per_channel + e; // We need to convert the qint8 value to float to ensure the // subtraction subexpression returns a float - rd[i] = (static_cast(qd[i].val_) - zero_points_data[c]) * - scales_data[c]; + auto qvalue = qd[i / elem_per_byte].val_; + if (bit_width < 8) { + qvalue >>= (i % elem_per_byte) * bit_width; + qvalue &= (1 << bit_width) - 1; + } + rd[i] = (static_cast(qvalue) - zero_points_data[c]) * scales_data[c]; } } } @@ -2667,7 +2677,7 @@ void quantize_tensor_per_channel_float_qparams_cpu( TORCH_CHECK(rtensor.is_contiguous() || (axis <=1), "If tensor is channels_last contig then per channel quantization " "is supported only for axis = 0 or 1."); - AT_DISPATCH_QINT_TYPES( + AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES( qtensor.scalar_type(), "quantize_tensor_per_channel_float_qparams_cpu", [&]() { int64_t batches = size_to_dim_(axis, rtensor.sizes()); int64_t elements_per_channel = @@ -2677,15 +2687,22 @@ void quantize_tensor_per_channel_float_qparams_cpu( auto zero_points_data = zero_points.data_ptr(); check_tensor_memory_format(rtensor, qtensor); const float* rdata = rtensor.data_ptr(); - auto qdata = qtensor.data_ptr(); + auto qdata = reinterpret_cast(qtensor.data_ptr()); + const auto elem_per_byte = CHAR_BIT / bit_width; + int qvalue = 0; if (axis == 1 && (rtensor.is_contiguous(MemoryFormat::ChannelsLast) || rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) { for (auto b = 0; b < batches; ++b) { for (auto e = 0; e < elements_per_channel; ++e) { for (auto c = 0; c < channel; ++c) { auto i = b * channel * elements_per_channel + e * channel + c; - qdata[i] = quantize_val_float_qparams( - scales_data[c], zero_points_data[c], rdata[i]); + qvalue = quantize_val_float_qparams( + scales_data[c], zero_points_data[c], rdata[i], quant_min, quant_max); + if (i % elem_per_byte == 0) { + qdata[i / elem_per_byte] = static_cast(qvalue); + } else { + qdata[i / elem_per_byte] |= static_cast(qvalue << ((i % elem_per_byte) * bit_width)); + } } } } @@ -2695,8 +2712,13 @@ void quantize_tensor_per_channel_float_qparams_cpu( for (auto e = 0; e < elements_per_channel; ++e) { auto i = b * channel * elements_per_channel + c * elements_per_channel + e; - qdata[i] = quantize_val_float_qparams( - scales_data[c], zero_points_data[c], rdata[i]); + qvalue = quantize_val_float_qparams( + scales_data[c], zero_points_data[c], rdata[i], quant_min, quant_max); + if (i % elem_per_byte == 0) { + qdata[i / elem_per_byte] = static_cast(qvalue); + } else { + qdata[i / elem_per_byte] |= static_cast(qvalue << ((i % elem_per_byte) * bit_width)); + } } } } @@ -2710,9 +2732,9 @@ void dequantize_tensor_per_channel_float_qparams_cpu( Tensor scales, Tensor zero_points, int64_t axis) { - AT_DISPATCH_QINT_TYPES( + AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES( qtensor.scalar_type(), "dequantize_tensor_per_channel_float_qparams_cpu", [&]() { - dequantize_per_channel_affine_kernel(qtensor, rtensor, scales, zero_points, axis); + dequantize_per_channel_affine_kernel(qtensor, rtensor, scales, zero_points, axis, bit_width); }); } diff --git a/test/quantization/test_quantized_tensor.py b/test/quantization/test_quantized_tensor.py index d42b9674bb0f..37ecb4fb42a5 100644 --- a/test/quantization/test_quantized_tensor.py +++ b/test/quantization/test_quantized_tensor.py @@ -360,6 +360,50 @@ def quantize_ref(data, scales, zero_points): zero_points = torch.tensor([0.1, 0.2, 1.], dtype=torch.float) self._test_quantize_per_channel(r, scales, zero_points, 0, True) + def test_quantize_per_channel_sub_byte(self): + """ Tests the per channel quantization scheme for 4-bit qtensors. + The scale and zero point for this have to be in floating point. """ + r = torch.rand(3, 2, dtype=torch.float) * 4 + scales = torch.tensor([0.2, 0.03, 0.1], dtype=torch.float) + zero_points = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float) + qr = torch.quantize_per_channel(r, scales, zero_points, 0, torch.quint4x2) + + def _get_qranges(bit_width): + if bit_width == 4: + return 0, 15 + + def _quantize_per_channel_sub_byte_ref(data, scales, zero_points, axis, bit_width): + dims = data.size() + data = data.view(-1, dims[axis], np.prod(dims[axis + 1:])) + qtensor_size = math.ceil(data.numel() / 2) + res = torch.empty(qtensor_size, dtype=torch.uint8) + elem_per_byte = 8 / bit_width + quant_min, quant_max = _get_qranges(bit_width) + for i in range(data.size()[0]): + for j in range(data.size()[1]): + for k in range(data.size()[2]): + inv_scale = 1.0 / scales[j] + index = i * data.size()[1] * data.size()[2] + j * data.size()[2] + k + qvalue = np.clip( + np.round(data[i][j][k] * inv_scale + zero_points[j]), quant_min, quant_max).to(dtype=torch.int) + res_idx = int(index / elem_per_byte) + if (index % elem_per_byte == 0): + res[res_idx] = qvalue + else: + res[res_idx] |= (qvalue << ((index % elem_per_byte) * bit_width)) + return res + + ref_res = _quantize_per_channel_sub_byte_ref(r, scales, zero_points, 0, 4) + self.assertTrue(np.allclose(qr.int_repr(), ref_res)) + + # Check 4D tensor with non-zero axis. + r = torch.rand(3, 2, 4, 5, dtype=torch.float) * 4 + scales = torch.tensor([0.2, 0.03], dtype=torch.float) + zero_points = torch.tensor([0.1, 0.2], dtype=torch.float) + qr = torch.quantize_per_channel(r, scales, zero_points, axis=1, dtype=torch.quint4x2) + ref_res = _quantize_per_channel_sub_byte_ref(r, scales, zero_points, 1, 4) + self.assertTrue(np.allclose(qr.int_repr(), ref_res)) + def test_qtensor_permute(self): scale = 0.02 zero_point = 1 @@ -447,7 +491,9 @@ def test_qtensor_per_channel_load_save(self): scales = torch.rand(10, dtype=torch.double) * 0.02 + 0.01 zero_points = torch.round(torch.rand(10) * 20 + 1).to(torch.long) # quint32, cuda is not supported yet - for dtype in [torch.quint8, torch.qint8]: + for dtype in [torch.quint8, torch.qint8, torch.quint4x2]: + if dtype == torch.quint4x2: + zero_points = torch.ones(10, dtype=torch.float) qr = torch.quantize_per_channel(r, scales, zero_points, 1, dtype) with tempfile.NamedTemporaryFile() as f: # Serializing and Deserializing Tensor