Skip to content

Commit

Permalink
Extend Posterior API to support torch distributions & overhaul MCSamp…
Browse files Browse the repository at this point in the history
…ler API (#1254)

Summary:
X-link: facebook/Ax#1254

X-link: facebookresearch/aepsych#193

Pull Request resolved: #1486

The main goal here is to broadly support non-Gaussian posteriors.
- Adds a generic `TorchPosterior` which wraps a Torch `Distribution`. This defines a few properties that we commonly expect, and calls the `distribution` for the rest.
- For a unified plotting API, this shifts away from mean & variance to a quantile function. Most torch distributions implement inverse CDF, which is used as quantile. For others, the user should implement it either at distribution or posterior level.
- Hands off the burden of base sample handling from the posterior to the samplers. Using a dispatcher based `get_sampler` method, we can support SAA with mixed posteriors without having to shuffle base samples in a `PosteriorList`, as long as all base distributions have a corresponding sampler and support base samples.
- Adds `ListSampler` for sampling from `PosteriorList`.
- Adds `ForkedRNGSampler` and `StochasticSampler` for sampling from posteriors without base samples.
- Adds `rsample_from_base_samples` for sampling with `base_samples` / with a `sampler`.
- Absorbs `FullyBayesianPosteriorList` into `PosteriorList`.
- For MC acqfs, introduces a `get_posterior_samples` for sampling from the posterior with base samples / a sampler. If a sampler was not specified, this constructs the appropriate sampler for the posterior using `get_sampler`, eliminating the need to construct a sampler in `__init__`, which we used to do under the assumption of Gaussian posteriors.

TODOs:
- Relax the Gaussian assumption in acquisition functions & utilities. Some of this might be addressed in a follow-up diff.
- Updates to website / docs & tutorials to clear up some of the Gaussian assumption, introduce the new relaxed API. Likely a follow-up diff.
- Some more listed in T134364907
- Test fixes and new units

Other notables:
- See D39760855 for usage of TorchDistribution in SkewGP.
- TransformedPosterior could serve as the fallback option for derived posteriors.
- MC samplers no longer support resample or collapse_batch_dims(=False). These can be handled by i) not using base samples, ii) just using torch.fork_rng and sampling without base samples from that. Samplers are only meant to support SAA. Introduces `ForkedRNGSampler` and `StochasticSampler` as convenience samplers for these use cases.
- Introduced `batch_range_override` for the sampler to support edge cases where we may want to override `posterior.batch_range` (needed in `qMultiStepLookahead`)
- Removes unused sampling utilities `construct_base_samples(_from_posterior)`, which assume Gaussian posterior.
- Moves the main logic of `_set_sampler` method of CachedCholesky subclasses to a `_update_base_samples` method on samplers, and simplifies these classes a bit more.

Reviewed By: Balandat

Differential Revision: D39759489

fbshipit-source-id: 59fa663777555ff6d528dab53d124665ae5e75e7
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Nov 16, 2022
1 parent 24ee799 commit a41ad1b
Show file tree
Hide file tree
Showing 101 changed files with 2,778 additions and 2,889 deletions.
36 changes: 36 additions & 0 deletions botorch/acquisition/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
from abc import ABC, abstractmethod
from typing import Callable, Optional

import torch
from botorch.exceptions import BotorchWarning, UnsupportedError
from botorch.models.model import Model
from botorch.posteriors.posterior import Posterior
from botorch.sampling.base import MCSampler
from botorch.sampling.get_sampler import get_sampler
from torch import Tensor
from torch.nn import Module

