Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] enable cupy in order to enable cudagraph mode for AMD GPUs #3123

Merged
merged 4 commits into from
Mar 5, 2024

Conversation

hongxiayang
Copy link
Contributor

@hongxiayang hongxiayang commented Feb 29, 2024

This pull request enables cupy for ROCm backend in order to run throughput benchmarking script in cudagraph (hipgraph) mode successfully.

[Reason]:
vllm currently needs cupy in order to run cudagraph mode successfully, as mentioned in this comment

We use CuPy all-reduce instead of torch.distributed.all_reduce when capturing
CUDA graphs, because torch.distributed.all_reduce causes errors when capturing
CUDA graphs. 

It also assisted in conducting mGPU thoughput benchmarking successfully on the gfx1100 and gfx90a architecture on cudagraph mode.

(Note: we will upgrade the base docker to the latest when it is public. It is needed to run cudagraph mode backed by cupy successfully. Now, we should still run with eager mode for tp>1).

@hongxiayang hongxiayang marked this pull request as ready for review February 29, 2024 20:26
@hongxiayang
Copy link
Contributor Author

cc @lcskrishna who co-authored this change.

Copy link
Contributor

@lcskrishna lcskrishna left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@WoosukKwon
Copy link
Collaborator

@hongxiayang Thanks for the PR! Just a quick question: to my understanding, this PR itself doesn't let vLLM use cupy, because of this line?

and not is_hip()):

Did you intend to not fix this line in the PR?

@hongxiayang hongxiayang marked this pull request as draft March 1, 2024 15:33
@hongxiayang hongxiayang changed the title [ROCm] enable cupy for AMD GPUs [ROCm] enable cupy in order to enable cudagraph mode for AMD GPUs Mar 1, 2024
@hongxiayang hongxiayang marked this pull request as ready for review March 1, 2024 22:57
@hongxiayang
Copy link
Contributor Author

@hongxiayang Thanks for the PR! Just a quick question: to my understanding, this PR itself doesn't let vLLM use cupy, because of this line?

and not is_hip()):

Did you intend to not fix this line in the PR?

This is fixed. Thank you for pointing out.

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hongxiayang Thanks for the PR! Happy to see such contributions from AMD 😄 Please check out my question.

&& cd libs \
&& git clone -b hipgraph_enablement --recursive https://github.com/ROCm/cupy.git \
&& cd cupy \
&& pip install mpi4py-mpich \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need MPI for CuPy? For NVIDIA GPUs, we use TCP store instead of MPI.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vLLM uses a hack that terminates the TCP store used by cupy right after the cupy nccl backend is initialized:

# Stop the TCP store to prevent the deadlock issues at termination time.
# FIXME(woosuk): This is hacky. Find a more robust solution.
if rank == 0 and hasattr(_NCCL_BACKEND, "_store"):
_NCCL_BACKEND._store.stop()

I did this because I found that otherwise the worker processes hang when they are terminated. If ROCm cupy uses MPI, then vLLM might need a similar hack to prevent deadlocks at termination time.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @WoosukKwon we prefer to use MPI for ROCm Cupy. Is there any specific reason to choose TCP Store instead of MPI from the vLLM side?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vLLM uses a hack that terminates the TCP store used by cupy right after the cupy nccl backend is initialized:

# Stop the TCP store to prevent the deadlock issues at termination time.
# FIXME(woosuk): This is hacky. Find a more robust solution.
if rank == 0 and hasattr(_NCCL_BACKEND, "_store"):
_NCCL_BACKEND._store.stop()

I did this because I found that otherwise the worker processes hang when they are terminated. If ROCm cupy uses MPI, then vLLM might need a similar hack to prevent deadlocks at termination time.

@WoosukKwon (1) Can you give more context about this deadlock issue when the processes are terminated? We will need to test it to see whether we see the deadlock, so that we can determine whether we need this "stop" hack, and to verify the patch afterwards if it is needed. (2) What is the reason that the TCP store is used instead of MPI? Is there any performance issue with MPI you observed? As @lcskrishna mentioned, we have tested more on the MPI path.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @lcskrishna @hongxiayang , we used TCP store just to reduce dependencies. We can use MPI if you prefer it over TCP store. Although I haven't tested MPI + CuPy on NVIDIA GPUs, I believe it works.

The deadlock issue is that, when the main process is terminated, the process hangs waiting for other processes spawned by cupy TCP store. The _NCCL_BACKEND._store.stop() hack is to avoid this.

Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/util.py", line 357, in _exit_function
Process ExceptionAwareProcess-1:
    p.join()
  File "/usr/local/lib/python3.10/dist-packages/cupyx/distributed/_store.py", line 38, in join
    super().join()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 43, in wait
    return self.poll(os.WNOHANG if timeout == 0.0 else 0)
  File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 27, in poll
    pid, sts = os.waitpid(self.pid, flag)
KeyboardInterrupt: 
Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/local/lib/python3.10/dist-packages/cupyx/distributed/_store.py", line 32, in run
    super().run()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.10/dist-packages/cupyx/distributed/_store.py", line 87, in _server_loop
    c_socket, addr = s.accept()
  File "/usr/lib/python3.10/socket.py", line 293, in accept
    fd, addr = self._accept()

I'm not sure whether this also happens for the MPI backend. Could you please test it out and see whether it happens?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @lcskrishna @hongxiayang , we used TCP store just to reduce dependencies. We can use MPI if you prefer it over TCP store. Although I haven't tested MPI + CuPy on NVIDIA GPUs, I believe it works.

The deadlock issue is that, when the main process is terminated, the process hangs waiting for other processes spawned by cupy TCP store. The _NCCL_BACKEND._store.stop() hack is to avoid this.

Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/util.py", line 357, in _exit_function
Process ExceptionAwareProcess-1:
    p.join()
  File "/usr/local/lib/python3.10/dist-packages/cupyx/distributed/_store.py", line 38, in join
    super().join()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 43, in wait
    return self.poll(os.WNOHANG if timeout == 0.0 else 0)
  File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 27, in poll
    pid, sts = os.waitpid(self.pid, flag)
KeyboardInterrupt: 
Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/local/lib/python3.10/dist-packages/cupyx/distributed/_store.py", line 32, in run
    super().run()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.10/dist-packages/cupyx/distributed/_store.py", line 87, in _server_loop
    c_socket, addr = s.accept()
  File "/usr/lib/python3.10/socket.py", line 293, in accept
    fd, addr = self._accept()

I'm not sure whether this also happens for the MPI backend. Could you please test it out and see whether it happens?

@WoosukKwon Quick question for verification: Regarding "when the main process is terminated"? do you mean it was killed manually before it completes in running throughput benchmarking script?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hongxiayang Not really. Without the two-line hack on cupy TCP store, the process hangs when it normally terminates (e.g., after running python examples/llm_engine_example.py -tp 2).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hongxiayang Not really. Without the two-line hack on cupy TCP store, the process hangs when it normally terminates (e.g., after running python examples/llm_engine_example.py -tp 2).

@WoosukKwon It seem that we do not need to do anything for this situation.
(1) I tested the examples/llm_engine_example.py -tp 2, it completed fine without deadlock or hang without any patch to call mpi_comm.Abort(). (2) I also tested with throughput benchmarking script and press Ctrl-C in the middle of the script, the script stopped cleanly. (3) I discussed with @lcskrishna earlier, and he also thought that mpi might not need any additional hack.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Thanks for testing!

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for submitting the PR!

@WoosukKwon WoosukKwon merged commit 05af6da into vllm-project:main Mar 5, 2024
19 checks passed
@hongxiayang
Copy link
Contributor Author

LGTM! Thanks for submitting the PR!

Thank you for your review!

dtransposed pushed a commit to afeldman-nm/vllm that referenced this pull request Mar 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants