Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 49 additions & 22 deletions torchao_hf_vllm/torchao_hf_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,25 @@ def get_quantization_config(args):
case "autoquant":
return TorchAoConfig("autoquant", min_sqnr=args.min_sqnr)
case "fp8":
return TorchAoConfig(
Float8DynamicActivationFloat8WeightConfig(granularity=gran)
)
single_config = Float8DynamicActivationFloat8WeightConfig(granularity=gran)
if args.experts_only_qwen_1_5_moe_a_2_7b:
from torchao.quantization import ModuleFqnToConfig
expert_fqn_to_config = {}
# TODO(future PR): this is annoying, I should be able to use a regex here
for layer_idx in range(24):
for expert_idx in range(60):
expert_fqn_to_config[f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj"] = single_config
expert_fqn_to_config[f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj"] = single_config
expert_fqn_to_config[f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj"] = single_config
module_fqn_to_config = ModuleFqnToConfig({
"_default": None,
**expert_fqn_to_config,
})
return TorchAoConfig(
quant_type=module_fqn_to_config,
)
else:
return TorchAoConfig(single_config)
case "int4_weight_only":
return TorchAoConfig(Int4WeightOnlyConfig(group_size=128))
case "int8_weight_only":
Expand Down Expand Up @@ -148,12 +164,14 @@ def main(
benchmark: bool = False,
bench_tokens: int = 100,
device_map: str = "cuda",
experts_only_qwen_1_5_moe_a_2_7b: bool = False,
save_model_to_disk: bool = True,
):
"""
Quantize a model with TorchAO and test its performance.

Args:
model_name: Model to quantize (e.g., meta-llama/Meta-Llama-3-8B, facebook/opt-125m)
model_name: Model to quantize (e.g., meta-llama/Meta-Llama-3-8B, facebook/opt-125m, Qwen/Qwen1.5-MoE-A2.7B)
output_dir: Directory to save the quantized model
push_to_hub: HF Hub repo name to push the model (e.g., 'your-username/model-name')
quant_type: Quantization type to use
Expand All @@ -163,6 +181,8 @@ def main(
benchmark: Run benchmarking comparison
bench_tokens: Number of tokens to generate for benchmarking
device_map: Device mapping strategy
experts_only_qwen_1_5_moe_a_2_7b: if True, quantizes experts only for Qwen1.5-MoE-A2.7B model
save_model_to_disk: if True, saves quantized model to local disk
"""
# Set seed before creating the model
set_seed(42)
Expand All @@ -184,9 +204,13 @@ def main(
benchmark=benchmark,
bench_tokens=bench_tokens,
device_map=device_map,
experts_only_qwen_1_5_moe_a_2_7b=experts_only_qwen_1_5_moe_a_2_7b,
save_model_to_disk=save_model_to_disk,
)
print(f"Using Model name: {args.model_name}")
print(f"Quantization type: {args.quant_type}")
print(f"{args=}")

if args.experts_only_qwen_1_5_moe_a_2_7b:
assert args.quant_type == "fp8", "unsupported"

# Create output directory
output_dir = Path(args.output_dir)
Expand Down Expand Up @@ -228,10 +252,11 @@ def main(
generated_text = tokenizer.decode(output, skip_special_tokens=True)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

# Save quantized model
print(f"\n📁Saving quantized model to: {output_dir}")
quantized_model.save_pretrained(output_dir, safe_serialization=False)
tokenizer.save_pretrained(output_dir)
if args.save_model_to_disk:
# Save quantized model
print(f"\nSaving quantized model to: {output_dir}")
quantized_model.save_pretrained(output_dir, safe_serialization=False)
tokenizer.save_pretrained(output_dir)

# Push to HuggingFace hub if requested
if args.push_to_hub:
Expand All @@ -242,22 +267,24 @@ def main(
quantized_model.push_to_hub(model_name, safe_serialization=False)
tokenizer.push_to_hub(model_name)

# Load saved model to verify
print("\nLoading saved quantized model to verify...")
# TODO: do we really need `weights_only=False` here?
loaded_model = AutoModelForCausalLM.from_pretrained(
output_dir, device_map=args.device_map, torch_dtype="auto", weights_only=False,
)
if args.save_model_to_disk:
# Load saved model to verify
print("\nLoading saved quantized model to verify...")
# TODO: do we really need `weights_only=False` here?
loaded_model = AutoModelForCausalLM.from_pretrained(
output_dir, device_map=args.device_map, torch_dtype="auto", weights_only=False,
)

# Test loaded model with first prompt
test_prompt = prompts[0]
input_ids = tokenizer(test_prompt, return_tensors="pt").to(loaded_model.device)
output = loaded_model.generate(**input_ids, max_new_tokens=args.max_new_tokens)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(f"Verification - Prompt: {test_prompt!r}, Generated text: {generated_text!r}")
# Test loaded model with first prompt
test_prompt = prompts[0]
input_ids = tokenizer(test_prompt, return_tensors="pt").to(loaded_model.device)
output = loaded_model.generate(**input_ids, max_new_tokens=args.max_new_tokens)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(f"Verification - Prompt: {test_prompt!r}, Generated text: {generated_text!r}")

# Benchmark if requested
if args.benchmark:
assert args.save_model_to_disk, "unsupported"
print("\nBenchmarking models...")
# Benchmark quantized model
print("Benchmarking quantized model:")
Expand Down