From 59f3ece2b6e237a92aa036a7e6f23d24ff24f363 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Wed, 21 May 2025 20:19:58 -0300 Subject: [PATCH 1/5] refactor: introduce middleware , sessionId generation --- crates/rust-mcp-sdk/src/hyper_servers.rs | 1 + .../src/hyper_servers/app_state.rs | 1 + .../src/hyper_servers/middlewares.rs | 1 + .../middlewares/session_id_gen.rs | 23 ++++++++ .../src/hyper_servers/routes/sse_routes.rs | 33 +++++++----- .../rust-mcp-sdk/src/hyper_servers/server.rs | 15 +++++- .../src/hyper_servers/session_store.rs | 4 +- .../src/mcp_runtimes/server_runtime.rs | 4 +- .../server_runtime/mcp_server_runtime.rs | 2 +- .../rust-mcp-sdk/tests/test_client_runtime.rs | 5 +- crates/rust-mcp-sdk/tests/test_server_sse.rs | 5 ++ crates/rust-mcp-transport/src/lib.rs | 3 ++ crates/rust-mcp-transport/src/sse.rs | 8 ++- crates/rust-mcp-transport/src/utils.rs | 54 ++++++++++++++++++- 14 files changed, 133 insertions(+), 26 deletions(-) create mode 100644 crates/rust-mcp-sdk/src/hyper_servers/middlewares.rs create mode 100644 crates/rust-mcp-sdk/src/hyper_servers/middlewares/session_id_gen.rs create mode 100644 crates/rust-mcp-sdk/tests/test_server_sse.rs diff --git a/crates/rust-mcp-sdk/src/hyper_servers.rs b/crates/rust-mcp-sdk/src/hyper_servers.rs index ad1e2cd..9a58b04 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers.rs @@ -2,6 +2,7 @@ mod app_state; pub mod error; pub mod hyper_server; pub mod hyper_server_core; +mod middlewares; mod routes; mod server; mod session_store; diff --git a/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs b/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs index 3276802..af572dd 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs @@ -18,5 +18,6 @@ pub struct AppState { pub server_details: Arc, pub handler: Arc, pub ping_interval: Duration, + pub sse_message_endpoint: String, pub transport_options: Arc, } diff --git a/crates/rust-mcp-sdk/src/hyper_servers/middlewares.rs b/crates/rust-mcp-sdk/src/hyper_servers/middlewares.rs new file mode 100644 index 0000000..612510e --- /dev/null +++ b/crates/rust-mcp-sdk/src/hyper_servers/middlewares.rs @@ -0,0 +1 @@ +pub(crate) mod session_id_gen; diff --git a/crates/rust-mcp-sdk/src/hyper_servers/middlewares/session_id_gen.rs b/crates/rust-mcp-sdk/src/hyper_servers/middlewares/session_id_gen.rs new file mode 100644 index 0000000..b68b325 --- /dev/null +++ b/crates/rust-mcp-sdk/src/hyper_servers/middlewares/session_id_gen.rs @@ -0,0 +1,23 @@ +use std::sync::Arc; + +use axum::{ + extract::{Request, State}, + middleware::Next, + response::Response, +}; +use hyper::StatusCode; +use rust_mcp_transport::SessionId; + +use crate::hyper_servers::app_state::AppState; + +// Middleware to generate and attach a session ID +pub async fn generate_session_id( + State(state): State>, + mut request: Request, + next: Next, +) -> Result { + let session_id: SessionId = state.id_generator.generate(); + request.extensions_mut().insert(session_id); + // Proceed to the next middleware or handler + Ok(next.run(request).await) +} diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs index 2efe3be..8085474 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs @@ -1,21 +1,25 @@ use crate::{ error::McpSdkError, - hyper_servers::{app_state::AppState, error::TransportServerResult}, + hyper_servers::{ + app_state::AppState, error::TransportServerResult, + middlewares::session_id_gen::generate_session_id, + }, mcp_server::{server_runtime, ServerRuntime}, mcp_traits::mcp_handler::McpServerHandler, McpServer, }; use axum::{ extract::State, + middleware, response::{ sse::{Event, KeepAlive}, IntoResponse, Sse, }, routing::get, - Router, + Extension, Router, }; use futures::stream::{self}; -use rust_mcp_transport::{error::TransportError, SseTransport}; +use rust_mcp_transport::{error::TransportError, SessionId, SseTransport}; use std::{convert::Infallible, sync::Arc, time::Duration}; use tokio::{ io::{duplex, AsyncBufReadExt, BufReader}, @@ -37,10 +41,8 @@ const DUPLEX_BUFFER_SIZE: usize = 8192; /// /// # Returns /// * `Result` - The constructed SSE event, infallible -fn initial_event(session_id: &str) -> Result { - Ok(Event::default() - .event("endpoint") - .data(format!("{SSE_MESSAGES_PATH}?sessionId={session_id}"))) +fn initial_event(endpoint: &str) -> Result { + Ok(Event::default().event("endpoint").data(endpoint)) } /// Configures the SSE routes for the application @@ -53,8 +55,13 @@ fn initial_event(session_id: &str) -> Result { /// /// # Returns /// * `Router>` - An Axum router configured with the SSE route -pub fn routes(_state: Arc, sse_endpoint: &str) -> Router> { - Router::new().route(sse_endpoint, get(handle_sse)) +pub fn routes(state: Arc, sse_endpoint: &str) -> Router> { + Router::new() + .route(sse_endpoint, get(handle_sse)) + .route_layer(middleware::from_fn_with_state( + state.clone(), + generate_session_id, + )) } /// Handles Server-Sent Events (SSE) connections @@ -68,15 +75,17 @@ pub fn routes(_state: Arc, sse_endpoint: &str) -> Router /// # Returns /// * `TransportServerResult` - The SSE response stream or an error pub async fn handle_sse( + Extension(session_id): Extension, State(state): State>, ) -> TransportServerResult { + let messages_endpoint = + SseTransport::message_endpoint(&state.sse_message_endpoint, &session_id); + // readable stream of string to be used in transport let (read_tx, read_rx) = duplex(DUPLEX_BUFFER_SIZE); // writable stream to deliver message to the client let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE); - // generate a session id, and keep it in the server state - let session_id = state.id_generator.generate(); state .session_store .set(session_id.to_owned(), read_tx) @@ -140,7 +149,7 @@ pub async fn handle_sse( }); // Initial SSE message to inform the client about the server's endpoint - let initial_event = stream::once(async move { initial_event(&session_id) }); + let initial_event = stream::once(async move { initial_event(&messages_endpoint) }); // Construct SSE stream for sending MCP messages to the server let reader = BufReader::new(write_rx); diff --git a/crates/rust-mcp-sdk/src/hyper_servers/server.rs b/crates/rust-mcp-sdk/src/hyper_servers/server.rs index f0770e1..b7394a7 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/server.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/server.rs @@ -23,6 +23,8 @@ const DEFAULT_CLIENT_PING_INTERVAL: Duration = Duration::from_secs(12); // Default Server-Sent Events (SSE) endpoint path const DEFAULT_SSE_ENDPOINT: &str = "/sse"; +// Default MCP Messages endpoint path +const DEFAULT_MESSAGES_ENDPOINT: &str = "/messages"; /// Configuration struct for the Hyper server /// Used to configure the HyperServer instance. @@ -31,9 +33,10 @@ pub struct HyperServerOptions { pub host: String, /// Hostname or IP address the server will bind to (default: "localhost") pub port: u16, - /// Optional custom path for the Server-Sent Events (SSE) endpoint. - /// If `None`, the default path `/sse` will be used. + /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`) pub custom_sse_endpoint: Option, + /// Optional custom path for the MCP messages endpoint (default: `/messages`) + pub custom_messages_endpoint: Option, /// Interval between automatic ping messages sent to clients to detect disconnects pub ping_interval: Duration, /// Enables SSL/TLS if set to `true` @@ -121,6 +124,12 @@ impl HyperServerOptions { .as_deref() .unwrap_or(DEFAULT_SSE_ENDPOINT) } + + pub fn sse_messages_endpoint(&self) -> &str { + self.custom_messages_endpoint + .as_deref() + .unwrap_or(DEFAULT_MESSAGES_ENDPOINT) + } } /// Default implementation for HyperServerOptions @@ -133,6 +142,7 @@ impl Default for HyperServerOptions { host: "127.0.0.1".to_string(), port: 8080, custom_sse_endpoint: None, + custom_messages_endpoint: None, ping_interval: DEFAULT_CLIENT_PING_INTERVAL, transport_options: Default::default(), enable_ssl: false, @@ -172,6 +182,7 @@ impl HyperServer { server_details: Arc::new(server_details), handler, ping_interval: server_options.ping_interval, + sse_message_endpoint: server_options.sse_messages_endpoint().to_owned(), transport_options: Arc::clone(&server_options.transport_options), }); let app = app_routes(Arc::clone(&state), &server_options); diff --git a/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs b/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs index da25000..998be01 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs @@ -3,15 +3,13 @@ use std::sync::Arc; use async_trait::async_trait; pub use in_memory::*; +use rust_mcp_transport::SessionId; use tokio::{io::DuplexStream, sync::Mutex}; use uuid::Uuid; // Type alias for the server-side duplex stream used in sessions pub type TxServer = DuplexStream; -// Type alias for session identifier, represented as a String -pub type SessionId = String; - /// Trait defining the interface for session storage operations /// /// This trait provides asynchronous methods for managing session data, diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs index d325305..5f22a43 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs @@ -12,10 +12,10 @@ use std::sync::{Arc, RwLock}; use tokio::io::AsyncWriteExt; use crate::error::SdkResult; -#[cfg(feature = "hyper-server")] -use crate::hyper_servers::SessionId; use crate::mcp_traits::mcp_handler::McpServerHandler; use crate::mcp_traits::mcp_server::McpServer; +#[cfg(feature = "hyper-server")] +use rust_mcp_transport::SessionId; /// Struct representing the runtime core of the MCP server, handling transport and client details pub struct ServerRuntime { diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs index dd9e98f..51eba77 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs @@ -12,7 +12,7 @@ use rust_mcp_transport::Transport; use super::ServerRuntime; #[cfg(feature = "hyper-server")] -use crate::hyper_servers::SessionId; +use rust_mcp_transport::SessionId; use crate::{ error::SdkResult, diff --git a/crates/rust-mcp-sdk/tests/test_client_runtime.rs b/crates/rust-mcp-sdk/tests/test_client_runtime.rs index c7804e5..c8b3b17 100644 --- a/crates/rust-mcp-sdk/tests/test_client_runtime.rs +++ b/crates/rust-mcp-sdk/tests/test_client_runtime.rs @@ -1,8 +1,7 @@ -use common::{test_client_info, TestClientHandler, NPX_SERVER_EVERYTHING}; -use rust_mcp_sdk::{mcp_client::client_runtime, McpClient, StdioTransport, TransportOptions}; - #[cfg(unix)] use common::UVX_SERVER_GIT; +use common::{test_client_info, TestClientHandler, NPX_SERVER_EVERYTHING}; +use rust_mcp_sdk::{mcp_client::client_runtime, McpClient, StdioTransport, TransportOptions}; #[path = "common/common.rs"] pub mod common; diff --git a/crates/rust-mcp-sdk/tests/test_server_sse.rs b/crates/rust-mcp-sdk/tests/test_server_sse.rs new file mode 100644 index 0000000..eee57fc --- /dev/null +++ b/crates/rust-mcp-sdk/tests/test_server_sse.rs @@ -0,0 +1,5 @@ +#[path = "common/common.rs"] +pub mod common; + +#[tokio::test] +async fn tets_server_sse() {} diff --git a/crates/rust-mcp-transport/src/lib.rs b/crates/rust-mcp-transport/src/lib.rs index a2ec12b..31d810d 100644 --- a/crates/rust-mcp-transport/src/lib.rs +++ b/crates/rust-mcp-transport/src/lib.rs @@ -16,3 +16,6 @@ pub use message_dispatcher::*; pub use sse::*; pub use stdio::*; pub use transport::*; + +// Type alias for session identifier, represented as a String +pub type SessionId = String; diff --git a/crates/rust-mcp-transport/src/sse.rs b/crates/rust-mcp-transport/src/sse.rs index 4d8b100..554826d 100644 --- a/crates/rust-mcp-transport/src/sse.rs +++ b/crates/rust-mcp-transport/src/sse.rs @@ -12,8 +12,8 @@ use crate::error::{TransportError, TransportResult}; use crate::mcp_stream::MCPStream; use crate::message_dispatcher::MessageDispatcher; use crate::transport::Transport; -use crate::utils::CancellationTokenSource; -use crate::{IoStream, McpDispatch, TransportOptions}; +use crate::utils::{endpoint_with_session_id, CancellationTokenSource}; +use crate::{IoStream, McpDispatch, SessionId, TransportOptions}; pub struct SseTransport { shutdown_source: tokio::sync::RwLock>, @@ -47,6 +47,10 @@ impl SseTransport { is_shut_down: Mutex::new(false), }) } + + pub fn message_endpoint(endpoint: &str, session_id: &SessionId) -> String { + endpoint_with_session_id(endpoint, session_id) + } } #[async_trait] diff --git a/crates/rust-mcp-transport/src/utils.rs b/crates/rust-mcp-transport/src/utils.rs index 68bc604..beff27d 100644 --- a/crates/rust-mcp-transport/src/utils.rs +++ b/crates/rust-mcp-transport/src/utils.rs @@ -13,7 +13,10 @@ pub(crate) use writable_channel::*; use rust_mcp_schema::schema_utils::SdkError; use tokio::time::{timeout, Duration}; -use crate::error::{TransportError, TransportResult}; +use crate::{ + error::{TransportError, TransportResult}, + SessionId, +}; pub async fn await_timeout(operation: F, timeout_duration: Duration) -> TransportResult where @@ -41,6 +44,55 @@ pub fn extract_origin(url: &str) -> Option { Some(format!("{}://{}", scheme, host_port)) } +/// Adds a session ID as a query parameter to a given endpoint URL. +/// +/// # Arguments +/// * `endpoint` - The base URL or endpoint (e.g., "/messages") +/// * `session_id` - The session ID to append as a query parameter +/// +/// # Returns +/// A String containing the endpoint with the session ID added as a query parameter +/// +/// # Examples +/// ``` +/// assert_eq!(endpoint_with_session_id("/messages", "AAA"), "/messages?sessionId=AAA"); +/// assert_eq!(endpoint_with_session_id("/messages?foo=bar&baz=qux", "AAA"), "/messages?foo=bar&baz=qux&sessionId=AAA"); +/// assert_eq!(endpoint_with_session_id("/messages#section1", "AAA"), "/messages?sessionId=AAA#section1"); +/// assert_eq!(endpoint_with_session_id("/messages?key=value#section2", "AAA"), "/messages?key=value&sessionId=AAA#section2"); +/// assert_eq!(endpoint_with_session_id("/", "AAA"), "/?sessionId=AAA"); +/// assert_eq!(endpoint_with_session_id("", "AAA"), "/?sessionId=AAA"); +/// ``` +pub fn endpoint_with_session_id(endpoint: &str, session_id: &SessionId) -> String { + // Handle empty endpoint + let base = if endpoint.is_empty() { "/" } else { endpoint }; + + // Split fragment if it exists + let (path_and_query, fragment) = if let Some((p, f)) = base.split_once('#') { + (p, Some(f)) + } else { + (base, None) + }; + + // Split path and query + let (path, query) = if let Some((p, q)) = path_and_query.split_once('?') { + (p, Some(q)) + } else { + (path_and_query, None) + }; + + // Build the query string + let new_query = match query { + Some(q) if !q.is_empty() => format!("{}&sessionId={}", q, session_id), + _ => format!("sessionId={}", session_id), + }; + + // Construct final URL + match fragment { + Some(f) => format!("{}?{}#{}", path, new_query, f), + None => format!("{}?{}", path, new_query), + } +} + #[cfg(test)] mod tests { use super::*; From 7f933bd0d3891577809bb5fa6e24717da67917c2 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Wed, 21 May 2025 20:44:50 -0300 Subject: [PATCH 2/5] chore: dynamic route for messages --- .../hyper_servers/routes/messages_routes.rs | 18 ++++++---- .../src/hyper_servers/routes/sse_routes.rs | 1 - crates/rust-mcp-sdk/src/utils.rs | 33 +++++++++++++++++++ 3 files changed, 44 insertions(+), 8 deletions(-) diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs index 10e2eb9..55d15b1 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs @@ -1,6 +1,9 @@ -use crate::hyper_servers::{ - app_state::AppState, - error::{TransportServerError, TransportServerResult}, +use crate::{ + hyper_servers::{ + app_state::AppState, + error::{TransportServerError, TransportServerResult}, + }, + utils::remove_query_and_hash, }; use axum::{ extract::{Query, State}, @@ -11,10 +14,11 @@ use axum::{ use std::{collections::HashMap, sync::Arc}; use tokio::io::AsyncWriteExt; -const SSE_MESSAGES_PATH: &str = "/messages"; - -pub fn routes(_state: Arc) -> Router> { - Router::new().route(SSE_MESSAGES_PATH, post(handle_messages)) +pub fn routes(state: Arc) -> Router> { + Router::new().route( + remove_query_and_hash(&state.sse_message_endpoint).as_str(), + post(handle_messages), + ) } pub async fn handle_messages( diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs index 8085474..b6e98f0 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs @@ -27,7 +27,6 @@ use tokio::{ }; use tokio_stream::StreamExt; -const SSE_MESSAGES_PATH: &str = "/messages"; const CLIENT_PING_TIMEOUT: Duration = Duration::from_secs(2); const DUPLEX_BUFFER_SIZE: usize = 8192; diff --git a/crates/rust-mcp-sdk/src/utils.rs b/crates/rust-mcp-sdk/src/utils.rs index 85ad72e..04bed61 100644 --- a/crates/rust-mcp-sdk/src/utils.rs +++ b/crates/rust-mcp-sdk/src/utils.rs @@ -22,3 +22,36 @@ pub fn format_assertion_message(entity: &str, capability: &str, method_name: &st entity, capability, method_name ) } + +/// Removes query string and hash fragment from a URL, returning the base path. +/// +/// # Arguments +/// * `endpoint` - The URL or endpoint to process (e.g., "/messages?foo=bar#section1") +/// +/// # Returns +/// A String containing the base path without query parameters or fragment +/// +/// # Examples +/// ``` +/// assert_eq!(remove_query_and_hash("/messages"), "/messages"); +/// assert_eq!(remove_query_and_hash("/messages?foo=bar&baz=qux"), "/messages"); +/// assert_eq!(remove_query_and_hash("/messages#section1"), "/messages"); +/// assert_eq!(remove_query_and_hash("/messages?key=value#section2"), "/messages"); +/// assert_eq!(remove_query_and_hash("/"), "/"); +/// ``` +pub fn remove_query_and_hash(endpoint: &str) -> String { + // Split off fragment (if any) and take the first part + let without_fragment = endpoint.split_once('#').map_or(endpoint, |(path, _)| path); + + // Split off query string (if any) and take the first part + let without_query = without_fragment + .split_once('?') + .map_or(without_fragment, |(path, _)| path); + + // Return the base path + if without_query.is_empty() { + "/".to_string() + } else { + without_query.to_string() + } +} From bc61a85ce274121e560d46b00491811c10508892 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Thu, 22 May 2025 08:24:34 -0300 Subject: [PATCH 3/5] chore add tests --- crates/rust-mcp-sdk/src/utils.rs | 29 +++++++++++++------- crates/rust-mcp-transport/src/utils.rs | 37 +++++++++++++++++++------- 2 files changed, 47 insertions(+), 19 deletions(-) diff --git a/crates/rust-mcp-sdk/src/utils.rs b/crates/rust-mcp-sdk/src/utils.rs index 04bed61..6b27d9b 100644 --- a/crates/rust-mcp-sdk/src/utils.rs +++ b/crates/rust-mcp-sdk/src/utils.rs @@ -30,16 +30,8 @@ pub fn format_assertion_message(entity: &str, capability: &str, method_name: &st /// /// # Returns /// A String containing the base path without query parameters or fragment -/// -/// # Examples -/// ``` -/// assert_eq!(remove_query_and_hash("/messages"), "/messages"); -/// assert_eq!(remove_query_and_hash("/messages?foo=bar&baz=qux"), "/messages"); -/// assert_eq!(remove_query_and_hash("/messages#section1"), "/messages"); -/// assert_eq!(remove_query_and_hash("/messages?key=value#section2"), "/messages"); -/// assert_eq!(remove_query_and_hash("/"), "/"); /// ``` -pub fn remove_query_and_hash(endpoint: &str) -> String { +pub(crate) fn remove_query_and_hash(endpoint: &str) -> String { // Split off fragment (if any) and take the first part let without_fragment = endpoint.split_once('#').map_or(endpoint, |(path, _)| path); @@ -55,3 +47,22 @@ pub fn remove_query_and_hash(endpoint: &str) -> String { without_query.to_string() } } + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn tets_remove_query_and_hash() { + assert_eq!(remove_query_and_hash("/messages"), "/messages"); + assert_eq!( + remove_query_and_hash("/messages?foo=bar&baz=qux"), + "/messages" + ); + assert_eq!(remove_query_and_hash("/messages#section1"), "/messages"); + assert_eq!( + remove_query_and_hash("/messages?key=value#section2"), + "/messages" + ); + assert_eq!(remove_query_and_hash("/"), "/"); + } +} diff --git a/crates/rust-mcp-transport/src/utils.rs b/crates/rust-mcp-transport/src/utils.rs index beff27d..4e8a2d7 100644 --- a/crates/rust-mcp-transport/src/utils.rs +++ b/crates/rust-mcp-transport/src/utils.rs @@ -53,16 +53,7 @@ pub fn extract_origin(url: &str) -> Option { /// # Returns /// A String containing the endpoint with the session ID added as a query parameter /// -/// # Examples -/// ``` -/// assert_eq!(endpoint_with_session_id("/messages", "AAA"), "/messages?sessionId=AAA"); -/// assert_eq!(endpoint_with_session_id("/messages?foo=bar&baz=qux", "AAA"), "/messages?foo=bar&baz=qux&sessionId=AAA"); -/// assert_eq!(endpoint_with_session_id("/messages#section1", "AAA"), "/messages?sessionId=AAA#section1"); -/// assert_eq!(endpoint_with_session_id("/messages?key=value#section2", "AAA"), "/messages?key=value&sessionId=AAA#section2"); -/// assert_eq!(endpoint_with_session_id("/", "AAA"), "/?sessionId=AAA"); -/// assert_eq!(endpoint_with_session_id("", "AAA"), "/?sessionId=AAA"); -/// ``` -pub fn endpoint_with_session_id(endpoint: &str, session_id: &SessionId) -> String { +pub(crate) fn endpoint_with_session_id(endpoint: &str, session_id: &SessionId) -> String { // Handle empty endpoint let base = if endpoint.is_empty() { "/" } else { endpoint }; @@ -131,4 +122,30 @@ mod tests { fn test_extract_origin_empty_string() { assert_eq!(extract_origin(""), None); } + + #[test] + fn test_endpoint_with_session_id() { + let session_id: SessionId = "AAA".to_string(); + assert_eq!( + endpoint_with_session_id("/messages", &session_id), + "/messages?sessionId=AAA" + ); + assert_eq!( + endpoint_with_session_id("/messages?foo=bar&baz=qux", &session_id), + "/messages?foo=bar&baz=qux&sessionId=AAA" + ); + assert_eq!( + endpoint_with_session_id("/messages#section1", &session_id), + "/messages?sessionId=AAA#section1" + ); + assert_eq!( + endpoint_with_session_id("/messages?key=value#section2", &session_id), + "/messages?key=value&sessionId=AAA#section2" + ); + assert_eq!( + endpoint_with_session_id("/", &session_id), + "/?sessionId=AAA" + ); + assert_eq!(endpoint_with_session_id("", &session_id), "/?sessionId=AAA"); + } } From 48fdf0f1c22509732da0a2c2eb782e2cc120a746 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Thu, 22 May 2025 11:34:29 -0300 Subject: [PATCH 4/5] feat: accpt custom session id generator --- Cargo.lock | 1 + crates/rust-mcp-sdk/Cargo.toml | 1 + .../rust-mcp-sdk/src/hyper_servers/server.rs | 14 +- crates/rust-mcp-sdk/tests/common/common.rs | 10 + .../rust-mcp-sdk/tests/common/test_server.rs | 118 +++++++++++ crates/rust-mcp-sdk/tests/test_server_sse.rs | 191 +++++++++++++++++- 6 files changed, 330 insertions(+), 5 deletions(-) create mode 100644 crates/rust-mcp-sdk/tests/common/test_server.rs diff --git a/Cargo.lock b/Cargo.lock index a80f199..8a833ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1631,6 +1631,7 @@ dependencies = [ "axum-server", "futures", "hyper 1.6.0", + "reqwest", "rust-mcp-macros", "rust-mcp-schema", "rust-mcp-transport", diff --git a/crates/rust-mcp-sdk/Cargo.toml b/crates/rust-mcp-sdk/Cargo.toml index ab88674..b4d7663 100644 --- a/crates/rust-mcp-sdk/Cargo.toml +++ b/crates/rust-mcp-sdk/Cargo.toml @@ -32,6 +32,7 @@ tracing.workspace = true hyper = { version = "1.6.0" } [dev-dependencies] +reqwest = { workspace = true, features = ["stream"] } tracing-subscriber = { workspace = true, features = [ "env-filter", "std", diff --git a/crates/rust-mcp-sdk/src/hyper_servers/server.rs b/crates/rust-mcp-sdk/src/hyper_servers/server.rs index b7394a7..cd005a0 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/server.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/server.rs @@ -12,7 +12,7 @@ use super::{ app_state::AppState, error::{TransportServerError, TransportServerResult}, routes::app_routes, - InMemorySessionStore, UuidGenerator, + IdGenerator, InMemorySessionStore, UuidGenerator, }; use axum::Router; use rust_mcp_schema::InitializeResult; @@ -31,7 +31,7 @@ const DEFAULT_MESSAGES_ENDPOINT: &str = "/messages"; pub struct HyperServerOptions { /// Hostname or IP address the server will bind to (default: "localhost") pub host: String, - /// Hostname or IP address the server will bind to (default: "localhost") + /// Hostname or IP address the server will bind to (default: "8080") pub port: u16, /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`) pub custom_sse_endpoint: Option, @@ -49,6 +49,8 @@ pub struct HyperServerOptions { pub ssl_key_path: Option, /// Shared transport configuration used by the server pub transport_options: Arc, + /// Optional thread-safe session id generator to generate unique session IDs. + pub session_id_generator: Option>, } impl HyperServerOptions { @@ -148,6 +150,7 @@ impl Default for HyperServerOptions { enable_ssl: false, ssl_cert_path: None, ssl_key_path: None, + session_id_generator: None, } } } @@ -174,11 +177,14 @@ impl HyperServer { pub(crate) fn new( server_details: InitializeResult, handler: Arc, - server_options: HyperServerOptions, + mut server_options: HyperServerOptions, ) -> Self { let state: Arc = Arc::new(AppState { session_store: Arc::new(InMemorySessionStore::new()), - id_generator: Arc::new(UuidGenerator {}), + id_generator: server_options + .session_id_generator + .take() + .map_or(Arc::new(UuidGenerator {}), |g| Arc::clone(&g)), server_details: Arc::new(server_details), handler, ping_interval: server_options.ping_interval, diff --git a/crates/rust-mcp-sdk/tests/common/common.rs b/crates/rust-mcp-sdk/tests/common/common.rs index d896a56..6746270 100644 --- a/crates/rust-mcp-sdk/tests/common/common.rs +++ b/crates/rust-mcp-sdk/tests/common/common.rs @@ -1,8 +1,10 @@ +mod test_server; use async_trait::async_trait; use rust_mcp_schema::{ ClientCapabilities, Implementation, InitializeRequestParams, JSONRPC_VERSION, }; use rust_mcp_sdk::mcp_client::ClientHandler; +pub use test_server::*; pub const NPX_SERVER_EVERYTHING: &str = "@modelcontextprotocol/server-everything"; @@ -24,3 +26,11 @@ pub struct TestClientHandler; #[async_trait] impl ClientHandler for TestClientHandler {} + +pub fn sse_event(sse_raw: &str) -> String { + sse_raw.replace("event: ", "") +} + +pub fn sse_data(sse_raw: &str) -> String { + sse_raw.replace("data: ", "") +} diff --git a/crates/rust-mcp-sdk/tests/common/test_server.rs b/crates/rust-mcp-sdk/tests/common/test_server.rs new file mode 100644 index 0000000..dcb4e1b --- /dev/null +++ b/crates/rust-mcp-sdk/tests/common/test_server.rs @@ -0,0 +1,118 @@ +use async_trait::async_trait; +use tokio_stream::StreamExt; + +use rust_mcp_schema::{ + Implementation, InitializeResult, ServerCapabilities, ServerCapabilitiesTools, + LATEST_PROTOCOL_VERSION, +}; +use rust_mcp_sdk::{ + mcp_server::{hyper_server, HyperServer, HyperServerOptions, IdGenerator, ServerHandler}, + McpServer, SessionId, +}; +use std::sync::RwLock; +use std::time::Duration; +use tokio::time::timeout; + +pub const INITIALIZE_REQUEST: &str = r#"{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"protocolVersion":"2.0","capabilities":{"sampling":{},"roots":{"listChanged":true}},"clientInfo":{"name":"reqwest-test","version":"0.1.0"}}}"#; +pub const PING_REQUEST: &str = r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#; + +pub fn test_server_details() -> InitializeResult { + InitializeResult { + // server name and version + server_info: Implementation { + name: "Test MCP Server".to_string(), + version: "0.1.0".to_string(), + }, + capabilities: ServerCapabilities { + // indicates that server support mcp tools + tools: Some(ServerCapabilitiesTools { list_changed: None }), + ..Default::default() // Using default values for other fields + }, + meta: None, + instructions: Some("server instructions...".to_string()), + protocol_version: LATEST_PROTOCOL_VERSION.to_string(), + } +} + +pub struct TestServerHandler; + +#[async_trait] +impl ServerHandler for TestServerHandler { + async fn on_server_started(&self, runtime: &dyn McpServer) { + let _ = runtime + .stderr_message("Server started successfully".into()) + .await; + } +} + +pub fn create_test_server(options: HyperServerOptions) -> HyperServer { + hyper_server::create_server(test_server_details(), TestServerHandler {}, options) +} + +// Tests the session ID generator, ensuring it returns a sequence of predefined session IDs. +pub struct TestIdGenerator { + constant_ids: Vec, + generated: RwLock, +} + +impl TestIdGenerator { + pub fn new(constant_ids: Vec) -> Self { + TestIdGenerator { + constant_ids, + generated: RwLock::new(0), + } + } +} + +impl IdGenerator for TestIdGenerator { + fn generate(&self) -> SessionId { + let mut lock = self.generated.write().unwrap(); + *lock += 1; + if *lock > self.constant_ids.len() { + *lock = 1; + } + self.constant_ids[*lock - 1].to_owned() + } +} + +pub async fn collect_sse_lines( + response: reqwest::Response, + line_count: usize, + read_timeout: Duration, +) -> Result, Box> { + let mut collected_lines = Vec::new(); + let mut stream = response.bytes_stream(); + + let result = timeout(read_timeout, async { + while let Some(chunk) = stream.next().await { + let chunk = chunk.map_err(|e| Box::new(e) as Box)?; + let chunk_str = String::from_utf8_lossy(&chunk); + + // Split the chunk into lines + let lines: Vec<&str> = chunk_str.lines().collect(); + + // Add each line to the collected_lines vector + for line in lines { + collected_lines.push(line.to_string()); + + // Check if we have collected 5 lines + if collected_lines.len() >= line_count { + return Ok(collected_lines); + } + } + } + // If the stream ends before collecting 5 lines, return what we have + Ok(collected_lines) + }) + .await; + + // Handle timeout or stream result + match result { + Ok(Ok(lines)) => Ok(lines), + Ok(Err(e)) => Err(e), + Err(_) => Err(Box::new(std::io::Error::new( + std::io::ErrorKind::TimedOut, + "Timed out waiting for 5 lines", + ))), + } +} diff --git a/crates/rust-mcp-sdk/tests/test_server_sse.rs b/crates/rust-mcp-sdk/tests/test_server_sse.rs index eee57fc..4196b73 100644 --- a/crates/rust-mcp-sdk/tests/test_server_sse.rs +++ b/crates/rust-mcp-sdk/tests/test_server_sse.rs @@ -1,5 +1,194 @@ +use std::{sync::Arc, time::Duration}; + +use common::{ + collect_sse_lines, create_test_server, sse_data, sse_event, TestIdGenerator, INITIALIZE_REQUEST, +}; +use reqwest::Client; +use rust_mcp_schema::{ + schema_utils::{ResultFromServer, ServerMessage}, + ServerResult, +}; +use rust_mcp_sdk::mcp_server::HyperServerOptions; +use tokio::time::sleep; + #[path = "common/common.rs"] pub mod common; #[tokio::test] -async fn tets_server_sse() {} +async fn tets_sse_endpoint_event_default() { + let server_options = HyperServerOptions { + session_id_generator: Some(Arc::new(TestIdGenerator::new(vec![ + "AAA-BBB-CCC".to_string() + ]))), + ..Default::default() + }; + + let base_url = format!("http://{}:{}", server_options.host, server_options.port); + + let server_endpoint = format!("{}{}", base_url, server_options.sse_endpoint()); + + tokio::spawn(async move { + let server = create_test_server(server_options); + server.start().await.unwrap(); + }); + + sleep(Duration::from_millis(750)).await; + + let client = Client::new(); + println!("connecting to : {}", server_endpoint); + // Act: Connect to the SSE endpoint and read the event stream + let response = client + .get(server_endpoint) + .header("Accept", "text/event-stream") + .send() + .await + .expect("Failed to connect to SSE endpoint"); + + assert_eq!( + response.headers().get("content-type").map(|v| v.as_bytes()), + Some(b"text/event-stream" as &[u8]), + "Response content-type should be text/event-stream" + ); + + let lines = collect_sse_lines(response, 2, Duration::from_secs(5)) + .await + .unwrap(); + + assert_eq!(sse_event(&lines[0]), "endpoint"); + assert_eq!(sse_data(&lines[1]), "/messages?sessionId=AAA-BBB-CCC"); + + let message_endpoint = format!("{}{}", base_url, sse_data(&lines[1])); + let res = client + .post(message_endpoint) + .header("Content-Type", "application/json") + .body(INITIALIZE_REQUEST.to_string()) + .send() + .await + .unwrap(); + assert!(res.status().is_success()); +} + +#[tokio::test] +async fn tets_sse_message_endpoint_query_hash() { + let server_options = HyperServerOptions { + custom_messages_endpoint: Some( + "/custom-msg-endpoint?something=true&otherthing=false#section-59".to_string(), + ), + session_id_generator: Some(Arc::new(TestIdGenerator::new(vec![ + "AAA-BBB-CCC".to_string() + ]))), + ..Default::default() + }; + + let base_url = format!("http://{}:{}", server_options.host, server_options.port); + + let server_endpoint = format!("{}{}", base_url, server_options.sse_endpoint()); + + tokio::spawn(async move { + let server = create_test_server(server_options); + server.start().await.unwrap(); + }); + + sleep(Duration::from_millis(750)).await; + + let client = Client::new(); + println!("connecting to : {}", server_endpoint); + // Act: Connect to the SSE endpoint and read the event stream + let response = client + .get(server_endpoint) + .header("Accept", "text/event-stream") + .send() + .await + .expect("Failed to connect to SSE endpoint"); + + assert_eq!( + response.headers().get("content-type").map(|v| v.as_bytes()), + Some(b"text/event-stream" as &[u8]), + "Response content-type should be text/event-stream" + ); + + let lines = collect_sse_lines(response, 2, Duration::from_secs(5)) + .await + .unwrap(); + + assert_eq!(sse_event(&lines[0]), "endpoint"); + assert_eq!( + sse_data(&lines[1]), + "/custom-msg-endpoint?something=true&otherthing=false&sessionId=AAA-BBB-CCC#section-59" + ); + + let message_endpoint = format!("{}{}", base_url, sse_data(&lines[1])); + let res = client + .post(message_endpoint) + .header("Content-Type", "application/json") + .body(INITIALIZE_REQUEST.to_string()) + .send() + .await + .unwrap(); + assert!(res.status().is_success()); +} + +#[tokio::test] +async fn tets_sse_custom_message_endpoint() { + let server_options = HyperServerOptions { + custom_messages_endpoint: Some( + "/custom-msg-endpoint?something=true&otherthing=false#section-59".to_string(), + ), + session_id_generator: Some(Arc::new(TestIdGenerator::new(vec![ + "AAA-BBB-CCC".to_string() + ]))), + ..Default::default() + }; + + let base_url = format!("http://{}:{}", server_options.host, server_options.port); + + let server_endpoint = format!("{}{}", base_url, server_options.sse_endpoint()); + + tokio::spawn(async move { + let server = create_test_server(server_options); + server.start().await.unwrap(); + }); + + sleep(Duration::from_millis(750)).await; + + let client = Client::new(); + println!("connecting to : {}", server_endpoint); + // Act: Connect to the SSE endpoint and read the event stream + let response = client + .get(server_endpoint) + .header("Accept", "text/event-stream") + .send() + .await + .expect("Failed to connect to SSE endpoint"); + + assert_eq!( + response.headers().get("content-type").map(|v| v.as_bytes()), + Some(b"text/event-stream" as &[u8]), + "Response content-type should be text/event-stream" + ); + + let message_endpoint = format!( + "{}{}", + base_url, + "/custom-msg-endpoint?something=true&otherthing=false&sessionId=AAA-BBB-CCC#section-59" + ); + let res = client + .post(message_endpoint) + .header("Content-Type", "application/json") + .body(INITIALIZE_REQUEST.to_string()) + .send() + .await + .unwrap(); + assert!(res.status().is_success()); + + let lines = collect_sse_lines(response, 5, Duration::from_secs(5)) + .await + .unwrap(); + + let init_response = sse_data(&lines[3]); + let result = serde_json::from_str::(&init_response).unwrap(); + + assert!(matches!(result, ServerMessage::Response(response) + if matches!(&response.result, ResultFromServer::ServerResult(server_result) + if matches!(server_result, ServerResult::InitializeResult(init_result) if init_result.server_info.name == "Test MCP Server".to_string())))); +} From e101bd1c51bfd593c0f70dc1e577e0f3827b695c Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Fri, 23 May 2025 18:21:19 -0300 Subject: [PATCH 5/5] feat: implement shutdown signal to server - improved sse architecture --- Cargo.lock | 36 +-- Cargo.toml | 2 +- Makefile.toml | 2 +- crates/rust-mcp-macros/Cargo.toml | 13 + crates/rust-mcp-macros/README.md | 12 + crates/rust-mcp-macros/src/lib.rs | 272 ++++++++++++------ crates/rust-mcp-macros/tests/macro_test.rs | 55 ++++ .../rust-mcp-sdk/src/hyper_servers/server.rs | 57 +++- .../src/hyper_servers/session_store.rs | 4 + .../hyper_servers/session_store/in_memory.rs | 9 + crates/rust-mcp-sdk/src/utils.rs | 1 + crates/rust-mcp-sdk/tests/test_server_sse.rs | 31 +- .../src/utils/readable_channel.rs | 1 - examples/hello-world-mcp-server/src/tools.rs | 12 +- 14 files changed, 386 insertions(+), 121 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8a833ba..060ec84 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -239,9 +239,9 @@ dependencies = [ [[package]] name = "bitflags" -version = "2.9.0" +version = "2.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c8214115b7bf84099f1309324e63141d4c5d7cc26862f97a0a857dbefe165bd" +checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967" [[package]] name = "bumpalo" @@ -257,9 +257,9 @@ checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" [[package]] name = "cc" -version = "1.2.22" +version = "1.2.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32db95edf998450acc7881c932f94cd9b05c87b4b2599e8bab064753da4acfd1" +checksum = "5f4ac86a9e5bc1e2b3449ab9d7d3a6a405e3d1bb28d7b9be8614f55846ae3766" dependencies = [ "jobserver", "libc", @@ -400,9 +400,9 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" [[package]] name = "errno" -version = "0.3.11" +version = "0.3.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "976dd42dc7e85965fe702eb8164f21f450704bdde31faefd6471dba214cb594e" +checksum = "cea14ef9355e3beab063703aa9dab15afd25f0667c341310c1e5274bb1d0da18" dependencies = [ "libc", "windows-sys 0.59.0", @@ -917,9 +917,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.11" +version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "497bbc33a26fdd4af9ed9c70d63f61cf56a938375fbb32df34db9b1cd6d643f2" +checksum = "cf9f1e950e0d9d1d3c47184416723cf29c0d1f93bd8cccf37e4beb6b44f31710" dependencies = [ "bytes", "futures-channel", @@ -984,9 +984,9 @@ checksum = "00210d6893afc98edb752b664b8890f0ef174c8adbb8d0be9710fa66fbbf72d3" [[package]] name = "icu_properties" -version = "2.0.0" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2549ca8c7241c82f59c80ba2a6f415d931c5b58d24fb8412caa1a1f02c49139a" +checksum = "016c619c1eeb94efb86809b015c58f479963de65bdb6253345c1a1276f22e32b" dependencies = [ "displaydoc", "icu_collections", @@ -1000,9 +1000,9 @@ dependencies = [ [[package]] name = "icu_properties_data" -version = "2.0.0" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8197e866e47b68f8f7d95249e172903bec06004b18b2937f1095d40a0c57de04" +checksum = "298459143998310acd25ffe6810ed544932242d3f07083eee1084d83a71bd632" [[package]] name = "icu_provider" @@ -1614,9 +1614,9 @@ dependencies = [ [[package]] name = "rust-mcp-schema" -version = "0.4.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "868d31d0ae0376ba45786eac9058771da06839e83bb961ac7e5997ab3910f086" +checksum = "49212f1da431236217031807377e6296db06a270224698c426afa94e5dacd8e7" dependencies = [ "serde", "serde_json", @@ -1749,9 +1749,9 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.20" +version = "1.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eded382c5f5f786b989652c49544c4877d9f015cc22e145a5ea8ea66c2921cd2" +checksum = "8a0d197bd2c9dc6e53b84da9556a69ba4cdfab8619eb41a8bd1cc2027a0f6b1d" [[package]] name = "ryu" @@ -2515,9 +2515,9 @@ dependencies = [ [[package]] name = "windows-result" -version = "0.3.3" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b895b5356fc36103d0f64dd1e94dfa7ac5633f1c9dd6e80fe9ec4adef69e09d" +checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" dependencies = [ "windows-link", ] diff --git a/Cargo.toml b/Cargo.toml index 986e877..f9e897b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ rust-mcp-sdk = { path = "crates/rust-mcp-sdk", default-features = false } rust-mcp-macros = { version = "0.2.1", path = "crates/rust-mcp-macros" } # External crates -rust-mcp-schema = { version = "0.4" } +rust-mcp-schema = { version = "0.5" } futures = { version = "0.3" } tokio = { version = "1.4", features = ["full"] } serde = { version = "1.0", features = ["derive", "serde_derive"] } diff --git a/Makefile.toml b/Makefile.toml index 8d11b2d..f76879a 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -8,7 +8,7 @@ args = ["fmt", "--all", "--", "--check"] [tasks.clippy] command = "cargo" -args = ["clippy", "--workspace", "--all-targets", "--all-features"] +args = ["clippy", "--workspace", "--all-targets"] [tasks.test] install_crate = "nextest" diff --git a/crates/rust-mcp-macros/Cargo.toml b/crates/rust-mcp-macros/Cargo.toml index e63f2ae..2b272c8 100644 --- a/crates/rust-mcp-macros/Cargo.toml +++ b/crates/rust-mcp-macros/Cargo.toml @@ -28,3 +28,16 @@ workspace = true [lib] proc-macro = true + + +[features] +# defalt features +default = ["2025_03_26"] # Default features + +# activates the latest MCP schema version, this will be updated once a new version of schema is published +latest = ["2025_03_26"] + +# enabled mcp schema version 2025_03_26 +2025_03_26 = ["rust-mcp-schema/2025_03_26"] +# enabled mcp schema version 2024_11_05 +2024_11_05 = ["rust-mcp-schema/2024_11_05"] diff --git a/crates/rust-mcp-macros/README.md b/crates/rust-mcp-macros/README.md index 5246a5b..6f6c956 100644 --- a/crates/rust-mcp-macros/README.md +++ b/crates/rust-mcp-macros/README.md @@ -19,6 +19,10 @@ The `mcp_tool` macro generates an implementation for the annotated struct that i #[mcp_tool( name = "write_file", description = "Create a new file or completely overwrite an existing file with new content." + destructive_hint = false + idempotent_hint = false + open_world_hint = false + read_only_hint = false )] #[derive(rust_mcp_macros::JsonSchema)] pub struct WriteFileTool { @@ -60,3 +64,11 @@ fn main() { Check out [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) , a high-performance, asynchronous toolkit for building MCP servers and clients. Focus on your app's logic while [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) takes care of the rest! --- + + +**Note**: The following attributes are available only in version `2025_03_26` and later of the MCP Schema, and their values will be used in the [annotations](https://github.com/rust-mcp-stack/rust-mcp-schema/blob/main/src/generated_schema/2025_03_26/mcp_schema.rs#L5557) attribute of the *[Tool struct](https://github.com/rust-mcp-stack/rust-mcp-schema/blob/main/src/generated_schema/2025_03_26/mcp_schema.rs#L5554-L5566). + +- `destructive_hint` +- `idempotent_hint` +- `open_world_hint` +- `read_only_hint` diff --git a/crates/rust-mcp-macros/src/lib.rs b/crates/rust-mcp-macros/src/lib.rs index d123b1e..3f35b03 100644 --- a/crates/rust-mcp-macros/src/lib.rs +++ b/crates/rust-mcp-macros/src/lib.rs @@ -19,11 +19,37 @@ use utils::{is_option, renamed_field, type_to_json_schema}; /// * `name` - An optional string representing the tool's name. /// * `description` - An optional string describing the tool. /// +#[cfg(feature = "2024_11_05")] struct McpToolMacroAttributes { name: Option, description: Option, } +/// Represents the attributes for the `mcp_tool` procedural macro. +/// +/// This struct parses and validates the `name` and `description` attributes provided +/// to the `mcp_tool` macro. Both attributes are required and must not be empty strings. +/// +/// # Fields +/// * `name` - An optional string representing the tool's name. +/// * `description` - An optional string describing the tool. +/// * `destructive_hint` - Optional boolean for `ToolAnnotations::destructive_hint`. +/// * `idempotent_hint` - Optional boolean for `ToolAnnotations::idempotent_hint`. +/// * `open_world_hint` - Optional boolean for `ToolAnnotations::open_world_hint`. +/// * `read_only_hint` - Optional boolean for `ToolAnnotations::read_only_hint`. +/// * `title` - Optional string for `ToolAnnotations::title`. +/// +#[cfg(feature = "2025_03_26")] +struct McpToolMacroAttributes { + name: Option, + description: Option, + destructive_hint: Option, + idempotent_hint: Option, + open_world_hint: Option, + read_only_hint: Option, + title: Option, +} + use syn::parse::ParseStream; struct ExprList { @@ -51,59 +77,102 @@ impl Parse for McpToolMacroAttributes { fn parse(attributes: syn::parse::ParseStream) -> syn::Result { let mut name = None; let mut description = None; + let mut destructive_hint = None; + let mut idempotent_hint = None; + let mut open_world_hint = None; + let mut read_only_hint = None; + let mut title = None; + let meta_list: Punctuated = Punctuated::parse_terminated(attributes)?; for meta in meta_list { if let Meta::NameValue(meta_name_value) = meta { let ident = meta_name_value.path.get_ident().unwrap(); let ident_str = ident.to_string(); - let value = match &meta_name_value.value { - Expr::Lit(ExprLit { - lit: Lit::Str(lit_str), - .. - }) => lit_str.value(), - - Expr::Macro(expr_macro) => { - let mac = &expr_macro.mac; - if mac.path.is_ident("concat") { - let args: ExprList = syn::parse2(mac.tokens.clone())?; - let mut result = String::new(); - - for expr in args.exprs { - if let Expr::Lit(ExprLit { - lit: Lit::Str(lit_str), - .. - }) = expr - { - result.push_str(&lit_str.value()); + match ident_str.as_str() { + "name" | "description" => { + let value = match &meta_name_value.value { + Expr::Lit(ExprLit { + lit: Lit::Str(lit_str), + .. + }) => lit_str.value(), + Expr::Macro(expr_macro) => { + let mac = &expr_macro.mac; + if mac.path.is_ident("concat") { + let args: ExprList = syn::parse2(mac.tokens.clone())?; + let mut result = String::new(); + for expr in args.exprs { + if let Expr::Lit(ExprLit { + lit: Lit::Str(lit_str), + .. + }) = expr + { + result.push_str(&lit_str.value()); + } else { + return Err(Error::new_spanned( + expr, + "Only string literals are allowed inside concat!()", + )); + } + } + result } else { return Err(Error::new_spanned( - expr, - "Only string literals are allowed inside concat!()", + expr_macro, + "Only concat!(...) is supported here", )); } } - - result - } else { - return Err(Error::new_spanned( - expr_macro, - "Only concat!(...) is supported here", - )); + _ => { + return Err(Error::new_spanned( + &meta_name_value.value, + "Expected a string literal or concat!(...)", + )); + } + }; + match ident_str.as_str() { + "name" => name = Some(value), + "description" => description = Some(value), + _ => {} } } - - _ => { - return Err(Error::new_spanned( - &meta_name_value.value, - "Expected a string literal or concat!(...)", - )); + "destructive_hint" | "idempotent_hint" | "open_world_hint" + | "read_only_hint" => { + let value = match &meta_name_value.value { + Expr::Lit(ExprLit { + lit: Lit::Bool(lit_bool), + .. + }) => lit_bool.value, + _ => { + return Err(Error::new_spanned( + &meta_name_value.value, + "Expected a boolean literal", + )); + } + }; + match ident_str.as_str() { + "destructive_hint" => destructive_hint = Some(value), + "idempotent_hint" => idempotent_hint = Some(value), + "open_world_hint" => open_world_hint = Some(value), + "read_only_hint" => read_only_hint = Some(value), + _ => {} + } + } + "title" => { + let value = match &meta_name_value.value { + Expr::Lit(ExprLit { + lit: Lit::Str(lit_str), + .. + }) => lit_str.value(), + _ => { + return Err(Error::new_spanned( + &meta_name_value.value, + "Expected a string literal", + )); + } + }; + title = Some(value); } - }; - - match ident_str.as_str() { - "name" => name = Some(value), - "description" => description = Some(value), _ => {} } } @@ -116,7 +185,6 @@ impl Parse for McpToolMacroAttributes { "The 'name' attribute is required and must not be empty.", )); } - if description .as_ref() .map(|s| s.trim().is_empty()) @@ -128,7 +196,21 @@ impl Parse for McpToolMacroAttributes { )); } - Ok(Self { name, description }) + #[cfg(feature = "2024_11_05")] + let instance = Self { name, description }; + + #[cfg(feature = "2025_03_26")] + let instance = Self { + name, + description, + destructive_hint, + idempotent_hint, + open_world_hint, + read_only_hint, + title, + }; + + Ok(instance) } } @@ -148,7 +230,7 @@ impl Parse for McpToolMacroAttributes { /// /// # Example /// ```rust -/// #[rust_mcp_macros::mcp_tool(name = "example_tool", description = "An example tool")] +/// #[rust_mcp_macros::mcp_tool(name = "example_tool", description = "An example tool", idempotent_hint=true )] /// #[derive(rust_mcp_macros::JsonSchema)] /// struct ExampleTool { /// field1: String, @@ -159,6 +241,7 @@ impl Parse for McpToolMacroAttributes { /// let tool : rust_mcp_schema::Tool = ExampleTool::tool(); /// assert_eq!(tool.name , "example_tool"); /// assert_eq!(tool.description.unwrap() , "An example tool"); +/// assert_eq!(tool.annotations.unwrap().idempotent_hint.unwrap() , true); /// /// let schema_properties = tool.input_schema.properties.unwrap(); /// assert_eq!(schema_properties.len() , 2); @@ -176,6 +259,62 @@ pub fn mcp_tool(attributes: TokenStream, input: TokenStream) -> TokenStream { let tool_name = macro_attributes.name.unwrap_or_default(); let tool_description = macro_attributes.description.unwrap_or_default(); + #[cfg(feature = "2025_03_26")] + let some_annotations = macro_attributes.destructive_hint.is_some() + || macro_attributes.idempotent_hint.is_some() + || macro_attributes.open_world_hint.is_some() + || macro_attributes.read_only_hint.is_some() + || macro_attributes.title.is_some(); + + #[cfg(feature = "2025_03_26")] + let annotations = if some_annotations { + let destructive_hint = macro_attributes + .destructive_hint + .map_or(quote! {None}, |v| quote! {Some(#v)}); + + let idempotent_hint = macro_attributes + .idempotent_hint + .map_or(quote! {None}, |v| quote! {Some(#v)}); + let open_world_hint = macro_attributes + .open_world_hint + .map_or(quote! {None}, |v| quote! {Some(#v)}); + let read_only_hint = macro_attributes + .read_only_hint + .map_or(quote! {None}, |v| quote! {Some(#v)}); + let title = macro_attributes + .title + .map_or(quote! {None}, |v| quote! {Some(#v)}); + quote! { + Some(rust_mcp_schema::ToolAnnotations { + destructive_hint: #destructive_hint, + idempotent_hint: #idempotent_hint, + open_world_hint: #open_world_hint, + read_only_hint: #read_only_hint, + title: #title, + }), + } + } else { + quote! {None} + }; + + #[cfg(feature = "2025_03_26")] + let tool_token = quote! { + rust_mcp_schema::Tool { + name: #tool_name.to_string(), + description: Some(#tool_description.to_string()), + input_schema: rust_mcp_schema::ToolInputSchema::new(required, properties), + annotations: #annotations + } + }; + #[cfg(feature = "2024_11_05")] + let tool_token = quote! { + rust_mcp_schema::Tool { + name: #tool_name.to_string(), + description: Some(#tool_description.to_string()), + input_schema: rust_mcp_schema::ToolInputSchema::new(required, properties), + } + }; + let output = quote! { impl #input_ident { /// Returns the name of the tool as a string. @@ -222,54 +361,7 @@ pub fn mcp_tool(attributes: TokenStream, input: TokenStream) -> TokenStream { .collect() }); - rust_mcp_schema::Tool { - name: #tool_name.to_string(), - description: Some(#tool_description.to_string()), - input_schema: rust_mcp_schema::ToolInputSchema::new(required, properties), - } - } - - #[deprecated(since = "0.2.0", note = "Use `tool()` instead.")] - pub fn get_tool()-> rust_mcp_schema::Tool - { - let json_schema = &#input_ident::json_schema(); - - let required: Vec<_> = match json_schema.get("required").and_then(|r| r.as_array()) { - Some(arr) => arr - .iter() - .filter_map(|item| item.as_str().map(String::from)) - .collect(), - None => Vec::new(), // Default to an empty vector if "required" is missing or not an array - }; - - let properties: Option< - std::collections::HashMap>, - > = json_schema - .get("properties") - .and_then(|v| v.as_object()) // Safely extract "properties" as an object. - .map(|properties| { - properties - .iter() - .filter_map(|(key, value)| { - serde_json::to_value(value) - .ok() // If serialization fails, return None. - .and_then(|v| { - if let serde_json::Value::Object(obj) = v { - Some(obj) - } else { - None - } - }) - .map(|obj| (key.to_string(), obj)) // Return the (key, value) tuple - }) - .collect() - }); - - rust_mcp_schema::Tool { - name: #tool_name.to_string(), - description: Some(#tool_description.to_string()), - input_schema: rust_mcp_schema::ToolInputSchema::new(required, properties), - } + #tool_token } } // Retain the original item (struct definition) diff --git a/crates/rust-mcp-macros/tests/macro_test.rs b/crates/rust-mcp-macros/tests/macro_test.rs index 3a23c87..4f24c9e 100644 --- a/crates/rust-mcp-macros/tests/macro_test.rs +++ b/crates/rust-mcp-macros/tests/macro_test.rs @@ -31,3 +31,58 @@ fn test_rename() { let properties = schema.get("properties").unwrap().as_object().unwrap(); assert_eq!(properties.len(), 2); } + +#[test] +#[cfg(feature = "2025_03_26")] +fn test_mcp_tool() { + #[rust_mcp_macros::mcp_tool( + name = "example_tool", + description = "An example tool", + idempotent_hint = true, + destructive_hint = true, + open_world_hint = true, + read_only_hint = true + )] + #[derive(rust_mcp_macros::JsonSchema)] + #[allow(unused)] + struct ExampleTool { + field1: String, + field2: i32, + } + + assert_eq!(ExampleTool::tool_name(), "example_tool"); + let tool: rust_mcp_schema::Tool = ExampleTool::tool(); + assert_eq!(tool.name, "example_tool"); + assert_eq!(tool.description.unwrap(), "An example tool"); + assert!(tool.annotations.as_ref().unwrap().idempotent_hint.unwrap(),); + assert!(tool.annotations.as_ref().unwrap().destructive_hint.unwrap(),); + assert!(tool.annotations.as_ref().unwrap().open_world_hint.unwrap(),); + assert!(tool.annotations.as_ref().unwrap().read_only_hint.unwrap(),); + + let schema_properties = tool.input_schema.properties.unwrap(); + assert_eq!(schema_properties.len(), 2); + assert!(schema_properties.contains_key("field1")); + assert!(schema_properties.contains_key("field2")); +} + +#[test] +#[cfg(feature = "2024_11_05")] +fn test_mcp_tool() { + #[rust_mcp_macros::mcp_tool(name = "example_tool", description = "An example tool")] + #[derive(rust_mcp_macros::JsonSchema)] + #[allow(unused)] + struct ExampleTool { + field1: String, + field2: i32, + } + + assert_eq!(ExampleTool::tool_name(), "example_tool"); + let tool: rust_mcp_schema::Tool = ExampleTool::tool(); + assert_eq!(tool.name, "example_tool"); + assert_eq!(tool.description.unwrap(), "An example tool"); + + let schema_properties = tool.input_schema.properties.unwrap(); + assert_eq!(schema_properties.len(), 2); + assert!(schema_properties.contains_key("field1")); + assert!(schema_properties.contains_key("field2")); +} diff --git a/crates/rust-mcp-sdk/src/hyper_servers/server.rs b/crates/rust-mcp-sdk/src/hyper_servers/server.rs index cd005a0..94a867a 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/server.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/server.rs @@ -1,12 +1,14 @@ use crate::mcp_traits::mcp_handler::McpServerHandler; #[cfg(feature = "ssl")] use axum_server::tls_rustls::RustlsConfig; +use axum_server::Handle; use std::{ net::{SocketAddr, ToSocketAddrs}, path::Path, sync::Arc, time::Duration, }; +use tokio::signal; use super::{ app_state::AppState, @@ -20,7 +22,7 @@ use rust_mcp_transport::TransportOptions; // Default client ping interval (12 seconds) const DEFAULT_CLIENT_PING_INTERVAL: Duration = Duration::from_secs(12); - +const GRACEFUL_SHUTDOWN_TMEOUT_SECS: u64 = 30; // Default Server-Sent Events (SSE) endpoint path const DEFAULT_SSE_ENDPOINT: &str = "/sse"; // Default MCP Messages endpoint path @@ -160,6 +162,7 @@ pub struct HyperServer { app: Router, state: Arc, options: HyperServerOptions, + handle: Handle, } impl HyperServer { @@ -196,6 +199,7 @@ impl HyperServer { app, state, options: server_options, + handle: Handle::new(), } } @@ -280,12 +284,25 @@ impl HyperServer { tracing::info!("{}", self.server_info(Some(addr)).await?); + // Spawn a task to trigger shutdown on signal + let handle_clone = self.handle.clone(); + tokio::spawn(async move { + shutdown_signal(handle_clone).await; + }); + + let handle_clone = self.handle.clone(); axum_server::bind_rustls(addr, config) + .handle(handle_clone) .serve(self.app.into_make_service()) .await .map_err(|err| TransportServerError::ServerStartError(err.to_string())) } + /// Returns server handle that could be used for graceful shutdown + pub fn server_handle(&self) -> Handle { + self.handle.clone() + } + /// Starts the server without SSL /// /// # Arguments @@ -296,7 +313,15 @@ impl HyperServer { async fn start_http(self, addr: SocketAddr) -> TransportServerResult<()> { tracing::info!("{}", self.server_info(Some(addr)).await?); + // Spawn a task to trigger shutdown on signal + let handle_clone = self.handle.clone(); + tokio::spawn(async move { + shutdown_signal(handle_clone).await; + }); + + let handle_clone = self.handle.clone(); axum_server::bind(addr) + .handle(handle_clone) .serve(self.app.into_make_service()) .await .map_err(|err| TransportServerError::ServerStartError(err.to_string())) @@ -327,3 +352,33 @@ impl HyperServer { } } } + +// Shutdown signal handler +async fn shutdown_signal(handle: Handle) { + // Wait for a Ctrl+C or SIGTERM signal + let ctrl_c = async { + signal::ctrl_c() + .await + .expect("Failed to install Ctrl+C handler"); + }; + + #[cfg(unix)] + let terminate = async { + signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("Failed to install signal handler") + .recv() + .await; + }; + + #[cfg(not(unix))] + let terminate = std::future::pending::<()>(); + + tokio::select! { + _ = ctrl_c => {}, + _ = terminate => {}, + } + + tracing::info!("Signal received, starting graceful shutdown"); + // Trigger graceful shutdown with a timeout + handle.graceful_shutdown(Some(Duration::from_secs(GRACEFUL_SHUTDOWN_TMEOUT_SECS))); +} diff --git a/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs b/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs index 998be01..b0716b8 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs @@ -37,6 +37,10 @@ pub trait SessionStore: Send + Sync { async fn delete(&self, key: &SessionId); /// Clears all sessions from the store async fn clear(&self); + + async fn keys(&self) -> Vec; + + async fn values(&self) -> Vec>>; } /// Trait for generating session identifiers diff --git a/crates/rust-mcp-sdk/src/hyper_servers/session_store/in_memory.rs b/crates/rust-mcp-sdk/src/hyper_servers/session_store/in_memory.rs index 7c5755d..342d232 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/session_store/in_memory.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/session_store/in_memory.rs @@ -3,6 +3,7 @@ use super::{SessionStore, TxServer}; use async_trait::async_trait; use std::collections::HashMap; use std::sync::Arc; +use tokio::io::DuplexStream; use tokio::sync::Mutex; use tokio::sync::RwLock; @@ -54,4 +55,12 @@ impl SessionStore for InMemorySessionStore { let mut store = self.store.write().await; store.clear(); } + async fn keys(&self) -> Vec { + let store = self.store.read().await; + store.keys().cloned().collect::>() + } + async fn values(&self) -> Vec>> { + let store = self.store.read().await; + store.values().cloned().collect::>() + } } diff --git a/crates/rust-mcp-sdk/src/utils.rs b/crates/rust-mcp-sdk/src/utils.rs index 6b27d9b..13dc579 100644 --- a/crates/rust-mcp-sdk/src/utils.rs +++ b/crates/rust-mcp-sdk/src/utils.rs @@ -31,6 +31,7 @@ pub fn format_assertion_message(entity: &str, capability: &str, method_name: &st /// # Returns /// A String containing the base path without query parameters or fragment /// ``` +#[allow(unused)] pub(crate) fn remove_query_and_hash(endpoint: &str) -> String { // Split off fragment (if any) and take the first part let without_fragment = endpoint.split_once('#').map_or(endpoint, |(path, _)| path); diff --git a/crates/rust-mcp-sdk/tests/test_server_sse.rs b/crates/rust-mcp-sdk/tests/test_server_sse.rs index 4196b73..ba7df51 100644 --- a/crates/rust-mcp-sdk/tests/test_server_sse.rs +++ b/crates/rust-mcp-sdk/tests/test_server_sse.rs @@ -17,6 +17,7 @@ pub mod common; #[tokio::test] async fn tets_sse_endpoint_event_default() { let server_options = HyperServerOptions { + port: 8081, session_id_generator: Some(Arc::new(TestIdGenerator::new(vec![ "AAA-BBB-CCC".to_string() ]))), @@ -27,9 +28,11 @@ async fn tets_sse_endpoint_event_default() { let server_endpoint = format!("{}{}", base_url, server_options.sse_endpoint()); - tokio::spawn(async move { - let server = create_test_server(server_options); + let server = create_test_server(server_options); + let handle = server.server_handle(); + let server_task = tokio::spawn(async move { server.start().await.unwrap(); + eprintln!("Server 1 is down"); }); sleep(Duration::from_millis(750)).await; @@ -66,11 +69,14 @@ async fn tets_sse_endpoint_event_default() { .await .unwrap(); assert!(res.status().is_success()); + handle.graceful_shutdown(Some(Duration::from_millis(1))); + server_task.await.unwrap(); } #[tokio::test] async fn tets_sse_message_endpoint_query_hash() { let server_options = HyperServerOptions { + port: 8082, custom_messages_endpoint: Some( "/custom-msg-endpoint?something=true&otherthing=false#section-59".to_string(), ), @@ -84,9 +90,12 @@ async fn tets_sse_message_endpoint_query_hash() { let server_endpoint = format!("{}{}", base_url, server_options.sse_endpoint()); - tokio::spawn(async move { - let server = create_test_server(server_options); + let server = create_test_server(server_options); + let handle = server.server_handle(); + + let server_task = tokio::spawn(async move { server.start().await.unwrap(); + eprintln!("Server 2 is down"); }); sleep(Duration::from_millis(750)).await; @@ -126,11 +135,14 @@ async fn tets_sse_message_endpoint_query_hash() { .await .unwrap(); assert!(res.status().is_success()); + handle.graceful_shutdown(Some(Duration::from_millis(1))); + server_task.await.unwrap(); } #[tokio::test] async fn tets_sse_custom_message_endpoint() { let server_options = HyperServerOptions { + port: 8083, custom_messages_endpoint: Some( "/custom-msg-endpoint?something=true&otherthing=false#section-59".to_string(), ), @@ -144,9 +156,12 @@ async fn tets_sse_custom_message_endpoint() { let server_endpoint = format!("{}{}", base_url, server_options.sse_endpoint()); - tokio::spawn(async move { - let server = create_test_server(server_options); + let server = create_test_server(server_options); + let handle = server.server_handle(); + + let server_task = tokio::spawn(async move { server.start().await.unwrap(); + eprintln!("Server 3 is down"); }); sleep(Duration::from_millis(750)).await; @@ -190,5 +205,7 @@ async fn tets_sse_custom_message_endpoint() { assert!(matches!(result, ServerMessage::Response(response) if matches!(&response.result, ResultFromServer::ServerResult(server_result) - if matches!(server_result, ServerResult::InitializeResult(init_result) if init_result.server_info.name == "Test MCP Server".to_string())))); + if matches!(server_result, ServerResult::InitializeResult(init_result) if init_result.server_info.name == "Test MCP Server")))); + handle.graceful_shutdown(Some(Duration::from_millis(1))); + server_task.await.unwrap(); } diff --git a/crates/rust-mcp-transport/src/utils/readable_channel.rs b/crates/rust-mcp-transport/src/utils/readable_channel.rs index d73697f..d07ca63 100644 --- a/crates/rust-mcp-transport/src/utils/readable_channel.rs +++ b/crates/rust-mcp-transport/src/utils/readable_channel.rs @@ -26,7 +26,6 @@ impl AsyncRead for ReadableChannel { /// /// # Returns /// * `Poll>` - Ready with Ok if data is read, Ready with Err if the channel is closed, or Pending if no data is available - fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, diff --git a/examples/hello-world-mcp-server/src/tools.rs b/examples/hello-world-mcp-server/src/tools.rs index 26a89cb..270cb0f 100644 --- a/examples/hello-world-mcp-server/src/tools.rs +++ b/examples/hello-world-mcp-server/src/tools.rs @@ -9,7 +9,11 @@ use rust_mcp_sdk::{ //****************// #[mcp_tool( name = "say_hello", - description = "Accepts a person's name and says a personalized \"Hello\" to that person" + description = "Accepts a person's name and says a personalized \"Hello\" to that person", + idempotent_hint = false, + destructive_hint = false, + open_world_hint = false, + read_only_hint = false )] #[derive(Debug, ::serde::Deserialize, ::serde::Serialize, JsonSchema)] pub struct SayHelloTool { @@ -29,7 +33,11 @@ impl SayHelloTool { //******************// #[mcp_tool( name = "say_goodbye", - description = "Accepts a person's name and says a personalized \"Goodbye\" to that person." + description = "Accepts a person's name and says a personalized \"Goodbye\" to that person.", + idempotent_hint = false, + destructive_hint = false, + open_world_hint = false, + read_only_hint = false )] #[derive(Debug, ::serde::Deserialize, ::serde::Serialize, JsonSchema)] pub struct SayGoodbyeTool {