Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dynamo, 3.12] fix positions and offsets of added instructions when we clean #123991

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10109,7 +10109,7 @@ def fn(inp):
opt_fn = torch.compile(fn, backend="eager")
opt_fn(inp)

def test_312_binary_slice_with_graph_break(self):
def test_312_binary_slice_with_graph_break1(self):
l1 = torch.nn.Linear(5, 5)
l2 = torch.nn.Linear(5, 5)

Expand All @@ -10122,6 +10122,31 @@ def fn(x):
opt_fn = torch.compile(fn, backend="eager")
opt_fn(torch.randn(5, 5))

def test_312_binary_slice_with_graph_break2(self):
class Foo:
def __setitem__(self, key, val):
pass

def __getitem__(self, key):
torch._dynamo.graph_break()
return 1

foo = Foo()

def fn(x):
# graph break in a STORE_SLICE instruction
foo[:] = x
# graph break in BINARY_SLICE with has_backedge check
x = x + foo[:]
if x is None:
x = x + 1
else:
x = x + 1
return x

opt_fn = torch.compile(fn, backend="eager")
opt_fn(torch.randn(5, 5))

def test_super_after_graph_break(self):
class Foo(torch.nn.Sequential):
def __init__(self, layers):
Expand Down
7 changes: 5 additions & 2 deletions torch/_dynamo/bytecode_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,7 @@ def remove_jump_if_none(instructions: List[Instruction]) -> None:
if "_NONE" in inst.opname:
is_op = create_instruction("IS_OP", arg=int("NOT" in inst.opname))
is_op.argval = is_op.arg
is_op.positions = inst.positions
if sys.version_info < (3, 12):
jump_op = create_instruction(
"POP_JUMP_FORWARD_IF_TRUE"
Expand All @@ -813,6 +814,7 @@ def remove_jump_if_none(instructions: List[Instruction]) -> None:
)
else:
jump_op = create_instruction("POP_JUMP_IF_TRUE", target=inst.target)
jump_op.positions = inst.positions
# update inst.exn_tab_entry.end if necessary
if inst.exn_tab_entry and inst.exn_tab_entry.end is inst:
inst.exn_tab_entry.end = jump_op
Expand All @@ -838,6 +840,7 @@ def remove_binary_store_slice(instructions: List[Instruction]) -> None:
if inst.exn_tab_entry and inst.exn_tab_entry.end is inst:
inst.exn_tab_entry.end = subscr_inst
subscr_inst.exn_tab_entry = copy.copy(inst.exn_tab_entry)
subscr_inst.positions = inst.positions
# modify inst in-place to preserve jump target
inst.opcode = dis.opmap["BUILD_SLICE"]
inst.opname = "BUILD_SLICE"
Expand Down Expand Up @@ -1176,10 +1179,10 @@ def cleaned_instructions(code, safe=False) -> List[Instruction]:
explicit_super(code, instructions)
if sys.version_info >= (3, 11):
remove_jump_if_none(instructions)
if sys.version_info >= (3, 12):
remove_binary_store_slice(instructions)
update_offsets(instructions)
devirtualize_jumps(instructions)
if sys.version_info >= (3, 12):
remove_binary_store_slice(instructions)
return instructions


Expand Down
Loading