Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 99 additions & 45 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import re
import sys
import traceback
import weakref
from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, OrderedDict, Set, Union

Expand Down Expand Up @@ -185,7 +186,7 @@ def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor])
self.restore()


class OutputGraph(fx.Tracer, Checkpointable[OutputGraphState]):
class OutputGraph(Checkpointable[OutputGraphState]):
"""
Wrapper class to hold outputs of InstructionTranslator. Mainly the
generated fx.Graph.
Expand All @@ -202,7 +203,7 @@ def __init__(
frame_state,
):
super().__init__()
self.graph = torch.fx.Graph()
self.tracers = [SubgraphTracer(self)]
# Map from graph input's `Source` to its `VariableTracker` to
# de-duplicate graph inputs by source and reuse the tracker
self.input_source_to_var: Dict[Source, VariableTracker] = {}
Expand Down Expand Up @@ -252,8 +253,6 @@ def __init__(
# used to track nodes that are added between calls of copy_graphstate
# and restore_graphstate
self.timestamp = 0
# Node => computed real value (see utils.get_real_value)
self.real_value_cache: Dict[fx.Node, torch.Tensor] = {}

# Not checkpointed
self.compiler_fn: CompilerFn = compiler_fn
Expand All @@ -267,10 +266,45 @@ def __init__(
self.random_values_var = None
self.unspec_variable_map: Dict[str, UnspecializedPythonVariable] = {}

# Map from graph input name to its placeholder proxy object, where the
# map's keys give all current placeholder node names and can be used to
# create unique node names
self.input_name_to_proxy: OrderedDict[str, fx.Proxy] = collections.OrderedDict()
@property
def root_tracer(self):
return self.tracers[0]

@property
def current_tracer(self):
return self.tracers[-1]

@property
def graph(self):
return self.current_tracer.graph

# TODO(rzou): can delete after we refactor speculate_subgraph to use nested GraphTracer.
@graph.setter
def graph(self, value):
self.current_tracer.graph = value

@property
def input_name_to_proxy(self):
return self.current_tracer.input_name_to_proxy

@property
def real_value_cache(self):
return self.current_tracer.real_value_cache

def create_graph_input(self, *args, **kwargs):
return self.current_tracer.create_graph_input(*args, **kwargs)

def create_proxy(self, *args, **kwargs):
return self.current_tracer.create_proxy(*args, **kwargs)

def create_node(self, *args, **kwargs):
return self.current_tracer.create_node(*args, **kwargs)

def remove_node(self, *args, **kwargs):
return self.current_tracer.remove_node(*args, **kwargs)

def create_arg(self, *args, **kwargs):
return self.current_tracer.create_arg(*args, **kwargs)

@property
def output(self):
Expand Down Expand Up @@ -396,40 +430,6 @@ def get_submodule(self, keys):
obj = getattr(obj, k)
return obj

# when before=True, we will insert this input before the most recent
# inserted proxy. This is a hack to get around an ordering problem,
# where we first insert a tensor argument, and then insert bindings
# for SymInts that may occur in the tensor argument.
# Remove this if https://github.com/pytorch/pytorch/issues/99007 gets
# fixed.
def create_graph_input(self, name, type_expr=None, before=False):
# unique
if name in self.input_name_to_proxy:
for i in itertools.count():
candidate_name = f"{name}_{i}"
if candidate_name not in self.input_name_to_proxy:
name = candidate_name
break

if self.input_name_to_proxy:
prev_name = next(reversed(self.input_name_to_proxy))
node = self.input_name_to_proxy[prev_name].node
if before:
ctx = self.graph.inserting_before(node)
else:
ctx = self.graph.inserting_after(node)
else:
ctx = self.graph.inserting_before(None)
with ctx:
proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr)
if self.input_name_to_proxy and before:
k, v = self.input_name_to_proxy.popitem()
self.input_name_to_proxy[name] = proxy
self.input_name_to_proxy[k] = v
else:
self.input_name_to_proxy[name] = proxy
return proxy

def new_var(self, name="tmp"):
existing = set(self.code_options["co_varnames"])
for i in itertools.count():
Expand Down Expand Up @@ -570,7 +570,7 @@ def register_leaf_name(leaf_name):
raise AssertionError("unreachable")

def compile_subgraph(
self, tx, partial_convert=False, reason: GraphCompileReason = None
self, tx, partial_convert=False, reason: Optional[GraphCompileReason] = None
):
"""
Generate a subgraph to continue execution on user code.
Expand Down Expand Up @@ -923,6 +923,26 @@ def cleanup(self) -> None:
self.input_name_to_proxy.clear()
self.side_effects.clear()


class SubgraphTracer(fx.Tracer):
"""
Holds an FX graph that is being traced. OutputGraph owns a SubgraphTracer
and the separation of responsibilities is that SubgraphTracer is
responsible for building the graph while OutputGraph is responsible for
compiling and executing the graph.
"""

