Skip to content

Commit

Permalink
fix(rust): Fix lazy cumsum and cumprod result types (#5792)
Browse files Browse the repository at this point in the history
  • Loading branch information
rzhang-at-hrt committed Dec 13, 2022
1 parent 136a186 commit d29fa30
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
15 changes: 14 additions & 1 deletion polars/polars-lazy/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,18 @@ impl Expr {
pub fn cumsum(self, reverse: bool) -> Self {
self.apply(
move |s: Series| Ok(s.cumsum(reverse)),
GetOutput::same_type(),
GetOutput::map_dtype(|dt| {
use DataType::*;
match dt {
Boolean => UInt32,
Int32 => Int32,
UInt32 => UInt32,
UInt64 => UInt64,
Float32 => Float32,
Float64 => Float64,
_ => Int64,
}
}),
)
.with_fmt("cumsum")
}
Expand All @@ -889,6 +900,8 @@ impl Expr {
GetOutput::map_dtype(|dt| {
use DataType::*;
match dt {
Boolean => Int64,
UInt64 => UInt64,
Float32 => Float32,
Float64 => Float64,
_ => Int64,
Expand Down
25 changes: 25 additions & 0 deletions py-polars/tests/unit/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1616,3 +1616,28 @@ def test_from_epoch_seq_input() -> None:
expected = pl.Series([datetime(2006, 5, 17, 15, 34, 4)])
result = pl.from_epoch(seq_input)
assert_series_equal(result, expected)


def test_cumagg_types() -> None:
lf = pl.DataFrame({"a": [1, 2], "b": [True, False], "c": [1.3, 2.4]}).lazy()
cumsum_lf = lf.select(
[pl.col("a").cumsum(), pl.col("b").cumsum(), pl.col("c").cumsum()]
)
assert cumsum_lf.schema["a"] == pl.Int64
assert cumsum_lf.schema["b"] == pl.UInt32
assert cumsum_lf.schema["c"] == pl.Float64
collected_cumsum_lf = cumsum_lf.collect()
assert collected_cumsum_lf.schema == cumsum_lf.schema

cumprod_lf = lf.select(
[
pl.col("a").cast(pl.UInt64).cumprod(),
pl.col("b").cumprod(),
pl.col("c").cumprod(),
]
)
assert cumprod_lf.schema["a"] == pl.UInt64
assert cumprod_lf.schema["b"] == pl.Int64
assert cumprod_lf.schema["c"] == pl.Float64
collected_cumprod_lf = cumprod_lf.collect()
assert collected_cumprod_lf.schema == cumprod_lf.schema

0 comments on commit d29fa30

Please sign in to comment.