Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FX][2/2] Make docstrings pretty when rendered #48871

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 1 addition & 20 deletions torch/fx/graph.py
Expand Up @@ -148,26 +148,7 @@ def forward(self, x):
%topk_1 : [#users=1] = call_function[target=<built-in method topk of type object at 0x7ff2da9dc300>](args = (%sum_1, 3), kwargs = {}) # noqa: B950
return topk_1

The Node semantics are as follows:

- ``placeholder`` represents a function input. The ``name`` attribute specifies the name this value will take on.
``target`` is similarly the name of the argument. ``args`` holds either: 1) nothing, or 2) a single argument
denoting the default parameter of the function input. ``kwargs`` is don't-care. Placeholders correspond to
the function parameters (e.g. ``x``) in the graph printout.
- ``get_attr`` retrieves a parameter from the module hierarchy. ``name`` is similarly the name the result of the
fetch is assigned to. ``target`` is the fully-qualified name of the parameter's position in the module hierarchy.
``args`` and ``kwargs`` are don't-care
- ``call_function`` applies a free function to some values. ``name`` is similarly the name of the value to assign
to. ``target`` is the function to be applied. ``args`` and ``kwargs`` represent the arguments to the function,
following the Python calling convention
- ``call_module`` applies a module in the module hierarchy's ``forward()`` method to given arguments. ``name`` is
as previous. ``target`` is the fully-qualified name of the module in the module hierarchy to call.
``args`` and ``kwargs`` represent the arguments to invoke the module on, *including the self argument*.
- ``call_method`` calls a method on a value. ``name`` is as similar. ``target`` is the string name of the method
to apply to the ``self`` argument. ``args`` and ``kwargs`` represent the arguments to invoke the module on,
*including the self argument*
- ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement
in the Graph printout.
For the semantics of operations represented in the ``Graph``, please see :class:`Node`.
"""
def __init__(self):
"""
Expand Down
94 changes: 76 additions & 18 deletions torch/fx/node.py
Expand Up @@ -21,8 +21,34 @@
]]

class Node:
def __init__(self, graph: 'Graph', name: str, op: str, target: Target,
args: Tuple[Argument, ...], kwargs: Dict[str, Argument],
"""
``Node`` is the data structure that represents individual operations within
a ``Graph``. For the most part, Nodes represent callsites to various entities,
such as operators, methods, and Modules (some exceptions include nodes that
specify function inputs and outputs). Each ``Node`` has a function specified
by its ``op`` property. The ``Node`` semantics for each value of ``op`` are as follows:

- ``placeholder`` represents a function input. The ``name`` attribute specifies the name this value will take on.
``target`` is similarly the name of the argument. ``args`` holds either: 1) nothing, or 2) a single argument
denoting the default parameter of the function input. ``kwargs`` is don't-care. Placeholders correspond to
the function parameters (e.g. ``x``) in the graph printout.
- ``get_attr`` retrieves a parameter from the module hierarchy. ``name`` is similarly the name the result of the
fetch is assigned to. ``target`` is the fully-qualified name of the parameter's position in the module hierarchy.
``args`` and ``kwargs`` are don't-care
- ``call_function`` applies a free function to some values. ``name`` is similarly the name of the value to assign
to. ``target`` is the function to be applied. ``args`` and ``kwargs`` represent the arguments to the function,
following the Python calling convention
- ``call_module`` applies a module in the module hierarchy's ``forward()`` method to given arguments. ``name`` is
as previous. ``target`` is the fully-qualified name of the module in the module hierarchy to call.
``args`` and ``kwargs`` represent the arguments to invoke the module on, *including the self argument*.
- ``call_method`` calls a method on a value. ``name`` is as similar. ``target`` is the string name of the method
to apply to the ``self`` argument. ``args`` and ``kwargs`` represent the arguments to invoke the module on,
*including the self argument*
- ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement
in the Graph printout.
"""
def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target',
args: Tuple['Argument', ...], kwargs: Dict[str, 'Argument'],
type : Optional[Any] = None) -> None:
self.graph = graph
self.name = name # unique name of value being created
Expand Down Expand Up @@ -60,23 +86,33 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: Target,
@property
def next(self) -> 'Node':
"""
Get the next node in the linked list
Returns the next ``Node`` in the linked list of Nodes.

Returns:

The next ``Node`` in the linked list of Nodes.
"""
return self._next

@property
def prev(self) -> 'Node':
"""
Get the previous node in the linked list
Returns the previous ``Node`` in the linked list of Nodes.

Returns:

The previous ``Node`` in the linked list of Nodes.
"""
return self._prev

def prepend(self, x: 'Node'):
"""Insert x before this node in the list of nodes in the graph.
Before: p -> self
bx -> x -> ax
After: p -> x -> self
bx -> ax
def prepend(self, x: 'Node') -> None:
"""
Insert x before this node in the list of nodes in the graph. Example::

Before: p -> self
bx -> x -> ax
After: p -> x -> self
bx -> ax

Args:
x (Node): The node to put before this node. Must be a member of the same graph.
Expand All @@ -87,8 +123,9 @@ def prepend(self, x: 'Node'):
p._next, x._prev = x, p
x._next, self._prev = self, x

def append(self, x: 'Node'):
"""Insert x after this node in the list of nodes in the graph.
def append(self, x: 'Node') -> None:
"""
Insert x after this node in the list of nodes in the graph.
Equvalent to ``self.next.prepend(x)``

Args:
Expand All @@ -103,9 +140,12 @@ def _remove_from_list(self):
@property
def args(self) -> Tuple[Argument, ...]:
"""
Return the tuple of arguments to this Node. The interpretation of arguments
depends on the node's opcode. See the ``fx.Graph`` docstring for more
The tuple of arguments to this ``Node``. The interpretation of arguments
depends on the node's opcode. See the :class:`Node` docstring for more
information.

Assignment to this property is allowed. All accounting of uses and users
is updated automatically on assignment.
"""
return self._args

Expand All @@ -121,9 +161,12 @@ def args(self, a : Tuple[Argument, ...]):
@property
def kwargs(self) -> Dict[str, Argument]:
"""
Return the dict of kwargs to this Node. The interpretation of arguments
depends on the node's opcode. See the ``fx.Graph`` docstring for more
The dict of keyword arguments to this ``Node``. The interpretation of arguments
depends on the node's opcode. See the :class:`Node` docstring for more
information.

Assignment to this property is allowed. All accounting of uses and users
is updated automatically on assignment.
"""
return self._kwargs

Expand All @@ -141,14 +184,22 @@ def all_input_nodes(self) -> List['Node']:
"""
Return all Nodes that are inputs to this Node. This is equivalent to
iterating over ``args`` and ``kwargs`` and only collecting the values that
are Nodes
are Nodes.

Returns:

List of ``Nodes`` that appear in the ``args`` and ``kwargs`` of this
``Node``, in that order.
"""
all_nodes : List['Node'] = []
map_arg(self.args, lambda n: all_nodes.append(n))
map_arg(self.kwargs, lambda n: all_nodes.append(n))
return all_nodes

def _update_args_kwargs(self, new_args : Tuple[Argument, ...], new_kwargs : Dict[str, Argument]):
"""
This API is internal. Do *not* call it directly.
"""
self._args = new_args
self._kwargs = new_kwargs

Expand All @@ -168,7 +219,14 @@ def __repr__(self) -> str:
def replace_all_uses_with(self, replace_with : 'Node') -> List['Node']:
"""
Replace all uses of ``self`` in the Graph with the Node ``replace_with``.
Returns the list of nodes on which this change was made.

Args:

replace_with (Node): The node to replace all uses of ``self`` with.

Returns:

The list of Nodes on which this change was made.
"""
to_process = list(self.users)
for use_node in to_process:
Expand Down
72 changes: 55 additions & 17 deletions torch/fx/symbolic_trace.py
@@ -1,6 +1,6 @@
import inspect
from types import CodeType, FunctionType
from typing import Any, Dict, Optional, List, Callable, Union
from typing import Any, Dict, Optional, Tuple, List, Callable, Union
import torch
from torch._C import ScriptObject # type: ignore

Expand Down Expand Up @@ -51,21 +51,31 @@ class Tracer(TracerBase):
def __init__(self):
super().__init__()

def create_arg(self, a: Any) -> Argument:
def create_arg(self, a: Any) -> 'Argument':
"""
A method to specify the behavior of tracing when preparing values to
be used as arguments to nodes in the ``Graph``.

By default, the behavior includes:
- Iterate through collection types (e.g. tuple, list, dict) and recursively
call ``create_args`` on the elements.
- Given a Proxy object, return a reference to the underlying IR ``Node``
- Given a non-Proxy Tensor object, emit IR for various cases:
- For a Parameter, emit a ``get_attr`` node referring to that Parameter
- For a non-Parameter Tensor, store the Tensor away in a special
attribute referring to that attribute.

#. Iterate through collection types (e.g. tuple, list, dict) and recursively
call ``create_args`` on the elements.
#. Given a Proxy object, return a reference to the underlying IR ``Node``
#. Given a non-Proxy Tensor object, emit IR for various cases:

* For a Parameter, emit a ``get_attr`` node referring to that Parameter
* For a non-Parameter Tensor, store the Tensor away in a special
attribute referring to that attribute.
This method can be overridden to support more types.

Args:

a (Any): The value to be emitted as an ``Argument`` in the ``Graph``.


Returns:

The value ``a`` converted into the appropriate ``Argument``
"""
# The base tracer is used to construct Graphs when there is no associated
# module hierarchy, so it can never create parameter references.
Expand Down Expand Up @@ -115,28 +125,32 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> boo
their constituent ops are recorded, unless specified otherwise
via this parameter.

Args
m - The module itself
module_qualified_name - The path to root of this module. For example,
if you have a module hierarchy where submodule ``foo`` contains
submodule ``bar``, which contains submodule ``baz``, that module will
appear with the qualified name ``foo.bar.baz`` here.
Args:
m (Module): The module being queried about
module_qualified_name (str): The path to root of this module. For example,
if you have a module hierarchy where submodule ``foo`` contains
submodule ``bar``, which contains submodule ``baz``, that module will
appear with the qualified name ``foo.bar.baz`` here.
"""
return m.__module__.startswith('torch.nn') and not isinstance(m, torch.nn.Sequential)

def path_of_module(self, mod) -> str:
def path_of_module(self, mod : torch.nn.Module) -> str:
"""
Helper method to find the qualified name of ``mod`` in the Module hierarchy
of ``root``. For example, if ``root`` has a submodule named ``foo``, which has
a submodule named ``bar``, passing ``bar`` into this function will return
the string "foo.bar".

Args:

mod (str): The ``Module`` to retrieve the qualified name for.
"""
for n, p in self.root.named_modules():
if mod is p:
return n
raise NameError('module is not installed as a submodule')

def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args, kwargs):
def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args : Tuple[Any, ...], kwargs : Dict[str, Any]) -> Any:
"""
Method that specifies the behavior of this ``Tracer`` when it encounters
a call to an ``nn.Module`` instance.
Expand All @@ -149,6 +163,20 @@ def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args, kwa
This method can be overridden to--for example--create nested traced
GraphModules, or any other behavior you would want while tracing across
``Module`` boundaries.
``Module`` boundaries.

Args:

m (Module): The module for which a call is being emitted
forward (Callable): The forward() method of the ``Module`` to be invoked
args (Tuple): args of the module callsite
kwargs (Dict): kwargs of the module callsite

Return:

The return value from the Module call. In the case that a ``call_module``
node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever
value was returned from the ``Module`` invocation.
"""
module_qualified_name = self.path_of_module(m)
if not self.is_leaf_module(m, module_qualified_name):
Expand Down Expand Up @@ -205,6 +233,16 @@ def trace(self, root: Union[torch.nn.Module, Callable]) -> Graph:
"""
Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root``
can either be an ``nn.Module`` instance or a Python callable.


Args:

root (Union[Module, Callable]): Either a ``Module`` or a function to be
traced through.

Returns:

A ``Graph`` representing the semantics of the passed-in ``root``.
"""
if isinstance(root, torch.nn.Module):
self.root = root
Expand Down