Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions torchdynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def __init__(
self.side_effects = SideEffects()
self.code_options = dict(code_options)
self.output_instructions = []
# Node => computed real value (see TensorVariable.get_real_value)
self.real_value_cache = {}

# Not checkpointed
self.compiler_fn = compiler_fn
Expand Down Expand Up @@ -146,6 +148,7 @@ def restore_graphstate(self, state):
if "example_value" in node.meta:
del node.meta["example_value"]
self.graph.erase_node(node)
self.real_value_cache.pop(node, None)

def count_calls(self):
return count_calls(self.graph)
Expand Down Expand Up @@ -387,6 +390,7 @@ def compile_and_call_fx_graph(self, tx, rv, root):
for node in self.graph.nodes:
if "example_value" in node.meta:
del node.meta["example_value"]
self.real_value_cache.clear()

gm = fx.GraphModule(root, self.graph)
gm.recompile()
Expand Down Expand Up @@ -459,6 +463,7 @@ def remove_unused_graphargs(self):
if "example_value" in node.meta:
del node.meta["example_value"]
self.graph.erase_node(node)
self.real_value_cache.pop(node, None)

self.graphargs = [arg for arg in self.graphargs if arg.uses > 0]

Expand Down Expand Up @@ -493,6 +498,7 @@ def cleanup(self):
for node in self.graph.nodes:
if "example_value" in node.meta:
del node.meta["example_value"]
self.real_value_cache.clear()

def create_proxy(
self,
Expand Down
2 changes: 1 addition & 1 deletion torchdynamo/variables/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def call_function(
def invoke_and_store_as_constant(tx, fn, name, options, args, kwargs):
def convert(x):
if isinstance(x, variables.TensorVariable):
return x.proxy.node.meta["example_value"]
return x.get_real_value()
return x.as_python_constant()

args = [convert(x) for x in args]
Expand Down
1 change: 0 additions & 1 deletion torchdynamo/variables/nn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def record_nn_module_stack():
*proxy_args_kwargs(args, kwargs),
current_tx=tx,
),
nnmodule=mod,
**options,
)
else:
Expand Down
204 changes: 124 additions & 80 deletions torchdynamo/variables/tensor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import contextlib
import copy
import functools
import itertools
Expand Down Expand Up @@ -46,6 +45,113 @@
from .lists import SizeVariable


def _run_node(node, args, kwargs, nnmodule):
op = node.op
if op == "call_function":
return node.target(*args, **kwargs)
elif op == "call_method":
return getattr(args[0], node.target)(*args[1:], **kwargs)
elif op == "call_module":
assert nnmodule is not None
return nnmodule(*args, **kwargs)
assert False, op


def _get_real_value(node, output_graph):
"""
Run the actual computation represented by `node` and return the result.
This will execute any dependent nodes in the graph as well.
"""
cache = output_graph.real_value_cache
if node in cache:
return cache[node]

op = node.op
args, kwargs = torch.fx.node.map_arg(
(node.args, node.kwargs),
lambda n: _get_real_value(n, output_graph),
)

if op == "call_module":
nn_module = output_graph.nn_modules[node.target]
if not is_lazy_module(nn_module):
nn_module = copy.deepcopy(nn_module)
else:
# In the case of a lazy module, we want to run
# the pre-hooks which initialize it
nn_module(*args, **kwargs)
else:
nn_module = None

try:
real_value = _run_node(node, args, kwargs, nn_module)
cache[node] = real_value
except RuntimeError as e:
raise TorchRuntimeError() from e
return real_value


def _get_fake_value(node, tx):
"""
Run the computation represented by `node` using fake tensors and return the result.
"""
op = node.op
fake_wrapper = functools.partial(wrap_to_fake_tensor, fake_mode=tx.fake_mode)
from ..utils import wrap_fake_exception

def visit(n: torch.fx.Node):
return n.meta["example_value"]

args, kwargs = torch.fx.node.map_arg((node.args, node.kwargs), visit)
args = tree_map(fake_wrapper, args)
kwargs = tree_map(fake_wrapper, kwargs)

nnmodule = None
if op == "call_module":
nnmodule = tx.output.nn_modules[node.target]

if not is_lazy_module(nnmodule):
nnmodule = deepcopy_to_fake_tensor(nnmodule, tx.fake_mode)

def context():
if hasattr(py_dispatch, "enable_torch_dispatch_mode"):
return py_dispatch.enable_torch_dispatch_mode(tx.fake_mode)
else:
return tx.fake_mode

if op == "call_module" and is_lazy_module(nnmodule):
assert nnmodule is not None
# In the case of a lazy module, we want to run
# the pre-hooks which initialize it
nnmodule(*args, **kwargs)
try:
with context():
return wrap_fake_exception(lambda: _run_node(node, args, kwargs, nnmodule))
except Unsupported:
raise
except RuntimeError as e:
if isinstance(e, DataDependentOutputException):
if config.capture_scalar_outputs and node.target == "item":
return torch.zeros(size=(), dtype=args[0].dtype).item()
else:
unimplemented(f"data dependent operator: {e.func}")
elif isinstance(e, DynamicOutputShapeException):
unimplemented(f"dynamic shape operator: {e.func}")
else:
raise TorchRuntimeError() from e


def _clone_input(value):
if isinstance(value, torch.Tensor):
use_fake_tensors = fake_tensors_available and config.fake_tensor_propagation
# tensor subclasses will not be converted to FakeTensors and need to be cloned
if not use_fake_tensors or not isinstance(value, FakeTensor):
# NB: ensure strides are preserved
value = clone_input(value)

return value


class TensorVariable(VariableTracker):
"""A torch.Tensor input or an intermediate value in the FX graph"""

Expand All @@ -61,27 +167,18 @@ class TensorVariable(VariableTracker):
"is_contiguous",
]

@staticmethod
def propagate_args_kwargs(node):
def visit(n: torch.fx.Node):
return n.meta["example_value"]

return torch.fx.node.map_arg((node.args, node.kwargs), visit)
def get_real_value(self):
"""
Get the actual value represented by this variable if computation is run
using the user-provided inputs.

@staticmethod
def run_proxy(proxy, args, kwargs, nnmodule):
op = proxy.node.op
if op == "call_function":
return proxy.node.target(*args, **kwargs)
elif op == "call_method":
return getattr(args[0], proxy.node.target)(*args[1:], **kwargs)
elif op == "call_module":
assert nnmodule is not None
return nnmodule(*args, **kwargs)
assert False, op
NOTE: this runs actual tensor computation and may be
slow and memory-intensive.
"""
return _get_real_value(self.proxy.node, self.proxy.tracer)

@classmethod
def create(cls, tx, proxy, example_value=None, nnmodule=None, **options):
def create(cls, tx, proxy, example_value=None, **options):
if "guards" in options and options["guards"] is not None:
tx.output.guards.update(options["guards"])

Expand All @@ -92,82 +189,29 @@ def create(cls, tx, proxy, example_value=None, nnmodule=None, **options):
return cls(proxy, **options)

use_fake_tensors = fake_tensors_available and config.fake_tensor_propagation
if use_fake_tensors:
fake_wrapper = functools.partial(
wrap_to_fake_tensor, fake_mode=tx.fake_mode
)
# python errors if the import isnt here
from ..utils import wrap_fake_exception
else:

def wrap_fake_exception(func):
return func()

args = kwargs = None
initial_example_value = example_value

with preserve_rng_state():
if example_value is None:
op = proxy.node.op
args, kwargs = cls.propagate_args_kwargs(proxy.node)
if use_fake_tensors:
args = tree_map(fake_wrapper, args)
kwargs = tree_map(fake_wrapper, kwargs)
if op == "call_module" and not is_lazy_module(nnmodule):
nnmodule = deepcopy_to_fake_tensor(nnmodule, tx.fake_mode)

def context():
if hasattr(py_dispatch, "enable_torch_dispatch_mode"):
return py_dispatch.enable_torch_dispatch_mode(tx.fake_mode)
else:
return tx.fake_mode

example_value = _get_fake_value(proxy.node, tx)
else:
context = contextlib.nullcontext
if op == "call_module" and not is_lazy_module(nnmodule):
nnmodule = copy.deepcopy(nnmodule)

if op == "call_module" and is_lazy_module(nnmodule):
assert nnmodule is not None
# In the case of a lazy module, we want to run
# the pre-hooks which initialize it
example_value = nnmodule(*args, **kwargs)
try:
with context():
example_value = wrap_fake_exception(
lambda: cls.run_proxy(proxy, args, kwargs, nnmodule)
)
except Unsupported:
raise
except RuntimeError as e:
if use_fake_tensors and isinstance(e, DataDependentOutputException):
if (
config.capture_scalar_outputs
and proxy.node.target == "item"
):
example_value = torch.zeros(
size=(), dtype=args[0].dtype
).item()
else:
unimplemented(f"data dependent operator: {e.func}")
elif use_fake_tensors and isinstance(
e, DynamicOutputShapeException
):
unimplemented(f"dynamic shape operator: {e.func}")
else:
raise TorchRuntimeError() from e
example_value = _get_real_value(proxy.node, tx.output)

else:
proxy.tracer.real_value_cache[proxy.node] = _clone_input(example_value)
if use_fake_tensors:
fake_wrapper = functools.partial(
wrap_to_fake_tensor, fake_mode=tx.fake_mode
)
example_value = fake_wrapper(example_value)

if isinstance(example_value, torch.Tensor):
is_parameter = isinstance(example_value, torch.nn.Parameter)
parameter_value = initial_example_value if is_parameter else None

# tensor subclasses will not be converted to FakeTensors and need to be cloned
if not use_fake_tensors or not isinstance(example_value, FakeTensor):
# NB: ensure strides are preserved
example_value = clone_input(example_value)
example_value = _clone_input(example_value)
proxy.node.meta["example_value"] = example_value
specialized_props = cls.specialize(example_value)
if use_fake_tensors and isinstance(example_value, FakeTensor):
Expand Down
2 changes: 1 addition & 1 deletion torchdynamo/variables/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def call_function(

def unwrap_real(arg):
if isinstance(arg, TensorVariable):
return arg.as_proxy().node.meta["example_value"]
return arg.get_real_value()
if isinstance(arg, UserFunctionVariable):
return arg.fn
if arg.has_unpack_var_sequence(tx):
Expand Down