Skip to content

Commit

Permalink
Preserve python backtrace in autograd engine errors.
Browse files Browse the repository at this point in the history
Pull Request resolved: #43684

This PR attempts to address #42560 by capturing the appropriate
exception_ptr in the autograd engine and passing it over to the Future.

As part of this change, there is a significant change the Future API where we
now only accept an exception_ptr as part of setError.

For the example in #42560, the exception trace would now look like:


```
> Traceback (most recent call last):
>   File "test_autograd.py", line 6914, in test_preserve_backtrace
>     Foo.apply(t).sum().backward()
>   File "torch/tensor.py", line 214, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "torch/autograd/__init__.py", line 127, in backward
>     allow_unreachable=True)  # allow_unreachable flag
>   File "torch/autograd/function.py", line 87, in apply
>     return self._forward_cls.backward(self, *args)
>   File "test_autograd.py", line 6910, in backward
>     raise ValueError("something")
> ValueError: something
```
ghstack-source-id: 111002998

Differential Revision: [D23365408](https://our.internmc.facebook.com/intern/diff/D23365408/)
  • Loading branch information
pritamdamania committed Aug 29, 2020
1 parent 1f69968 commit a07bcdc
Show file tree
Hide file tree
Showing 17 changed files with 120 additions and 82 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/core/ivalue.cpp
Expand Up @@ -770,7 +770,7 @@ CAFFE2_API intrusive_ptr<ivalue::Future> collectAny(
ctx->srcFutures =
List<intrusive_ptr<ivalue::Future>>(ctx->srcFutures.elementType());
if (src->hasError()) {
dst->setError(*src->error());
dst->setError(src->exception_ptr());
} else {
dst->markCompleted(src->constValue());
}
Expand Down
59 changes: 37 additions & 22 deletions aten/src/ATen/core/ivalue_inl.h
Expand Up @@ -300,35 +300,32 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
markCompleted(IValue {});
}

virtual void setError(std::string err) {
setError(FutureError(std::move(err)));
}

void setError(FutureError&& error) {
void setError(std::exception_ptr eptr) {
std::unique_lock<std::mutex> lock(mutex_);
setErrorInternal(std::move(error), lock);
setErrorInternal(std::move(eptr), lock);
}

void setErrorIfNeeded(std::string errorMsg) {
void setErrorIfNeeded(std::exception_ptr eptr) {
std::unique_lock<std::mutex> lock(mutex_);
if (completed_) {
// This should be rare and shouldn't cause log spew. Its important to
// log errors and thats why we have this log here.
LOG(INFO) << "Skipping setting following error on the Future since " <<
"it is already marked completed (this is not neccessarily an error): "
<< errorMsg;
LOG(INFO)
<< "Skipping setting following error on the Future since "
<< "it is already marked completed (this is not neccessarily an error): "
<< tryRetrieveErrorMessageInternal(eptr);
return;
} else {
setErrorInternal(FutureError(std::move(errorMsg)), lock);
setErrorInternal(std::move(eptr), lock);
}
}

// Get the result of the current future.
virtual IValue value() {
std::unique_lock<std::mutex> lock(mutex_);
AT_ASSERT(completed());
if (error_) {
throw *error_;
if (eptr_) {
std::rethrow_exception(eptr_);
}
return value_;
}
Expand All @@ -338,7 +335,7 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
virtual const IValue& constValue() {
std::unique_lock<std::mutex> lock(mutex_);
AT_ASSERT(completed());
AT_ASSERT(!error_);
AT_ASSERT(!eptr_);
return value_;
}

Expand Down Expand Up @@ -375,31 +372,38 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
try {
fut->markCompleted(cb());
} catch (std::exception& e) {
fut->setError(e.what());
fut->setError(std::current_exception());
}
},
std::move(callback)));
return fut;
}

// Tries to retrieve the error message from std::exception_ptr.
std::string tryRetrieveErrorMessage() {
TORCH_CHECK(hasError(), "No error present on the future.");
std::unique_lock<std::mutex> lock(mutex_);
return tryRetrieveErrorMessageInternal(eptr_);
}

