Skip to content

Commit

Permalink
improve Expr::apply
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jan 9, 2022
1 parent 3909e78 commit 01469d4
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 129 deletions.
41 changes: 40 additions & 1 deletion polars/polars-lazy/src/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,9 @@ pub struct FunctionOptions {
///
/// this also accounts for regex expansion
pub(crate) input_wildcard_expansion: bool,

/// automatically explode on unit length it ran as final aggregation.
pub(crate) auto_explode: bool,
}

#[derive(PartialEq, Clone)]
Expand Down Expand Up @@ -841,9 +844,16 @@ impl Expr {
!has_expr(&self, |e| matches!(e, Expr::Wildcard)),
"wildcard not supported in unique expr"
);
self.apply(
let options = FunctionOptions {
collect_groups: ApplyOptions::ApplyGroups,
input_wildcard_expansion: false,
auto_explode: false,
};

self.function_with_options(
move |s: Series| Ok(s.argsort(reverse).into_series()),
GetOutput::from_type(DataType::UInt32),
options,
)
}

Expand Down Expand Up @@ -920,6 +930,7 @@ impl Expr {
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
input_wildcard_expansion: false,
auto_explode: false,
},
}
}
Expand All @@ -941,6 +952,7 @@ impl Expr {
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
input_wildcard_expansion: false,
auto_explode: false,
},
}
}
Expand All @@ -965,10 +977,31 @@ impl Expr {
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyList,
input_wildcard_expansion: false,
auto_explode: false,
},
}
}

/// A function that cannot be expressed with `map` or `apply` and requires extra settings.
pub fn function_with_options<F>(
self,
function: F,
output_type: GetOutput,
options: FunctionOptions,
) -> Self
where
F: Fn(Series) -> Result<Series> + 'static + Send + Sync,
{
let f = move |s: &mut [Series]| function(std::mem::take(&mut s[0]));

Expr::Function {
input: vec![self],
function: NoEq::new(Arc::new(f)),
output_type,
options,
}
}

/// Apply a function/closure over the groups. This should only be used in a groupby aggregation.
///
/// It is the responsibility of the caller that the schema is correct by giving
Expand All @@ -991,6 +1024,7 @@ impl Expr {
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyGroups,
input_wildcard_expansion: false,
auto_explode: true,
},
}
}
Expand All @@ -1012,6 +1046,7 @@ impl Expr {
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyGroups,
input_wildcard_expansion: false,
auto_explode: true,
},
}
}
Expand Down Expand Up @@ -2165,6 +2200,7 @@ where
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
input_wildcard_expansion: true,
auto_explode: true,
},
}
} else {
Expand Down Expand Up @@ -2336,6 +2372,7 @@ where
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
input_wildcard_expansion: false,
auto_explode: false,
},
}
}
Expand All @@ -2361,6 +2398,7 @@ where
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyList,
input_wildcard_expansion: false,
auto_explode: true,
},
}
}
Expand Down Expand Up @@ -2388,6 +2426,7 @@ where
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyGroups,
input_wildcard_expansion: false,
auto_explode: true,
},
}
}
Expand Down
4 changes: 4 additions & 0 deletions polars/polars-lazy/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ pub fn argsort_by<E: AsRef<[Expr]>>(by: E, reverse: &[bool]) -> Expr {
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
input_wildcard_expansion: false,
auto_explode: true,
},
}
}
Expand All @@ -114,6 +115,7 @@ pub fn concat_str(s: Vec<Expr>, sep: &str) -> Expr {
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
input_wildcard_expansion: true,
auto_explode: false,
},
}
}
Expand Down Expand Up @@ -142,6 +144,7 @@ pub fn concat_lst(s: Vec<Expr>) -> Expr {
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
input_wildcard_expansion: true,
auto_explode: true,
},
}
}
Expand Down Expand Up @@ -304,6 +307,7 @@ pub fn datetime(
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
input_wildcard_expansion: true,
auto_explode: false,
},
}
.alias("datetime")
Expand Down
150 changes: 24 additions & 126 deletions polars/polars-lazy/src/physical_plan/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub struct ApplyExpr {
pub function: NoEq<Arc<dyn SeriesUdf>>,
pub expr: Expr,
pub collect_groups: ApplyOptions,
pub auto_explode: bool,
}

impl ApplyExpr {
Expand Down Expand Up @@ -64,6 +65,8 @@ impl PhysicalExpr for ApplyExpr {
let mut container = [Default::default()];
let name = ac.series().name().to_string();

let mut all_unit_len = true;

let mut ca: ListChunked = ac
.aggregated()
.list()
Expand All @@ -74,9 +77,14 @@ impl PhysicalExpr for ApplyExpr {
let in_len = s.len();
container[0] = s;
self.function.call_udf(&mut container).ok().map(|s| {
if s.len() != in_len {
let len = s.len();
if len != in_len {
update_group_tuples = true;
};
if len != 1 {
all_unit_len = false;
}

s
})
})
Expand All @@ -85,6 +93,7 @@ impl PhysicalExpr for ApplyExpr {

ca.rename(&name);
ac.with_series(ca.into_series(), true);
ac.with_all_unit_len(all_unit_len);
ac.with_update_groups(UpdateGroups::WithSeriesLen);
Ok(ac)
}
Expand Down Expand Up @@ -127,6 +136,7 @@ impl PhysicalExpr for ApplyExpr {

// length of the items to iterate over
let len = lists[0].len();
let mut all_unit_len = true;

let mut ca: ListChunked = (0..len)
.map(|_| {
Expand All @@ -137,12 +147,18 @@ impl PhysicalExpr for ApplyExpr {
Some(s) => container.push(s),
}
}
self.function.call_udf(&mut container).ok()
self.function.call_udf(&mut container).ok().map(|s| {
if s.len() != 1 {
all_unit_len = false;
}
s
})
})
.collect();
ca.rename(&name);
let mut ac = acs.pop().unwrap();
ac.with_series(ca.into_series(), true);
ac.with_all_unit_len(all_unit_len);
Ok(ac)
}
ApplyOptions::ApplyFlat => {
Expand All @@ -154,7 +170,7 @@ impl PhysicalExpr for ApplyExpr {
let s = self.function.call_udf(&mut s)?;
let mut ac = acs.pop().unwrap();
ac.with_update_groups(UpdateGroups::WithGroupsLen);
ac.with_series(s, true);
ac.with_series(s, false);
Ok(ac)
}
ApplyOptions::ApplyList => {
Expand Down Expand Up @@ -187,129 +203,11 @@ impl PhysicalAggregation for ApplyExpr {
groups: &GroupTuples,
state: &ExecutionState,
) -> Result<Option<Series>> {
if self.inputs.len() == 1 {
let mut ac = self.inputs[0].evaluate_on_groups(df, groups, state)?;

match self.collect_groups {
ApplyOptions::ApplyGroups => {
let mut container = [Default::default()];
let name = ac.series().name().to_string();

let mut ca: ListChunked = ac
.aggregated()
.list()
.unwrap()
.into_iter()
.map(|opt_s| {
opt_s.and_then(|s| {
container[0] = s;
self.function.call_udf(&mut container).ok()
})
})
.collect();
ca.rename(&name);
Ok(Some(ca.into_series()))
}
ApplyOptions::ApplyFlat => {
// the function needs to be called on a flat series
// but the series may be flat or aggregated
// if its flat, we just apply and return
// if not flat, the flattening sorts by group, so we must create new group tuples
// and again aggregate.
let out = self.function.call_udf(&mut [ac.flat_naive().into_owned()]);

if ac.is_not_aggregated() || !matches!(ac.series().dtype(), DataType::List(_)) {
out.map(Some)
} else {
// TODO! maybe just apply over list?
ac.with_update_groups(UpdateGroups::WithGroupsLen);
Ok(out?.agg_list(ac.groups()))
}
}
ApplyOptions::ApplyList => self
.function
.call_udf(&mut [ac.aggregated().into_owned()])
.map(Some),
}
} else {
let mut acs = self.prepare_multiple_inputs(df, groups, state)?;

match self.collect_groups {
ApplyOptions::ApplyGroups => {
let mut container = vec![Default::default(); acs.len()];
let name = acs[0].series().name().to_string();

// Don't ever try to be smart here.
// Every argument needs to be aggregated; period.
// We only work on groups in the groupby context.
// If the argument is a literal use `map`
let owned_series = acs.iter_mut().map(|ac| ac.aggregated()).collect::<Vec<_>>();

// now we make the iterators
let mut iters = owned_series
.iter()
.map(|s| {
let ca = s.list().unwrap();
ca.into_iter()
})
.collect::<Vec<_>>();

// length of the items to iterate over
let len = groups.len();

let mut ca: ListChunked = (0..len)
.map(|_| {
container.clear();
for iter in &mut iters {
match iter.next().unwrap() {
None => return None,
Some(s) => container.push(s),
}
}
self.function.call_udf(&mut container).ok()
})
.collect();
ca.rename(&name);
Ok(Some(ca.into_series()))
}
ApplyOptions::ApplyFlat => {
// the function needs to be called on a flat series
// but the series may be flat or aggregated
// if its flat, we just apply and return
// if not flat, the flattening sorts by group, so we must create new group tuples
// and again aggregate.
let name = acs[0].series().name().to_string();

// get the flat representation of the aggregation contexts
let mut container = acs
.iter_mut()
.map(|ac| {
// this is hard because the flattening sorts by group
assert!(
ac.is_not_aggregated(),
"flat apply on any expression that is already \
in aggregated state is not yet suported"
);
ac.flat_naive().into_owned()
})
.collect::<Vec<_>>();

let out = self.function.call_udf(&mut container)?;
let out = out.agg_list(acs[0].groups().as_ref()).map(|mut out| {
out.rename(&name);
out
});

Ok(out)
}
ApplyOptions::ApplyList => {
let mut s = acs
.iter_mut()
.map(|ac| ac.aggregated().into_owned())
.collect::<Vec<_>>();
self.function.call_udf(&mut s).map(Some)
}
}
let mut ac = self.evaluate_on_groups(df, groups, state)?;
let mut s = ac.aggregated().into_owned();
if ac.is_all_unit_len() && self.auto_explode {
s = s.explode().unwrap();
}
Ok(Some(s))
}
}

0 comments on commit 01469d4

Please sign in to comment.