diff --git a/.all-contributorsrc b/.all-contributorsrc index 9e4c7275888..faf27c7a969 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -28,6 +28,7 @@ "question", "review", "talk", + "test", "tutorial" ] }, diff --git a/sktime/base/__init__.py b/sktime/base/__init__.py index b25471f756c..a0c4aaccbf7 100644 --- a/sktime/base/__init__.py +++ b/sktime/base/__init__.py @@ -2,11 +2,14 @@ # -*- coding: utf-8 -*- # copyright: sktime developers, BSD-3-Clause License (see LICENSE file) -__author__ = ["Markus Löning"] +"""Base classes for defining estimators and other objects in sktime.""" + +__author__ = ["mloning", "RNKuhns", "fkiraly"] __all__ = [ + "BaseObject", "BaseEstimator", "_HeterogenousMetaEstimator", ] -from sktime.base._base import BaseEstimator +from sktime.base._base import BaseObject, BaseEstimator from sktime.base._meta import _HeterogenousMetaEstimator diff --git a/sktime/base/_base.py b/sktime/base/_base.py index 51645d2ff74..a7e86d708e6 100644 --- a/sktime/base/_base.py +++ b/sktime/base/_base.py @@ -1,12 +1,54 @@ -#!/usr/bin/env python3 -u # -*- coding: utf-8 -*- -# copyright: sktime developers, BSD-3-Clause License (see LICENSE file) +""" +Base class template for objects and fittable objects. -__author__ = ["Markus Löning"] -__all__ = ["BaseEstimator"] +templates in this module: + + BaseObject - object with parameters and tags + BaseEstimator - BaseObject that can be fitted + +Interface specifications below. + +--- + + class name: BaseObject + +Hyper-parameter inspection and setter methods: + inspect hyper-parameters - get_params() + setting hyper-parameters - set_params(**params) + +Tag inspection and setter methods + inspect tags (all) - get_tags() + inspect tags (one tag) - get_tag(tag_name: str, tag_value_default=None) + inspect tags (class method) - get_class_tags() + inspect tags (one tag, class) - get_class_tag(tag_name:str, tag_value_default=None) + setting dynamic tags - set_tag(**tag_dict: dict) + set/clone dynamic tags - clone_tags(estimator, tag_names=None) + +--- + + class name: BaseEstimator + +Provides all interface points of BaseObject, plus: + +Parameter inspection: + fitted parameter inspection - get_fitted_params() + +State: + fitted model/strategy - by convention, any attributes ending in "_" + fitted state flag - is_fitted (property) + fitted state check - check_is_fitted (raises error if not is_fitted) + +copyright: sktime developers, BSD-3-Clause License (see LICENSE file) +""" + +__author__ = ["mloning", "RNKuhns", "fkiraly"] +__all__ = ["BaseEstimator", "BaseObject"] import inspect +from copy import deepcopy + from sklearn import clone from sklearn.base import BaseEstimator as _BaseEstimator from sklearn.ensemble._base import _set_random_states @@ -14,14 +56,154 @@ from sktime.exceptions import NotFittedError -class BaseEstimator(_BaseEstimator): +class BaseObject(_BaseEstimator): + """Base class for parametric objects with tags sktime. + + Extends scikit-learn's BaseEstimator to include sktime interface for tags. + """ + + def __init__(self): + self._tags_dynamic = dict() + super(BaseObject, self).__init__() + + @classmethod + def get_class_tags(cls): + """Get class tags from estimator class and all its parent classes. + + Returns + ------- + collected_tags : dictionary of tag names : tag values + collected from _tags class attribute via nested inheritance + NOT overridden by dynamic tags set by set_tags or mirror_tags + """ + collected_tags = dict() + + # We exclude the last two parent classes: sklearn.base.BaseEstimator and + # the basic Python object. + for parent_class in reversed(inspect.getmro(cls)[:-2]): + if hasattr(parent_class, "_tags"): + # Need the if here because mixins might not have _more_tags + # but might do redundant work in estimators + # (i.e. calling more tags on BaseEstimator multiple times) + more_tags = parent_class._tags + collected_tags.update(more_tags) + + return deepcopy(collected_tags) + + @classmethod + def get_class_tag(cls, tag_name, tag_value_default=None): + """Get tag value from estimator class (only class tags). + + Parameters + ---------- + tag_name : str, name of tag value + tag_value_default : any type, default/fallback value if tag is not found + + Returns + ------- + tag_value : value of the tag tag_name in self if found + if tag is not found, returns tag_value_default + """ + collected_tags = cls.get_class_tags() + + return collected_tags.get(tag_name, tag_value_default) + + def get_tags(self): + """Get tags from estimator class and dynamic tag overrides. + + Returns + ------- + collected_tags : dictionary of tag names : tag values + collected from _tags class attribute via nested inheritance + then any overrides and new tags from _tags_dynamic object attribute + """ + collected_tags = self.get_class_tags() + + if hasattr(self, "_tags_dynamic"): + collected_tags.update(self._tags_dynamic) + + return deepcopy(collected_tags) + + def get_tag(self, tag_name, tag_value_default=None): + """Get tag value from estimator class and dynamic tag overrides. + + Parameters + ---------- + tag_name : str, name of tag value + tag_value_default : any type, default/fallback value if tag is not found + + Returns + ------- + tag_value : value of the tag tag_name in self if found + if tag is not found, returns tag_value_default + """ + collected_tags = self.get_tags() + + return collected_tags.get(tag_name, tag_value_default) + + def set_tags(self, **tag_dict): + """Set dynamic tags to given values. + + Parameters + ---------- + tag_dict : dictionary of tag names : tag values + + Returns + ------- + reference to self + + State change + ------------ + sets tag values in tag_dict as dynamic tags in self + """ + self._tags_dynamic.update(deepcopy(tag_dict)) + + return self + + def clone_tags(self, estimator, tag_names=None): + """clone/mirror tags from another estimator as dynamic override. + + Parameters + ---------- + estimator : an estimator inheriting from BaseEstimator + tag_names : list of str, or str; names of tags to clone + default = list of all tags in estimator + + Returns + ------- + reference to self + + State change + ------------ + sets tag values in tag_set from estimator as dynamic tags in self + """ + tags_est = deepcopy(estimator.get_tags()) + + # if tag_set is not passed, default is all tags in estimator + if tag_names is None: + tag_names = tags_est.keys() + else: + # if tag_set is passed, intersect keys with tags in estimator + if not isinstance(tag_names, list): + tag_names = [tag_names] + tag_names = [key for key in tag_names if key in tags_est.keys()] + + update_dict = {key: tags_est[key] for key in tag_names} + + self.set_tags(update_dict) + + return self + + +class BaseEstimator(BaseObject): """Base class for defining estimators in sktime. - Extends scikit-learn's BaseEstimator. + Extends sktime's BaseObject to include basic functionality for fittable estimators. """ def __init__(self): self._is_fitted = False + super(BaseEstimator, self).__init__() @property def is_fitted(self): @@ -42,27 +224,6 @@ def check_is_fitted(self): f"been fitted yet; please call `fit` first." ) - @classmethod - def _all_tags(cls): - """Get tags from estimator class and all its parent classes.""" - # We here create a separate estimator tag interface in addition to the one in - # scikit-learn to make sure we do not interfere with scikit-learn's one - # when we inherit from scikit-learn classes. We also make estimator tags a - # class rather than object attribute. - collected_tags = dict() - - # We exclude the last two parent classes; sklearn.base.BaseEstimator and - # the basic Python object. - for parent_class in reversed(inspect.getmro(cls)[:-2]): - if hasattr(parent_class, "_tags"): - # Need the if here because mixins might not have _more_tags - # but might do redundant work in estimators - # (i.e. calling more tags on BaseEstimator multiple times) - more_tags = parent_class._tags - collected_tags.update(more_tags) - - return collected_tags - def _clone_estimator(base_estimator, random_state=None): estimator = clone(base_estimator) diff --git a/sktime/base/_meta.py b/sktime/base/_meta.py index 323b838613e..8173c3014a7 100644 --- a/sktime/base/_meta.py +++ b/sktime/base/_meta.py @@ -11,16 +11,17 @@ class _HeterogenousMetaEstimator(BaseEstimator, metaclass=ABCMeta): - """Handles parameter management for estimtators composed of named - estimators. + """Handles parameter management for estimtators composed of named estimators. - from sklearn utils.metaestimator.py + From sklearn utils.metaestimator.py """ def get_params(self, deep=True): + """Return estimator parameters.""" raise NotImplementedError("abstract method") def set_params(self, **params): + """Set estimator parameters.""" raise NotImplementedError("abstract method") def _get_params(self, attr, deep=True): diff --git a/sktime/base/tests/test_base.py b/sktime/base/tests/test_base.py new file mode 100644 index 00000000000..e6456fdd75a --- /dev/null +++ b/sktime/base/tests/test_base.py @@ -0,0 +1,154 @@ +# -*- coding: utf-8 -*- +""" +Tests for BaseObject universal base class. + +tests in this module: + + test_get_class_tags - tests get_class_tags inheritance logic + test_get_class_tag - tests get_class_tag logic, incl default value + test_get_tags - tests get_tags inheritance logic + test_get_tag - tests get_tag logic, incl default value + test_set_tags - tests set_tags logic and related get_tags inheritance + +copyright: sktime developers, BSD-3-Clause License (see LICENSE file) +""" + +__author__ = ["fkiraly"] + +__all__ = [ + "test_get_class_tags", + "test_get_class_tag", + "test_get_tags", + "test_get_tag", + "test_set_tags", +] + +from copy import deepcopy + +from sktime.base import BaseObject + + +# Fixture class for testing tag system +class FixtureClassParent(BaseObject): + + _tags = {"A": "1", "B": 2, "C": 1234, 3: "D"} + + +# Fixture class for testing tag system, child overrides tags +class FixtureClassChild(FixtureClassParent): + + _tags = {"A": 42, 3: "E"} + + +FIXTURE_CLASSCHILD = FixtureClassChild + +FIXTURE_CLASSCHILD_TAGS = {"A": 42, "B": 2, "C": 1234, 3: "E"} + +# Fixture class for testing tag system, object overrides class tags +FIXTURE_OBJECT = FixtureClassChild() +FIXTURE_OBJECT._tags_dynamic = {"A": 42424241, "B": 3} + +FIXTURE_OBJECT_TAGS = {"A": 42424241, "B": 3, "C": 1234, 3: "E"} + + +def test_get_class_tags(): + """Tests get_class_tags class method of BaseObject for correctness. + + Raises + ------ + AssertError if inheritance logic in get_class_tags is incorrect + """ + child_tags = FIXTURE_CLASSCHILD.get_class_tags() + + msg = "Inheritance logic in BaseObject.get_class_tags is incorrect" + + assert child_tags == FIXTURE_CLASSCHILD_TAGS, msg + + +def test_get_class_tag(): + """Tests get_class_tag class method of BaseObject for correctness. + + Raises + ------ + AssertError if inheritance logic in get_tag is incorrect + AssertError if default override logic in get_tag is incorrect + """ + child_tags = dict() + child_tags_keys = FIXTURE_CLASSCHILD_TAGS.keys() + + for key in child_tags_keys: + child_tags[key] = FIXTURE_CLASSCHILD.get_class_tag(key) + + child_tag_default = FIXTURE_CLASSCHILD.get_class_tag("foo", "bar") + child_tag_defaultNone = FIXTURE_CLASSCHILD.get_class_tag("bar") + + msg = "Inheritance logic in BaseObject.get_class_tag is incorrect" + + for key in child_tags_keys: + assert child_tags[key] == FIXTURE_CLASSCHILD_TAGS[key], msg + + msg = "Default override logic in BaseObject.get_class_tag is incorrect" + + assert child_tag_default == "bar", msg + assert child_tag_defaultNone is None, msg + + +def test_get_tags(): + """Tests get_tags method of BaseObject for correctness. + + Raises + ------ + AssertError if inheritance logic in get_tags is incorrect + """ + object_tags = FIXTURE_OBJECT.get_tags() + + msg = "Inheritance logic in BaseObject.get_tags is incorrect" + + assert object_tags == FIXTURE_OBJECT_TAGS, msg + + +def test_get_tag(): + """Tests get_tag method of BaseObject for correctness. + + Raises + ------ + AssertError if inheritance logic in get_tag is incorrect + AssertError if default override logic in get_tag is incorrect + """ + object_tags = dict() + object_tags_keys = FIXTURE_OBJECT_TAGS.keys() + + for key in object_tags_keys: + object_tags[key] = FIXTURE_OBJECT.get_tag(key) + + object_tag_default = FIXTURE_OBJECT.get_tag("foo", "bar") + object_tag_defaultNone = FIXTURE_OBJECT.get_tag("bar") + + msg = "Inheritance logic in BaseObject.get_tag is incorrect" + + for key in object_tags_keys: + assert object_tags[key] == FIXTURE_OBJECT_TAGS[key], msg + + msg = "Default override logic in BaseObject.get_tag is incorrect" + + assert object_tag_default == "bar", msg + assert object_tag_defaultNone is None, msg + + +FIXTURE_TAG_SET = {"A": 42424243, "E": 3} +FIXTURE_OBJECT_SET = deepcopy(FIXTURE_OBJECT).set_tags(**FIXTURE_TAG_SET) +FIXTURE_OBJECT_SET_TAGS = {"A": 42424243, "B": 3, "C": 1234, 3: "E", "E": 3} +FIXTURE_OBJECT_SET_DYN = {"A": 42424243, "B": 3, "E": 3} + + +def test_set_tags(): + """Tests set_tags method of BaseObject for correctness. + + Raises + ------ + AssertError if override logic in set_tags is incorrect + """ + msg = "Setter/override logic in BaseObject.set_tags is incorrect" + + assert FIXTURE_OBJECT_SET._tags_dynamic == FIXTURE_OBJECT_SET_DYN, msg + assert FIXTURE_OBJECT_SET.get_tags() == FIXTURE_OBJECT_SET_TAGS, msg diff --git a/sktime/classification/base.py b/sktime/classification/base.py index 3a849737d85..c5c0eb553fd 100644 --- a/sktime/classification/base.py +++ b/sktime/classification/base.py @@ -1,4 +1,30 @@ # -*- coding: utf-8 -*- +""" +Base class template for time series classifier scitype. + + class name: BaseClassifier + +Scitype defining methods: + fitting - fit(self, X, y) + predicting - predict(self, X) + +State: + fitted model/strategy - by convention, any attributes ending in "_" + fitted state flag - is_fitted (property) + fitted state inspection - check_is_fitted() + +Inspection methods: + hyper-parameter inspection - get_params() + fitted parameter inspection - get_fitted_params() + +State: + fitted model/strategy - by convention, any attributes ending in "_" + fitted state flag - is_fitted (property) + fitted state inspection - check_is_fitted() + +copyright: sktime developers, BSD-3-Clause License (see LICENSE file) +""" + __all__ = [ "BaseClassifier", "classifier_list", @@ -76,8 +102,7 @@ def fit(self, X, y): creates fitted model (attributes ending in "_") sets is_fitted flag to true """ - - coerce_to_numpy = self._all_tags()["coerce-X-to-numpy"] + coerce_to_numpy = self.get_class_tag("coerce-X-to-numpy", False) X, y = check_X_y(X, y, coerce_to_numpy=coerce_to_numpy) @@ -89,7 +114,7 @@ def fit(self, X, y): return self def predict(self, X): - """predicts labels for sequences in X + """Predicts labels for sequences in X. Parameters ---------- @@ -102,8 +127,7 @@ def predict(self, X): ------- y : array-like, shape = [n_instances] - predicted class labels """ - - coerce_to_numpy = self._all_tags()["coerce-X-to-numpy"] + coerce_to_numpy = self.get_class_tag("coerce-X-to-numpy", False) X = check_X(X, coerce_to_numpy=coerce_to_numpy) self.check_is_fitted() @@ -113,9 +137,36 @@ def predict(self, X): return y def predict_proba(self, X): + """Predicts labels probabilities for sequences in X. + + Parameters + ---------- + X : 3D np.array, array-like or sparse matrix + of shape = [n_instances,n_dimensions,series_length] + or shape = [n_instances,series_length] + or single-column pd.DataFrame with pd.Series entries + + Returns + ------- + y : array-like, shape = [n_instances, n_classes] - predictive pmf + """ raise NotImplementedError("abstract method") def score(self, X, y): + """Scores predicted labels against ground truth labels on X. + + Parameters + ---------- + X : 3D np.array, array-like or sparse matrix + of shape = [n_instances,n_dimensions,series_length] + or shape = [n_instances,series_length] + or single-column pd.DataFrame with pd.Series entries + y : array-like, shape = [n_instances] - predicted class labels + + Returns + ------- + float, accuracy score of predict(X) vs y + """ from sklearn.metrics import accuracy_score return accuracy_score(y, self.predict(X), normalize=True) @@ -144,7 +195,7 @@ def _fit(self, X, y): raise NotImplementedError("abstract method") def _predict(self, X): - """predicts labels for sequences in X + """Predicts labels for sequences in X. core logic @@ -159,7 +210,6 @@ def _predict(self, X): ------- y : array-like, shape = [n_instances] - predicted class labels """ - distributions = self.predict_proba(X) predictions = [] for instance_index in range(0, X.shape[0]): diff --git a/sktime/dists_kernels/_base.py b/sktime/dists_kernels/_base.py index e57b8a520b9..2f43043b4f9 100644 --- a/sktime/dists_kernels/_base.py +++ b/sktime/dists_kernels/_base.py @@ -1,6 +1,35 @@ # -*- coding: utf-8 -*- """ -Abstract base class for pairwise transformers (such as distance/kernel matrix makers) +Base class templates for distances or kernels between time series, and for tabular data. + +templates in this module: + + BasePairwiseTransformer - distances/kernels for tabular data + BasePairwiseTransformerPanel - distances/kernels for time series + +Interface specifications below. + +--- + class name: BasePairwiseTransformer + +Scitype defining methods: + computing distance/kernel matrix (shorthand) - __call__(self, X, X2=X) + computing distance/kernel matrix - transform(self, X, X2=X) + +Inspection methods: + hyper-parameter inspection - get_params() + +--- + class name: BasePairwiseTransformerPanel + +Scitype defining methods: + computing distance/kernel matrix (shorthand) - __call__(self, X, X2=X) + computing distance/kernel matrix - transform(self, X, X2=X) + +Inspection methods: + hyper-parameter inspection - get_params() + +copyright: sktime developers, BSD-3-Clause License (see LICENSE file) """ __author__ = ["fkiraly"] @@ -32,7 +61,8 @@ def __init__(self): self.X_equals_X2 = False def __call__(self, X, X2=None): - """ + """Compute distance/kernel matrix, call shorthand. + Behaviour: returns pairwise distance/kernel matrix between samples in X and X2 if X2 is not passed, is equal to X @@ -60,7 +90,8 @@ def __call__(self, X, X2=None): return self.transform(X=X, X2=X2) def transform(self, X, X2=None): - """ + """Compute distance/kernel matrix. + Behaviour: returns pairwise distance/kernel matrix between samples in X and X2 (equal to X if not passed) @@ -80,7 +111,6 @@ def transform(self, X, X2=None): X_equals_X2: bool = True if X2 was not passed, False if X2 was passed for use to make internal calculations efficient, e.g., in _transform """ - X = check_series(X) if X2 is None: @@ -102,12 +132,13 @@ def input_as_numpy(val): return self._transform(X=X, X2=X2) def _transform(self, X, X2=None): - """ + """Compute distance/kernel matrix. + + Core logic + Behaviour: returns pairwise distance/kernel matrix between samples in X and X2 (equal to X if not passed) - core logic - Parameters ---------- X: pd.DataFrame of length n, or 2D np.array with n rows @@ -122,12 +153,14 @@ def _transform(self, X, X2=None): raise NotImplementedError def fit(self, X=None, X2=None): + """Fit method for interface compatibility (no logic inside).""" # no fitting logic, but in case fit is called or expected pass def _pairwise_panel_x_check(X): - """ + """Check and coerce input data. + Method used to check the input and convert numpy 3d or numpy list of df to list of dfs @@ -192,7 +225,8 @@ def __init__(self): self.X_equals_X2 = False def __call__(self, X, X2=None): - """ + """Compute distance/kernel matrix, call shorthand. + Behaviour: returns pairwise distance/kernel matrix between samples in X and X2 (equal to X if not passed) @@ -217,7 +251,8 @@ def __call__(self, X, X2=None): return self.transform(X=X, X2=X2) def transform(self, X, X2=None): - """ + """Compute distance/kernel matrix. + Behaviour: returns pairwise distance/kernel matrix between samples in X and X2 (equal to X if not passed) @@ -237,7 +272,6 @@ def transform(self, X, X2=None): X_equals_X2: bool = True if X2 was not passed, False if X2 was passed for use to make internal calculations efficient, e.g., in _transform """ - X = _pairwise_panel_x_check(X) if X2 is None: @@ -254,12 +288,13 @@ def transform(self, X, X2=None): return self._transform(X=X, X2=X2) def _transform(self, X, X2=None): - """ + """Compute distance/kernel matrix. + + Core logic + Behaviour: returns pairwise distance/kernel matrix between samples in X and X2 (equal to X if not passed) - core logic - Parameters ---------- X: list of pd.DataFrame or 2D np.arrays, of length n @@ -274,5 +309,6 @@ def _transform(self, X, X2=None): raise NotImplementedError def fit(self, X=None, X2=None): + """Fit method for interface compatibility (no logic inside).""" # no fitting logic, but in case fit is called or expected pass diff --git a/sktime/dists_kernels/compose_tab_to_panel.py b/sktime/dists_kernels/compose_tab_to_panel.py index c655b9b253e..22cd7a22b33 100644 --- a/sktime/dists_kernels/compose_tab_to_panel.py +++ b/sktime/dists_kernels/compose_tab_to_panel.py @@ -1,6 +1,10 @@ # -*- coding: utf-8 -*- """ -Composers that create panel pairwise transformers from table pairwise transformers +Composers that create panel pairwise transformers from table pairwise transformers. + +Currently implemented composers in this module: + + AggrDist - panel distance from aggregation of tabular distance matrix entries """ __author__ = ["fkiraly"] @@ -11,7 +15,8 @@ class AggrDist(BasePairwiseTransformerPanel): - """ + """Panel distance from tabular distance aggregation. + panel distance obtained by applying aggregation function to tabular distance matrix example: AggrDist(ScipyDist()) is mean Euclidean distance between series @@ -46,24 +51,27 @@ def __init__( super(AggrDist, self).__init__() def _transform(self, X, X2=None): - """ + """Compute distance/kernel matrix. + + Core logic. + Behaviour: returns pairwise distance/kernel matrix between samples in X and X2 if X2 is not passed, is equal to X if X/X2 is a pd.DataFrame and contains non-numeric columns, these are removed before computation - Args: - X: list of pd.DataFrame or 2D np.arrays, of length n + Parameters + ---------- + X: pd.DataFrame of length n, or 2D np.array with n rows + X2: pd.DataFrame of length m, or 2D np.array with m rows, optional + default X2 = X - Optional args: - X2: list of pd.DataFrame or 2D np.arrays, of length m - - Returns: - distmat: np.array of shape [n, m] - (i,j)-th entry contains distance/kernel between X[i] and X2[j] + Returns + ------- + distmat: np.array of shape [n, m] + (i,j)-th entry contains distance/kernel between X.iloc[i] and X2.iloc[j] """ - n = len(X) m = len(X2) @@ -75,7 +83,7 @@ def _transform(self, X, X2=None): aggfunc = np.mean aggfunc_is_symm = True - transformer_symm = self.transformer._all_tags()["symmetric"] + transformer_symm = self.transformer.get_tag("symmetric", False) # whether we know that resulting matrix must be symmetric # a sufficient condition for this: diff --git a/sktime/dists_kernels/scipy_dist.py b/sktime/dists_kernels/scipy_dist.py index 8c7e3d947a2..5ff618db50f 100644 --- a/sktime/dists_kernels/scipy_dist.py +++ b/sktime/dists_kernels/scipy_dist.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- -""" +"""Interface module to scipy. + Interface module to scipy.spatial's pairwise distance function cdist exposes parameters as scikit-learn hyper-parameters """ @@ -14,29 +15,31 @@ class ScipyDist(BasePairwiseTransformer): - """ + """Interface to scipy distances. + computes pairwise distances using scipy.spatial.distance.cdist includes Euclidean distance and p-norm (Minkowski) distance note: weighted distances are not supported - Hyper-parameters: - metric: string or function, as in cdist; default = 'euclidean' - if string, one of: 'braycurtis', 'canberra', 'chebyshev', 'cityblock', - 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', - 'jensenshannon', 'kulsinski', 'mahalanobis', 'matching', 'minkowski', - 'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener', - 'sokalsneath', 'sqeuclidean', 'yule' - if function, should have signature 1D-np.array x 1D-np.array -> float - p: if metric='minkowski', the "p" in "p-norm", otherwise irrelevant - colalign: string, one of 'intersect' (default), 'force-align', 'none' - controls column alignment if X, X2 passed in fit are pd.DataFrame - columns between X and X2 are aligned via column names - if 'intersect', distance is computed on columns occurring both in X and X2, - other columns are discarded; column ordering in X2 is copied from X - if 'force-align', raises an error if the set of columns in X, X2 differs; - column ordering in X2 is copied from X - if 'none', X and X2 are passed through unmodified (no columns are aligned) - note: this will potentially align "non-matching" columns + Parameters + ---------- + metric: string or function, as in cdist; default = 'euclidean' + if string, one of: 'braycurtis', 'canberra', 'chebyshev', 'cityblock', + 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', + 'jensenshannon', 'kulsinski', 'mahalanobis', 'matching', 'minkowski', + 'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener', + 'sokalsneath', 'sqeuclidean', 'yule' + if function, should have signature 1D-np.array x 1D-np.array -> float + p: if metric='minkowski', the "p" in "p-norm", otherwise irrelevant + colalign: string, one of 'intersect' (default), 'force-align', 'none' + controls column alignment if X, X2 passed in fit are pd.DataFrame + columns between X and X2 are aligned via column names + if 'intersect', distance is computed on columns occurring both in X and X2, + other columns are discarded; column ordering in X2 is copied from X + if 'force-align', raises an error if the set of columns in X, X2 differs; + column ordering in X2 is copied from X + if 'none', X and X2 are passed through unmodified (no columns are aligned) + note: this will potentially align "non-matching" columns """ _tags = { @@ -52,25 +55,27 @@ def __init__(self, metric="euclidean", p=2, colalign="intersect"): super(ScipyDist, self).__init__() def _transform(self, X, X2=None): - """ + """Compute distance/kernel matrix. + + Core logic + Behaviour: returns pairwise distance/kernel matrix between samples in X and X2 if X2 is not passed, is equal to X if X/X2 is a pd.DataFrame and contains non-numeric columns, these are removed before computation - Args: - X: pd.DataFrame of length n, or 2D np.array of 'float' with n rows + Parameters + ---------- + X: pd.DataFrame of length n, or 2D np.array with n rows + X2: pd.DataFrame of length m, or 2D np.array with m rows, optional + default X2 = X - Optional args: - X2: pd.DataFrame of length m, or 2D np.array of 'float' with m rows - - Returns: - distmat: np.array of shape [n, m] - (i,j)-th entry contains distance between X.iloc[i] and X2.iloc[j] - (non-numeric columns are removed before for DataFrame X/X2) + Returns + ------- + distmat: np.array of shape [n, m] + (i,j)-th entry contains distance/kernel between X.iloc[i] and X2.iloc[j] """ - p = self.p metric = self.metric diff --git a/sktime/forecasting/base/_base.py b/sktime/forecasting/base/_base.py index 6be5daff6da..df403436242 100644 --- a/sktime/forecasting/base/_base.py +++ b/sktime/forecasting/base/_base.py @@ -181,7 +181,6 @@ def fit_predict( y_pred_int : pd.DataFrame - only if return_pred_int=True Prediction intervals """ - self.fit(y=y, X=X, fh=fh) return self._predict(fh=fh, X=X, return_pred_int=return_pred_int, alpha=alpha) @@ -550,7 +549,7 @@ def _set_fh(self, fh): ---------- fh : None, int, list, np.ndarray or ForecastingHorizon """ - requires_fh = self._all_tags().get("requires-fh-in-fit", True) + requires_fh = self.get_tag("requires-fh-in-fit") msg = ( f"This is because fitting of the `" @@ -735,7 +734,6 @@ def _compute_pred_int(self, alphas): Each series in the list will contain the errors for each point in the forecast for the corresponding alpha. """ - # this should be the NotImplementedError # but current interface assumes private method # _compute_pred_err(alphas), not _compute_pred_int @@ -746,7 +744,7 @@ def _compute_pred_int(self, alphas): # raise NotImplementedError("abstract method") def _compute_pred_err(self, alphas): - """ temporary loopthrough for _compute_pred_err""" + """Temporary loopthrough for _compute_pred_err.""" raise NotImplementedError("abstract method") def _predict_moving_cutoff( diff --git a/sktime/forecasting/compose/_pipeline.py b/sktime/forecasting/compose/_pipeline.py index 49dad3ce62b..cff5b8071ac 100644 --- a/sktime/forecasting/compose/_pipeline.py +++ b/sktime/forecasting/compose/_pipeline.py @@ -11,7 +11,6 @@ from sktime.forecasting.base._base import BaseForecaster from sktime.forecasting.base._base import DEFAULT_ALPHA from sktime.transformations.base import _SeriesToSeriesTransformer -from sktime.utils import _has_tag from sktime.utils.validation.series import check_series @@ -354,7 +353,7 @@ def _predict(self, fh=None, X=None, return_pred_int=False, alpha=DEFAULT_ALPHA): for _, _, transformer in self._iter_transformers(reverse=True): # skip sktime transformers where inverse transform # is not wanted ur meaningful (e.g. Imputer, HampelFilter) - skip_trafo = transformer._all_tags().get("skip-inverse-transform", False) + skip_trafo = transformer.get_tag("skip-inverse-transform", False) if not skip_trafo: y_pred = transformer.inverse_transform(y_pred) return y_pred @@ -394,6 +393,6 @@ def inverse_transform(self, Z, X=None): self.check_is_fitted() zt = check_series(Z, enforce_univariate=True) for _, _, transformer in self._iter_transformers(reverse=True): - if not _has_tag(transformer, "skip-inverse-transform"): + if not transformer.get_tag("skip-inverse-transform", False): zt = transformer.inverse_transform(zt, X) return zt diff --git a/sktime/forecasting/tests/test_sktime_forecasters.py b/sktime/forecasting/tests/test_sktime_forecasters.py index 5d654ca0503..f351c3fda8a 100644 --- a/sktime/forecasting/tests/test_sktime_forecasters.py +++ b/sktime/forecasting/tests/test_sktime_forecasters.py @@ -73,10 +73,12 @@ def test_oh_setting(Forecaster): # check setting/getting API for forecasting horizon # divide Forecasters into groups based on when fh is required -FORECASTERS_REQUIRED = [f for f in FORECASTERS if f._all_tags()["requires-fh-in-fit"]] +FORECASTERS_REQUIRED = [ + f for f in FORECASTERS if f.get_class_tag("requires-fh-in-fit", True) +] FORECASTERS_OPTIONAL = [ - f for f in FORECASTERS if not f._all_tags()["requires-fh-in-fit"] + f for f in FORECASTERS if not f.get_class_tag("requires-fh-in-fit", True) ] diff --git a/sktime/performance_metrics/base/_base.py b/sktime/performance_metrics/base/_base.py index db337d11930..9b14afa13f7 100644 --- a/sktime/performance_metrics/base/_base.py +++ b/sktime/performance_metrics/base/_base.py @@ -5,14 +5,13 @@ __author__ = ["Ryan Kuhns"] __all__ = ["BaseMetric"] -import inspect -from sklearn.base import BaseEstimator +from sktime.base import BaseObject -class BaseMetric(BaseEstimator): +class BaseMetric(BaseObject): """Base class for defining metrics in sktime. - Extends scikit-learn's BaseEstimator. + Extends sktime BaseObject. """ def __init__(self, func, name=None): @@ -22,27 +21,3 @@ def __init__(self, func, name=None): def __call__(self, y_true, y_pred, **kwargs): """Calculate metric value using underlying metric function.""" NotImplementedError("abstract method") - - # This is copied from sktime.base.BaseEstimator. Choice to copy was made to - # Avoid the not applicable functionality from BaseEstimator that tripped - # up unit tests (e.g. is_fitted, check_is_fitted). - @classmethod - def _all_tags(cls): - """Get tags from estimator class and all its parent classes.""" - # We here create a separate estimator tag interface in addition to the one in - # scikit-learn to make sure we do not interfere with scikit-learn's one - # when we inherit from scikit-learn classes. We also make estimator tags a - # class rather than object attribute. - collected_tags = dict() - - # We exclude the last two parent classes; sklearn.base.BaseEstimator and - # the basic Python object. - for parent_class in reversed(inspect.getmro(cls)[:-2]): - if hasattr(parent_class, "_tags"): - # Need the if here because mixins might not have _more_tags - # but might do redundant work in estimators - # (i.e. calling more tags on BaseEstimator multiple times) - more_tags = parent_class._tags - collected_tags.update(more_tags) - - return collected_tags diff --git a/sktime/transformations/tests/test_all_transformers.py b/sktime/transformations/tests/test_all_transformers.py index e7f0a012f7e..94fcbbf2d53 100644 --- a/sktime/transformations/tests/test_all_transformers.py +++ b/sktime/transformations/tests/test_all_transformers.py @@ -15,7 +15,6 @@ from sktime.transformations.base import _SeriesToPrimitivesTransformer from sktime.transformations.base import _SeriesToSeriesTransformer from sktime.utils import all_estimators -from sktime.utils import _has_tag from sktime.utils._testing.estimator_checks import _assert_array_almost_equal from sktime.utils._testing.estimator_checks import _construct_instance from sktime.utils._testing.estimator_checks import _make_args @@ -58,7 +57,7 @@ def check_series_to_primitive_transform_univariate(Estimator, **kwargs): def _check_raises_error(Estimator, **kwargs): with pytest.raises(ValueError, match=r"univariate"): - if _has_tag(Estimator, "fit-in-transform"): + if Estimator.get_class_tag("fit-in-transform", False): # As some estimators have an empty fit method, we here check if # they raise the appropriate error in transform rather than fit. _construct_fit_transform(Estimator, **kwargs) @@ -69,7 +68,7 @@ def _check_raises_error(Estimator, **kwargs): def check_series_to_primitive_transform_multivariate(Estimator): n_columns = 3 - if _has_tag(Estimator, "univariate-only"): + if Estimator.get_class_tag("univariate-only", False): _check_raises_error(Estimator, n_columns=n_columns) else: out = _construct_fit_transform(Estimator, n_columns=n_columns) @@ -82,7 +81,7 @@ def check_series_to_series_transform_univariate(Estimator): out = _construct_fit_transform( Estimator, n_timepoints=n_timepoints, - add_nan=_has_tag(Estimator, "handles-missing-data"), + add_nan=Estimator.get_class_tag("handles-missing-data", False), ) assert isinstance(out, (pd.Series, np.ndarray, pd.DataFrame)) @@ -90,7 +89,7 @@ def check_series_to_series_transform_univariate(Estimator): def check_series_to_series_transform_multivariate(Estimator): n_columns = 3 n_timepoints = 15 - if _has_tag(Estimator, "univariate-only"): + if Estimator.get_class_tag("univariate-only", False): _check_raises_error(Estimator, n_timepoints=n_timepoints, n_columns=n_columns) else: out = _construct_fit_transform( @@ -109,7 +108,7 @@ def check_panel_to_tabular_transform_univariate(Estimator): def check_panel_to_tabular_transform_multivariate(Estimator): n_instances = 5 - if _has_tag(Estimator, "univariate-only"): + if Estimator.get_class_tag("univariate-only", False): _check_raises_error(Estimator, n_instances=n_instances, n_columns=3) else: out = _construct_fit_transform(Estimator, n_instances=n_instances, n_columns=3) @@ -130,7 +129,7 @@ def check_panel_to_panel_transform_univariate(Estimator): def check_panel_to_panel_transform_multivariate(Estimator): n_instances = 5 - if _has_tag(Estimator, "univariate-only"): + if Estimator.get_class_tag("univariate-only", False): _check_raises_error(Estimator, n_instances=n_instances, n_columns=3) else: out = _construct_fit_transform(Estimator, n_instances=n_instances, n_columns=3) @@ -198,5 +197,5 @@ def _yield_transformer_checks(Estimator): yield from panel_to_tabular_checks if issubclass(Estimator, _PanelToPanelTransformer): yield from panel_to_panel_checks - if _has_tag(Estimator, "transform-returns-same-time-index"): + if Estimator.get_class_tag("transform-returns-same-time-index", False): yield check_transform_returns_same_time_index diff --git a/sktime/utils/__init__.py b/sktime/utils/__init__.py index 73367d331ed..15add14df6e 100644 --- a/sktime/utils/__init__.py +++ b/sktime/utils/__init__.py @@ -176,20 +176,3 @@ def _get_err_msg(estimator_type): return all_estimators else: return [estimator for (name, estimator) in all_estimators] - - -def _has_tag(Estimator, tag): - """Check whether an Estimator has the given tag or not. - - Parameters - ---------- - Estimator : Estimator class - tag : str - An Estimator tag like "skip-inverse-transform" - - Returns - ------- - bool - """ - # Check if tag is in all tags - return Estimator._all_tags().get(tag, False) diff --git a/sktime/utils/_testing/estimator_checks.py b/sktime/utils/_testing/estimator_checks.py index 660f0f2a6a7..6f9c33ca33b 100644 --- a/sktime/utils/_testing/estimator_checks.py +++ b/sktime/utils/_testing/estimator_checks.py @@ -48,7 +48,6 @@ from sktime.utils._testing.panel import make_regression_problem from sktime.utils._testing.panel import make_clustering_problem from sktime.utils.data_processing import is_nested_dataframe -from sktime.utils import _has_tag from sktime.clustering.base.base import BaseClusterer from sktime.annotation.base import BaseSeriesAnnotator @@ -131,8 +130,8 @@ def check_required_params(Estimator): def check_estimator_tags(Estimator): - assert hasattr(Estimator, "_all_tags") - all_tags = Estimator._all_tags() + assert hasattr(Estimator, "get_class_tags") + all_tags = Estimator.get_class_tags() assert isinstance(all_tags, dict) assert all([isinstance(key, str) for key in all_tags.keys()]) if hasattr(Estimator, "_tags"): @@ -421,7 +420,7 @@ def check_methods_do_not_change_state(Estimator): args = _make_args(estimator, method) getattr(estimator, method)(*args) - if method == "transform" and _has_tag(Estimator, "fit-in-transform"): + if method == "transform" and Estimator.get_class_tag("fit-in-transform"): # Some transformations fit during transform, as they apply # some transformation to each series passed to transform, # so transform will actually change the state of these estimator. @@ -506,7 +505,7 @@ def check_multiprocessing_idempotent(Estimator): def check_valid_estimator_tags(Estimator): # check if Estimator tags are in VALID_ESTIMATOR_TAGS - for tag in Estimator._all_tags().keys(): + for tag in Estimator.get_class_tags().keys(): assert tag in VALID_ESTIMATOR_TAGS diff --git a/sktime/utils/validation/forecasting.py b/sktime/utils/validation/forecasting.py index 4f834d87af0..e7800218382 100644 --- a/sktime/utils/validation/forecasting.py +++ b/sktime/utils/validation/forecasting.py @@ -19,7 +19,6 @@ import numpy as np import pandas as pd -from sktime.utils import _has_tag from sktime.utils.validation import is_int from sktime.utils.validation.series import check_equal_time_index from sktime.utils.validation.series import check_series @@ -346,7 +345,9 @@ def check_scoring(scoring, allow_y_pred_benchmark=False): if scoring is None: return MeanAbsolutePercentageError() - if _has_tag(scoring, "requires-y-pred-benchmark") and not allow_y_pred_benchmark: + scoring_req_bench = scoring.get_class_tag("requires-y-pred-benchmark", False) + + if scoring_req_bench and not allow_y_pred_benchmark: msg = """Scoring requiring benchmark forecasts (y_pred_benchmark) are not fully supported yet. Please use a performance metric that does not require y_pred_benchmark as a keyword argument in its call signature.