Skip to content

Unsupported FX Nodes: {'call_function': ['aten.quantized_gru.input', 'quantized.linear_dynamic.default']} #2074

Open
@Sukriti-Mehrotra

Description

@Sukriti-Mehrotra

Hello,

I am trying to convert a torchao quantized deep learning model(consisting of Linear, GRU layers, etc) to onnx but running into the error: Unsupported FX nodes: {'call_function': ['aten.quantized_gru.input', 'quantized.linear_dynamic.default']}.

Post-Training Quantization(using torch.ao.quantization.quantize_fx)

The quantization method used is Post-Training Dynamic Int8 Quantization(weights-only) in FX mode.
Adding the snippet of quantizing the model and saving it as .pth below:

input_tensor = torch.randn(batch_size, 3840)
qconfig_mapping = QConfigMapping().set_global(torch.ao.quantization.default_dynamic_qconfig)
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, input_tensor)
model_quantized = quantize_fx.convert_fx(model_prepared)
torch.save(model_quantized, "fx_quant.pth")

Conversion to ONNX

Upon converting the quantized model to onnx:

onnx_program = torch.onnx.dynamo_export(model_quantized, input_tensor)

I run into the below error:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/_exporter_legacy.py", line 1222, in dynamo_export
    ).export()
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/_exporter_legacy.py", line 976, in export
    graph_module = self.options.fx_tracer.generate_fx(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/dynamo_graph_extractor.py", line 217, in generate_fx
    return self.pre_export_passes(options, model, graph_module, updated_model_args)  # type: ignore[return-value]
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/dynamo_graph_extractor.py", line 226, in pre_export_passes
    return _exporter_legacy.common_pre_export_passes(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/_exporter_legacy.py", line 1275, in common_pre_export_passes
    ).analyze(infra.levels.ERROR)
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/analysis/unsupported_nodes.py", line 85, in analyze
    self._lint(analysis_result, diagnostic_level)
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/analysis/unsupported_nodes.py", line 37, in _lint
    self.diagnostic_context.log_and_raise_if_error(diagnostic)
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/diagnostics/infra/context.py", line 356, in log_and_raise_if_error
    raise RuntimeErrorWithDiagnostic(diagnostic)
torch.onnx._internal.diagnostics.infra.context.RuntimeErrorWithDiagnostic: Unsupported FX nodes: {'call_function': ['aten.quantized_gru.input', 'quantized.linear_dynamic.default']}.

Note

  1. the conversion to onnx with the original model(without quantization) happens successfully
  2. I can find the support for op aten_quantized_gru_cell. is it possible to make use of this and if yes, then how?

report_dynamo_export.sarif

{
 "runs":[
  {
   "tool":{
    "driver":{
     "name":"torch.onnx.dynamo_export",
     "contents":[
      "localizedData",
      "nonLocalizedData"
     ],
     "language":"en-US",
     "rules":[
      {
       "id":"FXE0012",
       "fullDescription":{
        "text":"Result from FX graph analysis to reveal unsupported FX nodes.",
        "markdown":"This error indicates that an FX graph contains one or more unsupported nodes. The error message\nis typically accompanied by a list of the unsupported nodes found during analysis.\n\nTo resolve this error, you can try resolving each individual unsupported node error by following\nthe suggestions by its diagnostic. Typically, options include:\n\n- If exists, apply the auto-fix suggested by the diagnostic. TODO: this part is not available yet.\n- Rewrite the model using only supported PyTorch operators or functions.\n- Follow this [guide](https://pytorch.org/docs/stable/onnx.html#onnx-script-functions) to write and\n  register a custom symbolic function for the unsupported call_function FX node.\n"
       },
       "name":"unsupported-fx-node-analysis",
       "shortDescription":{
        "text":"Result from FX graph analysis to reveal unsupported FX nodes."
       }
      },
      {
       "id":"FXE0015",
       "fullDescription":{
        "text":"Determine if type promotion is required for the FX node. Insert cast nodes if needed.",
        "markdown":"This diagnostic monitors the node-level type promotion insertion process. In PyTorch, there is an automatic process called implicit type promotion,\nwhere the input types of an operator are promoted to a common type. The determination of the common type is based on the type promotion rule specific to each operator.\nTo learn more about PyTorch's type promotion rules, refer to the [elementwise_dtypes doc](https://github.com/pytorch/pytorch/blob/f044613f78df713fb57f70c608483c9f10ad332e/torch/_prims_common/__init__.py#L1252-L1335)\nand [torch._refs ops](https://github.com/pytorch/pytorch/blob/a475ea4542dfe961c9d097e33ab5041f61c8c17f/torch/_refs/__init__.py#L484).\n\nHowever, implicit type promotion is not supported in ONNX. Therefore, to replicate the PyTorch behavior, we need to explicitly insert cast nodes.\nThis diagnostic tracks the process of node-level type promotion insertion.\n\nThe type promotion rules used by this process can be found in `torch/onnx/_internal/fx/passes/type_promotion.py.`\nTo update or add new type promotion rules, please refer to the [Note: Update type promotion rule] section.\n"
       },
       "name":"fx-node-insert-type-promotion",
       "shortDescription":{
        "text":"Determine if type promotion is required for the FX node. Insert cast nodes if needed."
       }
      },
      {
       "id":"FXE0010",
       "fullDescription":{
        "text":"FX graph transformation during ONNX export before converting from FX IR to ONNX IR.",
        "markdown":"This diagnostic tracks the FX passes executed during the ONNX export process prior\nto converting from FX IR (Intermediate Representation) to ONNX IR.\n\nUnder the scope of ONNX export, an FX pass refers to a specific transformation applied to the FX GraphModule.\nThe primary aim of these passes is to streamline the graph into a format that aligns more with the ONNX IR.\nMoreover, these passes work to substitute unsupported FX IR features with those recognized and endorsed by\nONNX IR. Common transformations include, but aren't limited to, decomposition, functionalization and\ntype promotion.\n\nFor those who are interested in a comprehensive log detailing the modifications made during these passes,\nthere are a couple of options:\n\n- Set DiagnosticOptions.verbosity_level to logging.DEBUG.\n- Activate the environment variable TORCH_LOGS='onnx_diagnostics'.\n\nHowever, it's noteworthy that by default, such detailed logging is turned off. The primary reason being\nits considerable impact on performance.\n\nFor an in-depth understanding of each specific pass, please refer to the directory: torch/onnx/_internal/fx/passes.\n"
       },
       "name":"fx-pass",
       "shortDescription":{
        "text":"FX graph transformation during ONNX export before converting from FX IR to ONNX IR."
       }
      }
     ],
     "version":"2.5.1+cu124"
    }
   },
   "language":"en-US",
   "newlineSequences":[
    "\r\n",
    "\n"
   ],
   "results":[
    {
     "message":{
      "markdown":"Running Decompose pass. \n\n## Additional Message:\n\n## Function Signature\n### Function Signature Transform.run\n- self: <class 'torch.onnx._internal.fx.passes.decomp.Decompose'>\n- args: Tuple[length=1](\nTensor(f32[1, 3840]),\n)\nFor detailed logging of graph modifications by this pass, either set `DiagnosticOptions.verbosity_level` to `logging.DEBUG` or use the environment variable `TORCH_LOGS='onnx_diagnostics'`.\n## Return values\ntorch.fx.GraphModule(<lambda>)",
      "text":"Running Decompose pass. "
     },
     "codeFlows":[
      {
       "threadFlows":[
        {
         "locations":[]
        }
       ]
      }
     ],
     "graphs":[],
     "kind":"informational",
     "level":"none",
     "locations":[
      {
       "message":{
        "text":"Transform.run"
       },
       "physicalLocation":{
        "artifactLocation":{
         "uri":"/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/_pass.py"
        },
        "region":{
         "snippet":{
          "text":"@diagnostics.diagnose_call("
         },
         "startLine":240
        }
       }
      }
     ],
     "properties":{
      "tags":[]
     },
     "ruleId":"FXE0010",
     "stacks":[]
    },
    {
     "message":{
      "markdown":"Running Functionalize pass. \n\n## Additional Message:\n\n## Function Signature\n### Function Signature Transform.run\n- self: <class 'torch.onnx._internal.fx.passes.functionalization.Functionalize'>\n- args: Tuple[length=1](\nTensor(f32[1, 3840]),\n)\nFor detailed logging of graph modifications by this pass, either set `DiagnosticOptions.verbosity_level` to `logging.DEBUG` or use the environment variable `TORCH_LOGS='onnx_diagnostics'`.\n## Return values\ntorch.fx.GraphModule(<lambda>)",
      "text":"Running Functionalize pass. "
     },
     "codeFlows":[
      {
       "threadFlows":[
        {
         "locations":[]
        }
       ]
      }
     ],
     "graphs":[],
     "kind":"informational",
     "level":"none",
     "locations":[
      {
       "message":{
        "text":"Transform.run"
       },
       "physicalLocation":{
        "artifactLocation":{
         "uri":"/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/_pass.py"
        },
        "region":{
         "snippet":{
          "text":"@diagnostics.diagnose_call("
         },
         "startLine":240
        }
       }
      }
     ],
     "properties":{
      "tags":[]
     },
     "ruleId":"FXE0010",
     "stacks":[]
    },
    {
     "message":{
      "markdown":"Running RemoveInputMutation pass. \n\n## Additional Message:\n\n## Function Signature\n### Function Signature Transform.run\n- self: <class 'torch.onnx._internal.fx.passes.functionalization.RemoveInputMutation'>\n- args: Tuple[length=1](\nTensor(f32[1, 3840]),\n)\nFor detailed logging of graph modifications by this pass, either set `DiagnosticOptions.verbosity_level` to `logging.DEBUG` or use the environment variable `TORCH_LOGS='onnx_diagnostics'`.\n## Return values\ntorch.fx.GraphModule(<lambda>)",
      "text":"Running RemoveInputMutation pass. "
     },
     "codeFlows":[
      {
       "threadFlows":[
        {
         "locations":[]
        }
       ]
      }
     ],
     "graphs":[],
     "kind":"informational",
     "level":"none",
     "locations":[
      {
       "message":{
        "text":"Transform.run"
       },
       "physicalLocation":{
        "artifactLocation":{
         "uri":"/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/_pass.py"
        },
        "region":{
         "snippet":{
          "text":"@diagnostics.diagnose_call("
         },
         "startLine":240
        }
       }
      }
     ],
     "properties":{
      "tags":[]
     },
     "ruleId":"FXE0010",
     "stacks":[]
    },
    {
     "message":{
      "markdown":"Skipped l_x_: not a call_function.\n\n## Additional Message:\n\n## Function Signature\n### Function Signature _TypePromotionInterpreter.run_node\n- self: <class 'torch.onnx._internal.fx.passes.type_promotion._TypePromotionInterpreter'>\n- node: fx.Node(arg0)[placeholder]:Tensor(f32[1, 3840])\n## Return values\nTensor(f32[1, 3840])",
      "text":"Skipped l_x_: not a call_function."
     },
     "codeFlows":[
      {
       "threadFlows":[
        {
         "locations":[]
        }
       ]
      }
     ],
     "graphs":[],
     "kind":"informational",
     "level":"none",
     "locations":[
      {
       "message":{
        "text":"_TypePromotionInterpreter.run_node"
       },
       "physicalLocation":{
        "artifactLocation":{
         "uri":"/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/passes/type_promotion.py"
        },
        "region":{
         "snippet":{
          "text":"@diagnostics.diagnose_call("
         },
         "startLine":1607
        }
       }
      }
     ],
     "properties":{
      "tags":[]
     },
     "ruleId":"FXE0015",
     "stacks":[]
    },
    {
     "message":{
      "markdown":"Skipped for fx.Node(aten.unsqueeze.default)[call_function]:Tensor(f32[1, 1, 3840]): Cannot find type promotion rule for op: aten.unsqueeze.default\n\n## Additional Message:\n\n## Function Signature\n### Function Signature _TypePromotionInterpreter.run_node\n- self: <class 'torch.onnx._internal.fx.passes.type_promotion._TypePromotionInterpreter'>\n- node: fx.Node(aten.unsqueeze.default)[call_function]:Tensor(f32[1, 1, 3840])\n## Return values\nTensor(f32[1, 1, 3840])",
      "text":"Skipped for fx.Node(aten.unsqueeze.default)[call_function]:Tensor(f32[1, 1, 3840]): Cannot find type promotion rule for op: aten.unsqueeze.default"
     },
     "codeFlows":[
      {
       "threadFlows":[
        {
         "locations":[]
        }
       ]
      }
     ],
     "graphs":[],
     "kind":"informational",
     "level":"none",
     "locations":[
      {
       "message":{
        "text":"_TypePromotionInterpreter.run_node"
       },
       "physicalLocation":{
        "artifactLocation":{
         "uri":"/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/passes/type_promotion.py"
        },
        "region":{
         "snippet":{
          "text":"@diagnostics.diagnose_call("
         },
         "startLine":1607
        }
       }
      }
     ],
     "properties":{
      "tags":[]
     },
     "ruleId":"FXE0015",
     "stacks":[]
    },
    {
     "message":{
      "markdown":"Skipped _param_constant0: not a call_function.\n\n## Additional Message:\n\n## Function Signature\n### Function Signature _TypePromotionInterpreter.run_node\n- self: <class 'torch.onnx._internal.fx.passes.type_promotion._TypePromotionInterpreter'>\n- node: fx.Node(_param_constant0)[get_attr]:None\n## Return values\nParameter(Tensor(f32[1280, 1, 960]))",
      "text":"Skipped _param_constant0: not a call_function."
     },
     "codeFlows":[
      {
       "threadFlows":[
        {
         "locations":[]
        }
       ]
      }
     ],
     "graphs":[],
     "kind":"informational",
     "level":"none",
     "locations":[
      {
       "message":{
        "text":"_TypePromotionInterpreter.run_node"
       },
       "physicalLocation":{
        "artifactLocation":{
         "uri":"/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/passes/type_promotion.py"
        },
        "region":{
         "snippet":{
          "text":"@diagnostics.diagnose_call("
         },
         "startLine":1607
        }
       }
      }
     ],
     "properties":{
      "tags":[]
     },
     "ruleId":"FXE0015",
     "stacks":[]
    },
    {
     "message":{
      "markdown":"Skipped for fx.Node(aten.convolution.default)[call_function]:Tensor(f32[1, 1280, 7]): Cannot find type promotion rule for op: aten.convolution.default\n\n## Additional Message:\n\n## Function Signature\n### Function Signature _TypePromotionInterpreter.run_node\n- self: <class 'torch.onnx._internal.fx.passes.type_promotion._TypePromotionInterpreter'>\n- node: fx.Node(aten.convolution.default)[call_function]:Tensor(f32[1, 1280, 7])\n## Return values\nTensor(f32[1, 1280, 7])",
      "text":"Skipped for fx.Node(aten.convolution.default)[call_function]:Tensor(f32[1, 1280, 7]): Cannot find type promotion rule for op: aten.convolution.default"
     },
     "codeFlows":[
      {
       "threadFlows":[
        {
         "locations":[]
        }
       ]
      }
     ],
     "graphs":[],
     "kind":"informational",
     "level":"none",
     "locations":[
      {
       "message":{
        "text":"_TypePromotionInterpreter.run_node"
       },
       "physicalLocation":{
        "artifactLocation":{
         "uri":"/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/passes/type_promotion.py"
        },
        "region":{
         "snippet":{
          "text":"@diagnostics.diagnose_call("
         },
         "startLine":1607
        }
       }
      }
     ],
     "properties":{
      "tags":[]
     },
     "ruleId":"FXE0015",
     "stacks":[]
    },
    {
     "message":{
      "markdown":"Type promotion not needed for relu. \n\n## Additional Message:\n\n## Function Signature\n### Function Signature _TypePromotionInterpreter.run_node\n- self: <class 'torch.onnx._internal.fx.passes.type_promotion._TypePromotionInterpreter'>\n- node: fx.Node(aten.relu.default)[call_function]:Tensor(f32[1, 1280, 7])\nFound type promotion rule: ElementwiseTypePromotionRule('aten', 'relu', [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)\nArgument convolution is not promoted. Already torch.float32.\n## Return values\nTensor(f32[1, 1280, 7])",
      "text":"Type promotion not needed for relu. "
     },
     "codeFlows":[
      {
       "threadFlows":[
        {
         "locations":[]
        }
       ]
      }
     ],
     "graphs":[],
     "kind":"informational",
     "level":"none",
     "locations":[
      {
       "message":{
        "text":"_TypePromotionInterpreter.run_node"
       },
       "physicalLocation":{
        "artifactLocation":{
         "uri":"/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/passes/type_promotion.py"
        },
        "region":{
         "snippet":{
          "text":"@diagnostics.diagnose_call("
         },
         "startLine":1607
        }
       }
      }
     ],
     "properties":{
      "tags":[]
     },
     "ruleId":"FXE0015",
     "stacks":[]
    },
    {
     "message":{
      "markdown":"Skipped for fx.Node(aten.detach.default)[call_function]:Tensor(f32[1, 1280, 7]): Cannot find type promotion rule for op: aten.detach.default\n\n## Additional Message:\n\n## Function Signature\n### Function Signature _TypePromotionInterpreter.run_node\n- self: <class 'torch.onnx._internal.fx.passes.type_promotion._TypePromotionInterpreter'>\n- node: fx.Node(aten.detach.default)[call_function]:Tensor(f32[1, 1280, 7])\n## Return values\nTensor(f32[1, 1280, 7])",
      "text":"Skipped for fx.Node(aten.detach.default)[call_function]:Tensor(f32[1, 1280, 7]): Cannot find type promotion rule for op: aten.detach.default"
     },
 ................................(too long to paste here)

Is there any guideline on how to solve this problem and implement the support for the aforementioned operations?
Thank you, and sorry for the long post.

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: torchlibRelated to the torch/aten function lib in development

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions