Skip to content

Commit

Permalink
improve performance of sorting by multiple columns
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jul 5, 2021
1 parent d803ae1 commit a541573
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 127 deletions.
143 changes: 142 additions & 1 deletion polars/polars-core/src/chunked_array/ops/compare_inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,20 @@ use super::take_random::{
NumTakeRandomSingleChunk, Utf8TakeRandom, Utf8TakeRandomSingleChunk,
};
use crate::prelude::*;
use std::cmp::PartialEq;
use std::cmp::{Ordering, PartialEq};

pub trait PartialEqInner: Send + Sync {
/// Safety:
/// Does not do any bound checks
unsafe fn eq_element_unchecked(&self, idx_a: usize, idx_b: usize) -> bool;
}

pub trait PartialOrdInner: Send + Sync {
/// Safety:
/// Does not do any bound checks
unsafe fn cmp_element_unchecked(&self, idx_a: usize, idx_b: usize) -> Ordering;
}

macro_rules! impl_traits {
($struct:ty) => {
impl PartialEqInner for $struct {
Expand All @@ -23,6 +29,14 @@ macro_rules! impl_traits {
self.get(idx_a) == self.get(idx_b)
}
}
impl PartialOrdInner for $struct {
#[inline]
unsafe fn cmp_element_unchecked(&self, idx_a: usize, idx_b: usize) -> Ordering {
let a = self.get(idx_a);
let b = self.get(idx_b);
a.partial_cmp(&b).unwrap_or_else(|| fallback(a))
}
}
};
($struct:ty, $T:tt) => {
impl<$T> PartialEqInner for $struct
Expand All @@ -35,6 +49,20 @@ macro_rules! impl_traits {
self.get(idx_a) == self.get(idx_b)
}
}

impl<$T> PartialOrdInner for $struct
where
$T: PolarsNumericType + Sync,
$T::Native: Copy + PartialOrd + Sync,
{
#[inline]
unsafe fn cmp_element_unchecked(&self, idx_a: usize, idx_b: usize) -> Ordering {
// nulls so we can not do unchecked
let a = self.get(idx_a);
let b = self.get(idx_b);
a.partial_cmp(&b).unwrap_or_else(|| fallback(a))
}
}
};
}

Expand Down Expand Up @@ -143,3 +171,116 @@ impl<'a> IntoPartialEqInner<'a> for &'a CategoricalChunked {
unimplemented!()
}
}

// Partial ordering implementations

fn fallback<T: PartialEq>(a: T) -> Ordering {
// nan != nan
// this is a simple way to check if it is nan
// without convincing the compiler we deal with floats
#[allow(clippy::eq_op)]
if a != a {
Ordering::Less
} else {
Ordering::Greater
}
}

impl<T> PartialOrdInner for NumTakeRandomCont<'_, T>
where
T: Copy + PartialOrd + Sync,
{
unsafe fn cmp_element_unchecked(&self, idx_a: usize, idx_b: usize) -> Ordering {
// no nulls so we can do unchecked
let a = self.get_unchecked(idx_a);
let b = self.get_unchecked(idx_b);
a.partial_cmp(&b).unwrap_or_else(|| fallback(a))
}
}
/// Create a type that implements PartialOrdInner
pub(crate) trait IntoPartialOrdInner<'a> {
/// Create a type that implements `TakeRandom`.
fn into_partial_ord_inner(self) -> Box<dyn PartialOrdInner + 'a>;
}
/// We use a trait object because we want to call this from Series and cannot use a typed enum.
impl<'a, T> IntoPartialOrdInner<'a> for &'a ChunkedArray<T>
where
T: PolarsNumericType,
T::Native: PartialOrd,
{
fn into_partial_ord_inner(self) -> Box<dyn PartialOrdInner + 'a> {
let mut chunks = self.downcast_iter();

if self.chunks.len() == 1 {
let arr = chunks.next().unwrap();

if self.null_count() == 0 {
let t = NumTakeRandomCont {
slice: arr.values(),
};
Box::new(t)
} else {
let t = NumTakeRandomSingleChunk::<'_, T> { arr };
Box::new(t)
}
} else {
let t = NumTakeRandomChunked::<'_, T> {
chunks: chunks.collect(),
chunk_lens: self.chunks.iter().map(|a| a.len() as u32).collect(),
};
Box::new(t)
}
}
}

impl<'a> IntoPartialOrdInner<'a> for &'a Utf8Chunked {
fn into_partial_ord_inner(self) -> Box<dyn PartialOrdInner + 'a> {
match self.chunks.len() {
1 => {
let arr = self.downcast_iter().next().unwrap();
let t = Utf8TakeRandomSingleChunk { arr };
Box::new(t)
}
_ => {
let chunks = self.downcast_chunks();
let t = Utf8TakeRandom {
chunks,
chunk_lens: self.chunks.iter().map(|a| a.len() as u32).collect(),
};
Box::new(t)
}
}
}
}

impl<'a> IntoPartialOrdInner<'a> for &'a BooleanChunked {
fn into_partial_ord_inner(self) -> Box<dyn PartialOrdInner + 'a> {
match self.chunks.len() {
1 => {
let arr = self.downcast_iter().next().unwrap();
let t = BoolTakeRandomSingleChunk { arr };
Box::new(t)
}
_ => {
let chunks = self.downcast_chunks();
let t = BoolTakeRandom {
chunks,
chunk_lens: self.chunks.iter().map(|a| a.len() as u32).collect(),
};
Box::new(t)
}
}
}
}

impl<'a> IntoPartialOrdInner<'a> for &'a ListChunked {
fn into_partial_ord_inner(self) -> Box<dyn PartialOrdInner> {
unimplemented!()
}
}

impl<'a> IntoPartialOrdInner<'a> for &'a CategoricalChunked {
fn into_partial_ord_inner(self) -> Box<dyn PartialOrdInner> {
unimplemented!()
}
}
151 changes: 27 additions & 124 deletions polars/polars-core/src/chunked_array/ops/sort.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::prelude::compare_inner::PartialOrdInner;
use crate::prelude::*;
use crate::utils::NoNull;
use itertools::Itertools;
Expand Down Expand Up @@ -225,6 +226,9 @@ where
}

assert_eq!(other.len(), reverse.len() - 1);

let compare_inner: Vec<_> = other.iter().map(|s| s.into_partial_ord_inner()).collect();

let mut count: u32 = 0;
let mut vals: Vec<_> = self
.into_iter()
Expand All @@ -242,68 +246,7 @@ where
(_, Ordering::Equal) => {
let idx_a = tpl_a.0 as usize;
let idx_b = tpl_b.0 as usize;

macro_rules! partial_ord_by_idx {
($ca: ident, $reverse: expr) => {{
// Safety:
// Indexes are in bounds, we asserted equal lengths above
let a;
let b;
if $reverse {
b = unsafe { $ca.get_unchecked(idx_a) };
a = unsafe { $ca.get_unchecked(idx_b) };
} else {
a = unsafe { $ca.get_unchecked(idx_a) };
b = unsafe { $ca.get_unchecked(idx_b) };
}

match (&a).partial_cmp(&b).unwrap() {
// also equal, try next array
Ordering::Equal => continue,
// this array is not equal, return
ord => return ord,
}
}};
}

// series should be matching type or utf8
for (s, reverse) in other.iter().zip(&reverse[1..]) {
match s.dtype() {
DataType::Utf8 => {
let ca = s.utf8().unwrap();
partial_ord_by_idx!(ca, *reverse)
}
DataType::Float32 => {
let ca = s.f32().unwrap();
partial_ord_by_idx!(ca, *reverse)
}
DataType::Float64 => {
let ca = s.f64().unwrap();
partial_ord_by_idx!(ca, *reverse)
}
DataType::Int64 => {
let ca = s.i64().unwrap();
partial_ord_by_idx!(ca, *reverse)
}
DataType::Int32 => {
let ca = s.i32().unwrap();
partial_ord_by_idx!(ca, *reverse)
}
DataType::UInt32 => {
let ca = s.u32().unwrap();
partial_ord_by_idx!(ca, *reverse)
}
DataType::UInt64 => {
let ca = s.u64().unwrap();
partial_ord_by_idx!(ca, *reverse)
}
_ => {
unreachable!()
}
}
}
// all arrays exhausted, ordering equal it is.
Ordering::Equal
ordering_other_columns(&compare_inner, &reverse[1..], idx_a, idx_b)
}
(true, Ordering::Less) => Ordering::Greater,
(true, Ordering::Greater) => Ordering::Less,
Expand All @@ -316,6 +259,26 @@ where
}
}

fn ordering_other_columns<'a>(
compare_inner: &'a [Box<dyn PartialOrdInner + 'a>],
reverse: &[bool],
idx_a: usize,
idx_b: usize,
) -> Ordering {
for (cmp, reverse) in compare_inner.iter().zip(reverse) {
// Safety:
// indices are in bounds
let ordering = unsafe { cmp.cmp_element_unchecked(idx_a, idx_b) };
match (ordering, reverse) {
(Ordering::Equal, _) => continue,
(_, true) => return ordering.reverse(),
_ => return ordering,
}
}
// all arrays/columns exhausted, ordering equal it is.
Ordering::Equal
}

macro_rules! sort {
($self:ident, $reverse:ident) => {{
if $reverse {
Expand Down Expand Up @@ -381,6 +344,7 @@ impl ChunkSort<Utf8Type> for Utf8Chunked {
(i, v)
})
.collect();
let compare_inner: Vec<_> = other.iter().map(|s| s.into_partial_ord_inner()).collect();

vals.sort_by(
|tpl_a, tpl_b| match (reverse[0], sort_with_nulls(&tpl_a.1, &tpl_b.1)) {
Expand All @@ -389,68 +353,7 @@ impl ChunkSort<Utf8Type> for Utf8Chunked {
(_, Ordering::Equal) => {
let idx_a = tpl_a.0 as usize;
let idx_b = tpl_b.0 as usize;

macro_rules! partial_ord_by_idx {
($ca: ident, $reverse: expr) => {{
// Safety:
// Indexes are in bounds, we asserted equal lengths above
let a;
let b;
if $reverse {
b = unsafe { $ca.get_unchecked(idx_a) };
a = unsafe { $ca.get_unchecked(idx_b) };
} else {
a = unsafe { $ca.get_unchecked(idx_a) };
b = unsafe { $ca.get_unchecked(idx_b) };
}

match (&a).partial_cmp(&b).unwrap() {
// also equal, try next array
Ordering::Equal => continue,
// this array is not equal, return
ord => return ord,
}
}};
}

// series should be matching type or utf8
for (s, reverse) in other.iter().zip(&reverse[1..]) {
match s.dtype() {
DataType::Utf8 => {
let ca = s.utf8().unwrap();
partial_ord_by_idx!(ca, *reverse)
}
DataType::Float32 => {
let ca = s.f32().unwrap();
partial_ord_by_idx!(ca, *reverse)
}
DataType::Float64 => {
let ca = s.f64().unwrap();
partial_ord_by_idx!(ca, *reverse)
}
DataType::Int64 => {
let ca = s.i64().unwrap();
partial_ord_by_idx!(ca, *reverse)
}
DataType::Int32 => {
let ca = s.i32().unwrap();
partial_ord_by_idx!(ca, *reverse)
}
DataType::UInt32 => {
let ca = s.u32().unwrap();
partial_ord_by_idx!(ca, *reverse)
}
DataType::UInt64 => {
let ca = s.u64().unwrap();
partial_ord_by_idx!(ca, *reverse)
}
_ => {
unreachable!()
}
}
}
// all arrays exhausted, ordering equal it is.
Ordering::Equal
ordering_other_columns(&compare_inner, &reverse[1..], idx_a, idx_b)
}
(true, Ordering::Less) => Ordering::Greater,
(true, Ordering::Greater) => Ordering::Less,
Expand Down
5 changes: 4 additions & 1 deletion polars/polars-core/src/series/implementations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::chunked_array::comparison::*;
use crate::chunked_array::{
ops::{
aggregate::{ChunkAggSeries, VarAggSeries},
compare_inner::{IntoPartialEqInner, PartialEqInner},
compare_inner::{IntoPartialEqInner, IntoPartialOrdInner, PartialEqInner, PartialOrdInner},
},
AsSinglePtr, ChunkIdIter,
};
Expand Down Expand Up @@ -76,6 +76,9 @@ macro_rules! impl_dyn_series {
fn into_partial_eq_inner<'a>(&'a self) -> Box<dyn PartialEqInner + 'a> {
(&self.0).into_partial_eq_inner()
}
fn into_partial_ord_inner<'a>(&'a self) -> Box<dyn PartialOrdInner + 'a> {
(&self.0).into_partial_ord_inner()
}

fn vec_hash(&self, random_state: RandomState) -> AlignedVec<u64> {
self.0.vec_hash(random_state)
Expand Down

0 comments on commit a541573

Please sign in to comment.