diff --git a/backends/cadence/fusion_g3/operators/op_clamp.cpp b/backends/cadence/fusion_g3/operators/op_clamp.cpp index 9f3f72a674f..92fb97b1260 100644 --- a/backends/cadence/fusion_g3/operators/op_clamp.cpp +++ b/backends/cadence/fusion_g3/operators/op_clamp.cpp @@ -45,6 +45,7 @@ bool is_out_of_bounds(CTYPE_VAL val) { } ET_NODISCARD bool check_bounds( + KernelRuntimeContext& ctx, const Scalar& val_scalar, const ScalarType& val_type, const ScalarType& out_type, @@ -107,14 +108,14 @@ Tensor& clamp_out( if (has_min) { ET_KERNEL_CHECK( ctx, - check_bounds(min_opt.value(), min_type, out_type, "minimum"), + check_bounds(ctx, min_opt.value(), min_type, out_type, "minimum"), InvalidArgument, out); } if (has_max) { ET_KERNEL_CHECK( ctx, - check_bounds(max_opt.value(), max_type, out_type, "maximum"), + check_bounds(ctx, max_opt.value(), max_type, out_type, "maximum"), InvalidArgument, out); } diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm index 3cf06207b45..3a2b640b7d7 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm @@ -265,9 +265,15 @@ - (NSString *)description { auto const count = _tensor->numel(); os << "\n count: " << count << ","; os << "\n scalars: ["; + // Create a minimal context for error handling in ET_SWITCH + struct { + [[noreturn]] void fail(torch::executor::Error /* error */) { + ET_CHECK_MSG(false, "Unsupported dtype in description"); + } + } ctx; ET_SWITCH_REALHBBF16_TYPES( static_cast(_tensor->scalar_type()), - nullptr, + ctx, "description", CTYPE, [&] { @@ -488,9 +494,15 @@ - (instancetype)initWithScalars:(NSArray *)scalars "Number of scalars does not match the shape"); std::vector data; data.resize(count * ExecuTorchSizeOfDataType(dataType)); + // Create a minimal context for error handling in ET_SWITCH + struct { + [[noreturn]] void fail(torch::executor::Error /* error */) { + ET_CHECK_MSG(false, "Unsupported dtype in initWithScalars"); + } + } ctx; for (NSUInteger index = 0; index < count; ++index) { ET_SWITCH_REALHBBF16_AND_UINT_TYPES( - static_cast(dataType), nil, "initWithScalars", CTYPE, [&] { + static_cast(dataType), ctx, "initWithScalars", CTYPE, [&] { reinterpret_cast(data.data())[index] = utils::toType(scalars[index]); } ); @@ -801,8 +813,14 @@ + (instancetype)fullTensorWithShape:(NSArray *)shape dataType:(ExecuTorchDataType)dataType shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism { Scalar fillValue; + // Create a minimal context for error handling in ET_SWITCH + struct { + [[noreturn]] void fail(torch::executor::Error /* error */) { + ET_CHECK_MSG(false, "Unsupported dtype in fullTensor"); + } + } ctx; ET_SWITCH_REALHBBF16_AND_UINT_TYPES( - static_cast(dataType), nil, "fullTensor", CTYPE, [&] { + static_cast(dataType), ctx, "fullTensor", CTYPE, [&] { fillValue = utils::toType(scalar); } ); diff --git a/extension/llm/runner/text_decoder_runner.h b/extension/llm/runner/text_decoder_runner.h index f583ed647a6..2f9e9a67331 100644 --- a/extension/llm/runner/text_decoder_runner.h +++ b/extension/llm/runner/text_decoder_runner.h @@ -68,12 +68,20 @@ class ET_EXPERIMENTAL TextDecoderRunner { const executorch::aten::Tensor& logits_tensor, const float temperature = 0.0f) { int32_t result = 0; + + // Create a minimal context for error handling in ET_SWITCH + struct { + [[noreturn]] void fail(torch::executor::Error /* error */) { + ET_CHECK_MSG(false, "Unsupported dtype in logits_to_token"); + } + } ctx; + ET_SWITCH_THREE_TYPES( Float, Half, BFloat16, logits_tensor.scalar_type(), - unused, + ctx, "logits_to_token", CTYPE, [&]() { diff --git a/extension/tensor/tensor_ptr.h b/extension/tensor/tensor_ptr.h index 3259bdbaf2b..59690de9f26 100644 --- a/extension/tensor/tensor_ptr.h +++ b/extension/tensor/tensor_ptr.h @@ -111,7 +111,15 @@ inline TensorPtr make_tensor_ptr( runtime::canCast(deduced_type, type), "Cannot cast deduced type to specified type."); std::vector casted_data(data.size() * runtime::elementSize(type)); - ET_SWITCH_REALHBBF16_TYPES(type, nullptr, "make_tensor_ptr", CTYPE, [&] { + + // Create a minimal context for error handling in ET_SWITCH + struct { + [[noreturn]] void fail(torch::executor::Error /* error */) { + ET_CHECK_MSG(false, "Unsupported dtype in make_tensor_ptr"); + } + } ctx; + + ET_SWITCH_REALHBBF16_TYPES(type, ctx, "make_tensor_ptr", CTYPE, [&] { std::transform( data.begin(), data.end(), diff --git a/extension/tensor/tensor_ptr_maker.cpp b/extension/tensor/tensor_ptr_maker.cpp index 8e7c908bf43..511b0ebe582 100644 --- a/extension/tensor/tensor_ptr_maker.cpp +++ b/extension/tensor/tensor_ptr_maker.cpp @@ -89,7 +89,14 @@ TensorPtr random_strided( empty_strided(std::move(sizes), std::move(strides), type, dynamism); std::default_random_engine gen{std::random_device{}()}; - ET_SWITCH_REALHBBF16_TYPES(type, nullptr, "random_strided", CTYPE, [&] { + // Create a minimal context for error handling in ET_SWITCH + struct { + [[noreturn]] void fail(torch::executor::Error /* error */) { + ET_CHECK_MSG(false, "Unsupported dtype in random_strided"); + } + } ctx; + + ET_SWITCH_REALHBBF16_TYPES(type, ctx, "random_strided", CTYPE, [&] { std::generate_n(tensor->mutable_data_ptr(), tensor->numel(), [&]() { return static_cast(distribution(gen)); }); @@ -124,7 +131,14 @@ TensorPtr full_strided( executorch::aten::TensorShapeDynamism dynamism) { auto tensor = empty_strided(std::move(sizes), std::move(strides), type, dynamism); - ET_SWITCH_REALHBBF16_TYPES(type, nullptr, "full_strided", CTYPE, [&] { + // Create a minimal context for error handling in ET_SWITCH + struct { + [[noreturn]] void fail(torch::executor::Error /* error */) { + ET_CHECK_MSG(false, "Unsupported data type in full_strided"); + } + } ctx; + + ET_SWITCH_REALHBBF16_TYPES(type, ctx, "full_strided", CTYPE, [&] { CTYPE value; ET_EXTRACT_SCALAR(fill_value, value); std::fill( diff --git a/kernels/optimized/cpu/op_add_sub_impl.h b/kernels/optimized/cpu/op_add_sub_impl.h index 3fc22d88a63..37761b44c9b 100644 --- a/kernels/optimized/cpu/op_add_sub_impl.h +++ b/kernels/optimized/cpu/op_add_sub_impl.h @@ -144,13 +144,13 @@ Tensor& opt_add_sub_out_impl( } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) { // Cannot apply the trick of -alpha here because alpha is Scalar without // support for - operator. At least not right now. - ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() { + ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() -> void { CTYPE alpha_val; ET_KERNEL_CHECK_MSG( ctx, torch::executor::native::utils::extract_scalar(alpha, &alpha_val), InvalidArgument, - out, + , "Failed to extract scalar alpha."); using Vec = at::vec::Vectorized; Vec alpha_val_vec(alpha_val); @@ -164,13 +164,13 @@ Tensor& opt_add_sub_out_impl( auto add_lambda = [&alpha_val_vec](auto x, auto y) { return y - alpha_val_vec * x; }; - return torch::executor::handle_broadcast_elementwise( + torch::executor::handle_broadcast_elementwise( ctx, add_lambda, a, b, out, selected_optimized_path, alpha); } else { auto add_lambda = [&alpha_val_vec](auto x, auto y) { return x - alpha_val_vec * y; }; - return torch::executor::handle_broadcast_elementwise( + torch::executor::handle_broadcast_elementwise( ctx, add_lambda, a, b, out, selected_optimized_path, alpha); } } else { @@ -191,13 +191,13 @@ Tensor& opt_add_sub_out_impl( auto add_lambda = [&alpha_val_vec](auto x, auto y) { return y + alpha_val_vec * x; }; - return torch::executor::handle_broadcast_elementwise( + torch::executor::handle_broadcast_elementwise( ctx, add_lambda, a, b, out, selected_optimized_path, alpha); } else { auto add_lambda = [&alpha_val_vec](auto x, auto y) { return x + alpha_val_vec * y; }; - return torch::executor::handle_broadcast_elementwise( + torch::executor::handle_broadcast_elementwise( ctx, add_lambda, a, b, out, selected_optimized_path, alpha); } } diff --git a/kernels/optimized/cpu/op_div.cpp b/kernels/optimized/cpu/op_div.cpp index e2baf413989..7af2b4b4695 100644 --- a/kernels/optimized/cpu/op_div.cpp +++ b/kernels/optimized/cpu/op_div.cpp @@ -130,11 +130,11 @@ Tensor& opt_div_out( selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) { auto div_lambda = [](auto x, auto y) { return y / x; }; - return torch::executor::handle_broadcast_elementwise( + torch::executor::handle_broadcast_elementwise( ctx, div_lambda, a, b, out, selected_optimized_path); } else { auto div_lambda = [](auto x, auto y) { return x / y; }; - return torch::executor::handle_broadcast_elementwise( + torch::executor::handle_broadcast_elementwise( ctx, div_lambda, a, b, out, selected_optimized_path); } }); diff --git a/kernels/optimized/cpu/op_le.cpp b/kernels/optimized/cpu/op_le.cpp index 8e56e1ca4fc..51fca9b0063 100644 --- a/kernels/optimized/cpu/op_le.cpp +++ b/kernels/optimized/cpu/op_le.cpp @@ -57,7 +57,7 @@ Tensor& opt_le_tensor_out( // Handle optimized broadcast cases ET_SWITCH_REALB_TYPES(out_type, ctx, "le.Tensor_out", CTYPE, [&]() { auto le_lambda = [](auto x, auto y) { return x.le(y); }; - return torch::executor::handle_broadcast_elementwise( + torch::executor::handle_broadcast_elementwise( ctx, le_lambda, a, b, out, selected_optimized_path); }); } else { diff --git a/kernels/optimized/cpu/op_mul.cpp b/kernels/optimized/cpu/op_mul.cpp index 8783812ede1..0d132ab1e03 100644 --- a/kernels/optimized/cpu/op_mul.cpp +++ b/kernels/optimized/cpu/op_mul.cpp @@ -148,13 +148,13 @@ Tensor& opt_mul_out( ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { auto mul_lambda = [](auto x, auto y) { return x * y; }; - return torch::executor::handle_broadcast_elementwise( + torch::executor::handle_broadcast_elementwise( ctx, mul_lambda, a, b, out, selected_optimized_path); }); } else { ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { auto mul_lambda = [](auto x, auto y) { return x * y; }; - return torch::executor::handle_broadcast_elementwise( + torch::executor::handle_broadcast_elementwise( ctx, mul_lambda, a, b, out, selected_optimized_path); }); } diff --git a/kernels/portable/cpu/op_clamp.cpp b/kernels/portable/cpu/op_clamp.cpp index 31d4b8fdf56..b3aa41cda85 100644 --- a/kernels/portable/cpu/op_clamp.cpp +++ b/kernels/portable/cpu/op_clamp.cpp @@ -34,6 +34,7 @@ bool is_out_of_bounds(CTYPE_CAST val_cast) { } ET_NODISCARD bool check_bounds( + KernelRuntimeContext& ctx, const Scalar& val_scalar, const torch::executor::native::ScalarType& val_type, const torch::executor::native::ScalarType& out_type, @@ -107,14 +108,14 @@ Tensor& clamp_out( if (has_min) { ET_KERNEL_CHECK( ctx, - check_bounds(min_opt.value(), min_type, out_type, "minimum"), + check_bounds(ctx, min_opt.value(), min_type, out_type, "minimum"), InvalidArgument, out); } if (has_max) { ET_KERNEL_CHECK( ctx, - check_bounds(max_opt.value(), max_type, out_type, "maximum"), + check_bounds(ctx, max_opt.value(), max_type, out_type, "maximum"), InvalidArgument, out); } diff --git a/kernels/portable/cpu/op_convolution.cpp b/kernels/portable/cpu/op_convolution.cpp index 68991a09b33..f598ac99444 100644 --- a/kernels/portable/cpu/op_convolution.cpp +++ b/kernels/portable/cpu/op_convolution.cpp @@ -415,7 +415,7 @@ Tensor& convolution_out( ET_SWITCH_REALH_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() { const auto load_bias = bias.has_value() ? utils::internal::get_load_to_compute_fn( - bias.value(), utils::SupportedTensorDtypes::REALHBF16) + ctx, bias.value(), utils::SupportedTensorDtypes::REALHBF16) : nullptr; convolution_wrapper( in, diff --git a/kernels/portable/cpu/op_cumsum.cpp b/kernels/portable/cpu/op_cumsum.cpp index 1f4aa5c458e..3a518d30715 100644 --- a/kernels/portable/cpu/op_cumsum.cpp +++ b/kernels/portable/cpu/op_cumsum.cpp @@ -111,10 +111,10 @@ Tensor& cumsum_out( // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "cumsum.out"; - ET_SWITCH_REALHBBF16_TYPES(out.scalar_type(), ctx, op_name, CTYPE_OUT, [&] { + ET_SWITCH_REALHBBF16_TYPES(out.scalar_type(), ctx, op_name, CTYPE_OUT, [&]() { const auto load_self = utils::internal::get_load_to_compute_fn( - self, utils::SupportedTensorDtypes::REALHBBF16); + ctx, self, utils::SupportedTensorDtypes::REALHBBF16); cumsum_tensors(self, load_self, dim, out); }); diff --git a/kernels/portable/cpu/op_fill.cpp b/kernels/portable/cpu/op_fill.cpp index 6c7032a3b41..3bbdb66646f 100644 --- a/kernels/portable/cpu/op_fill.cpp +++ b/kernels/portable/cpu/op_fill.cpp @@ -90,7 +90,7 @@ Tensor& fill_tensor_out( static constexpr const char op_name[] = "fill.Tensor_out"; ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, op_name, CTYPE_A, [&] { - CTYPE_A b_casted; + CTYPE_A b_casted{}; ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, op_name, CTYPE_B, [&] { CTYPE_B b_val; ET_EXTRACT_SCALAR_TENSOR(b, b_val); diff --git a/kernels/portable/cpu/op_index_put.cpp b/kernels/portable/cpu/op_index_put.cpp index 76bd7a48922..812d3e8fab3 100644 --- a/kernels/portable/cpu/op_index_put.cpp +++ b/kernels/portable/cpu/op_index_put.cpp @@ -160,6 +160,7 @@ Tensor& index_put_out( namespace { bool check_special_case_in_place_args( + KernelRuntimeContext& ctx, Tensor& in, TensorOptList indices, const Tensor& values, @@ -285,7 +286,8 @@ Tensor& index_put_( size_t dim = 0; ET_KERNEL_CHECK( ctx, - check_special_case_in_place_args(in, indices, values, accumulate, &dim), + check_special_case_in_place_args( + ctx, in, indices, values, accumulate, &dim), InvalidArgument, in); diff --git a/kernels/portable/cpu/op_scatter.cpp b/kernels/portable/cpu/op_scatter.cpp index 965afbb4b66..58341cefb1e 100644 --- a/kernels/portable/cpu/op_scatter.cpp +++ b/kernels/portable/cpu/op_scatter.cpp @@ -104,25 +104,20 @@ void scatter_value_helper( } // namespace Tensor& scatter_src_out( - KernelRuntimeContext& context, + KernelRuntimeContext& ctx, const Tensor& in, int64_t dim, const Tensor& index, const Tensor& src, Tensor& out) { - (void)context; - ET_KERNEL_CHECK( - context, + ctx, check_scatter_src_args(in, dim, index, src, out), InvalidArgument, out); ET_KERNEL_CHECK( - context, - resize_tensor(out, in.sizes()) == Error::Ok, - InvalidArgument, - out); + ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out); constexpr auto name = "scatter.src_out"; diff --git a/kernels/portable/cpu/op_scatter_add.cpp b/kernels/portable/cpu/op_scatter_add.cpp index b83a56c2e01..22fb3d161a8 100644 --- a/kernels/portable/cpu/op_scatter_add.cpp +++ b/kernels/portable/cpu/op_scatter_add.cpp @@ -52,38 +52,30 @@ void scatter_add_helper( } // namespace Tensor& scatter_add_out( - KernelRuntimeContext& context, + KernelRuntimeContext& ctx, const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src, Tensor& out) { - (void)context; - ET_KERNEL_CHECK( - context, + ctx, check_scatter_add_args(self, dim, index, src, out), InvalidArgument, out); ET_KERNEL_CHECK( - context, - tensors_have_same_dim_order(self, src, out), - InvalidArgument, - out); + ctx, tensors_have_same_dim_order(self, src, out), InvalidArgument, out); ET_KERNEL_CHECK( - context, tensor_is_default_dim_order(index), InvalidArgument, out); + ctx, tensor_is_default_dim_order(index), InvalidArgument, out); if (dim < 0) { dim += nonzero_dim(self); } ET_KERNEL_CHECK( - context, - resize_tensor(out, self.sizes()) == Error::Ok, - InvalidArgument, - out); + ctx, resize_tensor(out, self.sizes()) == Error::Ok, InvalidArgument, out); ScalarType self_type = self.scalar_type(); diff --git a/kernels/portable/cpu/util/dtype_util.h b/kernels/portable/cpu/util/dtype_util.h index 15732219c8f..98cf0a573f5 100644 --- a/kernels/portable/cpu/util/dtype_util.h +++ b/kernels/portable/cpu/util/dtype_util.h @@ -31,10 +31,11 @@ using load_to_compute_fn = CTYPE_COMPUTE (*)(const void*); template load_to_compute_fn get_load_to_compute_fn_realhbbf16( + KernelRuntimeContext& context, const Tensor& t) { CTYPE_COMPUTE (*result)(const void*) = nullptr; ET_SWITCH_REALHBBF16_TYPES( - t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + t.scalar_type(), context, op_name, TENSOR_CTYPE, [&]() -> void { result = internal::load_and_convert; }); return result; @@ -42,10 +43,11 @@ load_to_compute_fn get_load_to_compute_fn_realhbbf16( template load_to_compute_fn get_load_to_compute_fn_realhbf16( + KernelRuntimeContext& context, const Tensor& t) { CTYPE_COMPUTE (*result)(const void*) = nullptr; ET_SWITCH_REALHBF16_TYPES( - t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + t.scalar_type(), context, op_name, TENSOR_CTYPE, [&]() -> void { result = internal::load_and_convert; }); return result; @@ -53,41 +55,59 @@ load_to_compute_fn get_load_to_compute_fn_realhbf16( template load_to_compute_fn get_load_to_compute_fn_floathbf16( + KernelRuntimeContext& context, const Tensor& t) { CTYPE_COMPUTE (*result)(const void*) = nullptr; ET_SWITCH_FLOATHBF16_TYPES( - t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + t.scalar_type(), context, op_name, TENSOR_CTYPE, [&]() -> void { result = internal::load_and_convert; }); return result; } template -load_to_compute_fn get_load_to_compute_fn_intb(const Tensor& t) { +load_to_compute_fn get_load_to_compute_fn_intb( + KernelRuntimeContext& context, + const Tensor& t) { CTYPE_COMPUTE (*result)(const void*) = nullptr; ET_SWITCH_INT_TYPES_AND( - Bool, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + Bool, t.scalar_type(), context, op_name, TENSOR_CTYPE, [&]() -> void { result = internal::load_and_convert; }); return result; } template -load_to_compute_fn get_load_to_compute_fn_bool(const Tensor& t) { - ET_CHECK_MSG( - t.scalar_type() == ScalarType::Bool, - "Unhandled dtype %s for %s", - ::executorch::runtime::toString(t.scalar_type()), - op_name); - return internal::load_and_convert; +load_to_compute_fn get_load_to_compute_fn_bool( + KernelRuntimeContext& context, + const Tensor& t) { + CTYPE_COMPUTE (*result)(const void*) = nullptr; + if (t.scalar_type() != ScalarType::Bool) { + context.fail(torch::executor::Error::InvalidArgument); + ET_LOG( + Error, + "Unhandled dtype %s for %s", + ::executorch::runtime::toString(t.scalar_type()), + op_name); + } else { + result = internal::load_and_convert; + } + return result; } template load_to_compute_fn get_load_to_compute_fn_bool_or_byte( + KernelRuntimeContext& context, const Tensor& t) { CTYPE_COMPUTE (*result)(const void*) = nullptr; ET_SWITCH_TWO_TYPES( - Bool, Byte, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + Bool, + Byte, + t.scalar_type(), + context, + op_name, + TENSOR_CTYPE, + [&]() -> void { result = internal::load_and_convert; }); return result; @@ -95,14 +115,21 @@ load_to_compute_fn get_load_to_compute_fn_bool_or_byte( template load_to_compute_fn get_load_to_compute_fn_same_as_compute( + KernelRuntimeContext& context, const Tensor& t) { + CTYPE_COMPUTE (*result)(const void*) = nullptr; constexpr auto common_scalar_type = CppTypeToScalarType::value; - ET_CHECK_MSG( - t.scalar_type() == common_scalar_type, - "Unhandled dtype %s for %s", - ::executorch::runtime::toString(common_scalar_type), - op_name); - return internal::load_and_convert; + if (t.scalar_type() != common_scalar_type) { + context.fail(torch::executor::Error::InvalidArgument); + ET_LOG( + Error, + "Unhandled dtype %s for %s", + ::executorch::runtime::toString(t.scalar_type()), + op_name); + } else { + result = internal::load_and_convert; + } + return result; } template < @@ -110,12 +137,18 @@ template < const char* op_name, std::enable_if_t, bool> = true> load_to_compute_fn get_load_to_compute_fn_same_as_common( + KernelRuntimeContext& context, const Tensor& t) { CTYPE_COMPUTE (*result)(const void*) = nullptr; ET_SWITCH_THREE_TYPES( - Float, Half, BFloat16, t.scalar_type(), unused, op_name, T, [&]() { - result = internal::load_and_convert; - }); + Float, + Half, + BFloat16, + t.scalar_type(), + context, + op_name, + T, + [&]() -> void { result = internal::load_and_convert; }); return result; } @@ -124,8 +157,10 @@ template < const char* op_name, std::enable_if_t, bool> = true> load_to_compute_fn get_load_to_compute_fn_same_as_common( + KernelRuntimeContext& context, const Tensor& t) { - return get_load_to_compute_fn_same_as_compute(t); + return get_load_to_compute_fn_same_as_compute( + context, t); } template @@ -133,10 +168,12 @@ using store_compute_to_tensor_fn = void (*)(CTYPE_COMPUTE, void*); template store_compute_to_tensor_fn -get_store_compute_to_tensor_fn_realhbbf16(const Tensor& t) { +get_store_compute_to_tensor_fn_realhbbf16( + KernelRuntimeContext& context, + const Tensor& t) { void (*result)(CTYPE_COMPUTE, void*) = nullptr; ET_SWITCH_REALHBBF16_TYPES( - t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + t.scalar_type(), context, op_name, TENSOR_CTYPE, [&]() -> void { result = internal::convert_and_store; }); return result; @@ -144,10 +181,12 @@ get_store_compute_to_tensor_fn_realhbbf16(const Tensor& t) { template store_compute_to_tensor_fn -get_store_compute_to_tensor_fn_realhbf16(const Tensor& t) { +get_store_compute_to_tensor_fn_realhbf16( + KernelRuntimeContext& context, + const Tensor& t) { void (*result)(CTYPE_COMPUTE, void*) = nullptr; ET_SWITCH_REALHBF16_TYPES( - t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + t.scalar_type(), context, op_name, TENSOR_CTYPE, [&]() -> void { result = internal::convert_and_store; }); return result; @@ -155,10 +194,12 @@ get_store_compute_to_tensor_fn_realhbf16(const Tensor& t) { template store_compute_to_tensor_fn -get_store_compute_to_tensor_fn_floathbf16(const Tensor& t) { +get_store_compute_to_tensor_fn_floathbf16( + KernelRuntimeContext& context, + const Tensor& t) { void (*result)(CTYPE_COMPUTE, void*) = nullptr; ET_SWITCH_FLOATHBF16_TYPES( - t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + t.scalar_type(), context, op_name, TENSOR_CTYPE, [&]() -> void { result = internal::convert_and_store; }); return result; @@ -166,10 +207,11 @@ get_store_compute_to_tensor_fn_floathbf16(const Tensor& t) { template store_compute_to_tensor_fn get_store_compute_to_tensor_fn_intb( + KernelRuntimeContext& context, const Tensor& t) { void (*result)(CTYPE_COMPUTE, void*) = nullptr; ET_SWITCH_INT_TYPES_AND( - Bool, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + Bool, t.scalar_type(), context, op_name, TENSOR_CTYPE, [&]() -> void { result = internal::convert_and_store; }); return result; @@ -177,21 +219,36 @@ store_compute_to_tensor_fn get_store_compute_to_tensor_fn_intb( template store_compute_to_tensor_fn get_store_compute_to_tensor_fn_bool( + KernelRuntimeContext& context, const Tensor& t) { - ET_CHECK_MSG( - t.scalar_type() == ScalarType::Bool, - "Unhandled dtype %s for %s", - ::executorch::runtime::toString(t.scalar_type()), - op_name); - return internal::convert_and_store; + void (*result)(CTYPE_COMPUTE, void*) = nullptr; + if (t.scalar_type() != ScalarType::Bool) { + context.fail(torch::executor::Error::InvalidArgument); + ET_LOG( + Error, + "Unhandled dtype %s for %s", + ::executorch::runtime::toString(t.scalar_type()), + op_name); + } else { + result = internal::convert_and_store; + } + return result; } template store_compute_to_tensor_fn -get_store_compute_to_tensor_fn_bool_or_byte(const Tensor& t) { +get_store_compute_to_tensor_fn_bool_or_byte( + KernelRuntimeContext& context, + const Tensor& t) { void (*result)(CTYPE_COMPUTE, void*) = nullptr; ET_SWITCH_TWO_TYPES( - Bool, Byte, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + Bool, + Byte, + t.scalar_type(), + context, + op_name, + TENSOR_CTYPE, + [&]() -> void { result = internal::convert_and_store; }); return result; @@ -199,14 +256,22 @@ get_store_compute_to_tensor_fn_bool_or_byte(const Tensor& t) { template store_compute_to_tensor_fn -get_store_compute_to_tensor_fn_same_as_compute(const Tensor& t) { +get_store_compute_to_tensor_fn_same_as_compute( + KernelRuntimeContext& context, + const Tensor& t) { + void (*result)(CTYPE_COMPUTE, void*) = nullptr; constexpr auto common_scalar_type = CppTypeToScalarType::value; - ET_CHECK_MSG( - t.scalar_type() == common_scalar_type, - "Unhandled dtype %s for %s", - ::executorch::runtime::toString(common_scalar_type), - op_name); - return internal::convert_and_store; + if (t.scalar_type() != common_scalar_type) { + context.fail(torch::executor::Error::InvalidArgument); + ET_LOG( + Error, + "Unhandled dtype %s for %s", + ::executorch::runtime::toString(t.scalar_type()), + op_name); + } else { + result = internal::convert_and_store; + } + return result; } template < @@ -214,10 +279,19 @@ template < const char* op_name, std::enable_if_t, bool> = true> store_compute_to_tensor_fn -get_store_compute_to_tensor_fn_same_as_common(const Tensor& t) { +get_store_compute_to_tensor_fn_same_as_common( + KernelRuntimeContext& context, + const Tensor& t) { void (*result)(CTYPE_COMPUTE, void*) = nullptr; ET_SWITCH_THREE_TYPES( - Float, Half, BFloat16, t.scalar_type(), unused, op_name, CTYPE, [&]() { + Float, + Half, + BFloat16, + t.scalar_type(), + context, + op_name, + CTYPE, + [&]() -> void { result = internal::convert_and_store; }); return result; @@ -228,9 +302,11 @@ template < const char* op_name, std::enable_if_t, bool> = true> store_compute_to_tensor_fn -get_store_compute_to_tensor_fn_same_as_common(const Tensor& t) { +get_store_compute_to_tensor_fn_same_as_common( + KernelRuntimeContext& context, + const Tensor& t) { return get_store_compute_to_tensor_fn_same_as_compute( - t); + context, t); } } // namespace internal @@ -251,25 +327,32 @@ namespace internal { template load_to_compute_fn get_load_to_compute_fn_impl( + KernelRuntimeContext& context, const Tensor& t, SupportedTensorDtypes dtypes) { switch (dtypes) { case SupportedTensorDtypes::REALHBBF16: - return get_load_to_compute_fn_realhbbf16(t); + return get_load_to_compute_fn_realhbbf16( + context, t); case SupportedTensorDtypes::REALHBF16: - return get_load_to_compute_fn_realhbf16(t); + return get_load_to_compute_fn_realhbf16( + context, t); case SupportedTensorDtypes::FLOATHBF16: - return get_load_to_compute_fn_realhbf16(t); + return get_load_to_compute_fn_realhbf16( + context, t); case SupportedTensorDtypes::INTB: - return get_load_to_compute_fn_intb(t); + return get_load_to_compute_fn_intb(context, t); case SupportedTensorDtypes::BOOL: - return get_load_to_compute_fn_bool(t); + return get_load_to_compute_fn_bool(context, t); case SupportedTensorDtypes::BOOL_OR_BYTE: - return get_load_to_compute_fn_bool_or_byte(t); + return get_load_to_compute_fn_bool_or_byte( + context, t); case SupportedTensorDtypes::SAME_AS_COMPUTE: - return get_load_to_compute_fn_same_as_compute(t); + return get_load_to_compute_fn_same_as_compute( + context, t); case SupportedTensorDtypes::SAME_AS_COMMON: - return get_load_to_compute_fn_same_as_common(t); + return get_load_to_compute_fn_same_as_common( + context, t); } ET_CHECK(false); return nullptr; @@ -281,34 +364,37 @@ load_to_compute_fn get_load_to_compute_fn_impl( // why; just be aware when trying to improve size further. template store_compute_to_tensor_fn get_store_compute_to_tensor_fn( + KernelRuntimeContext& context, const Tensor& t, SupportedTensorDtypes dtypes) { switch (dtypes) { case SupportedTensorDtypes::REALHBBF16: return get_store_compute_to_tensor_fn_realhbbf16( - t); + context, t); case SupportedTensorDtypes::REALHBF16: return get_store_compute_to_tensor_fn_realhbf16( - t); + context, t); case SupportedTensorDtypes::FLOATHBF16: return get_store_compute_to_tensor_fn_floathbf16( - t); + context, t); case SupportedTensorDtypes::INTB: - return get_store_compute_to_tensor_fn_intb(t); + return get_store_compute_to_tensor_fn_intb( + context, t); case SupportedTensorDtypes::BOOL: - return get_store_compute_to_tensor_fn_bool(t); + return get_store_compute_to_tensor_fn_bool( + context, t); case SupportedTensorDtypes::BOOL_OR_BYTE: return get_store_compute_to_tensor_fn_bool_or_byte< CTYPE_COMPUTE, - op_name>(t); + op_name>(context, t); case SupportedTensorDtypes::SAME_AS_COMPUTE: return get_store_compute_to_tensor_fn_same_as_compute< CTYPE_COMPUTE, - op_name>(t); + op_name>(context, t); case SupportedTensorDtypes::SAME_AS_COMMON: { return get_store_compute_to_tensor_fn_same_as_common< CTYPE_COMPUTE, - op_name>(t); + op_name>(context, t); } } ET_CHECK(false); @@ -322,6 +408,7 @@ inline constexpr const char kGenericElementwiseOpName[] = template load_to_compute_fn get_load_to_compute_fn( + KernelRuntimeContext& context, const Tensor& t, SupportedTensorDtypes dtypes) { // NOTE: Selective build relies on the operator name being passed @@ -335,7 +422,7 @@ load_to_compute_fn get_load_to_compute_fn( #else // EXECUTORCH_SELECTIVE_BUILD_DTYPE kGenericElementwiseOpName #endif // EXECUTORCH_SELECTIVE_BUILD_DTYPE - >(t, dtypes); + >(context, t, dtypes); } bool check_tensor_dtype( diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index 5bb5becf185..cc1110e10d7 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -119,9 +119,9 @@ inline void dtype_specialized_elementwise_fn_impl( // small-sized tests will test whether using Vectorized broke our // lambda. #ifndef NDEBUG - std::array loaded_inputs; + std::array loaded_inputs{}; #else // NDEBUG - std::array loaded_inputs; + std::array loaded_inputs{}; #endif // NDEBUG for (const auto input_idx : c10::irange(kNumInputs)) { loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx]; @@ -136,7 +136,7 @@ inline void dtype_specialized_elementwise_fn_impl( // Main vectorized loop. for (auto idx = vectorized_begin; idx < vectorized_end; idx += Vec::size()) { - std::array loaded_vec_inputs; + std::array loaded_vec_inputs{}; for (const auto input_idx : c10::irange(kNumInputs)) { loaded_vec_inputs[input_idx] = Vec::loadu(&inputs_data_ptrs[input_idx][idx]); @@ -148,9 +148,9 @@ inline void dtype_specialized_elementwise_fn_impl( // Scalar epilogue. for (const auto idx : c10::irange(vectorized_end, end)) { #ifndef NDEBUG - std::array loaded_inputs; + std::array loaded_inputs{}; #else // NDEBUG - std::array loaded_inputs; + std::array loaded_inputs{}; #endif // NDEBUG for (const auto input_idx : c10::irange(kNumInputs)) { loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx]; @@ -184,7 +184,7 @@ inline void dtype_specialized_elementwise_fn_impl( begin_it += begin; for (; (*begin_it)[0] < end; ++begin_it) { const auto& indexes = *begin_it; - std::array loaded_inputs; + std::array loaded_inputs{}; for (const auto idx : c10::irange(kNumInputs)) { loaded_inputs[idx] = inputs_data_ptrs[idx][indexes[idx + 1]]; } @@ -238,14 +238,14 @@ inline void apply_elementwise_fn_generic_impl( }; std::array inputs_info = {(InputInfo{ internal::get_load_to_compute_fn( - *inputs.first, inputs.second), + ctx, *inputs.first, inputs.second), reinterpret_cast(inputs.first->const_data_ptr()), inputs.first->element_size(), })...}; const auto store_compute_to_out = internal::get_store_compute_to_tensor_fn( - out, out_dtypes); + ctx, out, out_dtypes); char* const data_out = reinterpret_cast(out.mutable_data_ptr()); const auto out_element_size = out.element_size(); @@ -261,7 +261,7 @@ inline void apply_elementwise_fn_generic_impl( begin_it += begin; for (; (*begin_it)[0] < end; ++begin_it) { const auto& indexes = *begin_it; - std::array loaded_inputs; + std::array loaded_inputs{}; for (const auto idx : c10::irange(kNumInputs)) { const auto& input_info = inputs_info[idx]; loaded_inputs[idx] = input_info.load_to_compute( diff --git a/kernels/portable/test/dtype_selective_build_test.cpp b/kernels/portable/test/dtype_selective_build_test.cpp index 0492ee14b00..d536d90aa7c 100644 --- a/kernels/portable/test/dtype_selective_build_test.cpp +++ b/kernels/portable/test/dtype_selective_build_test.cpp @@ -15,6 +15,12 @@ using executorch::aten::ScalarType; using torch::executor::ScalarTypeToCppType; TEST(DtypeSelectiveBuildTest, UnknownOp) { + // Create a minimal context for error handling in ET_SWITCH + struct { + [[noreturn]] void fail(torch::executor::Error /* error */) { + ET_CHECK_MSG(false, "Unsupported dtype"); + } + } ctx; ET_EXPECT_DEATH( ET_SWITCH_TWO_TYPES( Float, @@ -29,6 +35,12 @@ TEST(DtypeSelectiveBuildTest, UnknownOp) { } TEST(DtypeSelectiveBuildTest, OpWithoutDtype) { + // Create a minimal context for error handling in ET_SWITCH + struct { + [[noreturn]] void fail(torch::executor::Error /* error */) { + ET_CHECK_MSG(false, "Unsupported dtype"); + } + } ctx; ET_EXPECT_DEATH( ET_SWITCH_TWO_TYPES( Float, @@ -43,6 +55,12 @@ TEST(DtypeSelectiveBuildTest, OpWithoutDtype) { } TEST(DtypeSelectiveBuildTest, OpWithDtype) { + // Create a minimal context for error handling in ET_SWITCH + struct { + [[noreturn]] void fail(torch::executor::Error /* error */) { + ET_CHECK_MSG(false, "Unsupported dtype"); + } + } ctx; ASSERT_EQ( ET_SWITCH_TWO_TYPES( Float, diff --git a/kernels/quantized/cpu/embeddingxb.cpp b/kernels/quantized/cpu/embeddingxb.cpp index 4a76eff1eef..0ad5470c2c3 100644 --- a/kernels/quantized/cpu/embeddingxb.cpp +++ b/kernels/quantized/cpu/embeddingxb.cpp @@ -258,6 +258,7 @@ void resize_out_tensor( Tensor& quantized_embedding_xbit_out( // TODO Evaluate whether this name is appropriate for an operator that takes // non quant input and returns fp output + KernelRuntimeContext& ctx, const Tensor& weight, const Tensor& weight_scales, const std::optional& opt_weight_zero_points, @@ -268,6 +269,8 @@ Tensor& quantized_embedding_xbit_out( int weight_nbit) { ScalarType out_type = out.scalar_type(); + resize_out_tensor(weight, indices, out, weight_nbit); + // TODO (jakeszwe): improve these to account for the size of out in relation // to weight and indices accounting for a possible batch dimension check_embedding_xbit_args( @@ -296,7 +299,6 @@ Tensor& quantized_embedding_xbit_out( } Tensor& quantized_embedding_xbit_out( - KernelRuntimeContext& context, const Tensor& weight, const Tensor& weight_scales, const std::optional& opt_weight_zero_points, @@ -307,9 +309,9 @@ Tensor& quantized_embedding_xbit_out( int weight_nbit) { // TODO(larryliu): Add a context arg to the real op function and remove this // wrapper - (void)context; - resize_out_tensor(weight, indices, out, weight_nbit); - return quantized_embedding_xbit_out( + KernelRuntimeContext context; + auto& res = quantized_embedding_xbit_out( + context, weight, weight_scales, opt_weight_zero_points, @@ -318,11 +320,14 @@ Tensor& quantized_embedding_xbit_out( indices, out, weight_nbit); + ET_CHECK(context.failure_state() == Error::Ok); + return res; } Tensor& quantized_embedding_xbit_dtype_out( // TODO Evaluate whether this name is appropriate for an operator that takes // non quant input and returns fp output + KernelRuntimeContext& ctx, const Tensor& weight, const Tensor& weight_scales, const std::optional& opt_weight_zero_points, @@ -332,6 +337,8 @@ Tensor& quantized_embedding_xbit_dtype_out( std::optional out_dtype, Tensor& out, int weight_nbit) { + resize_out_tensor(weight, indices, out, weight_nbit); + // TODO (jakeszwe): improve these to account for the size of out in relation // to weight and indices accounting for a possible batch dimension check_embedding_xbit_args( @@ -365,7 +372,6 @@ Tensor& quantized_embedding_xbit_dtype_out( } Tensor& quantized_embedding_xbit_dtype_out( - KernelRuntimeContext& context, const Tensor& weight, const Tensor& weight_scales, const std::optional& opt_weight_zero_points, @@ -377,9 +383,9 @@ Tensor& quantized_embedding_xbit_dtype_out( int weight_nbit) { // TODO(larryliu): Add a context arg to the real op function and remove this // wrapper - (void)context; - resize_out_tensor(weight, indices, out, weight_nbit); - return quantized_embedding_xbit_dtype_out( + KernelRuntimeContext context; + auto& res = quantized_embedding_xbit_dtype_out( + context, weight, weight_scales, opt_weight_zero_points, @@ -389,6 +395,8 @@ Tensor& quantized_embedding_xbit_dtype_out( out_dtype, out, weight_nbit); + ET_CHECK(context.failure_state() == Error::Ok); + return res; } } // namespace native diff --git a/kernels/quantized/cpu/op_embedding.cpp b/kernels/quantized/cpu/op_embedding.cpp index 899655c538f..8aa1696e8b6 100644 --- a/kernels/quantized/cpu/op_embedding.cpp +++ b/kernels/quantized/cpu/op_embedding.cpp @@ -232,6 +232,7 @@ void resize_out_tensor( Tensor& quantized_embedding_byte_out( // TODO Evaluate whether this name is appropriate for an operator that takes // non quant input and returns fp output + KernelRuntimeContext& ctx, const Tensor& weight, const Tensor& weight_scales, const std::optional& opt_weight_zero_points, @@ -242,6 +243,8 @@ Tensor& quantized_embedding_byte_out( ScalarType w_type = weight.scalar_type(); ScalarType out_type = out.scalar_type(); + resize_out_tensor(weight, indices, out); + // TODO (jakeszwe): improve these to account for the size of out in relation // to weight and indices accounting for a possible batch dimension check_embedding_byte_args( @@ -266,7 +269,6 @@ Tensor& quantized_embedding_byte_out( } Tensor& quantized_embedding_byte_out( - KernelRuntimeContext& context, const Tensor& weight, const Tensor& weight_scales, const std::optional& opt_weight_zero_points, @@ -276,9 +278,9 @@ Tensor& quantized_embedding_byte_out( Tensor& out) { // TODO(larryliu): Add a context arg to the real op function and remove this // wrapper - (void)context; - resize_out_tensor(weight, indices, out); - return quantized_embedding_byte_out( + KernelRuntimeContext context; + auto& res = quantized_embedding_byte_out( + context, weight, weight_scales, opt_weight_zero_points, @@ -286,11 +288,14 @@ Tensor& quantized_embedding_byte_out( weight_quant_max, indices, out); + ET_CHECK(context.failure_state() == Error::Ok); + return res; } Tensor& quantized_embedding_byte_dtype_out( // TODO Evaluate whether this name is appropriate for an operator that takes // non quant input and returns fp output + KernelRuntimeContext& ctx, const Tensor& weight, const Tensor& weight_scales, const std::optional& opt_weight_zero_points, @@ -299,6 +304,8 @@ Tensor& quantized_embedding_byte_dtype_out( const Tensor& indices, std::optional out_dtype, Tensor& out) { + resize_out_tensor(weight, indices, out); + // TODO (jakeszwe): improve these to account for the size of out in relation // to weight and indices accounting for a possible batch dimension check_embedding_byte_args( @@ -329,7 +336,6 @@ Tensor& quantized_embedding_byte_dtype_out( } Tensor& quantized_embedding_byte_dtype_out( - KernelRuntimeContext& context, const Tensor& weight, const Tensor& weight_scales, const std::optional& opt_weight_zero_points, @@ -340,9 +346,9 @@ Tensor& quantized_embedding_byte_dtype_out( Tensor& out) { // TODO(larryliu): Add a context arg to the real op function and remove this // wrapper - (void)context; - resize_out_tensor(weight, indices, out); - return quantized_embedding_byte_dtype_out( + KernelRuntimeContext context; + auto& res = quantized_embedding_byte_dtype_out( + context, weight, weight_scales, opt_weight_zero_points, @@ -351,6 +357,8 @@ Tensor& quantized_embedding_byte_dtype_out( indices, out_dtype, out); + ET_CHECK(context.failure_state() == Error::Ok); + return res; } } // namespace native diff --git a/kernels/quantized/cpu/op_mixed_linear.cpp b/kernels/quantized/cpu/op_mixed_linear.cpp index a9d5db10533..2bd61974d9e 100644 --- a/kernels/quantized/cpu/op_mixed_linear.cpp +++ b/kernels/quantized/cpu/op_mixed_linear.cpp @@ -61,15 +61,19 @@ bool check_quantized_mixed_linear_args( } Tensor& quantized_mixed_linear_out( + KernelRuntimeContext& ctx, const Tensor& in, const Tensor& weight, const Tensor& weight_scales, const std::optional& opt_weight_zero_points, const std::optional dtype, Tensor& out) { - // TODO (gjcomer) Replace with ET_KERNEL_CHECK when context is available. - ET_CHECK(check_quantized_mixed_linear_args( - in, weight, weight_scales, opt_weight_zero_points, dtype, out)); + ET_KERNEL_CHECK( + ctx, + check_quantized_mixed_linear_args( + in, weight, weight_scales, opt_weight_zero_points, dtype, out), + InvalidArgument, + out); ScalarType out_dtype = dtype.has_value() ? dtype.value() : out.scalar_type(); @@ -78,8 +82,11 @@ Tensor& quantized_mixed_linear_out( output_sizes[0] = in.size(0); output_sizes[1] = weight.size(0); - // TODO (gjcomer) Replace with ET_KERNEL_CHECK when context is available. - ET_CHECK(resize_tensor(out, {output_sizes, output_ndim}) == Error::Ok); + ET_KERNEL_CHECK( + ctx, + resize_tensor(out, {output_sizes, output_ndim}) == Error::Ok, + InvalidArgument, + out); constexpr auto name = "quantized_decomposed::mixed_linear.out"; @@ -113,7 +120,6 @@ Tensor& quantized_mixed_linear_out( } Tensor& quantized_mixed_linear_out( - KernelRuntimeContext& ctx, const Tensor& in, const Tensor& weight, const Tensor& weight_scales, @@ -122,9 +128,11 @@ Tensor& quantized_mixed_linear_out( Tensor& out) { // TODO(mcandales): Remove the need for this wrapper // TODO(mkg): add support for dtype - (void)ctx; - return quantized_mixed_linear_out( - in, weight, weight_scales, opt_weight_zero_points, dtype, out); + KernelRuntimeContext context; + auto& res = quantized_mixed_linear_out( + context, in, weight, weight_scales, opt_weight_zero_points, dtype, out); + ET_CHECK(context.failure_state() == Error::Ok); + return res; } } // namespace native diff --git a/kernels/quantized/cpu/op_mixed_mm.cpp b/kernels/quantized/cpu/op_mixed_mm.cpp index 5e52c681e1b..87fb63ccc6b 100644 --- a/kernels/quantized/cpu/op_mixed_mm.cpp +++ b/kernels/quantized/cpu/op_mixed_mm.cpp @@ -52,20 +52,29 @@ bool check_quantized_mixed_mm_args( } Tensor& quantized_mixed_mm_out( + KernelRuntimeContext& ctx, const Tensor& in, const Tensor& weight, const Tensor& weight_scales, const std::optional& opt_weight_zero_points, Tensor& out) { - ET_CHECK(check_quantized_mixed_mm_args( - in, weight, weight_scales, opt_weight_zero_points, out)); + ET_KERNEL_CHECK( + ctx, + check_quantized_mixed_mm_args( + in, weight, weight_scales, opt_weight_zero_points, out), + InvalidArgument, + out); size_t output_ndim = 2; executorch::aten::SizesType output_sizes[kTensorDimensionLimit]; output_sizes[0] = in.size(0); output_sizes[1] = weight.size(1); - ET_CHECK(resize_tensor(out, {output_sizes, output_ndim}) == Error::Ok); + ET_KERNEL_CHECK( + ctx, + resize_tensor(out, {output_sizes, output_ndim}) == Error::Ok, + InvalidArgument, + out); constexpr auto name = "quantized_decomposed::mixed_mm.out"; @@ -88,16 +97,17 @@ Tensor& quantized_mixed_mm_out( } Tensor& quantized_mixed_mm_out( - KernelRuntimeContext& ctx, const Tensor& in, const Tensor& weight, const Tensor& weight_scales, const std::optional& opt_weight_zero_points, Tensor& out) { // TODO(mcandales): Remove the need for this wrapper - (void)ctx; - return quantized_mixed_mm_out( - in, weight, weight_scales, opt_weight_zero_points, out); + KernelRuntimeContext context; + auto& res = quantized_mixed_mm_out( + context, in, weight, weight_scales, opt_weight_zero_points, out); + ET_CHECK(context.failure_state() == Error::Ok); + return res; } } // namespace native diff --git a/runtime/core/exec_aten/util/scalar_type_util.h b/runtime/core/exec_aten/util/scalar_type_util.h index 9df5d1e47a2..895536b72be 100644 --- a/runtime/core/exec_aten/util/scalar_type_util.h +++ b/runtime/core/exec_aten/util/scalar_type_util.h @@ -910,20 +910,21 @@ struct promote_types { } #endif -#define ET_INTERNAL_SWITCH(TYPE, CONTEXT, NAME, ...) \ - [&] { \ - const auto& _st = TYPE; \ - constexpr const char* et_switch_name = NAME; \ - (void)et_switch_name; /* Suppress unused var */ \ - switch (_st) { \ - __VA_ARGS__ \ - default: \ - ET_CHECK_MSG( \ - false, \ - "Unhandled dtype %s for %s", \ - ::executorch::runtime::toString(_st), \ - et_switch_name); \ - } \ +#define ET_INTERNAL_SWITCH(TYPE, CONTEXT, NAME, ...) \ + [&] { \ + const auto& _st = TYPE; \ + constexpr const char* et_switch_name = NAME; \ + (void)et_switch_name; /* Suppress unused var */ \ + switch (_st) { \ + __VA_ARGS__ \ + default: \ + CONTEXT.fail(torch::executor::Error::InvalidArgument); \ + ET_LOG( \ + Error, \ + "Unhandled dtype %s for %s", \ + ::executorch::runtime::toString(_st), \ + et_switch_name); \ + } \ }() #define ET_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, ...) \