Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(python): Convert date and datetime in literal construction #16018

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 36 additions & 10 deletions py-polars/polars/functions/lit.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,25 +78,42 @@ def lit(
time_unit: TimeUnit

if isinstance(value, datetime):
if dtype == Date:
dt_int = date_to_int(value.date())
return lit(dt_int).cast(Date)

# parse time unit
if dtype is not None and (tu := getattr(dtype, "time_unit", "us")) is not None:
time_unit = tu # type: ignore[assignment]
else:
time_unit = "us"

time_zone: str | None = getattr(dtype, "time_zone", None)
if (tzinfo := value.tzinfo) is not None:
tzinfo_str = str(tzinfo)
if time_zone is not None and time_zone != tzinfo_str:
msg = f"time zone of dtype ({time_zone!r}) differs from time zone of value ({tzinfo!r})"
# parse time zone
dtype_tz = getattr(dtype, "time_zone", None)
value_tz = value.tzinfo
if value_tz is None:
tz = dtype_tz
else:
if dtype_tz is None:
# value has time zone, but dtype does not: keep value time zone
tz = str(value_tz)
elif str(value_tz) == dtype_tz:
# dtype and value both have same time zone
tz = str(value_tz)
else:
# value has time zone that differs from dtype time zone
msg = (
f"time zone of dtype ({dtype_tz!r}) differs from time zone of "
f"value ({value_tz!r})"
)
raise TypeError(msg)
time_zone = tzinfo_str

dt_utc = value.replace(tzinfo=timezone.utc)
dt_int = datetime_to_int(dt_utc, time_unit)
expr = lit(dt_int).cast(Datetime(time_unit))
if time_zone is not None:
if tz is not None:
mcrumiller marked this conversation as resolved.
Show resolved Hide resolved
expr = expr.dt.replace_time_zone(
time_zone, ambiguous="earliest" if value.fold == 0 else "latest"
tz, ambiguous="earliest" if value.fold == 0 else "latest"
)
return expr

Expand All @@ -114,8 +131,17 @@ def lit(
return lit(time_int).cast(Time)

elif isinstance(value, date):
date_int = date_to_int(value)
return lit(date_int).cast(Date)
if dtype == Datetime:
time_unit = getattr(dtype, "time_unit", "us") or "us"
dt_utc = datetime(value.year, value.month, value.day)
dt_int = datetime_to_int(dt_utc, time_unit)
expr = lit(dt_int).cast(Datetime(time_unit))
if (time_zone := getattr(dtype, "time_zone", None)) is not None:
expr = expr.dt.replace_time_zone(str(time_zone))
return expr
else:
date_int = date_to_int(value)
return lit(date_int).cast(Date)

elif isinstance(value, pl.Series):
value = value._s
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from collections import OrderedDict
from datetime import date, datetime, time, timedelta
from typing import TYPE_CHECKING

Expand All @@ -15,7 +16,7 @@
if TYPE_CHECKING:
from zoneinfo import ZoneInfo

from polars._typing import TemporalLiteral, TimeUnit
from polars._typing import PolarsDataType, TemporalLiteral, TimeUnit
else:
from polars._utils.convert import string_to_zoneinfo as ZoneInfo

Expand Down Expand Up @@ -1350,3 +1351,79 @@ def test_dt_mean_deprecated() -> None:
with pytest.deprecated_call():
result = s.dt.mean()
assert result == s.mean()


@pytest.mark.parametrize(
"dtype",
[
pl.Date,
pl.Datetime("ms"),
pl.Datetime("ms", "EST"),
pl.Datetime("us"),
pl.Datetime("us", "EST"),
pl.Datetime("ns"),
pl.Datetime("ns", "EST"),
],
)
@pytest.mark.parametrize(
"value",
[
date(1677, 9, 22),
date(1970, 1, 1),
date(2024, 2, 29),
date(2262, 4, 11),
],
)
def test_literal_from_date(
value: date,
dtype: PolarsDataType,
) -> None:
out = pl.select(pl.lit(value, dtype=dtype))
assert out.schema == OrderedDict({"literal": dtype})
if dtype == pl.Datetime:
tz = ZoneInfo(dtype.time_zone) if dtype.time_zone is not None else None # type: ignore[union-attr]
value = datetime(value.year, value.month, value.day, tzinfo=tz)
assert out.item() == value


@pytest.mark.parametrize(
"dtype",
[
pl.Date,
pl.Datetime("ms"),
pl.Datetime("ms", "EST"),
pl.Datetime("us"),
pl.Datetime("us", "EST"),
pl.Datetime("ns"),
pl.Datetime("ns", "EST"),
],
)
@pytest.mark.parametrize(
"value",
[
datetime(1677, 9, 22),
datetime(1677, 9, 22, tzinfo=ZoneInfo("EST")),
datetime(1970, 1, 1),
datetime(1970, 1, 1, tzinfo=ZoneInfo("EST")),
datetime(2024, 2, 29),
datetime(2024, 2, 29, tzinfo=ZoneInfo("EST")),
datetime(2262, 4, 11),
datetime(2262, 4, 11, tzinfo=ZoneInfo("EST")),
],
)
def test_literal_from_datetime(
value: datetime,
dtype: pl.Date | pl.Datetime,
) -> None:
out = pl.select(pl.lit(value, dtype=dtype))
if dtype == pl.Date:
value = value.date() # type: ignore[assignment]
elif dtype.time_zone is None and value.tzinfo is not None: # type: ignore[union-attr]
# update the dtype with the supplied time zone in the value
dtype = pl.Datetime(dtype.time_unit, str(value.tzinfo)) # type: ignore[union-attr]
elif dtype.time_zone is not None and value.tzinfo is None: # type: ignore[union-attr]
# cast from dt without tz to dtype with tz
value = value.replace(tzinfo=ZoneInfo(dtype.time_zone)) # type: ignore[union-attr]

assert out.schema == OrderedDict({"literal": dtype})
assert out.item() == value