Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/coglet-python/src/log_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -615,11 +615,11 @@ mod tests {

let msg = rx.try_recv().unwrap();
match msg {
SlotResponse::Log { source, data } => {
SlotResponse::LogLine { source, data } => {
assert_eq!(source, LogSource::Stdout);
assert_eq!(data, "hello");
}
_ => panic!("expected Log message"),
_ => panic!("expected LogLine message"),
}
}

Expand Down
21 changes: 6 additions & 15 deletions crates/coglet/src/bridge/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,26 +159,17 @@ mod tests {
let mut codec = JsonCodec::<SlotResponse>::new();
let mut buf = BytesMut::new();

let resp = SlotResponse::Done {
id: "test".to_string(),
output: Some(serde_json::json!("result")),
predict_time: 1.5,
is_stream: false,
let resp = SlotResponse::OutputChunk {
output: serde_json::json!("result"),
index: 3,
};
codec.encode(resp, &mut buf).unwrap();
let decoded = codec.decode(&mut buf).unwrap().unwrap();

match decoded {
SlotResponse::Done {
id,
output,
predict_time,
is_stream,
} => {
assert_eq!(id, "test");
assert_eq!(output, Some(serde_json::json!("result")));
assert!((predict_time - 1.5).abs() < 0.001);
assert!(!is_stream);
SlotResponse::OutputChunk { output, index } => {
assert_eq!(output, serde_json::json!("result"));
assert_eq!(index, 3);
}
_ => panic!("wrong variant"),
}
Expand Down
69 changes: 59 additions & 10 deletions crates/coglet/src/bridge/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,11 +294,27 @@ pub enum MetricMode {
Append,
}

/// Current slot response protocol version.
///
/// The response enum is already serde-tagged with `type`, so this constant and
/// the optional `ProtocolVersion` message give future protocol changes an
/// explicit marker without adding a second envelope around every message.
pub const SLOT_RESPONSE_PROTOCOL_VERSION: u32 = 1;

/// Messages from worker to parent on slot socket.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum SlotResponse {
Log {
/// Protocol version handshake message.
///
/// Intended to be sent by the worker when the slot connection opens so the
/// orchestrator can detect version mismatches and adjust behavior. Currently
/// nothing sends this; it is scaffolding for future protocol evolution.
ProtocolVersion {
version: u32,
},

LogLine {
source: LogSource,
data: String,
},
Expand All @@ -313,9 +329,10 @@ pub enum SlotResponse {
mime_type: Option<String>,
},

/// Streaming output chunk (for generators).
Output {
/// Streaming output chunk for generator and iterator output.
OutputChunk {
output: serde_json::Value,
index: u64,
},

/// User-emitted metric from the prediction.
Expand All @@ -338,7 +355,7 @@ pub enum SlotResponse {
output: Option<serde_json::Value>,
predict_time: f64,
/// Predictor signal: true when the output is a list, generator, or
/// iterator used as fallback when the schema Output type is `Any`
/// iterator, used as fallback when the schema Output type is `Any`
/// or unavailable.
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
is_stream: bool,
Expand Down Expand Up @@ -499,20 +516,52 @@ mod tests {
}

#[test]
fn slot_log_serializes() {
let resp = SlotResponse::Log {
fn slot_log_line_serializes() {
let resp = SlotResponse::LogLine {
source: LogSource::Stdout,
data: "Processing...".to_string(),
};
insta::assert_json_snapshot!(resp);

assert_eq!(
serde_json::to_value(resp).unwrap(),
json!({
"type": "log_line",
"source": "stdout",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old snapshot files slot_log_serializes.snap and slot_output_serializes.snap are now orphaned since these tests were renamed and switched to inline assertions. They should be deleted.

"data": "Processing..."
})
);
}

#[test]
fn slot_output_serializes() {
let resp = SlotResponse::Output {
fn slot_output_chunk_serializes() {
let resp = SlotResponse::OutputChunk {
output: json!("chunk 1"),
index: 7,
};
insta::assert_json_snapshot!(resp);

assert_eq!(
serde_json::to_value(resp).unwrap(),
json!({
"type": "output_chunk",
"output": "chunk 1",
"index": 7
})
);
}

#[test]
fn slot_protocol_version_serializes() {
let resp = SlotResponse::ProtocolVersion {
version: SLOT_RESPONSE_PROTOCOL_VERSION,
};

assert_eq!(
serde_json::to_value(resp).unwrap(),
json!({
"type": "protocol_version",
"version": 1
})
);
}

#[test]
Expand Down

This file was deleted.

This file was deleted.

14 changes: 12 additions & 2 deletions crates/coglet/src/orchestrator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,17 @@ async fn run_event_loop(

Some((slot_id, result)) = slot_msg_rx.recv() => {
match result {
Ok(SlotResponse::Log { source, data }) => {
Ok(SlotResponse::ProtocolVersion { version }) => {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing in the worker sends ProtocolVersion yet -- this arm is dead code for now. A quick comment like // TODO: worker sends ProtocolVersion on slot connect (or wherever it'll be sent) would clarify the intent for future readers.

if version != crate::bridge::protocol::SLOT_RESPONSE_PROTOCOL_VERSION {
tracing::warn!(
%slot_id,
version,
expected = crate::bridge::protocol::SLOT_RESPONSE_PROTOCOL_VERSION,
"Worker reported unexpected slot response protocol version"
);
}
}
Ok(SlotResponse::LogLine { source, data }) => {
let (prediction_id, poisoned) = if let Some(pred) = predictions.get(&slot_id) {
if let Some(mut p) = try_lock_prediction(pred) {
p.append_log(&data);
Expand Down Expand Up @@ -1005,7 +1015,7 @@ async fn run_event_loop(
predictions.remove(&slot_id);
}
}
Ok(SlotResponse::Output { output }) => {
Ok(SlotResponse::OutputChunk { output, index: _ }) => {
let poisoned = if let Some(pred) = predictions.get(&slot_id) {
if let Some(mut p) = try_lock_prediction(pred) {
p.append_output(output);
Expand Down
32 changes: 26 additions & 6 deletions crates/coglet/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use std::io;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::OnceLock;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};

use futures::{SinkExt, StreamExt};
use tokio::runtime::Handle;
Expand Down Expand Up @@ -133,7 +133,7 @@ fn init_worker_tracing(tx: mpsc::Sender<ControlResponse>) {
use crate::bridge::codec::JsonCodec;
use crate::bridge::protocol::{
ControlRequest, ControlResponse, FileOutputKind, LogSource, MAX_INLINE_IPC_SIZE, MetricMode,
SlotId, SlotOutcome, SlotRequest, SlotResponse,
SLOT_RESPONSE_PROTOCOL_VERSION, SlotId, SlotOutcome, SlotRequest, SlotResponse,
};
use crate::bridge::transport::{ChildTransportInfo, connect_transport};
use crate::orchestrator::HealthcheckResult;
Expand All @@ -151,6 +151,7 @@ pub struct SlotSender {
tx: mpsc::UnboundedSender<SlotResponse>,
output_dir: PathBuf,
file_counter: Arc<AtomicUsize>,
output_counter: Arc<AtomicU64>,
}

impl SlotSender {
Expand All @@ -159,9 +160,14 @@ impl SlotSender {
tx,
output_dir,
file_counter: Arc::new(AtomicUsize::new(0)),
output_counter: Arc::new(AtomicU64::new(0)),
}
}

fn next_output_index(&self) -> u64 {
self.output_counter.fetch_add(1, Ordering::Relaxed)
}

/// Generate a unique filename in the output dir.
fn next_output_path(&self, extension: &str) -> PathBuf {
let n = self.file_counter.fetch_add(1, Ordering::Relaxed);
Expand All @@ -173,7 +179,7 @@ impl SlotSender {
return Ok(());
}

let msg = SlotResponse::Log {
let msg = SlotResponse::LogLine {
source,
data: truncate_worker_log(data.to_string()),
};
Expand Down Expand Up @@ -232,7 +238,7 @@ impl SlotSender {

/// Send prediction output, either inline or spilled to disk if too large.
pub fn send_output(&self, output: serde_json::Value) -> io::Result<()> {
let msg = build_output_message(&self.output_dir, output)?;
let msg = build_output_message(&self.output_dir, output, self.next_output_index())?;
self.tx
.send(msg)
.map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "slot channel closed"))
Expand All @@ -243,6 +249,7 @@ impl SlotSender {
fn build_output_message(
output_dir: &std::path::Path,
output: serde_json::Value,
index: u64,
) -> io::Result<SlotResponse> {
let serialized =
serde_json::to_vec(&output).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
Expand All @@ -260,7 +267,7 @@ fn build_output_message(
mime_type: None,
})
} else {
Ok(SlotResponse::Output { output })
Ok(SlotResponse::OutputChunk { output, index })
}
}

Expand Down Expand Up @@ -652,6 +659,19 @@ pub async fn run_worker<H: PredictHandler>(
.map(|(id, w)| (id, Arc::new(tokio::sync::Mutex::new(w))))
.collect();

// Send protocol version on each slot so the orchestrator can detect mismatches
for (slot_id, writer) in &slot_writers {
let mut w = writer.lock().await;
if let Err(e) = w
.send(SlotResponse::ProtocolVersion {
version: SLOT_RESPONSE_PROTOCOL_VERSION,
})
.await
{
tracing::warn!(%slot_id, error = %e, "Failed to send protocol version");
}
}

// Main event loop
loop {
tokio::select! {
Expand Down Expand Up @@ -872,7 +892,7 @@ async fn run_prediction<H: PredictHandler>(
// Send output as a separate message (handles spilling for large values).
// Skip if null or empty array — those mean "already streamed" (generators).
if !output.is_null() && output != serde_json::Value::Array(vec![]) {
let output_msg = match build_output_message(&output_dir, output) {
let output_msg = match build_output_message(&output_dir, output, 0) {
Ok(msg) => msg,
Err(e) => {
tracing::error!(error = %e, "Failed to build output message");
Expand Down