From 5a9a4b5b240c7df047c37f4b54c816049c1ad869 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Mon, 5 Aug 2024 19:48:13 -0700 Subject: [PATCH] [AOTI] Set sdpa_kernel context when exporting Summary: This improves average tokens/sec from 33.43 to 72.63 on A100 for AOTI. ``` python3 torchchat.py export llama3 --quantize '{"precision": {"dtype":"bfloat16"}, "executor":{"accelerator":"cuda"}}' --output-dso-path /tmp/model16.so && python3 torchchat.py generate llama3 --dso-path /tmp/model16.so --prompt "Once upon a time," --max-new-tokens 256 --device cuda --num-samples 3 ``` --- export.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/export.py b/export.py index 1c9328da5..b0bd5be68 100644 --- a/export.py +++ b/export.py @@ -68,12 +68,13 @@ def export_for_server( ) dynamic_shapes = None - so = torch._export.aot_compile( - model, - args=input, - options={"aot_inductor.output_path": output_path}, - dynamic_shapes=dynamic_shapes, - ) + with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]): + so = torch._export.aot_compile( + model, + args=input, + options={"aot_inductor.output_path": output_path}, + dynamic_shapes=dynamic_shapes, + ) print(f"The generated DSO model can be found at: {so}") return so