Skip to content

Commit

Permalink
fix(rust, python): fix explicit list + sort aggregation in groupby co… (
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 24, 2022
1 parent fba2697 commit 7e0e400
Show file tree
Hide file tree
Showing 18 changed files with 65 additions and 63 deletions.
3 changes: 0 additions & 3 deletions polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ lazy_regex = ["polars-lazy/regex"]
cum_agg = ["polars-core/cum_agg", "polars-core/cum_agg"]
rolling_window = ["polars-core/rolling_window", "polars-lazy/rolling_window", "polars-time/rolling_window"]
interpolate = ["polars-core/interpolate", "polars-lazy/interpolate"]
list = ["polars-lazy/list", "polars-ops/list"]
rank = ["polars-core/rank", "polars-lazy/rank"]
diff = ["polars-core/diff", "polars-lazy/diff", "polars-ops/diff"]
pct_change = ["polars-core/pct_change", "polars-lazy/pct_change"]
Expand Down Expand Up @@ -136,7 +135,6 @@ test = [
"private",
"rolling_window",
"rank",
"list",
"round_series",
"csv-file",
"dtype-categorical",
Expand Down Expand Up @@ -254,7 +252,6 @@ docs-selection = [
"interpolate",
"diff",
"rank",
"list",
"arange",
"diagonal_concat",
"horizontal_concat",
Expand Down
4 changes: 1 addition & 3 deletions polars/polars-lazy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ timezones = ["polars-plan/timezones"]
true_div = ["polars-plan/true_div"]

# operations
is_in = ["polars-plan/is_in", "list"]
is_in = ["polars-plan/is_in"]
repeat_by = ["polars-plan/repeat_by"]
round_series = ["polars-plan/round_series"]
is_first = ["polars-plan/is_first"]
Expand All @@ -77,7 +77,6 @@ rank = ["polars-plan/rank"]
diff = ["polars-plan/diff", "polars-plan/diff"]
pct_change = ["polars-plan/pct_change"]
moment = ["polars-plan/moment"]
list = ["polars-plan/list"]
abs = ["polars-plan/abs"]
random = ["polars-plan/random"]
dynamic_groupby = ["polars-plan/dynamic_groupby", "polars-time", "temporal"]
Expand Down Expand Up @@ -116,7 +115,6 @@ test = [
"private",
"rolling_window",
"rank",
"list",
"round_series",
"csv-file",
"dtype-categorical",
Expand Down
1 change: 0 additions & 1 deletion polars/polars-lazy/polars-plan/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ rank = ["polars-core/rank"]
diff = ["polars-core/diff", "polars-ops/diff"]
pct_change = ["polars-core/pct_change"]
moment = ["polars-core/moment"]
list = ["polars-ops/list"]
abs = ["polars-core/abs"]
random = ["polars-core/random"]
dynamic_groupby = ["polars-core/dynamic_groupby"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ impl Display for ListFunction {

let name = match self {
Concat => "concat",
#[cfg(feature = "is_in")]
Contains => "contains",
Slice => "slice",
};
Expand Down
5 changes: 0 additions & 5 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ mod dispatch;
mod fill_null;
#[cfg(feature = "is_in")]
mod is_in;
#[cfg(any(feature = "is_in", feature = "list"))]
mod list;
mod nan;
mod pow;
Expand All @@ -33,7 +32,6 @@ mod trigonometry;

use std::fmt::{Display, Formatter};

#[cfg(feature = "list")]
pub(super) use list::ListFunction;
use polars_core::prelude::*;
#[cfg(feature = "serde")]
Expand Down Expand Up @@ -91,7 +89,6 @@ pub enum FunctionExpr {
min: Option<AnyValue<'static>>,
max: Option<AnyValue<'static>>,
},
#[cfg(feature = "list")]
ListExpr(ListFunction),
#[cfg(feature = "dtype-struct")]
StructExpr(StructFunction),
Expand Down Expand Up @@ -147,7 +144,6 @@ impl Display for FunctionExpr {
(Some(_), None) => "clip_min",
_ => unreachable!(),
},
#[cfg(feature = "list")]
ListExpr(func) => return write!(f, "{}", func),
#[cfg(feature = "dtype-struct")]
StructExpr(func) => return write!(f, "{}", func),
Expand Down Expand Up @@ -301,7 +297,6 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
Clip { min, max } => {
map_owned!(clip::clip, min.clone(), max.clone())
}
#[cfg(feature = "list")]
ListExpr(lf) => {
use ListFunction::*;
match lf {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ impl FunctionExpr {
};

// map all dtypes
#[cfg(feature = "list")]
let map_dtypes = |func: &dyn Fn(&[&DataType]) -> DataType| {
let mut fld = fields[0].clone();
let dtypes = fields.iter().map(|fld| fld.data_type()).collect::<Vec<_>>();
Expand Down Expand Up @@ -58,7 +57,6 @@ impl FunctionExpr {
};

// inner super type of lists
#[cfg(feature = "list")]
let inner_super_type_list = || {
map_dtypes(&|dts| {
let mut super_type_inner = None;
Expand Down Expand Up @@ -157,7 +155,6 @@ impl FunctionExpr {
Nan(n) => n.get_field(fields),
#[cfg(feature = "round_series")]
Clip { .. } => same_type(),
#[cfg(feature = "list")]
ListExpr(l) => {
use ListFunction::*;
match l {
Expand Down
3 changes: 0 additions & 3 deletions polars/polars-lazy/polars-plan/src/dsl/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use polars_core::utils::get_supertype;

#[cfg(feature = "arg_where")]
use crate::dsl::function_expr::FunctionExpr;
#[cfg(feature = "list")]
use crate::dsl::function_expr::ListFunction;
#[cfg(feature = "strings")]
use crate::dsl::function_expr::StringFunction;
Expand Down Expand Up @@ -300,8 +299,6 @@ pub fn format_str<E: AsRef<[Expr]>>(format: &str, args: E) -> PolarsResult<Expr>
}

/// Concat lists entries.
#[cfg(feature = "list")]
#[cfg_attr(docsrs, doc(cfg(feature = "list")))]
pub fn concat_lst<E: AsRef<[IE]>, IE: Into<Expr> + Clone>(s: E) -> Expr {
let s = s.as_ref().iter().map(|e| e.clone().into()).collect();

Expand Down
5 changes: 3 additions & 2 deletions polars/polars-lazy/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use polars_core::prelude::*;
#[cfg(feature = "diff")]
use polars_core::series::ops::NullBehavior;
use polars_ops::prelude::*;

Expand Down Expand Up @@ -83,10 +84,10 @@ impl ListNameSpace {
}

/// Sort every sublist.
pub fn sort(self, reverse: bool) -> Expr {
pub fn sort(self, options: SortOptions) -> Expr {
self.0
.map(
move |s| Ok(s.list()?.lst_sort(reverse).into_series()),
move |s| Ok(s.list()?.lst_sort(options).into_series()),
GetOutput::same_type(),
)
.with_fmt("arr.sort")
Expand Down
3 changes: 0 additions & 3 deletions polars/polars-lazy/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ mod from;
pub(crate) mod function_expr;
#[cfg(feature = "compile")]
pub mod functions;
#[cfg(feature = "list")]
mod list;
#[cfg(feature = "meta")]
mod meta;
Expand All @@ -28,7 +27,6 @@ use std::sync::Arc;
pub use expr::*;
pub use function_expr::*;
pub use functions::*;
#[cfg(feature = "list")]
pub use list::*;
pub use options::*;
use polars_arrow::prelude::QuantileInterpolOptions;
Expand Down Expand Up @@ -2253,7 +2251,6 @@ impl Expr {
pub fn dt(self) -> dt::DateLikeNameSpace {
dt::DateLikeNameSpace(self)
}
#[cfg(feature = "list")]
pub fn arr(self) -> list::ListNameSpace {
list::ListNameSpace(self)
}
Expand Down
5 changes: 5 additions & 0 deletions polars/polars-lazy/src/dsl/list.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
#[cfg(feature = "list_eval")]
use std::sync::Mutex;

#[cfg(feature = "list_eval")]
use polars_arrow::utils::CustomIterTools;
#[cfg(feature = "list_eval")]
use polars_core::prelude::*;
#[cfg(feature = "list_eval")]
use polars_plan::dsl::*;
#[cfg(feature = "list_eval")]
use rayon::prelude::*;

use crate::prelude::*;
Expand Down
2 changes: 0 additions & 2 deletions polars/polars-lazy/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@
mod eval;
pub mod functions;
mod into;
#[cfg(feature = "list")]
mod list;

#[cfg(feature = "cumulative_eval")]
pub use eval::*;
pub use functions::*;
#[cfg(feature = "cumulative_eval")]
use into::IntoExpr;
#[cfg(feature = "list")]
pub use list::*;
pub use polars_plan::dsl::*;
pub use polars_plan::logical_plan::UdfSchema;
66 changes: 37 additions & 29 deletions polars/polars-lazy/src/physical_plan/expressions/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,38 +65,46 @@ impl PhysicalExpr for SortExpr {
state: &ExecutionState,
) -> PolarsResult<AggregationContext<'a>> {
let mut ac = self.physical_expr.evaluate_on_groups(df, groups, state)?;
let series = ac.flat_naive().into_owned();
match ac.agg_state() {
AggState::AggregatedList(s) => {
let ca = s.list().unwrap();
let out = ca.lst_sort(self.options);
ac.with_series(out.into_series(), true);
}
_ => {
let series = ac.flat_naive().into_owned();

let groups = match ac.groups().as_ref() {
GroupsProxy::Idx(groups) => {
groups
.iter()
.map(|(first, idx)| {
// Safety:
// Group tuples are always in bounds
let group = unsafe {
series.take_iter_unchecked(&mut idx.iter().map(|i| *i as usize))
};
let groups = match ac.groups().as_ref() {
GroupsProxy::Idx(groups) => {
groups
.iter()
.map(|(first, idx)| {
// Safety:
// Group tuples are always in bounds
let group = unsafe {
series.take_iter_unchecked(&mut idx.iter().map(|i| *i as usize))
};

let sorted_idx = group.argsort(self.options);
let new_idx = map_sorted_indices_to_group_idx(&sorted_idx, idx);
(new_idx.first().copied().unwrap_or(first), new_idx)
})
.collect()
let sorted_idx = group.argsort(self.options);
let new_idx = map_sorted_indices_to_group_idx(&sorted_idx, idx);
(new_idx.first().copied().unwrap_or(first), new_idx)
})
.collect()
}
GroupsProxy::Slice { groups, .. } => groups
.iter()
.map(|&[first, len]| {
let group = series.slice(first as i64, len as usize);
let sorted_idx = group.argsort(self.options);
let new_idx = map_sorted_indices_to_group_slice(&sorted_idx, first);
(new_idx.first().copied().unwrap_or(first), new_idx)
})
.collect(),
};
let groups = GroupsProxy::Idx(groups);
ac.with_groups(groups);
}
GroupsProxy::Slice { groups, .. } => groups
.iter()
.map(|&[first, len]| {
let group = series.slice(first as i64, len as usize);
let sorted_idx = group.argsort(self.options);
let new_idx = map_sorted_indices_to_group_slice(&sorted_idx, first);
(new_idx.first().copied().unwrap_or(first), new_idx)
})
.collect(),
};
let groups = GroupsProxy::Idx(groups);

ac.with_groups(groups);
}

Ok(ac)
}
Expand Down
3 changes: 1 addition & 2 deletions polars/polars-ops/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ propagate_nans = []

# ops
to_dummies = []
list_to_struct = ["polars-core/dtype-struct", "list"]
list = []
list_to_struct = ["polars-core/dtype-struct"]
diff = ["polars-core/diff"]
strings = ["polars-core/strings"]
string_justify = ["polars-core/strings"]
Expand Down
3 changes: 0 additions & 3 deletions polars/polars-ops/src/chunked_array/list/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@ use polars_core::prelude::*;

#[cfg(feature = "hash")]
pub(crate) mod hash;
#[cfg(feature = "list")]
#[cfg_attr(docsrs, doc(cfg(feature = "list")))]
mod namespace;
#[cfg(feature = "list_to_struct")]
mod to_struct;

#[cfg(feature = "list")]
pub use namespace::*;
#[cfg(feature = "list_to_struct")]
pub use to_struct::*;
Expand Down
5 changes: 3 additions & 2 deletions polars/polars-ops/src/chunked_array/list/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::fmt::Write;
use polars_arrow::kernels::list::sublist_get;
use polars_arrow::prelude::ValueSize;
use polars_core::chunked_array::builder::get_list_builder;
#[cfg(feature = "diff")]
use polars_core::series::ops::NullBehavior;
use polars_core::utils::{try_get_supertype, CustomIterTools};

Expand Down Expand Up @@ -133,9 +134,9 @@ pub trait ListNameSpaceImpl: AsList {
}

#[must_use]
fn lst_sort(&self, reverse: bool) -> ListChunked {
fn lst_sort(&self, options: SortOptions) -> ListChunked {
let ca = self.as_list();
ca.apply_amortized(|s| s.as_ref().sort(reverse))
ca.apply_amortized(|s| s.as_ref().sort_with(options))
}

#[must_use]
Expand Down
1 change: 0 additions & 1 deletion py-polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ features = [
"cum_agg",
"rolling_window",
"interpolate",
"list",
"rank",
"diff",
"moment",
Expand Down
5 changes: 4 additions & 1 deletion py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1302,7 +1302,10 @@ impl PyExpr {
self.inner
.clone()
.arr()
.sort(reverse)
.sort(SortOptions {
descending: reverse,
..Default::default()
})
.with_fmt("arr.sort")
.into()
}
Expand Down
10 changes: 10 additions & 0 deletions py-polars/tests/unit/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,13 @@ def test_sort_slice_fast_path_5245() -> None:
assert df.sort("foo").limit(1).select("foo").collect().to_dict(False) == {
"foo": ["a"]
}


def test_explicit_list_agg_sort_in_groupby() -> None:
df = pl.DataFrame({"A": ["a", "a", "a", "b", "b", "a"], "B": [1, 2, 3, 4, 5, 6]})
assert (
df.groupby("A")
.agg(pl.col("B").list().sort(reverse=True))
.sort("A")
.frame_equal(df.groupby("A").agg(pl.col("B").sort(reverse=True)).sort("A"))
)

0 comments on commit 7e0e400

Please sign in to comment.