In [1]:
import numpy as np
import time as t
import numba as nb
from multimodemodel import _cyclic_shift
import numba.core.runtime as rt

In [2]:
@nb.njit(inline='always')  # type: ignore
def _advection_momentum_u(
    i: int, j: int, k: int,
    m: int, n: int,
    ni: int, nj: int, nk: int,
    lbc: int,
    u: np.ndarray, v: np.ndarray, w: np.ndarray,
    mask_u: np.ndarray, mask_v: np.ndarray, mask_q: np.ndarray,
    dx_u: np.ndarray, dy_u: np.ndarray, dx_v: np.ndarray,
    ppp: np.ndarray, ppw: np.ndarray,
) -> float:  # pragma: no cover
    """Compute the advection of zonal momentum."""
    if mask_u[k, j, i] == 0:
        return 0.

    ip1 = _cyclic_shift(i, ni, 1)
    im1 = _cyclic_shift(i, ni, -1)
    jp1 = _cyclic_shift(j, nj, 1)
    jm1 = _cyclic_shift(j, nj, -1)

    if mask_q[k, j, i] == 0:
        lbc = lbc
    else:
        lbc = 1

    return (
        ppp[n, m, k]
        * mask_u[k, j, i]
        * (
            (
                dy_u[j, ip1] * mask_u[n, j, ip1] * u[n, j, ip1]
                + dy_u[j, i] * mask_u[n, j, i] * u[n, j, i]
            )
            * (mask_u[m, j, ip1] * u[m, j, ip1] + mask_u[m, j, i] * u[m, j, i])
            - (
                dy_u[j, i] * mask_u[n, j, i] * u[n, j, i]
                + dy_u[j, im1] * mask_u[n, j, im1] * u[n, j, im1]
            )
            * (mask_u[m, j, i] * u[m, j, i] + mask_u[m, j, im1] * u[m, j, im1])
            + (
                dx_v[jp1, i] * mask_v[n, jp1, i] * v[n, jp1, i]
                + dx_v[jp1, im1] * mask_v[n, jp1, im1] * v[n, jp1, im1]
            )
            * (mask_u[m, jp1, i] * u[m, jp1, i] + mask_u[m, j, i] * u[m, j, i])
            - (
                dx_v[j, i] * mask_v[n, j, i] * v[n, j, i]
                + dx_v[j, im1] * mask_v[n, j, im1] * v[n, j, im1]
            )
            * (lbc * mask_u[m, j, i] * u[m, j, i] + mask_u[m, jm1, i] * u[m, jm1, i])
        )
        / dx_u[j, i]
        / dy_u[j, i]
        / 4
        + ppw[n, m, k] * mask_u[k, j, i] * u[m, j, i] * (w[n, j, i] + w[n, j, im1]) / 2
    )

In [3]:
@nb.njit(inline='always')  # type: ignore
def _advection_momentum_u_alt(
    i: int,
    j: int,
    k: int,
    ni: int,
    nj: int,
    nk: int,
    lbc: int,
    u: np.ndarray,
    v: np.ndarray,
    w: np.ndarray,
    mask_u: np.ndarray,
    mask_v: np.ndarray,
    mask_q: np.ndarray,
    dx_u: np.ndarray,
    dy_u: np.ndarray,
    dx_v: np.ndarray,
    ppp: np.ndarray,
    ppw: np.ndarray,
) -> float:  # pragma: no cover
    """Compute the advection of zonal momentum."""
    if mask_u[k, j, i] == 0:
        return 0.

    ip1 = _cyclic_shift(i, ni, 1)
    im1 = _cyclic_shift(i, ni, -1)
    jp1 = _cyclic_shift(j, nj, 1)
    jm1 = _cyclic_shift(j, nj, -1)
    
    if mask_q[k, j, i] == 0:
        lbc = lbc
    else:
        lbc = 1

    mask_fac_ppp = mask_u[k, j, i] / dx_u[j, i] / dy_u[j, i] / 4
    mask_fac_ppw = 0.5 * mask_u[k, j, i]

    result = 0.
    
    for n in range(nk):
        u_eta_n_ij = (
            dy_u[j, ip1] * mask_u[n, j, ip1] * u[n, j, ip1]
            + dy_u[j, i] * mask_u[n, j, i] * u[n, j, i]
        )
        u_eta_n_im1j = (
            dy_u[j, i] * mask_u[n, j, i] * u[n, j, i]
            + dy_u[j, im1] * mask_u[n, j, im1] * u[n, j, im1]
        )
        v_q_n_ijp1 = (
            dx_v[jp1, i] * mask_v[n, jp1, i] * v[n, jp1, i]
            + dx_v[jp1, im1] * mask_v[n, jp1, im1] * v[n, jp1, im1]
        )
        v_q_n_ij = (
            dx_v[j, i] * mask_v[n, j, i] * v[n, j, i]
            + dx_v[j, im1] * mask_v[n, j, im1] * v[n, j, im1]
        )
        w_u_n_ij = w[n, j, i] + w[n, j, im1]

        for m in range(nk):
            result += (
                ppp[n, m, k]
                * mask_fac_ppp
                * (
                    u_eta_n_ij
                    * (mask_u[m, j, ip1] * u[m, j, ip1] + mask_u[m, j, i] * u[m, j, i])
                    - u_eta_n_im1j
                    * (mask_u[m, j, i] * u[m, j, i] + mask_u[m, j, im1] * u[m, j, im1])
                    + v_q_n_ijp1
                    * (mask_u[m, jp1, i] * u[m, jp1, i] + mask_u[m, j, i] * u[m, j, i])
                    - v_q_n_ij
                    * (lbc * mask_u[m, j, i] * u[m, j, i] + mask_u[m, jm1, i] * u[m, jm1, i])
                )
                + ppw[n, m, k] * mask_fac_ppw * u[m, j, i] * w_u_n_ij
            )
    return result

In [4]:
@nb.njit(parallel=True)  # type: ignore
def _numba_double_sum_parallel_over_nm_preallocated(args):
    ni, nj, nk = args[0], args[1], args[2]
    result = np.empty((nk * nk, nk, nj, ni))
    for ind in nb.prange(nk * nk):
        n, m = divmod(ind, nk)
        n = int(n)
        m = int(m)
        for k in range(nk):
            for j in range(nj):
                for i in range(ni):
                    result[ind, k, j, i] = _advection_momentum_u(
                        i, j, k,
                        m, n,
                        ni, nj, nk,
                        args[3],
                        args[4], args[5], args[6],
                        args[7], args[8], args[9],
                        args[10], args[11], args[12],
                        args[13], args[14],
                    )
    
    return result.sum(axis=0)

In [5]:
@nb.njit(parallel=True)  # type: ignore
def _numba_double_sum_parallel_over_nmk(args):
    n_threads = nb.get_num_threads()
    ni, nj, nk = args[0], args[1], args[2]
    result = np.zeros((n_threads, nk, nj, ni))
    for ind in nb.prange(nk * nk * nk):
        tid = nb.np.ufunc.parallel._get_thread_id()
        n, residual = divmod(ind, nk * nk)
        m, k = divmod(residual, nk)
        n = int(n)
        m = int(m)
        k = int(k)
        for j in range(nj):
            for i in range(ni):
                result[tid, k, j, i] += _advection_momentum_u(
                    i, j, k,
                    m, n,
                    ni, nj, nk,
                    args[3],
                    args[4], args[5], args[6],
                    args[7], args[8], args[9],
                    args[10], args[11], args[12],
                    args[13], args[14],
                )
    return result.sum(axis=0)

In [6]:
@nb.njit(parallel=True)  # type: ignore
def _numba_double_sum_parallel_over_k(args):
    ni, nj, nk = args[0], args[1], args[2]
    result = np.zeros((nk, nj, ni))
    for k in nb.prange(nk):
        for n in range(nk):
            for m in range(nk):
                for j in range(nj):
                    for i in range(ni):
                        result[k, j, i] += _advection_momentum_u(
                        i, j, k,
                        m, n,
                        ni, nj, nk,
                        args[3],
                        args[4], args[5], args[6],
                        args[7], args[8], args[9],
                        args[10], args[11], args[12],
                        args[13], args[14],
                        )
    return result

In [7]:
@nb.njit(parallel=True)  # type: ignore
def _numba_double_sum_parallel_over_kji(args):
    ni, nj, nk = args[0], args[1], args[2]
    result = np.zeros((nk, nj, ni))
    for ind in nb.prange(nk * nj * ni):
        k, residual = divmod(ind, nj * ni)
        j, i = divmod(residual, ni)
        k = int(k)
        j = int(j)
        i = int(i)
        for m in range(nk):
            for n in range(nk):
                result[k, j, i] += _advection_momentum_u(
                    i, j, k,
                    m, n,
                    ni, nj, nk,
                    args[3],
                    args[4], args[5], args[6],
                    args[7], args[8], args[9],
                    args[10], args[11], args[12],
                    args[13], args[14],
                )
    return result

