Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/models/llava/export_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def export_all(llava_model: LlavaModel):
XnnpackPartitioner(),
],
},
constant_methods={"get_max_seq_len": llava_model.max_seq_len},
compile_config=EdgeCompileConfig(_check_ir_validity=False),
)

Expand Down
46 changes: 15 additions & 31 deletions extension/llm/export/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def __init__(
self.calibration_data = calibration_data
self.tokenizer_path = tokenizer_path
self.verbose = verbose
self.metadata = metadata
self.metadata = metadata if metadata is not None else {}
self.metadata["get_max_seq_len"] = max_seq_len
self.dynamic_shapes = dynamic_shapes
self.save_exported_program = save_exported_program
self.generate_etrecord = generate_etrecord
Expand All @@ -132,18 +133,20 @@ def __init__(
self.output_dir = "."
self._saved_pte_filename = None

def __post_init__(self):
"""
Post init function to update metadata based on dynamic shape
"""
dynamic_shape = self._get_dynamic_shape()
if dynamic_shape is not None:
token_dim = dynamic_shape[0][1]
if self.verbose:
logging.info(
f"Metadata 'get_max_seq_len' is being updated to match torch.export's dynamic shape max: {token_dim.max}"
# Try to resolve dynamic shapes if not specified explicitly.
if not self.dynamic_shapes and self.enable_dynamic_shape:
if not self.use_kv_cache:
# Only one input argument: tokens
# Here we -1 due to export limitation: https://gist.github.com/larryliu0820/419022a57e24d5e64150e325a685eaad
self.dynamic_shapes = (
{1: torch.export.Dim("token_dim", max=self.max_seq_len - 1)},
)
else:
# Two input arguments: tokens and input_pos but input_pos is static shape
self.dynamic_shapes = (
{1: torch.export.Dim("token_dim", max=self.max_seq_len)},
{"input_pos": {0: 1}},
)
self.metadata["get_max_seq_len"] = token_dim.max

def set_output_dir(self, output_dir: str) -> "LLMEdgeManager":
"""
Expand Down Expand Up @@ -189,25 +192,6 @@ def source_transform(
return self

def _get_dynamic_shape(self) -> Any:
if self.dynamic_shapes:
return self.dynamic_shapes

if self.enable_dynamic_shape:
if not self.use_kv_cache:
# Only one input argument: tokens
# Here we -1 due to export limitation: https://gist.github.com/larryliu0820/419022a57e24d5e64150e325a685eaad
self.dynamic_shapes = (
{1: torch.export.Dim("token_dim", max=self.max_seq_len - 1)},
)
else:
# Two input arguments: tokens and input_pos but input_pos is static shape
self.dynamic_shapes = (
{1: torch.export.Dim("token_dim", max=self.max_seq_len)},
{"input_pos": {0: 1}},
)
else:
# Two input arguments: tokens and input_pos but both are of static shape
self.dynamic_shapes = None
return self.dynamic_shapes

def _get_edge_config(self) -> EdgeCompileConfig:
Expand Down
Loading