Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/kernels/benchmark_w8a8_block_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
from typing import Any

import torch
import triton
from tqdm import tqdm

from vllm.model_executor.layers.quantization.utils.fp8_utils import (
_w8a8_block_fp8_matmul,
)
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils import FlexibleArgumentParser

mp.set_start_method("spawn", force=True)
Expand Down
9 changes: 1 addition & 8 deletions examples/others/tensorize_vllm_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import argparse
import dataclasses
import json
import logging
import os
Expand Down Expand Up @@ -327,12 +325,7 @@ def main():


if args.command == "serialize":
eng_args_dict = {f.name: getattr(args, f.name) for f in
dataclasses.fields(EngineArgs)}

engine_args = EngineArgs.from_cli_args(
argparse.Namespace(**eng_args_dict)
)
engine_args = EngineArgs.from_cli_args(args)

input_dir = tensorizer_dir.rstrip('/')
suffix = args.suffix if args.suffix else uuid.uuid4().hex
Expand Down
2 changes: 1 addition & 1 deletion tests/entrypoints/test_api_server_process_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_api_server_process_manager_init(api_server_args, with_stats_update):
global WORKER_RUNTIME_SECONDS
WORKER_RUNTIME_SECONDS = 0.5

# Copy the args to avoid mutating the
# Copy the args to avoid mutating them
args = api_server_args.copy()

if not with_stats_update:
Expand Down
52 changes: 48 additions & 4 deletions tests/v1/test_external_lb_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
import requests

from tests.utils import RemoteOpenAIServer
from vllm.platforms import current_platform
Expand Down Expand Up @@ -70,6 +71,8 @@ def start_server(r: int, sargs: list[str]):
sargs,
auto_port=False,
env_dict={
"VLLM_SERVER_DEV_MODE":
"1",
current_platform.device_control_env_var:
",".join(
str(
Expand Down Expand Up @@ -127,11 +130,19 @@ def default_server_args():


@pytest.fixture(scope="module", params=[1, 4])
def servers(request, default_server_args):
def server_manager(request, default_server_args):
api_server_count = request.param
with ExternalLBServerManager(MODEL_NAME, DP_SIZE, api_server_count,
default_server_args) as server_list:
yield server_list
server_manager = ExternalLBServerManager(MODEL_NAME, DP_SIZE,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks a bit weird to have the same variable name as method name here (any particular reason for introducing intermediate var here?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is because __enter__ returns the server list rather than the manager itself

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can change the return value of __enter__ if you want

api_server_count,
default_server_args)

with server_manager:
yield server_manager


@pytest.fixture
def servers(server_manager):
return server_manager.servers


@pytest_asyncio.fixture
Expand All @@ -144,6 +155,39 @@ async def clients(servers: list[tuple[RemoteOpenAIServer, list[str]]]):
]


def _get_parallel_config(server: RemoteOpenAIServer):
response = requests.get(server.url_for("server_info?config_format=json"))
response.raise_for_status()

vllm_config = response.json()["vllm_config"]
return vllm_config["parallel_config"]


def test_external_lb_server_info(server_manager):
servers = server_manager.servers
api_server_count = server_manager.api_server_count

for i, (server, _) in enumerate(servers):
print(f"Testing {i=}")

# Each request will hit one of the API servers
# `n_reqs` is set so that there is a good chance each server
# receives at least one request
n_reqs = 2 * api_server_count * api_server_count
parallel_configs = [
_get_parallel_config(server) for _ in range(n_reqs)
]
api_process_counts = [
c["_api_process_count"] for c in parallel_configs
]
api_process_ranks = [c["_api_process_rank"] for c in parallel_configs]

assert all(c == api_server_count
for c in api_process_counts), api_process_counts
assert all(0 <= r < api_server_count
for r in api_process_ranks), api_process_ranks


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
Expand Down
54 changes: 49 additions & 5 deletions tests/v1/test_hybrid_lb_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
import requests

from tests.utils import RemoteOpenAIServer
from tests.v1.test_utils import check_request_balancing
Expand Down Expand Up @@ -92,6 +93,8 @@ def start_server(node: int, sargs: list[str]):
sargs,
auto_port=False,
env_dict={
"VLLM_SERVER_DEV_MODE":
"1",
current_platform.device_control_env_var:
",".join(
str(
Expand Down Expand Up @@ -150,12 +153,20 @@ def default_server_args():


@pytest.fixture(scope="module", params=[1, 4])
def servers(request, default_server_args):
def server_manager(request, default_server_args):
api_server_count = request.param
with HybridLBServerManager(MODEL_NAME, DP_SIZE, api_server_count,
default_server_args, DP_SIZE_LOCAL,
TP_SIZE) as server_list:
yield server_list
server_manager = HybridLBServerManager(MODEL_NAME, DP_SIZE,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment

api_server_count,
default_server_args, DP_SIZE_LOCAL,
TP_SIZE)

with server_manager:
yield server_manager


@pytest.fixture
def servers(server_manager):
return server_manager.servers


@pytest_asyncio.fixture
Expand All @@ -168,6 +179,39 @@ async def clients(servers: list[tuple[RemoteOpenAIServer, list[str]]]):
]


def _get_parallel_config(server: RemoteOpenAIServer):
response = requests.get(server.url_for("server_info?config_format=json"))
response.raise_for_status()

vllm_config = response.json()["vllm_config"]
return vllm_config["parallel_config"]


def test_hybrid_dp_server_info(server_manager):
servers = server_manager.servers
api_server_count = server_manager.api_server_count

for i, (server, _) in enumerate(servers):
print(f"Testing {i=}")

# Each request will hit one of the API servers
# `n_reqs` is set so that there is a good chance each server
# receives at least one request
n_reqs = 2 * api_server_count * api_server_count
parallel_configs = [
_get_parallel_config(server) for _ in range(n_reqs)
]
api_process_counts = [
c["_api_process_count"] for c in parallel_configs
]
api_process_ranks = [c["_api_process_rank"] for c in parallel_configs]

assert all(c == api_server_count
for c in api_process_counts), api_process_counts
assert all(0 <= r < api_server_count
for r in api_process_ranks), api_process_ranks


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
Expand Down
57 changes: 49 additions & 8 deletions tests/v1/test_internal_lb_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
import requests

from tests.utils import RemoteOpenAIServer
from tests.v1.test_utils import check_request_balancing
Expand Down Expand Up @@ -101,6 +102,8 @@ def start_server(sidx: int, r: int, sargs: list[str]):
sargs,
auto_port=False,
env_dict={
"VLLM_SERVER_DEV_MODE":
"1",
current_platform.device_control_env_var:
",".join(
str(
Expand Down Expand Up @@ -214,7 +217,10 @@ def start_api_server():
self.model_name,
api_server_args,
auto_port=False,
env_dict={}) # No GPUs needed for API-only server
env_dict={
"VLLM_SERVER_DEV_MODE": "1",
# No GPUs needed for API-only server
})
server.__enter__()
print(f"API-only server started successfully with "
f"{self.api_server_count} API servers")
Expand Down Expand Up @@ -293,14 +299,21 @@ def default_server_args():


@pytest.fixture(scope="module", params=[1, 4])
def servers(request, default_server_args):
def server_manager(request, default_server_args):
api_server_count = request.param
with MultinodeInternalLBServerManager(MODEL_NAME, DP_SIZE,
api_server_count,
default_server_args,
DP_SIZE // NUM_NODES,
TP_SIZE) as server_list:
yield server_list
server_manager = MultinodeInternalLBServerManager(MODEL_NAME, DP_SIZE,
api_server_count,
default_server_args,
DP_SIZE // NUM_NODES,
TP_SIZE)

with server_manager:
yield server_manager


@pytest.fixture
def servers(server_manager):
return server_manager.servers


@pytest.fixture(scope="module", params=[1, 4])
Expand Down Expand Up @@ -331,6 +344,34 @@ async def api_only_client(api_only_servers: list[tuple[RemoteOpenAIServer,
yield client


def _get_parallel_config(server: RemoteOpenAIServer):
response = requests.get(server.url_for("server_info?config_format=json"))
response.raise_for_status()

vllm_config = response.json()["vllm_config"]
return vllm_config["parallel_config"]


def test_multinode_dp_server_info(server_manager):
head_server = server_manager.servers[0][0]
api_server_count = server_manager.api_server_count

# Each request will hit one of the API servers
# `n_reqs` is set so that there is a good chance each server
# receives at least one request
n_reqs = 2 * api_server_count * api_server_count
parallel_configs = [
_get_parallel_config(head_server) for _ in range(n_reqs)
]
api_process_counts = [c["_api_process_count"] for c in parallel_configs]
api_process_ranks = [c["_api_process_rank"] for c in parallel_configs]

assert all(c == api_server_count
for c in api_process_counts), api_process_counts
assert all(0 <= r < api_server_count
for r in api_process_ranks), api_process_ranks


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
Expand Down
25 changes: 25 additions & 0 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,25 @@ class is dynamically inherited by the worker class. This is used to inject
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
needs to be divisible by dcp_size."""

_api_process_count: int = 1
"""
The number of API processes initialized.

Note:
This is an internal config that is only valid for and
should only be set by API server scale-out.
"""

_api_process_rank: int = 0
"""
The rank of this API process, or `-1` for engine core processes
under API server scale-out.

Note:
This is an internal config that is only valid for and
should only be set by API server scale-out.
"""

@property
def world_size_across_dp(self) -> int:
"""world_size_across_dp is TPxPPxDP, it is the size of the world
Expand Down Expand Up @@ -428,6 +447,12 @@ def __post_init__(self) -> None:
if self.distributed_executor_backend is None and self.world_size == 1:
self.distributed_executor_backend = "uni"

if not -1 <= self._api_process_rank < self._api_process_count:
raise ValueError(
"Invalid value of `_api_process_rank`. "
f"Expected to be `-1` or `[0, {self._api_process_count})`, "
f"but found: {self._api_process_rank}")

@property
def use_ray(self) -> bool:
return self.distributed_executor_backend == "ray" or (
Expand Down
9 changes: 8 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,8 @@ class EngineArgs:
enable_eplb: bool = ParallelConfig.enable_eplb
expert_placement_strategy: ExpertPlacementStrategy = \
ParallelConfig.expert_placement_strategy
_api_process_count: int = ParallelConfig._api_process_count
_api_process_rank: int = ParallelConfig._api_process_rank
num_redundant_experts: int = EPLBConfig.num_redundant_experts
eplb_window_size: int = EPLBConfig.window_size
eplb_step_interval: int = EPLBConfig.step_interval
Expand Down Expand Up @@ -951,7 +953,10 @@ def from_cli_args(cls, args: argparse.Namespace):
# Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)]
# Set the attributes from the parsed arguments.
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
engine_args = cls(**{
attr: getattr(args, attr)
for attr in attrs if hasattr(args, attr)
})
return engine_args

def create_model_config(self) -> ModelConfig:
Expand Down Expand Up @@ -1364,6 +1369,8 @@ def create_engine_config(
worker_cls=self.worker_cls,
worker_extension_cls=self.worker_extension_cls,
decode_context_parallel_size=self.decode_context_parallel_size,
_api_process_count=self._api_process_count,
_api_process_rank=self._api_process_rank,
)

speculative_config = self.create_speculative_config(
Expand Down
Loading