Skip to content

Conversation

@bdhirsh
Copy link
Contributor

@bdhirsh bdhirsh commented May 17, 2022

Confirmed that this fixes pytorch/functorch#806 locally.

The problem was that in the torch.tensor() constructor we needed to exclude both the functorch FrontMode + BackMode keys until we call at::lift, since both keys can potentially wrap their outputs into TensorWrappers.

Separately, Richard also pointed out that we want at::lift() to be a primitive op w.r.t. functorch. Otherwise, if we have multiple layers of transforms like grad(grad(...)), the first layer will decompose at::lift() into a no-op, and the second layer won't see at::lift() at all (and won't perform a second level of wrapping.

I made at::lift() a CompositeExplicitAutograd op, but also needed to explicitly op it out of autograd's view tracking / error checking logic.

The idea is that technically lift has incorrect alias info, because the no-op kernel for lift just returns the input tensor instead of creating a fresh one. But we don't actually want autograd to care about lift - lift should never be called with a requires_grad=True tensor.

Unfortunately that still doesn't fix this extra test though. Gonna take a deeper look at why later:

      def test_tensor_ctor_inside_grad_nested(self, device):
          def foo(x):
              z = torch.tensor(0.)
              z.copy_(x)
              return z.sum()

          x = torch.tensor(3.14, device=device)
          functorch.grad(functorch.grad(foo))(x)

# fails with:
  File "/raid/hirsheybar/pytorch/functorch/functorch/_src/eager_transforms.py", line 1086, in wrapper
    flat_grad_input = _autograd_grad(flat_outputs, flat_diff_args, create_graph=True)
  File "/raid/hirsheybar/pytorch/functorch/functorch/_src/eager_transforms.py", line 104, in _autograd_grad
    grad_inputs = torch.autograd.grad(diff_outputs, inputs, grad_outputs,
  File "/raid/hirsheybar/pytorch/torch/autograd/__init__.py", line 276, in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
NotImplementedError: Cannot access storage of TensorWrapper

Stack from ghstack:

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented May 17, 2022

🔗 Helpful links

❌ 1 New Failures

As of commit 10e0553 (more details on the Dr. CI page):

Expand to see more
  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages

See GitHub Actions build pull / linux-xenial-py3.7-gcc5.4 / test (backwards_compat, 1, 1, linux.2xlarge) (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-05-17T14:44:21.0010427Z The PR is introduc...m to confirm whether this change is wanted or not.
2022-05-17T14:44:20.9996455Z processing existing schema:  text(__torch__.torch.classes.profiling.SourceRef _0) -> (str _0)
2022-05-17T14:44:20.9997743Z processing existing schema:  count(__torch__.torch.classes.profiling.InstructionStats _0) -> (int _0)
2022-05-17T14:44:20.9999210Z processing existing schema:  duration_ns(__torch__.torch.classes.profiling.InstructionStats _0) -> (int _0)
2022-05-17T14:44:21.0000559Z processing existing schema:  source(__torch__.torch.classes.profiling.SourceStats _0) -> (__torch__.torch.classes.profiling.SourceRef _0)
2022-05-17T14:44:21.0002522Z processing existing schema:  line_map(__torch__.torch.classes.profiling.SourceStats _0) -> (Dict(int, __torch__.torch.classes.profiling.InstructionStats) _0)
2022-05-17T14:44:21.0003826Z processing existing schema:  __init__(__torch__.torch.classes.profiling._ScriptProfile _0) -> (NoneType _0)
2022-05-17T14:44:21.0005166Z processing existing schema:  enable(__torch__.torch.classes.profiling._ScriptProfile _0) -> (NoneType _0)
2022-05-17T14:44:21.0006602Z processing existing schema:  disable(__torch__.torch.classes.profiling._ScriptProfile _0) -> (NoneType _0)
2022-05-17T14:44:21.0008569Z processing existing schema:  _dump_stats(__torch__.torch.classes.profiling._ScriptProfile _0) -> (__torch__.torch.classes.profiling.SourceStats[] _0)
2022-05-17T14:44:21.0010183Z processing existing schema:  __init__(__torch__.torch.classes.dist_rpc.WorkerInfo _0, str _1, int _2) -> (NoneType _0)
2022-05-17T14:44:21.0010427Z The PR is introducing backward incompatible changes to the operator library. Please contact PyTorch team to confirm whether this change is wanted or not. 
2022-05-17T14:44:21.0010439Z 
2022-05-17T14:44:21.0010856Z Broken ops: [
2022-05-17T14:44:21.0011035Z 	aten::to_sparse_csc(Tensor self) -> (Tensor)
2022-05-17T14:44:21.0011254Z 	aten::to_sparse_bsc(Tensor self, int[2] blocksize) -> (Tensor)
2022-05-17T14:44:21.0011317Z ]
2022-05-17T14:44:21.0971806Z + cleanup
2022-05-17T14:44:21.0971932Z + retcode=1
2022-05-17T14:44:21.0971997Z + set +x
2022-05-17T14:44:21.1002807Z ##[error]Process completed with exit code 1.
2022-05-17T14:44:21.1039804Z ##[group]Run pytorch/pytorch/.github/actions/get-workflow-job-id@master

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

bdhirsh added a commit that referenced this pull request May 17, 2022
ghstack-source-id: ede7cbc
Pull Request resolved: #77650
@bdhirsh bdhirsh requested a review from zou3519 May 17, 2022 14:42
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

autograd changes sound ok. I'll let Richard review the functorch side.

@bdhirsh
Copy link
Contributor Author

bdhirsh commented May 17, 2022

@pytorchbot merge on green

@github-actions
Copy link
Contributor

Hey @bdhirsh.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

facebook-github-bot pushed a commit that referenced this pull request May 18, 2022
Summary:
Pull Request resolved: #77650

Approved by: https://github.com/zou3519

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/cfc87cad02e47a3a62b4328454e825748c7be4fd

Reviewed By: b0noI

Differential Revision: D36451898

Pulled By: bdhirsh

fbshipit-source-id: 4fc62a40f275ea5a27f3f985927433224c62328a
@facebook-github-bot facebook-github-bot deleted the gh/bdhirsh/234/head branch May 21, 2022 14:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants