-
Notifications
You must be signed in to change notification settings - Fork 22.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Printing should not have (bad) autograd side effects #49756
Comments
the other option, I guess, is to disable accessing anything beyond the grad_fn name in a no_grad block; I guess that would prevent users from printing out grad_fn metadata and such though. |
But that would mean that the grad_fn name is wrong (it goes from ExpandBackward to AsStridedBackward in this case). But outside of the printing, we still need access to the actual Node for things like torchviz to work. |
High priority at least to fix the print case. |
It's unclear to me why In fact, for the vast majority of In the codebase, I found only 3
Of course, the inevitable check for grad mode enabled still happens at some point. With this in mind, I propose that we:
While I understand this proposed fix affects a broader cross-section of the codebase than the originally proposed fixes, I think this fix is better conceptually, and fixing the semantics of If this seems too dangerous, I have a branch ready to go with the proposed fix of temporarily enabling grad mode during |
In the case where another view of a Tensor has been modified inplace, the next access to its grad_fn or grad informations will trigger a rewrite of the graph to properly take into account the fact that other views have been updated.
The problem is that if this update happens during a print statement, it happens in no_grad mode and thus a bogus graph is created (It adds the AsStridedBackward Node but this node points nowhere as the collect_next_edges returned an empty list).
Repro:
Fails with
I think the right fix is to make
grad_fn()
in variable.cpp enable grad mode all the time when collecting edges to make sure it always creates a valid graph (even if the user manually disable the grads before triggering the grad_fn update).Indeed, if we get to this codepath, it means that the original view was a differentiable view, created in a block with grad enabled and we should maintain that after the rewrite.
Alternative: we could modify the printing code to enable gradients when accessing the grad_fn to ensure it is properly created. But that wouldn't prevent the user from accessing the grad_fn in a no_grad block, leading to this exact same error.
cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @albanD @gqchen @pearu @nikitaved
The text was updated successfully, but these errors were encountered: