Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3460,6 +3460,71 @@ def backward(ctx, gO):
out.backward()
self.assertIn('MyFunc.apply', str(w[0].message))

def test_nested_anomaly_detect_nan(self):
size = 10

class MyFunc(Function):
@staticmethod
def forward(ctx, inp1, fail_0th):
ctx.fail_0th = fail_0th
ctx.save_for_backward(inp1)
return inp1.sum(0, keepdim=True)

@staticmethod
def backward(ctx, gO):
inp, = ctx.saved_tensors
fail_0th = ctx.fail_0th
g = gO.clone().expand(size)
gI = MyFunc2.apply(g * inp, g + inp, fail_0th)
return gI, None

class MyFunc2(Function):
@staticmethod
def forward(ctx, inp1, inp2, fail_0th):
ctx.fail_0th = fail_0th
return inp1 * 2.0 + inp2

@staticmethod
def backward(ctx, gO):
fail_0th = ctx.fail_0th
g1 = gO.clone()
g2 = gO.clone()
g1[0] = 0
g2[0] = 0
# generate a nan
if fail_0th:
g1[0] /= 0
else:
g2[0] /= 0
return g1, g2, None

inp = torch.rand(size, requires_grad=True)
out = MyFunc.apply(inp, True)
ginp, = torch.autograd.grad(out, (inp,), create_graph=True)
gsum = ginp.sum()
gsum.backward() # should not fail

inp = torch.rand(size, requires_grad=True)
out = MyFunc.apply(inp, True)
ginp, = torch.autograd.grad(out, (inp,), create_graph=True)
gsum = ginp.sum()
with warnings.catch_warnings(record=True) as w:
with self.assertRaisesRegex(RuntimeError, "Function 'MyFunc2Backward' returned nan values in its 0th output."):
with detect_anomaly():
gsum.backward()
self.assertIn('No forward pass information', str(w[1].message))

inp = torch.rand(size, requires_grad=True)
with warnings.catch_warnings(record=True) as w:
with self.assertRaisesRegex(RuntimeError, "Function 'MyFunc2Backward' returned nan values in its 1th output."):
with detect_anomaly():
out = MyFunc.apply(inp, False)
ginp, = torch.autograd.grad(out, (inp,), create_graph=True)
gsum = ginp.sum()
gsum.backward()
self.assertIn('MyFunc2.apply', str(w[1].message))
self.assertIn('MyFunc.apply', str(w[2].message))

def test_anomaly_grad_warnings(self):
# PyTorch won't throw warnings if there is an error
# but we'd want to at least see them in stderr
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/autograd/anomaly_mode.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
#pragma once

#include <string>
#include <memory>
#include <torch/csrc/WindowsTorchApiMacro.h>

namespace torch { namespace autograd {

// forward declaration of Node from function.h
struct Node;

struct TORCH_API AnomalyMode {
static bool is_enabled() {
return _enabled;
Expand All @@ -22,6 +26,7 @@ struct TORCH_API AnomalyMetadata {
virtual ~AnomalyMetadata();
virtual void store_stack() = 0;
virtual void print_stack(const std::string& current_node_name) = 0;
virtual void assign_parent(const std::shared_ptr<Node>& parent_node) = 0;
};

}}
1 change: 1 addition & 0 deletions torch/csrc/autograd/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ auto Engine::thread_main(const std::shared_ptr<GraphTask>& graph_task) -> void {
// queue_callback() to find the target GraphTask to append final
// callbacks.
GraphTaskGuard guard(local_graph_task);
NodeGuard ndguard(task.fn_);
evaluate_function(local_graph_task, task.fn_.get(), task.inputs_, local_graph_task->cpu_ready_queue_);
} catch (std::exception& e) {
thread_on_exception(local_graph_task, task.fn_, e);
Expand Down
18 changes: 18 additions & 0 deletions torch/csrc/autograd/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,24 @@

namespace torch { namespace autograd {

// The current evaluating node. This is useful to assign the current node as a
// parent of new nodes created during the evaluation of this node in anomaly
// mode.
static thread_local std::shared_ptr<Node> current_evaluating_node = nullptr;

NodeGuard::NodeGuard(std::shared_ptr<Node> node) {
last_evaluating_node_ = std::move(current_evaluating_node);
current_evaluating_node = std::move(node);
}
NodeGuard::~NodeGuard() {
// restore the previous evaluating node
current_evaluating_node = std::move(last_evaluating_node_);
}

void Node::assign_parent() {
metadata()->assign_parent(current_evaluating_node);
}

auto Node::name() const -> std::string {
return c10::demangle(typeid(*this).name());
}
Expand Down
19 changes: 19 additions & 0 deletions torch/csrc/autograd/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ using IndexRange = std::pair<size_t, size_t>;
// Custom deleter to prevent stack overflows.
TORCH_API void deleteNode(Node* function);

// Guard that sets and restores the evaluating node
class NodeGuard {
public:
explicit NodeGuard(std::shared_ptr<Node> node);
~NodeGuard();

private:
std::shared_ptr<Node> last_evaluating_node_;
};

//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Node
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -97,6 +107,12 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
next_edges_(std::move(next_edges)) {
if (AnomalyMode::is_enabled()) {
metadata()->store_stack();

// If anomaly mode is enabled and graph is constructed, then assign the
// currently evaluating node as the parent of this node.
// A parent is a Node where this Node is created.
// We are tracking the parents to track multiple backward operations.
assign_parent();
}
}

Expand Down Expand Up @@ -222,6 +238,9 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
return sequence_nr_;
}

// assigning a node as a parent to this node
void assign_parent();

/// Returns the name of the dynamic type of the function, for debugging.
virtual std::string name() const;

Expand Down
60 changes: 55 additions & 5 deletions torch/csrc/autograd/python_anomaly_mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <c10/util/Exception.h>
#include <pybind11/pybind11.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/autograd/python_cpp_function.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/auto_gil.h>
#include <torch/csrc/utils/object_ptr.h>
Expand Down Expand Up @@ -33,9 +34,51 @@ void PyAnomalyMetadata::print_stack(const std::string& current_node_name) {
if (!PyDict_Check(dict())) {
throw std::runtime_error("Anomaly metadata is not a python dictionary.");
}
PyObject* trace_stack = PyDict_GetItemString(dict(), ANOMALY_TRACE_KEY);
_print_stack(trace_stack, current_node_name, false);
PyObject* pyparent(PyDict_GetItemString(dict(), ANOMALY_PARENT_KEY));

// PyDict_GetItemString returns a borrowed reference
PyObject* stack(PyDict_GetItemString(dict(), ANOMALY_TRACE_KEY));
// if there is no "parent_" in metadata, then it means this metadata's node
// is the root and stop printing the traceback
while (pyparent) {
PyObject* parent_metadata(PyObject_GetAttrString(pyparent, "metadata"));
if (!parent_metadata) {
throw python_error();
}
PyObject* parent_name_pyobj(PyObject_CallMethod(pyparent, "name", ""));
if (!parent_name_pyobj) {
throw python_error();
}
const char* parent_name_char = PyUnicode_AsUTF8(parent_name_pyobj);
if (!parent_name_char) {
throw python_error();
}
const std::string parent_name(parent_name_char);
PyObject* parent_stack = PyDict_GetItemString(parent_metadata, ANOMALY_TRACE_KEY);
_print_stack(parent_stack, parent_name, true);
// get the parent of this node, if this node is a root, pyparent is simply null
pyparent = PyDict_GetItemString(parent_metadata, ANOMALY_PARENT_KEY);
}
}

void PyAnomalyMetadata::assign_parent(const std::shared_ptr<Node>& parent_node) {
// assign the python object of parent_node in metadata["parent_"]
// if parent_node is nullptr, then do nothing (it can mean that "parent_" key
// is not in metadata)

pybind11::gil_scoped_acquire gil;
if (!parent_node) return;

PyObject* pyobj = functionToPyObject(parent_node);
if (!pyobj) {
throw python_error();
}
if (PyDict_SetItemString(dict(), ANOMALY_PARENT_KEY, pyobj)) {
throw python_error();
}
}

void _print_stack(PyObject* stack, const std::string& current_node_name, bool is_parent) {
if (!stack) {
TORCH_WARN("Error detected in ", current_node_name, ". ",
"No forward pass information available. Enable detect anomaly "
Expand All @@ -55,9 +98,16 @@ void PyAnomalyMetadata::print_stack(const std::string& current_node_name) {
throw python_error();
}

TORCH_WARN("Error detected in ", current_node_name, ". ",
"Traceback of forward call that caused the error:\n",
THPUtils_unpackString(msg.get()));
if (!is_parent) {
TORCH_WARN("Error detected in ", current_node_name, ". ",
"Traceback of forward call that caused the error:\n",
THPUtils_unpackString(msg.get()));
} else {
TORCH_WARN("\n\n",
"Previous calculation was induced by ", current_node_name, ". "
"Traceback of forward call that induced the previous calculation:\n",
THPUtils_unpackString(msg.get()));
}
}

}}
3 changes: 3 additions & 0 deletions torch/csrc/autograd/python_anomaly_mode.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ namespace torch { namespace autograd {

struct PyAnomalyMetadata : public AnomalyMetadata {
static constexpr char* ANOMALY_TRACE_KEY = "traceback_";
static constexpr char* ANOMALY_PARENT_KEY = "parent_";

PyAnomalyMetadata() {
pybind11::gil_scoped_acquire gil;
Expand All @@ -20,6 +21,7 @@ struct PyAnomalyMetadata : public AnomalyMetadata {
}
void store_stack() override;
void print_stack(const std::string& current_node_name) override;
void assign_parent(const std::shared_ptr<Node>& parent_node) override;

PyObject* dict() {
return dict_;
Expand All @@ -28,5 +30,6 @@ struct PyAnomalyMetadata : public AnomalyMetadata {
private:
PyObject* dict_;
};
void _print_stack(PyObject* trace_stack, const std::string& current_node_name, bool is_parent);

}}
9 changes: 9 additions & 0 deletions torch/csrc/autograd/python_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <torch/csrc/jit/frontend/tracer.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/python/python_tracer.h>
#include <torch/csrc/utils/python_strings.h>
#include <torch/csrc/DynamicTypes.h>
#include <torch/csrc/Exceptions.h>

Expand Down Expand Up @@ -610,6 +611,13 @@ PyObject* process_outputs(PyObject *op_obj, const std::shared_ptr<PyNode>& cdata
return outputs.release();
}

PyObject* THPFunction_name(THPFunction *self, PyObject* noargs) {
HANDLE_TH_ERRORS
auto cdata = self->cdata.lock();
return THPUtils_packString(cdata->name());
END_HANDLE_TH_ERRORS
}

PyObject *THPFunction_apply(PyObject *cls, PyObject *inputs)
{
HANDLE_TH_ERRORS
Expand Down Expand Up @@ -1004,6 +1012,7 @@ static struct PyGetSetDef THPFunction_properties[] = {
};

static struct PyMethodDef THPFunction_methods[] = {
{(char*)"name", (PyCFunction)THPFunction_name, METH_NOARGS, nullptr},
{(char*)"apply", (PyCFunction)THPFunction_apply, METH_CLASS | METH_VARARGS, nullptr},
{(char*)"_do_backward", (PyCFunction)THPFunction_do_backward, METH_VARARGS, nullptr},
{(char*)"_register_hook_dict", (PyCFunction)THPFunction__register_hook_dict, METH_O, nullptr},
Expand Down