Skip to content

Commit

Permalink
Modifies qMFKG.evaluate() to work with project, expand and cost aware…
Browse files Browse the repository at this point in the history
… utility (#594)

Summary:
<!--
Thank you for sending the PR! We appreciate you spending the time to make BoTorch better.

Help us understand your motivation by explaining why you decided to make this change.

You can learn more about contributing to BoTorch here: https://github.com/pytorch/botorch/blob/master/CONTRIBUTING.md
-->

## Motivation

Modifies `qMFKG.evaluate()` to work with `project`, `expand` and `cost_aware_utility`. Partially fixes #587.

- Introduces a `ProjectedValueFunction` that wraps the `value_function` and applies the `project` operator on the `forward` call.
- Changes `evaluate()` signature to use `X` instead of `X_actual`. Current implementation raises an exception with the decorators when called with `evaluate(X_actual=...)`.

Note: The treatment of `cost_aware_utility` assumes that it is monotone non-decreasing in `deltas`. Otherwise, optimizing the inner problem and passing through `cost_aware_utility` may not produce the correct output.

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/master/CONTRIBUTING.md#pull-requests)?

Yes

Pull Request resolved: #594

Test Plan: Added mock unit tests. Verified the expected behavior in additional offline tests.

Reviewed By: qingfeng10

Differential Revision: D25173613

Pulled By: Balandat

fbshipit-source-id: 3ba0f196a622a84c951fdc3526a53cb6905e85d2
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Dec 8, 2020
1 parent 519b18b commit 439c9ef
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 15 deletions.
6 changes: 5 additions & 1 deletion botorch/acquisition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
InverseCostWeightedUtility,
)
from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
from botorch.acquisition.knowledge_gradient import (
qKnowledgeGradient,
qMultiFidelityKnowledgeGradient,
)
from botorch.acquisition.max_value_entropy_search import (
qMaxValueEntropy,
qMultiFidelityMaxValueEntropy,
Expand Down Expand Up @@ -62,6 +65,7 @@
"UpperConfidenceBound",
"qExpectedImprovement",
"qKnowledgeGradient",
"qMultiFidelityKnowledgeGradient",
"qMaxValueEntropy",
"qMultiFidelityMaxValueEntropy",
"qNoisyExpectedImprovement",
Expand Down
64 changes: 54 additions & 10 deletions botorch/acquisition/knowledge_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,12 @@ def forward(self, X: Tensor) -> Tensor:

@concatenate_pending_points
@t_batch_mode_transform()
def evaluate(self, X_actual: Tensor, bounds: Tensor, **kwargs: Any) -> Tensor:
def evaluate(self, X: Tensor, bounds: Tensor, **kwargs: Any) -> Tensor:
r"""Evaluate qKnowledgeGradient on the candidate set `X_actual` by
solving the inner optimization problem.
Args:
X_actual: A `b x q x d` Tensor with `b` t-batches of `q` design points
X: A `b x q x d` Tensor with `b` t-batches of `q` design points
each. Unlike `forward()`, this does not include solutions of the
inner optimization problem.
bounds: A `2 x d` tensor of lower and upper bounds for each column of
Expand All @@ -206,18 +206,24 @@ def evaluate(self, X_actual: Tensor, bounds: Tensor, **kwargs: Any) -> Tensor:
Returns:
A Tensor of shape `b`. For t-batch b, the q-KG value of the design
`X_actual[b]` is averaged across the fantasy models.
`X[b]` is averaged across the fantasy models.
NOTE: If `current_value` is not provided, then this is not the
true KG value of `X_actual[b]`.
true KG value of `X[b]`.
"""
if hasattr(self, "expand"):
X = self.expand(X)

# construct the fantasy model of shape `num_fantasies x b`
fantasy_model = self.model.fantasize(
X=X_actual, sampler=self.sampler, observation_noise=True
X=X, sampler=self.sampler, observation_noise=True
)

# get the value function
value_function = _get_value_function(
model=fantasy_model, objective=self.objective, sampler=self.inner_sampler
model=fantasy_model,
objective=self.objective,
sampler=self.inner_sampler,
project=getattr(self, "project", None),
)

from botorch.generation.gen import gen_candidates_scipy
Expand Down Expand Up @@ -246,6 +252,10 @@ def evaluate(self, X_actual: Tensor, bounds: Tensor, **kwargs: Any) -> Tensor:
if self.current_value is not None:
values = values - self.current_value

if hasattr(self, "cost_aware_utility"):
values = self.cost_aware_utility(
X=X, deltas=values, sampler=self.cost_sampler
)
# return average over the fantasy samples
return values.mean(dim=0)

Expand Down Expand Up @@ -409,13 +419,16 @@ def forward(self, X: Tensor) -> Tensor:

# get the value function
value_function = _get_value_function(
model=fantasy_model, objective=self.objective, sampler=self.inner_sampler
model=fantasy_model,
objective=self.objective,
sampler=self.inner_sampler,
project=self.project,
)

# make sure to propagate gradients to the fantasy model train inputs
# project the fantasy points
with settings.propagate_grads(True):
values = value_function(X=self.project(X_fantasies)) # num_fantasies x b
values = value_function(X=X_fantasies) # num_fantasies x b

if self.current_value is not None:
values = values - self.current_value
Expand All @@ -429,16 +442,47 @@ def forward(self, X: Tensor) -> Tensor:
return values.mean(dim=0)


class ProjectedAcquisitionFunction(AcquisitionFunction):
r"""
Defines a wrapper around an `AcquisitionFunction` that incorporates the project
operator. Typically used to handle value functions in look-ahead methods.
"""

def __init__(
self,
base_value_function: AcquisitionFunction,
project: Callable[[Tensor], Tensor],
) -> None:
super().__init__(base_value_function.model)
self.base_value_function = base_value_function
self.project = project
self.objective = base_value_function.objective
self.sampler = getattr(base_value_function, "sampler", None)

def forward(self, X: Tensor) -> Tensor:
return self.base_value_function(self.project(X))


def _get_value_function(
model: Model,
objective: Optional[Union[MCAcquisitionObjective, ScalarizedObjective]] = None,
sampler: Optional[MCSampler] = None,
project: Optional[Callable[[Tensor], Tensor]] = None,
) -> AcquisitionFunction:
r"""Construct value function (i.e. inner acquisition function)."""
if isinstance(objective, MCAcquisitionObjective):
return qSimpleRegret(model=model, sampler=sampler, objective=objective)
base_value_function = qSimpleRegret(
model=model, sampler=sampler, objective=objective
)
else:
return PosteriorMean(model=model, objective=objective)
base_value_function = PosteriorMean(model=model, objective=objective)
if project is None:
return base_value_function
else:
return ProjectedAcquisitionFunction(
base_value_function=base_value_function,
project=project,
)


def _split_fantasy_points(X: Tensor, n_f: int) -> Tuple[Tensor, Tensor]:
Expand Down
7 changes: 3 additions & 4 deletions botorch/optim/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
_get_value_function,
qKnowledgeGradient,
)
from botorch.acquisition.monte_carlo import MCAcquisitionFunction
from botorch.acquisition.utils import is_nonnegative
from botorch.exceptions.warnings import BadInitialCandidatesWarning, SamplingWarning
from botorch.models.model import Model
Expand Down Expand Up @@ -210,6 +209,7 @@ def gen_one_shot_kg_initial_conditions(
model=acq_function.model,
objective=acq_function.objective,
sampler=acq_function.inner_sampler,
project=getattr(acq_function, "project", None),
)
from botorch.optim.optimize import optimize_acqf

