Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dynamo] graph break on const dict KeyError #125882

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 15 additions & 2 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -3320,8 +3320,7 @@ def fn(x):
return y

x = {"a": torch.tensor([1]), "b": torch.tensor([1])}
# FIXME It should be KeyError here
self.assertRaises(torch._dynamo.exc.InternalTorchDynamoError, lambda: fn(x))
self.assertRaises(KeyError, lambda: fn(x))

def test_attached_attribute_in_dir(self):
class MyModule(torch.nn.Module):
Expand Down Expand Up @@ -4845,6 +4844,20 @@ def ladder(x):
opt_ladder = torch.compile(ladder, fullgraph=True, backend="eager")
self.assertEqual(opt_ladder(data), ladder(data))

def test_const_dict_keyerror(self):
d = {}

def fn(x):
try:
y = d[0]
except KeyError:
y = 1
return x + y

opt_fn = torch.compile(fn, backend="eager")
inp = torch.randn(3, 3)
self.assertEqual(fn(inp), opt_fn(inp))


instantiate_parametrized_tests(ReproTests)

Expand Down
Empty file.
Empty file.
22 changes: 12 additions & 10 deletions torch/_dynamo/variables/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,27 +787,29 @@ def call_self_handler(tx, args, kwargs):

def constant_fold_handler(tx, args, kwargs):
# fast path
return builder(
tx,
fn(
try:
res = fn(
*[x.as_python_constant() for x in args],
),
)
)
except Exception as exc:
unimplemented(f"constant fold exception: {repr(exc)}")
Comment on lines +794 to +795
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What error is this catching? I worry this will be hiding bugs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is catching an error on from a python call - fn would be a builtin function. Idea is that if we get a python error when trying to constant fold, we bubble up the error to skip the frame, then we evaluate the frame as normal, which would result in the same error. In the case of the repro, we would be catching a KeyError.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make catch case more narrow, perhaps by checking the message and warning if it is somthing else.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that any possible exception raised by a function in _constant_fold_functions would have to be caught, e.g. KeyError, ZeroDivisionError, AttributeError, ValueError, etc. We could determine a list of allowed exceptions (by going through the builtin constant fold op list and manually checking for exceptions that could be raised), but if we miss any, then tracing any code like below would crash:

def fn(x):
    try:
        y = ... # builtin constant fold op with constants that raises an exception we missed
        # e.g. we forgot ZeroDivisionError for y = 1 / 0
    except Exception:
        y = 1
    return x + y

What kind of bugs do you think would be hidden? My intent here is that if running the builtin op results in an exception, we should skip the frame, then default evaluation will also raise the exception and run any exception handlers.

return builder(tx, res)

else:

def constant_fold_handler(tx, args, kwargs):
# path with a runtime check
if check_unspec_or_constant_args(args, kwargs):
return builder(
tx,
fn(
try:
res = fn(
*[x.as_python_constant() for x in args],
**{
k: v.as_python_constant() for k, v in kwargs.items()
},
),
)
)
except Exception as exc:
unimplemented(f"constant fold exception: {repr(exc)}")
return builder(tx, res)

handlers.append(constant_fold_handler)

Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/variables/dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def reconstruct(self, codegen):
def getitem_const(self, arg: VariableTracker):
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
raise KeyError(arg.value)
unimplemented(f"dict KeyError: {arg.value}")
return self.items[key]

def call_method(
Expand Down