Skip to content

Commit

Permalink
Bug fix: _filter_kwargs was erroring when provided a function witho…
Browse files Browse the repository at this point in the history
…ut a `__name__` attribute (#1678)

Summary:
Pull Request resolved: #1678

See #1667

Reviewed By: danielrjiang

Differential Revision: D43286116

fbshipit-source-id: 3da3e6ff23b517f5379ee90f407dc04d4f2ad06e
  • Loading branch information
esantorella authored and facebook-github-bot committed Feb 15, 2023
1 parent ad38736 commit 63dd0cd
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 12 deletions.
7 changes: 6 additions & 1 deletion botorch/optim/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,14 @@ def _filter_kwargs(function: Callable, **kwargs: Any) -> Any:
allowed_params = signature(function).parameters
removed = {k for k in kwargs.keys() if k not in allowed_params}
if len(removed) > 0:
fn_descriptor = (
f" for function {function.__name__}"
if hasattr(function, "__name__")
else ""
)
warn(
f"Keyword arguments {list(removed)} will be ignored because they are"
f" not allowed parameters for function {function.__name__}. Allowed "
f" not allowed parameters{fn_descriptor}. Allowed "
f"parameters are {list(allowed_params.keys())}."
)
return {k: v for k, v in kwargs.items() if k not in removed}
Expand Down
6 changes: 0 additions & 6 deletions test/optim/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,10 @@ def test_optimize_acqf_joint(
mock_gen_candidates_scipy,
mock_gen_candidates_torch,
):
# Mocks don't have a __name__ attribute.
# Set the attribute, since it is needed for testing _filter_kwargs
if mock_gen_candidates == mock_gen_candidates_torch:
mock_signature.return_value = signature(gen_candidates_torch)
else:
mock_signature.return_value = signature(gen_candidates_scipy)
mock_gen_candidates.__name__ = "gen_candidates"

mock_gen_batch_initial_conditions.return_value = torch.zeros(
num_restarts, q, 3, device=self.device, dtype=dtype
Expand Down Expand Up @@ -835,13 +832,10 @@ def nlc(x):
mock_gen_candidates_torch,
mock_gen_candidates_scipy,
):
# Mocks don't have a __name__ attribute.
# Set the attribute, since it is needed for testing _filter_kwargs
if mock_gen_candidates == mock_gen_candidates_torch:
mock_signature.return_value = signature(gen_candidates_torch)
else:
mock_signature.return_value = signature(gen_candidates_scipy)
mock_gen_candidates.__name__ = "gen_candidates"
for dtype in (torch.float, torch.double):

mock_acq_function = MockAcquisitionFunction()
Expand Down
19 changes: 14 additions & 5 deletions test/optim/utils/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,23 @@ def mock_adam(params, lr: float = 0.001) -> None:
return # pragma: nocover

kwargs = {"lr": 0.01, "maxiter": 3000}
with catch_warnings(record=True) as ws:
expected_msg = (
r"Keyword arguments \['maxiter'\] will be ignored because they are "
r"not allowed parameters for function mock_adam. Allowed parameters "
r"are \['params', 'lr'\]."
)

with self.assertWarnsRegex(Warning, expected_msg):
valid_kwargs = _filter_kwargs(mock_adam, **kwargs)
self.assertEqual(set(valid_kwargs.keys()), {"lr"})

mock_partial = partial(mock_adam, lr=2.0)
expected_msg = (
"Keyword arguments ['maxiter'] will be ignored because they are not"
" allowed parameters for function mock_adam. Allowed parameters are "
"['params', 'lr']."
r"Keyword arguments \['maxiter'\] will be ignored because they are "
r"not allowed parameters. Allowed parameters are \['params', 'lr'\]."
)
self.assertEqual(expected_msg, str(ws[0].message))
with self.assertWarnsRegex(Warning, expected_msg):
valid_kwargs = _filter_kwargs(mock_partial, **kwargs)
self.assertEqual(set(valid_kwargs.keys()), {"lr"})

def test_handle_numerical_errors(self):
Expand Down

0 comments on commit 63dd0cd

Please sign in to comment.