In [1]:
from __future__ import absolute_import, division
from __future__ import print_function, unicode_literals
import pints.toy as toy
import pints
import numpy as np
import logging
import math
import matplotlib.ticker as mticker
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import sys
from numpy import inf
import copy 
import pickle
import time
import CMA as CMA

In [2]:
def CMA_on_Slice(opt, 
                 log_pdf,
                 sampler,
                 sampler_x0, 
                 n_chains=1,
                 chain_size=100,
                 need_sensitivities=False, 
                eval_fun=['ESS']):
    
    optimizer_best_fxs = []
    optimizer_best_xs = []
    
    optimizer_best_fx = np.inf
    optimizer_best_x = 0
    
    for _ in range(100):
        
        # Getting the samples of hyper-parameters for the samplers
        optimizer_xs = opt.ask()
        
        # Saving the score of each sample of hyper-parameters
        optimizer_fxs = []
        
        # Evaluate performance for each hyper-parameter configuration
        for x in optimizer_xs:
            
            dummy = pints.SliceStepoutMCMC
            
            our_x = [[x[0], x[1]], x[2], x[3], x[4]]
            
            # Initialise function evaluations and matrix for chains
            function_evaluations = 0 
            chains = []
            
            # Return array of samples for each chain
            for i in range(n_chains):

                # Create sampler object and set hyperparameter
                curr_x0 = sampler_x0[i]
                mcmc = sampler(curr_x0)
                mcmc.set_hyper_parameters(our_x)
                i_chain = []
                
                # Update until we have ``chain_size`` samples
                # Update function evaluations each time we use ask(),tell()
                while len(i_chain) < chain_size:
                    x = mcmc.ask()
                    if need_sensitivities:
                        fx, grad = log_pdf.evaluateS1(x)
                        sample = mcmc.tell((fx, grad))
                    else:
                        fx ,_ = log_pdf.evaluateS1(x)
                        sample = mcmc.tell(fx)
                    function_evaluations += 1
                    if sample is not None:
                        i_chain.append(sample)
                
                # Append ith chain to list of chains       
                chains.append(i_chain)
                
            chains = np.array(chains, copy=True)
            optimizer_fx = 0
            # Calculate the score of the sampler with the given hyper-parameters
            # Get the KL if requested
            if 'KL' in eval_fun:
                kl = 0
                for chain in chains:
                    kl += log_pdf.kl_divergence(chain)
                optimizer_fx = kl/len(chains)                
            
            elif 'KL-ITER' in eval_fun:
                kl = 0
                for chain in chains:
                    kl += log_pdf.kl_divergence(chain)
                avg_kl = kl/len(chains)
                avg_iteration_count = function_evaluations / len(chains)
                optimizer_fx = avg_kl * avg_iteration_count
            
            # Get the ESS if requested
            elif 'ESS' in eval_fun:
                ess = np.zeros(chains[0].shape[1])
                for chain in chains:
                    ess += np.array(pints._diagnostics.effective_sample_size(chain))
                ess /= len(chains)
                ess = np.min(ess)
                avg_iteration_count = function_evaluations / len(chains)
                optimizer_fx = avg_iteration_count / ess
                
            optimizer_fxs.append(optimizer_fx)
        opt.tell(optimizer_fxs)
        
        optimizer_best_fxs.append(opt.fbest())
        optimizer_best_xs.append(opt.xbest())
        
        if opt.fbest() < optimizer_best_fx:
            optimizer_best_fx = opt.fbest()
            optimizer_best_x = opt.xbest()
    
    print(optimizer_best_fx)
    print(optimizer_best_x)
    
    return optimizer_best_fxs, optimizer_best_xs

In [3]:
log_pdf = pints.toy.GaussianLogPDF([2, 4], [[1, 0], [0, 3]])

sampler_x0 = xs = [
    [2, 4],
    [3, 3],
    [5, 4],
]

dummy = pints.SliceStepoutMCMC([2, 4])
optimizer_x0 = [dummy.width()[0], dummy.width()[1], dummy.expansion_steps(), dummy.prob_overrelaxed(), dummy.bisection_steps()]

boundary = pints.RectangularBoundaries([0,0,1,0,0],[20,20,100,1,20])

sampler = pints.SliceStepoutMCMC

for i in range(5):
    cma = CMA.CMAES(optimizer_x0, boundaries=boundary)
    fxs, xs =  CMA_on_Slice(cma, log_pdf, sampler, 
                        sampler_x0, n_chains=3, chain_size=200, eval_fun=['KL-ITER'])

10.160995525005125
[10.05568633 11.8977937  46.7997411   0.14908052  3.12818519]
8.25374592573952
[1.12560582e+01 1.19847092e+01 4.54795842e+01 2.19475684e-02
 2.84147789e+00]
12.437020021564996
[ 2.30880826  7.43313196 22.78730271  0.05498331 12.53900359]
7.056619206855843
[ 8.30730142  9.56967433 65.84522946  0.07669645  3.5547955 ]
6.689095794958215
[ 0.98937105  7.29121975 16.57297685  0.0181846   9.00185092]


In [4]:
# ESS Results 200 chain_size, 3 chains, 100 CMA iterations 

29.575758203500015
[5.93891036e+00 5.78794726e+00 5.60323480e+01 1.98715464e-03
 1.32408025e+01]
26.091875651686852
[ 3.31232037  5.52222501 63.73590805  0.9745463   0.28910198]
31.825255138874912
[5.40921947e+00 2.37674752e+00 4.73645141e+01 5.38065397e-03
 7.08325469e+00]
29.453012035357805
[2.92869195e+00 7.13914181e+00 5.07695965e+01 3.46576075e-03
 7.44098164e+00]
32.114467539845265
[3.96738955e+00 2.00726807e+00 5.15796589e+01 3.20265408e-03
 9.93199278e+00]

SyntaxError: invalid syntax (<ipython-input-4-c5ec83211066>, line 4)

In [None]:
# KL Results 200 chain_size, 3 chains, 100 CMA iterations

0.0038305189772738224
[ 1.06433389  0.41413352 49.58176233  0.11197374  9.78366365]
0.005828319499047978
[ 0.98104766  0.92576308 50.27667629  0.34128915  9.98614531]
0.0038313192662691917
[ 8.52127491  0.38114917 49.98250439  0.13758487  9.99825935]
0.004461861632446921
[ 1.13585322  2.81089916 49.9038746   0.25019024 10.37224217]
0.004810017042027992
[ 0.32881397  0.32462586 50.07846645  0.39483721 10.37920012]
