From 1847f3ed87828779ea8f1c8bccc88f0ff933bc46 Mon Sep 17 00:00:00 2001 From: "Zeng, Xiangdong" Date: Thu, 13 Nov 2025 10:19:23 +0800 Subject: [PATCH 1/5] add common code for xpu --- .github/scripts/ci_test_xpu.sh | 14 +++++++++++- torchao/testing/utils.py | 41 ++++++++++++++++++++++++++++++++++ torchao/utils.py | 7 ++++++ 3 files changed, 61 insertions(+), 1 deletion(-) diff --git a/.github/scripts/ci_test_xpu.sh b/.github/scripts/ci_test_xpu.sh index d765696b40..5f59eae00b 100644 --- a/.github/scripts/ci_test_xpu.sh +++ b/.github/scripts/ci_test_xpu.sh @@ -14,4 +14,16 @@ python3 -c "import torch; import torchao; print(f'Torch version: {torch.__versio pip install pytest expecttest parameterized accelerate hf_transfer 'modelscope!=1.15.0' -pytest -v -s torchao/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py +pytest -v -s torchao/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py + +pytest -v -s torchao/test/quantization/ + +pytest -v -s torchao/test/dtypes/ + +pytest -v -s torchao/test/float8/ + +pytest -v -s torchao/test/integration/test_integration.py + +pytest -v -s torchao/test/prototype/ + +pytest -v -s torchao/test/test_ao_models.py diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index b371b21f06..4a4a8d7aec 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -98,6 +98,47 @@ def wrapper(*args, **kwargs): return decorator +def skip_if_no_xpu(message=None): + """Decorator to skip tests on ROCm platform with custom message. + + Args: + message (str, optional): Additional information about why the test is skipped. + """ + import unittest + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if not torch.xpu.is_available(): + skip_message = "Skipping the test in XPU" + if message: + skip_message += f": {message}" + unittest.skip(skip_message) + return func(*args, **kwargs) + + return wrapper + + return decorator + + +def skip_if_xpu(message=None): + """ + Decorator to skip tests if XPU is available. + + Args: + message (str, optional): Additional information about why the test is skipped. + """ + + def decorator(func): + reason = "Skipping the test on XPU" + if message: + reason += f": {message}" + + return unittest.skipIf(torch.xpu.is_available(), reason)(func) + + return decorator + + def skip_if_no_cuda(): import unittest diff --git a/torchao/utils.py b/torchao/utils.py index e123dfe891..b610fd969d 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -137,6 +137,13 @@ def get_available_devices(): return devices +def auto_detect_device(): + if torch.accelerator.is_available(): + return torch.accelerator.current_accelerator() + else: + return None + + def get_compute_capability(): if torch.cuda.is_available(): capability = torch.cuda.get_device_capability() From aab4cc603e7dbf273868c9eb5b2fd88e2253ad6d Mon Sep 17 00:00:00 2001 From: "Zeng, Xiangdong" Date: Thu, 13 Nov 2025 11:13:06 +0800 Subject: [PATCH 2/5] fix format issue --- torchao/testing/utils.py | 4 ++-- torchao/utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index 4a4a8d7aec..c00b48f558 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -100,7 +100,7 @@ def wrapper(*args, **kwargs): def skip_if_no_xpu(message=None): """Decorator to skip tests on ROCm platform with custom message. - + Args: message (str, optional): Additional information about why the test is skipped. """ @@ -124,7 +124,7 @@ def wrapper(*args, **kwargs): def skip_if_xpu(message=None): """ Decorator to skip tests if XPU is available. - + Args: message (str, optional): Additional information about why the test is skipped. """ diff --git a/torchao/utils.py b/torchao/utils.py index b610fd969d..607d1b8542 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -142,8 +142,8 @@ def auto_detect_device(): return torch.accelerator.current_accelerator() else: return None - - + + def get_compute_capability(): if torch.cuda.is_available(): capability = torch.cuda.get_device_capability() From b4933aee96936ae485b1e53d3346d4e08754acbe Mon Sep 17 00:00:00 2001 From: "Zeng, Xiangdong" Date: Thu, 13 Nov 2025 16:17:53 +0800 Subject: [PATCH 3/5] remove case --- .github/scripts/ci_test_xpu.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/scripts/ci_test_xpu.sh b/.github/scripts/ci_test_xpu.sh index 5f59eae00b..05089db7c8 100644 --- a/.github/scripts/ci_test_xpu.sh +++ b/.github/scripts/ci_test_xpu.sh @@ -14,8 +14,6 @@ python3 -c "import torch; import torchao; print(f'Torch version: {torch.__versio pip install pytest expecttest parameterized accelerate hf_transfer 'modelscope!=1.15.0' -pytest -v -s torchao/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py - pytest -v -s torchao/test/quantization/ pytest -v -s torchao/test/dtypes/ From f4eed651d84c4129924bd62c70a527fece96cb36 Mon Sep 17 00:00:00 2001 From: "Zeng, Xiangdong" Date: Fri, 14 Nov 2025 13:49:33 +0800 Subject: [PATCH 4/5] refine the xpu skip func --- torchao/testing/utils.py | 52 ++++++++++++++++++++++++++++------------ 1 file changed, 37 insertions(+), 15 deletions(-) diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index c00b48f558..bf91f7675d 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -98,22 +98,24 @@ def wrapper(*args, **kwargs): return decorator -def skip_if_no_xpu(message=None): - """Decorator to skip tests on ROCm platform with custom message. +def skip_if_no_xpu(): + try: + import pytest - Args: - message (str, optional): Additional information about why the test is skipped. - """ - import unittest + has_pytest = True + except ImportError: + has_pytest = False + import unittest def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): if not torch.xpu.is_available(): - skip_message = "Skipping the test in XPU" - if message: - skip_message += f": {message}" - unittest.skip(skip_message) + skip_message = "No XPU available" + if has_pytest: + pytest.skip(skip_message) + else: + unittest.skip(skip_message) return func(*args, **kwargs) return wrapper @@ -123,19 +125,39 @@ def wrapper(*args, **kwargs): def skip_if_xpu(message=None): """ - Decorator to skip tests if XPU is available. + Decorator to skip tests on XPU platform with custom message. Args: message (str, optional): Additional information about why the test is skipped. """ + try: + import pytest + + has_pytest = True + except ImportError: + has_pytest = False + import unittest def decorator(func): - reason = "Skipping the test on XPU" - if message: - reason += f": {message}" + @functools.wraps(func) + def wrapper(*args, **kwargs): + if torch.xpu.is_available(): + skip_message = "Skipping the test in XPU" + if message: + skip_message += f": {message}" + if has_pytest: + pytest.skip(skip_message) + else: + unittest.skip(skip_message) + return func(*args, **kwargs) - return unittest.skipIf(torch.xpu.is_available(), reason)(func) + return wrapper + # Handle both @skip_if_xpu and @skip_if_xpu() syntax + if callable(message): + func = message + message = None + return decorator(func) return decorator From da10ddb8d0f7c599eaeec937e3ca0d7559297ae0 Mon Sep 17 00:00:00 2001 From: "Zhang, Liangang" Date: Tue, 18 Nov 2025 06:37:32 +0800 Subject: [PATCH 5/5] change auto_device_check to get_current_accelerator_device --- torchao/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/utils.py b/torchao/utils.py index 607d1b8542..875383a064 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -137,7 +137,7 @@ def get_available_devices(): return devices -def auto_detect_device(): +def get_current_accelerator_device(): if torch.accelerator.is_available(): return torch.accelerator.current_accelerator() else: