From 21bd03abc8c554f768413fb9f0cbf2ac4ebafc9e Mon Sep 17 00:00:00 2001 From: kiyosora Date: Tue, 15 Dec 2020 15:43:35 +0800 Subject: [PATCH] extend test case to verify bool -> int64 --- test/test_reductions.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/test_reductions.py b/test/test_reductions.py index 6c0eeb0f0bbb..762ddaa2763c 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -817,8 +817,12 @@ def test_prod(self, device, dtype): def test_prod_bool(self, device): vals = [[True, True], [True, False], [False, False], []] for val in vals: - result = torch.prod(torch.tensor(val, device=device), dtype=torch.bool) - expect = torch.tensor(np.prod(np.array(val), dtype=np.bool), device=device) + result = torch.prod(torch.tensor(val, device=device), dtype=torch.bool).item() + expect = np.prod(np.array(val), dtype=np.bool) + self.assertEqual(result, expect) + + result = torch.prod(torch.tensor(val, device=device)).item() + expect = np.prod(np.array(val)) self.assertEqual(result, expect) @onlyCPU