From 1166a752dd6956cdfa53459e1dc69e23455b62cc Mon Sep 17 00:00:00 2001 From: Danbartl Date: Wed, 4 Jan 2023 14:17:40 +0100 Subject: [PATCH] hier_fix --- sktime/datatypes/_hierarchical/_check.py | 56 +++--------------------- sktime/datatypes/_panel/_check.py | 25 ++++++++++- sktime/datatypes/_panel/_examples.py | 6 +++ 3 files changed, 34 insertions(+), 53 deletions(-) diff --git a/sktime/datatypes/_hierarchical/_check.py b/sktime/datatypes/_hierarchical/_check.py index 7312b8fbf00..3be4fe2a349 100644 --- a/sktime/datatypes/_hierarchical/_check.py +++ b/sktime/datatypes/_hierarchical/_check.py @@ -43,9 +43,8 @@ __all__ = ["check_dict"] import numpy as np -import pandas as pd -from sktime.datatypes._series._check import check_pddataframe_series +from sktime.datatypes._panel._check import check_pdmultiindex_panel def _list_all_equal(obj): @@ -77,56 +76,11 @@ def _ret(valid, msg, metadata, return_metadata): def check_pdmultiindex_hierarchical(obj, return_metadata=False, var_name="obj"): - if not isinstance(obj, pd.DataFrame): - msg = f"{var_name} must be a pd.DataFrame, found {type(obj)}" - return _ret(False, msg, None, return_metadata) - - if not isinstance(obj.index, pd.MultiIndex): - msg = f"{var_name} must have a MultiIndex, found {type(obj.index)}" - return _ret(False, msg, None, return_metadata) - - # check that columns are unique - if not obj.columns.is_unique: - msg = f"{var_name} must have unique column indices, but found {obj.columns}" - return _ret(False, msg, None, return_metadata) - - # check that there are 3 or more index levels - nlevels = obj.index.nlevels - if not nlevels > 2: - msg = ( - f"{var_name} must have a MultiIndex with 3 or more levels, found {nlevels}" - ) - return _ret(False, msg, None, return_metadata) - - inst_inds = obj.index.droplevel(-1).unique() - panel_inds = inst_inds.droplevel(-1).unique() - - check_res = [ - check_pddataframe_series(obj.loc[i], return_metadata=True) for i in inst_inds - ] - bad_inds = [i[1] for i in enumerate(inst_inds) if not check_res[i[0]][0]] - - if len(bad_inds) > 0: - msg = ( - f"{var_name}.loc[i] must be Series of mtype pd.DataFrame," - f" not at i={bad_inds}" - ) - return _ret(False, msg, None, return_metadata) - - metadata = dict() - metadata["is_univariate"] = np.all([res[2]["is_univariate"] for res in check_res]) - metadata["is_equally_spaced"] = np.all( - [res[2]["is_equally_spaced"] for res in check_res] + ret = check_pdmultiindex_panel( + obj, return_metadata=return_metadata, var_name=var_name, panel=False ) - metadata["is_empty"] = np.any([res[2]["is_empty"] for res in check_res]) - metadata["n_instances"] = len(inst_inds) - metadata["n_panels"] = len(panel_inds) - metadata["is_one_series"] = len(inst_inds) == 1 - metadata["is_one_panel"] = len(panel_inds) == 1 - metadata["has_nans"] = obj.isna().values.any() - metadata["is_equal_length"] = _list_all_equal([len(obj.loc[i]) for i in inst_inds]) - - return _ret(True, None, metadata, return_metadata) + + return ret check_dict[("pd_multiindex_hier", "Hierarchical")] = check_pdmultiindex_hierarchical diff --git a/sktime/datatypes/_panel/_check.py b/sktime/datatypes/_panel/_check.py index 496bd47f30a..608b7a1168d 100644 --- a/sktime/datatypes/_panel/_check.py +++ b/sktime/datatypes/_panel/_check.py @@ -115,6 +115,9 @@ def check_dflist_panel(obj, return_metadata=False, var_name="obj"): metadata["is_empty"] = np.any([res[2]["is_empty"] for res in check_res]) metadata["has_nans"] = np.any([res[2]["has_nans"] for res in check_res]) metadata["is_one_series"] = n == 1 + metadata["n_panels"] = 1 + metadata["is_one_panel"] = True + metadata["n_instances"] = n return _ret(True, None, metadata, return_metadata) @@ -143,6 +146,8 @@ def check_numpy3d_panel(obj, return_metadata=False, var_name="obj"): metadata["n_instances"] = obj.shape[0] metadata["is_one_series"] = obj.shape[0] == 1 + metadata["n_panels"] = 1 + metadata["is_one_panel"] = True # check whether there any nans; only if requested if return_metadata: @@ -154,7 +159,7 @@ def check_numpy3d_panel(obj, return_metadata=False, var_name="obj"): check_dict[("numpy3D", "Panel")] = check_numpy3d_panel -def check_pdmultiindex_panel(obj, return_metadata=False, var_name="obj"): +def check_pdmultiindex_panel(obj, return_metadata=False, var_name="obj", panel=True): if not isinstance(obj, pd.DataFrame): msg = f"{var_name} must be a pd.DataFrame, found {type(obj)}" @@ -172,9 +177,14 @@ def check_pdmultiindex_panel(obj, return_metadata=False, var_name="obj"): # check that there are precisely two index levels nlevels = obj.index.nlevels - if not nlevels == 2: + if panel is True and not nlevels == 2: msg = f"{var_name} must have a MultiIndex with 2 levels, found {nlevels}" return _ret(False, msg, None, return_metadata) + elif panel is False and not nlevels > 2: + msg = ( + f"{var_name} must have a MultiIndex with 3 or more levels, found {nlevels}" + ) + return _ret(False, msg, None, return_metadata) # check that no dtype is object if "object" in obj.dtypes.values: @@ -225,10 +235,17 @@ def check_pdmultiindex_panel(obj, return_metadata=False, var_name="obj"): ) return _ret(False, msg, None, return_metadata) + if panel is True: + panel_inds = [1] + else: + panel_inds = inst_inds.droplevel(-1).unique() + metadata = dict() metadata["is_univariate"] = len(obj.columns) < 2 metadata["is_equally_spaced"] = is_equally_spaced metadata["is_empty"] = len(obj.index) < 1 or len(obj.columns) < 1 + metadata["n_panels"] = len(panel_inds) + metadata["is_one_panel"] = len(panel_inds) == 1 metadata["n_instances"] = len(inst_inds) metadata["is_one_series"] = len(inst_inds) == 1 metadata["has_nans"] = obj.isna().values.any() @@ -360,6 +377,8 @@ def is_nested_dataframe(obj, return_metadata=False, var_name="obj"): metadata["is_univariate"] = obj.shape[1] < 2 metadata["n_instances"] = len(obj) metadata["is_one_series"] = len(obj) == 1 + metadata["n_panels"] = 1 + metadata["is_one_panel"] = True if return_metadata: metadata["has_nans"] = _nested_dataframe_has_nans(obj) metadata["is_equal_length"] = not _nested_dataframe_has_unequal(obj) @@ -394,6 +413,8 @@ def check_numpyflat_Panel(obj, return_metadata=False, var_name="obj"): metadata["is_equal_length"] = True metadata["n_instances"] = obj.shape[0] metadata["is_one_series"] = obj.shape[0] == 1 + metadata["n_panels"] = 1 + metadata["is_one_panel"] = True # check whether there any nans; only if requested if return_metadata: diff --git a/sktime/datatypes/_panel/_examples.py b/sktime/datatypes/_panel/_examples.py index e29c7eb6968..02512f8bf41 100644 --- a/sktime/datatypes/_panel/_examples.py +++ b/sktime/datatypes/_panel/_examples.py @@ -82,6 +82,8 @@ example_dict_metadata[("Panel", 0)] = { "is_univariate": False, "is_one_series": False, + "n_panels": 1, + "is_one_panel": True, "is_equally_spaced": True, "is_equal_length": True, "is_empty": False, @@ -140,6 +142,8 @@ example_dict_metadata[("Panel", 1)] = { "is_univariate": True, "is_one_series": False, + "n_panels": 1, + "is_one_panel": True, "is_equally_spaced": True, "is_equal_length": True, "is_empty": False, @@ -192,6 +196,8 @@ example_dict_metadata[("Panel", 2)] = { "is_univariate": True, "is_one_series": True, + "n_panels": 1, + "is_one_panel": True, "is_equally_spaced": True, "is_equal_length": True, "is_empty": False,