Skip to content

Commit

Permalink
clean up Check.__call__ and SeriesSchemaBase.__call__ API (#153)
Browse files Browse the repository at this point in the history
This diff makes the Check API cleaner by moving out the
error handling into the schema doing the data validation
instead of within the __call__ subroutine.

This also standardizes the Check.__call__ signature such that
it takes the dataframe or series object to be checked.

This also cleans up the SeriesSchemaBase API and its subclasses,
where the __call__ method takes a single argument df_or_series,
which is the object to be checked.
  • Loading branch information
cosmicBboy authored and mastersplinter committed Jan 12, 2020
1 parent b62fac4 commit 7bf2fe8
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 205 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ good-names=
x,

[MESSAGES CONTROL]
disable=R0913,W0222
disable=R0913
222 changes: 64 additions & 158 deletions pandera/checks.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
"""Data validation checks."""

from typing import Union, Optional, List, Dict, Callable
from collections import namedtuple
from typing import Dict, Union, Optional, List, Callable

import pandas as pd

from . import errors, constants
from .dtypes import PandasDtype


CheckResult = namedtuple(
"CheckResult", ["check_passed", "checked_object", "failure_cases"])


GroupbyObject = Union[
Expand Down Expand Up @@ -144,110 +149,11 @@ def __init__(
self.groups = groups
self.failure_cases = None

@property
def _error_message(self):
"""Check error message."""
name = getattr(self.fn, '__name__', self.fn.__class__.__name__)
if self.error:
return "%s: %s" % (name, self.error)
return "%s" % name

def _vectorized_error_message(
self,
parent_schema: type,
check_index: int,
failure_cases: Union[pd.DataFrame, pd.Series]) -> str:
"""Construct an error message when a validator fails.
:param parent_schema: class of schema being validated.
:param check_index: The validator that failed.
:param failure_cases: The failure cases encountered by the element-wise
or vectorized validator.
"""
return (
"%s failed element-wise validator %d:\n"
"%s\nfailure cases:\n%s" %
(parent_schema, check_index,
self._error_message,
self._format_failure_cases(failure_cases)))

def _generic_error_message(
self,
parent_schema: type,
check_index: int) -> str:
"""Construct an error message when a check validator fails.
:param parent_schema: class of schema being validated.
:param check_index: The validator that failed.
"""
return "%s failed series validator %d: %s" % \
(parent_schema, check_index, self._error_message)

def _format_failure_cases(
self,
failure_cases: Union[pd.DataFrame, pd.Series]) -> pd.DataFrame:
"""Construct readable error messages for vectorized_error_message.
:param failure_cases: The failure cases encountered by the element-wise
or vectorized validator.
:returns: DataFrame where index contains failure cases, the "index"
column contains a list of integer indexes in the validation
DataFrame that caused the failure, and a "count" column
representing how many failures of that case occurred.
"""
if hasattr(failure_cases, "index") and \
isinstance(failure_cases.index, pd.MultiIndex):
index_name = failure_cases.index.name
failure_cases = (
failure_cases
.rename("failure_case")
.reset_index()
.assign(
index=lambda df: (
df.apply(tuple, axis=1).astype(str)
)
)
)
elif isinstance(failure_cases, pd.DataFrame):
index_name = failure_cases.index.name
failure_cases = (
failure_cases
.pipe(lambda df: pd.Series(
df.itertuples()).map(lambda x: x.__repr__()))
.rename("failure_case")
.reset_index()
)
elif isinstance(failure_cases, pd.Series):
index_name = failure_cases.index.name
failure_cases = (
failure_cases
.rename("failure_case")
.reset_index()
)
else:
raise TypeError(
"type of failure_cases argument not understood: %s" %
type(failure_cases))

index_name = "index" if index_name is None else index_name
failure_cases = (
failure_cases
.groupby("failure_case")[index_name].agg([list, len])
.rename(columns={"list": index_name, "len": "count"})
.sort_values("count", ascending=False)
)

self.failure_cases = failure_cases
return failure_cases.head(self.n_failure_cases)

def _format_groupby_input(
self,
groupby_obj: GroupbyObject,
groups: List[str]
) -> Union[Dict[str, Union[pd.Series, pd.DataFrame]]]:
) -> Union[Dict[str, Union[pd.Series, pd.DataFrame]]]:
# pylint: disable=no-self-use
"""Format groupby object into dict of groups to Series or DataFrame.
Expand All @@ -271,7 +177,8 @@ def _format_groupby_input(
def _prepare_series_input(
self,
series: pd.Series,
dataframe_context: pd.DataFrame) -> SeriesCheckObj:
dataframe_context: Optional[pd.DataFrame] = None
) -> SeriesCheckObj:
"""Prepare input for Column check.
:param pd.Series series: one-dimensional ndarray with axis labels
Expand All @@ -296,8 +203,7 @@ def _prepare_series_input(
return self._format_groupby_input(groupby_obj, self.groups)
raise TypeError("Type %s not recognized for `groupby` argument.")


def prepare_dataframe_input(
def _prepare_dataframe_input(
self, dataframe: pd.DataFrame) -> DataFrameCheckObj:
"""Prepare input for DataFrameSchema check.
Expand All @@ -310,66 +216,61 @@ def prepare_dataframe_input(
groupby_obj = dataframe.groupby(self.groupby)
return self._format_groupby_input(groupby_obj, self.groups)

def _vectorized_check(
self,
parent_schema: type,
check_index: int,
check_obj: Dict[str, Union[pd.Series, pd.DataFrame]]
) -> bool:
"""Perform a vectorized check on a series.
:param parent_schema: class of schema being validated.
:param check_index: The validator to check the series for
:param check_obj: a dictionary of pd.Series to be used by
`_check_fn` and `_vectorized_check`
:returns: True if pandas DataFramf or Series is valid.
"""
val_result = self.fn(check_obj)
if isinstance(val_result, pd.Series):
if not val_result.dtype == PandasDtype.Bool.value:
raise TypeError(
"validator %d: %s must return bool or Series of type "
"bool, found %s" %
(check_index, self.fn.__name__, val_result.dtype))
if val_result.all():
return True
if isinstance(check_obj, dict) or \
check_obj.shape[0] != val_result.shape[0] or \
(check_obj.index != val_result.index).all():
raise errors.SchemaError(
self._generic_error_message(parent_schema, check_index))
raise errors.SchemaError(self._vectorized_error_message(
parent_schema, check_index, check_obj[~val_result]))
if val_result:
return True
raise errors.SchemaError(
self._generic_error_message(parent_schema, check_index))

def __call__(
self,
parent_schema: type,
check_index: int,
check_obj: Union[pd.Series, pd.DataFrame]) -> bool:
df_or_series: Union[pd.DataFrame, pd.Series],
column: str = None,
) -> CheckResult:
"""Validate pandas DataFrame or Series.
:param parent_schema: class of schema being validated.
:check_index: index of check that is being validated.
:check_obj: pandas DataFrame of Series to validate.
:returns: True if check passes.
:df_or_series: pandas DataFrame of Series to validate.
:column: apply the check function to this column.
:returns: CheckResult tuple containing checked object,
check validation result, and failure cases from the checked object.
"""
if column is not None \
and isinstance(df_or_series, pd.DataFrame):
column_dataframe_context = df_or_series.drop(
column, axis="columns")
df_or_series = df_or_series[column].copy()
else:
column_dataframe_context = None

# prepare check object
if isinstance(df_or_series, pd.Series):
check_obj = self._prepare_series_input(
df_or_series, column_dataframe_context)
elif isinstance(df_or_series, pd.DataFrame):
check_obj = self._prepare_dataframe_input(df_or_series)
else:
raise ValueError(
"object of type %s not supported. Must be a "
"Series, a dictionary of Series, or DataFrame" %
df_or_series)

# apply check function to check object
if self.element_wise:
val_result = check_obj.apply(self.fn, axis=1) if \
check_result = check_obj.apply(self.fn, axis=1) if \
isinstance(check_obj, pd.DataFrame) else check_obj.map(self.fn)
if val_result.all():
return True
raise errors.SchemaError(self._vectorized_error_message(
parent_schema, check_index, check_obj[~val_result]))
if isinstance(check_obj, (pd.Series, dict, pd.DataFrame)):
return self._vectorized_check(
parent_schema, check_index, check_obj)
raise ValueError(
"check_obj type %s not supported. Must be a "
"Series, a dictionary of Series, or DataFrame" % check_obj)
else:
# vectorized check function case
check_result = self.fn(check_obj)

# failure cases only apply when the check function returns a boolean
# series that matches the shape and index of the check_obj
if isinstance(check_obj, dict) or \
isinstance(check_result, bool) or \
not isinstance(check_result, pd.Series) or \
check_obj.shape[0] != check_result.shape[0] or \
(check_obj.index != check_result.index).all():
failure_cases = None
else:
failure_cases = check_obj[~check_result]

check_passed = check_result.all() if \
isinstance(check_result, pd.Series) else check_result

return CheckResult(check_passed, check_obj, failure_cases)

def __eq__(self, other):
are_fn_objects_equal = self.__dict__["fn"].__code__.co_code == \
Expand All @@ -380,3 +281,8 @@ def __eq__(self, other):
{i: other.__dict__[i] for i in other.__dict__ if i != 'fn'}

return are_fn_objects_equal and are_all_other_check_attributes_equal

def __repr__(self):
name = getattr(self.fn, '__name__', self.fn.__class__.__name__)
return "<Check %s: %s>" % (name, self.error) \
if self.error is not None else "<Check %s>" % name
105 changes: 105 additions & 0 deletions pandera/error_formatters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""Make schema error messages human-friendly."""

from typing import Union

import pandas as pd

from .checks import Check


def format_generic_error_message(
parent_schema,
check: Check,
check_index: int,
) -> str:
"""Construct an error message when a check validator fails.
:param parent_schema: class of schema being validated.
:param check: check that generated error.
:param check_index: The validator that failed.
"""
return "%s failed series validator %d: %s" % \
(parent_schema, check_index, check)


def format_vectorized_error_message(
parent_schema,
check: Check,
check_index: int,
failure_cases: pd.Series) -> str:
"""Construct an error message when a validator fails.
:param parent_schema: class of schema being validated.
:param check: check that generated error.
:param check_index: The validator that failed.
:param failure_cases: The failure cases encountered by the element-wise
or vectorized validator.
"""
return (
"%s failed element-wise validator %d:\n"
"%s\nfailure cases:\n%s" % (
parent_schema,
check_index,
check,
format_failure_cases(failure_cases, check.n_failure_cases)
)
)


def format_failure_cases(
failure_cases: Union[pd.DataFrame, pd.Series],
n_cases: int) -> pd.DataFrame:
"""Construct readable error messages for vectorized_error_message.
:param failure_cases: The failure cases encountered by the element-wise
or vectorized validator.
:returns: DataFrame where index contains failure cases, the "index"
column contains a list of integer indexes in the validation
DataFrame that caused the failure, and a "count" column
representing how many failures of that case occurred.
"""
if hasattr(failure_cases, "index") and \
isinstance(failure_cases.index, pd.MultiIndex):
index_name = failure_cases.index.name
failure_cases = (
failure_cases
.rename("failure_case")
.reset_index()
.assign(
index=lambda df: (
df.apply(tuple, axis=1).astype(str)
)
)
)
elif isinstance(failure_cases, pd.DataFrame):
index_name = failure_cases.index.name
failure_cases = (
failure_cases
.pipe(lambda df: pd.Series(
df.itertuples()).map(lambda x: x.__repr__()))
.rename("failure_case")
.reset_index()
)
elif isinstance(failure_cases, pd.Series):
index_name = failure_cases.index.name
failure_cases = (
failure_cases
.rename("failure_case")
.reset_index()
)
else:
raise TypeError(
"type of failure_cases argument not understood: %s" %
type(failure_cases))

index_name = "index" if index_name is None else index_name
failure_cases = (
failure_cases
.groupby("failure_case")[index_name].agg([list, len])
.rename(columns={"list": index_name, "len": "count"})
.sort_values("count", ascending=False)
)

return failure_cases.head(n_cases)

0 comments on commit 7bf2fe8

Please sign in to comment.