From 6f99023dc0ace08dc0bb8fa98286467573bbd61c Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Sat, 14 Jan 2023 17:27:12 -0800 Subject: [PATCH] fix chebyshev scalarization (#1616) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/1616 See https://github.com/pytorch/botorch/issues/1614 Differential Revision: D42373368 fbshipit-source-id: 828dee51734cc54bf3f01c1876af05f9ae8b7201 --- .../utils/multi_objective/scalarization.py | 49 +++++++++---- .../multi_objective/test_scalarization.py | 73 ++++++++++++------- 2 files changed, 79 insertions(+), 43 deletions(-) diff --git a/botorch/utils/multi_objective/scalarization.py b/botorch/utils/multi_objective/scalarization.py index 9ca69ebf9f..c72eb423dc 100644 --- a/botorch/utils/multi_objective/scalarization.py +++ b/botorch/utils/multi_objective/scalarization.py @@ -19,7 +19,7 @@ from typing import Callable, Optional import torch -from botorch.exceptions.errors import BotorchTensorDimensionError +from botorch.exceptions.errors import BotorchTensorDimensionError, UnsupportedError from botorch.utils.transforms import normalize from torch import Tensor @@ -29,16 +29,18 @@ def get_chebyshev_scalarization( ) -> Callable[[Tensor, Optional[Tensor]], Tensor]: r"""Construct an augmented Chebyshev scalarization. - Augmented Chebyshev scalarization: - objective(y) = min(w * y) + alpha * sum(w * y) + The augmented Chebyshev scalarization is given by + g(y) = max_i(w_i * y_i) + alpha * sum_i(w_i * y_i) - Outcomes are first normalized to [0,1] for maximization (or [-1,0] for minimization) - and then an augmented Chebyshev scalarization is applied. + where the goal is to minimize g(y) in the setting where all objectives y_i are + to be minimized. Since the default in BoTorch is to maximize all objectives, + this method constructs a Chebyshev scalarization where the inputs are first + multiplied by -1, so that all objectives are to be minimized. Then, it computes + g(y) (which should be minimized), and returns -g(y), which should be maximized. - Note: this assumes maximization of the augmented Chebyshev scalarization. - Minimizing/Maximizing an objective is supported by passing a negative/positive - weight for that objective. To make all w * y's have positive sign - such that they are comparable when computing min(w * y), outcomes of minimization + Minimizing an objective is supported by passing a negative + weight for that objective. To make all w * y's have the same sign + such that they are comparable when computing max(w * y), outcomes of minimization objectives are shifted from [0,1] to [-1,0]. See [Knowles2005]_ for details. @@ -61,6 +63,9 @@ def get_chebyshev_scalarization( >>> weights = torch.tensor([0.75, -0.25]) >>> transform = get_aug_chebyshev_scalarization(weights, Y) """ + # the chebyshev_obj assumes all objectives should be minimized, so + # multiply Y by -1 + Y = -Y if weights.shape != Y.shape[-1:]: raise BotorchTensorDimensionError( "weights must be an `m`-dim tensor where Y is `... x m`." @@ -71,11 +76,24 @@ def get_chebyshev_scalarization( def chebyshev_obj(Y: Tensor, X: Optional[Tensor] = None) -> Tensor: product = weights * Y - return product.min(dim=-1).values + alpha * product.sum(dim=-1) + return product.max(dim=-1).values + alpha * product.sum(dim=-1) + # A boolean mask indicating if minimizing an objective + minimize = weights < 0 if Y.shape[-2] == 0: + if minimize.any(): + raise UnsupportedError( + "negative weights (for minimization) are only supported if " + "Y is provided." + ) # If there are no observations, we do not need to normalize the objectives - return chebyshev_obj + + def obj(Y: Tensor, X: Optional[Tensor] = None) -> Tensor: + # multiply the scalarization by -1, so that the scalarization should + # be maximized + return -chebyshev_obj(Y=-Y) + + return obj if Y.shape[-2] == 1: # If there is only one observation, set the bounds to be # [min(Y_m), min(Y_m) + 1] for each objective m. This ensures we do not @@ -85,15 +103,14 @@ def chebyshev_obj(Y: Tensor, X: Optional[Tensor] = None) -> Tensor: # Set the bounds to be [min(Y_m), max(Y_m)], for each objective m Y_bounds = torch.stack([Y.min(dim=-2).values, Y.max(dim=-2).values]) - # A boolean mask indicating if minimizing an objective - minimize = weights < 0 - def obj(Y: Tensor, X: Optional[Tensor] = None) -> Tensor: # scale to [0,1] - Y_normalized = normalize(Y, bounds=Y_bounds) + Y_normalized = normalize(-Y, bounds=Y_bounds) # If minimizing an objective, convert Y_normalized values to [-1,0], # such that min(w*y) makes sense, we want all w*y's to be positive Y_normalized[..., minimize] = Y_normalized[..., minimize] - 1 - return chebyshev_obj(Y=Y_normalized) + # multiply the scalarization by -1, so that the scalarization should + # be maximized + return -chebyshev_obj(Y=Y_normalized) return obj diff --git a/test/utils/multi_objective/test_scalarization.py b/test/utils/multi_objective/test_scalarization.py index 63341d2f3e..97c6ee0819 100644 --- a/test/utils/multi_objective/test_scalarization.py +++ b/test/utils/multi_objective/test_scalarization.py @@ -7,7 +7,7 @@ from __future__ import annotations import torch -from botorch.exceptions.errors import BotorchTensorDimensionError +from botorch.exceptions.errors import BotorchTensorDimensionError, UnsupportedError from botorch.utils.multi_objective.scalarization import get_chebyshev_scalarization from botorch.utils.testing import BotorchTestCase from botorch.utils.transforms import normalize @@ -17,10 +17,11 @@ class TestGetChebyshevScalarization(BotorchTestCase): def test_get_chebyshev_scalarization(self): tkwargs = {"device": self.device} Y_train = torch.rand(4, 2, **tkwargs) - Y_bounds = torch.stack( + neg_Y_train = -Y_train + neg_Y_bounds = torch.stack( [ - Y_train.min(dim=-2, keepdim=True).values, - Y_train.max(dim=-2, keepdim=True).values, + neg_Y_train.min(dim=-2, keepdim=True).values, + neg_Y_train.max(dim=-2, keepdim=True).values, ], dim=0, ) @@ -28,9 +29,10 @@ def test_get_chebyshev_scalarization(self): for batch_shape in (torch.Size([]), torch.Size([3])): tkwargs["dtype"] = dtype Y_test = torch.rand(batch_shape + torch.Size([5, 2]), **tkwargs) + neg_Y_test = -Y_test Y_train = Y_train.to(**tkwargs) - Y_bounds = Y_bounds.to(**tkwargs) - normalized_Y_test = normalize(Y_test, Y_bounds) + neg_Y_bounds = neg_Y_bounds.to(**tkwargs) + normalized_neg_Y_test = normalize(neg_Y_test, neg_Y_bounds) # test wrong shape with self.assertRaises(BotorchTensorDimensionError): get_chebyshev_scalarization( @@ -45,18 +47,27 @@ def test_get_chebyshev_scalarization(self): weights=weights, Y=Y_train ) Y_transformed = objective_transform(Y_test) - expected_Y_transformed = normalized_Y_test.min( - dim=-1 - ).values + 0.05 * normalized_Y_test.sum(dim=-1) + expected_Y_transformed = -( + normalized_neg_Y_test.max(dim=-1).values + + 0.05 * normalized_neg_Y_test.sum(dim=-1) + ) self.assertTrue(torch.equal(Y_transformed, expected_Y_transformed)) + # check that using negative objectives and negative weights + # yields an equivalent scalarized outcome + objective_transform2 = get_chebyshev_scalarization( + weights=-weights, Y=-Y_train + ) + Y_transformed2 = objective_transform2(-Y_test) + self.assertAllClose(Y_transformed, Y_transformed2) # test different alpha objective_transform = get_chebyshev_scalarization( weights=weights, Y=Y_train, alpha=1.0 ) Y_transformed = objective_transform(Y_test) - expected_Y_transformed = normalized_Y_test.min( - dim=-1 - ).values + normalized_Y_test.sum(dim=-1) + expected_Y_transformed = -( + normalized_neg_Y_test.max(dim=-1).values + + normalized_neg_Y_test.sum(dim=-1) + ) self.assertTrue(torch.equal(Y_transformed, expected_Y_transformed)) # Test different weights weights = torch.tensor([0.3, 0.7], **tkwargs) @@ -64,9 +75,10 @@ def test_get_chebyshev_scalarization(self): weights=weights, Y=Y_train ) Y_transformed = objective_transform(Y_test) - expected_Y_transformed = (weights * normalized_Y_test).min( - dim=-1 - ).values + 0.05 * (weights * normalized_Y_test).sum(dim=-1) + expected_Y_transformed = -( + (weights * normalized_neg_Y_test).max(dim=-1).values + + 0.05 * (weights * normalized_neg_Y_test).sum(dim=-1) + ) self.assertTrue(torch.equal(Y_transformed, expected_Y_transformed)) # test that when minimizing an objective (i.e. with a negative weight), # normalized Y values are shifted from [0,1] to [-1,0] @@ -75,10 +87,11 @@ def test_get_chebyshev_scalarization(self): weights=weights, Y=Y_train ) Y_transformed = objective_transform(Y_test) - normalized_Y_test[..., -1] = normalized_Y_test[..., -1] - 1 - expected_Y_transformed = (weights * normalized_Y_test).min( - dim=-1 - ).values + 0.05 * (weights * normalized_Y_test).sum(dim=-1) + normalized_neg_Y_test[..., -1] = normalized_neg_Y_test[..., -1] - 1 + expected_Y_transformed = -( + (weights * normalized_neg_Y_test).max(dim=-1).values + + 0.05 * (weights * normalized_neg_Y_test).sum(dim=-1) + ) self.assertTrue(torch.equal(Y_transformed, expected_Y_transformed)) # test that with no observations there is no normalization weights = torch.tensor([0.3, 0.7], **tkwargs) @@ -86,18 +99,24 @@ def test_get_chebyshev_scalarization(self): weights=weights, Y=Y_train[:0] ) Y_transformed = objective_transform(Y_test) - expected_Y_transformed = (weights * Y_test).min( - dim=-1 - ).values + 0.05 * (weights * Y_test).sum(dim=-1) + expected_Y_transformed = -( + (weights * neg_Y_test).max(dim=-1).values + + 0.05 * (weights * neg_Y_test).sum(dim=-1) + ) self.assertTrue(torch.equal(Y_transformed, expected_Y_transformed)) - # test that with one observation, we normalize by subtracting Y_train + # test that error is raised with negative weights and empty Y + with self.assertRaises(UnsupportedError): + get_chebyshev_scalarization(weights=-weights, Y=Y_train[:0]) + # test that with one observation, we normalize by subtracting + # neg_Y_train single_Y_train = Y_train[:1] objective_transform = get_chebyshev_scalarization( weights=weights, Y=single_Y_train ) Y_transformed = objective_transform(Y_test) - normalized_Y_test = Y_test - single_Y_train - expected_Y_transformed = (weights * normalized_Y_test).min( - dim=-1 - ).values + 0.05 * (weights * normalized_Y_test).sum(dim=-1) + normalized_neg_Y_test = neg_Y_test + single_Y_train + expected_Y_transformed = -( + (weights * normalized_neg_Y_test).max(dim=-1).values + + 0.05 * (weights * normalized_neg_Y_test).sum(dim=-1) + ) self.assertAllClose(Y_transformed, expected_Y_transformed)