Expand Down Expand Up @@ -132,3 +135,36 @@ def extract_candidates(self, X_full: Tensor) -> Tensor:
A `b x q x d`-dim Tensor with `b` t-batches of `q` design points each.
"""
pass # pragma: no cover


class MCSamplerMixin(ABC):
r"""A mix-in for adding sampler functionality into an acquisition function class.
Attributes:
_default_sample_shape: The `sample_shape` for the default sampler.
:meta private:
"""

_default_sample_shape = torch.Size([512])

def __init__(self, sampler: Optional[MCSampler] = None) -> None:
r"""Register the sampler on the acquisition function.
Args:
sampler: The sampler used to draw base samples for MC-based acquisition
functions. If `None`, a sampler is generated using `get_sampler`.
"""
self.sampler = sampler

def get_posterior_samples(self, posterior: Posterior) -> Tensor:
r"""Sample from the posterior using the sampler.
Args:
posterior: The posterior to sample from.
"""
if self.sampler is None:
self.sampler = get_sampler(
posterior=posterior, sample_shape=self._default_sample_shape
)
return self.sampler(posterior=posterior)
12 changes: 5 additions & 7 deletions botorch/acquisition/active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@

from typing import Optional

import torch
from botorch import settings
from botorch.acquisition.analytic import AnalyticAcquisitionFunction
from botorch.acquisition.monte_carlo import MCAcquisitionFunction
from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform
from botorch.models.model import Model
from botorch.sampling.samplers import MCSampler, SobolQMCNormalSampler
from botorch.sampling.base import MCSampler
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils.transforms import concatenate_pending_points, t_batch_mode_transform
from torch import Tensor

Expand Down Expand Up @@ -81,9 +83,7 @@ def __init__(
# variance does not depend on the samples y (only on x), which is true for
# standard GP models, but not in general (e.g. for other likelihoods or
# heteroskedastic GPs using a separate noise model fit on data).
sampler = SobolQMCNormalSampler(
num_samples=1, resample=False, collapse_batch_dims=True
)
sampler = SobolQMCNormalSampler(sample_shape=torch.Size([1]))
self.sampler = sampler
self.X_pending = X_pending
self.register_buffer("mc_points", mc_points)
Expand Down Expand Up @@ -150,8 +150,6 @@ def __init__(
two samples. Can be implemented via GenericMCObjective.
sampler: The sampler used for drawing MC samples.
"""
if sampler is None:
sampler = SobolQMCNormalSampler(num_samples=512, collapse_batch_dims=True)
super().__init__(
model=model, sampler=sampler, objective=objective, X_pending=None
)
Expand All @@ -175,7 +173,7 @@ def forward(self, X: Tensor) -> Tensor:
# The output is of shape batch_shape x 2 x d
# For PairwiseGP, d = 1
post = self.model.posterior(X)
samples = self.sampler(post) # num_samples x batch_shape x 2 x d
samples = self.get_posterior_samples(post) # num_samples x batch_shape x 2 x d

# The output is of shape num_samples x batch_shape x q/2 x d
# assuming the comparison is made between the 2 * i and 2 * i + 1 elements
Expand Down
5 changes: 3 additions & 2 deletions botorch/acquisition/analytic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from botorch.models.gp_regression import FixedNoiseGP
from botorch.models.gpytorch import GPyTorchModel
from botorch.models.model import Model
from botorch.sampling.samplers import SobolQMCNormalSampler
from botorch.utils.transforms import convert_to_target_pre_hook, t_batch_mode_transform
from torch import Tensor
from torch.distributions import Normal
Expand Down Expand Up @@ -561,9 +560,11 @@ def __init__(
"Only FixedNoiseGPs are currently supported for fantasy NEI"
)
# sample fantasies
from botorch.sampling.normal import SobolQMCNormalSampler

with torch.no_grad():
posterior = model.posterior(X=X_observed)
sampler = SobolQMCNormalSampler(num_fantasies)
sampler = SobolQMCNormalSampler(sample_shape=torch.Size([num_fantasies]))
Y_fantasized = sampler(posterior).squeeze(-1)
batch_X_observed = X_observed.expand(num_fantasies, *X_observed.shape)
# The fantasy model will operate in batch mode
Expand Down
86 changes: 39 additions & 47 deletions botorch/acquisition/cached_cholesky.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,15 @@

import warnings
from abc import ABC
from typing import Optional

import torch
from botorch.exceptions.errors import UnsupportedError
from botorch.exceptions.warnings import BotorchWarning
from botorch.models import HigherOrderGP
from botorch.models.deterministic import DeterministicModel
from botorch.models.gpytorch import GPyTorchModel
from botorch.models.higher_order_gp import HigherOrderGP
from botorch.models.model import Model, ModelList
from botorch.models.multitask import KroneckerMultiTaskGP, MultiTaskGP
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.posteriors.posterior import Posterior
from botorch.sampling.samplers import MCSampler
from botorch.utils.low_rank import extract_batch_covar, sample_cached_cholesky
from gpytorch import settings as gpt_settings
from gpytorch.distributions.multitask_multivariate_normal import (
Expand All @@ -43,58 +40,32 @@ class CachedCholeskyMCAcquisitionFunction(ABC):
:meta private:
"""

def _check_sampler(self) -> None:
r"""Check compatibility of sampler and model with a cached Cholesky."""
if not self.sampler.collapse_batch_dims:
raise UnsupportedError(
"Expected sampler to use `collapse_batch_dims=True`."
)
elif self.sampler.base_samples is not None:
warnings.warn(
message=(
"sampler.base_samples is not None. The base_samples must be "
"initialized to None. Resetting sampler.base_samples to None."
),
category=BotorchWarning,
)
self.sampler.base_samples = None
elif self._uses_matheron and self.sampler.batch_range != (0, -1):
raise RuntimeError(
"sampler.batch_range is not (0, -1). This check requires that the "
"sampler.batch_range is (0, -1) with GPs that use Matheron's rule "
"for sampling, in order to properly collapse batch dimensions. "
)

def _setup(
self,
model: Model,
sampler: Optional[MCSampler] = None,
cache_root: bool = False,
check_sampler: bool = False,
) -> None:
r"""Set class attributes and perform compatibility checks.
Args:
model: A model.
sampler: A sampler.
cache_root: A boolean indicating whether to cache the Cholesky.
This might be overridden in the model is not compatible.
check_sampler: A boolean indicating whether to check the sampler.
The sampler is always checked if cache_root is True.
"""
models = model.models if isinstance(model, ModelList) else [model]
self._is_mt = any(
if any(
isinstance(m, (MultiTaskGP, KroneckerMultiTaskGP, HigherOrderGP))
or not isinstance(m, GPyTorchModel)
for m in models
)
self._is_deterministic = any(isinstance(m, DeterministicModel) for m in models)
self._uses_matheron = any(
isinstance(m, (KroneckerMultiTaskGP, HigherOrderGP)) for m in models
)
if check_sampler or cache_root:
self._check_sampler()
if self._is_deterministic or self._is_mt:
cache_root = False
):
if cache_root:
warnings.warn(
"`cache_root` is only supported for GPyTorchModels (with the "
f"exception of MultiTask models). Got model={model}. Setting "
"`cache_root = False",
RuntimeWarning,
)
cache_root = False
self._cache_root = cache_root

def _compute_root_decomposition(
Expand All @@ -118,10 +89,10 @@ def _compute_root_decomposition(
Args:
posterior: The posterior over f(X_baseline).
"""
if isinstance(posterior.mvn, MultitaskMultivariateNormal):
lazy_covar = extract_batch_covar(posterior.mvn)
if isinstance(posterior.distribution, MultitaskMultivariateNormal):
lazy_covar = extract_batch_covar(posterior.distribution)
else:
lazy_covar = posterior.mvn.lazy_covariance_matrix
lazy_covar = posterior.distribution.lazy_covariance_matrix
with gpt_settings.fast_computations.covar_root_decomposition(False):
lazy_covar_root = lazy_covar.root_decomposition()
return lazy_covar_root.root.to_dense()
Expand All @@ -142,7 +113,7 @@ def _get_f_X_samples(self, posterior: GPyTorchPosterior, q_in: int) -> Tensor:
# cached covariance (and box decompositions) and the new block.
# But recomputing box decompositions every time the jitter changes would
# be quite slow.
if not self._is_mt and self._cache_root and hasattr(self, "_baseline_L"):
if self._cache_root and hasattr(self, "_baseline_L"):
try:
return sample_cached_cholesky(
posterior=posterior,
Expand All @@ -160,7 +131,7 @@ def _get_f_X_samples(self, posterior: GPyTorchPosterior, q_in: int) -> Tensor:
)

# TODO: improve efficiency for multi-task models
samples = self.sampler(posterior)
samples = self.get_posterior_samples(posterior)
if isinstance(self.model, HigherOrderGP):
# Select the correct q-batch dimension for HOGP.
q_dim = -self.model._num_dimensions
Expand All @@ -170,3 +141,24 @@ def _get_f_X_samples(self, posterior: GPyTorchPosterior, q_in: int) -> Tensor:
return samples.index_select(q_dim, q_idcs)
else:
return samples[..., -q_in:, :]

def _set_sampler(
self,
q_in: int,
posterior: Posterior,
) -> None:
r"""Update the sampler to use the original base samples for X_baseline.
Args:
q_in: The effective input batch size. This is typically equal to the
q-batch size of `X`. However, if using a one-to-many input transform,
e.g., `InputPerturbation` with `n_w` perturbations, the posterior will
have `n_w` points on the q-batch for each point on the q-batch of `X`.
In which case, `q_in = q * n_w` is used.
posterior: The posterior.
"""
if self.q_in != q_in and self.base_sampler is not None:
self.sampler._update_base_samples(
posterior=posterior, base_sampler=self.base_sampler
)
self.q_in = q_in
2 changes: 1 addition & 1 deletion botorch/acquisition/cost_aware.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from botorch.acquisition.objective import IdentityMCObjective, MCAcquisitionObjective
from botorch.exceptions.warnings import CostAwareWarning
from botorch.models.model import Model
from botorch.sampling.samplers import MCSampler
from botorch.sampling.base import MCSampler
from torch import Tensor
from torch.nn import Module

Expand Down
16 changes: 9 additions & 7 deletions botorch/acquisition/input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,11 @@
from botorch.exceptions.errors import UnsupportedError
from botorch.models.cost import AffineFidelityCostModel
from botorch.models.deterministic import FixedSingleSampleModel
from botorch.models.gpytorch import GPyTorchModel
from botorch.models.model import Model
from botorch.optim.optimize import optimize_acqf
from botorch.sampling.samplers import IIDNormalSampler, MCSampler, SobolQMCNormalSampler
from botorch.sampling.base import MCSampler
from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler
from botorch.utils.constraints import get_outcome_constraint_transforms
from botorch.utils.containers import BotorchContainer
from botorch.utils.datasets import BotorchDataset, SupervisedDataset
Expand Down Expand Up @@ -643,10 +645,10 @@ def construct_inputs_qUCB(
def _get_sampler(mc_samples: int, qmc: bool) -> MCSampler:
"""Set up MC sampler for q(N)EHVI."""
# initialize the sampler
seed = int(torch.randint(1, 10000, (1,)).item())
shape = torch.Size([mc_samples])
if qmc:
return SobolQMCNormalSampler(num_samples=mc_samples, seed=seed)
return IIDNormalSampler(num_samples=mc_samples, seed=seed)
return SobolQMCNormalSampler(sample_shape=shape)
return IIDNormalSampler(sample_shape=shape)


@acqf_input_constructor(ExpectedHypervolumeImprovement)
Expand Down Expand Up @@ -756,7 +758,7 @@ def construct_inputs_qEHVI(
)

sampler = kwargs.get("sampler")
if sampler is None:
if sampler is None and isinstance(model, GPyTorchModel):
sampler = _get_sampler(
mc_samples=kwargs.get("mc_samples", 128), qmc=kwargs.get("qmc", True)
)
Expand Down Expand Up @@ -806,7 +808,7 @@ def construct_inputs_qNEHVI(
cons_tfs = get_outcome_constraint_transforms(outcome_constraints)

sampler = kwargs.get("sampler")
if sampler is None:
if sampler is None and isinstance(model, GPyTorchModel):
sampler = _get_sampler(
mc_samples=kwargs.get("mc_samples", 128), qmc=kwargs.get("qmc", True)
)
Expand Down Expand Up @@ -1175,7 +1177,7 @@ def optimize_objective(
model=model,
objective=objective,
posterior_transform=posterior_transform,
sampler=sampler_cls(num_samples=mc_samples, seed=seed_inner),
sampler=sampler_cls(sample_shape=torch.Size([mc_samples]), seed=seed_inner),
)
else:
acq_function = PosteriorMean(
Expand Down
14 changes: 6 additions & 8 deletions botorch/acquisition/knowledge_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from botorch import settings
from botorch.acquisition.acquisition import (
AcquisitionFunction,
MCSamplerMixin,
OneShotAcquisitionFunction,
)
from botorch.acquisition.analytic import PosteriorMean
Expand All @@ -41,7 +42,8 @@
from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform
from botorch.exceptions.errors import UnsupportedError
from botorch.models.model import Model
from botorch.sampling.samplers import MCSampler, SobolQMCNormalSampler
from botorch.sampling.base import MCSampler
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils.transforms import (
concatenate_pending_points,
match_batch_shape,
Expand Down Expand Up @@ -108,9 +110,7 @@ def __init__(
"Must specify `num_fantasies` if no `sampler` is provided."
)
# base samples should be fixed for joint optimization over X, X_fantasies
sampler = SobolQMCNormalSampler(
num_samples=num_fantasies, resample=False, collapse_batch_dims=True
)
sampler = SobolQMCNormalSampler(sample_shape=torch.Size([num_fantasies]))
elif num_fantasies is not None:
if sampler.sample_shape != torch.Size([num_fantasies]):
raise ValueError(
Expand All @@ -119,11 +119,10 @@ def __init__(
else:
num_fantasies = sampler.sample_shape[0]
super(MCAcquisitionFunction, self).__init__(model=model)
MCSamplerMixin.__init__(self, sampler=sampler)
# if not explicitly specified, we use the posterior mean for linear objs
if isinstance(objective, MCAcquisitionObjective) and inner_sampler is None:
inner_sampler = SobolQMCNormalSampler(
num_samples=128, resample=False, collapse_batch_dims=True
)
inner_sampler = SobolQMCNormalSampler(sample_shape=torch.Size([128]))
elif objective is not None and not isinstance(
objective, MCAcquisitionObjective
):
Expand All @@ -150,7 +149,6 @@ def __init__(
"If using a multi-output model without an objective, "
"posterior_transform must scalarize the output."
)
self.sampler: MCSampler = sampler
self.objective = objective
self.posterior_transform = posterior_transform
self.set_X_pending(X_pending)
Expand Down
Loading

0 comments on commit a41ad1b

Please sign in to comment.