From 03d6a2bdb693bcc92fe3d1f7632e8775944717a4 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Mon, 5 Aug 2024 07:58:24 -0700 Subject: [PATCH 1/4] [AOTI] Change export to use static shapes Summary: The inputs to model forward are with static shapes, so changing the export call to make sure more Inductor optimizations will take effect down the stream. This change by itself improves average tokens/sec from 29.60 to 33.43 on A100. Some following PRs will provide further perf gains. --- export.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/export.py b/export.py index 0d1285eab..7cd3a3aa2 100644 --- a/export.py +++ b/export.py @@ -23,8 +23,6 @@ from build.utils import set_backend, set_precision from cli import add_arguments_for_verb, arg_init, check_args -from torch.export import Dim - try: executorch_export_available = True from export_util.export_et import export_model as export_model_et @@ -50,20 +48,15 @@ def export_for_server( The path to the exported model. """ input = ( - torch.tensor([[1, 9038, 2501, 263, 931]], dtype=torch.int, device=device), - torch.tensor([0, 1, 2, 3, 4], dtype=torch.int, device=device), + torch.tensor([[1]], dtype=torch.int, device=device), + torch.tensor([0], dtype=torch.int, device=device), ) - seq = Dim("seq", min=1, max=model.config.max_seq_length) - # Specify that the first dimension of each input is that batch size - dynamic_shapes = {"idx": {1: seq}, "input_pos": {0: seq}} - model.to(device) so = torch._export.aot_compile( model, args=input, options={"aot_inductor.output_path": output_path}, - dynamic_shapes=dynamic_shapes, ) print(f"The generated DSO model can be found at: {so}") return so From 69388a21c240e005af1c1621359a02d169627501 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Mon, 5 Aug 2024 14:48:52 -0700 Subject: [PATCH 2/4] Add a dynamic-shapes option for export --- .ci/scripts/validate.sh | 2 +- build/builder.py | 1 + export.py | 35 ++++++++++++++++++++++++++++------- 3 files changed, 30 insertions(+), 8 deletions(-) diff --git a/.ci/scripts/validate.sh b/.ci/scripts/validate.sh index 55d37d3b5..650610a0e 100644 --- a/.ci/scripts/validate.sh +++ b/.ci/scripts/validate.sh @@ -285,7 +285,7 @@ function eval_model_sanity_check() { echo "******** INT4 group-wise quantized (AOTI) *******" echo "*************************************************" if [ "$DTYPE" != "float16" ]; then - python3 -W ignore export.py --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1 + python3 -W ignore export.py --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --dynamic-shapes --device "$TARGET_DEVICE" || exit 1 python3 -W ignore eval.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/output_eval_aoti" || exit 1 cat "$MODEL_DIR/output_eval_aoti" fi; diff --git a/build/builder.py b/build/builder.py index 2cc6e3ecc..469a93249 100644 --- a/build/builder.py +++ b/build/builder.py @@ -44,6 +44,7 @@ class BuilderArgs: use_distributed: bool = False is_chat_model: bool = False prefill_possible: bool = False + dynamic_shapes: bool = False def __post_init__(self): if self.device is None: diff --git a/export.py b/export.py index 7cd3a3aa2..1c9328da5 100644 --- a/export.py +++ b/export.py @@ -23,6 +23,8 @@ from build.utils import set_backend, set_precision from cli import add_arguments_for_verb, arg_init, check_args +from torch.export import Dim + try: executorch_export_available = True from export_util.export_et import export_model as export_model_et @@ -35,7 +37,10 @@ def export_for_server( - model: nn.Module, device: Optional[str] = "cpu", output_path: str = "model.dso" + model: nn.Module, + device: Optional[str] = "cpu", + output_path: str = "model.dso", + dynamic_shapes: bool = False, ) -> str: """ Export the model using AOT Compile to get a .dso for server use cases. @@ -47,16 +52,27 @@ def export_for_server( Returns: The path to the exported model. """ - input = ( - torch.tensor([[1]], dtype=torch.int, device=device), - torch.tensor([0], dtype=torch.int, device=device), - ) + if dynamic_shapes: + input = ( + torch.tensor([[1, 9038, 2501, 263, 931]], dtype=torch.int, device=device), + torch.tensor([0, 1, 2, 3, 4], dtype=torch.int, device=device), + ) + + seq = Dim("seq", min=1, max=model.config.max_seq_length) + # Specify that the first dimension of each input is that batch size + dynamic_shapes = {"idx": {1: seq}, "input_pos": {0: seq}} + else: + input = ( + torch.tensor([[1]], dtype=torch.int, device=device), + torch.tensor([0], dtype=torch.int, device=device), + ) + dynamic_shapes = None - model.to(device) so = torch._export.aot_compile( model, args=input, options={"aot_inductor.output_path": output_path}, + dynamic_shapes=dynamic_shapes, ) print(f"The generated DSO model can be found at: {so}") return so @@ -136,7 +152,12 @@ def main(args): if output_dso_path: output_dso_path = str(os.path.abspath(output_dso_path)) print(f"Exporting model using AOT Inductor to {output_dso_path}") - export_for_server(model_to_dso, builder_args.device, output_dso_path) + export_for_server( + model_to_dso, + builder_args.device, + output_dso_path, + builder_args.dynamic_shapes, + ) if __name__ == "__main__": From 8d80acf258602aee68de13e819945c306ad73772 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Mon, 5 Aug 2024 17:05:18 -0700 Subject: [PATCH 3/4] Actually add --dynamic-shapes to CLI --- build/builder.py | 1 + cli.py | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/build/builder.py b/build/builder.py index 469a93249..c653dc7f7 100644 --- a/build/builder.py +++ b/build/builder.py @@ -158,6 +158,7 @@ def from_args(cls, args): # -> BuilderArgs: setup_caches=(output_dso_path or output_pte_path), use_distributed=args.distributed, is_chat_model=is_chat_model, + dynamic_shapes=args.dynamic_shapes, ) @classmethod diff --git a/cli.py b/cli.py index d2093b9c2..9b9d7e22d 100644 --- a/cli.py +++ b/cli.py @@ -185,6 +185,11 @@ def _add_export_output_path_args(parser) -> None: default=None, help="Output to the specified AOT Inductor .dso model file", ) + parser.add_argument( + "--dynamic-shapes", + action="store_true", + help="Call torch.export with dynamic shapes", + ) # Add CLI Args representing user provided exported model files From 334654e9b9ee7fa6757eab656f025e0e97acd6a5 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Mon, 5 Aug 2024 17:15:50 -0700 Subject: [PATCH 4/4] Access args.dynamic_shapes correctly --- build/builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/builder.py b/build/builder.py index c653dc7f7..ee1865f07 100644 --- a/build/builder.py +++ b/build/builder.py @@ -158,7 +158,7 @@ def from_args(cls, args): # -> BuilderArgs: setup_caches=(output_dso_path or output_pte_path), use_distributed=args.distributed, is_chat_model=is_chat_model, - dynamic_shapes=args.dynamic_shapes, + dynamic_shapes=getattr(args, "dynamic_shapes", False), ) @classmethod