Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix inference mode / PyDispatcher / Functionalize interaction #103275

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion c10/core/DispatchKeySet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset |
// where we would like to support composite implicit kernels but not
// explicit kernels therefore we manually add the key to the
// math_dispatch_keyset
DispatchKeySet{DispatchKey::NestedTensor};
DispatchKeySet{DispatchKey::NestedTensor} |
// Functionalize should always re-use CompositeImplicit decomps.
DispatchKeySet{DispatchKey::Functionalize};

constexpr DispatchKeySet nested_dispatch_keyset =
DispatchKeySet(
Expand Down
19 changes: 19 additions & 0 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -3287,6 +3287,25 @@ def check_type(obj, types_or_checks):
res = opt_check_type(torch.randn(4), [torch.Tensor])
self.assertEqual(ref, res)

# Test for https://github.com/pytorch/pytorch/issues/103132
@torch._dynamo.config.patch("assume_static_by_default", False)
def test_inference_mode_dynamic_shapes(self):
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, param):
z = torch.matmul(param, param)
return z

model = Repro()
# Need a 3d tensor to actually cause the error:
# we go down a path of the C++ matmul decomp that calls sizes().
inp = torch.randn(4, 4, 4, requires_grad=True)
model = torch.compile(model, backend="aot_eager", dynamic=True)
with torch.inference_mode():
model(inp)

def test_kwargs_out_list_variable(self):
class Repro(torch.nn.Module):
def __init__(self):
Expand Down
31 changes: 12 additions & 19 deletions torchgen/gen_functionalization_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,22 +687,8 @@ def gen_functionalization_registration(
) -> List[str]:
@with_native_function
def emit_registration_helper(f: NativeFunction) -> str:
if f.has_composite_implicit_autograd_kernel:
metadata = composite_implicit_autograd_index.get_kernel(f)
assert metadata is not None
native_api_name = metadata.kernel
sig = NativeSignature(f.func, symint=metadata.supports_symint())
# Note [Composite view ops in the functionalization pass]
# We don't need to worry about implemententing functionalization kernels for views with
# CompositeImplicitAutograd kernels, because we can just decompose them into their base operators.
# We can't just opt the entire Functionalization dispatch key into the composite keyset though,
# because we don't want to decompose non-view ops that are composite, like `at::ones`.
registration_str = (
f"static_cast<{sig.ptr_type()}>(at::native::{native_api_name})"
)
else:
# non-composite view ops (and inplace ops) get a normal registration.
registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})"
assert not f.has_composite_implicit_autograd_kernel
registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})"
return f'm.impl("{f.func.name}", {registration_str});'

# Don't generate kernels in mobile build
Expand All @@ -714,8 +700,13 @@ def emit_registration_helper(f: NativeFunction) -> str:
# See Note [Functionalization <> torch.Tensor constructor]
if str(g.view.func.name) == "lift_fresh":
return []
view_str = [emit_registration_helper(g.view)]
if g.view_inplace is not None:
view_str = []
if not g.view.has_composite_implicit_autograd_kernel:
view_str.append(emit_registration_helper(g.view))
if (
g.view_inplace is not None
and not g.view_inplace.has_composite_implicit_autograd_kernel
):
assert g.view_inplace.is_view_op
view_str.append(emit_registration_helper(g.view_inplace))
return view_str
Expand All @@ -729,6 +720,8 @@ def emit_registration_helper(f: NativeFunction) -> str:

registrations = []
for f in fns:
if f.has_composite_implicit_autograd_kernel:
continue
if str(f.func.name) == "lift":
# See Note [Functionalization <> torch.Tensor constructor]
return []
Expand All @@ -739,7 +732,7 @@ def emit_registration_helper(f: NativeFunction) -> str:
# functionalization needs to generate and register kernals for inplace ops.
# We *also* need to directly register CompositeImplicitAUtograd kernels
# so that they decompose properly before functioanlization.
if modifies_arguments(f) or f.has_composite_implicit_autograd_kernel:
if modifies_arguments(f):
registrations.append(emit_registration_helper(f))
return registrations

Expand Down