# ICON Stencil
This stencil is based on `model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/mo_velocity_advection_stencil_20.py`.
It is very simple and most faithfull implementation.

A reference `NumPy` implementation can be found in `model/atmosphere/dycore/tests/stencil_tests/test_mo_velocity_advection_stencil_20.py`.


In [1]:
import os, copy
#ncpu=1
#os.environ['XLA_FLAGS'] = f"--xla_force_host_platform_device_count={ncpu}"
#os.environ["XLA_FLAGS"] = ("--xla_cpu_multi_thread_eigen=false " "intra_op_parallelism_threads=1")
#os.environ['JAX_PLATFORMS'] = "cpu"

from datetime import datetime
from pickle import dump

import numpy as np
import jax
import sys
import jax.numpy as jnp
from jax import jit, grad, vmap
from jax import random

# This must be enabled when `make_jaxpr` is called, because otherwhise we get problems.
jax.config.update("jax_enable_x64", True)

import dace


In [2]:
from JaxprToSDFG import  JaxprToSDFG
t = JaxprToSDFG()

In [3]:
noITIR = False   # Do not use the GT4Py based translators or results
noSDFG = False    # Do not use the JAX -> SDFG based translators.

In [4]:

_build_type = "Release"
# removing  -ffast-math from DaCe default compiler args in order to support isfinite/isinf/isnan built-ins
_cpu_args = dace.config.Config.get("compiler", "cpu", "args").replace('-ffast-math', '') + ' -fopenmp' + ' -faligned-new'

### Demo Input

### Grid Dimensions
Since we miss the GT4Py Dimension mechanism, we will make some tricks.
- `{K, Cell, Edge, Vertex}DimN` is the number of either Vertex, Cell, Edges or vertical levels.
- Valid indexes always life within the range `[0, N-1]` where `N` is the above entity.
- However, they do not denote at which position they are.



In [5]:
from icondims import SimpleGridData

### Offset Providers
Here we define the offset providors, i.e. the connectivity.
It is important that they have to be read in a certain way.

In [6]:


# As stated above, each edge is either adijacent to 1 cell (at the boundary) or 2 (in the inner part)
#  This table describes which cell are adijacent to a certain edge, in case it is undefined we have `-1`.  
edge_to_cell_table = np.array(SimpleGridData.e2c_table).astype(np.int32)

# We have a triangular grid, this means a cell is bounded by three edges.
#  In unstructured it could be, we have many geometrical forms, in this case
#  we would allocate the matrix such that it fits the largest one and fill up with `-1` if we do not need them.
cell_to_edge_table = np.array(SimpleGridData.c2e_table).astype(np.int32)

# Whatever
E2C2EO_table = np.array(SimpleGridData.e2c2eO_table).astype(np.int32)


E2V_table = np.array(SimpleGridData.e2v_table).astype(np.int32)


In [7]:
KDimN = 50     # Vertical levels
CellDimN = cell_to_edge_table.shape[0]   # How many cells we have
EdgeDimN = edge_to_cell_table.shape[0]
VertexDimN = SimpleGridData.v2c_table.shape[0]


# `E2C` is the offset map that transforms _edges_ to _cells_.
#   Since a edge is adijacent to at most two cells this dimension is 2 but at least 1.
E2CDimN = 2

# Whatever
E2C2EODimN = E2C2EO_table.shape[1]

E2VDimN = E2V_table.shape[1]

In [8]:
# This is the level mask, the `+ 1` is taken from the test but it makes sense, at least a bit.
#  For `levelmask[Koff]` use `levelmask[:-1]`
#  and for `levelmask[Koff[+1]]` use `levelmask[1:]`.
levelmask = (np.random.rand(KDimN + 1) < 0.5).astype(np.bool_)

### Fields
These are the fields we are using as an inoput material

In [9]:
# Field defined on the edges of the grid and different in every hight level.
w_con_e = np.random.rand(EdgeDimN, KDimN).astype(np.float64)

# This is a field, which is the same for every height level.
#  It is defined for every edge but for each edge it has two values (thus `E2CDimN`) which is the result
#  of an cell to edge mapping.
c_lin_e = np.random.rand(EdgeDimN, E2CDimN).astype(np.float64)

# This is a field that defines for every cell in every height level one value.
z_w_con_c_full = np.random.rand(CellDimN, KDimN).astype(np.float64)

ddqz_z_full_e = np.random.rand(EdgeDimN, KDimN).astype(np.float64)

# This is the pseudoargument that we use to handle the return statement
_out = np.random.rand(EdgeDimN, KDimN).astype(np.float64)


area_edge = np.random.rand(EdgeDimN).astype(np.float64)
tangent_orientation = np.random.rand(EdgeDimN).astype(np.float64)
inv_primal_edge_length = np.random.rand(EdgeDimN).astype(np.float64)
zeta = np.random.rand(VertexDimN, KDimN).astype(np.float64)
geofac_grdiv = np.random.rand(EdgeDimN, E2C2EODimN).astype(np.float64)
vn = np.random.rand(EdgeDimN, KDimN).astype(np.float64)
ddt_vn_apc = np.random.rand(EdgeDimN, KDimN).astype(np.float64)


cfl_w_limit = np.full((1), 4.0)
scalfac_exdiff = np.full((1), 6.0)
dtime = np.full((1), 2.0)

cfl_w_limit_, scalfac_exdiff_, dtime_ = cfl_w_limit[0], scalfac_exdiff[0], dtime[0]


In [10]:
ddt_vn_apc_org = copy.deepcopy(ddt_vn_apc)

# Implementation

### Python Implementation

In [11]:
# Note that this is only the first part of the stencil.
def TestStencil(c_lin_e: dace.float64[EdgeDimN, E2CDimN],
                w_con_e: dace.float64[EdgeDimN, KDimN],
                z_w_con_c_full: dace.float64[CellDimN, KDimN],
                levelmask: dace.bool_[KDimN+1],
                edge_to_cell_table: dace.int32[EdgeDimN, E2CDimN],
                
                ddqz_z_full_e: dace.float64,    # Ver 2
                cfl_w_limit: dace.float64,
                scalfac_exdiff: dace.float64,
                dtime: dace.float64,

                area_edge: dace.float64[EdgeDimN],                   # Ver 3
                tangent_orientation: dace.float64[EdgeDimN],
                inv_primal_edge_length: dace.float64[EdgeDimN],
                zeta: dace.float64[VertexDimN, KDimN],
                geofac_grdiv: dace.float64[EdgeDimN, E2C2EODimN],
                vn: dace.float64[EdgeDimN, KDimN],
                ddt_vn_apc: dace.float64[EdgeDimN, KDimN],

                E2C2EO_table: dace.int32[EdgeDimN, E2C2EODimN],
                E2V_table: dace.int32[EdgeDimN, E2VDimN],
                
):
    c_lin_e = jnp.expand_dims(c_lin_e, axis=-1)   # Adding a new dimensions of shape 1 at the end for propper broadcasting (taken from the test).
    #difcoef = jnp.zeros_like(w_con_e)
    
    levelmask_offset_0  = levelmask[:-1]   # This is `levelmask[Koff[0]]`, i.e. the normal field
    levelmask_offset_1  = levelmask[1:]    # This is `levelmask[Koff[1]]`, i.e. the field at the top.
    levelmask_offset_OR = levelmask_offset_0 | levelmask_offset_1   # This is essentially manual common subexpression elimination.
    
    z_w_con_c_full_b = z_w_con_c_full[edge_to_cell_table]

    Z = jnp.where(
                jnp.expand_dims(edge_to_cell_table != -1, axis=-1), #(grid.connectivities[E2CDim] != -1)[:, :, np.newaxis],
                c_lin_e * z_w_con_c_full_b,    #z_w_con_c_full[grid.connectivities[E2CDim]],
                0,
    )
    Y = jnp.sum(Z, axis=1)

    # I am not sure if this is propagathed outside.
    w_con_e = jnp.where(levelmask_offset_OR, Y, 0.0)

    difcoef = jnp.where(
        levelmask_offset_OR & (jnp.abs(w_con_e) > cfl_w_limit * ddqz_z_full_e),
        scalfac_exdiff * jnp.minimum(
            0.85 - cfl_w_limit * dtime,
            jnp.abs(w_con_e) * dtime / ddqz_z_full_e - cfl_w_limit * dtime,
        ),
        0,
    )

    zeta_e2v = zeta[E2V_table]
    zeta_e2v_0 = zeta_e2v[:, 0]
    zeta_e2v_1 = zeta_e2v[:, 1]
    area_edge_ = jnp.expand_dims(area_edge, axis=-1)
    geofac_grdiv_ = jnp.expand_dims(geofac_grdiv, axis=-1)

    ddt_vn_apc_ = jnp.where(
        (levelmask_offset_OR)
        & (jnp.abs(w_con_e) > cfl_w_limit * ddqz_z_full_e),
        ddt_vn_apc
        + difcoef
        * area_edge_
        * (
           # Ths sum needs to be deactivated for the ITIR->SDFG to procude valid code
           #  But the ICON test simply works, I have no idea why.
            jnp.sum(
                jnp.where(
                    jnp.expand_dims(E2C2EO_table != -1, axis=-1),      #(grid.connectivities[E2C2EODim] != -1)[:, :, np.newaxis],
                    geofac_grdiv_ * vn[E2C2EO_table],  #vn[grid.connectivities[E2C2EODim]],
                    0,
                ),
                axis=1,
           ) +
            jnp.expand_dims(tangent_orientation * inv_primal_edge_length, axis=-1)
            * (
                zeta_e2v_1 #zeta[ grid.connectivities[E2VDim]][:, 1]
                - zeta_e2v_0 #- zeta[grid.connectivities[E2VDim]][:, 0]
            )
        ),
        ddt_vn_apc,
    )

    return ddt_vn_apc_
