diff --git a/py-polars/polars/functions/lit.py b/py-polars/polars/functions/lit.py index b636d5f2544a6..6c2add492465c 100644 --- a/py-polars/polars/functions/lit.py +++ b/py-polars/polars/functions/lit.py @@ -77,25 +77,43 @@ def lit( time_unit: TimeUnit if isinstance(value, datetime): + if dtype == Date: + dt_int = date_to_int(value.date()) + return lit(dt_int).cast(Date) + + # parse time unit if dtype is not None and (tu := getattr(dtype, "time_unit", "us")) is not None: time_unit = tu # type: ignore[assignment] else: time_unit = "us" - time_zone: str | None = getattr(dtype, "time_zone", None) - if (tzinfo := value.tzinfo) is not None: - tzinfo_str = str(tzinfo) - if time_zone is not None and time_zone != tzinfo_str: - msg = f"time zone of dtype ({time_zone!r}) differs from time zone of value ({tzinfo!r})" + # parse time zone + dtype_tz = getattr(dtype, "time_zone", None) + value_tz = value.tzinfo + if value_tz is None: + tz = dtype_tz + else: + if dtype_tz is None: + # value has time zone, but dtype does not: keep value time zone + tz = str(value_tz) + elif str(value_tz) == dtype_tz: + # dtype and value both have same time zone + tz = str(value_tz) + else: + # value has time zone that differs from dtype time zone + msg = ( + f"time zone of dtype ({dtype_tz!r}) differs from time zone of " + f"value ({value_tz!r})" + ) raise TypeError(msg) - time_zone = tzinfo_str dt_utc = value.replace(tzinfo=timezone.utc) dt_int = datetime_to_int(dt_utc, time_unit) expr = lit(dt_int).cast(Datetime(time_unit)) - if time_zone is not None: + if tz is not None: + print(f"tz is {tz}") expr = expr.dt.replace_time_zone( - time_zone, ambiguous="earliest" if value.fold == 0 else "latest" + tz, ambiguous="earliest" if value.fold == 0 else "latest" ) return expr @@ -113,8 +131,17 @@ def lit( return lit(time_int).cast(Time) elif isinstance(value, date): - date_int = date_to_int(value) - return lit(date_int).cast(Date) + if dtype == Datetime: + time_unit = getattr(dtype, "time_unit", "us") or "us" + dt_utc = datetime(value.year, value.month, value.day) + dt_int = datetime_to_int(dt_utc, time_unit) + expr = lit(dt_int).cast(Datetime(time_unit)) + if (time_zone := getattr(dtype, "time_zone", None)) is not None: + expr = expr.dt.replace_time_zone(str(time_zone)) + return expr + else: + date_int = date_to_int(value) + return lit(date_int).cast(Date) elif isinstance(value, pl.Series): value = value._s diff --git a/py-polars/tests/unit/namespaces/test_datetime.py b/py-polars/tests/unit/namespaces/test_datetime.py index 5307a2e7e129e..43e1f322667f8 100644 --- a/py-polars/tests/unit/namespaces/test_datetime.py +++ b/py-polars/tests/unit/namespaces/test_datetime.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections import OrderedDict from datetime import date, datetime, time, timedelta from typing import TYPE_CHECKING @@ -13,6 +14,7 @@ if TYPE_CHECKING: from zoneinfo import ZoneInfo + from polars.datatypes import PolarsDataType from polars.type_aliases import TemporalLiteral, TimeUnit else: from polars._utils.convert import string_to_zoneinfo as ZoneInfo @@ -1310,3 +1312,79 @@ def test_agg_median_expr() -> None: ) assert_frame_equal(df.select(pl.all().median()), expected) + + +@pytest.mark.parametrize( + "dtype", + [ + pl.Date, + pl.Datetime("ms"), + pl.Datetime("ms", "EST"), + pl.Datetime("us"), + pl.Datetime("us", "EST"), + pl.Datetime("ns"), + pl.Datetime("ns", "EST"), + ], +) +@pytest.mark.parametrize( + "value", + [ + date(1677, 9, 22), + date(1970, 1, 1), + date(2024, 2, 29), + date(2262, 4, 11), + ], +) +def test_literal_from_date( + value: date, + dtype: PolarsDataType, +) -> None: + out = pl.select(pl.lit(value, dtype=dtype)) + assert out.schema == OrderedDict({"literal": dtype}) + if dtype == pl.Datetime: + tz = ZoneInfo(dtype.time_zone) if dtype.time_zone is not None else None # type: ignore[union-attr] + value = datetime(value.year, value.month, value.day, tzinfo=tz) + assert out.item() == value + + +@pytest.mark.parametrize( + "dtype", + [ + pl.Date, + pl.Datetime("ms"), + pl.Datetime("ms", "EST"), + pl.Datetime("us"), + pl.Datetime("us", "EST"), + pl.Datetime("ns"), + pl.Datetime("ns", "EST"), + ], +) +@pytest.mark.parametrize( + "value", + [ + datetime(1677, 9, 22), + datetime(1677, 9, 22, tzinfo=ZoneInfo("EST")), + datetime(1970, 1, 1), + datetime(1970, 1, 1, tzinfo=ZoneInfo("EST")), + datetime(2024, 2, 29), + datetime(2024, 2, 29, tzinfo=ZoneInfo("EST")), + datetime(2262, 4, 11), + datetime(2262, 4, 11, tzinfo=ZoneInfo("EST")), + ], +) +def test_literal_from_datetime( + value: datetime, + dtype: pl.Date | pl.Datetime, +) -> None: + out = pl.select(pl.lit(value, dtype=dtype)) + if dtype == pl.Date: + value = value.date() # type: ignore[assignment] + elif dtype.time_zone is None and value.tzinfo is not None: # type: ignore[union-attr] + # update the dtype with the supplied time zone in the value + dtype = pl.Datetime(dtype.time_unit, str(value.tzinfo)) # type: ignore[arg-type, union-attr] + elif dtype.time_zone is not None and value.tzinfo is None: # type: ignore[union-attr] + # cast from dt without tz to dtype with tz + value = value.replace(tzinfo=ZoneInfo(dtype.time_zone)) # type: ignore[union-attr] + + assert out.schema == OrderedDict({"literal": dtype}) + assert out.item() == value