Skip to content

Commit

Permalink
[dynamo] graph break on const dict KeyError
Browse files Browse the repository at this point in the history
ghstack-source-id: 7104397e77939827b03c15592bba1a946db9c3ba
Pull Request resolved: #125882
  • Loading branch information
williamwen42 committed May 9, 2024
1 parent 23e71ff commit 4b2bd28
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 11 deletions.
14 changes: 14 additions & 0 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -4845,6 +4845,20 @@ def ladder(x):
opt_ladder = torch.compile(ladder, fullgraph=True, backend="eager")
self.assertEqual(opt_ladder(data), ladder(data))

def test_const_dict_keyerror(self):
d = {}

def fn(x):
try:
y = d[0]
except KeyError:
y = 1
return x + y

opt_fn = torch.compile(fn, backend="eager")
inp = torch.randn(3, 3)
self.assertEqual(fn(inp), opt_fn(inp))


instantiate_parametrized_tests(ReproTests)

Expand Down
22 changes: 12 additions & 10 deletions torch/_dynamo/variables/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,27 +787,29 @@ def call_self_handler(tx, args, kwargs):

def constant_fold_handler(tx, args, kwargs):
# fast path
return builder(
tx,
fn(
try:
res = fn(
*[x.as_python_constant() for x in args],
),
)
)
except Exception as exc:
unimplemented(f"constant fold exception: {repr(exc)}")
return builder(tx, res)

else:

def constant_fold_handler(tx, args, kwargs):
# path with a runtime check
if check_unspec_or_constant_args(args, kwargs):
return builder(
tx,
fn(
try:
res = fn(
*[x.as_python_constant() for x in args],
**{
k: v.as_python_constant() for k, v in kwargs.items()
},
),
)
)
except Exception as exc:
unimplemented(f"constant fold exception: {repr(exc)}")
return builder(tx, res)

handlers.append(constant_fold_handler)

Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/variables/dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def reconstruct(self, codegen):
def getitem_const(self, arg: VariableTracker):
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
raise KeyError(arg.value)
unimplemented(f"dict KeyError: {arg.value}")
return self.items[key]

def call_method(
Expand Down

0 comments on commit 4b2bd28

Please sign in to comment.