Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix split module interaction with dead code #104554

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
32 changes: 32 additions & 0 deletions test/test_fx_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,38 @@ def mod_partition(node: Node):

self.assertEqual(orig_out, submodules_out)

def test_split_module_dead_code(self):
class ModWithDeadCode(torch.nn.Module):
def forward(self, x):
output = x * 2 # we want this
dead_line = x + 2 # this is dead
return output

mod = ModWithDeadCode()
traced = torch.fx.symbolic_trace(mod)

# split into before (0), target (1), and after(2)
saw_mul = False

def split_callback(n):
nonlocal saw_mul
if n.target == operator.mul:
saw_mul = True
return 1

if not saw_mul:
return 0
if saw_mul:
return 2

split = split_module(traced, mod, split_callback)

x = torch.randn((5,))
torch.testing.assert_close(
split(x), traced(x)
)


def test_split_module_kwargs_expansion(self):
class ModuleWithKwargsExpansion(torch.nn.Module):
def forward(self, x, **kwargs):
Expand Down
15 changes: 11 additions & 4 deletions torch/fx/passes/split_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,13 @@ def record_cross_partition_use(
output_vals = tuple(
partition.environment[orig_nodes[name]] for name in partition.outputs
)
output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
partition.graph.output(output_vals)

# skip output node generation if there are no output values
num_output_vals = len(output_vals)
if num_output_vals == 1:
partition.graph.output(output_vals[0])
elif num_output_vals > 1:
partition.graph.output(output_vals)

if keep_original_order:
# first get the attr nodes required by this partition
Expand All @@ -346,12 +351,14 @@ def record_cross_partition_use(
partition.submod_name,
tuple(base_mod_env[name] for name in partition.inputs),
)
if len(partition.outputs) > 1:

num_outputs = len(partition.outputs)
if num_outputs > 1:
# Unpack multiple return values from submodule
output_val_proxy = torch.fx.proxy.Proxy(output_val)
for i, output_name in enumerate(partition.outputs):
base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
else:
elif num_outputs == 1:
base_mod_env[list(partition.outputs)[0]] = output_val

for node in m.graph.nodes:
Expand Down