Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix split module interaction with dead code #104554

Closed
wants to merge 1 commit into from

Conversation

benghaem
Copy link
Contributor

@benghaem benghaem commented Jul 3, 2023

Summary:
This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs.

This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function

Unit Test Added:
test_split_module_dead_code

A module with dead code:

class ModWithDeadCode(torch.nn.Module):
            def forward(self, x):
                output = x * 2 # we want this
                dead_line = x + 2 # this is dead
                return output

Before:

torch/fx/passes/split_module.py, line 357, in split_module
base_mod_env[list(partition.outputs)[0]] = output_val
IndexError: list index out of range

After:

class GraphModule(torch.nn.Module):
    def forward(self, x):
        # No stacktrace found for following nodes
        submod_2 = self.submod_2(x)
        submod_1 = self.submod_1(x);  x = None
        return submod_1

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            add = x + 2;  x = None
            return None

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            mul = x * 2;  x = None
            return mul

Submod 2 is correctly extracted

Test Plan: Tested with new unit test

Differential Revision: D47196732

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 3, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/104554

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 02c3a48:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Jul 3, 2023

CLA Signed

The committers listed above are authorized under a signed CLA.

  • ✅ login: benghaem / name: Benjamin Ghaemmaghami (02c3a48)

@pytorch-bot pytorch-bot bot added the release notes: fx release notes category label Jul 3, 2023
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D47196732

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D47196732

benghaem added a commit to benghaem/pytorch that referenced this pull request Jul 4, 2023
Summary:
Pull Request resolved: pytorch#104554

This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs.

This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function

Unit Test Added:
test_split_module_dead_code

A module with dead code:
```
class ModWithDeadCode(torch.nn.Module):
            def forward(self, x):
                output = x * 2 # we want this
                dead_line = x + 2 # this is dead
                return output
```

Before:
```
torch/fx/passes/split_module.py, line 357, in split_module
base_mod_env[list(partition.outputs)[0]] = output_val
IndexError: list index out of range
```

After:
```
class GraphModule(torch.nn.Module):
    def forward(self, x):
        # No stacktrace found for following nodes
        submod_2 = self.submod_2(x)
        submod_1 = self.submod_1(x);  x = None
        return submod_1

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            add = x + 2;  x = None
            return None

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            mul = x * 2;  x = None
            return mul
```
Submod 2 is correctly extracted

Test Plan: Tested with new unit test

Reviewed By: mustafaozdal

Differential Revision: D47196732

fbshipit-source-id: 2d9d22f0432fb762ef024a44c0ddbf16804928ec
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D47196732

benghaem added a commit to benghaem/pytorch that referenced this pull request Jul 12, 2023
Summary:
Pull Request resolved: pytorch#104554

This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs.

This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function

Unit Test Added:
test_split_module_dead_code

A module with dead code:
```
class ModWithDeadCode(torch.nn.Module):
            def forward(self, x):
                output = x * 2 # we want this
                dead_line = x + 2 # this is dead
                return output
```

Before:
```
torch/fx/passes/split_module.py, line 357, in split_module
base_mod_env[list(partition.outputs)[0]] = output_val
IndexError: list index out of range
```

After:
```
class GraphModule(torch.nn.Module):
    def forward(self, x):
        # No stacktrace found for following nodes
        submod_2 = self.submod_2(x)
        submod_1 = self.submod_1(x);  x = None
        return submod_1

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            add = x + 2;  x = None
            return None

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            mul = x * 2;  x = None
            return mul
```
Submod 2 is correctly extracted

Test Plan: Tested with new unit test

Reviewed By: mustafaozdal

Differential Revision: D47196732

fbshipit-source-id: b9b06ee8b2eaf0cc236e0f5a00e17d42f9971499
benghaem added a commit to benghaem/pytorch that referenced this pull request Jul 17, 2023
Summary:
Pull Request resolved: pytorch#104554

This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs.

This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function

Unit Test Added:
test_split_module_dead_code

A module with dead code:
```
class ModWithDeadCode(torch.nn.Module):
            def forward(self, x):
                output = x * 2 # we want this
                dead_line = x + 2 # this is dead
                return output
```

Before:
```
torch/fx/passes/split_module.py, line 357, in split_module
base_mod_env[list(partition.outputs)[0]] = output_val
IndexError: list index out of range
```

After:
```
class GraphModule(torch.nn.Module):
    def forward(self, x):
        # No stacktrace found for following nodes
        submod_2 = self.submod_2(x)
        submod_1 = self.submod_1(x);  x = None
        return submod_1

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            add = x + 2;  x = None
            return None

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            mul = x * 2;  x = None
            return mul
```
Submod 2 is correctly extracted

Test Plan: Tested with new unit test

Reviewed By: mustafaozdal

Differential Revision: D47196732

fbshipit-source-id: 228e56d9378978bc3463069ca8ccd2a9bc31df58
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D47196732

benghaem added a commit to benghaem/pytorch that referenced this pull request Jul 17, 2023
Summary:
Pull Request resolved: pytorch#104554

This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs.

This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function

Unit Test Added:
test_split_module_dead_code

A module with dead code:
```
class ModWithDeadCode(torch.nn.Module):
            def forward(self, x):
                output = x * 2 # we want this
                dead_line = x + 2 # this is dead
                return output
```

Before:
```
torch/fx/passes/split_module.py, line 357, in split_module
base_mod_env[list(partition.outputs)[0]] = output_val
IndexError: list index out of range
```

After:
```
class GraphModule(torch.nn.Module):
    def forward(self, x):
        # No stacktrace found for following nodes
        submod_2 = self.submod_2(x)
        submod_1 = self.submod_1(x);  x = None
        return submod_1

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            add = x + 2;  x = None
            return None

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            mul = x * 2;  x = None
            return mul
```
Submod 2 is correctly extracted

Test Plan: Tested with new unit test

Reviewed By: mustafaozdal

Differential Revision: D47196732

fbshipit-source-id: e15ebbfbdaa3d9a9f9b43f8038e9c2f57063e687
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D47196732

benghaem added a commit to benghaem/pytorch that referenced this pull request Jul 17, 2023
Summary:
Pull Request resolved: pytorch#104554

This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs.

This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function

Unit Test Added:
test_split_module_dead_code

A module with dead code:
```
class ModWithDeadCode(torch.nn.Module):
            def forward(self, x):
                output = x * 2 # we want this
                dead_line = x + 2 # this is dead
                return output
```

Before:
```
torch/fx/passes/split_module.py, line 357, in split_module
base_mod_env[list(partition.outputs)[0]] = output_val
IndexError: list index out of range
```

After:
```
class GraphModule(torch.nn.Module):
    def forward(self, x):
        # No stacktrace found for following nodes
        submod_2 = self.submod_2(x)
        submod_1 = self.submod_1(x);  x = None
        return submod_1

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            add = x + 2;  x = None
            return None

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            mul = x * 2;  x = None
            return mul
```
Submod 2 is correctly extracted

Test Plan: Tested with new unit test

Reviewed By: mustafaozdal

Differential Revision: D47196732

fbshipit-source-id: fecfba49f8c89b3cdf6007add53119adcaf0cf93
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D47196732

benghaem added a commit to benghaem/pytorch that referenced this pull request Jul 20, 2023
Summary:
Pull Request resolved: pytorch#104554

This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs.

This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function

Unit Test Added:
test_split_module_dead_code

A module with dead code:
```
class ModWithDeadCode(torch.nn.Module):
            def forward(self, x):
                output = x * 2 # we want this
                dead_line = x + 2 # this is dead
                return output
```

Before:
```
torch/fx/passes/split_module.py, line 357, in split_module
base_mod_env[list(partition.outputs)[0]] = output_val
IndexError: list index out of range
```

After:
```
class GraphModule(torch.nn.Module):
    def forward(self, x):
        # No stacktrace found for following nodes
        submod_2 = self.submod_2(x)
        submod_1 = self.submod_1(x);  x = None
        return submod_1

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            add = x + 2;  x = None
            return None

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            mul = x * 2;  x = None
            return mul
```
Submod 2 is correctly extracted

Test Plan: Tested with new unit test

Reviewed By: mustafaozdal

Differential Revision: D47196732

fbshipit-source-id: 6bbed6778ea24fd0cc2f1127fd4d5762c8d3c30b
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D47196732

benghaem added a commit to benghaem/pytorch that referenced this pull request Aug 1, 2023
Summary:
Pull Request resolved: pytorch#104554

This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs.

This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function

Unit Test Added:
test_split_module_dead_code

A module with dead code:
```
class ModWithDeadCode(torch.nn.Module):
            def forward(self, x):
                output = x * 2 # we want this
                dead_line = x + 2 # this is dead
                return output
```

Before:
```
torch/fx/passes/split_module.py, line 357, in split_module
base_mod_env[list(partition.outputs)[0]] = output_val
IndexError: list index out of range
```

After:
```
class GraphModule(torch.nn.Module):
    def forward(self, x):
        # No stacktrace found for following nodes
        submod_2 = self.submod_2(x)
        submod_1 = self.submod_1(x);  x = None
        return submod_1

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            add = x + 2;  x = None
            return None

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            mul = x * 2;  x = None
            return mul
```
Submod 2 is correctly extracted

Test Plan: Tested with new unit test

Reviewed By: mustafaozdal

Differential Revision: D47196732

fbshipit-source-id: 86a6ab33d292735d07b0efd25dc93e0763f4b552
benghaem added a commit to benghaem/pytorch that referenced this pull request Aug 1, 2023
Summary:
Pull Request resolved: pytorch#104554

This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs.

This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function

Unit Test Added:
test_split_module_dead_code

A module with dead code:
```
class ModWithDeadCode(torch.nn.Module):
            def forward(self, x):
                output = x * 2 # we want this
                dead_line = x + 2 # this is dead
                return output
```

Before:
```
torch/fx/passes/split_module.py, line 357, in split_module
base_mod_env[list(partition.outputs)[0]] = output_val
IndexError: list index out of range
```

After:
```
class GraphModule(torch.nn.Module):
    def forward(self, x):
        # No stacktrace found for following nodes
        submod_2 = self.submod_2(x)
        submod_1 = self.submod_1(x);  x = None
        return submod_1

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            add = x + 2;  x = None
            return None

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            mul = x * 2;  x = None
            return mul
```
Submod 2 is correctly extracted

Test Plan: Tested with new unit test

Reviewed By: mustafaozdal

Differential Revision: D47196732

fbshipit-source-id: 0e9c852e8fe5ac5fa6597f3fa81bde71c56bae82
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D47196732

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D47196732

benghaem added a commit to benghaem/pytorch that referenced this pull request Aug 2, 2023
Summary:
Pull Request resolved: pytorch#104554

This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs.

This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function

Unit Test Added:
test_split_module_dead_code

A module with dead code:
```
class ModWithDeadCode(torch.nn.Module):
            def forward(self, x):
                output = x * 2 # we want this
                dead_line = x + 2 # this is dead
                return output
```

Before:
```
torch/fx/passes/split_module.py, line 357, in split_module
base_mod_env[list(partition.outputs)[0]] = output_val
IndexError: list index out of range
```

After:
```
class GraphModule(torch.nn.Module):
    def forward(self, x):
        # No stacktrace found for following nodes
        submod_2 = self.submod_2(x)
        submod_1 = self.submod_1(x);  x = None
        return submod_1

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            add = x + 2;  x = None
            return None

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            mul = x * 2;  x = None
            return mul
```
Submod 2 is correctly extracted

Test Plan: Tested with new unit test

Reviewed By: mustafaozdal

Differential Revision: D47196732

fbshipit-source-id: 81bb570788a4baf6ba8953688d663dfc8fdd8d4d
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D47196732

benghaem added a commit to benghaem/pytorch that referenced this pull request Aug 2, 2023
Summary:
Pull Request resolved: pytorch#104554

This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs.

This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function

Unit Test Added:
test_split_module_dead_code

A module with dead code:
```
class ModWithDeadCode(torch.nn.Module):
            def forward(self, x):
                output = x * 2 # we want this
                dead_line = x + 2 # this is dead
                return output
```

Before:
```
torch/fx/passes/split_module.py, line 357, in split_module
base_mod_env[list(partition.outputs)[0]] = output_val
IndexError: list index out of range
```

After:
```
class GraphModule(torch.nn.Module):
    def forward(self, x):
        # No stacktrace found for following nodes
        submod_2 = self.submod_2(x)
        submod_1 = self.submod_1(x);  x = None
        return submod_1

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            add = x + 2;  x = None
            return None

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            mul = x * 2;  x = None
            return mul
```
Submod 2 is correctly extracted

Test Plan: Tested with new unit test

Reviewed By: mustafaozdal

Differential Revision: D47196732

fbshipit-source-id: fd39dc68e612e78cd3945a313154fdbf32bb7383
benghaem added a commit to benghaem/pytorch that referenced this pull request Aug 2, 2023
Summary:
Pull Request resolved: pytorch#104554

This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs.

This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function

Unit Test Added:
test_split_module_dead_code

A module with dead code:
```
class ModWithDeadCode(torch.nn.Module):
            def forward(self, x):
                output = x * 2 # we want this
                dead_line = x + 2 # this is dead
                return output
```

Before:
```
torch/fx/passes/split_module.py, line 357, in split_module
base_mod_env[list(partition.outputs)[0]] = output_val
IndexError: list index out of range
```

After:
```
class GraphModule(torch.nn.Module):
    def forward(self, x):
        # No stacktrace found for following nodes
        submod_2 = self.submod_2(x)
        submod_1 = self.submod_1(x);  x = None
        return submod_1

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            add = x + 2;  x = None
            return None

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            mul = x * 2;  x = None
            return mul
```
Submod 2 is correctly extracted

Test Plan: Tested with new unit test

Reviewed By: mustafaozdal

Differential Revision: D47196732

fbshipit-source-id: 0503440f9941edf577689b08164cbd1a2b1881d6
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D47196732

benghaem added a commit to benghaem/pytorch that referenced this pull request Aug 3, 2023
Summary:
Pull Request resolved: pytorch#104554

This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs.

This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function

Unit Test Added:
test_split_module_dead_code

A module with dead code:
```
class ModWithDeadCode(torch.nn.Module):
            def forward(self, x):
                output = x * 2 # we want this
                dead_line = x + 2 # this is dead
                return output
```

Before:
```
torch/fx/passes/split_module.py, line 357, in split_module
base_mod_env[list(partition.outputs)[0]] = output_val
IndexError: list index out of range
```

After:
```
class GraphModule(torch.nn.Module):
    def forward(self, x):
        # No stacktrace found for following nodes
        submod_2 = self.submod_2(x)
        submod_1 = self.submod_1(x);  x = None
        return submod_1

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            add = x + 2;  x = None
            return None

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            mul = x * 2;  x = None
            return mul
```
Submod 2 is correctly extracted

Test Plan: Tested with new unit test

Reviewed By: mustafaozdal, RenfeiChen-FB

Differential Revision: D47196732

fbshipit-source-id: e03013ef6a09c09981a2702c584a97909ea2024b
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D47196732

Summary:
Pull Request resolved: pytorch#104554

This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs.

This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function

Unit Test Added:
test_split_module_dead_code

A module with dead code:
```
class ModWithDeadCode(torch.nn.Module):
            def forward(self, x):
                output = x * 2 # we want this
                dead_line = x + 2 # this is dead
                return output
```

Before:
```
torch/fx/passes/split_module.py, line 357, in split_module
base_mod_env[list(partition.outputs)[0]] = output_val
IndexError: list index out of range
```

After:
```
class GraphModule(torch.nn.Module):
    def forward(self, x):
        # No stacktrace found for following nodes
        submod_2 = self.submod_2(x)
        submod_1 = self.submod_1(x);  x = None
        return submod_1

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            add = x + 2;  x = None
            return None

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            mul = x * 2;  x = None
            return mul
```
Submod 2 is correctly extracted

Test Plan: Tested with new unit test

Reviewed By: mustafaozdal, RenfeiChen-FB

Differential Revision: D47196732

fbshipit-source-id: 5c937491d2db8e9c8de9f1e00b9fe5baae920528
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D47196732

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 3, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request fb-exported Merged release notes: fx release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants