Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add pandas pyarrow backend support #1628

Merged
merged 7 commits into from
May 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pandera/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
INT64,
PANDAS_1_2_0_PLUS,
PANDAS_1_3_0_PLUS,
PANDAS_2_0_0_PLUS,
STRING,
UINT8,
UINT16,
Expand Down Expand Up @@ -136,7 +137,9 @@
"INT16",
"INT32",
"INT64",
"PANDAS_1_2_0_PLUS",
"PANDAS_1_3_0_PLUS",
"PANDAS_2_0_0_PLUS",
"STRING",
"UINT8",
"UINT16",
Expand Down
261 changes: 261 additions & 0 deletions pandera/engines/pandas_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@

PANDAS_1_2_0_PLUS = pandas_version().release >= (1, 2, 0)
PANDAS_1_3_0_PLUS = pandas_version().release >= (1, 3, 0)
PANDAS_2_0_0_PLUS = pandas_version().release >= (2, 0, 0)


# register different TypedDict type depending on python version
Expand All @@ -101,6 +102,16 @@
)


def is_pyarrow_dtype(
pd_dtype: PandasDataType,
) -> Union[bool, Iterable[bool]]:
"""Check if a value is a pandas pyarrow type or instance of one."""
if not PYARROW_INSTALLED:
raise TypeError("pyarrow must be installed to use pyarrow dtypes.")

Check warning on line 110 in pandera/engines/pandas_engine.py

View check run for this annotation

Codecov / codecov/patch

pandera/engines/pandas_engine.py#L110

Added line #L110 was not covered by tests

return isinstance(pd_dtype, pd.ArrowDtype)


@immutable(init=True)
class DataType(dtypes.DataType):
"""Base `DataType` for boxing Pandas data types."""
Expand Down Expand Up @@ -220,6 +231,8 @@
"Usage Tip: Use an instance or a string "
"representation."
) from None
elif is_pyarrow_dtype(data_type):
np_or_pd_dtype = data_type.pyarrow_dtype
else:
# let pandas transform any acceptable value
# into a numpy or pandas dtype.
Expand Down Expand Up @@ -1570,3 +1583,251 @@

def __str__(self) -> str:
return str(NamedTuple.__name__)


###############################################################################
# pyarrow types
###############################################################################

if PYARROW_INSTALLED and PANDAS_2_0_0_PLUS:

@Engine.register_dtype(
equivalents=[
"bool[pyarrow]",
pyarrow.bool_,
pd.ArrowDtype(pyarrow.bool_()),
]
)
@immutable
class ArrowBool(BOOL):
"""Semantic representation of a :class:`pyarrow.bool_`."""

type = pd.ArrowDtype(pyarrow.bool_())

@Engine.register_dtype(
equivalents=[
"int64[pyarrow]",
pyarrow.int64,
pd.ArrowDtype(pyarrow.int64()),
]
)
@immutable
class ArrowInt64(DataType, dtypes.Int):
"""Semantic representation of a :class:`pyarrow.int64`."""

type = pd.ArrowDtype(pyarrow.int64())
bit_width: int = 64

@Engine.register_dtype(
equivalents=[
"int32[pyarrow]",
pyarrow.int32,
pd.ArrowDtype(pyarrow.int32()),
]
)
@immutable
class ArrowInt32(ArrowInt64):
"""Semantic representation of a :class:`pyarrow.int32`."""

type = pd.ArrowDtype(pyarrow.int32())
bit_width: int = 32

@Engine.register_dtype(
equivalents=[
"int16[pyarrow]",
pyarrow.int16,
pd.ArrowDtype(pyarrow.int16()),
]
)
@immutable
class ArrowInt16(ArrowInt32):
"""Semantic representation of a :class:`pyarrow.int16`."""

type = pd.ArrowDtype(pyarrow.int16())
bit_width: int = 16

@Engine.register_dtype(
equivalents=[
"int8[pyarrow]",
pyarrow.int8,
pd.ArrowDtype(pyarrow.int8()),
]
)
@immutable
class ArrowInt8(ArrowInt16):
"""Semantic representation of a :class:`pyarrow.int8`."""

type = pd.ArrowDtype(pyarrow.int8())
bit_width: int = 8

@Engine.register_dtype(equivalents=[pyarrow.string])
@immutable
class ArrowString(DataType, dtypes.String):
"""Semantic representation of a :class:`pyarrow.string`."""

type = pd.ArrowDtype(pyarrow.string())

@Engine.register_dtype(
equivalents=[
"uint64[pyarrow]",
pyarrow.uint64,
pd.ArrowDtype(pyarrow.uint64()),
]
)
@immutable
class ArrowUInt64(DataType, dtypes.UInt):
"""Semantic representation of a :class:`pyarrow.uint64`."""

type = pd.ArrowDtype(pyarrow.uint64())
bit_width: int = 64

@Engine.register_dtype(
equivalents=[
"uint32[pyarrow]",
pyarrow.uint32,
pd.ArrowDtype(pyarrow.uint32()),
]
)
@immutable
class ArrowUInt32(ArrowUInt64):
"""Semantic representation of a :class:`pyarrow.uint32`."""

type = pd.ArrowDtype(pyarrow.uint32())
bit_width: int = 32

@Engine.register_dtype(
equivalents=[
"uint16[pyarrow]",
pyarrow.uint16,
pd.ArrowDtype(pyarrow.uint16()),
]
)
@immutable
class ArrowUInt16(ArrowUInt32):
"""Semantic representation of a :class:`pyarrow.uint16`."""

type = pd.ArrowDtype(pyarrow.uint16())
bit_width: int = 16

@Engine.register_dtype(
equivalents=[
"uint8[pyarrow]",
pyarrow.uint8,
pd.ArrowDtype(pyarrow.uint8()),
]
)
@immutable
class ArrowUInt8(ArrowUInt16):
"""Semantic representation of a :class:`pyarrow.uint8`."""

type = pd.ArrowDtype(pyarrow.uint8())
bit_width: int = 8

@Engine.register_dtype(
equivalents=[
"double[pyarrow]",
pyarrow.float64,
pd.ArrowDtype(pyarrow.float64()),
]
)
@immutable
class ArrowFloat64(DataType, dtypes.Float):
"""Semantic representation of a :class:`pyarrow.float64`."""

type = pd.ArrowDtype(pyarrow.float64())
bit_width: int = 64

@Engine.register_dtype(
equivalents=[
"float[pyarrow]",
pyarrow.float32,
pd.ArrowDtype(pyarrow.float32()),
]
)
@immutable
class ArrowFloat32(ArrowFloat64):
"""Semantic representation of a :class:`pyarrow.float32`."""

type = pd.ArrowDtype(pyarrow.float32())
bit_width: int = 32

@Engine.register_dtype(
equivalents=[pyarrow.decimal128, pyarrow.Decimal128Type]
)
@immutable(init=True)
class ArrowDecimal128(DataType, dtypes.Decimal):
"""Semantic representation of a :class:`pyarrow.decimal128`."""

type: Optional[pd.ArrowDtype] = dataclasses.field(
default=None, init=False
)
precision: int = 28
scale: int = 0

def __post_init__(self) -> None:
type_ = pd.ArrowDtype(
pyarrow.decimal128(self.precision, self.scale)
)
object.__setattr__(self, "type", type_)

@classmethod
def from_parametrized_dtype(
cls,
pyarrow_dtype: pyarrow.Decimal128Type,
):
return cls(precision=pyarrow_dtype.precision, scale=pyarrow_dtype.scale) # type: ignore

@Engine.register_dtype(
equivalents=[pyarrow.timestamp, pyarrow.TimestampType]
)
@immutable(init=True)
class ArrowTimestamp(DataType, dtypes.Timestamp):
"""Semantic representation of a :class:`pyarrow.timestamp`."""

type: Optional[pd.ArrowDtype] = dataclasses.field(
default=None, init=False
)
unit: Optional[str] = "ns"
tz: Optional[datetime.tzinfo] = None

def __post_init__(self):
type_ = pd.ArrowDtype(pyarrow.timestamp(self.unit, self.tz))
object.__setattr__(self, "type", type_)

@classmethod
def from_parametrized_dtype(cls, pyarrow_dtype: pyarrow.TimestampType):
return cls(unit=pyarrow_dtype.unit, tz=pyarrow_dtype.tz) # type: ignore

@Engine.register_dtype(
equivalents=[pyarrow.dictionary, pyarrow.DictionaryType]
)
@immutable(init=True)
class ArrowDictionary(DataType, dtypes.Category):
"""Semantic representation of a :class:`pyarrow.dictionary`."""

type: Optional[pd.ArrowDtype] = dataclasses.field(
default=None, init=False
)
index_type: Optional[pyarrow.DataType] = pyarrow.int64()
value_type: Optional[pyarrow.DataType] = pyarrow.int64()
ordered: bool = False

def __post_init__(self):
type_ = pd.ArrowDtype(
pyarrow.dictionary(
self.index_type,
self.value_type,
self.ordered,
)
)
object.__setattr__(self, "type", type_)

@classmethod
def from_parametrized_dtype(
cls, pyarrow_dtype: pyarrow.DictionaryType
):
return cls(
index_type=pyarrow_dtype.index_type, # type: ignore
value_type=pyarrow_dtype.value_type, # type: ignore
ordered=pyarrow_dtype.ordered, # type: ignore
)
14 changes: 13 additions & 1 deletion tests/core/test_pandas_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Test pandas engine."""

from datetime import date
from typing import Any, Set

import hypothesis
import hypothesis.extra.pandas as pd_st
Expand All @@ -14,9 +15,20 @@
from pandera.engines import pandas_engine
from pandera.errors import ParserError

UNSUPPORTED_DTYPE_CLS: Set[Any] = set()

# `string[pyarrow]` gets parsed to type `string` by pandas
if pandas_engine.PYARROW_INSTALLED and pandas_engine.PANDAS_2_0_0_PLUS:
UNSUPPORTED_DTYPE_CLS.add(pandas_engine.ArrowString)


@pytest.mark.parametrize(
"data_type", list(pandas_engine.Engine.get_registered_dtypes())
"data_type",
[
data_type
for data_type in pandas_engine.Engine.get_registered_dtypes()
if data_type not in UNSUPPORTED_DTYPE_CLS
],
)
def test_pandas_data_type(data_type):
"""Test numpy engine DataType base class."""
Expand Down
22 changes: 22 additions & 0 deletions tests/strategies/test_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,28 @@
pandas_engine.PythonNamedTuple,
]
)

if pandas_engine.PYARROW_INSTALLED and pandas_engine.PANDAS_2_0_0_PLUS:
UNSUPPORTED_DTYPE_CLS.update(
[
pandas_engine.ArrowBool,
pandas_engine.ArrowDecimal128,
pandas_engine.ArrowDictionary,
pandas_engine.ArrowFloat32,
pandas_engine.ArrowFloat64,
pandas_engine.ArrowInt8,
pandas_engine.ArrowInt16,
pandas_engine.ArrowInt32,
pandas_engine.ArrowInt64,
pandas_engine.ArrowString,
pandas_engine.ArrowTimestamp,
pandas_engine.ArrowUInt8,
pandas_engine.ArrowUInt16,
pandas_engine.ArrowUInt32,
pandas_engine.ArrowUInt64,
]
)

SUPPORTED_DTYPES = set()
for data_type in pandas_engine.Engine.get_registered_dtypes():
if (
Expand Down
Loading