Skip to content

Fix inefficient recursive update in ShardedTensor.state_dict hook #68806

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

awaelchli
Copy link
Contributor

@awaelchli awaelchli commented Nov 23, 2021

Fixes #68805

The bug is described in the linked issue. This PR is an attempt to make the functions _recurse_update_dict and _recurse_update_module more efficient in how they iterate over the submodules. The previous implementation was suboptimal, as it recursively called the update method on the submodules returned by module.named_modules(), while module.named_modules() already returned all submodules including nested ones.

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang

@pytorch-probot
Copy link

pytorch-probot bot commented Nov 23, 2021

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/awaelchli/pytorch/blob/ec204a6eea139d37e3f35a5d0983291915a76908/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-bionic-py3.6-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/xla ✅ triggered
linux-docs ciflow/all, ciflow/cpu, ciflow/default, ciflow/docs, ciflow/linux ✅ triggered
linux-vulkan-bionic-py3.6-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile ✅ triggered
linux-xenial-py3.6-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers ✅ triggered
linux-xenial-py3.6-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx ✅ triggered
linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/win ✅ triggered
Skipped Workflows
caffe2-linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux 🚫 skipped
docker-builds ciflow/all 🚫 skipped
ios-12-5-1-arm64 ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-coreml ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-custom-ops ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-full-jit ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-arm64-metal ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-x86-64 ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-x86-64-coreml ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
ios-12-5-1-x86-64-full-jit ciflow/all, ciflow/ios, ciflow/macos 🚫 skipped
libtorch-linux-xenial-cuda10.2-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow 🚫 skipped
linux-docs-push ciflow/all, ciflow/cpu, ciflow/linux, ciflow/scheduled 🚫 skipped
macos-10-15-py3-arm64 ciflow/all, ciflow/macos 🚫 skipped
macos-10-15-py3-lite-interpreter-x86-64 ciflow/all, ciflow/macos 🚫 skipped
macos-11-py3-x86-64 ciflow/all, ciflow/macos 🚫 skipped
parallelnative-linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux 🚫 skipped
periodic-libtorch-linux-bionic-cuda11.5-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-bionic-cuda11.5-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck 🚫 skipped
periodic-linux-xenial-cuda11.1-py3.6-gcc7-debug ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.1-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped

You can add a comment to the PR and tag @pytorchbot with the following commands:
# ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun

# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slow

For more information, please take a look at the CI Flow Wiki.

@facebook-github-bot
Copy link
Contributor

Hi @awaelchli!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@fb.com. Thanks!

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Nov 23, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit ec204a6 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@facebook-github-bot facebook-github-bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Nov 23, 2021
@facebook-github-bot
Copy link
Contributor

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Facebook open source project. Thanks!

Copy link
Contributor

@pritamdamania87 pritamdamania87 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for looking into this issue!

if submodule_name:
_recurse_update_module(submodule, state_dict, key + '.')

for attr_name, attr in module.__dict__.items():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be submodule.__dict__.items()? I'm wondering why this test didn't catch this issue: https://github.com/pytorch/pytorch/blob/master/test/distributed/_sharded_tensor/test_sharded_tensor.py#L965? Can we enhance that test so that it does catch this issue (maybe we need multiple levels of nesting)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, it should be submodule.
I'm not yet sure why the test had passed, but I'll give it a try.

Copy link
Contributor Author

@awaelchli awaelchli Nov 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pritamdamania87 I'm a little bit stuck here.
The test will pass no matter what. I'm running

python test/distributed/_sharded_tensor/test_sharded_tensor.py TestShardedTensorChunked.test_state_dict

I even added an assert False directly inside the test case and it still passed. The output is always:

.
----------------------------------------------------------------------
Ran 1 test in 1.523s

OK

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is indeed weird, can you try pytest test/distributed/_sharded_tensor/test_sharded_tensor.py -k test_state_dict?

cc @janeyx99 I was wondering if there might be something wrong here with our test infra?

Copy link
Contributor Author

@awaelchli awaelchli Nov 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion @pritamdamania87
The same phenomenon happens with your command. Output looks normal, as if all tests passed.
However, I found out that it is because of the @with_comms decorator. This one makes the test pass without running it. By removing it, my self.assertEqual(1, 2) triggers as expected.

Any intuition how this decorator was meant to be used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I signed the CLA when I opened the PR against master. The PR is now pointing to wanchaol/192/head, this must have reset the CLA bot. I submitted CLA once again a few hours ago, but the bot hasn't update the message.
Should I point back to master?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awaelchli Ah sorry for the confusion, I meant just checking locally with #69493 patched to see if the unit tests work as expected. I think we should wait for #69493 to land, rebase this PR on top of that on master and then have this PR against master :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding on some test failures, this issue unfortunately hided some silent failures that we couldn't capture earlier, just updated the PR and should fixed all issues.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

regarding CLA, as long as the state_dict related tests passed, I think you can re-point to master, i will try to merge #69493 as soon as i can. I do notice that test_load_state_dict_errors failed both on master and on my fix (as it's without with_comms wrapper, looks like a RPC related error). We can follow up on this separately I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Thanks! Yes the state dict tests passed when rebased on top of #69493

@awaelchli awaelchli marked this pull request as ready for review November 27, 2021 04:28
@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 29, 2021
@awaelchli awaelchli force-pushed the bugfix/sharded-dict-update branch 2 times, most recently from ca2e618 to 7c7698c Compare December 6, 2021 23:25
@awaelchli awaelchli changed the base branch from master to gh/wanchaol/192/head December 6, 2021 23:26
@awaelchli awaelchli force-pushed the bugfix/sharded-dict-update branch from 7c7698c to f0607d3 Compare December 7, 2021 01:49
Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, have one small nit

@@ -360,26 +360,18 @@ def pre_load_state_dict_hook(module, state_dict, prefix, local_metadata, strict,
_recurse_update_module(module, state_dict, prefix)

def _recurse_update_module(module, state_dict, prefix):
for attr_name, attr in module.__dict__.items():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: given that we don't need the recursion anymore, shall we remove these two functions and put the main logic in state_dict_hook and pre_load_state_dict_hook?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me. @pritamdamania87 do you agree?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awaelchli Yes this makes sense

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

wanchaol and others added 3 commits December 8, 2021 22:52
When added `with_comms` decorator with arguments, we added an `with_comms_decorator` inner function, `with_comms()` will refer to a function object, the added parentheses was necessary to use in test cases.

This PR fixes the `with_comms` wrapper behavior, to allow we both specify with/without arguments in test cases:
```
@with_comms
def test_case:
    ...
```
or
```
@with_comms(backend="gloo")
def test_case:
    ...
```

Differential Revision: [D32897555](https://our.internmc.facebook.com/intern/diff/D32897555/)

[ghstack-poisoned]
@awaelchli awaelchli force-pushed the bugfix/sharded-dict-update branch from f0607d3 to ba89913 Compare December 9, 2021 03:54
@awaelchli awaelchli changed the base branch from gh/wanchaol/192/head to master December 9, 2021 03:54
@albanD albanD removed request for albanD and bowangbj December 10, 2021 16:17
@facebook-github-bot
Copy link
Contributor

@wanchaol has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, thanks for fixing this!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed oncall: distributed Add this issue/PR to distributed oncall triage queue open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Recursive update in ShardedTensor.state_dict hook is inefficient
6 participants