diff --git a/.github/scripts/ci_test_xpu.sh b/.github/scripts/ci_test_xpu.sh index d765696b40..05089db7c8 100644 --- a/.github/scripts/ci_test_xpu.sh +++ b/.github/scripts/ci_test_xpu.sh @@ -14,4 +14,14 @@ python3 -c "import torch; import torchao; print(f'Torch version: {torch.__versio pip install pytest expecttest parameterized accelerate hf_transfer 'modelscope!=1.15.0' -pytest -v -s torchao/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py +pytest -v -s torchao/test/quantization/ + +pytest -v -s torchao/test/dtypes/ + +pytest -v -s torchao/test/float8/ + +pytest -v -s torchao/test/integration/test_integration.py + +pytest -v -s torchao/test/prototype/ + +pytest -v -s torchao/test/test_ao_models.py diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index b371b21f06..bf91f7675d 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -98,6 +98,69 @@ def wrapper(*args, **kwargs): return decorator +def skip_if_no_xpu(): + try: + import pytest + + has_pytest = True + except ImportError: + has_pytest = False + import unittest + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if not torch.xpu.is_available(): + skip_message = "No XPU available" + if has_pytest: + pytest.skip(skip_message) + else: + unittest.skip(skip_message) + return func(*args, **kwargs) + + return wrapper + + return decorator + + +def skip_if_xpu(message=None): + """ + Decorator to skip tests on XPU platform with custom message. + + Args: + message (str, optional): Additional information about why the test is skipped. + """ + try: + import pytest + + has_pytest = True + except ImportError: + has_pytest = False + import unittest + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if torch.xpu.is_available(): + skip_message = "Skipping the test in XPU" + if message: + skip_message += f": {message}" + if has_pytest: + pytest.skip(skip_message) + else: + unittest.skip(skip_message) + return func(*args, **kwargs) + + return wrapper + + # Handle both @skip_if_xpu and @skip_if_xpu() syntax + if callable(message): + func = message + message = None + return decorator(func) + return decorator + + def skip_if_no_cuda(): import unittest diff --git a/torchao/utils.py b/torchao/utils.py index e123dfe891..875383a064 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -137,6 +137,13 @@ def get_available_devices(): return devices +def get_current_accelerator_device(): + if torch.accelerator.is_available(): + return torch.accelerator.current_accelerator() + else: + return None + + def get_compute_capability(): if torch.cuda.is_available(): capability = torch.cuda.get_device_capability()