From 983b8e6b62fb9bc7260d1c52bbe27e99b771ad56 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Wed, 27 Jan 2021 19:33:26 -0800 Subject: [PATCH] fake_quant: add a more memory efficient version (#50561) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50561 Not for review yet, a bunch of TODOs need finalizing. tl;dr; add an alternative implementation of `fake_quantize` which saves a ask during the forward pass and uses it to calculate the backward. There are two benefits: 1. the backward function no longer needs the input Tensor, and it can be gc'ed earlier by autograd. On MobileNetV2, this reduces QAT overhead by ~15% (TODO: link, and absolute numbers). We add an additional mask Tensor to pass around, but its size is 4x smaller than the input tensor. A future optimization would be to pack the mask bitwise and unpack in the backward. 2. the computation of `qval` can be done only once in the forward and reused in the backward. No perf change observed, TODO verify with better matrics. TODO: describe in more detail Test Plan: OSS / torchvision / MobileNetV2 ``` python references/classification/train_quantization.py --print-freq 1 --data-path /data/local/packages/ai-group.imagenet-256-smallest-side/prod/ --output-dir ~/nfs/pytorch_vision_tests/ --backend qnnpack --epochs 5 TODO paste results here ``` TODO more Imported from OSS Reviewed By: ngimel Differential Revision: D25918519 fbshipit-source-id: ec544ca063f984de0f765bf833f205c99d6c18b6 --- aten/src/ATen/native/native_functions.yaml | 8 ++ .../cpu/kernels/QuantizedOpKernels.cpp | 34 ++++++++ .../quantized/cuda/fake_quantize_core.cu | 32 ++++++++ .../ATen/native/quantized/fake_quant_affine.h | 10 +++ .../fake_quant_per_tensor_affine.cpp | 81 ++++++++++++++++--- .../pt/quantization_test.py | 14 ++-- test/quantization/test_workflow_module.py | 59 ++++++++++++++ test/test_namedtuple_return_api.py | 4 +- tools/autograd/derivatives.yaml | 3 + torch/overrides.py | 1 + torch/quantization/fake_quantize.py | 16 +++- 11 files changed, 243 insertions(+), 19 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 1856b9a9bf13..1fd04b4255b7 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4642,6 +4642,14 @@ - func: fake_quantize_per_tensor_affine_backward(Tensor grad, Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor variants: function +- func: fake_quantize_per_tensor_affine_cachemask(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor output, Tensor mask) + variants: function + dispatch: + CPU, CUDA: fake_quantize_per_tensor_affine_cachemask + +- func: fake_quantize_per_tensor_affine_cachemask_backward(Tensor grad, Tensor mask) -> Tensor + variants: function + - func: _fake_quantize_learnable_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor variants: function dispatch: diff --git a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp index 8137049a75c8..5ed6e28663e0 100644 --- a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp +++ b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp @@ -2091,6 +2091,38 @@ void fake_quantize_grad_tensor_kernel( }); } +void fake_quantize_tensor_cachemask_kernel( + Tensor& output, + Tensor& mask, + const Tensor& input, + float sc, + int64_t z_point, + int64_t quant_min, + int64_t quant_max) { + float inv_scale = 1.0f / sc; + + auto iter_combined = TensorIteratorConfig() + .check_all_same_dtype(false) + .add_output(output) + .add_output(mask) + .add_input(input) + .build(); + + // TODO(#51090): make it work for other dtypes + iter_combined.for_each([&](char** data, const int64_t* strides, int64_t n) { + for (int64_t i = 0; i < n; i++) { + float* output_val = (float*)(data[0] + i * strides[0]); + bool* mask_val = (bool*)(data[1] + i * strides[1]); + float* input_val = (float*)(data[2] + i * strides[2]); + + const auto qval = static_cast(z_point + std::nearbyint(*input_val * inv_scale)); + *output_val = (std::fmin(std::fmax(qval, quant_min), quant_max) - z_point) * sc; + *mask_val = ((quant_min <= qval) && (qval <= quant_max)); + } + }); + +} + void fake_quantize_learnable_tensor_grad_kernel_cpu( TensorIterator& iter, float scale, @@ -3054,6 +3086,8 @@ REGISTER_DISPATCH(fake_quant_grad_tensor_stub, &fake_quantize_grad_tensor_kernel); REGISTER_DISPATCH(fake_quant_per_channel_stub, &fake_quant_per_channel_cpu); REGISTER_DISPATCH(fake_quant_tensor_stub, &fake_quantize_tensor_kernel); +REGISTER_DISPATCH(fake_quant_tensor_cachemask_stub, + &fake_quantize_tensor_cachemask_kernel); REGISTER_DISPATCH(qadaptive_avg_pool2d_nhwc_stub, &qadaptive_avg_pool2d_nhwc_kernel); REGISTER_DISPATCH(qadaptive_avg_pool3d_ndhwc_stub, diff --git a/aten/src/ATen/native/quantized/cuda/fake_quantize_core.cu b/aten/src/ATen/native/quantized/cuda/fake_quantize_core.cu index e2f51398b48f..87937df546a8 100644 --- a/aten/src/ATen/native/quantized/cuda/fake_quantize_core.cu +++ b/aten/src/ATen/native/quantized/cuda/fake_quantize_core.cu @@ -68,6 +68,37 @@ void fake_quantize_grad_tensor_kernel_cuda( }); } +void fake_quantize_tensor_cachemask_kernel_cuda( + Tensor& output, + Tensor& mask, + const Tensor& input, + float scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max) { + + float inv_scale = 1.0f / scale; + auto iter = TensorIteratorConfig() + .check_all_same_dtype(false) + .add_output(output) + .add_output(mask) + .add_input(input) + .build(); + + gpu_kernel_multiple_outputs( + iter, + [=] GPU_LAMBDA (float input_val) -> thrust::tuple { + const auto qval = static_cast(std::nearbyint(input_val * inv_scale) + zero_point); + return { + // fake_quantized value + (fminf(quant_max, fmaxf(quant_min, qval)) - zero_point) * scale, + // mask for grad + ((quant_min <= qval) && (qval <= quant_max)) + }; + } + ); +} + void _fake_quantize_grad_learnable_tensor_kernel_cuda( TensorIterator& iter, float scale, @@ -96,6 +127,7 @@ void _fake_quantize_grad_learnable_tensor_kernel_cuda( } REGISTER_DISPATCH(fake_quant_tensor_stub, &fake_quantize_tensor_kernel_cuda); +REGISTER_DISPATCH(fake_quant_tensor_cachemask_stub, &fake_quantize_tensor_cachemask_kernel_cuda); REGISTER_DISPATCH(fake_quant_grad_tensor_stub, &fake_quantize_grad_tensor_kernel_cuda); REGISTER_DISPATCH(fake_quant_grad_learnable_tensor_stub, &_fake_quantize_grad_learnable_tensor_kernel_cuda); diff --git a/aten/src/ATen/native/quantized/fake_quant_affine.h b/aten/src/ATen/native/quantized/fake_quant_affine.h index 7a90ff57ae1b..6865c75f4a49 100644 --- a/aten/src/ATen/native/quantized/fake_quant_affine.h +++ b/aten/src/ATen/native/quantized/fake_quant_affine.h @@ -26,6 +26,15 @@ using fake_quant_grad_tensor_fn = void (*)( int64_t quant_min, int64_t quant_max); +using fake_quant_tensor_cachemask_fn = void (*)( + Tensor& output, + Tensor& mask, + const Tensor& input, + float sc, + int64_t z_point, + int64_t quant_min, + int64_t quant_max); + using fake_quant_learnable_grad_tensor_fn = void (*)( TensorIterator& iter, float scale, @@ -36,6 +45,7 @@ using fake_quant_learnable_grad_tensor_fn = void (*)( DECLARE_DISPATCH(fake_quant_tensor_fn, fake_quant_tensor_stub); DECLARE_DISPATCH(fake_quant_grad_tensor_fn, fake_quant_grad_tensor_stub); +DECLARE_DISPATCH(fake_quant_tensor_cachemask_fn, fake_quant_tensor_cachemask_stub); DECLARE_DISPATCH(fake_quant_learnable_grad_tensor_fn, fake_quant_grad_learnable_tensor_stub); using fake_quant_per_channel_fn = void (*)( diff --git a/aten/src/ATen/native/quantized/fake_quant_per_tensor_affine.cpp b/aten/src/ATen/native/quantized/fake_quant_per_tensor_affine.cpp index fb0853cf2ff2..a782033d5002 100644 --- a/aten/src/ATen/native/quantized/fake_quant_per_tensor_affine.cpp +++ b/aten/src/ATen/native/quantized/fake_quant_per_tensor_affine.cpp @@ -12,16 +12,19 @@ namespace native { // Use REGISTER_DISPATCH to run CPU and CUDA backend. DEFINE_DISPATCH(fake_quant_tensor_stub); DEFINE_DISPATCH(fake_quant_grad_tensor_stub); +DEFINE_DISPATCH(fake_quant_tensor_cachemask_stub); DEFINE_DISPATCH(fake_quant_grad_learnable_tensor_stub); /* Fake-quantizes the 'inputs' tensor. + Args: - X: Forward input tensor. + self: Forward input tensor. dY: Backward input tensor (_backward op only). scale: scale of per tensor affine quantization zero_point: zero_point of per tensor affine quantization quant_min: minimum quantized value quant_max: maximum quantized value + Returns: Quantized tensor (double dtype). @@ -50,22 +53,15 @@ Tensor fake_quantize_per_tensor_affine( /* Backward path to fake-quantize the 'inputs' tensor. Args: - X: Forward input tensor. dY: Backward input tensor. + X: Forward input tensor. scale: scale of per tensor affine quantization zero_point: zero_point of per tensor affine quantization quant_min: minimum quantized value quant_max: maximum quantized value - quant_delay: Count of global steps for which to delay the quantization. - See note in forward. - iter: The current quantization iteration used for `quant_delay`. + Returns: Quantized tensor (double dtype). - -Notes: - - quant_delay might be set to non-zero to help weights stabilize in the - beginning of the training. - - quantization range [0, 2^bits - 1] */ Tensor fake_quantize_per_tensor_affine_backward( @@ -95,6 +91,71 @@ Tensor fake_quantize_per_tensor_affine_backward( return dX; } +/* Fake-quantizes the 'inputs' tensor, saving a mask for the backward pass. + +This is numerically equivalent to `fake_quantize_per_tensor_affine`, +but has a lower memory overhead in the backward pass. + +Args: + self: Forward input tensor. + scale: scale of per tensor affine quantization + zero_point: zero_point of per tensor affine quantization + quant_min: minimum quantized value + quant_max: maximum quantized value + +Returns: + Quantized tensor (double dtype). + Mask (bool dtype). +*/ +std::tuple fake_quantize_per_tensor_affine_cachemask( + const Tensor& self, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max) { + TORCH_CHECK(self.scalar_type() == ScalarType::Float); + TORCH_CHECK( + quant_min <= quant_max, + "`quant_min` should be less than or \ + equal to `quant_max`."); + TORCH_CHECK( + zero_point >= quant_min && zero_point <= quant_max, + "`zero_point` must be between `quant_min` and `quant_max`."); + + auto Y = at::empty_like(self, self.options(), MemoryFormat::Preserve); + auto mask = at::empty_like(self, at::kBool, MemoryFormat::Preserve); + fake_quant_tensor_cachemask_stub( + self.device().type(), Y, mask, self, scale, zero_point, quant_min, quant_max); + // TODO(future, optional): look into packing the mask further (BoolTensor uses + // 1 byte per element, we only need 1 bit per element). + return std::make_tuple(Y, mask); +} + +/* Backward path to fake-quantize the 'inputs' tensor, with mask. + +Args: + dY: output grad. + mask: mask tensor from the forward pass. + +Returns: + dX (input grad). +*/ +Tensor fake_quantize_per_tensor_affine_cachemask_backward( + const Tensor& dY, + const Tensor& mask) { + TORCH_CHECK(dY.scalar_type() == ScalarType::Float); + TORCH_CHECK(mask.scalar_type() == ScalarType::Bool); + TORCH_CHECK(mask.numel() == dY.numel(), + "`mask` and `dY` are not the same size: ", + "`mask` is size ", mask.numel(), " and `dY` is size ", dY.numel()); + if (dY.numel() <= 0) { + return dY; + } + // Note: no additional kernels needed, since mask is pre-computed + // and we can use the existing tensor multiplication kernels. + return dY * mask; +} + int64_t _get_zero_point_from_tensor( const Tensor& zero_point, int64_t quant_min, diff --git a/benchmarks/operator_benchmark/pt/quantization_test.py b/benchmarks/operator_benchmark/pt/quantization_test.py index af09a5fa2523..a8377fb3e488 100644 --- a/benchmarks/operator_benchmark/pt/quantization_test.py +++ b/benchmarks/operator_benchmark/pt/quantization_test.py @@ -130,35 +130,38 @@ def forward(self, input, scales, zero_points, axis: int, dtype: int): 'attr_names': ['N', 'C', 'H', 'W'], 'attrs': [ [1, 3, 512, 512], - [1, 3, 512, 512] ], 'tags': ['short'] } fake_quantize_configs_long_dict = { 'N': [1], - 'C': [1, 3, 8], + 'C': [1, 3, 8, 32], 'H': [256, 1024], 'W': [256, 1024], 'tags': ['long'] } fake_quantize_configs_short = op_bench.config_list( + cross_product_configs={ + 'device': ('cpu', 'cuda'), + }, **fake_quantize_configs_short_dict ) fake_quantize_configs_long = op_bench.cross_product_configs( + device=('cpu', 'cuda'), **fake_quantize_configs_long_dict ) class FakeQuantizeBenchmark(op_bench.TorchBenchmarkBase): r"""Benchmarks fake quantization with default parameters.""" - def init(self, N, C, H, W): + def init(self, N, C, H, W, device): self.inputs = { - "input": torch.rand(N, C, H, W) + "input": torch.rand(N, C, H, W).to(device) } - self.op = tq.FakeQuantize() + self.op = tq.FakeQuantize().to(device) self.set_module_name('FakeQuantize') def forward(self, input): @@ -169,6 +172,7 @@ def forward(self, input): fake_quantize_configs_short + fake_quantize_configs_long, FakeQuantizeBenchmark) + # op_type is used to describe the type of operator used in benchmarking: # py_module represents the operator written in Python that can # backpropagate on scale and zero point. diff --git a/test/quantization/test_workflow_module.py b/test/quantization/test_workflow_module.py index 866e1971ab19..869cd2cf3715 100644 --- a/test/quantization/test_workflow_module.py +++ b/test/quantization/test_workflow_module.py @@ -861,6 +861,65 @@ def test_backward_per_tensor(self, device, X): Y_prime.backward(dout) np.testing.assert_allclose(dX.cpu(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance) + def _test_forward_per_tensor_cachemask_impl(self, device): + for torch_type in (torch.qint8, torch.quint8): + X = torch.randn(4, 8).to(device) + # pick the scale + zp so that some values get clipped + obs = torch.quantization.MinMaxObserver(torch_type) + obs(X * 0.75) + scale, zero_point = obs.calculate_qparams() + scale, zero_point = float(scale), int(zero_point) + quant_min, quant_max = obs._calculate_qmin_qmax() + + Y_test, _mask = torch.fake_quantize_per_tensor_affine_cachemask( + X, scale, zero_point, quant_min, quant_max) + Y_ref = _fake_quantize_per_tensor_affine_reference( + X.cpu(), scale, zero_point, quant_min, quant_max).to(device) + self.assertTrue(torch.allclose(Y_test, Y_ref, rtol=tolerance, atol=tolerance)) + + def test_forward_per_tensor_cachemask_cpu(self): + device = torch.device('cpu') + self._test_forward_per_tensor_cachemask_impl(device) + + @unittest.skipIf(not TEST_CUDA, "No gpu is not available.") + def test_forward_per_tensor_cachemask_cuda(self): + device = torch.device('cuda') + self._test_forward_per_tensor_cachemask_impl(device) + + def _test_backward_per_tensor_cachemask_impl(self, device): + for torch_type in (torch.qint8, torch.quint8): + X = torch.randn(4, 8).to(device) + X.requires_grad_() + # pick the scale + zp so that some values get clipped + obs = torch.quantization.MinMaxObserver(torch_type) + obs(X * 0.75) + scale, zero_point = obs.calculate_qparams() + scale, zero_point = float(scale), int(zero_point) + quant_min, quant_max = obs._calculate_qmin_qmax() + + # forward pass + Y_test, mask = torch.fake_quantize_per_tensor_affine_cachemask( + X, scale, zero_point, quant_min, quant_max) + Y_ref = _fake_quantize_per_tensor_affine_reference( + X.cpu(), scale, zero_point, quant_min, quant_max).to(device) + self.assertTrue(torch.allclose(Y_test, Y_ref, rtol=tolerance, atol=tolerance)) + + # backward pass + dout = torch.rand(X.shape, dtype=torch.float).to(device) + dX = _fake_quantize_per_tensor_affine_grad_reference( + dout, X, scale, zero_point, quant_min, quant_max) + Y_test.backward(dout) + self.assertTrue(torch.allclose(dX, X.grad)) + + def test_backward_per_tensor_cachemask_cpu(self): + device = torch.device('cpu') + self._test_backward_per_tensor_cachemask_impl(device) + + @unittest.skipIf(not TEST_CUDA, "No gpu is not available.") + def test_backward_per_tensor_cachemask_cuda(self): + device = torch.device('cuda') + self._test_backward_per_tensor_cachemask_impl(device) + @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']), X=hu.tensor(shapes=hu.array_shapes(1, 5,), elements=hu.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False), diff --git a/test/test_namedtuple_return_api.py b/test/test_namedtuple_return_api.py index 00432c9e71cd..5ee70c0dacd1 100644 --- a/test/test_namedtuple_return_api.py +++ b/test/test_namedtuple_return_api.py @@ -14,7 +14,7 @@ 'max', 'min', 'median', 'nanmedian', 'mode', 'kthvalue', 'svd', 'symeig', 'eig', 'qr', 'geqrf', 'solve', 'slogdet', 'sort', 'topk', 'lstsq', 'triangular_solve', 'cummax', 'cummin', 'linalg_eigh', "unpack_dual", 'linalg_qr', - '_svd_helper', 'linalg_svd', 'linalg_slogdet', + '_svd_helper', 'linalg_svd', 'linalg_slogdet', 'fake_quantize_per_tensor_affine_cachemask', } @@ -68,6 +68,8 @@ def test_namedtuple_return(self): op(operators=['lstsq'], input=(a,), names=('solution', 'QR'), hasout=True), op(operators=['linalg_eigh'], input=("L",), names=('eigenvalues', 'eigenvectors'), hasout=True), op(operators=['linalg_slogdet'], input=(), names=('sign', 'logabsdet'), hasout=True), + op(operators=['fake_quantize_per_tensor_affine_cachemask'], + input=(0.1, 0, 0, 255), names=('output', 'mask',), hasout=False), op(operators=['unpack_dual'], input=(0,), names=('primal', 'tangent'), hasout=False), ] diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index c199d5a4e9df..a9a751abd260 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -461,6 +461,9 @@ - name: fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor self: fake_quantize_per_tensor_affine_backward(grad, self, scale, zero_point, quant_min, quant_max) +- name: fake_quantize_per_tensor_affine_cachemask(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor output, Tensor mask) + self: fake_quantize_per_tensor_affine_cachemask_backward(grad, mask) + - name: _fake_quantize_learnable_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor self, scale, zero_point: "grad.defined() ? _fake_quantize_learnable_per_tensor_affine_backward(grad, self, scale, zero_point, quant_min, quant_max) : std::tuple()" diff --git a/torch/overrides.py b/torch/overrides.py index 1a5ebfb9a133..187bb0425dc7 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -391,6 +391,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.expm1: lambda input, out=None: -1, torch.fake_quantize_per_channel_affine: lambda input, scale, zero_point, axis, quant_min, quant_max: -1, torch.fake_quantize_per_tensor_affine: lambda input, scale, zero_point, quant_min, quant_max: -1, + torch.fake_quantize_per_tensor_affine_cachemask: lambda input, scale, zero_point, quant_min, quant_max: -1, torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias: -1, torch.fbgemm_linear_fp16_weight_fp32_activation: lambda input, packed_weight, bias: -1, torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1, diff --git a/torch/quantization/fake_quantize.py b/torch/quantization/fake_quantize.py index 46dba803a1ff..4d29db46acc4 100644 --- a/torch/quantization/fake_quantize.py +++ b/torch/quantization/fake_quantize.py @@ -140,9 +140,19 @@ def forward(self, X): X = torch.fake_quantize_per_channel_affine(X, self.scale, self.zero_point, self.ch_axis, self.quant_min, self.quant_max) else: - X = torch.fake_quantize_per_tensor_affine(X, float(self.scale), - int(self.zero_point), self.quant_min, - self.quant_max) + if self.training: + # During training, use the memory optimized fake_quant + # forward. It has a reduced memory overhead in the backward + # pass compared to fake_quantize_per_tensor_affine. + X, _mask = torch.fake_quantize_per_tensor_affine_cachemask( + X, float(self.scale), int(self.zero_point), + self.quant_min, self.quant_max) + else: + # During inference, use the fastest fake_quant + # which does not compute any extra info for the backward. + X = torch.fake_quantize_per_tensor_affine( + X, float(self.scale), int(self.zero_point), + self.quant_min, self.quant_max) return X @torch.jit.export