Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add hardswish FP operator #34747

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
28 changes: 28 additions & 0 deletions aten/src/ATen/native/Activation.cpp
Expand Up @@ -23,6 +23,8 @@ DEFINE_DISPATCH(threshold_stub);
DEFINE_DISPATCH(hardtanh_backward_stub);
DEFINE_DISPATCH(hardsigmoid_stub);
DEFINE_DISPATCH(hardsigmoid_backward_stub);
DEFINE_DISPATCH(hardswish_stub);
DEFINE_DISPATCH(hardswish_backward_stub);
DEFINE_DISPATCH(hardshrink_stub);
DEFINE_DISPATCH(softshrink_stub);
DEFINE_DISPATCH(shrink_backward_stub);
Expand Down Expand Up @@ -136,6 +138,32 @@ Tensor elu_backward(
return iter.output();
}

Tensor hardswish(const Tensor& self) {
Tensor result;
auto iter = TensorIterator::unary_op(result, self);
hardswish_stub(iter.device_type(), iter);
return iter.output();
}

Tensor& hardswish_out(Tensor& result, const Tensor& self) {
auto iter = TensorIterator::unary_op(result, self);
hardswish_stub(iter.device_type(), iter);
return result;
}

Tensor& hardswish_(Tensor& self) {
auto iter = TensorIterator::unary_op(self, self);
hardswish_stub(iter.device_type(), iter);
return self;
}

Tensor hardswish_backward(const Tensor& grad_output, const Tensor& self) {
Tensor grad_input;
auto iter = TensorIterator::binary_op(grad_input, grad_output, self);
hardswish_backward_stub(iter.device_type(), iter);
return iter.output();
}

Tensor relu(const Tensor & self) {
return at::threshold(self, 0, 0);
}
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/Activation.h
Expand Up @@ -18,6 +18,8 @@ using threshold_fn = void (*)(TensorIterator&, Scalar, Scalar);
using hardtanh_backward_fn = void (*)(TensorIterator&, Scalar, Scalar);
using hardsigmoid_fn = void(*)(TensorIterator&);
using hardsigmoid_backward_fn = void(*)(TensorIterator&);
using hardswish_fn = void(*)(TensorIterator&);
using hardswish_backward_fn = void(*)(TensorIterator&);
using shrink_fn = void (*)(TensorIterator&, Scalar);
using shrink_backward_fn = void (*)(TensorIterator&, Scalar);
using elu_fn = void (*)(TensorIterator&, Scalar, Scalar, Scalar);
Expand All @@ -37,6 +39,8 @@ DECLARE_DISPATCH(activation_backward_fn, GeluBackwardKernel);
DECLARE_DISPATCH(hardtanh_backward_fn, hardtanh_backward_stub);
DECLARE_DISPATCH(hardsigmoid_fn, hardsigmoid_stub);
DECLARE_DISPATCH(hardsigmoid_backward_fn, hardsigmoid_backward_stub);
DECLARE_DISPATCH(hardswish_fn, hardswish_stub);
DECLARE_DISPATCH(hardswish_backward_fn, hardswish_backward_stub);
DECLARE_DISPATCH(shrink_fn, hardshrink_stub);
DECLARE_DISPATCH(shrink_fn, softshrink_stub);
DECLARE_DISPATCH(shrink_backward_fn, shrink_backward_stub);
Expand Down
63 changes: 63 additions & 0 deletions aten/src/ATen/native/cpu/Activation.cpp
Expand Up @@ -412,6 +412,67 @@ void hardtanh_backward_kernel(TensorIterator& iter, Scalar min, Scalar max) {
});
}

void hardswish_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "hardswish_cpu", [&]() {
const scalar_t zero(0.0f);
const scalar_t three(3.0f);
const scalar_t six(6.0f);
using Vec = vec256::Vec256<scalar_t>;
const Vec kZeroVec(zero);
const Vec kThreeVec(three);
const Vec kSixVec(six);
cpu_kernel_vec(
iter,
[&](scalar_t x) {
return x * std::min(std::max(x + three, zero), six) / six;
},
[&](Vec x_vec) {
return x_vec * vec256::minimum(
vec256::maximum(x_vec + kThreeVec, kZeroVec),
kSixVec
) / kSixVec;
}
);
});
}

void hardswish_backward_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "hardswish_backward_cpu", [&]() {
const scalar_t zero(0.0f);
const scalar_t three(3.0f);
const scalar_t neg_three(-3.0f);
const scalar_t one_half(0.5f);
using Vec = vec256::Vec256<scalar_t>;
const Vec kZeroVec(zero);
const Vec kThreeVec(three);
const Vec kNegThreeVec(neg_three);
const Vec kOneHalfVec(one_half);
cpu_kernel_vec(
iter,
[&](scalar_t grad_val, scalar_t self_val) {
if (self_val < neg_three) {
return zero;
} else if (self_val <= three) {
return grad_val * ((self_val / three) + one_half);
} else {
return grad_val;
}
},
[&](Vec grad_val, Vec self_val) {
return Vec::blendv(
Vec::blendv(
grad_val * ((self_val / kThreeVec) + kOneHalfVec),
grad_val,
self_val >= kThreeVec
),
kZeroVec,
self_val < kNegThreeVec
);
}
);
});
}

static void leaky_relu_kernel(TensorIterator& iter, Scalar negval_) {
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "leaky_relu_cpu", [&] {
using Vec = Vec256<scalar_t>;
Expand Down Expand Up @@ -538,6 +599,8 @@ REGISTER_DISPATCH(GeluBackwardKernel, &GeluBackwardKernelImpl);
REGISTER_DISPATCH(hardtanh_backward_stub, &hardtanh_backward_kernel);
REGISTER_DISPATCH(hardsigmoid_stub, &hardsigmoid_kernel);
REGISTER_DISPATCH(hardsigmoid_backward_stub, &hardsigmoid_backward_kernel);
REGISTER_DISPATCH(hardswish_stub, &hardswish_kernel);
REGISTER_DISPATCH(hardswish_backward_stub, &hardswish_backward_kernel);
REGISTER_DISPATCH(hardshrink_stub, &hardshrink_kernel);
REGISTER_DISPATCH(softshrink_stub, &softshrink_kernel);
REGISTER_DISPATCH(shrink_backward_stub, &shrink_backward_kernel);
Expand Down
14 changes: 14 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -5676,6 +5676,20 @@
CUDA: hardtanh_
QuantizedCPU: quantized_hardtanh_

- func: hardswish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
python_module: nn

- func: hardswish(Tensor self) -> Tensor
use_c10_dispatcher: full
python_module: nn

- func: hardswish_(Tensor(a!) self) -> Tensor(a!)
python_module: nn

- func: hardswish_backward(Tensor grad_output, Tensor self) -> Tensor
use_c10_dispatcher: full
python_module: nn

- func: leaky_relu.out(Tensor self, Scalar negative_slope=0.01, *, Tensor(a!) out) -> Tensor(a!)
python_module: nn
dispatch:
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/operator_benchmark/benchmark_all_other_test.py
Expand Up @@ -8,7 +8,7 @@
add_test, as_strided_test, batchnorm_test, binary_test, cat_test, # noqa
chunk_test, conv_test, diag_test, embeddingbag_test, fill_test, # noqa
gather_test, linear_test, matmul_test, pool_test, # noqa
softmax_test, hardsigmoid_test # noqa
softmax_test, hardsigmoid_test, hardswish_test # noqa
)

if __name__ == "__main__":
Expand Down
66 changes: 66 additions & 0 deletions benchmarks/operator_benchmark/pt/hardswish_test.py
@@ -0,0 +1,66 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals


import operator_benchmark as op_bench
import torch
import torch.nn as nn


"""
Microbenchmarks for the hardswish operators.
"""


# Configs for hardswish ops
hardswish_configs_short = op_bench.config_list(
attr_names=[
'N', 'C', 'H', 'W'
],
attrs=[
[1, 3, 256, 256],
[4, 3, 256, 256],
],
cross_product_configs={
'device': ['cpu'],
},
tags=['short']
)


hardswish_configs_long = op_bench.cross_product_configs(
N=[8, 16],
C=[3],
H=[256, 512],
W=[256, 512],
device=['cpu'],
tags=['long']
)


hardswish_ops_list = op_bench.op_list(
attr_names=['op_name', 'op_func'],
attrs=[
['Hardswish', nn.Hardswish],
],
)


class HardswishBenchmark(op_bench.TorchBenchmarkBase):
def init(self, N, C, H, W, device, op_func):
self.input_one = torch.rand(N, C, H, W, device=device)
self.op_func = op_func()

def forward(self):
return self.op_func(self.input_one)


op_bench.generate_pt_tests_from_op_list(hardswish_ops_list,
hardswish_configs_short + hardswish_configs_long,
HardswishBenchmark)


if __name__ == "__main__":
op_bench.benchmark_runner.main()
23 changes: 23 additions & 0 deletions test/test_torch.py
Expand Up @@ -14159,6 +14159,29 @@ def test_exp_slow(self, device, dtype):
b = torch.exp(torch.ones(1, dtype=dtype, device=device))
self.assertEqual(a, b.expand(2 ** 31))

@onlyCPU
@dtypes(torch.float, torch.double)
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_hardswish(self, device, dtype):
inputValues = [-1000, -4, -3, -2, 0, 2, 3, 4, 1000]
expectedOutput = np.multiply(
inputValues,
np.minimum(np.maximum((np.add(inputValues, 3)), 0), 6) / 6.0)
precision_4dps = 0.0002

inputTensor = torch.tensor(inputValues, dtype=dtype, device=device)
expectedOutputTensor = \
torch.tensor(expectedOutput, dtype=dtype, device=device)

# normal
self.assertEqual(torch.nn.functional.hardswish(inputTensor),
expectedOutputTensor, precision_4dps)

# inplace
inputTensorCpy = inputTensor.clone().detach()
torch.nn.functional.hardswish(inputTensorCpy, inplace=True)
self.assertEqual(inputTensorCpy, expectedOutputTensor, precision_4dps)

@onlyCPU
@dtypes(torch.float, torch.double)
def test_sigmoid(self, device, dtype):
Expand Down
3 changes: 3 additions & 0 deletions tools/autograd/derivatives.yaml
Expand Up @@ -417,6 +417,9 @@
- name: histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor
self: not_implemented("histc")

- name: hardswish(Tensor self) -> Tensor
self: hardswish_backward(grad, self)

- name: imag(Tensor self) -> Tensor
self: Scalar(std::complex<double>{0.0, 1.0})*grad.to(self.scalar_type())

Expand Down
1 change: 1 addition & 0 deletions torch/_overrides.py
Expand Up @@ -133,6 +133,7 @@ def get_ignored_functions():
torch.clear_autocast_cache,
torch.autocast_increment_nesting,
torch.autocast_decrement_nesting,
torch.nn.functional.hardswish,
)

def get_testing_overrides():
Expand Down
29 changes: 25 additions & 4 deletions torch/nn/functional.py
Expand Up @@ -1637,6 +1637,27 @@ def bilinear(input1, input2, weight, bias=None):
return torch.bilinear(input1, input2, weight, bias)


def hardswish(input, inplace=False):
r"""Applies the hardswish function, element-wise, as described in the paper:

`Searching for MobileNetV3`_.

.. math::
\text{Hardswish}(x) = x * \frac{ReLU6(x + 3)}{6}

See :class:`~torch.nn.Hardswish` for more details.

.. _`Searching for MobileNetV3`:
https://arxiv.org/abs/1905.02244
"""
if not torch.jit.is_scripting():
if type(input) is not Tensor and has_torch_function((input,)):
return handle_torch_function(hardswish, (input,), input, inplace=inplace)
if inplace:
return torch._C._nn.hardswish_(input)
return torch._C._nn.hardswish(input)


def _no_grad_embedding_renorm_(weight, input, max_norm, norm_type):
# type: (Tensor, Tensor, float, float) -> Tensor
with torch.no_grad():
Expand Down Expand Up @@ -2821,7 +2842,7 @@ def _interp_output_size(dim, closed_over_args): # noqa: F811
if size is None and scale_factor is None:
raise ValueError('either size or scale_factor should be defined')
if size is not None and scale_factor is not None:
raise ValueError('only one of size or scale_factor should be defined')
raise ValueError('only one of size or scale_factor should be defined')
if scale_factor is not None:
if isinstance(scale_factor, (list, tuple)):
if len(scale_factor) != dim:
Expand Down Expand Up @@ -2987,7 +3008,7 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne
if input.dim() == 3 and mode == 'nearest':
return torch._C._nn.upsample_nearest1d(input, _interp_output_size(1, closed_over_args), scale_factor_list[0])
elif input.dim() == 4 and mode == 'nearest':
return torch._C._nn.upsample_nearest2d(input, _interp_output_size(2, closed_over_args),
return torch._C._nn.upsample_nearest2d(input, _interp_output_size(2, closed_over_args),
scale_factor_list[0], scale_factor_list[1])
elif input.dim() == 5 and mode == 'nearest':
return torch._C._nn.upsample_nearest3d(input, _interp_output_size(3, closed_over_args),
Expand All @@ -3009,7 +3030,7 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne
raise NotImplementedError("Got 4D input, but linear mode needs 3D input")
elif input.dim() == 4 and mode == 'bilinear':
assert align_corners is not None
return torch._C._nn.upsample_bilinear2d(input, _interp_output_size(2, closed_over_args), align_corners,
return torch._C._nn.upsample_bilinear2d(input, _interp_output_size(2, closed_over_args), align_corners,
scale_factor_list[0], scale_factor_list[1])
elif input.dim() == 4 and mode == 'trilinear':
raise NotImplementedError("Got 4D input, but trilinear mode needs 5D input")
Expand All @@ -3023,7 +3044,7 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corne
scale_factor_list[0], scale_factor_list[1], scale_factor_list[2])
elif input.dim() == 4 and mode == 'bicubic':
assert align_corners is not None
return torch._C._nn.upsample_bicubic2d(input, _interp_output_size(2, closed_over_args), align_corners,
return torch._C._nn.upsample_bicubic2d(input, _interp_output_size(2, closed_over_args), align_corners,
scale_factor_list[0], scale_factor_list[1])
else:
raise NotImplementedError("Input Error: Only 3D, 4D and 5D input Tensors supported"
Expand Down
4 changes: 2 additions & 2 deletions torch/nn/modules/__init__.py
Expand Up @@ -5,7 +5,7 @@
from .activation import Threshold, ReLU, Hardtanh, ReLU6, Sigmoid, Tanh, \
Softmax, Softmax2d, LogSoftmax, ELU, SELU, CELU, GELU, Hardshrink, LeakyReLU, LogSigmoid, \
Softplus, Softshrink, MultiheadAttention, PReLU, Softsign, Softmin, Tanhshrink, RReLU, GLU, \
Hardsigmoid
Hardsigmoid, Hardswish
from .loss import L1Loss, NLLLoss, KLDivLoss, MSELoss, BCELoss, BCEWithLogitsLoss, NLLLoss2d, \
CosineEmbeddingLoss, CTCLoss, HingeEmbeddingLoss, MarginRankingLoss, \
MultiLabelMarginLoss, MultiLabelSoftMarginLoss, MultiMarginLoss, \
Expand Down Expand Up @@ -54,5 +54,5 @@
'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold',
'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder',
'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer',
'Flatten', 'Hardsigmoid',
'Flatten', 'Hardsigmoid', 'Hardswish',
]
27 changes: 27 additions & 0 deletions torch/nn/modules/activation.py
Expand Up @@ -318,6 +318,33 @@ def forward(self, input):
return torch.tanh(input)


class Hardswish(Module):
r"""Applies the hardswish function, element-wise, as described in the paper:

`Searching for MobileNetV3`_.

.. math::
\text{Hardswish}(x) = x * \frac{ReLU6(x + 3)}{6}

Shape:
- Input: :math:`(N, *)` where `*` means, any number of additional
dimensions
- Output: :math:`(N, *)`, same shape as the input

Examples::

>>> m = nn.Hardswish()
>>> input = torch.randn(2)
>>> output = m(input)

.. _`Searching for MobileNetV3`:
https://arxiv.org/abs/1905.02244
"""

def forward(self, input):
return F.hardswish(input)


class ELU(Module):
r"""Applies the element-wise function:

Expand Down