From 5c7cd995388ec05a49d599771e444b9d42e0924a Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 6 Oct 2025 04:30:21 -0700 Subject: [PATCH] extend llmcompressor script and inspection scripts to handle Qwen-1.5 MoE Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- .../inspect_llm_compressor_output.py | 20 ++++--- hf_torchao_vllm/inspect_torchao_output.py | 21 +++++-- .../quantize_hf_model_with_llm_compressor.py | 26 +++++--- hf_torchao_vllm/utils.py | 59 +++++++++++++++++++ 4 files changed, 105 insertions(+), 21 deletions(-) create mode 100644 hf_torchao_vllm/utils.py diff --git a/hf_torchao_vllm/inspect_llm_compressor_output.py b/hf_torchao_vllm/inspect_llm_compressor_output.py index a991a1e..9416593 100644 --- a/hf_torchao_vllm/inspect_llm_compressor_output.py +++ b/hf_torchao_vllm/inspect_llm_compressor_output.py @@ -5,6 +5,8 @@ import json import fire +from utils import inspect_model_state_dict + def run( dir_name: str = 'data/llmcompressor/fp8-opt-125m', ): @@ -14,13 +16,17 @@ def run( # TODO: pretty print print(json.dumps(data, indent=2)) - # inpect the model, saved in safetensors format - model_name = f'{dir_name}/model.safetensors' - with safetensors.safe_open(model_name, framework='pt', device='cpu') as f: - print(f.metadata()) - for k in f.keys(): - t = f.get_tensor(k) - print(k, t.shape, t.dtype) + model_name, model_extension = 'model', 'safetensors' + inspect_model_state_dict(dir_name, model_name, model_extension) + + if False: + # inpect the model, saved in safetensors format + model_name = f'{dir_name}/model.safetensors' + with safetensors.safe_open(model_name, framework='pt', device='cpu') as f: + print(f.metadata()) + for k in f.keys(): + t = f.get_tensor(k) + print(k, t.shape, t.dtype) if __name__ == '__main__': fire.Fire(run) diff --git a/hf_torchao_vllm/inspect_torchao_output.py b/hf_torchao_vllm/inspect_torchao_output.py index edd4369..16d6f05 100644 --- a/hf_torchao_vllm/inspect_torchao_output.py +++ b/hf_torchao_vllm/inspect_torchao_output.py @@ -2,10 +2,14 @@ # via the `torchao_hf_script.py` script import json +import os +import pathlib import torch import torchao # this is needed to run torch.serialization.add_safe_globals([torchao.quantization.Float8Tensor]) import fire +from utils import inspect_model_state_dict + # not sure why I still need this torch.serialization.add_safe_globals([getattr]) @@ -15,14 +19,21 @@ def run(dir_name: str = 'data/torchao/fp8-opt-125m'): # inspect the config with open(json_config_name, 'r') as f: data = json.load(f) - # TODO: pretty print print(json.dumps(data, indent=2)) # inspect the data - model_name = f'{dir_name}/pytorch_model.bin' - state_dict = torch.load(model_name, weights_only=True) - for k, v in state_dict.items(): - print(k, v.shape, type(v)) + # + # if there is a single chunk, the state dict is named `pytorch_model.bin` + # + # if there are multiple chunks, the state dict is spread across multiple files: + # + # pytorch_model-00001-of-00004.bin + # ... + # pytorch_model-00004-of-00004.bin + # pytorch_model.bin.index.json + # + model_name, model_extension = 'pytorch_model', 'bin' + inspect_model_state_dict(dir_name, model_name, model_extension) if __name__ == '__main__': fire.Fire(run) diff --git a/hf_torchao_vllm/quantize_hf_model_with_llm_compressor.py b/hf_torchao_vllm/quantize_hf_model_with_llm_compressor.py index 4e67ce4..fe5ff3d 100644 --- a/hf_torchao_vllm/quantize_hf_model_with_llm_compressor.py +++ b/hf_torchao_vllm/quantize_hf_model_with_llm_compressor.py @@ -10,21 +10,29 @@ import fire -def run(): - - # MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" - MODEL_ID = "facebook/opt-125m" - +def run(model_name: str = 'facebook/opt-125m'): # Load model. - model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16) - tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16) + print(model) + tokenizer = AutoTokenizer.from_pretrained(model_name) # Configure the quantization algorithm and scheme. # In this case, we: # * quantize the weights to fp8 with per channel via ptq # * quantize the activations to fp8 with dynamic per token recipe = QuantizationModifier( - targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"] + targets="Linear", + scheme="FP8_DYNAMIC", + ignore=[ + "lm_head", + # for Qwen MoE, but ok to just hardcode here for now + # https://github.com/vllm-project/llm-compressor/blob/33ef5f497a9801893764c6a2c880cb1f560067fa/examples/quantizing_moe/qwen_example.py#L10 + "re:.*mlp.gate$", + "re:.*mlp.shared_expert_gate$", + # also skip attention and shared expert, to focus on MoE for now + "re:.*self_attn.*", + "re:.*shared_expert.*", + ], ) # Apply quantization. @@ -41,7 +49,7 @@ def run(): print("==========================================") # Save to disk in compressed-tensors format. - SAVE_DIR = "data/llmcompressor/" + "fp8-" + MODEL_ID.rstrip("/").split("/")[-1] + SAVE_DIR = "data/llmcompressor/" + "fp8-" + model_name.rstrip("/").split("/")[-1] model.save_pretrained(SAVE_DIR) tokenizer.save_pretrained(SAVE_DIR) diff --git a/hf_torchao_vllm/utils.py b/hf_torchao_vllm/utils.py new file mode 100644 index 0000000..50a7210 --- /dev/null +++ b/hf_torchao_vllm/utils.py @@ -0,0 +1,59 @@ +import json +import os +import pathlib + +import safetensors + +import torch + +torch.serialization.add_safe_globals([getattr]) + +def _inspect_state_dict_file(model_name): + if str(model_name).endswith('safetensors'): + # safetensors format + with safetensors.safe_open(model_name, framework='pt', device='cpu') as f: + print(f.metadata()) + for k in f.keys(): + t = f.get_tensor(k) + print(k, type(t), t.shape, t.dtype) + else: + # pytorch format + state_dict = torch.load(model_name, weights_only=True) + for k, v in state_dict.items(): + print(k, type(v), v.shape, v.dtype) + +def inspect_model_state_dict(dir_name, model_name, model_extension) -> None: + """ + Inspect the model state_dict from HuggingFace and print data to stdout. + For example, if model_name == `pytorch_model` and extension == `bin`, + 1. if there is a single chunk, the state dict is named `pytorch_model.bin` + 2. if there are multiple chunks, the state dict is spread across multiple + files: + + pytorch_model-00001-of-00004.bin + ... + pytorch_model-00004-of-00004.bin + pytorch_model.bin.index.json + """ + is_single_chunk = os.path.isfile(f'{dir_name}/{model_name}.{model_extension}') + if is_single_chunk: + print('single state dict file') + model_name = f'{dir_name}/{model_name}.{model_extension}' + _inspect_state_dict_file(model_name) + else: + print('multiple state dict files') + + index_name = f'{dir_name}/{model_name}.{model_extension}.index.json' + print(index_name) + with open(index_name, 'r') as f: + data = json.load(f) + print(json.dumps(data, indent=2)) + + # iterate through each file + for file_path in pathlib.Path(dir_name).iterdir(): + if not file_path.is_file(): + continue + if not (model_name in str(file_path) and str(file_path).endswith(model_extension)): + continue + print(file_path) + _inspect_state_dict_file(file_path)