Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BaseObject and rework of tags system, including dynamic tags #1134

Merged
merged 47 commits into from Jul 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
28421db
Added BaseObject
RNKuhns May 20, 2021
cb46ae4
Updated docs to comply with pydocstyle
RNKuhns May 20, 2021
6da97c0
added dynamic tag handling
fkiraly Jul 3, 2021
a8bb08f
Merge branch 'main' into 891-suggestions
fkiraly Jul 3, 2021
bf74e65
docstrings
fkiraly Jul 3, 2021
7a63b36
fixed BaseObject.__init__
fkiraly Jul 3, 2021
271f095
Merge branch 'main' into 891-suggestions
fkiraly Jul 5, 2021
645f297
Merge branch 'main' into 891-suggestions
fkiraly Jul 5, 2021
03a6104
Merge branch 'main' into 891-suggestions
fkiraly Jul 12, 2021
a86a661
Merge branch 'main' into 891-suggestions
fkiraly Jul 12, 2021
a23b77c
made changes to base class to comply with discussion result in #981
fkiraly Jul 13, 2021
952a9d3
BaseMetric inherits from BaseObject now
fkiraly Jul 13, 2021
d6052e1
changed all downstream references to new _base interface
fkiraly Jul 13, 2021
af2eb4e
one more
fkiraly Jul 13, 2021
2a424cd
removed _has_tag from utils, no longer used
fkiraly Jul 13, 2021
ad02fdb
linting
fkiraly Jul 13, 2021
2d41923
fixed baseobject import in metrics
fkiraly Jul 13, 2021
c9da460
updated basemetric docstring
fkiraly Jul 13, 2021
b3332c8
fixed get_class_tag reference in baseobject
fkiraly Jul 13, 2021
e18cb85
Update sktime/base/_base.py
fkiraly Jul 14, 2021
01b505f
Update sktime/base/_base.py
fkiraly Jul 14, 2021
326e028
renamed mirror_tags and tag_set to clone_tags and tag_names
fkiraly Jul 14, 2021
f084cc8
docstrings: renamed arguments to parameters
fkiraly Jul 14, 2021
d2db6ea
Merge branch 'main' into 891-refactor
fkiraly Jul 14, 2021
abae38a
updated module docstring
fkiraly Jul 14, 2021
0158fab
Merge branch '891-refactor' of https://github.com/alan-turing-institu…
fkiraly Jul 14, 2021
42871dd
removed type since class methods work no objects
fkiraly Jul 14, 2021
a0c0ba8
Merge branch 'main' into 891-refactor
fkiraly Jul 14, 2021
8ed82bf
Update sktime/base/_base.py
fkiraly Jul 15, 2021
3655e9e
added deepcopy in to return/input of dict
fkiraly Jul 15, 2021
95db8e2
Merge branch 'main' into 891-refactor
fkiraly Jul 15, 2021
d163a19
removed default tag from base forecaster _set_fh
fkiraly Jul 16, 2021
2f4e096
deep_equals utility for testing
fkiraly Jul 16, 2021
21a1f65
testing BaseObject methods
fkiraly Jul 16, 2021
790abfd
Merge branch 'main' into 891-refactor
fkiraly Jul 16, 2021
3b43b17
add comment for object fixture
fkiraly Jul 16, 2021
cba4d99
typo fixing
fkiraly Jul 16, 2021
c80fda1
added test contribution :smiley:
fkiraly Jul 16, 2021
feb01ac
changed all tag entries to primitives in test, removed deep_equals
fkiraly Jul 17, 2021
1d761a4
Merge branch 'main' into 891-refactor
fkiraly Jul 17, 2021
75edaa7
Merge branch 'main' into 891-refactor
fkiraly Jul 17, 2021
94a7834
made fixture variables private in test
fkiraly Jul 17, 2021
f733a7d
linting tests
fkiraly Jul 17, 2021
ae14882
classifier base docstrings
fkiraly Jul 17, 2021
07f3505
updated docstrings
fkiraly Jul 17, 2021
6c745eb
linting base forecaster
fkiraly Jul 17, 2021
3f84787
Merge branch 'main' into 891-refactor
fkiraly Jul 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .all-contributorsrc
Expand Up @@ -28,6 +28,7 @@
"question",
"review",
"talk",
"test",
"tutorial"
]
},
Expand Down
7 changes: 5 additions & 2 deletions sktime/base/__init__.py
Expand Up @@ -2,11 +2,14 @@
# -*- coding: utf-8 -*-
# copyright: sktime developers, BSD-3-Clause License (see LICENSE file)

__author__ = ["Markus Löning"]
"""Base classes for defining estimators and other objects in sktime."""

__author__ = ["mloning", "RNKuhns", "fkiraly"]
__all__ = [
"BaseObject",
"BaseEstimator",
"_HeterogenousMetaEstimator",
]

from sktime.base._base import BaseEstimator
from sktime.base._base import BaseObject, BaseEstimator
from sktime.base._meta import _HeterogenousMetaEstimator
215 changes: 188 additions & 27 deletions sktime/base/_base.py
@@ -1,27 +1,209 @@
#!/usr/bin/env python3 -u
# -*- coding: utf-8 -*-
# copyright: sktime developers, BSD-3-Clause License (see LICENSE file)
"""
Base class template for objects and fittable objects.

__author__ = ["Markus Löning"]
__all__ = ["BaseEstimator"]
templates in this module:

BaseObject - object with parameters and tags
BaseEstimator - BaseObject that can be fitted

Interface specifications below.

---

class name: BaseObject

Hyper-parameter inspection and setter methods:
inspect hyper-parameters - get_params()
setting hyper-parameters - set_params(**params)

Tag inspection and setter methods
inspect tags (all) - get_tags()
inspect tags (one tag) - get_tag(tag_name: str, tag_value_default=None)
inspect tags (class method) - get_class_tags()
inspect tags (one tag, class) - get_class_tag(tag_name:str, tag_value_default=None)
setting dynamic tags - set_tag(**tag_dict: dict)
set/clone dynamic tags - clone_tags(estimator, tag_names=None)

---

class name: BaseEstimator

Provides all interface points of BaseObject, plus:

Parameter inspection:
fitted parameter inspection - get_fitted_params()

State:
fitted model/strategy - by convention, any attributes ending in "_"
fitted state flag - is_fitted (property)
fitted state check - check_is_fitted (raises error if not is_fitted)

copyright: sktime developers, BSD-3-Clause License (see LICENSE file)
"""

__author__ = ["mloning", "RNKuhns", "fkiraly"]
__all__ = ["BaseEstimator", "BaseObject"]

import inspect

from copy import deepcopy

from sklearn import clone
from sklearn.base import BaseEstimator as _BaseEstimator
from sklearn.ensemble._base import _set_random_states

from sktime.exceptions import NotFittedError


class BaseEstimator(_BaseEstimator):
class BaseObject(_BaseEstimator):
"""Base class for parametric objects with tags sktime.

Extends scikit-learn's BaseEstimator to include sktime interface for tags.
"""

def __init__(self):
self._tags_dynamic = dict()
fkiraly marked this conversation as resolved.
Show resolved Hide resolved
super(BaseObject, self).__init__()

@classmethod
def get_class_tags(cls):
"""Get class tags from estimator class and all its parent classes.

Returns
-------
collected_tags : dictionary of tag names : tag values
collected from _tags class attribute via nested inheritance
NOT overridden by dynamic tags set by set_tags or mirror_tags
"""
collected_tags = dict()

# We exclude the last two parent classes: sklearn.base.BaseEstimator and
# the basic Python object.
for parent_class in reversed(inspect.getmro(cls)[:-2]):
mloning marked this conversation as resolved.
Show resolved Hide resolved
if hasattr(parent_class, "_tags"):
# Need the if here because mixins might not have _more_tags
# but might do redundant work in estimators
# (i.e. calling more tags on BaseEstimator multiple times)
more_tags = parent_class._tags
collected_tags.update(more_tags)

return deepcopy(collected_tags)

@classmethod
def get_class_tag(cls, tag_name, tag_value_default=None):
"""Get tag value from estimator class (only class tags).

Parameters
----------
tag_name : str, name of tag value
tag_value_default : any type, default/fallback value if tag is not found

Returns
-------
tag_value : value of the tag tag_name in self if found
if tag is not found, returns tag_value_default
"""
collected_tags = cls.get_class_tags()

return collected_tags.get(tag_name, tag_value_default)

def get_tags(self):
"""Get tags from estimator class and dynamic tag overrides.

Returns
-------
collected_tags : dictionary of tag names : tag values
collected from _tags class attribute via nested inheritance
then any overrides and new tags from _tags_dynamic object attribute
"""
collected_tags = self.get_class_tags()

if hasattr(self, "_tags_dynamic"):
collected_tags.update(self._tags_dynamic)

return deepcopy(collected_tags)

def get_tag(self, tag_name, tag_value_default=None):
"""Get tag value from estimator class and dynamic tag overrides.

Parameters
----------
tag_name : str, name of tag value
tag_value_default : any type, default/fallback value if tag is not found

Returns
-------
tag_value : value of the tag tag_name in self if found
if tag is not found, returns tag_value_default
"""
collected_tags = self.get_tags()

return collected_tags.get(tag_name, tag_value_default)

def set_tags(self, **tag_dict):
"""Set dynamic tags to given values.

Parameters
----------
tag_dict : dictionary of tag names : tag values

Returns
-------
reference to self

State change
------------
sets tag values in tag_dict as dynamic tags in self
"""
self._tags_dynamic.update(deepcopy(tag_dict))

return self

def clone_tags(self, estimator, tag_names=None):
"""clone/mirror tags from another estimator as dynamic override.

Parameters
----------
estimator : an estimator inheriting from BaseEstimator
tag_names : list of str, or str; names of tags to clone
default = list of all tags in estimator

Returns
-------
reference to self

State change
------------
sets tag values in tag_set from estimator as dynamic tags in self
"""
tags_est = deepcopy(estimator.get_tags())

# if tag_set is not passed, default is all tags in estimator
if tag_names is None:
tag_names = tags_est.keys()
else:
# if tag_set is passed, intersect keys with tags in estimator
if not isinstance(tag_names, list):
tag_names = [tag_names]
tag_names = [key for key in tag_names if key in tags_est.keys()]

update_dict = {key: tags_est[key] for key in tag_names}

self.set_tags(update_dict)

return self


class BaseEstimator(BaseObject):
"""Base class for defining estimators in sktime.

Extends scikit-learn's BaseEstimator.
Extends sktime's BaseObject to include basic functionality for fittable estimators.
"""

def __init__(self):
self._is_fitted = False
super(BaseEstimator, self).__init__()

@property
def is_fitted(self):
Expand All @@ -42,27 +224,6 @@ def check_is_fitted(self):
f"been fitted yet; please call `fit` first."
)

@classmethod
def _all_tags(cls):
"""Get tags from estimator class and all its parent classes."""
# We here create a separate estimator tag interface in addition to the one in
# scikit-learn to make sure we do not interfere with scikit-learn's one
# when we inherit from scikit-learn classes. We also make estimator tags a
# class rather than object attribute.
collected_tags = dict()

# We exclude the last two parent classes; sklearn.base.BaseEstimator and
# the basic Python object.
for parent_class in reversed(inspect.getmro(cls)[:-2]):
if hasattr(parent_class, "_tags"):
# Need the if here because mixins might not have _more_tags
# but might do redundant work in estimators
# (i.e. calling more tags on BaseEstimator multiple times)
more_tags = parent_class._tags
collected_tags.update(more_tags)

return collected_tags


def _clone_estimator(base_estimator, random_state=None):
estimator = clone(base_estimator)
Expand Down
7 changes: 4 additions & 3 deletions sktime/base/_meta.py
Expand Up @@ -11,16 +11,17 @@


class _HeterogenousMetaEstimator(BaseEstimator, metaclass=ABCMeta):
"""Handles parameter management for estimtators composed of named
estimators.
"""Handles parameter management for estimtators composed of named estimators.

from sklearn utils.metaestimator.py
From sklearn utils.metaestimator.py
"""

def get_params(self, deep=True):
"""Return estimator parameters."""
raise NotImplementedError("abstract method")

def set_params(self, **params):
"""Set estimator parameters."""
raise NotImplementedError("abstract method")

def _get_params(self, attr, deep=True):
Expand Down