diff --git a/extension/pybindings/pybindings.cpp b/extension/pybindings/pybindings.cpp index 02534b303ed..dbb5bf0345b 100644 --- a/extension/pybindings/pybindings.cpp +++ b/extension/pybindings/pybindings.cpp @@ -140,7 +140,7 @@ void setup_output_storage( const std::vector>& output_storages) { if (output_storages.size() != method.outputs_size()) { THROW_IF_ERROR( - Error(), + Error::InvalidArgument, "number of output storages %zu does not match number of outputs %zu", output_storages.size(), method.outputs_size()); @@ -249,10 +249,10 @@ class Module final { const std::vector& args, const std::optional>>& output_storages = std::nullopt) { - auto& method = methods_[method_name]; + auto& method = get_method(method_name); exec_aten::ArrayRef input_evalue_list(args.data(), args.size()); - Error set_inputs_status = method->set_inputs(input_evalue_list); + Error set_inputs_status = method.set_inputs(input_evalue_list); THROW_IF_ERROR( set_inputs_status, "method->set_inputs() for method '%s' failed with error 0x%" PRIx32, @@ -273,9 +273,9 @@ class Module final { c10::autograd_dispatch_keyset); #endif if (output_storages) { - setup_output_storage(*method, *output_storages); + setup_output_storage(method, *output_storages); } - Error execute_status = method->execute(); + Error execute_status = method.execute(); THROW_IF_ERROR( execute_status, "method->execute() failed with error 0x%" PRIx32, @@ -302,7 +302,9 @@ class Module final { Method& get_method(const std::string& method_name) { if (methods_.count(method_name) == 0) { THROW_IF_ERROR( - Error(), "no such method in program: %s", method_name.c_str()); + Error::InvalidArgument, + "no such method in program: %s", + method_name.c_str()); } return *methods_[method_name].get(); } diff --git a/extension/pybindings/test/make_test.py b/extension/pybindings/test/make_test.py index 24b75b86518..b44de2680fb 100644 --- a/extension/pybindings/test/make_test.py +++ b/extension/pybindings/test/make_test.py @@ -341,6 +341,17 @@ def test_method_meta(tester) -> None: tester.assertEqual(output_tensor.nbytes(), 16) tester.assertEqual(str(output_tensor), tensor_info) + def test_bad_name(tester) -> None: + # Create an ExecuTorch program from ModuleAdd. + # pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`. + exported_program, inputs = create_program(ModuleAdd()) + + # Use pybindings to load and execute the program. + executorch_module = load_fn(exported_program.buffer) + # Invoke the callable on executorch_module instead of calling module.forward. + with tester.assertRaises(RuntimeError): + executorch_module.run_method("not_a_real_method", inputs) + ######### RUN TEST CASES ######### test_e2e(tester) test_multiple_entry(tester) @@ -351,5 +362,6 @@ def test_method_meta(tester) -> None: test_quantized_ops(tester) test_constant_output_not_memory_planned(tester) test_method_meta(tester) + test_bad_name(tester) return wrapper