Skip to content

Commit

Permalink
adds equality operators for Checks, Columns, Schemas etc. (#143)
Browse files Browse the repository at this point in the history
  • Loading branch information
mastersplinter authored and cosmicBboy committed Jan 12, 2020
1 parent b49a4ab commit b62fac4
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 1 deletion.
11 changes: 10 additions & 1 deletion pandera/checks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Data validation checks."""

from typing import Union, Optional, List, Dict, Callable

import pandas as pd

from . import errors, constants
Expand Down Expand Up @@ -371,3 +370,13 @@ def __call__(
raise ValueError(
"check_obj type %s not supported. Must be a "
"Series, a dictionary of Series, or DataFrame" % check_obj)

def __eq__(self, other):
are_fn_objects_equal = self.__dict__["fn"].__code__.co_code == \
other.__dict__["fn"].__code__.co_code

are_all_other_check_attributes_equal = \
{i: self.__dict__[i] for i in self.__dict__ if i != 'fn'} == \
{i: other.__dict__[i] for i in other.__dict__ if i != 'fn'}

return are_fn_objects_equal and are_all_other_check_attributes_equal
9 changes: 9 additions & 0 deletions pandera/schema_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def __repr__(self):
dtype = self._pandas_dtype
return "<Schema Column: '%s' type=%s>" % (self._name, dtype)

def __eq__(self, other):
return self.__dict__ == other.__dict__


class Index(SeriesSchemaBase):
"""Extends SeriesSchemaBase with Index-specific options"""
Expand Down Expand Up @@ -156,6 +159,9 @@ def __repr__(self):
return "<Schema Index>"
return "<Schema Index: '%s'>" % self._name

def __eq__(self, other):
return self.__dict__ == other.__dict__


class MultiIndex(DataFrameSchema):
"""Extends SeriesSchemaBase with Multi-index-specific options"""
Expand Down Expand Up @@ -234,3 +240,6 @@ def __call__(self, df: pd.DataFrame) -> bool:

def __repr__(self):
return "<Schema MultiIndex: '%s'>" % list(self.columns)

def __eq__(self, other):
return self.__dict__ == other.__dict__
10 changes: 10 additions & 0 deletions pandera/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,10 @@ def __str__(self):
indent=_indent,
)

def __eq__(self, other):
return self.__dict__ == other.__dict__



class SeriesSchemaBase():
"""Base series validator object."""
Expand Down Expand Up @@ -432,6 +436,9 @@ def __call__(
check._prepare_series_input(series, dataframe_context)))
return all(val_results)

def __eq__(self, other):
return self.__dict__ == other.__dict__


class SeriesSchema(SeriesSchemaBase):
"""Series validator."""
Expand Down Expand Up @@ -513,3 +520,6 @@ def validate(self, series: pd.Series) -> pd.Series:

assert super(SeriesSchema, self).__call__(series)
return series

def __eq__(self, other):
return self.__dict__ == other.__dict__
19 changes: 19 additions & 0 deletions tests/test_checks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests the way Columns are Checked"""

import copy
import pandas as pd
import pytest

Expand Down Expand Up @@ -265,3 +266,21 @@ def test_format_failure_case_exceptions():
for data in [1, "foobar", 1.0, {"key": "value"}, list(range(10))]:
with pytest.raises(TypeError):
check._format_failure_cases(data)


def test_check_equality_operators():
"""Test the usage of == between a Check and an entirely different Check."""
check = Check(lambda g: g["foo"]["col1"].iat[0] == 1, groupby="col3")

not_equal_check = Check(lambda x: x.isna().sum() == 0)
assert check == copy.deepcopy(check)
assert check != not_equal_check


def test_equality_operators_functional_equivalence():
"""Test the usage of == for Checks where the Check callable object has
the same implementation."""
main_check = Check(lambda g: g["foo"]["col1"].iat[0] == 1, groupby="col3")
same_check = Check(lambda h: h["foo"]["col1"].iat[0] == 1, groupby="col3")

assert main_check == same_check
27 changes: 27 additions & 0 deletions tests/test_schema_components.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Testing the components of the Schema objects."""

import copy
import pandas as pd
import pytest


from pandera import errors
from pandera import (
Column, DataFrameSchema, Index, MultiIndex, Check, DateTime, Float, Int,
Expand Down Expand Up @@ -103,3 +105,28 @@ def test_multi_index_index():
def test_column_dtype_property(pandas_dtype, expected):
"""Tests that the dtypes provided by Column match pandas dtypes"""
assert Column(pandas_dtype).dtype == expected

def test_schema_component_equality_operators():
"""Test the usage of == for Column, Index and MultiIndex."""
column = Column(Int, Check(lambda s: s >= 0))
index = Index(Int, [Check(lambda x: 1 <= x <= 11, element_wise=True)])
multi_index = MultiIndex(
indexes=[
Index(Int,
Check(lambda s: (s < 5) & (s >= 0)),
name="index0"),
Index(String,
Check(lambda s: s.isin(["foo", "bar"])),
name="index1"),
]
)
not_equal_schema = DataFrameSchema({
"col1": Column(Int, Check(lambda s: s >= 0))
})

assert column == copy.deepcopy(column)
assert column != not_equal_schema
assert index == copy.deepcopy(index)
assert index != not_equal_schema
assert multi_index == copy.deepcopy(multi_index)
assert multi_index != not_equal_schema
39 changes: 39 additions & 0 deletions tests/test_schemas.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""Testing creation and manipulation of DataFrameSchema objects."""

import copy
import numpy as np
import pandas as pd
import pytest


from pandera import (
Column, DataFrameSchema, Index, SeriesSchema, Bool, Category, Check,
DateTime, Float, Int, Object, String, Timedelta, errors)
from pandera.schemas import SeriesSchemaBase
from tests.test_dtypes import TESTABLE_DTYPES


Expand Down Expand Up @@ -447,3 +450,39 @@ def test_dataframe_schema_dtype_property():
def test_series_schema_dtype_property(pandas_dtype, expected):
"""Tests every type of allowed dtype."""
assert SeriesSchema(pandas_dtype).dtype == expected


def test_schema_equality_operators():
"""Test the usage of == for DataFrameSchema, SeriesSchema and
SeriesSchemaBase."""
df_schema = DataFrameSchema({
"col1": Column(Int, Check(lambda s: s >= 0)),
"col2": Column(String, Check(lambda s: s >= 2)),
}, strict=True)
df_schema_columns_in_different_order = DataFrameSchema({
"col2": Column(String, Check(lambda s: s >= 2)),
"col1": Column(Int, Check(lambda s: s >= 0)),
}, strict=True)
series_schema = SeriesSchema(
String,
checks=[Check(lambda s: s.str.startswith("foo"))],
nullable=False,
allow_duplicates=True,
name="my_series")
series_schema_base = SeriesSchemaBase(
String,
checks=[Check(lambda s: s.str.startswith("foo"))],
nullable=False,
allow_duplicates=True,
name="my_series")
not_equal_schema = DataFrameSchema({
"col1": Column(String)
}, strict=False)

assert df_schema == copy.deepcopy(df_schema)
assert df_schema != not_equal_schema
assert df_schema == df_schema_columns_in_different_order
assert series_schema == copy.deepcopy(series_schema)
assert series_schema != not_equal_schema
assert series_schema_base == copy.deepcopy(series_schema_base)
assert series_schema_base != not_equal_schema

0 comments on commit b62fac4

Please sign in to comment.