diff --git a/docs/source/fx.rst b/docs/source/fx.rst index 21c4268ecda3..cd7c2738031e 100644 --- a/docs/source/fx.rst +++ b/docs/source/fx.rst @@ -28,3 +28,5 @@ API Reference .. autoclass:: torch.fx.Tracer :members: + +.. autoclass:: torch.fx.Proxy diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index a9bb6cd0a1ad..669a762aea8d 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -112,17 +112,21 @@ def __init__(self, graph: Graph): class TraceError(ValueError): pass -# Proxy objects are stand-in values for normal values in a PyTorch computation. -# Instead of performing compute they record computation into Graph. -# Each proxy wraps the Node instance that represents the expression that define the -# value. class Proxy: + """ + ``Proxy`` objects are ``Node`` wrappers that flow through the + program during symbolic tracing and record all the operations + (``torch`` function calls, method calls, operators) that they touch + into the growing FX Graph. + + If you're doing graph transforms, you can wrap your own ``Proxy`` + method around a raw ``Node`` so that you can use the overloaded + operators to add additional things to a ``Graph``. + """ def __init__(self, node: Node, tracer: 'Optional[TracerBase]' = None): if tracer is None: - # this allows you to create a proxy object around a raw node - # so that if you are doing graph transforms you can use the overloaded operators - # to add additional things to a graph. + # This allows you to create a Proxy object around a raw Node tracer = GraphAppendingTracer(node.graph) self.tracer = tracer self.node = node