From 9c1de99e3ef56f8f0b75d76a6fb383e310286351 Mon Sep 17 00:00:00 2001 From: "Zeng, Xiangdong" Date: Tue, 18 Nov 2025 13:45:13 +0800 Subject: [PATCH 1/4] port 2 files to intel XPU --- test/quantization/test_moe_quant.py | 6 ++- test/quantization/test_qat.py | 57 +++++++++++++++-------------- 2 files changed, 34 insertions(+), 29 deletions(-) diff --git a/test/quantization/test_moe_quant.py b/test/quantization/test_moe_quant.py index 61000babc1..b973c7866e 100644 --- a/test/quantization/test_moe_quant.py +++ b/test/quantization/test_moe_quant.py @@ -33,7 +33,9 @@ quantize_, ) from torchao.quantization.utils import compute_error -from torchao.utils import is_sm_at_least_90 +from torchao.utils import is_sm_at_least_90, get_current_accelerator_device + +_DEVICE = get_current_accelerator_device() if torch.version.hip is not None: pytest.skip( @@ -54,7 +56,7 @@ def _test_impl_moe_quant( base_class=AffineQuantizedTensor, tensor_impl_class=None, dtype=torch.bfloat16, - device="cuda", + device=_DEVICE, fullgraph=False, ): """ diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index f523cb091c..4b36353013 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -100,10 +100,12 @@ _is_fbgemm_gpu_genai_available, is_fbcode, is_sm_at_least_89, + get_current_accelerator_device, ) # TODO: put this in a common test utils file -_CUDA_IS_AVAILABLE = torch.cuda.is_available() +_GPU_IS_AVAILABLE = torch.accelerator.is_available() +_DEVICE = get_current_accelerator_device() class Sub(torch.nn.Module): @@ -347,7 +349,7 @@ def _set_ptq_weight( group_size, ) q_weight = torch.ops.aten._convert_weight_to_int4pack( - q_weight.to("cuda"), + q_weight.to(_DEVICE), qat_linear.inner_k_tiles, ) ptq_linear.weight = q_weight @@ -600,13 +602,13 @@ def _assert_close_4w(self, val, ref): print(mean_err) self.assertTrue(mean_err < 0.05) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when gpu is not available") def test_qat_4w_primitives(self): n_bit = 4 group_size = 32 inner_k_tiles = 8 scales_precision = torch.bfloat16 - device = torch.device("cuda") + device = torch.device(_DEVICE) dtype = torch.bfloat16 torch.manual_seed(self.SEED) x = torch.randn(100, 256, dtype=dtype, device=device) @@ -651,13 +653,13 @@ def test_qat_4w_primitives(self): self._assert_close_4w(qat_out, ptq_out) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when gpu is not available") def test_qat_4w_linear(self): from torchao.quantization.GPTQ import WeightOnlyInt4Linear from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear group_size = 128 - device = torch.device("cuda") + device = torch.device(_DEVICE) dtype = torch.bfloat16 torch.manual_seed(self.SEED) qat_linear = Int4WeightOnlyQATLinear( @@ -692,14 +694,14 @@ def test_qat_4w_quantizer_gradients(self): quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8) self._test_qat_quantized_gradients(quantizer) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when gpu is not available") def test_qat_4w_quantizer(self): from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer from torchao.quantization.qat import Int4WeightOnlyQATQuantizer group_size = 32 inner_k_tiles = 8 - device = torch.device("cuda") + device = torch.device(_DEVICE) dtype = torch.bfloat16 torch.manual_seed(self.SEED) m = M().to(device).to(dtype) @@ -1870,6 +1872,7 @@ def test_float8_fake_quantize(self, granularity: Granularity): sqnr = compute_error(out, out_expected) self.assertGreater(sqnr, 16) + @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when gpu is not available") def _test_quantize_api_against_ptq( self, base_config: AOBaseConfig, @@ -1891,12 +1894,12 @@ def _test_quantize_api_against_ptq( torch.manual_seed(self.SEED) if module_type == "linear": - m = M().to(dtype).cuda() - example_inputs = (m.example_inputs()[0].to(dtype).cuda(),) + m = M().to(dtype).to(_DEVICE) + example_inputs = (m.example_inputs()[0].to(dtype).to(_DEVICE),) filter_fn = lambda m, fqn: isinstance(m, torch.nn.Linear) elif module_type == "embedding": - m = M3().to(dtype).cuda() - example_inputs = (m.example_inputs()[0].cuda(),) + m = M3().to(dtype).to(_DEVICE) + example_inputs = (m.example_inputs()[0].to(_DEVICE),) filter_fn = lambda m, fqn: isinstance(m, torch.nn.Embedding) else: raise ValueError(f"Unknown module type {module_type}") @@ -1919,7 +1922,7 @@ def _test_quantize_api_against_ptq( self.assertGreaterEqual(convert_sqnr, target_convert_sqnr) @parametrize("granularity", [PerTensor(), PerRow()]) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when gpu is not available") @unittest.skipIf(not is_sm_at_least_89(), "Need sm89+") def test_quantize_api_fp8_fp8(self, granularity: Granularity): """ @@ -1933,7 +1936,7 @@ def test_quantize_api_fp8_fp8(self, granularity: Granularity): target_convert_sqnr=float("inf"), ) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when gpu is not available") @unittest.skipIf(not is_sm_at_least_89(), "Need sm89+") @unittest.skipIf( not _is_fbgemm_gpu_genai_available(), "Requires fbgemm-gpu-genai >= 1.2.0" @@ -1950,7 +1953,7 @@ def test_quantize_api_fp8_int4(self): target_convert_sqnr=float("inf"), ) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when gpu is not available") @unittest.skipIf( not _is_fbgemm_gpu_genai_available(), "Requires fbgemm-gpu-genai >= 1.2.0" ) @@ -1971,7 +1974,7 @@ def test_quantize_api_int4(self, version: int, packing_format: Int4PackingFormat target_convert_sqnr=float("inf"), ) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when gpu is not available") def test_quantize_api_int8_int4(self): """ Test the following: @@ -1984,7 +1987,7 @@ def test_quantize_api_int8_int4(self): target_convert_sqnr=float("inf"), ) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when gpu is not available") @parametrize( "weight_dtype, weight_granularity, dtype", [ @@ -2009,7 +2012,7 @@ def test_quantize_api_int8_intx(self, weight_dtype, weight_granularity, dtype): dtype=dtype, ) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when gpu is not available") @parametrize( "weight_dtype, granularity, dtype, module_type", [ @@ -2092,7 +2095,7 @@ def test_quantize_api_nvfp4(self, use_per_tensor_scale: bool): ) @unittest.skipIf(not is_sm_at_least_89(), "Need sm89+") - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when gpu is not available") @parametrize("use_per_tensor_scale", [True, False]) def test_qat_nvfp4(self, use_per_tensor_scale: bool): """ @@ -2102,7 +2105,7 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool): from torchao.prototype.qat import NVFP4FakeQuantizeConfig torch.manual_seed(self.SEED) - m = M().cuda() + m = M().to(_DEVICE) baseline_model = copy.deepcopy(m) quantize_( baseline_model, @@ -2117,13 +2120,13 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool): # Compare prepared values torch.manual_seed(self.SEED) - x = m.example_inputs("cuda") + x = m.example_inputs(_DEVICE) out = m(*x) baseline_out = baseline_model(*x) sqnr = compute_error(out, baseline_out).item() self.assertGreaterEqual(sqnr, float("inf")) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when gpu is not available") @unittest.skipIf( not _is_fbgemm_gpu_genai_available(), "Requires fbgemm-gpu-genai >= 1.2.0" ) @@ -2141,7 +2144,7 @@ def test_fbgemm_fp8_primitives(self): _quantize_affine_float8, ) - x1 = torch.randn([128, 256], dtype=torch.bfloat16).cuda() + x1 = torch.randn([128, 256], dtype=torch.bfloat16).to(_DEVICE) x2 = copy.deepcopy(x1) # (1) Just call `quantize_fp8_row` @@ -2163,7 +2166,7 @@ def test_fbgemm_fp8_primitives(self): self.assertGreater(sqnr, 40) self.assertGreater(scale_sqnr, 50) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when gpu is not available") @unittest.skipIf( not _is_fbgemm_gpu_genai_available(), "Requires fbgemm-gpu-genai >= 1.2.0" ) @@ -2188,7 +2191,7 @@ def test_fbgemm_fp8_int4_preshuffled_primitives(self): ) group_size = 128 - x1 = torch.randn([128, 256], dtype=torch.bfloat16).cuda() + x1 = torch.randn([128, 256], dtype=torch.bfloat16).to(_DEVICE) x2 = copy.deepcopy(x1) x3 = copy.deepcopy(x1) @@ -2245,7 +2248,7 @@ def shuffle_and_pack(t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: ) self.assertGreater(sqnr_q1_q3_preshuffle, 32) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when gpu is not available") @unittest.skipIf( not _is_fbgemm_gpu_genai_available(), "Requires fbgemm-gpu-genai >= 1.2.0" ) @@ -2263,7 +2266,7 @@ def test_fbgemm_int4_weight_only_primitives(self): ) group_size = 128 - x1 = torch.randn([128, 256], dtype=torch.bfloat16).cuda() + x1 = torch.randn([128, 256], dtype=torch.bfloat16).to(_DEVICE) x2 = copy.deepcopy(x1) x3 = copy.deepcopy(x1) From c5cdd7ad597fa29cf4830c4895c94a63b0b175ad Mon Sep 17 00:00:00 2001 From: "Zeng, Xiangdong" Date: Tue, 18 Nov 2025 14:14:06 +0800 Subject: [PATCH 2/4] port 2 files to intel XPU --- test/quantization/test_moe_quant.py | 6 +- test/quantization/test_qat.py | 53 ++++---- test/quantization/test_quant_api.py | 196 ++++++++++------------------ 3 files changed, 102 insertions(+), 153 deletions(-) diff --git a/test/quantization/test_moe_quant.py b/test/quantization/test_moe_quant.py index b973c7866e..61000babc1 100644 --- a/test/quantization/test_moe_quant.py +++ b/test/quantization/test_moe_quant.py @@ -33,9 +33,7 @@ quantize_, ) from torchao.quantization.utils import compute_error -from torchao.utils import is_sm_at_least_90, get_current_accelerator_device - -_DEVICE = get_current_accelerator_device() +from torchao.utils import is_sm_at_least_90 if torch.version.hip is not None: pytest.skip( @@ -56,7 +54,7 @@ def _test_impl_moe_quant( base_class=AffineQuantizedTensor, tensor_impl_class=None, dtype=torch.bfloat16, - device=_DEVICE, + device="cuda", fullgraph=False, ): """ diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 4b36353013..1aa8f85f93 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -98,12 +98,13 @@ ) from torchao.utils import ( _is_fbgemm_gpu_genai_available, + get_current_accelerator_device, is_fbcode, is_sm_at_least_89, - get_current_accelerator_device, ) # TODO: put this in a common test utils file +_CUDA_IS_AVAILABLE = torch.cuda.is_available() _GPU_IS_AVAILABLE = torch.accelerator.is_available() _DEVICE = get_current_accelerator_device() @@ -602,7 +603,9 @@ def _assert_close_4w(self, val, ref): print(mean_err) self.assertTrue(mean_err < 0.05) - @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when gpu is not available") + @unittest.skipIf( + not _GPU_IS_AVAILABLE, "skipping when cuda or xpu is not available" + ) def test_qat_4w_primitives(self): n_bit = 4 group_size = 32 @@ -653,13 +656,13 @@ def test_qat_4w_primitives(self): self._assert_close_4w(qat_out, ptq_out) - @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when gpu is not available") + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_linear(self): from torchao.quantization.GPTQ import WeightOnlyInt4Linear from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear group_size = 128 - device = torch.device(_DEVICE) + device = torch.device("cuda") dtype = torch.bfloat16 torch.manual_seed(self.SEED) qat_linear = Int4WeightOnlyQATLinear( @@ -694,7 +697,7 @@ def test_qat_4w_quantizer_gradients(self): quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8) self._test_qat_quantized_gradients(quantizer) - @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when gpu is not available") + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_quantizer(self): from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer from torchao.quantization.qat import Int4WeightOnlyQATQuantizer @@ -706,6 +709,7 @@ def test_qat_4w_quantizer(self): torch.manual_seed(self.SEED) m = M().to(device).to(dtype) m2 = copy.deepcopy(m) + qat_quantizer = Int4WeightOnlyQATQuantizer( groupsize=group_size, inner_k_tiles=inner_k_tiles, @@ -1872,7 +1876,6 @@ def test_float8_fake_quantize(self, granularity: Granularity): sqnr = compute_error(out, out_expected) self.assertGreater(sqnr, 16) - @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when gpu is not available") def _test_quantize_api_against_ptq( self, base_config: AOBaseConfig, @@ -1894,12 +1897,12 @@ def _test_quantize_api_against_ptq( torch.manual_seed(self.SEED) if module_type == "linear": - m = M().to(dtype).to(_DEVICE) - example_inputs = (m.example_inputs()[0].to(dtype).to(_DEVICE),) + m = M().to(dtype).cuda() + example_inputs = (m.example_inputs()[0].to(dtype).cuda(),) filter_fn = lambda m, fqn: isinstance(m, torch.nn.Linear) elif module_type == "embedding": - m = M3().to(dtype).to(_DEVICE) - example_inputs = (m.example_inputs()[0].to(_DEVICE),) + m = M3().to(dtype).cuda() + example_inputs = (m.example_inputs()[0].cuda(),) filter_fn = lambda m, fqn: isinstance(m, torch.nn.Embedding) else: raise ValueError(f"Unknown module type {module_type}") @@ -1922,7 +1925,7 @@ def _test_quantize_api_against_ptq( self.assertGreaterEqual(convert_sqnr, target_convert_sqnr) @parametrize("granularity", [PerTensor(), PerRow()]) - @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when gpu is not available") + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") @unittest.skipIf(not is_sm_at_least_89(), "Need sm89+") def test_quantize_api_fp8_fp8(self, granularity: Granularity): """ @@ -1936,7 +1939,7 @@ def test_quantize_api_fp8_fp8(self, granularity: Granularity): target_convert_sqnr=float("inf"), ) - @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when gpu is not available") + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") @unittest.skipIf(not is_sm_at_least_89(), "Need sm89+") @unittest.skipIf( not _is_fbgemm_gpu_genai_available(), "Requires fbgemm-gpu-genai >= 1.2.0" @@ -1953,7 +1956,7 @@ def test_quantize_api_fp8_int4(self): target_convert_sqnr=float("inf"), ) - @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when gpu is not available") + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") @unittest.skipIf( not _is_fbgemm_gpu_genai_available(), "Requires fbgemm-gpu-genai >= 1.2.0" ) @@ -1974,7 +1977,7 @@ def test_quantize_api_int4(self, version: int, packing_format: Int4PackingFormat target_convert_sqnr=float("inf"), ) - @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when gpu is not available") + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_quantize_api_int8_int4(self): """ Test the following: @@ -1987,7 +1990,7 @@ def test_quantize_api_int8_int4(self): target_convert_sqnr=float("inf"), ) - @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when gpu is not available") + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") @parametrize( "weight_dtype, weight_granularity, dtype", [ @@ -2012,7 +2015,7 @@ def test_quantize_api_int8_intx(self, weight_dtype, weight_granularity, dtype): dtype=dtype, ) - @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when gpu is not available") + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") @parametrize( "weight_dtype, granularity, dtype, module_type", [ @@ -2095,7 +2098,7 @@ def test_quantize_api_nvfp4(self, use_per_tensor_scale: bool): ) @unittest.skipIf(not is_sm_at_least_89(), "Need sm89+") - @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when gpu is not available") + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") @parametrize("use_per_tensor_scale", [True, False]) def test_qat_nvfp4(self, use_per_tensor_scale: bool): """ @@ -2105,7 +2108,7 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool): from torchao.prototype.qat import NVFP4FakeQuantizeConfig torch.manual_seed(self.SEED) - m = M().to(_DEVICE) + m = M().cuda() baseline_model = copy.deepcopy(m) quantize_( baseline_model, @@ -2120,13 +2123,13 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool): # Compare prepared values torch.manual_seed(self.SEED) - x = m.example_inputs(_DEVICE) + x = m.example_inputs("cuda") out = m(*x) baseline_out = baseline_model(*x) sqnr = compute_error(out, baseline_out).item() self.assertGreaterEqual(sqnr, float("inf")) - @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when gpu is not available") + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") @unittest.skipIf( not _is_fbgemm_gpu_genai_available(), "Requires fbgemm-gpu-genai >= 1.2.0" ) @@ -2144,7 +2147,7 @@ def test_fbgemm_fp8_primitives(self): _quantize_affine_float8, ) - x1 = torch.randn([128, 256], dtype=torch.bfloat16).to(_DEVICE) + x1 = torch.randn([128, 256], dtype=torch.bfloat16).cuda() x2 = copy.deepcopy(x1) # (1) Just call `quantize_fp8_row` @@ -2166,7 +2169,7 @@ def test_fbgemm_fp8_primitives(self): self.assertGreater(sqnr, 40) self.assertGreater(scale_sqnr, 50) - @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when gpu is not available") + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") @unittest.skipIf( not _is_fbgemm_gpu_genai_available(), "Requires fbgemm-gpu-genai >= 1.2.0" ) @@ -2191,7 +2194,7 @@ def test_fbgemm_fp8_int4_preshuffled_primitives(self): ) group_size = 128 - x1 = torch.randn([128, 256], dtype=torch.bfloat16).to(_DEVICE) + x1 = torch.randn([128, 256], dtype=torch.bfloat16).cuda() x2 = copy.deepcopy(x1) x3 = copy.deepcopy(x1) @@ -2248,7 +2251,7 @@ def shuffle_and_pack(t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: ) self.assertGreater(sqnr_q1_q3_preshuffle, 32) - @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when gpu is not available") + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") @unittest.skipIf( not _is_fbgemm_gpu_genai_available(), "Requires fbgemm-gpu-genai >= 1.2.0" ) @@ -2266,7 +2269,7 @@ def test_fbgemm_int4_weight_only_primitives(self): ) group_size = 128 - x1 = torch.randn([128, 256], dtype=torch.bfloat16).to(_DEVICE) + x1 = torch.randn([128, 256], dtype=torch.bfloat16).cuda() x2 = copy.deepcopy(x1) x3 = copy.deepcopy(x1) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index e530babdb9..4c5f22d386 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -64,14 +64,17 @@ ) from torchao.quantization.quant_primitives import MappingType from torchao.quantization.utils import compute_error -from torchao.testing.utils import skip_if_rocm +from torchao.testing.utils import skip_if_rocm, skip_if_xpu from torchao.utils import ( + get_current_accelerator_device, is_sm_at_least_89, is_sm_at_least_90, torch_version_at_least, unwrap_tensor_subclass, ) +_DEVICE = get_current_accelerator_device() + try: import gemlite # noqa: F401 @@ -240,7 +243,7 @@ def api(model): torch.testing.assert_close(ref, res.cpu()) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") def test_int8_wo_quant_save_load(self): m = ToyLinearModel().eval().cpu() @@ -261,8 +264,8 @@ def api(model): api(m2) m2.load_state_dict(state_dict) - m2 = m2.to(device="cuda") - example_inputs = map(lambda x: x.cuda(), example_inputs) + m2 = m2.to(_DEVICE) + example_inputs = map(lambda x: x.to(_DEVICE), example_inputs) res = m2(*example_inputs) # TODO: figure out why ROCm has a larger error @@ -294,12 +297,13 @@ def test_8da4w_quantizer_linear_bias(self): m(*example_inputs) @unittest.skip("skipping until we get checkpoints for gpt-fast") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") def test_quantizer_int4_weight_only(self): from torchao._models._eval import TransformerEvalWrapper from torchao.quantization.linear_quant_modules import Int4WeightOnlyQuantizer precision = torch.bfloat16 - device = "cuda" + device = _DEVICE checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") model = Transformer.from_name(checkpoint_path.parent.name) checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) @@ -316,7 +320,7 @@ def test_quantizer_int4_weight_only(self): quantizer = Int4WeightOnlyQuantizer( groupsize, ) - model = quantizer.quantize(model).cuda() + model = quantizer.quantize(model).to(_DEVICE) result = TransformerEvalWrapper( model, tokenizer, @@ -332,11 +336,12 @@ def test_quantizer_int4_weight_only(self): ) @unittest.skip("skipping until we get checkpoints for gpt-fast") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") def test_eval_wrapper(self): from torchao._models._eval import TransformerEvalWrapper precision = torch.bfloat16 - device = "cuda" + device = _DEVICE checkpoint_path = Path("../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") model = Transformer.from_name(checkpoint_path.parent.name) checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) @@ -365,11 +370,12 @@ def test_eval_wrapper(self): # EVAL IS CURRENTLY BROKEN FOR LLAMA 3, VERY LOW ACCURACY @unittest.skip("skipping until we get checkpoints for gpt-fast") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") def test_eval_wrapper_llama3(self): from torchao._models._eval import TransformerEvalWrapper precision = torch.bfloat16 - device = "cuda" + device = _DEVICE checkpoint_path = Path( ".../gpt-fast/checkpoints/meta-llama/Meta-Llama-3-8B/model.pth" ) @@ -438,7 +444,7 @@ def test_quantized_tensor_subclass_8da4w(self, mapping_type): ref = m_copy(*example_inputs) self.assertTrue(torch.equal(res, ref)) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") def test_quantized_tensor_subclass_save_load(self): m = ToyLinearModel().eval().to(torch.bfloat16) m_copy = copy.deepcopy(m) @@ -456,7 +462,7 @@ def test_quantized_tensor_subclass_save_load(self): res = m_copy(*example_inputs) self.assertEqual(res, ref) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") def test_int8wo_quantized_model_to_device(self): m = ToyLinearModel().eval().to(torch.bfloat16) example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cpu") @@ -464,15 +470,15 @@ def test_int8wo_quantized_model_to_device(self): quantize_(m, Int8WeightOnlyConfig()) ref = m(*example_inputs) - example_inputs_cuda = (example_inputs[0].to("cuda"),) - m.to(device="cuda") + example_inputs_cuda = (example_inputs[0].to(_DEVICE),) + m.to(_DEVICE) cuda_res = m(*example_inputs_cuda) self.assertEqual(cuda_res.cpu(), ref) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") def test_quantized_tensor_subclass_save_load_map_location(self): - m = ToyLinearModel().eval().to(dtype=torch.bfloat16, device="cuda") - example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda") + m = ToyLinearModel().eval().to(dtype=torch.bfloat16, device=_DEVICE) + example_inputs = m.example_inputs(dtype=torch.bfloat16, device=_DEVICE) quantize_(m, Int8WeightOnlyConfig()) ref = m(*example_inputs) @@ -485,31 +491,33 @@ def test_quantized_tensor_subclass_save_load_map_location(self): m_copy = ToyLinearModel().eval() m_copy.load_state_dict(state_dict, assign=True) - m_copy.to(dtype=torch.bfloat16, device="cuda") + m_copy.to(dtype=torch.bfloat16, device=_DEVICE) res = m_copy(*example_inputs) self.assertEqual(res, ref) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") def test_quantized_model_streaming(self): + device_module = torch.get_device_module(_DEVICE) + def reset_memory(): gc.collect() - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() + device_module.empty_cache() + device_module.reset_peak_memory_stats() reset_memory() m = ToyLinearModel() - quantize_(m.to(device="cuda"), Int8WeightOnlyConfig()) - memory_baseline = torch.cuda.max_memory_allocated() + quantize_(m.to(device=_DEVICE), Int8WeightOnlyConfig()) + memory_baseline = device_module.max_memory_allocated() del m reset_memory() m = ToyLinearModel() - quantize_(m, Int8WeightOnlyConfig(), device="cuda") - memory_streaming = torch.cuda.max_memory_allocated() + quantize_(m, Int8WeightOnlyConfig(), device=_DEVICE) + memory_streaming = device_module.max_memory_allocated() for param in m.parameters(): - assert param.is_cuda + assert param.device.type == _DEVICE.type self.assertLess(memory_streaming, memory_baseline) @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) @@ -538,7 +546,7 @@ def test_int4wo_cpu(self, dtype, x_dim, use_hqq): assert "aten.mm.default" not in code[0] # TODO(#1690): move to new config names - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") @common_utils.parametrize( "config", [ @@ -555,6 +563,7 @@ def test_int4wo_cpu(self, dtype, x_dim, use_hqq): UIntXWeightOnlyConfig(dtype=torch.uint4), ], ) + @skip_if_xpu("XPU enablement in progress") @skip_if_rocm("ROCm enablement in progress") def test_workflow_e2e_numerics(self, config): """ @@ -583,17 +592,17 @@ def test_workflow_e2e_numerics(self, config): # scale has to be moved to cuda here because the parametrization init # code happens before gating for cuda availability if isinstance(config, Float8StaticActivationFloat8WeightConfig): - config.scale = config.scale.to("cuda") + config.scale = config.scale.to(_DEVICE) dtype = torch.bfloat16 if isinstance(config, GemliteUIntXWeightOnlyConfig): dtype = torch.float16 # set up inputs - x = torch.randn(128, 128, device="cuda", dtype=dtype) + x = torch.randn(128, 128, device=_DEVICE, dtype=dtype) # TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469 # is that expected? - m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().to(dtype) + m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).to(_DEVICE).to(dtype) m_q = copy.deepcopy(m_ref) # quantize @@ -606,13 +615,13 @@ def test_workflow_e2e_numerics(self, config): sqnr = compute_error(y_ref, y_q) assert sqnr >= 16.5, f"SQNR {sqnr} is too low" - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") def test_module_fqn_to_config_default(self): config1 = Int4WeightOnlyConfig(group_size=32, version=1) config2 = Int8WeightOnlyConfig() config = ModuleFqnToConfig({"_default": config1, "linear2": config2}) - model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) - example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) + model = ToyLinearModel().to(_DEVICE).to(dtype=torch.bfloat16) + example_inputs = model.example_inputs(device=_DEVICE, dtype=torch.bfloat16) quantize_(model, config, filter_fn=None) model(*example_inputs) assert isinstance(model.linear1.weight, AffineQuantizedTensor) @@ -620,13 +629,13 @@ def test_module_fqn_to_config_default(self): assert isinstance(model.linear2.weight, AffineQuantizedTensor) assert isinstance(model.linear2.weight._layout, PlainLayout) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") def test_module_fqn_to_config_module_name(self): config1 = Int4WeightOnlyConfig(group_size=32, version=1) config2 = Int8WeightOnlyConfig() config = ModuleFqnToConfig({"linear1": config1, "linear2": config2}) - model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) - example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) + model = ToyLinearModel().to(_DEVICE).to(dtype=torch.bfloat16) + example_inputs = model.example_inputs(device=_DEVICE, dtype=torch.bfloat16) quantize_(model, config, filter_fn=None) model(*example_inputs) assert isinstance(model.linear1.weight, AffineQuantizedTensor) @@ -759,25 +768,25 @@ def test_module_fqn_to_config_embedding_linear(self): assert isinstance(model.emb.weight, IntxUnpackedToInt8Tensor) assert isinstance(model.linear.weight, IntxUnpackedToInt8Tensor) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") def test_module_fqn_to_config_skip(self): config1 = Int4WeightOnlyConfig(group_size=32, version=1) config = ModuleFqnToConfig({"_default": config1, "linear2": None}) - model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) - example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) + model = ToyLinearModel().to(_DEVICE).to(dtype=torch.bfloat16) + example_inputs = model.example_inputs(device=_DEVICE, dtype=torch.bfloat16) quantize_(model, config, filter_fn=None) model(*example_inputs) assert isinstance(model.linear1.weight, AffineQuantizedTensor) assert isinstance(model.linear1.weight._layout, TensorCoreTiledLayout) assert not isinstance(model.linear2.weight, AffineQuantizedTensor) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") def test_int4wo_cuda_serialization(self): config = Int4WeightOnlyConfig(group_size=32, version=1) - model = ToyLinearModel().cuda().to(dtype=torch.bfloat16) + model = ToyLinearModel().to(_DEVICE).to(dtype=torch.bfloat16) # quantize in cuda quantize_(model, config) - example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16) + example_inputs = model.example_inputs(device=_DEVICE, dtype=torch.bfloat16) model(*example_inputs) with tempfile.NamedTemporaryFile() as ckpt: # save checkpoint in cuda @@ -786,7 +795,7 @@ def test_int4wo_cuda_serialization(self): # This is what torchtune does: https://github.com/pytorch/torchtune/blob/v0.6.1/torchtune/training/checkpointing/_utils.py#L253 sd = torch.load(ckpt.name, weights_only=False, map_location="cpu") for k, v in sd.items(): - sd[k] = v.to("cuda") + sd[k] = v.to(_DEVICE) # load state_dict in cuda model.load_state_dict(sd, assign=True) @@ -826,31 +835,25 @@ def test_config_deprecation(self): uintx_weight_only: (torch.uint4,), } - # Call each deprecated API twice - for cls, args in deprecated_apis_to_args.items(): - with warnings.catch_warnings(record=True) as _warnings: + with warnings.catch_warnings(record=True) as _warnings: + # Call each deprecated API twice + for cls, args in deprecated_apis_to_args.items(): cls(*args) cls(*args) - # Each call should have at least one warning. - # Some of them can have two warnings - one for deprecation, - # one for moving to prototype - # 1 warning - just deprecation - # 2 warnings - deprecation and prototype warnings - self.assertTrue(len(_warnings) in (1, 2)) - found_deprecated = False - for w in _warnings: - if "is deprecated and will be removed in a future release" in str( - w.message - ): - found_deprecated = True - self.assertTrue(found_deprecated) + # Each call should trigger the warning only once + self.assertEqual(len(_warnings), len(deprecated_apis_to_args)) + for w in _warnings: + self.assertIn( + "is deprecated and will be removed in a future release", + str(w.message), + ) common_utils.instantiate_parametrized_tests(TestQuantFlow) -@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") +@unittest.skipIf(not torch.accelerator.is_available(), "Need CUDA available") @unittest.skipIf(not is_sm_at_least_90(), "Checkpoints are produced in SM90+") class TestFqnToConfig(TestCase): def test_quantize_param_fqn_exact(self): @@ -860,7 +863,7 @@ def test_quantize_param_fqn_exact(self): config = AutoConfig.from_pretrained( "unsloth/Llama-4-Scout-17B-16E-Instruct" ).text_config - model = Llama4TextMoe(config).to(torch.bfloat16).cuda() + model = Llama4TextMoe(config).to(torch.bfloat16).to(_DEVICE) quant_config = FqnToConfig( { @@ -1105,86 +1108,31 @@ def test_non_fqn_config_filter_fn_none(self): assert isinstance(model.weight, Float8Tensor) assert model.weight.scale.numel() == 1 - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") def test_quantized_model_streaming_fqn_config(self): + device_module = torch.get_device_module(_DEVICE) + def reset_memory(): gc.collect() - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() + device_module.empty_cache() + device_module.reset_peak_memory_stats() quant_config = FqnToConfig({"_default": Int8WeightOnlyConfig()}) reset_memory() m = ToyLinearModel() - quantize_(m.to(device="cuda"), quant_config, filter_fn=None) - memory_baseline = torch.cuda.max_memory_allocated() + quantize_(m.to(device=_DEVICE), quant_config, filter_fn=None) + memory_baseline = device_module.max_memory_allocated() del m reset_memory() m = ToyLinearModel() - quantize_(m, quant_config, device="cuda", filter_fn=None) - memory_streaming = torch.cuda.max_memory_allocated() + quantize_(m, quant_config, device=_DEVICE, filter_fn=None) + memory_streaming = device_module.max_memory_allocated() for param in m.parameters(): - assert param.is_cuda + assert param.device.type == _DEVICE.type self.assertLess(memory_streaming, memory_baseline) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_fqn_config_quantized_nested_module(self): - class NestedModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(16, 16) - - class TopLevelModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.nested = NestedModule() - self.linear1 = torch.nn.Linear(16, 16) - - m = TopLevelModule() - quant_config = FqnToConfig( - { - "nested.linear": Int8WeightOnlyConfig(), - "linear1": Int8WeightOnlyConfig(), - } - ) - quantize_(m, quant_config, filter_fn=None) - - assert isinstance(m.nested.linear.weight, AffineQuantizedTensor) - assert isinstance(m.linear1.weight, AffineQuantizedTensor) - - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_fqn_config_quantized_nested_module_param(self): - class NestedModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(16, 16) - - class TopLevelModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.nested = NestedModule() - self.linear1 = torch.nn.Linear(16, 16) - - m = TopLevelModule() - quant_config = FqnToConfig( - { - "nested.linear.weight": Int8WeightOnlyConfig(), - "linear1.weight": Int8WeightOnlyConfig(), - } - ) - quantize_(m, quant_config, filter_fn=None) - - assert isinstance(m.nested.linear.weight, AffineQuantizedTensor) - assert isinstance(m.linear1.weight, AffineQuantizedTensor) - - def test_fqn_config_module_config_and_fqn_config_both_specified(self): - with self.assertRaises(ValueError): - FqnToConfig( - fqn_to_config={"test": Float8WeightOnlyConfig()}, - module_fqn_to_config={"test2": Float8WeightOnlyConfig()}, - ) - if __name__ == "__main__": unittest.main() From 5b76453001e4ea9fa5ef9da12a8661f0f0b94500 Mon Sep 17 00:00:00 2001 From: "Zeng, Xiangdong" Date: Tue, 18 Nov 2025 14:20:42 +0800 Subject: [PATCH 3/4] update --- test/quantization/test_qat.py | 5 +- test/quantization/test_quant_api.py | 83 +++++++++++++++++++++++++---- 2 files changed, 74 insertions(+), 14 deletions(-) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 1aa8f85f93..60a4d13867 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -603,9 +603,7 @@ def _assert_close_4w(self, val, ref): print(mean_err) self.assertTrue(mean_err < 0.05) - @unittest.skipIf( - not _GPU_IS_AVAILABLE, "skipping when cuda or xpu is not available" - ) + @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when gpu is not available") def test_qat_4w_primitives(self): n_bit = 4 group_size = 32 @@ -709,7 +707,6 @@ def test_qat_4w_quantizer(self): torch.manual_seed(self.SEED) m = M().to(device).to(dtype) m2 = copy.deepcopy(m) - qat_quantizer = Int4WeightOnlyQATQuantizer( groupsize=group_size, inner_k_tiles=inner_k_tiles, diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 4c5f22d386..dc74948c40 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -835,19 +835,25 @@ def test_config_deprecation(self): uintx_weight_only: (torch.uint4,), } - with warnings.catch_warnings(record=True) as _warnings: - # Call each deprecated API twice - for cls, args in deprecated_apis_to_args.items(): + # Call each deprecated API twice + for cls, args in deprecated_apis_to_args.items(): + with warnings.catch_warnings(record=True) as _warnings: cls(*args) cls(*args) - # Each call should trigger the warning only once - self.assertEqual(len(_warnings), len(deprecated_apis_to_args)) - for w in _warnings: - self.assertIn( - "is deprecated and will be removed in a future release", - str(w.message), - ) + # Each call should have at least one warning. + # Some of them can have two warnings - one for deprecation, + # one for moving to prototype + # 1 warning - just deprecation + # 2 warnings - deprecation and prototype warnings + self.assertTrue(len(_warnings) in (1, 2)) + found_deprecated = False + for w in _warnings: + if "is deprecated and will be removed in a future release" in str( + w.message + ): + found_deprecated = True + self.assertTrue(found_deprecated) common_utils.instantiate_parametrized_tests(TestQuantFlow) @@ -1133,6 +1139,63 @@ def reset_memory(): assert param.device.type == _DEVICE.type self.assertLess(memory_streaming, memory_baseline) + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") + def test_fqn_config_quantized_nested_module(self): + class NestedModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(16, 16) + + class TopLevelModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.nested = NestedModule() + self.linear1 = torch.nn.Linear(16, 16) + + m = TopLevelModule() + quant_config = FqnToConfig( + { + "nested.linear": Int8WeightOnlyConfig(), + "linear1": Int8WeightOnlyConfig(), + } + ) + quantize_(m, quant_config, filter_fn=None) + + assert isinstance(m.nested.linear.weight, AffineQuantizedTensor) + assert isinstance(m.linear1.weight, AffineQuantizedTensor) + + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") + def test_fqn_config_quantized_nested_module_param(self): + class NestedModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(16, 16) + + class TopLevelModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.nested = NestedModule() + self.linear1 = torch.nn.Linear(16, 16) + + m = TopLevelModule() + quant_config = FqnToConfig( + { + "nested.linear.weight": Int8WeightOnlyConfig(), + "linear1.weight": Int8WeightOnlyConfig(), + } + ) + quantize_(m, quant_config, filter_fn=None) + + assert isinstance(m.nested.linear.weight, AffineQuantizedTensor) + assert isinstance(m.linear1.weight, AffineQuantizedTensor) + + def test_fqn_config_module_config_and_fqn_config_both_specified(self): + with self.assertRaises(ValueError): + FqnToConfig( + fqn_to_config={"test": Float8WeightOnlyConfig()}, + module_fqn_to_config={"test2": Float8WeightOnlyConfig()}, + ) + if __name__ == "__main__": unittest.main() From 079d4b3629241d2a10b64532647826e8971ec30e Mon Sep 17 00:00:00 2001 From: "Zeng, Xiangdong" Date: Tue, 18 Nov 2025 14:58:26 +0800 Subject: [PATCH 4/4] update --- test/quantization/test_qat.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 60a4d13867..db33561fa9 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -105,7 +105,6 @@ # TODO: put this in a common test utils file _CUDA_IS_AVAILABLE = torch.cuda.is_available() -_GPU_IS_AVAILABLE = torch.accelerator.is_available() _DEVICE = get_current_accelerator_device() @@ -603,7 +602,7 @@ def _assert_close_4w(self, val, ref): print(mean_err) self.assertTrue(mean_err < 0.05) - @unittest.skipIf(not _GPU_IS_AVAILABLE, "skipping when gpu is not available") + @unittest.skipIf(_DEVICE is None, "skipping when gpu is not available") def test_qat_4w_primitives(self): n_bit = 4 group_size = 32