@@ -61,6 +61,55 @@ def regenerate_rids(self):
6161 return self .rids
6262
6363
64+ @dataclass
65+ class RequestTimingMetricsMixin :
66+ """
67+ Mixin class containing common request-level timing metrics.
68+
69+ This class consolidates the timing metrics that are shared across all batch output types
70+ to avoid code duplication and ensure consistency.
71+ """
72+
73+ # Queue duration: time spent waiting in queue before request is scheduled.
74+ queue_time : Optional [List [Optional [float ]]]
75+
76+ # Forward entry time: timestamp when the request enters the forward pass stage.
77+ # This corresponds to `forward_entry_time` in TimeStats.
78+ # In different modes:
79+ # - Unified/PD-colocate: timestamp when forward computation begins (covers prefill + decode)
80+ # - Prefill instance (P): timestamp when prefill forward pass begins
81+ # - Decode instance (D): timestamp when decode forward pass begins
82+ # Note: This is NOT the same as prefill_start_time. There may be a delay between
83+ # forward_entry_time and prefill_start_time (see prefill_delay).
84+ forward_entry_time : Optional [List [Optional [float ]]]
85+
86+ # Prefill delay: time spent waiting between forward entry and prefill start.
87+ # Calculated as: prefill_start_time - forward_entry_time
88+ # This represents the delay between when the request enters the forward stage
89+ # and when prefill computation actually begins.
90+ prefill_delay : Optional [List [Optional [float ]]]
91+
92+ # Prefill latency: time spent during prefill computation.
93+ # Calculated as: prefill_end_time - prefill_start_time
94+ prefill_latency : Optional [List [Optional [float ]]]
95+
96+
97+ @dataclass
98+ class SpeculativeDecodingMetricsMixin :
99+ """
100+ Mixin class containing speculative decoding metrics.
101+
102+ This class consolidates speculative decoding metrics that are shared across
103+ batch output types that support speculative decoding to avoid code duplication.
104+ """
105+
106+ # Verify count: number of verification forward passes
107+ spec_verify_ct : List [int ]
108+
109+ # Accepted tokens: Number of accepted tokens during speculative decoding
110+ spec_accepted_tokens : List [int ]
111+
112+
64113# Parameters for a session
65114@dataclass
66115class SessionParams :
@@ -148,6 +197,9 @@ class GenerateReqInput(BaseReq):
148197 bootstrap_room : Optional [Union [List [int ], int ]] = None
149198 bootstrap_pair_key : Optional [Union [List [str ], str ]] = None
150199
200+ # Validation step duration
201+ validation_time : Optional [float ] = None
202+
151203 # For data parallel rank routing
152204 data_parallel_rank : Optional [int ] = None
153205
@@ -564,6 +616,7 @@ def __getitem__(self, i):
564616 if self .bootstrap_pair_key is not None
565617 else None
566618 ),
619+ validation_time = self .validation_time ,
567620 data_parallel_rank = (
568621 self .data_parallel_rank if self .data_parallel_rank is not None else None
569622 ),
@@ -684,6 +737,8 @@ class EmbeddingReqInput(BaseReq):
684737 log_metrics : bool = True
685738 # The modalities of the image data [image, multi-images, video]
686739 modalities : Optional [List [str ]] = None
740+ # Validation step duration
741+ validation_time : Optional [float ] = None
687742 # For cross-encoder requests
688743 is_cross_encoder_request : bool = False
689744 # Priority for the request
@@ -774,6 +829,7 @@ def __getitem__(self, i):
774829 video_data = self .video_data [i ] if self .video_data is not None else None ,
775830 sampling_params = self .sampling_params [i ],
776831 rid = self .rid [i ],
832+ validation_time = self .validation_time ,
777833 dimensions = self .dimensions ,
778834 http_worker_ipc = self .http_worker_ipc ,
779835 )
@@ -815,7 +871,9 @@ def __iter__(self):
815871
816872
817873@dataclass
818- class BatchTokenIDOutput (BaseBatchReq ):
874+ class BatchTokenIDOutput (
875+ BaseBatchReq , RequestTimingMetricsMixin , SpeculativeDecodingMetricsMixin
876+ ):
819877 # The finish reason
820878 finished_reasons : List [BaseFinishReason ]
821879 # For incremental decoding
@@ -833,8 +891,6 @@ class BatchTokenIDOutput(BaseBatchReq):
833891 prompt_tokens : List [int ]
834892 completion_tokens : List [int ]
835893 cached_tokens : List [int ]
836- spec_verify_ct : List [int ]
837- spec_accepted_tokens : List [int ]
838894
839895 # Logprobs
840896 input_token_logprobs_val : List [float ]
@@ -868,7 +924,7 @@ class BatchTokenIDOutput(BaseBatchReq):
868924
869925
870926@dataclass
871- class BatchMultimodalDecodeReq (BaseBatchReq ):
927+ class BatchMultimodalDecodeReq (BaseBatchReq , RequestTimingMetricsMixin ):
872928 decoded_ids : List [int ]
873929 input_token_logprobs_val : List [float ]
874930 input_token_logprobs_idx : List [int ]
@@ -900,7 +956,9 @@ class BatchMultimodalDecodeReq(BaseBatchReq):
900956
901957
902958@dataclass
903- class BatchStrOutput (BaseBatchReq ):
959+ class BatchStrOutput (
960+ BaseBatchReq , RequestTimingMetricsMixin , SpeculativeDecodingMetricsMixin
961+ ):
904962 # The finish reason
905963 finished_reasons : List [dict ]
906964 # The output decoded strings
@@ -912,8 +970,6 @@ class BatchStrOutput(BaseBatchReq):
912970 prompt_tokens : List [int ]
913971 completion_tokens : List [int ]
914972 cached_tokens : List [int ]
915- spec_verify_ct : List [int ]
916- spec_accepted_tokens : List [int ]
917973
918974 # Logprobs
919975 input_token_logprobs_val : List [float ]
@@ -947,7 +1003,7 @@ class BatchStrOutput(BaseBatchReq):
9471003
9481004
9491005@dataclass
950- class BatchMultimodalOutput (BaseBatchReq ):
1006+ class BatchMultimodalOutput (BaseBatchReq , RequestTimingMetricsMixin ):
9511007 # The finish reason
9521008 finished_reasons : List [dict ]
9531009 decoded_ids : List [List [int ]]
@@ -972,7 +1028,7 @@ class BatchMultimodalOutput(BaseBatchReq):
9721028
9731029
9741030@dataclass
975- class BatchEmbeddingOutput (BaseBatchReq ):
1031+ class BatchEmbeddingOutput (BaseBatchReq , RequestTimingMetricsMixin ):
9761032 # The finish reason
9771033 finished_reasons : List [BaseFinishReason ]
9781034 # The output embedding
0 commit comments