Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion test/test_dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from helion._testing import is_cuda
from helion._testing import skipIfRefEager
from helion._testing import skipIfRocm
from helion._testing import skipIfXPU
import helion.language as hl


Expand Down Expand Up @@ -86,9 +87,23 @@ def make_test_function(input_dtype, acc_dtype, static_shapes_option):
@skipIfRocm("Core dumps with rocm -- https://github.com/pytorch/helion/issues/445")
def test_impl(self):
# Skip FP8 tests if GPU doesn't support it
def _is_cuda_fp8_supported():
if not is_cuda():
return False
return torch.cuda.get_device_capability(0)[0] >= 9

def _is_xpu_fp8_supported():
if not torch.xpu.is_available():
return False

from packaging import version

return version.parse(triton.__version__) >= version.parse("3.5")

is_fp8_supported = _is_cuda_fp8_supported() or _is_xpu_fp8_supported()
if (
input_dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
and torch.cuda.get_device_capability(0)[0] < 9
and not is_fp8_supported
):
self.skipTest(f"FP8 dtype {input_dtype} not supported on this GPU")

Expand Down Expand Up @@ -243,6 +258,7 @@ def bmm(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
# torch.baddbmm codegen shape is covered indirectly by broader matmul tests; skipping a brittle code-inspection here

@skipIfRefEager("Debug dtype codegen checks rely on compiled code")
@skipIfXPU("Failed on XPU - https://github.com/pytorch/helion/issues/772")
def test_baddbmm_pipeline_debug_dtype_asserts(self):
# Reproduces scripts/repro512.py within the test suite and asserts
# the kernel compiles and runs with debug dtype asserts enabled.
Expand Down
Loading