Skip to content

Commit

Permalink
feat(python): enable "inefficient apply" warnings from Series (#10104)
Browse files Browse the repository at this point in the history
Co-authored-by: MarcoGorelli <33491632+marcogorelli@users.noreply.github.com>
Co-authored-by: Marco Edward Gorelli <marcogorelli@protonmail.com>
  • Loading branch information
3 people committed Jul 27, 2023
1 parent e50e1d9 commit db901b4
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 47 deletions.
8 changes: 3 additions & 5 deletions py-polars/polars/series/series.py
Expand Up @@ -4576,16 +4576,14 @@ def apply(
Series
"""
# TODO:
# from polars.utils.udfs import warn_on_inefficient_apply
# warn_on_inefficient_apply(
# function, columns=[self.name], apply_target="series"
# )
from polars.utils.udfs import warn_on_inefficient_apply

if return_dtype is None:
pl_return_dtype = None
else:
pl_return_dtype = py_type_to_dtype(return_dtype)

warn_on_inefficient_apply(function, columns=[self.name], apply_target="series")
return self._from_pyseries(
self._s.apply_lambda(function, pl_return_dtype, skip_nulls)
)
Expand Down
93 changes: 73 additions & 20 deletions py-polars/polars/utils/udfs.py
Expand Up @@ -2,13 +2,14 @@
from __future__ import annotations

import dis
import re
import sys
import warnings
from bisect import bisect_left
from collections import defaultdict
from dis import get_instructions
from inspect import signature
from itertools import zip_longest
from itertools import count, zip_longest
from typing import TYPE_CHECKING, Any, Callable, Iterator, Literal, NamedTuple, Union

if TYPE_CHECKING:
Expand Down Expand Up @@ -82,13 +83,15 @@ class OpNames:
UNARY_VALUES = frozenset(UNARY.values())


# numpy functions that we can map to a native expression
# numpy functions that we can map to native expressions
_NUMPY_MODULE_ALIASES = frozenset(("np", "numpy"))
_NUMPY_FUNCTIONS = frozenset(
("cbrt", "cos", "cosh", "sin", "sinh", "sqrt", "tan", "tanh")
)
# python function that we can map to a native expression

# python functions that we can map to native expressions
_PYTHON_CASTS_MAP = {"float": "Float64", "int": "Int64", "str": "Utf8"}
_PYTHON_BUILTINS = frozenset(_PYTHON_CASTS_MAP) | {"abs"}
_PYTHON_METHODS_MAP = {
"lower": "str.to_lowercase",
"title": "str.to_titlecase",
Expand All @@ -100,6 +103,7 @@ class BytecodeParser:
"""Introspect UDF bytecode and determine if we can rewrite as native expression."""

_can_rewrite: dict[str, bool]
_apply_target_name: str | None = None

def __init__(self, function: Callable[[Any], Any], apply_target: ApplyTarget):
try:
Expand Down Expand Up @@ -159,8 +163,8 @@ def _inject_nesting(
# combine logical '&' blocks (and update start/block_offsets)
prev = block_offsets[bisect_left(block_offsets, start) - 1]
expression_blocks[prev] += f" & {expression_blocks.pop(start)}"
block_offsets = list(expression_blocks.keys())
combined_offset_idxs.add(i - 1)
block_offsets.remove(start)
start = prev

if logical_op == "|":
Expand All @@ -178,6 +182,31 @@ def _inject_nesting(

return sorted(expression_blocks.items())

def _get_target_name(self, col: str, expression: str) -> str:
"""The name of the object against which the 'apply' is being invoked."""
if self._apply_target_name is not None:
return self._apply_target_name
else:
col_expr = f'pl.col("{col}")'
if self._apply_target == "expr":
return col_expr
elif self._apply_target == "series":
# note: handle overlapping name from global variables; fallback
# through "s", "srs", "series" and (finally) srs0 -> srsN...
search_expr = expression.replace(col_expr, "")
for name in ("s", "srs", "series"):
if not re.search(rf"\b{name}\b", search_expr):
self._apply_target_name = name
return name
n = count()
while True:
name = f"srs{next(n)}"
if not re.search(rf"\b{name}\b", search_expr):
self._apply_target_name = name
return name

raise NotImplementedError(f"TODO: apply_target = {self._apply_target!r}")

@property
def apply_target(self) -> ApplyTarget:
"""The apply target, eg: one of 'expr', 'frame', or 'series'."""
Expand Down Expand Up @@ -229,8 +258,9 @@ def rewritten_instructions(self) -> list[Instruction]:
"""The rewritten bytecode instructions from the function we are parsing."""
return list(self._rewritten_instructions)

def to_expression(self, col: str, as_repr: bool = True) -> str | None:
def to_expression(self, col: str) -> str | None:
"""Translate postfix bytecode instructions to polars expression/string."""
self._apply_target_name = None
if not self.can_rewrite() or self._param_name is None:
return None

Expand All @@ -246,7 +276,7 @@ def to_expression(self, col: str, as_repr: bool = True) -> str | None:
control_flow_blocks[jump_offset].append(inst)

# convert each block to a polars expression string
expression_blocks = self._inject_nesting(
expression_strings = self._inject_nesting(
{
offset: InstructionTranslator(
instructions=ops,
Expand All @@ -260,14 +290,19 @@ def to_expression(self, col: str, as_repr: bool = True) -> str | None:
},
logical_instructions,
)
polars_expr = " ".join(expr for _offset, expr in expression_blocks)
polars_expr = " ".join(expr for _offset, expr in expression_strings)

# note: if no 'pl.col' in the expression, it likely represents a compound
# constant value (e.g. `lambda x: CONST + 123`), so we don't want to warn
if "pl.col(" not in polars_expr:
return None

return polars_expr if as_repr else eval(polars_expr, globals())
elif self._apply_target == "series":
return polars_expr.replace(
f'pl.col("{col}")',
self._get_target_name(col, polars_expr),
)
else:
return polars_expr

def warn(
self,
Expand All @@ -284,7 +319,9 @@ def warn(
)

suggested_expression = suggestion_override or self.to_expression(col)

if suggested_expression is not None:
target_name = self._get_target_name(col, suggested_expression)
func_name = udf_override or self._function.__name__ or "..."
if func_name == "<lambda>":
func_name = f"lambda {self._param_name}: ..."
Expand All @@ -294,21 +331,28 @@ def warn(
if 'pl.col("")' in suggested_expression
else ""
)
if self._apply_target == "expr":
apitype = "expressions"
clsname = "Expr"
else:
apitype = "series"
clsname = "Series"

before_after_suggestion = (
(
f' \033[31m- pl.col("{col}").apply({func_name})\033[0m\n'
f" \033[31m- {target_name}.apply({func_name})\033[0m\n"
f" \033[32m+ {suggested_expression}\033[0m\n{addendum}"
)
if in_terminal_that_supports_colour()
else (
f' - pl.col("{col}").apply({func_name})\n'
f" - {target_name}.apply({func_name})\n"
f" + {suggested_expression}\n{addendum}"
)
)
warnings.warn(
"\nExpr.apply is significantly slower than the native expressions API.\n"
f"\n{clsname}.apply is significantly slower than the native {apitype} API.\n"
"Only use if you absolutely CANNOT implement your logic otherwise.\n"
"In this case, you can replace your `apply` with an expression:\n"
"In this case, you can replace your `apply` with the following:\n"
f"{before_after_suggestion}",
PolarsInefficientApplyWarning,
stacklevel=find_stacklevel(),
Expand Down Expand Up @@ -353,6 +397,12 @@ def _expr(cls, value: StackEntry, col: str, param_name: str, depth: int) -> str:
e1 = cls._expr(value.left_operand, col, param_name, depth + 1)
if value.operator_arity == 1:
if op not in OpNames.UNARY_VALUES:
if not e1.startswith("pl.col("):
# support use of consts as numpy/builtin params, eg:
# "np.sin(3) + np.cos(x)", or "len('const_string') + len(x)"
pfx = "np." if op in _NUMPY_FUNCTIONS else ""
return f"{pfx}{op}({e1})"

call = "" if op.endswith(")") else "()"
return f"{e1}.{op}{call}"
return f"{op}{e1}"
Expand Down Expand Up @@ -381,7 +431,7 @@ def _to_intermediate_stack(
self, instructions: list[Instruction], apply_target: ApplyTarget
) -> StackEntry:
"""Take postfix bytecode and convert to an intermediate natural-order stack."""
if apply_target == "expr":
if apply_target in ("expr", "series"):
stack: list[StackEntry] = []
for inst in instructions:
stack.append(
Expand Down Expand Up @@ -508,15 +558,18 @@ def _rewrite_builtins(
"""Replace builtin function calls with a synthetic POLARS_EXPRESSION op."""
if matching_instructions := self._matches(
idx,
opnames=["LOAD_GLOBAL", "LOAD_FAST", OpNames.CALL],
argvals=[_PYTHON_CASTS_MAP],
opnames=["LOAD_GLOBAL", "LOAD_*", OpNames.CALL],
argvals=[_PYTHON_BUILTINS],
):
inst1, inst2 = matching_instructions[:2]
dtype = _PYTHON_CASTS_MAP[inst1.argval]
if (argval := inst1.argval) in _PYTHON_CASTS_MAP:
dtype = _PYTHON_CASTS_MAP[argval]
argval = f"cast(pl.{dtype})"

synthetic_call = inst1._replace(
opname="POLARS_EXPRESSION",
argval=f"cast(pl.{dtype})",
argrepr=f"cast(pl.{dtype})",
argval=argval,
argrepr=argval,
offset=inst2.offset,
)
# POLARS_EXPRESSION is mapped as a unary op, so switch instruction order
Expand Down Expand Up @@ -623,7 +676,7 @@ def warn_on_inefficient_apply(
The target of the ``apply`` call. One of ``"expr"``, ``"frame"``,
or ``"series"``.
"""
if apply_target in ("frame", "series"):
if apply_target == "frame":
raise NotImplementedError("TODO: 'frame' and 'series' apply-function parsing")

# note: we only consider simple functions with a single col/param
Expand Down
22 changes: 21 additions & 1 deletion py-polars/tests/test_udfs.py
Expand Up @@ -30,7 +30,7 @@
("a", lambda x: x // 1 % 2, '(pl.col("a") // 1) % 2'),
("a", lambda x: x & True, 'pl.col("a") & True'),
("a", lambda x: x | False, 'pl.col("a") | False'),
("a", lambda x: x != 3, 'pl.col("a") != 3'),
("a", lambda x: abs(x) != 3, 'pl.col("a").abs() != 3'),
("a", lambda x: int(x) > 1, 'pl.col("a").cast(pl.Int64) > 1'),
("a", lambda x: not (x > 1) or x == 2, '~(pl.col("a") > 1) | (pl.col("a") == 2)'),
("a", lambda x: x is None, 'pl.col("a") is None'),
Expand Down Expand Up @@ -65,6 +65,11 @@
("a", lambda x: MY_CONSTANT + x, 'MY_CONSTANT + pl.col("a")'),
("a", lambda x: 0 + numpy.cbrt(x), '0 + pl.col("a").cbrt()'),
("a", lambda x: np.sin(x) + 1, 'pl.col("a").sin() + 1'),
(
"a", # note: functions operate on consts
lambda x: np.sin(3.14159265358979) + (x - 1) + abs(-3),
'(np.sin(3.14159265358979) + (pl.col("a") - 1)) + abs(-3)',
),
(
"a",
lambda x: (float(x) * int(x)) // 2,
Expand Down Expand Up @@ -113,6 +118,12 @@
("c", lambda x: json.loads(x), 'pl.col("c").str.json_extract()'),
]

NOOP_TEST_CASES = [
lambda x: x,
lambda x, y: x + y,
lambda x: x[0] + 1,
]


@pytest.mark.parametrize(
("col", "func", "expected"),
Expand All @@ -125,3 +136,12 @@ def test_bytecode_parser_expression(
bytecode_parser = udfs.BytecodeParser(func, apply_target="expr")
result = bytecode_parser.to_expression(col)
assert result == expected


@pytest.mark.parametrize(
"func",
NOOP_TEST_CASES,
)
def test_bytecode_parser_expression_noop(func: Callable[[Any], Any]) -> None:
udfs = pytest.importorskip("udfs")
assert not udfs.BytecodeParser(func, apply_target="expr").can_rewrite()

0 comments on commit db901b4

Please sign in to comment.