Skip to content

Commit

Permalink
fix: prevent panic with arg_sort_by
Browse files Browse the repository at this point in the history
  • Loading branch information
eitsupi committed Mar 23, 2024
1 parent 911abc3 commit 289d24f
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
4 changes: 2 additions & 2 deletions crates/polars-lazy/src/tests/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1048,7 +1048,7 @@ fn test_arg_sort_multiple() -> PolarsResult<()> {
let out = df
.clone()
.lazy()
.select([arg_sort_by([col("int"), col("flt")], &[true, false])])
.select([arg_sort_by([col("int"), col("flt")], &[true, false]).unwrap()])
.collect()?;

assert_eq!(
Expand All @@ -1063,7 +1063,7 @@ fn test_arg_sort_multiple() -> PolarsResult<()> {
// check if this runs
let _out = df
.lazy()
.select([arg_sort_by([col("str"), col("flt")], &[true, false])])
.select([arg_sort_by([col("str"), col("flt")], &[true, false]).unwrap()])
.collect()?;
Ok(())
}
Expand Down
12 changes: 7 additions & 5 deletions crates/polars-plan/src/dsl/functions/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ use super::*;
/// until duplicates are found. Once duplicates are found, the next `Series` will
/// be used and so on.
#[cfg(feature = "range")]
pub fn arg_sort_by<E: AsRef<[Expr]>>(by: E, descending: &[bool]) -> Expr {
pub fn arg_sort_by<E: AsRef<[Expr]>>(by: E, descending: &[bool]) -> PolarsResult<Expr> {
let e = &by.as_ref()[0];
let name = expr_output_name(e).unwrap();
int_range(lit(0 as IdxSize), len().cast(IDX_DTYPE), 1, IDX_DTYPE)
.sort_by(by, descending)
.alias(name.as_ref())
let name = expr_output_name(e)?;
Ok(
int_range(lit(0 as IdxSize), len().cast(IDX_DTYPE), 1, IDX_DTYPE)
.sort_by(by, descending)
.alias(name.as_ref()),
)
}

#[cfg(feature = "arg_where")]
Expand Down
5 changes: 3 additions & 2 deletions py-polars/src/functions/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,10 @@ pub fn rolling_cov(
}

#[pyfunction]
pub fn arg_sort_by(by: Vec<PyExpr>, descending: Vec<bool>) -> PyExpr {
pub fn arg_sort_by(by: Vec<PyExpr>, descending: Vec<bool>) -> PyResult<PyExpr> {
let by = by.into_iter().map(|e| e.inner).collect::<Vec<Expr>>();
dsl::arg_sort_by(by, &descending).into()
let e = dsl::arg_sort_by(by, &descending).map_err(PyPolarsErr::from)?;
Ok(e.into())
}

#[pyfunction]
Expand Down

0 comments on commit 289d24f

Please sign in to comment.