Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile is not usable on MacOS out of box #122705

Closed
malfet opened this issue Mar 26, 2024 · 7 comments
Closed

torch.compile is not usable on MacOS out of box #122705

malfet opened this issue Mar 26, 2024 · 7 comments
Assignees
Labels
high priority module: binaries Anything related to official binaries that we release to users module: macos Mac OS related issues module: openmp Related to OpenMP (omp) support in PyTorch oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@malfet
Copy link
Contributor

malfet commented Mar 26, 2024

🐛 Describe the bug

Attempts to run any torch.compiled code fails out-of-box with 2.3.0 wheels, as it attempts to link against wrong copy of libomp.dylib:

OMP: Error #15: Initializing libomp.dylib, but found libomp.dylib already initialized.
OMP: Hint This means that multiple copies of the OpenMP runtime have been linked into the program. That is dangerous, since it can degrade performance or cause incorrect results. The best thing to do is to ensure that only a single OpenMP runtime is linked into the process, e.g. by avoiding static linking of the OpenMP runtime in any library. As an unsafe, unsupported, undocumented workaround you can set the environment variable KMP_DUPLICATE_LIB_OK=TRUE to allow the program to continue to execute, but that may cause crashes or silently produce incorrect results. For more information, please see http://openmp.llvm.org/

This happens, because delocate embeds libomp.dylib into torch/.dylibs folder, but another instance is shipped in torch/libs

Versions

2.3.0

cc @ezyang @gchanan @zou3519 @kadeng @seemethere @osalpekar @atalman @albanD @msaroufim @bdhirsh @anijain2305 @chauhang

@malfet malfet added high priority module: binaries Anything related to official binaries that we release to users module: macos Mac OS related issues module: openmp Related to OpenMP (omp) support in PyTorch oncall: pt2 labels Mar 26, 2024
@malfet malfet added this to the 2.3.0 milestone Mar 26, 2024
@yf225
Copy link
Contributor

yf225 commented Apr 2, 2024

Maybe a better error msg "please install openmp via conda" would help

@desertfire
Copy link
Contributor

desertfire commented Apr 2, 2024

Ok, I tried with downloading a nightly locally, and I did get a warning msg,

warning: overriding currently unsupported use of floating point exceptions on this target [-Wunsupported-floating-point-opt]
In file included from /var/folders/np/d5kv7dqs1s50chyql3nx1nnm0000gn/T/torchinductor_binbao/uj/cujc4esel7544h7dr4wb35xl5dcay5k5kkfkqjaanqdw2gex4gql.cpp:2:
/var/folders/np/d5kv7dqs1s50chyql3nx1nnm0000gn/T/torchinductor_binbao/nd/cndd7co72iqjtof53ikp4l7yibmqrbjkni3cu6xj5p7hywloe5yg.h:8:10: fatal error: 'omp.h' file not found
#include <omp.h>
         ^~~~~~~
1 warning and 1 error generated.


OpenMP support not found. Please try one of the following solutions:
(1) Set the `CXX` environment variable to a compiler other than Apple clang++/g++ that has builtin OpenMP support;
(2) install OpenMP via conda: `conda install llvm-openmp`;
(3) install libomp via brew: `brew install libomp`;
(4) manually setup OpenMP and set the `OMP_PREFIX` environment variable to point to a path with `include/omp.h` under it.

Not sure how to reproduce @malfet 's problem.

@desertfire desertfire added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Apr 2, 2024
@malfet
Copy link
Contributor Author

malfet commented Apr 2, 2024

Hmm, omp.h should have been bundled with torch wheel, see

omp_cflags = get_cmake_cache_vars()["OpenMP_C_FLAGS"]

But indeed, I don't see it there:

% python -c "import torch;from pathlib import Path;print(torch.__version__, torch.backends.openmp.is_available(), '\n'.join(str(x) for x in (Path(torch.utils.cmake_prefix_path).parent.parent / 'include').glob('*.h')))"
2.4.0.dev20240401 False /Users/nshulga/miniforge3/envs/py311/lib/python3.11/site-packages/torch/include/sleef.h
/Users/nshulga/miniforge3/envs/py311/lib/python3.11/site-packages/torch/include/cpuinfo.h
/Users/nshulga/miniforge3/envs/py311/lib/python3.11/site-packages/torch/include/nnpack.h
/Users/nshulga/miniforge3/envs/py311/lib/python3.11/site-packages/torch/include/xnnpack.h
/Users/nshulga/miniforge3/envs/py311/lib/python3.11/site-packages/torch/include/libshm.h
/Users/nshulga/miniforge3/envs/py311/lib/python3.11/site-packages/torch/include/fp16.h
/Users/nshulga/miniforge3/envs/py311/lib/python3.11/site-packages/torch/include/qnnpack_func.h
/Users/nshulga/miniforge3/envs/py311/lib/python3.11/site-packages/torch/include/pthreadpool.h
/Users/nshulga/miniforge3/envs/py311/lib/python3.11/site-packages/torch/include/clog.h
/Users/nshulga/miniforge3/envs/py311/lib/python3.11/site-packages/torch/include/psimd.h
/Users/nshulga/miniforge3/envs/py311/lib/python3.11/site-packages/torch/include/experiments-config.h
/Users/nshulga/miniforge3/envs/py311/lib/python3.11/site-packages/torch/include/fxdiv.h

