diff --git a/test/test_jit_fuser.py b/test/test_jit_fuser.py index a3f5ca00519f..98ffdca5de7a 100644 --- a/test/test_jit_fuser.py +++ b/test/test_jit_fuser.py @@ -59,7 +59,7 @@ def f(x, y): z1, z2 = (x + y).chunk(2, dim=1) return z1 * z2 - x = torch.randn(4, 4, dtype=torch.float, device='cuda') + x = torch.randn(4,4,dtype=torch.float, device='cuda') y = torch.randn(4, 4, dtype=torch.float, device='cuda') traced_f = torch.jit.trace(f, (x, y,)) self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y))