New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Tanh Gelu Approximation #61439
Conversation
Add fast-gelu implementation for CPU and CUDA
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit fbd5e62 (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages: linux-bionic-py3.7-clang9 / test (xla, 1, 1, linux.2xlarge) (1/1)Step: "Test" (full log | diagnosis details | 🔁 rerun)
|
From what I understood, there are now multiple approximations in use in third-party frameworks. Should the arg then be called |
Add approximate argument to JIT symbolic script + testing Update Gelu documenation
@rwightman @hendrycks @ptrblck Could you please comment on this? |
Possibilities: Or there could be both the sigmoid (for lowest memory usage) and tanh approximation (for cross-compatibility with tensorflow). |
Tensor Expressions ignores the approximate flag
FYI, backward_compatibility_check_test is just a warning that you are changing the signature and need to add the function to an allow list. |
@vadimkantorov @hendrycks re the arg, if it is going to be an arg I think the string approach is good as there are different users/use cases of both approx out there, probably don't need both the bool + string overload... However, personally I don't see much point in squishing it into one fn. All 3 (original, sigmoid approx, tanh approx) aren't numerically compatible with each other (ie you can't just change the flag and expect a trained network to work the same, you have to fine-tune). I think it's more clear in that respect if they are separate activation instances. Having it as an arg isn't going to save one code/effort as it'll require a partial wrap anyways, might as well make it |
Tensorflow/keras put it in one function. In other modules you can customize them to get slightly different behavior (e.g., nn.Linear(..., bias=False), torch.svd_lowrank(A, q=6, niter=2)). The other advantage is that it doesn't clutter the PyTorch namespace as much. I don't have a strong feeling either way, but argument options feel cleaner to me.
Yes, so it could be approx='sigmoid', approx='tanh' (and maybe approx='True' depending on what others think). |
Summary: Things changed in this PR that requires review: 1. aten/src/ATen/core/interned_strings.h 2. torch/csrc/jit/ir/alias_analysis.h : exposing createValue to allow efficient mutation 3. torch/csrc/jit/runtime/symbolic_shape_registry.cpp : added gelu/tanh/erf in registry 4. torch/jit/_script.py : throws scripting model sees autocast as decorator since it's not supported nvfuser code update: 1. codegen improvements and performance tuning 2. integration bug fixes for shape expression logic 3. kernel segmentation update to address perf regression from horizontal fusion 4. scalar cpu tensor promotion to support inter-device operation between cpu scalar tensor and cuda tensor Things reverted from local changes: aten::gelu with approximation (tracked in PR: pytorch/pytorch#61439) Pull Request resolved: pytorch/pytorch#72127 Reviewed By: HamidShojanazeri Differential Revision: D34113233 Pulled By: jbschlosser fbshipit-source-id: b82cde32b71e324eca0ea57cb8c9f9647278ca74 (cherry picked from commit e009bc5c4e943211c4953e6fdf7c9913fa66b3c9)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: VitalyFedyunin Differential Revision: D33894937 Pulled By: jbschlosser fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851 (cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
Summary: Things changed in this PR that requires review: 1. aten/src/ATen/core/interned_strings.h 2. torch/csrc/jit/ir/alias_analysis.h : exposing createValue to allow efficient mutation 3. torch/csrc/jit/runtime/symbolic_shape_registry.cpp : added gelu/tanh/erf in registry 4. torch/jit/_script.py : throws scripting model sees autocast as decorator since it's not supported nvfuser code update: 1. codegen improvements and performance tuning 2. integration bug fixes for shape expression logic 3. kernel segmentation update to address perf regression from horizontal fusion 4. scalar cpu tensor promotion to support inter-device operation between cpu scalar tensor and cuda tensor Things reverted from local changes: aten::gelu with approximation (tracked in PR: pytorch/pytorch#61439) Pull Request resolved: pytorch/pytorch#72127 Reviewed By: HamidShojanazeri Differential Revision: D34113233 Pulled By: jbschlosser fbshipit-source-id: b82cde32b71e324eca0ea57cb8c9f9647278ca74 (cherry picked from commit e009bc5c4e943211c4953e6fdf7c9913fa66b3c9)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: VitalyFedyunin Differential Revision: D33894937 Pulled By: jbschlosser fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851 (cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
Summary: Things changed in this PR that requires review: 1. aten/src/ATen/core/interned_strings.h 2. torch/csrc/jit/ir/alias_analysis.h : exposing createValue to allow efficient mutation 3. torch/csrc/jit/runtime/symbolic_shape_registry.cpp : added gelu/tanh/erf in registry 4. torch/jit/_script.py : throws scripting model sees autocast as decorator since it's not supported nvfuser code update: 1. codegen improvements and performance tuning 2. integration bug fixes for shape expression logic 3. kernel segmentation update to address perf regression from horizontal fusion 4. scalar cpu tensor promotion to support inter-device operation between cpu scalar tensor and cuda tensor Things reverted from local changes: aten::gelu with approximation (tracked in PR: pytorch/pytorch#61439) Pull Request resolved: pytorch/pytorch#72127 Reviewed By: HamidShojanazeri Differential Revision: D34113233 Pulled By: jbschlosser fbshipit-source-id: b82cde32b71e324eca0ea57cb8c9f9647278ca74 (cherry picked from commit e009bc5c4e943211c4953e6fdf7c9913fa66b3c9)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: VitalyFedyunin Differential Revision: D33894937 Pulled By: jbschlosser fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851 (cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
Summary: Things changed in this PR that requires review: 1. aten/src/ATen/core/interned_strings.h 2. torch/csrc/jit/ir/alias_analysis.h : exposing createValue to allow efficient mutation 3. torch/csrc/jit/runtime/symbolic_shape_registry.cpp : added gelu/tanh/erf in registry 4. torch/jit/_script.py : throws scripting model sees autocast as decorator since it's not supported nvfuser code update: 1. codegen improvements and performance tuning 2. integration bug fixes for shape expression logic 3. kernel segmentation update to address perf regression from horizontal fusion 4. scalar cpu tensor promotion to support inter-device operation between cpu scalar tensor and cuda tensor Things reverted from local changes: aten::gelu with approximation (tracked in PR: pytorch/pytorch#61439) Pull Request resolved: pytorch/pytorch#72127 Reviewed By: HamidShojanazeri Differential Revision: D34113233 Pulled By: jbschlosser fbshipit-source-id: b82cde32b71e324eca0ea57cb8c9f9647278ca74 (cherry picked from commit e009bc5c4e943211c4953e6fdf7c9913fa66b3c9)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: VitalyFedyunin Differential Revision: D33894937 Pulled By: jbschlosser fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851 (cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
Summary: Things changed in this PR that requires review: 1. aten/src/ATen/core/interned_strings.h 2. torch/csrc/jit/ir/alias_analysis.h : exposing createValue to allow efficient mutation 3. torch/csrc/jit/runtime/symbolic_shape_registry.cpp : added gelu/tanh/erf in registry 4. torch/jit/_script.py : throws scripting model sees autocast as decorator since it's not supported nvfuser code update: 1. codegen improvements and performance tuning 2. integration bug fixes for shape expression logic 3. kernel segmentation update to address perf regression from horizontal fusion 4. scalar cpu tensor promotion to support inter-device operation between cpu scalar tensor and cuda tensor Things reverted from local changes: aten::gelu with approximation (tracked in PR: pytorch/pytorch#61439) Pull Request resolved: pytorch/pytorch#72127 Reviewed By: HamidShojanazeri Differential Revision: D34113233 Pulled By: jbschlosser fbshipit-source-id: b82cde32b71e324eca0ea57cb8c9f9647278ca74 (cherry picked from commit e009bc5c4e943211c4953e6fdf7c9913fa66b3c9)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: VitalyFedyunin Differential Revision: D33894937 Pulled By: jbschlosser fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851 (cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
Summary: Things changed in this PR that requires review: 1. aten/src/ATen/core/interned_strings.h 2. torch/csrc/jit/ir/alias_analysis.h : exposing createValue to allow efficient mutation 3. torch/csrc/jit/runtime/symbolic_shape_registry.cpp : added gelu/tanh/erf in registry 4. torch/jit/_script.py : throws scripting model sees autocast as decorator since it's not supported nvfuser code update: 1. codegen improvements and performance tuning 2. integration bug fixes for shape expression logic 3. kernel segmentation update to address perf regression from horizontal fusion 4. scalar cpu tensor promotion to support inter-device operation between cpu scalar tensor and cuda tensor Things reverted from local changes: aten::gelu with approximation (tracked in PR: pytorch/pytorch#61439) Pull Request resolved: pytorch/pytorch#72127 Reviewed By: HamidShojanazeri Differential Revision: D34113233 Pulled By: jbschlosser fbshipit-source-id: b82cde32b71e324eca0ea57cb8c9f9647278ca74 (cherry picked from commit e009bc5c4e943211c4953e6fdf7c9913fa66b3c9)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: VitalyFedyunin Differential Revision: D33894937 Pulled By: jbschlosser fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851 (cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
Summary: Things changed in this PR that requires review: 1. aten/src/ATen/core/interned_strings.h 2. torch/csrc/jit/ir/alias_analysis.h : exposing createValue to allow efficient mutation 3. torch/csrc/jit/runtime/symbolic_shape_registry.cpp : added gelu/tanh/erf in registry 4. torch/jit/_script.py : throws scripting model sees autocast as decorator since it's not supported nvfuser code update: 1. codegen improvements and performance tuning 2. integration bug fixes for shape expression logic 3. kernel segmentation update to address perf regression from horizontal fusion 4. scalar cpu tensor promotion to support inter-device operation between cpu scalar tensor and cuda tensor Things reverted from local changes: aten::gelu with approximation (tracked in PR: pytorch/pytorch#61439) Pull Request resolved: pytorch/pytorch#72127 Reviewed By: HamidShojanazeri Differential Revision: D34113233 Pulled By: jbschlosser fbshipit-source-id: b82cde32b71e324eca0ea57cb8c9f9647278ca74 (cherry picked from commit e009bc5c4e943211c4953e6fdf7c9913fa66b3c9)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: VitalyFedyunin Differential Revision: D33894937 Pulled By: jbschlosser fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851 (cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
Summary: Things changed in this PR that requires review: 1. aten/src/ATen/core/interned_strings.h 2. torch/csrc/jit/ir/alias_analysis.h : exposing createValue to allow efficient mutation 3. torch/csrc/jit/runtime/symbolic_shape_registry.cpp : added gelu/tanh/erf in registry 4. torch/jit/_script.py : throws scripting model sees autocast as decorator since it's not supported nvfuser code update: 1. codegen improvements and performance tuning 2. integration bug fixes for shape expression logic 3. kernel segmentation update to address perf regression from horizontal fusion 4. scalar cpu tensor promotion to support inter-device operation between cpu scalar tensor and cuda tensor Things reverted from local changes: aten::gelu with approximation (tracked in PR: pytorch/pytorch#61439) Pull Request resolved: pytorch/pytorch#72127 Reviewed By: HamidShojanazeri Differential Revision: D34113233 Pulled By: jbschlosser fbshipit-source-id: b82cde32b71e324eca0ea57cb8c9f9647278ca74 (cherry picked from commit e009bc5c4e943211c4953e6fdf7c9913fa66b3c9)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: VitalyFedyunin Differential Revision: D33894937 Pulled By: jbschlosser fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851 (cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
Summary: Things changed in this PR that requires review: 1. aten/src/ATen/core/interned_strings.h 2. torch/csrc/jit/ir/alias_analysis.h : exposing createValue to allow efficient mutation 3. torch/csrc/jit/runtime/symbolic_shape_registry.cpp : added gelu/tanh/erf in registry 4. torch/jit/_script.py : throws scripting model sees autocast as decorator since it's not supported nvfuser code update: 1. codegen improvements and performance tuning 2. integration bug fixes for shape expression logic 3. kernel segmentation update to address perf regression from horizontal fusion 4. scalar cpu tensor promotion to support inter-device operation between cpu scalar tensor and cuda tensor Things reverted from local changes: aten::gelu with approximation (tracked in PR: pytorch/pytorch#61439) Pull Request resolved: pytorch/pytorch#72127 Reviewed By: HamidShojanazeri Differential Revision: D34113233 Pulled By: jbschlosser fbshipit-source-id: b82cde32b71e324eca0ea57cb8c9f9647278ca74 (cherry picked from commit e009bc5c4e943211c4953e6fdf7c9913fa66b3c9)
Summary: 1. Implements pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch#61439 Reviewed By: VitalyFedyunin Differential Revision: D33894937 Pulled By: jbschlosser fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
Summary: Things changed in this PR that requires review: 1. aten/src/ATen/core/interned_strings.h 2. torch/csrc/jit/ir/alias_analysis.h : exposing createValue to allow efficient mutation 3. torch/csrc/jit/runtime/symbolic_shape_registry.cpp : added gelu/tanh/erf in registry 4. torch/jit/_script.py : throws scripting model sees autocast as decorator since it's not supported nvfuser code update: 1. codegen improvements and performance tuning 2. integration bug fixes for shape expression logic 3. kernel segmentation update to address perf regression from horizontal fusion 4. scalar cpu tensor promotion to support inter-device operation between cpu scalar tensor and cuda tensor Things reverted from local changes: aten::gelu with approximation (tracked in PR: pytorch#61439) Pull Request resolved: pytorch#72127 Reviewed By: HamidShojanazeri Differential Revision: D34113233 Pulled By: jbschlosser fbshipit-source-id: b82cde32b71e324eca0ea57cb8c9f9647278ca74 (cherry picked from commit e009bc5)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: mikaylagawarecki Differential Revision: D33744717 Pulled By: jbschlosser fbshipit-source-id: d64532a562ed53247bb4fa52bb16722634d5c187 (cherry picked from commit 4713dd9ccaa8983422bf3aa7b73df8d9ebd8cc02)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: cpuhrsch Differential Revision: D33850228 Pulled By: jbschlosser fbshipit-source-id: 3cc33fb298e480d7ecc5c67716da019d60c6ab33 (cherry picked from commit 3a53b3e94fd58190d1261efd3cf41b53506fb96e)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: VitalyFedyunin Differential Revision: D33894937 Pulled By: jbschlosser fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851 (cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
Summary: Things changed in this PR that requires review: 1. aten/src/ATen/core/interned_strings.h 2. torch/csrc/jit/ir/alias_analysis.h : exposing createValue to allow efficient mutation 3. torch/csrc/jit/runtime/symbolic_shape_registry.cpp : added gelu/tanh/erf in registry 4. torch/jit/_script.py : throws scripting model sees autocast as decorator since it's not supported nvfuser code update: 1. codegen improvements and performance tuning 2. integration bug fixes for shape expression logic 3. kernel segmentation update to address perf regression from horizontal fusion 4. scalar cpu tensor promotion to support inter-device operation between cpu scalar tensor and cuda tensor Things reverted from local changes: aten::gelu with approximation (tracked in PR: pytorch/pytorch#61439) Pull Request resolved: pytorch/pytorch#72127 Reviewed By: HamidShojanazeri Differential Revision: D34113233 Pulled By: jbschlosser fbshipit-source-id: b82cde32b71e324eca0ea57cb8c9f9647278ca74 (cherry picked from commit e009bc5c4e943211c4953e6fdf7c9913fa66b3c9)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: mikaylagawarecki Differential Revision: D33744717 Pulled By: jbschlosser fbshipit-source-id: d64532a562ed53247bb4fa52bb16722634d5c187 (cherry picked from commit 4713dd9ccaa8983422bf3aa7b73df8d9ebd8cc02)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: cpuhrsch Differential Revision: D33850228 Pulled By: jbschlosser fbshipit-source-id: 3cc33fb298e480d7ecc5c67716da019d60c6ab33 (cherry picked from commit 3a53b3e94fd58190d1261efd3cf41b53506fb96e)
Summary: 1. Implements pytorch/pytorch#39853 2. Adds approximate boolean flag to Gelu 3. Enables Tanh Gelu approximation 4. Adds double backward support for Gelu 5. Enable Tanh Gelu in NvFuser ``` def gelu(x, approximate : str = 'none'): if approximate == 'tanh': # sqrt(2/pi) = 0.7978845608028654 return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0)))) else: return x * normcdf(x) ``` Linking XLA PR - pytorch/xla#3039 Pull Request resolved: pytorch/pytorch#61439 Reviewed By: VitalyFedyunin Differential Revision: D33894937 Pulled By: jbschlosser fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851 (cherry picked from commit 6e986f91a958dd73514b4e64984c0b149157dc6f)
Summary: Things changed in this PR that requires review: 1. aten/src/ATen/core/interned_strings.h 2. torch/csrc/jit/ir/alias_analysis.h : exposing createValue to allow efficient mutation 3. torch/csrc/jit/runtime/symbolic_shape_registry.cpp : added gelu/tanh/erf in registry 4. torch/jit/_script.py : throws scripting model sees autocast as decorator since it's not supported nvfuser code update: 1. codegen improvements and performance tuning 2. integration bug fixes for shape expression logic 3. kernel segmentation update to address perf regression from horizontal fusion 4. scalar cpu tensor promotion to support inter-device operation between cpu scalar tensor and cuda tensor Things reverted from local changes: aten::gelu with approximation (tracked in PR: pytorch/pytorch#61439) Pull Request resolved: pytorch/pytorch#72127 Reviewed By: HamidShojanazeri Differential Revision: D34113233 Pulled By: jbschlosser fbshipit-source-id: b82cde32b71e324eca0ea57cb8c9f9647278ca74 (cherry picked from commit e009bc5c4e943211c4953e6fdf7c9913fa66b3c9)
Linking XLA PR - pytorch/xla#3039