diff --git a/botorch/fit.py b/botorch/fit.py index bd35d1900b..bcbd75e060 100644 --- a/botorch/fit.py +++ b/botorch/fit.py @@ -10,95 +10,78 @@ import logging from contextlib import nullcontext -from re import compile, Pattern -from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union -from warnings import catch_warnings, simplefilter, warn, WarningMessage +from functools import partial +from itertools import filterfalse +from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Tuple, Type, Union +from warnings import catch_warnings, simplefilter, warn, warn_explicit, WarningMessage from botorch.exceptions.errors import ModelFittingError, UnsupportedError from botorch.exceptions.warnings import BotorchWarning, OptimizationWarning +from botorch.models.approximate_gp import ApproximateGPyTorchModel from botorch.models.converter import batched_to_model_list, model_list_to_batched from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP - from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel from botorch.models.model_list_gp_regression import ModelListGP -from botorch.optim.fit import fit_gpytorch_scipy +from botorch.optim.closures import get_loss_closure_with_grads +from botorch.optim.core import _LBFGSB_MAXITER_MAXFUN_REGEX +from botorch.optim.fit import fit_gpytorch_mll_scipy, fit_gpytorch_mll_torch from botorch.optim.utils import ( + _warning_handler_template, allclose_mll, + get_parameters, + sample_all_priors, +) +from botorch.settings import debug +from botorch.utils.context_managers import ( del_attribute_ctx, + module_rollback_ctx, parameter_rollback_ctx, requires_grad_ctx, - sample_all_priors, - state_rollback_ctx, - Tkwargs, + TensorCheckpoint, +) +from botorch.utils.dispatcher import ( + Dispatcher, + MDNotImplementedError, + type_bypassing_encoder, ) -from botorch.settings import debug -from botorch.utils.dispatcher import Dispatcher, MDNotImplementedError from gpytorch.likelihoods import Likelihood +from gpytorch.mlls._approximate_mll import _ApproximateMarginalLogLikelihood from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood from linear_operator.utils.errors import NotPSDError from pyro.infer.mcmc import MCMC, NUTS from torch import device, mean, Tensor +from torch.nn import Parameter +from torch.utils.data import DataLoader -OptimizerType = Callable[[MarginalLogLikelihood], Tuple[MarginalLogLikelihood, Any]] -DEFAULT_LOGGING_PATTERNS: Dict[int, Pattern] = { - logging.DEBUG: compile( # catch warning corresponding to `maxiter` and `maxfun` - "TOTAL NO. of (ITERATIONS REACHED LIMIT|f AND g EVALUATIONS EXCEEDS LIMIT)" - ) -} - - -def DEFAULT_WARNING_FILTER( - w: WarningMessage, - logging_patterns: Dict[int, Pattern] = DEFAULT_LOGGING_PATTERNS, -) -> bool: - r"""Default warning resolution policy: retry upon encountering an - OptimizationWarning that does not match any logging pattern. - - Args: - w: Candidate for filtering. - logging_patterns: Dictionary mapping logging levels to regular expressions. - Warning messages are compared against these expressions and matches are - awarded first-come-first-serve when iterating through the dictionary. - - Returns: - Boolean indicating whether the warning is unresolved. - """ - for level, pattern in logging_patterns.items(): - if pattern.search(str(w.message)): - logging.log(level, w.message) - return False - - # Rethrow OptimizationWarnings but mark them as resolved - if not issubclass(w.category, OptimizationWarning): - warn(w.message, w.category) - return False - - return True - - -# Dispatcher for `fit_gpytorch_mll` -def _type_bypassing_encoder(arg: Any) -> Type: - # Allow type variables to be passed as pre-encoded arguments - return arg if isinstance(arg, type) else type(arg) - -dispatcher = Dispatcher("fit_gpytorch_mll", encoder=_type_bypassing_encoder) +DEFAULT_WARNING_HANDLER = partial( + _warning_handler_template, + debug=lambda w: _LBFGSB_MAXITER_MAXFUN_REGEX.search(str(w.message)), + rethrow=lambda w: not issubclass(w.category, OptimizationWarning), +) +FitGPyTorchMLL = Dispatcher("fit_gpytorch_mll", encoder=type_bypassing_encoder) def fit_gpytorch_mll( mll: MarginalLogLikelihood, + closure: Optional[Callable[[], Tuple[Tensor, Sequence[Optional[Tensor]]]]] = None, optimizer: Optional[Callable] = None, - optimizer_kwargs: Optional[dict] = None, + closure_kwargs: Optional[Dict[str, Any]] = None, + optimizer_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> MarginalLogLikelihood: r"""Clearing house for fitting models passed as GPyTorch MarginalLogLikelihoods. Args: mll: A GPyTorch MarginalLogLikelihood instance. + closure: Forward-backward closure for obtaining objective values and gradients. + Responsible for setting parameters' `grad` attributes. If no closure is + provided, one will be obtained by calling `get_loss_closure_with_grads`. optimizer: User specified optimization algorithm. When `optimizer is None`, this keyword argument is omitted when calling the dispatcher. + closure_kwargs: Keyword arguments passed when calling `closure`. optimizer_kwargs: A dictionary of keyword arguments passed when calling `optimizer`. **kwargs: Keyword arguments passed down through the dispatcher to @@ -111,10 +94,12 @@ def fit_gpytorch_mll( if optimizer is not None: # defer to per-method defaults kwargs["optimizer"] = optimizer - return dispatcher( + return FitGPyTorchMLL( mll, type(mll.likelihood), type(mll.model), + closure=closure, + closure_kwargs=closure_kwargs, optimizer_kwargs=optimizer_kwargs, **kwargs, ) @@ -122,7 +107,7 @@ def fit_gpytorch_mll( def fit_gpytorch_model( mll: MarginalLogLikelihood, - optimizer: Optional[OptimizerType] = None, + optimizer: Optional[Callable] = None, optimizer_kwargs: Optional[dict] = None, exclude: Optional[Iterable[str]] = None, max_retries: Optional[int] = None, @@ -136,6 +121,7 @@ def fit_gpytorch_model( optimizer: User specified optimization algorithm. When `optimizer is None`, this keyword argument is omitted when calling the dispatcher from inside `fit_gpytorch_mll`. + optimizer_kwargs: Keyword arguments passed to `optimizer`. exclude: Legacy argument for specifying parameters `x` that should be held fixed during optimization. Internally, used to temporarily set `x.requires_grad` to False. @@ -151,7 +137,7 @@ def fit_gpytorch_model( kwargs["max_attempts"] = max_retries optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs - for key in ("bounds", "options", "track_iterations", "approx_mll"): + for key in ("bounds", "options"): if key not in kwargs: continue @@ -179,16 +165,18 @@ def fit_gpytorch_model( return mll -@dispatcher.register(MarginalLogLikelihood, object, object) +@FitGPyTorchMLL.register(MarginalLogLikelihood, object, object) def _fit_fallback( mll: MarginalLogLikelihood, _: Type[object], __: Type[object], *, - optimizer: Optional[Callable] = fit_gpytorch_scipy, - optimizer_kwargs: Optional[dict] = None, + closure: Optional[Callable[[], Tuple[Tensor, Sequence[Optional[Tensor]]]]] = None, + optimizer: Optional[Callable] = fit_gpytorch_mll_scipy, + closure_kwargs: Optional[Dict[str, Any]] = None, + optimizer_kwargs: Optional[Dict[str, Any]] = None, max_attempts: int = 5, - warning_filter: Callable[[WarningMessage], bool] = DEFAULT_WARNING_FILTER, + warning_handler: Callable[[WarningMessage], bool] = DEFAULT_WARNING_HANDLER, caught_exception_types: Tuple[Type[BaseException], ...] = (NotPSDError,), **ignore: Any, ) -> MarginalLogLikelihood: @@ -200,8 +188,12 @@ def _fit_fallback( by resampling tunable parameters. Args: + closure: Forward-backward closure for obtaining objective values and gradients. + Responsible for setting parameters' `grad` attributes. If no closure is + provided, one will be obtained by calling `get_loss_closure_with_grads`. optimizer: The underlying optimization algorithm to run. - optimizer_kwargs: Keyword arguments passed when calling `optimizer`. + closure_kwargs: Keyword arguments passed to `closure`. + optimizer_kwargs: Keyword arguments passed to `optimizer`. max_attempts: The maximum number of fit attempts allowed. The attempt budget is NOT shared between calls to this method. warning_filter: A function used to filter warnings produced when calling @@ -215,34 +207,46 @@ def _fit_fallback( The `mll` instance. If fitting succeeded, then `mll` will be in evaluation mode, i.e. `mll.training == False`. Otherwise, `mll` will be in training mode. """ - ckpt: Dict[str, Tuple[Tensor, Tkwargs]] = None # lazy CPU-based checkpoint - ckpt_nograd: Dict[str, Tuple[Tensor, Tkwargs]] = None # subset for fixed parameters + # Setup optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs + params_nograd: Dict[str, Parameter] = None # pyre-ignore [9] + ckpt_nograd: Dict[str, TensorCheckpoint] = None # pyre-ignore [9] + ckpt: Dict[str, TensorCheckpoint] = None # pyre-ignore [9] + # Build closure mll.train() + if closure is None: + closure = get_loss_closure_with_grads( + mll, parameters=get_parameters(mll, requires_grad=True) + ) + if closure_kwargs is not None: + closure = partial(closure, **closure_kwargs) + + # Attempt to fit the model for attempt in range(1, 1 + max_attempts): - # Wrap with rollback contextmanager so each loop iteration reloads the original - # state_dict upon exiting (unless `ckpt` is cleared). - with state_rollback_ctx(mll, checkpoint=ckpt, device=device("cpu")) as ckpt: - if ckpt_nograd is None: - ckpt_nograd = { # reuse cached values from primary checkpoint - k: ckpt[k] for k, v in mll.named_parameters() if not v.requires_grad - } - - if attempt > 1: # maybe resample parameters that require gradients - with parameter_rollback_ctx(mll, checkpoint=ckpt_nograd): + # Wrap with rollback contextmanager so that each loop iteration reloads the + # original state_dict upon exiting (unless we clear `ckpt`). + with module_rollback_ctx(mll, checkpoint=ckpt, device=device("cpu")) as ckpt: + if attempt > 1: # resample free parameters + if params_nograd is None: + params_nograd = get_parameters(mll, requires_grad=False) + + if ckpt_nograd is None: # reuse primary checkpoint + ckpt_nograd = {name: ckpt[name] for name in params_nograd} + + with parameter_rollback_ctx(params_nograd, checkpoint=ckpt_nograd): sample_all_priors(mll.model) try: # Fit the model with catch_warnings(record=True) as warning_list, debug(True): simplefilter("always", category=OptimizationWarning) - mll, _ = optimizer(mll, **optimizer_kwargs) + optimizer(mll, closure=closure, **optimizer_kwargs) - # Resolve warning messages and determine whether or not to retry + # Resolved warnings and determine whether or not to retry done = True - for unresolved_warning in filter(warning_filter, warning_list): - warn(unresolved_warning.message, unresolved_warning.category) + for w in filterfalse(warning_handler, warning_list): + warn_explicit(str(w.message), w.category, w.filename, w.lineno) done = False if done: @@ -264,10 +268,14 @@ def _fit_fallback( f"{err}", ) - raise ModelFittingError("All attempts to fit the model have failed.") + msg = "All attempts to fit the model have failed." + if debug.off(): + msg = msg + " For more information, try enabling botorch.settings.debug mode." + + raise ModelFittingError(msg) -@dispatcher.register(SumMarginalLogLikelihood, Likelihood, ModelListGP) +@FitGPyTorchMLL.register(SumMarginalLogLikelihood, object, ModelListGP) def _fit_list( mll: SumMarginalLogLikelihood, _: Type[Likelihood], @@ -291,18 +299,26 @@ def _fit_list( return mll.eval() if not any(sub_mll.training for sub_mll in mll.mlls) else mll -@dispatcher.register(MarginalLogLikelihood, Likelihood, BatchedMultiOutputGPyTorchModel) +@FitGPyTorchMLL.register( + (MarginalLogLikelihood, _ApproximateMarginalLogLikelihood), + object, + BatchedMultiOutputGPyTorchModel, +) def _fit_multioutput_independent( mll: MarginalLogLikelihood, _: Type[Likelihood], __: Type[BatchedMultiOutputGPyTorchModel], *, + closure: Optional[Callable[[], Tuple[Tensor, Sequence[Optional[Tensor]]]]] = None, sequential: bool = True, **kwargs: Any, ) -> MarginalLogLikelihood: r"""Fitting routine for multioutput Gaussian processes. Args: + closure: Forward-backward closure for obtaining objective values and gradients. + Responsible for setting parameters' `grad` attributes. If no closure is + provided, one will be obtained by calling `get_loss_closure_with_grads`. sequential: Boolean specifying whether or not to an attempt should be made to fit the model as a collection of independent GPs. Only relevant for certain types of GPs with independent outputs, see `batched_to_model_list`. @@ -314,6 +330,7 @@ def _fit_multioutput_independent( """ if ( # incompatible models not sequential + or closure is not None or mll.model.num_outputs == 1 or mll.likelihood is not getattr(mll.model, "likelihood", None) ): @@ -340,7 +357,7 @@ def _fit_multioutput_independent( # Repackage submodels and copy over state_dict repacked_model = model_list_to_batched(unpacked_mll.model.train()) repacked_mll = type(mll)(repacked_model.likelihood, repacked_model) - with state_rollback_ctx(mll, device=device("cpu")) as ckpt: + with module_rollback_ctx(mll, device=device("cpu")) as ckpt: mll.load_state_dict(repacked_mll.state_dict()) if not allclose_mll(a=mll, b=repacked_mll): raise RuntimeError( # validate model repacking @@ -359,6 +376,55 @@ def _fit_multioutput_independent( raise MDNotImplementedError +@FitGPyTorchMLL.register(_ApproximateMarginalLogLikelihood, object, object) +def _fit_fallback_approximate( + mll: _ApproximateMarginalLogLikelihood, + _: Type[Likelihood], + __: Type[ApproximateGPyTorchModel], + *, + closure: Optional[Callable[[], Tuple[Tensor, Sequence[Optional[Tensor]]]]] = None, + data_loader: Optional[DataLoader] = None, + optimizer: Optional[Callable] = None, + full_batch_limit: int = 1024, # TODO: To be determined. + **kwargs: Any, +) -> _ApproximateMarginalLogLikelihood: + r"""Fallback method for fitting approximate Gaussian processes. + + Args: + closure: Forward-backward closure for obtaining objective values and gradients. + Responsible for setting parameters' `grad` attributes. If no closure is + provided, one will be obtained by calling `get_loss_closure_with_grads`. + optimizer: The underlying optimization algorithm to run. Default to + `fit_gpytorch_mll_scipy` when `closure=None` and the model's internal + training set has no more than `full_batch_cutoff` observations; otherwise, + defaults to `fit_gpytorch_mll_torch`. + data_loader: An optional DataLoader to pass to `get_loss_closure_with_grads`. + May only be provided when `closure=None`. + full_batch_limit: Threshold for determining the default choice of `optimizer` + when `closure=None`. + **kwargs: Keyword arguments passed to `_fit_fallback`. + """ + if data_loader is not None: + if closure is not None: + raise UnsupportedError( + "Only one of `data_loader` or `closure` may be passed." + ) + closure = get_loss_closure_with_grads( + mll=mll, + data_loader=data_loader, + parameters=get_parameters(mll, requires_grad=True), + ) + + if optimizer is None: + optimizer = ( + fit_gpytorch_mll_scipy + if closure is None and len(mll.model.train_targets) <= full_batch_limit + else fit_gpytorch_mll_torch + ) + + return _fit_fallback(mll, _, __, closure=closure, optimizer=optimizer, **kwargs) + + def fit_fully_bayesian_model_nuts( model: Union[SaasFullyBayesianSingleTaskGP, SaasFullyBayesianMultiTaskGP], max_tree_depth: int = 6, diff --git a/botorch/models/pairwise_gp.py b/botorch/models/pairwise_gp.py index 4c726bfb34..5be462451f 100644 --- a/botorch/models/pairwise_gp.py +++ b/botorch/models/pairwise_gp.py @@ -818,7 +818,7 @@ def forward(self, datapoints: Tensor) -> MultivariateNormal: # self.utility might be None if exception was raised and _update # was failed to be called during hyperparameter optimization - # procedures (e.g., fit_gpytorch_scipy) + # procedures (e.g., fit_gpytorch_mll_scipy) if self.utility is None: self._update(transformed_dp) diff --git a/botorch/optim/__init__.py b/botorch/optim/__init__.py index a1d914582a..540752d1e0 100644 --- a/botorch/optim/__init__.py +++ b/botorch/optim/__init__.py @@ -4,6 +4,17 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from botorch.optim.closures import ( + ForwardBackwardClosure, + get_loss_closure, + get_loss_closure_with_grads, +) +from botorch.optim.core import ( + OptimizationResult, + OptimizationStatus, + scipy_minimize, + torch_minimize, +) from botorch.optim.initializers import initialize_q_batch, initialize_q_batch_nonneg from botorch.optim.numpy_converter import module_to_array, set_params_with_array from botorch.optim.optimize import ( @@ -18,15 +29,22 @@ __all__ = [ + "ForwardBackwardClosure", + "get_loss_closure", + "get_loss_closure_with_grads", "gen_batch_initial_conditions", "initialize_q_batch", "initialize_q_batch_nonneg", + "OptimizationResult", + "OptimizationStatus", "optimize_acqf", "optimize_acqf_cyclic", "optimize_acqf_discrete", "optimize_acqf_discrete_local_search", "optimize_acqf_mixed", "module_to_array", + "scipy_minimize", "set_params_with_array", + "torch_minimize", "ExpMAStoppingCriterion", ] diff --git a/botorch/optim/closures/__init__.py b/botorch/optim/closures/__init__.py new file mode 100644 index 0000000000..e9a63b4235 --- /dev/null +++ b/botorch/optim/closures/__init__.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from botorch.optim.closures.core import ( + ForwardBackwardClosure, + NdarrayOptimizationClosure, +) +from botorch.optim.closures.model_closures import ( + get_loss_closure, + get_loss_closure_with_grads, +) + + +__all__ = [ + "ForwardBackwardClosure", + "get_loss_closure", + "get_loss_closure_with_grads", + "NdarrayOptimizationClosure", +] diff --git a/botorch/optim/closures/core.py b/botorch/optim/closures/core.py new file mode 100644 index 0000000000..d2363ad1fb --- /dev/null +++ b/botorch/optim/closures/core.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Core methods for building closures in torch and interfacing with numpy.""" + +from __future__ import annotations + +from functools import partial +from typing import Any, Callable, Dict, Optional, Sequence, Tuple + +import torch +from botorch.optim.utils import ( + _handle_numerical_errors, + get_tensors_as_ndarray_1d, + set_tensors_from_ndarray_1d, +) +from botorch.optim.utils.numpy_utils import as_ndarray +from botorch.utils.context_managers import zero_grad_ctx +from numpy import float64 as np_float64, full as np_full, ndarray, zeros as np_zeros +from torch import Tensor + + +class ForwardBackwardClosure: + r"""Wrapper for fused forward and backward closures.""" + + def __init__( + self, + forward: Callable[[], Tensor], + parameters: Dict[str, Tensor], + backward: Callable[[Tensor], None] = Tensor.backward, + reducer: Optional[Callable[[Tensor], Tensor]] = torch.sum, + callback: Optional[Callable[[Tensor, Sequence[Optional[Tensor]]], None]] = None, + context_manager: Callable = None, # pyre-ignore [9] + ) -> None: + r"""Initializes a ForwardBackwardClosure instance. + + Args: + closure: Callable that returns a tensor. + parameters: A dictionary of tensors whose `grad` fields are to be returned. + backward: Callable that takes the (reduced) output of `forward` and sets the + `grad` attributes of tensors in `parameters`. + reducer: Optional callable used to reduce the output of the forward pass. + callback: Optional callable that takes the reduced output of `forward` and + the gradients of `parameters` as positional arguments. + context_manager: A ContextManager used to wrap each forward-backward call. + When passed as `None`, `context_manager` defaults to a `zero_grad_ctx` + that zeroes the gradients of `parameters` upon entry. + """ + if context_manager is None: + context_manager = partial(zero_grad_ctx, parameters) + + self.forward = forward + self.backward = backward + self.parameters = parameters + self.reducer = reducer + self.callback = callback + self.context_manager = context_manager + + def __call__(self, **kwargs: Any) -> Tuple[Tensor, Tuple[Optional[Tensor], ...]]: + with self.context_manager(): + values = self.forward(**kwargs) + value = values if self.reducer is None else self.reducer(values) + self.backward(value) + + grads = tuple(param.grad for param in self.parameters.values()) + if self.callback: + self.callback(value, grads) + + return value, grads + + +class NdarrayOptimizationClosure: + r"""Adds stateful behavior and a numpy.ndarray-typed API to a closure with an + expected return type Tuple[Tensor, Union[Tensor, Sequence[Optional[Tensor]]]].""" + + def __init__( + self, + closure: Callable[[], Tuple[Tensor, Sequence[Optional[Tensor]]]], + parameters: Dict[str, Tensor], + as_array: Callable[[Tensor], ndarray] = None, # pyre-ignore [9] + as_tensor: Callable[[ndarray], Tensor] = torch.as_tensor, + get_state: Callable[[], ndarray] = None, # pyre-ignore [9] + set_state: Callable[[ndarray], None] = None, # pyre-ignore [9] + fill_value: float = 0.0, + persistent: bool = True, + ) -> None: + r"""Initializes a NdarrayOptimizationClosure instance. + + Args: + closure: A ForwardBackwardClosure instance. + parameters: A dictionary of tensors representing the closure's state. + Expected to correspond with the first `len(parameters)` optional + gradient tensors returned by `closure`. + as_array: Callable used to convert tensors to ndarrays. + as_tensor: Callable used to convert ndarrays to tensors. + get_state: Callable that returns the closure's state as an ndarray. When + passed as `None`, defaults to calling `get_tensors_as_ndarray_1d` + on `closure.parameters` while passing `as_array` (if given by the user). + set_state: Callable that takes a 1-dimensional ndarray and sets the + closure's state. When passed as `None`, `set_state` defaults to + calling `set_tensors_from_ndarray_1d` with `closure.parameters` and + a given ndarray while passing `as_tensor`. + fill_value: Fill value for parameters whose gradients are None. In most + cases, `fill_value` should either be zero or NaN. + persistent: Boolean specifying whether an ndarray should be retained + as a persistent buffer for gradients. + """ + if get_state is None: + # Note: Numpy supports copying data between ndarrays with different dtypes. + # Hence, our default behavior need not coerce the ndarray represenations of + # tensors in `parameters` to float64 when copying over data. + _as_array = as_ndarray if as_array is None else as_array + get_state = partial( + get_tensors_as_ndarray_1d, parameters, as_array=_as_array + ) + + if as_array is None: # per the note, do this after resolving `get_state` + as_array = partial(as_ndarray, dtype=np_float64) + + if set_state is None: + set_state = partial( + set_tensors_from_ndarray_1d, parameters, as_tensor=as_tensor + ) + + self.closure = closure + self.parameters = parameters + + self.as_array = as_ndarray + self.as_tensor = as_tensor + self._get_state = get_state + self._set_state = set_state + + self.fill_value = fill_value + self.persistent = persistent + self._gradient_ndarray: Optional[ndarray] = None + + def __call__( + self, state: Optional[ndarray] = None, **kwargs: Any + ) -> Tuple[ndarray, ndarray]: + if state is not None: + self.state = state + + try: + value_tensor, grad_tensors = self.closure(**kwargs) + value = self.as_array(value_tensor) + grads = self._get_gradient_ndarray(fill_value=self.fill_value) + index = 0 + for param, grad in zip(self.parameters.values(), grad_tensors): + size = param.numel() + if grad is not None: + grads[index : index + size] = self.as_array(grad.view(-1)) + index += size + except RuntimeError as e: + value, grads = _handle_numerical_errors(error=e, x=self.state) + + return value, grads + + @property + def state(self) -> ndarray: + return self._get_state() + + @state.setter + def state(self, state: ndarray) -> None: + self._set_state(state) + + def _get_gradient_ndarray(self, fill_value: Optional[float] = None) -> ndarray: + if self.persistent and self._gradient_ndarray is not None: + if fill_value is not None: + self._gradient_ndarray.fill(fill_value) + return self._gradient_ndarray + + size = sum(param.numel() for param in self.parameters.values()) + array = ( + np_zeros(size) + if fill_value is None or fill_value == 0.0 + else np_full(size, fill_value) + ) + if self.persistent: + self._gradient_ndarray = array + + return array diff --git a/botorch/optim/closures/model_closures.py b/botorch/optim/closures/model_closures.py new file mode 100644 index 0000000000..8e4c39a0f2 --- /dev/null +++ b/botorch/optim/closures/model_closures.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r"""Utilities for building model-based closures.""" + +from __future__ import annotations + +from itertools import chain, repeat +from typing import Any, Callable, Dict, Optional, Sequence, Tuple + +from botorch.optim.closures.core import ForwardBackwardClosure +from botorch.optim.utils import TNone +from botorch.utils.dispatcher import Dispatcher, type_bypassing_encoder +from gpytorch.mlls import ( + ExactMarginalLogLikelihood, + MarginalLogLikelihood, + SumMarginalLogLikelihood, +) +from torch import Tensor +from torch.utils.data import DataLoader + +GetLossClosure = Dispatcher("get_loss_closure", encoder=type_bypassing_encoder) +GetLossClosureWithGrads = Dispatcher( + "get_loss_closure_with_grads", encoder=type_bypassing_encoder +) + + +def get_loss_closure( + mll: MarginalLogLikelihood, + data_loader: Optional[DataLoader] = None, + **kwargs: Any, +) -> Callable[[], Tensor]: + r"""Public API for GetLossClosure dispatcher. + + This method, and the dispatcher that powers it, acts as a clearing house + for factory functions that define how `mll` is evaluated. + + Users may specify custom evaluation routines by registering a factory function + with GetLossClosure. These factories should be registered using the type signature + + `Type[MarginalLogLikeLihood], Type[Likelihood], Type[Model], Type[DataLoader]`. + + The final argument, Type[DataLoader], is optional. Evaluation routines that obtain + training data from, e.g., `mll.model` should register this argument as `type(None)`. + + Args: + mll: A MarginalLogLikelihood instance whose negative defines the loss. + data_loader: An optional DataLoader instance for cases where training + data is passed in rather than obtained from `mll.model`. + + Returns: + A closure that takes zero positional arguments and returns the negated + value of `mll`. + """ + return GetLossClosure( + mll, type(mll.likelihood), type(mll.model), data_loader, **kwargs + ) + + +def get_loss_closure_with_grads( + mll: MarginalLogLikelihood, + parameters: Dict[str, Tensor], + data_loader: Optional[DataLoader] = None, + backward: Callable[[Tensor], None] = Tensor.backward, + reducer: Optional[Callable[[Tensor], Tensor]] = Tensor.sum, + context_manager: Optional[Callable] = None, + **kwargs: Any, +) -> Callable[[], Tuple[Tensor, Tuple[Tensor, ...]]]: + r"""Public API for GetLossClosureWithGrads dispatcher. + + In most cases, this method simply adds a backward pass to a loss closure obtained by + calling `get_loss_closure`. For further details, see `get_loss_closure`. + + Args: + mll: A MarginalLogLikelihood instance whose negative defines the loss. + parameters: A dictionary of tensors whose `grad` fields are to be returned. + reducer: Optional callable used to reduce the output of the forward pass. + data_loader: An optional DataLoader instance for cases where training + data is passed in rather than obtained from `mll.model`. + context_manager: An optional ContextManager used to wrap each forward-backward + pass. Defaults to a `zero_grad_ctx` that zeroes the gradients of + `parameters` upon entry. None may be passed as an alias for `nullcontext`. + + Returns: + A closure that takes zero positional arguments and returns the reduced and + negated value of `mll` along with the gradients of `parameters`. + """ + return GetLossClosureWithGrads( + mll, + type(mll.likelihood), + type(mll.model), + data_loader, + parameters=parameters, + reducer=reducer, + backward=backward, + context_manager=context_manager, + **kwargs, + ) + + +@GetLossClosureWithGrads.register(object, object, object, object) +def _get_loss_closure_with_grads_fallback( + mll: MarginalLogLikelihood, + _: object, + __: object, + data_loader: Optional[DataLoader], + parameters: Dict[str, Tensor], + reducer: Callable[[Tensor], Tensor] = Tensor.sum, + backward: Callable[[Tensor], None] = Tensor.backward, + context_manager: Callable = None, # pyre-ignore [9] + **kwargs: Any, +) -> ForwardBackwardClosure: + r"""Wraps a `loss_closure` with a ForwardBackwardClosure.""" + loss_closure = get_loss_closure(mll, data_loader=data_loader, **kwargs) + return ForwardBackwardClosure( + forward=loss_closure, + backward=backward, + parameters=parameters, + reducer=reducer, + context_manager=context_manager, + ) + + +@GetLossClosure.register(MarginalLogLikelihood, object, object, DataLoader) +def _get_loss_closure_fallback_external( + mll: MarginalLogLikelihood, + _: object, + __: object, + data_loader: DataLoader, + **ignore: Any, +) -> Callable[[], Tensor]: + r"""Fallback loss closure with externally provided data.""" + batch_generator = chain.from_iterable(iter(data_loader) for _ in repeat(None)) + + def closure(**kwargs: Any) -> Tensor: + batch = next(batch_generator) + if not isinstance(batch, Sequence): + raise TypeError( + "Expected `data_loader` to generate a batch of tensors, " + f"but found {type(batch)}." + ) + + num_inputs = len(mll.model.train_inputs) + model_output = mll.model(*batch[:num_inputs]) + log_likelihood = mll(model_output, *batch[num_inputs:], **kwargs) + return -log_likelihood + + return closure + + +@GetLossClosure.register(MarginalLogLikelihood, object, object, TNone) +def _get_loss_closure_fallback_internal( + mll: MarginalLogLikelihood, _: object, __: object, ___: TNone, **ignore: Any +) -> Callable[[], Tensor]: + r"""Fallback loss closure with internally managed data.""" + + def closure(**kwargs: Any) -> Tensor: + model_output = mll.model(*mll.model.train_inputs) + log_likelihood = mll(model_output, mll.model.train_targets, **kwargs) + return -log_likelihood + + return closure + + +@GetLossClosure.register(ExactMarginalLogLikelihood, object, object, TNone) +def _get_loss_closure_exact_internal( + mll: ExactMarginalLogLikelihood, _: object, __: object, ___: TNone, **ignore: Any +) -> Callable[[], Tensor]: + r"""ExactMarginalLogLikelihood loss closure with internally managed data.""" + + def closure(**kwargs: Any) -> Tensor: + model_output = mll.model(*mll.model.train_inputs) + log_likelihood = mll( + model_output, mll.model.train_targets, *mll.model.train_inputs, **kwargs + ) + return -log_likelihood + + return closure + + +@GetLossClosure.register(SumMarginalLogLikelihood, object, object, TNone) +def _get_loss_closure_sum_internal( + mll: SumMarginalLogLikelihood, _: object, __: object, ___: TNone, **ignore: Any +) -> Callable[[], Tensor]: + r"""SumMarginalLogLikelihood loss closure with internally managed data.""" + + def closure(**kwargs: Any) -> Tensor: + model_output = mll.model(*mll.model.train_inputs) + log_likelihood = mll( + model_output, + mll.model.train_targets, + *map(list, mll.model.train_inputs), + **kwargs, + ) + return -log_likelihood + + return closure diff --git a/botorch/optim/core.py b/botorch/optim/core.py new file mode 100644 index 0000000000..9110312fb3 --- /dev/null +++ b/botorch/optim/core.py @@ -0,0 +1,221 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r"""Core abstractions and generic optimizers.""" + +from __future__ import annotations + +import re +from dataclasses import dataclass, replace +from enum import auto, Enum +from itertools import count +from sys import maxsize +from time import monotonic +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +from botorch.optim.closures import NdarrayOptimizationClosure +from botorch.optim.utils import get_bounds_as_ndarray +from numpy import asarray, ndarray +from scipy.optimize import minimize +from torch import Tensor +from torch.optim.adam import Adam +from torch.optim.optimizer import Optimizer + +try: + from torch.optim.lr_scheduler import LRScheduler +except ImportError: # pragma: no cover + from torch.optim.lr_scheduler import _LRScheduler as LRScheduler # pragma: no cover + + +_LBFGSB_MAXITER_MAXFUN_REGEX = re.compile( # regex for maxiter and maxfun messages + "TOTAL NO. of (ITERATIONS REACHED LIMIT|f AND g EVALUATIONS EXCEEDS LIMIT)" +) + + +class OptimizationStatus(int, Enum): + RUNNING = auto() # incomplete + SUCCESS = auto() # optimizer converged + FAILURE = auto() # terminated abnormally + STOPPED = auto() # stopped due to user provided criterion + + +@dataclass +class OptimizationResult: + step: int + fval: Union[float, int] + status: OptimizationStatus + runtime: Optional[float] = None + message: Optional[str] = None + + +def scipy_minimize( + closure: Union[ + Callable[[], Tuple[Tensor, Sequence[Optional[Tensor]]]], + NdarrayOptimizationClosure, + ], + parameters: Dict[str, Tensor], + bounds: Optional[Dict[str, Tuple[Optional[float], Optional[float]]]] = None, + callback: Optional[Callable[[Dict[str, Tensor], OptimizationResult], None]] = None, + x0: Optional[ndarray] = None, + method: str = "L-BFGS-B", + options: Optional[Dict[str, Any]] = None, +) -> OptimizationResult: + r"""Generic scipy.optimize.minimize-based optimization routine. + + Args: + closure: Callable that returns a tensor and an iterable of gradient tensors or + NdarrayOptimizationClosure instance. + parameters: A dictionary of tensors to be optimized. + bounds: A dictionary mapping parameter names to lower and upper bounds. + callback: A callable taking `parameters` and an OptimizationResult as arguments. + x0: An optional initialization vector passed to scipy.optimize.minimize. + method: Solver type, passed along to scipy.minimize. + options: Dictionary of solver options, passed along to scipy.minimize. + + Returns: + An OptimizationResult summarizing the final state of the run. + """ + start_time = monotonic() + wrapped_closure = ( + closure + if isinstance(closure, NdarrayOptimizationClosure) + else NdarrayOptimizationClosure(closure, parameters) + ) + if bounds is None: + bounds_np = None + else: + bounds_np = get_bounds_as_ndarray(parameters, bounds) + + if callback is None: + wrapped_callback = None + else: + call_counter = count(1) # callbacks are typically made at the end of each iter + + def wrapped_callback(x: ndarray): + result = OptimizationResult( + step=next(call_counter), + fval=float(wrapped_closure(x)[0]), + status=OptimizationStatus.RUNNING, + runtime=monotonic() - start_time, + ) + return callback(parameters, result) # pyre-ignore [29] + + raw = minimize( + wrapped_closure, + wrapped_closure.state if x0 is None else x0, + jac=True, + bounds=bounds_np, + method=method, + options=options, + callback=wrapped_callback, + ) + + # Post-processing and outcome handling + wrapped_closure.state = asarray(raw.x) # set parameter state to optimal values + msg = raw.message if isinstance(raw.message, str) else raw.message.decode("ascii") + if raw.success: + status = OptimizationStatus.SUCCESS + else: + status = ( # Check whether we stopped due to reaching maxfun or maxiter + OptimizationStatus.STOPPED + if _LBFGSB_MAXITER_MAXFUN_REGEX.search(msg) + else OptimizationStatus.FAILURE + ) + + return OptimizationResult( + fval=raw.fun, + step=raw.nit, + status=status, + message=msg, + runtime=monotonic() - start_time, + ) + + +def torch_minimize( + closure: Callable[[], Tuple[Tensor, Sequence[Optional[Tensor]]]], + parameters: Dict[str, Tensor], + bounds: Optional[Dict[str, Tuple[Optional[float], Optional[float]]]] = None, + callback: Optional[Callable[[Dict[str, Tensor], OptimizationResult], None]] = None, + optimizer: Union[Optimizer, Callable[[List[Tensor]], Optimizer]] = Adam, + scheduler: Optional[Union[LRScheduler, Callable[[Optimizer], LRScheduler]]] = None, + step_limit: Optional[int] = None, + stopping_criterion: Optional[Callable[[Tensor], bool]] = None, +) -> OptimizationResult: + r"""Generic torch.optim-based optimization routine. + + Args: + closure: Callable that returns a tensor and an iterable of gradient tensors. + Responsible for setting relevant parameters' `grad` attributes. + parameters: A dictionary of tensors to be optimized. + bounds: An optional dictionary of bounds for elements of `parameters`. + callback: A callable taking `parameters` and an OptimizationResult as arguments. + step_limit: Integer specifying a maximum number of optimization steps. + One of `step_limit` or `stopping_criterion` must be passed. + stopping_criterion: A StoppingCriterion for the optimization loop. + optimizer: A `torch.optim.Optimizer` instance or a factory that takes + a list of parameters and returns an `Optimizer` instance. + scheduler: A `torch.optim.lr_scheduler._LRScheduler` instance or a factory + that takes a `Optimizer` instance and returns a `_LRSchedule` instance. + + Returns: + An OptimizationResult summarizing the final state of the run. + """ + start_time = monotonic() + if step_limit is None: + if stopping_criterion is None: + raise RuntimeError("No termination conditions were given.") + step_limit = maxsize + + if not isinstance(optimizer, Optimizer): + optimizer = optimizer(list(parameters.values())) + + if not (scheduler is None or isinstance(scheduler, LRScheduler)): + scheduler = scheduler(optimizer) + + _bounds = ( + {} + if bounds is None + else {name: limits for name, limits in bounds.items() if name in parameters} + ) + result: OptimizationResult + for step in range(step_limit): + fval, _ = closure() + result = OptimizationResult( + step=step, + fval=fval.detach().cpu().item(), + status=OptimizationStatus.RUNNING, + runtime=monotonic() - start_time, + ) + + # TODO: Update stopping_criterion API to return a message. + if stopping_criterion and stopping_criterion(fval): + result.status = OptimizationStatus.STOPPED + result.message = "`torch_minimize` stopped due to `stopping_criterion`." + + if callback: + callback(parameters, result) + + if result.status != OptimizationStatus.RUNNING: + break + + optimizer.step() + for name, (lower, upper) in _bounds.items(): + parameters[name].data = parameters[name].clamp(min=lower, max=upper) + + if scheduler: + scheduler.step() + + if result.status != OptimizationStatus.RUNNING: + return replace(result, runtime=monotonic() - start_time) + + # Account for final parameter update when stopping due to step_limit + return OptimizationResult( + step=step + 1, + fval=closure()[0].detach().cpu().item(), + status=OptimizationStatus.STOPPED, + runtime=monotonic() - start_time, + message=f"`torch_minimize` stopped after reaching step_limit={step_limit}.", + ) diff --git a/botorch/optim/fit.py b/botorch/optim/fit.py index aded76727f..8f507a75bb 100644 --- a/botorch/optim/fit.py +++ b/botorch/optim/fit.py @@ -4,15 +4,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -r""" -Tools for model fitting. -""" +r"""Tools for model fitting.""" from __future__ import annotations -import warnings +from functools import partial from itertools import filterfalse -from re import Pattern from time import monotonic from typing import ( Any, @@ -20,56 +17,186 @@ Dict, Iterator, List, - NamedTuple, Optional, + Pattern, + Sequence, Set, Tuple, Union, ) +from warnings import warn -import numpy as np from botorch.exceptions.warnings import OptimizationWarning +from botorch.optim.closures import get_loss_closure_with_grads +from botorch.optim.core import ( + OptimizationResult, + OptimizationStatus, + scipy_minimize, + torch_minimize, +) from botorch.optim.numpy_converter import ( + _scipy_objective_and_grad, module_to_array, set_params_with_array, - TorchAttr, ) from botorch.optim.stopping import ExpMAStoppingCriterion from botorch.optim.utils import ( _filter_kwargs, _get_extra_mll_args, - _scipy_objective_and_grad, - create_name_filter, + DEFAULT, + get_name_filter, + get_parameters_and_bounds, + TorchAttr, ) -from gpytorch import settings as gpt_settings +from botorch.optim.utils.model_utils import get_parameters from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood +from gpytorch.settings import fast_computations +from numpy import ndarray from scipy.optimize import Bounds, minimize from torch import Tensor from torch.nn import Module from torch.optim.adam import Adam +from torch.optim.lr_scheduler import _LRScheduler from torch.optim.optimizer import Optimizer - -ParameterBounds = Dict[str, Tuple[Optional[float], Optional[float]]] +TBoundsDict = Dict[str, Tuple[Optional[float], Optional[float]]] TScipyObjective = Callable[ - [np.ndarray, MarginalLogLikelihood, Dict[str, TorchAttr]], Tuple[float, np.ndarray] + [ndarray, MarginalLogLikelihood, Dict[str, TorchAttr]], Tuple[float, ndarray] ] TModToArray = Callable[ - [Module, Optional[ParameterBounds], Optional[Set[str]]], - Tuple[np.ndarray, Dict[str, TorchAttr], Optional[np.ndarray]], + [Module, Optional[TBoundsDict], Optional[Set[str]]], + Tuple[ndarray, Dict[str, TorchAttr], Optional[ndarray]], ] -TArrayToMod = Callable[[Module, np.ndarray, Dict[str, TorchAttr]], Module] +TArrayToMod = Callable[[Module, ndarray, Dict[str, TorchAttr]], Module] + + +def fit_gpytorch_mll_scipy( + mll: MarginalLogLikelihood, + parameters: Optional[Dict[str, Tensor]] = None, + bounds: Optional[Dict[str, Tuple[Optional[float], Optional[float]]]] = None, + closure: Optional[Callable[[], Tuple[Tensor, Sequence[Optional[Tensor]]]]] = None, + closure_kwargs: Optional[Dict[str, Any]] = None, + method: str = "L-BFGS-B", + options: Optional[Dict[str, Any]] = None, + callback: Optional[Callable[[Dict[str, Tensor], OptimizationResult], None]] = None, +) -> OptimizationResult: + r"""Generic scipy.optimized-based fitting routine for GPyTorch MLLs. + + The model and likelihood in mll must already be in train mode. + + Args: + mll: MarginalLogLikelihood to be maximized. + parameters: Optional dictionary of parameters to be optimized. Defaults + to all parameters of `mll` that require gradients. + bounds: A dictionary of user-specified bounds for `parameters`. Used to update + default parameter bounds obtained from `mll`. + closure: Callable that returns a tensor and an iterable of gradient tensors. + Responsible for setting the `grad` attributes of `parameters`. If no closure + is provided, one will be obtained by calling `get_loss_closure_with_grads`. + closure_kwargs: Keyword arguments passed to `closure`. + method: Solver type, passed along to scipy.minimize. + options: Dictionary of solver options, passed along to scipy.minimize. + callback: Optional callback taking `parameters` and an OptimizationResult as its + sole arguments. + + Returns: + The final OptimizationResult. + """ + # Resolve `parameters` and update default bounds + _parameters, _bounds = get_parameters_and_bounds(mll) + bounds = _bounds if bounds is None else {**_bounds, **bounds} + if parameters is None: + parameters = {n: p for n, p in _parameters.items() if p.requires_grad} + + if closure is None: + closure = get_loss_closure_with_grads(mll, parameters=parameters) + + if closure_kwargs is not None: + closure = partial(closure, **closure_kwargs) + + result = scipy_minimize( + closure=closure, + parameters=parameters, + bounds=bounds, + method=method, + options=options, + callback=callback, + ) + if result.status != OptimizationStatus.SUCCESS: + warn( + f"`scipy_minimize` terminated with status {result.status}, displaying" + f" original message from `scipy.optimize.minimize`: {result.message}", + OptimizationWarning, + ) + + return result + + +def fit_gpytorch_mll_torch( + mll: MarginalLogLikelihood, + parameters: Optional[Dict[str, Tensor]] = None, + bounds: Optional[Dict[str, Tuple[Optional[float], Optional[float]]]] = None, + closure: Optional[Callable[[], Tuple[Tensor, Sequence[Optional[Tensor]]]]] = None, + closure_kwargs: Optional[Dict[str, Any]] = None, + step_limit: Optional[int] = None, + stopping_criterion: Optional[Callable[[Tensor], bool]] = DEFAULT, # pyre-ignore [9] + optimizer: Union[Optimizer, Callable[..., Optimizer]] = Adam, + scheduler: Optional[Union[_LRScheduler, Callable[..., _LRScheduler]]] = None, + callback: Optional[Callable[[Dict[str, Tensor], OptimizationResult], None]] = None, +) -> OptimizationResult: + r"""Generic torch.optim-based fitting routine for GPyTorch MLLs. + + Args: + mll: MarginalLogLikelihood to be maximized. + parameters: Optional dictionary of parameters to be optimized. Defaults + to all parameters of `mll` that require gradients. + bounds: A dictionary of user-specified bounds for `parameters`. Used to update + default parameter bounds obtained from `mll`. + closure: Callable that returns a tensor and an iterable of gradient tensors. + Responsible for setting the `grad` attributes of `parameters`. If no closure + is provided, one will be obtained by calling `get_loss_closure_with_grads`. + closure_kwargs: Keyword arguments passed to `closure`. + step_limit: Optional upper bound on the number of optimization steps. + stopping_criterion: A StoppingCriterion for the optimization loop. + optimizer: A `torch.optim.Optimizer` instance or a factory that takes + a list of parameters and returns an `Optimizer` instance. + scheduler: A `torch.optim.lr_scheduler._LRScheduler` instance or a factory + that takes an `Optimizer` instance and returns an `_LRSchedule`. + callback: Optional callback taking `parameters` and an OptimizationResult as its + sole arguments. + + Returns: + The final OptimizationResult. + """ + if stopping_criterion == DEFAULT: + stopping_criterion = ExpMAStoppingCriterion() + + # Resolve `parameters` and update default bounds + param_dict, bounds_dict = get_parameters_and_bounds(mll) + if parameters is None: + parameters = {n: p for n, p in param_dict.items() if p.requires_grad} + if closure is None: + closure = get_loss_closure_with_grads(mll, parameters) -class OptimizationIteration(NamedTuple): - itr: int - fun: float - time: float + if closure_kwargs is not None: + closure = partial(closure, **closure_kwargs) + + return torch_minimize( + closure=closure, + parameters=parameters, + bounds=bounds_dict if bounds is None else {**bounds_dict, **bounds}, + optimizer=optimizer, + scheduler=scheduler, + step_limit=step_limit, + stopping_criterion=stopping_criterion, + callback=callback, + ) def fit_gpytorch_scipy( mll: MarginalLogLikelihood, - bounds: Optional[ParameterBounds] = None, + bounds: Optional[Dict[str, Tuple[Optional[float], Optional[float]]]] = None, method: str = "L-BFGS-B", options: Optional[Dict[str, Any]] = None, track_iterations: bool = False, @@ -77,20 +204,19 @@ def fit_gpytorch_scipy( scipy_objective: TScipyObjective = _scipy_objective_and_grad, module_to_array_func: TModToArray = module_to_array, module_from_array_func: TArrayToMod = set_params_with_array, -) -> Tuple[MarginalLogLikelihood, Dict[str, Union[float, List[OptimizationIteration]]]]: - r"""Fit a gpytorch model by maximizing MLL with a scipy optimizer. + **kwargs: Any, +) -> Tuple[MarginalLogLikelihood, Dict[str, Union[float, List[OptimizationResult]]]]: + r"""Legacy method for scipy-based fitting of gpytorch models. - The model and likelihood in mll must already be in train mode. - This method requires that the model has `train_inputs` and `train_targets`. + The model and likelihood in mll must already be in train mode. This method requires + that the model has `train_inputs` and `train_targets`. Args: mll: MarginalLogLikelihood to be maximized. bounds: A dictionary mapping parameter names to tuples of lower and upper bounds. - method: Solver type, passed along to scipy.minimize. - options: Dictionary of solver options, passed along to scipy.minimize. - track_iterations: Track the function values and wall time for each - iteration. + method: Solver type, passed along to scipy.optimize.minimize. + options: Dictionary of solver options, passed along to scipy.optimize.minimize. approx_mll: If True, use gpytorch's approximate MLL computation. This is disabled by default since the stochasticity is an issue for determistic optimizers). Enabling this is only recommended when @@ -102,97 +228,93 @@ def fit_gpytorch_scipy( - Dictionary with the following key/values: "fopt": Best mll value. "wall_time": Wall time of fitting. - "iterations": List of OptimizationIteration objects with information on each + "iterations": List of OptimizationResult objects with information on each iteration. If track_iterations is False, will be empty. "OptimizeResult": The result returned by `scipy.optim.minimize`. - - Example: - >>> gp = SingleTaskGP(train_X, train_Y) - >>> mll = ExactMarginalLogLikelihood(gp.likelihood, gp) - >>> mll.train() - >>> fit_gpytorch_scipy(mll) - >>> mll.eval() """ + warn( + "`fit_gpytorch_scipy` is marked for deprecation, consider using " + "`scipy_minimize` or its model fitting helper `fit_gpytorch_mll_scipy`.", + DeprecationWarning, + ) + start_time = monotonic() + iterations: List[OptimizationResult] = [] + options = {} if options is None else options.copy() exclude: Iterator[Union[Pattern, str]] = options.pop("exclude", None) if exclude: exclude, _ = zip( # get the qualified names of excluded parameters - *filterfalse(create_name_filter(exclude), mll.named_parameters()) + *filterfalse(get_name_filter(exclude), mll.named_parameters()) ) x0, property_dict, bounds = module_to_array_func( - module=mll, - bounds=bounds, - exclude=exclude, + module=mll, exclude=exclude, bounds=bounds ) - x0 = x0.astype(np.float64) if bounds is not None: bounds = Bounds(lb=bounds[0], ub=bounds[1], keep_feasible=True) - xs = [] - ts = [] - t1 = monotonic() + def wrapper(x: ndarray) -> Tuple[float, ndarray]: + with fast_computations(log_prob=approx_mll): + return scipy_objective(x=x, mll=mll, property_dict=property_dict) def store_iteration(xk): - xs.append(xk.copy()) - ts.append(monotonic() - t1) - - cb = store_iteration if track_iterations else None - with gpt_settings.fast_computations(log_prob=approx_mll): - res = minimize( - scipy_objective, - x0, - args=(mll, property_dict), - bounds=bounds, - method=method, - jac=True, - options=options, - callback=cb, + iterations.append( + OptimizationResult( + step=len(iterations), + fval=float(wrapper(xk)[0]), + status=OptimizationStatus.RUNNING, + runtime=monotonic() - start_time, + ) ) - iterations = [] - if track_iterations: - for i, xk in enumerate(xs): - obj, _ = scipy_objective(x=xk, mll=mll, property_dict=property_dict) - iterations.append(OptimizationIteration(i, obj, ts[i])) - # Construct info dict + result = minimize( + wrapper, + x0, + bounds=bounds, + method=method, + jac=True, + options=options, + callback=store_iteration if track_iterations else None, + ) + info_dict = { - "fopt": float(res.fun), - "wall_time": monotonic() - t1, + "fopt": float(result.fun), + "wall_time": monotonic() - start_time, "iterations": iterations, - "OptimizeResult": res, + "OptimizeResult": result, } - if not res.success: + if not result.success: try: - # Some res.message are bytes - msg = res.message.decode("ascii") + # Some result.message are bytes + msg = result.message.decode("ascii") except AttributeError: # Others are str - msg = res.message - warnings.warn( + msg = result.message + warn( f"Fitting failed with the optimizer reporting '{msg}'", OptimizationWarning ) + # Set to optimum - mll = module_from_array_func(mll, res.x, property_dict) + mll = module_from_array_func(mll, result.x, property_dict) return mll, info_dict def fit_gpytorch_torch( mll: MarginalLogLikelihood, - bounds: Optional[ParameterBounds] = None, + bounds: Optional[Dict[str, Tuple[Optional[float], Optional[float]]]] = None, optimizer_cls: Optimizer = Adam, options: Optional[Dict[str, Any]] = None, track_iterations: bool = False, approx_mll: bool = False, -) -> Tuple[MarginalLogLikelihood, Dict[str, Union[float, List[OptimizationIteration]]]]: - r"""Fit a gpytorch model by maximizing MLL with a torch optimizer. +) -> Tuple[MarginalLogLikelihood, Dict[str, Union[float, List[OptimizationResult]]]]: + r"""Legacy method for torch-based fitting of gpytorch models. The model and likelihood in mll must already be in train mode. Note: this method requires that the model has `train_inputs` and `train_targets`. Args: mll: MarginalLogLikelihood to be maximized. - bounds: A ParameterBounds dictionary mapping parameter names to tuples + bounds: An optional dictionary mapping parameter names to tuples of lower and upper bounds. Bounds specified here take precedence over bounds on the same parameters specified in the constraints registered with the module. @@ -201,12 +323,6 @@ def fit_gpytorch_torch( the `optimizer_cls`. Additionally, options can include: "disp" to specify whether to display model fitting diagnostics and "maxiter" to specify the maximum number of iterations. - track_iterations: Track the function values and wall time for each - iteration. - approx_mll: If True, use gpytorch's approximate MLL computation ( - according to the gpytorch defaults based on the training at size). - Unlike for the deterministic algorithms used in fit_gpytorch_scipy, - this is not an issue for stochastic optimizers. Returns: 2-element tuple containing @@ -214,7 +330,7 @@ def fit_gpytorch_torch( - Dictionary with the following key/values: "fopt": Best mll value. "wall_time": Wall time of fitting. - "iterations": List of OptimizationIteration objects with information on each + "iterations": List of OptimizationResult objects with information on each iteration. If track_iterations is False, will be empty. Example: @@ -224,74 +340,51 @@ def fit_gpytorch_torch( >>> fit_gpytorch_torch(mll) >>> mll.eval() """ - optim_options = {"maxiter": 100, "disp": True, "lr": 0.05} - optim_options.update(options or {}) - exclude = optim_options.pop("exclude", None) - if exclude is None: - mll_params = list(mll.parameters()) - else: - mll_params = [ - v for k, v in filter(create_name_filter(exclude), mll.named_parameters()) - ] + warn( + "`fit_gpytorch_torch` is marked for deprecation, consider using " + "`torch_minimize` or its model fitting helper `fit_gpytorch_mll_torch`.", + DeprecationWarning, + ) + _options = {"maxiter": 100, "disp": True, "lr": 0.05} + _options.update(options or {}) + exclude = _options.pop("exclude", None) + parameters = get_parameters( + mll, + requires_grad=True, + name_filter=None if exclude is None else get_name_filter(exclude), + ) optimizer = optimizer_cls( - params=[{"params": mll_params}], - **_filter_kwargs(optimizer_cls, **optim_options), + params=list(parameters.values()), **_filter_kwargs(optimizer_cls, **_options) + ) + iterations: List[OptimizationResult] = [] + stopping_criterion = ExpMAStoppingCriterion( + **_filter_kwargs(ExpMAStoppingCriterion, **_options) ) - # get bounds specified in model (if any) - bounds_: ParameterBounds = {} - if hasattr(mll, "named_parameters_and_constraints"): - for param_name, _, constraint in mll.named_parameters_and_constraints(): - if constraint is not None and not constraint.enforced: - bounds_[param_name] = constraint.lower_bound, constraint.upper_bound + def closure() -> Tuple[Tensor, Tuple[Tensor, ...]]: + optimizer.zero_grad() + with fast_computations(log_prob=approx_mll): + out = mll.model(*mll.model.train_inputs) + loss = -mll(out, mll.model.train_targets, *_get_extra_mll_args(mll)).sum() + loss.backward() - # update with user-supplied bounds (overwrites if already exists) - if bounds is not None: - bounds_.update(bounds) + return loss, tuple(param.grad for param in parameters.values()) - iterations = [] - t1 = monotonic() + def store_iteration(parameters: Dict[str, Tensor], result: OptimizationResult): + iterations.append(result) - param_trajectory: Dict[str, List[Tensor]] = { - name: [] for name, param in mll.named_parameters() - } - loss_trajectory: List[float] = [] - i = 0 - stop = False - stopping_criterion = ExpMAStoppingCriterion( - **_filter_kwargs(ExpMAStoppingCriterion, **optim_options) + result = fit_gpytorch_mll_torch( + mll=mll, + closure=closure, + bounds=bounds, + parameters=parameters, + optimizer=optimizer, + stopping_criterion=stopping_criterion, + callback=store_iteration if track_iterations else None, ) - train_inputs, train_targets = mll.model.train_inputs, mll.model.train_targets - while not stop: - optimizer.zero_grad() - with gpt_settings.fast_computations(log_prob=approx_mll): - output = mll.model(*train_inputs) - # we sum here to support batch mode - args = [output, train_targets] + _get_extra_mll_args(mll) - loss = -mll(*args).sum() - loss.backward() - loss_trajectory.append(loss.item()) - for name, param in mll.named_parameters(): - param_trajectory[name].append(param.detach().clone()) - if optim_options["disp"] and ( - (i + 1) % 10 == 0 or i == (optim_options["maxiter"] - 1) - ): - print(f"Iter {i + 1}/{optim_options['maxiter']}: {loss.item()}") - if track_iterations: - iterations.append(OptimizationIteration(i, loss.item(), monotonic() - t1)) - - optimizer.step() - # project onto bounds: - if bounds_: - for pname, param in mll.named_parameters(): - if pname in bounds_: - param.data = param.data.clamp(*bounds_[pname]) - i += 1 - stop = stopping_criterion.evaluate(fvals=loss.detach()) - info_dict = { - "fopt": loss_trajectory[-1], - "wall_time": monotonic() - t1, + return mll, { + "fopt": result.fval, + "wall_time": result.runtime, "iterations": iterations, } - return mll, info_dict diff --git a/botorch/optim/numpy_converter.py b/botorch/optim/numpy_converter.py index 5c0dce3b36..91aa103d3a 100644 --- a/botorch/optim/numpy_converter.py +++ b/botorch/optim/numpy_converter.py @@ -15,120 +15,25 @@ from collections import OrderedDict from math import inf from numbers import Number -from re import Pattern -from typing import ( - Any, - Callable, - Dict, - Iterator, - List, - NamedTuple, - Optional, - Set, - Tuple, - Union, -) +from typing import Dict, List, Optional, Set, Tuple +from warnings import warn import numpy as np import torch -from torch.nn import Module, Parameter - -ParameterBounds = Dict[str, Tuple[Optional[float], Optional[float]]] - - -class TorchAttr(NamedTuple): - shape: torch.Size - dtype: torch.dtype - device: torch.device - - -def create_name_filter( - patterns: Iterator[Union[Pattern, str]] -) -> Callable[[Union[str, Tuple[str, Any, ...]]], bool]: - r"""Returns a binary function that filters strings (or iterables whose first - element is a string) according to a bank of excluded patterns. Typically, used - in conjunction with generators such as `module.named_parameters()`. - - Args: - patterns: A collection of regular expressions or strings that - define the set of names to be excluded. - - Returns: - A binary function indicating whether or not an item should be filtered. - """ - names = set() - _patterns = set() - for pattern in patterns: - if isinstance(pattern, str): - names.add(pattern) - elif isinstance(pattern, Pattern): - _patterns.add(pattern) - else: - raise TypeError - - def name_filter(item: Union[str, Tuple[str, Any, ...]]) -> bool: - name = item if isinstance(item, str) else next(iter(item)) - if name in names: - return False - - for pattern in _patterns: - if pattern.search(name): - return False - - return True - - return name_filter - - -def get_parameters_and_bounds( - module: Module, - name_filter: Optional[Callable[[str], bool]] = None, - requires_grad: Optional[bool] = None, - default_bounds: Tuple[float, float] = (-float("inf"), float("inf")), -) -> Tuple[Dict[str, Parameter], Dict[str, ParameterBounds]]: - r"""Helper method for extracting parameters and feasible ranges thereof. - - Args: - module: The target module from which parameters are to be extracted. - name_filter: Optional Boolean function used to filter parameters by name. - requires_grad: Optional Boolean used to filter parameters based on whether - or not their require_grad attribute matches the user provided value. - default_bounds: Default lower and upper bounds for constrained parameters - with `None` typed bounds. - - Returns: - 0: Dictionary mapping names to Parameters. - 1: Dictionary mapping names of constrained parameters to ParameterBounds. - """ - if hasattr(module, "named_parameters_and_constraints"): - bounds = {} - params = {} - for name, param, constraint in module.named_parameters_and_constraints(): - if (requires_grad is None or (param.requires_grad == requires_grad)) and ( - name_filter is None or name_filter(name) - ): - params[name] = param - if constraint is None: - continue - - bounds[name] = tuple( - default if bound is None else constraint.inverse_transform(bound) - for (bound, default) in zip(constraint, default_bounds) - ) - else: - bounds = {} - params = { - name: param - for name, param in module.named_parameters() - if name_filter is None or name_filter(name) - } - - return params, bounds +from botorch.optim.utils import ( + _get_extra_mll_args, + _handle_numerical_errors, + get_name_filter, + get_parameters_and_bounds, + TorchAttr, +) +from gpytorch.mlls import MarginalLogLikelihood +from torch.nn import Module def module_to_array( module: Module, - bounds: Optional[ParameterBounds] = None, + bounds: Optional[Dict[str, Tuple[Optional[float], Optional[float]]]] = None, exclude: Optional[Set[str]] = None, ) -> Tuple[np.ndarray, Dict[str, TorchAttr], Optional[np.ndarray]]: r"""Extract named parameters from a module into a numpy array. @@ -138,7 +43,7 @@ def module_to_array( Args: module: A module with parameters. May specify parameter constraints in a `named_parameters_and_constraints` method. - bounds: A ParameterBounds dictionary mapping parameter names to tuples + bounds: A dictionary mapping parameter names t lower and upper bounds. of lower and upper bounds. Bounds specified here take precedence over bounds on the same parameters specified in the constraints registered with the module. @@ -156,9 +61,15 @@ def module_to_array( >>> mll = ExactMarginalLogLikelihood(model.likelihood, model) >>> parameter_array, property_dict, bounds_out = module_to_array(mll) """ + warn( + "`module_to_array` is marked for deprecation, consider using " + "`get_parameters_and_bounds`, `get_parameters_as_ndarray_1d`, or " + "`get_bounds_as_ndarray` instead.", + DeprecationWarning, + ) param_dict, bounds_dict = get_parameters_and_bounds( module=module, - name_filter=None if exclude is None else create_name_filter(exclude), + name_filter=None if exclude is None else get_name_filter(exclude), requires_grad=True, ) if bounds is not None: @@ -220,6 +131,11 @@ def set_params_with_array( >>> parameter_array += 0.1 # perturb parameters (for example only) >>> mll = set_params_with_array(mll, parameter_array, property_dict) """ + warn( + "`_set_params_with_array` is marked for deprecation, consider using " + "`set_parameters_from_ndarray_1d` instead.", + DeprecationWarning, + ) param_dict = OrderedDict(module.named_parameters()) start_idx = 0 for p_name, attrs in property_dict.items(): @@ -240,3 +156,46 @@ def set_params_with_array( param_dict[p_name].copy_(new_data) param_dict[p_name].requires_grad_(True) return module + + +def _scipy_objective_and_grad( + x: np.ndarray, mll: MarginalLogLikelihood, property_dict: Dict[str, TorchAttr] +) -> Tuple[float, np.ndarray]: + r"""Get objective and gradient in format that scipy expects. + + Args: + x: The (flattened) input parameters. + mll: The MarginalLogLikelihood module to evaluate. + property_dict: The property dictionary required to "unflatten" the input + parameter vector, as generated by `module_to_array`. + + Returns: + 2-element tuple containing + + - The objective value. + - The gradient of the objective. + """ + warn("`_scipy_objective_and_grad` is marked for deprecation.", DeprecationWarning) + mll = set_params_with_array(mll, x, property_dict) + train_inputs, train_targets = mll.model.train_inputs, mll.model.train_targets + mll.zero_grad() + try: # catch linear algebra errors in gpytorch + output = mll.model(*train_inputs) + args = [output, train_targets] + _get_extra_mll_args(mll) + loss = -mll(*args).sum() + except RuntimeError as e: + return _handle_numerical_errors(error=e, x=x) + loss.backward() + + i = 0 + param_dict = OrderedDict(mll.named_parameters()) + grad = np.zeros(sum([tattr.shape.numel() for tattr in property_dict.values()])) + for p_name in property_dict: + t = param_dict[p_name] + size = t.numel() + if t.requires_grad and t.grad is not None: + grad[i : i + size] = t.grad.detach().view(-1).cpu().double().clone().numpy() + i += size + + mll.zero_grad() + return loss.item(), grad diff --git a/botorch/optim/stopping.py b/botorch/optim/stopping.py index 68363d2288..bbeeee34e1 100644 --- a/botorch/optim/stopping.py +++ b/botorch/optim/stopping.py @@ -37,6 +37,9 @@ def evaluate(self, fvals: Tensor) -> bool: """ pass # pragma: no cover + def __call__(self, fvals: Tensor) -> bool: + return self.evaluate(fvals) + class ExpMAStoppingCriterion(StoppingCriterion): r"""Exponential moving average stopping criterion. diff --git a/botorch/optim/utils.py b/botorch/optim/utils.py deleted file mode 100644 index 849fa6b575..0000000000 --- a/botorch/optim/utils.py +++ /dev/null @@ -1,490 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -r""" -Utilities for optimization. -""" - -from __future__ import annotations - -import warnings -from collections import OrderedDict -from contextlib import contextmanager -from inspect import signature -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union - -import numpy as np -import torch -from botorch.acquisition.acquisition import AcquisitionFunction -from botorch.exceptions.errors import BotorchError -from botorch.exceptions.warnings import BotorchWarning -from botorch.models.gpytorch import GPyTorchModel, ModelListGPyTorchModel -from botorch.optim.numpy_converter import ( # noqa F401 - create_name_filter, - get_parameters_and_bounds, - set_params_with_array, - TorchAttr, -) -from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood -from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood -from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood -from linear_operator.utils.errors import NanError, NotPSDError -from torch import Tensor -from torch.nn import Module - -ParameterBounds = Dict[str, Tuple[Optional[float], Optional[float]]] -Tkwargs = Dict[str, Union[torch.device, torch.dtype]] - - -def sample_all_priors(model: GPyTorchModel, max_retries: int = 100) -> None: - r"""Sample from hyperparameter priors (in-place). - - Args: - model: A GPyTorchModel. - """ - for _, module, prior, closure, setting_closure in model.named_priors(): - if setting_closure is None: - raise RuntimeError( - "Must provide inverse transform to be able to sample from prior." - ) - for i in range(max_retries): - try: - setting_closure(module, prior.sample(closure(module).shape)) - break - except NotImplementedError: - warnings.warn( - f"`rsample` not implemented for {type(prior)}. Skipping.", - BotorchWarning, - ) - break - except RuntimeError as e: - if "out of bounds of its current constraints" in str(e): - if i == max_retries - 1: - raise RuntimeError( - "Failed to sample a feasible parameter value " - f"from the prior after {max_retries} attempts." - ) - else: - raise e - - -def columnwise_clamp( - X: Tensor, - lower: Optional[Union[float, Tensor]] = None, - upper: Optional[Union[float, Tensor]] = None, - raise_on_violation: bool = False, -) -> Tensor: - r"""Clamp values of a Tensor in column-wise fashion (with support for t-batches). - - This function is useful in conjunction with optimizers from the torch.optim - package, which don't natively handle constraints. If you apply this after - a gradient step you can be fancy and call it "projected gradient descent". - This funtion is also useful for post-processing candidates generated by the - scipy optimizer that satisfy bounds only up to numerical accuracy. - - Args: - X: The `b x n x d` input tensor. If 2-dimensional, `b` is assumed to be 1. - lower: The column-wise lower bounds. If scalar, apply bound to all columns. - upper: The column-wise upper bounds. If scalar, apply bound to all columns. - raise_on_violation: If `True`, raise an exception when the elments in `X` - are out of the specified bounds (up to numerical accuracy). This is - useful for post-processing candidates generated by optimizers that - satisfy imposed bounds only up to numerical accuracy. - - Returns: - The clamped tensor. - """ - min_bounds = _expand_bounds(lower, X) - max_bounds = _expand_bounds(upper, X) - if min_bounds is not None and max_bounds is not None: - if torch.any(min_bounds > max_bounds): - raise ValueError("Minimum values must be <= maximum values") - Xout = X - if min_bounds is not None: - Xout = Xout.max(min_bounds) - if max_bounds is not None: - Xout = Xout.min(max_bounds) - if raise_on_violation and not torch.allclose(Xout, X): - raise BotorchError("Original value(s) are out of bounds.") - return Xout - - -def fix_features( - X: Tensor, fixed_features: Optional[Dict[int, Optional[float]]] = None -) -> Tensor: - r"""Fix feature values in a Tensor. - - The fixed features will have zero gradient in downstream calculations. - - Args: - X: input Tensor with shape `... x p`, where `p` is the number of features - fixed_features: A dictionary with keys as column indices and values - equal to what the feature should be set to in `X`. If the value is - None, that column is just considered fixed. Keys should be in the - range `[0, p - 1]`. - - Returns: - The tensor X with fixed features. - """ - if fixed_features is None: - return X - else: - return torch.cat( - [ - X[..., i].unsqueeze(-1) - if i not in fixed_features - else _fix_feature(X[..., i].unsqueeze(-1), fixed_features[i]) - for i in range(X.shape[-1]) - ], - dim=-1, - ) - - -def _fix_feature(Z: Tensor, value: Optional[float]) -> Tensor: - r"""Helper function returns a Tensor like `Z` filled with `value` if provided.""" - if value is None: - return Z.detach() - return torch.full_like(Z, value) - - -def _expand_bounds( - bounds: Optional[Union[float, Tensor]], X: Tensor -) -> Optional[Tensor]: - r"""Expands a tensor representing bounds. - - Expand the dimension of bounds if necessary such that the dimension of bounds - is the same as the dimension of `X`. - - Args: - bounds: a bound (either upper or lower) of each entry of `X`. If this is a - single float, then all entries have the same bound. Different sizes of - tensors can be used to specify custom bounds. E.g., a `d`-dim tensor can - be used to specify bounds for each column (last dimension) of `X`, or a - tensor with same shape as `X` can be used to specify a different bound - for each entry of `X`. - X: `... x d` tensor - - Returns: - A tensor of bounds expanded to the size of `X` if bounds is not None, - and None if bounds is None. - """ - if bounds is not None: - if not torch.is_tensor(bounds): - bounds = torch.tensor(bounds) - try: - ebounds = bounds.expand_as(X) - except RuntimeError: - raise RuntimeError("Bounds must be broadcastable to X!") - return ebounds.to(dtype=X.dtype, device=X.device) - else: - return None - - -def _get_extra_mll_args( - mll: MarginalLogLikelihood, -) -> Union[List[Tensor], List[List[Tensor]]]: - r"""Obtain extra arguments for MarginalLogLikelihood objects. - - Get extra arguments (beyond the model output and training targets) required - for the particular type of MarginalLogLikelihood for a forward pass. - - Args: - mll: The MarginalLogLikelihood module. - - Returns: - Extra arguments for the MarginalLogLikelihood. - Returns an empty list if the mll type is unknown. - """ - if isinstance(mll, ExactMarginalLogLikelihood): - return list(mll.model.train_inputs) - elif isinstance(mll, SumMarginalLogLikelihood): - return [list(x) for x in mll.model.train_inputs] - return [] - - -def _filter_kwargs(function: Callable, **kwargs: Any) -> Any: - r"""Filter out kwargs that are not applicable for a given function. - Return a copy of given kwargs dict with only the required kwargs.""" - return {k: v for k, v in kwargs.items() if k in signature(function).parameters} - - -def _scipy_objective_and_grad( - x: np.ndarray, mll: MarginalLogLikelihood, property_dict: Dict[str, TorchAttr] -) -> Tuple[float, np.ndarray]: - r"""Get objective and gradient in format that scipy expects. - - Args: - x: The (flattened) input parameters. - mll: The MarginalLogLikelihood module to evaluate. - property_dict: The property dictionary required to "unflatten" the input - parameter vector, as generated by `module_to_array`. - - Returns: - 2-element tuple containing - - - The objective value. - - The gradient of the objective. - """ - mll = set_params_with_array(mll, x, property_dict) - train_inputs, train_targets = mll.model.train_inputs, mll.model.train_targets - mll.zero_grad() - try: # catch linear algebra errors in gpytorch - output = mll.model(*train_inputs) - args = [output, train_targets] + _get_extra_mll_args(mll) - loss = -mll(*args).sum() - except RuntimeError as e: - return _handle_numerical_errors(error=e, x=x) - loss.backward() - - i = 0 - param_dict = OrderedDict(mll.named_parameters()) - grad = np.zeros(sum([tattr.shape.numel() for tattr in property_dict.values()])) - for p_name in property_dict: - t = param_dict[p_name] - size = t.numel() - if t.requires_grad and t.grad is not None: - grad[i : i + size] = t.grad.detach().view(-1).cpu().double().clone().numpy() - i += size - - mll.zero_grad() - return loss.item(), grad - - -def _handle_numerical_errors( - error: RuntimeError, x: np.ndarray -) -> Tuple[float, np.ndarray]: - if isinstance(error, NotPSDError): - raise error - error_message = error.args[0] if len(error.args) > 0 else "" - if ( - isinstance(error, NanError) - or "singular" in error_message # old pytorch message - or "input is not positive-definite" in error_message # since pytorch #63864 - ): - return float("nan"), np.full_like(x, "nan") - raise error # pragma: nocover - - -def get_X_baseline(acq_function: AcquisitionFunction) -> Optional[Tensor]: - r"""Extract X_baseline from an acquisition function. - - This tries to find the baseline set of points. First, this checks if the - acquisition function has an `X_baseline` attribute. If it does not, - then this method attempts to use the model's `train_inputs` as `X_baseline`. - - Args: - acq_function: The acquisition function. - - Returns - An optional `n x d`-dim tensor of baseline points. This is None if no - baseline points are found. - """ - try: - X = acq_function.X_baseline - # if there are no baseline points, use training points - if X.shape[0] == 0: - raise BotorchError - except (BotorchError, AttributeError): - try: - # for entropy MOO methods - model = acq_function.mo_model - except AttributeError: - try: - # some acquisition functions do not have a model attribute - # e.g. FixedFeatureAcquisitionFunction - model = acq_function.model - except AttributeError: - warnings.warn("Failed to extract X_baseline.", BotorchWarning) - return - try: - # Make sure we get the original train inputs. - m = model.models[0] if isinstance(model, ModelListGPyTorchModel) else model - if m._has_transformed_inputs: - X = m._original_train_inputs - else: - X = m.train_inputs[0] - except (BotorchError, AttributeError): - warnings.warn("Failed to extract X_baseline.", BotorchWarning) - return - # just use one batch - while X.ndim > 2: - X = X[0] - return X - - -@contextmanager -def del_attribute_ctx( - instance: object, *attrs: str, enforce_hasattr: bool = False -) -> Generator[None, None, None]: - r"""Contextmanager for temporarily deleting attributes.""" - try: - cache = {} - for key in attrs: - if hasattr(instance, key): - cache[key] = getattr(instance, key) - delattr(instance, key) - elif enforce_hasattr: - raise ValueError( - f"Attribute {key} missing from {type(instance)} instance." - ) - yield - finally: - for key, cached_val in cache.items(): - setattr(instance, key, cached_val) - - -@contextmanager -def requires_grad_ctx( - module: Module, assignments: Dict[str, bool] -) -> Generator[None, None, None]: - r"""Contextmanager for temporarily setting the requires_grad field of a module's - parameters.""" - try: - cache = {} - for name, mode in assignments.items(): - parameter = module.get_parameter(name) - cache[name] = parameter.requires_grad - parameter.requires_grad_(mode) - yield - finally: - for name, mode in cache.items(): - module.get_parameter(name).requires_grad_(mode) - - -@contextmanager -def parameter_rollback_ctx( - module: Module, - name_filter: Optional[Callable[[str], bool]] = None, - requires_grad: Optional[bool] = None, - checkpoint: Optional[Dict[str, Tuple[Tensor, Tkwargs]]] = None, - **tkwargs: Any, -) -> Generator[Dict[str, Tensor], None, None]: - r"""Contextmanager that exits by rolling back parameter values. - - Args: - module: Module instance. - name_filter: Optional Boolean function used to filter parameters by name. - requires_grad: Optional Boolean used to filter parameters based on whether - or not their require_grad attribute matches the user provided value. - checkpoint: Optional cache of values and tensor metadata specifying the rollback - state for the module (or some subset thereof). - **tkwargs: Keyword arguments passed to `torch.Tensor.to` when copying data from - each tensor in `module.state_dict()` to the internally created checkpoint. - Only adhered to when the `checkpoint` argument is None. - - Yields: - A checkpoint dictionary for the module, mapping qualified names to cached values - and tensor metadata. Any in-places changes to the checkpoint will be observed at - rollback time. If the checkpoint is cleared, no rollback will occur. - """ - # Create copies of the orginal values - if checkpoint is None: - checkpoint = {} - for name, param in module.named_parameters(): - if (requires_grad is None or (param.requires_grad == requires_grad)) and ( - name_filter is None or name_filter(name) - ): - checkpoint[name]: Tuple[Tensor, Tkwargs] = ( - param.detach().to(**tkwargs).clone(), - {"device": param.device, "dtype": param.dtype}, - ) - - try: # yield the checkpoint to the user - yield checkpoint - finally: # restore original values of tracked parameters - for name, (values, _tkwargs) in checkpoint.items(): - param = module.get_parameter(name) - param.data[...] = values.to(**_tkwargs) - - -@contextmanager -def state_rollback_ctx( - module: Module, - name_filter: Optional[Callable[[str], bool]] = None, - checkpoint: Optional[Dict[str, Tuple[Tensor, Tkwargs]]] = None, - **tkwargs: Any, -) -> Generator[Dict[str, Tuple[Tensor, Tkwargs]], None, None]: - r"""Contextmanager that exits by rolling back a module's state_dict. - - Args: - module: Module instance. - name_filter: Optional Boolean function used to filter items by name. - checkpoint: Optional cache of values and tensor metadata specifying the rollback - state for the module (or some subset thereof). - **tkwargs: Keyword arguments passed to `torch.Tensor.to` when copying data from - each tensor in `module.state_dict()` to the internally created checkpoint. - Only adhered to when the `checkpoint` argument is None. - - Yields: - A checkpoint dictionary for the module, mapping qualified names to cached values - and tensor metadata. Any in-places changes to the checkpoint will be observed at - rollback time. If the checkpoint is cleared, no rollback will occur. - """ - # Create copies of the orginal values - if checkpoint is None: - checkpoint: Dict[str, Tuple[Tensor, Tkwargs]] = { - name: ( - data.detach().to(**tkwargs).clone(), - {"device": data.device, "dtype": data.dtype}, - ) - for name, data in module.state_dict().items() - if name_filter is None or name_filter(name) - } - - try: # yield the checkpoint dictionary to the user - yield checkpoint - finally: # restore original values of tracked parameters - if checkpoint: - state_dict = module.state_dict() - for key, (values, _tkwargs) in checkpoint.items(): - tnsr = state_dict.get(key) - if tnsr is None: - state_dict[key] = values.to(**_tkwargs) - else: - tnsr[...] = values.to(**_tkwargs) - module.load_state_dict(state_dict) - - -def allclose_mll( - a: MarginalLogLikelihood, - b: MarginalLogLikelihood, - transform_a: Optional[Callable[[Tensor], Tensor]] = None, - transform_b: Optional[Callable[[Tensor], Tensor]] = None, - rtol: float = 1e-05, - atol: float = 1e-08, -) -> bool: - r"""Convenience method for testing whether the log likelihoods produced by different - MarginalLogLikelihood instances, when evaluated on their respective models' training - sets, are allclose. - - Args: - a: A MarginalLogLikelihood instance. - b: A second MarginalLogLikelihood instance. - transform_a: Optional callable used to post-transform log likelihoods under `a`. - transform_b: Optional callable used to post-transform log likelihoods under `b`. - rtol: Relative tolerance. - atol: Absolute tolerance. - - Returns: - Boolean result of the allclose test. - """ - values_a = a( - a.model(*a.model.train_inputs), - a.model.train_targets, - *_get_extra_mll_args(a), - ) - if transform_a: - values_a = transform_a(values_a) - - values_b = b( - b.model(*b.model.train_inputs), - b.model.train_targets, - *_get_extra_mll_args(b), - ) - if transform_b: - values_b = transform_b(values_b) - - return values_a.allclose(values_b, rtol=rtol, atol=atol) diff --git a/botorch/optim/utils/__init__.py b/botorch/optim/utils/__init__.py new file mode 100644 index 0000000000..ddd1e9d72f --- /dev/null +++ b/botorch/optim/utils/__init__.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from botorch.optim.utils.acquisition_utils import ( + columnwise_clamp, + fix_features, + get_X_baseline, +) +from botorch.optim.utils.common import ( + _filter_kwargs, + _handle_numerical_errors, + _warning_handler_template, + DEFAULT, + TNone, +) +from botorch.optim.utils.model_utils import ( + _get_extra_mll_args, + allclose_mll, + get_data_loader, + get_name_filter, + get_parameters, + get_parameters_and_bounds, + sample_all_priors, + TorchAttr, +) +from botorch.optim.utils.numpy_utils import ( + as_ndarray, + get_bounds_as_ndarray, + get_tensors_as_ndarray_1d, + set_tensors_from_ndarray_1d, +) + +__all__ = [ + "_filter_kwargs", + "_get_extra_mll_args", + "_handle_numerical_errors", + "_warning_handler_template", + "allclose_mll", + "as_ndarray", + "columnwise_clamp", + "DEFAULT", + "fix_features", + "get_name_filter", + "get_bounds_as_ndarray", + "get_data_loader", + "get_parameters", + "get_parameters_and_bounds", + "get_tensors_as_ndarray_1d", + "get_X_baseline", + "sample_all_priors", + "set_tensors_from_ndarray_1d", + "TorchAttr", + "TNone", +] diff --git a/botorch/optim/utils/acquisition_utils.py b/botorch/optim/utils/acquisition_utils.py new file mode 100644 index 0000000000..c8df213590 --- /dev/null +++ b/botorch/optim/utils/acquisition_utils.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r"""Utilities for maximizing acquisition functions.""" + +from __future__ import annotations + +from typing import Dict, Optional, Union +from warnings import warn + +import torch +from botorch.acquisition.acquisition import AcquisitionFunction +from botorch.exceptions.errors import BotorchError +from botorch.exceptions.warnings import BotorchWarning +from botorch.models.gpytorch import ModelListGPyTorchModel +from torch import Tensor + + +def columnwise_clamp( + X: Tensor, + lower: Optional[Union[float, Tensor]] = None, + upper: Optional[Union[float, Tensor]] = None, + raise_on_violation: bool = False, +) -> Tensor: + r"""Clamp values of a Tensor in column-wise fashion (with support for t-batches). + + This function is useful in conjunction with optimizers from the torch.optim + package, which don't natively handle constraints. If you apply this after + a gradient step you can be fancy and call it "projected gradient descent". + This funtion is also useful for post-processing candidates generated by the + scipy optimizer that satisfy bounds only up to numerical accuracy. + + Args: + X: The `b x n x d` input tensor. If 2-dimensional, `b` is assumed to be 1. + lower: The column-wise lower bounds. If scalar, apply bound to all columns. + upper: The column-wise upper bounds. If scalar, apply bound to all columns. + raise_on_violation: If `True`, raise an exception when the elments in `X` + are out of the specified bounds (up to numerical accuracy). This is + useful for post-processing candidates generated by optimizers that + satisfy imposed bounds only up to numerical accuracy. + + Returns: + The clamped tensor. + """ + if lower is None and upper is None: + return X + + if lower is not None: + lower = torch.as_tensor(lower).expand_as(X).to(X) + + if upper is not None: + upper = torch.as_tensor(upper).expand_as(X).to(X) + if lower is not None and (lower > upper).any(): + raise ValueError("Lower bounds cannot exceed upper bounds.") + + out = X.clamp(lower, upper) + if raise_on_violation and not X.allclose(out): + raise BotorchError("Original value(s) are out of bounds.") + + return out + + +def fix_features( + X: Tensor, fixed_features: Optional[Dict[int, Optional[float]]] = None +) -> Tensor: + r"""Fix feature values in a Tensor. + + The fixed features will have zero gradient in downstream calculations. + + Args: + X: input Tensor with shape `... x p`, where `p` is the number of features + fixed_features: A dictionary with keys as column indices and values + equal to what the feature should be set to in `X`. If the value is + None, that column is just considered fixed. Keys should be in the + range `[0, p - 1]`. + + Returns: + The tensor X with fixed features. + """ + if fixed_features is None: + return X + + columns = list(X.unbind(dim=-1)) + for index, value in fixed_features.items(): + if value is None: + columns[index] = columns[index].detach() + else: + columns[index] = torch.full_like(columns[index], value) + + return torch.stack(columns, dim=-1) + + +def get_X_baseline(acq_function: AcquisitionFunction) -> Optional[Tensor]: + r"""Extract X_baseline from an acquisition function. + + This tries to find the baseline set of points. First, this checks if the + acquisition function has an `X_baseline` attribute. If it does not, + then this method attempts to use the model's `train_inputs` as `X_baseline`. + + Args: + acq_function: The acquisition function. + + Returns + An optional `n x d`-dim tensor of baseline points. This is None if no + baseline points are found. + """ + try: + X = acq_function.X_baseline + # if there are no baseline points, use training points + if X.shape[0] == 0: + raise BotorchError + except (BotorchError, AttributeError): + try: + # for entropy MOO methods + model = acq_function.mo_model + except AttributeError: + try: + # some acquisition functions do not have a model attribute + # e.g. FixedFeatureAcquisitionFunction + model = acq_function.model + except AttributeError: + warn("Failed to extract X_baseline.", BotorchWarning) + return + try: + # Make sure we get the original train inputs. + m = model.models[0] if isinstance(model, ModelListGPyTorchModel) else model + if m._has_transformed_inputs: + X = m._original_train_inputs + else: + X = m.train_inputs[0] + except (BotorchError, AttributeError): + warn("Failed to extract X_baseline.", BotorchWarning) + return + # just use one batch + while X.ndim > 2: + X = X[0] + return X diff --git a/botorch/optim/utils/common.py b/botorch/optim/utils/common.py new file mode 100644 index 0000000000..fbb5a20252 --- /dev/null +++ b/botorch/optim/utils/common.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r"""General-purpose optimization utilities.""" + +from __future__ import annotations + +from inspect import signature +from logging import debug as logging_debug +from typing import Any, Callable, Optional, Tuple +from warnings import warn_explicit, WarningMessage + +import numpy as np +from linear_operator.utils.errors import NanError, NotPSDError + +TNone = type(None) + + +class _TDefault: + pass + + +DEFAULT = _TDefault() + + +def _filter_kwargs(function: Callable, **kwargs: Any) -> Any: + r"""Filter out kwargs that are not applicable for a given function. + Return a copy of given kwargs dict with only the required kwargs.""" + return {k: v for k, v in kwargs.items() if k in signature(function).parameters} + + +def _handle_numerical_errors( + error: RuntimeError, x: np.ndarray +) -> Tuple[np.ndarray, np.ndarray]: + if isinstance(error, NotPSDError): + raise error + error_message = error.args[0] if len(error.args) > 0 else "" + if ( + isinstance(error, NanError) + or "singular" in error_message # old pytorch message + or "input is not positive-definite" in error_message # since pytorch #63864 + ): + return np.full((), "nan", dtype=x.dtype), np.full_like(x, "nan") + raise error # pragma: nocover + + +def _warning_handler_template( + w: WarningMessage, + debug: Optional[Callable[[WarningMessage], bool]] = None, + rethrow: Optional[Callable[[WarningMessage], bool]] = None, +) -> bool: + r"""Helper for making basic warning handlers. Typically used with functools.partial. + + Args: + w: The WarningMessage to be resolved and filtered out or returned unresolved. + debug: Optional callable used to specify that a warning should be + resolved as a logging statement at the DEBUG level. + rethrow: Optional callable used to specify that a warning should be + resolved by rethrowing the warning. + + Returns: + Boolean indicating whether or not the warning message was resolved. + """ + if debug and debug(w): + logging_debug(str(w.message)) + return True + + if rethrow and rethrow(w): + warn_explicit(str(w.message), w.category, w.filename, w.lineno) + return True + + return False diff --git a/botorch/optim/utils/model_utils.py b/botorch/optim/utils/model_utils.py new file mode 100644 index 0000000000..70a17b6b1b --- /dev/null +++ b/botorch/optim/utils/model_utils.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r"""Utilities for fitting and manipulating models.""" + +from __future__ import annotations + +from re import Pattern +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union, +) +from warnings import warn + +import torch +from botorch.exceptions.warnings import BotorchWarning +from botorch.models.gpytorch import GPyTorchModel +from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood +from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood +from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood +from torch import Tensor +from torch.nn import Module +from torch.utils.data import DataLoader, TensorDataset + + +class TorchAttr(NamedTuple): + shape: torch.Size + dtype: torch.dtype + device: torch.device + + +def _get_extra_mll_args( + mll: MarginalLogLikelihood, +) -> Union[List[Tensor], List[List[Tensor]]]: + r"""Obtain extra arguments for MarginalLogLikelihood objects. + + Get extra arguments (beyond the model output and training targets) required + for the particular type of MarginalLogLikelihood for a forward pass. + + Args: + mll: The MarginalLogLikelihood module. + + Returns: + Extra arguments for the MarginalLogLikelihood. + Returns an empty list if the mll type is unknown. + """ + warn("`_get_extra_mll_args` is marked for deprecation.", DeprecationWarning) + if isinstance(mll, ExactMarginalLogLikelihood): + return list(mll.model.train_inputs) + elif isinstance(mll, SumMarginalLogLikelihood): + return [list(x) for x in mll.model.train_inputs] + return [] + + +def get_data_loader( + model: GPyTorchModel, batch_size: int = 1024, **kwargs: Any +) -> DataLoader: + dataset = TensorDataset(*model.train_inputs, model.train_targets) + return DataLoader( + dataset=dataset, batch_size=min(batch_size, len(model.train_targets)), **kwargs + ) + + +def get_parameters( + module: Module, + requires_grad: Optional[bool] = None, + name_filter: Optional[Callable[[str], bool]] = None, +) -> Dict[str, Tensor]: + r"""Helper method for obtaining a module's parameters and their respective ranges. + + Args: + module: The target module from which parameters are to be extracted. + requires_grad: Optional Boolean used to filter parameters based on whether + or not their require_grad attribute matches the user provided value. + name_filter: Optional Boolean function used to filter parameters by name. + + Returns: + A dictionary of parameters. + """ + parameters = {} + for name, param in module.named_parameters(): + if requires_grad is not None and param.requires_grad != requires_grad: + continue + + if name_filter and not name_filter(name): + continue + + parameters[name] = param + + return parameters + + +def get_parameters_and_bounds( + module: Module, + requires_grad: Optional[bool] = None, + name_filter: Optional[Callable[[str], bool]] = None, + default_bounds: Tuple[float, float] = (-float("inf"), float("inf")), +) -> Tuple[Dict[str, Tensor], Dict[str, Tuple[Optional[float], Optional[float]]]]: + r"""Helper method for obtaining a module's parameters and their respective ranges. + + Args: + module: The target module from which parameters are to be extracted. + name_filter: Optional Boolean function used to filter parameters by name. + requires_grad: Optional Boolean used to filter parameters based on whether + or not their require_grad attribute matches the user provided value. + default_bounds: Default lower and upper bounds for constrained parameters + with `None` typed bounds. + + Returns: + A dictionary of parameters and a dictionary of parameter bounds. + """ + if hasattr(module, "named_parameters_and_constraints"): + bounds = {} + params = {} + for name, param, constraint in module.named_parameters_and_constraints(): + if (requires_grad is None or (param.requires_grad == requires_grad)) and ( + name_filter is None or name_filter(name) + ): + params[name] = param + if constraint is None: + continue + + bounds[name] = tuple( + default if bound is None else constraint.inverse_transform(bound) + for (bound, default) in zip(constraint, default_bounds) + ) + + return params, bounds + + params = get_parameters( + module, requires_grad=requires_grad, name_filter=name_filter + ) + return params, {} + + +def get_name_filter( + patterns: Iterator[Union[Pattern, str]] +) -> Callable[[Union[str, Tuple[str, Any, ...]]], bool]: + r"""Returns a binary function that filters strings (or iterables whose first + element is a string) according to a bank of excluded patterns. Typically, used + in conjunction with generators such as `module.named_parameters()`. + + Args: + patterns: A collection of regular expressions or strings that + define the set of names to be excluded. + + Returns: + A binary function indicating whether or not an item should be filtered. + """ + names = set() + _patterns = set() + for pattern in patterns: + if isinstance(pattern, str): + names.add(pattern) + elif isinstance(pattern, Pattern): + _patterns.add(pattern) + else: + raise TypeError( + "Expected `patterns` to contain `str` or `re.Pattern` typed elements, " + f"but found {type(pattern)}." + ) + + def name_filter(item: Union[str, Tuple[str, Any, ...]]) -> bool: + name = item if isinstance(item, str) else next(iter(item)) + if name in names: + return False + + for pattern in _patterns: + if pattern.search(name): + return False + + return True + + return name_filter + + +def sample_all_priors(model: GPyTorchModel, max_retries: int = 100) -> None: + r"""Sample from hyperparameter priors (in-place). + + Args: + model: A GPyTorchModel. + """ + for _, module, prior, closure, setting_closure in model.named_priors(): + if setting_closure is None: + raise RuntimeError( + "Must provide inverse transform to be able to sample from prior." + ) + for i in range(max_retries): + try: + setting_closure(module, prior.sample(closure(module).shape)) + break + except NotImplementedError: + warn( + f"`rsample` not implemented for {type(prior)}. Skipping.", + BotorchWarning, + ) + break + except RuntimeError as e: + if "out of bounds of its current constraints" in str(e): + if i == max_retries - 1: + raise RuntimeError( + "Failed to sample a feasible parameter value " + f"from the prior after {max_retries} attempts." + ) + else: + raise e + + +def allclose_mll( + a: MarginalLogLikelihood, + b: MarginalLogLikelihood, + transform_a: Optional[Callable[[Tensor], Tensor]] = None, + transform_b: Optional[Callable[[Tensor], Tensor]] = None, + rtol: float = 1e-05, + atol: float = 1e-08, +) -> bool: + r"""Convenience method for testing whether the log likelihoods produced by different + MarginalLogLikelihood instances, when evaluated on their respective models' training + sets, are allclose. + + Args: + a: A MarginalLogLikelihood instance. + b: A second MarginalLogLikelihood instance. + transform_a: Optional callable used to post-transform log likelihoods under `a`. + transform_b: Optional callable used to post-transform log likelihoods under `b`. + rtol: Relative tolerance. + atol: Absolute tolerance. + + Returns: + Boolean result of the allclose test. + """ + warn("`allclose_mll` is marked for deprecation.", DeprecationWarning) + + values_a = a( + a.model(*a.model.train_inputs), + a.model.train_targets, + *_get_extra_mll_args(a), + ) + if transform_a: + values_a = transform_a(values_a) + + values_b = b( + b.model(*b.model.train_inputs), + b.model.train_targets, + *_get_extra_mll_args(b), + ) + if transform_b: + values_b = transform_b(values_b) + + return values_a.allclose(values_b, rtol=rtol, atol=atol) diff --git a/botorch/optim/utils/numpy_utils.py b/botorch/optim/utils/numpy_utils.py new file mode 100644 index 0000000000..052b58ec69 --- /dev/null +++ b/botorch/optim/utils/numpy_utils.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r"""Utilities for interfacing Numpy and Torch.""" + +from __future__ import annotations + +from itertools import tee +from math import prod +from typing import Callable, Dict, Iterator, Optional, Tuple, Union + +import numpy as np +import torch +from numpy import ndarray +from torch import Tensor + +# Dictionaries mapping numpy to torch dtypes and vice-versa +numpy_to_torch_dtype_dict = { + np.bool: torch.bool, + np.uint8: torch.uint8, + np.int8: torch.int8, + np.int16: torch.int16, + np.int32: torch.int32, + np.int64: torch.int64, + np.float16: torch.float16, + np.float32: torch.float32, + np.float64: torch.float64, + np.complex64: torch.complex64, + np.complex128: torch.complex128, +} + +torch_to_numpy_dtype_dict = { + value: key for (key, value) in numpy_to_torch_dtype_dict.items() +} + + +def as_ndarray( + values: Tensor, dtype: Optional[np.dtype] = None, inplace: bool = True +) -> ndarray: + r"""Helper for going from torch.Tensor to numpy.ndarray. + + Args: + values: Tensor to be converted to ndarray. + dtype: Optional numpy.dtype for the converted tensor. + inplace: Boolean indicating whether memory should be shared if possible. + + Returns: + An ndarray with the same data as `values`. + """ + with torch.no_grad(): + out = values.cpu() # maybe transfer to cpu + + # Determine whether or not to `clone` + if ( + # cond 1: are we not in `inplace` mode? + not inplace + # cond 2: did we already copy when calling `cpu` above? + and out.device == values.device + # cond 3: will we copy when calling `astype` below? + and (dtype is None or out.dtype == torch_to_numpy_dtype_dict[dtype]) + ): + out = out.clone() + + # Convert to ndarray and maybe cast to `dtype` + out = out.numpy() + return out if (dtype is None or dtype == out.dtype) else out.astype(dtype) + + +def get_tensors_as_ndarray_1d( + tensors: Union[Iterator[Tensor], Dict[str, Tensor]], + out: Optional[ndarray] = None, + dtype: Optional[Union[np.dtype, str]] = None, + as_array: Callable[[Tensor], ndarray] = as_ndarray, +) -> ndarray: + # Create a pair of iterators, one for setup and one for data transfer + named_tensors_iter, named_tensors_iter2 = tee( + iter(tensors.items()) if isinstance(tensors, dict) else enumerate(tensors), 2 + ) + + # Use `named_tensors_iter` to get size of `out` and `dtype` when None + try: + name, tnsr = next(named_tensors_iter) + except StopIteration: + raise RuntimeError(f"Argument `tensors` with type {type(tensors)} is empty.") + size = tnsr.numel() + sum(tnsr.numel() for _, tnsr in named_tensors_iter) + dtype = torch_to_numpy_dtype_dict[tnsr.dtype] if dtype is None else dtype + + # Preallocate or validate `out` + if out is None: # use first tensor as a reference when `dtype` is None + out = np.empty([size], dtype=dtype) + elif out.ndim != 1: + raise ValueError(f"Expected a vector for `out`, but out.shape={out.shape}.") + elif out.size != size: + raise ValueError( + f"Size of `parameters` ({size}) does not match size of `out` ({out.size})." + ) + + # Use `named_tensors_iter2` to transfer data from `tensors` to `out` + index = 0 + for name, tnsr in named_tensors_iter2: + try: + size = tnsr.numel() + out[index : index + size] = as_array(tnsr.view(-1)) + index += size + except Exception as e: + raise RuntimeError( + "`get_tensors_as_ndarray_1d` failed while copying values from " + f"tensor {name}; rethrowing original exception." + ) from e + + return out + + +def set_tensors_from_ndarray_1d( + tensors: Union[Iterator[Tensor], Dict[str, Tensor]], + array: ndarray, + as_tensor: Callable[[ndarray], Tensor] = torch.as_tensor, +) -> None: + r"""Sets the values of one more tensors based off of a vector of assignments.""" + named_tensors_iter = ( + iter(tensors.items()) if isinstance(tensors, dict) else enumerate(tensors) + ) + with torch.no_grad(): + index = 0 + for name, tnsr in named_tensors_iter: + try: + size = tnsr.numel() + vals = array[index : index + size] if tnsr.ndim else array[index] + tnsr.copy_(as_tensor(vals).to(tnsr).view(tnsr.shape).to(tnsr)) + index += size + except Exception as e: + raise RuntimeError( + "`set_tensors_from_ndarray_1d` failed while copying values to " + f"tensor {name}; rethrowing original exception." + ) from e + + +def get_bounds_as_ndarray( + parameters: Dict[str, Tensor], + bounds: Dict[str, Tuple[Optional[float], Optional[float]]], +) -> Optional[np.ndarray]: + r"""Helper method for extracting a module's parameters and their respective + ranges. + + Args: + module: The target module from which parameters are to be extracted. + name_filter: Optional Boolean function used to filter parameters by name. + requires_grad: Optional Boolean used to filter parameters based on whether + or not their require_grad attribute matches the user provided value. + default_bounds: Default lower and upper bounds for constrained parameters + with `None` typed bounds. + + Returns: + A dictionary of parameters and a dictionary of parameter bounds. + """ + inf = float("inf") + out = None + index = 0 + for name, param in parameters.items(): + size = prod(param.shape) + if name in bounds: + lower, upper = bounds[name] + lower = -inf if lower is None else lower + upper = inf if upper is None else upper + if lower != -inf or upper != inf: + if out is None: + full_size = sum(prod(param.shape) for param in parameters.values()) + out = np.full((full_size, 2), (-inf, inf)) + out[index : index + size] = (lower, upper) + index = index + size + return out diff --git a/botorch/utils/context_managers.py b/botorch/utils/context_managers.py new file mode 100644 index 0000000000..6257239793 --- /dev/null +++ b/botorch/utils/context_managers.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r""" +Utilities for optimization. +""" + +from __future__ import annotations + +from contextlib import contextmanager +from typing import Any, Callable, Dict, Generator, Iterable, NamedTuple, Optional, Union + +from torch import device as Device, dtype as Dtype, Tensor +from torch.nn import Module + + +class TensorCheckpoint(NamedTuple): + values: Tensor + device: Optional[Device] = None + dtype: Optional[Dtype] = None + + +@contextmanager +def del_attribute_ctx( + instance: object, *attrs: str, enforce_hasattr: bool = False +) -> Generator[None, None, None]: + r"""Contextmanager for temporarily deleting attributes.""" + try: + cache = {} + for key in attrs: + if hasattr(instance, key): + cache[key] = getattr(instance, key) + delattr(instance, key) + elif enforce_hasattr: + raise ValueError( + f"Attribute {key} missing from {type(instance)} instance." + ) + yield + finally: + for key, cached_val in cache.items(): + setattr(instance, key, cached_val) + + +@contextmanager +def requires_grad_ctx( + module: Module, assignments: Dict[str, bool] +) -> Generator[None, None, None]: + r"""Contextmanager for temporarily setting the requires_grad field of a module's + parameters.""" + try: + cache = {} + for name, mode in assignments.items(): + parameter = module.get_parameter(name) + cache[name] = parameter.requires_grad + parameter.requires_grad_(mode) + yield + finally: + for name, mode in cache.items(): + module.get_parameter(name).requires_grad_(mode) + + +@contextmanager +def parameter_rollback_ctx( + parameters: Dict[str, Tensor], + checkpoint: Optional[Dict[str, TensorCheckpoint]] = None, + **tkwargs: Any, +) -> Generator[Dict[str, TensorCheckpoint], None, None]: + r"""Contextmanager that exits by rolling back a module's state_dict. + + Args: + module: Module instance. + name_filter: Optional Boolean function used to filter items by name. + checkpoint: Optional cache of values and tensor metadata specifying the rollback + state for the module (or some subset thereof). + **tkwargs: Keyword arguments passed to `torch.Tensor.to` when copying data from + each tensor in `module.state_dict()` to the internally created checkpoint. + Only adhered to when the `checkpoint` argument is None. + + Yields: + A dictionary of TensorCheckpoints for the module's state_dict. Any in-places + changes to the checkpoint will be observed at rollback time. If the checkpoint + is cleared, no rollback will occur. + """ + # Create copies of the orginal values + if checkpoint is None: + checkpoint = { + name: TensorCheckpoint( + values=param.detach().to(**tkwargs).clone(), + device=param.device, + dtype=param.dtype, + ) + for name, param in parameters.items() + } + + try: # yield the checkpoint dictionary to the user + yield checkpoint + finally: # restore original values of tracked parameters + if checkpoint: + for name, param in parameters.items(): + if name in checkpoint: + values, device, dtype = checkpoint[name] + param.data.copy_(values.to(device=device, dtype=dtype)) + + +@contextmanager +def module_rollback_ctx( + module: Module, + name_filter: Optional[Callable[[str], bool]] = None, + checkpoint: Optional[Dict[str, TensorCheckpoint]] = None, + **tkwargs: Any, +) -> Generator[Dict[str, TensorCheckpoint], None, None]: + r"""Contextmanager that exits by rolling back a module's state_dict. + + Args: + module: Module instance. + name_filter: Optional Boolean function used to filter items by name. + checkpoint: Optional cache of values and tensor metadata specifying the rollback + state for the module (or some subset thereof). + **tkwargs: Keyword arguments passed to `torch.Tensor.to` when copying data from + each tensor in `module.state_dict()` to the internally created checkpoint. + Only adhered to when the `checkpoint` argument is None. + + Yields: + A dictionary of TensorCheckpoints for the module's state_dict. Any in-places + changes to the checkpoint will be observed at rollback time. If the checkpoint + is cleared, no rollback will occur. + """ + # Create copies of the orginal values + if checkpoint is None: + checkpoint = { + name: TensorCheckpoint( + values=values.detach().to(**tkwargs).clone(), + device=values.device, + dtype=values.dtype, + ) + for name, values in module.state_dict().items() + if name_filter is None or name_filter(name) + } + + try: # yield the checkpoint dictionary to the user + yield checkpoint + finally: # restore original values of tracked parameters + if checkpoint: + state_dict = module.state_dict() + for key, (values, device, dtype) in checkpoint.items(): + tnsr = state_dict.get(key) + if tnsr is None: + state_dict[key] = values.to(device=device, dtype=dtype) + else: + tnsr[...] = values.to(device=device, dtype=dtype) + + module.load_state_dict(state_dict) + + +@contextmanager +def zero_grad_ctx( + parameters: Union[Dict[str, Tensor], Iterable[Tensor]], + zero_on_enter: bool = True, + zero_on_exit: bool = False, +) -> Generator[None, None, None]: + def zero_() -> None: + for param in ( + parameters.values() if isinstance(parameters, dict) else parameters + ): + if param.grad is not None: + param.grad.zero_() + + if zero_on_enter: + zero_() + + try: + yield + finally: + if zero_on_exit: + zero_() diff --git a/botorch/utils/dispatcher.py b/botorch/utils/dispatcher.py index 285085ae44..7e4fe2e113 100644 --- a/botorch/utils/dispatcher.py +++ b/botorch/utils/dispatcher.py @@ -16,6 +16,11 @@ ) +def type_bypassing_encoder(arg: Any) -> Type: + # Allow type variables to be passed as pre-encoded arguments + return arg if isinstance(arg, type) else type(arg) + + class Dispatcher(MDDispatcher): r"""Clearing house for multiple dispatch functionality. This class extends `` by: (i) generalizing the argument encoding diff --git a/sphinx/source/optim.rst b/sphinx/source/optim.rst index a3295f81d2..bd8ac5bab6 100644 --- a/sphinx/source/optim.rst +++ b/sphinx/source/optim.rst @@ -10,6 +10,11 @@ botorch.optim Optimization ------------------------------------------- +Core +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.optim.core + :members: + Acquisition Function Optimization ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.optim.optimize @@ -31,21 +36,49 @@ Stopping Criteria .. automodule:: botorch.optim.stopping :members: +Closures +------------------------------------------- + +Core +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.optim.closures.core + :members: + +Model Fitting Closures +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.optim.closures.model_closures + :members: + Utilities ------------------------------------------- +General Optimization Utilities +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.optim.utils.common + :members: + +Acquisition Optimization Utilities +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.optim.utils.acquisition_utils + :members: + +Model Fitting Utilities +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.optim.utils.model_utils + :members: + Numpy - Torch Conversion Tools ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: botorch.optim.numpy_converter +.. automodule:: botorch.optim.utils.numpy_utils :members: -Parameter Constraint Utilities +Numpy - Torch Conversion Tools (OLD) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: botorch.optim.parameter_constraints +.. automodule:: botorch.optim.numpy_converter :members: -General Optimization Utilities +Parameter Constraint Utilities ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: botorch.optim.utils +.. automodule:: botorch.optim.parameter_constraints :members: diff --git a/sphinx/source/utils.rst b/sphinx/source/utils.rst index 11cff4e181..4e8440b862 100644 --- a/sphinx/source/utils.rst +++ b/sphinx/source/utils.rst @@ -17,6 +17,11 @@ Containers .. automodule:: botorch.utils.containers :members: +Context Managers +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.utils.context_managers + :members: + Datasets ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.utils.datasets diff --git a/test/models/test_higher_order_gp.py b/test/models/test_higher_order_gp.py index 46d2f9fb22..466398296a 100644 --- a/test/models/test_higher_order_gp.py +++ b/test/models/test_higher_order_gp.py @@ -13,7 +13,7 @@ from botorch.models.higher_order_gp import FlattenedStandardize from botorch.models.transforms.input import Normalize from botorch.models.transforms.outcome import Standardize -from botorch.optim.fit import fit_gpytorch_torch +from botorch.optim.fit import fit_gpytorch_mll_torch from botorch.posteriors import GPyTorchPosterior, TransformedPosterior from botorch.sampling import IIDNormalSampler from botorch.utils.testing import BotorchTestCase @@ -59,7 +59,7 @@ def setUp(self): for m in [self.model, model_2, model_3]: mll = ExactMarginalLogLikelihood(m.likelihood, m) - fit_gpytorch_torch(mll, options={"maxiter": 1, "disp": False}) + fit_gpytorch_mll_torch(mll, step_limit=1) def test_num_output_dims(self): for dtype in [torch.float, torch.double]: @@ -137,7 +137,7 @@ def test_transforms(self): outcome_transform=FlattenedStandardize(train_y.shape[1:]), ) mll = ExactMarginalLogLikelihood(model.likelihood, model) - fit_gpytorch_torch(mll, options={"maxiter": 1, "disp": False}) + fit_gpytorch_mll_torch(mll, step_limit=1) test_x = torch.rand(2, 5, 3, device=self.device, dtype=dtype) test_y = torch.randn(2, 5, 4, 5, device=self.device, dtype=dtype) diff --git a/test/optim/closures/__init__.py b/test/optim/closures/__init__.py new file mode 100644 index 0000000000..4b87eb9e4d --- /dev/null +++ b/test/optim/closures/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/test/optim/closures/test_core.py b/test/optim/closures/test_core.py new file mode 100644 index 0000000000..6ce4f78a58 --- /dev/null +++ b/test/optim/closures/test_core.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from contextlib import nullcontext +from functools import partial +from typing import Dict +from unittest.mock import MagicMock + +import numpy as np +import torch +from botorch.optim.closures.core import ( + ForwardBackwardClosure, + get_tensors_as_ndarray_1d, + NdarrayOptimizationClosure, +) +from botorch.optim.utils import as_ndarray +from botorch.utils.context_managers import zero_grad_ctx +from botorch.utils.testing import BotorchTestCase +from linear_operator.utils.errors import NanError, NotPSDError +from torch.nn import Module, Parameter + + +class ToyModule(Module): + def __init__(self, w: Parameter, b: Parameter, x: Parameter, dummy: Parameter): + r"""Toy module for unit testing.""" + super().__init__() + self.w = w + self.b = b + self.x = x + self.dummy = dummy + + def forward(self) -> torch.Tensor: + return self.w * self.x + self.b + + @property + def free_parameters(self) -> Dict[str, torch.Tensor]: + return {n: p for n, p in self.named_parameters() if p.requires_grad} + + +class TestForwardBackwardClosure(BotorchTestCase): + def setUp(self): + super().setUp() + module = ToyModule( + w=Parameter(torch.tensor(2.0)), + b=Parameter(torch.tensor(3.0), requires_grad=False), + x=Parameter(torch.tensor(4.0)), + dummy=Parameter(torch.tensor(5.0)), + ).to(self.device) + self.modules = {} + for dtype in ("float32", "float64"): + self.modules[dtype] = module.to(dtype=getattr(torch, dtype)) + + def test_main(self): + for module in self.modules.values(): + closure = ForwardBackwardClosure(module, module.free_parameters) + + # Test __init__ + closure = ForwardBackwardClosure(module, module.free_parameters) + self.assertEqual(module.free_parameters, closure.parameters) + self.assertIsInstance(closure.context_manager, partial) + self.assertEqual(closure.context_manager.func, zero_grad_ctx) + + # Test return values + value, (dw, dx, dd) = closure() + self.assertTrue(value.equal(module())) + self.assertTrue(dw.equal(module.x)) + self.assertTrue(dx.equal(module.w)) + self.assertEqual(dd, None) + + # Test `callback`` and `reducer`` + closure = ForwardBackwardClosure(module, module.free_parameters) + mock_reducer = MagicMock(return_value=closure.forward()) + mock_callback = MagicMock() + closure = ForwardBackwardClosure( + forward=module, + parameters=module.free_parameters, + reducer=mock_reducer, + callback=mock_callback, + ) + value, grads = closure() + mock_reducer.assert_called_once_with(value) + mock_callback.assert_called_once_with(value, grads) + + # Test `backward`` and `context_manager` + closure = ForwardBackwardClosure( + forward=module, + parameters=module.free_parameters, + backward=partial(torch.Tensor.backward, retain_graph=True), + context_manager=nullcontext, + ) + _, (dw, dx, dd) = closure() # x2 because `grad` is no longer zeroed + self.assertTrue(dw.equal(2 * module.x)) + self.assertTrue(dx.equal(2 * module.w)) + self.assertEqual(dd, None) + + +class TestNdarrayOptimizationClosure(BotorchTestCase): + def setUp(self): + super().setUp() + self.module = ToyModule( + w=Parameter(torch.tensor(2.0)), + b=Parameter(torch.tensor(3.0), requires_grad=False), + x=Parameter(torch.tensor(4.0)), + dummy=Parameter(torch.tensor(5.0)), + ).to(self.device) + + self.wrappers = {} + for dtype in ("float32", "float64"): + module = self.module.to(dtype=getattr(torch, dtype)) + closure = ForwardBackwardClosure(module, module.free_parameters) + wrapper = NdarrayOptimizationClosure(closure, closure.parameters) + self.wrappers[dtype] = wrapper + + def test_main(self): + for wrapper in self.wrappers.values(): + # Test setter/getter + state = get_tensors_as_ndarray_1d(wrapper.closure.parameters) + other = np.random.randn(*state.shape).astype(state.dtype) + + wrapper.state = other + self.assertTrue(np.allclose(other, wrapper.state)) + + index = 0 + for param in wrapper.closure.parameters.values(): + size = param.numel() + self.assertTrue( + np.allclose( + other[index : index + size], wrapper.as_array(param.view(-1)) + ) + ) + index += size + + wrapper.state = state + self.assertTrue(np.allclose(state, wrapper.state)) + + # Test __call__ + value, grads = wrapper(other) + self.assertTrue(np.allclose(other, wrapper.state)) + self.assertIsInstance(value, np.ndarray) + self.assertIsInstance(grads, np.ndarray) + + # Test return values + value_tensor, grad_tensors = wrapper.closure() # get raw Tensor equivalents + self.assertTrue(np.allclose(value, wrapper.as_array(value_tensor))) + index = 0 + for x, dx in zip(wrapper.parameters.values(), grad_tensors): + size = x.numel() + grad = grads[index : index + size] + if dx is None: + self.assertTrue((grad == wrapper.fill_value).all()) + else: + self.assertTrue(np.allclose(grad, wrapper.as_array(dx))) + index += size + + module = wrapper.closure.forward + self.assertTrue(np.allclose(grads[0], as_ndarray(module.x))) + self.assertTrue(np.allclose(grads[1], as_ndarray(module.w))) + self.assertEqual(grads[2], wrapper.fill_value) + + # Test persistent buffers + for mode in (False, True): + wrapper.persistent = mode + self.assertEqual( + mode, + wrapper._get_gradient_ndarray() is wrapper._get_gradient_ndarray(), + ) + + def test_exceptions(self): + for wrapper in self.wrappers.values(): + mock_closure = MagicMock(return_value=wrapper.closure()) + mock_wrapper = NdarrayOptimizationClosure( + mock_closure, wrapper.closure.parameters + ) + with self.assertRaisesRegex(NotPSDError, "foo"): + mock_wrapper.closure.side_effect = NotPSDError("foo") + mock_wrapper() + + for exception in ( + NanError("foo"), + RuntimeError("singular"), + RuntimeError("input is not positive-definite"), + ): + mock_wrapper.closure.side_effect = exception + value, grads = mock_wrapper() + self.assertTrue(np.isnan(value).all()) + self.assertTrue(np.isnan(grads).all()) diff --git a/test/optim/closures/test_model_closures.py b/test/optim/closures/test_model_closures.py new file mode 100644 index 0000000000..188dc93e8a --- /dev/null +++ b/test/optim/closures/test_model_closures.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from itertools import zip_longest +from math import pi + +import torch +from botorch.models import ModelListGP, SingleTaskGP +from botorch.models.transforms.input import Normalize +from botorch.models.transforms.outcome import Standardize +from botorch.optim.closures.model_closures import ( + get_loss_closure, + get_loss_closure_with_grads, +) +from botorch.utils.testing import BotorchTestCase +from gpytorch import settings as gpytorch_settings +from gpytorch.mlls import ExactMarginalLogLikelihood, SumMarginalLogLikelihood +from torch.utils.data import DataLoader, TensorDataset + + +class TestLossClosures(BotorchTestCase): + def setUp(self): + super().setUp() + with torch.random.fork_rng(): + torch.manual_seed(0) + train_X = torch.linspace(0, 1, 10).unsqueeze(-1) + train_Y = torch.sin((2 * pi) * train_X) + train_Y = train_Y + 0.1 * torch.randn_like(train_Y) + + self.mlls = {} + model = SingleTaskGP( + train_X=train_X, + train_Y=train_Y, + input_transform=Normalize(d=1), + outcome_transform=Standardize(m=1), + ) + mll = ExactMarginalLogLikelihood(model.likelihood, model) + self.mlls[type(mll), type(model.likelihood), type(model)] = mll.to(self.device) + + model = ModelListGP(model, model) + mll = SumMarginalLogLikelihood(model.likelihood, model) + self.mlls[type(mll), type(model.likelihood), type(model)] = mll.to(self.device) + + def test_main(self): + for mll in self.mlls.values(): + out = mll.model(*mll.model.train_inputs) + loss = -mll(out, mll.model.train_targets).sum() + loss.backward() + params = {n: p for n, p in mll.named_parameters() if p.requires_grad} + grads = [ + torch.zeros_like(p) if p.grad is None else p.grad + for p in params.values() + ] + + closure = get_loss_closure(mll) + self.assertTrue(loss.equal(closure())) + + closure = get_loss_closure_with_grads(mll, params) + _loss, _grads = closure() + self.assertTrue(loss.equal(_loss)) + self.assertTrue(all(a.equal(b) for a, b in zip_longest(grads, _grads))) + + def test_data_loader(self): + for mll in self.mlls.values(): + if type(mll) != ExactMarginalLogLikelihood: + continue + + dataset = TensorDataset(*mll.model.train_inputs, mll.model.train_targets) + loader = DataLoader(dataset, batch_size=len(mll.model.train_targets)) + params = {n: p for n, p in mll.named_parameters() if p.requires_grad} + A = get_loss_closure_with_grads(mll, params) + (a, das) = A() + + B = get_loss_closure_with_grads(mll, params, data_loader=loader) + with gpytorch_settings.debug(False): # disables GPyTorch's internal check + (b, dbs) = B() + + self.assertTrue(a.allclose(b)) + for da, db in zip_longest(das, dbs): + self.assertTrue(da.allclose(db)) + + loader = DataLoader(mll.model.train_targets, len(mll.model.train_targets)) + closure = get_loss_closure_with_grads(mll, params, data_loader=loader) + with self.assertRaisesRegex(TypeError, "Expected .* a batch of tensors"): + closure() diff --git a/test/optim/test_core.py b/test/optim/test_core.py new file mode 100644 index 0000000000..df55583c70 --- /dev/null +++ b/test/optim/test_core.py @@ -0,0 +1,268 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial +from typing import Dict +from unittest.mock import MagicMock, patch + +import torch +from botorch.optim import core +from botorch.optim.closures import ForwardBackwardClosure, NdarrayOptimizationClosure +from botorch.optim.core import ( + OptimizationResult, + OptimizationStatus, + scipy_minimize, + torch_minimize, +) +from botorch.utils.testing import BotorchTestCase +from numpy import allclose +from scipy.optimize import OptimizeResult +from torch import Tensor +from torch.nn import Module, Parameter +from torch.optim.sgd import SGD + +try: + from torch.optim.lr_scheduler import LRScheduler +except ImportError: # pragma: no cover + from torch.optim.lr_scheduler import _LRScheduler as LRScheduler # pragma: no cover + + +class ToyModule(Module): + def __init__(self, b: Parameter, x: Parameter, dummy: Parameter): + r"""Toy module for unit testing.""" + super().__init__() + self.x = x + self.b = b + self.dummy = dummy + + def forward(self) -> Tensor: + return (self.x - self.b).square().sum() + + @property + def free_parameters(self) -> Dict[str, Tensor]: + return {n: p for n, p in self.named_parameters() if p.requires_grad} + + +class TestScipyMinimize(BotorchTestCase): + def setUp(self): + super().setUp() + module = ToyModule( + x=Parameter(torch.tensor(0.5, device=self.device)), + b=Parameter(torch.tensor(0.0, device=self.device), requires_grad=False), + dummy=Parameter(torch.tensor(1.0, device=self.device)), + ).to(self.device) + + self.closures = {} + for dtype in ("float32", "float64"): + m = module.to(dtype=getattr(torch, dtype)) + self.closures[dtype] = ForwardBackwardClosure(m, m.free_parameters) + + def test_basic(self): + x = Parameter(torch.rand([])) + + def closure(): + if x.grad is not None: + x.grad.zero_() + + loss = x.square().sum() + loss.backward() + return loss, [x.grad] + + result = scipy_minimize(closure, {"x": x}) + self.assertEqual(result.status, OptimizationStatus.SUCCESS) + self.assertTrue(allclose(result.fval, 0.0)) + + def test_main(self): + def _callback(parameters, result, out) -> None: + out.append(result) + + for closure in self.closures.values(): + for with_wrapper in (True, False): + with torch.no_grad(): + cache = {} # cache random starting values + for name, param in closure.parameters.items(): + init = cache[name] = torch.rand_like(param) + param.data.copy_(init) + + closure_arg = ( + NdarrayOptimizationClosure(closure, closure.parameters) + if with_wrapper + else closure + ) + result = scipy_minimize( + closure=closure_arg, + parameters=closure.parameters, + bounds={"x": (0, 1)}, + ) + self.assertIsInstance(result, OptimizationResult) + self.assertEqual(result.status, OptimizationStatus.SUCCESS) + self.assertTrue(allclose(result.fval, 0.0)) + self.assertTrue(closure.parameters["dummy"].equal(cache["dummy"])) + self.assertFalse(closure.parameters["x"].equal(cache["x"])) + + # Test `bounds` and `callback` + with torch.no_grad(): # closure.forward is a ToyModule instance + closure.forward.b.fill_(0.0) + closure.forward.x.fill_(0.5) + + step_results = [] + result = scipy_minimize( + closure=closure, + parameters=closure.parameters, + bounds={"x": (0.1, 1.0)}, + callback=partial(_callback, out=step_results), + ) + self.assertTrue(allclose(0.01, result.fval)) + self.assertTrue(allclose(0.1, closure.forward.x.detach().cpu().item())) + + self.assertEqual(result.step, len(step_results)) + self.assertEqual(result.step, step_results[-1].step) + self.assertEqual(result.fval, step_results[-1].fval) + + def test_post_processing(self): + closure = next(iter(self.closures.values())) + wrapper = NdarrayOptimizationClosure(closure, closure.parameters) + with patch.object(core, "minimize") as mock_minimize: + for status, msg in ( + (OptimizationStatus.FAILURE, b"ABNORMAL_TERMINATION_IN_LNSRCH"), + (OptimizationStatus.STOPPED, "TOTAL NO. of ITERATIONS REACHED LIMIT"), + ): + mock_minimize.return_value = OptimizeResult( + x=wrapper.state, + fun=1.0, + nit=3, + success=False, + message=msg, + ) + result = core.scipy_minimize(wrapper, closure.parameters) + self.assertEqual(result.status, status) + self.assertEqual(result.fval, mock_minimize.return_value.fun) + self.assertEqual( + result.message, msg if isinstance(msg, str) else msg.decode("ascii") + ) + + +class TestTorchMinimize(BotorchTestCase): + def setUp(self): + super().setUp() + module = ToyModule( + x=Parameter(torch.tensor(0.5, device=self.device)), + b=Parameter(torch.tensor(0.0, device=self.device), requires_grad=False), + dummy=Parameter(torch.tensor(1.0, device=self.device)), + ).to(self.device) + + self.closures = {} + for dtype in ("float32", "float64"): + m = module.to(dtype=getattr(torch, dtype)) + self.closures[dtype] = ForwardBackwardClosure(m, m.free_parameters) + + def test_basic(self): + x = Parameter(torch.tensor([0.02])) + + def closure(): + if x.grad is not None: + x.grad.zero_() + + loss = x.square().sum() + loss.backward() + return loss, [x.grad] + + result = torch_minimize(closure, {"x": x}, step_limit=100) + self.assertEqual(result.status, OptimizationStatus.STOPPED) + self.assertTrue(allclose(result.fval, 0.0)) + + def test_main(self): + def _callback(parameters, result, out) -> None: + out.append(result) + + for closure in self.closures.values(): + # Test that we error out if no termination conditions are given + with self.assertRaisesRegex(RuntimeError, "No termination conditions"): + torch_minimize(closure=closure, parameters=closure.parameters) + + # Test single step behavior + for optimizer in ( + SGD(params=list(closure.parameters.values()), lr=0.1), # instance + partial(SGD, lr=0.1), # factory + ): + cache = {n: p.detach().clone() for n, p in closure.parameters.items()} + grads = [g if g is None else g.detach().clone() for g in closure()[1]] + result = torch_minimize( + closure=closure, + parameters=closure.parameters, + optimizer=optimizer, + step_limit=1, + ) + self.assertIsInstance(result, OptimizationResult) + self.assertEqual(result.fval, closure()[0]) + self.assertEqual(result.step, 1) + self.assertEqual(result.status, OptimizationStatus.STOPPED) + self.assertTrue(closure.parameters["dummy"].equal(cache["dummy"])) + self.assertFalse(closure.parameters["x"].equal(cache["x"])) + for (name, param), g in zip(closure.parameters.items(), grads): + self.assertTrue( + param.allclose(cache[name] - (0 if g is None else 0.1 * g)) + ) + + # Test local convergence + with torch.no_grad(): # closure.forward is a ToyModule instance + closure.forward.b.fill_(0.0) + closure.forward.x.fill_(0.02) + + result = torch_minimize(closure, closure.parameters, step_limit=100) + self.assertTrue(allclose(0.0, result.fval)) + self.assertEqual(result.step, 100) + + # Test `bounds` and `callback` + with torch.no_grad(): # closure.forward is a ToyModule instance + closure.forward.b.fill_(0.0) + closure.forward.x.fill_(0.11) + + step_results = [] + result = torch_minimize( + closure=closure, + parameters=closure.parameters, + bounds={"x": (0.1, 1.0)}, + callback=partial(_callback, out=step_results), + step_limit=100, + ) + self.assertTrue(allclose(0.01, result.fval)) + self.assertEqual(result.step, len(step_results)) + + # Test `stopping_criterion` + stopping_decisions = iter((False, False, True, False)) + result = torch_minimize( + closure=closure, + parameters=closure.parameters, + stopping_criterion=lambda fval: next(stopping_decisions), + ) + self.assertEqual(result.step, 2) + self.assertEqual(result.status, OptimizationStatus.STOPPED) + + # Test passing `scheduler` + mock_scheduler = MagicMock(spec=LRScheduler) + mock_scheduler.step = MagicMock(side_effect=RuntimeError("foo")) + with self.assertRaisesRegex(RuntimeError, "foo"): + torch_minimize( + closure=closure, + parameters=closure.parameters, + scheduler=mock_scheduler, + step_limit=1, + ) + mock_scheduler.step.assert_called_once() + + # Test passing `scheduler` as a factory + optimizer = SGD(list(closure.parameters.values()), lr=1e-3) + mock_factory = MagicMock(side_effect=RuntimeError("foo")) + with self.assertRaisesRegex(RuntimeError, "foo"): + torch_minimize( + closure=closure, + parameters=closure.parameters, + optimizer=optimizer, + scheduler=mock_factory, + step_limit=1, + ) + mock_factory.assert_called_once_with(optimizer) diff --git a/test/optim/test_fit.py b/test/optim/test_fit.py index cdbd72f7c1..3abe79a4f3 100644 --- a/test/optim/test_fit.py +++ b/test/optim/test_fit.py @@ -5,24 +5,208 @@ # LICENSE file in the root directory of this source tree. import math -from re import compile -from unittest.mock import patch -from warnings import catch_warnings, warn +import re +from unittest.mock import MagicMock, patch +from warnings import catch_warnings import torch from botorch.exceptions.warnings import OptimizationWarning from botorch.models import SingleTaskGP from botorch.models.transforms.input import Normalize from botorch.models.transforms.outcome import Standardize -from botorch.optim import fit -from botorch.optim.fit import OptimizationIteration -from botorch.optim.utils import state_rollback_ctx +from botorch.optim import core, fit + +from botorch.optim.core import OptimizationResult from botorch.settings import debug +from botorch.utils.context_managers import module_rollback_ctx, TensorCheckpoint from botorch.utils.testing import BotorchTestCase from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood from scipy.optimize import OptimizeResult +class TestFitGPyTorchMLLScipy(BotorchTestCase): + def setUp(self): + self.mlls = {} + with torch.random.fork_rng(): + torch.manual_seed(0) + train_X = torch.linspace(0, 1, 10).unsqueeze(-1) + train_Y = torch.sin((2 * math.pi) * train_X) + train_Y = train_Y + 0.1 * torch.randn_like(train_Y) + + model = SingleTaskGP( + train_X=train_X, + train_Y=train_Y, + input_transform=Normalize(d=1), + outcome_transform=Standardize(m=1), + ) + self.mlls[SingleTaskGP, 1] = ExactMarginalLogLikelihood(model.likelihood, model) + + def test_fit_gpytorch_mll_scipy(self): + for mll in self.mlls.values(): + for dtype in (torch.float32, torch.float64): + self._test_fit_gpytorch_mll_scipy(mll.to(dtype=dtype)) + + def _test_fit_gpytorch_mll_scipy(self, mll): + options = {"disp": False, "maxiter": 2} + ckpt = { + k: TensorCheckpoint(v.detach().clone(), v.device, v.dtype) + for k, v in mll.state_dict().items() + } + with self.subTest("main"), module_rollback_ctx(mll, checkpoint=ckpt): + with catch_warnings(record=True) as ws, debug(True): + result = fit.fit_gpytorch_mll_scipy(mll, options=options) + + # Test only parameters requiring gradients have changed + self.assertTrue( + all( + param.equal(ckpt[name].values) != param.requires_grad + for name, param in mll.named_parameters() + ) + ) + + # Test maxiter warning message + self.assertTrue(any("TOTAL NO. of" in str(w.message) for w in ws)) + self.assertTrue( + any(issubclass(w.category, OptimizationWarning) for w in ws) + ) + + # Test iteration tracking + self.assertIsInstance(result, OptimizationResult) + self.assertLessEqual(result.step, options["maxiter"]) + self.assertEqual(sum(1 for w in ws if "TOTAL NO. of" in str(w.message)), 1) + + # Test that user provided bounds are respected + with self.subTest("bounds"), module_rollback_ctx(mll, checkpoint=ckpt): + fit.fit_gpytorch_mll_scipy( + mll, + bounds={"likelihood.noise_covar.raw_noise": (123, 456)}, + options=options, + ) + + self.assertTrue( + mll.likelihood.noise_covar.raw_noise >= 123 + and mll.likelihood.noise_covar.raw_noise <= 456 + ) + + for name, param in mll.named_parameters(): + self.assertNotEqual(param.requires_grad, param.equal(ckpt[name].values)) + + # Test handling of scipy optimization failures and parameter assignments + mock_x = [] + assignments = {} + for name, param in mll.named_parameters(): + if not param.requires_grad: + continue # pragma: no cover + + values = assignments[name] = torch.rand_like(param) + mock_x.append(values.view(-1)) + + with module_rollback_ctx(mll, checkpoint=ckpt), patch.object( + core, "minimize" + ) as mock_minimize: + mock_minimize.return_value = OptimizeResult( + x=torch.concat(mock_x).tolist(), + success=False, + status=0, + fun=float("nan"), + jac=None, + nfev=1, + njev=1, + nhev=1, + nit=1, + message="ABNORMAL_TERMINATION_IN_LNSRCH".encode(), + ) + with catch_warnings(record=True) as ws, debug(True): + fit.fit_gpytorch_mll_scipy(mll, options=options) + + # Test that warning gets raised + self.assertTrue( + any("ABNORMAL_TERMINATION_IN_LNSRCH" in str(w.message) for w in ws) + ) + + # Test that parameter values get assigned correctly + self.assertTrue( + all( + param.equal(assignments[name]) + for name, param in mll.named_parameters() + if param.requires_grad + ) + ) + + # Test `closure_kwargs` + with self.subTest("closure_kwargs"): + mock_closure = MagicMock(side_effect=StopIteration("foo")) + with self.assertRaisesRegex(StopIteration, "foo"): + fit.fit_gpytorch_mll_scipy( + mll, closure=mock_closure, closure_kwargs={"ab": "cd"} + ) + mock_closure.assert_called_once_with(ab="cd") + + +class TestFitGPyTorchMLLTorch(BotorchTestCase): + def setUp(self): + self.mlls = {} + with torch.random.fork_rng(): + torch.manual_seed(0) + train_X = torch.linspace(0, 1, 10).unsqueeze(-1) + train_Y = torch.sin((2 * math.pi) * train_X) + train_Y = train_Y + 0.1 * torch.randn_like(train_Y) + + model = SingleTaskGP( + train_X=train_X, + train_Y=train_Y, + input_transform=Normalize(d=1), + outcome_transform=Standardize(m=1), + ) + self.mlls[SingleTaskGP, 1] = ExactMarginalLogLikelihood(model.likelihood, model) + + def test_fit_gpytorch_mll_torch(self): + for mll in self.mlls.values(): + for dtype in (torch.float32, torch.float64): + self._test_fit_gpytorch_mll_torch(mll.to(dtype=dtype)) + + def _test_fit_gpytorch_mll_torch(self, mll): + ckpt = { + k: TensorCheckpoint(v.detach().clone(), v.device, v.dtype) + for k, v in mll.state_dict().items() + } + with self.subTest("main"), module_rollback_ctx(mll, checkpoint=ckpt): + with catch_warnings(record=True) as _, debug(True): + result = fit.fit_gpytorch_mll_torch(mll, step_limit=2) + + self.assertIsInstance(result, OptimizationResult) + self.assertLessEqual(result.step, 2) + + # Test only parameters requiring gradients have changed + self.assertTrue( + all( + param.requires_grad != param.equal(ckpt[name].values) + for name, param in mll.named_parameters() + ) + ) + + # Test that user provided bounds are respected + with self.subTest("bounds"), module_rollback_ctx(mll, checkpoint=ckpt): + fit.fit_gpytorch_mll_torch( + mll, + bounds={"likelihood.noise_covar.raw_noise": (123, 456)}, + ) + + self.assertTrue( + mll.likelihood.noise_covar.raw_noise >= 123 + and mll.likelihood.noise_covar.raw_noise <= 456 + ) + + # Test `closure_kwargs` + with self.subTest("closure_kwargs"): + mock_closure = MagicMock(side_effect=StopIteration("foo")) + with self.assertRaisesRegex(StopIteration, "foo"): + fit.fit_gpytorch_mll_torch( + mll, closure=mock_closure, closure_kwargs={"ab": "cd"} + ) + mock_closure.assert_called_once_with(ab="cd") + + class TestFitGPyTorchScipy(BotorchTestCase): def setUp(self): self.mlls = {} @@ -47,8 +231,11 @@ def test_fit_gpytorch_scipy(self): def _test_fit_gpytorch_scipy(self, mll): options = {"disp": False, "maxiter": 3, "maxfun": 2} - ckpt = {k: (v.detach().clone(), {}) for k, v in mll.state_dict().items()} - with self.subTest("main"), state_rollback_ctx(mll, checkpoint=ckpt): + ckpt = { + k: TensorCheckpoint(v.detach().clone(), v.device, v.dtype) + for k, v in mll.state_dict().items() + } + with self.subTest("main"), module_rollback_ctx(mll, checkpoint=ckpt): with catch_warnings(record=True) as ws, debug(True): _, info_dict = fit.fit_gpytorch_scipy( mll, track_iterations=True, options=options @@ -70,14 +257,14 @@ def _test_fit_gpytorch_scipy(self, mll): # Test iteration tracking self.assertLessEqual(len(info_dict["iterations"]), options["maxiter"]) - self.assertIsInstance(info_dict["iterations"][0], OptimizationIteration) + self.assertIsInstance(info_dict["iterations"][0], OptimizationResult) self.assertTrue("fopt" in info_dict) self.assertTrue("wall_time" in info_dict) self.assertEqual(sum(1 for w in ws if "TOTAL NO. of" in str(w.message)), 1) # Test that user provided bounds and `exclude` argument are respected - exclude = "model.mean_module.constant", compile("raw_lengthscale$") - with self.subTest("bounds"), state_rollback_ctx(mll, checkpoint=ckpt): + exclude = "model.mean_module.constant", re.compile("raw_lengthscale$") + with self.subTest("bounds"), module_rollback_ctx(mll, checkpoint=ckpt): fit.fit_gpytorch_scipy( mll, bounds={"likelihood.noise_covar.raw_noise": (123, 456)}, @@ -103,7 +290,7 @@ def _test_fit_gpytorch_scipy(self, mll): self.assertFalse(param.equal(ckpt[name][0])) # Test use of `approx_mll` flag - with self.subTest("approx_mll"), state_rollback_ctx(mll, checkpoint=ckpt): + with self.subTest("approx_mll"), module_rollback_ctx(mll, checkpoint=ckpt): fit.fit_gpytorch_scipy(mll, approx_mll=True, options=options) self.assertTrue( all( @@ -122,7 +309,7 @@ def _test_fit_gpytorch_scipy(self, mll): values = assignments[name] = torch.rand_like(param) mock_x.append(values.view(-1)) - with state_rollback_ctx(mll, checkpoint=ckpt), patch.object( + with module_rollback_ctx(mll, checkpoint=ckpt), patch.object( fit, "minimize" ) as mock_minimize: mock_minimize.return_value = OptimizeResult( @@ -179,9 +366,12 @@ def test_fit_gpytorch_torch(self): def _test_fit_gpytorch_torch(self, mll): options = {"disp": False, "maxiter": 3} - ckpt = {k: (v.detach().clone(), {}) for k, v in mll.state_dict().items()} - with self.subTest("main"), state_rollback_ctx(mll, checkpoint=ckpt): - with catch_warnings(record=True) as ws, debug(True): + ckpt = { + k: TensorCheckpoint(v.detach().clone(), v.device, v.dtype) + for k, v in mll.state_dict().items() + } + with self.subTest("main"), module_rollback_ctx(mll, checkpoint=ckpt): + with catch_warnings(record=True), debug(True): _, info_dict = fit.fit_gpytorch_torch( mll, track_iterations=True, options=options ) @@ -196,13 +386,13 @@ def _test_fit_gpytorch_torch(self, mll): # Test iteration tracking self.assertEqual(len(info_dict["iterations"]), options["maxiter"]) - self.assertIsInstance(info_dict["iterations"][0], OptimizationIteration) + self.assertIsInstance(info_dict["iterations"][0], OptimizationResult) self.assertTrue("fopt" in info_dict) self.assertTrue("wall_time" in info_dict) # Test that user provided bounds and `exclude` argument are respected - exclude = "model.mean_module.constant", compile("raw_lengthscale$") - with self.subTest("bounds"), state_rollback_ctx(mll, checkpoint=ckpt): + exclude = "model.mean_module.constant", re.compile("raw_lengthscale$") + with self.subTest("bounds"), module_rollback_ctx(mll, checkpoint=ckpt): fit.fit_gpytorch_torch( mll, bounds={"likelihood.noise_covar.raw_noise": (123, 456)}, @@ -228,7 +418,7 @@ def _test_fit_gpytorch_torch(self, mll): self.assertFalse(param.equal(ckpt[name][0])) # Test use of `approx_mll` flag - with self.subTest("approx_mll"), state_rollback_ctx(mll, checkpoint=ckpt): + with self.subTest("approx_mll"), module_rollback_ctx(mll, checkpoint=ckpt): fit.fit_gpytorch_torch(mll, approx_mll=True, options=options) self.assertTrue( all( @@ -236,7 +426,3 @@ def _test_fit_gpytorch_torch(self, mll): for name, param in mll.named_parameters() ) ) - - with patch.object(fit, "print", new=warn), catch_warnings(record=True) as ws: - fit.fit_gpytorch_torch(mll, options={"disp": True, "maxiter": 11}) - self.assertEqual(len(ws), 2) diff --git a/test/optim/test_numpy_converter.py b/test/optim/test_numpy_converter.py index 7a309330e8..a3b6c4ba54 100644 --- a/test/optim/test_numpy_converter.py +++ b/test/optim/test_numpy_converter.py @@ -4,15 +4,16 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from re import compile -from string import ascii_lowercase -from unittest.mock import MagicMock +from math import pi +from unittest.mock import MagicMock, patch +from warnings import catch_warnings, simplefilter import numpy as np import torch +from botorch.models import SingleTaskGP +from botorch.optim import numpy_converter from botorch.optim.numpy_converter import ( - create_name_filter, - get_parameters_and_bounds, + _scipy_objective_and_grad, module_to_array, set_params_with_array, ) @@ -34,43 +35,6 @@ def _get_index(property_dict, parameter_name): return idx -class TestCreateNameFilter(BotorchTestCase): - def test_create_name_filter(self): - with self.assertRaises(TypeError): - create_name_filter(("foo", compile("bar"), 1)) - - names = ascii_lowercase - name_filter = create_name_filter(iter(names[1::2])) - self.assertEqual(names[::2], "".join(filter(name_filter, names))) - - items = tuple(zip(names, range(len(names)))) - self.assertEqual(items[::2], tuple(filter(name_filter, items))) - - -class TestGetParametersAndBounds(BotorchTestCase): - def setUp(self): - self.module = GaussianLikelihood( - noise_constraint=GreaterThan(1e-6, initial_value=0.123), - ) - - def test_get_parameters_and_bounds(self): - module = GaussianLikelihood( - noise_constraint=GreaterThan(1e-6, initial_value=0.123), - ) - param_dict, bounds_dict = get_parameters_and_bounds(module) - self.assertTrue(1 == len(param_dict) == len(bounds_dict)) - - name, bounds = next(iter(bounds_dict.items())) - self.assertEqual(name, "noise_covar.raw_noise") - self.assertEqual(bounds, (-float("inf"), float("inf"))) - - mock_module = torch.nn.Module() - mock_module.named_parameters = MagicMock(return_value=module.named_parameters()) - param_dict2, bounds_dict2 = get_parameters_and_bounds(mock_module) - self.assertEqual(param_dict, param_dict2) - self.assertTrue(len(bounds_dict2) == 0) - - class TestModuleToArray(BotorchTestCase): def test_basic(self): for dtype in (torch.float, torch.double): @@ -84,7 +48,9 @@ def test_basic(self): model.to(device=self.device, dtype=dtype) mll = ExactMarginalLogLikelihood(likelihood, model) # test the basic case - x, pdict, bounds = module_to_array(module=mll) + with catch_warnings(): + simplefilter("ignore", category=DeprecationWarning) + x, pdict, bounds = module_to_array(module=mll) self.assertTrue(np.array_equal(x, np.zeros(5))) expected_sizes = { "likelihood.noise_covar.raw_noise": torch.Size([1]), @@ -110,9 +76,11 @@ def test_exclude(self): model.to(device=self.device, dtype=dtype) mll = ExactMarginalLogLikelihood(likelihood, model) # test the basic case - x, pdict, bounds = module_to_array( - module=mll, exclude={"model.mean_module.raw_constant"} - ) + with catch_warnings(): + simplefilter("ignore", category=DeprecationWarning) + x, pdict, bounds = module_to_array( + module=mll, exclude={"model.mean_module.raw_constant"} + ) self.assertTrue(np.array_equal(x, np.zeros(4))) expected_sizes = { "likelihood.noise_covar.raw_noise": torch.Size([1]), @@ -137,9 +105,12 @@ def test_manual_bounds(self): model.to(device=self.device, dtype=dtype) mll = ExactMarginalLogLikelihood(likelihood, model) # test the basic case - x, pdict, bounds = module_to_array( - module=mll, bounds={"model.covar_module.raw_lengthscale": (0.1, None)} - ) + with catch_warnings(): + simplefilter("ignore", category=DeprecationWarning) + x, pdict, bounds = module_to_array( + module=mll, + bounds={"model.covar_module.raw_lengthscale": (0.1, None)}, + ) self.assertTrue(np.array_equal(x, np.zeros(5))) expected_sizes = { "likelihood.noise_covar.raw_noise": torch.Size([1]), @@ -160,13 +131,15 @@ def test_manual_bounds(self): self.assertTrue(np.equal(bounds[0], lower_exp).all()) self.assertTrue(np.equal(bounds[1], np.full_like(x, np.inf)).all()) - x, pdict, bounds = module_to_array( - module=mll, - bounds={ - key: (-float("inf"), float("inf")) - for key, _ in mll.named_parameters() - }, - ) + with catch_warnings(): + simplefilter("ignore", category=DeprecationWarning) + x, pdict, bounds = module_to_array( + module=mll, + bounds={ + key: (-float("inf"), float("inf")) + for key, _ in mll.named_parameters() + }, + ) self.assertIsNone(bounds) def test_module_bounds(self): @@ -183,9 +156,12 @@ def test_module_bounds(self): model.to(device=self.device, dtype=dtype) mll = ExactMarginalLogLikelihood(likelihood, model) # test the basic case - x, pdict, bounds = module_to_array( - module=mll, bounds={"model.covar_module.raw_lengthscale": (0.1, None)} - ) + with catch_warnings(): + simplefilter("ignore", category=DeprecationWarning) + x, pdict, bounds = module_to_array( + module=mll, + bounds={"model.covar_module.raw_lengthscale": (0.1, None)}, + ) self.assertTrue(np.array_equal(x, np.zeros(5))) expected_sizes = { "likelihood.noise_covar.raw_noise": torch.Size([1]), @@ -216,12 +192,17 @@ def test_set_parameters(self): model.mean_module = ConstantMean() model.to(device=self.device, dtype=dtype) mll = ExactMarginalLogLikelihood(likelihood, model) - # get parameters - x, pdict, bounds = module_to_array(module=mll) - # Set parameters - mll = set_params_with_array(mll, np.array([1.0, 2.0, 3.0, 4.0, 5.0]), pdict) - z = dict(mll.named_parameters()) + with catch_warnings(): + # Get parameters + simplefilter("ignore", category=DeprecationWarning) + x, pdict, bounds = module_to_array(module=mll) + + # Set parameters + mll = set_params_with_array( + mll, np.array([1.0, 2.0, 3.0, 4.0, 5.0]), pdict + ) + z = dict(mll.named_parameters()) self.assertTrue( torch.equal( z["likelihood.noise_covar.raw_noise"], @@ -242,5 +223,48 @@ def test_set_parameters(self): ) # Extract again - x2, pdict2, bounds2 = module_to_array(module=mll) + with catch_warnings(): + simplefilter("ignore", category=DeprecationWarning) + x2, pdict2, bounds2 = module_to_array(module=mll) self.assertTrue(np.array_equal(x2, np.array([1.0, 2.0, 3.0, 4.0, 5.0]))) + + +class TestScipyObjectiveAndGrad(BotorchTestCase): + def setUp(self): + with torch.random.fork_rng(): + torch.manual_seed(0) + train_X = torch.linspace(0, 1, 10).unsqueeze(-1) + train_Y = torch.sin((2 * pi) * train_X) + train_Y = train_Y + 0.1 * torch.randn_like(train_Y) + + model = SingleTaskGP(train_X=train_X, train_Y=train_Y) + self.mll = ExactMarginalLogLikelihood(model.likelihood, model) + + def test_scipy_objective_and_grad(self): + with catch_warnings(): + simplefilter("ignore", category=DeprecationWarning) + x, property_dict, bounds = module_to_array(module=self.mll) + loss, grad = _scipy_objective_and_grad(x, self.mll, property_dict) + + _dist = self.mll.model(*self.mll.model.train_inputs) + _loss = -self.mll(_dist, self.mll.model.train_targets) + _loss.sum().backward() + _grad = torch.concat( + [self.mll.get_parameter(name).grad.view(-1) for name in property_dict] + ) + self.assertEqual(loss, _loss.detach().sum().item()) + self.assertTrue(np.allclose(grad, _grad.detach().numpy())) + + def _getter(*args, **kwargs): + raise RuntimeError("foo") + + _handler = MagicMock() + + with catch_warnings(), patch.multiple( + numpy_converter, + _get_extra_mll_args=_getter, + _handle_numerical_errors=_handler, + ): + simplefilter("ignore", category=DeprecationWarning) + _scipy_objective_and_grad(x, self.mll, property_dict) + self.assertEqual(_handler.call_count, 1) diff --git a/test/optim/test_utils.py b/test/optim/test_utils.py deleted file mode 100644 index c4a6d63344..0000000000 --- a/test/optim/test_utils.py +++ /dev/null @@ -1,672 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -import math -import warnings -from copy import deepcopy -from itertools import product -from string import ascii_lowercase -from unittest.mock import MagicMock, patch - -import numpy as np -import torch -from botorch import settings -from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction -from botorch.acquisition.monte_carlo import ( - qExpectedImprovement, - qNoisyExpectedImprovement, -) -from botorch.acquisition.multi_objective.max_value_entropy_search import ( - qMultiObjectiveMaxValueEntropy, -) -from botorch.acquisition.multi_objective.monte_carlo import ( - qExpectedHypervolumeImprovement, - qNoisyExpectedHypervolumeImprovement, -) -from botorch.exceptions import BotorchError -from botorch.exceptions.warnings import BotorchWarning -from botorch.models import ModelListGP, SingleTaskGP -from botorch.models.transforms.input import Warp -from botorch.optim import utils -from botorch.optim.numpy_converter import module_to_array -from botorch.optim.utils import ( - _expand_bounds, - _get_extra_mll_args, - _handle_numerical_errors, - _scipy_objective_and_grad, - allclose_mll, - columnwise_clamp, - del_attribute_ctx, - fix_features, - get_X_baseline, - parameter_rollback_ctx, - requires_grad_ctx, - sample_all_priors, - state_rollback_ctx, -) -from botorch.utils.multi_objective.box_decompositions.non_dominated import ( - FastNondominatedPartitioning, -) -from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior -from gpytorch.constraints import GreaterThan -from gpytorch.kernels.matern_kernel import MaternKernel -from gpytorch.kernels.scale_kernel import ScaleKernel -from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood -from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood -from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood -from gpytorch.priors import UniformPrior -from gpytorch.priors.prior import Prior -from gpytorch.priors.torch_priors import GammaPrior -from linear_operator.utils.errors import NanError, NotPSDError -from torch.nn import Module, Parameter - - -class DummyPrior(Prior): - arg_constraints = {} - - def rsample(self, sample_shape=torch.Size()): # noqa: B008 - raise NotImplementedError - - -class DummyPriorRuntimeError(Prior): - arg_constraints = {} - - def rsample(self, sample_shape=torch.Size()): # noqa: B008 - raise RuntimeError("Another runtime error.") - - -class TestColumnWiseClamp(BotorchTestCase): - def setUp(self): - super().setUp() - self.X = torch.tensor([[-2, 1], [0.5, -0.5]], device=self.device) - self.X_expected = torch.tensor([[-1, 0.5], [0.5, -0.5]], device=self.device) - - def test_column_wise_clamp_scalars(self): - X, X_expected = self.X, self.X_expected - with self.assertRaises(ValueError): - X_clmp = columnwise_clamp(X, 1, -1) - X_clmp = columnwise_clamp(X, -1, 0.5) - self.assertTrue(torch.equal(X_clmp, X_expected)) - X_clmp = columnwise_clamp(X, -3, 3) - self.assertTrue(torch.equal(X_clmp, X)) - - def test_column_wise_clamp_scalar_tensors(self): - X, X_expected = self.X, self.X_expected - with self.assertRaises(ValueError): - X_clmp = columnwise_clamp(X, torch.tensor(1), torch.tensor(-1)) - X_clmp = columnwise_clamp(X, torch.tensor(-1), torch.tensor(0.5)) - self.assertTrue(torch.equal(X_clmp, X_expected)) - X_clmp = columnwise_clamp(X, torch.tensor(-3), torch.tensor(3)) - self.assertTrue(torch.equal(X_clmp, X)) - - def test_column_wise_clamp_tensors(self): - X, X_expected = self.X, self.X_expected - with self.assertRaises(ValueError): - X_clmp = columnwise_clamp(X, torch.ones(2), torch.zeros(2)) - with self.assertRaises(RuntimeError): - X_clmp = columnwise_clamp(X, torch.zeros(3), torch.ones(3)) - X_clmp = columnwise_clamp(X, torch.tensor([-1, -1]), torch.tensor([0.5, 0.5])) - self.assertTrue(torch.equal(X_clmp, X_expected)) - X_clmp = columnwise_clamp(X, torch.tensor([-3, -3]), torch.tensor([3, 3])) - self.assertTrue(torch.equal(X_clmp, X)) - - def test_column_wise_clamp_full_dim_tensors(self): - X = torch.tensor([[[-1, 2, 0.5], [0.5, 3, 1.5]], [[0.5, 1, 0], [2, -2, 3]]]) - lower = torch.tensor([[[0, 0.5, 1], [0, 2, 2]], [[0, 2, 0], [1, -1, 0]]]) - upper = torch.tensor([[[1, 1.5, 1], [1, 4, 3]], [[1, 3, 0.5], [3, 1, 2.5]]]) - X_expected = torch.tensor( - [[[0, 1.5, 1], [0.5, 3, 2]], [[0.5, 2, 0], [2, -1, 2.5]]] - ) - X_clmp = columnwise_clamp(X, lower, upper) - self.assertTrue(torch.equal(X_clmp, X_expected)) - X_clmp = columnwise_clamp(X, lower - 5, upper + 5) - self.assertTrue(torch.equal(X_clmp, X)) - with self.assertRaises(ValueError): - X_clmp = columnwise_clamp(X, torch.ones_like(X), torch.zeros_like(X)) - with self.assertRaises(RuntimeError): - X_clmp = columnwise_clamp(X, lower.unsqueeze(-3), upper.unsqueeze(-3)) - - def test_column_wise_clamp_raise_on_violation(self): - X = self.X - with self.assertRaises(BotorchError): - X_clmp = columnwise_clamp( - X, torch.zeros(2), torch.ones(2), raise_on_violation=True - ) - X_clmp = columnwise_clamp( - X, torch.tensor([-3, -3]), torch.tensor([3, 3]), raise_on_violation=True - ) - self.assertTrue(torch.equal(X_clmp, X)) - - -class TestFixFeatures(BotorchTestCase): - def _getTensors(self): - X = torch.tensor([[-2, 1, 3], [0.5, -0.5, 1.0]], device=self.device) - X_null_two = torch.tensor([[-2, 1, 3], [0.5, -0.5, 1.0]], device=self.device) - X_expected = torch.tensor([[-1, 1, -2], [-1, -0.5, -2]], device=self.device) - X_expected_null_two = torch.tensor( - [[-1, 1, 3], [-1, -0.5, 1.0]], device=self.device - ) - return X, X_null_two, X_expected, X_expected_null_two - - def test_fix_features(self): - X, X_null_two, X_expected, X_expected_null_two = self._getTensors() - X.requires_grad_(True) - X_null_two.requires_grad_(True) - - X_fix = fix_features(X, {0: -1, 2: -2}) - X_fix_null_two = fix_features(X_null_two, {0: -1, 2: None}) - - self.assertTrue(torch.equal(X_fix, X_expected)) - self.assertTrue(torch.equal(X_fix_null_two, X_expected_null_two)) - - def f(X): - return X.sum() - - f(X).backward() - self.assertTrue(torch.equal(X.grad, torch.ones_like(X))) - X.grad.zero_() - - f(X_fix).backward() - self.assertTrue( - torch.equal( - X.grad, - torch.tensor([[0.0, 1.0, 0.0], [0.0, 1.0, 0.0]], device=self.device), - ) - ) - - f(X_null_two).backward() - self.assertTrue(torch.equal(X_null_two.grad, torch.ones_like(X))) - X_null_two.grad.zero_() - f(X_fix_null_two).backward() - self.assertTrue( - torch.equal( - X_null_two.grad, - torch.tensor([[0.0, 1.0, 0.0], [0.0, 1.0, 0.0]], device=self.device), - ) - ) - - -class TestGetExtraMllArgs(BotorchTestCase): - def test_get_extra_mll_args(self): - train_X = torch.rand(3, 5) - train_Y = torch.rand(3, 1) - model = SingleTaskGP(train_X=train_X, train_Y=train_Y) - - # test ExactMarginalLogLikelihood - exact_mll = ExactMarginalLogLikelihood(model.likelihood, model) - exact_extra_args = _get_extra_mll_args(mll=exact_mll) - self.assertEqual(len(exact_extra_args), 1) - self.assertTrue(torch.equal(exact_extra_args[0], train_X)) - - # test SumMarginalLogLikelihood - model2 = ModelListGP(model) - sum_mll = SumMarginalLogLikelihood(model2.likelihood, model2) - sum_mll_extra_args = _get_extra_mll_args(mll=sum_mll) - self.assertEqual(len(sum_mll_extra_args), 1) - self.assertEqual(len(sum_mll_extra_args[0]), 1) - self.assertTrue(torch.equal(sum_mll_extra_args[0][0], train_X)) - - # test unsupported MarginalLogLikelihood type - unsupported_mll = MarginalLogLikelihood(model.likelihood, model) - unsupported_mll_extra_args = _get_extra_mll_args(mll=unsupported_mll) - self.assertEqual(unsupported_mll_extra_args, []) - - -class TestExpandBounds(BotorchTestCase): - def test_expand_bounds(self): - X = torch.zeros(2, 3) - expected_bounds = torch.zeros(2, 3) - # bounds is float - bounds = 0.0 - expanded_bounds = _expand_bounds(bounds=bounds, X=X) - self.assertTrue(torch.equal(expected_bounds, expanded_bounds)) - # bounds is 0-d - bounds = torch.tensor(0.0) - expanded_bounds = _expand_bounds(bounds=bounds, X=X) - self.assertTrue(torch.equal(expected_bounds, expanded_bounds)) - # bounds is 1-d - bounds = torch.zeros(3) - expanded_bounds = _expand_bounds(bounds=bounds, X=X) - self.assertTrue(torch.equal(expected_bounds, expanded_bounds)) - # bounds is 2-d - bounds = torch.zeros(1, 3) - expanded_bounds = _expand_bounds(bounds=bounds, X=X) - self.assertTrue(torch.equal(expected_bounds, expanded_bounds)) - # bounds is > 2-d - bounds = torch.zeros(1, 1, 3) - with self.assertRaises(RuntimeError): - # X does not have a t-batch - expanded_bounds = _expand_bounds(bounds=bounds, X=X) - X = torch.zeros(4, 2, 3) - expanded_bounds = _expand_bounds(bounds=bounds, X=X) - self.assertTrue(torch.equal(expanded_bounds, torch.zeros_like(X))) - with self.assertRaises(RuntimeError): - # bounds is not broadcastable to X - expanded_bounds = _expand_bounds(bounds=torch.zeros(2, 1, 3), X=X) - # bounds is None - expanded_bounds = _expand_bounds(bounds=None, X=X) - self.assertIsNone(expanded_bounds) - - -class TestSampleAllPriors(BotorchTestCase): - def test_sample_all_priors(self): - for dtype in (torch.float, torch.double): - train_X = torch.rand(3, 5, device=self.device, dtype=dtype) - train_Y = torch.rand(3, 1, device=self.device, dtype=dtype) - model = SingleTaskGP(train_X=train_X, train_Y=train_Y) - mll = ExactMarginalLogLikelihood(model.likelihood, model) - mll.to(device=self.device, dtype=dtype) - original_state_dict = dict(deepcopy(mll.model.state_dict())) - sample_all_priors(model) - - # make sure one of the hyperparameters changed - self.assertTrue( - dict(model.state_dict())["likelihood.noise_covar.raw_noise"] - != original_state_dict["likelihood.noise_covar.raw_noise"] - ) - # check that lengthscales are all different - ls = model.covar_module.base_kernel.raw_lengthscale.view(-1).tolist() - self.assertTrue(all(ls[0] != ls[i]) for i in range(1, len(ls))) - - # change one of the priors to a dummy prior that does not support sampling - model.covar_module = ScaleKernel( - MaternKernel( - nu=2.5, - ard_num_dims=model.train_inputs[0].shape[-1], - batch_shape=model._aug_batch_shape, - lengthscale_prior=DummyPrior(), - ), - batch_shape=model._aug_batch_shape, - outputscale_prior=GammaPrior(2.0, 0.15), - ) - original_state_dict = dict(deepcopy(mll.model.state_dict())) - with warnings.catch_warnings(record=True) as ws, settings.debug(True): - sample_all_priors(model) - self.assertEqual(len(ws), 1) - self.assertTrue("rsample" in str(ws[0].message)) - - # change to dummy prior that raises an unrecognized RuntimeError - model.covar_module = ScaleKernel( - MaternKernel( - nu=2.5, - ard_num_dims=model.train_inputs[0].shape[-1], - batch_shape=model._aug_batch_shape, - lengthscale_prior=DummyPriorRuntimeError(), - ), - batch_shape=model._aug_batch_shape, - outputscale_prior=GammaPrior(2.0, 0.15), - ) - with self.assertRaises(RuntimeError): - sample_all_priors(model) - - # the lengthscale should not have changed because sampling is - # not implemented for DummyPrior - self.assertTrue( - torch.equal( - dict(model.state_dict())[ - "covar_module.base_kernel.raw_lengthscale" - ], - original_state_dict["covar_module.base_kernel.raw_lengthscale"], - ) - ) - - # set setting_closure to None and make sure RuntimeError is raised - prior_tuple = model.likelihood.noise_covar._priors["noise_prior"] - model.likelihood.noise_covar._priors["noise_prior"] = ( - prior_tuple[0], - prior_tuple[1], - None, - ) - with self.assertRaises(RuntimeError): - sample_all_priors(model) - - # test for error when sampling violates constraint - model = SingleTaskGP(train_X=train_X, train_Y=train_Y) - mll = ExactMarginalLogLikelihood(model.likelihood, model) - mll.to(device=self.device, dtype=dtype) - model.covar_module = ScaleKernel( - MaternKernel( - nu=2.5, - ard_num_dims=model.train_inputs[0].shape[-1], - batch_shape=model._aug_batch_shape, - lengthscale_prior=GammaPrior(3.0, 6.0), - ), - batch_shape=model._aug_batch_shape, - outputscale_prior=UniformPrior(1.0, 2.0), - outputscale_constraint=GreaterThan(3.0), - ) - original_state_dict = dict(deepcopy(mll.model.state_dict())) - with self.assertRaises(RuntimeError): - sample_all_priors(model) - - -class TestHelpers(BotorchTestCase): - def test_handle_numerical_errors(self): - x = np.zeros(1) - - with self.assertRaisesRegex(NotPSDError, "foo"): - _handle_numerical_errors(error=NotPSDError("foo"), x=x) - - for error in ( - NanError(), - RuntimeError("singular"), - RuntimeError("input is not positive-definite"), - ): - fake_loss, fake_grad = _handle_numerical_errors(error=error, x=x) - self.assertTrue(math.isnan(fake_loss)) - self.assertEqual(fake_grad.shape, x.shape) - self.assertTrue(np.isnan(fake_grad).all()) - - with self.assertRaisesRegex(RuntimeError, "foo"): - _handle_numerical_errors(error=RuntimeError("foo"), x=x) - - -class TestGetXBaseline(BotorchTestCase): - def test_get_X_baseline(self): - tkwargs = {"device": self.device} - for dtype in (torch.float, torch.double): - tkwargs["dtype"] = dtype - X_train = torch.rand(20, 2, **tkwargs) - model = MockModel( - MockPosterior(mean=(2 * X_train + 1).sum(dim=-1, keepdim=True)) - ) - # test NEI with X_baseline - acqf = qNoisyExpectedImprovement( - model, X_baseline=X_train[:2], cache_root=False - ) - X = get_X_baseline(acq_function=acqf) - self.assertTrue(torch.equal(X, acqf.X_baseline)) - # test EI without X_baseline - acqf = qExpectedImprovement(model, best_f=0.0) - - with warnings.catch_warnings(record=True) as w, settings.debug(True): - - X_rnd = get_X_baseline( - acq_function=acqf, - ) - self.assertEqual(len(w), 1) - self.assertTrue(issubclass(w[-1].category, BotorchWarning)) - self.assertIsNone(X_rnd) - - # set train inputs - model.train_inputs = (X_train,) - X = get_X_baseline( - acq_function=acqf, - ) - self.assertTrue(torch.equal(X, X_train)) - # test that we fail back to train_inputs if X_baseline is an empty tensor - acqf.register_buffer("X_baseline", X_train[:0]) - X = get_X_baseline( - acq_function=acqf, - ) - self.assertTrue(torch.equal(X, X_train)) - - # test acquisition function without X_baseline or model - acqf = FixedFeatureAcquisitionFunction(acqf, d=2, columns=[0], values=[0]) - with warnings.catch_warnings(record=True) as w, settings.debug(True): - X_rnd = get_X_baseline( - acq_function=acqf, - ) - self.assertEqual(len(w), 1) - self.assertTrue(issubclass(w[-1].category, BotorchWarning)) - self.assertIsNone(X_rnd) - - Y_train = 2 * X_train[:2] + 1 - moo_model = MockModel(MockPosterior(mean=Y_train, samples=Y_train)) - ref_point = torch.zeros(2, **tkwargs) - # test NEHVI with X_baseline - acqf = qNoisyExpectedHypervolumeImprovement( - moo_model, - ref_point=ref_point, - X_baseline=X_train[:2], - cache_root=False, - ) - X = get_X_baseline( - acq_function=acqf, - ) - self.assertTrue(torch.equal(X, acqf.X_baseline)) - # test qEHVI without train_inputs - acqf = qExpectedHypervolumeImprovement( - moo_model, - ref_point=ref_point, - partitioning=FastNondominatedPartitioning( - ref_point=ref_point, - Y=Y_train, - ), - ) - # test extracting train_inputs from model list GP - model_list = ModelListGP( - SingleTaskGP(X_train, Y_train[:, :1]), - SingleTaskGP(X_train, Y_train[:, 1:]), - ) - acqf = qExpectedHypervolumeImprovement( - model_list, - ref_point=ref_point, - partitioning=FastNondominatedPartitioning( - ref_point=ref_point, - Y=Y_train, - ), - ) - X = get_X_baseline( - acq_function=acqf, - ) - self.assertTrue(torch.equal(X, X_train)) - - # test MESMO for which we need to use - # `acqf.mo_model` - batched_mo_model = SingleTaskGP(X_train, Y_train) - acqf = qMultiObjectiveMaxValueEntropy( - batched_mo_model, - sample_pareto_frontiers=lambda model: torch.rand(10, 2, **tkwargs), - ) - X = get_X_baseline( - acq_function=acqf, - ) - self.assertTrue(torch.equal(X, X_train)) - # test that if there is an input transform that is applied - # to the train_inputs when the model is in eval mode, we - # extract the untransformed train_inputs - model = SingleTaskGP( - X_train, Y_train[:, :1], input_transform=Warp(indices=[0, 1]) - ) - model.eval() - self.assertFalse(torch.equal(model.train_inputs[0], X_train)) - acqf = qExpectedImprovement(model, best_f=0.0) - X = get_X_baseline( - acq_function=acqf, - ) - self.assertTrue(torch.equal(X, X_train)) - - -class TestAllcloseMLL(BotorchTestCase): - def setUp(self): - with torch.random.fork_rng(): - torch.manual_seed(0) - train_X = torch.linspace(0, 1, 10).unsqueeze(-1) - train_Y = torch.sin((2 * math.pi) * train_X) - train_Y = train_Y + 0.1 * torch.randn_like(train_Y) - - self.mlls = [] - for nu in (1.5, 2.5): - model = SingleTaskGP(train_X=train_X, train_Y=train_Y) - model.covar_module.base_kernel.nu = nu - self.mlls.append(ExactMarginalLogLikelihood(model.likelihood, model)) - - def test_allclose_mll(self): - self.assertTrue(allclose_mll(a=self.mlls[0], b=self.mlls[0])) - for transform_a, transform_b in product( - *(2 * [(None, lambda vals: torch.zeros_like(vals))]) - ): - out = allclose_mll( - a=self.mlls[0], - b=self.mlls[1], - transform_a=transform_a, - transform_b=transform_b, - ) - self.assertEqual(out, transform_a is not None and transform_b is not None) - - -class TestContextManagers(BotorchTestCase): - def setUp(self): - module = self.module = Module() - for i, name in enumerate(ascii_lowercase[:3], start=1): - values = torch.rand(2).to(torch.float16) - param = Parameter(values.to(torch.float64), requires_grad=bool(i % 2)) - module.register_parameter(name, param) - - def test_del_attribute_ctx(self): - # Test temporary removal of attributes - a = self.module.a - b = self.module.b - with del_attribute_ctx(self.module, "a", "b"): - self.assertIsNone(getattr(self.module, "a", None)) - self.assertIsNone(getattr(self.module, "b", None)) - self.assertTrue(self.module.c is not None) - - # Test that removed attributes get restored - self.assertTrue(self.module.a.equal(a)) - self.assertTrue(self.module.b.equal(b)) - - with self.assertRaisesRegex(ValueError, "Attribute .* missing"): - with del_attribute_ctx(self.module, "z", enforce_hasattr=True): - pass # pragma: no cover - - def test_requires_grad_ctx(self): - # Test temporary setting of requires_grad field - with requires_grad_ctx(self.module, assignments={"a": False, "b": True}): - self.assertTrue(not self.module.a.requires_grad) - self.assertTrue(self.module.b.requires_grad) - self.assertTrue(self.module.c.requires_grad) - - # Test that requires_grad fields get restored - self.assertTrue(self.module.a.requires_grad) - self.assertTrue(not self.module.b.requires_grad) - self.assertTrue(self.module.c.requires_grad) - - def test_parameter_rollback_ctx(self): - # Test that only unfiltered parameters get rolled back - a = self.module.a.detach().clone() - b = self.module.b.detach().clone() - c = self.module.c.detach().clone() - with parameter_rollback_ctx( - module=self.module, - name_filter=lambda name: name in ("a", "b"), - requires_grad=True, - dtype=torch.float16, - ) as ckpt: - for (tnsr, _) in ckpt.values(): # test whether dtype is obeyed - self.assertEqual(torch.float16, tnsr.dtype) - - self.module.a.data[...] = 0 - self.module.b.data[...] = 0 - self.module.c.data[...] = 0 - - self.assertTrue(self.module.a.equal(a)) - self.assertTrue(self.module.b.eq(0).all()) - self.assertTrue(self.module.c.eq(0).all()) - - # Test that changes to checkpoint dict are reflected in rollback state - with parameter_rollback_ctx(self.module) as ckpt: - self.module.a.data[...] = 1 - self.module.b.data[...] = 1 - self.module.c.data[...] = 1 - del ckpt["a"] - - self.assertTrue(self.module.a.eq(1).all()) - self.assertTrue(self.module.b.eq(0).all()) - self.assertTrue(self.module.c.eq(0).all()) - - # Test rolling back to a user-provided checkpoint - checkpoint = {"a": (a, {}), "b": (b, {}), "c": (c, {})} - with parameter_rollback_ctx(module=self.module, checkpoint=checkpoint): - pass - - self.assertTrue(self.module.a.equal(a)) - self.assertTrue(self.module.b.equal(b)) - self.assertTrue(self.module.c.equal(c)) - - def test_state_rollback_ctx(self): - # Test that only unfiltered objects get rolled back - a = self.module.a.detach().clone() - b = self.module.b.detach().clone() - c = self.module.c.detach().clone() - with state_rollback_ctx( - self.module, lambda name: name == "a", dtype=torch.float16 - ) as ckpt: - for (tnsr, _) in ckpt.values(): # test whether dtype is obeyed - self.assertEqual(torch.float16, tnsr.dtype) - - self.module.a.data[...] = 0 - self.module.b.data[...] = 0 - self.module.c.data[...] = 0 - - self.assertTrue(self.module.a.equal(a)) - self.assertTrue(self.module.b.eq(0).all()) - self.assertTrue(self.module.c.eq(0).all()) - - # Test that changes to checkpoint dict are reflected in rollback state - with state_rollback_ctx(self.module) as ckpt: - self.module.a.data[...] = 1 - self.module.b.data[...] = 1 - self.module.c.data[...] = 1 - del ckpt["a"] - - self.assertTrue(self.module.a.eq(1).all()) - self.assertTrue(self.module.b.eq(0).all()) - self.assertTrue(self.module.c.eq(0).all()) - - # Test rolling back to a user-provided checkpoint - checkpoint = {"a": (a, {}), "b": (b, {}), "c": (c, {})} - with state_rollback_ctx(module=self.module, checkpoint=checkpoint): - pass - self.assertTrue(self.module.a.equal(a)) - self.assertTrue(self.module.b.equal(b)) - self.assertTrue(self.module.c.equal(c)) - - # Test that items in checkpoint get inserted into state_dict - with del_attribute_ctx(self.module, "a"): - with self.assertRaisesRegex( # should fail when attempting to rollback - RuntimeError, r'Unexpected key\(s\) in state_dict: "a"' - ): - with state_rollback_ctx(module=self.module, checkpoint=checkpoint): - pass - - -class TestScipyObjectiveAndGrad(BotorchTestCase): - def setUp(self): - with torch.random.fork_rng(): - torch.manual_seed(0) - train_X = torch.linspace(0, 1, 10).unsqueeze(-1) - train_Y = torch.sin((2 * math.pi) * train_X) - train_Y = train_Y + 0.1 * torch.randn_like(train_Y) - - model = SingleTaskGP(train_X=train_X, train_Y=train_Y) - self.mll = ExactMarginalLogLikelihood(model.likelihood, model) - - def test_scipy_objective_and_grad(self): - x, property_dict, bounds = module_to_array(module=self.mll) - loss, grad = _scipy_objective_and_grad(x, self.mll, property_dict) - - _dist = self.mll.model(*self.mll.model.train_inputs) - _loss = -self.mll(_dist, self.mll.model.train_targets) - _loss.sum().backward() - _grad = torch.concat( - [self.mll.get_parameter(name).grad.view(-1) for name in property_dict] - ) - self.assertEqual(loss, _loss.detach().sum().item()) - self.assertTrue(np.allclose(grad, _grad.detach().numpy())) - - def _getter(*args, **kwargs): - raise RuntimeError("foo") - - _handler = MagicMock() - with patch.multiple( - utils, _get_extra_mll_args=_getter, _handle_numerical_errors=_handler - ): - _scipy_objective_and_grad(x, self.mll, property_dict) - self.assertEqual(_handler.call_count, 1) diff --git a/test/optim/utils/__init__.py b/test/optim/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/optim/utils/test_acquisition_utils.py b/test/optim/utils/test_acquisition_utils.py new file mode 100644 index 0000000000..f11ca79435 --- /dev/null +++ b/test/optim/utils/test_acquisition_utils.py @@ -0,0 +1,261 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import warnings + +import torch +from botorch import settings +from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction +from botorch.acquisition.monte_carlo import ( + qExpectedImprovement, + qNoisyExpectedImprovement, +) +from botorch.acquisition.multi_objective.max_value_entropy_search import ( + qMultiObjectiveMaxValueEntropy, +) +from botorch.acquisition.multi_objective.monte_carlo import ( + qExpectedHypervolumeImprovement, + qNoisyExpectedHypervolumeImprovement, +) +from botorch.exceptions import BotorchError +from botorch.exceptions.warnings import BotorchWarning +from botorch.models import ModelListGP, SingleTaskGP +from botorch.models.transforms.input import Warp +from botorch.optim.utils import columnwise_clamp, fix_features, get_X_baseline +from botorch.utils.multi_objective.box_decompositions.non_dominated import ( + FastNondominatedPartitioning, +) +from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior + + +class TestColumnWiseClamp(BotorchTestCase): + def setUp(self): + super().setUp() + self.X = torch.tensor([[-2, 1], [0.5, -0.5]], device=self.device) + self.X_expected = torch.tensor([[-1, 0.5], [0.5, -0.5]], device=self.device) + + def test_column_wise_clamp_scalars(self): + X, X_expected = self.X, self.X_expected + with self.assertRaises(ValueError): + X_clmp = columnwise_clamp(X, 1, -1) + X_clmp = columnwise_clamp(X, -1, 0.5) + self.assertTrue(torch.equal(X_clmp, X_expected)) + X_clmp = columnwise_clamp(X, -3, 3) + self.assertTrue(torch.equal(X_clmp, X)) + + def test_column_wise_clamp_scalar_tensors(self): + X, X_expected = self.X, self.X_expected + with self.assertRaises(ValueError): + X_clmp = columnwise_clamp(X, torch.tensor(1), torch.tensor(-1)) + X_clmp = columnwise_clamp(X, torch.tensor(-1), torch.tensor(0.5)) + self.assertTrue(torch.equal(X_clmp, X_expected)) + X_clmp = columnwise_clamp(X, torch.tensor(-3), torch.tensor(3)) + self.assertTrue(torch.equal(X_clmp, X)) + + def test_column_wise_clamp_tensors(self): + X, X_expected = self.X, self.X_expected + with self.assertRaises(ValueError): + X_clmp = columnwise_clamp(X, torch.ones(2), torch.zeros(2)) + with self.assertRaises(RuntimeError): + X_clmp = columnwise_clamp(X, torch.zeros(3), torch.ones(3)) + X_clmp = columnwise_clamp(X, torch.tensor([-1, -1]), torch.tensor([0.5, 0.5])) + self.assertTrue(torch.equal(X_clmp, X_expected)) + X_clmp = columnwise_clamp(X, torch.tensor([-3, -3]), torch.tensor([3, 3])) + self.assertTrue(torch.equal(X_clmp, X)) + + def test_column_wise_clamp_full_dim_tensors(self): + X = torch.tensor([[[-1, 2, 0.5], [0.5, 3, 1.5]], [[0.5, 1, 0], [2, -2, 3]]]) + lower = torch.tensor([[[0, 0.5, 1], [0, 2, 2]], [[0, 2, 0], [1, -1, 0]]]) + upper = torch.tensor([[[1, 1.5, 1], [1, 4, 3]], [[1, 3, 0.5], [3, 1, 2.5]]]) + X_expected = torch.tensor( + [[[0, 1.5, 1], [0.5, 3, 2]], [[0.5, 2, 0], [2, -1, 2.5]]] + ) + X_clmp = columnwise_clamp(X, lower, upper) + self.assertTrue(torch.equal(X_clmp, X_expected)) + X_clmp = columnwise_clamp(X, lower - 5, upper + 5) + self.assertTrue(torch.equal(X_clmp, X)) + with self.assertRaises(ValueError): + X_clmp = columnwise_clamp(X, torch.ones_like(X), torch.zeros_like(X)) + with self.assertRaises(RuntimeError): + X_clmp = columnwise_clamp(X, lower.unsqueeze(-3), upper.unsqueeze(-3)) + + def test_column_wise_clamp_raise_on_violation(self): + X = self.X + with self.assertRaises(BotorchError): + X_clmp = columnwise_clamp( + X, torch.zeros(2), torch.ones(2), raise_on_violation=True + ) + X_clmp = columnwise_clamp( + X, torch.tensor([-3, -3]), torch.tensor([3, 3]), raise_on_violation=True + ) + self.assertTrue(torch.equal(X_clmp, X)) + + +class TestFixFeatures(BotorchTestCase): + def _getTensors(self): + X = torch.tensor([[-2, 1, 3], [0.5, -0.5, 1.0]], device=self.device) + X_null_two = torch.tensor([[-2, 1, 3], [0.5, -0.5, 1.0]], device=self.device) + X_expected = torch.tensor([[-1, 1, -2], [-1, -0.5, -2]], device=self.device) + X_expected_null_two = torch.tensor( + [[-1, 1, 3], [-1, -0.5, 1.0]], device=self.device + ) + return X, X_null_two, X_expected, X_expected_null_two + + def test_fix_features(self): + X, X_null_two, X_expected, X_expected_null_two = self._getTensors() + X.requires_grad_(True) + X_null_two.requires_grad_(True) + + X_fix = fix_features(X, {0: -1, 2: -2}) + X_fix_null_two = fix_features(X_null_two, {0: -1, 2: None}) + + self.assertTrue(torch.equal(X_fix, X_expected)) + self.assertTrue(torch.equal(X_fix_null_two, X_expected_null_two)) + + def f(X): + return X.sum() + + f(X).backward() + self.assertTrue(torch.equal(X.grad, torch.ones_like(X))) + X.grad.zero_() + + f(X_fix).backward() + self.assertTrue( + torch.equal( + X.grad, + torch.tensor([[0.0, 1.0, 0.0], [0.0, 1.0, 0.0]], device=self.device), + ) + ) + + f(X_null_two).backward() + self.assertTrue(torch.equal(X_null_two.grad, torch.ones_like(X))) + X_null_two.grad.zero_() + f(X_fix_null_two).backward() + self.assertTrue( + torch.equal( + X_null_two.grad, + torch.tensor([[0.0, 1.0, 0.0], [0.0, 1.0, 0.0]], device=self.device), + ) + ) + + +class TestGetXBaseline(BotorchTestCase): + def test_get_X_baseline(self): + tkwargs = {"device": self.device} + for dtype in (torch.float, torch.double): + tkwargs["dtype"] = dtype + X_train = torch.rand(20, 2, **tkwargs) + model = MockModel( + MockPosterior(mean=(2 * X_train + 1).sum(dim=-1, keepdim=True)) + ) + # test NEI with X_baseline + acqf = qNoisyExpectedImprovement( + model, X_baseline=X_train[:2], cache_root=False + ) + X = get_X_baseline(acq_function=acqf) + self.assertTrue(torch.equal(X, acqf.X_baseline)) + # test EI without X_baseline + acqf = qExpectedImprovement(model, best_f=0.0) + + with warnings.catch_warnings(record=True) as w, settings.debug(True): + + X_rnd = get_X_baseline( + acq_function=acqf, + ) + self.assertEqual(len(w), 1) + self.assertTrue(issubclass(w[-1].category, BotorchWarning)) + self.assertIsNone(X_rnd) + + # set train inputs + model.train_inputs = (X_train,) + X = get_X_baseline( + acq_function=acqf, + ) + self.assertTrue(torch.equal(X, X_train)) + # test that we fail back to train_inputs if X_baseline is an empty tensor + acqf.register_buffer("X_baseline", X_train[:0]) + X = get_X_baseline( + acq_function=acqf, + ) + self.assertTrue(torch.equal(X, X_train)) + + # test acquisition function without X_baseline or model + acqf = FixedFeatureAcquisitionFunction(acqf, d=2, columns=[0], values=[0]) + with warnings.catch_warnings(record=True) as w, settings.debug(True): + X_rnd = get_X_baseline( + acq_function=acqf, + ) + self.assertEqual(len(w), 1) + self.assertTrue(issubclass(w[-1].category, BotorchWarning)) + self.assertIsNone(X_rnd) + + Y_train = 2 * X_train[:2] + 1 + moo_model = MockModel(MockPosterior(mean=Y_train, samples=Y_train)) + ref_point = torch.zeros(2, **tkwargs) + # test NEHVI with X_baseline + acqf = qNoisyExpectedHypervolumeImprovement( + moo_model, + ref_point=ref_point, + X_baseline=X_train[:2], + cache_root=False, + ) + X = get_X_baseline( + acq_function=acqf, + ) + self.assertTrue(torch.equal(X, acqf.X_baseline)) + # test qEHVI without train_inputs + acqf = qExpectedHypervolumeImprovement( + moo_model, + ref_point=ref_point, + partitioning=FastNondominatedPartitioning( + ref_point=ref_point, + Y=Y_train, + ), + ) + # test extracting train_inputs from model list GP + model_list = ModelListGP( + SingleTaskGP(X_train, Y_train[:, :1]), + SingleTaskGP(X_train, Y_train[:, 1:]), + ) + acqf = qExpectedHypervolumeImprovement( + model_list, + ref_point=ref_point, + partitioning=FastNondominatedPartitioning( + ref_point=ref_point, + Y=Y_train, + ), + ) + X = get_X_baseline( + acq_function=acqf, + ) + self.assertTrue(torch.equal(X, X_train)) + + # test MESMO for which we need to use + # `acqf.mo_model` + batched_mo_model = SingleTaskGP(X_train, Y_train) + acqf = qMultiObjectiveMaxValueEntropy( + batched_mo_model, + sample_pareto_frontiers=lambda model: torch.rand(10, 2, **tkwargs), + ) + X = get_X_baseline( + acq_function=acqf, + ) + self.assertTrue(torch.equal(X, X_train)) + # test that if there is an input transform that is applied + # to the train_inputs when the model is in eval mode, we + # extract the untransformed train_inputs + model = SingleTaskGP( + X_train, Y_train[:, :1], input_transform=Warp(indices=[0, 1]) + ) + model.eval() + self.assertFalse(torch.equal(model.train_inputs[0], X_train)) + acqf = qExpectedImprovement(model, best_f=0.0) + X = get_X_baseline( + acq_function=acqf, + ) + self.assertTrue(torch.equal(X, X_train)) diff --git a/test/optim/utils/test_common.py b/test/optim/utils/test_common.py new file mode 100644 index 0000000000..713a7acf15 --- /dev/null +++ b/test/optim/utils/test_common.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from functools import partial +from warnings import catch_warnings, warn + +import numpy as np +from botorch.optim.utils import _handle_numerical_errors, _warning_handler_template +from botorch.utils.testing import BotorchTestCase +from linear_operator.utils.errors import NanError, NotPSDError + + +class TestUtilsCommon(BotorchTestCase): + def test_handle_numerical_errors(self): + x = np.zeros(1) + + with self.assertRaisesRegex(NotPSDError, "foo"): + _handle_numerical_errors(error=NotPSDError("foo"), x=x) + + for error in ( + NanError(), + RuntimeError("singular"), + RuntimeError("input is not positive-definite"), + ): + fake_loss, fake_grad = _handle_numerical_errors(error=error, x=x) + self.assertTrue(np.isnan(fake_loss)) + self.assertEqual(fake_grad.shape, x.shape) + self.assertTrue(np.isnan(fake_grad).all()) + + with self.assertRaisesRegex(RuntimeError, "foo"): + _handle_numerical_errors(error=RuntimeError("foo"), x=x) + + def test_warning_handler_template(self): + with catch_warnings(record=True) as ws: + warn(DeprecationWarning("foo")) + warn(RuntimeWarning("bar")) + + self.assertFalse(any(_warning_handler_template(w) for w in ws)) + handler = partial( + _warning_handler_template, + debug=lambda w: issubclass(w.category, DeprecationWarning), + rethrow=lambda w: True, + ) + with self.assertLogs(level="DEBUG") as logs, catch_warnings(record=True) as _ws: + self.assertTrue(all(handler(w) for w in ws)) + self.assertEqual(1, len(logs.output)) + self.assertTrue("foo" in logs.output[0]) + self.assertEqual(1, len(_ws)) + self.assertEqual("bar", str(_ws[0].message)) diff --git a/test/optim/utils/test_model_utils.py b/test/optim/utils/test_model_utils.py new file mode 100644 index 0000000000..87d88ebadf --- /dev/null +++ b/test/optim/utils/test_model_utils.py @@ -0,0 +1,283 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import math +import re +import warnings +from copy import deepcopy +from itertools import product +from string import ascii_lowercase +from unittest.mock import MagicMock, patch + +import torch +from botorch import settings +from botorch.models import ModelListGP, SingleTaskGP +from botorch.optim.utils import ( + _get_extra_mll_args, + allclose_mll, + get_data_loader, + get_name_filter, + get_parameters, + get_parameters_and_bounds, + model_utils, + sample_all_priors, +) +from botorch.utils.testing import BotorchTestCase +from gpytorch.constraints import GreaterThan +from gpytorch.kernels.matern_kernel import MaternKernel +from gpytorch.kernels.scale_kernel import ScaleKernel +from gpytorch.likelihoods import GaussianLikelihood +from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood +from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood +from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood +from gpytorch.priors import UniformPrior +from gpytorch.priors.prior import Prior +from gpytorch.priors.torch_priors import GammaPrior + + +class DummyPrior(Prior): + arg_constraints = {} + + def rsample(self, sample_shape=torch.Size()): # noqa: B008 + raise NotImplementedError + + +class DummyPriorRuntimeError(Prior): + arg_constraints = {} + + def rsample(self, sample_shape=torch.Size()): # noqa: B008 + raise RuntimeError("Another runtime error.") + + +class TestGetExtraMllArgs(BotorchTestCase): + def test_get_extra_mll_args(self): + train_X = torch.rand(3, 5) + train_Y = torch.rand(3, 1) + model = SingleTaskGP(train_X=train_X, train_Y=train_Y) + + # test ExactMarginalLogLikelihood + exact_mll = ExactMarginalLogLikelihood(model.likelihood, model) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=DeprecationWarning) + exact_extra_args = _get_extra_mll_args(mll=exact_mll) + self.assertEqual(len(exact_extra_args), 1) + self.assertTrue(torch.equal(exact_extra_args[0], train_X)) + + # test SumMarginalLogLikelihood + model2 = ModelListGP(model) + sum_mll = SumMarginalLogLikelihood(model2.likelihood, model2) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=DeprecationWarning) + sum_mll_extra_args = _get_extra_mll_args(mll=sum_mll) + self.assertEqual(len(sum_mll_extra_args), 1) + self.assertEqual(len(sum_mll_extra_args[0]), 1) + self.assertTrue(torch.equal(sum_mll_extra_args[0][0], train_X)) + + # test unsupported MarginalLogLikelihood type + unsupported_mll = MarginalLogLikelihood(model.likelihood, model) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=DeprecationWarning) + unsupported_mll_extra_args = _get_extra_mll_args(mll=unsupported_mll) + self.assertEqual(unsupported_mll_extra_args, []) + + +class TestGetDataLoader(BotorchTestCase): + def setUp(self): + super().setUp() + with torch.random.fork_rng(): + torch.random.manual_seed(0) + train_X = torch.rand(3, 5, device=self.device) + train_Y = torch.rand(3, 1, device=self.device) + + self.model = SingleTaskGP(train_X=train_X, train_Y=train_Y).to(torch.float64) + + def test_get_data_loader(self): + data_loader = get_data_loader(self.model) + self.assertEqual(data_loader.batch_size, len(self.model.train_targets)) + + train_X, train_Y = next(iter(data_loader)) + self.assertTrue(self.model.train_inputs[0].equal(train_X)) + self.assertTrue(self.model.train_targets.equal(train_Y)) + + _TensorDataset = MagicMock(return_value="foo") + _DataLoader = MagicMock() + with patch.multiple( + model_utils, TensorDataset=_TensorDataset, DataLoader=_DataLoader + ): + model_utils.get_data_loader(self.model, batch_size=2, shuffle=True) + _DataLoader.assert_called_once_with( + dataset="foo", + batch_size=2, + shuffle=True, + ) + + +class TestGetParameters(BotorchTestCase): + def setUp(self): + self.module = GaussianLikelihood( + noise_constraint=GreaterThan(1e-6, initial_value=0.123), + ) + + def test_get_parameters(self): + self.assertEqual(0, len(get_parameters(self.module, requires_grad=False))) + + params = get_parameters(self.module) + self.assertTrue(1 == len(params)) + self.assertEqual(next(iter(params)), "noise_covar.raw_noise") + self.assertTrue( + self.module.noise_covar.raw_noise.equal(next(iter(params.values()))) + ) + + def test_get_parameters_and_bounds(self): + param_dict, bounds_dict = get_parameters_and_bounds(self.module) + self.assertTrue(1 == len(param_dict) == len(bounds_dict)) + + name, bounds = next(iter(bounds_dict.items())) + self.assertEqual(name, "noise_covar.raw_noise") + self.assertEqual(bounds, (-float("inf"), float("inf"))) + + mock_module = torch.nn.Module() + mock_module.named_parameters = MagicMock( + return_value=self.module.named_parameters() + ) + param_dict2, bounds_dict2 = get_parameters_and_bounds(mock_module) + self.assertEqual(param_dict, param_dict2) + self.assertTrue(len(bounds_dict2) == 0) + + +class TestGetNameFilter(BotorchTestCase): + def test_get_name_filter(self): + with self.assertRaisesRegex(TypeError, "Expected `patterns` to contain"): + get_name_filter(("foo", re.compile("bar"), 1)) + + names = ascii_lowercase + name_filter = get_name_filter(iter(names[1::2])) + self.assertEqual(names[::2], "".join(filter(name_filter, names))) + + items = tuple(zip(names, range(len(names)))) + self.assertEqual(items[::2], tuple(filter(name_filter, items))) + + +class TestSampleAllPriors(BotorchTestCase): + def test_sample_all_priors(self): + for dtype in (torch.float, torch.double): + train_X = torch.rand(3, 5, device=self.device, dtype=dtype) + train_Y = torch.rand(3, 1, device=self.device, dtype=dtype) + model = SingleTaskGP(train_X=train_X, train_Y=train_Y) + mll = ExactMarginalLogLikelihood(model.likelihood, model) + mll.to(device=self.device, dtype=dtype) + original_state_dict = dict(deepcopy(mll.model.state_dict())) + sample_all_priors(model) + + # make sure one of the hyperparameters changed + self.assertTrue( + dict(model.state_dict())["likelihood.noise_covar.raw_noise"] + != original_state_dict["likelihood.noise_covar.raw_noise"] + ) + # check that lengthscales are all different + ls = model.covar_module.base_kernel.raw_lengthscale.view(-1).tolist() + self.assertTrue(all(ls[0] != ls[i]) for i in range(1, len(ls))) + + # change one of the priors to a dummy prior that does not support sampling + model.covar_module = ScaleKernel( + MaternKernel( + nu=2.5, + ard_num_dims=model.train_inputs[0].shape[-1], + batch_shape=model._aug_batch_shape, + lengthscale_prior=DummyPrior(), + ), + batch_shape=model._aug_batch_shape, + outputscale_prior=GammaPrior(2.0, 0.15), + ) + original_state_dict = dict(deepcopy(mll.model.state_dict())) + with warnings.catch_warnings(record=True) as ws, settings.debug(True): + sample_all_priors(model) + self.assertEqual(len(ws), 1) + self.assertTrue("rsample" in str(ws[0].message)) + + # change to dummy prior that raises an unrecognized RuntimeError + model.covar_module = ScaleKernel( + MaternKernel( + nu=2.5, + ard_num_dims=model.train_inputs[0].shape[-1], + batch_shape=model._aug_batch_shape, + lengthscale_prior=DummyPriorRuntimeError(), + ), + batch_shape=model._aug_batch_shape, + outputscale_prior=GammaPrior(2.0, 0.15), + ) + with self.assertRaises(RuntimeError): + sample_all_priors(model) + + # the lengthscale should not have changed because sampling is + # not implemented for DummyPrior + self.assertTrue( + torch.equal( + dict(model.state_dict())[ + "covar_module.base_kernel.raw_lengthscale" + ], + original_state_dict["covar_module.base_kernel.raw_lengthscale"], + ) + ) + + # set setting_closure to None and make sure RuntimeError is raised + prior_tuple = model.likelihood.noise_covar._priors["noise_prior"] + model.likelihood.noise_covar._priors["noise_prior"] = ( + prior_tuple[0], + prior_tuple[1], + None, + ) + with self.assertRaises(RuntimeError): + sample_all_priors(model) + + # test for error when sampling violates constraint + model = SingleTaskGP(train_X=train_X, train_Y=train_Y) + mll = ExactMarginalLogLikelihood(model.likelihood, model) + mll.to(device=self.device, dtype=dtype) + model.covar_module = ScaleKernel( + MaternKernel( + nu=2.5, + ard_num_dims=model.train_inputs[0].shape[-1], + batch_shape=model._aug_batch_shape, + lengthscale_prior=GammaPrior(3.0, 6.0), + ), + batch_shape=model._aug_batch_shape, + outputscale_prior=UniformPrior(1.0, 2.0), + outputscale_constraint=GreaterThan(3.0), + ) + original_state_dict = dict(deepcopy(mll.model.state_dict())) + with self.assertRaises(RuntimeError): + sample_all_priors(model) + + +class TestAllcloseMLL(BotorchTestCase): + def setUp(self): + with torch.random.fork_rng(): + torch.manual_seed(0) + train_X = torch.linspace(0, 1, 10).unsqueeze(-1) + train_Y = torch.sin((2 * math.pi) * train_X) + train_Y = train_Y + 0.1 * torch.randn_like(train_Y) + + self.mlls = [] + for nu in (1.5, 2.5): + model = SingleTaskGP(train_X=train_X, train_Y=train_Y) + model.covar_module.base_kernel.nu = nu + self.mlls.append(ExactMarginalLogLikelihood(model.likelihood, model)) + + def test_allclose_mll(self): + self.assertTrue(allclose_mll(a=self.mlls[0], b=self.mlls[0])) + for transform_a, transform_b in product( + *(2 * [(None, lambda vals: torch.zeros_like(vals))]) + ): + out = allclose_mll( + a=self.mlls[0], + b=self.mlls[1], + transform_a=transform_a, + transform_b=transform_b, + ) + self.assertEqual(out, transform_a is not None and transform_b is not None) diff --git a/test/optim/utils/test_numpy_utils.py b/test/optim/utils/test_numpy_utils.py new file mode 100644 index 0000000000..cde7a9a748 --- /dev/null +++ b/test/optim/utils/test_numpy_utils.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from unittest.mock import MagicMock + +import numpy as np +import torch +from botorch.optim.closures.core import ( + as_ndarray, + get_tensors_as_ndarray_1d, + set_tensors_from_ndarray_1d, +) +from botorch.optim.utils import get_bounds_as_ndarray +from botorch.utils.testing import BotorchTestCase +from torch.nn import Parameter + + +class TestNumpyUtils(BotorchTestCase): + def setUp(self): + super().setUp() + self.parameters = {"foo": torch.rand(2), "bar": Parameter(torch.rand(3))} + + def test_as_ndarray(self): + base = np.random.randn(3) + tnsr = torch.from_numpy(base) + + # Test inplace conversion + result = as_ndarray(tnsr) + self.assertTrue(np.shares_memory(base, result)) + + # Test conversion with memory allocation + result = as_ndarray(tnsr, inplace=False) + self.assertTrue(np.allclose(base, result)) + self.assertFalse(np.shares_memory(base, result)) + + result = as_ndarray(tnsr, dtype=np.float32) + self.assertTrue(np.allclose(base, result)) + self.assertFalse(np.shares_memory(base, result)) + self.assertEqual(result.dtype, np.float32) + + # Test that `clone` does not get called on non-CPU tensors + mock_tensor = MagicMock() + mock_tensor.cpu.return_value = mock_tensor + mock_tensor.device.return_value = "foo" + mock_tensor.clone.return_value = mock_tensor + + as_ndarray(mock_tensor) + mock_tensor.cpu.assert_called_once() + mock_tensor.clone.assert_not_called() + mock_tensor.numpy.assert_called_once() + + def test_get_tensors_as_ndarray_1d(self): + with self.assertRaisesRegex(RuntimeError, "Argument `tensors` .* is empty"): + get_tensors_as_ndarray_1d(()) + + values = get_tensors_as_ndarray_1d(self.parameters) + self.assertTrue( + np.allclose(values, get_tensors_as_ndarray_1d(self.parameters.values())) + ) + n = 0 + for param in self.parameters.values(): + k = param.numel() + self.assertTrue( + np.allclose(values[n : n + k], param.view(-1).detach().cpu().numpy()) + ) + n += k + + with self.assertRaisesRegex(ValueError, "Expected a vector for `out`"): + get_tensors_as_ndarray_1d(self.parameters, out=np.empty((1, 1))) + + with self.assertRaisesRegex(ValueError, "Size of `parameters` .* not match"): + get_tensors_as_ndarray_1d(self.parameters, out=np.empty(values.size - 1)) + + with self.assertRaisesRegex(RuntimeError, "failed while copying values .* foo"): + get_tensors_as_ndarray_1d( + self.parameters, + out=np.empty(values.size), + as_array=MagicMock(side_effect=RuntimeError("foo")), + ) + + def test_set_tensors_from_ndarray_1d(self): + values = get_tensors_as_ndarray_1d(self.parameters) + others = np.random.rand(*values.shape).astype(values.dtype) + with self.assertRaisesRegex(RuntimeError, "failed while copying values to"): + set_tensors_from_ndarray_1d(self.parameters, np.empty([1])) + + set_tensors_from_ndarray_1d(self.parameters, others) + n = 0 + for param in self.parameters.values(): + k = param.numel() + self.assertTrue( + np.allclose(others[n : n + k], param.view(-1).detach().cpu().numpy()) + ) + n += k + + def test_get_bounds_as_ndarray(self): + params = {"a": torch.rand(1), "b": torch.rand(1), "c": torch.rand(2)} + bounds = {"a": (None, 1), "c": (0, None)} + + test = np.full((4, 2), (-float("inf"), float("inf"))) + test[0, 1] = 1 + test[2, 0] = 0 + test[3, 0] = 0 + + array = get_bounds_as_ndarray(parameters=params, bounds=bounds) + self.assertTrue(np.array_equal(test, array)) diff --git a/test/test_fit.py b/test/test_fit.py index 04d8f630bc..09ffbd1044 100644 --- a/test/test_fit.py +++ b/test/test_fit.py @@ -8,9 +8,8 @@ import warnings from contextlib import nullcontext from copy import deepcopy -from itertools import product -from typing import Iterable, Optional -from unittest import mock +from itertools import filterfalse, product +from typing import Callable, Iterable, Optional from unittest.mock import MagicMock, patch from warnings import catch_warnings, warn, WarningMessage @@ -19,44 +18,52 @@ from botorch.exceptions.errors import ModelFittingError, UnsupportedError from botorch.exceptions.warnings import BotorchWarning, OptimizationWarning from botorch.fit import fit_gpytorch_mll -from botorch.models import FixedNoiseGP, HeteroskedasticSingleTaskGP, SingleTaskGP +from botorch.models import ( + FixedNoiseGP, + HeteroskedasticSingleTaskGP, + SingleTaskGP, + SingleTaskVariationalGP, +) from botorch.models.converter import batched_to_model_list from botorch.models.transforms.input import Normalize from botorch.models.transforms.outcome import Standardize -from botorch.optim.utils import ( - allclose_mll, + +from botorch.optim.closures import get_loss_closure_with_grads +from botorch.optim.fit import fit_gpytorch_mll_scipy, fit_gpytorch_mll_torch +from botorch.optim.utils import allclose_mll, get_data_loader +from botorch.settings import debug +from botorch.utils.context_managers import ( del_attribute_ctx, + module_rollback_ctx, requires_grad_ctx, - state_rollback_ctx, + TensorCheckpoint, ) -from botorch.settings import debug from botorch.utils.dispatcher import MDNotImplementedError from botorch.utils.testing import BotorchTestCase from gpytorch.kernels import MaternKernel -from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood +from gpytorch.mlls import ExactMarginalLogLikelihood, VariationalELBO from linear_operator.utils.errors import NotPSDError MAX_ITER_MSG = "TOTAL NO. of ITERATIONS REACHED LIMIT" -MAX_RETRY_MSG = "All attempts to fit the model have failed." class MockOptimizer: def __init__( self, randomize_requires_grad: bool = True, - thrown_warnings: Iterable[WarningMessage] = (), - thrown_exception: Optional[BaseException] = None, + warnings: Iterable[WarningMessage] = (), + exception: Optional[BaseException] = None, ): r"""Class used to mock `optimizer` argument to `fit_gpytorch_mll.""" self.randomize_requires_grad = randomize_requires_grad - self.thrown_warnings = thrown_warnings - self.thrown_exception = thrown_exception + self.warnings = warnings + self.exception = exception self.call_count = 0 - def __call__(self, mll): + def __call__(self, mll, closure: Optional[Callable] = None): self.call_count += 1 - for w in self.thrown_warnings: - warn(w.message, w.category) + for w in self.warnings: + warn(str(w.message), w.category) if self.randomize_requires_grad: with torch.no_grad(): @@ -64,8 +71,8 @@ def __call__(self, mll): if param.requires_grad: param[...] = torch.rand_like(param) - if self.thrown_exception is not None: - raise self.thrown_exception + if self.exception is not None: + raise self.exception return mll, None @@ -182,7 +189,10 @@ def setUp(self): key = model_type, output_dim self.mlls[key] = mll.to(dtype=dtype) self.checkpoints[key] = { - k: (v.detach().clone(), {}) for k, v in mll.state_dict().items() + k: TensorCheckpoint( + values=v.detach().clone(), device=v.device, dtype=v.dtype + ) + for k, v in mll.state_dict().items() } def test_main(self): @@ -200,14 +210,14 @@ def test_exceptions(self): def _test_main(self, mll, ckpt): r"""Main test for `_fit_fallback`.""" optimizer = MockOptimizer() - optimizer.thrown_warnings = [ + optimizer.warnings = [ WarningMessage("test_runtime_warning", RuntimeWarning, __file__, 0), ] for should_fail in (True, False): optimizer.call_count = 0 with catch_warnings(), requires_grad_ctx( module=mll, assignments={"model.mean_module.constant": False} - ), state_rollback_ctx(mll, checkpoint=ckpt): + ), module_rollback_ctx(mll, checkpoint=ckpt): try: fit._fit_fallback( mll, @@ -215,7 +225,7 @@ def _test_main(self, mll, ckpt): None, max_attempts=2, optimizer=optimizer, - warning_filter=lambda w: should_fail, + warning_handler=lambda w: not should_fail, ) except ModelFittingError: failed = True @@ -230,30 +240,39 @@ def _test_main(self, mll, ckpt): self.assertEqual(failed, mll.training) for key, vals in mll.state_dict().items(): if failed: - self.assertTrue(vals.equal(ckpt[key][0])) + self.assertTrue(vals.equal(ckpt[key].values)) else: try: param = mll.get_parameter(key) self.assertNotEqual( - param.equal(ckpt[key][0]), param.requires_grad + param.equal(ckpt[key].values), param.requires_grad ) except AttributeError: pass + # Test `closure_kwargs` + with self.subTest("closure_kwargs"): + mock_closure = MagicMock(side_effect=StopIteration("foo")) + with self.assertRaisesRegex(StopIteration, "foo"): + fit._fit_fallback( + mll, None, None, closure=mock_closure, closure_kwargs={"ab": "cd"} + ) + mock_closure.assert_called_once_with(ab="cd") + def _test_warnings(self, mll, ckpt): r"""Test warning handling for `_fit_fallback`.""" optimizer = MockOptimizer(randomize_requires_grad=False) - optimizer.thrown_warnings = [ + optimizer.warnings = [ WarningMessage("test_runtime_warning", RuntimeWarning, __file__, 0), WarningMessage(MAX_ITER_MSG, OptimizationWarning, __file__, 0), ] - warning_filters = { - "default": fit.DEFAULT_WARNING_FILTER, - "none": lambda w: True, - "all": lambda w: False, + warning_handlers = { + "default": fit.DEFAULT_WARNING_HANDLER, + "none": lambda w: False, + "all": lambda w: True, } - for case, warning_filter in warning_filters.items(): + for case, warning_handler in warning_handlers.items(): with ( self.assertLogs(level="DEBUG") if case == "default" else nullcontext() ) as logs, catch_warnings(record=True) as ws, debug(True): @@ -264,7 +283,7 @@ def _test_warnings(self, mll, ckpt): None, max_attempts=2, optimizer=optimizer, - warning_filter=warning_filter, + warning_handler=warning_handler, ) except ModelFittingError: failed = True @@ -274,22 +293,22 @@ def _test_warnings(self, mll, ckpt): # Test that warnings were resolved in the expected fashion self.assertEqual(failed, case == "none") with catch_warnings(record=True) as rethrown: - unresolved = list(filter(warning_filter, optimizer.thrown_warnings)) + unresolved = list(filterfalse(warning_handler, optimizer.warnings)) self.assertEqual(failed, len(unresolved) > 0) self.assertEqual( {str(w.message) for w in ws}, {str(w.message) for w in rethrown + unresolved}, ) - if logs: # test that default filter logs certain warnings self.assertTrue(any(MAX_ITER_MSG in log for log in logs.output)) # Test default of retrying upon encountering an uncaught OptimizationWarning - optimizer.thrown_warnings.append( + optimizer.warnings.append( WarningMessage("test_optim_warning", OptimizationWarning, __file__, 0) ) - with self.assertRaisesRegex(ModelFittingError, MAX_RETRY_MSG), catch_warnings(): + + with self.assertRaises(ModelFittingError), catch_warnings(): fit._fit_fallback( mll, None, @@ -300,11 +319,11 @@ def _test_warnings(self, mll, ckpt): def _test_exceptions(self, mll, ckpt): r"""Test exception handling for `_fit_fallback`.""" - optimizer = MockOptimizer(thrown_exception=NotPSDError("not_psd")) + optimizer = MockOptimizer(exception=NotPSDError("not_psd")) with catch_warnings(): # Test behavior when encountering a caught exception - with self.assertLogs(level="DEBUG") as logs, self.assertRaisesRegex( - ModelFittingError, MAX_RETRY_MSG + with self.assertLogs(level="DEBUG") as logs, self.assertRaises( + ModelFittingError ): fit._fit_fallback( mll, @@ -316,7 +335,7 @@ def _test_exceptions(self, mll, ckpt): self.assertTrue(any("not_psd" in log for log in logs.output)) self.assertTrue( # test state rollback - all(v.equal(ckpt[k][0]) for k, v in mll.state_dict().items()) + all(v.equal(ckpt[k].values) for k, v in mll.state_dict().items()) ) # Test behavior when encountering an uncaught exception @@ -331,7 +350,115 @@ def _test_exceptions(self, mll, ckpt): ) self.assertTrue( # test state rollback - all(v.equal(ckpt[k][0]) for k, v in mll.state_dict().items()) + all(v.equal(ckpt[k].values) for k, v in mll.state_dict().items()) + ) + + +class TestFitFallbackAppoximate(BotorchTestCase): + def setUp(self): + with torch.random.fork_rng(): + torch.manual_seed(0) + train_X = torch.linspace(0, 1, 10).unsqueeze(-1) + train_F = torch.sin(2 * math.pi * train_X) + train_Y = train_F + 0.1 * torch.randn_like(train_F) + + model = SingleTaskVariationalGP( + train_X=train_X, + train_Y=train_Y, + input_transform=Normalize(d=1), + outcome_transform=Standardize(m=1), + ) + self.mll = mll = VariationalELBO(model.likelihood, model.model, num_data=10) + self.data_loader = get_data_loader(mll.model, batch_size=1) + self.closure = get_loss_closure_with_grads( + mll=mll, + parameters={n: p for n, p in mll.named_parameters() if p.requires_grad}, + data_loader=self.data_loader, + ) + + def test_main(self): + # Test parameter updates + with module_rollback_ctx(self.mll) as ckpt: + fit._fit_fallback_approximate( + self.mll, + None, + None, + closure=self.closure, + optimizer_kwargs={"step_limit": 3}, + ) + for name, param in self.mll.named_parameters(): + self.assertFalse(param.equal(ckpt[name].values)) + + # Test dispatching pattern + kwargs = {"full_batch_limit": float("inf")} + with patch.object(fit, "_fit_fallback") as mock_fallback: + fit._fit_fallback_approximate(self.mll, None, None, full_batch_limit=1) + mock_fallback.assert_called_once_with( + self.mll, + None, + None, + closure=None, + optimizer=fit_gpytorch_mll_torch, + ) + + with patch.object(fit, "_fit_fallback") as mock_fallback: + fit._fit_fallback_approximate(self.mll, None, None, **kwargs) + mock_fallback.assert_called_once_with( + self.mll, + None, + None, + closure=None, + optimizer=fit_gpytorch_mll_scipy, + ) + + with patch.object(fit, "_fit_fallback") as mock_fallback: + fit._fit_fallback_approximate( + self.mll, None, None, closure=self.closure, **kwargs + ) + + mock_fallback.assert_called_once_with( + self.mll, + None, + None, + closure=self.closure, + optimizer=fit_gpytorch_mll_torch, + ) + + with patch.object(fit, "_fit_fallback") as mock_fallback, patch.object( + fit, "get_loss_closure_with_grads" + ) as mock_get_closure: + mock_get_closure.return_value = "foo" + fit._fit_fallback_approximate( + self.mll, + None, + None, + data_loader=self.data_loader, + **kwargs, + ) + params = {n: p for n, p in self.mll.named_parameters() if p.requires_grad} + mock_get_closure.assert_called_once_with( + mll=self.mll, + data_loader=self.data_loader, + parameters=params, + ) + mock_fallback.assert_called_once_with( + self.mll, + None, + None, + closure="foo", + optimizer=fit_gpytorch_mll_torch, + ) + + # Test exception handling + with self.assertRaisesRegex( + UnsupportedError, "Only one of `data_loader` or `closure` may be passed." + ): + fit._fit_fallback_approximate( + self.mll, + None, + None, + closure=self.closure, + data_loader=self.data_loader, ) @@ -369,7 +496,10 @@ def setUp(self): key = model_type, output_dim self.mlls[key] = mll.to(dtype=dtype).train() self.checkpoints[key] = { - k: (v.detach().clone(), {}) for k, v in mll.state_dict().items() + k: TensorCheckpoint( + values=v.detach().clone(), device=v.device, dtype=v.dtype + ) + for k, v in mll.state_dict().items() } if output_dim > 1: with del_attribute_ctx(mll.model, "outcome_transform"): @@ -413,28 +543,29 @@ def _test_main(self, mll, ckpt): return optimizer = MockOptimizer() - with state_rollback_ctx(mll, checkpoint=ckpt), debug( + with module_rollback_ctx(mll, checkpoint=ckpt), debug( True ), warnings.catch_warnings(record=True) as ws: warnings.simplefilter("always", BotorchWarning) + warnings.simplefilter("ignore", DeprecationWarning) try: fit._fit_multioutput_independent( mll, None, None, optimizer=optimizer, - warning_filter=lambda w: False, # filter all warnings + warning_handler=lambda w: True, # mark all warnings as resolved max_attempts=1, ) except Exception: pass # exception handling tested separately else: - self.assertEqual(len(ws), 0) # Model repacking did not fail. + self.assertEqual(0, len(ws)) self.assertFalse(mll.training) self.assertEqual(optimizer.call_count, mll.model.num_outputs) self.assertTrue( all( - v.equal(ckpt[k][0]) != v.requires_grad + v.equal(ckpt[k].values) != v.requires_grad for k, v in mll.named_parameters() ) ) @@ -457,7 +588,7 @@ def _test_unpack(self, mll, ckpt, bad_mll): self.assertEqual(converter.call_count, 1) self.assertEqual(optimizer.call_count, 0) # should fail beforehand self.assertTrue( - all(v.equal(ckpt[k][0]) for k, v in mll.state_dict().items()) + all(v.equal(ckpt[k].values) for k, v in mll.state_dict().items()) ) self.assertTrue(any("unpacked model differs" in str(w.message) for w in ws)) @@ -476,7 +607,7 @@ def _test_repack(self, mll, ckpt, bad_mll): fit._fit_multioutput_independent(mll, None, None, max_attempts=1) self.assertTrue( - all(v.equal(ckpt[k][0]) for k, v in mll.state_dict().items()) + all(v.equal(ckpt[k].values) for k, v in mll.state_dict().items()) ) self.assertTrue(any("repacked model differs" in str(w.message) for w in ws)) @@ -500,7 +631,7 @@ def mock_fit_gpytorch_mll(*args, **kwargs): model_list_to_batched=converter, # should not get called fit_gpytorch_mll=mock_fit_gpytorch_mll, SumMarginalLogLikelihood=type(mll), - state_rollback_ctx=lambda *args, **kwargs: nullcontext({}), + module_rollback_ctx=lambda *args, **kwargs: nullcontext({}), ): fit._fit_multioutput_independent(mll, None, None) except MDNotImplementedError: @@ -521,7 +652,7 @@ def test_fit_with_converter(self): intf = Normalize(2) model = SingleTaskGP(X, Y, input_transform=intf) mll = ExactMarginalLogLikelihood(model.likelihood, model) - with mock.patch( + with patch( f"{fit_gpytorch_mll.__module__}.batched_to_model_list", wraps=batched_to_model_list, ) as wrapped_converter, warnings.catch_warnings(record=True) as ws: diff --git a/test/utils/test_context_managers.py b/test/utils/test_context_managers.py new file mode 100644 index 0000000000..a9e3d141d3 --- /dev/null +++ b/test/utils/test_context_managers.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from string import ascii_lowercase + +import torch +from botorch.utils.context_managers import ( + del_attribute_ctx, + module_rollback_ctx, + parameter_rollback_ctx, + requires_grad_ctx, + TensorCheckpoint, + zero_grad_ctx, +) +from botorch.utils.testing import BotorchTestCase +from torch.nn import Module, Parameter + + +class TestContextManagers(BotorchTestCase): + def setUp(self): + module = self.module = Module() + for i, name in enumerate(ascii_lowercase[:3], start=1): + values = torch.rand(2).to(torch.float16) + param = Parameter(values.to(torch.float64), requires_grad=bool(i % 2)) + module.register_parameter(name, param) + + def test_del_attribute_ctx(self): + # Test temporary removal of attributes + a = self.module.a + b = self.module.b + with del_attribute_ctx(self.module, "a", "b"): + self.assertIsNone(getattr(self.module, "a", None)) + self.assertIsNone(getattr(self.module, "b", None)) + self.assertTrue(self.module.c is not None) + + # Test that removed attributes get restored + self.assertTrue(self.module.a.equal(a)) + self.assertTrue(self.module.b.equal(b)) + + with self.assertRaisesRegex(ValueError, "Attribute .* missing"): + with del_attribute_ctx(self.module, "z", enforce_hasattr=True): + pass # pragma: no cover + + def test_requires_grad_ctx(self): + # Test temporary setting of requires_grad field + with requires_grad_ctx(self.module, assignments={"a": False, "b": True}): + self.assertTrue(not self.module.a.requires_grad) + self.assertTrue(self.module.b.requires_grad) + self.assertTrue(self.module.c.requires_grad) + + # Test that requires_grad fields get restored + self.assertTrue(self.module.a.requires_grad) + self.assertTrue(not self.module.b.requires_grad) + self.assertTrue(self.module.c.requires_grad) + + def test_parameter_rollback_ctx(self): + # Test that only unfiltered parameters get rolled back + a = self.module.a.detach().clone() + b = self.module.b.detach().clone() + c = self.module.c.detach().clone() + parameters = dict(self.module.named_parameters()) + with parameter_rollback_ctx(parameters, dtype=torch.float16) as ckpt: + for (tnsr, _, __) in ckpt.values(): # test whether dtype is obeyed + self.assertEqual(torch.float16, tnsr.dtype) + + self.module.a.data[...] = 0 + self.module.b.data[...] = 0 + self.module.c.data[...] = 0 + del ckpt["c"] # test whether changes to checkpoint dict are respected + + self.assertTrue(self.module.a.equal(a)) + self.assertTrue(self.module.b.equal(b)) + self.assertTrue(self.module.c.eq(0).all()) + + # Test rolling back to a user-provided checkpoint + with parameter_rollback_ctx( + parameters, checkpoint={"c": TensorCheckpoint(c, c.device, c.dtype)} + ): + pass + self.assertTrue(self.module.c.equal(c)) + + def test_module_rollback_ctx(self): + # Test that only unfiltered objects get rolled back + a = self.module.a.detach().clone() + b = self.module.b.detach().clone() + c = self.module.c.detach().clone() + with module_rollback_ctx( + self.module, lambda name: name == "a", dtype=torch.float16 + ) as ckpt: + for (tnsr, _, __) in ckpt.values(): # test whether dtype is obeyed + self.assertEqual(torch.float16, tnsr.dtype) + + self.module.a.data[...] = 0 + self.module.b.data[...] = 0 + self.module.c.data[...] = 0 + + self.assertTrue(self.module.a.equal(a)) + self.assertTrue(self.module.b.eq(0).all()) + self.assertTrue(self.module.c.eq(0).all()) + + # Test that changes to checkpoint dict are reflected in rollback state + with module_rollback_ctx(self.module) as ckpt: + self.module.a.data[...] = 1 + self.module.b.data[...] = 1 + self.module.c.data[...] = 1 + del ckpt["a"] + + self.assertTrue(self.module.a.eq(1).all()) + self.assertTrue(self.module.b.eq(0).all()) + self.assertTrue(self.module.c.eq(0).all()) + + # Test rolling back to a user-provided checkpoint + checkpoint = { + "a": TensorCheckpoint(a, a.device, a.dtype), + "b": TensorCheckpoint(b, b.device, b.dtype), + "c": TensorCheckpoint(c, c.device, c.dtype), + } + with module_rollback_ctx(module=self.module, checkpoint=checkpoint): + pass + self.assertTrue(self.module.a.equal(a)) + self.assertTrue(self.module.b.equal(b)) + self.assertTrue(self.module.c.equal(c)) + + # Test that items in checkpoint get inserted into state_dict + with del_attribute_ctx(self.module, "a"): + with self.assertRaisesRegex( # should fail when attempting to rollback + RuntimeError, r'Unexpected key\(s\) in state_dict: "a"' + ): + with module_rollback_ctx(module=self.module, checkpoint=checkpoint): + pass + + def test_zero_grad_ctx(self): + params = (Parameter(torch.rand(1)), Parameter(torch.rand(1))) + sum(params).backward() + with zero_grad_ctx(params, zero_on_enter=False, zero_on_exit=True): + self.assertFalse(any(x.grad.eq(0).all() for x in params)) + self.assertTrue(all(x.grad.eq(0).all() for x in params)) + + sum(params).backward() + with zero_grad_ctx(params, zero_on_enter=True, zero_on_exit=False): + self.assertTrue(all(x.grad.eq(0).all() for x in params)) + sum(params).backward() + self.assertFalse(any(x.grad.eq(0).all() for x in params))