Skip to content

Commit

Permalink
Fix decimal (#956)
Browse files Browse the repository at this point in the history
* fix dtypes.Decimal bug

Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>

* update docs

Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>

* update dtypes docs

Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>

* update deps

Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>

* re-pin protobuf

Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>

Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>
  • Loading branch information
cosmicBboy authored Oct 8, 2022
1 parent 4e9c997 commit 3063a0a
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 29 deletions.
70 changes: 60 additions & 10 deletions docs/source/dtypes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Pandera Data Types

*new in 0.7.0*

.. _dtypes-into:
.. _dtypes-intro:

Motivations
~~~~~~~~~~~
Expand Down Expand Up @@ -201,17 +201,19 @@ underlying numpy dtype to coerce an individual value. The ``pandas`` -native
datatypes like :class:`~pandas.CategoricalDtype` and :class:`~pandas.BooleanDtype`
are also supported.

As an example of a special-cased ``coerce_value`` implementation, see
:py:meth:`~pandera.engines.pandas_engine.Category.coerce_value`:
As an example of a special-cased ``coerce_value`` implementation, see the
source code for :meth:`pandera.engines.pandas_engine.Category.coerce_value`:

.. code-block:: python
.. literalinclude:: ../../pandera/engines/pandas_engine.py
:lines: 580-586

And :py:meth:`~pandera.engines.pandas_engine.BOOL.coerce_value`:
def coerce_value(self, value: Any) -> Any:
"""Coerce an value to a particular type."""
if value not in self.categories: # type: ignore
raise TypeError(
f"value {value} cannot be coerced to type {self.type}"
)
return value
.. literalinclude:: ../../pandera/engines/pandas_engine.py
:lines: 223-229
Logical data types
~~~~~~~~~~~~~~~~~~
Expand All @@ -224,7 +226,7 @@ e.g.: ``Int8``, ``Float32``, ``String``, etc., whereas logical types represent t
abstracted understanding of that data. e.g.: ``IPs``, ``URLs``, ``paths``, etc.

Validating a logical data type consists of validating the supporting physical data type
(see :ref:`dtypes-into`) and a check on actual values. For example, an IP address data
(see :ref:`dtypes-intro`) and a check on actual values. For example, an IP address data
type would validate that:

1. The data container type is a ``String``.
Expand All @@ -238,3 +240,51 @@ validated via the pandera DataType :class:`~pandera.dtypes.Decimal`.
To implement a logical data type, you just need to implement the method
:meth:`pandera.dtypes.DataType.check` and make use of the ``data_container`` argument to
perform checks on the values of the data.

For example, you can create an ``IPAddress`` datatype that inherits from the numpy string
physical type, thereby storing the values as strings, and checks whether the values actually
match an IP address regular expression.

.. testcode:: dtypes

import re
from typing import Optional, Iterable, Union

@pandas_engine.Engine.register_dtype
@dtypes.immutable
class IPAddress(pandas_engine.NpString):

def check(
self,
pandera_dtype: dtypes.DataType,
data_container: Optional[pd.Series] = None,
) -> Union[bool, Iterable[bool]]:

# ensure that the data container's data type is a string,
# using the parent class's check implementation
correct_type = super().check(pandera_dtype)
if not correct_type:
return correct_type

# ensure the filepaths actually exist locally
exp = re.compile(r"(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})")
return data_container.map(lambda x: exp.match(x) is not None)

def __str__(self) -> str:
return str(self.__class__.__name__)

def __repr__(self) -> str:
return f"DataType({self})"


schema = pa.DataFrameSchema(columns={"ips": pa.Column(IPAddress)})
schema.validate(pd.DataFrame({"ips": ["0.0.0.0", "0.0.0.1", "0.0.0.a"]}))

.. testoutput:: dtypes

Traceback (most recent call last):
...
pandera.errors.SchemaError: expected series 'ips' to have type IPAddress:
failure cases:
index failure_case
0 2 0.0.0.a
1 change: 1 addition & 0 deletions docs/source/reference/dtypes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ Passing native pandas dtypes to pandera components is preferred.
pandera.engines.pandas_engine.DateTime
pandera.engines.pandas_engine.Date
pandera.engines.pandas_engine.Decimal
pandera.engines.pandas_engine.Category

GeoPandas Dtypes
----------------
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dependencies:

# modin extra
- modin
- protobuf <= 3.20
- protobuf <= 3.20.3

# dask extra
- dask
Expand Down
31 changes: 30 additions & 1 deletion pandera/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import dataclasses
import decimal
import inspect
from abc import ABC
from typing import (
Expand Down Expand Up @@ -390,6 +391,11 @@ class Complex64(Complex128):
DEFAULT_PYTHON_PREC = 28


def _scale_to_exp(scale: int) -> decimal.Decimal:
scale_fmt = format(10**-scale, f".{scale}f")
return decimal.Decimal(scale_fmt)


@immutable(init=True)
class Decimal(_Number):
"""Semantic representation of a decimal data type."""
Expand All @@ -402,7 +408,21 @@ class Decimal(_Number):
scale: int = 0 # default 0 is aligned with pyarrow and various databases.
"""The number of digits after the decimal point."""

def __init__(self, precision: int = DEFAULT_PYTHON_PREC, scale: int = 0):
# pylint: disable=line-too-long
rounding: str = dataclasses.field(
default_factory=lambda: decimal.getcontext().rounding
)
"""
The `rounding mode <https://docs.python.org/3/library/decimal.html#rounding-modes>`__
supported by the Python :py:class:`decimal.Decimal` class.
"""

def __init__(
self,
precision: int = DEFAULT_PYTHON_PREC,
scale: int = 0,
rounding: Optional[str] = None,
):
super().__init__()
if precision <= 0:
raise ValueError(
Expand All @@ -414,6 +434,15 @@ def __init__(self, precision: int = DEFAULT_PYTHON_PREC, scale: int = 0):
)
object.__setattr__(self, "precision", precision)
object.__setattr__(self, "scale", scale)
object.__setattr__(self, "rounding", rounding)
object.__setattr__(
self, "_exp", _scale_to_exp(scale) if scale else None
)
object.__setattr__(
self,
"_ctx",
decimal.Context(prec=precision, rounding=self.rounding),
)

def __str__(self) -> str:
return f"{self.__class__.__name__}({self.precision}, {self.scale})"
Expand Down
7 changes: 7 additions & 0 deletions pandera/engines/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,13 @@ def dtype(cls: _EngineType, data_type: Any) -> _DataType:
equivalent_data_type = registry.equivalents.get(data_type)
if equivalent_data_type is not None:
return equivalent_data_type
elif isinstance(data_type, DataType):
# in the case where data_type is a parameterized dtypes.DataType instance that isn't
# in the equivalents registry, use its type to get the equivalent, and feed
# the parameters into the recognized data type class.
equivalent_data_type = registry.equivalents.get(type(data_type))
if equivalent_data_type is not None:
return type(equivalent_data_type)(**data_type.__dict__)

try:
return registry.dispatch(data_type)
Expand Down
25 changes: 9 additions & 16 deletions pandera/engines/pandas_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,11 +428,6 @@ class FLOAT32(FLOAT64):
###############################################################################


def _scale_to_exp(scale: int) -> decimal.Decimal:
scale_fmt = format(10**-scale, f".{scale}f")
return decimal.Decimal(scale_fmt)


def _check_decimal(
pandas_obj: pd.Series,
precision: Optional[int] = None,
Expand Down Expand Up @@ -471,7 +466,9 @@ def _check_decimal(
return is_valid.to_numpy()


@Engine.register_dtype(equivalents=["decimal", decimal.Decimal])
@Engine.register_dtype(
equivalents=["decimal", decimal.Decimal, dtypes.Decimal]
)
@immutable(init=True)
class Decimal(DataType, dtypes.Decimal):
# pylint:disable=line-too-long
Expand All @@ -488,6 +485,11 @@ class Decimal(DataType, dtypes.Decimal):
rounding: str = dataclasses.field(
default_factory=lambda: decimal.getcontext().rounding
)
"""
The `rounding mode <https://docs.python.org/3/library/decimal.html#rounding-modes>`__
supported by the Python :py:class:`decimal.Decimal` class.
"""

_exp: decimal.Decimal = dataclasses.field(init=False)
_ctx: decimal.Context = dataclasses.field(init=False)

Expand All @@ -497,16 +499,7 @@ def __init__( # pylint:disable=super-init-not-called
scale: int = 0,
rounding: Optional[str] = None,
) -> None:
dtypes.Decimal.__init__(self, precision, scale)
object.__setattr__(self, "rounding", rounding)
object.__setattr__(
self, "_exp", _scale_to_exp(scale) if scale else None
)
object.__setattr__(
self,
"_ctx",
decimal.Context(prec=precision, rounding=self.rounding),
)
dtypes.Decimal.__init__(self, precision, scale, rounding)

def coerce_value(self, value: Any) -> decimal.Decimal:
"""Coerce an value to a particular type."""
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pandas-stubs <= 1.4.3.220807
pyspark-stubs
pyspark >= 3.2.0
modin
protobuf <= 3.20
protobuf <= 3.20.3
dask
distributed
geopandas
Expand Down

0 comments on commit 3063a0a

Please sign in to comment.