Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dynamo 3.11] support prefix instructions MAKE_CELL, COPY_FREE_VARS, RETURN_GENERATOR, RESUME #96506

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions torch/_dynamo/bytecode_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def walk(state, start):
state.reads.add(inst.argval)
elif "STORE" in inst.opname:
state.writes.add(inst.argval)
elif inst.opname == "MAKE_CELL":
pass
else:
raise NotImplementedError(f"unhandled {inst.opname}")
if inst.opcode in JUMP_OPCODES:
Expand Down
13 changes: 12 additions & 1 deletion torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import contextlib
import dataclasses
import dis
import functools
import inspect
import logging
Expand Down Expand Up @@ -341,11 +342,21 @@ def __init__(self):
super().__init__(callback=None)


def first_real_inst_idx(code):
if sys.version_info < (3, 11):
return 0
for inst in dis.get_instructions(code):
if inst.opname == "RESUME":
return inst.offset // 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the //2 ok here? All opcodes might not be of size 2 here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We follow the implementation of _co_firsttraceable in CPython (codeobject.c), which uses bytecode index (2 bytes), not instruction index.

raise RuntimeError("RESUME instruction not found in code")

williamwen42 marked this conversation as resolved.
Show resolved Hide resolved

def catch_errors_wrapper(callback, hooks: Hooks):
@functools.wraps(callback)
def catch_errors(frame, cache_size):
if (
frame.f_lasti >= 0
# TODO: the first condition is not covered by any test
frame.f_lasti >= first_real_inst_idx(frame.f_code)
or skipfiles.check(frame.f_code.co_filename)
or config.disable
):
Expand Down
25 changes: 25 additions & 0 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import operator
import re
import sys
import traceback
from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, OrderedDict, Set, Union
Expand Down Expand Up @@ -517,6 +518,27 @@ def compile_subgraph(
if not all(block.can_restore() for block in tx.block_stack):
unimplemented("compile_subgraph with block_depth != 0")

prefix_insts: List[Instruction] = []
if sys.version_info >= (3, 11):
# prefix instructions (Python 3.11+)
for inst in tx.prefix_insts:
if inst.opname == "MAKE_CELL":
prefix_insts.append(
create_instruction("MAKE_CELL", argval=inst.argval)
)
elif inst.opname == "COPY_FREE_VARS":
prefix_insts.append(
create_instruction(
"COPY_FREE_VARS", len(tx.code_options["co_freevars"])
)
)
else:
prefix_insts.append(inst)

def append_prefix_insts():
self.add_output_instructions(prefix_insts)
prefix_insts.clear()

for block in reversed(tx.block_stack):
block.exit(tx)

Expand Down Expand Up @@ -544,6 +566,7 @@ def compile_subgraph(

# to handle random calls
if len(tx.random_calls) > 0:
append_prefix_insts()
random_calls_instructions = []
self.random_values_var = self.new_var("random_values")
rand_fn_name = unique_id("__gen_rand_values")
Expand Down Expand Up @@ -576,6 +599,7 @@ def compile_subgraph(
and len(set(stack_values)) == len(stack_values)
and self.side_effects.is_empty()
):
append_prefix_insts()
# optimization to generate better code in a common case
self.add_output_instructions(
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
Expand Down Expand Up @@ -609,6 +633,7 @@ def compile_subgraph(
output.append(pass2.create_store(graph_output_var))
else:
output.append(create_instruction("POP_TOP"))
append_prefix_insts()
self.add_output_instructions(output + pass2.get_instructions())

# restore all the live local vars
Expand Down
5 changes: 5 additions & 0 deletions torch/_dynamo/resume_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,11 @@ def update(instructions: List[Instruction], code_options: Dict[str, Any]):
(target,) = [i for i in instructions if i.offset == offset]

prefix = []
if sys.version_info >= (3, 11):
if freevars:
prefix.append(create_instruction("COPY_FREE_VARS", len(freevars)))
prefix.append(create_instruction("RESUME", 0))

cleanup = []
hooks = {fn.stack_index: fn for fn in setup_fns}
null_idxes_i = 0
Expand Down
24 changes: 22 additions & 2 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,8 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
lineno: int
mutated_closure_cell_contents: Set[str]
kw_names: Optional[ConstantVariable]
accept_prefix_inst: bool
prefix_insts: List[Instruction]

checkpoint: Optional[Tuple[Instruction, InstructionTranslatorGraphState]]
random_calls: List[
Expand Down Expand Up @@ -1502,9 +1504,12 @@ def LOAD_ASSERTION_ERROR(self, inst):
INPLACE_OR = stack_op(operator.ior)

# 3.11 opcodes
# note: passed opcodes are intentional
def RESUME(self, inst):
pass
if inst.arg == 0:
self.append_prefix_inst(inst)
self.accept_prefix_inst = False
else:
assert not self.accept_prefix_inst

def BINARY_OP(self, inst):
if sys.version_info >= (3, 11):
Expand Down Expand Up @@ -1600,6 +1605,19 @@ def BEFORE_WITH(self, inst):
self.push(exit)
self.push(ctx.enter(self))

def append_prefix_inst(self, inst):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to enforce the ordering of the inst in here? Or it should always be ok since we add things as we read the original frame?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We add things as read in the original frame, and any instructions that jump necessarily come after the prefix, and they shouldn't jump to before the prefix.

assert self.accept_prefix_inst
self.prefix_insts.append(inst)

def MAKE_CELL(self, inst):
self.append_prefix_inst(inst)

def COPY_FREE_VARS(self, inst):
self.append_prefix_inst(inst)

def RETURN_GENERATOR(self, inst):
self.append_prefix_inst(inst)

def copy_graphstate(self) -> InstructionTranslatorGraphState:
"""Create a checkpoint of the current state by copying everything"""
return InstructionTranslatorGraphState(
Expand Down Expand Up @@ -1699,6 +1717,8 @@ def __init__(
self.block_stack = []
self.lineno = code_options["co_firstlineno"]
self.kw_names = None
self.accept_prefix_inst = True
self.prefix_insts = []

# Properties of the input/output code
self.instructions: List[Instruction] = instructions
Expand Down