Skip to content

Commit

Permalink
Optionally order categorical with lexical ordering (#2726)
Browse files Browse the repository at this point in the history
* lexical_cats

* impl CategoricalTakeRandom

* expose python
  • Loading branch information
ritchie46 committed Feb 22, 2022
1 parent cb3cf56 commit 0e1d524
Show file tree
Hide file tree
Showing 24 changed files with 679 additions and 224 deletions.
2 changes: 1 addition & 1 deletion polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ dtype-i8 = ["polars-core/dtype-i8", "polars-lazy/dtype-i8"]
dtype-i16 = ["polars-core/dtype-i16", "polars-lazy/dtype-i16"]
dtype-u8 = ["polars-core/dtype-u8", "polars-lazy/dtype-u8"]
dtype-u16 = ["polars-core/dtype-u16", "polars-lazy/dtype-u16"]
dtype-categorical = ["polars-core/dtype-categorical", "polars-io/dtype-categorical"]
dtype-categorical = ["polars-core/dtype-categorical", "polars-io/dtype-categorical", "polars-lazy/dtype-categorical"]

docs-selection = [
"csv-file",
Expand Down
18 changes: 18 additions & 0 deletions polars/polars-core/src/chunked_array/logical/categorical/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ mod series;
use super::*;
use crate::prelude::*;
pub use builder::*;
pub(crate) use ops::{CategoricalTakeRandomGlobal, CategoricalTakeRandomLocal};

#[derive(Clone)]
pub struct CategoricalChunked {
logical: Logical<CategoricalType, UInt32Type>,
/// 1st bit: original local categorical
/// meaning that n_unique is the same as the cat map length
/// 2nd bit: use lexical sorting
bit_settings: u8,
}

Expand All @@ -30,6 +32,10 @@ impl CategoricalChunked {
self.logical.len()
}

pub(crate) fn name(&self) -> &str {
self.logical.name()
}

/// Get a reference to the logical array (the categories).
pub(crate) fn logical(&self) -> &UInt32Chunked {
&self.logical
Expand All @@ -56,6 +62,18 @@ impl CategoricalChunked {
}
}

pub fn set_lexical_sorted(&mut self, toggle: bool) {
if toggle {
self.bit_settings |= 1u8 << 1;
} else {
self.bit_settings &= !(1u8 << 1);
}
}

pub(crate) fn use_lexical_sort(&self) -> bool {
self.bit_settings & 1 << 1 != 0
}

pub(crate) fn from_cats_and_rev_map(idx: UInt32Chunked, rev_map: Arc<RevMapping>) -> Self {
let mut logical = Logical::<UInt32Type, _>::new_logical::<CategoricalType>(idx);
logical.2 = Some(DataType::Categorical(Some(rev_map)));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
mod append;
mod full;
mod take_random;
mod unique;
#[cfg(feature = "zip_with")]
mod zip;

use super::*;
pub(crate) use take_random::{CategoricalTakeRandomGlobal, CategoricalTakeRandomLocal};
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
use crate::prelude::compare_inner::PartialOrdInner;
use crate::prelude::{
CategoricalChunked, IntoTakeRandom, NumTakeRandomChunked, NumTakeRandomCont,
NumTakeRandomSingleChunk, PlHashMap, RevMapping, TakeRandBranch3, TakeRandom,
};
use arrow::array::Utf8Array;
use std::cmp::Ordering;

type TakeCats<'a> = TakeRandBranch3<
NumTakeRandomCont<'a, u32>,
NumTakeRandomSingleChunk<'a, u32>,
NumTakeRandomChunked<'a, u32>,
>;

pub(crate) struct CategoricalTakeRandomLocal<'a> {
rev_map: &'a Utf8Array<i64>,
cats: TakeCats<'a>,
}

impl<'a> CategoricalTakeRandomLocal<'a> {
pub(crate) fn new(ca: &'a CategoricalChunked) -> Self {
// should be rechunked upstream
assert_eq!(ca.logical.chunks.len(), 1, "implementation error");
if let RevMapping::Local(rev_map) = &**ca.get_rev_map() {
let cats = ca.logical().take_rand();
Self { rev_map, cats }
} else {
unreachable!()
}
}
}

impl PartialOrdInner for CategoricalTakeRandomLocal<'_> {
unsafe fn cmp_element_unchecked(&self, idx_a: usize, idx_b: usize) -> Ordering {
let a = self
.cats
.get_unchecked(idx_a)
.map(|cat| self.rev_map.value_unchecked(cat as usize));
let b = self
.cats
.get_unchecked(idx_b)
.map(|cat| self.rev_map.value_unchecked(cat as usize));
a.partial_cmp(&b).unwrap()
}
}

pub(crate) struct CategoricalTakeRandomGlobal<'a> {
rev_map_part_1: &'a PlHashMap<u32, u32>,
rev_map_part_2: &'a Utf8Array<i64>,
cats: TakeCats<'a>,
}
impl<'a> CategoricalTakeRandomGlobal<'a> {
pub(crate) fn new(ca: &'a CategoricalChunked) -> Self {
// should be rechunked upstream
assert_eq!(ca.logical.chunks.len(), 1, "implementation error");
if let RevMapping::Global(rev_map_part_1, rev_map_part_2, _) = &**ca.get_rev_map() {
let cats = ca.logical().take_rand();
Self {
rev_map_part_1,
rev_map_part_2,
cats,
}
} else {
unreachable!()
}
}
}

