Skip to content

Commit

Permalink
Extend Posterior API to support torch distributions & overhaul MCSamp…
Browse files Browse the repository at this point in the history
…ler API (facebook#1486)

Summary:
X-link: pytorch/botorch#1486

The main goal here is to broadly support non-Gaussian posteriors.
- Adds a generic `TorchPosterior` which wraps a Torch `Distribution`. This defines a few properties that we commonly expect, and calls the `distribution` for the rest.
- For a unified plotting API, this shifts away from mean & variance to a quantile function. Most torch distributions implement inverse CDF, which is used as quantile. For others, the user should implement it either at distribution or posterior level.
- Hands off the burden of base sample handling from the posterior to the samplers. Using a dispatcher based `get_sampler` method, we can support SAA with mixed posteriors without having to shuffle base samples in a `PosteriorList`, as long as all base distributions have a corresponding sampler and support base samples.
- Adds `ListSampler` for sampling from `PosteriorList`.
- Adds `ForkedRNGSampler` and `StochasticSampler` for sampling from posteriors without base samples.
- Adds `rsample_from_base_samples` for sampling with `base_samples` / with a `sampler`.
- Absorbs `FullyBayesianPosteriorList` into `PosteriorList`.
- For MC acqfs, introduces a `get_posterior_samples` for sampling from the posterior with base samples / a sampler. If a sampler was not specified, this constructs the appropriate sampler for the posterior using `get_sampler`, eliminating the need to construct a sampler in `__init__`, which we used to do under the assumption of Gaussian posteriors.

TODOs:
- Relax the Gaussian assumption in acquisition functions & utilities. Some of this might be addressed in a follow-up diff.
- Updates to website / docs & tutorials to clear up some of the Gaussian assumption, introduce the new relaxed API. Likely a follow-up diff.
- Some more listed in T134364907
- Test fixes and new units

Other notables:
- See D39760855 for usage of TorchDistribution in SkewGP.
- TransformedPosterior could serve as the fallback option for derived posteriors.
- MC samplers no longer support resample or collapse_batch_dims(=False). These can be handled by i) not using base samples, ii) just using torch.fork_rng and sampling without base samples from that. Samplers are only meant to support SAA. Introduces `ForkedRNGSampler` and `StochasticSampler` as convenience samplers for these use cases.
- Introduced `batch_range_override` for the sampler to support edge cases where we may want to override `posterior.batch_range` (needed in `qMultiStepLookahead`)
- Removes unused sampling utilities `construct_base_samples(_from_posterior)`, which assume Gaussian posterior.
- Moves the main logic of `_set_sampler` method of CachedCholesky subclasses to a `_update_base_samples` method on samplers, and simplifies these classes a bit more.

Reviewed By: Balandat

Differential Revision: D39759489

fbshipit-source-id: ffec0cc4594e8ff8ae2777928a4acf1ee1259038
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Nov 11, 2022
1 parent a67302f commit 4fbff72
Show file tree
Hide file tree
Showing 9 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion ax/models/tests/test_botorch_kg.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)
from botorch.exceptions.errors import UnsupportedError
from botorch.models.transforms.input import Warp
from botorch.sampling.samplers import IIDNormalSampler, SobolQMCNormalSampler
from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler
from botorch.utils.datasets import FixedNoiseDataset


Expand Down
2 changes: 1 addition & 1 deletion ax/models/tests/test_botorch_mes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from botorch.exceptions.errors import UnsupportedError
from botorch.models.transforms.input import Warp
from botorch.sampling.samplers import SobolQMCNormalSampler
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils.datasets import FixedNoiseDataset


Expand Down
2 changes: 1 addition & 1 deletion ax/models/torch/botorch_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def _get_acquisition_func(
# construct Objective module
if kwargs.get("chebyshev_scalarization", False):
with torch.no_grad():
Y = model.posterior(X_observed).mean
Y = model.posterior(X_observed).mean # pyre-ignore [16]
obj_tf = get_chebyshev_scalarization(weights=objective_weights, Y=Y)
else:
obj_tf = get_objective_weights_transform(objective_weights)
Expand Down
2 changes: 1 addition & 1 deletion ax/models/torch/botorch_kg.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from botorch.models.model import Model
from botorch.optim.initializers import gen_one_shot_kg_initial_conditions
from botorch.optim.optimize import optimize_acqf
from botorch.sampling.samplers import IIDNormalSampler, SobolQMCNormalSampler
from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler
from torch import Tensor


Expand Down
2 changes: 1 addition & 1 deletion ax/models/torch/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from botorch.models.gp_regression import FixedNoiseGP, SingleTaskGP
from botorch.models.gp_regression_fidelity import FixedNoiseMultiFidelityGP
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.sampling.samplers import SobolQMCNormalSampler
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils.datasets import FixedNoiseDataset, SupervisedDataset


Expand Down
2 changes: 1 addition & 1 deletion ax/models/torch/tests/test_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from botorch.models.model import Model
from botorch.models.transforms.input import InputPerturbation, Normalize
from botorch.models.transforms.outcome import Standardize
from botorch.sampling.samplers import SobolQMCNormalSampler
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils.datasets import SupervisedDataset
from gpytorch.constraints import Interval
from gpytorch.kernels import Kernel, RBFKernel, ScaleKernel # noqa: F401
Expand Down
2 changes: 1 addition & 1 deletion ax/models/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from botorch.models import ModelListGP, SingleTaskGP
from botorch.models.model import Model
from botorch.posteriors.fully_bayesian import FullyBayesianPosterior
from botorch.sampling.samplers import IIDNormalSampler, SobolQMCNormalSampler
from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler
from botorch.utils.constraints import get_outcome_constraint_transforms
from botorch.utils.datasets import FixedNoiseDataset, SupervisedDataset
from botorch.utils.objective import get_objective_weights_transform
Expand Down
2 changes: 1 addition & 1 deletion ax/utils/sensitivity/derivative_measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ax.utils.sensitivity.derivative_gp import posterior_derivative
from botorch.models.model import Model
from botorch.posteriors.posterior import Posterior
from botorch.sampling.samplers import SobolQMCNormalSampler
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils.sampling import draw_sobol_samples
from botorch.utils.transforms import unnormalize
from gpytorch.distributions import MultivariateNormal
Expand Down
2 changes: 1 addition & 1 deletion ax/utils/sensitivity/sobol_measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch
from botorch.models.model import Model
from botorch.sampling.samplers import SobolQMCNormalSampler
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils.sampling import draw_sobol_samples
from botorch.utils.transforms import unnormalize
from torch._tensor import Tensor
Expand Down

0 comments on commit 4fbff72

Please sign in to comment.