Skip to content

Commit

Permalink
[ENH] improve performance of pandas based panel and hierachical mty…
Browse files Browse the repository at this point in the history
…pe checks (#3935)

Increases speed of `pandas` based panel and hierachical mtype checks by using `groupby`.

See #3827
  • Loading branch information
danbartl committed Jan 2, 2023
1 parent 1de1fb9 commit 1257ec2
Showing 1 changed file with 50 additions and 23 deletions.
73 changes: 50 additions & 23 deletions sktime/datatypes/_panel/_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,14 @@
import numpy as np
import pandas as pd

from sktime.datatypes._series._check import check_pddataframe_series
from sktime.utils.validation.series import is_integer_index
from sktime.datatypes._series._check import (
_index_equally_spaced,
check_pddataframe_series,
)
from sktime.utils.validation.series import is_in_valid_index_types, is_integer_index

VALID_MULTIINDEX_TYPES = (pd.RangeIndex, pd.Index)
VALID_INDEX_TYPES = (pd.RangeIndex, pd.PeriodIndex, pd.DatetimeIndex)


def is_in_valid_multiindex_types(x) -> bool:
Expand Down Expand Up @@ -161,8 +165,9 @@ def check_pdmultiindex_panel(obj, return_metadata=False, var_name="obj"):
return _ret(False, msg, None, return_metadata)

# check that columns are unique
if not obj.columns.is_unique:
msg = f"{var_name} must have unique column indices, but found {obj.columns}"
col_names = obj.columns
if not col_names.is_unique:
msg = f"{var_name} must have unique column indices, but found {col_names}"
return _ret(False, msg, None, return_metadata)

# check that there are precisely two index levels
Expand All @@ -171,41 +176,63 @@ def check_pdmultiindex_panel(obj, return_metadata=False, var_name="obj"):
msg = f"{var_name} must have a MultiIndex with 2 levels, found {nlevels}"
return _ret(False, msg, None, return_metadata)

# check that no dtype is object
if "object" in obj.dtypes.values:
msg = f"{var_name} should not have column of 'object' dtype"
return _ret(False, msg, None, return_metadata)

# check whether the time index is of valid type
if not is_in_valid_index_types(obj.index.get_level_values(-1)):
msg = (
f"{type(obj.index)} is not supported for {var_name}, use "
f"one of {VALID_INDEX_TYPES} or integer index instead."
)
return _ret(False, msg, None, return_metadata)

time_obj = obj.reset_index(-1).drop(obj.columns, axis=1)
time_grp = time_obj.groupby(level=0, group_keys=True, as_index=True)
inst_inds = time_obj.index.unique()

# check instance index being integer or range index
instind = obj.index.get_level_values(0)
if not is_in_valid_multiindex_types(instind):
if not is_in_valid_multiindex_types(inst_inds):
msg = (
f"instance index (first/highest index) must be {VALID_MULTIINDEX_TYPES}, "
f"integer index, but found {type(instind)}"
f"integer index, but found {type(inst_inds)}"
)
return _ret(False, msg, None, return_metadata)

inst_inds = obj.index.get_level_values(0).unique()
# inst_inds = np.unique(obj.index.get_level_values(0))
if pd.__version__ < "1.5.0":
# Earlier versions of pandas are very slow for this type of operation.
is_equally_list = [_index_equally_spaced(obj.loc[i].index) for i in inst_inds]
is_equally_spaced = all(is_equally_list)
montonic_list = [obj.loc[i].index.is_monotonic for i in inst_inds]
time_is_monotonic = len([i for i in montonic_list if i is False]) == 0
else:
timedelta_by_grp = (
time_grp.diff().groupby(level=0, group_keys=True, as_index=True).nunique()
)
timedelta_unique = timedelta_by_grp.iloc[:, 0].unique()
is_equally_spaced = len(timedelta_unique) == 1
time_is_monotonic = all(timedelta_unique >= 0)

check_res = [
check_pddataframe_series(obj.loc[i], return_metadata=True) for i in inst_inds
]
bad_inds = [i for i in range(len(inst_inds)) if not check_res[i][0]]
is_equal_length = time_grp.count()

if len(bad_inds) > 0:
# Check time index is ordered in time
if not time_is_monotonic:
msg = (
f"{var_name}.loc[i] must be Series of mtype pd.DataFrame,"
f" not at i={bad_inds}"
f"The (time) index of {var_name} must be sorted monotonically increasing, "
f"but found: {obj.index.get_level_values(-1)}"
)
return _ret(False, msg, None, return_metadata)

metadata = dict()
metadata["is_univariate"] = np.all([res[2]["is_univariate"] for res in check_res])
metadata["is_equally_spaced"] = np.all(
[res[2]["is_equally_spaced"] for res in check_res]
)
metadata["is_empty"] = np.any([res[2]["is_empty"] for res in check_res])
metadata["is_univariate"] = len(obj.columns) < 2
metadata["is_equally_spaced"] = is_equally_spaced
metadata["is_empty"] = len(obj.index) < 1 or len(obj.columns) < 1
metadata["n_instances"] = len(inst_inds)
metadata["is_one_series"] = len(inst_inds) == 1
metadata["has_nans"] = obj.isna().values.any()
metadata["is_equal_length"] = _list_all_equal([len(obj.loc[i]) for i in inst_inds])

metadata["is_equal_length"] = is_equal_length.nunique().shape[0] == 1
return _ret(True, None, metadata, return_metadata)


Expand Down

0 comments on commit 1257ec2

Please sign in to comment.