Skip to content

Commit

Permalink
Document FX debugging (#51530)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #51530

Test Plan: Imported from OSS

Reviewed By: jamesr66a

Differential Revision: D26192641

Pulled By: ansley

fbshipit-source-id: c69ab1bb2451d8ee5a729445f52bccc66e6f431b
  • Loading branch information
Ansley Ussery authored and facebook-github-bot committed Feb 3, 2021
1 parent f7313b3 commit ab4623d
Show file tree
Hide file tree
Showing 4 changed files with 246 additions and 4 deletions.
243 changes: 240 additions & 3 deletions docs/source/fx.rst
Expand Up @@ -7,15 +7,250 @@ Overview
--------
.. automodule:: torch.fx

.. _Writing Transformations:

Writing Transformations
-----------------------

TODO

Debugging Transformations
-------------------------

TODO
Debugging
-----------

Introduction
^^^^^^^^^^^^^^^^

After symbolically tracing an ``nn.Module`` and performing some number
of transformations on the resulting GraphModule, we'll want to verify
that the proper semantics were preserved after those transforms. If they
weren't, we may need to do some debugging. The key is to work
backwards: first, check the results of the generated module, then debug
the generated code, then debug the process of transformations that lead
to the generated code.

If you’re not familiar with debuggers, please see the auxiliary section
:ref:`Available Debuggers`.

Debugging the Generated Code
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Because FX generates the ``forward()`` function on GraphModules, using
traditional debugging techniques like ``print`` statements or ``pdb`` is
not as straightfoward. Luckily, we have several techniques we can use
for debugging the generated code.

Use ``pdb``
~~~~~~~~~~~~~
Invoke ``pdb`` to step into the running program. Although the code that
represents the FX graph is not in any source file, we can still step
into it manually using ``pdb`` when the forward pass is invoked.

::

def my_pass(in: torch.nn.Module) -> torch.nn.Module:
traced = torch.fx.symbolic_trace(in)
# Transformation logic here
return traced

# When this line is executed at runtime, we will be dropped into an
# interactive `pdb` prompt. We can use the `step` or `s` command to
# step into the execution of the next line
import pdb; pdb.set_trace()

my_pass(my_module)

.. _Print the Generated Code:

Print the Generated Code
~~~~~~~~~~~~~~~~~~~~~~~~~~~
If you’d like to run the same code multiple times, then it can be
a bit tedious to step to the right code with ``pdb``. In that case, one
approach is to simply copy-paste the generated ``forward`` pass into
your code and examine it from there.

::

# Assume that `traced` is a GraphModule that has undergone some
# number of transforms

# Print the code generated from symbolic tracing. This outputs:
"""
def forward(self, y):
x = self.x
add_1 = x + y; x = y = None
return add_1
"""
# Copy this code for later
print(traced)

# Subclass the original Module
class SubclassM(M):
def __init__(self):
super().__init__()

# Paste the generated `forward` function (the one we printed and
# copied on line 22) here
def forward(self, y):
x = self.x
add_1 = x + y; x = y = None
return add_1

# Create an instance of the original, untraced Module. Then, create an
# instance of the Module with the copied `forward` function. We can
# now compare the output of both the original and the traced version.
pre_trace = M()
post_trace = SubclassM()

Use the ``to_folder`` Function From ``GraphModule``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
:meth:`GraphModule.to_folder` is a method in ``GraphModule`` that allows
you to dump out the generated FX code to a folder. Although copying the
forward pass into the code often suffices as in :ref:`Print the Generated Code`,
it may be easier to examine modules and parameters using ``to_folder``.

::

m = symbolic_trace(M())
m.to_folder("foo", "Bar")
from foo import Bar
y = Bar()

After running the above example, we can then look at the code within
``foo/module.py`` and modify it as desired (e.g. adding ``print``
statements or using ``pdb``) to debug the generated code.

Debugging the Transformation
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Now that we've identified that a transformation is creating incorrect
code, it's time to debug the transformation itself. First, we'll check
the :ref:`Limitations of Symbolic Tracing` section in the documentation.
Once we verify that ``symbolic_trace`` is working as expected, the goal
becomes figuring out what went wrong during our ``GraphModule``
transformation. There may be a quick answer in
:ref:`Writing Transformations`, but, if not, there are several ways to
examine our traced module:

::

# Sample Module
class M(torch.nn.Module):
def forward(self, x, y):
return x + y

# Create an instance of `M`
m = M()

# Symbolically trace an instance of `M` (returns a GraphModule). In
# this example, we'll only be discussing how to inspect a
# GraphModule, so we aren't showing any sample transforms for the
# sake of brevity.
traced = symbolic_trace(m)

# Print the code produced by tracing the module. The generated `forward`
# function is:
"""
def forward(self, x, y):
add_1 = x + y; x = y = None
return add_1
"""
print(traced)

# Print the internal Graph. This representation returns:
"""
graph(x, y):
%add_1 : [#users=1] = call_function[target=<built-in function add>](args = (%x, %y), kwargs = {})
return add_1
"""
print(traced.graph)

# Print a tabular representation of the internal Graph. This gives us:
"""
opcode name target args kwargs
------------- ------ ----------------------- -------- --------
placeholder x x () {}
placeholder y y () {}
call_function add_1 <built-in function add> (x, y) {}
"""
traced.graph.print_tabular()

Using the utility functions above, we can compare our traced Module
before and after we've apply our transformations. Sometimes, a
simple visual comparison is enough to trace down a bug. If it's still
not clear what's going wrong, a debugger like ``pdb`` can be a good
next step.

Going off of the example above, consider the following code:

::

# Sample user-defined function
def transform_graph(gm: GraphModule) -> None:

# Get the Graph from our traced Module
g = gm.graph

"""
Transformations on `g` go here
"""

# Recompile the GraphModule. This must be called after editing
# the Graph `g`, otherwise the generated code will still reflect
# the old Graph before any transforms
gm.recompile()

# Transform the Graph
transform_graph(traced)

# Print the new code after our transforms. Check to see if it was
# what we expected
print(traced)

Using the above example, let’s say that the call to ``print(traced)``
showed us that there was an error in our transforms. We want to find
what goes wrong using a debugger. We start a ``pdb`` session. We can see
what’s happening during the transform by breaking on
``transform_graph(traced)``, then pressing ``s`` to “step into” the call
to ``transform_graph(traced)``.

We may also have good luck by editing the ``print_tabular`` method to print
different attributes of the Nodes in the Graph. (For example, we might
want to see the Node’s ``input_nodes`` and ``users``.)

.. _Available Debuggers:

Available Debuggers
^^^^^^^^^^^^^^^^^^^^^^

The most common Python debugger is
`pdb <https://docs.python.org/3/library/pdb.html>`__. You can start
your program in “debug mode” with ``pdb`` by typing
``python -m pdb FILENAME.py`` into the command line, where ``FILENAME``
is the name of the file you want to debug. After that, you can use the
``pdb`` `debugger commands
<https://docs.python.org/3/library/pdb.html#debugger-commands>`__
to move through your running program stepwise. It’s common to set a
breakpoint (``b LINE-NUMBER``) when you start ``pdb``, then call ``c`` to
run the program until that point. This prevents you from having to step
through each line of execution (using ``s`` or ``n``) to get to the part
of the code you want to examine. Alternatively, you can write
``import pdb; pdb.set_trace()`` before the line you want to break at.
If you add ``pdb.set_trace()``, your program will automatically start
in debug mode when you run it. (In other words, you can just type
``python FILENAME.py`` into the command line instead of
``python -m pdb FILENAME.py``.) Once you're running your file in
debug mode, you can step through the code and examine your program's
internal state using certain commands. There are many excellent
tutorials on ``pdb`` online, including RealPython’s
`“Python Debugging With Pdb” <https://realpython.com/python-debugging-pdb/>`__.

IDEs like PyCharm or VSCode usually have a debugger built in. In your
IDE, you can choose to either a) use ``pdb`` by pulling up a terminal
window in your IDE (e.g. View → Terminal in VSCode), or b) use the
built-in debugger (usually a graphical wrapper around ``pdb``).

.. _Limitations of Symbolic Tracing:

Limitations of Symbolic Tracing
-------------------------------
Expand Down Expand Up @@ -328,3 +563,5 @@ API Reference

.. autoclass:: torch.fx.Transformer
:members:

.. autofunction:: torch.fx.replace_pattern
1 change: 1 addition & 0 deletions torch/fx/__init__.py
Expand Up @@ -88,3 +88,4 @@ def forward(self, x):
from .node import Node, map_arg
from .proxy import Proxy
from .interpreter import Interpreter as Interpreter, Transformer as Transformer
from .subgraph_rewriter import replace_pattern
1 change: 1 addition & 0 deletions torch/fx/__init__.pyi
Expand Up @@ -4,3 +4,4 @@ from .node import Node as Node, map_arg as map_arg
from .proxy import Proxy as Proxy
from .symbolic_trace import Tracer as Tracer, symbolic_trace as symbolic_trace, wrap as wrap
from .interpreter import Interpreter as Interpreter, Transformer as Transformer
from .subgraph_rewriter import replace_pattern as replace_pattern
5 changes: 4 additions & 1 deletion torch/fx/subgraph_rewriter.py
@@ -1,4 +1,7 @@
from torch.fx import Graph, GraphModule, Node, symbolic_trace
from .graph_module import GraphModule
from .graph import Graph
from .node import Node
from .symbolic_trace import symbolic_trace

import copy
from typing import Callable, Dict, List, NamedTuple, Set
Expand Down

0 comments on commit ab4623d

Please sign in to comment.