Skip to content

Commit

Permalink
add Half support for sigmoid on CPU (#96077)
Browse files Browse the repository at this point in the history
Pull Request resolved: #96077
Approved by: https://github.com/jgong5, https://github.com/ezyang
  • Loading branch information
mingfeima authored and pytorchmergebot committed Apr 4, 2023
1 parent 89dc87a commit 34c7adf
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 16 deletions.
8 changes: 8 additions & 0 deletions aten/src/ATen/Dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,14 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
AT_DISPATCH_SWITCH( \
TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(__VA_ARGS__))

#define AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)

#define AT_DISPATCH_REDUCED_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__))

#define AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, ...) \
AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
Expand Down
33 changes: 18 additions & 15 deletions aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,25 @@ inline namespace CPU_CAPABILITY {
using namespace vec;

static void sigmoid_kernel(TensorIteratorBase& iter) {
if (iter.common_dtype() == kBFloat16) {
cpu_kernel_vec(
iter,
[=](BFloat16 a) -> BFloat16 {
float a0 = static_cast<float>(a);
return static_cast<float>(1) / (static_cast<float>(1) + std::exp((-a0)));
},
[=](Vectorized<BFloat16> a) {
Vectorized<float> a0, a1;
std::tie(a0, a1) = convert_bfloat16_float(a);
a0 = (Vectorized<float>(static_cast<float>(1)) + a0.neg().exp()).reciprocal();
a1 = (Vectorized<float>(static_cast<float>(1)) + a1.neg().exp()).reciprocal();
return convert_float_bfloat16(a0, a1);
});
const auto dtype = iter.common_dtype();
if (at::isReducedFloatingType(dtype)) {
AT_DISPATCH_REDUCED_FLOATING_TYPES(dtype, "sigmoid_cpu_reduced_float", [&]() {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t {
float a0 = static_cast<float>(a);
return static_cast<float>(1) / (static_cast<float>(1) + std::exp((-a0)));
},
[=](Vectorized<scalar_t> a) {
Vectorized<float> a0, a1;
std::tie(a0, a1) = convert_to_float<scalar_t>(a);
a0 = (Vectorized<float>(static_cast<float>(1)) + a0.neg().exp()).reciprocal();
a1 = (Vectorized<float>(static_cast<float>(1)) + a1.neg().exp()).reciprocal();
return convert_from_float<scalar_t>(a0, a1);
});
});
} else {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.common_dtype(), "sigmoid_cpu", [&]() {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(dtype, "sigmoid_cpu", [&]() {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t {
Expand Down
4 changes: 4 additions & 0 deletions c10/core/ScalarType.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,10 @@ static inline bool isFloatingType(ScalarType t) {
t == ScalarType::Half || t == ScalarType::BFloat16);
}

static inline bool isReducedFloatingType(ScalarType t) {
return (t == ScalarType::Half || t == ScalarType::BFloat16);
}

static inline bool isComplexType(ScalarType t) {
return (
t == ScalarType::ComplexHalf || t == ScalarType::ComplexFloat ||
Expand Down
2 changes: 1 addition & 1 deletion torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16561,7 +16561,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
dtypes=[torch.complex64, torch.cdouble]),
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
dtypes=[torch.chalf, torch.complex64, torch.cdouble])),
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.complex32, torch.bool, torch.half, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
Expand Down

0 comments on commit 34c7adf

Please sign in to comment.