Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepseekMoE support with Fused MoE kernel #2453

Merged
merged 20 commits into from
Jan 30, 2024

Conversation

zwd003
Copy link
Contributor

@zwd003 zwd003 commented Jan 16, 2024

Adding support for DeepseekMoE as described in here.

This work was partly done by @esmeetu and DeepSeek-AI

We have fixed some bugs in the @esmeetu's code and added support for expert parallelism and fused moe kernel.

Test code:

from vllm import LLM, SamplingParams

prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

llm = LLM(model="deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True, tensor_parallel_size=8)
outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

Ouput:

Prompt: 'Hello, my name is', Generated text: ' Mr. Jacobson, and I teach 6th Grade Mathematics and Physical Science'
Prompt: 'The president of the United States is', Generated text: ' not only the most powerful man on the planet, he is also one of the'
Prompt: 'The capital of France is', Generated text: ' among the most visited cities in Europe and the world, so it’s worth'
Prompt: 'The future of AI is', Generated text: ' here.\n- Develop your programming skills to use in AI.\n- Use'

update 2024.01.19

Performance Benchmarking for Fused MoE Improvements

This PR introduces significant performance enhancements when compared to the current method in Mistral and other baseline methods. Below is a summary of the benchmarks conducted:

Benchmark Details:

  • Hardware Configuration: 2 x A100 40G PCIe
  • Prompt Count: Varies per setting (32 or 256)
  • Average Prompt Length: 166 tokens
  • Max Tokens per Request: 64 or 256
  • Max Batch Size: 1 or 256

Results Summary:

Setting Model/Method Prompts Max Tokens Max Batch Size Time Cost (Baseline) Time Cost (This PR) Speedup
1 Baseline vs. with_fused_moe 256 64 256 25s 9s 2.8x
2 Baseline vs. with_fused_moe 32 64 1 104s 33s (fix) 3.2x (fix)
3 Llama2-13b vs. deepseek-16b-moe (with_fused_moe) 256 256 256 51s 23s 2.2x
4 Llama2-13b vs. deepseek-16b-moe (with_fused_moe) 32 256 1 396s 140s 2.8x
5* deepseek-16b-moe with PR #2293 vs. with_fused_moe 256 64 256 26s 9s 2.9x
6 deepseek-16b-moe vs. llama2-7b 256 64 256 8s 9s 0.9x

*Updated on 2024-01-20: Added comparison with PR #2293.
Updated on 2024.01.21: Implement the align_block_size function in C++ to achieve a 10% performance improvement. Now deepseek-moe16b is possible to achieve speeds almost identical to 7b dense model.

Updated on 2024.01.22: I found that it has greatly exceeded the speed of llama2-7b, with deepseekmoe-16b at 16587.82 tokens/s and llama2-7b at 10978.67 tokens/s. The speed bottleneck in the results from the table above(8s vs 9s) is not due to the model's computation speed. The following code was used to test this:

python benchmarks/benchmark_throughput.py  --model=meta-llama/Llama-2-7b\
 --input-len 1000 --output-len 64 -tp 2 --num-prompts 256
python benchmarks/benchmark_throughput.py  --model=deepseek-ai/deepseek-moe-16b-base \
--input-len 1000 --output-len 64 -tp 2 --num-prompts 256 --trust-remote-code

Future Works

I believe it is possible to achieve higher performance and surpass the speed of the 7b-dense model; we might need to do the following things:

  • Fuse the computations of shared_expert and routed_expert.
  • Fuse the gate, softmax, and topk operations.

those works may be done by the community in the future(in another pr)

@esmeetu
Copy link
Collaborator

esmeetu commented Jan 16, 2024

@zwd003 Thanks for fixing my last PR! But have you seen that it seems no speed boost after adapting expert parallelism.

@zhuohan123
Copy link
Collaborator

@esmeetu Can you help test and review this PR?

@esmeetu esmeetu mentioned this pull request Jan 17, 2024
4 tasks
@esmeetu
Copy link
Collaborator

esmeetu commented Jan 17, 2024

Hi, @zwd003, I refactor based on this PR . Please refer to #2467.
@zhuohan123 Sure. After i test on this, i found a more consistent model style aligning with Mixtral implementation. But it's difficult to make changes on this. So I commit a new PR based on this.

@zwd003
Copy link
Contributor Author

zwd003 commented Jan 18, 2024

@zwd003 Thanks for fixing my last PR! But have you seen that it seems no speed boost after adapting expert parallelism.

in my setting(8 A100 40g, tp=8, max_tokens=256, number of prompts = 256, max_batch_size = 256, with average input tokens per prompt 168), this implementation has faster speed.

|code                    |enforce_eager.     |speed(it/s)|
|baseline(original code) |True               |1.87|
|this(20240116)          |True               |7.04|
|this(fused_moe)         |False              |10.73|

@esmeetu
Copy link
Collaborator

esmeetu commented Jan 18, 2024

Hi, @zwd003 Could you help benchmark #2467 compared with this? I want to see how much performance difference.

@zwd003 zwd003 changed the title DeepseekMoE support DeepseekMoE support with Fused MoE kernel Jan 18, 2024
@zwd003
Copy link
Contributor Author

zwd003 commented Jan 19, 2024

i developed a fused MOE kernel that achieves faster speeds(different from #2293). the code is ready to review @esmeetu @zhuohan123.

@esmeetu
Copy link
Collaborator

esmeetu commented Jan 19, 2024

@zwd003 Running on T4 GPU with float16 not working.

RuntimeError: Internal Triton PTX codegen error:
ptxas /tmp/compile-ptx-src-f6f8c2, line 505; error : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-f6f8c2, line 505; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher

@zwd003
Copy link
Contributor Author

zwd003 commented Jan 19, 2024

@zwd003 Running on T4 GPU with float16 not working.

RuntimeError: Internal Triton PTX codegen error: ptxas /tmp/compile-ptx-src-f6f8c2, line 505; error : Feature '.bf16' requires .target sm_80 or higher ptxas /tmp/compile-ptx-src-f6f8c2, line 505; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher

now it can run with fp16

@esmeetu
Copy link
Collaborator

esmeetu commented Jan 19, 2024

LGTM!
For the generation speed, i got 12 t/s for batch size 1 which is x2 than before. However this also not good as a 13B model which is 27t/s on my machine with TP4.

@zwd003
Copy link
Contributor Author

zwd003 commented Jan 19, 2024

LGTM! For the generation speed, i got 12 t/s for batch size 1 which is x2 than before. However this also not good as a 13B model which is 27t/s on my machine with TP4.

in my setting(TP=2, A100 40g pcie, max_batch_size = 256, max_tokens=256, prompts=256 with average input length 166), llama2 13b cost 51s, deepseek16b-moe cost 25s. For smaller models, a smaller tp yields better acceleration effects compare to dense model. see benchmarks above for more results

@scv119
Copy link
Contributor

scv119 commented Jan 20, 2024

Pretty cool implementation! I do feel this PR will perform better result comparing to my implementation #2293. Wondering if you have benchmarked that yet?

@simon-mo simon-mo mentioned this pull request Jan 22, 2024
@SinanAkkoyun
Copy link

SinanAkkoyun commented Jan 22, 2024

Hi, thank you very much for all the great work!
Why is it being compared to llama 13B? If I didn't misunderstand, the MoE only has ~3B active parameters, what is holding the current implementation back from achieving llama 3B t/s?

@zwd003
Copy link
Contributor Author

zwd003 commented Jan 22, 2024

Hi, thank you very much for all the great work! Why is it being compared to llama 13B? If I didn't misunderstand, the MoE only has ~3B active parameters, what is holding the current implementation back from achieving llama 3B t/s?

MoE requires more optimization to achieve the same effect as dense models (dense models are mostly dense matrix multiplication, and cublas can almost reach the limit of hardware computation speed). For instance, the current mistral8x7b is also slower than the 14b dense model. Additionally, MoE models involve many more kernel computations, such as gate (softmax) and topk. If these kernels continue to be fused together, there is potential for further speed improvements.

@arnavdantuluri
Copy link

This looks really good!
I'm curious though if this is applicable out of the box to mistral8x7b as well. If not what changes would be necessary to port this code to mistral8x7b

@zwd003
Copy link
Contributor Author

zwd003 commented Jan 22, 2024

This looks really good! I'm curious though if this is applicable out of the box to mistral8x7b as well. If not what changes would be necessary to port this code to mistral8x7b

this also be applicable to mistral8x7b. Only need to modify the parameter names in Mistral (packing parameters from different experts together, you can se def pack_params() in deepseek model), but I haven't yet tested the acceleration effect in Mistral. Some Triton kernel hyperparameters (such as BLOCK_SIZE_M) may need to be adjusted to achieve the best performance.

@pcmoritz pcmoritz mentioned this pull request Jan 22, 2024
@casper-hansen
Copy link
Contributor

@zwd003 I have a question related to quantization.

How can we apply this optimization to quantized models? Do we need to dequantize weights before/during running the kernel to achieve the speedup?

@zwd003
Copy link
Contributor Author

zwd003 commented Jan 22, 2024

@zwd003 I have a question related to quantization.

How can we apply this optimization to quantized models? Do we need to dequantize weights before/during running the kernel to achieve the speedup?

it needs a new kernel supporting quant matmul(we need to rewrite the kernel in cuda/cpp, but the main idea is not changed, that is, computing each block for the corresponding expert)

@casper-hansen
Copy link
Contributor

it needs a new kernel supporting quant matmul(we need to rewrite the kernel in cuda/cpp, but the main idea is not changed, that is, computing each block for the corresponding expert)

I do believe we have Triton kernels for both GPTQ and AWQ. Perhaps you would need to create separate quantized Triton kernels based on the linked kernels for quantized matmul.

If this is of interest to DeepSeek, I could implement quantization in AWQ of the DeepSeek-MoE model.

intermediate_size: int,
hidden_act: str,
linear_method: Optional[LinearMethodBase] = None,
reduce_results=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was a helpful way to remove ambiguity if we define another DeepseekExpertMLP without reducing results. Then we can remove this reduce_results parameter. If we keep this, adding a type for reduce_results looks nicer. Both choices are ok.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, it's right, i have fix it

@zwd003
Copy link
Contributor Author

zwd003 commented Jan 22, 2024

Hi, thank you very much for all the great work! Why is it being compared to llama 13B? If I didn't misunderstand, the MoE only has ~3B active parameters, what is holding the current implementation back from achieving llama 3B t/s?

Even with the current code, I found that it is indeed 1.6 times faster than dense-7b. The performance bottleneck in the test results from the table is not in the model's computation. Please see the new test results.

@SinanAkkoyun
Copy link

Even with the current code, I found that it is indeed 1.6 times faster than dense-7b. The performance bottleneck in the test results from the table is not in the model's computation. Please see the new test results.

That's so great to hear, thanks!

@pcmoritz
Copy link
Collaborator

pcmoritz commented Jan 29, 2024

Unfortunately, there seem to be some correctness issues with the kernel (this was prompted by the user feedback #2542 (comment)). It happens in all kinds of configurations in smaller ways but can be reproduced in a pretty major way with the following diff:

diff --git a/tests/kernels/test_fused_moe.py b/tests/kernels/test_fused_moe.py
index f68e84f4f9..1b8e6321e6 100644
--- a/tests/kernels/test_fused_moe.py
+++ b/tests/kernels/test_fused_moe.py
@@ -22,11 +22,11 @@ def torch_moe(a, w1, w2, topk_weight, topk_ids):
             topk_weight.view(B, -1, 1)).sum(dim=1)
 
 
-@pytest.mark.parametrize("m", [512, 222, 33, 1])
-@pytest.mark.parametrize("n", [2048, 256, 1024])
-@pytest.mark.parametrize("k", [128, 511, 1024])
-@pytest.mark.parametrize("e", [8, 64])
-@pytest.mark.parametrize("topk", [2, 6])
+@pytest.mark.parametrize("m", [1])
+@pytest.mark.parametrize("n", [2048, 256, 1024, 8192])
+@pytest.mark.parametrize("k", [4096])
+@pytest.mark.parametrize("e", [8])
+@pytest.mark.parametrize("topk", [2])
 @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
 def test_fused_moe(
     m: int,
@@ -46,4 +46,4 @@ def test_fused_moe(
 
     triton_output = fused_moe(a, w1, w2, topk_weight, topk_ids, False)
     torch_output = torch_moe(a, w1, w2, topk_weight, topk_ids)
-    assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0)
+    assert torch.allclose(triton_output, torch_output, atol=1e-3, rtol=0), torch.max(abs(triton_output - torch_output))

This gives an error of 0.0312 for the test_fused_moe[dtype1-2-8-4096-8192-1] setting and I have actually seen even larger divergences. I suspect there is a bug in the fused_moe kernel somewhere.

@zwd003
Copy link
Contributor Author

zwd003 commented Jan 30, 2024

Unfortunately, there seem to be some correctness issues with the kernel (this was prompted by the user feedback #2542 (comment)). It happens in all kinds of configurations in smaller ways but can be reproduced in a pretty major way with the following diff:

diff --git a/tests/kernels/test_fused_moe.py b/tests/kernels/test_fused_moe.py
index f68e84f4f9..1b8e6321e6 100644
--- a/tests/kernels/test_fused_moe.py
+++ b/tests/kernels/test_fused_moe.py
@@ -22,11 +22,11 @@ def torch_moe(a, w1, w2, topk_weight, topk_ids):
             topk_weight.view(B, -1, 1)).sum(dim=1)
 
 
-@pytest.mark.parametrize("m", [512, 222, 33, 1])
-@pytest.mark.parametrize("n", [2048, 256, 1024])
-@pytest.mark.parametrize("k", [128, 511, 1024])
-@pytest.mark.parametrize("e", [8, 64])
-@pytest.mark.parametrize("topk", [2, 6])
+@pytest.mark.parametrize("m", [1])
+@pytest.mark.parametrize("n", [2048, 256, 1024, 8192])
+@pytest.mark.parametrize("k", [4096])
+@pytest.mark.parametrize("e", [8])
+@pytest.mark.parametrize("topk", [2])
 @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
 def test_fused_moe(
     m: int,
@@ -46,4 +46,4 @@ def test_fused_moe(
 
     triton_output = fused_moe(a, w1, w2, topk_weight, topk_ids, False)
     torch_output = torch_moe(a, w1, w2, topk_weight, topk_ids)
-    assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0)
+    assert torch.allclose(triton_output, torch_output, atol=1e-3, rtol=0), torch.max(abs(triton_output - torch_output))

This gives an error of 0.0312 for the test_fused_moe[dtype1-2-8-4096-8192-1] setting and I have actually seen even larger divergences. I suspect there is a bug in the fused_moe kernel somewhere.

This issue might be due to numerical precision; after switching to fp32, the difference is very small.

# assert hidden_states.dtype in [torch.float16, torch.bfloat16]# disable type assert
accumulator += tl.dot(a, b, allow_tf32=False) # disable tf32 in  in kernel
# accumulator = accumulator.to(compute_type) # disable translation to bf16 or fp16
test_fused_moe(1, 8192, 4096, 8, 2, torch.float32)

outputs:

fused_moe = tensor([[ 0.5336,  0.6163,  1.0068,  ...,  0.6273,  0.4510, -0.8778]],
       device='cuda:0')
torch_moe = tensor([[ 0.5336,  0.6163,  1.0068,  ...,  0.6273,  0.4510, -0.8778]],
       device='cuda:0')
(fused_moe - torch_moe).abs().max() = tensor(6.5863e-06, device='cuda:0')

@pcmoritz
Copy link
Collaborator

@zwd003 You are right, thanks a lot for checking. I only set the dtype to float32 OR used the allow_tf32=False but not both at the same time. Great catch! Let's merge the PR then :)

size_t numel) {
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread;
__shared__ int32_t tokens_cnts[NUM_MAX_EXPERTS + 1][NUM_MAX_EXPERTS];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Actually, this part doesn't need to be static. Shared memory size can be configured dynamically at the kernel launch time. However, I think we can fix this in a later PR.

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zwd003 LGTM. Thanks for the amazing work! We're really excited to have this model and the Triton fused MoE kernel.

@esmeetu Thanks for your original PR and your continuous contribution to vLLM!

@pcmoritz Many thanks for your review! It helped so much.

@WoosukKwon WoosukKwon merged commit 5d60def into vllm-project:main Jan 30, 2024
10 of 12 checks passed
@SinanAkkoyun
Copy link

When will this merge be pushed to the latest openai docker image?

Thank you so much for the work!

@casper-hansen
Copy link
Contributor

In my tests, (fused_moe - torch_moe).abs().max() gives 0.37 when compared to the Huggingface implementation of Mixtral. This is while setting both implementations to float16. @zwd003 are you sure the test implemented corresponds to the original implementation? It seems like a large difference when keeping the precision constant.

https://github.com/casper-hansen/AutoAWQ/blob/mixtral_fused/tests/test_fused_moe.py

@pcmoritz
Copy link
Collaborator

pcmoritz commented Jan 30, 2024

Thanks @casper-hansen, I'm digging into this some more and I'm also planning to add a test to the repo about it :)

In the test you posted (which is very nice btw), I see that all the states_fused are zero after running the forward pass (this is after porting it to the MixtralMoE that is now in master, so it might be different in your setting). I'm digging more into this (e.g. at the moment I'm not sure if the gate is being loaded).

@pcmoritz
Copy link
Collaborator

I figured out what is going on now I think. There were two adaptations I needed to make so your script can be adapted to the Mixtral MOE: Load the gate, and also set inplace=False so the hidden state doesn't get overridden. This is the updated script

import torch

torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)

import time
from vllm.model_executor.models.mixtral import MixtralMoE
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

config = MixtralConfig()

block = MixtralSparseMoeBlock(config).float().to("cuda")
fused = MixtralMoE(
    num_experts=config.num_local_experts,
    top_k=config.num_experts_per_tok,
    hidden_size=config.hidden_size,
    intermediate_size=config.intermediate_size,
)

fused.gate.linear_weights["weight"][:] = block.gate.weight.data
for i in range(config.num_local_experts):
    fused.ws[i][:] = torch.cat((
        block.experts[i].w1.weight.data,
        block.experts[i].w3.weight.data,
    ), dim=0).to("cuda")
    fused.w2s[i][:] = block.experts[i].w2.weight.data

def _run_profile(fn, inputs, rounds=2):
    start_time = time.perf_counter()
    torch.cuda.synchronize()

    for _ in range(rounds):
        states, router_logits = fn(inputs)

    torch.cuda.synchronize()
    end_time = time.perf_counter()

    return (end_time - start_time) / rounds, states, router_logits

# [batch_size, seq_len, hidden_dim]
inputs = torch.randn((1, 64, config.hidden_size)).to("cuda")

block_time, states_block, router_block = _run_profile(block.forward, inputs)
fused_time, states_fused, router_fused = _run_profile(fused.forward, inputs)

print(block_time, fused_time, block_time / fused_time)
print("states_fused", states_fused)
print("states_block", states_block)
print("diff1", (states_fused - states_block).mean().abs())
print("diff2", (states_fused - states_block).abs().max())

And this is the diff to the repo (mostly to make sure the MoE layer can run in the same process and also make sure it doesn't use lower precision tensor core arithmetic):

diff --git a/vllm/model_executor/layers/fused_moe.py b/vllm/model_executor/layers/fused_moe.py
index 998062d82d..99d8de7ccb 100644
--- a/vllm/model_executor/layers/fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe.py
@@ -105,7 +105,7 @@ def fused_moe_kernel(
                     mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
                     other=0.0)
         # We accumulate along the K dimension.
-        accumulator += tl.dot(a, b)
+        accumulator += tl.dot(a, b, allow_tf32=False)
         # Advance the ptrs to the next K block.
         a_ptrs += BLOCK_SIZE_K * stride_ak
         b_ptrs += BLOCK_SIZE_K * stride_bk
@@ -235,7 +235,7 @@ def fused_moe(hidden_states: torch.Tensor,
     assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
     assert w1.is_contiguous(), "Expert weights1 must be contiguous"
     assert w2.is_contiguous(), "Expert weights2 must be contiguous"
-    assert hidden_states.dtype in [torch.float16, torch.bfloat16]
+    assert hidden_states.dtype in [torch.float16, torch.bfloat16, torch.float32]
     M, _ = hidden_states.shape
     E, N, _ = w1.shape
 
diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py
index f36c35fd27..480de2b8bf 100644
--- a/vllm/model_executor/models/mixtral.py
+++ b/vllm/model_executor/models/mixtral.py
@@ -72,7 +72,7 @@ class MixtralMoE(nn.Module):
         params_dtype: Optional[torch.dtype] = None,
     ):
         super().__init__()
-        tp_size = get_tensor_model_parallel_world_size()
+        tp_size = 1 # get_tensor_model_parallel_world_size()
         self.num_total_experts = num_experts
         self.top_k = top_k
         self.hidden_size = hidden_size
@@ -93,13 +93,15 @@ class MixtralMoE(nn.Module):
                         2 * self.intermediate_size,
                         self.hidden_size,
                         device="cuda",
-                        dtype=self.params_dtype))
+                        dtype=self.params_dtype),
+            requires_grad=False)
         self.w2s = nn.Parameter(
             torch.empty(self.num_total_experts,
                         self.hidden_size,
                         self.intermediate_size,
                         device="cuda",
-                        dtype=self.params_dtype))
+                        dtype=self.params_dtype),
+            requires_grad=False)
 
         set_weight_attrs(self.ws, {
             "weight_loader": self.weight_loader,
@@ -139,13 +141,13 @@ class MixtralMoE(nn.Module):
                                         self.w2s,
                                         routing_weights,
                                         selected_experts,
-                                        inplace=True)
+                                        inplace=False)
 
-        final_hidden_states = tensor_model_parallel_all_reduce(
-            final_hidden_states)
+        # final_hidden_states = tensor_model_parallel_all_reduce(
+        #     final_hidden_states)
 
         return final_hidden_states.view(batch_size, sequence_length,
-                                        hidden_size)
+                                        hidden_size), None
 
 
 class MixtralAttention(nn.Module):
@@ -160,7 +162,7 @@ class MixtralAttention(nn.Module):
                  sliding_window: Optional[int] = None) -> None:
         super().__init__()
         self.hidden_size = hidden_size
-        tp_size = get_tensor_model_parallel_world_size()
+        tp_size = 1 # get_tensor_model_parallel_world_size()
         self.total_num_heads = num_heads
         assert self.total_num_heads % tp_size == 0
         self.num_heads = self.total_num_heads // tp_size

With those modifications, the results are very accurate even if I set the number of rounds to a high value, for example for rounds = 100, I'm getting

diff1 tensor(4.0595e-08, device='cuda:0', grad_fn=<AbsBackward0>)
diff2 tensor(0.0002, device='cuda:0', grad_fn=<MaxBackward1>)

I'll convert this to a test that can be committed into the repo next! Thanks for looking into this, I'm similarly interested in making sure the model quality is as high as possible :)

@pcmoritz
Copy link
Collaborator

@casper-hansen Unit test added in #2677

@zhuohan123 zhuohan123 mentioned this pull request Jan 31, 2024
30 tasks
NikolaBorisov pushed a commit to deepinfra/vllm that referenced this pull request Jan 31, 2024
Co-authored-by: roy <jasonailu87@gmail.com>
@chu-tianxiang chu-tianxiang mentioned this pull request Feb 5, 2024
3 tasks
hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Feb 13, 2024
Co-authored-by: roy <jasonailu87@gmail.com>
alexm-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request Feb 13, 2024
Co-authored-by: roy <jasonailu87@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants