Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[proposal] Add approx variant option to F.gelu #39853

Closed
vadimkantorov opened this issue Jun 11, 2020 · 21 comments
Closed

[proposal] Add approx variant option to F.gelu #39853

vadimkantorov opened this issue Jun 11, 2020 · 21 comments
Labels
actionable feature A request for a proper, new feature. high priority module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Jun 11, 2020

Approx variants of gelu are apparently used in google's BERT / GPT-2 (gelu_new in huggingface and gelu_accurate in fairseq). The version in fairseq and in huggingface do not have stable gradients at large inputs for fp16). So maybe it's worth adding it to core in C++ for memory saving / fusion and error-avoidance. One way could be an approx=True option to F.gelu / nn.GELU.

This is commonly used now, but I don't know why (maybe for simplicity of impl). And it can be deemed too case-specific for adding into core (otoh gelu itself is quite specific).

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @albanD @mruberry @walterddr @anjali411 @Varal7

@mruberry mruberry added feature A request for a proper, new feature. module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jun 12, 2020
@hendrycks
Copy link

I used the tanh approximation simply because the error function erf was slow in tensorflow some years ago. If the exact version is fast enough now and does not have numerical issues, I do not see a reason to use an inexact version.

@t-vi
Copy link
Collaborator

t-vi commented Jul 15, 2020

To echo @hendrycks comment:
I think putting them in PyTorch proper paints the wrong image about efficiency:

torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
torch._C._jit_override_can_fuse_on_gpu(True)
torch._C._jit_set_texpr_fuser_enabled(False)

@torch.jit.script
def gelu(x):
    return x * (0.5 + torch.erf(x * 0.7071067811865476) * 0.5)


@torch.jit.script
def fast_gelu_1(x):
    # sqrt(2/pi) = 0.7978845608028654
    return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
@torch.jit.script
def fast_gelu_2(x):
    return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))


x = torch.randn(32, 128, device="cuda")
def time_me(fn):
    for i in range(100):
        fn(x)
    torch.cuda.synchronize()

time_me(gelu)
%timeit time_me(gelu)
time_me(fast_gelu_1)
%timeit time_me(fast_gelu_1)
time_me(fast_gelu_2)
%timeit time_me(fast_gelu_2)

gives me (on a RadeonVII where I happened to have a running Jupyter notebook)

787 µs ± 9.54 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.04 ms ± 9.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
841 µs ± 9.69 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Now, the fuser might not have produce kernels but the difference between them is only the compute.

Except when people need to reproduce weird technical decisions relating to ancient versions of other frameworks, one should not use F.gelu. If you put it in PyTorch, people will start "optimizing".

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Jul 15, 2020

The "fast" variants are actually quite slow, right?

I stumbled on this only during reproducing huggingface models in fairseq (and reading huggingface code). I agree this is very specific, but if people keep using those old variants of GELU for running original BERT models (like they did with CrossMapLRN2d because reproducing AlexNet was important), then maybe it's worth to put in an approx variant as well - for compat reasons specifically. But I agree, this may not be a sufficient reason

@t-vi
Copy link
Collaborator

t-vi commented Jul 15, 2020

The speed difference probably doesn't matter much.

The HuggingFace BERT used from the tutorials in their transformer library does not use it anymore apparently (this is how I found that the translation was not quite accurate). I would probably consider that to be the leading BERT implementation for PyTorch (and possibly beyond), so I would think that the use case is too limited.

@vadimkantorov
Copy link
Contributor Author

@hendrycks
Copy link

Yes, though GPT-2 was in tensorflow which did not have any optimized versions. I'll suggest to tensorflow not to default to the approximate version.

@vadimkantorov
Copy link
Contributor Author

It seems that at the end TensorFlow kept both options: default exact and optional approximate...

@vadimkantorov
Copy link
Contributor Author

#39853 found that approximate version is 1.75x faster on TPUv2 and now supports both approx and exact variants

@hendrycks
Copy link

hendrycks commented Jun 8, 2021

I now support adding this option after seeing that more people are manually adding the approximation in their PyTorch model implementations. Adding something like https://github.com/huggingface/transformers/blob/master/src/transformers/activations.py#L56
to PyTorch as an option would be good.
I noticed iGPT and Jukebox use that version too (https://github.com/openai/image-gpt/blob/master/src/model.py#L29, https://github.com/openai/jukebox/blob/master/jukebox/transformer/transformer.py#L91).

@rwightman
Copy link

Since this is is activate again, I've been chatting (complaining) a bit to @ptrblck about the speed of GELU in PyTorch w/ CUDA. In float32 networks using GELU vs (native) SiLU are comparable. However, w/ AMP enabled, GELU is quite a bit slower. AMP autocasts the op to float32 so that's a big part of the slowdown. I've done some measurements on pure float16 and there seems to be some pytorch version/hardware specific differences where sometimes it's equivalent and sometimes slower (but not as bad as AMP). I'm not sure if that descrepency is a different issue?

Popular networks in my focus such as NFNet, Vision Transformers, and Vision MLP models all use GELU by default (those were all developed on TPU w/ JAX where this is not much an issue).

To put a few numbers to it that I collected, on an RTX 3090 w/ AMP enabled.

ViT Small:

act inf img/s train img/s
silu 2440 835
quick_gelu 2260 755
gelu 2090 742

NFNet-F0:

act inf img/s train img/s
silu 1014 334
quick_gelu 840 280
gelu 746 267

The vit number ratios are similar for other ViT and MLP models using GELU (most of them) ... in the 15-20% range. NFNets in the 25-35% range.

The python API quick_gelu x * torch.sigmoid(1.702 * x) is an improvement but not ideal. It chews up more memory and is slower than a native op (much like SiLU before it was properly added to PyTorch, all EfficientNets were slower than they could have been). torchscript and custom autograd can be used to get some performance back in quick_gelu but it's a pain to maintain those. Also the quick_gelu is definitely is not a lossless replacement for a network trained with exact gelu. Not sure if any other approximations are?

@hendrycks
Copy link

hendrycks commented Jun 10, 2021

Not sure if any other approximations are?

The other approximations are not, but slightly fine-tuning would probably fix and slight degradation. While there is a tanh cubic approximation and a sigmoid approximation, the sigmoid might be preferable largely because the OpenAI people chose that one over the tanh cubic approximation (their algorithms team keeps track of efficiency and knows what they're doing).

Looks like a native GELU approximation would be really useful. Thanks for sharing those numbers @rwightman!

@ptrblck
Copy link
Collaborator

ptrblck commented Jun 10, 2021

While chatting with @rwightman we've also talked about other approximations such as:

@torch.jit.script
def bias_gelu(x):
    return  x * 0.5 * (1.0 + torch.erf(x * 0.70710678))

@torch.jit.script
def bias_gelu_back(g, x):
    ff = 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
    return ff*g

One concern I've raised was the potential saturation of erf for x <= -4 and x >= 4, but I don't know if this would cause issues for commonly used models (and if one approximation would be preferred over the other one).
A potential approach would be to use an argument such as approximate=True/False (or different op name).

@hendrycks
Copy link

hendrycks commented Jun 10, 2021 via email

@albanD
Copy link
Collaborator

albanD commented Jun 11, 2021

In the timing above, it seems also that amp is computing the approximate version in half precision while the original is not.
Do you see the same results after updating amp's custom list of ops that can run in half precision?

@ezyang
Copy link
Contributor

ezyang commented Jun 14, 2021

#59639 should help with this

@ezyang
Copy link
Contributor

ezyang commented Jun 14, 2021

cc @mcarilli @Fuzzkatt

facebook-github-bot pushed a commit to pytorch/nestedtensor that referenced this issue Jan 28, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: mikaylagawarecki

Differential Revision: D33744717

Pulled By: jbschlosser

fbshipit-source-id: d64532a562ed53247bb4fa52bb16722634d5c187
facebook-github-bot pushed a commit to pytorch/glow that referenced this issue Jan 28, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: mikaylagawarecki

Differential Revision: D33744717

Pulled By: jbschlosser

fbshipit-source-id: d64532a562ed53247bb4fa52bb16722634d5c187
facebook-github-bot pushed a commit that referenced this issue Jan 28, 2022
Summary:
1. Implements #39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: #61439

Reviewed By: mikaylagawarecki

Differential Revision: D33744717

Pulled By: jbschlosser

fbshipit-source-id: d64532a562ed53247bb4fa52bb16722634d5c187
pytorchmergebot pushed a commit that referenced this issue Jan 28, 2022
Summary:
1. Implements #39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: #61439

Reviewed By: mikaylagawarecki

Differential Revision: D33744717

Pulled By: jbschlosser

fbshipit-source-id: d64532a562ed53247bb4fa52bb16722634d5c187
(cherry picked from commit 4713dd9)
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 15, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 15, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 15, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 16, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 16, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 17, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 17, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 17, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 17, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 17, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 17, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 17, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 20, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 20, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 20, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this issue Feb 21, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Mar 22, 2022

Seems something may be off in HF testing at huggingface/transformers#15397 ...

leezu pushed a commit to leezu/pytorch that referenced this issue Mar 30, 2022
Summary:
1. Implements pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
jjsjann123 pushed a commit to jjsjann123/nvfuser that referenced this issue Oct 29, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: mikaylagawarecki

Differential Revision: D33744717

Pulled By: jbschlosser

fbshipit-source-id: d64532a562ed53247bb4fa52bb16722634d5c187
(cherry picked from commit 4713dd9ccaa8983422bf3aa7b73df8d9ebd8cc02)
jjsjann123 pushed a commit to jjsjann123/nvfuser that referenced this issue Oct 29, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: cpuhrsch

Differential Revision: D33850228

Pulled By: jbschlosser

fbshipit-source-id: 3cc33fb298e480d7ecc5c67716da019d60c6ab33
(cherry picked from commit 3a53b3e94fd58190d1261efd3cf41b53506fb96e)
jjsjann123 pushed a commit to jjsjann123/nvfuser that referenced this issue Oct 29, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
jjsjann123 pushed a commit to jjsjann123/nvfuser that referenced this issue Nov 10, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: mikaylagawarecki

Differential Revision: D33744717

Pulled By: jbschlosser

fbshipit-source-id: d64532a562ed53247bb4fa52bb16722634d5c187
(cherry picked from commit 4713dd9ccaa8983422bf3aa7b73df8d9ebd8cc02)
jjsjann123 pushed a commit to jjsjann123/nvfuser that referenced this issue Nov 10, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: cpuhrsch

Differential Revision: D33850228

Pulled By: jbschlosser

fbshipit-source-id: 3cc33fb298e480d7ecc5c67716da019d60c6ab33
(cherry picked from commit 3a53b3e94fd58190d1261efd3cf41b53506fb96e)
jjsjann123 pushed a commit to jjsjann123/nvfuser that referenced this issue Nov 10, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
jart added a commit to jart/llm.c that referenced this issue May 21, 2024
This change removes the tanh GeLU approximation. This gives us the
benefit of better accuracy, better performance and better standard
conformance, since we no longer need any compiler-specific tricks.

Here's the last lines of train_gpt2 output before this change:

    step 37: train loss 3.739647 (took 598.548076 ms)
    step 38: train loss 4.611735 (took 596.626145 ms)
    step 39: train loss 3.970751 (took 598.439552 ms)
    val loss 4.016658
    generating:
    ---
    Come Running Away,
    Greater conquer
    With the Imperial blood
    the heaviest host of the gods
    into this wondrous world beyond.
    I will not back thee, for how sweet after birth
    Netflix against repounder,
    will not
    flourish against the earlocks of
    Allay
    ---
    step 40: train loss 4.377756 (took 592.704936 ms)

Here's the last lines of train_gpt2 output after this change:

    step 37: train loss 3.731596 (took 594.893995 ms)
    step 38: train loss 4.561646 (took 600.064035 ms)
    step 39: train loss 3.933512 (took 599.666173 ms)
    val loss 4.014135
    generating:
    ---
    Whether Hipocrates,
    Bigon Nicinius, or rep'd
    With Thy fair winter-tail your outraged hand,
    The richness of the good smour
    Nine years by turns covered my Member. Thou art
    Nay, I fear be; but
    Lets o' thee know, if it
    ---
    step 40: train loss 4.358461 (took 597.594065 ms)

This change has the disadvantage of diverging from PyTorch. I view
this as being justified and worthwhile, for numerous reasons, e.g.

  "I used the tanh approximation simply because the error function
   erf was slow in tensorflow some years ago. If the exact version
   is fast enough now and does not have numerical issues, I do not
   see a reason to use an inexact version."  ──Quoth Dan Hendrycks

See pytorch/pytorch#39853
jart added a commit to jart/llm.c that referenced this issue May 21, 2024
This change removes the tanh GeLU approximation. This gives us the
benefit of better accuracy, roughly equal perf and strict standard
conformance, since we no longer need any compiler-specific tricks.

Here's the last lines of train_gpt2 output before this change:

    step 37: train loss 3.739647 (took 598.548076 ms)
    step 38: train loss 4.611735 (took 596.626145 ms)
    step 39: train loss 3.970751 (took 598.439552 ms)
    val loss 4.016658
    generating:
    ---
    Come Running Away,
    Greater conquer
    With the Imperial blood
    the heaviest host of the gods
    into this wondrous world beyond.
    I will not back thee, for how sweet after birth
    Netflix against repounder,
    will not
    flourish against the earlocks of
    Allay
    ---
    step 40: train loss 4.377756 (took 592.704936 ms)

Here's the last lines of train_gpt2 output after this change:

    step 37: train loss 3.731596 (took 594.893995 ms)
    step 38: train loss 4.561646 (took 600.064035 ms)
    step 39: train loss 3.933512 (took 599.666173 ms)
    val loss 4.014135
    generating:
    ---
    Whether Hipocrates,
    Bigon Nicinius, or rep'd
    With Thy fair winter-tail your outraged hand,
    The richness of the good smour
    Nine years by turns covered my Member. Thou art
    Nay, I fear be; but
    Lets o' thee know, if it
    ---
    step 40: train loss 4.358461 (took 597.594065 ms)

This change has the disadvantage of diverging from PyTorch. I view
this as being justified and worthwhile, for numerous reasons, e.g.

  "I used the tanh approximation simply because the error function
   erf was slow in tensorflow some years ago. If the exact version
   is fast enough now and does not have numerical issues, I do not
   see a reason to use an inexact version."  ──Quoth Dan Hendrycks

See pytorch/pytorch#39853
jart added a commit to jart/llm.c that referenced this issue May 23, 2024
This change removes the tanh GeLU approximation. This gives us the
benefit of better accuracy, roughly equal perf and strict standard
conformance, since we no longer need any compiler-specific tricks.

Here's the last lines of train_gpt2 output before this change:

    step 37: train loss 3.739647 (took 598.548076 ms)
    step 38: train loss 4.611735 (took 596.626145 ms)
    step 39: train loss 3.970751 (took 598.439552 ms)
    val loss 4.016658
    generating:
    ---
    Come Running Away,
    Greater conquer
    With the Imperial blood
    the heaviest host of the gods
    into this wondrous world beyond.
    I will not back thee, for how sweet after birth
    Netflix against repounder,
    will not
    flourish against the earlocks of
    Allay
    ---
    step 40: train loss 4.377756 (took 592.704936 ms)

Here's the last lines of train_gpt2 output after this change:

    step 37: train loss 3.731596 (took 594.893995 ms)
    step 38: train loss 4.561646 (took 600.064035 ms)
    step 39: train loss 3.933512 (took 599.666173 ms)
    val loss 4.014135
    generating:
    ---
    Whether Hipocrates,
    Bigon Nicinius, or rep'd
    With Thy fair winter-tail your outraged hand,
    The richness of the good smour
    Nine years by turns covered my Member. Thou art
    Nay, I fear be; but
    Lets o' thee know, if it
    ---
    step 40: train loss 4.358461 (took 597.594065 ms)

This change has the disadvantage of diverging from PyTorch. I view
this as being justified and worthwhile, for numerous reasons, e.g.

  "I used the tanh approximation simply because the error function
   erf was slow in tensorflow some years ago. If the exact version
   is fast enough now and does not have numerical issues, I do not
   see a reason to use an inexact version."  ──Quoth Dan Hendrycks

See pytorch/pytorch#39853
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
actionable feature A request for a proper, new feature. high priority module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

9 participants