Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Speed up hierarchical checks and unify with panel approach #4061

Merged
merged 1 commit into from Jan 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
56 changes: 5 additions & 51 deletions sktime/datatypes/_hierarchical/_check.py
Expand Up @@ -43,9 +43,8 @@
__all__ = ["check_dict"]

import numpy as np
import pandas as pd

from sktime.datatypes._series._check import check_pddataframe_series
from sktime.datatypes._panel._check import check_pdmultiindex_panel


def _list_all_equal(obj):
Expand Down Expand Up @@ -77,56 +76,11 @@ def _ret(valid, msg, metadata, return_metadata):

def check_pdmultiindex_hierarchical(obj, return_metadata=False, var_name="obj"):

if not isinstance(obj, pd.DataFrame):
msg = f"{var_name} must be a pd.DataFrame, found {type(obj)}"
return _ret(False, msg, None, return_metadata)

if not isinstance(obj.index, pd.MultiIndex):
msg = f"{var_name} must have a MultiIndex, found {type(obj.index)}"
return _ret(False, msg, None, return_metadata)

# check that columns are unique
if not obj.columns.is_unique:
msg = f"{var_name} must have unique column indices, but found {obj.columns}"
return _ret(False, msg, None, return_metadata)

# check that there are 3 or more index levels
nlevels = obj.index.nlevels
if not nlevels > 2:
msg = (
f"{var_name} must have a MultiIndex with 3 or more levels, found {nlevels}"
)
return _ret(False, msg, None, return_metadata)

inst_inds = obj.index.droplevel(-1).unique()
panel_inds = inst_inds.droplevel(-1).unique()

check_res = [
check_pddataframe_series(obj.loc[i], return_metadata=True) for i in inst_inds
]
bad_inds = [i[1] for i in enumerate(inst_inds) if not check_res[i[0]][0]]

if len(bad_inds) > 0:
msg = (
f"{var_name}.loc[i] must be Series of mtype pd.DataFrame,"
f" not at i={bad_inds}"
)
return _ret(False, msg, None, return_metadata)

metadata = dict()
metadata["is_univariate"] = np.all([res[2]["is_univariate"] for res in check_res])
metadata["is_equally_spaced"] = np.all(
[res[2]["is_equally_spaced"] for res in check_res]
ret = check_pdmultiindex_panel(
obj, return_metadata=return_metadata, var_name=var_name, panel=False
)
metadata["is_empty"] = np.any([res[2]["is_empty"] for res in check_res])
metadata["n_instances"] = len(inst_inds)
metadata["n_panels"] = len(panel_inds)
metadata["is_one_series"] = len(inst_inds) == 1
metadata["is_one_panel"] = len(panel_inds) == 1
metadata["has_nans"] = obj.isna().values.any()
metadata["is_equal_length"] = _list_all_equal([len(obj.loc[i]) for i in inst_inds])

return _ret(True, None, metadata, return_metadata)

return ret


check_dict[("pd_multiindex_hier", "Hierarchical")] = check_pdmultiindex_hierarchical
25 changes: 23 additions & 2 deletions sktime/datatypes/_panel/_check.py
Expand Up @@ -115,6 +115,9 @@ def check_dflist_panel(obj, return_metadata=False, var_name="obj"):
metadata["is_empty"] = np.any([res[2]["is_empty"] for res in check_res])
metadata["has_nans"] = np.any([res[2]["has_nans"] for res in check_res])
metadata["is_one_series"] = n == 1
metadata["n_panels"] = 1
metadata["is_one_panel"] = True

metadata["n_instances"] = n

return _ret(True, None, metadata, return_metadata)
Expand Down Expand Up @@ -143,6 +146,8 @@ def check_numpy3d_panel(obj, return_metadata=False, var_name="obj"):

metadata["n_instances"] = obj.shape[0]
metadata["is_one_series"] = obj.shape[0] == 1
metadata["n_panels"] = 1
metadata["is_one_panel"] = True

# check whether there any nans; only if requested
if return_metadata:
Expand All @@ -154,7 +159,7 @@ def check_numpy3d_panel(obj, return_metadata=False, var_name="obj"):
check_dict[("numpy3D", "Panel")] = check_numpy3d_panel


def check_pdmultiindex_panel(obj, return_metadata=False, var_name="obj"):
def check_pdmultiindex_panel(obj, return_metadata=False, var_name="obj", panel=True):

if not isinstance(obj, pd.DataFrame):
msg = f"{var_name} must be a pd.DataFrame, found {type(obj)}"
Expand All @@ -172,9 +177,14 @@ def check_pdmultiindex_panel(obj, return_metadata=False, var_name="obj"):

# check that there are precisely two index levels
nlevels = obj.index.nlevels
if not nlevels == 2:
if panel is True and not nlevels == 2:
msg = f"{var_name} must have a MultiIndex with 2 levels, found {nlevels}"
return _ret(False, msg, None, return_metadata)
elif panel is False and not nlevels > 2:
msg = (
f"{var_name} must have a MultiIndex with 3 or more levels, found {nlevels}"
)
return _ret(False, msg, None, return_metadata)

# check that no dtype is object
if "object" in obj.dtypes.values:
Expand Down Expand Up @@ -225,10 +235,17 @@ def check_pdmultiindex_panel(obj, return_metadata=False, var_name="obj"):
)
return _ret(False, msg, None, return_metadata)

if panel is True:
panel_inds = [1]
else:
panel_inds = inst_inds.droplevel(-1).unique()

metadata = dict()
metadata["is_univariate"] = len(obj.columns) < 2
metadata["is_equally_spaced"] = is_equally_spaced
metadata["is_empty"] = len(obj.index) < 1 or len(obj.columns) < 1
metadata["n_panels"] = len(panel_inds)
metadata["is_one_panel"] = len(panel_inds) == 1
metadata["n_instances"] = len(inst_inds)
metadata["is_one_series"] = len(inst_inds) == 1
metadata["has_nans"] = obj.isna().values.any()
Expand Down Expand Up @@ -360,6 +377,8 @@ def is_nested_dataframe(obj, return_metadata=False, var_name="obj"):
metadata["is_univariate"] = obj.shape[1] < 2
metadata["n_instances"] = len(obj)
metadata["is_one_series"] = len(obj) == 1
metadata["n_panels"] = 1
metadata["is_one_panel"] = True
if return_metadata:
metadata["has_nans"] = _nested_dataframe_has_nans(obj)
metadata["is_equal_length"] = not _nested_dataframe_has_unequal(obj)
Expand Down Expand Up @@ -394,6 +413,8 @@ def check_numpyflat_Panel(obj, return_metadata=False, var_name="obj"):
metadata["is_equal_length"] = True
metadata["n_instances"] = obj.shape[0]
metadata["is_one_series"] = obj.shape[0] == 1
metadata["n_panels"] = 1
metadata["is_one_panel"] = True

# check whether there any nans; only if requested
if return_metadata:
Expand Down
6 changes: 6 additions & 0 deletions sktime/datatypes/_panel/_examples.py
Expand Up @@ -82,6 +82,8 @@
example_dict_metadata[("Panel", 0)] = {
"is_univariate": False,
"is_one_series": False,
"n_panels": 1,
"is_one_panel": True,
"is_equally_spaced": True,
"is_equal_length": True,
"is_empty": False,
Expand Down Expand Up @@ -140,6 +142,8 @@
example_dict_metadata[("Panel", 1)] = {
"is_univariate": True,
"is_one_series": False,
"n_panels": 1,
"is_one_panel": True,
"is_equally_spaced": True,
"is_equal_length": True,
"is_empty": False,
Expand Down Expand Up @@ -192,6 +196,8 @@
example_dict_metadata[("Panel", 2)] = {
"is_univariate": True,
"is_one_series": True,
"n_panels": 1,
"is_one_panel": True,
"is_equally_spaced": True,
"is_equal_length": True,
"is_empty": False,
Expand Down