Skip to content

Commit

Permalink
fix(rust, python): convert panic to err in concat_list (#5637)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Nov 26, 2022
1 parent c432d5b commit 7360c66
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
12 changes: 6 additions & 6 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ impl FunctionExpr {
};

// map all dtypes
let map_dtypes = |func: &dyn Fn(&[&DataType]) -> DataType| {
let try_map_dtypes = |func: &dyn Fn(&[&DataType]) -> PolarsResult<DataType>| {
let mut fld = fields[0].clone();
let dtypes = fields.iter().map(|fld| fld.data_type()).collect::<Vec<_>>();
let new_type = func(&dtypes);
let new_type = func(&dtypes)?;
fld.coerce(new_type);
Ok(fld)
};
Expand Down Expand Up @@ -69,26 +69,26 @@ impl FunctionExpr {

// inner super type of lists
let inner_super_type_list = || {
map_dtypes(&|dts| {
try_map_dtypes(&|dts| {
let mut super_type_inner = None;

for dt in dts {
match dt {
DataType::List(inner) => match super_type_inner {
None => super_type_inner = Some(*inner.clone()),
Some(st_inner) => {
super_type_inner = try_get_supertype(&st_inner, inner).ok()
super_type_inner = Some(try_get_supertype(&st_inner, inner)?)
}
},
dt => match super_type_inner {
None => super_type_inner = Some((*dt).clone()),
Some(st_inner) => {
super_type_inner = try_get_supertype(&st_inner, dt).ok()
super_type_inner = Some(try_get_supertype(&st_inner, dt)?)
}
},
}
}
DataType::List(Box::new(super_type_inner.unwrap()))
Ok(DataType::List(Box::new(super_type_inner.unwrap())))
})
};

Expand Down
12 changes: 12 additions & 0 deletions py-polars/tests/unit/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,3 +293,15 @@ def test_invalid_sort_by() -> None:
match="The sortby operation produced a different length than the Series that has to be sorted.", # noqa: E501
):
df.select(pl.col("a").filter(pl.col("b") == "M").sort_by("c", True))


def test_concat_list_err_supertype() -> None:
df = pl.DataFrame({"nums": [1, 2, 3, 4], "letters": ["a", "b", "c", "d"]}).select(
[
pl.col("nums"),
pl.struct(["letters", "nums"]).alias("combo"),
pl.struct(["nums", "letters"]).alias("reverse_combo"),
]
)
with pytest.raises(pl.ComputeError, match="Failed to determine supertype"):
df.select(pl.concat_list(["combo", "reverse_combo"]))

0 comments on commit 7360c66

Please sign in to comment.