But the same works for 2.3.0-RC1:

% python -c "import torch;from pathlib import Path;print(torch.__version__, torch.backends.openmp.is_available(), '\n'.join(str(x) for x in (Path(torch.utils.cmake_prefix_path).parent.parent / 'include').glob('*.h')))"
2.3.0 True /Users/nshulga/miniforge3/envs/py311/lib/python3.11/site-packages/torch/include/sleef.h
/Users/nshulga/miniforge3/envs/py311/lib/python3.11/site-packages/torch/include/cpuinfo.h
/Users/nshulga/miniforge3/envs/py311/lib/python3.11/site-packages/torch/include/nnpack.h
/Users/nshulga/miniforge3/envs/py311/lib/python3.11/site-packages/torch/include/xnnpack.h
/Users/nshulga/miniforge3/envs/py311/lib/python3.11/site-packages/torch/include/libshm.h
/Users/nshulga/miniforge3/envs/py311/lib/python3.11/site-packages/torch/include/fp16.h
/Users/nshulga/miniforge3/envs/py311/lib/python3.11/site-packages/torch/include/qnnpack_func.h
/Users/nshulga/miniforge3/envs/py311/lib/python3.11/site-packages/torch/include/pthreadpool.h
/Users/nshulga/miniforge3/envs/py311/lib/python3.11/site-packages/torch/include/clog.h
/Users/nshulga/miniforge3/envs/py311/lib/python3.11/site-packages/torch/include/omp.h
/Users/nshulga/miniforge3/envs/py311/lib/python3.11/site-packages/torch/include/psimd.h
/Users/nshulga/miniforge3/envs/py311/lib/python3.11/site-packages/torch/include/experiments-config.h
/Users/nshulga/miniforge3/envs/py311/lib/python3.11/site-packages/torch/include/fxdiv.h

malfet added a commit to pytorch/torchchat that referenced this issue Apr 3, 2024
Also, move to M1 mac (by adding macos-14 to the matrix), but right now it's blocked by pytorch/pytorch#123225 and pytorch/pytorch#122705 (though later is occluded by pipe redirect)
@malfet malfet assigned malfet and unassigned desertfire Apr 3, 2024
@malfet
Copy link
Contributor Author

malfet commented Apr 3, 2024

Grabbing for myself, as right now it's a packaging problem: we ship with libomp.dylib twice

@malfet
Copy link
Contributor Author

malfet commented Apr 4, 2024

@atalman fyi

% shasum ~/miniforge3/envs/py311-torch230/lib/python3.11/site-packages/torch/.dylibs/libomp.dylib ~/miniforge3/envs/py311-torch230/lib/python3.11/site-packages/torch/lib/libomp.dylib
3e7bbc2948a3c5ab0ed09fcb8912611e958ba1be  /Users/nshulga/miniforge3/envs/py311-torch230/lib/python3.11/site-packages/torch/.dylibs/libomp.dylib
cbce7ccb089ed3598d5fbd9a204f8d5c52b91f7f  /Users/nshulga/miniforge3/envs/py311-torch230/lib/python3.11/site-packages/torch/lib/libomp.dylib

@atalman
Copy link
Contributor

atalman commented Apr 4, 2024

cc @malfet

Looks like its looking in completely different path:
miniconda3/envs/py310/lib/libomp.dylib
rather then torch/lib or relocated lib

  File "/Users/atalman/miniconda3/envs/py310/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 2112, in _load_library_inner
    module = importlib.util.module_from_spec(spec)
  File "<frozen importlib._bootstrap>", line 571, in module_from_spec
  File "<frozen importlib._bootstrap_external>", line 1176, in create_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
ImportError: dlopen(/var/folders/c_/2yk7dry97cs9xg36yv7977hm0000gn/T/torchinductor_atalman/l5/cl5ix6nduljulc7mkdk4kuai3ezvnt6z6kpspq7x5drbduw3gmbl.so, 0x0002): Library not loaded: @rpath/libomp.dylib
  Referenced from: <526A8617-346C-3DD1-987B-9B31B4F9A70F> /private/var/folders/c_/2yk7dry97cs9xg36yv7977hm0000gn/T/torchinductor_atalman/l5/cl5ix6nduljulc7mkdk4kuai3ezvnt6z6kpspq7x5drbduw3gmbl.so
  Reason: tried: '/Users/atalman/miniconda3/envs/py310/bin/../lib/libomp.dylib' (no such file), '/Users/atalman/miniconda3/envs/py310/bin/../lib/libomp.dylib' (no such file)

doing

otool -L /var/folders/c_/2yk7dry97cs9xg36yv7977hm0000gn/T/torchinductor_atalman/l5/cl5ix6nduljulc7mkdk4kuai3ezvnt6z6kpspq7x5drbduw3gmbl.so
/var/folders/c_/2yk7dry97cs9xg36yv7977hm0000gn/T/torchinductor_atalman/l5/cl5ix6nduljulc7mkdk4kuai3ezvnt6z6kpspq7x5drbduw3gmbl.so:
	/var/folders/c_/2yk7dry97cs9xg36yv7977hm0000gn/T/torchinductor_atalman/l5/cl5ix6nduljulc7mkdk4kuai3ezvnt6z6kpspq7x5drbduw3gmbl.so (compatibility version 0.0.0, current version 0.0.0)
	@rpath/libomp.dylib (compatibility version 5.0.0, current version 5.0.0)
	@rpath/libc10.dylib (compatibility version 0.0.0, current version 0.0.0)
	/usr/lib/libc++.1.dylib (compatibility version 1.0.0, current version 1700.255.0)
	/usr/lib/libSystem.B.dylib (compatibility version 1.0.0, current version 1345.100.2)

malfet added a commit that referenced this issue Apr 5, 2024
To prevent delocate from double-packing it, which makes Torch wheels
unusable with torch.compile out of the box

Fixes #122705
@atalman
Copy link
Contributor

atalman commented Apr 5, 2024

Please note. The packaging of libomp.dylib twice, exist only in trunk, currently. Release 2.3.0 seems not to be affected. This seems like 2 different problems:
1 - path
2 -packaging of libiomp twise

pytorchbot pushed a commit that referenced this issue Apr 5, 2024
To prevent delocate from double-packing it, which makes Torch wheels
unusable with torch.compile out of the box

Fixes #122705

Pull Request resolved: #123417
Approved by: https://github.com/atalman

(cherry picked from commit 5b0ce8f)
atalman pushed a commit to atalman/pytorch that referenced this issue Apr 5, 2024
To prevent delocate from double-packing it, which makes Torch wheels
unusable with torch.compile out of the box

Fixes pytorch#122705

Pull Request resolved: pytorch#123417
Approved by: https://github.com/atalman
atalman added a commit that referenced this issue Apr 5, 2024
To prevent delocate from double-packing it, which makes Torch wheels
unusable with torch.compile out of the box

Fixes #122705

Pull Request resolved: #123417
Approved by: https://github.com/atalman

Co-authored-by: Nikita Shulga <nikita.shulga@gmail.com>
sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this issue Apr 22, 2024
To prevent delocate from double-packing it, which makes Torch wheels
unusable with torch.compile out of the box

Fixes pytorch#122705

Pull Request resolved: pytorch#123417
Approved by: https://github.com/atalman
malfet added a commit to pytorch/torchchat that referenced this issue Jul 17, 2024
Also, move to M1 mac (by adding macos-14 to the matrix), but right now it's blocked by pytorch/pytorch#123225 and pytorch/pytorch#122705 (though later is occluded by pipe redirect)
malfet added a commit to pytorch/torchchat that referenced this issue Jul 17, 2024
Also, move to M1 mac (by adding macos-14 to the matrix), but right now it's blocked by pytorch/pytorch#123225 and pytorch/pytorch#122705 (though later is occluded by pipe redirect)
malfet added a commit to pytorch/torchchat that referenced this issue Jul 17, 2024
Also, move to M1 mac (by adding macos-14 to the matrix), but right now it's blocked by pytorch/pytorch#123225 and pytorch/pytorch#122705 (though later is occluded by pipe redirect)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: binaries Anything related to official binaries that we release to users module: macos Mac OS related issues module: openmp Related to OpenMP (omp) support in PyTorch oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants