In [1]:
import numpy as np
import time as t
import numba
from multimodemodel import _cyclic_shift

In [2]:
@numba.njit(inline='always')  # type: ignore
def _advection_momentum_u(
    i: int,
    j: int,
    k: int,
    m: int,
    n: int,
    ni: int,
    nj: int,
    nk: int,
    u: np.ndarray,
    v: np.ndarray,
    w: np.ndarray,
    mask_u: np.ndarray,
    mask_v: np.ndarray,
    mask_q: np.ndarray,
    dx_u: np.ndarray,
    dy_u: np.ndarray,
    dx_v: np.ndarray,
    lbc: int,
    ppp: np.ndarray,
    ppw: np.ndarray,
) -> float:  # pragma: no cover
    """Compute the advection of zonal momentum."""
    ip1 = _cyclic_shift(i, ni, 1)
    im1 = _cyclic_shift(i, ni, -1)
    jp1 = _cyclic_shift(j, nj, 1)
    jm1 = _cyclic_shift(j, nj, -1)

    if mask_q[k, j, i] == 0:
        lbc = lbc
    else:
        lbc = 1

    return (
        ppp[n, m, k]
        * mask_u[k, j, i]
        * (
            (
                dy_u[j, ip1] * mask_u[n, j, ip1] * u[n, j, ip1]
                + dy_u[j, i] * mask_u[n, j, i] * u[n, j, i]
            )
            * (mask_u[m, j, ip1] * u[m, j, ip1] + mask_u[m, j, i] * u[m, j, i])
            - (
                dy_u[j, i] * mask_u[n, j, i] * u[n, j, i]
                + dy_u[j, im1] * mask_u[n, j, im1] * u[n, j, im1]
            )
            * (mask_u[m, j, i] * u[m, j, i] + mask_u[m, j, im1] * u[m, j, im1])
            + (
                dx_v[jp1, i] * mask_v[n, jp1, i] * v[n, jp1, i]
                + dx_v[jp1, im1] * mask_v[n, jp1, im1] * v[n, jp1, im1]
            )
            * (mask_u[m, jp1, i] * u[m, jp1, i] + mask_u[m, j, i] * u[m, j, i])
            - (
                dx_v[j, i] * mask_v[n, j, i] * v[n, j, i]
                + dx_v[j, im1] * mask_v[n, j, im1] * v[n, j, im1]
            )
            * (lbc * mask_u[m, j, i] * u[m, j, i] + mask_u[m, jm1, i] * u[m, jm1, i])
        )
        / dx_u[j, i]
        / dy_u[j, i]
        / 4
        + ppw[n, m, k] * mask_u[m, j, i] * u[m, j, i] * (w[n, j, i] + w[n, j, im1]) / 2
    )

In [3]:
@numba.njit(parallel=True)  # type: ignore
def _numba_double_sum_parallel_over_nm_preallocated(args):
    ni, nj, nk = args[0], args[1], args[2]
    result = np.empty((nk * nk, nk, nj, ni))
    for ind in numba.prange(nk * nk):
        n, m = divmod(ind, nk)
        n = int(n)
        m = int(m)
        for k in range(nk):
            for j in range(nj):
                for i in range(ni):
                    result[ind, k, j, i] = _advection_momentum_u(
                        i,
                        j,
                        k,
                        m,
                        n,
                        ni,
                        nj,
                        nk,
                        args[3],
                        args[4],
                        args[5],
                        args[6],
                        args[7],
                        args[8],
                        args[9],
                        args[10],
                        args[11],
                        args[12],
                        args[13],
                        args[14],
                    )
    
    return result.sum(axis=0)

In [4]:
@numba.njit(parallel=True)  # type: ignore
def _numba_double_sum_parallel_over_nm(args):
    ni, nj, nk = args[0], args[1], args[2]
    result = np.zeros((nk, nj, ni))
    for ind in numba.prange(nk * nk):
        n, m = divmod(ind, nk)
        n = int(n)
        m = int(m)
        pre_result = np.empty((nk, nj, ni))
        for k in range(nk):
            for j in range(nj):
                for i in range(ni):
                    pre_result[k, j, i] = _advection_momentum_u(
                        i,
                        j,
                        k,
                        m,
                        n,
                        ni,
                        nj,
                        nk,
                        args[3],
                        args[4],
                        args[5],
                        args[6],
                        args[7],
                        args[8],
                        args[9],
                        args[10],
                        args[11],
                        args[12],
                        args[13],
                        args[14],
                    )
        result += pre_result
    return result

