Skip to content

Latest commit

 

History

History
330 lines (251 loc) · 9.91 KB

fx.rst

File metadata and controls

330 lines (251 loc) · 9.91 KB
.. currentmodule:: torch.fx

torch.fx

Overview

.. automodule:: torch.fx

Writing Transformations

TODO

Debugging Transformations

TODO

Limitations of Symbolic Tracing

FX uses a system of symbolic tracing (a.k.a symbolic execution) to capture the semantics of programs in a transformable/analyzable form. The system is tracing in that it executes the program (really an nn.Module or function) to gather this information. It is symbolic in that the data flowing through the program during this execution is not real data, but rather symbols (“Proxy” in FX parlance).

Although symbolic tracing works for most neural net code, it has some limitations.

Dynamic Control Flow

The main limitation of symbolic tracing is it does not currently support dynamic control flow. That is, loops or if statements where the condition may depend on the input values of the program.

For example, let’s examine the following program:

def func_to_trace(x):
    dim0 = x.size[0]
    if dim0 == 3:
        return torch.relu(x)
    else:
        return torch.neg(x)

traced = torch.fx.symbolic_trace(func_to_trace)
"""
  <...>
  File "dyn.py", line 6, in func_to_trace
    if dim0 == 3:
  File "pytorch/torch/fx/proxy.py", line 155, in __bool__
    return self.tracer.to_bool(self)
  File "pytorch/torch/fx/proxy.py", line 85, in to_bool
    raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""

The condition to the if statement relies on the value of dim0, which eventually relies on the value of x, a function input. Since x can change (i.e. if you pass a new input tensor to the traced function), this is dynamic control flow. The traceback walks back up through your code to show you where this situation happens.

Static Control Flow

On the other hand, so-called static control flow is supported. Static control flow is loops or if statements whose value cannot change across invocations. Typically, in PyTorch programs, this control flow arises for code making decisions about a model’s architecture based on hyper-parameters. As a concrete example:

import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self, do_activation : bool = False):
        super().__init__()
        self.do_activation = do_activation
        self.linear = torch.nn.Linear(512, 512)

    def forward(self, x):
        x = self.linear(x)
        # This if-statement is so-called static control flow.
        # Its condition does not depend on any input values
        if self.do_activation:
            x = torch.relu(x)
        return x

without_activation = MyModule(do_activation=False)
with_activation = MyModule(do_activation=True)

traced_without_activation = torch.fx.symbolic_trace(without_activation)
print(traced_without_activation.code)
"""
def forward(self, x):
    linear_1 = self.linear(x);  x = None
    return linear_1
"""

traced_with_activation = torch.fx.symbolic_trace(with_activation)
print(traced_with_activation.code)
"""
import torch
def forward(self, x):
    linear_1 = self.linear(x);  x = None
    relu_1 = torch.relu(linear_1);  linear_1 = None
    return relu_1
"""

The if-statement if self.do_activation does not depend on any function inputs, thus it is static. do_activation can be considered to be a hyper-parameter, and the traces of different instances of MyModule with different values for that parameter have different code. This is a valid pattern that is supported by symbolic tracing.

Many instances of dynamic control flow are semantically static control flow. These instances can be made to support symbolic tracing by removing the data dependencies on input values, for example by moving values to Module attributes or by passing constant values during symbolic tracing:

def f(x, flag):
    if flag: return x
    else: return x*2

fx.symbolic_trace(f) # Fails!

def g(flag):
    return lambda x: f(x, flag)

new_f = g(flag=True)
fx.symbolic_trace(new_f)

In the case of truly dynamic control flow, the sections of the program that contain this code can be traced as calls to the Method (see :ref:`Customizing Tracing`) or function (see :func:`wrap`) rather than tracing through them.

Non-torch Functions

FX uses __torch_function__ as the mechanism by which it intercepts calls (see the technical overview for more information about this). Some functions, such as builtin Python functions or those in the math module, are things that are not covered by __torch_function__, but we would still like to capture them in symbolic tracing. For example:

from math import sqrt

def normalize(x):
    """
    Normalize `x` by the size of the batch dimension
    """
    return x / sqrt(len(x))

# It's valid Python code
normalize(torch.rand(3, 4))

traced = torch.fx.symbolic_trace(normalize)
"""
  <...>
  File "sqrt.py", line 9, in normalize
    return x / sqrt(len(x))
  File "pytorch/torch/fx/proxy.py", line 161, in __len__
    raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
"""

The error tells us that the built-in function len is not supported. We can make it so that functions like this are recorded in the trace as direct calls using the :func:`wrap` API:

torch.fx.wrap('len')
torch.fx.wrap('sqrt')

traced = torch.fx.symbolic_trace(normalize)

print(traced.code)
"""
import math
def forward(self, x):
    len_1 = len(x)
    sqrt_1 = math.sqrt(len_1);  len_1 = None
    truediv = x / sqrt_1;  x = sqrt_1 = None
    return truediv
"""

Customizing Tracing with the Tracer class

The :class:`Tracer` class is the class that underlies the implementation of symbolic_trace. The behavior of tracing can be customized by subclassing Tracer, like so:

class MyCustomTracer(torch.fx.Tracer):
    # Inside here you can override various methods
    # to customize tracing. See the `Tracer` API
    # reference
    pass


# Let's use this custom tracer to trace through this module
class MyModule(torch.nn.Module):
    def forward(self, x):
        return torch.relu(x) + torch.ones(3, 4)

mod = MyModule()

traced_graph = MyCustomTracer().trace(mod)
# trace() returns a Graph. Let's wrap it up in a
# GraphModule to make it runnable
traced = torch.fx.GraphModule(mod, traced_graph)
Leaf Modules

Leaf Modules are the modules that appear as calls in the symbolic trace rather than being traced through. The default set of leaf modules is the set of standard torch.nn module instances. For example:

class MySpecialSubmodule(torch.nn.Module):
    def forward(self, x):
        return torch.neg(x)

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3, 4)
        self.submod = MySpecialSubmodule()

    def forward(self, x):
        return self.submod(self.linear(x))

traced = torch.fx.symbolic_trace(MyModule())
print(traced.code)
# `linear` is preserved as a call, yet `submod` is traced though.
# This is because the default set of "Leaf Modules" includes all
# standard `torch.nn` modules.
"""
import torch
def forward(self, x):
    linear_1 = self.linear(x);  x = None
    neg_1 = torch.neg(linear_1);  linear_1 = None
    return neg_1
"""

The set of leaf modules can be customized by overriding :meth:`Tracer.is_leaf_module`.

Miscellanea

  • Tensor constructors (e.g. torch.zeros, torch.ones, torch.rand, torch.randn, torch.sparse_coo_tensor) are currently not traceable.
    • The deterministic constructors (zeros, ones) can be used and the value they produce will be embedded in the trace as a constant. This is only problematic if the arguments to these constructors refers to dynamic input sizes. In this case, ones_like or zeros_like may be a viable substitute.
    • Nondeterministic constructors (rand, randn) will have a single random value embedded in the trace. This is likely not the intended behavior.
    • This behavior may be fixed in a future release.
  • Type annotations
    • Python 3-style type annotations (e.g. func(x : torch.Tensor, y : int) -> torch.Tensor) are supported and will be preserved by symbolic tracing.
    • Python 2-style comment type annotations # type: (torch.Tensor, int) -> torch.Tensor are not currently supported.
    • Annotations on local names within a function are not currently supported.

API Reference

.. autofunction:: torch.fx.symbolic_trace

.. autofunction:: torch.fx.wrap

.. autoclass:: torch.fx.GraphModule
  :members:

  .. automethod:: __init__

.. autoclass:: torch.fx.Graph
  :members:

  .. automethod:: __init__

.. autoclass:: torch.fx.Node
  :members:

.. autoclass:: torch.fx.Tracer
  :members:

.. autoclass:: torch.fx.Proxy

.. autoclass:: torch.fx.Interpreter
  :members:

.. autoclass:: torch.fx.Transformer
  :members: