Skip to content

Commit

Permalink
bugfix: lazy validation handles check_fn returning scalar False value…
Browse files Browse the repository at this point in the history
…, io can handle null index (#217)

* bugfix: lazy validation handles check returning scalar False value

* add unit tests for all schema and schema components
  • Loading branch information
cosmicBboy committed Jun 14, 2020
1 parent 00cd964 commit 3bf8e72
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 19 deletions.
1 change: 1 addition & 0 deletions .envrc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
source activate pandera
6 changes: 2 additions & 4 deletions pandera/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,8 @@ def failure_cases(x):
schema_errors
.fillna({"column": "<NA>"})
.groupby(["schema_context", "column", "check"])
.failure_case.agg([failure_cases, len])
.rename(columns={
"len": "n_failure_cases",
})
.failure_case.agg([failure_cases])
.assign(n_failure_cases=lambda df: df.failure_cases.map(len))
.sort_index(
level=["schema_context", "column"],
ascending=[False, True],
Expand Down
3 changes: 2 additions & 1 deletion pandera/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,8 @@ def to_script(dataframe_schema, path_or_buf=None):
)
columns[colname] = column_code.strip()

index = _format_index(statistics["index"])
index = None if statistics["index"] is None else \
_format_index(statistics["index"])

column_str = ", ".join("'{}': {}".format(k, v) for k, v in columns.items())

Expand Down
18 changes: 13 additions & 5 deletions pandera/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,8 @@ def __call__(
head: Optional[int] = None,
tail: Optional[int] = None,
sample: Optional[int] = None,
random_state: Optional[int] = None):
random_state: Optional[int] = None,
lazy: bool = False):
"""Alias for :func:`DataFrameSchema.validate` method.
:param pd.DataFrame dataframe: the dataframe to be validated.
Expand All @@ -445,8 +446,12 @@ def __call__(
:type tail: int
:param sample: validate a random sample of n rows. Rows overlapping
with `head` or `tail` are de-duplicated.
:param random_state: random seed for the ``sample`` argument.
:param lazy: if True, lazily evaluates dataframe against all validation
checks and raises a ``SchemaErrorReport``. Otherwise, raise
``SchemaError`` as soon as one occurs.
"""
return self.validate(dataframe)
return self.validate(dataframe, head, tail, sample, random_state, lazy)

def __repr__(self):
"""Represent string for logging."""
Expand Down Expand Up @@ -913,9 +918,10 @@ def __call__(
tail: Optional[int] = None,
sample: Optional[int] = None,
random_state: Optional[int] = None,
lazy: bool = False,
) -> Union[pd.DataFrame, pd.Series]:
"""Alias for ``validate`` method."""
return self.validate(check_obj, head, tail, sample, random_state)
return self.validate(check_obj, head, tail, sample, random_state, lazy)

def __eq__(self, other):
return self.__dict__ == other.__dict__
Expand Down Expand Up @@ -996,9 +1002,10 @@ def __call__(
tail: Optional[int] = None,
sample: Optional[int] = None,
random_state: Optional[int] = None,
lazy: bool = False,
) -> pd.Series:
"""Alias for :func:`SeriesSchema.validate` method."""
return self.validate(check_obj)
return self.validate(check_obj, head, tail, sample, random_state, lazy)

def __eq__(self, other):
return self.__dict__ == other.__dict__
Expand Down Expand Up @@ -1039,7 +1046,8 @@ def _handle_check_results(
check_result = check(check_obj, *check_args)
if not check_result.check_passed:
if check_result.failure_cases is None:
failure_cases = None
# encode scalar False values explicitly
failure_cases = scalar_failure_case(check_result.check_passed)
error_msg = format_generic_error_message(
schema, check, check_index)
else:
Expand Down
16 changes: 9 additions & 7 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,18 @@
PYYAML_VERSION = version.parse(yaml.__version__) # type: ignore


def _create_schema(multi_index=False):
def _create_schema(index="single"):

if multi_index:
if index == "multi":
index = pa.MultiIndex([
pa.Index(pa.Int, name="int_index0"),
pa.Index(pa.Int, name="int_index1"),
pa.Index(pa.Int, name="int_index2"),
])
else:
elif index == "single":
index = pa.Index(pa.Int, name="int_index")
else:
index = None

return pa.DataFrameSchema(
columns={
Expand Down Expand Up @@ -199,12 +201,12 @@ def test_io_yaml():
assert schema_from_yaml == schema


@pytest.mark.parametrize("multi_index", [
[True], [False]
@pytest.mark.parametrize("index", [
"single", "multi", None
])
def test_to_script(multi_index):
def test_to_script(index):
"""Test writing DataFrameSchema to a script."""
schema_to_write = _create_schema(multi_index)
schema_to_write = _create_schema(index)
script = io.to_script(schema_to_write)

local_dict = {}
Expand Down
27 changes: 25 additions & 2 deletions tests/test_schemas.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Testing creation and manipulation of DataFrameSchema objects."""

import copy
from functools import partial

import numpy as np
import pandas as pd
import pytest


from pandera import (
Column, DataFrameSchema, Index, MultiIndex, SeriesSchema, Bool, Category,
Check, DateTime, Float, Int, Object, String, Timedelta, errors)
Expand Down Expand Up @@ -746,6 +747,28 @@ def test_lazy_dataframe_validation_nullable():
lambda df: df.column == col, "index"].iloc[0] == index


@pytest.mark.parametrize("schema_cls, data", [
[DataFrameSchema, pd.DataFrame({"column": [1]})],
[SeriesSchema, pd.Series([1, 2, 3])],
[partial(Column, name="column"), pd.DataFrame({"column": [1]})],
[
partial(Index, name="index"),
pd.DataFrame(index=pd.Index([1, 2, 3], name="index"))
],
])
def test_lazy_dataframe_scalar_false_check(schema_cls, data):
"""Lazy validation handles checks returning scalar False values."""
# define a check that always returns a scalare False value
check = Check(
check_fn=lambda _: False,
element_wise=False,
error="failing check"
)
schema = schema_cls(checks=check)
with pytest.raises(errors.SchemaErrors):
schema(data, lazy=True)


@pytest.mark.parametrize("schema, data, expectation", [
[
SeriesSchema(Int, checks=Check.greater_than(0)),
Expand Down Expand Up @@ -830,7 +853,7 @@ def test_lazy_series_validation_error(schema, data, expectation):
try:
schema.validate(data, lazy=True)
except errors.SchemaErrors as err:
# data in the caught exception should be equal to the dataframe
# data in the caught exception should be equal to the data
# passed into validate
assert err.data.equals(expectation["data"])

Expand Down

0 comments on commit 3bf8e72

Please sign in to comment.