Skip to content

Commit

Permalink
fake_quant: add a more memory efficient version (#50561)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #50561

Not for review yet, a bunch of TODOs need finalizing.

tl;dr; add an alternative implementation of `fake_quantize` which saves
a ask during the forward pass and uses it to calculate the backward.

There are two benefits:

1. the backward function no longer needs the input Tensor, and it can be
gc'ed earlier by autograd.  On MobileNetV2, this reduces QAT overhead
by ~15% (TODO: link, and absolute numbers).  We add an additional mask Tensor
to pass around, but its size is 4x smaller than the input tensor. A
future optimization would be to pack the mask bitwise and unpack in the
backward.

2. the computation of `qval` can be done only once in the forward and
reused in the backward. No perf change observed, TODO verify with better
matrics.

TODO: describe in more detail

Test Plan:
OSS / torchvision / MobileNetV2
```
python references/classification/train_quantization.py
  --print-freq 1
  --data-path /data/local/packages/ai-group.imagenet-256-smallest-side/prod/
  --output-dir ~/nfs/pytorch_vision_tests/
  --backend qnnpack
  --epochs 5
TODO paste results here
```

TODO more

Imported from OSS

Reviewed By: ngimel

Differential Revision: D25918519

fbshipit-source-id: ec544ca063f984de0f765bf833f205c99d6c18b6
  • Loading branch information
vkuzo authored and facebook-github-bot committed Jan 28, 2021
1 parent d14d8c7 commit 983b8e6
Show file tree
Hide file tree
Showing 11 changed files with 243 additions and 19 deletions.
8 changes: 8 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -4642,6 +4642,14 @@
- func: fake_quantize_per_tensor_affine_backward(Tensor grad, Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor
variants: function

- func: fake_quantize_per_tensor_affine_cachemask(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor output, Tensor mask)
variants: function
dispatch:
CPU, CUDA: fake_quantize_per_tensor_affine_cachemask

- func: fake_quantize_per_tensor_affine_cachemask_backward(Tensor grad, Tensor mask) -> Tensor
variants: function

- func: _fake_quantize_learnable_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor
variants: function
dispatch:
Expand Down
34 changes: 34 additions & 0 deletions aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
Expand Up @@ -2091,6 +2091,38 @@ void fake_quantize_grad_tensor_kernel(
});
}

void fake_quantize_tensor_cachemask_kernel(
Tensor& output,
Tensor& mask,
const Tensor& input,
float sc,
int64_t z_point,
int64_t quant_min,
int64_t quant_max) {
float inv_scale = 1.0f / sc;

auto iter_combined = TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(output)
.add_output(mask)
.add_input(input)
.build();

// TODO(#51090): make it work for other dtypes
iter_combined.for_each([&](char** data, const int64_t* strides, int64_t n) {
for (int64_t i = 0; i < n; i++) {
float* output_val = (float*)(data[0] + i * strides[0]);
bool* mask_val = (bool*)(data[1] + i * strides[1]);
float* input_val = (float*)(data[2] + i * strides[2]);

const auto qval = static_cast<int64_t>(z_point + std::nearbyint(*input_val * inv_scale));
*output_val = (std::fmin(std::fmax(qval, quant_min), quant_max) - z_point) * sc;
*mask_val = ((quant_min <= qval) && (qval <= quant_max));
}
});

}

void fake_quantize_learnable_tensor_grad_kernel_cpu(
TensorIterator& iter,
float scale,
Expand Down Expand Up @@ -3054,6 +3086,8 @@ REGISTER_DISPATCH(fake_quant_grad_tensor_stub,
&fake_quantize_grad_tensor_kernel);
REGISTER_DISPATCH(fake_quant_per_channel_stub, &fake_quant_per_channel_cpu);
REGISTER_DISPATCH(fake_quant_tensor_stub, &fake_quantize_tensor_kernel);
REGISTER_DISPATCH(fake_quant_tensor_cachemask_stub,
&fake_quantize_tensor_cachemask_kernel);
REGISTER_DISPATCH(qadaptive_avg_pool2d_nhwc_stub,
&qadaptive_avg_pool2d_nhwc_kernel);
REGISTER_DISPATCH(qadaptive_avg_pool3d_ndhwc_stub,
Expand Down
32 changes: 32 additions & 0 deletions aten/src/ATen/native/quantized/cuda/fake_quantize_core.cu
Expand Up @@ -68,6 +68,37 @@ void fake_quantize_grad_tensor_kernel_cuda(
});
}

