diff --git a/.ci/scripts/validate.sh b/.ci/scripts/validate.sh index 55d37d3b5..650610a0e 100644 --- a/.ci/scripts/validate.sh +++ b/.ci/scripts/validate.sh @@ -285,7 +285,7 @@ function eval_model_sanity_check() { echo "******** INT4 group-wise quantized (AOTI) *******" echo "*************************************************" if [ "$DTYPE" != "float16" ]; then - python3 -W ignore export.py --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1 + python3 -W ignore export.py --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --dynamic-shapes --device "$TARGET_DEVICE" || exit 1 python3 -W ignore eval.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/output_eval_aoti" || exit 1 cat "$MODEL_DIR/output_eval_aoti" fi; diff --git a/build/builder.py b/build/builder.py index 2cc6e3ecc..ee1865f07 100644 --- a/build/builder.py +++ b/build/builder.py @@ -44,6 +44,7 @@ class BuilderArgs: use_distributed: bool = False is_chat_model: bool = False prefill_possible: bool = False + dynamic_shapes: bool = False def __post_init__(self): if self.device is None: @@ -157,6 +158,7 @@ def from_args(cls, args): # -> BuilderArgs: setup_caches=(output_dso_path or output_pte_path), use_distributed=args.distributed, is_chat_model=is_chat_model, + dynamic_shapes=getattr(args, "dynamic_shapes", False), ) @classmethod diff --git a/cli.py b/cli.py index d2093b9c2..9b9d7e22d 100644 --- a/cli.py +++ b/cli.py @@ -185,6 +185,11 @@ def _add_export_output_path_args(parser) -> None: default=None, help="Output to the specified AOT Inductor .dso model file", ) + parser.add_argument( + "--dynamic-shapes", + action="store_true", + help="Call torch.export with dynamic shapes", + ) # Add CLI Args representing user provided exported model files diff --git a/export.py b/export.py index 0d1285eab..1c9328da5 100644 --- a/export.py +++ b/export.py @@ -37,7 +37,10 @@ def export_for_server( - model: nn.Module, device: Optional[str] = "cpu", output_path: str = "model.dso" + model: nn.Module, + device: Optional[str] = "cpu", + output_path: str = "model.dso", + dynamic_shapes: bool = False, ) -> str: """ Export the model using AOT Compile to get a .dso for server use cases. @@ -49,16 +52,22 @@ def export_for_server( Returns: The path to the exported model. """ - input = ( - torch.tensor([[1, 9038, 2501, 263, 931]], dtype=torch.int, device=device), - torch.tensor([0, 1, 2, 3, 4], dtype=torch.int, device=device), - ) + if dynamic_shapes: + input = ( + torch.tensor([[1, 9038, 2501, 263, 931]], dtype=torch.int, device=device), + torch.tensor([0, 1, 2, 3, 4], dtype=torch.int, device=device), + ) - seq = Dim("seq", min=1, max=model.config.max_seq_length) - # Specify that the first dimension of each input is that batch size - dynamic_shapes = {"idx": {1: seq}, "input_pos": {0: seq}} + seq = Dim("seq", min=1, max=model.config.max_seq_length) + # Specify that the first dimension of each input is that batch size + dynamic_shapes = {"idx": {1: seq}, "input_pos": {0: seq}} + else: + input = ( + torch.tensor([[1]], dtype=torch.int, device=device), + torch.tensor([0], dtype=torch.int, device=device), + ) + dynamic_shapes = None - model.to(device) so = torch._export.aot_compile( model, args=input, @@ -143,7 +152,12 @@ def main(args): if output_dso_path: output_dso_path = str(os.path.abspath(output_dso_path)) print(f"Exporting model using AOT Inductor to {output_dso_path}") - export_for_server(model_to_dso, builder_args.device, output_dso_path) + export_for_server( + model_to_dso, + builder_args.device, + output_dso_path, + builder_args.dynamic_shapes, + ) if __name__ == "__main__":