# Arakawa Scheme
> Implementing Stencils efficiently in Python
- toc: True

By Akio Arakawa from "Computational Design for Long-Term Numerical Integration of the Equations of Fluid Motion: Two-Dimensional Incompressible Flow"

https://www.sciencedirect.com/science/article/pii/S0021999197956977

## Introduction

$$ 
\frac{\partial \zeta}{\partial t} + \mathbf{v} \cdot \nabla \zeta = 0
$$
with: $\mathbf{v = k} \times \nabla \psi, \quad \zeta=\mathbf{k\cdot\nabla\times v \equiv \nabla^2} \psi$

### Rewriting

$$ 
\frac{\partial \zeta}{\partial t} = J(\zeta, \psi) 
\quad or \quad 
\frac{\partial \nabla^2 \psi}{\partial t} = J(\nabla^2 \psi, \psi)
$$

## Explicit Calculation

### Equations

$$\frac{\mathbb{J}^{++} + \mathbb{J}^{+\times} + \mathbb{J}^{\times +}}{3}$$

$\mathbb{J}^{++}(\zeta, \psi) 
= \frac{1}{4d^2} \left[ 
    \left( \zeta_{i+1, j}   - \zeta_{i-1, j}   \right) \left( \psi_{i,   j+1} - \psi_{i,   j-1} \right)
   -\left( \zeta_{i,   j+1} - \zeta_{i,   j-1} \right) \left( \psi_{i+1, j}   - \psi_{i-1, j}   \right)
\right]
$

$\mathbb{J}^{+\times}(\zeta, \psi)
= \frac{1}{4d^2} \left[
    \zeta_{i+1, j}   \left( \psi_{i+1, j+1} - \psi_{i+1, j-1} \right)
   -\zeta_{i-1, j}   \left( \psi_{i-1, j+1} - \psi_{i-1, j-1} \right)
   -\zeta_{i,   j+1} \left( \psi_{i+1, j+1} - \psi_{i-1, j+1} \right)
   +\zeta_{i,   j-1} \left( \psi_{i+1, j-1} - \psi_{i-1, j-1} \right)
\right]
$

$\mathbb{J}^{+\times}(\zeta, \psi)
= \frac{1}{4d^2} \left[
    \zeta_{i+1, j+1} \left( \psi_{i,   j+1} - \psi_{i+1, j}   \right)
   -\zeta_{i-1, j-1} \left( \psi_{i-1, j}   - \psi_{i, j-1} \right)
   -\zeta_{i-1, j+1} \left( \psi_{i,   j+1} - \psi_{i-1, j}   \right)
   +\zeta_{i+1, j-1} \left( \psi_{i+1, j}   - \psi_{i,   j-1} \right)
\right]
$

### In Python

In [1]:
def jpp(zeta, psi, d, i, j):
    return ((zeta[i+1, j  ] - zeta[i-1, j  ])*(psi[i,   j+1] - psi[i,   j-1])
           -(zeta[i,   j+1] - zeta[i,   j-1])*(psi[i+1, j  ] - psi[i-1, j  ]))/(4*d**2)

def jpx(zeta, psi, d, i, j):
    return (zeta[i+1, j  ]*(psi[i+1, j+1] - psi[i+1, j-1])
           -zeta[i-1, j  ]*(psi[i-1, j+1] - psi[i-1, j-1])
           -zeta[i,   j+1]*(psi[i+1, j+1] - psi[i-1, j+1])
           +zeta[i,   j-1]*(psi[i+1, j-1] - psi[i-1, j-1]))/(4*d**2)

def jxp(zeta, psi, d, i, j):
    return (zeta[i+1, j+1]*(psi[i,   j+1] - psi[i+1, j  ])
           -zeta[i-1, j-1]*(psi[i-1, j  ] - psi[i,   j-1])
           -zeta[i-1, j+1]*(psi[i,   j+1] - psi[i-1, j  ])
           +zeta[i+1, j-1]*(psi[i+1, j  ] - psi[i,   j-1]))/(4*d**2)

In [49]:
def arakawa(zeta, psi, d):
    val = np.empty_like(zeta)
    for i in range(zeta.shape[0]):
        for j in range(zeta.shape[1]):
            val += (jpp(zeta, psi, d, i, j) + jpx(zeta, psi, d, i, j) + jxp(zeta, psi, d, i, j))
    return val/3

### As Stencil

https://numba.pydata.org/numba-doc/dev/user/stencil.html

In [2]:
from numba import stencil

In [3]:
@stencil
def jpp(zeta, psi, d):
    return ((zeta[1, 0] - zeta[-1, 0])*(psi[0, 1] - psi[0, -1])
           -(zeta[0, 1] - zeta[0, -1])*(psi[1, 0] - psi[-1, 0]))/(4*d**2)

@stencil
def jpx(zeta, psi, d):
    return (zeta[1, 0]*(psi[1, 1] - psi[1, -1])
           -zeta[-1, 0]*(psi[-1, 1] - psi[-1, -1])
           -zeta[0, 1]*(psi[1, 1] - psi[-1, 1])
           +zeta[0, -1]*(psi[1, -1] - psi[-1, -1]))/(4*d**2)

@stencil
def jxp(zeta, psi, d):
    return (zeta[ 1,  1]*(psi[ 0, 1] - psi[ 1,  0])
           -zeta[-1, -1]*(psi[-1, 0] - psi[ 0, -1])
           -zeta[-1,  1]*(psi[ 0, 1] - psi[-1,  0])
           +zeta[ 1, -1]*(psi[ 1, 0] - psi[ 0, -1]))/(4*d**2)

In [4]:
from numba import jit

In [5]:
@jit
def arakawa(zeta, psi, d):
    return (jpp(zeta, psi, d) + jpx(zeta, psi, d) + jxp(zeta, psi, d))/3

### Periodic Boundary Conditions

In [6]:
def periodic_boundary(A):
    A_ = np.zeros((A.shape[0]+2, A.shape[1]+2))
    A_[1:-1, 1:-1] = A
    A_[0, 1:-1]  = A[-1, :]
    A_[-1, 1:-1] = A[0, :]
    A_[1:-1, -1] = A[:, 1]
    A_[1:-1, 0]  = A[:, -1]
    A_[0,0] = A[-1,-1]
    A_[-1,-1] = A[0,0]
    return A_

In [8]:
import numpy as np

In [14]:
N = 4
A = np.arange(N*N).reshape((N,N))
A_ = periodic_boundary(A)
print(A)
print(A_)

[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]
 [12 13 14 15]]
[[15. 12. 13. 14. 15.  0.]
 [ 3.  0.  1.  2.  3.  1.]
 [ 7.  4.  5.  6.  7.  5.]
 [11.  8.  9. 10. 11.  9.]
 [15. 12. 13. 14. 15. 13.]
 [ 0.  0.  1.  2.  3.  0.]]


## How do stencils work

https://numba.pydata.org/numba-doc/dev/user/stencil.html

Here, we will use a simple example of a row wise smoothing over the left, center, and right element in each. The code should be self explanatory.

In [15]:
import numpy as np
from numba import stencil

In [16]:
N = 5

In [24]:
A = np.arange(N*N).reshape((N,N))
A_ = periodic_boundary(A)

In [25]:
val = np.empty_like(A, dtype=float)
for i in range(1,A.shape[0]+1):
    for j in range(1, A.shape[1]+1):
        val[i-1][j-1] = (A_[i, j-1] + A_[i, j] + A_[i, j+1])/3

In [26]:
@stencil
def simple_stencil(A):
    return (A[0, -1] + A[0, 0] + A[0, 1])/3

In [28]:
print(A_[1:-1, 1:-1])
print(val)
print(simple_stencil(A_)[1:-1,1:-1])

[[ 0.  1.  2.  3.  4.]
 [ 5.  6.  7.  8.  9.]
 [10. 11. 12. 13. 14.]
 [15. 16. 17. 18. 19.]
 [20. 21. 22. 23. 24.]]
[[ 1.66666667  1.          2.          3.          2.66666667]
 [ 6.66666667  6.          7.          8.          7.66666667]
 [11.66666667 11.         12.         13.         12.66666667]
 [16.66666667 16.         17.         18.         17.66666667]
 [21.66666667 21.         22.         23.         22.66666667]]
[[ 1.66666667  1.          2.          3.          2.66666667]
 [ 6.66666667  6.          7.          8.          7.66666667]
 [11.66666667 11.         12.         13.         12.66666667]
 [16.66666667 16.         17.         18.         17.66666667]
 [21.66666667 21.         22.         23.         22.66666667]]


## Applying Arakawa Scheme

### Simple shapes that should result in 0

In [29]:
A = np.random.random(N)
A = np.array([A for i in range(N)])
B = A
A_ = periodic_boundary(A)
B_ = periodic_boundary(B)

#### The three components

In [34]:
print(A_[1:-1, 1:-1])
print(jpp(A_, B_, 1)[1:-1,1:-1])
print(jpx(A_, B_, 1)[1:-1,1:-1])
print(jxp(A_, B_, 1)[1:-1,1:-1])

[[0.90962163 0.56786158 0.90072913 0.52656847 0.24296998]
 [0.90962163 0.56786158 0.90072913 0.52656847 0.24296998]
 [0.90962163 0.56786158 0.90072913 0.52656847 0.24296998]
 [0.90962163 0.56786158 0.90072913 0.52656847 0.24296998]
 [0.90962163 0.56786158 0.90072913 0.52656847 0.24296998]]
[[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]
[[ 0.          0.          0.          0.         -0.03074891]
 [ 0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.        ]
 [ 0.02699606  0.          0.          0.         -0.01850583]]
[[ 4.62592927e-18 -9.25185854e-18  0.00000000e+00  0.00000000e+00
   3.07489094e-02]
 [ 4.62592927e-18 -9.25185854e-18  0.00000000e+00  0.00000000e+00
   4.62592927e-18]
 [ 4.62592927e-18 -9.25185854e-18  0.00000000e+00  0.00000000e+00
   4.62592927e-18]
 [ 4.62592927e-18 -9.25185854e-18  0.00000000e+

All come out close to zero

#### The Full Average

In [94]:
@jit
def arakawa(zeta, psi, d):
    return (jpp(zeta, psi, d) + jpx(zeta, psi, d) + jxp(zeta, psi, d))/3

In [95]:
print(A_[1:-1,1:-1])
print(arakawa(periodic_boundary(A), periodic_boundary(B), 1)[1:-1,1:-1])

[[0.68652698 0.15561446 0.1491985  0.35296006]
 [0.68652698 0.15561446 0.1491985  0.35296006]
 [0.68652698 0.15561446 0.1491985  0.35296006]
 [0.68652698 0.15561446 0.1491985  0.35296006]]
[[ 0.00000000e+00  0.00000000e+00  1.80700362e-19  0.00000000e+00]
 [ 0.00000000e+00  0.00000000e+00  1.80700362e-19 -1.92747053e-19]
 [ 0.00000000e+00  0.00000000e+00  1.80700362e-19 -1.92747053e-19]
 [-1.15648232e-18  0.00000000e+00  1.80700362e-19  1.15648232e-18]]


Again, all cpme out close to zero when applying the arakawa scheme.

## Timing

In [68]:
N = 512

In [69]:
%%timeit
A = np.random.random((N,N))
B = np.random.random((N,N))
A_ = periodic_boundary(A)
B_ = periodic_boundary(B)

4.04 ms ± 182 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### Pure Python (3.46s)

In [71]:
def jpp(zeta, psi, d, i, j):
    return ((zeta[i+1, j  ] - zeta[i-1, j  ])*(psi[i,   j+1] - psi[i,   j-1])
           -(zeta[i,   j+1] - zeta[i,   j-1])*(psi[i+1, j  ] - psi[i-1, j  ]))/(4*d**2)

def jpx(zeta, psi, d, i, j):
    return (zeta[i+1, j  ]*(psi[i+1, j+1] - psi[i+1, j-1])
           -zeta[i-1, j  ]*(psi[i-1, j+1] - psi[i-1, j-1])
           -zeta[i,   j+1]*(psi[i+1, j+1] - psi[i-1, j+1])
           +zeta[i,   j-1]*(psi[i+1, j-1] - psi[i-1, j-1]))/(4*d**2)

def jxp(zeta, psi, d, i, j):
    return (zeta[i+1, j+1]*(psi[i,   j+1] - psi[i+1, j  ])
           -zeta[i-1, j-1]*(psi[i-1, j  ] - psi[i,   j-1])
           -zeta[i-1, j+1]*(psi[i,   j+1] - psi[i-1, j  ])
           +zeta[i+1, j-1]*(psi[i+1, j  ] - psi[i,   j-1]))/(4*d**2)

In [72]:
def arakawa(zeta, psi, d):
    val = np.empty_like(zeta)
    for i in range(1, zeta.shape[0]-1):
        for j in range(1, zeta.shape[1]-1):
            val[i][j] = (jpp(zeta, psi, d, i, j) + jpx(zeta, psi, d, i, j) + jxp(zeta, psi, d, i, j))
    return val/3

In [None]:
A = np.random.random(4)
A = np.array([A for i in range(4)])
B = A
A_ = periodic_boundary(A)
B_ = periodic_boundary(B)

In [61]:
print(arakawa(A_, B_, d=1).dtype)
print(arakawa(A_, B_, d=1))

float64
[[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   0.00000000e+00  0.00000000e+00  0.00000000e+00]
 [ 0.00000000e+00  1.54197642e-18 -3.08395285e-18  0.00000000e+00
   0.00000000e+00  0.00000000e+00  0.00000000e+00]
 [ 0.00000000e+00  1.54197642e-18 -3.08395285e-18  0.00000000e+00
   0.00000000e+00  1.54197642e-18  0.00000000e+00]
 [ 0.00000000e+00  1.54197642e-18 -3.08395285e-18  0.00000000e+00
   0.00000000e+00  1.54197642e-18  0.00000000e+00]
 [ 0.00000000e+00  1.54197642e-18 -3.08395285e-18  0.00000000e+00
   0.00000000e+00  1.54197642e-18  0.00000000e+00]
 [ 0.00000000e+00  0.00000000e+00 -3.08395285e-18  0.00000000e+00
   0.00000000e+00  0.00000000e+00  0.00000000e+00]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   0.00000000e+00  0.00000000e+00  0.00000000e+00]]


In [73]:
%%timeit
A = np.random.random((N,N))
B = np.random.random((N,N))
A_ = periodic_boundary(A)
B_ = periodic_boundary(B)
C = arakawa(A_, B_, d=1)

3.46 s ± 217 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Mixed (1.47s)

In [74]:
@stencil
def jpp(zeta, psi, d):
    return ((zeta[1, 0] - zeta[-1, 0])*(psi[0, 1] - psi[0, -1])
           -(zeta[0, 1] - zeta[0, -1])*(psi[1, 0] - psi[-1, 0]))/(4*d**2)

@stencil
def jpx(zeta, psi, d):
    return (zeta[1, 0]*(psi[1, 1] - psi[1, -1])
           -zeta[-1, 0]*(psi[-1, 1] - psi[-1, -1])
           -zeta[0, 1]*(psi[1, 1] - psi[-1, 1])
           +zeta[0, -1]*(psi[1, -1] - psi[-1, -1]))/(4*d**2)

@stencil
def jxp(zeta, psi, d):
    return (zeta[ 1,  1]*(psi[ 0, 1] - psi[ 1,  0])
           -zeta[-1, -1]*(psi[-1, 0] - psi[ 0, -1])
           -zeta[-1,  1]*(psi[ 0, 1] - psi[-1,  0])
           +zeta[ 1, -1]*(psi[ 1, 0] - psi[ 0, -1]))/(4*d**2)

In [75]:
def arakawa(zeta, psi, d):
    return (jpp(zeta, psi, d) + jpx(zeta, psi, d) + jxp(zeta, psi, d))/3

In [65]:
print(arakawa(A_, B_, d=1).dtype)
print(arakawa(A_, B_, d=1))

float64
[[ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   0.00000000e+00  0.00000000e+00  0.00000000e+00]
 [ 0.00000000e+00  1.54197642e-18 -3.08395285e-18  0.00000000e+00
   0.00000000e+00  0.00000000e+00  0.00000000e+00]
 [ 0.00000000e+00  1.54197642e-18 -3.08395285e-18  0.00000000e+00
   0.00000000e+00  1.54197642e-18  0.00000000e+00]
 [ 0.00000000e+00  1.54197642e-18 -3.08395285e-18  0.00000000e+00
   0.00000000e+00  1.54197642e-18  0.00000000e+00]
 [ 0.00000000e+00  1.54197642e-18 -3.08395285e-18  0.00000000e+00
   0.00000000e+00  1.54197642e-18  0.00000000e+00]
 [ 0.00000000e+00  0.00000000e+00 -3.08395285e-18  0.00000000e+00
   0.00000000e+00  0.00000000e+00  0.00000000e+00]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   0.00000000e+00  0.00000000e+00  0.00000000e+00]]


In [76]:
%%timeit
A = np.random.random((N,N))
B = np.random.random((N,N))
A_ = periodic_boundary(A)
B_ = periodic_boundary(B)
C = arakawa(A_, B_, 1)

1.47 s ± 215 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Fully Numba (0.012s, speedup$\,\geq\, 288x$)

In [77]:
@jit
def arakawa(zeta, psi, d):
    return (jpp(zeta, psi, d) + jpx(zeta, psi, d) + jxp(zeta, psi, d))/3

In [79]:
%%timeit
A = np.random.random((N,N))
B = np.random.random((N,N))
A_ = periodic_boundary(A)
B_ = periodic_boundary(B)
C = arakawa(A_, B_, 1)

12.1 ms ± 1.39 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [83]:
@jit(nopython=True, parallel=True, nogil=True, fastmath=True)
def arakawa(zeta, psi, d):
    return (jpp(zeta, psi, d) + jpx(zeta, psi, d) + jxp(zeta, psi, d))/3

In [84]:
%%timeit
A = np.random.random((N,N))
B = np.random.random((N,N))
A_ = periodic_boundary(A)
B_ = periodic_boundary(B)
C = arakawa(A_, B_, 1)

11.8 ms ± 543 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Reducing to one summation

#### Equations

$\frac{\mathbb{J}^{++} + \mathbb{J}^{+\times} + \mathbb{J}^{\times +}}{3}$

$=\frac{1}{3} 
\frac{1}{4d^2} \left[ 
    \left( \zeta_{i+1, j}   - \zeta_{i-1, j}   \right) \left( \psi_{i,   j+1} - \psi_{i,   j-1} \right)
   -\left( \zeta_{i,   j+1} - \zeta_{i,   j-1} \right) \left( \psi_{i+1, j}   - \psi_{i-1, j}   \right) \\
+
    \zeta_{i+1, j}   \left( \psi_{i+1, j+1} - \psi_{i+1, j-1} \right)
   -\zeta_{i-1, j}   \left( \psi_{i-1, j+1} - \psi_{i-1, j-1} \right)
   -\zeta_{i,   j+1} \left( \psi_{i+1, j+1} - \psi_{i-1, j+1} \right)
   +\zeta_{i,   j-1} \left( \psi_{i+1, j-1} - \psi_{i-1, j-1} \right) \\
+
    \zeta_{i+1, j+1} \left( \psi_{i,   j+1} - \psi_{i+1, j}   \right)
   -\zeta_{i-1, j-1} \left( \psi_{i-1, j}   - \psi_{i,   j-1} \right)
   -\zeta_{i-1, j+1} \left( \psi_{i,   j+1} - \psi_{i-1, j}   \right)
   +\zeta_{i+1, j-1} \left( \psi_{i+1, j}   - \psi_{i,   j-1} \right)
\right]
$

$=\frac{1}{3} 
\frac{1}{4d^2} \left[ 
    \zeta_{i+1, j}   \left( \psi_{i,   j+1} - \psi_{i,   j-1} \right)
   +\zeta_{i+1, j}   \left( \psi_{i+1, j+1} - \psi_{i+1, j-1} \right)
   -\zeta_{i-1, j}   \left( \psi_{i,   j+1} - \psi_{i,   j-1} \right)
   -\zeta_{i-1, j}   \left( \psi_{i-1, j+1} - \psi_{i-1, j-1} \right)
   -\zeta_{i,   j+1} \left( \psi_{i+1, j}   - \psi_{i-1, j}   \right) 
   -\zeta_{i,   j+1} \left( \psi_{i+1, j+1} - \psi_{i-1, j+1} \right)
   +\zeta_{i,   j-1} \left( \psi_{i+1, j}   - \psi_{i-1, j}   \right)
   +\zeta_{i,   j-1} \left( \psi_{i+1, j-1} - \psi_{i-1, j-1} \right)
   +\zeta_{i+1, j+1} \left( \psi_{i,   j+1} - \psi_{i+1, j}   \right)
   -\zeta_{i-1, j-1} \left( \psi_{i-1, j}   - \psi_{i,   j-1} \right)
   -\zeta_{i-1, j+1} \left( \psi_{i,   j+1} - \psi_{i-1, j}   \right)
   +\zeta_{i+1, j-1} \left( \psi_{i+1, j}   - \psi_{i,   j-1} \right)
\right]
$

$=\frac{1}{12 d^2} \left[ 
\;            \zeta_{i+1,\, j}   \left( \psi_{i,\,   j+1} - \psi_{i,\,   j-1} + \psi_{i+1,\, j+1} - \psi_{i+1,\, j-1} \right) \\
\quad\quad   -\zeta_{i-1,\, j}   \left( \psi_{i,\,   j+1} - \psi_{i,\,   j-1} + \psi_{i-1,\, j+1} - \psi_{i-1,\, j-1} \right) \\
\quad\quad   -\zeta_{i,\,   j+1} \left( \psi_{i+1,\, j}   - \psi_{i-1,\, j}   + \psi_{i+1,\, j+1} - \psi_{i-1,\, j+1} \right) \\
\quad\quad   +\zeta_{i,\,   j-1} \left( \psi_{i+1,\, j}   - \psi_{i-1,\, j}   + \psi_{i+1,\, j-1} - \psi_{i-1,\, j-1} \right) \\
\quad\quad   +\zeta_{i+1,\, j-1} \left( \psi_{i+1,\, j}   - \psi_{i,\,   j-1} \right) \\
\quad\quad   +\zeta_{i+1,\, j+1} \left( \psi_{i,\,   j+1} - \psi_{i+1,\, j}   \right) \\
\quad\quad   -\zeta_{i-1,\, j+1} \left( \psi_{i,\,   j+1} - \psi_{i-1,\, j}   \right) \\
\quad\quad   -\zeta_{i-1,\, j-1} \left( \psi_{i-1,\, j}   - \psi_{i,\,   j-1} \right)
\right]
$

#### Code

In [96]:
def arakawa_stencil(zeta, psi):
    return (zeta[ 1,  0] * (psi[ 0, 1] - psi[ 0, -1] + psi[ 1,  1] - psi[ 1, -1])
           -zeta[-1,  0] * (psi[ 0, 1] - psi[ 0, -1] + psi[-1,  1] - psi[-1, -1])
           -zeta[ 0,  1] * (psi[ 1, 0] - psi[-1,  0] + psi[ 1,  1] - psi[-1,  1])
           +zeta[ 0, -1] * (psi[ 1, 0] - psi[-1,  0] + psi[ 1, -1] - psi[-1, -1])
           +zeta[ 1, -1] * (psi[ 1, 0] - psi[ 0, -1])
           +zeta[ 1,  1] * (psi[ 0, 1] - psi[ 1,  0])
           -zeta[-1,  1] * (psi[ 0, 1] - psi[-1,  0])
           -zeta[-1, -1] * (psi[-1, 0] - psi[0, -1]))


def arakawa(zeta, psi, d):
    return arakawa_stencil(zeta, psi) / (12*(d**2))

In [80]:
A = np.random.random(4)
A = np.array([A for i in range(4)])
B = A
A_ = periodic_boundary(A)
B_ = periodic_boundary(B)

In [82]:
arakawa(A_, B_, 1)

array([[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        -2.71050543e-19, -2.89120579e-19,  0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        -2.71050543e-19,  0.00000000e+00,  0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        -2.71050543e-19,  0.00000000e+00,  0.00000000e+00],
       [ 0.00000000e+00, -3.46944695e-18,  0.00000000e+00,
        -2.71050543e-19,  1.44560290e-18,  0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00]])

### Timing

#### Mixed

In [None]:
@stencil
def arakawa_stencil(zeta, psi):
    return (zeta[ 1,  0] * (psi[ 0, 1] - psi[ 0, -1] + psi[ 1,  1] - psi[ 1, -1])
           -zeta[-1,  0] * (psi[ 0, 1] - psi[ 0, -1] + psi[-1,  1] - psi[-1, -1])
           -zeta[ 0,  1] * (psi[ 1, 0] - psi[-1,  0] + psi[ 1,  1] - psi[-1,  1])
           +zeta[ 0, -1] * (psi[ 1, 0] - psi[-1,  0] + psi[ 1, -1] - psi[-1, -1])
           +zeta[ 1, -1] * (psi[ 1, 0] - psi[ 0, -1])
           +zeta[ 1,  1] * (psi[ 0, 1] - psi[ 1,  0])
           -zeta[-1,  1] * (psi[ 0, 1] - psi[-1,  0])
           -zeta[-1, -1] * (psi[-1, 0] - psi[0, -1]))


def arakawa(zeta, psi, d):
    return arakawa_stencil(zeta, psi) / (12*(d**2))

In [86]:
%%timeit
A = np.random.random((N,N))
B = np.random.random((N,N))
A_ = periodic_boundary(A)
B_ = periodic_boundary(B)
C = arakawa(A_, B_, 1)

1.98 s ± 275 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


#### Numba (0.00913s, speedup$\,\geq\, 370$)

In [87]:
@jit
def arakawa(zeta, psi, d):
    return arakawa_stencil(zeta, psi) / (12*(d**2))

In [91]:
arakawa(A_, B_, 1.)

array([[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        -2.71050543e-19, -2.89120579e-19,  0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        -2.71050543e-19,  0.00000000e+00,  0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        -2.71050543e-19,  0.00000000e+00,  0.00000000e+00],
       [ 0.00000000e+00, -3.46944695e-18,  0.00000000e+00,
        -2.71050543e-19,  1.44560290e-18,  0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00]])

In [92]:
%%timeit
A = np.random.random((N,N))
B = np.random.random((N,N))
A_ = periodic_boundary(A)
B_ = periodic_boundary(B)
C = arakawa(A_, B_, 1.)

9.13 ms ± 512 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Conclusion

Numba is an exceptional tool for speeding up numerical Python and should be a core tool for any scientist. It makes writing high performing code almost as easy as writing normal Python, especially for stencils. While some operations like boundary conditions still have to be explicitly implemented, it is worth the effort to learn.

The bottom line is, it enables the conversion from equations found in a paper, to a performative implementation without the need to use lower level languages that require compilation for each machine. While there definitely is still some speedup to be had for using more tools, a factor of $370$ over pure python with a few lines of extra code is worth it in many circumstance to remove it as a bottleneck in the calculation.