In [8]:
@nb.njit(parallel=True)  # type: ignore
def _numba_double_sum_parallel_over_kji_alt(args):
    ni, nj, nk = args[0], args[1], args[2]
    result = np.empty((nk, nj, ni))
    for ind in nb.prange(nk * nj * ni):
        k, residual = divmod(ind, nj * ni)
        j, i = divmod(residual, ni)
        k = int(k)
        j = int(j)
        i = int(i)
        result[k, j, i] = _advection_momentum_u_alt(
            i, j, k,
            ni, nj, nk,
            args[3],
            args[4], args[5], args[6],
            args[7], args[8], args[9],
            args[10], args[11], args[12],
            args[13], args[14],
        )
    return result

In [9]:
nk, nj, ni = 16, 41, 41

args = (
    ni, nj, nk,
    0,
    np.random.randn(nk, nj, ni),
    np.random.randn(nk, nj, ni),
    np.random.randn(nk, nj, ni),
    np.random.choice([0, 1], (nk, nj, ni)),
    np.random.choice([0, 1], (nk, nj, ni)),
    np.random.choice([0, 1], (nk, nj, ni)),
    np.ones((nj, ni)),
    np.ones((nj, ni)),
    np.random.randn(nj, ni),
    np.ones((nk, nk, nk)),
    np.ones((nk, nk, nk)),
)

In [10]:
def allocation_difference(func, args):
   a = rt.rtsys.get_allocation_stats().alloc - rt.rtsys.get_allocation_stats().free
   if type(func)==np.ufunc:
      func(*args)
   else:
      func(args)
   b =  rt.rtsys.get_allocation_stats().alloc - rt.rtsys.get_allocation_stats().free
   return b-a

In [11]:
allocation_difference(_numba_double_sum_parallel_over_k, args)

0

In [12]:
allocation_difference(_numba_double_sum_parallel_over_kji, args)

0

In [13]:
allocation_difference(_numba_double_sum_parallel_over_kji_alt, args)

0

In [14]:
allocation_difference(_numba_double_sum_parallel_over_nmk, args)

0

In [15]:
allocation_difference(_numba_double_sum_parallel_over_nm_preallocated, args)

0

In [16]:
%%timeit -n20 -r20 -o 
_numba_double_sum_parallel_over_k(args)

21 ms ± 7.47 ms per loop (mean ± std. dev. of 20 runs, 20 loops each)


<TimeitResult : 21 ms ± 7.47 ms per loop (mean ± std. dev. of 20 runs, 20 loops each)>

In [17]:
%%timeit -n20 -r20 -o 
_numba_double_sum_parallel_over_kji(args)

18.1 ms ± 1.45 ms per loop (mean ± std. dev. of 20 runs, 20 loops each)


<TimeitResult : 18.1 ms ± 1.45 ms per loop (mean ± std. dev. of 20 runs, 20 loops each)>

In [18]:
%%timeit -n20 -r20 -o 
_numba_double_sum_parallel_over_kji_alt(args)

5.46 ms ± 731 µs per loop (mean ± std. dev. of 20 runs, 20 loops each)


<TimeitResult : 5.46 ms ± 731 µs per loop (mean ± std. dev. of 20 runs, 20 loops each)>

In [19]:
%%timeit -n20 -r20 -o 
_numba_double_sum_parallel_over_nmk(args)

20.6 ms ± 1.22 ms per loop (mean ± std. dev. of 20 runs, 20 loops each)


<TimeitResult : 20.6 ms ± 1.22 ms per loop (mean ± std. dev. of 20 runs, 20 loops each)>

In [20]:
%%timeit -n20 -r20 -o 
_numba_double_sum_parallel_over_nm_preallocated(args)

40.7 ms ± 2.45 ms per loop (mean ± std. dev. of 20 runs, 20 loops each)


<TimeitResult : 40.7 ms ± 2.45 ms per loop (mean ± std. dev. of 20 runs, 20 loops each)>

In [21]:
print("Threading layer chosen: %s" % nb.threading_layer())

Threading layer chosen: omp


In [22]:
oracle = _numba_double_sum_parallel_over_nm_preallocated(args)

assert np.allclose(_numba_double_sum_parallel_over_k(args), oracle)
assert np.allclose(_numba_double_sum_parallel_over_kji(args), oracle)
assert np.allclose(_numba_double_sum_parallel_over_kji_alt(args), oracle)
assert np.allclose(_numba_double_sum_parallel_over_nmk(args), oracle)
