diff --git a/test/prototype/safetensors/test_safetensors_support.py b/test/prototype/safetensors/test_safetensors_support.py index b67bf2bf0c..1f6a031ab5 100644 --- a/test/prototype/safetensors/test_safetensors_support.py +++ b/test/prototype/safetensors/test_safetensors_support.py @@ -6,6 +6,8 @@ from safetensors.torch import load_file, save_file from torch.testing._internal.common_utils import ( TestCase, + instantiate_parametrized_tests, + parametrize, run_tests, ) @@ -15,10 +17,11 @@ unflatten_tensor_state_dict, ) from torchao.quantization.granularity import PerRow -from torchao.quantization.quant_api import Float8DynamicActivationFloat8WeightConfig -from torchao.utils import ( - is_sm_at_least_89, +from torchao.quantization.quant_api import ( + Float8DynamicActivationFloat8WeightConfig, + Int4WeightOnlyConfig, ) +from torchao.utils import is_sm_at_least_89 def load_data(file_path: str, device: str): @@ -36,13 +39,24 @@ def load_data(file_path: str, device: str): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(not is_sm_at_least_89(), "Need sm89+") class TestSafeTensors(TestCase): - def test_safetensors(self): - config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) + @parametrize( + "config, act_pre_scale", + [ + (Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), False), + (Int4WeightOnlyConfig(), False), + (Int4WeightOnlyConfig(), True), + ], + ) + def test_safetensors(self, config, act_pre_scale=False): model = torch.nn.Sequential( - torch.nn.Linear(32, 256, dtype=torch.bfloat16, device="cuda") + torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") ) quantize_(model, config) - example_inputs = (torch.randn(2, 32, dtype=torch.bfloat16, device="cuda"),) + if act_pre_scale: + model[0].weight.act_pre_scale = torch.ones( + (1), dtype=torch.bfloat16, device="cuda" + ) + example_inputs = (torch.randn(2, 128, dtype=torch.bfloat16, device="cuda"),) ref_output = model(*example_inputs) with tempfile.NamedTemporaryFile() as f: @@ -54,12 +68,14 @@ def test_safetensors(self): ) model = torch.nn.Sequential( - torch.nn.Linear(32, 256, dtype=torch.bfloat16, device="cuda") + torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") ) model.load_state_dict(reconstructed_dict, assign=True) output = model(*example_inputs) assert torch.equal(output, ref_output) +instantiate_parametrized_tests(TestSafeTensors) + if __name__ == "__main__": run_tests() diff --git a/torchao/prototype/safetensors/safetensors_support.py b/torchao/prototype/safetensors/safetensors_support.py index 19943e4b4a..61457ea328 100644 --- a/torchao/prototype/safetensors/safetensors_support.py +++ b/torchao/prototype/safetensors/safetensors_support.py @@ -5,10 +5,10 @@ import torch from torchao.prototype.safetensors.safetensors_utils import ( - Float8TensorAttributeJSONEncoder, + ALLOWED_TENSORS_SUBCLASSES, + TensorSubclassAttributeJSONEncoder, object_from_dict, ) -from torchao.quantization import Float8Tensor logger: logging.Logger = logging.getLogger(__name__) @@ -77,7 +77,7 @@ def unflatten_tensor_state_dict( tensor_metadata = json.loads(metadata.get(tensor_name)) tensor_type = tensor_metadata.get("_type") - if tensor_type == Float8Tensor.__name__: + if tensor_type in ALLOWED_TENSORS_SUBCLASSES: tensor_metadata["_data"].update(tensor_tensors) result[tensor_name] = object_from_dict(tensor_metadata) elif tensor_type == torch.Tensor.__name__: @@ -140,12 +140,18 @@ def flatten_tensor_state_dict( tensors_data_dict = {} for tensor_name, tensor in tensors_dict.items(): - if isinstance(tensor, Float8Tensor): + if tensor.__class__.__name__ in ALLOWED_TENSORS_SUBCLASSES: tensor_dict = {} - for tensor_data_name in tensor.tensor_data_names: - tensor_dict[tensor_data_name] = getattr(tensor, tensor_data_name) - tensor_metadata = json.dumps(tensor, cls=Float8TensorAttributeJSONEncoder) + all_tensor_data = list(tensor.tensor_data_names) # create a copy + if hasattr(tensor, "optional_tensor_data_names"): + all_tensor_data += tensor.optional_tensor_data_names + + for tensor_data_name in all_tensor_data: + if getattr(tensor, tensor_data_name) is not None: + tensor_dict[tensor_data_name] = getattr(tensor, tensor_data_name) + + tensor_metadata = json.dumps(tensor, cls=TensorSubclassAttributeJSONEncoder) elif type(tensor) is torch.Tensor: tensor_dict = {"_data": tensor} tensor_metadata = json.dumps({"_type": torch.Tensor.__name__}) diff --git a/torchao/prototype/safetensors/safetensors_utils.py b/torchao/prototype/safetensors/safetensors_utils.py index eb0258a505..2b01d7f729 100644 --- a/torchao/prototype/safetensors/safetensors_utils.py +++ b/torchao/prototype/safetensors/safetensors_utils.py @@ -6,12 +6,13 @@ import torch import torchao -from torchao.quantization import Float8Tensor +from torchao.quantization import Float8Tensor, Int4Tensor from torchao.quantization.quantize_.common import KernelPreference from torchao.quantization.quantize_.workflows import QuantizeTensorToFloat8Kwargs ALLOWED_CLASSES = { "Float8Tensor": Float8Tensor, + "Int4Tensor": Int4Tensor, "Float8MMConfig": torchao.float8.inference.Float8MMConfig, "QuantizeTensorToFloat8Kwargs": QuantizeTensorToFloat8Kwargs, "PerRow": torchao.quantization.PerRow, @@ -19,21 +20,27 @@ "KernelPreference": KernelPreference, } -ALLOWED_TENSORS = ["Float8Tensor", "Tensor"] +ALLOWED_TENSORS_SUBCLASSES = ["Float8Tensor", "Int4Tensor"] __all__ = [ - "Float8TensorAttributeJSONEncoder", + "TensorSubclassAttributeJSONEncoder", "object_from_dict", "is_metadata_torchao", ] -class Float8TensorAttributeJSONEncoder(json.JSONEncoder): +class TensorSubclassAttributeJSONEncoder(json.JSONEncoder): def default(self, o): - if isinstance(o, Float8Tensor): + if o.__class__.__name__ in ALLOWED_TENSORS_SUBCLASSES: tensor_attr_dict = {} + optional_tensor_attributes = ( + o.optional_tensor_attribute_names + if hasattr(o, "optional_tensor_attribute_names") + else [] + ) + all_tensor_attributes = ( - o.optional_tensor_attribute_names + o.tensor_attribute_names + optional_tensor_attributes + o.tensor_attribute_names ) for tensor_attribute_name in all_tensor_attributes: @@ -190,7 +197,7 @@ def is_metadata_torchao(metadata: Dict[str, Any]): # returns None if _type not in tensor_dict tensor_type = tensor_dict.get("_type") - if tensor_type not in ALLOWED_TENSORS: + if tensor_type not in ALLOWED_TENSORS_SUBCLASSES or tensor_type != "Tensor": return False return True