In [5]:
@numba.njit(parallel=True)  # type: ignore
def _numba_double_sum_parallel_over_k(args):
    ni, nj, nk = args[0], args[1], args[2]
    result = np.zeros((nk, nj, ni))
    pre_result = np.empty((nk, nj, ni))
    for n in range(nk):
        for m in range(nk):
            for k in numba.prange(nk):
                for j in range(nj):
                    for i in range(ni):
                        pre_result[k, j, i] = _advection_momentum_u(
                            i,
                            j,
                            k,
                            m,
                            n,
                            ni,
                            nj,
                            nk,
                            args[3],
                            args[4],
                            args[5],
                            args[6],
                            args[7],
                            args[8],
                            args[9],
                            args[10],
                            args[11],
                            args[12],
                            args[13],
                            args[14],
                        )
            result += pre_result
    return result

In [6]:
from  multimodemodel import StaggeredGrid

nmodes = 16

c_grid = StaggeredGrid.regular_lat_lon_c_grid(
    lon_start=-5.0,
    lon_end=5.0,
    lat_start=-5.0,
    lat_end=5.0,
    nx=10 * 4 + 1,
    ny=10 * 4 + 1,
    z = np.arange(nmodes)
)

In [7]:
from multimodemodel import State, Variable
t0 = np.datetime64("2000-01-01")
state_zero = State(
    u=Variable(None, c_grid.u, t0),
    v=Variable(None, c_grid.v, t0),
    eta=Variable(None, c_grid.eta, t0),
    q=Variable(None, c_grid.q, t0)
)

state_zero.set_diagnostic_variable(w=Variable(None, c_grid.eta, t0))

args = (
    c_grid.u.shape[c_grid.u.dim_x],
    c_grid.u.shape[c_grid.u.dim_y],
    c_grid.u.shape[c_grid.u.dim_z],
    state_zero.variables["u"].safe_data,
    state_zero.variables["v"].safe_data,
    state_zero.diagnostic_variables["w"].safe_data,
    state_zero.variables["u"].grid.mask,
    state_zero.variables["v"].grid.mask,
    state_zero.variables["q"].grid.mask,
    state_zero.variables["u"].grid.dx,
    state_zero.variables["u"].grid.dy,
    state_zero.variables["v"].grid.dx,
    0,
    np.ones((c_grid.u.shape[c_grid.u.dim_z], c_grid.u.shape[c_grid.u.dim_z], c_grid.u.shape[c_grid.u.dim_z])),
    np.ones((c_grid.u.shape[c_grid.u.dim_z], c_grid.u.shape[c_grid.u.dim_z], c_grid.u.shape[c_grid.u.dim_z])),
)

In [8]:
numba.set_num_threads(8)

In [9]:
_numba_double_sum_parallel_over_k(args);
print(numba.core.runtime.rtsys.get_allocation_stats())
_numba_double_sum_parallel_over_nm(args);
print(numba.core.runtime.rtsys.get_allocation_stats())
_numba_double_sum_parallel_over_nm_preallocated(args);
print(numba.core.runtime.rtsys.get_allocation_stats())

nrt_mstats(alloc=19, free=16, mi_alloc=19, mi_free=16)
nrt_mstats(alloc=49, free=43, mi_alloc=49, mi_free=43)
nrt_mstats(alloc=64, free=58, mi_alloc=63, mi_free=57)


In [10]:
%%timeit -n10 -r20 -o 
_numba_double_sum_parallel_over_k(args)

55.5 ms ± 6.94 ms per loop (mean ± std. dev. of 20 runs, 10 loops each)


<TimeitResult : 55.5 ms ± 6.94 ms per loop (mean ± std. dev. of 20 runs, 10 loops each)>

In [12]:
%%timeit -n10 -r20 -o 
_numba_double_sum_parallel_over_nm(args)

32.1 ms ± 3.34 ms per loop (mean ± std. dev. of 20 runs, 10 loops each)


<TimeitResult : 32.1 ms ± 3.34 ms per loop (mean ± std. dev. of 20 runs, 10 loops each)>

In [13]:
%%timeit -n10 -r20 -o 
_numba_double_sum_parallel_over_nm_preallocated(args)

68 ms ± 11.9 ms per loop (mean ± std. dev. of 20 runs, 10 loops each)


<TimeitResult : 68 ms ± 11.9 ms per loop (mean ± std. dev. of 20 runs, 10 loops each)>