diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 5022e0ecb22d..5f0d21f6b9b2 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4658,10 +4658,7 @@ - func: fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor variants: function dispatch: - CPU, CUDA: fake_quantize_per_channel_affine - -- func: fake_quantize_per_channel_affine_backward(Tensor grad, Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor - variants: function + Math: fake_quantize_per_channel_affine - func: fake_quantize_per_channel_affine_cachemask(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor output, Tensor mask) variants: function diff --git a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp index c2c633aa765a..85ab9eacf39c 100644 --- a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp +++ b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp @@ -2134,23 +2134,6 @@ void fake_quantize_learnable_tensor_grad_kernel_cpu( }); } -void fake_quant_per_channel_cpu( - TensorIterator& iter, - int64_t quant_min, - int64_t quant_max) { - cpu_kernel(iter, [=](float self, float scale, int64_t zero_point) -> float { - float inv_scale = 1.0f / scale; - return (std::fmin( - std::fmax( - static_cast( - zero_point + std::nearbyint(self * inv_scale)), - quant_min), - quant_max) - - zero_point) * - scale; - }); -} - void fake_quant_per_channel_cachemask_cpu( TensorIterator& iter, TensorIterator& iter_mask, @@ -2180,19 +2163,6 @@ void fake_quant_per_channel_cachemask_cpu( }); } -void fake_quant_grad_per_channel_cpu( - TensorIterator& iter, - int64_t quant_min, - int64_t quant_max) { - cpu_kernel( - iter, [=](float x, float dy, float scale, int64_t zero_point) -> float { - float inv_scale = 1.0f / scale; - int64_t xq = - static_cast(zero_point + std::nearbyint(x * inv_scale)); - return dy * (xq >= quant_min && xq <= quant_max); - }); -} - void fake_quantize_learnable_channel_grad_kernel_cpu( TensorIterator& iter, int64_t quant_min, @@ -3072,9 +3042,6 @@ REGISTER_DISPATCH(dequantize_tensor_per_channel_float_qparams_stub, &dequantize_tensor_per_channel_float_qparams_cpu); REGISTER_DISPATCH(fake_quant_grad_learnable_tensor_stub, &fake_quantize_learnable_tensor_grad_kernel_cpu); -REGISTER_DISPATCH(fake_quant_grad_per_channel_stub, - &fake_quant_grad_per_channel_cpu); -REGISTER_DISPATCH(fake_quant_per_channel_stub, &fake_quant_per_channel_cpu); REGISTER_DISPATCH(fake_quant_per_channel_cachemask_stub, &fake_quant_per_channel_cachemask_cpu); REGISTER_DISPATCH(fake_quant_tensor_cachemask_stub, &fake_quantize_tensor_cachemask_kernel); diff --git a/aten/src/ATen/native/quantized/cuda/fake_quantize_core.cu b/aten/src/ATen/native/quantized/cuda/fake_quantize_core.cu index bcb87f8f2091..4e9d3ee0bdbb 100644 --- a/aten/src/ATen/native/quantized/cuda/fake_quantize_core.cu +++ b/aten/src/ATen/native/quantized/cuda/fake_quantize_core.cu @@ -83,22 +83,6 @@ REGISTER_DISPATCH(fake_quant_grad_learnable_tensor_stub, &_fake_quantize_grad_le // Fake quantize per channel -void fake_quant_per_channel_cuda(TensorIterator &iter, int64_t quant_min, int64_t quant_max) { - gpu_kernel(iter, - [=] GPU_LAMBDA (float input_val, float scale, int64_t zero_point) -> float { - float inv_scale = 1.0f / scale; - return (fminf( - quant_max, - fmaxf( - quant_min, - static_cast( - std::nearbyint(input_val * inv_scale) + - zero_point))) - - zero_point) * - scale; - }); -} - void fake_quant_per_channel_cachemask_cuda( TensorIterator &iter, TensorIterator &iter_mask, int64_t quant_min, int64_t quant_max) { // TODO(future, optional): read once, write twice. Not done at the moment @@ -128,15 +112,6 @@ void fake_quant_per_channel_cachemask_cuda( }); } -void fake_quant_grad_per_channel_cuda(TensorIterator &iter, int64_t quant_min, int64_t quant_max) { - gpu_kernel(iter, - [=] GPU_LAMBDA (float x, float dy, float scale, int64_t zero_point) -> float { - float inv_scale = 1.0f / scale; - int64_t Xq = std::nearbyint(x * inv_scale) + zero_point; - return (Xq >= quant_min && Xq <= quant_max) * dy; - }); -} - void _fake_quantize_grad_learnable_channel_kernel_cuda(TensorIterator &iter, int64_t quant_min, int64_t quant_max) { gpu_kernel_multiple_outputs(iter, [=] GPU_LAMBDA (float x_input, float dy_input, float scale_input, float zero_point_input) -> thrust::tuple { @@ -161,8 +136,6 @@ void _fake_quantize_grad_learnable_channel_kernel_cuda(TensorIterator &iter, int }); } -REGISTER_DISPATCH(fake_quant_per_channel_stub, &fake_quant_per_channel_cuda); -REGISTER_DISPATCH(fake_quant_grad_per_channel_stub, &fake_quant_grad_per_channel_cuda); REGISTER_DISPATCH(fake_quant_per_channel_cachemask_stub, &fake_quant_per_channel_cachemask_cuda); REGISTER_DISPATCH(fake_quant_grad_learnable_channel_stub, &_fake_quantize_grad_learnable_channel_kernel_cuda); diff --git a/aten/src/ATen/native/quantized/fake_quant_affine.h b/aten/src/ATen/native/quantized/fake_quant_affine.h index 48f8574fe34f..e75223907a38 100644 --- a/aten/src/ATen/native/quantized/fake_quant_affine.h +++ b/aten/src/ATen/native/quantized/fake_quant_affine.h @@ -40,8 +40,6 @@ using fake_quant_per_channel_cachemask_fn = void (*)( int64_t quant_min, int64_t quant_max); -DECLARE_DISPATCH(fake_quant_per_channel_fn, fake_quant_per_channel_stub); -DECLARE_DISPATCH(fake_quant_per_channel_fn, fake_quant_grad_per_channel_stub); DECLARE_DISPATCH(fake_quant_per_channel_cachemask_fn, fake_quant_per_channel_cachemask_stub); DECLARE_DISPATCH(fake_quant_per_channel_fn, fake_quant_grad_learnable_channel_stub); diff --git a/aten/src/ATen/native/quantized/fake_quant_per_channel_affine.cpp b/aten/src/ATen/native/quantized/fake_quant_per_channel_affine.cpp index 5139aa02631f..d0c22f8d27a9 100644 --- a/aten/src/ATen/native/quantized/fake_quant_per_channel_affine.cpp +++ b/aten/src/ATen/native/quantized/fake_quant_per_channel_affine.cpp @@ -10,8 +10,6 @@ namespace at { namespace native { // Use REGISTER_DISPATCH to run CPU and CUDA backend. -DEFINE_DISPATCH(fake_quant_per_channel_stub); -DEFINE_DISPATCH(fake_quant_grad_per_channel_stub); DEFINE_DISPATCH(fake_quant_per_channel_cachemask_stub); DEFINE_DISPATCH(fake_quant_grad_learnable_channel_stub); @@ -36,50 +34,9 @@ Tensor fake_quantize_per_channel_affine( int64_t axis, int64_t quant_min, int64_t quant_max) { - TORCH_CHECK(self.scalar_type() == ScalarType::Float); - TORCH_CHECK(scale.scalar_type() == ScalarType::Float, - "Scale must be Float, found ", scale.scalar_type()); - TORCH_CHECK(zero_point.scalar_type() == ScalarType::Long, - "Zero-point must be Long, found ", zero_point.scalar_type()); - TORCH_CHECK(scale.dim() == 1, "scale should be a 1-D tensor"); - TORCH_CHECK(zero_point.dim() == 1, "zero point should be a 1-D tensor"); - TORCH_CHECK( - scale.numel() == zero_point.numel(), - "scale and zero-point need to have the same dimensions"); - TORCH_CHECK( - scale.numel() == self.size(axis), - "dimensions of scale and zero-point are not consistent with input tensor") - - TORCH_CHECK( - quant_min <= quant_max, - "`quant_min` should be less than or \ - equal to `quant_max`."); - - TORCH_CHECK( - at::min(zero_point).item().toLong() >= quant_min && - at::max(zero_point).item().toLong() <= quant_max, - "`zero_point` must be between `quant_min` and `quant_max`."); - - TORCH_CHECK( - axis >= 0 && axis <= self.dim(), - "`axis` must be between 0 and number of dimensions of input"); - - auto Y = at::empty_like(self, self.options(), MemoryFormat::Preserve); - - std::vector expected_shape(self.dim(), 1); - expected_shape[axis] = self.size(axis); - - TensorIterator iter = TensorIteratorConfig() - .check_all_same_dtype(false) - .add_output(Y) - .add_input(self) - .add_input(native::_unsafe_view(scale, expected_shape)) - .add_input(native::_unsafe_view(zero_point, expected_shape)) - .build(); - - fake_quant_per_channel_stub(iter.device_type(), iter, quant_min, quant_max); - - return Y; + const auto res = at::fake_quantize_per_channel_affine_cachemask( + self, scale, zero_point, axis, quant_min, quant_max); + return std::get<0>(res); } std::tuple fake_quantize_per_channel_affine_cachemask( @@ -172,87 +129,6 @@ Tensor fake_quantize_per_channel_affine_cachemask_backward( return dY * mask; } -/* Backward path for per-channel fake-quantization of the 'inputs' tensor. - -Args: - X: Forward input tensor. - dY: Backward input tensor. - scale: scale of per tensor affine quantization - zero_point: zero_point of per tensor affine quantization - axis: int ,the axis over which quantization parameters vary - quant_min: int, minimum quantized value - quant_max: int, maximum quantized value - -Returns: - Gradient for per channel fake quant (double dtype). - -*/ -Tensor fake_quantize_per_channel_affine_backward( - const Tensor& dY, - const Tensor& X, - const Tensor& scale, - const Tensor& zero_point, - int64_t axis, - int64_t quant_min, - int64_t quant_max) { - TORCH_CHECK(dY.scalar_type() == ScalarType::Float); - TORCH_CHECK(X.scalar_type() == ScalarType::Float); - TORCH_CHECK(scale.scalar_type() == ScalarType::Float, - "Scale must be Float, found ", scale.scalar_type()); - TORCH_CHECK(zero_point.scalar_type() == ScalarType::Long, - "Zero-point must be Long, found ", zero_point.scalar_type()); - - TORCH_CHECK(X.sizes() == dY.sizes(), "`X` and `dY` are not the same size"); - TORCH_CHECK( - quant_min <= quant_max, - "`quant_min` should be less than or \ - equal to `quant_max`."); - TORCH_CHECK(scale.dim() == 1, "scale should be a 1-D tensor"); - TORCH_CHECK(zero_point.dim() == 1, "zero point should be a 1-D tensor"); - TORCH_CHECK( - scale.numel() == zero_point.numel(), - "scale and zero-point need to have the same dimensions"); - TORCH_CHECK( - scale.numel() == X.size(axis), - "dimensions of scale and zero-point are not consistent with input tensor") - - TORCH_CHECK( - quant_min <= quant_max, - "`quant_min` should be less than or \ - equal to `quant_max`."); - - TORCH_CHECK( - at::min(zero_point).item().toLong() >= quant_min && - at::max(zero_point).item().toLong() <= quant_max, - "`zero_point` must be between `quant_min` and `quant_max`."); - - TORCH_CHECK( - axis >= 0 && axis <= X.dim(), - "`axis` must be between 0 and number of dimensions of input"); - - if (X.numel() <= 0) { - return X; - } - - auto dX = at::empty_like(X, X.options(), MemoryFormat::Preserve); - - std::vector expected_shape(X.dim(), 1); - expected_shape[axis] = X.size(axis); - - TensorIterator iter = TensorIteratorConfig() - .check_all_same_dtype(false) - .add_output(dX) - .add_input(X) - .add_input(dY) - .add_input(native::_unsafe_view(scale, expected_shape)) - .add_input(native::_unsafe_view(zero_point, expected_shape)) - .build(); - - fake_quant_grad_per_channel_stub(iter.device_type(), iter, quant_min, quant_max); - - return dX; -} - Tensor _get_rounded_zero_point( const Tensor& zero_point, int64_t quant_min, diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py index 658c61dd4425..24d47d55eb6e 100644 --- a/test/backward_compatibility/check_backward_compatibility.py +++ b/test/backward_compatibility/check_backward_compatibility.py @@ -57,6 +57,7 @@ ("aten::_multinomial_alias_draw", datetime.date(2021, 1, 31)), ("prim::profile_optional", datetime.date(2021, 1, 31)), ("aten::fake_quantize_per_tensor_affine_backward", datetime.date(2021, 2, 20)), + ("aten::fake_quantize_per_channel_affine_backward", datetime.date(2021, 2, 20)), ] def allow_listed(schema, allow_list): diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 436da7b2d3b8..55428902882e 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -464,9 +464,6 @@ - name: _fake_quantize_learnable_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor self, scale, zero_point: "grad.defined() ? _fake_quantize_learnable_per_tensor_affine_backward(grad, self, scale, zero_point, quant_min, quant_max) : std::tuple()" -- name: fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor - self: fake_quantize_per_channel_affine_backward(grad, self, scale, zero_point, axis, quant_min, quant_max) - - name: fake_quantize_per_channel_affine_cachemask(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor output, Tensor mask) self: fake_quantize_per_channel_affine_cachemask_backward(grad, mask) diff --git a/torch/quantization/fake_quantize.py b/torch/quantization/fake_quantize.py index 03357facd314..576d3134c862 100644 --- a/torch/quantization/fake_quantize.py +++ b/torch/quantization/fake_quantize.py @@ -137,14 +137,9 @@ def forward(self, X): if self.fake_quant_enabled[0] == 1: if self.is_per_channel: - if self.training: - X, _mask = torch.fake_quantize_per_channel_affine_cachemask( - X, self.scale, self.zero_point, - self.ch_axis, self.quant_min, self.quant_max) - else: - X = torch.fake_quantize_per_channel_affine( - X, self.scale, self.zero_point, - self.ch_axis, self.quant_min, self.quant_max) + X = torch.fake_quantize_per_channel_affine( + X, self.scale, self.zero_point, + self.ch_axis, self.quant_min, self.quant_max) else: X = torch.fake_quantize_per_tensor_affine( X, float(self.scale), int(self.zero_point),