Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/scripts/validate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ function eval_model_sanity_check() {
echo "******** INT4 group-wise quantized (AOTI) *******"
echo "*************************************************"
if [ "$DTYPE" != "float16" ]; then
python3 -W ignore export.py --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
python3 -W ignore export.py --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --dynamic-shapes --device "$TARGET_DEVICE" || exit 1
python3 -W ignore eval.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/output_eval_aoti" || exit 1
cat "$MODEL_DIR/output_eval_aoti"
fi;
Expand Down
2 changes: 2 additions & 0 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class BuilderArgs:
use_distributed: bool = False
is_chat_model: bool = False
prefill_possible: bool = False
dynamic_shapes: bool = False

def __post_init__(self):
if self.device is None:
Expand Down Expand Up @@ -157,6 +158,7 @@ def from_args(cls, args): # -> BuilderArgs:
setup_caches=(output_dso_path or output_pte_path),
use_distributed=args.distributed,
is_chat_model=is_chat_model,
dynamic_shapes=getattr(args, "dynamic_shapes", False),
)

@classmethod
Expand Down
5 changes: 5 additions & 0 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,11 @@ def _add_export_output_path_args(parser) -> None:
default=None,
help="Output to the specified AOT Inductor .dso model file",
)
parser.add_argument(
"--dynamic-shapes",
action="store_true",
help="Call torch.export with dynamic shapes",
)


# Add CLI Args representing user provided exported model files
Expand Down
34 changes: 24 additions & 10 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@


def export_for_server(
model: nn.Module, device: Optional[str] = "cpu", output_path: str = "model.dso"
model: nn.Module,
device: Optional[str] = "cpu",
output_path: str = "model.dso",
dynamic_shapes: bool = False,
) -> str:
"""
Export the model using AOT Compile to get a .dso for server use cases.
Expand All @@ -49,16 +52,22 @@ def export_for_server(
Returns:
The path to the exported model.
"""
input = (
torch.tensor([[1, 9038, 2501, 263, 931]], dtype=torch.int, device=device),
torch.tensor([0, 1, 2, 3, 4], dtype=torch.int, device=device),
)
if dynamic_shapes:
input = (
torch.tensor([[1, 9038, 2501, 263, 931]], dtype=torch.int, device=device),
torch.tensor([0, 1, 2, 3, 4], dtype=torch.int, device=device),
)

seq = Dim("seq", min=1, max=model.config.max_seq_length)
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"idx": {1: seq}, "input_pos": {0: seq}}
seq = Dim("seq", min=1, max=model.config.max_seq_length)
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"idx": {1: seq}, "input_pos": {0: seq}}
else:
input = (
torch.tensor([[1]], dtype=torch.int, device=device),
torch.tensor([0], dtype=torch.int, device=device),
)
dynamic_shapes = None

model.to(device)
so = torch._export.aot_compile(
model,
args=input,
Expand Down Expand Up @@ -143,7 +152,12 @@ def main(args):
if output_dso_path:
output_dso_path = str(os.path.abspath(output_dso_path))
print(f"Exporting model using AOT Inductor to {output_dso_path}")
export_for_server(model_to_dso, builder_args.device, output_dso_path)
export_for_server(
model_to_dso,
builder_args.device,
output_dso_path,
builder_args.dynamic_shapes,
)


if __name__ == "__main__":
Expand Down