In the following notebook I explore code optimization with [Numba](http://numba.pydata.org/numba-doc/latest/index.html). Have a look at the documentation page for further details.

In [1]:
import numba as nb
import numpy as np
import pyfftw

Toggle this option {'0', '1'} to enable or disable the JIT compilation entirely. The `jit()` decorator acts as if it performs no operation, and the invocation of decorated functions calls the original Python function instead of a compiled version. This can be useful if you want to run the Python debugger over your code.

In [2]:
# os.environ['NUMBA_DISABLE_JIT'] = '1'

In order to test the performance improvement achieved by the jit compiler we are going to sum over all element of a 2D array in a nested for loop.

In [3]:
size = int(1e3)
a = np.random.randint(-10, 10, (size, size))

This is the python implementation of the function we are going to optimize and its performance:

In [4]:
def sum2d_python(arr):
    M, N = arr.shape
    result = 0.0
    for i in range(M):
        for j in range(N):
            result += arr[i, j]
    return result

In [5]:
%%timeit
sum2d_python(a)

10 loops, best of 3: 177 ms per loop


With the jit() (*Just In Time*) compiler, pure Python code is optimized using the LLVM compiler infrastructure. The `@jit` decorator tells Numba to compile the following function. The argument types will be inferred by Numba when the function is called.

In [6]:
@nb.jit
def sum2d_serial(arr):
    M, N = arr.shape
    result = 0.0
    for i in range(M):
        for j in range(N):
            result += arr[i, j]
    return result

In [7]:
%%timeit
sum2d_serial(a)

The slowest run took 96.84 times longer than the fastest. This could mean that an intermediate result is being cached.
1000 loops, best of 3: 1.29 ms per loop


Note that with the optimized function, there is a performance boost by a factor of more than **X130** (177 ms / 1.29 ms). Also note that the function is compiled on the first run (*just-in-time*) so there is an overhead (`The slowest run took 96.84 times longer than the fastest.`) but the next calls are much faster because the same optimized function is reused.

We can further optimize this function if we parallelize it. Setting the parallel option for `jit()` enables a Numba feature that attempts to automatically parallelize and perform other optimizations on (part of) a function.

In [8]:
@nb.jit(nopython=True,
        nogil=True,
        parallel=True)
def sum2d_parallel(arr):
    M, N = arr.shape
    result = 0.0
    for i in nb.prange(M):
        for j in nb.prange(N):
            result += arr[i, j]
    return result

In [9]:
%%timeit
sum2d_parallel(a)

The slowest run took 679.95 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 3: 420 µs per loop


Parallelizing the code boosts performance by an additional factor of **X3** (1.29 ms / 0.420 ms). Again, there is an overhead (`The slowest run took 679.95 times longer than the fastest.`) as the function is compiled on demand.

Additional optimization is achieved if the datatypes of the function's signature are predetermined. In this example `a` is a 2D array of dtype `int64` and we expect the function to return a single `int64` value.

In [10]:
sig = nb.int64(nb.int64[:, :])
@nb.jit(sig,
        nopython=True,
        nogil=True,
        parallel=True)
def sum2d_parallel2(arr):
    M, N = arr.shape
    result = 0.0
    for i in nb.prange(M):
        for j in nb.prange(N):
            result += arr[i, j]
    return result


sum2d_parallel2.recompile()

In [11]:
%%timeit
sum2d_parallel2(a)

1000 loops, best of 3: 236 µs per loop


Note that we can pre-compile the function because the datatypes are declared so the overhead is paid right after the definition of the function. Calling the function now has no additional overhead and boosts performance by an additional **X1.7** (420 $\mu$s / 236 $\mu$s) factor.

Comparing the results:

In [12]:
print(sum2d_python(a),
      sum2d_serial(a),
      sum2d_parallel(a),
      sum2d_parallel2(a))

(-501399.0, -501399.0, -501399.0, -501399.0)


Overall performance boost: **X660**