Skip to content

Commit

Permalink
fix(rust): don't set auto-explode in apply_multiple (#5265)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 25, 2022
1 parent d4dd828 commit 857ffa1
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 46 deletions.
11 changes: 9 additions & 2 deletions polars/polars-lazy/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2373,7 +2373,12 @@ where
///
/// * `[map_mul]` should be used for operations that are independent of groups, e.g. `multiply * 2`, or `raise to the power`
/// * `[apply_mul]` should be used for operations that work on a group of data. e.g. `sum`, `count`, etc.
pub fn apply_multiple<F, E>(function: F, expr: E, output_type: GetOutput) -> Expr
pub fn apply_multiple<F, E>(
function: F,
expr: E,
output_type: GetOutput,
returns_scalar: bool,
) -> Expr
where
F: Fn(&mut [Series]) -> PolarsResult<Series> + 'static + Send + Sync,
E: AsRef<[Expr]>,
Expand All @@ -2386,7 +2391,9 @@ where
output_type,
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyGroups,
auto_explode: true,
// don't set this to true
// this is for the caller to decide
auto_explode: returns_scalar,
fmt_str: "",
..Default::default()
},
Expand Down
40 changes: 0 additions & 40 deletions polars/polars-lazy/src/tests/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1674,46 +1674,6 @@ fn test_groupby_rank() -> PolarsResult<()> {
Ok(())
}

#[test]
fn test_apply_multiple_columns() -> PolarsResult<()> {
let df = fruits_cars();

let multiply = |s: &mut [Series]| Ok(&(&s[0] * &s[0]) * &s[1]);

let out = df
.clone()
.lazy()
.select([map_multiple(
multiply,
[col("A"), col("B")],
GetOutput::from_type(DataType::Float64),
)])
.collect()?;
let out = out.column("A")?;
let out = out.i32()?;
assert_eq!(
Vec::from(out),
&[Some(5), Some(16), Some(27), Some(32), Some(25)]
);

let out = df
.lazy()
.groupby_stable([col("cars")])
.agg([apply_multiple(
multiply,
[col("A"), col("B")],
GetOutput::from_type(DataType::Float64),
)])
.collect()?;

let out = out.column("A")?;
let out = out.list()?.get(1).unwrap();
let out = out.i32()?;

assert_eq!(Vec::from(out), &[Some(16)]);
Ok(())
}

#[test]
pub fn test_select_by_dtypes() -> PolarsResult<()> {
let df = df![
Expand Down
41 changes: 41 additions & 0 deletions polars/tests/it/lazy/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,44 @@ fn test_unknown_supertype_ignore() -> PolarsResult<()> {
assert_eq!(out.shape(), (4, 2));
Ok(())
}

#[test]
fn test_apply_multiple_columns() -> PolarsResult<()> {
let df = fruits_cars();

let multiply = |s: &mut [Series]| Ok(&(&s[0] * &s[0]) * &s[1]);

let out = df
.clone()
.lazy()
.select([map_multiple(
multiply,
[col("A"), col("B")],
GetOutput::from_type(DataType::Float64),
)])
.collect()?;
let out = out.column("A")?;
let out = out.i32()?;
assert_eq!(
Vec::from(out),
&[Some(5), Some(16), Some(27), Some(32), Some(25)]
);

let out = df
.lazy()
.groupby_stable([col("cars")])
.agg([apply_multiple(
multiply,
[col("A"), col("B")],
GetOutput::from_type(DataType::Float64),
true,
)])
.collect()?;

let out = out.column("A")?;
let out = out.list()?.get(1).unwrap();
let out = out.i32()?;

assert_eq!(Vec::from(out), &[Some(16)]);
Ok(())
}
13 changes: 11 additions & 2 deletions py-polars/polars/internals/lazy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,13 +968,16 @@ def map(
"""
exprs = pli.selection_to_pyexpr_list(exprs)
return pli.wrap_expr(_map_mul(exprs, f, return_dtype, apply_groups=False))
return pli.wrap_expr(
_map_mul(exprs, f, return_dtype, apply_groups=False, returns_scalar=False)
)


def apply(
exprs: Sequence[str | pli.Expr],
f: Callable[[Sequence[pli.Series]], pli.Series | Any],
return_dtype: type[DataType] | None = None,
returns_scalar: bool = True,
) -> pli.Expr:
"""
Apply a custom/user-defined function (UDF) in a GroupBy context.
Expand All @@ -995,14 +998,20 @@ def apply(
Function to apply over the input
return_dtype
dtype of the output Series
returns_scalar
If the function returns a single scalar as output.
Returns
-------
Expr
"""
exprs = pli.selection_to_pyexpr_list(exprs)
return pli.wrap_expr(_map_mul(exprs, f, return_dtype, apply_groups=True))
return pli.wrap_expr(
_map_mul(
exprs, f, return_dtype, apply_groups=True, returns_scalar=returns_scalar
)
)


def fold(
Expand Down
3 changes: 2 additions & 1 deletion py-polars/src/lazy/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ pub fn map_mul(
lambda: PyObject,
output_type: Option<Wrap<DataType>>,
apply_groups: bool,
returns_scalar: bool,
) -> PyExpr {
// get the pypolars module
// do the import outside of the function to prevent import side effects in a hot loop.
Expand All @@ -212,7 +213,7 @@ pub fn map_mul(
None => fld.clone(),
});
if apply_groups {
polars::lazy::dsl::apply_multiple(function, exprs, output_map).into()
polars::lazy::dsl::apply_multiple(function, exprs, output_map, returns_scalar).into()
} else {
polars::lazy::dsl::map_multiple(function, exprs, output_map).into()
}
Expand Down
10 changes: 9 additions & 1 deletion py-polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,8 +442,16 @@ pub fn map_mul(
lambda: PyObject,
output_type: Option<Wrap<DataType>>,
apply_groups: bool,
returns_scalar: bool,
) -> PyExpr {
lazy::map_mul(&pyexpr, py, lambda, output_type, apply_groups)
lazy::map_mul(
&pyexpr,
py,
lambda,
output_type,
apply_groups,
returns_scalar,
)
}

#[pyfunction]
Expand Down

0 comments on commit 857ffa1

Please sign in to comment.