Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

TORCH_DISTRIBUTED_DEBUG not effective #70667

Closed
cicirori opened this issue Jan 5, 2022 · 10 comments
Closed

TORCH_DISTRIBUTED_DEBUG not effective #70667

cicirori opened this issue Jan 5, 2022 · 10 comments
Assignees
Labels
better-engineering Relatively self-contained tasks for better engineering contributors high priority module: c10d Issues/PRs related to collective communications and process groups oncall: distributed Add this issue/PR to distributed oncall triage queue

Comments

@cicirori
Copy link

cicirori commented Jan 5, 2022

馃悰 Describe the bug

I'M testING TORCH_DISTRIBUTED_DEBUG according to this documents: https://pytorch.org/docs/master/distributed.html#debugging-torch-distributed-applications

import os

import torch
import torch.distributed as dist
import torch.multiprocessing as mp


class TwoLinLayerNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.a = torch.nn.Linear(10, 10, bias=False)
        self.b = torch.nn.Linear(10, 1, bias=False)

    def forward(self, x):
        a = self.a(x)
        b = self.b(x)
        return (a, b)


def worker(rank):
    dist.init_process_group("nccl", rank=rank, world_size=2)
    torch.cuda.set_device(rank)
    print("init model")
    model = TwoLinLayerNet().cuda()
    print("init ddp")
    ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])

    inp = torch.randn(10, 10).cuda()
    print("train")

    for _ in range(20):
        output = ddp_model(inp)
        loss = output[0] + output[1]
        loss.sum().backward()


if __name__ == "__main__":
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29501"
    os.environ[
        "TORCH_DISTRIBUTED_DEBUG"
    ] = "DETAIL"  # set to DETAIL for runtime logging.
    mp.spawn(worker, nprocs=2, args=())

The code above is from this document and is identical

but the output has no ddp debug information

init model
init model
init ddp
init ddp
i39a12200:2427:2427 [0] NCCL INFO Bootstrap : Using bond0:11.164.100.100<0>
i39a12200:2427:2427 [0] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation

i39a12200:2427:2427 [0] misc/ibvwrap.cc:63 NCCL WARN Failed to open libibverbs.so[.1]
i39a12200:2427:2427 [0] NCCL INFO NET/Socket : Using [0]bond0:11.164.100.100<0>
i39a12200:2427:2427 [0] NCCL INFO Using network Socket
NCCL version 2.10.3+cuda11.3
i39a12200:2428:2428 [1] NCCL INFO Bootstrap : Using bond0:11.164.100.100<0>
i39a12200:2428:2428 [1] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation

i39a12200:2428:2428 [1] misc/ibvwrap.cc:63 NCCL WARN Failed to open libibverbs.so[.1]
i39a12200:2428:2428 [1] NCCL INFO NET/Socket : Using [0]bond0:11.164.100.100<0>
i39a12200:2428:2428 [1] NCCL INFO Using network Socket
i39a12200:2427:2577 [0] NCCL INFO NCCL_MAX_NCHANNELS set by environment to 2.
i39a12200:2428:2578 [1] NCCL INFO NCCL_MAX_NCHANNELS set by environment to 2.
i39a12200:2427:2577 [0] NCCL INFO NCCL_MIN_NCHANNELS set by environment to 2.
i39a12200:2427:2577 [0] NCCL INFO Channel 00/02 :    0   1
i39a12200:2428:2578 [1] NCCL INFO NCCL_MIN_NCHANNELS set by environment to 2.
i39a12200:2428:2578 [1] NCCL INFO Trees [0] -1/-1/-1->1->0 [1] -1/-1/-1->1->0
i39a12200:2427:2577 [0] NCCL INFO Channel 01/02 :    0   1
i39a12200:2428:2578 [1] NCCL INFO Setting affinity for GPU 1 to ff,ffffffff,ffffffff,ffffffff
i39a12200:2427:2577 [0] NCCL INFO Trees [0] 1/-1/-1->0->-1 [1] 1/-1/-1->0->-1
i39a12200:2427:2577 [0] NCCL INFO Setting affinity for GPU 0 to ff,ffffffff,ffffffff,ffffffff
i39a12200:2428:2578 [1] NCCL INFO Channel 00 : 1[57000] -> 0[52000] via P2P/IPC/read
i39a12200:2427:2577 [0] NCCL INFO Channel 00 : 0[52000] -> 1[57000] via P2P/IPC/read
i39a12200:2428:2578 [1] NCCL INFO Channel 01 : 1[57000] -> 0[52000] via P2P/IPC/read
i39a12200:2427:2577 [0] NCCL INFO Channel 01 : 0[52000] -> 1[57000] via P2P/IPC/read
i39a12200:2428:2578 [1] NCCL INFO Connected all rings
i39a12200:2428:2578 [1] NCCL INFO Connected all trees
i39a12200:2428:2578 [1] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 8/8/512
i39a12200:2428:2578 [1] NCCL INFO 2 coll channels, 2 p2p channels, 2 p2p channels per peer
i39a12200:2427:2577 [0] NCCL INFO Connected all rings
i39a12200:2427:2577 [0] NCCL INFO Connected all trees
i39a12200:2427:2577 [0] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 8/8/512
i39a12200:2427:2577 [0] NCCL INFO 2 coll channels, 2 p2p channels, 2 p2p channels per peer
i39a12200:2428:2578 [1] NCCL INFO comm 0x7fdb5c002f70 rank 1 nranks 2 cudaDev 1 busId 57000 - Init COMPLETE
i39a12200:2427:2577 [0] NCCL INFO comm 0x7fbfdc002f70 rank 0 nranks 2 cudaDev 0 busId 52000 - Init COMPLETE
i39a12200:2427:2427 [0] NCCL INFO Launch mode Parallel
train
train

Do I need to successfully configure USE_GLOG at compile time or specify a CAFFE2_LOG_THRESHOLD of no less than INFO in order to print these messages? If so, I suggest refine the documentation here.

The PyTorch build used:

pip3 install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html

Versions

PyTorch version: 1.10.1+cu113
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.19.1
Libc version: glibc-2.9

Python version: 3.6.12 |Anaconda, Inc.| (default, Sep  8 2020, 23:10:56)  [GCC 7.3.0] (64-bit runtime)
Python platform: Linux-3.10.0-327.ali2019.alios7.x86_64-x86_64-with-debian-buster-sid
Is CUDA available: True
CUDA runtime version: 11.3.109
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB
GPU 2: NVIDIA A100-SXM4-40GB
GPU 3: NVIDIA A100-SXM4-40GB
GPU 4: NVIDIA A100-SXM4-40GB
GPU 5: NVIDIA A100-SXM4-40GB
GPU 6: NVIDIA A100-SXM4-40GB
GPU 7: NVIDIA A100-SXM4-40GB

Nvidia driver version: 470.82.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.2.0
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.16.6
[pip3] torch==1.10.1+cu113
[pip3] torch-cluster==1.5.9
[pip3] torch-geometric==2.0.3
[pip3] torch-scatter==2.0.9
[pip3] torch-sparse==0.6.12
[pip3] torch-spline-conv==1.2.1
[pip3] torch-xla==1.10
[pip3] torchaudio==0.10.1+cu113
[pip3] torchtext==0.11.1
[pip3] torchvision==0.11.2+cu113
[conda] magma-cuda113             2.5.2                         1    pytorch
[conda] mkl                       2021.4.0           h06a4308_640    defaults
[conda] mkl-include               2022.0.1           h8d4b97c_803    conda-forge
[conda] numpy                     1.16.6                   pypi_0    pypi
[conda] torch                     1.10.1+cu113             pypi_0    pypi
[conda] torch-cluster             1.5.9                    pypi_0    pypi
[conda] torch-geometric           2.0.3                    pypi_0    pypi
[conda] torch-scatter             2.0.9                    pypi_0    pypi
[conda] torch-sparse              0.6.12                   pypi_0    pypi
[conda] torch-spline-conv         1.2.1                    pypi_0    pypi
[conda] torch-xla                 1.10                     pypi_0    pypi
[conda] torchaudio                0.10.1+cu113             pypi_0    pypi
[conda] torchtext                 0.11.1                   pypi_0    pypi
[conda] torchvision               0.11.2+cu113             pypi_0    pypi

cc @ezyang @gchanan @zou3519 @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang

@facebook-github-bot facebook-github-bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jan 5, 2022
@pritamdamania87
Copy link
Contributor

I think you probably might have to compile with USE_GLOG=1 to enable this. @rohan-varma I was wondering if we tested this with the OSS pytorch binaries. It seems like USE_GLOG is off by default: https://github.com/pytorch/pytorch/blob/master/CMakeLists.txt#L217, so was wondering if this additional logging would be missing if users just install the pytorch binaries.

@pritamdamania87 pritamdamania87 added the module: c10d Issues/PRs related to collective communications and process groups label Jan 5, 2022
@rohan-varma rohan-varma self-assigned this Jan 6, 2022
@jbuckman
Copy link

jbuckman commented Jan 9, 2022

I'm having this same issue, but somehow getting even less information. Running the exact code above, no debug info is printed, only

init model
init model
init ddp
init ddp
train
train

Version info:

PyTorch version: 1.10.1+cu113
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.15.5
Libc version: glibc-2.27

Python version: 3.9.4 (default, Apr  9 2021, 16:34:09)  [GCC 7.3.0] (64-bit runtime)
Python platform: Linux-5.4.0-1064-azure-x86_64-with-glibc2.27
Is CUDA available: True
CUDA runtime version: 10.1.243
GPU models and configuration: 
GPU 0: Tesla P100-PCIE-16GB
GPU 1: Tesla P100-PCIE-16GB
GPU 2: Tesla P100-PCIE-16GB
GPU 3: Tesla P100-PCIE-16GB

Nvidia driver version: 460.91.03
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.20.1
[pip3] numpy-ml==0.1.2
[pip3] pytorch-lightning==1.4.0.dev0
[pip3] torch==1.10.1+cu113
[pip3] torchaudio==0.10.1+cu113
[pip3] torchmetrics==0.3.2
[pip3] torchvision==0.11.2+cu113
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               11.1.1               h6406543_8    conda-forge
[conda] mkl                       2021.2.0           h726a3e6_389    conda-forge
[conda] mkl-service               2.3.0            py39h27cfd23_1  
[conda] mkl_fft                   1.3.0            py39h42c9631_2  
[conda] mkl_random                1.2.1            py39ha9443f7_2  
[conda] numpy                     1.20.1           py39h93e21f0_0  
[conda] numpy-base                1.20.1           py39h7d8b39e_0  
[conda] numpy-ml                  0.1.2                    pypi_0    pypi
[conda] pytorch-lightning         1.4.0.dev0                dev_0    <develop>
[conda] torch                     1.10.1+cu113             pypi_0    pypi
[conda] torchaudio                0.10.1+cu113             pypi_0    pypi
[conda] torchmetrics              0.3.2                    pypi_0    pypi
[conda] torchvision               0.11.2+cu113             pypi_0    pypi

@cicirori
Copy link
Author

I'm having this same issue, but somehow getting even less information. Running the exact code above, no debug info is printed, only

init model
init model
init ddp
init ddp
train
train

Version info:

PyTorch version: 1.10.1+cu113
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.15.5
Libc version: glibc-2.27

Python version: 3.9.4 (default, Apr  9 2021, 16:34:09)  [GCC 7.3.0] (64-bit runtime)
Python platform: Linux-5.4.0-1064-azure-x86_64-with-glibc2.27
Is CUDA available: True
CUDA runtime version: 10.1.243
GPU models and configuration: 
GPU 0: Tesla P100-PCIE-16GB
GPU 1: Tesla P100-PCIE-16GB
GPU 2: Tesla P100-PCIE-16GB
GPU 3: Tesla P100-PCIE-16GB

Nvidia driver version: 460.91.03
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.20.1
[pip3] numpy-ml==0.1.2
[pip3] pytorch-lightning==1.4.0.dev0
[pip3] torch==1.10.1+cu113
[pip3] torchaudio==0.10.1+cu113
[pip3] torchmetrics==0.3.2
[pip3] torchvision==0.11.2+cu113
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               11.1.1               h6406543_8    conda-forge
[conda] mkl                       2021.2.0           h726a3e6_389    conda-forge
[conda] mkl-service               2.3.0            py39h27cfd23_1  
[conda] mkl_fft                   1.3.0            py39h42c9631_2  
[conda] mkl_random                1.2.1            py39ha9443f7_2  
[conda] numpy                     1.20.1           py39h93e21f0_0  
[conda] numpy-base                1.20.1           py39h7d8b39e_0  
[conda] numpy-ml                  0.1.2                    pypi_0    pypi
[conda] pytorch-lightning         1.4.0.dev0                dev_0    <develop>
[conda] torch                     1.10.1+cu113             pypi_0    pypi
[conda] torchaudio                0.10.1+cu113             pypi_0    pypi
[conda] torchmetrics              0.3.2                    pypi_0    pypi
[conda] torchvision               0.11.2+cu113             pypi_0    pypi

set NCCL_DEBUG=INFO, you should get exactly the same output with mine.

@rohan-varma
Copy link
Member

Thanks for pointing this out! It indeed looks like the USE_GLOG is set to off in prebuilt PyTorch binaries and thus calls like LOG(INFO) wouldn't work. Here are a couple of similar discussions:

@malfet - is there any context as to why PyTorch prebuilt binaries have USE_GLOG and thus LOG calls off and no way to turn them on? Is there a workaround or could we possibly consider building with USE_GLOG on for the prebuilt binaries?

In the short-term, we can resolve this by logging to stderr, although writing these logs with LOG(INFO) would be preferable.

@rohan-varma
Copy link
Member

cc @cbalioglu

@cbalioglu
Copy link
Contributor

See #68226 as well.

@malfet
Copy link
Contributor

malfet commented Jan 21, 2022

@malfet - is there any context as to why PyTorch prebuilt binaries have USE_GLOG and thus LOG calls off and no way to turn them on? Is there a workaround or could we possibly consider building with USE_GLOG on for the prebuilt binaries?

GLOG was disabled by #16789 , you can probably ask @soumith for the reason why, but my guess is that it introduces lots of runtime dependencies (both glog, gflags and protobuf are notoriously hard to integrate so that they would not conflict with other installations of the libraries in the system.
Also, it feels a bit as an overkill, as https://github.com/pytorch/pytorch/blob/master/c10/util/logging_is_not_google_glog.h should give similar abstraction and I believe is translated back to python logging (though couldn't find the code on the top of my head)

It would be a great BE effort to unify multiple PyTorch logging primitives together, such as base c10 logging and systems like jit/runtime/logging.h and jit/jit_log.h to name a few

Please note, that debug level logging is likely compiled out from release build for performance reason, as even a dummy function call is expensive in the code that is called billions of times

@malfet
Copy link
Contributor

malfet commented Jan 25, 2022

It looks like easiest solution for TORCH_DISTRIBUTED_DEBUG environment variable is to add call to c10::showLogInfoToStderr() defined here

C10_API void ShowLogInfoToStderr();

to say here

if (parseDistDebugLevel() == DistributedDebugLevel::DETAIL) {
LOG(INFO) << *this;
}

@ShawnZhong
Copy link
Contributor

Seems that you need to set TORCH_CPP_LOG_LEVEL=INFO as well. c.f., #71746 and #73361

@rohan-varma rohan-varma added better-engineering Relatively self-contained tasks for better engineering contributors and removed triage review labels May 3, 2022
@rohan-varma
Copy link
Member

@cbalioglu has fixed this with #71746

pytorchmergebot pushed a commit that referenced this issue May 3, 2022
Fixes #70667

`TORCH_CPP_LOG_LEVEL=INFO` is needed for `TORCH_DISTRIBUTED_DEBUG` to be effective.

For reference, #71746 introduced the environment variable `TORCH_CPP_LOG_LEVEL` and #73361 documented it.

Pull Request resolved: #76625
Approved by: https://github.com/rohan-varma
rohan-varma added a commit that referenced this issue May 4, 2022
Fixes #70667

`TORCH_CPP_LOG_LEVEL=INFO` is needed for `TORCH_DISTRIBUTED_DEBUG` to be effective.

For reference, #71746 introduced the environment variable `TORCH_CPP_LOG_LEVEL` and #73361 documented it.
Approved by: https://github.com/rohan-varma

[ghstack-poisoned]
facebook-github-bot pushed a commit that referenced this issue May 4, 2022
Summary:
Fixes #70667

`TORCH_CPP_LOG_LEVEL=INFO` is needed for `TORCH_DISTRIBUTED_DEBUG` to be effective.

For reference, #71746 introduced the environment variable `TORCH_CPP_LOG_LEVEL` and #73361 documented it.

Pull Request resolved: #76625
Approved by: https://github.com/rohan-varma

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/9c902f4749a7b023cb5a6f97b943ba65796418c3

Reviewed By: malfet

Differential Revision: D36134083

fbshipit-source-id: 15fb58706412d3af029d5544654c70cb85670d6f
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
better-engineering Relatively self-contained tasks for better engineering contributors high priority module: c10d Issues/PRs related to collective communications and process groups oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

Successfully merging a pull request may close this issue.

8 participants