Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Move input transforms to GPyTorch #1372

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

saitcakmak
Copy link
Contributor

@saitcakmak saitcakmak commented Aug 31, 2022

Summary:
This diff presents a minimal implementation of input transforms in GPyTorch. See cornellius-gp/gpytorch#2114 for GPyTorch side of these changes.

What this does:

  • Moves the transform_inputs from BoTorch Model to GPyTorch GP class, with some modifications to explicitly identify whether given inputs are train or test inputs.
  • Modifies the InputTransform.forward call to use is_training_input argument instead of self.training check to apply the transforms that have transform_on_train=True.
  • Removes preprocess_transform method since this is no-longer needed.
  • For ExactGP models, it transforms both train and test inputs in __call__. For train_inputs it always uses is_training_input=True. For generic inputs, it uses is_training_input=self.training which signals that these are training inputs when the model is in train mode, and that these are test inputs when the model is in eval mode.
  • For ApproximateGP models, it applies the transform to inputs in __call__ using is_training_input=self.training. This again signifies whether the given inputs are train or test inputs based on the mode of the model. Note that this NEVER transforms inducing_points, thus fixes the previous bug with inducing_points getting transformed in train but not getting transformed in eval. It is expected that the user will define inducing points in the appropriate space (mostly the normalized space / unit cube).
  • For BoTorch SingleTaskVariationalGP, it moves the input_transform attribute down to _SingleTaskVariationalGP, which is the actual ApproximateGP instance. This makes the transform accessible from GPyTorch.

What this doesn't do:

  • It doesn't do anything about DeterministicModels. Those will still need to deal with their own transforms, which is not implemented here. If we make Model inherit from GP, we can keep the existing setup with very minimal changes.
  • It does not clean up the call sites for self.transform_inputs. This is just made into a no-op and the clean-up is left for later.
  • It does not upstream the abstract InputTransform classes to GPyTorch. That'll be done if we decide to go forward with this design.
  • It does not touch PairwiseGP. PairwiseGP has some non-standard use of input transforms, so it needs an audit to make sure things still work fine.
  • I didn't look into ApproximateGP.fantasize. This may need some changes similar to ExactGP.get_fantasy_model.
  • It does not support PyroGP and DeepGP.

Differential Revision: D39147547

Summary:
This diff presents a minimal implementation of input transforms in GPyTorch.

What this does:
* Moves the `transform_inputs` from BoTorch `Model` to GPyTorch `GP` class, with some modifications to explicitly identify whether given inputs are train or test inputs.
* Modifies the `InputTransform.forward` call to use `is_training_input` argument instead of `self.training` check to apply the transforms that have `transform_on_train=True`.
* Removes `preprocess_transform` method since this is no-longer needed.
* For `ExactGP` models, it transforms both train and test inputs in `__call__`. For `train_inputs` it always uses `is_training_input=True`. For generic `inputs`, it uses `is_training_input=self.training` which signals that these are training inputs when the model is in `train` mode, and that these are test inputs when the model is in `eval` mode.
* For `ApproximateGP` models, it applies the transform to `inputs` in `__call__` using `is_training_input=self.training`. This again signifies whether the given inputs are train or test inputs based on the mode of the model. Note that this NEVER transforms `inducing_points`, thus fixes the previous bug with `inducing_points` getting transformed in `train` but not getting transformed in `eval`. It is expected that the user will define inducing points in the appropriate space (mostly the normalized space / unit cube).
* For BoTorch `SingleTaskVariationalGP`, it moves the `input_transform` attribute down to `_SingleTaskVariationalGP`, which is the actual `ApproximateGP` instance. This makes the transform accessible from GPyTorch.

What this doesn't do:
* It doesn't do anything about `DeterministicModel`s. Those will still need to deal with their own transforms, which is not implemented here. If we make `Model` inherit from `GP`, we can keep the existing setup with very minimal changes.
* It does not clean up the call sites for `self.transform_inputs`. This is just made into a no-op and the clean-up is left for later.
* It does not upstream the abstract `InputTransform` classes to GPyTorch. That'll be done if we decide to go forward with this design.
* It does not touch `PairwiseGP`. `PairwiseGP` has some non-standard use of input transforms, so it needs an audit to make sure things still work fine.
* I didn't look into `ApproximateGP.fantasize`. This may need some changes similar to `ExactGP.get_fantasy_model`.
* It does not support `PyroGP` and `DeepGP`.

Differential Revision: D39147547

fbshipit-source-id: ed2745b0ff666a13764759e1511a139c228d1d39
@facebook-github-bot facebook-github-bot added CLA Signed Do not delete this pull request or issue due to inactivity. fb-exported labels Aug 31, 2022
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D39147547

@saitcakmak saitcakmak changed the title Move input transforms to GPyTorch [RFC] Move input transforms to GPyTorch Aug 31, 2022
@Balandat
Copy link
Contributor

Balandat commented Sep 6, 2022

I like this design!

It doesn't do anything about DeterministicModels. Those will still need to deal with their own transforms, which is not implemented here. If we make Model inherit from GP, we can keep the existing setup with very minimal changes.

I think we probably want to have sth like a gpytorch BaseModel that implements the transform handling in a general sense and then have GP inherit from that (so we don't have deterministic models suddenly be GPs...).

It does not upstream the abstract InputTransform classes to GPyTorch. That'll be done if we decide to go forward with this design.

@gpleiss do you have any high-level feedback on the transform setup (https://github.com/pytorch/botorch/tree/main/botorch/models/transforms) that we'd want to incorporate when upstreaming those?

One point that @j-wilson had brought up is that if the transforms are expensive and not learnable (e.g. a pre-fit NN feature extractor) then repeatedly applying it to the same inputs during training (for the full batch case of exact GPs anyway) could be quite wasteful. Is there an elegant solution to this by means of caching the transformed values of the training data and evicting that cache when they are reset?

@gpleiss
Copy link
Contributor

gpleiss commented Sep 9, 2022

@Balandat I really like the botorch API, and this would be super useful to have upstream in GPyTorch!

One point that @j-wilson had brought up is that if the transforms are expensive and not learnable (e.g. a pre-fit NN feature extractor) then repeatedly applying it to the same inputs during training (for the full batch case of exact GPs anyway) could be quite wasteful. Is there an elegant solution to this by means of caching the transformed values of the training data and evicting that cache when they are reset?

There probably is an elegant way to do this, but nothing really comes to mind. We should circle back to this at some point, but at the very least a power user could (e.g.) apply a pre-trained NN to the inputs without using the transforms API.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed Do not delete this pull request or issue due to inactivity. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants