-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Explicit attribute setting for pruning and weight_norm upon reparam removal #34170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
💊 Build failures summary and remediationsAs of commit 16ac08b (more details on the Dr. CI page):
ci.pytorch.org: 1 failedThis comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker. This comment has been revised 8 times. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these properly detected as parameters when you do .parameters()
after doing this??
The parameters do show up correctly. Take the example in the test I had added to Same thing for weight norm. Take the test case here (https://github.com/pytorch/pytorch/pull/34170/files#diff-d89baec73022f5f511c5beb5ce6498dfR2468). After applying weight norm to the LSTM param This, indeed, had always been the case. The only difference is how this affects the updating of Is your requested change to include this explicitly in the tests? |
Another possible alternative is to override |
My understanding is as follows: modules that are not rnn it feels like Not sure what I'm missing here? Note in your tests that you traverse |
@ngimel Probably true! The only thing is that I'm not familiar enough with the design decisions that went into the RNN code, so I'm not sure if there is any particular reason why things are exactly the way they are, and what else relies on them being exactly this way. On the other hand, I am pretty confident that changing pruning and weight norm won't have negative repercussion on other parts of the codebase. Happy with either though :) |
@albanD we do put the parameter back into In case I'm not understanding correctly, do you have an explicit case you'd like me to check? You have a fair point that the tests don't touch this aspect but only check that |
Perhaps more clearly... module._parameters[self._tensor_name] = None
setattr(module, self._tensor_name, orig) [as currently suggested in this PR], have exactly the same effect. What changes is the effect they have on how |
So, explicitly adding a test for this will render this discussion moot, and thus is a reasonable thing to do :-) However, consider overriding register_parameter in RNN. |
Ok, after some testing, I though that But given that this is rnn-specific logic (in the sense that it's only a trick to make rnn works), I would agree with natalia that it might be better to move this to the rnn code by changing register_parameter there. |
I looked into it and I think you're right, I think we don't need to explicitly create the placeholder in I also added the recommended test. Other than that, changes to rnn are something that would benefit from an rnn maintainer to properly think through, as a quick pass over the rnn file shows that such a change may have a large blast radius. The code in # Resets _flat_weights
# Note: be v. careful before removing this, as 3rd party device types
# likely rely on this behavior to properly .to() modules like LSTM.
self._flat_weights = [(lambda wn: getattr(self, wn) if hasattr(self, wn) else None)(wn) for wn in self._flat_weights_names] |
One thing that stands out is that basically now this PR simply amounts to substituting So perhaps what would need to change, instead, is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But keep in mind that setattr()
itself calls register_parameter
. So you would go in infinite recursion :D
Thanks for the update. Looks good to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mickypaganini is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@mickypaganini merged this pull request in d37a486. |
To address one of the problems with RNNs that emerged in #33618, I modified the
remove
methods intorch.nn.utils.prune
andtorch.nn.utils.weight_norm
to make an explicit call tosetattr
, which, inrnn.py
directly modifies_flat_weights
(https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py#L96) to include the new element.This is important so that
_flat_weights
can reflect the presence of theParameter
after the (pruning or weight norm) reparametrization is removed. Without this, the weight in_flat_weights
would remain a tensor, as originally set by the reparametrization.Simple testing is added, which depends on the current naming scheme for the LSTM module.