Skip to content

Commit

Permalink
Add benchmark for numba elemwise
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Dec 29, 2022
1 parent c0710c9 commit 683479c
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion tests/link/numba/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytensor.tensor as at
import pytensor.tensor.inplace as ati
import pytensor.tensor.math as aem
from pytensor import config
from pytensor import config, function
from pytensor.compile.ops import deep_copy_op
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant
Expand Down Expand Up @@ -117,6 +117,25 @@ def test_Elemwise(inputs, input_vals, output_fn, exc):
compare_numba_and_py(out_fg, input_vals)


def test_elemwise_speed(benchmark):
x = at.dmatrix("y")
y = at.dvector("z")

out = np.exp(2 * x * y + y)

rng = np.random.default_rng(42)

x_val = rng.normal(size=(200, 500))
y_val = rng.normal(size=500)

func = function([x, y], out, mode="NUMBA")
func = func.vm.jit_fn
(out,) = func(x_val, y_val)
np.testing.assert_allclose(np.exp(2 * x_val * y_val + y_val), out)

benchmark(func, x_val, y_val)


@pytest.mark.parametrize(
"v, new_order",
[
Expand Down

0 comments on commit 683479c

Please sign in to comment.