In [139]:
import numba
import numpy as np
from typing import Callable, Tuple, Any
from inspect import signature

import llvmlite.binding as llvm
llvm.set_option('', '--debug-only=loop-vectorize')

In [102]:
# define test data
nx, ny = 10000, 10000
dx = 1.
g = 1.
eta = np.random.rand(nx, ny)

In [103]:
# Different numba options
@numba.njit
def _iterate_over_grid_2D_plain(
    loop_body: Callable[..., float], ni: int, nj: int, args: Tuple[Any]
) -> np.array:
    result = np.empty((ni, nj))
    for i in range(ni):
        for j in range(nj):
            result[i, j] = loop_body(*args, i, j, ni, nj)
    return result


@numba.njit
def _zonal_pressure_gradient_loop_body_plain(
    eta: np.array, g: float, dx: float, i: int, j: int, ni: int, nj: int
) -> float:
    ip1 = (i + 1) % ni
    return -g * (eta[ip1, j] - eta[i, j]) / dx


@numba.njit(fastmath=True)
def _iterate_over_grid_2D_fastmath(
    loop_body: Callable[..., float], ni: int, nj: int, args: Tuple[Any]
) -> np.array:
    result = np.empty((ni, nj))
    for i in range(ni):
        for j in range(nj):
            result[i, j] = loop_body(*args, i, j, ni, nj)
    return result


@numba.njit(fastmath=True)
def _zonal_pressure_gradient_loop_body_fastmath(
    eta: np.array, g: float, dx: float, i: int, j: int, ni: int, nj: int
) -> float:
    ip1 = (i + 1) % ni
    return -g * (eta[ip1, j] - eta[i, j]) / dx


In [104]:
%timeit _iterate_over_grid_2D_plain(_zonal_pressure_gradient_loop_body_plain, nx, ny, (eta, g, dx))

612 ms ± 6.62 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [105]:
%timeit _iterate_over_grid_2D_fastmath(_zonal_pressure_gradient_loop_body_fastmath, nx, ny, (eta, g, dx))

662 ms ± 11.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [155]:
# impact of inlining
@numba.njit
def _iterate_over_grid_2D_manual_inlined(
    ni: int, nj: int, eta: np.array, g: float, dx: float
) -> np.array:
    result = np.empty((ni, nj))
    for i in range(ni):
        for j in range(nj):
            ip1 = (i + 1) % ni
            result[i, j] = -g * (eta[ip1, j] - eta[i, j]) / dx
    return result

%timeit _iterate_over_grid_2D_manual_inlined(nx, ny, eta, g, dx)

474 ms ± 6.21 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [107]:
@numba.njit
def _iterate_over_grid_2D_default_inline(
    ni: int, nj: int, eta: np.array, g: float, dx: float
) -> np.array:
    result = np.empty((ni, nj))
    for i in range(ni):
        for j in range(nj):
            result[i, j] = _zonal_pressure_gradient_loop_body_default_inline(eta, g, dx, i, j, ni, nj)
    return result

@numba.njit
def _zonal_pressure_gradient_loop_body_default_inline(
    eta: np.array, g: float, dx: float, i: int, j: int, ni: int, nj: int
) -> float:
    ip1 = (i + 1) % ni
    return -g * (eta[ip1, j] - eta[i, j]) / dx

%timeit _iterate_over_grid_2D_default_inline(nx, ny, eta, g, dx)

619 ms ± 4.11 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [108]:
@numba.njit()
def _iterate_over_grid_2D_always_inline(
    ni: int, nj: int, eta: np.array, g: float, dx: float
) -> np.array:
    result = np.empty((ni, nj))
    for i in range(ni):
        for j in range(nj):
            result[i, j] = _zonal_pressure_gradient_loop_body_always_inline(eta, g, dx, i, j, ni, nj)
    return result

@numba.njit(inline='always')
def _zonal_pressure_gradient_loop_body_always_inline(
    eta: np.array, g: float, dx: float, i: int, j: int, ni: int, nj: int
) -> float:
    ip1 = (i + 1) % ni
    return -g * (eta[ip1, j] - eta[i, j]) / dx

%timeit _iterate_over_grid_2D_always_inline(nx, ny, eta, g, dx)

464 ms ± 2.67 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [109]:
@numba.njit
def _iterate_over_grid_2D_callback(
    loop_body: Callable[..., float],
    ni: int, nj: int, eta: np.array, g: float, dx: float
) -> np.array:
    result = np.empty((ni, nj))
    for i in range(ni):
        for j in range(nj):
            result[i, j] = loop_body(eta, g, dx, i, j, ni, nj)
    return result

%timeit _iterate_over_grid_2D_callback(_zonal_pressure_gradient_loop_body_always_inline, nx, ny, eta, g, dx)

622 ms ± 4.13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [110]:
def _iterate_over_grid_2D_decorator(func):
    @numba.njit
    def wrapper(ni: int, nj: int, eta: np.array, g: float, dx: float):
        result = np.empty((ni, nj))
        for i in range(ni):
            for j in range(nj):
                result[i, j] = func(eta, g, dx, i, j, ni, nj)
        return result
    return wrapper

_iterate_over_grid_2D_decorated = _iterate_over_grid_2D_decorator(_zonal_pressure_gradient_loop_body_always_inline)

In [111]:
%timeit _iterate_over_grid_2D_decorated(nx, ny, eta, g, dx)

469 ms ± 5.67 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [126]:

def _make_2D_grid_iterator_vararg(func):
    @numba.njit
    def _interate_over_grid_2D(ni: int, nj: int, *args: Tuple[Any]):
        result = np.empty((ni, nj))
        for i in range(ni):
            for j in range(nj):
                f_args = (args[0], args[1], args[2], i, j, ni, nj)
                result[i, j] = func(*f_args)
        return result
    return _interate_over_grid_2D

_iterate_over_grid_2D_decorated_vararg = _make_2D_grid_iterator_vararg(_zonal_pressure_gradient_loop_body_always_inline)

In [144]:
len(signature(_zonal_pressure_gradient_loop_body_always_inline).parameters)

7

@numba.njit(inline='always')
def expand_7_args(func: Callable[..., float], args: Tuple[Any]):
    return func(args[0], args[1], args[2], args[3], args[4], args[5], args[6])

def _make_2D_grid_iterator_vararg(func):
    @numba.njit
    def _interate_over_grid_2D(ni: int, nj: int, *args: Tuple[Any]):
        result = np.empty((ni, nj))
        for i in range(ni):
            for j in range(nj):
                f_args = (args[0], args[1], args[2], i, j, ni, nj)
                result[i, j] = expand_7_args(func, f_args)
        return result
    return _interate_over_grid_2D

_iterate_over_grid_2D_decorated_vararg = _make_2D_grid_iterator_vararg(_zonal_pressure_gradient_loop_body_always_inline)


In [158]:
%timeit _iterate_over_grid_2D_decorated_vararg(nx, ny, eta, g, dx)

485 ms ± 8.74 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [152]:
def _make_2D_grid_iterator_vararg(func):
    @numba.njit
    def _interate_over_grid_2D(ni: int, nj: int, *args: Tuple[Any]):
        result = np.empty((ni, nj))
        for i in range(ni):
            for j in range(nj):
                f_args = (args[0], args[1], args[2], i, j, ni, nj)
                result[i, j] = func(f_args)
        return result
    return _interate_over_grid_2D

@numba.njit(inline='always')
def _zonal_pressure_gradient_loop_body_inline_vararg(
    args: Tuple[Any]
    # eta: np.array, g: float, dx: float, i: int, j: int, ni: int, nj: int
) -> float:
    eta, g, dx, i, j, ni, nj = args
    ip1 = (i + 1) % ni
    return -g * (eta[ip1, j] - eta[i, j]) / dx

_iterate_over_grid_2D_decorated_vararg = _make_2D_grid_iterator_vararg(_zonal_pressure_gradient_loop_body_inline_vararg)

%timeit _iterate_over_grid_2D_decorated_vararg(nx, ny, eta, g, dx)

487 ms ± 4.85 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [167]:
@numba.njit(inline='always')
def expand_7_arguments(func, i, j, ni, nj, args):
    return func(i, j, ni, nj, args[0], args[1], args[2])

@numba.njit(inline='always')
def expand_6_arguments(func, i, j, ni, nj, args):
    return func(i, j, ni, nj, args[0], args[1])

_arg_expand_map = {
    6: expand_6_arguments,
    7: expand_7_arguments,
}

def _make_2D_grid_iterator(func):
    exp_args = _arg_expand_map[len(signature(func).parameters)]

    @numba.njit
    def _interate_over_grid_2D(ni: int, nj: int, *args: Tuple[Any]):
        result = np.empty((ni, nj))
        for i in range(ni):
            for j in range(nj):
                result[i, j] = exp_args(func, i, j, ni, nj, args)
        return result
    return _interate_over_grid_2D

@numba.njit(inline='always')
def _loop_body(
    i: int, j: int, ni: int, nj: int, eta: np.array, g: float, dx: float
) -> float:
    ip1 = (i + 1) % ni
    return -g * (eta[ip1, j] - eta[i, j]) / dx

_grid_iterator = _make_2D_grid_iterator(_loop_body)

%timeit _grid_iterator(nx, ny, eta, g, dx)

459 ms ± 4.64 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
