Skip to content

Commit

Permalink
[ENH] speed up BaseSplitter boilerplate (#5063)
Browse files Browse the repository at this point in the history
This PR speeds up the checks and conversions boilerplate in
`BaseSplitter`.

* avoids unnecessary checks to detect mtype, by replacing `convert_to`
by a more specific `convert`
* prevents unnecessary metadata computation in `check_is_scitype` by
adjusting the requested metadata argument to none

To effect this, adds an argument and logic to `convert` and `convert_to`
that allows to return the mtype of the object that was returned - this
is not immediate from the arguments, as the `to_type` can be an
allow-list.

Related, and depends on (for changes to `convert`):
#5036
  • Loading branch information
fkiraly committed Aug 19, 2023
1 parent 8589fca commit cd9c587
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 19 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTORS.md
Expand Up @@ -205,7 +205,7 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
<td align="center" valign="top" width="11.11%"><a href="https://github.com/pul95"><img src="https://avatars.githubusercontent.com/pul95?s=100" width="100px;" alt="Pulkit Verma"/><br /><sub><b>Pulkit Verma</b></sub></a><br /><a href="https://github.com/sktime/sktime/commits?author=pul95" title="Documentation">📖</a></td>
<td align="center" valign="top" width="11.11%"><a href="https://github.com/Quaterion"><img src="https://avatars2.githubusercontent.com/u/23200273?v=4?s=100" width="100px;" alt="Quaterion"/><br /><sub><b>Quaterion</b></sub></a><br /><a href="https://github.com/sktime/sktime/issues?q=author%3AQuaterion" title="Bug reports">🐛</a></td>
<td align="center" valign="top" width="11.11%"><a href="https://github.com/rakshitha123"><img src="https://avatars.githubusercontent.com/u/7654679?v=4?s=100" width="100px;" alt="Rakshitha Godahewa"/><br /><sub><b>Rakshitha Godahewa</b></sub></a><br /><a href="https://github.com/sktime/sktime/commits?author=rakshitha123" title="Code">💻</a> <a href="https://github.com/sktime/sktime/commits?author=rakshitha123" title="Documentation">📖</a></td>
<td align="center" valign="top" width="11.11%"><a href="https://github.com/Ram0nB"><img src="https://avatars.githubusercontent.com/u/45173421?s=100" width="100px;" alt="Ramon Bussing"/><br /><sub><b>Ramon Bussing</b></sub></a><br /><a href="https://github.com/sktime/sktime/commits?author=Ram0nB" title="Documentation">📖</a></td>
<td align="center" valign="top" width="11.11%"><a href="https://github.com/Ram0nB"><img src="https://avatars.githubusercontent.com/u/45173421?s=100" width="100px;" alt="Ramon Bussing"/><br /><sub><b>Ramon Bussing</b></sub></a><br /><a href="https://github.com/sktime/sktime/commits?author=Ram0nB" title="Documentation">📖</a> <a href="https://github.com/sktime/sktime/commits?author=Ram0nB" title="Code">💻</a></td>
<td align="center" valign="top" width="11.11%"><a href="https://github.com/RavenRudi"><img src="https://avatars.githubusercontent.com/u/46402968?v=4?s=100" width="100px;" alt="RavenRudi"/><br /><sub><b>RavenRudi</b></sub></a><br /><a href="https://github.com/sktime/sktime/commits?author=RavenRudi" title="Code">💻</a></td>
<td align="center" valign="top" width="11.11%"><a href="https://github.com/wolph"><img src="?s=100" width="100px;" alt="Rick van Hattem"/><br /><sub><b>Rick van Hattem</b></sub></a><br /><a href="#infra-wolph" title="Infrastructure (Hosting, Build-Tools, etc)">🚇</a></td>
<td align="center" valign="top" width="11.11%"><a href="https://github.com/Ris-Bali"><img src="https://avatars.githubusercontent.com/u/81592570?v=4?s=100" width="100px;" alt="Rishabh Bali"/><br /><sub><b>Rishabh Bali</b></sub></a><br /><a href="https://github.com/sktime/sktime/commits?author=Ris-Bali" title="Code">💻</a></td>
Expand Down
23 changes: 18 additions & 5 deletions sktime/datatypes/_convert.py
Expand Up @@ -93,6 +93,7 @@ def convert(
as_scitype: str = None,
store=None,
store_behaviour: str = None,
return_to_mtype: bool = False,
):
"""Convert objects between different machine representations, subject to scitype.
Expand All @@ -114,11 +115,15 @@ def convert(
"freeze" - store is read-only, may be read/used by conversion but not changed
"update" - store is updated from conversion and retains previous contents
None - automatic: "update" if store is empty and not None; "freeze", otherwise
return_to_mtype: bool, optional (default=False)
if True, also returns the str of the mtype converted to
Returns
-------
converted_obj : to_type - object obj converted to to_type
if obj was None, returns None
converted_obj : to_type - object ``obj`` converted to mtype ``to_type``
if ``obj`` was ``None``, is ``None``
to_type : str, only returned if ``return_to_mtype=True``
mtype of ``converted_obj`` - useful of ``to_type`` was a list
Raises
------
Expand Down Expand Up @@ -176,7 +181,10 @@ def convert(

converted_obj = convert_dict[key](obj, store=store)

return converted_obj
if return_to_mtype:
return converted_obj, to_type
else:
return converted_obj


# conversion based on queriable type to specified target
Expand All @@ -186,6 +194,7 @@ def convert_to(
as_scitype: str = None,
store=None,
store_behaviour: str = None,
return_to_mtype: bool = False,
):
"""Convert object to a different machine representation, subject to scitype.
Expand All @@ -207,6 +216,8 @@ def convert_to(
"freeze" - store is read-only, may be read/used by conversion but not changed
"update" - store is updated from conversion and retains previous contents
None - automatic: "update" if store is empty and not None; "freeze", otherwise
return_to_mtype: bool, optional (default=False)
if True, also returns the str of the mtype converted to
Returns
-------
Expand All @@ -219,6 +230,8 @@ def convert_to(
converted_obj is converted to the first mtype in to_type
that is of same scitype as obj
case 4: if obj was None, converted_obj is also None
to_type : str, only returned if ``return_to_mtype=True``
mtype of ``converted_obj`` - useful of ``to_type`` was a list
Raises
------
Expand Down Expand Up @@ -254,6 +267,7 @@ def convert_to(
as_scitype=as_scitype,
store=store,
store_behaviour=store_behaviour,
return_to_mtype=return_to_mtype,
)

return converted_obj
Expand All @@ -272,8 +286,7 @@ def _get_first_mtype_of_same_scitype(from_mtype, to_mtypes, varname="to_mtypes")
-------
to_type : str - first mtype in to_mtypes that has same scitype as from_mtype
"""
if isinstance(to_mtypes, str):
return to_mtypes
to_mtypes = _check_str_or_list_of_str(to_mtypes, obj_name=varname)

if not isinstance(to_mtypes, list):
raise TypeError(f"{varname} must be a str or a list of str")
Expand Down
35 changes: 22 additions & 13 deletions sktime/forecasting/model_selection/_split.py
Expand Up @@ -28,7 +28,7 @@
from sklearn.model_selection import train_test_split as _train_test_split

from sktime.base import BaseObject
from sktime.datatypes import check_is_scitype, convert_to
from sktime.datatypes import check_is_scitype, convert
from sktime.datatypes._utilities import get_index_for_series, get_time_index, get_window
from sktime.forecasting.base import ForecastingHorizon
from sktime.forecasting.base._fh import VALID_FORECASTING_HORIZON_TYPES
Expand Down Expand Up @@ -507,7 +507,7 @@ def split_series(self, y: ACCEPTED_Y_TYPES) -> Iterator[SPLIT_TYPE]:
test : time series of same sktime mtype as `y`
test series in the split
"""
y, y_orig_mtype = self._check_y(y)
y_inner, y_orig_mtype, y_inner_mtype = self._check_y(y)

use_iloc_or_loc = self.get_tag("split_series_uses", "iloc", raise_error=False)

Expand All @@ -523,14 +523,14 @@ def split_series(self, y: ACCEPTED_Y_TYPES) -> Iterator[SPLIT_TYPE]:
)

_split = getattr(self, splitter_name)
_slicer = getattr(y, use_iloc_or_loc)
_slicer = getattr(y_inner, use_iloc_or_loc)

for train, test in _split(y.index):
for train, test in _split(y_inner.index):
y_train = _slicer[train]
y_test = _slicer[test]

y_train = convert_to(y_train, y_orig_mtype)
y_test = convert_to(y_test, y_orig_mtype)
y_train = convert(y_train, from_type=y_inner_mtype, to_type=y_orig_mtype)
y_test = convert(y_test, from_type=y_inner_mtype, to_type=y_orig_mtype)
yield y_train, y_test

def _coerce_to_index(self, y: ACCEPTED_Y_TYPES) -> pd.Index:
Expand All @@ -548,7 +548,7 @@ def _coerce_to_index(self, y: ACCEPTED_Y_TYPES) -> pd.Index:
y_index : y, if y was pd.Index; otherwise _check_y(y).index
"""
if not isinstance(y, pd.Index):
y, _ = self._check_y(y, allow_index=True)
y = self._check_y(y, allow_index=True)[0]
y_index = y.index
else:
y_index = y
Expand All @@ -564,10 +564,14 @@ def _check_y(self, y, allow_index=False):
Returns
-------
y_inner : time series y coerced to one of the sktime pandas based mtypes:
y_inner : pd.DataFrame or pd.Series, sktime time series data container
time series y coerced to one of the sktime pandas based mtypes:
pd.DataFrame, pd.Series, pd-multiindex, pd_multiindex_hier
returns pd.Series only if y was pd.Series, otherwise a pandas.DataFrame
y_mtype : original mtype of y
y_mtype : str, sktime mtype string
original mtype of y (the input)
y_inner_mtype : str, sktime mtype string
mtype of y_inner (the output)
Raises
------
Expand All @@ -591,7 +595,7 @@ def _check_y(self, y, allow_index=False):
"pd_multiindex_hier",
]
y_valid, _, y_metadata = check_is_scitype(
y, scitype=ALLOWED_SCITYPES, return_metadata=True, var_name="y"
y, scitype=ALLOWED_SCITYPES, return_metadata=[], var_name="y"
)
if allow_index:
msg = (
Expand Down Expand Up @@ -622,11 +626,16 @@ def _check_y(self, y, allow_index=False):
if not y_valid:
raise TypeError(msg)

y_inner = convert_to(y, to_type=PANDAS_MTYPES)
y_mtype = y_metadata["mtype"]

mtype = y_metadata["mtype"]
y_inner, y_inner_mtype = convert(
y,
from_type=y_mtype,
to_type=PANDAS_MTYPES,
return_to_mtype=True,
)

return y_inner, mtype
return y_inner, y_mtype, y_inner_mtype

def get_n_splits(self, y: Optional[ACCEPTED_Y_TYPES] = None) -> int:
"""Return the number of splits.
Expand Down

0 comments on commit cd9c587

Please sign in to comment.