Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions crates/coglet-python/src/predictor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,22 @@ impl PythonPredictor {
return Ok(PredictionOutput::Single(serde_json::Value::Null));
}

// List/tuple output — iterate items so file outputs (Path, IOBase)
// go through the FileOutput IPC path for upload instead of being
// base64-encoded inline by process_output.
if let Ok(list) = result.cast::<pyo3::types::PyList>() {
for item in list.iter() {
send_output_item(py, &item, json_module, slot_sender)?;
}
return Ok(PredictionOutput::Stream(vec![]));
}
if let Ok(tuple) = result.cast::<pyo3::types::PyTuple>() {
for item in tuple.iter() {
send_output_item(py, &item, json_module, slot_sender)?;
}
return Ok(PredictionOutput::Stream(vec![]));
}

// Non-file output — process normally
let processed = output::process_output(py, result, None)
.map_err(|e| PredictionError::Failed(format!("Failed to process output: {}", e)))?;
Expand Down
7 changes: 6 additions & 1 deletion crates/coglet-python/src/worker_bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,12 @@ impl PredictHandler for PythonPredictHandler {

match result {
Ok(r) => {
PredictResult::success(output_to_json(r.output), start.elapsed().as_secs_f64())
let is_stream = r.output.is_stream();
PredictResult::success(
output_to_json(r.output),
start.elapsed().as_secs_f64(),
is_stream,
)
}
Err(e) => {
if matches!(e, coglet_core::PredictionError::Cancelled) {
Expand Down
3 changes: 3 additions & 0 deletions crates/coglet/src/bridge/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ mod tests {
id: "test".to_string(),
output: Some(serde_json::json!("result")),
predict_time: 1.5,
is_stream: false,
};
codec.encode(resp, &mut buf).unwrap();
let decoded = codec.decode(&mut buf).unwrap().unwrap();
Expand All @@ -165,10 +166,12 @@ mod tests {
id,
output,
predict_time,
is_stream,
} => {
assert_eq!(id, "test");
assert_eq!(output, Some(serde_json::json!("result")));
assert!((predict_time - 1.5).abs() < 0.001);
assert!(!is_stream);
}
_ => panic!("wrong variant"),
}
Expand Down
6 changes: 6 additions & 0 deletions crates/coglet/src/bridge/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,11 @@ pub enum SlotResponse {
#[serde(skip_serializing_if = "Option::is_none")]
output: Option<serde_json::Value>,
predict_time: f64,
/// Predictor signal: true when the output is a list, generator, or
/// iterator — used as fallback when the schema Output type is `Any`
/// or unavailable.
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
is_stream: bool,
},

Failed {
Expand Down Expand Up @@ -436,6 +441,7 @@ mod tests {
id: "pred_123".to_string(),
output: Some(json!("final result")),
predict_time: 1.234,
is_stream: false,
};
insta::assert_json_snapshot!(resp);
}
Expand Down
233 changes: 217 additions & 16 deletions crates/coglet/src/orchestrator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,35 @@ fn try_lock_prediction(
}
}

/// Wrap collected output items into the correct `PredictionOutput` variant.
///
/// Priority:
/// 1. Schema says `"type": "array"` (`output_is_array = true`) → always `Stream`
/// 2. Predictor signals `is_stream` (list/generator return) → always `Stream`
/// 3. Otherwise → `Single` for one item, `Stream` for multiple
///
/// This ensures `List[Path]` with a single element returns `["url"]` not `"url"`.
fn wrap_outputs(
outputs: Vec<serde_json::Value>,
output_is_array: bool,
is_stream: bool,
) -> PredictionOutput {
let should_stream = output_is_array || is_stream;

match outputs.as_slice() {
[] => {
if should_stream {
PredictionOutput::Stream(vec![])
} else {
PredictionOutput::Single(serde_json::Value::Null)
}
}
_ if should_stream => PredictionOutput::Stream(outputs),
[single] => PredictionOutput::Single(single.clone()),
_ => PredictionOutput::Stream(outputs),
}
}

fn emit_worker_log(target: &str, level: &str, msg: &str) {
use std::collections::HashMap;
use std::sync::OnceLock;
Expand Down Expand Up @@ -543,6 +572,25 @@ pub async fn spawn_worker(
tracing::trace!(target: "coglet::schema", schema = %json, "OpenAPI schema");
}

// Determine whether the output type is an array from the schema so the
// event loop can correctly wrap single-element list returns as Stream
// instead of collapsing them to Single.
let output_is_array = schema
.as_ref()
.and_then(|s| s.get("components"))
.and_then(|c| c.get("schemas"))
.and_then(|schemas| {
let key = if config.is_train {
"TrainingOutput"
} else {
"Output"
};
schemas.get(key)
})
.and_then(|output| output.get("type"))
.and_then(|t| t.as_str())
.is_some_and(|t| t == "array");

let pool = Arc::new(PermitPool::new(num_slots));
let sockets = transport.drain_sockets();

Expand Down Expand Up @@ -585,6 +633,7 @@ pub async fn spawn_worker(
cancel_rx,
pool_for_loop,
upload_url,
output_is_array,
)
.await;
});
Expand Down Expand Up @@ -616,6 +665,10 @@ async fn run_event_loop(
mut cancel_rx: mpsc::Receiver<String>,
pool: Arc<PermitPool>,
upload_url: Option<String>,
// Schema says Output is "type": "array" — always wrap as Stream.
// When false, the schema was unavailable or Output type is Any; fall
// back to the predictor's is_stream flag on the Done message.
output_is_array: bool,
) {
let mut predictions: HashMap<SlotId, Arc<StdMutex<Prediction>>> = HashMap::new();
let mut idle_senders: HashMap<SlotId, tokio::sync::oneshot::Sender<SlotIdleToken>> =
Expand Down Expand Up @@ -997,11 +1050,13 @@ async fn run_event_loop(
}
}
}
Ok(SlotResponse::Done { id, output: _, predict_time }) => {
Ok(SlotResponse::Done { id, output: _, predict_time, is_stream }) => {
tracing::info!(
target: "coglet::prediction",
prediction_id = %id,
predict_time,
is_stream,
output_is_array,
"Prediction succeeded"
);
let uploads = pending_uploads.remove(&slot_id).unwrap_or_default();
Expand All @@ -1013,11 +1068,11 @@ async fn run_event_loop(
// registered waiters; spawning a task can fire the
// notification before the service registers its waiter.
if let Some(mut p) = try_lock_prediction(&pred) {
let pred_output = match p.take_outputs().as_slice() {
[] => PredictionOutput::Single(serde_json::Value::Null),
[single] => PredictionOutput::Single(single.clone()),
many => PredictionOutput::Stream(many.to_vec()),
};
let pred_output = wrap_outputs(
p.take_outputs(),
output_is_array,
is_stream,
);
p.set_succeeded(pred_output);
}
} else {
Expand All @@ -1040,11 +1095,6 @@ async fn run_event_loop(
prediction_id = %upload_pred_id,
"Aborting in-flight uploads due to cancellation"
);
// JoinAll drops the JoinHandles when it goes out of
// scope at the end of this branch, but JoinHandle::drop
// does NOT abort the spawned task. The upload tasks
// were already aborted by the cancel handler in the
// event loop (cancel_rx arm), so they will terminate.
if let Some(mut p) = try_lock_prediction(&pred) {
p.set_canceled();
}
Expand All @@ -1057,11 +1107,11 @@ async fn run_event_loop(
}
}
if let Some(mut p) = try_lock_prediction(&pred) {
let pred_output = match p.take_outputs().as_slice() {
[] => PredictionOutput::Single(serde_json::Value::Null),
[single] => PredictionOutput::Single(single.clone()),
many => PredictionOutput::Stream(many.to_vec()),
};
let pred_output = wrap_outputs(
p.take_outputs(),
output_is_array,
is_stream,
);
p.set_succeeded(pred_output);
}
});
Expand Down Expand Up @@ -1121,3 +1171,154 @@ async fn run_event_loop(

tracing::info!("Event loop exiting");
}

#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;

// ── wrap_outputs: schema says array (output_is_array = true) ──

#[test]
fn wrap_outputs_schema_array_empty() {
// List[Path] that returned no items → empty array
let result = wrap_outputs(vec![], true, true);
assert!(result.is_stream());
assert_eq!(result.into_values(), Vec::<serde_json::Value>::new());
}

#[test]
fn wrap_outputs_schema_array_single_item() {
// List[Path] with num_outputs=1 → ["url"] not "url"
let result = wrap_outputs(vec![json!("https://example.com/img.png")], true, true);
assert!(result.is_stream());
assert_eq!(
result.into_values(),
vec![json!("https://example.com/img.png")]
);
}

#[test]
fn wrap_outputs_schema_array_multiple_items() {
// List[Path] with num_outputs=4
let items = vec![
json!("https://example.com/1.png"),
json!("https://example.com/2.png"),
json!("https://example.com/3.png"),
json!("https://example.com/4.png"),
];
let result = wrap_outputs(items.clone(), true, true);
assert!(result.is_stream());
assert_eq!(result.into_values(), items);
}

#[test]
fn wrap_outputs_schema_array_overrides_is_stream_false() {
// Schema says array but predictor didn't set is_stream (shouldn't happen,
// but schema is authoritative)
let result = wrap_outputs(vec![json!("url")], true, false);
assert!(result.is_stream());
}

// ── wrap_outputs: predictor signal (is_stream = true, no schema) ──

#[test]
fn wrap_outputs_predictor_stream_empty() {
// Generator that yielded nothing, no schema
let result = wrap_outputs(vec![], false, true);
assert!(result.is_stream());
assert_eq!(result.into_values(), Vec::<serde_json::Value>::new());
}

#[test]
fn wrap_outputs_predictor_stream_single_item() {
// Any-typed list with one element, no schema
let result = wrap_outputs(vec![json!("only_item")], false, true);
assert!(result.is_stream());
assert_eq!(result.into_values(), vec![json!("only_item")]);
}

#[test]
fn wrap_outputs_predictor_stream_multiple_items() {
// Generator yielding multiple, no schema
let items = vec![json!("a"), json!("b"), json!("c")];
let result = wrap_outputs(items.clone(), false, true);
assert!(result.is_stream());
assert_eq!(result.into_values(), items);
}

// ── wrap_outputs: scalar output (neither schema array nor predictor stream) ──

#[test]
fn wrap_outputs_scalar_empty() {
// Single output that was null (e.g. Path sent via FileOutput, not yet resolved?)
let result = wrap_outputs(vec![], false, false);
assert!(!result.is_stream());
assert_eq!(result.final_value(), &json!(null));
}

#[test]
fn wrap_outputs_scalar_single() {
// return Path("output.png") → single string
let result = wrap_outputs(vec![json!("https://example.com/output.png")], false, false);
assert!(!result.is_stream());
assert_eq!(
result.final_value(),
&json!("https://example.com/output.png")
);
}

#[test]
fn wrap_outputs_scalar_multiple_falls_back_to_stream() {
// Shouldn't happen for scalar returns, but if multiple items arrive
// with neither flag set, Stream is the safe choice
let items = vec![json!("a"), json!("b")];
let result = wrap_outputs(items.clone(), false, false);
assert!(result.is_stream());
assert_eq!(result.into_values(), items);
}

// ── Serialization: is_stream field on Done message ──

#[test]
fn done_is_stream_false_omitted_from_json() {
let msg = SlotResponse::Done {
id: "p1".into(),
output: None,
predict_time: 1.0,
is_stream: false,
};
let json = serde_json::to_value(&msg).unwrap();
assert!(
json.get("is_stream").is_none(),
"is_stream=false should be omitted"
);
}

#[test]
fn done_is_stream_true_present_in_json() {
let msg = SlotResponse::Done {
id: "p1".into(),
output: None,
predict_time: 1.0,
is_stream: true,
};
let json = serde_json::to_value(&msg).unwrap();
assert_eq!(json.get("is_stream"), Some(&json!(true)));
}

#[test]
fn done_without_is_stream_deserializes_as_false() {
// Backward compat: old workers won't send is_stream
let json = json!({
"type": "done",
"id": "p1",
"predict_time": 1.0
});
let msg: SlotResponse = serde_json::from_value(json).unwrap();
match msg {
SlotResponse::Done { is_stream, .. } => assert!(!is_stream),
_ => panic!("wrong variant"),
}
}
}
Loading