impl PartialOrdInner for CategoricalTakeRandomGlobal<'_> {
unsafe fn cmp_element_unchecked(&self, idx_a: usize, idx_b: usize) -> Ordering {
let a = self.cats.get_unchecked(idx_a).map(|cat| {
let idx = self.rev_map_part_1.get(&cat).unwrap();
self.rev_map_part_2.value_unchecked(*idx as usize)
});
let b = self.cats.get_unchecked(idx_b).map(|cat| {
let idx = self.rev_map_part_1.get(&cat).unwrap();
self.rev_map_part_2.value_unchecked(*idx as usize)
});
a.partial_cmp(&b).unwrap()
}
}
30 changes: 7 additions & 23 deletions polars/polars-core/src/chunked_array/ops/compare_inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ where
};
Box::new(t)
} else {
let t = NumTakeRandomSingleChunk::<'_, T::Native> { arr };
let t = NumTakeRandomSingleChunk::<'_, T::Native>::new(arr);
Box::new(t)
}
} else {
Expand Down Expand Up @@ -159,19 +159,6 @@ impl<'a> IntoPartialEqInner<'a> for &'a BooleanChunked {
}
}

impl<'a> IntoPartialEqInner<'a> for &'a ListChunked {
fn into_partial_eq_inner(self) -> Box<dyn PartialEqInner> {
unimplemented!()
}
}

#[cfg(feature = "dtype-categorical")]
impl<'a> IntoPartialEqInner<'a> for &'a CategoricalChunked {
fn into_partial_eq_inner(self) -> Box<dyn PartialEqInner> {
unimplemented!()
}
}

// Partial ordering implementations

fn fallback<T: PartialEq>(a: T) -> Ordering {
Expand Down Expand Up @@ -219,7 +206,7 @@ where
};
Box::new(t)
} else {
let t = NumTakeRandomSingleChunk::<'_, T::Native> { arr };
let t = NumTakeRandomSingleChunk::<'_, T::Native>::new(arr);
Box::new(t)
}
} else {
Expand Down Expand Up @@ -272,16 +259,13 @@ impl<'a> IntoPartialOrdInner<'a> for &'a BooleanChunked {
}
}

impl<'a> IntoPartialOrdInner<'a> for &'a ListChunked {
fn into_partial_ord_inner(self) -> Box<dyn PartialOrdInner> {
unimplemented!()
}
}

#[cfg(feature = "dtype-categorical")]
impl<'a> IntoPartialOrdInner<'a> for &'a CategoricalChunked {
fn into_partial_ord_inner(self) -> Box<dyn PartialOrdInner> {
unimplemented!()
fn into_partial_ord_inner(self) -> Box<dyn PartialOrdInner + 'a> {
match &**self.get_rev_map() {
RevMapping::Local(_) => Box::new(CategoricalTakeRandomLocal::new(self)),
RevMapping::Global(_, _, _) => Box::new(CategoricalTakeRandomGlobal::new(self)),
}
}
}

Expand Down
54 changes: 54 additions & 0 deletions polars/polars-core/src/chunked_array/ops/sort/argsort_multiple.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
use super::*;

pub(crate) fn args_validate<T: PolarsDataType>(
ca: &ChunkedArray<T>,
other: &[Series],
reverse: &[bool],
) -> Result<()> {
for s in other {
assert_eq!(ca.len(), s.len());
}
if other.len() != (reverse.len() - 1) {
return Err(PolarsError::ValueError(
format!(
"The amount of ordering booleans: {} does not match that no. of Series: {}",
reverse.len(),
other.len() + 1
)
.into(),
));
}

assert_eq!(other.len(), reverse.len() - 1);
Ok(())
}

pub(crate) fn argsort_multiple_impl<T: PartialOrd>(
mut vals: Vec<(IdxSize, Option<T>)>,
other: &[Series],
reverse: &[bool],
) -> Result<IdxCa> {
let compare_inner: Vec<_> = other
.iter()
.map(|s| s.into_partial_ord_inner())
.collect_trusted();

vals.sort_by(
|tpl_a, tpl_b| match (reverse[0], sort_with_nulls(&tpl_a.1, &tpl_b.1)) {
// if ordering is equal, we check the other arrays until we find a non-equal ordering
// if we have exhausted all arrays, we keep the equal ordering.
(_, Ordering::Equal) => {
let idx_a = tpl_a.0 as usize;
let idx_b = tpl_b.0 as usize;
ordering_other_columns(&compare_inner, &reverse[1..], idx_a, idx_b)
}
(true, Ordering::Less) => Ordering::Greater,
(true, Ordering::Greater) => Ordering::Less,
(_, ord) => ord,
},
);
let ca: NoNull<IdxCa> = vals.into_iter().map(|(idx, _v)| idx).collect_trusted();
let mut ca = ca.into_inner();
ca.set_sorted(reverse[0]);
Ok(ca)
}

0 comments on commit 0e1d524

Please sign in to comment.