Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for compute capability <7.0 #963

Closed
andersRuge opened this issue Sep 6, 2023 · 22 comments · May be fixed by #4409
Closed

Support for compute capability <7.0 #963

andersRuge opened this issue Sep 6, 2023 · 22 comments · May be fixed by #4409

Comments

@andersRuge
Copy link

Hi,

How tightly coupled is the requirement for compute capability of 7.0 or higher? Is it possible to disable some features, and run on e.g. 6.0? Like a P100

Maybe this is totally unfeasible, but I am limited in my GPU options.

@WoosukKwon
Copy link
Collaborator

Hi @andersRuge, I believe there's no technical reason. We just thought that Pascal GPUs were not very popular these days, and we didn't have Pascal GPUs to test vLLM on. vLLM may work without any modification: Just try installing vLLM from source.

git clone https://github.com/vllm-project/vllm.git
cd vllm
pip install -e .

@klebster2
Copy link

klebster2 commented Dec 19, 2023

Hi there,

I can confirm that it's working with Pascal architecture (Quadro P2000, CUDA version 12.3) when built from source.

Tested on a Dell Precision 5530 Laptop (with Python 3.9.18)
VLLM git hash: 671af2b
VLLM version: 0.2.6

git clone https://github.com/vllm-project/vllm.git
cd vllm
# The setup doesn't allow compute capability greater than 7.0 (lines 149,150,151 cause this because they limit the version explicitly)
mv setup.py _setup.py
# We use awk to recreate the file without that 'if' block in python
awk '!(NR == 151 || NR == 150 || NR == 149)' ./_setup.py  > ./setup.py
pip install -e .

Below, find the lines removed from the setup.py file with indicated with the comment at the end line '# REMOVE'

# First, check the TORCH_CUDA_ARCH_LIST environment variable.
compute_capabilities = get_torch_arch_list()
if _is_cuda() and not compute_capabilities:
    # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
    # GPUs on the current machine.
    device_count = torch.cuda.device_count()
    for i in range(device_count):
        major, minor = torch.cuda.get_device_capability(i)
        if major < 7:  # REMOVE
            raise RuntimeError(  # REMOVE
                "GPUs with compute capability below 7.0 are not supported.")  # REMOVE
        compute_capabilities.add(f"{major}.{minor}")

Image of VLLM working on the machine mentioned:

VLLM Pascal architecture Quadro P2000 works

@jasonacox
Copy link
Contributor

Thanks for this tip @klebster2 ! That is exactly what I needed.

I was able to get vLLM to work with the current version (v0.2.7 / 220a476) in a Docker container. My test rig is Ubuntu 22.04, CUDA 12.1 and I started with a GTX 1060 and then tested on 4 x P100's for a larger model.

Based on your notes, here is my How-to:

  1. Pull (git clone) latest vllm then add the 6.x compute capabilities to the setup.py NVIDIA_SUPPORTED_ARCHS - Patch diff:
--- _setup.py	2024-01-27 18:44:45.509406538 +0000
+++ setup.py	2024-01-28 00:02:23.581639719 +0000
@@ -18,7 +18,7 @@
 MAIN_CUDA_VERSION = "12.1"
 
 # Supported NVIDIA GPU architectures.
-NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
+NVIDIA_SUPPORTED_ARCHS = {"6.0", "6.1", "6.2", "7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
 ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx1030", "gfx1100"}
 # SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS)
 
@@ -184,9 +184,9 @@
     device_count = torch.cuda.device_count()
     for i in range(device_count):
         major, minor = torch.cuda.get_device_capability(i)
-        if major < 7:
+        if major < 6:
             raise RuntimeError(
-                "GPUs with compute capability below 7.0 are not supported.")
+                "GPUs with compute capability below 6.0 are not supported.")
         compute_capabilities.add(f"{major}.{minor}")
 
 ext_modules = []
  1. Create a Dockerfile (build from vllm directory and add entrypoint.sh):
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04
RUN apt-get update -y \
     && apt-get install -y python3-pip
WORKDIR /app
COPY . .
RUN python3 -m pip install -e .
EXPOSE 8001
COPY entrypoint.sh /usr/local/bin/
CMD [ "entrypoint.sh" ]
  1. Run with --dtype float (since bfloat16 not available for Pascal arch), add --shm-size and store the downloaded model in a persistent volume (change /path/to/models to where you want to keep the HF models):
nvidia-docker run -d -p 8001:8001 --gpus=all --shm-size=10.24gb \
  -e MODEL=mistralai/Mistral-7B-Instruct-v0.1 \
  -e PORT=8001 \
  -e HF_HOME=/app/models \
  -e NUM_GPU=4 \
  -e EXTRA_ARGS="--dtype float --max-model-len 20000" \
  -v /path/to/models:/app/models \
  --name vllm \
  vllm 

Additional details here: https://github.com/jasonacox/TinyLLM/tree/main/vllm#running-vllm-on-pascal

INFO 01-27 23:52:57 llm_engine.py:871] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 41.7 tokens/s, Running: 1 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.4%, CPU KV cache usage: 0.0%
INFO 01-27 23:53:02 llm_engine.py:871] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 41.1 tokens/s, Running: 1 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.8%, CPU KV cache usage: 0.0%

@klebster2
Copy link

Great to hear it helped!

@Nero10578
Copy link

Thanks for this tip @klebster2 ! That is exactly what I needed.

I was able to get vLLM to work with the current version (v0.2.7 / 220a476) in a Docker container. My test rig is Ubuntu 22.04, CUDA 12.1 and I started with a GTX 1060 and then tested on 4 x P100's for a larger model.

Based on your notes, here is my How-to:

  1. Pull (git clone) latest vllm then add the 6.x compute capabilities to the setup.py NVIDIA_SUPPORTED_ARCHS - Patch diff:
--- _setup.py	2024-01-27 18:44:45.509406538 +0000
+++ setup.py	2024-01-28 00:02:23.581639719 +0000
@@ -18,7 +18,7 @@
 MAIN_CUDA_VERSION = "12.1"
 
 # Supported NVIDIA GPU architectures.
-NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
+NVIDIA_SUPPORTED_ARCHS = {"6.0", "6.1", "6.2", "7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
 ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx1030", "gfx1100"}
 # SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS)
 
@@ -184,9 +184,9 @@
     device_count = torch.cuda.device_count()
     for i in range(device_count):
         major, minor = torch.cuda.get_device_capability(i)
-        if major < 7:
+        if major < 6:
             raise RuntimeError(
-                "GPUs with compute capability below 7.0 are not supported.")
+                "GPUs with compute capability below 6.0 are not supported.")
         compute_capabilities.add(f"{major}.{minor}")
 
 ext_modules = []
  1. Create a Dockerfile (build from vllm directory and add entrypoint.sh):
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04
RUN apt-get update -y \
     && apt-get install -y python3-pip
WORKDIR /app
COPY . .
RUN python3 -m pip install -e .
EXPOSE 8001
COPY entrypoint.sh /usr/local/bin/
CMD [ "entrypoint.sh" ]
  1. Run with --dtype float (since bfloat16 not available for Pascal arch), add --shm-size and store the downloaded model in a persistent volume (change /path/to/models to where you want to keep the HF models):
nvidia-docker run -d -p 8001:8001 --gpus=all --shm-size=10.24gb \
  -e MODEL=mistralai/Mistral-7B-Instruct-v0.1 \
  -e PORT=8001 \
  -e HF_HOME=/app/models \
  -e NUM_GPU=4 \
  -e EXTRA_ARGS="--dtype float --max-model-len 20000" \
  -v /path/to/models:/app/models \
  --name vllm \
  vllm 

Additional details here: https://github.com/jasonacox/TinyLLM/tree/main/vllm#running-vllm-on-pascal

INFO 01-27 23:52:57 llm_engine.py:871] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 41.7 tokens/s, Running: 1 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.4%, CPU KV cache usage: 0.0% INFO 01-27 23:53:02 llm_engine.py:871] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 41.1 tokens/s, Running: 1 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.8%, CPU KV cache usage: 0.0%

What's the max tokens/s you can get on 4x P100s?

@jasonacox
Copy link
Contributor

Running a simple benchmark.py I get:

Min TPS: 27.6, Max TPS: 40.1, Avg TPS: 33.9

@Nero10578
Copy link

Running a simple benchmark.py I get:

Min TPS: 27.6, Max TPS: 40.1, Avg TPS: 33.9

Considering vllm is a batched inference server wouldn't you be able to still get more tokens/s running multiple of the benchmark scripts? Unless its already maxing out the P100s? Honestly 34tps is kinda low, I was hoping it would be a bit higher then it could be viable as a cheap inference GPU...

@jasonacox
Copy link
Contributor

jasonacox commented Feb 1, 2024

You are right, I'm only running one thread with this test. I ran again with 3 benchmarks running simultaneously and all 3 stayed above 30 TPS (so that could be considered 90 TPS ?), but I don't know where the ceiling would be. If anyone has a good test script that pushes batching, would love to try it.

UPDATE - I ran a MT benchmark with 10 threads and got 208 tokens/s. Each concurrent thread was seeing about 21.

Average TPS across all 10 threads: 208.3 - Individual Threads: Min TPS: 20.8, Max TPS: 20.8

@Nero10578
Copy link

You are right, I'm only running one thread with this test. I ran again with 3 benchmarks running simultaneously and all 3 stayed above 30 TPS (so that could be considered 90 TPS ?), but I don't know where the ceiling would be. If anyone has a good test script that pushes batching, would love to try it.

UPDATE - I ran a MT benchmark with 10 threads and got 208 tokens/s. Each concurrent thread was seeing about 21.

Average TPS across all 10 threads: 208.3 - Individual Threads: Min TPS: 20.8, Max TPS: 20.8

That's decent speeds out of old cards actually. Did you have time to test larger models on it too?

@jasonacox
Copy link
Contributor

I would be happy to, any suggestions?

On Mistral 7B Instruct, I'm running the full 32k context, using float16 (bfloat16 not available on Pascal?) and it is filling up most of the VRAM on all four P100 (each with 16G). Now for the other bit... there are 7 GPUs in this system, but vLLM will only split the model across 4 (to get an even split of the 32 layers I assume). Of course, I'm not wasting those GPUs. I'm using the other 3 GPUs for text2vec transformers and smaller models. But it would be fun to test a model on all 7. 😀

@jasonacox
Copy link
Contributor

Would anyone here be willing to review this PR? #2635

@Nero10578
Copy link

What is the minimum VRAM required for running 7B models on the P100? If 2 cards are enough I might try and get 2x P100 to experiment with and try that PR. Also if you could test larger 34b models on more than 4 cards that would be awesome too since I can't even run it on vllm on 2x 3090s.

@jasonacox
Copy link
Contributor

I would only go with P100's if they are cheap. They only have 16G of VRAM and the Pascal architecture is at the bottom edge of CUDA support. The 7B model with 32k context (--max-model-len) needs ~40G of VRAM on the 4 x P100 setup I have but with a smaller context it may fit within the 32G limit of having two. I can try with a 4k context (--max-model-len 4096) to see what that would do. For your 2 x 3090s setup with the 34B model, try turning down the context to see if that helps fit in the 48G limit.

@Nero10578
Copy link

Nero10578 commented Mar 13, 2024

I would only go with P100's if they are cheap. They only have 16G of VRAM and the Pascal architecture is at the bottom edge of CUDA support. The 7B model with 32k context (--max-model-len) needs ~40G of VRAM on the 4 x P100 setup I have but with a smaller context it may fit within the 32G limit of having two. I can try with a 4k context (--max-model-len 4096) to see what that would do. For your 2 x 3090s setup with the 34B model, try turning down the context to see if that helps fit in the 48G limit.

So I have experimented with VLLM some more and I can run 70B AWQ 4-bit models on my 2x3090 with up to --max-model-len 16384 and --gpu-memory-utilization 0.98. Have you tried AWQ 4-bit models? I bet you can fit much more on the P100 since the AWQ 4-bit models use so much less VRAM than the full FP16 models do.

In terms of performance I am getting this:

2x RTX 3090 + Intel Xeon E5 2679 v4 3.2GHz 20-cores + 256GB DDR4 2400MHz RDIMM

Mistral-Dolphin-2.6-7B
Completed 29 prompts and produced 28350 tokens in 39.061 seconds.
Average TPS across all 20 threads: 725.8 - Individual Threads: Min TPS: 35.9, Max TPS: 36.4

Mythalion-13B-AWQ
Completed 29 prompts and produced 306291 tokens in 857.697 seconds.
Average TPS across all 40 threads: 357.1 - Individual Threads: Min TPS: 8.1, Max TPS: 9.2

dolphin-2.2-70B-AWQ
Completed 29 prompts and produced 138811 tokens in 734.203 seconds.
Average TPS across all 20 threads: 189.1 - Individual Threads: Min TPS: 9.4, Max TPS: 9.5

Senku-70B-AWQ 16384ctx
Completed 29 prompts and produced 106804 tokens in 698.438 seconds.
Average TPS across all 20 threads: 152.9 - Individual Threads: Min TPS: 6.4, Max TPS: 10.0

RTX 3060 + Intel i3 7350K 5.1GHz 2-cores + 32GB DDR4 3200MHz

Mistral-Dolphin-2.6-7B-AWQ
Completed 29 prompts and produced 123820 tokens in 237.819 seconds.
Average TPS across all 40 threads: 520.6 - Individual Threads: Min TPS: 13.0, Max TPS: 13.0

Mythalion-13B-AWQ
Completed 29 prompts and produced 153243 tokens in 962.322 seconds.
Average TPS across all 20 threads: 159.2 - Individual Threads: Min TPS: 7.9, Max TPS: 8.0

It seems to me like the limitation becomes the single core performance of my CPU when running 7B on my 3090s since its not much faster than my 3060 machine with a much less cores but much faster GHz CPU.

@jasonacox
Copy link
Contributor

Thanks @Nero10578 ! I'll give it a try.

@cduk
Copy link

cduk commented Apr 16, 2024

What is the minimum VRAM required for running 7B models on the P100? If 2 cards are enough I might try and get 2x P100 to experiment with and try that PR. Also if you could test larger 34b models on more than 4 cards that would be awesome too since I can't even run it on vllm on 2x 3090s.

Just to add, I've tested vLLM with P100 and it works very well. If you want to use only a single card, then you can limit context size, or better still, use a quantized model and then it will run very fast as a bonus.

@jasonacox
Copy link
Contributor

Just to add, I've tested vLLM with P100

I believe the latest release added optional support without needing to patch. Did you need to do anything special to get it to work?

@Nero10578
Copy link

Just to add, I've tested vLLM with P100

I believe the latest release added optional support without needing to patch. Did you need to do anything special to get it to work?

Oh awesome. I might get some of the P100s then. I was also asking the Aphrodite Engine devs if the P100 would work on it too since Aphrodite is an awesome fork of VLLM that supports more quantizations.

@cduk
Copy link

cduk commented Apr 17, 2024

Just to add, I've tested vLLM with P100

I believe the latest release added optional support without needing to patch. Did you need to do anything special to get it to work?

Yes, I had to patch the source code. I can send a pull request.

@lee-b
Copy link

lee-b commented Apr 20, 2024

Please merge pascal support. Many of us are running P40 / P100 rigs (even built recently for the purpose) because they're a very good VRAM/$ deal given ebay prices and the limited number PCIe slots available on consumer rigs.

@jasonacox
Copy link
Contributor

Yes, I had to patch the source code. I can send a pull request.

Thanks! Or post it here?

@cduk
Copy link

cduk commented Apr 23, 2024

I posted the pull request: #4290

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants