Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gelu_accurate: nan gradient with fp16. add an inner cast to float32? #2235

Open
vadimkantorov opened this issue Jun 11, 2020 · 3 comments
Open
Assignees

Comments

@vadimkantorov
Copy link

gelu_accurate impl does not cast inputs to float32, so with float16 this leads to finite outputs (because tanh(+inf) = 1) and nan gradients even for moderately large inputs (because it can't represent 256 ** 3). I do understand that half dtype has quite low dynamic range, but it seems that for fixing that gelu(...) has cast to float32. Since it's supposed to be a non-saturating non-linearity, this is somewhat a problem (understanding that large inputs are still not a good idea, some tolerance to them may be good especially during unstable training)

A question is: was gelu_accurate added because it's more accurate? or because it's faster? (it used to be called gelu_fast). Isn't it an approximation? (the original paper says "we can approximate gelu with: ...(gelu_accurate formula)..."

Reusing code from https://github.com/pytorch/fairseq/blob/master/fairseq/modules/gelu.py :

import torch
import math

def gelu_accurate(x):
    if not hasattr(gelu_accurate, "_a"):
        gelu_accurate._a = math.sqrt(2 / math.pi)
    return (
        0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
    )

def gelu(x: torch.Tensor) -> torch.Tensor:
    return torch.nn.functional.gelu(x.float()).type_as(x)

x = torch.tensor([256.0], dtype = torch.float16, device = 'cuda', requires_grad = True)
y1 = gelu_accurate(x)
g1 = torch.autograd.grad(y1, (x,))[0]

y2 = gelu(x)
g2 = torch.autograd.grad(y2, (x,))[0]

print('y1', float(y1), 'g1', float(g1))
# y1 256.0 g1 nan

print('y2', float(y2), 'g2', float(g2))
# y2 256.0 g2 1.0
@myleott
Copy link
Contributor

myleott commented Jun 13, 2020

A question is: was gelu_accurate added because it's more accurate? or because it's faster? (it used to be called gelu_fast). Isn't it an approximation? (the original paper says "we can approximate gelu with: ...(gelu_accurate formula)..."

Nice observation! We don't actually use it anymore (preferring torch.nn.functional.gelu), but it's kept for backward compatibility.

A question is: was gelu_accurate added because it's more accurate? or because it's faster? (it used to be called gelu_fast). Isn't it an approximation? (the original paper says "we can approximate gelu with: ...(gelu_accurate formula)..."

Yeah that was poorly named on my part. I based it on the original README, where it's described as "slower but more accurate" than sigmoid(1.702 * x) * x, but it's still an approximation. Calling it gelu_approx would have been better.

@myleott myleott self-assigned this Jun 13, 2020
@myleott myleott added the bug label Jun 13, 2020
@vadimkantorov
Copy link
Author

About F.gelu: even without x.float() and then cast-back it seems to work for 256.0 and for torch.finfo(torch.float16).max, so casts may be excessive if the reason was just the dynamic range (and not precision)

@vadimkantorov
Copy link
Author

In addition, its current impementation cannot be TorschScripted because of tricks with caching sqrt(2/pi) in an attribute of the function

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants