Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion extension/llm/export/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass

from executorch.extension.export_util.utils import export_to_edge, save_pte_program
from executorch.extension.llm.utils import get_tokenizer
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer import Quantizer
Expand Down Expand Up @@ -87,6 +88,10 @@ def __init__(
self.output_dir = "."
self.dynamic_shapes = dynamic_shapes
self._saved_pte_filename = None
self.calibration_tasks = None
self.calibration_limit = None
self.calibration_seq_length = None
self.tokenizer_path = None

def set_output_dir(self, output_dir: str) -> "LLMEdgeManager":
"""
Expand Down Expand Up @@ -166,6 +171,37 @@ def capture_pre_autograd_graph(self) -> "LLMEdgeManager":
)
return self

def calibrate(
self,
prepared_module,
calibration_tasks,
calibration_limit,
calibration_seq_length,
tokenizer_path,
):
logging.info("run calibration...")
try:
from executorch.examples.models.llama2.evaluate import EagerEvalWrapper
except ImportError:
raise ImportError(
"Please install the llm eval dependency via examples/models/llama2/install_requirements.sh"
)

tokenizer = get_tokenizer(tokenizer_path)
eval_wrapper = EagerEvalWrapper(
Copy link
Collaborator

@shewu-quic shewu-quic Aug 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @cccclai,
Thanks for this change.
It seems to have a problem when I capture seq_len = 1 in kv_cache mode but calibrate with whole sequence in batch process.
From my understanding, after we capture, some variable will be fixed such as "batch, seqlen, _ = x.shape".
If I am mistaken, please correct me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi thank you for checking out the pr! I pulled some changes from another local patch and haven't tested this pr properly. What you mentioned is true...I vaguelly remember I modified the code there to use kv cache instead, however it was too slow to calibrate (as you observed too)

model=prepared_module.to(device="cuda"),
tokenizer=tokenizer,
max_seq_length=calibration_seq_length,
use_kv_cache=self.use_kv_cache,
)
eval_results = eval(
eval_wrapper,
tasks=["wikitext"],
limit=calibration_limit,
)
for task, res in eval_results["results"].items():
print(f"{task}: {res}")

def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManager":
"""
Quantize the model via pt2e flow and retrieve LLMEdgeManager including the quantized model.
Expand All @@ -189,7 +225,24 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
), "Please run capture_pre_autograd_graph first"
m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer)
# Calibrate
m(*self.example_inputs)
if (
self.calibration_tasks is not None
and self.calibration_limit is not None
and self.calibration_seq_length is not None
and self.tokenizer_path is not None
):
self.calibrate(
prepared_module=m,
calibration_tasks=self.calibration_tasks,
calibration_limit=self.calibration_limit,
calibration_seq_length=self.calibration_seq_length,
tokenizer_path=self.tokenizer_path,
)
else:
logging.info(
"No calibration provided, using dummy input to calibrate..."
)
m(*self.example_inputs)
m = convert_pt2e(m)
DuplicateDynamicQuantChainPass()(m)
self.pre_autograd_graph_module = m
Expand Down