diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 8be89c6cfc4..508ec61c386 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -205,7 +205,7 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d Pulkit Verma
Pulkit Verma

📖 Quaterion
Quaterion

🐛 Rakshitha Godahewa
Rakshitha Godahewa

💻 📖 - Ramon Bussing
Ramon Bussing

📖 + Ramon Bussing
Ramon Bussing

📖 💻 RavenRudi
RavenRudi

💻 Rick van Hattem
Rick van Hattem

🚇 Rishabh Bali
Rishabh Bali

💻 diff --git a/sktime/datatypes/_convert.py b/sktime/datatypes/_convert.py index 785185d8b5e..d5bde4c2b02 100644 --- a/sktime/datatypes/_convert.py +++ b/sktime/datatypes/_convert.py @@ -93,6 +93,7 @@ def convert( as_scitype: str = None, store=None, store_behaviour: str = None, + return_to_mtype: bool = False, ): """Convert objects between different machine representations, subject to scitype. @@ -114,11 +115,15 @@ def convert( "freeze" - store is read-only, may be read/used by conversion but not changed "update" - store is updated from conversion and retains previous contents None - automatic: "update" if store is empty and not None; "freeze", otherwise + return_to_mtype: bool, optional (default=False) + if True, also returns the str of the mtype converted to Returns ------- - converted_obj : to_type - object obj converted to to_type - if obj was None, returns None + converted_obj : to_type - object ``obj`` converted to mtype ``to_type`` + if ``obj`` was ``None``, is ``None`` + to_type : str, only returned if ``return_to_mtype=True`` + mtype of ``converted_obj`` - useful of ``to_type`` was a list Raises ------ @@ -176,7 +181,10 @@ def convert( converted_obj = convert_dict[key](obj, store=store) - return converted_obj + if return_to_mtype: + return converted_obj, to_type + else: + return converted_obj # conversion based on queriable type to specified target @@ -186,6 +194,7 @@ def convert_to( as_scitype: str = None, store=None, store_behaviour: str = None, + return_to_mtype: bool = False, ): """Convert object to a different machine representation, subject to scitype. @@ -207,6 +216,8 @@ def convert_to( "freeze" - store is read-only, may be read/used by conversion but not changed "update" - store is updated from conversion and retains previous contents None - automatic: "update" if store is empty and not None; "freeze", otherwise + return_to_mtype: bool, optional (default=False) + if True, also returns the str of the mtype converted to Returns ------- @@ -219,6 +230,8 @@ def convert_to( converted_obj is converted to the first mtype in to_type that is of same scitype as obj case 4: if obj was None, converted_obj is also None + to_type : str, only returned if ``return_to_mtype=True`` + mtype of ``converted_obj`` - useful of ``to_type`` was a list Raises ------ @@ -254,6 +267,7 @@ def convert_to( as_scitype=as_scitype, store=store, store_behaviour=store_behaviour, + return_to_mtype=return_to_mtype, ) return converted_obj @@ -272,8 +286,7 @@ def _get_first_mtype_of_same_scitype(from_mtype, to_mtypes, varname="to_mtypes") ------- to_type : str - first mtype in to_mtypes that has same scitype as from_mtype """ - if isinstance(to_mtypes, str): - return to_mtypes + to_mtypes = _check_str_or_list_of_str(to_mtypes, obj_name=varname) if not isinstance(to_mtypes, list): raise TypeError(f"{varname} must be a str or a list of str") diff --git a/sktime/forecasting/model_selection/_split.py b/sktime/forecasting/model_selection/_split.py index 46010cf0a31..9b050cef5d5 100644 --- a/sktime/forecasting/model_selection/_split.py +++ b/sktime/forecasting/model_selection/_split.py @@ -28,7 +28,7 @@ from sklearn.model_selection import train_test_split as _train_test_split from sktime.base import BaseObject -from sktime.datatypes import check_is_scitype, convert_to +from sktime.datatypes import check_is_scitype, convert from sktime.datatypes._utilities import get_index_for_series, get_time_index, get_window from sktime.forecasting.base import ForecastingHorizon from sktime.forecasting.base._fh import VALID_FORECASTING_HORIZON_TYPES @@ -507,7 +507,7 @@ def split_series(self, y: ACCEPTED_Y_TYPES) -> Iterator[SPLIT_TYPE]: test : time series of same sktime mtype as `y` test series in the split """ - y, y_orig_mtype = self._check_y(y) + y_inner, y_orig_mtype, y_inner_mtype = self._check_y(y) use_iloc_or_loc = self.get_tag("split_series_uses", "iloc", raise_error=False) @@ -523,14 +523,14 @@ def split_series(self, y: ACCEPTED_Y_TYPES) -> Iterator[SPLIT_TYPE]: ) _split = getattr(self, splitter_name) - _slicer = getattr(y, use_iloc_or_loc) + _slicer = getattr(y_inner, use_iloc_or_loc) - for train, test in _split(y.index): + for train, test in _split(y_inner.index): y_train = _slicer[train] y_test = _slicer[test] - y_train = convert_to(y_train, y_orig_mtype) - y_test = convert_to(y_test, y_orig_mtype) + y_train = convert(y_train, from_type=y_inner_mtype, to_type=y_orig_mtype) + y_test = convert(y_test, from_type=y_inner_mtype, to_type=y_orig_mtype) yield y_train, y_test def _coerce_to_index(self, y: ACCEPTED_Y_TYPES) -> pd.Index: @@ -548,7 +548,7 @@ def _coerce_to_index(self, y: ACCEPTED_Y_TYPES) -> pd.Index: y_index : y, if y was pd.Index; otherwise _check_y(y).index """ if not isinstance(y, pd.Index): - y, _ = self._check_y(y, allow_index=True) + y = self._check_y(y, allow_index=True)[0] y_index = y.index else: y_index = y @@ -564,10 +564,14 @@ def _check_y(self, y, allow_index=False): Returns ------- - y_inner : time series y coerced to one of the sktime pandas based mtypes: + y_inner : pd.DataFrame or pd.Series, sktime time series data container + time series y coerced to one of the sktime pandas based mtypes: pd.DataFrame, pd.Series, pd-multiindex, pd_multiindex_hier returns pd.Series only if y was pd.Series, otherwise a pandas.DataFrame - y_mtype : original mtype of y + y_mtype : str, sktime mtype string + original mtype of y (the input) + y_inner_mtype : str, sktime mtype string + mtype of y_inner (the output) Raises ------ @@ -591,7 +595,7 @@ def _check_y(self, y, allow_index=False): "pd_multiindex_hier", ] y_valid, _, y_metadata = check_is_scitype( - y, scitype=ALLOWED_SCITYPES, return_metadata=True, var_name="y" + y, scitype=ALLOWED_SCITYPES, return_metadata=[], var_name="y" ) if allow_index: msg = ( @@ -622,11 +626,16 @@ def _check_y(self, y, allow_index=False): if not y_valid: raise TypeError(msg) - y_inner = convert_to(y, to_type=PANDAS_MTYPES) + y_mtype = y_metadata["mtype"] - mtype = y_metadata["mtype"] + y_inner, y_inner_mtype = convert( + y, + from_type=y_mtype, + to_type=PANDAS_MTYPES, + return_to_mtype=True, + ) - return y_inner, mtype + return y_inner, y_mtype, y_inner_mtype def get_n_splits(self, y: Optional[ACCEPTED_Y_TYPES] = None) -> int: """Return the number of splits.