Skip to content

Commit

Permalink
fix[rust]: handle NaNs in argsort (#4497)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 19, 2022
1 parent 90f1feb commit 5dc495c
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 8 deletions.
6 changes: 3 additions & 3 deletions polars/polars-arrow/src/kernels/rolling/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type Len = usize;

pub fn compare_fn_nan_min<T>(a: &T, b: &T) -> Ordering
where
T: PartialOrd + IsFloat + NativeType,
T: PartialOrd + IsFloat,
{
if T::is_float() {
match (a.is_nan(), b.is_nan()) {
Expand All @@ -41,9 +41,9 @@ where
}
}

fn compare_fn_nan_max<T>(a: &T, b: &T) -> Ordering
pub fn compare_fn_nan_max<T>(a: &T, b: &T) -> Ordering
where
T: PartialOrd + IsFloat + NativeType,
T: PartialOrd + IsFloat,
{
if T::is_float() {
match (a.is_nan(), b.is_nan()) {
Expand Down
7 changes: 4 additions & 3 deletions polars/polars-core/src/chunked_array/ops/sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use std::iter::FromIterator;
use arrow::{bitmap::MutableBitmap, buffer::Buffer};
use num::Float;
use polars_arrow::array::default_arrays::FromDataUtf8;
use polars_arrow::kernels::rolling::compare_fn_nan_max;
use polars_arrow::prelude::{FromData, ValueSize};
use polars_arrow::trusted_len::PushUnchecked;
use rayon::prelude::*;
Expand Down Expand Up @@ -82,14 +83,14 @@ fn sort_branch<T, Fd, Fr>(
#[cfg(feature = "private")]
pub fn argsort_no_nulls<Idx, T>(slice: &mut [(Idx, T)], reverse: bool)
where
T: PartialOrd + Send,
T: PartialOrd + Send + IsFloat,
Idx: PartialOrd + Send,
{
argsort_branch(
slice,
reverse,
|(_, a), (_, b)| a.partial_cmp(b).unwrap(),
|(_, a), (_, b)| b.partial_cmp(a).unwrap(),
|(_, a), (_, b)| compare_fn_nan_max(a, b),
|(_, a), (_, b)| compare_fn_nan_max(b, a),
);
}

Expand Down
4 changes: 2 additions & 2 deletions polars/polars-ops/src/series/ops/search_sorted.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::cmp::Ordering;

use polars_arrow::kernels::rolling::compare_fn_nan_min;
use polars_arrow::kernels::rolling::compare_fn_nan_max;
use polars_arrow::prelude::*;
use polars_core::downcast_as_macro_arg_physical;
use polars_core::export::num::NumCast;
Expand All @@ -27,7 +27,7 @@ where
// - `mid < size`: `mid` is limited by `[left; right)` bound.
let cmp = match unsafe { taker.get_unchecked(mid as usize) } {
None => Ordering::Less,
Some(value) => compare_fn_nan_min(&value, &search_value),
Some(value) => compare_fn_nan_max(&value, &search_value),
};

// The reason why we use if/else control flow rather than match
Expand Down
17 changes: 17 additions & 0 deletions py-polars/tests/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,20 @@ def test_sorted_fast_paths() -> None:
assert rev.to_list() == [None, 3, 2, 1]
assert rev.sort(reverse=True).to_list() == [None, 3, 2, 1]
assert rev.sort().to_list() == [None, 1, 2, 3]


def test_argsort_rank_nans() -> None:
assert (
pl.DataFrame(
{
"val": [1.0, float("NaN")],
}
)
.with_columns(
[
pl.col("val").rank().alias("rank"),
pl.col("val").argsort().alias("argsort"),
]
)
.select(["rank", "argsort"])
).to_dict(False) == {"rank": [1.0, 2.0], "argsort": [0, 1]}

0 comments on commit 5dc495c

Please sign in to comment.