-
Notifications
You must be signed in to change notification settings - Fork 383
create a new accuracy eval script for official README.md eval accuracy #3449
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,176 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD 3-Clause license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import argparse | ||
| import subprocess | ||
|
|
||
| import torch | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig | ||
|
|
||
| from torchao.quantization import ( | ||
| Float8DynamicActivationFloat8WeightConfig, | ||
| Float8DynamicActivationInt4WeightConfig, | ||
| Int4WeightOnlyConfig, | ||
| Int8DynamicActivationInt8WeightConfig, | ||
| Int8WeightOnlyConfig, | ||
| PerRow, | ||
| ) | ||
|
|
||
|
|
||
| def string_to_config(s): | ||
| if s is None: | ||
| return None | ||
| elif s == "float8_rowwise": | ||
| return Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) | ||
| elif s == "int4_groupwise_weight_float8_rowwise_activation": | ||
| return Float8DynamicActivationInt4WeightConfig() | ||
| elif s == "int4_groupwise_hqq_weight_only": | ||
| return Int4WeightOnlyConfig( | ||
| group_size=32, | ||
| int4_packing_format="tile_packed_to_4d", | ||
| int4_choose_qparams_algorithm="hqq", | ||
| ) | ||
| elif s == "int8_rowwise_weight_only": | ||
| return Int8WeightOnlyConfig() | ||
| elif s == "int8_rowwise": | ||
| return Int8DynamicActivationInt8WeightConfig() | ||
| else: | ||
| raise AssertionError(f"unsupported {s}") | ||
|
|
||
|
|
||
| def quantize_model_and_save(model_id, quant_config, output_dir="results"): | ||
| """Quantize the model and save it to the output directory.""" | ||
| print("Quantizing model with config: ", quant_config) | ||
| if quant_config is None: | ||
| quantization_config = None | ||
| else: | ||
| quantization_config = TorchAoConfig(quant_type=quant_config) | ||
| quantized_model = AutoModelForCausalLM.from_pretrained( | ||
| model_id, | ||
| device_map="auto", | ||
| dtype=torch.bfloat16, | ||
| quantization_config=quantization_config, | ||
| ) | ||
| tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
| quantized_model.save_pretrained(output_dir, safe_serialization=False) | ||
| tokenizer.save_pretrained(output_dir, safe_serialization=False) | ||
| return quantized_model, tokenizer | ||
|
|
||
|
|
||
| def run_lm_eval(model_dir, tasks_list=["hellaswag"], device="cuda:0", batch_size=8): | ||
| """Run the lm_eval command using subprocess.""" | ||
| tasks_str = ",".join(tasks_list) | ||
| command = [ | ||
| "lm_eval", | ||
| "--model", | ||
| "hf", | ||
| "--model_args", | ||
| f"pretrained={model_dir}", | ||
| "--tasks", | ||
| f"{tasks_str}", | ||
| "--device", | ||
| f"{device}", | ||
| "--batch_size", | ||
| f"{batch_size}", | ||
| "--output_path", | ||
| f"{model_dir}/lm_eval_outputs/", | ||
| ] | ||
| subprocess.run(command, check=True) | ||
|
|
||
|
|
||
| def get_size_of_dir(model_output_dir): | ||
| # get dir size from shell, to skip complexity of dealing with tensor | ||
| # subclasses | ||
| result = subprocess.run( | ||
| ["du", "-sb", model_output_dir], capture_output=True, text=True | ||
| ) | ||
| size = int(result.stdout.split()[0]) | ||
| return size | ||
|
|
||
|
|
||
| def run( | ||
| model_id: str, | ||
| quant_recipe_name: str | None, | ||
| tasks, | ||
| device, | ||
| batch_size, | ||
| model_output_dir, | ||
| ): | ||
| print(f"\nRunning {model_id=} with {quant_recipe_name=}\n") | ||
| model_name = model_id.split("/")[-1] | ||
| model_output_dir = ( | ||
| f"benchmarks/data/quantized_model/{model_name}-{quant_recipe_name}" | ||
| ) | ||
| quant_config = string_to_config(quant_recipe_name) | ||
| quantized_model, tokenizer = quantize_model_and_save( | ||
| model_id, quant_config=quant_config, output_dir=model_output_dir | ||
| ) | ||
| print(quantized_model) | ||
|
|
||
| model_size = get_size_of_dir(model_output_dir) / 1e9 | ||
| print(f"checkpoint size: {model_size} GB") | ||
|
|
||
| run_lm_eval( | ||
| model_output_dir, tasks_list=tasks, device=device, batch_size=batch_size | ||
| ) | ||
| print("done\n") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| try: | ||
| import lm_eval # noqa: F401 | ||
| except: | ||
| print( | ||
| "lm_eval is required to run this script. Please install it using pip install lm-eval." | ||
| ) | ||
| exit(0) | ||
|
|
||
| # Set up argument parser | ||
| parser = argparse.ArgumentParser( | ||
| description="Quantize a model and evaluate its throughput." | ||
| ) | ||
| parser.add_argument( | ||
| "--model_id", | ||
| type=str, | ||
| default="meta-llama/Llama-3.1-8B", | ||
| help="The model ID to use.", | ||
| ) | ||
| parser.add_argument( | ||
| "--quant_recipe_name", | ||
| type=str, | ||
| default=None, | ||
| help="The quantization recipe to use.", | ||
| ) | ||
| parser.add_argument( | ||
| "--tasks", | ||
| nargs="+", | ||
| type=str, | ||
| default=["wikitext"], | ||
| help="List of lm-eluther tasks to evaluate usage: --tasks task1 task2", | ||
| ) | ||
| parser.add_argument( | ||
| "--device", type=str, default="cuda:0", help="Device to run the model on." | ||
| ) | ||
| parser.add_argument( | ||
| "--batch_size", type=str, default="auto", help="Batch size for lm_eval." | ||
| ) | ||
| parser.add_argument( | ||
| "--output_dir", | ||
| type=str, | ||
| default="quantized_models", | ||
| help="Output directory for quantized model.", | ||
| ) | ||
| args = parser.parse_args() | ||
|
|
||
| # Use parsed arguments | ||
| run( | ||
| model_id=args.model_id, | ||
| quant_recipe_name=args.quant_recipe_name, | ||
| tasks=args.tasks, | ||
| device=args.device, | ||
| batch_size=args.batch_size, | ||
| model_output_dir=args.output_dir, | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| #!/bin/bash | ||
|
|
||
| set -e | ||
|
|
||
| # Get model_id as positional argument (optional) | ||
| MODEL_ID="${1:-meta-llama/Llama-3.1-8B}" | ||
|
|
||
| # Get log file as first positional argument (optional) | ||
| LOG_FILE="${2:-benchmarks/data/eval_accuracy_for_readme_log.txt}" | ||
|
|
||
| # Build the base command arguments | ||
| BASE_ARGS="--tasks wikitext winogrande" | ||
| if [[ -n "$MODEL_ID" ]]; then | ||
| BASE_ARGS="--model_id $MODEL_ID $BASE_ARGS" | ||
| fi | ||
|
|
||
| # baseline | ||
| # note: the -u flag is to prevent python from buffering stdout and stderr | ||
| # and make the output log file be in chronological order | ||
| time python -u benchmarks/quantization/eval_accuracy_for_readme.py $BASE_ARGS 2>&1 | tee "$LOG_FILE" | ||
|
|
||
| # quantized recipes | ||
| # note: | ||
| # * `int4_groupwise_hqq_weight_float8_rowwise_activation` doesn't work with dtype_map auto: https://gist.github.com/vkuzo/6b128681b628744d445c553cdeac8a85 | ||
| # * `int4_groupwise_hqq_weight_only` only works on A100 | ||
| for quant_recipe in float8_rowwise int4_groupwise_weight_float8_rowwise_activation int4_groupwise_hqq_weight_only int8_rowwise_weight_only int8_rowwise; do | ||
| time python -u benchmarks/quantization/eval_accuracy_for_readme.py $BASE_ARGS --quant_recipe_name $quant_recipe 2>&1 | tee -a "$LOG_FILE" | ||
| done | ||
|
|
||
| # TODO(future PR): script to parse the log file instead of manual copy-paste | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
int4_groupwise_weight_float8_rowwise_activation-->float8_rowwise_activation_int4_groupwise_weightto match the config name order?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm matching the order in https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#stable-workflows
overall if we want to standardize this everywhere sounds reasonable, IMO let's do that in a separate "rename-only" PR?