diff --git a/test/test_dot.py b/test/test_dot.py index 43520e5a9..1642bd899 100644 --- a/test/test_dot.py +++ b/test/test_dot.py @@ -16,6 +16,7 @@ from helion._testing import is_cuda from helion._testing import skipIfRefEager from helion._testing import skipIfRocm +from helion._testing import skipIfXPU import helion.language as hl @@ -86,9 +87,23 @@ def make_test_function(input_dtype, acc_dtype, static_shapes_option): @skipIfRocm("Core dumps with rocm -- https://github.com/pytorch/helion/issues/445") def test_impl(self): # Skip FP8 tests if GPU doesn't support it + def _is_cuda_fp8_supported(): + if not is_cuda(): + return False + return torch.cuda.get_device_capability(0)[0] >= 9 + + def _is_xpu_fp8_supported(): + if not torch.xpu.is_available(): + return False + + from packaging import version + + return version.parse(triton.__version__) >= version.parse("3.5") + + is_fp8_supported = _is_cuda_fp8_supported() or _is_xpu_fp8_supported() if ( input_dtype in (torch.float8_e4m3fn, torch.float8_e5m2) - and torch.cuda.get_device_capability(0)[0] < 9 + and not is_fp8_supported ): self.skipTest(f"FP8 dtype {input_dtype} not supported on this GPU") @@ -243,6 +258,7 @@ def bmm(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # torch.baddbmm codegen shape is covered indirectly by broader matmul tests; skipping a brittle code-inspection here @skipIfRefEager("Debug dtype codegen checks rely on compiled code") + @skipIfXPU("Failed on XPU - https://github.com/pytorch/helion/issues/772") def test_baddbmm_pipeline_debug_dtype_asserts(self): # Reproduces scripts/repro512.py within the test suite and asserts # the kernel compiles and runs with debug dtype asserts enabled.