Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make DPMultiheadAttention drop-in compatible with nn.MultiheadAttention #529

Closed
wants to merge 1 commit into from
Closed

Make DPMultiheadAttention drop-in compatible with nn.MultiheadAttention #529

wants to merge 1 commit into from

Conversation

Wei-1
Copy link

@Wei-1 Wei-1 commented Oct 25, 2022

Summary: This PR is target to resolve #123 on GitHub by having an additional re-naming mechanism to match the state_dict structure of nn.MultiheadAttention.

Differential Revision: D40671870

@facebook-github-bot facebook-github-bot added CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported labels Oct 25, 2022
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D40671870

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D40671870

@Wei-1
Copy link
Author

Wei-1 commented Oct 25, 2022

You can test the edited part by running pytest within the opacus folder.
And you should see the following expected result:

============= 162 passed, 41 skipped, 4411 warnings in 243.18s (0:04:03) =============

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D40671870

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D40671870

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D40671870

Copy link
Contributor

@ffuuugor ffuuugor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work! Making this work is very intricate work and your efforts are hugely appreciated.

@alexandresablayrolles @karthikprasad What are your thoughts on having state_dict not represent the actual physical structure of the model? Loading is not a problem, since DPMultiheadAttention can load both DP and vanilla state dictionaries.

On the one hand, it makes it easy to interoperate with vanilla nn.MultiHeadAttention - you can train a model with DP, save it's state_dict and then load this dict into a non-DP model.

On the other hand, I can potentially see issues with unexpected state_dict: keys not matching parameter names. That said, I don't see any immediate problems, but I might be missing something.

Comment on lines +164 to +259
if "in_proj_bias" in dp_attn.state_dict():
dp_attn._register_state_dict_hook(remove_in_proj_bias_hook)
self.assertFalse("in_proj_bias" in dp_attn.state_dict())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand the point of this assertion

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The state_dict function is copied from the original nn.module, and one function that is has is to have the capability to apply hook to modify the state_dict. This test is basically validating the hook can also work in the new state_dict function in our code.
The reason why I didn't directly use the hook function as the core method to modify state_dict is because I think it might cause confusion when people are trying to add, modify, or remove other hook in their network.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see, it is to check if state dict hooks still work, gotcha. Can you maybe add a comment explaining this?

@@ -126,3 +130,37 @@ def test_attn(
need_weights=True,
attn_mask=None,
)

attn.load_state_dict(dp_attn.state_dict())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would put this in a separate test with independently initialized modules. First, this would make it conceptually clearer and easier to understand potential fails. Second, attn and dp_attn has been initialized with the same weights (line 96), which defeats the purpose of this test

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense! I will do the modification tomorrow!

Comment on lines 394 to 395
def named_parameters(self, prefix: str = '', recurse: bool = True):
return self.state_dict(prefix = prefix).items()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tbh, I would not change named_parameters - having this method not match the actual structure of the model would be very confusing and would lead to unexpected behaviour:

  • named_parameters and parameters should return the same set of parameters
  • sometimes you need to modify the output of these methods and you want them to point to the actual objects used in the model
  • it would mess with calling model attributes directly

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that is a valid concern.
Currently, in opacus/tests/dp_layers/common.py line-220, we check if the two named_parameters can match.

        nn_params = dict(nn_module.named_parameters())
        dp_params = dict(dp_module.named_parameters())

How about we just change them into:

        nn_params = nn_module.state_dict()
        dp_params = dp_module.state_dict()

Since getting named_parameters and then change them back to dict is a bit weird as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that make sense

@Wei-1
Copy link
Author

Wei-1 commented Oct 25, 2022

On the other hand, I can potentially see issues with unexpected state_dict: keys not matching parameter names. That said, I don't see any immediate problems, but I might be missing something.

As we try to cover the parameter naming logic in nn.MultiheadAttention and DPMultiheadAttention, I think the major problem might come with maintenance. Since the entire transformation logic is rule-based, things will most likely break when there are modifications in nn.MultiheadAttention or DPMultiheadAttention in their naming method.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D40671870

@Wei-1
Copy link
Author

Wei-1 commented Oct 26, 2022

@ffuuugor Changes had been made to address the concerns. Please let me know if this make sense to you!

@Wei-1 Wei-1 requested a review from ffuuugor October 27, 2022 17:21
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D40671870

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D40671870

@ffuuugor
Copy link
Contributor

Thanks for addressing the comments!
I believe this PR is close to landing, but we need to sort out one thing first. Due to some bug in how CircleCI works with phabricator, our main testing pipeline is not being triggered on this PR.

I can see that it won't pass the linter check. Please refer to our Contributor's guide and run isort/black/flake commands to check your code it formatted properly

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D40671870

@Wei-1
Copy link
Author

Wei-1 commented Oct 28, 2022

Thanks, @ffuuugor! An update had been made to address the lint consistency with Black/ISort/Flake8.

@ffuuugor
Copy link
Contributor

ffuuugor commented Nov 1, 2022

Hey
Thanks for taking care of this.
One last thing - sometimes isort give different recommendations depending on the version. I have mine set up exactly as CircleCI and it gives the following:

--- a/opacus/layers/dp_multihead_attention.py
+++ b/opacus/layers/dp_multihead_attention.py
@@ -14,14 +14,13 @@
 # limitations under the License.

 import warnings
+from collections import OrderedDict

 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 from torch.nn.parameter import Parameter

-from collections import OrderedDict
-

 class SequenceBias(nn.Module):

Can you pls make the change to make the linter happy?

And I'm really sorry for back and forth on this. Tests not triggering for fbcode-exported PRs is painful and we're investigating.

…on (#529)

Summary:
Pull Request resolved: #529

This PR is target to resolve #123 on GitHub by having an additional re-naming mechanism to match the `state_dict` structure of `nn.MultiheadAttention`.

GitHub Issue Link: #123

Differential Revision: D40671870

fbshipit-source-id: b1e2a4526bde8c53fc01e30c65a33e8615e6ecce
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D40671870

Copy link
Contributor

@ffuuugor ffuuugor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@Wei-1
Copy link
Author

Wei-1 commented Nov 1, 2022

I just pushed a new version to address this issue! Please let me know if everything is good!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

DPMultiheadAttention is not drop-in compatible with nn.MultiheadAttention
3 participants