Skip to content

Commit e719bb0

Browse files
authored
[1/2] Refactor multi-tokenizer manager (#10074)
1 parent 0672468 commit e719bb0

File tree

6 files changed

+424
-488
lines changed

6 files changed

+424
-488
lines changed

python/sglang/srt/entrypoints/engine.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,24 @@ def launch_phase_sigquit_handler(signum, frame):
704704
mp.set_start_method("spawn", force=True)
705705

706706

707+
def _init_tokenizer_manager(
708+
server_args: ServerArgs, port_args: PortArgs
709+
) -> TokenizerManager:
710+
# Launch tokenizer process
711+
tokenizer_manager = TokenizerManager(server_args, port_args)
712+
713+
# Initialize templates
714+
template_manager = TemplateManager()
715+
template_manager.initialize_templates(
716+
tokenizer_manager=tokenizer_manager,
717+
model_path=server_args.model_path,
718+
chat_template=server_args.chat_template,
719+
completion_template=server_args.completion_template,
720+
)
721+
722+
return tokenizer_manager, template_manager
723+
724+
707725
def _launch_subprocesses(
708726
server_args: ServerArgs, port_args: Optional[PortArgs] = None
709727
) -> Tuple[TokenizerManager, TemplateManager, Dict]:
@@ -816,23 +834,15 @@ def _launch_subprocesses(
816834
),
817835
)
818836
detoken_proc.start()
837+
838+
# Init tokenizer manager first, as the bootstrap server is initialized here
819839
if server_args.tokenizer_worker_num > 1:
820840
# Launch multi-tokenizer router
821841
tokenizer_manager = MultiTokenizerRouter(server_args, port_args)
822-
823-
# Initialize templates
824842
template_manager = None
825843
else:
826-
# Launch tokenizer process
827-
tokenizer_manager = TokenizerManager(server_args, port_args)
828-
829-
# Initialize templates
830-
template_manager = TemplateManager()
831-
template_manager.initialize_templates(
832-
tokenizer_manager=tokenizer_manager,
833-
model_path=server_args.model_path,
834-
chat_template=server_args.chat_template,
835-
completion_template=server_args.completion_template,
844+
tokenizer_manager, template_manager = _init_tokenizer_manager(
845+
server_args, port_args
836846
)
837847

838848
# Wait for the model to finish loading
@@ -856,5 +866,7 @@ def _launch_subprocesses(
856866

857867
# Assume all schedulers have the same scheduler_info
858868
scheduler_info = scheduler_infos[0]
869+
859870
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
871+
860872
return tokenizer_manager, template_manager, scheduler_info

python/sglang/srt/entrypoints/http_server.py

Lines changed: 26 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@
9292
)
9393
from sglang.srt.managers.multi_tokenizer_mixin import (
9494
MultiTokenizerManager,
95-
deserialize_data,
9695
get_main_process_id,
9796
read_from_shared_memory,
9897
write_data_for_multi_tokenizer,
@@ -136,33 +135,22 @@ def set_global_state(global_state: _GlobalState):
136135
_global_state = global_state
137136

138137

139-
# Function to set up all middlewares for multi-tokenizer compatibility
140-
def setup_middlewares(api_key: Optional[str], enable_metrics: bool):
141-
"""Setup all middlewares for both single and multi-process modes"""
142-
worker_pid = os.getpid()
143-
144-
if api_key:
145-
add_api_key_middleware(app, api_key)
146-
logger.info(f"Worker {worker_pid} added API key middleware")
147-
148-
if enable_metrics:
149-
add_prometheus_middleware(app)
150-
enable_func_timer()
151-
logger.info(f"Worker {worker_pid} added prometheus middleware")
152-
153-
154138
async def init_multi_tokenizer() -> ServerArgs:
155139
"""Read args information from shm and init tokenizer manager for current process"""
156140
pid = os.getpid()
157141
main_pid = get_main_process_id()
158142
logger.info(f"current worker_id: {pid}, main processID: {main_pid}")
159143

160144
# Read configuration from shared memory
161-
port_args_data = read_from_shared_memory(f"port_args_{main_pid}")
162-
server_args_data = read_from_shared_memory(f"server_args_{main_pid}")
163-
scheduler_info_data = read_from_shared_memory(f"scheduler_info_{main_pid}")
164-
port_args, server_args = deserialize_data(port_args_data, server_args_data)
165-
scheduler_info = scheduler_info_data
145+
port_args, server_args, scheduler_info = read_from_shared_memory(
146+
f"multi_tokenizer_args_{main_pid}"
147+
)
148+
server_args: ServerArgs
149+
150+
# API key authentication is not supported in multi-tokenizer mode
151+
assert (
152+
server_args.api_key is None
153+
), "API key is not supported in multi-tokenizer mode"
166154

167155
port_args.tokenizer_ipc_name = (
168156
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
@@ -193,13 +181,17 @@ async def init_multi_tokenizer() -> ServerArgs:
193181

194182
@asynccontextmanager
195183
async def lifespan(fast_api_app: FastAPI):
196-
server_args = getattr(fast_api_app, "server_args", None)
197-
if server_args is None:
184+
if not getattr(fast_api_app, "is_single_tokenizer_mode", False):
198185
# Initialize multi-tokenizer support for worker processes
199-
fast_api_app.server_args = await init_multi_tokenizer()
200-
setup_middlewares(
201-
fast_api_app.server_args.api_key, fast_api_app.server_args.enable_metrics
202-
)
186+
fast_api_app.server_args: ServerArgs = await init_multi_tokenizer()
187+
188+
# only metrics middleware is supported in multi-tokenizer mode
189+
worker_pid = os.getpid()
190+
if fast_api_app.server_args.enable_metrics:
191+
add_prometheus_middleware(app)
192+
enable_func_timer()
193+
194+
logger.info(f"Worker {worker_pid} added prometheus middleware")
203195
fast_api_app.warmup_thread = threading.Thread(
204196
target=_wait_and_warmup,
205197
args=(
@@ -1187,12 +1179,10 @@ def launch_server(
11871179
)
11881180

11891181
if server_args.tokenizer_worker_num > 1:
1190-
port_args_shm, server_args_shm, scheduler_info_shm = (
1191-
write_data_for_multi_tokenizer(
1192-
port_args,
1193-
server_args,
1194-
scheduler_info,
1195-
)
1182+
multi_tokenizer_args_shm = write_data_for_multi_tokenizer(
1183+
port_args,
1184+
server_args,
1185+
scheduler_info,
11961186
)
11971187
else:
11981188
# Add api key authorization
@@ -1239,6 +1229,7 @@ def launch_server(
12391229
workers=server_args.tokenizer_worker_num,
12401230
)
12411231
else:
1232+
app.is_single_tokenizer_mode = True
12421233
uvicorn.run(
12431234
app,
12441235
host=server_args.host,
@@ -1249,10 +1240,8 @@ def launch_server(
12491240
)
12501241
finally:
12511242
if server_args.tokenizer_worker_num > 1:
1252-
port_args_shm.unlink()
1253-
server_args_shm.unlink()
1254-
scheduler_info_shm.unlink()
1255-
_global_state.tokenizer_manager.clear_tokenizer_mapping()
1243+
multi_tokenizer_args_shm.unlink()
1244+
_global_state.tokenizer_manager.socket_mapping.clear_all_sockets()
12561245
else:
12571246
warmup_thread.join()
12581247

python/sglang/srt/managers/detokenizer_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
FreezeGCReq,
3535
MultiTokenizerRegisterReq,
3636
)
37-
from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerMixin
37+
from sglang.srt.managers.multi_tokenizer_mixin import MultiHttpWorkerDetokenizerMixin
3838
from sglang.srt.server_args import PortArgs, ServerArgs
3939
from sglang.srt.utils import (
4040
configure_logger,
@@ -69,7 +69,7 @@ class DecodeStatus:
6969
sent_offset: int = 0
7070

7171

72-
class DetokenizerManager(MultiTokenizerMixin):
72+
class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
7373
"""DetokenizerManager is a process that detokenizes the token ids."""
7474

7575
def __init__(
@@ -289,11 +289,11 @@ def run_detokenizer_process(
289289
try:
290290
manager = DetokenizerManager(server_args, port_args)
291291
if server_args.tokenizer_worker_num > 1:
292-
manager.multi_tokenizer_manager_event_loop()
292+
manager.multi_http_worker_event_loop()
293293
else:
294294
manager.event_loop()
295295
except Exception:
296-
manager.clear_tokenizer_mapping()
296+
manager.socket_mapping.clear_all_sockets()
297297
traceback = get_exception_traceback()
298298
logger.error(f"DetokenizerManager hit an exception: {traceback}")
299299
parent_process.send_signal(signal.SIGQUIT)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""Start bootstrap/kv-store-related server"""
2+
3+
import os
4+
from typing import Type
5+
6+
from sglang.srt.disaggregation.base import BaseKVBootstrapServer
7+
from sglang.srt.disaggregation.utils import (
8+
DisaggregationMode,
9+
KVClassType,
10+
TransferBackend,
11+
get_kv_class,
12+
)
13+
from sglang.srt.server_args import ServerArgs
14+
15+
16+
def start_disagg_service(
17+
server_args: ServerArgs,
18+
):
19+
# Start kv boostrap server on prefill
20+
disagg_mode = DisaggregationMode(server_args.disaggregation_mode)
21+
transfer_backend = TransferBackend(server_args.disaggregation_transfer_backend)
22+
23+
if disagg_mode == DisaggregationMode.PREFILL:
24+
# only start bootstrap server on prefill tm
25+
kv_bootstrap_server_class: Type[BaseKVBootstrapServer] = get_kv_class(
26+
transfer_backend, KVClassType.BOOTSTRAP_SERVER
27+
)
28+
bootstrap_server: BaseKVBootstrapServer = kv_bootstrap_server_class(
29+
host=server_args.host,
30+
port=server_args.disaggregation_bootstrap_port,
31+
)
32+
is_create_store = (
33+
server_args.node_rank == 0 and transfer_backend == TransferBackend.ASCEND
34+
)
35+
if is_create_store:
36+
try:
37+
from mf_adapter import create_config_store
38+
39+
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
40+
create_config_store(ascend_url)
41+
except Exception as e:
42+
error_message = f"Failed create mf store, invalid ascend_url."
43+
error_message += f" With exception {e}"
44+
raise error_message
45+
46+
return bootstrap_server

0 commit comments

Comments
 (0)