# Bayesian optimal experimental design
This notebook implements Bayesian optimal experimental design (BOED) using the statistical programming package pyro. Based on the example shown in https://pyro.ai/examples/working_memory.html.

In this case we aim to sample points in $x$ that will provide the most information about the model parameters $x_0$,$w$,$b$ of the function $f$ below. 

In [None]:
import matplotlib.pyplot as plt
import torch

# define the function
def f(x, x0, d, b):
    x = x - x0
    return -d * 0.5 * (1 + torch.tanh(x / b))

# define a noise corrupted version of the function
def f_noisy(x, x0, d, b, noise_std=0.01):
    return f(x, x0, d, b) + torch.randn_like(x) * noise_std


In [None]:
# visualize the ground truth function
ground_truth_x0 = 2.0 # lower edge location
ground_truth_d = 1 # plateau height
ground_truth_b = 0.1 # sharpness of the plateau edge
test_x = torch.linspace(0,3,100)

fig,ax = plt.subplots()
ax.plot(test_x,f_noisy(
    test_x,    
    x0=ground_truth_x0,
    d=ground_truth_d,
    b=ground_truth_b
))

In [None]:
from xopt import Xopt, VOCS, Evaluator
from xopt.generators.bayesian.boed import BOEDGenerator
from xopt.numerical_optimizer import GridOptimizer
import pyro.distributions as dist

vocs = VOCS(
    variables={"x": [0.0, 3.0]},
    observables=["y"]
)
priors = {
    "x0": dist.Normal(2.0, 1.0),
    "d": dist.Normal(1.0, 0.5),
    "b": dist.Normal(0.1, 0.01)
}
generator = BOEDGenerator(
    vocs=vocs,
    model_priors=priors,
    measurement_noise=0.01,
    model_function=f,
    numerical_optimizer=GridOptimizer(n_grid_points=1000),
    train_steps=3000,
    train_lr=0.01
)

evaluator = Evaluator(
    function=lambda x: {"y": float(f_noisy(torch.tensor(x["x"]), ground_truth_x0, ground_truth_d, ground_truth_b))}
)

X = Xopt(
    vocs=vocs,
    generator=generator,
    evaluator=evaluator
)

X.grid_evaluate(2)

In [None]:
X.generator.train_model(X.data)
predictive = X.generator.get_predictive()

test_x = torch.linspace(0,3,100)
pred = predictive(test_x)
y = pred["y"]

print(y.shape)
print(pred.keys())

# plot the predictions
fig,ax = plt.subplots()
ax.plot(test_x.numpy(), y.quantile(0.5, dim=0).numpy())
l = y.quantile(0.05, dim=0).numpy()
u = y.quantile(0.95, dim=0).numpy()

ax.fill_between(test_x.numpy(),l,u,alpha=0.3)
ax.scatter(X.data["x"], X.data["y"], color="red")

# overlay a few samples from the learned model
for i in range(10):
    ax.plot(test_x.numpy(), y[i].numpy(), color="gray", alpha=0.15)

# plot the eig
#X.generator.visualize_model()

current_model = X.generator.model
acqf = X.generator._get_acquisition(current_model)
candidate_designs = torch.linspace(0, 3, 1000).unsqueeze(-1)
eig = acqf(candidate_designs)

fig, ax = plt.subplots()
ax.plot(candidate_designs.numpy(), eig.detach().numpy())

fig, ax = plt.subplots(1,len(priors.keys()), figsize=(15,5))
for i, (name, prior) in enumerate(priors.items()):
    h, bins = torch.histogram(pred[name], bins=30, density=True)
    width = (bins[1]-bins[0]).numpy()
    ax[i].bar(bins[:-1].numpy(), h.numpy(), width=width, alpha=0.5, label="Posterior")
    
    ax[i].plot(bins.numpy(), torch.exp(prior.log_prob(bins)).numpy(), "r", label="Prior")
    ax[i].set_title(name)
    ax[i].legend()

In [None]:
for i in range(5):
    print(i)
    X.step()

In [None]:
X.data

In [None]:
X.generator.train_model(X.data)
predictive = X.generator.get_predictive()

test_x = torch.linspace(0,3,100)
pred = predictive(test_x)
y = pred["y"]

print(y.shape)
print(pred.keys())

# plot the predictions
fig,ax = plt.subplots()
ax.plot(test_x.numpy(), y.quantile(0.5, dim=0).numpy())
l = y.quantile(0.05, dim=0).numpy()
u = y.quantile(0.95, dim=0).numpy()

ax.fill_between(test_x.numpy(),l,u,alpha=0.3)
ax.scatter(X.data["x"], X.data["y"], color="red")

# overlay a few samples from the learned model
for i in range(10):
    ax.plot(test_x.numpy(), y[i].numpy(), color="gray", alpha=0.15)

# plot the eig
#X.generator.visualize_model()

current_model = X.generator.model
acqf = X.generator._get_acquisition(current_model)
candidate_designs = torch.linspace(0, 3, 1000).unsqueeze(-1)
eig = acqf(candidate_designs)

fig, ax = plt.subplots()
ax.plot(candidate_designs.numpy(), eig.detach().numpy())

fig, ax = plt.subplots(1,len(priors.keys()), figsize=(15,5))
for i, (name, prior) in enumerate(priors.items()):
    h, bins = torch.histogram(pred[name], bins=30, density=True)
    width = (bins[1]-bins[0]).numpy()
    ax[i].bar(bins[:-1].numpy(), h.numpy(), width=width, alpha=0.5, label="Posterior")
    
    ax[i].plot(bins.numpy(), torch.exp(prior.log_prob(bins)).numpy(), "r", label="Prior")
    ax[i].set_title(name)
    ax[i].legend()