diff --git a/backends/cadence/aot/functions_fusion_g3.yaml b/backends/cadence/aot/functions_fusion_g3.yaml index 0feb1e47891..269e8a08e4b 100644 --- a/backends/cadence/aot/functions_fusion_g3.yaml +++ b/backends/cadence/aot/functions_fusion_g3.yaml @@ -171,6 +171,11 @@ kernels: - arg_meta: null kernel_name: cadence::impl::G3::exp_out + +- op: hardtanh.out + kernels: + - arg_meta: null + kernel_name: cadence::impl::G3::hardtanh_out # custom ops - func: cadence::quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) diff --git a/backends/cadence/fusion_g3/operators/CMakeLists.txt b/backends/cadence/fusion_g3/operators/CMakeLists.txt index 561323e045e..ec3220179a6 100644 --- a/backends/cadence/fusion_g3/operators/CMakeLists.txt +++ b/backends/cadence/fusion_g3/operators/CMakeLists.txt @@ -50,6 +50,7 @@ set(_aten_ops__srcs "${CMAKE_CURRENT_SOURCE_DIR}/op_lt.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/op_where.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/op_clamp.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_hardtanh.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_bmm.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_div.cpp" diff --git a/backends/cadence/fusion_g3/operators/op_hardtanh.cpp b/backends/cadence/fusion_g3/operators/op_hardtanh.cpp new file mode 100644 index 00000000000..09a2535c0dc --- /dev/null +++ b/backends/cadence/fusion_g3/operators/op_hardtanh.cpp @@ -0,0 +1,116 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include + +#include + +#include + +#include +#include +#include +#include +#include + +using ::executorch::aten::Scalar; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::Error; +using ::executorch::runtime::KernelRuntimeContext; +using ::torch::executor::native::utils::extract_scalar; +using ::torch::executor::native::utils::get_scalar_dtype; + +namespace cadence { +namespace impl { +namespace G3 { +namespace native { + +Tensor& hardtanh_out( + KernelRuntimeContext& ctx, + const Tensor& in, + const Scalar& min, + const Scalar& max, + Tensor& out) { + (void)ctx; + +#ifdef OP_ARG_CHECK + // Resize for dynamic shape + ET_KERNEL_CHECK_MSG( + ctx, + executorch::runtime::resize_tensor(out, in.sizes()) == Error::Ok, + InvalidArgument, + out, + "Failed to resize output tensor."); + + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensors_have_same_dim_order(in, out), + InvalidArgument, + out); +#endif + + ScalarType in_type = in.scalar_type(); + ScalarType min_type = get_scalar_dtype(min); + ScalarType max_type = get_scalar_dtype(max); + ScalarType out_type = out.scalar_type(); + + ET_KERNEL_CHECK(ctx, in_type == out_type, InvalidArgument, out); + + if (in_type == ScalarType::Float) { + const float* const inp1_data = in.const_data_ptr(); + float* const out_data = out.mutable_data_ptr(); + float min_val, max_val; + extract_scalar(min, &min_val); + extract_scalar(max, &max_val); + + XT_KERNEL_CHECK( + ctx, + out, + xa_nn_elm_clamp_scalar_f32_f32, + out_data, + inp1_data, + min_val, + max_val, + out.numel()); + } else { + ET_SWITCH_REALHBF16_TYPES(in_type, ctx, "hardtanh.out", CTYPE, [&]() { + CTYPE min_casted; + ET_SWITCH_SCALAR_OBJ_TYPES( + min_type, ctx, "hardtanh.out", CTYPE_MIN, [&]() { + CTYPE_MIN min_val; + extract_scalar(min, &min_val); + min_casted = static_cast(min_val); + }); + + CTYPE max_casted; + ET_SWITCH_SCALAR_OBJ_TYPES( + max_type, ctx, "hardtanh.out", CTYPE_MAX, [&]() { + CTYPE_MAX max_val; + extract_scalar(max, &max_val); + max_casted = static_cast(max_val); + }); + + torch::executor::apply_unary_map_fn( + [min_casted, max_casted](const CTYPE val_in) { + return torch::executor::native::utils::min_override( + torch::executor::native::utils::max_override( + val_in, min_casted), + max_casted); + }, + in.const_data_ptr(), + out.mutable_data_ptr(), + in.numel()); + }); + } + return out; +} + +} // namespace native +} // namespace G3 +} // namespace impl +} // namespace cadence