Skip to content

Commit

Permalink
Merge branch 'distance_refactor' of https://github.com/alan-turing-in…
Browse files Browse the repository at this point in the history
…stitute/sktime into distance_refactor
  • Loading branch information
Tony Bagnall committed Apr 9, 2022
2 parents f5e81ff + b76807f commit 7deed44
Show file tree
Hide file tree
Showing 8 changed files with 513 additions and 122 deletions.
134 changes: 96 additions & 38 deletions sktime/base/_meta.py
Expand Up @@ -7,6 +7,7 @@
__all__ = ["_HeterogenousMetaEstimator"]

from abc import ABCMeta
from inspect import isclass

from sklearn import clone

Expand Down Expand Up @@ -100,7 +101,35 @@ def _subset_dict_keys(self, dict_to_subset, keys):
subsetted_dict = dict((k, dict_to_subset[k]) for k in keys_in_both)
return subsetted_dict

def _check_estimators(self, estimators, attr_name="steps", cls_type=None):
@staticmethod
def _is_name_and_est(obj, cls_type=None):
"""Check whether obj is a tuple of type (str, cls_type).
Parameters
----------
cls_type : class or tuple of class, optional. Default = BaseEstimator.
class(es) that all estimators are checked to be an instance of
Returns
-------
bool : True if obj is (str, cls_type) tuple, False otherise
"""
if cls_type is None:
cls_type = BaseEstimator
if not isinstance(obj, tuple) or len(obj) != 2:
return False
if not isinstance(obj[0], str) or not isinstance(obj[1], cls_type):
return False
return True

def _check_estimators(
self,
estimators,
attr_name="steps",
cls_type=None,
allow_mix=True,
clone_ests=True,
):
"""Check that estimators is a list of estimators or list of str/est tuples.
Parameters
Expand All @@ -110,8 +139,12 @@ def _check_estimators(self, estimators, attr_name="steps", cls_type=None):
estimators should inherit from cls_type class
attr_name : str, optional. Default = "steps"
Name of checked attribute in error messages
cls_type : class, optional. Default = BaseEstimator.
class that all estimators are checked to be an instance of
cls_type : class or tuple of class, optional. Default = BaseEstimator.
class(es) that all estimators are checked to be an instance of
allow_mix : boolean, optional. Default = True.
whether mix of estimator and (str, estimator) is allowed in `estimators`
clone_ests : boolean, optional. Default = True.
whether estimators in return are cloned (True) or references (False).
Returns
-------
Expand All @@ -129,9 +162,15 @@ class that all estimators are checked to be an instance of
" of estimators, or a list of (string, estimator) tuples. "
)
if cls_type is None:
msg += f"All estimators in '{attr_name}' must be of type BaseEstimator."
cls_type = BaseEstimator
elif isclass(cls_type) or isinstance(cls_type, tuple):
msg += (
f"All estimators in '{attr_name}' must be of type "
f"{cls_type.__name__}."
)
else:
msg += f"All estimators must be of type {cls_type}."
raise TypeError("cls_type must be a class or tuple of classes")

if (
estimators is None
Expand All @@ -140,44 +179,74 @@ class that all estimators are checked to be an instance of
):
raise TypeError(msg)

if not isinstance(estimators[0], (cls_type, tuple)):
def is_est_is_tuple(obj):
"""Check whether obj is estimator of right type, or (str, est) tuple."""
is_est = isinstance(obj, cls_type)
is_tuple = self._is_name_and_est(obj, cls_type)

return is_est, is_tuple

if not all(any(is_est_is_tuple(x)) for x in estimators):
raise TypeError(msg)

if isinstance(estimators[0], cls_type):
if not all(isinstance(est, cls_type) for est in estimators):
raise TypeError(msg)
if isinstance(estimators[0], tuple):
if not all(isinstance(est, tuple) for est in estimators):
raise TypeError(msg)
if not all(isinstance(est[0], str) for est in estimators):
raise TypeError(msg)
if not all(isinstance(est[1], cls_type) for est in estimators):
raise TypeError(msg)
msg_no_mix = (
f"elements of {attr_name} must either all be estimators, "
f"or all (str, estimator) tuples, mix of the two is not allowed"
)

if not allow_mix and not all(is_est_is_tuple(x)[0] for x in estimators):
if not all(is_est_is_tuple(x)[1] for x in estimators):
raise TypeError(msg_no_mix)

return self._get_estimator_tuples(estimators, clone_ests=clone_ests)

return self._get_estimator_tuples(estimators, clone_ests=True)
def _coerce_estimator_tuple(self, obj, clone_est=False):
"""Coerce estimator or (str, estimator) tuple to (str, estimator) tuple.
Parameters
----------
obj : estimator or (str, estimator) tuple
assumes that this has been checked, no checks are performed
clone_est : boolean, optional. Default = False.
Whether to return clone of estimator in obj (True) or a reference (False).
Returns
-------
est_tuple : (str, stimator tuple)
obj if obj was (str, estimator) tuple
(obj class name, obj) if obj was estimator
"""
if isinstance(obj, tuple):
est = obj[1]
name = obj[0]
else:
est = obj
name = type(obj).__name__

if clone_est:
return (name, clone(est))
else:
return (name, est)

def _get_estimator_list(self, estimators):
"""Return list of estimators, from a list or tuple.
Arguments
---------
Parameters
----------
estimators : list of estimators, or list of (str, estimator tuples)
Returns
-------
list of estimators - identical with estimators if list of estimators
if list of (str, estimator) tuples, the str get removed
"""
if isinstance(estimators[0], tuple):
return [x[1] for x in estimators]
else:
return estimators
return [self._coerce_estimator_tuple(x)[1] for x in estimators]

def _get_estimator_names(self, estimators, make_unique=False):
"""Return names for the estimators, optionally made unique.
Arguments
---------
Parameters
----------
estimators : list of estimators, or list of (str, estimator tuples)
make_unique : bool, optional, default=False
whether names should be made unique in the return
Expand All @@ -188,27 +257,16 @@ def _get_estimator_names(self, estimators, make_unique=False):
names for estimators in estimators
if make_unique=True, made unique using _make_strings_unique
"""
if estimators is None or len(estimators) == 0:
names = []
elif isinstance(estimators[0], tuple):
names = [x[0] for x in estimators]
elif isinstance(estimators[0], BaseEstimator):
names = [type(e).__name__ for e in estimators]
else:
raise RuntimeError(
"unreachable condition in _get_estimator_names, "
" likely input assumptions are violated,"
" run _check_estimators before running _get_estimator_names"
)
names = [self._coerce_estimator_tuple(x)[0] for x in estimators]
if make_unique:
names = self._make_strings_unique(names)
return names

def _get_estimator_tuples(self, estimators, clone_ests=False):
"""Return list of estimator tuples, from a list or tuple.
Arguments
---------
Parameters
----------
estimators : list of estimators, or list of (str, estimator tuples)
clone_ests : bool, whether estimators get cloned in the process
Expand Down
62 changes: 62 additions & 0 deletions sktime/forecasting/base/_base.py
Expand Up @@ -108,6 +108,68 @@ def __init__(self):

