# Just-in-time Compilation with [Numba](http://numba.pydata.org/) 

<div align="center"><img src="https://raw.githubusercontent.com/numba/numba/main/docs/_static/numba-blue-icon-rgb.svg" width="300"/></div>

## Numba is a JIT compiler which translates Python code in native machine language

* Using special decorators on Python functions Numba compiles them on the fly to machine code using LLVM
* Numba is compatible with Numpy arrays which are the basis of many scientific packages in Python
* It enables parallelization of machine code so that all the CPU cores are used

In [None]:
import math
import numpy as np
import matplotlib.pyplot as plt
import numba

In [None]:
!numba -s

### Create a new [Random Generator](https://numpy.org/doc/stable/reference/random/generator.html)

In [None]:
rng = np.random.default_rng()

## Using `numba.jit`

Numba offers `jit` which can used to decorate Python functions.

In [None]:
def is_prime(n):
    if n <= 1:
        raise ArithmeticError(f"{n}' <= 1")
    if n == 2:
        return True
    elif n % 2 == 0:
        return False
    else:
        n_sqrt =int(math.sqrt(n))
        for i in range(3, n_sqrt + 1, 2):
            if n % i == 0:
                return False
            
    return True

In [None]:
n = rng.integers(2, 10000000) # Get a random integer between 2 and 10000000
print(n, is_prime(n))

In [None]:
%time is_prime(n)

In [None]:
@numba.jit
def is_prime_jitted(n):
    if n <= 1:
        raise ArithmeticError(f"{n}' <= 1")
    if n == 2:
        return True
    elif n % 2 == 0:
        return False
    else:
        n_sqrt = int(math.sqrt(n))
        for i in range(3, n_sqrt + 1, 2):
            if n % i == 0:
                return False

    return True

In [None]:
numbers = rng.integers(2, 100000, size=1000000)
%time p1 = [is_prime(n) for n in numbers]
%time p2 = [is_prime_jitted(n) for n in numbers]

## Using `numba.jit` with `nopython=True`

In [None]:
@numba.jit(nopython=True)
def is_prime_njitted(n):
    if n <= 1:
        raise ArithmeticError(f"{n}' <= 1")
    if n == 2:
        return True
    elif n % 2 == 0:
        return False
    else:
        n_sqrt = int(math.sqrt(n))
        for i in range(3, n_sqrt + 1, 2):
            if n % i == 0:
                return False

    return True

In [None]:
numbers = rng.integers(2, 100000, size=1000000)
%time p1 = [is_prime_jitted(n) for n in numbers]
%time p2 = [is_prime_njitted(n) for n in numbers]

## Using ` @numba.jit(nopython=True)` is equivalent to using ` @numba.njit`

In [None]:
@numba.njit
def is_prime_njitted(n):
    if n <= 1:
        raise ArithmeticError(f"{n}' <= 1")
    if n == 2:
        return True
    elif n % 2 == 0:
        return False
    else:
        n_sqrt = int(math.sqrt(n))
        for i in range(3, n_sqrt + 1, 2):
            if n % i == 0:
                return False

    return True

In [None]:
numbers = rng.integers(2, 100000, size=1)
%time p = [is_prime_jitted(n) for n in numbers]
%time p = [is_prime_njitted(n) for n in numbers]

## Use `cache=True` to cache the compiled function

In [None]:
import math
from numba import njit

@njit(cache=True)
def is_prime_njitted_cached(n):
    if n <= 1:
        raise ArithmeticError(f"{n}' <= 1")
    if n == 2:
        return True
    elif n % 2 == 0:
        return False
    else:
        n_sqrt = int(math.sqrt(n))
        for i in range(3, n_sqrt + 1, 2):
            if n % i == 0:
                return False

    return True

In [None]:
numbers = rng.integers(2, 100000, size=10000)
%time p = [is_prime_njitted(n) for n in numbers]
%time p = [is_prime_njitted_cached(n) for n in numbers]

## Vector Triad Benchmark Python vs Numpy vs Numba

In [None]:
from time import perf_counter

def vecTriad(a, b, c, d):
    for j in range(a.shape[0]):
        a[j] = b[j] + c[j] * d[j]
        
def vecTriadNumpy(a, b, c, d):
    a[:] = b + c * d

@numba.njit
def vecTriadNumba(a, b, c, d):
    for j in range(a.shape[0]):
        a[j] = b[j] + c[j] * d[j]

        
# Initialize Vectors
n = 100000 # Vector size
r = 100 # Iterations
a = np.zeros(n, dtype=np.float64)
b = np.empty_like(a)
b[:] = 1.0
c = np.empty_like(a)
c[:] = 1.0
d = np.empty_like(a)
d[:] = 1.0


# Python version
start = perf_counter()

for i in range(r):
    vecTriad(a, b, c, d)
        
end = perf_counter()
mflops = 2.0 * r * n / ((end - start) * 1.0e6) 
print(f'Python: Mflops/sec: {mflops}')


# Numpy version
start = perf_counter()

for i in range(r):
    vecTriadNumpy(a, b, c, d)
        
end = perf_counter()
mflops = 2.0 * r * n / ((end - start) * 1.0e6) 
print(f'Numpy: Mflops/sec: {mflops}')


# Numba version
vecTriadNumba(a, b, c, d) # Run once to avoid measuring the compilation overhead

start = perf_counter()

for i in range(r):
    vecTriadNumba(a, b, c, d)
        
end = perf_counter()
mflops = 2.0 * r * n / ((end - start) * 1.0e6) 
print(f'Numba: Mflops/sec: {mflops}')

In [None]:
vecTriadNumba.inspect_asm()

## Eager compilation using function signatures

### Compilation overhead can be avoided by "informing" Numba of the supported argument/return types 

In [None]:
import math
from numba import njit

@njit(['boolean(int64)', 'boolean(int32)'])
def is_prime_njitted_eager(n):
    if n <= 1:
        raise ArithmeticError(f"{n}' <= 1")
    if n == 2:
        return True
    elif n % 2 == 0:
        return False
    else:
        n_sqrt = int(math.sqrt(n))
        for i in range(3, n_sqrt + 1, 2):
            if n % i == 0:
                return False

    return True

In [None]:
numbers = rng.integers(2, 1000000, size=1000)

# Run twice aft
%time p1 = [is_prime_njitted_eager(n) for n in numbers]
%time p2 = [is_prime_njitted_eager(n) for n in numbers]

### Using eager compilation, we remove the "freedom" from Numba to compile on the fly for unsupported types

In [None]:
p1 = [is_prime_njitted_eager(n) for n in numbers.astype(np.int16)]
p2 = [is_prime_njitted(n) for n in numbers.astype(np.float32)]

In [None]:
is_prime_njitted_eager.nopython_signatures

#### The following demonstrates how Numba compilers different versions of the same function depending on the types of the arguments

In [None]:
from numba import njit

@njit
def myfunc(n):
    return n // 2

In [None]:
myfunc(2.0);
myfunc(1);

In [None]:
myfunc.nopython_signatures

## Calculating and plotting the [Mandelbrot set](https://en.wikipedia.org/wiki/Mandelbrot_set)

In [None]:
X, Y = np.meshgrid(np.linspace(-2.0, 1, 1000), np.linspace(-1.0, 1.0, 1000))

def mandelbrot(X, Y, radius2, itermax):
    mandel = np.empty(shape=X.shape, dtype=np.int32)
    for i in range(X.shape[0]):
        for j in range(X.shape[1]):
            it = 0
            cx = X[i, j]
            cy = Y[i, j]
            x = 0.0
            y = 0.0
            while x * x + y * y < radius2 and it < itermax:
                x, y = x * x - y * y + cx, 2.0 * x * y + cy
                it += 1
            mandel[i, j] = it
            
    return mandel

In [None]:
fig = plt.figure(figsize=(15, 10))
ax = fig.add_subplot(111)

%time m = mandelbrot(X, Y, 4.0, 100)
    
ax.imshow(np.log(1 + m), extent=[-2.0, 1, -1.0, 1.0]);
ax.set_aspect('equal')
ax.set_ylabel('Im[c]')
ax.set_xlabel('Re[c]');

In [None]:
X, Y = np.meshgrid(np.linspace(-2.0, 1, 10000), np.linspace(-1.0, 1.0, 10000))

@njit
def mandelbrot_jitted(X, Y, radius2, itermax):
    mandel = np.empty(shape=X.shape, dtype=np.int32)
    for i in range(X.shape[0]):
        for j in range(X.shape[1]):
            it = 0
            cx = X[i, j]
            cy = Y[i, j]
            x = 0.0
            y = 0.0
            while x * x + y * y < radius2 and it < itermax:
                x, y = x * x - y * y + cx, 2.0 * x * y + cy
                it += 1
            mandel[i, j] = it
            
    return mandel

In [None]:
m = mandelbrot_jitted(X, Y, 4.0, 100) # Warmup 
%time m = mandelbrot_jitted(X, Y, 4.0, 100)
    
ax.imshow(np.log(1 + m), extent=[-2.0, 1, -1.0, 1.0]);
ax.set_aspect('equal')
ax.set_ylabel('Im[c]')
ax.set_xlabel('Re[c]');

### Loops can be parallelized by a combination of `parallel=True` in the `numba.njit` decorator and `numba.prange`

In [None]:
@numba.njit(parallel=True)
def mandelbrot_parallel_jitted(X, Y, radius2, itermax):
    mandel = np.empty(shape=X.shape, dtype=np.int32)
    for i in numba.prange(X.shape[0]):
        for j in numba.prange(X.shape[1]):
            it = 0
            cx = X[i, j]
            cy = Y[i, j]
            x = cx
            y = cy
            while x * x + y * y < radius2 and it < itermax:
                x, y = x * x - y * y + cx, 2.0 * x * y + cy
                it += 1
            mandel[i, j] = it
            
    return mandel

In [None]:
fig = plt.figure(figsize=(15, 10))
ax = fig.add_subplot(111)

m = mandelbrot_parallel_jitted(X, Y, 4.0, 100) # Warmup
%time m = mandelbrot_parallel_jitted(X, Y, 4.0, 100)
    
ax.imshow(np.log(1 + m), extent=[-2.0, 1, -1.0, 1.0]);
ax.set_aspect('equal')
ax.set_ylabel('Im[c]')
ax.set_xlabel('Re[c]');

#### Getting parallelization information

In [None]:
mandelbrot_parallel_jitted.parallel_diagnostics(level=3)

### Controlling the number of parallel threads

#### Numba will use `numba.config.NUMBA_NUM_THREADS` threads for parallel computation. The value can be changed using `numba.set_numthreads` function

In [None]:
print(f'The default number of threads is: {numba.config.NUMBA_NUM_THREADS}')

#### <mark>Exercise</mark> Test and time the parallelized Mandelbrot set calculation using different numbers of theads

## Creating `ufuncs` using `numba.vectorize`

In [None]:
from math import sin
from numba import float64, int64

def my_numpy_sin(a, b):
    return np.sin(a) + np.sin(b)

@np.vectorize
def my_sin(a, b):
    return sin(a) + sin(b)

@numba.vectorize([float64(float64, float64), int64(int64, int64)], target='parallel')
def my_sin_numba(a, b):
    return np.sin(a) + np.sin(b)

In [None]:
x = rng.integers(0, 100, size=90000000)
y = rng.integers(0, 100, size=90000000)

%time _ = my_numpy_sin(x, y)
%time _ = my_sin(x, y)
%time _ = my_sin_numba(x, y)

### Vectorize the testing of prime numbers 

In [None]:
@numba.vectorize('boolean(int64)')
def is_prime_v(n):
    if n <= 1:
        raise ArithmeticError(f"'{n}' <= 1")
    if n == 2:
        return True
    elif n % 2 == 0:
        return False
    else:
        n_sqrt = int(math.sqrt(n))
        for i in range(3, n_sqrt + 1, 2):
            if n % i == 0:
                return False
            
    return True

In [None]:
numbers = rng.integers(2, 10000000000, size=100000)
%time p = is_prime_v(numbers)

### Parallelize the vectorized function

In [None]:
@numba.vectorize(['boolean(int64)', 'boolean(int32)'],
                 target='parallel')
def is_prime_vp(n):
    if n <= 1:
        raise ArithmeticError(f"'{n}' <= 1")
    if n == 2:
        return True
    elif n % 2 == 0:
        return False
    else:
        n_sqrt = int(math.sqrt(n))
        for i in range(3, n_sqrt + 1, 2):
            if n % i == 0:
                return False
            
    return True

In [None]:
numbers = rng.integers(2, 10000000000, dtype=np.int64, size=1000000)
%time p1 = is_prime_v(numbers)
%time p2 = is_prime_vp(numbers)

In [None]:
# Print the largest primes from to 1 and 10 millions
numbers = np.arange(1000000, 10000001, dtype=np.int32)
%time p1 = is_prime_vp(numbers)
primes = numbers[p1]

for n in primes[-10:]:
    print(n)