diff --git a/extension/apple/ExecuTorch/Internal/ExecuTorchUtils.h b/extension/apple/ExecuTorch/Internal/ExecuTorchUtils.h index d514a101383..d6f4c036637 100644 --- a/extension/apple/ExecuTorch/Internal/ExecuTorchUtils.h +++ b/extension/apple/ExecuTorch/Internal/ExecuTorchUtils.h @@ -11,9 +11,11 @@ #ifdef __cplusplus #import +#import namespace executorch::extension::utils { using namespace aten; +using namespace runtime; /** * Deduces the scalar type for a given NSNumber based on its type encoding. @@ -41,6 +43,56 @@ static inline ScalarType deduceScalarType(NSNumber *number) { return ScalarType::Undefined; } +/** + * Converts the value held in the NSNumber to the specified C++ type T. + * + * @tparam T The target C++ numeric type. + * @param number The NSNumber instance to extract the value from. + * @return The value converted to type T. + */ +template +static inline T extractValue(NSNumber *number) { + ET_CHECK_MSG(!(isFloatingType(deduceScalarType(number)) && + isIntegralType(CppTypeToScalarType::value, true)), + "Cannot convert floating point to integral type"); + T value; + if constexpr (std::is_same_v) { + value = number.unsignedCharValue; + } else if constexpr (std::is_same_v) { + value = number.charValue; + } else if constexpr (std::is_same_v) { + value = number.shortValue; + } else if constexpr (std::is_same_v) { + value = number.intValue; + } else if constexpr (std::is_same_v) { + value = number.longLongValue; + } else if constexpr (std::is_same_v) { + value = number.floatValue; + } else if constexpr (std::is_same_v) { + value = number.doubleValue; + } else if constexpr (std::is_same_v) { + value = number.boolValue; + } else if constexpr (std::is_same_v) { + value = number.unsignedShortValue; + } else if constexpr (std::is_same_v) { + value = number.unsignedIntValue; + } else if constexpr (std::is_same_v) { + value = number.unsignedLongLongValue; + } else if constexpr (std::is_same_v) { + value = number.integerValue; + } else if constexpr (std::is_same_v) { + value = number.unsignedIntegerValue; + } else if constexpr (std::is_same_v || + std::is_same_v) { + value = T(number.floatValue); + } else { + static_assert(sizeof(T) == 0, "Unsupported type"); + } + ET_DCHECK_MSG(std::numeric_limits::lowest() <= value && value <= std::numeric_limits::max(), + "Value out of range"); + return value; +} + } // namespace executorch::extension::utils #endif // __cplusplus