Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] estimator serialization: user choice of serialization_format, support for cloudpickle #5486

Merged
merged 12 commits into from Nov 3, 2023
58 changes: 46 additions & 12 deletions sktime/base/_base.py
Expand Up @@ -66,6 +66,11 @@ class name: BaseEstimator
from sktime.exceptions import NotFittedError
from sktime.utils.random_state import set_random_state

SERIALIZATION_FORMATS = {
"pickle",
"cloudpickle",
}


class BaseObject(_BaseObject):
"""Base class for parametric objects with tags in sktime.
Expand Down Expand Up @@ -157,7 +162,7 @@ def _get_set_config_doc(cls):
doc += doc_end
return doc

def save(self, path=None):
def save(self, path=None, serialization_format="pickle"):
"""Save serialized self to bytes-like object or to (.zip) file.

Behaviour:
Expand All @@ -177,6 +182,12 @@ def save(self, path=None):
path="/home/stored/estimator" then a zip file `estimator.zip` will be
stored in `/home/stored/`.

serialization_format: str, default = "pickle"
Module to use for serialization.
The available options are "pickle" and "cloudpickle".
Note that non-default formats might require
installation of other soft dependencies.

Returns
-------
if `path` is None - in-memory serialized self
Expand All @@ -187,21 +198,44 @@ def save(self, path=None):
from pathlib import Path
from zipfile import ZipFile

if path is None:
return (type(self), pickle.dumps(self))
if not isinstance(path, (str, Path)):
from sktime.utils.validation._dependencies import _check_soft_dependencies

if serialization_format not in SERIALIZATION_FORMATS:
raise ValueError(
f"The provided `serialization_format`='{serialization_format}' "
"is not yet supported. The possible formats are: "
f"{SERIALIZATION_FORMATS}."
)

if path is not None and not isinstance(path, (str, Path)):
raise TypeError(
"`path` is expected to either be a string or a Path object "
f"but found of type:{type(path)}."
)

path = Path(path) if isinstance(path, str) else path
path.mkdir()

with open(path / "_metadata", "wb") as file:
pickle.dump(type(self), file)
with open(path / "_obj", "wb") as file:
pickle.dump(self, file)
if path is not None:
path = Path(path) if isinstance(path, str) else path
path.mkdir()

if serialization_format == "cloudpickle":
_check_soft_dependencies("cloudpickle", severity="error")
import cloudpickle

if path is None:
return (type(self), cloudpickle.dumps(self))

with open(path / "_metadata", "wb") as file:
cloudpickle.dump(type(self), file)
with open(path / "_obj", "wb") as file:
cloudpickle.dump(self, file)

elif serialization_format == "pickle":
if path is None:
return (type(self), pickle.dumps(self))

with open(path / "_metadata", "wb") as file:
pickle.dump(type(self), file)
with open(path / "_obj", "wb") as file:
pickle.dump(self, file)

shutil.make_archive(base_name=path, format="zip", root_dir=path)
shutil.rmtree(path)
Expand Down
26 changes: 25 additions & 1 deletion sktime/base/_serialize.py
Expand Up @@ -34,7 +34,7 @@ def load(serial):

Examples
--------
Example 1: saving an estimator as pickle and loading
Example 1: saving an estimator in-memory and loading it back

>>> from sktime.datasets import load_airline
>>> from sktime.forecasting.naive import NaiveForecaster
Expand Down Expand Up @@ -79,6 +79,30 @@ def load(serial):
>>> # 4. continue using the loaded estimator
>>> pred = cnn.predict(X=sample_test_X) # doctest: +SKIP
>>> loaded_pred = loaded_cnn.predict(X=sample_test_X) # doctest: +SKIP

Example 3: saving an estimator using cloudpickle's serialization functionality
and loading it back
Note: `cloudpickle` is a soft dependency and is not present
with the base-installation.

>>> from sktime.classification.feature_based import Catch22Classifier
>>> from sktime.datasets import load_basic_motions # doctest: +SKIP
>>>
>>> # 1. Fit the estimator
>>> X_train, y_train = load_basic_motions(split="TRAIN") # doctest: +SKIP
>>> X_test, y_test = load_basic_motions(split="TEST") # doctest: +SKIP
>>> est = Catch22Classifier().fit(X_train, y_train) # doctest: +SKIP
>>>
>>> # 2. save the fitted estimator
>>> cpkl_serialized = est.save(serialization_format="cloudpickle") # doctest: +SKIP
>>>
>>> # 3. load the saved estimator (possibly after sending it across a stream)
>>> from sktime.base import load # doctest: +SKIP
>>> loaded_est = load(cpkl_serialized) # doctest: +SKIP
>>>
>>> # 4. continue using the estimator as normal
>>> pred = loaded_est.predict(X_test) # doctest: +SKIP
>>> loaded_pred = loaded_est.predict(X_test) # doctest: +SKIP
"""
import pickle
from pathlib import Path
Expand Down
74 changes: 55 additions & 19 deletions sktime/classification/deep_learning/base.py
Expand Up @@ -12,6 +12,7 @@
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from sklearn.utils import check_random_state

from sktime.base._base import SERIALIZATION_FORMATS
from sktime.classification.base import BaseClassifier
from sktime.utils.validation._dependencies import _check_soft_dependencies

Expand Down Expand Up @@ -203,7 +204,7 @@ def __setstate__(self, state):
if hasattr(self, "history"):
self.__dict__["history"] = self.history

def save(self, path=None):
def save(self, path=None, serialization_format="pickle"):
"""Save serialized self to bytes-like object or to (.zip) file.

