diff --git a/pyiron_workflow/mixin/storage.py b/pyiron_workflow/mixin/storage.py index b44229b29..24d0abcaf 100644 --- a/pyiron_workflow/mixin/storage.py +++ b/pyiron_workflow/mixin/storage.py @@ -12,6 +12,7 @@ import sys from typing import Optional +import cloudpickle import h5io from pyiron_snippets.files import FileObject, DirectoryObject @@ -88,6 +89,7 @@ def _delete(self): class PickleStorage(StorageInterface): _PICKLE_STORAGE_FILE_NAME = "pickle.pckl" + _CLOUDPICKLE_STORAGE_FILE_NAME = "cloudpickle.cpckl" def __init__(self, owner: HasPickleStorage): super().__init__(owner=owner) @@ -97,12 +99,22 @@ def owner(self) -> HasPickleStorage: return self._owner def _save(self): - with open(self._pickle_storage_file_path, "wb") as file: - pickle.dump(self.owner, file) + try: + with open(self._pickle_storage_file_path, "wb") as file: + pickle.dump(self.owner, file) + except Exception: + self._delete() + with open(self._cloudpickle_storage_file_path, "wb") as file: + cloudpickle.dump(self.owner, file) def _load(self): - with open(self._pickle_storage_file_path, "rb") as file: - inst = pickle.load(file) + if self._has_pickle_contents: + with open(self._pickle_storage_file_path, "rb") as file: + inst = pickle.load(file) + elif self._has_cloudpickle_contents: + with open(self._cloudpickle_storage_file_path, "rb") as file: + inst = cloudpickle.load(file) + if inst.__class__ != self.owner.__class__: raise TypeError( f"{self.owner.label} cannot load, as it has type " @@ -111,24 +123,38 @@ def _load(self): ) self.owner.__setstate__(inst.__getstate__()) + def _delete_file(self, file: str): + FileObject(file, self.owner.storage_directory).delete() + def _delete(self): - if self.has_contents: - FileObject( - self._PICKLE_STORAGE_FILE_NAME, self.owner.storage_directory - ).delete() + if self._has_pickle_contents: + self._delete_file(self._PICKLE_STORAGE_FILE_NAME) + elif self._has_cloudpickle_contents: + self._delete_file(self._CLOUDPICKLE_STORAGE_FILE_NAME) + + def _storage_path(self, file: str): + return str((self.owner.storage_directory.path / file).resolve()) @property def _pickle_storage_file_path(self) -> str: - return str( - ( - self.owner.storage_directory.path / self._PICKLE_STORAGE_FILE_NAME - ).resolve() - ) + return self._storage_path(self._PICKLE_STORAGE_FILE_NAME) + + @property + def _cloudpickle_storage_file_path(self) -> str: + return self._storage_path(self._CLOUDPICKLE_STORAGE_FILE_NAME) @property def _has_contents(self) -> bool: + return self._has_pickle_contents or self._has_cloudpickle_contents + + @property + def _has_pickle_contents(self) -> bool: return os.path.isfile(self._pickle_storage_file_path) + @property + def _has_cloudpickle_contents(self) -> bool: + return os.path.isfile(self._cloudpickle_storage_file_path) + class H5ioStorage(StorageInterface): diff --git a/pyiron_workflow/node.py b/pyiron_workflow/node.py index 678609d46..66762d158 100644 --- a/pyiron_workflow/node.py +++ b/pyiron_workflow/node.py @@ -143,7 +143,8 @@ class Node( - As long as you haven't put anything unpickleable on them, or defined them in an unpicklable place (e.g. in the `` of another function), you can simple (un)pickle nodes. There is no save/load interface for this right - now, just import pickle and do it. + now, just import pickle and do it. The "pickle" backend to the `Node.save` + method will fall back on `cloudpickle` as needed to overcome this. - Saving is triggered manually, or by setting a flag to save after the nodes runs. - At the end of instantiation, nodes will load automatically if they find saved @@ -168,8 +169,8 @@ class Node( - [ALPHA ISSUE] There are three possible back-ends for saving: one leaning on `tinybase.storage.GenericStorage` (in practice, `H5ioStorage(GenericStorage)`), and the other that uses the `h5io` module - directly. The third (default) option is to use `pickle`. The backend used - is always the one on the graph root. + directly. The third (default) option is to use `(cloud)pickle`. The backend + used is always the one on the graph root. - [ALPHA ISSUE] The `h5io` backend is deprecated -- it can't handle custom reconstructors (i.e. when `__reduce__` returns a tuple with some non-standard callable as its first entry), and basically all our nodes do diff --git a/tests/unit/test_node.py b/tests/unit/test_node.py index 42ce5fab4..a93a98e51 100644 --- a/tests/unit/test_node.py +++ b/tests/unit/test_node.py @@ -20,7 +20,7 @@ class ANode(Node): """To de-abstract the class""" def _setup_node(self) -> None: - self._inputs = Inputs(InputData("x", self, type_hint=int)) + self._inputs = Inputs(InputData("x", self, type_hint=int),) self._outputs = OutputsWithInjection( OutputDataWithInjection("y", self, type_hint=int), ) @@ -472,7 +472,29 @@ def test_storage(self): force_run.outputs.y.value, msg="Destroying the save should allow immediate re-running" ) + + hard_input = ANode(label="hard", storage_backend=backend) + hard_input.inputs.x.type_hint = callable + hard_input.inputs.x = lambda x: x * 2 + if backend == "pickle": + hard_input.save() + reloaded = ANode( + label=hard_input.label, + storage_backend=backend + ) + self.assertEqual( + reloaded.inputs.x.value(4), + hard_input.inputs.x.value(4), + msg="Cloud pickle should be strong enough to recover this" + ) + else: + with self.assertRaises( + (TypeError, AttributeError), + msg="Other backends are not powerful enough for some values" + ): + hard_input.save() finally: + hard_input.delete_storage() self.n1.delete_storage() @unittest.skipIf(sys.version_info < (3, 11), "Storage will only work in 3.11+")