Skip to content

Commit

Permalink
feat[rust, python]: top_k expression (#4556)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 29, 2022
1 parent 85eeaec commit dee5d5f
Show file tree
Hide file tree
Showing 18 changed files with 161 additions and 22 deletions.
1 change: 1 addition & 0 deletions polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ date_offset = ["polars-lazy/date_offset"]
trigonometry = ["polars-lazy/trigonometry"]
sign = ["polars-lazy/sign"]
pivot = ["polars-lazy/pivot"]
top_k = ["polars-lazy/top_k"]

test = [
"lazy",
Expand Down
2 changes: 2 additions & 0 deletions polars/polars-arrow/src/kernels/rolling/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type Idx = usize;
type WindowSize = usize;
type Len = usize;

#[inline]
pub fn compare_fn_nan_min<T>(a: &T, b: &T) -> Ordering
where
T: PartialOrd + IsFloat,
Expand All @@ -41,6 +42,7 @@ where
}
}

#[inline]
pub fn compare_fn_nan_max<T>(a: &T, b: &T) -> Ordering
where
T: PartialOrd + IsFloat,
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-core/src/chunked_array/ops/sort/argsort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ use super::*;

#[inline]
fn default_order<T: PartialOrd + IsFloat>(a: &(IdxSize, T), b: &(IdxSize, T)) -> Ordering {
sort_cmp(&a.1, &b.1)
compare_fn_nan_max(&a.1, &b.1)
}

#[inline]
fn reverse_order<T: PartialOrd + IsFloat>(a: &(IdxSize, T), b: &(IdxSize, T)) -> Ordering {
sort_cmp(&b.1, &a.1)
compare_fn_nan_max(&b.1, &a.1)
}

pub(super) fn argsort<I, J, T>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ pub(crate) fn argsort_multiple_impl<T: PartialOrd + Send + IsFloat + Copy>(
.collect_trusted();

let first_reverse = reverse[0];
vals.par_sort_by(
|tpl_a, tpl_b| match (first_reverse, sort_cmp(&tpl_a.1, &tpl_b.1)) {
vals.par_sort_by(|tpl_a, tpl_b| {
match (first_reverse, compare_fn_nan_max(&tpl_a.1, &tpl_b.1)) {
// if ordering is equal, we check the other arrays until we find a non-equal ordering
// if we have exhausted all arrays, we keep the equal ordering.
(_, Ordering::Equal) => {
Expand All @@ -51,8 +51,8 @@ pub(crate) fn argsort_multiple_impl<T: PartialOrd + Send + IsFloat + Copy>(
(true, Ordering::Less) => Ordering::Greater,
(true, Ordering::Greater) => Ordering::Less,
(_, ord) => ord,
},
);
}
});
let ca: NoNull<IdxCa> = vals.into_iter().map(|(idx, _v)| idx).collect_trusted();
let mut ca = ca.into_inner();
ca.set_sorted(reverse[0]);
Expand Down
16 changes: 0 additions & 16 deletions polars/polars-core/src/chunked_array/ops/sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,6 @@ use crate::prelude::sort::argsort_multiple::{args_validate, argsort_multiple_imp
use crate::prelude::*;
use crate::utils::{CustomIterTools, NoNull};

#[inline]
fn sort_cmp<T: PartialOrd + IsFloat>(a: &T, b: &T) -> Ordering {
if T::is_float() {
match (a.is_nan(), b.is_nan()) {
// safety: we checked nans
(false, false) => unsafe { a.partial_cmp(b).unwrap_unchecked() },
(true, true) => Ordering::Equal,
(true, false) => Ordering::Greater,
(false, true) => Ordering::Less,
}
} else {
// no floats, so we can compare unchecked
unsafe { a.partial_cmp(b).unwrap_unchecked() }
}
}

/// Reverse sorting when there are no nulls
fn order_reverse<T: Ord>(a: &T, b: &T) -> Ordering {
b.cmp(a)
Expand Down
1 change: 1 addition & 0 deletions polars/polars-lazy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ arg_where = []
search_sorted = ["polars-ops/search_sorted"]
meta = []
pivot = ["polars-core/rows"]
top_k = ["polars-ops/top_k"]

# no guarantees whatsoever
private = ["polars-time/private"]
Expand Down
11 changes: 11 additions & 0 deletions polars/polars-lazy/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ pub enum FunctionExpr {
ListExpr(ListFunction),
#[cfg(feature = "dtype-struct")]
StructExpr(StructFunction),
#[cfg(feature = "top_k")]
TopK {
k: usize,
reverse: bool,
},
}

#[cfg(feature = "trigonometry")]
Expand Down Expand Up @@ -258,6 +263,8 @@ impl FunctionExpr {
}
}
}
#[cfg(feature = "top_k")]
TopK { .. } => same_type(),
}
}
}
Expand Down Expand Up @@ -409,6 +416,10 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
FieldByName(name) => map_with_args!(struct_::get_by_name, name.clone()),
}
}
#[cfg(feature = "top_k")]
TopK { k, reverse } => {
map_with_args!(top_k, k, reverse)
}
}
}
}
Expand Down
8 changes: 8 additions & 0 deletions polars/polars-lazy/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,14 @@ impl Expr {
}
}

/// Returns the `k` largest elements.
///
/// This has time complexity `O(n + k log(n))`.
#[cfg(feature = "top_k")]
pub fn top_k(self, k: usize, reverse: bool) -> Self {
self.apply_private(FunctionExpr::TopK { k, reverse }, "top_k")
}

/// Reverse column
pub fn reverse(self) -> Self {
Expr::Reverse(Box::new(self))
Expand Down
1 change: 1 addition & 0 deletions polars/polars-ops/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@ hash = []
rolling_window = ["polars-core/rolling_window"]
moment = ["polars-core/moment"]
search_sorted = []
top_k = []
4 changes: 4 additions & 0 deletions polars/polars-ops/src/chunked_array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ mod set;
mod strings;
#[cfg(feature = "to_dummies")]
mod to_dummies;
#[cfg(feature = "top_k")]
mod top_k;

pub use list::*;
#[allow(unused_imports)]
Expand All @@ -11,6 +13,8 @@ pub use set::ChunkedSet;
pub use strings::*;
#[cfg(feature = "to_dummies")]
pub use to_dummies::*;
#[cfg(feature = "top_k")]
pub use top_k::*;

#[allow(unused_imports)]
use crate::prelude::*;
71 changes: 71 additions & 0 deletions polars/polars-ops/src/chunked_array/top_k.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use std::cmp::Ordering;
use std::collections::binary_heap::BinaryHeap;

use polars_arrow::kernels::rolling::compare_fn_nan_max;
use polars_core::downcast_as_macro_arg_physical;
use polars_core::export::num::NumCast;
use polars_core::prelude::*;

#[repr(transparent)]
struct Compare<T>(T);

impl<T: PartialOrd + IsFloat> PartialEq for Compare<T> {
fn eq(&self, other: &Self) -> bool {
matches!(self.cmp(other), Ordering::Equal)
}
}

impl<T: PartialOrd + IsFloat> PartialOrd for Compare<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(compare_fn_nan_max(&self.0, &other.0))
}
}

impl<T: PartialOrd + IsFloat> Eq for Compare<T> {}

impl<T: PartialOrd + IsFloat> Ord for Compare<T> {
fn cmp(&self, other: &Self) -> Ordering {
// Safety:
// we always return Some
unsafe { self.partial_cmp(other).unwrap_unchecked() }
}
}

fn top_k_impl<T>(ca: &ChunkedArray<T>, k: usize, mult_order: T::Native) -> Result<ChunkedArray<T>>
where
T: PolarsNumericType,
{
// mult_order should be -1 / +1 to determine the order of the heap

let mut heap = BinaryHeap::with_capacity(ca.len());

for arr in ca.downcast_iter() {
for v in arr {
heap.push(v.map(|v| Compare(*v * mult_order)));
}
}
let mut out: ChunkedArray<_> = (0..k)
.map(|_| {
heap.pop()
.unwrap()
.map(|compare_struct| compare_struct.0 * mult_order)
})
.collect();
out.rename(ca.name());
Ok(out)
}

pub fn top_k(s: &Series, k: usize, reverse: bool) -> Result<Series> {
let dtype = s.dtype();

let s = s.to_physical_repr();

macro_rules! dispatch {
($ca:expr) => {{
let mult_order = if reverse { -1 } else { 1 };
top_k_impl($ca, k, NumCast::from(mult_order).unwrap()).map(|ca| ca.into_series())
}};
}

downcast_as_macro_arg_physical!(&s, dispatch).and_then(|s| s.cast(dtype))
}
2 changes: 2 additions & 0 deletions py-polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ csv-file = ["polars/csv-file"]
object = ["polars/object"]
extract_jsonpath = ["polars/extract_jsonpath"]
pivot = ["polars/pivot"]
top_k = ["polars/top_k"]

all = [
"json",
Expand All @@ -80,6 +81,7 @@ all = [
"polars/timezones",
"object",
"pivot",
"top_k",
]

# we cannot conditionaly activate simd
Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/expression.rst
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ Manipulation/ selection
Expr.take
Expr.take_every
Expr.to_physical
Expr.top_k
Expr.upper_bound
Expr.where

Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/series.rst
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ Manipulation/ selection
Series.take
Series.take_every
Series.to_dummies
Series.top_k
Series.view
Series.zip_with

Expand Down
20 changes: 20 additions & 0 deletions py-polars/polars/internals/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1762,6 +1762,26 @@ def sort(self, reverse: bool = False, nulls_last: bool = False) -> Expr:
"""
return wrap_expr(self._pyexpr.sort_with(reverse, nulls_last))

def top_k(self, k: int = 5, reverse: bool = False) -> Expr:
r"""
Return the `k` largest elements.
If 'reverse=True` the smallest elements will be given.
This has time complexity:
.. math:: O(n + k \\log{}n - \frac{k}{2})
Parameters
----------
k
Number of elements to return.
reverse
Return the smallest elements.
"""
return wrap_expr(self._pyexpr.top_k(k, reverse))

def arg_sort(self, reverse: bool = False, nulls_last: bool = False) -> Expr:
"""
Get the index values that would sort this column.
Expand Down
19 changes: 19 additions & 0 deletions py-polars/polars/internals/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1571,6 +1571,25 @@ def sort(self, reverse: bool = False, *, in_place: bool = False) -> Series | Non
else:
return wrap_s(self._s.sort(reverse))

def top_k(self, k: int = 5, reverse: bool = False) -> Series:
r"""
Return the `k` largest elements.
If 'reverse=True` the smallest elements will be given.
This has time complexity:
.. math:: O(n + k \\log{}n - \frac{k}{2})
Parameters
----------
k
Number of elements to return.
reverse
Return the smallest elements.
"""

def arg_sort(self, reverse: bool = False, nulls_last: bool = False) -> Series:
"""
Get the index values that would sort this Series.
Expand Down
6 changes: 6 additions & 0 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,12 @@ impl PyExpr {
})
.into()
}

#[cfg(feature = "top_k")]
pub fn top_k(&self, k: usize, reverse: bool) -> PyExpr {
self.inner.clone().top_k(k, reverse).into()
}

pub fn arg_max(&self) -> PyExpr {
self.clone().inner.arg_max().into()
}
Expand Down
7 changes: 7 additions & 0 deletions py-polars/tests/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,10 @@ def test_argsort_rank_nans() -> None:
)
.select(["rank", "argsort"])
).to_dict(False) == {"rank": [1.0, 2.0], "argsort": [0, 1]}


def test_top_k() -> None:
s = pl.Series([3, 1, 2, 5, 8])

assert s.top_k(3).to_list() == [8, 5, 3]
assert s.top_k(4, reverse=True).to_list() == [1, 2, 3, 5]

0 comments on commit dee5d5f

Please sign in to comment.