<a href="https://colab.research.google.com/github/semiGr/dl-projects/blob/master/EMhybrid.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This is an implementation of the EM hybrid optimization algorithm described in Section 4.2 in this [paper](https://arxiv.org/pdf/1905.10474v8.pdf).

Why JAX? Here's a cool [video](https://www.youtube.com/watch?v=z-WSrQDXkuM)

In [5]:
!pip install jaxlib



In [6]:
# note that np will be the jax version, and the original numpy will be onp
# ------------------------------------------------------------------------
import jax.numpy as np
import numpy as onp
import sys
import time
from scipy.stats import multivariate_normal
from jax import grad, jit, vmap
from jax import random
from jax.ops import index, index_add, index_update
from jax import lax 
from jax import vmap
import plotly.express as px
import plotly.graph_objects as go
import jax 
from IPython.display import clear_output

In [7]:
# define some functions we use later
# ----------------------------------

def hypersphere(key, radius, d=10):
    ''' Reference: Foundations of Data Science - Avrim Blum, John Hopcroft, Ravindran Kannan;
        see chapter 2.5. '''
    key, subkey = random.split(key)
    sample = random.normal(subkey, shape=([d]))
    z = np.sqrt(np.sum([x**2 for x in sample]))
    return np.multiply(radius/z, sample)

def rastrigin(y, a=1):
    ''' Rastrigin function used to test optimisation algorithms
        Inputs can be d-dimensional vectors too, not just scalars'''    
    # clip the coordinates of y to [-5.12, 5.12]
    n = len(y)
    y_ = np.clip(y, -5.12, 5.12)
    a_ = [np.power(a, (i-1)/(n-1)) for i in range(1, n+1)]
    return 10.0*len(y) + np.sum([(a_[i]*y_[i])**2 - 10.0*np.cos(2.*np.pi*a_[i]*y_[i]) for i in range(len(y_))])    

def shapeFunct(x, mu, sigma): 
    ''' This is a transformation, denoted by W below, applied to the function we optimise; 
        details are in Section 4.2 in the article'''
    return (1+np.exp((x-mu)/sigma))**-1


def mn_dense(x, mu, sigma): 
    ''' Multivariate normal density - perhaps there is already a function for this in jax or in other packages that could be used
        ...will check later'''
    d = len(mu)
    det = np.exp(np.linalg.slogdet(sigma)[1])
    return np.power(2*np.pi, -d/2.)*np.power(det, -0.5)*np.exp(-0.5*(x-mu).dot(np.linalg.inv(sigma).dot(x-mu)))


def mn_logdense(x, mu, sigma):
    ''' Multivariate normal density - perhaps there is already a function for this in jax or in other packages that could be used
        ...will check later'''
    d = len(mu)
    det = np.exp(np.linalg.slogdet(sigma)[1])
    return np.log(np.power(2*np.pi, -d/2.)*np.power(det, -0.5)*np.exp(-0.5*(x-mu).dot(np.linalg.inv(sigma).dot(x-mu))))

def is_posDef(cov):
    for x in np.linalg.eigvals(cov):
      if x.real < 0:
          return False
    return True 

def matrixFix(cov):
    delta = np.linalg.eigvals(cov).real.min() - 0.5
    return cov-delta*np.eye(len(cov))

def symmetrize(cov):
    # making sure the covariance matrix is symmetric 
    for i in range(len(cov)-1):
        for j in range(i+1, len(cov)):
            cov = index_update(cov, index[i, j], cov[j][i])
    return cov

# save the grad functions we will use repeatedly
grad_fn_mu = jit(grad(mn_logdense, 1)) # this gives the gradient of the log-density of the multivariate normal wrt to mu
grad_fn_sigma = jit(grad(mn_logdense, 2)) # this gives the gradient of the log-density of the multivariate normal wrt to sigma

In [8]:
# define the two types of updates that the hybrid scheme uses:
# - MC-GD (Monte Carlo - Gradiate Descent) update as per eqn (26) in the article (more details to be added later) 
# - EDA update as per eqn (18) in the article (more details to be added later)
# ---------------------------------------------------------------------------------------------------------------

def MCGDupdate(samples, f_samples, mu, sigma, alpha):
    ''' function that performs the MC-GD update as per eqn (26) in the paper 
        --------------------------------------------------------------------
        inputs : samples; 
                 f_samples - function values evaluated at the sample points. Note! feed in the 
                             function values already transformed by the shape function; 
                 mu - mean vector of the Gaussian search model used to produce the samples; 
                 sigma - covariance matrix of the Gaussian search model used to produce the samples; 
                 alpha - learning rate; 
        outputs : mu_new - updated mean vector; 
                  sigma_new - updated covariance matrix;  
    '''
    
    #mu_new = mu + alpha*np.sum(np.array([np.multiply(f_samples[i], grad_fn_mu(samples[i], mu, sigma)) for i in range(len(samples))]), 0)
    #sigma_new = sigma + alpha*np.sum(np.array([np.multiply(f_samples[i], grad_fn_sigma(samples[i], mu, sigma)) for i in range(len(samples))]), 0)
    
    mu_new = mu + alpha*np.sum(vmap(np.multiply)(f_samples, vmap(grad_fn_mu, in_axes=(0, None, None))(samples, mu, sigma)), axis=0)
    sigma_new = sigma + alpha*np.sum(vmap(np.multiply)(f_samples, vmap(grad_fn_sigma, in_axes=(0, None, None))(samples, mu, sigma)), axis=0) 
    
    sigma_new = symmetrize(sigma_new)
    
    return mu_new, sigma_new

def EDAupdate(samples, f_samples):
    ''' function that performs the EDA update as per eqn (18) in the paper 
        -------------------------------------------------------------------
        inputs : samples; 
                 f_samples - function values evaluated at the sample points. Note! feed in the 
                             function values already transformed by the shape function; 
        outputs : mu_new - updated mean vector; 
                  sigma_new - updated covariance matrix;  
    '''
    mu_new = (1./sum(f_samples))*np.sum(vmap(np.multiply)(f_samples, samples), axis=0)
    sigma_new = (1./sum(f_samples))*np.sum(vmap(np.multiply)(f_samples, vmap(np.outer)(samples, samples)), axis=0)

    return mu_new, sigma_new

In [9]:
# setting 
# -------

N = 1000 # Number of samples
budget = 5000 # computational budget in terms of number of iteration

results = {}
key = random.PRNGKey(42)
key, subkey = random.split(key)
mu = hypersphere(key, 30)
sigma = np.identity(10)
tau = -11.0

threshold = 0.003 # if the np.power(determinant, 0.1) < threshold ---> convergence is achieved
numberOfRestarts = 0 # how many times has the algorithm restarted 

alpha = 0.001 # this is not used now, as the alpha is calculated dynamically by a bound
numberOfIter = 1000 # number of iteration
backup = {}
milestones = {} # to store some mu and sigma along the way
aux_min = sys.maxsize*2 + 1
aux_rnd = 0


jax.numpy reductions won't accept lists and tuples in future versions, only scalars and ndarrays



In [10]:
# optimisation algorithm
# ----------------------

for k in range(budget): #numberOfIter, 110):
    # Step 2: (*) sample z from the distribution with parameters mu and sigma
    #         (*) compute f(z), mu_f, sigma_f, W(f(z))
    # -----------------------------------------------------------------------
    #key, *subkeys = random.split(key, N+1)
    #samples = np.array([random.multivariate_normal(subkey, mu, sigma) for subkey in subkeys])
    key, subkey = random.split(key)
    samples = random.multivariate_normal(subkey, mu, sigma, shape=(N,))
    f_samples = vmap(rastrigin)(samples)

    # get the f_sample mean and variance 
    mu_sample = np.mean(f_samples)
    var_sample = np.var(f_samples, ddof=1)

    # compute W(f(z))
    weighted_f = shapeFunct(f_samples, mu_sample, np.sqrt(var_sample))

    # Step 3: (*) perform an update of mu and sigma using MC-GD or EDA depending on the entropy H_t > tau 
    # -------------------------------------------------------------------------------------------------
    H = 0.5*np.linalg.slogdet(2*np.pi*np.e*sigma)[1]
    H_sign = np.linalg.slogdet(2*np.pi*np.e*sigma)[0] 
    #print('Entropy: {}'.format(H))

    update_type = 'MC-GD' # 1 - MC-GD update; 0 - EDA update

    if H < tau:

        # calculate a bound for alpha; based on the smallest eigenvalue of sigma 
        a_bound = 1.8*((np.min(np.linalg.eigvals(sigma)).real)**2 /(np.sum(weighted_f))) #1.5 was the multiplier of alpha
        # print('alpha: {}'.format(a_bound))
        mu_new, sigma_new = MCGDupdate(samples, weighted_f, mu, sigma, a_bound) 
        sigma_new = symmetrize(sigma_new)
        if is_posDef(sigma_new): 
            mu, sigma = mu_new, sigma_new
            # print('Step: {}'.format(k), '- update: MC-GD')
        else:
            mu_new, sigma_new = MCGDupdate(samples, weighted_f, mu, sigma, 0.0001*a_bound)
            sigma_new = symmetrize(sigma_new) 
            if is_posDef(sigma_new):
                mu, sigma = mu_new, sigma_new
                # print('Step: {}'.format(k), '- update: MC-GD (reduced alpha)')
            else: 
                backup['mu'] = mu
                backup['sigma'] = sigma
                backup['iter'] = k
                print('Terminate the run... because of negative definite sigma')
                break
    else:
        update_type = 'EDA'
        mu, sigma = EDAupdate(samples, weighted_f) 
        if not is_posDef(sigma): 
        #    print('Step: {}'.format(k), '- update: EDA')
            backup['mu'] = mu
            backup['sigma'] = sigma
            backup['iter'] = k
            print('Terminate the run... because of negative definite sigma')
            break
            #sigma = matrixFix(sigma)
            #print('Step: {}'.format(k), '- update: EDA (fixed sigma)')

    # Step 4: (*) store the results
    #         (*) stop if the threshold for convergence is met
    # --------------------------------------------------------
    detSigma = np.exp(np.linalg.slogdet(sigma)[1])
    if np.isnan(detSigma):
         backup['mu'] = mu
         backup['sigma'] = sigma
         backup['iter'] = k
         print('Terminate the run... because of nan')
         break
    
    if aux_min > np.min(f_samples):
         aux_min = min(aux_min, np.min(f_samples))
         aux_rnd = k

    clear_output(wait=True)
    print('Current progress: {:.2f} %'.format(np.round(k/budget *100, 2)), flush=True)
    print('Number of Restarts: {}'.format(numberOfRestarts), flush=True)
    print('Update type: {} -- '.format(update_type), 'Entropy: {}'.format(H), flush=True)
    print('Determinant: {} -- '.format(np.power(detSigma, 0.1)), 'min: {}'.format(np.min(f_samples)), flush=True)
    print('Running minima: {} -- '.format(aux_min), 'achieved in round... {}'.format(aux_rnd)  , flush=True) 
    time.sleep(0.3)

    if np.power(detSigma, 0.1) < threshold:
        key = random.PRNGKey(k)
        key, subkey = random.split(key)
        mu = hypersphere(key, 30)
        sigma = np.identity(10)

        aux_min = sys.maxsize*2 + 1
        aux_rnd = k

        numberOfRestarts += 1

    #print(np.power(detSigma, 0.1), f_samples[0], f_samples[1], np.min(f_samples), flush=True)
    results[k] = [np.power(detSigma, 0.1), H, f_samples[0], f_samples[1], np.min(f_samples)]
    if k%100 == 0:
        milestones['step'] = k 
        milestones['mu'] = mu
        milestones['sigma'] = sigma

Current progress: 99.98 %
Number of Restarts: 4
Update type: MC-GD --  Entropy: -13.124946594238281
Determinant: 0.004261123947799206 --  min: 1.2752151489257812
Running minima: 0.45755767822265625 --  achieved in round... 4973


In [13]:
iterations = []
min_f = []
iterations.append(0)
min_f.append(results[0][4])

for i in range(1, len(results)):
    iterations.append(i)
    if (np.sign(results[i-1][1]) < 0) and (np.sign(results[i][1]) > 0): 
        min_f.append(results[i][4])
    else: 
        min_f.append(min(results[i][4], min_f[i-1]))

In [14]:
# Plot the result 
# ---------------

fig = go.Figure()
fig.add_trace(go.Scatter(x=iterations, y=min_f,
                    mode='lines',
                    name='lines'))


# Edit the layout
fig.update_layout(title='EM hybrid optimization on the 10-dimensional Rastrigin function',
                   xaxis_title='Iterations',
                   yaxis_title='Minimum f(z) achieved')
fig.show()