-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update functorch supported autograd.Function to allow mark_dirty #91222
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/91222
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 15e8773: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 1b8229c8ef1e8b91db9a80a2dd974952bd30f9a0 Pull Request resolved: #91222
test/functorch/test_ops.py
Outdated
@@ -450,7 +449,7 @@ def wrapped_fn(*args, **kwargs): | |||
xfail('nn.functional._scaled_dot_product_attention', device_type='cuda'), | |||
|
|||
xfail('nn.functional.rrelu'), # in-place test errors out with no formula implemented | |||
xfail('NumpyExpMarkDirtyAutogradFunction'), # https://github.com/pytorch/pytorch/issues/90225 | |||
xfail('NumpyExpMarkDirtyAutogradFunction'), # TODO: calling in-place operation that would mutate a captured Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This errors for a different reason now, need to investigate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zou3519 I think I figured out what the issue was with this, but not sure what the solution is
In cpp functorch:
- we create dual tensor with tangent that is captured (so it has a immutable wrapper)
- when we call into exp_, the
checkForInvalidMutationOnCaptures
does not care about the tangent having a immutable wrapper, because that is hidden by the dual tensors wrapper which isn't immutable. - before we call into VariableType, we first exclude the dynamicLayerFront
- forward AD formula for inplace ops does tangent.copy_(new_tangent). Because we already excluded dynamicLayerFront, we just go into VariableType again (which is basically a noop since neither tensor has tangent this time around)
In pyfunctorch
- we still create that immutable wrapper for tangent
- when we call process, we do not exclude dynamicLayerFront
- process constructs the single layer autograd Function and calls apply (which calls into forward, then jvp)
- after forward is done (with no problems), jvp is performed, which does tangent.mul_(output). At this point, JvpTransform is still at the stop of the stack.
- since we did not exclude this time, we go into dynamicLayerFront this time, which now errors due to checkForInvalidMutationOnCaptures because now we're performing an in-place op on the tangent which has the immutable wrapper.
click for repro
from functorch import vmap, jvp
import torch
import numpy as np
def to_numpy(tensor):
return tensor.cpu().numpy()
class NumpyMul(torch.autograd.Function):
@staticmethod
def forward(x, y):
return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
@staticmethod
def setup_context(ctx, inputs, outputs):
ctx.save_for_backward(*inputs)
ctx.save_for_forward(*inputs)
@staticmethod
def backward(ctx, grad_output):
x, y = ctx.saved_tensors
gx = None
if ctx.needs_input_grad[0]:
gx = NumpyMul.apply(grad_output, y)
gy = None
if ctx.needs_input_grad[1]:
gy = NumpyMul.apply(grad_output, x)
return gx, gy
@staticmethod
def vmap(info, in_dims, x, y):
x_bdim, y_bdim = in_dims
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
result = NumpyMul.apply(x, y)
result = result.movedim(-1, 0)
return result, 0
@staticmethod
def jvp(ctx, x_tangent, y_tangent):
x, y = ctx.saved_tensors
return x_tangent * y + y_tangent * x
class NumpyExp_(torch.autograd.Function):
@staticmethod
def forward(x):
x_np = to_numpy(x)
np.exp(x_np, x_np)
return x
@staticmethod
def setup_context(ctx, inputs, outputs):
x, = inputs
ctx.mark_dirty(x)
ctx.save_for_backward(outputs)
ctx.save_for_forward(outputs)
@staticmethod
def backward(ctx, grad_output):
output, = ctx.saved_tensors
return NumpyMul.apply(grad_output, output)
@staticmethod
def vmap(info, in_dims, x):
NumpyExp_.apply(x)
return x, in_dims[0]
@staticmethod
def jvp(ctx, x_tangent):
output, = ctx.saved_tensors
x_tangent.mul_(output)
return x_tangent
def fn(x):
# return torch.exp_(x) <-- does not error
return NumpyExp_.apply(x)
a = torch.rand(4,)
b = torch.rand(4,)
with torch.autograd.function._set_autograd_function_extension_enabled(True):
jvp(fn, (a,), (b,))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jvp is performed, which does tangent.mul_(output)
Where is the mul_ in the code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the jvp NumpExp_
defines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe one solution could be:
- We could just say that it is okay for us to go through
process
again. It is basicaly a noop, since the stack is the same. Technically it does more checks, but maybe that is fine and we actually want those checks? - Currently in the creation of a dual tensor, the primal and tangent are not explicitly wrapped, instead we rely on them to get automatically lifted. If we manually wrap tangent (and primal) instead, this error should no longer trigger even if we go through
process
an extra time. Since tangent is something the user passed in themselves, we should be okay with mutating it, and not mark it with the immutable wrapper.
Alternate solution (doesn't work):
- I also tried excluding manually in PyFunctorch's process to mimic the cpp version but ran into an issue with
unwrapped_count > 0 INTERNAL ASSERT FAILED
in the dead tensor wrapper fallback and not sure what that means yet. (What does this mean?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still processing what is going on, but let me reply to your questions:
I also tried excluding manually in PyFunctorch's process to mimic the cpp version but ran into an issue with unwrapped_count > 0 INTERNAL ASSERT FAILED in the dead tensor wrapper fallback and not sure what that means yet. (What does this mean?)
There's an invariant that a Tensor with a FuncTorchTensorWrapper dispatch key must be a TensorWrapper. Given that we hit the dead_tensor_fallback, then at least one of the inputs must be a TensorWrapper. The assertion is complaining that none of the inputs are TensorWrapper
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The thing I am struggling a bit with right now is, does the in-place mutation check even make sense for forward-mode AD?
- if it does, then it sounds like C++ functorch is wrong because it bypasses it
- if it doesn't, then to what extent can we just get rid of it from C++ and Python functorch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Claim: the input-mutation check makes sense for forward-mode AD. We want to prevent a situation where the dual tensor is created on the wrong TensorWrapper.
There are two cases here:
Case 1: captured value mutated in-place. If we have:
y = torch.tensor(1.)
def f(x):
y.copy_(x)
return x + y
jvp(f, (x,), (t,)
Then the dual should be created on the wrapped version of y, not y itself. The in-place error checks should ideally raise an error in this situation.
Case 2: tangent tensor mutated in-place (which is what is happening in this PR).
import torch
import torch.autograd.forward_ad as fwAD
x = torch.tensor(2.)
y = torch.tensor(3.)
with fwAD.dual_level():
x_dual = fwAD.make_dual(x, y)
y.copy_(x_dual)
x, x_tangent = fwAD.unpack_dual(x_dual)
If we ran the functorch.jvp equivalent of the above, it's important that the tangent of x is a TensorWrapper, because it ends up getting its own tangent value.
Solution?
Given the above, I like one of the solutions you proposed above, which is:
Currently in the creation of a dual tensor, the primal and tangent are not explicitly wrapped, instead we rely on them to get automatically lifted. If we manually wrap tangent (and primal) instead, this error should no longer trigger even if we go through process an extra time. Since tangent is something the user passed in themselves, we should be okay with mutating it, and not mark it with the immutable wrapper.
functorch.jvp should wrap the primal and the tangent before calling make_dual. The end state is that we get TensorWrapper(primal) that has a tangent which is TensorWrapper(tangent).
Thoughts? Also, thank you for the detailed analysis, it saved me from stepping through the code in gdb
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Case 2 is actually a bug in PyTorch forward mode AD. Even if we make primal and tangent both have TensorWrapper at the same level, the tangent's TensorWrapper should itself never have a tangent. Normally we'd error if we're setting a tangent that itself has a tangent, but we're getting around that check with an in-place lol.
I think that morally tangent should not be wrapped at the same level as the primal (I see the tangent as being metadata that lives on the primal's wrapper, so in a sense it should be subordinate to the primal). Tangent is being wrapped today because we are computing with it while JVP is active, in theory we are only computing with plain tensors at that point, so (if the forward/backward AD kernels were separate) we should be able exclude Autograd key and properly unwrap and pop JVP off the stack before computing forward grads.
That being said, I still think that it is a good idea to manually wrap tangent at the same level as primal today to indicate that it is a tensor explicitly passed in so that its AD metadata isn't immutable.
…_dirty" Uses what was originally in #89860 [ghstack-poisoned]
ghstack-source-id: 1b8229c8ef1e8b91db9a80a2dd974952bd30f9a0 Pull Request resolved: #91222
…_dirty" Uses what was originally in #89860 [ghstack-poisoned]
ghstack-source-id: a2ada18a146d8add1d2ef123dac672c2d64b9b2a Pull Request resolved: #91222
ghstack-source-id: 4cdc34f34d01998c58faf2f41a04ed671af6577f Pull Request resolved: #91222
ghstack-source-id: b93a19d4ccb827968aab8519784f1ef9c77b3cdf Pull Request resolved: #91222
ghstack-source-id: c99c1fc6d3b7aa20298840f47eb4ab3ea2cc348c Pull Request resolved: #91222
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, some minor comments. I assume we are punting the handling of the TODO to the future (but feel free to dig into it more if you're interested)
@@ -953,6 +952,7 @@ def test_vmapvjp(self, device, dtype, op): | |||
# skip because this is flaky depending on what the max_norm is! | |||
skip('nn.functional.embedding', ''), | |||
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format | |||
xfail('NumpyExpMarkDirtyAutogradFunction'), # vmap: inplace into a regular tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to check, this is not "calling in-place operation that would mutate a captured Tensor", right?
ghstack-source-id: bca86129d2448fe005c249ec222f320fe2c83c7d Pull Request resolved: #91222
Yup leaving this for a follow up for now |
@pytorchbot merge -g |
Merge failedReason: Not merging any PRs at the moment because there is a merge blocking https://github.com/pytorch/pytorch/labels/ci:%20sev issue open at: Details for Dev Infra teamRaised by workflow job |
# def setup_context(ctx, outputs, x): | ||
# y = outputs | ||
# def setup_context(ctx, inputs, output): | ||
# y = output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the rename from outputs -> output? is it a single output now? Or they are unpacked?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has always been a single output, just updating the name to reflect that. Since we are returning what the user returned from forward as-is, that can sometime be a tuple, depending on what the user returns.
In an earlier version of this PR I made it always pass in a tuple for consistency, but after discussion here #91222 (comment), I decided to revert that change.
@@ -171,8 +188,8 @@ def mark_dirty_error(*args, **kwargs): | |||
# return x.exp() | |||
# | |||
# @staticmethod | |||
# def setup_context(ctx, outputs, x): | |||
# y = outputs | |||
# def setup_context(ctx, inputs, output): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you actually swap the order? That wasn't reflected in the tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just an outdated comment, this is now the correct order
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(see the python bindings)
@pytorchbot rebase |
@pytorchbot successfully started a rebase job. Check the current status here |
Rebase failed due to Command
Raised by https://github.com/pytorch/pytorch/actions/runs/3790618497 |
ghstack-source-id: 700631c38f3c067eb2cf7dc4fd072d8bb494452b Pull Request resolved: #91222
ghstack-source-id: 9590b032ff4b1c8b075f9cd4c468c0c2cbdcae6e Pull Request resolved: #91222
@pytorchbot merge -g |
Merge startedYour change will be merged once all checks on your PR pass since you used the green (-g) flag (ETA: 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
Fixes #90225
Uses what was originally in #89860