Skip to content

Commit

Permalink
[quant] Add quantized::leaky_relu that takes scale/zero_point as input (
Browse files Browse the repository at this point in the history
#45702)

Summary:
Pull Request resolved: #45702

#45593

Previously quantized leaky_relu does not require observation and just inherits
the quantization parameters from input, but that does not work very well in qat
This PR added a quantized::leaky_relu that has observation for output and it will
become the default leaky_relu that our quantization tools produce (eager/graph mode)

Test Plan: Imported from OSS

Reviewed By: raghuramank100

Differential Revision: D24067681

fbshipit-source-id: d216738344363794b82bd3d75c8587a4b9415bca
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed Oct 6, 2020
1 parent 001a799 commit d1fc155
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 14 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/native_functions.yaml
Expand Up @@ -6970,7 +6970,7 @@
python_module: nn
dispatch:
CPU, CUDA: leaky_relu
QuantizedCPU: heaky_relu_quantized_cpu
QuantizedCPU: leaky_relu_quantized_cpu

- func: leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result) -> Tensor
use_c10_dispatcher: full
Expand Down
21 changes: 20 additions & 1 deletion aten/src/ATen/native/quantized/cpu/qrelu.cpp
Expand Up @@ -113,7 +113,7 @@ Tensor& leaky_relu_out_quantized_cpu(Tensor& result, const Tensor& self,
return result;
}

Tensor heaky_relu_quantized_cpu(const Tensor& self, Scalar negval) {
Tensor leaky_relu_quantized_cpu(const Tensor& self, Scalar negval) {
const auto qx = self.contiguous(self.suggest_memory_format());
auto qy = at::_empty_affine_quantized(qx.sizes(),
at::device(kCPU).dtype(self.scalar_type()),
Expand Down Expand Up @@ -170,8 +170,27 @@ class QRelu6 final {
}
};

class QLeakyRelu final {
public:
static Tensor run(Tensor self, Scalar negative_slope, bool inplace, double output_scale, int64_t output_zero_point) {
// inplace argument is ignored now, TODO:support inplace
if (inplace) {
TORCH_WARN("inplace=True is not supported for quantized::leaky_relu yet");
}
const auto qx = self.contiguous(self.suggest_memory_format());
auto qy = at::_empty_affine_quantized(qx.sizes(),
at::device(kCPU).dtype(self.scalar_type()),
output_scale,
output_zero_point,
self.suggest_memory_format());
qrelu_leaky_stub(self.device().type(), qy, qx, negative_slope);
return qy;
}
};

TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
m.impl(TORCH_SELECTIVE_NAME("quantized::relu6"), TORCH_FN(QRelu6::run));
m.impl(TORCH_SELECTIVE_NAME("quantized::leaky_relu"), TORCH_FN(QLeakyRelu::run));
}

} // namespace
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/quantized/library.cpp
Expand Up @@ -157,6 +157,7 @@ TORCH_LIBRARY(quantized, m) {
m.def(TORCH_SELECTIVE_SCHEMA("quantized::max_pool1d(Tensor qx, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::max_pool2d(Tensor qx, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::relu6(Tensor qx, bool inplace=False) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::leaky_relu(Tensor qx, Scalar negative_slope, bool inplace, float output_scale, int output_zero_point) -> Tensor"));
}

// According to #33294: The "_" prefix registration will be
Expand Down
46 changes: 34 additions & 12 deletions test/quantization/test_quantized_op.py
Expand Up @@ -140,17 +140,17 @@ def _test_activation_function(self, X, fn_name, test_configs):
quantized_fn: a list of the quantized functions to be tested
reference_fn: the original reference function to be called on the
the dequantized X
inplace_kwarg: the additional inplace keyword argument to test in-place
extra_kwargs: the additional keyword arguments
for each test entry in ops_under_test, it must have at least the fields
for quantized_fn and reference_fn. If inplace_kwarg is missing, the
quantized function is assumed to be either inplace by default or the
test is not testing an inplace function.
for quantized_fn and reference_fn.
output_range: the output range the operator will map to. By default, if it is
no specified, the range will not be controlled and depend on Xmin and Xmax.
change_zero_point: a boolean flag indicating if the zero point parameter should
be determined based on torch_type during quantization (see sigmoid/hardsigmoid for
examples). By default, if it is not specified, change_zero_point is assumed to be
False and zero point will just take on the default value from X.
`output_is_observed`: if specified and is True, we'll append extra
output_scale/output_zero_point keyword argument when calling quantized op
"""
# Retrives the default parameters from X.
X, (scale, zero_point, torch_type) = X
Expand All @@ -162,15 +162,15 @@ def _test_activation_function(self, X, fn_name, test_configs):
for op_group in test_configs:
ref_op = op_group['reference_fn']
for q_op in op_group['quantized_fn']:
# Retrieves the inplace keyword arguments
# some functions require inplace=True to test in-place.
extra_kwargs = op_group.get('extra_kwargs', dict())
output_is_observed = op_group.get('output_is_observed', False)
# Quantizes and dequantizes to account for max error.
qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
dtype=torch_type)
dqX = qX.dequantize()
dqY_hat = ref_op(dqX.clone())

# Retrieves the inplace keyword arguments
# some functions require inplace=True to test in-place.
inplace_kwarg = op_group.get('inplace_kwarg', dict())
dqY_hat = ref_op(dqX.clone(), **extra_kwargs)

# Adjusts output_scale if needed.
# The output_scale determines the quantization scale for functions that
Expand All @@ -194,8 +194,11 @@ def _test_activation_function(self, X, fn_name, test_configs):
zero_point=output_zero_point,
dtype=torch_type)

if output_is_observed:
extra_kwargs.update({'output_scale': scale, 'output_zero_point': zero_point})

# Finds qY using in-place or non-in-place quantized operators.
qY = q_op(qX, **inplace_kwarg)
qY = q_op(qX, **extra_kwargs)

self.assertEqual(qY, qY_hat, msg='{} - {} failed: ({} vs. {})'.format(
fn_name, q_op, qY, qY_hat
Expand All @@ -222,7 +225,7 @@ def test_qrelu(self, X):
torch.nn.quantized.functional.relu,
],
'reference_fn': torch.nn.functional.relu,
'inplace_kwarg': {
'extra_kwargs': {
'inplace': True
}
}
Expand Down Expand Up @@ -280,11 +283,30 @@ def test_qhardsigmoid(self, X):
]
self._test_activation_function(X, 'hardsigmoid', hardsigmoid_test_configs)

@override_qengines
@given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5),
qparams=hu.qparams()))
def test_leaky_relu_observed_output(self, X):
leaky_relu_test_configs = [
{
'quantized_fn': [
torch.ops.quantized.leaky_relu
],
'reference_fn': torch.nn.functional.leaky_relu,
'extra_kwargs': {
'negative_slope': 0.1,
'inplace': False,
},
'output_is_observed': True,
}
]
self._test_activation_function(X, 'leaky_relu', leaky_relu_test_configs)

"""Tests the correctness of the quantized::relu op."""
@given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5),
qparams=hu.qparams()),
alpha=st.floats(0.0, 1.0, allow_nan=False, allow_infinity=False))
def test_qrelu_leaky(self, X, alpha):
def test_leaky_relu(self, X, alpha):
X, (scale, zero_point, torch_type) = X

X = torch.from_numpy(X)
Expand Down

0 comments on commit d1fc155

Please sign in to comment.