Skip to content

Commit

Permalink
fix(rust, python): correct output dtype for cummin/cumsum/cummax (#6062)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jan 5, 2023
1 parent 5bf8e31 commit a609f25
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 16 deletions.
29 changes: 13 additions & 16 deletions polars/polars-lazy/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -880,14 +880,18 @@ impl Expr {
move |s: Series| Ok(s.cumsum(reverse)),
GetOutput::map_dtype(|dt| {
use DataType::*;
match dt {
Boolean => UInt32,
Int32 => Int32,
UInt32 => UInt32,
UInt64 => UInt64,
Float32 => Float32,
Float64 => Float64,
_ => Int64,
if dt.is_logical() {
dt.clone()
} else {
match dt {
Boolean => UInt32,
Int32 => Int32,
UInt32 => UInt32,
UInt64 => UInt64,
Float32 => Float32,
Float64 => Float64,
_ => Int64,
}
}
}),
)
Expand Down Expand Up @@ -928,14 +932,7 @@ impl Expr {
pub fn cummax(self, reverse: bool) -> Self {
self.apply(
move |s: Series| Ok(s.cummax(reverse)),
GetOutput::map_dtype(|dt| {
use DataType::*;
match dt {
Float32 => Float32,
Float64 => Float64,
_ => Int64,
}
}),
GetOutput::same_type(),
)
.with_fmt("cummax")
}
Expand Down
55 changes: 55 additions & 0 deletions py-polars/tests/unit/test_datelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -2337,3 +2337,58 @@ def test_tz_aware_day_weekday() -> None:
"tyo_weekday": [1, 4, 7],
"ny_weekday": [7, 3, 6],
}


def test_datetime_cum_agg_schema() -> None:
df = pl.DataFrame(
{
"timestamp": [
datetime(2023, 1, 1),
datetime(2023, 1, 2),
datetime(2023, 1, 3),
]
}
)
# Exactly the same as above but with lazy() and collect() later
assert (
df.lazy()
.with_columns(
[
(pl.col("timestamp").cummin()).alias("cummin"),
(pl.col("timestamp").cummax()).alias("cummax"),
]
)
.with_columns(
[
(pl.col("cummin") + pl.duration(hours=24)).alias("cummin+24"),
(pl.col("cummax") + pl.duration(hours=24)).alias("cummax+24"),
]
)
.collect()
).to_dict(False) == {
"timestamp": [
datetime(2023, 1, 1, 0, 0),
datetime(2023, 1, 2, 0, 0),
datetime(2023, 1, 3, 0, 0),
],
"cummin": [
datetime(2023, 1, 1, 0, 0),
datetime(2023, 1, 1, 0, 0),
datetime(2023, 1, 1, 0, 0),
],
"cummax": [
datetime(2023, 1, 1, 0, 0),
datetime(2023, 1, 2, 0, 0),
datetime(2023, 1, 3, 0, 0),
],
"cummin+24": [
datetime(2023, 1, 2, 0, 0),
datetime(2023, 1, 2, 0, 0),
datetime(2023, 1, 2, 0, 0),
],
"cummax+24": [
datetime(2023, 1, 2, 0, 0),
datetime(2023, 1, 3, 0, 0),
datetime(2023, 1, 4, 0, 0),
],
}

0 comments on commit a609f25

Please sign in to comment.