Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 63 additions & 30 deletions hf_torchao_vllm/convert_torchao_checkpoint_to_compressed_tensors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import filecmp
import json
import os
import pathlib
import shutil
import subprocess
Expand All @@ -16,6 +17,8 @@
from safetensors import safe_open
from safetensors.torch import save_file

from utils import convert_pt_statedict_to_safetensors, convert_pt_multifile_index_to_safetensors

def ao_config_to_compressed_tensors_config(aobaseconfig: AOBaseConfig) -> Dict[str, Any]:
# for now, allowlist of recipes we know how to convert and hand convert
# them here
Expand Down Expand Up @@ -55,20 +58,30 @@ def run(
dir_validation: str = 'data/llmcompressor/fp8-opt-125m',
skip_conversion: bool = False,
):
dir_source = dir_source.rstrip('/')
dir_target = dir_target.rstrip('/')
dir_validation = dir_validation.rstrip('/')

config_name_source = f"{dir_source}/config.json"
config_name_target = f"{dir_target}/config.json"
config_name_validation = f"{dir_validation}/config.json"
weights_name_source = f"{dir_source}/pytorch_model.bin"
weights_name_target = f"{dir_target}/model.safetensors"
weights_name_validation = f"{dir_validation}/model.safetensors"

# create new dir if not yet exists
os.makedirs(dir_target, exist_ok=True)

if not skip_conversion:
source_converted_filenames = set()

#
# convert config.json
#

with open(config_name_source, 'r') as f:
config_source = json.load(f)
print(json.dumps(config_source, indent=2))

