diff --git a/test/test_jit_fuser.py b/test/test_jit_fuser.py index a3f5ca00519f..9d019248c08e 100644 --- a/test/test_jit_fuser.py +++ b/test/test_jit_fuser.py @@ -59,7 +59,7 @@ def f(x, y): z1, z2 = (x + y).chunk(2, dim=1) return z1 * z2 - x = torch.randn(4, 4, dtype=torch.float, device='cuda') + x = torch.randn(4,4,dtype=torch.float, device='cuda') y = torch.randn(4, 4, dtype=torch.float, device='cuda') traced_f = torch.jit.trace(f, (x, y,)) self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y)) @@ -83,7 +83,7 @@ def scaleshift(x, scale, shift): @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(not RUN_CUDA_HALF, "no half support") def test_cuda_half(self): - x = torch.randn(4, 4, dtype=torch.half, device='cuda') + x = torch.randn(4,4,dtype=torch.half, device='cuda') y = torch.randn(4, 4, dtype=torch.half, device='cuda') funcs = [