Skip to content

Commit

Permalink
fix inference mode / PyDispatcher / Functionalize interaction
Browse files Browse the repository at this point in the history
ghstack-source-id: cf06aa9bd5ce15c2d6ebd36aea239514c5a38e96
Pull Request resolved: #103275
  • Loading branch information
bdhirsh committed Jun 9, 2023
1 parent d89dd05 commit 6685a2a
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 20 deletions.
4 changes: 3 additions & 1 deletion c10/core/DispatchKeySet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset |
// where we would like to support composite implicit kernels but not
// explicit kernels therefore we manually add the key to the
// math_dispatch_keyset
DispatchKeySet{DispatchKey::NestedTensor};
DispatchKeySet{DispatchKey::NestedTensor} |
// Functionalize should always re-use CompositeImplicit decomps.
DispatchKeySet{DispatchKey::Functionalize};

constexpr DispatchKeySet nested_dispatch_keyset =
DispatchKeySet(
Expand Down
19 changes: 19 additions & 0 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -3257,6 +3257,25 @@ def check_type(obj, types_or_checks):
res = opt_check_type(torch.randn(4), [torch.Tensor])
self.assertEqual(ref, res)

# Test for https://github.com/pytorch/pytorch/issues/103132
@torch._dynamo.config.patch("assume_static_by_default", False)
def test_inference_mode_dynamic_shapes(self):
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, param):
z = torch.matmul(param, param)
return z

model = Repro()
# Need a 3d tensor to actually cause the error:
# we go down a path of the C++ matmul decomp that calls sizes().
inp = torch.randn(4, 4, 4, requires_grad=True)
model = torch.compile(model, backend="aot_eager", dynamic=True)
with torch.inference_mode():
model(inp)

def test_kwargs_out_list_variable(self):
class Repro(torch.nn.Module):
def __init__(self):
Expand Down
31 changes: 12 additions & 19 deletions torchgen/gen_functionalization_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,22 +689,8 @@ def gen_functionalization_registration(
) -> List[str]:
@with_native_function
def emit_registration_helper(f: NativeFunction) -> str:
if f.has_composite_implicit_autograd_kernel:
metadata = composite_implicit_autograd_index.get_kernel(f)
assert metadata is not None
native_api_name = metadata.kernel
sig = NativeSignature(f.func, symint=metadata.supports_symint())
# Note [Composite view ops in the functionalization pass]
# We don't need to worry about implemententing functionalization kernels for views with
# CompositeImplicitAutograd kernels, because we can just decompose them into their base operators.
# We can't just opt the entire Functionalization dispatch key into the composite keyset though,
# because we don't want to decompose non-view ops that are composite, like `at::ones`.
registration_str = (
f"static_cast<{sig.ptr_type()}>(at::native::{native_api_name})"
)
else:
# non-composite view ops (and inplace ops) get a normal registration.
registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})"
assert not f.has_composite_implicit_autograd_kernel
registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})"
return f'm.impl("{f.func.name}", {registration_str});'

# Don't generate kernels in mobile build
Expand All @@ -716,8 +702,13 @@ def emit_registration_helper(f: NativeFunction) -> str:
# See Note [Functionalization <> torch.Tensor constructor]
if str(g.view.func.name) == "lift_fresh":
return []
view_str = [emit_registration_helper(g.view)]
if g.view_inplace is not None:
view_str = []
if not g.view.has_composite_implicit_autograd_kernel:
view_str.append(emit_registration_helper(g.view))
if (
g.view_inplace is not None
and not g.view_inplace.has_composite_implicit_autograd_kernel
):
assert g.view_inplace.is_view_op
view_str.append(emit_registration_helper(g.view_inplace))
return view_str
Expand All @@ -731,6 +722,8 @@ def emit_registration_helper(f: NativeFunction) -> str:

registrations = []
for f in fns:
if f.has_composite_implicit_autograd_kernel:
continue
if str(f.func.name) == "lift":
# See Note [Functionalization <> torch.Tensor constructor]
return []
Expand All @@ -741,7 +734,7 @@ def emit_registration_helper(f: NativeFunction) -> str:
# functionalization needs to generate and register kernals for inplace ops.
# We *also* need to directly register CompositeImplicitAUtograd kernels
# so that they decompose properly before functioanlization.
if modifies_arguments(f) or f.has_composite_implicit_autograd_kernel:
if modifies_arguments(f):
registrations.append(emit_registration_helper(f))
return registrations

Expand Down

0 comments on commit 6685a2a

Please sign in to comment.