Skip to content

Commit

Permalink
handle nulls separately in argsort
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Nov 27, 2021
1 parent c8e8210 commit 37d3c54
Showing 1 changed file with 78 additions and 16 deletions.
94 changes: 78 additions & 16 deletions polars/polars-core/src/chunked_array/ops/sort.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::prelude::compare_inner::PartialOrdInner;
use crate::prelude::*;
use crate::utils::{CustomIterTools, NoNull};
use arrow::bitmap::MutableBitmap;
use arrow::{bitmap::MutableBitmap, buffer::Buffer};
use itertools::Itertools;
use polars_arrow::prelude::ValueSize;
use polars_arrow::trusted_len::PushUnchecked;
Expand Down Expand Up @@ -153,14 +153,14 @@ where
if self.has_validity() {
// if the nulls are already last we can clone
if options.nulls_last && self.get(self.len() - 1).is_none() ||
// if the nulls are already fist we can clone
// if the nulls are already first we can clone
self.get(0).is_none()
{
return self.clone();
}
// nulls are not at the right place
// continue w/ sorting
// TODO: we can optimize here and just put the null as the correct place
// TODO: we can optimize here and just put the null at the correct place
} else {
return self.clone();
}
Expand Down Expand Up @@ -248,7 +248,7 @@ where
}

fn argsort(&self, reverse: bool) -> UInt32Chunked {
let ca: NoNull<UInt32Chunked> = if !self.has_validity() {
if !self.has_validity() {
let mut vals = Vec::with_capacity(self.len());
let mut count: u32 = 0;
self.downcast_iter().for_each(|arr| {
Expand All @@ -268,30 +268,64 @@ where
|(_, a), (_, b)| b.partial_cmp(a).unwrap(),
);

vals.into_iter().map(|(idx, _v)| idx).collect_trusted()
let ca: NoNull<UInt32Chunked> = vals.into_iter().map(|(idx, _v)| idx).collect_trusted();
let mut ca = ca.into_inner();
ca.rename(self.name());
ca
} else {
let mut vals = Vec::with_capacity(self.len());
let null_count = self.null_count();
let len = self.len();
let mut vals = Vec::with_capacity(len - null_count);

// if we sort reverse, the nulls are last
// and need to be extended to the indices in reverse order
let null_cap = if reverse {
null_count
// if we sort normally, the nulls are first
// and can be extended with the sorted indices
} else {
len
};
let mut nulls_idx = Vec::with_capacity(null_cap);
let mut count: u32 = 0;
self.downcast_iter().for_each(|arr| {
let iter = arr.iter().map(|v| {
let iter = arr.iter().filter_map(|v| {
let i = count;
count += 1;
(i, v.copied())
match v {
Some(v) => Some((i, *v)),
None => {
// Safety:
// we allocated enough
unsafe { nulls_idx.push_unchecked(i) };
None
}
}
});
vals.extend_trusted_len(iter);
vals.extend(iter);
});

argsort_branch(
vals.as_mut_slice(),
reverse,
|(_, a), (_, b)| order_default_null(a, b),
|(_, a), (_, b)| order_reverse_null(a, b),
|(_, a), (_, b)| a.partial_cmp(b).unwrap(),
|(_, a), (_, b)| b.partial_cmp(a).unwrap(),
);
vals.into_iter().map(|(idx, _v)| idx).collect_trusted()
};
let mut ca = ca.into_inner();
ca.rename(self.name());
ca

let iter = vals.into_iter().map(|(idx, _v)| idx);
let idx = if reverse {
let mut idx = Vec::with_capacity(len);
idx.extend(iter);
idx.extend(nulls_idx.into_iter().rev());
idx
} else {
nulls_idx.extend(iter);
nulls_idx
};

let arr = UInt32Array::from_data(ArrowDataType::UInt32, Buffer::from_vec(idx), None);
UInt32Chunked::new_from_chunks(self.name(), vec![Arc::new(arr)])
}
}

#[cfg(feature = "sort_multiple")]
Expand Down Expand Up @@ -539,6 +573,34 @@ pub(crate) fn prepare_argsort(
mod test {
use crate::prelude::*;

#[test]
fn test_argsort() {
let a = Int32Chunked::new_from_opt_slice(
"a",
&[
Some(1), // 0
Some(5), // 1
None, // 2
Some(1), // 3
None, // 4
Some(4), // 5
Some(3), // 6
Some(1), // 7
],
);
let idx = a.argsort(false);
let idx = idx.cont_slice().unwrap();

let expected = [2, 4, 0, 3, 7, 6, 5, 1];
assert_eq!(idx, expected);

let idx = a.argsort(true);
let idx = idx.cont_slice().unwrap();
// the duplicates are in reverse order of appearance, so we cannot reverse expected
let expected = [1, 5, 6, 0, 3, 7, 4, 2];
assert_eq!(idx, expected);
}

#[test]
fn test_sort() {
let a = Int32Chunked::new_from_opt_slice(
Expand Down

0 comments on commit 37d3c54

Please sign in to comment.