Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Printing should not have (bad) autograd side effects #49756

Closed
albanD opened this issue Dec 22, 2020 · 4 comments
Closed

Printing should not have (bad) autograd side effects #49756

albanD opened this issue Dec 22, 2020 · 4 comments
Labels
high priority module: autograd Related to torch.autograd, and the autograd engine in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@albanD
Copy link
Collaborator

albanD commented Dec 22, 2020

In the case where another view of a Tensor has been modified inplace, the next access to its grad_fn or grad informations will trigger a rewrite of the graph to properly take into account the fact that other views have been updated.

The problem is that if this update happens during a print statement, it happens in no_grad mode and thus a bogus graph is created (It adds the AsStridedBackward Node but this node points nowhere as the collect_next_edges returned an empty list).

Repro:

import torch
from torch import nn
a = torch.ones(3, 1, requires_grad=True)
b = a.view_as(a)
# No grad here just to avoid issue with modifying leafs inplace
with torch.no_grad():
    a[0,0] = 3
# Here the print trigger a recompute of the grad_fn
# Removing this print makes the code work just fine
print(b)
d = torch.sum(3*b)
d.backward()

Fails with

RuntimeError: Index out of range
Exception raised from should_compute_output at /pytorch/torch/csrc/autograd/function.h:272 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7fbeba7ef8b2 in /usr/local/lib/python3.6/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x27c8a02 (0x7fbef648aa02 in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_cpu.so)
frame #2: <unknown function> + 0x27744c0 (0x7fbef64364c0 in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_cpu.so)
frame #3: torch::autograd::generated::AsStridedBackward::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) + 0x99 (0x7fbef643a549 in 

I think the right fix is to make grad_fn() in variable.cpp enable grad mode all the time when collecting edges to make sure it always creates a valid graph (even if the user manually disable the grads before triggering the grad_fn update).
Indeed, if we get to this codepath, it means that the original view was a differentiable view, created in a block with grad enabled and we should maintain that after the rewrite.

Alternative: we could modify the printing code to enable gradients when accessing the grad_fn to ensure it is properly created. But that wouldn't prevent the user from accessing the grad_fn in a no_grad block, leading to this exact same error.

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @albanD @gqchen @pearu @nikitaved

@ngimel ngimel added module: autograd Related to torch.autograd, and the autograd engine in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Dec 22, 2020
@gchanan
Copy link
Contributor

gchanan commented Dec 22, 2020

the other option, I guess, is to disable accessing anything beyond the grad_fn name in a no_grad block; I guess that would prevent users from printing out grad_fn metadata and such though.

@albanD
Copy link
Collaborator Author

albanD commented Dec 23, 2020

But that would mean that the grad_fn name is wrong (it goes from ExpandBackward to AsStridedBackward in this case).

But outside of the printing, we still need access to the actual Node for things like torchviz to work.

@albanD
Copy link
Collaborator Author

albanD commented Dec 23, 2020

High priority at least to fix the print case.

@jbschlosser
Copy link
Contributor

jbschlosser commented Jan 29, 2021

It's unclear to me why collect_next_edges returns an empty set when grad mode is disabled. It's likely the unexpected semantics of this function are what caused the bug in the first place. Imo it should be the responsibility of this function to do exactly what it says, and should be a responsibility of the call-site to determine whether the function should be called at all, possibly conditional on whether grad mode is enabled. This make the semantics much more clear.

In fact, for the vast majority of collect_next_edges call-sites, including all generated ones, there is a separate check to GradMode::is_enabled, usually through compute_requires_grad, that determines whether collect_next_edges needs to be called at all.

In the codebase, I found only 3 collect_next_edges call-sites that rely on an empty list of edges being returned when grad mode is disabled:

  • unpack_input() in torch/csrc/autograd/python_function.cpp
  • apply() in torch/csrc/autograd/custom_function.h
  • wrap_outputs() in torch/csrc/autograd/functions/utils.cpp

Of course, the inevitable check for grad mode enabled still happens at some point.

With this in mind, I propose that we:

  1. Remove the grad enabled check / empty list return logic from collect_next_edges
  2. Update the 3 call-sites to manually construct an empty list instead of calling collect_next_edges when grad mode is disabled

While I understand this proposed fix affects a broader cross-section of the codebase than the originally proposed fixes, I think this fix is better conceptually, and fixing the semantics of collect_next_edges makes future maintenance easier.

If this seems too dangerous, I have a branch ready to go with the proposed fix of temporarily enabling grad mode during collect_next_edges in grad_fn().

Thoughts on this? @albanD @gchanan

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: autograd Related to torch.autograd, and the autograd engine in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants