Skip to content

Commit

Permalink
fake_quant: more memory efficient per-channel backward
Browse files Browse the repository at this point in the history
Summary:

This is the same as #50561, but for per-channel fake_quant.

TODO before land write up better

Test Plan:

```
python test/test_quantization.py TestFakeQuantize.test_forward_per_channel_cachemask_cpu
python test/test_quantization.py TestFakeQuantize.test_forward_per_channel_cachemask_cuda
python test/test_quantization.py TestFakeQuantize.test_backward_per_channel_cachemask_cpu
python test/test_quantization.py TestFakeQuantize.test_backward_per_channel_cachemask_cuda
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 7733441a1ed07b2b9ceea9aa0db0bfac6b961a64
Pull Request resolved: #51255
  • Loading branch information
vkuzo committed Jan 28, 2021
1 parent 12a434a commit 9ac5fa1
Show file tree
Hide file tree
Showing 10 changed files with 238 additions and 2 deletions.
8 changes: 8 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4663,6 +4663,14 @@
- func: fake_quantize_per_channel_affine_backward(Tensor grad, Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor
variants: function

- func: fake_quantize_per_channel_affine_cachemask(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor output, Tensor mask)
variants: function
dispatch:
CPU, CUDA: fake_quantize_per_channel_affine_cachemask

- func: fake_quantize_per_channel_affine_cachemask_backward(Tensor grad, Tensor mask) -> Tensor
variants: function

- func: _fake_quantize_learnable_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor
variants: function
dispatch:
Expand Down
30 changes: 30 additions & 0 deletions aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2151,6 +2151,35 @@ void fake_quant_per_channel_cpu(
});
}

void fake_quant_per_channel_cachemask_cpu(
TensorIterator& iter,
TensorIterator& iter_mask,
int64_t quant_min,
int64_t quant_max) {
// TODO(future, optional): read once, write twice. Not done at the moment
// for simplicity, as we do not expect this to be a bottleneck.

// write mask
cpu_kernel(iter_mask, [=](float self, float scale, int64_t zero_point) -> bool {
float inv_scale = 1.0f / scale;
const auto qval = static_cast<int64_t>(zero_point + std::nearbyint(self * inv_scale));
return ((quant_min <= qval) && (qval <= quant_max));
});

// write fake_quant
cpu_kernel(iter, [=](float self, float scale, int64_t zero_point) -> float {
float inv_scale = 1.0f / scale;
return (std::fmin(
std::fmax(
static_cast<int64_t>(
zero_point + std::nearbyint(self * inv_scale)),
quant_min),
quant_max) -
zero_point) *
scale;
});
}

void fake_quant_grad_per_channel_cpu(
TensorIterator& iter,
int64_t quant_min,
Expand Down Expand Up @@ -3046,6 +3075,7 @@ REGISTER_DISPATCH(fake_quant_grad_learnable_tensor_stub,
REGISTER_DISPATCH(fake_quant_grad_per_channel_stub,
&fake_quant_grad_per_channel_cpu);
REGISTER_DISPATCH(fake_quant_per_channel_stub, &fake_quant_per_channel_cpu);
REGISTER_DISPATCH(fake_quant_per_channel_cachemask_stub, &fake_quant_per_channel_cachemask_cpu);
REGISTER_DISPATCH(fake_quant_tensor_cachemask_stub,
&fake_quantize_tensor_cachemask_kernel);
REGISTER_DISPATCH(qadaptive_avg_pool2d_nhwc_stub,
Expand Down
30 changes: 30 additions & 0 deletions aten/src/ATen/native/quantized/cuda/fake_quantize_core.cu
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,35 @@ void fake_quant_per_channel_cuda(TensorIterator &iter, int64_t quant_min, int64_
});
}

void fake_quant_per_channel_cachemask_cuda(
TensorIterator &iter, TensorIterator &iter_mask, int64_t quant_min, int64_t quant_max) {
// TODO(future, optional): read once, write twice. Not done at the moment
// for simplicity, as we do not expect this to be a bottleneck.

// write mask
gpu_kernel(iter_mask,
[=] GPU_LAMBDA (float input_val, float scale, int64_t zero_point) -> bool {
float inv_scale = 1.0f / scale;
const auto qval = static_cast<int64_t>(std::nearbyint(input_val * inv_scale) + zero_point);
return ((quant_min <= qval) && (qval <= quant_max));
});

// write fake_quant
gpu_kernel(iter,
[=] GPU_LAMBDA (float input_val, float scale, int64_t zero_point) -> float {
float inv_scale = 1.0f / scale;
return (fminf(
quant_max,
fmaxf(
quant_min,
static_cast<int64_t>(
std::nearbyint(input_val * inv_scale) +
zero_point))) -
zero_point) *
scale;
});
}

void fake_quant_grad_per_channel_cuda(TensorIterator &iter, int64_t quant_min, int64_t quant_max) {
gpu_kernel(iter,
[=] GPU_LAMBDA (float x, float dy, float scale, int64_t zero_point) -> float {
Expand Down Expand Up @@ -134,6 +163,7 @@ void _fake_quantize_grad_learnable_channel_kernel_cuda(TensorIterator &iter, int

REGISTER_DISPATCH(fake_quant_per_channel_stub, &fake_quant_per_channel_cuda);
REGISTER_DISPATCH(fake_quant_grad_per_channel_stub, &fake_quant_grad_per_channel_cuda);
REGISTER_DISPATCH(fake_quant_per_channel_cachemask_stub, &fake_quant_per_channel_cachemask_cuda);
REGISTER_DISPATCH(fake_quant_grad_learnable_channel_stub, &_fake_quantize_grad_learnable_channel_kernel_cuda);

} // namespace native
Expand Down
7 changes: 7 additions & 0 deletions aten/src/ATen/native/quantized/fake_quant_affine.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,15 @@ using fake_quant_per_channel_fn = void (*)(
int64_t quant_min,
int64_t quant_max);

using fake_quant_per_channel_cachemask_fn = void (*)(
TensorIterator &iter,
TensorIterator &iter_mask,
int64_t quant_min,
int64_t quant_max);

DECLARE_DISPATCH(fake_quant_per_channel_fn, fake_quant_per_channel_stub);
DECLARE_DISPATCH(fake_quant_per_channel_fn, fake_quant_grad_per_channel_stub);
DECLARE_DISPATCH(fake_quant_per_channel_cachemask_fn, fake_quant_per_channel_cachemask_stub);
DECLARE_DISPATCH(fake_quant_per_channel_fn, fake_quant_grad_learnable_channel_stub);

} // namespace native
Expand Down
91 changes: 91 additions & 0 deletions aten/src/ATen/native/quantized/fake_quant_per_channel_affine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace native {
// Use REGISTER_DISPATCH to run CPU and CUDA backend.
DEFINE_DISPATCH(fake_quant_per_channel_stub);
DEFINE_DISPATCH(fake_quant_grad_per_channel_stub);
DEFINE_DISPATCH(fake_quant_per_channel_cachemask_stub);
DEFINE_DISPATCH(fake_quant_grad_learnable_channel_stub);

/* Per channel fake-quantizes the 'inputs' tensor.
Expand Down Expand Up @@ -81,6 +82,96 @@ Tensor fake_quantize_per_channel_affine(
return Y;
}

std::tuple<Tensor, Tensor> fake_quantize_per_channel_affine_cachemask(
const Tensor& self,
const Tensor& scale,
const Tensor& zero_point,
int64_t axis,
int64_t quant_min,
int64_t quant_max) {
TORCH_CHECK(self.scalar_type() == ScalarType::Float);
TORCH_CHECK(scale.scalar_type() == ScalarType::Float,
"Scale must be Float, found ", scale.scalar_type());
TORCH_CHECK(zero_point.scalar_type() == ScalarType::Long,
"Zero-point must be Long, found ", zero_point.scalar_type());
TORCH_CHECK(scale.dim() == 1, "scale should be a 1-D tensor");
TORCH_CHECK(zero_point.dim() == 1, "zero point should be a 1-D tensor");
TORCH_CHECK(
scale.numel() == zero_point.numel(),
"scale and zero-point need to have the same dimensions");
TORCH_CHECK(
scale.numel() == self.size(axis),
"dimensions of scale and zero-point are not consistent with input tensor")

TORCH_CHECK(
quant_min <= quant_max,
"`quant_min` should be less than or \
equal to `quant_max`.");

TORCH_CHECK(
at::min(zero_point).item().toLong() >= quant_min &&
at::max(zero_point).item().toLong() <= quant_max,
"`zero_point` must be between `quant_min` and `quant_max`.");

TORCH_CHECK(
axis >= 0 && axis <= self.dim(),
"`axis` must be between 0 and number of dimensions of input");

auto Y = at::empty_like(self, self.options(), MemoryFormat::Preserve);
auto mask = at::empty_like(self, at::kBool, MemoryFormat::Preserve);

std::vector<int64_t> expected_shape(self.dim(), 1);
expected_shape[axis] = self.size(axis);

TensorIterator iter = TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(Y)
.add_input(self)
.add_input(native::_unsafe_view(scale, expected_shape))
.add_input(native::_unsafe_view(zero_point, expected_shape))
.build();

// TODO(future, optional): read once, write twice. Not done at the moment
// for simplicity, as we do not expect this to be a bottleneck.
TensorIterator iter_mask = TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(mask)
.add_input(self)
.add_input(native::_unsafe_view(scale, expected_shape))
.add_input(native::_unsafe_view(zero_point, expected_shape))
.build();

// TODO(future, optional): look into packing the mask further (BoolTensor uses
// 1 byte per element, we only need 1 bit per element).
fake_quant_per_channel_cachemask_stub(iter.device_type(), iter, iter_mask, quant_min, quant_max);
return std::make_tuple(Y, mask);
}

/* Backward path to fake-quantize the 'inputs' tensor per channel, with mask.
Args:
dY: output grad.
mask: mask tensor from the forward pass.
Returns:
dX (input grad).
*/
Tensor fake_quantize_per_channel_affine_cachemask_backward(
const Tensor& dY,
const Tensor& mask) {
TORCH_CHECK(dY.scalar_type() == ScalarType::Float);
TORCH_CHECK(mask.scalar_type() == ScalarType::Bool);
TORCH_CHECK(mask.numel() == dY.numel(),
"`mask` and `dY` are not the same size: ",
"`mask` is size ", mask.numel(), " and `dY` is size ", dY.numel());
if (dY.numel() <= 0) {
return dY;
}
// Note: no additional kernels needed, since mask is pre-computed
// and we can use the existing tensor multiplication kernels.
return dY * mask;
}

/* Backward path for per-channel fake-quantization of the 'inputs' tensor.
Args:
Expand Down
54 changes: 54 additions & 0 deletions test/quantization/test_workflow_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1278,6 +1278,32 @@ def test_forward_per_channel(self, device, X):
X, scale, zero_point, axis, quant_min, quant_max)
np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)

def _test_forward_per_channel_cachemask_impl(self, device):
for torch_type in (torch.qint8, torch.quint8):

X = torch.randn(1, 2, 4, 4).to(device)
# pick the scale + zp so that some values get clipped
axis = 1
obs = torch.quantization.PerChannelMinMaxObserver(axis, torch_type).to(device)
obs(X * 0.75)
scale, zero_point = obs.calculate_qparams()
# TODO(future PR): fix the wrong dtype in obs.calculate_qparams and remove the cast
zero_point = zero_point.to(torch.int64)
quant_min, quant_max = obs._calculate_qmin_qmax()

Y = _fake_quantize_per_channel_affine_reference(
X.cpu(), scale.cpu(), zero_point.cpu(), axis, quant_min, quant_max)
Y_prime, _mask = torch.fake_quantize_per_channel_affine_cachemask(
X, scale, zero_point, axis, quant_min, quant_max)
np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)

def test_forward_per_channel_cachemask_cpu(self):
self._test_forward_per_channel_cachemask_impl('cpu')

@unittest.skipIf(not TEST_CUDA, "No gpu is not available.")
def test_forward_per_channel_cachemask_cuda(self):
self._test_forward_per_channel_cachemask_impl('cuda')

def _test_learnable_forward_per_channel(self, X_base, device, scale_base, zero_point_base, axis):
r"""Tests the forward path of the learnable FakeQuantizePerTensorAffine op.
"""
Expand Down Expand Up @@ -1347,6 +1373,34 @@ def test_backward_per_channel(self, device, X):
Y_prime.backward(dout)
np.testing.assert_allclose(dX.cpu().detach().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)

def _test_backward_per_channel_cachemask_impl(self, device):
for torch_type in (torch.qint8, torch.quint8):
X = torch.randn(1, 2, 4, 4).to(device)
# pick the scale + zp so that some values get clipped
axis = 1
obs = torch.quantization.PerChannelMinMaxObserver(axis, torch_type).to(device)
obs(X * 0.75)
scale, zero_point = obs.calculate_qparams()
# TODO(future PR): fix the wrong dtype in obs.calculate_qparams and remove the cast
zero_point = zero_point.to(torch.int64)
quant_min, quant_max = obs._calculate_qmin_qmax()
X.requires_grad_()
Y_prime, _mask = torch.fake_quantize_per_channel_affine_cachemask(
X, scale, zero_point, axis, quant_min, quant_max)
dout = torch.rand(X.shape, dtype=torch.float).to(device)
dX = _fake_quantize_per_channel_affine_grad_reference(
dout, X, scale, zero_point, axis, quant_min, quant_max)
Y_prime.backward(dout)
np.testing.assert_allclose(
dX.cpu().detach().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)

def test_backward_per_channel_cachemask_cpu(self):
self._test_backward_per_channel_cachemask_impl('cpu')

@unittest.skipIf(not TEST_CUDA, "No gpu is not available.")
def test_backward_per_channel_cachemask_cuda(self):
self._test_backward_per_channel_cachemask_impl('cuda')

def _test_learnable_backward_per_channel(self, X_base, device, scale_base, zero_point_base, axis):
r"""Tests the backward path of the learnable FakeQuantizePerTensorAffine op.
"""
Expand Down
6 changes: 6 additions & 0 deletions test/test_namedtuple_return_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
'qr', 'geqrf', 'solve', 'slogdet', 'sort', 'topk', 'lstsq',
'triangular_solve', 'cummax', 'cummin', 'linalg_eigh', "unpack_dual", 'linalg_qr',
'_svd_helper', 'linalg_svd', 'linalg_slogdet', 'fake_quantize_per_tensor_affine_cachemask',
'fake_quantize_per_channel_affine_cachemask',
}


Expand Down Expand Up @@ -51,6 +52,8 @@ def test_native_functions_yaml(self):

def test_namedtuple_return(self):
a = torch.randn(5, 5)
per_channel_scale = torch.randn(5)
per_channel_zp = torch.zeros(5, dtype=torch.int64)

op = namedtuple('op', ['operators', 'input', 'names', 'hasout'])
operators = [
Expand All @@ -70,6 +73,9 @@ def test_namedtuple_return(self):
op(operators=['linalg_slogdet'], input=(), names=('sign', 'logabsdet'), hasout=True),
op(operators=['fake_quantize_per_tensor_affine_cachemask'],
input=(0.1, 0, 0, 255), names=('output', 'mask',), hasout=False),
op(operators=['fake_quantize_per_channel_affine_cachemask'],
input=(per_channel_scale, per_channel_zp, 1, 0, 255),
names=('output', 'mask',), hasout=False),
op(operators=['unpack_dual'], input=(0,), names=('primal', 'tangent'), hasout=False),
]

Expand Down
3 changes: 3 additions & 0 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,9 @@
- name: fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor
self: fake_quantize_per_channel_affine_backward(grad, self, scale, zero_point, axis, quant_min, quant_max)

- name: fake_quantize_per_channel_affine_cachemask(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor output, Tensor mask)
self: fake_quantize_per_channel_affine_cachemask_backward(grad, mask)

- name: _fake_quantize_learnable_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor
self, scale, zero_point: "grad.defined() ? _fake_quantize_learnable_per_channel_affine_backward(grad, self, scale, zero_point, axis, quant_min, quant_max) : std::tuple<Tensor, Tensor, Tensor>()"

Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.exp2: lambda input, out=None: -1,
torch.expm1: lambda input, out=None: -1,
torch.fake_quantize_per_channel_affine: lambda input, scale, zero_point, axis, quant_min, quant_max: -1,
torch.fake_quantize_per_channel_affine_cachemask: lambda input, scale, zero_point, axis, quant_min, quant_max: -1,
torch.fake_quantize_per_tensor_affine: lambda input, scale, zero_point, quant_min, quant_max: -1,
torch.fake_quantize_per_tensor_affine_cachemask: lambda input, scale, zero_point, quant_min, quant_max: -1,
torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias: -1,
Expand Down
10 changes: 8 additions & 2 deletions torch/quantization/fake_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,14 @@ def forward(self, X):

if self.fake_quant_enabled[0] == 1:
if self.is_per_channel:
X = torch.fake_quantize_per_channel_affine(X, self.scale, self.zero_point,
self.ch_axis, self.quant_min, self.quant_max)
if self.training:
X, _mask = torch.fake_quantize_per_channel_affine_cachemask(
X, self.scale, self.zero_point,
self.ch_axis, self.quant_min, self.quant_max)
else:
X = torch.fake_quantize_per_channel_affine(
X, self.scale, self.zero_point,
self.ch_axis, self.quant_min, self.quant_max)
else:
X = torch.fake_quantize_per_tensor_affine(
X, float(self.scale), int(self.zero_point),
Expand Down

0 comments on commit 9ac5fa1

Please sign in to comment.