Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix anomaly mode memory leak #51610

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
39 changes: 39 additions & 0 deletions test/test_autograd.py
Expand Up @@ -3799,6 +3799,45 @@ def __exit__(self, *args):
self.assertIn('Anomaly Detection has been enabled', str(w[0].message))
self.assertIn('Error detected in PowBackward0', s.captured)

def test_anomaly_pyobject_cleanup(self):
# Test that python objects created are properly cleaned up
import weakref

def get_ref():
with detect_anomaly():
# we use torch.exp here but any function that will construct a new node in its
# backward call in grad mode will work
x = torch.randn(2, 2, requires_grad=True)
t = x.exp()

# ExpBackward calls mul, creating the MulBackward node when create_graph=True.
# In anomaly mode, a PyObject referencing MulBackward's "parent" ExpBackward is added to
# MulBackward's anomaly metadata dict, creating the following reference chain:
#
# grad -> MulBackward -> PyObject -> ExpBackward
#
grad = torch.autograd.grad(t, x, torch.ones_like(t), create_graph=True)

# We add a weak reference to a new object foo, which we insert into ExpBackward's metadata dict
#
# (PyObject) -> ExpBackward -> dict -> *Foo*
# t ----^ WeakRef ---^
#
# We want to test that when grad goes out of scope at the end of this function that PyObject is destroyed
# We can test this by seeing whether Foo is not kept alive once t is destroyed
meta_dict = t.grad_fn.metadata
class Foo(object):
pass
my_obj = Foo()
meta_dict[0] = my_obj
ref = weakref.ref(my_obj)
return t, ref

t, ref = get_ref()
self.assertIsNotNone(ref())
del t
self.assertIsNone(ref())

@skipIfNoLapack
def test_eig_no_eigenvectors(self):
A = torch.tensor([[1., 2.], [2., 4.]], dtype=torch.float32, requires_grad=True)
Expand Down
16 changes: 8 additions & 8 deletions torch/csrc/autograd/python_anomaly_mode.cpp
Expand Up @@ -41,23 +41,23 @@ void PyAnomalyMetadata::print_stack(const std::string& current_node_name) {
// if there is no "parent_" in metadata, then it means this metadata's node
// is the root and stop printing the traceback
while (pyparent) {
PyObject* parent_metadata(PyObject_GetAttrString(pyparent, "metadata"));
THPObjectPtr parent_metadata(PyObject_GetAttrString(pyparent, "metadata"));
if (!parent_metadata) {
throw python_error();
}
PyObject* parent_name_pyobj(PyObject_CallMethod(pyparent, "name", ""));
THPObjectPtr parent_name_pyobj(PyObject_CallMethod(pyparent, "name", ""));
if (!parent_name_pyobj) {
throw python_error();
}
const char* parent_name_char = PyUnicode_AsUTF8(parent_name_pyobj);
const char* parent_name_char = PyUnicode_AsUTF8(parent_name_pyobj.get());
if (!parent_name_char) {
throw python_error();
}
const std::string parent_name(parent_name_char);
PyObject* parent_stack = PyDict_GetItemString(parent_metadata, ANOMALY_TRACE_KEY);
PyObject* parent_stack = PyDict_GetItemString(parent_metadata.get(), ANOMALY_TRACE_KEY);
_print_stack(parent_stack, parent_name, true);
// get the parent of this node, if this node is a root, pyparent is simply null
pyparent = PyDict_GetItemString(parent_metadata, ANOMALY_PARENT_KEY);
pyparent = PyDict_GetItemString(parent_metadata.get(), ANOMALY_PARENT_KEY);
}
}

Expand All @@ -69,11 +69,11 @@ void PyAnomalyMetadata::assign_parent(const std::shared_ptr<Node>& parent_node)
pybind11::gil_scoped_acquire gil;
if (!parent_node) return;

PyObject* pyobj = functionToPyObject(parent_node);
if (!pyobj) {
THPObjectPtr parent_node_(functionToPyObject(parent_node));
if (!parent_node_) {
throw python_error();
}
if (PyDict_SetItemString(dict(), ANOMALY_PARENT_KEY, pyobj)) {
if (PyDict_SetItemString(dict(), ANOMALY_PARENT_KEY, parent_node_.get())) {
throw python_error();
}
}
Expand Down