Skip to content

Commit

Permalink
Make get_infeasible_cost return a cost value for each outcome. (#1191)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1191

Current implementation returns a single `M` value. This modifies it to return an `m`-dim tensor of infeasible cost values, each corresponding to one of the `m` outcomes.

Reviewed By: Balandat

Differential Revision: D35847505

fbshipit-source-id: 39b3cd27b53b3e4536130709faecd7e4f1b87715
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Apr 25, 2022
1 parent 92112b7 commit 46fc326
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 21 deletions.
4 changes: 2 additions & 2 deletions botorch/acquisition/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import inspect
import warnings
from abc import ABC, abstractmethod
from typing import Callable, List, Optional
from typing import Union, Callable, List, Optional

import torch
from botorch.exceptions.errors import UnsupportedError
Expand Down Expand Up @@ -441,7 +441,7 @@ def __init__(
self,
objective: Callable[[Tensor, Optional[Tensor]], Tensor],
constraints: List[Callable[[Tensor], Tensor]],
infeasible_cost: float = 0.0,
infeasible_cost: Union[Tensor, float] = 0.0,
eta: float = 1e-3,
) -> None:
r"""Feasibility-weighted objective.
Expand Down
20 changes: 12 additions & 8 deletions botorch/acquisition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,22 +209,22 @@ def get_infeasible_cost(
model: Model,
objective: Optional[Callable[[Tensor, Optional[Tensor]], Tensor]] = None,
posterior_transform: Optional[PosteriorTransform] = None,
) -> float:
) -> Tensor:
r"""Get infeasible cost for a model and objective.
Computes an infeasible cost `M` such that `-M < min_x f(x)` almost always,
so that feasible points are preferred.
For each outcome, computes an infeasible cost `M` such that
`-M < min_x f(x)` almost always, so that feasible points are preferred.
Args:
X: A `n x d` Tensor of `n` design points to use in evaluating the
minimum. These points should cover the design space well. The more
points the better the estimate, at the expense of added computation.
model: A fitted botorch model.
model: A fitted botorch model with `m` outcomes.
objective: The objective with which to evaluate the model output.
posterior_transform: A PosteriorTransform (optional).
Returns:
The infeasible cost `M` value.
An `m`-dim tensor of infeasible cost values.
Example:
>>> model = SingleTaskGP(train_X, train_Y)
Expand All @@ -237,9 +237,13 @@ def objective(Y: Tensor, X: Optional[Tensor] = None):
return Y.squeeze(-1)

posterior = model.posterior(X, posterior_transform=posterior_transform)
lb = objective(posterior.mean - 6 * posterior.variance.clamp_min(0).sqrt()).min()
M = -(lb.clamp_max(0.0))
return M.item()
lb = objective(posterior.mean - 6 * posterior.variance.clamp_min(0).sqrt())
if lb.ndim < posterior.mean.ndim:
lb = lb.unsqueeze(-1)
# Take outcome-wise min. Looping in to handle batched models.
while lb.dim() > 1:
lb = lb.min(dim=-2).values
return -(lb.clamp_max(0.0))


def is_nonnegative(acq_function: AcquisitionFunction) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions test/acquisition/test_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def test_constrained_mc_objective(self):
obj=constrained_obj,
constraints=[feasible_con, infeasible_con],
samples=samples,
infeasible_cost=0.0,
infeasible_cost=torch.tensor([0.0], device=self.device, dtype=dtype),
)
self.assertTrue(torch.equal(obj(samples), constrained_obj))
# one feasible, one infeasible, infeasible_cost
Expand All @@ -342,7 +342,7 @@ def test_constrained_mc_objective(self):
obj = ConstrainedMCObjective(
objective=generic_obj,
constraints=[feasible_con, infeasible_con],
infeasible_cost=5.0,
infeasible_cost=torch.tensor([5.0], device=self.device, dtype=dtype),
)
samples = torch.randn(4, 3, 2, device=self.device, dtype=dtype)
constrained_obj = generic_obj(samples)
Expand Down
29 changes: 20 additions & 9 deletions test/acquisition/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,23 +585,34 @@ def test_GetUnknownAcquisitionFunction(self):
class TestGetInfeasibleCost(BotorchTestCase):
def test_get_infeasible_cost(self):
for dtype in (torch.float, torch.double):
X = torch.zeros(5, 1, device=self.device, dtype=dtype)
means = torch.tensor(
[1.0, 2.0, 3.0, 4.0, 5.0], device=self.device, dtype=dtype
)
variances = torch.tensor(
[0.09, 0.25, 0.36, 0.25, 0.09], device=self.device, dtype=dtype
tkwargs = {"dtype": dtype, "device": self.device}
X = torch.zeros(5, 1, **tkwargs)
means = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], **tkwargs).view(-1, 1)
variances = torch.tensor([0.09, 0.25, 0.36, 0.25, 0.09], **tkwargs).view(
-1, 1
)
mm = MockModel(MockPosterior(mean=means, variance=variances))
# means - 6 * std = [-0.8, -1, -0.6, 1, 3.2]. After applying the
# objective, the minimum becomes -6.0, so 6.0 should be returned.
M = get_infeasible_cost(
X=X, model=mm, objective=lambda Y: Y.squeeze(-1) - 5.0
)
self.assertEqual(M, 6.0)
# test default objective (squeeze last dim)
self.assertTrue(torch.allclose(M, torch.tensor([6.0], **tkwargs)))
# Test default objective (squeeze last dim).
M2 = get_infeasible_cost(X=X, model=mm)
self.assertEqual(M2, 1.0)
self.assertTrue(torch.allclose(M2, torch.tensor([1.0], **tkwargs)))
# Test multi-output.
m_ = means.repeat(1, 2)
m_[:, 1] -= 10
mm = MockModel(MockPosterior(mean=m_, variance=variances.expand(-1, 2)))
M3 = get_infeasible_cost(X=X, model=mm)
self.assertTrue(torch.allclose(M3, torch.tensor([1.0, 11.0], **tkwargs)))
# With a batched model.
means = means.expand(2, 4, -1, -1)
variances = variances.expand(2, 4, -1, -1)
mm = MockModel(MockPosterior(mean=means, variance=variances))
M4 = get_infeasible_cost(X=X, model=mm)
self.assertTrue(torch.allclose(M4, torch.tensor([1.0], **tkwargs)))


class TestPruneInferiorPoints(BotorchTestCase):
Expand Down

0 comments on commit 46fc326

Please sign in to comment.