Skip to content

Commit

Permalink
fix(rust, python): fix incorrect duration dtype (#5226)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 17, 2022
1 parent 31ecab0 commit f32b44a
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 10 deletions.
2 changes: 1 addition & 1 deletion polars/polars-lazy/polars-plan/src/dsl/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ pub fn duration(args: DurationArgs) -> Expr {
args.weeks.unwrap_or_else(|| lit(0i64)),
],
function,
output_type: GetOutput::from_type(DataType::Datetime(TimeUnit::Microseconds, None)),
output_type: GetOutput::from_type(DataType::Duration(TimeUnit::Nanoseconds)),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
input_wildcard_expansion: true,
Expand Down
15 changes: 10 additions & 5 deletions py-polars/polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def _to_python_datetime(
dt += timedelta(seconds=value * 3600 * 24)
return dt.date()
elif dtype == Datetime:
if tz is None or tz == "" or tz == "UTC":
if tz is None or tz == "":
if tu == "ns":
# nanoseconds to seconds
dt = EPOCH + timedelta(microseconds=value / 1000)
Expand All @@ -225,18 +225,23 @@ def _to_python_datetime(
else:
raise ValueError(f"tu must be one of {{'ns', 'us', 'ms'}}, got {tu}")
else:
tzinfo = zoneinfo.ZoneInfo("UTC")
if tu == "ns":
# nanoseconds to seconds
dt = datetime.fromtimestamp(0) + timedelta(microseconds=value / 1000)
dt = datetime.fromtimestamp(0, tz=tzinfo) + timedelta(
microseconds=value / 1000
)
elif tu == "us":
dt = datetime.fromtimestamp(0) + timedelta(microseconds=value)
dt = datetime.fromtimestamp(0, tz=tzinfo) + timedelta(
microseconds=value
)
elif tu == "ms":
# milliseconds to seconds
dt = datetime.fromtimestamp(value / 1000)
dt = datetime.fromtimestamp(value / 1000, tz=tzinfo)
else:
raise ValueError(f"tu must be one of {{'ns', 'us', 'ms'}}, got {tu}")

return _localize(dt, tz)

return dt
else:
raise NotImplementedError # pragma: no cover
Expand Down
3 changes: 2 additions & 1 deletion py-polars/src/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -469,11 +469,12 @@ impl ToPyObject for Wrap<&DatetimeChunked> {
let py_date_dtype = pl.getattr("Datetime").unwrap();

let tu = Wrap(self.0.time_unit()).to_object(py);
let tz = self.0.time_zone().to_object(py);

let iter = self
.0
.into_iter()
.map(|opt_v| opt_v.map(|v| convert.call1((v, py_date_dtype, &tu)).unwrap()));
.map(|opt_v| opt_v.map(|v| convert.call1((v, py_date_dtype, &tu, &tz)).unwrap()));
PyList::new(py, iter).into_py(py)
}
}
Expand Down
27 changes: 24 additions & 3 deletions py-polars/tests/unit/test_datelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,8 @@ def test_read_utc_times_parquet() -> None:
df.to_parquet(f)
f.seek(0)
df_in = pl.read_parquet(f)
assert df_in["Timestamp"][0] == datetime(2022, 1, 1, 0, 0)
tz = zoneinfo.ZoneInfo("UTC")
assert df_in["Timestamp"][0] == datetime(2022, 1, 1, 0, 0, tzinfo=tz)


def test_epoch() -> None:
Expand Down Expand Up @@ -1607,14 +1608,16 @@ def test_invalid_date_parsing_4898() -> None:


def test_cast_timezone() -> None:
utc = zoneinfo.ZoneInfo("UTC")
ny = zoneinfo.ZoneInfo("America/New_York")
assert pl.DataFrame({"a": [datetime(2022, 9, 25, 14)]}).with_column(
pl.col("a")
.dt.with_time_zone("America/New_York")
.dt.cast_time_zone("UTC")
.alias("b")
).to_dict(False) == {
"a": [datetime(2022, 9, 25, 14, 0)],
"b": [datetime(2022, 9, 25, 18, 0)],
"b": [datetime(2022, 9, 25, 18, 0, tzinfo=utc)],
}
assert pl.DataFrame({"a": [datetime(2022, 9, 25, 18)]}).with_column(
pl.col("a")
Expand All @@ -1623,7 +1626,7 @@ def test_cast_timezone() -> None:
.alias("b")
).to_dict(False) == {
"a": [datetime(2022, 9, 25, 18, 0)],
"b": [datetime(2022, 9, 25, 14, 0)],
"b": [datetime(2022, 9, 25, 10, 0, tzinfo=ny)],
}


Expand All @@ -1633,3 +1636,21 @@ def test_tz_aware_get_idx_5010() -> None:
)
a = pa.array([when]).cast(pa.timestamp("s", tz="Asia/Shanghai"))
assert int(pl.from_arrow(a)[0].timestamp()) == when # type: ignore[union-attr]


def test_tz_datetime_duration_arithm_5221() -> None:
run_datetimes = [
datetime.fromisoformat("2022-01-01T00:00:00+00:00"),
datetime.fromisoformat("2022-01-02T00:00:00+00:00"),
]
out = pl.DataFrame(
data={"run_datetime": run_datetimes},
columns=[("run_datetime", pl.Datetime(time_zone="UTC"))],
)
utc = zoneinfo.ZoneInfo("UTC")
assert out.to_dict(False) == {
"run_datetime": [
datetime(2022, 1, 1, 0, 0, tzinfo=utc),
datetime(2022, 1, 2, 0, 0, tzinfo=utc),
]
}

0 comments on commit f32b44a

Please sign in to comment.