From f179ca547a5e185ba4b4657ee415695427c66b84 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Thu, 13 Nov 2025 12:00:03 -0800 Subject: [PATCH] changes --- .../safetensors/test_safetensors_support.py | 44 ++++++++++++++++++- .../safetensors/safetensors_support.py | 36 +++++++++++---- .../safetensors/safetensors_utils.py | 18 +++++++- 3 files changed, 88 insertions(+), 10 deletions(-) diff --git a/test/prototype/safetensors/test_safetensors_support.py b/test/prototype/safetensors/test_safetensors_support.py index bfec170fd0..6892a0ca22 100644 --- a/test/prototype/safetensors/test_safetensors_support.py +++ b/test/prototype/safetensors/test_safetensors_support.py @@ -74,9 +74,10 @@ def test_safetensors(self, config, act_pre_scale=False): save_file(tensors_data_dict, f.name, metadata=metadata) tensors_data_dict, metadata = load_data(file_path=f.name, device="cuda") - reconstructed_dict = unflatten_tensor_state_dict( + reconstructed_dict, leftover_tensor_data_dict = unflatten_tensor_state_dict( tensors_data_dict, metadata ) + assert not leftover_tensor_data_dict model = torch.nn.Sequential( torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") @@ -85,6 +86,47 @@ def test_safetensors(self, config, act_pre_scale=False): output = model(*example_inputs) assert torch.equal(output, ref_output) + @parametrize( + "config, act_pre_scale", + [ + (Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), False), + (Int4WeightOnlyConfig(), False), + (Int4WeightOnlyConfig(), True), + (Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d"), False), + (IntxWeightOnlyConfig(), False), + (Int8DynamicActivationIntxWeightConfig(), False), + ], + ) + def test_safetensors_sharded(self, config, act_pre_scale=False): + model = torch.nn.Sequential( + torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + ) + quantize_(model, config) + if act_pre_scale: + model[0].weight.act_pre_scale = torch.ones( + (1), dtype=torch.bfloat16, device="cuda" + ) + + with tempfile.NamedTemporaryFile() as f: + tensors_data_dict, metadata = flatten_tensor_state_dict(model.state_dict()) + save_file(tensors_data_dict, f.name, metadata=metadata) + tensors_data_dict, metadata = load_data(file_path=f.name, device="cuda") + + # simulate missing info on future file + if act_pre_scale: + del tensors_data_dict["0._weight_act_pre_scale"] # optional tensor data + else: + del tensors_data_dict["0._weight_qdata"] + + reconstructed_dict, leftover_tensor_data_dict = unflatten_tensor_state_dict( + tensors_data_dict, metadata + ) + + # since qdata is missing, layer 0 should not have been processed + for key in tensors_data_dict.keys(): + if key.startswith("0._weight_"): + assert key in leftover_tensor_data_dict + instantiate_parametrized_tests(TestSafeTensors) diff --git a/torchao/prototype/safetensors/safetensors_support.py b/torchao/prototype/safetensors/safetensors_support.py index 63623dcb15..08040f17b6 100644 --- a/torchao/prototype/safetensors/safetensors_support.py +++ b/torchao/prototype/safetensors/safetensors_support.py @@ -34,7 +34,8 @@ def unflatten_tensor_state_dict( '_data': { 'block_size': [1,32], ... - } + }, + '_tensor_data_names': ['qdata', 'scale'] } '0.bias': { '_type': 'torch.Tensor', @@ -66,12 +67,15 @@ def unflatten_tensor_state_dict( tensor_names = json.loads(metadata["tensor_names"]) result = {} - + leftover_state_dict = tensors_data_dict.copy() for tensor_name in tensor_names: + processed_tensors = [] + module_fqn, weight_name = tensor_name.rsplit(".", 1) prefix = f"{module_fqn}._{weight_name}_" tensor_tensors = {} + for key, value in combined_data.items(): if key.startswith(prefix): # Remove the prefix @@ -79,20 +83,35 @@ def unflatten_tensor_state_dict( tensor_metadata = json.loads(metadata.get(tensor_name)) tensor_type = tensor_metadata.get("_type") + complete_tensor_data_names = tensor_metadata.get("_tensor_data_names") if tensor_type in ALLOWED_TENSORS_SUBCLASSES: - if not tensor_tensors: - # we allow the option of loading in state_dict info for a single tensor - # if tensor state dict info is not loaded in yet, we wait for it to be provided - # in a future call + # if not all tensor data is present (ie missing qdata) we wait for it + # to be loaded in from a future call + if not len(tensor_tensors) is len(complete_tensor_data_names): continue tensor_metadata["_data"].update(tensor_tensors) result[tensor_name] = object_from_dict(tensor_metadata) + + for suffix in complete_tensor_data_names: + processed_tensors.append(prefix + suffix) elif tensor_type == torch.Tensor.__name__: + # we allow the option of loading in state_dict info for a single tensor + # if tensor state dict info is not loaded in yet, we wait for it to be provided + # in a future call + if tensor_name not in tensors_data_dict.keys(): + continue result[tensor_name] = tensors_data_dict[tensor_name] + processed_tensors.append( + tensor_name + ) # add here because key for torch.Tensor has no prefix else: raise ValueError(f"Unsupported tensor type: {tensor_type}") - return result + + for tensor_name in processed_tensors: + del leftover_state_dict[tensor_name] + + return result, leftover_state_dict def flatten_tensor_state_dict( @@ -125,7 +144,8 @@ def flatten_tensor_state_dict( '_data': { 'block_size': [1,32], ... - } + }, + '_tensor_data_names': ['qdata', 'scale'] } '0.bias': { '_type': 'torch.Tensor', diff --git a/torchao/prototype/safetensors/safetensors_utils.py b/torchao/prototype/safetensors/safetensors_utils.py index 3b1d032f67..9630515039 100644 --- a/torchao/prototype/safetensors/safetensors_utils.py +++ b/torchao/prototype/safetensors/safetensors_utils.py @@ -60,7 +60,23 @@ def default(self, o): encoded_attribute = self.encode_value(attribute) tensor_attr_dict[tensor_attribute_name] = encoded_attribute - return {"_type": o.__class__.__name__, "_data": tensor_attr_dict} + optional_tensor_data_names = ( + o.optional_tensor_data_names + if hasattr(o, "optional_tensor_data_names") + else [] + ) + all_tensor_data_names = optional_tensor_data_names + o.tensor_data_names + + _tensor_data_names = [] + for tensor_data_name in all_tensor_data_names: + if getattr(o, tensor_data_name) is not None: + _tensor_data_names.append(tensor_data_name) + + return { + "_type": o.__class__.__name__, + "_data": tensor_attr_dict, + "_tensor_data_names": _tensor_data_names, + } if hasattr(o, "_fields") and hasattr( o, "_asdict"