Skip to content

Commit

Permalink
Fix anomaly mode memory leak (#51610)
Browse files Browse the repository at this point in the history
Summary:
Fixes #51349

The memory leak happens when 1) `create_graph` is True AND 2) detect anomaly mode is on. When a backward node's constructor is called during backward, the current evaluating node is assigned as a "parent" of the created node. The code that assigns the parent encounters the below issue:

`functionToPyObject(parent_node)` returns a new PyObject (with refcount 1) or if PyObject already exists, increments its refcount by 1. However [PyDict_SetItem](https://github.com/python/cpython/blob/1b55b65638254aa78b005fbf0b71fb02499f1852/Objects/dictobject.c#L1532) calls into [insertdict](https://github.com/python/cpython/blob/v3.8.1/Objects/dictobject.c#L1034) which increments refcount again. This means that when dict is destroyed, the refcount of the PyObject is at least one. This keeps `parent_node` (the backward function) alive, which then keeps the saved tensor alive.

Similar calls in the codebase to `functionToPyObject` won't require Py_DECREF if it is then passed into a tuple (instead of dict), because the analogous PyTuple_SetItem call does not increment refcount.

Pull Request resolved: #51610

Reviewed By: albanD

Differential Revision: D26240336

Pulled By: soulitzer

fbshipit-source-id: 2854528f66fab9dbce448f8a7ba732ce386a7310
  • Loading branch information
soulitzer authored and facebook-github-bot committed Feb 4, 2021
1 parent 0222966 commit 2e8e560
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 8 deletions.
98 changes: 98 additions & 0 deletions test/test_autograd.py
Expand Up @@ -3690,6 +3690,104 @@ def __exit__(self, *args):
self.assertIn('Anomaly Detection has been enabled', str(w[0].message))
self.assertIn('Error detected in PowBackward0', s.captured)

def test_anomaly_assign_parent_cleanup(self):
# Test that python objects created are properly cleaned up when assign_parent is called
import weakref

def get_ref():
# we use torch.exp here but any function that will construct a new node in its
# backward call in grad mode will work
x = torch.randn(2, 2, requires_grad=True)
t = x.exp()

# ExpBackward calls mul, creating the MulBackward node when create_graph=True.
# In anomaly mode, a PyObject referencing MulBackward's "parent" ExpBackward is added to
# MulBackward's anomaly metadata dict, creating the following reference chain:
#
# grad -> MulBackward -> PyObject -> ExpBackward
#
with detect_anomaly():
grad = torch.autograd.grad(t, x, torch.ones_like(t), create_graph=True)

# We add a weak reference to a new Foo object, which we insert into ExpBackward's metadata dict
#
# (PyObject) -> ExpBackward -> dict -> *Foo*
# t ----^ WeakRef ---^
#
# We want to test that when grad goes out of scope at the end of this function that PyObject is destroyed
# We can test this by seeing whether Foo is not kept alive once t is destroyed
class Foo(object):
pass
my_obj = Foo()
meta_dict = t.grad_fn.metadata
meta_dict[0] = my_obj
ref = weakref.ref(my_obj)
return t, ref

t, ref = get_ref()
self.assertIsNotNone(ref())
del t
self.assertIsNone(ref())

def test_nested_anomaly_printstack_cleanup(self):
# Test if metadata dict PyObject is properly destroyed
import weakref

def get_ref():
# This is similar to the construction in test_anomaly_assign_parent_cleanup:
#
# MyFuncBackward2 -> PyObject -> MyFuncBackward -> dict -> Foo
# out ---^ WeakRef ---^
#
# We want to check that Foo is still properly destroyed even when MyFunc2Backward's
# AnomalyMetadata calls printstack, which does some python object manipulation.
#
# You might be wondering why we still have to test_anomaly_assign_parent_cleanup,
# since if PyObject is not destroyed here, wouldn't this test would detect that also?
# The answer is that custom function's PyObject (THPFunction) actually only hold
# a weak reference to the c++ node!
class MyFunc(Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x

@staticmethod
def backward(ctx, gO):
x, = ctx.saved_tensors
return MyFunc2.apply(x)

class MyFunc2(Function):
@staticmethod
def forward(ctx, x):
return x

@staticmethod
def backward(ctx, gO):
return gO + float("NaN")

inp = torch.rand(1, requires_grad=True)
out = MyFunc.apply(inp)
ginp, = torch.autograd.grad(out, (inp,), create_graph=True)

with warnings.catch_warnings(record=True) as w:
with self.assertRaisesRegex(RuntimeError, "Function 'MyFunc2Backward' returned nan values in its 0th output."):
with detect_anomaly():
ginp.backward()

class Foo(object):
pass
my_obj = Foo()
meta_dict = out.grad_fn.metadata
meta_dict[0] = my_obj
ref = weakref.ref(my_obj)
return out, ref

t, ref = get_ref()
self.assertIsNotNone(ref())
del t
self.assertIsNone(ref())

@skipIfNoLapack
def test_eig_no_eigenvectors(self):
A = torch.tensor([[1., 2.], [2., 4.]], dtype=torch.float32, requires_grad=True)
Expand Down
16 changes: 8 additions & 8 deletions torch/csrc/autograd/python_anomaly_mode.cpp
Expand Up @@ -41,23 +41,23 @@ void PyAnomalyMetadata::print_stack(const std::string& current_node_name) {
// if there is no "parent_" in metadata, then it means this metadata's node
// is the root and stop printing the traceback
while (pyparent) {
PyObject* parent_metadata(PyObject_GetAttrString(pyparent, "metadata"));
THPObjectPtr parent_metadata(PyObject_GetAttrString(pyparent, "metadata"));
if (!parent_metadata) {
throw python_error();
}
PyObject* parent_name_pyobj(PyObject_CallMethod(pyparent, "name", ""));
THPObjectPtr parent_name_pyobj(PyObject_CallMethod(pyparent, "name", ""));
if (!parent_name_pyobj) {
throw python_error();
}
const char* parent_name_char = PyUnicode_AsUTF8(parent_name_pyobj);
const char* parent_name_char = PyUnicode_AsUTF8(parent_name_pyobj.get());
if (!parent_name_char) {
throw python_error();
}
const std::string parent_name(parent_name_char);
PyObject* parent_stack = PyDict_GetItemString(parent_metadata, ANOMALY_TRACE_KEY);
PyObject* parent_stack = PyDict_GetItemString(parent_metadata.get(), ANOMALY_TRACE_KEY);
_print_stack(parent_stack, parent_name, true);
// get the parent of this node, if this node is a root, pyparent is simply null
pyparent = PyDict_GetItemString(parent_metadata, ANOMALY_PARENT_KEY);
pyparent = PyDict_GetItemString(parent_metadata.get(), ANOMALY_PARENT_KEY);
}
}

Expand All @@ -69,11 +69,11 @@ void PyAnomalyMetadata::assign_parent(const std::shared_ptr<Node>& parent_node)
pybind11::gil_scoped_acquire gil;
if (!parent_node) return;

PyObject* pyobj = functionToPyObject(parent_node);
if (!pyobj) {
THPObjectPtr parent_node_(functionToPyObject(parent_node));
if (!parent_node_) {
throw python_error();
}
if (PyDict_SetItemString(dict(), ANOMALY_PARENT_KEY, pyobj)) {
if (PyDict_SetItemString(dict(), ANOMALY_PARENT_KEY, parent_node_.get())) {
throw python_error();
}
}
Expand Down

0 comments on commit 2e8e560

Please sign in to comment.