diff --git a/docs/source/polars.md b/docs/source/polars.md index 584cb15c..7f832fa0 100644 --- a/docs/source/polars.md +++ b/docs/source/polars.md @@ -532,7 +532,11 @@ from pandera.engines.polars_engine import DateTime class DateTimeModel(pa.DataFrameModel): - created_at: Annotated[DateTime, True] + created_at: Annotated[DateTime, True, "us", None] +``` +. +```{note} +For `Annotated` types, you need to pass in all positional and keyword arguments. ``` ::: diff --git a/pandera/api/polars/model.py b/pandera/api/polars/model.py index ca8b0eb5..25608dff 100644 --- a/pandera/api/polars/model.py +++ b/pandera/api/polars/model.py @@ -1,5 +1,6 @@ """Class-based api for polars models.""" +import inspect from typing import Dict, List, Tuple, Type import pandas as pd @@ -47,10 +48,16 @@ def _build_columns( # pylint:disable=too-many-locals field_name = field.name check_name = getattr(field, "check_name", None) - engine_dtype = None try: engine_dtype = pe.Engine.dtype(annotation.raw_annotation) - dtype = engine_dtype.type + if inspect.isclass(annotation.raw_annotation) and issubclass( + annotation.raw_annotation, pe.DataType + ): + # use the raw annotation as the dtype if it's a native + # pandera polars datatype + dtype = annotation.raw_annotation + else: + dtype = engine_dtype.type except (TypeError, ValueError) as exc: if annotation.metadata: if field.dtype_kwargs: @@ -64,13 +71,13 @@ def _build_columns( # pylint:disable=too-many-locals elif annotation.default_dtype: dtype = annotation.default_dtype else: - dtype = annotation.arg + dtype = annotation.arg # type: ignore if ( annotation.origin is None or isinstance(annotation.origin, pl.datatypes.DataTypeClass) or annotation.origin is Series - or engine_dtype + or dtype ): if check_name is False: raise SchemaInitError( diff --git a/tests/polars/test_polars_model.py b/tests/polars/test_polars_model.py index d420b528..b88481f1 100644 --- a/tests/polars/test_polars_model.py +++ b/tests/polars/test_polars_model.py @@ -3,9 +3,18 @@ import sys from typing import Optional +try: # python 3.9+ + from typing import Annotated # type: ignore +except ImportError: + from typing_extensions import Annotated # type: ignore + import polars as pl import pytest +from hypothesis import given +from hypothesis import strategies as st +from polars.testing.parametric import column, dataframes +import pandera.engines.polars_engine as pe from pandera.errors import SchemaError from pandera.polars import ( Column, @@ -211,3 +220,53 @@ class ModelWithNestedDtypes(DataFrameModel): schema = ModelWithNestedDtypes.to_schema() assert schema_with_list_type == schema + + +@pytest.mark.parametrize( + "time_zone", + [ + None, + "UTC", + "GMT", + "EST", + ], +) +@given(st.data()) +def test_dataframe_schema_with_tz_agnostic_dates(time_zone, data): + strategy = dataframes( + column("datetime_col", dtype=pl.Datetime()), + lazy=True, + size=10, + ) + lf = data.draw(strategy) + lf = lf.cast({"datetime_col": pl.Datetime(time_zone=time_zone)}) + + class ModelTZAgnosticKwargs(DataFrameModel): + datetime_col: pe.DateTime = Field( + dtype_kwargs={"time_zone_agnostic": True} + ) + + class ModelTZSensitiveKwargs(DataFrameModel): + datetime_col: pe.DateTime = Field( + dtype_kwargs={"time_zone_agnostic": False} + ) + + class ModelTZAgnosticAnnotated(DataFrameModel): + datetime_col: Annotated[pe.DateTime, True, "us", None] + + class ModelTZSensitiveAnnotated(DataFrameModel): + datetime_col: Annotated[pe.DateTime, False, "us", None] + + for tz_agnostic_model in ( + ModelTZAgnosticKwargs, + ModelTZAgnosticAnnotated, + ): + tz_agnostic_model.validate(lf) + + for tz_sensitive_model in ( + ModelTZSensitiveKwargs, + ModelTZSensitiveAnnotated, + ): + if time_zone: + with pytest.raises(SchemaError): + tz_sensitive_model.validate(lf)