Skip to content

Commit

Permalink
sort_with dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 2, 2021
1 parent af8e7a7 commit 762aad3
Show file tree
Hide file tree
Showing 18 changed files with 97 additions and 45 deletions.
6 changes: 3 additions & 3 deletions polars/polars-core/src/chunked_array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -496,10 +496,10 @@ pub trait ToDummies<T>: ChunkUnique<T> {
}
}

#[derive(Default)]
#[derive(Default, Copy, Clone, Eq, PartialEq, Debug)]
pub struct SortOptions {
descending: bool,
nulls_last: bool,
pub descending: bool,
pub nulls_last: bool,
}

/// Sort operations on `ChunkedArray`.
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-core/src/series/implementations/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,8 @@ impl SeriesTrait for SeriesWrap<BooleanChunked> {
self.0.get_any_value_unchecked(index)
}

fn sort(&self, reverse: bool) -> Series {
ChunkSort::sort(&self.0, reverse).into_series()
fn sort_with(&self, options: SortOptions) -> Series {
ChunkSort::sort_with(&self.0, options).into_series()
}

fn argsort(&self, reverse: bool) -> UInt32Chunked {
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-core/src/series/implementations/categorical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,8 @@ impl SeriesTrait for SeriesWrap<CategoricalChunked> {
self.0.get_any_value_unchecked(index)
}

fn sort(&self, reverse: bool) -> Series {
ChunkSort::sort(&self.0, reverse).into_series()
fn sort_with(&self, options: SortOptions) -> Series {
ChunkSort::sort_with(&self.0, options).into_series()
}

fn argsort(&self, reverse: bool) -> UInt32Chunked {
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-core/src/series/implementations/dates_time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,8 @@ macro_rules! impl_dyn_series {
self.0.get_any_value_unchecked(index).$into_logical()
}

fn sort(&self, reverse: bool) -> Series {
self.0.sort(reverse).$into_logical().into_series()
fn sort_with(&self, options: SortOptions) -> Series {
self.0.sort_with(options).$into_logical().into_series()
}

fn argsort(&self, reverse: bool) -> UInt32Chunked {
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-core/src/series/implementations/floats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,8 @@ macro_rules! impl_dyn_series {
self.0.get_any_value_unchecked(index)
}

fn sort(&self, reverse: bool) -> Series {
ChunkSort::sort(&self.0, reverse).into_series()
fn sort_with(&self, options: SortOptions) -> Series {
ChunkSort::sort_with(&self.0, options).into_series()
}

fn argsort(&self, reverse: bool) -> UInt32Chunked {
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-core/src/series/implementations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -616,8 +616,8 @@ macro_rules! impl_dyn_series {
self.0.get_any_value_unchecked(index)
}

fn sort(&self, reverse: bool) -> Series {
ChunkSort::sort(&self.0, reverse).into_series()
fn sort_with(&self, options: SortOptions) -> Series {
ChunkSort::sort_with(&self.0, options).into_series()
}

fn argsort(&self, reverse: bool) -> UInt32Chunked {
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-core/src/series/implementations/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,8 @@ impl SeriesTrait for SeriesWrap<Utf8Chunked> {
self.0.get_any_value_unchecked(index)
}

fn sort(&self, reverse: bool) -> Series {
ChunkSort::sort(&self.0, reverse).into_series()
fn sort_with(&self, options: SortOptions) -> Series {
ChunkSort::sort_with(&self.0, options).into_series()
}

fn argsort(&self, reverse: bool) -> UInt32Chunked {
Expand Down
7 changes: 7 additions & 0 deletions polars/polars-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,13 @@ impl Series {
Ok(self)
}

pub fn sort(&self, reverse: bool) -> Self {
self.sort_with(SortOptions {
descending: reverse,
..Default::default()
})
}

/// Only implemented for numeric types
pub fn as_single_ptr(&mut self) -> Result<usize> {
self.get_inner_mut().as_single_ptr()
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/src/series/series_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ pub trait SeriesTrait:
invalid_operation_panic!(self)
}

fn sort(&self, _reverse: bool) -> Series {
fn sort_with(&self, _options: SortOptions) -> Series {
invalid_operation_panic!(self)
}

Expand Down
19 changes: 14 additions & 5 deletions polars/polars-lazy/src/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ pub enum Expr {
},
Sort {
expr: Box<Expr>,
reverse: bool,
options: SortOptions,
},
Take {
expr: Box<Expr>,
Expand Down Expand Up @@ -387,7 +387,7 @@ impl fmt::Debug for Expr {
Not(expr) => write!(f, "NOT {:?}", expr),
IsNull(expr) => write!(f, "{:?} IS NULL", expr),
IsNotNull(expr) => write!(f, "{:?} IS NOT NULL", expr),
Sort { expr, reverse } => match reverse {
Sort { expr, options } => match options.descending {
true => write!(f, "{:?} DESC", expr),
false => write!(f, "{:?} ASC", expr),
},
Expand Down Expand Up @@ -840,12 +840,21 @@ impl Expr {
}

/// Sort in increasing order. See [the eager implementation](polars_core::series::SeriesTrait::sort).
///
/// Can be used in `default` and `aggregation` context.
pub fn sort(self, reverse: bool) -> Self {
Expr::Sort {
expr: Box::new(self),
reverse,
options: SortOptions {
descending: reverse,
..Default::default()
},
}
}

/// Sort with given options.
pub fn sort_with(self, options: SortOptions) -> Self {
Expr::Sort {
expr: Box::new(self),
options,
}
}

Expand Down
4 changes: 2 additions & 2 deletions polars/polars-lazy/src/logical_plan/aexpr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ pub enum AExpr {
},
Sort {
expr: Node,
reverse: bool,
options: SortOptions,
},
Take {
expr: Node,
Expand Down Expand Up @@ -368,7 +368,7 @@ impl AExpr {
(Literal(left), Literal(right)) => left == right,
(BinaryExpr { op: l, .. }, BinaryExpr { op: r, .. }) => l == r,
(Cast { data_type: l, .. }, Cast { data_type: r, .. }) => l == r,
(Sort { reverse: l, .. }, Sort { reverse: r, .. }) => l == r,
(Sort { options: l, .. }, Sort { options: r, .. }) => l == r,
(SortBy { reverse: l, .. }, SortBy { reverse: r, .. }) => l == r,
(Shift { periods: l, .. }, Shift { periods: r, .. }) => l == r,
(
Expand Down
8 changes: 4 additions & 4 deletions polars/polars-lazy/src/logical_plan/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ pub(crate) fn to_aexpr(expr: Expr, arena: &mut Arena<AExpr>) -> Node {
expr: to_aexpr(*expr, arena),
idx: to_aexpr(*idx, arena),
},
Expr::Sort { expr, reverse } => AExpr::Sort {
Expr::Sort { expr, options } => AExpr::Sort {
expr: to_aexpr(*expr, arena),
reverse,
options,
},
Expr::SortBy { expr, by, reverse } => AExpr::SortBy {
expr: to_aexpr(*expr, arena),
Expand Down Expand Up @@ -450,11 +450,11 @@ pub(crate) fn node_to_exp(node: Node, expr_arena: &Arena<AExpr>) -> Expr {
strict,
}
}
AExpr::Sort { expr, reverse } => {
AExpr::Sort { expr, options } => {
let exp = node_to_exp(expr, expr_arena);
Expr::Sort {
expr: Box::new(exp),
reverse,
options,
}
}
AExpr::Take { expr, idx } => {
Expand Down
12 changes: 8 additions & 4 deletions polars/polars-lazy/src/logical_plan/optimizer/simplify_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,10 +345,14 @@ impl OptimizationRule for SimplifyExprRule {
AExpr::Reverse(expr) => {
let input = expr_arena.get(*expr);
match input {
AExpr::Sort { expr, reverse } => Some(AExpr::Sort {
expr: *expr,
reverse: !*reverse,
}),
AExpr::Sort { expr, options } => {
let mut options = *options;
options.descending = !options.descending;
Some(AExpr::Sort {
expr: *expr,
options,
})
}
AExpr::SortBy { expr, by, reverse } => Some(AExpr::SortBy {
expr: *expr,
by: by.clone(),
Expand Down
12 changes: 6 additions & 6 deletions polars/polars-lazy/src/physical_plan/expressions/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ use std::sync::Arc;

pub struct SortExpr {
pub(crate) physical_expr: Arc<dyn PhysicalExpr>,
pub(crate) reverse: bool,
pub(crate) options: SortOptions,
expr: Expr,
}

impl SortExpr {
pub fn new(physical_expr: Arc<dyn PhysicalExpr>, reverse: bool, expr: Expr) -> Self {
pub fn new(physical_expr: Arc<dyn PhysicalExpr>, options: SortOptions, expr: Expr) -> Self {
Self {
physical_expr,
reverse,
options,
expr,
}
}
Expand All @@ -27,7 +27,7 @@ impl PhysicalExpr for SortExpr {

fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> Result<Series> {
let series = self.physical_expr.evaluate(df, state)?;
Ok(series.sort(self.reverse))
Ok(series.sort_with(self.options))
}

#[allow(clippy::ptr_arg)]
Expand All @@ -49,7 +49,7 @@ impl PhysicalExpr for SortExpr {
let group =
unsafe { series.take_iter_unchecked(&mut idx.iter().map(|i| *i as usize)) };

let sorted_idx = group.argsort(self.reverse);
let sorted_idx = group.argsort(self.options.descending);

let new_idx: Vec<_> = sorted_idx
.cont_slice()
Expand Down Expand Up @@ -90,7 +90,7 @@ impl PhysicalAggregation for SortExpr {
let agg_s = agg_s
.list()
.unwrap()
.apply_amortized(|s| s.as_ref().sort(self.reverse))
.apply_amortized(|s| s.as_ref().sort_with(self.options))
.into_series();
Ok(Some(agg_s))
}
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-lazy/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -531,11 +531,11 @@ impl DefaultPlanner {
column,
node_to_exp(expression, expr_arena),
))),
Sort { expr, reverse } => {
Sort { expr, options } => {
let phys_expr = self.create_physical_expr(expr, ctxt, expr_arena)?;
Ok(Arc::new(SortExpr::new(
phys_expr,
reverse,
options,
node_to_exp(expression, expr_arena),
)))
}
Expand Down
20 changes: 17 additions & 3 deletions py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,11 @@ def cumsum(self, reverse: bool = False) -> "Expr":
----------
reverse
Reverse the operation.
Notes
-----
Dtypes in {Int8, UInt8, Int16, UInt16} are cast to
Int64 before summing to prevent overflow issues.
"""
return wrap_expr(self._pyexpr.cumsum(reverse))

Expand All @@ -605,6 +610,11 @@ def cumprod(self, reverse: bool = False) -> "Expr":
----------
reverse
Reverse the operation.
Notes
-----
Dtypes in {Int8, UInt8, Int16, UInt16} are cast to
Int64 before summing to prevent overflow issues.
"""
return wrap_expr(self._pyexpr.cumprod(reverse))

Expand Down Expand Up @@ -693,7 +703,7 @@ def cast(self, dtype: Type[Any], strict: bool = True) -> "Expr":
dtype = py_type_to_dtype(dtype)
return wrap_expr(self._pyexpr.cast(dtype, strict))

def sort(self, reverse: bool = False) -> "Expr":
def sort(self, reverse: bool = False, nulls_last: bool = False) -> "Expr":
"""
Sort this column. In projection/ selection context the whole column is sorted.
If used in a groupby context, the groups are sorted.
Expand All @@ -703,8 +713,10 @@ def sort(self, reverse: bool = False) -> "Expr":
reverse
False -> order from small to large.
True -> order from large to small.
nulls_last
If True nulls are considered to be larger than any valid value
"""
return wrap_expr(self._pyexpr.sort(reverse))
return wrap_expr(self._pyexpr.sort_with(reverse, nulls_last))

def arg_sort(self, reverse: bool = False) -> "Expr":
"""
Expand Down Expand Up @@ -894,7 +906,9 @@ def sum(self) -> "Expr":
"""
Get sum value.
Note that dtypes in {Int8, UInt8, Int16, UInt16} are cast to
Notes
-----
Dtypes in {Int8, UInt8, Int16, UInt16} are cast to
Int64 before summing to prevent overflow issues.
"""
return wrap_expr(self._pyexpr.sum())
Expand Down
14 changes: 13 additions & 1 deletion py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,9 @@ def sum(self) -> Union[int, float]:
"""
Reduce this Series to the sum value.
Note that dtypes in {Int8, UInt8, Int16, UInt16} are cast to
Notes
-----
Dtypes in {Int8, UInt8, Int16, UInt16} are cast to
Int64 before summing to prevent overflow issues.
Examples
Expand Down Expand Up @@ -944,6 +946,11 @@ def cumsum(self, reverse: bool = False) -> "Series":
reverse
reverse the operation.
Notes
-----
Dtypes in {Int8, UInt8, Int16, UInt16} are cast to
Int64 before summing to prevent overflow issues.
Examples
--------
>>> s = pl.Series("a", [1, 2, 3])
Expand Down Expand Up @@ -1016,6 +1023,11 @@ def cumprod(self, reverse: bool = False) -> "Series":
reverse
reverse the operation.
Notes
-----
Dtypes in {Int8, UInt8, Int16, UInt16} are cast to
Int64 before summing to prevent overflow issues.
Examples
--------
>>> s = pl.Series("a", [1, 2, 3])
Expand Down
10 changes: 8 additions & 2 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,14 @@ impl PyExpr {
};
expr.into()
}
pub fn sort(&self, reverse: bool) -> PyExpr {
self.clone().inner.sort(reverse).into()
pub fn sort_with(&self, descending: bool, nulls_last: bool) -> PyExpr {
self.clone()
.inner
.sort_with(SortOptions {
descending,
nulls_last,
})
.into()
}

pub fn arg_sort(&self, reverse: bool) -> PyExpr {
Expand Down

0 comments on commit 762aad3

Please sign in to comment.