Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
084059a
initial commit
inkcherry Nov 19, 2025
e60238f
update gitignore
inkcherry Nov 19, 2025
f34fbf4
fix with new main branch
inkcherry Nov 19, 2025
7a2b576
fix dp router
inkcherry Nov 19, 2025
4fac8b7
format
inkcherry Nov 19, 2025
88f7b39
add proxy example
inkcherry Nov 19, 2025
e3a7c9c
refine
inkcherry Nov 20, 2025
be0fdea
fix dp proxy
inkcherry Nov 20, 2025
b0b980b
refine code
inkcherry Nov 20, 2025
a4ae8ce
refine
inkcherry Nov 20, 2025
a17e2b8
refine
inkcherry Nov 20, 2025
b4ca43b
format
inkcherry Nov 20, 2025
773e071
format
inkcherry Nov 20, 2025
8682732
remove port
inkcherry Nov 20, 2025
00eb53b
format
inkcherry Nov 20, 2025
652dbe4
refine
inkcherry Nov 20, 2025
7d2f43d
fix format error
inkcherry Nov 20, 2025
af8dd1d
update
inkcherry Nov 20, 2025
2dfe859
fix mypy
inkcherry Nov 20, 2025
519fd04
fix mypy and tp test pass
inkcherry Nov 20, 2025
e32456a
fix all mypy
inkcherry Nov 20, 2025
fe262d9
break long line
inkcherry Nov 21, 2025
0af2637
fix format
inkcherry Nov 21, 2025
4d7a37b
fix all commit
inkcherry Nov 21, 2025
6e699f4
refine
inkcherry Nov 21, 2025
7731c76
update
inkcherry Nov 21, 2025
fcf0590
update
inkcherry Nov 21, 2025
ae3f6e0
update
inkcherry Nov 21, 2025
23b59a4
refine code
inkcherry Nov 21, 2025
1c0e39c
remove handle_proxy_request
inkcherry Nov 21, 2025
4e88ec1
more
inkcherry Nov 21, 2025
7622000
update
inkcherry Nov 21, 2025
2a56356
tp write single pass
inkcherry Nov 21, 2025
ad7b00d
updata finished request collection
inkcherry Nov 21, 2025
5daf6ce
format
inkcherry Nov 24, 2025
03bc080
update lock
inkcherry Nov 24, 2025
3a5ea09
update proxy path
inkcherry Nov 24, 2025
ae6d9c3
format
inkcherry Nov 24, 2025
8478bb4
update
inkcherry Nov 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,4 @@ ep_kernels_workspace/

# Allow tracked library source folders under submodules (e.g., benchmarks/lib)
!vllm/benchmarks/lib/
examples/online_serving/disaggregated_serving_p2p_moriio_xpyd/
Original file line number Diff line number Diff line change
@@ -0,0 +1,305 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import copy
import logging
import os
import re
import socket
import threading
import uuid

import aiohttp
import msgpack
import zmq
from quart import Quart, make_response, request

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
prefill_instances: list[dict] = []
decode_instances: list[dict] = []
request_nums = 0
app = Quart(__name__)

IP_PORT_PATTERN = re.compile(r"//(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}):(\d+)")


TRANSFER_TYPE = None


def _append_whole_dict_unique(target_list, data_dict):
new_filtered = {k: v for k, v in data_dict.items() if k != "index"}
for existed in target_list:
existed_filtered = {k: v for k, v in existed.items() if k != "index"}
if existed_filtered == new_filtered:
return False
print("!!APPEND!!", data_dict)
target_list.append(data_dict)
transfer_mode = data_dict.get("transfer_mode", "unknown")
global TRANSFER_TYPE

if TRANSFER_TYPE is None:
TRANSFER_TYPE = transfer_mode
logger.info("SET TRANSFER TYPE TO %s", TRANSFER_TYPE)
elif transfer_mode != TRANSFER_TYPE:
raise ValueError(f"mismatched transfer mode {TRANSFER_TYPE} vs {transfer_mode}")

return True


_list_lock = threading.RLock()


def _listen_for_register(hostname, port):
context = zmq.Context()
router_socket = context.socket(zmq.ROUTER)
router_socket.bind(f"tcp://{hostname}:{port}")
poller = zmq.Poller()
poller.register(router_socket, zmq.POLLIN)
global prefill_instances
global decode_instances

while True:
socks = dict(poller.poll())
if router_socket in socks:
remote_addr, msg = router_socket.recv_multipart()
data = msgpack.loads(msg)
if data["type"] == "HELLO":
pass
elif (
data["type"] == "register"
and data["role"] == "P"
and data["request_address"] not in prefill_instances
):
with _list_lock:
_append_whole_dict_unique(prefill_instances, data)

elif (
data["type"] == "register"
and data["role"] == "D"
and data["request_address"] not in decode_instances
):
with _list_lock:
_append_whole_dict_unique(decode_instances, data)


def start_service_discovery(hostname, port):
if not hostname:
hostname = socket.gethostname()
if port == 0:
raise ValueError("Port cannot be 0")

_listener_thread = threading.Thread(
target=_listen_for_register, args=(hostname, port), daemon=True
)
_listener_thread.start()
return _listener_thread


async def send_request_to_prefill(
endpoint, req_data, request_id, p_endpoint, pip, pports, selected_prefill_dp_rank
):
req_data_copy = req_data

req_data_copy["kv_transfer_params"].update(
{
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_handshake_port": p_endpoint["handshake_port"],
"remote_notify_port": p_endpoint["notify_port"],
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": pip,
"remote_port": pports,
}
)
req_data_copy["stream"] = False
req_data_copy["max_tokens"] = 1
if "max_completion_tokens" in req_data_copy:
req_data_copy["max_completion_tokens"] = 1
if "stream_options" in req_data_copy:
del req_data_copy["stream_options"]
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)
) as session:
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id,
}
if selected_prefill_dp_rank is not None:
headers["X-data-parallel-rank"] = str(selected_prefill_dp_rank)
async with session.post(
url=endpoint, json=req_data_copy, headers=headers
) as response:
if response.status == 200:
return await response.json()

else:
raise RuntimeError(
"send_request_to_prefill response.status != 200response.status = ",
response.status,
)


async def start_decode_request(endpoint, req_data, request_id):
session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)
)
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id,
}
response = await session.post(url=endpoint, json=req_data, headers=headers)
return session, response


async def stream_decode_response(session, response, request_id):
try:
if response.status == 200:
async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes
else:
raise RuntimeError(
f"decode response.status != 200, status = {response.status}"
)
finally:
await session.close()


async def send_request_to_decode(endpoint, req_data, request_id):
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)
) as session:
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id,
}
async with session.post(
url=endpoint, json=req_data, headers=headers
) as response:
if response.status == 200:
async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes
else:
raise RuntimeError(
"send_request_to_decode response.status != 200,response.statuus = ",
response.status,
)


def example_round_robin_dp_loader(request_number, dp_size):
return request_nums % dp_size


@app.route("/v1/completions", methods=["POST"])
@app.route("/v1/chat/completions", methods=["POST"])
async def handle_request():
try:
global request_nums
request_nums += 1

def extract_ip_port_fast(url):
match = IP_PORT_PATTERN.search(url)
if not match:
raise ValueError(f"Invalid URL format: {url}")
return match.groups()

req_data = await request.get_json()
request_id = str(uuid.uuid4())

prefill_instance_endpoint = None
decode_instance_endpoint = None

pid = request_nums % len(prefill_instances)
did = request_nums % len(decode_instances)
prefill_instance_endpoint = prefill_instances[pid]
decode_instance_endpoint = decode_instances[did]

selected_prefill_dp_rank = None
if prefill_instance_endpoint["dp_size"] > 1:
selected_prefill_dp_rank = example_round_robin_dp_loader(
request_nums // len(prefill_instance_endpoint),
prefill_instance_endpoint["dp_size"],
)

dip, dport = extract_ip_port_fast(decode_instance_endpoint["request_address"])
ip, port = extract_ip_port_fast(prefill_instance_endpoint["request_address"])

req_data_to_prefill = copy.deepcopy(req_data)
req_data_to_prefill["kv_transfer_params"] = {}
req_data["kv_transfer_params"] = {}
req_data_to_prefill["kv_transfer_params"]["remote_dp_size"] = (
decode_instance_endpoint["dp_size"]
)
req_data_to_prefill["kv_transfer_params"]["remote_tp_size"] = (
decode_instance_endpoint["tp_size"]
)

send_prefill_task = asyncio.create_task(
send_request_to_prefill(
prefill_instance_endpoint["request_address"],
req_data_to_prefill,
request_id,
decode_instance_endpoint,
dip,
dport,
selected_prefill_dp_rank,
)
)
ip, port = extract_ip_port_fast(prefill_instance_endpoint["request_address"])

req_data["max_tokens"] -= 1

req_data["kv_transfer_params"] = {
"do_remote_decode": False,
"do_remote_prefill": True,
"remote_handshake_port": prefill_instance_endpoint["handshake_port"],
"remote_notify_port": prefill_instance_endpoint["notify_port"],
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": ip,
"remote_port": port,
}
if TRANSFER_TYPE == "READ":
# In read mode, prefill and decode are executed serially.
prefill_response = await send_prefill_task
req_data["kv_transfer_params"]["remote_engine_id"] = prefill_response[
"kv_transfer_params"
]["remote_engine_id"]
req_data["kv_transfer_params"]["remote_block_ids"] = prefill_response[
"kv_transfer_params"
]["remote_block_ids"]

req_data["kv_transfer_params"]["remote_dp_size"] = prefill_instance_endpoint[
"dp_size"
]
req_data["kv_transfer_params"]["remote_tp_size"] = prefill_instance_endpoint[
"tp_size"
]

if selected_prefill_dp_rank is not None:
req_data["kv_transfer_params"]["remote_dp_rank"] = selected_prefill_dp_rank

decode_request_task = asyncio.create_task(
start_decode_request(
decode_instance_endpoint["request_address"], req_data, request_id
)
)

session, decode_response = await decode_request_task
stream_generator = stream_decode_response(session, decode_response, request_id)
response = await make_response(stream_generator)
return response
except Exception as e:
print(e)
pass


if __name__ == "__main__":
t = start_service_discovery("0.0.0.0", 36367)
app.debug = True
app.config["BODY_TIMEOUT"] = 360000
app.config["RESPONSE_TIMEOUT"] = 360000

app.run(host="0.0.0.0", port=10001)
t.join()
6 changes: 6 additions & 0 deletions vllm/distributed/kv_transfer/kv_connector/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,12 @@ def get_connector_class(
"MultiConnector",
)

KVConnectorFactory.register_connector(
"MoRIIOConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.moriio_connector",
"MoRIIOConnector",
)

KVConnectorFactory.register_connector(
"OffloadingConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector",
Expand Down
Loading
Loading