Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Full Tensor Parallelism for LoRA Layers #3524

Merged
merged 18 commits into from
Apr 27, 2024

Conversation

FurtherAI
Copy link
Contributor

@FurtherAI FurtherAI commented Mar 20, 2024

The current multi-LoRA layers run only one of the LoRAs with tensor parallelism. Here, I add layers that are fully sharded, as in, both LoRA layers run with tensor parallelism. This has better performance at higher sequence length, rank, tensor parallel size and so on, with similar performance at smaller scales.

One thing that needs to be added is the ability for the user to select the fully sharded versions or vLLM should select these based on some heuristics. Also need the tests I wrote to be converted to pytest, that should be a copy of the existing LoRA tests but with the layers replaced with the fully sharded versions or added to those tests. I had trouble running the tests myself.

I do not add versions for the sampler, since I had mixed results with the improvement, or the vocab parallel embedding, since it did not look like it would benefit very much.

ADD TO #1804

@FurtherAI FurtherAI changed the title Full Tensor Parallelism for LoRA Layers [Kernel] Full Tensor Parallelism for LoRA Layers Mar 20, 2024
@Yard1 Yard1 self-requested a review March 20, 2024 17:27
@Yard1
Copy link
Collaborator

Yard1 commented Mar 20, 2024

@FurtherAI Thank you for your contribution! I will be happy to help you get the tests converted to pytest.

@Yard1
Copy link
Collaborator

Yard1 commented Mar 25, 2024

@FurtherAI Thanks for the PR. General feedback:

  1. Would it be possible to reduce code duplication between sharded/non sharded LoRA layers by eg. making the sharded one subclasses of non-sharded ones and moving some logic to private methods so they can be called in both, if possible? Feel free to do whatever you need to reduce code duplication
  2. Likewise, it would be great if we could reduce code duplication in C-land (perhaps by modifying/merging some macros)
  3. I took a look at the test. I will need to think a little about how to rewrite it to use pytest.

@FurtherAI
Copy link
Contributor Author

  1. Yeah, it should be possible to move the functions aside from apply_weights() into the same class. apply_weights() is unique though, and not similar between layers aside from Merged and QKV.
  2. I don't see a good way to make the C stuff more compact. Pretty sure the expand kernel won't allow any feat_in less than 8 because it divides by vector_size which is 8. This would launch a grid of size 0. So, we need to choose some sizes that get the shrink but not the expand. Alternatively, it is possible to limit the degree of tensor parallelism based on the rank, but this would be awkward I think.
  3. Is it not as simple as adding the sharded versions of the layers to the create_random_*_layer() functions then adding the test with something like @pytest.mark.parametrize("shard_level", ["normal", "full"])? Anyways, thanks, appreciate that!

I'll try to reduce the duplicate code following what I mentioned and add that by next week.

@FurtherAI
Copy link
Contributor Author

@Yard1 Check out the latest commit. Minimized the duplicate code, allow the user to select the implementation and modified the test_layers.py to run tests on the fully sharded layers (though as I mentioned I can't get pytest to work locally so it is only a guess).

Copy link
Collaborator

@Yard1 Yard1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this looks good! I think it should be straightforward to convert the sharded test to use pytest. I am a bit worried about its complexity - it would be good to add some more comments to explain it and what happens in the Worker class.

Could we also add low-level unit tests for the INST dimensions in the test_punica.py file?

vllm/lora/layers.py Outdated Show resolved Hide resolved
vllm/lora/layers.py Outdated Show resolved Hide resolved
vllm/lora/fully_sharded_layers.py Outdated Show resolved Hide resolved
vllm/lora/fully_sharded_layers.py Show resolved Hide resolved
tests/lora/test_layers.py Outdated Show resolved Hide resolved
@FurtherAI
Copy link
Contributor Author

I think we should drop my test_sharded.py completely, it is only meant as a way for me to run the layers locally and it's a mess. The fully shared layers are a drop in replacement for the normal Lora layers, which are tested in test_layers.py and it is really easy to test the fully sharded layers there also, so we should do that.

I have already proposed the updated test_layers.py which handles this, so it just needs to run to double check it works. Since I'm having trouble running pytest locally, it'd be great if you could run it.

@Yard1
Copy link
Collaborator

Yard1 commented Apr 15, 2024

@FurtherAI Ok, that sounds good, let's drop it. Can you merge master into your branch to resolve conflicts?

@FurtherAI
Copy link
Contributor Author

Yes, I'll merge master in the next commit.

@FurtherAI
Copy link
Contributor Author

FurtherAI commented Apr 16, 2024

Forgot about this. The error on the last test is an oom error, not sure if it is related to the changes. But I'll merge main and add the tests and see if it occurs again.

Could we also add low-level unit tests for the INST dimensions in the test_punica.py file.

@FurtherAI
Copy link
Contributor Author

FurtherAI commented Apr 20, 2024

Merged main and added the extra Punica tests.
Only errors that seem to be left are the formatter and a typo from quantization that I just fixed.

The formatter is a little mad still for two things:

  • I import all the layers in vllm/vllm/lora/utils.py so they are detected for layer replacement. There might be a cleaner way to handle this, but what is your preference?
  • The lines in vllm/csrc/punica/bgmv/generator.py which are slightly too long.

@Yard1
Copy link
Collaborator

Yard1 commented Apr 25, 2024

Thanks @FurtherAI , this is looking good! I think we can merge them after conflicts are resolved, and then we can work on integrating those. Thank you for your patience!

I import all the layers in vllm/vllm/lora/utils.py so they are detected for layer replacement. There might be a cleaner way to handle this, but what is your preference?

Honestly, let's just explicitly define them in a list or something. This auto detection is too magical.

The lines in vllm/csrc/punica/bgmv/generator.py which are slightly too long.

Add # noqa: E501 to the end of both lines.

@FurtherAI
Copy link
Contributor Author

@Yard1 could you look at this? It can't figure out what formatting it wants for utils.py.

@Yard1
Copy link
Collaborator

Yard1 commented Apr 27, 2024

@FurtherAI should be good now

@FurtherAI
Copy link
Contributor Author

Tyvm.

Copy link
Collaborator

@Yard1 Yard1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Let's work on followup to use the new layers.

@Yard1 Yard1 merged commit eefeb16 into vllm-project:main Apr 27, 2024
48 checks passed
@exceedzhang
Copy link

@FurtherAI
When I was running the Qwen1.5-7B LoRA model, I encountered the following error when using multiple GPUs for inference:
WX20240505-131301@2x
Single GPU inference is normal, I estimate this may be a bug.

@FurtherAI
Copy link
Contributor Author

Yes, sorry about that. Made a new PR to fix that. Was able to reproduce and fix your error, so should be correct.

robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request May 6, 2024
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
z103cb pushed a commit to z103cb/opendatahub_vllm that referenced this pull request May 7, 2024
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
@sfc-gh-ybsat
Copy link

@FurtherAI seems like there may be a bug with llama3-70b-instruct. Getting the following error on startup of vllm (during engine initilization like count num blocks) when using tensor parallelism 8 with A100 GPUs.
TLDR error:
RuntimeError: x must be contiguous
Exception raised from dispatch_bgmv_low_level at csrc/punica/punica_ops.cu:326

Full error stack:

ERROR 05-17 21:49:30 worker_base.py:145] Error executing method determine_num_available_blocks. This might cause deadlock in distributed execution.
ERROR 05-17 21:49:30 worker_base.py:145] Traceback (most recent call last):
ERROR 05-17 21:49:30 worker_base.py:145]   File "/home/corvo/vllm-project/vllm/worker/worker_base.py", line 137, in execute_method
ERROR 05-17 21:49:30 worker_base.py:145]     return executor(*args, **kwargs)
ERROR 05-17 21:49:30 worker_base.py:145]   File "/home/corvo/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
ERROR 05-17 21:49:30 worker_base.py:145]     return func(*args, **kwargs)
ERROR 05-17 21:49:30 worker_base.py:145]   File "/home/corvo/vllm-project/vllm/worker/worker.py", line 153, in determine_num_available_blocks
ERROR 05-17 21:49:30 worker_base.py:145]     self.model_runner.profile_run()
ERROR 05-17 21:49:30 worker_base.py:145]   File "/home/corvo/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
ERROR 05-17 21:49:30 worker_base.py:145]     return func(*args, **kwargs)
ERROR 05-17 21:49:30 worker_base.py:145]   File "/home/corvo/vllm-project/vllm/worker/model_runner.py", line 892, in profile_run
ERROR 05-17 21:49:30 worker_base.py:145]     self.execute_model(seqs, kv_caches)
ERROR 05-17 21:49:30 worker_base.py:145]   File "/home/corvo/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
ERROR 05-17 21:49:30 worker_base.py:145]     return func(*args, **kwargs)
ERROR 05-17 21:49:30 worker_base.py:145]   File "/home/corvo/vllm-project/vllm/worker/model_runner.py", line 811, in execute_model
ERROR 05-17 21:49:30 worker_base.py:145]     hidden_states = model_executable(**execute_model_kwargs)
ERROR 05-17 21:49:30 worker_base.py:145]   File "/home/corvo/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
ERROR 05-17 21:49:30 worker_base.py:145]     return self._call_impl(*args, **kwargs)
ERROR 05-17 21:49:30 worker_base.py:145]   File "/home/corvo/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
ERROR 05-17 21:49:30 worker_base.py:145]     return forward_call(*args, **kwargs)
ERROR 05-17 21:49:30 worker_base.py:145]   File "/home/corvo/vllm-project/vllm/model_executor/models/llama.py", line 368, in forward
ERROR 05-17 21:49:30 worker_base.py:145]     hidden_states = self.model(input_ids, positions, kv_caches,
ERROR 05-17 21:49:30 worker_base.py:145]   File "/home/corvo/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
ERROR 05-17 21:49:30 worker_base.py:145]     return self._call_impl(*args, **kwargs)
ERROR 05-17 21:49:30 worker_base.py:145]   File "/home/corvo/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
ERROR 05-17 21:49:30 worker_base.py:145]     return forward_call(*args, **kwargs)
ERROR 05-17 21:49:30 worker_base.py:145]   File "/home/corvo/vllm-project/vllm/model_executor/models/llama.py", line 293, in forward
ERROR 05-17 21:49:30 worker_base.py:145]     hidden_states, residual = layer(
ERROR 05-17 21:49:30 worker_base.py:145]   File "/home/corvo/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
ERROR 05-17 21:49:30 worker_base.py:145]     return self._call_impl(*args, **kwargs)
ERROR 05-17 21:49:30 worker_base.py:145]   File "/home/corvo/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
ERROR 05-17 21:49:30 worker_base.py:145]     return forward_call(*args, **kwargs)
ERROR 05-17 21:49:30 worker_base.py:145]   File "/home/corvo/vllm-project/vllm/model_executor/models/llama.py", line 235, in forward
ERROR 05-17 21:49:30 worker_base.py:145]     hidden_states = self.self_attn(
ERROR 05-17 21:49:30 worker_base.py:145]   File "/home/corvo/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
ERROR 05-17 21:49:30 worker_base.py:145]     return self._call_impl(*args, **kwargs)
ERROR 05-17 21:49:30 worker_base.py:145]   File "/home/corvo/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
ERROR 05-17 21:49:30 worker_base.py:145]     return forward_call(*args, **kwargs)
ERROR 05-17 21:49:30 worker_base.py:145]   File "/home/corvo/vllm-project/vllm/model_executor/models/llama.py", line 165, in forward
ERROR 05-17 21:49:30 worker_base.py:145]     qkv, _ = self.qkv_proj(hidden_states)
ERROR 05-17 21:49:30 worker_base.py:145]   File "/home/corvo/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
ERROR 05-17 21:49:30 worker_base.py:145]     return self._call_impl(*args, **kwargs)
ERROR 05-17 21:49:30 worker_base.py:145]   File "/home/corvo/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
ERROR 05-17 21:49:30 worker_base.py:145]     return forward_call(*args, **kwargs)
ERROR 05-17 21:49:30 worker_base.py:145]   File "/home/corvo/vllm-project/vllm/lora/layers.py", line 465, in forward
ERROR 05-17 21:49:30 worker_base.py:145]     output_parallel = self.apply(input_, bias)
ERROR 05-17 21:49:30 worker_base.py:145]   File "/home/corvo/vllm-project/vllm/lora/fully_sharded_layers.py", line 194, in apply
ERROR 05-17 21:49:30 worker_base.py:145]     return _mcp_apply(x, bias, self)
ERROR 05-17 21:49:30 worker_base.py:145]   File "/home/corvo/vllm-project/vllm/lora/fully_sharded_layers.py", line 117, in _mcp_apply
ERROR 05-17 21:49:30 worker_base.py:145]     dispatch_bgmv_low_level(output, buffers[idx],
ERROR 05-17 21:49:30 worker_base.py:145]   File "/home/corvo/vllm-project/vllm/lora/punica.py", line 82, in dispatch_bgmv_low_level
ERROR 05-17 21:49:30 worker_base.py:145]     punica_kernels.dispatch_bgmv_low_level(
ERROR 05-17 21:49:30 worker_base.py:145] RuntimeError: x must be contiguous
ERROR 05-17 21:49:30 worker_base.py:145] Exception raised from dispatch_bgmv_low_level at /home/corvo/vllm-project/csrc/punica/punica_ops.cu:326 (most recent call first):
ERROR 05-17 21:49:30 worker_base.py:145] frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7fc0ceecf897 in /home/corvo/.local/lib/python3.10/site-packages/torch/lib/libc10.so)
ERROR 05-17 21:49:30 worker_base.py:145] frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0x68 (0x7fc0cee7fbee in /home/corvo/.local/lib/python3.10/site-packages/torch/lib/libc10.so)
ERROR 05-17 21:49:30 worker_base.py:145] frame #2: dispatch_bgmv_low_level(at::Tensor, at::Tensor, at::Tensor, at::Tensor, long, float, long, long, long) + 0x2ae9 (0x7fb3e7e1cb19 in /home/corvo/vllm-project/vllm/_punica_C.cpython-310-x86_64-linux-gnu.so)
ERROR 05-17 21:49:30 worker_base.py:145] frame #3: <unknown function> + 0x452b13 (0x7fb3e7e6db13 in /home/corvo/vllm-project/vllm/_punica_C.cpython-310-x86_64-linux-gnu.so)
ERROR 05-17 21:49:30 worker_base.py:145] frame #4: <unknown function> + 0x44f679 (0x7fb3e7e6a679 in /home/corvo/vllm-project/vllm/_punica_C.cpython-310-x86_64-linux-gnu.so)
ERROR 05-17 21:49:30 worker_base.py:145] <omitting python frames>
ERROR 05-17 21:49:30 worker_base.py:145] 

Let me know if you are able to reproduce.
Thanks in advance for your help!

@FurtherAI
Copy link
Contributor Author

FurtherAI commented May 19, 2024

@sfc-gh-ybsat Thanks for finding this. I'm waiting on access to Llama 3 or a way to skip that before I can reproduce it. It seems to be the buffer which is not contiguous. Quick fix is to just do buffers = tensor_model_parallel_all_gather(buffers).contiguous()).

Not sure why the buffer can be discontiguous, though, since they should be contiguous when they're created and all gather shouldn't change that.

Update: I was not able to reproduce, at least in eager mode. Graph capture was either taking too long or stuck waiting for NCCL. Running the example script multilora_inference.py with tensor parallel size 8, fully sharded loras, dummy weights and eager mode was successful.

@FurtherAI
Copy link
Contributor Author

@sfc-gh-ybsat Do you have some minimal code that can reproduce the error?

@sfc-gh-ybsat
Copy link

thanks @FurtherAI for the follow up.
I tried adding .conitguous() but that causes the startup to hang and eventually fail after RayWorker timeout.
Here is a reproducible example. You will need to download llama3-70b-instruct to some local path (in my case /models/llama3-70b-instruct) before running the example. The lora adapter I created is accesssible publically on hugging face so no extra step needed for you beyond download llama3.

Example is adjusted from https://docs.vllm.ai/en/latest/getting_started/examples/multilora_inference.html

from typing import List, Optional, Tuple

from huggingface_hub import snapshot_download

from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest


def create_test_prompts(
        lora_path: str
) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
    """Create a list of test prompts with their sampling parameters.

    2 requests for base model, 4 requests for the LoRA. We define 2
    different LoRA adapters (using the same model for demo purposes).
    Since we also set `max_loras=1`, the expectation is that the requests
    with the second LoRA adapter will be ran after all requests with the
    first adapter have finished.
    """
    return [
        (
            "Hi how are you?", SamplingParams(temperature=0.0,
                                              max_tokens=128),
            LoRARequest("sql-lora", 1, lora_path)),
    ]


def process_requests(engine: LLMEngine,
                     test_prompts: List[Tuple[str, SamplingParams,
                                              Optional[LoRARequest]]]):
    """Continuously process a list of prompts and handle the outputs."""
    request_id = 0

    while test_prompts or engine.has_unfinished_requests():
        if test_prompts:
            prompt, sampling_params, lora_request = test_prompts.pop(0)
            engine.add_request(str(request_id),
                               prompt,
                               sampling_params,
                               lora_request=lora_request)
            request_id += 1

        request_outputs: List[RequestOutput] = engine.step()

        for request_output in request_outputs:
            if request_output.finished:
                print(request_output)


def initialize_engine() -> LLMEngine:
    """Initialize the LLMEngine."""
    engine_args = EngineArgs(model="/models/meta-llama/Meta-Llama-3-70B-Instruct",
                             enable_lora=True,
                             tensor_parallel_size=8,
                             fully_sharded_loras=True,
                             max_loras=8,
                             max_lora_rank=8,
                             max_cpu_loras=16,
                             max_num_seqs=256)
    return LLMEngine.from_engine_args(engine_args)


def main():
    """Main function that sets up and runs the prompt processing."""
    engine = initialize_engine()
    lora_path = snapshot_download(repo_id="ybybybyb97/llama3-70b-lora-zero-weight")
    test_prompts = create_test_prompts(lora_path)
    process_requests(engine, test_prompts)


if __name__ == '__main__':
    main()

@FurtherAI
Copy link
Contributor Author

FurtherAI commented May 23, 2024

@sfc-gh-ybsat Thanks, I can reproduce. Calling contiguous is correct I think to fix the error, but there's another error with the graph capture. You should be able to run it in eager mode though for now with the line changed as below.

@Yard1, I think it is throwing an error during the all gather, do you know anything about what's wrong with this or changing it to call .contiguous()?
btw vllm is using nccl 2.18.1.

buffers = tensor_model_parallel_all_gather(buffers).contiguous()
dgx-h100-04:733673:733673 [0] include/alloc.h:103 NCCL WARN Cuda failure 1 'invalid argument'
dgx-h100-04:733673:733673 [0] NCCL INFO include/alloc.h:155 -> 1

dgx-h100-04:733673:733673 [0] include/alloc.h:161 NCCL WARN Failed to CUDA malloc 512 bytes
dgx-h100-04:733673:733673 [0] NCCL INFO enqueue.cc:1064 -> 1
dgx-h100-04:733673:733673 [0] NCCL INFO enqueue.cc:1306 -> 1
dgx-h100-04:733673:733673 [0] NCCL INFO group.cc:161 -> 1
dgx-h100-04:733673:733673 [0] NCCL INFO group.cc:339 -> 1
dgx-h100-04:733673:733673 [0] NCCL INFO group.cc:418 -> 1
dgx-h100-04:733673:733673 [0] NCCL INFO enqueue.cc:1981 -> 1
^C(raylet) The autoscaler failed with the following error:
Terminated with signal 15
  File "/home/auvesel/miniconda3/envs/clora/lib/python3.10/site-packages/ray/autoscaler/_private/monitor.py", line 709, in <module>
    monitor.run()
  File "/home/auvesel/miniconda3/envs/clora/lib/python3.10/site-packages/ray/autoscaler/_private/monitor.py", line 584, in run
    self._run()
  File "/home/auvesel/miniconda3/envs/clora/lib/python3.10/site-packages/ray/autoscaler/_private/monitor.py", line 438, in _run
    time.sleep(AUTOSCALER_UPDATE_INTERVAL_S)

(RayWorkerWrapper pid=747629) INFO 05-23 15:32:49 model_runner.py:824] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. [repeated 6x across cluster]
(RayWorkerWrapper pid=747629) INFO 05-23 15:32:49 model_runner.py:828] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage. [repeated 6x across cluster]
[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/auvesel/miniconda3/envs/clora/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 75, in wrapper
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/home/auvesel/miniconda3/envs/clora/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 2948, in all_gather_into_tensor
[rank0]:     work = group._allgather_base(output_tensor, input_tensor, opts)
[rank0]: torch.distributed.DistBackendError: NCCL error in: ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:2395, unhandled cuda error (run with NCCL_DEBUG=INFO for details), NCCL version 2.20.5
[rank0]: ncclUnhandledCudaError: Call to CUDA function failed.
[rank0]: Last error:
[rank0]: Failed to CUDA malloc 512 bytes

Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants