Skip to content

Commit

Permalink
fix inference_mode with torch.compile (#101219)
Browse files Browse the repository at this point in the history
It looks like inference_mode wasn't playing well with functionalization.

If you run torch.compile on a function, and the inputs to the function are tensors created outside of inference mode, then we need to make sure that when we created functional tensor wrappers for those inputs during compilation, those functional wrappers properly mirror whether or not the original tensor is an inference tensor.

Hopefully fixes #101151

Pull Request resolved: #101219
Approved by: https://github.com/albanD, https://github.com/ezyang
  • Loading branch information
bdhirsh authored and pytorchmergebot committed May 19, 2023
1 parent 1fabee3 commit 11f7ae1
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
4 changes: 4 additions & 0 deletions aten/src/ATen/FunctionalTensorWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ void FunctionalTensorWrapper::set_constructor_metadata() {
// TODO: metadata copying may not actually be necessary then
set_custom_sizes_strides(SizesStridesPolicy::CustomSizes);
set_custom_device(true);
// E.g. when running torch.compile under inference mode, we need to make sure that
// for any inputs that were created outside of inference mode (so they are not inference tensors),
// then the functional wrappers that we wrap them with should also not be inference tensors.
version_counter_ = value_.unsafeGetTensorImpl()->version_counter();
}

FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& value)
Expand Down
11 changes: 11 additions & 0 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1821,6 +1821,17 @@ def f(x):
out = af(inp)
self.assertEqual(out, f(inp))

def test_inference_mode(self):
m = torch.nn.Linear(4, 4)
inp = torch.randn(4, 4)

aot_mod = aot_module(m, fw_compiler=nop)

with torch.inference_mode():
out_ref = m(inp)
out_test = aot_mod(inp)
self.assertEqual(out_ref, out_test)

def test_default_partitioner_saves_symints_not_tensors_for_bw(self):
"""
In this test, the important thing is that primals_1 is **only** needed in the backward
Expand Down

1 comment on commit 11f7ae1

@pytorchmergebot
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted #101219 on behalf of https://github.com/PaliC due to breaking inductor tests (comment)

Please sign in to comment.