Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dynamo] Graph break properly on dict key error #125866

Closed
williamwen42 opened this issue May 9, 2024 · 0 comments
Closed

[dynamo] Graph break properly on dict key error #125866

williamwen42 opened this issue May 9, 2024 · 0 comments
Assignees
Labels
module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@williamwen42
Copy link
Member

williamwen42 commented May 9, 2024

Found as a part of debugging #93624

Code:

import torch

d = {}

def fn(x):
    try:
        y = d[0]
    except KeyError:
        y = 1
    return x + y

opt_fn = torch.compile(fn, backend="eager")
opt_fn(torch.randn(3, 3))

Error:

Traceback (most recent call last):
  File "/data/users/williamwen/pytorch2/playground3.py", line 13, in <module>
    opt_fn(torch.randn(3, 3))
  File "/data/users/williamwen/pytorch2/torch/_dynamo/eval_frame.py", line 420, in _fn
    return fn(*args, **kwargs)
  File "/data/users/williamwen/pytorch2/torch/_dynamo/convert_frame.py", line 981, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "/data/users/williamwen/pytorch2/torch/_dynamo/convert_frame.py", line 822, in _convert_frame
    result = inner_convert(
  File "/data/users/williamwen/pytorch2/torch/_dynamo/convert_frame.py", line 410, in _convert_frame_assert
    return _compile(
  File "/data/users/williamwen/pytorch2/torch/_utils_internal.py", line 70, in wrapper_function
    return function(*args, **kwargs)
  File "/data/users/williamwen/py310-env/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/data/users/williamwen/pytorch2/torch/_dynamo/convert_frame.py", line 732, in _compile
    raise InternalTorchDynamoError(str(e)).with_traceback(
  File "/data/users/williamwen/pytorch2/torch/_dynamo/convert_frame.py", line 703, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/data/users/williamwen/pytorch2/torch/_dynamo/utils.py", line 273, in time_wrapper
    r = func(*args, **kwargs)
  File "/data/users/williamwen/pytorch2/torch/_dynamo/convert_frame.py", line 570, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/data/users/williamwen/pytorch2/torch/_dynamo/bytecode_transformation.py", line 1167, in transform_code_object
    transformations(instructions, code_options)
  File "/data/users/williamwen/pytorch2/torch/_dynamo/convert_frame.py", line 172, in _fn
    return fn(*args, **kwargs)
  File "/data/users/williamwen/pytorch2/torch/_dynamo/convert_frame.py", line 517, in transform
    tracer.run()
  File "/data/users/williamwen/pytorch2/torch/_dynamo/symbolic_convert.py", line 2234, in run
    super().run()
  File "/data/users/williamwen/pytorch2/torch/_dynamo/symbolic_convert.py", line 884, in run
    while self.step():
  File "/data/users/williamwen/pytorch2/torch/_dynamo/symbolic_convert.py", line 799, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/data/users/williamwen/pytorch2/torch/_dynamo/symbolic_convert.py", line 494, in wrapper
    return inner_fn(self, inst)
  File "/data/users/williamwen/pytorch2/torch/_dynamo/symbolic_convert.py", line 231, in impl
    self.push(fn_var.call_function(self, self.popn(nargs), {}))
  File "/data/users/williamwen/pytorch2/torch/_dynamo/variables/builtin.py", line 946, in call_function
    return handler(tx, args, kwargs)
  File "/data/users/williamwen/pytorch2/torch/_dynamo/variables/builtin.py", line 830, in builtin_dipatch
    rv = fn(tx, args, kwargs)
  File "/data/users/williamwen/pytorch2/torch/_dynamo/variables/builtin.py", line 750, in call_self_handler
    result = self_handler(tx, *args, **kwargs)
  File "/data/users/williamwen/pytorch2/torch/_dynamo/variables/builtin.py", line 1346, in call_getitem
    return args[0].call_method(tx, "__getitem__", args[1:], kwargs)
  File "/data/users/williamwen/pytorch2/torch/_dynamo/variables/dicts.py", line 224, in call_method
    return self.getitem_const(args[0])
  File "/data/users/williamwen/pytorch2/torch/_dynamo/variables/dicts.py", line 200, in getitem_const
    raise KeyError(arg.value)
torch._dynamo.exc.InternalTorchDynamoError: 0

from user code:
   File "/data/users/williamwen/pytorch2/playground3.py", line 7, in fn
    y = d[0]

We should graph break cleanly, not error out.

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng

@williamwen42 williamwen42 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module oncall: pt2 module: dynamo labels May 9, 2024
@williamwen42 williamwen42 self-assigned this May 9, 2024
@williamwen42 williamwen42 changed the title [dynamo] Graph break properly on dict key errror [dynamo] Graph break properly on dict key error May 9, 2024
ZelboK pushed a commit to ZelboK/pytorch that referenced this issue May 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

1 participant