Skip to content

Commit

Permalink
[reland][export] make aot_export_module uses dynamo's fake_mode (#114009
Browse files Browse the repository at this point in the history
)

Retry landing #113681

Fixes #110100.

Pull Request resolved: #114009
Approved by: https://github.com/angelayi
  • Loading branch information
ydwu4 authored and pytorchmergebot committed Nov 18, 2023
1 parent 310e306 commit 46542f6
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 23 deletions.
81 changes: 81 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -1612,6 +1612,87 @@ def forward(self, x):
inp = (torch.randn(2, 8),)
ep = export(M(), inp) # This errors because dynamo adds an extra input

def test_export_with_fake_tensor_inputs(self):
fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()

class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(2, 2)

def forward(self, x):
out = self.linear(x)
return out

# Put the inputs on a device
with fake_mode, torch.device('meta'):
x = torch.rand(5, 2, 2)
model = Model()

def check_device_and_fake_mode():
exported_program = torch.export.export(model, (x,))
export_res = exported_program(x)
exp_res = model(x)
all_meta_val = [node.meta["val"] for node in exported_program.graph_module.graph.nodes if 'val' in node.meta]
self.assertTrue(export_res.size() == exp_res.size())
self.assertTrue(all(val.device == x.device for val in all_meta_val))
self.assertTrue(all(val.fake_mode is all_meta_val[0].fake_mode for val in all_meta_val))

check_device_and_fake_mode()

def test_export_with_fake_tensor_inputs_on_cuda_devices(self):
fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()

class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(2, 2)

def forward(self, x):
out = self.linear(x)
return out

# Put the inputs on a device
with fake_mode, torch.device('meta'):
x = torch.rand(5, 2, 2)
model = Model()

# Manualy set the fake_device of fake tensors.
x.fake_device = torch.device('cuda:0')
for n, p in model.named_parameters():
p.fake_device = torch.device('cuda:0')

# Need to set all the requires_grad of tensors to False, because fake_tensor with CUDA device
# doesn't quite work well with aot_autograd right now due to some logic fails
# the check in call getDeviceGuardImpl in InputMetadata.
x.requires_grad = False
for n, p in model.named_parameters():
p.requires_grad = False


def check_device_and_fake_mode():
exported_program = torch.export.export(model, (x,))
export_res = exported_program(x)
exp_res = model(x)
all_meta_val = [node.meta["val"] for node in exported_program.graph_module.graph.nodes if 'val' in node.meta]
self.assertTrue(export_res.size() == exp_res.size())
self.assertTrue(all(val.device == x.device for val in all_meta_val))
self.assertTrue(all(val.fake_mode is all_meta_val[0].fake_mode for val in all_meta_val))

check_device_and_fake_mode()


def test_export_graph_with_no_inputs(self):
# We saw this pattern when users want to export
# a graph that initlizes the states of a model.
def f():
return torch.randn(3, 4), torch.randn(3, 4)

ep = torch.export.export(f, ())
a, b = ep()
self.assertEqual(a.size(), torch.Size([3, 4]))
self.assertEqual(b.size(), torch.Size([3, 4]))


if __name__ == '__main__':
run_tests()
71 changes: 48 additions & 23 deletions torch/_export/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import dataclasses
import functools
import io
import json
import pathlib
Expand Down Expand Up @@ -54,7 +55,7 @@
)
from torch.fx import traceback as fx_traceback
from torch.fx._compatibility import compatibility
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
from torch.fx.experimental.symbolic_shapes import (
ConstraintViolationError,
GuardOnDataDependentSymNode,
Expand Down Expand Up @@ -376,7 +377,7 @@ def _train(self, mode: bool = True):
def _eval(self, mode: bool = True):
raise NotImplementedError("Calling eval() is not supported yet.")

_, _, fake_mode = _convert_input_to_fake(m, args, kwargs)
_, _, _, fake_mode = _convert_input_to_fake(m, args, kwargs)

m.meta["inline_constraints"] = {
k: v
Expand All @@ -397,15 +398,11 @@ def _eval(self, mode: bool = True):


def _convert_input_to_fake(gm, args, kwargs):
fake_inps: List[torch.Tensor] = []
fake_mode = FakeTensorMode(
allow_fallback_kernels=False,
allow_non_fake_inputs=True,
shape_env=ShapeEnv(
assume_static_by_default=True,
),
)
if len(args) == 0 and len(kwargs) == 0 and len(dict(gm.named_parameters())) == 0 and len(dict(gm.named_buffers())) == 0:
return [], {}, {}, None

fake_inps: List[torch.Tensor] = []
fake_mode = None
for node in gm.graph.nodes:
if node.op == "placeholder" and "val" in node.meta:
fake_val = node.meta["val"]
Expand All @@ -415,6 +412,8 @@ def _convert_input_to_fake(gm, args, kwargs):
if detected_fake_mode := detect_fake_mode(fake_inps):
fake_mode = detected_fake_mode

assert fake_mode is not None, "Cannot find fake_mode attatched to the graph's placeholders."

count = 0

def convert_to_fake(x):
Expand All @@ -426,7 +425,11 @@ def convert_to_fake(x):
fake_args = pytree.tree_map_only(torch.Tensor, convert_to_fake, args)
# TODO properly use the cached fake tensor
fake_kwargs = pytree.tree_map_only(torch.Tensor, fake_mode.from_tensor, kwargs)
return fake_args, fake_kwargs, fake_mode
fake_params_buffers = pytree.tree_map_only(torch.Tensor,
fake_mode.from_tensor,
{**dict(gm.named_parameters(remove_duplicate=False)),
**dict(gm.named_buffers(remove_duplicate=False))})
return fake_args, fake_kwargs, fake_params_buffers, fake_mode


def _replace_param_buffer_names(param_buffer_table, sig):
Expand Down Expand Up @@ -554,7 +557,17 @@ def export(
preserve_module_call_signature=preserve_module_call_signature,
)

def _disable_prexisiting_fake_mode(fn):

@functools.wraps(fn)
def wrapper(*args, **kwargs):
with maybe_disable_fake_tensor_mode():
return fn(*args, **kwargs)

return wrapper


@_disable_prexisiting_fake_mode
def _export(
f: Callable,
args: Tuple[Any, ...],
Expand Down Expand Up @@ -601,7 +614,8 @@ def _export(
for name, buffer in gm_torch_level.named_buffers(remove_duplicate=False):
params_buffers[name] = buffer

fake_args, fake_kwargs, fake_mode = _convert_input_to_fake(gm_torch_level, args, kwargs)
# We detect the fake_mode by looking at gm_torch_level's placeholders, this is the fake_mode created in dynamo.
fake_args, fake_kwargs, fake_params_buffers, dynamo_fake_mode = _convert_input_to_fake(gm_torch_level, args, kwargs)

# First, we want to pass through the graph to try populating
# val field for getattr if there is anything missing.
Expand All @@ -612,7 +626,10 @@ def _export(
attr = getattr(gm_torch_level, node.target)
# Checks if it is not a HigherOrderOp branch or a module
if not isinstance(attr, torch.nn.Module):
node.meta["val"] = fake_mode.from_tensor(attr, static_shapes=True)
assert dynamo_fake_mode is not None, (
"Cannot find dynamo_fake_mode. This could be due to the exported graph module have no placeholders."
)
node.meta["val"] = dynamo_fake_mode.from_tensor(attr, static_shapes=True)

# When aot_export lifts the params, we lose the nn_module_stack
# and source_fn from the param nodes as they are treated as fresh inputs
Expand Down Expand Up @@ -686,11 +703,16 @@ def _export(

# Note: aot_export_module doesn't accept kwargs, we'd like to reorder the kwargs as an OrderedDict
# to follow the order in orig_args and correctly call gm_torch_level
gm, graph_signature = aot_export_module(
gm_torch_level,
(*fake_args, *_reorder_kwargs_by_names(orig_args, fake_args, fake_kwargs).values()),
trace_joint=False
)

# This _reparametrize_module makes sure inputs and gm_torch_level.params/buffers have the same fake_mode,
# otherwise aot_export_module will error out because it sees a mix of fake_modes.
# And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.
with torch.nn.utils.stateless._reparametrize_module(gm_torch_level, fake_params_buffers):
gm, graph_signature = aot_export_module(
gm_torch_level,
(*fake_args, *_reorder_kwargs_by_names(orig_args, fake_args, fake_kwargs).values()),
trace_joint=False
)

def to_str_list(sig_component: List[Any]):
return [str(v) for v in sig_component]
Expand Down Expand Up @@ -719,11 +741,14 @@ def to_str_dict(sig_component: Dict[Any, Any]):

# The unbacked symint symbols are updated in aot_export
# so we serialize them here instead of inside dynamo
gm.meta["inline_constraints"] = {
k: v
for k, v in fake_mode.shape_env.runtime_var_to_range.items()
if re.match(r"^[if]\d+$", str(k))
}

# dynamo_fake_mode can be None if there's no placeholder in gm_torch_level
if dynamo_fake_mode:
gm.meta["inline_constraints"] = {
k: v
for k, v in dynamo_fake_mode.shape_env.runtime_var_to_range.items()
if re.match(r"^[if]\d+$", str(k))
}

# After aot_export, set the param/buffer metadata back into placeholders
# Technically, users can still construct this data from param names
Expand Down

0 comments on commit 46542f6

Please sign in to comment.