Skip to content

Commit

Permalink
[FX][1/2] Make docstrings pretty when rendered
Browse files Browse the repository at this point in the history
ghstack-source-id: ed63675e1df90a7b70c276a3c9f8582011c22c15
Pull Request resolved: #48738
  • Loading branch information
James Reed committed Dec 3, 2020
1 parent 22c3ae8 commit 413428d
Show file tree
Hide file tree
Showing 8 changed files with 472 additions and 284 deletions.
30 changes: 30 additions & 0 deletions docs/source/fx.rst
@@ -0,0 +1,30 @@
.. currentmodule:: torch.fx

torch.fx
=============

Overview
--------
.. automodule:: torch.fx


API Reference
-------------

.. autofunction:: torch.fx.symbolic_trace

.. autoclass:: torch.fx.GraphModule
:members:

.. automethod:: __init__

.. autoclass:: torch.fx.Graph
:members:

.. automethod:: __init__

.. autoclass:: torch.fx.Node
:members:

.. autoclass:: torch.fx.Tracer
:members:
1 change: 1 addition & 0 deletions docs/source/index.rst
Expand Up @@ -61,6 +61,7 @@ Features described in this documentation are classified by release status:
torch.distributions <distributions>
torch.fft <fft>
futures
fx
torch.hub <hub>
torch.jit <jit>
torch.linalg <linalg>
Expand Down
140 changes: 71 additions & 69 deletions torch/fx/__init__.py
Expand Up @@ -4,84 +4,86 @@
FX is a toolkit for capturing and transforming functional PyTorch programs. It
consists of GraphModule and a corresponding intermediate representation (IR). When GraphModule is constructed
with an `nn.Module` instance as its argument, GraphModule will trace through the computation of that Module's
`forward` method symbolically and record those operations in the FX intermediate representation.
with an ``nn.Module`` instance as its argument, GraphModule will trace through the computation of that Module's
``forward`` method symbolically and record those operations in the FX intermediate representation.
```
import torch
from torch.fx import symbolic_trace
.. code-block:: python
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
import torch
import torch.fx
def forward(self, x):
return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)
m = MyModule()
gm = symbolic_trace(m)
```
The Intermediate Representation centers around a 5-opcode format:
```
print(gm.graph)
```
```
graph(x):
%linear_weight : [uses=1] = self.linear.weight
%add_1 : [uses=1] = call_function[target=<built-in function add>](args = (%x, %linear_weight), kwargs = {})
%linear_1 : [uses=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
%relu_1 : [uses=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})
%sum_1 : [uses=1] = call_function[target=<built-in method sum of type object at 0x7fad0a3c16a0>](args = (%relu_1,), kwargs = {dim: -1}) # noqa: B950
%topk_1 : [uses=1] = call_function[target=<built-in method topk of type object at 0x7fad0a3c16a0>](args = (%sum_1, 3), kwargs = {}) # noqa: B950
return topk_1
```
The semantics are as follows:
- `placeholder` represents a function input. The `name` attribute specifies the name this value will take on.
`target` is similarly the name of the argument. `args` holds either: 1) nothing, or 2) a single argument
denoting the default parameter of the function input. `kwargs` is don't-care. Placeholders correspond to
the function parameters (e.g. `x`) in the graph printout.
- `get_attr` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the
fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy.
`args` and `kwargs` are don't-care
- `call_function` applies a free function to some values. `name` is similarly the name of the value to assign
to. `target` is the function to be applied. `args` and `kwargs` represent the arguments to the function,
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)
m = MyModule()
gm = torch.fx.symbolic_trace(m)
The Intermediate Representation centers around a 5-opcode format::
print(gm.graph)
.. code-block:: text
graph(x):
%linear_weight : [#users=1] = self.linear.weight
%add_1 : [#users=1] = call_function[target=<built-in function add>](args = (%x, %linear_weight), kwargs = {})
%linear_1 : [#users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
%relu_1 : [#users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})
%sum_1 : [#users=1] = call_function[target=<built-in method sum of type object at 0x7ff2da9dc300>](args = (%relu_1,), kwargs = {dim: -1}) # noqa: B950
%topk_1 : [#users=1] = call_function[target=<built-in method topk of type object at 0x7ff2da9dc300>](args = (%sum_1, 3), kwargs = {}) # noqa: B950
return topk_1
The Node semantics are as follows:
- ``placeholder`` represents a function input. The ``name`` attribute specifies the name this value will take on.
``target`` is similarly the name of the argument. ``args`` holds either: 1) nothing, or 2) a single argument
denoting the default parameter of the function input. ``kwargs`` is don't-care. Placeholders correspond to
the function parameters (e.g. ``x``) in the graph printout.
- ``get_attr`` retrieves a parameter from the module hierarchy. ``name`` is similarly the name the result of the
fetch is assigned to. ``target`` is the fully-qualified name of the parameter's position in the module hierarchy.
``args`` and ``kwargs`` are don't-care
- ``call_function`` applies a free function to some values. ``name`` is similarly the name of the value to assign
to. ``target`` is the function to be applied. ``args`` and ``kwargs`` represent the arguments to the function,
following the Python calling convention
- `call_module` applies a module in the module hierarchy's `forward()` method to given arguments. `name` is
as previous. `target` is the fully-qualified name of the module in the module hierarchy to call.
`args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_.
- `call_method` calls a method on a value. `name` is as similar. `target` is the string name of the method
to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on,
_including the self argument_.
- `output` contains the output of the traced function in its `args[0]` attribute. This corresponds to the "return" statement
- ``call_module`` applies a module in the module hierarchy's ``forward()`` method to given arguments. ``name`` is
as previous. ``target`` is the fully-qualified name of the module in the module hierarchy to call.
``args`` and ``kwargs`` represent the arguments to invoke the module on, *including the self argument*.
- ``call_method`` calls a method on a value. ``name`` is as similar. ``target`` is the string name of the method
to apply to the ``self`` argument. ``args`` and ``kwargs`` represent the arguments to invoke the module on,
*including the self argument*
- ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement
in the Graph printout.
GraphModule automatically generates Python code for the operations it symbolically observed:
GraphModule automatically generates Python code for the operations it symbolically observed::
```
print(gm.code)
```
print(gm.code)
```
def forward(self, x):
linear_weight = self.linear.weight
add_1 = x + linear_weight
linear_1 = self.linear(add_1)
relu_1 = linear_1.relu()
sum_1 = torch.sum(relu_1, dim = -1)
topk_1 = torch.topk(sum_1, 3)
return topk_1
.. code-block:: python
```
Because this code is valid PyTorch code, the resulting `GraphModule` can be used in any context another
`nn.Module` can be used, including in TorchScript tracing/compilation.
import torch
def forward(self, x):
linear_weight = self.linear.weight
add_1 = x + linear_weight
x = linear_weight = None
linear_1 = self.linear(add_1)
add_1 = None
relu_1 = linear_1.relu()
linear_1 = None
sum_1 = torch.sum(relu_1, dim = -1)
relu_1 = None
topk_1 = torch.topk(sum_1, 3)
sum_1 = None
return topk_1
topk_1 = None
Because this code is valid PyTorch code, the resulting ``GraphModule`` can be used in any context another
``nn.Module`` can be used, including in TorchScript tracing/compilation.
'''

from .graph_module import GraphModule
Expand Down

0 comments on commit 413428d

Please sign in to comment.