Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slow (20-50x) RNN tutorial/example when torch is installed using pip comp. to conda installation #29722

Open
CeadeS opened this issue Nov 13, 2019 · 17 comments
Labels
module: binaries Anything related to official binaries that we release to users module: mkl Related to our MKL support module: performance Issues related to performance, either of kernel code or framework glue module: rnn Issues related to RNN support (LSTM, GRU, etc) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@CeadeS
Copy link

CeadeS commented Nov 13, 2019

❓ Questions and Help

The example https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html
runs slow on pytorch installations 1.[320].x[+cudaxxx ]
I first noticed that behavior on an really old version of torch, then updated. The example ran 50 times faster on a Colab notebook and on a windows machine then on my workstation. Then i tried different environments in docker and found that the behavior is observable for all environments with and without nvidia cuda and different distributions when using pip to install torch.
Conda installations are running fine.

Did i do something wrong?

cc @ezyang @VitalyFedyunin @ngimel @mruberry @zou3519

@peterjc123
Copy link
Collaborator

peterjc123 commented Nov 14, 2019

Are you running PyTorch 1.3.0/1.3.1 with CUDA 10.1? If yes, then it is because we didn't pack all the CUDA DLLs in it. Please wait for #29356 to be fixed.

@CeadeS
Copy link
Author

CeadeS commented Nov 14, 2019

Are you running PyTorch 1.3.0/1.3.1 with CUDA 10.1? If yes, then it is because we didn't pack all the CUDA DLLs in it. Please wait for #29356 to be fixed.

I guess its the MKL

@peterjc123
Copy link
Collaborator

We link against the MKL static libraries, so they should not rely on the user's environment.

@vincentqb vincentqb added module: binaries Anything related to official binaries that we release to users triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: rnn Issues related to RNN support (LSTM, GRU, etc) labels Nov 14, 2019
@vincentqb
Copy link
Contributor

Have you also tried compiling from source?

@zou3519 zou3519 added the module: performance Issues related to performance, either of kernel code or framework glue label Nov 14, 2019
@vincentqb vincentqb added the module: mkl Related to our MKL support label Nov 14, 2019
@CeadeS
Copy link
Author

CeadeS commented Nov 15, 2019

Compiled it from source now but got the same results. With and without cuda.

@CeadeS
Copy link
Author

CeadeS commented Nov 18, 2019

I compiled it with TBB and it runs much faster now. Not as fast as the conda build, but i can work with that.

@peterjc123
Copy link
Collaborator

Why? We didn't use MKL that is installed with conda. Also, we didn't build it with TBB in both conda builds and wheels build. That's pretty weird.

@CeadeS
Copy link
Author

CeadeS commented Nov 18, 2019

Why? We didn't use MKL that is installed with conda. Also, we didn't build it with TBB in both conda builds and wheels build. That's pretty weird.

I am too new to this, to be able to comment on that. The only thing I know is, that I compiled that over and over with different settings and ended up with one that is kind of working. Maybe it has something to do with openmp?

@peterjc123
Copy link
Collaborator

peterjc123 commented Nov 18, 2019

By comparing the build scripts and the logs in (https://dev.azure.com/pytorch/PyTorch/_build/results?buildId=17796 vs https://dev.azure.com/pytorch/PyTorch/_build/results?buildId=17796) and (https://github.com/pytorch/builder/blob/master/conda/pytorch-nightly/bld.bat vs https://github.com/pytorch/builder/blob/master/windows/build_pytorch.bat), I don't think openmp is causing this issue and I couldn't find the actual reason that may lead to this issue. However, your finding may be also helpful to #19106 because if it is the case, then libtorch should also be affected because they use the same build script.

@peterjc123
Copy link
Collaborator

peterjc123 commented Nov 18, 2019

BTW, does replacing some of the DLLs with those in the conda package help? If yes, then we can actually compare the compile and link commands.

@pytorch pytorch deleted a comment from Rustin333 Nov 18, 2019
@CeadeS
Copy link
Author

CeadeS commented Nov 18, 2019

Installing pytorch with pip in an activated conda enviroment leads to a fast version. Only the installation outside of a conda enviroment is slow. Is there a missing dependency or a wrong version that is omitted by pip and installed correctly by anaconda?

@peterjc123
Copy link
Collaborator

@CeadeS Interesting. Would you please do the following to print out the DLLs that were loaded in the two environments:

# pip install psutil first
import psutil, os
import torch
p = psutil.Process( os.getpid() )
for dll in p.memory_maps():
  print(dll.path)

@CeadeS
Copy link
Author

CeadeS commented Nov 19, 2019

@peterjc123 i generated the prints for different enviroments that can be found here: https://github.com/CeadeS/notes/tree/master/torch_perf
I noticed that the not working installation with pip is not importing mkl-ish but the conda installs do. The builts both, the working and the not working one do not load something mkl-ish

@peterjc123
Copy link
Collaborator

cc @soumith

@CeadeS
Copy link
Author

CeadeS commented Nov 20, 2019

I replaced the libgomp-7c85b1e2.so.1 with libiomp5.so
The performance improved, i was able to reproduce behavior. Do not ask me why

@CeadeS
Copy link
Author

CeadeS commented Nov 20, 2019

I guess its FindOpenMP.cmake or FindOpenMKL.cmake libiomp ist not linked correctly even with the latest intel build enviroment installed.

CeadeS added a commit to CeadeS/pytorch that referenced this issue Nov 21, 2019
MKLDNN determines by itself if intel openmp is present or not. Setting this variable here prevents it from using intel openmp. This results in the behavior described in pytorch#29722
@CeadeS
Copy link
Author

CeadeS commented Nov 21, 2019

I found a working solution. FindMKLDNN.cmake enforces SET(MKLDNN_THREADING "OMP:COMP" CACHE STRING "") if no MKLDNN Threading is defined. This prevents MKLDNN from finding the libiomp5 and disables intel mkldnn threading. The default should be "OMP" so the MKLDNN is able to look lookup libiomp5 by itself.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: binaries Anything related to official binaries that we release to users module: mkl Related to our MKL support module: performance Issues related to performance, either of kernel code or framework glue module: rnn Issues related to RNN support (LSTM, GRU, etc) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants
@vincentqb @zou3519 @peterjc123 @CeadeS and others