In [1]:
import copy
from argparse import Namespace
import numpy as np
import matplotlib.pyplot as plt

from bax.models.simple_gp import SimpleGp
from bax.alg.algorithms import Dijkstras
from bax.acq.acqoptimize import AcqOptimizer

from bax.util.graph import Vertex, make_vertices, make_edges, farthest_pair

import neatplot
neatplot.set_style('fonts')

In [2]:
# Set random seed
seed = 0
np.random.seed(seed)

In [3]:
# make a grid

g = 10
x1, x2 = np.meshgrid(np.linspace(-1, 1, g), np.linspace(-1, 1, g))
positions = np.stack([x1.flatten(), x2.flatten()], axis=-1)
n = len(positions)

has_edge = [[False for _ in range(n)] for _ in range(n)]
for i in range(n):
    for j in range(i + 1, n):
        if ((abs(i - j) == 1) and (j % g != 0)):
            has_edge[i][j] = True
        elif (abs(i - j) == g):
            has_edge[i][j] = True
        else:
            has_edge[i][j] = False
has_edge = np.array(has_edge)

In [4]:
has_edge

array([[False,  True, False, ..., False, False, False],
       [False, False,  True, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       ...,
       [False, False, False, ..., False,  True, False],
       [False, False, False, ..., False, False,  True],
       [False, False, False, ..., False, False, False]])

In [5]:
vertices = make_vertices(positions, has_edge)

In [6]:
vertices[0]

(0, [1, 10])

In [7]:
edges = make_edges(vertices)

In [8]:
# make a graph
def l2_dist(u: Vertex, v: Vertex):
    return np.sqrt(np.sum((u.position - v.position)**2))

start, goal = farthest_pair(vertices, distance_func=l2_dist)

In [9]:
# Set function
f = lambda x: x[0]**2 + x[1]**2

# Set data for model
data = Namespace()
data.x = []
data.y = [f(x) for x in data.x]

# Set model as a GP
gp_params = {'ls': 1.0, 'alpha': 1.0, 'sigma': 1e-2}
model = SimpleGp(gp_params)
model.set_data(data)

# Set algorithm
algo = Dijkstras({
    'start': start,
    'goal': goal,
    'vertices': vertices})

x_test = [positions] # input points to maximize acquisition function over

*[INFO] SimpleGp with params=Namespace(alpha=1.0, kernel=<function kern_exp_quad at 0x7f7a89e9d430>, ls=1.0, name='SimpleGp', sigma=0.01)
*[INFO] Dijkstras with params=Namespace(goal=(99, [89, 98]), name='Dijkstras', start=(0, [1, 10]), vertices=[(0, [1, 10]), (1, [0, 2, 11]), (2, [1, 3, 12]), (3, [2, 4, 13]), (4, [3, 5, 14]), (5, [4, 6, 15]), (6, [5, 7, 16]), (7, [6, 8, 17]), (8, [7, 9, 18]), (9, [8, 19]), (10, [0, 11, 20]), (11, [1, 10, 12, 21]), (12, [2, 11, 13, 22]), (13, [3, 12, 14, 23]), (14, [4, 13, 15, 24]), (15, [5, 14, 16, 25]), (16, [6, 15, 17, 26]), (17, [7, 16, 18, 27]), (18, [8, 17, 19, 28]), (19, [9, 18, 29]), (20, [10, 21, 30]), (21, [11, 20, 22, 31]), (22, [12, 21, 23, 32]), (23, [13, 22, 24, 33]), (24, [14, 23, 25, 34]), (25, [15, 24, 26, 35]), (26, [16, 25, 27, 36]), (27, [17, 26, 28, 37]), (28, [18, 27, 29, 38]), (29, [19, 28, 39]), (30, [20, 31, 40]), (31, [21, 30, 32, 41]), (32, [22, 31, 33, 42]), (33, [23, 32, 34, 43]), (34, [24, 33, 35, 44]), (35, [25, 34, 36, 4

In [10]:
# BAX iterations
n_iter = 40

for i in range(n_iter):
    # Optimize acquisition function
    acqopt = AcqOptimizer({'x_test': x_test,
                           'n_path': 1})
    x_next = acqopt.optimize(model, algo)
    print(f'Acq optimizer x_next = {x_next}')
    print(f'Finished iter i = {i}')

    # Query function, update data
    y_next = f(x_next)
    data.x.append(x_next)
    data.y.append(y_next)

    # Update model
    model = SimpleGp(gp_params)
    model.set_data(data)

*[INFO] AcqOptimizer with params=Namespace(acq_str='exe', n_path=1, name='AcqOptimizer', opt_str='rs', viz_acq=True)
best_cost 0
best_cost 0.04472832271810501
best_cost 0.09574603773266976
best_cost 0.10945535304845877
best_cost 0.14771738652166277
best_cost 0.1542679906442097
best_cost 0.18728640306672473
best_cost 0.19706692142641002
best_cost 0.19744396924736574
best_cost 0.2058545504805307
best_cost 0.23290835279650013
best_cost 0.23654958589248198
best_cost 0.23920854994319418
best_cost 0.2516493114296898
best_cost 0.2627425523654201
best_cost 0.2804927072607508
best_cost 0.28177298140029605
best_cost 0.2892865998970269
best_cost 0.29354708188165013
best_cost 0.302034499030132
best_cost 0.30548328535926994
best_cost 0.3076614029161422
best_cost 0.31543874598752275
best_cost 0.32237466659647795
best_cost 0.33076346098016307
best_cost 0.33336917065003124
best_cost 0.3342650516760024
best_cost 0.3422328855469199
best_cost 0.3464792139510475
best_cost 0.3510378637510756
best_cost 0.35

Process ForkPoolWorker-1:
Traceback (most recent call last):
  File "/home/alex/miniconda3/envs/dev/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/alex/miniconda3/envs/dev/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/alex/miniconda3/envs/dev/lib/python3.8/multiprocessing/pool.py", line 114, in worker
    task = get()
  File "/home/alex/miniconda3/envs/dev/lib/python3.8/multiprocessing/queues.py", line 356, in get
    res = self._reader.recv_bytes()
  File "/home/alex/miniconda3/envs/dev/lib/python3.8/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
  File "/home/alex/miniconda3/envs/dev/lib/python3.8/multiprocessing/connection.py", line 414, in _recv_bytes
    buf = self._recv(4)
  File "/home/alex/miniconda3/envs/dev/lib/python3.8/multiprocessing/connection.py", line 379, in _recv
    chunk = read(handle, remaining)
Ke

*[TIME] [Sample 1 execution paths] Elapsed: 228.39 seconds


KeyboardInterrupt: 