Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update functorch supported autograd.Function to allow mark_dirty #91222

Closed
wants to merge 9 commits into from

Conversation

soulitzer
Copy link
Contributor

@soulitzer soulitzer commented Dec 21, 2022

Stack from ghstack (oldest at bottom):

Fixes #90225
Uses what was originally in #89860

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 21, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/91222

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 15e8773:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

soulitzer added a commit that referenced this pull request Dec 21, 2022
ghstack-source-id: 1b8229c8ef1e8b91db9a80a2dd974952bd30f9a0
Pull Request resolved: #91222
@@ -450,7 +449,7 @@ def wrapped_fn(*args, **kwargs):
xfail('nn.functional._scaled_dot_product_attention', device_type='cuda'),

xfail('nn.functional.rrelu'), # in-place test errors out with no formula implemented
xfail('NumpyExpMarkDirtyAutogradFunction'), # https://github.com/pytorch/pytorch/issues/90225
xfail('NumpyExpMarkDirtyAutogradFunction'), # TODO: calling in-place operation that would mutate a captured Tensor
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This errors for a different reason now, need to investigate.

Copy link
Contributor Author

@soulitzer soulitzer Dec 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 I think I figured out what the issue was with this, but not sure what the solution is

In cpp functorch:

  • we create dual tensor with tangent that is captured (so it has a immutable wrapper)
  • when we call into exp_, the checkForInvalidMutationOnCaptures does not care about the tangent having a immutable wrapper, because that is hidden by the dual tensors wrapper which isn't immutable.
  • before we call into VariableType, we first exclude the dynamicLayerFront
  • forward AD formula for inplace ops does tangent.copy_(new_tangent). Because we already excluded dynamicLayerFront, we just go into VariableType again (which is basically a noop since neither tensor has tangent this time around)

In pyfunctorch

  • we still create that immutable wrapper for tangent
  • when we call process, we do not exclude dynamicLayerFront
  • process constructs the single layer autograd Function and calls apply (which calls into forward, then jvp)
  • after forward is done (with no problems), jvp is performed, which does tangent.mul_(output). At this point, JvpTransform is still at the stop of the stack.
  • since we did not exclude this time, we go into dynamicLayerFront this time, which now errors due to checkForInvalidMutationOnCaptures because now we're performing an in-place op on the tangent which has the immutable wrapper.
click for repro
from functorch import vmap, jvp
import torch
import numpy as np

def to_numpy(tensor):
    return tensor.cpu().numpy()


class NumpyMul(torch.autograd.Function):
    @staticmethod
    def forward(x, y):
        return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)

    @staticmethod
    def setup_context(ctx, inputs, outputs):
        ctx.save_for_backward(*inputs)
        ctx.save_for_forward(*inputs)

    @staticmethod
    def backward(ctx, grad_output):
        x, y = ctx.saved_tensors
        gx = None
        if ctx.needs_input_grad[0]:
            gx = NumpyMul.apply(grad_output, y)
        gy = None
        if ctx.needs_input_grad[1]:
            gy = NumpyMul.apply(grad_output, x)
        return gx, gy

    @staticmethod
    def vmap(info, in_dims, x, y):
        x_bdim, y_bdim = in_dims
        x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
        y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
        result = NumpyMul.apply(x, y)
        result = result.movedim(-1, 0)
        return result, 0

    @staticmethod
    def jvp(ctx, x_tangent, y_tangent):
        x, y = ctx.saved_tensors
        return x_tangent * y + y_tangent * x

class NumpyExp_(torch.autograd.Function):
    @staticmethod
    def forward(x):
        x_np = to_numpy(x)
        np.exp(x_np, x_np)
        return x

    @staticmethod
    def setup_context(ctx, inputs, outputs):
        x, = inputs
        ctx.mark_dirty(x)
        ctx.save_for_backward(outputs)
        ctx.save_for_forward(outputs)

    @staticmethod
    def backward(ctx, grad_output):
        output, = ctx.saved_tensors
        return NumpyMul.apply(grad_output, output)

    @staticmethod
    def vmap(info, in_dims, x):
        NumpyExp_.apply(x)
        return x, in_dims[0]

    @staticmethod
    def jvp(ctx, x_tangent):
        output, = ctx.saved_tensors
        x_tangent.mul_(output)
        return x_tangent


def fn(x):
    # return torch.exp_(x) <-- does not error
    return NumpyExp_.apply(x)

a = torch.rand(4,)
b = torch.rand(4,)

with torch.autograd.function._set_autograd_function_extension_enabled(True):
    jvp(fn, (a,), (b,))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jvp is performed, which does tangent.mul_(output)

Where is the mul_ in the code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the jvp NumpExp_ defines

Copy link
Contributor Author

@soulitzer soulitzer Dec 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe one solution could be:

  • We could just say that it is okay for us to go through process again. It is basicaly a noop, since the stack is the same. Technically it does more checks, but maybe that is fine and we actually want those checks?
  • Currently in the creation of a dual tensor, the primal and tangent are not explicitly wrapped, instead we rely on them to get automatically lifted. If we manually wrap tangent (and primal) instead, this error should no longer trigger even if we go through process an extra time. Since tangent is something the user passed in themselves, we should be okay with mutating it, and not mark it with the immutable wrapper.

Alternate solution (doesn't work):

  • I also tried excluding manually in PyFunctorch's process to mimic the cpp version but ran into an issue with unwrapped_count > 0 INTERNAL ASSERT FAILED in the dead tensor wrapper fallback and not sure what that means yet. (What does this mean?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still processing what is going on, but let me reply to your questions:

I also tried excluding manually in PyFunctorch's process to mimic the cpp version but ran into an issue with unwrapped_count > 0 INTERNAL ASSERT FAILED in the dead tensor wrapper fallback and not sure what that means yet. (What does this mean?)

There's an invariant that a Tensor with a FuncTorchTensorWrapper dispatch key must be a TensorWrapper. Given that we hit the dead_tensor_fallback, then at least one of the inputs must be a TensorWrapper. The assertion is complaining that none of the inputs are TensorWrapper

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing I am struggling a bit with right now is, does the in-place mutation check even make sense for forward-mode AD?

  • if it does, then it sounds like C++ functorch is wrong because it bypasses it
  • if it doesn't, then to what extent can we just get rid of it from C++ and Python functorch?

Copy link
Contributor

@zou3519 zou3519 Dec 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claim: the input-mutation check makes sense for forward-mode AD. We want to prevent a situation where the dual tensor is created on the wrong TensorWrapper.

There are two cases here:

Case 1: captured value mutated in-place. If we have:

y = torch.tensor(1.)
def f(x):
  y.copy_(x)
  return x + y

jvp(f, (x,), (t,)

Then the dual should be created on the wrapped version of y, not y itself. The in-place error checks should ideally raise an error in this situation.

Case 2: tangent tensor mutated in-place (which is what is happening in this PR).

import torch
import torch.autograd.forward_ad as fwAD

x = torch.tensor(2.)
y = torch.tensor(3.)

with fwAD.dual_level():
    x_dual = fwAD.make_dual(x, y)
    y.copy_(x_dual)
    x, x_tangent = fwAD.unpack_dual(x_dual)

If we ran the functorch.jvp equivalent of the above, it's important that the tangent of x is a TensorWrapper, because it ends up getting its own tangent value.

Solution?

Given the above, I like one of the solutions you proposed above, which is:

Currently in the creation of a dual tensor, the primal and tangent are not explicitly wrapped, instead we rely on them to get automatically lifted. If we manually wrap tangent (and primal) instead, this error should no longer trigger even if we go through process an extra time. Since tangent is something the user passed in themselves, we should be okay with mutating it, and not mark it with the immutable wrapper.

functorch.jvp should wrap the primal and the tangent before calling make_dual. The end state is that we get TensorWrapper(primal) that has a tangent which is TensorWrapper(tangent).

Thoughts? Also, thank you for the detailed analysis, it saved me from stepping through the code in gdb

Copy link
Contributor Author

@soulitzer soulitzer Dec 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Case 2 is actually a bug in PyTorch forward mode AD. Even if we make primal and tangent both have TensorWrapper at the same level, the tangent's TensorWrapper should itself never have a tangent. Normally we'd error if we're setting a tangent that itself has a tangent, but we're getting around that check with an in-place lol.

I think that morally tangent should not be wrapped at the same level as the primal (I see the tangent as being metadata that lives on the primal's wrapper, so in a sense it should be subordinate to the primal). Tangent is being wrapped today because we are computing with it while JVP is active, in theory we are only computing with plain tensors at that point, so (if the forward/backward AD kernels were separate) we should be able exclude Autograd key and properly unwrap and pop JVP off the stack before computing forward grads.

That being said, I still think that it is a good idea to manually wrap tangent at the same level as primal today to indicate that it is a tensor explicitly passed in so that its AD metadata isn't immutable.

@soulitzer soulitzer added the release notes: functorch release notes category; Pertaining to torch.func or pytorch/functorch label Dec 21, 2022
@pytorch pytorch deleted a comment from github-actions bot Dec 21, 2022
…_dirty"


Uses what was originally in #89860


[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Dec 21, 2022
ghstack-source-id: 1b8229c8ef1e8b91db9a80a2dd974952bd30f9a0
Pull Request resolved: #91222
…_dirty"


Uses what was originally in #89860


[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Dec 21, 2022
ghstack-source-id: a2ada18a146d8add1d2ef123dac672c2d64b9b2a
Pull Request resolved: #91222
…_dirty"


Fixes #90225
Uses what was originally in #89860


[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Dec 21, 2022
ghstack-source-id: 4cdc34f34d01998c58faf2f41a04ed671af6577f
Pull Request resolved: #91222
…_dirty"


Fixes #90225
Uses what was originally in #89860


[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Dec 21, 2022
ghstack-source-id: b93a19d4ccb827968aab8519784f1ef9c77b3cdf
Pull Request resolved: #91222
…_dirty"


Fixes #90225
Uses what was originally in #89860


[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Dec 21, 2022
ghstack-source-id: c99c1fc6d3b7aa20298840f47eb4ab3ea2cc348c
Pull Request resolved: #91222
Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, some minor comments. I assume we are punting the handling of the TODO to the future (but feel free to dig into it more if you're interested)

torch/_functorch/autograd_function.py Show resolved Hide resolved
@@ -953,6 +952,7 @@ def test_vmapvjp(self, device, dtype, op):
# skip because this is flaky depending on what the max_norm is!
skip('nn.functional.embedding', ''),
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
xfail('NumpyExpMarkDirtyAutogradFunction'), # vmap: inplace into a regular tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to check, this is not "calling in-place operation that would mutate a captured Tensor", right?

test/functorch/test_ops.py Outdated Show resolved Hide resolved
…_dirty"


Fixes #90225
Uses what was originally in #89860


[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Dec 23, 2022
ghstack-source-id: bca86129d2448fe005c249ec222f320fe2c83c7d
Pull Request resolved: #91222
@soulitzer
Copy link
Contributor Author

I assume we are punting the handling of the TODO to the future (but feel free to dig into it more if you're interested)

Yup leaving this for a follow up for now

@soulitzer
Copy link
Contributor Author

@pytorchbot merge -g

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 23, 2022
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Not merging any PRs at the moment because there is a merge blocking https://github.com/pytorch/pytorch/labels/ci:%20sev issue open at:
#91332

Details for Dev Infra team Raised by workflow job

# def setup_context(ctx, outputs, x):
# y = outputs
# def setup_context(ctx, inputs, output):
# y = output
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the rename from outputs -> output? is it a single output now? Or they are unpacked?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has always been a single output, just updating the name to reflect that. Since we are returning what the user returned from forward as-is, that can sometime be a tuple, depending on what the user returns.

In an earlier version of this PR I made it always pass in a tuple for consistency, but after discussion here #91222 (comment), I decided to revert that change.

@@ -171,8 +188,8 @@ def mark_dirty_error(*args, **kwargs):
# return x.exp()
#
# @staticmethod
# def setup_context(ctx, outputs, x):
# y = outputs
# def setup_context(ctx, inputs, output):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you actually swap the order? That wasn't reflected in the tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just an outdated comment, this is now the correct order

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(see the python bindings)

@soulitzer
Copy link
Contributor Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Rebase failed due to Command git -C /home/runner/work/pytorch/pytorch rebase refs/remotes/origin/viable/strict gh/soulitzer/159/orig returned non-zero exit code 1

warning: skipped previously applied commit 17288ffde5
warning: skipped previously applied commit 26582f5770
hint: use --reapply-cherry-picks to include skipped commits
hint: Disable this message with "git config advice.skippedCherryPicks false"
Rebasing (1/1)
Auto-merging test/functorch/test_eager_transforms.py
Auto-merging test/functorch/test_ops.py
Auto-merging test/test_autograd.py
Auto-merging torch/_C/_functorch.pyi
CONFLICT (content): Merge conflict in torch/_C/_functorch.pyi
Auto-merging torch/_functorch/autograd_function.py
CONFLICT (content): Merge conflict in torch/_functorch/autograd_function.py
Auto-merging torch/testing/_internal/autograd_function_db.py
error: could not apply 1a9cb2038f... Update functorch supported autograd.Function to allow mark_dirty
hint: Resolve all conflicts manually, mark them as resolved with
hint: "git add/rm <conflicted_files>", then run "git rebase --continue".
hint: You can instead skip this commit: run "git rebase --skip".
hint: To abort and get back to the state before "git rebase", run "git rebase --abort".
Could not apply 1a9cb2038f... Update functorch supported autograd.Function to allow mark_dirty

Raised by https://github.com/pytorch/pytorch/actions/runs/3790618497

…_dirty"


Fixes #90225
Uses what was originally in #89860


[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Dec 28, 2022
ghstack-source-id: 700631c38f3c067eb2cf7dc4fd072d8bb494452b
Pull Request resolved: #91222
…_dirty"


Fixes #90225
Uses what was originally in #89860


[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Dec 28, 2022
ghstack-source-id: 9590b032ff4b1c8b075f9cd4c468c0c2cbdcae6e
Pull Request resolved: #91222
@pytorch pytorch deleted a comment from pytorch-bot bot Dec 28, 2022
@soulitzer
Copy link
Contributor Author

@pytorchbot merge -g

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks on your PR pass since you used the green (-g) flag (ETA: 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/soulitzer/159/head branch June 8, 2023 18:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: functorch release notes category; Pertaining to torch.func or pytorch/functorch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants