Skip to content

Commit

Permalink
add multi-ordering in multilevel sort and reduce compiler bloat
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed May 8, 2021
1 parent 99bfc4d commit 9447def
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 50 deletions.
2 changes: 1 addition & 1 deletion polars/polars-core/src/chunked_array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ pub trait ChunkSort<T> {
fn argsort(&self, reverse: bool) -> UInt32Chunked;

/// Retrieve the indexes need to sort this and the other arrays.
fn argsort_multiple(&self, _other: &[Series], _reverse: bool) -> Result<UInt32Chunked> {
fn argsort_multiple(&self, _other: &[Series], _reverse: &[bool]) -> Result<UInt32Chunked> {
Err(PolarsError::InvalidOperation(
"argsort_multiple not implemented for this dtype".into(),
))
Expand Down
80 changes: 45 additions & 35 deletions polars/polars-core/src/chunked_array/ops/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,30 +204,20 @@ where
///
/// This function is very opinionated.
/// We assume that all numeric `Series` are of the same type, if not it will panic
fn argsort_multiple(&self, other: &[Series], reverse: bool) -> Result<UInt32Chunked> {
fn argsort_multiple(&self, other: &[Series], reverse: &[bool]) -> Result<UInt32Chunked> {
for ca in other {
assert_eq!(self.len(), ca.len());
}
assert_eq!(other.len(), reverse.len() - 1);
let mut count: u32 = 0;
let mut vals: Vec<_> = match reverse {
true => self
.into_iter()
.rev()
.map(|v| {
let i = count;
count += 1;
(i, v)
})
.collect(),
false => self
.into_iter()
.map(|v| {
let i = count;
count += 1;
(i, v)
})
.collect(),
};
let mut vals: Vec<_> = self
.into_iter()
.map(|v| {
let i = count;
count += 1;
(i, v)
})
.collect();

vals.sort_by(|tpl_a, tpl_b| match sort_with_nulls(&tpl_a.1, &tpl_b.1) {
// if ordering is equal, we check the other arrays until we find a non-equal ordering
Expand All @@ -237,11 +227,18 @@ where
let idx_b = tpl_b.0 as usize;

macro_rules! partial_ord_by_idx {
($ca: ident) => {{
($ca: ident, $reverse: expr) => {{
// Safety:
// Indexes are in bounds, we asserted equal lengths above
let a = unsafe { $ca.get_unchecked(idx_a) };
let b = unsafe { $ca.get_unchecked(idx_b) };
let a;
let b;
if $reverse {
b = unsafe { $ca.get_unchecked(idx_a) };
a = unsafe { $ca.get_unchecked(idx_b) };
} else {
a = unsafe { $ca.get_unchecked(idx_a) };
b = unsafe { $ca.get_unchecked(idx_b) };
}

match (&a).partial_cmp(&b).unwrap() {
// also equal, try next array
Expand All @@ -253,17 +250,17 @@ where
}

// series should be matching type or utf8
for s in other {
for (s, reverse) in other.iter().zip(&reverse[1..]) {
match s.dtype() {
DataType::Utf8 => {
let ca = s.utf8().unwrap();
partial_ord_by_idx!(ca)
partial_ord_by_idx!(ca, *reverse)
}
_ => {
let ca = self
.unpack_series_matching_type(s)
.expect("should be same type");
partial_ord_by_idx!(ca)
partial_ord_by_idx!(ca, *reverse)
}
}
}
Expand Down Expand Up @@ -324,12 +321,18 @@ impl ChunkSort<Utf8Type> for Utf8Chunked {
///
/// In this case we assume that all numeric `Series` are `f64` types. The caller needs to
/// uphold this contract. If not, it will panic.
fn argsort_multiple(&self, other: &[Series], reverse: bool) -> Result<UInt32Chunked> {
///
fn argsort_multiple(&self, other: &[Series], reverse: &[bool]) -> Result<UInt32Chunked> {
for ca in other {
assert_eq!(self.len(), ca.len());
if self.len() != ca.len() {
return Err(PolarsError::ShapeMisMatch(
"sort column should have equal length".into(),
));
}
}
assert_eq!(other.len(), reverse.len() - 1);
let mut count: u32 = 0;
let mut vals: Vec<_> = match reverse {
let mut vals: Vec<_> = match reverse[0] {
true => self
.into_iter()
.rev()
Expand Down Expand Up @@ -357,11 +360,18 @@ impl ChunkSort<Utf8Type> for Utf8Chunked {
let idx_b = tpl_b.0 as usize;

macro_rules! partial_ord_by_idx {
($ca: ident) => {{
($ca: ident, $reverse: expr) => {{
// Safety:
// Indexes are in bounds, we asserted equal lengths above
let a = unsafe { $ca.get_unchecked(idx_a) };
let b = unsafe { $ca.get_unchecked(idx_b) };
let a;
let b;
if $reverse {
b = unsafe { $ca.get_unchecked(idx_a) };
a = unsafe { $ca.get_unchecked(idx_b) };
} else {
a = unsafe { $ca.get_unchecked(idx_a) };
b = unsafe { $ca.get_unchecked(idx_b) };
}

match (&a).partial_cmp(&b).unwrap() {
// also equal, try next array
Expand All @@ -373,15 +383,15 @@ impl ChunkSort<Utf8Type> for Utf8Chunked {
}

// series should be matching type or utf8
for s in other {
for (s, reverse) in other.iter().zip(&reverse[1..]) {
match s.dtype() {
DataType::Utf8 => {
let ca = s.utf8().unwrap();
partial_ord_by_idx!(ca)
partial_ord_by_idx!(ca, *reverse)
}
_ => {
let ca = s.f64().expect("cast to f64 before calling this method");
partial_ord_by_idx!(ca)
partial_ord_by_idx!(ca, *reverse)
}
}
}
Expand Down
51 changes: 40 additions & 11 deletions polars/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -760,17 +760,14 @@ impl DataFrame {
Ok(self)
}

/// Return a sorted clone of this DataFrame.
pub fn sort<'a, S, J>(&self, by_column: S, reverse: bool) -> Result<Self>
where
S: Selection<'a, J>,
{
let take = match by_column.single() {
Some(by_column) => {
let s = self.column(by_column)?;
s.argsort(reverse)
/// This is the dispatch of Self::sort, and exists to reduce compile bloat by monomorphization.
fn sort_impl(&self, by_column: Vec<&str>, mut reverse: Vec<bool>) -> Result<Self> {
let take = match by_column.len() {
1 => {
let s = self.column(by_column[0])?;
s.argsort(reverse[0])
}
None => {
n_cols => {
#[cfg(feature = "sort_multiple")]
{
let mut columns = self.select_series(by_column)?;
Expand Down Expand Up @@ -800,12 +797,19 @@ impl DataFrame {
})
.collect::<Vec<_>>();

// broadcast ordering
if n_cols > reverse.len() && reverse.len() == 1 {
while n_cols != reverse.len() {
reverse.push(reverse[0]);
}
}

if !matches!(first.dtype(), DataType::Utf8) {
first = first.cast_with_dtype(&dtype)?;
}
}

first.argsort_multiple(&columns, reverse)?
first.argsort_multiple(&columns, &reverse)?
}
#[cfg(not(feature = "sort_multiple"))]
{
Expand All @@ -816,6 +820,31 @@ impl DataFrame {
Ok(self.take(&take))
}

/// Return a sorted clone of this DataFrame.
///
/// # Example
///
/// ```
/// use polars_core::prelude::*;
///
/// fn sort_example(df: &DataFrame, reverse: &[bool]) -> Result<DataFrame> {
/// df.sort("a", reverse)
/// }
///
/// fn sort_by_multiple_columns_example(df: &DataFrame) -> Result<DataFrame> {
/// df.sort(&["a", "b"], false)
/// }
/// ```
pub fn sort<'a, S, J>(&self, by_column: S, reverse: impl IntoVec<bool>) -> Result<Self>
where
S: Selection<'a, J>,
{
// we do this heap allocation and dispatch to reduce monomorphization bloat
let by_column = by_column.to_selection_vec();
let reverse = reverse.into_vec();
self.sort_impl(by_column, reverse)
}

/// Replace a column with a series.
pub fn replace<S: IntoSeries>(&mut self, column: &str, new_col: S) -> Result<&mut Self> {
self.apply(column, |_| new_col.into_series())
Expand Down
1 change: 1 addition & 0 deletions polars/polars-core/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub use crate::{
IntoSeries, NamedFrom, Series, SeriesTrait,
},
testing::*,
utils::IntoVec,
vector_hasher::VecHash,
};
pub use arrow::datatypes::{ArrowPrimitiveType, Field as ArrowField, Schema as ArrowSchema};
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/src/series/implementations/dates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ macro_rules! impl_dyn_series {
cast_and_apply!(self, group_tuples, multithreaded)
}
#[cfg(feature = "sort_multiple")]
fn argsort_multiple(&self, by: &[Series], reverse: bool) -> Result<UInt32Chunked> {
fn argsort_multiple(&self, by: &[Series], reverse: &[bool]) -> Result<UInt32Chunked> {
let phys_type = self.0.physical_type();
let s = self.cast_with_dtype(&phys_type).unwrap();

Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/src/series/implementations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ macro_rules! impl_dyn_series {
}

#[cfg(feature = "sort_multiple")]
fn argsort_multiple(&self, by: &[Series], reverse: bool) -> Result<UInt32Chunked> {
fn argsort_multiple(&self, by: &[Series], reverse: &[bool]) -> Result<UInt32Chunked> {
self.0.argsort_multiple(by, reverse)
}
}
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ pub(crate) mod private {
unimplemented!()
}
#[cfg(feature = "sort_multiple")]
fn argsort_multiple(&self, _by: &[Series], _reverse: bool) -> Result<UInt32Chunked> {
fn argsort_multiple(&self, _by: &[Series], _reverse: &[bool]) -> Result<UInt32Chunked> {
Err(PolarsError::InvalidOperation(
"argsort_multiple is not implemented for this Series".into(),
))
Expand Down
16 changes: 16 additions & 0 deletions polars/polars-core/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -868,3 +868,19 @@ where
}
}
}

pub trait IntoVec<T> {
fn into_vec(self) -> Vec<T>;
}

impl IntoVec<bool> for bool {
fn into_vec(self) -> Vec<bool> {
vec![self]
}
}

impl<T> IntoVec<T> for Vec<T> {
fn into_vec(self) -> Self {
self
}
}

0 comments on commit 9447def

Please sign in to comment.