#

In [12]:
with jax.disable_jit(disable=True):
    TestStencil_jaxpr = jax.make_jaxpr(TestStencil)(c_lin_e, w_con_e,
                                                    z_w_con_c_full,
                                                    levelmask, edge_to_cell_table, ddqz_z_full_e, cfl_w_limit[0], scalfac_exdiff[0], dtime[0],
                                                    area_edge, tangent_orientation, inv_primal_edge_length, zeta, geofac_grdiv, vn, ddt_vn_apc,
                                                    E2C2EO_table, E2V_table)
print(TestStencil_jaxpr)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f64[27,2][39m b[35m:f64[27,50][39m c[35m:f64[18,50][39m d[35m:bool[51][39m e[35m:i32[27,2][39m f[35m:f64[27,50][39m
    g[35m:f64[][39m h[35m:f64[][39m i[35m:f64[][39m j[35m:f64[27][39m k[35m:f64[27][39m l[35m:f64[27][39m m[35m:f64[9,50][39m n[35m:f64[27,5][39m
    o[35m:f64[27,50][39m p[35m:f64[27,50][39m q[35m:i32[27,5][39m r[35m:i32[27,2][39m. [34m[22m[1mlet
    [39m[22m[22ms[35m:f64[27,2,1][39m = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(27, 2, 1)
    ] a
    t[35m:bool[50][39m = slice[limit_indices=(50,) start_indices=(0,) strides=None] d
    u[35m:bool[50][39m = slice[limit_indices=(51,) start_indices=(1,) strides=None] d
    v[35m:bool[50][39m = or t u
    w[35m:bool[27,2][39m = lt e 0
    x[35m:i32[27,2][39m = add e 18
    y[35m:i32[27,2][39m = select_n w e x
    z[35m:i32[27,2,1][39m = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(27

In [13]:
resExp = TestStencil(c_lin_e, w_con_e, z_w_con_c_full, levelmask, edge_to_cell_table, ddqz_z_full_e, cfl_w_limit[0], scalfac_exdiff[0], dtime[0], area_edge, tangent_orientation, inv_primal_edge_length, zeta, geofac_grdiv, vn, ddt_vn_apc, E2C2EO_table, E2V_table)

if(not noSDFG):
    TestStencil_sdfg = t(TestStencil_jaxpr, simplify=True, auto_opt=5, ret_by_arg=True)

    TestStencil_sdfg(c_lin_e, w_con_e, z_w_con_c_full, levelmask, edge_to_cell_table, ddqz_z_full_e, cfl_w_limit[0], scalfac_exdiff[0], dtime[0],
                     area_edge, tangent_orientation, inv_primal_edge_length, zeta, geofac_grdiv, vn, ddt_vn_apc, E2C2EO_table, E2V_table, _out)
    resDC = _out.copy()
    assert np.allclose(resExp, resDC)
#

In [14]:
assert np.allclose(ddt_vn_apc_org, ddt_vn_apc)

In [15]:
#TestStencil_sdfg.view()

In [16]:
if(True):
    pass #TestStencil_sdfg.view()
else:
    TestStencil_sdfg.save("./SDFG_of_stencil.json")


### GT4Py Implementation

In [17]:
import gt4py
from gt4py.next.common import GridType
from gt4py.next.ffront.decorator import field_operator, program
from gt4py.next.ffront.fbuiltins import (  # noqa: A004 # import gt4py builtin
    Field,
    abs,
    astype,
    broadcast,
    int32,
    minimum,
    neighbor_sum,
    where,
)

import os
from typing import TypeAlias

from gt4py.next.ffront.fbuiltins import float32, float64


wpfloat: TypeAlias = float64

precision = os.environ.get("FLOAT_PRECISION", "double").lower()
if precision == "double":
    vpfloat = wpfloat
elif precision == "mixed":
    vpfloat: TypeAlias = float32
else:
    raise ValueError("Only 'double' and 'mixed' precision are supported.")


from icondims import (
    E2C,
    E2C2EO,
    CellDim,
    E2C2EODim,
    E2CDim,
    EdgeDim,
    KDim,
    E2V,
    E2VDim,
    Koff,
    VertexDim,
)

E2C_offset_provider = gt4py.next.NeighborTableOffsetProvider(edge_to_cell_table, EdgeDim, CellDim, E2CDimN)
E2C2EO_offset_provider = gt4py.next.NeighborTableOffsetProvider(E2C2EO_table, EdgeDim, EdgeDim, E2C2EODimN)
E2V_offset_provider = gt4py.next.NeighborTableOffsetProvider(E2V_table, EdgeDim, VertexDim, E2VDimN)



@field_operator
def _mo_velocity_advection_stencil_20(
    levelmask: Field[[KDim], bool],
    c_lin_e: Field[[EdgeDim, E2CDim], float],
    z_w_con_c_full: Field[[CellDim, KDim], float],
    
    ddqz_z_full_e: Field[[EdgeDim, KDim], float],
    cfl_w_limit: float,
    scalfac_exdiff: float,
    dtime: float,

    area_edge: Field[[EdgeDim], float],
    tangent_orientation: Field[[EdgeDim], float],
    inv_primal_edge_length: Field[[EdgeDim], float],
    zeta: Field[[VertexDim, KDim], float],
    geofac_grdiv: Field[[EdgeDim, E2C2EODim], float],
    vn: Field[[EdgeDim, KDim], float],
    ddt_vn_apc: Field[[EdgeDim, KDim], float],
) -> Field[[EdgeDim, KDim], float]:
    z_w_con_c_full_wp, ddqz_z_full_e_wp, ddt_vn_apc_wp, cfl_w_limit_wp = astype(
        (z_w_con_c_full, ddqz_z_full_e, ddt_vn_apc, cfl_w_limit), wpfloat
    )

    w_con_e = broadcast(wpfloat("0.0"), (EdgeDim, KDim))
    difcoef = broadcast(wpfloat("0.0"), (EdgeDim, KDim))
    ddt_vn_apc_wp_ = broadcast(wpfloat("0.0"), (EdgeDim, KDim))   # Different from the original kernel, but for compability
    
    w_con_e = where(
        levelmask | levelmask(Koff[1]),
        neighbor_sum(c_lin_e * z_w_con_c_full_wp(E2C), axis=E2CDim),
        w_con_e,
    )
    difcoef = where(
        (levelmask | levelmask(Koff[1]))
        & (abs(w_con_e) > astype(cfl_w_limit * ddqz_z_full_e, wpfloat)),
        scalfac_exdiff
        * minimum(
            wpfloat("0.85") - cfl_w_limit_wp * dtime,
            abs(w_con_e) * dtime / ddqz_z_full_e_wp - cfl_w_limit_wp * dtime,
        ),
        difcoef,
    )
    ddt_vn_apc_wp_ = where(
        (levelmask | levelmask(Koff[1]))
        & (abs(w_con_e) > astype(cfl_w_limit * ddqz_z_full_e, wpfloat)),
        ddt_vn_apc_wp
        + difcoef
        * area_edge
        * (
            # Without deactivating this neighbor sum a compilation error would happen.
            #  I have no idea why the ICON part works.
            neighbor_sum(geofac_grdiv * vn(E2C2EO), axis=E2C2EODim) +
            tangent_orientation
            * inv_primal_edge_length
            * astype(zeta(E2V[1]) - zeta(E2V[0]), wpfloat)
        ),
        ddt_vn_apc_wp,
    )
    return astype(ddt_vn_apc_wp_, vpfloat)
#


@program(grid_type=GridType.UNSTRUCTURED)
def mo_velocity_advection_stencil_20(
    levelmask: Field[[KDim], bool],
    c_lin_e: Field[[EdgeDim, E2CDim], float],
    z_w_con_c_full: Field[[CellDim, KDim], float],

    ddqz_z_full_e: Field[[EdgeDim, KDim], float],
    cfl_w_limit: float,
    scalfac_exdiff: float,
    dtime: float,
    
    area_edge: Field[[EdgeDim], float],
    tangent_orientation: Field[[EdgeDim], float],
    inv_primal_edge_length: Field[[EdgeDim], float],
    zeta: Field[[VertexDim, KDim], float],
    geofac_grdiv: Field[[EdgeDim, E2C2EODim], float],
    vn: Field[[EdgeDim, KDim], float],
    ddt_vn_apc: Field[[EdgeDim, KDim], float],

    _ret_: Field[[EdgeDim, KDim], float],
    horizontal_start: int32,
    horizontal_end: int32,
    vertical_start: int32,
    vertical_end: int32,
):
    _mo_velocity_advection_stencil_20(
        levelmask,
        c_lin_e,
        z_w_con_c_full,
        ddqz_z_full_e, cfl_w_limit, scalfac_exdiff, dtime,

        area_edge, tangent_orientation, inv_primal_edge_length, zeta, geofac_grdiv, vn, ddt_vn_apc,

        out=_ret_,
        domain={
            EdgeDim: (horizontal_start, horizontal_end),
            KDim: (vertical_start, vertical_end),
        },
    )

In [18]:
from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator

levelmask_F = gt4py.next.as_field([KDim], levelmask)
c_lin_e_F = gt4py.next.as_field([EdgeDim, E2CDim], c_lin_e)
z_w_con_c_full_F = gt4py.next.as_field([CellDim, KDim], z_w_con_c_full)

ddqz_z_full_e_F = gt4py.next.as_field([EdgeDim, KDim], ddqz_z_full_e)



area_edge_F = gt4py.next.as_field([EdgeDim], area_edge)

tangent_orientation_F = gt4py.next.as_field([EdgeDim], tangent_orientation)

inv_primal_edge_length_F = gt4py.next.as_field([EdgeDim], inv_primal_edge_length)
zeta_F = gt4py.next.as_field([VertexDim, KDim], zeta)
geofac_grdiv_F = gt4py.next.as_field([EdgeDim, E2C2EODim], geofac_grdiv)
vn_F = gt4py.next.as_field([EdgeDim, KDim], vn)
ddt_vn_apc_F = gt4py.next.as_field([EdgeDim, KDim], ddt_vn_apc)

_ret_ = gt4py.next.as_field([EdgeDim, KDim], np.zeros((EdgeDimN, KDimN)))


In [19]:
if(not noITIR):
    TestStencil_itir_sdfg, TestStencil_itir_sdfg_args = run_dace_iterator(
        mo_velocity_advection_stencil_20.itir,
        
        levelmask_F, c_lin_e_F, z_w_con_c_full_F,
    
        ddqz_z_full_e_F, cfl_w_limit[0], scalfac_exdiff[0], dtime[0],
        area_edge_F, tangent_orientation_F, inv_primal_edge_length_F, zeta_F, geofac_grdiv_F, vn_F, ddt_vn_apc_F,

        _ret_,
        0, EdgeDimN, 0, KDimN,
    
        return_sdfg_args=True, run_sdfg=True,
        auto_optimize=True,
        cpu_args=_cpu_args,
        offset_provider={'E2C': E2C_offset_provider, "Koff": KDim, 'E2C2EO': E2C2EO_offset_provider, 'E2V': E2V_offset_provider}
    )

    assert np.allclose(ddt_vn_apc_org, ddt_vn_apc)
    assert np.allclose(TestStencil_itir_sdfg_args['_ret_'], resExp)
#



## "Performance" / Runtime

In [20]:
import timeit
repeat, number = 3000, 50

In [21]:
TestStencil_jax = jax.jit(TestStencil)

In [22]:
with dace.config.temporary_config():
    dace.config.Config.set("compiler", "build_type", value=_build_type)
    if(len(_cpu_args) > 0):
        dace.config.Config.set("compiler", "cpu", "args", value=_cpu_args)
    TestStencil_itir_prog = None if noITIR else TestStencil_itir_sdfg.compile()
    TestStencil_prog = None if noSDFG else TestStencil_sdfg.compile()


In [23]:
# Now to make that they are all loded, call them once
if(not noITIR):
    with dace.config.temporary_config():
        dace.Config.set('compiler', 'allow_view_arguments', value=True)
        _ = TestStencil_itir_prog(**TestStencil_itir_sdfg_args)

In [24]:
if(not noSDFG):
    pass
    _ = TestStencil_prog(c_lin_e, w_con_e, z_w_con_c_full, levelmask, edge_to_cell_table, ddqz_z_full_e, cfl_w_limit[0], scalfac_exdiff[0], dtime[0],
                         area_edge, tangent_orientation, inv_primal_edge_length, zeta, geofac_grdiv, vn, ddt_vn_apc, E2C2EO_table, E2V_table, _out)

In [25]:
_ = TestStencil_jax(c_lin_e, w_con_e, z_w_con_c_full, levelmask, edge_to_cell_table, ddqz_z_full_e, cfl_w_limit[0], scalfac_exdiff[0], dtime[0],
                     area_edge, tangent_orientation, inv_primal_edge_length, zeta, geofac_grdiv, vn, ddt_vn_apc, E2C2EO_table, E2V_table)

In [26]:
#%timeit TestStencil_itir_prog(**TestStencil_itir_sdfg_args_)
if(not noITIR):
    def TestStencil_itir_prog_TI():
        TestStencil_itir_prog._fast_call(*TestStencil_itir_prog._lastargs)
    #
    with dace.config.temporary_config():
        dace.Config.set('compiler', 'allow_view_arguments', value=True)
        t_ = timeit.Timer(TestStencil_itir_prog_TI)
        r = np.array(t_.repeat(repeat, number))
    print(f"Run the function {number} times inside a loop and repeated the loop {repeat} times:\n- min: {r.min()}s\n- max: {r.max()}s\n- mean: {r.mean()}s\n- std: {r.std()}s")


Run the function 3 times inside a loop and repeated the loop 2 times:
- min: 0.00044409399561118335s
- max: 0.0006593979996978305s
- mean: 0.0005517459976545069s
- std: 0.00010765200204332359s


In [27]:
#%timeit TestStencil_prog(c_lin_e, w_con_e, z_w_con_c_full, levelmask, edge_to_cell_table, ddqz_z_full_e, cfl_w_limit[0], scalfac_exdiff[0], dtime[0])
if(not noSDFG):
    def TestStencil_prog_TI():
        TestStencil_prog._fast_call(*TestStencil_prog._lastargs)
    #

    t_ = timeit.Timer(TestStencil_prog_TI)
    r = np.array(t_.repeat(repeat, number))
    print(f"Run the function {number} times inside a loop and repeated the loop {repeat} times:\n- min: {r.min()}s\n- max: {r.max()}s\n- mean: {r.mean()}s\n- std: {r.std()}s")


Run the function 3 times inside a loop and repeated the loop 2 times:
- min: 6.949700036784634e-05s
- max: 0.0004733749956358224s
- mean: 0.0002714359980018344s
- std: 0.00020193899763398804s


In [28]:
#%timeit TestStencil_jax(c_lin_e, w_con_e, z_w_con_c_full, levelmask, edge_to_cell_table, ddqz_z_full_e, cfl_w_limit[0], scalfac_exdiff[0], dtime[0])
_JAX_ARGS = [c_lin_e, w_con_e, z_w_con_c_full, levelmask, edge_to_cell_table, ddqz_z_full_e, cfl_w_limit[0], scalfac_exdiff[0], dtime[0],
              area_edge, tangent_orientation, inv_primal_edge_length, zeta, geofac_grdiv, vn, ddt_vn_apc, E2C2EO_table, E2V_table]
def TestStencil_jax_TI():
    TestStencil_jax(*_JAX_ARGS)
#

t_ = timeit.Timer(TestStencil_jax_TI)
r = np.array(t_.repeat(repeat, number))

print(f"Run the function {number} times inside a loop and repeated the loop {repeat} times:\n- min: {r.min()}s\n- max: {r.max()}s\n- mean: {r.mean()}s\n- std: {r.std()}s")


Run the function 3 times inside a loop and repeated the loop 2 times:
- min: 0.0001045500030159019s
- max: 0.00032244299654848874s
- mean: 0.00021349649978219531s
- std: 0.00010894649676629342s


In [29]:
resTime_sdfg = []
resTime_jax  = []
resTime_itir = []
KDimNs       = [50] + [100 * i  for i in range(1, 21, 3)]

for KDimN_ in KDimNs:
    print(f"Start with KDimN = {KDimN_}", flush=True)
    w_con_e = np.random.rand(EdgeDimN, KDimN_).astype(np.float64)
    z_w_con_c_full = np.random.rand(CellDimN, KDimN_).astype(np.float64)
    ddqz_z_full_e = np.random.rand(EdgeDimN, KDimN_).astype(np.float64)
    levelmask = (np.random.rand(KDimN_ + 1) < 0.5).astype(np.bool_)

    area_edge = np.random.rand(EdgeDimN).astype(np.float64)
    tangent_orientation = np.random.rand(EdgeDimN).astype(np.float64)
    inv_primal_edge_length = np.random.rand(EdgeDimN).astype(np.float64)
    zeta = np.random.rand(VertexDimN, KDimN_).astype(np.float64)
    geofac_grdiv = np.random.rand(EdgeDimN, E2C2EODimN).astype(np.float64)
    vn = np.random.rand(EdgeDimN, KDimN_).astype(np.float64)
    ddt_vn_apc = np.random.rand(EdgeDimN, KDimN_).astype(np.float64)
    
    _out = np.random.rand(EdgeDimN, KDimN_).astype(np.float64)   # This is the argument for returning the result in the `JAX -> SDFG` Translation

    # Generating JAX Source
    with jax.disable_jit(disable=True):
        TestStencil_jaxpr = jax.make_jaxpr(TestStencil)(c_lin_e, w_con_e, z_w_con_c_full, levelmask, edge_to_cell_table, ddqz_z_full_e, cfl_w_limit[0], scalfac_exdiff[0], dtime[0],
                                                        area_edge, tangent_orientation, inv_primal_edge_length, zeta, geofac_grdiv, vn, ddt_vn_apc, E2C2EO_table, E2V_table)
    #

    # Generating the ITIR Based translation
    if(not noITIR):
        levelmask_F = gt4py.next.as_field([KDim], levelmask)
        c_lin_e_F = gt4py.next.as_field([EdgeDim, E2CDim], c_lin_e)
        z_w_con_c_full_F = gt4py.next.as_field([CellDim, KDim], z_w_con_c_full)
        ddqz_z_full_e_F = gt4py.next.as_field([EdgeDim, KDim], ddqz_z_full_e)
        _ret_ = gt4py.next.as_field([EdgeDim, KDim], np.zeros_like(ddqz_z_full_e))
        TestStencil_itir_sdfg, TestStencil_itir_sdfg_args = run_dace_iterator(
                mo_velocity_advection_stencil_20.itir, 
                levelmask_F, c_lin_e_F, z_w_con_c_full_F, ddqz_z_full_e_F, cfl_w_limit[0], scalfac_exdiff[0], dtime[0],
                area_edge_F, tangent_orientation_F, inv_primal_edge_length_F, zeta_F, geofac_grdiv_F, vn_F, ddt_vn_apc_F,
                _ret_,
                0, EdgeDimN, 0, KDimN,
                return_sdfg_args=True, run_sdfg=False, auto_optimize=True,
                cpu_args=_cpu_args,
                offset_provider={'E2C': E2C_offset_provider, "Koff": KDim, 'E2C2EO': E2C2EO_offset_provider, 'E2V': E2V_offset_provider}
        )
        #
    #

    # Translating from JAX to SDFG
    if(not noSDFG):
        TestStencil_sdfg = t(TestStencil_jaxpr, simplify=True, auto_opt=5, ret_by_arg=True)

    # Create the computations
    with dace.config.temporary_config():
        dace.config.Config.set("compiler", "build_type", value=_build_type)
        if(len(_cpu_args) > 0):
            dace.config.Config.set("compiler", "cpu", "args", value=_cpu_args)
        if(not noSDFG):
            TestStencil_prog = TestStencil_sdfg.compile()
            # Ensure that everything is initialized
            _ = TestStencil_prog(c_lin_e, w_con_e, z_w_con_c_full, levelmask, edge_to_cell_table, ddqz_z_full_e, cfl_w_limit[0], scalfac_exdiff[0], dtime[0],
                                 area_edge, tangent_orientation, inv_primal_edge_length, zeta, geofac_grdiv, vn, ddt_vn_apc, E2C2EO_table, E2V_table, _out)
        if(not noITIR):
            TestStencil_itir_prog = TestStencil_itir_sdfg.compile()
            # Ensure that everything is initialized
            with dace.config.temporary_config():
                dace.Config.set('compiler', 'allow_view_arguments', value=True)
                _ = TestStencil_itir_prog(**TestStencil_itir_sdfg_args)
    #

    if(not noITIR):
        #TestStencil_itir_prog_TI = lambda: TestStencil_itir_prog(**TestStencil_itir_sdfg_args)
        TestStencil_itir_prog_TI = lambda: TestStencil_itir_prog._fast_call(*TestStencil_itir_prog._lastargs)
    #
    if(not noSDFG):
        #TestStencil_prog_TI = lambda: TestStencil_prog(c_lin_e, w_con_e, z_w_con_c_full, levelmask, edge_to_cell_table, ddqz_z_full_e, cfl_w_limit[0], scalfac_exdiff[0], dtime[0], _out=_out)
        TestStencil_prog_TI = lambda: TestStencil_prog._fast_call(*TestStencil_prog._lastargs)
    #

    # Make JAX Block and prevent asynchrone dyspatch
    #  According to `https://jax.readthedocs.io/en/latest/async_dispatch.html` this style should not materialize the return value.
    TestStencil_jax_TI = lambda: TestStencil_jax(c_lin_e, w_con_e, z_w_con_c_full, levelmask, edge_to_cell_table, ddqz_z_full_e, cfl_w_limit[0], scalfac_exdiff[0], dtime[0],
                                                 area_edge, tangent_orientation, inv_primal_edge_length, zeta, geofac_grdiv, vn, ddt_vn_apc, E2C2EO_table, E2V_table).block_until_ready()
    TestStencil_jax_TI()    # Call it once to ensure that it is present
    
    # Test ITIR Translation
    if(not noITIR):
        with dace.config.temporary_config():
            dace.Config.set('compiler', 'allow_view_arguments', value=True)
            t_ = timeit.Timer(TestStencil_itir_prog_TI)
            r = np.array(t_.repeat(repeat, number))
    else:
        r = np.zeros(repeat)
    resTime_itir.append(r)

    # Test SDFG
    if(not noSDFG):
        t_ = timeit.Timer(TestStencil_prog_TI)
        r = np.array(t_.repeat(repeat, number))
    else:
        r = np.zeros(repeat)
    resTime_sdfg.append(r)

    # Test JAX
    t_ = timeit.Timer(TestStencil_jax_TI)
    r = np.array(t_.repeat(repeat, number))
    resTime_jax.append(r)
#

Now = datetime.now()
with open(f'ICON__{Now.year}-{Now.month}-{Now.day}_{Now.hour}-{Now.minute}-{Now.second}.pickle', '+bw') as F:
    dump((KDimNs, resTime_itir, resTime_jax, resTime_sdfg), F)
#

Start with KDimN = 50


KeyboardInterrupt: 

In [None]:
resTime_sdfg_min = np.array([r.min()  for r in resTime_sdfg])
resTime_jax_min = np.array([r.min()  for r in resTime_jax])
resTime_itir_min = np.array([r.min()  for r in resTime_itir])

resTime_sdfg_mean = np.array([r.mean()  for r in resTime_sdfg])
resTime_jax_mean = np.array([r.mean()  for r in resTime_jax])
resTime_itir_mean = np.array([r.mean()  for r in resTime_itir])

resTime_sdfg_std = np.array([r.std()  for r in resTime_sdfg])
resTime_jax_std = np.array([r.std()  for r in resTime_jax])
resTime_itir_std = np.array([r.std()  for r in resTime_itir])


In [None]:
import matplotlib as mpl
from matplotlib import pyplot as plt

In [None]:
X = np.vstack([KDimNs, np.ones_like(KDimNs)]).T

for ID, Time in {'JAX -> SDFG':  resTime_sdfg_min,
                 'ITIR -> SDFG': resTime_itir_min,
                 'JAX':          resTime_jax_min, }.items():
    plt.plot(KDimNs, Time * 1000000 / number, label=ID)
    col = plt.gca().lines[-1].get_color()
    a, b = np.linalg.lstsq(X, Time / number, rcond=None)[0]

    Times_ = a * np.array(KDimNs) + b
    plt.plot(
        KDimNs, Times_ * 1000000,
        c=col, linestyle='--', alpha=0.8, 
        label=f'${int(a * (1e6*1e4))/1e4} \\cdot k + {int(b * (1e3*1e4))/1e4}$',
    )
#


plt.legend()
plt.xlabel('KDim')
plt.ylabel('$t_{min run}$ [$\\mu$s]')
plt.grid(visible=True)
#plt.yscale('log')

plt.title("Minimal Runtime")

plt.show()


In [None]:
X = np.vstack([KDimNs, np.ones_like(KDimNs)]).T

for ID, Time in {'JAX -> SDFG':  resTime_sdfg_min,
                 'ITIR -> SDFG': resTime_itir_min,
                 'JAX':          resTime_jax_min, }.items():
    plt.plot(KDimNs, (Time - Time[0]) * 1000 / number, label=ID)
    col = plt.gca().lines[-1].get_color()
    a, b = np.linalg.lstsq(X, Time / number, rcond=None)[0]

    Times_ = a * np.array(KDimNs)
    plt.plot(
        KDimNs, Times_ * 1000,
        c=col, linestyle='--', alpha=0.8, 
        label=f'${int(a * (1e3*1e4))/1e4} \\cdot k + {int(b * (1e3*1e4))/1e4}$',
    )

#


plt.legend()
plt.xlabel('KDim')
plt.ylabel('$\\hat{t}_{min run}$ [ms]')
plt.grid(visible=True)
plt.yscale('log')

plt.title("Relative Runtime")

plt.show()
