diff --git a/kernels/portable/cpu/util/dtype_util.cpp b/kernels/portable/cpu/util/dtype_util.cpp index 299910da746..d240b9f83bc 100644 --- a/kernels/portable/cpu/util/dtype_util.cpp +++ b/kernels/portable/cpu/util/dtype_util.cpp @@ -28,17 +28,14 @@ bool check_tensor_dtype( case SupportedTensorDtypes::INTB: return executorch::runtime::tensor_is_integral_type(t, true); case SupportedTensorDtypes::BOOL_OR_BYTE: - return ( - executorch::runtime::tensor_is_type(t, ScalarType::Bool) || - executorch::runtime::tensor_is_type(t, ScalarType::Byte)); + return (executorch::runtime::tensor_is_type( + t, ScalarType::Bool, ScalarType::Byte)); case SupportedTensorDtypes::SAME_AS_COMPUTE: return executorch::runtime::tensor_is_type(t, compute_type); case SupportedTensorDtypes::SAME_AS_COMMON: { if (compute_type == ScalarType::Float) { - return ( - executorch::runtime::tensor_is_type(t, ScalarType::Float) || - executorch::runtime::tensor_is_type(t, ScalarType::Half) || - executorch::runtime::tensor_is_type(t, ScalarType::BFloat16)); + return (executorch::runtime::tensor_is_type( + t, ScalarType::Float, ScalarType::Half, ScalarType::BFloat16)); } else { return executorch::runtime::tensor_is_type(t, compute_type); } diff --git a/runtime/core/exec_aten/util/tensor_util.h b/runtime/core/exec_aten/util/tensor_util.h index d577251f4a4..e238ec301ee 100644 --- a/runtime/core/exec_aten/util/tensor_util.h +++ b/runtime/core/exec_aten/util/tensor_util.h @@ -484,6 +484,37 @@ inline bool tensor_is_type( return true; } +inline bool tensor_is_type( + executorch::aten::Tensor t, + executorch::aten::ScalarType dtype, + executorch::aten::ScalarType dtype2) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + t.scalar_type() == dtype || t.scalar_type() == dtype2, + "Expected to find %s or %s type, but tensor has type %s", + torch::executor::toString(dtype), + torch::executor::toString(dtype2), + torch::executor::toString(t.scalar_type())); + + return true; +} + +inline bool tensor_is_type( + executorch::aten::Tensor t, + executorch::aten::ScalarType dtype, + executorch::aten::ScalarType dtype2, + executorch::aten::ScalarType dtype3) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + t.scalar_type() == dtype || t.scalar_type() == dtype2 || + t.scalar_type() == dtype3, + "Expected to find %s, %s, or %s type, but tensor has type %s", + torch::executor::toString(dtype), + torch::executor::toString(dtype2), + torch::executor::toString(dtype3), + torch::executor::toString(t.scalar_type())); + + return true; +} + inline bool tensor_is_integral_type( executorch::aten::Tensor t, bool includeBool = false) {