Skip to content

Commit

Permalink
fix(rust, python): determine supertype of datetimes with timezones an… (
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 17, 2022
1 parent 07749ec commit 95c8a77
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
6 changes: 5 additions & 1 deletion polars/polars-core/src/utils/supertype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,11 @@ pub fn get_supertype(l: &DataType, r: &DataType) -> Option<DataType> {
// None and Some("<tz>") timezones
// we cast from more precision to higher precision as that always fits with occasional loss of precision
#[cfg(feature = "dtype-datetime")]
(Datetime(tu_l, tz_l), Datetime(tu_r, tz_r)) if tz_l.is_none() && tz_r.is_some() => {
(Datetime(tu_l, tz_l), Datetime(tu_r, tz_r)) if
// both are none
tz_l.is_none() && tz_r.is_some()
// both have the same time zone
|| (tz_l.is_some() && (tz_l == tz_r)) => {
let tu = get_time_units(tu_l, tu_r);
Some(Datetime(tu, tz_r.clone()))
}
Expand Down
24 changes: 23 additions & 1 deletion py-polars/tests/unit/test_queries.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

from datetime import datetime
from datetime import datetime, timedelta

import numpy as np
import pandas as pd

import polars as pl

Expand Down Expand Up @@ -392,3 +393,24 @@ def test_none_comparison_4773() -> None:
).filter(pl.col("x") != pl.col("y"))
assert df.shape == (3, 2)
assert df.rows() == [(0, 1), (1, 2), (2, 3)]


def test_datetime_supertype_5236() -> None:
df = pd.DataFrame(
{
"StartDateTime": [
pd.Timestamp(datetime.utcnow(), tz="UTC"),
pd.Timestamp(datetime.utcnow(), tz="UTC"),
],
"EndDateTime": [
pd.Timestamp(datetime.utcnow(), tz="UTC"),
pd.Timestamp(datetime.utcnow(), tz="UTC"),
],
}
)
out = pl.from_pandas(df).filter(
pl.col("StartDateTime")
< (pl.col("EndDateTime").dt.truncate("1d").max() - timedelta(days=1))
)
assert out.shape == (0, 2)
assert out.dtypes == [pl.Datetime("ns", "UTC")] * 2

0 comments on commit 95c8a77

Please sign in to comment.