def __init__(self, output_graph):
super(SubgraphTracer, self).__init__()
self.output_graph = weakref.proxy(output_graph)
self.graph = torch.fx.Graph()
# Map from graph input name to its placeholder proxy object, where the
# map's keys give all current placeholder node names and can be used to
# create unique node names
self.input_name_to_proxy: OrderedDict[str, fx.Proxy] = collections.OrderedDict()
# Node => computed real value (see utils.get_real_value)
self.real_value_cache: Dict[fx.Node, torch.Tensor] = {}

def create_proxy(
self,
kind,
Expand All @@ -938,7 +958,7 @@ def create_proxy(
)

# append stack trace to fx node
tx = self.current_tx
tx = self.output_graph.current_tx

nn_module_stack = tx.nn_module_stack
if nn_module_stack:
Expand All @@ -965,11 +985,45 @@ def create_proxy(

def create_node(self, *args, **kwargs):
node = super().create_node(*args, **kwargs)
node.meta["creation_timestamp"] = self.timestamp
node.meta["creation_timestamp"] = self.output_graph.timestamp
return node

# Note: we did not override erase_node since
# we call self.graph.erase_node elsewhere
def remove_node(self, node):
self.graph.erase_node(node)
self.input_name_to_proxy.pop(node.name, None)

# when before=True, we will insert this input before the most recent
# inserted proxy. This is a hack to get around an ordering problem,
# where we first insert a tensor argument, and then insert bindings
# for SymInts that may occur in the tensor argument.
# Remove this if https://github.com/pytorch/pytorch/issues/99007 gets
# fixed.
def create_graph_input(self, name, type_expr=None, before=False):
# unique
if name in self.input_name_to_proxy:
for i in itertools.count():
candidate_name = f"{name}_{i}"
if candidate_name not in self.input_name_to_proxy:
name = candidate_name
break

if self.input_name_to_proxy:
prev_name = next(reversed(self.input_name_to_proxy))
node = self.input_name_to_proxy[prev_name].node
if before:
ctx = self.graph.inserting_before(node)
else:
ctx = self.graph.inserting_after(node)
else:
ctx = self.graph.inserting_before(None)
with ctx:
proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr)
if self.input_name_to_proxy and before:
k, v = self.input_name_to_proxy.popitem()
self.input_name_to_proxy[name] = proxy
self.input_name_to_proxy[k] = v
else:
self.input_name_to_proxy[name] = proxy
return proxy
16 changes: 8 additions & 8 deletions torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,7 +1263,7 @@ def visit(n: torch.fx.Node):
raise TorchRuntimeError() from e


def run_node(output_graph, node, args, kwargs, nnmodule):
def run_node(tracer, node, args, kwargs, nnmodule):
"""
Runs a given node, with the given args and kwargs.

Expand All @@ -1272,7 +1272,7 @@ def run_node(output_graph, node, args, kwargs, nnmodule):
run_node is useful for extracting real values out of nodes.
See get_real_value for more info on common usage.

Note: The output_graph arg is only used for 'get_attr' ops
Note: The tracer arg is only used for 'get_attr' ops
Note: The nnmodule arg is only used for 'call_module' ops

Nodes that are not call_function, call_method, call_module, or get_attr will
Expand All @@ -1288,7 +1288,7 @@ def run_node(output_graph, node, args, kwargs, nnmodule):
assert nnmodule is not None
return nnmodule(*args, **kwargs)
elif op == "get_attr":
return output_graph.get_submodule(node.target)
return tracer.get_submodule(node.target)
elif op == "placeholder":
assert "example_value" in node.meta
return node.meta["example_value"]
Expand All @@ -1299,23 +1299,23 @@ def run_node(output_graph, node, args, kwargs, nnmodule):
raise AssertionError(op)


def get_real_value(node, output_graph):
def get_real_value(node, tracer):
"""
Run the actual computation represented by `node` and return the result.
This will execute any dependent nodes in the graph as well.
"""
cache = output_graph.real_value_cache
cache = tracer.real_value_cache
if node in cache:
return cache[node]

op = node.op
args, kwargs = torch.fx.node.map_arg(
(node.args, node.kwargs),
lambda n: get_real_value(n, output_graph),
lambda n: get_real_value(n, tracer),
)

if op == "call_module":
nn_module = output_graph.nn_modules[node.target]
nn_module = tracer.output_graph.nn_modules[node.target]
if not is_lazy_module(nn_module):
nn_module = copy.deepcopy(nn_module)
else:
Expand All @@ -1326,7 +1326,7 @@ def get_real_value(node, output_graph):
nn_module = None

try:
real_value = run_node(output_graph, node, args, kwargs, nn_module)
real_value = run_node(tracer, node, args, kwargs, nn_module)
cache[node] = real_value
except RuntimeError as e:
raise TorchRuntimeError() from e
Expand Down