Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

grad doesn't run when under torch.no_grad() #13

Closed
Chillee opened this issue May 5, 2021 · 2 comments · Fixed by #179
Closed

grad doesn't run when under torch.no_grad() #13

Chillee opened this issue May 5, 2021 · 2 comments · Fixed by #179
Assignees

Comments

@Chillee
Copy link
Contributor

Chillee commented May 5, 2021

from functorch import grad, vmap, pythonkey_trace, wrap_key
import torch
import torch.fx as fx

def f(x):
    return torch.sin(x)
with torch.no_grad():
    print(grad(f)(torch.randn(())))
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Not totally sure what the semantics should be... but I kinda think we should be ignoring torch.no_grad().

@zou3519
Copy link
Contributor

zou3519 commented Jul 22, 2021

One really annoying thing about torch.no_grad is that it is not traceable. JAX has a stop_gradient primitive that operates on Tensors so it does become traceable: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.stop_gradient.html

However, I think it's useful to be able to use something like torch.no_grad inside of a transform. For example, one pattern I've seen is:

def f(x):
    with torch.no_grad():
        shift = x.mean()
    return x - shift

Proposal:

  • torch.no_grad does affect grad/vjp transforms. Any computation that happens within torch.no_grad is invisible to vjp/grad
  • If a user calls grad/vjp inside of torch.no_grad, we raise a warning that explains that their gradients will be 0. (Or maybe this should be an error?)
  • For tracing... either we introduce something like stop_gradient(Tensor) -> Tensor or figure out how to "trace" torch.no_grad. This sounds a bit like factory function tracing and could potentially be done with a mode-based dispatch key

Alternatives:

  • functorch just straight up ignores torch.no_grad
  • we introduce a functorch.stop_gradient or something

@zou3519 zou3519 self-assigned this Oct 4, 2021
@zou3519
Copy link
Contributor

zou3519 commented Oct 4, 2021

New proposal: here's what I think the semantics should be.

Case 1: grad gets called inside torch.no_grad.

  • grad should ignore torch.no_grad because it's "creating a new level of autograd above the current level"
  • Another way to think about this is that grad(f) is a "function transform": its result should not be affected by context managers that are outside of the function f

Case 2: torch.no_grad gets called inside `grad

  • grad should respect torch.no_grad

How does one actually implement this? We can probably do something with a mode stack here...

zou3519 added a commit that referenced this issue Oct 5, 2021
Fixes #13

Case 1: grad gets called inside torch.no_grad.

- grad should ignore torch.no_grad because it's "creating a new level of
autograd above the current level"
- Another way to think about this is that grad(f) is a "function
transform": its result should not be affected by context managers that
are outside of the function f

Case 2: torch.no_grad gets called inside `grad`
- grad should respect torch.no_grad

See NOTE [grad and vjp interaction with no_grad] for implementation
strategy. It unfortunately involves a mode.

Test Plan:
- Many tests
@zou3519 zou3519 mentioned this issue Oct 5, 2021
zou3519 added a commit that referenced this issue Oct 5, 2021
Fixes #13

Case 1: grad gets called inside torch.no_grad.

- grad should ignore torch.no_grad because it's "creating a new level of
autograd above the current level"
- Another way to think about this is that grad(f) is a "function
transform": its result should not be affected by context managers that
are outside of the function f

Case 2: torch.no_grad gets called inside `grad`
- grad should respect torch.no_grad

See NOTE [grad and vjp interaction with no_grad] for implementation
strategy. It unfortunately involves a mode.

Test Plan:
- Many tests
zou3519 added a commit that referenced this issue Oct 23, 2021
Fixes #13

Case 1: grad gets called inside torch.no_grad.

- grad should ignore torch.no_grad because it's "creating a new level of
autograd above the current level"
- Another way to think about this is that grad(f) is a "function
transform": its result should not be affected by context managers that
are outside of the function f

Case 2: torch.no_grad gets called inside `grad`
- grad should respect torch.no_grad

See NOTE [grad and vjp interaction with no_grad] for implementation
strategy. It unfortunately involves a mode.

Test Plan:
- Many tests
zou3519 added a commit to zou3519/pytorch that referenced this issue Jul 20, 2022
Fixes pytorch/functorch#13

Case 1: grad gets called inside torch.no_grad.

- grad should ignore torch.no_grad because it's "creating a new level of
autograd above the current level"
- Another way to think about this is that grad(f) is a "function
transform": its result should not be affected by context managers that
are outside of the function f

Case 2: torch.no_grad gets called inside `grad`
- grad should respect torch.no_grad

See NOTE [grad and vjp interaction with no_grad] for implementation
strategy. It unfortunately involves a mode.

Test Plan:
- Many tests
bigfootjon pushed a commit to pytorch/pytorch that referenced this issue Jul 21, 2022
Fixes pytorch/functorch#13

Case 1: grad gets called inside torch.no_grad.

- grad should ignore torch.no_grad because it's "creating a new level of
autograd above the current level"
- Another way to think about this is that grad(f) is a "function
transform": its result should not be affected by context managers that
are outside of the function f

Case 2: torch.no_grad gets called inside `grad`
- grad should respect torch.no_grad

See NOTE [grad and vjp interaction with no_grad] for implementation
strategy. It unfortunately involves a mode.

Test Plan:
- Many tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants