Skip to content

Commit

Permalink
Error handling for get_tag (#1450)
Browse files Browse the repository at this point in the history
  • Loading branch information
fkiraly committed Oct 1, 2021
1 parent be2ddca commit f869fa4
Show file tree
Hide file tree
Showing 12 changed files with 69 additions and 28 deletions.
2 changes: 1 addition & 1 deletion extension_templates/annotation.py
Expand Up @@ -114,7 +114,7 @@ def __init__(
# if est.foo == 42:
# self.set_tags(handles-missing-data=True)
# example 2: cloning tags from component
# self.clone_tags(est2, ["enforce-index-type", "handles-missing-data"])
# self.clone_tags(est2, ["enforce_index_type", "handles-missing-data"])

# todo: implement this, mandatory
def _fit(self, X, Y=None):
Expand Down
4 changes: 2 additions & 2 deletions extension_templates/classification.py
Expand Up @@ -69,7 +69,7 @@ class MyTSC(BaseClassifier):
_tags = {
"handles-missing-data": False, # can estimator handle missing data?
"X-y-must-have-same-index": True, # can estimator handle different X/y index?
"enforce-index-type": None, # index type that needs to be enforced in X/y
"enforce_index_type": None, # index type that needs to be enforced in X/y
}
# in case of inheritance, concrete class should typically set tags
# alternatively, descendants can set tags in __init__ (avoid this if possible)
Expand Down Expand Up @@ -103,7 +103,7 @@ def __init__(self, est, parama, est2=None, paramb="default", paramc=None):
# if est.foo == 42:
# self.set_tags(handles-missing-data=True)
# example 2: cloning tags from component
# self.clone_tags(est2, ["enforce-index-type", "handles-missing-data"])
# self.clone_tags(est2, ["enforce_index_type", "handles-missing-data"])

# todo: implement this, mandatory
def _fit(self, X, y):
Expand Down
2 changes: 1 addition & 1 deletion extension_templates/dist_kern_panel.py
Expand Up @@ -92,7 +92,7 @@ def __init__(self, est, parama, est2=None, paramb="default", paramc=None):
# if est.foo == 42:
# self.set_tags(handles-missing-data=True)
# example 2: cloning tags from component
# self.clone_tags(est2, ["enforce-index-type", "handles-missing-data"])
# self.clone_tags(est2, ["enforce_index_type", "handles-missing-data"])

# todo: implement this, mandatory
def _transform(self, X, X2=None):
Expand Down
2 changes: 1 addition & 1 deletion extension_templates/dist_kern_tab.py
Expand Up @@ -92,7 +92,7 @@ def __init__(self, est, parama, est2=None, paramb="default", paramc=None):
# if est.foo == 42:
# self.set_tags(handles-missing-data=True)
# example 2: cloning tags from component
# self.clone_tags(est2, ["enforce-index-type", "handles-missing-data"])
# self.clone_tags(est2, ["enforce_index_type", "handles-missing-data"])

# todo: implement this, mandatory
def _transform(self, X, X2=None):
Expand Down
4 changes: 2 additions & 2 deletions extension_templates/forecasting.py
Expand Up @@ -79,7 +79,7 @@ class MyForecaster(BaseForecaster):
"X_inner_mtype": "pd.DataFrame", # which types do _fit, _predict, assume for X?
"requires-fh-in-fit": True, # is forecasting horizon already required in fit?
"X-y-must-have-same-index": True, # can estimator handle different X/y index?
"enforce-index-type": None, # index type that needs to be enforced in X/y
"enforce_index_type": None, # index type that needs to be enforced in X/y
}
# in case of inheritance, concrete class should typically set tags
# alternatively, descendants can set tags in __init__ (avoid this if possible)
Expand Down Expand Up @@ -113,7 +113,7 @@ def __init__(self, est, parama, est2=None, paramb="default", paramc=None):
# if est.foo == 42:
# self.set_tags(handles-missing-data=True)
# example 2: cloning tags from component
# self.clone_tags(est2, ["enforce-index-type", "handles-missing-data"])
# self.clone_tags(est2, ["enforce_index_type", "handles-missing-data"])

# todo: implement this, mandatory
def _fit(self, y, X=None, fh=None):
Expand Down
22 changes: 17 additions & 5 deletions sktime/base/_base.py
Expand Up @@ -132,25 +132,37 @@ class attribute via nested inheritance and then any overrides

return deepcopy(collected_tags)

def get_tag(self, tag_name, tag_value_default=None):
def get_tag(self, tag_name, tag_value_default=None, raise_error=True):
"""Get tag value from estimator class and dynamic tag overrides.
Parameters
----------
tag_name : str
Name of tag value.
tag_value_default : any type
Default/fallback value if tag is not found.
Name of tag to be retrieved
tag_value_default : any type, optional; default=None
Default/fallback value if tag is not found
raise_error : bool
whether a ValueError is raised when the tag is not found
Returns
-------
tag_value :
Value of the `tag_name` tag in self. If not found, returns
`tag_value_default`.
Raises
------
ValueError if raise_error is True and tag_name does not exist
i.e., if tag_name is not in self.get_tags().keys()
"""
collected_tags = self.get_tags()

return collected_tags.get(tag_name, tag_value_default)
tag_value = collected_tags.get(tag_name, tag_value_default)

if raise_error and tag_name not in collected_tags.keys():
raise ValueError(f"Tag with name {tag_name} could not be found.")

return tag_value

def set_tags(self, **tag_dict):
"""Set dynamic tags to given values.
Expand Down
19 changes: 16 additions & 3 deletions sktime/base/tests/test_base.py
Expand Up @@ -22,6 +22,8 @@
"test_set_tags",
]

import pytest

from copy import deepcopy

from sktime.base import BaseObject
Expand Down Expand Up @@ -118,10 +120,10 @@ def test_get_tag():
object_tags_keys = FIXTURE_OBJECT_TAGS.keys()

for key in object_tags_keys:
object_tags[key] = FIXTURE_OBJECT.get_tag(key)
object_tags[key] = FIXTURE_OBJECT.get_tag(key, raise_error=False)

object_tag_default = FIXTURE_OBJECT.get_tag("foo", "bar")
object_tag_defaultNone = FIXTURE_OBJECT.get_tag("bar")
object_tag_default = FIXTURE_OBJECT.get_tag("foo", "bar", raise_error=False)
object_tag_defaultNone = FIXTURE_OBJECT.get_tag("bar", raise_error=False)

msg = "Inheritance logic in BaseObject.get_tag is incorrect"

Expand All @@ -134,6 +136,17 @@ def test_get_tag():
assert object_tag_defaultNone is None, msg


def test_get_tag_raises():
"""Tests that get_tag method raises error for unknown tag.
Raises
------
AssertError if get_tag does not raise error for unknown tag.
"""
with pytest.raises(ValueError, match=r"Tag with name"):
FIXTURE_OBJECT.get_tag("bar")


FIXTURE_TAG_SET = {"A": 42424243, "E": 3}
FIXTURE_OBJECT_SET = deepcopy(FIXTURE_OBJECT).set_tags(**FIXTURE_TAG_SET)
FIXTURE_OBJECT_SET_TAGS = {"A": 42424243, "B": 3, "C": 1234, 3: "E", "E": 3}
Expand Down
1 change: 1 addition & 0 deletions sktime/classification/base.py
Expand Up @@ -74,6 +74,7 @@ class BaseClassifier(BaseEstimator):

_tags = {
"coerce-X-to-numpy": True,
"coerce-X-to-pandas": False,
}

def __init__(self):
Expand Down
2 changes: 1 addition & 1 deletion sktime/forecasting/base/_base.py
Expand Up @@ -71,7 +71,7 @@ class BaseForecaster(BaseEstimator):
"X_inner_mtype": "pd.DataFrame", # which types do _fit/_predict, support for X?
"requires-fh-in-fit": True, # is forecasting horizon already required in fit?
"X-y-must-have-same-index": True, # can estimator handle different X/y index?
"enforce-index-type": None, # index type that needs to be enforced in X/y
"enforce_index_type": None, # index type that needs to be enforced in X/y
}

def __init__(self):
Expand Down
2 changes: 1 addition & 1 deletion sktime/forecasting/model_selection/_tune.py
Expand Up @@ -65,7 +65,7 @@ def __init__(
"y_inner_mtype",
"X_inner_mtype",
"X-y-must-have-same-index",
"enforce-index-type",
"enforce_index_type",
]

self.clone_tags(forecaster, tags_to_clone)
Expand Down
2 changes: 1 addition & 1 deletion sktime/registry/_tags.py
Expand Up @@ -94,7 +94,7 @@
"do X/y in fit/update and X/fh in predict have to be same indices?",
),
(
"enforce-index-type",
"enforce_index_type",
["forecaster", "classifier", "regressor"],
"type",
"passed to input checks, input conversion index type to enforce",
Expand Down
35 changes: 25 additions & 10 deletions sktime/transformations/base.py
@@ -1,7 +1,7 @@
#!/usr/bin/env python3 -u
# -*- coding: utf-8 -*-
"""Base class template for transformers."""

__author__ = ["Markus Löning"]
__author__ = ["mloning"]
__all__ = [
"BaseTransformer",
"_SeriesToPrimitivesTransformer",
Expand Down Expand Up @@ -34,14 +34,25 @@


class BaseTransformer(BaseEstimator):
"""Transformer base class"""
"""Transformer base class."""

# default tag values - these typically make the "safest" assumption
_tags = {
"univariate-only": False, # can the transformer handle multivariate X?
"handles-missing-data": False, # can estimator handle missing data?
"X-y-must-have-same-index": False, # can estimator handle different X/y index?
"enforce_index_type": None, # index type that needs to be enforced in X/y
"fit-in-transform": True, # is fit empty and can be skipped? Yes = True
"transform-returns-same-time-index": False,
# does transform return have the same time index as input X
"skip-inverse-transform": False, # is inverse-transform skipped when called?
}

def __init__(self):
super(BaseTransformer, self).__init__()

def fit(self, Z, X=None):
"""
Fit transformer to X and y.
"""Fit transformer to X and y.
By default, fit is empty. Fittable transformations overwrite fit method.
Expand All @@ -60,7 +71,7 @@ def fit(self, Z, X=None):
return self

def transform(self, Z, X=None):
"""Transform data. Returns a transformed version of X."""
"""Transform data. Returns a transformed version of Z."""
raise NotImplementedError("abstract method")

def fit_transform(self, Z, X=None):
Expand Down Expand Up @@ -98,28 +109,32 @@ def fit_transform(self, Z, X=None):


class _SeriesToPrimitivesTransformer(BaseTransformer):
"""Transformer base class for series to primitive(s) transforms"""
"""Transformer base class for series to primitive(s) transforms."""

def transform(self, Z: Series, X=None) -> Primitives:
"""Transform data. Returns a transformed version of Z."""
raise NotImplementedError("abstract method")


class _SeriesToSeriesTransformer(BaseTransformer):
"""Transformer base class for series to series transforms"""
"""Transformer base class for series to series transforms."""

def transform(self, Z: Series, X=None) -> Series:
"""Transform data. Returns a transformed version of Z."""
raise NotImplementedError("abstract method")


class _PanelToTabularTransformer(BaseTransformer):
"""Transformer base class for panel to tabular transforms"""
"""Transformer base class for panel to tabular transforms."""

def transform(self, X: Panel, y=None) -> Tabular:
"""Transform data. Returns a transformed version of X."""
raise NotImplementedError("abstract method")


class _PanelToPanelTransformer(BaseTransformer):
"""Transformer base class for panel to panel transforms"""
"""Transformer base class for panel to panel transforms."""

def transform(self, X: Panel, y=None) -> Panel:
"""Transform data. Returns a transformed version of X."""
raise NotImplementedError("abstract method")

0 comments on commit f869fa4

Please sign in to comment.