Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

memory efficient per-channel fq: use it everywhere, delete old version #51265

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 1 addition & 4 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4658,10 +4658,7 @@
- func: fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor
variants: function
dispatch:
CPU, CUDA: fake_quantize_per_channel_affine

- func: fake_quantize_per_channel_affine_backward(Tensor grad, Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor
variants: function
Math: fake_quantize_per_channel_affine

- func: fake_quantize_per_channel_affine_cachemask(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor output, Tensor mask)
variants: function
Expand Down
33 changes: 0 additions & 33 deletions aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2134,23 +2134,6 @@ void fake_quantize_learnable_tensor_grad_kernel_cpu(
});
}

void fake_quant_per_channel_cpu(
TensorIterator& iter,
int64_t quant_min,
int64_t quant_max) {
cpu_kernel(iter, [=](float self, float scale, int64_t zero_point) -> float {
float inv_scale = 1.0f / scale;
return (std::fmin(
std::fmax(
static_cast<int64_t>(
zero_point + std::nearbyint(self * inv_scale)),
quant_min),
quant_max) -
zero_point) *
scale;
});
}

void fake_quant_per_channel_cachemask_cpu(
TensorIterator& iter,
TensorIterator& iter_mask,
Expand Down Expand Up @@ -2180,19 +2163,6 @@ void fake_quant_per_channel_cachemask_cpu(
});
}

void fake_quant_grad_per_channel_cpu(
TensorIterator& iter,
int64_t quant_min,
int64_t quant_max) {
cpu_kernel(
iter, [=](float x, float dy, float scale, int64_t zero_point) -> float {
float inv_scale = 1.0f / scale;
int64_t xq =
static_cast<int64_t>(zero_point + std::nearbyint(x * inv_scale));
return dy * (xq >= quant_min && xq <= quant_max);
});
}

void fake_quantize_learnable_channel_grad_kernel_cpu(
TensorIterator& iter,
int64_t quant_min,
Expand Down Expand Up @@ -3072,9 +3042,6 @@ REGISTER_DISPATCH(dequantize_tensor_per_channel_float_qparams_stub,
&dequantize_tensor_per_channel_float_qparams_cpu);
REGISTER_DISPATCH(fake_quant_grad_learnable_tensor_stub,
&fake_quantize_learnable_tensor_grad_kernel_cpu);
REGISTER_DISPATCH(fake_quant_grad_per_channel_stub,
&fake_quant_grad_per_channel_cpu);
REGISTER_DISPATCH(fake_quant_per_channel_stub, &fake_quant_per_channel_cpu);
REGISTER_DISPATCH(fake_quant_per_channel_cachemask_stub, &fake_quant_per_channel_cachemask_cpu);
REGISTER_DISPATCH(fake_quant_tensor_cachemask_stub,
&fake_quantize_tensor_cachemask_kernel);
Expand Down
27 changes: 0 additions & 27 deletions aten/src/ATen/native/quantized/cuda/fake_quantize_core.cu
Original file line number Diff line number Diff line change
Expand Up @@ -83,22 +83,6 @@ REGISTER_DISPATCH(fake_quant_grad_learnable_tensor_stub, &_fake_quantize_grad_le

// Fake quantize per channel

void fake_quant_per_channel_cuda(TensorIterator &iter, int64_t quant_min, int64_t quant_max) {
gpu_kernel(iter,
[=] GPU_LAMBDA (float input_val, float scale, int64_t zero_point) -> float {
float inv_scale = 1.0f / scale;
return (fminf(
quant_max,
fmaxf(
quant_min,
static_cast<int64_t>(
std::nearbyint(input_val * inv_scale) +
zero_point))) -
zero_point) *
scale;
});
}

void fake_quant_per_channel_cachemask_cuda(
TensorIterator &iter, TensorIterator &iter_mask, int64_t quant_min, int64_t quant_max) {
// TODO(future, optional): read once, write twice. Not done at the moment
Expand Down Expand Up @@ -128,15 +112,6 @@ void fake_quant_per_channel_cachemask_cuda(
});
}

void fake_quant_grad_per_channel_cuda(TensorIterator &iter, int64_t quant_min, int64_t quant_max) {
gpu_kernel(iter,
[=] GPU_LAMBDA (float x, float dy, float scale, int64_t zero_point) -> float {
float inv_scale = 1.0f / scale;
int64_t Xq = std::nearbyint(x * inv_scale) + zero_point;
return (Xq >= quant_min && Xq <= quant_max) * dy;
});
}

void _fake_quantize_grad_learnable_channel_kernel_cuda(TensorIterator &iter, int64_t quant_min, int64_t quant_max) {
gpu_kernel_multiple_outputs(iter,
[=] GPU_LAMBDA (float x_input, float dy_input, float scale_input, float zero_point_input) -> thrust::tuple<float, float, float> {
Expand All @@ -161,8 +136,6 @@ void _fake_quantize_grad_learnable_channel_kernel_cuda(TensorIterator &iter, int
});
}

REGISTER_DISPATCH(fake_quant_per_channel_stub, &fake_quant_per_channel_cuda);
REGISTER_DISPATCH(fake_quant_grad_per_channel_stub, &fake_quant_grad_per_channel_cuda);
REGISTER_DISPATCH(fake_quant_per_channel_cachemask_stub, &fake_quant_per_channel_cachemask_cuda);
REGISTER_DISPATCH(fake_quant_grad_learnable_channel_stub, &_fake_quantize_grad_learnable_channel_kernel_cuda);

Expand Down
2 changes: 0 additions & 2 deletions aten/src/ATen/native/quantized/fake_quant_affine.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ using fake_quant_per_channel_cachemask_fn = void (*)(
int64_t quant_min,
int64_t quant_max);

DECLARE_DISPATCH(fake_quant_per_channel_fn, fake_quant_per_channel_stub);
DECLARE_DISPATCH(fake_quant_per_channel_fn, fake_quant_grad_per_channel_stub);
DECLARE_DISPATCH(fake_quant_per_channel_cachemask_fn, fake_quant_per_channel_cachemask_stub);
DECLARE_DISPATCH(fake_quant_per_channel_fn, fake_quant_grad_learnable_channel_stub);

Expand Down
130 changes: 3 additions & 127 deletions aten/src/ATen/native/quantized/fake_quant_per_channel_affine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ namespace at {
namespace native {

// Use REGISTER_DISPATCH to run CPU and CUDA backend.
DEFINE_DISPATCH(fake_quant_per_channel_stub);
DEFINE_DISPATCH(fake_quant_grad_per_channel_stub);
DEFINE_DISPATCH(fake_quant_per_channel_cachemask_stub);
DEFINE_DISPATCH(fake_quant_grad_learnable_channel_stub);

Expand All @@ -36,50 +34,9 @@ Tensor fake_quantize_per_channel_affine(
int64_t axis,
int64_t quant_min,
int64_t quant_max) {
TORCH_CHECK(self.scalar_type() == ScalarType::Float);
TORCH_CHECK(scale.scalar_type() == ScalarType::Float,
"Scale must be Float, found ", scale.scalar_type());
TORCH_CHECK(zero_point.scalar_type() == ScalarType::Long,
"Zero-point must be Long, found ", zero_point.scalar_type());
TORCH_CHECK(scale.dim() == 1, "scale should be a 1-D tensor");
TORCH_CHECK(zero_point.dim() == 1, "zero point should be a 1-D tensor");
TORCH_CHECK(
scale.numel() == zero_point.numel(),
"scale and zero-point need to have the same dimensions");
TORCH_CHECK(
scale.numel() == self.size(axis),
"dimensions of scale and zero-point are not consistent with input tensor")

TORCH_CHECK(
quant_min <= quant_max,
"`quant_min` should be less than or \
equal to `quant_max`.");

TORCH_CHECK(
at::min(zero_point).item().toLong() >= quant_min &&
at::max(zero_point).item().toLong() <= quant_max,
"`zero_point` must be between `quant_min` and `quant_max`.");

TORCH_CHECK(
axis >= 0 && axis <= self.dim(),
"`axis` must be between 0 and number of dimensions of input");

auto Y = at::empty_like(self, self.options(), MemoryFormat::Preserve);

std::vector<int64_t> expected_shape(self.dim(), 1);
expected_shape[axis] = self.size(axis);

TensorIterator iter = TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(Y)
.add_input(self)
.add_input(native::_unsafe_view(scale, expected_shape))
.add_input(native::_unsafe_view(zero_point, expected_shape))
.build();

fake_quant_per_channel_stub(iter.device_type(), iter, quant_min, quant_max);

return Y;
const auto res = at::fake_quantize_per_channel_affine_cachemask(
self, scale, zero_point, axis, quant_min, quant_max);
return std::get<0>(res);
}

std::tuple<Tensor, Tensor> fake_quantize_per_channel_affine_cachemask(
Expand Down Expand Up @@ -172,87 +129,6 @@ Tensor fake_quantize_per_channel_affine_cachemask_backward(
return dY * mask;
}

/* Backward path for per-channel fake-quantization of the 'inputs' tensor.

Args:
X: Forward input tensor.
dY: Backward input tensor.
scale: scale of per tensor affine quantization
zero_point: zero_point of per tensor affine quantization
axis: int ,the axis over which quantization parameters vary
quant_min: int, minimum quantized value
quant_max: int, maximum quantized value

Returns:
Gradient for per channel fake quant (double dtype).

*/
Tensor fake_quantize_per_channel_affine_backward(
const Tensor& dY,
const Tensor& X,
const Tensor& scale,
const Tensor& zero_point,
int64_t axis,
int64_t quant_min,
int64_t quant_max) {
TORCH_CHECK(dY.scalar_type() == ScalarType::Float);
TORCH_CHECK(X.scalar_type() == ScalarType::Float);
TORCH_CHECK(scale.scalar_type() == ScalarType::Float,
"Scale must be Float, found ", scale.scalar_type());
TORCH_CHECK(zero_point.scalar_type() == ScalarType::Long,
"Zero-point must be Long, found ", zero_point.scalar_type());

TORCH_CHECK(X.sizes() == dY.sizes(), "`X` and `dY` are not the same size");
TORCH_CHECK(
quant_min <= quant_max,
"`quant_min` should be less than or \
equal to `quant_max`.");
TORCH_CHECK(scale.dim() == 1, "scale should be a 1-D tensor");
TORCH_CHECK(zero_point.dim() == 1, "zero point should be a 1-D tensor");
TORCH_CHECK(
scale.numel() == zero_point.numel(),
"scale and zero-point need to have the same dimensions");
TORCH_CHECK(
scale.numel() == X.size(axis),
"dimensions of scale and zero-point are not consistent with input tensor")

TORCH_CHECK(
quant_min <= quant_max,
"`quant_min` should be less than or \
equal to `quant_max`.");

TORCH_CHECK(
at::min(zero_point).item().toLong() >= quant_min &&
at::max(zero_point).item().toLong() <= quant_max,
"`zero_point` must be between `quant_min` and `quant_max`.");

TORCH_CHECK(
axis >= 0 && axis <= X.dim(),
"`axis` must be between 0 and number of dimensions of input");

if (X.numel() <= 0) {
return X;
}

auto dX = at::empty_like(X, X.options(), MemoryFormat::Preserve);

std::vector<int64_t> expected_shape(X.dim(), 1);
expected_shape[axis] = X.size(axis);

TensorIterator iter = TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(dX)
.add_input(X)
.add_input(dY)
.add_input(native::_unsafe_view(scale, expected_shape))
.add_input(native::_unsafe_view(zero_point, expected_shape))
.build();

fake_quant_grad_per_channel_stub(iter.device_type(), iter, quant_min, quant_max);

return dX;
}

Tensor _get_rounded_zero_point(
const Tensor& zero_point,
int64_t quant_min,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
("aten::_multinomial_alias_draw", datetime.date(2021, 1, 31)),
("prim::profile_optional", datetime.date(2021, 1, 31)),
("aten::fake_quantize_per_tensor_affine_backward", datetime.date(2021, 2, 20)),
("aten::fake_quantize_per_channel_affine_backward", datetime.date(2021, 2, 20)),
]

def allow_listed(schema, allow_list):
Expand Down
3 changes: 0 additions & 3 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -464,9 +464,6 @@
- name: _fake_quantize_learnable_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor
self, scale, zero_point: "grad.defined() ? _fake_quantize_learnable_per_tensor_affine_backward(grad, self, scale, zero_point, quant_min, quant_max) : std::tuple<Tensor, Tensor, Tensor>()"

- name: fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor
self: fake_quantize_per_channel_affine_backward(grad, self, scale, zero_point, axis, quant_min, quant_max)

- name: fake_quantize_per_channel_affine_cachemask(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor output, Tensor mask)
self: fake_quantize_per_channel_affine_cachemask_backward(grad, mask)

Expand Down
11 changes: 3 additions & 8 deletions torch/quantization/fake_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,9 @@ def forward(self, X):

if self.fake_quant_enabled[0] == 1:
if self.is_per_channel:
if self.training:
X, _mask = torch.fake_quantize_per_channel_affine_cachemask(
X, self.scale, self.zero_point,
self.ch_axis, self.quant_min, self.quant_max)
else:
X = torch.fake_quantize_per_channel_affine(
X, self.scale, self.zero_point,
self.ch_axis, self.quant_min, self.quant_max)
X = torch.fake_quantize_per_channel_affine(
X, self.scale, self.zero_point,
self.ch_axis, self.quant_min, self.quant_max)
else:
X = torch.fake_quantize_per_tensor_affine(
X, float(self.scale), int(self.zero_point),
Expand Down