Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dynamo 3.11] fix push null timing in resume functions #96504

Closed
wants to merge 8 commits into from
12 changes: 7 additions & 5 deletions torch/_dynamo/resume_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,17 +267,19 @@ def update(instructions: List[Instruction], code_options: Dict[str, Any]):
prefix = []
cleanup = []
hooks = {fn.stack_index: fn for fn in setup_fns}
null_idxes_i = 0
for i in range(nstack):
while (
null_idxes_i < len(null_idxes)
and null_idxes[null_idxes_i] == i + null_idxes_i
):
prefix.append(create_instruction("PUSH_NULL"))
null_idxes_i += 1
prefix.append(create_instruction("LOAD_FAST", argval=f"___stack{i}"))
if i in hooks:
prefix.extend(hooks.pop(i)(code_options, cleanup))
assert not hooks

if sys.version_info >= (3, 11):
for idx in null_idxes:
prefix.append(create_instruction("PUSH_NULL"))
prefix.extend(create_rot_n(idx))

prefix.append(create_jump_absolute(target))

# because the line number table monotonically increases from co_firstlineno
Expand Down
12 changes: 10 additions & 2 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1880,14 +1880,22 @@ def create_call_resume_at(self, inst):
# Python does not allow null to be an arg to a function, so
# we remove nulls from the stack and restore them in the
# prologue of the resume function

# sorted list of indices of nulls on the stack
null_idxes: List[int] = []
if sys.version_info >= (3, 11):
# find indices of NullVariables
for i, var in enumerate(self.stack):
if isinstance(var, NullVariable):
null_idxes.append(i)
# generate bytecode to pop the nulls
null_cnt = 0
for i, var in enumerate(reversed(self.stack)):
if isinstance(var, NullVariable):
for j in range(2, i + 2 - len(null_idxes)):
for j in range(2, i + 2 - null_cnt):
cg.append_output(create_instruction("SWAP", j))
null_idxes.append(i + 1)
cg.extend_output(cg.pop_null())
null_cnt += 1

# we popped all nulls from the stack at runtime,
# so we should not count NullVariables
Expand Down