Skip to content

Commit

Permalink
memory efficient per-channel fq: use it everywhere, delete old version
Browse files Browse the repository at this point in the history
Summary:

This PR is the cleanup after #51159. High level, we make the new
definition of fake_quant per channel be the definition used by autograd, but keep the old
function around as a thin wrapper to keep the user facing API the same.

In detail:

1. point fake_quantize_per_channel_affine's implementation to be fake_quantize_per_channel_affine_cachemask
2. delete the fake_quantize_per_channel_affine backward, autograd will automatically use the cachemask backward
3. delete all the fake_quantize_per_channel_affine kernels, since they are no longer used by anything

Test Plan:

```
python test/test_quantization.py TestFakeQuantize
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
  • Loading branch information
vkuzo committed Jan 28, 2021
1 parent 3596d53 commit 3b53e9b
Show file tree
Hide file tree
Showing 8 changed files with 8 additions and 204 deletions.
5 changes: 1 addition & 4 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -4658,10 +4658,7 @@
- func: fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor
variants: function
dispatch:
CPU, CUDA: fake_quantize_per_channel_affine

- func: fake_quantize_per_channel_affine_backward(Tensor grad, Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor
variants: function
Math: fake_quantize_per_channel_affine

- func: fake_quantize_per_channel_affine_cachemask(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor output, Tensor mask)
variants: function
Expand Down
33 changes: 0 additions & 33 deletions aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
Expand Up @@ -2134,23 +2134,6 @@ void fake_quantize_learnable_tensor_grad_kernel_cpu(
});
}

void fake_quant_per_channel_cpu(
TensorIterator& iter,
int64_t quant_min,
int64_t quant_max) {
cpu_kernel(iter, [=](float self, float scale, int64_t zero_point) -> float {
float inv_scale = 1.0f / scale;
return (std::fmin(
std::fmax(
static_cast<int64_t>(
zero_point + std::nearbyint(self * inv_scale)),
quant_min),
quant_max) -
zero_point) *
scale;
});
}

void fake_quant_per_channel_cachemask_cpu(
TensorIterator& iter,
TensorIterator& iter_mask,
Expand Down Expand Up @@ -2180,19 +2163,6 @@ void fake_quant_per_channel_cachemask_cpu(
});
}

void fake_quant_grad_per_channel_cpu(
TensorIterator& iter,
int64_t quant_min,
int64_t quant_max) {
cpu_kernel(
iter, [=](float x, float dy, float scale, int64_t zero_point) -> float {
float inv_scale = 1.0f / scale;
int64_t xq =
static_cast<int64_t>(zero_point + std::nearbyint(x * inv_scale));
return dy * (xq >= quant_min && xq <= quant_max);
});
}

void fake_quantize_learnable_channel_grad_kernel_cpu(
TensorIterator& iter,
int64_t quant_min,
Expand Down Expand Up @@ -3072,9 +3042,6 @@ REGISTER_DISPATCH(dequantize_tensor_per_channel_float_qparams_stub,
&dequantize_tensor_per_channel_float_qparams_cpu);
REGISTER_DISPATCH(fake_quant_grad_learnable_tensor_stub,
&fake_quantize_learnable_tensor_grad_kernel_cpu);
REGISTER_DISPATCH(fake_quant_grad_per_channel_stub,
&fake_quant_grad_per_channel_cpu);
REGISTER_DISPATCH(fake_quant_per_channel_stub, &fake_quant_per_channel_cpu);
REGISTER_DISPATCH(fake_quant_per_channel_cachemask_stub, &fake_quant_per_channel_cachemask_cpu);
REGISTER_DISPATCH(fake_quant_tensor_cachemask_stub,
&fake_quantize_tensor_cachemask_kernel);
Expand Down
27 changes: 0 additions & 27 deletions aten/src/ATen/native/quantized/cuda/fake_quantize_core.cu
Expand Up @@ -83,22 +83,6 @@ REGISTER_DISPATCH(fake_quant_grad_learnable_tensor_stub, &_fake_quantize_grad_le

// Fake quantize per channel

void fake_quant_per_channel_cuda(TensorIterator &iter, int64_t quant_min, int64_t quant_max) {
gpu_kernel(iter,
[=] GPU_LAMBDA (float input_val, float scale, int64_t zero_point) -> float {
float inv_scale = 1.0f / scale;
return (fminf(
quant_max,
fmaxf(
quant_min,
static_cast<int64_t>(
std::nearbyint(input_val * inv_scale) +
zero_point))) -
zero_point) *
scale;
});
}

void fake_quant_per_channel_cachemask_cuda(
TensorIterator &iter, TensorIterator &iter_mask, int64_t quant_min, int64_t quant_max) {
// TODO(future, optional): read once, write twice. Not done at the moment
Expand Down Expand Up @@ -128,15 +112,6 @@ void fake_quant_per_channel_cachemask_cuda(
});
}

void fake_quant_grad_per_channel_cuda(TensorIterator &iter, int64_t quant_min, int64_t quant_max) {
gpu_kernel(iter,
[=] GPU_LAMBDA (float x, float dy, float scale, int64_t zero_point) -> float {
float inv_scale = 1.0f / scale;
int64_t Xq = std::nearbyint(x * inv_scale) + zero_point;
return (Xq >= quant_min && Xq <= quant_max) * dy;
});
}

void _fake_quantize_grad_learnable_channel_kernel_cuda(TensorIterator &iter, int64_t quant_min, int64_t quant_max) {
gpu_kernel_multiple_outputs(iter,
[=] GPU_LAMBDA (float x_input, float dy_input, float scale_input, float zero_point_input) -> thrust::tuple<float, float, float> {
Expand All @@ -161,8 +136,6 @@ void _fake_quantize_grad_learnable_channel_kernel_cuda(TensorIterator &iter, int
});
}

REGISTER_DISPATCH(fake_quant_per_channel_stub, &fake_quant_per_channel_cuda);
REGISTER_DISPATCH(fake_quant_grad_per_channel_stub, &fake_quant_grad_per_channel_cuda);
REGISTER_DISPATCH(fake_quant_per_channel_cachemask_stub, &fake_quant_per_channel_cachemask_cuda);
REGISTER_DISPATCH(fake_quant_grad_learnable_channel_stub, &_fake_quantize_grad_learnable_channel_kernel_cuda);

Expand Down
2 changes: 0 additions & 2 deletions aten/src/ATen/native/quantized/fake_quant_affine.h
Expand Up @@ -40,8 +40,6 @@ using fake_quant_per_channel_cachemask_fn = void (*)(
int64_t quant_min,
int64_t quant_max);

DECLARE_DISPATCH(fake_quant_per_channel_fn, fake_quant_per_channel_stub);
DECLARE_DISPATCH(fake_quant_per_channel_fn, fake_quant_grad_per_channel_stub);
DECLARE_DISPATCH(fake_quant_per_channel_cachemask_fn, fake_quant_per_channel_cachemask_stub);
DECLARE_DISPATCH(fake_quant_per_channel_fn, fake_quant_grad_learnable_channel_stub);

Expand Down
130 changes: 3 additions & 127 deletions aten/src/ATen/native/quantized/fake_quant_per_channel_affine.cpp
Expand Up @@ -10,8 +10,6 @@ namespace at {
namespace native {

// Use REGISTER_DISPATCH to run CPU and CUDA backend.
DEFINE_DISPATCH(fake_quant_per_channel_stub);
DEFINE_DISPATCH(fake_quant_grad_per_channel_stub);
DEFINE_DISPATCH(fake_quant_per_channel_cachemask_stub);
DEFINE_DISPATCH(fake_quant_grad_learnable_channel_stub);

Expand All @@ -36,50 +34,9 @@ Tensor fake_quantize_per_channel_affine(
int64_t axis,
int64_t quant_min,
int64_t quant_max) {
TORCH_CHECK(self.scalar_type() == ScalarType::Float);
TORCH_CHECK(scale.scalar_type() == ScalarType::Float,
"Scale must be Float, found ", scale.scalar_type());
TORCH_CHECK(zero_point.scalar_type() == ScalarType::Long,
"Zero-point must be Long, found ", zero_point.scalar_type());
TORCH_CHECK(scale.dim() == 1, "scale should be a 1-D tensor");
TORCH_CHECK(zero_point.dim() == 1, "zero point should be a 1-D tensor");
TORCH_CHECK(
scale.numel() == zero_point.numel(),
"scale and zero-point need to have the same dimensions");
TORCH_CHECK(
scale.numel() == self.size(axis),
"dimensions of scale and zero-point are not consistent with input tensor")

TORCH_CHECK(
quant_min <= quant_max,
"`quant_min` should be less than or \
equal to `quant_max`.");

TORCH_CHECK(
at::min(zero_point).item().toLong() >= quant_min &&
at::max(zero_point).item().toLong() <= quant_max,
"`zero_point` must be between `quant_min` and `quant_max`.");

TORCH_CHECK(
axis >= 0 && axis <= self.dim(),
"`axis` must be between 0 and number of dimensions of input");

auto Y = at::empty_like(self, self.options(), MemoryFormat::Preserve);

std::vector<int64_t> expected_shape(self.dim(), 1);
expected_shape[axis] = self.size(axis);

TensorIterator iter = TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(Y)
.add_input(self)
.add_input(native::_unsafe_view(scale, expected_shape))
.add_input(native::_unsafe_view(zero_point, expected_shape))
.build();

fake_quant_per_channel_stub(iter.device_type(), iter, quant_min, quant_max);

return Y;
const auto res = at::fake_quantize_per_channel_affine_cachemask(
self, scale, zero_point, axis, quant_min, quant_max);
return std::get<0>(res);
}

std::tuple<Tensor, Tensor> fake_quantize_per_channel_affine_cachemask(
Expand Down Expand Up @@ -172,87 +129,6 @@ Tensor fake_quantize_per_channel_affine_cachemask_backward(
return dY * mask;
}

/* Backward path for per-channel fake-quantization of the 'inputs' tensor.
Args:
X: Forward input tensor.
dY: Backward input tensor.
scale: scale of per tensor affine quantization
zero_point: zero_point of per tensor affine quantization
axis: int ,the axis over which quantization parameters vary
quant_min: int, minimum quantized value
quant_max: int, maximum quantized value
Returns:
Gradient for per channel fake quant (double dtype).
*/
Tensor fake_quantize_per_channel_affine_backward(
const Tensor& dY,
const Tensor& X,
const Tensor& scale,
const Tensor& zero_point,
int64_t axis,
int64_t quant_min,
int64_t quant_max) {
TORCH_CHECK(dY.scalar_type() == ScalarType::Float);
TORCH_CHECK(X.scalar_type() == ScalarType::Float);
TORCH_CHECK(scale.scalar_type() == ScalarType::Float,
"Scale must be Float, found ", scale.scalar_type());
TORCH_CHECK(zero_point.scalar_type() == ScalarType::Long,
"Zero-point must be Long, found ", zero_point.scalar_type());

TORCH_CHECK(X.sizes() == dY.sizes(), "`X` and `dY` are not the same size");
TORCH_CHECK(
quant_min <= quant_max,
"`quant_min` should be less than or \
equal to `quant_max`.");
TORCH_CHECK(scale.dim() == 1, "scale should be a 1-D tensor");
TORCH_CHECK(zero_point.dim() == 1, "zero point should be a 1-D tensor");
TORCH_CHECK(
scale.numel() == zero_point.numel(),
"scale and zero-point need to have the same dimensions");
TORCH_CHECK(
scale.numel() == X.size(axis),
"dimensions of scale and zero-point are not consistent with input tensor")

TORCH_CHECK(
quant_min <= quant_max,
"`quant_min` should be less than or \
equal to `quant_max`.");

TORCH_CHECK(
at::min(zero_point).item().toLong() >= quant_min &&
at::max(zero_point).item().toLong() <= quant_max,
"`zero_point` must be between `quant_min` and `quant_max`.");

TORCH_CHECK(
axis >= 0 && axis <= X.dim(),
"`axis` must be between 0 and number of dimensions of input");

if (X.numel() <= 0) {
return X;
}

auto dX = at::empty_like(X, X.options(), MemoryFormat::Preserve);

std::vector<int64_t> expected_shape(X.dim(), 1);
expected_shape[axis] = X.size(axis);

TensorIterator iter = TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(dX)
.add_input(X)
.add_input(dY)
.add_input(native::_unsafe_view(scale, expected_shape))
.add_input(native::_unsafe_view(zero_point, expected_shape))
.build();

fake_quant_grad_per_channel_stub(iter.device_type(), iter, quant_min, quant_max);

return dX;
}

Tensor _get_rounded_zero_point(
const Tensor& zero_point,
int64_t quant_min,
Expand Down
Expand Up @@ -57,6 +57,7 @@
("aten::_multinomial_alias_draw", datetime.date(2021, 1, 31)),
("prim::profile_optional", datetime.date(2021, 1, 31)),
("aten::fake_quantize_per_tensor_affine_backward", datetime.date(2021, 2, 20)),
("aten::fake_quantize_per_channel_affine_backward", datetime.date(2021, 2, 20)),
]

def allow_listed(schema, allow_list):
Expand Down
3 changes: 0 additions & 3 deletions tools/autograd/derivatives.yaml
Expand Up @@ -464,9 +464,6 @@
- name: _fake_quantize_learnable_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor
self, scale, zero_point: "grad.defined() ? _fake_quantize_learnable_per_tensor_affine_backward(grad, self, scale, zero_point, quant_min, quant_max) : std::tuple<Tensor, Tensor, Tensor>()"

- name: fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor
self: fake_quantize_per_channel_affine_backward(grad, self, scale, zero_point, axis, quant_min, quant_max)

- name: fake_quantize_per_channel_affine_cachemask(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor output, Tensor mask)
self: fake_quantize_per_channel_affine_cachemask_backward(grad, mask)

Expand Down
11 changes: 3 additions & 8 deletions torch/quantization/fake_quantize.py
Expand Up @@ -137,14 +137,9 @@ def forward(self, X):

if self.fake_quant_enabled[0] == 1:
if self.is_per_channel:
if self.training:
X, _mask = torch.fake_quantize_per_channel_affine_cachemask(
X, self.scale, self.zero_point,
self.ch_axis, self.quant_min, self.quant_max)
else:
X = torch.fake_quantize_per_channel_affine(
X, self.scale, self.zero_point,
self.ch_axis, self.quant_min, self.quant_max)
X = torch.fake_quantize_per_channel_affine(
X, self.scale, self.zero_point,
self.ch_axis, self.quant_min, self.quant_max)
else:
X = torch.fake_quantize_per_tensor_affine(
X, float(self.scale), int(self.zero_point),
Expand Down

0 comments on commit 3b53e9b

Please sign in to comment.