Skip to content

Commit

Permalink
add prepare_function_arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
auderson committed May 21, 2024
1 parent 96581a3 commit 2aae933
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 10 deletions.
4 changes: 3 additions & 1 deletion pandas/core/_numba/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@

from pandas.compat._optional import import_optional_dependency

from pandas.core.util.numba_ import jit_user_function


@functools.cache
def generate_apply_looper(func, nopython=True, nogil=True, parallel=False):
if TYPE_CHECKING:
import numba
else:
numba = import_optional_dependency("numba")
nb_compat_func = numba.extending.register_jitable(func)
nb_compat_func = jit_user_function(func)

@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
def nb_looper(values, axis, *args):
Expand Down
20 changes: 13 additions & 7 deletions pandas/core/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@
from pandas.core._numba.executor import generate_apply_looper
import pandas.core.common as com
from pandas.core.construction import ensure_wrapped_if_datetimelike
from pandas.core.util.numba_ import get_jit_arguments
from pandas.core.util.numba_ import (
get_jit_arguments,
prepare_function_arguments,
)

if TYPE_CHECKING:
from collections.abc import (
Expand Down Expand Up @@ -973,15 +976,16 @@ def wrapper(*args, **kwargs):
return wrapper

if engine == "numba":
args, kwargs = prepare_function_arguments(self.func, self.args, self.kwargs)
# error: Argument 1 to "__call__" of "_lru_cache_wrapper" has
# incompatible type "Callable[..., Any] | str | list[Callable
# [..., Any] | str] | dict[Hashable,Callable[..., Any] | str |
# list[Callable[..., Any] | str]]"; expected "Hashable"
nb_looper = generate_apply_looper(
self.func, # type: ignore[arg-type]
**get_jit_arguments(engine_kwargs, self.kwargs),
**get_jit_arguments(engine_kwargs, kwargs),
)
result = nb_looper(self.values, self.axis, *self.args)
result = nb_looper(self.values, self.axis, *args)
# If we made the result 2-D, squeeze it back to 1-D
result = np.squeeze(result)
else:
Expand Down Expand Up @@ -1135,9 +1139,10 @@ def numba_func(values, col_names, df_index, *args):
return numba_func

def apply_with_numba(self) -> dict[int, Any]:
args, kwargs = prepare_function_arguments(self.func, self.args, self.kwargs)
nb_func = self.generate_numba_apply_func(
cast(Callable, self.func),
**get_jit_arguments(self.engine_kwargs, self.kwargs),
**get_jit_arguments(self.engine_kwargs, kwargs),
)
from pandas.core._numba.extensions import set_numba_data

Expand All @@ -1152,7 +1157,7 @@ def apply_with_numba(self) -> dict[int, Any]:
# Convert from numba dict to regular dict
# Our isinstance checks in the df constructor don't pass for numbas typed dict
with set_numba_data(index) as index, set_numba_data(columns) as columns:
res = dict(nb_func(self.values, columns, index, *self.args))
res = dict(nb_func(self.values, columns, index, *args))
return res

@property
Expand Down Expand Up @@ -1279,9 +1284,10 @@ def numba_func(values, col_names_index, index, *args):
return numba_func

def apply_with_numba(self) -> dict[int, Any]:
args, kwargs = prepare_function_arguments(self.func, self.args, self.kwargs)
nb_func = self.generate_numba_apply_func(
cast(Callable, self.func),
**get_jit_arguments(self.engine_kwargs, self.kwargs),
**get_jit_arguments(self.engine_kwargs, kwargs),
)

from pandas.core._numba.extensions import set_numba_data
Expand All @@ -1292,7 +1298,7 @@ def apply_with_numba(self) -> dict[int, Any]:
set_numba_data(self.obj.index) as index,
set_numba_data(self.columns) as columns,
):
res = dict(nb_func(self.values, columns, index, *self.args))
res = dict(nb_func(self.values, columns, index, *args))

return res

Expand Down
45 changes: 45 additions & 0 deletions pandas/core/util/numba_.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import inspect
import types
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -97,3 +98,47 @@ def jit_user_function(func: Callable) -> Callable:
numba_func = numba.extending.register_jitable(func)

return numba_func


_sentinel = object()


def prepare_function_arguments(
func: Callable, args: tuple, kwargs: dict
) -> tuple[tuple, dict]:
"""
Prepare arguments for jitted function. As numba functions do not support kwargs,
we try to move kwargs into args if possible.
Parameters
----------
func : function
user defined function
args : tuple
user input positional arguments
kwargs : dict
user input keyword arguments
Returns
-------
tuple[tuple, dict]
args, kwargs
"""
if not kwargs:
return args, kwargs

# the udf should have this pattern: def udf(value, *args, **kwargs):...
signature = inspect.signature(func)
arguments = signature.bind(_sentinel, *args, **kwargs)
arguments.apply_defaults()
# Ref: https://peps.python.org/pep-0362/
# Arguments which could be passed as part of either *args or **kwargs
# will be included only in the BoundArguments.args attribute.
args = arguments.args
kwargs = arguments.kwargs

assert args[0] is _sentinel
args = args[1:]

return args, kwargs
19 changes: 17 additions & 2 deletions pandas/tests/apply/test_frame_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,35 @@ def test_apply(float_frame, engine, request):
@pytest.mark.parametrize("axis", [0, 1])
@pytest.mark.parametrize("raw", [True, False])
def test_apply_args(float_frame, axis, raw, engine):
# GH:58712
result = float_frame.apply(
lambda x, y: x + y, axis, args=(1,), raw=raw, engine=engine
)
expected = float_frame + 1
tm.assert_frame_equal(result, expected)

# GH:58712
result = float_frame.apply(
lambda x, a, b: x + a + b, args=(1,), b=2, engine=engine, raw=raw
)
expected = float_frame + 3
tm.assert_frame_equal(result, expected)

if engine == "numba":
# keyword-only arguments are not supported in numba
with pytest.raises(
pd.errors.NumbaUtilError,
match="numba does not support kwargs with nopython=True",
):
float_frame.apply(
lambda x, a, *, b: x + a + b, args=(1,), b=2, engine=engine, raw=raw
)

with pytest.raises(
pd.errors.NumbaUtilError,
match="numba does not support kwargs with nopython=True",
):
float_frame.apply(
lambda x, a, b: x + a + b, args=(1,), b=2, engine=engine, raw=raw
lambda *x, b: x[0] + x[1] + b, args=(1,), b=2, engine=engine, raw=raw
)


Expand Down

0 comments on commit 2aae933

Please sign in to comment.