Skip to content

Commit

Permalink
check output length of all 'map' expressions (#3052)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Apr 3, 2022
1 parent e7414e4 commit 6bb03b5
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
16 changes: 12 additions & 4 deletions polars/polars-lazy/src/physical_plan/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ fn all_unit_length(ca: &ListChunked) -> bool {
(offset[offset.len() - 1] as usize) == list_arr.len() as usize
}

fn check_map_output_len(input_len: usize, output_len: usize) -> Result<()> {
if input_len != output_len {
Err(PolarsError::ComputeError("A 'map' functions output length must be equal to that of the input length. Consider using 'apply' in favor of 'map'.".into()))
} else {
Ok(())
}
}

impl PhysicalExpr for ApplyExpr {
fn as_expression(&self) -> &Expr {
&self.expr
Expand Down Expand Up @@ -115,10 +123,7 @@ impl PhysicalExpr for ApplyExpr {
let input_len = input.len();
let s = self.function.call_udf(&mut [input])?;

if s.len() != input_len {
return Err(PolarsError::ComputeError("A map function may never return a Series of a different length than its input".into()));
}

check_map_output_len(input_len, s.len())?;
ac.with_series(s, false);
Ok(ac)
}
Expand Down Expand Up @@ -176,7 +181,10 @@ impl PhysicalExpr for ApplyExpr {
.map(|ac| ac.flat_naive().into_owned())
.collect::<Vec<_>>();

let input_len = s.iter().map(|s| s.len()).max().unwrap();
let s = self.function.call_udf(&mut s)?;
check_map_output_len(input_len, s.len())?;

let mut ac = acs.pop().unwrap();
if ac.is_aggregated() {
ac.with_update_groups(UpdateGroups::WithGroupsLen);
Expand Down
13 changes: 13 additions & 0 deletions py-polars/tests/test_errors.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import pytest

import polars as pl
Expand All @@ -8,3 +9,15 @@ def test_error_on_empty_groupby() -> None:
pl.ComputeError, match="expected keys in groupby operation, got nothing"
):
pl.DataFrame(dict(x=[0, 0, 1, 1])).groupby([]).agg(pl.count())


def test_error_on_reducing_map() -> None:
df = pl.DataFrame(
dict(id=[0, 0, 0, 1, 1, 1], t=[2, 4, 5, 10, 11, 14], y=[0, 1, 1, 2, 3, 4])
)

with pytest.raises(
pl.ComputeError,
match="A 'map' functions output length must be equal to that of the input length. Consider using 'apply' in favor of 'map'.",
):
df.groupby("id").agg(pl.map(["t", "y"], np.trapz))

0 comments on commit 6bb03b5

Please sign in to comment.