# ICON Stencil
This stencil is based on `model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/mo_solve_nonhydro_stencil_20.py`.
It is very simple and most faithfull implementation.

A reference `NumPy` implementation can be found in `model/atmosphere/dycore/tests/stencil_tests/test_mo_velocity_advection_stencil_20.py`.


In [1]:
import os
ncpu=1
os.environ['XLA_FLAGS'] = f"--xla_force_host_platform_device_count={ncpu}"
os.environ["XLA_FLAGS"] = ("--xla_cpu_multi_thread_eigen=false "
                           "intra_op_parallelism_threads=1")
os.environ['JAX_PLATFORMS'] = "cpu"

import numpy as np
import jax
import sys
import jax.numpy as jnp
from jax import jit, grad, vmap
from jax import random

# This must be enabled when `make_jaxpr` is called, because otherwhise we get problems.
jax.config.update("jax_enable_x64", True)

import dace


In [2]:
from JaxprToSDFG import  JaxprToSDFG
t = JaxprToSDFG()

### Demo Input

### Grid Dimensions
Since we miss the GT4Py Dimension mechanism, we will make some tricks.
- `{K, Cell, Edge, Vertex}DimN` is the number of either Vertex, Cell, Edges or vertical levels.
- Valid indexes always life within the range `[0, N-1]` where `N` is the above entity.
- However, they do not denote at which position they are.



In [3]:
KDimN = 5     # Vertical levels
CellDimN = 6   # How many cells we have
EdgeDimN = 12
VertexDimN = 100


# `E2C` is the offset map that transforms _edges_ to _cells_.
#   Since a edge is adijacent to at most two cells this dimension is 2 but at least 1.
E2CDimN = 2


### Fields
These are the fields we are using as an inoput material

In [4]:
# Field defined on the edges of the grid and different in every hight level.
w_con_e = np.random.rand(EdgeDimN, KDimN).astype(np.float64)

# This is a field, which is the same for every height level.
#  It is defined for every edge but for each edge it has two values (thus `E2CDimN`) which is the result
#  of an cell to edge mapping.
c_lin_e = np.random.rand(EdgeDimN, E2CDimN).astype(np.float64)

# This is a field that defines for every cell in every height level one value.
z_w_con_c_full = np.random.rand(CellDimN, KDimN).astype(np.float64)

ddqz_z_full_e = np.random.rand(EdgeDimN, KDimN).astype(np.float64)

cfl_w_limit = np.full((1), 4.0)
scalfac_exdiff = np.full((1), 6.0)
dtime = np.full((1), 2.0)


In [5]:
A = np.ones((1, 1, 5))


### Offset Providers
Here we define the offset providors, i.e. the connectivity.
It is important that they have to be read in a certain way.

In [6]:
# This is the level mask, the `+ 1` is taken from the test but it makes sense, at least a bit.
#  For `levelmask[Koff]` use `levelmask[:-1]`
#  and for `levelmask[Koff[+1]]` use `levelmask[1:]`.
levelmask = (np.random.rand(KDimN + 1) < 0.5).astype(np.bool_)

# As stated above, each edge is either adijacent to 1 cell (at the boundary) or 2 (in the inner part)
#  This table describes which cell are adijacent to a certain edge, in case it is undefined we have `-1`.  
edge_to_cell_table = np.array([
    [0, -1], # edge 0 (neighbours: cell 0)
    [2, -1], # edge 1
    [2, -1], # edge 2
    [3, -1], # edge 3
    [4, -1], # edge 4
    [5, -1], # edge 5
    [0, 5],  # edge 6 (neighbours: cell 0, cell 5)
    [0, 1],  # edge 7
    [1, 2],  # edge 8
    [1, 3],  # edge 9
    [3, 4],  # edge 10
    [4, 5]   # edge 11
]).astype(np.int32)

# We have a triangular grid, this means a cell is bounded by three edges.
#  In unstructured it could be, we have many geometrical forms, in this case
#  we would allocate the matrix such that it fits the largest one and fill up with `-1` if we do not need them.
cell_to_edge_table = np.array([
    [0, 6, 7],   # cell 0 (neighbors: edge 0, edge 6, edge 7)
    [7, 8, 9],   # cell 1
    [1, 2, 8],   # cell 2
    [3, 9, 10],  # cell 3
    [4, 10, 11], # cell 4
    [5, 6, 11],  # cell 5
]).astype(np.int32)


assert edge_to_cell_table.shape[0] == EdgeDimN
assert cell_to_edge_table.shape[0] == CellDimN


# Implementation

### Python Implementation

In [7]:
# Note that this is only the first part of the stencil.
def TestStencil(c_lin_e,
                w_con_e,
                z_w_con_c_full,
                levelmask,
                edge_to_cell_table,
                
                ddqz_z_full_e,    # Ver 2
                cfl_w_limit,
                scalfac_exdiff,
                dtime,
):
    c_lin_e = jnp.expand_dims(c_lin_e, axis=-1)   # Adding a new dimensions of shape 1 at the end for propper broadcasting (taken from the test).
    
    levelmask_offset_0 = levelmask[:-1]   # This is `levelmask[Koff[0]]`, i.e. the normal field
    levelmask_offset_1 = levelmask[1:]    # This is `levelmask[Koff[1]]`, i.e. the field at the top.

    (EdgeDimN, E2CDimN, _) = c_lin_e.shape
    (CellDimN, KDimN)      = z_w_con_c_full.shape

    if(False):
        # This is more general but it requieres that a `scatter` operation is present.
        z_w_con_c_full_b = jnp.zeros((EdgeDimN, E2CDimN, KDimN))
        z_w_con_c_full_b = z_w_con_c_full_b.at[:, 0, :].set(z_w_con_c_full[edge_to_cell_table[:, 0], :]) 
        z_w_con_c_full_b = z_w_con_c_full_b.at[:, 1, :].set(z_w_con_c_full[edge_to_cell_table[:, 1], :])  # The `-1` in the `edge_to_cell_table` will be resolved as some bullshit values, that will be removed further down in the where.
    else:
        z_w_0_ = z_w_con_c_full[edge_to_cell_table[:, 0], :]
        z_w_1_ = z_w_con_c_full[edge_to_cell_table[:, 1], :]
        z_w_con_c_full_0 = jnp.expand_dims(z_w_0_, axis=1)
        z_w_con_c_full_1 = jnp.expand_dims(z_w_1_, axis=1)
        z_w_con_c_full_b = jnp.concatenate((z_w_con_c_full_0, z_w_con_c_full_1), axis=1)
    #
    MaskZ = jnp.expand_dims(edge_to_cell_table != -1, axis=-1)
    Z = jnp.where(
                MaskZ, #(grid.connectivities[E2CDim] != -1)[:, :, np.newaxis],
                c_lin_e * z_w_con_c_full_b,    #z_w_con_c_full[grid.connectivities[E2CDim]],
                0,
    )
    Y = jnp.sum(Z, axis=1)
    Mask = (levelmask_offset_0) | (levelmask_offset_1)
    w_con_e = jnp.where(Mask, Y, w_con_e)

    difcoef = jnp.zeros_like(w_con_e)

    difcoef = jnp.where(
        ((levelmask_offset_0) | (levelmask_offset_1))
        & (jnp.abs(w_con_e) > cfl_w_limit * ddqz_z_full_e),
        scalfac_exdiff
        * jnp.minimum(
            0.85 - cfl_w_limit * dtime,
            jnp.abs(w_con_e) * dtime / ddqz_z_full_e - cfl_w_limit * dtime,
        ),
        difcoef,
    )


    
    return difcoef
#

In [8]:
with jax.disable_jit(disable=True):
    TestStencil_jaxpr = jax.make_jaxpr(TestStencil)(c_lin_e, w_con_e, z_w_con_c_full, levelmask, edge_to_cell_table, ddqz_z_full_e, cfl_w_limit, scalfac_exdiff, dtime)
