# Numerical Python Speed Tests

As we know, Python can be *very* slow when computing over a large number of elements. This notebook explores a *few* ways we can speed up Python processing. In particular, we seek to speed up numerical processing comparing the following methods:

1. Looping over every element in an array.
2. Using logical indexing to compute outputs.
3. Use the [numpy ufunc extension](https://numpy.org/doc/stable/reference/generated/numpy.frompyfunc.html?highlight=frompyfunc#numpy.frompyfunc) to intellegently loop over the array.
4. Using [numba's vectorization](https://numba.readthedocs.io/en/stable/user/vectorize.html) method
5. Using [numba's just-in-time (JIT)](https://numba.pydata.org/numba-doc/latest/user/jit.html) compiler

## Problem Setup

In my doctorate, I utilize a complicated function over which I have to iterate millions of datapoints. In essence, I am trying to quantify the extinction of light at certain wavelengths (see the extinction equations in [this paper](http://articles.adsabs.harvard.edu/pdf/1989ApJ...345..245C)).

## Imports

In [None]:
# Numpy Imports
import numpy as np
from numpy import frompyfunc as vectorize

# Numba Imports
from numba import jit, vectorize as nbvec, float64, prange

## Functions

Next, let's define the functions.

In [None]:
# Loop over elements
def ext_loop(wave, ebv=1.0, rvp=3.1):
    
    # Create the output array
    k = np.empty_like(wave)
    
    # Loop over wave elements
    for i, w in enumerate(wave):
        
        # Get wave number
        x = 1/wave[i]
        
        # Calculate a, b
        if x < 1.1:
            
            a =  0.574 * x**1.61
            b = -0.527 * x**1.61
            
        elif x < 3.3:
            
            y = x - 1.82
            a = 1 + 0.17699*y - 0.50447*y**2 - 0.02427*y**3 + 0.72085*y**4 + \
                0.01979*y**5 - 0.7753*y**6 + 0.32999*y**7
            b = 1.41338*y + 2.28305*y**2 + 1.07233*y**3 - 5.38434*y**4 - \
                0.62251*y**5 + 5.3026*y**6 - 2.09002*y**7
        
        elif x < 8:
            
            if x > 5.9:
                
                y  = x - 5.9
                fA = -0.0447*y**2 - 0.009779*y**3
                fB =  0.2130*y**2 + 0.1207*y**3
            
            else:
                
                fA = fB = 0
                
            
            a = fA + 1.752 - 0.316*x - 0.104/(0.341 + (x - 4.67)**2)
            b = fB - 3.090 + 1.825*x + 1.206/(0.263 + (x - 4.62)**2)
        
        else:
            
            y = x - 8
            a = -1.073 - 0.628*y + 0.137*y**2 - 0.070*y**3
            b = 13.670 + 4.257*y - 0.420*y**2 + 0.374*y**3
        
        # Calc k
        k[i] = rvp*a + b
        
    # Return k
    return k

In [None]:
def ext_ind(wave, ebv=1.0, rvp=3.1):
    
    def getk(a, b):
        return rvp*a + b
    
    # Setup the Output Array
    k = np.empty_like(wave)
    
    # Get wave number
    x = 1/wave
    
    # Do First Cut
    inds = (x < 1.1)
    a =  0.574 * x[inds]**1.61
    b = -0.527 * x[inds]**1.61
    k[inds] = getk(a, b)
    
    # Do Second Cut
    inds = ((1.1 <= x) & (x < 3.3))
    y = x[inds] - 1.82
    a = 1 + 0.17699*y - 0.50447*y**2 - 0.02427*y**3 + 0.72085*y**4 + \
        0.01979*y**5 - 0.7753*y**6 + 0.32999*y**7
    b = 1.41338*y + 2.28305*y**2 + 1.07233*y**3 - 5.38434*y**4 - \
        0.62251*y**5 + 5.3026*y**6 - 2.09002*y**7
    k[inds] = getk(a, b)
    
    # Do Third Cut
    inds = ((3.3 <= x) & (x < 5.9))
    fA = fB = 0
    a = fA + 1.752 - 0.316*x[inds] - 0.104/(0.341 + (x[inds] - 4.67)**2)
    b = fB - 3.090 + 1.825*x[inds] + 1.206/(0.263 + (x[inds] - 4.62)**2)
    k[inds] = getk(a, b)
    
    # Do Fourth Cut
    inds = ((5.9 <= x) & (x < 8))
    y  = x[inds] - 5.9
    fA = -0.0447*y**2 - 0.009779*y**3
    fB =  0.2130*y**2 + 0.1207*y**3
    a = fA + 1.752 - 0.316*x[inds] - 0.104/(0.341 + (x[inds] - 4.67)**2)
    b = fB - 3.090 + 1.825*x[inds] + 1.206/(0.263 + (x[inds] - 4.62)**2)
    k[inds] = getk(a, b)
    
    # Do Fifth Cut
    inds = (8 <= x)
    y = x[inds] - 8
    a = -1.073 - 0.628*y + 0.137*y**2 - 0.070*y**3
    b = 13.670 + 4.257*y - 0.420*y**2 + 0.374*y**3
    k[inds] = getk(a, b)
    
    # Return k
    return k

In [None]:
# Written as Scalar for ufunc
def ext_sclr(wave, ebv=1.0, rvp=3.1):
    
    # Get wave number
    x = 1/wave

    # Calculate a, b
    if x < 1.1:

        a =  0.574 * x**1.61
        b = -0.527 * x**1.61

    elif x < 3.3:

        y = x - 1.82
        a = 1 + 0.17699*y - 0.50447*y**2 - 0.02427*y**3 + 0.72085*y**4 + \
            0.01979*y**5 - 0.7753*y**6 + 0.32999*y**7
        b = 1.41338*y + 2.28305*y**2 + 1.07233*y**3 - 5.38434*y**4 - \
            0.62251*y**5 + 5.3026*y**6 - 2.09002*y**7

    elif x < 8:

        if x > 5.9:

            y  = x - 5.9
            fA = -0.0447*y**2 - 0.009779*y**3
            fB =  0.2130*y**2 + 0.1207*y**3

        else:

            fA = fB = 0


        a = fA + 1.752 - 0.316*x - 0.104/(0.341 + (x - 4.67)**2)
        b = fB - 3.090 + 1.825*x + 1.206/(0.263 + (x - 4.62)**2)

    else:

        y = x - 8
        a = -1.073 - 0.628*y + 0.137*y**2 - 0.070*y**3
        b = 13.670 + 4.257*y - 0.420*y**2 + 0.374*y**3

    # Return k
    return rvp*a + b


# Create the ufunc
def ext_vec(wave, ebv=1.0, rvp=3.1):
    return vectorize(ext_sclr, 3, 1)(wave, ebv, rvp)

In [None]:
# Written as Scalar for numba ufunc
@nbvec(
    [float64(float64, float64, float64)],
    nopython=True,
    target='parallel'
)
def ext_nbvec(wave, ebv=1.0, rvp=3.1):
    
    # Get wave number
    x = 1/wave

    # Calculate a, b
    if x < 1.1:

        a =  0.574 * x**1.61
        b = -0.527 * x**1.61

    elif x < 3.3:

        y = x - 1.82
        a = 1 + 0.17699*y - 0.50447*y**2 - 0.02427*y**3 + 0.72085*y**4 + \
            0.01979*y**5 - 0.7753*y**6 + 0.32999*y**7
        b = 1.41338*y + 2.28305*y**2 + 1.07233*y**3 - 5.38434*y**4 - \
            0.62251*y**5 + 5.3026*y**6 - 2.09002*y**7

    elif x < 8:

        if x > 5.9:

            y  = x - 5.9
            fA = -0.0447*y**2 - 0.009779*y**3
            fB =  0.2130*y**2 + 0.1207*y**3

        else:

            fA = fB = 0


        a = fA + 1.752 - 0.316*x - 0.104/(0.341 + (x - 4.67)**2)
        b = fB - 3.090 + 1.825*x + 1.206/(0.263 + (x - 4.62)**2)

    else:

        y = x - 8
        a = -1.073 - 0.628*y + 0.137*y**2 - 0.070*y**3
        b = 13.670 + 4.257*y - 0.420*y**2 + 0.374*y**3

    # Return k
    return rvp*a + b

In [None]:
# Create the JIT version
@jit(
    float64[::1](float64[::1], float64, float64),
    nopython=True,
    parallel=True
)
def ext_jit(wave, ebv=1.0, rvp=3.1):
    
    # Create the output array
    k = np.empty_like(wave)
    
    # Loop over wave elements
    for i in prange(len(wave)):
        
        # Get wave number
        x = 1/wave[i]
        
        # Calculate a, b
        if x < 1.1:
            
            a =  0.574 * x**1.61
            b = -0.527 * x**1.61
            
        elif x < 3.3:
            
            y = x - 1.82
            a = 1 + 0.17699*y - 0.50447*y**2 - 0.02427*y**3 + 0.72085*y**4 + \
                0.01979*y**5 - 0.7753*y**6 + 0.32999*y**7
            b = 1.41338*y + 2.28305*y**2 + 1.07233*y**3 - 5.38434*y**4 - \
                0.62251*y**5 + 5.3026*y**6 - 2.09002*y**7
        
        elif x < 8:
            
            if x > 5.9:
                
                y  = x - 5.9
                fA = -0.0447*y**2 - 0.009779*y**3
                fB =  0.2130*y**2 + 0.1207*y**3
            
            else:
                
                fA = fB = 0
                
            
            a = fA + 1.752 - 0.316*x - 0.104/(0.341 + (x - 4.67)**2)
            b = fB - 3.090 + 1.825*x + 1.206/(0.263 + (x - 4.62)**2)
        
        else:
            
            y = x - 8
            a = -1.073 - 0.628*y + 0.137*y**2 - 0.070*y**3
            b = 13.670 + 4.257*y - 0.420*y**2 + 0.374*y**3
        
        # Calc k
        k[i] = rvp*a + b
        
    # Return k
    return k

## Create the Data

We will create 10 million random units to test over.

In [None]:
# Create the RNG
rng = np.random.default_rng(0)

# Get the Wavelengths
waves = rng.uniform(low=1/10, high=1/0.3, size=int(1e7))

## Test the Implementations

### Loop Method

In [None]:
%%timeit
_ = ext_loop(waves)

### Indexing Method

In [None]:
%%timeit
_ = ext_ind(waves)

### UFunc Method

In [None]:
%%timeit
_ = ext_vec(waves)

### Numba Vectorization

Jit takes time to compile then runs faster

In [None]:
%%time
_ = ext_nbvec(waves, 1., 3.1)

In [None]:
%%timeit
_ = ext_nbvec(waves, 1., 3.1)

### JIT Method

In [None]:
%%time
_ = ext_jit(waves, 1., 3.1)

In [None]:
%%timeit
_ = ext_jit(waves, 1., 3.1)