Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nn.Linear weight initalization - uniform or kaiming_uniform? #57109

Closed
adrianstaniec opened this issue Apr 28, 2021 · 7 comments
Closed

nn.Linear weight initalization - uniform or kaiming_uniform? #57109

adrianstaniec opened this issue Apr 28, 2021 · 7 comments
Labels
high priority module: docs Related to our documentation, both in docs/ and docblocks module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@adrianstaniec
Copy link

adrianstaniec commented Apr 28, 2021

IMHO there is a discrepancy between the docs and code of nn.Linear, when it comes to initialization.

documentation says that the weights are initialized from
uniform ( 1/sqrt(in_ feaures) , 1/sqrt(in_ feaures)):

weight: the learnable weights of the module of shape
:math:`(\text{out\_features}, \text{in\_features})`. The values are
initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
:math:`k = \frac{1}{\text{in\_features}}`

code says that the weights are initialized from
kaiming_uniform

def reset_parameters(self) -> None:
init.kaiming_uniform_(self.weight, a=math.sqrt(5))

and that includes factors of sqrt(3), gain based on 'a', and 'fan':

pytorch/torch/nn/init.py

Lines 390 to 395 in 77721ee

fan = _calculate_correct_fan(tensor, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
with torch.no_grad():
return tensor.uniform_(-bound, bound)

Is that an error or am I missing something?

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @anjali411 @brianjo @mruberry @albanD

@zou3519 zou3519 added high priority module: docs Related to our documentation, both in docs/ and docblocks labels Apr 28, 2021
@zou3519
Copy link
Contributor

zou3519 commented Apr 28, 2021

Hi-pri because the docs say something different from the code. We should probably document the factor of sqrt(3) that we use. I remember that there was some justification for it (maybe @gchanan or @soumith remembers?)

@soumith
Copy link
Member

soumith commented Apr 28, 2021

the kaiming_init is used for convenience, but basically the sqrt(5) goes in and gets simplified in the formula for gain as sqrt(2 / (sqrt(5) * sqrt(5)) which is sqrt(1/3).

The overall logic simplifies to this equivalent logic:

https://github.com/pytorch/pytorch/blob/v0.4.1/torch/nn/modules/linear.py#L48-L52

@soumith
Copy link
Member

soumith commented Apr 28, 2021

relevant PR: #9038 (review)

@soumith
Copy link
Member

soumith commented Apr 28, 2021

We should probably document the factor of sqrt(3) that we use. I remember that there was some justification for it

Uniform distribution's stdv is basically multiplied by sqrt(3) when you draw: https://physics.nist.gov/cuu/Uncertainty/typeb.html and https://physics.stackexchange.com/questions/110242/why-is-uncertainty-divided-by-sqrt3

If you notice in my "equivalent logic" snippet, one thing worth noticing here is that by default, we don't multiply the stdv with sqrt(3).
The reason we don't do that because Collobert at al. in some historical past have figured out that this not-multiplying and having a slight gain in the uniform distributions heuristically works better. This is not recorded in literature anywhere but has been recorded in code since Lush, Senna, Torch5, Torch7 and now PyTorch which is somewhat unfortunate. A detailed discussion of this was recorded on Google Plus, which is now defunct and that discussion has been erased from the internet. However, I've revived a copy of that discussion here: https://soumith.ch/files/20141213_gplus_nninit_discussion.htm

@albanD albanD added module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels May 3, 2021
@ciaochiaociao
Copy link

ciaochiaociao commented Jun 15, 2021

the kaiming_init is used for convenience, but basically the sqrt(5) goes in and gets simplified in the formula for gain as sqrt(2 / (sqrt(5) * sqrt(5)) which is sqrt(1/3).

The overall logic simplifies to this equivalent logic:

https://github.com/pytorch/pytorch/blob/v0.4.1/torch/nn/modules/linear.py#L48-L52

Just for others coming to this issue, there's a typo, which is "1" missing in sqrt(2 / (1 + sqrt(5) * sqrt(5)) as in https://pytorch.org/docs/stable/nn.init.html

@erhuliu
Copy link

erhuliu commented May 7, 2022

They're the same

@matthijsvk
Copy link

matthijsvk commented Nov 10, 2022

The source code should be clear and understanable.

Right now the code really does this:

fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)
# generate sqrt(3) in complicated way: sqrt(2 / (1 + 5)) = sqrt(1/3)
gain = torch.nn.init.calculate_gain("leaky_relu", math.sqrt(5)) 
# sqrt(3) b/c uniform gets rid of the sqrt(1/3) again
bound = math.sqrt(3) * gain / math.sqrt(fan_in) 
nn.init.uniform_(weight, -bound, bound)

Much more clearer would be:

fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(weight, -bound, bound)

And an explanation why we don't actually use Kaiming He's inititialization but values that are sqrt(3)/sqrt(6) smaller (with gain from 'linear'/'relu'), even though the user may think that from reading the source code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: docs Related to our documentation, both in docs/ and docblocks module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants