Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix split module interaction with dead code #104554

Closed
wants to merge 1 commit into from

Commits on Aug 3, 2023

  1. Fix split module interaction with dead code (pytorch#104554)

    Summary:
    Pull Request resolved: pytorch#104554
    
    This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs.
    
    This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function
    
    Unit Test Added:
    test_split_module_dead_code
    
    A module with dead code:
    ```
    class ModWithDeadCode(torch.nn.Module):
                def forward(self, x):
                    output = x * 2 # we want this
                    dead_line = x + 2 # this is dead
                    return output
    ```
    
    Before:
    ```
    torch/fx/passes/split_module.py, line 357, in split_module
    base_mod_env[list(partition.outputs)[0]] = output_val
    IndexError: list index out of range
    ```
    
    After:
    ```
    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            submod_2 = self.submod_2(x)
            submod_1 = self.submod_1(x);  x = None
            return submod_1
    
        class GraphModule(torch.nn.Module):
            def forward(self, x):
                # No stacktrace found for following nodes
                add = x + 2;  x = None
                return None
    
        class GraphModule(torch.nn.Module):
            def forward(self, x):
                # No stacktrace found for following nodes
                mul = x * 2;  x = None
                return mul
    ```
    Submod 2 is correctly extracted
    
    Test Plan: Tested with new unit test
    
    Reviewed By: mustafaozdal, RenfeiChen-FB
    
    Differential Revision: D47196732
    
    fbshipit-source-id: 5c937491d2db8e9c8de9f1e00b9fe5baae920528
    benghaem authored and facebook-github-bot committed Aug 3, 2023
    Configuration menu
    Copy the full SHA
    02c3a48 View commit details
    Browse the repository at this point in the history