New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
nn.Linear weight initalization - uniform or kaiming_uniform? #57109
Comments
the The overall logic simplifies to this equivalent logic: https://github.com/pytorch/pytorch/blob/v0.4.1/torch/nn/modules/linear.py#L48-L52 |
relevant PR: #9038 (review) |
Uniform distribution's stdv is basically multiplied by sqrt(3) when you draw: https://physics.nist.gov/cuu/Uncertainty/typeb.html and https://physics.stackexchange.com/questions/110242/why-is-uncertainty-divided-by-sqrt3 If you notice in my "equivalent logic" snippet, one thing worth noticing here is that by default, we don't multiply the stdv with sqrt(3). |
Just for others coming to this issue, there's a typo, which is "1" missing in |
They're the same |
The source code should be clear and understanable. Right now the code really does this:
Much more clearer would be:
And an explanation why we don't actually use Kaiming He's inititialization but values that are sqrt(3)/sqrt(6) smaller (with gain from 'linear'/'relu'), even though the user may think that from reading the source code. |
IMHO there is a discrepancy between the docs and code of nn.Linear, when it comes to initialization.
documentation says that the weights are initialized from
uniform ( 1/sqrt(in_ feaures) , 1/sqrt(in_ feaures)):
pytorch/torch/nn/modules/linear.py
Lines 53 to 56 in 0df5740
code says that the weights are initialized from
kaiming_uniform
pytorch/torch/nn/modules/linear.py
Lines 88 to 89 in 77721ee
and that includes factors of sqrt(3), gain based on 'a', and 'fan':
pytorch/torch/nn/init.py
Lines 390 to 395 in 77721ee
Is that an error or am I missing something?
cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @anjali411 @brianjo @mruberry @albanD
The text was updated successfully, but these errors were encountered: