Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Half support for sigmoid on CPU #96077

Closed
wants to merge 9 commits into from
8 changes: 8 additions & 0 deletions aten/src/ATen/Dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,14 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
AT_DISPATCH_SWITCH( \
TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(__VA_ARGS__))

#define AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)

#define AT_DISPATCH_REDUCED_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__))

#define AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, ...) \
AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
Expand Down
33 changes: 18 additions & 15 deletions aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,25 @@ inline namespace CPU_CAPABILITY {
using namespace vec;

static void sigmoid_kernel(TensorIteratorBase& iter) {
if (iter.common_dtype() == kBFloat16) {
cpu_kernel_vec(
iter,
[=](BFloat16 a) -> BFloat16 {
float a0 = static_cast<float>(a);
return static_cast<float>(1) / (static_cast<float>(1) + std::exp((-a0)));
},
[=](Vectorized<BFloat16> a) {
Vectorized<float> a0, a1;
std::tie(a0, a1) = convert_bfloat16_float(a);
a0 = (Vectorized<float>(static_cast<float>(1)) + a0.neg().exp()).reciprocal();
a1 = (Vectorized<float>(static_cast<float>(1)) + a1.neg().exp()).reciprocal();
return convert_float_bfloat16(a0, a1);
});
const auto dtype = iter.common_dtype();
if (at::isReducedFloatingType(dtype)) {
AT_DISPATCH_REDUCED_FLOATING_TYPES(dtype, "sigmoid_cpu_reduced_float", [&]() {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t {
float a0 = static_cast<float>(a);
return static_cast<float>(1) / (static_cast<float>(1) + std::exp((-a0)));
},
[=](Vectorized<scalar_t> a) {
Vectorized<float> a0, a1;
std::tie(a0, a1) = convert_to_float<scalar_t>(a);
a0 = (Vectorized<float>(static_cast<float>(1)) + a0.neg().exp()).reciprocal();
a1 = (Vectorized<float>(static_cast<float>(1)) + a1.neg().exp()).reciprocal();
return convert_from_float<scalar_t>(a0, a1);
});
});
} else {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.common_dtype(), "sigmoid_cpu", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(dtype, "sigmoid_cpu", [&]() {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t {
Expand Down
4 changes: 4 additions & 0 deletions c10/core/ScalarType.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,10 @@ static inline bool isFloatingType(ScalarType t) {
t == ScalarType::Half || t == ScalarType::BFloat16);
}

static inline bool isReducedFloatingType(ScalarType t) {
return (t == ScalarType::Half || t == ScalarType::BFloat16);
}

static inline bool isComplexType(ScalarType t) {
return (
t == ScalarType::ComplexHalf || t == ScalarType::ComplexFloat ||
Expand Down
2 changes: 1 addition & 1 deletion torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16561,7 +16561,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
dtypes=[torch.complex64, torch.cdouble]),
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
dtypes=[torch.chalf, torch.complex64, torch.cdouble])),
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.complex32, torch.bool, torch.half, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
Expand Down