Skip to content

Commit

Permalink
fix[rust, python]: make cut work on integer types (#4945)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 23, 2022
1 parent a3e772b commit bec22c4
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 9 deletions.
19 changes: 11 additions & 8 deletions polars/polars-algo/src/algo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,17 @@ pub fn cut(
.with_columns([col(category_str).cast(DataType::Categorical(None))])
.collect()?;

s.sort(false).into_frame().join_asof(
&cuts,
var_name,
breakpoint_str,
AsofStrategy::Forward,
None,
None,
)
s.cast(&DataType::Float64)?
.sort(false)
.into_frame()
.join_asof(
&cuts,
var_name,
breakpoint_str,
AsofStrategy::Forward,
None,
None,
)
}

#[test]
Expand Down
3 changes: 2 additions & 1 deletion py-polars/polars/internals/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,8 @@ def cut(
cuts_df = cuts_df.with_column(pli.col(category_label).cast(Categorical))

result = (
s.sort()
s.cast(Float64)
.sort()
.to_frame()
.join_asof(
cuts_df,
Expand Down
5 changes: 5 additions & 0 deletions py-polars/tests/unit/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ def test_cut() -> None:
],
}

# test cut on integers #4939
df = pl.DataFrame({"a": list(range(5))})
ser = df.select("a").to_series()
assert pl.cut(ser, bins=[-1, 1]).shape == (5, 3)


def test_null_handling_correlation() -> None:
df = pl.DataFrame({"a": [1, 2, 3, None, 4], "b": [1, 2, 3, 10, 4]})
Expand Down

0 comments on commit bec22c4

Please sign in to comment.