Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions test/mobile/test_lite_script_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,13 @@ def forward(self):

_, lineno = inspect.getsourcelines(FooTest2)

with self.assertRaisesRegex(RuntimeError, 'test_lite_script_module.py\", line {}'.format(lineno + 3)):
# In C++ code, the type of exception thrown is torch::jit::JITException which
# does not extend c10::Error, and hence it isn't possible to add additional
# context to the exception message and preserve the correct C++ stack trace
# for symbolication. i.e. it isn't possible to add the debug handle string
# to show where in the Python code the exception occured w/o first changing
# torch::jit::JITException to extend c10::Error.
with self.assertRaisesRegex(torch.jit.Error, 'foo'):
ft = FooTest2()
loaded = self.getScriptExportImportCopy(ft)
loaded()
Expand Down Expand Up @@ -432,10 +438,16 @@ def forward(self, val: int, x, y, w):

try:
loaded(42, torch.rand(3, 4), torch.rand(3, 4), torch.rand(30, 40))
except RuntimeError as e:
except torch.jit.Error as e:
error_message = f"{e}"
self.assertTrue('test_lite_script_module.py\", line {}'.format(lineno + 8) in error_message)
self.assertTrue('top(FooTest5)' in error_message)

# In C++ code, the type of exception thrown is torch::jit::JITException which
# does not extend c10::Error, and hence it isn't possible to add additional
# context to the exception message and preserve the correct C++ stack trace
# for symbolication. i.e. it isn't possible to add the debug handle string
# to show where in the Python code the exception occured w/o first changing
# torch::jit::JITException to extend c10::Error.
self.assertTrue('self.val and val are same' in error_message)


class TestLiteScriptQuantizedModule(QuantizationLiteTestCase):
Expand Down
56 changes: 25 additions & 31 deletions torch/csrc/jit/mobile/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <exception>

#include <ATen/record_function.h>
#include <c10/util/ScopeExit.h>
#include <c10/util/irange.h>

namespace torch {
Expand Down Expand Up @@ -178,57 +179,50 @@ void Method::run(Stack& stack) const {
debug_info->setMethodName(function_->name());
at::DebugInfoGuard guard(at::DebugInfoKind::MOBILE_RUNTIME_INFO, debug_info);

std::string error_message;
auto failure_guard = c10::make_scope_exit([&]() {
if (!observer) {
return;
}

#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
if (error_message.empty()) {
error_message = owner_->getDebugTable().getSourceDebugString(
function_->getExceptionDebugHandle(), getTopModuleTypeName(*owner_));
}
#endif

observer->onFailRunMethod(
instance_key,
error_message.empty() ? "Unknown exception" : error_message.c_str());
});

try {
stack.insert(stack.begin(), owner_->_ivalue()); // self
function_->run(stack);
if (observer) {
observer->onExitRunMethod(instance_key);
}
failure_guard.release();
// This exception must be caught first as it derived from c10::Error
} catch (c10::BackendRuntimeException& e) {
#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
e.pushDebugHandle(function_->getExceptionDebugHandle());
// symbolicate all handles
e.add_context(owner_->getDebugTable().getSourceDebugString(
e.getDebugHandles(), getTopModuleTypeName(*owner_)));
auto debug_string = owner_->getDebugTable().getSourceDebugString(
e.getDebugHandles(), getTopModuleTypeName(*owner_));
e.add_context(debug_string);
#endif
if (observer) {
observer->onFailRunMethod(instance_key, e.what());
}
error_message = e.what();
TORCH_RETHROW(e);
} catch (c10::Error& error) {
#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
auto debug_string = owner_->getDebugTable().getSourceDebugString(
function_->getExceptionDebugHandle(), getTopModuleTypeName(*owner_));
error.add_context(debug_string);
#endif
if (observer) {
observer->onFailRunMethod(instance_key, error.what());
}
error_message = error.what();
TORCH_RETHROW(error);
} catch (...) {
auto currentException = std::current_exception();
try {
if (!currentException) {
TORCH_CHECK(false, "Unknown exception");
} else {
try {
std::rethrow_exception(currentException);
} catch (const std::exception& e) {
TORCH_CHECK(false, e.what());
}
}
} catch (c10::Error& error) {
#if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
auto debug_string = owner_->getDebugTable().getSourceDebugString(
function_->getExceptionDebugHandle(), getTopModuleTypeName(*owner_));
error.add_context(debug_string);
#endif
if (observer) {
observer->onFailRunMethod(instance_key, error.what());
}
TORCH_RETHROW(error);
}
}
}

Expand Down