diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp index 55da52e30fc49..6726c5c0cad92 100644 --- a/aten/src/ATen/core/ivalue.cpp +++ b/aten/src/ATen/core/ivalue.cpp @@ -770,7 +770,7 @@ CAFFE2_API intrusive_ptr collectAny( ctx->srcFutures = List>(ctx->srcFutures.elementType()); if (src->hasError()) { - dst->setError(*src->error()); + dst->setError(src->exception_ptr()); } else { dst->markCompleted(src->constValue()); } diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h index 5f070909d1103..cee0e8be62487 100644 --- a/aten/src/ATen/core/ivalue_inl.h +++ b/aten/src/ATen/core/ivalue_inl.h @@ -300,26 +300,23 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { markCompleted(IValue {}); } - virtual void setError(std::string err) { - setError(FutureError(std::move(err))); - } - - void setError(FutureError&& error) { + void setError(std::exception_ptr eptr) { std::unique_lock lock(mutex_); - setErrorInternal(std::move(error), lock); + setErrorInternal(std::move(eptr), lock); } - void setErrorIfNeeded(std::string errorMsg) { + void setErrorIfNeeded(std::exception_ptr eptr) { std::unique_lock lock(mutex_); if (completed_) { // This should be rare and shouldn't cause log spew. Its important to // log errors and thats why we have this log here. - LOG(INFO) << "Skipping setting following error on the Future since " << - "it is already marked completed (this is not neccessarily an error): " - << errorMsg; + LOG(INFO) + << "Skipping setting following error on the Future since " + << "it is already marked completed (this is not neccessarily an error): " + << tryRetrieveErrorMessageInternal(eptr); return; } else { - setErrorInternal(FutureError(std::move(errorMsg)), lock); + setErrorInternal(std::move(eptr), lock); } } @@ -327,8 +324,8 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { virtual IValue value() { std::unique_lock lock(mutex_); AT_ASSERT(completed()); - if (error_) { - throw *error_; + if (eptr_) { + std::rethrow_exception(eptr_); } return value_; } @@ -338,7 +335,7 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { virtual const IValue& constValue() { std::unique_lock lock(mutex_); AT_ASSERT(completed()); - AT_ASSERT(!error_); + AT_ASSERT(!eptr_); return value_; } @@ -375,13 +372,20 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { try { fut->markCompleted(cb()); } catch (std::exception& e) { - fut->setError(e.what()); + fut->setError(std::current_exception()); } }, std::move(callback))); return fut; } + // Tries to retrieve the error message from std::exception_ptr. + std::string tryRetrieveErrorMessage() { + TORCH_CHECK(hasError(), "No error present on the future."); + std::unique_lock lock(mutex_); + return tryRetrieveErrorMessageInternal(eptr_); + } + // Check if the current future has completed virtual bool completed() const{ return completed_; @@ -389,17 +393,17 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { virtual bool hasValue() const { std::unique_lock lock(mutex_); - return completed_ && !error_; + return completed_ && !eptr_; } bool hasError() const { std::unique_lock lock(mutex_); - return error_ ? true : false; + return eptr_ ? true : false; } - c10::optional error() const { + std::exception_ptr exception_ptr() const { std::unique_lock lock(mutex_); - return error_; + return eptr_; } CAFFE2_API friend std::ostream& operator<<( @@ -412,11 +416,11 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { private: void setErrorInternal( - FutureError error, + std::exception_ptr eptr, std::unique_lock& lock) { AT_ASSERT(!completed()); completed_ = true; - error_ = std::move(error); + eptr_ = std::move(eptr); std::vector> cbs; cbs.swap(callbacks_); @@ -428,6 +432,17 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { } } + // Tries to retrieve the error message from std::exception_ptr. + std::string tryRetrieveErrorMessageInternal(std::exception_ptr eptr) { + try { + std::rethrow_exception(eptr); + } catch (const std::exception& e) { + return e.what(); + } catch (...) { + return "Unknown Exception Type"; + } + } + mutable std::mutex mutex_; std::atomic_bool completed_ = {false}; // is this future complete std::condition_variable finished_cv_; @@ -435,7 +450,7 @@ struct C10_EXPORT ivalue::Future : c10::intrusive_ptr_target { IValue value_; // when finished the value TypePtr type_; std::vector> callbacks_; - c10::optional error_; + std::exception_ptr eptr_; }; // Input is a list of Futures with the same target type. diff --git a/aten/src/ATen/test/ivalue_test.cpp b/aten/src/ATen/test/ivalue_test.cpp index 52ce92a8e49b1..6474aa45d4dd2 100644 --- a/aten/src/ATen/test/ivalue_test.cpp +++ b/aten/src/ATen/test/ivalue_test.cpp @@ -139,10 +139,10 @@ TEST(IValueTest, FutureExceptions) { } }); ivalue::Future::FutureError err("My Error"); - f3->setError(std::move(err)); + f3->setError(std::make_exception_ptr(err)); ASSERT_EQ(calledTimes, 1); ASSERT_TRUE(f3->hasError()); - ASSERT_EQ(std::string(f3->error()->what()), std::string("My Error")); + ASSERT_EQ(f3->tryRetrieveErrorMessage(), std::string("My Error")); } TEST(IValueTest, ValueEquality) { diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp index d2aa322c2bbd2..2aba55fdf9b82 100644 --- a/test/cpp/jit/test_misc.cpp +++ b/test/cpp/jit/test_misc.cpp @@ -1987,7 +1987,8 @@ void testFutures() { int sat1 = 0; int sat2 = 0; f1->addCallback([&]() { ++sat1; }); - f1->setError("Failed"); + f1->setError( + std::make_exception_ptr(c10::ivalue::Future::FutureError("Failed"))); ASSERT_EQ(sat1, 1); ASSERT_TRUE(f1->completed()); ASSERT_TRUE(f1->hasError()); @@ -2001,8 +2002,9 @@ void testFutures() { f1->addCallback([&]() { ++sat2; }); ASSERT_EQ(sat1, 1); ASSERT_EQ(sat2, 1); - f1->setErrorIfNeeded("Dup"); - ASSERT_TRUE(strcmp(f1->error()->what(), "Failed") == 0); + f1->setErrorIfNeeded( + std::make_exception_ptr(c10::ivalue::Future::FutureError("Dup"))); + ASSERT_TRUE(strcmp(f1->tryRetrieveErrorMessage().c_str(), "Failed") == 0); ASSERT_EQ(sat1, 1); ASSERT_EQ(sat2, 1); } @@ -2082,7 +2084,8 @@ void testFutures() { futures.push_back(s4); auto c5 = collectAll(futures); ASSERT_FALSE(c5->completed()); - s4->setError("Failed"); + s4->setError( + std::make_exception_ptr(c10::ivalue::Future::FutureError("Failed"))); ASSERT_TRUE(c5->completed()); ASSERT_EQ(c5->value().toList().size(), 4); try { diff --git a/test/test_autograd.py b/test/test_autograd.py index b4b8fd766f7c3..037899b18a8d8 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -253,7 +253,7 @@ def test_custom_function_exception(self): tmp = (t1 + t2) * (t1 + t2) t3 = TestAutograd.SimulateBackwardError.apply(tmp) - with self.assertRaisesRegex(RuntimeError, "Simulate error on backward pass"): + with self.assertRaisesRegex(Exception, "Simulate error on backward pass"): t3.sum().backward() def test_invalid_gradients(self): @@ -2313,7 +2313,7 @@ def backward(ctx, grad): return grad d = ReentrantFunc.apply(c) - with self.assertRaisesRegex(RuntimeError, 'Simulate error'): + with self.assertRaisesRegex(Exception, 'Simulate error'): d.sum().backward() def test_broadcast_tensors(self): @@ -6168,7 +6168,7 @@ def backward(ctx, grad): t7 = t6 * t6 # Parent graph will error out first, while child graph will continue executing. - with self.assertRaisesRegex(RuntimeError, "Simulate error"): + with self.assertRaisesRegex(Exception, "Simulate error"): torch.autograd.backward([t5.sum(), t7.sum()]) # No grads should be accumulated since child graph will stop execution @@ -6964,6 +6964,24 @@ def train_fn_fork_join_calls_retain(x): self.assertEqual(grad, grad1) self.assertEqual(grad, grad2) + def test_preserve_backtrace(self): + class Foo(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + return input + + @staticmethod + def backward(ctx, *grad): + raise ValueError("something") + + t = torch.rand(10, requires_grad=True) + try: + Foo.apply(t).sum().backward() + except Exception: + import traceback + tb = sys.exc_info()[2] + tb_str = "\n".join(traceback.format_tb(tb)) + self.assertTrue('raise ValueError("something")' in tb_str) for test in method_tests(): add_test(*test) diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index 17bfd10354932..62ca26e469399 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -442,7 +442,7 @@ void Engine::thread_on_exception( std::shared_ptr graph_task, const std::shared_ptr& fn, std::exception& e) { - graph_task->set_exception(e, fn); + graph_task->set_exception(std::current_exception(), fn); } bool GraphTask::completed() { @@ -473,7 +473,7 @@ void GraphTask::mark_as_completed_and_run_post_processing() { lock.unlock(); future_result_->markCompleted(std::move(vars)); } catch (std::exception& e) { - future_result_->setErrorIfNeeded(e.what()); + future_result_->setErrorIfNeeded(std::current_exception()); } } @@ -523,11 +523,11 @@ void GraphTask::set_exception_without_signal(const std::shared_ptr& fn) { } void GraphTask::set_exception( - std::exception& e, + std::exception_ptr eptr, const std::shared_ptr& fn) { set_exception_without_signal(fn); if (!future_completed_.exchange(true)) { - future_result_->setError(e.what()); + future_result_->setError(std::move(eptr)); } } diff --git a/torch/csrc/autograd/engine.h b/torch/csrc/autograd/engine.h index f72bdec36a9df..0dde6e735d10c 100644 --- a/torch/csrc/autograd/engine.h +++ b/torch/csrc/autograd/engine.h @@ -129,7 +129,7 @@ struct GraphTask: std::enable_shared_from_this { // Set an appropriate exception on this graph_task which was encountered while // running the provided function. - void set_exception(std::exception& e, const std::shared_ptr& fn); + void set_exception(std::exception_ptr eptr, const std::shared_ptr& fn); // Set an appropriate exception on this graph_task which was encountered while // running the provided function. But doesn't signal completion on diff --git a/torch/csrc/distributed/autograd/context/context.cpp b/torch/csrc/distributed/autograd/context/context.cpp index 3401258e7cae5..a74a596d88ad6 100644 --- a/torch/csrc/distributed/autograd/context/context.cpp +++ b/torch/csrc/distributed/autograd/context/context.cpp @@ -127,17 +127,17 @@ void DistAutogradContext::addOutstandingRpc( futureMessage->addCallback([this](const rpc::FutureMessage& futureMessage) { if (futureMessage.hasError()) { // If we have an error, let the local autograd engine know about it. - std::runtime_error err((*futureMessage.error()).what()); std::unique_lock lock(lock_); if (graphTask_) { graphTask_->set_exception_without_signal(nullptr); lock.unlock(); if (!graphTask_->future_completed_.exchange(true)) { - graphTask_->future_result_->setErrorIfNeeded(err.what()); + graphTask_->future_result_->setErrorIfNeeded( + std::make_exception_ptr(*futureMessage.error())); } } else { LOG(WARNING) << "Ignoring error since GraphTask is no longer valid: " - << err.what(); + << (*futureMessage.error()).what(); } } }); diff --git a/torch/csrc/distributed/autograd/engine/dist_engine.cpp b/torch/csrc/distributed/autograd/engine/dist_engine.cpp index ccf373224c722..71ac010bf19a3 100644 --- a/torch/csrc/distributed/autograd/engine/dist_engine.cpp +++ b/torch/csrc/distributed/autograd/engine/dist_engine.cpp @@ -389,29 +389,31 @@ std::shared_ptr DistEngine::runEngineAndAccumulateGradients( // future that waits for all gradient accumulation to finish. auto accumulateGradFuture = std::make_shared(); - futureGrads->addCallback([autogradContext, outputEdges, accumulateGradFuture, &futureGrads]() { - if (futureGrads->hasError()) { - // Don't accumulate gradients if we receive an error. - // We must add the node information here since DistEngine::execute - // waits on accumulateGradFuture and will throw an exception once we - // set the error below. - std::string errorMsg = c10::str( - "Error on Node ", - DistAutogradContainer::getInstance().getWorkerId(), - ": ", - futureGrads->error()->what()); - accumulateGradFuture->setError(errorMsg); - return; - } + futureGrads->addCallback( + [autogradContext, outputEdges, accumulateGradFuture, &futureGrads]() { + if (futureGrads->hasError()) { + // Don't accumulate gradients if we receive an error. + // We must add the node information here since DistEngine::execute + // waits on accumulateGradFuture and will throw an exception once we + // set the error below. + std::string errorMsg = c10::str( + "Error on Node ", + DistAutogradContainer::getInstance().getWorkerId(), + ": ", + futureGrads->tryRetrieveErrorMessage()); + accumulateGradFuture->setError(errorMsg); + return; + } - try { - const variable_list& grads = futureGrads->constValue().toTensorVector(); - TORCH_INTERNAL_ASSERT(grads.size() == outputEdges.size()); - accumulateGradFuture->markCompleted(rpc::Message()); - } catch (std::exception& e) { - accumulateGradFuture->setErrorIfNeeded(e.what()); - } - }); + try { + const variable_list& grads = + futureGrads->constValue().toTensorVector(); + TORCH_INTERNAL_ASSERT(grads.size() == outputEdges.size()); + accumulateGradFuture->markCompleted(rpc::Message()); + } catch (std::exception& e) { + accumulateGradFuture->setErrorIfNeeded(e.what()); + } + }); return accumulateGradFuture; } diff --git a/torch/csrc/distributed/rpc/python_functions.cpp b/torch/csrc/distributed/rpc/python_functions.cpp index e42b466896797..b7c16639b19b1 100644 --- a/torch/csrc/distributed/rpc/python_functions.cpp +++ b/torch/csrc/distributed/rpc/python_functions.cpp @@ -138,7 +138,8 @@ c10::intrusive_ptr wrapFutureMessageInJitFuture( at::wrapPropagateTLSState([jitFuture, wp]() { auto futureResponseMessage = wp.lock(); if (futureResponseMessage->hasError()) { - jitFuture->setError(futureResponseMessage->error()->what()); + jitFuture->setError( + std::make_exception_ptr(*futureResponseMessage->error())); } else { jitFuture->markCompleted( toIValue(futureResponseMessage->constValue())); @@ -154,7 +155,8 @@ c10::intrusive_ptr wrapFutureMessageInJitFuture( at::wrapPropagateTLSState([wp, jitFuture]() { auto futureResponseMessage = wp.lock(); if (futureResponseMessage->hasError()) { - jitFuture->setError(futureResponseMessage->error()->what()); + jitFuture->setError( + std::make_exception_ptr(*futureResponseMessage->error())); } else { jitFuture->markCompleted(IValue()); } diff --git a/torch/csrc/distributed/rpc/request_callback_impl.cpp b/torch/csrc/distributed/rpc/request_callback_impl.cpp index 821d63011607e..b68cb4092b678 100644 --- a/torch/csrc/distributed/rpc/request_callback_impl.cpp +++ b/torch/csrc/distributed/rpc/request_callback_impl.cpp @@ -274,7 +274,7 @@ void RequestCallbackImpl::processScriptRemoteCall( try { ownerRRef->setValue(jitFuture->value()); } catch (const std::exception& e) { - ownerRRef->setError(e.what()); + ownerRRef->setError(std::current_exception()); } postProcessing(); }; @@ -297,7 +297,7 @@ void RequestCallbackImpl::processScriptRemoteCall( setRRefValue(valueJitFuture); }); } catch (const std::exception& e) { - ownerRRef->setError(e.what()); + ownerRRef->setError(std::current_exception()); postProcessing(); } } else { @@ -380,7 +380,7 @@ void RequestCallbackImpl::processPythonRemoteCall( responseFuture->markCompleted(std::move(m)); }); } catch (std::exception& e) { - ownerRRef->setError(e.what()); + ownerRRef->setError(std::current_exception()); auto m = RemoteRet(rrefId, forkId).toMessage(); m.setId(messageId); responseFuture->markCompleted(std::move(m)); @@ -397,12 +397,12 @@ void RequestCallbackImpl::processPythonRemoteCall( ownerRRef->setValue(std::move(py_ivalue)); } catch (py::error_already_set& e) { // py::error_already_set requires GIL to destruct, take special care. - ownerRRef->setError(e.what()); + ownerRRef->setError(std::current_exception()); py::gil_scoped_acquire acquire; e.restore(); PyErr_Clear(); } catch (std::exception& e) { - ownerRRef->setError(e.what()); + ownerRRef->setError(std::current_exception()); } markComplete(RemoteRet(rrefId, forkId).toMessage()); } @@ -418,7 +418,7 @@ void RequestCallbackImpl::processPythonRRefFetchCall( int64_t messageId) mutable { auto whenValueSet = rref->getFuture(); if (whenValueSet->hasError()) { - responseFuture->setError(whenValueSet->error()->what()); + responseFuture->setError(whenValueSet->tryRetrieveErrorMessage()); return; } try { diff --git a/torch/csrc/distributed/rpc/request_callback_no_python.cpp b/torch/csrc/distributed/rpc/request_callback_no_python.cpp index 4e085e269f2bf..895073a28377e 100644 --- a/torch/csrc/distributed/rpc/request_callback_no_python.cpp +++ b/torch/csrc/distributed/rpc/request_callback_no_python.cpp @@ -192,7 +192,7 @@ bool RequestCallbackNoPython::processScriptRemoteCallOp( } catch (const std::exception& e) { // Don't throw in this call, but rather transfer the exception // to the rref. - ownerRRef->setError(e.what()); + ownerRRef->setError(std::current_exception()); postProcessing(); return true; } @@ -321,7 +321,8 @@ void RequestCallbackNoPython::processRpc( whenValueSet->addCallback( [responseFuture, messageId, rref, whenValueSet]() { if (whenValueSet->hasError()) { - responseFuture->setError(whenValueSet->error()->what()); + responseFuture->setError( + whenValueSet->tryRetrieveErrorMessage()); return; } try { diff --git a/torch/csrc/distributed/rpc/rref_impl.cpp b/torch/csrc/distributed/rpc/rref_impl.cpp index 2f64231a9129e..34249172473c8 100644 --- a/torch/csrc/distributed/rpc/rref_impl.cpp +++ b/torch/csrc/distributed/rpc/rref_impl.cpp @@ -254,8 +254,8 @@ void OwnerRRef::setValue(IValue&& value) { future_->markCompleted(value); } -void OwnerRRef::setError(const std::string& error) { - future_->setErrorIfNeeded(error); +void OwnerRRef::setError(std::exception_ptr eptr) { + future_->setErrorIfNeeded(std::move(eptr)); } std::ostream& operator<<(std::ostream& os, const RRef& rref) { diff --git a/torch/csrc/distributed/rpc/rref_impl.h b/torch/csrc/distributed/rpc/rref_impl.h index 7a3cc6932cb9b..29aa355908fa9 100644 --- a/torch/csrc/distributed/rpc/rref_impl.h +++ b/torch/csrc/distributed/rpc/rref_impl.h @@ -384,7 +384,7 @@ class TORCH_API OwnerRRef final : public RRef { // does not create any new py::object. void setValue(IValue&& value); // Sets the value of this ``OwnerRRef`` to contain an exception. - void setError(const std::string& err); + void setError(std::exception_ptr eptr); // Has a value or error been set? bool hasValue() const; diff --git a/torch/csrc/distributed/rpc/torchscript_functions.cpp b/torch/csrc/distributed/rpc/torchscript_functions.cpp index c5e65621494bc..fec3f990774b2 100644 --- a/torch/csrc/distributed/rpc/torchscript_functions.cpp +++ b/torch/csrc/distributed/rpc/torchscript_functions.cpp @@ -68,7 +68,7 @@ c10::intrusive_ptr rpcTorchscript( auto futMessage = wp.lock(); if (futMessage->hasError()) { c10::ivalue::Future::FutureError jitFutErr(futMessage->error()->what()); - futPtr->setError(std::move(jitFutErr)); + futPtr->setError(std::make_exception_ptr(jitFutErr)); } else { futPtr->markCompleted(deserializeRespToIValue(futMessage->constValue())); } diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp index f7e7a3bd3b5ac..337fe66c07897 100644 --- a/torch/csrc/jit/runtime/interpreter.cpp +++ b/torch/csrc/jit/runtime/interpreter.cpp @@ -1571,7 +1571,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { formatStackTrace(ss); ss << "RuntimeError: " << msg << "\n"; if (future_) { - future_->setError(Future::FutureError(ss.str())); + future_->setError(std::make_exception_ptr(Future::FutureError(ss.str()))); } else if (is_jit_exception) { throw JITException(ss.str()); } else { diff --git a/torch/lib/c10d/ProcessGroupNCCL.hpp b/torch/lib/c10d/ProcessGroupNCCL.hpp index 3ee8bb4adf63d..0ac6f80c1936d 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.hpp +++ b/torch/lib/c10d/ProcessGroupNCCL.hpp @@ -224,10 +224,6 @@ class ProcessGroupNCCL : public ProcessGroup { value_ = std::move(value); } - void setError(std::string err) override { - error_ = FutureError(std::move(err)); - } - // Just returns FutureNCCL's value after wait returns. at::IValue value() override { TORCH_INTERNAL_ASSERT(hasValue(), "FutureNCCL's value is None.") @@ -279,7 +275,7 @@ class ProcessGroupNCCL : public ProcessGroup { // records callback's stream. (*thenFutCudaEvents)[0].record(*futureNCCLCallbackStream_); } catch (const std::exception& e) { - fut->setError(e.what()); + fut->setError(std::current_exception()); } }, std::move(callback)));