Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH add *args support for numba apply #58767

Merged
merged 12 commits into from
Jun 11, 2024
Merged

Conversation

auderson
Copy link
Contributor

@auderson auderson commented May 18, 2024

added *args to nb_looper().
use get_jit_arguments() to disallow **kwargs in raw + numba engine.

@auderson
Copy link
Contributor Author

@lithomas1
I notice that *args support for raw=False + numba engine can be also added, should I open a another PR or just put them together?

@lithomas1
Copy link
Member

@lithomas1 I notice that *args support for raw=False + numba engine can be also added, should I open a another PR or just put them together?

If it's not too involved to add support there, you're welcome to put it in the same PR.

Thanks for the quick PR btw.

@auderson auderson changed the title add *args for raw numba apply add *args support for numba apply May 19, 2024
@auderson
Copy link
Contributor Author

It's all done. Could you take a look when you have time?

@rhshadrach rhshadrach added Enhancement Apply Apply, Aggregate, Transform, Map numba numba-accelerated operations labels May 19, 2024
@lithomas1
Copy link
Member

Does this also handle **kwargs by e.g. passing them as *args?

It might be worth exploring that.

This LGTM otherwise.

@auderson
Copy link
Contributor Author

That's a good point, I'll have a try later when I have time.

@auderson auderson changed the title add *args support for numba apply ENH add *args support for numba apply May 20, 2024
@auderson
Copy link
Contributor Author

BTW, since this PR was labeled as enhancement, should I move it to the enhancement section in whatsnew?

@auderson
Copy link
Contributor Author

@lithomas1
Hi, I added an util function as you suggested. It would be nice if you spare some time to have a look at it.

