Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class BuilderArgs:
is_chat_model: bool = False
prefill_possible: bool = False
dynamic_shapes: bool = False
max_seq_length: Optional[int] = None

def __post_init__(self):
if self.device is None:
Expand Down Expand Up @@ -159,6 +160,7 @@ def from_args(cls, args): # -> BuilderArgs:
use_distributed=args.distributed,
is_chat_model=is_chat_model,
dynamic_shapes=getattr(args, "dynamic_shapes", False),
max_seq_length=getattr(args, "max_seq_length", None),
)

@classmethod
Expand Down Expand Up @@ -437,6 +439,7 @@ def _initialize_model(
builder_args,
quantize,
tokenizer=None,
max_seq_length=None,
):
print("Loading model...")

Expand Down Expand Up @@ -513,7 +516,7 @@ def _initialize_model(
if builder_args.setup_caches:
with torch.device(builder_args.device):
model.setup_caches(
max_batch_size=1, max_seq_length=model.config.max_seq_length
max_batch_size=1, max_seq_length=max_seq_length or model.config.max_seq_length
)

model.to(dtype=builder_args.precision)
Expand Down
10 changes: 10 additions & 0 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def add_arguments_for_verb(parser, verb: str) -> None:
_add_generation_args(parser, verb)
if verb == "export":
_add_export_output_path_args(parser)
_add_export_args(parser)
if verb == "eval":
_add_exported_input_path_args(parser)
_add_evaluation_args(parser)
Expand Down Expand Up @@ -185,11 +186,20 @@ def _add_export_output_path_args(parser) -> None:
default=None,
help="Output to the specified AOT Inductor .dso model file",
)


def _add_export_args(parser) -> None:
parser.add_argument(
"--dynamic-shapes",
action="store_true",
help="Call torch.export with dynamic shapes",
)
parser.add_argument(
"--max-seq-length",
type=int,
default=None,
help="Set maximum length sequence when before calling torch.export",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set the default to 300 and update the help string so that it's clear what the default is (300)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was my initial implementation. Then there was an issue when running eval.

When running eval, we use --dynamic-shapes which uses a larger max_seq_length, i.e. model.config.max_seq_length. But in theory, we should not stop user from calling eval with both options, something like --dynamic-shapes --max-seq-length 1000. When that happens, if args.max_seq_length has a default value, we will have no way to distinguish if args.max_seq_length is from a default value or from an intentional user overwriting.

)


# Add CLI Args representing user provided exported model files
Expand Down
9 changes: 9 additions & 0 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,19 @@ def main(args):
except:
tokenizer = None

if (
output_dso_path is not None
and builder_args.max_seq_length is None
and not builder_args.dynamic_shapes
):
print("Setting max_seq_length to 300 for DSO export.")
builder_args.max_seq_length = 300
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you shouldn't need this if you set the default in the other file

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add more details to the printout if that helps.


model = _initialize_model(
builder_args,
quantize,
tokenizer,
max_seq_length=builder_args.max_seq_length,
)
model_to_pte = model
model_to_dso = model
Expand Down