New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix split module interaction with dead code #104554
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/104554
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 02c3a48: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
This pull request was exported from Phabricator. Differential Revision: D47196732 |
1 similar comment
This pull request was exported from Phabricator. Differential Revision: D47196732 |
9ca440f
to
3b9f2df
Compare
Summary: Pull Request resolved: pytorch#104554 This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs. This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function Unit Test Added: test_split_module_dead_code A module with dead code: ``` class ModWithDeadCode(torch.nn.Module): def forward(self, x): output = x * 2 # we want this dead_line = x + 2 # this is dead return output ``` Before: ``` torch/fx/passes/split_module.py, line 357, in split_module base_mod_env[list(partition.outputs)[0]] = output_val IndexError: list index out of range ``` After: ``` class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes submod_2 = self.submod_2(x) submod_1 = self.submod_1(x); x = None return submod_1 class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes add = x + 2; x = None return None class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes mul = x * 2; x = None return mul ``` Submod 2 is correctly extracted Test Plan: Tested with new unit test Reviewed By: mustafaozdal Differential Revision: D47196732 fbshipit-source-id: 2d9d22f0432fb762ef024a44c0ddbf16804928ec
This pull request was exported from Phabricator. Differential Revision: D47196732 |
3b9f2df
to
5b73e9c
Compare
Summary: Pull Request resolved: pytorch#104554 This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs. This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function Unit Test Added: test_split_module_dead_code A module with dead code: ``` class ModWithDeadCode(torch.nn.Module): def forward(self, x): output = x * 2 # we want this dead_line = x + 2 # this is dead return output ``` Before: ``` torch/fx/passes/split_module.py, line 357, in split_module base_mod_env[list(partition.outputs)[0]] = output_val IndexError: list index out of range ``` After: ``` class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes submod_2 = self.submod_2(x) submod_1 = self.submod_1(x); x = None return submod_1 class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes add = x + 2; x = None return None class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes mul = x * 2; x = None return mul ``` Submod 2 is correctly extracted Test Plan: Tested with new unit test Reviewed By: mustafaozdal Differential Revision: D47196732 fbshipit-source-id: b9b06ee8b2eaf0cc236e0f5a00e17d42f9971499
5b73e9c
to
995ba97
Compare
Summary: Pull Request resolved: pytorch#104554 This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs. This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function Unit Test Added: test_split_module_dead_code A module with dead code: ``` class ModWithDeadCode(torch.nn.Module): def forward(self, x): output = x * 2 # we want this dead_line = x + 2 # this is dead return output ``` Before: ``` torch/fx/passes/split_module.py, line 357, in split_module base_mod_env[list(partition.outputs)[0]] = output_val IndexError: list index out of range ``` After: ``` class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes submod_2 = self.submod_2(x) submod_1 = self.submod_1(x); x = None return submod_1 class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes add = x + 2; x = None return None class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes mul = x * 2; x = None return mul ``` Submod 2 is correctly extracted Test Plan: Tested with new unit test Reviewed By: mustafaozdal Differential Revision: D47196732 fbshipit-source-id: 228e56d9378978bc3463069ca8ccd2a9bc31df58
This pull request was exported from Phabricator. Differential Revision: D47196732 |
Summary: Pull Request resolved: pytorch#104554 This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs. This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function Unit Test Added: test_split_module_dead_code A module with dead code: ``` class ModWithDeadCode(torch.nn.Module): def forward(self, x): output = x * 2 # we want this dead_line = x + 2 # this is dead return output ``` Before: ``` torch/fx/passes/split_module.py, line 357, in split_module base_mod_env[list(partition.outputs)[0]] = output_val IndexError: list index out of range ``` After: ``` class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes submod_2 = self.submod_2(x) submod_1 = self.submod_1(x); x = None return submod_1 class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes add = x + 2; x = None return None class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes mul = x * 2; x = None return mul ``` Submod 2 is correctly extracted Test Plan: Tested with new unit test Reviewed By: mustafaozdal Differential Revision: D47196732 fbshipit-source-id: e15ebbfbdaa3d9a9f9b43f8038e9c2f57063e687
995ba97
to
2a29363
Compare
This pull request was exported from Phabricator. Differential Revision: D47196732 |
Summary: Pull Request resolved: pytorch#104554 This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs. This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function Unit Test Added: test_split_module_dead_code A module with dead code: ``` class ModWithDeadCode(torch.nn.Module): def forward(self, x): output = x * 2 # we want this dead_line = x + 2 # this is dead return output ``` Before: ``` torch/fx/passes/split_module.py, line 357, in split_module base_mod_env[list(partition.outputs)[0]] = output_val IndexError: list index out of range ``` After: ``` class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes submod_2 = self.submod_2(x) submod_1 = self.submod_1(x); x = None return submod_1 class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes add = x + 2; x = None return None class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes mul = x * 2; x = None return mul ``` Submod 2 is correctly extracted Test Plan: Tested with new unit test Reviewed By: mustafaozdal Differential Revision: D47196732 fbshipit-source-id: fecfba49f8c89b3cdf6007add53119adcaf0cf93
2a29363
to
bf8174e
Compare
This pull request was exported from Phabricator. Differential Revision: D47196732 |
bf8174e
to
65c1a9e
Compare
Summary: Pull Request resolved: pytorch#104554 This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs. This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function Unit Test Added: test_split_module_dead_code A module with dead code: ``` class ModWithDeadCode(torch.nn.Module): def forward(self, x): output = x * 2 # we want this dead_line = x + 2 # this is dead return output ``` Before: ``` torch/fx/passes/split_module.py, line 357, in split_module base_mod_env[list(partition.outputs)[0]] = output_val IndexError: list index out of range ``` After: ``` class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes submod_2 = self.submod_2(x) submod_1 = self.submod_1(x); x = None return submod_1 class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes add = x + 2; x = None return None class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes mul = x * 2; x = None return mul ``` Submod 2 is correctly extracted Test Plan: Tested with new unit test Reviewed By: mustafaozdal Differential Revision: D47196732 fbshipit-source-id: 6bbed6778ea24fd0cc2f1127fd4d5762c8d3c30b
65c1a9e
to
dc98de2
Compare
This pull request was exported from Phabricator. Differential Revision: D47196732 |
Summary: Pull Request resolved: pytorch#104554 This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs. This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function Unit Test Added: test_split_module_dead_code A module with dead code: ``` class ModWithDeadCode(torch.nn.Module): def forward(self, x): output = x * 2 # we want this dead_line = x + 2 # this is dead return output ``` Before: ``` torch/fx/passes/split_module.py, line 357, in split_module base_mod_env[list(partition.outputs)[0]] = output_val IndexError: list index out of range ``` After: ``` class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes submod_2 = self.submod_2(x) submod_1 = self.submod_1(x); x = None return submod_1 class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes add = x + 2; x = None return None class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes mul = x * 2; x = None return mul ``` Submod 2 is correctly extracted Test Plan: Tested with new unit test Reviewed By: mustafaozdal Differential Revision: D47196732 fbshipit-source-id: 86a6ab33d292735d07b0efd25dc93e0763f4b552
Summary: Pull Request resolved: pytorch#104554 This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs. This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function Unit Test Added: test_split_module_dead_code A module with dead code: ``` class ModWithDeadCode(torch.nn.Module): def forward(self, x): output = x * 2 # we want this dead_line = x + 2 # this is dead return output ``` Before: ``` torch/fx/passes/split_module.py, line 357, in split_module base_mod_env[list(partition.outputs)[0]] = output_val IndexError: list index out of range ``` After: ``` class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes submod_2 = self.submod_2(x) submod_1 = self.submod_1(x); x = None return submod_1 class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes add = x + 2; x = None return None class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes mul = x * 2; x = None return mul ``` Submod 2 is correctly extracted Test Plan: Tested with new unit test Reviewed By: mustafaozdal Differential Revision: D47196732 fbshipit-source-id: 0e9c852e8fe5ac5fa6597f3fa81bde71c56bae82
dc98de2
to
54b1f11
Compare
This pull request was exported from Phabricator. Differential Revision: D47196732 |
1 similar comment
This pull request was exported from Phabricator. Differential Revision: D47196732 |
54b1f11
to
89ff0e6
Compare
Summary: Pull Request resolved: pytorch#104554 This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs. This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function Unit Test Added: test_split_module_dead_code A module with dead code: ``` class ModWithDeadCode(torch.nn.Module): def forward(self, x): output = x * 2 # we want this dead_line = x + 2 # this is dead return output ``` Before: ``` torch/fx/passes/split_module.py, line 357, in split_module base_mod_env[list(partition.outputs)[0]] = output_val IndexError: list index out of range ``` After: ``` class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes submod_2 = self.submod_2(x) submod_1 = self.submod_1(x); x = None return submod_1 class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes add = x + 2; x = None return None class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes mul = x * 2; x = None return mul ``` Submod 2 is correctly extracted Test Plan: Tested with new unit test Reviewed By: mustafaozdal Differential Revision: D47196732 fbshipit-source-id: 81bb570788a4baf6ba8953688d663dfc8fdd8d4d
This pull request was exported from Phabricator. Differential Revision: D47196732 |
89ff0e6
to
daf9a51
Compare
Summary: Pull Request resolved: pytorch#104554 This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs. This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function Unit Test Added: test_split_module_dead_code A module with dead code: ``` class ModWithDeadCode(torch.nn.Module): def forward(self, x): output = x * 2 # we want this dead_line = x + 2 # this is dead return output ``` Before: ``` torch/fx/passes/split_module.py, line 357, in split_module base_mod_env[list(partition.outputs)[0]] = output_val IndexError: list index out of range ``` After: ``` class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes submod_2 = self.submod_2(x) submod_1 = self.submod_1(x); x = None return submod_1 class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes add = x + 2; x = None return None class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes mul = x * 2; x = None return mul ``` Submod 2 is correctly extracted Test Plan: Tested with new unit test Reviewed By: mustafaozdal Differential Revision: D47196732 fbshipit-source-id: fd39dc68e612e78cd3945a313154fdbf32bb7383
Summary: Pull Request resolved: pytorch#104554 This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs. This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function Unit Test Added: test_split_module_dead_code A module with dead code: ``` class ModWithDeadCode(torch.nn.Module): def forward(self, x): output = x * 2 # we want this dead_line = x + 2 # this is dead return output ``` Before: ``` torch/fx/passes/split_module.py, line 357, in split_module base_mod_env[list(partition.outputs)[0]] = output_val IndexError: list index out of range ``` After: ``` class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes submod_2 = self.submod_2(x) submod_1 = self.submod_1(x); x = None return submod_1 class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes add = x + 2; x = None return None class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes mul = x * 2; x = None return mul ``` Submod 2 is correctly extracted Test Plan: Tested with new unit test Reviewed By: mustafaozdal Differential Revision: D47196732 fbshipit-source-id: 0503440f9941edf577689b08164cbd1a2b1881d6
daf9a51
to
c826fc3
Compare
This pull request was exported from Phabricator. Differential Revision: D47196732 |
c826fc3
to
9846615
Compare
Summary: Pull Request resolved: pytorch#104554 This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs. This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function Unit Test Added: test_split_module_dead_code A module with dead code: ``` class ModWithDeadCode(torch.nn.Module): def forward(self, x): output = x * 2 # we want this dead_line = x + 2 # this is dead return output ``` Before: ``` torch/fx/passes/split_module.py, line 357, in split_module base_mod_env[list(partition.outputs)[0]] = output_val IndexError: list index out of range ``` After: ``` class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes submod_2 = self.submod_2(x) submod_1 = self.submod_1(x); x = None return submod_1 class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes add = x + 2; x = None return None class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes mul = x * 2; x = None return mul ``` Submod 2 is correctly extracted Test Plan: Tested with new unit test Reviewed By: mustafaozdal, RenfeiChen-FB Differential Revision: D47196732 fbshipit-source-id: e03013ef6a09c09981a2702c584a97909ea2024b
This pull request was exported from Phabricator. Differential Revision: D47196732 |
Summary: Pull Request resolved: pytorch#104554 This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs. This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function Unit Test Added: test_split_module_dead_code A module with dead code: ``` class ModWithDeadCode(torch.nn.Module): def forward(self, x): output = x * 2 # we want this dead_line = x + 2 # this is dead return output ``` Before: ``` torch/fx/passes/split_module.py, line 357, in split_module base_mod_env[list(partition.outputs)[0]] = output_val IndexError: list index out of range ``` After: ``` class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes submod_2 = self.submod_2(x) submod_1 = self.submod_1(x); x = None return submod_1 class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes add = x + 2; x = None return None class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes mul = x * 2; x = None return mul ``` Submod 2 is correctly extracted Test Plan: Tested with new unit test Reviewed By: mustafaozdal, RenfeiChen-FB Differential Revision: D47196732 fbshipit-source-id: 5c937491d2db8e9c8de9f1e00b9fe5baae920528
9846615
to
02c3a48
Compare
This pull request was exported from Phabricator. Differential Revision: D47196732 |
@pytorchbot merge (Initiating merge automatically since Phabricator Diff has merged) |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary:
This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs.
This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function
Unit Test Added:
test_split_module_dead_code
A module with dead code:
Before:
After:
Submod 2 is correctly extracted
Test Plan: Tested with new unit test
Differential Revision: D47196732