Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save/Load OptimizedModule #101651

Closed
wants to merge 14 commits into from
Closed

Save/Load OptimizedModule #101651

wants to merge 14 commits into from

Conversation

msaroufim
Copy link
Member

@msaroufim msaroufim commented May 17, 2023

This enables us to torch.save(opt_mod) by just unwrapping and saving the _orig_mod

torch.load("opt_model.pt") is semantically equivalent to loading an uncompiled model, we can make this better by saving a cache or use AOT Inductor

So torch.save(mod) == torch.save(opt_mod) and torch.load(opt_model) == torch.load(mod)

So users don't need to write code like this

if is_compiled_module(model):
    torch.save(model._orig_mod)

Idea 0 (rejected)

Don't bother saving an optimized model, just throw and error but make it human readable in that people need to torch.save(opt_model.state_dict()) #101997

Idea 1 (rejected)

Instantiate an nn module instead of an optimizedmodule, it works but it's confusing and breaks deepcopying

    def __getstate__(self):
        return {"_orig_mod": self._orig_mod}

    def __setstate__(self, state):
        orig_mod = state["_orig_mod"]
        self.__class__ = orig_mod.__class__
        self.__dict__.update(orig_mod.__dict__)

    def __new__(cls, *args, **kwargs):
        instance = super().__new__(cls)
        super(OptimizedModule, instance).__init__()
        return instance

Idea 2 (this PR now)

A lot more work but we can use wrapper classes around functions instead so they become picklable. Need to merge this fast though cause the merge conflicts make me sad

Test pass on modules test/dynamo/test_modules.py

Failing test tracker

  dynamo/test_logging
  dynamo/test_dynamic_shapes
  dynamo/test_misc

cc @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @aakhundov @desertfire

@pytorch-bot
Copy link

pytorch-bot bot commented May 17, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/101651

Note: Links to docs will display an error until the docs builds have been completed.

❌ 21 New Failures

As of commit 18cc651:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.


def __new__(cls, *args, **kwargs):
instance = super().__new__(cls)
super(OptimizedModule, instance).__init__()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should definitely not do that haha

torch/_dynamo/eval_frame.py Outdated Show resolved Hide resolved
Copy link
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Silently converting one type to another in pickle load is confusing, let's not do that.

@pytorch pytorch deleted a comment from pytorchmergebot Jun 1, 2023
@pytorch pytorch deleted a comment from pytorchmergebot Jun 1, 2023
@@ -1435,6 +1438,19 @@ def test_recursion(self):
opt_mod(torch.randn(10, 10))
self.assertEqual(cnt.frame_count, 1)

def test_save_and_load(self):
mod = MockModule()
cnt = torch._dynamo.testing.CompileCounter()
Copy link
Member Author

@msaroufim msaroufim Jun 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note from @albanD: check if it works with inductor as well. It's possible that this backend has a simple state

@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Aug 12, 2023
@msaroufim
Copy link
Member Author

Closing this for now. I couldn't figure out why this changes introduces extra graph breaks. Might revisit

@msaroufim msaroufim closed this Aug 13, 2023
weiyusheng pushed a commit to weiyusheng/pytorch that referenced this pull request May 9, 2024
weiyusheng added a commit to weiyusheng/pytorch that referenced this pull request Jun 3, 2024
weiyusheng added a commit to weiyusheng/pytorch that referenced this pull request Jun 4, 2024
pytorchmergebot pushed a commit that referenced this pull request Jun 5, 2024
## save&load support for OptimizedModule

[Issue Description](#101651)

English is not my native language; please excuse typing errors.

This pr is based on commit b958810\
I'll do something with the merge conflicts later

### test result for test/dynamo

Conclusion:\
It performs the same as before as far as I can see.

ENV(CPU only):\
platform linux -- Python 3.10.14, pytest-7.3.2, pluggy-1.5.0\
configfile: pytest.ini\
plugins: anyio-3.7.1, cpp-2.3.0, flakefinder-1.1.0, xdist-3.3.1, xdoctest-1.1.0, metadata-3.1.1, html-4.1.1, hypothesis-5.35.1, rerunfailures-14.0

#### before this pr:

[before](https://github.com/pytorch/pytorch/files/15329370/before.md)

#### after this pr:

[after](https://github.com/pytorch/pytorch/files/15329376/after.md)

### some changes

1. add test_save_and_load to test/dynamo/test_modules.py with & without "backend='inductor'"
2. add \_\_reduce\_\_ function to OptimizedModule and derived classes of _TorchDynamoContext for pickling & unpickling
3. change the wrappers into wrapper classes ( including convert_frame_assert, convert_frame, catch_errors_wrapper in torch/_dynamo/convert_frame.py & wrap_backend_debug in torch/_dynamo/repro/after_dynamo.py )
4. change self.output.compiler_fn into innermost_fn(self.output.compiler_fn) in torch/_dynamo/symbolic_convert.py to get the origin compiler_fn and to avoid the "compiler_fn is not eager" condition

Pull Request resolved: #126374
Approved by: https://github.com/msaroufim, https://github.com/jansel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants