Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nn.linear module weight initialization fix #19526

Closed
wants to merge 3 commits into from

Conversation

interesaaat
Copy link
Contributor

This PR fixes the module weight initialization in nn.Linear by matching what state in the module documentation.

Fixes #19376.

@pytorchbot pytorchbot added the module: nn Related to torch.nn label Apr 19, 2019
@ezyang
Copy link
Contributor

ezyang commented Apr 22, 2019

@pytorchbot rebase this please

@ezyang
Copy link
Contributor

ezyang commented Apr 22, 2019

Thanks for the PR. Can we please get a test for this?

@interesaaat
Copy link
Contributor Author

Thanks for the PR. Can we please get a test for this?

Let me look into this. (Any hint on where the test should go?)

@smessmer smessmer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 22, 2019
@soumith
Copy link
Member

soumith commented Apr 22, 2019

@interesaaat i took a closer look. Mathematically, init.kaiming_uniform_(self.weight, a=math.sqrt(5)) is doing the same as a uniform with 1.0 / std::sqrt(weight.size(1)), from what I can tell.

If a is math.sqrt(5), then kaiming init's bound becomes: sqrt(6 / (6 * weight.size(1)) which is 1.0 / sqrt(weight.size(1)). They are both exactly doing the same thing.

@soumith soumith closed this Apr 22, 2019
@interesaaat
Copy link
Contributor Author

I still believe that this is sort of confusing (using kaiming_uniform for initializing the weight in the linear module, but not the bias, nor the weight in the bilinear module).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: nn Related to torch.nn open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

nn.Linear module weight initialization does not match the documentation
5 participants