Skip to content

Commit

Permalink
[dynamo] graph break on const dict KeyError
Browse files Browse the repository at this point in the history
ghstack-source-id: 0a98e95179b0b809528ecdbb6a39235beea34978
Pull Request resolved: #125882
  • Loading branch information
williamwen42 committed May 10, 2024
1 parent 23e71ff commit b2566f8
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 13 deletions.
17 changes: 15 additions & 2 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -3320,8 +3320,7 @@ def fn(x):
return y

x = {"a": torch.tensor([1]), "b": torch.tensor([1])}
# FIXME It should be KeyError here
self.assertRaises(torch._dynamo.exc.InternalTorchDynamoError, lambda: fn(x))
self.assertRaises(KeyError, lambda: fn(x))

def test_attached_attribute_in_dir(self):
class MyModule(torch.nn.Module):
Expand Down Expand Up @@ -4845,6 +4844,20 @@ def ladder(x):
opt_ladder = torch.compile(ladder, fullgraph=True, backend="eager")
self.assertEqual(opt_ladder(data), ladder(data))

def test_const_dict_keyerror(self):
d = {}

def fn(x):
try:
y = d[0]
except KeyError:
y = 1
return x + y

opt_fn = torch.compile(fn, backend="eager")
inp = torch.randn(3, 3)
self.assertEqual(fn(inp), opt_fn(inp))


instantiate_parametrized_tests(ReproTests)

Expand Down
Empty file.
22 changes: 12 additions & 10 deletions torch/_dynamo/variables/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,27 +787,29 @@ def call_self_handler(tx, args, kwargs):

def constant_fold_handler(tx, args, kwargs):
# fast path
return builder(
tx,
fn(
try:
res = fn(
*[x.as_python_constant() for x in args],
),
)
)
except Exception as exc:
unimplemented(f"constant fold exception: {repr(exc)}")
return builder(tx, res)

else:

def constant_fold_handler(tx, args, kwargs):
# path with a runtime check
if check_unspec_or_constant_args(args, kwargs):
return builder(
tx,
fn(
try:
res = fn(
*[x.as_python_constant() for x in args],
**{
k: v.as_python_constant() for k, v in kwargs.items()
},
),
)
)
except Exception as exc:
unimplemented(f"constant fold exception: {repr(exc)}")
return builder(tx, res)

handlers.append(constant_fold_handler)

Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/variables/dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def reconstruct(self, codegen):
def getitem_const(self, arg: VariableTracker):
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
raise KeyError(arg.value)
unimplemented(f"dict KeyError: {arg.value}")
return self.items[key]

def call_method(
Expand Down

0 comments on commit b2566f8

Please sign in to comment.