Skip to content

Commit

Permalink
[dynamo 3.11] support prefix instructions MAKE_CELL, COPY_FREE_VARS, …
Browse files Browse the repository at this point in the history
…RETURN_GENERATOR, RESUME (#96506)

Pull Request resolved: #96506
Approved by: https://github.com/jansel
  • Loading branch information
williamwen42 authored and pytorchmergebot committed Mar 31, 2023
1 parent 05641b8 commit cb4bc8e
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 3 deletions.
2 changes: 2 additions & 0 deletions torch/_dynamo/bytecode_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def walk(state, start):
state.reads.add(inst.argval)
elif "STORE" in inst.opname:
state.writes.add(inst.argval)
elif inst.opname == "MAKE_CELL":
pass
else:
raise NotImplementedError(f"unhandled {inst.opname}")
if inst.opcode in JUMP_OPCODES:
Expand Down
13 changes: 12 additions & 1 deletion torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import contextlib
import dataclasses
import dis
import functools
import inspect
import logging
Expand Down Expand Up @@ -341,11 +342,21 @@ def __init__(self):
super().__init__(callback=None)


def first_real_inst_idx(code):
if sys.version_info < (3, 11):
return 0
for inst in dis.get_instructions(code):
if inst.opname == "RESUME":
return inst.offset // 2
raise RuntimeError("RESUME instruction not found in code")


def catch_errors_wrapper(callback, hooks: Hooks):
@functools.wraps(callback)
def catch_errors(frame, cache_size):
if (
frame.f_lasti >= 0
# TODO: the first condition is not covered by any test
frame.f_lasti >= first_real_inst_idx(frame.f_code)
or skipfiles.check(frame.f_code.co_filename)
or config.disable
):
Expand Down
25 changes: 25 additions & 0 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import operator
import re
import sys
import traceback
from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, OrderedDict, Set, Union
Expand Down Expand Up @@ -524,6 +525,27 @@ def compile_subgraph(
if not all(block.can_restore() for block in tx.block_stack):
unimplemented("compile_subgraph with block_depth != 0")

prefix_insts: List[Instruction] = []
if sys.version_info >= (3, 11):
# prefix instructions (Python 3.11+)
for inst in tx.prefix_insts:
if inst.opname == "MAKE_CELL":
prefix_insts.append(
create_instruction("MAKE_CELL", argval=inst.argval)
)
elif inst.opname == "COPY_FREE_VARS":
prefix_insts.append(
create_instruction(
"COPY_FREE_VARS", len(tx.code_options["co_freevars"])
)
)
else:
prefix_insts.append(inst)

def append_prefix_insts():
self.add_output_instructions(prefix_insts)
prefix_insts.clear()

for block in reversed(tx.block_stack):
block.exit(tx)

Expand Down Expand Up @@ -551,6 +573,7 @@ def compile_subgraph(

# to handle random calls
if len(tx.random_calls) > 0:
append_prefix_insts()
random_calls_instructions = []
self.random_values_var = self.new_var("random_values")
rand_fn_name = unique_id("__gen_rand_values")
Expand Down Expand Up @@ -583,6 +606,7 @@ def compile_subgraph(
and len(set(stack_values)) == len(stack_values)
and self.side_effects.is_empty()
):
append_prefix_insts()
# optimization to generate better code in a common case
self.add_output_instructions(
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
Expand Down Expand Up @@ -616,6 +640,7 @@ def compile_subgraph(
output.append(pass2.create_store(graph_output_var))
else:
output.append(create_instruction("POP_TOP"))
append_prefix_insts()
self.add_output_instructions(output + pass2.get_instructions())

# restore all the live local vars
Expand Down
5 changes: 5 additions & 0 deletions torch/_dynamo/resume_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,11 @@ def update(instructions: List[Instruction], code_options: Dict[str, Any]):
(target,) = [i for i in instructions if i.offset == offset]

prefix = []
if sys.version_info >= (3, 11):
if freevars:
prefix.append(create_instruction("COPY_FREE_VARS", len(freevars)))
prefix.append(create_instruction("RESUME", 0))

cleanup = []
hooks = {fn.stack_index: fn for fn in setup_fns}
null_idxes_i = 0
Expand Down
24 changes: 22 additions & 2 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,8 @@ class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState])
lineno: int
mutated_closure_cell_contents: Set[str]
kw_names: Optional[ConstantVariable]
accept_prefix_inst: bool
prefix_insts: List[Instruction]

checkpoint: Optional[Tuple[Instruction, InstructionTranslatorGraphState]]
random_calls: List[
Expand Down Expand Up @@ -1502,9 +1504,12 @@ def LOAD_ASSERTION_ERROR(self, inst):
INPLACE_OR = stack_op(operator.ior)

# 3.11 opcodes
# note: passed opcodes are intentional
def RESUME(self, inst):
pass
if inst.arg == 0:
self.append_prefix_inst(inst)
self.accept_prefix_inst = False
else:
assert not self.accept_prefix_inst

def BINARY_OP(self, inst):
if sys.version_info >= (3, 11):
Expand Down Expand Up @@ -1600,6 +1605,19 @@ def BEFORE_WITH(self, inst):
self.push(exit)
self.push(ctx.enter(self))

def append_prefix_inst(self, inst):
assert self.accept_prefix_inst
self.prefix_insts.append(inst)

def MAKE_CELL(self, inst):
self.append_prefix_inst(inst)

def COPY_FREE_VARS(self, inst):
self.append_prefix_inst(inst)

def RETURN_GENERATOR(self, inst):
self.append_prefix_inst(inst)

def copy_graphstate(self) -> InstructionTranslatorGraphState:
"""Create a checkpoint of the current state by copying everything"""
return InstructionTranslatorGraphState(
Expand Down Expand Up @@ -1699,6 +1717,8 @@ def __init__(
self.block_stack = []
self.lineno = code_options["co_firstlineno"]
self.kw_names = None
self.accept_prefix_inst = True
self.prefix_insts = []

# Properties of the input/output code
self.instructions: List[Instruction] = instructions
Expand Down

0 comments on commit cb4bc8e

Please sign in to comment.