From dec955d5691a18063efa5b0f9b3f1d1bad21f334 Mon Sep 17 00:00:00 2001 From: Dhruv Matani Date: Mon, 2 Aug 2021 19:51:30 -0700 Subject: [PATCH] [PyTorch Edge] Simplify Exception Handling (Take-2) (module.cpp) Apply the same set of changes as in D27688352 to `module.cpp` as instructed by @xcheng16. Basically, this simplifies exception handling and allows propagation of the original message undisturbed to the caller so that we can figure out the lineage of the exception in crash tasks such as t96812652 Differential Revision: [D30038867](https://our.internmc.facebook.com/intern/diff/D30038867/) [ghstack-poisoned] --- test/mobile/test_lite_script_module.py | 20 +++++++-- torch/csrc/jit/mobile/module.cpp | 56 ++++++++++++-------------- 2 files changed, 41 insertions(+), 35 deletions(-) diff --git a/test/mobile/test_lite_script_module.py b/test/mobile/test_lite_script_module.py index b4db9936c798..82bf527d2fb9 100644 --- a/test/mobile/test_lite_script_module.py +++ b/test/mobile/test_lite_script_module.py @@ -363,7 +363,13 @@ def forward(self): _, lineno = inspect.getsourcelines(FooTest2) - with self.assertRaisesRegex(RuntimeError, 'test_lite_script_module.py\", line {}'.format(lineno + 3)): + # In C++ code, the type of exception thrown is torch::jit::JITException which + # does not extend c10::Error, and hence it isn't possible to add additional + # context to the exception message and preserve the correct C++ stack trace + # for symbolication. i.e. it isn't possible to add the debug handle string + # to show where in the Python code the exception occured w/o first changing + # torch::jit::JITException to extend c10::Error. + with self.assertRaisesRegex(torch.jit.Error, 'foo'): ft = FooTest2() loaded = self.getScriptExportImportCopy(ft) loaded() @@ -432,10 +438,16 @@ def forward(self, val: int, x, y, w): try: loaded(42, torch.rand(3, 4), torch.rand(3, 4), torch.rand(30, 40)) - except RuntimeError as e: + except torch.jit.Error as e: error_message = f"{e}" - self.assertTrue('test_lite_script_module.py\", line {}'.format(lineno + 8) in error_message) - self.assertTrue('top(FooTest5)' in error_message) + + # In C++ code, the type of exception thrown is torch::jit::JITException which + # does not extend c10::Error, and hence it isn't possible to add additional + # context to the exception message and preserve the correct C++ stack trace + # for symbolication. i.e. it isn't possible to add the debug handle string + # to show where in the Python code the exception occured w/o first changing + # torch::jit::JITException to extend c10::Error. + self.assertTrue('self.val and val are same' in error_message) class TestLiteScriptQuantizedModule(QuantizationLiteTestCase): diff --git a/torch/csrc/jit/mobile/module.cpp b/torch/csrc/jit/mobile/module.cpp index 3f2f569f4a84..fad6447679bc 100644 --- a/torch/csrc/jit/mobile/module.cpp +++ b/torch/csrc/jit/mobile/module.cpp @@ -7,6 +7,7 @@ #include #include +#include #include namespace torch { @@ -178,23 +179,41 @@ void Method::run(Stack& stack) const { debug_info->setMethodName(function_->name()); at::DebugInfoGuard guard(at::DebugInfoKind::MOBILE_RUNTIME_INFO, debug_info); + std::string error_message; + auto failure_guard = c10::make_scope_exit([&]() { + if (!observer) { + return; + } + +#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE) + if (error_message.empty()) { + error_message = owner_->getDebugTable().getSourceDebugString( + function_->getExceptionDebugHandle(), getTopModuleTypeName(*owner_)); + } +#endif + + observer->onFailRunMethod( + instance_key, + error_message.empty() ? "Unknown exception" : error_message.c_str()); + }); + try { stack.insert(stack.begin(), owner_->_ivalue()); // self function_->run(stack); if (observer) { observer->onExitRunMethod(instance_key); } + failure_guard.release(); // This exception must be caught first as it derived from c10::Error } catch (c10::BackendRuntimeException& e) { #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE) e.pushDebugHandle(function_->getExceptionDebugHandle()); // symbolicate all handles - e.add_context(owner_->getDebugTable().getSourceDebugString( - e.getDebugHandles(), getTopModuleTypeName(*owner_))); + auto debug_string = owner_->getDebugTable().getSourceDebugString( + e.getDebugHandles(), getTopModuleTypeName(*owner_)); + e.add_context(debug_string); #endif - if (observer) { - observer->onFailRunMethod(instance_key, e.what()); - } + error_message = e.what(); TORCH_RETHROW(e); } catch (c10::Error& error) { #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE) @@ -202,33 +221,8 @@ void Method::run(Stack& stack) const { function_->getExceptionDebugHandle(), getTopModuleTypeName(*owner_)); error.add_context(debug_string); #endif - if (observer) { - observer->onFailRunMethod(instance_key, error.what()); - } + error_message = error.what(); TORCH_RETHROW(error); - } catch (...) { - auto currentException = std::current_exception(); - try { - if (!currentException) { - TORCH_CHECK(false, "Unknown exception"); - } else { - try { - std::rethrow_exception(currentException); - } catch (const std::exception& e) { - TORCH_CHECK(false, e.what()); - } - } - } catch (c10::Error& error) { -#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE) - auto debug_string = owner_->getDebugTable().getSourceDebugString( - function_->getExceptionDebugHandle(), getTopModuleTypeName(*owner_)); - error.add_context(debug_string); -#endif - if (observer) { - observer->onFailRunMethod(instance_key, error.what()); - } - TORCH_RETHROW(error); - } } }