Skip to content

Commit

Permalink
implement coercion-on-initialization for DataFrame[SchemaModel] types (
Browse files Browse the repository at this point in the history
…#772)

* implement coercion-on-initialization

* pylint

* Update tests/core/test_model.py

Co-authored-by: Matt Richards <45483497+m-richards@users.noreply.github.com>

Co-authored-by: Matt Richards <45483497+m-richards@users.noreply.github.com>
  • Loading branch information
cosmicBboy and m-richards committed Apr 1, 2022
1 parent 5a48432 commit 2cef93b
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 30 deletions.
3 changes: 2 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ disable=
no-else-return,
inconsistent-return-statements,
protected-access,
too-many-ancestors
too-many-ancestors,
too-many-lines
11 changes: 5 additions & 6 deletions pandera/typing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,13 @@ def __setattr__(self, name: str, value: Any) -> None:

# prevent the double validation problem by preventing checks for
# dataframes with a defined pandera.schema
pandera = getattr(self, "pandera")
pandera_accessor = getattr(self, "pandera")
if (
pandera.schema is None
or pandera.schema != schema_model.to_schema()
pandera_accessor.schema is None
or pandera_accessor.schema != schema_model.to_schema()
):
# pylint: disable=self-cls-assignment
self = schema_model.validate(self)
pandera.add_schema(schema_model.to_schema())
pandera_accessor.add_schema(schema_model.to_schema())
self.__dict__ = schema_model.validate(self).__dict__


# pylint:disable=too-few-public-methods
Expand Down
23 changes: 1 addition & 22 deletions pandera/typing/dask.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Pandera type annotations for Dask."""

import inspect
from typing import TYPE_CHECKING, Any, Generic, TypeVar
from typing import TYPE_CHECKING, Generic, TypeVar

from .common import DataFrameBase, IndexBase, SeriesBase
from .pandas import GenericDtype, Schema
Expand Down Expand Up @@ -32,26 +31,6 @@ class DataFrame(DataFrameBase, dd.DataFrame, Generic[T]):
*new in 0.8.0*
"""

def __setattr__(self, name: str, value: Any) -> None:
object.__setattr__(self, name, value)
if name == "__orig_class__":
class_args = getattr(self.__orig_class__, "__args__", None)
if class_args is not None and any(
x.__name__ == "SchemaModel"
for x in inspect.getmro(class_args[0])
):
schema_model = value.__args__[0]

# prevent the double validation problem by preventing checks
# for dataframes with a defined pandera.schema
if (
self.pandera.schema is None
or self.pandera.schema != schema_model.to_schema()
):
# pylint: disable=self-cls-assignment
self.__dict__ = schema_model.validate(self).__dict__
self.pandera.add_schema(schema_model.to_schema())

# pylint:disable=too-few-public-methods
class Series(SeriesBase, dd.Series, Generic[GenericDtype]): # type: ignore
"""Representation of pandas.Series, only used for type annotation.
Expand Down
37 changes: 37 additions & 0 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,8 @@ class Child(Base):


def test_column_access_regex() -> None:
"""Test that column regex alias is reflected in schema attribute."""

class Schema(pa.SchemaModel):
col_regex: Series[str] = pa.Field(alias="column_([0-9])+", regex=True)

Expand All @@ -965,3 +967,38 @@ class Bar(pa.SchemaModel):

assert Foo.Config.name == "foo"
assert Bar.Config.name == "Bar"


def test_validate_coerce_on_init():
"""Test that DataFrame[Schema] validates and coerces on initialization."""

class Schema(pa.SchemaModel):
state: Series[str]
city: Series[str]
price: Series[float] = pa.Field(
in_range={"min_value": 5, "max_value": 20}
)

class Config:
coerce = True

class SchemaNoCoerce(Schema):
class Config:
coerce = False

raw_data = {
"state": ["NY", "FL", "GA", "CA"],
"city": ["New York", "Miami", "Atlanta", "San Francisco"],
"price": [8, 12, 10, 16],
}
pandera_validated_df = DataFrame[Schema](raw_data)
pandas_df = pd.DataFrame(raw_data)
assert pandera_validated_df.equals(Schema.validate(pandas_df))
assert isinstance(pandera_validated_df, DataFrame)
assert isinstance(pandas_df, pd.DataFrame)

with pytest.raises(
pa.errors.SchemaError,
match="^expected series 'price' to have type float64, got int64$",
):
DataFrame[SchemaNoCoerce](raw_data)
1 change: 0 additions & 1 deletion tests/dask/test_dask.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
""" Tests that basic Pandera functionality works for Dask objects. """


import dask.dataframe as dd
import pandas as pd
import pytest
Expand Down

0 comments on commit 2cef93b

Please sign in to comment.