Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JIT] Allow freezing modules that contain mutable interfaces #86039

Closed
wants to merge 7 commits into from

Conversation

davidberard98
Copy link
Contributor

This PR allows freezing modules like the one below:

# Ex. 1
        @torch.jit.interface
        class ModuleInterface(torch.nn.Module):
            def forward(self, inp: torch.Tensor) -> torch.Tensor:
                pass

        class ImplementsInterface(torch.nn.Module):
            def __init__(self):
                super(ImplementsInterface, self).__init__()
                self.sum = torch.zeros((2, 2))

            def forward(self, inp: torch.Tensor) -> torch.Tensor:
                self.sum += inp.relu()  # this makes the interface-implementing module mutable
                                        # and previously this would prevent freezing
                return self.sum

        class WrapperModule(torch.nn.Module):
            impl: ModuleInterface

            def __init__(self):
                super().__init__()
                self.impl = ImplementsInterface()

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return self.impl.forward(x)

Previously during freezing, we handle interfaces as shown below:

  1. we inline interfaces in any preserved method graphs
  2. during cleanupFrozenModule, we try to simplify the module data structure (<- this part is unrelated to freezing so far). During this step, if we found that a interface type was mutable, we'd error out; because of the possibility of a module that swaps out the value of an interface-typed attribute at runtime.

Below is an example of a module that swaps out the value of an interface-typed attribute at runtime:

# Ex. 2
class MyBadModule(torch.nn.Module):
    impl: MyInterface
    option1: IfaceImpl1
    option2: IfaceImpl2
    ....
    def forward(self, x):
        if x > 0:
            self.impl = self.option1
        else:
            self.impl = self.option2
        ....

^ this type of situation cannot be supported by freezing (or at least would be difficult to do correctly) because it greatly complicates the details of handling types and simplifying the module data structure.

But we can still support the first example without too much work:

  1. inline the interface code as before
  2. check to see if we have any setattrs on interface types; if so, error out
  3. otherwise, replace the type of the interface types with the concrete type implementation
  4. continue simplifying the module data structure as if we never had any interfaces.

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 1, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/86039

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a3055b2:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: jit release notes category label Oct 1, 2022
@facebook-github-bot facebook-github-bot added cla signed oncall: jit Add this issue/PR to JIT oncall triage queue labels Oct 1, 2022
@facebook-github-bot
Copy link
Contributor

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

3 similar comments
@facebook-github-bot
Copy link
Contributor

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

/easycla

As part of the transition to the PyTorch Foundation, this project now requires contributions be covered under the new CLA. See #85559 for additional details.

This comment will trigger a new check of this PR. If you are already covered, you will simply see a new "EasyCLA" check that passes. If you are not covered, a bot will leave a new comment with a link to sign.

@davidberard98 davidberard98 force-pushed the davidberard98/freeze_mutable_interfaces branch from e0f054e to 853cd4c Compare October 6, 2022 00:25
@facebook-github-bot
Copy link
Contributor

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good ! one question

@@ -113,18 +126,23 @@ class AttributePropagator {
LowerSimpleTuples(subgraph);
};

std::unordered_map<std::string, std::unordered_set<std::string>>
interfacesToRetype;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: spell out interfacesToRealType - within context i can guess that's what this means but seeing ReType i also think abt Return Type

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant interfacesToReassignType haha... so yeah I guess I should rename that :)

x = torch.rand((2, 2))

m_frozen(x)
self.assertEqual(m_frozen.impl.sum, x.relu())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

# I don't think there's any way to create a plain python object that
# contains a torch.nn.Module inside it, but just in case... I'm not
# sure freezing would handle this case correctly, so marking as xfail
# so that if this ever _does_ start working someone will need to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