# get torchao config format
# example: https://www.internalfb.com/phabricator/paste/view/P1975688376
Expand All @@ -78,6 +91,11 @@ def run(
fqn_to_serialized_aobaseconfig = old_hf_quantization_config["quant_type"]
assert len(fqn_to_serialized_aobaseconfig) == 1, "unsupported"

if fqn_to_serialized_aobaseconfig['default']['_type'] == 'ModuleFqnToConfig':
fqn_to_serialized_aobaseconfig = \
fqn_to_serialized_aobaseconfig['default']['_data']['module_fqn_to_config']


new_hf_quantization_config = {
"config_groups": {},
"format": "float-quantized",
Expand All @@ -90,13 +108,14 @@ def run(
}

for fqn, serialized_aobaseconfig in fqn_to_serialized_aobaseconfig.items():
print(fqn, serialized_aobaseconfig)
if serialized_aobaseconfig is None:
new_hf_quantization_config['ignore'].append(fqn)
continue

aobaseconfig = config_from_dict(serialized_aobaseconfig)
print(aobaseconfig)
ct_config = ao_config_to_compressed_tensors_config(aobaseconfig)
print(json.dumps(ct_config, indent=2))

assert fqn == "default", "unsupported"
assert fqn in ("default", "_default"), "unsupported"
new_hf_quantization_config["config_groups"]["group_0"] = ct_config

# for now, modify config_source inplace
Expand All @@ -106,46 +125,58 @@ def run(
with open(config_name_target, 'w') as f:
json.dump(config_source, f, indent=2)

source_converted_filenames.add(config_name_source)

#
# convert the checkpoint
#

# not sure why I still need this
torch.serialization.add_safe_globals([getattr])

old_state_dict = torch.load(weights_name_source, weights_only=True)
new_state_dict = {}

for k, v in old_state_dict.items():
print(k, v.shape, type(v))
if type(v) == torch.Tensor:

if "lm_head" in k:
# work around issues detailed in
# https://huggingface.co/docs/safetensors/torch_shared_tensors
v = copy.deepcopy(v)

new_state_dict[k] = v
elif type(v) == Float8Tensor:
new_state_dict[k] = v.qdata
# for now, manually cast scale to bfloat16 to match currnt
# llm-compressor script
# TODO(future): prob needs to be user controllable
new_state_dict[k + '_scale'] = v.scale.bfloat16()
else:
raise AssertionError(f'unsupported type {type(v)}')
save_file(new_state_dict, weights_name_target)
is_single_chunk = os.path.isfile(f'{dir_source}/pytorch_model.bin')
if is_single_chunk:
convert_pt_statedict_to_safetensors(weights_name_source, weights_name_target)
source_converted_filenames.add(weights_name_source)
else:
# convert each model state_dict file
model_part_filenames = []
for file_path in pathlib.Path(dir_source).iterdir():
if not file_path.is_file():
continue
if not (('pytorch_model') in str(file_path) and str(file_path).endswith('bin')):
continue
pt_sd_filename = str(file_path)
# dir_source/pytorch_model-00001-of-00004.bin -> dir_target/model-00001-of-00004.safetensors
safetensors_sd_filename = pt_sd_filename.replace(dir_source, dir_target)
safetensors_sd_filename = safetensors_sd_filename.replace('pytorch_model', 'model')
safetensors_sd_filename = safetensors_sd_filename.replace('.bin', '.safetensors')
model_part_filenames.append(safetensors_sd_filename)
print(pt_sd_filename, safetensors_sd_filename)
convert_pt_statedict_to_safetensors(pt_sd_filename, safetensors_sd_filename)
source_converted_filenames.add(pt_sd_filename)

# convert pytorch_model.bin.index.json
convert_pt_multifile_index_to_safetensors(
f'{dir_source}/pytorch_model.bin.index.json',
f'{dir_target}/model.safetensors.index.json',
model_part_filenames,
)
source_converted_filenames.add(f'{dir_source}/pytorch_model.bin.index.json')

print(source_converted_filenames)

# move all the other files over
for dir_and_file_path in pathlib.Path(dir_source).iterdir():
if not dir_and_file_path.is_file():
continue
file_path = dir_and_file_path.parts[-1]
if file_path in ('config.json', 'pytorch_model.bin'):
if str(dir_and_file_path) in source_converted_filenames:
# these are converted in custom logic elsewhere in this script
continue
# if we got here, we just need to copy the file over without any changes
file_path = dir_and_file_path.parts[-1]
target_file_path = f"{dir_target}/{str(file_path)}"
print(f'copying {dir_and_file_path} to {target_file_path}')
shutil.copyfile(dir_and_file_path, target_file_path)

# validate target_dir vs validation_dir
Expand All @@ -165,9 +196,11 @@ def run(
# this will always fail, for now, as we are not perfectly matching
print(e.stderr)

# TODO(future, as needed): also validate the other files, they are unlikely to match
# exactly for any model with >1 chunk of state dict files since we are not
# trying to enfore that the same tensors live in the same chunks.

elif file_path_target == 'model.safetensors':
# TODO implement me
pass

with safe_open(dir_and_file_path, framework='pt') as f_target:
with safe_open(dir_and_file_path_validation, framework='pt') as f_validation:
Expand Down
88 changes: 88 additions & 0 deletions hf_torchao_vllm/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import copy
import json
import os
from typing import List
import pathlib

import safetensors
from safetensors.torch import save_file

import torch
from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor


torch.serialization.add_safe_globals([getattr])

Expand Down Expand Up @@ -57,3 +62,86 @@ def inspect_model_state_dict(dir_name, model_name, model_extension) -> None:
continue
print(file_path)
_inspect_state_dict_file(file_path)

def convert_pt_statedict_to_safetensors(
pt_statedict_filename,
safetensors_statedict_filename,
) -> None:
old_state_dict = torch.load(pt_statedict_filename, weights_only=True)
new_state_dict = {}

for k, v in old_state_dict.items():
print(k, v.shape, type(v))
if type(v) == torch.Tensor:

if "lm_head" in k:
# work around issues detailed in
# https://huggingface.co/docs/safetensors/torch_shared_tensors
v = copy.deepcopy(v)

new_state_dict[k] = v
elif type(v) == Float8Tensor:
new_state_dict[k] = v.qdata
# for now, manually cast scale to bfloat16 to match current
# llm-compressor script
# TODO(future): prob needs to be user controllable
new_state_dict[k + '_scale'] = v.scale.bfloat16()
else:
raise AssertionError(f'unsupported type {type(v)}')
save_file(new_state_dict, safetensors_statedict_filename)

def convert_pt_multifile_index_to_safetensors(
source_filename: str,
target_filename: str,
model_part_filenames: List[str],
) -> None:
"""
Source format

{
"metadata": {...},
"weight_map": {
"foo": "pytorch_model-00001-of-00004.bin",
"bar": "pytorch_model-00002-of-00004.bin",
...
}
}

Target format

{
"metadata": {...},
"weight_map": {
# weight already in high precision
"foo": "pytorch_model-00001-of-00004.bin",
# weight original stored as tensor subclass, but now decomposed
# into qdata and scale
"bar": "model-00002-of-00004.safetensors",
"bar_scale": "model-00002-of-00004.safetensors",
...
}
}

For now, metadata is not updated.
"""

# generate the new fqn to weight location map from the new safetensors files
new_weight_map = {}
for model_part_filename in model_part_filenames:
# print(model_part_filename)

# get the file_name from dir_name/file_name
basename = os.path.basename(model_part_filename)
# print(basename)

with safetensors.safe_open(model_part_filename, framework='pt', device='cpu') as f:
for k in f.keys():
new_weight_map[k] = basename

# save the updated mapping
with open(source_filename, 'r') as f:
source_mapping = json.load(f)
source_mapping['weight_map'] = new_weight_map
# print(json.dumps(source_mapping, indent=2))
with open(target_filename, 'w') as f:
json.dump(source_mapping, f, indent=2)