Skip to content
32 changes: 24 additions & 8 deletions test/prototype/safetensors/test_safetensors_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from safetensors.torch import load_file, save_file
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)

Expand All @@ -15,10 +17,11 @@
unflatten_tensor_state_dict,
)
from torchao.quantization.granularity import PerRow
from torchao.quantization.quant_api import Float8DynamicActivationFloat8WeightConfig
from torchao.utils import (
is_sm_at_least_89,
from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
Int4WeightOnlyConfig,
)
from torchao.utils import is_sm_at_least_89


def load_data(file_path: str, device: str):
Expand All @@ -36,13 +39,24 @@ def load_data(file_path: str, device: str):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
class TestSafeTensors(TestCase):
def test_safetensors(self):
config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
@parametrize(
"config, act_pre_scale",
[
(Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), False),
(Int4WeightOnlyConfig(), False),
(Int4WeightOnlyConfig(), True),
],
)
def test_safetensors(self, config, act_pre_scale=False):
model = torch.nn.Sequential(
torch.nn.Linear(32, 256, dtype=torch.bfloat16, device="cuda")
torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
)
quantize_(model, config)
example_inputs = (torch.randn(2, 32, dtype=torch.bfloat16, device="cuda"),)
if act_pre_scale:
model[0].weight.act_pre_scale = torch.ones(
(1), dtype=torch.bfloat16, device="cuda"
)
example_inputs = (torch.randn(2, 128, dtype=torch.bfloat16, device="cuda"),)
ref_output = model(*example_inputs)

with tempfile.NamedTemporaryFile() as f:
Expand All @@ -54,12 +68,14 @@ def test_safetensors(self):
)

model = torch.nn.Sequential(
torch.nn.Linear(32, 256, dtype=torch.bfloat16, device="cuda")
torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
)
model.load_state_dict(reconstructed_dict, assign=True)
output = model(*example_inputs)
assert torch.equal(output, ref_output)


instantiate_parametrized_tests(TestSafeTensors)

if __name__ == "__main__":
run_tests()
20 changes: 13 additions & 7 deletions torchao/prototype/safetensors/safetensors_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import torch

from torchao.prototype.safetensors.safetensors_utils import (
Float8TensorAttributeJSONEncoder,
ALLOWED_TENSORS_SUBCLASSES,
TensorSubclassAttributeJSONEncoder,
object_from_dict,
)
from torchao.quantization import Float8Tensor

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -77,7 +77,7 @@ def unflatten_tensor_state_dict(
tensor_metadata = json.loads(metadata.get(tensor_name))
tensor_type = tensor_metadata.get("_type")

if tensor_type == Float8Tensor.__name__:
if tensor_type in ALLOWED_TENSORS_SUBCLASSES:
tensor_metadata["_data"].update(tensor_tensors)
result[tensor_name] = object_from_dict(tensor_metadata)
elif tensor_type == torch.Tensor.__name__:
Expand Down Expand Up @@ -140,12 +140,18 @@ def flatten_tensor_state_dict(
tensors_data_dict = {}

for tensor_name, tensor in tensors_dict.items():
if isinstance(tensor, Float8Tensor):
if tensor.__class__.__name__ in ALLOWED_TENSORS_SUBCLASSES:
tensor_dict = {}
for tensor_data_name in tensor.tensor_data_names:
tensor_dict[tensor_data_name] = getattr(tensor, tensor_data_name)

tensor_metadata = json.dumps(tensor, cls=Float8TensorAttributeJSONEncoder)
all_tensor_data = list(tensor.tensor_data_names) # create a copy
if hasattr(tensor, "optional_tensor_data_names"):
all_tensor_data += tensor.optional_tensor_data_names

for tensor_data_name in all_tensor_data:
if getattr(tensor, tensor_data_name) is not None:
tensor_dict[tensor_data_name] = getattr(tensor, tensor_data_name)

tensor_metadata = json.dumps(tensor, cls=TensorSubclassAttributeJSONEncoder)
elif type(tensor) is torch.Tensor:
tensor_dict = {"_data": tensor}
tensor_metadata = json.dumps({"_type": torch.Tensor.__name__})
Expand Down
21 changes: 14 additions & 7 deletions torchao/prototype/safetensors/safetensors_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,41 @@
import torch

import torchao
from torchao.quantization import Float8Tensor
from torchao.quantization import Float8Tensor, Int4Tensor
from torchao.quantization.quantize_.common import KernelPreference
from torchao.quantization.quantize_.workflows import QuantizeTensorToFloat8Kwargs

ALLOWED_CLASSES = {
"Float8Tensor": Float8Tensor,
"Int4Tensor": Int4Tensor,
"Float8MMConfig": torchao.float8.inference.Float8MMConfig,
"QuantizeTensorToFloat8Kwargs": QuantizeTensorToFloat8Kwargs,
"PerRow": torchao.quantization.PerRow,
"PerTensor": torchao.quantization.PerTensor,
"KernelPreference": KernelPreference,
}

ALLOWED_TENSORS = ["Float8Tensor", "Tensor"]
ALLOWED_TENSORS_SUBCLASSES = ["Float8Tensor", "Int4Tensor"]

__all__ = [
"Float8TensorAttributeJSONEncoder",
"TensorSubclassAttributeJSONEncoder",
"object_from_dict",
"is_metadata_torchao",
]


class Float8TensorAttributeJSONEncoder(json.JSONEncoder):
class TensorSubclassAttributeJSONEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, Float8Tensor):
if o.__class__.__name__ in ALLOWED_TENSORS_SUBCLASSES:
tensor_attr_dict = {}
optional_tensor_attributes = (
o.optional_tensor_attribute_names
if hasattr(o, "optional_tensor_attribute_names")
else []
)

all_tensor_attributes = (
o.optional_tensor_attribute_names + o.tensor_attribute_names
optional_tensor_attributes + o.tensor_attribute_names
)

for tensor_attribute_name in all_tensor_attributes:
Expand Down Expand Up @@ -190,7 +197,7 @@ def is_metadata_torchao(metadata: Dict[str, Any]):

# returns None if _type not in tensor_dict
tensor_type = tensor_dict.get("_type")
if tensor_type not in ALLOWED_TENSORS:
if tensor_type not in ALLOWED_TENSORS_SUBCLASSES or tensor_type != "Tensor":
return False

return True
Loading