Skip to content

Commit

Permalink
Convert date/datetime in lit construction
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller committed May 3, 2024
1 parent a87a8f3 commit 849292c
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 10 deletions.
47 changes: 37 additions & 10 deletions py-polars/polars/functions/lit.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,25 +77,43 @@ def lit(
time_unit: TimeUnit

if isinstance(value, datetime):
if dtype == Date:
dt_int = date_to_int(value.date())
return lit(dt_int).cast(Date)

# parse time unit
if dtype is not None and (tu := getattr(dtype, "time_unit", "us")) is not None:
time_unit = tu # type: ignore[assignment]
else:
time_unit = "us"

time_zone: str | None = getattr(dtype, "time_zone", None)
if (tzinfo := value.tzinfo) is not None:
tzinfo_str = str(tzinfo)
if time_zone is not None and time_zone != tzinfo_str:
msg = f"time zone of dtype ({time_zone!r}) differs from time zone of value ({tzinfo!r})"
# parse time zone
dtype_tz = getattr(dtype, "time_zone", None)
value_tz = value.tzinfo
if value_tz is None:
tz = dtype_tz
else:
if dtype_tz is None:
# value has time zone, but dtype does not: keep value time zone
tz = str(value_tz)
elif str(value_tz) == dtype_tz:
# dtype and value both have same time zone
tz = str(value_tz)
else:
# value has time zone that differs from dtype time zone
msg = (
f"time zone of dtype ({dtype_tz!r}) differs from time zone of "
f"value ({value_tz!r})"
)
raise TypeError(msg)
time_zone = tzinfo_str

dt_utc = value.replace(tzinfo=timezone.utc)
dt_int = datetime_to_int(dt_utc, time_unit)
expr = lit(dt_int).cast(Datetime(time_unit))
if time_zone is not None:
if tz is not None:
print(f"tz is {tz}")
expr = expr.dt.replace_time_zone(
time_zone, ambiguous="earliest" if value.fold == 0 else "latest"
tz, ambiguous="earliest" if value.fold == 0 else "latest"
)
return expr

Expand All @@ -113,8 +131,17 @@ def lit(
return lit(time_int).cast(Time)

elif isinstance(value, date):
date_int = date_to_int(value)
return lit(date_int).cast(Date)
if dtype == Datetime:
time_unit = getattr(dtype, "time_unit", "us") or "us"
dt_utc = datetime(value.year, value.month, value.day)
dt_int = datetime_to_int(dt_utc, time_unit)
expr = lit(dt_int).cast(Datetime(time_unit))
if (time_zone := getattr(dtype, "time_zone", None)) is not None:
expr = expr.dt.replace_time_zone(str(time_zone))
return expr
else:
date_int = date_to_int(value)
return lit(date_int).cast(Date)

elif isinstance(value, pl.Series):
value = value._s
Expand Down
78 changes: 78 additions & 0 deletions py-polars/tests/unit/namespaces/test_datetime.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from collections import OrderedDict
from datetime import date, datetime, time, timedelta
from typing import TYPE_CHECKING

Expand All @@ -13,6 +14,7 @@
if TYPE_CHECKING:
from zoneinfo import ZoneInfo

from polars.datatypes import PolarsDataType
from polars.type_aliases import TemporalLiteral, TimeUnit
else:
from polars._utils.convert import string_to_zoneinfo as ZoneInfo
Expand Down Expand Up @@ -1310,3 +1312,79 @@ def test_agg_median_expr() -> None:
)

assert_frame_equal(df.select(pl.all().median()), expected)


@pytest.mark.parametrize(
"dtype",
[
pl.Date,
pl.Datetime("ms"),
pl.Datetime("ms", "EST"),
pl.Datetime("us"),
pl.Datetime("us", "EST"),
pl.Datetime("ns"),
pl.Datetime("ns", "EST"),
],
)
@pytest.mark.parametrize(
"value",
[
date(1677, 9, 22),
date(1970, 1, 1),
date(2024, 2, 29),
date(2262, 4, 11),
],
)
def test_literal_from_date(
value: date,
dtype: PolarsDataType,
) -> None:
out = pl.select(pl.lit(value, dtype=dtype))
assert out.schema == OrderedDict({"literal": dtype})
if dtype == pl.Datetime:
tz = ZoneInfo(dtype.time_zone) if dtype.time_zone is not None else None # type: ignore[union-attr]
value = datetime(value.year, value.month, value.day, tzinfo=tz)
assert out.item() == value


@pytest.mark.parametrize(
"dtype",
[
pl.Date,
pl.Datetime("ms"),
pl.Datetime("ms", "EST"),
pl.Datetime("us"),
pl.Datetime("us", "EST"),
pl.Datetime("ns"),
pl.Datetime("ns", "EST"),
],
)
@pytest.mark.parametrize(
"value",
[
datetime(1677, 9, 22),
datetime(1677, 9, 22, tzinfo=ZoneInfo("EST")),
datetime(1970, 1, 1),
datetime(1970, 1, 1, tzinfo=ZoneInfo("EST")),
datetime(2024, 2, 29),
datetime(2024, 2, 29, tzinfo=ZoneInfo("EST")),
datetime(2262, 4, 11),
datetime(2262, 4, 11, tzinfo=ZoneInfo("EST")),
],
)
def test_literal_from_datetime(
value: datetime,
dtype: pl.Date | pl.Datetime,
) -> None:
out = pl.select(pl.lit(value, dtype=dtype))
if dtype == pl.Date:
value = value.date() # type: ignore[assignment]
elif dtype.time_zone is None and value.tzinfo is not None: # type: ignore[union-attr]
# update the dtype with the supplied time zone in the value
dtype = pl.Datetime(dtype.time_unit, str(value.tzinfo)) # type: ignore[arg-type, union-attr]
elif dtype.time_zone is not None and value.tzinfo is None: # type: ignore[union-attr]
# cast from dt without tz to dtype with tz
value = value.replace(tzinfo=ZoneInfo(dtype.time_zone)) # type: ignore[union-attr]

assert out.schema == OrderedDict({"literal": dtype})
assert out.item() == value

0 comments on commit 849292c

Please sign in to comment.