torch.var_mean is slower than layer norm #83800
Labels
module: nn
Related to torch.nn
module: performance
Issues related to performance, either of kernel code or framework glue
needs research
We need to decide whether or not this merits inclusion, based on research world
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
馃悰 Describe the bug
It's known that layer norm needs to compute the variance and mean of its input. So we can expect that
torch.var_mean
runs faster thanLayerNorm
. But, when I time them, I find thattorch.var_mean
runs much slower thanLayerNorm
on cpu.Versions
PyTorch version: 1.12.1
Is debug build: False
CUDA used to build PyTorch: 11.6
ROCM used to build PyTorch: N/A
OS: Microsoft Windows 11 Home China
GCC version: (x86_64-posix-seh, Built by strawberryperl.com project) 8.3.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A
Python version: 3.9.12 (main, Apr 4 2022, 05:22:27) [MSC v.1916 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.22622-SP0
Is CUDA available: True
CUDA runtime version: 11.7.64
GPU models and configuration: GPU 0: NVIDIA GeForce GTX 1080
Nvidia driver version: 516.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] mypy==0.971
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.22.4
[pip3] pytorch-lightning==1.7.2
[pip3] pytorch-ranger==0.1.1
[pip3] torch==1.12.1
[pip3] torch-complex==0.4.3
[pip3] torch-optimizer==0.3.0
[pip3] torch-stoi==0.1.2
[pip3] torchaudio==0.12.1
[pip3] torchdata==0.4.1
[pip3] torchmetrics==0.9.3
[pip3] torchvision==0.13.1
[conda] blas 2.115 mkl conda-forge
[conda] blas-devel 3.9.0 15_win64_mkl conda-forge
[conda] cudatoolkit 11.6.0 hc0ea762_10 conda-forge
[conda] libblas 3.9.0 15_win64_mkl conda-forge
[conda] libcblas 3.9.0 15_win64_mkl conda-forge
[conda] liblapack 3.9.0 15_win64_mkl conda-forge
[conda] liblapacke 3.9.0 15_win64_mkl conda-forge
[conda] mkl 2022.1.0 pypi_0 pypi
[conda] mkl-devel 2022.1.0 h57928b3_875 conda-forge
[conda] mkl-fft 1.3.1 pypi_0 pypi
[conda] mkl-include 2022.1.0 h6a75c08_874 conda-forge
[conda] mkl-random 1.2.2 pypi_0 pypi
[conda] mkl-service 2.4.0 pypi_0 pypi
[conda] numpy 1.22.4 pypi_0 pypi
[conda] pytorch 1.12.1 py3.9_cuda11.6_cudnn8_0 pytorch
[conda] pytorch-lightning 1.7.2 pypi_0 pypi
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] pytorch-ranger 0.1.1 pypi_0 pypi
[conda] torch-complex 0.4.3 pypi_0 pypi
[conda] torch-optimizer 0.3.0 pypi_0 pypi
[conda] torch-stoi 0.1.2 pypi_0 pypi
[conda] torchaudio 0.12.1 py39_cu116 pytorch
[conda] torchdata 0.4.1 pypi_0 pypi
[conda] torchmetrics 0.9.3 pypi_0 pypi
[conda] torchvision 0.13.1 py39_cu116 pytorch
cc @albanD @mruberry @jbschlosser @walterddr @kshitij12345 @saketh-are @VitalyFedyunin @ngimel
The text was updated successfully, but these errors were encountered: