Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error handling for get_tag #1450

Merged
merged 17 commits into from
Oct 1, 2021
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion extension_templates/annotation.py
Expand Up @@ -114,7 +114,7 @@ def __init__(
# if est.foo == 42:
# self.set_tags(handles-missing-data=True)
# example 2: cloning tags from component
# self.clone_tags(est2, ["enforce-index-type", "handles-missing-data"])
# self.clone_tags(est2, ["enforce_index_type", "handles-missing-data"])

# todo: implement this, mandatory
def _fit(self, X, Y=None):
Expand Down
4 changes: 2 additions & 2 deletions extension_templates/classification.py
Expand Up @@ -69,7 +69,7 @@ class MyTSC(BaseClassifier):
_tags = {
"handles-missing-data": False, # can estimator handle missing data?
"X-y-must-have-same-index": True, # can estimator handle different X/y index?
"enforce-index-type": None, # index type that needs to be enforced in X/y
"enforce_index_type": None, # index type that needs to be enforced in X/y
}
# in case of inheritance, concrete class should typically set tags
# alternatively, descendants can set tags in __init__ (avoid this if possible)
Expand Down Expand Up @@ -103,7 +103,7 @@ def __init__(self, est, parama, est2=None, paramb="default", paramc=None):
# if est.foo == 42:
# self.set_tags(handles-missing-data=True)
# example 2: cloning tags from component
# self.clone_tags(est2, ["enforce-index-type", "handles-missing-data"])
# self.clone_tags(est2, ["enforce_index_type", "handles-missing-data"])

# todo: implement this, mandatory
def _fit(self, X, y):
Expand Down
2 changes: 1 addition & 1 deletion extension_templates/dist_kern_panel.py
Expand Up @@ -92,7 +92,7 @@ def __init__(self, est, parama, est2=None, paramb="default", paramc=None):
# if est.foo == 42:
# self.set_tags(handles-missing-data=True)
# example 2: cloning tags from component
# self.clone_tags(est2, ["enforce-index-type", "handles-missing-data"])
# self.clone_tags(est2, ["enforce_index_type", "handles-missing-data"])

# todo: implement this, mandatory
def _transform(self, X, X2=None):
Expand Down
2 changes: 1 addition & 1 deletion extension_templates/dist_kern_tab.py
Expand Up @@ -92,7 +92,7 @@ def __init__(self, est, parama, est2=None, paramb="default", paramc=None):
# if est.foo == 42:
# self.set_tags(handles-missing-data=True)
# example 2: cloning tags from component
# self.clone_tags(est2, ["enforce-index-type", "handles-missing-data"])
# self.clone_tags(est2, ["enforce_index_type", "handles-missing-data"])

# todo: implement this, mandatory
def _transform(self, X, X2=None):
Expand Down
4 changes: 2 additions & 2 deletions extension_templates/forecasting.py
Expand Up @@ -79,7 +79,7 @@ class MyForecaster(BaseForecaster):
"X_inner_mtype": "pd.DataFrame", # which types do _fit, _predict, assume for X?
"requires-fh-in-fit": True, # is forecasting horizon already required in fit?
"X-y-must-have-same-index": True, # can estimator handle different X/y index?
"enforce-index-type": None, # index type that needs to be enforced in X/y
"enforce_index_type": None, # index type that needs to be enforced in X/y
}
# in case of inheritance, concrete class should typically set tags
# alternatively, descendants can set tags in __init__ (avoid this if possible)
Expand Down Expand Up @@ -113,7 +113,7 @@ def __init__(self, est, parama, est2=None, paramb="default", paramc=None):
# if est.foo == 42:
# self.set_tags(handles-missing-data=True)
# example 2: cloning tags from component
# self.clone_tags(est2, ["enforce-index-type", "handles-missing-data"])
# self.clone_tags(est2, ["enforce_index_type", "handles-missing-data"])

# todo: implement this, mandatory
def _fit(self, y, X=None, fh=None):
Expand Down
22 changes: 17 additions & 5 deletions sktime/base/_base.py
Expand Up @@ -132,25 +132,37 @@ class attribute via nested inheritance and then any overrides

return deepcopy(collected_tags)

def get_tag(self, tag_name, tag_value_default=None):
def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
"""Get tag value from estimator class and dynamic tag overrides.

Parameters
----------
tag_name : str
Name of tag value.
tag_value_default : any type
Default/fallback value if tag is not found.
Name of tag to be retrieved
tag_value_default : any type, optional; default=None
Default/fallback value if tag is not found
raise_error : bool
whether a ValueError is raised when the tag is not found

Returns
-------
tag_value :
Value of the `tag_name` tag in self. If not found, returns
`tag_value_default`.

Raises
------
ValueError if raise_error is True and tag_name does not exist
i.e., if tag_name is not in self.get_tags().keys()
"""
collected_tags = self.get_tags()

return collected_tags.get(tag_name, tag_value_default)
tag_value = collected_tags.get(tag_name, tag_value_default)

if raise_error and tag_name not in collected_tags.keys():
raise ValueError(f"Tag with name {tag_name} could not be found.")

return tag_value

def set_tags(self, **tag_dict):
"""Set dynamic tags to given values.
Expand Down
1 change: 1 addition & 0 deletions sktime/classification/base.py
Expand Up @@ -74,6 +74,7 @@ class BaseClassifier(BaseEstimator):

_tags = {
"coerce-X-to-numpy": True,
"coerce-X-to-pandas": False,
}

def __init__(self):
Expand Down
2 changes: 1 addition & 1 deletion sktime/forecasting/base/_base.py
Expand Up @@ -71,7 +71,7 @@ class BaseForecaster(BaseEstimator):
"X_inner_mtype": "pd.DataFrame", # which types do _fit/_predict, support for X?
"requires-fh-in-fit": True, # is forecasting horizon already required in fit?
"X-y-must-have-same-index": True, # can estimator handle different X/y index?
"enforce-index-type": None, # index type that needs to be enforced in X/y
"enforce_index_type": None, # index type that needs to be enforced in X/y
}

def __init__(self):
Expand Down
2 changes: 1 addition & 1 deletion sktime/forecasting/model_selection/_tune.py
Expand Up @@ -65,7 +65,7 @@ def __init__(
"y_inner_mtype",
"X_inner_mtype",
"X-y-must-have-same-index",
"enforce-index-type",
"enforce_index_type",
]

self.clone_tags(forecaster, tags_to_clone)
Expand Down
2 changes: 1 addition & 1 deletion sktime/registry/_tags.py
Expand Up @@ -94,7 +94,7 @@
"do X/y in fit/update and X/fh in predict have to be same indices?",
),
(
"enforce-index-type",
"enforce_index_type",
["forecaster", "classifier", "regressor"],
"type",
"passed to input checks, input conversion index type to enforce",
Expand Down
35 changes: 25 additions & 10 deletions sktime/transformations/base.py
@@ -1,7 +1,7 @@
#!/usr/bin/env python3 -u
# -*- coding: utf-8 -*-
"""Base class template for transformers."""

__author__ = ["Markus Löning"]
__author__ = ["mloning"]
__all__ = [
"BaseTransformer",
"_SeriesToPrimitivesTransformer",
Expand Down Expand Up @@ -34,14 +34,25 @@


class BaseTransformer(BaseEstimator):
"""Transformer base class"""
"""Transformer base class."""

# default tag values - these typically make the "safest" assumption
_tags = {
"univariate-only": False, # can the transformer handle multivariate X?
"handles-missing-data": False, # can estimator handle missing data?
"X-y-must-have-same-index": False, # can estimator handle different X/y index?
"enforce_index_type": None, # index type that needs to be enforced in X/y
fkiraly marked this conversation as resolved.
Show resolved Hide resolved
"fit-in-transform": True, # is fit empty and can be skipped? Yes = True
"transform-returns-same-time-index": False,
# does transform return have the same time index as input X
"skip-inverse-transform": False, # is inverse-transform skipped when called?
}

def __init__(self):
super(BaseTransformer, self).__init__()

def fit(self, Z, X=None):
"""
Fit transformer to X and y.
"""Fit transformer to X and y.

By default, fit is empty. Fittable transformations overwrite fit method.

Expand All @@ -60,7 +71,7 @@ def fit(self, Z, X=None):
return self

def transform(self, Z, X=None):
"""Transform data. Returns a transformed version of X."""
"""Transform data. Returns a transformed version of Z."""
raise NotImplementedError("abstract method")

def fit_transform(self, Z, X=None):
Expand Down Expand Up @@ -98,28 +109,32 @@ def fit_transform(self, Z, X=None):


class _SeriesToPrimitivesTransformer(BaseTransformer):
"""Transformer base class for series to primitive(s) transforms"""
"""Transformer base class for series to primitive(s) transforms."""

def transform(self, Z: Series, X=None) -> Primitives:
"""Transform data. Returns a transformed version of Z."""
raise NotImplementedError("abstract method")


class _SeriesToSeriesTransformer(BaseTransformer):
"""Transformer base class for series to series transforms"""
"""Transformer base class for series to series transforms."""

def transform(self, Z: Series, X=None) -> Series:
"""Transform data. Returns a transformed version of Z."""
raise NotImplementedError("abstract method")


class _PanelToTabularTransformer(BaseTransformer):
"""Transformer base class for panel to tabular transforms"""
"""Transformer base class for panel to tabular transforms."""

def transform(self, X: Panel, y=None) -> Tabular:
"""Transform data. Returns a transformed version of X."""
raise NotImplementedError("abstract method")


class _PanelToPanelTransformer(BaseTransformer):
"""Transformer base class for panel to panel transforms"""
"""Transformer base class for panel to panel transforms."""

def transform(self, X: Panel, y=None) -> Panel:
"""Transform data. Returns a transformed version of X."""
raise NotImplementedError("abstract method")