Skip to content

Commit

Permalink
fix duration filters with different time units (#3179)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcvanheerden committed Apr 19, 2022
1 parent 6220dcd commit 3ba7106
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
16 changes: 15 additions & 1 deletion polars/polars-core/src/series/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,8 @@ pub(crate) fn coerce_lhs_rhs<'a>(
Ok((left, right))
}

// Handle (Date | Datetime) +/- (Duration) | (Duration) +/- (Date | Datetime)
// Handle (Date | Datetime) +/- (Duration) | (Duration) +/- (Date | Datetime) | (Duration) +-
// (Duration)
// Time arithmetic is only implemented on the date / datetime so ensure that's on left

fn coerce_time_units<'a>(
Expand All @@ -365,6 +366,19 @@ fn coerce_time_units<'a>(
Cow::Owned(rhs.cast(&DataType::Duration(units))?)
};
Ok((left, right))
} else if let (DataType::Duration(lu), DataType::Duration(ru)) = (lhs.dtype(), rhs.dtype()) {
let units = get_time_units(lu, ru);
let left = if *lu == units {
Cow::Borrowed(lhs)
} else {
Cow::Owned(lhs.cast(&DataType::Duration(units))?)
};
let right = if *ru == units {
Cow::Borrowed(rhs)
} else {
Cow::Owned(rhs.cast(&DataType::Duration(units))?)
};
Ok((left, right))
} else if let (DataType::Date, DataType::Duration(units)) = (lhs.dtype(), rhs.dtype()) {
let left = Cow::Owned(lhs.cast(&DataType::Datetime(*units, None))?);
Ok((left, Cow::Borrowed(rhs)))
Expand Down
12 changes: 12 additions & 0 deletions py-polars/tests/test_datelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,3 +815,15 @@ def test_timelike_init() -> None:
for ts in [durations, dates, datetimes]:
s = pl.Series(ts)
assert s.to_list() == ts


def test_duration_filter() -> None:
date_df = pl.DataFrame(
{
"start_date": [date(2022, 1, 1), date(2022, 1, 1), date(2022, 1, 1)],
"end_date": [date(2022, 1, 7), date(2022, 2, 20), date(2023, 1, 1)],
}
).with_column((pl.col("end_date") - pl.col("start_date")).alias("time_passed"))

assert date_df.filter(pl.col("time_passed") < timedelta(days=30)).shape[0] == 1
assert date_df.filter(pl.col("time_passed") >= timedelta(days=30)).shape[0] == 2

0 comments on commit 3ba7106

Please sign in to comment.