Skip to content

Commit

Permalink
[FX] Update overview docstring (#50896)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #50896

Test Plan: Imported from OSS

Reviewed By: ansley

Differential Revision: D26002067

Pulled By: jamesr66a

fbshipit-source-id: 3b4d4b96017d16739a31f25a306f55b6f96324dc
  • Loading branch information
James Reed authored and facebook-github-bot committed Jan 22, 2021
1 parent eb0fe70 commit 5016637
Showing 1 changed file with 56 additions and 59 deletions.
115 changes: 56 additions & 59 deletions torch/fx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,82 +2,79 @@
r'''
**This feature is under a Beta release and its API may change.**
FX is a toolkit for capturing and transforming functional PyTorch programs. It
consists of GraphModule and a corresponding intermediate representation (IR). When GraphModule is constructed
with an ``nn.Module`` instance as its argument, GraphModule will trace through the computation of that Module's
``forward`` method symbolically and record those operations in the FX intermediate representation.
FX is a toolkit for developers to use to transform ``nn.Module``
instances. FX consists of three main components: a **symbolic tracer,**
an **intermediate representation**, and **Python code generation**. A
demonstration of these components in action:
.. code-block:: python
::
import torch
import torch.fx
# Simple module for demonstration
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)
m = MyModule()
gm = torch.fx.symbolic_trace(m)
The Intermediate Representation centers around a 5-opcode format::
def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
print(gm.graph)
module = MyModule()
.. code-block:: text
from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)
# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph(x):
%linear_weight : [#users=1] = self.linear.weight
%add_1 : [#users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {})
%param : [#users=1] = self.param
%add_1 : [#users=1] = call_function[target=<built-in function add>](args = (%x, %param), kwargs = {})
%linear_1 : [#users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
%relu_1 : [#users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})
%sum_1 : [#users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1})
%topk_1 : [#users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {})
return topk_1
The Node semantics are as follows:
- ``placeholder`` represents a function input. The ``name`` attribute specifies the name this value will take on.
``target`` is similarly the name of the argument. ``args`` holds either: 1) nothing, or 2) a single argument
denoting the default parameter of the function input. ``kwargs`` is don't-care. Placeholders correspond to
the function parameters (e.g. ``x``) in the graph printout.
- ``get_attr`` retrieves a parameter from the module hierarchy. ``name`` is similarly the name the result of the
fetch is assigned to. ``target`` is the fully-qualified name of the parameter's position in the module hierarchy.
``args`` and ``kwargs`` are don't-care
- ``call_function`` applies a free function to some values. ``name`` is similarly the name of the value to assign
to. ``target`` is the function to be applied. ``args`` and ``kwargs`` represent the arguments to the function,
following the Python calling convention
- ``call_module`` applies a module in the module hierarchy's ``forward()`` method to given arguments. ``name`` is
as previous. ``target`` is the fully-qualified name of the module in the module hierarchy to call.
``args`` and ``kwargs`` represent the arguments to invoke the module on, *including the self argument*.
- ``call_method`` calls a method on a value. ``name`` is as similar. ``target`` is the string name of the method
to apply to the ``self`` argument. ``args`` and ``kwargs`` represent the arguments to invoke the module on,
*including the self argument*
- ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement
in the Graph printout.
%clamp_1 : [#users=1] = call_method[target=clamp](args = (%linear_1,), kwargs = {min: 0.0, max: 1.0})
return clamp_1
"""
GraphModule automatically generates Python code for the operations it symbolically observed::
print(gm.code)
.. code-block:: python
import torch
# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):
linear_weight = self.linear.weight
add_1 = x + linear_weight; x = linear_weight = None
param = self.param
add_1 = x + param; x = param = None
linear_1 = self.linear(add_1); add_1 = None
relu_1 = linear_1.relu(); linear_1 = None
sum_1 = torch.sum(relu_1, dim = -1); relu_1 = None
topk_1 = torch.topk(sum_1, 3); sum_1 = None
return topk_1
Because this code is valid PyTorch code, the resulting ``GraphModule`` can be used in any context another
``nn.Module`` can be used, including in TorchScript tracing/compilation.
clamp_1 = linear_1.clamp(min = 0.0, max = 1.0); linear_1 = None
return clamp_1
"""
The **symbolic tracer** performs “abstract interpretation” of the Python
code. It feeds fake values, called Proxies, through the code. Operations
on theses Proxies are recorded. More information about symbolic tracing
can be found in the
`symbolic\_trace <https://pytorch.org/docs/master/fx.html#torch.fx.symbolic_trace>`__
and `Tracer <https://pytorch.org/docs/master/fx.html#torch.fx.Tracer>`__
documentation.
The **intermediate representation** is the container for the operations
that were recorded during symbolic tracing. It consists of a list of
Nodes that represent function inputs, callsites (to functions, methods,
or ``nn.Module`` instances), and return values. More information about
the IR can be found in the documentation for
`Graph <https://pytorch.org/docs/master/fx.html#torch.fx.Graph>`__. The
IR is the format on which transformations are applied.
**Python code generation** is what makes FX a Python-to-Python (or
Module-to-Module) transformation toolkit. For each Graph IR, we can
create valid Python code matching the Graph’s semantics. This
functionality is wrapped up in
`GraphModule <https://pytorch.org/docs/master/fx.html#torch.fx.GraphModule>`__,
which is an ``nn.Module`` instance that holds a ``Graph`` as well as a
``forward`` method generated from the Graph.
Taken together, this pipeline of components (symbolic tracing →
intermediate representation → transforms → Python code generation)
constitutes the Python-to-Python transformation pipeline of FX.
'''

from .graph_module import GraphModule
Expand Down

0 comments on commit 5016637

Please sign in to comment.