From 74be58cf0f7cf97c1ed1b7c0f5a6491695f8be95 Mon Sep 17 00:00:00 2001 From: Neeraj Malhotra <52220398+NeerajMalhotra-QB@users.noreply.github.com> Date: Fri, 12 May 2023 20:20:50 -0700 Subject: [PATCH] Support for Native PySpark (#1185) * init * init * init structure * disable imports' * adding structure for pyspark * setting dependency * update class import * fixing datatype cls * adding dtypes for pyspark * keep only bool type * remove pydantic schema * register pyspark data types * add check method * updating equivalents to native types * update column schema * refactor array to column * rename array to column_schema * remove pandas imports * remove index and multiindex functionality * adding pydantic schema class * adding model components * add model config * define pyspark BaseConfig class * removing index and multi-indexes * remove modify schema * Pyspark backend components, base, container, accessor, test file for accessor * Pyspark backend components, base, container, accessor, test file for accessor * Pyspark backend components, base, container, accessor, test file for accessor * Pyspark backend components, base, container, accessor, test file for accessor * add pyspark model components and types * remove hypothesis * remove synthesis and hypothesis * Pyspark backend components, base, container, accessor, test file for accessor * test for pyspark dataframeschema class * test schema with alias types * ensuring treat dataframes as tables types * update container for pyspark dataframe * adding negative test flow * removing series and index on pysparrk dataframes * remove series * revert series from pyspark.pandas * adding checks for pyspark * registering pysparkCheckBackend * cleaning base * Fixing the broken type cast check, validation of schema fix. * define spark level schema * fixing check flow * setting apply fn * add sub sample functionality * adjusting test case against common attributes * need apply for column level check * adding builtin checks for pyspark * adding checks for pyspark df * getting check registered * fixing a bug a in error handling for schema check * check_name validation fixed * implementing dtype checks for pyspark * updating error msg * fixing dtype reason_code * updating builtin checks for pyspark * registeration * Implementation of checks import and spark columns information check * enhancing __call__, checks classes and builtin_checks * delete junk files * Changes to fix the implemtation of checks. Changed Apply function to send list with dataframe and column name, builtin function registers functions with lists which inculdes the dataframe * extending pyspark checks * Fixed builtin check bug and added test for supported builtin checks for pyspark * add todos * bydefault validate all checks * fixing issue with sqlctx * add dtypes pytests * setting up schema * add negative and positive tests * add fixtures and refactor tests * generalize spark_df func * refactor to use conftest * use conftest * add support for decimal dtype and fixing other types * Added new Datatypes support for pyspark, test cases for dtypes pyspark, created test file for error * refactor ArraySchema * rename array to column.py * 1) Changes in test cases to look for summarised error raise instead of fast fail, since default behaviour is changed to summarised. 2) Added functionality to accept and check the precision and scale in Decimal Datatypes. * add neg test * add custom ErrorHandler * Added functionality to DayTimeIntervalType datatype to accept parameters * Added functionality to DayTimeIntervalType datatype to accept parameters * return summarized error report * replace dataframe to dict for return obj * Changed checks input datatype to custom named tuple from the existing list. Also started changing the pyspark checks to include more datatypes * refactor * introduce error categories * rename error categories * fixing bug in schema.dtype.check * fixing error category to by dynamic * Added checks for each datatype in test cases. Reduced the code redundancy of the code in test file. Refactored the name of custom datatype object for checks. * error_handler pass through * add ErrorHandler to column api * removed SchemaErrors since we now aggregate in errorHandler * fixing dict keys * Added Decorator to raise TypeError in case of unexpected input type for the check function. * replace validator with report_errors * cleaning debugs * Support DataModels and Field * Added Decorator to raise TypeError in case of unexpected input type for the check function. Merged with Develop * Fix to run using the class schema type * use alias types * clean up * add new typing for pyspark.sql * Added Decorator to raise TypeError in case of unexpected input type for the check function. Merged with Develop * Added changes to support raising error for use of datatype not supported by the check and support for map and array type. * support bare dtypes for DataFrameModel * remove resolved TODOs and breakpoints * change to bare types * use spark types instead of bare types * using SchemaErrorReason instead of hardcode in container * fixing an issue with error reason codes * minor fix * fixing checks and errors in pyspark * Changes include the following: 1) Updated dtypes test functionality to make it more readable 2) Changes in accessor tests to support the new functionality 3) Changes in engine class to conform to check class everywhere else * enhancing dataframeschema and model classes * Changes to remove the pandas dependency * Refactoring of the checks test functions * Fixing the test case breaking * Isort and Black formatting * Container Test function failure * Isort and black linting * Changes to remove the pandas dependency * Refactoring of the checks test functions * Isort and black linting * Added Changes to refactor the checks class. Fixes to some test cases failures. * Removing breakpoint * fixing raise error * adding metadata dict * Removing the reference of pandas from docstrings * Removing redundant code block in utils * Changes to return dataframe with errors property * add accessor for errorHandler * support errors access on pyspark.sql * updating pyspark error tcs * fixing model test cases * adjusting errors to use pandera.errors * use accessor instead of dict * revert to develop * Removal of imports which are not needed and improved test case. * setting independent pyspark import * pyspark imports * revert comments * store and retrieve metadata at schema levels * adding metadata support * Added changes to support parameter based run. 1) Added parameters.yaml file to hold the configurations 2) Added code in utility to read the config 3) Updated the test cases to support the parameter based run 4) Moved pyspark decorators to a new file decorators.py in backend 5) Type fix in get_matadata property in container.py file * Changing the default value in config * change to consistent interface * cleaning api/pyspark * backend and tests * adding setter on errors accessors for pyspark * reformatting error dict * doc * run black linter Signed-off-by: Niels Bantilan * fix lint Signed-off-by: Niels Bantilan * update pylintrc Signed-off-by: Niels Bantilan --------- Signed-off-by: Niels Bantilan Co-authored-by: jaskaransinghsidana Co-authored-by: jaskaransinghsidana <112083212+jaskaransinghsidana@users.noreply.github.com> Co-authored-by: Niels Bantilan --- .pre-commit-config.yaml | 2 +- .pylintrc | 6 +- conf/pyspark/parameters.yaml | 3 + mypy.ini | 8 +- pandera/__init__.py | 2 +- pandera/accessors/pyspark_sql_accessor.py | 136 ++ pandera/api/base/checks.py | 3 +- pandera/api/base/model.py | 1 - pandera/api/base/model_components.py | 6 + pandera/api/base/schema.py | 19 +- pandera/api/checks.py | 3 +- pandera/api/extensions.py | 9 +- pandera/api/hypotheses.py | 1 - pandera/api/pandas/array.py | 6 +- pandera/api/pandas/components.py | 2 +- pandera/api/pandas/container.py | 1 - pandera/api/pandas/model.py | 2 +- pandera/api/pandas/types.py | 11 +- pandera/api/pyspark/__init__.py | 3 + pandera/api/pyspark/column_schema.py | 193 +++ pandera/api/pyspark/components.py | 174 ++ pandera/api/pyspark/container.py | 605 +++++++ pandera/api/pyspark/error_handler.py | 98 ++ pandera/api/pyspark/model.py | 544 +++++++ pandera/api/pyspark/model_components.py | 303 ++++ pandera/api/pyspark/model_config.py | 65 + pandera/api/pyspark/types.py | 100 ++ pandera/backends/base/builtin_checks.py | 1 - pandera/backends/pandas/__init__.py | 2 - pandera/backends/pandas/base.py | 4 +- pandera/backends/pandas/builtin_checks.py | 4 +- pandera/backends/pandas/builtin_hypotheses.py | 3 +- pandera/backends/pandas/checks.py | 8 +- pandera/backends/pandas/components.py | 2 +- pandera/backends/pandas/container.py | 4 +- pandera/backends/pandas/hypotheses.py | 5 +- pandera/backends/pyspark/__init__.py | 19 + pandera/backends/pyspark/base.py | 120 ++ pandera/backends/pyspark/builtin_checks.py | 326 ++++ pandera/backends/pyspark/checks.py | 117 ++ pandera/backends/pyspark/column.py | 241 +++ pandera/backends/pyspark/components.py | 163 ++ pandera/backends/pyspark/container.py | 546 +++++++ pandera/backends/pyspark/decorators.py | 96 ++ pandera/backends/pyspark/error_formatters.py | 26 + pandera/backends/pyspark/utils.py | 60 + pandera/decorators.py | 3 +- pandera/dtypes.py | 13 + pandera/engines/engine.py | 4 + pandera/engines/numpy_engine.py | 3 +- pandera/engines/pandas_engine.py | 3 +- pandera/engines/pyspark_engine.py | 568 +++++++ pandera/engines/type_aliases.py | 2 + pandera/engines/utils.py | 4 +- pandera/errors.py | 5 + pandera/extensions.py | 2 +- pandera/io/__init__.py | 24 +- pandera/io/pandas_io.py | 5 +- pandera/mypy.py | 6 +- pandera/pyspark.py | 108 ++ pandera/schema_inference/pandas.py | 6 +- pandera/schema_statistics/__init__.py | 8 +- pandera/strategies/base_strategies.py | 1 - pandera/strategies/pandas_strategies.py | 1 - pandera/typing/__init__.py | 29 +- pandera/typing/common.py | 49 +- pandera/typing/dask.py | 3 +- pandera/typing/geopandas.py | 1 - pandera/typing/modin.py | 3 +- pandera/typing/pandas.py | 2 +- pandera/typing/pyspark.py | 10 +- pandera/typing/pyspark_sql.py | 40 + tests/modin/__init__.py | 0 tests/pyspark/__init__.py | 0 tests/pyspark/conftest.py | 155 ++ tests/pyspark/test_pyspark_accessor.py | 65 + tests/pyspark/test_pyspark_check.py | 1395 +++++++++++++++++ tests/pyspark/test_pyspark_container.py | 116 ++ tests/pyspark/test_pyspark_dtypes.py | 388 +++++ tests/pyspark/test_pyspark_engine.py | 32 + tests/pyspark/test_pyspark_error.py | 153 ++ tests/pyspark/test_pyspark_model.py | 298 ++++ ...k.py => test_schemas_on_pyspark_pandas.py} | 2 +- 83 files changed, 7455 insertions(+), 105 deletions(-) create mode 100644 conf/pyspark/parameters.yaml create mode 100644 pandera/accessors/pyspark_sql_accessor.py create mode 100644 pandera/api/pyspark/__init__.py create mode 100644 pandera/api/pyspark/column_schema.py create mode 100644 pandera/api/pyspark/components.py create mode 100644 pandera/api/pyspark/container.py create mode 100644 pandera/api/pyspark/error_handler.py create mode 100644 pandera/api/pyspark/model.py create mode 100644 pandera/api/pyspark/model_components.py create mode 100644 pandera/api/pyspark/model_config.py create mode 100644 pandera/api/pyspark/types.py create mode 100644 pandera/backends/pyspark/__init__.py create mode 100644 pandera/backends/pyspark/base.py create mode 100644 pandera/backends/pyspark/builtin_checks.py create mode 100644 pandera/backends/pyspark/checks.py create mode 100644 pandera/backends/pyspark/column.py create mode 100644 pandera/backends/pyspark/components.py create mode 100644 pandera/backends/pyspark/container.py create mode 100644 pandera/backends/pyspark/decorators.py create mode 100644 pandera/backends/pyspark/error_formatters.py create mode 100644 pandera/backends/pyspark/utils.py create mode 100644 pandera/engines/pyspark_engine.py create mode 100644 pandera/pyspark.py create mode 100644 pandera/typing/pyspark_sql.py create mode 100644 tests/modin/__init__.py create mode 100644 tests/pyspark/__init__.py create mode 100644 tests/pyspark/conftest.py create mode 100644 tests/pyspark/test_pyspark_accessor.py create mode 100644 tests/pyspark/test_pyspark_check.py create mode 100644 tests/pyspark/test_pyspark_container.py create mode 100644 tests/pyspark/test_pyspark_dtypes.py create mode 100644 tests/pyspark/test_pyspark_engine.py create mode 100644 tests/pyspark/test_pyspark_error.py create mode 100644 tests/pyspark/test_pyspark_model.py rename tests/pyspark/{test_schemas_on_pyspark.py => test_schemas_on_pyspark_pandas.py} (99%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 580cacfb6..45ce90e35 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,7 +13,7 @@ repos: - id: check-yaml description: Attempts to load all yaml files to verify syntax - id: debug-statements - description: Check for debugger imports and py37+ breakpoint() calls in python source + description: Check for debugger imports and py37+ calls in python source - id: end-of-file-fixer description: Makes sure files end in a newline and only a newline - id: trailing-whitespace diff --git a/.pylintrc b/.pylintrc index 6bed87754..6ba161540 100644 --- a/.pylintrc +++ b/.pylintrc @@ -1,5 +1,6 @@ [BASIC] -ignore=mypy.py,noxfile.py +ignore=mypy.py,noxfile.py,pandera/accessors/pyspark_sql_accessor.py,pandera/engines/pyspark_engine.py,pandera/pyspark.py,pandera/typing/pyspark_sql.py, +ignore-patterns=pandera/api/pyspark/*,tests/pyspark/* good-names= T, F, @@ -45,4 +46,5 @@ disable= function-redefined, arguments-differ, unnecessary-dunder-call, - use-dict-literal + use-dict-literal, + invalid-name diff --git a/conf/pyspark/parameters.yaml b/conf/pyspark/parameters.yaml new file mode 100644 index 000000000..9bf14b83b --- /dev/null +++ b/conf/pyspark/parameters.yaml @@ -0,0 +1,3 @@ +# Params are case sensitive use only upper case +VALIDATION: ENABLE # Supported Value [ENABLE/DISABLE] +DEPTH: SCHEMA_AND_DATA #[Supported values: SCHEMA_ONLY, DATA_ONLY, SCHEMA_AND_DATA] \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index a346783ab..ebec0d32e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -5,4 +5,10 @@ allow_redefinition = True warn_return_any = False warn_unused_configs = True show_error_codes = True -exclude = tests/mypy/modules +exclude=(?x)( + ^tests/mypy/modules + | ^pandera/engines/pyspark_engine + | ^pandera/api/pyspark + | ^pandera/backends/pyspark + | ^tests/pyspark + ) diff --git a/pandera/__init__.py b/pandera/__init__.py index aefbb6241..30194b9a4 100644 --- a/pandera/__init__.py +++ b/pandera/__init__.py @@ -1,6 +1,7 @@ """A flexible and expressive pandas validation library.""" import platform +import pandera.backends from pandera import errors, external_config, typing from pandera.accessors import pandas_accessor from pandera.api import extensions @@ -85,7 +86,6 @@ except ImportError: pass - try: import modin.pandas diff --git a/pandera/accessors/pyspark_sql_accessor.py b/pandera/accessors/pyspark_sql_accessor.py new file mode 100644 index 000000000..954fef4f6 --- /dev/null +++ b/pandera/accessors/pyspark_sql_accessor.py @@ -0,0 +1,136 @@ +"""Custom accessor functionality for PySpark.Sql. +""" + +import warnings +from functools import wraps +from typing import Optional, Union + + +from pandera.api.pyspark.container import DataFrameSchema +from pandera.api.pyspark.error_handler import ErrorHandler + +"""Register pyspark accessor for pandera schema metadata.""" + + +Schemas = Union[DataFrameSchema] +Errors = Union[ErrorHandler] + + +# Todo Refactor to create a seperate module for panderaAccessor +class PanderaAccessor: + """Pandera accessor for pyspark object.""" + + def __init__(self, pyspark_obj): + """Initialize the pandera accessor.""" + self._pyspark_obj = pyspark_obj + self._schema: Optional[Schemas] = None + self._errors: Optional[Errors] = None + + @staticmethod + def check_schema_type(schema: Schemas): + """Abstract method for checking the schema type.""" + raise NotImplementedError + + def add_schema(self, schema): + """Add a schema to the pyspark object.""" + self.check_schema_type(schema) + self._schema = schema + return self._pyspark_obj + + @property + def schema(self) -> Optional[Schemas]: + """Access schema metadata.""" + return self._schema + + @property + def errors(self) -> Optional[Errors]: + """Access errors data.""" + return self._errors + + @errors.setter + def errors(self, value: dict): + """Set errors data.""" + self._errors = value + + +class CachedAccessor: + """ + Custom property-like object. + + A descriptor for caching accessors: + + :param name: Namespace that accessor's methods, properties, etc will be + accessed under, e.g. "foo" for a dataframe accessor yields the accessor + ``df.foo`` + :param cls: Class with the extension methods. + + For accessor, the class's __init__ method assumes that you are registering + an accessor for one of ``Series``, ``DataFrame``, or ``Index``. + """ + + def __init__(self, name, accessor): + self._name = name + self._accessor = accessor + + def __get__(self, obj, cls): + if obj is None: # pragma: no cover + return self._accessor + accessor_obj = self._accessor(obj) + object.__setattr__(obj, self._name, accessor_obj) + return accessor_obj + + +def _register_accessor(name, cls): + """ + Register a custom accessor on {class} objects. + + :param name: Name under which the accessor should be registered. A warning + is issued if this name conflicts with a preexisting attribute. + :returns: A class decorator callable. + """ + + def decorator(accessor): + if hasattr(cls, name): + msg = ( + f"registration of accessor {accessor} under name '{name}' for " + "type {cls.__name__} is overriding a preexisting attribute " + "with the same name." + ) + + warnings.warn( + msg, + UserWarning, + stacklevel=2, + ) + setattr(cls, name, CachedAccessor(name, accessor)) + return accessor + + return decorator + + +def register_dataframe_accessor(name): + """ + Register a custom accessor with a DataFrame + + :param name: name used when calling the accessor after its registered + :returns: a class decorator callable. + """ + # pylint: disable=import-outside-toplevel + from pyspark.sql import DataFrame + + return _register_accessor(name, DataFrame) + + +class PanderaDataFrameAccessor(PanderaAccessor): + """Pandera accessor for pyspark DataFrame.""" + + @staticmethod + def check_schema_type(schema): + if not isinstance(schema, DataFrameSchema): + raise TypeError( + f"schema arg must be a DataFrameSchema, found {type(schema)}" + ) + + +register_dataframe_accessor("pandera")(PanderaDataFrameAccessor) +# register_series_accessor("pandera")(PanderaSeriesAccessor) diff --git a/pandera/api/base/checks.py b/pandera/api/base/checks.py index 9c584e4b1..bd9ca8486 100644 --- a/pandera/api/base/checks.py +++ b/pandera/api/base/checks.py @@ -15,6 +15,7 @@ Union, no_type_check, ) + import pandas as pd from multimethod import multidispatch as _multidispatch @@ -173,7 +174,6 @@ def from_builtin_check_name( # by the check object if statistics is None: statistics = check_kwargs - return cls( cls.get_builtin_check_fn(name), statistics=statistics, @@ -188,6 +188,7 @@ def register_backend(cls, type_: Type, backend: Type[BaseCheckBackend]): @classmethod def get_backend(cls, check_obj: Any) -> Type[BaseCheckBackend]: """Get the backend associated with the type of ``check_obj`` .""" + check_obj_cls = type(check_obj) classes = inspect.getmro(check_obj_cls) for _class in classes: diff --git a/pandera/api/base/model.py b/pandera/api/base/model.py index 0859df59e..4b60feefc 100644 --- a/pandera/api/base/model.py +++ b/pandera/api/base/model.py @@ -18,7 +18,6 @@ from pandera.api.checks import Check from pandera.typing import AnnotationInfo - TBaseModel = TypeVar("TBaseModel", bound="BaseModel") diff --git a/pandera/api/base/model_components.py b/pandera/api/base/model_components.py index 49472e876..df63f0f8b 100644 --- a/pandera/api/base/model_components.py +++ b/pandera/api/base/model_components.py @@ -43,6 +43,7 @@ class BaseFieldInfo: "title", "description", "default", + "metadata", ) def __init__( @@ -58,6 +59,7 @@ def __init__( title: Optional[str] = None, description: Optional[str] = None, default: Optional[Any] = None, + metadata: Optional[dict] = None, ) -> None: self.checks = to_checklist(checks) self.nullable = nullable @@ -71,6 +73,7 @@ def __init__( self.title = title self.description = description self.default = default + self.metadata = metadata @property def name(self) -> str: @@ -107,6 +110,9 @@ def __ne__(self, other): def __set__(self, instance: Any, value: Any) -> None: # pragma: no cover raise AttributeError(f"Can't set the {self.original_name} field.") + def __get_metadata__(self): + return self.metadata + class BaseCheckInfo: # pylint:disable=too-few-public-methods """Captures extra information about a Check.""" diff --git a/pandera/api/base/schema.py b/pandera/api/base/schema.py index b5db52310..f843e1c62 100644 --- a/pandera/api/base/schema.py +++ b/pandera/api/base/schema.py @@ -8,7 +8,7 @@ import inspect from abc import ABC from functools import wraps -from typing import Any, Dict, Tuple, Type, Union +from typing import Any, Dict, Tuple, Type, Optional, Union from pandera.backends.base import BaseSchemaBackend from pandera.errors import BackendNotFoundError @@ -32,6 +32,7 @@ def __init__( name=None, title=None, description=None, + metadata=None, ): """Abstract base schema initializer.""" self.dtype = dtype @@ -40,6 +41,7 @@ def __init__( self.name = name self.title = title self.description = description + self.metadata = metadata def validate( self, @@ -69,9 +71,20 @@ def register_backend(cls, type_: Type, backend: Type[BaseSchemaBackend]): cls.BACKEND_REGISTRY[(cls, type_)] = backend @classmethod - def get_backend(cls, check_obj: Any) -> BaseSchemaBackend: + def get_backend( + cls, + check_obj: Optional[Any] = None, + check_type: Optional[Type] = None, + ) -> BaseSchemaBackend: """Get the backend associated with the type of ``check_obj`` .""" - check_obj_cls = type(check_obj) + if check_obj is not None: + check_obj_cls = type(check_obj) + elif check_type is not None: + check_obj_cls = check_type + else: + raise ValueError( + "Must pass in one of `check_obj` or `check_type`." + ) classes = inspect.getmro(check_obj_cls) for _class in classes: try: diff --git a/pandera/api/checks.py b/pandera/api/checks.py index a8b056ff8..5ac463725 100644 --- a/pandera/api/checks.py +++ b/pandera/api/checks.py @@ -18,7 +18,6 @@ from pandera.api.base.checks import BaseCheck, CheckResult from pandera.strategies import SearchStrategy - T = TypeVar("T") @@ -165,7 +164,6 @@ def __init__( """ super().__init__(name=name, error=error) - if element_wise and groupby is not None: raise errors.SchemaInitError( "Cannot use groupby when element_wise=True." @@ -480,6 +478,7 @@ def str_startswith(cls, string: str, **kwargs) -> "Check": :param string: String all values should start with :param kwargs: key-word arguments passed into the `Check` initializer. """ + return cls.from_builtin_check_name( "str_startswith", kwargs, diff --git a/pandera/api/extensions.py b/pandera/api/extensions.py index b6b278c89..18cc2ff38 100644 --- a/pandera/api/extensions.py +++ b/pandera/api/extensions.py @@ -136,7 +136,10 @@ def register_check_method( check_fn=None, *, statistics: Optional[List[str]] = None, - supported_types: Union[type, Tuple, List] = (pd.DataFrame, pd.Series), + supported_types: Union[type, Tuple, List] = ( + pd.DataFrame, + pd.Series, + ), check_type: Union[CheckType, str] = "vectorized", strategy=None, ): @@ -156,8 +159,8 @@ def register_check_method( which serve as the statistics needed to serialize/de-serialize the check and generate data if a ``strategy`` function is provided. :param supported_types: the pandas type(s) supported by the check function. - Valid values are ``pd.DataFrame``, ``pd.Series``, or a list/tuple of - ``(pa.DataFrame, pa.Series)`` if both types are supported. + Valid values are ``pd.DataFrame``, ``pd.Series``, ``ps.DataFrame``, or a list/tuple of + ``(pa.DataFrame, pa.Series, ps.DataFrame)`` if both types are supported. :param check_type: the expected input of the check function. Valid values are :class:`~pandera.extensions.CheckType` enums or ``{"vectorized", "element_wise", "groupby"}``. The input signature of diff --git a/pandera/api/hypotheses.py b/pandera/api/hypotheses.py index 0fb794245..2c1b46a5a 100644 --- a/pandera/api/hypotheses.py +++ b/pandera/api/hypotheses.py @@ -6,7 +6,6 @@ from pandera.api.checks import Check from pandera.strategies import SearchStrategy - DEFAULT_ALPHA = 0.01 diff --git a/pandera/api/pandas/array.py b/pandera/api/pandas/array.py index e10e04afe..f4c714639 100644 --- a/pandera/api/pandas/array.py +++ b/pandera/api/pandas/array.py @@ -11,11 +11,7 @@ from pandera.api.base.schema import BaseSchema, inferred_schema_guard from pandera.api.checks import Check from pandera.api.hypotheses import Hypothesis -from pandera.api.pandas.types import ( - CheckList, - PandasDtypeInputTypes, - is_field, -) +from pandera.api.pandas.types import CheckList, PandasDtypeInputTypes, is_field from pandera.dtypes import DataType, UniqueSettings from pandera.engines import pandas_engine diff --git a/pandera/api/pandas/components.py b/pandera/api/pandas/components.py index 1012f1899..b4ca9a358 100644 --- a/pandera/api/pandas/components.py +++ b/pandera/api/pandas/components.py @@ -181,7 +181,7 @@ def get_regex_columns( from pandera.backends.pandas.components import ColumnBackend return cast( - ColumnBackend, self.get_backend(pd.DataFrame()) + ColumnBackend, self.get_backend(check_type=pd.DataFrame) ).get_regex_columns(self, columns) def __eq__(self, other): diff --git a/pandera/api/pandas/container.py b/pandera/api/pandas/container.py index 2bbb4e47f..a17946499 100644 --- a/pandera/api/pandas/container.py +++ b/pandera/api/pandas/container.py @@ -358,7 +358,6 @@ def _validate( lazy: bool = False, inplace: bool = False, ) -> pd.DataFrame: - if self._is_inferred: warnings.warn( f"This {type(self)} is an inferred schema that hasn't been " diff --git a/pandera/api/pandas/model.py b/pandera/api/pandas/model.py index 11d2d985c..55ec2dcbc 100644 --- a/pandera/api/pandas/model.py +++ b/pandera/api/pandas/model.py @@ -23,7 +23,6 @@ import pandas as pd -from pandera.strategies import pandas_strategies as st from pandera.api.base.model import BaseModel from pandera.api.checks import Check from pandera.api.pandas.components import Column, Index, MultiIndex @@ -38,6 +37,7 @@ ) from pandera.api.pandas.model_config import BaseConfig from pandera.errors import SchemaInitError +from pandera.strategies import pandas_strategies as st from pandera.typing import INDEX_TYPES, SERIES_TYPES, AnnotationInfo from pandera.typing.common import DataFrameBase diff --git a/pandera/api/pandas/types.py b/pandera/api/pandas/types.py index dbfa390ec..258fa7316 100644 --- a/pandera/api/pandas/types.py +++ b/pandera/api/pandas/types.py @@ -3,11 +3,6 @@ from functools import lru_cache from typing import List, NamedTuple, Tuple, Type, Union -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal # type: ignore [misc] - import numpy as np import pandas as pd @@ -15,6 +10,12 @@ from pandera.api.hypotheses import Hypothesis from pandera.dtypes import DataType +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal # type: ignore [misc] + + CheckList = Union[Check, List[Union[Check, Hypothesis]]] PandasDtypeInputTypes = Union[ diff --git a/pandera/api/pyspark/__init__.py b/pandera/api/pyspark/__init__.py new file mode 100644 index 000000000..af113c389 --- /dev/null +++ b/pandera/api/pyspark/__init__.py @@ -0,0 +1,3 @@ +"""PySpark native core.""" +from pandera.api.pyspark.components import Column +from pandera.api.pyspark.container import DataFrameSchema diff --git a/pandera/api/pyspark/column_schema.py b/pandera/api/pyspark/column_schema.py new file mode 100644 index 000000000..de3368e39 --- /dev/null +++ b/pandera/api/pyspark/column_schema.py @@ -0,0 +1,193 @@ +"""Core pyspark array specification.""" + +import copy +from typing import Any, List, Optional, TypeVar, cast + +import pyspark.sql as ps + +from pandera.api.base.schema import BaseSchema, inferred_schema_guard +from pandera.api.checks import Check +from pandera.api.pyspark.error_handler import ErrorHandler +from pandera.api.pyspark.types import CheckList, PySparkDtypeInputTypes +from pandera.dtypes import DataType +from pandera.engines import pyspark_engine + +TColumnSchemaBase = TypeVar("TColumnSchemaBase", bound="ColumnSchema") + + +class ColumnSchema(BaseSchema): + """Base column validator object.""" + + def __init__( + self, + dtype: Optional[PySparkDtypeInputTypes] = None, + checks: Optional[CheckList] = None, + nullable: bool = False, + coerce: bool = False, + name: Any = None, + title: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[dict] = None, + ) -> None: + """Initialize array schema. + + :param dtype: datatype of the column. + :param checks: If element_wise is True, then callable signature should + be: + + ``Callable[Any, bool]`` where the ``Any`` input is a scalar element + in the column. Otherwise, the input is assumed to be a + dataframe object. + :param nullable: Whether or not column can contain null values. + :param coerce: If True, when schema.validate is called the column will + be coerced into the specified dtype. This has no effect on columns + where ``dtype=None``. + :param name: column name in dataframe to validate. + :param title: A human-readable label for the series. + :param description: An arbitrary textual description of the series. + :param metadata: An optional key-value data. + :type nullable: bool + """ + + super().__init__( + dtype=dtype, + checks=checks, + coerce=coerce, + name=name, + title=title, + description=description, + metadata=metadata, + ) + + if checks is None: + checks = [] + if isinstance(checks, Check): + checks = [checks] + self.checks = checks + self.nullable = nullable + self.title = title + self.description = description + self.metadata = metadata + self._dtype = None + + @property + def dtype(self) -> DataType: + """Get the pyspark dtype""" + return self._dtype # type: ignore + + @dtype.setter + def dtype(self, value: Optional[PySparkDtypeInputTypes]) -> None: + """Set the pyspark dtype""" + # this is a pylint false positive + # pylint: disable=no-value-for-parameter + self._dtype = pyspark_engine.Engine.dtype(value) if value else None + + def validate( + self, + check_obj, + head: Optional[int] = None, + tail: Optional[int] = None, + sample: Optional[int] = None, + random_state: Optional[int] = None, + lazy: bool = False, + inplace: bool = False, + error_handler: ErrorHandler = None, + ): + # pylint: disable=too-many-locals,too-many-branches,too-many-statements + """Validate a series or specific column in dataframe. + + :check_obj: pyspark DataFrame to validate. + :param head: validate the first n rows. Rows overlapping with `tail` or + `sample` are de-duplicated. + :param tail: validate the last n rows. Rows overlapping with `head` or + `sample` are de-duplicated. + :param sample: validate a random sample of n rows. Rows overlapping + with `head` or `tail` are de-duplicated. + :param random_state: random seed for the ``sample`` argument. + :param lazy: if True, lazily evaluates dataframe against all validation + checks and raises a ``SchemaErrors``. Otherwise, raise + ``SchemaError`` as soon as one occurs. + :param inplace: if True, applies coercion to the object of validation, + otherwise creates a copy of the data. + :returns: validated DataFrame or Series. + + """ + return self.get_backend(check_obj).validate( + check_obj, + schema=self, + head=head, + tail=tail, + sample=sample, + random_state=random_state, + lazy=lazy, + inplace=inplace, + error_handler=error_handler, + ) + + def __call__( + self, + check_obj: ps.DataFrame, + head: Optional[int] = None, + tail: Optional[int] = None, + sample: Optional[int] = None, + random_state: Optional[int] = None, + lazy: bool = False, + inplace: bool = False, + ): + """Alias for ``validate`` method.""" + return self.validate( + check_obj, head, tail, sample, random_state, lazy, inplace + ) + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + @classmethod + def __get_validators__(cls): + yield cls._pydantic_validate + + @classmethod + def _pydantic_validate( # type: ignore + cls: TColumnSchemaBase, schema: Any + ) -> TColumnSchemaBase: + """Verify that the input is a compatible Schema.""" + if not isinstance(schema, cls): # type: ignore + raise TypeError(f"{schema} is not a {cls}.") + + return cast(TColumnSchemaBase, schema) + + ############################# + # Schema Transforms Methods # + ############################# + + @inferred_schema_guard + def update_checks(self, checks: List[Check]): + """Create a new SeriesSchema with a new set of Checks + + :param checks: checks to set on the new schema + :returns: a new SeriesSchema with a new set of checks + """ + schema_copy = cast(ColumnSchema, copy.deepcopy(self)) + schema_copy.checks = checks + return schema_copy + + def set_checks(self, checks: CheckList): + """Create a new SeriesSchema with a new set of Checks + + .. caution:: + This method will be deprecated in favor of ``update_checks`` in + v0.15.0 + + :param checks: checks to set on the new schema + :returns: a new SeriesSchema with a new set of checks + """ + return self.update_checks(checks) + + def __repr__(self): + return ( + f"" + ) + + def __str__(self): + return f"column '{self.name}' with type {self.dtype}" diff --git a/pandera/api/pyspark/components.py b/pandera/api/pyspark/components.py new file mode 100644 index 000000000..2455458a9 --- /dev/null +++ b/pandera/api/pyspark/components.py @@ -0,0 +1,174 @@ +"""Core pyspark schema component specifications.""" + +from typing import Any, Dict, Iterable, Optional, Tuple, Union + +import pyspark.sql as ps +from pandera.api.pyspark.column_schema import ColumnSchema +from pandera.api.pyspark.error_handler import ErrorHandler +from pandera.api.pyspark.types import CheckList, PySparkDtypeInputTypes + + +class Column(ColumnSchema): + """Validate types and properties of DataFrame columns.""" + + def __init__( + self, + dtype: PySparkDtypeInputTypes = None, + checks: Optional[CheckList] = None, + nullable: bool = False, + coerce: bool = False, + required: bool = True, + name: Union[str, Tuple[str, ...], None] = None, + regex: bool = False, + title: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[dict] = None, + ) -> None: + """Create column validator object. + + :param dtype: datatype of the column. The datatype for type-checking + a dataframe. If a string is specified, then assumes + one of the valid pyspark string values: + https://spark.apache.org/docs/latest/sql-ref-datatypes.html + :param checks: checks to verify validity of the column + :param nullable: Whether or not column can contain null values. + :param coerce: If True, when schema.validate is called the column will + be coerced into the specified dtype. This has no effect on columns + where ``dtype=None``. + :param required: Whether or not column is allowed to be missing + :param name: column name in dataframe to validate. + :param regex: whether the ``name`` attribute should be treated as a + regex pattern to apply to multiple columns in a dataframe. + :param title: A human-readable label for the column. + :param description: An arbitrary textual description of the column. + :param metadata: An optional key value data. + + :raises SchemaInitError: if impossible to build schema from parameters + + :example: + + >>> import pyspark as ps + >>> import pandera as pa + >>> + >>> + >>> schema = pa.DataFrameSchema({ + ... "column": pa.Column(str) + ... }) + >>> + >>> schema.validate(spark.createDataFrame([{"column": "foo"},{ "column":"bar"}])) + column + foo + bar + + See :ref:`here` for more usage details. + """ + super().__init__( + dtype=dtype, + checks=checks, + nullable=nullable, + coerce=coerce, + name=name, + title=title, + description=description, + metadata=metadata, + ) + if name is not None and not isinstance(name, str) and regex: + raise ValueError( + "You cannot specify a non-string name when setting regex=True" + ) + self.required = required + self.name = name + self.regex = regex + self.metadata = metadata + + @property + def _allow_groupby(self) -> bool: + """Whether the schema or schema component allows groupby operations.""" + return True + + @property + def properties(self) -> Dict[str, Any]: + """Get column properties.""" + return { + "dtype": self.dtype, + "checks": self.checks, + "nullable": self.nullable, + "coerce": self.coerce, + "required": self.required, + "name": self.name, + "regex": self.regex, + "title": self.title, + "description": self.description, + "metadata": self.metadata, + } + + def set_name(self, name: str): + """Used to set or modify the name of a column object. + + :param str name: the name of the column object + + """ + self.name = name + return self + + def validate( + self, + check_obj: ps.DataFrame, + head: Optional[int] = None, + tail: Optional[int] = None, + sample: Optional[int] = None, + random_state: Optional[int] = None, + lazy: bool = True, + inplace: bool = False, + error_handler: ErrorHandler = None, + ) -> ps.DataFrame: + """Validate a Column in a DataFrame object. + + :param check_obj: pyspark DataFrame to validate. + :param head: validate the first n rows. Rows overlapping with `tail` or + `sample` are de-duplicated. + :param tail: validate the last n rows. Rows overlapping with `head` or + `sample` are de-duplicated. + :param sample: validate a random sample of n rows. Rows overlapping + with `head` or `tail` are de-duplicated. + :param random_state: random seed for the ``sample`` argument. + :param lazy: if True, lazily evaluates dataframe against all validation + checks and raises a ``SchemaErrors``. Otherwise, raise + ``SchemaError`` as soon as one occurs. + :param inplace: if True, applies coercion to the object of validation, + otherwise creates a copy of the data. + :returns: validated DataFrame. + """ + return self.get_backend(check_obj).validate( + check_obj=check_obj, + schema=self, + head=head, + tail=tail, + sample=sample, + random_state=random_state, + lazy=lazy, + inplace=inplace, + error_handler=error_handler, + ) + + def get_regex_columns(self, columns: Any) -> Iterable: + """Get matching column names based on regex column name pattern. + + :param columns: columns to regex pattern match + :returns: matchin columns + """ + return self.get_backend(check_type=ps.DataFrame).get_regex_columns( + self, columns + ) + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + + def _compare_dict(obj): + return { + k: v if k != "_checks" else set(v) + for k, v in obj.__dict__.items() + } + + return _compare_dict(self) == _compare_dict(other) diff --git a/pandera/api/pyspark/container.py b/pandera/api/pyspark/container.py new file mode 100644 index 000000000..e40ec423b --- /dev/null +++ b/pandera/api/pyspark/container.py @@ -0,0 +1,605 @@ +"""Core pyspark dataframe container specification.""" + +from __future__ import annotations + +import copy +import os +import warnings +from pathlib import Path +from typing import Any, Dict, List, Optional, Union, cast, overload + +from pyspark.sql import DataFrame + +from pandera import errors +from pandera.api.base.schema import BaseSchema +from pandera.api.checks import Check +from pandera.api.pyspark.error_handler import ErrorHandler +from pandera.api.pyspark.types import ( + CheckList, + PySparkDtypeInputTypes, + StrictType, +) +from pandera.backends.pyspark.container import DataFrameSchemaBackend +from pandera.dtypes import DataType, UniqueSettings +from pandera.engines import pyspark_engine + +N_INDENT_SPACES = 4 + + +class DataFrameSchema(BaseSchema): # pylint: disable=too-many-public-methods + """A light-weight PySpark DataFrame validator.""" + + BACKEND = DataFrameSchemaBackend() + + def __init__( + self, + columns: Optional[ # type: ignore [name-defined] + Dict[Any, "pandera.api.pyspark.components.Column"] # type: ignore [name-defined] + ] = None, + checks: Optional[CheckList] = None, + dtype: PySparkDtypeInputTypes = None, + coerce: bool = False, + strict: StrictType = False, + name: Optional[str] = None, + ordered: bool = False, + unique: Optional[Union[str, List[str]]] = None, + report_duplicates: UniqueSettings = "all", + unique_column_names: bool = False, + title: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[dict] = None, + ) -> None: + """Initialize DataFrameSchema validator. + + :param columns: a dict where keys are column names and values are + Column objects specifying the datatypes and properties of a + particular column. + :type columns: mapping of column names and column schema component. + :param checks: dataframe-wide checks. + :param index: specify the datatypes and properties of the index. + :param dtype: datatype of the dataframe. This overrides the data + types specified in any of the columns. If a string is specified, + then assumes one of the valid pyspark string values: + https://spark.apache.org/docs/latest/sql-ref-datatypes.html. + :param coerce: whether or not to coerce all of the columns on + validation. This has no effect on columns where + ``dtype=None`` + :param strict: ensure that all and only the columns defined in the + schema are present in the dataframe. If set to 'filter', + only the columns in the schema will be passed to the validated + dataframe. If set to filter and columns defined in the schema + are not present in the dataframe, will throw an error. + :param name: name of the schema. + :param ordered: whether or not to validate the columns order. + :param unique: a list of columns that should be jointly unique. + :param report_duplicates: how to report unique errors + - `exclude_first`: report all duplicates except first occurence + - `exclude_last`: report all duplicates except last occurence + - `all`: (default) report all duplicates + :param unique_column_names: whether or not column names must be unique. + :param title: A human-readable label for the schema. + :param description: An arbitrary textual description of the schema. + :param metadata: An optional key-value data. + + :raises SchemaInitError: if impossible to build schema from parameters + + :examples: + + >>> import pandera as pa + >>> + >>> schema = pa.DataFrameSchema({ + ... "str_column": pa.Column(str), + ... "float_column": pa.Column(float), + ... "int_column": pa.Column(int), + ... "date_column": pa.Column(pa.DateTime), + ... }) + + Use the pyspark API to define checks, which takes a function with + the signature: ``ps.Dataframe -> Union[bool]`` where the + output contains boolean values. + + >>> schema_withchecks = pa.DataFrameSchema({ + ... "probability": pa.Column( + ... float, pa.Check(lambda s: (s >= 0) & (s <= 1))), + ... + ... # check that the "category" column contains a few discrete + ... # values, and the majority of the entries are dogs. + ... "category": pa.Column( + ... str, [ + ... pa.Check(lambda s: s.isin(["dog", "cat", "duck"])), + ... pa.Check(lambda s: (s == "dog").mean() > 0.5), + ... ]), + ... }) + + See :ref:`here` for more usage details. + + """ + + if columns is None: + columns = {} + _validate_columns(columns) + columns = _columns_renamed(columns) + + if checks is None: + checks = [] + if isinstance(checks, (Check)): + checks = [checks] + + super().__init__( + dtype=dtype, + checks=checks, + name=name, + title=title, + description=description, + metadata=metadata, + ) + + self.columns: Dict[Any, "pandera.api.pyspark.components.Column"] = ( # type: ignore [name-defined] + {} if columns is None else columns + ) + + if strict not in ( + False, + True, + "filter", + ): + raise errors.SchemaInitError( + "strict parameter must equal either `True`, `False`, " + "or `'filter'`." + ) + + self.strict: Union[bool, str] = strict + self._coerce = coerce + self.ordered = ordered + self._unique = unique + self.report_duplicates = report_duplicates + self.unique_column_names = unique_column_names + + # this attribute is not meant to be accessed by users and is explicitly + # set to True in the case that a schema is created by infer_schema. + self._IS_INFERRED = False + self.metadata = metadata + + @property + def coerce(self) -> bool: + """Whether to coerce series to specified type.""" + if isinstance(self.dtype, DataType): + return self.dtype.auto_coerce or self._coerce + return self._coerce + + @coerce.setter + def coerce(self, value: bool) -> None: + """Set coerce attribute""" + self._coerce = value + + @property + def unique(self): + """List of columns that should be jointly unique.""" + return self._unique + + @unique.setter + def unique(self, value: Optional[Union[str, List[str]]]) -> None: + """Set unique attribute.""" + self._unique = [value] if isinstance(value, str) else value + + # the _is_inferred getter and setter methods are not public + @property + def _is_inferred(self) -> bool: + return self._IS_INFERRED + + @_is_inferred.setter + def _is_inferred(self, value: bool) -> None: + self._IS_INFERRED = value + + @property + def dtypes(self) -> Dict[str, DataType]: + # pylint:disable=anomalous-backslash-in-string + """ + A dict where the keys are column names and values are + :class:`~pandera.dtypes.DataType` s for the column. Excludes columns + where `regex=True`. + + :returns: dictionary of columns and their associated dtypes. + """ + regex_columns = [ + name for name, col in self.columns.items() if col.regex + ] + if regex_columns: + warnings.warn( + "Schema has columns specified as regex column names: " + f"{regex_columns}. Use the `get_dtypes` to get the datatypes " + "for these columns.", + UserWarning, + ) + return {n: c.dtype for n, c in self.columns.items() if not c.regex} + + @property + def get_metadata(self) -> Optional[dict]: + """Provide metadata for columns and schema level""" + res = {"columns": {}} + for k in self.columns.keys(): + res["columns"][k] = self.columns[k].properties["metadata"] + + res["dataframe"] = self.metadata + + meta = {} + meta[self.name] = res + return meta + + def get_dtypes(self, dataframe: DataFrame) -> Dict[str, DataType]: + """ + Same as the ``dtype`` property, but expands columns where + ``regex == True`` based on the supplied dataframe. + + :returns: dictionary of columns and their associated dtypes. + """ + regex_dtype = {} + for _, column in self.columns.items(): + if column.regex: + regex_dtype.update( + { + c: column.dtype + for c in column.BACKEND.get_regex_columns( + column, + dataframe.columns, + ) + } + ) + return { + **{n: c.dtype for n, c in self.columns.items() if not c.regex}, + **regex_dtype, + } + + @property + def dtype( + self, + ) -> DataType: + """Get the dtype property.""" + return self._dtype # type: ignore + + @dtype.setter + def dtype(self, value: PySparkDtypeInputTypes) -> None: + """Set the pyspark dtype property.""" + # this is a pylint false positive + # pylint: disable=no-value-for-parameter + self._dtype = pyspark_engine.Engine.dtype(value) if value else None + + def coerce_dtype(self, check_obj: DataFrame) -> DataFrame: + return self.get_backend(check_obj).coerce_dtype(check_obj, schema=self) + + def validate( + self, + check_obj: DataFrame, + head: Optional[int] = None, + tail: Optional[int] = None, + sample: Optional[int] = None, + random_state: Optional[int] = None, + lazy: bool = True, + inplace: bool = False, + ): + """Check if all columns in a dataframe have a column in the Schema. + + :param pd.DataFrame check_obj: the dataframe to be validated. + :param head: validate the first n rows. Rows overlapping with `tail` or + `sample` are de-duplicated. + :param tail: validate the last n rows. Rows overlapping with `head` or + `sample` are de-duplicated. + :param sample: validate a random sample of n rows. Rows overlapping + with `head` or `tail` are de-duplicated. + :param random_state: random seed for the ``sample`` argument. + :param lazy: if True, lazily evaluates dataframe against all validation + checks and raises a ``SchemaErrors``. Otherwise, raise + ``SchemaError`` as soon as one occurs. + :param inplace: if True, applies coercion to the object of validation, + otherwise creates a copy of the data. + :returns: validated ``DataFrame`` + + :raises SchemaError: when ``DataFrame`` violates built-in or custom + checks. + + :example: + + Calling ``schema.validate`` returns the dataframe. + + + >>> import pandera as pa + >>> + >>> df = spark.createDataFrame([(0.1, 'dog'), (0.4, 'dog'), (0.52, 'cat'), (0.23, 'duck'), + ... (0.8, 'dog'), (0.76, 'dog')],schema=['probability','category']) + >>> + >>> schema_withchecks = pa.DataFrameSchema({ + ... "probability": pa.Column( + ... float, pa.Check(lambda s: (s >= 0) & (s <= 1))), + ... + ... # check that the "category" column contains a few discrete + ... # values, and the majority of the entries are dogs. + ... "category": pa.Column( + ... str, [ + ... pa.Check(lambda s: s.isin(["dog", "cat", "duck"])), + ... pa.Check(lambda s: (s == "dog").mean() > 0.5), + ... ]), + ... }) + >>> + >>> schema_withchecks.validate(df)[["probability", "category"]] + probability category + 0.10 dog + 0.40 dog + 0.52 cat + 0.23 duck + 0.80 dog + 0.76 dog + """ + error_handler = ErrorHandler(lazy) + + return self._validate( + check_obj=check_obj, + head=head, + tail=tail, + sample=sample, + random_state=random_state, + lazy=lazy, + inplace=inplace, + error_handler=error_handler, + ) + + def _validate( + self, + check_obj: DataFrame, + head: Optional[int] = None, + tail: Optional[int] = None, + sample: Optional[int] = None, + random_state: Optional[int] = None, + lazy: bool = False, + inplace: bool = False, + error_handler: ErrorHandler = None, + ): + if self._is_inferred: + warnings.warn( + f"This {type(self)} is an inferred schema that hasn't been " + "modified. It's recommended that you refine the schema " + "by calling `add_columns`, `remove_columns`, or " + "`update_columns` before using it to validate data.", + UserWarning, + ) + + return self.get_backend(check_obj).validate( + check_obj=check_obj, + schema=self, + head=head, + tail=tail, + sample=sample, + random_state=random_state, + lazy=lazy, + inplace=inplace, + error_handler=error_handler, + ) + + def __call__( + self, + dataframe: DataFrame, + head: Optional[int] = None, + tail: Optional[int] = None, + sample: Optional[int] = None, + random_state: Optional[int] = None, + lazy: bool = True, + inplace: bool = False, + ): + """Alias for :func:`DataFrameSchema.validate` method. + + :param pd.DataFrame dataframe: the dataframe to be validated. + :param head: validate the first n rows. Rows overlapping with `tail` or + `sample` are de-duplicated. + :type head: int + :param tail: validate the last n rows. Rows overlapping with `head` or + `sample` are de-duplicated. + :type tail: int + :param sample: validate a random sample of n rows. Rows overlapping + with `head` or `tail` are de-duplicated. + :param random_state: random seed for the ``sample`` argument. + :param lazy: if True, lazily evaluates dataframe against all validation + checks and raises a ``SchemaErrors``. Otherwise, raise + ``SchemaError`` as soon as one occurs. + :param inplace: if True, applies coercion to the object of validation, + otherwise creates a copy of the data. + """ + return self.validate( + dataframe, head, tail, sample, random_state, lazy, inplace + ) + + def __repr__(self) -> str: + """Represent string for logging.""" + return ( + f"" + ) + + def __str__(self) -> str: + """Represent string for user inspection.""" + + def _format_multiline(json_str, arg): + return "\n".join( + f"{indent}{line}" if i != 0 else f"{indent}{arg}={line}" + for i, line in enumerate(json_str.split("\n")) + ) + + indent = " " * N_INDENT_SPACES + if self.columns: + columns_str = f"{indent}columns={{\n" + for colname, col in self.columns.items(): + columns_str += f"{indent * 2}'{colname}': {col}\n" + columns_str += f"{indent}}}" + else: + columns_str = f"{indent}columns={{}}" + + if self.checks: + checks_str = f"{indent}checks=[\n" + for check in self.checks: + checks_str += f"{indent * 2}{check}\n" + checks_str += f"{indent}]" + else: + checks_str = f"{indent}checks=[]" + + return ( + f"" + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, type(self)): + return NotImplemented + + def _compare_dict(obj): + return { + k: v for k, v in obj.__dict__.items() if k != "_IS_INFERRED" + } + + return _compare_dict(self) == _compare_dict(other) + + @classmethod + def __get_validators__(cls): + yield cls._pydantic_validate + + @classmethod + def _pydantic_validate(cls, schema: Any) -> "DataFrameSchema": + """Verify that the input is a compatible DataFrameSchema.""" + if not isinstance(schema, cls): # type: ignore + raise TypeError(f"{schema} is not a {cls}.") + + return cast("DataFrameSchema", schema) + + ##################### + # Schema IO Methods # + ##################### + + def to_script(self, fp: Union[str, Path] = None) -> "DataFrameSchema": + """Create DataFrameSchema from yaml file. + + :param path: str, Path to write script + :returns: dataframe schema. + """ + # pylint: disable=import-outside-toplevel,cyclic-import,redefined-outer-name + import pandera.io + + return pandera.io.to_script(self, fp) + + @classmethod + def from_yaml(cls, yaml_schema) -> "DataFrameSchema": + """Create DataFrameSchema from yaml file. + + :param yaml_schema: str, Path to yaml schema, or serialized yaml + string. + :returns: dataframe schema. + """ + # pylint: disable=import-outside-toplevel,cyclic-import,redefined-outer-name + import pandera.io + + return pandera.io.from_yaml(yaml_schema) + + @overload + def to_yaml(self, stream: None = None) -> str: # pragma: no cover + ... + + @overload + def to_yaml(self, stream: os.PathLike) -> None: # pragma: no cover + ... + + def to_yaml(self, stream: Optional[os.PathLike] = None) -> Optional[str]: + """Write DataFrameSchema to yaml file. + + :param stream: file stream to write to. If None, dumps to string. + :returns: yaml string if stream is None, otherwise returns None. + """ + # pylint: disable=import-outside-toplevel,cyclic-import,redefined-outer-name + import pandera.io + + return pandera.io.to_yaml(self, stream=stream) + + @classmethod + def from_json(cls, source) -> "DataFrameSchema": + """Create DataFrameSchema from json file. + + :param source: str, Path to json schema, or serialized yaml + string. + :returns: dataframe schema. + """ + # pylint: disable=import-outside-toplevel,cyclic-import,redefined-outer-name + import pandera.io + + return pandera.io.from_json(source) + + @overload + def to_json( + self, target: None = None, **kwargs + ) -> str: # pragma: no cover + ... + + @overload + def to_json( + self, target: os.PathLike, **kwargs + ) -> None: # pragma: no cover + ... + + def to_json( + self, target: Optional[os.PathLike] = None, **kwargs + ) -> Optional[str]: + """Write DataFrameSchema to json file. + + :param target: file target to write to. If None, dumps to string. + :returns: json string if target is None, otherwise returns None. + """ + # pylint: disable=import-outside-toplevel,cyclic-import,redefined-outer-name + import pandera.io + + return pandera.io.to_json(self, target, **kwargs) + + +def _validate_columns( + column_dict: dict[Any, "pandera.api.pyspark.components.Column"], # type: ignore [name-defined] +) -> None: + for column_name, column in column_dict.items(): + for check in column.checks: + if check.groupby is None or callable(check.groupby): + continue + nonexistent_groupby_columns = [ + c for c in check.groupby if c not in column_dict + ] + if nonexistent_groupby_columns: + raise errors.SchemaInitError( + f"groupby argument {nonexistent_groupby_columns} in " + f"Check for Column {column_name} not " + "specified in the DataFrameSchema." + ) + + +def _columns_renamed( + columns: dict[Any, "pandera.api.pyspark.components.Column"], # type: ignore [name-defined] +) -> dict[Any, "pandera.api.pyspark.components.Column"]: # type: ignore [name-defined] + def renamed(column, new_name): + column = copy.deepcopy(column) + column.set_name(new_name) + return column + + return { + column_name: renamed(column, column_name) + for column_name, column in columns.items() + } diff --git a/pandera/api/pyspark/error_handler.py b/pandera/api/pyspark/error_handler.py new file mode 100644 index 000000000..04f0bd0fd --- /dev/null +++ b/pandera/api/pyspark/error_handler.py @@ -0,0 +1,98 @@ +"""Handle schema errors.""" + +from collections import defaultdict +from enum import Enum +from typing import Dict, List, Union + +from pandera.errors import SchemaError, SchemaErrorReason +from pandera.api.checks import Check + + +class ErrorCategory(Enum): + """Error category codes""" + + DATA = "data-failures" + SCHEMA = "schema-failures" + DTYPE_COERCION = "dtype-coercion-failures" + + +class ErrorHandler: + """Handler for Schema & Data level errors during validation.""" + + def __init__(self, lazy: bool) -> None: + """Initialize ErrorHandler. + + :param lazy: if True, lazily evaluates all checks before raising the exception. + """ + self._lazy = lazy + self._collected_errors = [] # type: ignore + self._summarized_errors = defaultdict(lambda: defaultdict(list)) + + @property + def lazy(self) -> bool: + """Whether or not the schema error handler raises errors immediately.""" + return self._lazy + + def collect_error( + self, + type: ErrorCategory, + reason_code: SchemaErrorReason, + schema_error: SchemaError, + original_exc: BaseException = None, + ): + """Collect schema error, raising exception if lazy is False. + + :param type: type of error + :param reason_code: string representing reason for error + :param schema_error: ``SchemaError`` object. + """ + if not self._lazy: + raise schema_error from original_exc + + # delete data of validated object from SchemaError object to prevent + # storing copies of the validated DataFrame/Series for every + # SchemaError collected. + del schema_error.data + schema_error.data = None + + self._collected_errors.append( + { + "type": type, + "column": schema_error.schema.name, + "check": schema_error.check, + "reason_code": reason_code, + "error": schema_error, + } + ) + + @property + def collected_errors(self) -> List[Dict[str, Union[SchemaError, str]]]: + """Retrieve SchemaError objects collected during lazy validation.""" + return self._collected_errors + + def summarize(self, schema): + """Collect schema error, raising exception if lazy is False. + + :param schema: schema object + """ + + for e in self._collected_errors: + category = e["type"].name + subcategory = e["reason_code"].name + error = e["error"] + + if isinstance(error.check, Check): + check = error.check.error + else: + check = error.check + + self._summarized_errors[category][subcategory].append( + { + "schema": schema.name, + "column": e["column"], + "check": check, + "error": error.__str__(), + } + ) + + return self._summarized_errors diff --git a/pandera/api/pyspark/model.py b/pandera/api/pyspark/model.py new file mode 100644 index 000000000..aa4e8b8f2 --- /dev/null +++ b/pandera/api/pyspark/model.py @@ -0,0 +1,544 @@ +"""Class-based api for pyspark models.""" + +import copy +import inspect +import os +import re +import typing +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, +) + +import pyspark.sql as ps + +from pandera.api.base.model import BaseModel +from pandera.api.checks import Check +from pandera.api.pyspark.components import Column +from pandera.api.pyspark.container import DataFrameSchema +from pandera.api.pyspark.model_components import ( + CHECK_KEY, + DATAFRAME_CHECK_KEY, + CheckInfo, + Field, + FieldCheckInfo, + FieldInfo, +) +from pandera.api.pyspark.model_config import BaseConfig +from pandera.errors import SchemaInitError +from pandera.typing import INDEX_TYPES, SERIES_TYPES, AnnotationInfo +from pandera.typing.common import DataFrameBase + +try: + from typing_extensions import get_type_hints +except ImportError: + from typing import get_type_hints # type: ignore + +try: + from pydantic.fields import ModelField # pylint:disable=unused-import + + HAS_PYDANTIC = True +except ImportError: + HAS_PYDANTIC = False + + +_CONFIG_KEY = "Config" +MODEL_CACHE: Dict[Type["DataFrameModel"], DataFrameSchema] = {} +GENERIC_SCHEMA_CACHE: Dict[ + Tuple[Type["DataFrameModel"], Tuple[Type[Any], ...]], + Type["DataFrameModel"], +] = {} + +F = TypeVar("F", bound=Callable) +TDataFrameModel = TypeVar("TDataFrameModel", bound="DataFrameModel") + + +# def docstring_substitution(*args: Any, **kwargs: Any) -> Callable[[F], F]: +# """Typed wrapper around pd.util.Substitution.""" + +# def decorator(func: F) -> F: +# return cast(F, pd.util.Substitution(*args, **kwargs)(func)) + +# return decorator + + +def _is_field(name: str) -> bool: + """Ignore private and reserved keywords.""" + return not name.startswith("_") and name != _CONFIG_KEY + + +_config_options = [attr for attr in vars(BaseConfig) if _is_field(attr)] + + +def _extract_config_options_and_extras( + config: Any, +) -> Tuple[Dict[str, Any], Dict[str, Any]]: + config_options, extras = {}, {} + for name, value in vars(config).items(): + if name in _config_options: + config_options[name] = value + elif _is_field(name): + extras[name] = value + # drop private/reserved keywords + + return config_options, extras + + +def _convert_extras_to_checks(extras: Dict[str, Any]) -> List[Check]: + """ + New in GH#383. + Any key not in BaseConfig keys is interpreted as defining a dataframe check. This function + defines this conversion as follows: + - Look up the key name in Check + - If value is + - tuple: interpret as args + - dict: interpret as kwargs + - anything else: interpret as the only argument to pass to Check + """ + checks = [] + for name, value in extras.items(): + if isinstance(value, tuple): + args, kwargs = value, {} + elif isinstance(value, dict): + args, kwargs = (), value + else: + args, kwargs = (value,), {} + + # dispatch directly to getattr to raise the correct exception + checks.append(Check.__getattr__(name)(*args, **kwargs)) + + return checks + + +class DataFrameModel(BaseModel): + """Definition of a :class:`~pandera.api.pyspark.container.DataFrameSchema`. + + *new in 0.5.0* + + .. important:: + + This class is the new name for ``SchemaModel``, which will be deprecated + in pandera version ``0.20.0``. + + See the :ref:`User Guide ` for more. + """ + + Config: Type[BaseConfig] = BaseConfig + __extras__: Optional[Dict[str, Any]] = None + __schema__: Optional[DataFrameSchema] = None + __config__: Optional[Type[BaseConfig]] = None + + #: Key according to `FieldInfo.name` + __fields__: Mapping[str, Tuple[AnnotationInfo, FieldInfo]] = {} + __checks__: Dict[str, List[Check]] = {} + __root_checks__: List[Check] = [] + + # @docstring_substitution(validate_doc=DataFrameSchema.validate.__doc__) + def __new__(cls, *args, **kwargs) -> DataFrameBase[TDataFrameModel]: # type: ignore [misc] + """%(validate_doc)s""" + return cast( + DataFrameBase[TDataFrameModel], cls.validate(*args, **kwargs) + ) + + def __init_subclass__(cls, **kwargs): + """Ensure :class:`~pandera.api.pyspark.model_components.FieldInfo` instances.""" + if "Config" in cls.__dict__: + cls.Config.name = ( + cls.Config.name + if hasattr(cls.Config, "name") + else cls.__name__ + ) + else: + cls.Config = type("Config", (BaseConfig,), {"name": cls.__name__}) + super().__init_subclass__(**kwargs) + # pylint:disable=no-member + subclass_annotations = cls.__dict__.get("__annotations__", {}) + for field_name in subclass_annotations.keys(): + if _is_field(field_name) and field_name not in cls.__dict__: + # Field omitted + field = Field() + field.__set_name__(cls, field_name) + setattr(cls, field_name, field) + + cls.__config__, cls.__extras__ = cls._collect_config_and_extras() + + def __class_getitem__( + cls: Type[TDataFrameModel], + params: Union[Type[Any], Tuple[Type[Any], ...]], + ) -> Type[TDataFrameModel]: + """Parameterize the class's generic arguments with the specified types""" + if not hasattr(cls, "__parameters__"): + raise TypeError( + f"{cls.__name__} must inherit from typing.Generic before being parameterized" + ) + # pylint: disable=no-member + __parameters__: Tuple[TypeVar, ...] = cls.__parameters__ # type: ignore + + if not isinstance(params, tuple): + params = (params,) + if len(params) != len(__parameters__): + raise ValueError( + f"Expected {len(__parameters__)} generic arguments but found {len(params)}" + ) + if (cls, params) in GENERIC_SCHEMA_CACHE: + return typing.cast( + Type[TDataFrameModel], GENERIC_SCHEMA_CACHE[(cls, params)] + ) + + param_dict: Dict[TypeVar, Type[Any]] = dict( + zip(__parameters__, params) + ) + extra: Dict[str, Any] = {"__annotations__": {}} + for field, (annot_info, field_info) in cls._collect_fields().items(): + if isinstance(annot_info.arg, TypeVar): + if annot_info.arg in param_dict: + raw_annot = annot_info.origin[param_dict[annot_info.arg]] # type: ignore + if annot_info.optional: + raw_annot = Optional[raw_annot] + extra["__annotations__"][field] = raw_annot + extra[field] = copy.deepcopy(field_info) + + parameterized_name = ( + f"{cls.__name__}[{', '.join(p.__name__ for p in params)}]" + ) + parameterized_cls = type(parameterized_name, (cls,), extra) + GENERIC_SCHEMA_CACHE[(cls, params)] = parameterized_cls + return parameterized_cls + + @classmethod + def to_schema(cls) -> DataFrameSchema: + """Create :class:`~pandera.DataFrameSchema` from the :class:`.DataFrameModel`.""" + + if cls in MODEL_CACHE: + return MODEL_CACHE[cls] + + cls.__fields__ = cls._collect_fields() + + for field, (annot_info, _) in cls.__fields__.items(): + if isinstance(annot_info.arg, TypeVar): + raise SchemaInitError(f"Field {field} has a generic data type") + + check_infos = typing.cast( + List[FieldCheckInfo], cls._collect_check_infos(CHECK_KEY) + ) + + cls.__checks__ = cls._extract_checks( + check_infos, field_names=list(cls.__fields__.keys()) + ) + + df_check_infos = cls._collect_check_infos(DATAFRAME_CHECK_KEY) + df_custom_checks = cls._extract_df_checks(df_check_infos) + df_registered_checks = _convert_extras_to_checks( + {} if cls.__extras__ is None else cls.__extras__ + ) + cls.__root_checks__ = df_custom_checks + df_registered_checks + + columns = cls._build_columns_index(cls.__fields__, cls.__checks__) + + kwargs = {} + if cls.__config__ is not None: + kwargs = { + "dtype": cls.__config__.dtype, + "coerce": cls.__config__.coerce, + "strict": cls.__config__.strict, + "name": cls.__config__.name, + "ordered": cls.__config__.ordered, + "unique": cls.__config__.unique, + "title": cls.__config__.title, + "description": cls.__config__.description or cls.__doc__, + "unique_column_names": cls.__config__.unique_column_names, + } + cls.__schema__ = DataFrameSchema( + columns, + checks=cls.__root_checks__, # type: ignore + **kwargs, # type: ignore + ) + + if cls not in MODEL_CACHE: + MODEL_CACHE[cls] = cls.__schema__ # type: ignore + return cls.__schema__ # type: ignore + + @classmethod + def to_yaml(cls, stream: Optional[os.PathLike] = None): + """ + Convert `Schema` to yaml using `io.to_yaml`. + """ + return cls.to_schema().to_yaml(stream) + + @classmethod + # @docstring_substitution(validate_doc=DataFrameSchema.validate.__doc__) + def validate( + cls: Type[TDataFrameModel], + check_obj: ps.DataFrame, + head: Optional[int] = None, + tail: Optional[int] = None, + sample: Optional[int] = None, + random_state: Optional[int] = None, + lazy: bool = True, + inplace: bool = False, + ) -> DataFrameBase[TDataFrameModel]: + """%(validate_doc)s""" + return cast( + DataFrameBase[TDataFrameModel], + cls.to_schema().validate( + check_obj, head, tail, sample, random_state, lazy, inplace + ), + ) + + @classmethod + def _build_columns_index( # pylint:disable=too-many-locals + cls, + fields: Dict[str, Tuple[AnnotationInfo, FieldInfo]], + checks: Dict[str, List[Check]], + ) -> Dict[str, Column]: + columns: Dict[str, Column] = {} + for field_name, (annotation, field) in fields.items(): + field_checks = checks.get(field_name, []) + field_name = field.name + check_name = getattr(field, "check_name", None) + + if annotation.metadata: + if field.dtype_kwargs: + raise TypeError( + "Cannot specify redundant 'dtype_kwargs' " + + f"for {annotation.raw_annotation}." + + "\n Usage Tip: Drop 'typing.Annotated'." + ) + dtype_kwargs = _get_dtype_kwargs(annotation) + dtype = annotation.arg(**dtype_kwargs) # type: ignore + elif annotation.default_dtype: + dtype = annotation.default_dtype + else: + dtype = annotation.arg + + dtype = None if dtype is Any else dtype + + if annotation.origin is None: + col_constructor = field.to_column if field else Column + + if check_name is False: + raise SchemaInitError( + f"'check_name' is not supported for {field_name}." + ) + + columns[field_name] = col_constructor( # type: ignore + dtype, + required=not annotation.optional, + checks=field_checks, + name=field_name, + ) + else: + raise SchemaInitError( + f"Invalid annotation '{field_name}: " + f"{annotation.raw_annotation}'" + ) + + return columns + + @classmethod + def _get_model_attrs(cls) -> Dict[str, Any]: + """Return all attributes. + Similar to inspect.get_members but bypass descriptors __get__. + """ + bases = inspect.getmro(cls)[:-1] # bases -> DataFrameModel -> object + attrs = {} + for base in reversed(bases): + if issubclass(base, DataFrameModel): + attrs.update(base.__dict__) + return attrs + + @classmethod + def _collect_fields(cls) -> Dict[str, Tuple[AnnotationInfo, FieldInfo]]: + """Centralize publicly named fields and their corresponding annotations.""" + + annotations = get_type_hints( # pylint:disable=unexpected-keyword-arg + cls, include_extras=True # type: ignore [call-arg] + ) + attrs = cls._get_model_attrs() + + missing = [] + for name, attr in attrs.items(): + if inspect.isroutine(attr): + continue + if not _is_field(name): + annotations.pop(name, None) + elif name not in annotations: + missing.append(name) + + if missing: + raise SchemaInitError(f"Found missing annotations: {missing}") + + fields = {} + for field_name, annotation in annotations.items(): + field = attrs[field_name] # __init_subclass__ guarantees existence + if not isinstance(field, FieldInfo): + raise SchemaInitError( + f"'{field_name}' can only be assigned a 'Field', " + + f"not a '{type(field)}.'" + ) + fields[field.name] = (AnnotationInfo(annotation), field) + + return fields + + @classmethod + def _collect_config_and_extras( + cls, + ) -> Tuple[Type[BaseConfig], Dict[str, Any]]: + """Collect config options from bases, splitting off unknown options.""" + bases = inspect.getmro(cls)[:-1] + bases = tuple( + base for base in bases if issubclass(base, DataFrameModel) + ) + root_model, *models = reversed(bases) + + options, extras = _extract_config_options_and_extras(root_model.Config) + + for model in models: + config = getattr(model, _CONFIG_KEY, {}) + base_options, base_extras = _extract_config_options_and_extras( + config + ) + options.update(base_options) + extras.update(base_extras) + + return type("Config", (BaseConfig,), options), extras + + @classmethod + def _collect_check_infos(cls, key: str) -> List[CheckInfo]: + """Collect inherited check metadata from bases. + Inherited classmethods are not in cls.__dict__, that's why we need to + walk the inheritance tree. + """ + bases = inspect.getmro(cls)[:-2] # bases -> DataFrameModel -> object + bases = tuple( + base for base in bases if issubclass(base, DataFrameModel) + ) + + method_names = set() + check_infos = [] + for base in bases: + for attr_name, attr_value in vars(base).items(): + check_info = getattr(attr_value, key, None) + if not isinstance(check_info, CheckInfo): + continue + if attr_name in method_names: # check overridden by subclass + continue + method_names.add(attr_name) + check_infos.append(check_info) + return check_infos + + @classmethod + def _extract_checks( + cls, check_infos: List[FieldCheckInfo], field_names: List[str] + ) -> Dict[str, List[Check]]: + """Collect field annotations from bases in mro reverse order.""" + checks: Dict[str, List[Check]] = {} + for check_info in check_infos: + check_info_fields = { + field.name if isinstance(field, FieldInfo) else field + for field in check_info.fields + } + if check_info.regex: + matched = _regex_filter(field_names, check_info_fields) + else: + matched = check_info_fields + + check_ = check_info.to_check(cls) + + for field in matched: + if field not in field_names: + raise SchemaInitError( + f"Check {check_.name} is assigned to a non-existing field '{field}'." + ) + if field not in checks: + checks[field] = [] + checks[field].append(check_) + return checks + + @classmethod + def _extract_df_checks(cls, check_infos: List[CheckInfo]) -> List[Check]: + """Collect field annotations from bases in mro reverse order.""" + return [check_info.to_check(cls) for check_info in check_infos] + + @classmethod + def __get_validators__(cls): + yield cls.pydantic_validate + + @classmethod + def pydantic_validate(cls, schema_model: Any) -> "DataFrameModel": + """Verify that the input is a compatible dataframe model.""" + if not inspect.isclass(schema_model): # type: ignore + raise TypeError(f"{schema_model} is not a pandera.DataFrameModel") + + if not issubclass(schema_model, cls): # type: ignore + raise TypeError(f"{schema_model} does not inherit {cls}.") + + try: + schema_model.to_schema() + except SchemaInitError as exc: + raise ValueError( + f"Cannot use {cls} as a pydantic type as its " + "DataFrameModel cannot be converted to a DataFrameSchema.\n" + f"Please revisit the model to address the following errors:" + f"\n{exc}" + ) from exc + + return cast("DataFrameModel", schema_model) + + @classmethod + def get_metadata(self) -> Optional[dict]: + """Provide metadata for columns and schema level""" + res = {"columns": {}} + columns = self._collect_fields() + + for k, (_, v) in columns.items(): + res["columns"][k] = v.properties["metadata"] + + res["dataframe"] = self.Config.metadata + + meta = {} + meta[self.Config.name] = res + return meta + + +SchemaModel = DataFrameModel +""" +Alias for DataFrameModel. + +.. warning:: + + This subclass is necessary for backwards compatibility, and will be + deprecated in pandera version ``0.20.0`` in favor of + :py:class:`~pandera.api.pyspark.model.DataFrameModel` +""" + + +def _regex_filter(seq: Iterable, regexps: Iterable[str]) -> Set[str]: + """Filter items matching at least one of the regexes.""" + matched: Set[str] = set() + for regex in regexps: + pattern = re.compile(regex) + matched.update(filter(pattern.match, seq)) + return matched + + +def _get_dtype_kwargs(annotation: AnnotationInfo) -> Dict[str, Any]: + sig = inspect.signature(annotation.arg) # type: ignore + dtype_arg_names = list(sig.parameters.keys()) + if len(annotation.metadata) != len(dtype_arg_names): # type: ignore + raise TypeError( + f"Annotation '{annotation.arg.__name__}' requires " # type: ignore + + f"all positional arguments {dtype_arg_names}." + ) + return dict(zip(dtype_arg_names, annotation.metadata)) # type: ignore diff --git a/pandera/api/pyspark/model_components.py b/pandera/api/pyspark/model_components.py new file mode 100644 index 000000000..832a4626a --- /dev/null +++ b/pandera/api/pyspark/model_components.py @@ -0,0 +1,303 @@ +"""DataFrameModel components""" +from typing import ( + Any, + Callable, + Dict, + Iterable, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, +) + +from pandera.api.base.model_components import ( + BaseCheckInfo, + BaseFieldInfo, + CheckArg, + to_checklist, +) +from pandera.api.checks import Check +from pandera.api.pyspark.column_schema import ColumnSchema +from pandera.api.pyspark.components import Column +from pandera.api.pyspark.types import PySparkDtypeInputTypes +from pandera.errors import SchemaInitError + +AnyCallable = Callable[..., Any] +SchemaComponent = TypeVar("SchemaComponent", bound=ColumnSchema) + +CHECK_KEY = "__check_config__" +DATAFRAME_CHECK_KEY = "__dataframe_check_config__" + + +class FieldInfo(BaseFieldInfo): + """Captures extra information about a field. + + *new in 0.5.0* + """ + + def _to_schema_component( + self, + dtype: PySparkDtypeInputTypes, + component: Type[SchemaComponent], + checks: CheckArg = None, + **kwargs: Any, + ) -> SchemaComponent: + if self.dtype_kwargs: + dtype = dtype(**self.dtype_kwargs) # type: ignore + checks = self.checks + to_checklist(checks) + return component(dtype, checks=checks, **kwargs) # type: ignore + + def to_column( + self, + dtype: PySparkDtypeInputTypes, + checks: CheckArg = None, + required: bool = True, + name: str = None, + ) -> Column: + """Create a schema_components.Column from a field.""" + return self._to_schema_component( + dtype, + Column, + nullable=self.nullable, + coerce=self.coerce, + regex=self.regex, + required=required, + name=name, + checks=checks, + title=self.title, + description=self.description, + metadata=self.metadata, + ) + + @property + def properties(self) -> Dict[str, Any]: + """Get column properties.""" + + return { + "dtype": self.dtype_kwargs, + "checks": self.checks, + "nullable": self.nullable, + "coerce": self.coerce, + "name": self.name, + "regex": self.regex, + "title": self.title, + "description": self.description, + "metadata": self.metadata, + } + + +def Field( + *, + eq: Any = None, + ne: Any = None, + gt: Any = None, + ge: Any = None, + lt: Any = None, + le: Any = None, + in_range: Dict[str, Any] = None, + isin: Iterable = None, + notin: Iterable = None, + str_contains: Optional[str] = None, + str_endswith: Optional[str] = None, + str_length: Optional[Dict[str, Any]] = None, + str_matches: Optional[str] = None, + str_startswith: Optional[str] = None, + nullable: bool = False, + unique: bool = False, + coerce: bool = False, + regex: bool = False, + ignore_na: bool = True, + raise_warning: bool = False, + n_failure_cases: int = None, + alias: Any = None, + check_name: Optional[bool] = None, + dtype_kwargs: Optional[Dict[str, Any]] = None, + title: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[dict] = None, + **kwargs, +) -> Any: + """Used to provide extra information about a field of a DataFrameModel. + + *new in 0.5.0* + + Some arguments apply only to numeric dtypes and some apply only to ``str``. + See the :ref:`User Guide ` for more information. + + The keyword-only arguments from ``eq`` to ``str_startswith`` are dispatched + to the built-in :py:class:`~pandera.api.checks.Check` methods. + + :param nullable: Whether or not the column/index can contain null values. + :param unique: Whether column values should be unique. + :param coerce: coerces the data type if ``True``. + :param regex: whether or not the field name or alias is a regex pattern. + :param ignore_na: whether or not to ignore null values in the checks. + :param raise_warning: raise a warning instead of an Exception. + :param n_failure_cases: report the first n unique failure cases. If None, + report all failure cases. + :param alias: The public name of the column/index. + :param check_name: Whether to check the name of the column/index during + validation. `None` is the default behavior, which translates to `True` + for columns and multi-index, and to `False` for a single index. + :param dtype_kwargs: The parameters to be forwarded to the type of the + field. + :param title: A human-readable label for the field. + :param description: An arbitrary textual description of the field. + :param metadata: An optional key-value data. + :param kwargs: Specify custom checks that have been registered with the + :class:`~pandera.extensions.register_check_method` decorator. + """ + # pylint:disable=C0103,W0613,R0914 + check_kwargs = { + "ignore_na": ignore_na, + "raise_warning": raise_warning, + "n_failure_cases": n_failure_cases, + } + args = locals() + checks = [] + + check_dispatch = _check_dispatch() + for key in kwargs: + if key not in check_dispatch: + raise SchemaInitError( + f"custom check '{key}' is not available. Make sure you use " + "pandera.extensions.register_check_method decorator to " + "register your custom check method." + ) + + for arg_name, check_constructor in check_dispatch.items(): + arg_value = args.get(arg_name, kwargs.get(arg_name)) + if arg_value is None: + continue + if isinstance(arg_value, dict): + check_ = check_constructor(**arg_value, **check_kwargs) + else: + check_ = check_constructor(arg_value, **check_kwargs) + checks.append(check_) + + return FieldInfo( + checks=checks or None, + nullable=nullable, + unique=unique, + coerce=coerce, + regex=regex, + check_name=check_name, + alias=alias, + title=title, + description=description, + dtype_kwargs=dtype_kwargs, + metadata=metadata, + ) + + +def _check_dispatch(): + return { + "eq": Check.equal_to, + "ne": Check.not_equal_to, + "gt": Check.greater_than, + "ge": Check.greater_than_or_equal_to, + "lt": Check.less_than, + "le": Check.less_than_or_equal_to, + "in_range": Check.in_range, + "isin": Check.isin, + "notin": Check.notin, + "str_contains": Check.str_contains, + "str_endswith": Check.str_endswith, + "str_matches": Check.str_matches, + "str_length": Check.str_length, + "str_startswith": Check.str_startswith, + **Check.REGISTERED_CUSTOM_CHECKS, + } + + +class CheckInfo(BaseCheckInfo): # pylint:disable=too-few-public-methods + """Captures extra information about a Check.""" + + ... + + +class FieldCheckInfo(CheckInfo): # pylint:disable=too-few-public-methods + """Captures extra information about a Check assigned to a field.""" + + def __init__( + self, + fields: Set[Union[str, FieldInfo]], + check_fn: AnyCallable, + regex: bool = False, + **check_kwargs: Any, + ) -> None: + super().__init__(check_fn, **check_kwargs) + self.fields = fields + self.regex = regex + + +def _to_function_and_classmethod( + fn: Union[AnyCallable, classmethod] +) -> Tuple[AnyCallable, classmethod]: + if isinstance(fn, classmethod): + fn, method = fn.__func__, cast(classmethod, fn) + else: + method = classmethod(fn) + return fn, method + + +ClassCheck = Callable[[Union[classmethod, AnyCallable]], classmethod] + + +def check(*fields, regex: bool = False, **check_kwargs) -> ClassCheck: + """Decorator to make DataFrameModel method a column check function. + + *new in 0.5.0* + + This indicates that the decorated method should be used to validate a field + (column). The method will be converted to a classmethod. Therefore + its signature must start with `cls` followed by regular check arguments. + See the :ref:`User Guide ` for more. + + :param _fn: Method to decorate. + :param check_kwargs: Keywords arguments forwarded to Check. + """ + + def _wrapper(fn: Union[classmethod, AnyCallable]) -> classmethod: + check_fn, check_method = _to_function_and_classmethod(fn) + check_kwargs.setdefault("description", fn.__doc__) + setattr( + check_method, + CHECK_KEY, + FieldCheckInfo(set(fields), check_fn, regex, **check_kwargs), + ) + return check_method + + return _wrapper + + +def dataframe_check(_fn=None, **check_kwargs) -> ClassCheck: + """Decorator to make DataFrameModel method a dataframe-wide check function. + + *new in 0.5.0* + + Decorate a method on the DataFrameModel indicating that it should be used to + validate the DataFrame. The method will be converted to a classmethod. + Therefore its signature must start with `cls` followed by regular check + arguments. See the :ref:`User Guide ` for + more. + + :param check_kwargs: Keywords arguments forwarded to Check. + """ + + def _wrapper(fn: Union[classmethod, AnyCallable]) -> classmethod: + check_fn, check_method = _to_function_and_classmethod(fn) + check_kwargs.setdefault("description", fn.__doc__) + setattr( + check_method, + DATAFRAME_CHECK_KEY, + CheckInfo(check_fn, **check_kwargs), + ) + return check_method + + if _fn: + return _wrapper(_fn) # type: ignore + return _wrapper diff --git a/pandera/api/pyspark/model_config.py b/pandera/api/pyspark/model_config.py new file mode 100644 index 000000000..eec1ce81f --- /dev/null +++ b/pandera/api/pyspark/model_config.py @@ -0,0 +1,65 @@ +"""Class-based dataframe model API configuration for pyspark.""" + +from typing import Any, Callable, Dict, List, Optional, Union + +from pandera.api.base.model_config import BaseModelConfig +from pandera.api.pyspark.types import PySparkDtypeInputTypes, StrictType +from pandera.typing.formats import Format + + +class BaseConfig(BaseModelConfig): # pylint:disable=R0903 + """Define DataFrameSchema-wide options. + + *new in 0.5.0* + """ + + #: datatype of the dataframe. This overrides the data types specified in + #: any of the fields. + dtype: Optional[PySparkDtypeInputTypes] = None + + name: Optional[str] = None #: name of schema + title: Optional[str] = None #: human-readable label for schema + description: Optional[str] = None #: arbitrary textual description + coerce: bool = False #: coerce types of all schema components + + #: make sure certain column combinations are unique + unique: Optional[Union[str, List[str]]] = None + + #: make sure all specified columns are in the validated dataframe - + #: if ``"filter"``, removes columns not specified in the schema + strict: StrictType = False + + ordered: bool = False #: validate columns order + + #: make sure dataframe column names are unique + unique_column_names: bool = False + + #: data format before validation. This option only applies to + #: schemas used in the context of the pandera type constructor + #: ``pa.typing.DataFrame[Schema](data)``. If None, assumes a data structure + #: compatible with the ``pyspark.sql.DataFrame`` constructor. + from_format: Optional[Union[Format, Callable]] = None + + #: a dictionary keyword arguments to pass into the reader function that + #: converts the object of type ``from_format`` to a pandera-validate-able + #: data structure. The reader function is implemented in the pandera.typing + #: generic types via the ``from_format`` and ``to_format`` methods. + from_format_kwargs: Optional[Dict[str, Any]] = None + + #: data format to serialize into after validation. This option only applies + #: to schemas used in the context of the pandera type constructor + #: ``pa.typing.DataFrame[Schema](data)``. If None, returns a dataframe. + to_format: Optional[Union[Format, Callable]] = None + + #: Buffer to be provided when to_format is a custom callable. See docs for + #: example of how to implement an example of a to format function. + to_format_buffer: Optional[Union[str, Callable]] = None + + #: a dictionary keyword arguments to pass into the writer function that + #: converts the pandera-validate-able object to type ``to_format``. + #: The writer function is implemented in the pandera.typing + #: generic types via the ``from_format`` and ``to_format`` methods. + to_format_kwargs: Optional[Dict[str, Any]] = None + + #: a dictionary object to store key-value data at schema level + metadata: Optional[dict] = None diff --git a/pandera/api/pyspark/types.py b/pandera/api/pyspark/types.py new file mode 100644 index 000000000..0663685a1 --- /dev/null +++ b/pandera/api/pyspark/types.py @@ -0,0 +1,100 @@ +"""Utility functions for pyspark validation.""" + +from functools import lru_cache +from typing import List, NamedTuple, Tuple, Type, Union + +import pyspark.sql.types as pst +from pyspark.sql import DataFrame + +from pandera.api.checks import Check +from pandera.dtypes import DataType + +try: + from typing import Literal, NamedTuple +except ImportError: + from typing_extensions import Literal # type: ignore [misc] + + +CheckList = Union[Check, List[Check]] + +PysparkDefaultTypes = Union[ + pst.BooleanType, + pst.StringType, + pst.IntegerType, + pst.DecimalType, + pst.FloatType, + pst.DateType, + pst.TimestampType, + pst.DoubleType, + pst.ShortType, + pst.ByteType, + pst.LongType, + pst.DayTimeIntervalType, + pst.BinaryType, +] + +PySparkDtypeInputTypes = Union[ + str, + int, + float, + bool, + type, + DataType, + Type, + pst.BooleanType, + pst.StringType, + pst.IntegerType, + pst.DecimalType, + pst.FloatType, + pst.DateType, + pst.TimestampType, + pst.DoubleType, + pst.ShortType, + pst.ByteType, + pst.LongType, + pst.DayTimeIntervalType, + pst.BinaryType, +] + +StrictType = Union[bool, Literal["filter"]] + +SupportedTypes = NamedTuple( + "SupportedTypes", + (("table_types", Tuple[type, ...]),), +) + + +class PysparkDataframeColumnObject(NamedTuple): + dataframe: DataFrame + column_name: str + + +@lru_cache(maxsize=None) +def supported_types() -> SupportedTypes: + """Get the types supported by pandera schemas.""" + # pylint: disable=import-outside-toplevel + table_types = [DataFrame] + + try: + table_types.append(DataFrame) + + except ImportError: + pass + + return SupportedTypes( + tuple(table_types), + ) + + +def is_table(obj): + """Verifies whether an object is table-like. + + Where a table is a 2-dimensional data matrix of rows and columns, which + can be indexed in multiple different ways. + """ + return isinstance(obj, supported_types().table_types) + + +def is_bool(x): + """Verifies whether an object is a boolean type.""" + return isinstance(x, (bool, pst.BooleanType())) diff --git a/pandera/backends/base/builtin_checks.py b/pandera/backends/base/builtin_checks.py index dc9ba8846..11d23827f 100644 --- a/pandera/backends/base/builtin_checks.py +++ b/pandera/backends/base/builtin_checks.py @@ -13,7 +13,6 @@ from pandera.api.checks import Check - T = TypeVar("T") diff --git a/pandera/backends/pandas/__init__.py b/pandera/backends/pandas/__init__.py index 65de3a9eb..141f1f93a 100644 --- a/pandera/backends/pandas/__init__.py +++ b/pandera/backends/pandas/__init__.py @@ -65,13 +65,11 @@ for t in dataframe_datatypes: DataFrameSchema.register_backend(t, DataFrameSchemaBackend) - # SeriesSchema.register_backend(t, SeriesSchemaBackend) Column.register_backend(t, ColumnBackend) MultiIndex.register_backend(t, MultiIndexBackend) Index.register_backend(t, IndexBackend) for t in series_datatypes: - # DataFrameSchema.register_backend(t, DataFrameSchemaBackend) SeriesSchema.register_backend(t, SeriesSchemaBackend) Column.register_backend(t, ColumnBackend) Index.register_backend(t, IndexBackend) diff --git a/pandera/backends/pandas/base.py b/pandera/backends/pandas/base.py index 36194d8b7..7ab55b616 100644 --- a/pandera/backends/pandas/base.py +++ b/pandera/backends/pandas/base.py @@ -16,12 +16,12 @@ from pandera.api.base.checks import CheckResult from pandera.backends.base import BaseSchemaBackend, CoreCheckResult from pandera.backends.pandas.error_formatters import ( + consolidate_failure_cases, format_generic_error_message, format_vectorized_error_message, - consolidate_failure_cases, - summarize_failure_cases, reshape_failure_cases, scalar_failure_case, + summarize_failure_cases, ) from pandera.errors import FailureCaseMetadata, SchemaError, SchemaErrorReason diff --git a/pandera/backends/pandas/builtin_checks.py b/pandera/backends/pandas/builtin_checks.py index 5ba87fdb6..4e12ab429 100644 --- a/pandera/backends/pandas/builtin_checks.py +++ b/pandera/backends/pandas/builtin_checks.py @@ -2,17 +2,15 @@ import operator import re -from typing import cast, Any, Iterable, TypeVar, Union +from typing import Any, Iterable, TypeVar, Union, cast import pandas as pd import pandera.strategies as st from pandera.api.extensions import register_builtin_check - from pandera.typing.modin import MODIN_INSTALLED from pandera.typing.pyspark import PYSPARK_INSTALLED - if MODIN_INSTALLED and not PYSPARK_INSTALLED: # pragma: no cover import modin.pandas as mpd diff --git a/pandera/backends/pandas/builtin_hypotheses.py b/pandera/backends/pandas/builtin_hypotheses.py index 2766a0240..e29fbc444 100644 --- a/pandera/backends/pandas/builtin_hypotheses.py +++ b/pandera/backends/pandas/builtin_hypotheses.py @@ -3,10 +3,9 @@ from typing import Tuple +from pandera.api.extensions import register_builtin_hypothesis from pandera.backends.pandas.builtin_checks import PandasData from pandera.backends.pandas.hypotheses import HAS_SCIPY -from pandera.api.extensions import register_builtin_hypothesis - if HAS_SCIPY: from scipy import stats diff --git a/pandera/backends/pandas/checks.py b/pandera/backends/pandas/checks.py index 886b4e7c2..dffc5953f 100644 --- a/pandera/backends/pandas/checks.py +++ b/pandera/backends/pandas/checks.py @@ -4,17 +4,17 @@ from typing import Dict, List, Optional, Union, cast import pandas as pd -from multimethod import overload, DispatchError +from multimethod import DispatchError, overload -from pandera.backends.base import BaseCheckBackend from pandera.api.base.checks import CheckResult, GroupbyObject from pandera.api.checks import Check from pandera.api.pandas.types import ( - is_table, + is_bool, is_field, + is_table, is_table_or_field, - is_bool, ) +from pandera.backends.base import BaseCheckBackend class PandasCheckBackend(BaseCheckBackend): diff --git a/pandera/backends/pandas/components.py b/pandera/backends/pandas/components.py index 66c98b083..3a7e2815e 100644 --- a/pandera/backends/pandas/components.py +++ b/pandera/backends/pandas/components.py @@ -18,7 +18,7 @@ ) from pandera.backends.pandas.error_formatters import scalar_failure_case from pandera.error_handlers import SchemaErrorHandler -from pandera.errors import SchemaError, SchemaErrors, SchemaErrorReason +from pandera.errors import SchemaError, SchemaErrorReason, SchemaErrors class ColumnBackend(ArraySchemaBackend): diff --git a/pandera/backends/pandas/container.py b/pandera/backends/pandas/container.py index 27a4520b1..ea43d707f 100644 --- a/pandera/backends/pandas/container.py +++ b/pandera/backends/pandas/container.py @@ -20,10 +20,10 @@ from pandera.error_handlers import SchemaErrorHandler from pandera.errors import ( ParserError, - SchemaError, - SchemaErrors, SchemaDefinitionError, + SchemaError, SchemaErrorReason, + SchemaErrors, ) diff --git a/pandera/backends/pandas/hypotheses.py b/pandera/backends/pandas/hypotheses.py index d8ee5c217..6344e1371 100644 --- a/pandera/backends/pandas/hypotheses.py +++ b/pandera/backends/pandas/hypotheses.py @@ -1,16 +1,15 @@ """Hypothesis backend for pandas.""" from functools import partial -from typing import cast, Any, Callable, Dict, Union +from typing import Any, Callable, Dict, Union, cast import pandas as pd from multimethod import overload from pandera import errors -from pandera.backends.pandas.checks import PandasCheckBackend from pandera.api.hypotheses import Hypothesis from pandera.api.pandas.types import is_field, is_table - +from pandera.backends.pandas.checks import PandasCheckBackend try: from scipy import stats # pylint: disable=unused-import diff --git a/pandera/backends/pyspark/__init__.py b/pandera/backends/pyspark/__init__.py new file mode 100644 index 000000000..218bfb951 --- /dev/null +++ b/pandera/backends/pyspark/__init__.py @@ -0,0 +1,19 @@ +"""PySpark native backend implementation for schemas and checks.""" + +import pyspark.sql as pst + +from pandera.api.checks import Check +from pandera.api.pyspark.column_schema import ColumnSchema +from pandera.api.pyspark.components import Column +from pandera.api.pyspark.container import DataFrameSchema +from pandera.backends.pyspark import builtin_checks +from pandera.backends.pyspark.checks import PySparkCheckBackend +from pandera.backends.pyspark.column import ColumnSchemaBackend +from pandera.backends.pyspark.components import ColumnBackend +from pandera.backends.pyspark.container import DataFrameSchemaBackend + +for t in [pst.DataFrame]: + Check.register_backend(t, PySparkCheckBackend) + ColumnSchema.register_backend(t, ColumnSchemaBackend) + Column.register_backend(t, ColumnBackend) + DataFrameSchema.register_backend(t, DataFrameSchemaBackend) diff --git a/pandera/backends/pyspark/base.py b/pandera/backends/pyspark/base.py new file mode 100644 index 000000000..f0f1fb7dc --- /dev/null +++ b/pandera/backends/pyspark/base.py @@ -0,0 +1,120 @@ +"""pyspark Parsing, Validation, and Error Reporting Backends.""" + +import warnings +from typing import ( + Any, + Dict, + FrozenSet, + Iterable, + List, + NamedTuple, + Optional, + TypeVar, + Union, +) + +from pyspark.sql import DataFrame +from pyspark.sql.functions import col + +from pandera.backends.base import BaseSchemaBackend +from pandera.backends.pyspark.error_formatters import ( # format_vectorized_error_message,; consolidate_failure_cases,; summarize_failure_cases,; reshape_failure_cases, + format_generic_error_message, + scalar_failure_case, +) +from pandera.backends.pyspark.utils import ConfigParams +from pandera.errors import FailureCaseMetadata, SchemaError + + +class ColumnInfo(NamedTuple): + """Column metadata used during validation.""" + + sorted_column_names: Iterable + expanded_column_names: FrozenSet + destuttered_column_names: List + absent_column_names: List + lazy_exclude_column_names: List + + +FieldCheckObj = Union[col, DataFrame] + +T = TypeVar( + "T", + col, + DataFrame, + FieldCheckObj, + covariant=True, +) + + +class PysparkSchemaBackend(BaseSchemaBackend): + """Base backend for pyspark schemas.""" + + try: + params = ConfigParams("pyspark", "parameters.yaml") + except Exception as err: + raise err + + def subsample( + self, + check_obj: DataFrame, + sample: Optional[float] = None, + seed: Optional[int] = None, + ): + if sample is not None: + return check_obj.sample( + withReplacement=False, fraction=sample, seed=seed + ) + return check_obj + + def run_check( + self, + check_obj, + schema, + check, + check_index: int, + *args, + ) -> bool: + """Handle check results, raising SchemaError on check failure. + + :param check_index: index of check in the schema component check list. + :param check: Check object used to validate pyspark object. + :param check_args: arguments to pass into check object. + :returns: True if check results pass or check.raise_warning=True, otherwise + False. + """ + + check_result = check(check_obj, *args) + if not check_result.check_passed: + if check_result.failure_cases is None: + # encode scalar False values explicitly + failure_cases = scalar_failure_case(check_result.check_passed) + error_msg = format_generic_error_message(schema, check) + + # raise a warning without exiting if the check is specified to do so + if check.raise_warning: + warnings.warn(error_msg, UserWarning) + return True + + raise SchemaError( + schema, + check_obj, + error_msg, + failure_cases=failure_cases, + check=check, + check_index=check_index, + check_output=check_result.check_output, + ) + return check_result.check_passed + + def failure_cases_metadata( + self, + schema_name: str, + schema_errors: List[Dict[str, Any]], + ) -> FailureCaseMetadata: + """Create failure cases metadata required for SchemaErrors exception.""" + + return FailureCaseMetadata( + failure_cases=None, + message=schema_errors, + error_counts=None, + ) diff --git a/pandera/backends/pyspark/builtin_checks.py b/pandera/backends/pyspark/builtin_checks.py new file mode 100644 index 000000000..de31bb0e5 --- /dev/null +++ b/pandera/backends/pyspark/builtin_checks.py @@ -0,0 +1,326 @@ +"""PySpark implementation of built-in checks""" + +import re +from typing import Any, Iterable, TypeVar, Union + +import pyspark.sql.types as pst +from pyspark.sql.functions import col + +import pandera.strategies as st +from pandera.api.extensions import register_builtin_check +from pandera.api.pyspark.types import PysparkDataframeColumnObject +from pandera.backends.pyspark.utils import convert_to_list +from pandera.backends.pyspark.decorators import register_input_datatypes + +T = TypeVar("T") +ALL_NUMERIC_TYPE = [ + pst.LongType, + pst.IntegerType, + pst.ByteType, + pst.ShortType, + pst.DoubleType, + pst.DecimalType, + pst.FloatType, +] +ALL_DATE_TYPE = [pst.DateType, pst.TimestampType] +BOLEAN_TYPE = pst.BooleanType +BINARY_TYPE = pst.BinaryType +STRING_TYPE = pst.StringType +DAYTIMEINTERVAL_TYPE = pst.DayTimeIntervalType + + +@register_builtin_check( + aliases=["eq"], + error="equal_to({value})", +) +@register_input_datatypes( + acceptable_datatypes=convert_to_list( + ALL_NUMERIC_TYPE, ALL_DATE_TYPE, STRING_TYPE, BINARY_TYPE, BOLEAN_TYPE + ) +) +def equal_to(data: PysparkDataframeColumnObject, value: Any) -> bool: + """Ensure all elements of a data container equal a certain value. + + :param data: PysparkDataframeColumnObject column object which is a contains dataframe and column name to do the check + :param value: values in this DataFrame data structure must be + equal to this value. + """ + cond = col(data.column_name) == value + return data.dataframe.filter(~cond).limit(1).count() == 0 + + +@register_builtin_check( + aliases=["ne"], + strategy=st.ne_strategy, + error="not_equal_to({value})", +) +@register_input_datatypes( + acceptable_datatypes=convert_to_list( + ALL_NUMERIC_TYPE, ALL_DATE_TYPE, STRING_TYPE, BINARY_TYPE, BOLEAN_TYPE + ) +) +def not_equal_to(data: PysparkDataframeColumnObject, value: Any) -> bool: + """Ensure no elements of a data container equals a certain value. + :param data: PysparkDataframeColumnObject column object which is a contains dataframe and column name to do the check + :param value: This value must not occur in the checked + """ + cond = col(data.column_name) != value + return data.dataframe.filter(~cond).limit(1).count() == 0 + + +@register_builtin_check( + aliases=["gt"], + error="greater_than({min_value})", +) +@register_input_datatypes( + acceptable_datatypes=convert_to_list(ALL_NUMERIC_TYPE, ALL_DATE_TYPE) +) +def greater_than(data: PysparkDataframeColumnObject, min_value: Any) -> bool: + """ + Ensure values of a data container are strictly greater than a minimum + value. + :param data: PysparkDataframeColumnObject column object which is a contains dataframe and column name to do the check + :param min_value: Lower bound to be exceeded. + """ + cond = col(data.column_name) > min_value + return data.dataframe.filter(~cond).limit(1).count() == 0 + + +@register_builtin_check( + aliases=["ge"], + strategy=st.ge_strategy, + error="greater_than_or_equal_to({min_value})", +) +@register_input_datatypes( + acceptable_datatypes=convert_to_list(ALL_NUMERIC_TYPE, ALL_DATE_TYPE) +) +def greater_than_or_equal_to( + data: PysparkDataframeColumnObject, min_value: Any +) -> bool: + """Ensure all values are greater or equal a certain value. + :param data: PysparkDataframeColumnObject column object which is a contains dataframe and column name to do the check + :param min_value: Allowed minimum value for values of a series. Must be + a type comparable to the dtype of the column datatype of pyspark + """ + cond = col(data.column_name) >= min_value + return data.dataframe.filter(~cond).limit(1).count() == 0 + + +@register_builtin_check( + aliases=["lt"], + strategy=st.lt_strategy, + error="less_than({max_value})", +) +@register_input_datatypes( + acceptable_datatypes=convert_to_list(ALL_NUMERIC_TYPE, ALL_DATE_TYPE) +) +def less_than(data: PysparkDataframeColumnObject, max_value: Any) -> bool: + """Ensure values of a series are strictly below a maximum value. + + :param data: PysparkDataframeColumnObject column object which is a contains dataframe and column name to do the check + :param max_value: All elements of a series must be strictly smaller + than this. Must be a type comparable to the dtype of the column datatype of pyspark + """ + if max_value is None: + raise ValueError("max_value must not be None") + cond = col(data.column_name) < max_value + return data.dataframe.filter(~cond).limit(1).count() == 0 + + +@register_builtin_check( + aliases=["le"], + strategy=st.le_strategy, + error="less_than_or_equal_to({max_value})", +) +@register_input_datatypes( + acceptable_datatypes=convert_to_list(ALL_NUMERIC_TYPE, ALL_DATE_TYPE) +) +def less_than_or_equal_to( + data: PysparkDataframeColumnObject, max_value: Any +) -> bool: + """Ensure values of a series are strictly below a maximum value. + :param data: PysparkDataframeColumnObject column object which is a contains dataframe and column name to do the check + :param max_value: Upper bound not to be exceeded. Must be + a type comparable to the dtype of the column datatype of pyspark + """ + if max_value is None: + raise ValueError("max_value must not be None") + cond = col(data.column_name) <= max_value + return data.dataframe.filter(~cond).limit(1).count() == 0 + + +@register_builtin_check( + aliases=["between"], + strategy=st.in_range_strategy, + error="in_range({min_value}, {max_value})", +) +@register_input_datatypes( + acceptable_datatypes=convert_to_list(ALL_NUMERIC_TYPE, ALL_DATE_TYPE) +) +def in_range( + data: PysparkDataframeColumnObject, + min_value: T, + max_value: T, + include_min: bool = True, + include_max: bool = True, +): + """Ensure all values of a column are within an interval. + + Both endpoints must be a type comparable to the dtype of the + :class:`pyspark.sql.function.col` to be validated. + + :param data: PysparkDataframeColumnObject column object which is a contains dataframe and column name to do the check + :param min_value: Left / lower endpoint of the interval. + :param max_value: Right / upper endpoint of the interval. Must not be + smaller than min_value. + :param include_min: Defines whether min_value is also an allowed value + (the default) or whether all values must be strictly greater than + min_value. + :param include_max: Defines whether min_value is also an allowed value + (the default) or whether all values must be strictly smaller than + max_value. + """ + # Using functions from operator module to keep conditions out of the + # closure + cond_right = ( + col(data.column_name) >= min_value + if include_min + else col(data.column_name) > min_value + ) + cond_left = ( + col(data.column_name) <= max_value + if include_max + else col(data.column_name) < max_value + ) + return data.dataframe.filter(~(cond_right & cond_left)).limit(1).count() == 0 # type: ignore + + +@register_builtin_check( + strategy=st.isin_strategy, + error="isin({allowed_values})", +) +@register_input_datatypes( + acceptable_datatypes=convert_to_list( + ALL_NUMERIC_TYPE, ALL_DATE_TYPE, STRING_TYPE, BINARY_TYPE + ) +) +def isin(data: PysparkDataframeColumnObject, allowed_values: Iterable) -> bool: + """Ensure only allowed values occur within a series. + + Remember it can be a compute intensive check on large dataset. So, use it with caution. + + This checks whether all elements of a :class:`pyspark.sql.function.col` + are part of the set of elements of allowed values. If allowed + values is a string, the set of elements consists of all distinct + characters of the string. Thus only single characters which occur + in allowed_values at least once can meet this condition. If you + want to check for substrings use :meth:`Check.str_contains`. + + :param data: PysparkDataframeColumnObject column object which is a contains dataframe and column name to do the check + :param allowed_values: The set of allowed values. May be any iterable. + :param kwargs: key-word arguments passed into the `Check` initializer. + """ + return ( + data.dataframe.filter( + ~col(data.column_name).isin(list(allowed_values)) + ) + .limit(1) + .count() + == 0 + ) + + +@register_builtin_check( + strategy=st.notin_strategy, + error="notin({forbidden_values})", +) +@register_input_datatypes( + acceptable_datatypes=convert_to_list( + ALL_NUMERIC_TYPE, ALL_DATE_TYPE, STRING_TYPE, BINARY_TYPE + ) +) +def notin( + data: PysparkDataframeColumnObject, forbidden_values: Iterable +) -> bool: + """Ensure some defined values don't occur within a series. + + Remember it can be a compute intensive check on large dataset. So, use it with caution. + + Like :meth:`Check.isin` this check operates on single characters if + it is applied on strings. If forbidden_values is a string, it is understood + as set of prohibited characters. Any string of length > 1 can't be in it by + design. + + :param data: PysparkDataframeColumnObject column object which is a contains dataframe and column name to do the check + :param forbidden_values: The set of values which should not occur. May + be any iterable. + :param raise_warning: if True, check raises UserWarning instead of + SchemaError on validation. + """ + return ( + data.dataframe.filter( + col(data.column_name).isin(list(forbidden_values)) + ) + .limit(1) + .count() + == 0 + ) + + +@register_builtin_check( + strategy=st.str_contains_strategy, + error="str_contains('{pattern}')", +) +@register_input_datatypes(acceptable_datatypes=convert_to_list(STRING_TYPE)) +def str_contains( + data: PysparkDataframeColumnObject, pattern: Union[str, re.Pattern] +) -> bool: + """Ensure that a pattern can be found within each row. + + Remember it can be a compute intensive check on large dataset. So, use it with caution. + + :param data: PysparkDataframeColumnObject column object which is a contains dataframe and column name to do the check + :param pattern: Regular expression pattern to use for searching + :param kwargs: key-word arguments passed into the `Check` initializer. + """ + + return ( + data.dataframe.filter(~col(data.column_name).rlike(pattern.pattern)) + .limit(1) + .count() + == 0 + ) + + +@register_builtin_check( + error="str_startswith('{string}')", +) +@register_input_datatypes(acceptable_datatypes=convert_to_list(STRING_TYPE)) +def str_startswith(data: PysparkDataframeColumnObject, string: str) -> bool: + """Ensure that all values start with a certain string. + + Remember it can be a compute intensive check on large dataset. So, use it with caution. + + :param data: PysparkDataframeColumnObject column object which is a contains dataframe and column name to do the check + :param string: String all values should start with + :param kwargs: key-word arguments passed into the `Check` initializer. + """ + cond = col(data.column_name).startswith(string) + return data.dataframe.filter(~cond).limit(1).count() == 0 + + +@register_builtin_check( + strategy=st.str_endswith_strategy, error="str_endswith('{string}')" +) +@register_input_datatypes(acceptable_datatypes=convert_to_list(STRING_TYPE)) +def str_endswith(data: PysparkDataframeColumnObject, string: str) -> bool: + """Ensure that all values end with a certain string. + + Remember it can be a compute intensive check on large dataset. So, use it with caution. + + :param data: PysparkDataframeColumnObject column object which is a contains dataframe and column name to do the check + :param string: String all values should end with + :param kwargs: key-word arguments passed into the `Check` initializer. + """ + cond = col(data.column_name).endswith(string) + return data.dataframe.filter(~cond).limit(1).count() == 0 diff --git a/pandera/backends/pyspark/checks.py b/pandera/backends/pyspark/checks.py new file mode 100644 index 000000000..16fba5a7c --- /dev/null +++ b/pandera/backends/pyspark/checks.py @@ -0,0 +1,117 @@ +"""Check backend for pyspark.""" + +from functools import partial +from typing import Any, Dict, List, Optional + +from multimethod import DispatchError, overload +from pyspark.sql import DataFrame + +from pandera.api.base.checks import CheckResult, GroupbyObject +from pandera.api.checks import Check +from pandera.api.pyspark.types import ( + PysparkDataframeColumnObject, + is_bool, + is_table, +) +from pandera.backends.base import BaseCheckBackend + + +class PySparkCheckBackend(BaseCheckBackend): + """Check backend for PySpark.""" + + def __init__(self, check: Check): + """Initializes a check backend object.""" + super().__init__(check) + assert check._check_fn is not None, "Check._check_fn must be set." + self.check = check + self.check_fn = partial(check._check_fn, **check._check_kwargs) + + def groupby(self, check_obj: DataFrame): + """Implements groupby behavior for check object.""" + assert self.check.groupby is not None, "Check.groupby must be set." + if isinstance(self.check.groupby, (str, list)): + return check_obj.groupby(self.check.groupby) + return self.check.groupby(check_obj) + + def query(self, check_obj): + """Implements querying behavior to produce subset of check object.""" + raise NotImplementedError + + def aggregate(self, check_obj): + """Implements aggregation behavior for check object.""" + raise NotImplementedError + + @staticmethod + def _format_groupby_input( + groupby_obj: GroupbyObject, + groups: Optional[List[str]], + ) -> Dict[str, DataFrame]: + pass + + @overload # type: ignore [no-redef] + def preprocess( + self, + check_obj: DataFrame, # type: ignore [valid-type] + key: str, + ) -> DataFrame: + return check_obj + + @overload + def apply(self, check_obj): + """Apply the check function to a check object.""" + raise NotImplementedError + + @overload # type: ignore [no-redef] + def apply(self, check_obj: DataFrame): + return self.check_fn(check_obj) + + @overload # type: ignore [no-redef] + def apply(self, check_obj: is_table): # type: ignore [valid-type] + return self.check_fn(check_obj) + + @overload # type: ignore [no-redef] + def apply(self, check_obj: DataFrame, column_name: str, kwargs: dict): # type: ignore [valid-type] + # kwargs['column_name'] = column_name + # return self.check._check_fn(check_obj, *list(kwargs.values())) + check_obj_and_col_name = PysparkDataframeColumnObject( + check_obj, column_name + ) + return self.check._check_fn(check_obj_and_col_name, **kwargs) + + @overload + def postprocess(self, check_obj, check_output): + """Postprocesses the result of applying the check function.""" + raise TypeError( + f"output type of check_fn not recognized: {type(check_output)}" + ) + + @overload # type: ignore [no-redef] + def postprocess( + self, + check_obj, + check_output: is_bool, # type: ignore [valid-type] + ) -> CheckResult: + """Postprocesses the result of applying the check function.""" + return CheckResult( + check_output=check_output, + check_passed=check_output, + checked_object=check_obj, + failure_cases=None, + ) + + def __call__( + self, + check_obj: DataFrame, + key: Optional[str] = None, + ) -> CheckResult: + check_obj = self.preprocess(check_obj, key) + try: + check_output = self.apply(check_obj, key, self.check._check_kwargs) + + except DispatchError as exc: + if exc.__cause__ is not None: + raise exc.__cause__ + raise exc + except TypeError as err: + raise err + return self.postprocess(check_obj, check_output) diff --git a/pandera/backends/pyspark/column.py b/pandera/backends/pyspark/column.py new file mode 100644 index 000000000..5d0894d07 --- /dev/null +++ b/pandera/backends/pyspark/column.py @@ -0,0 +1,241 @@ +"""Pandera array backends.""" + +import traceback +from typing import Iterable, NamedTuple, Optional, cast + +from multimethod import DispatchError +from pyspark.sql import DataFrame +from pyspark.sql.functions import col + +from pandera.api.pyspark.error_handler import ErrorCategory, ErrorHandler +from pandera.backends.pyspark.base import PysparkSchemaBackend +from pandera.backends.pyspark.error_formatters import scalar_failure_case +from pandera.backends.pyspark.decorators import validate_params +from pandera.engines.pyspark_engine import Engine +from pandera.errors import ParserError, SchemaError, SchemaErrorReason + + +class CoreCheckResult(NamedTuple): + """Namedtuple for holding results of core checks.""" + + check: str + reason_code: SchemaErrorReason + passed: bool + message: Optional[str] = None + failure_cases: Optional[Iterable] = None + + +class ColumnSchemaBackend(PysparkSchemaBackend): + """Backend for pyspark arrays.""" + + def preprocess(self, check_obj, inplace: bool = False): + return check_obj + + @validate_params(params=PysparkSchemaBackend.params, scope="SCHEMA") + def _core_checks(self, check_obj, schema, error_handler): + """This function runs the core checks""" + # run the core checks + for core_check in ( + self.check_name, + self.check_dtype, + self.check_nullable, + ): + check_result = core_check(check_obj, schema) + if not check_result.passed: + error_handler.collect_error( + ErrorCategory.SCHEMA, + check_result.reason_code, + SchemaError( + schema=schema, + data=check_obj, + message=check_result.message, + failure_cases=check_result.failure_cases, + check=check_result.check, + reason_code=check_result.reason_code, + ), + ) + + def validate( + self, + check_obj, + schema, + *, + head: Optional[int] = None, + tail: Optional[int] = None, + sample: Optional[int] = None, + random_state: Optional[int] = None, + lazy: bool = False, + inplace: bool = False, + error_handler: ErrorHandler, + ): + # pylint: disable=too-many-locals + check_obj = self.preprocess(check_obj, inplace) + + if schema.coerce: + try: + check_obj = self.coerce_dtype( + check_obj, schema=schema, error_handler=error_handler + ) + except SchemaError as exc: + error_handler.collect_error( + ErrorCategory.SCHEMA, exc.reason_code, exc + ) + + self._core_checks(check_obj, schema, error_handler) + + self.run_checks(check_obj, schema, error_handler, lazy) + + return check_obj + + @validate_params(params=PysparkSchemaBackend.params, scope="SCHEMA") + def coerce_dtype( + self, + check_obj, + *, + schema=None, + # pylint: disable=unused-argument + ): + """Coerce type of a pyspark.sql.function.col by type specified in dtype. + + :param Dataframe: Pyspark DataFrame + :returns: ``DataFrame`` with coerced data type + """ + assert schema is not None, "The `schema` argument must be provided." + if schema.dtype is None or not schema.coerce: + return check_obj + + try: + return schema.dtype.try_coerce(check_obj) + except ParserError as exc: + raise SchemaError( + schema=schema, + data=check_obj, + message=( + f"Error while coercing '{schema.name}' to type " + f"{schema.dtype}: {exc}:\n{exc.failure_cases}" + ), + failure_cases=exc.failure_cases, + check=f"coerce_dtype('{schema.dtype}')", + ) from exc + + @validate_params(params=PysparkSchemaBackend.params, scope="SCHEMA") + def check_nullable(self, check_obj: DataFrame, schema): + isna = ( + check_obj.filter(col(schema.name).isNull()).limit(1).count() == 0 + ) + passed = schema.nullable or isna + return CoreCheckResult( + check="not_nullable", + reason_code=SchemaErrorReason.SERIES_CONTAINS_NULLS, + passed=cast(bool, passed), + message=(f"non-nullable column '{schema.name}' contains null"), + failure_cases=scalar_failure_case(schema.name), + ) + + @validate_params(params=PysparkSchemaBackend.params, scope="SCHEMA") + def check_name(self, check_obj: DataFrame, schema): + column_found = not ( + schema.name is None or schema.name not in check_obj.columns + ) + return CoreCheckResult( + check=f"field_name('{schema.name}')", + reason_code=SchemaErrorReason.WRONG_FIELD_NAME + if not column_found + else SchemaErrorReason.NO_ERROR, + passed=column_found, + message=( + f"Expected {type(check_obj)} to have column named: '{schema.name}', " + f"but found columns '{check_obj.columns}'" + if not column_found + else f"column check_name validation passed." + ), + failure_cases=scalar_failure_case(schema.name) + if not column_found + else None, + ) + + @validate_params(params=PysparkSchemaBackend.params, scope="SCHEMA") + def check_dtype(self, check_obj: DataFrame, schema): + passed = True + failure_cases = None + msg = None + + if schema.dtype is not None: + dtype_check_results = schema.dtype.check( + Engine.dtype(check_obj.schema[schema.name].dataType), + ) + + if isinstance(dtype_check_results, bool): + passed = dtype_check_results + failure_cases = scalar_failure_case( + str(Engine.dtype(check_obj.schema[schema.name].dataType)) + ) + msg = ( + f"expected column '{schema.name}' to have type " + f"{schema.dtype}, got {Engine.dtype(check_obj.schema[schema.name].dataType)}" + if not passed + else f"column type matched with expected '{schema.dtype}'" + ) + reason_code = ( + SchemaErrorReason.WRONG_DATATYPE + if not dtype_check_results + else SchemaErrorReason.NO_ERROR + ) + + return CoreCheckResult( + check=f"dtype('{schema.dtype}')", + reason_code=reason_code, + passed=passed, + message=msg, + failure_cases=failure_cases, + ) + + @validate_params(params=PysparkSchemaBackend.params, scope="DATA") + # pylint: disable=unused-argument + def run_checks(self, check_obj, schema, error_handler, lazy): + check_results = [] + for check_index, check in enumerate(schema.checks): + check_args = [schema.name] + try: + check_results.append( + self.run_check( + check_obj, + schema, + check, + check_index, + *check_args, + ) + ) + except SchemaError as err: + error_handler.collect_error( + ErrorCategory.DATA, + SchemaErrorReason.DATAFRAME_CHECK, + err, + ) + except Exception as err: # pylint: disable=broad-except + # catch other exceptions that may occur when executing the Check + if isinstance(err, DispatchError): + # if the error was raised by a check registered via + # multimethod, get the underlying __cause__ + err = err.__cause__ + + err_msg = f'"{err.args[0]}"' if len(err.args) > 0 else "" + err_str = f"{err.__class__.__name__}({ err_msg})" + error_handler.collect_error( + ErrorCategory.DATA, + SchemaErrorReason.CHECK_ERROR, + SchemaError( + schema=schema, + data=check_obj, + message=( + f"Error while executing check function: {err_str}\n" + + traceback.format_exc() + ), + failure_cases=scalar_failure_case(err_str), + check=check, + check_index=check_index, + ), + original_exc=err, + ) + + return check_results diff --git a/pandera/backends/pyspark/components.py b/pandera/backends/pyspark/components.py new file mode 100644 index 000000000..3f3a5f2fb --- /dev/null +++ b/pandera/backends/pyspark/components.py @@ -0,0 +1,163 @@ +"""Backend implementation for pyspark schema components.""" + +import re +import traceback +from copy import copy +from typing import Iterable, Optional + +from pyspark.sql import DataFrame +from pyspark.sql.functions import cast + +from pandera.api.pyspark.error_handler import ErrorCategory, ErrorHandler +from pandera.backends.pyspark.column import ColumnSchemaBackend +from pandera.backends.pyspark.error_formatters import scalar_failure_case +from pandera.backends.pyspark.decorators import validate_params +from pandera.errors import SchemaError, SchemaErrorReason + + +class ColumnBackend(ColumnSchemaBackend): + """Backend implementation for pyspark dataframe columns.""" + + def validate( + self, + check_obj: DataFrame, + schema, + *, + head: Optional[int] = None, + tail: Optional[int] = None, + sample: Optional[int] = None, + random_state: Optional[int] = None, + lazy: bool = False, + inplace: bool = False, + error_handler: ErrorHandler, + ) -> DataFrame: + """Validation backend implementation for pyspark dataframe columns..""" + + if schema.name is None: + raise SchemaError( + schema, + check_obj, + "column name is set to None. Pass the ``name` argument when " + "initializing a Column object, or use the ``set_name`` " + "method.", + ) + + def validate_column(check_obj, column_name): + try: + # pylint: disable=super-with-arguments + super(ColumnBackend, self).validate( + check_obj, + copy(schema).set_name(column_name), + head=head, + tail=tail, + sample=sample, + random_state=random_state, + lazy=lazy, + inplace=inplace, + error_handler=error_handler, + ) + + except SchemaError as err: + error_handler.collect_error( + ErrorCategory.DATA, err.reason_code, err + ) + + column_keys_to_check = ( + self.get_regex_columns(schema, check_obj.columns, check_obj) + if schema.regex + else [schema.name] + ) + + for column_name in column_keys_to_check: + if schema.coerce: + check_obj = self.coerce_dtype( + check_obj, + schema=schema, + error_handler=error_handler, + ) + validate_column(check_obj, column_name) + + return check_obj + + def get_regex_columns(self, schema, columns) -> Iterable: + """Get matching column names based on regex column name pattern. + + :param schema: schema specification to use + :param columns: columns to regex pattern match + :returns: matchin columns + """ + pattern = re.compile(schema.name) + column_keys_to_check = [ + col_name for col_name in columns if pattern.match(col_name) + ] + if len(column_keys_to_check) == 0: + raise SchemaError( + schema=schema, + data=columns, + message=( + f"Column regex name='{schema.name}' did not match any " + "columns in the dataframe. Update the regex pattern so " + f"that it matches at least one column:\n{columns.tolist()}", + ), + failure_cases=scalar_failure_case(str(columns)), + check=f"no_regex_column_match('{schema.name}')", + ) + + return column_keys_to_check + + @validate_params(params=ColumnSchemaBackend.params, scope="SCHEMA") + def coerce_dtype( + self, + check_obj: DataFrame, + *, + schema=None, + ) -> DataFrame: + """Coerce dtype of a column, handling duplicate column names.""" + # pylint: disable=super-with-arguments + # pylint: disable=fixme + + check_obj = check_obj.withColumn(schema.name, cast(schema.dtype)) + + return check_obj + + @validate_params(params=ColumnSchemaBackend.params, scope="DATA") + def run_checks(self, check_obj, schema, error_handler, lazy): + check_results = [] + for check_index, check in enumerate(schema.checks): + check_args = [schema.name] + try: + check_results.append( + self.run_check( + check_obj, schema, check, check_index, *check_args + ) + ) + except SchemaError as err: + error_handler.collect_error( + type=ErrorCategory.DATA, + reason_code=SchemaErrorReason.DATAFRAME_CHECK, + schema_error=err, + ) + except TypeError as err: + raise err + except Exception as err: # pylint: disable=broad-except + # catch other exceptions that may occur when executing the Check + err_msg = f'"{err.args[0]}"' if len(err.args) > 0 else "" + err_str = f"{err.__class__.__name__}({ err_msg})" + + error_handler.collect_error( + type=ErrorCategory.DATA, + reason_code=SchemaErrorReason.CHECK_ERROR, + schema_error=SchemaError( + schema=schema, + data=check_obj, + message=( + f"Error while executing check function: {err_str}\n" + + traceback.format_exc() + ), + failure_cases=scalar_failure_case(err_str), + check=check, + check_index=check_index, + ), + original_exc=err, + ) + return check_results diff --git a/pandera/backends/pyspark/container.py b/pandera/backends/pyspark/container.py new file mode 100644 index 000000000..7429c53b5 --- /dev/null +++ b/pandera/backends/pyspark/container.py @@ -0,0 +1,546 @@ +"""Pyspark Parsing, Validation, and Error Reporting Backends.""" + +import copy +import traceback +from typing import Any, List, Optional + +from pyspark.sql import DataFrame +from pyspark.sql.functions import col + +from pandera.api.pyspark.error_handler import ErrorCategory, ErrorHandler +from pandera.api.pyspark.types import is_table +from pandera.backends.pyspark.base import ColumnInfo, PysparkSchemaBackend +from pandera.backends.pyspark.error_formatters import scalar_failure_case +from pandera.backends.pyspark.decorators import validate_params + +from pandera.errors import ( + ParserError, + SchemaDefinitionError, + SchemaError, + SchemaErrorReason, + SchemaErrors, +) +import warnings + + +class DataFrameSchemaBackend(PysparkSchemaBackend): + """Backend for pyspark DataFrameSchema.""" + + def preprocess(self, check_obj: DataFrame, inplace: bool = False): + """Preprocesses a check object before applying check functions.""" + return check_obj + + @validate_params(params=PysparkSchemaBackend.params, scope="SCHEMA") + def _column_checks( + self, + check_obj: DataFrame, + schema, + column_info: ColumnInfo, + error_handler: ErrorHandler, + ): + """run the checks related to columns presence, uniqueness and filter column if neccesary""" + + # check the container metadata, e.g. field names + try: + self.check_column_names_are_unique(check_obj, schema) + except SchemaError as exc: + error_handler.collect_error( + type=ErrorCategory.SCHEMA, + reason_code=exc.reason_code, + schema_error=exc, + ) + + try: + self.check_column_presence(check_obj, schema, column_info) + except SchemaErrors as exc: + for schema_error in exc.schema_errors: + error_handler.collect_error( + type=ErrorCategory.SCHEMA, + reason_code=schema_error["reason_code"], + schema_error=schema_error["error"], + ) + + # strictness check and filter + try: + check_obj = self.strict_filter_columns( + check_obj, schema, column_info, error_handler + ) + except SchemaError as exc: + error_handler.collect_error( + type=ErrorCategory.SCHEMA, + reason_code=exc.reason_code, + schema_error=exc, + ) + # try to coerce datatypes + check_obj = self.coerce_dtype( + check_obj, + schema=schema, + error_handler=error_handler, + ) + + return check_obj + + def validate( + self, + check_obj: DataFrame, + schema, + *, + head: Optional[int] = None, + tail: Optional[int] = None, + sample: Optional[int] = None, + random_state: Optional[int] = None, + lazy: bool = False, + inplace: bool = False, + error_handler: ErrorHandler = None, + ): + """ + Parse and validate a check object, returning type-coerced and validated + object. + """ + if self.params["VALIDATION"] == "DISABLE": + warnings.warn( + "Skipping the validation checks as validation is disabled" + ) + return check_obj + if not is_table(check_obj): + raise TypeError( + f"expected a pyspark DataFrame, got {type(check_obj)}" + ) + + check_obj = self.preprocess(check_obj, inplace=inplace) + if hasattr(check_obj, "pandera"): + check_obj = check_obj.pandera.add_schema(schema) + column_info = self.collect_column_info(check_obj, schema, lazy) + + # validate the columns of the dataframe + check_obj = self._column_checks( + check_obj, schema, column_info, error_handler + ) + + # collect schema components and prepare check object to be validated + schema_components = self.collect_schema_components( + check_obj, schema, column_info + ) + check_obj_subsample = self.subsample(check_obj, sample, random_state) + try: + # TODO: need to create apply at column level + self.run_schema_component_checks( + check_obj_subsample, schema_components, lazy, error_handler + ) + except SchemaError as exc: + error_handler.collect_error( + type=ErrorCategory.SCHEMA, + reason_code=exc.reason_code, + schema_error=exc, + ) + try: + self.run_checks(check_obj_subsample, schema, error_handler) + except SchemaError as exc: + error_handler.collect_error( + type=ErrorCategory.DATA, + reason_code=exc.reason_code, + schema_error=exc, + ) + + error_dicts = {} + + if error_handler.collected_errors: + error_dicts = error_handler.summarize(schema=schema) + + check_obj.pandera.errors = error_dicts + return check_obj + + def run_schema_component_checks( + self, + check_obj: DataFrame, + schema_components: List, + lazy: bool, + error_handler: ErrorHandler, + ): + """Run checks for all schema components.""" + check_results = [] + # schema-component-level checks + for schema_component in schema_components: + try: + result = schema_component.validate( + check_obj=check_obj, + lazy=lazy, + inplace=True, + error_handler=error_handler, + ) + check_results.append(is_table(result)) + except SchemaError as err: + error_handler.collect_error( + ErrorCategory.SCHEMA, + err.reason_code, + err, + ) + assert all(check_results) + + @validate_params(params=PysparkSchemaBackend.params, scope="DATA") + def run_checks(self, check_obj: DataFrame, schema, error_handler): + """Run a list of checks on the check object.""" + # dataframe-level checks + check_results = [] + for check_index, check in enumerate( + schema.checks + ): # schama.checks is null + try: + check_results.append( + self.run_check(check_obj, schema, check, check_index) + ) + except SchemaError as err: + error_handler.collect_error( + ErrorCategory.DATA, err.reason_code, err + ) + except SchemaDefinitionError: + raise + except Exception as err: # pylint: disable=broad-except + # catch other exceptions that may occur when executing the check + err_msg = f'"{err.args[0]}"' if len(err.args) > 0 else "" + err_str = f"{err.__class__.__name__}({ err_msg})" + msg = ( + f"Error while executing check function: {err_str}\n" + + traceback.format_exc() + ) + + error_handler.collect_error( + ErrorCategory.DATA, + SchemaErrorReason.CHECK_ERROR, + SchemaError( + self, + check_obj, + msg, + failure_cases=scalar_failure_case(err_str), + check=check, + check_index=check_index, + ), + original_exc=err, + ) + + def collect_column_info( + self, + check_obj: DataFrame, + schema, + lazy: bool, + ) -> ColumnInfo: + """Collect column metadata.""" + column_names: List[Any] = [] + absent_column_names: List[Any] = [] + lazy_exclude_column_names: List[Any] = [] + + for col_name, col_schema in schema.columns.items(): + if ( + not col_schema.regex + and col_name not in check_obj.columns + and col_schema.required + ): + absent_column_names.append(col_name) + if lazy: + # NOTE: remove this since we can just use + # absent_column_names in the collect_schema_components + # method + lazy_exclude_column_names.append(col_name) + + if col_schema.regex: + try: + column_names.extend( + col_schema.BACKEND.get_regex_columns( + col_schema, check_obj.columns + ) + ) + except SchemaError: + pass + elif col_name in check_obj.columns: + column_names.append(col_name) + + # drop adjacent duplicated column names + + destuttered_column_names = list(set(check_obj.columns)) + + return ColumnInfo( + sorted_column_names=dict.fromkeys(column_names), + expanded_column_names=frozenset(column_names), + destuttered_column_names=destuttered_column_names, + absent_column_names=absent_column_names, + lazy_exclude_column_names=lazy_exclude_column_names, + ) + + def collect_schema_components( + self, + check_obj: DataFrame, + schema, + column_info: ColumnInfo, + ): + """Collects all schema components to use for validation.""" + schema_components = [] + for col_name, col in schema.columns.items(): + if ( + col.required or col_name in check_obj + ) and col_name not in column_info.lazy_exclude_column_names: + col = copy.deepcopy(col) + if schema.dtype is not None: + # override column dtype with dataframe dtype + col.dtype = schema.dtype + + # disable coercion at the schema component level since the + # dataframe-level schema already coerced it. + col.coerce = False + schema_components.append(col) + + return schema_components + + ########### + # Parsers # + ########### + @validate_params(params=PysparkSchemaBackend.params, scope="SCHEMA") + def strict_filter_columns( + self, + check_obj: DataFrame, + schema, + column_info: ColumnInfo, + error_handler: ErrorHandler, + ): + """Filters columns that aren't specified in the schema.""" + # dataframe strictness check makes sure all columns in the dataframe + # are specified in the dataframe schema + if not (schema.strict or schema.ordered): + return check_obj + + filter_out_columns = [] + + sorted_column_names = iter(column_info.sorted_column_names) + for column in column_info.destuttered_column_names: + is_schema_col = column in column_info.expanded_column_names + if schema.strict is True and not is_schema_col: + error_handler.collect_error( + ErrorCategory.SCHEMA, + SchemaErrorReason.COLUMN_NOT_IN_SCHEMA, + SchemaError( + schema=schema, + data=check_obj, + message=( + f"column '{column}' not in {schema.__class__.__name__}" + f" {schema.columns}" + ), + failure_cases=scalar_failure_case(column), + check="column_in_schema", + reason_code=SchemaErrorReason.COLUMN_NOT_IN_SCHEMA, + ), + ) + if schema.strict == "filter" and not is_schema_col: + filter_out_columns.append(column) + if schema.ordered and is_schema_col: + try: + next_ordered_col = next(sorted_column_names) + except StopIteration: + pass + if next_ordered_col != column: + error_handler.collect_error( + ErrorCategory.SCHEMA, + SchemaErrorReason.COLUMN_NOT_ORDERED, + SchemaError( + schema=schema, + data=check_obj, + message=f"column '{column}' out-of-order", + failure_cases=scalar_failure_case(column), + check="column_ordered", + reason_code=SchemaErrorReason.COLUMN_NOT_ORDERED, + ), + ) + + if schema.strict == "filter": + check_obj = check_obj.drop(*filter_out_columns) + + return check_obj + + @validate_params(params=PysparkSchemaBackend.params, scope="SCHEMA") + def coerce_dtype( + self, + check_obj: DataFrame, + *, + schema=None, + error_handler: ErrorHandler = None, + ): + """Coerces check object to the expected type.""" + assert schema is not None, "The `schema` argument must be provided." + + if not ( + schema.coerce or any(col.coerce for col in schema.columns.values()) + ): + return check_obj + + try: + check_obj = self._coerce_dtype(check_obj, schema) + + except SchemaErrors as err: + for schema_error_dict in err.schema_errors: + if not error_handler.lazy: + # raise the first error immediately if not doing lazy + # validation + raise schema_error_dict["error"] + error_handler.collect_error( + ErrorCategory.DTYPE_COERCION, + SchemaErrorReason.CHECK_ERROR, + schema_error_dict["error"], + ) + except SchemaError as err: + if not error_handler.lazy: + raise err + error_handler.collect_error( + ErrorCategory.SCHEMA, err.reason_code, err + ) + + return check_obj + + def _coerce_dtype( + self, + obj: DataFrame, + schema, + ) -> DataFrame: + """Coerce dataframe to the type specified in dtype. + + :param obj: dataframe to coerce. + :returns: dataframe with coerced dtypes + """ + # NOTE: clean up the error handling! + error_handler = ErrorHandler(lazy=True) + + def _coerce_df_dtype(obj: DataFrame) -> DataFrame: + if schema.dtype is None: + raise ValueError( + "dtype argument is None. Must specify this argument " + "to coerce dtype" + ) + + try: + return schema.dtype.try_coerce(obj) + except ParserError as exc: + raise SchemaError( + schema=schema, + data=obj, + message=( + f"Error while coercing '{schema.name}' to type " + f"{schema.dtype}: {exc}\n{exc.failure_cases}" + ), + failure_cases=exc.failure_cases, + check=f"coerce_dtype('{schema.dtype}')", + ) from exc + + def _try_coercion(obj, colname, col_schema): + try: + schema = obj.pandera.schema + + obj = obj.withColumn( + colname, col(colname).cast(col_schema.dtype.type) + ) + obj.pandera.add_schema(schema) + return obj + + except SchemaError as exc: + error_handler.collect_error( + ErrorCategory.DTYPE_COERCION, exc.reason_code, exc + ) + return obj + + for colname, col_schema in schema.columns.items(): + if col_schema.regex: + try: + matched_columns = col_schema.BACKEND.get_regex_columns( + col_schema, obj.columns + ) + except SchemaError: + matched_columns = None + + for matched_colname in matched_columns: + if col_schema.coerce or schema.coerce: + obj = _try_coercion( + obj, + matched_colname, + col_schema + # col_schema.coerce_dtype, obj[matched_colname] + ) + + elif ( + (col_schema.coerce or schema.coerce) + and schema.dtype is None + and colname in obj.columns + ): + _col_schema = copy.deepcopy(col_schema) + _col_schema.coerce = True + obj = _try_coercion(obj, colname, col_schema) + + if schema.dtype is not None: + obj = _try_coercion(_coerce_df_dtype, obj) + + if error_handler.collected_errors: + raise SchemaErrors( + schema=schema, + schema_errors=error_handler.collected_errors, + data=obj, + ) + + return obj + + ########## + # Checks # + ########## + + @validate_params(params=PysparkSchemaBackend.params, scope="SCHEMA") + def check_column_names_are_unique(self, check_obj: DataFrame, schema): + """Check for column name uniquness.""" + if not schema.unique_column_names: + return + column_count_dict = {} + failed = [] + for column_name in check_obj.columns: + if column_count_dict.get(column_name): + # Insert to the list only once + if column_count_dict[column_name] == 1: + failed.append(column_name) + column_count_dict[column_name] += 1 + + else: + column_count_dict[column_name] = 0 + + if failed: + raise SchemaError( + schema=schema, + data=check_obj, + message=( + "dataframe contains multiple columns with label(s): " + f"{failed}" + ), + failure_cases=scalar_failure_case(failed), + check="dataframe_column_labels_unique", + reason_code=SchemaErrorReason.DUPLICATE_COLUMN_LABELS, + ) + + @validate_params(params=PysparkSchemaBackend.params, scope="SCHEMA") + def check_column_presence( + self, check_obj: DataFrame, schema, column_info: ColumnInfo + ): + """Check for presence of specified columns in the data object.""" + if column_info.absent_column_names: + reason_code = SchemaErrorReason.COLUMN_NOT_IN_DATAFRAME + raise SchemaErrors( + schema=schema, + schema_errors=[ + { + "reason_code": reason_code, + "error": SchemaError( + schema=schema, + data=check_obj, + message=( + f"column '{colname}' not in dataframe" + f"\n{check_obj.head()}" + ), + failure_cases=scalar_failure_case(colname), + check="column_in_dataframe", + reason_code=reason_code, + ), + } + for colname in column_info.absent_column_names + ], + data=check_obj, + ) diff --git a/pandera/backends/pyspark/decorators.py b/pandera/backends/pyspark/decorators.py new file mode 100644 index 000000000..31aa19f5d --- /dev/null +++ b/pandera/backends/pyspark/decorators.py @@ -0,0 +1,96 @@ +import warnings +import functools + +import pyspark.sql + +from pandera.errors import SchemaError +from typing import List, Type +from pandera.api.pyspark.types import PysparkDefaultTypes + + +def register_input_datatypes( + acceptable_datatypes: List[Type[PysparkDefaultTypes]] = None, +): + """ + This decorator is used to register the input datatype for the check. + An Error would br raised in case the type is not in the list of acceptable types. + + :param acceptable_datatypes: List of pyspark datatypes for which the function is applicable + """ + + def wrapper(func): + @functools.wraps(func) + def _wrapper(*args, **kwargs): + # Get the pyspark object from arguments + pyspark_object = [a for a in args][0] + validation_df = pyspark_object.dataframe + validation_column = pyspark_object.column_name + pandera_schema_datatype = validation_df.pandera.schema.get_dtypes( + validation_df + )[validation_column].type.typeName + # Type Name of the valid datatypes needed for comparison to remove the parameterized values since + # only checking type not the parameters + valid_datatypes = [i.typeName for i in acceptable_datatypes] + current_datatype = ( + validation_df.select(validation_column) + .schema[0] + .dataType.typeName + ) + if pandera_schema_datatype != current_datatype: + raise SchemaError( + schema=validation_df.pandera.schema, + data=validation_df, + message=f'The check with name "{func.__name__}" only accepts the following datatypes \n' + f"{[i.typeName() for i in acceptable_datatypes]} but got {current_datatype()} from the input. \n" + f" This error is usually caused by schema mismatch of value is different from schema defined in" + f" pandera schema", + ) + if current_datatype in valid_datatypes: + return func(*args, **kwargs) + else: + raise TypeError( + f'The check with name "{func.__name__}" only supports the following datatypes ' + f'{[i.typeName() for i in acceptable_datatypes]} and not the given "{current_datatype()}" ' + f"datatype" + ) + + return _wrapper + + return wrapper + + +def validate_params(params, scope): + def _wrapper(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + if scope == "SCHEMA": + if (params["DEPTH"] == "SCHEMA_AND_DATA") or ( + params["DEPTH"] == "SCHEMA_ONLY" + ): + return func(self, *args, **kwargs) + else: + warnings.warn( + "Skipping Execution of function as parameters set to DATA_ONLY " + ) + if not kwargs: + for key, value in kwargs.items(): + if isinstance(value, pyspark.sql.DataFrame): + return value + if args: + for value in args: + if isinstance(value, pyspark.sql.DataFrame): + return value + + elif scope == "DATA": + if (params["DEPTH"] == "SCHEMA_AND_DATA") or ( + params["DEPTH"] == "DATA_ONLY" + ): + return func(self, *args, **kwargs) + else: + warnings.warn( + "Skipping Execution of function as parameters set to SCHEMA_ONLY " + ) + + return wrapper + + return _wrapper diff --git a/pandera/backends/pyspark/error_formatters.py b/pandera/backends/pyspark/error_formatters.py new file mode 100644 index 000000000..a3586613d --- /dev/null +++ b/pandera/backends/pyspark/error_formatters.py @@ -0,0 +1,26 @@ +"""Make schema error messages human-friendly.""" + + +def format_generic_error_message( + parent_schema, + check, +) -> str: + """Construct an error message when a check validator fails. + + :param parent_schema: class of schema being validated. + :param check: check that generated error. + :param check_index: The validator that failed. + """ + return f"{parent_schema} failed validation " f"{check.error}" + + +def scalar_failure_case(x) -> dict: + """Construct failure case from a scalar value. + + :param x: a scalar value representing failure case. + :returns: DataFrame used for error reporting with ``SchemaErrors``. + """ + return { + "index": [None], + "failure_case": [x], + } diff --git a/pandera/backends/pyspark/utils.py b/pandera/backends/pyspark/utils.py new file mode 100644 index 000000000..de949f8f4 --- /dev/null +++ b/pandera/backends/pyspark/utils.py @@ -0,0 +1,60 @@ +"""pyspark backend utilities.""" +import yaml +import os +import pandera + + +def convert_to_list(*args): + converted_list = [] + for arg in args: + if isinstance(arg, list): + converted_list.extend(arg) + else: + converted_list.append(arg) + + return converted_list + + +class ConfigParams(dict): + def __init__(self, module_name, config_name): + self.module_name = module_name + self.config_name = config_name + self.config = self.fetch_yaml(self.module_name, self.config_name) + self.validate_params(self.config) + super().__init__(self.config) + + @staticmethod + def fetch_yaml(module_name, config_file): + root_dir = os.path.abspath( + os.path.join(os.path.dirname(pandera.__file__), "..") + ) + path = os.path.join(root_dir, "conf", module_name, config_file) + with open(path) as file: + return yaml.safe_load(file) + + @staticmethod + def validate_params(config): + if not config.get("VALIDATION"): + raise ValueError( + 'Parameter "VALIDATION" not found in config, ensure the parameter value is in upper case' + ) + else: + if config.get("VALIDATION") not in ["ENABLE", "DISABLE"]: + raise ValueError( + "Parameter 'VALIDATION' only supports 'ON' or 'OFF' as valid values." + "Ensure the value is in upper case only" + ) + if not config.get("DEPTH"): + raise ValueError( + 'Parameter "DEPTH" not found in config, ensure the parameter value is in upper case' + ) + else: + if config.get("DEPTH") not in [ + "SCHEMA_ONLY", + "DATA_ONLY", + "SCHEMA_AND_DATA", + ]: + raise ValueError( + "Parameter 'VALIDATION' only supports 'ON' or 'OFF' as valid values." + "Ensure the value is in upper case only" + ) diff --git a/pandera/decorators.py b/pandera/decorators.py index 2ff8f4045..ce2814e9f 100644 --- a/pandera/decorators.py +++ b/pandera/decorators.py @@ -2,8 +2,8 @@ import functools import inspect import sys -import typing import types +import typing from collections import OrderedDict from typing import ( Any, @@ -703,7 +703,6 @@ def validate_inputs( args: Tuple[Any, ...], kwargs: Dict[str, Any], ) -> Tuple[Dict[str, Any], Dict[str, Any]]: - if instance is not None: # If the wrapped function is a method -> add "self" as the first positional arg args = (instance, *args) diff --git a/pandera/dtypes.py b/pandera/dtypes.py index 5a30a5ac6..13354a77b 100644 --- a/pandera/dtypes.py +++ b/pandera/dtypes.py @@ -530,6 +530,14 @@ def __str__(self) -> str: return "timedelta" +@immutable +class Binary(DataType): + """Semantic representation of a delta time data type.""" + + def __str__(self) -> str: + return "binary" + + ############################################################################### # Utilities ############################################################################### @@ -596,6 +604,11 @@ def is_timedelta(pandera_dtype: Union[DataType, Type[DataType]]) -> bool: return is_subdtype(pandera_dtype, Timedelta) +def is_binary(pandera_dtype: Union[DataType, Type[DataType]]) -> bool: + """Return True if :class:`pandera.dtypes.DataType` is a timedelta.""" + return is_subdtype(pandera_dtype, Binary) + + UniqueSettings = Union[ # Report all unique errors except the first Literal["exclude_first"], diff --git a/pandera/engines/engine.py b/pandera/engines/engine.py index ab198f705..75acadded 100644 --- a/pandera/engines/engine.py +++ b/pandera/engines/engine.py @@ -188,6 +188,8 @@ def _wrapper(pandera_dtype_cls: Type[_DataType]) -> Type[_DataType]: ) if equivalents: + # pylint: disable=fixme + # Todo - Need changes to this function to support uninitialised object cls._register_equivalents(pandera_dtype_cls, *equivalents) if "from_parametrized_dtype" in pandera_dtype_cls.__dict__: @@ -211,6 +213,8 @@ def dtype(cls: _EngineType, data_type: Any) -> DataType: data_type, cls._base_pandera_dtypes ): try: + # pylint: disable=fixme + # Todo - check if we can move to validate without initialization return data_type() except (TypeError, AttributeError) as err: raise TypeError( diff --git a/pandera/engines/numpy_engine.py b/pandera/engines/numpy_engine.py index 4f150fbed..e50a635e5 100644 --- a/pandera/engines/numpy_engine.py +++ b/pandera/engines/numpy_engine.py @@ -12,9 +12,9 @@ from pandera import dtypes, errors from pandera.dtypes import immutable -from pandera.system import FLOAT_128_AVAILABLE from pandera.engines import engine, utils from pandera.engines.type_aliases import PandasObject +from pandera.system import FLOAT_128_AVAILABLE @immutable(init=True) @@ -166,7 +166,6 @@ def _build_number_equivalents( @Engine.register_dtype(equivalents=_int_equivalents[64]) @immutable class Int64(DataType, dtypes.Int64): - type = np.dtype("int64") bit_width: int = 64 diff --git a/pandera/engines/pandas_engine.py b/pandera/engines/pandas_engine.py index 35fae1194..6b1ecac4e 100644 --- a/pandera/engines/pandas_engine.py +++ b/pandera/engines/pandas_engine.py @@ -33,7 +33,6 @@ from pandera import dtypes, errors from pandera.dtypes import immutable -from pandera.system import FLOAT_128_AVAILABLE from pandera.engines import engine, numpy_engine, utils from pandera.engines.type_aliases import ( PandasDataType, @@ -41,6 +40,7 @@ PandasObject, ) from pandera.engines.utils import pandas_version +from pandera.system import FLOAT_128_AVAILABLE try: import pyarrow # pylint: disable=unused-import @@ -750,7 +750,6 @@ class _BaseDateTime(DataType): @staticmethod def _get_to_datetime_fn(obj: Any) -> Callable: - # NOTE: this is a hack to support pyspark.pandas. This needs to be # thoroughly tested, right now pyspark.pandas returns NA when a # dtype value can't be coerced into the target dtype. diff --git a/pandera/engines/pyspark_engine.py b/pandera/engines/pyspark_engine.py new file mode 100644 index 000000000..a29c03179 --- /dev/null +++ b/pandera/engines/pyspark_engine.py @@ -0,0 +1,568 @@ +"""PySpark engine and data types.""" +# pylint:disable=too-many-ancestors + +# docstrings are inherited +# pylint:disable=missing-class-docstring + +# pylint doesn't know about __init__ generated with dataclass +# pylint:disable=unexpected-keyword-arg,no-value-for-parameter + +import dataclasses +import inspect +import re +import warnings +from typing import Any, Iterable, Union + +import pyspark.sql.types as pst +from pyspark.sql.types import DecimalType + +from pandera import dtypes, errors +from pandera.dtypes import immutable +from pandera.engines import engine +from pandera.engines.engine import Engine +from pandera.engines.type_aliases import PysparkObject + +try: + import pyarrow # pylint:disable=unused-import + + PYARROW_INSTALLED = True +except ImportError: + PYARROW_INSTALLED = False + +try: + from typing import Literal # type: ignore +except ImportError: + from typing_extensions import Literal # type: ignore + +DEFAULT_PYSPARK_PREC = DecimalType().precision +DEFAULT_PYSPARK_SCALE = DecimalType().scale + + +@immutable(init=True) +class DataType(dtypes.DataType): + """Base `DataType` for boxing PySpark data types.""" + + type: Any = dataclasses.field(repr=False, init=False) + """Native pyspark dtype boxed by the data type.""" + + def __init__(self, dtype: Any): + super().__init__() + # Pyspark str() doesnot return equivalent string using the below code to convert the datatype to class + try: + if isinstance(dtype, str): + dtype = eval("pst." + dtype) + except AttributeError: + pass + except TypeError: + pass + + object.__setattr__(self, "type", dtype) + dtype_cls = dtype if inspect.isclass(dtype) else dtype.__class__ + warnings.warn( + f"'{dtype_cls}' support is not guaranteed.\n" + + "Usage Tip: Consider writing a custom " + + "pandera.dtypes.DataType or opening an issue at " + + "https://github.com/pandera-dev/pandera" + ) + + def __post_init__(self): + # this method isn't called if __init__ is defined + object.__setattr__(self, "type", self.type) # pragma: no cover + + def check( + self, + pandera_dtype: dtypes.DataType, + ) -> Union[bool, Iterable[bool]]: + try: + return self.type == pandera_dtype.type + except TypeError: + return False + + def __str__(self) -> str: + return str(self.type) + + def __repr__(self) -> str: + return f"DataType({self})" + + def coerce(self, data_container: PysparkObject) -> PysparkObject: + """Pure coerce without catching exceptions.""" + coerced = data_container.astype(self.type) + if type(data_container).__module__.startswith("modin.pandas"): + # NOTE: this is a hack to enable catching of errors in modin + coerced.__str__() + return coerced + + def try_coerce(self, data_container: PysparkObject) -> PysparkObject: + try: + coerced = self.coerce(data_container) + if type(data_container).__module__.startswith("pyspark.pandas"): + # NOTE: this is a hack to enable catching of errors in modin + coerced.__str__() + except Exception as exc: # pylint:disable=broad-except + if isinstance(exc, errors.ParserError): + raise + else: + type_alias = str(self) + raise errors.ParserError( + f"Could not coerce {type(data_container)} data_container " + f"into type {type_alias}", + failure_cases=None, + ) from exc + + return coerced + + +class Engine( # pylint:disable=too-few-public-methods + metaclass=engine.Engine, + base_pandera_dtypes=(DataType), +): + """PySpark data type engine.""" + + @classmethod + def dtype(cls, data_type: Any) -> dtypes.DataType: + """Convert input into a pyspark-compatible + Pandera :class:`~pandera.dtypes.DataType` object.""" + try: + if isinstance(data_type, str): + regex = r"(\(\d.*?\b\))" + subst = "()" + # You can manually specify the number of replacements by changing the 4th argument + data_type = re.sub(regex, subst, data_type, 0, re.MULTILINE) + return engine.Engine.dtype(cls, data_type) + except TypeError: + raise + + +############################################################################### +# boolean +############################################################################### + + +@Engine.register_dtype( + equivalents=[ + bool, + "bool", + "BooleanType()", + pst.BooleanType(), + pst.BooleanType, + ], +) +@immutable +class Bool(DataType, dtypes.Bool): + """Semantic representation of a :class:`pyspark.sql.types.BooleanType`.""" + + type = pst.BooleanType() + _bool_like = frozenset({True, False}) + + def coerce_value(self, value: Any) -> Any: + """Coerce an value to specified boolean type.""" + if value not in self._bool_like: + raise TypeError( + f"value {value} cannot be coerced to type {self.type}" + ) + return super().coerce_value(value) + + +############################################################################### +# string +############################################################################### + + +@Engine.register_dtype( + equivalents=[str, "str", "string", "StringType()", pst.StringType(), pst.StringType], # type: ignore +) +@immutable +class String(DataType, dtypes.String): # type: ignore + """Semantic representation of a :class:`pyspark.sql.types.StringType`.""" + + type = pst.StringType() # type: ignore + + +############################################################################### +# integer +############################################################################### + + +@Engine.register_dtype( + equivalents=[int, "int", "IntegerType()", pst.IntegerType(), pst.IntegerType], # type: ignore +) +@immutable +class Int(DataType, dtypes.Int): # type: ignore + """Semantic representation of a :class:`pyspark.sql.types.IntegerType`.""" + + type = pst.IntegerType() # type: ignore + + +############################################################################### +# float +############################################################################### + + +@Engine.register_dtype( + equivalents=[float, "float", "FloatType()", pst.FloatType(), pst.FloatType], # type: ignore +) +@immutable +class Float(DataType, dtypes.Float): # type: ignore + """Semantic representation of a :class:`pyspark.sql.types.FloatType`.""" + + type = pst.FloatType() # type: ignore + + +############################################################################### +# bigint or long +############################################################################### + + +@Engine.register_dtype( + equivalents=["bigint", "long", "LongType()", pst.LongType(), pst.LongType], # type: ignore +) +@immutable +class BigInt(DataType, dtypes.Int64): # type: ignore + """Semantic representation of a :class:`pyspark.sql.types.LongType`.""" + + type = pst.LongType() # type: ignore + + +############################################################################### +# smallint +############################################################################### + + +@Engine.register_dtype( + equivalents=["smallint", "short", "ShortType()", pst.ShortType(), pst.ShortType], # type: ignore +) +@immutable +class ShortInt(DataType, dtypes.Int16): # type: ignore + """Semantic representation of a :class:`pyspark.sql.types.ShortType`.""" + + type = pst.ShortType() # type: ignore + + +############################################################################### +# tinyint +############################################################################### + + +@Engine.register_dtype( + equivalents=[bytes, "tinyint", "bytes", "ByteType()", pst.ByteType(), pst.ByteType], # type: ignore +) +@immutable +class ByteInt(DataType, dtypes.Int8): # type: ignore + """Semantic representation of a :class:`pyspark.sql.types.ByteType`.""" + + type = pst.ByteType() # type: ignore + + +############################################################################### +# decimal +############################################################################### + + +@Engine.register_dtype( + equivalents=["decimal", "DecimalType()", pst.DecimalType(), pst.DecimalType], # type: ignore +) +@immutable(init=True) +class Decimal(DataType, dtypes.Decimal): # type: ignore + """Semantic representation of a :class:`pyspark.sql.types.DecimalType`.""" + + type: pst.DecimalType = dataclasses.field(default=pst.DecimalType, init=False) # type: ignore[assignment] # noqa + + # precision: int = dataclasses.field(default=DEFAULT_PYSPARK_PREC, init=False) + # scale: int = dataclasses.field(default=DEFAULT_PYSPARK_SCALE, init=False) + def __init__( # pylint:disable=super-init-not-called + self, + precision: int = DEFAULT_PYSPARK_PREC, + scale: int = DEFAULT_PYSPARK_SCALE, + ) -> None: + dtypes.Decimal.__init__(self, precision, scale, None) + object.__setattr__( + self, + "type", + pst.DecimalType(self.precision, self.scale), # type: ignore + ) + + def __post_init__(self): + object.__setattr__( + self, + "type", + pst.DecimalType(precision=self.precision, scale=self.scale), + ) + + @classmethod + def from_parametrized_dtype(cls, ps_dtype: pst.DecimalType): + """Convert a :class:`pyspark.sql.types.DecimalType` to + a Pandera :class:`pandera.engines.pyspark_engine.Decimal`.""" + return cls(precision=ps_dtype.precision, scale=ps_dtype.scale) # type: ignore + + def check( + self, + pandera_dtype: dtypes.DataType, + data_container: Any = None, + ) -> Union[bool, Iterable[bool]]: + try: + pandera_dtype = Engine.dtype(pandera_dtype) + except TypeError: + return False + + # attempts to compare pyspark native type if possible + # to let subclass inherit check + # (super will compare that DataType classes are exactly the same) + try: + return ( + (self.type == pandera_dtype.type) + & (self.scale == pandera_dtype.scale) + & (self.precision == pandera_dtype.precision) + ) # or super().check(pandera_dtype) + + except TypeError: + return super().check(pandera_dtype) + + +############################################################################### +# double +############################################################################### + + +@Engine.register_dtype( + equivalents=["double", "DoubleType()", pst.DoubleType(), pst.DoubleType], # type: ignore +) +@immutable +class Double(DataType, dtypes.Float): # type: ignore + """Semantic representation of a :class:`pyspark.sql.types.DoubleType`.""" + + type = pst.DoubleType() + + +############################################################################### +# date +############################################################################### + + +@Engine.register_dtype( + equivalents=["date", "DateType()", pst.DateType(), pst.DateType], # type: ignore +) +@immutable +class Date(DataType, dtypes.Date): # type: ignore + """Semantic representation of a :class:`pyspark.sql.types.DateType`.""" + + type = pst.DateType() # type: ignore + + +############################################################################### +# timestamp +############################################################################### + + +@Engine.register_dtype( + equivalents=["datetime", "timestamp", "TimestampType()", pst.TimestampType(), pst.TimestampType], # type: ignore +) +@immutable +class Timestamp(DataType, dtypes.Timestamp): # type: ignore + """Semantic representation of a :class:`pyspark.sql.types.TimestampType`.""" + + type = pst.TimestampType() # type: ignore + + +############################################################################### +# binary +############################################################################### + + +@Engine.register_dtype( + equivalents=["binary", "BinaryType()", pst.BinaryType(), pst.BinaryType], # type: ignore +) +@immutable +class Binary(DataType, dtypes.Binary): # type: ignore + """Semantic representation of a :class:`pyspark.sql.types.BinaryType`.""" + + type = pst.BinaryType() # type: ignore + + +############################################################################### +# timedelta +############################################################################### + + +@Engine.register_dtype( + equivalents=[ + "timedelta", + "DayTimeIntervalType()", + pst.DayTimeIntervalType(), + pst.DayTimeIntervalType, + ] +) +@immutable(init=True) +class TimeDelta(DataType): + """Semantic representation of a :class:`pyspark.sql.types.DayTimeIntervalType`.""" + + type: pst.DayTimeIntervalType = dataclasses.field( + default=pst.DayTimeIntervalType, init=False + ) + + def __init__( # pylint:disable=super-init-not-called + self, + startField: int = 0, + endField: int = 3, + ) -> None: + # super().__init__(self) + object.__setattr__(self, "startField", startField) + object.__setattr__(self, "endField", endField) + + object.__setattr__( + self, + "type", + pst.DayTimeIntervalType(self.startField, self.endField), # type: ignore + ) + + @classmethod + def from_parametrized_dtype(cls, ps_dtype: pst.DayTimeIntervalType): + """Convert a :class:`pyspark.sql.types.DayTimeIntervalType` to + a Pandera :class:`pandera.engines.pyspark_engine.TimeDelta`.""" + return cls(startField=ps_dtype.startField, endField=ps_dtype.endField) # type: ignore + + def check( + self, + pandera_dtype: dtypes.DataType, + data_container: Any = None, + ) -> Union[bool, Iterable[bool]]: + try: + pandera_dtype = Engine.dtype(pandera_dtype) + except TypeError: + return False + + # attempts to compare pyspark native type if possible + # to let subclass inherit check + # (super will compare that DataType classes are exactly the same) + try: + return ( + (self.type == pandera_dtype.type) + & (self.type.DAY == pandera_dtype.type.DAY) + & (self.type.HOUR == pandera_dtype.type.HOUR) + & (self.type.MINUTE == pandera_dtype.type.MINUTE) + & (self.type.SECOND == pandera_dtype.type.SECOND) + ) + + except TypeError: + return super().check(pandera_dtype) + + +############################################################################### +# array +############################################################################### + + +@Engine.register_dtype(equivalents=[pst.ArrayType]) +@immutable(init=True) +class ArrayType(DataType): + """Semantic representation of a :class:`pyspark.sql.types.ArrayType`.""" + + type: pst.ArrayType = dataclasses.field(default=pst.ArrayType, init=False) + + def __init__( # pylint:disable=super-init-not-called + self, + elementType: Any = pst.StringType(), + containsNull: bool = True, + ) -> None: + # super().__init__(self) + object.__setattr__(self, "elementType", elementType) + object.__setattr__(self, "containsNull", containsNull) + + object.__setattr__( + self, + "type", + pst.ArrayType(self.elementType, self.containsNull), # type: ignore + ) + + @classmethod + def from_parametrized_dtype(cls, ps_dtype: pst.ArrayType): + """Convert a :class:`pyspark.sql.types.ArrayType` to + a Pandera :class:`pandera.engines.pyspark_engine.ArrayType`.""" + return cls(elementType=ps_dtype.elementType, containsNull=ps_dtype.containsNull) # type: ignore + + def check( + self, + pandera_dtype: dtypes.DataType, + data_container: Any = None, + ) -> Union[bool, Iterable[bool]]: + try: + pandera_dtype = Engine.dtype(pandera_dtype) + except TypeError: + return False + # attempts to compare pyspark native type if possible + # to let subclass inherit check + # (super will compare that DataType classes are exactly the same) + try: + return ( + (self.type == pandera_dtype.type) + & (self.type.elementType == pandera_dtype.type.elementType) + & (self.type.containsNull == pandera_dtype.type.containsNull) + ) + + except TypeError: + return super().check(pandera_dtype) + + +############################################################################### +# map +############################################################################### + + +@Engine.register_dtype(equivalents=[pst.MapType]) +@immutable(init=True) +class MapType(DataType): + """Semantic representation of a :class:`pyspark.sql.types.MapType`.""" + + type: pst.MapType = dataclasses.field(default=pst.MapType, init=False) + + def __init__( # pylint:disable=super-init-not-called + self, + keyType: Any = pst.StringType(), + valueType: Any = pst.StringType(), + valueContainsNull: bool = True, + ) -> None: + # super().__init__(self) + object.__setattr__(self, "keyType", keyType) + object.__setattr__(self, "valueType", valueType) + object.__setattr__(self, "valueContainsNull", valueContainsNull) + + object.__setattr__( + self, + "type", + pst.MapType(self.keyType, self.valueType, self.valueContainsNull), # type: ignore + ) + + @classmethod + def from_parametrized_dtype(cls, ps_dtype: pst.MapType): + """Convert a :class:`pyspark.sql.types.MapType` to + a Pandera :class:`pandera.engines.pyspark_engine.MapType`.""" + return cls( + keyType=ps_dtype.keyType, + valueType=ps_dtype.valueType, + valueContainsNull=ps_dtype.valueContainsNull, + ) # type: ignore + + def check( + self, + pandera_dtype: dtypes.DataType, + data_container: Any = None, + ) -> Union[bool, Iterable[bool]]: + try: + pandera_dtype = Engine.dtype(pandera_dtype) + except TypeError: + return False + # attempts to compare pyspark native type if possible + # to let subclass inherit check + # (super will compare that DataType classes are exactly the same) + try: + return ( + (self.type == pandera_dtype.type) + & (self.type.valueType == pandera_dtype.type.valueType) + & (self.type.keyType == pandera_dtype.type.keyType) + & ( + self.type.valueContainsNull + == pandera_dtype.type.valueContainsNull + ) + ) + + except TypeError: + return super().check(pandera_dtype) diff --git a/pandera/engines/type_aliases.py b/pandera/engines/type_aliases.py index 1aaeca918..3131bc81d 100644 --- a/pandera/engines/type_aliases.py +++ b/pandera/engines/type_aliases.py @@ -4,7 +4,9 @@ import numpy as np import pandas as pd +from pyspark.sql import DataFrame PandasObject = Union[pd.Series, pd.DataFrame] PandasExtensionType = pd.core.dtypes.base.ExtensionDtype PandasDataType = Union[pd.core.dtypes.base.ExtensionDtype, np.dtype, type] +PysparkObject = Union[DataFrame] diff --git a/pandera/engines/utils.py b/pandera/engines/utils.py index 7dd342342..bc239d992 100644 --- a/pandera/engines/utils.py +++ b/pandera/engines/utils.py @@ -46,11 +46,11 @@ def numpy_pandas_coerce_failure_cases( into particular data type. """ # pylint: disable=import-outside-toplevel,cyclic-import - from pandera.engines import pandas_engine from pandera.api.checks import Check - from pandera.api.pandas.types import is_index, is_field, is_table + from pandera.api.pandas.types import is_field, is_index, is_table from pandera.backends.pandas import error_formatters from pandera.backends.pandas.checks import PandasCheckBackend + from pandera.engines import pandas_engine data_type = pandas_engine.Engine.dtype(type_) diff --git a/pandera/errors.py b/pandera/errors.py index c46fd1d5e..a8cca8264 100644 --- a/pandera/errors.py +++ b/pandera/errors.py @@ -153,6 +153,7 @@ class SchemaErrorReason(Enum): SERIES_CHECK = "series_check" WRONG_DATATYPE = "wrong_dtype" INDEX_CHECK = "index_check" + NO_ERROR = "no_errors" class SchemaErrors(ReducedPickleExceptionBase): @@ -180,3 +181,7 @@ def __init__( self.error_counts = failure_cases_metadata.error_counts self.failure_cases = failure_cases_metadata.failure_cases super().__init__(failure_cases_metadata.message) + + +class PysparkSchemaError(ReducedPickleExceptionBase): + """Raised when pyspark schema are collected into one error.""" diff --git a/pandera/extensions.py b/pandera/extensions.py index b6c9e4376..86872ef5d 100644 --- a/pandera/extensions.py +++ b/pandera/extensions.py @@ -2,9 +2,9 @@ # pylint: disable=unused-import from pandera.api.extensions import ( + CheckType, register_builtin_check, register_builtin_hypothesis, - CheckType, register_check_method, register_check_statistics, ) diff --git a/pandera/io/__init__.py b/pandera/io/__init__.py index 7cd0d6eeb..91bfd55e1 100644 --- a/pandera/io/__init__.py +++ b/pandera/io/__init__.py @@ -1,21 +1,21 @@ """Subpackage for serializing/deserializing pandera schemas to other formats.""" from pandera.io.pandas_io import ( - serialize_schema, - deserialize_schema, - from_yaml, - to_yaml, - from_json, - to_json, - to_script, - from_frictionless_schema, - _get_dtype_string_alias, - _serialize_check_stats, - _serialize_dataframe_stats, - _serialize_component_stats, _deserialize_check_stats, _deserialize_component_stats, _format_checks, _format_index, _format_script, + _get_dtype_string_alias, + _serialize_check_stats, + _serialize_component_stats, + _serialize_dataframe_stats, + deserialize_schema, + from_frictionless_schema, + from_json, + from_yaml, + serialize_schema, + to_json, + to_script, + to_yaml, ) diff --git a/pandera/io/pandas_io.py b/pandera/io/pandas_io.py index 8bf16aa83..c69bb1330 100644 --- a/pandera/io/pandas_io.py +++ b/pandera/io/pandas_io.py @@ -10,11 +10,10 @@ import pandas as pd import pandera.errors - from pandera import dtypes -from pandera.api.pandas.container import DataFrameSchema -from pandera.api.pandas.components import Column from pandera.api.checks import Check +from pandera.api.pandas.components import Column +from pandera.api.pandas.container import DataFrameSchema from pandera.engines import pandas_engine from pandera.schema_statistics import get_dataframe_schema_statistics diff --git a/pandera/mypy.py b/pandera/mypy.py index 4774a96a2..99c2c38c6 100644 --- a/pandera/mypy.py +++ b/pandera/mypy.py @@ -2,11 +2,7 @@ from typing import Callable, Optional, Union, cast -from mypy.nodes import ( - FuncBase, - SymbolNode, - TypeInfo, -) +from mypy.nodes import FuncBase, SymbolNode, TypeInfo from mypy.plugin import ( ClassDefContext, FunctionSigContext, diff --git a/pandera/pyspark.py b/pandera/pyspark.py new file mode 100644 index 000000000..bf37e52be --- /dev/null +++ b/pandera/pyspark.py @@ -0,0 +1,108 @@ +try: + import pyspark.sql + + from pandera.accessors import pyspark_sql_accessor + from pandera.api.pyspark import Column, DataFrameSchema + from pandera.api.pyspark.model import DataFrameModel, SchemaModel + from pandera.api.pyspark.model_components import ( + Field, + check, + dataframe_check, + ) + from pandera.api.checks import Check + from pandera.typing import pyspark_sql + from pandera.errors import PysparkSchemaError, SchemaInitError + from pandera.decorators import ( + check_input, + check_io, + check_output, + check_types, + ) + from pandera.backends.pyspark.utils import ConfigParams + from pandera.dtypes import ( + Bool, + Category, + Complex, + Complex64, + Complex128, + Complex256, + DataType, + Date, + DateTime, + Decimal, + Float, + Float16, + Float32, + Float64, + Float128, + Int, + Int8, + Int16, + Int32, + Int64, + String, + Timedelta, + Timestamp, + UInt, + UInt8, + UInt16, + UInt32, + UInt64, + ) + from pandera.schema_inference.pandas import infer_schema + from pandera.version import __version__ + +except ImportError: + pass + +__all__ = [ + # dtypes + "Bool", + "Category", + "Complex", + "Complex64", + "Complex128", + "Complex256", + "DataType", + "DateTime", + "Float", + "Float16", + "Float32", + "Float64", + "Float128", + "Int", + "Int8", + "Int16", + "Int32", + "Int64", + "String", + "Timedelta", + "Timestamp", + "UInt", + "UInt8", + "UInt16", + "UInt32", + "UInt64", + # checks + "Check", + # decorators + "check_input", + "check_io", + "check_output", + "check_types", + # model + "DataFrameModel", + "SchemaModel", + # model_components + "Field", + "check", + "dataframe_check", + # schema_components + "Column", + # schema_inference + "infer_schema", + # schemas + "DataFrameSchema", + # version + "__version__", +] diff --git a/pandera/schema_inference/pandas.py b/pandera/schema_inference/pandas.py index 5bdc6258c..515e8b6ae 100644 --- a/pandera/schema_inference/pandas.py +++ b/pandera/schema_inference/pandas.py @@ -4,14 +4,14 @@ import pandas as pd +from pandera.api.pandas.array import SeriesSchema +from pandera.api.pandas.components import Column, Index, MultiIndex +from pandera.api.pandas.container import DataFrameSchema from pandera.schema_statistics.pandas import ( infer_dataframe_statistics, infer_series_statistics, parse_check_statistics, ) -from pandera.api.pandas.array import SeriesSchema -from pandera.api.pandas.components import Column, Index, MultiIndex -from pandera.api.pandas.container import DataFrameSchema @overload diff --git a/pandera/schema_statistics/__init__.py b/pandera/schema_statistics/__init__.py index b6c0fb69c..0fa0b1ae7 100644 --- a/pandera/schema_statistics/__init__.py +++ b/pandera/schema_statistics/__init__.py @@ -1,12 +1,12 @@ """Module to extract schema statsitics from schema objects.""" from pandera.schema_statistics.pandas import ( - infer_dataframe_statistics, - infer_series_statistics, - infer_index_statistics, - parse_check_statistics, get_dataframe_schema_statistics, get_index_schema_statistics, get_series_schema_statistics, + infer_dataframe_statistics, + infer_index_statistics, + infer_series_statistics, + parse_check_statistics, parse_checks, ) diff --git a/pandera/strategies/base_strategies.py b/pandera/strategies/base_strategies.py index 9004beb94..05c04b4bc 100644 --- a/pandera/strategies/base_strategies.py +++ b/pandera/strategies/base_strategies.py @@ -2,7 +2,6 @@ from typing import Callable, Dict, Tuple, Type - # This strategy registry maps (check_name, data_type) -> strategy_function # For example: ("greater_than", pd.DataFrame) -> () STRATEGY_DISPATCHER: Dict[Tuple[str, Type], Callable] = {} diff --git a/pandera/strategies/pandas_strategies.py b/pandera/strategies/pandas_strategies.py index 6c63bf0ad..0a279fba3 100644 --- a/pandera/strategies/pandas_strategies.py +++ b/pandera/strategies/pandas_strategies.py @@ -50,7 +50,6 @@ import hypothesis.strategies as st from hypothesis.strategies import SearchStrategy, composite except ImportError: # pragma: no cover - # pylint: disable=too-few-public-methods class SearchStrategy: # type: ignore """placeholder type.""" diff --git a/pandera/typing/__init__.py b/pandera/typing/__init__.py index d28d50522..16ffc098e 100644 --- a/pandera/typing/__init__.py +++ b/pandera/typing/__init__.py @@ -6,13 +6,30 @@ from typing import Set, Type -from pandera.typing import dask, fastapi, geopandas, modin, pyspark +from pandera.typing import ( + dask, + fastapi, + geopandas, + modin, + pyspark, + pyspark_sql, +) from pandera.typing.common import ( BOOL, INT8, INT16, INT32, INT64, + PYSPARK_BINARY, + PYSPARK_BYTEINT, + PYSPARK_DATE, + PYSPARK_DECIMAL, + PYSPARK_FLOAT, + PYSPARK_INT, + PYSPARK_LONGINT, + PYSPARK_SHORTINT, + PYSPARK_STRING, + PYSPARK_TIMESTAMP, STRING, UINT8, UINT16, @@ -42,6 +59,7 @@ UInt64, ) from pandera.typing.pandas import DataFrame, Index, Series +from pandera.typing.pyspark_sql import Column DATAFRAME_TYPES: Set[Type] = {DataFrame} SERIES_TYPES: Set[Type] = {Series} @@ -62,14 +80,13 @@ SERIES_TYPES.update({pyspark.Series}) INDEX_TYPES.update({pyspark.Index}) # type: ignore [arg-type] +if pyspark_sql.PYSPARK_SQL_INSTALLED: + DATAFRAME_TYPES.update({pyspark_sql.DataFrame}) + COLUMN_TYPES: Set[Type] = {Column} if geopandas.GEOPANDAS_INSTALLED: DATAFRAME_TYPES.update({geopandas.GeoDataFrame}) SERIES_TYPES.update({geopandas.GeoSeries}) -__all__ = [ - "DataFrame", - "Series", - "Index", -] +__all__ = ["DataFrame", "Series", "Index", "Column"] diff --git a/pandera/typing/common.py b/pandera/typing/common.py index 1145fefc2..611588b7e 100644 --- a/pandera/typing/common.py +++ b/pandera/typing/common.py @@ -1,5 +1,5 @@ """Common typing functionality.""" -# pylint:disable=abstract-method,disable=too-many-ancestors +# pylint:disable=abstract-method,too-many-ancestors,invalid-name import inspect from typing import TYPE_CHECKING, Any, Generic, Optional, Type, TypeVar, Union @@ -8,7 +8,7 @@ import typing_inspect from pandera import dtypes -from pandera.engines import numpy_engine, pandas_engine +from pandera.engines import numpy_engine, pandas_engine, pyspark_engine Bool = dtypes.Bool #: ``"bool"`` numpy dtype Date = dtypes.Date #: ``datetime.date`` object dtype @@ -43,6 +43,17 @@ #: fall back on the str-as-object-array representation. STRING = pandas_engine.STRING #: ``"str"`` numpy dtype BOOL = pandas_engine.BOOL #: ``"str"`` numpy dtype +PYSPARK_STRING = pyspark_engine.String +PYSPARK_INT = pyspark_engine.Int +PYSPARK_LONGINT = pyspark_engine.BigInt +PYSPARK_SHORTINT = pyspark_engine.ShortInt +PYSPARK_BYTEINT = pyspark_engine.ByteInt +PYSPARK_DOUBLE = pyspark_engine.Double +PYSPARK_FLOAT = pyspark_engine.Float +PYSPARK_DECIMAL = pyspark_engine.Decimal +PYSPARK_DATE = pyspark_engine.Date +PYSPARK_TIMESTAMP = pyspark_engine.Timestamp +PYSPARK_BINARY = pyspark_engine.Binary try: Geometry = pandas_engine.Geometry # : ``"geometry"`` geopandas dtype @@ -90,6 +101,16 @@ String, STRING, Geometry, + pyspark_engine.String, + pyspark_engine.Int, + pyspark_engine.BigInt, + pyspark_engine.ShortInt, + pyspark_engine.ByteInt, + pyspark_engine.Float, + pyspark_engine.Decimal, + pyspark_engine.Date, + pyspark_engine.Timestamp, + pyspark_engine.Binary, ], ) else: @@ -131,6 +152,16 @@ Object, String, STRING, + pyspark_engine.String, + pyspark_engine.Int, + pyspark_engine.BigInt, + pyspark_engine.ShortInt, + pyspark_engine.ByteInt, + pyspark_engine.Float, + pyspark_engine.Decimal, + pyspark_engine.Date, + pyspark_engine.Timestamp, + pyspark_engine.Binary, ], ) @@ -205,6 +236,20 @@ def __get__( raise AttributeError("Indexes should resolve to pa.Index-s") +class ColumnBase(Generic[GenericDtype]): + """Representation of pandas.Index, only used for type annotation. + + *new in 0.5.0* + """ + + default_dtype: Optional[Type] = None + + def __get__( + self, instance: object, owner: Type + ) -> str: # pragma: no cover + raise AttributeError("column should resolve to pyspark.sql.Column-s") + + class AnnotationInfo: # pylint:disable=too-few-public-methods """Captures extra information about an annotation. diff --git a/pandera/typing/dask.py b/pandera/typing/dask.py index 0f204d852..e873796da 100644 --- a/pandera/typing/dask.py +++ b/pandera/typing/dask.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Generic, TypeVar from pandera.typing.common import DataFrameBase, IndexBase, SeriesBase -from pandera.typing.pandas import GenericDtype, DataFrameModel +from pandera.typing.pandas import DataFrameModel, GenericDtype try: import dask.dataframe as dd @@ -21,7 +21,6 @@ if DASK_INSTALLED: - # pylint: disable=too-few-public-methods,abstract-method class DataFrame(DataFrameBase, dd.DataFrame, Generic[T]): """ diff --git a/pandera/typing/geopandas.py b/pandera/typing/geopandas.py index f24868111..ad84bff2b 100644 --- a/pandera/typing/geopandas.py +++ b/pandera/typing/geopandas.py @@ -2,7 +2,6 @@ from typing import TYPE_CHECKING, Generic, TypeVar from pandera.typing.common import DataFrameBase, SeriesBase - from pandera.typing.pandas import DataFrameModel try: diff --git a/pandera/typing/modin.py b/pandera/typing/modin.py index 95f452ae5..a6d1f1f3b 100644 --- a/pandera/typing/modin.py +++ b/pandera/typing/modin.py @@ -5,7 +5,7 @@ from packaging import version from pandera.typing.common import DataFrameBase, IndexBase, SeriesBase -from pandera.typing.pandas import GenericDtype, DataFrameModel +from pandera.typing.pandas import DataFrameModel, GenericDtype try: import modin @@ -29,7 +29,6 @@ def modin_version(): if MODIN_INSTALLED: - # pylint: disable=too-few-public-methods class DataFrame(DataFrameBase, mpd.DataFrame, Generic[T]): """ diff --git a/pandera/typing/pandas.py b/pandera/typing/pandas.py index d1c6b9cd5..2e31d6948 100644 --- a/pandera/typing/pandas.py +++ b/pandera/typing/pandas.py @@ -20,9 +20,9 @@ from pandera.errors import SchemaError, SchemaInitError from pandera.typing.common import ( DataFrameBase, + DataFrameModel, GenericDtype, IndexBase, - DataFrameModel, SeriesBase, ) from pandera.typing.formats import Formats diff --git a/pandera/typing/pyspark.py b/pandera/typing/pyspark.py index 8fdb2d975..5f0934cec 100644 --- a/pandera/typing/pyspark.py +++ b/pandera/typing/pyspark.py @@ -2,8 +2,13 @@ from typing import TYPE_CHECKING, Generic, TypeVar -from pandera.typing.common import DataFrameBase, IndexBase, SeriesBase -from pandera.typing.pandas import GenericDtype, DataFrameModel, _GenericAlias +from pandera.typing.common import ( + DataFrameBase, + GenericDtype, + IndexBase, + SeriesBase, +) +from pandera.typing.pandas import DataFrameModel, _GenericAlias try: import pyspark.pandas as ps @@ -21,7 +26,6 @@ if PYSPARK_INSTALLED: - # pylint: disable=too-few-public-methods,arguments-renamed class DataFrame(DataFrameBase, ps.DataFrame, Generic[T]): """ diff --git a/pandera/typing/pyspark_sql.py b/pandera/typing/pyspark_sql.py new file mode 100644 index 000000000..8f439d61b --- /dev/null +++ b/pandera/typing/pyspark_sql.py @@ -0,0 +1,40 @@ +"""Pandera type annotations for Dask.""" + +from typing import TYPE_CHECKING, Generic, TypeVar + +import pyspark.sql + +from pandera.typing.common import ColumnBase, DataFrameBase, GenericDtype +from pandera.typing.pandas import DataFrameModel, _GenericAlias + +try: + import pyspark.sql as ps + + PYSPARK_SQL_INSTALLED = True +except ImportError: # pragma: no cover + PYSPARK_SQL_INSTALLED = False + + +# pylint:disable=invalid-name +if TYPE_CHECKING: + T = TypeVar("T") # pragma: no cover +else: + T = DataFrameModel + + +if PYSPARK_SQL_INSTALLED: + # pylint: disable=too-few-public-methods,arguments-renamed + class DataFrame(DataFrameBase, ps.DataFrame, Generic[T]): + """ + Representation of dask.dataframe.DataFrame, only used for type + annotation. + + *new in 0.8.0* + """ + + def __class_getitem__(cls, item): + """Define this to override's pyspark.pandas generic type.""" + return _GenericAlias(cls, item) + + class Column(ColumnBase, pyspark.sql.Column, Generic[GenericDtype]): # type: ignore [misc] # noqa + """Representation of pyspark.sql.Column, only used for type annotation.""" diff --git a/tests/modin/__init__.py b/tests/modin/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/pyspark/__init__.py b/tests/pyspark/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/pyspark/conftest.py b/tests/pyspark/conftest.py new file mode 100644 index 000000000..6b76158c4 --- /dev/null +++ b/tests/pyspark/conftest.py @@ -0,0 +1,155 @@ +"""Registers common pyspark.sql fixtures""" +# pylint: disable=redefined-outer-name +import datetime + +import pyspark.sql.types as T +import pytest +from pyspark.sql import SparkSession + +from pandera.backends.pyspark.utils import ConfigParams + + +@pytest.fixture(scope="session") +def spark() -> SparkSession: + """ + creates spark session + """ + return SparkSession.builder.getOrCreate() + + +@pytest.fixture(scope="session") +def sample_data(): + """ + provides sample data + """ + return [("Bread", 9), ("Butter", 15)] + + +@pytest.fixture(scope="session") +def sample_spark_schema(): + """ + provides spark schema for sample data + """ + return T.StructType( + [ + T.StructField("product", T.StringType(), True), + T.StructField("price", T.IntegerType(), True), + ], + ) + + +def spark_df(spark, data: list, spark_schema: T.StructType): + """Create a spark dataframe for testing""" + return spark.createDataFrame( + data=data, schema=spark_schema, verifySchema=False + ) + + +@pytest.fixture(scope="session") +def sample_date_object(spark): + """Creates a spark dataframe with date data.""" + sample_data = [ + ( + datetime.date(2022, 10, 1), + datetime.datetime(2022, 10, 1, 5, 32, 0), + datetime.timedelta(45), + datetime.timedelta(45), + ), + ( + datetime.date(2022, 11, 5), + datetime.datetime(2022, 11, 5, 15, 34, 0), + datetime.timedelta(30), + datetime.timedelta(45), + ), + ] + sample_spark_schema = T.StructType( + [ + T.StructField("purchase_date", T.DateType(), False), + T.StructField("purchase_datetime", T.TimestampType(), False), + T.StructField("expiry_time", T.DayTimeIntervalType(), False), + T.StructField("expected_time", T.DayTimeIntervalType(2, 3), False), + ], + ) + df = spark_df( + spark=spark, spark_schema=sample_spark_schema, data=sample_data + ) + return df + + +@pytest.fixture(scope="session") +def sample_string_binary_object(spark): + """Creates a spark dataframe with string binary data.""" + sample_data = [ + ( + "test1", + "Bread", + ), + ("test2", "Butter"), + ] + sample_spark_schema = T.StructType( + [ + T.StructField("purchase_info", T.StringType(), False), + T.StructField("product", T.StringType(), False), + ], + ) + df = spark_df( + spark=spark, spark_schema=sample_spark_schema, data=sample_data + ) + df = df.withColumn( + "purchase_info", df["purchase_info"].cast(T.BinaryType()) + ) + return df + + +@pytest.fixture(scope="session") +def sample_complex_data(spark): + """Creates a spark dataframe datetimes, strings, and array types.""" + sample_data = [ + ( + datetime.date(2022, 10, 1), + [["josh"], ["27"]], + {"product_bought": "bread"}, + ), + ( + datetime.date(2022, 11, 5), + [["Adam"], ["22"]], + {"product_bought": "bread"}, + ), + ] + + sample_spark_schema = T.StructType( + [ + T.StructField("purchase_date", T.DateType(), False), + T.StructField( + "customer_details", + T.ArrayType( + T.ArrayType(T.StringType()), + ), + False, + ), + T.StructField( + "product_details", + T.MapType(T.StringType(), T.StringType()), + False, + ), + ], + ) + return spark_df(spark, sample_data, sample_spark_schema) + + +# pylint: disable=unused-argument +@pytest.fixture(scope="session") +def sample_check_data(spark): + """Creates a dictionary of sample data for checks.""" + + return { + "test_pass_data": [("foo", 30), ("bar", 30)], + "test_fail_data": [("foo", 30), ("bar", 31)], + "test_expression": 30, + } + + +@pytest.fixture(scope="session") +def config_params(): + """Configuration for pyspark.""" + return ConfigParams("pyspark", "parameters.yaml") diff --git a/tests/pyspark/test_pyspark_accessor.py b/tests/pyspark/test_pyspark_accessor.py new file mode 100644 index 000000000..3f63506ce --- /dev/null +++ b/tests/pyspark/test_pyspark_accessor.py @@ -0,0 +1,65 @@ +"""Unit tests for dask_accessor module.""" +from typing import Union + +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql.functions import col +import pytest + +import pandera.pyspark as pa +from pandera.pyspark import pyspark_sql_accessor + + +spark = SparkSession.builder.getOrCreate() + + +@pytest.mark.parametrize( + "schema1, schema2, data, invalid_data", + [ + [ + pa.DataFrameSchema({"col": pa.Column("long")}, coerce=True), + pa.DataFrameSchema({"col": pa.Column("float")}, coerce=False), + spark.createDataFrame([{"col": 1}, {"col": 2}, {"col": 3}]), + spark.createDataFrame([{"col": 1}, {"col": 2}, {"col": 3}]), + ], + ], +) +def test_dataframe_series_add_schema( + schema1: pa.DataFrameSchema, + schema2: pa.DataFrameSchema, + data: Union[DataFrame, col], + invalid_data: Union[DataFrame, col], + config_params: pa.ConfigParams, +) -> None: + """ + Test that pandas object contains schema metadata after pandera validation. + """ + validated_data_1 = schema1(data) # type: ignore[arg-type] + + assert data.pandera.schema == schema1 + assert isinstance(schema1.validate(data), DataFrame) + assert isinstance(schema1(data), DataFrame) + if config_params["DEPTH"] != "DATA_ONLY": + assert dict(schema2(invalid_data).pandera.errors["SCHEMA"]) == { + "WRONG_DATATYPE": [ + { + "schema": None, + "column": "col", + "check": "dtype('FloatType()')", + "error": "expected column 'col' to have type FloatType(), got LongType()", + } + ] + } # type: ignore[arg-type] + + +class CustomAccessor: + """Mock accessor class""" + + def __init__(self, obj): + self._obj = obj + + +def test_modin_accessor_warning(): + """Test that modin accessor raises warning when name already exists.""" + pyspark_sql_accessor.register_dataframe_accessor("foo")(CustomAccessor) + with pytest.warns(UserWarning): + pyspark_sql_accessor.register_dataframe_accessor("foo")(CustomAccessor) diff --git a/tests/pyspark/test_pyspark_check.py b/tests/pyspark/test_pyspark_check.py new file mode 100644 index 000000000..c4f381a89 --- /dev/null +++ b/tests/pyspark/test_pyspark_check.py @@ -0,0 +1,1395 @@ +"""Unit tests for pyspark container.""" +import datetime + +from pyspark.sql.types import ( + LongType, + StringType, + StructField, + StructType, + IntegerType, + ByteType, + ShortType, + TimestampType, + DateType, + DecimalType, + DoubleType, + BooleanType, + FloatType, + ArrayType, + MapType, +) +import decimal + +import pytest +import pandera.pyspark as pa +from pandera.backends.pyspark.utils import ConfigParams +from pandera.backends.pyspark.decorators import validate_params +from pandera.pyspark import DataFrameSchema, Column +from pandera.errors import PysparkSchemaError + + +class TestDecorator: + params = ConfigParams("pyspark", "parameters.yaml") + + @validate_params(params=params, scope="DATA") + def test_datatype_check_decorator(self, spark): + schema = DataFrameSchema( + { + "product": Column(StringType()), + "code": Column(StringType(), pa.Check.str_startswith("B")), + } + ) + + spark_schema = StructType( + [ + StructField("product", StringType(), False), + StructField("code", StringType(), False), + ], + ) + pass_case_data = [["foo", "B1"], ["bar", "B2"]] + df = spark.createDataFrame(data=pass_case_data, schema=spark_schema) + df_out = schema.validate(df) + if df_out.pandera.errors: + print(df_out.pandera.errors) + raise PysparkSchemaError + + fail_schema = DataFrameSchema( + { + "product": Column(StringType()), + "code": Column(IntegerType(), pa.Check.str_startswith("B")), + } + ) + + spark_schema = StructType( + [ + StructField("product", StringType(), False), + StructField("code", IntegerType(), False), + ], + ) + fail_case_data = [["foo", 1], ["bar", 2]] + df = spark.createDataFrame(data=fail_case_data, schema=spark_schema) + df_out = schema.validate(df) + if not df_out.pandera.errors: + print(df_out.pandera.errors) + raise PysparkSchemaError + df = spark.createDataFrame(data=fail_case_data, schema=spark_schema) + try: + df_out = fail_schema.validate(df) + except TypeError as err: + assert ( + err.__str__() + == 'The check with name "str_startswith" only supports the following datatypes [\'string\'] and not the given "integer" datatype' + ) + + +class BaseClass: + params = ConfigParams("pyspark", "parameters.yaml") + + def __int__(self, params=None): + pass + + sample_string_data = { + "test_pass_data": [("foo", "b"), ("bar", "c")], + "test_expression": "a", + } + + sample_array_data = { + "test_pass_data": [("foo", ["a"]), ("bar", ["a"])], + "test_expression": "a", + } + + sample_map_data = { + "test_pass_data": [("foo", {"a": "a"}), ("bar", {"b": "b"})], + "test_expression": "b", + } + + sample_bolean_data = { + "test_pass_data": [("foo", True), ("bar", True)], + "test_expression": False, + } + + def pytest_generate(self, metafunc): + raise NotImplementedError + + @staticmethod + def convert_value(sample_data, conversion_datatype): + data_dict = {} + for key, value in sample_data.items(): + if key == "test_expression": + if not isinstance(value, list): + data_dict[key] = conversion_datatype(value) + else: + data_dict[key] = [conversion_datatype(i) for i in value] + + else: + if not isinstance(value[0][1], list): + data_dict[key] = [ + (i[0], conversion_datatype(i[1])) for i in value + ] + else: + final_val = [] + for row in value: + data_val = [] + for column in row[1]: + data_val.append(conversion_datatype(column)) + final_val.append((row[0], data_val)) + data_dict[key] = final_val + return data_dict + + @staticmethod + def convert_numeric_data(sample_data, convert_type): + if (convert_type == "double") or (convert_type == "float"): + data_dict = BaseClass.convert_value(sample_data, float) + + if convert_type == "decimal": + data_dict = BaseClass.convert_value(sample_data, decimal.Decimal) + + return data_dict + + @staticmethod + def convert_timestamp_to_date(sample_data): + data_dict = {} + for key, value in sample_data.items(): + if key == "test_expression": + if not isinstance(value, list): + data_dict[key] = value.date() + else: + data_dict[key] = [i.date() for i in value] + + else: + if not isinstance(value[0][1], list): + data_dict[key] = [(i[0], i[1].date()) for i in value] + else: + final_val = [] + for row in value: + data_val = [] + for column in row[1]: + data_val.append(column.date()) + final_val.append((row[0], data_val)) + data_dict[key] = final_val + return data_dict + + @staticmethod + def check_function( + spark, + check_fn, + pass_case_data, + fail_case_data, + data_types, + function_args, + skip_fail_case=False, + ): + schema = DataFrameSchema( + { + "product": Column(StringType()), + "code": Column(data_types, check_fn(*function_args)) + if isinstance(function_args, tuple) + else Column(data_types, check_fn(function_args)), + } + ) + spark_schema = StructType( + [ + StructField("product", StringType(), False), + StructField("code", data_types, False), + ], + ) + df = spark.createDataFrame(data=pass_case_data, schema=spark_schema) + df_out = schema.validate(df) + if df_out.pandera.errors: + print(df_out.pandera.errors) + raise PysparkSchemaError + if not skip_fail_case: + with pytest.raises(PysparkSchemaError): + df_fail = spark.createDataFrame( + data=fail_case_data, schema=spark_schema + ) + df_out = schema.validate(df_fail) + if df_out.pandera.errors: + raise PysparkSchemaError + + +class TestEqualToCheck(BaseClass): + sample_numeric_data = { + "test_pass_data": [("foo", 30), ("bar", 30)], + "test_fail_data": [("foo", 30), ("bar", 31)], + "test_expression": 30, + } + + sample_timestamp_data = { + "test_pass_data": [ + ("foo", datetime.datetime(2020, 10, 1, 10, 0)), + ("bar", datetime.datetime(2020, 10, 1, 10, 0)), + ], + "test_fail_data": [ + ("foo", datetime.datetime(2020, 10, 2, 10, 0)), + ("bar", datetime.datetime(2020, 10, 2, 10, 0)), + ], + "test_expression": datetime.datetime(2020, 10, 1, 10, 0), + } + + sample_string_data = { + "test_pass_data": [("foo", "a"), ("bar", "a")], + "test_fail_data": [("foo", "a"), ("bar", "b")], + "test_expression": "a", + } + + sample_bolean_data = { + "test_pass_data": [("foo", True), ("bar", True)], + "test_fail_data": [("foo", False), ("bar", False)], + "test_expression": True, + } + + def pytest_generate_tests(self, metafunc): + # called once per each test function + funcarglist = self.get_data_param()[metafunc.function.__name__] + argnames = sorted(funcarglist[0]) + metafunc.parametrize( + argnames, + [ + [funcargs[name] for name in argnames] + for funcargs in funcarglist + ], + ) + + def get_data_param(self): + return { + "test_equal_to_check": [ + {"datatype": LongType(), "data": self.sample_numeric_data}, + {"datatype": IntegerType(), "data": self.sample_numeric_data}, + {"datatype": ByteType(), "data": self.sample_numeric_data}, + {"datatype": ShortType(), "data": self.sample_numeric_data}, + {"datatype": StringType(), "data": self.sample_string_data}, + { + "datatype": DoubleType(), + "data": self.convert_numeric_data( + self.sample_numeric_data, "double" + ), + }, + { + "datatype": TimestampType(), + "data": self.sample_timestamp_data, + }, + { + "datatype": DateType(), + "data": self.convert_timestamp_to_date( + self.sample_timestamp_data + ), + }, + { + "datatype": DecimalType(), + "data": self.convert_numeric_data( + self.sample_numeric_data, "decimal" + ), + }, + { + "datatype": FloatType(), + "data": self.convert_numeric_data( + self.sample_numeric_data, "float" + ), + }, + {"datatype": BooleanType(), "data": self.sample_bolean_data}, + ], + "test_failed_unaccepted_datatypes": [ + { + "datatype": ArrayType(StringType()), + "data": self.sample_array_data, + }, + { + "datatype": MapType(StringType(), StringType()), + "data": self.sample_map_data, + }, + ], + } + + @validate_params(params=BaseClass.params, scope="DATA") + @pytest.mark.parametrize("check_fn", [pa.Check.equal_to, pa.Check.eq]) + def test_equal_to_check(self, spark, check_fn, datatype, data) -> None: + """Test the Check to see if all the values are equal to defined value""" + self.check_function( + spark, + check_fn, + data["test_pass_data"], + data["test_fail_data"], + datatype, + data["test_expression"], + ) + + @validate_params(params=BaseClass.params, scope="DATA") + @pytest.mark.parametrize("check_fn", [pa.Check.equal_to, pa.Check.eq]) + def test_failed_unaccepted_datatypes( + self, spark, check_fn, datatype, data + ) -> None: + """Test the Check to see if all the values are equal to defined value""" + with pytest.raises(TypeError): + self.check_function( + spark, + check_fn, + data["test_pass_data"], + None, + datatype, + data["test_expression"], + ) + + +class TestNotEqualToCheck(BaseClass): + sample_numeric_data = { + "test_pass_data": [("foo", 31), ("bar", 32)], + "test_fail_data": [("foo", 30), ("bar", 31)], + "test_expression": 30, + } + + sample_timestamp_data = { + "test_pass_data": [ + ("foo", datetime.datetime(2020, 10, 1, 10, 0)), + ("bar", datetime.datetime(2020, 10, 2, 10, 0)), + ], + "test_fail_data": [ + ("foo", datetime.datetime(2020, 10, 3, 10, 0)), + ("bar", datetime.datetime(2020, 10, 2, 10, 0)), + ], + "test_expression": datetime.datetime(2020, 10, 3, 10, 0), + } + + sample_string_data = { + "test_pass_data": [("foo", "b"), ("bar", "c")], + "test_fail_data": [("foo", "a"), ("bar", "b")], + "test_expression": "a", + } + + sample_bolean_data = { + "test_pass_data": [("foo", True), ("bar", True)], + "test_fail_data": [("foo", False), ("bar", True)], + "test_expression": False, + } + + def pytest_generate_tests(self, metafunc): + # called once per each test function + funcarglist = self.get_data_param()[metafunc.function.__name__] + argnames = sorted(funcarglist[0]) + metafunc.parametrize( + argnames, + [ + [funcargs[name] for name in argnames] + for funcargs in funcarglist + ], + ) + + def get_data_param(self): + return { + "test_not_equal_to_check": [ + {"datatype": LongType(), "data": self.sample_numeric_data}, + {"datatype": IntegerType(), "data": self.sample_numeric_data}, + {"datatype": ByteType(), "data": self.sample_numeric_data}, + {"datatype": ShortType(), "data": self.sample_numeric_data}, + {"datatype": StringType(), "data": self.sample_string_data}, + { + "datatype": DoubleType(), + "data": self.convert_numeric_data( + self.sample_numeric_data, "double" + ), + }, + { + "datatype": TimestampType(), + "data": self.sample_timestamp_data, + }, + { + "datatype": DateType(), + "data": self.convert_timestamp_to_date( + self.sample_timestamp_data + ), + }, + { + "datatype": DecimalType(), + "data": self.convert_numeric_data( + self.sample_numeric_data, "decimal" + ), + }, + { + "datatype": FloatType(), + "data": self.convert_numeric_data( + self.sample_numeric_data, "float" + ), + }, + {"datatype": BooleanType(), "data": self.sample_bolean_data}, + ], + "test_failed_unaccepted_datatypes": [ + { + "datatype": ArrayType(StringType()), + "data": self.sample_array_data, + }, + { + "datatype": MapType(StringType(), StringType()), + "data": self.sample_map_data, + }, + ], + } + + @validate_params(params=BaseClass.params, scope="DATA") + @pytest.mark.parametrize("check_fn", [pa.Check.not_equal_to, pa.Check.ne]) + def test_not_equal_to_check(self, spark, check_fn, datatype, data) -> None: + """Test the Check to see if all the values are equal to defined value""" + self.check_function( + spark, + check_fn, + data["test_pass_data"], + data["test_fail_data"], + datatype, + data["test_expression"], + ) + + @validate_params(params=BaseClass.params, scope="DATA") + @pytest.mark.parametrize("check_fn", [pa.Check.not_equal_to, pa.Check.ne]) + def test_failed_unaccepted_datatypes( + self, spark, check_fn, datatype, data + ) -> None: + """Test the Check to see if all the values are equal to defined value""" + with pytest.raises(TypeError): + self.check_function( + spark, + check_fn, + data["test_pass_data"], + None, + datatype, + data["test_expression"], + ) + + +class TestGreaterThanCheck(BaseClass): + sample_numeric_data = { + "test_pass_data": [("foo", 31), ("bar", 32)], + "test_fail_data": [("foo", 30), ("bar", 31)], + "test_expression": 30, + } + + sample_timestamp_data = { + "test_pass_data": [ + ("foo", datetime.datetime(2020, 10, 1, 11, 0)), + ("bar", datetime.datetime(2020, 10, 2, 11, 0)), + ], + "test_fail_data": [ + ("foo", datetime.datetime(2020, 10, 1, 10, 0)), + ("bar", datetime.datetime(2020, 10, 2, 11, 0)), + ], + "test_expression": datetime.datetime(2020, 10, 1, 10, 0), + } + sample_date_data = { + "test_pass_data": [ + ("foo", datetime.date(2020, 10, 2)), + ("bar", datetime.date(2020, 10, 3)), + ], + "test_fail_data": [ + ("foo", datetime.date(2020, 10, 2)), + ("bar", datetime.date(2020, 10, 1)), + ], + "test_expression": datetime.date(2020, 10, 1), + } + + def pytest_generate_tests(self, metafunc): + # called once per each test function + funcarglist = self.get_data_param()[metafunc.function.__name__] + argnames = sorted(funcarglist[0]) + metafunc.parametrize( + argnames, + [ + [funcargs[name] for name in argnames] + for funcargs in funcarglist + ], + ) + + def get_data_param(self): + return { + "test_greater_than_check": [ + {"datatype": LongType(), "data": self.sample_numeric_data}, + {"datatype": IntegerType(), "data": self.sample_numeric_data}, + {"datatype": ByteType(), "data": self.sample_numeric_data}, + {"datatype": ShortType(), "data": self.sample_numeric_data}, + { + "datatype": DoubleType(), + "data": self.convert_numeric_data( + self.sample_numeric_data, "double" + ), + }, + { + "datatype": TimestampType(), + "data": self.sample_timestamp_data, + }, + {"datatype": DateType(), "data": self.sample_date_data}, + { + "datatype": DecimalType(), + "data": self.convert_numeric_data( + self.sample_numeric_data, "decimal" + ), + }, + { + "datatype": FloatType(), + "data": self.convert_numeric_data( + self.sample_numeric_data, "float" + ), + }, + ], + "test_failed_unaccepted_datatypes": [ + {"datatype": StringType(), "data": self.sample_string_data}, + {"datatype": BooleanType(), "data": self.sample_bolean_data}, + { + "datatype": ArrayType(StringType()), + "data": self.sample_array_data, + }, + { + "datatype": MapType(StringType(), StringType()), + "data": self.sample_map_data, + }, + ], + } + + @validate_params(params=BaseClass.params, scope="DATA") + @pytest.mark.parametrize("check_fn", [pa.Check.greater_than, pa.Check.gt]) + def test_greater_than_check(self, spark, check_fn, datatype, data) -> None: + """Test the Check to see if all the values are equal to defined value""" + self.check_function( + spark, + check_fn, + data["test_pass_data"], + data["test_fail_data"], + datatype, + data["test_expression"], + ) + + @validate_params(params=BaseClass.params, scope="DATA") + @pytest.mark.parametrize("check_fn", [pa.Check.greater_than, pa.Check.gt]) + def test_failed_unaccepted_datatypes( + self, spark, check_fn, datatype, data + ) -> None: + """Test the Check to see if all the values are equal to defined value""" + with pytest.raises(TypeError): + self.check_function( + spark, + check_fn, + data["test_pass_data"], + None, + datatype, + data["test_expression"], + ) + + +class TestGreaterThanEqualToCheck(BaseClass): + sample_numeric_data = { + "test_pass_data": [("foo", 31), ("bar", 32)], + "test_fail_data": [("foo", 30), ("bar", 31)], + "test_expression": 31, + } + + sample_timestamp_data = { + "test_pass_data": [ + ("foo", datetime.datetime(2020, 10, 1, 11, 0)), + ("bar", datetime.datetime(2020, 10, 2, 11, 0)), + ], + "test_fail_data": [ + ("foo", datetime.datetime(2020, 10, 1, 10, 0)), + ("bar", datetime.datetime(2020, 10, 2, 11, 0)), + ], + "test_expression": datetime.datetime(2020, 10, 1, 11, 0), + } + sample_date_data = { + "test_pass_data": [ + ("foo", datetime.date(2020, 10, 2)), + ("bar", datetime.date(2020, 10, 3)), + ], + "test_fail_data": [ + ("foo", datetime.date(2020, 10, 2)), + ("bar", datetime.date(2020, 10, 1)), + ], + "test_expression": datetime.date(2020, 10, 2), + } + + def pytest_generate_tests(self, metafunc): + # called once per each test function + funcarglist = self.get_data_param()[metafunc.function.__name__] + argnames = sorted(funcarglist[0]) + metafunc.parametrize( + argnames, + [ + [funcargs[name] for name in argnames] + for funcargs in funcarglist + ], + ) + + def get_data_param(self): + return { + "test_greater_than_or_equal_to_check": [ + {"datatype": LongType(), "data": self.sample_numeric_data}, + {"datatype": IntegerType(), "data": self.sample_numeric_data}, + {"datatype": ByteType(), "data": self.sample_numeric_data}, + {"datatype": ShortType(), "data": self.sample_numeric_data}, + { + "datatype": DoubleType(), + "data": self.convert_numeric_data( + self.sample_numeric_data, "double" + ), + }, + { + "datatype": TimestampType(), + "data": self.sample_timestamp_data, + }, + {"datatype": DateType(), "data": self.sample_date_data}, + { + "datatype": DecimalType(), + "data": self.convert_numeric_data( + self.sample_numeric_data, "decimal" + ), + }, + { + "datatype": FloatType(), + "data": self.convert_numeric_data( + self.sample_numeric_data, "float" + ), + }, + ], + "test_failed_unaccepted_datatypes": [ + {"datatype": StringType(), "data": self.sample_string_data}, + {"datatype": BooleanType(), "data": self.sample_bolean_data}, + { + "datatype": ArrayType(StringType()), + "data": self.sample_array_data, + }, + { + "datatype": MapType(StringType(), StringType()), + "data": self.sample_map_data, + }, + ], + } + + @validate_params(params=BaseClass.params, scope="DATA") + @pytest.mark.parametrize( + "check_fn", [pa.Check.greater_than_or_equal_to, pa.Check.ge] + ) + def test_greater_than_or_equal_to_check( + self, spark, check_fn, datatype, data + ) -> None: + """Test the Check to see if all the values are equal to defined value""" + self.check_function( + spark, + check_fn, + data["test_pass_data"], + data["test_fail_data"], + datatype, + data["test_expression"], + ) + + @validate_params(params=BaseClass.params, scope="DATA") + @pytest.mark.parametrize( + "check_fn", [pa.Check.greater_than_or_equal_to, pa.Check.ge] + ) + def test_failed_unaccepted_datatypes( + self, spark, check_fn, datatype, data + ) -> None: + """Test the Check to see if all the values are equal to defined value""" + with pytest.raises(TypeError): + self.check_function( + spark, + check_fn, + data["test_pass_data"], + None, + datatype, + data["test_expression"], + ) + + +class TestLessThanCheck(BaseClass): + sample_numeric_data = { + "test_pass_data": [("foo", 31), ("bar", 32)], + "test_fail_data": [("foo", 33), ("bar", 31)], + "test_expression": 33, + } + + sample_timestamp_data = { + "test_pass_data": [ + ("foo", datetime.datetime(2020, 10, 1, 11, 0)), + ("bar", datetime.datetime(2020, 10, 2, 11, 0)), + ], + "test_fail_data": [ + ("foo", datetime.datetime(2020, 10, 1, 10, 0)), + ("bar", datetime.datetime(2020, 10, 2, 12, 0)), + ], + "test_expression": datetime.datetime(2020, 10, 2, 12, 0), + } + sample_date_data = { + "test_pass_data": [ + ("foo", datetime.date(2020, 10, 2)), + ("bar", datetime.date(2020, 10, 1)), + ], + "test_fail_data": [ + ("foo", datetime.date(2020, 10, 2)), + ("bar", datetime.date(2020, 10, 3)), + ], + "test_expression": datetime.date(2020, 10, 3), + } + + def pytest_generate_tests(self, metafunc): + # called once per each test function + funcarglist = self.get_data_param()[metafunc.function.__name__] + argnames = sorted(funcarglist[0]) + metafunc.parametrize( + argnames, + [ + [funcargs[name] for name in argnames] + for funcargs in funcarglist + ], + ) + + def get_data_param(self): + return { + "test_less_than_check": [ + {"datatype": LongType(), "data": self.sample_numeric_data}, + {"datatype": IntegerType(), "data": self.sample_numeric_data}, + {"datatype": ByteType(), "data": self.sample_numeric_data}, + {"datatype": ShortType(), "data": self.sample_numeric_data}, + { + "datatype": DoubleType(), + "data": self.convert_numeric_data( + self.sample_numeric_data, "double" + ), + }, + { + "datatype": TimestampType(), + "data": self.sample_timestamp_data, + }, + {"datatype": DateType(), "data": self.sample_date_data}, + { + "datatype": DecimalType(), + "data": self.convert_numeric_data( + self.sample_numeric_data, "decimal" + ), + }, + { + "datatype": FloatType(), + "data": self.convert_numeric_data( + self.sample_numeric_data, "float" + ), + }, + ], + "test_failed_unaccepted_datatypes": [ + {"datatype": StringType(), "data": self.sample_string_data}, + {"datatype": BooleanType(), "data": self.sample_bolean_data}, + { + "datatype": ArrayType(StringType()), + "data": self.sample_array_data, + }, + { + "datatype": MapType(StringType(), StringType()), + "data": self.sample_map_data, + }, + ], + } + + @validate_params(params=BaseClass.params, scope="DATA") + @pytest.mark.parametrize("check_fn", [pa.Check.less_than, pa.Check.lt]) + def test_less_than_check(self, spark, check_fn, datatype, data) -> None: + """Test the Check to see if all the values are equal to defined value""" + self.check_function( + spark, + check_fn, + data["test_pass_data"], + data["test_fail_data"], + datatype, + data["test_expression"], + ) + + @validate_params(params=BaseClass.params, scope="DATA") + @pytest.mark.parametrize("check_fn", [pa.Check.less_than, pa.Check.lt]) + def test_failed_unaccepted_datatypes( + self, spark, check_fn, datatype, data + ) -> None: + """Test the Check to see if all the values are equal to defined value""" + with pytest.raises(TypeError): + self.check_function( + spark, + check_fn, + data["test_pass_data"], + None, + datatype, + data["test_expression"], + ) + + +class TestLessThanOrEqualToCheck(BaseClass): + sample_numeric_data = { + "test_pass_data": [("foo", 31), ("bar", 33)], + "test_fail_data": [("foo", 34), ("bar", 31)], + "test_expression": 33, + } + + sample_timestamp_data = { + "test_pass_data": [ + ("foo", datetime.datetime(2020, 10, 1, 11, 0)), + ("bar", datetime.datetime(2020, 10, 2, 11, 0)), + ], + "test_fail_data": [ + ("foo", datetime.datetime(2020, 10, 1, 10, 0)), + ("bar", datetime.datetime(2020, 10, 2, 12, 0)), + ], + "test_expression": datetime.datetime(2020, 10, 2, 11, 0), + } + sample_date_data = { + "test_pass_data": [ + ("foo", datetime.date(2020, 10, 2)), + ("bar", datetime.date(2020, 10, 1)), + ], + "test_fail_data": [ + ("foo", datetime.date(2020, 10, 2)), + ("bar", datetime.date(2020, 10, 3)), + ], + "test_expression": datetime.date(2020, 10, 2), + } + + def pytest_generate_tests(self, metafunc): + # called once per each test function + funcarglist = self.get_data_param()[metafunc.function.__name__] + argnames = sorted(funcarglist[0]) + metafunc.parametrize( + argnames, + [ + [funcargs[name] for name in argnames] + for funcargs in funcarglist + ], + ) + + def get_data_param(self): + return { + "test_less_than_or_equal_to_check": [ + {"datatype": LongType(), "data": self.sample_numeric_data}, + {"datatype": IntegerType(), "data": self.sample_numeric_data}, + {"datatype": ByteType(), "data": self.sample_numeric_data}, + {"datatype": ShortType(), "data": self.sample_numeric_data}, + { + "datatype": DoubleType(), + "data": self.convert_numeric_data( + self.sample_numeric_data, "double" + ), + }, + { + "datatype": TimestampType(), + "data": self.sample_timestamp_data, + }, + {"datatype": DateType(), "data": self.sample_date_data}, + { + "datatype": DecimalType(), + "data": self.convert_numeric_data( + self.sample_numeric_data, "decimal" + ), + }, + { + "datatype": FloatType(), + "data": self.convert_numeric_data( + self.sample_numeric_data, "float" + ), + }, + ], + "test_failed_unaccepted_datatypes": [ + {"datatype": StringType(), "data": self.sample_string_data}, + {"datatype": BooleanType(), "data": self.sample_bolean_data}, + { + "datatype": ArrayType(StringType()), + "data": self.sample_array_data, + }, + { + "datatype": MapType(StringType(), StringType()), + "data": self.sample_map_data, + }, + ], + } + + @validate_params(params=BaseClass.params, scope="DATA") + @pytest.mark.parametrize( + "check_fn", [pa.Check.less_than_or_equal_to, pa.Check.le] + ) + def test_less_than_or_equal_to_check( + self, spark, check_fn, datatype, data + ) -> None: + """Test the Check to see if all the values are equal to defined value""" + self.check_function( + spark, + check_fn, + data["test_pass_data"], + data["test_fail_data"], + datatype, + data["test_expression"], + ) + + @validate_params(params=BaseClass.params, scope="DATA") + @pytest.mark.parametrize( + "check_fn", [pa.Check.less_than_or_equal_to, pa.Check.le] + ) + def test_failed_unaccepted_datatypes( + self, spark, check_fn, datatype, data + ) -> None: + """Test the Check to see if all the values are equal to defined value""" + with pytest.raises(TypeError): + self.check_function( + spark, + check_fn, + data["test_pass_data"], + None, + datatype, + data["test_expression"], + ) + + +class TestIsInCheck(BaseClass): + sample_numeric_data = { + "test_pass_data": [("foo", 31), ("bar", 32)], + "test_fail_data": [("foo", 30), ("bar", 31)], + "test_expression": [31, 32], + } + + sample_timestamp_data = { + "test_pass_data": [ + ("foo", datetime.datetime(2020, 10, 1, 10, 0)), + ("bar", datetime.datetime(2020, 10, 2, 10, 0)), + ], + "test_fail_data": [ + ("foo", datetime.datetime(2020, 10, 3, 10, 0)), + ("bar", datetime.datetime(2020, 10, 2, 10, 0)), + ], + "test_expression": [ + datetime.datetime(2020, 10, 1, 10, 0), + datetime.datetime(2020, 10, 2, 10, 0), + ], + } + sample_date_data = { + "test_pass_data": [ + ("foo", datetime.date(2020, 10, 1)), + ("bar", datetime.date(2020, 10, 2)), + ], + "test_fail_data": [ + ("foo", datetime.date(2020, 10, 2)), + ("bar", datetime.date(2020, 10, 3)), + ], + "test_expression": [ + datetime.date(2020, 10, 1), + datetime.date(2020, 10, 2), + ], + } + + sample_string_data = { + "test_pass_data": [("foo", "b"), ("bar", "c")], + "test_fail_data": [("foo", "a"), ("bar", "b")], + "test_expression": ["b", "c"], + } + sample_bolean_data = { + "test_pass_data": [("foo", [True]), ("bar", [True])], + "test_expression": [False], + } + + def pytest_generate_tests(self, metafunc): + # called once per each test function + funcarglist = self.get_data_param()[metafunc.function.__name__] + argnames = sorted(funcarglist[0]) + metafunc.parametrize( + argnames, + [ + [funcargs[name] for name in argnames] + for funcargs in funcarglist + ], + ) + + def get_data_param(self): + return { + "test_isin_check": [ + {"datatype": LongType(), "data": self.sample_numeric_data}, + {"datatype": IntegerType(), "data": self.sample_numeric_data}, + {"datatype": ByteType(), "data": self.sample_numeric_data}, + {"datatype": ShortType(), "data": self.sample_numeric_data}, + { + "datatype": DoubleType(), + "data": self.convert_numeric_data( + self.sample_numeric_data, "double" + ), + }, + { + "datatype": TimestampType(), + "data": self.sample_timestamp_data, + }, + {"datatype": DateType(), "data": self.sample_date_data}, + { + "datatype": DecimalType(), + "data": self.convert_numeric_data( + self.sample_numeric_data, "decimal" + ), + }, + { + "datatype": FloatType(), + "data": self.convert_numeric_data( + self.sample_numeric_data, "float" + ), + }, + {"datatype": StringType(), "data": self.sample_string_data}, + ], + "test_failed_unaccepted_datatypes": [ + {"datatype": BooleanType(), "data": self.sample_bolean_data}, + { + "datatype": ArrayType(StringType()), + "data": self.sample_array_data, + }, + { + "datatype": MapType(StringType(), StringType()), + "data": self.sample_map_data, + }, + ], + } + + @validate_params(params=BaseClass.params, scope="DATA") + def test_isin_check(self, spark, datatype, data) -> None: + """Test the Check to see if all the values are equal to defined value""" + self.check_function( + spark, + pa.Check.isin, + data["test_pass_data"], + data["test_fail_data"], + datatype, + data["test_expression"], + ) + + @validate_params(params=BaseClass.params, scope="DATA") + def test_failed_unaccepted_datatypes(self, spark, datatype, data) -> None: + """Test the Check to see if all the values are equal to defined value""" + with pytest.raises(TypeError): + self.check_function( + spark, + pa.Check.isin, + data["test_pass_data"], + None, + datatype, + data["test_expression"], + ) + + +class TestNotInCheck(BaseClass): + sample_numeric_data = { + "test_pass_data": [("foo", 31), ("bar", 32)], + "test_fail_data": [("foo", 30), ("bar", 31)], + "test_expression": [30, 33], + } + + sample_timestamp_data = { + "test_pass_data": [ + ("foo", datetime.datetime(2020, 10, 1, 10, 0)), + ("bar", datetime.datetime(2020, 10, 2, 10, 0)), + ], + "test_fail_data": [ + ("foo", datetime.datetime(2020, 10, 3, 10, 0)), + ("bar", datetime.datetime(2020, 10, 2, 10, 0)), + ], + "test_expression": [ + datetime.datetime(2020, 10, 3, 10, 0), + datetime.datetime(2020, 10, 4, 10, 0), + ], + } + + sample_string_data = { + "test_pass_data": [("foo", "b"), ("bar", "c")], + "test_fail_data": [("foo", "a"), ("bar", "b")], + "test_expression": ["a", "d"], + } + + sample_bolean_data = { + "test_pass_data": [("foo", [True]), ("bar", [True])], + "test_expression": [False], + } + + def pytest_generate_tests(self, metafunc): + # called once per each test function + funcarglist = self.get_data_param()[metafunc.function.__name__] + argnames = sorted(funcarglist[0]) + metafunc.parametrize( + argnames, + [ + [funcargs[name] for name in argnames] + for funcargs in funcarglist + ], + ) + + def get_data_param(self): + return { + "test_notin_check": [ + {"datatype": LongType(), "data": self.sample_numeric_data}, + {"datatype": IntegerType(), "data": self.sample_numeric_data}, + {"datatype": ByteType(), "data": self.sample_numeric_data}, + {"datatype": ShortType(), "data": self.sample_numeric_data}, + { + "datatype": DoubleType(), + "data": self.convert_numeric_data( + self.sample_numeric_data, "double" + ), + }, + { + "datatype": TimestampType(), + "data": self.sample_timestamp_data, + }, + { + "datatype": DateType(), + "data": self.convert_timestamp_to_date( + self.sample_timestamp_data + ), + }, + { + "datatype": DecimalType(), + "data": self.convert_numeric_data( + self.sample_numeric_data, "decimal" + ), + }, + { + "datatype": FloatType(), + "data": self.convert_numeric_data( + self.sample_numeric_data, "float" + ), + }, + {"datatype": StringType(), "data": self.sample_string_data}, + ], + "test_failed_unaccepted_datatypes": [ + {"datatype": BooleanType(), "data": self.sample_bolean_data}, + { + "datatype": ArrayType(StringType()), + "data": self.sample_array_data, + }, + { + "datatype": MapType(StringType(), StringType()), + "data": self.sample_map_data, + }, + ], + } + + @validate_params(params=BaseClass.params, scope="DATA") + def test_notin_check(self, spark, datatype, data) -> None: + """Test the Check to see if all the values are equal to defined value""" + + self.check_function( + spark, + pa.Check.notin, + data["test_pass_data"], + data["test_fail_data"], + datatype, + data["test_expression"], + ) + + @validate_params(params=BaseClass.params, scope="DATA") + def test_failed_unaccepted_datatypes(self, spark, datatype, data) -> None: + """Test the Check to see if all the values are equal to defined value""" + with pytest.raises(TypeError): + self.check_function( + spark, + pa.Check.notin, + data["test_pass_data"], + None, + datatype, + data["test_expression"], + ) + + +class TestStringType(BaseClass): + @validate_params(params=BaseClass.params, scope="DATA") + def test_str_startswith_check(self, spark) -> None: + """Test the Check to see if any value is not in the specified value""" + check_func = pa.Check.str_startswith + check_value = "B" + + pass_data = [("Bal", "Bread"), ("Bal", "Butter")] + fail_data = [("Bal", "Test"), ("Bal", "Butter")] + BaseClass.check_function( + spark, check_func, pass_data, fail_data, StringType(), check_value + ) + + @validate_params(params=BaseClass.params, scope="DATA") + def test_str_endswith_check(self, spark) -> None: + """Test the Check to see if any value is not in the specified value""" + check_func = pa.Check.str_endswith + check_value = "d" + + pass_data = [("Bal", "Bread"), ("Bal", "Bad")] + fail_data = [("Bal", "Test"), ("Bal", "Bad")] + BaseClass.check_function( + spark, check_func, pass_data, fail_data, StringType(), check_value + ) + + @validate_params(params=BaseClass.params, scope="DATA") + def test_str_contains_check(self, spark) -> None: + """Test the Check to see if any value is not in the specified value""" + check_func = pa.Check.str_contains + check_value = "Ba" + + pass_data = [("Bal", "Bat!"), ("Bal", "Bat78")] + fail_data = [("Bal", "Cs"), ("Bal", "Jam!")] + BaseClass.check_function( + spark, check_func, pass_data, fail_data, StringType(), check_value + ) + + +class TestInRangeCheck(BaseClass): + sample_numeric_data = { + "test_pass_data": [("foo", 31), ("bar", 33)], + "test_fail_data": [("foo", 35), ("bar", 31)], + } + + sample_timestamp_data = { + "test_pass_data": [ + ("foo", datetime.datetime(2020, 10, 1, 11, 0)), + ("bar", datetime.datetime(2020, 10, 2, 11, 0)), + ], + "test_fail_data": [ + ("foo", datetime.datetime(2020, 10, 1, 10, 0)), + ("bar", datetime.datetime(2020, 10, 5, 12, 0)), + ], + } + + sample_bolean_data = { + "test_pass_data": [("foo", [True]), ("bar", [True])], + "test_expression": [False], + } + + def pytest_generate_tests(self, metafunc): + # called once per each test function + funcarglist = self.get_data_param()[metafunc.function.__name__] + argnames = sorted(funcarglist[0]) + metafunc.parametrize( + argnames, + [ + [funcargs[name] for name in argnames] + for funcargs in funcarglist + ], + ) + + def create_min_max(self, data_dictionary): + value_dict = [value[1] for value in data_dictionary["test_pass_data"]] + min_val = min(value_dict) + max_val = max(value_dict) + if isinstance(min_val, datetime.datetime): + add_value = datetime.timedelta(1) + elif isinstance(min_val, datetime.date): + add_value = datetime.timedelta(1) + else: + add_value = 1 + return min_val, max_val, add_value + + def get_data_param(self): + param_vals = [ + {"datatype": LongType(), "data": self.sample_numeric_data}, + {"datatype": IntegerType(), "data": self.sample_numeric_data}, + {"datatype": ByteType(), "data": self.sample_numeric_data}, + {"datatype": ShortType(), "data": self.sample_numeric_data}, + { + "datatype": DoubleType(), + "data": self.convert_numeric_data( + self.sample_numeric_data, "double" + ), + }, + {"datatype": TimestampType(), "data": self.sample_timestamp_data}, + { + "datatype": DateType(), + "data": self.convert_timestamp_to_date( + self.sample_timestamp_data + ), + }, + { + "datatype": DecimalType(), + "data": self.convert_numeric_data( + self.sample_numeric_data, "decimal" + ), + }, + { + "datatype": FloatType(), + "data": self.convert_numeric_data( + self.sample_numeric_data, "float" + ), + }, + ] + + return { + "test_inrange_exclude_min_max_check": param_vals, + "test_inrange_exclude_min_only_check": param_vals, + "test_inrange_exclude_max_only_check": param_vals, + "test_inrange_include_min_max_check": param_vals, + "test_failed_unaccepted_datatypes": [ + {"datatype": StringType(), "data": self.sample_string_data}, + {"datatype": BooleanType(), "data": self.sample_bolean_data}, + { + "datatype": ArrayType(StringType()), + "data": self.sample_array_data, + }, + { + "datatype": MapType(StringType(), StringType()), + "data": self.sample_map_data, + }, + ], + } + + @validate_params(params=BaseClass.params, scope="DATA") + def test_inrange_exclude_min_max_check( + self, spark, datatype, data + ) -> None: + """Test the Check to see if any value is not in the specified value""" + min_val, max_val, add_value = self.create_min_max(data) + self.check_function( + spark, + pa.Check.in_range, + data["test_pass_data"], + data["test_fail_data"], + datatype, + (min_val - add_value, max_val + add_value, False, False), + ) + + @validate_params(params=BaseClass.params, scope="DATA") + def test_inrange_exclude_min_only_check( + self, spark, datatype, data + ) -> None: + """Test the Check to see if any value is not in the specified value""" + min_val, max_val, add_value = self.create_min_max(data) + self.check_function( + spark, + pa.Check.in_range, + data["test_pass_data"], + data["test_fail_data"], + datatype, + (min_val, max_val + add_value, True, False), + ) + + @validate_params(params=BaseClass.params, scope="DATA") + def test_inrange_exclude_max_only_check( + self, spark, datatype, data + ) -> None: + """Test the Check to see if any value is not in the specified value""" + min_val, max_val, add_value = self.create_min_max(data) + self.check_function( + spark, + pa.Check.in_range, + data["test_pass_data"], + data["test_fail_data"], + datatype, + (min_val - add_value, max_val, False, True), + ) + + @validate_params(params=BaseClass.params, scope="DATA") + def test_inrange_include_min_max_check( + self, spark, datatype, data + ) -> None: + """Test the Check to see if any value is not in the specified value""" + min_val, max_val, add_value = self.create_min_max(data) + self.check_function( + spark, + pa.Check.in_range, + data["test_pass_data"], + data["test_fail_data"], + datatype, + (min_val, max_val, True, True), + ) + + @validate_params(params=BaseClass.params, scope="DATA") + def test_failed_unaccepted_datatypes(self, spark, datatype, data) -> None: + """Test the Check to see if all the values are equal to defined value""" + with pytest.raises(TypeError): + self.check_function( + spark, + pa.Check.in_range, + data["test_pass_data"], + None, + datatype, + data["test_expression"], + ) diff --git a/tests/pyspark/test_pyspark_container.py b/tests/pyspark/test_pyspark_container.py new file mode 100644 index 000000000..f30f6ed6a --- /dev/null +++ b/tests/pyspark/test_pyspark_container.py @@ -0,0 +1,116 @@ +"""Unit tests for pyspark container.""" + +from pyspark.sql import SparkSession +import pyspark.sql.types as T +import pytest +import pandera.pyspark as pa +import pandera.errors +from pandera.pyspark import DataFrameSchema, Column + +spark = SparkSession.builder.getOrCreate() + + +def test_pyspark_dataframeschema(): + """ + Test creating a pyspark DataFrameSchema object + """ + + schema = DataFrameSchema( + { + "name": Column(T.StringType()), + "age": Column(T.IntegerType(), coerce=True, nullable=True), + } + ) + + data = [("Neeraj", 35), ("Jask", 30)] + + df = spark.createDataFrame(data=data, schema=["name", "age"]) + df_out = schema.validate(df) + + assert df_out.pandera.errors != None + + data = [("Neeraj", "35"), ("Jask", "a")] + + df2 = spark.createDataFrame(data=data, schema=["name", "age"]) + + df_out = schema.validate(df2) + + assert not df_out.pandera.errors + + +def test_pyspark_dataframeschema_with_alias_types(config_params): + """ + Test creating a pyspark DataFrameSchema object + """ + + schema = DataFrameSchema( + columns={ + "product": Column("str", checks=pa.Check.str_startswith("B")), + "price": Column("int", checks=pa.Check.gt(5)), + }, + name="product_schema", + description="schema for product info", + title="ProductSchema", + ) + + data = [("Bread", 9), ("Butter", 15)] + + spark_schema = T.StructType( + [ + T.StructField("product", T.StringType(), False), + T.StructField("price", T.IntegerType(), False), + ], + ) + + df = spark.createDataFrame(data=data, schema=spark_schema) + + df_out = schema.validate(df) + + assert not df_out.pandera.errors + if config_params["DEPTH"] in ["SCHEMA_AND_DATA", "DATA_ONLY"]: + with pytest.raises(pandera.errors.PysparkSchemaError): + data_fail = [("Bread", 3), ("Butter", 15)] + + df_fail = spark.createDataFrame( + data=data_fail, schema=spark_schema + ) + + fail_df = schema.validate(df_fail) + if fail_df.pandera.errors: + raise pandera.errors.PysparkSchemaError + + +def test_pyspark_column_metadata(): + """ + Test creating a pyspark Column object with metadata + """ + + schema = DataFrameSchema( + columns={ + "product": Column( + "str", + checks=pa.Check.str_startswith("B"), + metadata={"usecase": "product_pricing", "type": ["t1", "t2"]}, + ), + "price": Column("int", checks=pa.Check.gt(5)), + }, + name="product_schema", + description="schema for product info", + title="ProductSchema", + metadata={"category": "product"}, + ) + + expected = { + "product_schema": { + "columns": { + "product": { + "usecase": "product_pricing", + "type": ["t1", "t2"], + }, + "price": None, + }, + "dataframe": {"category": "product"}, + } + } + + assert schema.get_metadata == expected diff --git a/tests/pyspark/test_pyspark_dtypes.py b/tests/pyspark/test_pyspark_dtypes.py new file mode 100644 index 000000000..533ee5e5a --- /dev/null +++ b/tests/pyspark/test_pyspark_dtypes.py @@ -0,0 +1,388 @@ +"""Unit tests for pyspark container.""" + +import pyspark.sql.types as T +import pytest + +from pandera.pyspark import DataFrameSchema, Column +from tests.pyspark.conftest import spark_df +from pandera.backends.pyspark.utils import ConfigParams +from pandera.backends.pyspark.decorators import validate_params +from pyspark.sql import DataFrame + + +class BaseClass: + params = ConfigParams("pyspark", "parameters.yaml") + + def validate_datatype(self, df, pandera_schema): + df_out = pandera_schema(df) + + assert df.pandera.schema == pandera_schema + assert isinstance(pandera_schema.validate(df), DataFrame) + assert isinstance(pandera_schema(df), DataFrame) + return df_out + + def pytest_generate_tests(self, metafunc): + # called once per each test function + funcarglist = metafunc.cls.params[metafunc.function.__name__] + argnames = sorted(funcarglist[0]) + metafunc.parametrize( + argnames, + [ + [funcargs[name] for name in argnames] + for funcargs in funcarglist + ], + ) + + def validate_data( + self, df, pandera_equivalent, column_name, return_error=False + ): + pandera_schema = DataFrameSchema( + columns={ + column_name: Column(pandera_equivalent), + }, + ) + df_out = self.validate_datatype(df, pandera_schema) + if df_out.pandera.errors: + if return_error == True: + return df_out.pandera.errors + else: + print(df_out.pandera.errors) + raise Exception + + +class TestAllNumericTypes(BaseClass): + # a map specifying multiple argument sets for a test method + params = { + "test_pyspark_all_float_types": [ + {"pandera_equivalent": float}, + {"pandera_equivalent": "FloatType()"}, + {"pandera_equivalent": T.FloatType()}, + {"pandera_equivalent": T.FloatType}, + {"pandera_equivalent": "float"}, + ], + "test_pyspark_decimal_default_types": [ + {"pandera_equivalent": "decimal"}, + {"pandera_equivalent": "DecimalType()"}, + {"pandera_equivalent": T.DecimalType}, + {"pandera_equivalent": T.DecimalType()}, + ], + "test_pyspark_decimal_parameterized_types": [ + { + "pandera_equivalent": { + "parameter_match": T.DecimalType(20, 5), + "parameter_mismatch": T.DecimalType(20, 3), + } + } + ], + "test_pyspark_all_double_types": [ + {"pandera_equivalent": T.DoubleType()}, + {"pandera_equivalent": T.DoubleType}, + {"pandera_equivalent": "double"}, + {"pandera_equivalent": "DoubleType()"}, + ], + "test_pyspark_all_int_types": [ + {"pandera_equivalent": int}, + {"pandera_equivalent": "int"}, + {"pandera_equivalent": "IntegerType()"}, + {"pandera_equivalent": T.IntegerType()}, + {"pandera_equivalent": T.IntegerType}, + ], + "test_pyspark_all_longint_types": [ + {"pandera_equivalent": "bigint"}, + {"pandera_equivalent": "long"}, + {"pandera_equivalent": T.LongType}, + {"pandera_equivalent": T.LongType()}, + {"pandera_equivalent": "LongType()"}, + ], + "test_pyspark_all_shortint_types": [ + {"pandera_equivalent": "ShortType()"}, + {"pandera_equivalent": T.ShortType}, + {"pandera_equivalent": T.ShortType()}, + {"pandera_equivalent": "short"}, + {"pandera_equivalent": "smallint"}, + ], + "test_pyspark_all_bytetint_types": [ + {"pandera_equivalent": "ByteType()"}, + {"pandera_equivalent": T.ByteType}, + {"pandera_equivalent": T.ByteType()}, + {"pandera_equivalent": "bytes"}, + {"pandera_equivalent": "tinyint"}, + ], + } + + def create_schema(self, column_name, datatype): + spark_schema = T.StructType( + [ + T.StructField(column_name, datatype, False), + ], + ) + return spark_schema + + def test_pyspark_all_float_types( + self, spark, sample_data, pandera_equivalent + ): + """ + Test int dtype column + """ + column_name = "price" + spark_schema = self.create_schema(column_name, T.FloatType()) + df = spark_df(spark, sample_data, spark_schema) + self.validate_data(df, pandera_equivalent, column_name) + + def test_pyspark_all_double_types( + self, spark, sample_data, pandera_equivalent + ): + """ + Test int dtype column + """ + column_name = "price" + spark_schema = self.create_schema(column_name, T.DoubleType()) + df = spark_df(spark, sample_data, spark_schema) + self.validate_data(df, pandera_equivalent, column_name) + + def test_pyspark_decimal_default_types( + self, spark, sample_data, pandera_equivalent + ): + """ + Test int dtype column + """ + column_name = "price" + spark_schema = self.create_schema(column_name, T.DecimalType()) + df = spark_df(spark, sample_data, spark_schema) + self.validate_data(df, pandera_equivalent, column_name) + + @validate_params(params=BaseClass.params, scope="SCHEMA") + def test_pyspark_decimal_parameterized_types( + self, spark, sample_data, pandera_equivalent + ): + """ + Test int dtype column + """ + column_name = "price" + spark_schema = self.create_schema(column_name, T.DecimalType(20, 5)) + df = spark_df(spark, sample_data, spark_schema) + self.validate_data( + df, pandera_equivalent["parameter_match"], column_name + ) + errors = self.validate_data( + df, pandera_equivalent["parameter_mismatch"], column_name, True + ) + assert dict(errors["SCHEMA"]) == { + "WRONG_DATATYPE": [ + { + "schema": None, + "column": "price", + "check": "dtype('DecimalType(20,3)')", + "error": "expected column 'price' to have type DecimalType(20,3), " + "got DecimalType(20,5)", + } + ] + } + + def test_pyspark_all_int_types( + self, spark, sample_data, pandera_equivalent + ): + """ + Test int dtype column + """ + column_name = "price" + spark_schema = self.create_schema(column_name, T.IntegerType()) + df = spark_df(spark, sample_data, spark_schema) + self.validate_data(df, pandera_equivalent, column_name) + + def test_pyspark_all_longint_types( + self, spark, sample_data, pandera_equivalent + ): + """ + Test int dtype column + """ + column_name = "price" + spark_schema = self.create_schema(column_name, T.LongType()) + df = spark_df(spark, sample_data, spark_schema) + self.validate_data(df, pandera_equivalent, column_name) + + def test_pyspark_all_shortint_types( + self, spark, sample_data, pandera_equivalent + ): + """ + Test int dtype column + """ + column_name = "price" + spark_schema = self.create_schema(column_name, T.ShortType()) + df = spark_df(spark, sample_data, spark_schema) + self.validate_data(df, pandera_equivalent, column_name) + + def test_pyspark_all_bytetint_types( + self, spark, sample_data, pandera_equivalent + ): + """ + Test int dtype column + """ + column_name = "price" + spark_schema = self.create_schema(column_name, T.ByteType()) + df = spark_df(spark, sample_data, spark_schema) + self.validate_data(df, pandera_equivalent, column_name) + + +class TestAllDatetimeTestClass(BaseClass): + # a map specifying multiple argument sets for a test method + params = { + "test_pyspark_all_date_types": [ + {"pandera_equivalent": T.DateType}, + {"pandera_equivalent": "DateType()"}, + {"pandera_equivalent": T.DateType()}, + {"pandera_equivalent": "date"}, + ], + "test_pyspark_all_datetime_types": [ + {"pandera_equivalent": T.TimestampType}, + {"pandera_equivalent": "TimestampType()"}, + {"pandera_equivalent": T.TimestampType()}, + {"pandera_equivalent": "datetime"}, + {"pandera_equivalent": "timestamp"}, + ], + "test_pyspark_all_daytimeinterval_types": [ + {"pandera_equivalent": T.DayTimeIntervalType}, + {"pandera_equivalent": "timedelta"}, + {"pandera_equivalent": T.DayTimeIntervalType()}, + {"pandera_equivalent": "DayTimeIntervalType()"}, + ], + "test_pyspark_daytimeinterval_param_mismatch": [ + {"pandera_equivalent": T.DayTimeIntervalType(1, 3)}, + ], + } + + def test_pyspark_all_date_types( + self, pandera_equivalent, sample_date_object + ): + column_name = "purchase_date" + df = sample_date_object.select(column_name) + error = self.validate_data(df, pandera_equivalent, column_name) + + def test_pyspark_all_datetime_types( + self, pandera_equivalent, sample_date_object + ): + column_name = "purchase_datetime" + df = sample_date_object.select(column_name) + self.validate_data(df, pandera_equivalent, column_name) + + def test_pyspark_all_daytimeinterval_types( + self, pandera_equivalent, sample_date_object + ): + column_name = "expiry_time" + df = sample_date_object.select(column_name) + self.validate_data(df, pandera_equivalent, column_name) + + @validate_params(params=BaseClass.params, scope="SCHEMA") + def test_pyspark_daytimeinterval_param_mismatch( + self, pandera_equivalent, sample_date_object + ): + column_name = "expected_time" + df = sample_date_object.select(column_name) + errors = self.validate_data(df, pandera_equivalent, column_name, True) + assert dict(errors["SCHEMA"]) == { + "WRONG_DATATYPE": [ + { + "schema": None, + "column": "expected_time", + "check": "dtype('DayTimeIntervalType(1, 3)')", + "error": "expected column 'expected_time' to have type DayTimeIntervalType(1, 3), " + "got DayTimeIntervalType(2, 3)", + } + ] + } + + +class TestBinaryStringTypes(BaseClass): + # a map specifying multiple argument sets for a test method + params = { + "test_pyspark_all_binary_types": [ + {"pandera_equivalent": "binary"}, + {"pandera_equivalent": "BinaryType()"}, + {"pandera_equivalent": T.BinaryType()}, + {"pandera_equivalent": T.BinaryType}, + ], + "test_pyspark_all_string_types": [ + {"pandera_equivalent": str}, + {"pandera_equivalent": "string"}, + {"pandera_equivalent": "StringType()"}, + {"pandera_equivalent": T.StringType()}, + {"pandera_equivalent": T.StringType}, + ], + } + + def test_pyspark_all_binary_types( + self, pandera_equivalent, sample_string_binary_object + ): + column_name = "purchase_info" + df = sample_string_binary_object.select(column_name) + self.validate_data(df, pandera_equivalent, column_name) + + def test_pyspark_all_string_types( + self, pandera_equivalent, sample_string_binary_object + ): + column_name = "product" + df = sample_string_binary_object.select(column_name) + self.validate_data(df, pandera_equivalent, column_name) + + +class TestComplexType(BaseClass): + params = { + "test_pyspark_array_type": [ + { + "pandera_equivalent": { + "schema_match": T.ArrayType(T.ArrayType(T.StringType())), + "schema_mismatch": T.ArrayType( + T.ArrayType(T.IntegerType()) + ), + } + } + ], + "test_pyspark_map_type": [ + { + "pandera_equivalent": { + "schema_match": T.MapType(T.StringType(), T.StringType()), + "schema_mismatch": T.MapType( + T.StringType(), T.IntegerType() + ), + } + } + ], + } + + @validate_params(params=BaseClass.params, scope="SCHEMA") + def test_pyspark_array_type(self, sample_complex_data, pandera_equivalent): + column_name = "customer_details" + df = sample_complex_data.select(column_name) + self.validate_data(df, pandera_equivalent["schema_match"], column_name) + errors = self.validate_data( + df, pandera_equivalent["schema_mismatch"], column_name, True + ) + assert dict(errors["SCHEMA"]) == { + "WRONG_DATATYPE": [ + { + "schema": None, + "column": "customer_details", + "check": "dtype('ArrayType(ArrayType(IntegerType(), True), True)')", + "error": "expected column 'customer_details' to have type ArrayType(ArrayType(IntegerType(), True), True), got ArrayType(ArrayType(StringType(), True), True)", + } + ] + } + + @validate_params(params=BaseClass.params, scope="SCHEMA") + def test_pyspark_map_type(self, sample_complex_data, pandera_equivalent): + column_name = "product_details" + df = sample_complex_data.select(column_name) + self.validate_data(df, pandera_equivalent["schema_match"], column_name) + errors = self.validate_data( + df, pandera_equivalent["schema_mismatch"], column_name, True + ) + assert dict(errors["SCHEMA"]) == { + "WRONG_DATATYPE": [ + { + "schema": None, + "column": "product_details", + "check": "dtype('MapType(StringType(), IntegerType(), True)')", + "error": "expected column 'product_details' to have type MapType(StringType(), IntegerType(), True), got MapType(StringType(), StringType(), True)", + } + ] + } diff --git a/tests/pyspark/test_pyspark_engine.py b/tests/pyspark/test_pyspark_engine.py new file mode 100644 index 000000000..d4b19a19c --- /dev/null +++ b/tests/pyspark/test_pyspark_engine.py @@ -0,0 +1,32 @@ +"""Tests Engine subclassing and registring DataTypes.""" +# pylint:disable=redefined-outer-name,unused-argument +# pylint:disable=missing-function-docstring,missing-class-docstring +"""Test pyspark engine.""" + +import pytest + +from pandera.engines import pyspark_engine + + +@pytest.mark.parametrize( + "data_type", list(pyspark_engine.Engine.get_registered_dtypes()) +) +def test_pyspark_data_type(data_type): + """Test pyspark engine DataType base class.""" + if data_type.type is None: + # don't test data types that require parameters e.g. Category + return + parameterized_datatypes = ["daytimeinterval", "decimal", "array", "map"] + + pyspark_engine.Engine.dtype(data_type) + pyspark_engine.Engine.dtype(data_type.type) + if data_type.type.typeName() not in parameterized_datatypes: + print(data_type.type.typeName()) + pyspark_engine.Engine.dtype(str(data_type.type)) + + with pytest.warns(UserWarning): + pd_dtype = pyspark_engine.DataType(data_type.type) + if data_type.type.typeName() not in parameterized_datatypes: + with pytest.warns(UserWarning): + pd_dtype_from_str = pyspark_engine.DataType(str(data_type.type)) + assert pd_dtype == pd_dtype_from_str diff --git a/tests/pyspark/test_pyspark_error.py b/tests/pyspark/test_pyspark_error.py new file mode 100644 index 000000000..afe2a4397 --- /dev/null +++ b/tests/pyspark/test_pyspark_error.py @@ -0,0 +1,153 @@ +"""Unit tests for dask_accessor module.""" +from typing import Union + +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql.functions import col +import pyspark.sql.types as T +import pytest +import pandera.pyspark as pa + +from pyspark.sql.types import StringType +from pandera.pyspark import DataFrameSchema, Column, DataFrameModel, Field +from tests.pyspark.conftest import spark_df + + +spark = SparkSession.builder.getOrCreate() + + +@pytest.mark.parametrize( + "schema, invalid_data", + [ + [ + pa.DataFrameSchema( + { + "product": pa.Column(StringType()), + "code": pa.Column(StringType(), pa.Check.not_equal_to(30)), + } + ), + spark.createDataFrame( + data=[("23", 31), ("34", 35)], schema=["product", "code"] + ), + ], + ], +) +def test_dataframe_add_schema( + schema: pa.DataFrameSchema, + invalid_data: Union[DataFrame, col], +) -> None: + """ + Test that pandas object contains schema metadata after pandera validation. + """ + schema(invalid_data, lazy=True) # type: ignore[arg-type] + + +def test_pyspark_check_eq(spark, sample_spark_schema): + """ + Test creating a pyspark DataFrameSchema object + """ + + pandera_schema = DataFrameSchema( + columns={ + "product": Column("str", checks=pa.Check.str_startswith("B")), + "price": Column("int", checks=pa.Check.gt(5)), + }, + name="product_schema", + description="schema for product info", + title="ProductSchema", + ) + + data_fail = [("Bread", 5), ("Cutter", 15)] + df_fail = spark_df(spark, data_fail, sample_spark_schema) + errors = pandera_schema.validate(check_obj=df_fail) + print(errors) + + +def test_pyspark_check_nullable(spark, sample_spark_schema): + """ + Test creating a pyspark DataFrameSchema object + """ + + pandera_schema = DataFrameSchema( + columns={ + "product": Column("str", checks=pa.Check.str_startswith("B")), + "price": Column("int", nullable=False), + } + ) + + data_fail = [("Bread", None), ("Cutter", 15)] + sample_spark_schema = T.StructType( + [ + T.StructField("product", T.StringType(), False), + T.StructField("price", T.IntegerType(), True), + ], + ) + df_fail = spark_df(spark, data_fail, sample_spark_schema) + errors = pandera_schema.validate(check_obj=df_fail) + print(errors.pandera.errors) + + +def test_pyspark_schema_data_checks(spark): + """ + Test schema and data level checks + """ + + pandera_schema = DataFrameSchema( + columns={ + "product": Column("str", checks=pa.Check.str_startswith("B")), + "price": Column("int", checks=pa.Check.gt(5)), + "id": Column(T.ArrayType(StringType())), + }, + name="product_schema", + description="schema for product info", + title="ProductSchema", + ) + + data_fail = [("Bread", 5, ["Food"]), ("Cutter", 15, ["99"])] + + spark_schema = T.StructType( + [ + T.StructField("product", T.StringType(), False), + T.StructField("price", T.IntegerType(), False), + T.StructField("id", T.ArrayType(StringType()), False), + ], + ) + + df_fail = spark_df(spark, data_fail, spark_schema) + errors = pandera_schema.validate(check_obj=df_fail) + print(errors) + + +def test_pyspark_fields(spark): + """ + Test schema and data level checks + """ + + class pandera_schema(DataFrameModel): + product: T.StringType = Field(str_startswith="B") + price: T.IntegerType = Field(gt=5) + id: T.DecimalType(20, 5) = Field() + id2: T.ArrayType(StringType()) = Field() + product_info: T.MapType(StringType(), StringType()) + + data_fail = [ + ("Bread", 5, 44.4, ["val"], {"product_category": "dairy"}), + ("Cutter", 15, 99.0, ["val2"], {"product_category": "bakery"}), + ] + + spark_schema = T.StructType( + [ + T.StructField("product", T.StringType(), False), + T.StructField("price", T.IntegerType(), False), + T.StructField("id", T.DecimalType(20, 5), False), + T.StructField("id2", T.ArrayType(T.StringType()), False), + T.StructField( + "product_info", + T.MapType(T.StringType(), T.StringType(), False), + False, + ), + ], + ) + df_fail = spark_df(spark, data_fail, spark_schema) + df_out = pandera_schema.validate(check_obj=df_fail) + + print(df_out) diff --git a/tests/pyspark/test_pyspark_model.py b/tests/pyspark/test_pyspark_model.py new file mode 100644 index 000000000..a9442ad4c --- /dev/null +++ b/tests/pyspark/test_pyspark_model.py @@ -0,0 +1,298 @@ +"""Unit tests for DataFrameModel module.""" + +from pyspark.sql import DataFrame +import pyspark.sql.types as T +import pytest +import pandera.pyspark as pa +from pandera.pyspark import DataFrameModel, DataFrameSchema, Field +from tests.pyspark.conftest import spark_df + + +def test_schema_with_bare_types(): + """ + Test that DataFrameModel can be defined without generics. + """ + + class Model(DataFrameModel): + a: int + b: str + c: float + + expected = pa.DataFrameSchema( + name="Model", + columns={ + "a": pa.Column(int), + "b": pa.Column(str), + "c": pa.Column(float), + }, + ) + + assert expected == Model.to_schema() + + +def test_schema_with_bare_types_and_field(): + """ + Test that DataFrameModel can be defined without generics. + """ + + class Model(DataFrameModel): + a: int = Field() + b: str = Field() + c: float = Field() + + expected = DataFrameSchema( + name="Model", + columns={ + "a": pa.Column(int), + "b": pa.Column(str), + "c": pa.Column(float), + }, + ) + + assert expected == Model.to_schema() + + +def test_schema_with_bare_types_field_and_checks(spark): + """ + Test that DataFrameModel can be defined without generics. + """ + + class Model(DataFrameModel): + a: str = Field(str_startswith="B") + b: int = Field(gt=6) + c: float = Field() + + expected = DataFrameSchema( + name="Model", + columns={ + "a": pa.Column(str, checks=pa.Check.str_startswith("B")), + "b": pa.Column(int, checks=pa.Check.gt(6)), + "c": pa.Column(float), + }, + ) + + assert expected == Model.to_schema() + + data_fail = [("Bread", 5, "Food"), ("Cutter", 15, 99.99)] + + spark_schema = T.StructType( + [ + T.StructField("a", T.StringType(), False), # should fail + T.StructField("b", T.IntegerType(), False), # should fail + T.StructField("c", T.FloatType(), False), + ], + ) + + df_fail = spark_df(spark, data_fail, spark_schema) + df_out = Model.validate(check_obj=df_fail) + assert df_out.pandera.errors != None + + +def test_schema_with_bare_types_field_type(spark): + """ + Test that DataFrameModel can be defined without generics. + """ + + class Model(DataFrameModel): + a: str = Field(str_startswith="B") + b: int = Field(gt=6) + c: float = Field() + + data_fail = [("Bread", 5, "Food"), ("Cutter", 15, 99.99)] + + spark_schema = T.StructType( + [ + T.StructField("a", T.StringType(), False), # should fail + T.StructField("b", T.IntegerType(), False), # should fail + T.StructField("c", T.StringType(), False), # should fail + ], + ) + + df_fail = spark_df(spark, data_fail, spark_schema) + df_out = Model.validate(check_obj=df_fail) + assert df_out.pandera.errors != None + + +def test_pyspark_bare_fields(spark): + """ + Test schema and data level checks + """ + + class pandera_schema(DataFrameModel): + id: T.IntegerType() = Field(gt=5) + product_name: T.StringType() = Field(str_startswith="B") + price: T.DecimalType(20, 5) = Field() + description: T.ArrayType(T.StringType()) = Field() + meta: T.MapType(T.StringType(), T.StringType()) = Field() + + data_fail = [ + ( + 5, + "Bread", + 44.4, + ["description of product"], + {"product_category": "dairy"}, + ), + ( + 15, + "Butter", + 99.0, + ["more details here"], + {"product_category": "bakery"}, + ), + ] + + spark_schema = T.StructType( + [ + T.StructField("id", T.IntegerType(), False), + T.StructField("product", T.StringType(), False), + T.StructField("price", T.DecimalType(20, 5), False), + T.StructField( + "description", T.ArrayType(T.StringType(), False), False + ), + T.StructField( + "meta", T.MapType(T.StringType(), T.StringType(), False), False + ), + ], + ) + df_fail = spark_df(spark, data_fail, spark_schema) + df_out = pandera_schema.validate(check_obj=df_fail) + assert df_out.pandera.errors != None + + +def test_dataframe_schema_strict(spark) -> None: + """ + Checks if strict=True whether a schema error is raised because 'a' is + not present in the dataframe. + """ + schema = DataFrameSchema( + { + "a": pa.Column("long", nullable=True), + "b": pa.Column("int", nullable=True), + }, + strict=True, + ) + df = spark.createDataFrame( + [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]], ["a", "b", "c", "d"] + ) + + df_out = schema.validate(df.select(["a", "b"])) + + assert isinstance(df_out, DataFrame) + with pytest.raises(pa.PysparkSchemaError): + df_out = schema.validate(df) + print(df_out.pandera.errors) + if df_out.pandera.errors: + raise pa.PysparkSchemaError + + schema.strict = "filter" + assert isinstance(schema.validate(df), DataFrame) + + assert list(schema.validate(df).columns) == ["a", "b"] + # + with pytest.raises(pa.SchemaInitError): + DataFrameSchema( + { + "a": pa.Column(int, nullable=True), + "b": pa.Column(int, nullable=True), + }, + strict="foobar", # type: ignore[arg-type] + ) + + with pytest.raises(pa.PysparkSchemaError): + df_out = schema.validate(df.select("a")) + if df_out.pandera.errors: + raise pa.PysparkSchemaError + with pytest.raises(pa.PysparkSchemaError): + df_out = schema.validate(df.select(["a", "c"])) + if df_out.pandera.errors: + raise pa.PysparkSchemaError + + +def test_pyspark_fields_metadata(spark): + """ + Test schema and metadata on field + """ + + class pandera_schema(DataFrameModel): + id: T.IntegerType() = Field( + gt=5, + metadata={ + "usecase": ["telco", "retail"], + "category": "product_pricing", + }, + ) + product_name: T.StringType() = Field(str_startswith="B") + price: T.DecimalType(20, 5) = Field() + + class Config: + name = "product_info" + strict = True + coerce = True + metadata = {"category": "product-details"} + + expected = { + "product_info": { + "columns": { + "id": { + "usecase": ["telco", "retail"], + "category": "product_pricing", + }, + "product_name": None, + "price": None, + }, + "dataframe": {"category": "product-details"}, + } + } + assert pandera_schema.get_metadata() == expected + + +def test_dataframe_schema_strict(spark, config_params) -> None: + """ + Checks if strict=True whether a schema error is raised because 'a' is + not present in the dataframe. + """ + if config_params["DEPTH"] != "DATA_ONLY": + schema = DataFrameSchema( + { + "a": pa.Column("long", nullable=True), + "b": pa.Column("int", nullable=True), + }, + strict=True, + ) + df = spark.createDataFrame( + [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]], ["a", "b", "c", "d"] + ) + + df_out = schema.validate(df.select(["a", "b"])) + + assert isinstance(df_out, DataFrame) + + with pytest.raises(pa.PysparkSchemaError): + df_out = schema.validate(df) + print(df_out.pandera.errors) + if df_out.pandera.errors: + raise pa.PysparkSchemaError + + schema.strict = "filter" + assert isinstance(schema.validate(df), DataFrame) + + assert list(schema.validate(df).columns) == ["a", "b"] + # + with pytest.raises(pa.SchemaInitError): + DataFrameSchema( + { + "a": pa.Column(int, nullable=True), + "b": pa.Column(int, nullable=True), + }, + strict="foobar", # type: ignore[arg-type] + ) + + with pytest.raises(pa.PysparkSchemaError): + df_out = schema.validate(df.select("a")) + if df_out.pandera.errors: + raise pa.PysparkSchemaError + with pytest.raises(pa.PysparkSchemaError): + df_out = schema.validate(df.select(["a", "c"])) + if df_out.pandera.errors: + raise pa.PysparkSchemaError diff --git a/tests/pyspark/test_schemas_on_pyspark.py b/tests/pyspark/test_schemas_on_pyspark_pandas.py similarity index 99% rename from tests/pyspark/test_schemas_on_pyspark.py rename to tests/pyspark/test_schemas_on_pyspark_pandas.py index 401444993..8e0672c32 100644 --- a/tests/pyspark/test_schemas_on_pyspark.py +++ b/tests/pyspark/test_schemas_on_pyspark_pandas.py @@ -62,7 +62,7 @@ pandas_engine.Date, } -SPARK_VERSION = version.parse(SparkContext().version) +SPARK_VERSION = version.parse(SparkContext.getOrCreate().version) if SPARK_VERSION < version.parse("3.3.0"): PYSPARK_PANDAS_UNSUPPORTED.add(numpy_engine.Timedelta64)