Skip to content

Commit

Permalink
hier_fix (#4061)
Browse files Browse the repository at this point in the history
Roll out of panel datatype check changes to hierarchical datatype
  • Loading branch information
danbartl committed Jan 6, 2023
1 parent c5aa78f commit 4a0fbfe
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 53 deletions.
56 changes: 5 additions & 51 deletions sktime/datatypes/_hierarchical/_check.py
Expand Up @@ -43,9 +43,8 @@
__all__ = ["check_dict"]

import numpy as np
import pandas as pd

from sktime.datatypes._series._check import check_pddataframe_series
from sktime.datatypes._panel._check import check_pdmultiindex_panel


def _list_all_equal(obj):
Expand Down Expand Up @@ -77,56 +76,11 @@ def _ret(valid, msg, metadata, return_metadata):

def check_pdmultiindex_hierarchical(obj, return_metadata=False, var_name="obj"):

if not isinstance(obj, pd.DataFrame):
msg = f"{var_name} must be a pd.DataFrame, found {type(obj)}"
return _ret(False, msg, None, return_metadata)

if not isinstance(obj.index, pd.MultiIndex):
msg = f"{var_name} must have a MultiIndex, found {type(obj.index)}"
return _ret(False, msg, None, return_metadata)

# check that columns are unique
if not obj.columns.is_unique:
msg = f"{var_name} must have unique column indices, but found {obj.columns}"
return _ret(False, msg, None, return_metadata)

# check that there are 3 or more index levels
nlevels = obj.index.nlevels
if not nlevels > 2:
msg = (
f"{var_name} must have a MultiIndex with 3 or more levels, found {nlevels}"
)
return _ret(False, msg, None, return_metadata)

inst_inds = obj.index.droplevel(-1).unique()
panel_inds = inst_inds.droplevel(-1).unique()

check_res = [
check_pddataframe_series(obj.loc[i], return_metadata=True) for i in inst_inds
]
bad_inds = [i[1] for i in enumerate(inst_inds) if not check_res[i[0]][0]]

if len(bad_inds) > 0:
msg = (
f"{var_name}.loc[i] must be Series of mtype pd.DataFrame,"
f" not at i={bad_inds}"
)
return _ret(False, msg, None, return_metadata)

metadata = dict()
metadata["is_univariate"] = np.all([res[2]["is_univariate"] for res in check_res])
metadata["is_equally_spaced"] = np.all(
[res[2]["is_equally_spaced"] for res in check_res]
ret = check_pdmultiindex_panel(
obj, return_metadata=return_metadata, var_name=var_name, panel=False
)
metadata["is_empty"] = np.any([res[2]["is_empty"] for res in check_res])
metadata["n_instances"] = len(inst_inds)
metadata["n_panels"] = len(panel_inds)
metadata["is_one_series"] = len(inst_inds) == 1
metadata["is_one_panel"] = len(panel_inds) == 1
metadata["has_nans"] = obj.isna().values.any()
metadata["is_equal_length"] = _list_all_equal([len(obj.loc[i]) for i in inst_inds])

return _ret(True, None, metadata, return_metadata)

return ret


check_dict[("pd_multiindex_hier", "Hierarchical")] = check_pdmultiindex_hierarchical
25 changes: 23 additions & 2 deletions sktime/datatypes/_panel/_check.py
Expand Up @@ -115,6 +115,9 @@ def check_dflist_panel(obj, return_metadata=False, var_name="obj"):
metadata["is_empty"] = np.any([res[2]["is_empty"] for res in check_res])
metadata["has_nans"] = np.any([res[2]["has_nans"] for res in check_res])
metadata["is_one_series"] = n == 1
metadata["n_panels"] = 1
metadata["is_one_panel"] = True

metadata["n_instances"] = n

return _ret(True, None, metadata, return_metadata)
Expand Down Expand Up @@ -143,6 +146,8 @@ def check_numpy3d_panel(obj, return_metadata=False, var_name="obj"):

metadata["n_instances"] = obj.shape[0]
metadata["is_one_series"] = obj.shape[0] == 1
metadata["n_panels"] = 1
metadata["is_one_panel"] = True

# check whether there any nans; only if requested
if return_metadata:
Expand All @@ -154,7 +159,7 @@ def check_numpy3d_panel(obj, return_metadata=False, var_name="obj"):
check_dict[("numpy3D", "Panel")] = check_numpy3d_panel


def check_pdmultiindex_panel(obj, return_metadata=False, var_name="obj"):
def check_pdmultiindex_panel(obj, return_metadata=False, var_name="obj", panel=True):

if not isinstance(obj, pd.DataFrame):
msg = f"{var_name} must be a pd.DataFrame, found {type(obj)}"
Expand All @@ -172,9 +177,14 @@ def check_pdmultiindex_panel(obj, return_metadata=False, var_name="obj"):

# check that there are precisely two index levels
nlevels = obj.index.nlevels
if not nlevels == 2:
if panel is True and not nlevels == 2:
msg = f"{var_name} must have a MultiIndex with 2 levels, found {nlevels}"
return _ret(False, msg, None, return_metadata)
elif panel is False and not nlevels > 2:
msg = (
f"{var_name} must have a MultiIndex with 3 or more levels, found {nlevels}"
)
return _ret(False, msg, None, return_metadata)

# check that no dtype is object
if "object" in obj.dtypes.values:
Expand Down Expand Up @@ -225,10 +235,17 @@ def check_pdmultiindex_panel(obj, return_metadata=False, var_name="obj"):
)
return _ret(False, msg, None, return_metadata)

if panel is True:
panel_inds = [1]
else:
panel_inds = inst_inds.droplevel(-1).unique()

metadata = dict()
metadata["is_univariate"] = len(obj.columns) < 2
metadata["is_equally_spaced"] = is_equally_spaced
metadata["is_empty"] = len(obj.index) < 1 or len(obj.columns) < 1
metadata["n_panels"] = len(panel_inds)
metadata["is_one_panel"] = len(panel_inds) == 1
metadata["n_instances"] = len(inst_inds)
metadata["is_one_series"] = len(inst_inds) == 1
metadata["has_nans"] = obj.isna().values.any()
Expand Down Expand Up @@ -360,6 +377,8 @@ def is_nested_dataframe(obj, return_metadata=False, var_name="obj"):
metadata["is_univariate"] = obj.shape[1] < 2
metadata["n_instances"] = len(obj)
metadata["is_one_series"] = len(obj) == 1
metadata["n_panels"] = 1
metadata["is_one_panel"] = True
if return_metadata:
metadata["has_nans"] = _nested_dataframe_has_nans(obj)
metadata["is_equal_length"] = not _nested_dataframe_has_unequal(obj)
Expand Down Expand Up @@ -394,6 +413,8 @@ def check_numpyflat_Panel(obj, return_metadata=False, var_name="obj"):
metadata["is_equal_length"] = True
metadata["n_instances"] = obj.shape[0]
metadata["is_one_series"] = obj.shape[0] == 1
metadata["n_panels"] = 1
metadata["is_one_panel"] = True

# check whether there any nans; only if requested
if return_metadata:
Expand Down
6 changes: 6 additions & 0 deletions sktime/datatypes/_panel/_examples.py
Expand Up @@ -82,6 +82,8 @@
example_dict_metadata[("Panel", 0)] = {
"is_univariate": False,
"is_one_series": False,
"n_panels": 1,
"is_one_panel": True,
"is_equally_spaced": True,
"is_equal_length": True,
"is_empty": False,
Expand Down Expand Up @@ -140,6 +142,8 @@
example_dict_metadata[("Panel", 1)] = {
"is_univariate": True,
"is_one_series": False,
"n_panels": 1,
"is_one_panel": True,
"is_equally_spaced": True,
"is_equal_length": True,
"is_empty": False,
Expand Down Expand Up @@ -192,6 +196,8 @@
example_dict_metadata[("Panel", 2)] = {
"is_univariate": True,
"is_one_series": True,
"n_panels": 1,
"is_one_panel": True,
"is_equally_spaced": True,
"is_equal_length": True,
"is_empty": False,
Expand Down

0 comments on commit 4a0fbfe

Please sign in to comment.