diff --git a/apps/skit/src/mcp/mod.rs b/apps/skit/src/mcp/mod.rs index 37719c3e..0ab01ea7 100644 --- a/apps/skit/src/mcp/mod.rs +++ b/apps/skit/src/mcp/mod.rs @@ -248,6 +248,12 @@ pub struct TuneNodeArgs { pub message: streamkit_core::control::NodeControlMessage, } +#[derive(Debug, Deserialize, schemars::JsonSchema)] +pub struct GetNodeDefinitionArgs { + /// Node kind to look up (e.g., "audio::gain", "core::passthrough"). + pub kind: String, +} + // --------------------------------------------------------------------------- // StreamKit MCP service // --------------------------------------------------------------------------- @@ -693,6 +699,66 @@ impl StreamKitMcp { let result = serde_json::json!({ "success": true }); json_tool_result(&result) } + + // -- list_plugins ------------------------------------------------------ + + #[tool( + description = "List installed StreamKit plugins with their kind, version, type (native/wasm), and categories." + )] + async fn list_plugins( + &self, + ctx: RequestContext, + ) -> Result { + let (_role_name, perms) = extract_auth(&ctx, &self.app_state)?; + + let mut plugins = self.app_state.plugin_manager.lock().await.list_plugins(); + plugins.retain(|plugin| perms.is_plugin_allowed(&plugin.kind)); + + info!(count = plugins.len(), "MCP list_plugins"); + + json_tool_result(&plugins) + } + + // -- get_node_definition ----------------------------------------------- + + #[tool( + description = "Get the full definition (schema, pins, categories) for a specific node kind. Use this when you need the param schema or pin layout for a particular node." + )] + async fn get_node_definition( + &self, + Parameters(args): Parameters, + ctx: RequestContext, + ) -> Result { + let (_role_name, perms) = extract_auth(&ctx, &self.app_state)?; + + if !perms.is_node_allowed(&args.kind) { + return Err(McpError::invalid_request( + format!("Permission denied: node kind '{}' is not allowed", args.kind), + None, + )); + } + + if args.kind.starts_with("plugin::") && !perms.is_plugin_allowed(&args.kind) { + return Err(McpError::invalid_request( + format!("Permission denied: plugin '{}' is not allowed", args.kind), + None, + )); + } + + let definitions = filtered_node_definitions(&self.app_state, &perms)?; + let definition = definitions.into_iter().find(|d| d.kind == args.kind); + + let Some(definition) = definition else { + return Err(McpError::invalid_params( + format!("Node kind '{}' not found", args.kind), + None, + )); + }; + + info!(kind = %args.kind, "MCP get_node_definition"); + + json_tool_result(&definition) + } } // --------------------------------------------------------------------------- @@ -706,7 +772,9 @@ impl ServerHandler for StreamKitMcp { let capabilities = ServerCapabilities::builder().enable_tools().enable_prompts().build(); let mut info = ServerInfo::new(capabilities).with_instructions( "StreamKit MCP server. Use list_nodes to discover available \ - processing nodes, validate_pipeline to check YAML, and \ + processing nodes, get_node_definition to look up a specific \ + node's schema/pins/categories, list_plugins to see installed \ + plugins, validate_pipeline to check YAML, and \ create_session / list_sessions / get_pipeline / destroy_session \ to manage dynamic pipeline sessions. Use validate_batch and \ apply_batch to mutate a running session's graph as a validated batch, \ diff --git a/apps/skit/tests/mcp_integration_test.rs b/apps/skit/tests/mcp_integration_test.rs index 2b72107e..001403ca 100644 --- a/apps/skit/tests/mcp_integration_test.rs +++ b/apps/skit/tests/mcp_integration_test.rs @@ -1302,6 +1302,108 @@ async fn mcp_modify_sessions_permission_denied() { ); } +// --------------------------------------------------------------------------- +// Plugin & node-definition tools +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn mcp_list_plugins_returns_results() { + let _ = tracing_subscriber::fmt::try_init(); + + let (addr, _handle, token, _dir) = start_mcp_server().await; + let client = reqwest::Client::new(); + let session_id = init_mcp_session(&client, addr, &token).await; + + let list = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "list_plugins", + "arguments": {} + } + }); + let res = mcp_post_with_session(&client, addr, &list, &token, &session_id).await; + assert_eq!(res.status(), StatusCode::OK); + + let body_text = res.text().await.unwrap(); + let body = extract_sse_json(&body_text); + let result = &body["result"]; + assert!(!result.is_null(), "expected result, got: {body}"); + + let text = result["content"][0]["text"].as_str().expect("expected text content"); + let plugins: Vec = serde_json::from_str(text).expect("expected JSON array"); + // No plugins loaded in test server, but should return an empty array without error. + assert!(plugins.is_empty(), "expected empty plugin list in test server, got: {plugins:?}"); +} + +#[tokio::test] +async fn mcp_get_node_definition_found() { + let _ = tracing_subscriber::fmt::try_init(); + + let (addr, _handle, token, _dir) = start_mcp_server().await; + let client = reqwest::Client::new(); + let session_id = init_mcp_session(&client, addr, &token).await; + + let get_def = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "get_node_definition", + "arguments": { + "kind": "core::passthrough" + } + } + }); + let res = mcp_post_with_session(&client, addr, &get_def, &token, &session_id).await; + assert_eq!(res.status(), StatusCode::OK); + + let body_text = res.text().await.unwrap(); + let body = extract_sse_json(&body_text); + let result = &body["result"]; + assert!(!result.is_null(), "expected result, got: {body}"); + + let text = result["content"][0]["text"].as_str().expect("expected text content"); + let def: serde_json::Value = serde_json::from_str(text).expect("expected JSON object"); + assert_eq!(def["kind"], "core::passthrough"); + assert!(def["inputs"].is_array(), "expected inputs array in definition"); + assert!(def["outputs"].is_array(), "expected outputs array in definition"); +} + +#[tokio::test] +async fn mcp_get_node_definition_not_found() { + let _ = tracing_subscriber::fmt::try_init(); + + let (addr, _handle, token, _dir) = start_mcp_server().await; + let client = reqwest::Client::new(); + let session_id = init_mcp_session(&client, addr, &token).await; + + let get_def = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "get_node_definition", + "arguments": { + "kind": "nonexistent::node_kind" + } + } + }); + let res = mcp_post_with_session(&client, addr, &get_def, &token, &session_id).await; + assert_eq!(res.status(), StatusCode::OK); + + let body_text = res.text().await.unwrap(); + let body = extract_sse_json(&body_text); + let error = &body["error"]; + assert!(!error.is_null(), "expected error for nonexistent kind, got: {body}"); + assert!( + error["message"].as_str().unwrap_or("").contains("not found"), + "expected 'not found' error message, got: {}", + error["message"] + ); +} + // --------------------------------------------------------------------------- // STDIO transport tests // ---------------------------------------------------------------------------