Expand Down Expand Up @@ -304,9 +304,8 @@ def gen_value_function_initial_conditions(
value_function = _get_value_function(
model=current_model,
objective=acq_function.objective,
sampler=acq_function.sampler
if isinstance(acq_function, MCAcquisitionFunction)
else None,
sampler=getattr(acq_function, "sampler", None),
project=getattr(acq_function, "project", None),
)
from botorch.optim.optimize import optimize_acqf

Expand Down
91 changes: 91 additions & 0 deletions test/acquisition/test_knowledge_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from contextlib import ExitStack
from unittest import mock

import torch
Expand All @@ -14,6 +15,7 @@
_split_fantasy_points,
qKnowledgeGradient,
qMultiFidelityKnowledgeGradient,
ProjectedAcquisitionFunction,
)
from botorch.acquisition.monte_carlo import qSimpleRegret
from botorch.acquisition.objective import GenericMCObjective, ScalarizedObjective
Expand Down Expand Up @@ -399,6 +401,72 @@ def test_evaluate_q_multi_fidelity_knowledge_gradient(self):
self.assertTrue(torch.allclose(val, val_exp, atol=1e-4))
self.assertTrue(torch.equal(qMFKG.extract_candidates(X), X[..., :-n_f, :]))

def test_evaluate_qMFKG(self):
# mock test qMFKG.evaluate() with expand, project & cost aware utility
for dtype in (torch.float, torch.double):
mean = torch.zeros(1, 1, 1, device=self.device, dtype=dtype)
mm = MockModel(MockPosterior(mean=mean))
mm._input_batch_shape = torch.Size([1])
cau = GenericCostAwareUtility(mock_util)
n_f = 4
mean = torch.rand(n_f, 2, 1, 1, device=self.device, dtype=dtype)
variance = torch.rand(n_f, 2, 1, 1, device=self.device, dtype=dtype)
mfm = MockModel(MockPosterior(mean=mean, variance=variance))
mfm._input_batch_shape = torch.Size([n_f, 2])
with ExitStack() as es:
patch_f = es.enter_context(
mock.patch.object(MockModel, "fantasize", return_value=mfm)
)
mock_num_outputs = es.enter_context(
mock.patch(NO, new_callable=mock.PropertyMock)
)
es.enter_context(
mock.patch(
"botorch.optim.optimize.optimize_acqf",
return_value=(
torch.ones(1, 1, 1, device=self.device, dtype=dtype),
torch.ones(1, device=self.device, dtype=dtype),
),
),
)
es.enter_context(
mock.patch(
"botorch.generation.gen.gen_candidates_scipy",
return_value=(
torch.ones(1, 1, 1, device=self.device, dtype=dtype),
torch.ones(1, device=self.device, dtype=dtype),
),
),
)

mock_num_outputs.return_value = 1
qMFKG = qMultiFidelityKnowledgeGradient(
model=mm,
num_fantasies=n_f,
X_pending=torch.rand(1, 1, 1, device=self.device, dtype=dtype),
current_value=torch.zeros(1, device=self.device, dtype=dtype),
cost_aware_utility=cau,
project=lambda X: torch.zeros_like(X),
expand=lambda X: torch.ones_like(X),
)
val = qMFKG.evaluate(
X=torch.zeros(1, 1, 1, device=self.device, dtype=dtype),
bounds=torch.tensor([[0.0], [1.0]]),
num_restarts=1,
raw_samples=1,
)
patch_f.asset_called_once()
cargs, ckwargs = patch_f.call_args
self.assertTrue(
torch.equal(
ckwargs["X"],
torch.ones(1, 2, 1, device=self.device, dtype=dtype),
)
)
self.assertEqual(
val, cau(None, torch.ones(1, device=self.device, dtype=dtype))
)


class TestKGUtils(BotorchTestCase):
def test_get_value_function(self):
Expand All @@ -416,6 +484,29 @@ def test_get_value_function(self):
self.assertIsInstance(vf, qSimpleRegret)
self.assertEqual(vf.objective, obj)
self.assertEqual(vf.sampler, sampler)
# test with project
mock_project = mock.Mock(
return_value=torch.ones(1, 1, 1, device=self.device)
)
vf = _get_value_function(
model=mm,
objective=obj,
sampler=sampler,
project=mock_project,
)
self.assertIsInstance(vf, ProjectedAcquisitionFunction)
self.assertEqual(vf.objective, obj)
self.assertEqual(vf.sampler, sampler)
self.assertEqual(vf.project, mock_project)
test_X = torch.rand(1, 1, 1, device=self.device)
with mock.patch.object(
vf, "base_value_function", __class__=torch.nn.Module, return_value=None
) as patch_bvf:
vf(test_X)
mock_project.assert_called_once_with(test_X)
patch_bvf.assert_called_once_with(
torch.ones(1, 1, 1, device=self.device)
)

def test_split_fantasy_points(self):
for dtype in (torch.float, torch.double):
Expand Down

0 comments on commit 439c9ef

Please sign in to comment.