Skip to content

Commit

Permalink
Per-channel quantized tensor to have only a single axis (#26675)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch/pytorch#26675

Based on offline poll, we're very unlikely to have multi-axis quantized tensors in the foreseeable future. Let's simplify API and just return int instead of list. It also matches the singular `axis` name.

Test Plan: Imported from OSS

Differential Revision: D17537052

Pulled By: dzhulgakov

fbshipit-source-id: 676abc3b251d288468aaed467b5e5ca4063b98b0
  • Loading branch information
Dmytro Dzhulgakov authored and facebook-github-bot committed Sep 24, 2019
1 parent 7fba2d7 commit cf9bb7b
Show file tree
Hide file tree
Showing 11 changed files with 37 additions and 45 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/core/TensorBody.h
Expand Up @@ -744,7 +744,7 @@ class CAFFE2_API Tensor {
int64_t q_zero_point() const;
Tensor q_per_channel_scales() const;
Tensor q_per_channel_zero_points() const;
IntArrayRef q_per_channel_axis() const;
int64_t q_per_channel_axis() const;
Tensor int_repr() const;
QScheme qscheme() const;
Tensor to(const TensorOptions & options, bool non_blocking=false, bool copy=false) const;
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/core/TensorMethods.h
Expand Up @@ -3252,7 +3252,7 @@ inline Tensor Tensor::q_per_channel_zero_points() const {
op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast<Tensor&>(*this));
#endif
}
inline IntArrayRef Tensor::q_per_channel_axis() const {
inline int64_t Tensor::q_per_channel_axis() const {
#ifdef USE_STATIC_DISPATCH
switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) {
case Backend::QuantizedCPU:
Expand All @@ -3262,8 +3262,8 @@ inline IntArrayRef Tensor::q_per_channel_axis() const {
AT_ERROR("q_per_channel_axis not implemented for ", at::toString(type_set()));
}
#else
static auto table = globalATenDispatch().getOpTable("aten::q_per_channel_axis(Tensor self) -> int[]");
return table->getOp<IntArrayRef (const Tensor &)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast<Tensor&>(*this));
static auto table = globalATenDispatch().getOpTable("aten::q_per_channel_axis(Tensor self) -> int");
return table->getOp<int64_t (const Tensor &)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast<Tensor&>(*this));
#endif
}
inline Tensor Tensor::int_repr() const {
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -995,7 +995,7 @@

# it's a factory function receiving a tensor argument, thus overriding explicitly
# other overrides are to provide a more helpful error message that dtype is required
- func: _empty_per_channel_affine_quantized(int[] size, *, Tensor scales, Tensor zero_points, int[] axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor
- func: _empty_per_channel_affine_quantized(int[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor
category_override: factory
dispatch:
CPU: empty_per_channel_affine_quantized_other_backends_stub
Expand Down Expand Up @@ -3508,7 +3508,7 @@
dispatch:
CPU: quantize_per_tensor_cpu

- func: quantize_per_channel(Tensor self, Tensor scales, Tensor zero_points, int[] axis, ScalarType dtype) -> Tensor
- func: quantize_per_channel(Tensor self, Tensor scales, Tensor zero_points, int axis, ScalarType dtype) -> Tensor
variants: function
dispatch:
CPU: quantize_per_channel_cpu
Expand Down Expand Up @@ -3548,7 +3548,7 @@
dispatch:
QuantizedCPU: q_per_channel_zero_points_quant

- func: q_per_channel_axis(Tensor self) -> int[]
- func: q_per_channel_axis(Tensor self) -> int
variants: function, method
dispatch:
QuantizedCPU: q_per_channel_axis_quant
Expand All @@ -3564,7 +3564,7 @@
dispatch:
CPU: make_per_tensor_quantized_tensor_cpu

- func: _make_per_channel_quantized_tensor(Tensor self, Tensor scale, Tensor zero_point, int[] axis) -> Tensor
- func: _make_per_channel_quantized_tensor(Tensor self, Tensor scale, Tensor zero_point, int axis) -> Tensor
use_c10_dispatcher: unboxed_only
dispatch:
CPU: make_per_channel_quantized_tensor_cpu
Expand Down
7 changes: 3 additions & 4 deletions aten/src/ATen/native/quantized/QTensor.cpp
Expand Up @@ -19,15 +19,14 @@ Tensor quantize_per_channel_cpu(
const Tensor& self,
const Tensor& scales,
const Tensor& zero_points,
IntArrayRef axis,
int64_t axis,
ScalarType dtype) {
TORCH_CHECK(scales.dim() == 1, "scale tensor must have dimension 1");
TORCH_CHECK(
zero_points.dim() == 1, "zero_points tensor must have dimension 1");
TORCH_CHECK(
scales.numel() == zero_points.numel(),
"number of elements in scales and zero_points must match");
TORCH_CHECK(axis.size() == 1, "only axis of size 1 is supported right now");
double* scales_data = scales.data_ptr<double>();
int64_t* zero_points_data = zero_points.data_ptr<int64_t>();
std::vector<double> scale_vals(scales_data, scales_data + scales.numel());
Expand Down Expand Up @@ -94,7 +93,7 @@ Tensor q_per_channel_zero_points_quant(const Tensor& self) {
self.options().dtype(at::kLong));
}

IntArrayRef q_per_channel_axis_quant(const Tensor& self) {
int64_t q_per_channel_axis_quant(const Tensor& self) {
auto quantizer = get_qtensorimpl(self)->quantizer();
TORCH_CHECK(quantizer->qscheme() == kPerChannelAffine);
return static_cast<PerChannelAffineQuantizer*>(quantizer.get())->axis();
Expand Down Expand Up @@ -141,7 +140,7 @@ Tensor make_per_channel_quantized_tensor_cpu(
const Tensor& self,
const Tensor& scales,
const Tensor& zero_points,
IntArrayRef axis) {
int64_t axis) {
Tensor dst = at::_empty_per_channel_affine_quantized(
self.sizes(),
scales,
Expand Down
5 changes: 2 additions & 3 deletions aten/src/ATen/native/quantized/TensorFactories.cpp
Expand Up @@ -31,7 +31,7 @@ Tensor empty_per_channel_affine_quantized_cpu(
IntArrayRef size,
const Tensor& scales,
const Tensor& zero_points,
IntArrayRef axis,
int64_t axis,
const TensorOptions& options,
c10::optional<c10::MemoryFormat> optional_memory_format) {
TORCH_CHECK(
Expand All @@ -46,7 +46,6 @@ Tensor empty_per_channel_affine_quantized_cpu(
TORCH_CHECK(
scales.numel() == zero_points.numel(),
"number of elements in scales and zero_points must match");
TORCH_CHECK(axis.size() == 1, "only axis of size 1 is supported right now");
double* scales_data = scales.data_ptr<double>();
int64_t* zero_points_data = zero_points.data_ptr<int64_t>();
std::vector<double> scale_vals(scales_data, scales_data + scales.numel());
Expand Down Expand Up @@ -77,7 +76,7 @@ Tensor empty_per_channel_affine_quantized_other_backends_stub(
IntArrayRef,
const Tensor&,
const Tensor&,
IntArrayRef,
int64_t,
const TensorOptions&,
c10::optional<c10::MemoryFormat>) {
TORCH_CHECK(false, "Creation of quantized tensor requires quantized dtype like torch.quint8");
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp
Expand Up @@ -69,9 +69,9 @@ class QConvPackWeightInt8 final : public c10::OperatorKernel {
if (qtype == kPerTensorAffine) {
zero_points[0] = weight.q_zero_point();
} else if (qtype == kPerChannelAffine) {
auto axis = weight.q_per_channel_axis();
int64_t axis = weight.q_per_channel_axis();
TORCH_CHECK(
axis.size() == 1 && axis[0] == 0,
axis == 0,
"Only per output channel quantization is supported for the weights");
zero_points.resize(output_channels, 0);
for (int i = 0; i < output_channels; ++i) {
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/quantized/cpu/qconv_unpack.cpp
Expand Up @@ -58,7 +58,7 @@ class QConvUnpackWeightsInt8 final : public c10::OperatorKernel {
{output_channels, C_per_G, kernel_h, kernel_w},
scales.toType(kDouble),
zero_points.toType(kLong),
{0}, /* The output channel axis is 0 */
0, /* The output channel axis is 0 */
device(kCPU).dtype(kQInt8),
MemoryFormat::ChannelsLast);
} else {
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/quantized/cpu/qlinear_unpack.cpp
Expand Up @@ -40,7 +40,7 @@ class QLinearUnpackWeightInt8 final : public c10::OperatorKernel {
{N, K},
scales.toType(kDouble),
zero_points.toType(kLong),
{0}, // The output channel axis is 0
0, // The output channel axis is 0
device(kCPU).dtype(kQInt8));
}

Expand Down
24 changes: 11 additions & 13 deletions aten/src/ATen/quantized/Quantizer.cpp
Expand Up @@ -262,16 +262,15 @@ Tensor quantize_tensor_per_channel_affine(Tensor rtensor,
Tensor qtensor,
const std::vector<double>& scales,
const std::vector<int64_t>& zero_points,
IntArrayRef axis) {
int64_t axis) {
auto fn_name = "quantize_tensor_per_channel_affine";
checkFloatCPUTensor(fn_name, rtensor);
checkQuantizedCPUTensor<T>(fn_name, qtensor);
checkZeroPoints<typename T::underlying>(fn_name, zero_points);
int64_t channel_axis = axis[0];
TORCH_CHECK(channel_axis < rtensor.dim(), "Channel axis out of range in per channel affine quantization.");
int64_t batches = size_to_dim_(channel_axis, rtensor.sizes());
int64_t elements_per_channel = size_from_dim_(channel_axis + 1, rtensor.sizes());
int64_t channel = rtensor.size(channel_axis);
TORCH_CHECK(0 <= axis && axis < rtensor.dim(), "Channel axis out of range in per channel affine quantization.");
int64_t batches = size_to_dim_(axis, rtensor.sizes());
int64_t elements_per_channel = size_from_dim_(axis + 1, rtensor.sizes());
int64_t channel = rtensor.size(axis);
TORCH_CHECK(channel == int64_t(scales.size()),
"length of scales must equal to channel");
TORCH_CHECK(channel == int64_t(zero_points.size()),
Expand All @@ -294,17 +293,16 @@ Tensor dequantize_tensor_per_channel_affine(Tensor qtensor,
Tensor rtensor,
const std::vector<double>& scales,
const std::vector<int64_t>& zero_points,
IntArrayRef axis) {
int64_t axis) {
auto fn_name = "dequantize_tensor_per_channel_affine";
checkFloatCPUTensor(fn_name, rtensor);
checkQuantizedCPUTensor<T>(fn_name, qtensor);
checkZeroPoints<typename T::underlying>(fn_name, zero_points);
int64_t channel_axis = axis[0];
TORCH_CHECK(channel_axis < qtensor.dim(),
TORCH_CHECK(0 <= axis && axis < qtensor.dim(),
"Channel axis out of range in per channel affine dequantization.");
int64_t batches = size_to_dim_(channel_axis, rtensor.sizes());
int64_t elements_per_channel = size_from_dim_(channel_axis + 1, rtensor.sizes());
int64_t channel = rtensor.size(channel_axis);
int64_t batches = size_to_dim_(axis, rtensor.sizes());
int64_t elements_per_channel = size_from_dim_(axis + 1, rtensor.sizes());
int64_t channel = rtensor.size(axis);
TORCH_CHECK(channel == int64_t(scales.size()),
"length of scales must equal to channel");
TORCH_CHECK(channel == int64_t(zero_points.size()),
Expand Down Expand Up @@ -335,7 +333,7 @@ QuantizerPtr make_per_tensor_affine_quantizer(
QuantizerPtr make_per_channel_affine_quantizer(
const std::vector<double>& scales,
const std::vector<int64_t>& zero_points,
IntArrayRef axis,
int64_t axis,
ScalarType scalar_type) {
return c10::make_intrusive<PerChannelAffineQuantizer>(scalar_type,
scales, zero_points, axis);
Expand Down
20 changes: 8 additions & 12 deletions aten/src/ATen/quantized/Quantizer.h
Expand Up @@ -179,15 +179,11 @@ struct CAFFE2_API PerChannelAffineQuantizer : public AffineQuantizer {
ScalarType scalar_type,
const std::vector<double>& scales,
const std::vector<int64_t>& zero_points,
IntArrayRef axis)
: AffineQuantizer(scalar_type),
scales_(scales),
zero_points_(zero_points),
axis_(axis.vec()) {
TORCH_CHECK(
axis_.size() == 1,
"Per channel affine quantization in multiple axis is not supported yet.");
}
int64_t axis)
: AffineQuantizer(scalar_type),
scales_(scales),
zero_points_(zero_points),
axis_(axis) {}

QScheme qscheme() const override {
return kPerChannelAffine;
Expand All @@ -201,7 +197,7 @@ struct CAFFE2_API PerChannelAffineQuantizer : public AffineQuantizer {
return zero_points_;
}

IntArrayRef axis() const {
int64_t axis() const {
return axis_;
}

Expand All @@ -223,7 +219,7 @@ struct CAFFE2_API PerChannelAffineQuantizer : public AffineQuantizer {
private:
const std::vector<double> scales_;
const std::vector<int64_t> zero_points_;
const SmallVector<int64_t, 1> axis_;
const int64_t axis_;
};

// This is an internal utility function for getting at the QTensorImpl,
Expand Down Expand Up @@ -258,7 +254,7 @@ make_per_tensor_affine_quantizer(
CAFFE2_API QuantizerPtr
make_per_channel_affine_quantizer(
const std::vector<double>& scales, const std::vector<int64_t>& zero_points,
IntArrayRef axis, ScalarType scalar_type);
int64_t axis, ScalarType scalar_type);

// Create a Quantized Tensor given arguments for normal Tensor and a quantizer
CAFFE2_API Tensor new_qtensor_cpu(
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/test/quantized_test.cpp
Expand Up @@ -128,7 +128,7 @@ TEST(TestQTensor, EmptyPerchannelQuantized) {
{numel},
scales,
zero_points,
{ch_axis},
ch_axis,
at::device(at::kCPU).dtype(kQUInt8));
// Assigning to QTensor
auto* q_data = q.data_ptr<quint8>();
Expand Down

0 comments on commit cf9bb7b

Please sign in to comment.