From bc1f00d48e495ee7282bdcce086bd7a5298ccbba Mon Sep 17 00:00:00 2001 From: Michael Suo Date: Wed, 5 Oct 2022 16:11:47 -0700 Subject: [PATCH] [dynamo] Introduce `get_real_value` API to TensorVariable Right now, example_value is doing two jobs: - We use it to propagate metadata (e.g. return type, shapes, etc.) throughout the graph - We use it to satisfy queries for the actual value (e.g. torch.cond, `assume_constant_result`) This is further complicated by the fact that we have two modes, one where `example_value` is a fake tensor, and one where it is a real tensor (this is the `fake_tensor_propagation` config flag). This leads to scenarios where we don't support every combination of job + mode, e.g. if `fake_tensor_propagation=False`, `assume_constant_result` is broken. This is made worse by the fact that "fake tensor mode" is the default and is required if you want dynamic shapes to work. So, this PR introduces a `get_real_value` API that just runs the graph up to `node` in order to get a concrete value. This API is orthogonal to `example_value`, so it doesn't care about `fake_tensor_propagation`. When `fake_tensor_propagation=True`: `example_value` is a fake tensor, you must use the `get_real_value` API to get a concrete value. This will be the only configuration in the future. When `fake_tensor_propagation=False`: `example_value` and `get_real_value` will produce the same value. This is redundant but we will be removing this config soon. To support this, I introduce a cache for computed real values, to memoize the work involved if we're asking for real values a lot. I attached this state to `OutputGraph` because it seems to be what historically managed `example_value` lifetimes, but idk. --- torchdynamo/output_graph.py | 6 + torchdynamo/variables/functions.py | 2 +- torchdynamo/variables/nn_module.py | 1 - torchdynamo/variables/tensor.py | 204 ++++++++++++++++++----------- torchdynamo/variables/torch.py | 2 +- 5 files changed, 132 insertions(+), 83 deletions(-) diff --git a/torchdynamo/output_graph.py b/torchdynamo/output_graph.py index ad881c5cb5..442e27045d 100644 --- a/torchdynamo/output_graph.py +++ b/torchdynamo/output_graph.py @@ -99,6 +99,8 @@ def __init__( self.side_effects = SideEffects() self.code_options = dict(code_options) self.output_instructions = [] + # Node => computed real value (see TensorVariable.get_real_value) + self.real_value_cache = {} # Not checkpointed self.compiler_fn = compiler_fn @@ -146,6 +148,7 @@ def restore_graphstate(self, state): if "example_value" in node.meta: del node.meta["example_value"] self.graph.erase_node(node) + self.real_value_cache.pop(node, None) def count_calls(self): return count_calls(self.graph) @@ -387,6 +390,7 @@ def compile_and_call_fx_graph(self, tx, rv, root): for node in self.graph.nodes: if "example_value" in node.meta: del node.meta["example_value"] + self.real_value_cache.clear() gm = fx.GraphModule(root, self.graph) gm.recompile() @@ -459,6 +463,7 @@ def remove_unused_graphargs(self): if "example_value" in node.meta: del node.meta["example_value"] self.graph.erase_node(node) + self.real_value_cache.pop(node, None) self.graphargs = [arg for arg in self.graphargs if arg.uses > 0] @@ -493,6 +498,7 @@ def cleanup(self): for node in self.graph.nodes: if "example_value" in node.meta: del node.meta["example_value"] + self.real_value_cache.clear() def create_proxy( self, diff --git a/torchdynamo/variables/functions.py b/torchdynamo/variables/functions.py index f69b79dfbe..2a2a9256e5 100644 --- a/torchdynamo/variables/functions.py +++ b/torchdynamo/variables/functions.py @@ -269,7 +269,7 @@ def call_function( def invoke_and_store_as_constant(tx, fn, name, options, args, kwargs): def convert(x): if isinstance(x, variables.TensorVariable): - return x.proxy.node.meta["example_value"] + return x.get_real_value() return x.as_python_constant() args = [convert(x) for x in args] diff --git a/torchdynamo/variables/nn_module.py b/torchdynamo/variables/nn_module.py index 8247b475b4..b76dad4864 100644 --- a/torchdynamo/variables/nn_module.py +++ b/torchdynamo/variables/nn_module.py @@ -204,7 +204,6 @@ def record_nn_module_stack(): *proxy_args_kwargs(args, kwargs), current_tx=tx, ), - nnmodule=mod, **options, ) else: diff --git a/torchdynamo/variables/tensor.py b/torchdynamo/variables/tensor.py index 454b5c110f..fe9327246b 100644 --- a/torchdynamo/variables/tensor.py +++ b/torchdynamo/variables/tensor.py @@ -1,4 +1,3 @@ -import contextlib import copy import functools import itertools @@ -46,6 +45,113 @@ from .lists import SizeVariable +def _run_node(node, args, kwargs, nnmodule): + op = node.op + if op == "call_function": + return node.target(*args, **kwargs) + elif op == "call_method": + return getattr(args[0], node.target)(*args[1:], **kwargs) + elif op == "call_module": + assert nnmodule is not None + return nnmodule(*args, **kwargs) + assert False, op + + +def _get_real_value(node, output_graph): + """ + Run the actual computation represented by `node` and return the result. + This will execute any dependent nodes in the graph as well. + """ + cache = output_graph.real_value_cache + if node in cache: + return cache[node] + + op = node.op + args, kwargs = torch.fx.node.map_arg( + (node.args, node.kwargs), + lambda n: _get_real_value(n, output_graph), + ) + + if op == "call_module": + nn_module = output_graph.nn_modules[node.target] + if not is_lazy_module(nn_module): + nn_module = copy.deepcopy(nn_module) + else: + # In the case of a lazy module, we want to run + # the pre-hooks which initialize it + nn_module(*args, **kwargs) + else: + nn_module = None + + try: + real_value = _run_node(node, args, kwargs, nn_module) + cache[node] = real_value + except RuntimeError as e: + raise TorchRuntimeError() from e + return real_value + + +def _get_fake_value(node, tx): + """ + Run the computation represented by `node` using fake tensors and return the result. + """ + op = node.op + fake_wrapper = functools.partial(wrap_to_fake_tensor, fake_mode=tx.fake_mode) + from ..utils import wrap_fake_exception + + def visit(n: torch.fx.Node): + return n.meta["example_value"] + + args, kwargs = torch.fx.node.map_arg((node.args, node.kwargs), visit) + args = tree_map(fake_wrapper, args) + kwargs = tree_map(fake_wrapper, kwargs) + + nnmodule = None + if op == "call_module": + nnmodule = tx.output.nn_modules[node.target] + + if not is_lazy_module(nnmodule): + nnmodule = deepcopy_to_fake_tensor(nnmodule, tx.fake_mode) + + def context(): + if hasattr(py_dispatch, "enable_torch_dispatch_mode"): + return py_dispatch.enable_torch_dispatch_mode(tx.fake_mode) + else: + return tx.fake_mode + + if op == "call_module" and is_lazy_module(nnmodule): + assert nnmodule is not None + # In the case of a lazy module, we want to run + # the pre-hooks which initialize it + nnmodule(*args, **kwargs) + try: + with context(): + return wrap_fake_exception(lambda: _run_node(node, args, kwargs, nnmodule)) + except Unsupported: + raise + except RuntimeError as e: + if isinstance(e, DataDependentOutputException): + if config.capture_scalar_outputs and node.target == "item": + return torch.zeros(size=(), dtype=args[0].dtype).item() + else: + unimplemented(f"data dependent operator: {e.func}") + elif isinstance(e, DynamicOutputShapeException): + unimplemented(f"dynamic shape operator: {e.func}") + else: + raise TorchRuntimeError() from e + + +def _clone_input(value): + if isinstance(value, torch.Tensor): + use_fake_tensors = fake_tensors_available and config.fake_tensor_propagation + # tensor subclasses will not be converted to FakeTensors and need to be cloned + if not use_fake_tensors or not isinstance(value, FakeTensor): + # NB: ensure strides are preserved + value = clone_input(value) + + return value + + class TensorVariable(VariableTracker): """A torch.Tensor input or an intermediate value in the FX graph""" @@ -61,27 +167,18 @@ class TensorVariable(VariableTracker): "is_contiguous", ] - @staticmethod - def propagate_args_kwargs(node): - def visit(n: torch.fx.Node): - return n.meta["example_value"] - - return torch.fx.node.map_arg((node.args, node.kwargs), visit) + def get_real_value(self): + """ + Get the actual value represented by this variable if computation is run + using the user-provided inputs. - @staticmethod - def run_proxy(proxy, args, kwargs, nnmodule): - op = proxy.node.op - if op == "call_function": - return proxy.node.target(*args, **kwargs) - elif op == "call_method": - return getattr(args[0], proxy.node.target)(*args[1:], **kwargs) - elif op == "call_module": - assert nnmodule is not None - return nnmodule(*args, **kwargs) - assert False, op + NOTE: this runs actual tensor computation and may be + slow and memory-intensive. + """ + return _get_real_value(self.proxy.node, self.proxy.tracer) @classmethod - def create(cls, tx, proxy, example_value=None, nnmodule=None, **options): + def create(cls, tx, proxy, example_value=None, **options): if "guards" in options and options["guards"] is not None: tx.output.guards.update(options["guards"]) @@ -92,82 +189,29 @@ def create(cls, tx, proxy, example_value=None, nnmodule=None, **options): return cls(proxy, **options) use_fake_tensors = fake_tensors_available and config.fake_tensor_propagation - if use_fake_tensors: - fake_wrapper = functools.partial( - wrap_to_fake_tensor, fake_mode=tx.fake_mode - ) - # python errors if the import isnt here - from ..utils import wrap_fake_exception - else: - def wrap_fake_exception(func): - return func() - - args = kwargs = None initial_example_value = example_value with preserve_rng_state(): if example_value is None: - op = proxy.node.op - args, kwargs = cls.propagate_args_kwargs(proxy.node) if use_fake_tensors: - args = tree_map(fake_wrapper, args) - kwargs = tree_map(fake_wrapper, kwargs) - if op == "call_module" and not is_lazy_module(nnmodule): - nnmodule = deepcopy_to_fake_tensor(nnmodule, tx.fake_mode) - - def context(): - if hasattr(py_dispatch, "enable_torch_dispatch_mode"): - return py_dispatch.enable_torch_dispatch_mode(tx.fake_mode) - else: - return tx.fake_mode - + example_value = _get_fake_value(proxy.node, tx) else: - context = contextlib.nullcontext - if op == "call_module" and not is_lazy_module(nnmodule): - nnmodule = copy.deepcopy(nnmodule) - - if op == "call_module" and is_lazy_module(nnmodule): - assert nnmodule is not None - # In the case of a lazy module, we want to run - # the pre-hooks which initialize it - example_value = nnmodule(*args, **kwargs) - try: - with context(): - example_value = wrap_fake_exception( - lambda: cls.run_proxy(proxy, args, kwargs, nnmodule) - ) - except Unsupported: - raise - except RuntimeError as e: - if use_fake_tensors and isinstance(e, DataDependentOutputException): - if ( - config.capture_scalar_outputs - and proxy.node.target == "item" - ): - example_value = torch.zeros( - size=(), dtype=args[0].dtype - ).item() - else: - unimplemented(f"data dependent operator: {e.func}") - elif use_fake_tensors and isinstance( - e, DynamicOutputShapeException - ): - unimplemented(f"dynamic shape operator: {e.func}") - else: - raise TorchRuntimeError() from e + example_value = _get_real_value(proxy.node, tx.output) + else: + proxy.tracer.real_value_cache[proxy.node] = _clone_input(example_value) if use_fake_tensors: + fake_wrapper = functools.partial( + wrap_to_fake_tensor, fake_mode=tx.fake_mode + ) example_value = fake_wrapper(example_value) if isinstance(example_value, torch.Tensor): is_parameter = isinstance(example_value, torch.nn.Parameter) parameter_value = initial_example_value if is_parameter else None - # tensor subclasses will not be converted to FakeTensors and need to be cloned - if not use_fake_tensors or not isinstance(example_value, FakeTensor): - # NB: ensure strides are preserved - example_value = clone_input(example_value) + example_value = _clone_input(example_value) proxy.node.meta["example_value"] = example_value specialized_props = cls.specialize(example_value) if use_fake_tensors and isinstance(example_value, FakeTensor): diff --git a/torchdynamo/variables/torch.py b/torchdynamo/variables/torch.py index 2b596ab681..a4dc0b78b2 100644 --- a/torchdynamo/variables/torch.py +++ b/torchdynamo/variables/torch.py @@ -551,7 +551,7 @@ def call_function( def unwrap_real(arg): if isinstance(arg, TensorVariable): - return arg.as_proxy().node.meta["example_value"] + return arg.get_real_value() if isinstance(arg, UserFunctionVariable): return arg.fn if arg.has_unpack_var_sequence(tx):