diff --git a/docs/source/fx.rst b/docs/source/fx.rst new file mode 100644 index 000000000000..21c4268ecda3 --- /dev/null +++ b/docs/source/fx.rst @@ -0,0 +1,30 @@ +.. currentmodule:: torch.fx + +torch.fx +============= + +Overview +-------- +.. automodule:: torch.fx + + +API Reference +------------- + +.. autofunction:: torch.fx.symbolic_trace + +.. autoclass:: torch.fx.GraphModule + :members: + + .. automethod:: __init__ + +.. autoclass:: torch.fx.Graph + :members: + + .. automethod:: __init__ + +.. autoclass:: torch.fx.Node + :members: + +.. autoclass:: torch.fx.Tracer + :members: diff --git a/docs/source/index.rst b/docs/source/index.rst index a8893d956a3d..3cbe8fc07178 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -61,6 +61,7 @@ Features described in this documentation are classified by release status: torch.distributions torch.fft futures + fx torch.hub torch.jit torch.linalg diff --git a/torch/fx/__init__.py b/torch/fx/__init__.py index f3804c515612..7a3eb03de1ef 100644 --- a/torch/fx/__init__.py +++ b/torch/fx/__init__.py @@ -4,84 +4,86 @@ FX is a toolkit for capturing and transforming functional PyTorch programs. It consists of GraphModule and a corresponding intermediate representation (IR). When GraphModule is constructed -with an `nn.Module` instance as its argument, GraphModule will trace through the computation of that Module's -`forward` method symbolically and record those operations in the FX intermediate representation. +with an ``nn.Module`` instance as its argument, GraphModule will trace through the computation of that Module's +``forward`` method symbolically and record those operations in the FX intermediate representation. -``` -import torch -from torch.fx import symbolic_trace +.. code-block:: python -class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(3, 4)) - self.linear = torch.nn.Linear(4, 5) + import torch + import torch.fx - def forward(self, x): - return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3) - -m = MyModule() -gm = symbolic_trace(m) -``` - -The Intermediate Representation centers around a 5-opcode format: - -``` -print(gm.graph) -``` - -``` -graph(x): - %linear_weight : [uses=1] = self.linear.weight - %add_1 : [uses=1] = call_function[target=](args = (%x, %linear_weight), kwargs = {}) - %linear_1 : [uses=1] = call_module[target=linear](args = (%add_1,), kwargs = {}) - %relu_1 : [uses=1] = call_method[target=relu](args = (%linear_1,), kwargs = {}) - %sum_1 : [uses=1] = call_function[target=](args = (%relu_1,), kwargs = {dim: -1}) # noqa: B950 - %topk_1 : [uses=1] = call_function[target=](args = (%sum_1, 3), kwargs = {}) # noqa: B950 - return topk_1 -``` - -The semantics are as follows: - -- `placeholder` represents a function input. The `name` attribute specifies the name this value will take on. - `target` is similarly the name of the argument. `args` holds either: 1) nothing, or 2) a single argument - denoting the default parameter of the function input. `kwargs` is don't-care. Placeholders correspond to - the function parameters (e.g. `x`) in the graph printout. -- `get_attr` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the - fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy. - `args` and `kwargs` are don't-care -- `call_function` applies a free function to some values. `name` is similarly the name of the value to assign - to. `target` is the function to be applied. `args` and `kwargs` represent the arguments to the function, + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.rand(3, 4)) + self.linear = torch.nn.Linear(4, 5) + + def forward(self, x): + return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3) + + m = MyModule() + gm = torch.fx.symbolic_trace(m) + +The Intermediate Representation centers around a 5-opcode format:: + + print(gm.graph) + +.. code-block:: text + + graph(x): + %linear_weight : [#users=1] = self.linear.weight + %add_1 : [#users=1] = call_function[target=](args = (%x, %linear_weight), kwargs = {}) + %linear_1 : [#users=1] = call_module[target=linear](args = (%add_1,), kwargs = {}) + %relu_1 : [#users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {}) + %sum_1 : [#users=1] = call_function[target=](args = (%relu_1,), kwargs = {dim: -1}) # noqa: B950 + %topk_1 : [#users=1] = call_function[target=](args = (%sum_1, 3), kwargs = {}) # noqa: B950 + return topk_1 + +The Node semantics are as follows: + +- ``placeholder`` represents a function input. The ``name`` attribute specifies the name this value will take on. + ``target`` is similarly the name of the argument. ``args`` holds either: 1) nothing, or 2) a single argument + denoting the default parameter of the function input. ``kwargs`` is don't-care. Placeholders correspond to + the function parameters (e.g. ``x``) in the graph printout. +- ``get_attr`` retrieves a parameter from the module hierarchy. ``name`` is similarly the name the result of the + fetch is assigned to. ``target`` is the fully-qualified name of the parameter's position in the module hierarchy. + ``args`` and ``kwargs`` are don't-care +- ``call_function`` applies a free function to some values. ``name`` is similarly the name of the value to assign + to. ``target`` is the function to be applied. ``args`` and ``kwargs`` represent the arguments to the function, following the Python calling convention -- `call_module` applies a module in the module hierarchy's `forward()` method to given arguments. `name` is - as previous. `target` is the fully-qualified name of the module in the module hierarchy to call. - `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_. -- `call_method` calls a method on a value. `name` is as similar. `target` is the string name of the method - to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on, - _including the self argument_. -- `output` contains the output of the traced function in its `args[0]` attribute. This corresponds to the "return" statement +- ``call_module`` applies a module in the module hierarchy's ``forward()`` method to given arguments. ``name`` is + as previous. ``target`` is the fully-qualified name of the module in the module hierarchy to call. + ``args`` and ``kwargs`` represent the arguments to invoke the module on, *including the self argument*. +- ``call_method`` calls a method on a value. ``name`` is as similar. ``target`` is the string name of the method + to apply to the ``self`` argument. ``args`` and ``kwargs`` represent the arguments to invoke the module on, + *including the self argument* +- ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement in the Graph printout. -GraphModule automatically generates Python code for the operations it symbolically observed: +GraphModule automatically generates Python code for the operations it symbolically observed:: -``` -print(gm.code) -``` + print(gm.code) -``` -def forward(self, x): - linear_weight = self.linear.weight - add_1 = x + linear_weight - linear_1 = self.linear(add_1) - relu_1 = linear_1.relu() - sum_1 = torch.sum(relu_1, dim = -1) - topk_1 = torch.topk(sum_1, 3) - return topk_1 +.. code-block:: python -``` - -Because this code is valid PyTorch code, the resulting `GraphModule` can be used in any context another -`nn.Module` can be used, including in TorchScript tracing/compilation. + import torch + def forward(self, x): + linear_weight = self.linear.weight + add_1 = x + linear_weight + x = linear_weight = None + linear_1 = self.linear(add_1) + add_1 = None + relu_1 = linear_1.relu() + linear_1 = None + sum_1 = torch.sum(relu_1, dim = -1) + relu_1 = None + topk_1 = torch.topk(sum_1, 3) + sum_1 = None + return topk_1 + topk_1 = None + +Because this code is valid PyTorch code, the resulting ``GraphModule`` can be used in any context another +``nn.Module`` can be used, including in TorchScript tracing/compilation. ''' from .graph_module import GraphModule diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 65db80f8d919..072aef6e3b93 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -62,9 +62,9 @@ def _type_repr(obj): typically enough to uniquely identify a type. For everything else, we fall back on repr(obj). """ - # HACK: In Python 3.6, type aliases from `typing` are instances of `type`, but in - # later Python versions, type aliases are not instances of `type`!! We want - # all type aliases to fall through to `repr`, so if we have a type that is + # HACK: In Python 3.6, type aliases from ``typing`` are instances of ``type``, but in + # later Python versions, type aliases are not instances of ``type``!! We want + # all type aliases to fall through to ``repr``, so if we have a type that is # in the module typing, don't go down this path. if isinstance(obj, type) and obj.__module__ != 'typing': if obj.__module__ == 'builtins': @@ -109,67 +109,65 @@ def __reversed__(self): class Graph: """ - `Graph` is the main data structure used in the FX Intermediate Representation. - It consists of a series of `Node`s, each representing callsites (or other - syntactic constructs). The list of `Node`s, taken together, constitute a + ``Graph`` is the main data structure used in the FX Intermediate Representation. + It consists of a series of ``Node`` s, each representing callsites (or other + syntactic constructs). The list of ``Node`` s, taken together, constitute a valid Python function. For example, the following code - ``` - import torch - from torch.fx import symbolic_trace + .. code-block:: python - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.param = torch.nn.Parameter(torch.rand(3, 4)) - self.linear = torch.nn.Linear(4, 5) + import torch + import torch.fx - def forward(self, x): - return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3) + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.rand(3, 4)) + self.linear = torch.nn.Linear(4, 5) - m = MyModule() - gm = symbolic_trace(m) - ``` + def forward(self, x): + return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3) - Will produce the following Graph: + m = MyModule() + gm = torch.fx.symbolic_trace(m) - ``` - print(gm.graph) - ``` + Will produce the following Graph:: - ``` - graph(x): - %linear_weight : [uses=1] = self.linear.weight - %add_1 : [uses=1] = call_function[target=](args = (%x, %linear_weight), kwargs = {}) - %linear_1 : [uses=1] = call_module[target=linear](args = (%add_1,), kwargs = {}) - %relu_1 : [uses=1] = call_method[target=relu](args = (%linear_1,), kwargs = {}) - %sum_1 : [uses=1] = call_function[target=](args = (%relu_1,), kwargs = {dim: -1}) # noqa: B950 - %topk_1 : [uses=1] = call_function[target=](args = (%sum_1, 3), kwargs = {}) # noqa: B950 - return topk_1 - ``` + print(gm.graph) + + .. code-block:: text + + graph(x): + %linear_weight : [#users=1] = self.linear.weight + %add_1 : [#users=1] = call_function[target=](args = (%x, %linear_weight), kwargs = {}) + %linear_1 : [#users=1] = call_module[target=linear](args = (%add_1,), kwargs = {}) + %relu_1 : [#users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {}) + %sum_1 : [#users=1] = call_function[target=](args = (%relu_1,), kwargs = {dim: -1}) # noqa: B950 + %topk_1 : [#users=1] = call_function[target=](args = (%sum_1, 3), kwargs = {}) # noqa: B950 + return topk_1 The Node semantics are as follows: - - `placeholder` represents a function input. The `name` attribute specifies the name this value will take on. - `target` is similarly the name of the argument. `args` holds either: 1) nothing, or 2) a single argument - denoting the default parameter of the function input. `kwargs` is don't-care. Placeholders correspond to - the function parameters (e.g. `x`) in the graph printout. - - `get_attr` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the - fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy. - `args` and `kwargs` are don't-care - - `call_function` applies a free function to some values. `name` is similarly the name of the value to assign - to. `target` is the function to be applied. `args` and `kwargs` represent the arguments to the function, - following the Python calling convention - - `call_module` applies a module in the module hierarchy's `forward()` method to given arguments. `name` is - as previous. `target` is the fully-qualified name of the module in the module hierarchy to call. - `args` and `kwargs` represent the arguments to invoke the module on, _including the self argument_. - - `call_method` calls a method on a value. `name` is as similar. `target` is the string name of the method - to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on, - _including the self argument_. - - `output` contains the output of the traced function in its `args[0]` attribute. This corresponds to the "return" statement - in the Graph printout. + - ``placeholder`` represents a function input. The ``name`` attribute specifies the name this value will take on. + ``target`` is similarly the name of the argument. ``args`` holds either: 1) nothing, or 2) a single argument + denoting the default parameter of the function input. ``kwargs`` is don't-care. Placeholders correspond to + the function parameters (e.g. ``x``) in the graph printout. + - ``get_attr`` retrieves a parameter from the module hierarchy. ``name`` is similarly the name the result of the + fetch is assigned to. ``target`` is the fully-qualified name of the parameter's position in the module hierarchy. + ``args`` and ``kwargs`` are don't-care + - ``call_function`` applies a free function to some values. ``name`` is similarly the name of the value to assign + to. ``target`` is the function to be applied. ``args`` and ``kwargs`` represent the arguments to the function, + following the Python calling convention + - ``call_module`` applies a module in the module hierarchy's ``forward()`` method to given arguments. ``name`` is + as previous. ``target`` is the fully-qualified name of the module in the module hierarchy to call. + ``args`` and ``kwargs`` represent the arguments to invoke the module on, *including the self argument*. + - ``call_method`` calls a method on a value. ``name`` is as similar. ``target`` is the string name of the method + to apply to the ``self`` argument. ``args`` and ``kwargs`` represent the arguments to invoke the module on, + *including the self argument* + - ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement + in the Graph printout. """ def __init__(self): """ @@ -183,19 +181,34 @@ def __init__(self): @property def nodes(self) -> _node_list: """ - Get the list of `Node`s that constitute this Graph. + Get the list of Nodes that constitute this Graph. - Note that this `Node` list representation is a doubly-linked list. Mutations + Note that this ``Node`` list representation is a doubly-linked list. Mutations during iteration (e.g. delete a Node, add a Node) are safe. + + Returns: + + A doubly-linked list of Nodes. Note that ``reversed`` can be called on + this list to switch iteration order. """ return _node_list(self) - def graph_copy(self, g : 'Graph', val_map : Dict[Node, Node]) -> Optional[Argument]: + def graph_copy(self, g : 'Graph', val_map : Dict[Node, Node]) -> 'Optional[Argument]': """ - Append all nodes from graph `g` to this graph. `val_map` should be a dictionary - that maps nodes in `g` to nodes in `self. `val_map` will be populated with more - items by this function. Returns the equivalent output value of `g` with - Nodes switched to refer to nodes in `self`. + Copy all nodes from a given graph into ``self``. + + Args: + + g (Graph): The source graph from which to copy Nodes. + + val_map (Dict[Node, Node]): a dictionary that will be populated with a mapping + from nodes in ``g`` to nodes in ``self``. Note that ``val_map`` can be passed + in with values in it already to override copying of certain values. + + Returns: + + The value in ``self`` that is now equivalent to the output value in ``g``, + if ``g`` had an ``output`` node. ``None`` otherwise. """ for node in g.nodes: if node in val_map: @@ -220,25 +233,35 @@ def __deepcopy__(self, memo=None) -> 'Graph': g.output(output_val) return g - def create_node(self, op: str, target: Target, - args: Optional[Tuple[Argument, ...]] = None, - kwargs: Optional[Dict[str, Argument]] = None, + def create_node(self, op: str, target: 'Target', + args: Optional[Tuple['Argument', ...]] = None, + kwargs: Optional[Dict[str, 'Argument']] = None, name: Optional[str] = None, type_expr: Optional[Any] = None) -> Node: """ - Create a `Node` and add it to the `Graph` at the current insert-point. - Note that the current insert-point can be set via `Graph.inserting_before` - and `Graph.inserting_after`. + Create a ``Node`` and add it to the ``Graph`` at the current insert-point. + Note that the current insert-point can be set via :meth:`Graph.inserting_before` + and :meth:`Graph.inserting_after`. + + Args: + op (str): the opcode for this Node. One of 'call_function', 'call_method', 'get_attr', + 'call_module', 'placeholder', or 'output'. The semantics of these opcodes are + described in the ``Graph`` docstring. + + args (Optional[Tuple[Argument, ...]]): is a tuple of arguments to this node. + + kwargs (Optional[Dict[str, Argument]]): the kwargs of this Node + + name (Optional[str]): an optional string name for the ``Node``. + This will influence the name of the value assigned to in the + Python generated code. + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + Returns: - - op is the opcode for this Node. One of 'call_function', 'call_method', 'get_attr', - 'call_module', 'placeholder', or 'output'. The semantics of these opcodes are - described in the `Graph` docstring. - - args is a tuple of arguments to this node. - - kwargs is a dict from string to argument, representing the kwargs of this Node - - name is an optional string name for the `Node`. This will influence the name - of the value assigned to in the Python generated code. - - type_expr is an optional type annotation representing the Python type - the output of this node will have. + The newly-created and inserted node. """ assert op in ('call_function', 'call_method', 'get_attr', 'call_module', 'placeholder', 'output') args = () if args is None else args @@ -249,10 +272,14 @@ def create_node(self, op: str, target: Target, self._len += 1 return n - def erase_node(self, to_erase : Node): + def erase_node(self, to_erase : Node) -> None: """ - Erases the node `to_erase` from the `Graph`. Throws an exception if - there are still users of that node in the `Graph`. + Erases a ``Node`` from the ``Graph``. Throws an exception if + there are still users of that node in the ``Graph``. + + Args: + + to_erase (Node): The ``Node`` to erase from the ``Graph``. """ if len(to_erase.users) > 0: raise RuntimeError(f'Tried to erase Node {to_erase} but it still had {len(to_erase.users)} ' @@ -263,7 +290,7 @@ def erase_node(self, to_erase : Node): self._len -= 1 # Null out this Node's argument nodes so that the Nodes referred to - # can update their `users` accordingly + # can update their ``users`` accordingly new_args = map_arg(to_erase.args, lambda n: None) assert isinstance(new_args, tuple) to_erase.args = new_args @@ -274,7 +301,7 @@ def erase_node(self, to_erase : Node): def inserting_before(self, n: Optional[Node] = None): """Set the point at which create_node and companion methods will insert into the graph. When used within a 'with' statement, this will temporary set the insert point and - then restore it when the with statement exits: + then restore it when the with statement exits:: with g.inserting_before(n): ... # inserting before node n @@ -286,7 +313,7 @@ def inserting_before(self, n: Optional[Node] = None): the beginning of the entire graph. Returns: - A resource manager that will restore the insert point on `__exit__`. + A resource manager that will restore the insert point on ``__exit__``. """ if n is None: return self.inserting_after(self._root) @@ -296,7 +323,7 @@ def inserting_before(self, n: Optional[Node] = None): def inserting_after(self, n: Optional[Node] = None): """Set the point at which create_node and companion methods will insert into the graph. When used within a 'with' statement, this will temporary set the insert point and - then restore it when the with statement exits: + then restore it when the with statement exits:: with g.inserting_after(n): ... # inserting after node n @@ -308,7 +335,7 @@ def inserting_after(self, n: Optional[Node] = None): the beginning of the entire graph. Returns: - A resource manager that will restore the insert point on `__exit__`. + A resource manager that will restore the insert point on ``__exit__``. """ if n is None: return self.inserting_before(self._root) @@ -318,97 +345,178 @@ def inserting_after(self, n: Optional[Node] = None): # sugar for create_node when you know the op def placeholder(self, name: str, type_expr: Optional[Any] = None) -> Node: """ - Insert a `placeholder` node into the Graph. A `placeholder` represents - a function input. This function takes a string `name` for the input - value as well as an optional `type_expr`, which is a type expression - describing the type of value this input will take. The type expression - is needed in some cases for proper code generation. + Insert a ``placeholder`` node into the Graph. A ``placeholder`` represents + a function input. + + Args: - The same insertion point rules apply for this method as `Graph.create_node`. + name (str): A name for the input value. This corresponds to the name + of the positional argument to the function this ``Graph`` represents. + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. This is needed in some + cases for proper code generation (e.g. when the function is used + subsequently in TorchScript compilation). + + .. note:: + The same insertion point and type expression rules apply for this method + as ``Graph.create_node``. """ return self.create_node('placeholder', name, type_expr=type_expr) def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> Node: """ - Insert a `get_attr` node into the Graph. A `get_attr` `Node` represents the - fetch of an attribute from the `Module` hierarchy. `qualified_name` is the - fully-qualified name of the attribute to be retrieved. For example, if - the traced Module has a submodule named `foo`, which has a submodule named - `bar`, which has an attribute named `baz`, the qualified name `foo.bar.baz` - should be passed as `qualified_name`. + Insert a ``get_attr`` node into the Graph. A ``get_attr`` ``Node`` represents the + fetch of an attribute from the ``Module`` hierarchy. + + Args: + + qualified_name (str): the fully-qualified name of the attribute to be retrieved. + For example, if the traced Module has a submodule named ``foo``, which has a + submodule named ``bar``, which has an attribute named ``baz``, the qualified + name ``foo.bar.baz`` should be passed as ``qualified_name``. - The same insertion point and type expression rules apply for this method - as `Graph.create_node`. + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + + Returns: + + The newly-created and inserted ``get_attr`` node. + + .. note:: + The same insertion point and type expression rules apply for this method + as ``Graph.create_node``. """ return self.create_node('get_attr', qualified_name, type_expr=type_expr) def call_module(self, module_name: str, - args: Optional[Tuple[Argument, ...]] = None, - kwargs: Optional[Dict[str, Argument]] = None, + args: Optional[Tuple['Argument', ...]] = None, + kwargs: Optional[Dict[str, 'Argument']] = None, type_expr: Optional[Any] = None) -> Node: """ - Insert a `call_module` `Node` into the `Graph`. A `call_module` node - represents a call to the forward() function of a `Module` in the `Module` - hierarchy. For example, if the traced `Module` has a submodule named `foo`, - which has a submodule named `bar`, the qualified name `foo.bar` should - be passed as `module_name` to call that module. + Insert a ``call_module`` ``Node`` into the ``Graph``. A ``call_module`` node + represents a call to the forward() function of a ``Module`` in the ``Module`` + hierarchy. + + Args: + + module_name (str): The qualified name of the ``Module`` in the ``Module`` + hierarchy to be called. For example, if the traced ``Module`` has a + submodule named ``foo``, which has a submodule named ``bar``, the + qualified name ``foo.bar`` should be passed as ``module_name`` to + call that module. + + args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed + to the called method. Note that this should *not* include a ``self`` argument. - `args` and `kwargs` represent the args and kwargs passed to the called - `Module`, respectively. + kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed + to the called method - The same insertion point and type expression rules apply for this method - as `Graph.create_node`. + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + Returns: + + The newly-created and inserted ``call_module`` node. + + .. note:: + The same insertion point and type expression rules apply for this method + as :meth:`Graph.create_node`. """ return self.create_node('call_module', module_name, args, kwargs, type_expr=type_expr) def call_method(self, method_name: str, - args: Optional[Tuple[Argument, ...]] = None, - kwargs: Optional[Dict[str, Argument]] = None, + args: Optional[Tuple['Argument', ...]] = None, + kwargs: Optional[Dict[str, 'Argument']] = None, type_expr: Optional[Any] = None) -> Node: """ - Insert a `call_method` `Node` into the `Graph`. A `call_method` node + Insert a ``call_method`` ``Node`` into the ``Graph``. A ``call_method`` node represents a call to a given method on the 0th element of `args. - For example, if args[0] is a `Node` representing a `Tensor`, then to call - `relu()` on that `Tensor`, pass `relu` to `method_name`. - `args` and `kwargs` represent the args and kwargs passed to the called - method, respectively. + Args: + + method_name (str): The name of the method to apply to the self argument. + For example, if args[0] is a ``Node`` representing a ``Tensor``, + then to call ``relu()`` on that ``Tensor``, pass ``relu`` to ``method_name``. + + args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed + to the called method. Note that this *should* include a ``self`` argument. + + kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed + to the called method - The same insertion point and type expression rules apply for this method - as `Graph.create_node`. + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + Returns: + + The newly created and inserted ``call_method`` node. + + .. note:: + The same insertion point and type expression rules apply for this method + as :meth:`Graph.create_node`. """ return self.create_node('call_method', method_name, args, kwargs, type_expr=type_expr) def call_function(self, the_function: Callable[..., Any], - args: Optional[Tuple[Argument, ...]] = None, - kwargs: Optional[Dict[str, Argument]] = None, + args: Optional[Tuple['Argument', ...]] = None, + kwargs: Optional[Dict[str, 'Argument']] = None, type_expr: Optional[Any] = None) -> Node: """ - Insert a `call_function` `Node` into the `Graph`. A `call_function` node - represents a call to a Python callable, specified by `the_function`. `the_function` - can be any PyTorch operator, Python function, or member of the `builtins` - or `operator` namespaces. + Insert a ``call_function`` ``Node`` into the ``Graph``. A ``call_function`` node + represents a call to a Python callable, specified by ``the_function``. ``the_function`` + can be + + Args: + + the_function (Callable[..., Any]): The function to be called. Can be any PyTorch + operator, Python function, or member of the ``builtins`` or ``operator`` + namespaces. + + args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed + to the called function. + + kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed + to the called function + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + Returns - `args` and `kwargs` represent the args and kwargs passed to the called - method, respectively. + The newly created and inserted ``call_function`` node. - The same insertion point and type expression rules apply for this method - as `Graph.create_node`. + .. note:: + The same insertion point and type expression rules apply for this method + as :meth:`Graph.create_node`. """ return self.create_node('call_function', the_function, args, kwargs, type_expr=type_expr) - def node_copy(self, node: Node, arg_transform: Callable[[Node], Argument] = lambda x: x) -> Node: - """ Copy a node from one graph into another. arg_transform needs to transform arguments from the graph of node - to the graph of self. Example: + def node_copy(self, node: Node, arg_transform: Callable[[Node], 'Argument'] = lambda x: x) -> Node: + """ + Copy a node from one graph into another. ``arg_transform`` needs to transform arguments from + the graph of node to the graph of self. Example:: + # Copying all the nodes in `g` into `new_graph` g : torch.fx.Graph = ... new_graph = torch.fx.graph() value_remap = {} for node in g.nodes: value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n]) + + Args: + + node (Node): The node to copy into ``self``. + + arg_transform (Callable[[Node], Argument]): A function that transforms + ``Node`` arguments in node's ``args`` and ``kwargs`` into the + equivalent argument in ``self``. In the simplest case, this should + retrieve a value out of a table mapping Nodes in the original + graph to ``self``. """ args = map_arg(node.args, arg_transform) kwargs = map_arg(node.kwargs, arg_transform) @@ -416,14 +524,23 @@ def node_copy(self, node: Node, arg_transform: Callable[[Node], Argument] = lamb assert isinstance(kwargs, dict) return self.create_node(node.op, node.target, args, kwargs, node.name, node.type) - def output(self, result: Argument, type_expr: Optional[Any] = None): + def output(self, result: 'Argument', type_expr: Optional[Any] = None): """ - Insert an `output` `Node` into the `Graph`. An `output` node represents - a `return` statement in the Python code. `result` is the value that should + Insert an ``output`` ``Node`` into the ``Graph``. An ``output`` node represents + a ``return`` statement in Python code. ``result`` is the value that should be returned. - The same insertion point and type expression rules apply for this method - as `Graph.create_node`. + Args: + + result (Argument): The value to be returned. + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + .. note:: + + The same insertion point and type expression rules apply for this method + as ``Graph.create_node``. """ return self.create_node(op='output', target='output', args=(result,), type_expr=type_expr) @@ -463,7 +580,16 @@ def illegal_shadowing_name(name : str) -> bool: def python_code(self, root_module: str) -> str: """ - Turn this `Graph` into valid Python code. + Turn this ``Graph`` into valid Python code. + + Args: + + root_module (str): The name of the root module on which to look-up + qualified name targets. This is usually 'self'. + + Returns: + + The string source code generated from this ``Graph``. """ free_vars: List[str] = [] modules_used : Set[str] = set() @@ -569,7 +695,7 @@ def emit_node(node : Node): delete_unused_values(node) # repr() for inf and nan floating point values aren't parseable by - # python as literals. Explicitly import the names from the `math` module. + # python as literals. Explicitly import the names from the ``math`` module. import_strs = [f'import {name}' for name in sorted(modules_used)] import_block = '\n'.join(import_strs) @@ -589,7 +715,7 @@ def __str__(self) -> str: of this Graph """ placeholder_names : List[str] = [] - # This is a one-element array just so `format_node` can modify the closed + # This is a one-element array just so ``format_node`` can modify the closed # over value maybe_return_typename : List[str] = [''] @@ -644,7 +770,13 @@ def lint(self, root : Optional[torch.nn.Module] = None): particular: - Checks Nodes have correct ownership (owned by this graph) - Checks Nodes appear in topological order - - If `root` is provided, checks that `target`s exist in `root` + - If ``root`` is provided, checks that targets exist in ``root`` + + Args: + + root (Optional[torch.nn.Module]): The root module with which to check + for targets. This is equivalent to the ``root`` argument that is + passed when constructing a ``GraphModule``. """ # Check topo order @@ -653,7 +785,7 @@ def check_arg(arg : Node, n : Optional[Node] = None) -> None: if arg.graph is not self: raise RuntimeError(f'Argument \'{arg}\'{context_str}does not belong to this Graph, ' f'but was used as an argument! If you are copying nodes from another graph, make ' - f'sure to use `arg_transform` on node_copy() to remap values\n{self}') + f'sure to use ``arg_transform`` on node_copy() to remap values\n{self}') if arg not in seen_values: raise RuntimeError(f'Argument \'{arg}\'{context_str}was used before it has been ' f'defined! Please check that Nodes in the graph are topologically ordered\n{self}') diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index c593734eea4c..4254506334d8 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -50,7 +50,7 @@ def __init__(self, body): super().__init__() self.__dict__ = body - CodeOnlyModule.forward = _forward_from_src(body['code']) + CodeOnlyModule.forward = _forward_from_src(body['_code']) from .symbolic_trace import Tracer @@ -107,17 +107,17 @@ def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str): class GraphModule(torch.nn.Module): """ - GraphModule is an nn.Module generated from an fx.Graph. GraphModule has - important attributes: + GraphModule is an nn.Module generated from an fx.Graph. Graphmodule has a + ``graph`` attribute, as well as ``code`` and ``forward`` attributes generated + from that ``graph``. - graph : The graph from which this GraphModule was generated - code : The Python source code for the function generated from `graph` - forward : The Python method generated from `graph` + .. warning:: + + When ``graph`` is reassigned, ``code`` and ``forward`` will be automatically + regenerated. However, if you edit the contents of the ``graph`` without reassigning + the ``graph`` attribute itself, you must call ``recompile()`` to update the generated + code. - Note that when `graph` is reassigned, `code` and `forward` will be automatically - regenerated. However, if you edit the contents of the `graph` without reassigning - the `graph` attribute itself, you must call `recompile()` to update the generated - code. """ def __new__(cls: 'Type[GraphModule]', *args, **kwargs): # each instance of a graph module needs its own forward method @@ -132,14 +132,20 @@ class GraphModuleImpl(cls): # type: ignore def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph): """ Construct a GraphModule. - root - `root` can either be an nn.Module instance or a Dict mapping strings to any attribute type. - - In the case that `root` is a Module, any references to Module-based objects (via qualified - name) in the Graph's Nodes' `target` field will be copied over from the respective place - within `root`'s Module hierarchy into the GraphModule's module hierarchy. - - In the case that `root` is a dict, the qualified name found in a Node's `target` will be - looked up directly in the dict's keys. The object mapped to by the Dict will be copied - over into the appropriate place within the GraphModule's module hierarchy. - graph - `graph` contains the nodes this GraphModule should use for code generation + + Args: + + root (Union[torch.nn.Module, Dict[str, Any]): + ``root`` can either be an nn.Module instance or a Dict mapping strings to any attribute type. + In the case that ``root`` is a Module, any references to Module-based objects (via qualified + name) in the Graph's Nodes' ``target`` field will be copied over from the respective place + within ``root``'s Module hierarchy into the GraphModule's module hierarchy. + In the case that ``root`` is a dict, the qualified name found in a Node's ``target`` will be + looked up directly in the dict's keys. The object mapped to by the Dict will be copied + over into the appropriate place within the GraphModule's module hierarchy. + + graph (Graph): ``graph`` contains the nodes this GraphModule should use for code generation + """ super().__init__() if isinstance(root, torch.nn.Module): @@ -156,14 +162,14 @@ def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph): assert isinstance(node.target, str) if node.target not in root: raise RuntimeError('Node ' + str(node) + ' referenced target ' + node.target + - ' but that target was not provided in `root`!') + ' but that target was not provided in ``root``!') targets_to_copy.append(node.target) # Sort targets in ascending order of the # of atoms. # This will ensure that less deeply nested attributes are assigned # before more deeply nested attributes. For example, foo.bar # will be assigned before foo.bar.baz. Otherwise, we might assign - # the user-provided `foo.bar` and wipe out the previously-assigned - # `foo.bar.baz` + # the user-provided ``foo.bar`` and wipe out the previously-assigned + # ``foo.bar.baz`` targets_to_copy.sort(key=lambda t: t.count('.')) for target_to_copy in targets_to_copy: _assign_attr(root[target_to_copy], self, target_to_copy) @@ -178,31 +184,41 @@ def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph): __jit_unused_properties__ = ['graph'] @property - def graph(self): + def graph(self) -> Graph: """ - Return the `Graph` underlying this `GraphModule` + Return the ``Graph`` underlying this ``GraphModule`` """ return self._graph @graph.setter def graph(self, g) -> None: """ - Set the underlying `Graph` for this `GraphModule`. This will internally - recompile the `GraphModule` so that the generated `forward()` function - corresponds to `g` + Set the underlying ``Graph`` for this ``GraphModule``. This will internally + recompile the ``GraphModule`` so that the generated ``forward()`` function + corresponds to ``g`` """ self._graph = g self.recompile() + @property + def code(self) -> str: + """ + Return the Python code generated from the ``Graph`` underlying this + ``GraphModule``. + """ + if not hasattr(self, '_code'): + raise RuntimeError('Code has not been generated! Please report a bug to PyTorch') + return self._code + def recompile(self) -> None: """ - Recompile this GraphModule from its `graph` attribute. This should be - called after editing the contained `graph`, otherwise the generated - code of this `GraphModule` will be out of date. + Recompile this GraphModule from its ``graph`` attribute. This should be + called after editing the contained ``graph``, otherwise the generated + code of this ``GraphModule`` will be out of date. """ - self.code = self._graph.python_code(root_module='self') + self._code = self._graph.python_code(root_module='self') cls = type(self) - cls.forward = _forward_from_src(self.code) + cls.forward = _forward_from_src(self._code) cls_call = cls.__call__ @@ -221,10 +237,10 @@ def wrapped_call(self, *args, **kwargs): def __reduce__(self): """ Serialization of GraphModule. We serialize only the generated code, not - the underlying `Graph`. This is because `Graph` does not have on-disk + the underlying ``Graph``. This is because ``Graph`` does not have on-disk backward-compatibility guarantees, whereas Python source code does. On the deserialization side, we symbolically trace through the generated - code to regenerate the underlying `Graph` + code to regenerate the underlying ``Graph`` """ dict_without_graph = self.__dict__.copy() del dict_without_graph['_graph'] @@ -243,7 +259,7 @@ def __copy__(self): def __str__(self) -> str: orig_str = super().__str__() - return '\n'.join([orig_str, self.code]) + return '\n'.join([orig_str, self._code]) # workarounds for issues in __torch_function__ diff --git a/torch/fx/node.py b/torch/fx/node.py index 8c484e0ab421..1cc94be83e7e 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -37,7 +37,7 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: Target, self._update_args_kwargs(map_arg(args, lambda x: x), map_arg(kwargs, lambda x: x)) # type: ignore # All of the nodes that use the value produced by this Node - # Note one user may correspond to several uses, e.g. the node fo `x + x` + # Note one user may correspond to several uses, e.g. the node fo ``x + x`` # would appear once here, but represents two uses. # # Is a dict to act as an "ordered set". Keys are significant, value dont-care @@ -49,9 +49,9 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: Target, # For placeholder nodes, this value will be used to type-annotate the # generated function parameters. # For the return ndoe, this value will be used to type-annotate the - # generated function return type. (Note this is a special case. `return` + # generated function return type. (Note this is a special case. ``return`` # does not produce a value, it's more of a notation. Thus, this value - # describes the type of args[0] in the `return` node. + # describes the type of args[0] in the ``return`` node. self.type : Optional[Any] = type self._prev = self self._next = self @@ -89,7 +89,7 @@ def prepend(self, x: 'Node'): def append(self, x: 'Node'): """Insert x after this node in the list of nodes in the graph. - Equvalent to `self.next.prepend(x)` + Equvalent to ``self.next.prepend(x)`` Args: x (Node): The node to put after this node. Must be a member of the same graph. @@ -104,7 +104,7 @@ def _remove_from_list(self): def args(self) -> Tuple[Argument, ...]: """ Return the tuple of arguments to this Node. The interpretation of arguments - depends on the node's opcode. See the `fx.Graph` docstring for more + depends on the node's opcode. See the ``fx.Graph`` docstring for more information. """ return self._args @@ -113,7 +113,7 @@ def args(self) -> Tuple[Argument, ...]: def args(self, a : Tuple[Argument, ...]): """ Set the tuple of arguments to this Node. The interpretation of arguments - depends on the node's opcode. See the `fx.Graph` docstring for more + depends on the node's opcode. See the ``fx.Graph`` docstring for more information. """ self._update_args_kwargs(map_arg(a, lambda x: x), self._kwargs) # type: ignore @@ -122,7 +122,7 @@ def args(self, a : Tuple[Argument, ...]): def kwargs(self) -> Dict[str, Argument]: """ Return the dict of kwargs to this Node. The interpretation of arguments - depends on the node's opcode. See the `fx.Graph` docstring for more + depends on the node's opcode. See the ``fx.Graph`` docstring for more information. """ return self._kwargs @@ -131,7 +131,7 @@ def kwargs(self) -> Dict[str, Argument]: def kwargs(self, k : Dict[str, Argument]): """ Set the dict of kwargs to this Node. The interpretation of arguments - depends on the node's opcode. See the `fx.Graph` docstring for more + depends on the node's opcode. See the ``fx.Graph`` docstring for more information. """ self._update_args_kwargs(self._args, map_arg(k, lambda x: x)) # type: ignore @@ -140,7 +140,7 @@ def kwargs(self, k : Dict[str, Argument]): def all_input_nodes(self) -> List['Node']: """ Return all Nodes that are inputs to this Node. This is equivalent to - iterating over `args` and `kwargs` and only collecting the values that + iterating over ``args`` and ``kwargs`` and only collecting the values that are Nodes """ all_nodes : List['Node'] = [] @@ -167,7 +167,7 @@ def __repr__(self) -> str: def replace_all_uses_with(self, replace_with : 'Node') -> List['Node']: """ - Replace all uses of `self` in the Graph with the Node `replace_with`. + Replace all uses of ``self`` in the Graph with the Node ``replace_with``. Returns the list of nodes on which this change was made. """ to_process = list(self.users) diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index da005f0640b0..f8c4aa8d8366 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -33,8 +33,8 @@ def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: If kind = 'placeholder', then we're creating a Node that represents the parameter of a function. If we need to encode - a default parameter, we use the `args` tuple. `args` is - otherwise empty for `placeholder` Nodes. + a default parameter, we use the ``args`` tuple. ``args`` is + otherwise empty for ``placeholder`` Nodes. ''' args_ = self.create_arg(args) kwargs_ = self.create_arg(kwargs) diff --git a/torch/fx/symbolic_trace.py b/torch/fx/symbolic_trace.py index b2e5b0961114..fe59c2a11e17 100644 --- a/torch/fx/symbolic_trace.py +++ b/torch/fx/symbolic_trace.py @@ -39,9 +39,9 @@ def _patch_function(fn: FunctionType, nargs: int) -> FunctionType: class Tracer(TracerBase): """ - `Tracer` is the class that implements the symbolic tracing functionality - of `torch.fx.symbolic_trace`. A call to `symbolic_trace(m)` is equivalent - to `Tracer().trace(m)`. + ``Tracer`` is the class that implements the symbolic tracing functionality + of ``torch.fx.symbolic_trace``. A call to ``symbolic_trace(m)`` is equivalent + to ``Tracer().trace(m)``. Tracer can be subclassed to override various behaviors of the tracing process. The different behaviors that can be overridden are described @@ -53,14 +53,14 @@ def __init__(self): def create_arg(self, a: Any) -> Argument: """ A method to specify the behavior of tracing when preparing values to - be used as arguments to nodes in the `Graph`. + be used as arguments to nodes in the ``Graph``. By default, the behavior includes: - Iterate through collection types (e.g. tuple, list, dict) and recursively - call `create_args` on the elements. - - Given a Proxy object, return a reference to the underlying IR `Node` + call ``create_args`` on the elements. + - Given a Proxy object, return a reference to the underlying IR ``Node`` - Given a non-Proxy Tensor object, emit IR for various cases: - - For a Parameter, emit a `get_attr` node referring to that Parameter + - For a Parameter, emit a ``get_attr`` node referring to that Parameter - For a non-Parameter Tensor, store the Tensor away in a special attribute referring to that attribute. @@ -105,10 +105,10 @@ def create_arg(self, a: Any) -> Argument: def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool: """ - A method to specify whether a given `nn.Module` is a "leaf" module. + A method to specify whether a given ``nn.Module`` is a "leaf" module. Leaf modules are the atomic units that appear in - the IR, referenced by `call_module` calls. By default, + the IR, referenced by ``call_module`` calls. By default, Modules in the PyTorch standard library namespace (torch.nn) are leaf modules. All other modules are traced through and their constituent ops are recorded, unless specified otherwise @@ -117,17 +117,17 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> boo Args m - The module itself module_qualified_name - The path to root of this module. For example, - if you have a module hierarchy where submodule `foo` contains - submodule `bar`, which contains submodule `baz`, that module will - appear with the qualified name `foo.bar.baz` here. + if you have a module hierarchy where submodule ``foo`` contains + submodule ``bar``, which contains submodule ``baz``, that module will + appear with the qualified name ``foo.bar.baz`` here. """ return m.__module__.startswith('torch.nn') and not isinstance(m, torch.nn.Sequential) def path_of_module(self, mod) -> str: """ - Helper method to find the qualified name of `mod` in the Module hierarchy - of `root`. For example, if `root` has a submodule named `foo`, which has - a submodule named `bar`, passing `bar` into this function will return + Helper method to find the qualified name of ``mod`` in the Module hierarchy + of ``root``. For example, if ``root`` has a submodule named ``foo``, which has + a submodule named ``bar``, passing ``bar`` into this function will return the string "foo.bar". """ for n, p in self.root.named_modules(): @@ -137,17 +137,17 @@ def path_of_module(self, mod) -> str: def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args, kwargs): """ - Method that specifies the behavior of this `Tracer` when it encounters - a call to an `nn.Module` instance. + Method that specifies the behavior of this ``Tracer`` when it encounters + a call to an ``nn.Module`` instance. By default, the behavior is to check if the called module is a leaf module - via `is_leaf_module`. If it is, emit a `call_module` node referring to - `m` in the `Graph`. Otherwise, call the `Module` normally, tracing through - the operations in its `forward` function. + via ``is_leaf_module``. If it is, emit a ``call_module`` node referring to + ``m`` in the ``Graph``. Otherwise, call the ``Module`` normally, tracing through + the operations in its ``forward`` function. This method can be overridden to--for example--create nested traced GraphModules, or any other behavior you would want while tracing across - `Module` boundaries. + ``Module`` boundaries. """ module_qualified_name = self.path_of_module(m) if not self.is_leaf_module(m, module_qualified_name): @@ -156,12 +156,12 @@ def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args, kwa def create_args_for_root(self, root_fn, is_module): """ - Create `placeholder` nodes corresponding to the signature of the `root` - Module. This method introspects `root`'s signature and emits those + Create ``placeholder`` nodes corresponding to the signature of the ``root`` + Module. This method introspects ``root``'s signature and emits those nodes accordingly, also supporting *args and **kwargs. """ # In some cases, a function or method has been decorated with a wrapper - # defined via `functools.wraps`. In this case, the outer code object + # defined via ``functools.wraps``. In this case, the outer code object # will likely not contain the actual parameters we care about, so unwrap # the function to get to the innermost callable. fn_for_analysis = inspect.unwrap(root_fn) @@ -172,7 +172,7 @@ def create_args_for_root(self, root_fn, is_module): skip_arg_idx = 0 if is_module: if total_args == 0: - raise RuntimeError('`self` argument cannot be part of *args expansion!') + raise RuntimeError('``self`` argument cannot be part of *args expansion!') skip_arg_idx = 1 next(names_iter) # skip self args.append(self.root) @@ -202,8 +202,8 @@ def proxy_placeholder(name: str): def trace(self, root: Union[torch.nn.Module, Callable]) -> Graph: """ - Trace `root` and return the corresponding FX `Graph` representation. `root` - can either be an `nn.Module` instance or a Python callable. + Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root`` + can either be an ``nn.Module`` instance or a Python callable. """ if isinstance(root, torch.nn.Module): self.root = root @@ -268,10 +268,17 @@ def forward(*args, **kwargs): def symbolic_trace(root : Union[torch.nn.Module, Callable]) -> GraphModule: - """ - Symbolic tracing API + """Symbolic tracing API + + Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule`` + constructed by recording operations seen while tracing through ``root``. + + Args: + root (Union[torch.nn.Module, Callable]): Module or function to be traced and converted + into a Graph representation. + + Returns: + GraphModule: a Module created from the recorded operations from ``root``. - Given an `nn.Module` or function instance `root`, this function will return a `GraphModule` - constructed by recording operations seen while tracing through `root`. """ return GraphModule(root if isinstance(root, torch.nn.Module) else torch.nn.Module(), Tracer().trace(root))