diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 264e1e95ad3..30679ce057a 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -27,6 +27,7 @@ from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass from executorch.extension.export_util.utils import export_to_edge, save_pte_program +from executorch.extension.llm.utils import get_tokenizer from torch._export import capture_pre_autograd_graph from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.ao.quantization.quantizer import Quantizer @@ -87,6 +88,10 @@ def __init__( self.output_dir = "." self.dynamic_shapes = dynamic_shapes self._saved_pte_filename = None + self.calibration_tasks = None + self.calibration_limit = None + self.calibration_seq_length = None + self.tokenizer_path = None def set_output_dir(self, output_dir: str) -> "LLMEdgeManager": """ @@ -166,6 +171,37 @@ def capture_pre_autograd_graph(self) -> "LLMEdgeManager": ) return self + def calibrate( + self, + prepared_module, + calibration_tasks, + calibration_limit, + calibration_seq_length, + tokenizer_path, + ): + logging.info("run calibration...") + try: + from executorch.examples.models.llama2.evaluate import EagerEvalWrapper + except ImportError: + raise ImportError( + "Please install the llm eval dependency via examples/models/llama2/install_requirements.sh" + ) + + tokenizer = get_tokenizer(tokenizer_path) + eval_wrapper = EagerEvalWrapper( + model=prepared_module.to(device="cuda"), + tokenizer=tokenizer, + max_seq_length=calibration_seq_length, + use_kv_cache=self.use_kv_cache, + ) + eval_results = eval( + eval_wrapper, + tasks=["wikitext"], + limit=calibration_limit, + ) + for task, res in eval_results["results"].items(): + print(f"{task}: {res}") + def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManager": """ Quantize the model via pt2e flow and retrieve LLMEdgeManager including the quantized model. @@ -189,7 +225,24 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage ), "Please run capture_pre_autograd_graph first" m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer) # Calibrate - m(*self.example_inputs) + if ( + self.calibration_tasks is not None + and self.calibration_limit is not None + and self.calibration_seq_length is not None + and self.tokenizer_path is not None + ): + self.calibrate( + prepared_module=m, + calibration_tasks=self.calibration_tasks, + calibration_limit=self.calibration_limit, + calibration_seq_length=self.calibration_seq_length, + tokenizer_path=self.tokenizer_path, + ) + else: + logging.info( + "No calibration provided, using dummy input to calibrate..." + ) + m(*self.example_inputs) m = convert_pt2e(m) DuplicateDynamicQuantChainPass()(m) self.pre_autograd_graph_module = m