Skip to content

Commit

Permalink
feat[python,rust]: Add nulls_last option to Expr.arg_sort (#4600)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Aug 28, 2022
1 parent bd91e1d commit 579d737
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 43 deletions.
10 changes: 2 additions & 8 deletions polars/polars-lazy/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -503,21 +503,15 @@ impl Expr {
}

/// Get the index values that would sort this expression.
pub fn arg_sort(self, reverse: bool) -> Self {
pub fn arg_sort(self, sort_options: SortOptions) -> Self {
let options = FunctionOptions {
collect_groups: ApplyOptions::ApplyGroups,
fmt_str: "arg_sort",
..Default::default()
};

self.function_with_options(
move |s: Series| {
Ok(s.argsort(SortOptions {
descending: reverse,
..Default::default()
})
.into_series())
},
move |s: Series| Ok(s.argsort(sort_options).into_series()),
GetOutput::from_type(IDX_DTYPE),
options,
)
Expand Down
40 changes: 35 additions & 5 deletions polars/polars-lazy/src/tests/aggregations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,14 @@ fn take_aggregations() -> Result<()> {
.agg([
// keep the head as it test slice correctness
col("book")
.take(col("count").arg_sort(true).head(Some(2)))
.take(
col("count")
.arg_sort(SortOptions {
descending: true,
nulls_last: false,
})
.head(Some(2)),
)
.alias("ordered"),
])
.sort("user", Default::default())
Expand Down Expand Up @@ -498,7 +505,12 @@ fn test_take_consistency() -> Result<()> {
let out = df
.clone()
.lazy()
.select([col("A").arg_sort(true).take(lit(0))])
.select([col("A")
.arg_sort(SortOptions {
descending: true,
nulls_last: false,
})
.take(lit(0))])
.collect()?;

let a = out.column("A")?;
Expand All @@ -509,7 +521,12 @@ fn test_take_consistency() -> Result<()> {
.clone()
.lazy()
.groupby_stable([col("cars")])
.agg([col("A").arg_sort(true).take(lit(0))])
.agg([col("A")
.arg_sort(SortOptions {
descending: true,
nulls_last: false,
})
.take(lit(0))])
.collect()?;

let out = out.column("A")?;
Expand All @@ -522,9 +539,22 @@ fn test_take_consistency() -> Result<()> {
.groupby_stable([col("cars")])
.agg([
col("A"),
col("A").arg_sort(true).take(lit(0)).alias("1"),
col("A")
.take(col("A").arg_sort(true).take(lit(0)))
.arg_sort(SortOptions {
descending: true,
nulls_last: false,
})
.take(lit(0))
.alias("1"),
col("A")
.take(
col("A")
.arg_sort(SortOptions {
descending: true,
nulls_last: false,
})
.take(lit(0)),
)
.alias("2"),
])
.collect()?;
Expand Down
9 changes: 8 additions & 1 deletion polars/polars-lazy/src/tests/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1831,7 +1831,14 @@ fn test_single_group_result() -> Result<()> {

let out = df
.lazy()
.select([col("a").arg_sort(false).list().over([col("a")]).flatten()])
.select([col("a")
.arg_sort(SortOptions {
descending: false,
nulls_last: false,
})
.list()
.over([col("a")])
.flatten()])
.collect()?;

let a = out.column("a")?.idx()?;
Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/series.rst
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ Manipulation/ selection

Series.alias
Series.append
Series.arg_sort
Series.argsort
Series.cast
Series.ceil
Expand Down
31 changes: 23 additions & 8 deletions py-polars/polars/internals/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1761,20 +1761,21 @@ def sort(self, reverse: bool = False, nulls_last: bool = False) -> Expr:
"""
return wrap_expr(self._pyexpr.sort_with(reverse, nulls_last))

def arg_sort(self, reverse: bool = False) -> Expr:
def arg_sort(self, reverse: bool = False, nulls_last: bool = False) -> Expr:
"""
Get the index values that would sort this column.
Parameters
----------
reverse
False -> order from small to large.
True -> order from large to small.
Sort in reverse (descending) order.
nulls_last
Place null values last instead of first.
Returns
-------
out
Series of type UInt32
Expr
Series of dtype UInt32.
Examples
--------
Expand All @@ -1798,7 +1799,7 @@ def arg_sort(self, reverse: bool = False) -> Expr:
└─────┘
"""
return wrap_expr(self._pyexpr.arg_sort(reverse))
return wrap_expr(self._pyexpr.arg_sort(reverse, nulls_last))

def arg_max(self) -> Expr:
"""
Expand Down Expand Up @@ -4553,10 +4554,24 @@ def abs(self) -> Expr:
"""
return wrap_expr(self._pyexpr.abs())

def argsort(self, reverse: bool = False) -> Expr:
def argsort(self, reverse: bool = False, nulls_last: bool = False) -> Expr:
"""
Get the index values that would sort this column.
Alias for `arg_sort`.
Parameters
----------
reverse
Sort in reverse (descending) order.
nulls_last
Place null values last instead of first.
Returns
-------
Expr
Series of dtype UInt32.
Examples
--------
>>> df = pl.DataFrame(
Expand All @@ -4579,7 +4594,7 @@ def argsort(self, reverse: bool = False) -> Expr:
└─────┘
"""
return self.arg_sort(reverse)
return self.arg_sort(reverse, nulls_last)

def rank(self, method: RankMethod = "average", reverse: bool = False) -> Expr:
"""
Expand Down
32 changes: 23 additions & 9 deletions py-polars/polars/internals/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1579,21 +1579,21 @@ def sort(self, reverse: bool = False, *, in_place: bool = False) -> Series | Non
else:
return wrap_s(self._s.sort(reverse))

def argsort(self, reverse: bool = False, nulls_last: bool = False) -> Series:
def arg_sort(self, reverse: bool = False, nulls_last: bool = False) -> Series:
"""
Index location of the sorted variant of this Series.
Get the index values that would sort this Series.
Returns
-------
indexes
Indexes that can be used to sort this array.
Parameters
----------
reverse
Sort in reverse (descending) order.
nulls_last
Place null values last.
Place null values last instead of first.
Examples
--------
>>> s = pl.Series("a", [5, 3, 4, 1, 2])
>>> s.argsort()
>>> s.arg_sort()
shape: (5,)
Series: 'a' [u32]
[
Expand All @@ -1605,7 +1605,21 @@ def argsort(self, reverse: bool = False, nulls_last: bool = False) -> Series:
]
"""
return wrap_s(self._s.argsort(reverse, nulls_last))

def argsort(self, reverse: bool = False, nulls_last: bool = False) -> Series:
"""
Get the index values that would sort this Series.
Alias for :func:`arg_sort`.
Parameters
----------
reverse
Sort in reverse (descending) order.
nulls_last
Place null values last instead of first.
"""

def arg_unique(self) -> Series:
"""Get unique index as Series."""
Expand Down
10 changes: 8 additions & 2 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,14 @@ impl PyExpr {
.into()
}

pub fn arg_sort(&self, reverse: bool) -> PyExpr {
self.clone().inner.arg_sort(reverse).into()
pub fn arg_sort(&self, reverse: bool, nulls_last: bool) -> PyExpr {
self.clone()
.inner
.arg_sort(SortOptions {
descending: reverse,
nulls_last,
})
.into()
}
pub fn arg_max(&self) -> PyExpr {
self.clone().inner.arg_max().into()
Expand Down
10 changes: 0 additions & 10 deletions py-polars/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -540,16 +540,6 @@ impl PySeries {
PySeries::new(self.series.sort(reverse))
}

pub fn argsort(&self, reverse: bool, nulls_last: bool) -> Self {
self.series
.argsort(SortOptions {
descending: reverse,
nulls_last,
})
.into_series()
.into()
}

pub fn value_counts(&self, sorted: bool) -> PyResult<PyDataFrame> {
let df = self
.series
Expand Down

0 comments on commit 579d737

Please sign in to comment.