Skip to content

Commit 58095cb

Browse files
ciciroriscottjlee
andauthored
Add timing metrics for requests (#12646)
Co-authored-by: Scott Lee <scottjlee@users.noreply.github.com>
1 parent fd3034d commit 58095cb

File tree

9 files changed

+334
-52
lines changed

9 files changed

+334
-52
lines changed

python/sglang/srt/entrypoints/openai/serving_base.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import json
44
import logging
5+
import time
56
import uuid
67
from abc import ABC, abstractmethod
78
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
@@ -84,17 +85,23 @@ def _validate_lora_enabled(self, adapter_name: str) -> None:
8485
async def handle_request(
8586
self, request: OpenAIServingRequest, raw_request: Request
8687
) -> Union[Any, StreamingResponse, ErrorResponse]:
87-
"""Handle the specific request type with common pattern"""
88+
"""Handle the specific request type with common pattern
89+
If you want to override this method, you should be careful to record the validation time.
90+
"""
8891
try:
8992
# Validate request
93+
validation_start = time.perf_counter()
9094
error_msg = self._validate_request(request)
95+
validation_time = time.perf_counter() - validation_start
9196
if error_msg:
9297
return self.create_error_response(error_msg)
9398

9499
# Convert to internal format
95100
adapted_request, processed_request = self._convert_to_internal_request(
96101
request, raw_request
97102
)
103+
if hasattr(adapted_request, "validation_time"):
104+
adapted_request.validation_time = validation_time
98105

99106
# Note(Xinyuan): raw_request below is only used for detecting the connection of the client
100107
if hasattr(request, "stream") and request.stream:
@@ -157,6 +164,7 @@ def _convert_to_internal_request(
157164
self,
158165
request: OpenAIServingRequest,
159166
raw_request: Request = None,
167+
validation_time: float = None,
160168
) -> tuple[GenerateReqInput, OpenAIServingRequest]:
161169
"""Convert OpenAI request to internal format"""
162170
pass

python/sglang/srt/grpc/grpc_request_manager.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ class GrpcReqState:
8080
last_time: float = 0.0
8181
last_completion_tokens: int = 1
8282

83+
# perf_counter equivalents for accurate time calculations
84+
finished_time_perf: float = 0.0
85+
first_token_time_perf: float = 0.0
86+
8387
# Streaming state
8488
stream_finished: bool = False
8589
input_logprobs_sent: bool = False # Track if input logprobs were sent in streaming
@@ -536,6 +540,7 @@ async def _handle_batch_output(self, batch_out: BatchTokenIDOutput):
536540
put_tasks = []
537541
cleanup_tasks = []
538542
now = time.time()
543+
now_perf_counter = time.perf_counter()
539544

540545
# Process each request in the batch
541546
for i, rid in enumerate(batch_out.rids):
@@ -552,6 +557,7 @@ async def _handle_batch_output(self, batch_out: BatchTokenIDOutput):
552557
# Update metrics
553558
if state.first_token_time == 0.0:
554559
state.first_token_time = now
560+
state.first_token_time_perf = now_perf_counter
555561
state.last_time = now
556562

557563
# Extract output for this request
@@ -650,6 +656,7 @@ def get_part(attr_name):
650656
if output_data["finished"]:
651657
state.finished = True
652658
state.finished_time = now
659+
state.finished_time_perf = now_perf_counter
653660
state.stream_finished = True
654661
state.event.set()
655662

@@ -691,6 +698,7 @@ async def _handle_embedding_output(self, batch_out: BatchEmbeddingOutput):
691698
# Mark as finished
692699
state.finished = True
693700
state.finished_time = time.time()
701+
state.finished_time_perf = time.perf_counter()
694702
state.event.set()
695703

696704
async def _handle_health_check_output(self, health_out: HealthCheckOutput):
@@ -723,6 +731,7 @@ async def _handle_health_check_output(self, health_out: HealthCheckOutput):
723731
# Mark as finished
724732
state.finished = True
725733
state.finished_time = time.time()
734+
state.finished_time_perf = time.perf_counter()
726735
state.event.set()
727736

728737
async def _handle_abort_req(self, recv_obj: AbortReq):

python/sglang/srt/managers/detokenizer_manager.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,10 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput):
277277
placeholder_tokens_val=None,
278278
retraction_counts=recv_obj.retraction_counts,
279279
token_steps=recv_obj.token_steps,
280+
queue_time=recv_obj.queue_time,
281+
forward_entry_time=recv_obj.forward_entry_time,
282+
prefill_delay=recv_obj.prefill_delay,
283+
prefill_latency=recv_obj.prefill_latency,
280284
)
281285

282286
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
@@ -291,6 +295,10 @@ def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
291295
cached_tokens=recv_obj.cached_tokens,
292296
placeholder_tokens_idx=None,
293297
placeholder_tokens_val=None,
298+
queue_time=recv_obj.queue_time,
299+
forward_entry_time=recv_obj.forward_entry_time,
300+
prefill_delay=recv_obj.prefill_delay,
301+
prefill_latency=recv_obj.prefill_latency,
294302
)
295303

296304
def handle_freeze_gc_req(self, recv_req: FreezeGCReq):

python/sglang/srt/managers/io_struct.py

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,55 @@ def regenerate_rids(self):
6161
return self.rids
6262

6363

64+
@dataclass
65+
class RequestTimingMetricsMixin:
66+
"""
67+
Mixin class containing common request-level timing metrics.
68+
69+
This class consolidates the timing metrics that are shared across all batch output types
70+
to avoid code duplication and ensure consistency.
71+
"""
72+
73+
# Queue duration: time spent waiting in queue before request is scheduled.
74+
queue_time: Optional[List[Optional[float]]]
75+
76+
# Forward entry time: timestamp when the request enters the forward pass stage.
77+
# This corresponds to `forward_entry_time` in TimeStats.
78+
# In different modes:
79+
# - Unified/PD-colocate: timestamp when forward computation begins (covers prefill + decode)
80+
# - Prefill instance (P): timestamp when prefill forward pass begins
81+
# - Decode instance (D): timestamp when decode forward pass begins
82+
# Note: This is NOT the same as prefill_start_time. There may be a delay between
83+
# forward_entry_time and prefill_start_time (see prefill_delay).
84+
forward_entry_time: Optional[List[Optional[float]]]
85+
86+
# Prefill delay: time spent waiting between forward entry and prefill start.
87+
# Calculated as: prefill_start_time - forward_entry_time
88+
# This represents the delay between when the request enters the forward stage
89+
# and when prefill computation actually begins.
90+
prefill_delay: Optional[List[Optional[float]]]
91+
92+
# Prefill latency: time spent during prefill computation.
93+
# Calculated as: prefill_end_time - prefill_start_time
94+
prefill_latency: Optional[List[Optional[float]]]
95+
96+
97+
@dataclass
98+
class SpeculativeDecodingMetricsMixin:
99+
"""
100+
Mixin class containing speculative decoding metrics.
101+
102+
This class consolidates speculative decoding metrics that are shared across
103+
batch output types that support speculative decoding to avoid code duplication.
104+
"""
105+
106+
# Verify count: number of verification forward passes
107+
spec_verify_ct: List[int]
108+
109+
# Accepted tokens: Number of accepted tokens during speculative decoding
110+
spec_accepted_tokens: List[int]
111+
112+
64113
# Parameters for a session
65114
@dataclass
66115
class SessionParams:
@@ -148,6 +197,9 @@ class GenerateReqInput(BaseReq):
148197
bootstrap_room: Optional[Union[List[int], int]] = None
149198
bootstrap_pair_key: Optional[Union[List[str], str]] = None
150199

200+
# Validation step duration
201+
validation_time: Optional[float] = None
202+
151203
# For data parallel rank routing
152204
data_parallel_rank: Optional[int] = None
153205

