Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BaseObject and rework of tags system, including dynamic tags #1134

Merged
merged 47 commits into from Jul 17, 2021
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
28421db
Added BaseObject
RNKuhns May 20, 2021
cb46ae4
Updated docs to comply with pydocstyle
RNKuhns May 20, 2021
6da97c0
added dynamic tag handling
fkiraly Jul 3, 2021
a8bb08f
Merge branch 'main' into 891-suggestions
fkiraly Jul 3, 2021
bf74e65
docstrings
fkiraly Jul 3, 2021
7a63b36
fixed BaseObject.__init__
fkiraly Jul 3, 2021
271f095
Merge branch 'main' into 891-suggestions
fkiraly Jul 5, 2021
645f297
Merge branch 'main' into 891-suggestions
fkiraly Jul 5, 2021
03a6104
Merge branch 'main' into 891-suggestions
fkiraly Jul 12, 2021
a86a661
Merge branch 'main' into 891-suggestions
fkiraly Jul 12, 2021
a23b77c
made changes to base class to comply with discussion result in #981
fkiraly Jul 13, 2021
952a9d3
BaseMetric inherits from BaseObject now
fkiraly Jul 13, 2021
d6052e1
changed all downstream references to new _base interface
fkiraly Jul 13, 2021
af2eb4e
one more
fkiraly Jul 13, 2021
2a424cd
removed _has_tag from utils, no longer used
fkiraly Jul 13, 2021
ad02fdb
linting
fkiraly Jul 13, 2021
2d41923
fixed baseobject import in metrics
fkiraly Jul 13, 2021
c9da460
updated basemetric docstring
fkiraly Jul 13, 2021
b3332c8
fixed get_class_tag reference in baseobject
fkiraly Jul 13, 2021
e18cb85
Update sktime/base/_base.py
fkiraly Jul 14, 2021
01b505f
Update sktime/base/_base.py
fkiraly Jul 14, 2021
326e028
renamed mirror_tags and tag_set to clone_tags and tag_names
fkiraly Jul 14, 2021
f084cc8
docstrings: renamed arguments to parameters
fkiraly Jul 14, 2021
d2db6ea
Merge branch 'main' into 891-refactor
fkiraly Jul 14, 2021
abae38a
updated module docstring
fkiraly Jul 14, 2021
0158fab
Merge branch '891-refactor' of https://github.com/alan-turing-institu…
fkiraly Jul 14, 2021
42871dd
removed type since class methods work no objects
fkiraly Jul 14, 2021
a0c0ba8
Merge branch 'main' into 891-refactor
fkiraly Jul 14, 2021
8ed82bf
Update sktime/base/_base.py
fkiraly Jul 15, 2021
3655e9e
added deepcopy in to return/input of dict
fkiraly Jul 15, 2021
95db8e2
Merge branch 'main' into 891-refactor
fkiraly Jul 15, 2021
d163a19
removed default tag from base forecaster _set_fh
fkiraly Jul 16, 2021
2f4e096
deep_equals utility for testing
fkiraly Jul 16, 2021
21a1f65
testing BaseObject methods
fkiraly Jul 16, 2021
790abfd
Merge branch 'main' into 891-refactor
fkiraly Jul 16, 2021
3b43b17
add comment for object fixture
fkiraly Jul 16, 2021
cba4d99
typo fixing
fkiraly Jul 16, 2021
c80fda1
added test contribution :smiley:
fkiraly Jul 16, 2021
feb01ac
changed all tag entries to primitives in test, removed deep_equals
fkiraly Jul 17, 2021
1d761a4
Merge branch 'main' into 891-refactor
fkiraly Jul 17, 2021
75edaa7
Merge branch 'main' into 891-refactor
fkiraly Jul 17, 2021
94a7834
made fixture variables private in test
fkiraly Jul 17, 2021
f733a7d
linting tests
fkiraly Jul 17, 2021
ae14882
classifier base docstrings
fkiraly Jul 17, 2021
07f3505
updated docstrings
fkiraly Jul 17, 2021
6c745eb
linting base forecaster
fkiraly Jul 17, 2021
3f84787
Merge branch 'main' into 891-refactor
fkiraly Jul 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 5 additions & 2 deletions sktime/base/__init__.py
Expand Up @@ -2,11 +2,14 @@
# -*- coding: utf-8 -*-
# copyright: sktime developers, BSD-3-Clause License (see LICENSE file)

__author__ = ["Markus Löning"]
"""Base classes for defining estimators and other objects in sktime."""

__author__ = ["mloning", "RNKuhns", "fkiraly"]
__all__ = [
"BaseObject",
"BaseEstimator",
"_HeterogenousMetaEstimator",
]

from sktime.base._base import BaseEstimator
from sktime.base._base import BaseObject, BaseEstimator
from sktime.base._meta import _HeterogenousMetaEstimator
219 changes: 192 additions & 27 deletions sktime/base/_base.py
@@ -1,9 +1,49 @@
#!/usr/bin/env python3 -u
# -*- coding: utf-8 -*-
# copyright: sktime developers, BSD-3-Clause License (see LICENSE file)
"""
Base class template for objects and fittable objects.

__author__ = ["Markus Löning"]
__all__ = ["BaseEstimator"]
templates in this module:

BaseObject - object with parameters and tags
BaseEstimator - BaseObject that can be fitted

Interface specifications below.

---

class name: BaseObject

Hyper-parameter inspection and setter methods:
inspect hyper-parameters - get_params()
setting hyper-parameters - set_params(**params)

Tag inspection and setter methods
inspect tags (all) - get_tags()
inspect tags (one tag) - get_tag(tag_name: str, tag_value_default=None)
inspect tags (class only) - get_class_tags()
inspect tags (class, one) - get_class_tag(tag_name:str, tag_value_default=None)
setting dynamic tags - set_tag(**tag_dict: dict)
set mirrored dynamic tags - mirror_tags(estimator)

---

class name: BaseEstimator

Provides all interface points of BaseObject, plus:

Parameter inspection:
fitted parameter inspection - get_fitted_params()

State:
fitted model/strategy - by convention, any attributes ending in "_"
fitted state flag - is_fitted (property)
fitted state check - check_is_fitted (raises error if not is_fitted)

copyright: sktime developers, BSD-3-Clause License (see LICENSE file)
"""

__author__ = ["mloning", "RNKuhns", "fkiraly"]
__all__ = ["BaseEstimator", "BaseObject"]

import inspect

Expand All @@ -14,14 +54,160 @@
from sktime.exceptions import NotFittedError


class BaseEstimator(_BaseEstimator):
class BaseObject(_BaseEstimator):
"""Base class for parametric objects with tags sktime.

Extends scikit-learn's BaseEstimator to include sktime interface for tags.
"""

def __init__(self):
self._tags_dynamic = dict()
fkiraly marked this conversation as resolved.
Show resolved Hide resolved
super(BaseObject, self).__init__()

@classmethod
def get_class_tags(cls):
"""Get class tags from estimator class and all its parent classes.

Returns
-------
collected_tags : dictionary of tag names : tag values
collected from _tags class attribute via nested inheritance
NOT overridden by dynamic tags set by set_tags or mirror_tags
"""
collected_tags = dict()

# We exclude the last two parent classes: sklearn.base.BaseEstimator and
# the basic Python object.
for parent_class in reversed(inspect.getmro(cls)[:-2]):
mloning marked this conversation as resolved.
Show resolved Hide resolved
if hasattr(parent_class, "_tags"):
# Need the if here because mixins might not have _more_tags
# but might do redundant work in estimators
# (i.e. calling more tags on BaseEstimator multiple times)
more_tags = parent_class._tags
collected_tags.update(more_tags)

return collected_tags

@classmethod
def get_class_tag(cls, tag_name, tag_value_default=None):
"""Get tag value from estimator class (only class tags).

Arguments
---------
tag_name : str, name of tag value
tag_value_default : any type, default/fallback value if tag is not found

Returns
-------
tag_value : value of the tag tag_name in self if found
if tag is not found, returns tag_value_default
"""
collected_tags = cls.get_class_tags()

if tag_name in collected_tags.keys():
return collected_tags[tag_name]
else:
return tag_value_default
fkiraly marked this conversation as resolved.
Show resolved Hide resolved

def get_tags(self):
"""Get tags from estimator class and dynamic tag overrides.

Returns
-------
collected_tags : dictionary of tag names : tag values
collected from _tags class attribute via nested inheritance
then any overrides and new tags from _tags_dynamic object attribute
"""
collected_tags = type(self).get_class_tags().copy()
fkiraly marked this conversation as resolved.
Show resolved Hide resolved

if hasattr(self, "_tags_dynamic"):
collected_tags.update(self._tags_dynamic)

return collected_tags

def get_tag(self, tag_name, tag_value_default=None):
"""Get tag value from estimator class and dynamic tag overrides.

Arguments
fkiraly marked this conversation as resolved.
Show resolved Hide resolved
---------
tag_name : str, name of tag value
tag_value_default : any type, default/fallback value if tag is not found

Returns
-------
tag_value : value of the tag tag_name in self if found
if tag is not found, returns tag_value_default
"""
collected_tags = self.get_tags()

if tag_name in collected_tags.keys():
return collected_tags[tag_name]
else:
return tag_value_default
fkiraly marked this conversation as resolved.
Show resolved Hide resolved

def set_tags(self, **tag_dict):
"""Set dynamic tags to given values.

Arguments
---------
tag_dict : dictionary of tag names : tag values

Returns
-------
reference to self

State change
------------
sets tag values in tag_dict as dynamic tags in self
"""
self._tags_dynamic.update(tag_dict.copy())
mloning marked this conversation as resolved.
Show resolved Hide resolved

return self

def mirror_tags(self, estimator, tag_set=None):
fkiraly marked this conversation as resolved.
Show resolved Hide resolved
"""Mirror tags from estimator as dynamic override.

Arguments
---------
estimator : an estimator inheriting from BaseEstimator
tag_set : list of str, or str; tag names
default = list of all tags in estimator

Returns
-------
reference to self

State change
------------
sets tag values in tag_set from estimator as dynamic tags in self
"""
tags_est = estimator.get_tags().copy()
fkiraly marked this conversation as resolved.
Show resolved Hide resolved

# if tag_set is not passed, default is all tags in estimator
if tag_set is None:
tag_set = tags_est.keys()
else:
# if tag_set is passed, intersect keys with tags in estimator
if not isinstance(tag_set, list):
tag_set = [tag_set]
tag_set = [key for key in tag_set if key in tags_est.keys()]

update_dict = {key: tags_est[key] for key in tag_set}

self.set_tags(update_dict)

return self


class BaseEstimator(BaseObject):
"""Base class for defining estimators in sktime.

Extends scikit-learn's BaseEstimator.
Extends sktime's BaseObject to include basic functionality for fittable estimators.
"""

def __init__(self):
self._is_fitted = False
super(BaseEstimator, self).__init__()

@property
def is_fitted(self):
Expand All @@ -42,27 +228,6 @@ def check_is_fitted(self):
f"been fitted yet; please call `fit` first."
)

@classmethod
def _all_tags(cls):
"""Get tags from estimator class and all its parent classes."""
# We here create a separate estimator tag interface in addition to the one in
# scikit-learn to make sure we do not interfere with scikit-learn's one
# when we inherit from scikit-learn classes. We also make estimator tags a
# class rather than object attribute.
collected_tags = dict()

# We exclude the last two parent classes; sklearn.base.BaseEstimator and
# the basic Python object.
for parent_class in reversed(inspect.getmro(cls)[:-2]):
if hasattr(parent_class, "_tags"):
# Need the if here because mixins might not have _more_tags
# but might do redundant work in estimators
# (i.e. calling more tags on BaseEstimator multiple times)
more_tags = parent_class._tags
collected_tags.update(more_tags)

return collected_tags


def _clone_estimator(base_estimator, random_state=None):
estimator = clone(base_estimator)
Expand Down
7 changes: 4 additions & 3 deletions sktime/base/_meta.py
Expand Up @@ -11,16 +11,17 @@


class _HeterogenousMetaEstimator(BaseEstimator, metaclass=ABCMeta):
"""Handles parameter management for estimtators composed of named
estimators.
"""Handles parameter management for estimtators composed of named estimators.

from sklearn utils.metaestimator.py
From sklearn utils.metaestimator.py
"""

def get_params(self, deep=True):
"""Return estimator parameters."""
raise NotImplementedError("abstract method")

def set_params(self, **params):
"""Set estimator parameters."""
raise NotImplementedError("abstract method")

def _get_params(self, attr, deep=True):
Expand Down
4 changes: 2 additions & 2 deletions sktime/classification/base.py
Expand Up @@ -77,7 +77,7 @@ def fit(self, X, y):
sets is_fitted flag to true
"""

coerce_to_numpy = self._all_tags()["coerce-X-to-numpy"]
coerce_to_numpy = self.get_class_tag("coerce-X-to-numpy", False)
mloning marked this conversation as resolved.
Show resolved Hide resolved

X, y = check_X_y(X, y, coerce_to_numpy=coerce_to_numpy)

Expand All @@ -103,7 +103,7 @@ def predict(self, X):
y : array-like, shape = [n_instances] - predicted class labels
"""

coerce_to_numpy = self._all_tags()["coerce-X-to-numpy"]
coerce_to_numpy = self.get_class_tag("coerce-X-to-numpy", False)

X = check_X(X, coerce_to_numpy=coerce_to_numpy)
self.check_is_fitted()
Expand Down
2 changes: 1 addition & 1 deletion sktime/dists_kernels/compose_tab_to_panel.py
Expand Up @@ -75,7 +75,7 @@ def _transform(self, X, X2=None):
aggfunc = np.mean
aggfunc_is_symm = True

transformer_symm = self.transformer._all_tags()["symmetric"]
transformer_symm = self.transformer.get_tag("symmetric", False)

# whether we know that resulting matrix must be symmetric
# a sufficient condition for this:
Expand Down
2 changes: 1 addition & 1 deletion sktime/forecasting/base/_base.py
Expand Up @@ -550,7 +550,7 @@ def _set_fh(self, fh):
----------
fh : None, int, list, np.ndarray or ForecastingHorizon
"""
requires_fh = self._all_tags().get("requires-fh-in-fit", True)
requires_fh = self.get_tag("requires-fh-in-fit", True)

msg = (
f"This is because fitting of the `"
Expand Down
5 changes: 2 additions & 3 deletions sktime/forecasting/compose/_pipeline.py
Expand Up @@ -11,7 +11,6 @@
from sktime.forecasting.base._base import BaseForecaster
from sktime.forecasting.base._base import DEFAULT_ALPHA
from sktime.transformations.base import _SeriesToSeriesTransformer
from sktime.utils import _has_tag
from sktime.utils.validation.series import check_series


Expand Down Expand Up @@ -354,7 +353,7 @@ def _predict(self, fh=None, X=None, return_pred_int=False, alpha=DEFAULT_ALPHA):
for _, _, transformer in self._iter_transformers(reverse=True):
# skip sktime transformers where inverse transform
# is not wanted ur meaningful (e.g. Imputer, HampelFilter)
skip_trafo = transformer._all_tags().get("skip-inverse-transform", False)
skip_trafo = transformer.get_tag("skip-inverse-transform", False)
if not skip_trafo:
y_pred = transformer.inverse_transform(y_pred)
return y_pred
Expand Down Expand Up @@ -394,6 +393,6 @@ def inverse_transform(self, Z, X=None):
self.check_is_fitted()
zt = check_series(Z, enforce_univariate=True)
for _, _, transformer in self._iter_transformers(reverse=True):
if not _has_tag(transformer, "skip-inverse-transform"):
if not transformer.get_tag("skip-inverse-transform", False):
zt = transformer.inverse_transform(zt, X)
return zt
6 changes: 4 additions & 2 deletions sktime/forecasting/tests/test_sktime_forecasters.py
Expand Up @@ -73,10 +73,12 @@ def test_oh_setting(Forecaster):
# check setting/getting API for forecasting horizon

# divide Forecasters into groups based on when fh is required
FORECASTERS_REQUIRED = [f for f in FORECASTERS if f._all_tags()["requires-fh-in-fit"]]
FORECASTERS_REQUIRED = [
f for f in FORECASTERS if f.get_class_tag("requires-fh-in-fit", True)
]

FORECASTERS_OPTIONAL = [
f for f in FORECASTERS if not f._all_tags()["requires-fh-in-fit"]
f for f in FORECASTERS if not f.get_class_tag("requires-fh-in-fit", True)
]


Expand Down