From e212a84646ee861700fb3a6131be557a0f53351f Mon Sep 17 00:00:00 2001 From: Angel Li Date: Wed, 24 Sep 2025 14:09:45 -0700 Subject: [PATCH 01/12] add int4tensor support for safetensors [ghstack-poisoned] --- .../safetensors/test_safetensors_support.py | 15 +++++++++------ .../safetensors/safetensors_support.py | 18 +++++++++--------- .../prototype/safetensors/safetensors_utils.py | 12 +++++++----- 3 files changed, 25 insertions(+), 20 deletions(-) diff --git a/test/prototype/safetensors/test_safetensors_support.py b/test/prototype/safetensors/test_safetensors_support.py index b67bf2bf0c..a2dcb52ba2 100644 --- a/test/prototype/safetensors/test_safetensors_support.py +++ b/test/prototype/safetensors/test_safetensors_support.py @@ -7,6 +7,8 @@ from torch.testing._internal.common_utils import ( TestCase, run_tests, + instantiate_parametrized_tests, + parametrize, ) from torchao import quantize_ @@ -15,7 +17,7 @@ unflatten_tensor_state_dict, ) from torchao.quantization.granularity import PerRow -from torchao.quantization.quant_api import Float8DynamicActivationFloat8WeightConfig +from torchao.quantization.quant_api import Float8DynamicActivationFloat8WeightConfig, Int4WeightOnlyConfig from torchao.utils import ( is_sm_at_least_89, ) @@ -36,13 +38,13 @@ def load_data(file_path: str, device: str): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(not is_sm_at_least_89(), "Need sm89+") class TestSafeTensors(TestCase): - def test_safetensors(self): - config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) + @parametrize("config", [Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), Int4WeightOnlyConfig()]) + def test_safetensors(self, config): model = torch.nn.Sequential( - torch.nn.Linear(32, 256, dtype=torch.bfloat16, device="cuda") + torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") ) quantize_(model, config) - example_inputs = (torch.randn(2, 32, dtype=torch.bfloat16, device="cuda"),) + example_inputs = (torch.randn(2, 128, dtype=torch.bfloat16, device="cuda"),) ref_output = model(*example_inputs) with tempfile.NamedTemporaryFile() as f: @@ -54,12 +56,13 @@ def test_safetensors(self): ) model = torch.nn.Sequential( - torch.nn.Linear(32, 256, dtype=torch.bfloat16, device="cuda") + torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") ) model.load_state_dict(reconstructed_dict, assign=True) output = model(*example_inputs) assert torch.equal(output, ref_output) +instantiate_parametrized_tests(TestSafeTensors) if __name__ == "__main__": run_tests() diff --git a/torchao/prototype/safetensors/safetensors_support.py b/torchao/prototype/safetensors/safetensors_support.py index 19943e4b4a..649f35b6dd 100644 --- a/torchao/prototype/safetensors/safetensors_support.py +++ b/torchao/prototype/safetensors/safetensors_support.py @@ -7,8 +7,9 @@ from torchao.prototype.safetensors.safetensors_utils import ( Float8TensorAttributeJSONEncoder, object_from_dict, + ALLOWED_TENSORS ) -from torchao.quantization import Float8Tensor +from torchao.quantization import Float8Tensor, Int4Tensor logger: logging.Logger = logging.getLogger(__name__) @@ -76,12 +77,11 @@ def unflatten_tensor_state_dict( tensor_metadata = json.loads(metadata.get(tensor_name)) tensor_type = tensor_metadata.get("_type") - - if tensor_type == Float8Tensor.__name__: + if tensor_type == torch.Tensor.__name__: + result[tensor_name] = tensor_tensors["_data"] + elif tensor_type in ALLOWED_TENSORS: tensor_metadata["_data"].update(tensor_tensors) result[tensor_name] = object_from_dict(tensor_metadata) - elif tensor_type == torch.Tensor.__name__: - result[tensor_name] = tensor_tensors["_data"] else: raise ValueError(f"Unsupported tensor type: {tensor_type}") @@ -140,15 +140,15 @@ def flatten_tensor_state_dict( tensors_data_dict = {} for tensor_name, tensor in tensors_dict.items(): - if isinstance(tensor, Float8Tensor): + if type(tensor) is torch.Tensor: + tensor_dict = {"_data": tensor} + tensor_metadata = json.dumps({"_type": torch.Tensor.__name__}) + elif tensor.__class__.__name__ in ALLOWED_TENSORS: tensor_dict = {} for tensor_data_name in tensor.tensor_data_names: tensor_dict[tensor_data_name] = getattr(tensor, tensor_data_name) tensor_metadata = json.dumps(tensor, cls=Float8TensorAttributeJSONEncoder) - elif type(tensor) is torch.Tensor: - tensor_dict = {"_data": tensor} - tensor_metadata = json.dumps({"_type": torch.Tensor.__name__}) else: raise ValueError(f"Unsupported tensor type: {type(tensor)}") diff --git a/torchao/prototype/safetensors/safetensors_utils.py b/torchao/prototype/safetensors/safetensors_utils.py index eb0258a505..4ca17197fb 100644 --- a/torchao/prototype/safetensors/safetensors_utils.py +++ b/torchao/prototype/safetensors/safetensors_utils.py @@ -6,12 +6,13 @@ import torch import torchao -from torchao.quantization import Float8Tensor +from torchao.quantization import Float8Tensor, Int4Tensor from torchao.quantization.quantize_.common import KernelPreference from torchao.quantization.quantize_.workflows import QuantizeTensorToFloat8Kwargs ALLOWED_CLASSES = { "Float8Tensor": Float8Tensor, + "Int4Tensor": Int4Tensor, "Float8MMConfig": torchao.float8.inference.Float8MMConfig, "QuantizeTensorToFloat8Kwargs": QuantizeTensorToFloat8Kwargs, "PerRow": torchao.quantization.PerRow, @@ -19,7 +20,7 @@ "KernelPreference": KernelPreference, } -ALLOWED_TENSORS = ["Float8Tensor", "Tensor"] +ALLOWED_TENSORS = ["Float8Tensor", "Int4Tensor", "Tensor"] __all__ = [ "Float8TensorAttributeJSONEncoder", @@ -27,13 +28,14 @@ "is_metadata_torchao", ] - class Float8TensorAttributeJSONEncoder(json.JSONEncoder): def default(self, o): - if isinstance(o, Float8Tensor): + if o.__class__.__name__ in ALLOWED_TENSORS: tensor_attr_dict = {} + optional_tensor_attributes = o.optional_tensor_attribute_names if hasattr(o, "optional_tensor_attribute_names") else [] + all_tensor_attributes = ( - o.optional_tensor_attribute_names + o.tensor_attribute_names + optional_tensor_attributes + o.tensor_attribute_names ) for tensor_attribute_name in all_tensor_attributes: From 7c82d4d35d51913b7c22dda38ee66e05909572d1 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Wed, 24 Sep 2025 14:13:19 -0700 Subject: [PATCH 02/12] Update on "add int4tensor support for safetensors" [ghstack-poisoned] --- .../safetensors/test_safetensors_support.py | 16 +++++++++++++--- .../prototype/safetensors/safetensors_support.py | 3 +-- .../prototype/safetensors/safetensors_utils.py | 7 ++++++- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/test/prototype/safetensors/test_safetensors_support.py b/test/prototype/safetensors/test_safetensors_support.py index a2dcb52ba2..36d808dc8a 100644 --- a/test/prototype/safetensors/test_safetensors_support.py +++ b/test/prototype/safetensors/test_safetensors_support.py @@ -6,9 +6,9 @@ from safetensors.torch import load_file, save_file from torch.testing._internal.common_utils import ( TestCase, - run_tests, instantiate_parametrized_tests, parametrize, + run_tests, ) from torchao import quantize_ @@ -17,7 +17,10 @@ unflatten_tensor_state_dict, ) from torchao.quantization.granularity import PerRow -from torchao.quantization.quant_api import Float8DynamicActivationFloat8WeightConfig, Int4WeightOnlyConfig +from torchao.quantization.quant_api import ( + Float8DynamicActivationFloat8WeightConfig, + Int4WeightOnlyConfig, +) from torchao.utils import ( is_sm_at_least_89, ) @@ -38,7 +41,13 @@ def load_data(file_path: str, device: str): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(not is_sm_at_least_89(), "Need sm89+") class TestSafeTensors(TestCase): - @parametrize("config", [Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), Int4WeightOnlyConfig()]) + @parametrize( + "config", + [ + Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), + Int4WeightOnlyConfig(), + ], + ) def test_safetensors(self, config): model = torch.nn.Sequential( torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") @@ -62,6 +71,7 @@ def test_safetensors(self, config): output = model(*example_inputs) assert torch.equal(output, ref_output) + instantiate_parametrized_tests(TestSafeTensors) if __name__ == "__main__": diff --git a/torchao/prototype/safetensors/safetensors_support.py b/torchao/prototype/safetensors/safetensors_support.py index 649f35b6dd..c884a2679b 100644 --- a/torchao/prototype/safetensors/safetensors_support.py +++ b/torchao/prototype/safetensors/safetensors_support.py @@ -5,11 +5,10 @@ import torch from torchao.prototype.safetensors.safetensors_utils import ( + ALLOWED_TENSORS, Float8TensorAttributeJSONEncoder, object_from_dict, - ALLOWED_TENSORS ) -from torchao.quantization import Float8Tensor, Int4Tensor logger: logging.Logger = logging.getLogger(__name__) diff --git a/torchao/prototype/safetensors/safetensors_utils.py b/torchao/prototype/safetensors/safetensors_utils.py index 4ca17197fb..cb14dde762 100644 --- a/torchao/prototype/safetensors/safetensors_utils.py +++ b/torchao/prototype/safetensors/safetensors_utils.py @@ -28,11 +28,16 @@ "is_metadata_torchao", ] + class Float8TensorAttributeJSONEncoder(json.JSONEncoder): def default(self, o): if o.__class__.__name__ in ALLOWED_TENSORS: tensor_attr_dict = {} - optional_tensor_attributes = o.optional_tensor_attribute_names if hasattr(o, "optional_tensor_attribute_names") else [] + optional_tensor_attributes = ( + o.optional_tensor_attribute_names + if hasattr(o, "optional_tensor_attribute_names") + else [] + ) all_tensor_attributes = ( optional_tensor_attributes + o.tensor_attribute_names From bbeaf52cc44f403fae99d2f854a259cf97751627 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Wed, 24 Sep 2025 14:28:31 -0700 Subject: [PATCH 03/12] Update on "add int4tensor support for safetensors" [ghstack-poisoned] --- .../safetensors/test_safetensors_support.py | 1 + .../safetensors/safetensors_support.py | 21 ++++++++++--------- .../safetensors/safetensors_utils.py | 10 ++++----- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/test/prototype/safetensors/test_safetensors_support.py b/test/prototype/safetensors/test_safetensors_support.py index 36d808dc8a..c2902c6a70 100644 --- a/test/prototype/safetensors/test_safetensors_support.py +++ b/test/prototype/safetensors/test_safetensors_support.py @@ -20,6 +20,7 @@ from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, Int4WeightOnlyConfig, + Float8DynamicActivationInt4WeightConfig, ) from torchao.utils import ( is_sm_at_least_89, diff --git a/torchao/prototype/safetensors/safetensors_support.py b/torchao/prototype/safetensors/safetensors_support.py index c884a2679b..9283493e17 100644 --- a/torchao/prototype/safetensors/safetensors_support.py +++ b/torchao/prototype/safetensors/safetensors_support.py @@ -5,8 +5,8 @@ import torch from torchao.prototype.safetensors.safetensors_utils import ( - ALLOWED_TENSORS, - Float8TensorAttributeJSONEncoder, + ALLOWED_TENSORS_SUBCLASSES, + TensorSubclassAttributeJSONEncoder, object_from_dict, ) @@ -76,11 +76,12 @@ def unflatten_tensor_state_dict( tensor_metadata = json.loads(metadata.get(tensor_name)) tensor_type = tensor_metadata.get("_type") - if tensor_type == torch.Tensor.__name__: - result[tensor_name] = tensor_tensors["_data"] - elif tensor_type in ALLOWED_TENSORS: + + if tensor_type in ALLOWED_TENSORS_SUBCLASSES: tensor_metadata["_data"].update(tensor_tensors) result[tensor_name] = object_from_dict(tensor_metadata) + elif tensor_type == torch.Tensor.__name__: + result[tensor_name] = tensor_tensors["_data"] else: raise ValueError(f"Unsupported tensor type: {tensor_type}") @@ -139,15 +140,15 @@ def flatten_tensor_state_dict( tensors_data_dict = {} for tensor_name, tensor in tensors_dict.items(): - if type(tensor) is torch.Tensor: - tensor_dict = {"_data": tensor} - tensor_metadata = json.dumps({"_type": torch.Tensor.__name__}) - elif tensor.__class__.__name__ in ALLOWED_TENSORS: + if tensor.__class__.__name__ in ALLOWED_TENSORS_SUBCLASSES: tensor_dict = {} for tensor_data_name in tensor.tensor_data_names: tensor_dict[tensor_data_name] = getattr(tensor, tensor_data_name) - tensor_metadata = json.dumps(tensor, cls=Float8TensorAttributeJSONEncoder) + tensor_metadata = json.dumps(tensor, cls=TensorSubclassAttributeJSONEncoder) + elif type(tensor) is torch.Tensor: + tensor_dict = {"_data": tensor} + tensor_metadata = json.dumps({"_type": torch.Tensor.__name__}) else: raise ValueError(f"Unsupported tensor type: {type(tensor)}") diff --git a/torchao/prototype/safetensors/safetensors_utils.py b/torchao/prototype/safetensors/safetensors_utils.py index cb14dde762..2b01d7f729 100644 --- a/torchao/prototype/safetensors/safetensors_utils.py +++ b/torchao/prototype/safetensors/safetensors_utils.py @@ -20,18 +20,18 @@ "KernelPreference": KernelPreference, } -ALLOWED_TENSORS = ["Float8Tensor", "Int4Tensor", "Tensor"] +ALLOWED_TENSORS_SUBCLASSES = ["Float8Tensor", "Int4Tensor"] __all__ = [ - "Float8TensorAttributeJSONEncoder", + "TensorSubclassAttributeJSONEncoder", "object_from_dict", "is_metadata_torchao", ] -class Float8TensorAttributeJSONEncoder(json.JSONEncoder): +class TensorSubclassAttributeJSONEncoder(json.JSONEncoder): def default(self, o): - if o.__class__.__name__ in ALLOWED_TENSORS: + if o.__class__.__name__ in ALLOWED_TENSORS_SUBCLASSES: tensor_attr_dict = {} optional_tensor_attributes = ( o.optional_tensor_attribute_names @@ -197,7 +197,7 @@ def is_metadata_torchao(metadata: Dict[str, Any]): # returns None if _type not in tensor_dict tensor_type = tensor_dict.get("_type") - if tensor_type not in ALLOWED_TENSORS: + if tensor_type not in ALLOWED_TENSORS_SUBCLASSES or tensor_type != "Tensor": return False return True From 0200ad1a8e8f425bb74b31081dac221a7caf8afd Mon Sep 17 00:00:00 2001 From: Angel Li Date: Wed, 24 Sep 2025 14:30:09 -0700 Subject: [PATCH 04/12] Update on "add int4tensor support for safetensors" [ghstack-poisoned] --- test/prototype/safetensors/test_safetensors_support.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/prototype/safetensors/test_safetensors_support.py b/test/prototype/safetensors/test_safetensors_support.py index c2902c6a70..36d808dc8a 100644 --- a/test/prototype/safetensors/test_safetensors_support.py +++ b/test/prototype/safetensors/test_safetensors_support.py @@ -20,7 +20,6 @@ from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, Int4WeightOnlyConfig, - Float8DynamicActivationInt4WeightConfig, ) from torchao.utils import ( is_sm_at_least_89, From 60c5de93574157e801278c184ba3375f56257405 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Wed, 24 Sep 2025 15:43:06 -0700 Subject: [PATCH 05/12] Update on "add int4tensor support for safetensors" [ghstack-poisoned] --- .../safetensors/test_safetensors_support.py | 18 +++++++++++------- .../safetensors/safetensors_support.py | 10 ++++++++-- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/test/prototype/safetensors/test_safetensors_support.py b/test/prototype/safetensors/test_safetensors_support.py index 36d808dc8a..4bb41e1e83 100644 --- a/test/prototype/safetensors/test_safetensors_support.py +++ b/test/prototype/safetensors/test_safetensors_support.py @@ -21,9 +21,7 @@ Float8DynamicActivationFloat8WeightConfig, Int4WeightOnlyConfig, ) -from torchao.utils import ( - is_sm_at_least_89, -) +from torchao.utils import is_sm_at_least_89 def load_data(file_path: str, device: str): @@ -42,17 +40,23 @@ def load_data(file_path: str, device: str): @unittest.skipIf(not is_sm_at_least_89(), "Need sm89+") class TestSafeTensors(TestCase): @parametrize( - "config", + "config, act_pre_scale", [ - Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), - Int4WeightOnlyConfig(), + (Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), None), + (Int4WeightOnlyConfig(), None), + ( + Int4WeightOnlyConfig(), + torch.ones((1), dtype=torch.bfloat16, device="cuda"), + ), ], ) - def test_safetensors(self, config): + def test_safetensors(self, config, act_pre_scale=None): model = torch.nn.Sequential( torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") ) quantize_(model, config) + if act_pre_scale is not None: + model[0].weight.act_pre_scale = act_pre_scale example_inputs = (torch.randn(2, 128, dtype=torch.bfloat16, device="cuda"),) ref_output = model(*example_inputs) diff --git a/torchao/prototype/safetensors/safetensors_support.py b/torchao/prototype/safetensors/safetensors_support.py index 9283493e17..61457ea328 100644 --- a/torchao/prototype/safetensors/safetensors_support.py +++ b/torchao/prototype/safetensors/safetensors_support.py @@ -142,8 +142,14 @@ def flatten_tensor_state_dict( for tensor_name, tensor in tensors_dict.items(): if tensor.__class__.__name__ in ALLOWED_TENSORS_SUBCLASSES: tensor_dict = {} - for tensor_data_name in tensor.tensor_data_names: - tensor_dict[tensor_data_name] = getattr(tensor, tensor_data_name) + + all_tensor_data = list(tensor.tensor_data_names) # create a copy + if hasattr(tensor, "optional_tensor_data_names"): + all_tensor_data += tensor.optional_tensor_data_names + + for tensor_data_name in all_tensor_data: + if getattr(tensor, tensor_data_name) is not None: + tensor_dict[tensor_data_name] = getattr(tensor, tensor_data_name) tensor_metadata = json.dumps(tensor, cls=TensorSubclassAttributeJSONEncoder) elif type(tensor) is torch.Tensor: From 6e316830f470819eded3919944ae0ee17ffe663d Mon Sep 17 00:00:00 2001 From: Angel Li Date: Wed, 24 Sep 2025 15:56:17 -0700 Subject: [PATCH 06/12] Update on "add int4tensor support for safetensors" **Summary** adding `Int4Tensor` support for safetensors (`Int4WeightOnlyConfig`) **Test plan** modified unit test to include `Int4WeightOnlyConfig` `python test/prototype/safetensors/test_safetensors_support.py` [ghstack-poisoned] --- test/prototype/safetensors/test_safetensors_support.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/prototype/safetensors/test_safetensors_support.py b/test/prototype/safetensors/test_safetensors_support.py index 4bb41e1e83..543a14cb6c 100644 --- a/test/prototype/safetensors/test_safetensors_support.py +++ b/test/prototype/safetensors/test_safetensors_support.py @@ -46,7 +46,7 @@ class TestSafeTensors(TestCase): (Int4WeightOnlyConfig(), None), ( Int4WeightOnlyConfig(), - torch.ones((1), dtype=torch.bfloat16, device="cuda"), + torch.ones((1), dtype=torch.bfloat16), ), ], ) @@ -56,6 +56,7 @@ def test_safetensors(self, config, act_pre_scale=None): ) quantize_(model, config) if act_pre_scale is not None: + act_pre_scale = act_pre_scale.to("cuda") model[0].weight.act_pre_scale = act_pre_scale example_inputs = (torch.randn(2, 128, dtype=torch.bfloat16, device="cuda"),) ref_output = model(*example_inputs) From 86d7fcf7a3dc4690eff0b3d960db1e11a3375262 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Wed, 24 Sep 2025 16:04:50 -0700 Subject: [PATCH 07/12] Update on "add int4tensor support for safetensors" **Summary** adding `Int4Tensor` support for safetensors (`Int4WeightOnlyConfig`) **Test plan** modified unit test to include `Int4WeightOnlyConfig` `python test/prototype/safetensors/test_safetensors_support.py` [ghstack-poisoned] --- .../safetensors/test_safetensors_support.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/test/prototype/safetensors/test_safetensors_support.py b/test/prototype/safetensors/test_safetensors_support.py index 543a14cb6c..1f6a031ab5 100644 --- a/test/prototype/safetensors/test_safetensors_support.py +++ b/test/prototype/safetensors/test_safetensors_support.py @@ -42,22 +42,20 @@ class TestSafeTensors(TestCase): @parametrize( "config, act_pre_scale", [ - (Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), None), - (Int4WeightOnlyConfig(), None), - ( - Int4WeightOnlyConfig(), - torch.ones((1), dtype=torch.bfloat16), - ), + (Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), False), + (Int4WeightOnlyConfig(), False), + (Int4WeightOnlyConfig(), True), ], ) - def test_safetensors(self, config, act_pre_scale=None): + def test_safetensors(self, config, act_pre_scale=False): model = torch.nn.Sequential( torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") ) quantize_(model, config) - if act_pre_scale is not None: - act_pre_scale = act_pre_scale.to("cuda") - model[0].weight.act_pre_scale = act_pre_scale + if act_pre_scale: + model[0].weight.act_pre_scale = torch.ones( + (1), dtype=torch.bfloat16, device="cuda" + ) example_inputs = (torch.randn(2, 128, dtype=torch.bfloat16, device="cuda"),) ref_output = model(*example_inputs) From 9395b3fb9bf6fad4174bcca65ca5cc64f2958009 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Thu, 25 Sep 2025 06:50:52 -0700 Subject: [PATCH 08/12] add int4preshuffledtensor to safetensors [ghstack-poisoned] --- test/prototype/safetensors/test_safetensors_support.py | 2 ++ torchao/prototype/safetensors/safetensors_utils.py | 5 +++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/test/prototype/safetensors/test_safetensors_support.py b/test/prototype/safetensors/test_safetensors_support.py index 1f6a031ab5..a5fb5bf65b 100644 --- a/test/prototype/safetensors/test_safetensors_support.py +++ b/test/prototype/safetensors/test_safetensors_support.py @@ -20,6 +20,7 @@ from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, Int4WeightOnlyConfig, + Float8DynamicActivationInt4WeightConfig ) from torchao.utils import is_sm_at_least_89 @@ -45,6 +46,7 @@ class TestSafeTensors(TestCase): (Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), False), (Int4WeightOnlyConfig(), False), (Int4WeightOnlyConfig(), True), + (Float8DynamicActivationInt4WeightConfig(), False) ], ) def test_safetensors(self, config, act_pre_scale=False): diff --git a/torchao/prototype/safetensors/safetensors_utils.py b/torchao/prototype/safetensors/safetensors_utils.py index 2b01d7f729..a880f01545 100644 --- a/torchao/prototype/safetensors/safetensors_utils.py +++ b/torchao/prototype/safetensors/safetensors_utils.py @@ -6,13 +6,14 @@ import torch import torchao -from torchao.quantization import Float8Tensor, Int4Tensor +from torchao.quantization import Float8Tensor, Int4Tensor, Int4PreshuffledTensor from torchao.quantization.quantize_.common import KernelPreference from torchao.quantization.quantize_.workflows import QuantizeTensorToFloat8Kwargs ALLOWED_CLASSES = { "Float8Tensor": Float8Tensor, "Int4Tensor": Int4Tensor, + "Int4PreshuffledTensor": Int4PreshuffledTensor, "Float8MMConfig": torchao.float8.inference.Float8MMConfig, "QuantizeTensorToFloat8Kwargs": QuantizeTensorToFloat8Kwargs, "PerRow": torchao.quantization.PerRow, @@ -20,7 +21,7 @@ "KernelPreference": KernelPreference, } -ALLOWED_TENSORS_SUBCLASSES = ["Float8Tensor", "Int4Tensor"] +ALLOWED_TENSORS_SUBCLASSES = ["Float8Tensor", "Int4Tensor", "Int4PreshuffledTensor"] __all__ = [ "TensorSubclassAttributeJSONEncoder", From 585afe25d6b9117caf27f8640fb32a754a6f4b15 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Thu, 25 Sep 2025 06:55:53 -0700 Subject: [PATCH 09/12] Update on "add int4preshuffledtensor to safetensors" adding `Int4PreshuffledTensor` to safetensors modified unit test, `python test/prototype/safetensors/test_safetensors_support.py` [ghstack-poisoned] --- test/prototype/safetensors/test_safetensors_support.py | 4 ++-- torchao/prototype/safetensors/safetensors_utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/prototype/safetensors/test_safetensors_support.py b/test/prototype/safetensors/test_safetensors_support.py index a5fb5bf65b..63d07fd9e9 100644 --- a/test/prototype/safetensors/test_safetensors_support.py +++ b/test/prototype/safetensors/test_safetensors_support.py @@ -19,8 +19,8 @@ from torchao.quantization.granularity import PerRow from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, + Float8DynamicActivationInt4WeightConfig, Int4WeightOnlyConfig, - Float8DynamicActivationInt4WeightConfig ) from torchao.utils import is_sm_at_least_89 @@ -46,7 +46,7 @@ class TestSafeTensors(TestCase): (Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), False), (Int4WeightOnlyConfig(), False), (Int4WeightOnlyConfig(), True), - (Float8DynamicActivationInt4WeightConfig(), False) + (Float8DynamicActivationInt4WeightConfig(), False), ], ) def test_safetensors(self, config, act_pre_scale=False): diff --git a/torchao/prototype/safetensors/safetensors_utils.py b/torchao/prototype/safetensors/safetensors_utils.py index a880f01545..dcc291662b 100644 --- a/torchao/prototype/safetensors/safetensors_utils.py +++ b/torchao/prototype/safetensors/safetensors_utils.py @@ -6,7 +6,7 @@ import torch import torchao -from torchao.quantization import Float8Tensor, Int4Tensor, Int4PreshuffledTensor +from torchao.quantization import Float8Tensor, Int4PreshuffledTensor, Int4Tensor from torchao.quantization.quantize_.common import KernelPreference from torchao.quantization.quantize_.workflows import QuantizeTensorToFloat8Kwargs From 87d2e184361caff846a81a976f242afa48b0ac13 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Thu, 25 Sep 2025 07:09:09 -0700 Subject: [PATCH 10/12] add int4tilepackedto4dtensor subclass to safetensors [ghstack-poisoned] --- .../safetensors/test_safetensors_support.py | 1 + .../prototype/safetensors/safetensors_utils.py | 15 +++++++++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/test/prototype/safetensors/test_safetensors_support.py b/test/prototype/safetensors/test_safetensors_support.py index 63d07fd9e9..f3b8d914d0 100644 --- a/test/prototype/safetensors/test_safetensors_support.py +++ b/test/prototype/safetensors/test_safetensors_support.py @@ -46,6 +46,7 @@ class TestSafeTensors(TestCase): (Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), False), (Int4WeightOnlyConfig(), False), (Int4WeightOnlyConfig(), True), + (Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d"), True), (Float8DynamicActivationInt4WeightConfig(), False), ], ) diff --git a/torchao/prototype/safetensors/safetensors_utils.py b/torchao/prototype/safetensors/safetensors_utils.py index dcc291662b..8748db61b5 100644 --- a/torchao/prototype/safetensors/safetensors_utils.py +++ b/torchao/prototype/safetensors/safetensors_utils.py @@ -6,7 +6,12 @@ import torch import torchao -from torchao.quantization import Float8Tensor, Int4PreshuffledTensor, Int4Tensor +from torchao.quantization import ( + Float8Tensor, + Int4PreshuffledTensor, + Int4Tensor, + Int4TilePackedTo4dTensor, +) from torchao.quantization.quantize_.common import KernelPreference from torchao.quantization.quantize_.workflows import QuantizeTensorToFloat8Kwargs @@ -14,6 +19,7 @@ "Float8Tensor": Float8Tensor, "Int4Tensor": Int4Tensor, "Int4PreshuffledTensor": Int4PreshuffledTensor, + "Int4TilePackedTo4dTensor": Int4TilePackedTo4dTensor, "Float8MMConfig": torchao.float8.inference.Float8MMConfig, "QuantizeTensorToFloat8Kwargs": QuantizeTensorToFloat8Kwargs, "PerRow": torchao.quantization.PerRow, @@ -21,7 +27,12 @@ "KernelPreference": KernelPreference, } -ALLOWED_TENSORS_SUBCLASSES = ["Float8Tensor", "Int4Tensor", "Int4PreshuffledTensor"] +ALLOWED_TENSORS_SUBCLASSES = [ + "Float8Tensor", + "Int4Tensor", + "Int4PreshuffledTensor", + "Int4TilePackedTo4dTensor", +] __all__ = [ "TensorSubclassAttributeJSONEncoder", From 6007af3f95740978a9b2d1261a169b021e283000 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Thu, 25 Sep 2025 07:10:41 -0700 Subject: [PATCH 11/12] Update on "add int4tilepackedto4dtensor subclass to safetensors" [ghstack-poisoned] --- test/prototype/safetensors/test_safetensors_support.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/prototype/safetensors/test_safetensors_support.py b/test/prototype/safetensors/test_safetensors_support.py index f3b8d914d0..48f6d432e5 100644 --- a/test/prototype/safetensors/test_safetensors_support.py +++ b/test/prototype/safetensors/test_safetensors_support.py @@ -46,7 +46,7 @@ class TestSafeTensors(TestCase): (Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), False), (Int4WeightOnlyConfig(), False), (Int4WeightOnlyConfig(), True), - (Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d"), True), + (Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d"), False), (Float8DynamicActivationInt4WeightConfig(), False), ], ) From 8820dee5f5dc7ec43920619cb0f923284d1f1af2 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Thu, 25 Sep 2025 12:16:35 -0700 Subject: [PATCH 12/12] Update base for Update on "add int4tilepackedto4dtensor subclass to safetensors" adding `Int4TilePackedTo4dTensor` to safetensors (`Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d")`) modified unit test, `python test/prototype/safetensors/test_safetensors_support.py` [ghstack-poisoned] --- test/prototype/safetensors/test_safetensors_support.py | 2 -- torchao/prototype/safetensors/safetensors_utils.py | 5 ++--- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/test/prototype/safetensors/test_safetensors_support.py b/test/prototype/safetensors/test_safetensors_support.py index 63d07fd9e9..1f6a031ab5 100644 --- a/test/prototype/safetensors/test_safetensors_support.py +++ b/test/prototype/safetensors/test_safetensors_support.py @@ -19,7 +19,6 @@ from torchao.quantization.granularity import PerRow from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, - Float8DynamicActivationInt4WeightConfig, Int4WeightOnlyConfig, ) from torchao.utils import is_sm_at_least_89 @@ -46,7 +45,6 @@ class TestSafeTensors(TestCase): (Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), False), (Int4WeightOnlyConfig(), False), (Int4WeightOnlyConfig(), True), - (Float8DynamicActivationInt4WeightConfig(), False), ], ) def test_safetensors(self, config, act_pre_scale=False): diff --git a/torchao/prototype/safetensors/safetensors_utils.py b/torchao/prototype/safetensors/safetensors_utils.py index dcc291662b..2b01d7f729 100644 --- a/torchao/prototype/safetensors/safetensors_utils.py +++ b/torchao/prototype/safetensors/safetensors_utils.py @@ -6,14 +6,13 @@ import torch import torchao -from torchao.quantization import Float8Tensor, Int4PreshuffledTensor, Int4Tensor +from torchao.quantization import Float8Tensor, Int4Tensor from torchao.quantization.quantize_.common import KernelPreference from torchao.quantization.quantize_.workflows import QuantizeTensorToFloat8Kwargs ALLOWED_CLASSES = { "Float8Tensor": Float8Tensor, "Int4Tensor": Int4Tensor, - "Int4PreshuffledTensor": Int4PreshuffledTensor, "Float8MMConfig": torchao.float8.inference.Float8MMConfig, "QuantizeTensorToFloat8Kwargs": QuantizeTensorToFloat8Kwargs, "PerRow": torchao.quantization.PerRow, @@ -21,7 +20,7 @@ "KernelPreference": KernelPreference, } -ALLOWED_TENSORS_SUBCLASSES = ["Float8Tensor", "Int4Tensor", "Int4PreshuffledTensor"] +ALLOWED_TENSORS_SUBCLASSES = ["Float8Tensor", "Int4Tensor"] __all__ = [ "TensorSubclassAttributeJSONEncoder",