Skip to content

Commit

Permalink
fix categorical sorts
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 14, 2021
1 parent bfe17da commit d1d5661
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 8 deletions.
15 changes: 15 additions & 0 deletions polars/polars-core/src/chunked_array/categorical/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ pub enum RevMapping {

#[allow(clippy::len_without_is_empty)]
impl RevMapping {
/// Get the length of the [`RevMapping`]
pub fn len(&self) -> usize {
match self {
Self::Global(_, a, _) => a.len(),
Expand All @@ -66,6 +67,20 @@ impl RevMapping {
Self::Local(a) => a.value(idx as usize),
}
}

/// Categorical to str
///
/// # Safety:
/// This doesn't do any bound checking
pub unsafe fn get_unchecked(&self, idx: u32) -> &str {
match self {
Self::Global(map, a, _) => {
let idx = *map.get(&idx).unwrap();
a.value_unchecked(idx as usize)
}
Self::Local(a) => a.value_unchecked(idx as usize),
}
}
/// Check if the categoricals are created under the same global string cache.
pub fn same_src(&self, other: &Self) -> bool {
match (self, other) {
Expand Down
10 changes: 7 additions & 3 deletions polars/polars-core/src/chunked_array/categorical/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,13 @@ impl<'a> Iterator for CatIter<'a> {
type Item = Option<&'a str>;

fn next(&mut self) -> Option<Self::Item> {
self.iter
.next()
.map(|item| item.map(|idx| self.rev.get(idx)))
self.iter.next().map(|item| {
item.map(|idx| {
// Safety:
// all categories are in bound
unsafe { self.rev.get_unchecked(idx) }
})
})
}

fn size_hint(&self) -> (usize, Option<usize>) {
Expand Down
66 changes: 61 additions & 5 deletions polars/polars-core/src/chunked_array/ops/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ use rayon::prelude::*;
use std::cmp::Ordering;
use std::hint::unreachable_unchecked;
use std::iter::FromIterator;
#[cfg(feature = "dtype-categorical")]
use std::ops::Deref;

/// # Safety
/// only may produce true, for f32/f64::NaN
Expand Down Expand Up @@ -629,15 +627,58 @@ impl ChunkSort<Utf8Type> for Utf8Chunked {
#[cfg(feature = "dtype-categorical")]
impl ChunkSort<CategoricalType> for CategoricalChunked {
fn sort_with(&self, options: SortOptions) -> ChunkedArray<CategoricalType> {
self.deref().sort_with(options).into()
assert!(
!options.nulls_last,
"null last not yet supported for categorical dtype"
);
let mut vals = self
.into_iter()
.zip(self.iter_str())
.trust_my_length(self.len())
.collect_trusted::<Vec<_>>();

argsort_branch(
vals.as_mut_slice(),
options.descending,
|(_, a), (_, b)| order_default_null(a, b),
|(_, a), (_, b)| order_reverse_null(a, b),
);
let arr: UInt32Array = vals.into_iter().map(|(idx, _v)| idx).collect_trusted();
let mut ca = self.clone();
ca.chunks = vec![Arc::new(arr)];

ca
}

fn sort(&self, reverse: bool) -> Self {
self.deref().sort(reverse).into()
self.sort_with(SortOptions {
nulls_last: false,
descending: reverse,
})
}

fn argsort(&self, reverse: bool) -> UInt32Chunked {
self.deref().argsort(reverse)
let mut count: u32 = 0;
let mut vals = self
.iter_str()
.map(|s| {
let i = count;
count += 1;
(i, s)
})
.trust_my_length(self.len())
.collect_trusted::<Vec<_>>();

argsort_branch(
vals.as_mut_slice(),
reverse,
|(_, a), (_, b)| order_default_null(a, b),
|(_, a), (_, b)| order_reverse_null(a, b),
);
let ca: NoNull<UInt32Chunked> = vals.into_iter().map(|(idx, _v)| idx).collect_trusted();
let mut ca = ca.into_inner();
ca.rename(self.name());
ca
}
}

Expand Down Expand Up @@ -883,4 +924,19 @@ mod test {
let expected = &[Some("c"), Some("b"), Some("a")];
assert_eq!(Vec::from(&out), expected);
}

#[test]
#[cfg(feature = "dtype-categorical")]
fn test_sort_categorical() {
let ca = Utf8Chunked::new("a", &[Some("a"), None, Some("c"), None, Some("b")]);
let ca = ca.cast(&DataType::Categorical).unwrap();
let ca = ca.categorical().unwrap();
let out = ca.sort_with(SortOptions {
descending: false,
nulls_last: false,
});
let out = out.iter_str().collect::<Vec<_>>();
let expected = &[None, None, Some("a"), Some("b"), Some("c")];
assert_eq!(out, expected);
}
}

0 comments on commit d1d5661

Please sign in to comment.