Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Botorch closures #1439

Closed
wants to merge 1 commit into from
Closed

Botorch closures #1439

wants to merge 1 commit into from

Conversation

j-wilson
Copy link
Contributor

@j-wilson j-wilson commented Oct 3, 2022

Summary:
Changelog:

  • Enable user-defined loss closures.
  • fit_gptorch_torch rewrite
  • Add fit_gyptorch_mll dispatch for ApproximateGPs

Differential Revision: D39101211

@facebook-github-bot facebook-github-bot added CLA Signed Do not delete this pull request or issue due to inactivity. fb-exported labels Oct 3, 2022
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D39101211

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D39101211

botorch/fit.py Outdated
data_loader: Convience keyword for passing in a DataLoader instance or dict of
keyword arguments passed to `get_data_loader` to obtain one. May only be
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels a bit weird that data_loader can either be an instance or a set of kwargs passed to a factory function.

botorch/fit.py Outdated Show resolved Hide resolved

dispatcher = Dispatcher("get_loss_closure")
NoneType = type(None)
TLossClosure = Callable[[], Tensor]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the right type if you're forwarding kwargs to the MLL?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really, no. If we end up allowing kwargs to be passed to loss closures (debatable since torch.jit.script does not support this), we'll make TLossClosure a typing.Protocol.

botorch/optim/closures.py Outdated Show resolved Hide resolved
Comment on lines 233 to 341
if isinstance(torch_optimizer, (type, partial)):
torch_optimizer = torch_optimizer(params=list(param_dict.values()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not the biggest fan of the pattern of allowing this to be either an instance or a factory function. Not sure if there is an elegant alternative though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would factory only be better or worse? Note that instances can still be passed as lambda **_: instance, but the typing is now more explicit.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My $0.02: being able to pass an Instance is nicer than the factory, if only because then I don't necessarily need all the parameters I want to optimize to live somewhere on mll. There's even a reasonable argument to be made for accepting lists of Optimizers and stepping all of them, to support e.g. NGD or different Optimizers / learning rates for the GP vs a DNN

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

being able to pass an Instance is nicer than the factory, if only because then I don't necessarily need all the parameters I want to optimize to live somewhere on mll.

So, here's what we can do. We can split up fit_gpytorch_torch into a pair of methods: i) a generic torch-based gradient descent method and ii) a wrapper that provides mll-based model fitting conveniences. This seems to address your first point.

There's even a reasonable argument to be made for accepting lists of Optimizers and stepping all of them, to support e.g. NGD or different Optimizers / learning rates for the GP vs a DNN

Interesting idea. I like that this eliminates the need for step_limit. If you wanted to go further, you could pass in a list of OptimizationStep objects with slots for closure, optimizer, etc . Unclear to me whether we run into any scheduling issues here.

botorch/optim/fit.py Outdated Show resolved Hide resolved
torch_scheduler.step()

loss = None if loss is None else loss.detach().cpu().item()
return mll, {"fopt": loss, "wall_time": monotonic() - start_time}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we include iterations here for backwards compatibility? I guess this is sufficiently deep in the stack that we don't need to worry too much about this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Callbacks Are All You Need. Wilson et al., 2022.

Jokes aside: Is it just a backward compatibility thing or are you thinking that some conveniences would be useful here?

botorch/optim/numpy_converter.py Outdated Show resolved Hide resolved
torch_closure: Optional[TLossClosure] = None,
parameter_setter: Callable[[np.ndarray], Any] = set_params_with_array,
) -> Callable[[], Tuple[np.ndarray, np.ndarray]]:
if torch_closure is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a docstring here?

botorch/optim/numpy_converter.py Outdated Show resolved Hide resolved
botorch/optim/fit.py Outdated Show resolved Hide resolved
@jacobrgardner
Copy link

This looks like a pretty reasonable solution to me! I actually like the ability to roll my own loss closures and optimizers here, and think it's worth the extra engineering. @nataliemaus has code that sometimes wants to train a GP end to end with a VAE included in the ELBO, and sometimes only wants to train the GP, so this should let us switch even more code into using BoTorch routines.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D39101211

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D39101211

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D39101211

@codecov
Copy link

codecov bot commented Oct 10, 2022

Codecov Report

Merging #1439 (4f6c4ef) into main (7613cd2) will not change coverage.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##              main     #1439    +/-   ##
==========================================
  Coverage   100.00%   100.00%            
==========================================
  Files          134       143     +9     
  Lines        12402     12755   +353     
==========================================
+ Hits         12402     12755   +353     
Impacted Files Coverage Δ
botorch/models/pairwise_gp.py 100.00% <ø> (ø)
botorch/fit.py 100.00% <100.00%> (ø)
botorch/optim/__init__.py 100.00% <100.00%> (ø)
botorch/optim/closures/__init__.py 100.00% <100.00%> (ø)
botorch/optim/closures/core.py 100.00% <100.00%> (ø)
botorch/optim/closures/model_closures.py 100.00% <100.00%> (ø)
botorch/optim/core.py 100.00% <100.00%> (ø)
botorch/optim/fit.py 100.00% <100.00%> (ø)
botorch/optim/numpy_converter.py 100.00% <100.00%> (ø)
botorch/optim/stopping.py 100.00% <100.00%> (ø)
... and 7 more

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D39101211

@j-wilson j-wilson changed the title Loss closures, fit_gpytorch_torch, fit_gyptorch_mll dispatch for ApproximateGPs Botorch closures Oct 20, 2022
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D39101211

9 similar comments
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D39101211

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D39101211

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D39101211

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D39101211

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D39101211

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D39101211

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D39101211

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D39101211

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D39101211

Summary:
X-link: facebook/Ax#1191

Pull Request resolved: #1439

This diff acts as follow-up to the recent model fitting refactor. The previous update focused on the high-level logic used to determine which fitting routines to use for which MLLs. This diff refactors the internal machinery used to evaluate forward-backward passes (producing losses and gradients, respectively) during optimization.

The solution we have opted for is to abstract away the evaluation process by relying on closures. In most cases, these closures are automatically constructed by composing simpler, multiply-dispatched base functions.

Reviewed By: Balandat

Differential Revision: D39101211

fbshipit-source-id: f4ec341228f9f16a327307ff398c3fb8839a3de2
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D39101211

facebook-github-bot pushed a commit to facebook/Ax that referenced this pull request Nov 11, 2022
Summary:
Pull Request resolved: #1191

X-link: pytorch/botorch#1439

This diff acts as follow-up to the recent model fitting refactor. The previous update focused on the high-level logic used to determine which fitting routines to use for which MLLs. This diff refactors the internal machinery used to evaluate forward-backward passes (producing losses and gradients, respectively) during optimization.

The solution we have opted for is to abstract away the evaluation process by relying on closures. In most cases, these closures are automatically constructed by composing simpler, multiply-dispatched base functions.

Reviewed By: Balandat

Differential Revision: D39101211

fbshipit-source-id: c2058a387fd74058073cfe73c9404d2df2f9b55a
esantorella added a commit to esantorella/botorch that referenced this pull request Oct 12, 2023
Summary:
## Motivation

Removes everything deprecated in pytorch#1439, in version 0.8.0, except for `_get_extra_mll_args`. Removal was straightforward since the functionality was not used and relevant unit tests were clearly labeled and self-contained.

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes

Pull Request resolved: pytorch#1995

Test Plan:
Existing units; make sure codecov has not regressed from deleting tests.

## Related PRs

pytorch#1439

Reviewed By: Balandat

Differential Revision: D48738275

Pulled By: esantorella

fbshipit-source-id: 38f39d185c0cc843f4be7ccd420c0430d3fa1fcd
esantorella added a commit to esantorella/botorch that referenced this pull request Oct 12, 2023
Summary:
## Motivation

Removes everything deprecated in pytorch#1439, in version 0.8.0, except for `_get_extra_mll_args`. Removal was straightforward since the functionality was not used and relevant unit tests were clearly labeled and self-contained.

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes

Pull Request resolved: pytorch#1995

Test Plan:
Existing units; make sure codecov has not regressed from deleting tests.

## Related PRs

pytorch#1439

Differential Revision: https://internalfb.com/D48738275

fbshipit-source-id: 4cb19467d42d782c4abe95810e48428c193bef99
esantorella added a commit to esantorella/botorch that referenced this pull request Oct 12, 2023
Summary:
## Motivation

Removes everything deprecated in pytorch#1439, in version 0.8.0, except for `_get_extra_mll_args`. Removal was straightforward since the functionality was not used and relevant unit tests were clearly labeled and self-contained.

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes

Pull Request resolved: pytorch#1995

Test Plan:
Existing units; make sure codecov has not regressed from deleting tests.

## Related PRs

pytorch#1439

Reviewed By: Balandat

Differential Revision: D48738275

Pulled By: esantorella

fbshipit-source-id: 828db010206db8d83071f53576ffba5b6dd49ea7
esantorella added a commit to esantorella/botorch that referenced this pull request Oct 12, 2023
…h_torch from BoTorch (pytorch#1995)

Summary:
## Motivation

Removes most of what was deprecated in pytorch#1439, in version 0.8.0, except for `_get_extra_mll_args` and `fit_gpytorch_model`. Removal was straightforward since the functionality was not used and relevant unit tests were clearly labeled and self-contained.

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes

Pull Request resolved: pytorch#1995

Test Plan:
Existing units; make sure codecov has not regressed from deleting tests.

## Related PRs

pytorch#1439

Differential Revision: https://internalfb.com/D48738275

fbshipit-source-id: 699d54a6382bd18996624d45a0ba1a564d2e0390
esantorella added a commit to esantorella/botorch that referenced this pull request Oct 12, 2023
…h_torch from BoTorch (pytorch#1995)

Summary:
## Motivation

Removes most of what was deprecated in pytorch#1439, in version 0.8.0, except for `_get_extra_mll_args` and `fit_gpytorch_model`. Removal was straightforward since the functionality was not used and relevant unit tests were clearly labeled and self-contained.

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes

Pull Request resolved: pytorch#1995

Test Plan:
Existing units; make sure codecov has not regressed from deleting tests.

## Related PRs

pytorch#1439

Differential Revision: https://internalfb.com/D48738275

fbshipit-source-id: 382aacf301178208f21c18071d104085e0f7f73a
esantorella added a commit to esantorella/botorch that referenced this pull request Oct 12, 2023
…h_torch from BoTorch (pytorch#1995)

Summary:
## Motivation

Removes most of what was deprecated in pytorch#1439, in version 0.8.0, except for `_get_extra_mll_args` and `fit_gpytorch_model`. Removal was straightforward since the functionality was not used and relevant unit tests were clearly labeled and self-contained.

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes

Pull Request resolved: pytorch#1995

Test Plan:
Existing units; make sure codecov has not regressed from deleting tests.

## Related PRs

pytorch#1439

Differential Revision: https://internalfb.com/D48738275

fbshipit-source-id: bcdf275627537c78243f8a01bd291bed3fc764e8
esantorella added a commit to esantorella/botorch that referenced this pull request Oct 13, 2023
…h_torch from BoTorch (pytorch#1995)

Summary:
## Motivation

Removes most of what was deprecated in pytorch#1439, in version 0.8.0, except for `_get_extra_mll_args` and `fit_gpytorch_model`. Removal was straightforward since the functionality was not used and relevant unit tests were clearly labeled and self-contained.

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes

Pull Request resolved: pytorch#1995

Test Plan:
Existing units; make sure codecov has not regressed from deleting tests.

## Related PRs

pytorch#1439

Reviewed By: Balandat

Differential Revision: D48738275

Pulled By: esantorella

fbshipit-source-id: bd9874f91cb90a8745943d2ae64facacc52d7a0e
esantorella added a commit to esantorella/botorch that referenced this pull request Oct 13, 2023
…h_torch from BoTorch (pytorch#1995)

Summary:
## Motivation

Removes most of what was deprecated in pytorch#1439, in version 0.8.0, except for `_get_extra_mll_args` and `fit_gpytorch_model`. Removal was straightforward since the functionality was not used and relevant unit tests were clearly labeled and self-contained.

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes

Pull Request resolved: pytorch#1995

Test Plan:
Existing units; make sure codecov has not regressed from deleting tests.

## Related PRs

pytorch#1439

Reviewed By: Balandat

Differential Revision: D48738275

Pulled By: esantorella

fbshipit-source-id: dae2be3547dec7406ad342198521897cf0afdca8
esantorella added a commit to esantorella/botorch that referenced this pull request Oct 13, 2023
…h_torch from BoTorch (pytorch#1995)

Summary:
## Motivation

Removes most of what was deprecated in pytorch#1439, in version 0.8.0, except for `_get_extra_mll_args` and `fit_gpytorch_model`. Removal was straightforward since the functionality was not used and relevant unit tests were clearly labeled and self-contained.

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes

Pull Request resolved: pytorch#1995

Test Plan:
Existing units; make sure codecov has not regressed from deleting tests.

## Related PRs

pytorch#1439

Reviewed By: Balandat

Differential Revision: D48738275

Pulled By: esantorella

fbshipit-source-id: cc34165b35429fe3967f635ede73c40c32ff2730
esantorella added a commit to esantorella/botorch that referenced this pull request Oct 13, 2023
…h_torch from BoTorch (pytorch#1995)

Summary:
## Motivation

Removes most of what was deprecated in pytorch#1439, in version 0.8.0, except for `_get_extra_mll_args` and `fit_gpytorch_model`. Removal was straightforward since the functionality was not used and relevant unit tests were clearly labeled and self-contained.

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes

Pull Request resolved: pytorch#1995

Test Plan:
Existing units; make sure codecov has not regressed from deleting tests.

## Related PRs

pytorch#1439

Differential Revision: https://internalfb.com/D48738275

fbshipit-source-id: ca0b2f52fbbcff8c88a624c7f89c63cb06928f53
esantorella added a commit to esantorella/botorch that referenced this pull request Oct 13, 2023
…rch (pytorch#1995)

Summary:
## Motivation

Removes most of what was deprecated in pytorch#1439, in version 0.8.0, except for `_get_extra_mll_args` and `fit_gpytorch_model`. Removal was straightforward since the functionality was not used and relevant unit tests were clearly labeled and self-contained.

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes

Pull Request resolved: pytorch#1995

Pull Request resolved: pytorch#2041

Test Plan:
Existing units; make sure codecov has not regressed from deleting tests.

## Related PRs

pytorch#1439

Differential Revision: https://internalfb.com/D50276821

fbshipit-source-id: c11ff61ad694e3e4db04ebf338c0745f22a61f79
esantorella added a commit to esantorella/botorch that referenced this pull request Oct 13, 2023
…rch (pytorch#1995)

Summary:
## Motivation

Removes most of what was deprecated in pytorch#1439, in version 0.8.0, except for `_get_extra_mll_args` and `fit_gpytorch_model`. Removal was straightforward since the functionality was not used and relevant unit tests were clearly labeled and self-contained.

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes

Pull Request resolved: pytorch#1995

Pull Request resolved: pytorch#2041

Test Plan:
Existing units; make sure codecov has not regressed from deleting tests.

## Related PRs

pytorch#1439

Differential Revision: https://internalfb.com/D50276821

fbshipit-source-id: 2c90b1b254a5246160b9848812155de68b2a00ce
facebook-github-bot pushed a commit that referenced this pull request Oct 13, 2023
…rch (#1995)

Summary:
## Motivation

Removes most of what was deprecated in #1439, in version 0.8.0, except for `_get_extra_mll_args` and `fit_gpytorch_model`. Removal was straightforward since the functionality was not used and relevant unit tests were clearly labeled and self-contained.

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes

Pull Request resolved: #1995

Pull Request resolved: #2041

Test Plan:
Existing units; make sure codecov has not regressed from deleting tests.

## Related PRs

#1439

Reviewed By: saitcakmak

Differential Revision: D50276821

Pulled By: esantorella

fbshipit-source-id: 45c8f082cdd00cb3bb78a342e38e3f0e751cf56f
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed Do not delete this pull request or issue due to inactivity. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants