Skip to content

Commit

Permalink
Fix TransformedPosterior missing batch shape error in _update_base_sa…
Browse files Browse the repository at this point in the history
…mples (pytorch#1625)

Summary:
Pull Request resolved: pytorch#1625

See pytorch#1623. This is the only use case for `posterior.batch_shape`, so fixing it locally makes sense to me.

Reviewed By: Balandat

Differential Revision: D42421494

fbshipit-source-id: def6180343b363ddce01aede2bbc5892fb2cc99a
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Jan 11, 2023
1 parent a79071d commit 3596b12
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
8 changes: 7 additions & 1 deletion botorch/sampling/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from botorch.posteriors import Posterior
from botorch.posteriors.higher_order import HigherOrderGPPosterior
from botorch.posteriors.multitask import MultitaskGPPosterior
from botorch.posteriors.transformed import TransformedPosterior
from botorch.sampling.base import MCSampler
from botorch.utils.sampling import draw_sobol_normal_samples, manual_seed
from torch import Tensor
Expand Down Expand Up @@ -112,8 +113,13 @@ def _update_base_samples(
..., -n_train_samples:
]
else:
batch_shape = (
posterior._posterior.batch_shape
if isinstance(posterior, TransformedPosterior)
else posterior.batch_shape
)
single_output = (
len(posterior.base_sample_shape) - len(posterior.batch_shape)
len(posterior.base_sample_shape) - len(batch_shape)
) == 1
if single_output:
self.base_samples[
Expand Down
30 changes: 30 additions & 0 deletions test/acquisition/multi_objective/test_monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
from botorch.models.gp_regression import SingleTaskGP
from botorch.models.transforms.input import InputPerturbation
from botorch.models.transforms.outcome import Standardize
from botorch.posteriors.posterior_list import PosteriorList
from botorch.posteriors.transformed import TransformedPosterior
from botorch.sampling.list_sampler import ListSampler
from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler
from botorch.utils.low_rank import sample_cached_cholesky
from botorch.utils.multi_objective.box_decompositions.dominated import (
Expand Down Expand Up @@ -1624,3 +1627,30 @@ def get_acqf(model, matheron):
rtol=1e-4,
)
)

def test_with_transformed(self):
# Verify that _set_sampler works with transformed posteriors.
mm = MockModel(
posterior=PosteriorList(
TransformedPosterior(
MockPosterior(samples=torch.rand(2, 3, 1)), lambda X: X
),
TransformedPosterior(
MockPosterior(samples=torch.rand(2, 3, 1)), lambda X: X
),
)
)
sampler = ListSampler(
IIDNormalSampler(sample_shape=torch.Size([2])),
IIDNormalSampler(sample_shape=torch.Size([2])),
)
# This calls _set_sampler which used to error out in
# NormalMCSampler._update_base_samples with TransformedPosterior
# due to the missing batch_shape (fixed in #1625).
qNoisyExpectedHypervolumeImprovement(
model=mm,
ref_point=torch.tensor([0.0, 0.0]),
X_baseline=torch.rand(3, 2),
sampler=sampler,
cache_root=False,
)

0 comments on commit 3596b12

Please sign in to comment.