match="numba does not support kwargs with nopython=True",
):
float_frame.apply(
lambda x, a, *, b: x + a + b, args=(1,), b=2, engine=engine, raw=raw
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work if you pass nopython=False in the engine_kwargs?

Copy link
Contributor Author

@auderson auderson May 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pandas test raises missing a required argument error when nopython=False.

FAILED
pandas/tests/apply/test_frame_apply.py:63 (test_apply_args[numba-False-True-0])
float_frame =                A         B         C         D
foo_0   0.189053 -0.522748 -0.413064 -2.441467
foo_1   1.799707  1.1441...47218  0.968478 -0.955145
foo_28  0.354112 -1.968397  0.899274 -0.158248
foo_29 -0.967681  1.678419  0.765355  0.045808
axis = 0, raw = True, engine = 'numba', nopython = False

    @pytest.mark.parametrize("axis", [0, 1])
    @pytest.mark.parametrize("raw", [True, False])
    @pytest.mark.parametrize("nopython", [True, False])
    def test_apply_args(float_frame, axis, raw, engine, nopython):
        engine_kwargs = {"nopython": nopython}
        result = float_frame.apply(
            lambda x, y: x + y, axis, args=(1,), raw=raw, engine=engine, engine_kwargs=engine_kwargs
        )
        expected = float_frame + 1
        tm.assert_frame_equal(result, expected)
    
        # GH:58712
        result = float_frame.apply(
            lambda x, a, b: x + a + b, args=(1,), b=2, raw=raw, engine=engine, engine_kwargs=engine_kwargs
        )
        expected = float_frame + 3
        tm.assert_frame_equal(result, expected)
    
        if engine == "numba":
            # keyword-only arguments are not supported in numba
            with pytest.raises(
                pd.errors.NumbaUtilError,
                match="numba does not support kwargs with nopython=True",
            ):
>               float_frame.apply(
                    lambda x, a, *, b: x + a + b, args=(1,), b=2, raw=raw, engine=engine, engine_kwargs=engine_kwargs
                )

test_frame_apply.py:88: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../../core/frame.py:10353: in apply
    return op.apply().__finalize__(self, method="apply")
../../core/apply.py:886: in apply
    return self.apply_raw(engine=self.engine, engine_kwargs=self.engine_kwargs)
../../core/apply.py:991: in apply_raw
    result = nb_looper(self.values, self.axis, *args)
/home/auderson/miniconda3/envs/pandas-dev/lib/python3.10/site-packages/numba/core/dispatcher.py:468: in _compile_for_args
    error_rewrite(e, 'typing')
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

e = TypingError('Failed in nopython mode pipeline (step: nopython frontend)\n\x1b[1m\x1b[1m\x1b[1mNo implementation of fun...0 = values.shape[0]\n\x1b[1m        res0 = nb_compat_func(first_elem, *args)\n\x1b[0m        \x1b[1m^\x1b[0m\x1b[0m\n')
issue_type = 'typing'

    def error_rewrite(e, issue_type):
        """
        Rewrite and raise Exception `e` with help supplied based on the
        specified issue_type.
        """
        if config.SHOW_HELP:
            help_msg = errors.error_extras[issue_type]
            e.patch_message('\n'.join((str(e).rstrip(), help_msg)))
        if config.FULL_TRACEBACKS:
            raise e
        else:
>           raise e.with_traceback(None)
E           numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
E           No implementation of function Function(<function test_apply_args.<locals>.<lambda> at 0x7f0d2f6081f0>) found for signature:
E            
E            >>> <lambda>(readonly array(float64, 1d, A), int64)
E            
E           There are 2 candidate implementations:
E               - Of which 2 did not match due to:
E               Overload in function 'register_jitable.<locals>.wrap.<locals>.ov_wrap': File: numba/core/extending.py: Line 161.
E                 With argument(s): '(readonly array(float64, 1d, A), int64)':
E                Rejected as the implementation raised a specific error:
E                  TypeError: missing a required argument: 'a'
E             raised from /home/auderson/miniconda3/envs/pandas-dev/lib/python3.10/inspect.py:3101
E           
E           During: resolving callee type: Function(<function test_apply_args.<locals>.<lambda> at 0x7f0d2f6081f0>)
E           During: typing of call at /mnt/c/Users/auderson/Documents/Works/OpenSourcePackages/pandas/pandas/core/_numba/executor.py (38)
E           
E           
E           File "../../core/_numba/executor.py", line 38:
E               def nb_looper(values, axis, *args):
E                   <source elided>
E                       dim0 = values.shape[0]
E                   res0 = nb_compat_func(first_elem, *args)
E                   ^

/home/auderson/miniconda3/envs/pandas-dev/lib/python3.10/site-packages/numba/core/dispatcher.py:409: TypingError

An equivalent reproducer:

from numba import jit

@jit(nopython=False)
def foo(*args, kwarg=None):
    print(args)
    print(kwarg)

@jit(nopython=False)
def bar(a, *args):
    foo(a, *args, kwarg='foobar')

bar(1, 2, 3)
UnsupportedError: Failed in object mode pipeline (step: analyzing bytecode)

CALL_FUNCTION_EX with **kwargs not supported.
If you are not using **kwargs this may indicate that
you have a large number of kwargs and are using inlined control
flow. You can resolve this issue by moving the control flow out of
the function call. For example, if you have

    f(a=1 if flag else 0, ...)

Replace that with:

    a_val = 1 if flag else 0
    f(a=a_val, ...)

Copy link
Contributor Author

@auderson auderson May 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe here need to be modified:

if kwargs and nopython:
raise NumbaUtilError(
"numba does not support kwargs with nopython=True: "
"https://github.com/numba/numba/issues/2916"
)

Regardless nopython is True or False, the "numba does not support kwargs" should be raised:

 if kwargs: 
     raise NumbaUtilError( 
         "numba does not support keyword-only arguments" 
         "https://github.com/numba/numba/issues/2916" 
     ) 

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, updating the error message is probably the right move here.

@auderson
Copy link
Contributor Author

We can also extend prepare_function_arguments by adding a option: num_required_args, so that all the numba functions' args & kwargs stuff can be handled by this function, instead of in get_jit_arguments .
For example in groupby, we have numba_func(group, group_index, *args), in this case num_required_args = 2:

@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
def group_agg(
values: np.ndarray,
index: np.ndarray,
begin: np.ndarray,
end: np.ndarray,
num_columns: int,
*args: Any,
) -> np.ndarray:
assert len(begin) == len(end)
num_groups = len(begin)
result = np.empty((num_groups, num_columns))
for i in numba.prange(num_groups):
group_index = index[begin[i] : end[i]]
for j in numba.prange(num_columns):
group = values[begin[i] : end[i], j]
result[i, j] = numba_func(group, group_index, *args)
return result

@lithomas1
Copy link
Member

We can also extend prepare_function_arguments by adding a option: num_required_args, so that all the numba functions' args & kwargs stuff can be handled by this function, instead of in get_jit_arguments . For example in groupby, we have numba_func(group, group_index, *args), in this case num_required_args = 2:

@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
def group_agg(
values: np.ndarray,
index: np.ndarray,
begin: np.ndarray,
end: np.ndarray,
num_columns: int,
*args: Any,
) -> np.ndarray:
assert len(begin) == len(end)
num_groups = len(begin)
result = np.empty((num_groups, num_columns))
for i in numba.prange(num_groups):
group_index = index[begin[i] : end[i]]
for j in numba.prange(num_columns):
group = values[begin[i] : end[i], j]
result[i, j] = numba_func(group, group_index, *args)
return result

Good idea.

Do you mind doing this in a separate PR after this?

@auderson
Copy link
Contributor Author

Good idea.

Do you mind doing this in a separate PR after this?

Sure, I'll do that after I finish this PR.

@auderson
Copy link
Contributor Author

Looks all green :) @lithomas1

@auderson
Copy link
Contributor Author

auderson commented Jun 5, 2024

@lithomas1 Hi, if you may have time, please take a look so that I will start the other PR which enhances prepare_function_arguments.

Copy link
Member

@lithomas1 lithomas1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @auderson

@lithomas1 lithomas1 added this to the 3.0 milestone Jun 10, 2024
@mroeschke mroeschke merged commit bbe0e53 into pandas-dev:main Jun 11, 2024
47 checks passed
@mroeschke
Copy link
Member

Thanks @auderson

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apply Apply, Aggregate, Transform, Map Enhancement numba numba-accelerated operations
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BUG: No kwargs in df.apply(raw=True, engine="numba")
4 participants