Skip to content

Commit

Permalink
Rewrite _create_code() with Structural Pattern Matching (limited to t…
Browse files Browse the repository at this point in the history
…uples) (#496)

* rewrite _create_code() with Structural Pattern Matching (limited to tuples)

* lnotab and linetable: fixup

* pattern matching: optimizations

* pattern matching: more optimizations

* _create_code: tests

* match/case: don't check code members' types
  • Loading branch information
leogama committed Jul 3, 2022
1 parent 3881a2b commit dc35b66
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 150 deletions.
298 changes: 148 additions & 150 deletions dill/_dill.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,160 +721,158 @@ def _create_function(fcode, fglobals, fname=None, fdefaults=None,
# assert id(fglobals) == id(func.__globals__)
return func

class match:
"""
Make avaialable a limited structural pattern matching-like syntax for Python < 3.10
Patterns can be only tuples (without types) currently.
Inspired by the package pattern-matching-PEP634.
Usage:
>>> with match(args) as m:
>>> if m.case(('x', 'y')):
>>> # use m.x and m.y
>>> elif m.case(('x', 'y', 'z')):
>>> # use m.x, m.y and m.z
Equivalent native code for Python >= 3.10:
>>> match args:
>>> case (x, y):
>>> # use x and y
>>> case (x, y, z):
>>> # use x, y and z
"""
def __init__(self, value):
self.value = value
self._fields = None
def __enter__(self):
return self
def __exit__(self, *exc_info):
return False
def case(self, args): # *args, **kwargs):
"""just handles tuple patterns"""
if len(self.value) != len(args): # + len(kwargs):
return False
#if not all(isinstance(arg, pat) for arg, pat in zip(self.value[len(args):], kwargs.values())):
# return False
self.args = args # (*args, *kwargs)
return True
@property
def fields(self):
# Only bind names to values if necessary.
if self._fields is None:
self._fields = dict(zip(self.args, self.value))
return self._fields
def __getattr__(self, item):
return self.fields[item]

ALL_CODE_PARAMS = [
# Version New attribute CodeType parameters
((3,11,'a'), 'co_endlinetable', 'argcount posonlyargcount kwonlyargcount nlocals stacksize flags code consts names varnames filename name qualname firstlineno linetable endlinetable columntable exceptiontable freevars cellvars'),
((3,11), 'co_exceptiontable', 'argcount posonlyargcount kwonlyargcount nlocals stacksize flags code consts names varnames filename name qualname firstlineno linetable exceptiontable freevars cellvars'),
((3,10), 'co_linetable', 'argcount posonlyargcount kwonlyargcount nlocals stacksize flags code consts names varnames filename name firstlineno linetable freevars cellvars'),
((3,8), 'co_posonlyargcount', 'argcount posonlyargcount kwonlyargcount nlocals stacksize flags code consts names varnames filename name firstlineno lnotab freevars cellvars'),
((3,7), 'co_kwonlyargcount', 'argcount kwonlyargcount nlocals stacksize flags code consts names varnames filename name firstlineno lnotab freevars cellvars'),
]
for version, new_attr, params in ALL_CODE_PARAMS:
if hasattr(CodeType, new_attr):
CODE_VERSION = version
CODE_PARAMS = params.split()
break
ENCODE_PARAMS = set(CODE_PARAMS).intersection(
['code', 'lnotab', 'linetable', 'endlinetable', 'columntable', 'exceptiontable'])

def _create_code(*args):
if type(args[0]) is not int: # co_lnotab stored from >= 3.10
LNOTAB = args[0].encode() if hasattr(args[0], 'encode') else args[0]
args = args[1:]
if not isinstance(args[0], int): # co_lnotab stored from >= 3.10
LNOTAB, *args = args
else: # from < 3.10 (or pre-LNOTAB storage)
LNOTAB = b''
if hasattr(args[-3], 'encode'): #NOTE: from PY2 fails (optcode)
args = list(args)
if len(args) == 20: # from 3.11a
# obj.co_argcount, obj.co_posonlyargcount,
# obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize,
# obj.co_flags, obj.co_code, obj.co_consts, obj.co_names,
# obj.co_varnames, obj.co_filename, obj.co_name, obj.co_qualname,
# obj.co_firstlineno, obj.co_linetable, obj.co_endlinetable,
# obj.co_columntable, obj.co_exceptiontable, obj.co_freevars,
# obj.co_cellvars
args[-3] = args[-3].encode() # co_exceptiontable
args[-6] = args[-6].encode() # co_linetable
args[-14] = args[-14].encode() # co_code
if args[-4] is not None:
args[-4] = args[-4].encode() # co_columntable
if args[-5] is not None:
args[-5] = args[-5].encode() # co_endlinetable
elif len(args) == 18: # from 3.11
# obj.co_argcount, obj.co_posonlyargcount,
# obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize,
# obj.co_flags, obj.co_code, obj.co_consts, obj.co_names,
# obj.co_varnames, obj.co_filename, obj.co_name, obj.co_qualname,
# obj.co_firstlineno, obj.co_linetable, obj.co_exceptiontable,
# obj.co_freevars, obj.co_cellvars
args[-3] = args[-3].encode() # co_exceptiontable
args[-4] = args[-4].encode() # co_linetable
args[-12] = args[-12].encode() # co_code
else: # from 3.10
# obj.co_argcount, obj.co_posonlyargcount,
# obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize,
# obj.co_flags, obj.co_code, obj.co_consts, obj.co_names,
# obj.co_varnames, obj.co_filename, obj.co_name,
# obj.co_firstlineno, obj.co_linetable, obj.co_freevars,
# obj.co_cellvars
args[-3] = args[-3].encode() # co_linetable (or co_lnotab)
args[-10] = args[-10].encode() # co_code
args = tuple(args)
if hasattr(CodeType, 'co_endlinetable'): # python 3.11a
# obj.co_argcount, obj.co_posonlyargcount,
# obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize,
# obj.co_flags, obj.co_code, obj.co_consts, obj.co_names,
# obj.co_varnames, obj.co_filename, obj.co_name, obj.co_qualname,
# obj.co_firstlineno, obj.co_linetable, obj.co_endlinetable,
# obj.co_columntable, obj.co_exceptiontable, obj.co_freevars,
# obj.co_cellvars
if len(args) == 20: return CodeType(*args)
elif len(args) == 18: # from 3.11
argz = (None, None)
argz = args[:-3] + argz + args[-3:]
return CodeType(*argz)
elif len(args) == 16: # from 3.10 or from 3.8
if LNOTAB: # here and above uses stored co_linetable
argz = (None, None, b'')
argz = args[:-4] + args[-5:-4] + args[-4:-2] + argz + args[-2:]
else: # here and below drops stored co_lnotab
argz = (LNOTAB, None, None, b'')
argz = args[:-4] + args[-5:-4] + args[-4:-3] + argz + args[-2:]
return CodeType(*argz)
elif len(args) == 15: # from 3.7
argz = (LNOTAB, None, None, b'')
argz = args[1:-4] + args[-5:-4] + args[-4:-3] + argz + args[-2:]
return CodeType(args[0], 0, *argz)
argz = (LNOTAB, None, None, b'') # from 2.7
argz = args[1:-4] + args[-5:-4] + args[-4:-3] + argz + args[-2:]
return CodeType(args[0], 0, 0, *argz)
elif hasattr(CodeType, 'co_exceptiontable'): # python 3.11
# obj.co_argcount, obj.co_posonlyargcount,
# obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize,
# obj.co_flags, obj.co_code, obj.co_consts, obj.co_names,
# obj.co_varnames, obj.co_filename, obj.co_name, obj.co_qualname,
# obj.co_firstlineno, obj.co_linetable, obj.co_exceptiontable,
# obj.co_freevars, obj.co_cellvars
if len(args) == 20: return CodeType(*(args[:15] + args[17:]))
elif len(args) == 18: return CodeType(*args)
elif len(args) == 16: # from 3.10 or from 3.8
if LNOTAB: # here and above uses stored co_linetable
argz = (b'',)
argz = args[:-4] + args[-5:-4] + args[-4:-2] + argz + args[-2:]
else: # here and below drops stored co_lnotab
argz = (LNOTAB, b'')
argz = args[:-4] + args[-5:-4] + args[-4:-3] + argz + args[-2:]
return CodeType(*argz)
elif len(args) == 15: # from 3.7
argz = (LNOTAB, b'')
argz = args[1:-4] + args[-5:-4] + args[-4:-3] + argz + args[-2:]
return CodeType(args[0], 0, *argz)
argz = (LNOTAB, b'') # from 2.7
argz = args[1:-4] + args[-5:-4] + args[-4:-3] + argz + args[-2:]
return CodeType(args[0], 0, 0, *argz)
elif hasattr(CodeType, 'co_linetable'): # python 3.10
# obj.co_argcount, obj.co_posonlyargcount,
# obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize,
# obj.co_flags, obj.co_code, obj.co_consts, obj.co_names,
# obj.co_varnames, obj.co_filename, obj.co_name,
# obj.co_firstlineno, obj.co_linetable, obj.co_freevars,
# obj.co_cellvars
if len(args) == 20: # from 3.11a
return CodeType(*(args[:12] + args[13:15] + args[18:]))
elif len(args) == 18: # from 3.11
return CodeType(*(args[:12] + args[13:15] + args[16:]))
elif len(args) == 16: # from 3.10 or from 3.8
if not LNOTAB: # here and below drops stored co_lnotab
args = args[:-3] + (LNOTAB,) + args[-2:]
return CodeType(*args)
elif len(args) == 15: # from 3.7
argz = args[1:-3] + (LNOTAB,) + args[-2:]
return CodeType(args[0], 0, *argz)
argz = args[1:-3] + (LNOTAB,) + args[-2:]
return CodeType(args[0], 0, 0, *argz) # from 2.7
elif hasattr(CodeType, 'co_posonlyargcount'): # python 3.8
# obj.co_argcount, obj.co_posonlyargcount,
# obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize,
# obj.co_flags, obj.co_code, obj.co_consts, obj.co_names,
# obj.co_varnames, obj.co_filename, obj.co_name,
# obj.co_firstlineno, obj.co_lnotab, obj.co_freevars,
# obj.co_cellvars
if len(args) == 20: # from 3.11a
args = args[:12] + args[13:14] + (LNOTAB,) + args[18:]
return CodeType(*args)
elif len(args) == 18: # from 3.11
args = args[:12] + args[13:14] + (LNOTAB,) + args[16:]
return CodeType(*args)
elif len(args) == 16: # from 3.10 or from 3.8
if LNOTAB: # here and above uses stored LNOTAB
args = args[:-3] + (LNOTAB,) + args[-2:]
return CodeType(*args)
elif len(args) == 15: return CodeType(args[0], 0, *args[1:]) # from 3.7
return CodeType(args[0], 0, 0, *args[1:]) # from 2.7
else: # python 3.7
# obj.co_argcount, obj.co_kwonlyargcount, obj.co_nlocals,
# obj.co_stacksize, obj.co_flags, obj.co_code, obj.co_consts,
# obj.co_names, obj.co_varnames, obj.co_filename,
# obj.co_name, obj.co_firstlineno, obj.co_lnotab,
# obj.co_freevars, obj.co_cellvars
if len(args) == 20: # from 3.11a
args = args[:1] + args[2:12] + args[13:14] + (LNOTAB,) + args[18:]
return CodeType(*args)
elif len(args) == 18: # from 3.11
args = args[:1] + args[2:12] + args[13:14] + (LNOTAB,) + args[16:]
return CodeType(*args)
elif len(args) == 16: # from 3.10 or from 3.8
if LNOTAB: # here and above uses stored LNOTAB
argz = args[2:-3] + (LNOTAB,) + args[-2:]

with match(args) as m:
# Python 3.11/3.12a (18 members)
if m.case((
'argcount', 'posonlyargcount', 'kwonlyargcount', 'nlocals', 'stacksize', 'flags', # args[0:6]
'code', 'consts', 'names', 'varnames', 'filename', 'name', 'qualname', 'firstlineno', # args[6:14]
'linetable', 'exceptiontable', 'freevars', 'cellvars' # args[14:]
)):
if CODE_VERSION == (3,11):
return CodeType(
*args[:6],
args[6].encode() if hasattr(args[6], 'encode') else args[6], # code
*args[7:14],
args[14].encode() if hasattr(args[14], 'encode') else args[14], # linetable
args[15].encode() if hasattr(args[15], 'encode') else args[15], # exceptiontable
args[16],
args[17],
)
fields = m.fields
# Python 3.10 or 3.8/3.9 (16 members)
elif m.case((
'argcount', 'posonlyargcount', 'kwonlyargcount', 'nlocals', 'stacksize', 'flags', # args[0:6]
'code', 'consts', 'names', 'varnames', 'filename', 'name', 'firstlineno', # args[6:13]
'LNOTAB_OR_LINETABLE', 'freevars', 'cellvars' # args[13:]
)):
if CODE_VERSION == (3,10) or CODE_VERSION == (3,8):
return CodeType(
*args[:6],
args[6].encode() if hasattr(args[6], 'encode') else args[6], # code
*args[7:13],
args[13].encode() if hasattr(args[13], 'encode') else args[13], # lnotab/linetable
args[14],
args[15],
)
fields = m.fields
if CODE_VERSION >= (3,10):
fields['linetable'] = m.LNOTAB_OR_LINETABLE
else:
argz = args[2:]
return CodeType(args[0], *argz)
elif len(args) == 15: return CodeType(*args)
return CodeType(args[0], 0, *args[1:]) # from 2.7
fields['lnotab'] = LNOTAB if LNOTAB else m.LNOTAB_OR_LINETABLE
# Python 3.7 (15 args)
elif m.case((
'argcount', 'kwonlyargcount', 'nlocals', 'stacksize', 'flags', # args[0:5]
'code', 'consts', 'names', 'varnames', 'filename', 'name', 'firstlineno', # args[5:12]
'lnotab', 'freevars', 'cellvars' # args[12:]
)):
if CODE_VERSION == (3,7):
return CodeType(
*args[:5],
args[5].encode() if hasattr(args[5], 'encode') else args[5], # code
*args[6:12],
args[12].encode() if hasattr(args[12], 'encode') else args[12], # lnotab
args[13],
args[14],
)
fields = m.fields
# Python 3.11a (20 members)
elif m.case((
'argcount', 'posonlyargcount', 'kwonlyargcount', 'nlocals', 'stacksize', 'flags', # args[0:6]
'code', 'consts', 'names', 'varnames', 'filename', 'name', 'qualname', 'firstlineno', # args[6:14]
'linetable', 'endlinetable', 'columntable', 'exceptiontable', 'freevars', 'cellvars' # args[14:]
)):
if CODE_VERSION == (3,11,'a'):
return CodeType(
*args[:6],
args[6].encode() if hasattr(args[6], 'encode') else args[6], # code
*args[7:14],
*(a.encode() if hasattr(a, 'encode') else a for a in args[14:18]), # linetable-exceptiontable
args[18],
args[19],
)
fields = m.fields
else:
raise UnpicklingError("pattern match for code object failed")

# The args format doesn't match this version.
fields.setdefault('posonlyargcount', 0) # from python <= 3.7
fields.setdefault('lnotab', LNOTAB) # from python >= 3.10
fields.setdefault('linetable', b'') # from python <= 3.9
fields.setdefault('qualname', fields['name']) # from python <= 3.10
fields.setdefault('exceptiontable', b'') # from python <= 3.10
fields.setdefault('endlinetable', None) # from python != 3.11a
fields.setdefault('columntable', None) # from python != 3.11a

args = (fields[k].encode() if k in ENCODE_PARAMS and hasattr(fields[k], 'encode') else fields[k]
for k in CODE_PARAMS)
return CodeType(*args)

def _create_ftype(ftypeobj, func, args, kwds):
if kwds is None:
Expand Down
23 changes: 23 additions & 0 deletions dill/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,29 @@ def test_functions():
assert dill.loads(dumped_func_e)(1, 2, 3, e2=4) == 12
assert dill.loads(dumped_func_e)(1, 2, 3, e2=4, e3=5) == 15''')

def test_code_object():
from dill._dill import ALL_CODE_PARAMS, CODE_PARAMS, CODE_VERSION, _create_code
code = function_c.__code__
LNOTAB = getattr(code, 'co_lnotab', b'')
fields = {f: getattr(code, 'co_'+f) for f in CODE_PARAMS}
fields.setdefault('posonlyargcount', 0) # python >= 3.8
fields.setdefault('lnotab', LNOTAB) # python <= 3.9
fields.setdefault('linetable', b'') # python >= 3.10
fields.setdefault('qualname', fields['name']) # python >= 3.11
fields.setdefault('exceptiontable', b'') # python >= 3.11
fields.setdefault('endlinetable', None) # python == 3.11a
fields.setdefault('columntable', None) # python == 3.11a

for version, _, params in ALL_CODE_PARAMS:
args = tuple(fields[p] for p in params.split())
try:
_create_code(*args)
if version >= (3,10):
_create_code(fields['lnotab'], *args)
except Exception as error:
raise Exception("failed to construct code object with format version {}".format(version)) from error

if __name__ == '__main__':
test_functions()
test_issue_510()
test_code_object()

0 comments on commit dc35b66

Please sign in to comment.