Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[reland][export] make aot_export_module uses dynamo's fake_mode #114009

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
81 changes: 81 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -1612,6 +1612,87 @@ def forward(self, x):
inp = (torch.randn(2, 8),)
ep = export(M(), inp) # This errors because dynamo adds an extra input

def test_export_with_fake_tensor_inputs(self):
fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()

class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(2, 2)

def forward(self, x):
out = self.linear(x)
return out

# Put the inputs on a device
with fake_mode, torch.device('meta'):
x = torch.rand(5, 2, 2)
model = Model()

def check_device_and_fake_mode():
exported_program = torch.export.export(model, (x,))
export_res = exported_program(x)
exp_res = model(x)
all_meta_val = [node.meta["val"] for node in exported_program.graph_module.graph.nodes if 'val' in node.meta]
self.assertTrue(export_res.size() == exp_res.size())
self.assertTrue(all(val.device == x.device for val in all_meta_val))
self.assertTrue(all(val.fake_mode is all_meta_val[0].fake_mode for val in all_meta_val))

check_device_and_fake_mode()

def test_export_with_fake_tensor_inputs_on_cuda_devices(self):
fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()

class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(2, 2)

def forward(self, x):
out = self.linear(x)
return out

# Put the inputs on a device
with fake_mode, torch.device('meta'):
x = torch.rand(5, 2, 2)
model = Model()

# Manualy set the fake_device of fake tensors.
x.fake_device = torch.device('cuda:0')
for n, p in model.named_parameters():
p.fake_device = torch.device('cuda:0')

# Need to set all the requires_grad of tensors to False, because fake_tensor with CUDA device
# doesn't quite work well with aot_autograd right now due to some logic fails
# the check in call getDeviceGuardImpl in InputMetadata.
x.requires_grad = False
for n, p in model.named_parameters():
p.requires_grad = False


def check_device_and_fake_mode():
exported_program = torch.export.export(model, (x,))
export_res = exported_program(x)
exp_res = model(x)
all_meta_val = [node.meta["val"] for node in exported_program.graph_module.graph.nodes if 'val' in node.meta]
self.assertTrue(export_res.size() == exp_res.size())
self.assertTrue(all(val.device == x.device for val in all_meta_val))
self.assertTrue(all(val.fake_mode is all_meta_val[0].fake_mode for val in all_meta_val))

check_device_and_fake_mode()


def test_export_graph_with_no_inputs(self):
# We saw this pattern when users want to export
# a graph that initlizes the states of a model.
def f():
return torch.randn(3, 4), torch.randn(3, 4)

ep = torch.export.export(f, ())
a, b = ep()
self.assertEqual(a.size(), torch.Size([3, 4]))
self.assertEqual(b.size(), torch.Size([3, 4]))


if __name__ == '__main__':
run_tests()
71 changes: 48 additions & 23 deletions torch/_export/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import dataclasses
import functools
import io
import json
import pathlib
Expand Down Expand Up @@ -54,7 +55,7 @@
)
from torch.fx import traceback as fx_traceback
from torch.fx._compatibility import compatibility
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
from torch.fx.experimental.symbolic_shapes import (
ConstraintViolationError,
GuardOnDataDependentSymNode,
Expand Down Expand Up @@ -376,7 +377,7 @@ def _train(self, mode: bool = True):
def _eval(self, mode: bool = True):
raise NotImplementedError("Calling eval() is not supported yet.")

_, _, fake_mode = _convert_input_to_fake(m, args, kwargs)
_, _, _, fake_mode = _convert_input_to_fake(m, args, kwargs)

m.meta["inline_constraints"] = {
k: v
Expand All @@ -397,15 +398,11 @@ def _eval(self, mode: bool = True):


def _convert_input_to_fake(gm, args, kwargs):
fake_inps: List[torch.Tensor] = []
fake_mode = FakeTensorMode(
allow_fallback_kernels=False,
allow_non_fake_inputs=True,
shape_env=ShapeEnv(
assume_static_by_default=True,
),
)
if len(args) == 0 and len(kwargs) == 0 and len(dict(gm.named_parameters())) == 0 and len(dict(gm.named_buffers())) == 0:
return [], {}, {}, None

fake_inps: List[torch.Tensor] = []
fake_mode = None
for node in gm.graph.nodes:
if node.op == "placeholder" and "val" in node.meta:
fake_val = node.meta["val"]
Expand All @@ -415,6 +412,8 @@ def _convert_input_to_fake(gm, args, kwargs):
if detected_fake_mode := detect_fake_mode(fake_inps):
fake_mode = detected_fake_mode

assert fake_mode is not None, "Cannot find fake_mode attatched to the graph's placeholders."

count = 0

def convert_to_fake(x):
Expand All @@ -426,7 +425,11 @@ def convert_to_fake(x):
fake_args = pytree.tree_map_only(torch.Tensor, convert_to_fake, args)
# TODO properly use the cached fake tensor
fake_kwargs = pytree.tree_map_only(torch.Tensor, fake_mode.from_tensor, kwargs)
return fake_args, fake_kwargs, fake_mode
fake_params_buffers = pytree.tree_map_only(torch.Tensor,
fake_mode.from_tensor,
{**dict(gm.named_parameters(remove_duplicate=False)),
**dict(gm.named_buffers(remove_duplicate=False))})
return fake_args, fake_kwargs, fake_params_buffers, fake_mode


def _replace_param_buffer_names(param_buffer_table, sig):
Expand Down Expand Up @@ -554,7 +557,17 @@ def export(
preserve_module_call_signature=preserve_module_call_signature,
)

def _disable_prexisiting_fake_mode(fn):

@functools.wraps(fn)
def wrapper(*args, **kwargs):
with maybe_disable_fake_tensor_mode():
return fn(*args, **kwargs)

return wrapper


@_disable_prexisiting_fake_mode
def _export(
f: Callable,
args: Tuple[Any, ...],
Expand Down Expand Up @@ -601,7 +614,8 @@ def _export(
for name, buffer in gm_torch_level.named_buffers(remove_duplicate=False):
params_buffers[name] = buffer

fake_args, fake_kwargs, fake_mode = _convert_input_to_fake(gm_torch_level, args, kwargs)
# We detect the fake_mode by looking at gm_torch_level's placeholders, this is the fake_mode created in dynamo.
fake_args, fake_kwargs, fake_params_buffers, dynamo_fake_mode = _convert_input_to_fake(gm_torch_level, args, kwargs)

# First, we want to pass through the graph to try populating
# val field for getattr if there is anything missing.
Expand All @@ -612,7 +626,10 @@ def _export(
attr = getattr(gm_torch_level, node.target)
# Checks if it is not a HigherOrderOp branch or a module
if not isinstance(attr, torch.nn.Module):
node.meta["val"] = fake_mode.from_tensor(attr, static_shapes=True)
assert dynamo_fake_mode is not None, (
"Cannot find dynamo_fake_mode. This could be due to the exported graph module have no placeholders."
)
node.meta["val"] = dynamo_fake_mode.from_tensor(attr, static_shapes=True)

# When aot_export lifts the params, we lose the nn_module_stack
# and source_fn from the param nodes as they are treated as fresh inputs
Expand Down Expand Up @@ -686,11 +703,16 @@ def _export(

# Note: aot_export_module doesn't accept kwargs, we'd like to reorder the kwargs as an OrderedDict
# to follow the order in orig_args and correctly call gm_torch_level
gm, graph_signature = aot_export_module(
gm_torch_level,
(*fake_args, *_reorder_kwargs_by_names(orig_args, fake_args, fake_kwargs).values()),
trace_joint=False
)

# This _reparametrize_module makes sure inputs and gm_torch_level.params/buffers have the same fake_mode,
# otherwise aot_export_module will error out because it sees a mix of fake_modes.
# And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.
with torch.nn.utils.stateless._reparametrize_module(gm_torch_level, fake_params_buffers):
gm, graph_signature = aot_export_module(
gm_torch_level,
(*fake_args, *_reorder_kwargs_by_names(orig_args, fake_args, fake_kwargs).values()),
trace_joint=False
)

def to_str_list(sig_component: List[Any]):
return [str(v) for v in sig_component]
Expand Down Expand Up @@ -719,11 +741,14 @@ def to_str_dict(sig_component: Dict[Any, Any]):

# The unbacked symint symbols are updated in aot_export
# so we serialize them here instead of inside dynamo
gm.meta["inline_constraints"] = {
k: v
for k, v in fake_mode.shape_env.runtime_var_to_range.items()
if re.match(r"^[if]\d+$", str(k))
}

# dynamo_fake_mode can be None if there's no placeholder in gm_torch_level
if dynamo_fake_mode:
gm.meta["inline_constraints"] = {
k: v
for k, v in dynamo_fake_mode.shape_env.runtime_var_to_range.items()
if re.match(r"^[if]\d+$", str(k))
}

# After aot_export, set the param/buffer metadata back into placeholders
# Technically, users can still construct this data from param names
Expand Down