Skip to content

Commit

Permalink
Merge e147346 into 63dd0cd
Browse files Browse the repository at this point in the history
  • Loading branch information
sdaulton committed Feb 15, 2023
2 parents 63dd0cd + e147346 commit a10d9f2
Show file tree
Hide file tree
Showing 10 changed files with 219 additions and 50 deletions.
26 changes: 7 additions & 19 deletions botorch/acquisition/fixed_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@

import torch
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
from torch import Tensor
from torch.nn import Module


class FixedFeatureAcquisitionFunction(AcquisitionFunction):
class FixedFeatureAcquisitionFunction(AbstractAcquisitionFunctionWrapper):
"""A wrapper around AquisitionFunctions to fix a subset of features.
Example:
Expand Down Expand Up @@ -56,8 +56,7 @@ def __init__(
combination of `Tensor`s and numbers which can be broadcasted
to form a tensor with trailing dimension size of `d_f`.
"""
Module.__init__(self)
self.acq_func = acq_function
AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=acq_function)
dtype = torch.float
device = torch.device("cpu")
self.d = d
Expand Down Expand Up @@ -126,24 +125,13 @@ def forward(self, X: Tensor):
X_full = self._construct_X_full(X)
return self.acq_func(X_full)

@property
def X_pending(self):
r"""Return the `X_pending` of the base acquisition function."""
try:
return self.acq_func.X_pending
except (ValueError, AttributeError):
raise ValueError(
f"Base acquisition function {type(self.acq_func).__name__} "
"does not have an `X_pending` attribute."
)

@X_pending.setter
def X_pending(self, X_pending: Optional[Tensor]):
def set_X_pending(self, X_pending: Optional[Tensor]):
r"""Sets the `X_pending` of the base acquisition function."""
if X_pending is not None:
self.acq_func.X_pending = self._construct_X_full(X_pending)
full_X_pending = self._construct_X_full(X_pending)
else:
self.acq_func.X_pending = X_pending
full_X_pending = None
self.acq_func.set_X_pending(full_X_pending)

def _construct_X_full(self, X: Tensor) -> Tensor:
r"""Constructs the full input for the base acquisition function.
Expand Down
24 changes: 5 additions & 19 deletions botorch/acquisition/penalized.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@

import torch
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.analytic import AnalyticAcquisitionFunction
from botorch.acquisition.objective import GenericMCObjective
from botorch.exceptions import UnsupportedError
from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
from torch import Tensor


Expand Down Expand Up @@ -139,7 +138,7 @@ def forward(self, X: Tensor) -> Tensor:
return regularization_term


class PenalizedAcquisitionFunction(AcquisitionFunction):
class PenalizedAcquisitionFunction(AbstractAcquisitionFunctionWrapper):
r"""Single-outcome acquisition function regularized by the given penalty.
The usage is similar to:
Expand All @@ -161,29 +160,16 @@ def __init__(
penalty_func: The regularization function.
regularization_parameter: Regularization parameter used in optimization.
"""
super().__init__(model=raw_acqf.model)
self.raw_acqf = raw_acqf
AcquisitionFunction.__init__(self, model=raw_acqf.model)
AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=raw_acqf)
self.penalty_func = penalty_func
self.regularization_parameter = regularization_parameter

def forward(self, X: Tensor) -> Tensor:
raw_value = self.raw_acqf(X=X)
raw_value = self.acq_func(X=X)
penalty_term = self.penalty_func(X)
return raw_value - self.regularization_parameter * penalty_term

@property
def X_pending(self) -> Optional[Tensor]:
return self.raw_acqf.X_pending

def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None:
if not isinstance(self.raw_acqf, AnalyticAcquisitionFunction):
self.raw_acqf.set_X_pending(X_pending=X_pending)
else:
raise UnsupportedError(
"The raw acquisition function is Analytic and does not account "
"for X_pending yet."
)


def group_lasso_regularizer(X: Tensor, groups: List[List[int]]) -> Tensor:
r"""Computes the group lasso regularization function for the given point.
Expand Down
15 changes: 10 additions & 5 deletions botorch/acquisition/proximal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import torch
from botorch.acquisition import AcquisitionFunction

from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
from botorch.exceptions.errors import UnsupportedError
from botorch.models import ModelListGP
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
Expand All @@ -25,7 +27,7 @@
from torch.nn import Module


class ProximalAcquisitionFunction(AcquisitionFunction):
class ProximalAcquisitionFunction(AbstractAcquisitionFunctionWrapper):
"""A wrapper around AcquisitionFunctions to add proximal weighting of the
acquisition function. The acquisition function is
weighted via a squared exponential centered at the last training point,
Expand Down Expand Up @@ -70,17 +72,14 @@ def __init__(
beta: If not None, apply a softplus transform to the base acquisition
function, allows negative base acquisition function values.
"""
Module.__init__(self)

self.acq_func = acq_function
AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=acq_function)
model = self.acq_func.model

