-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JIT] Allow freezing modules that contain mutable interfaces #86039
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/86039
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit a3055b2: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
3 similar comments
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
/easycla As part of the transition to the PyTorch Foundation, this project now requires contributions be covered under the new CLA. See #85559 for additional details. This comment will trigger a new check of this PR. If you are already covered, you will simply see a new "EasyCLA" check that passes. If you are not covered, a bot will leave a new comment with a link to sign. |
e0f054e
to
853cd4c
Compare
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good ! one question
@@ -113,18 +126,23 @@ class AttributePropagator { | |||
LowerSimpleTuples(subgraph); | |||
}; | |||
|
|||
std::unordered_map<std::string, std::unordered_set<std::string>> | |||
interfacesToRetype; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: spell out interfacesToRealType
- within context i can guess that's what this means but seeing ReType
i also think abt Return Type
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant interfacesToReassignType
haha... so yeah I guess I should rename that :)
x = torch.rand((2, 2)) | ||
|
||
m_frozen(x) | ||
self.assertEqual(m_frozen.impl.sum, x.relu()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
# I don't think there's any way to create a plain python object that | ||
# contains a torch.nn.Module inside it, but just in case... I'm not | ||
# sure freezing would handle this case correctly, so marking as xfail | ||
# so that if this ever _does_ start working someone will need to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
for (auto function : preservedMethods_) { | ||
GRAPH_DEBUG("Analyzing function: " + function->name()); | ||
auto graph = toGraphFunction(*function).graph(); | ||
optimizeSubGraphs(graph, applyInline); | ||
if (freezeInterfaces_) { | ||
inlineInterfaceCalls(graph); | ||
inlineInterfaceCalls(graph, interfacesToRetype); | ||
} | ||
// Record Attributes that are explicitly set in the module. | ||
// They cannot be folded. | ||
recordMutableAttrs(graph); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we be recording mutable attrs after reassigning Interface Types?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eellison suppose we put reassignInterfaceTypes immediately after inlineInterfaceCalls. Then we can have the following situation:
- preservedMethod1 does getattr(self, some_interface_module). It registers it as an interface to reassign type
- we reassign type for some_interface_module, so now the module object labels it as an interface type
- preservedMethod2 does setattr(self, some_interface_module). But when it looks it up in the module, it sees that some_interface_module is not an interface (because the type was reassigned already). So it doesn't error out.
Is there anything that can go wrong with reassigning interface types after recordMutableAttrs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, good point. Sorry the important part is that the calls themselves are inlined (and mutation is observed) not that the actual types on the modules are re-assigned.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually just to be safe I reordered it anyway, just added an extra loop to do the recordMutableAttrs afterwards.
|
||
self.assertEqual(expected, actual) | ||
|
||
def test_freeze_recursive_interfaces_same_name(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a test where an interface does an interface reassignment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!! Great tests too!!
853cd4c
to
273bf1b
Compare
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@pytorchbot rebase -s |
@pytorchbot successfully started a rebase job. Check the current status here |
Successfully rebased |
273bf1b
to
2c95496
Compare
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@pytorchbot rebase -s |
@pytorchbot successfully started a rebase job. Check the current status here |
- add tests (recursive interface swap & 2-method interface swap) - reassign types _before_ checking for mutability
Successfully rebased |
2c95496
to
a3055b2
Compare
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Hey @davidberard98. |
…#86039) Summary: This PR allows freezing modules like the one below: ```python # Ex. 1 torch.jit.interface class ModuleInterface(torch.nn.Module): def forward(self, inp: torch.Tensor) -> torch.Tensor: pass class ImplementsInterface(torch.nn.Module): def __init__(self): super(ImplementsInterface, self).__init__() self.sum = torch.zeros((2, 2)) def forward(self, inp: torch.Tensor) -> torch.Tensor: self.sum += inp.relu() # this makes the interface-implementing module mutable # and previously this would prevent freezing return self.sum class WrapperModule(torch.nn.Module): impl: ModuleInterface def __init__(self): super().__init__() self.impl = ImplementsInterface() def forward(self, x: torch.Tensor) -> torch.Tensor: return self.impl.forward(x) ``` Previously during freezing, we handle interfaces as shown below: 1. we inline interfaces in any preserved method graphs 2. during `cleanupFrozenModule`, we try to simplify the module data structure (<- this part is unrelated to freezing so far). During this step, if we found that a interface type was mutable, we'd error out; because of the possibility of a module that _swaps out the value of an interface-typed attribute at runtime_. Below is an example of a module that swaps out the value of an interface-typed attribute at runtime: ```python # Ex. 2 class MyBadModule(torch.nn.Module): impl: MyInterface option1: IfaceImpl1 option2: IfaceImpl2 .... def forward(self, x): if x > 0: self.impl = self.option1 else: self.impl = self.option2 .... ``` ^ this type of situation cannot be supported by freezing (or at least would be difficult to do correctly) because it greatly complicates the details of handling types and simplifying the module data structure. But we can still support the first example without _too_ much work: 1. inline the interface code as before 2. check to see if we have any setattrs on interface types; if so, error out 3. otherwise, replace the type of the interface types with the concrete type implementation 4. continue simplifying the module data structure as if we never had any interfaces. Pull Request resolved: #86039 Approved by: https://github.com/eellison Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/bac26155e7e5bca949e986fa36ab37e042b8ad53 Original Phabricator Test Plan: Imported from GitHub, without a `Test Plan:` line. Reviewed By: seemethere, eellison, hzh0512 Differential Revision: D39990822 fbshipit-source-id: 469375b8c7a43adb1e7b7401768527ba81b71719
This PR allows freezing modules like the one below:
Previously during freezing, we handle interfaces as shown below:
cleanupFrozenModule
, we try to simplify the module data structure (<- this part is unrelated to freezing so far). During this step, if we found that a interface type was mutable, we'd error out; because of the possibility of a module that swaps out the value of an interface-typed attribute at runtime.Below is an example of a module that swaps out the value of an interface-typed attribute at runtime:
^ this type of situation cannot be supported by freezing (or at least would be difficult to do correctly) because it greatly complicates the details of handling types and simplifying the module data structure.
But we can still support the first example without too much work: