Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support CSE on python UDFs #16637

Open
kszlim opened this issue May 31, 2024 · 10 comments
Open

Support CSE on python UDFs #16637

kszlim opened this issue May 31, 2024 · 10 comments
Labels
enhancement New feature or an improvement of an existing feature

Comments

@kszlim
Copy link
Contributor

kszlim commented May 31, 2024

Description

I'm guessing CSE isn't supported because python UDFs can potentially be stateful. Could we make it so that map_* methods on pl.Expr can take in a is_pure parameter that will let these get CSE'd?

@kszlim kszlim added the enhancement New feature or an improvement of an existing feature label May 31, 2024
@ritchie46
Copy link
Member

I am not sure about that. We will not call into python for equality of functions and pointers checking failed in the past.

@kszlim
Copy link
Contributor Author

kszlim commented Jun 1, 2024

Can these be special cased so that a user can say that it's safe to CSE this expression? This ends up being a pretty big annoyance in some cases and makes certain programming patterns ugly.

@kszlim
Copy link
Contributor Author

kszlim commented Jun 1, 2024

At work, I essentially provide a framework where users pass me expressions and I apply them to a base table as well as adding an over to the user provided expressions. The only way to avoid UDFs from being recomputed would be by referencing them by name in a later context. It's okay if the UDF is cheap, but some of them are quite expensive, so having CSE work would be great.

@ritchie46
Copy link
Member

I think this can create many bugs, which I don't want open at this point in time. We can look at enabling it for UDF's later.

@kszlim
Copy link
Contributor Author

kszlim commented Jun 3, 2024

Even if it's completely opt in? This is a bit of a blocker for me, I'm curious if it'd be possible to bring back the pl.Expr.cache method as an alternative to this instead?

@avimallu
Copy link
Contributor

avimallu commented Jun 3, 2024

The only way to avoid UDFs from being recomputed would be by referencing them by name in a later context.

Does lru_cache not work for your case?

@kszlim
Copy link
Contributor Author

kszlim commented Jun 3, 2024

The only way to avoid UDFs from being recomputed would be by referencing them by name in a later context.

Does lru_cache not work for your case?

@avimallu I don't think you understand the feature request.

import polars as pl
ldf = pl.LazyFrame({"a": [1, 2, 3]})
udf_expr = pl.col("a").map_batches(lambda x: x*2).alias("b")
derived_expr_0 = udf_expr.mul(2).alias("c")
derived_expr_1 = udf_expr.mul(3).alias("d")
ldf = ldf.with_columns(udf_expr, derived_expr_0, derived_expr_1)
print(ldf.explain())

You'll notice that:

 WITH_COLUMNS:
 [col("a").python_udf().alias("b"), [(col("a").python_udf()) * (2.cast(Unknown(Any)))].alias("c"), [(col("a").python_udf()) * (3.cast(Unknown(Any)))].alias("d")], []
  DF ["a"]; PROJECT */1 COLUMNS; SELECTION: None

Will print out, meaning that the udf gets evaluated 3x.

Contrast it with:

import polars as pl
ldf = pl.LazyFrame({"a": [1, 2, 3]})
udf_expr = pl.col("a").mul(2).alias("b")
derived_expr_0 = udf_expr.mul(2).alias("c")
derived_expr_1 = udf_expr.mul(3).alias("d")
ldf = ldf.with_columns(udf_expr, derived_expr_0, derived_expr_1)
print(ldf.explain())

Which will print out:

 WITH_COLUMNS:
 [col("__POLARS_CSER_0xd39686281a38356a").alias("b"), [(col("__POLARS_CSER_0xd39686281a38356a")) * (2)].alias("c"), [(col("__POLARS_CSER_0xd39686281a38356a")) * (3)].alias("d")], [[(col("a")) * (2)].alias("__POLARS_CSER_0xd39686281a38356a")]
  DF ["a"]; PROJECT */1 COLUMNS; SELECTION: None

This requires polars side changes or you have to explicitly write your query/code like:

import polars as pl
ldf = pl.LazyFrame({"a": [1, 2, 3]})
udf_expr = pl.col("a").map_batches(lambda x: x*2).alias("b")
derived_expr_0 = pl.col("b").mul(2).alias("c")
derived_expr_1 = pl.col("b").mul(3).alias("d")
ldf = ldf.with_columns(udf_expr)
ldf = ldf.with_columns(derived_expr_0, derived_expr_1)
print(ldf.explain())

But in my use case, users build up trees of expressions which they pass to my framework to evaluate, which becomes very ugly if CSE isn't supported, then it breaks the abstraction.

@deanm0000
Copy link
Collaborator

@avimallu I had the same thought on the lru_cache but it doesn't work because neither pl.Series or even np.ndarray are hashable for the underlying cache.

@ritchie46
Copy link
Member

Ping me next week. I will see if I can put something behind an env var.

@kszlim
Copy link
Contributor Author

kszlim commented Jun 10, 2024

@ritchie46 if you're not swamped this would be great! Or if you can give me some high level guidance I can try to take a crack at this if you don't think it's too complicated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or an improvement of an existing feature
Projects
None yet
Development

No branches or pull requests

4 participants