Skip to content

Commit

Permalink
Extend Posterior API to support torch distributions & overhaul MCSamp…
Browse files Browse the repository at this point in the history
…ler API (#1254)

Summary:
X-link: facebook/Ax#1254

Pull Request resolved: facebookresearch#193

X-link: pytorch/botorch#1486

The main goal here is to broadly support non-Gaussian posteriors.
- Adds a generic `TorchPosterior` which wraps a Torch `Distribution`. This defines a few properties that we commonly expect, and calls the `distribution` for the rest.
- For a unified plotting API, this shifts away from mean & variance to a quantile function. Most torch distributions implement inverse CDF, which is used as quantile. For others, the user should implement it either at distribution or posterior level.
- Hands off the burden of base sample handling from the posterior to the samplers. Using a dispatcher based `get_sampler` method, we can support SAA with mixed posteriors without having to shuffle base samples in a `PosteriorList`, as long as all base distributions have a corresponding sampler and support base samples.
- Adds `ListSampler` for sampling from `PosteriorList`.
- Adds `ForkedRNGSampler` and `StochasticSampler` for sampling from posteriors without base samples.
- Adds `rsample_from_base_samples` for sampling with `base_samples` / with a `sampler`.
- Absorbs `FullyBayesianPosteriorList` into `PosteriorList`.
- For MC acqfs, introduces a `get_posterior_samples` for sampling from the posterior with base samples / a sampler. If a sampler was not specified, this constructs the appropriate sampler for the posterior using `get_sampler`, eliminating the need to construct a sampler in `__init__`, which we used to do under the assumption of Gaussian posteriors.

TODOs:
- Relax the Gaussian assumption in acquisition functions & utilities. Some of this might be addressed in a follow-up diff.
- Updates to website / docs & tutorials to clear up some of the Gaussian assumption, introduce the new relaxed API. Likely a follow-up diff.

Other notables:
- See D39760855 for usage of TorchDistribution in SkewGP.
- TransformedPosterior could serve as the fallback option for derived posteriors.
- MC samplers no longer support resample or collapse_batch_dims(=False). These can be handled by i) not using base samples, ii) just using torch.fork_rng and sampling without base samples from that. Samplers are only meant to support SAA. Introduces `ForkedRNGSampler` and `StochasticSampler` as convenience samplers for these use cases.
- Introduced `batch_range_override` for the sampler to support edge cases where we may want to override `posterior.batch_range` (needed in `qMultiStepLookahead`)
- Removes unused sampling utilities `construct_base_samples(_from_posterior)`, which assume Gaussian posterior.
- Moves the main logic of `_set_sampler` method of CachedCholesky subclasses to a `_update_base_samples` method on samplers, and simplifies these classes a bit more.

Reviewed By: Balandat

Differential Revision: D39759489

fbshipit-source-id: d6f2d48c1019370f9727acb3fc2652f048e302a0
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Nov 17, 2022
1 parent c9a29a1 commit 5e5a0a7
Show file tree
Hide file tree
Showing 7 changed files with 11 additions and 16 deletions.
2 changes: 1 addition & 1 deletion aepsych/acquisition/lookahead_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def posterior_at_xstar_xq(
mu = posterior.mean[..., :, 0]
Mu_s = mu[..., 0].unsqueeze(-1)
Mu_q = mu[..., 1:]
Cov = posterior.mvn.covariance_matrix
Cov = posterior.distribution.covariance_matrix
Sigma2_s = Cov[..., 0, 0].unsqueeze(-1)
Sigma2_q = torch.diagonal(Cov[..., 1:, 1:], dim1=-1, dim2=-2)
Sigma_sq = Cov[..., 0, 1:]
Expand Down
2 changes: 1 addition & 1 deletion aepsych/acquisition/lse.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
sampler: The sampler used for drawing MC samples.
"""
if sampler is None:
sampler = SobolQMCNormalSampler(num_samples=512, collapse_batch_dims=True)
sampler = SobolQMCNormalSampler(sample_shape=torch.Size([512]))
if objective is None:
objective = ProbitObjective()
super().__init__(model=model, sampler=sampler, objective=None, X_pending=None)
Expand Down
5 changes: 3 additions & 2 deletions aepsych/acquisition/mc_posterior_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from botorch.acquisition.monte_carlo import MCAcquisitionFunction
from botorch.acquisition.objective import MCAcquisitionObjective
from botorch.models.model import Model
from botorch.sampling.samplers import MCSampler, SobolQMCNormalSampler
from botorch.sampling.base import MCSampler
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils.transforms import t_batch_mode_transform
from torch import Tensor

Expand Down Expand Up @@ -54,7 +55,7 @@ def __init__(
sampler: The sampler used for drawing MC samples.
"""
if sampler is None:
sampler = SobolQMCNormalSampler(num_samples=512, collapse_batch_dims=True)
sampler = SobolQMCNormalSampler(sample_shape=torch.Size([512]))
if objective is None:
objective = ProbitObjective()
super().__init__(model=model, sampler=sampler, objective=None, X_pending=None)
Expand Down
3 changes: 2 additions & 1 deletion aepsych/acquisition/mutual_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from botorch.acquisition.monte_carlo import MCAcquisitionFunction
from botorch.acquisition.objective import MCAcquisitionObjective
from botorch.models.model import Model
from botorch.sampling.samplers import MCSampler, SobolQMCNormalSampler
from botorch.sampling.base import MCSampler
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils.transforms import t_batch_mode_transform
from torch import Tensor
from torch.distributions.bernoulli import Bernoulli
Expand Down
11 changes: 2 additions & 9 deletions aepsych/acquisition/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch
from botorch.posteriors import Posterior
from botorch.sampling.samplers import MCSampler
from botorch.sampling.base import MCSampler
from torch import Tensor


Expand Down Expand Up @@ -38,14 +38,7 @@ def __init__(
self.num_samples = num_samples
self.num_rejection_samples = num_rejection_samples
self.constrained_idx = constrained_idx
self._sample_shape = torch.Size([num_samples])
super().__init__()

def _get_base_sample_shape(self, posterior: Posterior) -> torch.Size:
return torch.Size([])

def _construct_base_samples(self, posterior: Posterior, shape: torch.Size) -> None:
self.base_samples = None
super().__init__(sample_shape=torch.Size([num_samples]))

def forward(self, posterior: Posterior) -> Tensor:
"""Run the rejection sampler.
Expand Down
2 changes: 1 addition & 1 deletion aepsych/models/monotonic_projection_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def posterior(
# Adjust the whole covariance matrix to accomadate the projected marginals
with torch.no_grad():
post = super().posterior(X=X)
R = cov2corr(post.mvn.covariance_matrix.squeeze().numpy())
R = cov2corr(post.distribution.covariance_matrix.squeeze().numpy())
S_proj = torch.tensor(corr2cov(R, sigma_proj.numpy()), dtype=X.dtype)
mvn_proj = gpytorch.distributions.MultivariateNormal(
mu_proj.unsqueeze(0),
Expand Down
2 changes: 1 addition & 1 deletion tests/test_lookahead.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def setUp(self):
model = MockModel(
MockPosterior(mean=f[:, None], variance=torch.diag(covar)[None, :, None])
)
model._posterior.mvn = mvn
model._posterior.distribution = mvn
self.model, self.f, self.covar = model, f, covar

def test_posterior_extraction(self):
Expand Down

0 comments on commit 5e5a0a7

Please sign in to comment.