-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[proposal] Add approx variant option to F.gelu #39853
Comments
I used the tanh approximation simply because the error function erf was slow in tensorflow some years ago. If the exact version is fast enough now and does not have numerical issues, I do not see a reason to use an inexact version. |
To echo @hendrycks comment: torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
torch._C._jit_override_can_fuse_on_gpu(True)
torch._C._jit_set_texpr_fuser_enabled(False)
@torch.jit.script
def gelu(x):
return x * (0.5 + torch.erf(x * 0.7071067811865476) * 0.5)
@torch.jit.script
def fast_gelu_1(x):
# sqrt(2/pi) = 0.7978845608028654
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
@torch.jit.script
def fast_gelu_2(x):
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
x = torch.randn(32, 128, device="cuda")
def time_me(fn):
for i in range(100):
fn(x)
torch.cuda.synchronize()
time_me(gelu)
%timeit time_me(gelu)
time_me(fast_gelu_1)
%timeit time_me(fast_gelu_1)
time_me(fast_gelu_2)
%timeit time_me(fast_gelu_2) gives me (on a RadeonVII where I happened to have a running Jupyter notebook)
Now, the fuser might not have produce kernels but the difference between them is only the compute. Except when people need to reproduce weird technical decisions relating to ancient versions of other frameworks, one should not use |
The "fast" variants are actually quite slow, right? I stumbled on this only during reproducing huggingface models in fairseq (and reading huggingface code). I agree this is very specific, but if people keep using those old variants of GELU for running original BERT models (like they did with CrossMapLRN2d because reproducing AlexNet was important), then maybe it's worth to put in an approx variant as well - for compat reasons specifically. But I agree, this may not be a sufficient reason |
The speed difference probably doesn't matter much. The HuggingFace BERT used from the tutorials in their transformer library does not use it anymore apparently (this is how I found that the translation was not quite accurate). I would probably consider that to be the leading BERT implementation for PyTorch (and possibly beyond), so I would think that the use case is too limited. |
(gpt-2 used gelu_new (gelu approx) by default: https://github.com/huggingface/transformers/blob/master/src/transformers/configuration_gpt2.py#L125 ) |
Yes, though GPT-2 was in tensorflow which did not have any optimized versions. I'll suggest to tensorflow not to default to the approximate version. |
It seems that at the end TensorFlow kept both options: default exact and optional approximate... |
#39853 found that approximate version is 1.75x faster on TPUv2 and now supports both approx and exact variants |
I now support adding this option after seeing that more people are manually adding the approximation in their PyTorch model implementations. Adding something like https://github.com/huggingface/transformers/blob/master/src/transformers/activations.py#L56 |
Since this is is activate again, I've been chatting (complaining) a bit to @ptrblck about the speed of GELU in PyTorch w/ CUDA. In float32 networks using GELU vs (native) SiLU are comparable. However, w/ AMP enabled, GELU is quite a bit slower. AMP autocasts the op to float32 so that's a big part of the slowdown. I've done some measurements on pure float16 and there seems to be some pytorch version/hardware specific differences where sometimes it's equivalent and sometimes slower (but not as bad as AMP). I'm not sure if that descrepency is a different issue? Popular networks in my focus such as NFNet, Vision Transformers, and Vision MLP models all use GELU by default (those were all developed on TPU w/ JAX where this is not much an issue). To put a few numbers to it that I collected, on an RTX 3090 w/ AMP enabled. ViT Small:
NFNet-F0:
The vit number ratios are similar for other ViT and MLP models using GELU (most of them) ... in the 15-20% range. NFNets in the 25-35% range. The python API quick_gelu |
The other approximations are not, but slightly fine-tuning would probably fix and slight degradation. While there is a tanh cubic approximation and a sigmoid approximation, the sigmoid might be preferable largely because the OpenAI people chose that one over the tanh cubic approximation (their algorithms team keeps track of efficiency and knows what they're doing). Looks like a native GELU approximation would be really useful. Thanks for sharing those numbers @rwightman! |
While chatting with @rwightman we've also talked about other approximations such as: @torch.jit.script
def bias_gelu(x):
return x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@torch.jit.script
def bias_gelu_back(g, x):
ff = 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
return ff*g One concern I've raised was the potential saturation of |
My assumption is that the erf is relatively slow. An approximate flag sounds easiest and is what tensorflow does. Tensorflow uses a tanh approximation, but I assume the sigmoid approximation is cheapest.
…On Thu, Jun 10, 2021, 4:24 PM ptrblck ***@***.***> wrote:
While chatting with @rwightman <https://github.com/rwightman> we've also
talked about other approximations such as:
@torch.jit.scriptdef bias_gelu(x):
return x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@torch.jit.scriptdef bias_gelu_back(g, x):
ff = 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
return ff*g
One concern I've raised was the potential saturation of erf for x <= -4
and x >= 4, but I don't know if this would cause issues for commonly used
models (and if one approximation would be preferred over the other one).
A potential approach would be to use an argument such as
approximate=True/False (or different op name).
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#39853 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ACZBITV55EVLAAFCRO7LR6LTSFCUBANCNFSM4N3NDNNA>
.
|
In the timing above, it seems also that amp is computing the approximate version in half precision while the original is not. |
#59639 should help with this |
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: mikaylagawarecki Differential Revision: D33744717 Pulled By: jbschlosser fbshipit-source-id: d64532a562ed53247bb4fa52bb16722634d5c187
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: mikaylagawarecki Differential Revision: D33744717 Pulled By: jbschlosser fbshipit-source-id: d64532a562ed53247bb4fa52bb16722634d5c187
Summary: 1. Implements #39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: #61439 Reviewed By: mikaylagawarecki Differential Revision: D33744717 Pulled By: jbschlosser fbshipit-source-id: d64532a562ed53247bb4fa52bb16722634d5c187
Summary: 1. Implements #39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: #61439 Reviewed By: mikaylagawarecki Differential Revision: D33744717 Pulled By: jbschlosser fbshipit-source-id: d64532a562ed53247bb4fa52bb16722634d5c187 (cherry picked from commit 4713dd9)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: VitalyFedyunin Differential Revision: D33894937 Pulled By: jbschlosser fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851 (cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: VitalyFedyunin Differential Revision: D33894937 Pulled By: jbschlosser fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851 (cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: VitalyFedyunin Differential Revision: D33894937 Pulled By: jbschlosser fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851 (cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: VitalyFedyunin Differential Revision: D33894937 Pulled By: jbschlosser fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851 (cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: VitalyFedyunin Differential Revision: D33894937 Pulled By: jbschlosser fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851 (cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: VitalyFedyunin Differential Revision: D33894937 Pulled By: jbschlosser fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851 (cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: VitalyFedyunin Differential Revision: D33894937 Pulled By: jbschlosser fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851 (cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: VitalyFedyunin Differential Revision: D33894937 Pulled By: jbschlosser fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851 (cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: VitalyFedyunin Differential Revision: D33894937 Pulled By: jbschlosser fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851 (cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: VitalyFedyunin Differential Revision: D33894937 Pulled By: jbschlosser fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851 (cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: VitalyFedyunin Differential Revision: D33894937 Pulled By: jbschlosser fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851 (cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: VitalyFedyunin Differential Revision: D33894937 Pulled By: jbschlosser fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851 (cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: VitalyFedyunin Differential Revision: D33894937 Pulled By: jbschlosser fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851 (cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: VitalyFedyunin Differential Revision: D33894937 Pulled By: jbschlosser fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851 (cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: VitalyFedyunin Differential Revision: D33894937 Pulled By: jbschlosser fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851 (cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: VitalyFedyunin Differential Revision: D33894937 Pulled By: jbschlosser fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851 (cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
Seems something may be off in HF testing at huggingface/transformers#15397 ... |
Summary: 1. Implements pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch#61439 Reviewed By: VitalyFedyunin Differential Revision: D33894937 Pulled By: jbschlosser fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: mikaylagawarecki Differential Revision: D33744717 Pulled By: jbschlosser fbshipit-source-id: d64532a562ed53247bb4fa52bb16722634d5c187 (cherry picked from commit 4713dd9ccaa8983422bf3aa7b73df8d9ebd8cc02)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: cpuhrsch Differential Revision: D33850228 Pulled By: jbschlosser fbshipit-source-id: 3cc33fb298e480d7ecc5c67716da019d60c6ab33 (cherry picked from commit 3a53b3e94fd58190d1261efd3cf41b53506fb96e)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: VitalyFedyunin Differential Revision: D33894937 Pulled By: jbschlosser fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851 (cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: mikaylagawarecki Differential Revision: D33744717 Pulled By: jbschlosser fbshipit-source-id: d64532a562ed53247bb4fa52bb16722634d5c187 (cherry picked from commit 4713dd9ccaa8983422bf3aa7b73df8d9ebd8cc02)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: cpuhrsch Differential Revision: D33850228 Pulled By: jbschlosser fbshipit-source-id: 3cc33fb298e480d7ecc5c67716da019d60c6ab33 (cherry picked from commit 3a53b3e94fd58190d1261efd3cf41b53506fb96e)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: VitalyFedyunin Differential Revision: D33894937 Pulled By: jbschlosser fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851 (cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
This change removes the tanh GeLU approximation. This gives us the benefit of better accuracy, better performance and better standard conformance, since we no longer need any compiler-specific tricks. Here's the last lines of train_gpt2 output before this change: step 37: train loss 3.739647 (took 598.548076 ms) step 38: train loss 4.611735 (took 596.626145 ms) step 39: train loss 3.970751 (took 598.439552 ms) val loss 4.016658 generating: --- Come Running Away, Greater conquer With the Imperial blood the heaviest host of the gods into this wondrous world beyond. I will not back thee, for how sweet after birth Netflix against repounder, will not flourish against the earlocks of Allay --- step 40: train loss 4.377756 (took 592.704936 ms) Here's the last lines of train_gpt2 output after this change: step 37: train loss 3.731596 (took 594.893995 ms) step 38: train loss 4.561646 (took 600.064035 ms) step 39: train loss 3.933512 (took 599.666173 ms) val loss 4.014135 generating: --- Whether Hipocrates, Bigon Nicinius, or rep'd With Thy fair winter-tail your outraged hand, The richness of the good smour Nine years by turns covered my Member. Thou art Nay, I fear be; but Lets o' thee know, if it --- step 40: train loss 4.358461 (took 597.594065 ms) This change has the disadvantage of diverging from PyTorch. I view this as being justified and worthwhile, for numerous reasons, e.g. "I used the tanh approximation simply because the error function erf was slow in tensorflow some years ago. If the exact version is fast enough now and does not have numerical issues, I do not see a reason to use an inexact version." ──Quoth Dan Hendrycks See pytorch/pytorch#39853
This change removes the tanh GeLU approximation. This gives us the benefit of better accuracy, roughly equal perf and strict standard conformance, since we no longer need any compiler-specific tricks. Here's the last lines of train_gpt2 output before this change: step 37: train loss 3.739647 (took 598.548076 ms) step 38: train loss 4.611735 (took 596.626145 ms) step 39: train loss 3.970751 (took 598.439552 ms) val loss 4.016658 generating: --- Come Running Away, Greater conquer With the Imperial blood the heaviest host of the gods into this wondrous world beyond. I will not back thee, for how sweet after birth Netflix against repounder, will not flourish against the earlocks of Allay --- step 40: train loss 4.377756 (took 592.704936 ms) Here's the last lines of train_gpt2 output after this change: step 37: train loss 3.731596 (took 594.893995 ms) step 38: train loss 4.561646 (took 600.064035 ms) step 39: train loss 3.933512 (took 599.666173 ms) val loss 4.014135 generating: --- Whether Hipocrates, Bigon Nicinius, or rep'd With Thy fair winter-tail your outraged hand, The richness of the good smour Nine years by turns covered my Member. Thou art Nay, I fear be; but Lets o' thee know, if it --- step 40: train loss 4.358461 (took 597.594065 ms) This change has the disadvantage of diverging from PyTorch. I view this as being justified and worthwhile, for numerous reasons, e.g. "I used the tanh approximation simply because the error function erf was slow in tensorflow some years ago. If the exact version is fast enough now and does not have numerical issues, I do not see a reason to use an inexact version." ──Quoth Dan Hendrycks See pytorch/pytorch#39853
This change removes the tanh GeLU approximation. This gives us the benefit of better accuracy, roughly equal perf and strict standard conformance, since we no longer need any compiler-specific tricks. Here's the last lines of train_gpt2 output before this change: step 37: train loss 3.739647 (took 598.548076 ms) step 38: train loss 4.611735 (took 596.626145 ms) step 39: train loss 3.970751 (took 598.439552 ms) val loss 4.016658 generating: --- Come Running Away, Greater conquer With the Imperial blood the heaviest host of the gods into this wondrous world beyond. I will not back thee, for how sweet after birth Netflix against repounder, will not flourish against the earlocks of Allay --- step 40: train loss 4.377756 (took 592.704936 ms) Here's the last lines of train_gpt2 output after this change: step 37: train loss 3.731596 (took 594.893995 ms) step 38: train loss 4.561646 (took 600.064035 ms) step 39: train loss 3.933512 (took 599.666173 ms) val loss 4.014135 generating: --- Whether Hipocrates, Bigon Nicinius, or rep'd With Thy fair winter-tail your outraged hand, The richness of the good smour Nine years by turns covered my Member. Thou art Nay, I fear be; but Lets o' thee know, if it --- step 40: train loss 4.358461 (took 597.594065 ms) This change has the disadvantage of diverging from PyTorch. I view this as being justified and worthwhile, for numerous reasons, e.g. "I used the tanh approximation simply because the error function erf was slow in tensorflow some years ago. If the exact version is fast enough now and does not have numerical issues, I do not see a reason to use an inexact version." ──Quoth Dan Hendrycks See pytorch/pytorch#39853
Approx variants of gelu are apparently used in google's BERT / GPT-2 (
gelu_new
in huggingface andgelu_accurate
in fairseq). The version in fairseq and in huggingface do not have stable gradients at large inputs for fp16). So maybe it's worth adding it to core in C++ for memory saving / fusion and error-avoidance. One way could be anapprox=True
option to F.gelu / nn.GELU.This is commonly used now, but I don't know why (maybe for simplicity of impl). And it can be deemed too case-specific for adding into core (otoh gelu itself is quite specific).
cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @albanD @mruberry @walterddr @anjali411 @Varal7
The text was updated successfully, but these errors were encountered: