diff --git a/kernels/portable/cpu/op_sigmoid.cpp b/kernels/portable/cpu/op_sigmoid.cpp index 84c4ea2f542..34b2ec60dec 100644 --- a/kernels/portable/cpu/op_sigmoid.cpp +++ b/kernels/portable/cpu/op_sigmoid.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -35,21 +36,26 @@ Tensor& sigmoid_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { out, "Failed to resize output tensor."); - ScalarType in_type = in.scalar_type(); - ScalarType out_type = out.scalar_type(); - ET_SWITCH_REALHB_TYPES(in_type, ctx, "sigmoid.out", CTYPE_IN, [&]() { - ET_SWITCH_FLOATH_TYPES(out_type, ctx, "sigmoid.out", CTYPE_OUT, [&]() { - apply_unary_map_fn( - [](const CTYPE_IN val_in) { - // perform math in double to preserve precision - double in_casted = static_cast(val_in); - double out_val = 1.0 / (1.0 + exp(-in_casted)); - return static_cast(out_val); - }, - in.const_data_ptr(), - out.mutable_data_ptr(), - in.numel()); - }); + ScalarType compute_type = + executorch::runtime::isFloatingType(in.scalar_type()) ? in.scalar_type() + : ScalarType::Float; + compute_type = utils::get_compute_type(compute_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "sigmoid.out"; + + ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_unitensor_elementwise_fn( + [](const CTYPE_COMPUTE val_in) { + CTYPE_COMPUTE out_val = static_cast(1.0) / + (static_cast(1.0) + exp(-val_in)); + return out_val; + }, + ctx, + in, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::FLOATHBF16); }); return out; diff --git a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl index f63932d4840..b89c4ac1a95 100644 --- a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -1074,6 +1074,9 @@ ATEN_OPS = ( name = "op_sigmoid", deps = [ "//executorch/kernels/portable/cpu/util:functional_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", + "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", ], ), op_target(