Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Tanh Gelu Approximation #61439

Closed
wants to merge 64 commits into from
Closed

Conversation

rdspring1
Copy link
Contributor

@rdspring1 rdspring1 commented Jul 9, 2021

  1. Implements [proposal] Add approx variant option to F.gelu #39853
  2. Adds approximate boolean flag to Gelu
  3. Enables Tanh Gelu approximation
  4. Adds double backward support for Gelu
  5. Enable Tanh Gelu in NvFuser
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)

Linking XLA PR - pytorch/xla#3039

Add fast-gelu implementation for CPU and CUDA
@facebook-github-bot facebook-github-bot added oncall: jit Add this issue/PR to JIT oncall triage queue cla signed labels Jul 9, 2021
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jul 9, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit fbd5e62 (more details on the Dr. CI page):


  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See GitHub Actions build linux-bionic-py3.7-clang9 / test (xla, 1, 1, linux.2xlarge) (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-02-10T00:06:50.3252013Z /var/lib/jenkins/w... at::Tensor&, const at::Tensor&, c10::string_view)
2022-02-10T00:06:50.3245349Z  at::Tensor XLANativeFunctions::gelu(const at::Tensor& self) {
2022-02-10T00:06:50.3245747Z             ^~~~~~~~~~~~~~~~~~
2022-02-10T00:06:50.3246216Z In file included from /var/lib/jenkins/workspace/xla/torch_xla/csrc/aten_xla_type.cpp:13:0:
2022-02-10T00:06:50.3246991Z /var/lib/jenkins/workspace/xla/torch_xla/csrc/XLANativeFunctions.h:182:19: error: candidate is: static at::Tensor torch_xla::XLANativeFunctions::gelu(const at::Tensor&, c10::string_view)
2022-02-10T00:06:50.3247733Z  static at::Tensor gelu(const at::Tensor & self, c10::string_view approximate);
2022-02-10T00:06:50.3248137Z                    ^~~~
2022-02-10T00:06:50.3249472Z /var/lib/jenkins/workspace/xla/torch_xla/csrc/aten_xla_type.cpp:1514:12: error: prototype for ‘at::Tensor torch_xla::XLANativeFunctions::gelu_backward(const at::Tensor&, const at::Tensor&)’ does not match any in class ‘torch_xla::XLANativeFunctions’
2022-02-10T00:06:50.3250337Z  at::Tensor XLANativeFunctions::gelu_backward(const at::Tensor& grad,
2022-02-10T00:06:50.3250740Z             ^~~~~~~~~~~~~~~~~~
2022-02-10T00:06:50.3251193Z In file included from /var/lib/jenkins/workspace/xla/torch_xla/csrc/aten_xla_type.cpp:13:0:
2022-02-10T00:06:50.3252013Z /var/lib/jenkins/workspace/xla/torch_xla/csrc/XLANativeFunctions.h:183:19: error: candidate is: static at::Tensor torch_xla::XLANativeFunctions::gelu_backward(const at::Tensor&, const at::Tensor&, c10::string_view)
2022-02-10T00:06:50.3252991Z  static at::Tensor gelu_backward(const at::Tensor & grad_output, const at::Tensor & self, c10::string_view approximate);
2022-02-10T00:06:50.3253446Z                    ^~~~~~~~~~~~~
2022-02-10T00:06:51.2681160Z [34/178] c++ -MMD -MF /var/lib/jenkins/workspace/xla/build/temp.linux-x86_64-3.7/torch_xla/csrc/tensor_methods.o.d -pthread -B /opt/conda/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I/var/lib/jenkins/workspace/xla -I/var/lib/jenkins/workspace/xla/third_party/tensorflow/bazel-tensorflow -I/var/lib/jenkins/workspace/xla/third_party/tensorflow/bazel-bin -I/var/lib/jenkins/workspace/xla/third_party/tensorflow/bazel-tensorflow/external/protobuf_archive/src -I/var/lib/jenkins/workspace/xla/third_party/tensorflow/bazel-tensorflow/external/com_google_protobuf/src -I/var/lib/jenkins/workspace/xla/third_party/tensorflow/bazel-tensorflow/external/eigen_archive -I/var/lib/jenkins/workspace/xla/third_party/tensorflow/bazel-tensorflow/external/com_google_absl -I/var/lib/jenkins/workspace -I/var/lib/jenkins/workspace/torch/csrc -I/var/lib/jenkins/workspace/torch/lib/tmp_install/include -I/opt/conda/lib/python3.7/site-packages/torch/include -I/opt/conda/lib/python3.7/site-packages/torch/include/torch/csrc/api/include -I/opt/conda/lib/python3.7/site-packages/torch/include/TH -I/opt/conda/lib/python3.7/site-packages/torch/include/THC -I/opt/conda/include/python3.7m -c -c /var/lib/jenkins/workspace/xla/torch_xla/csrc/tensor_methods.cpp -o /var/lib/jenkins/workspace/xla/build/temp.linux-x86_64-3.7/torch_xla/csrc/tensor_methods.o -std=c++14 -Wno-sign-compare -Wno-deprecated-declarations -Wno-return-type -DNDEBUG -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_clang"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1002"' -DTORCH_EXTENSION_NAME=_XLAC -D_GLIBCXX_USE_CXX11_ABI=1
2022-02-10T00:06:51.2684947Z cc1plus: warning: command line option ‘-Wstrict-prototypes’ is valid for C/ObjC but not for C++
2022-02-10T00:06:51.2685762Z In file included from /var/lib/jenkins/workspace/c10/util/Logging.h:28:0,
2022-02-10T00:06:51.2686191Z                  from /var/lib/jenkins/workspace/c10/core/TensorImpl.h:14,
2022-02-10T00:06:51.2686792Z                  from /opt/conda/lib/python3.7/site-packages/torch/include/ATen/core/TensorBody.h:21,
2022-02-10T00:06:51.2687398Z                  from /opt/conda/lib/python3.7/site-packages/torch/include/ATen/Tensor.h:3,
2022-02-10T00:06:51.2687875Z                  from /var/lib/jenkins/workspace/torch/csrc/autograd/function_hook.h:5,
2022-02-10T00:06:51.2688320Z                  from /var/lib/jenkins/workspace/torch/csrc/autograd/variable.h:7,

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Jul 9, 2021

From what I understood, there are now multiple approximations in use in third-party frameworks. Should the arg then be called approx_tanh/tanh_approx/approx='tanh'? (extensible if the other approx variant gets supported in the future)

Add approximate argument to JIT symbolic script + testing

Update Gelu documenation
@vadimkantorov
Copy link
Contributor

vadimkantorov commented Jul 12, 2021

From what I understood, there are now multiple approximations in use in third-party frameworks. Should the arg then be called approx_tanh/tanh_approx/approx='tanh'? (extensible if the other approx variant gets supported in the future)

@rwightman @hendrycks @ptrblck Could you please comment on this?

@hendrycks
Copy link

Possibilities:
It could be approx=True, which would then default to a sigmoid approximation.
approx=True -> sigmoid approximation

Or there could be both the sigmoid (for lowest memory usage) and tanh approximation (for cross-compatibility with tensorflow).
approx='tanh', approx='sigmoid', approx=True (equivalent to approx='sigmoid')

@jjsjann123
Copy link
Collaborator

FYI, backward_compatibility_check_test is just a warning that you are changing the signature and need to add the function to an allow list.

@rwightman
Copy link

rwightman commented Jul 14, 2021

@vadimkantorov @hendrycks re the arg, if it is going to be an arg I think the string approach is good as there are different users/use cases of both approx out there, probably don't need both the bool + string overload...

However, personally I don't see much point in squishing it into one fn. All 3 (original, sigmoid approx, tanh approx) aren't numerically compatible with each other (ie you can't just change the flag and expect a trained network to work the same, you have to fine-tune). I think it's more clear in that respect if they are separate activation instances. Having it as an arg isn't going to save one code/effort as it'll require a partial wrap anyways, might as well make it gelu, gelu_approx_sigmoid, gelu_approx_tanh ....

@hendrycks
Copy link

hendrycks commented Jul 14, 2021

However, personally I don't see much point in squishing it into one fn.

Tensorflow/keras put it in one function. In other modules you can customize them to get slightly different behavior (e.g., nn.Linear(..., bias=False), torch.svd_lowrank(A, q=6, niter=2)). The other advantage is that it doesn't clutter the PyTorch namespace as much. I don't have a strong feeling either way, but argument options feel cleaner to me.

probably don't need both the bool + string overload

Yes, so it could be approx='sigmoid', approx='tanh' (and maybe approx='True' depending on what others think).

cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 17, 2022
Summary:
Things changed in this PR that requires review:
1. aten/src/ATen/core/interned_strings.h
2. torch/csrc/jit/ir/alias_analysis.h : exposing createValue to allow efficient mutation
3. torch/csrc/jit/runtime/symbolic_shape_registry.cpp : added gelu/tanh/erf in registry
4. torch/jit/_script.py : throws scripting model sees autocast as decorator since it's not supported

nvfuser code update:
1. codegen improvements and performance tuning
2. integration bug fixes for shape expression logic
3. kernel segmentation update to address perf regression from horizontal fusion
4. scalar cpu tensor promotion to support inter-device operation between cpu scalar tensor and cuda tensor

Things reverted from local changes:
aten::gelu with approximation (tracked in PR: pytorch/pytorch#61439)

Pull Request resolved: pytorch/pytorch#72127

Reviewed By: HamidShojanazeri

Differential Revision: D34113233

Pulled By: jbschlosser

fbshipit-source-id: b82cde32b71e324eca0ea57cb8c9f9647278ca74
(cherry picked from commit e009bc5c4e943211c4953e6fdf7c9913fa66b3c9)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 17, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 17, 2022
Summary:
Things changed in this PR that requires review:
1. aten/src/ATen/core/interned_strings.h
2. torch/csrc/jit/ir/alias_analysis.h : exposing createValue to allow efficient mutation
3. torch/csrc/jit/runtime/symbolic_shape_registry.cpp : added gelu/tanh/erf in registry
4. torch/jit/_script.py : throws scripting model sees autocast as decorator since it's not supported

nvfuser code update:
1. codegen improvements and performance tuning
2. integration bug fixes for shape expression logic
3. kernel segmentation update to address perf regression from horizontal fusion
4. scalar cpu tensor promotion to support inter-device operation between cpu scalar tensor and cuda tensor

Things reverted from local changes:
aten::gelu with approximation (tracked in PR: pytorch/pytorch#61439)

Pull Request resolved: pytorch/pytorch#72127

Reviewed By: HamidShojanazeri

Differential Revision: D34113233

Pulled By: jbschlosser

fbshipit-source-id: b82cde32b71e324eca0ea57cb8c9f9647278ca74
(cherry picked from commit e009bc5c4e943211c4953e6fdf7c9913fa66b3c9)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 17, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 17, 2022
Summary:
Things changed in this PR that requires review:
1. aten/src/ATen/core/interned_strings.h
2. torch/csrc/jit/ir/alias_analysis.h : exposing createValue to allow efficient mutation
3. torch/csrc/jit/runtime/symbolic_shape_registry.cpp : added gelu/tanh/erf in registry
4. torch/jit/_script.py : throws scripting model sees autocast as decorator since it's not supported

nvfuser code update:
1. codegen improvements and performance tuning
2. integration bug fixes for shape expression logic
3. kernel segmentation update to address perf regression from horizontal fusion
4. scalar cpu tensor promotion to support inter-device operation between cpu scalar tensor and cuda tensor

Things reverted from local changes:
aten::gelu with approximation (tracked in PR: pytorch/pytorch#61439)

Pull Request resolved: pytorch/pytorch#72127

Reviewed By: HamidShojanazeri

Differential Revision: D34113233

Pulled By: jbschlosser

fbshipit-source-id: b82cde32b71e324eca0ea57cb8c9f9647278ca74
(cherry picked from commit e009bc5c4e943211c4953e6fdf7c9913fa66b3c9)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 17, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 17, 2022
Summary:
Things changed in this PR that requires review:
1. aten/src/ATen/core/interned_strings.h
2. torch/csrc/jit/ir/alias_analysis.h : exposing createValue to allow efficient mutation
3. torch/csrc/jit/runtime/symbolic_shape_registry.cpp : added gelu/tanh/erf in registry
4. torch/jit/_script.py : throws scripting model sees autocast as decorator since it's not supported

nvfuser code update:
1. codegen improvements and performance tuning
2. integration bug fixes for shape expression logic
3. kernel segmentation update to address perf regression from horizontal fusion
4. scalar cpu tensor promotion to support inter-device operation between cpu scalar tensor and cuda tensor

Things reverted from local changes:
aten::gelu with approximation (tracked in PR: pytorch/pytorch#61439)

Pull Request resolved: pytorch/pytorch#72127

Reviewed By: HamidShojanazeri

Differential Revision: D34113233

Pulled By: jbschlosser

fbshipit-source-id: b82cde32b71e324eca0ea57cb8c9f9647278ca74
(cherry picked from commit e009bc5c4e943211c4953e6fdf7c9913fa66b3c9)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 17, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 17, 2022
Summary:
Things changed in this PR that requires review:
1. aten/src/ATen/core/interned_strings.h
2. torch/csrc/jit/ir/alias_analysis.h : exposing createValue to allow efficient mutation
3. torch/csrc/jit/runtime/symbolic_shape_registry.cpp : added gelu/tanh/erf in registry
4. torch/jit/_script.py : throws scripting model sees autocast as decorator since it's not supported

nvfuser code update:
1. codegen improvements and performance tuning
2. integration bug fixes for shape expression logic
3. kernel segmentation update to address perf regression from horizontal fusion
4. scalar cpu tensor promotion to support inter-device operation between cpu scalar tensor and cuda tensor

Things reverted from local changes:
aten::gelu with approximation (tracked in PR: pytorch/pytorch#61439)

Pull Request resolved: pytorch/pytorch#72127

Reviewed By: HamidShojanazeri

Differential Revision: D34113233

Pulled By: jbschlosser

fbshipit-source-id: b82cde32b71e324eca0ea57cb8c9f9647278ca74
(cherry picked from commit e009bc5c4e943211c4953e6fdf7c9913fa66b3c9)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 20, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 20, 2022
Summary:
Things changed in this PR that requires review:
1. aten/src/ATen/core/interned_strings.h
2. torch/csrc/jit/ir/alias_analysis.h : exposing createValue to allow efficient mutation
3. torch/csrc/jit/runtime/symbolic_shape_registry.cpp : added gelu/tanh/erf in registry
4. torch/jit/_script.py : throws scripting model sees autocast as decorator since it's not supported

nvfuser code update:
1. codegen improvements and performance tuning
2. integration bug fixes for shape expression logic
3. kernel segmentation update to address perf regression from horizontal fusion
4. scalar cpu tensor promotion to support inter-device operation between cpu scalar tensor and cuda tensor

Things reverted from local changes:
aten::gelu with approximation (tracked in PR: pytorch/pytorch#61439)

Pull Request resolved: pytorch/pytorch#72127

Reviewed By: HamidShojanazeri

Differential Revision: D34113233

Pulled By: jbschlosser

fbshipit-source-id: b82cde32b71e324eca0ea57cb8c9f9647278ca74
(cherry picked from commit e009bc5c4e943211c4953e6fdf7c9913fa66b3c9)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 20, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 20, 2022
Summary:
Things changed in this PR that requires review:
1. aten/src/ATen/core/interned_strings.h
2. torch/csrc/jit/ir/alias_analysis.h : exposing createValue to allow efficient mutation
3. torch/csrc/jit/runtime/symbolic_shape_registry.cpp : added gelu/tanh/erf in registry
4. torch/jit/_script.py : throws scripting model sees autocast as decorator since it's not supported

nvfuser code update:
1. codegen improvements and performance tuning
2. integration bug fixes for shape expression logic
3. kernel segmentation update to address perf regression from horizontal fusion
4. scalar cpu tensor promotion to support inter-device operation between cpu scalar tensor and cuda tensor

Things reverted from local changes:
aten::gelu with approximation (tracked in PR: pytorch/pytorch#61439)

Pull Request resolved: pytorch/pytorch#72127

Reviewed By: HamidShojanazeri

Differential Revision: D34113233

Pulled By: jbschlosser

fbshipit-source-id: b82cde32b71e324eca0ea57cb8c9f9647278ca74
(cherry picked from commit e009bc5c4e943211c4953e6fdf7c9913fa66b3c9)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 20, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 20, 2022
Summary:
Things changed in this PR that requires review:
1. aten/src/ATen/core/interned_strings.h
2. torch/csrc/jit/ir/alias_analysis.h : exposing createValue to allow efficient mutation
3. torch/csrc/jit/runtime/symbolic_shape_registry.cpp : added gelu/tanh/erf in registry
4. torch/jit/_script.py : throws scripting model sees autocast as decorator since it's not supported

nvfuser code update:
1. codegen improvements and performance tuning
2. integration bug fixes for shape expression logic
3. kernel segmentation update to address perf regression from horizontal fusion
4. scalar cpu tensor promotion to support inter-device operation between cpu scalar tensor and cuda tensor

Things reverted from local changes:
aten::gelu with approximation (tracked in PR: pytorch/pytorch#61439)

Pull Request resolved: pytorch/pytorch#72127

Reviewed By: HamidShojanazeri

Differential Revision: D34113233

Pulled By: jbschlosser

fbshipit-source-id: b82cde32b71e324eca0ea57cb8c9f9647278ca74
(cherry picked from commit e009bc5c4e943211c4953e6fdf7c9913fa66b3c9)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 21, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 21, 2022
Summary:
Things changed in this PR that requires review:
1. aten/src/ATen/core/interned_strings.h
2. torch/csrc/jit/ir/alias_analysis.h : exposing createValue to allow efficient mutation
3. torch/csrc/jit/runtime/symbolic_shape_registry.cpp : added gelu/tanh/erf in registry
4. torch/jit/_script.py : throws scripting model sees autocast as decorator since it's not supported

nvfuser code update:
1. codegen improvements and performance tuning
2. integration bug fixes for shape expression logic
3. kernel segmentation update to address perf regression from horizontal fusion
4. scalar cpu tensor promotion to support inter-device operation between cpu scalar tensor and cuda tensor

Things reverted from local changes:
aten::gelu with approximation (tracked in PR: pytorch/pytorch#61439)

Pull Request resolved: pytorch/pytorch#72127

Reviewed By: HamidShojanazeri

Differential Revision: D34113233

Pulled By: jbschlosser

fbshipit-source-id: b82cde32b71e324eca0ea57cb8c9f9647278ca74
(cherry picked from commit e009bc5c4e943211c4953e6fdf7c9913fa66b3c9)
leezu pushed a commit to leezu/pytorch that referenced this pull request Mar 30, 2022
Summary:
1. Implements pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
leezu pushed a commit to leezu/pytorch that referenced this pull request Mar 30, 2022
Summary:
Things changed in this PR that requires review:
1. aten/src/ATen/core/interned_strings.h
2. torch/csrc/jit/ir/alias_analysis.h : exposing createValue to allow efficient mutation
3. torch/csrc/jit/runtime/symbolic_shape_registry.cpp : added gelu/tanh/erf in registry
4. torch/jit/_script.py : throws scripting model sees autocast as decorator since it's not supported

nvfuser code update:
1. codegen improvements and performance tuning
2. integration bug fixes for shape expression logic
3. kernel segmentation update to address perf regression from horizontal fusion
4. scalar cpu tensor promotion to support inter-device operation between cpu scalar tensor and cuda tensor

Things reverted from local changes:
aten::gelu with approximation (tracked in PR: pytorch#61439)

Pull Request resolved: pytorch#72127

Reviewed By: HamidShojanazeri

Differential Revision: D34113233

Pulled By: jbschlosser

fbshipit-source-id: b82cde32b71e324eca0ea57cb8c9f9647278ca74
(cherry picked from commit e009bc5)
jjsjann123 pushed a commit to jjsjann123/nvfuser that referenced this pull request Oct 29, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: mikaylagawarecki

Differential Revision: D33744717

Pulled By: jbschlosser

fbshipit-source-id: d64532a562ed53247bb4fa52bb16722634d5c187
(cherry picked from commit 4713dd9ccaa8983422bf3aa7b73df8d9ebd8cc02)
jjsjann123 pushed a commit to jjsjann123/nvfuser that referenced this pull request Oct 29, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: cpuhrsch

Differential Revision: D33850228

Pulled By: jbschlosser

fbshipit-source-id: 3cc33fb298e480d7ecc5c67716da019d60c6ab33
(cherry picked from commit 3a53b3e94fd58190d1261efd3cf41b53506fb96e)
jjsjann123 pushed a commit to jjsjann123/nvfuser that referenced this pull request Oct 29, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
jjsjann123 added a commit to jjsjann123/nvfuser that referenced this pull request Oct 29, 2022
Summary:
Things changed in this PR that requires review:
1. aten/src/ATen/core/interned_strings.h
2. torch/csrc/jit/ir/alias_analysis.h : exposing createValue to allow efficient mutation
3. torch/csrc/jit/runtime/symbolic_shape_registry.cpp : added gelu/tanh/erf in registry
4. torch/jit/_script.py : throws scripting model sees autocast as decorator since it's not supported

nvfuser code update:
1. codegen improvements and performance tuning
2. integration bug fixes for shape expression logic
3. kernel segmentation update to address perf regression from horizontal fusion
4. scalar cpu tensor promotion to support inter-device operation between cpu scalar tensor and cuda tensor

Things reverted from local changes:
aten::gelu with approximation (tracked in PR: pytorch/pytorch#61439)

Pull Request resolved: pytorch/pytorch#72127

Reviewed By: HamidShojanazeri

Differential Revision: D34113233

Pulled By: jbschlosser

fbshipit-source-id: b82cde32b71e324eca0ea57cb8c9f9647278ca74
(cherry picked from commit e009bc5c4e943211c4953e6fdf7c9913fa66b3c9)
jjsjann123 pushed a commit to jjsjann123/nvfuser that referenced this pull request Nov 10, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: mikaylagawarecki

Differential Revision: D33744717

Pulled By: jbschlosser

fbshipit-source-id: d64532a562ed53247bb4fa52bb16722634d5c187
(cherry picked from commit 4713dd9ccaa8983422bf3aa7b73df8d9ebd8cc02)
jjsjann123 pushed a commit to jjsjann123/nvfuser that referenced this pull request Nov 10, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: cpuhrsch

Differential Revision: D33850228

Pulled By: jbschlosser

fbshipit-source-id: 3cc33fb298e480d7ecc5c67716da019d60c6ab33
(cherry picked from commit 3a53b3e94fd58190d1261efd3cf41b53506fb96e)
jjsjann123 pushed a commit to jjsjann123/nvfuser that referenced this pull request Nov 10, 2022
Summary:
1. Implements pytorch/pytorch#39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - pytorch/xla#3039

Pull Request resolved: pytorch/pytorch#61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
jjsjann123 added a commit to jjsjann123/nvfuser that referenced this pull request Nov 10, 2022
Summary:
Things changed in this PR that requires review:
1. aten/src/ATen/core/interned_strings.h
2. torch/csrc/jit/ir/alias_analysis.h : exposing createValue to allow efficient mutation
3. torch/csrc/jit/runtime/symbolic_shape_registry.cpp : added gelu/tanh/erf in registry
4. torch/jit/_script.py : throws scripting model sees autocast as decorator since it's not supported

nvfuser code update:
1. codegen improvements and performance tuning
2. integration bug fixes for shape expression logic
3. kernel segmentation update to address perf regression from horizontal fusion
4. scalar cpu tensor promotion to support inter-device operation between cpu scalar tensor and cuda tensor

Things reverted from local changes:
aten::gelu with approximation (tracked in PR: pytorch/pytorch#61439)

Pull Request resolved: pytorch/pytorch#72127

Reviewed By: HamidShojanazeri

Differential Revision: D34113233

Pulled By: jbschlosser

fbshipit-source-id: b82cde32b71e324eca0ea57cb8c9f9647278ca74
(cherry picked from commit e009bc5c4e943211c4953e6fdf7c9913fa66b3c9)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed oncall: jit Add this issue/PR to JIT oncall triage queue open source release notes: nn release notes category Reverted topic: improvements topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet