In [1]:
import numpy as np
from numpy import vstack
from numpy import argmax
from sklearn.gaussian_process import GaussianProcessRegressor
from warnings import catch_warnings
from warnings import simplefilter
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.stats import norm

In [2]:
# objective function
def objective(x1, x2, x3, noise=0.1):
    noise = np.random.normal(loc=0, scale=noise)
    result = ((x1 * 10e+0) / (x2 * x3))
    if (x2==0 or x3==0):
      return 0 + noise
    return result + noise

In [3]:
# surrogate or approximation for the objective function
def surrogate(model, X):
    # catch any warning generated when making a prediction
    with catch_warnings():
        # ignore generated warnings
        simplefilter("ignore")
        return model.predict(X, return_std=True)

In [18]:
# probability of improvement acquisition function
def acquisition(X, Xsamples, model):
    # calculate the best surrogate score found so far
    yhat, _ = surrogate(model, X)
    best = max(yhat)
    # calculate mean and stdev via surrogate function
    mu, std = surrogate(model, Xsamples)
    # calculate the probability of improvement
    probs = norm.cdf((mu - best) / (std+1E-9))
    return probs

In [10]:
# optimize the acquisition function
def opt_acquisition(X, y, model):
    # random search, generate random samples
    Xsamples = np.zeros((100, 3))
    Xsamples[:, 0] = np.linspace(0,20, num=100 )
    Xsamples[:, 1] = np.linspace(0,20, num=100 )
    Xsamples[:, 2] = np.linspace(0,13, num=100 )
    # calculate the acquisition function for each sample
    scores = acquisition(X, Xsamples, model)
    # locate the index of the largest scores
    ix = argmax(scores)
    return Xsamples[ix, :]

In [24]:
# input data
X = np.array([[2, 5, 13], [20, 15, 6.5], [2, 5, 6.5], [2, 15, 13], [10, 5, 13], [10, 15, 6.5], [10, 5, 6.5], [10, 15, 13], [20, 5, 13], [2, 15, 6.5], [20, 5, 6.5], [20, 15, 13], [0, 0, 0]])
y = np.array([20.9, 27.4, 29.7, 34.2, 30.1, 29.6, 41.7, 21.8, 28.1, 42, 32.2, 53.8, 18.7]).reshape(-1, 1)
# y = np.array([0.11, 0.21, 0.57, 0.62, 0.58, 0.54, 0.85, 0.14, 0.41, 0.9, 0.71, 0.95, 0.1]).reshape(-1, 1)

In [25]:
# define the model
model = GaussianProcessRegressor()

# fit the model
model.fit(X, y)

In [26]:
# perform the optimization process
for i in range(100):
    # select the next point to sample
    x = opt_acquisition(X, y, model)
    # sample the point
    actual = objective(x[0], x[1], x[2])
    # summarize the finding
    est, _ = surrogate(model, [x])
    # print(est[0])
    print(' itr number = %.0f, x=%.3f, y=%.3f, z=%.3f, f()=%.10f, actual=%.10f, est[0] = %.10f' % (i, x[0], x[1], x[2], est, actual, est[0]))
    # add the data to the dataset
    X = vstack((X, x))
    y = vstack((y, [actual]))
    # update the model
    model.fit(X, y)

  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))


 itr number = 0, x=0.000, y=0.000, z=0.000, f()=18.6999999981, actual=-0.0259742404, est[0] = 18.6999999981
 itr number = 1, x=0.000, y=0.000, z=0.000, f()=9.3370208740, actual=0.1797249013, est[0] = 9.3370208740
 itr number = 2, x=0.000, y=0.000, z=0.000, f()=6.2845916748, actual=-0.0211899823, est[0] = 6.2845916748
 itr number = 3, x=0.000, y=0.000, z=0.000, f()=4.7081298828, actual=-0.0867687513, est[0] = 4.7081298828
 itr number = 4, x=0.000, y=0.000, z=0.000, f()=3.7491378784, actual=0.1021079318, est[0] = 3.7491378784
 itr number = 5, x=0.000, y=0.000, z=0.000, f()=3.1413383484, actual=0.0954059373, est[0] = 3.1413383484
 itr number = 6, x=0.000, y=0.000, z=0.000, f()=2.7061538696, actual=-0.0156270674, est[0] = 2.7061538696
 itr number = 7, x=0.000, y=0.000, z=0.000, f()=2.3659744263, actual=-0.0035895496, est[0] = 2.3659744263
 itr number = 8, x=0.000, y=0.000, z=0.000, f()=2.1026687622, actual=-0.2135110246, est[0] = 2.1026687622
 itr number = 9, x=0.000, y=0.000, z=0.000, f()

  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))


 itr number = 13, x=0.000, y=0.000, z=0.000, f()=1.3353195190, actual=0.0621910068, est[0] = 1.3353195190
 itr number = 14, x=0.000, y=0.000, z=0.000, f()=1.2504653931, actual=-0.0729570082, est[0] = 1.2504653931
 itr number = 15, x=0.000, y=0.000, z=0.000, f()=1.1676864624, actual=0.2021099697, est[0] = 1.1676864624
 itr number = 16, x=0.000, y=0.000, z=0.000, f()=1.1109580994, actual=0.0339530033, est[0] = 1.1109580994
 itr number = 17, x=0.000, y=0.000, z=0.000, f()=1.0510444641, actual=-0.0421356828, est[0] = 1.0510444641
 itr number = 18, x=0.000, y=0.000, z=0.000, f()=0.9935379028, actual=-0.0766021775, est[0] = 0.9935379028
 itr number = 19, x=0.000, y=0.000, z=0.000, f()=0.9400939941, actual=0.0098159805, est[0] = 0.9400939941
 itr number = 20, x=0.000, y=0.000, z=0.000, f()=0.8957424164, actual=0.0476735383, est[0] = 0.8957424164
 itr number = 21, x=0.000, y=0.000, z=0.000, f()=0.8572072983, actual=-0.0493529036, est[0] = 0.8572072983
 itr number = 22, x=0.000, y=0.000, z=0.00

  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))


 itr number = 45, x=0.000, y=0.000, z=0.000, f()=0.4240994453, actual=-0.1255850028, est[0] = 0.4240994453
 itr number = 46, x=0.000, y=0.000, z=0.000, f()=0.4124355316, actual=0.0591897484, est[0] = 0.4124355316
 itr number = 47, x=0.000, y=0.000, z=0.000, f()=0.4050464630, actual=0.0426960101, est[0] = 0.4050464630
 itr number = 48, x=0.000, y=0.000, z=0.000, f()=0.3976411819, actual=0.0733620416, est[0] = 0.3976411819
 itr number = 49, x=0.000, y=0.000, z=0.000, f()=0.3911786079, actual=0.0545681372, est[0] = 0.3911786079
 itr number = 50, x=0.000, y=0.000, z=0.000, f()=0.3845939636, actual=0.1000633442, est[0] = 0.3845939636
 itr number = 51, x=0.000, y=0.000, z=0.000, f()=0.3790588379, actual=-0.1412152825, est[0] = 0.3790588379
 itr number = 52, x=0.000, y=0.000, z=0.000, f()=0.3692340851, actual=0.0268304899, est[0] = 0.3692340851
 itr number = 53, x=0.000, y=0.000, z=0.000, f()=0.3628883362, actual=-0.1458266635, est[0] = 0.3628883362
 itr number = 54, x=0.000, y=0.000, z=0.000

  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))


 itr number = 57, x=0.000, y=0.000, z=0.000, f()=0.3343315125, actual=0.0447606742, est[0] = 0.3343315125
 itr number = 58, x=0.000, y=0.000, z=0.000, f()=0.3294625282, actual=0.2069675779, est[0] = 0.3294625282
 itr number = 59, x=0.000, y=0.000, z=0.000, f()=0.3273050785, actual=-0.1282411785, est[0] = 0.3273050785
 itr number = 60, x=0.000, y=0.000, z=0.000, f()=0.3198442459, actual=0.0051834230, est[0] = 0.3198442459
 itr number = 61, x=0.000, y=0.000, z=0.000, f()=0.3148274422, actual=-0.1223043391, est[0] = 0.3148274422
 itr number = 62, x=0.000, y=0.000, z=0.000, f()=0.3078727722, actual=-0.0367881058, est[0] = 0.3078727722
 itr number = 63, x=0.000, y=0.000, z=0.000, f()=0.3024606705, actual=-0.0602855748, est[0] = 0.3024606705
 itr number = 64, x=0.000, y=0.000, z=0.000, f()=0.2968764305, actual=-0.0302644928, est[0] = 0.2968764305
 itr number = 65, x=0.000, y=0.000, z=0.000, f()=0.2919287682, actual=-0.0386627640, est[0] = 0.2919287682
 itr number = 66, x=0.000, y=0.000, z=0.

  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))


 itr number = 67, x=0.000, y=0.000, z=0.000, f()=0.2831420898, actual=-0.0623403169, est[0] = 0.2831420898
 itr number = 68, x=0.000, y=0.000, z=0.000, f()=0.2781271935, actual=0.0200706209, est[0] = 0.2781271935
 itr number = 69, x=0.000, y=0.000, z=0.000, f()=0.2745223045, actual=-0.0326410533, est[0] = 0.2745223045
 itr number = 70, x=0.000, y=0.000, z=0.000, f()=0.2701811790, actual=0.0528029090, est[0] = 0.2701811790
 itr number = 71, x=0.000, y=0.000, z=0.000, f()=0.2670528889, actual=-0.0344147654, est[0] = 0.2670528889
 itr number = 72, x=0.000, y=0.000, z=0.000, f()=0.2629890442, actual=-0.1639553171, est[0] = 0.2629890442
 itr number = 73, x=0.000, y=0.000, z=0.000, f()=0.2572088242, actual=0.0878229137, est[0] = 0.2572088242
 itr number = 74, x=0.000, y=0.000, z=0.000, f()=0.2549798489, actual=0.0409241278, est[0] = 0.2549798489
 itr number = 75, x=0.000, y=0.000, z=0.000, f()=0.2520608902, actual=-0.0501650968, est[0] = 0.2520608902
 itr number = 76, x=0.000, y=0.000, z=0.0

  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))


 itr number = 79, x=0.000, y=0.000, z=0.000, f()=0.2386822701, actual=0.0954670098, est[0] = 0.2386822701
 itr number = 80, x=0.000, y=0.000, z=0.000, f()=0.2369010448, actual=0.0367176567, est[0] = 0.2369010448
 itr number = 81, x=0.000, y=0.000, z=0.000, f()=0.2344644070, actual=-0.0451650267, est[0] = 0.2344644070
 itr number = 82, x=0.000, y=0.000, z=0.000, f()=0.2311077118, actual=-0.0584117150, est[0] = 0.2311077118
 itr number = 83, x=0.000, y=0.000, z=0.000, f()=0.2277069092, actual=0.0330014618, est[0] = 0.2277069092
 itr number = 84, x=0.000, y=0.000, z=0.000, f()=0.2253513336, actual=-0.1752646177, est[0] = 0.2253513336
 itr number = 85, x=0.000, y=0.000, z=0.000, f()=0.2207431793, actual=-0.0125344110, est[0] = 0.2207431793
 itr number = 86, x=0.000, y=0.000, z=0.000, f()=0.2180395126, actual=0.0726863108, est[0] = 0.2180395126
 itr number = 87, x=0.000, y=0.000, z=0.000, f()=0.2163944244, actual=0.0051248732, est[0] = 0.2163944244
 itr number = 88, x=0.000, y=0.000, z=0.00

  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))
  result = ((x1 * 10e+0) / (x2 * x3))


 itr number = 90, x=0.000, y=0.000, z=0.000, f()=0.2057771683, actual=-0.0602799470, est[0] = 0.2057771683
 itr number = 91, x=0.000, y=0.000, z=0.000, f()=0.2028985023, actual=-0.1051639543, est[0] = 0.2028985023
 itr number = 92, x=0.000, y=0.000, z=0.000, f()=0.1995873451, actual=0.0494848541, est[0] = 0.1995873451
 itr number = 93, x=0.000, y=0.000, z=0.000, f()=0.1979978085, actual=0.0425641263, est[0] = 0.1979978085
 itr number = 94, x=0.000, y=0.000, z=0.000, f()=0.1962997913, actual=0.0411710290, est[0] = 0.1962997913
 itr number = 95, x=0.000, y=0.000, z=0.000, f()=0.1947813034, actual=-0.0057735888, est[0] = 0.1947813034
 itr number = 96, x=0.000, y=0.000, z=0.000, f()=0.1927211285, actual=0.0120216444, est[0] = 0.1927211285
 itr number = 97, x=0.000, y=0.000, z=0.000, f()=0.1908268929, actual=-0.0321286024, est[0] = 0.1908268929
 itr number = 98, x=0.000, y=0.000, z=0.000, f()=0.1885962486, actual=-0.0432362487, est[0] = 0.1885962486
 itr number = 99, x=0.000, y=0.000, z=0.0

  result = ((x1 * 10e+0) / (x2 * x3))
