From 080a0979359b443eaef8aed40065b244dfef33bd Mon Sep 17 00:00:00 2001 From: Ansley Ussery Date: Mon, 11 Jan 2021 13:46:10 -0800 Subject: [PATCH] Add docstring for Proxy (#50145) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50145 Test Plan: Imported from OSS Reviewed By: pbelevich Differential Revision: D25854281 Pulled By: ansley fbshipit-source-id: d7af6fd6747728ef04e86fbcdeb87cb0508e1fd8 --- docs/source/fx.rst | 2 ++ torch/fx/proxy.py | 18 +++++++++++------- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/docs/source/fx.rst b/docs/source/fx.rst index 21c4268ecda3..cd7c2738031e 100644 --- a/docs/source/fx.rst +++ b/docs/source/fx.rst @@ -28,3 +28,5 @@ API Reference .. autoclass:: torch.fx.Tracer :members: + +.. autoclass:: torch.fx.Proxy diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index a9bb6cd0a1ad..669a762aea8d 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -112,17 +112,21 @@ def __init__(self, graph: Graph): class TraceError(ValueError): pass -# Proxy objects are stand-in values for normal values in a PyTorch computation. -# Instead of performing compute they record computation into Graph. -# Each proxy wraps the Node instance that represents the expression that define the -# value. class Proxy: + """ + ``Proxy`` objects are ``Node`` wrappers that flow through the + program during symbolic tracing and record all the operations + (``torch`` function calls, method calls, operators) that they touch + into the growing FX Graph. + + If you're doing graph transforms, you can wrap your own ``Proxy`` + method around a raw ``Node`` so that you can use the overloaded + operators to add additional things to a ``Graph``. + """ def __init__(self, node: Node, tracer: 'Optional[TracerBase]' = None): if tracer is None: - # this allows you to create a proxy object around a raw node - # so that if you are doing graph transforms you can use the overloaded operators - # to add additional things to a graph. + # This allows you to create a Proxy object around a raw Node tracer = GraphAppendingTracer(node.graph) self.tracer = tracer self.node = node