diff --git a/kernels/portable/cpu/op_clamp.cpp b/kernels/portable/cpu/op_clamp.cpp index ec34fa9bd35..0c3894bcf47 100644 --- a/kernels/portable/cpu/op_clamp.cpp +++ b/kernels/portable/cpu/op_clamp.cpp @@ -218,30 +218,36 @@ Tensor& clamp_tensor_out( ET_SWITCH_REALHB_TYPES(min_type, ctx, name, CTYPE_MIN, [&]() { ET_SWITCH_REALHB_TYPES(max_type, ctx, name, CTYPE_MAX, [&]() { ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() { - apply_ternary_elementwise_fn< - CTYPE_IN, - CTYPE_MIN, - CTYPE_MAX, - CTYPE_OUT>( + apply_ternary_elementwise_fn( [has_min, has_max]( - const CTYPE_IN val_in, - const CTYPE_MIN val_min, - const CTYPE_MAX val_max) { - CTYPE_OUT val_out = static_cast(val_in); + const CTYPE_OUT val_in, + const CTYPE_OUT val_min, + const CTYPE_OUT val_max) { + CTYPE_OUT val_out = val_in; if (has_min) { - val_out = utils::max_override( - val_out, static_cast(val_min)); + val_out = utils::max_override(val_out, val_min); } if (has_max) { - val_out = utils::min_override( - val_out, static_cast(val_max)); + val_out = utils::min_override(val_out, val_max); } return val_out; }, in, min, max, - out); + out, + [](const void* inPtr) { + return static_cast( + *reinterpret_cast(inPtr)); + }, + [](const void* minPtr) { + return static_cast( + *reinterpret_cast(minPtr)); + }, + [](const void* maxPtr) { + return static_cast( + *reinterpret_cast(maxPtr)); + }); }); }); }); diff --git a/kernels/portable/cpu/op_where.cpp b/kernels/portable/cpu/op_where.cpp index a7736247597..9c3f401c560 100644 --- a/kernels/portable/cpu/op_where.cpp +++ b/kernels/portable/cpu/op_where.cpp @@ -48,16 +48,26 @@ Tensor& where_out( ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, name, CTYPE_B, [&]() { using CTYPE_OUT = typename torch::executor::promote_types::type; - apply_ternary_elementwise_fn( - [](const CTYPE_A val_a, const CTYPE_B val_b, const uint8_t val_c) { - CTYPE_OUT a_casted = static_cast(val_a); - CTYPE_OUT b_casted = static_cast(val_b); - return val_c ? a_casted : b_casted; - }, + apply_ternary_elementwise_fn( + [](const CTYPE_OUT val_a, + const CTYPE_OUT val_b, + const CTYPE_OUT val_c) { return val_c ? val_a : val_b; }, a, b, cond, - out); + out, + [](const void* aPtr) { + return static_cast( + *reinterpret_cast(aPtr)); + }, + [](const void* bPtr) { + return static_cast( + *reinterpret_cast(bPtr)); + }, + [](const void* cPtr) { + return static_cast( + *reinterpret_cast(cPtr)); + }); }); }); diff --git a/kernels/portable/cpu/util/broadcast_util.h b/kernels/portable/cpu/util/broadcast_util.h index 92d35f322fb..125f1efe3cd 100644 --- a/kernels/portable/cpu/util/broadcast_util.h +++ b/kernels/portable/cpu/util/broadcast_util.h @@ -313,28 +313,42 @@ inline void apply_binary_elementwise_fn( * Useful for ternary elementwise operators. For each element of the inputs, * perform a computation and write to the corresponding element of the output. * Tensor broadcasting is applied wherever it is required. + * + * In order to mitigate build time cost (straightforwardly |CTYPE_A| * + * |CTYPE_B| * |CTYPE_C| * |CTYPE_OUT|), all arguments to compute_fun + * are passed as CTYPE_OUT and we require conversion functions from + * each input type to the output type be provided. Each conversion + * function must take a void* pointing to an element of the + * corresponding tensor, load that element, and convert it to + * CTYPE_OUT. */ template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_C, typename CTYPE_OUT, - typename Op> + typename Op, + typename AToOutFunc, + typename BToOutFunc, + typename CToOutFunc> inline void apply_ternary_elementwise_fn( const Op& compute_fun, const Tensor& a, const Tensor& b, const Tensor& c, - const Tensor& out) { + const Tensor& out, + AToOutFunc aToOut, + BToOutFunc bToOut, + CToOutFunc cToOut) { const bool a_is_broadcasted = !out.sizes().equals(a.sizes()); const bool b_is_broadcasted = !out.sizes().equals(b.sizes()); const bool c_is_broadcasted = !out.sizes().equals(c.sizes()); const bool any_is_broadcasted = (a_is_broadcasted || b_is_broadcasted || c_is_broadcasted); - const CTYPE_A* const data_a = a.const_data_ptr(); - const CTYPE_B* const data_b = b.const_data_ptr(); - const CTYPE_C* const data_c = c.const_data_ptr(); + const char* const data_a = reinterpret_cast(a.const_data_ptr()); + const char* const data_b = reinterpret_cast(b.const_data_ptr()); + const char* const data_c = reinterpret_cast(c.const_data_ptr()); + const auto a_element_size = a.element_size(); + const auto b_element_size = b.element_size(); + const auto c_element_size = c.element_size(); CTYPE_OUT* const data_out = out.mutable_data_ptr(); for (size_t i = 0; i < out.numel(); ++i) { @@ -358,7 +372,9 @@ inline void apply_ternary_elementwise_fn( } data_out[i] = compute_fun( - data_a[a_linear_index], data_b[b_linear_index], data_c[c_linear_index]); + aToOut(&data_a[a_linear_index * a_element_size]), + bToOut(&data_b[b_linear_index * b_element_size]), + cToOut(&data_c[c_linear_index * c_element_size])); } }