Skip to content

Commit

Permalink
Fix split module interaction with dead code
Browse files Browse the repository at this point in the history
Summary:
This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs.

This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function

Unit Test Added:
test_split_module_dead_code

A module with dead code:
```
class ModWithDeadCode(torch.nn.Module):
            def forward(self, x):
                output = x * 2 # we want this
                dead_line = x + 2 # this is dead
                return output
```

Before:
```
torch/fx/passes/split_module.py, line 357, in split_module
base_mod_env[list(partition.outputs)[0]] = output_val
IndexError: list index out of range
```

After:
```
class GraphModule(torch.nn.Module):
    def forward(self, x):
        # No stacktrace found for following nodes
        submod_2 = self.submod_2(x)
        submod_1 = self.submod_1(x);  x = None
        return submod_1

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            add = x + 2;  x = None
            return None

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            mul = x * 2;  x = None
            return mul
```
Submod 2 is correctly extracted

Test Plan: Tested with new unit test

Differential Revision: D47196732

fbshipit-source-id: 82d7627ffc0b38b4c4a5d68dcfce4816d96ffddd
  • Loading branch information
benghaem authored and facebook-github-bot committed Jul 3, 2023
1 parent d7b5cd7 commit 9ca440f
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 3 deletions.
31 changes: 31 additions & 0 deletions test/test_fx_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,37 @@ def mod_partition(node: Node):

self.assertEqual(orig_out, submodules_out)

def test_split_module_dead_code(self):
class ModWithDeadCode(torch.nn.Module):
def forward(self, x):
output = x * 2 # we want this
dead_line = x + 2 # this is dead
return output

mod = ModWithDeadCode()
traced = torch.fx.symbolic_trace(mod)

# split into before (0), target (1), and after(2)
saw_mul = False
def split_callback(n):
nonlocal saw_mul
if n.target == operator.mul:
saw_mul = True
return 1

if not saw_mul:
return 0
if saw_mul:
return 2

split = split_module(traced, mod, split_callback)

x = torch.randn((5,))
torch.testing.assert_close(
split(x), traced(x)
)


def test_split_module_kwargs_expansion(self):
class ModuleWithKwargsExpansion(torch.nn.Module):
def forward(self, x, **kwargs):
Expand Down
12 changes: 9 additions & 3 deletions torch/fx/passes/split_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,11 @@ def record_cross_partition_use(
output_vals = tuple(
partition.environment[orig_nodes[name]] for name in partition.outputs
)
output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
num_output_vals = len(output_vals)
if num_output_vals == 0:
output_vals = None
elif num_output_vals == 1:
output_vals = output_vals[0]
partition.graph.output(output_vals)

if keep_original_order:
Expand All @@ -346,12 +350,14 @@ def record_cross_partition_use(
partition.submod_name,
tuple(base_mod_env[name] for name in partition.inputs),
)
if len(partition.outputs) > 1:

num_outputs = len(partition.outputs)
if num_outputs > 1:
# Unpack multiple return values from submodule
output_val_proxy = torch.fx.proxy.Proxy(output_val)
for i, output_name in enumerate(partition.outputs):
base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
else:
elif num_outputs == 1:
base_mod_env[list(partition.outputs)[0]] = output_val

for node in m.graph.nodes:
Expand Down

0 comments on commit 9ca440f

Please sign in to comment.