diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 9679f6975a..20cf1c6311 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -39,6 +39,7 @@ from torchao.utils import ( check_cpu_version, check_xpu_version, + get_current_accelerator_device, is_fbcode, is_ROCM, is_sm_at_least_89, @@ -47,10 +48,11 @@ is_cusparselt_available = ( hasattr(torch.backends, "cusparselt") and torch.backends.cusparselt.is_available() ) +_DEVICE = get_current_accelerator_device() def get_quantization_functions( - do_sparse: bool, do_int4: bool, device: str = "cuda", int4_zp_int: bool = False + do_sparse: bool, do_int4: bool, device: str = _DEVICE, int4_zp_int: bool = False ): base_functions = [ Int8WeightOnlyConfig(), @@ -105,9 +107,9 @@ class TestAffineQuantized(TestCase): ["xpu"] if torch.xpu.is_available() else [] ) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") def test_tensor_core_layout_transpose(self): - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=_DEVICE) t = linear.weight shape = t.shape apply_int4_weight_only_quant = Int4WeightOnlyConfig(group_size=32, version=1) @@ -169,7 +171,7 @@ def _apply(module, config_or_subclass_inserter): ql = _apply(linear, apply_quant) ql.to(device) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") def test_register_new_dispatch(self): from torchao.dtypes import AffineQuantizedTensor from torchao.dtypes.affine_quantized_tensor_ops import ( @@ -206,10 +208,10 @@ def apply_uint6_weight_only_quant(linear): ) return linear - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=_DEVICE) apply_uint6_weight_only_quant(linear) - example_input = torch.randn(1, 128, dtype=torch.bfloat16, device="cuda") + example_input = torch.randn(1, 128, dtype=torch.bfloat16, device=_DEVICE) with self.assertRaisesRegex( AssertionError, "dispatching to my impl for uint6 weight only quant" ): @@ -234,11 +236,11 @@ def test_print_quantized_module(self): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize( - "apply_quant", get_quantization_functions(False, True, "cuda", False) + "apply_quant", get_quantization_functions(False, True, _DEVICE, False) ) def test_test_copy__apply(self, apply_quant): - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") - linear2 = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=_DEVICE) + linear2 = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=_DEVICE) if isinstance(apply_quant, AOBaseConfig): quantize_(linear, apply_quant) @@ -249,20 +251,20 @@ def test_test_copy__apply(self, apply_quant): ql = apply_quant(linear) ql2 = apply_quant(linear2) - example_input = torch.randn(1, 128, dtype=torch.bfloat16, device="cuda") + example_input = torch.randn(1, 128, dtype=torch.bfloat16, device=_DEVICE) output = ql(example_input) ql2.weight.copy_(ql.weight) ql2.bias = ql.bias output2 = ql2(example_input) self.assertEqual(output, output2) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") @common_utils.parametrize( - "apply_quant", get_quantization_functions(False, True, "cuda", False) + "apply_quant", get_quantization_functions(False, True, _DEVICE, False) ) def test_copy__mismatch_metadata(self, apply_quant): - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") - linear2 = torch.nn.Linear(128, 512, dtype=torch.bfloat16, device="cuda") + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=_DEVICE) + linear2 = torch.nn.Linear(128, 512, dtype=torch.bfloat16, device=_DEVICE) if isinstance(apply_quant, AOBaseConfig): quantize_(linear, apply_quant) @@ -336,7 +338,7 @@ def test_alias(self, device, dtype): quantize_(dummy, Int8DynamicActivationInt8WeightConfig()) _ = dummy.weight[...] - @common_utils.parametrize("device", ["cuda"]) + @common_utils.parametrize("device", [_DEVICE]) @common_utils.parametrize("dtype", [torch.bfloat16]) @skip_if_no_cuda() @skip_if_rocm("ROCm enablement in progress") @@ -350,9 +352,9 @@ def test_slice_int4wo(self, device, dtype): _ = dummy.weight.narrow(0, 0, 64) _ = dummy.weight.narrow(1, 0, 128) - @common_utils.parametrize("device", ["cuda"]) + @common_utils.parametrize("device", [_DEVICE]) @common_utils.parametrize("dtype", [torch.float16, torch.bfloat16]) - @skip_if_no_cuda() + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") @skip_if_no_gemlite() def test_slice_gemlite(self, device, dtype): # in_feature not divisible by 1024 @@ -433,7 +435,7 @@ def dequant(input_layer, in_features, orig_shape): ) self.assertEqual((W_slice_ref - W_slice).abs().mean().item(), 0) - @common_utils.parametrize("device", ["cuda"]) + @common_utils.parametrize("device", [_DEVICE]) @common_utils.parametrize("dtype", [torch.bfloat16]) def test_matmul(self, device, dtype): x = torch.randn(53, 2048) @@ -450,14 +452,14 @@ def test_matmul(self, device, dtype): # make sure it runs torch.matmul(x, w.t()) - @common_utils.parametrize("device", ["cuda"]) + @common_utils.parametrize("device", [_DEVICE]) @common_utils.parametrize("dtype", [torch.bfloat16]) @skip_if_no_cuda() @skip_if_rocm("ROCm enablement in progress") def test_slice_and_copy_int4wo(self, device, dtype): - l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) + l = torch.nn.Linear(1024, 1024).to(_DEVICE).to(torch.bfloat16) l.weight = torch.nn.Parameter( - torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda") + torch.zeros(1024, 1024, dtype=torch.bfloat16, device=_DEVICE) ) quantize_(l, Int4WeightOnlyConfig(version=1)) param = l.weight @@ -474,7 +476,7 @@ def test_slice_and_copy_int4wo(self, device, dtype): assert param.data.dequantize()[0][0] == 0 # dummy_l has random input (shouldn't be 0) - dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) + dummy_l = torch.nn.Linear(1024, 1024).to(_DEVICE).to(torch.bfloat16) quantize_(dummy_l, Int4WeightOnlyConfig(version=1)) quantized = dummy_l.weight quantized = quantized.narrow(0, 0, 512) @@ -484,9 +486,9 @@ def test_slice_and_copy_int4wo(self, device, dtype): # making sure param.data is updated assert param.data.dequantize()[0][0] != 0 - @common_utils.parametrize("device", ["cuda"]) + @common_utils.parametrize("device", [_DEVICE]) @common_utils.parametrize("dtype", [torch.bfloat16]) - @skip_if_no_cuda() + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") @skip_if_rocm("ROCm enablement in progress") def test_mm_int4wo(self, device, dtype): weight = torch.randn(512, 1024).to(device).to(dtype) diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 738e9b6164..54820cb5b3 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -39,6 +39,7 @@ ) from torchao.quantization.quantize_.common import KernelPreference from torchao.utils import ( + get_current_accelerator_device, is_sm_at_least_89, is_sm_at_least_90, is_sm_version, @@ -46,6 +47,7 @@ random.seed(0) torch.manual_seed(0) +_DEVICE = get_current_accelerator_device() class ToyLinearModel(torch.nn.Module): @@ -61,7 +63,7 @@ def forward(self, x): class TestAffineQuantizedFloat8Compile(InductorTestCase): - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") @unittest.skipIf( not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) @@ -97,7 +99,7 @@ def test_fp8_linear_variants( with error_context: M, N, K = sizes - input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") + input_tensor = torch.randn(*M, K, dtype=dtype, device=_DEVICE) # Get a "reasonable" scale for the input tensor even though # we use the same scale for multiple activations scale, _ = choose_qparams_affine( @@ -122,7 +124,7 @@ def test_fp8_linear_variants( } # Create a linear layer with bfloat16 dtype - model = ToyLinearModel(K, N).eval().to(dtype).to("cuda") + model = ToyLinearModel(K, N).eval().to(dtype).to(_DEVICE) quantized_model = copy.deepcopy(model) factory = mode_map[mode]() @@ -140,14 +142,16 @@ def test_fp8_linear_variants( ) @unittest.skipIf( - not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + _DEVICE == "cuda" and not is_sm_at_least_89(), + "Requires GPU with compute capability >= 8.9", ) def test_invalid_granularity(self): with pytest.raises(ValueError, match="Invalid granularity specification"): Float8DynamicActivationFloat8WeightConfig(granularity="invalid") @unittest.skipIf( - not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + _DEVICE == "cuda" and not is_sm_at_least_89(), + "Requires GPU with compute capability >= 8.9", ) def test_mismatched_granularity(self): with pytest.raises( @@ -159,7 +163,8 @@ def test_mismatched_granularity(self): ) @unittest.skipIf( - not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + _DEVICE == "cuda" and not is_sm_at_least_89(), + "Requires GPU with compute capability >= 8.9", ) def test_unsupported_granularity(self): class UnsupportedGranularity: @@ -179,7 +184,7 @@ def test_per_row_with_float32(self): AssertionError, match="PerRow quantization only works for bfloat16 precision", ): - model = ToyLinearModel(64, 64).eval().to(torch.float32).to("cuda") + model = ToyLinearModel(64, 64).eval().to(torch.float32).to(_DEVICE) quantize_( model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), @@ -192,7 +197,7 @@ def test_per_row_with_float32(self): @common_utils.parametrize("mode", ["dynamic", "weight-only", "static"]) def test_serialization(self, mode: str): # Create and quantize the model - model = ToyLinearModel(16, 32).to(device="cuda") + model = ToyLinearModel(16, 32).to(device=_DEVICE) mode_map = { "dynamic": partial( @@ -203,7 +208,7 @@ def test_serialization(self, mode: str): "weight-only": partial(Float8WeightOnlyConfig, version=1), "static": partial( Float8StaticActivationFloat8WeightConfig, - scale=torch.tensor(1.0, dtype=torch.float32, device="cuda"), + scale=torch.tensor(1.0, dtype=torch.float32, device=_DEVICE), granularity=PerTensor(), ), } @@ -266,7 +271,7 @@ def test_serialization(self, mode: str): ) def test_fp8_weight_dimension_warning(self): # Create model with incompatible dimensions (not multiples of 16) - model = ToyLinearModel(10, 25).cuda() # 10x25 and 25x10 weights + model = ToyLinearModel(10, 25).to(_DEVICE) # 10x25 and 25x10 weights # Set up logging capture with self.assertLogs( @@ -312,7 +317,7 @@ def test_fp8_weight_dimension_warning(self): def test_mm_float8dq_per_row( self, in_features, out_features, leading_shape, bias: bool ): - device = "cuda" + device = _DEVICE dtype = torch.bfloat16 input_shape = leading_shape + (in_features,) @@ -353,15 +358,16 @@ def test_mm_float8dq_per_row( error = compute_error(ref_output, quant_output) assert error > 20, f"Quantization error is too high got a SQNR of {error}" - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") @unittest.skipIf( - not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + _DEVICE == "cuda" and not is_sm_at_least_89(), + "Requires GPU with compute capability >= 8.9", ) @common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) @common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16]) def test_choose_scale_float8_bounds(self, float8_dtype, output_dtype): block_size = () - device = "cuda" + device = _DEVICE input_tensor = torch.randn(8, 64, device=device, dtype=torch.float32) # testing upper bounds @@ -396,9 +402,10 @@ def test_choose_scale_float8_bounds(self, float8_dtype, output_dtype): # since scale = abs_max / quant_max, larger abs_max means scale is larger self.assertTrue(scale_ref < scale_with_lb) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") @unittest.skipIf( - not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + _DEVICE == "cuda" and not is_sm_at_least_89(), + "Requires GPU with compute capability >= 8.9", ) @common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) @common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16]) @@ -406,7 +413,7 @@ def test_choose_scale_float8_bounds(self, float8_dtype, output_dtype): def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size): """Test _dequantize_affine_float8 with various configurations""" - device = "cuda" + device = _DEVICE input_tensor = torch.randn(8, 64, device=device, dtype=torch.float32) # Choose quantization parameters @@ -429,13 +436,14 @@ def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size): error = torch.abs(input_tensor.to(output_dtype) - dequantized).mean() self.assertLess(error, 0.1, "Quantization error too high") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") @unittest.skipIf( - not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" + _DEVICE == "cuda" and not is_sm_at_least_89(), + "Requires GPU with compute capability >= 8.9", ) def test_dequantize_affine_float8_scale_broadcasting(self): """Test that scale broadcasting works correctly for block-wise quantization""" - device = "cuda" + device = _DEVICE # Create input tensor with known block structure input_tensor = torch.randn(4, 32, device=device, dtype=torch.float32) block_size = (2, 16) # 2x2 blocks in first dim, 2x16 blocks in second dim @@ -468,7 +476,7 @@ def test_dequantize_affine_float8_scale_broadcasting(self): @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) def test_float8_tensor_slicing_basic(self, granularity): """Test basic slicing operations on Float8 tensors""" - device = "cuda" + device = _DEVICE dtype = torch.bfloat16 # Create and quantize a model @@ -505,7 +513,7 @@ def test_float8_tensor_slicing_basic(self, granularity): ) def test_float8_tensor_slicing_per_tensor(self): """Test slicing with per-tensor quantization (scale should not change)""" - device = "cuda" + device = _DEVICE dtype = torch.bfloat16 # Create and quantize with per-tensor granularity @@ -539,7 +547,7 @@ def test_float8_tensor_slicing_per_tensor(self): ) def test_float8_tensor_slicing_per_row(self): """Test slicing with per-row quantization (scale should be sliced appropriately)""" - device = "cuda" + device = _DEVICE dtype = torch.bfloat16 # Create and quantize with per-row granularity @@ -578,7 +586,7 @@ def test_float8_tensor_slicing_per_row(self): ) def test_float8_tensor_slicing_edge_cases(self): """Test edge cases in slicing""" - device = "cuda" + device = _DEVICE dtype = torch.bfloat16 # Create and quantize a model @@ -615,7 +623,7 @@ def test_float8_tensor_slicing_edge_cases(self): ) def test_float8_tensor_slicing_functional_correctness(self, granularity): """Test that sliced tensors produce correct results in computations""" - device = "cuda" + device = _DEVICE dtype = torch.bfloat16 # Create reference and quantized models with dimensions that are multiples of 16 @@ -791,7 +799,7 @@ def test_expected_kernels_on_gpu(self, granularity, float8_config_version): M, K, N = 128, 256, 512 m = torch.nn.Sequential( - torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16) + torch.nn.Linear(K, N, device=_DEVICE, dtype=torch.bfloat16) ) if float8_config_version == 1: config = Float8DynamicActivationFloat8WeightConfig( @@ -810,7 +818,7 @@ def test_expected_kernels_on_gpu(self, granularity, float8_config_version): ) m = torch.compile(m) - x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + x = torch.randn(M, K, device=_DEVICE, dtype=torch.bfloat16) out, code = run_and_get_code(m, x) # triton kernel call looks like: