Skip to content

Commit

Permalink
[dynamo 3.11] implement 3.11 linetable (#96509)
Browse files Browse the repository at this point in the history
Pull Request resolved: #96509
Approved by: https://github.com/jansel
  • Loading branch information
williamwen42 authored and pytorchmergebot committed Mar 31, 2023
1 parent 14ef91c commit 089134b
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 19 deletions.
73 changes: 71 additions & 2 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1983,8 +1983,77 @@ def fn(a, b):

self.assertTrue(same(ref, res))

@unittest.skipIf(sys.version_info < (3, 10), "use linetable when python >= 3.10")
def test_linetable_writer(self):
@unittest.skipIf(sys.version_info < (3, 11), "linetable test for Python 3.11")
def test_linetable_311_writer1(self):
def fn():
a = 10
b = 20
c = a + b
f = "linetable_writer"
return f"Test if {f} generates correct co_linetable: {c}"

# Dynamo doesn't deal with column locations or end line numbers,
# so we only check that start line numbers in the linetables match.
keys = bytecode_transformation.get_code_keys()
code_options = {k: getattr(fn.__code__, k) for k in keys}
result = bytecode_transformation.clean_and_assemble_instructions(
bytecode_transformation.cleaned_instructions(fn.__code__),
keys,
code_options,
)
l1, l2 = list(fn.__code__.co_positions()), list(result[1].co_positions())
self.assertEqual(len(l1), len(l2))
for p1, p2 in zip(l1, l2):
# check that start line numbers match
self.assertEqual(p1[0], p2[0])
self.assertEqual(fn.__code__.co_lnotab, result[1].co_lnotab)

@unittest.skipIf(sys.version_info < (3, 11), "linetable test for Python 3.11")
def test_linetable_311_writer2(self):
"""
test large ops (LOAD_METHOD) and EXTENDED_ARGS
fn_str is in the form:
def fn():
...
x0 = 1
x1 = 1
...
l = [x0, x1, ...]
"""
fn_str = f"""\
def fn():
foo.bar(1, 2, 3)
{str(chr(10)).join(' ' * 4 + 'x' + str(i) + ' = 1' for i in range(1 << 9))}
l = [{str(' ').join('x' + str(i) + ',' for i in range(1 << 9))}]
"""
locals = {}
exec(fn_str, {}, locals)
fn = locals["fn"]
orig_inst_str = "\n".join(list(map(str, dis.get_instructions(fn))))
self.assertIn("EXTENDED_ARG", orig_inst_str)
self.assertIn("LOAD_METHOD", orig_inst_str)
keys = bytecode_transformation.get_code_keys()
code_options = {k: getattr(fn.__code__, k) for k in keys}
result = bytecode_transformation.clean_and_assemble_instructions(
bytecode_transformation.cleaned_instructions(fn.__code__),
keys,
code_options,
)
new_inst_str = "\n".join(list(map(str, result[0])))
self.assertIn("EXTENDED_ARG", new_inst_str)
self.assertIn("LOAD_METHOD", new_inst_str)
l1, l2 = list(fn.__code__.co_positions()), list(result[1].co_positions())
self.assertEqual(len(l1), len(l2))
for p1, p2 in zip(l1, l2):
# check that start line numbers match
self.assertEqual(p1[0], p2[0])
self.assertEqual(fn.__code__.co_lnotab, result[1].co_lnotab)

@unittest.skipIf(
sys.version_info < (3, 10) or sys.version_info >= (3, 11),
"linetable test for Python 3.10",
)
def test_linetable_310_writer(self):
def fn():
a = 10
b = 20
Expand Down
106 changes: 89 additions & 17 deletions torch/_dynamo/bytecode_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,13 @@ def update(lineno_new, byteno_new):
return lnotab, update


def linetable_writer(first_lineno):
def linetable_310_writer(first_lineno):
"""
Used to create typing.CodeType.co_linetable
See https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt
This is the internal format of the line number table if Python >= 3.10
This is the internal format of the line number table for Python 3.10
"""
assert sys.version_info >= (3, 10)
assert sys.version_info >= (3, 10) and sys.version_info < (3, 11)
linetable = []
lineno = first_lineno
lineno_delta = 0
Expand Down Expand Up @@ -236,25 +236,92 @@ def end(total_bytes):
return linetable, update, end


def encode_varint(n):
"""
6-bit chunk encoding of an unsigned integer
See https://github.com/python/cpython/blob/3.11/Objects/locations.md
"""
assert n >= 0
b = [n & 63]
n >>= 6
while n > 0:
b[-1] |= 64
b.append(n & 63)
n >>= 6
return b


def linetable_311_writer(first_lineno):
"""
Used to create typing.CodeType.co_linetable
See https://github.com/python/cpython/blob/3.11/Objects/locations.md
This is the internal format of the line number table for Python 3.11
"""
assert sys.version_info >= (3, 11)
linetable = []
lineno = first_lineno

def update(lineno_new, inst_size):
nonlocal lineno

def _update(delta, size):
assert 0 < size <= 8
# first byte - always use no column info code (13)
linetable.append(0b1_1101_000 + size - 1)
# encode signed int
if delta < 0:
delta = ((-delta) << 1) | 1
else:
delta <<= 1
# encode unsigned int
linetable.extend(encode_varint(delta))

if lineno_new is None:
lineno_delta = 0
else:
lineno_delta = lineno_new - lineno
lineno = lineno_new
while inst_size > 8:
_update(lineno_delta, 8)
inst_size -= 8
_update(lineno_delta, inst_size)

return linetable, update


def assemble(instructions: List[Instruction], firstlineno):
"""Do the opposite of dis.get_instructions()"""
code = []
if sys.version_info < (3, 10):
lnotab, update_lineno = lnotab_writer(firstlineno)
else:
lnotab, update_lineno, end = linetable_writer(firstlineno)

for inst in instructions:
if inst.starts_line is not None:
update_lineno(inst.starts_line, len(code))
arg = inst.arg or 0
code.extend((inst.opcode, arg & 0xFF))
if sys.version_info >= (3, 11):
if sys.version_info >= (3, 11):
lnotab, update_lineno = linetable_311_writer(firstlineno)
num_ext = 0
for inst in instructions:
if inst.opname == "EXTENDED_ARG":
inst_size = 1
num_ext += 1
else:
inst_size = instruction_size(inst) // 2 + num_ext
num_ext = 0
update_lineno(inst.starts_line, inst_size)
num_ext = 0
arg = inst.arg or 0
code.extend((inst.opcode, arg & 0xFF))
for _ in range(instruction_size(inst) // 2 - 1):
code.extend((0, 0))
else:
if sys.version_info < (3, 10):
lnotab, update_lineno = lnotab_writer(firstlineno)
else:
lnotab, update_lineno, end = linetable_310_writer(firstlineno)

if sys.version_info >= (3, 10):
end(len(code))
for inst in instructions:
if inst.starts_line is not None:
update_lineno(inst.starts_line, len(code))
arg = inst.arg or 0
code.extend((inst.opcode, arg & 0xFF))

if sys.version_info >= (3, 10):
end(len(code))

return bytes(code), bytes(lnotab)

Expand Down Expand Up @@ -566,7 +633,7 @@ def should_compute_arg():
assert instructions[i].arg >= 0


def transform_code_object(code, transformations, safe=False):
def get_code_keys():
# Python 3.11 changes to code keys are not fully documented.
# See https://github.com/python/cpython/blob/3.11/Objects/clinic/codeobject.c.h#L24
# for new format.
Expand Down Expand Up @@ -602,6 +669,11 @@ def transform_code_object(code, transformations, safe=False):
"co_cellvars",
]
)
return keys


def transform_code_object(code, transformations, safe=False):
keys = get_code_keys()
code_options = {k: getattr(code, k) for k in keys}
assert len(code_options["co_varnames"]) == code_options["co_nlocals"]

Expand Down

0 comments on commit 089134b

Please sign in to comment.