Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions hf_torchao_vllm/inspect_llm_compressor_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import json
import fire

from utils import inspect_model_state_dict

def run(
dir_name: str = 'data/llmcompressor/fp8-opt-125m',
):
Expand All @@ -14,13 +16,17 @@ def run(
# TODO: pretty print
print(json.dumps(data, indent=2))

# inpect the model, saved in safetensors format
model_name = f'{dir_name}/model.safetensors'
with safetensors.safe_open(model_name, framework='pt', device='cpu') as f:
print(f.metadata())
for k in f.keys():
t = f.get_tensor(k)
print(k, t.shape, t.dtype)
model_name, model_extension = 'model', 'safetensors'
inspect_model_state_dict(dir_name, model_name, model_extension)

if False:
# inpect the model, saved in safetensors format
model_name = f'{dir_name}/model.safetensors'
with safetensors.safe_open(model_name, framework='pt', device='cpu') as f:
print(f.metadata())
for k in f.keys():
t = f.get_tensor(k)
print(k, t.shape, t.dtype)

if __name__ == '__main__':
fire.Fire(run)
21 changes: 16 additions & 5 deletions hf_torchao_vllm/inspect_torchao_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
# via the `torchao_hf_script.py` script

import json
import os
import pathlib
import torch
import torchao # this is needed to run torch.serialization.add_safe_globals([torchao.quantization.Float8Tensor])
import fire

from utils import inspect_model_state_dict

# not sure why I still need this
torch.serialization.add_safe_globals([getattr])

Expand All @@ -15,14 +19,21 @@ def run(dir_name: str = 'data/torchao/fp8-opt-125m'):
# inspect the config
with open(json_config_name, 'r') as f:
data = json.load(f)
# TODO: pretty print
print(json.dumps(data, indent=2))

# inspect the data
model_name = f'{dir_name}/pytorch_model.bin'
state_dict = torch.load(model_name, weights_only=True)
for k, v in state_dict.items():
print(k, v.shape, type(v))
#
# if there is a single chunk, the state dict is named `pytorch_model.bin`
#
# if there are multiple chunks, the state dict is spread across multiple files:
#
# pytorch_model-00001-of-00004.bin
# ...
# pytorch_model-00004-of-00004.bin
# pytorch_model.bin.index.json
#
model_name, model_extension = 'pytorch_model', 'bin'
inspect_model_state_dict(dir_name, model_name, model_extension)

if __name__ == '__main__':
fire.Fire(run)
26 changes: 17 additions & 9 deletions hf_torchao_vllm/quantize_hf_model_with_llm_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,29 @@

import fire

def run():

# MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
MODEL_ID = "facebook/opt-125m"

def run(model_name: str = 'facebook/opt-125m'):
# Load model.
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
print(model)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Configure the quantization algorithm and scheme.
# In this case, we:
# * quantize the weights to fp8 with per channel via ptq
# * quantize the activations to fp8 with dynamic per token
recipe = QuantizationModifier(
targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"]
targets="Linear",
scheme="FP8_DYNAMIC",
ignore=[
"lm_head",
# for Qwen MoE, but ok to just hardcode here for now
# https://github.com/vllm-project/llm-compressor/blob/33ef5f497a9801893764c6a2c880cb1f560067fa/examples/quantizing_moe/qwen_example.py#L10
"re:.*mlp.gate$",
"re:.*mlp.shared_expert_gate$",
# also skip attention and shared expert, to focus on MoE for now
"re:.*self_attn.*",
"re:.*shared_expert.*",
],
)

# Apply quantization.
Expand All @@ -41,7 +49,7 @@ def run():
print("==========================================")

# Save to disk in compressed-tensors format.
SAVE_DIR = "data/llmcompressor/" + "fp8-" + MODEL_ID.rstrip("/").split("/")[-1]
SAVE_DIR = "data/llmcompressor/" + "fp8-" + model_name.rstrip("/").split("/")[-1]
model.save_pretrained(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)

Expand Down
59 changes: 59 additions & 0 deletions hf_torchao_vllm/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import json
import os
import pathlib

import safetensors

import torch

torch.serialization.add_safe_globals([getattr])

def _inspect_state_dict_file(model_name):
if str(model_name).endswith('safetensors'):
# safetensors format
with safetensors.safe_open(model_name, framework='pt', device='cpu') as f:
print(f.metadata())
for k in f.keys():
t = f.get_tensor(k)
print(k, type(t), t.shape, t.dtype)
else:
# pytorch format
state_dict = torch.load(model_name, weights_only=True)
for k, v in state_dict.items():
print(k, type(v), v.shape, v.dtype)

def inspect_model_state_dict(dir_name, model_name, model_extension) -> None:
"""
Inspect the model state_dict from HuggingFace and print data to stdout.
For example, if model_name == `pytorch_model` and extension == `bin`,
1. if there is a single chunk, the state dict is named `pytorch_model.bin`
2. if there are multiple chunks, the state dict is spread across multiple
files:

pytorch_model-00001-of-00004.bin
...
pytorch_model-00004-of-00004.bin
pytorch_model.bin.index.json
"""
is_single_chunk = os.path.isfile(f'{dir_name}/{model_name}.{model_extension}')
if is_single_chunk:
print('single state dict file')
model_name = f'{dir_name}/{model_name}.{model_extension}'
_inspect_state_dict_file(model_name)
else:
print('multiple state dict files')

index_name = f'{dir_name}/{model_name}.{model_extension}.index.json'
print(index_name)
with open(index_name, 'r') as f:
data = json.load(f)
print(json.dumps(data, indent=2))

# iterate through each file
for file_path in pathlib.Path(dir_name).iterdir():
if not file_path.is_file():
continue
if not (model_name in str(file_path) and str(file_path).endswith(model_extension)):
continue
print(file_path)
_inspect_state_dict_file(file_path)