Skip to content

Commit

Permalink
fix(rust, python): adhere to schema in arr.eval of empty list (#5947)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 30, 2022
1 parent 24d4209 commit e80eaff
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 30 deletions.
61 changes: 32 additions & 29 deletions polars/polars-lazy/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ pub trait ListNameSpaceExtension: IntoListNameSpace + Sized {
let func = move |s: Series| {
let lst = s.list()?;
if lst.is_empty() {
return Ok(s);
// ensure we get the new schema
let fld = field_to_dtype(lst.ref_field(), &expr);
return Ok(Series::new_empty(s.name(), fld.data_type()));
}

let phys_expr =
Expand Down Expand Up @@ -98,37 +100,38 @@ pub trait ListNameSpaceExtension: IntoListNameSpace + Sized {
this.0
.map(
func,
GetOutput::map_field(move |f| {
// dummy df to determine output dtype
let dtype = f
.data_type()
.inner_dtype()
.cloned()
.unwrap_or_else(|| f.data_type().clone());

let df = Series::new_empty("", &dtype).into_frame();

#[cfg(feature = "python")]
let out = {
use pyo3::Python;
Python::with_gil(|py| {
py.allow_threads(|| df.lazy().select([expr2.clone()]).collect())
})
};
#[cfg(not(feature = "python"))]
let out = { df.lazy().select([expr2.clone()]).collect() };

match out {
Ok(out) => {
let dtype = out.get_columns()[0].dtype();
Field::new(f.name(), DataType::List(Box::new(dtype.clone())))
}
Err(_) => Field::new(f.name(), DataType::Null),
}
}),
GetOutput::map_field(move |f| field_to_dtype(f, &expr2)),
)
.with_fmt("eval")
}
}

#[cfg(feature = "list_eval")]
fn field_to_dtype(f: &Field, expr: &Expr) -> Field {
// dummy df to determine output dtype
let dtype = f
.data_type()
.inner_dtype()
.cloned()
.unwrap_or_else(|| f.data_type().clone());

let df = Series::new_empty("", &dtype).into_frame();

#[cfg(feature = "python")]
let out = {
use pyo3::Python;
Python::with_gil(|py| py.allow_threads(|| df.lazy().select([expr.clone()]).collect()))
};
#[cfg(not(feature = "python"))]
let out = { df.lazy().select([expr.clone()]).collect() };

match out {
Ok(out) => {
let dtype = out.get_columns()[0].dtype();
Field::new(f.name(), DataType::List(Box::new(dtype.clone())))
}
Err(_) => Field::new(f.name(), DataType::Null),
}
}

impl ListNameSpaceExtension for ListNameSpace {}
7 changes: 6 additions & 1 deletion py-polars/polars/internals/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2517,7 +2517,12 @@ def insert_at_idx(self: DF, index: int, series: pli.Series) -> DF:

def filter(
self,
predicate: pli.Expr | str | pli.Series | list[bool] | np.ndarray[Any, Any],
predicate: pli.Expr
| str
| pli.Series
| list[bool]
| np.ndarray[Any, Any]
| bool,
) -> DataFrame:
"""
Filter the rows in the DataFrame based on a predicate expression.
Expand Down
7 changes: 7 additions & 0 deletions py-polars/tests/unit/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,10 @@ def test_all_null_cast_5826() -> None:
out = df.with_column(pl.col("a").cast(pl.Boolean))
assert out.dtypes == [pl.Boolean]
assert out.item() is None


def test_emtpy_list_eval_schema_5734() -> None:
df = pl.DataFrame({"a": [[{"b": 1, "c": 2}]]})
assert df.filter(False).select(
pl.col("a").arr.eval(pl.element().struct.field("b"))
).schema == {"a": pl.List(pl.Int64)}

0 comments on commit e80eaff

Please sign in to comment.