diff --git a/apps/skit/src/mcp.rs b/apps/skit/src/mcp.rs index 7ebba8cb..89d15909 100644 --- a/apps/skit/src/mcp.rs +++ b/apps/skit/src/mcp.rs @@ -103,6 +103,32 @@ pub struct SessionIdArgs { pub session_id: String, } +#[derive(Debug, Deserialize, schemars::JsonSchema)] +pub struct ValidateBatchArgs { + /// Session ID or name. + pub session_id: String, + /// List of batch operations to validate. + pub operations: Vec, +} + +#[derive(Debug, Deserialize, schemars::JsonSchema)] +pub struct ApplyBatchArgs { + /// Session ID or name. + pub session_id: String, + /// List of batch operations to apply atomically. + pub operations: Vec, +} + +#[derive(Debug, Deserialize, schemars::JsonSchema)] +pub struct TuneNodeArgs { + /// Session ID or name. + pub session_id: String, + /// Node ID to send the control message to. + pub node_id: String, + /// The control message (e.g., UpdateParams with a JSON value). + pub message: streamkit_core::control::NodeControlMessage, +} + // --------------------------------------------------------------------------- // MCP prompt argument structs // --------------------------------------------------------------------------- @@ -623,6 +649,181 @@ impl StreamKitMcp { .map_err(|e| McpError::internal_error(format!("serialization error: {e}"), None))?, )])) } + + // -- validate_batch ---------------------------------------------------- + + #[tool( + description = "Validate a batch of graph mutations against a running session without applying them. Returns validation errors for any operations that would fail. Operations: addnode, removenode, connect, disconnect." + )] + async fn validate_batch( + &self, + Parameters(args): Parameters, + ctx: RequestContext, + ) -> Result { + let (role_name, perms) = extract_auth(&ctx, &self.app_state)?; + + if !perms.modify_sessions { + return Err(McpError::invalid_request( + "Permission denied: modify_sessions required", + None, + )); + } + + let session = { + let sm = self.app_state.session_manager.lock().await; + sm.get_session_by_name_or_id(&args.session_id) + }; + + let Some(session) = session else { + return Err(McpError::invalid_params( + format!("Session '{}' not found", args.session_id), + None, + )); + }; + + if !perms.access_all_sessions + && session.created_by.as_ref().is_some_and(|c| c != &role_name) + { + return Err(McpError::invalid_request( + "Permission denied: you do not own this session", + None, + )); + } + + let errors = crate::server::validate_batch_operations( + &session, + &args.operations, + &perms, + &self.app_state.config.security, + ) + .await; + + info!( + session_id = %args.session_id, + operation_count = args.operations.len(), + error_count = errors.len(), + "MCP validate_batch" + ); + + let json = serde_json::to_string_pretty(&errors) + .map_err(|e| McpError::internal_error(format!("serialization error: {e}"), None))?; + + Ok(CallToolResult::success(vec![Content::text(json)])) + } + + // -- apply_batch ------------------------------------------------------- + + #[tool( + description = "Apply a batch of graph mutations atomically to a running session. All operations succeed or all fail together. Operations: addnode, removenode, connect, disconnect." + )] + async fn apply_batch( + &self, + Parameters(args): Parameters, + ctx: RequestContext, + ) -> Result { + let (role_name, perms) = extract_auth(&ctx, &self.app_state)?; + + if !perms.modify_sessions { + return Err(McpError::invalid_request( + "Permission denied: modify_sessions required", + None, + )); + } + + let session = { + let sm = self.app_state.session_manager.lock().await; + sm.get_session_by_name_or_id(&args.session_id) + }; + + let Some(session) = session else { + return Err(McpError::invalid_params( + format!("Session '{}' not found", args.session_id), + None, + )); + }; + + if !perms.access_all_sessions + && session.created_by.as_ref().is_some_and(|c| c != &role_name) + { + return Err(McpError::invalid_request( + "Permission denied: you do not own this session", + None, + )); + } + + crate::server::apply_batch_operations( + &session, + args.operations, + &perms, + &self.app_state.config.security, + ) + .await + .map_err(|e| McpError::invalid_params(e, None))?; + + info!(session_id = %args.session_id, "MCP apply_batch"); + + let result = serde_json::json!({ "success": true }); + Ok(CallToolResult::success(vec![Content::text( + serde_json::to_string_pretty(&result) + .map_err(|e| McpError::internal_error(format!("serialization error: {e}"), None))?, + )])) + } + + // -- tune_node --------------------------------------------------------- + + #[tool( + description = "Send a control message to a specific node in a running session. Commonly used with UpdateParams to modify node parameters at runtime." + )] + async fn tune_node( + &self, + Parameters(args): Parameters, + ctx: RequestContext, + ) -> Result { + let (role_name, perms) = extract_auth(&ctx, &self.app_state)?; + + if !perms.tune_nodes { + return Err(McpError::invalid_request("Permission denied: tune_nodes required", None)); + } + + let session = { + let sm = self.app_state.session_manager.lock().await; + sm.get_session_by_name_or_id(&args.session_id) + }; + + let Some(session) = session else { + return Err(McpError::invalid_params( + format!("Session '{}' not found", args.session_id), + None, + )); + }; + + if !perms.access_all_sessions + && session.created_by.as_ref().is_some_and(|c| c != &role_name) + { + return Err(McpError::invalid_request( + "Permission denied: you do not own this session", + None, + )); + } + + crate::server::tune_session_node( + &session, + args.node_id.clone(), + args.message, + &self.app_state.config.security, + &self.app_state.event_tx, + ) + .await + .map_err(|e| McpError::invalid_params(e, None))?; + + info!(session_id = %args.session_id, node_id = %args.node_id, "MCP tune_node"); + + let result = serde_json::json!({ "success": true }); + Ok(CallToolResult::success(vec![Content::text( + serde_json::to_string_pretty(&result) + .map_err(|e| McpError::internal_error(format!("serialization error: {e}"), None))?, + )])) + } } // --------------------------------------------------------------------------- @@ -1082,10 +1283,13 @@ impl ServerHandler for StreamKitMcp { "StreamKit MCP server. Use list_nodes to discover available \ processing nodes, validate_pipeline to check YAML, \ create_session / list_sessions / get_pipeline / destroy_session \ - to manage dynamic pipeline sessions, and \ + to manage dynamic pipeline sessions, \ generate_oneshot_command to get a curl or skit-cli command for \ - batch processing via the HTTP data plane. Two built-in prompts \ - are available: design_pipeline (guided pipeline creation) and \ + batch processing via the HTTP data plane, validate_batch and \ + apply_batch to atomically mutate a running session's graph, \ + and tune_node to send control messages (e.g. UpdateParams) to \ + individual nodes at runtime. Two built-in prompts are available: \ + design_pipeline (guided pipeline creation) and \ debug_pipeline (session diagnostics).", ); info.server_info = rmcp::model::Implementation::new("streamkit", env!("CARGO_PKG_VERSION")); diff --git a/apps/skit/src/server/mod.rs b/apps/skit/src/server/mod.rs index 0880df23..e1349d9f 100644 --- a/apps/skit/src/server/mod.rs +++ b/apps/skit/src/server/mod.rs @@ -4004,6 +4004,286 @@ pub fn check_file_path_security( } } +/// Validate a batch of operations against a session's pipeline without applying. +/// +/// Returns a list of validation errors. An empty list means all operations +/// are valid. Callers must perform session-level permission and ownership +/// checks before calling this function. +pub async fn validate_batch_operations( + session: &crate::session::Session, + operations: &[streamkit_api::BatchOperation], + perms: &crate::permissions::Permissions, + security_config: &crate::config::SecurityConfig, +) -> Vec { + let mut errors: Vec = Vec::new(); + + // Pre-validate duplicate node_ids against the pipeline model, simulating + // the Add/Remove sequence so that Remove→Add for the same ID within the + // batch is allowed, but duplicate Adds are rejected. + let mut live_ids: std::collections::HashSet = + session.pipeline.lock().await.nodes.keys().cloned().collect(); + for op in operations { + match op { + streamkit_api::BatchOperation::AddNode { node_id, .. } => { + if !live_ids.insert(node_id.clone()) { + errors.push(streamkit_api::ValidationError { + error_type: streamkit_api::ValidationErrorType::Error, + message: format!( + "Batch rejected: node '{node_id}' already exists in the pipeline" + ), + node_id: Some(node_id.clone()), + connection_id: None, + }); + } + }, + streamkit_api::BatchOperation::RemoveNode { node_id } => { + live_ids.remove(node_id.as_str()); + }, + _ => {}, + } + } + + // Validate all AddNode operations against permission and security rules. + for op in operations { + if let streamkit_api::BatchOperation::AddNode { node_id, kind, params, .. } = op { + if let Some(message) = crate::websocket_handlers::validate_add_node_op( + kind, + params.as_ref(), + perms, + security_config, + ) { + errors.push(streamkit_api::ValidationError { + error_type: streamkit_api::ValidationErrorType::Error, + message, + node_id: Some(node_id.clone()), + connection_id: None, + }); + } + } + } + + errors +} + +/// Apply a batch of graph mutations atomically to a running session. +/// +/// Returns `Ok(())` on success, or `Err(message)` if pre-validation fails +/// (e.g. duplicate node IDs or forbidden node kinds). Callers must perform +/// session-level permission and ownership checks before calling this function. +/// +/// # Errors +/// +/// Returns an error string when a batch operation fails pre-validation +/// (duplicate node IDs or forbidden node kinds). +pub async fn apply_batch_operations( + session: &crate::session::Session, + operations: Vec, + perms: &crate::permissions::Permissions, + security_config: &crate::config::SecurityConfig, +) -> Result<(), String> { + // Pre-validate duplicate node_ids. + { + let mut live_ids: std::collections::HashSet = + session.pipeline.lock().await.nodes.keys().cloned().collect(); + for op in &operations { + match op { + streamkit_api::BatchOperation::AddNode { node_id, .. } => { + if !live_ids.insert(node_id.clone()) { + return Err(format!( + "Batch rejected: node '{node_id}' already exists in the pipeline" + )); + } + }, + streamkit_api::BatchOperation::RemoveNode { node_id } => { + live_ids.remove(node_id.as_str()); + }, + _ => {}, + } + } + } + + // Validate permissions for all AddNode operations. + for op in &operations { + if let streamkit_api::BatchOperation::AddNode { kind, params, .. } = op { + if let Some(message) = crate::websocket_handlers::validate_add_node_op( + kind, + params.as_ref(), + perms, + security_config, + ) { + return Err(message); + } + } + } + + // Apply all operations in order. + let mut engine_operations = Vec::new(); + { + let mut pipeline = session.pipeline.lock().await; + for op in operations { + match op { + streamkit_api::BatchOperation::AddNode { node_id, kind, params } => { + pipeline.nodes.insert( + node_id.clone(), + streamkit_api::Node { + kind: kind.clone(), + params: params.clone(), + state: None, + }, + ); + engine_operations.push( + streamkit_core::control::EngineControlMessage::AddNode { + node_id, + kind, + params, + }, + ); + }, + streamkit_api::BatchOperation::RemoveNode { node_id } => { + pipeline.nodes.shift_remove(&node_id); + pipeline + .connections + .retain(|conn| conn.from_node != node_id && conn.to_node != node_id); + engine_operations.push( + streamkit_core::control::EngineControlMessage::RemoveNode { node_id }, + ); + }, + streamkit_api::BatchOperation::Connect { + from_node, + from_pin, + to_node, + to_pin, + mode, + } => { + pipeline.connections.push(streamkit_api::Connection { + from_node: from_node.clone(), + from_pin: from_pin.clone(), + to_node: to_node.clone(), + to_pin: to_pin.clone(), + mode, + }); + let core_mode = match mode { + streamkit_api::ConnectionMode::Reliable => { + streamkit_core::control::ConnectionMode::Reliable + }, + streamkit_api::ConnectionMode::BestEffort => { + streamkit_core::control::ConnectionMode::BestEffort + }, + }; + engine_operations.push( + streamkit_core::control::EngineControlMessage::Connect { + from_node, + from_pin, + to_node, + to_pin, + mode: core_mode, + }, + ); + }, + streamkit_api::BatchOperation::Disconnect { + from_node, + from_pin, + to_node, + to_pin, + } => { + pipeline.connections.retain(|conn| { + !(conn.from_node == from_node + && conn.from_pin == from_pin + && conn.to_node == to_node + && conn.to_pin == to_pin) + }); + engine_operations.push( + streamkit_core::control::EngineControlMessage::Disconnect { + from_node, + from_pin, + to_node, + to_pin, + }, + ); + }, + } + } + drop(pipeline); + } + + // Send control messages to the engine. + for msg in engine_operations { + session.send_control_message(msg).await; + } + + Ok(()) +} + +/// Send a control message to a specific node in a running session. +/// +/// For `UpdateParams` messages, this function also validates file-path +/// security, updates the durable pipeline model, and broadcasts a +/// `NodeParamsChanged` event. Callers must perform session-level +/// permission and ownership checks before calling this function. +/// +/// # Errors +/// +/// Returns an error string when the security policy rejects the +/// `UpdateParams` payload. +pub async fn tune_session_node( + session: &crate::session::Session, + node_id: String, + message: streamkit_core::control::NodeControlMessage, + security_config: &crate::config::SecurityConfig, + event_tx: &tokio::sync::broadcast::Sender, +) -> Result<(), String> { + use streamkit_core::control::NodeControlMessage; + + if let NodeControlMessage::UpdateParams(ref params) = message { + let kind = { + let pipeline = session.pipeline.lock().await; + pipeline.nodes.get(&node_id).map(|n| n.kind.clone()) + }; + + if !crate::websocket_handlers::validate_update_params_security( + kind.as_deref(), + params, + security_config, + ) { + return Err("Security policy rejected the UpdateParams payload".to_string()); + } + + { + let mut durable_params = params.clone(); + if let serde_json::Value::Object(ref mut map) = durable_params { + map.retain(|k, _| !k.starts_with('_')); + } + let mut pipeline = session.pipeline.lock().await; + if let Some(node) = pipeline.nodes.get_mut(&node_id) { + node.params = Some(match node.params.take() { + Some(existing) => { + crate::websocket_handlers::deep_merge_json(existing, durable_params) + }, + None => durable_params, + }); + } + } + + let event = streamkit_api::Event { + message_type: streamkit_api::MessageType::Event, + correlation_id: None, + payload: streamkit_api::EventPayload::NodeParamsChanged { + session_id: session.id.clone(), + node_id: node_id.clone(), + params: params.clone(), + }, + }; + if let Err(e) = event_tx.send(crate::state::BroadcastEvent::to_all(event)) { + tracing::error!("Failed to broadcast NodeParamsChanged event: {}", e); + } + } + + let control_msg = streamkit_core::control::EngineControlMessage::TuneNode { node_id, message }; + session.send_control_message(control_msg).await; + + Ok(()) +} + /// Creates the Axum application with all routes and middleware. /// /// # Arguments diff --git a/apps/skit/src/websocket_handlers.rs b/apps/skit/src/websocket_handlers.rs index dc4f2d32..c68a8377 100644 --- a/apps/skit/src/websocket_handlers.rs +++ b/apps/skit/src/websocket_handlers.rs @@ -14,8 +14,7 @@ use crate::state::{AppState, BroadcastEvent}; use opentelemetry::global; use std::sync::Arc; use streamkit_api::{ - Event as ApiEvent, EventPayload, MessageType, RequestPayload, ResponsePayload, ValidationError, - ValidationErrorType, + Event as ApiEvent, EventPayload, MessageType, RequestPayload, ResponsePayload, }; use streamkit_core::control::{EngineControlMessage, NodeControlMessage}; use streamkit_core::registry::NodeDefinition; @@ -41,7 +40,7 @@ fn can_access_session(session: &Session, role_name: &str, perms: &Permissions) - /// Returns `Some(error_message)` if the operation is not allowed, `None` if it passes. /// This is the single source of truth for AddNode validation, used by `handle_add_node`, /// `handle_validate_batch`, and `handle_apply_batch`. -fn validate_add_node_op( +pub fn validate_add_node_op( kind: &str, params: Option<&serde_json::Value>, perms: &Permissions, @@ -776,18 +775,16 @@ async fn handle_tune_node( perms: &Permissions, role_name: &str, ) -> Option { - // Check permission to tune nodes if !perms.tune_nodes { return Some(ResponsePayload::Error { message: "Permission denied: cannot tune nodes".to_string(), }); } - // Get session with SHORT lock hold to avoid blocking other operations let session = { let session_manager = app_state.session_manager.lock().await; session_manager.get_session_by_name_or_id(&session_id) - }; // Session manager lock released here + }; let Some(session) = session else { return Some(ResponsePayload::Error { @@ -795,115 +792,30 @@ async fn handle_tune_node( }); }; - // Check ownership (session is cloned, doesn't need lock) if !can_access_session(&session, role_name, perms) { return Some(ResponsePayload::Error { message: "Permission denied: you do not own this session".to_string(), }); } - // Handle UpdateParams specially for event broadcasting (and validate file paths) - if let NodeControlMessage::UpdateParams(ref params) = message { - let (kind, file_path, script_path) = { - let pipeline = session.pipeline.lock().await; - let kind = pipeline.nodes.get(&node_id).map(|n| n.kind.clone()); - let file_path = - params.get("path").and_then(serde_json::Value::as_str).map(str::to_string); - let script_path = - params.get("script_path").and_then(serde_json::Value::as_str).map(str::to_string); - drop(pipeline); - (kind, file_path, script_path) - }; - - let file_path = file_path.as_deref(); - let script_path = script_path.as_deref(); - - if kind.as_deref() == Some("core::file_reader") { - let Some(path) = file_path else { - return Some(ResponsePayload::Error { - message: "Invalid file_reader params: expected params.path to be a string" - .to_string(), - }); - }; - if let Err(e) = file_security::validate_file_path(path, &app_state.config.security) { - return Some(ResponsePayload::Error { message: format!("Invalid file path: {e}") }); - } - } - - if kind.as_deref() == Some("core::file_writer") { - if let Some(path) = file_path { - if let Err(e) = file_security::validate_write_path(path, &app_state.config.security) - { - return Some(ResponsePayload::Error { - message: format!("Invalid write path: {e}"), - }); - } - } - } - - if kind.as_deref() == Some("core::script") { - if let Some(path) = script_path { - if !path.trim().is_empty() { - if let Err(e) = - file_security::validate_file_path(path, &app_state.config.security) - { - return Some(ResponsePayload::Error { - message: format!("Invalid script_path: {e}"), - }); - } - } - } - } - - { - // Store sanitized params: strip transient sync metadata - // (_sender, _rev, etc.) for consistency with the - // fire-and-forget handler. - let mut durable_params = params.clone(); - if let serde_json::Value::Object(ref mut map) = durable_params { - map.retain(|k, _| !k.starts_with('_')); - } - let mut pipeline = session.pipeline.lock().await; - if let Some(node) = pipeline.nodes.get_mut(&node_id) { - // Deep-merge the partial update into existing params so - // sibling keys are preserved (mirrors the async handler). - node.params = Some(match node.params.take() { - Some(existing) => deep_merge_json(existing, durable_params), - None => durable_params, - }); - } else { - warn!( - node_id = %node_id, - "Attempted to tune params for non-existent node in pipeline model" - ); - } - } // Lock released here - - // Broadcast event to all clients - let event = ApiEvent { - message_type: MessageType::Event, - correlation_id: None, - payload: EventPayload::NodeParamsChanged { - session_id: session.id.clone(), - node_id: node_id.clone(), - params: params.clone(), - }, - }; - if let Err(e) = app_state.event_tx.send(BroadcastEvent::to_all(event)) { - error!("Failed to broadcast NodeParamsChanged event: {}", e); - } + match crate::server::tune_session_node( + &session, + node_id, + message, + &app_state.config.security, + &app_state.event_tx, + ) + .await + { + Ok(()) => Some(ResponsePayload::Success), + Err(message) => Some(ResponsePayload::Error { message }), } - - // Now safe to do async operations without holding session_manager lock - let control_msg = EngineControlMessage::TuneNode { node_id, message }; - session.send_control_message(control_msg).await; - Some(ResponsePayload::Success) } /// Validate file/script paths in UpdateParams against security policy. /// /// Returns `true` if the params are allowed, `false` if they should be rejected. -fn validate_update_params_security( +pub fn validate_update_params_security( kind: Option<&str>, params: &serde_json::Value, security: &crate::config::SecurityConfig, @@ -960,9 +872,7 @@ async fn handle_tune_node_fire_and_forget( ) -> Option { let action_label = "TuneNodeAsync"; - // Check permission to tune nodes if !perms.tune_nodes { - // For async operations, we don't send a response but we should still log warn!("Permission denied: attempted to tune node without permission via {action_label}"); return None; } @@ -970,10 +880,9 @@ async fn handle_tune_node_fire_and_forget( let session = { let session_manager = app_state.session_manager.lock().await; session_manager.get_session_by_name_or_id(&session_id) - }; // Session manager lock released here + }; if let Some(session) = session { - // Check ownership if !can_access_session(&session, role_name, perms) { warn!( session_id = %session_id, @@ -983,71 +892,21 @@ async fn handle_tune_node_fire_and_forget( return None; } - // Handle UpdateParams specially for pipeline model updates and event broadcasting - if let NodeControlMessage::UpdateParams(ref params) = message { - let kind = { - let pipeline = session.pipeline.lock().await; - pipeline.nodes.get(&node_id).map(|n| n.kind.clone()) - }; - - if !validate_update_params_security(kind.as_deref(), params, &app_state.config.security) - { - return None; - } - - { - // Store sanitized params: strip transient sync metadata - // (_sender, _rev, etc.) from the durable pipeline model. - // Top-level keys prefixed with `_` are reserved for - // in-flight metadata and must not leak into persistence - // or GetPipeline responses. - let mut durable_params = params.clone(); - if let serde_json::Value::Object(ref mut map) = durable_params { - map.retain(|k, _| !k.starts_with('_')); - } - let mut pipeline = session.pipeline.lock().await; - if let Some(node) = pipeline.nodes.get_mut(&node_id) { - // Deep-merge the partial update into existing params so - // sibling keys are preserved. Without this, a partial - // nested update like `{ properties: { show: false } }` - // would overwrite the entire params, losing keys such - // as `fps`, `width`, or `properties.name`. - node.params = Some(match node.params.take() { - Some(existing) => deep_merge_json(existing, durable_params), - None => durable_params, - }); - } else { - warn!( - node_id = %node_id, - "Attempted to tune params for non-existent node in pipeline model via {action_label}" - ); - } - } // Lock released here - - // Broadcast the *partial delta* (not merged state) to all clients. - // Correct deep-merge on receive depends on each client having a - // valid base state, which is guaranteed because every client - // fetches the full pipeline on connect. - let event = ApiEvent { - message_type: MessageType::Event, - correlation_id: None, - payload: EventPayload::NodeParamsChanged { - session_id: session.id.clone(), - node_id: node_id.clone(), - params: params.clone(), - }, - }; - if let Err(e) = app_state.event_tx.send(BroadcastEvent::to_all(event)) { - error!("Failed to broadcast NodeParamsChanged event: {}", e); - } + if let Err(e) = crate::server::tune_session_node( + &session, + node_id, + message, + &app_state.config.security, + &app_state.event_tx, + ) + .await + { + warn!("Security policy rejected tune via {action_label}: {e}"); } - - let control_msg = EngineControlMessage::TuneNode { node_id, message }; - session.send_control_message(control_msg).await; } else { warn!("Could not tune non-existent session '{session_id}' via {action_label}"); } - None // Do not send a response + None } /// Handle async node tuning (fire-and-forget, broadcasts to all). @@ -1136,14 +995,12 @@ async fn handle_validate_batch( perms: &Permissions, role_name: &str, ) -> ResponsePayload { - // Validate that user has permission for modify_sessions if !perms.modify_sessions { return ResponsePayload::Error { message: "Permission denied: cannot modify sessions".to_string(), }; } - // Verify session exists let session = { let session_manager = app_state.session_manager.lock().await; session_manager.get_session_by_name_or_id(&session_id) @@ -1153,56 +1010,19 @@ async fn handle_validate_batch( return ResponsePayload::Error { message: format!("Session '{session_id}' not found") }; }; - // Check ownership if !can_access_session(&session, role_name, perms) { return ResponsePayload::Error { message: "Permission denied: you do not own this session".to_string(), }; } - // Collect all validation errors so the caller sees every problem at once. - let mut errors: Vec = Vec::new(); - - // Pre-validate duplicate node_ids against the pipeline model, mirroring - // the same simulation that handle_apply_batch performs. - let mut live_ids: std::collections::HashSet = - session.pipeline.lock().await.nodes.keys().cloned().collect(); - for op in operations { - match op { - streamkit_api::BatchOperation::AddNode { node_id, .. } => { - if !live_ids.insert(node_id.clone()) { - errors.push(ValidationError { - error_type: ValidationErrorType::Error, - message: format!( - "Batch rejected: node '{node_id}' already exists in the pipeline" - ), - node_id: Some(node_id.clone()), - connection_id: None, - }); - } - }, - streamkit_api::BatchOperation::RemoveNode { node_id } => { - live_ids.remove(node_id.as_str()); - }, - _ => {}, - } - } - - // Validate all AddNode operations against permission and security rules. - for op in operations { - if let streamkit_api::BatchOperation::AddNode { node_id, kind, params, .. } = op { - if let Some(message) = - validate_add_node_op(kind, params.as_ref(), perms, &app_state.config.security) - { - errors.push(ValidationError { - error_type: ValidationErrorType::Error, - message, - node_id: Some(node_id.clone()), - connection_id: None, - }); - } - } - } + let errors = crate::server::validate_batch_operations( + &session, + operations, + perms, + &app_state.config.security, + ) + .await; info!( operation_count = operations.len(), @@ -1212,7 +1032,6 @@ async fn handle_validate_batch( ResponsePayload::ValidationResult { errors } } -#[allow(clippy::significant_drop_tightening)] async fn handle_apply_batch( session_id: String, operations: Vec, @@ -1220,18 +1039,16 @@ async fn handle_apply_batch( perms: &Permissions, role_name: &str, ) -> Option { - // Check permission to modify sessions if !perms.modify_sessions { return Some(ResponsePayload::Error { message: "Permission denied: cannot modify sessions".to_string(), }); } - // Get session with SHORT lock hold to avoid blocking other operations let session = { let session_manager = app_state.session_manager.lock().await; session_manager.get_session_by_name_or_id(&session_id) - }; // Session manager lock released here + }; let Some(session) = session else { return Some(ResponsePayload::Error { @@ -1239,142 +1056,26 @@ async fn handle_apply_batch( }); }; - // Check ownership (session is cloned, doesn't need lock) if !can_access_session(&session, role_name, perms) { return Some(ResponsePayload::Error { message: "Permission denied: you do not own this session".to_string(), }); } - // Pre-validate duplicate node_ids against the pipeline model. - // Simulate the batch's Add/Remove sequence so that Remove→Add for - // the same ID within the batch is allowed, but duplicate Adds - // (without intervening Remove) are rejected before any mutation. - { - let pipeline = session.pipeline.lock().await; - let mut live_ids: std::collections::HashSet<&str> = - pipeline.nodes.keys().map(String::as_str).collect(); - for op in &operations { - match op { - streamkit_api::BatchOperation::AddNode { node_id, .. } => { - if !live_ids.insert(node_id.as_str()) { - return Some(ResponsePayload::Error { - message: format!( - "Batch rejected: node '{node_id}' already exists in the pipeline" - ), - }); - } - }, - streamkit_api::BatchOperation::RemoveNode { node_id } => { - live_ids.remove(node_id.as_str()); - }, - _ => {}, - } - } - } // Pipeline lock released after pre-validation - - // Validate permissions for all operations. - for op in &operations { - if let streamkit_api::BatchOperation::AddNode { kind, params, .. } = op { - if let Some(message) = - validate_add_node_op(kind, params.as_ref(), perms, &app_state.config.security) - { - return Some(ResponsePayload::Error { message }); - } - } - } - - // Apply all operations in order - let mut engine_operations = Vec::new(); - + match crate::server::apply_batch_operations( + &session, + operations, + perms, + &app_state.config.security, + ) + .await { - let mut pipeline = session.pipeline.lock().await; - - for op in operations { - match op { - streamkit_api::BatchOperation::AddNode { node_id, kind, params } => { - pipeline.nodes.insert( - node_id.clone(), - streamkit_api::Node { - kind: kind.clone(), - params: params.clone(), - state: None, - }, - ); - engine_operations.push(EngineControlMessage::AddNode { node_id, kind, params }); - }, - streamkit_api::BatchOperation::RemoveNode { node_id } => { - pipeline.nodes.shift_remove(&node_id); - pipeline - .connections - .retain(|conn| conn.from_node != node_id && conn.to_node != node_id); - engine_operations.push(EngineControlMessage::RemoveNode { node_id }); - }, - streamkit_api::BatchOperation::Connect { - from_node, - from_pin, - to_node, - to_pin, - mode, - } => { - pipeline.connections.push(streamkit_api::Connection { - from_node: from_node.clone(), - from_pin: from_pin.clone(), - to_node: to_node.clone(), - to_pin: to_pin.clone(), - mode, - }); - let core_mode = match mode { - streamkit_api::ConnectionMode::Reliable => { - streamkit_core::control::ConnectionMode::Reliable - }, - streamkit_api::ConnectionMode::BestEffort => { - streamkit_core::control::ConnectionMode::BestEffort - }, - }; - engine_operations.push(EngineControlMessage::Connect { - from_node, - from_pin, - to_node, - to_pin, - mode: core_mode, - }); - }, - streamkit_api::BatchOperation::Disconnect { - from_node, - from_pin, - to_node, - to_pin, - } => { - pipeline.connections.retain(|conn| { - !(conn.from_node == from_node - && conn.from_pin == from_pin - && conn.to_node == to_node - && conn.to_pin == to_pin) - }); - engine_operations.push(EngineControlMessage::Disconnect { - from_node, - from_pin, - to_node, - to_pin, - }); - }, - } - } - drop(pipeline); - } // Release pipeline lock - - // Now safe to do async operations without holding session_manager lock - for msg in engine_operations { - session.send_control_message(msg).await; + Ok(()) => { + info!(session_id = %session_id, "Applied batch operations successfully"); + Some(ResponsePayload::BatchApplied { success: true, errors: Vec::new() }) + }, + Err(message) => Some(ResponsePayload::Error { message }), } - - info!( - session_id = %session_id, - "Applied batch operations successfully" - ); - - Some(ResponsePayload::BatchApplied { success: true, errors: Vec::new() }) } fn handle_get_permissions(perms: &Permissions, role_name: &str) -> ResponsePayload { @@ -1385,7 +1086,7 @@ fn handle_get_permissions(perms: &Permissions, role_name: &str) -> ResponsePaylo /// Recursively deep-merges `source` into `target`, returning the merged value. /// Only JSON objects are merged recursively; arrays and scalars in `source` /// replace the corresponding value in `target`. -fn deep_merge_json(target: serde_json::Value, source: serde_json::Value) -> serde_json::Value { +pub fn deep_merge_json(target: serde_json::Value, source: serde_json::Value) -> serde_json::Value { match (target, source) { (serde_json::Value::Object(mut t_map), serde_json::Value::Object(s_map)) => { for (key, s_val) in s_map { diff --git a/apps/skit/tests/mcp_integration_test.rs b/apps/skit/tests/mcp_integration_test.rs index d630af13..92fad028 100644 --- a/apps/skit/tests/mcp_integration_test.rs +++ b/apps/skit/tests/mcp_integration_test.rs @@ -966,3 +966,337 @@ async fn mcp_get_prompt_unknown_returns_error() { let error = &body["error"]; assert!(!error.is_null(), "expected error for unknown prompt, got: {body}"); } + +// ----------------------------------------------------------------------- +// Batch & Tune tests +// ----------------------------------------------------------------------- + +/// Helper: create a StreamKit session via MCP and return its session_id. +async fn create_skit_session( + client: &reqwest::Client, + addr: SocketAddr, + token: &str, + mcp_session: &str, + yaml: &str, +) -> String { + let create = json!({ + "jsonrpc": "2.0", + "id": 100, + "method": "tools/call", + "params": { + "name": "create_session", + "arguments": { "yaml": yaml } + } + }); + let res = mcp_post_with_session(client, addr, &create, token, mcp_session).await; + assert_eq!(res.status(), StatusCode::OK); + let body = extract_sse_json(&res.text().await.unwrap()); + let text = body["result"]["content"][0]["text"].as_str().expect("create_session text"); + let parsed: serde_json::Value = serde_json::from_str(text).unwrap(); + parsed["session_id"].as_str().expect("session_id").to_string() +} + +#[tokio::test] +async fn mcp_validate_batch_valid_operations() { + let _ = tracing_subscriber::fmt::try_init(); + let (addr, _handle, token, _dir) = start_mcp_server().await; + let client = reqwest::Client::new(); + let mcp_session = init_mcp_session(&client, addr, &token).await; + + let skit_session = + create_skit_session(&client, addr, &token, &mcp_session, PASSTHROUGH_YAML).await; + + let validate = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "validate_batch", + "arguments": { + "session_id": skit_session, + "operations": [ + { "action": "addnode", "node_id": "new_pass", "kind": "core::passthrough" } + ] + } + } + }); + let res = mcp_post_with_session(&client, addr, &validate, &token, &mcp_session).await; + assert_eq!(res.status(), StatusCode::OK); + + let body = extract_sse_json(&res.text().await.unwrap()); + let text = body["result"]["content"][0]["text"].as_str().expect("validate_batch text"); + let errors: Vec = serde_json::from_str(text).unwrap(); + assert!(errors.is_empty(), "expected no validation errors, got: {errors:?}"); +} + +#[tokio::test] +async fn mcp_validate_batch_invalid_duplicate_node() { + let _ = tracing_subscriber::fmt::try_init(); + let (addr, _handle, token, _dir) = start_mcp_server().await; + let client = reqwest::Client::new(); + let mcp_session = init_mcp_session(&client, addr, &token).await; + + let skit_session = + create_skit_session(&client, addr, &token, &mcp_session, PASSTHROUGH_YAML).await; + + // "pass" already exists in the pipeline — adding it again should fail. + let validate = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "validate_batch", + "arguments": { + "session_id": skit_session, + "operations": [ + { "action": "addnode", "node_id": "pass", "kind": "core::passthrough" } + ] + } + } + }); + let res = mcp_post_with_session(&client, addr, &validate, &token, &mcp_session).await; + assert_eq!(res.status(), StatusCode::OK); + + let body = extract_sse_json(&res.text().await.unwrap()); + let text = body["result"]["content"][0]["text"].as_str().expect("validate_batch text"); + let errors: Vec = serde_json::from_str(text).unwrap(); + assert!(!errors.is_empty(), "expected validation errors for duplicate node"); + assert!( + errors[0]["message"].as_str().unwrap().contains("already exists"), + "expected 'already exists' error" + ); +} + +#[tokio::test] +async fn mcp_apply_batch_add_node_round_trip() { + let _ = tracing_subscriber::fmt::try_init(); + let (addr, _handle, token, _dir) = start_mcp_server().await; + let client = reqwest::Client::new(); + let mcp_session = init_mcp_session(&client, addr, &token).await; + + let skit_session = + create_skit_session(&client, addr, &token, &mcp_session, PASSTHROUGH_YAML).await; + + // Apply: add a new node + let apply = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "apply_batch", + "arguments": { + "session_id": skit_session, + "operations": [ + { "action": "addnode", "node_id": "extra", "kind": "core::passthrough" } + ] + } + } + }); + let res = mcp_post_with_session(&client, addr, &apply, &token, &mcp_session).await; + assert_eq!(res.status(), StatusCode::OK); + + let body = extract_sse_json(&res.text().await.unwrap()); + let text = body["result"]["content"][0]["text"].as_str().expect("apply_batch text"); + let result: serde_json::Value = serde_json::from_str(text).unwrap(); + assert_eq!(result["success"], true); + + // Verify via get_pipeline that "extra" exists + let get = json!({ + "jsonrpc": "2.0", + "id": 3, + "method": "tools/call", + "params": { + "name": "get_pipeline", + "arguments": { "session_id": skit_session } + } + }); + let res = mcp_post_with_session(&client, addr, &get, &token, &mcp_session).await; + assert_eq!(res.status(), StatusCode::OK); + + let body = extract_sse_json(&res.text().await.unwrap()); + let text = body["result"]["content"][0]["text"].as_str().expect("get_pipeline text"); + let pipeline: serde_json::Value = serde_json::from_str(text).unwrap(); + assert!(pipeline["nodes"]["extra"].is_object(), "expected 'extra' node in pipeline"); + assert!(pipeline["nodes"]["pass"].is_object(), "expected original 'pass' node in pipeline"); + + // Clean up + let destroy = json!({ + "jsonrpc": "2.0", + "id": 4, + "method": "tools/call", + "params": { "name": "destroy_session", "arguments": { "session_id": skit_session } } + }); + let _ = mcp_post_with_session(&client, addr, &destroy, &token, &mcp_session).await; +} + +#[tokio::test] +async fn mcp_tune_node_update_params() { + let _ = tracing_subscriber::fmt::try_init(); + let (addr, _handle, token, _dir) = start_mcp_server().await; + let client = reqwest::Client::new(); + let mcp_session = init_mcp_session(&client, addr, &token).await; + + let skit_session = + create_skit_session(&client, addr, &token, &mcp_session, PASSTHROUGH_YAML).await; + + // Tune: send UpdateParams to the "pass" node + let tune = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "tune_node", + "arguments": { + "session_id": skit_session, + "node_id": "pass", + "message": { "UpdateParams": { "gain": 0.5 } } + } + } + }); + let res = mcp_post_with_session(&client, addr, &tune, &token, &mcp_session).await; + assert_eq!(res.status(), StatusCode::OK); + + let body = extract_sse_json(&res.text().await.unwrap()); + let text = body["result"]["content"][0]["text"].as_str().expect("tune_node text"); + let result: serde_json::Value = serde_json::from_str(text).unwrap(); + assert_eq!(result["success"], true); + + // Verify params persisted in pipeline model + let get = json!({ + "jsonrpc": "2.0", + "id": 3, + "method": "tools/call", + "params": { + "name": "get_pipeline", + "arguments": { "session_id": skit_session } + } + }); + let res = mcp_post_with_session(&client, addr, &get, &token, &mcp_session).await; + assert_eq!(res.status(), StatusCode::OK); + + let body = extract_sse_json(&res.text().await.unwrap()); + let text = body["result"]["content"][0]["text"].as_str().expect("get_pipeline text"); + let pipeline: serde_json::Value = serde_json::from_str(text).unwrap(); + let pass_params = &pipeline["nodes"]["pass"]["params"]; + assert_eq!(pass_params["gain"], 0.5, "expected tuned gain param"); + + // Clean up + let destroy = json!({ + "jsonrpc": "2.0", + "id": 4, + "method": "tools/call", + "params": { "name": "destroy_session", "arguments": { "session_id": skit_session } } + }); + let _ = mcp_post_with_session(&client, addr, &destroy, &token, &mcp_session).await; +} + +#[tokio::test] +async fn mcp_modify_sessions_permission_denied() { + let _ = tracing_subscriber::fmt::try_init(); + + // Start a server where modify_sessions is disabled for admin + let listener = + TcpListener::bind("127.0.0.1:0").await.expect("Failed to bind test server listener"); + let addr = listener.local_addr().unwrap(); + let temp_dir = TempDir::new().unwrap(); + + let mut config = Config::default(); + config.mcp.enabled = true; + config.auth.mode = streamkit_server::config::AuthMode::Enabled; + config.auth.state_dir = temp_dir.path().to_string_lossy().to_string(); + + if let Some(admin_perms) = config.permissions.roles.get_mut("admin") { + admin_perms.modify_sessions = false; + admin_perms.tune_nodes = false; + } + + let auth_state = streamkit_server::auth::AuthState::new(&config.auth, true) + .await + .expect("Failed to init auth state"); + let auth_state = Arc::new(auth_state); + + let admin_token_path = temp_dir.path().join("admin.token"); + let admin_token = + tokio::fs::read_to_string(&admin_token_path).await.expect("Missing admin.token"); + let admin_token = admin_token.trim().to_string(); + + let (app, _state) = streamkit_server::server::create_app(config, Some(auth_state)); + let _server_handle = tokio::spawn(async move { + axum::serve(listener, app.into_make_service()).await.unwrap(); + }); + + wait_for_healthz(addr).await; + + let client = reqwest::Client::new(); + let mcp_session = init_mcp_session(&client, addr, &admin_token).await; + + // validate_batch should be denied + let validate = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "validate_batch", + "arguments": { + "session_id": "any", + "operations": [] + } + } + }); + let res = mcp_post_with_session(&client, addr, &validate, &admin_token, &mcp_session).await; + assert_eq!(res.status(), StatusCode::OK); + let body = extract_sse_json(&res.text().await.unwrap()); + let error = &body["error"]; + assert!(!error.is_null(), "expected error for validate_batch, got: {body}"); + assert!( + error["message"].as_str().unwrap_or("").contains("Permission denied"), + "expected permission denied" + ); + + // apply_batch should be denied + let apply = json!({ + "jsonrpc": "2.0", + "id": 3, + "method": "tools/call", + "params": { + "name": "apply_batch", + "arguments": { + "session_id": "any", + "operations": [] + } + } + }); + let res = mcp_post_with_session(&client, addr, &apply, &admin_token, &mcp_session).await; + assert_eq!(res.status(), StatusCode::OK); + let body = extract_sse_json(&res.text().await.unwrap()); + let error = &body["error"]; + assert!(!error.is_null(), "expected error for apply_batch, got: {body}"); + assert!( + error["message"].as_str().unwrap_or("").contains("Permission denied"), + "expected permission denied" + ); + + // tune_node should be denied + let tune = json!({ + "jsonrpc": "2.0", + "id": 4, + "method": "tools/call", + "params": { + "name": "tune_node", + "arguments": { + "session_id": "any", + "node_id": "any", + "message": { "UpdateParams": {} } + } + } + }); + let res = mcp_post_with_session(&client, addr, &tune, &admin_token, &mcp_session).await; + assert_eq!(res.status(), StatusCode::OK); + let body = extract_sse_json(&res.text().await.unwrap()); + let error = &body["error"]; + assert!(!error.is_null(), "expected error for tune_node, got: {body}"); + assert!( + error["message"].as_str().unwrap_or("").contains("Permission denied"), + "expected permission denied" + ); +} diff --git a/crates/api/Cargo.toml b/crates/api/Cargo.toml index ea133ef5..1cc3a07b 100644 --- a/crates/api/Cargo.toml +++ b/crates/api/Cargo.toml @@ -18,6 +18,7 @@ serde = { version = "1.0.228", features = ["derive", "rc"] } serde_json = "1.0" serde-saphyr = "0.0.23" ts-rs = { version = "12.0.1" } +schemars = { version = "1.2.0", features = ["derive"] } indexmap = { version = "2.14", features = ["serde"] } [[bin]] diff --git a/crates/api/src/lib.rs b/crates/api/src/lib.rs index ec204f85..9d567730 100644 --- a/crates/api/src/lib.rs +++ b/crates/api/src/lib.rs @@ -234,7 +234,7 @@ pub enum RequestPayload { GetPermissions, } -#[derive(Serialize, Deserialize, Debug, Clone, TS)] +#[derive(Serialize, Deserialize, Debug, Clone, TS, schemars::JsonSchema)] #[ts(export)] #[serde(tag = "action")] #[serde(rename_all = "lowercase")] diff --git a/crates/core/src/control.rs b/crates/core/src/control.rs index f73e90dd..e36a085d 100644 --- a/crates/core/src/control.rs +++ b/crates/core/src/control.rs @@ -15,7 +15,7 @@ use serde::{Deserialize, Serialize}; use ts_rs::TS; /// A message sent to a specific, running node to tune its parameters or control its lifecycle. -#[derive(Debug, Deserialize, Serialize, TS)] +#[derive(Debug, Deserialize, Serialize, TS, schemars::JsonSchema)] #[ts(export)] pub enum NodeControlMessage { UpdateParams(#[ts(type = "JsonValue")] serde_json::Value), @@ -28,7 +28,9 @@ pub enum NodeControlMessage { } /// Specifies how a connection handles backpressure from slow consumers. -#[derive(Debug, Deserialize, Serialize, Clone, Copy, PartialEq, Eq, Default, TS)] +#[derive( + Debug, Deserialize, Serialize, Clone, Copy, PartialEq, Eq, Default, TS, schemars::JsonSchema, +)] #[ts(export)] #[serde(rename_all = "snake_case")] pub enum ConnectionMode {