diff --git a/build/builder.py b/build/builder.py index ee1865f07..5f0876aab 100644 --- a/build/builder.py +++ b/build/builder.py @@ -45,6 +45,7 @@ class BuilderArgs: is_chat_model: bool = False prefill_possible: bool = False dynamic_shapes: bool = False + max_seq_length: Optional[int] = None def __post_init__(self): if self.device is None: @@ -159,6 +160,7 @@ def from_args(cls, args): # -> BuilderArgs: use_distributed=args.distributed, is_chat_model=is_chat_model, dynamic_shapes=getattr(args, "dynamic_shapes", False), + max_seq_length=getattr(args, "max_seq_length", None), ) @classmethod @@ -437,6 +439,7 @@ def _initialize_model( builder_args, quantize, tokenizer=None, + max_seq_length=None, ): print("Loading model...") @@ -513,7 +516,7 @@ def _initialize_model( if builder_args.setup_caches: with torch.device(builder_args.device): model.setup_caches( - max_batch_size=1, max_seq_length=model.config.max_seq_length + max_batch_size=1, max_seq_length=max_seq_length or model.config.max_seq_length ) model.to(dtype=builder_args.precision) diff --git a/cli.py b/cli.py index 9b9d7e22d..04718e6e7 100644 --- a/cli.py +++ b/cli.py @@ -68,6 +68,7 @@ def add_arguments_for_verb(parser, verb: str) -> None: _add_generation_args(parser, verb) if verb == "export": _add_export_output_path_args(parser) + _add_export_args(parser) if verb == "eval": _add_exported_input_path_args(parser) _add_evaluation_args(parser) @@ -185,11 +186,20 @@ def _add_export_output_path_args(parser) -> None: default=None, help="Output to the specified AOT Inductor .dso model file", ) + + +def _add_export_args(parser) -> None: parser.add_argument( "--dynamic-shapes", action="store_true", help="Call torch.export with dynamic shapes", ) + parser.add_argument( + "--max-seq-length", + type=int, + default=None, + help="Set maximum length sequence when before calling torch.export", + ) # Add CLI Args representing user provided exported model files diff --git a/export.py b/export.py index b0bd5be68..6068abb4d 100644 --- a/export.py +++ b/export.py @@ -113,10 +113,19 @@ def main(args): except: tokenizer = None + if ( + output_dso_path is not None + and builder_args.max_seq_length is None + and not builder_args.dynamic_shapes + ): + print("Setting max_seq_length to 300 for DSO export.") + builder_args.max_seq_length = 300 + model = _initialize_model( builder_args, quantize, tokenizer, + max_seq_length=builder_args.max_seq_length, ) model_to_pte = model model_to_dso = model