-
Notifications
You must be signed in to change notification settings - Fork 22.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Save/Load OptimizedModule #101651
Save/Load OptimizedModule #101651
Conversation
torch/_dynamo/eval_frame.py
Outdated
|
||
def __new__(cls, *args, **kwargs): | ||
instance = super().__new__(cls) | ||
super(OptimizedModule, instance).__init__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should definitely not do that haha
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Silently converting one type to another in pickle load is confusing, let's not do that.
@@ -1435,6 +1438,19 @@ def test_recursion(self): | |||
opt_mod(torch.randn(10, 10)) | |||
self.assertEqual(cnt.frame_count, 1) | |||
|
|||
def test_save_and_load(self): | |||
mod = MockModule() | |||
cnt = torch._dynamo.testing.CompileCounter() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note from @albanD: check if it works with inductor as well. It's possible that this backend has a simple state
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Closing this for now. I couldn't figure out why this changes introduces extra graph breaks. Might revisit |
## save&load support for OptimizedModule [Issue Description](#101651) English is not my native language; please excuse typing errors. This pr is based on commit b958810\ I'll do something with the merge conflicts later ### test result for test/dynamo Conclusion:\ It performs the same as before as far as I can see. ENV(CPU only):\ platform linux -- Python 3.10.14, pytest-7.3.2, pluggy-1.5.0\ configfile: pytest.ini\ plugins: anyio-3.7.1, cpp-2.3.0, flakefinder-1.1.0, xdist-3.3.1, xdoctest-1.1.0, metadata-3.1.1, html-4.1.1, hypothesis-5.35.1, rerunfailures-14.0 #### before this pr: [before](https://github.com/pytorch/pytorch/files/15329370/before.md) #### after this pr: [after](https://github.com/pytorch/pytorch/files/15329376/after.md) ### some changes 1. add test_save_and_load to test/dynamo/test_modules.py with & without "backend='inductor'" 2. add \_\_reduce\_\_ function to OptimizedModule and derived classes of _TorchDynamoContext for pickling & unpickling 3. change the wrappers into wrapper classes ( including convert_frame_assert, convert_frame, catch_errors_wrapper in torch/_dynamo/convert_frame.py & wrap_backend_debug in torch/_dynamo/repro/after_dynamo.py ) 4. change self.output.compiler_fn into innermost_fn(self.output.compiler_fn) in torch/_dynamo/symbolic_convert.py to get the origin compiler_fn and to avoid the "compiler_fn is not eager" condition Pull Request resolved: #126374 Approved by: https://github.com/msaroufim, https://github.com/jansel
This enables us to
torch.save(opt_mod)
by just unwrapping and saving the_orig_mod
torch.load("opt_model.pt")
is semantically equivalent to loading an uncompiled model, we can make this better by saving a cache or use AOT InductorSo
torch.save(mod) == torch.save(opt_mod)
andtorch.load(opt_model) == torch.load(mod)
So users don't need to write code like this
Idea 0 (rejected)
Don't bother saving an optimized model, just throw and error but make it human readable in that people need to
torch.save(opt_model.state_dict())
#101997Idea 1 (rejected)
Instantiate an nn module instead of an optimizedmodule, it works but it's confusing and breaks deepcopying
Idea 2 (this PR now)
A lot more work but we can use wrapper classes around functions instead so they become picklable. Need to merge this fast though cause the merge conflicts make me sad
Test pass on modules
test/dynamo/test_modules.py
Failing test tracker
cc @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @aakhundov @desertfire