Behaviour:
Expand All @@ -225,18 +226,62 @@ def save(self, path=None):
path="/home/stored/estimator" then a zip file `estimator.zip` will be
stored in `/home/stored/`.

serialization_format: str, default = "pickle"
Module to use for serialization.
The available options are present under
`sktime.base._base.SERIALIZATION_FORMATS`. Note that non-default formats
might require installation of other soft dependencies.

Returns
-------
if `path` is None - in-memory serialized self
if `path` is file location - ZipFile with reference to the file
"""
import pickle
import shutil
from pathlib import Path

if serialization_format not in SERIALIZATION_FORMATS:
raise ValueError(
f"The provided `serialization_format`='{serialization_format}' "
"is not yet supported. The possible formats are: "
f"{SERIALIZATION_FORMATS}."
)

if path is not None and not isinstance(path, (str, Path)):
raise TypeError(
"`path` is expected to either be a string or a Path object "
f"but found of type:{type(path)}."
)

if path is not None:
path = Path(path) if isinstance(path, str) else path
path.mkdir()

if serialization_format == "cloudpickle":
_check_soft_dependencies("cloudpickle", severity="error")
import cloudpickle

return self._serialize_using_dump_func(
path=path,
dump=cloudpickle.dump,
dumps=cloudpickle.dumps,
)

elif serialization_format == "pickle":
return self._serialize_using_dump_func(
path=path,
dump=pickle.dump,
dumps=pickle.dumps,
)

def _serialize_using_dump_func(self, path, dump, dumps):
"""Serialize & return DL Estimator using `dump` and `dumps` functions."""
import shutil
from zipfile import ZipFile

history = self.history.history if self.history is not None else None
if path is None:
_check_soft_dependencies("h5py")
_check_soft_dependencies("h5py", severity="error")
import h5py

in_memory_model = None
Expand All @@ -248,34 +293,25 @@ def save(self, path=None):
h5file.flush()
in_memory_model = h5file.id.get_file_image()

in_memory_history = pickle.dumps(self.history.history)

in_memory_history = dumps(history)
return (
type(self),
(
pickle.dumps(self),
dumps(self),
in_memory_model,
in_memory_history,
),
)

if not isinstance(path, (str, Path)):
raise TypeError(
"`path` is expected to either be a string or a Path object "
f"but found of type:{type(path)}."
)

path = Path(path) if isinstance(path, str) else path
path.mkdir()

if self.model_ is not None:
self.model_.save(path / "keras/")

with open(path / "history", "wb") as history_writer:
pickle.dump(self.history.history, history_writer)

pickle.dump(type(self), open(path / "_metadata", "wb"))
pickle.dump(self, open(path / "_obj", "wb"))
dump(history, history_writer)
with open(path / "_metadata", "wb") as file:
dump(type(self), file)
with open(path / "_obj", "wb") as file:
dump(self, file)

shutil.make_archive(base_name=path, format="zip", root_dir=path)
shutil.rmtree(path)
Expand Down
1 change: 0 additions & 1 deletion sktime/classification/deep_learning/cnn.py
Expand Up @@ -17,7 +17,6 @@ class CNNClassifier(BaseDeepClassifier):

Parameters
----------
should inherited fields be listed here?
n_epochs : int, default = 2000
the number of epochs to train the model
batch_size : int, default = 16
Expand Down
23 changes: 23 additions & 0 deletions sktime/classification/tests/test_base.py
Expand Up @@ -508,3 +508,26 @@ def test_deep_estimator_full(optimizer):

# check if components are same
assert full_dummy.__dict__ == deserialized_full.__dict__


DUMMY_EST_PARAMETERS_FOO = [None, 10.3, "string", {"key": "value"}, lambda x: x**2]


@pytest.mark.skipif(
not _check_soft_dependencies("cloudpickle", severity="none"),
reason="skip test if required soft dependency not available",
)
@pytest.mark.parametrize("foo", DUMMY_EST_PARAMETERS_FOO)
def test_save_estimator_using_cloudpickle(foo):
"""Check if serialization works with cloudpickle."""
from sktime.base._serialize import load

est = _DummyComposite(foo)

serialized = est.save(serialization_format="cloudpickle")
loaded_est = load(serialized)

if callable(foo):
assert est.foo(2) == loaded_est.foo(2)
else:
assert est.foo == loaded_est.foo
4 changes: 4 additions & 0 deletions sktime/tests/_config.py
Expand Up @@ -178,6 +178,10 @@
"test_hierarchical_with_exogeneous", # refer to #4743
],
"Pipeline": ["test_inheritance"], # does not inherit from intermediate base classes
# Arsenal Classifier contract fails on `len(obj.estimators)=1`, refer to #5488
# The actual test is present in test_arsenal::test_contracted_arsenal.py, present
# here only for reference.
"Arsenal": ["test_contracted_arsenal"],
}

# We use estimator tags in addition to class hierarchies to further distinguish
Expand Down
2 changes: 1 addition & 1 deletion sktime/utils/mlflow_sktime.py
Expand Up @@ -28,7 +28,7 @@
`{"predict_method": {"predict": {}, "predict_interval": {"coverage": [0.1, 0.9]}}`.
`Dict[str, list]`, with default parameters in predict method, for example
`{"predict_method": ["predict", "predict_interval"}` (Note: when including
`predict_proba` method the former appraoch must be followed as `quantiles`
`predict_proba` method the former approach must be followed as `quantiles`
parameter has to be provided by the user). If no prediction config is defined
`pyfunc.predict()` will return output from sktime `predict()` method.
"""
Expand Down