Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.var_mean is slower than layer norm #83800

Open
quancs opened this issue Aug 20, 2022 · 3 comments
Open

torch.var_mean is slower than layer norm #83800

quancs opened this issue Aug 20, 2022 · 3 comments
Labels
module: nn Related to torch.nn module: performance Issues related to performance, either of kernel code or framework glue needs research We need to decide whether or not this merits inclusion, based on research world triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@quancs
Copy link

quancs commented Aug 20, 2022

馃悰 Describe the bug

It's known that layer norm needs to compute the variance and mean of its input. So we can expect that torch.var_mean runs faster than LayerNorm. But, when I time them, I find that torch.var_mean runs much slower than LayerNorm on cpu.

from functools import partial
import torch
import timeit

x = torch.randn((257, 252, 192),dtype=torch.float32)
ln = torch.nn.LayerNorm(192)

ln.eval()
with torch.no_grad():
    var_mean_time = timeit.timeit(partial(torch.var_mean, input=x, dim=(2,)), number=100)
    ln_time = timeit.timeit(partial(ln, input=x), number=100)
    print(var_mean_time, ln_time) # 3.149209 1.2331005

Versions

PyTorch version: 1.12.1
Is debug build: False
CUDA used to build PyTorch: 11.6
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 11 Home China
GCC version: (x86_64-posix-seh, Built by strawberryperl.com project) 8.3.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A

Python version: 3.9.12 (main, Apr 4 2022, 05:22:27) [MSC v.1916 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.22622-SP0
Is CUDA available: True
CUDA runtime version: 11.7.64
GPU models and configuration: GPU 0: NVIDIA GeForce GTX 1080
Nvidia driver version: 516.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy==0.971
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.22.4
[pip3] pytorch-lightning==1.7.2
[pip3] pytorch-ranger==0.1.1
[pip3] torch==1.12.1
[pip3] torch-complex==0.4.3
[pip3] torch-optimizer==0.3.0
[pip3] torch-stoi==0.1.2
[pip3] torchaudio==0.12.1
[pip3] torchdata==0.4.1
[pip3] torchmetrics==0.9.3
[pip3] torchvision==0.13.1
[conda] blas 2.115 mkl conda-forge
[conda] blas-devel 3.9.0 15_win64_mkl conda-forge
[conda] cudatoolkit 11.6.0 hc0ea762_10 conda-forge
[conda] libblas 3.9.0 15_win64_mkl conda-forge
[conda] libcblas 3.9.0 15_win64_mkl conda-forge
[conda] liblapack 3.9.0 15_win64_mkl conda-forge
[conda] liblapacke 3.9.0 15_win64_mkl conda-forge
[conda] mkl 2022.1.0 pypi_0 pypi
[conda] mkl-devel 2022.1.0 h57928b3_875 conda-forge
[conda] mkl-fft 1.3.1 pypi_0 pypi
[conda] mkl-include 2022.1.0 h6a75c08_874 conda-forge
[conda] mkl-random 1.2.2 pypi_0 pypi
[conda] mkl-service 2.4.0 pypi_0 pypi
[conda] numpy 1.22.4 pypi_0 pypi
[conda] pytorch 1.12.1 py3.9_cuda11.6_cudnn8_0 pytorch
[conda] pytorch-lightning 1.7.2 pypi_0 pypi
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] pytorch-ranger 0.1.1 pypi_0 pypi
[conda] torch-complex 0.4.3 pypi_0 pypi
[conda] torch-optimizer 0.3.0 pypi_0 pypi
[conda] torch-stoi 0.1.2 pypi_0 pypi
[conda] torchaudio 0.12.1 py39_cu116 pytorch
[conda] torchdata 0.4.1 pypi_0 pypi
[conda] torchmetrics 0.9.3 pypi_0 pypi
[conda] torchvision 0.13.1 py39_cu116 pytorch

cc @albanD @mruberry @jbschlosser @walterddr @kshitij12345 @saketh-are @VitalyFedyunin @ngimel

@mikaylagawarecki mikaylagawarecki added module: performance Issues related to performance, either of kernel code or framework glue module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed module: performance Issues related to performance, either of kernel code or framework glue labels Aug 23, 2022
@mikaylagawarecki
Copy link
Contributor

Hi! Thanks for the report, was able to repro this on master -- var_mean indeed runs much slower than LayerNorm on cpu. The issue seems to arise from the computation of variance as torch.var also runs slower than LayerNorm. Will look into this further!

@jbschlosser jbschlosser added the needs research We need to decide whether or not this merits inclusion, based on research world label Jan 27, 2023
@ngimel
Copy link
Collaborator

ngimel commented Jan 28, 2023

On cpu, var and mean reduction kernels are non-vectorized, and thus are pretty slow, layer norm has an optimized implementation. Is var_mean actually a bottleneck for you @quancs?

@quancs
Copy link
Author

quancs commented Jan 28, 2023

Yes. I implemented a new normalization method, which relays on var_mean for computing statistics.

@ngimel ngimel removed their assignment Apr 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: nn Related to torch.nn module: performance Issues related to performance, either of kernel code or framework glue needs research We need to decide whether or not this merits inclusion, based on research world triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: To pick up
Development

No branches or pull requests

4 participants