From 8708caf64c2fbaee8564acee512cbee0c705f731 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Wed, 3 Sep 2025 16:49:19 -0700 Subject: [PATCH] refactoring functions for huggingface integration --- .../safetensors/test_safetensors_support.py | 30 ++- .../safetensors/safetensors_support.py | 172 ++++++++++-------- 2 files changed, 123 insertions(+), 79 deletions(-) diff --git a/test/prototype/safetensors/test_safetensors_support.py b/test/prototype/safetensors/test_safetensors_support.py index d21e2997e6..b755640fe0 100644 --- a/test/prototype/safetensors/test_safetensors_support.py +++ b/test/prototype/safetensors/test_safetensors_support.py @@ -1,7 +1,9 @@ +import json import tempfile import unittest import torch +from safetensors.torch import load_file, save_file from torch.testing._internal.common_utils import ( TestCase, run_tests, @@ -9,8 +11,8 @@ from torchao import quantize_ from torchao.prototype.safetensors.safetensors_support import ( - load_tensor_state_dict, - save_tensor_state_dict, + flatten_tensor_state_dict, + unflatten_tensor_state_dict, ) from torchao.quantization.granularity import PerRow from torchao.quantization.quant_api import Float8DynamicActivationFloat8WeightConfig @@ -19,6 +21,18 @@ ) +def load_data(file_path: str, device: str): + loaded_tensors = load_file(file_path, device) + with open(file_path, "rb") as f: + import struct + + header_size = struct.unpack("