diff --git a/hf_torchao_vllm/convert_torchao_checkpoint_to_compressed_tensors.py b/hf_torchao_vllm/convert_torchao_checkpoint_to_compressed_tensors.py index a8b964d..cac7a41 100644 --- a/hf_torchao_vllm/convert_torchao_checkpoint_to_compressed_tensors.py +++ b/hf_torchao_vllm/convert_torchao_checkpoint_to_compressed_tensors.py @@ -1,6 +1,7 @@ import copy import filecmp import json +import os import pathlib import shutil import subprocess @@ -16,6 +17,8 @@ from safetensors import safe_open from safetensors.torch import save_file +from utils import convert_pt_statedict_to_safetensors, convert_pt_multifile_index_to_safetensors + def ao_config_to_compressed_tensors_config(aobaseconfig: AOBaseConfig) -> Dict[str, Any]: # for now, allowlist of recipes we know how to convert and hand convert # them here @@ -55,6 +58,10 @@ def run( dir_validation: str = 'data/llmcompressor/fp8-opt-125m', skip_conversion: bool = False, ): + dir_source = dir_source.rstrip('/') + dir_target = dir_target.rstrip('/') + dir_validation = dir_validation.rstrip('/') + config_name_source = f"{dir_source}/config.json" config_name_target = f"{dir_target}/config.json" config_name_validation = f"{dir_validation}/config.json" @@ -62,13 +69,19 @@ def run( weights_name_target = f"{dir_target}/model.safetensors" weights_name_validation = f"{dir_validation}/model.safetensors" + # create new dir if not yet exists + os.makedirs(dir_target, exist_ok=True) + if not skip_conversion: + source_converted_filenames = set() + # # convert config.json # with open(config_name_source, 'r') as f: config_source = json.load(f) + print(json.dumps(config_source, indent=2)) # get torchao config format # example: https://www.internalfb.com/phabricator/paste/view/P1975688376 @@ -78,6 +91,11 @@ def run( fqn_to_serialized_aobaseconfig = old_hf_quantization_config["quant_type"] assert len(fqn_to_serialized_aobaseconfig) == 1, "unsupported" + if fqn_to_serialized_aobaseconfig['default']['_type'] == 'ModuleFqnToConfig': + fqn_to_serialized_aobaseconfig = \ + fqn_to_serialized_aobaseconfig['default']['_data']['module_fqn_to_config'] + + new_hf_quantization_config = { "config_groups": {}, "format": "float-quantized", @@ -90,13 +108,14 @@ def run( } for fqn, serialized_aobaseconfig in fqn_to_serialized_aobaseconfig.items(): - print(fqn, serialized_aobaseconfig) + if serialized_aobaseconfig is None: + new_hf_quantization_config['ignore'].append(fqn) + continue + aobaseconfig = config_from_dict(serialized_aobaseconfig) - print(aobaseconfig) ct_config = ao_config_to_compressed_tensors_config(aobaseconfig) - print(json.dumps(ct_config, indent=2)) - assert fqn == "default", "unsupported" + assert fqn in ("default", "_default"), "unsupported" new_hf_quantization_config["config_groups"]["group_0"] = ct_config # for now, modify config_source inplace @@ -106,6 +125,8 @@ def run( with open(config_name_target, 'w') as f: json.dump(config_source, f, indent=2) + source_converted_filenames.add(config_name_source) + # # convert the checkpoint # @@ -113,39 +134,49 @@ def run( # not sure why I still need this torch.serialization.add_safe_globals([getattr]) - old_state_dict = torch.load(weights_name_source, weights_only=True) - new_state_dict = {} - - for k, v in old_state_dict.items(): - print(k, v.shape, type(v)) - if type(v) == torch.Tensor: - - if "lm_head" in k: - # work around issues detailed in - # https://huggingface.co/docs/safetensors/torch_shared_tensors - v = copy.deepcopy(v) - - new_state_dict[k] = v - elif type(v) == Float8Tensor: - new_state_dict[k] = v.qdata - # for now, manually cast scale to bfloat16 to match currnt - # llm-compressor script - # TODO(future): prob needs to be user controllable - new_state_dict[k + '_scale'] = v.scale.bfloat16() - else: - raise AssertionError(f'unsupported type {type(v)}') - save_file(new_state_dict, weights_name_target) + is_single_chunk = os.path.isfile(f'{dir_source}/pytorch_model.bin') + if is_single_chunk: + convert_pt_statedict_to_safetensors(weights_name_source, weights_name_target) + source_converted_filenames.add(weights_name_source) + else: + # convert each model state_dict file + model_part_filenames = [] + for file_path in pathlib.Path(dir_source).iterdir(): + if not file_path.is_file(): + continue + if not (('pytorch_model') in str(file_path) and str(file_path).endswith('bin')): + continue + pt_sd_filename = str(file_path) + # dir_source/pytorch_model-00001-of-00004.bin -> dir_target/model-00001-of-00004.safetensors + safetensors_sd_filename = pt_sd_filename.replace(dir_source, dir_target) + safetensors_sd_filename = safetensors_sd_filename.replace('pytorch_model', 'model') + safetensors_sd_filename = safetensors_sd_filename.replace('.bin', '.safetensors') + model_part_filenames.append(safetensors_sd_filename) + print(pt_sd_filename, safetensors_sd_filename) + convert_pt_statedict_to_safetensors(pt_sd_filename, safetensors_sd_filename) + source_converted_filenames.add(pt_sd_filename) + + # convert pytorch_model.bin.index.json + convert_pt_multifile_index_to_safetensors( + f'{dir_source}/pytorch_model.bin.index.json', + f'{dir_target}/model.safetensors.index.json', + model_part_filenames, + ) + source_converted_filenames.add(f'{dir_source}/pytorch_model.bin.index.json') + + print(source_converted_filenames) # move all the other files over for dir_and_file_path in pathlib.Path(dir_source).iterdir(): if not dir_and_file_path.is_file(): continue - file_path = dir_and_file_path.parts[-1] - if file_path in ('config.json', 'pytorch_model.bin'): + if str(dir_and_file_path) in source_converted_filenames: # these are converted in custom logic elsewhere in this script continue # if we got here, we just need to copy the file over without any changes + file_path = dir_and_file_path.parts[-1] target_file_path = f"{dir_target}/{str(file_path)}" + print(f'copying {dir_and_file_path} to {target_file_path}') shutil.copyfile(dir_and_file_path, target_file_path) # validate target_dir vs validation_dir @@ -165,9 +196,11 @@ def run( # this will always fail, for now, as we are not perfectly matching print(e.stderr) + # TODO(future, as needed): also validate the other files, they are unlikely to match + # exactly for any model with >1 chunk of state dict files since we are not + # trying to enfore that the same tensors live in the same chunks. + elif file_path_target == 'model.safetensors': - # TODO implement me - pass with safe_open(dir_and_file_path, framework='pt') as f_target: with safe_open(dir_and_file_path_validation, framework='pt') as f_validation: diff --git a/hf_torchao_vllm/utils.py b/hf_torchao_vllm/utils.py index 50a7210..365b26f 100644 --- a/hf_torchao_vllm/utils.py +++ b/hf_torchao_vllm/utils.py @@ -1,10 +1,15 @@ +import copy import json import os +from typing import List import pathlib import safetensors +from safetensors.torch import save_file import torch +from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor + torch.serialization.add_safe_globals([getattr]) @@ -57,3 +62,86 @@ def inspect_model_state_dict(dir_name, model_name, model_extension) -> None: continue print(file_path) _inspect_state_dict_file(file_path) + +def convert_pt_statedict_to_safetensors( + pt_statedict_filename, + safetensors_statedict_filename, +) -> None: + old_state_dict = torch.load(pt_statedict_filename, weights_only=True) + new_state_dict = {} + + for k, v in old_state_dict.items(): + print(k, v.shape, type(v)) + if type(v) == torch.Tensor: + + if "lm_head" in k: + # work around issues detailed in + # https://huggingface.co/docs/safetensors/torch_shared_tensors + v = copy.deepcopy(v) + + new_state_dict[k] = v + elif type(v) == Float8Tensor: + new_state_dict[k] = v.qdata + # for now, manually cast scale to bfloat16 to match current + # llm-compressor script + # TODO(future): prob needs to be user controllable + new_state_dict[k + '_scale'] = v.scale.bfloat16() + else: + raise AssertionError(f'unsupported type {type(v)}') + save_file(new_state_dict, safetensors_statedict_filename) + +def convert_pt_multifile_index_to_safetensors( + source_filename: str, + target_filename: str, + model_part_filenames: List[str], +) -> None: + """ + Source format + + { + "metadata": {...}, + "weight_map": { + "foo": "pytorch_model-00001-of-00004.bin", + "bar": "pytorch_model-00002-of-00004.bin", + ... + } + } + + Target format + + { + "metadata": {...}, + "weight_map": { + # weight already in high precision + "foo": "pytorch_model-00001-of-00004.bin", + # weight original stored as tensor subclass, but now decomposed + # into qdata and scale + "bar": "model-00002-of-00004.safetensors", + "bar_scale": "model-00002-of-00004.safetensors", + ... + } + } + + For now, metadata is not updated. + """ + + # generate the new fqn to weight location map from the new safetensors files + new_weight_map = {} + for model_part_filename in model_part_filenames: + # print(model_part_filename) + + # get the file_name from dir_name/file_name + basename = os.path.basename(model_part_filename) + # print(basename) + + with safetensors.safe_open(model_part_filename, framework='pt', device='cpu') as f: + for k in f.keys(): + new_weight_map[k] = basename + + # save the updated mapping + with open(source_filename, 'r') as f: + source_mapping = json.load(f) + source_mapping['weight_map'] = new_weight_map + # print(json.dumps(source_mapping, indent=2)) + with open(target_filename, 'w') as f: + json.dump(source_mapping, f, indent=2)