diff --git a/test/prototype/safetensors/test_safetensors_support.py b/test/prototype/safetensors/test_safetensors_support.py index 1f6a031ab5..988cfe8582 100644 --- a/test/prototype/safetensors/test_safetensors_support.py +++ b/test/prototype/safetensors/test_safetensors_support.py @@ -45,6 +45,7 @@ class TestSafeTensors(TestCase): (Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), False), (Int4WeightOnlyConfig(), False), (Int4WeightOnlyConfig(), True), + (Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d"), False), ], ) def test_safetensors(self, config, act_pre_scale=False): diff --git a/torchao/prototype/safetensors/safetensors_utils.py b/torchao/prototype/safetensors/safetensors_utils.py index 2b01d7f729..fe68ad4cb6 100644 --- a/torchao/prototype/safetensors/safetensors_utils.py +++ b/torchao/prototype/safetensors/safetensors_utils.py @@ -6,13 +6,18 @@ import torch import torchao -from torchao.quantization import Float8Tensor, Int4Tensor +from torchao.quantization import ( + Float8Tensor, + Int4Tensor, + Int4TilePackedTo4dTensor, +) from torchao.quantization.quantize_.common import KernelPreference from torchao.quantization.quantize_.workflows import QuantizeTensorToFloat8Kwargs ALLOWED_CLASSES = { "Float8Tensor": Float8Tensor, "Int4Tensor": Int4Tensor, + "Int4TilePackedTo4dTensor": Int4TilePackedTo4dTensor, "Float8MMConfig": torchao.float8.inference.Float8MMConfig, "QuantizeTensorToFloat8Kwargs": QuantizeTensorToFloat8Kwargs, "PerRow": torchao.quantization.PerRow, @@ -20,7 +25,11 @@ "KernelPreference": KernelPreference, } -ALLOWED_TENSORS_SUBCLASSES = ["Float8Tensor", "Int4Tensor"] +ALLOWED_TENSORS_SUBCLASSES = [ + "Float8Tensor", + "Int4Tensor", + "Int4TilePackedTo4dTensor", +] __all__ = [ "TensorSubclassAttributeJSONEncoder",