## Optimization

This notebook introduces our work regarding optimizing the SGHMC sampling using `Numba` and `Cython`.

In [51]:
import numpy as np
import numpy.linalg as la
import cython
import timeit
import warnings
warnings.filterwarnings('ignore')

### Plain Python Implementation

In [32]:
def sghmc(theta0, gradU, V, alpha = 0.1, epoch = 500, step = 1, lr = 0.01):
    '''Simpler sghmc sampling'''
    
    np.random.seed(1234)
    
    p = theta0.shape[0]
    beta = 0.5 * lr * V
    disp = la.cholesky(2 * lr * (alpha * np.eye(p) - 2 * beta))
    
    samples = np.zeros((epoch, p))
    samples[0] = theta0
    
    for i in range(epoch - 1):
        theta = samples[i]
        r = np.random.randn(p)
        
        for j in range(step):
            
            theta += lr * r
            r += -lr * gradU(theta) - alpha * r + disp @ np.random.randn(p)
        
        samples[i + 1] = theta
    
    return samples

In [33]:
# set the parameters

U = lambda x: -2 * x**2 + x**4
gradU =  lambda x: -4 * x + 4 * x**3 + 2 * np.random.randn()
theta0 = np.zeros((2))
V = 4 * np.eye((2))

### Profiling part

In [34]:
%prun -q -D sghmc.prof sghmc(theta0, gradU, V)

 
*** Profile stats marshalled to file 'sghmc.prof'. 


In [35]:
import pstats

p = pstats.Stats('sghmc.prof')
p.print_stats()
pass

Sun Apr 25 00:52:30 2021    sghmc.prof

         2025 function calls in 0.027 seconds

   Random listing order was used

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.000    0.000 {method 'get' of 'dict' objects}
        1    0.000    0.000    0.027    0.027 {built-in method builtins.exec}
        1    0.000    0.000    0.000    0.000 {built-in method builtins.getattr}
        3    0.000    0.000    0.000    0.000 {built-in method builtins.issubclass}
        1    0.013    0.013    0.027    0.027 <ipython-input-32-ee9660aead69>:1(sghmc)
        1    0.000    0.000    0.027    0.027 <string>:1(<module>)
      499    0.008    0.000    0.009    0.000 <ipython-input-33-4658113d012f>:4(<lambda>)
        1    0.000    0.000    0.000    0.000 {method 'seed' of 'numpy.random.mtrand.RandomState' objects}
     1497    0.006    0.000    0.006    0.000 {method 'randn' of 'numpy.random.mtrand.RandomState' objects}
        1    0.000    0.00

From the profiling result we can see that `np.random.randn` and `lambda` function are called the most times during running the sghmc function. This could help us to find which parts to focus on and then optimize. Before using `numba` and `cython` to optimize, we need to store the sampling result and its cost time.

In [36]:
samples = sghmc(theta0, gradU, V)

In [37]:
%timeit sghmc(theta0, gradU, V)

12.6 ms ± 1.81 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


### Numba

We use `numba.jit` as a function decorator to speed up the `sghmc` function 

In [48]:
import numba
from numba import jit

In [52]:
@jit
def sghmc_numba(theta0, gradU, V, alpha = 0.1, epoch = 500, step = 1, lr = 0.01):
    '''sghmc sampling with numba'''
    
    np.random.seed(1234)
    
    p = theta0.shape[0]
    beta = 0.5 * lr * V
    disp = la.cholesky(2 * lr * (alpha * np.eye(p) - 2 * beta))
    
    samples = np.zeros((epoch, p))
    samples[0] = theta0
    
    for i in range(epoch - 1):
        theta = samples[i]
        r = np.random.randn(p)
        
        for j in range(step):
            
            theta += lr * r
            r += -lr * gradU(theta) - alpha * r + disp @ np.random.randn(p)
        
        samples[i + 1] = theta
    
    return samples

In [53]:
print(np.allclose(samples, sghmc_numba(theta0, gradU, V)))

True


In [54]:
%timeit sghmc_numba(theta0, gradU, V)

12.6 ms ± 241 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
