Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions test/prototype/safetensors/test_safetensors_support.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import json
import tempfile
import unittest

import torch
from safetensors.torch import load_file, save_file
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)

from torchao import quantize_
from torchao.prototype.safetensors.safetensors_support import (
load_tensor_state_dict,
save_tensor_state_dict,
flatten_tensor_state_dict,
unflatten_tensor_state_dict,
)
from torchao.quantization.granularity import PerRow
from torchao.quantization.quant_api import Float8DynamicActivationFloat8WeightConfig
Expand All @@ -19,6 +21,18 @@
)


def load_data(file_path: str, device: str):
loaded_tensors = load_file(file_path, device)
with open(file_path, "rb") as f:
import struct

header_size = struct.unpack("<Q", f.read(8))[0]
header_bytes = f.read(header_size)
header = json.loads(header_bytes)
metadata = header.get("__metadata__", {})
return loaded_tensors, metadata


@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
class TestSafeTensors(TestCase):
Expand All @@ -32,8 +46,16 @@ def test_safetensors(self):
ref_output = model(*example_inputs)

with tempfile.NamedTemporaryFile() as f:
save_tensor_state_dict(model.state_dict(), f.name)
reconstructed_dict = load_tensor_state_dict(f.name, device="cuda")
tensors_data_dict, metadata_dict = flatten_tensor_state_dict(
model.state_dict()
)
save_file(tensors_data_dict, f.name, metadata=metadata_dict)
tensors_data_dict, metadata_dict = load_data(
file_path=f.name, device="cuda"
)
reconstructed_dict = unflatten_tensor_state_dict(
tensors_data_dict, metadata_dict
)

