New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make DPMultiheadAttention drop-in compatible with nn.MultiheadAttention #529
Conversation
This pull request was exported from Phabricator. Differential Revision: D40671870 |
1 similar comment
This pull request was exported from Phabricator. Differential Revision: D40671870 |
You can test the edited part by running
|
This pull request was exported from Phabricator. Differential Revision: D40671870 |
This pull request was exported from Phabricator. Differential Revision: D40671870 |
1 similar comment
This pull request was exported from Phabricator. Differential Revision: D40671870 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the great work! Making this work is very intricate work and your efforts are hugely appreciated.
@alexandresablayrolles @karthikprasad What are your thoughts on having state_dict
not represent the actual physical structure of the model? Loading is not a problem, since DPMultiheadAttention
can load both DP and vanilla state dictionaries.
On the one hand, it makes it easy to interoperate with vanilla nn.MultiHeadAttention
- you can train a model with DP, save it's state_dict and then load this dict into a non-DP model.
On the other hand, I can potentially see issues with unexpected state_dict: keys not matching parameter names. That said, I don't see any immediate problems, but I might be missing something.
if "in_proj_bias" in dp_attn.state_dict(): | ||
dp_attn._register_state_dict_hook(remove_in_proj_bias_hook) | ||
self.assertFalse("in_proj_bias" in dp_attn.state_dict()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I understand the point of this assertion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The state_dict
function is copied from the original nn.module
, and one function that is has is to have the capability to apply hook
to modify the state_dict
. This test is basically validating the hook
can also work in the new state_dict
function in our code.
The reason why I didn't directly use the hook
function as the core method to modify state_dict
is because I think it might cause confusion when people are trying to add, modify, or remove other hook
in their network.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see, it is to check if state dict hooks still work, gotcha. Can you maybe add a comment explaining this?
@@ -126,3 +130,37 @@ def test_attn( | |||
need_weights=True, | |||
attn_mask=None, | |||
) | |||
|
|||
attn.load_state_dict(dp_attn.state_dict()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would put this in a separate test with independently initialized modules. First, this would make it conceptually clearer and easier to understand potential fails. Second, attn
and dp_attn
has been initialized with the same weights (line 96), which defeats the purpose of this test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sense! I will do the modification tomorrow!
def named_parameters(self, prefix: str = '', recurse: bool = True): | ||
return self.state_dict(prefix = prefix).items() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tbh, I would not change named_parameters
- having this method not match the actual structure of the model would be very confusing and would lead to unexpected behaviour:
named_parameters
andparameters
should return the same set of parameters- sometimes you need to modify the output of these methods and you want them to point to the actual objects used in the model
- it would mess with calling model attributes directly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that is a valid concern.
Currently, in opacus/tests/dp_layers/common.py
line-220, we check if the two named_parameters
can match.
nn_params = dict(nn_module.named_parameters())
dp_params = dict(dp_module.named_parameters())
How about we just change them into:
nn_params = nn_module.state_dict()
dp_params = dp_module.state_dict()
Since getting named_parameters
and then change them back to dict
is a bit weird as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that make sense
As we try to cover the parameter naming logic in |
This pull request was exported from Phabricator. Differential Revision: D40671870 |
@ffuuugor Changes had been made to address the concerns. Please let me know if this make sense to you! |
This pull request was exported from Phabricator. Differential Revision: D40671870 |
1 similar comment
This pull request was exported from Phabricator. Differential Revision: D40671870 |
Thanks for addressing the comments! I can see that it won't pass the linter check. Please refer to our Contributor's guide and run isort/black/flake commands to check your code it formatted properly |
This pull request was exported from Phabricator. Differential Revision: D40671870 |
Thanks, @ffuuugor! An update had been made to address the lint consistency with Black/ISort/Flake8. |
Hey --- a/opacus/layers/dp_multihead_attention.py
+++ b/opacus/layers/dp_multihead_attention.py
@@ -14,14 +14,13 @@
# limitations under the License.
import warnings
+from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
-from collections import OrderedDict
-
class SequenceBias(nn.Module): Can you pls make the change to make the linter happy? And I'm really sorry for back and forth on this. Tests not triggering for fbcode-exported PRs is painful and we're investigating. |
…on (#529) Summary: Pull Request resolved: #529 This PR is target to resolve #123 on GitHub by having an additional re-naming mechanism to match the `state_dict` structure of `nn.MultiheadAttention`. GitHub Issue Link: #123 Differential Revision: D40671870 fbshipit-source-id: b1e2a4526bde8c53fc01e30c65a33e8615e6ecce
This pull request was exported from Phabricator. Differential Revision: D40671870 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
I just pushed a new version to address this issue! Please let me know if everything is good! |
Summary: This PR is target to resolve #123 on GitHub by having an additional re-naming mechanism to match the
state_dict
structure ofnn.MultiheadAttention
.Differential Revision: D40671870