if hasattr(acq_function, "X_pending"):
if acq_function.X_pending is not None:
raise UnsupportedError(
"Proximal acquisition function requires `X_pending` to be None."
)
self.X_pending = acq_function.X_pending

self.register_buffer("proximal_weights", proximal_weights)
self.register_buffer(
Expand All @@ -91,6 +90,12 @@ def __init__(

_validate_model(model, proximal_weights)

def set_X_pending(self, X_pending: Optional[Tensor]) -> None:
r"""Sets the `X_pending` of the base acquisition function."""
raise UnsupportedError(
"Proximal acquisition function does not support `X_pending`."
)

@t_batch_mode_transform(expected_q=1, assert_output_shape=False)
def forward(self, X: Tensor) -> Tensor:
r"""Evaluate base acquisition function with proximal weighting.
Expand Down
17 changes: 15 additions & 2 deletions botorch/acquisition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from __future__ import annotations

import math
from typing import Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from botorch.acquisition import analytic, monte_carlo, multi_objective # noqa F401
Expand All @@ -22,6 +22,7 @@
MCAcquisitionObjective,
PosteriorTransform,
)
from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
from botorch.exceptions.errors import UnsupportedError
from botorch.models.fully_bayesian import MCMC_DIM
from botorch.models.model import Model
Expand Down Expand Up @@ -253,6 +254,18 @@ def objective(Y: Tensor, X: Optional[Tensor] = None):
return -(lb.clamp_max(0.0))


def isinstance_af(
__obj: object,
__class_or_tuple: Union[type, tuple[Union[type, tuple[Any, ...]], ...]],
) -> bool:
r"""A variant of isinstance first checks for the acq_func attribute on wrapped acquisition functions."""
if isinstance(__obj, AbstractAcquisitionFunctionWrapper):
isinstance_base_af = isinstance(__obj.acq_func, __class_or_tuple)
else:
isinstance_base_af = False
return isinstance_base_af or isinstance(__obj, __class_or_tuple)


def is_nonnegative(acq_function: AcquisitionFunction) -> bool:
r"""Determine whether a given acquisition function is non-negative.
Expand All @@ -267,7 +280,7 @@ def is_nonnegative(acq_function: AcquisitionFunction) -> bool:
>>> qEI = qExpectedImprovement(model, best_f=0.1)
>>> is_nonnegative(qEI) # returns True
"""
return isinstance(
return isinstance_af(
acq_function,
(
analytic.ExpectedImprovement,
Expand Down
55 changes: 55 additions & 0 deletions botorch/acquisition/wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
A wrapper classes around AcquisitionFunctions to modify inputs and outputs.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Optional

from botorch.acquisition.acquisition import AcquisitionFunction
from torch import Tensor
from torch.nn import Module


class AbstractAcquisitionFunctionWrapper(AcquisitionFunction, ABC):
r"""Abstract acquisition wrapper."""

def __init__(self, acq_function: AcquisitionFunction) -> None:
Module.__init__(self)
self.acq_func = acq_function

@property
def X_pending(self) -> Optional[Tensor]:
r"""Return the `X_pending` of the base acquisition function."""
try:
return self.acq_func.X_pending
except (ValueError, AttributeError):
raise ValueError(
f"Base acquisition function {type(self.acq_func).__name__} "
"does not have an `X_pending` attribute."
)

def set_X_pending(self, X_pending: Optional[Tensor]) -> None:
r"""Sets the `X_pending` of the base acquisition function."""
self.acq_func.set_X_pending(X_pending)

@abstractmethod
def forward(self, X: Tensor) -> Tensor:
r"""Evaluate the wrapped acquisition function on the candidate set X.
Args:
X: A `(b) x q x d`-dim Tensor of `(b)` t-batches with `q` `d`-dim
design points each.
Returns:
A `(b)`-dim Tensor of acquisition function values at the given
design points `X`.
"""
pass # pragma: no cover
9 changes: 7 additions & 2 deletions sphinx/source/acquisition.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ Analytic Acquisition Function API
.. autoclass:: AnalyticAcquisitionFunction
:members:

Acquisition Function Wrapper API
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.wrapper
:members:

Cached Cholesky Acquisition Function API
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.cached_cholesky
Expand Down Expand Up @@ -65,7 +70,7 @@ Multi-Objective Analytic Acquisition Functions
.. automodule:: botorch.acquisition.multi_objective.analytic
:members:
:exclude-members: MultiObjectiveAnalyticAcquisitionFunction

Multi-Objective Joint Entropy Search Acquisition Functions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.multi_objective.joint_entropy_search
Expand All @@ -86,7 +91,7 @@ Multi-Objective Multi-Fidelity Acquisition Functions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.multi_objective.multi_fidelity
:members:

Multi-Objective Predictive Entropy Search Acquisition Functions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.multi_objective.predictive_entropy_search
Expand Down
2 changes: 1 addition & 1 deletion test/acquisition/test_fixed_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_fixed_features(self):
qEI_ff.set_X_pending(X_pending[..., :-1])
self.assertAllClose(qEI.X_pending, X_pending)
# test setting to None
qEI_ff.X_pending = None
qEI_ff.set_X_pending(None)
self.assertIsNone(qEI_ff.X_pending)

# test gradient
Expand Down
8 changes: 7 additions & 1 deletion test/acquisition/test_proximal.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,15 @@ def test_proximal(self):

# test for x_pending points
pending_acq = DummyAcquisitionFunction(model)
pending_acq.set_X_pending(torch.rand(3, 3, device=self.device, dtype=dtype))
X_pending = torch.rand(3, 3, device=self.device, dtype=dtype)
pending_acq.set_X_pending(X_pending)
with self.assertRaises(UnsupportedError):
ProximalAcquisitionFunction(pending_acq, proximal_weights)
# test setting pending points
pending_acq.set_X_pending(None)
af = ProximalAcquisitionFunction(pending_acq, proximal_weights)
with self.assertRaises(UnsupportedError):
af.set_X_pending(X_pending)

# test model with multi-batch training inputs
train_X = torch.rand(5, 2, 3, device=self.device, dtype=dtype)
Expand Down
61 changes: 60 additions & 1 deletion test/acquisition/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from unittest import mock

import torch
from botorch.acquisition import monte_carlo
from botorch.acquisition import analytic, monte_carlo, multi_objective
from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction
from botorch.acquisition.multi_objective import (
MCMultiOutputObjective,
monte_carlo as moo_monte_carlo,
Expand All @@ -18,10 +19,13 @@
MCAcquisitionObjective,
ScalarizedPosteriorTransform,
)
from botorch.acquisition.proximal import ProximalAcquisitionFunction
from botorch.acquisition.utils import (
expand_trace_observations,
get_acquisition_function,
get_infeasible_cost,
is_nonnegative,
isinstance_af,
project_to_sample_points,
project_to_target_fidelity,
prune_inferior_points,
Expand Down Expand Up @@ -606,6 +610,61 @@ def test_get_infeasible_cost(self):
self.assertAllClose(M4, torch.tensor([1.0], **tkwargs))


class TestIsNonnegative(BotorchTestCase):
def test_is_nonnegative(self):
nonneg_afs = (
analytic.ExpectedImprovement,
analytic.ConstrainedExpectedImprovement,
analytic.ProbabilityOfImprovement,
analytic.NoisyExpectedImprovement,
monte_carlo.qExpectedImprovement,
monte_carlo.qNoisyExpectedImprovement,
monte_carlo.qProbabilityOfImprovement,
multi_objective.analytic.ExpectedHypervolumeImprovement,
multi_objective.monte_carlo.qExpectedHypervolumeImprovement,
multi_objective.monte_carlo.qNoisyExpectedHypervolumeImprovement,
)
mm = MockModel(
MockPosterior(
mean=torch.rand(1, 1, device=self.device),
variance=torch.ones(1, 1, device=self.device),
)
)
acq_func = analytic.ExpectedImprovement(model=mm, best_f=-1.0)
with mock.patch(
"botorch.acquisition.utils.isinstance_af", return_value=True
) as mock_isinstance_af:
self.assertTrue(is_nonnegative(acq_function=acq_func))
mock_isinstance_af.assert_called_once()
cargs, _ = mock_isinstance_af.call_args
self.assertIs(cargs[0], acq_func)
self.assertEqual(cargs[1], nonneg_afs)
acq_func = analytic.UpperConfidenceBound(model=mm, beta=2.0)
self.assertFalse(is_nonnegative(acq_function=acq_func))


class TestIsinstanceAf(BotorchTestCase):
def test_isinstance_af(self):
mm = MockModel(
MockPosterior(
mean=torch.rand(1, 1, device=self.device),
variance=torch.ones(1, 1, device=self.device),
)
)
acq_func = analytic.ExpectedImprovement(model=mm, best_f=-1.0)
self.assertTrue(isinstance_af(acq_func, analytic.ExpectedImprovement))
self.assertFalse(isinstance_af(acq_func, analytic.UpperConfidenceBound))
wrapped_af = FixedFeatureAcquisitionFunction(
acq_function=acq_func, d=2, columns=[1], values=[0.0]
)
# test base af class
self.assertTrue(isinstance_af(wrapped_af, analytic.ExpectedImprovement))
self.assertFalse(isinstance_af(wrapped_af, analytic.UpperConfidenceBound))
# test wrapper class
self.assertTrue(isinstance_af(wrapped_af, FixedFeatureAcquisitionFunction))
self.assertFalse(isinstance_af(wrapped_af, ProximalAcquisitionFunction))


class TestPruneInferiorPoints(BotorchTestCase):
def test_prune_inferior_points(self):
for dtype in (torch.float, torch.double):
Expand Down

0 comments on commit a10d9f2

Please sign in to comment.