Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion .github/scripts/ci_test_xpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,14 @@ python3 -c "import torch; import torchao; print(f'Torch version: {torch.__versio

pip install pytest expecttest parameterized accelerate hf_transfer 'modelscope!=1.15.0'

pytest -v -s torchao/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py
pytest -v -s torchao/test/quantization/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does these tests run on xpu?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These UTs will be run on XPU finally. The current TorchAO CI should only cover the CUDA device and skip on XPU. We plan to firstly add these test folders into the XPU ci which is triggered by ciflow/xpu and then enabled the XPU UTs file by file to reduce the review effort.


pytest -v -s torchao/test/dtypes/

pytest -v -s torchao/test/float8/

pytest -v -s torchao/test/integration/test_integration.py

pytest -v -s torchao/test/prototype/

pytest -v -s torchao/test/test_ao_models.py
63 changes: 63 additions & 0 deletions torchao/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,69 @@ def wrapper(*args, **kwargs):
return decorator


def skip_if_no_xpu():
try:
import pytest

has_pytest = True
except ImportError:
has_pytest = False
import unittest

def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if not torch.xpu.is_available():
skip_message = "No XPU available"
if has_pytest:
pytest.skip(skip_message)
else:
unittest.skip(skip_message)
return func(*args, **kwargs)

return wrapper

return decorator


def skip_if_xpu(message=None):
"""
Decorator to skip tests on XPU platform with custom message.

Args:
message (str, optional): Additional information about why the test is skipped.
"""
try:
import pytest

has_pytest = True
except ImportError:
has_pytest = False
import unittest

def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if torch.xpu.is_available():
skip_message = "Skipping the test in XPU"
if message:
skip_message += f": {message}"
if has_pytest:
pytest.skip(skip_message)
else:
unittest.skip(skip_message)
return func(*args, **kwargs)

return wrapper

# Handle both @skip_if_xpu and @skip_if_xpu() syntax
if callable(message):
func = message
message = None
return decorator(func)
return decorator


def skip_if_no_cuda():
import unittest

Expand Down
7 changes: 7 additions & 0 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,13 @@ def get_available_devices():
return devices


def get_current_accelerator_device():
if torch.accelerator.is_available():
return torch.accelerator.current_accelerator()
else:
return None


def get_compute_capability():
if torch.cuda.is_available():
capability = torch.cuda.get_device_capability()
Expand Down
Loading