-
Notifications
You must be signed in to change notification settings - Fork 401
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug] SingleTaskVariationalGP raises a warning when using input_transform #1824
Comments
Thanks for reporting this -- that is a concerning warning! I believe the warning is erroneous and the input transforms are being applied appropriately. import torch
from botorch.models import SingleTaskGP, SingleTaskVariationalGP
from botorch.models.transforms import Normalize
from matplotlib import pyplot as plt
from gpytorch.mlls import VariationalELBO, ExactMarginalLogLikelihood
from botorch.fit import fit_gpytorch_mll
train_X = torch.linspace(1, 3, 10, dtype=torch.double)[:, None]
y = -3 * train_X + 5
test_X = torch.linspace(1, 5, 10, dtype=torch.double)[:, None]
model = SingleTaskGP(train_X=train_X, train_Y=y, input_transform=Normalize(1))
mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_mll(mll)
post = model.posterior(test_X)
model = SingleTaskVariationalGP(train_X=train_X, train_Y=y, input_transform=Normalize(1))
mll = VariationalELBO(
model.likelihood, model.model, num_data=train_X.shape[-2]
)
fit_gpytorch_mll(mll)
post_var = model.posterior(test_X) # Warning
fig, axes = plt.subplots(1, 2, sharex=True, sharey=True)
axes[0].scatter(train_X, y, label="train data")
axes[0].plot(test_X, post.mean.detach().numpy(), label="posterior mean")
axes[0].legend()
axes[0].set_title("SingleTaskGP")
axes[1].scatter(train_X, y, label="train data")
axes[1].plot(test_X, post_var.mean.detach().numpy(), label="posterior mean")
axes[1].legend()
axes[1].set_title("VariationalGP")
for ax in axes:
ax.set_xlabel("X")
axes[0].set_ylabel("y") Variational GPs deal with input transforms differently than most models. When |
cc @saitcakmak this is relevant for the proposed transforms refactor in cornellius-gp/gpytorch#2114 |
Great, glad it's a false alarm. Thanks for the quick reply. |
…GPs (#1826) Summary: ## Motivation The BoTorch base `Model` class warns if an input transform has been provided, the `eval` method is called, and the object has no `train_inputs` attribute. This is not appropriate for `ApproximateGPyTorchModel`s; see #1824 . This PR gives `ApproximateGPyTorchModel` the `train` and `eval` modes from `torch.nn.Module`, which is the same as the methods it had been inheriting from `Model` but without the irrelevant input transform logic. A nicer fix would be to remove the input transform logic from `Model` and have it only in subclasses that it applies to, so that subclasses like `ApproximateGPyTorchModel` would not need to do anything special to avoid inheriting that. I think this all applies to `EnsembleModel`s as well as `ApproximateGPyTorchModel`s --looking into this now. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Pull Request resolved: #1826 Test Plan: Existing units for `ApproximateGPyTorchModel` look good. Reviewed By: Balandat Differential Revision: D45782048 Pulled By: esantorella fbshipit-source-id: 2091956a5a0cb6680f4c7292c0951f9079975ffb
I vaguely recall some issue with inducing points sometimes getting transformed and sometimes not with ApproximateGP. Looking at my old notes, I found this:
So, there might be some truth to the warning here. The proper solution would be to push through #1372 and fix this for good. I'll leave this open just in case. |
🐛 Bug
Calling
SingleTaskVariationalGP.posterior
with aninput_transform
raises a warning, whereas the equivalent call withSingleTaskGP
does not. I'm not sure ifinput_transform
works correctly withSingleTaskVariationalGP
or if I can safely interpret the resulting posterior. It also seems a little odd to me that this would be a warning and not an exception.To reproduce
** Code snippet to reproduce **
** Stack trace/error message **
System information
botorch version = 0.8.5
gpytorch version = 1.10
torch version = 1.13.1
The text was updated successfully, but these errors were encountered: