Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Could not jit compile custom extension in dataparallel mode #125403

Closed
daniil-lyakhov opened this issue May 2, 2024 · 0 comments
Closed

Could not jit compile custom extension in dataparallel mode #125403

daniil-lyakhov opened this issue May 2, 2024 · 0 comments
Labels
module: cpp-extensions Related to torch.utils.cpp_extension module: ddp Issues/PRs related distributed data parallel training triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@daniil-lyakhov
Copy link
Contributor

daniil-lyakhov commented May 2, 2024

🐛 Describe the bug

Functions torch.utils.cpp_extension.load_inline and torch.utils.cpp_extension.load does not work correctly when called in DataPrallel mode

import torch
from torch.utils.cpp_extension import load_inline

# torch extensions cache should be cleared before the test
if not torch.cuda.is_available() or torch.cuda.device_count() < 2:
    raise RuntimeError("Wrong env for the reproducer.")

class TestModel(torch.nn.Module):
    def forward(self, x):
        code = "int f() {return 2;}"
        module = load_inline(
                name='jit_extension',
                cpp_sources=code,
                functions='f',
                verbose=True)
        return x * module.f()


model = torch.nn.DataParallel(TestModel().cuda())
output = model(torch.ones([10, 1, 1, 1], device="cuda"))
assert torch.all(output == 2.)

The reason is while the first thread is budling the extension, the second one asks JIT_EXTENSION_VERSIONER for a current version, and the JIT_EXTENTION_VERSIONER returns the version of the first thread build.
https://github.com/pytorch/pytorch/blob/main/torch/utils/cpp_extension.py#L1677-L1696
Second thread then skips lock file and tries to load non existed extension, failing with an error:

Traceback (most recent call last):
  File "/home/dlyakhov/miniconda3/lib/python3.12/runpy.py", line 198, in _run_module_as_main
    return _run_code(code, main_globals, None,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dlyakhov/miniconda3/lib/python3.12/runpy.py", line 88, in _run_code
    exec(code, run_globals)
  File "/home/dlyakhov/.vscode-server/extensions/ms-python.python-2023.8.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
    cli.main()
  File "/home/dlyakhov/.vscode-server/extensions/ms-python.python-2023.8.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
    run()
  File "/home/dlyakhov/.vscode-server/extensions/ms-python.python-2023.8.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
    runpy.run_path(target, run_name="__main__")
  File "/home/dlyakhov/.vscode-server/extensions/ms-python.python-2023.8.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
    return _run_module_code(code, init_globals, run_name,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dlyakhov/.vscode-server/extensions/ms-python.python-2023.8.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/home/dlyakhov/.vscode-server/extensions/ms-python.python-2023.8.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
    exec(code, run_globals)
  File "/home/dlyakhov/Projects/pytorch/repro.py", line 20, in <module>
    output = model(torch.ones([10, 1, 1, 1], device="cuda"))
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dlyakhov/Projects/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dlyakhov/Projects/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dlyakhov/Projects/pytorch/torch/nn/parallel/data_parallel.py", line 185, in forward
    outputs = self.parallel_apply(replicas, inputs, module_kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dlyakhov/Projects/pytorch/torch/nn/parallel/data_parallel.py", line 200, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dlyakhov/Projects/pytorch/torch/nn/parallel/parallel_apply.py", line 108, in parallel_apply
    output.reraise()
  File "/home/dlyakhov/Projects/pytorch/torch/_utils.py", line 708, in reraise
    raise exception
ImportError: Caught ImportError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/home/dlyakhov/Projects/pytorch/torch/nn/parallel/parallel_apply.py", line 83, in _worker
    output = module(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dlyakhov/Projects/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dlyakhov/Projects/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dlyakhov/Projects/pytorch/repro.py", line 11, in forward
    module = load_inline(
             ^^^^^^^^^^^^
  File "/home/dlyakhov/Projects/pytorch/torch/utils/cpp_extension.py", line 1644, in load_inline
    return _jit_compile(
           ^^^^^^^^^^^^^
  File "/home/dlyakhov/Projects/pytorch/torch/utils/cpp_extension.py", line 1745, in _jit_compile
    return _import_module_from_library(name, build_directory, is_python_module)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dlyakhov/Projects/pytorch/torch/utils/cpp_extension.py", line 2143, in _import_module_from_library
    module = importlib.util.module_from_spec(spec)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<frozen importlib._bootstrap>", line 813, in module_from_spec
  File "<frozen importlib._bootstrap_external>", line 1289, in create_module
  File "<frozen importlib._bootstrap>", line 488, in _call_with_frames_removed
ImportError: /home/dlyakhov/.cache/torch_extensions/py312_cu118/jit_extension/jit_extension.so: cannot open shared object file: No such file or directory

I would like to submit a PR with a fix: #125404

Versions

PyTorch version: 2.4.0a0+git8046de3
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.26.4
Libc version: glibc-2.31

Python version: 3.12.2 | packaged by Anaconda, Inc. | (main, Feb 27 2024, 17:35:02) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-105-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 3090
GPU 1: NVIDIA GeForce RTX 3090
GPU 2: NVIDIA GeForce RTX 3090

Nvidia driver version: 525.147.05
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 48 bits virtual
CPU(s): 36
On-line CPU(s) list: 0-35
Thread(s) per core: 2
Core(s) per socket: 18
Socket(s): 1
NUMA node(s): 1
Vendor ID: GenuineIntel
CPU family: 6
Model: 85
Model name: Intel(R) Core(TM) i9-10980XE CPU @ 3.00GHz
Stepping: 7
CPU MHz: 3000.000
CPU max MHz: 4800,0000
CPU min MHz: 1200,0000
BogoMIPS: 6000.00
Virtualization: VT-x
L1d cache: 576 KiB
L1i cache: 576 KiB
L2 cache: 18 MiB
L3 cache: 24,8 MiB
NUMA node0 CPU(s): 0-35
Vulnerability Gather data sampling: Mitigation; Microcode
Vulnerability Itlb multihit: KVM: Mitigation: VMX disabled
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Mitigation; TSX disabled
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req avx512_vnni md_clear flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] pytorch-triton==3.0.0+45fff310c8
[pip3] torch==2.4.0a0+git8046de3
[conda] magma-cuda110 2.5.2 1 pytorch
[conda] mkl-include 2024.1.0 intel_691 intel
[conda] mkl-static 2024.1.0 intel_691 intel
[conda] pytorch-triton 3.0.0+45fff310c8 pypi_0 pypi
[conda] torch 2.4.0a0+git8046de3 dev_0

cc @malfet @zou3519

@drisspg drisspg added module: cpp-extensions Related to torch.utils.cpp_extension triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: ddp Issues/PRs related distributed data parallel training labels May 3, 2024
alexsu52 pushed a commit to openvinotoolkit/nncf that referenced this issue May 3, 2024
…#2662)

### Changes

Model moved to cuda before the quantization in
`test_works_when_wrapped_with_dataparallel`

### Reason for changes

To build CUDA extensions in main process as pytorch could not build
extensions in dataparallell mode

### Related tickets

pytorch/pytorch#125403

### Tests

torch_nightly/204/ - Passed

Test started to fail because the order of the tests were changed:

torch_nightly/197:
```
tests/torch/nas/test_sanity_sample.py ........                           [  0%]
[2024-04-25T00:01:42.305Z] tests/torch/quantization/test_functions.py ............sss...sss........ [  2%]
...
```
torch_nightly/198
```
[2024-04-26T00:03:47.228Z] tests/torch/nas/test_sanity_sample.py ........                           [  0%]
[2024-04-26T00:03:47.804Z] tests/torch/nncf_network/test_nncf_network.py .                          [  0%]
[2024-04-26T00:04:26.574Z] tests/torch/quantization/test_algo_quantization.py F                     [  0%]
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cpp-extensions Related to torch.utils.cpp_extension module: ddp Issues/PRs related distributed data parallel training triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants