Skip to content

Arm backend bug: FoldAndAnnotateQParamsPass folds qdomain changes through aten.cat #18999

@Rob-Hughes-Arm

Description

@Rob-Hughes-Arm

🐛 Describe the bug

Arm backend bug: FoldAndAnnotateQParamsPass folds qdomain changes through aten.cat

Summary

The Arm backend currently lets FoldAndAnnotateQParamsPass fold an explicit
qdomain change through pass-through ops such as aten.cat.default.

That produces a graph where the folded metadata claims:

  • cat input qparams are int8 affine, zp=102
  • cat output qparams are int16 symmetric, zp=0
  • but there is no surviving requant node after cat

cat does not change integer codes, so that folded state is internally
inconsistent. Later lowering passes then reason about the tensor as if it were
already in the consumer qdomain, even though the live codes are still in the
producer qdomain. On the same sample input used by the repro, that turns an
expected add output of approximately
[[[[0.3, 0.5], [0.7, 0.9]], [[-0.1, 0.1], [0.3, 0.5]]]] into an effective
consumer result of approximately
[[[[0.1032, 0.1032], [0.1033, 0.1034]], [[0.1031, 0.1031], [0.1032, 0.1032]]]]
when the folded graph is interpreted literally.

The underlying bug is small and can be reproduced with a standalone
ExecuTorch/PyTorch script.

Impact

This creates an invalid intermediate representation before lowering:

  • producer codes are still in the pre-cat domain
  • cat.output_qparams describes a different post-cat domain
  • downstream passes have no explicit requant boundary left to materialize

This can cause downstream lowering to treat raw producer-domain codes as if
they were already in the consumer domain, leading to incorrect results. In the
sample below, the maximum absolute error at the add consumer is about
0.796655.

This is not only a metadata inconsistency. The same pattern was lowered to TOSA
and executed with the TOSA reference model, and the delegated model produced an
actually wrong tensor.

Standalone repro

Properties of the repro:

  • no project-specific helper imports
  • only uses torch, torchao, and executorch APIs
  • works from a plain ExecuTorch checkout; if executorch is not installed in
    editable mode, run it with PYTHONPATH=<repo>/src
  • builds a tiny graph:
input -> slice -> mul ----\
                           cat -> add(const)
input -> slice -> mul ----/
  • forces:
    • both mul outputs to fixed int8 affine qparams with zp=102
    • cat output to fixed int16 symmetric qparams
    • both add inputs/output to fixed int16 symmetric qparams

Copy-paste and run the following script:

from __future__ import annotations

import argparse
import copy
import json

import torch
import torch.nn as nn
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
    DQ_OPS,
    Q_OPS,
    FoldAndAnnotateQParamsPass,
    QuantArgs,
)
from executorch.backends.arm.quantizer import TOSAQuantizer
from executorch.backends.arm.quantizer.arm_quantizer import QuantizationSpec
from executorch.backends.arm.quantizer.arm_quantizer_utils import (
    mark_node_as_annotated,
)
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from executorch.backends.arm.tosa import TosaSpecification
from executorch.exir import EdgeCompileConfig, to_edge
from torchao.quantization.pt2e.observer import FixedQParamsObserver
from torchao.quantization.pt2e.quantizer import QuantizationAnnotation


def fixed_qspec(*, scale, zero_point, dtype, qscheme, quant_min, quant_max):
    observer_dtype = torch.qint8 if dtype is torch.int8 else torch.int16
    return QuantizationSpec(
        dtype=dtype,
        observer_or_fake_quant_ctr=FixedQParamsObserver.with_args(
            scale=scale,
            zero_point=zero_point,
            dtype=observer_dtype,
            qscheme=qscheme,
            quant_min=quant_min,
            quant_max=quant_max,
        ),
        quant_min=quant_min,
        quant_max=quant_max,
        qscheme=qscheme,
        is_dynamic=False,
    )


Q8 = fixed_qspec(
    scale=0.1,
    zero_point=102,
    dtype=torch.int8,
    qscheme=torch.per_tensor_affine,
    quant_min=-128,
    quant_max=127,
)
Q16 = fixed_qspec(
    scale=1.0 / 32767.0,
    zero_point=0,
    dtype=torch.int16,
    qscheme=torch.per_tensor_symmetric,
    quant_min=-32767,
    quant_max=32767,
)
QCONFIG = QuantizationConfig(
    input_activation=Q8,
    output_activation=Q8,
    weight=None,
    bias=None,
)


class Probe(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.register_buffer(
            "mul_weight",
            torch.tensor([[[[1.0]], [[2.0]]]], dtype=torch.float32),
        )
        self.register_buffer(
            "add_bias",
            torch.full((1, 2, 2, 2), 0.1, dtype=torch.float32),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        a = x[:, 0:1] * self.mul_weight[:, 0:1]
        b = x[:, 1:2] * self.mul_weight[:, 1:2]
        return torch.cat([a, b], dim=1) + self.add_bias


def annotate(node, *, input_nodes=None, output_qspec=None):
    qa = node.meta.get("quantization_annotation", QuantizationAnnotation())
    if input_nodes is not None:
        qa.input_qspec_map = {inp: input_nodes[inp] for inp in input_nodes}
    if output_qspec is not None:
        qa.output_qspec = output_qspec
    node.meta["quantization_annotation"] = qa
    mark_node_as_annotated(node)


class ProbeQuantizer(TOSAQuantizer):
    def __init__(self) -> None:
        super().__init__(TosaSpecification.create_from_string("TOSA-1.0+INT"))
        self.set_global(QCONFIG)
        self.set_io(QCONFIG)

    def _annotate_for_static_quantization_config(self, model):
        model = super()._annotate_for_static_quantization_config(model)

        muls = []
        cat = None
        add = None
        for node in model.graph.nodes:
            if node.op != "call_function":
                continue
            if node.target == torch.ops.aten.mul.Tensor:
                muls.append(node)
            elif node.target == torch.ops.aten.cat.default:
                cat = node
            elif node.target == torch.ops.aten.add.Tensor:
                add = node

        if len(muls) != 2 or cat is None or add is None:
            raise RuntimeError(
                f"Unexpected graph shape: muls={len(muls)} cat={cat is not None} add={add is not None}"
            )

        for mul in muls:
            annotate(
                mul,
                input_nodes={arg: Q8 for arg in mul.args if isinstance(arg, torch.fx.Node)},
                output_qspec=Q8,
            )

        annotate(
            cat,
            input_nodes={arg: Q8 for arg in cat.args[0] if isinstance(arg, torch.fx.Node)},
            output_qspec=Q16,
        )
        annotate(
            add,
            input_nodes={arg: Q16 for arg in add.args if isinstance(arg, torch.fx.Node)},
            output_qspec=Q16,
        )
        return model


def qargs_from_qdq(node):
    if node.target not in (*Q_OPS, *DQ_OPS):
        return None
    args = node.args
    return QuantArgs.from_operator(node.target, (args[0], args[1], args[2], *args[3:]))


def apply_patch():
    if getattr(FoldAndAnnotateQParamsPass, "_cat_qdomain_fold_repro_patch_applied", False):
        return

    original = FoldAndAnnotateQParamsPass.is_foldable

    def should_preserve(node):
        if node.op != "call_function":
            return False
        if "aten.cat.default" not in str(node.target):
            return False
        if len(node.args) == 0 or not isinstance(node.args[0], list):
            return False

        input_domain = None
        for arg in node.args[0]:
            if not isinstance(arg, torch.fx.Node) or arg.target not in DQ_OPS:
                return False
            q = qargs_from_qdq(arg)
            if q is None:
                return False
            if input_domain is None:
                input_domain = q
            elif input_domain != q:
                return False

        if input_domain is None:
            return False

        for user in node.users:
            if user.target not in Q_OPS:
                continue
            q = qargs_from_qdq(user)
            if q is not None and q != input_domain:
                return True
        return False

    def patched(node):
        return original(node) and not should_preserve(node)

    FoldAndAnnotateQParamsPass.is_foldable = staticmethod(patched)
    FoldAndAnnotateQParamsPass._cat_qdomain_fold_repro_patch_applied = True


def to_json_qargs_map(qargs_map):
    return {
        str(i): {
            "scale": float(q.get_scale_per_tensor()),
            "zero_point": int(q.get_zp_per_tensor()),
            "dtype": str(q.dtype),
            "qmin": int(q.qmin),
            "qmax": int(q.qmax),
        }
        for i, q in qargs_map.items()
    }


def run_case(*, patched: bool):
    if patched:
        apply_patch()

    model = Probe().eval()
    example_input = (
        torch.tensor(
            [[[[0.2, 0.4], [0.6, 0.8]], [[-0.1, 0.0], [0.1, 0.2]]]],
            dtype=torch.float32,
        ),
    )
    exported = torch.export.export(model, example_input)
    quantized = ProbeQuantizer().quantize_with_submodules(
        exported.module(check_guards=False),
        calibration_samples=[example_input],
        is_qat=False,
    )
    quantized_exported = torch.export.export(quantized, example_input)
    edge = to_edge(
        quantized_exported,
        compile_config=EdgeCompileConfig(_check_ir_validity=False),
    ).exported_program()

    folded = FoldAndAnnotateQParamsPass(exported_program=edge)(
        copy.deepcopy(edge.graph_module)
    ).graph_module

    cat = None
    add = None
    q_after_cat = []
    for node in folded.graph.nodes:
        if node.op != "call_function":
            continue
        target = str(node.target)
        if "aten.cat.default" in target:
            cat = node
        elif "aten.add.Tensor" in target:
            add = node
        elif node.target in Q_OPS:
            parent = node.args[0] if len(node.args) > 0 else None
            if isinstance(parent, torch.fx.Node) and "aten.cat.default" in str(parent.target):
                q_after_cat.append(node.name)

    if cat is None or add is None:
        raise RuntimeError("Failed to find cat/add after fold.")

    cat_input = to_json_qargs_map(cat.meta.get("input_qparams", {}))
    cat_output = to_json_qargs_map(cat.meta.get("output_qparams", {}))
    add_input = to_json_qargs_map(add.meta.get("input_qparams", {}))

    return {
        "patched": patched,
        "cat_input_qparams": cat_input,
        "cat_output_qparams": cat_output,
        "add_input_qparams": add_input,
        "cat_input_output_domain_match": next(iter(cat_input.values()), None)
        == next(iter(cat_output.values()), None),
        "surviving_q_after_cat": q_after_cat,
    }


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", choices=("upstream", "patched", "both"), default="both")
    args = parser.parse_args()

    results = []
    if args.mode in {"upstream", "both"}:
        results.append(run_case(patched=False))
    if args.mode in {"patched", "both"}:
        results.append(run_case(patched=True))
    print(json.dumps({"results": results}, indent=2))


if __name__ == "__main__":
    main()

check_guards=False is important here: newer torch.export builds materialize
an auxiliary _guards_fn call_module, and the current Arm quantization
prepare pipeline rejects that node before this repro reaches the fold pass.
That is a separate issue from the qdomain-fold bug being reported here.

Run from the ExecuTorch repo root:

PYTHONPATH=$PWD/src python repro.py --mode both

On Windows PowerShell:

$env:PYTHONPATH = "$PWD/src"
python .\repro.py --mode both

Actual result

With current upstream folding logic, the repro prints:

{
  "cat_input_qparams": {
    "0": {
      "scale": 0.10000000149011612,
      "zero_point": 102,
      "dtype": "torch.int8"
    }
  },
  "cat_output_qparams": {
    "0": {
      "scale": 3.0518509447574615e-05,
      "zero_point": 0,
      "dtype": "torch.int16"
    }
  },
  "add_input_qparams": {
    "0": {
      "scale": 3.0518509447574615e-05,
      "zero_point": 0,
      "dtype": "torch.int16"
    },
    "1": {
      "scale": 3.0518509447574615e-05,
      "zero_point": 0,
      "dtype": "torch.int16"
    }
  },
  "cat_input_output_domain_match": false,
  "surviving_q_after_cat": []
}

So the explicit qdomain change is gone, but cat still advertises an
int16 zp=0 output domain that it never actually produced.

Numeric consequence

The same standalone repro was also evaluated on real tensors using the sample
input embedded in the script. The key tensors are:

{
  "cat_live_codes_after_cat_int8": [[[[104, 106], [108, 110]], [[100, 102], [104, 106]]]],
  "cat_correct_codes_in_add_domain_int16": [[[[6553, 13107], [19660, 26214]], [[-6553, 0], [6553, 13107]]]],
  "expected_add_output_float": [[[[0.3, 0.5], [0.7, 0.9]], [[-0.1, 0.1], [0.3, 0.5]]]],
  "effective_add_output_float": [[[[0.1032, 0.1032], [0.1033, 0.1034]], [[0.1031, 0.1031], [0.1032, 0.1032]]]],
  "effective_matches_expected": false,
  "effective_max_abs_error": 0.796655,
  "effective_mae": 0.348334
}

Interpretation:

  • after the bad fold, the live cat codes are still the producer-domain
    int8 values 104, 106, ...
  • the consumer side expects post-cat int16 codes such as
    6553, 13107, ...
  • if a downstream pass trusts the folded metadata and materializes the consumer
    view from that state, the add result is numerically wrong

This turns the issue from a metadata inconsistency into a concrete tensor-level
miscompile witness.

Actual lowered execution

The same graph shape was also lowered to TOSA-1.0+INT+int16 and executed with
tosa_reference_model.

A control case where cat stays in the same qdomain as its producer and
consumer executes correctly:

{
  "actual_output_float": [[[[0.3, 0.5], [0.7, 0.9]], [[-0.1, 0.1], [0.3, 0.5]]]],
  "expected_output_float": [[[[0.3, 0.5], [0.7, 0.9]], [[-0.1, 0.1], [0.3, 0.5]]]],
  "max_abs_error": 0.0,
  "mae": 0.0
}

The qdomain-changing cat -> add case from this report lowers and runs, but
produces the wrong result:

{
  "actual_output_float": [[[[0.1, 0.1], [0.1, 0.1]], [[0.1, 0.1], [0.1, 0.1]]]],
  "expected_output_float": [[[[0.3, 0.5], [0.7, 0.9]], [[-0.1, 0.1], [0.3, 0.5]]]],
  "max_abs_error": 0.8,
  "mae": 0.35
}

This is a real execution failure of the lowered delegated graph, not just a
hand-derived interpretation of the folded metadata.

Worked example

Using the exact qparams from the repro:

  • producer domain Q8: scale=0.1, zp=102
  • consumer domain Q16: scale=1/32767, zp=0

and the exact sample input from the script, the cat tensor before the final
add is:

channel 0: [[ 0.2, 0.4],
            [ 0.6, 0.8]]

channel 1: [[-0.2, 0.0],
            [ 0.2, 0.4]]

The add bias is 0.1 everywhere.

Consider the top-left element of channel 0, which is 0.2.

Correct execution:

  1. Producer-side Q8 code is:
    round(0.2 / 0.1) + 102 = 104
  2. The explicit post-cat requant to Q16 should then produce:
    round(0.2 / (1/32767)) = 6553
  3. The add computes:
    0.2 + 0.1 = 0.3
  4. Correct Q16 add-output code is:
    round(0.3 / (1/32767)) = 9830
  5. The outer Q8 output is:
    round(0.3 / 0.1) + 102 = 105
  6. Dequantizing that gives the expected final value:
    (105 - 102) * 0.1 = 0.3

Buggy folded execution:

  1. After the bad fold, cat still physically produces the old producer-domain
    code 104
  2. But the folded metadata says the tensor is already in consumer-domain Q16
  3. Interpreting the live code 104 as Q16 gives:
    104 * (1/32767) = 0.00317
    instead of 0.2
  4. The add then computes:
    0.00317 + 0.1 = 0.10317
  5. Quantizing that back to the final Q8 output gives:
    round(0.10317 / 0.1) + 102 = 103
  6. Dequantizing that final code gives:
    (103 - 102) * 0.1 = 0.1

So this one element alone turns from the correct final result 0.3 into the
wrong final result 0.1.

The same collapse happens across the tensor because the live cat codes remain
small producer-domain Q8 values around 100..110. If those are misread as
consumer-domain Q16, they become tiny floats around 0.0031..0.0034; after
adding 0.1, they all quantize back to essentially the same final Q8 bucket,
which is why the lowered delegated execution returns a near-flat output of
0.1.

Expected result

For pass-through ops such as cat/concatenate/stack, one of these must be
true:

  1. folding is only allowed when producer and consumer qdomains are identical, or
  2. the explicit Q after the pass-through op must survive so a real requant can
    be lowered later

The graph should never claim that cat itself changed the integer code domain.

Why this is a FoldAndAnnotateQParamsPass bug

The problem should be fixed at fold time, not in a later rescale insertion pass.

Current behavior:

  1. FoldAndAnnotateQParamsPass sees:
    • DQ(int8 zp=102) -> cat -> Q(int16 zp=0)
  2. It folds away both the DQ inputs and the Q output.
  3. It stores:
    • cat.input_qparams = int8 zp=102
    • cat.output_qparams = int16 zp=0
  4. But cat is a pass-through op and has no arithmetic that could realize that
    qdomain change.

At that point the graph metadata is already wrong. Any later pass that trusts
cat.output_qparams as the live code-domain can generate a wrong rescale.

Proposed fix

Do not fold qdomain-changing boundaries through pass-through ops.

Concretely, in FoldAndAnnotateQParamsPass.is_foldable():

  • detect pass-through ops:
    • aten.cat.default
    • aten.concatenate.default
    • aten.stack.default
  • if all list inputs come from DQ nodes with the same qparams
  • and the op has a user Q node with different qparams
  • then return False

That preserves the explicit post-pass-through requant, which later lowering
passes can materialize correctly.

Proposed patch

diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py
index XXXXXXX..YYYYYYY 100644
--- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py
+++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py
@@
 class FoldAndAnnotateQParamsPass(ArmPass):
@@
     @staticmethod
     def is_foldable(node: Node) -> bool:
         if node.op != "call_function":
             return False
         # Don't fold chains of quant-ops into each other.
         if node.target in (*Q_OPS, *DQ_OPS):
             return False
+
+        # Pass-through ops like cat/concatenate/stack do not change integer
+        # codes. If their inputs are dequantized from one qdomain and the
+        # output is immediately re-quantized into a different qdomain, we must
+        # preserve that explicit requant boundary. Otherwise the node ends up
+        # carrying output_qparams for a code-domain it never actually produced.
+        if node.target in (
+            exir_ops.edge.aten.cat.default,
+            exir_ops.edge.aten.concatenate.default,
+            exir_ops.edge.aten.stack.default,
+        ):
+            input_nodes = node.args[0] if len(node.args) > 0 and isinstance(node.args[0], list) else None
+            if input_nodes:
+                input_qparams = None
+                all_inputs_are_same_dq_domain = True
+                for arg in input_nodes:
+                    if not isinstance(arg, Node) or arg.target not in DQ_OPS:
+                        all_inputs_are_same_dq_domain = False
+                        break
+                    args = arg.args
+                    arg_qparams = QuantArgs.from_operator(
+                        arg.target, (args[0], args[1], args[2], *args[3:])
+                    )
+                    if input_qparams is None:
+                        input_qparams = arg_qparams
+                    elif input_qparams != arg_qparams:
+                        all_inputs_are_same_dq_domain = False
+                        break
+
+                if all_inputs_are_same_dq_domain and input_qparams is not None:
+                    for user in node.users:
+                        if user.target not in Q_OPS:
+                            continue
+                        args = user.args
+                        output_qparams = QuantArgs.from_operator(
+                            user.target, (args[0], args[1], args[2], *args[3:])
+                        )
+                        if output_qparams != input_qparams:
+                            return False
 
         # Always fold q-dq into constant ops.
         if node.target in (
             exir_ops.edge.aten.full_like.default,
             *ComputeConstantOpsAOTPass.targeted_ops,
         ):
             return True

Validation of the proposed fix

The same repro was rerun with the above logic monkey-patched.

Patched result:

{
  "cat_input_qparams": {},
  "cat_output_qparams": {},
  "add_input_qparams": {
    "0": {
      "scale": 3.0518509447574615e-05,
      "zero_point": 0,
      "dtype": "torch.int16"
    },
    "1": {
      "scale": 3.0518509447574615e-05,
      "zero_point": 0,
      "dtype": "torch.int16"
    }
  },
  "cat_input_output_domain_match": true,
  "surviving_q_after_cat": [
    "quantized_decomposed_quantize_per_tensor_default_1"
  ]
}

This is the correct shape:

  • cat no longer pretends it changed qdomain
  • the explicit requant after cat survives
  • downstream lowering still sees the intended int16 domain at the add
  • the numeric witness returns the expected add output with zero error

Environment used to validate

  • local ExecuTorch checkout at 12c3e33c9b
  • torch==2.10.0+cpu
  • torchao==0.15.0
  • tosa_reference_model available from a local tosa-tools build
  • validated on 2026-04-20

Files involved upstream

  • executorch/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py
  • executorch/backends/arm/_passes/insert_rescales_pass.py
  • executorch/backends/arm/quantizer/quantization_annotator.py

The bug is introduced by the fold pass; later rescale passes only expose the
incorrect metadata more obviously.

Versions

local ExecuTorch checkout at 12c3e33c9b
torch==2.10.0+cpu
torchao==0.15.0
tosa_reference_model available from a local tosa-tools build

cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell

Metadata

Metadata

Assignees

No one assigned

    Labels

    partner: armFor backend delegation, kernels, demo, etc. from the 3rd-party partner, Arm

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions