Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for optional name validation of single-index #326

Merged
merged 2 commits into from
Nov 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 45 additions & 19 deletions pandera/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class BaseConfig: # pylint:disable=R0903
name: Optional[str] = None #: name of schema
coerce: bool = False #: coerce types of all schema components
strict: bool = False #: make sure all specified columns are in dataframe
check_index_name: bool = False
multiindex_name: Optional[str] = None #: name of multiindex

#: coerce types of all MultiIndex components
Expand Down Expand Up @@ -145,7 +146,7 @@ def validate(
)

@classmethod
def _build_columns_index(
def _build_columns_index( # pylint:disable=too-many-locals
cls,
checks: Dict[str, List[Check]],
annotations: Dict[str, Any],
Expand All @@ -154,39 +155,59 @@ def _build_columns_index(
Dict[str, schema_components.Column],
Optional[Union[schema_components.Index, schema_components.MultiIndex]],
]:
annotations = {
field_name: (parse_annotation(raw_annotation), raw_annotation)
for field_name, raw_annotation in annotations.items()
}
index_count = sum(
annotation.origin is Index
for annotation, _ in annotations.values()
)

columns: Dict[str, schema_components.Column] = {}
indices: List[schema_components.Index] = []
for field_name, raw_annotation in annotations.items():
annotation_info = parse_annotation(raw_annotation)
for field_name, (annotation, raw_annotation) in annotations.items():

field = getattr(cls, field_name, None)
if field is not None and not isinstance(field, FieldInfo):
raise SchemaInitError(
f"'{field_name}' can only be assigned a 'Field', "
+ f"not a '{field.__class__}.'"
)
field: FieldInfo = getattr(cls, field_name, None)
_check_fieldinfo(field, field_name)

field_checks = checks.get(field_name, [])
if annotation_info.origin is Series:
check_name = getattr(field, "check_name", None)

if annotation.origin is Series:
col_constructor = (
field.to_column if field else schema_components.Column
)

if check_name is False:
raise SchemaInitError(
f"'check_name' is not supported for {field_name}."
)

columns[field_name] = col_constructor( # type: ignore
annotation_info.arg,
required=not annotation_info.optional,
annotation.arg,
required=not annotation.optional,
checks=field_checks,
name=field_name,
)
elif annotation_info.origin is Index:
if annotation_info.optional:
elif annotation.origin is Index:
if annotation.optional:
raise SchemaInitError(
f"Index '{field_name}' cannot be Optional."
)

if check_name is False or (
# default single index
check_name is None
and index_count == 1
):
field_name = None # type:ignore

index_constructor = (
field.to_index if field else schema_components.Index
)
index = index_constructor( # type: ignore
annotation_info.arg, checks=field_checks, name=field_name
annotation.arg, checks=field_checks, name=field_name
)
indices.append(index)
else:
Expand Down Expand Up @@ -293,15 +314,12 @@ def _not_routine(member: Any) -> bool:


def _build_schema_index(
indices: List[schema_components.Index],
**multiindex_kwargs: Any,
indices: List[schema_components.Index], **multiindex_kwargs: Any
) -> Optional[SchemaIndex]:
index: Optional[SchemaIndex] = None
if indices:
if len(indices) == 1:
index = indices[0]
# don't force name on single index
index._name = None # pylint:disable=W0212
else:
index = schema_components.MultiIndex(indices, **multiindex_kwargs)
return index
Expand All @@ -314,3 +332,11 @@ def _regex_filter(seq: Iterable, regexps: Iterable[str]) -> Set[str]:
pattern = re.compile(regex)
matched.update(filter(pattern.match, seq))
return matched


def _check_fieldinfo(field: FieldInfo, field_name: str) -> None:
if field is not None and not isinstance(field, FieldInfo):
raise SchemaInitError(
f"'{field_name}' can only be assigned a 'Field', "
+ f"not a '{field.__class__}.'"
)
17 changes: 16 additions & 1 deletion pandera/model_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,14 @@ class FieldInfo:
*new in 0.5.0*
"""

__slots__ = ("checks", "nullable", "allow_duplicates", "coerce", "regex")
__slots__ = (
"checks",
"nullable",
"allow_duplicates",
"coerce",
"regex",
"check_name",
)

def __init__(
self,
Expand All @@ -53,12 +60,14 @@ def __init__(
allow_duplicates: bool = True,
coerce: bool = False,
regex: bool = False,
check_name: bool = None,
) -> None:
self.checks = _to_checklist(checks)
self.nullable = nullable
self.allow_duplicates = allow_duplicates
self.coerce = coerce
self.regex = regex
self.check_name = check_name

def _to_schema_component(
self,
Expand Down Expand Up @@ -131,13 +140,18 @@ def Field(
ignore_na: bool = True,
raise_warning: bool = False,
n_failure_cases: int = 10,
check_name: bool = None,
) -> Any:
"""Used to provide extra information about a field of a SchemaModel.

*new in 0.5.0*

Some arguments apply only to number dtypes and some apply only to ``str``.
See the :ref:`User Guide <schema_models>` for more.

:param check_name: Whether to check the name of the column/index during validation.
`None` is the default behavior, which translates to `True` for columns and
multi-index, and to `False` for a single index.
"""
# pylint:disable=C0103,W0613,R0914
check_kwargs = {
Expand All @@ -163,6 +177,7 @@ def Field(
allow_duplicates=allow_duplicates,
coerce=coerce,
regex=regex,
check_name=check_name,
)


Expand Down
2 changes: 1 addition & 1 deletion pandera/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,7 +943,7 @@ def validate(
if not inplace:
series = series.copy()

if series.name != self._name:
if self.name is not None and series.name != self._name:
msg = "Expected %s to have name '%s', found '%s'" % (
type(self),
self._name,
Expand Down
68 changes: 68 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,74 @@ class Schema(pa.SchemaModel):
assert expected == Schema.to_schema()


def test_column_check_name():
"""Test that column name is mandatory."""

class Schema(pa.SchemaModel):
a: Series[int] = pa.Field(check_name=False)

with pytest.raises(pa.errors.SchemaInitError):
Schema.to_schema()


def test_single_index_check_name():
"""Test single index name."""
df = pd.DataFrame(index=pd.Index(["cat", "dog"], name="animal"))

class DefaultSchema(pa.SchemaModel):
a: Index[str]

assert isinstance(DefaultSchema.validate(df), pd.DataFrame)

class DefaultFieldSchema(pa.SchemaModel):
a: Index[str] = pa.Field(check_name=None)

assert isinstance(DefaultFieldSchema.validate(df), pd.DataFrame)

class NotCheckNameSchema(pa.SchemaModel):
a: Index[str] = pa.Field(check_name=False)

assert isinstance(NotCheckNameSchema.validate(df), pd.DataFrame)

class SchemaNamedIndex(pa.SchemaModel):
a: Index[str] = pa.Field(check_name=True)

err_msg = "name 'a', found 'animal'"
with pytest.raises(pa.errors.SchemaError, match=err_msg):
SchemaNamedIndex.validate(df)


def test_multiindex_check_name():
"""Test a MultiIndex name."""

df = pd.DataFrame(
index=pd.MultiIndex.from_arrays(
[["foo", "bar"], [0, 1]], names=["a", "b"]
)
)

class DefaultSchema(pa.SchemaModel):
a: Index[str]
b: Index[int]

assert isinstance(DefaultSchema.validate(df), pd.DataFrame)

class CheckNameSchema(pa.SchemaModel):
a: Index[str] = pa.Field(check_name=True)
b: Index[int] = pa.Field(check_name=True)

assert isinstance(CheckNameSchema.validate(df), pd.DataFrame)

class NotCheckNameSchema(pa.SchemaModel):
a: Index[str] = pa.Field(check_name=False)
b: Index[int] = pa.Field(check_name=False)

df = pd.DataFrame(
index=pd.MultiIndex.from_arrays([["foo", "bar"], [0, 1]])
)
assert isinstance(NotCheckNameSchema.validate(df), pd.DataFrame)


def test_check_validate_method():
"""Test validate method on valid data."""

Expand Down