Skip to content

Commit

Permalink
check nan in sort by single column (#3742)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jun 19, 2022
1 parent 6cafa8b commit 8bbff3a
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 25 deletions.
8 changes: 5 additions & 3 deletions polars/polars-arrow/src/data_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pub unsafe trait IsFloat: private::Sealed {
}

#[allow(clippy::wrong_self_convention)]
fn is_nan(self) -> bool
fn is_nan(&self) -> bool
where
Self: Sized,
{
Expand All @@ -25,6 +25,7 @@ unsafe impl IsFloat for u16 {}
unsafe impl IsFloat for u32 {}
unsafe impl IsFloat for u64 {}
unsafe impl IsFloat for &str {}
unsafe impl IsFloat for bool {}
unsafe impl<T: IsFloat> IsFloat for Option<T> {}

mod private {
Expand All @@ -40,6 +41,7 @@ mod private {
impl Sealed for f32 {}
impl Sealed for f64 {}
impl Sealed for &str {}
impl Sealed for bool {}
impl<T: Sealed> Sealed for Option<T> {}
}

Expand All @@ -50,8 +52,8 @@ macro_rules! impl_is_float {
true
}

fn is_nan(self) -> bool {
<$tp>::is_nan(self)
fn is_nan(&self) -> bool {
<$tp>::is_nan(*self)
}
}
};
Expand Down
16 changes: 9 additions & 7 deletions polars/polars-core/src/chunked_array/ops/sort/argsort.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use super::*;

fn default_order<T: PartialOrd>(a: &(IdxSize, T), b: &(IdxSize, T)) -> Ordering {
a.1.partial_cmp(&b.1).unwrap()
#[inline]
fn default_order<T: PartialOrd + IsFloat>(a: &(IdxSize, T), b: &(IdxSize, T)) -> Ordering {
sort_cmp(&a.1, &b.1)
}

fn reverse_order<T: PartialOrd>(a: &(IdxSize, T), b: &(IdxSize, T)) -> Ordering {
b.1.partial_cmp(&a.1).unwrap()
#[inline]
fn reverse_order<T: PartialOrd + IsFloat>(a: &(IdxSize, T), b: &(IdxSize, T)) -> Ordering {
sort_cmp(&b.1, &a.1)
}

pub(super) fn argsort<I, J, K>(
pub(super) fn argsort<I, J, T>(
name: &str,
iters: I,
options: SortOptions,
Expand All @@ -17,8 +19,8 @@ pub(super) fn argsort<I, J, K>(
) -> IdxCa
where
I: IntoIterator<Item = J>,
J: IntoIterator<Item = Option<K>>,
K: PartialOrd + Send + Sync,
J: IntoIterator<Item = Option<T>>,
T: PartialOrd + Send + Sync + IsFloat,
{
let reverse = options.descending;
let nulls_last = options.nulls_last;
Expand Down
15 changes: 0 additions & 15 deletions polars/polars-core/src/chunked_array/ops/sort/argsort_multiple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,6 @@ pub(crate) fn args_validate<T: PolarsDataType>(
Ok(())
}

fn sort_cmp<T: PartialOrd + IsFloat + Copy>(a: &T, b: &T) -> Ordering {
if T::is_float() {
match (a.is_nan(), b.is_nan()) {
// safety: we checked nans
(false, false) => unsafe { a.partial_cmp(b).unwrap_unchecked() },
(true, true) => Ordering::Equal,
(true, false) => Ordering::Greater,
(false, true) => Ordering::Less,
}
} else {
// no floats, so we can compare unchecked
unsafe { a.partial_cmp(b).unwrap_unchecked() }
}
}

pub(crate) fn argsort_multiple_impl<T: PartialOrd + Send + IsFloat + Copy>(
mut vals: Vec<(IdxSize, T)>,
other: &[Series],
Expand Down
16 changes: 16 additions & 0 deletions polars/polars-core/src/chunked_array/ops/sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,22 @@ use std::cmp::Ordering;
use std::hint::unreachable_unchecked;
use std::iter::FromIterator;

#[inline]
fn sort_cmp<T: PartialOrd + IsFloat>(a: &T, b: &T) -> Ordering {
if T::is_float() {
match (a.is_nan(), b.is_nan()) {
// safety: we checked nans
(false, false) => unsafe { a.partial_cmp(b).unwrap_unchecked() },
(true, true) => Ordering::Equal,
(true, false) => Ordering::Greater,
(false, true) => Ordering::Less,
}
} else {
// no floats, so we can compare unchecked
unsafe { a.partial_cmp(b).unwrap_unchecked() }
}
}

/// Reverse sorting when there are no nulls
fn order_reverse<T: Ord>(a: &T, b: &T) -> Ordering {
b.cmp(a)
Expand Down
10 changes: 10 additions & 0 deletions py-polars/tests/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,13 @@ def test_argsort_window_functions() -> None:
assert (
out["arg_sort"].to_list() == out["argsort_by"].to_list() == [0, 1, 0, 1, 0, 1]
)


def test_sort_nans_3740() -> None:
df = pl.DataFrame(
{
"key": [1, 2, 3, 4, 5],
"val": [0.0, None, float("nan"), float("-inf"), float("inf")],
}
)
assert df.sort("val")["key"].to_list() == [2, 4, 1, 5, 3]

0 comments on commit 8bbff3a

Please sign in to comment.