Skip to content

Commit 244ebc0

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Fix SDPA decomp problem
Summary: As titled. The new `_safe_softmax` function is meant to avoid NaN issues mostly in training. For inference, we shouldn't need it so we swap with the regular softmax, which will prevent the decomposition that introduces the unsupported ops (`eq`, `logical_not` and `any`). See https://www.internalfb.com/code/fbsource/fbcode/caffe2/torch/_decomp/decompositions.py?lines=425. Note that it needed some changes to `run_and_verify` since we now need some aten IR changes. I will fix it in another diff, where `run_and_verify` will use a nop quantizer instead. This way the code path will be the same for fp32 and quantized. But let's make CI green first! We will also need to formalize better how to apply passes on the initial graph module (aten IR passes as opposed to edge IR passes). Seems like lifted constants and other things like that can create issues, but unless we see errors, let's wait until the IR changes from PT/ET are in first. Reviewed By: hsharma35 Differential Revision: D61639074
1 parent 87b38cf commit 244ebc0

File tree

3 files changed

+50
-5
lines changed

3 files changed

+50
-5
lines changed

backends/cadence/aot/compiler.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
ReplaceLogicalNotBooleanWhereWithWherePass,
1919
ReplacePT2DequantWithCadenceDequantPass,
2020
ReplacePT2QuantWithCadenceQuantPass,
21+
ReplaceSafeSoftmaxWithSoftmax,
2122
ReplaceScalarTensorWithFullPass,
2223
ReplaceSqueezeAndUnsqueezeWithViewPass,
2324
)
2425
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
2526
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
26-
from executorch.backends.cadence.aot.utils import model_is_quantized
27+
from executorch.backends.cadence.aot.utils import model_gm_has_SDPA, model_is_quantized
2728
from executorch.backends.transforms.decompose_sdpa import (
2829
DecomposeScaledDotProductAttention,
2930
)
@@ -57,13 +58,20 @@ def convert_pt2(
5758
"""
5859

5960
# Export with dynamo
60-
model_exp = capture_pre_autograd_graph(model, inputs)
61+
model_gm = capture_pre_autograd_graph(model, inputs)
6162

62-
# Decompose SDPA
63-
DecomposeScaledDotProductAttention(False)(model_exp)
63+
if model_gm_has_SDPA(model_gm):
64+
# Decompose SDPA
65+
DecomposeScaledDotProductAttention(False)(model_gm)
66+
67+
# Swap _safe_softmax with _softmax (see https://github.com/pytorch/pytorch/pull/133882
68+
# for details).
69+
result = ReplaceSafeSoftmaxWithSoftmax()(model_gm)
70+
assert result is not None
71+
model_gm = result.graph_module
6472

6573
# Prepare
66-
prepared_model = prepare_pt2e(model_exp, quantizer)
74+
prepared_model = prepare_pt2e(model_gm, quantizer)
6775

6876
# Calibrate
6977
prepared_model(*inputs)

backends/cadence/aot/passes.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,29 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
266266
result = SpecPropPass()(graph_module)
267267
assert result is not None
268268
return result
269+
270+
271+
class ReplaceSafeSoftmaxWithSoftmax(ExportPass):
272+
"""
273+
Replace _safe_softmax with _softmax
274+
"""
275+
276+
def call_operator(
277+
self,
278+
op, # pyre-ignore
279+
args: tuple[Argument, ...],
280+
kwargs: dict[str, Argument],
281+
meta: NodeMetadata,
282+
) -> ProxyValue:
283+
if op != torch.ops.aten._safe_softmax.default:
284+
return super().call_operator(op, args, kwargs, meta)
285+
286+
# Add False for the half_to_float argument of softmax
287+
softmax_args = list(args) + [False]
288+
289+
return super().call_operator(
290+
torch.ops.aten._softmax.default,
291+
tuple(softmax_args),
292+
kwargs,
293+
meta,
294+
)

backends/cadence/aot/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,14 @@ def print_ops_info(
177177
tablefmt="outline",
178178
)
179179
)
180+
181+
182+
def model_gm_has_SDPA(model_gm: torch.fx.GraphModule) -> bool:
183+
for node in model_gm.graph.nodes:
184+
if node.op == "call_function":
185+
if (
186+
node.target
187+
== torch.ops.aten.scaled_dot_product_attention.default
188+
):
189+
return True
190+
return False

0 commit comments

Comments
 (0)