Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] add support to ROCm 6.0 and MI300 #2274

Merged
merged 17 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions Dockerfile.rocm
Original file line number Diff line number Diff line change
@@ -1,4 +1,24 @@
FROM rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1
# default base image
ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"

FROM $BASE_IMAGE

ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"

RUN echo "Base image is $BASE_IMAGE"

# BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1"
# BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"

# this does not always work for all rocm versions
RUN LLVM_GFX_ARCH=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) && \
echo "LLVM_GFX_ARCH is $LLVM_GFX_ARCH"

ARG FA_GFX_ARCHS="gfx90a;gfx942"
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"

ARG FA_BRANCH="3d2b6f5"
RUN echo "FA_BRANCH is $FA_BRANCH"

# Install some basic utilities
RUN apt-get update && apt-get install python3 python3-pip -y
Expand Down Expand Up @@ -37,17 +57,23 @@ RUN mkdir libs \
&& cd libs \
&& git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \
&& cd flash-attention \
&& git checkout 3d2b6f5 \
&& git checkout ${FA_BRANCH} \
&& git submodule update --init \
&& export GPU_ARCHS=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) \
&& patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch \
&& export GPU_ARCHS=${FA_GFX_ARCHS} \
hongxiayang marked this conversation as resolved.
Show resolved Hide resolved
&& if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \
patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \
&& python3 setup.py install \
&& cd ..

COPY ./ /app/vllm

RUN python3 -m pip install --upgrade pip
RUN pip install xformers==0.0.23 --no-deps
RUN python3 -m pip install xformers==0.0.23 --no-deps

# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
# Manually removed it so that later steps of numpy upgrade can continue
RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi

RUN cd /app \
&& cd vllm \
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ Easy, fast, and cheap LLM serving for everyone
---

*Latest News* 🔥
- [2023/12] Added ROCm support to vLLM.
- [2024/01] Added ROCm 6.0 support to vLLM.
- [2023/12] Added ROCm 5.7 support to vLLM.
- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there.
- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv!
Expand Down
30 changes: 27 additions & 3 deletions docs/source/getting_started/amd-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ Requirements
------------

* OS: Linux
* Python: 3.8 -- 3.11 (Verified on 3.10)
* GPU: MI200s
* Python: 3.8 -- 3.11
* GPU: MI200s (gfx90a), MI300 (gfx942)
* Pytorch 2.0.1/2.1.1/2.2
* ROCm 5.7
* ROCm 5.7 (Verified on python 3.10) or ROCm 6.0 (Verified on python 3.9)

Installation options:

Expand All @@ -27,6 +27,8 @@ Installation options:
(Recommended) Option 1: Quick start with vLLM pre-installed in Docker Image
---------------------------------------------------------------------------

This option is for ROCm 5.7 only:
hongxiayang marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: console

$ docker pull embeddedllminfo/vllm-rocm:vllm-v0.2.4
Expand Down Expand Up @@ -95,6 +97,23 @@ You can build and install vLLM from source:

Build a docker image from `Dockerfile.rocm`, and launch a docker container.

The `Dokerfile.rocm` is designed to support both ROCm 5.7 and ROCm 6.0 and later versions. It provides flexibility to customize the build of docker image using the following arguments:

* `BASE_IMAGE`: specifies the base image used when running `docker build`, specifically the PyTorch on ROCm base image. We have tested ROCm 5.7 and ROCm 6.0. The default is `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`
* `FX_GFX_ARCHS`: specifies the GFX architecture that is used to build flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942`
* `FA_BRANCH`: specifies the branch used to build the flash-attention in `ROCmSoftwarePlatform's flash-attention repo <https://github.com/ROCmSoftwarePlatform/flash-attention>`_. The default is `3d2b6f5`

Their values can be passed in when running `docker build` with `--build-arg` options.

For example, to build docker image for vllm on ROCm 5.7, you can run:

.. code-block:: console

$ docker build --build-arg BASE_IMAGE="rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" \
-f Dockerfile.rocm -t vllm-rocm .

To build vllm on ROCm 6.0, you can use the default:

.. code-block:: console

$ docker build -f Dockerfile.rocm -t vllm-rocm .
Expand Down Expand Up @@ -142,3 +161,8 @@ Alternatively, if you plan to install vLLM-ROCm on a local machine or start from
$ cd vllm
$ pip install -U -r requirements-rocm.txt
$ python setup.py install # This may take 5-10 minutes.

.. note::

- You may need to turn on the "--enforce-eager" flag if you experience process hang when running the `run_benchmark.py` script to test your installation.

3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def _is_cuda() -> bool:
"Cannot find ROCM_HOME. ROCm must be available to build the package."
)
NVCC_FLAGS += ["-DUSE_ROCM"]
NVCC_FLAGS += [f"-U__HIP_NO_HALF_CONVERSIONS__"]
NVCC_FLAGS += [f"-U__HIP_NO_HALF_OPERATORS__"]


if _is_cuda() and CUDA_HOME is None:
raise RuntimeError(
Expand Down
11 changes: 11 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,19 @@ def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes."""
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74

hongxiayang marked this conversation as resolved.
Show resolved Hide resolved
max_shared_mem = cuda_utils.get_device_attribute(
cudaDevAttrMaxSharedMemoryPerBlockOptin, gpu)
if max_shared_mem == 0 and is_hip():
# got 0 sometimes when using 74 on certain ROCm versions on torch 2.0.1
print(
"ROCm get_max_shared_memory_bytes got 0, trying to use value 97 instead"
)
cudaDevAttrMaxSharedMemoryPerBlockOptin = 97
max_shared_mem = cuda_utils.get_device_attribute(
cudaDevAttrMaxSharedMemoryPerBlockOptin, gpu)
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py will fail
assert max_shared_mem > 0, "max_shared_mem can not be zero"
return int(max_shared_mem)


Expand Down
23 changes: 21 additions & 2 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,27 @@ def init_model(self) -> None:

# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
self.device = torch.device(f"cuda:{self.local_rank}")
torch.cuda.set_device(self.device)

try:
self.device = torch.device(f"cuda:{self.local_rank}")
torch.cuda.set_device(self.device)
except RuntimeError as re:
# On certain versions, we experienced RuntimeError for rank non-0 when running with tensor-parallel option on ROCm.
# For example, for option, -tp 2, calling torch.cuda.set_device(self.device) for device 1 would throw the following error:
# HIP Error invalid device ordial
# By debugging, we found that CUDA_VISIABLE_DEVICES=0,1, but device_count is 1 and env HIP_VISIBLE_DEVICES is None.
# below is a work around when that happens so that we can continue
device_count = torch.cuda.device_count()
print(
f"RuntimeError {re} in cuda.set_device {self.device}, device_count={device_count}. "
)
if device_count > 0:
self.device = torch.device("cuda:0")
print(f"Trying get around by set_device to {self.device}")
torch.cuda.set_device(self.device)
else:
# no work around is available
raise

_check_if_gpu_supports_dtype(self.model_config.dtype)

Expand Down
Loading