## A notebook for students to get their hands on numba and cython

Reminder: `conda install numba` and `conda install cython` if they aren't installed on your system

In [1]:
def swap_min_max(arr):
    max_val = arr[0]
    max_ind = 0
    min_val = arr[0]
    min_ind = 0
    for i in range(1, arr.shape[0]):
        if arr[i] > max_val:
            max_val = arr[i]
            max_ind = i
        if arr[i] < min_val:
            min_val = arr[i]
            min_ind = i
    arr[min_ind] = arr[max_ind]
    arr[max_ind] = min_val

In [2]:
import numpy as np
X = np.arange(int(1e8)) #100 million numbers

In [3]:
%timeit -n 1 -r 1 swap_min_max(X) 
#this takes a long time, so we add the -n 1 -r 1 options to make it only perform the computation once
#usually it's best to perform it multiple times to get a more accurate measure of expected runtime

24.3 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


### In-Class Exercise: 

Rewrite swap_min_max using Numba. Rerun the timing in the previous block to compare its performance to pure Python.

In [4]:
import numba

def swap_min_max_numba(arr):
    '''
    Here's my wonderful function to swap the 
    min and max values of an array using numba!
    '''

### In-Class Exercise:

Rewrite swap_min_max using Cython. Here's an initial implementation. Provide types and extra Cython syntax to improve the running time as much as possible.

In [5]:
%load_ext cython

In [6]:
%%cython -a
def swap_min_max_cython(arr):
    max_val = arr[0]
    max_ind = 0
    min_val = arr[0]
    min_ind = 0
    for i in range(1, arr.shape[0]):
        if arr[i] > max_val:
            max_val = arr[i]
            max_ind = i
        if arr[i] < min_val:
            min_val = arr[i]
            min_ind = i
    arr[min_ind] = arr[max_ind]
    arr[max_ind] = min_val
    
    
    

In [12]:
%%cython -a
cimport numpy as np
cimport cython

@cython.boundscheck(False) #don't check if the array indices are valid
@cython.wraparound(False) #don't allow negative indices
def swap_min_max_cython(np.ndarray[ndim=1, dtype=np.int64_t] arr):
    cdef:
        np.int64_t min_val = arr[0]
        int min_ind = 0
        int max_ind = 0
        int i
        int n = arr.shape[0]
    max_val = arr[0]
    for i in range(1, n):
        if arr[i] > max_val:
            max_val = arr[i]
            max_ind = i
        if arr[i] < min_val:
            min_val = arr[i]
            min_ind = i
    arr[min_ind] = arr[max_ind]
    arr[max_ind] = min_val


In file included from /home/robert/.miniconda/envs/py/lib/python3.9/site-packages/numpy/core/include/numpy/ndarraytypes.h:1960,
                 from /home/robert/.miniconda/envs/py/lib/python3.9/site-packages/numpy/core/include/numpy/ndarrayobject.h:12,
                 from /home/robert/.miniconda/envs/py/lib/python3.9/site-packages/numpy/core/include/numpy/arrayobject.h:5,
                 from /home/robert/.cache/ipython/cython/_cython_magic_2b601a52a141563e0d1ddd72e5c7f8d4.c:713:
      |  ^~~~~~~


In [7]:
%timeit swap_min_max_cython(X)

17.3 s ± 2.48 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
