Skip to content

Commit

Permalink
[dynamo] graph break on const dict KeyError (#125882)
Browse files Browse the repository at this point in the history
Fixes #125866

Pull Request resolved: #125882
Approved by: https://github.com/jansel
  • Loading branch information
williamwen42 authored and pytorchmergebot committed May 15, 2024
1 parent b5432ad commit 100e3c1
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 13 deletions.
17 changes: 15 additions & 2 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -3346,8 +3346,7 @@ def fn(x):
return y

x = {"a": torch.tensor([1]), "b": torch.tensor([1])}
# FIXME It should be KeyError here
self.assertRaises(torch._dynamo.exc.InternalTorchDynamoError, lambda: fn(x))
self.assertRaises(KeyError, lambda: fn(x))

def test_attached_attribute_in_dir(self):
class MyModule(torch.nn.Module):
Expand Down Expand Up @@ -4949,6 +4948,20 @@ def f(a, tmp):
# grad state may not be properly reset after the error
self.assertTrue(torch.is_grad_enabled())

def test_const_dict_keyerror(self):
d = {}

def fn(x):
try:
y = d[0]
except KeyError:
y = 1
return x + y

opt_fn = torch.compile(fn, backend="eager")
inp = torch.randn(3, 3)
self.assertEqual(fn(inp), opt_fn(inp))


instantiate_parametrized_tests(ReproTests)

Expand Down
Empty file.
Empty file.
Empty file.
Empty file.
22 changes: 12 additions & 10 deletions torch/_dynamo/variables/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,27 +787,29 @@ def call_self_handler(tx, args, kwargs):

def constant_fold_handler(tx, args, kwargs):
# fast path
return builder(
tx,
fn(
try:
res = fn(
*[x.as_python_constant() for x in args],
),
)
)
except Exception as exc:
unimplemented(f"constant fold exception: {repr(exc)}")
return builder(tx, res)

else:

def constant_fold_handler(tx, args, kwargs):
# path with a runtime check
if check_unspec_or_constant_args(args, kwargs):
return builder(
tx,
fn(
try:
res = fn(
*[x.as_python_constant() for x in args],
**{
k: v.as_python_constant() for k, v in kwargs.items()
},
),
)
)
except Exception as exc:
unimplemented(f"constant fold exception: {repr(exc)}")
return builder(tx, res)

handlers.append(constant_fold_handler)

Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/variables/dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def reconstruct(self, codegen):
def getitem_const(self, arg: VariableTracker):
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
raise KeyError(arg.value)
unimplemented(f"dict KeyError: {arg.value}")
return self.items[key]

def call_method(
Expand Down

0 comments on commit 100e3c1

Please sign in to comment.