Skip to content

Commit

Permalink
Update to PyTorch 1.2 (#192)
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo authored and neerajprad committed Aug 11, 2019
1 parent 90c6df4 commit 0322d4c
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ cache:

install:
- pip install -U pip
- pip install https://download.pytorch.org/whl/cpu/torch-1.1.0-cp36-cp36m-linux_x86_64.whl
- pip install torch==1.2.0+cpu -f https://download.pytorch.org/whl/torch_stable.html

# Keep track of Pyro dev branch
- pip install https://github.com/pyro-ppl/pyro/archive/dev.zip
Expand Down
2 changes: 1 addition & 1 deletion funsor/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def assert_close(actual, expected, atol=1e-6, rtol=1e-6):
elif isinstance(actual, torch.Tensor):
assert actual.dtype == expected.dtype, msg
assert actual.shape == expected.shape, msg
if actual.dtype in (torch.long, torch.uint8):
if actual.dtype in (torch.long, torch.uint8, torch.bool):
assert (actual == expected).all(), msg
else:
eq = (actual == expected)
Expand Down
2 changes: 1 addition & 1 deletion funsor/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ def _exp(x):

@ops.log.register(torch.Tensor)
def _log(x):
if x.dtype in (torch.uint8, torch.long):
if x.dtype in (torch.bool, torch.uint8, torch.long):
x = x.float()
return x.log()

Expand Down
2 changes: 1 addition & 1 deletion test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


@pytest.mark.parametrize('shape', [(), (4,), (3, 2)])
@pytest.mark.parametrize('dtype', [torch.float, torch.long, torch.uint8])
@pytest.mark.parametrize('dtype', [torch.float, torch.long, torch.uint8, torch.bool])
def test_to_funsor(shape, dtype):
t = torch.randn(shape).type(dtype)
f = funsor.to_funsor(t)
Expand Down

0 comments on commit 0322d4c

Please sign in to comment.