From c1e7c06faa8fbb7f723b5a5c1f61a6c03ef73356 Mon Sep 17 00:00:00 2001 From: Niels Bantilan Date: Fri, 19 Apr 2024 11:00:59 -0400 Subject: [PATCH] implement timezone agnostic polars_engine.DateTime type (#1589) Signed-off-by: cosmicBboy --- pandera/engines/polars_engine.py | 43 +++++++++++++++++++++++++-- tests/polars/test_polars_container.py | 36 ++++++++++++++++++++++ tests/polars/test_polars_dtypes.py | 40 +++++++++++++++++++++++++ 3 files changed, 116 insertions(+), 3 deletions(-) diff --git a/pandera/engines/polars_engine.py b/pandera/engines/polars_engine.py index 07d30eac5..f73954427 100644 --- a/pandera/engines/polars_engine.py +++ b/pandera/engines/polars_engine.py @@ -5,7 +5,16 @@ import decimal import inspect import warnings -from typing import Any, Union, Optional, Iterable, Literal, Sequence, Tuple +from typing import ( + Any, + Union, + Optional, + Iterable, + Literal, + Sequence, + Tuple, + Type, +) import polars as pl @@ -416,16 +425,26 @@ class Date(DataType, dtypes.Date): class DateTime(DataType, dtypes.DateTime): """Polars datetime data type.""" - type = pl.Datetime + type: Type[pl.Datetime] = pl.Datetime + time_zone_agnostic: bool = False def __init__( # pylint:disable=super-init-not-called self, time_zone: Optional[str] = None, time_unit: Optional[str] = None, + time_zone_agnostic: bool = False, ) -> None: + + _kwargs = {} + if time_unit is not None: + # avoid deprecated warning when initializing pl.Datetime: + # passing time_unit=None is deprecated. + _kwargs["time_unit"] = time_unit + object.__setattr__( - self, "type", pl.Datetime(time_zone=time_zone, time_unit=time_unit) + self, "type", pl.Datetime(time_zone=time_zone, **_kwargs) ) + object.__setattr__(self, "time_zone_agnostic", time_zone_agnostic) @classmethod def from_parametrized_dtype(cls, polars_dtype: pl.Datetime): @@ -435,6 +454,24 @@ def from_parametrized_dtype(cls, polars_dtype: pl.Datetime): time_zone=polars_dtype.time_zone, time_unit=polars_dtype.time_unit ) + def check( + self, + pandera_dtype: dtypes.DataType, + data_container: Optional[PolarsDataContainer] = None, + ) -> Union[bool, Iterable[bool]]: + try: + pandera_dtype = Engine.dtype(pandera_dtype) + except TypeError: + return False + + if self.time_zone_agnostic: + return ( + isinstance(pandera_dtype.type, pl.Datetime) + and pandera_dtype.type.time_unit == self.type.time_unit + ) + + return self.type == pandera_dtype.type and super().check(pandera_dtype) + @Engine.register_dtype( equivalents=[ diff --git a/tests/polars/test_polars_container.py b/tests/polars/test_polars_container.py index dbcb4b2f3..afa9623d4 100644 --- a/tests/polars/test_polars_container.py +++ b/tests/polars/test_polars_container.py @@ -11,9 +11,14 @@ import polars as pl import pytest +from hypothesis import given +from hypothesis import strategies as st +from polars.testing.parametric import dataframes, column + import pandera as pa from pandera import Check as C from pandera.api.polars.types import PolarsData +from pandera.engines import polars_engine as pe from pandera.polars import Column, DataFrameSchema, DataFrameModel @@ -528,3 +533,34 @@ class Config: lf_with_nested_types, lazy=True ) assert validated_lf.collect().equals(validated_lf.collect()) + + +@pytest.mark.parametrize( + "time_zone", + [ + None, + "UTC", + "GMT", + "EST", + ], +) +@given(st.data()) +def test_dataframe_schema_with_tz_agnostic_dates(time_zone, data): + strategy = dataframes( + column("datetime_col", dtype=pl.Datetime()), + lazy=True, + size=10, + ) + lf = data.draw(strategy) + lf = lf.cast({"datetime_col": pl.Datetime(time_zone=time_zone)}) + schema_tz_agnostic = DataFrameSchema( + {"datetime_col": Column(pe.DateTime(time_zone_agnostic=True))} + ) + schema_tz_agnostic.validate(lf) + + schema_tz_sensitive = DataFrameSchema( + {"datetime_col": Column(pe.DateTime(time_zone_agnostic=False))} + ) + if time_zone: + with pytest.raises(pa.errors.SchemaError): + schema_tz_sensitive.validate(lf) diff --git a/tests/polars/test_polars_dtypes.py b/tests/polars/test_polars_dtypes.py index cbb451af7..7d0de216a 100644 --- a/tests/polars/test_polars_dtypes.py +++ b/tests/polars/test_polars_dtypes.py @@ -1,4 +1,6 @@ """Polars dtype tests.""" + +import datetime import decimal from decimal import Decimal from typing import Union, Tuple, Sequence @@ -403,3 +405,41 @@ def test_polars_nested_dtypes_try_coercion( pe.Engine.dtype(noncoercible_dtype).try_coerce(PolarsData(data)) except pandera.errors.ParserError as exc: assert exc.failure_cases.equals(data.collect()) + + +@pytest.mark.parametrize( + "dtype", + [ + "datetime", + datetime.datetime, + pl.Datetime, + pl.Datetime(), + pl.Datetime(time_unit="ns"), + pl.Datetime(time_unit="us"), + pl.Datetime(time_unit="ms"), + pl.Datetime(time_zone="UTC"), + ], +) +def test_datetime_time_zone_agnostic(dtype): + + tz_agnostic = pe.DateTime(time_zone_agnostic=True) + dtype = pe.Engine.dtype(dtype) + + if tz_agnostic.type.time_unit == getattr(dtype.type, "time_unit", "us"): + # timezone agnostic pandera dtype should pass regardless of timezone + assert tz_agnostic.check(dtype) + else: + # but fail if the time units don't match + assert not tz_agnostic.check(dtype) + + tz_sensitive = pe.DateTime() + if getattr(dtype.type, "time_zone", None) is not None: + assert not tz_sensitive.check(dtype) + + tz_sensitive_utc = pe.DateTime(time_zone="UTC") + if getattr( + dtype.type, "time_zone", None + ) is None and tz_sensitive_utc.type.time_zone != getattr( + dtype.type, "time_zone", None + ): + assert not tz_sensitive_utc.check(dtype)