@@ -564,6 +616,7 @@ def __getitem__(self, i):
564616
if self.bootstrap_pair_key is not None
565617
else None
566618
),
619+
validation_time=self.validation_time,
567620
data_parallel_rank=(
568621
self.data_parallel_rank if self.data_parallel_rank is not None else None
569622
),
@@ -684,6 +737,8 @@ class EmbeddingReqInput(BaseReq):
684737
log_metrics: bool = True
685738
# The modalities of the image data [image, multi-images, video]
686739
modalities: Optional[List[str]] = None
740+
# Validation step duration
741+
validation_time: Optional[float] = None
687742
# For cross-encoder requests
688743
is_cross_encoder_request: bool = False
689744
# Priority for the request
@@ -774,6 +829,7 @@ def __getitem__(self, i):
774829
video_data=self.video_data[i] if self.video_data is not None else None,
775830
sampling_params=self.sampling_params[i],
776831
rid=self.rid[i],
832+
validation_time=self.validation_time,
777833
dimensions=self.dimensions,
778834
http_worker_ipc=self.http_worker_ipc,
779835
)
@@ -815,7 +871,9 @@ def __iter__(self):
815871

816872

817873
@dataclass
818-
class BatchTokenIDOutput(BaseBatchReq):
874+
class BatchTokenIDOutput(
875+
BaseBatchReq, RequestTimingMetricsMixin, SpeculativeDecodingMetricsMixin
876+
):
819877
# The finish reason
820878
finished_reasons: List[BaseFinishReason]
821879
# For incremental decoding
@@ -833,8 +891,6 @@ class BatchTokenIDOutput(BaseBatchReq):
833891
prompt_tokens: List[int]
834892
completion_tokens: List[int]
835893
cached_tokens: List[int]
836-
spec_verify_ct: List[int]
837-
spec_accepted_tokens: List[int]
838894

839895
# Logprobs
840896
input_token_logprobs_val: List[float]
@@ -868,7 +924,7 @@ class BatchTokenIDOutput(BaseBatchReq):
868924

869925

870926
@dataclass
871-
class BatchMultimodalDecodeReq(BaseBatchReq):
927+
class BatchMultimodalDecodeReq(BaseBatchReq, RequestTimingMetricsMixin):
872928
decoded_ids: List[int]
873929
input_token_logprobs_val: List[float]
874930
input_token_logprobs_idx: List[int]
@@ -900,7 +956,9 @@ class BatchMultimodalDecodeReq(BaseBatchReq):
900956

901957

902958
@dataclass
903-
class BatchStrOutput(BaseBatchReq):
959+
class BatchStrOutput(
960+
BaseBatchReq, RequestTimingMetricsMixin, SpeculativeDecodingMetricsMixin
961+
):
904962
# The finish reason
905963
finished_reasons: List[dict]
906964
# The output decoded strings
@@ -912,8 +970,6 @@ class BatchStrOutput(BaseBatchReq):
912970
prompt_tokens: List[int]
913971
completion_tokens: List[int]
914972
cached_tokens: List[int]
915-
spec_verify_ct: List[int]
916-
spec_accepted_tokens: List[int]
917973

918974
# Logprobs
919975
input_token_logprobs_val: List[float]
@@ -947,7 +1003,7 @@ class BatchStrOutput(BaseBatchReq):
9471003

9481004

9491005
@dataclass
950-
class BatchMultimodalOutput(BaseBatchReq):
1006+
class BatchMultimodalOutput(BaseBatchReq, RequestTimingMetricsMixin):
9511007
# The finish reason
9521008
finished_reasons: List[dict]
9531009
decoded_ids: List[List[int]]
@@ -972,7 +1028,7 @@ class BatchMultimodalOutput(BaseBatchReq):
9721028

9731029

9741030
@dataclass
975-
class BatchEmbeddingOutput(BaseBatchReq):
1031+
class BatchEmbeddingOutput(BaseBatchReq, RequestTimingMetricsMixin):
9761032
# The finish reason
9771033
finished_reasons: List[BaseFinishReason]
9781034
# The output embedding