for (auto function : preservedMethods_) {
GRAPH_DEBUG("Analyzing function: " + function->name());
auto graph = toGraphFunction(*function).graph();
optimizeSubGraphs(graph, applyInline);
if (freezeInterfaces_) {
inlineInterfaceCalls(graph);
inlineInterfaceCalls(graph, interfacesToRetype);
}
// Record Attributes that are explicitly set in the module.
// They cannot be folded.
recordMutableAttrs(graph);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we be recording mutable attrs after reassigning Interface Types?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison suppose we put reassignInterfaceTypes immediately after inlineInterfaceCalls. Then we can have the following situation:

  • preservedMethod1 does getattr(self, some_interface_module). It registers it as an interface to reassign type
  • we reassign type for some_interface_module, so now the module object labels it as an interface type
  • preservedMethod2 does setattr(self, some_interface_module). But when it looks it up in the module, it sees that some_interface_module is not an interface (because the type was reassigned already). So it doesn't error out.

Is there anything that can go wrong with reassigning interface types after recordMutableAttrs?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, good point. Sorry the important part is that the calls themselves are inlined (and mutation is observed) not that the actual types on the modules are re-assigned.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually just to be safe I reordered it anyway, just added an extra loop to do the recordMutableAttrs afterwards.


self.assertEqual(expected, actual)

def test_freeze_recursive_interfaces_same_name(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a test where an interface does an interface reassignment?

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!! Great tests too!!

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 6, 2022
@davidberard98 davidberard98 force-pushed the davidberard98/freeze_mutable_interfaces branch from 853cd4c to 273bf1b Compare October 6, 2022 20:52
@facebook-github-bot
Copy link
Contributor

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@davidberard98
Copy link
Contributor Author

@pytorchbot rebase -s

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased davidberard98/freeze_mutable_interfaces onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout davidberard98/freeze_mutable_interfaces && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the davidberard98/freeze_mutable_interfaces branch from 273bf1b to 2c95496 Compare October 7, 2022 00:34
@facebook-github-bot
Copy link
Contributor

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@davidberard98
Copy link
Contributor Author

@pytorchbot rebase -s

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

- add tests (recursive interface swap & 2-method interface swap)
- reassign types _before_ checking for mutability
@pytorchmergebot
Copy link
Collaborator

Successfully rebased davidberard98/freeze_mutable_interfaces onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout davidberard98/freeze_mutable_interfaces && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the davidberard98/freeze_mutable_interfaces branch from 2c95496 to a3055b2 Compare October 7, 2022 21:24
@facebook-github-bot
Copy link
Contributor

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@davidberard98
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions
Copy link

github-actions bot commented Oct 8, 2022

Hey @davidberard98.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

facebook-github-bot pushed a commit that referenced this pull request Oct 10, 2022
…#86039)

Summary:
This PR allows freezing modules like the one below:
```python
# Ex. 1
        torch.jit.interface
        class ModuleInterface(torch.nn.Module):
            def forward(self, inp: torch.Tensor) -> torch.Tensor:
                pass

        class ImplementsInterface(torch.nn.Module):
            def __init__(self):
                super(ImplementsInterface, self).__init__()
                self.sum = torch.zeros((2, 2))

            def forward(self, inp: torch.Tensor) -> torch.Tensor:
                self.sum += inp.relu()  # this makes the interface-implementing module mutable
                                        # and previously this would prevent freezing
                return self.sum

        class WrapperModule(torch.nn.Module):
            impl: ModuleInterface

            def __init__(self):
                super().__init__()
                self.impl = ImplementsInterface()

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return self.impl.forward(x)
```

Previously during freezing, we handle interfaces as shown below:
1. we inline interfaces in any preserved method graphs
2. during `cleanupFrozenModule`, we try to simplify the module data structure (<- this part is unrelated to freezing so far). During this step, if we found that a interface type was mutable, we'd error out; because of the possibility of a module that _swaps out the value of an interface-typed attribute at runtime_.

Below is an example of a module that swaps out the value of an interface-typed attribute at runtime:
```python
# Ex. 2
class MyBadModule(torch.nn.Module):
    impl: MyInterface
    option1: IfaceImpl1
    option2: IfaceImpl2
    ....
    def forward(self, x):
        if x > 0:
            self.impl = self.option1
        else:
            self.impl = self.option2
        ....
```

^ this type of situation cannot be supported by freezing (or at least would be difficult to do correctly) because it greatly complicates the details of handling types and simplifying the module data structure.

But we can still support the first example without _too_ much work:
1. inline the interface code as before
2. check to see if we have any setattrs on interface types; if so, error out
3. otherwise, replace the type of the interface types with the concrete type implementation
4. continue simplifying the module data structure as if we never had any interfaces.

Pull Request resolved: #86039
Approved by: https://github.com/eellison

Test Plan:
contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/bac26155e7e5bca949e986fa36ab37e042b8ad53

Original Phabricator Test Plan:
Imported from GitHub, without a `Test Plan:` line.

Reviewed By: seemethere, eellison, hzh0512

Differential Revision: D39990822

fbshipit-source-id: 469375b8c7a43adb1e7b7401768527ba81b71719
@github-actions github-actions bot deleted the davidberard98/freeze_mutable_interfaces branch April 8, 2024 01:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request cla signed Merged oncall: jit Add this issue/PR to JIT oncall triage queue release notes: jit release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants