Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in graph partitioner and update graph signature after partitioning. #125133

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
76 changes: 72 additions & 4 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@
"(Tensor x) -> (Tensor)",
tags=torch.Tag.pt2_compliant_tag,
)
torch.library.define(
"testlib::foo_unbacked",
"(Scalar x) -> (Tensor)",
tags=torch.Tag.pt2_compliant_tag,
)


@torch.library.impl("testlib::returns_tensor_symint", "cpu")
Expand Down Expand Up @@ -125,6 +130,15 @@ def foo_functional(x):
return a.cos()


@torch.library.impl("testlib::foo_unbacked", "CompositeImplicitAutograd")
def foo_unbacked(x):
if x > 2:
return torch.ones(4, 4)
if x < 6:
return torch.ones(4, 4)
return torch.ones(4, 4)


@dataclass
class Inp:
x: Tensor
Expand Down Expand Up @@ -2380,6 +2394,7 @@ def forward(self, x, y):

ep = export(M(), (torch.tensor(1), torch.ones(4, 5)))

# This is because we insert sym_constrain_range in the graph now
if is_non_strict_test(self._testMethodName):
error_msg = "Invalid value range"
else:
Expand Down Expand Up @@ -4043,16 +4058,16 @@ def forward(self, b_pred, b_t, x, y):
"""\
def forward(self, b_t, x, y):
submod_3 = self.submod_1
add_1 = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(True, submod_3, b_t, x, y); submod_3 = b_t = x = y = None
add_1 = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(True, submod_3, x, b_t, y); submod_3 = x = b_t = y = None
return (add_1,)""",
)

self.assertExpectedInline(
str(exported_program.graph_module.true_graph_0.submod_1.code.strip()),
"""\
def forward(self, b_t, x, y):
sub = torch.ops.aten.sub.Tensor(b_t, 1); b_t = None
add = torch.ops.aten.add.Tensor(sub, x); sub = x = None
def forward(self, x, b_t, y):
sub = torch.ops.aten.sub.Tensor(x, 1); x = None
add = torch.ops.aten.add.Tensor(sub, b_t); sub = b_t = None
add_1 = torch.ops.aten.add.Tensor(add, y); add = y = None
return add_1""",
)
Expand Down Expand Up @@ -4552,6 +4567,59 @@ def forward(self, x, y, div="floor"):
self.assertEqual(div_spec.arg.name, "div")
self.assertEqual(div_spec.arg.value, "floor")

def test_unbacked_deferred_runtime_retrace(self):
class Foo(torch.nn.Module):
def forward(self, x, y):
y_sum = y.sin().sum()
with torch.no_grad():
a = x.item()
torch._check_is_size(a)
torch._check(a > 2)
torch._check(a < 6)
unbacked_shape = torch.ops.testlib.foo_unbacked(a)
return y + y_sum + unbacked_shape.sum()

inps = (torch.tensor(4), torch.randn(5, 5))
from torch.export import _trace
ep_pre = _trace._export(Foo(), inps, pre_dispatch=True, strict=False)
self.assertExpectedInline(str(ep_pre.graph_module.submod_1.code).strip(), """\
def forward(self, x):
item = torch.ops.aten.item.default(x); x = None
sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(item)
sym_constrain_range_default = torch.ops.aten.sym_constrain_range.default(item, min = 3, max = 5)
mul = -1 * item
le = mul <= 0; mul = None
_assert_scalar_default = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression -u1 <= 0 on node 'le'\\nMore context: %mul : [num_users=1] = call_function[target=operator.mul](args = (-1, %item), kwargs = {})\\n%le : [num_users=0] = call_function[target=operator.le](args = (%mul, 0), kwargs = {})"); le = None
mul_1 = -1 * item
lt = mul_1 < -2; mul_1 = None
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(lt, "Runtime assertion failed for expression -u1 < -2 on node 'lt'\\nMore context: %mul_1 : [num_users=1] = call_function[target=operator.mul](args = (-1, %item), kwargs = {})\\n%lt : [num_users=0] = call_function[target=operator.lt](args = (%mul_1, -2), kwargs = {})"); lt = None
lt_1 = item < 6
_assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(lt_1, "Runtime assertion failed for expression u1 < 6 on node 'lt_1'\\nMore context: %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%lt, Runtime assertion failed for expression -u1 < -2 on node 'lt'\\nMore context: %mul_1 : [num_users=1] = call_function[target=operator.mul](args = (-1, %item), kwargs = {})\\n%lt : [num_users=0] = call_function[target=operator.lt](args = (%mul_1, -2), kwargs = {})), kwargs = {})\\n%lt_1 : [num_users=0] = call_function[target=operator.lt](args = (%item, 6), kwargs = {})"); lt_1 = None
foo_unbacked = torch.ops.testlib.foo_unbacked.default(item); item = None
return foo_unbacked""")
ep_aot = ep_pre.run_decompositions()
self.assertExpectedInline(str(ep_aot.graph_module.code).strip(), """\
def forward(self, x, y):
sin = torch.ops.aten.sin.default(y)
sum_1 = torch.ops.aten.sum.dim_IntList(sin, []); sin = None
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x); x = None
sym_constrain_range_for_size = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense)
sym_constrain_range = torch.ops.aten.sym_constrain_range.default(_local_scalar_dense, min = 3, max = 5)
mul = -1 * _local_scalar_dense
le = mul <= 0; mul = None
_assert_scalar = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression -u1 <= 0 on node 'le'\\nMore context: %mul : [num_users=1] = call_function[target=operator.mul](args = (-1, %item), kwargs = {})\\n%le : [num_users=0] = call_function[target=operator.le](args = (%mul, 0), kwargs = {})"); le = None
mul_1 = -1 * _local_scalar_dense
lt = mul_1 < -2; mul_1 = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(lt, "Runtime assertion failed for expression -u1 < -2 on node 'lt'\\nMore context: %mul_1 : [num_users=1] = call_function[target=operator.mul](args = (-1, %item), kwargs = {})\\n%lt : [num_users=0] = call_function[target=operator.lt](args = (%mul_1, -2), kwargs = {})"); lt = None
lt_1 = _local_scalar_dense < 6; _local_scalar_dense = None
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(lt_1, "Runtime assertion failed for expression u1 < 6 on node 'lt_1'\\nMore context: %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%lt, Runtime assertion failed for expression -u1 < -2 on node 'lt'\\nMore context: %mul_1 : [num_users=1] = call_function[target=operator.mul](args = (-1, %item), kwargs = {})\\n%lt : [num_users=0] = call_function[target=operator.lt](args = (%mul_1, -2), kwargs = {})), kwargs = {})\\n%lt_1 : [num_users=0] = call_function[target=operator.lt](args = (%item, 6), kwargs = {})"); lt_1 = None
full = torch.ops.aten.full.default([4, 4], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
add = torch.ops.aten.add.Tensor(y, sum_1); y = sum_1 = None
sum_2 = torch.ops.aten.sum.dim_IntList(full, []); full = None
add_1 = torch.ops.aten.add.Tensor(add, sum_2); add = sum_2 = None
return (add_1,)""")


def test_nested_dynamic_shapes_spec(self):
class Foo(torch.nn.Module):
def forward(self, x):
Expand Down
46 changes: 27 additions & 19 deletions torch/_export/passes/replace_set_grad_with_hop_pass.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import copy

import torch
Expand Down Expand Up @@ -125,7 +126,9 @@ def _remove_set_grad_and_inline(node: torch.fx.Node):
node_inline_(node)


def _sequential_split_and_maybe_inline_subgraphs(gm: torch.fx.GraphModule):
def _sequential_split_and_maybe_inline_subgraphs(
gm: torch.fx.GraphModule, graph_signature
):
"""
Helper function for replace_set_grad_with_hop_pass().
Split the graph module into multiple subgraphs based on the set_grad_enabled nodes.
Expand All @@ -141,35 +144,40 @@ def _sequential_split_and_maybe_inline_subgraphs(gm: torch.fx.GraphModule):
if need_replacing:
new_gm = sequential_split(gm, _is_set_grad_enabled_node)

def _maybe_inline_or_replace_with_hop(node: torch.fx.Node):
if _is_set_grad_enabled_sub_mod(node, omit_if_same_with_ambient=True):
_replace_with_hop(node)
else:
_remove_set_grad_and_inline(node)

nodes_map(
list(new_gm.graph.nodes),
lambda node: (
_maybe_inline_or_replace_with_hop(node)
if node.op == "call_module"
else node
),
)
replace_ctx = contextlib.nullcontext()
if graph_signature is not None:
replace_ctx = new_gm._set_replace_hook(graph_signature.get_replace_hook()) # type: ignore[assignment]

with replace_ctx:

def _maybe_inline_or_replace_with_hop(node: torch.fx.Node):
if _is_set_grad_enabled_sub_mod(node, omit_if_same_with_ambient=True):
_replace_with_hop(node)
else:
_remove_set_grad_and_inline(node)
Comment on lines +153 to +157
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can define this function out of the context manager.


nodes_map(
list(new_gm.graph.nodes),
lambda node: (
_maybe_inline_or_replace_with_hop(node)
if node.op == "call_module"
else node
),
)
return new_gm

return gm


def replace_set_grad_with_hop_pass(gm: torch.fx.GraphModule):
new_gm = _sequential_split_and_maybe_inline_subgraphs(gm)

def replace_set_grad_with_hop_pass(gm: torch.fx.GraphModule, graph_signature):
Copy link
Contributor

@ydwu4 ydwu4 Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe a type annotation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah adding a type annotation here is better.

new_gm = _sequential_split_and_maybe_inline_subgraphs(gm, graph_signature)
# recursively call
for node in new_gm.graph.nodes:
if node.op == "get_attr":
subgm = getattr(new_gm, node.target)
if not isinstance(subgm, torch.fx.GraphModule):
continue
new_subgm = replace_set_grad_with_hop_pass(subgm)
new_subgm = replace_set_grad_with_hop_pass(subgm, None)
setattr(new_gm, node.target, new_subgm)

new_gm.recompile()
Expand Down
70 changes: 35 additions & 35 deletions torch/export/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,41 +508,6 @@ def _compiling_state_context():
if isinstance(mod, torch.fx.GraphModule) and hasattr(mod, "meta"):
gm.meta.update(mod.meta)

if pre_dispatch:
from torch._export.passes.replace_set_grad_with_hop_pass import (
replace_set_grad_with_hop_pass,
)

gm = replace_set_grad_with_hop_pass(gm)

# Remove nn_module_stack, stack_trace metadata from all placeholders/inputs nodes.
for _mod in gm.modules():
if not isinstance(_mod, torch.fx.GraphModule):
continue
for node in _mod.graph.nodes:
if node.op in ["placeholder", "output"]:
node.meta.pop("nn_module_stack", None)
node.meta.pop("stack_trace", None)

# NOTE: aot_export adds symint metadata for placeholders with int values;
# since these become specialized, we replace such metadata with the original values
flat_args = pytree.tree_leaves((fake_args, fake_kwargs))
index = 0
total_non_user_inputs = (
len(graph_signature.parameters)
+ len(graph_signature.buffers)
+ len(graph_signature.input_tokens)
)
for node in gm.graph.nodes:
if node.op == "placeholder":
if index >= total_non_user_inputs:
user_arg = flat_args[index - total_non_user_inputs]
if not isinstance(user_arg, torch.Tensor):
node.meta["val"] = user_arg
index += 1

is_joint = graph_signature.backward_signature is not None

def make_argument_spec(i, node) -> ArgumentSpec:
if isinstance(node, (int, bool, float, type(None))):
# For const outputs we just directly return this
Expand Down Expand Up @@ -571,6 +536,25 @@ def make_argument_spec(i, node) -> ArgumentSpec:
f"while writing the metadata for exported program"
)

is_joint = graph_signature.backward_signature is not None

# NOTE: aot_export adds symint metadata for placeholders with int values;
# since these become specialized, we replace such metadata with the original values
flat_args = pytree.tree_leaves((fake_args, fake_kwargs))
index = 0
total_non_user_inputs = (
len(graph_signature.parameters)
+ len(graph_signature.buffers)
+ len(graph_signature.input_tokens)
)
for node in gm.graph.nodes:
if node.op == "placeholder":
if index >= total_non_user_inputs:
user_arg = flat_args[index - total_non_user_inputs]
if not isinstance(user_arg, torch.Tensor):
node.meta["val"] = user_arg
index += 1

input_specs, output_specs = _sig_to_specs(
user_inputs=set(graph_signature.user_inputs),
inputs_to_parameters=graph_signature.inputs_to_parameters, # type: ignore[arg-type]
Expand Down Expand Up @@ -599,6 +583,22 @@ def make_argument_spec(i, node) -> ArgumentSpec:
input_specs=input_specs, output_specs=output_specs
)

if pre_dispatch:
from torch._export.passes.replace_set_grad_with_hop_pass import (
replace_set_grad_with_hop_pass,
)

gm = replace_set_grad_with_hop_pass(gm, export_graph_signature)

# Remove nn_module_stack, stack_trace metadata from all placeholders/inputs nodes.
for _mod in gm.modules():
if not isinstance(_mod, torch.fx.GraphModule):
continue
for node in _mod.graph.nodes:
if node.op in ["placeholder", "output"]:
node.meta.pop("nn_module_stack", None)
node.meta.pop("stack_trace", None)

constants = rewrite_script_object_meta(gm)
constants.update(lift_constants_pass(gm, export_graph_signature, constant_attrs))

Expand Down
14 changes: 0 additions & 14 deletions torch/export/exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,20 +663,6 @@ def update_arg(old_arg, new_ph):

_replace_sym_size_ops_pass(gm)

if len(new_range_constraints) > 0:
stack_trace = (
'File "torch/_export/passes/add_runtime_assertions_for_constraints_pass.py", line 46, '
"in _AddRuntimeAssertionsForInlineConstraintsPass"
)
with gm._set_create_node_hook(
functools.partial(_node_metadata_hook, stack_trace=stack_trace)
):
res = _AddRuntimeAssertionsForInlineConstraintsPass(
new_range_constraints
)(gm)
assert res is not None
gm = res.graph_module

exported_program = ExportedProgram(
root=gm,
graph=gm.graph,
Expand Down