From 3d364330869f479e0f85a18b46eaa4cd1c1ea7d0 Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Sat, 7 Sep 2024 16:33:38 -0700 Subject: [PATCH] Reland add proper calibration for pt2e flow (#5152) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/5152 See discussion in https://github.com/pytorch/executorch/pull/5095 Reland because of internal failure Differential Revision: D62323396 --- examples/models/llama2/eval_llama_lib.py | 82 ++++++++++++--- examples/models/llama2/export_llama_lib.py | 29 +++++- examples/models/llama2/tokenizer/targets.bzl | 16 +++ extension/llm/export/TARGETS | 1 + extension/llm/export/builder.py | 101 ++++++++++++++++++- extension/llm/tokenizer/targets.bzl | 22 +--- 6 files changed, 212 insertions(+), 39 deletions(-) diff --git a/examples/models/llama2/eval_llama_lib.py b/examples/models/llama2/eval_llama_lib.py index bd650fab1ad..2d10f5edc0a 100644 --- a/examples/models/llama2/eval_llama_lib.py +++ b/examples/models/llama2/eval_llama_lib.py @@ -29,6 +29,51 @@ ) +class GraphModuleEvalWrapper(EagerEvalWrapper): + """ + A wrapper class for ExecuTorch py-binded integration with the + lm-evaluation-harness library. + """ + + def __init__( + self, + model: torch.fx.GraphModule, + tokenizer: Union[SentencePieceTokenizer, Tiktoken], + max_seq_length: Optional[int] = None, + use_kv_cache: bool = False, + enable_dynamic_shape: bool = True, + ): + super().__init__( + model=model, tokenizer=tokenizer, max_seq_length=max_seq_length + ) + self._model = model.to(self.device) + self._use_kv_cache = use_kv_cache + self._enable_dynamic_shape = enable_dynamic_shape + + def _model_call(self, inps): + if self._use_kv_cache: + if not self._enable_dynamic_shape: + # graph module exported without dynamic shape won't work with a different shape. + # And we have to do single token prefill here. + result_logits = [] + for pos in range(inps.shape[-1]): + pos_tensor = torch.tensor([pos], dtype=torch.int64) + logits = self._model(inps[:, pos : pos + 1], pos_tensor) + result_logits.append(logits) + return torch.cat(result_logits, dim=1) + else: + pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) + # Batch process the whole sequence. + logits = self._model(inps[:, : self._max_seq_length], pos_tensor) + return logits + + else: + return self._model(inps) + + def _model_generate(self, context, max_length, eos_token_id): + raise Exception("unimplemented") + + class ETPybindEvalWrapper(EagerEvalWrapper): """ A wrapper class for ExecuTorch py-binded integration with the @@ -148,6 +193,13 @@ def gen_eval_wrapper( if torch.cuda.is_available() else manager.pre_autograd_graph_module.to(device="cpu") ) + return GraphModuleEvalWrapper( + model=model, + tokenizer=tokenizer, + max_seq_length=args.max_seq_length, + use_kv_cache=args.use_kv_cache, + enable_dynamic_shape=args.enable_dynamic_shape, + ) else: # TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch # for quantizers. Currently capture_pre_autograd_graph only works with --kv_cache, but @@ -158,21 +210,21 @@ def gen_eval_wrapper( else manager.model.eval().to(device="cpu") ) - # Save the checkpoint after the eager model preparation is done. - # The reason for this option is that the checkpoint can be used - # to do evaluations in other evaluation platforms, or with data - # that is not available in this eval_llama. We save the checkpoint - # here for consistency with eval_llama. The accuracy results we - # get from eval_llama can be used as a reference to other evaluations. - if args.output_eager_checkpoint_file is not None: - torch.save(model, args.output_eager_checkpoint_file) - - return EagerEvalWrapper( - model=model, - tokenizer=tokenizer, - max_seq_length=args.max_seq_length, - use_kv_cache=args.use_kv_cache, - ) + # Save the checkpoint after the eager model preparation is done. + # The reason for this option is that the checkpoint can be used + # to do evaluations in other evaluation platforms, or with data + # that is not available in this eval_llama. We save the checkpoint + # here for consistency with eval_llama. The accuracy results we + # get from eval_llama can be used as a reference to other evaluations. + if args.output_eager_checkpoint_file is not None: + torch.save(model, args.output_eager_checkpoint_file) + + return EagerEvalWrapper( + model=model, + tokenizer=tokenizer, + max_seq_length=args.max_seq_length, + use_kv_cache=args.use_kv_cache, + ) def build_args_parser() -> argparse.ArgumentParser: diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index e56d2fe848b..f6abc3aaf4e 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -16,7 +16,7 @@ from enum import Enum from json import JSONDecodeError from pathlib import Path -from typing import Optional, Union +from typing import List, Optional, Union import pkg_resources @@ -166,19 +166,25 @@ def build_args_parser() -> argparse.ArgumentParser: nargs="+", type=str, default=None, - help="Tasks for GPTQ calibration", + help="Tasks for GPTQ calibration from lm_eval", ) parser.add_argument( "--calibration_limit", type=int, default=None, - help="number of samples used for calibration", + help="number of samples used for calibration from lm_eval", ) parser.add_argument( "--calibration_seq_length", type=int, default=None, - help="Sequence length for GPTQ calibration", + help="Sequence length for GPTQ calibration from lm_eval", + ) + parser.add_argument( + "--calibration_data", + type=str, + default="Once upon a time", + help="Calibration prompts from users", ) parser.add_argument( "-t", @@ -420,6 +426,11 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager: generate_full_logits=args.generate_full_logits, weight_type=weight_type, enable_dynamic_shape=args.enable_dynamic_shape, + calibration_tasks=args.calibration_tasks, + calibration_limit=args.calibration_limit, + calibration_seq_length=args.calibration_seq_length, + calibration_data=args.calibration_data, + tokenizer_path=args.tokenizer_path, verbose=args.verbose, max_seq_len=args.max_seq_length, metadata_str=args.metadata, @@ -630,6 +641,11 @@ def _load_llama_model( generate_full_logits: bool = False, weight_type: WeightType = WeightType.LLAMA, enable_dynamic_shape: bool = False, + calibration_tasks: Optional[List[str]] = None, + calibration_limit: Optional[int] = None, + calibration_seq_length: Optional[int] = None, + calibration_data: Optional[str] = None, + tokenizer_path: Optional[str] = None, verbose: bool = False, max_seq_len: int = 128, metadata_str: Optional[str] = None, @@ -686,6 +702,11 @@ def _load_llama_model( use_kv_cache=use_kv_cache, example_inputs=example_inputs, enable_dynamic_shape=enable_dynamic_shape, + calibration_tasks=calibration_tasks, + calibration_limit=calibration_limit, + calibration_seq_length=calibration_seq_length, + calibration_data=calibration_data, + tokenizer_path=tokenizer_path, verbose=verbose, metadata=_load_llama_model_metadata( weight_type, diff --git a/examples/models/llama2/tokenizer/targets.bzl b/examples/models/llama2/tokenizer/targets.bzl index 70318740d6a..40f8f29ac1e 100644 --- a/examples/models/llama2/tokenizer/targets.bzl +++ b/examples/models/llama2/tokenizer/targets.bzl @@ -21,3 +21,19 @@ def define_common_targets(): "@EXECUTORCH_CLIENTS", ], ) + + runtime.python_library( + name = "tiktoken_py", + srcs = [ + "tiktoken.py", + ], + _is_external_target = True, + visibility = [ + "//bento/...", + "//bento_kernels/...", + "//executorch/...", + ], + deps = [ + "fbsource//third-party/pypi/tiktoken:tiktoken", + ], + ) diff --git a/extension/llm/export/TARGETS b/extension/llm/export/TARGETS index 75f5cf937e8..be9bc183dbe 100644 --- a/extension/llm/export/TARGETS +++ b/extension/llm/export/TARGETS @@ -33,5 +33,6 @@ runtime.python_library( "//executorch/exir:lib", "//executorch/exir/backend:backend_details", "//executorch/extension/export_util:export_util", + "//executorch/extension/llm/tokenizer:tokenizer_py_lib", ], ) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 6eecebb9466..bc64ae869fc 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -27,6 +27,7 @@ from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass from executorch.extension.export_util.utils import export_to_edge, save_pte_program +from executorch.extension.llm.tokenizer.utils import get_tokenizer from torch._export import capture_pre_autograd_graph from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.ao.quantization.quantizer import Quantizer @@ -68,6 +69,11 @@ def __init__( example_inputs, args: Optional[Any] = None, enable_dynamic_shape: bool = False, + calibration_tasks: Optional[List[str]] = None, + calibration_limit: Optional[int] = None, + calibration_seq_length: Optional[int] = None, + calibration_data: Optional[str] = None, + tokenizer_path: Optional[str] = None, verbose: bool = False, metadata: Optional[dict] = None, dynamic_shapes: Optional[Any] = None, @@ -90,6 +96,11 @@ def __init__( self.dynamic_shapes = dynamic_shapes self._saved_pte_filename = None self.args = args + self.calibration_tasks = calibration_tasks + self.calibration_limit = calibration_limit + self.calibration_seq_length = calibration_seq_length + self.calibration_data = calibration_data + self.tokenizer_path = tokenizer_path def set_output_dir(self, output_dir: str) -> "LLMEdgeManager": """ @@ -181,6 +192,69 @@ def capture_pre_autograd_graph(self) -> "LLMEdgeManager": return self + def pt2e_calibrate( + self, + prepared_module, + calibration_tasks, + calibration_limit, + calibration_seq_length, + calibration_data, + tokenizer_path, + ): + logging.info("Run calibration...") + try: + from executorch.examples.models.llama2.eval_llama_lib import ( + GraphModuleEvalWrapper, + ) + from executorch.examples.models.llama2.evaluate import evaluate_model + except ImportError: + raise ImportError( + "Please install the llm eval dependency via examples/models/llama2/install_requirements.sh" + ) + + tokenizer = get_tokenizer(tokenizer_path) + + def calibrate_template( + module: torch.fx.GraphModule, tokenizer, prompts: str, max_len: int + ): + # TODO: change criteria & support batch inputs if necessary + pos = torch.tensor(0, dtype=torch.int64) + token_list = tokenizer.encode(prompts, bos=True, eos=False) + + with torch.no_grad(): + while token_list[-1] != tokenizer.eos_id and pos < max_len: + logits = module( + torch.full((1, 1), token_list[pos]), + torch.tensor((pos,)), + ) + pos += 1 + if pos >= len(token_list): + token_list.append(torch.argmax(logits[:], dim=-1).item()) + + calibrate_template( + module=prepared_module, + tokenizer=tokenizer, + prompts=calibration_data, + max_len=calibration_seq_length, + ) + + eval_wrapper = GraphModuleEvalWrapper( + model=prepared_module, + tokenizer=tokenizer, + max_seq_length=calibration_seq_length, + use_kv_cache=self.use_kv_cache, + enable_dynamic_shape=self.enable_dynamic_shape, + ) + eval_results = evaluate_model( + eval_wrapper, + calibration_tasks, + calibration_limit, + ) + + for task, res in eval_results["results"].items(): + print(f"{task}: {res}") + logging.info("Calibration finish...") + def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManager": """ Quantize the model via pt2e flow and retrieve LLMEdgeManager including the quantized model. @@ -203,8 +277,33 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage self.pre_autograd_graph_module is not None ), "Please run capture_pre_autograd_graph first" m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer) + logging.info( + f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}" + ) # Calibrate - m(*self.example_inputs) + if ( + self.calibration_tasks is not None + and self.calibration_limit is not None + and self.calibration_seq_length is not None + and self.calibration_data is not None + and self.tokenizer_path is not None + ): + logging.info( + f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}" + ) + self.pt2e_calibrate( + prepared_module=m, + calibration_tasks=self.calibration_tasks, + calibration_limit=self.calibration_limit, + calibration_seq_length=self.calibration_seq_length, + calibration_data=self.calibration_data, + tokenizer_path=self.tokenizer_path, + ) + else: + logging.info( + "No calibration provided, using dummy input to calibrate..." + ) + m(*self.example_inputs) m = convert_pt2e(m) DuplicateDynamicQuantChainPass()(m) self.pre_autograd_graph_module = m diff --git a/extension/llm/tokenizer/targets.bzl b/extension/llm/tokenizer/targets.bzl index f8e4df095ca..fa6cc915c4b 100644 --- a/extension/llm/tokenizer/targets.bzl +++ b/extension/llm/tokenizer/targets.bzl @@ -11,36 +11,20 @@ def define_common_targets(): srcs = [ "__init__.py", "tokenizer.py", + "utils.py", ], base_module = "executorch.extension.llm.tokenizer", visibility = [ "//executorch/examples/...", "//executorch/extension/llm/tokenizer/...", + "//executorch/extension/llm/export/...", "//bento/...", "//bento_kernels/...", ], _is_external_target = True, - external_deps = [ - "sentencepiece-py", - ], - ) - - runtime.python_library( - name = "utils", - srcs = [ - "utils.py", - ], - base_module = "executorch.extension.llm.utils", - visibility = [ - "//executorch/examples/...", - "//executorch/extension/llm/tokenizer/...", - "//bento/...", - "//bento_kernels/...", - ], deps = [ - "//executorch/examples/models/llama2/tokenizer:tiktoken", + "//executorch/examples/models/llama2/tokenizer:tiktoken_py", ], - _is_external_target = True, external_deps = [ "sentencepiece-py", ],