Skip to content

Commit

Permalink
add *args for raw=False as well; merge tests together
Browse files Browse the repository at this point in the history
  • Loading branch information
auderson committed May 19, 2024
1 parent c026845 commit 96581a3
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 32 deletions.
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ Other
- Bug in :class:`DataFrame` when passing a ``dict`` with a NA scalar and ``columns`` that would always return ``np.nan`` (:issue:`57205`)
- Bug in :func:`eval` where the names of the :class:`Series` were not preserved when using ``engine="numexpr"``. (:issue:`10239`)
- Bug in :func:`unique` on :class:`Index` not always returning :class:`Index` (:issue:`57043`)
- Bug in :meth:`DataFrame.apply` where passing ``raw=True`` and ``engine="numba"`` ignored ``args`` passed to the applied function (:issue:`58712`)
- Bug in :meth:`DataFrame.apply` where passing ``engine="numba"`` ignored ``args`` passed to the applied function (:issue:`58712`)
- Bug in :meth:`DataFrame.eval` and :meth:`DataFrame.query` which caused an exception when using NumPy attributes via ``@`` notation, e.g., ``df.eval("@np.floor(a)")``. (:issue:`58041`)
- Bug in :meth:`DataFrame.eval` and :meth:`DataFrame.query` which did not allow to use ``tan`` function. (:issue:`55091`)
- Bug in :meth:`DataFrame.sort_index` when passing ``axis="columns"`` and ``ignore_index=True`` and ``ascending=False`` not returning a :class:`RangeIndex` columns (:issue:`57293`)
Expand Down
18 changes: 10 additions & 8 deletions pandas/core/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,21 +1122,22 @@ def generate_numba_apply_func(
# Currently the parallel argument doesn't get passed through here
# (it's disabled) since the dicts in numba aren't thread-safe.
@numba.jit(nogil=nogil, nopython=nopython, parallel=parallel)
def numba_func(values, col_names, df_index):
def numba_func(values, col_names, df_index, *args):
results = {}
for j in range(values.shape[1]):
# Create the series
ser = Series(
values[:, j], index=df_index, name=maybe_cast_str(col_names[j])
)
results[j] = jitted_udf(ser)
results[j] = jitted_udf(ser, *args)
return results

return numba_func

def apply_with_numba(self) -> dict[int, Any]:
nb_func = self.generate_numba_apply_func(
cast(Callable, self.func), **self.engine_kwargs
cast(Callable, self.func),
**get_jit_arguments(self.engine_kwargs, self.kwargs),
)
from pandas.core._numba.extensions import set_numba_data

Expand All @@ -1151,7 +1152,7 @@ def apply_with_numba(self) -> dict[int, Any]:
# Convert from numba dict to regular dict
# Our isinstance checks in the df constructor don't pass for numbas typed dict
with set_numba_data(index) as index, set_numba_data(columns) as columns:
res = dict(nb_func(self.values, columns, index))
res = dict(nb_func(self.values, columns, index, *self.args))
return res

@property
Expand Down Expand Up @@ -1259,7 +1260,7 @@ def generate_numba_apply_func(
jitted_udf = numba.extending.register_jitable(func)

@numba.jit(nogil=nogil, nopython=nopython, parallel=parallel)
def numba_func(values, col_names_index, index):
def numba_func(values, col_names_index, index, *args):
results = {}
# Currently the parallel argument doesn't get passed through here
# (it's disabled) since the dicts in numba aren't thread-safe.
Expand All @@ -1271,15 +1272,16 @@ def numba_func(values, col_names_index, index):
index=col_names_index,
name=maybe_cast_str(index[i]),
)
results[i] = jitted_udf(ser)
results[i] = jitted_udf(ser, *args)

return results

return numba_func

def apply_with_numba(self) -> dict[int, Any]:
nb_func = self.generate_numba_apply_func(
cast(Callable, self.func), **self.engine_kwargs
cast(Callable, self.func),
**get_jit_arguments(self.engine_kwargs, self.kwargs),
)

from pandas.core._numba.extensions import set_numba_data
Expand All @@ -1290,7 +1292,7 @@ def apply_with_numba(self) -> dict[int, Any]:
set_numba_data(self.obj.index) as index,
set_numba_data(self.columns) as columns,
):
res = dict(nb_func(self.values, columns, index))
res = dict(nb_func(self.values, columns, index, *self.args))

return res

Expand Down
34 changes: 11 additions & 23 deletions pandas/tests/apply/test_frame_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,23 @@ def test_apply(float_frame, engine, request):

@pytest.mark.parametrize("axis", [0, 1])
@pytest.mark.parametrize("raw", [True, False])
def test_apply_args(float_frame, axis, raw, engine, request):
if engine == "numba" and raw is False:
mark = pytest.mark.xfail(reason="numba engine doesn't support args")
request.node.add_marker(mark)
def test_apply_args(float_frame, axis, raw, engine):
# GH:58712
result = float_frame.apply(
lambda x, y: x + y, axis, args=(1,), raw=raw, engine=engine
)
expected = float_frame + 1
tm.assert_frame_equal(result, expected)

if engine == "numba":
with pytest.raises(
pd.errors.NumbaUtilError,
match="numba does not support kwargs with nopython=True",
):
float_frame.apply(
lambda x, a, b: x + a + b, args=(1,), b=2, engine=engine, raw=raw
)


def test_apply_categorical_func():
# GH 9573
Expand Down Expand Up @@ -1718,22 +1725,3 @@ def test_agg_dist_like_and_nonunique_columns():
result = df.agg({"A": "count"})
expected = df["A"].count()
tm.assert_series_equal(result, expected)


def test_numba_raw_apply_with_args(engine):
if engine == "numba":
# GH:58712
df = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
result = df.apply(
lambda x, a, b: x + a + b, args=(1, 2), engine=engine, raw=True
)
# note: result is always float dtype,
# see core._numba.executor.py:generate_apply_looper
expected = df + 3.0
tm.assert_frame_equal(result, expected)

with pytest.raises(
pd.errors.NumbaUtilError,
match="numba does not support kwargs with nopython=True",
):
df.apply(lambda x, a, b: x + a + b, args=(1,), b=2, engine=engine, raw=True)

0 comments on commit 96581a3

Please sign in to comment.