diff --git a/sktime/base/_base.py b/sktime/base/_base.py index 6c89b5ed34e..6ff54472b17 100644 --- a/sktime/base/_base.py +++ b/sktime/base/_base.py @@ -66,6 +66,11 @@ class name: BaseEstimator from sktime.exceptions import NotFittedError from sktime.utils.random_state import set_random_state +SERIALIZATION_FORMATS = { + "pickle", + "cloudpickle", +} + class BaseObject(_BaseObject): """Base class for parametric objects with tags in sktime. @@ -157,7 +162,7 @@ def _get_set_config_doc(cls): doc += doc_end return doc - def save(self, path=None): + def save(self, path=None, serialization_format="pickle"): """Save serialized self to bytes-like object or to (.zip) file. Behaviour: @@ -177,6 +182,12 @@ def save(self, path=None): path="/home/stored/estimator" then a zip file `estimator.zip` will be stored in `/home/stored/`. + serialization_format: str, default = "pickle" + Module to use for serialization. + The available options are "pickle" and "cloudpickle". + Note that non-default formats might require + installation of other soft dependencies. + Returns ------- if `path` is None - in-memory serialized self @@ -187,21 +198,44 @@ def save(self, path=None): from pathlib import Path from zipfile import ZipFile - if path is None: - return (type(self), pickle.dumps(self)) - if not isinstance(path, (str, Path)): + from sktime.utils.validation._dependencies import _check_soft_dependencies + + if serialization_format not in SERIALIZATION_FORMATS: + raise ValueError( + f"The provided `serialization_format`='{serialization_format}' " + "is not yet supported. The possible formats are: " + f"{SERIALIZATION_FORMATS}." + ) + + if path is not None and not isinstance(path, (str, Path)): raise TypeError( "`path` is expected to either be a string or a Path object " f"but found of type:{type(path)}." ) - - path = Path(path) if isinstance(path, str) else path - path.mkdir() - - with open(path / "_metadata", "wb") as file: - pickle.dump(type(self), file) - with open(path / "_obj", "wb") as file: - pickle.dump(self, file) + if path is not None: + path = Path(path) if isinstance(path, str) else path + path.mkdir() + + if serialization_format == "cloudpickle": + _check_soft_dependencies("cloudpickle", severity="error") + import cloudpickle + + if path is None: + return (type(self), cloudpickle.dumps(self)) + + with open(path / "_metadata", "wb") as file: + cloudpickle.dump(type(self), file) + with open(path / "_obj", "wb") as file: + cloudpickle.dump(self, file) + + elif serialization_format == "pickle": + if path is None: + return (type(self), pickle.dumps(self)) + + with open(path / "_metadata", "wb") as file: + pickle.dump(type(self), file) + with open(path / "_obj", "wb") as file: + pickle.dump(self, file) shutil.make_archive(base_name=path, format="zip", root_dir=path) shutil.rmtree(path) diff --git a/sktime/base/_serialize.py b/sktime/base/_serialize.py index 5f2da84c4f4..b0aee1cd5d8 100644 --- a/sktime/base/_serialize.py +++ b/sktime/base/_serialize.py @@ -34,7 +34,7 @@ def load(serial): Examples -------- - Example 1: saving an estimator as pickle and loading + Example 1: saving an estimator in-memory and loading it back >>> from sktime.datasets import load_airline >>> from sktime.forecasting.naive import NaiveForecaster @@ -79,6 +79,30 @@ def load(serial): >>> # 4. continue using the loaded estimator >>> pred = cnn.predict(X=sample_test_X) # doctest: +SKIP >>> loaded_pred = loaded_cnn.predict(X=sample_test_X) # doctest: +SKIP + + Example 3: saving an estimator using cloudpickle's serialization functionality + and loading it back + Note: `cloudpickle` is a soft dependency and is not present + with the base-installation. + + >>> from sktime.classification.feature_based import Catch22Classifier + >>> from sktime.datasets import load_basic_motions # doctest: +SKIP + >>> + >>> # 1. Fit the estimator + >>> X_train, y_train = load_basic_motions(split="TRAIN") # doctest: +SKIP + >>> X_test, y_test = load_basic_motions(split="TEST") # doctest: +SKIP + >>> est = Catch22Classifier().fit(X_train, y_train) # doctest: +SKIP + >>> + >>> # 2. save the fitted estimator + >>> cpkl_serialized = est.save(serialization_format="cloudpickle") # doctest: +SKIP + >>> + >>> # 3. load the saved estimator (possibly after sending it across a stream) + >>> from sktime.base import load # doctest: +SKIP + >>> loaded_est = load(cpkl_serialized) # doctest: +SKIP + >>> + >>> # 4. continue using the estimator as normal + >>> pred = loaded_est.predict(X_test) # doctest: +SKIP + >>> loaded_pred = loaded_est.predict(X_test) # doctest: +SKIP """ import pickle from pathlib import Path diff --git a/sktime/classification/deep_learning/base.py b/sktime/classification/deep_learning/base.py index 1dda1707b0c..9c60cc3575f 100644 --- a/sktime/classification/deep_learning/base.py +++ b/sktime/classification/deep_learning/base.py @@ -12,6 +12,7 @@ from sklearn.preprocessing import LabelEncoder, OneHotEncoder from sklearn.utils import check_random_state +from sktime.base._base import SERIALIZATION_FORMATS from sktime.classification.base import BaseClassifier from sktime.utils.validation._dependencies import _check_soft_dependencies @@ -203,7 +204,7 @@ def __setstate__(self, state): if hasattr(self, "history"): self.__dict__["history"] = self.history - def save(self, path=None): + def save(self, path=None, serialization_format="pickle"): """Save serialized self to bytes-like object or to (.zip) file. Behaviour: @@ -225,18 +226,62 @@ def save(self, path=None): path="/home/stored/estimator" then a zip file `estimator.zip` will be stored in `/home/stored/`. + serialization_format: str, default = "pickle" + Module to use for serialization. + The available options are present under + `sktime.base._base.SERIALIZATION_FORMATS`. Note that non-default formats + might require installation of other soft dependencies. + Returns ------- if `path` is None - in-memory serialized self if `path` is file location - ZipFile with reference to the file """ import pickle - import shutil from pathlib import Path + + if serialization_format not in SERIALIZATION_FORMATS: + raise ValueError( + f"The provided `serialization_format`='{serialization_format}' " + "is not yet supported. The possible formats are: " + f"{SERIALIZATION_FORMATS}." + ) + + if path is not None and not isinstance(path, (str, Path)): + raise TypeError( + "`path` is expected to either be a string or a Path object " + f"but found of type:{type(path)}." + ) + + if path is not None: + path = Path(path) if isinstance(path, str) else path + path.mkdir() + + if serialization_format == "cloudpickle": + _check_soft_dependencies("cloudpickle", severity="error") + import cloudpickle + + return self._serialize_using_dump_func( + path=path, + dump=cloudpickle.dump, + dumps=cloudpickle.dumps, + ) + + elif serialization_format == "pickle": + return self._serialize_using_dump_func( + path=path, + dump=pickle.dump, + dumps=pickle.dumps, + ) + + def _serialize_using_dump_func(self, path, dump, dumps): + """Serialize & return DL Estimator using `dump` and `dumps` functions.""" + import shutil from zipfile import ZipFile + history = self.history.history if self.history is not None else None if path is None: - _check_soft_dependencies("h5py") + _check_soft_dependencies("h5py", severity="error") import h5py in_memory_model = None @@ -248,34 +293,25 @@ def save(self, path=None): h5file.flush() in_memory_model = h5file.id.get_file_image() - in_memory_history = pickle.dumps(self.history.history) - + in_memory_history = dumps(history) return ( type(self), ( - pickle.dumps(self), + dumps(self), in_memory_model, in_memory_history, ), ) - if not isinstance(path, (str, Path)): - raise TypeError( - "`path` is expected to either be a string or a Path object " - f"but found of type:{type(path)}." - ) - - path = Path(path) if isinstance(path, str) else path - path.mkdir() - if self.model_ is not None: self.model_.save(path / "keras/") with open(path / "history", "wb") as history_writer: - pickle.dump(self.history.history, history_writer) - - pickle.dump(type(self), open(path / "_metadata", "wb")) - pickle.dump(self, open(path / "_obj", "wb")) + dump(history, history_writer) + with open(path / "_metadata", "wb") as file: + dump(type(self), file) + with open(path / "_obj", "wb") as file: + dump(self, file) shutil.make_archive(base_name=path, format="zip", root_dir=path) shutil.rmtree(path) diff --git a/sktime/classification/deep_learning/cnn.py b/sktime/classification/deep_learning/cnn.py index 45d78cae6e3..042e8b2e81c 100644 --- a/sktime/classification/deep_learning/cnn.py +++ b/sktime/classification/deep_learning/cnn.py @@ -17,7 +17,6 @@ class CNNClassifier(BaseDeepClassifier): Parameters ---------- - should inherited fields be listed here? n_epochs : int, default = 2000 the number of epochs to train the model batch_size : int, default = 16 diff --git a/sktime/classification/kernel_based/tests/test_arsenal.py b/sktime/classification/kernel_based/tests/test_arsenal.py index 2017cf556f1..0281a9a6341 100644 --- a/sktime/classification/kernel_based/tests/test_arsenal.py +++ b/sktime/classification/kernel_based/tests/test_arsenal.py @@ -24,4 +24,4 @@ def test_contracted_arsenal(): ) arsenal.fit(X_train, y_train) - assert len(arsenal.estimators_) > 1 + assert len(arsenal.estimators_) >= 1 diff --git a/sktime/classification/tests/test_base.py b/sktime/classification/tests/test_base.py index caf5b1b5fbf..89dc59ad3df 100644 --- a/sktime/classification/tests/test_base.py +++ b/sktime/classification/tests/test_base.py @@ -508,3 +508,26 @@ def test_deep_estimator_full(optimizer): # check if components are same assert full_dummy.__dict__ == deserialized_full.__dict__ + + +DUMMY_EST_PARAMETERS_FOO = [None, 10.3, "string", {"key": "value"}, lambda x: x**2] + + +@pytest.mark.skipif( + not _check_soft_dependencies("cloudpickle", severity="none"), + reason="skip test if required soft dependency not available", +) +@pytest.mark.parametrize("foo", DUMMY_EST_PARAMETERS_FOO) +def test_save_estimator_using_cloudpickle(foo): + """Check if serialization works with cloudpickle.""" + from sktime.base._serialize import load + + est = _DummyComposite(foo) + + serialized = est.save(serialization_format="cloudpickle") + loaded_est = load(serialized) + + if callable(foo): + assert est.foo(2) == loaded_est.foo(2) + else: + assert est.foo == loaded_est.foo