Skip to content

Commit 4ef4390

Browse files
SYChen123slin1237CatherineSue
authored
bugfix: multi-model routing for /generate api (#12979)
Co-authored-by: Simo Lin <linsimo.mark@gmail.com> Co-authored-by: Chang Su <chang.s.su@oracle.com>
1 parent d646cf6 commit 4ef4390

File tree

5 files changed

+40
-28
lines changed

5 files changed

+40
-28
lines changed

sgl-router/benches/request_processing.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ fn get_bootstrap_info(worker: &BasicWorker) -> (String, Option<u16>) {
3333
fn default_generate_request() -> GenerateRequest {
3434
GenerateRequest {
3535
text: None,
36+
model: None,
3637
input_ids: None,
3738
input_embeds: None,
3839
image_data: None,

sgl-router/src/protocols/generate.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ pub struct GenerateRequest {
2121
#[serde(skip_serializing_if = "Option::is_none")]
2222
pub text: Option<String>,
2323

24+
pub model: Option<String>,
25+
2426
/// Input IDs for tokenized input
2527
#[serde(skip_serializing_if = "Option::is_none")]
2628
pub input_ids: Option<InputIds>,
@@ -201,8 +203,12 @@ impl GenerationRequest for GenerateRequest {
201203
}
202204

203205
fn get_model(&self) -> Option<&str> {
204-
// Generate requests typically don't have a model field
205-
None
206+
// Generate requests have an optional model field
207+
if let Some(s) = &self.model {
208+
Some(s.as_str())
209+
} else {
210+
None
211+
}
206212
}
207213

208214
fn extract_text_for_routing(&self) -> String {

sgl-router/src/routers/router_manager.rs

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -350,12 +350,12 @@ impl RouterTrait for RouterManager {
350350
&self,
351351
headers: Option<&HeaderMap>,
352352
body: &GenerateRequest,
353-
_model_id: Option<&str>,
353+
model_id: Option<&str>,
354354
) -> Response {
355-
let router = self.select_router_for_request(headers, None);
355+
let router = self.select_router_for_request(headers, model_id);
356356

357357
if let Some(router) = router {
358-
router.route_generate(headers, body, None).await
358+
router.route_generate(headers, body, model_id).await
359359
} else {
360360
(
361361
StatusCode::NOT_FOUND,
@@ -369,12 +369,12 @@ impl RouterTrait for RouterManager {
369369
&self,
370370
headers: Option<&HeaderMap>,
371371
body: &ChatCompletionRequest,
372-
_model_id: Option<&str>,
372+
model_id: Option<&str>,
373373
) -> Response {
374-
let router = self.select_router_for_request(headers, Some(&body.model));
374+
let router = self.select_router_for_request(headers, model_id);
375375

376376
if let Some(router) = router {
377-
router.route_chat(headers, body, Some(&body.model)).await
377+
router.route_chat(headers, body, model_id).await
378378
} else {
379379
(
380380
StatusCode::NOT_FOUND,
@@ -388,14 +388,12 @@ impl RouterTrait for RouterManager {
388388
&self,
389389
headers: Option<&HeaderMap>,
390390
body: &CompletionRequest,
391-
_model_id: Option<&str>,
391+
model_id: Option<&str>,
392392
) -> Response {
393-
let router = self.select_router_for_request(headers, Some(&body.model));
393+
let router = self.select_router_for_request(headers, model_id);
394394

395395
if let Some(router) = router {
396-
router
397-
.route_completion(headers, body, Some(&body.model))
398-
.await
396+
router.route_completion(headers, body, model_id).await
399397
} else {
400398
(
401399
StatusCode::NOT_FOUND,
@@ -487,14 +485,12 @@ impl RouterTrait for RouterManager {
487485
&self,
488486
headers: Option<&HeaderMap>,
489487
body: &EmbeddingRequest,
490-
_model_id: Option<&str>,
488+
model_id: Option<&str>,
491489
) -> Response {
492-
let router = self.select_router_for_request(headers, Some(&body.model));
490+
let router = self.select_router_for_request(headers, model_id);
493491

494492
if let Some(router) = router {
495-
router
496-
.route_embeddings(headers, body, Some(&body.model))
497-
.await
493+
router.route_embeddings(headers, body, model_id).await
498494
} else {
499495
(
500496
StatusCode::NOT_FOUND,
@@ -510,7 +506,7 @@ impl RouterTrait for RouterManager {
510506
body: &RerankRequest,
511507
model_id: Option<&str>,
512508
) -> Response {
513-
let router = self.select_router_for_request(headers, None);
509+
let router = self.select_router_for_request(headers, model_id);
514510

515511
if let Some(router) = router {
516512
router.route_rerank(headers, body, model_id).await
@@ -529,7 +525,7 @@ impl RouterTrait for RouterManager {
529525
body: &ClassifyRequest,
530526
model_id: Option<&str>,
531527
) -> Response {
532-
let router = self.select_router_for_request(headers, Some(&body.model));
528+
let router = self.select_router_for_request(headers, model_id);
533529

534530
if let Some(router) = router {
535531
router.route_classify(headers, body, model_id).await

sgl-router/src/server.rs

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,10 @@ async fn generate(
136136
headers: http::HeaderMap,
137137
Json(body): Json<GenerateRequest>,
138138
) -> Response {
139+
let model_id = body.model.as_deref();
139140
state
140141
.router
141-
.route_generate(Some(&headers), &body, None)
142+
.route_generate(Some(&headers), &body, model_id)
142143
.await
143144
}
144145

@@ -147,7 +148,10 @@ async fn v1_chat_completions(
147148
headers: http::HeaderMap,
148149
ValidatedJson(body): ValidatedJson<ChatCompletionRequest>,
149150
) -> Response {
150-
state.router.route_chat(Some(&headers), &body, None).await
151+
state
152+
.router
153+
.route_chat(Some(&headers), &body, Some(&body.model))
154+
.await
151155
}
152156

153157
async fn v1_completions(
@@ -157,7 +161,7 @@ async fn v1_completions(
157161
) -> Response {
158162
state
159163
.router
160-
.route_completion(Some(&headers), &body, None)
164+
.route_completion(Some(&headers), &body, Some(&body.model))
161165
.await
162166
}
163167

@@ -166,17 +170,21 @@ async fn rerank(
166170
headers: http::HeaderMap,
167171
ValidatedJson(body): ValidatedJson<RerankRequest>,
168172
) -> Response {
169-
state.router.route_rerank(Some(&headers), &body, None).await
173+
state
174+
.router
175+
.route_rerank(Some(&headers), &body, Some(&body.model))
176+
.await
170177
}
171178

172179
async fn v1_rerank(
173180
State(state): State<Arc<AppState>>,
174181
headers: http::HeaderMap,
175182
Json(body): Json<V1RerankReqInput>,
176183
) -> Response {
184+
let rerank_body = &body.into();
177185
state
178186
.router
179-
.route_rerank(Some(&headers), &body.into(), None)
187+
.route_rerank(Some(&headers), rerank_body, Some(&rerank_body.model))
180188
.await
181189
}
182190

@@ -187,7 +195,7 @@ async fn v1_responses(
187195
) -> Response {
188196
state
189197
.router
190-
.route_responses(Some(&headers), &body, None)
198+
.route_responses(Some(&headers), &body, Some(&body.model))
191199
.await
192200
}
193201

@@ -198,7 +206,7 @@ async fn v1_embeddings(
198206
) -> Response {
199207
state
200208
.router
201-
.route_embeddings(Some(&headers), &body, None)
209+
.route_embeddings(Some(&headers), &body, Some(&body.model))
202210
.await
203211
}
204212

@@ -209,7 +217,7 @@ async fn v1_classify(
209217
) -> Response {
210218
state
211219
.router
212-
.route_classify(Some(&headers), &body, None)
220+
.route_classify(Some(&headers), &body, Some(&body.model))
213221
.await
214222
}
215223

sgl-router/tests/test_openai_routing.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,7 @@ async fn test_unsupported_endpoints() {
602602

603603
let generate_request = GenerateRequest {
604604
text: Some("Hello world".to_string()),
605+
model: None,
605606
input_ids: None,
606607
input_embeds: None,
607608
image_data: None,

0 commit comments

Comments
 (0)