diff --git a/export.py b/export.py index 1c9328da5..b0bd5be68 100644 --- a/export.py +++ b/export.py @@ -68,12 +68,13 @@ def export_for_server( ) dynamic_shapes = None - so = torch._export.aot_compile( - model, - args=input, - options={"aot_inductor.output_path": output_path}, - dynamic_shapes=dynamic_shapes, - ) + with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]): + so = torch._export.aot_compile( + model, + args=input, + options={"aot_inductor.output_path": output_path}, + dynamic_shapes=dynamic_shapes, + ) print(f"The generated DSO model can be found at: {so}") return so