Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Suggestions for PR #891 - BaseObject with dynamic tag handling functionality and dedicated get_tag interface point #1099

Closed
wants to merge 10 commits into from
7 changes: 5 additions & 2 deletions sktime/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
# -*- coding: utf-8 -*-
# copyright: sktime developers, BSD-3-Clause License (see LICENSE file)

__author__ = ["Markus Löning"]
"""Base classes for defining estimators and other objects in sktime."""

__author__ = ["mloning", "RNKuhns", "fkiraly"]
__all__ = [
"BaseObject",
"BaseEstimator",
"_HeterogenousMetaEstimator",
]

from sktime.base._base import BaseEstimator
from sktime.base._base import BaseObject, BaseEstimator
from sktime.base._meta import _HeterogenousMetaEstimator
195 changes: 168 additions & 27 deletions sktime/base/_base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,47 @@
#!/usr/bin/env python3 -u
# -*- coding: utf-8 -*-
# copyright: sktime developers, BSD-3-Clause License (see LICENSE file)
"""
Base class template for objects and fittable objects.

__author__ = ["Markus Löning"]
__all__ = ["BaseEstimator"]
templates in this module:

BaseObject - object with parameters and tags
BaseEstimator - BaseObject that can be fitted

Interface specifications below.

---

class name: BaseObject

Hyper-parameter inspection and setter methods:
inspect hyper-parameters - get_params()
setting hyper-parameters - set_params(**params)

Tag inspection and setter methods
inspect tags (all) - get_tags()
inspect tags (individual) - get_tag(tag_name: str)
setting dynamic tags - set_tag(tag_dict: dict)
set mirrored dynamic tags - mirror_tags(estimator)

---

class name: BaseEstimator

Provides all interface points of BaseObject, plus:

Parameter inspection:
fitted parameter inspection - get_fitted_params()

State:
fitted model/strategy - by convention, any attributes ending in "_"
fitted state flag - is_fitted (property)
fitted state check - check_is_fitted (raises error if not is_fitted)

copyright: sktime developers, BSD-3-Clause License (see LICENSE file)
"""

__author__ = ["mloning", "RNKuhns", "fkiraly"]
__all__ = ["BaseEstimator", "BaseObject"]

import inspect

Expand All @@ -14,14 +52,138 @@
from sktime.exceptions import NotFittedError


class BaseEstimator(_BaseEstimator):
class BaseObject(_BaseEstimator):
"""Base class for parametric objects with tags sktime.

Extends scikit-learn's BaseEstimator to include sktime interface for tags.
"""

def __init__(self):
self._tags_dynamic = dict()
super(BaseObject, self).__init__()

@classmethod
def _all_tags(cls):
"""Get tags from estimator class and all its parent classes.

Creates a separate sktime tag interface in addition to the one in
scikit-learn to make sure it does not interfere with scikit-learn's tag
interface when inheriting from scikit-learn classes. Sktime's
estimator tags are class rather than object attribute as in scikit-learn.
"""
collected_tags = dict()

# We exclude the last two parent classes: sklearn.base.BaseEstimator and
# the basic Python object.
for parent_class in reversed(inspect.getmro(cls)[:-2]):
if hasattr(parent_class, "_tags"):
# Need the if here because mixins might not have _more_tags
# but might do redundant work in estimators
# (i.e. calling more tags on BaseEstimator multiple times)
more_tags = parent_class._tags
collected_tags.update(more_tags)

return collected_tags

def get_tags(self):
"""Get tags from estimator class and dynamic tag overrides.

Returns
-------
collected_tags : dictionary of tag names : tag values
collected from _tags class attribute via nested inheritance
then any overrides and new tags from _tags_dynamic object attribute
"""
collected_tags = type(self)._all_tags().copy()

if hasattr(self, "_tags_dynamic"):
collected_tags.update(self._tags_dynamic)

return collected_tags

def get_tag(self, tag_name, tag_value_default=None):
"""Get tag value from estimator class and dynamic tag overrides.

Arguments
---------
tag_name : str, name of tag value
tag_value_default : any type, default/fallback value if tag is not found

Returns
-------
tag_value : value of the tag tag_name in self if found
if tag is not found, returns tag_value_default
"""
collected_tags = self.get_tags()

if tag_name in collected_tags.keys():
return collected_tags[tag_name]
else:
return tag_value_default

def set_tags(self, tag_dict):
"""Set dynamic tags to given values.

Arguments
---------
tag_dict : dictionary of tag names : tag values

Returns
-------
reference to self

State change
------------
sets tag values in tag_dict as dynamic tags in self
"""
self._tags_dynamic.update(tag_dict.copy())

return self

def mirror_tags(self, estimator, tag_set=None):
"""Mirror tags from estimator as dynamic override.

Arguments
---------
estimator : an estimator inheriting from BaseEstimator
tag_set : list of str, or str; tag names
default = list of all tags in estimator

Returns
-------
reference to self

State change
------------
sets tag values in tag_set from estimator as dynamic tags in self
"""
tags_est = estimator.get_tags().copy()

# if tag_set is not passed, default is all tags in estimator
if tag_set is None:
tag_set = tags_est.keys()
else:
# if tag_set is passed, intersect keys with tags in estimator
if not isinstance(tag_set, list):
tag_set = [tag_set]
tag_set = [key for key in tag_set if key in tags_est.keys()]

update_dict = {key: tags_est[key] for key in tag_set}

self.set_tags(update_dict)

return self


class BaseEstimator(BaseObject):
"""Base class for defining estimators in sktime.

Extends scikit-learn's BaseEstimator.
Extends sktime's BaseObject to include basic functionality for fittable estimators.
"""

def __init__(self):
self._is_fitted = False
super(BaseEstimator, self).__init__()

@property
def is_fitted(self):
Expand All @@ -42,27 +204,6 @@ def check_is_fitted(self):
f"been fitted yet; please call `fit` first."
)

@classmethod
def _all_tags(cls):
"""Get tags from estimator class and all its parent classes."""
# We here create a separate estimator tag interface in addition to the one in
# scikit-learn to make sure we do not interfere with scikit-learn's one
# when we inherit from scikit-learn classes. We also make estimator tags a
# class rather than object attribute.
collected_tags = dict()

# We exclude the last two parent classes; sklearn.base.BaseEstimator and
# the basic Python object.
for parent_class in reversed(inspect.getmro(cls)[:-2]):
if hasattr(parent_class, "_tags"):
# Need the if here because mixins might not have _more_tags
# but might do redundant work in estimators
# (i.e. calling more tags on BaseEstimator multiple times)
more_tags = parent_class._tags
collected_tags.update(more_tags)

return collected_tags


def _clone_estimator(base_estimator, random_state=None):
estimator = clone(base_estimator)
Expand Down
7 changes: 4 additions & 3 deletions sktime/base/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,17 @@


class _HeterogenousMetaEstimator(BaseEstimator, metaclass=ABCMeta):
"""Handles parameter management for estimtators composed of named
estimators.
"""Handles parameter management for estimtators composed of named estimators.

from sklearn utils.metaestimator.py
From sklearn utils.metaestimator.py
"""

def get_params(self, deep=True):
"""Return estimator parameters."""
raise NotImplementedError("abstract method")

def set_params(self, **params):
"""Set estimator parameters."""
raise NotImplementedError("abstract method")

def _get_params(self, attr, deep=True):
Expand Down