Skip to content

Commit

Permalink
ENH add *args support for numba apply (#58767)
Browse files Browse the repository at this point in the history
* add *args for raw numba apply

* add whatsnew

* fix test_case

* fix pre-commit

* fix test case

* add *args for raw=False as well; merge tests together

* add prepare_function_arguments

* fix mypy

* update get_jit_arguments

* add nopython test in `test_apply_args`

* fix test

* fix pre-commit
  • Loading branch information
auderson committed Jun 11, 2024
1 parent 42f785f commit bbe0e53
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 27 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,7 @@ Other
- Bug in :class:`DataFrame` when passing a ``dict`` with a NA scalar and ``columns`` that would always return ``np.nan`` (:issue:`57205`)
- Bug in :func:`eval` where the names of the :class:`Series` were not preserved when using ``engine="numexpr"``. (:issue:`10239`)
- Bug in :func:`unique` on :class:`Index` not always returning :class:`Index` (:issue:`57043`)
- Bug in :meth:`DataFrame.apply` where passing ``engine="numba"`` ignored ``args`` passed to the applied function (:issue:`58712`)
- Bug in :meth:`DataFrame.eval` and :meth:`DataFrame.query` which caused an exception when using NumPy attributes via ``@`` notation, e.g., ``df.eval("@np.floor(a)")``. (:issue:`58041`)
- Bug in :meth:`DataFrame.eval` and :meth:`DataFrame.query` which did not allow to use ``tan`` function. (:issue:`55091`)
- Bug in :meth:`DataFrame.sort_index` when passing ``axis="columns"`` and ``ignore_index=True`` and ``ascending=False`` not returning a :class:`RangeIndex` columns (:issue:`57293`)
Expand Down
12 changes: 7 additions & 5 deletions pandas/core/_numba/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@

from pandas.compat._optional import import_optional_dependency

from pandas.core.util.numba_ import jit_user_function


@functools.cache
def generate_apply_looper(func, nopython=True, nogil=True, parallel=False):
if TYPE_CHECKING:
import numba
else:
numba = import_optional_dependency("numba")
nb_compat_func = numba.extending.register_jitable(func)
nb_compat_func = jit_user_function(func)

@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
def nb_looper(values, axis):
def nb_looper(values, axis, *args):
# Operate on the first row/col in order to get
# the output shape
if axis == 0:
Expand All @@ -33,7 +35,7 @@ def nb_looper(values, axis):
else:
first_elem = values[0]
dim0 = values.shape[0]
res0 = nb_compat_func(first_elem)
res0 = nb_compat_func(first_elem, *args)
# Use np.asarray to get shape for
# https://github.com/numba/numba/issues/4202#issuecomment-1185981507
buf_shape = (dim0,) + np.atleast_1d(np.asarray(res0)).shape
Expand All @@ -44,11 +46,11 @@ def nb_looper(values, axis):
if axis == 1:
buff[0] = res0
for i in numba.prange(1, values.shape[0]):
buff[i] = nb_compat_func(values[i])
buff[i] = nb_compat_func(values[i], *args)
else:
buff[:, 0] = res0
for j in numba.prange(1, values.shape[1]):
buff[:, j] = nb_compat_func(values[:, j])
buff[:, j] = nb_compat_func(values[:, j], *args)
return buff

return nb_looper
Expand Down
36 changes: 23 additions & 13 deletions pandas/core/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@
from pandas.core._numba.executor import generate_apply_looper
import pandas.core.common as com
from pandas.core.construction import ensure_wrapped_if_datetimelike
from pandas.core.util.numba_ import (
get_jit_arguments,
prepare_function_arguments,
)

if TYPE_CHECKING:
from collections.abc import (
Expand All @@ -70,7 +74,6 @@
from pandas.core.resample import Resampler
from pandas.core.window.rolling import BaseWindow


ResType = dict[int, Any]


Expand Down Expand Up @@ -997,17 +1000,20 @@ def wrapper(*args, **kwargs):
return wrapper

if engine == "numba":
engine_kwargs = {} if engine_kwargs is None else engine_kwargs

args, kwargs = prepare_function_arguments(
self.func, # type: ignore[arg-type]
self.args,
self.kwargs,
)
# error: Argument 1 to "__call__" of "_lru_cache_wrapper" has
# incompatible type "Callable[..., Any] | str | list[Callable
# [..., Any] | str] | dict[Hashable,Callable[..., Any] | str |
# list[Callable[..., Any] | str]]"; expected "Hashable"
nb_looper = generate_apply_looper(
self.func, # type: ignore[arg-type]
**engine_kwargs,
**get_jit_arguments(engine_kwargs, kwargs),
)
result = nb_looper(self.values, self.axis)
result = nb_looper(self.values, self.axis, *args)
# If we made the result 2-D, squeeze it back to 1-D
result = np.squeeze(result)
else:
Expand Down Expand Up @@ -1148,21 +1154,23 @@ def generate_numba_apply_func(
# Currently the parallel argument doesn't get passed through here
# (it's disabled) since the dicts in numba aren't thread-safe.
@numba.jit(nogil=nogil, nopython=nopython, parallel=parallel)
def numba_func(values, col_names, df_index):
def numba_func(values, col_names, df_index, *args):
results = {}
for j in range(values.shape[1]):
# Create the series
ser = Series(
values[:, j], index=df_index, name=maybe_cast_str(col_names[j])
)
results[j] = jitted_udf(ser)
results[j] = jitted_udf(ser, *args)
return results

return numba_func

def apply_with_numba(self) -> dict[int, Any]:
func = cast(Callable, self.func)
args, kwargs = prepare_function_arguments(func, self.args, self.kwargs)
nb_func = self.generate_numba_apply_func(
cast(Callable, self.func), **self.engine_kwargs
func, **get_jit_arguments(self.engine_kwargs, kwargs)
)
from pandas.core._numba.extensions import set_numba_data

Expand All @@ -1177,7 +1185,7 @@ def apply_with_numba(self) -> dict[int, Any]:
# Convert from numba dict to regular dict
# Our isinstance checks in the df constructor don't pass for numbas typed dict
with set_numba_data(index) as index, set_numba_data(columns) as columns:
res = dict(nb_func(self.values, columns, index))
res = dict(nb_func(self.values, columns, index, *args))
return res

@property
Expand Down Expand Up @@ -1285,7 +1293,7 @@ def generate_numba_apply_func(
jitted_udf = numba.extending.register_jitable(func)

@numba.jit(nogil=nogil, nopython=nopython, parallel=parallel)
def numba_func(values, col_names_index, index):
def numba_func(values, col_names_index, index, *args):
results = {}
# Currently the parallel argument doesn't get passed through here
# (it's disabled) since the dicts in numba aren't thread-safe.
Expand All @@ -1297,15 +1305,17 @@ def numba_func(values, col_names_index, index):
index=col_names_index,
name=maybe_cast_str(index[i]),
)
results[i] = jitted_udf(ser)
results[i] = jitted_udf(ser, *args)

return results

return numba_func

def apply_with_numba(self) -> dict[int, Any]:
func = cast(Callable, self.func)
args, kwargs = prepare_function_arguments(func, self.args, self.kwargs)
nb_func = self.generate_numba_apply_func(
cast(Callable, self.func), **self.engine_kwargs
func, **get_jit_arguments(self.engine_kwargs, kwargs)
)

from pandas.core._numba.extensions import set_numba_data
Expand All @@ -1316,7 +1326,7 @@ def apply_with_numba(self) -> dict[int, Any]:
set_numba_data(self.obj.index) as index,
set_numba_data(self.columns) as columns,
):
res = dict(nb_func(self.values, columns, index))
res = dict(nb_func(self.values, columns, index, *args))

return res

Expand Down
56 changes: 53 additions & 3 deletions pandas/core/util/numba_.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import inspect
import types
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -54,10 +55,15 @@ def get_jit_arguments(
engine_kwargs = {}

nopython = engine_kwargs.get("nopython", True)
if kwargs and nopython:
if kwargs:
# Note: in case numba supports keyword-only arguments in
# a future version, we should remove this check. But this
# seems unlikely to happen soon.

raise NumbaUtilError(
"numba does not support kwargs with nopython=True: "
"https://github.com/numba/numba/issues/2916"
"numba does not support keyword-only arguments"
"https://github.com/numba/numba/issues/2916, "
"https://github.com/numba/numba/issues/6846"
)
nogil = engine_kwargs.get("nogil", False)
parallel = engine_kwargs.get("parallel", False)
Expand Down Expand Up @@ -97,3 +103,47 @@ def jit_user_function(func: Callable) -> Callable:
numba_func = numba.extending.register_jitable(func)

return numba_func


_sentinel = object()


def prepare_function_arguments(
func: Callable, args: tuple, kwargs: dict
) -> tuple[tuple, dict]:
"""
Prepare arguments for jitted function. As numba functions do not support kwargs,
we try to move kwargs into args if possible.
Parameters
----------
func : function
user defined function
args : tuple
user input positional arguments
kwargs : dict
user input keyword arguments
Returns
-------
tuple[tuple, dict]
args, kwargs
"""
if not kwargs:
return args, kwargs

# the udf should have this pattern: def udf(value, *args, **kwargs):...
signature = inspect.signature(func)
arguments = signature.bind(_sentinel, *args, **kwargs)
arguments.apply_defaults()
# Ref: https://peps.python.org/pep-0362/
# Arguments which could be passed as part of either *args or **kwargs
# will be included only in the BoundArguments.args attribute.
args = arguments.args
kwargs = arguments.kwargs

assert args[0] is _sentinel
args = args[1:]

return args, kwargs
54 changes: 49 additions & 5 deletions pandas/tests/apply/test_frame_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,60 @@ def test_apply(float_frame, engine, request):

@pytest.mark.parametrize("axis", [0, 1])
@pytest.mark.parametrize("raw", [True, False])
def test_apply_args(float_frame, axis, raw, engine, request):
if engine == "numba":
mark = pytest.mark.xfail(reason="numba engine doesn't support args")
request.node.add_marker(mark)
@pytest.mark.parametrize("nopython", [True, False])
def test_apply_args(float_frame, axis, raw, engine, nopython):
engine_kwargs = {"nopython": nopython}
result = float_frame.apply(
lambda x, y: x + y, axis, args=(1,), raw=raw, engine=engine
lambda x, y: x + y,
axis,
args=(1,),
raw=raw,
engine=engine,
engine_kwargs=engine_kwargs,
)
expected = float_frame + 1
tm.assert_frame_equal(result, expected)

# GH:58712
result = float_frame.apply(
lambda x, a, b: x + a + b,
args=(1,),
b=2,
raw=raw,
engine=engine,
engine_kwargs=engine_kwargs,
)
expected = float_frame + 3
tm.assert_frame_equal(result, expected)

if engine == "numba":
# keyword-only arguments are not supported in numba
with pytest.raises(
pd.errors.NumbaUtilError,
match="numba does not support keyword-only arguments",
):
float_frame.apply(
lambda x, a, *, b: x + a + b,
args=(1,),
b=2,
raw=raw,
engine=engine,
engine_kwargs=engine_kwargs,
)

with pytest.raises(
pd.errors.NumbaUtilError,
match="numba does not support keyword-only arguments",
):
float_frame.apply(
lambda *x, b: x[0] + x[1] + b,
args=(1,),
b=2,
raw=raw,
engine=engine,
engine_kwargs=engine_kwargs,
)


def test_apply_categorical_func():
# GH 9573
Expand Down
4 changes: 3 additions & 1 deletion pandas/tests/window/test_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,9 @@ def f(x):

@td.skip_if_no("numba")
def test_invalid_kwargs_nopython():
with pytest.raises(NumbaUtilError, match="numba does not support kwargs with"):
with pytest.raises(
NumbaUtilError, match="numba does not support keyword-only arguments"
):
Series(range(1)).rolling(1).apply(
lambda x: x, kwargs={"a": 1}, engine="numba", raw=True
)
Expand Down

0 comments on commit bbe0e53

Please sign in to comment.