# Optimize Forward Step

Use jit, numba and C++ to optimize the forward step in Variational autoencoder. Hopefully the performance will get close to using tensorflow only.

In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import time
from tensorflow.python.client import timeline
import matplotlib.pyplot as plt
%matplotlib inline

### Initialize Parameters

In [None]:
import sys
sys.path.append('../')
from vae_sta663 import *
from misc_sta663 import *

In [None]:
import tensorflow as tf
import numpy as np

config = {}
config['x_in'] = 784
config['encoder_1'] = 500
config['encoder_2'] = 500
config['decoder_1'] = 500
config['decoder_2'] = 500
config['z'] = 20

encoder_weights, _ = init_weights(config)

In [None]:
# transform tensors to numpy array
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

encoder_weights_np = {}
encoder_weights_np['h1'] = sess.run(encoder_weights['h1'])
encoder_weights_np['h2'] = sess.run(encoder_weights['h2'])
encoder_weights_np['mu'] = sess.run(encoder_weights['mu'])
encoder_weights_np['sigma'] = sess.run(encoder_weights['sigma'])
encoder_weights_np['b1'] = sess.run(encoder_weights['b1'])
encoder_weights_np['b2'] = sess.run(encoder_weights['b2'])
encoder_weights_np['bias_mu'] = sess.run(encoder_weights['bias_mu'])
encoder_weights_np['bias_sigma'] = sess.run(encoder_weights['bias_sigma'])

In [None]:
(mnist, n_samples) = mnist_loader()

In [None]:
x_sample, _ = mnist.train.next_batch(100)

### Use Tensorflow

In [None]:
def forward_z(x, encoder_weights):
    """
    Compute mean and sigma of z
    """
    layer_1 = tf.nn.softplus(tf.add(tf.matmul(x, encoder_weights['h1']), encoder_weights['b1']))
    layer_2 = tf.nn.softplus(tf.add(tf.matmul(layer_1, encoder_weights['h2']), encoder_weights['b2']))
    z_mean = tf.add(tf.matmul(layer_2, encoder_weights['mu']), encoder_weights['bias_mu'])
    z_sigma = tf.add(tf.matmul(layer_2, encoder_weights['sigma']), encoder_weights['bias_sigma'])
    
    return(z_mean, z_sigma)

In [None]:
x_sample_tf = tf.constant(x_sample)

In [None]:
%timeit -n10 -r3 sess.run(forward_z(x_sample_tf, encoder_weights))

### Use Numpy without Optimization

In [None]:
def forward_z_raw(x, encoder_weights):
    """
    Compute mean and sigma of z using numpy without any optimization
    """
    layer_1 = np.log(np.exp(x_sample @ encoder_weights_np['h1'] + encoder_weights_np['b1']) + 1)
    layer_2 = np.log(np.exp(layer_1 @ encoder_weights_np['h2'] + encoder_weights_np['b2']) + 1)
    z_mean = (layer_2 @ encoder_weights_np['mu'] + encoder_weights_np['bias_mu'])
    z_sigma = (layer_2 @ encoder_weights_np['sigma'] + encoder_weights_np['bias_sigma'])
    
    return(z_mean, z_sigma)

In [None]:
%timeit -n10 -r3 forward_z_raw(x_sample, encoder_weights_np)

In [None]:
np.testing.assert_almost_equal(sess.run(forward_z(x_sample_tf, encoder_weights)), forward_z_raw(x_sample, encoder_weights_np), 
                               decimal=5)

### Use Numpy with Numba

In [None]:
import numba
from numba import jit, vectorize, float32, float64

In [None]:
@jit('float32[:,:](float64[:,:],float64[:,:])')
def mat_mul(A, B):
    m, n = A.shape
    n, p = B.shape
    C = np.zeros((m, p))
    for i in range(m):
        for j in range(p):
            for k in range(n):
                C[i,j] += A[i,k] * B[k,j]
    return C

In [None]:
# parallel version of soft plus function
@vectorize([float64(float64)], target='parallel')
def soft_plus(x):
    """
    Vectorize version of numba
    """
    return np.log(np.exp(x) + 1)

In [None]:
@jit('UniTuple(float64[:,:], 2)(float64[:],float64[:,:],float64[:,:],float64[:,:],float64[:,:],float64[:,:],float64[:,:],float64[:,:],float64[:,:])')
def forward_z_numba(x, encoder_weights_h1, encoder_weights_h2, encoder_weights_b1, encoder_weights_b2, encoder_weights_mu, 
                  encoder_weights_bias_mu, encoder_weights_sigma, encoder_weights_bias_sigma):
    """
    Compute mean and sigma of z using numpy without any optimization
    """
    layer_1 = soft_plus(mat_mul(x, encoder_weights_h1) + encoder_weights_b1)
    layer_2 = soft_plus(mat_mul(layer_1, encoder_weights_h2) + encoder_weights_b2)
    z_mean = (mat_mul(layer_2, encoder_weights_mu) + encoder_weights_bias_mu)
    z_sigma = (mat_mul(layer_2, encoder_weights_sigma) + encoder_weights_bias_sigma)
    
    return(z_mean, z_sigma)

In [None]:
%%timeit -n10 -r3 
forward_z_numba(x_sample, encoder_weights_np['h1'], encoder_weights_np['h2'], encoder_weights_np['b1'], 
              encoder_weights_np['b2'], encoder_weights_np['mu'], encoder_weights_np['bias_mu'], 
              encoder_weights_np['sigma'], encoder_weights_np['bias_sigma'])

In [None]:
np.testing.assert_almost_equal(sess.run(forward_z(x_sample_tf, encoder_weights)), 
                               forward_z_numba(x_sample, encoder_weights_np['h1'], encoder_weights_np['h2'], 
                                             encoder_weights_np['b1'], encoder_weights_np['b2'], encoder_weights_np['mu'], 
                                             encoder_weights_np['bias_mu'], encoder_weights_np['sigma'], 
                                             encoder_weights_np['bias_sigma']), decimal=5)

### Use Cython

In [None]:
%load_ext Cython

In [None]:
%%cython -a
cimport cython
import numpy as np
from libc.math cimport exp, log

@cython.wraparound(False)
@cython.boundscheck(False)
cdef double[:,:] mat_mul_cython(double[:,:] A, double[:,:] B):
    """Matrix multiply function. Cythonize"""
    cdef int m = A.shape[0]
    cdef int n = A.shape[1]
    cdef int p = B.shape[1]
    cdef int i,j,k
    cdef double[:,:] C = np.zeros((m, p))
    for i in range(m):
        for j in range(p):
            for k in range(n):
                C[i,j] += A[i,k] * B[k,j]
    return C

@cython.wraparound(False)
@cython.boundscheck(False)
cdef double[:,:] mat_add_cython(double[:,:] A, double[:] B):
    """Matrix multiply function. Cythonize"""
    cdef int m = A.shape[0]
    cdef int n = A.shape[1]
    cdef int i,j
    cdef double[:,:] C = np.zeros((m, n))
    for i in range(m):
        for j in range(n):
            C[i,j] = A[i,j] + B[j]
    return C

@cython.wraparound(False)
@cython.boundscheck(False)
cdef double[:,:] soft_plus_cython(double[:,:] x):
    cdef int m = x.shape[0]
    cdef int n = x.shape[1]
    cdef double[:,:] y = np.zeros((m, n))
    for i in range(m):
        for j in range(n):
            y[i,j] = log(exp(x[i,j])+1)
    return y

@cython.wraparound(False)
@cython.boundscheck(False)
def forward_z_cython(double[:,:] x, double[:,:] encoder_weights_h1, double[:,:] encoder_weights_h2, 
                     double[:] encoder_weights_b1, double[:] encoder_weights_b2, double [:,:] encoder_weights_mu, 
                     double[:] encoder_weights_bias_mu, double[:,:] encoder_weights_sigma, 
                     double[:] encoder_weights_bias_sigma):
    """
    Compute mean and sigma of z using numpy with cython optimization
    """
    cdef double[:,:] layer_1 = soft_plus_cython(mat_add_cython(mat_mul_cython(x, encoder_weights_h1), encoder_weights_b1))
    cdef double[:,:] layer_2 = soft_plus_cython(mat_add_cython(mat_mul_cython(layer_1, encoder_weights_h2), encoder_weights_b2))
    cdef double[:,:] z_mean = mat_add_cython(mat_mul_cython(layer_2, encoder_weights_mu), encoder_weights_bias_mu)
    cdef double[:,:] z_sigma = mat_add_cython(mat_mul_cython(layer_2, encoder_weights_sigma), encoder_weights_bias_sigma)
    
    return(np.array(z_mean), np.array(z_sigma))

In [None]:
x_sample = x_sample.astype(np.float64)
encoder_weights_np['h1'] = encoder_weights_np['h1'].astype(np.float64)
encoder_weights_np['h2'] = encoder_weights_np['h2'].astype(np.float64)
encoder_weights_np['b1'] = encoder_weights_np['b1'].astype(np.float64)
encoder_weights_np['b2'] = encoder_weights_np['b2'].astype(np.float64)
encoder_weights_np['mu'] = encoder_weights_np['mu'].astype(np.float64)
encoder_weights_np['bias_mu'] = encoder_weights_np['bias_mu'].astype(np.float64)
encoder_weights_np['sigma'] = encoder_weights_np['sigma'].astype(np.float64)
encoder_weights_np['bias_sigma'] = encoder_weights_np['bias_sigma'].astype(np.float64)

In [None]:
%%timeit -n10 -r3 
forward_z_cython(x_sample, encoder_weights_np['h1'], encoder_weights_np['h2'], encoder_weights_np['b1'], 
              encoder_weights_np['b2'], encoder_weights_np['mu'], encoder_weights_np['bias_mu'], 
              encoder_weights_np['sigma'], encoder_weights_np['bias_sigma'])

In [None]:
np.testing.assert_almost_equal(sess.run(forward_z(x_sample_tf, encoder_weights)), 
                               forward_z_cython(x_sample, encoder_weights_np['h1'], encoder_weights_np['h2'], 
                                             encoder_weights_np['b1'], encoder_weights_np['b2'], encoder_weights_np['mu'], 
                                             encoder_weights_np['bias_mu'], encoder_weights_np['sigma'], 
                                             encoder_weights_np['bias_sigma']), decimal=5)