// Check if the current future has completed
virtual bool completed() const{
return completed_;
}

virtual bool hasValue() const {
std::unique_lock<std::mutex> lock(mutex_);
return completed_ && !error_;
return completed_ && !eptr_;
}

bool hasError() const {
std::unique_lock<std::mutex> lock(mutex_);
return error_ ? true : false;
return eptr_ ? true : false;
}

c10::optional<FutureError> error() const {
std::exception_ptr exception_ptr() const {
std::unique_lock<std::mutex> lock(mutex_);
return error_;
return eptr_;
}

CAFFE2_API friend std::ostream& operator<<(
Expand All @@ -412,11 +416,11 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {

private:
void setErrorInternal(
FutureError error,
std::exception_ptr eptr,
std::unique_lock<std::mutex>& lock) {
AT_ASSERT(!completed());
completed_ = true;
error_ = std::move(error);
eptr_ = std::move(eptr);

std::vector<std::function<void(void)>> cbs;
cbs.swap(callbacks_);
Expand All @@ -428,14 +432,25 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target {
}
}

// Tries to retrieve the error message from std::exception_ptr.
std::string tryRetrieveErrorMessageInternal(std::exception_ptr eptr) {
try {
std::rethrow_exception(eptr);
} catch (const std::exception& e) {
return e.what();
} catch (...) {
return "Unknown Exception Type";
}
}

mutable std::mutex mutex_;
std::atomic_bool completed_ = {false}; // is this future complete
std::condition_variable finished_cv_;

IValue value_; // when finished the value
TypePtr type_;
std::vector<std::function<void(void)>> callbacks_;
c10::optional<FutureError> error_;
std::exception_ptr eptr_;
};

// Input is a list of Futures with the same target type.
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/test/ivalue_test.cpp
Expand Up @@ -139,10 +139,10 @@ TEST(IValueTest, FutureExceptions) {
}
});
ivalue::Future::FutureError err("My Error");
f3->setError(std::move(err));
f3->setError(std::make_exception_ptr(err));
ASSERT_EQ(calledTimes, 1);
ASSERT_TRUE(f3->hasError());
ASSERT_EQ(std::string(f3->error()->what()), std::string("My Error"));
ASSERT_EQ(f3->tryRetrieveErrorMessage(), std::string("My Error"));
}

