Skip to content

Commit

Permalink
extend test case to verify bool -> int64
Browse files Browse the repository at this point in the history
  • Loading branch information
Kiyosora committed Dec 16, 2020
1 parent db1a49a commit 21bd03a
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions test/test_reductions.py
Expand Up @@ -817,8 +817,12 @@ def test_prod(self, device, dtype):
def test_prod_bool(self, device):
vals = [[True, True], [True, False], [False, False], []]
for val in vals:
result = torch.prod(torch.tensor(val, device=device), dtype=torch.bool)
expect = torch.tensor(np.prod(np.array(val), dtype=np.bool), device=device)
result = torch.prod(torch.tensor(val, device=device), dtype=torch.bool).item()
expect = np.prod(np.array(val), dtype=np.bool)
self.assertEqual(result, expect)

result = torch.prod(torch.tensor(val, device=device)).item()
expect = np.prod(np.array(val))
self.assertEqual(result, expect)

@onlyCPU
Expand Down

0 comments on commit 21bd03a

Please sign in to comment.