Skip to content

Commit

Permalink
fix lazy validation with nullable columns (#213)
Browse files Browse the repository at this point in the history
  • Loading branch information
cosmicBboy authored May 25, 2020
1 parent 923a197 commit f5051fa
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 36 deletions.
7 changes: 5 additions & 2 deletions pandera/error_formatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,14 @@ def scalar_failure_case(x) -> pd.DataFrame:


def reshape_failure_cases(
failure_cases: Union[pd.DataFrame, pd.Series]) -> pd.DataFrame:
failure_cases: Union[pd.DataFrame, pd.Series],
ignore_na: bool = True
) -> pd.DataFrame:
"""Construct readable error messages for vectorized_error_message.
:param failure_cases: The failure cases encountered by the element-wise
or vectorized validator.
:param ignore_na: whether or not to ignore null failure cases.
:returns: DataFrame where index contains failure cases, the "index"
column contains a list of integer indexes in the validation
DataFrame that caused the failure, and a "count" column
Expand Down Expand Up @@ -112,4 +115,4 @@ def reshape_failure_cases(
"type of failure_cases argument not understood: %s" %
type(failure_cases))

return failure_cases.dropna()
return failure_cases.dropna() if ignore_na else failure_cases
40 changes: 15 additions & 25 deletions pandera/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,15 +793,15 @@ def validate(
check="column_name('%s')" % self._name,
)

_dtype = self.dtype

if self._nullable:
# currently, to handle null cases drop null values before passing
# into checks.
series = series.dropna()
if _dtype in dtypes.NUMPY_NONNULLABLE_INT_DTYPES:
_series = series.astype(_dtype)
if self.dtype in dtypes.NUMPY_NONNULLABLE_INT_DTYPES:
_series = series.astype(self.dtype)
if (_series != series).any():
# in case where dtype is meant to be int, make sure that
# casting to int results in the same values.
# casting to int results in equal values.
msg = (
"after dropping null values, expected values in "
"series '%s' to be int, found: %s" %
Expand All @@ -817,22 +817,8 @@ def validate(
)
series = _series

nulls = series.isnull()
nulls = series.isna()
if sum(nulls) > 0:
if series.dtype != _dtype:
msg = (
"expected series '%s' to have type %s, got %s" %
(series.name, self.dtype, series.dtype)
)
error_handler.collect_error(
"wrong_pandas_dtype",
errors.SchemaError(
self, check_obj, msg,
failure_cases=scalar_failure_case(self.dtype),
check="pandas_dtype('%s')" % self.dtype,
)
)

msg = (
"non-nullable series '%s' contains null values: %s" %
(series.name,
Expand All @@ -843,7 +829,9 @@ def validate(
"series_contains_nulls",
errors.SchemaError(
self, check_obj, msg,
failure_cases=series[nulls],
failure_cases=reshape_failure_cases(
series[nulls], ignore_na=False
),
check="not_nullable",
)
)
Expand All @@ -862,22 +850,24 @@ def validate(
"series_contains_duplicates",
errors.SchemaError(
self, check_obj, msg,
failure_cases=series[duplicates],
failure_cases=reshape_failure_cases(
series[duplicates]
),
check="no_duplicates",
)
)

if _dtype is not None and str(series.dtype) != _dtype:
if self.dtype is not None and str(series.dtype) != self.dtype:
msg = (
"expected series '%s' to have type %s, got %s" %
(series.name, _dtype, str(series.dtype))
(series.name, self.dtype, str(series.dtype))
)
error_handler.collect_error(
"wrong_pandas_dtype",
errors.SchemaError(
self, check_obj, msg,
failure_cases=scalar_failure_case(str(series.dtype)),
check="pandas_dtype('%s')" % _dtype,
check="pandas_dtype('%s')" % self.dtype,
)
)

Expand Down
42 changes: 33 additions & 9 deletions tests/test_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,15 +158,6 @@ def test_series_schema():
SeriesSchema(Float, nullable=False).validate(
pd.Series([1.1, 2.3, 5.5, np.nan]))

# when series contains null values when schema is not nullable in addition
# to having the wrong data type
with pytest.raises(
errors.SchemaError,
match=(
r"^expected series '.+' to have type .+, got .+")):
SeriesSchema(Int, nullable=False).validate(
pd.Series([1.1, 2.3, 5.5, np.nan]))


def test_series_schema_multiple_validators():
"""Tests how multiple Checks on a Series Schema are handled both
Expand Down Expand Up @@ -720,6 +711,39 @@ def test_lazy_dataframe_validation_error():
)


def test_lazy_dataframe_validation_nullable():
"""
Test that non-nullable column failure cases are correctly processed during
lazy validation.
"""
schema = DataFrameSchema(
columns={
"int_column": Column(Int, nullable=False),
"float_column": Column(Float, nullable=False),
"str_column": Column(String, nullable=False),
},
strict=True
)

df = pd.DataFrame({
"int_column": [1, None, 3],
"float_column": [0.1, 1.2, None],
"str_column": [None, "foo", "bar"],
})

try:
schema.validate(df, lazy=True)
except errors.SchemaErrors as err:
assert err.schema_errors.failure_case.isna().all()
for col, index in [
("int_column", 1),
("float_column", 2),
("str_column", 0)]:
# pylint: disable=cell-var-from-loop
assert err.schema_errors.loc[
lambda df: df.column == col, "index"].iloc[0] == index


@pytest.mark.parametrize("schema, data, expectation", [
[
SeriesSchema(Int, checks=Check.greater_than(0)),
Expand Down

0 comments on commit f5051fa

Please sign in to comment.