-
Notifications
You must be signed in to change notification settings - Fork 102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
grad doesn't run when under torch.no_grad()
#13
Comments
One really annoying thing about torch.no_grad is that it is not traceable. JAX has a stop_gradient primitive that operates on Tensors so it does become traceable: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.stop_gradient.html However, I think it's useful to be able to use something like torch.no_grad inside of a transform. For example, one pattern I've seen is:
Proposal:
Alternatives:
|
New proposal: here's what I think the semantics should be. Case 1:
Case 2: torch.no_grad gets called inside `grad
How does one actually implement this? We can probably do something with a mode stack here... |
Fixes #13 Case 1: grad gets called inside torch.no_grad. - grad should ignore torch.no_grad because it's "creating a new level of autograd above the current level" - Another way to think about this is that grad(f) is a "function transform": its result should not be affected by context managers that are outside of the function f Case 2: torch.no_grad gets called inside `grad` - grad should respect torch.no_grad See NOTE [grad and vjp interaction with no_grad] for implementation strategy. It unfortunately involves a mode. Test Plan: - Many tests
Fixes #13 Case 1: grad gets called inside torch.no_grad. - grad should ignore torch.no_grad because it's "creating a new level of autograd above the current level" - Another way to think about this is that grad(f) is a "function transform": its result should not be affected by context managers that are outside of the function f Case 2: torch.no_grad gets called inside `grad` - grad should respect torch.no_grad See NOTE [grad and vjp interaction with no_grad] for implementation strategy. It unfortunately involves a mode. Test Plan: - Many tests
Fixes #13 Case 1: grad gets called inside torch.no_grad. - grad should ignore torch.no_grad because it's "creating a new level of autograd above the current level" - Another way to think about this is that grad(f) is a "function transform": its result should not be affected by context managers that are outside of the function f Case 2: torch.no_grad gets called inside `grad` - grad should respect torch.no_grad See NOTE [grad and vjp interaction with no_grad] for implementation strategy. It unfortunately involves a mode. Test Plan: - Many tests
Fixes pytorch/functorch#13 Case 1: grad gets called inside torch.no_grad. - grad should ignore torch.no_grad because it's "creating a new level of autograd above the current level" - Another way to think about this is that grad(f) is a "function transform": its result should not be affected by context managers that are outside of the function f Case 2: torch.no_grad gets called inside `grad` - grad should respect torch.no_grad See NOTE [grad and vjp interaction with no_grad] for implementation strategy. It unfortunately involves a mode. Test Plan: - Many tests
Fixes pytorch/functorch#13 Case 1: grad gets called inside torch.no_grad. - grad should ignore torch.no_grad because it's "creating a new level of autograd above the current level" - Another way to think about this is that grad(f) is a "function transform": its result should not be affected by context managers that are outside of the function f Case 2: torch.no_grad gets called inside `grad` - grad should respect torch.no_grad See NOTE [grad and vjp interaction with no_grad] for implementation strategy. It unfortunately involves a mode. Test Plan: - Many tests
Not totally sure what the semantics should be... but I kinda think we should be ignoring
torch.no_grad()
.The text was updated successfully, but these errors were encountered: