diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 27d6d8bb85..1ca0daee9a 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -96,6 +96,7 @@ get_groupwise_affine_qparams, groupwise_affine_quantize_tensor, ) +from torchao.testing.utils import skip_if_xpu from torchao.utils import ( _is_fbgemm_gpu_genai_available, get_current_accelerator_device, @@ -695,10 +696,7 @@ def test_qat_4w_quantizer_gradients(self): self._test_qat_quantized_gradients(quantizer) @unittest.skipIf(_DEVICE is None, "skipping when GPU is not available") - @unittest.skipIf( - _DEVICE is torch.device("xpu"), - "skipped due to https://github.com/intel/torch-xpu-ops/issues/1770", - ) + @skip_if_xpu("skipped due to https://github.com/intel/torch-xpu-ops/issues/1770") def test_qat_4w_quantizer(self): from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer from torchao.quantization.qat import Int4WeightOnlyQATQuantizer @@ -2015,6 +2013,7 @@ def test_quantize_api_int8_intx(self, weight_dtype, weight_granularity, dtype): ) @unittest.skipIf(_DEVICE is None, "skipping when GPU is not available") + @skip_if_xpu("XPU enablement in progress") @parametrize( "weight_dtype, granularity, dtype, module_type", [