diff --git a/test/test_print.py b/test/test_print.py index cba34ee9e..fdfc1bb4e 100644 --- a/test/test_print.py +++ b/test/test_print.py @@ -38,8 +38,8 @@ def run_kernel_and_capture_output(self, kernel_fn, args): code, result = code_and_output(kernel_fn, args) # Wait for any device prints to reach the host - if hasattr(result, "device") and result.device.type == "cuda": - torch.cuda.synchronize() + if hasattr(result, "device") and result.device.type == DEVICE.type: + torch.accelerator.synchronize() # Grab what pytest captured: stdout + stderr out, err = self._capfd.readouterr() @@ -69,8 +69,8 @@ def run_kernel_and_capture_output(self, kernel_fn, args): code, result = code_and_output(kernel_fn, args) # Force GPU synchronization to ensure all device prints complete - if hasattr(result, "device") and result.device.type == "cuda": - torch.cuda.synchronize() + if hasattr(result, "device") and result.device.type == DEVICE.type: + torch.accelerator.synchronize() # Ensure all output is flushed sys.stdout.flush() diff --git a/test/test_tensor_descriptor.py b/test/test_tensor_descriptor.py index f3fb7da62..3774df1f9 100644 --- a/test/test_tensor_descriptor.py +++ b/test/test_tensor_descriptor.py @@ -240,7 +240,7 @@ def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: y = torch.randn((64, 64), device=DEVICE, dtype=torch.float16) code, result = code_and_output(matmul, (x, y)) - torch.cuda.synchronize() + torch.accelerator.synchronize() expected = torch.matmul(x, y) torch.testing.assert_close(result, expected, atol=1e-2, rtol=1e-2)