diff --git a/test/prototype/safetensors/test_safetensors_support.py b/test/prototype/safetensors/test_safetensors_support.py index 988cfe8582..7d05eaf309 100644 --- a/test/prototype/safetensors/test_safetensors_support.py +++ b/test/prototype/safetensors/test_safetensors_support.py @@ -20,6 +20,8 @@ from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, Int4WeightOnlyConfig, + Int8DynamicActivationIntxWeightConfig, + IntxWeightOnlyConfig, ) from torchao.utils import is_sm_at_least_89 @@ -46,6 +48,8 @@ class TestSafeTensors(TestCase): (Int4WeightOnlyConfig(), False), (Int4WeightOnlyConfig(), True), (Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d"), False), + (IntxWeightOnlyConfig(), False), + (Int8DynamicActivationIntxWeightConfig(), False), ], ) def test_safetensors(self, config, act_pre_scale=False): diff --git a/torchao/prototype/safetensors/safetensors_utils.py b/torchao/prototype/safetensors/safetensors_utils.py index fe68ad4cb6..c3e85bd4fb 100644 --- a/torchao/prototype/safetensors/safetensors_utils.py +++ b/torchao/prototype/safetensors/safetensors_utils.py @@ -10,6 +10,7 @@ Float8Tensor, Int4Tensor, Int4TilePackedTo4dTensor, + IntxUnpackedToInt8Tensor, ) from torchao.quantization.quantize_.common import KernelPreference from torchao.quantization.quantize_.workflows import QuantizeTensorToFloat8Kwargs @@ -18,6 +19,7 @@ "Float8Tensor": Float8Tensor, "Int4Tensor": Int4Tensor, "Int4TilePackedTo4dTensor": Int4TilePackedTo4dTensor, + "IntxUnpackedToInt8Tensor": IntxUnpackedToInt8Tensor, "Float8MMConfig": torchao.float8.inference.Float8MMConfig, "QuantizeTensorToFloat8Kwargs": QuantizeTensorToFloat8Kwargs, "PerRow": torchao.quantization.PerRow, @@ -29,6 +31,7 @@ "Float8Tensor", "Int4Tensor", "Int4TilePackedTo4dTensor", + "IntxUnpackedToInt8Tensor", ] __all__ = [