diff --git a/test/test_jit_fuser.py b/test/test_jit_fuser.py index a3f5ca00519f..1abbfe0fdb1d 100644 --- a/test/test_jit_fuser.py +++ b/test/test_jit_fuser.py @@ -83,7 +83,7 @@ def scaleshift(x, scale, shift): @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(not RUN_CUDA_HALF, "no half support") def test_cuda_half(self): - x = torch.randn(4, 4, dtype=torch.half, device='cuda') + x = torch.randn(4,4,dtype=torch.half, device='cuda') y = torch.randn(4, 4, dtype=torch.half, device='cuda') funcs = [