From 809bf74197fafc067dca8b57868f63eedd28dade Mon Sep 17 00:00:00 2001 From: James Reed Date: Fri, 4 Dec 2020 18:03:54 -0800 Subject: [PATCH] [FX][2/2] Make docstrings pretty when rendered [ghstack-poisoned] --- torch/fx/graph.py | 21 +-------- torch/fx/node.py | 94 ++++++++++++++++++++++++++++++-------- torch/fx/symbolic_trace.py | 72 ++++++++++++++++++++++------- 3 files changed, 132 insertions(+), 55 deletions(-) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 072aef6e3b93..ca4b8d64bb0e 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -148,26 +148,7 @@ def forward(self, x): %topk_1 : [#users=1] = call_function[target=](args = (%sum_1, 3), kwargs = {}) # noqa: B950 return topk_1 - The Node semantics are as follows: - - - ``placeholder`` represents a function input. The ``name`` attribute specifies the name this value will take on. - ``target`` is similarly the name of the argument. ``args`` holds either: 1) nothing, or 2) a single argument - denoting the default parameter of the function input. ``kwargs`` is don't-care. Placeholders correspond to - the function parameters (e.g. ``x``) in the graph printout. - - ``get_attr`` retrieves a parameter from the module hierarchy. ``name`` is similarly the name the result of the - fetch is assigned to. ``target`` is the fully-qualified name of the parameter's position in the module hierarchy. - ``args`` and ``kwargs`` are don't-care - - ``call_function`` applies a free function to some values. ``name`` is similarly the name of the value to assign - to. ``target`` is the function to be applied. ``args`` and ``kwargs`` represent the arguments to the function, - following the Python calling convention - - ``call_module`` applies a module in the module hierarchy's ``forward()`` method to given arguments. ``name`` is - as previous. ``target`` is the fully-qualified name of the module in the module hierarchy to call. - ``args`` and ``kwargs`` represent the arguments to invoke the module on, *including the self argument*. - - ``call_method`` calls a method on a value. ``name`` is as similar. ``target`` is the string name of the method - to apply to the ``self`` argument. ``args`` and ``kwargs`` represent the arguments to invoke the module on, - *including the self argument* - - ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement - in the Graph printout. + For the semantics of operations represented in the ``Graph``, please see :class:`Node`. """ def __init__(self): """ diff --git a/torch/fx/node.py b/torch/fx/node.py index 1cc94be83e7e..fd8a4bc1377c 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -21,8 +21,34 @@ ]] class Node: - def __init__(self, graph: 'Graph', name: str, op: str, target: Target, - args: Tuple[Argument, ...], kwargs: Dict[str, Argument], + """ + ``Node`` is the data structure that represents individual operations within + a ``Graph``. For the most part, Nodes represent callsites to various entities, + such as operators, methods, and Modules (some exceptions include nodes that + specify function inputs and outputs). Each ``Node`` has a function specified + by its ``op`` property. The ``Node`` semantics for each value of ``op`` are as follows: + + - ``placeholder`` represents a function input. The ``name`` attribute specifies the name this value will take on. + ``target`` is similarly the name of the argument. ``args`` holds either: 1) nothing, or 2) a single argument + denoting the default parameter of the function input. ``kwargs`` is don't-care. Placeholders correspond to + the function parameters (e.g. ``x``) in the graph printout. + - ``get_attr`` retrieves a parameter from the module hierarchy. ``name`` is similarly the name the result of the + fetch is assigned to. ``target`` is the fully-qualified name of the parameter's position in the module hierarchy. + ``args`` and ``kwargs`` are don't-care + - ``call_function`` applies a free function to some values. ``name`` is similarly the name of the value to assign + to. ``target`` is the function to be applied. ``args`` and ``kwargs`` represent the arguments to the function, + following the Python calling convention + - ``call_module`` applies a module in the module hierarchy's ``forward()`` method to given arguments. ``name`` is + as previous. ``target`` is the fully-qualified name of the module in the module hierarchy to call. + ``args`` and ``kwargs`` represent the arguments to invoke the module on, *including the self argument*. + - ``call_method`` calls a method on a value. ``name`` is as similar. ``target`` is the string name of the method + to apply to the ``self`` argument. ``args`` and ``kwargs`` represent the arguments to invoke the module on, + *including the self argument* + - ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement + in the Graph printout. + """ + def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', + args: Tuple['Argument', ...], kwargs: Dict[str, 'Argument'], type : Optional[Any] = None) -> None: self.graph = graph self.name = name # unique name of value being created @@ -60,23 +86,33 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: Target, @property def next(self) -> 'Node': """ - Get the next node in the linked list + Returns the next ``Node`` in the linked list of Nodes. + + Returns: + + The next ``Node`` in the linked list of Nodes. """ return self._next @property def prev(self) -> 'Node': """ - Get the previous node in the linked list + Returns the previous ``Node`` in the linked list of Nodes. + + Returns: + + The previous ``Node`` in the linked list of Nodes. """ return self._prev - def prepend(self, x: 'Node'): - """Insert x before this node in the list of nodes in the graph. - Before: p -> self - bx -> x -> ax - After: p -> x -> self - bx -> ax + def prepend(self, x: 'Node') -> None: + """ + Insert x before this node in the list of nodes in the graph. Example:: + + Before: p -> self + bx -> x -> ax + After: p -> x -> self + bx -> ax Args: x (Node): The node to put before this node. Must be a member of the same graph. @@ -87,8 +123,9 @@ def prepend(self, x: 'Node'): p._next, x._prev = x, p x._next, self._prev = self, x - def append(self, x: 'Node'): - """Insert x after this node in the list of nodes in the graph. + def append(self, x: 'Node') -> None: + """ + Insert x after this node in the list of nodes in the graph. Equvalent to ``self.next.prepend(x)`` Args: @@ -103,9 +140,12 @@ def _remove_from_list(self): @property def args(self) -> Tuple[Argument, ...]: """ - Return the tuple of arguments to this Node. The interpretation of arguments - depends on the node's opcode. See the ``fx.Graph`` docstring for more + The tuple of arguments to this ``Node``. The interpretation of arguments + depends on the node's opcode. See the :class:`Node` docstring for more information. + + Assignment to this property is allowed. All accounting of uses and users + is updated automatically on assignment. """ return self._args @@ -121,9 +161,12 @@ def args(self, a : Tuple[Argument, ...]): @property def kwargs(self) -> Dict[str, Argument]: """ - Return the dict of kwargs to this Node. The interpretation of arguments - depends on the node's opcode. See the ``fx.Graph`` docstring for more + The dict of keyword arguments to this ``Node``. The interpretation of arguments + depends on the node's opcode. See the :class:`Node` docstring for more information. + + Assignment to this property is allowed. All accounting of uses and users + is updated automatically on assignment. """ return self._kwargs @@ -141,7 +184,12 @@ def all_input_nodes(self) -> List['Node']: """ Return all Nodes that are inputs to this Node. This is equivalent to iterating over ``args`` and ``kwargs`` and only collecting the values that - are Nodes + are Nodes. + + Returns: + + List of ``Nodes`` that appear in the ``args`` and ``kwargs`` of this + ``Node``, in that order. """ all_nodes : List['Node'] = [] map_arg(self.args, lambda n: all_nodes.append(n)) @@ -149,6 +197,9 @@ def all_input_nodes(self) -> List['Node']: return all_nodes def _update_args_kwargs(self, new_args : Tuple[Argument, ...], new_kwargs : Dict[str, Argument]): + """ + This API is internal. Do *not* call it directly. + """ self._args = new_args self._kwargs = new_kwargs @@ -168,7 +219,14 @@ def __repr__(self) -> str: def replace_all_uses_with(self, replace_with : 'Node') -> List['Node']: """ Replace all uses of ``self`` in the Graph with the Node ``replace_with``. - Returns the list of nodes on which this change was made. + + Args: + + replace_with (Node): The node to replace all uses of ``self`` with. + + Returns: + + The list of Nodes on which this change was made. """ to_process = list(self.users) for use_node in to_process: diff --git a/torch/fx/symbolic_trace.py b/torch/fx/symbolic_trace.py index fe59c2a11e17..874f4b1bf89c 100644 --- a/torch/fx/symbolic_trace.py +++ b/torch/fx/symbolic_trace.py @@ -1,6 +1,6 @@ import inspect from types import CodeType, FunctionType -from typing import Any, Dict, Optional, List, Callable, Union +from typing import Any, Dict, Optional, Tuple, List, Callable, Union import torch from .node import Argument @@ -50,21 +50,31 @@ class Tracer(TracerBase): def __init__(self): super().__init__() - def create_arg(self, a: Any) -> Argument: + def create_arg(self, a: Any) -> 'Argument': """ A method to specify the behavior of tracing when preparing values to be used as arguments to nodes in the ``Graph``. By default, the behavior includes: - - Iterate through collection types (e.g. tuple, list, dict) and recursively - call ``create_args`` on the elements. - - Given a Proxy object, return a reference to the underlying IR ``Node`` - - Given a non-Proxy Tensor object, emit IR for various cases: - - For a Parameter, emit a ``get_attr`` node referring to that Parameter - - For a non-Parameter Tensor, store the Tensor away in a special - attribute referring to that attribute. + #. Iterate through collection types (e.g. tuple, list, dict) and recursively + call ``create_args`` on the elements. + #. Given a Proxy object, return a reference to the underlying IR ``Node`` + #. Given a non-Proxy Tensor object, emit IR for various cases: + + * For a Parameter, emit a ``get_attr`` node referring to that Parameter + * For a non-Parameter Tensor, store the Tensor away in a special + attribute referring to that attribute. This method can be overridden to support more types. + + Args: + + a (Any): The value to be emitted as an ``Argument`` in the ``Graph``. + + + Returns: + + The value ``a`` converted into the appropriate ``Argument`` """ # The base tracer is used to construct Graphs when there is no associated # module hierarchy, so it can never create parameter references. @@ -114,28 +124,32 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> boo their constituent ops are recorded, unless specified otherwise via this parameter. - Args - m - The module itself - module_qualified_name - The path to root of this module. For example, - if you have a module hierarchy where submodule ``foo`` contains - submodule ``bar``, which contains submodule ``baz``, that module will - appear with the qualified name ``foo.bar.baz`` here. + Args: + m (Module): The module being queried about + module_qualified_name (str): The path to root of this module. For example, + if you have a module hierarchy where submodule ``foo`` contains + submodule ``bar``, which contains submodule ``baz``, that module will + appear with the qualified name ``foo.bar.baz`` here. """ return m.__module__.startswith('torch.nn') and not isinstance(m, torch.nn.Sequential) - def path_of_module(self, mod) -> str: + def path_of_module(self, mod : torch.nn.Module) -> str: """ Helper method to find the qualified name of ``mod`` in the Module hierarchy of ``root``. For example, if ``root`` has a submodule named ``foo``, which has a submodule named ``bar``, passing ``bar`` into this function will return the string "foo.bar". + + Args: + + mod (str): The ``Module`` to retrieve the qualified name for. """ for n, p in self.root.named_modules(): if mod is p: return n raise NameError('module is not installed as a submodule') - def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args, kwargs): + def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args : Tuple[Any, ...], kwargs : Dict[str, Any]) -> Any: """ Method that specifies the behavior of this ``Tracer`` when it encounters a call to an ``nn.Module`` instance. @@ -148,6 +162,20 @@ def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args, kwa This method can be overridden to--for example--create nested traced GraphModules, or any other behavior you would want while tracing across ``Module`` boundaries. + ``Module`` boundaries. + + Args: + + m (Module): The module for which a call is being emitted + forward (Callable): The forward() method of the ``Module`` to be invoked + args (Tuple): args of the module callsite + kwargs (Dict): kwargs of the module callsite + + Return: + + The return value from the Module call. In the case that a ``call_module`` + node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever + value was returned from the ``Module`` invocation. """ module_qualified_name = self.path_of_module(m) if not self.is_leaf_module(m, module_qualified_name): @@ -204,6 +232,16 @@ def trace(self, root: Union[torch.nn.Module, Callable]) -> Graph: """ Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root`` can either be an ``nn.Module`` instance or a Python callable. + + + Args: + + root (Union[Module, Callable]): Either a ``Module`` or a function to be + traced through. + + Returns: + + A ``Graph`` representing the semantics of the passed-in ``root``. """ if isinstance(root, torch.nn.Module): self.root = root