From 4ed1264a182b32eadedfb16f6065abf5aee8b19e Mon Sep 17 00:00:00 2001 From: Michael Suo Date: Sat, 7 Sep 2019 00:09:59 -0700 Subject: [PATCH] Update test_jit_fuser.py --- test/test_jit_fuser.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_jit_fuser.py b/test/test_jit_fuser.py index a3f5ca00519f..9d019248c08e 100644 --- a/test/test_jit_fuser.py +++ b/test/test_jit_fuser.py @@ -59,7 +59,7 @@ def f(x, y): z1, z2 = (x + y).chunk(2, dim=1) return z1 * z2 - x = torch.randn(4, 4, dtype=torch.float, device='cuda') + x = torch.randn(4,4,dtype=torch.float, device='cuda') y = torch.randn(4, 4, dtype=torch.float, device='cuda') traced_f = torch.jit.trace(f, (x, y,)) self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y)) @@ -83,7 +83,7 @@ def scaleshift(x, scale, shift): @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(not RUN_CUDA_HALF, "no half support") def test_cuda_half(self): - x = torch.randn(4, 4, dtype=torch.half, device='cuda') + x = torch.randn(4,4,dtype=torch.half, device='cuda') y = torch.randn(4, 4, dtype=torch.half, device='cuda') funcs = [