Skip to content

Commit

Permalink
Moving penalized acqfn from botorch_fb to botorch (#585)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #585

Just moved the development for penalized acqfn from botorch_fb to botorch to push it to the OSS.

Reviewed By: Balandat

Differential Revision: D24508442

fbshipit-source-id: 54f0884e8e5a86296c6d0e58cc913bbf46323dbb
  • Loading branch information
Abbas Kazerouni authored and facebook-github-bot committed Oct 24, 2020
1 parent 4033c16 commit c5ec613
Show file tree
Hide file tree
Showing 3 changed files with 308 additions and 0 deletions.
172 changes: 172 additions & 0 deletions botorch/acquisition/penalized.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Modules to add regularization to acquisition functions.
"""

from __future__ import annotations

import math
from typing import List, Optional

import torch
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.analytic import AnalyticAcquisitionFunction
from botorch.exceptions import UnsupportedError
from torch import Tensor


class L2Penalty(torch.nn.Module):
r"""L2 penalty class to be added to any arbitrary acquisition function."""

def __init__(self, init_point: Tensor):
r"""Initializing L2 regularization.
Args:
init_point: The "1 x dim" reference point against which
we want to regularize.
"""
super().__init__()
self.init_point = init_point

def forward(self, X: Tensor) -> Tensor:
r"""
Args:
X: A "batch_shape x q x dim" representing the points to be evaluated.
Returns:
A tensor of size "batch_shape" representing the acqfn for each q-batch.
"""
regularization_term = (
torch.norm((X - self.init_point), p=2, dim=-1).max(dim=-1).values ** 2
)
return regularization_term


class GaussianPenalty(torch.nn.Module):
r"""Gaussian penalty class to be added to any arbitrary acquisition function."""

def __init__(self, init_point: Tensor, sigma: float):
r"""Initializing Gaussian regularization.
Args:
init_point: The "1 x dim" reference point against which
we want to regularize.
sigma: The parameter used in gaussian function.
"""
super().__init__()
self.init_point = init_point
self.sigma = sigma

def forward(self, X: Tensor) -> Tensor:
r"""
Args:
X: A "batch_shape x q x dim" representing the points to be evaluated.
Returns:
A tensor of size "batch_shape" representing the acqfn for each q-batch.
"""
sq_diff = torch.norm((X - self.init_point), p=2, dim=-1) ** 2
pdf = torch.exp(sq_diff / 2 / self.sigma ** 2)
regularization_term = pdf.max(dim=-1).values
return regularization_term


class GroupLassoPenalty(torch.nn.Module):
r"""Group lasso penalty class to be added to any arbitrary acquisition function."""

def __init__(self, init_point: Tensor, groups: List[List[int]]):
r"""Initializing Group-Lasso regularization.
Args:
init_point: The "1 x dim" reference point against which we want
to regularize.
groups: Groups of indices used in group lasso.
"""
super().__init__()
self.init_point = init_point
self.groups = groups

def forward(self, X: Tensor) -> Tensor:
r"""
X should be batch_shape x 1 x dim tensor. Evaluation for q-batch is not
implemented yet.
"""
if X.shape[-2] != 1:
raise NotImplementedError(
"group-lasso has not been implemented for q>1 yet."
)

regularization_term = group_lasso_regularizer(
X=X.squeeze(-2) - self.init_point, groups=self.groups
)
return regularization_term


class PenalizedAcquisitionFunction(AcquisitionFunction):
r"""Single-outcome acquisition function regularized by the given penalty.
The usage is similar to:
raw_acqf = NoisyExpectedImprovement(...)
penalty = GroupLassoPenalty(...)
acqf = PenalizedAcquisitionFunction(raw_acqf, penalty)
"""

def __init__(
self,
raw_acqf: AcquisitionFunction,
penalty_func: torch.nn.Module,
regularization_parameter: float,
) -> None:
r"""Initializing Group-Lasso regularization.
Args:
raw_acqf: The raw acquisition function that is going to be regularized.
penalty_func: The regularization function.
regularization_parameter: Regularization parameter used in optimization.
"""
super().__init__(model=raw_acqf.model)
self.raw_acqf = raw_acqf
self.penalty_func = penalty_func
self.regularization_parameter = regularization_parameter

def forward(self, X: Tensor) -> Tensor:
raw_value = self.raw_acqf(X=X)
penalty_term = self.penalty_func(X)
return raw_value - self.regularization_parameter * penalty_term

@property
def X_pending(self) -> Optional[Tensor]:
return self.raw_acqf.X_pending

def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None:
if not isinstance(self.raw_acqf, AnalyticAcquisitionFunction):
self.raw_acqf.set_X_pending(X_pending=X_pending)
else:
raise UnsupportedError(
"The raw acquisition function is Analytic and does not account "
"for X_pending yet."
)


def group_lasso_regularizer(X: Tensor, groups: List[List[int]]) -> Tensor:
r"""Computes the group lasso regularization function for the given point.
Args:
X: A bxd tensor representing the points to evaluate the regularization at.
groups: List of indices of different groups.
Returns:
Computed group lasso norm of at the given points.
"""
return torch.sum(
torch.stack(
[math.sqrt(len(g)) * torch.norm(X[..., g], p=2, dim=-1) for g in groups],
dim=-1,
),
dim=-1,
)
5 changes: 5 additions & 0 deletions sphinx/source/acquisition.rst
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ Fixed Feature Acquisition Function
.. automodule:: botorch.acquisition.fixed_feature
:members:

Penalized Acquisition Function Wrapper
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.penalized
:members:

General Utilities for Acquisition Functions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.utils
Expand Down
131 changes: 131 additions & 0 deletions test/acquisition/test_penalized.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from botorch.acquisition.analytic import ExpectedImprovement
from botorch.acquisition.monte_carlo import qExpectedImprovement
from botorch.acquisition.penalized import (
GaussianPenalty,
GroupLassoPenalty,
L2Penalty,
PenalizedAcquisitionFunction,
group_lasso_regularizer,
)
from botorch.exceptions import UnsupportedError
from botorch.sampling.samplers import IIDNormalSampler
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior


class TestL2Penalty(BotorchTestCase):
def test_gaussian_penalty(self):
for dtype in (torch.float, torch.double):
init_point = torch.tensor([1.0, 1.0, 1.0], device=self.device, dtype=dtype)
l2_module = L2Penalty(init_point=init_point)

# testing a batch of two points
sample_point = torch.tensor(
[[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], device=self.device, dtype=dtype
)

diff_norm_squared = (
torch.norm((sample_point - init_point), p=2, dim=-1) ** 2
)
real_value = diff_norm_squared.max(dim=-1).values
computed_value = l2_module(sample_point)
self.assertEqual(computed_value.item(), real_value.item())


class TestGaussianPenalty(BotorchTestCase):
def test_gaussian_penalty(self):
for dtype in (torch.float, torch.double):
init_point = torch.tensor([1.0, 1.0, 1.0], device=self.device, dtype=dtype)
sigma = 0.1
gaussian_module = GaussianPenalty(init_point=init_point, sigma=sigma)

# testing a batch of two points
sample_point = torch.tensor(
[[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], device=self.device, dtype=dtype
)

diff_norm_squared = (
torch.norm((sample_point - init_point), p=2, dim=-1) ** 2
)
max_l2_distance = diff_norm_squared.max(dim=-1).values
real_value = torch.exp(max_l2_distance / 2 / sigma ** 2)
computed_value = gaussian_module(sample_point)
self.assertEqual(computed_value.item(), real_value.item())


class TestGroupLassoPenalty(BotorchTestCase):
def test_group_lasso_penalty(self):
for dtype in (torch.float, torch.double):
init_point = torch.tensor([0.5, 0.5, 0.5], device=self.device, dtype=dtype)
groups = [[0, 2], [1]]
group_lasso_module = GroupLassoPenalty(init_point=init_point, groups=groups)

# testing a single point
sample_point = torch.tensor(
[[1.0, 2.0, 3.0]], device=self.device, dtype=dtype
)
real_value = group_lasso_regularizer(
sample_point - init_point, groups
) # torch.tensor([5.105551242828369], device=self.device, dtype=dtype)
computed_value = group_lasso_module(sample_point)
self.assertEqual(computed_value.item(), real_value.item())

# testing unsupported input dim: X.shape[-2] > 1
sample_point_2 = torch.tensor(
[[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], device=self.device, dtype=dtype
)
with self.assertRaises(NotImplementedError):
group_lasso_module(sample_point_2)


class TestPenalizedAcquisitionFunction(BotorchTestCase):
def test_penalized_acquisition_function(self):
for dtype in (torch.float, torch.double):
mock_model = MockModel(
MockPosterior(mean=torch.tensor([1.0]), variance=torch.tensor([1.0]))
)
init_point = torch.tensor([0.5, 0.5, 0.5], device=self.device, dtype=dtype)
groups = [[0, 2], [1]]
raw_acqf = ExpectedImprovement(model=mock_model, best_f=1.0)
penalty = GroupLassoPenalty(init_point=init_point, groups=groups)
lmbda = 0.1
acqf = PenalizedAcquisitionFunction(
raw_acqf=raw_acqf, penalty_func=penalty, regularization_parameter=lmbda
)

sample_point = torch.tensor(
[[1.0, 2.0, 3.0]], device=self.device, dtype=dtype
)
raw_value = raw_acqf(sample_point)
penalty_value = penalty(sample_point)
real_value = raw_value - lmbda * penalty_value
computed_value = acqf(sample_point)
self.assertTrue(torch.equal(real_value, computed_value))

# testing X_pending for analytic raw_acqfn (EI)
X_pending = torch.tensor([0.1, 0.2, 0.3], device=self.device, dtype=dtype)
with self.assertRaises(UnsupportedError):
acqf.set_X_pending(X_pending)

# testing X_pending for non-analytic raw_acqfn (EI)
sampler = IIDNormalSampler(num_samples=2)
raw_acqf_2 = qExpectedImprovement(
model=mock_model, best_f=0, sampler=sampler
)
init_point = torch.tensor([1.0, 1.0, 1.0], device=self.device, dtype=dtype)
l2_module = L2Penalty(init_point=init_point)
acqf_2 = PenalizedAcquisitionFunction(
raw_acqf=raw_acqf_2,
penalty_func=l2_module,
regularization_parameter=lmbda,
)

X_pending = torch.tensor([0.1, 0.2, 0.3], device=self.device, dtype=dtype)
acqf_2.set_X_pending(X_pending)
self.assertTrue(torch.equal(acqf_2.X_pending, X_pending))

0 comments on commit c5ec613

Please sign in to comment.