Skip to content

Commit

Permalink
Fix argsort (#2946)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Mar 22, 2022
1 parent c71b56e commit bcb220f
Show file tree
Hide file tree
Showing 9 changed files with 54 additions and 41 deletions.
8 changes: 4 additions & 4 deletions polars/polars-lazy/src/dsl/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,11 @@ pub fn argsort_by<E: AsRef<[Expr]>>(by: E, reverse: &[bool]) -> Expr {
Expr::Function {
input: by.as_ref().to_vec(),
function,
output_type: GetOutput::from_type(DataType::UInt32),
output_type: GetOutput::from_type(IDX_DTYPE),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
input_wildcard_expansion: false,
auto_explode: true,
collect_groups: ApplyOptions::ApplyGroups,
input_wildcard_expansion: true,
auto_explode: false,
fmt_str: "argsort_by",
},
}
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ impl Expr {
);
let options = FunctionOptions {
collect_groups: ApplyOptions::ApplyGroups,
input_wildcard_expansion: false,
input_wildcard_expansion: true,
auto_explode: false,
fmt_str: "arg_sort",
};
Expand Down
3 changes: 3 additions & 0 deletions polars/polars-lazy/src/logical_plan/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,13 @@ pub struct DistinctOptions {
pub enum ApplyOptions {
/// Collect groups to a list and apply the function over the groups.
/// This can be important in aggregation context.
// e.g. [g1, g1, g2] -> [[g1, g2], g2]
ApplyGroups,
// collect groups to a list and then apply
// e.g. [g1, g1, g2] -> list([g1, g1, g2])
ApplyList,
// do not collect before apply
// e.g. [g1, g1, g2] -> [g1, g1, g2]
ApplyFlat,
}

Expand Down
2 changes: 0 additions & 2 deletions polars/polars-time/src/windows/groupby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,6 @@ pub fn groupby_windows(
}

let first = i as IdxSize;
dbg!(start_offset);
dbg!(first);

while i < time.len() {
let t = time[i];
Expand Down
3 changes: 2 additions & 1 deletion py-polars/docs/source/reference/expression.rst
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ Manipulation/ selection
Expr.cast
Expr.sort
Expr.arg_sort
Expr.argsort
Expr.sort_by
Expr.take
Expr.shift
Expand All @@ -195,7 +196,7 @@ Manipulation/ selection
Expr.drop_nulls
Expr.drop_nans
Expr.interpolate
Expr.argsort
Expr.arg_sort
Expr.clip
Expr.lower_bound
Expr.upper_bound
Expand Down
54 changes: 24 additions & 30 deletions py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,28 @@ def arg_sort(self, reverse: bool = False) -> "Expr":
-------
out
Series of type UInt32
Examples
--------
>>> df = pl.DataFrame(
... {
... "a": [20, 10, 30],
... }
... )
>>> df.select(pl.col("a").arg_sort())
shape: (3, 1)
┌─────┐
│ a │
│ --- │
│ u32 │
╞═════╡
│ 1 │
├╌╌╌╌╌┤
│ 0 │
├╌╌╌╌╌┤
│ 2 │
└─────┘
"""
return wrap_expr(self._pyexpr.arg_sort(reverse))

Expand Down Expand Up @@ -2211,37 +2233,9 @@ def abs(self) -> "Expr":

def argsort(self, reverse: bool = False) -> "Expr":
"""
Index location of the sorted variant of this Series.
Parameters
----------
reverse
Reverse the ordering. Default is from low to high.
Examples
--------
>>> df = pl.DataFrame(
... {
... "a": [20, 10, 30],
... }
... )
>>> df.select(pl.col("a").arg_sort())
shape: (3, 1)
┌─────┐
│ a │
│ --- │
│ u32 │
╞═════╡
│ 1 │
├╌╌╌╌╌┤
│ 0 │
├╌╌╌╌╌┤
│ 2 │
└─────┘
alias for `arg_sort`
"""
return pli.argsort_by([self], [reverse])
return self.arg_sort(reverse)

def rank(self, method: str = "average", reverse: bool = False) -> "Expr":
"""
Expand Down
5 changes: 4 additions & 1 deletion py-polars/polars/internals/lazy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,7 +966,8 @@ def arange(


def argsort_by(
exprs: List[Union["pli.Expr", str]], reverse: Union[List[bool], bool] = False
exprs: Union[Union["pli.Expr", str], Sequence[Union["pli.Expr", str]]],
reverse: Union[List[bool], bool] = False,
) -> "pli.Expr":
"""
Find the indexes that would sort the columns.
Expand All @@ -982,6 +983,8 @@ def argsort_by(
reverse
Default is ascending.
"""
if not isinstance(exprs, list):
exprs = [exprs] # type: ignore
if not isinstance(reverse, list):
reverse = [reverse] * len(exprs)
exprs = pli.selection_to_pyexpr_list(exprs)
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,10 +1032,10 @@ def test_argsort() -> None:
s = pl.Series("a", [5, 3, 4, 1, 2])
expected = pl.Series("a", [3, 4, 1, 2, 0], dtype=UInt32)

verify_series_and_expr_api(s, expected, "argsort")
assert s.argsort().series_equal(expected)

expected_reverse = pl.Series("a", [0, 2, 1, 4, 3], dtype=UInt32)
verify_series_and_expr_api(s, expected_reverse, "argsort", True)
assert s.argsort(True).series_equal(expected_reverse)


def test_arg_min_and_arg_max() -> None:
Expand Down
14 changes: 14 additions & 0 deletions py-polars/tests/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,17 @@ def test_argsort_nulls() -> None:
]
with pytest.raises(ValueError):
a.to_frame().sort(by=["a", "b"], nulls_last=True)


def test_argsort_window_functions() -> None:
df = pl.DataFrame({"Id": [1, 1, 2, 2, 3, 3], "Age": [1, 2, 3, 4, 5, 6]})
out = df.select(
[
pl.col("Age").arg_sort().over("Id").alias("arg_sort"),
pl.argsort_by("Age").over("Id").alias("argsort_by"),
]
)

assert (
out["arg_sort"].to_list() == out["argsort_by"].to_list() == [0, 1, 0, 1, 0, 1]
)

0 comments on commit bcb220f

Please sign in to comment.