Skip to content

Commit

Permalink
python: ensure type information is passed in custom apply expression (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 9, 2022
1 parent a61bc54 commit 23d27ef
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 2 deletions.
2 changes: 1 addition & 1 deletion py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2828,7 +2828,7 @@ def apply(
def wrap_f(x: pli.Series) -> pli.Series: # pragma: no cover
return x.apply(f, return_dtype=return_dtype)

return self.map(wrap_f, agg_list=True)
return self.map(wrap_f, agg_list=True, return_dtype=return_dtype)

def flatten(self) -> Expr:
"""
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/internals/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -6980,6 +6980,6 @@ def apply(
for name in self.selection:
s = df.drop_in_place(name + "_agg_list").apply(func, return_dtype)
s.rename(name, in_place=True)
df[name] = s
df.with_column(s)

return df
24 changes: 24 additions & 0 deletions py-polars/tests/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,27 @@ def test_apply_all_types() -> None:
# test we don't panic
for dtype in dtypes:
pl.Series([1, 2, 3, 4, 5], dtype=dtype).apply(lambda x: x)


def test_apply_type_propagation() -> None:
assert (
pl.from_dict(
{
"a": [1, 2, 3],
"b": [{"c": 1, "d": 2}, {"c": 2, "d": 3}, {"c": None, "d": None}],
}
)
.groupby("a", maintain_order=True)
.agg(
[
pl.when(pl.col("b").null_count() == 0)
.then(
pl.col("b").apply(
lambda s: s[0]["c"],
return_dtype=pl.Float64,
)
)
.otherwise(None)
]
)
).to_dict(False) == {"a": [1, 2, 3], "b": [1.0, 2.0, None]}

0 comments on commit 23d27ef

Please sign in to comment.