print(TestStencil_jaxpr)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f64[12,2][39m b[35m:f64[12,5][39m c[35m:f64[6,5][39m d[35m:bool[6][39m e[35m:i32[12,2][39m f[35m:f64[12,5][39m g[35m:f64[1][39m
    h[35m:f64[1][39m i[35m:f64[1][39m. [34m[22m[1mlet
    [39m[22m[22mj[35m:f64[12,2,1][39m = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(12, 2, 1)
    ] a
    k[35m:bool[5][39m = slice[limit_indices=(5,) start_indices=(0,) strides=None] d
    l[35m:bool[5][39m = slice[limit_indices=(6,) start_indices=(1,) strides=None] d
    m[35m:i32[12,1][39m = slice[limit_indices=(12, 1) start_indices=(0, 0) strides=None] e
    n[35m:i32[12][39m = squeeze[dimensions=(1,)] m
    o[35m:bool[12][39m = lt n 0
    p[35m:i32[12][39m = add n 6
    q[35m:i32[12][39m = select_n o n p
    r[35m:i32[12,1][39m = broadcast_in_dim[broadcast_dimensions=(0,) shape=(12, 1)] q
    s[35m:f64[12,5][39m = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collaps

In [9]:
TestStencil_sdfg = t(TestStencil_jaxpr, simplify=True, auto_opt=5)

In [10]:
resExp = TestStencil(c_lin_e, w_con_e, z_w_con_c_full, levelmask, edge_to_cell_table, ddqz_z_full_e, cfl_w_limit, scalfac_exdiff, dtime)
resDC  = TestStencil_sdfg(c_lin_e, w_con_e, z_w_con_c_full, levelmask, edge_to_cell_table, ddqz_z_full_e, cfl_w_limit, scalfac_exdiff, dtime)

assert np.allclose(resExp, resDC)


In [11]:
TestStencil_sdfg.name

'jax_139734330774336'

In [12]:
TestStencil_sdfg

In [13]:
TestStencil_sdfg.view()

File saved at /tmp/tmp4mpxhg4z.sdfg.html


### GT4Py Implementation

In [14]:
import gt4py
from gt4py.next.common import GridType
from gt4py.next.ffront.decorator import field_operator, program
from gt4py.next.ffront.fbuiltins import (  # noqa: A004 # import gt4py builtin
    Field,
    abs,
    broadcast,
    int32,
    minimum,
    neighbor_sum,
    where,
)

from icondims import (
    E2C,
    E2C2EO,
    E2V,
    CellDim,
    E2C2EODim,
    E2CDim,
    EdgeDim,
    KDim,
    Koff,
    VertexDim,
)

E2C_offset_provider = gt4py.next.NeighborTableOffsetProvider(edge_to_cell_table, EdgeDim, CellDim, 2)


@field_operator
def _mo_velocity_advection_stencil_20(
    levelmask: Field[[KDim], bool],
    c_lin_e: Field[[EdgeDim, E2CDim], float],
    z_w_con_c_full: Field[[CellDim, KDim], float],
    
    ddqz_z_full_e: Field[[EdgeDim, KDim], float],
    cfl_w_limit: float,
    scalfac_exdiff: float,
    dtime: float,
) -> Field[[EdgeDim, KDim], float]:
    w_con_e = broadcast(0.0, (EdgeDim, KDim))
    difcoef = broadcast(0.0, (EdgeDim, KDim))
    w_con_e = where(
        levelmask | levelmask(Koff[1]),
        neighbor_sum(c_lin_e * z_w_con_c_full(E2C), axis=E2CDim),
        w_con_e,
    )

    difcoef = where(
        (levelmask | levelmask(Koff[1])) & (abs(w_con_e) > cfl_w_limit * ddqz_z_full_e),
        scalfac_exdiff
        * minimum(
            0.85 - cfl_w_limit * dtime,
            abs(w_con_e) * dtime / ddqz_z_full_e - cfl_w_limit * dtime,
        ),
        difcoef,
    )
    
    return difcoef
#


@program(grid_type=GridType.UNSTRUCTURED)
def mo_velocity_advection_stencil_20(
    levelmask: Field[[KDim], bool],
    c_lin_e: Field[[EdgeDim, E2CDim], float],
    z_w_con_c_full: Field[[CellDim, KDim], float],

    ddqz_z_full_e: Field[[EdgeDim, KDim], float],
    cfl_w_limit: float,
    scalfac_exdiff: float,
    dtime: float,
    
    _ret_: Field[[EdgeDim, KDim], float],
    horizontal_start: int32,
    horizontal_end: int32,
    vertical_start: int32,
    vertical_end: int32,
):
    _mo_velocity_advection_stencil_20(
        levelmask,
        c_lin_e,
        z_w_con_c_full,
        ddqz_z_full_e, cfl_w_limit, scalfac_exdiff, dtime,
        out=_ret_,
        domain={
            EdgeDim: (horizontal_start, horizontal_end),
            KDim: (vertical_start, vertical_end),
        },
    )

In [15]:
from gt4py.next.program_processors.runners.dace_iterator import run_dace_iterator

levelmask_F = gt4py.next.as_field([KDim], levelmask)
c_lin_e_F = gt4py.next.as_field([EdgeDim, E2CDim], c_lin_e)
z_w_con_c_full_F = gt4py.next.as_field([CellDim, KDim], z_w_con_c_full)

ddqz_z_full_e_F = gt4py.next.as_field([EdgeDim, KDim], ddqz_z_full_e)

_ret_ = gt4py.next.as_field([EdgeDim, E2CDim], np.zeros_like(resExp))


In [16]:
TestStencil_itir_sdfg = run_dace_iterator(
    mo_velocity_advection_stencil_20.itir,
    levelmask_F, c_lin_e_F, z_w_con_c_full_F,

    ddqz_z_full_e_F, cfl_w_limit[0], scalfac_exdiff[0], dtime[0],

    _ret_,
    0, EdgeDimN, 0, KDimN,

    return_sdfg=True, run_sdfg=True,
    offset_provider={'E2C':E2C_offset_provider, "Koff": KDim})

Will return the generated SDFG.


In [17]:
TestStencil_itir_sdfg

In [18]:
TestStencil_itir_sdfg.view()

File saved at /tmp/tmp0qx8hxtd.sdfg.html
