Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove FixedNoiseMultiTaskGP #2323

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions botorch/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,14 @@
from botorch.models.higher_order_gp import HigherOrderGP
from botorch.models.model import ModelList
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.multitask import (
FixedNoiseMultiTaskGP,
KroneckerMultiTaskGP,
MultiTaskGP,
)
from botorch.models.multitask import KroneckerMultiTaskGP, MultiTaskGP
from botorch.models.pairwise_gp import PairwiseGP, PairwiseLaplaceMarginalLogLikelihood

__all__ = [
"AffineDeterministicModel",
"AffineFidelityCostModel",
"ApproximateGPyTorchModel",
"FixedNoiseGP",
"FixedNoiseMultiTaskGP",
"SaasFullyBayesianSingleTaskGP",
"SaasFullyBayesianMultiTaskGP",
"GenericDeterministicModel",
Expand Down
83 changes: 6 additions & 77 deletions botorch/models/multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from __future__ import annotations

import math
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -375,7 +374,12 @@ def construct_inputs(

# Call Model.construct_inputs to parse training data
base_inputs = super().construct_inputs(training_data=training_data)
if isinstance(training_data, MultiTaskDataset):
if (
isinstance(training_data, MultiTaskDataset)
# If task features are included in the data, all tasks will have
# some observations and they may have different task features.
and training_data.task_feature_index is None
):
all_tasks = list(range(len(training_data.datasets)))
base_inputs["all_tasks"] = all_tasks
if task_covar_prior is not None:
Expand All @@ -387,81 +391,6 @@ def construct_inputs(
return base_inputs


class FixedNoiseMultiTaskGP(MultiTaskGP):
r"""Multi-Task GP model using an ICM kernel, with known observation noise.

DEPRECATED: Please use `MultiTaskGP` with `train_Yvar` instead.
Will be removed in a future release (~v0.10).
"""

def __init__(
self,
train_X: Tensor,
train_Y: Tensor,
train_Yvar: Tensor,
task_feature: int,
covar_module: Optional[Module] = None,
task_covar_prior: Optional[Prior] = None,
output_tasks: Optional[List[int]] = None,
rank: Optional[int] = None,
input_transform: Optional[InputTransform] = None,
outcome_transform: Optional[OutcomeTransform] = None,
) -> None:
r"""
Args:
train_X: A `n x (d + 1)` or `b x n x (d + 1)` (batch mode) tensor
of training data. One of the columns should contain the task
features (see `task_feature` argument).
train_Y: A `n x 1` or `b x n x 1` (batch mode) tensor of training
observations.
train_Yvar: A `n` or `b x n` (batch mode) tensor of observed measurement
noise.
task_feature: The index of the task feature (`-d <= task_feature <= d`).
task_covar_prior : A Prior on the task covariance matrix. Must operate
on p.s.d. matrices. A common prior for this is the `LKJ` prior.
output_tasks: A list of task indices for which to compute model
outputs for. If omitted, return outputs for all task indices.
rank: The rank to be used for the index kernel. If omitted, use a
full rank (i.e. number of tasks) kernel.
input_transform: An input transform that is applied in the model's
forward pass.
outcome_transform: An outcome transform that is applied to the
training data during instantiation and to the posterior during
inference (that is, the `Posterior` obtained by calling
`.posterior` on the model will be on the original scale).

Example:
>>> X1, X2 = torch.rand(10, 2), torch.rand(20, 2)
>>> i1, i2 = torch.zeros(10, 1), torch.ones(20, 1)
>>> train_X = torch.cat([
>>> torch.cat([X1, i1], -1), torch.cat([X2, i2], -1),
>>> ], dim=0)
>>> train_Y = torch.cat(f1(X1), f2(X2))
>>> train_Yvar = 0.1 + 0.1 * torch.rand_like(train_Y)
>>> model = FixedNoiseMultiTaskGP(train_X, train_Y, train_Yvar, -1)
"""
warnings.warn(
"`FixedNoiseMultiTaskGP` has been deprecated and will be removed in a "
"future release. Please use the `MultiTaskGP` model instead. "
"When `train_Yvar` is specified, `MultiTaskGP` behaves the same "
"as the `FixedNoiseMultiTaskGP`.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(
train_X=train_X,
train_Y=train_Y,
train_Yvar=train_Yvar,
covar_module=covar_module,
task_feature=task_feature,
output_tasks=output_tasks,
rank=rank,
task_covar_prior=task_covar_prior,
input_transform=input_transform,
outcome_transform=outcome_transform,
)


class KroneckerMultiTaskGP(ExactGP, GPyTorchModel, FantasizeMixin):
"""Multi-task GP with Kronecker structure, using an ICM kernel.

Expand Down
8 changes: 0 additions & 8 deletions botorch/models/transforms/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,6 @@ def __init__(
transform_on_fantasize: bool = True,
approximate: bool = False,
tau: float = 1e-3,
**kwargs,
) -> None:
r"""Initialize transform.

Expand All @@ -800,13 +799,6 @@ def __init__(
rounding should be used. Default: False.
tau: The temperature parameter for approximate rounding.
"""
indices = kwargs.get("indices")
if indices is not None:
warn(
"`indices` is marked for deprecation in favor of `integer_indices`.",
DeprecationWarning,
)
integer_indices = indices
if approximate and categorical_features is not None:
raise NotImplementedError
super().__init__()
Expand Down
17 changes: 1 addition & 16 deletions botorch/models/transforms/outcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,9 @@

from __future__ import annotations

import warnings

from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Any, List, Mapping, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union

import torch
from botorch.models.transforms.utils import (
Expand Down Expand Up @@ -256,19 +254,6 @@ def __init__(
self._batch_shape = batch_shape
self._min_stdv = min_stdv

def load_state_dict(
self, state_dict: Mapping[str, Any], strict: bool = True
) -> None:
r"""Custom logic for loading the state dict."""
if "_is_trained" not in state_dict:
warnings.warn(
"Key '_is_trained' not found in state_dict. Setting to True. "
"In a future release, this will result in an error.",
DeprecationWarning,
)
state_dict = {**state_dict, "_is_trained": torch.tensor(True)}
super().load_state_dict(state_dict, strict=strict)

def forward(
self, Y: Tensor, Yvar: Optional[Tensor] = None
) -> Tuple[Tensor, Optional[Tensor]]:
Expand Down
1 change: 0 additions & 1 deletion test/models/test_contextual_multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ def test_construct_inputs(self) -> None:
model_inputs.pop("context_cat_feature"),
torch.tensor([[0.4], [0.5]]),
)
self.assertEqual(model_inputs.pop("all_tasks"), [0, 1])
self.assertEqual(model_inputs.pop("task_feature"), 0)
self.assertIsNone(model_inputs.pop("output_tasks"))
# Check that there are no unexpected inputs.
Expand Down
Loading
Loading