diff --git a/examples/models/llava/export_llava.py b/examples/models/llava/export_llava.py index 9ba6a510736..f903e0f2ecf 100644 --- a/examples/models/llava/export_llava.py +++ b/examples/models/llava/export_llava.py @@ -242,6 +242,7 @@ def export_all(llava_model: LlavaModel): XnnpackPartitioner(), ], }, + constant_methods={"get_max_seq_len": llava_model.max_seq_len}, compile_config=EdgeCompileConfig(_check_ir_validity=False), ) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 4fa220b0565..01000f3564c 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -114,7 +114,8 @@ def __init__( self.calibration_data = calibration_data self.tokenizer_path = tokenizer_path self.verbose = verbose - self.metadata = metadata + self.metadata = metadata if metadata is not None else {} + self.metadata["get_max_seq_len"] = max_seq_len self.dynamic_shapes = dynamic_shapes self.save_exported_program = save_exported_program self.generate_etrecord = generate_etrecord @@ -132,18 +133,20 @@ def __init__( self.output_dir = "." self._saved_pte_filename = None - def __post_init__(self): - """ - Post init function to update metadata based on dynamic shape - """ - dynamic_shape = self._get_dynamic_shape() - if dynamic_shape is not None: - token_dim = dynamic_shape[0][1] - if self.verbose: - logging.info( - f"Metadata 'get_max_seq_len' is being updated to match torch.export's dynamic shape max: {token_dim.max}" + # Try to resolve dynamic shapes if not specified explicitly. + if not self.dynamic_shapes and self.enable_dynamic_shape: + if not self.use_kv_cache: + # Only one input argument: tokens + # Here we -1 due to export limitation: https://gist.github.com/larryliu0820/419022a57e24d5e64150e325a685eaad + self.dynamic_shapes = ( + {1: torch.export.Dim("token_dim", max=self.max_seq_len - 1)}, + ) + else: + # Two input arguments: tokens and input_pos but input_pos is static shape + self.dynamic_shapes = ( + {1: torch.export.Dim("token_dim", max=self.max_seq_len)}, + {"input_pos": {0: 1}}, ) - self.metadata["get_max_seq_len"] = token_dim.max def set_output_dir(self, output_dir: str) -> "LLMEdgeManager": """ @@ -189,25 +192,6 @@ def source_transform( return self def _get_dynamic_shape(self) -> Any: - if self.dynamic_shapes: - return self.dynamic_shapes - - if self.enable_dynamic_shape: - if not self.use_kv_cache: - # Only one input argument: tokens - # Here we -1 due to export limitation: https://gist.github.com/larryliu0820/419022a57e24d5e64150e325a685eaad - self.dynamic_shapes = ( - {1: torch.export.Dim("token_dim", max=self.max_seq_len - 1)}, - ) - else: - # Two input arguments: tokens and input_pos but input_pos is static shape - self.dynamic_shapes = ( - {1: torch.export.Dim("token_dim", max=self.max_seq_len)}, - {"input_pos": {0: 1}}, - ) - else: - # Two input arguments: tokens and input_pos but both are of static shape - self.dynamic_shapes = None return self.dynamic_shapes def _get_edge_config(self) -> EdgeCompileConfig: