Skip to content

Commit

Permalink
fix arr.eval type inference (#3203)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Apr 21, 2022
1 parent 5a46c37 commit b10233e
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 5 deletions.
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/dsl/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ pub fn concat_lst(s: Vec<Expr>) -> Expr {
Expr::Function {
input: s,
function,
output_type: GetOutput::from_type(DataType::Utf8),
output_type: GetOutput::map_dtype(|dt| DataType::List(Box::new(dt.clone()))),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
input_wildcard_expansion: true,
Expand Down
6 changes: 3 additions & 3 deletions polars/polars-lazy/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ impl ListNameSpace {
move |s| s.list()?.lst_get(index),
GetOutput::map_field(|field| match field.data_type() {
DataType::List(inner) => Field::new(field.name(), *inner.clone()),
_ => panic!("should be a list type"),
dt => panic!("should be a list type, got {:?}", dt),
}),
)
}
Expand Down Expand Up @@ -295,11 +295,11 @@ impl ListNameSpace {
.cloned()
.unwrap_or_else(|| f.data_type().clone());

let df = Series::new_empty(f.name(), &dtype).into_frame();
let df = Series::new_empty("", &dtype).into_frame();
match df.lazy().select([expr2.clone()]).collect() {
Ok(out) => {
let dtype = out.get_columns()[0].dtype();
Field::new(f.name(), dtype.clone())
Field::new(f.name(), DataType::List(Box::new(dtype.clone())))
}
Err(_) => Field::new(f.name(), DataType::Null),
}
Expand Down
4 changes: 3 additions & 1 deletion py-polars/polars/internals/lazy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1259,7 +1259,9 @@ def format(fstring: str, *args: Union["pli.Expr", str]) -> "pli.Expr":
return concat_str(exprs, sep="")


def concat_list(exprs: Sequence[Union[str, "pli.Expr", "pli.Series"]]) -> "pli.Expr":
def concat_list(
exprs: Union[Sequence[Union[str, "pli.Expr", "pli.Series"]], "pli.Expr"]
) -> "pli.Expr":
"""
Concat the arrays in a Series dtype List in linear time.
Expand Down
30 changes: 30 additions & 0 deletions py-polars/tests/test_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,33 @@ def test_cast_inner() -> None:
# this creates an inner null type
df = pl.from_pandas(pd.DataFrame(data=[[[]], [[]]], columns=["A"]))
assert df["A"].cast(pl.List(int)).dtype.inner == pl.Int64 # type: ignore


def test_list_eval_dtype_inference() -> None:
grades = pl.DataFrame(
{
"student": ["bas", "laura", "tim", "jenny"],
"arithmetic": [10, 5, 6, 8],
"biology": [4, 6, 2, 7],
"geography": [8, 4, 9, 7],
}
)

rank_pct = pl.col("").rank(reverse=True) / pl.col("").count()

# the .arr.first() would fail if .arr.eval did not correctly infer the output type
assert grades.with_column(
pl.concat_list(pl.all().exclude("student")).alias("all_grades")
).select(
[
pl.col("all_grades")
.arr.eval(rank_pct, parallel=True)
.alias("grades_rank")
.arr.first()
]
).to_series().to_list() == [
0.3333333432674408,
0.6666666865348816,
0.6666666865348816,
0.3333333432674408,
]

0 comments on commit b10233e

Please sign in to comment.