From 6b90f52091331d841c5da1b63045f9dfcff5cb6a Mon Sep 17 00:00:00 2001 From: Leonardo Gama Date: Fri, 15 Jul 2022 11:15:02 -0300 Subject: [PATCH] fix dump_module() bugs and rename parameter 'main' to 'module' (#526) * fix dump_module() bugs and rename parameter 'main' to 'module' (fixes #525) New phrasing of mismatching modules error messages in load_session(): ```python >>> import dill >>> dill.dump_module() >>> dill.load_module(module='math') ValueError: can't update module 'math' with the saved state of module '__main__' >>> import types >>> main = types.ModuleType('__main__') >>> dill.load_module(module=main) ValueError: can't update module-type object '__main__' with the saved state of imported module '__main__' >>> dill.dump_module(module=main) >>> dill.load_module(module='__main__') ValueError: can't update imported module '__main__' with the saved state of module-type object '__main__' ``` * dump_module: clarify refimport description * improvements to 'refimported' handling and extra checks in *_module() functions * load_session(): clarify that the 'module' argument must match the session file's module --- .gitignore | 2 +- dill/_dill.py | 188 ++++++++++++++++++++++--------------- dill/tests/test_session.py | 45 ++++++--- 3 files changed, 146 insertions(+), 89 deletions(-) diff --git a/.gitignore b/.gitignore index 9e136965..477f7cec 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,4 @@ /docs/build /build /README -/dill/info.py \ No newline at end of file +/dill/__info__.py diff --git a/dill/_dill.py b/dill/_dill.py index e341357c..9642045e 100644 --- a/dill/_dill.py +++ b/dill/_dill.py @@ -321,19 +321,22 @@ def loads(str, ignore=None, **kwds): ### Pickle the Interpreter Session import pathlib +import re import tempfile +from types import SimpleNamespace -SESSION_IMPORTED_AS_TYPES = (ModuleType, ClassType, TypeType, Exception, - FunctionType, MethodType, BuiltinMethodType) TEMPDIR = pathlib.PurePath(tempfile.gettempdir()) def _module_map(): """get map of imported modules""" - from collections import defaultdict, namedtuple - modmap = namedtuple('Modmap', ['by_name', 'by_id', 'top_level']) - modmap = modmap(defaultdict(list), defaultdict(list), {}) + from collections import defaultdict + modmap = SimpleNamespace( + by_name=defaultdict(list), + by_id=defaultdict(list), + top_level={}, + ) for modname, module in sys.modules.items(): - if not isinstance(module, ModuleType): + if modname in ('__main__', '__mp_main__') or not isinstance(module, ModuleType): continue if '.' not in modname: modmap.top_level[id(module)] = modname @@ -342,12 +345,19 @@ def _module_map(): modmap.by_id[id(modobj)].append((modobj, objname, modname)) return modmap +SESSION_IMPORTED_AS_TYPES = (ModuleType, TypeType, FunctionType, MethodType, BuiltinMethodType) +SESSION_IMPORTED_AS_MODULES = ('ctypes', 'typing', 'subprocess', 'threading', + r'concurrent\.futures(\.\w+)?', r'multiprocessing(\.\w+)?') +SESSION_IMPORTED_AS_MODULES = tuple(re.compile(x) for x in SESSION_IMPORTED_AS_MODULES) + def _lookup_module(modmap, name, obj, main_module): """lookup name or id of obj if module is imported""" for modobj, modname in modmap.by_name[name]: if modobj is obj and sys.modules[modname] is not main_module: return modname, name - if isinstance(obj, SESSION_IMPORTED_AS_TYPES): + __module__ = getattr(obj, '__module__', None) + if isinstance(obj, SESSION_IMPORTED_AS_TYPES) or (__module__ is not None + and any(regex.fullmatch(__module__) for regex in SESSION_IMPORTED_AS_MODULES)): for modobj, objname, modname in modmap.by_id[id(obj)]: if sys.modules[modname] is not main_module: return modname, objname @@ -359,36 +369,38 @@ def _stash_modules(main_module): imported = [] imported_as = [] - imported_top_level = [] # keep separeted for backwards compatibility + imported_top_level = [] # keep separeted for backward compatibility original = {} for name, obj in main_module.__dict__.items(): if obj is main_module: original[name] = newmod # self-reference - continue - + elif obj is main_module.__dict__: + original[name] = newmod.__dict__ # Avoid incorrectly matching a singleton value in another package (ex.: __doc__). - if any(obj is singleton for singleton in (None, False, True)) or \ - isinstance(obj, ModuleType) and _is_builtin_module(obj): # always saved by ref + elif any(obj is singleton for singleton in (None, False, True)) \ + or isinstance(obj, ModuleType) and _is_builtin_module(obj): # always saved by ref original[name] = obj - continue - - source_module, objname = _lookup_module(modmap, name, obj, main_module) - if source_module: - if objname == name: - imported.append((source_module, name)) - else: - imported_as.append((source_module, objname, name)) else: - try: - imported_top_level.append((modmap.top_level[id(obj)], name)) - except KeyError: - original[name] = obj + source_module, objname = _lookup_module(modmap, name, obj, main_module) + if source_module is not None: + if objname == name: + imported.append((source_module, name)) + else: + imported_as.append((source_module, objname, name)) + else: + try: + imported_top_level.append((modmap.top_level[id(obj)], name)) + except KeyError: + original[name] = obj if len(original) < len(main_module.__dict__): newmod.__dict__.update(original) newmod.__dill_imported = imported newmod.__dill_imported_as = imported_as newmod.__dill_imported_top_level = imported_top_level + if getattr(newmod, '__loader__', None) is None and _is_imported_module(main_module): + # Trick _is_imported_module() to force saving as an imported module. + newmod.__loader__ = True # will be discarded by save_module() return newmod else: return main_module @@ -407,7 +419,7 @@ def _restore_modules(unpickler, main_module): #NOTE: 06/03/15 renamed main_module to main def dump_module( filename = str(TEMPDIR/'session.pkl'), - main: Optional[Union[ModuleType, str]] = None, + module: Union[ModuleType, str] = None, refimported: bool = False, **kwds ) -> None: @@ -420,29 +432,31 @@ def dump_module( Parameters: filename: a path-like object or a writable stream. - main: a module object or the name of an importable module. - refimported: if `True`, all objects imported into the module's - namespace are saved by reference. *Note:* this is similar but - independent from ``dill.settings[`byref`]``, as ``refimported`` - refers to all imported objects, while ``byref`` only affects - select objects. + module: a module object or the name of an importable module. If `None` + (the default), :py:mod:`__main__` is saved. + refimported: if `True`, all objects identified as having been imported + into the module's namespace are saved by reference. *Note:* this is + similar but independent from ``dill.settings[`byref`]``, as + ``refimported`` refers to virtually all imported objects, while + ``byref`` only affects select objects. **kwds: extra keyword arguments passed to :py:class:`Pickler()`. Raises: :py:exc:`PicklingError`: if pickling fails. Examples: + - Save current interpreter session state: >>> import dill - >>> squared = lambda x:x*x + >>> squared = lambda x: x*x >>> dill.dump_module() # save state of __main__ to /tmp/session.pkl - Save the state of an imported/importable module: >>> import dill >>> import pox - >>> pox.plus_one = lambda x:x+1 + >>> pox.plus_one = lambda x: x+1 >>> dill.dump_module('pox_session.pkl', main=pox) - Save the state of a non-importable, module-type object: @@ -468,24 +482,34 @@ def dump_module( >>> [foo.sin(x) for x in foo.values] [0.8414709848078965, 0.9092974268256817, 0.1411200080598672] - *Changed in version 0.3.6:* the function ``dump_session()`` was renamed to - ``dump_module()``. + *Changed in version 0.3.6:* Function ``dump_session()`` was renamed to + ``dump_module()``. Parameters ``main`` and ``byref`` were renamed to + ``module`` and ``refimported``, respectively. - *Changed in version 0.3.6:* the parameter ``byref`` was renamed to - ``refimported``. + Note: + Currently, ``dill.settings['byref']`` and ``dill.settings['recurse']`` + don't apply to this function.` """ - if 'byref' in kwds: - warnings.warn( - "The argument 'byref' has been renamed 'refimported'" - " to distinguish it from dill.settings['byref'].", - PendingDeprecationWarning - ) - if refimported: - raise TypeError("both 'refimported' and 'byref' were used") - refimported = kwds.pop('byref') + for old_par, par in [('main', 'module'), ('byref', 'refimported')]: + if old_par in kwds: + message = "The argument %r has been renamed %r" % (old_par, par) + if old_par == 'byref': + message += " to distinguish it from dill.settings['byref']" + warnings.warn(message + ".", PendingDeprecationWarning) + if locals()[par]: # the defaults are None and False + raise TypeError("both %r and %r arguments were used" % (par, old_par)) + refimported = kwds.pop('byref', refimported) + module = kwds.pop('main', module) + from .settings import settings protocol = settings['protocol'] - if main is None: main = _main_module + main = module + if main is None: + main = _main_module + elif isinstance(main, str): + main = _import_module(main) + if not isinstance(main, ModuleType): + raise TypeError("%r is not a module" % main) if hasattr(filename, 'write'): file = filename else: @@ -510,7 +534,7 @@ def dump_module( # Backward compatibility. def dump_session(filename=str(TEMPDIR/'session.pkl'), main=None, byref=False, **kwds): warnings.warn("dump_session() has been renamed dump_module()", PendingDeprecationWarning) - dump_module(filename, main, refimported=byref, **kwds) + dump_module(filename, module=main, refimported=byref, **kwds) dump_session.__doc__ = dump_module.__doc__ class _PeekableReader: @@ -574,7 +598,7 @@ def _identify_module(file, main=None): def load_module( filename = str(TEMPDIR/'session.pkl'), - main: Union[ModuleType, str] = None, + module: Union[ModuleType, str] = None, **kwds ) -> Optional[ModuleType]: """Update :py:mod:`__main__` or another module with the state from the @@ -592,7 +616,9 @@ def load_module( Parameters: filename: a path-like object or a readable stream. - main: a module object or the name of an importable module. + module: a module object or the name of an importable module, either of + which must match the name and kind (importable or module-type + object) of the session file's module. **kwds: extra keyword arguments passed to :py:class:`Unpickler()`. Raises: @@ -609,11 +635,11 @@ def load_module( - Save the state of some modules: >>> import dill - >>> squared = lambda x:x*x + >>> squared = lambda x: x*x >>> dill.dump_module() # save state of __main__ to /tmp/session.pkl >>> >>> import pox # an imported module - >>> pox.plus_one = lambda x:x+1 + >>> pox.plus_one = lambda x: x+1 >>> dill.dump_module('pox_session.pkl', main=pox) >>> >>> from types import ModuleType @@ -659,19 +685,27 @@ def load_module( >>> from types import ModuleType >>> foo = ModuleType('foo') >>> foo.values = ['a','b'] - >>> foo.sin = lambda x:x*x + >>> foo.sin = lambda x: x*x >>> dill.load_module('foo_session.pkl', main=foo) >>> [foo.sin(x) for x in foo.values] [0.8414709848078965, 0.9092974268256817, 0.1411200080598672] - *Changed in version 0.3.6:* the function ``load_session()`` was renamed to - ``load_module()``. + *Changed in version 0.3.6:* Function ``load_session()`` was renamed to + ``load_module()``. Parameter ``main`` was renamed to ``module``. See also: :py:func:`load_module_asdict` to load the contents of module saved with :py:func:`dump_module` into a dictionary. """ - main_arg = main + if 'main' in kwds: + warnings.warn( + "The argument 'main' has been renamed 'module'.", + PendingDeprecationWarning + ) + if module is not None: + raise TypeError("both 'module' and 'main' arguments were used") + module = kwds.pop('main') + main = module if hasattr(filename, 'read'): file = filename else: @@ -681,9 +715,9 @@ def load_module( #FIXME: dill.settings are disabled unpickler = Unpickler(file, **kwds) unpickler._session = True - pickle_main = _identify_module(file, main) # Resolve unpickler._main + pickle_main = _identify_module(file, main) if main is None and pickle_main is not None: main = pickle_main if isinstance(main, str): @@ -694,7 +728,7 @@ def load_module( main = _import_module(main) if main is not None: if not isinstance(main, ModuleType): - raise ValueError("%r is not a module" % main) + raise TypeError("%r is not a module" % main) unpickler._main = main else: main = unpickler._main @@ -705,28 +739,26 @@ def load_module( is_runtime_mod = pickle_main.startswith('__runtime__.') if is_runtime_mod: pickle_main = pickle_main.partition('.')[-1] + error_msg = "can't update{} module{} %r with the saved state of{} module{} %r" if is_runtime_mod and is_main_imported: raise ValueError( - "can't restore non-imported module %r into an imported one" - % pickle_main + error_msg.format(" imported", "", "", "-type object") + % (main.__name__, pickle_main) ) if not is_runtime_mod and not is_main_imported: raise ValueError( - "can't restore imported module %r into a non-imported one" - % pickle_main - ) - if main.__name__ != pickle_main: - raise ValueError( - "can't restore module %r into module %r" + error_msg.format("", "-type object", " imported", "") % (pickle_main, main.__name__) ) + if main.__name__ != pickle_main: + raise ValueError(error_msg.format("", "", "", "") % (main.__name__, pickle_main)) # This is for find_class() to be able to locate it. if not is_main_imported: runtime_main = '__runtime__.%s' % main.__name__ sys.modules[runtime_main] = main - module = unpickler.load() + loaded = unpickler.load() finally: if not hasattr(filename, 'read'): # if newly opened file file.close() @@ -734,15 +766,17 @@ def load_module( del sys.modules[runtime_main] except (KeyError, NameError): pass - assert module is main - _restore_modules(unpickler, module) - if not (module is _main_module or module is main_arg): - return module + assert loaded is main + _restore_modules(unpickler, main) + if main is _main_module or main is module: + return None + else: + return main # Backward compatibility. def load_session(filename=str(TEMPDIR/'session.pkl'), main=None, **kwds): warnings.warn("load_session() has been renamed load_module().", PendingDeprecationWarning) - load_module(filename, main, **kwds) + load_module(filename, module=main, **kwds) load_session.__doc__ = load_module.__doc__ def load_module_asdict( @@ -774,6 +808,7 @@ def load_module_asdict( Note: If ``update`` is True, the saved module may be imported then updated. + If imported, the loaded module remains unchanged as in the general case. Example: >>> import dill @@ -796,8 +831,8 @@ def load_module_asdict( >>> new_var in main # would be True if the option 'update' was set False """ - if 'main' in kwds: - raise TypeError("'main' is an invalid keyword argument for load_module_asdict()") + if 'module' in kwds: + raise TypeError("'module' is an invalid keyword argument for load_module_asdict()") if hasattr(filename, 'read'): file = filename else: @@ -815,7 +850,6 @@ def load_module_asdict( main.__builtins__ = __builtin__ sys.modules[main_name] = main load_module(file, **kwds) - main.__session__ = str(filename) finally: if not hasattr(filename, 'read'): # if newly opened file file.close() @@ -826,6 +860,7 @@ def load_module_asdict( sys.modules[main_name] = old_main except NameError: # failed before setting old_main pass + main.__session__ = str(filename) return main.__dict__ ### End: Pickle the Interpreter @@ -2410,6 +2445,7 @@ def save_capsule(pickler, obj): _incedental_reverse_typemap['PyCapsuleType'] = PyCapsuleType _reverse_typemap['PyCapsuleType'] = PyCapsuleType _incedental_types.add(PyCapsuleType) + SESSION_IMPORTED_AS_TYPES += (PyCapsuleType,) else: _testcapsule = None @@ -2428,7 +2464,7 @@ def pickles(obj,exact=False,safe=False,**kwds): """ if safe: exceptions = (Exception,) # RuntimeError, ValueError else: - exceptions = (TypeError, AssertionError, PicklingError, UnpicklingError) + exceptions = (TypeError, AssertionError, NotImplementedError, PicklingError, UnpicklingError) try: pik = copy(obj, **kwds) #FIXME: should check types match first, then check content if "exact" diff --git a/dill/tests/test_session.py b/dill/tests/test_session.py index 8f687934..9124802c 100644 --- a/dill/tests/test_session.py +++ b/dill/tests/test_session.py @@ -9,6 +9,7 @@ import os import sys import __main__ +from contextlib import suppress from io import BytesIO import dill @@ -27,7 +28,7 @@ def _error_line(error, obj, refimported): if __name__ == '__main__' and len(sys.argv) >= 3 and sys.argv[1] == '--child': # Test session loading in a fresh interpreter session. refimported = (sys.argv[2] == 'True') - dill.load_module(session_file % refimported) + dill.load_module(session_file % refimported, module='__main__') def test_modules(refimported): # FIXME: In this test setting with CPython 3.7, 'calendar' is not included @@ -111,10 +112,8 @@ def _clean_up_cache(module): cached = module.__cached__ if hasattr(module, '__cached__') else cached pycache = os.path.join(os.path.dirname(module.__file__), '__pycache__') for remove, file in [(os.remove, cached), (os.removedirs, pycache)]: - try: + with suppress(OSError): remove(file) - except OSError: - pass atexit.register(_clean_up_cache, local_mod) @@ -163,16 +162,14 @@ def test_session_main(refimported): error = sp.call([python, __file__, '--child', str(refimported)], shell=shell) if error: sys.exit(error) finally: - try: + with suppress(OSError): os.remove(session_file % refimported) - except OSError: - pass # Test session loading in the same session. session_buffer = BytesIO() dill.dump_module(session_buffer, refimported=refimported) session_buffer.seek(0) - dill.load_module(session_buffer) + dill.load_module(session_buffer, module='__main__') ns.backup['_test_objects'](__main__, ns.backup, refimported) def test_session_other(): @@ -183,13 +180,13 @@ def test_session_other(): dict_objects = [obj for obj in module.__dict__.keys() if not obj.startswith('__')] session_buffer = BytesIO() - dill.dump_module(session_buffer, main=module) + dill.dump_module(session_buffer, module) for obj in dict_objects: del module.__dict__[obj] session_buffer.seek(0) - dill.load_module(session_buffer) #, main=module) + dill.load_module(session_buffer, module) assert all(obj in module.__dict__ for obj in dict_objects) assert module.selfref is module @@ -210,12 +207,12 @@ def test_runtime_module(): # without imported objects in the namespace. It's a contrived example because # even dill can't be in it. This should work after fixing #462. session_buffer = BytesIO() - dill.dump_module(session_buffer, main=runtime, refimported=True) + dill.dump_module(session_buffer, module=runtime, refimported=True) session_dump = session_buffer.getvalue() # Pass a new runtime created module with the same name. runtime = ModuleType(modname) # empty - return_val = dill.load_module(BytesIO(session_dump), main=runtime) + return_val = dill.load_module(BytesIO(session_dump), module=runtime) assert return_val is None assert runtime.__name__ == modname assert runtime.x == 42 @@ -228,6 +225,29 @@ def test_runtime_module(): assert runtime.x == 42 assert runtime not in sys.modules.values() +def test_refimported_imported_as(): + import collections + import concurrent.futures + import types + import typing + mod = sys.modules['__test__'] = types.ModuleType('__test__') + dill.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + mod.Dict = collections.UserDict # select by type + mod.AsyncCM = typing.AsyncContextManager # select by __module__ + mod.thread_exec = dill.executor # select by __module__ with regex + + session_buffer = BytesIO() + dill.dump_module(session_buffer, mod, refimported=True) + session_buffer.seek(0) + mod = dill.load(session_buffer) + del sys.modules['__test__'] + + assert set(mod.__dill_imported_as) == { + ('collections', 'UserDict', 'Dict'), + ('typing', 'AsyncContextManager', 'AsyncCM'), + ('dill', 'executor', 'thread_exec'), + } + def test_load_module_asdict(): with TestNamespace(): session_buffer = BytesIO() @@ -256,4 +276,5 @@ def test_load_module_asdict(): test_session_main(refimported=True) test_session_other() test_runtime_module() + test_refimported_imported_as() test_load_module_asdict()