Skip to content

Commit

Permalink
Enable no-cpython-wrapper in numba where possible (#765)
Browse files Browse the repository at this point in the history
* Enable no-cpython-wrapper in numba where possible

* Fix test with no_cpython_wrapper

* Add docstring to numba_funcify
  • Loading branch information
aseyboldt committed May 13, 2024
1 parent 15b90be commit 3ed2c49
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 5 deletions.
9 changes: 8 additions & 1 deletion pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def global_numba_func(func):

def numba_njit(*args, **kwargs):
kwargs.setdefault("cache", config.numba__cache)
kwargs.setdefault("no_cpython_wrapper", True)
kwargs.setdefault("no_cfunc_wrapper", True)

# Supress caching warnings
warnings.filterwarnings(
Expand Down Expand Up @@ -419,7 +421,12 @@ def perform(*inputs):

@singledispatch
def numba_funcify(op, node=None, storage_map=None, **kwargs):
"""Generate a numba function for a given op and apply node."""
"""Generate a numba function for a given op and apply node.
The resulting function will usually use the `no_cpython_wrapper`
argument in numba, so it can not be called directly from python,
but only from other jit functions.
"""
return generate_fallback_impl(op, node, storage_map, **kwargs)


Expand Down
13 changes: 11 additions & 2 deletions pytensor/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,9 @@ def axis_apply_fn(x):
"afn", # Approximate functions
"reassoc",
"nsz", # TODO Do we want this one?
}
},
"no_cpython_wrapper": True,
"no_cfunc_wrapper": True,
}


Expand Down Expand Up @@ -698,7 +700,14 @@ def elemwise(*inputs):
return tuple(outputs_summed)
return outputs_summed[0]

@overload(elemwise)
@overload(
elemwise,
jit_options={
"fastmath": flags,
"no_cpython_wrapper": True,
"no_cfunc_wrapper": True,
},
)
def ov_elemwise(*inputs):
return elemwise_wrapper

Expand Down
2 changes: 1 addition & 1 deletion pytensor/link/numba/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def fgraph_convert(self, fgraph, **kwargs):
def jit_compile(self, fn):
from pytensor.link.numba.dispatch.basic import numba_njit

jitted_fn = numba_njit(fn)
jitted_fn = numba_njit(fn, no_cpython_wrapper=False, no_cfunc_wrapper=False)
return jitted_fn

def create_thunk_inputs(self, storage_map):
Expand Down
9 changes: 8 additions & 1 deletion tests/link/numba/test_tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,8 @@ def test_ExtractDiag(val, offset):
)
@pytest.mark.parametrize("reverse_axis", (False, True))
def test_ExtractDiag_exhaustive(k, axis1, axis2, reverse_axis):
from pytensor.link.numba.dispatch.basic import numba_njit

if reverse_axis:
axis1, axis2 = axis2, axis1

Expand All @@ -394,7 +396,12 @@ def test_ExtractDiag_exhaustive(k, axis1, axis2, reverse_axis):
x_test = np.arange(np.prod(x_shape)).reshape(x_shape)
out = pt.diagonal(x, k, axis1, axis2)
numba_fn = numba_funcify(out.owner.op, out.owner)
np.testing.assert_allclose(numba_fn(x_test), np.diagonal(x_test, k, axis1, axis2))

@numba_njit(no_cpython_wrapper=False)
def wrap(x):
return numba_fn(x)

np.testing.assert_allclose(wrap(x_test), np.diagonal(x_test, k, axis1, axis2))


@pytest.mark.parametrize(
Expand Down

0 comments on commit 3ed2c49

Please sign in to comment.