Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
from parameterized import parameterized
import itertools
import logging
from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4
from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4, is_fbcode

logger = logging.getLogger("INFO")

Expand Down Expand Up @@ -760,6 +760,7 @@ def test_int8_dynamic_quant_subclass_api(self, device, dtype):
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_int8_weight_only_quant_subclass_api(self, device, dtype):
self._test_lin_weight_subclass_api_impl(
change_linear_weights_to_int8_woqtensors, device, 40, test_dtype=dtype
Expand Down Expand Up @@ -935,6 +936,7 @@ def forward(self, x):
self.assertTrue(torch.equal(ref_q, test))

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(is_fbcode(), "'PlainAQTLayout' object has no attribute 'int_data'")
@torch.no_grad()
def test_save_load_dqtensors(self, device, dtype):
if device == "cpu":
Expand All @@ -943,6 +945,7 @@ def test_save_load_dqtensors(self, device, dtype):

@parameterized.expand(COMMON_DEVICE_DTYPE)
@torch.no_grad()
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_save_load_int8woqtensors(self, device, dtype):
self._test_handle_save_load_meta_impl(change_linear_weights_to_int8_woqtensors, device, test_dtype=dtype)

Expand Down Expand Up @@ -1025,6 +1028,7 @@ def test_non_dynamically_quantizable_linear(self):
self.assertTrue(isinstance(model[0], SmoothFakeDynamicallyQuantizedLinear))

@torch.inference_mode()
@unittest.skipIf(is_fbcode(), "can't load tokenizer")
def test_on_dummy_distilbert(self):
# https://huggingface.co/distilbert-base-uncased#how-to-use
from transformers import ( # type: ignore[import-untyped]
Expand Down
8 changes: 8 additions & 0 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torchao.utils import (
TORCH_VERSION_AFTER_2_3,
TORCH_VERSION_AFTER_2_4,
is_fbcode,
)

_SEED = 1234
Expand Down Expand Up @@ -179,6 +180,7 @@ def test_choose_qparams_group_sym(self):
self.assertTrue(torch.equal(zero_point, zp_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch version is 2.3 or lower")
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_choose_qparams_token_asym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
Expand All @@ -193,6 +195,7 @@ def test_choose_qparams_token_asym(self):
torch.testing.assert_close(scale, scale_ref, atol=10e-3, rtol=10e-3)
self.assertTrue(torch.equal(zero_point, zp_ref))

@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_choose_qparams_tensor_asym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
Expand All @@ -211,6 +214,7 @@ def test_choose_qparams_tensor_asym(self):
self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zp_ref))

@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_choose_qparams_tensor_sym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.SYMMETRIC
Expand Down Expand Up @@ -271,6 +275,7 @@ def test_quantize_activation_per_token_abs_max_dtype(self):


@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_quantize_dequantize_group_sym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.SYMMETRIC
Expand All @@ -295,6 +300,7 @@ def test_quantize_dequantize_group_sym(self):
self.assertTrue(torch.equal(dequantized, dequantized_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_quantize_dequantize_channel_asym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
Expand All @@ -318,6 +324,7 @@ def test_quantize_dequantize_channel_asym(self):
self.assertTrue(torch.equal(dequantized, dequantized_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_quantize_dequantize_tensor_asym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
Expand All @@ -341,6 +348,7 @@ def test_quantize_dequantize_tensor_asym(self):
self.assertTrue(torch.equal(dequantized, dequantized_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_quantize_dequantize_channel_asym_4d(self):
input = torch.randn(3, 3, 10, 10)
mapping_type = MappingType.ASYMMETRIC
Expand Down
4 changes: 3 additions & 1 deletion test/sparsity/test_fast_sparse_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
swap_semi_sparse_linear_with_linear,
SemiSparseLinear
)
from torchao.utils import TORCH_VERSION_AFTER_2_4
from torchao.utils import TORCH_VERSION_AFTER_2_4, is_fbcode

class TestModel(nn.Module):
def __init__(self):
Expand All @@ -30,6 +30,7 @@ class TestRuntimeSemiStructuredSparsity(TestCase):

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "pytorch 2.4+ feature")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_runtime_weight_sparsification(self):
# need this import inside to not break 2.2 tests
from torch.sparse import SparseSemiStructuredTensorCUSPARSELT
Expand Down Expand Up @@ -70,6 +71,7 @@ def test_runtime_weight_sparsification(self):

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "pytorch 2.4+ feature")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_runtime_weight_sparsification_compile(self):
# need this import inside to not break 2.2 tests
from torch.sparse import SparseSemiStructuredTensorCUSPARSELT
Expand Down
3 changes: 3 additions & 0 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,6 @@ def unwrap_tensor_subclass(model, filter_fn=None):
TORCH_VERSION_AFTER_2_2 = True
else:
TORCH_VERSION_AFTER_2_2 = False

def is_fbcode():
return not hasattr(torch.version, "git_version")