Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] estimator serialization: user choice of serialization_format, support for cloudpickle #5486

Merged
merged 12 commits into from Nov 3, 2023
58 changes: 46 additions & 12 deletions sktime/base/_base.py
Expand Up @@ -66,6 +66,11 @@ class name: BaseEstimator
from sktime.exceptions import NotFittedError
from sktime.utils.random_state import set_random_state

SERIALIZATION_FORMATS = {
"pickle",
"cloudpickle",
}


class BaseObject(_BaseObject):
"""Base class for parametric objects with tags in sktime.
Expand Down Expand Up @@ -157,7 +162,7 @@ def _get_set_config_doc(cls):
doc += doc_end
return doc

def save(self, path=None):
def save(self, path=None, serialization_format="pickle"):
"""Save serialized self to bytes-like object or to (.zip) file.

Behaviour:
Expand All @@ -177,6 +182,12 @@ def save(self, path=None):
path="/home/stored/estimator" then a zip file `estimator.zip` will be
stored in `/home/stored/`.

serialization_format: str, default = "pickle"
Module to use for serialization.
The available options are present under
achieveordie marked this conversation as resolved.
Show resolved Hide resolved
`sktime.base._base.SERIALIZATION_FORMATS`. Note that non-default formats
might require installation of other soft dependencies.

Returns
-------
if `path` is None - in-memory serialized self
Expand All @@ -187,21 +198,44 @@ def save(self, path=None):
from pathlib import Path
from zipfile import ZipFile

if path is None:
return (type(self), pickle.dumps(self))
if not isinstance(path, (str, Path)):
from sktime.utils.validation._dependencies import _check_soft_dependencies

if serialization_format not in SERIALIZATION_FORMATS:
raise ValueError(
f"The provided `serialization_format`='{serialization_format}' "
"is not yet supported. The possible formats are: "
f"{SERIALIZATION_FORMATS}."
)

if path is not None and not isinstance(path, (str, Path)):
raise TypeError(
"`path` is expected to either be a string or a Path object "
f"but found of type:{type(path)}."
)

path = Path(path) if isinstance(path, str) else path
path.mkdir()

with open(path / "_metadata", "wb") as file:
pickle.dump(type(self), file)
with open(path / "_obj", "wb") as file:
pickle.dump(self, file)
if path is not None:
path = Path(path) if isinstance(path, str) else path
path.mkdir()

if serialization_format == "cloudpickle":
_check_soft_dependencies("cloudpickle", severity="error")
import cloudpickle

if path is None:
return (type(self), cloudpickle.dumps(self))

with open(path / "_metadata", "wb") as file:
cloudpickle.dump(type(self), file)
with open(path / "_obj", "wb") as file:
cloudpickle.dump(self, file)

elif serialization_format == "pickle":
if path is None:
return (type(self), pickle.dumps(self))

with open(path / "_metadata", "wb") as file:
pickle.dump(type(self), file)
with open(path / "_obj", "wb") as file:
pickle.dump(self, file)

shutil.make_archive(base_name=path, format="zip", root_dir=path)
shutil.rmtree(path)
Expand Down
26 changes: 25 additions & 1 deletion sktime/base/_serialize.py
Expand Up @@ -34,7 +34,7 @@ def load(serial):

Examples
--------
Example 1: saving an estimator as pickle and loading
Example 1: saving an estimator in-memory and loading it back

>>> from sktime.datasets import load_airline
>>> from sktime.forecasting.naive import NaiveForecaster
Expand Down Expand Up @@ -79,6 +79,30 @@ def load(serial):
>>> # 4. continue using the loaded estimator
>>> pred = cnn.predict(X=sample_test_X) # doctest: +SKIP
>>> loaded_pred = loaded_cnn.predict(X=sample_test_X) # doctest: +SKIP

Example 3: saving an estimator using cloudpickle's serialization functionality
and loading it back
Note: `cloudpickle` is a soft dependency and is not present
with the base-installation.

>>> from sktime.classification.feature_based import Catch22Classifier
>>> from sktime.datasets import load_basic_motions # doctest: +SKIP
>>>
>>> # 1. Fit the estimator
>>> X_train, y_train = load_basic_motions(split="TRAIN") # doctest: +SKIP
>>> X_test, y_test = load_basic_motions(split="TEST") # doctest: +SKIP
>>> est = Catch22Classifier().fit(X_train, y_train) # doctest: +SKIP
>>>
>>> # 2. save the fitted estimator
>>> cpkl_serialized = est.save(serialization_format="cloudpickle") # doctest: +SKIP
>>>
>>> # 3. load the saved estimator (possibly after sending it across a stream)
>>> from sktime.base import load # doctest: +SKIP
>>> loaded_est = load(cpkl_serialized) # doctest: +SKIP
>>>
>>> # 4. continue using the estimator as normal
>>> pred = loaded_est.predict(X_test) # doctest: +SKIP
>>> loaded_pred = loaded_est.predict(X_test) # doctest: +SKIP
"""
import pickle
from pathlib import Path
Expand Down
74 changes: 55 additions & 19 deletions sktime/classification/deep_learning/base.py
Expand Up @@ -12,6 +12,7 @@
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from sklearn.utils import check_random_state

from sktime.base._base import SERIALIZATION_FORMATS
from sktime.classification.base import BaseClassifier
from sktime.utils.validation._dependencies import _check_soft_dependencies

Expand Down Expand Up @@ -203,7 +204,7 @@ def __setstate__(self, state):
if hasattr(self, "history"):
self.__dict__["history"] = self.history

def save(self, path=None):
def save(self, path=None, serialization_format="pickle"):
"""Save serialized self to bytes-like object or to (.zip) file.

Behaviour:
Expand All @@ -225,18 +226,62 @@ def save(self, path=None):
path="/home/stored/estimator" then a zip file `estimator.zip` will be
stored in `/home/stored/`.

serialization_format: str, default = "pickle"
Module to use for serialization.
The available options are present under
`sktime.base._base.SERIALIZATION_FORMATS`. Note that non-default formats
might require installation of other soft dependencies.

Returns
-------
if `path` is None - in-memory serialized self
if `path` is file location - ZipFile with reference to the file
"""
import pickle
import shutil
from pathlib import Path

if serialization_format not in SERIALIZATION_FORMATS:
raise ValueError(
f"The provided `serialization_format`='{serialization_format}' "
"is not yet supported. The possible formats are: "
f"{SERIALIZATION_FORMATS}."
)

if path is not None and not isinstance(path, (str, Path)):
raise TypeError(
"`path` is expected to either be a string or a Path object "
f"but found of type:{type(path)}."
)

if path is not None:
path = Path(path) if isinstance(path, str) else path
path.mkdir()

if serialization_format == "cloudpickle":
_check_soft_dependencies("cloudpickle", severity="error")
import cloudpickle

return self._serialize_using_dump_func(
path=path,
dump=cloudpickle.dump,
dumps=cloudpickle.dumps,
)

elif serialization_format == "pickle":
return self._serialize_using_dump_func(
path=path,
dump=pickle.dump,
dumps=pickle.dumps,
)

def _serialize_using_dump_func(self, path, dump, dumps):
"""Serialize & return DL Estimator using `dump` and `dumps` functions."""
import shutil
from zipfile import ZipFile

history = self.history.history if self.history is not None else None
if path is None:
_check_soft_dependencies("h5py")
_check_soft_dependencies("h5py", severity="error")
import h5py

in_memory_model = None
Expand All @@ -248,34 +293,25 @@ def save(self, path=None):
h5file.flush()
in_memory_model = h5file.id.get_file_image()

in_memory_history = pickle.dumps(self.history.history)

in_memory_history = dumps(history)
return (
type(self),
(
pickle.dumps(self),
dumps(self),
in_memory_model,
in_memory_history,
),
)

if not isinstance(path, (str, Path)):
raise TypeError(
"`path` is expected to either be a string or a Path object "
f"but found of type:{type(path)}."
)

path = Path(path) if isinstance(path, str) else path
path.mkdir()

if self.model_ is not None:
self.model_.save(path / "keras/")

with open(path / "history", "wb") as history_writer:
pickle.dump(self.history.history, history_writer)

pickle.dump(type(self), open(path / "_metadata", "wb"))
pickle.dump(self, open(path / "_obj", "wb"))
dump(history, history_writer)
with open(path / "_metadata", "wb") as file:
dump(type(self), file)
with open(path / "_obj", "wb") as file:
dump(self, file)

shutil.make_archive(base_name=path, format="zip", root_dir=path)
shutil.rmtree(path)
Expand Down
1 change: 0 additions & 1 deletion sktime/classification/deep_learning/cnn.py
Expand Up @@ -17,7 +17,6 @@ class CNNClassifier(BaseDeepClassifier):

Parameters
----------
should inherited fields be listed here?
n_epochs : int, default = 2000
the number of epochs to train the model
batch_size : int, default = 16
Expand Down
6 changes: 6 additions & 0 deletions sktime/classification/kernel_based/tests/test_arsenal.py
Expand Up @@ -6,6 +6,12 @@
from sktime.tests.test_switch import run_test_for_class


# A reference to this issue is also present inside sktime/tests/_config.py,
# and needs to be removed from `EXCLUDED_TESTS` upon resolution.
@pytest.mark.skip(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is sporadic, we should not skip it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the general consensus on dealing with such sporadic errors?

reason="Fails because of `len(obj.estimators_)==1`, "
"refer issue #5488 for details."
)
@pytest.mark.skipif(
not run_test_for_class(Arsenal),
reason="run test only if softdeps are present and incrementally (if requested)",
Expand Down
23 changes: 23 additions & 0 deletions sktime/classification/tests/test_base.py
Expand Up @@ -508,3 +508,26 @@ def test_deep_estimator_full(optimizer):

# check if components are same
assert full_dummy.__dict__ == deserialized_full.__dict__


DUMMY_EST_PARAMETERS_FOO = [None, 10.3, "string", {"key": "value"}, lambda x: x**2]


@pytest.mark.skipif(
not _check_soft_dependencies("cloudpickle", severity="none"),
reason="skip test if required soft dependency not available",
)
@pytest.mark.parametrize("foo", DUMMY_EST_PARAMETERS_FOO)
def test_save_estimator_using_cloudpickle(foo):
"""Check if serialization works with cloudpickle."""
from sktime.base._serialize import load

est = _DummyComposite(foo)

serialized = est.save(serialization_format="cloudpickle")
loaded_est = load(serialized)

if callable(foo):
assert est.foo(2) == loaded_est.foo(2)
else:
assert est.foo == loaded_est.foo
4 changes: 4 additions & 0 deletions sktime/tests/_config.py
Expand Up @@ -178,6 +178,10 @@
"test_hierarchical_with_exogeneous", # refer to #4743
],
"Pipeline": ["test_inheritance"], # does not inherit from intermediate base classes
# Arsenal Classifier contract fails on `len(obj.estimators)=1`, refer to #5488
# The actual test is present in test_arsenal::test_contracted_arsenal.py, present
# here only for reference.
"Arsenal": ["test_contracted_arsenal"],
}

# We use estimator tags in addition to class hierarchies to further distinguish
Expand Down
37 changes: 10 additions & 27 deletions sktime/utils/mlflow_sktime.py
Expand Up @@ -33,7 +33,7 @@
`pyfunc.predict()` will return output from sktime `predict()` method.
"""

__author__ = ["benjaminbluhm"]
__author__ = ["benjaminbluhm", "achieveordie"]
__all__ = [
"get_default_pip_requirements",
"get_default_conda_env",
Expand All @@ -44,13 +44,13 @@

import logging
import os
import pickle

import pandas as pd
import yaml

import sktime
from sktime import utils
from sktime.base._serialize import load
from sktime.utils.multiindex import flatten_multiindex
from sktime.utils.validation._dependencies import _check_soft_dependencies

Expand Down Expand Up @@ -258,7 +258,7 @@ def save_model(
if input_example is not None:
_save_example(mlflow_model, input_example, path)

model_data_subpath = "model.pkl"
model_data_subpath = "model"
model_data_path = os.path.join(path, model_data_subpath)
_save_model(
sktime_model, model_data_path, serialization_format=serialization_format
Expand Down Expand Up @@ -523,22 +523,12 @@ def _save_model(model, path, serialization_format):
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INTERNAL_ERROR

with open(path, "wb") as out:
if serialization_format == SERIALIZATION_FORMAT_PICKLE:
pickle.dump(model, out)
elif serialization_format == SERIALIZATION_FORMAT_CLOUDPICKLE:
_check_soft_dependencies("cloudpickle", severity="error")
import cloudpickle

cloudpickle.dump(model, out)
else:
raise MlflowException(
message="Unrecognized serialization format: "
"{serialization_format}".format(
serialization_format=serialization_format
),
error_code=INTERNAL_ERROR,
)
if serialization_format not in SUPPORTED_SERIALIZATION_FORMATS:
raise MlflowException(
message="Unrecognized serialization format: " f"{serialization_format}.",
error_code=INTERNAL_ERROR,
)
model.save(path=path, serialization_format=serialization_format)


def _load_model(path, serialization_format):
Expand All @@ -559,14 +549,7 @@ def _load_model(path, serialization_format):
error_code=INVALID_PARAMETER_VALUE,
)

with open(path, "rb") as pickled_model:
if serialization_format == SERIALIZATION_FORMAT_PICKLE:
return pickle.load(pickled_model)
elif serialization_format == SERIALIZATION_FORMAT_CLOUDPICKLE:
_check_soft_dependencies("cloudpickle", severity="error")
import cloudpickle

return cloudpickle.load(pickled_model)
return load(path)


def _load_pyfunc(path):
Expand Down