You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
gelu_accurate impl does not cast inputs to float32, so with float16 this leads to finite outputs (because tanh(+inf) = 1) and nan gradients even for moderately large inputs (because it can't represent 256 ** 3). I do understand that half dtype has quite low dynamic range, but it seems that for fixing that gelu(...) has cast to float32. Since it's supposed to be a non-saturating non-linearity, this is somewhat a problem (understanding that large inputs are still not a good idea, some tolerance to them may be good especially during unstable training)
A question is: was gelu_accurate added because it's more accurate? or because it's faster? (it used to be called gelu_fast). Isn't it an approximation? (the original paper says "we can approximate gelu with: ...(gelu_accurate formula)..."
A question is: was gelu_accurate added because it's more accurate? or because it's faster? (it used to be called gelu_fast). Isn't it an approximation? (the original paper says "we can approximate gelu with: ...(gelu_accurate formula)..."
Nice observation! We don't actually use it anymore (preferring torch.nn.functional.gelu), but it's kept for backward compatibility.
A question is: was gelu_accurate added because it's more accurate? or because it's faster? (it used to be called gelu_fast). Isn't it an approximation? (the original paper says "we can approximate gelu with: ...(gelu_accurate formula)..."
Yeah that was poorly named on my part. I based it on the original README, where it's described as "slower but more accurate" than sigmoid(1.702 * x) * x, but it's still an approximation. Calling it gelu_approx would have been better.
About F.gelu: even without x.float() and then cast-back it seems to work for 256.0 and for torch.finfo(torch.float16).max, so casts may be excessive if the reason was just the dynamic range (and not precision)
gelu_accurate
impl does not cast inputs to float32, so with float16 this leads to finite outputs (becausetanh(+inf) = 1
) and nan gradients even for moderately large inputs (because it can't represent 256 ** 3). I do understand that half dtype has quite low dynamic range, but it seems that for fixing thatgelu(...)
has cast to float32. Since it's supposed to be a non-saturating non-linearity, this is somewhat a problem (understanding that large inputs are still not a good idea, some tolerance to them may be good especially during unstable training)A question is: was gelu_accurate added because it's more accurate? or because it's faster? (it used to be called
gelu_fast
). Isn't it an approximation? (the original paper says "we can approximate gelu with: ...(gelu_accurate formula)..."Reusing code from https://github.com/pytorch/fairseq/blob/master/fairseq/modules/gelu.py :
The text was updated successfully, but these errors were encountered: