Skip to content

Conversation

maleksan85
Copy link
Contributor

@maleksan85 maleksan85 commented Sep 18, 2025

Perf is the same:
upstream tp1

============ Serving Benchmark Result ============
Successful requests:                     320
Maximum request concurrency:             64
Benchmark duration (s):                  413.07
Total input tokens:                      326905
Total generated tokens:                  327680
Request throughput (req/s):              0.77
Output token throughput (tok/s):         793.27
Total Token throughput (tok/s):          1584.66
---------------Time to First Token----------------
Mean TTFT (ms):                          5240.27
Median TTFT (ms):                        5150.59
P99 TTFT (ms):                           13542.02
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          75.58
Median TPOT (ms):                        76.05
P99 TPOT (ms):                           79.35
---------------Inter-token Latency----------------
Mean ITL (ms):                           75.58
Median ITL (ms):                         66.27
P99 ITL (ms):                            73.04
----------------End-to-end Latency----------------
Mean E2EL (ms):                          82558.95
Median E2EL (ms):                        82994.12
P99 E2EL (ms):                           89413.31
==================================================

355_wip tp1

============ Serving Benchmark Result ============
Successful requests:                     320
Maximum request concurrency:             64
Benchmark duration (s):                  415.51
Total input tokens:                      326905
Total generated tokens:                  327680
Request throughput (req/s):              0.77
Output token throughput (tok/s):         788.62
Total Token throughput (tok/s):          1575.37
---------------Time to First Token----------------
Mean TTFT (ms):                          5262.51
Median TTFT (ms):                        5156.42
P99 TTFT (ms):                           14143.78
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          76.03
Median TPOT (ms):                        76.44
P99 TPOT (ms):                           79.86
---------------Inter-token Latency----------------
Mean ITL (ms):                           76.03
Median ITL (ms):                         66.64
P99 ITL (ms):                            69.13
----------------End-to-end Latency----------------
Mean E2EL (ms):                          83044.74
Median E2EL (ms):                        83387.81
P99 E2EL (ms):                           90413.53
==================================================

Command:

HIP_VISIBLE_DEVICES=7 \
VLLM_DISABLE_COMPILE_CACHE=1 \
USE_FASTSAFETENSOR=1 \
SAFETENSORS_FAST_GPU=1 \
VLLM_USE_V1=1 \
AMDGCN_USE_BUFFER_OPS=1 \
TRITON_HIP_ASYNC_COPY_BYPASS_PERMUTE=1 \
TRITON_HIP_USE_ASYNC_COPY=1 \
TRITON_HIP_USE_BLOCK_PINGPONG=1 \
TRITON_HIP_ASYNC_FAST_SWIZZLE=1 \
VLLM_ROCM_USE_AITER=1 \
VLLM_ROCM_USE_AITER_MHA=0 \
VLLM_V1_USE_PREFILL_DECODE_ATTENTION=1 \
VLLM_USE_AITER_UNIFIED_ATTENTION=0 \
VLLM_ROCM_USE_TRITON_ROPE=1 \
VLLM_ROCM_USE_AITER_RMSNORM=1 \
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM=1 \
vllm serve /data/models/Llama-3.1-405B-Instruct-MXFP4-Preview \
  --host localhost \
  --port 30000 \
  --swap-space 64 \
  --disable-log-requests \
  --dtype auto \
  --max-model-len 8192 \
  --tensor-parallel-size 1 \
  --max-num-seqs 64 \
  --distributed-executor-backend mp \
  --trust-remote-code \
  --kv-cache-dtype fp8 \
  --gpu-memory-utilization 0.85 \
  --max-seq-len-to-capture 8192 \
  --no-enable-prefix-caching \
  --async-scheduling \
  --max-num-batched-tokens 8192 \
  --compilation-config='{"pass_config":{"enable_attn_fusion":true,"enable_noop":true,"enable_fusion":true},"cudagraph_mode":"FULL","custom_ops":["+rms_norm","+silu_and_mul","+quant_fp8"],"splitting_ops":[]}'

Run the client benchmark

vllm bench serve \
  --host localhost \
  --port 30000 \
  --model /data/models/Llama-3.1-405B-Instruct-MXFP4-Preview \
  --dataset-name random \
  --random-input-len 1024 \
  --random-output-len 1024 \
  --max-concurrency 64 \
  --num-prompts 320 \
  --percentile-metrics ttft,tpot,itl,e2el \
  --ignore-eos

Correctness - shows reasonable answers for command:

HIP_VISIBLE_DEVICES=7 \
VLLM_DISABLE_COMPILE_CACHE=1 \
USE_FASTSAFETENSOR=1 \
SAFETENSORS_FAST_GPU=1 \
VLLM_USE_V1=1 \
VLLM_V1_USE_PREFILL_DECODE_ATTENTION=1 \
AMDGCN_USE_BUFFER_OPS=1 \
VLLM_USE_AITER_TRITON_ROPE=1 \
TRITON_HIP_ASYNC_COPY_BYPASS_PERMUTE=1 \
TRITON_HIP_USE_ASYNC_COPY=1 \
TRITON_HIP_USE_BLOCK_PINGPONG=1 \
TRITON_HIP_ASYNC_FAST_SWIZZLE=1 \
VLLM_ROCM_USE_AITER=1 \
VLLM_ROCM_USE_AITER_MHA=0 \
VLLM_ROCM_USE_AITER_RMSNORM=1 \
VLLM_TRITON_FP4_GEMM_USE_ASM=1 \
python /data/vllm-scripts/llm_test.py \
  --model /data/models/Llama-3.1-405B-Instruct-MXFP4-Preview \
  --dataset-path /data/models/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json \
  --batch-size 32 \
  --swap-space 64 \
  --dtype auto \
  --max-model-len 8192 \
  --tensor-parallel-size 1 \
  --max-num-seqs 1024 \
  --kv-cache-dtype fp8 \
  --gpu-memory-utilization 0.92 \
  --max-seq-len-to-capture 8192 \
  --no-enable-prefix-caching \
  --max-num-batched-tokens 8192

Aleksandr Malyshev added 3 commits September 18, 2025 03:01
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
@mergify mergify bot added the llama Related to Llama models label Sep 18, 2025
@maleksan85 maleksan85 marked this pull request as ready for review September 18, 2025 04:21
Aleksandr Malyshev added 3 commits September 18, 2025 18:33
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
@maleksan85 maleksan85 requested a review from gshtras as a code owner September 18, 2025 23:26
@mergify mergify bot added the v1 label Sep 18, 2025
Copy link

mergify bot commented Sep 18, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @maleksan85.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 18, 2025
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
@maleksan85 maleksan85 force-pushed the llamas_changes_upstr_from_355_wip branch from 798e475 to f9626ee Compare September 18, 2025 23:29
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
@gshtras gshtras added the rocm Related to AMD ROCm label Sep 19, 2025
@mergify mergify bot removed the needs-rebase label Sep 19, 2025
@gshtras gshtras added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 19, 2025
Copy link

mergify bot commented Sep 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @maleksan85.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 19, 2025
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
@mergify mergify bot removed the needs-rebase label Sep 22, 2025
@wuhuikx
Copy link

wuhuikx commented Sep 23, 2025

Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get some unit tests for the batched_rotary_embedding kernel?

Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
@maleksan85
Copy link
Contributor Author

Can we get some unit tests for the batched_rotary_embedding kernel?

removed batched rope for now to speed up this PR landing

Copy link
Contributor

@fxmarty-amd fxmarty-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great that cdna4 mxfp4 gemm gets upstreamed!

Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
@SageMoore
Copy link
Contributor

CC @mgoin

Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
@maleksan85
Copy link
Contributor Author

maleksan85 commented Sep 25, 2025

CI test model-executor-test fails on main:f552d5e578077574276aa9d83139b91e1d5ae163 as well which this branch is based on. Please force merge this PR. Thanks.

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove the x_quant_scales change to the linear layer

Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
@maleksan85 maleksan85 requested a review from mgoin September 25, 2025 22:00
@maleksan85
Copy link
Contributor Author

Let's remove the x_quant_scales change to the linear layer

removed, please look again.

@maleksan85
Copy link
Contributor Author

CI test that failed, passes locally:
python -m pytest -svx tests/v1/e2e/test_spec_decode.py::test_eagle_correctness[FLASH_ATTN-llama3_eagle]

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now, thank you!

Copy link
Contributor

@zejunchen-zejun zejunchen-zejun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mgoin mgoin merged commit 53a3084 into vllm-project:main Sep 26, 2025
47 checks passed
Comment on lines +132 to +166
if self.emulate:
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
requires_grad=False)
try:
from quark.torch.export.nn.modules import realquantizer
from quark.torch.quantization.config.config import (
QuantizationSpec)
except ImportError as err:
raise ImportError(
"The package `amd-quark` is required to use AMD Quark "
"MX-FP4 models. Please install it with `pip install "
"amd-quark`.") from err

weight_quant_spec = QuantizationSpec.from_dict(
self.weight_quant_spec)

weight_quantizer = realquantizer.get_real_quantizer(
qspec=weight_quant_spec,
quantizer=None,
real_quantized=True,
reorder=False,
float_dtype=self.out_dtype,
scale_shape=layer.weight_scale.shape,
zero_point_shape=None,
)
weight_quantizer.scale.data = layer.weight_scale.data

layer.weight = torch.nn.Parameter(
weight_quantizer(layer.weight.data).to(self.out_dtype),
requires_grad=False,
)
layer.weight_scale = None

# This call is necessary to release the scales memory.
torch.cuda.empty_cache()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I insist that this is unnecessary https://github.com/vllm-project/vllm/pull/25135/files#r2378191214 - was not able to reopen the thread that was closed unfortunately.

yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: Doug Lehr <douglehr@amd.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llama Related to Llama models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants