# Numba minimal introduction

Playing with saxpy (single precision  A X plus Y)

$x_i = a x_i + y_i, \forall i \in [0,N-1[$

In [None]:
import numpy as np

In [None]:
import time

from numba import cuda

def measure_saxpy_bandwidth(func, N, useCuda=False):
    x = np.arange(N,dtype=np.float32)
    y = 2*x+1
    a = 0.5
    if useCuda:
        nThreads=256
        nBlocks = (N+nThreads-1) // nThreads
        start = cuda.event()
        stop = cuda.event()
        # copy to device
        x_d = cuda.to_device(x)
        y_d = cuda.to_device(y)
        start.record()
        func[nBlocks,nThreads](x_d,y_d,a)
        stop.record()
        stop.synchronize()
        duration = cuda.event_elapsed_time(start,stop)/1000
    else:
        start = time.perf_counter()
        func(x,y,a)
        stop = time.perf_counter()
        duration = stop-start
    
    # return bandwidth in GBytes/s
    # there 3 memory operation : 2 reads, 1 write
    # each memory operation involves N * sizeof(float32) = 4*N bytes
    print("Bandwidth : {} GBytes/s".format(3*(N*4)*1e-9/duration))
    return 3*(N*4)*1e-9/duration

## serial version : pure python

In [None]:
N = 1000
x = np.arange(N,dtype=np.float32)
y = 2*x+1
a = 0.5
def saxpy(x,y,a):
    for i in range(x.shape[0]):
        x[i] = a * x[i] + y[i]

In [None]:
%timeit saxpy(x,y,a)


In [None]:
measure_saxpy_bandwidth(saxpy,1000000)

# serial version : just in time compiled

In [None]:
import numba

@numba.jit(nopython=True)
def saxpy_jit(x,y,a):
    for i in range(x.shape[0]):
        x[i] = a*x[i] + y[i]

In [None]:
%timeit saxpy_jit(x,y,a)

In [None]:
measure_saxpy_bandwidth(saxpy_jit,1000000)

## using numba generalized universal functions (gufunc)


In [None]:
from numba import guvectorize, float32

In [None]:
@guvectorize([(float32[:], float32[:], float32, float32[:])], '(n),(n),()->(n)')
def saxpy_vectorized(x, y, a, z):
    for i in range(x.shape[0]):
        z[i] = a*x[i] + y[i]


In [None]:
%timeit saxpy_vectorized(x,y,a,x)

In [None]:
measure_saxpy_bandwidth(saxpy_vectorized,1000000)

## CPU parallel version : multithreading

In [None]:
from numba import config, njit, threading_layer

# set the threading layer before any parallel target compilation
#config.THREADING_LAYER = 'threadsafe'
config.THREADING_LAYER = 'tbb'

In [None]:
import numba
from numba import prange

@numba.jit(nopython=True, parallel=True)
def saxpy_jitp(x,y,a):
    for i in prange(x.shape[0]):
        x[i] = a*x[i] + y[i]

In [None]:
print("Threading layer chosen: %s" % threading_layer())

In [None]:
%timeit axpy_jitp(x,y,a)

In [None]:
N=10000000
x = np.arange(N,dtype=np.float32)
y = 2*x+1
a = 0.5

In [None]:
%timeit axpy_jit(x,y,a)
%timeit axpy_jitp(x,y,a)

# GPU parallel version with numba/cuda

In [None]:
from numba import cuda

@cuda.jit('void(float32[:], float32[:], float32)')
def saxpy_jit_cuda(x,y,a):
    i = cuda.grid(1)
    if i < x.shape[0]:
        x[i] = a*x[i] + y[i]

In [None]:
measure_saxpy_bandwidth(saxpy_jit_cuda,100000000, True)