Skip to content

Commit

Permalink
[ONNX] More debug logging from fx to onnx
Browse files Browse the repository at this point in the history
Summary:
- Log fx graph name for 'fx-graph-to-onnx' diagnostic.
- Log fx graph and onnx graph under DEBUG verbosity level for 'fx-graph-to-onnx' diagnostic.
- Adjust unittest to run with diagnostics verbosity level logging.DEBUG.

ghstack-source-id: 2a6c0a52e50f81d028a8f6c92d61c57e9bf24a6f
Pull Request resolved: #107654
  • Loading branch information
BowenBao committed Aug 23, 2023
1 parent 478be31 commit 1e9bed2
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 2 deletions.
7 changes: 6 additions & 1 deletion test/onnx/onnx_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import copy
import dataclasses
import io
import logging
import os
import unittest
import warnings
Expand All @@ -32,6 +33,7 @@
import torch
from torch.onnx import _constants, verification
from torch.onnx._internal import _beartype
from torch.onnx._internal.fx import diagnostics
from torch.testing._internal.opinfo import core as opinfo_core
from torch.types import Number

Expand Down Expand Up @@ -274,13 +276,16 @@ def run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
export_options=torch.onnx.ExportOptions(
op_level_debug=self.op_level_debug,
dynamic_shapes=self.dynamic_shapes,
diagnostic_options=torch.onnx.DiagnosticOptions(
verbosity_level=logging.DEBUG
),
),
)
except torch.onnx.OnnxExporterError as e:
export_error = e
export_output = e.export_output

if verbose:
if verbose and diagnostics.is_onnx_diagnostics_log_artifact_enabled():
export_output.save_diagnostics(
f"test_report_{self._testMethodName}"
f"_op_level_debug_{self.op_level_debug}"
Expand Down
6 changes: 6 additions & 0 deletions torch/onnx/_internal/fx/_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,12 @@ def run(self, *args, **kwargs) -> torch.fx.GraphModule:
diagnostic = self.diagnostic_context.inflight_diagnostic(
rule=diagnostics.rules.fx_pass
)
diagnostic.info(
"For detailed logging of graph modifications by this pass, either set "
"`DiagnosticOptions.verbosity_level` to `logging.DEBUG` or use the environment variable "
"`TORCH_LOGS='onnx_diagnostics'`."
)

# Gather graph information before transform.
graph_diff_log_level = logging.DEBUG
if diagnostic.logger.isEnabledFor(graph_diff_log_level):
Expand Down
26 changes: 25 additions & 1 deletion torch/onnx/_internal/fx/fx_onnx_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ def _fx_node_to_onnx_message_formatter(
return f"FX Node: {node.op}:{node.target}[name={node.name}]. "


@_beartype.beartype
def _fx_graph_to_onnx_message_formatter(
fn: Callable,
self,
fx_graph_module: torch.fx.GraphModule,
*args,
**kwargs,
) -> str:
return f"FX Graph: {fx_graph_module._get_name()}. "


def _location_from_fx_stack_trace(
node_stack_trace: str,
) -> Optional[diagnostics.infra.Location]:
Expand Down Expand Up @@ -440,7 +451,10 @@ def run_node(
raise RuntimeError(f"Found node type not defined in torch.fx: {node.op}")

@_beartype.beartype
@diagnostics.diagnose_call(diagnostics.rules.fx_graph_to_onnx)
@diagnostics.diagnose_call(
diagnostics.rules.fx_graph_to_onnx,
diagnostic_message_formatter=_fx_graph_to_onnx_message_formatter,
)
def run(
self,
fx_graph_module: torch.fx.GraphModule,
Expand All @@ -460,6 +474,13 @@ def run(
`fx_graph_module` is a submodule. If not provided,
`fx_graph_module` is assumed to be the root module.
"""
diagnostic = self.diagnostic_context.inflight_diagnostic()
with diagnostic.log_section(logging.DEBUG, "FX Graph:"):
diagnostic.debug(
"```\n%s\n```",
diagnostics.LazyString(fx_graph_module.print_readable, False),
)

if parent_onnxscript_graph is not None:
# If parent_onnxscript_graph is provided, we assume fx_graph_module is a
# submodule representing a forward call of an nn.Module.
Expand Down Expand Up @@ -520,6 +541,9 @@ def run(
fx_name_to_onnxscript_value,
)

with diagnostic.log_section(logging.DEBUG, "ONNX Graph:"):
diagnostic.debug("```\n%s\n```", onnxscript_graph.torch_graph)

return onnxscript_graph

@_beartype.beartype
Expand Down

0 comments on commit 1e9bed2

Please sign in to comment.