Skip to content

Commit

Permalink
mem-efficient fake_quant: delete old version
Browse files Browse the repository at this point in the history
Summary:

Switches the default fake_quant path to use the new memory efficient backward from
#50561.

Separating for clean testing and review, but ideally we combine
this with #50561.

Test Plan:

```
python test/test_quantization.py TestFakeQuantize.test_forward_per_tensor_cpu
python test/test_quantization.py TestFakeQuantize.test_forward_per_tensor_cuda
python test/test_quantization.py TestFakeQuantize.test_backward_per_tensor_cpu
python test/test_quantization.py TestFakeQuantize.test_backward_per_tensor_cuda
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: ba2787ee2a6b023bab36733ec32c91fb174f2cc7
Pull Request resolved: #50857
  • Loading branch information
vkuzo committed Jan 21, 2021
1 parent 6077315 commit dd0d21b
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 318 deletions.
12 changes: 2 additions & 10 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -4652,20 +4652,12 @@
dispatch:
QuantizedCPU, QuantizedCUDA: qscheme_quant

- func: fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor
- func: fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max, bool compute_mask) -> (Tensor output, Tensor mask)
variants: function
dispatch:
CPU, CUDA: fake_quantize_per_tensor_affine

- func: fake_quantize_per_tensor_affine_backward(Tensor grad, Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor
variants: function

- func: fake_quantize_per_tensor_affine_cachemask(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor output, Tensor mask)
variants: function
dispatch:
CPU, CUDA: fake_quantize_per_tensor_affine_cachemask

- func: fake_quantize_per_tensor_affine_cachemask_backward(Tensor grad, Tensor mask) -> Tensor
- func: fake_quantize_per_tensor_affine_backward(Tensor grad, Tensor mask, bool compute_mask) -> Tensor
variants: function

- func: _fake_quantize_learnable_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor
Expand Down
96 changes: 35 additions & 61 deletions aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
Expand Up @@ -2055,71 +2055,49 @@ void q_batch_norm_kernel(
}

void fake_quantize_tensor_kernel(
Tensor& output,
const Tensor& input,
float sc,
int64_t z_point,
int64_t quant_min,
int64_t quant_max) {
float inv_scale = 1.0f / sc;
auto iter = TensorIterator::unary_op(output, input);
cpu_kernel(iter, [&](float self) -> float {
return (std::fmin(
std::fmax(
static_cast<int64_t>(
z_point + std::nearbyint(self * inv_scale)),
quant_min),
quant_max) -
z_point) *
sc;
});
}

void fake_quantize_grad_tensor_kernel(
Tensor& input_grad,
const Tensor& input,
const Tensor& output_grad,
float sc,
int64_t z_point,
int64_t quant_min,
int64_t quant_max) {
float inv_scale = 1.0f / sc;
auto iter = TensorIterator::binary_op(input_grad, input, output_grad);
cpu_kernel(iter, [&](float x, float dy) -> float {
int64_t xq = static_cast<int64_t>(z_point + std::nearbyint(x * inv_scale));
return dy * (xq >= quant_min && xq <= quant_max);
});
}

void fake_quantize_tensor_cachemask_kernel(
Tensor& output,
Tensor& mask,
const Tensor& input,
float sc,
int64_t z_point,
int64_t quant_min,
int64_t quant_max) {
int64_t quant_max,
bool compute_mask) {
float inv_scale = 1.0f / sc;

auto iter_combined = TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(output)
.add_output(mask)
.add_input(input)
.build();

iter_combined.for_each([&](char** data, const int64_t* strides, int64_t n) {
for (int64_t i = 0; i < n; i++) {
float* output_val = (float*)(data[0] + i * strides[0]);
bool* mask_val = (bool*)(data[1] + i * strides[1]);
float* input_val = (float*)(data[2] + i * strides[2]);

const auto qval = static_cast<int64_t>(z_point + std::nearbyint(*input_val * inv_scale));
*output_val = (std::fmin(std::fmax(qval, quant_min), quant_max) - z_point) * sc;
*mask_val = ((quant_min <= qval) && (qval <= quant_max));
}
});

if (compute_mask) {
auto iter_combined = TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(output)
.add_output(mask)
.add_input(input)
.build();

iter_combined.for_each([&](char** data, const int64_t* strides, int64_t n) {
for (int64_t i = 0; i < n; i++) {
float* output_val = (float*)(data[0] + i * strides[0]);
bool* mask_val = (bool*)(data[1] + i * strides[1]);
float* input_val = (float*)(data[2] + i * strides[2]);

const auto qval = static_cast<int64_t>(z_point + std::nearbyint(*input_val * inv_scale));
*output_val = (std::fmin(std::fmax(qval, quant_min), quant_max) - z_point) * sc;
*mask_val = ((quant_min <= qval) && (qval <= quant_max));
}
});
} else {
// single output, mask computation is skipped
auto iter = TensorIterator::unary_op(output, input);
cpu_kernel(iter, [&](float self) -> float {
return (std::fmin(
std::fmax(
static_cast<int64_t>(
z_point + std::nearbyint(self * inv_scale)),
quant_min),
quant_max) -
z_point) *
sc;
});
}
}

void fake_quantize_learnable_tensor_grad_kernel_cpu(
Expand Down Expand Up @@ -3081,12 +3059,8 @@ REGISTER_DISPATCH(fake_quant_grad_learnable_tensor_stub,
&fake_quantize_learnable_tensor_grad_kernel_cpu);
REGISTER_DISPATCH(fake_quant_grad_per_channel_stub,
&fake_quant_grad_per_channel_cpu);
REGISTER_DISPATCH(fake_quant_grad_tensor_stub,
&fake_quantize_grad_tensor_kernel);
REGISTER_DISPATCH(fake_quant_per_channel_stub, &fake_quant_per_channel_cpu);
REGISTER_DISPATCH(fake_quant_tensor_stub, &fake_quantize_tensor_kernel);
REGISTER_DISPATCH(fake_quant_tensor_cachemask_stub,
&fake_quantize_tensor_cachemask_kernel);
REGISTER_DISPATCH(qadaptive_avg_pool2d_nhwc_stub,
&qadaptive_avg_pool2d_nhwc_kernel);
REGISTER_DISPATCH(qadaptive_avg_pool3d_ndhwc_stub,
Expand Down
108 changes: 39 additions & 69 deletions aten/src/ATen/native/quantized/cuda/fake_quantize_core.cu
Expand Up @@ -21,82 +21,54 @@ Returns:
namespace at {
namespace native {
void fake_quantize_tensor_kernel_cuda(
Tensor& output,
const Tensor& input,
float scale,
int64_t zero_point,
int64_t quant_min,
int64_t quant_max) {
// scalar type of this function is guaranteed to be float
float inv_scale = 1.0f / scale;
auto iter = TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(output)
.add_input(input)
.build();
gpu_kernel(iter, [=] GPU_LAMBDA(float input_val) -> float {
return (fminf(
quant_max,
fmaxf(
quant_min,
static_cast<int64_t>(
std::nearbyint(input_val * inv_scale) + zero_point))) -
zero_point) *
scale;
});
}

void fake_quantize_grad_tensor_kernel_cuda(
Tensor& input_grad,
const Tensor& input,
const Tensor& output_grad,
float scale,
int64_t zero_point,
int64_t quant_min,
int64_t quant_max) {
// scalar type of this function is guaranteed to be float
float inv_scale = 1.0f / scale;
auto iter = TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(input_grad)
.add_input(output_grad)
.add_input(input)
.build();
gpu_kernel(iter, [=] GPU_LAMBDA(float dy, float x) -> float {
int64_t Xq = std::nearbyint(x * inv_scale) + zero_point;
return (Xq >= quant_min && Xq <= quant_max) * dy;
});
}

void fake_quantize_tensor_cachemask_kernel_cuda(
Tensor& output,
Tensor& mask,
const Tensor& input,
float scale,
int64_t zero_point,
int64_t quant_min,
int64_t quant_max) {
int64_t quant_max,
bool compute_mask) {

float inv_scale = 1.0f / scale;
auto iter = TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(output)
.add_output(mask)
.add_input(input)
.build();

gpu_kernel_multiple_outputs(
iter,
[=] GPU_LAMBDA (float input_val) -> thrust::tuple<float, bool> {
const auto qval = static_cast<int64_t>(std::nearbyint(input_val * inv_scale) + zero_point);
return {
// fake_quantized value
(fminf(quant_max, fmaxf(quant_min, qval)) - zero_point) * scale,
// mask for grad
((quant_min <= qval) && (qval <= quant_max))
};
}
);
if (compute_mask) {
auto iter = TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(output)
.add_output(mask)
.add_input(input)
.build();

gpu_kernel_multiple_outputs(
iter,
[=] GPU_LAMBDA (float input_val) -> thrust::tuple<float, bool> {
const auto qval = static_cast<int64_t>(std::nearbyint(input_val * inv_scale) + zero_point);
return {
// fake_quantized value
(fminf(quant_max, fmaxf(quant_min, qval)) - zero_point) * scale,
// mask for grad
((quant_min <= qval) && (qval <= quant_max))
};
}
);
} else {
auto iter = TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(output)
.add_input(input)
.build();
gpu_kernel(iter, [=] GPU_LAMBDA(float input_val) -> float {
return (fminf(
quant_max,
fmaxf(
quant_min,
static_cast<int64_t>(
std::nearbyint(input_val * inv_scale) + zero_point))) -
zero_point) *
scale;
});
}
}

void _fake_quantize_grad_learnable_tensor_kernel_cuda(
Expand Down Expand Up @@ -127,8 +99,6 @@ void _fake_quantize_grad_learnable_tensor_kernel_cuda(
}

REGISTER_DISPATCH(fake_quant_tensor_stub, &fake_quantize_tensor_kernel_cuda);
REGISTER_DISPATCH(fake_quant_tensor_cachemask_stub, &fake_quantize_tensor_cachemask_kernel_cuda);
REGISTER_DISPATCH(fake_quant_grad_tensor_stub, &fake_quantize_grad_tensor_kernel_cuda);
REGISTER_DISPATCH(fake_quant_grad_learnable_tensor_stub, &_fake_quantize_grad_learnable_tensor_kernel_cuda);

// Fake quantize per channel
Expand Down
22 changes: 2 additions & 20 deletions aten/src/ATen/native/quantized/fake_quant_affine.h
Expand Up @@ -10,30 +10,14 @@ struct TensorIterator;
namespace native {

using fake_quant_tensor_fn = void (*)(
Tensor& output,
const Tensor& input,
float sc,
int64_t z_point,
int64_t quant_min,
int64_t quant_max);

using fake_quant_grad_tensor_fn = void (*)(
Tensor& input_grad,
const Tensor& input,
const Tensor& output_grad,
float sc,
int64_t z_point,
int64_t quant_min,
int64_t quant_max);

using fake_quant_tensor_cachemask_fn = void (*)(
Tensor& output,
Tensor& mask,
const Tensor& input,
float sc,
int64_t z_point,
int64_t quant_min,
int64_t quant_max);
int64_t quant_max,
bool compute_mask);

using fake_quant_learnable_grad_tensor_fn = void (*)(
TensorIterator& iter,
Expand All @@ -44,8 +28,6 @@ using fake_quant_learnable_grad_tensor_fn = void (*)(
int64_t quant_max);

DECLARE_DISPATCH(fake_quant_tensor_fn, fake_quant_tensor_stub);
DECLARE_DISPATCH(fake_quant_grad_tensor_fn, fake_quant_grad_tensor_stub);
DECLARE_DISPATCH(fake_quant_tensor_cachemask_fn, fake_quant_tensor_cachemask_stub);
DECLARE_DISPATCH(fake_quant_learnable_grad_tensor_fn, fake_quant_grad_learnable_tensor_stub);

using fake_quant_per_channel_fn = void (*)(
Expand Down

0 comments on commit dd0d21b

Please sign in to comment.