python/sglang/srt/managers/multi_tokenizer_mixin.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,26 @@ def _handle_output_by_index(output, i):
9191
if isinstance(output, BatchTokenIDOutput):
9292
new_output = BatchTokenIDOutput(
9393
rids=[output.rids[i]],
94+
spec_verify_ct=(
95+
[output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
96+
),
97+
spec_accepted_tokens=(
98+
[output.spec_accepted_tokens[i]]
99+
if len(output.spec_accepted_tokens) > i
100+
else None
101+
),
102+
queue_time=[output.queue_time[i]] if len(output.queue_time) > i else None,
103+
forward_entry_time=(
104+
[output.forward_entry_time[i]]
105+
if len(output.forward_entry_time) > i
106+
else None
107+
),
108+
prefill_delay=(
109+
[output.prefill_delay[i]] if len(output.prefill_delay) > i else None
110+
),
111+
prefill_latency=(
112+
[output.prefill_latency[i]] if len(output.prefill_latency) > i else None
113+
),
94114
finished_reasons=(
95115
[output.finished_reasons[i]]
96116
if len(output.finished_reasons) > i
@@ -132,9 +152,6 @@ def _handle_output_by_index(output, i):
132152
cached_tokens=(
133153
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
134154
),
135-
spec_verify_ct=(
136-
[output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
137-
),
138155
input_token_logprobs_val=(
139156
[output.input_token_logprobs_val[i]]
140157
if output.input_token_logprobs_val
@@ -230,6 +247,26 @@ def _handle_output_by_index(output, i):
230247
elif isinstance(output, BatchStrOutput):
231248
new_output = BatchStrOutput(
232249
rids=[output.rids[i]],
250+
spec_verify_ct=(
251+
[output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
252+
),
253+
spec_accepted_tokens=(
254+
[output.spec_accepted_tokens[i]]
255+
if len(output.spec_accepted_tokens) > i
256+
else None
257+
),
258+
queue_time=[output.queue_time[i]] if len(output.queue_time) > i else None,
259+
forward_entry_time=(
260+
[output.forward_entry_time[i]]
261+
if len(output.forward_entry_time) > i
262+
else None
263+
),
264+
prefill_delay=(
265+
[output.prefill_delay[i]] if len(output.prefill_delay) > i else None
266+
),
267+
prefill_latency=(
268+
[output.prefill_latency[i]] if len(output.prefill_latency) > i else None
269+
),
233270
finished_reasons=(
234271
[output.finished_reasons[i]]
235272
if len(output.finished_reasons) > i
@@ -254,14 +291,6 @@ def _handle_output_by_index(output, i):
254291
cached_tokens=(
255292
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
256293
),
257-
spec_verify_ct=(
258-
[output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
259-
),
260-
spec_accepted_tokens=(
261-
[output.spec_accepted_tokens[i]]
262-
if len(output.spec_accepted_tokens) > i
263-
else None
264-
),
265294
input_token_logprobs_val=(
266295
[output.input_token_logprobs_val[i]]
267296
if output.input_token_logprobs_val

python/sglang/srt/managers/scheduler.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@
152152
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
153153
from sglang.srt.mem_cache.radix_cache import RadixCache
154154
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
155+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
155156
from sglang.srt.multiplex.multiplexing_mixin import SchedulerMultiplexMixin
156157
from sglang.srt.parser.reasoning_parser import ReasoningParser
157158
from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
@@ -1952,6 +1953,12 @@ def run_batch(
19521953
logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
19531954
time.sleep(self.forward_sleep_time)
19541955

1956+
# Capture prefill start time for EXTEND mode
1957+
if batch.forward_mode == ForwardMode.EXTEND:
1958+
current_time = time.perf_counter()
1959+
for req in batch.reqs:
1960+
req.time_stats.prefill_start_time = current_time
1961+
19551962
# Run forward
19561963
if self.is_generation:
19571964
batch_or_worker_batch = batch
@@ -2045,11 +2052,18 @@ def run_batch(
20452052
batch_result.extend_logprob_start_len_per_req = (
20462053
extend_logprob_start_len_per_req
20472054
)
2048-
return batch_result
2055+
ret = batch_result
20492056
else: # embedding or reward model
20502057
model_worker_batch = batch.get_model_worker_batch()
20512058
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
20522059
ret = EmbeddingBatchResult(embeddings=embeddings)
2060+
2061+
# Capture prefill end time for EXTEND mode
2062+
if batch.forward_mode == ForwardMode.EXTEND:
2063+
current_time = time.perf_counter()
2064+
for req in batch.reqs:
2065+
req.time_stats.prefill_end_time = current_time
2066+
20532067
return ret
20542068

20552069
def launch_batch_sample_if_needed(

0 commit comments

Comments
 (0)