diff --git a/torch/ao/quantization/pt2e/port_metadata_pass.py b/torch/ao/quantization/pt2e/port_metadata_pass.py index b0fe72f678740..5f535ca686031 100644 --- a/torch/ao/quantization/pt2e/port_metadata_pass.py +++ b/torch/ao/quantization/pt2e/port_metadata_pass.py @@ -22,7 +22,6 @@ __all__ = ["PortNodeMetaForQDQ"] _METADATA_TO_PORT = [ - "nn_module_stack", "stack_trace", "quantization_tag", ] @@ -167,6 +166,12 @@ class PortNodeMetaForQDQ(_ExportPassBase): - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> choose_params -> Q -> DQ -> Linear] - Quantized [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> DQ -> [choose_params -> Q -> DQ -> Linear]] - Note first Q does not inherit metadata from any nodes + NB: + - The best place for porting metadata is during observer conversion to q/dq. This is because it precisely + knows which quantization spec is converted to q/dq and thus from where the metadata should be ported. + However, since FX and PT2E quant workflow are on a common code-base, this hurts readability quite a bit. + Doing it via a separate pass, helps readability of the code. Once we are able to refactor PT2E quant + code, this pass should like to be integrated in the refactored variant of "convert" step. """ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: diff --git a/torch/ao/quantization/pt2e/utils.py b/torch/ao/quantization/pt2e/utils.py index d58dd11136722..171f1426abe2b 100644 --- a/torch/ao/quantization/pt2e/utils.py +++ b/torch/ao/quantization/pt2e/utils.py @@ -14,6 +14,7 @@ from torch.ao.quantization.quantizer import QuantizationAnnotation + __all__ = [ "fold_bn_weights_into_conv_node", "get_aten_graph_module", @@ -34,15 +35,19 @@ ] -def _is_connected(next_node: torch.fx.Node, target: torch.fx.Node) -> bool: - if target.op == "output": - return False - if next_node == target: - return True - for n in next_node.users.keys(): - if _is_connected(n, target): - return True - return False +def _is_connected(source: torch.fx.Node, dest: torch.fx.Node) -> bool: + """ + Assuming dest is one of the ops inserted by quant workflow, this function + finds if source and dest are connected. Assumption is that only quant workflow + inserted ops exist between source and dest + """ + quant_workflow_ops = _QUANTIZE_OPS + _DEQUANTIZE_OPS + quant_workflow_ops.append(torch.ops.quantized_decomposed.choose_qparams.tensor) + while dest.target in quant_workflow_ops: + if not isinstance(dest.args[0], torch.fx.Node): + raise ValueError(f"expected arg[0] of quant workflow ops to be a node but found {dest.args[0]}") + dest = dest.args[0] + return (dest == source) def _find_q_dq_node_for_user(