TEST(IValueTest, ValueEquality) {
Expand Down
12 changes: 8 additions & 4 deletions test/cpp/jit/test_misc.cpp
Expand Up @@ -1906,7 +1906,8 @@ void testFutures() {
int sat1 = 0;
int sat2 = 0;
f1->addCallback([&]() { ++sat1; });
f1->setError("Failed");
f1->setError(
std::make_exception_ptr(c10::ivalue::Future::FutureError("Failed")));
ASSERT_EQ(sat1, 1);
ASSERT_TRUE(f1->completed());
ASSERT_TRUE(f1->hasError());
Expand All @@ -1920,8 +1921,10 @@ void testFutures() {
f1->addCallback([&]() { ++sat2; });
ASSERT_EQ(sat1, 1);
ASSERT_EQ(sat2, 1);
f1->setErrorIfNeeded("Dup");
ASSERT_TRUE(strcmp(f1->error()->what(), "Failed") == 0);
f1->setErrorIfNeeded(
std::make_exception_ptr(c10::ivalue::Future::FutureError("Dup")));
ASSERT_TRUE(
strcmp(f1->tryRetrieveErrorMessage().c_str(), "Failed") == 0);
ASSERT_EQ(sat1, 1);
ASSERT_EQ(sat2, 1);
}
Expand Down Expand Up @@ -2001,7 +2004,8 @@ void testFutures() {
futures.push_back(s4);
auto c5 = collectAll(futures);
ASSERT_FALSE(c5->completed());
s4->setError("Failed");
s4->setError(
std::make_exception_ptr(c10::ivalue::Future::FutureError("Failed")));
ASSERT_TRUE(c5->completed());
ASSERT_EQ(c5->value().toList().size(), 4);
try {
Expand Down
24 changes: 21 additions & 3 deletions test/test_autograd.py
Expand Up @@ -253,7 +253,7 @@ def test_custom_function_exception(self):

tmp = (t1 + t2) * (t1 + t2)
t3 = TestAutograd.SimulateBackwardError.apply(tmp)
with self.assertRaisesRegex(RuntimeError, "Simulate error on backward pass"):
with self.assertRaisesRegex(Exception, "Simulate error on backward pass"):
t3.sum().backward()

def test_invalid_gradients(self):
Expand Down Expand Up @@ -2313,7 +2313,7 @@ def backward(ctx, grad):
return grad

d = ReentrantFunc.apply(c)
with self.assertRaisesRegex(RuntimeError, 'Simulate error'):
with self.assertRaisesRegex(Exception, 'Simulate error'):
d.sum().backward()

def test_broadcast_tensors(self):
Expand Down Expand Up @@ -6103,7 +6103,7 @@ def backward(ctx, grad):
t7 = t6 * t6

# Parent graph will error out first, while child graph will continue executing.
with self.assertRaisesRegex(RuntimeError, "Simulate error"):
with self.assertRaisesRegex(Exception, "Simulate error"):
torch.autograd.backward([t5.sum(), t7.sum()])

# No grads should be accumulated since child graph will stop execution
Expand Down Expand Up @@ -6899,6 +6899,24 @@ def train_fn_fork_join_calls_retain(x):
self.assertEqual(grad, grad1)
self.assertEqual(grad, grad2)

def test_preserve_backtrace(self):
class Foo(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
return input

@staticmethod
def backward(ctx, *grad):
raise ValueError("something")

t = torch.rand(10, requires_grad=True)
try:
Foo.apply(t).sum().backward()
except Exception:
import traceback
tb = sys.exc_info()[2]
tb_str = "\n".join(traceback.format_tb(tb))
self.assertTrue('raise ValueError("something")' in tb_str)

for test in method_tests():
add_test(*test)
Expand Down
8 changes: 4 additions & 4 deletions torch/csrc/autograd/engine.cpp
Expand Up @@ -441,7 +441,7 @@ void Engine::thread_on_exception(
std::shared_ptr<GraphTask> graph_task,
const std::shared_ptr<Node>& fn,
std::exception& e) {
graph_task->set_exception(e, fn);
graph_task->set_exception(std::current_exception(), fn);
}

bool GraphTask::completed() {
Expand Down Expand Up @@ -472,7 +472,7 @@ void GraphTask::mark_as_completed_and_run_post_processing() {
lock.unlock();
future_result_->markCompleted(std::move(vars));
} catch (std::exception& e) {
future_result_->setErrorIfNeeded(e.what());
future_result_->setErrorIfNeeded(std::current_exception());
}
}

Expand Down Expand Up @@ -522,11 +522,11 @@ void GraphTask::set_exception_without_signal(const std::shared_ptr<Node>& fn) {
}

void GraphTask::set_exception(
std::exception& e,
std::exception_ptr eptr,
const std::shared_ptr<Node>& fn) {
set_exception_without_signal(fn);
if (!future_completed_.exchange(true)) {
future_result_->setError(e.what());
future_result_->setError(std::move(eptr));
}
}

Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/autograd/engine.h
Expand Up @@ -129,7 +129,7 @@ struct GraphTask: std::enable_shared_from_this<GraphTask> {

// Set an appropriate exception on this graph_task which was encountered while
// running the provided function.
void set_exception(std::exception& e, const std::shared_ptr<Node>& fn);
void set_exception(std::exception_ptr eptr, const std::shared_ptr<Node>& fn);

// Set an appropriate exception on this graph_task which was encountered while
// running the provided function. But doesn't signal completion on
Expand Down
6 changes: 3 additions & 3 deletions torch/csrc/distributed/autograd/context/context.cpp
Expand Up @@ -127,17 +127,17 @@ void DistAutogradContext::addOutstandingRpc(
futureMessage->addCallback([this](const rpc::FutureMessage& futureMessage) {
if (futureMessage.hasError()) {
// If we have an error, let the local autograd engine know about it.
std::runtime_error err((*futureMessage.error()).what());
std::unique_lock<std::mutex> lock(lock_);
if (graphTask_) {
graphTask_->set_exception_without_signal(nullptr);
lock.unlock();
if (!graphTask_->future_completed_.exchange(true)) {
graphTask_->future_result_->setErrorIfNeeded(err.what());
graphTask_->future_result_->setErrorIfNeeded(
std::make_exception_ptr(futureMessage.error()));
}
} else {
LOG(WARNING) << "Ignoring error since GraphTask is no longer valid: "
<< err.what();
<< (*futureMessage.error()).what();
}
}
});
Expand Down
46 changes: 24 additions & 22 deletions torch/csrc/distributed/autograd/engine/dist_engine.cpp
Expand Up @@ -389,29 +389,31 @@ std::shared_ptr<rpc::FutureMessage> DistEngine::runEngineAndAccumulateGradients(
// future that waits for all gradient accumulation to finish.
auto accumulateGradFuture = std::make_shared<rpc::FutureMessage>();

futureGrads->addCallback([autogradContext, outputEdges, accumulateGradFuture, &futureGrads]() {
if (futureGrads->hasError()) {
// Don't accumulate gradients if we receive an error.
// We must add the node information here since DistEngine::execute
// waits on accumulateGradFuture and will throw an exception once we
// set the error below.
std::string errorMsg = c10::str(
"Error on Node ",
DistAutogradContainer::getInstance().getWorkerId(),
": ",
futureGrads->error()->what());
accumulateGradFuture->setError(errorMsg);
return;
}
futureGrads->addCallback(
[autogradContext, outputEdges, accumulateGradFuture, &futureGrads]() {
if (futureGrads->hasError()) {
// Don't accumulate gradients if we receive an error.
// We must add the node information here since DistEngine::execute
// waits on accumulateGradFuture and will throw an exception once we
// set the error below.
std::string errorMsg = c10::str(
"Error on Node ",
DistAutogradContainer::getInstance().getWorkerId(),
": ",
futureGrads->tryRetrieveErrorMessage());
accumulateGradFuture->setError(errorMsg);
return;
}

try {
const variable_list& grads = futureGrads->constValue().toTensorVector();
TORCH_INTERNAL_ASSERT(grads.size() == outputEdges.size());
accumulateGradFuture->markCompleted(rpc::Message());
} catch (std::exception& e) {
accumulateGradFuture->setErrorIfNeeded(e.what());
}
});
try {
const variable_list& grads =
futureGrads->constValue().toTensorVector();
TORCH_INTERNAL_ASSERT(grads.size() == outputEdges.size());
accumulateGradFuture->markCompleted(rpc::Message());
} catch (std::exception& e) {
accumulateGradFuture->setErrorIfNeeded(e.what());
}
});

return accumulateGradFuture;
}
Expand Down
6 changes: 4 additions & 2 deletions torch/csrc/distributed/rpc/python_functions.cpp
Expand Up @@ -138,7 +138,8 @@ c10::intrusive_ptr<JitFuture> wrapFutureMessageInJitFuture(
at::wrapPropagateTLSState<void>([jitFuture, wp]() {
auto futureResponseMessage = wp.lock();
if (futureResponseMessage->hasError()) {
jitFuture->setError(futureResponseMessage->error()->what());
jitFuture->setError(
std::make_exception_ptr(futureResponseMessage->error()));
} else {
jitFuture->markCompleted(
toIValue(futureResponseMessage->constValue()));
Expand All @@ -154,7 +155,8 @@ c10::intrusive_ptr<JitFuture> wrapFutureMessageInJitFuture(
at::wrapPropagateTLSState<void>([wp, jitFuture]() {
auto futureResponseMessage = wp.lock();
if (futureResponseMessage->hasError()) {
jitFuture->setError(futureResponseMessage->error()->what());
jitFuture->setError(
std::make_exception_ptr(*futureResponseMessage->error()));
} else {
jitFuture->markCompleted(IValue());
}
Expand Down

0 comments on commit a07bcdc

Please sign in to comment.