model = torch.nn.Sequential(
torch.nn.Linear(32, 256, dtype=torch.bfloat16, device="cuda")
Expand Down
172 changes: 97 additions & 75 deletions torchao/prototype/safetensors/safetensors_support.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import json
import logging
from typing import Dict
from typing import Any, Dict

import torch
from safetensors.torch import load_file, save_file

from torchao.prototype.safetensors.safetensors_serialization import (
Float8TensorAttributeJSONEncoder,
Expand All @@ -14,55 +13,67 @@
logger: logging.Logger = logging.getLogger(__name__)


def load_tensor_state_dict(file_path: str, device: str):
def unflatten_tensor_state_dict(
tensors_data_dict: Dict[str, Any],
metadata_dict: Dict[str, Any],
):
"""
Load a dictionary of tensor subclasses from a safetensors file.

For torch.Tensors, we load:
- _data: the tensor data
- _type: the tensor type

For Float8Tensor, we load:
- tensor_data: qdata and scale
- tensor_attributes:
- block_size
- mm_config
- hp_value_lb
- hp_value_ub
- act_quant_kwargs
- kernel_preference
- dtype
Reconstructs tensor subclass state dict from provided torch.Tensor data and metadata
This function is used after loading in previously saved model state dict (using safetensors.save_file) to reconstruct tensor subclass structure

For example, given a previously flattened tensors_data_dict and metadata_dict:
tensors_data_dict = {
'0.weight:qdata': torch.Tensor(...),
'0.weight:scale': torch.Tensor(...),
'0.bias:_data': torch.Tensor(...),
}
metadata_dict = {
'0.weight': {
'_type': 'Float8Tensor',
'_data': {
'block_size': [1,32],
...
}
}
'0.bias': {
'_type': 'torch.Tensor',
}
'tensor_names': ['0.weight', '0.bias']
}

We recover the structure of the original state dict:
tensor_dict = {
'0.weight': Float8Tensor(
qdata=torch.Tensor(...),
scale=torch.Tensor(...),
block_size=[1,32],
...),
'0.bias': torch.Tensor(...),
}

Args:
file_path: Path to the safetensors file
tensors_data_dict: a dictionary from "tensor_name:tensor_data_attribute_name" to flattened torch.Tensor data for tensor subclass instance
metadata_dict: a dictionary from "tensor_name" to another dictionary that contains type and attributes for tensor subclass instance

Returns:
Dictionary of reconstructed tensor subclasses
"""
loaded_tensors = load_file(file_path, device)

with open(file_path, "rb") as f:
import struct

header_size = struct.unpack("<Q", f.read(8))[0]
header_bytes = f.read(header_size)
header = json.loads(header_bytes)
metadata = header.get("__metadata__", {})
combined_data = {**tensors_data_dict, **metadata_dict}

if "tensor_names" not in metadata:
if "tensor_names" not in metadata_dict:
raise ValueError("No tensors found")

tensor_names = json.loads(metadata["tensor_names"])
tensor_names = json.loads(metadata_dict["tensor_names"])
result = {}

for tensor_name in tensor_names:
tensor_tensors = {}
for key, value in loaded_tensors.items():
for key, value in combined_data.items():
if key.startswith(f"{tensor_name}:"):
# Remove the prefix
tensor_tensors[key[len(tensor_name) + 1 :]] = value

tensor_metadata = json.loads(metadata.get(tensor_name))
tensor_metadata = json.loads(metadata_dict.get(tensor_name))
tensor_type = tensor_metadata.get("_type")

if tensor_type == Float8Tensor.__name__:
Expand All @@ -73,54 +84,69 @@ def load_tensor_state_dict(file_path: str, device: str):
else:
raise ValueError(f"Unsupported tensor type: {tensor_type}")

logger.info(
f"Loaded {len(tensor_names)} tensor subclasses from {file_path} with metadata"
)
return result


def save_tensor_state_dict(
tensor_dict: Dict[str, Dict[str, torch.Tensor]],
file_path: str,
def flatten_tensor_state_dict(
tensors_dict: Dict[str, Dict[str, torch.Tensor]],
):
"""
Save a dictionary of tensor subclasses with appropriate metadata.

For torch.Tensors, we save:
- _data: the tensor data
- _type: the tensor type

For Float8Tensor, we save:
- tensor_data:
- qdata
- scale
- tensor_attributes:
- block_size
- mm_config
- hp_value_lb
- hp_value_ub
- act_quant_kwargs
- kernel_preference
- dtype
Flattens a dictionary of tensor subclasses so that it is compatible with safetensors.save_file
We disconstruct tensor subclass structure into torch.Tensor data and metadata

For example, given something like:
tensor_dict = {
'0.weight': Float8Tensor(
qdata=torch.Tensor(...),
scale=torch.Tensor(...),
block_size=[1,32],
...),
'0.bias': torch.Tensor(...),
}

We flatten this to:
tensors_data = {
'0.weight:qdata': torch.Tensor(...),
'0.weight:scale': torch.Tensor(...),
'0.bias:_data': torch.Tensor(...),
}
metadata = {
'0.weight': {
'_type': 'Float8Tensor',
'_data': {
'block_size': [1,32],
...
}
}
'0.bias': {
'_type': 'torch.Tensor',
}
'tensor_names': ['0.weight', '0.bias']
}

Args:
tensor_dict: Dictionary of tensor subclasses to save, with keys as tensor names
file_path: Path where to save the tensors

Returns:
A tuple of (tensors_data, metadata) where
tensors_data: Dict[str, torch.Tensor] contains the tensor data
metadata: Dict[str, str] contains accompanying metadata from tensor subclass
This structure is compatible with safetensors.save_file
"""

combined_metadata = {}
combined_tensors_dict = {}
metadata_dict = {}
tensors_data_dict = {}

for tensor_name, tensor in tensor_dict.items():
for tensor_name, tensor in tensors_dict.items():
if isinstance(tensor, Float8Tensor):
tensors_dict = {}
tensor_dict = {}
for tensor_data_name in tensor.tensor_data_names:
tensors_dict[tensor_data_name] = getattr(tensor, tensor_data_name)
tensor_dict[tensor_data_name] = getattr(tensor, tensor_data_name)

metadata = json.dumps(tensor, cls=Float8TensorAttributeJSONEncoder)
tensor_metadata = json.dumps(tensor, cls=Float8TensorAttributeJSONEncoder)
elif type(tensor) is torch.Tensor:
tensors_dict = {"_data": tensor}
metadata = json.dumps({"_type": torch.Tensor.__name__})
tensor_dict = {"_data": tensor}
tensor_metadata = json.dumps({"_type": torch.Tensor.__name__})
else:
raise ValueError(f"Unsupported tensor type: {type(tensor)}")

Expand All @@ -129,15 +155,11 @@ def save_tensor_state_dict(
f"{tensor_name}:{key}": (
value.detach().clone() if isinstance(value, torch.Tensor) else value
)
for key, value in tensors_dict.items()
for key, value in tensor_dict.items()
}

combined_metadata[tensor_name] = metadata
combined_tensors_dict.update(prefixed_tensors_dict)

combined_metadata["tensor_names"] = json.dumps(list(tensor_dict.keys()))
metadata_dict[tensor_name] = tensor_metadata
tensors_data_dict.update(prefixed_tensors_dict)

save_file(combined_tensors_dict, file_path, metadata=combined_metadata)
logger.info(
f"Saved {len(tensor_dict)} tensor subclasses to {file_path} with metadata"
)
metadata_dict["tensor_names"] = json.dumps(list(tensors_dict.keys()))
return tensors_data_dict, metadata_dict
Loading