Skip to content

Commit

Permalink
[quant] PerChannelFloatQParams support for quint4x2 dtype
Browse files Browse the repository at this point in the history
Summary:
Adds support for Per-channel quantization using float qparams for 4-bit dtype
We use the new dispatch mechanism and use existing quantize/dequantize kernels to pack the
4-bit data depending on the bit_width.
Size of 4-bit quantized tensor is half that of 8-bit quantized tensor.

Test Plan:
python test/test_quantization.py TestQuantizedTensor.test_quantize_per_channel_sub_byte

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 97553da79ec9e1eb2a63516a5ce37a9747dd9264
Pull Request resolved: #45594
  • Loading branch information
supriyar committed Sep 30, 2020
1 parent 6e63f37 commit baa33ce
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 31 deletions.
20 changes: 5 additions & 15 deletions aten/src/ATen/native/quantized/affine_quantizer.cpp
Expand Up @@ -174,7 +174,7 @@ Tensor quantize_tensor_per_channel_float_qparams(
checkSameDevice(fn_name, rtensor, qtensor);
checkSameSize(fn_name, qtensor, rtensor);

AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), fn_name, [&]() {
AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(qtensor.scalar_type(), fn_name, [&]() {
checkQuantizedTensor<scalar_t>(fn_name, qtensor);
});

Expand Down Expand Up @@ -269,7 +269,7 @@ Tensor dequantize_tensor_per_channel_float_qparams(
checkSameDevice(fn_name, rtensor, qtensor);
checkSameSize(fn_name, qtensor, rtensor);

AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), fn_name, [&]() {
AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(qtensor.scalar_type(), fn_name, [&]() {
checkQuantizedTensor<scalar_t>(fn_name, qtensor);
});

Expand Down Expand Up @@ -410,17 +410,13 @@ CAFFE2_API float dequantize_val(double scale, int64_t zero_point, T value) {
* Note: For the case of embedding quantization we will set zero_point
* to (-Xmin/scale), where Xmin is the min value in input tensor row.
*/
template <typename T>
T quantize_val_float_qparams(float scale, float zero_point, float value) {
int64_t qvalue;
int quantize_val_float_qparams(float scale, float zero_point, float value, int qmin, int qmax) {
int qvalue;

// TODO make sure qmax and qmin for dtypes other than int8, uint8 is correctly defined.
constexpr int64_t qmin = std::numeric_limits<typename T::underlying>::min();
constexpr int64_t qmax = std::numeric_limits<typename T::underlying>::max();
float inv_scale = scale == 0 ? 1.0f : 1.0f / scale;
qvalue = lrintf(value * inv_scale + zero_point);
qvalue = std::max(qmin, std::min(qvalue, qmax));
return static_cast<T>(qvalue);
return qvalue;
}

template <typename SRC_T, typename DST_T>
Expand Down Expand Up @@ -507,11 +503,5 @@ requantize_from_int<quint8>(double, int64_t, int64_t);
template CAFFE2_API qint32
requantize_from_int<qint32>(double, int64_t, int64_t);

template CAFFE2_API qint8
quantize_val_float_qparams<qint8>(float scale, float zero_point, float value);
template CAFFE2_API quint8
quantize_val_float_qparams<quint8>(float scale, float zero_point, float value);
template CAFFE2_API qint32
quantize_val_float_qparams<qint32>(float scale, float zero_point, float value);
} // namespace native
} // namespace at
3 changes: 1 addition & 2 deletions aten/src/ATen/native/quantized/affine_quantizer.h
Expand Up @@ -158,8 +158,7 @@ template <typename DST_T>
CAFFE2_API DST_T
requantize_from_int(double multiplier, int64_t zero_point, int64_t src);

template <typename T>
CAFFE2_API T quantize_val_float_qparams(float scale, float zero_point, float value);
int quantize_val_float_qparams(float scale, float zero_point, float value, int qmin, int qmax);

} // namespace native
} // namespace at
48 changes: 35 additions & 13 deletions aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
Expand Up @@ -2592,7 +2592,8 @@ void dequantize_per_channel_affine_kernel(
Tensor rtensor,
Tensor scales,
Tensor zero_points,
int64_t axis) {
int64_t axis,
int bit_width=8) {

// For contiguous tensors, e.g. NCHW, arbitrary axis can be used.
// For channels_last/3d however axis == 0 or 1.
Expand All @@ -2611,6 +2612,7 @@ void dequantize_per_channel_affine_kernel(
check_tensor_memory_format(qtensor, rtensor);
const auto* qd = qtensor.data_ptr<Q>();
float* rd = rtensor.data_ptr<float>();
const auto elem_per_byte = 8 / bit_width;
if (axis == 1 && (rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
for (auto b = 0; b < batches; ++b) {
Expand All @@ -2619,8 +2621,12 @@ void dequantize_per_channel_affine_kernel(
auto i = b * channel * elements_per_channel + e * channel + c;
// We need to convert the qint8 value to float to ensure the
// subtraction subexpression returns a float
rd[i] = (static_cast<float>(qd[i].val_) - zero_points_data[c]) *
scales_data[c];
auto qvalue = qd[i / elem_per_byte].val_;
if (bit_width < 8) {
qvalue >>= (i % elem_per_byte) * bit_width;
qvalue &= (1 << bit_width) - 1;
}
rd[i] = (static_cast<float>(qvalue) - zero_points_data[c]) * scales_data[c];
}
}
}
Expand All @@ -2632,8 +2638,12 @@ void dequantize_per_channel_affine_kernel(
c * elements_per_channel + e;
// We need to convert the qint8 value to float to ensure the
// subtraction subexpression returns a float
rd[i] = (static_cast<float>(qd[i].val_) - zero_points_data[c]) *
scales_data[c];
auto qvalue = qd[i / elem_per_byte].val_;
if (bit_width < 8) {
qvalue >>= (i % elem_per_byte) * bit_width;
qvalue &= (1 << bit_width) - 1;
}
rd[i] = (static_cast<float>(qvalue) - zero_points_data[c]) * scales_data[c];
}
}
}
Expand Down Expand Up @@ -2667,7 +2677,7 @@ void quantize_tensor_per_channel_float_qparams_cpu(
TORCH_CHECK(rtensor.is_contiguous() || (axis <=1),
"If tensor is channels_last contig then per channel quantization "
"is supported only for axis = 0 or 1.");
AT_DISPATCH_QINT_TYPES(
AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(
qtensor.scalar_type(), "quantize_tensor_per_channel_float_qparams_cpu", [&]() {
int64_t batches = size_to_dim_(axis, rtensor.sizes());
int64_t elements_per_channel =
Expand All @@ -2677,15 +2687,22 @@ void quantize_tensor_per_channel_float_qparams_cpu(
auto zero_points_data = zero_points.data_ptr<float>();
check_tensor_memory_format(rtensor, qtensor);
const float* rdata = rtensor.data_ptr<float>();
auto qdata = qtensor.data_ptr<scalar_t>();
auto qdata = reinterpret_cast<underlying_t*>(qtensor.data_ptr<scalar_t>());
const auto elem_per_byte = CHAR_BIT / bit_width;
int qvalue = 0;
if (axis == 1 && (rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
for (auto b = 0; b < batches; ++b) {
for (auto e = 0; e < elements_per_channel; ++e) {
for (auto c = 0; c < channel; ++c) {
auto i = b * channel * elements_per_channel + e * channel + c;
qdata[i] = quantize_val_float_qparams<scalar_t>(
scales_data[c], zero_points_data[c], rdata[i]);
qvalue = quantize_val_float_qparams(
scales_data[c], zero_points_data[c], rdata[i], quant_min, quant_max);
if (i % elem_per_byte == 0) {
qdata[i / elem_per_byte] = static_cast<underlying_t>(qvalue);
} else {
qdata[i / elem_per_byte] |= static_cast<underlying_t>(qvalue << ((i % elem_per_byte) * bit_width));
}
}
}
}
Expand All @@ -2695,8 +2712,13 @@ void quantize_tensor_per_channel_float_qparams_cpu(
for (auto e = 0; e < elements_per_channel; ++e) {
auto i = b * channel * elements_per_channel +
c * elements_per_channel + e;
qdata[i] = quantize_val_float_qparams<scalar_t>(
scales_data[c], zero_points_data[c], rdata[i]);
qvalue = quantize_val_float_qparams(
scales_data[c], zero_points_data[c], rdata[i], quant_min, quant_max);
if (i % elem_per_byte == 0) {
qdata[i / elem_per_byte] = static_cast<underlying_t>(qvalue);
} else {
qdata[i / elem_per_byte] |= static_cast<underlying_t>(qvalue << ((i % elem_per_byte) * bit_width));
}
}
}
}
Expand All @@ -2710,9 +2732,9 @@ void dequantize_tensor_per_channel_float_qparams_cpu(
Tensor scales,
Tensor zero_points,
int64_t axis) {
AT_DISPATCH_QINT_TYPES(
AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(
qtensor.scalar_type(), "dequantize_tensor_per_channel_float_qparams_cpu", [&]() {
dequantize_per_channel_affine_kernel<float, float, scalar_t>(qtensor, rtensor, scales, zero_points, axis);
dequantize_per_channel_affine_kernel<float, float, scalar_t>(qtensor, rtensor, scales, zero_points, axis, bit_width);
});
}

Expand Down
48 changes: 47 additions & 1 deletion test/quantization/test_quantized_tensor.py
Expand Up @@ -360,6 +360,50 @@ def quantize_ref(data, scales, zero_points):
zero_points = torch.tensor([0.1, 0.2, 1.], dtype=torch.float)
self._test_quantize_per_channel(r, scales, zero_points, 0, True)

def test_quantize_per_channel_sub_byte(self):
""" Tests the per channel quantization scheme for 4-bit qtensors.
The scale and zero point for this have to be in floating point. """
r = torch.rand(3, 2, dtype=torch.float) * 4
scales = torch.tensor([0.2, 0.03, 0.1], dtype=torch.float)
zero_points = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float)
qr = torch.quantize_per_channel(r, scales, zero_points, 0, torch.quint4x2)

def _get_qranges(bit_width):
if bit_width == 4:
return 0, 15

def _quantize_per_channel_sub_byte_ref(data, scales, zero_points, axis, bit_width):
dims = data.size()
data = data.view(-1, dims[axis], np.prod(dims[axis + 1:]))
qtensor_size = math.ceil(data.numel() / 2)
res = torch.empty(qtensor_size, dtype=torch.uint8)
elem_per_byte = 8 / bit_width
quant_min, quant_max = _get_qranges(bit_width)
for i in range(data.size()[0]):
for j in range(data.size()[1]):
for k in range(data.size()[2]):
inv_scale = 1.0 / scales[j]
index = i * data.size()[1] * data.size()[2] + j * data.size()[2] + k
qvalue = np.clip(
np.round(data[i][j][k] * inv_scale + zero_points[j]), quant_min, quant_max).to(dtype=torch.int)
res_idx = int(index / elem_per_byte)
if (index % elem_per_byte == 0):
res[res_idx] = qvalue
else:
res[res_idx] |= (qvalue << ((index % elem_per_byte) * bit_width))
return res

ref_res = _quantize_per_channel_sub_byte_ref(r, scales, zero_points, 0, 4)
self.assertTrue(np.allclose(qr.int_repr(), ref_res))

# Check 4D tensor with non-zero axis.
r = torch.rand(3, 2, 4, 5, dtype=torch.float) * 4
scales = torch.tensor([0.2, 0.03], dtype=torch.float)
zero_points = torch.tensor([0.1, 0.2], dtype=torch.float)
qr = torch.quantize_per_channel(r, scales, zero_points, axis=1, dtype=torch.quint4x2)
ref_res = _quantize_per_channel_sub_byte_ref(r, scales, zero_points, 1, 4)
self.assertTrue(np.allclose(qr.int_repr(), ref_res))

def test_qtensor_permute(self):
scale = 0.02
zero_point = 1
Expand Down Expand Up @@ -447,7 +491,9 @@ def test_qtensor_per_channel_load_save(self):
scales = torch.rand(10, dtype=torch.double) * 0.02 + 0.01
zero_points = torch.round(torch.rand(10) * 20 + 1).to(torch.long)
# quint32, cuda is not supported yet
for dtype in [torch.quint8, torch.qint8]:
for dtype in [torch.quint8, torch.qint8, torch.quint4x2]:
if dtype == torch.quint4x2:
zero_points = torch.ones(10, dtype=torch.float)
qr = torch.quantize_per_channel(r, scales, zero_points, 1, dtype)
with tempfile.NamedTemporaryFile() as f:
# Serializing and Deserializing Tensor
Expand Down

0 comments on commit baa33ce

Please sign in to comment.