diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 4461355fb3..c3e483fe5b 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -73,7 +73,7 @@ from parameterized import parameterized import itertools import logging -from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4 +from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4, is_fbcode logger = logging.getLogger("INFO") @@ -760,6 +760,7 @@ def test_int8_dynamic_quant_subclass_api(self, device, dtype): ) @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_int8_weight_only_quant_subclass_api(self, device, dtype): self._test_lin_weight_subclass_api_impl( change_linear_weights_to_int8_woqtensors, device, 40, test_dtype=dtype @@ -935,6 +936,7 @@ def forward(self, x): self.assertTrue(torch.equal(ref_q, test)) @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf(is_fbcode(), "'PlainAQTLayout' object has no attribute 'int_data'") @torch.no_grad() def test_save_load_dqtensors(self, device, dtype): if device == "cpu": @@ -943,6 +945,7 @@ def test_save_load_dqtensors(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @torch.no_grad() + @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_save_load_int8woqtensors(self, device, dtype): self._test_handle_save_load_meta_impl(change_linear_weights_to_int8_woqtensors, device, test_dtype=dtype) @@ -1025,6 +1028,7 @@ def test_non_dynamically_quantizable_linear(self): self.assertTrue(isinstance(model[0], SmoothFakeDynamicallyQuantizedLinear)) @torch.inference_mode() + @unittest.skipIf(is_fbcode(), "can't load tokenizer") def test_on_dummy_distilbert(self): # https://huggingface.co/distilbert-base-uncased#how-to-use from transformers import ( # type: ignore[import-untyped] diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index a942c6a743..0e5388c301 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -27,6 +27,7 @@ from torchao.utils import ( TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4, + is_fbcode, ) _SEED = 1234 @@ -179,6 +180,7 @@ def test_choose_qparams_group_sym(self): self.assertTrue(torch.equal(zero_point, zp_ref)) @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch version is 2.3 or lower") + @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_choose_qparams_token_asym(self): input = torch.randn(10, 10) mapping_type = MappingType.ASYMMETRIC @@ -193,6 +195,7 @@ def test_choose_qparams_token_asym(self): torch.testing.assert_close(scale, scale_ref, atol=10e-3, rtol=10e-3) self.assertTrue(torch.equal(zero_point, zp_ref)) + @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_choose_qparams_tensor_asym(self): input = torch.randn(10, 10) mapping_type = MappingType.ASYMMETRIC @@ -211,6 +214,7 @@ def test_choose_qparams_tensor_asym(self): self.assertTrue(torch.equal(scale, scale_ref)) self.assertTrue(torch.equal(zero_point, zp_ref)) + @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_choose_qparams_tensor_sym(self): input = torch.randn(10, 10) mapping_type = MappingType.SYMMETRIC @@ -271,6 +275,7 @@ def test_quantize_activation_per_token_abs_max_dtype(self): @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_group_sym(self): input = torch.randn(10, 10) mapping_type = MappingType.SYMMETRIC @@ -295,6 +300,7 @@ def test_quantize_dequantize_group_sym(self): self.assertTrue(torch.equal(dequantized, dequantized_ref)) @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_channel_asym(self): input = torch.randn(10, 10) mapping_type = MappingType.ASYMMETRIC @@ -318,6 +324,7 @@ def test_quantize_dequantize_channel_asym(self): self.assertTrue(torch.equal(dequantized, dequantized_ref)) @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_tensor_asym(self): input = torch.randn(10, 10) mapping_type = MappingType.ASYMMETRIC @@ -341,6 +348,7 @@ def test_quantize_dequantize_tensor_asym(self): self.assertTrue(torch.equal(dequantized, dequantized_ref)) @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_channel_asym_4d(self): input = torch.randn(3, 3, 10, 10) mapping_type = MappingType.ASYMMETRIC diff --git a/test/sparsity/test_fast_sparse_training.py b/test/sparsity/test_fast_sparse_training.py index 081f0e4d2f..2bd0d1878c 100644 --- a/test/sparsity/test_fast_sparse_training.py +++ b/test/sparsity/test_fast_sparse_training.py @@ -12,7 +12,7 @@ swap_semi_sparse_linear_with_linear, SemiSparseLinear ) -from torchao.utils import TORCH_VERSION_AFTER_2_4 +from torchao.utils import TORCH_VERSION_AFTER_2_4, is_fbcode class TestModel(nn.Module): def __init__(self): @@ -30,6 +30,7 @@ class TestRuntimeSemiStructuredSparsity(TestCase): @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "pytorch 2.4+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_runtime_weight_sparsification(self): # need this import inside to not break 2.2 tests from torch.sparse import SparseSemiStructuredTensorCUSPARSELT @@ -70,6 +71,7 @@ def test_runtime_weight_sparsification(self): @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "pytorch 2.4+ feature") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_runtime_weight_sparsification_compile(self): # need this import inside to not break 2.2 tests from torch.sparse import SparseSemiStructuredTensorCUSPARSELT diff --git a/torchao/utils.py b/torchao/utils.py index 27650dae1c..2a19993e4d 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -154,3 +154,6 @@ def unwrap_tensor_subclass(model, filter_fn=None): TORCH_VERSION_AFTER_2_2 = True else: TORCH_VERSION_AFTER_2_2 = False + +def is_fbcode(): + return not hasattr(torch.version, "git_version")