super(BaseForecaster, self).__init__()

def __mul__(self, other):
"""Magic * method, return (right) concatenated TransformedTargetForecaster.
Implemented for `other` being a transformer, otherwise returns `NotImplemented`.
Parameters
----------
other: `sktime` transformer, must inherit from BaseTransformer
otherwise, `NotImplemented` is returned
Returns
-------
TransformedTargetForecaster object,
concatenation of `self` (first) with `other` (last).
not nested, contains only non-TransformerPipeline `sktime` transformers
"""
from sktime.forecasting.compose import TransformedTargetForecaster
from sktime.transformations.base import BaseTransformer
from sktime.transformations.series.adapt import TabularToSeriesAdaptor
from sktime.utils.sklearn import is_sklearn_transformer

# we wrap self in a pipeline, and concatenate with the other
# the TransformedTargetForecaster does the rest, e.g., dispatch on other
if isinstance(other, BaseTransformer):
self_as_pipeline = TransformedTargetForecaster(steps=[self])
return self_as_pipeline * other
elif is_sklearn_transformer(other):
return self * TabularToSeriesAdaptor(other)
else:
return NotImplemented

def __rmul__(self, other):
"""Magic * method, return (left) concatenated TransformerPipeline.
Implemented for `other` being a transformer, otherwise returns `NotImplemented`.
Parameters
----------
other: `sktime` transformer, must inherit from BaseTransformer
otherwise, `NotImplemented` is returned
Returns
-------
TransformedTargetForecaster object,
concatenation of `other` (first) with `self` (last).
not nested, contains only non-TransformerPipeline `sktime` steps
"""
from sktime.forecasting.compose import TransformedTargetForecaster
from sktime.transformations.base import BaseTransformer
from sktime.transformations.series.adapt import TabularToSeriesAdaptor
from sktime.utils.sklearn import is_sklearn_transformer

# we wrap self in a pipeline, and concatenate with the other
# the TransformedTargetForecaster does the rest, e.g., dispatch on other
if isinstance(other, BaseTransformer):
self_as_pipeline = TransformedTargetForecaster(steps=[self])
return other * self_as_pipeline
elif is_sklearn_transformer(other):
return TabularToSeriesAdaptor(other) * self
else:
return NotImplemented

def fit(self, y, X=None, fh=None):
"""Fit forecaster to training data.
Expand Down

0 comments on commit 7deed44

Please sign in to comment.