Skip to content

Commit

Permalink
Add BotorchTestCase.assertAllClose (#1618)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1618

`BotorchTestCase.assertAllClose` is a thin wrapper around `torch.testing.assert_close`, designed to replace usages of
`self.assertTrue(torch.allclose(...))`.

Using  has several advantages over `torch.allclose`:
* Checks that shapes are equal, not just values
* More configurability and better defaults, such as higher tolerances for single precision, if we choose to use them in the future. for the time being, I've set up this wrapper so that numerical checks remain exactly the same as they used to be.
* More informative test output, showing what was put in and why the test failed:

Old test output:
```AssertionError: False is not true```

New test output:
```
1) AssertionError: Scalars are not close!

Absolute difference: 1.0000034868717194 (up to 0.0001 allowed)
Relative difference: 0.8348668001940709 (up to 1e-05 allowed)
```

Reviewed By: Balandat

Differential Revision: D42402142

fbshipit-source-id: 8d29fde720b24a1c6c9f22797585db5962d01d7b
  • Loading branch information
esantorella authored and facebook-github-bot committed Jan 9, 2023
1 parent bb5fc4c commit 05d93e6
Showing 1 changed file with 31 additions and 1 deletion.
32 changes: 31 additions & 1 deletion botorch/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import math
import warnings
from collections import OrderedDict
from typing import Any, List, Optional, Tuple
from typing import Any, List, Optional, Tuple, Union
from unittest import TestCase

import torch
Expand Down Expand Up @@ -51,6 +51,36 @@ def setUp(self):
category=UserWarning,
)

def assertAllClose(
self,
input: torch.Tensor,
other: Union[torch.Tensor, float],
rtol: float = 1e-05,
atol: float = 1e-08,
equal_nan: bool = False,
) -> None:
r"""
Calls torch.testing.assert_close, using the signature and default behavior
of torch.allclose.
Example output:
AssertionError: Scalars are not close!
Absolute difference: 1.0000034868717194 (up to 0.0001 allowed)
Relative difference: 0.8348668001940709 (up to 1e-05 allowed)
"""
# Why not just use the signature and behavior of `torch.testing.assert_close`?
# Because we used `torch.allclose` for testing in the past, and the two don't
# behave exactly the same. In particular, `assert_close` requires both `atol`
# and `rtol` to be set if either one is.
torch.testing.assert_close(
input,
other,
rtol=rtol,
atol=atol,
equal_nan=equal_nan,
)


class BaseTestProblemBaseTestCase:

Expand Down

0 comments on commit 05d93e6

Please sign in to comment.