Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[quant] PerChannelFloatQParams support for quint4x2 dtype #45594

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 5 additions & 15 deletions aten/src/ATen/native/quantized/affine_quantizer.cpp
Expand Up @@ -174,7 +174,7 @@ Tensor quantize_tensor_per_channel_float_qparams(
checkSameDevice(fn_name, rtensor, qtensor);
checkSameSize(fn_name, qtensor, rtensor);

AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), fn_name, [&]() {
AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(qtensor.scalar_type(), fn_name, [&]() {
checkQuantizedTensor<scalar_t>(fn_name, qtensor);
});

Expand Down Expand Up @@ -269,7 +269,7 @@ Tensor dequantize_tensor_per_channel_float_qparams(
checkSameDevice(fn_name, rtensor, qtensor);
checkSameSize(fn_name, qtensor, rtensor);

AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), fn_name, [&]() {
AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(qtensor.scalar_type(), fn_name, [&]() {
checkQuantizedTensor<scalar_t>(fn_name, qtensor);
});

Expand Down Expand Up @@ -410,17 +410,13 @@ CAFFE2_API float dequantize_val(double scale, int64_t zero_point, T value) {
* Note: For the case of embedding quantization we will set zero_point
* to (-Xmin/scale), where Xmin is the min value in input tensor row.
*/
template <typename T>
T quantize_val_float_qparams(float scale, float zero_point, float value) {
int64_t qvalue;
int quantize_val_float_qparams(float scale, float zero_point, float value, int qmin, int qmax) {
int qvalue;

// TODO make sure qmax and qmin for dtypes other than int8, uint8 is correctly defined.
constexpr int64_t qmin = std::numeric_limits<typename T::underlying>::min();
constexpr int64_t qmax = std::numeric_limits<typename T::underlying>::max();
float inv_scale = scale == 0 ? 1.0f : 1.0f / scale;
qvalue = lrintf(value * inv_scale + zero_point);
qvalue = std::max(qmin, std::min(qvalue, qmax));
return static_cast<T>(qvalue);
return qvalue;
}

template <typename SRC_T, typename DST_T>
Expand Down Expand Up @@ -507,11 +503,5 @@ requantize_from_int<quint8>(double, int64_t, int64_t);
template CAFFE2_API qint32
requantize_from_int<qint32>(double, int64_t, int64_t);

template CAFFE2_API qint8
quantize_val_float_qparams<qint8>(float scale, float zero_point, float value);
template CAFFE2_API quint8
quantize_val_float_qparams<quint8>(float scale, float zero_point, float value);
template CAFFE2_API qint32
quantize_val_float_qparams<qint32>(float scale, float zero_point, float value);
} // namespace native
} // namespace at
3 changes: 1 addition & 2 deletions aten/src/ATen/native/quantized/affine_quantizer.h
Expand Up @@ -158,8 +158,7 @@ template <typename DST_T>
CAFFE2_API DST_T
requantize_from_int(double multiplier, int64_t zero_point, int64_t src);

template <typename T>
CAFFE2_API T quantize_val_float_qparams(float scale, float zero_point, float value);
int quantize_val_float_qparams(float scale, float zero_point, float value, int qmin, int qmax);

} // namespace native
} // namespace at
48 changes: 35 additions & 13 deletions aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
Expand Up @@ -2592,7 +2592,8 @@ void dequantize_per_channel_affine_kernel(
Tensor rtensor,
Tensor scales,
Tensor zero_points,
int64_t axis) {
int64_t axis,
int bit_width=8) {

// For contiguous tensors, e.g. NCHW, arbitrary axis can be used.
// For channels_last/3d however axis == 0 or 1.
Expand All @@ -2611,6 +2612,7 @@ void dequantize_per_channel_affine_kernel(
check_tensor_memory_format(qtensor, rtensor);
const auto* qd = qtensor.data_ptr<Q>();
float* rd = rtensor.data_ptr<float>();
const auto elem_per_byte = 8 / bit_width;
if (axis == 1 && (rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
for (auto b = 0; b < batches; ++b) {
Expand All @@ -2619,8 +2621,12 @@ void dequantize_per_channel_affine_kernel(
auto i = b * channel * elements_per_channel + e * channel + c;
// We need to convert the qint8 value to float to ensure the
// subtraction subexpression returns a float
rd[i] = (static_cast<float>(qd[i].val_) - zero_points_data[c]) *
scales_data[c];
auto qvalue = qd[i / elem_per_byte].val_;
if (bit_width < 8) {
qvalue >>= (i % elem_per_byte) * bit_width;
qvalue &= (1 << bit_width) - 1;
}
rd[i] = (static_cast<float>(qvalue) - zero_points_data[c]) * scales_data[c];
}
}
}
Expand All @@ -2632,8 +2638,12 @@ void dequantize_per_channel_affine_kernel(
c * elements_per_channel + e;
// We need to convert the qint8 value to float to ensure the
// subtraction subexpression returns a float
rd[i] = (static_cast<float>(qd[i].val_) - zero_points_data[c]) *
scales_data[c];
auto qvalue = qd[i / elem_per_byte].val_;
if (bit_width < 8) {
qvalue >>= (i % elem_per_byte) * bit_width;
qvalue &= (1 << bit_width) - 1;
}
rd[i] = (static_cast<float>(qvalue) - zero_points_data[c]) * scales_data[c];
}
}
}
Expand Down Expand Up @@ -2667,7 +2677,7 @@ void quantize_tensor_per_channel_float_qparams_cpu(
TORCH_CHECK(rtensor.is_contiguous() || (axis <=1),
"If tensor is channels_last contig then per channel quantization "
"is supported only for axis = 0 or 1.");
AT_DISPATCH_QINT_TYPES(
AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(
qtensor.scalar_type(), "quantize_tensor_per_channel_float_qparams_cpu", [&]() {
int64_t batches = size_to_dim_(axis, rtensor.sizes());
int64_t elements_per_channel =
Expand All @@ -2677,15 +2687,22 @@ void quantize_tensor_per_channel_float_qparams_cpu(
auto zero_points_data = zero_points.data_ptr<float>();
check_tensor_memory_format(rtensor, qtensor);
const float* rdata = rtensor.data_ptr<float>();
auto qdata = qtensor.data_ptr<scalar_t>();
auto qdata = reinterpret_cast<underlying_t*>(qtensor.data_ptr<scalar_t>());
const auto elem_per_byte = CHAR_BIT / bit_width;
int qvalue = 0;
if (axis == 1 && (rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
for (auto b = 0; b < batches; ++b) {
for (auto e = 0; e < elements_per_channel; ++e) {
for (auto c = 0; c < channel; ++c) {
auto i = b * channel * elements_per_channel + e * channel + c;
qdata[i] = quantize_val_float_qparams<scalar_t>(
scales_data[c], zero_points_data[c], rdata[i]);
qvalue = quantize_val_float_qparams(
scales_data[c], zero_points_data[c], rdata[i], quant_min, quant_max);
if (i % elem_per_byte == 0) {
qdata[i / elem_per_byte] = static_cast<underlying_t>(qvalue);
} else {
qdata[i / elem_per_byte] |= static_cast<underlying_t>(qvalue << ((i % elem_per_byte) * bit_width));
}
}
}
}
Expand All @@ -2695,8 +2712,13 @@ void quantize_tensor_per_channel_float_qparams_cpu(
for (auto e = 0; e < elements_per_channel; ++e) {
auto i = b * channel * elements_per_channel +
c * elements_per_channel + e;
qdata[i] = quantize_val_float_qparams<scalar_t>(
scales_data[c], zero_points_data[c], rdata[i]);
qvalue = quantize_val_float_qparams(
scales_data[c], zero_points_data[c], rdata[i], quant_min, quant_max);
if (i % elem_per_byte == 0) {
qdata[i / elem_per_byte] = static_cast<underlying_t>(qvalue);
} else {
qdata[i / elem_per_byte] |= static_cast<underlying_t>(qvalue << ((i % elem_per_byte) * bit_width));
}
}
}
}
Expand All @@ -2710,9 +2732,9 @@ void dequantize_tensor_per_channel_float_qparams_cpu(
Tensor scales,
Tensor zero_points,
int64_t axis) {
AT_DISPATCH_QINT_TYPES(
AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(
qtensor.scalar_type(), "dequantize_tensor_per_channel_float_qparams_cpu", [&]() {
dequantize_per_channel_affine_kernel<float, float, scalar_t>(qtensor, rtensor, scales, zero_points, axis);
dequantize_per_channel_affine_kernel<float, float, scalar_t>(qtensor, rtensor, scales, zero_points, axis, bit_width);
});
}

Expand Down
50 changes: 49 additions & 1 deletion test/quantization/test_quantized_tensor.py
Expand Up @@ -429,6 +429,52 @@ def quantize_ref(data, scales, zero_points):
zero_points = torch.tensor([0.1, 0.2, 1.], dtype=torch.float)
self._test_quantize_per_channel(r, scales, zero_points, 0, True)

def test_quantize_per_channel_sub_byte(self):
""" Tests the per channel quantization scheme for 4-bit qtensors.
The scale and zero point for this have to be in floating point. """
r = torch.rand(3, 2, dtype=torch.float) * 4
scales = torch.tensor([0.2, 0.3, 0.1], dtype=torch.float)
zero_points = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float)
qr = torch.quantize_per_channel(r, scales, zero_points, 0, torch.quint4x2)
dequant_tensor = qr.dequantize()

def _get_qranges(bit_width):
if bit_width == 4:
return 0, 15

def _quantize_per_channel_sub_byte_ref(data, scales, zero_points, axis, bit_width):
dims = data.size()
data = data.view(-1, dims[axis], np.prod(dims[axis + 1:]))
qtensor_size = math.ceil(data.numel() / 2)
res = torch.empty(qtensor_size, dtype=torch.uint8)
elem_per_byte = 8 / bit_width
quant_min, quant_max = _get_qranges(bit_width)
for i in range(data.size()[0]):
for j in range(data.size()[1]):
for k in range(data.size()[2]):
inv_scale = 1.0 / scales[j]
index = i * data.size()[1] * data.size()[2] + j * data.size()[2] + k
qvalue = np.clip(
np.round(data[i][j][k] * inv_scale + zero_points[j]), quant_min, quant_max).to(dtype=torch.int)
res_idx = int(index / elem_per_byte)
if (index % elem_per_byte == 0):
res[res_idx] = qvalue
else:
res[res_idx] |= (qvalue << ((index % elem_per_byte) * bit_width))
return res

ref_res = _quantize_per_channel_sub_byte_ref(r, scales, zero_points, 0, 4)
self.assertTrue(np.allclose(qr.int_repr(), ref_res))
self.assertTrue(np.allclose(r.numpy(), dequant_tensor.numpy(), atol=1 / np.min(scales.numpy())))

# Check 4D tensor with non-zero axis.
r = torch.rand(3, 2, 4, 5, dtype=torch.float) * 4
scales = torch.tensor([0.2, 0.03], dtype=torch.float)
zero_points = torch.tensor([0.1, 0.2], dtype=torch.float)
qr = torch.quantize_per_channel(r, scales, zero_points, axis=1, dtype=torch.quint4x2)
ref_res = _quantize_per_channel_sub_byte_ref(r, scales, zero_points, 1, 4)
self.assertTrue(np.allclose(qr.int_repr(), ref_res))

def test_qtensor_permute(self):
scale = 0.02
zero_point = 1
Expand Down Expand Up @@ -516,7 +562,9 @@ def test_qtensor_per_channel_load_save(self):
scales = torch.rand(10, dtype=torch.double) * 0.02 + 0.01
zero_points = torch.round(torch.rand(10) * 20 + 1).to(torch.long)
# quint32, cuda is not supported yet
for dtype in [torch.quint8, torch.qint8]:
for dtype in [torch.quint8, torch.qint8, torch.quint4x2]:
if dtype == torch.quint4x2:
zero_points = torch.ones(10, dtype=torch.float)
qr = torch.quantize_per_channel(r, scales, zero_points, 1, dtype)
with tempfile.NamedTemporaryFile() as f:
# Serializing and Deserializing Tensor
Expand Down