-
-
Notifications
You must be signed in to change notification settings - Fork 299
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix test_schemas, test_checks unit tests
- Loading branch information
1 parent
68aae18
commit 0473280
Showing
35 changed files
with
999 additions
and
272 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
"""Register dask accessor for pandera schema metadata.""" | ||
|
||
from dask.dataframe.extensions import ( | ||
register_dataframe_accessor, | ||
register_series_accessor, | ||
) | ||
|
||
from pandera._accessors.pandas_accessor import ( | ||
PanderaDataFrameAccessor, | ||
PanderaSeriesAccessor, | ||
) | ||
|
||
register_dataframe_accessor("pandera")(PanderaDataFrameAccessor) | ||
register_series_accessor("pandera")(PanderaSeriesAccessor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
"""Custom accessor functionality for modin. | ||
Source code adapted from pyspark.pandas implementation: | ||
https://spark.apache.org/docs/3.2.0/api/python/reference/pyspark.pandas/api/pyspark.pandas.extensions.register_dataframe_accessor.html?highlight=register_dataframe_accessor#pyspark.pandas.extensions.register_dataframe_accessor | ||
""" | ||
|
||
import warnings | ||
|
||
from pandera._accessors.pandas_accessor import ( | ||
PanderaDataFrameAccessor, | ||
PanderaSeriesAccessor, | ||
) | ||
|
||
|
||
# pylint: disable=too-few-public-methods | ||
class CachedAccessor: | ||
""" | ||
Custom property-like object. | ||
A descriptor for caching accessors: | ||
:param name: Namespace that accessor's methods, properties, etc will be | ||
accessed under, e.g. "foo" for a dataframe accessor yields the accessor | ||
``df.foo`` | ||
:param cls: Class with the extension methods. | ||
For accessor, the class's __init__ method assumes that you are registering | ||
an accessor for one of ``Series``, ``DataFrame``, or ``Index``. | ||
""" | ||
|
||
def __init__(self, name, accessor): | ||
self._name = name | ||
self._accessor = accessor | ||
|
||
def __get__(self, obj, cls): | ||
if obj is None: # pragma: no cover | ||
return self._accessor | ||
accessor_obj = self._accessor(obj) | ||
object.__setattr__(obj, self._name, accessor_obj) | ||
return accessor_obj | ||
|
||
|
||
def _register_accessor(name, cls): | ||
""" | ||
Register a custom accessor on {class} objects. | ||
:param name: Name under which the accessor should be registered. A warning | ||
is issued if this name conflicts with a preexisting attribute. | ||
:returns: A class decorator callable. | ||
""" | ||
|
||
def decorator(accessor): | ||
if hasattr(cls, name): | ||
msg = ( | ||
f"registration of accessor {accessor} under name '{name}' for " | ||
"type {cls.__name__} is overriding a preexisting attribute " | ||
"with the same name." | ||
) | ||
|
||
warnings.warn( | ||
msg, | ||
UserWarning, | ||
stacklevel=2, | ||
) | ||
setattr(cls, name, CachedAccessor(name, accessor)) | ||
return accessor | ||
|
||
return decorator | ||
|
||
|
||
def register_dataframe_accessor(name): | ||
""" | ||
Register a custom accessor with a DataFrame | ||
:param name: name used when calling the accessor after its registered | ||
:returns: a class decorator callable. | ||
""" | ||
# pylint: disable=import-outside-toplevel | ||
from modin.pandas import DataFrame | ||
|
||
return _register_accessor(name, DataFrame) | ||
|
||
|
||
def register_series_accessor(name): | ||
""" | ||
Register a custom accessor with a Series object | ||
:param name: name used when calling the accessor after its registered | ||
:returns: a callable class decorator | ||
""" | ||
# pylint: disable=import-outside-toplevel | ||
from modin.pandas import Series | ||
|
||
return _register_accessor(name, Series) | ||
|
||
|
||
register_dataframe_accessor("pandera")(PanderaDataFrameAccessor) | ||
register_series_accessor("pandera")(PanderaSeriesAccessor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
"""Register pandas accessor for pandera schema metadata.""" | ||
|
||
from typing import Optional, Union | ||
|
||
import pandas as pd | ||
|
||
from pandera.core.pandas.array import SeriesSchema | ||
from pandera.core.pandas.container import DataFrameSchema | ||
|
||
Schemas = Union[DataFrameSchema, SeriesSchema] | ||
|
||
|
||
class PanderaAccessor: | ||
"""Pandera accessor for pandas object.""" | ||
|
||
def __init__(self, pandas_obj): | ||
"""Initialize the pandera accessor.""" | ||
self._pandas_obj = pandas_obj | ||
self._schema: Optional[Schemas] = None | ||
|
||
@staticmethod | ||
def check_schema_type(schema: Schemas): | ||
"""Abstract method for checking the schema type.""" | ||
raise NotImplementedError | ||
|
||
def add_schema(self, schema): | ||
"""Add a schema to the pandas object.""" | ||
self.check_schema_type(schema) | ||
self._schema = schema | ||
return self._pandas_obj | ||
|
||
@property | ||
def schema(self) -> Optional[Schemas]: | ||
"""Access schema metadata.""" | ||
return self._schema | ||
|
||
|
||
@pd.api.extensions.register_dataframe_accessor("pandera") | ||
class PanderaDataFrameAccessor(PanderaAccessor): | ||
"""Pandera accessor for pandas DataFrame.""" | ||
|
||
@staticmethod | ||
def check_schema_type(schema): | ||
if not isinstance(schema, DataFrameSchema): | ||
raise TypeError( | ||
f"schema arg must be a DataFrameSchema, found {type(schema)}" | ||
) | ||
|
||
|
||
@pd.api.extensions.register_series_accessor("pandera") | ||
class PanderaSeriesAccessor(PanderaAccessor): | ||
"""Pandera accessor for pandas Series.""" | ||
|
||
@staticmethod | ||
def check_schema_type(schema): | ||
if not isinstance(schema, SeriesSchema): | ||
raise TypeError( | ||
f"schema arg must be a SeriesSchema, found {type(schema)}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# pylint: skip-file | ||
# NOTE: skip file since py=3.10 yields these errors: | ||
# https://github.com/pandera-dev/pandera/runs/4998710717?check_suite_focus=true | ||
"""Register pyspark accessor for pandera schema metadata.""" | ||
|
||
from pyspark.pandas.extensions import ( | ||
register_dataframe_accessor, | ||
register_series_accessor, | ||
) | ||
|
||
from pandera._accessors.pandas_accessor import ( | ||
PanderaDataFrameAccessor, | ||
PanderaSeriesAccessor, | ||
) | ||
|
||
register_dataframe_accessor("pandera")(PanderaDataFrameAccessor) | ||
register_series_accessor("pandera")(PanderaSeriesAccessor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
"""Module for inferring dataframe/series schema.""" | ||
|
||
from typing import overload | ||
|
||
import pandas as pd | ||
|
||
from pandera._schema_statistics.pandas import ( | ||
infer_dataframe_statistics, | ||
infer_series_statistics, | ||
parse_check_statistics, | ||
) | ||
from pandera.core.pandas.array import SeriesSchema | ||
from pandera.core.pandas.components import Column, Index, MultiIndex | ||
from pandera.core.pandas.container import DataFrameSchema | ||
|
||
|
||
@overload | ||
def infer_schema( | ||
pandas_obj: pd.Series, | ||
) -> SeriesSchema: # pragma: no cover | ||
... | ||
|
||
|
||
@overload | ||
def infer_schema( # type: ignore[misc] | ||
pandas_obj: pd.DataFrame, | ||
) -> DataFrameSchema: # pragma: no cover | ||
... | ||
|
||
|
||
def infer_schema(pandas_obj): | ||
"""Infer schema for pandas DataFrame or Series object. | ||
:param pandas_obj: DataFrame or Series object to infer. | ||
:returns: DataFrameSchema or SeriesSchema | ||
:raises: TypeError if pandas_obj is not expected type. | ||
""" | ||
if isinstance(pandas_obj, pd.DataFrame): | ||
return infer_dataframe_schema(pandas_obj) | ||
elif isinstance(pandas_obj, pd.Series): | ||
return infer_series_schema(pandas_obj) | ||
else: | ||
raise TypeError( | ||
"pandas_obj type not recognized. Expected a pandas DataFrame or " | ||
f"Series, found {type(pandas_obj)}" | ||
) | ||
|
||
|
||
def _create_index(index_statistics): | ||
index = [ | ||
Index( | ||
properties["dtype"], | ||
checks=parse_check_statistics(properties["checks"]), | ||
nullable=properties["nullable"], | ||
name=properties["name"], | ||
) | ||
for properties in index_statistics | ||
] | ||
if len(index) == 1: | ||
index = index[0] # type: ignore | ||
else: | ||
index = MultiIndex(index) # type: ignore | ||
|
||
return index | ||
|
||
|
||
def infer_dataframe_schema(df: pd.DataFrame) -> DataFrameSchema: | ||
"""Infer a DataFrameSchema from a pandas DataFrame. | ||
:param df: DataFrame object to infer. | ||
:returns: DataFrameSchema | ||
""" | ||
df_statistics = infer_dataframe_statistics(df) | ||
schema = DataFrameSchema( | ||
columns={ | ||
colname: Column( | ||
properties["dtype"], | ||
checks=parse_check_statistics(properties["checks"]), | ||
nullable=properties["nullable"], | ||
) | ||
for colname, properties in df_statistics["columns"].items() | ||
}, | ||
index=_create_index(df_statistics["index"]), | ||
coerce=True, | ||
) | ||
schema._is_inferred = True | ||
return schema | ||
|
||
|
||
def infer_series_schema(series) -> SeriesSchema: | ||
"""Infer a SeriesSchema from a pandas DataFrame. | ||
:param series: Series object to infer. | ||
:returns: SeriesSchema | ||
""" | ||
series_statistics = infer_series_statistics(series) | ||
schema = SeriesSchema( | ||
dtype=series_statistics["dtype"], | ||
checks=parse_check_statistics(series_statistics["checks"]), | ||
nullable=series_statistics["nullable"], | ||
name=series_statistics["name"], | ||
coerce=True, | ||
) | ||
schema._is_inferred = True | ||
return schema |
Oops, something went wrong.