Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT_BREAK] Implement Tanh Gelu Approximation #3039

Merged
merged 7 commits into from
Feb 14, 2022

Conversation

rdspring1
Copy link
Contributor

@rdspring1 rdspring1 commented Jul 14, 2021

Enables Tanh Gelu approximation via integer argument
This is to stay in consistency with PyTorch core changes that are introducing in pytorch/pytorch#61439

@JackCaoG
Copy link
Collaborator

Thanks Ryan! Feel free to @ me for review when upstream pr is ready and this pr is ready to be reviewed.

@rdspring1 rdspring1 changed the title Implement Tanh Gelu Approximation [PT_BREAK] Implement Tanh Gelu Approximation Oct 7, 2021
@rdspring1 rdspring1 marked this pull request as ready for review October 8, 2021 00:40
facebook-github-bot pushed a commit to pytorch/nestedtensor that referenced this pull request Jan 28, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: mikaylagawarecki

Differential Revision: D33744717

Pulled By: jbschlosser

fbshipit-source-id: d64532a562ed53247bb4fa52bb16722634d5c187
facebook-github-bot pushed a commit to pytorch/glow that referenced this pull request Jan 28, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: mikaylagawarecki

Differential Revision: D33744717

Pulled By: jbschlosser

fbshipit-source-id: d64532a562ed53247bb4fa52bb16722634d5c187
facebook-github-bot pushed a commit to pytorch/pytorch that referenced this pull request Jan 28, 2022
Summary:
1. Implements #39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: #61439

Reviewed By: mikaylagawarecki

Differential Revision: D33744717

Pulled By: jbschlosser

fbshipit-source-id: d64532a562ed53247bb4fa52bb16722634d5c187
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Jan 28, 2022
Summary:
1. Implements #39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: #61439

Reviewed By: mikaylagawarecki

Differential Revision: D33744717

Pulled By: jbschlosser

fbshipit-source-id: d64532a562ed53247bb4fa52bb16722634d5c187
(cherry picked from commit 4713dd9)
@rdspring1
Copy link
Contributor Author

@JackCaoG The upstream gelu PR was just merged, so this PR is ready for review.

@JackCaoG
Copy link
Collaborator

@rdspring1 Wait, does it mean if we don't merge this pr right now, CI will start fail for pytorch/xla?

@rdspring1
Copy link
Contributor Author

rdspring1 commented Jan 28, 2022

Well, the PR was just reverted. I need to rebase the latest changes. We've been working through some internal CI failures.

@JackCaoG
Copy link
Collaborator

@rdspring1 Please let me finish the review of this pr before merging the pytorch one. Ideally we should have both pr reviewed and merge at the same time to avoid the breakage.

Our head is broken due to the change

clang++-8 -MMD -MF /tmp/pytorch/xla/build/temp.linux-x86_64-3.7/torch_xla/csrc/aten_xla_type.o.d -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -D_GLIBCXX_USE_CXX11_ABI=1 -fPIC -I/tmp/pytorch/xla -I/tmp/pytorch/xla/third_party/tensorflow/bazel-tensorflow -I/tmp/pytorch/xla/third_party/tensorflow/bazel-bin -I/tmp/pytorch/xla/third_party/tensorflow/bazel-tensorflow/external/protobuf_archive/src -I/tmp/pytorch/xla/third_party/tensorflow/bazel-tensorflow/external/com_google_protobuf/src -I/tmp/pytorch/xla/third_party/tensorflow/bazel-tensorflow/external/eigen_archive -I/tmp/pytorch/xla/third_party/tensorflow/bazel-tensorflow/external/com_google_absl -I/tmp/pytorch -I/tmp/pytorch/torch/csrc -I/tmp/pytorch/torch/lib/tmp_install/include -I/opt/conda/envs/pytorch/lib/python3.7/site-packages/torch/include -I/opt/conda/envs/pytorch/lib/python3.7/site-packages/torch/include/torch/csrc/api/include -I/opt/conda/envs/pytorch/lib/python3.7/site-packages/torch/include/TH -I/opt/conda/envs/pytorch/lib/python3.7/site-packages/torch/include/THC -I/opt/conda/envs/pytorch/include/python3.7m -c -c /tmp/pytorch/xla/torch_xla/csrc/aten_xla_type.cpp -o /tmp/pytorch/xla/build/temp.linux-x86_64-3.7/torch_xla/csrc/aten_xla_type.o -std=c++14 -Wno-sign-compare -Wno-deprecated-declarations -Wno-return-type -Wno-macro-redefined -Wno-return-std-move -DNDEBUG -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_clang"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1002"' -DTORCH_EXTENSION_NAME=_XLAC -D_GLIBCXX_USE_CXX11_ABI=1
/tmp/pytorch/xla/torch_xla/csrc/aten_xla_type.cpp:1511:32: error: out-of-line definition of 'gelu' does not match any declaration in 'torch_xla::XLANativeFunctions'
at::Tensor XLANativeFunctions::gelu(const at::Tensor& self) {
                               ^~~~
/tmp/pytorch/xla/torch_xla/csrc/aten_xla_type.cpp:1516:32: error: out-of-line definition of 'gelu_backward' does not match any declaration in 'torch_xla::XLANativeFunctions'
at::Tensor XLANativeFunctions::gelu_backward(const at::Tensor& grad

and I am triggering a new one since it is reverted.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly LGTM, minor nits

if (approximate == kTanh) {
// inner = math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(input, 3))
// input * 0.5 * (1.0 + torch.tanh(inner))
const float kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe better to mark it as static?

const int64_t kNone = 0;
const int64_t kTanh = 1;
if (approximate == kTanh) {
const float kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto


NodePtr tanh_derivative = one - tanh_inner * tanh_inner;
NodePtr inner_derivative = beta * (one + three * kappa * Pow(input, two));
NodePtr right_derivative = left * tanh_derivative * inner_derivative;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, it really should be lower as a node class and lowering function if it gets this complicated. I will let this one go since upstream pr needs to merge and fix it latter.

facebook-github-bot pushed a commit to pytorch/nestedtensor that referenced this pull request Jan 31, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: cpuhrsch

Differential Revision: D33850228

Pulled By: jbschlosser

fbshipit-source-id: 3cc33fb298e480d7ecc5c67716da019d60c6ab33
facebook-github-bot pushed a commit to pytorch/glow that referenced this pull request Jan 31, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: cpuhrsch

Differential Revision: D33850228

Pulled By: jbschlosser

fbshipit-source-id: 3cc33fb298e480d7ecc5c67716da019d60c6ab33
facebook-github-bot pushed a commit to pytorch/pytorch that referenced this pull request Jan 31, 2022
Summary:
1. Implements #39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: #61439

Reviewed By: cpuhrsch

Differential Revision: D33850228

Pulled By: jbschlosser

fbshipit-source-id: 3cc33fb298e480d7ecc5c67716da019d60c6ab33
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Jan 31, 2022
Summary:
1. Implements #39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: #61439

Reviewed By: cpuhrsch

Differential Revision: D33850228

Pulled By: jbschlosser

fbshipit-source-id: 3cc33fb298e480d7ecc5c67716da019d60c6ab33
(cherry picked from commit 3a53b3e)
cpuhrsch pushed a commit to cpuhrsch/nestedtensor that referenced this pull request Feb 2, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: mikaylagawarecki

Differential Revision: D33744717

Pulled By: jbschlosser

fbshipit-source-id: d64532a562ed53247bb4fa52bb16722634d5c187
cpuhrsch pushed a commit to cpuhrsch/nestedtensor that referenced this pull request Feb 2, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: cpuhrsch

Differential Revision: D33850228

Pulled By: jbschlosser

fbshipit-source-id: 3cc33fb298e480d7ecc5c67716da019d60c6ab33
@JackCaoG
Copy link
Collaborator

JackCaoG commented Feb 2, 2022

@rdspring1 Will the upstream pytorch pr make it to the pytorch 1.11?

@rdspring1
Copy link
Contributor Author

rdspring1 commented Feb 3, 2022

No, I don't think so. There is some issue with the mobile builds.

facebook-github-bot pushed a commit to pytorch/glow that referenced this pull request Feb 14, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
facebook-github-bot pushed a commit to pytorch/pytorch that referenced this pull request Feb 14, 2022
Summary:
1. Implements #39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: #61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Feb 14, 2022
Summary:
1. Implements #39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: #61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f9)
@JackCaoG JackCaoG merged commit 82a110f into pytorch:master Feb 14, 2022
jbschlosser pushed a commit to jbschlosser/pytorch that referenced this pull request Feb 15, 2022
Summary:
1. Implements pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f9)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 15, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 15, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 15, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 16, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 16, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 17, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 17, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 17, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 17, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 17, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 17, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 17, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 20, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 20, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 20, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 21, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
leezu pushed a commit to leezu/pytorch that referenced this pull request Mar 30, 2022
Summary:
1. Implements pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
jjsjann123 pushed a commit to jjsjann123/nvfuser that referenced this pull request Oct 29, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: mikaylagawarecki

Differential Revision: D33744717

Pulled By: jbschlosser

fbshipit-source-id: d64532a562ed53247bb4fa52bb16722634d5c187
(cherry picked from commit 4713dd9ccaa8983422bf3aa7b73df8d9ebd8cc02)
jjsjann123 pushed a commit to jjsjann123/nvfuser that referenced this pull request Oct 29, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: cpuhrsch

Differential Revision: D33850228

Pulled By: jbschlosser

fbshipit-source-id: 3cc33fb298e480d7ecc5c67716da019d60c6ab33
(cherry picked from commit 3a53b3e94fd58190d1261efd3cf41b53506fb96e)
jjsjann123 pushed a commit to jjsjann123/nvfuser that referenced this pull request Oct 29, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
jjsjann123 pushed a commit to jjsjann123/nvfuser that referenced this pull request Nov 10, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: mikaylagawarecki

Differential Revision: D33744717

Pulled By: jbschlosser

fbshipit-source-id: d64532a562ed53247bb4fa52bb16722634d5c187
(cherry picked from commit 4713dd9ccaa8983422bf3aa7b73df8d9ebd8cc02)
jjsjann123 pushed a commit to jjsjann123/nvfuser that referenced this pull request Nov 10, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: cpuhrsch

Differential Revision: D33850228

Pulled By: jbschlosser

fbshipit-source-id: 3cc33fb298e480d7ecc5c67716da019d60c6ab33
(cherry picked from commit 3a53b3e94fd58190d1261efd3cf41b53506fb96e)
jjsjann123 pushed a commit to jjsjann123/nvfuser that referenced this pull request Nov 10, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants