Skip to content

Commit

Permalink
fix inference mode / PyDispatcher / Functionalize interaction
Browse files Browse the repository at this point in the history
ghstack-source-id: 5b7cf0d317ce5dcbbf93c7a76aa8dcfef9f38c21
Pull Request resolved: #103275
  • Loading branch information
bdhirsh committed Jun 8, 2023
1 parent 605a852 commit 9bd4949
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
19 changes: 19 additions & 0 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -3257,6 +3257,25 @@ def check_type(obj, types_or_checks):
res = opt_check_type(torch.randn(4), [torch.Tensor])
self.assertEqual(ref, res)

# Test for https://github.com/pytorch/pytorch/issues/103132
@torch._dynamo.config.patch("assume_static_by_default", False)
def test_inference_mode_dynamic_shapes(self):
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, param):
z = torch.matmul(param, param)
return z

model = Repro()
# Need a 3d tensor to actually cause the error:
# we go down a path of the C++ matmul decomp that calls sizes().
inp = torch.randn(4, 4, 4, requires_grad=True)
model = torch.compile(model, backend="aot_eager", dynamic=True)
with torch.inference_mode():
model(inp)

def test_kwargs_out_list_variable(self):
class Repro(torch.nn.Module):
def __init__(self):
Expand Down
8 changes: 8 additions & 0 deletions torch/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,14 @@ def inner(fn):
f"Trying to override a python impl for {k} on operator {self.name()}"
)
self.py_kernels[k] = fn
if k == torch._C.DispatchKey.CompositeImplicitAutograd and torch._C.DispatchKey.Functionalize not in self.py_kernels:
# Functionalization codegen is sneaky,
# and it registers C++ CompositeImplicit kernels *directly* to the Functionalize key.
# If we have a CompositeImplicit decomp registered from python, we want functionalize
# to use it, instead of the C++ decomp. We can't though, because Functionalize
# isn't part of the CompositeImplicitAutograd alias set.
# (open quesetion: will we eventually need to do this for functorch transform keys too?)
self.py_kernels[torch._C.DispatchKey.Functionalize] = fn
self._dispatch_cache.clear()
return fn

Expand Down

0 comments on commit 9bd4949

Please sign in to comment.