Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove incorrect THP{Cpp,}Function_traverse PyObject traversals #102860

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
72 changes: 72 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7462,6 +7462,78 @@ def forward(self, x):
gc.collect()
self.assertIsNone(ref_())

@parametrize("use_custom_function", [True, False])
@parametrize("use_tensor_hook", [True, False])
def test_hook_closure_cycle(self, use_custom_function, use_tensor_hook):
# This creates a cycle between the hook and grad_fn_b
# hook -> closure -> grad_fn_b (python) -> grad_fn (cpp) -> hook (cpp)
# -> dict -> hook
#
# This test is testing that the grad_fn_b (python) only traverses the
# dict if it is the only one holding a reference to the grad_fn_b (cpp)
# shared_ptr
#
# See: https://github.com/pytorch/pytorch/issues/102174
class Function(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x

@staticmethod
def backward(ctx, grad):
return grad

class Test():
pass

count = [0]

def scope():
a = torch.tensor(1., requires_grad=True)
if use_custom_function:
b = Function.apply(a)
else:
b = a.clone()
grad_fn_b = b.grad_fn
obj = Test()

def hook(*args):
# Make sure this hook's closure holds onto grad_fn_b
# This forms a cycle between the hook and grad_fn_b
# We also hold onto a sentinel object 'obj' to track
# whether this cycle is still alive. See 'ref' below.
grad_fn_b
obj
count[0] += 1
if use_tensor_hook:
b.register_hook(hook)
else:
b.grad_fn.register_hook(hook)
c = b.clone()
ref = weakref.ref(obj)
return c, ref

with disable_gc():
out, ref = scope()
out.backward(retain_graph=True)

gc.collect()

# Make sure gc does not clear the cycle noted above.
# e.g. the hook is alive and gets fired even after gc runs
out.backward(retain_graph=True)
self.assertEqual(count[0], 2)

# ref is still alive because the use_count of the cpp grad_fn
# shared_ptr > 1 since (1) the python grad_fn is alive, and (2) the
# rest of the graph holds onto the shared_ptr
self.assertIsNotNone(ref())

# Then delete the rest of the graph and check that ref is dead
del out
gc.collect()
self.assertIsNone(ref())

def test_full_backward_hook_double_backward(self):
x = torch.rand(1, requires_grad=True)
y = torch.rand_like(x)
Expand Down
8 changes: 8 additions & 0 deletions torch/csrc/autograd/python_cpp_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,14 @@ PyObject* THPCppFunction_call(

int THPCppFunction_traverse(PyObject* self, visitproc visit, void* arg) {
auto& fn = *((THPCppFunction*)self)->cdata;
if ((((THPCppFunction*)self)->cdata).use_count() > 1) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This definitely makes the diff easier to read but I think it's better to guard the calls below rather than early returning.

// The fields traversed below are owned by the cpp grad_fn, which we own a
// reference to. We should only them traverse however if we are the only
// owner of the grad_fn, otherwise we risk prematurely gc'ing the grad_fn.
//
// See: https://github.com/pytorch/pytorch/issues/102174
return 0;
}
for (const auto& hook : fn.tensor_pre_hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
Expand Down
32 changes: 2 additions & 30 deletions torch/csrc/autograd/python_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,36 +206,8 @@ auto PyNode::name() const -> std::string {

// Traverse and clear are required for supporting Python's GC cycle handling.
static int THPFunction_traverse(THPFunction* self, visitproc visit, void* arg) {
// cdata could be null if the PyNode has already gone out of scope
// by the time we're GC'ing this THPFunction (e.g., the user saved grad_fn
// only).
//
// TODO: I'm not really sure if we're actually obligated to traverse PyObject
// that is stored in PyNode, since we don't really own that C++ object.
if (auto cdata = self->cdata.lock()) {
for (const auto& hook : cdata->tensor_pre_hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
}
}
// See NOTE [retains_grad_hook PyObject traversal]
for (const auto& pair : cdata->retains_grad_hooks()) {
if (auto pyhook =
dynamic_cast<PyFunctionTensorPreHook*>(pair.second.get())) {
Py_VISIT(pyhook->dict);
}
}
for (const auto& hook : cdata->pre_hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
}
}
for (const auto& hook : cdata->post_hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
}
}
}
// NB: We should not traverse PyObbject stored on PyNode, since we only hold
// as weak reference to the PyNode.
Py_VISIT(self->to_save);
Py_VISIT(self->non_differentiable);
Py_VISIT(self->dirty_tensors);
Expand Down