Skip to content

Commit

Permalink
[JIT] Add a flag to rethrow caught exception in jit interpreter (#63073)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #63073

It turned out that it's less than ideal to print out verbose stacktrace in exception messages in high-QPS services (see the related task) with a non-significant failure rate due to the truncation of long stacktrace which results in losing the original exception message thrown from native code. It is actually desirable to retain only the message of the original exception directly thrown from native code in such a usecase.

This change adds a new flag `torch_jit_disable_exception_stacktrace` to the pytorch jit interpreter to suppress stacktrace in the messages of exception thrown from the interpreter.

Reviewed By: Krovatkin

Differential Revision: D30241792

fbshipit-source-id: fbc90c11e99ab8c1623f016559ad5bfeeaca8465
  • Loading branch information
d1jang authored and facebook-github-bot committed Aug 13, 2021
1 parent 126ff62 commit 8ce8b39
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 5 deletions.
58 changes: 58 additions & 0 deletions test/cpp/jit/test_interpreter.cpp
@@ -1,3 +1,4 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include <ATen/Parallel.h>
Expand Down Expand Up @@ -210,5 +211,62 @@ TEST(InterpreterTest, runAsyncBasicTest) {
interp.runAsync(stack)->wait();
ASSERT_TRUE(asyncCounter > 0);
}

TEST(
EnableRethrowCaughtExceptionTest,
EnableRethrowCaughtExceptionTestRethrowsCaughtException) {
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(
R"IR(
graph(%0 : Tensor,
%1 : Tensor):
%2 : int = prim::Constant[value=2]()
%3 : Tensor = aten::add(%0, %1, %2)
return (%3)
)IR",
&*graph,
vmap);
Code function(graph, "");
InterpreterState interp = InterpreterState(function);
auto a = at::zeros({2, 2}, at::kFloat);
auto b = at::ones({2, 3}, at::kFloat);
a.set_requires_grad(true);
a = a.to(at::kCPU);
std::vector<IValue> stack({a, b});

bool original_flag_value = FLAGS_torch_jit_enable_rethrow_caught_exception;
bool exception_handled = false;
try {
FLAGS_torch_jit_enable_rethrow_caught_exception = false;
interp.run(stack);
} catch (std::runtime_error& e) {
exception_handled = true;
std::string exception_msg = e.what();
EXPECT_THAT(
exception_msg,
::testing::HasSubstr("%3 : Tensor = aten::add(%0, %1, %2)"));
EXPECT_THAT(
exception_msg,
::testing::HasSubstr(
"The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1"));
}
EXPECT_TRUE(exception_handled);

exception_handled = false;
try {
FLAGS_torch_jit_enable_rethrow_caught_exception = true;
interp.run(stack);
} catch (c10::Error& e) {
exception_handled = true;
std::string exception_msg = e.what_without_backtrace();
EXPECT_STREQ(
exception_msg.c_str(),
"The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1");
}
EXPECT_TRUE(exception_handled);
FLAGS_torch_jit_enable_rethrow_caught_exception = original_flag_value;
}

} // namespace jit
} // namespace torch
18 changes: 13 additions & 5 deletions torch/csrc/jit/runtime/interpreter.cpp
Expand Up @@ -43,6 +43,11 @@ using torch::distributed::autograd::DistAutogradContainer;
#include <utility>
#include <vector>

C10_DEFINE_bool(
torch_jit_enable_rethrow_caught_exception,
false,
"enable rethrowing caught exception");

namespace torch {
namespace jit {

Expand Down Expand Up @@ -708,12 +713,15 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
push(stack, IValue());
try {
f.run(stack);
} catch (std::exception& e) {
std::ostringstream ss;
ss << "The following operation failed in the TorchScript interpreter.\n";
formatStackTrace(ss);
ss << "RuntimeError: " << ExceptionMessage(e) << "\n";
} catch (std::exception& _) {
// TODO(T98048876): Handle `_` correctly.
}
}
if (FLAGS_torch_jit_enable_rethrow_caught_exception) {
if (future_) {
future_->setError(std::make_exception_ptr(e));
}
throw;
}
bool is_jit_exception = dynamic_cast<JITException*>(&e);
// Janky af. See https://github.com/pytorch/pytorch/issues/54612
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/runtime/interpreter.h
Expand Up @@ -10,6 +10,7 @@
#include <torch/csrc/jit/frontend/source_range.h>

C10_DECLARE_bool(torch_jit_disable_warning_prints);
C10_DECLARE_bool(torch_jit_enable_rethrow_caught_exception);

namespace at {
class Tensor;
Expand Down

0 comments on commit 8ce8b39

Please sign in to comment.