void fake_quantize_tensor_cachemask_kernel_cuda(
Tensor& output,
Tensor& mask,
const Tensor& input,
float scale,
int64_t zero_point,
int64_t quant_min,
int64_t quant_max) {

float inv_scale = 1.0f / scale;
auto iter = TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(output)
.add_output(mask)
.add_input(input)
.build();

gpu_kernel_multiple_outputs(
iter,
[=] GPU_LAMBDA (float input_val) -> thrust::tuple<float, bool> {
const auto qval = static_cast<int64_t>(std::nearbyint(input_val * inv_scale) + zero_point);
return {
// fake_quantized value
(fminf(quant_max, fmaxf(quant_min, qval)) - zero_point) * scale,
// mask for grad
((quant_min <= qval) && (qval <= quant_max))
};
}
);
}

void _fake_quantize_grad_learnable_tensor_kernel_cuda(
TensorIterator& iter,
float scale,
Expand Down Expand Up @@ -96,6 +127,7 @@ void _fake_quantize_grad_learnable_tensor_kernel_cuda(
}

REGISTER_DISPATCH(fake_quant_tensor_stub, &fake_quantize_tensor_kernel_cuda);
REGISTER_DISPATCH(fake_quant_tensor_cachemask_stub, &fake_quantize_tensor_cachemask_kernel_cuda);
REGISTER_DISPATCH(fake_quant_grad_tensor_stub, &fake_quantize_grad_tensor_kernel_cuda);
REGISTER_DISPATCH(fake_quant_grad_learnable_tensor_stub, &_fake_quantize_grad_learnable_tensor_kernel_cuda);

Expand Down
10 changes: 10 additions & 0 deletions aten/src/ATen/native/quantized/fake_quant_affine.h
Expand Up @@ -26,6 +26,15 @@ using fake_quant_grad_tensor_fn = void (*)(
int64_t quant_min,
int64_t quant_max);

using fake_quant_tensor_cachemask_fn = void (*)(
Tensor& output,
Tensor& mask,
const Tensor& input,
float sc,
int64_t z_point,
int64_t quant_min,
int64_t quant_max);

using fake_quant_learnable_grad_tensor_fn = void (*)(
TensorIterator& iter,
float scale,
Expand All @@ -36,6 +45,7 @@ using fake_quant_learnable_grad_tensor_fn = void (*)(

DECLARE_DISPATCH(fake_quant_tensor_fn, fake_quant_tensor_stub);
DECLARE_DISPATCH(fake_quant_grad_tensor_fn, fake_quant_grad_tensor_stub);
DECLARE_DISPATCH(fake_quant_tensor_cachemask_fn, fake_quant_tensor_cachemask_stub);
DECLARE_DISPATCH(fake_quant_learnable_grad_tensor_fn, fake_quant_grad_learnable_tensor_stub);

using fake_quant_per_channel_fn = void (*)(
Expand Down
81 changes: 71 additions & 10 deletions aten/src/ATen/native/quantized/fake_quant_per_tensor_affine.cpp
Expand Up @@ -12,16 +12,19 @@ namespace native {
// Use REGISTER_DISPATCH to run CPU and CUDA backend.
DEFINE_DISPATCH(fake_quant_tensor_stub);
DEFINE_DISPATCH(fake_quant_grad_tensor_stub);
DEFINE_DISPATCH(fake_quant_tensor_cachemask_stub);
DEFINE_DISPATCH(fake_quant_grad_learnable_tensor_stub);

/* Fake-quantizes the 'inputs' tensor.
Args:
X: Forward input tensor.
self: Forward input tensor.
dY: Backward input tensor (_backward op only).
scale: scale of per tensor affine quantization
zero_point: zero_point of per tensor affine quantization
quant_min: minimum quantized value
quant_max: maximum quantized value
Returns:
Quantized tensor (double dtype).
Expand Down Expand Up @@ -50,22 +53,15 @@ Tensor fake_quantize_per_tensor_affine(
/* Backward path to fake-quantize the 'inputs' tensor.
Args:
X: Forward input tensor.
dY: Backward input tensor.
X: Forward input tensor.
scale: scale of per tensor affine quantization
zero_point: zero_point of per tensor affine quantization
quant_min: minimum quantized value
quant_max: maximum quantized value
quant_delay: Count of global steps for which to delay the quantization.
See note in forward.
iter: The current quantization iteration used for `quant_delay`.
Returns:
Quantized tensor (double dtype).
Notes:
- quant_delay might be set to non-zero to help weights stabilize in the
beginning of the training.
- quantization range [0, 2^bits - 1]
*/

Tensor fake_quantize_per_tensor_affine_backward(
Expand Down Expand Up @@ -95,6 +91,71 @@ Tensor fake_quantize_per_tensor_affine_backward(
return dX;
}

/* Fake-quantizes the 'inputs' tensor, saving a mask for the backward pass.
This is numerically equivalent to `fake_quantize_per_tensor_affine`,
but has a lower memory overhead in the backward pass.
Args:
self: Forward input tensor.
scale: scale of per tensor affine quantization
zero_point: zero_point of per tensor affine quantization
quant_min: minimum quantized value
quant_max: maximum quantized value
Returns:
Quantized tensor (double dtype).
Mask (bool dtype).
*/
std::tuple<Tensor, Tensor> fake_quantize_per_tensor_affine_cachemask(
const Tensor& self,
double scale,
int64_t zero_point,
int64_t quant_min,
int64_t quant_max) {
TORCH_CHECK(self.scalar_type() == ScalarType::Float);
TORCH_CHECK(
quant_min <= quant_max,
"`quant_min` should be less than or \
equal to `quant_max`.");
TORCH_CHECK(
zero_point >= quant_min && zero_point <= quant_max,
"`zero_point` must be between `quant_min` and `quant_max`.");

auto Y = at::empty_like(self, self.options(), MemoryFormat::Preserve);
auto mask = at::empty_like(self, at::kBool, MemoryFormat::Preserve);
fake_quant_tensor_cachemask_stub(
self.device().type(), Y, mask, self, scale, zero_point, quant_min, quant_max);
// TODO(future, optional): look into packing the mask further (BoolTensor uses
// 1 byte per element, we only need 1 bit per element).
return std::make_tuple(Y, mask);
}

/* Backward path to fake-quantize the 'inputs' tensor, with mask.
Args:
dY: output grad.
mask: mask tensor from the forward pass.
Returns:
dX (input grad).
*/
Tensor fake_quantize_per_tensor_affine_cachemask_backward(
const Tensor& dY,
const Tensor& mask) {
TORCH_CHECK(dY.scalar_type() == ScalarType::Float);
TORCH_CHECK(mask.scalar_type() == ScalarType::Bool);
TORCH_CHECK(mask.numel() == dY.numel(),
"`mask` and `dY` are not the same size: ",
"`mask` is size ", mask.numel(), " and `dY` is size ", dY.numel());
if (dY.numel() <= 0) {
return dY;
}
// Note: no additional kernels needed, since mask is pre-computed
// and we can use the existing tensor multiplication kernels.
return dY * mask;
}

int64_t _get_zero_point_from_tensor(
const Tensor& zero_point,
int64_t quant_min,
Expand Down
14 changes: 9 additions & 5 deletions benchmarks/operator_benchmark/pt/quantization_test.py
Expand Up @@ -130,35 +130,38 @@ def forward(self, input, scales, zero_points, axis: int, dtype: int):
'attr_names': ['N', 'C', 'H', 'W'],
'attrs': [
[1, 3, 512, 512],
[1, 3, 512, 512]
],
'tags': ['short']
}

fake_quantize_configs_long_dict = {
'N': [1],
'C': [1, 3, 8],
'C': [1, 3, 8, 32],
'H': [256, 1024],
'W': [256, 1024],
'tags': ['long']
}

fake_quantize_configs_short = op_bench.config_list(
cross_product_configs={
'device': ('cpu', 'cuda'),
},
**fake_quantize_configs_short_dict
)

fake_quantize_configs_long = op_bench.cross_product_configs(
device=('cpu', 'cuda'),
**fake_quantize_configs_long_dict
)


class FakeQuantizeBenchmark(op_bench.TorchBenchmarkBase):
r"""Benchmarks fake quantization with default parameters."""
def init(self, N, C, H, W):
def init(self, N, C, H, W, device):
self.inputs = {
"input": torch.rand(N, C, H, W)
"input": torch.rand(N, C, H, W).to(device)
}
self.op = tq.FakeQuantize()
self.op = tq.FakeQuantize().to(device)
self.set_module_name('FakeQuantize')

def forward(self, input):
Expand All @@ -169,6 +172,7 @@ def forward(self, input):
fake_quantize_configs_short + fake_quantize_configs_long,
FakeQuantizeBenchmark)


# op_type is used to describe the type of operator used in benchmarking:
# py_module represents the operator written in Python that can
# backpropagate on scale and zero point.
Expand Down
59 changes: 59 additions & 0 deletions test/quantization/test_workflow_module.py
Expand Up @@ -861,6 +861,65 @@ def test_backward_per_tensor(self, device, X):
Y_prime.backward(dout)
np.testing.assert_allclose(dX.cpu(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)

def _test_forward_per_tensor_cachemask_impl(self, device):
for torch_type in (torch.qint8, torch.quint8):
X = torch.randn(4, 8).to(device)
# pick the scale + zp so that some values get clipped
obs = torch.quantization.MinMaxObserver(torch_type)
obs(X * 0.75)
scale, zero_point = obs.calculate_qparams()
scale, zero_point = float(scale), int(zero_point)
quant_min, quant_max = obs._calculate_qmin_qmax()

Y_test, _mask = torch.fake_quantize_per_tensor_affine_cachemask(
X, scale, zero_point, quant_min, quant_max)
Y_ref = _fake_quantize_per_tensor_affine_reference(
X.cpu(), scale, zero_point, quant_min, quant_max).to(device)
self.assertTrue(torch.allclose(Y_test, Y_ref, rtol=tolerance, atol=tolerance))

def test_forward_per_tensor_cachemask_cpu(self):
device = torch.device('cpu')
self._test_forward_per_tensor_cachemask_impl(device)

@unittest.skipIf(not TEST_CUDA, "No gpu is not available.")
def test_forward_per_tensor_cachemask_cuda(self):
device = torch.device('cuda')
self._test_forward_per_tensor_cachemask_impl(device)

def _test_backward_per_tensor_cachemask_impl(self, device):
for torch_type in (torch.qint8, torch.quint8):
X = torch.randn(4, 8).to(device)
X.requires_grad_()
# pick the scale + zp so that some values get clipped
obs = torch.quantization.MinMaxObserver(torch_type)
obs(X * 0.75)
scale, zero_point = obs.calculate_qparams()
scale, zero_point = float(scale), int(zero_point)
quant_min, quant_max = obs._calculate_qmin_qmax()

# forward pass
Y_test, mask = torch.fake_quantize_per_tensor_affine_cachemask(
X, scale, zero_point, quant_min, quant_max)
Y_ref = _fake_quantize_per_tensor_affine_reference(
X.cpu(), scale, zero_point, quant_min, quant_max).to(device)
self.assertTrue(torch.allclose(Y_test, Y_ref, rtol=tolerance, atol=tolerance))

# backward pass
dout = torch.rand(X.shape, dtype=torch.float).to(device)
dX = _fake_quantize_per_tensor_affine_grad_reference(
dout, X, scale, zero_point, quant_min, quant_max)
Y_test.backward(dout)
self.assertTrue(torch.allclose(dX, X.grad))

def test_backward_per_tensor_cachemask_cpu(self):
device = torch.device('cpu')
self._test_backward_per_tensor_cachemask_impl(device)

@unittest.skipIf(not TEST_CUDA, "No gpu is not available.")
def test_backward_per_tensor_cachemask_cuda(self):
device = torch.device('cuda')
self._test_backward_per_tensor_cachemask_impl(device)

@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
X=hu.tensor(shapes=hu.array_shapes(1, 5,),
elements=hu.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False),
Expand Down
4 changes: 3 additions & 1 deletion test/test_namedtuple_return_api.py
Expand Up @@ -14,7 +14,7 @@
'max', 'min', 'median', 'nanmedian', 'mode', 'kthvalue', 'svd', 'symeig', 'eig',
'qr', 'geqrf', 'solve', 'slogdet', 'sort', 'topk', 'lstsq',
'triangular_solve', 'cummax', 'cummin', 'linalg_eigh', "unpack_dual", 'linalg_qr',
'_svd_helper', 'linalg_svd', 'linalg_slogdet',
'_svd_helper', 'linalg_svd', 'linalg_slogdet', 'fake_quantize_per_tensor_affine_cachemask',
}


Expand Down Expand Up @@ -68,6 +68,8 @@ def test_namedtuple_return(self):
op(operators=['lstsq'], input=(a,), names=('solution', 'QR'), hasout=True),
op(operators=['linalg_eigh'], input=("L",), names=('eigenvalues', 'eigenvectors'), hasout=True),
op(operators=['linalg_slogdet'], input=(), names=('sign', 'logabsdet'), hasout=True),
op(operators=['fake_quantize_per_tensor_affine_cachemask'],
input=(0.1, 0, 0, 255), names=('output', 'mask',), hasout=False),
op(operators=['unpack_dual'], input=(0,), names=('primal', 'tangent'), hasout=False),
]

Expand Down
3 changes: 3 additions & 0 deletions tools/autograd/derivatives.yaml
Expand Up @@ -461,6 +461,9 @@
- name: fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor
self: fake_quantize_per_tensor_affine_backward(grad, self, scale, zero_point, quant_min, quant_max)

- name: fake_quantize_per_tensor_affine_cachemask(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor output, Tensor mask)
self: fake_quantize_per_tensor_affine_cachemask_backward(grad, mask)

- name: _fake_quantize_learnable_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor
self, scale, zero_point: "grad.defined() ? _fake_quantize_learnable_per_tensor_affine_backward(grad, self, scale, zero_point, quant_min, quant_max) : std::tuple<Tensor, Tensor, Tensor>()"

Expand Down

0 comments on commit 983b8e6

Please sign in to comment.