Skip to content

Explicit attribute setting for pruning and weight_norm upon reparam removal #34170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from

Conversation

mickypaganini
Copy link
Contributor

To address one of the problems with RNNs that emerged in #33618, I modified the remove methods in torch.nn.utils.prune and torch.nn.utils.weight_norm to make an explicit call to setattr, which, in rnn.py directly modifies _flat_weights (https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py#L96) to include the new element.

This is important so that _flat_weights can reflect the presence of the Parameter after the (pruning or weight norm) reparametrization is removed. Without this, the weight in _flat_weights would remain a tensor, as originally set by the reparametrization.

Simple testing is added, which depends on the current naming scheme for the LSTM module.

@mickypaganini mickypaganini requested a review from apaszke as a code owner March 3, 2020 22:57
@dr-ci
Copy link

dr-ci bot commented Mar 3, 2020

💊 Build failures summary and remediations

As of commit 16ac08b (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-CircleCI failure(s)

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

See how this bot performed.

This comment has been revised 8 times.

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these properly detected as parameters when you do .parameters() after doing this??

@mickypaganini
Copy link
Contributor Author

mickypaganini commented Apr 27, 2020

The parameters do show up correctly.

Take the example in the test I had added to test_nn.py for this change (https://github.com/pytorch/pytorch/pull/34170/files#diff-d89baec73022f5f511c5beb5ce6498dfR2449). Here, first, we prune the 'weight_ih_l0' parameter in an LSTM --> the resulting parameters at this stage, after pruning, should be ['weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0', 'weight_ih_l0_orig'], and that's indeed fine [note the last parameter created by the reparametrizaton]. Now we hit remove, which is the subject of the suggested change in this PR. Your question, in this case, is whether 'weight_ih_l0' would appear back again in the list of parameters. The answer is yes. .named_parameters() will contain the following keys (and corresponding parameters): ['weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0', 'weight_ih_l0'].

Same thing for weight norm. Take the test case here (https://github.com/pytorch/pytorch/pull/34170/files#diff-d89baec73022f5f511c5beb5ce6498dfR2468). After applying weight norm to the LSTM param 'weight_ih_l0', the new .named_parameters() will include ['weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0', 'weight_ih_l0_g', 'weight_ih_l0_v']. After calling .remove_weight_norm(), not only is l._flat_weights now finally correct, but also .named_parameters() remains correct and contains the following params: ['weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0', 'weight_ih_l0'].

This, indeed, had always been the case. The only difference is how this affects the updating of .flat_weights in an LSTM, which, because of its implementation, requires an explicit call to setattr.

Is your requested change to include this explicitly in the tests?

@ngimel
Copy link
Collaborator

ngimel commented Apr 27, 2020

Another possible alternative is to override register_parameter in RNN module, so that it calls setattr - that would fix not only weight norm and pruning, but also other reparameterizations that users might have that use similar approaches.

@albanD
Copy link
Collaborator

albanD commented Apr 27, 2020

My understanding is as follows: modules that are not rnn it feels like named_parameters() here checks the ._parameters field.
So if you set nn.Parameters without putting them there, they won't be detected.

Not sure what I'm missing here?

Note in your tests that you traverse ._flat_weights and not .parameters() (which users are usually using).

@mickypaganini
Copy link
Contributor Author

@ngimel Probably true! The only thing is that I'm not familiar enough with the design decisions that went into the RNN code, so I'm not sure if there is any particular reason why things are exactly the way they are, and what else relies on them being exactly this way. On the other hand, I am pretty confident that changing pruning and weight norm won't have negative repercussion on other parts of the codebase. Happy with either though :)

@mickypaganini
Copy link
Contributor Author

@albanD we do put the parameter back into ._parameters explicitly with the line module._parameters[self._tensor_name] = None, and then we fill it. This is true for RNNs as well as other non-RNN Modules (as far as I can tell from the examples I tried).

In case I'm not understanding correctly, do you have an explicit case you'd like me to check?

You have a fair point that the tests don't touch this aspect but only check that ._flat_weights is set correctly, so I can add an assertion about the parameters as well.

@mickypaganini
Copy link
Contributor Author

Perhaps more clearly...
My understanding is that, as far as ._parameters and .named_parameters() are concerned, doing module.register_parameter(self._tensor_name, orig) [as previously found in the code], or doing

module._parameters[self._tensor_name] = None
setattr(module, self._tensor_name, orig)

[as currently suggested in this PR], have exactly the same effect.

What changes is the effect they have on how _flat_weights are set in an RNN because it expects a call to setattr.

@ngimel
Copy link
Collaborator

ngimel commented Apr 27, 2020

So, explicitly adding a test for this will render this discussion moot, and thus is a reasonable thing to do :-) However, consider overriding register_parameter in RNN.

@albanD
Copy link
Collaborator

albanD commented Apr 27, 2020

Ok, after some testing, I though that setattr was bypassing our custom logic to add parameters but it does not.
Last question: Why do you need to do module._parameters[self._tensor_name] = None then? The content here will be overwritten by the setattr call anyway.

But given that this is rnn-specific logic (in the sense that it's only a trick to make rnn works), I would agree with natalia that it might be better to move this to the rnn code by changing register_parameter there.

@mickypaganini
Copy link
Contributor Author

I looked into it and I think you're right, I think we don't need to explicitly create the placeholder in ._parameters. As long as the attribute we are adding with setattr is a Parameter, setattr should automatically put it into the _parameters. So I removed that line.

I also added the recommended test.

Other than that, changes to rnn are something that would benefit from an rnn maintainer to properly think through, as a quick pass over the rnn file shows that such a change may have a large blast radius. The code in rrn.py literally says:

# Resets _flat_weights
# Note: be v. careful before removing this, as 3rd party device types
# likely rely on this behavior to properly .to() modules like LSTM.
self._flat_weights = [(lambda wn: getattr(self, wn) if hasattr(self, wn) else None)(wn) for wn in self._flat_weights_names]

@mickypaganini
Copy link
Contributor Author

One thing that stands out is that basically now this PR simply amounts to substituting module.register_parameter(name, param) with setattr(module, name, param).

So perhaps what would need to change, instead, is register_parameter in module.py? That could use the setattr line above instead of the current self._parameters[name] = param. But god knows what can of worms that change is going to open up 😨

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But keep in mind that setattr() itself calls register_parameter. So you would go in infinite recursion :D

Thanks for the update. Looks good to me.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mickypaganini is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@mickypaganini merged this pull request in d37a486.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants