From f02b81fc23a902504a0056f121d82a6bf2180fef Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Wed, 10 Sep 2025 18:39:21 -0700 Subject: [PATCH 1/2] Fix get_max_seq_len metadata method not found --- extension/llm/export/builder.py | 45 +++++++++++++-------------------- 1 file changed, 18 insertions(+), 27 deletions(-) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 4fa220b0565..1b15486f11f 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -93,7 +93,7 @@ def __init__( calibration_data: Optional[str] = None, tokenizer_path: Optional[str] = None, verbose: bool = False, - metadata: Optional[dict] = None, + metadata: dict = {}, dynamic_shapes: Optional[Any] = None, save_exported_program: bool = False, generate_etrecord: bool = False, @@ -132,13 +132,23 @@ def __init__( self.output_dir = "." self._saved_pte_filename = None - def __post_init__(self): - """ - Post init function to update metadata based on dynamic shape - """ - dynamic_shape = self._get_dynamic_shape() - if dynamic_shape is not None: - token_dim = dynamic_shape[0][1] + # Try to resolve dynamic shapes if not specified explicitly. + if not self.dynamic_shapes and self.enable_dynamic_shape: + if not self.use_kv_cache: + # Only one input argument: tokens + # Here we -1 due to export limitation: https://gist.github.com/larryliu0820/419022a57e24d5e64150e325a685eaad + self.dynamic_shapes = ( + {1: torch.export.Dim("token_dim", max=self.max_seq_len - 1)}, + ) + else: + # Two input arguments: tokens and input_pos but input_pos is static shape + self.dynamic_shapes = ( + {1: torch.export.Dim("token_dim", max=self.max_seq_len)}, + {"input_pos": {0: 1}}, + ) + + if self.dynamic_shapes is not None: + token_dim = self.dynamic_shapes[0][1] if self.verbose: logging.info( f"Metadata 'get_max_seq_len' is being updated to match torch.export's dynamic shape max: {token_dim.max}" @@ -189,25 +199,6 @@ def source_transform( return self def _get_dynamic_shape(self) -> Any: - if self.dynamic_shapes: - return self.dynamic_shapes - - if self.enable_dynamic_shape: - if not self.use_kv_cache: - # Only one input argument: tokens - # Here we -1 due to export limitation: https://gist.github.com/larryliu0820/419022a57e24d5e64150e325a685eaad - self.dynamic_shapes = ( - {1: torch.export.Dim("token_dim", max=self.max_seq_len - 1)}, - ) - else: - # Two input arguments: tokens and input_pos but input_pos is static shape - self.dynamic_shapes = ( - {1: torch.export.Dim("token_dim", max=self.max_seq_len)}, - {"input_pos": {0: 1}}, - ) - else: - # Two input arguments: tokens and input_pos but both are of static shape - self.dynamic_shapes = None return self.dynamic_shapes def _get_edge_config(self) -> EdgeCompileConfig: From dd8fc329ee078b15d0eaed77de7ccd161d1ae23c Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Thu, 11 Sep 2025 07:26:55 -0700 Subject: [PATCH 2/2] PR review --- examples/models/llava/export_llava.py | 1 + extension/llm/export/builder.py | 13 +++---------- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/examples/models/llava/export_llava.py b/examples/models/llava/export_llava.py index 9ba6a510736..f903e0f2ecf 100644 --- a/examples/models/llava/export_llava.py +++ b/examples/models/llava/export_llava.py @@ -242,6 +242,7 @@ def export_all(llava_model: LlavaModel): XnnpackPartitioner(), ], }, + constant_methods={"get_max_seq_len": llava_model.max_seq_len}, compile_config=EdgeCompileConfig(_check_ir_validity=False), ) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 1b15486f11f..01000f3564c 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -93,7 +93,7 @@ def __init__( calibration_data: Optional[str] = None, tokenizer_path: Optional[str] = None, verbose: bool = False, - metadata: dict = {}, + metadata: Optional[dict] = None, dynamic_shapes: Optional[Any] = None, save_exported_program: bool = False, generate_etrecord: bool = False, @@ -114,7 +114,8 @@ def __init__( self.calibration_data = calibration_data self.tokenizer_path = tokenizer_path self.verbose = verbose - self.metadata = metadata + self.metadata = metadata if metadata is not None else {} + self.metadata["get_max_seq_len"] = max_seq_len self.dynamic_shapes = dynamic_shapes self.save_exported_program = save_exported_program self.generate_etrecord = generate_etrecord @@ -147,14 +148,6 @@ def __init__( {"input_pos": {0: 1}}, ) - if self.dynamic_shapes is not None: - token_dim = self.dynamic_shapes[0][1] - if self.verbose: - logging.info( - f"Metadata 'get_max_seq_len' is being updated to match torch.export's dynamic shape max: {token_dim.max}" - ) - self.metadata["get_max_seq_len"] = token_dim.max - def set_output_dir(self, output_dir: str) -> "LLMEdgeManager": """ Set the directory where the .pte file will be saved.