Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AWS EC2 P3DN, EFA is enabled] Torch RPC tensorpipe/common/ibv.h:172 "": Operation not supported #65022

Open
chaoyanghe opened this issue Sep 14, 2021 · 12 comments
Labels
module: rpc Related to RPC, distributed autograd, RRef, and distributed optimizer module: tensorpipe Related to Tensorpipe RPC Agent oncall: distributed Add this issue/PR to distributed oncall triage queue pipeline parallelism Issues related to https://pytorch.org/docs/master/pipeline.html triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@chaoyanghe
Copy link

chaoyanghe commented Sep 14, 2021

🐛 Bug

I got the following error when using sync.Pipe and initializing RPC.

To Reproduce

I provide a small github repo and related script to reproduce this issue (https://github.com/chaoyanghe/pytorch_bug_reproduce). A full error log is also maintained there.

Expected behavior

Successfully run the Pipe demo.

Environment

I run my source code on AWS EC2 P3DN GPU server. EFA is enabled.

Collecting environment information...
PyTorch version: N/A
Is debug build: N/A
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: N/A

OS: Amazon Linux 2 (x86_64)
GCC version: (GCC) 7.3.1 20180712 (Red Hat 7.3.1-9)
Clang version: 7.0.1 (Amazon Linux 2 7.0.1-1.amzn2.0.2)
CMake version: version 3.18.2
Libc version: glibc-2.2.5

Python version: 2.7.18 (default, Aug 27 2020, 21:22:52) [GCC 7.3.1 20180712 (Red Hat 7.3.1-9)] (64-bit runtime)
Python platform: Linux-4.14.200-155.322.amzn2.x86_64-x86_64-with-glibc2.2.5
Is CUDA available: N/A
CUDA runtime version: 11.0.221
GPU models and configuration:
GPU 0: Tesla V100-SXM2-32GB
GPU 1: Tesla V100-SXM2-32GB
GPU 2: Tesla V100-SXM2-32GB
GPU 3: Tesla V100-SXM2-32GB
GPU 4: Tesla V100-SXM2-32GB
GPU 5: Tesla V100-SXM2-32GB
GPU 6: Tesla V100-SXM2-32GB
GPU 7: Tesla V100-SXM2-32GB

Nvidia driver version: 450.80.02
cuDNN version: Probably one of the following:
/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcudnn.so.7.6.5
/usr/local/cuda-10.2/targets/x86_64-linux/lib/libcudnn.so.7.6.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn.so.8.0.4
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.0.4
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.0.4
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.0.4
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.0.4
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.0.4
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.0.4
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip] numpy==1.16.6
[conda] Could not collect

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang @cbalioglu @gcramer23 @jjlilley @mrzzd @lw @beauby

@facebook-github-bot facebook-github-bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Sep 14, 2021
@chaoyanghe chaoyanghe changed the title Torch RPC tensorpipe/common/ibv.h:172 "": Operation not supported [AWS EC2 P3DN, EFA is enabled] Torch RPC tensorpipe/common/ibv.h:172 "": Operation not supported Sep 14, 2021
@chaoyanghe
Copy link
Author

after I changed the backend to processgroup, it works. But I want to use torchpipe to improve the performance.

    def _init_rpc(self):
        # https://github.com/pytorch/pytorch/issues/55615
        # [BC-Breaking][RFC] Retire ProcessGroup Backend for RPC #55615
        str_init_method = "tcp://" + str(self.master_addr) + ":10000"
        logging.info("str_init_method = {}".format(str_init_method))
        options = rpc.ProcessGroupRpcBackendOptions(
            num_send_recv_threads=4, rpc_timeout=0.0, init_method=str_init_method
        )
        rpc.init_rpc(
            "worker:" + str(self.global_rank),
            backend=dist.rpc.BackendType.PROCESS_GROUP,
            rank=self.global_rank,
            world_size=self.world_size,
            rpc_backend_options=options,
        )
        # torch.distributed.rpc.init_rpc('worker', rank=self.global_rank, world_size=self.world_size)
        logging.info("init_rpc finished.")

@pritamdamania87 pritamdamania87 added the module: tensorpipe Related to Tensorpipe RPC Agent label Sep 14, 2021
@pritamdamania87
Copy link
Contributor

@lw Was wondering if you could help out here, this seems to be coming from TensorPipe ibv support.

@pritamdamania87 pritamdamania87 added module: rpc Related to RPC, distributed autograd, RRef, and distributed optimizer pipeline parallelism Issues related to https://pytorch.org/docs/master/pipeline.html labels Sep 14, 2021
@chaoyanghe
Copy link
Author

I also have another question:
#65038

❓ Questions and Help

For PyTorch pipeline, I think we have two ways to pass tensors to a skipped layer. Which way has higher performance?

  1. passing tensor layer by layer
    class Layer1(nn.Module):
        def __init__(self, hidden_size):
            super().__init__()
            self.fc1 = nn.Linear(hidden_size, hidden_size).cuda(local_rank)

        def forward(self, input):
            original_input = input
            return self.fc1(input), original_input

    class Layer2(nn.Module):
        def __init__(self, hidden_size):
            super().__init__()
            self.fc2 = nn.Linear(hidden_size, hidden_size).cuda(local_rank)

        def forward(self, input, original_input):
            return self.fc2(input), original_input

    class Layer3(nn.Module):
        def __init__(self, hidden_size):
            super().__init__()
            self.fc3 = nn.Linear(hidden_size, hidden_size).cuda(local_rank)

        def forward(self, input, original_input):
            return self.fc3(input) + original_input

layer1 = Layer1(hidden_size)
layer2 = Layer2(hidden_size)
layer3 = Layer3(hidden_size)
pipeline_model = nn.Sequential(layer1, layer2, layer3)

#  build Pipe (torch.distributed.pipeline.sync.Pipe)
self.pipeline_model = PipeModelWrapper(Pipe(pipeline_model, chunks=4, checkpoint="never"))
  1. using skip-connection API
    @skippable(stash=['1to3'])
    class Layer1(nn.Module):
        def __init__(self, hidden_size):
            super().__init__()
            self.fc1 = nn.Linear(hidden_size, hidden_size).cuda(local_rank)

        def forward(self, input):
            yield stash('1to3', input)
            return self.fc1(input)

    class Layer2(nn.Module):
        def __init__(self, hidden_size):
            super().__init__()
            self.fc2 = nn.Linear(hidden_size, hidden_size).cuda(local_rank)

        def forward(self, input):
            return self.fc2(input)

    @skippable(pop=['1to3'])
    class Layer3(nn.Module):
        def __init__(self, hidden_size):
            super().__init__()
            self.fc3 = nn.Linear(hidden_size, hidden_size).cuda(local_rank)

        def forward(self, input):
            skip_1to3 = yield pop('1to3')
            return self.fc3(input) + skip_1to3

layer1 = Layer1(hidden_size)
layer2 = Layer2(hidden_size)
layer3 = Layer3(hidden_size)
pipeline_model = nn.Sequential(layer1, layer2, layer3)

#  build Pipe (torch.distributed.pipeline.sync.Pipe)
self.pipeline_model = PipeModelWrapper(Pipe(pipeline_model, chunks=4, checkpoint="never"))

@pritamdamania87
Copy link
Contributor

@chaoyanghe Can you try the following workaround for the tensorpipe issue:

options = rpc.TensorPipeRpcBackendOptions(
        init_method="file://{}".format(tmpfile.name),
        _transports=["uv"],
)

@chaoyanghe
Copy link
Author

@pritamdamania87 Is this what you expect?

    def _init_rpc_with_torchpipe(self):
        # https://github.com/pytorch/pytorch/issues/55615
        # [BC-Breaking][RFC] Retire ProcessGroup Backend for RPC #55615
        str_init_method = "tcp://" + str(self.master_addr) + ":10000"
        logging.info("str_init_method = {}".format(str_init_method))
        options = rpc.TensorPipeRpcBackendOptions(
            num_worker_threads=16, rpc_timeout=20, init_method=str_init_method, _transports=["uv"]
        )
        rpc.init_rpc(
            "worker:" + str(self.global_rank),
            backend=rpc.BackendType.TENSORPIPE,
            rank=self.global_rank,
            world_size=self.world_size,
            rpc_backend_options=options,
        )
        # torch.distributed.rpc.init_rpc('worker', rank=self.global_rank, world_size=self.world_size)
        logging.info("init_rpc finished.")

@chaoyanghe
Copy link
Author

@pritamdamania87 I tested by just adding "_transports=["uv"]", it works now. Thank you! It's better to handle this automatically by pytorch APIs.

@chaoyanghe
Copy link
Author

@pritamdamania87 After I integrated this demo into our project. I met this issue:

10.0.87.143: 10.0.92.5: [W tensorpipe_agent.cpp:843] RPC agent for worker:17 encountered error when reading incoming request from worker:0: EOF: end of file (this error originated at tensorpipe/transport/uv/connection_
impl.cc:132)
10.0.87.143: 10.0.89.236: [W tensorpipe_agent.cpp:843] RPC agent for worker:59 encountered error when reading incoming request from worker:0: EOF: end of file (this error originated at tensorpipe/transport/uv/connectio
n_impl.cc:132)
10.0.87.143: 10.0.95.41: [W tensorpipe_agent.cpp:843] RPC agent for worker:9 encountered error when reading incoming request from worker:0: EOF: end of file (this error originated at tensorpipe/transport/uv/connection_
impl.cc:132)
10.0.87.143: 10.0.92.5: [W tensorpipe_agent.cpp:843] RPC agent for worker:18 encountered error when reading incoming request from worker:0: EOF: end of file (this error originated at tensorpipe/transport/uv/connection_
impl.cc:132)
10.0.87.143: 10.0.78.113: [W tensorpipe_agent.cpp:843] RPC agent for worker:30 encountered error when reading incoming request from worker:0: EOF: end of file (this error originated at tensorpipe/transport/uv/connectio
n_impl.cc:132)

@chaoyanghe
Copy link
Author

Another bug:
#65093

@lw
Copy link
Contributor

lw commented Sep 20, 2021

@chaoyanghe I cannot find the error log in the repo you pointed to.

@lw
Copy link
Contributor

lw commented Sep 20, 2021

Never mind, I went back in the history and found it. The logs are all garbled, because multiple processes wrote to the same file without synchronization, but the error seems to be this:

In operator() at tensorpipe/common/ibv.h:172 "": Operation not supported

This is something I've seen before: on AWS, the EFA card presents itself as capable of carrying InfiniBand traffic, and this "tricks" TensorPipe into trying to use it, however then it doesn't support some of the features that TensorPipe tries to use. We could do a more nuanced detection logic that probes for these features earlier, however I haven't gotten to it yet (I need to figure out how to use AWS). For now the workaround proposed by @pritamdamania87 is the best I could offer.

@pritamdamania87
Copy link
Contributor

@pritamdamania87 After I integrated this demo into our project. I met this issue:

@chaoyanghe Do you have a repro for this issue? @lw Looks like something was failing even after "_transports=["uv"]" was set.

@lw
Copy link
Contributor

lw commented Sep 20, 2021

I see, those EOF errors. In my experience these tend to just be "side effects" of another worker abruptly crashing, hence it may help to search the logs of the mentioned workers for the real root cause.

@rohan-varma rohan-varma added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 8, 2022
hyunwoongko pushed a commit to EleutherAI/oslo that referenced this issue Jul 24, 2023
## Title

Fix rpc bug on AWS


## Description

- rpc.TensorPipeRpcBackendOptions returns an error when run on AWS.


![image](https://github.com/EleutherAI/oslo/assets/26476095/4bb98124-1e0a-4d02-b473-cbe3ddaf7610)

related issues
- pytorch/pytorch#65022
- pytorch/tensorpipe#413
- pytorch/pytorch#65093
- pytorch/pytorch#65022

## Linked Issues

- resolved #00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: rpc Related to RPC, distributed autograd, RRef, and distributed optimizer module: tensorpipe Related to Tensorpipe RPC Agent oncall: distributed Add this issue/PR to distributed oncall triage queue pipeline parallelism Issues related to https://pytorch.org/docs/master/pipeline.html triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants