diff --git a/README.md b/README.md index c1e201c..51c3b49 100644 --- a/README.md +++ b/README.md @@ -153,6 +153,7 @@ let server = hyper_server::create_server( HyperServerOptions { host: "127.0.0.1".to_string(), sse_support: false, + event_store: Some(Arc::new(InMemoryEventStore::default())), // enable resumability ..Default::default() }, ); @@ -191,7 +192,6 @@ impl ServerHandler for MyServerHandler { } /// Handles requests to call a specific tool. - async fn handle_call_tool_request( &self, request: CallToolRequest, runtime: Arc ) -> Result { if request.tool_name() == SayHelloTool::tool_name() { @@ -416,6 +416,7 @@ server.start().await?; Here is a list of available options with descriptions for configuring the HyperServer: ```rs + pub struct HyperServerOptions { /// Hostname or IP address the server will bind to (default: "127.0.0.1") pub host: String, @@ -432,6 +433,10 @@ pub struct HyperServerOptions { /// Shared transport configuration used by the server pub transport_options: Arc, + /// Event store for resumability support + /// If provided, resumability will be enabled, allowing clients to reconnect and resume messages + pub event_store: Option>, + /// This setting only applies to streamable HTTP. /// If true, the server will return JSON responses instead of starting an SSE stream. /// This can be useful for simple request/response scenarios without streaming. @@ -500,8 +505,8 @@ The `rust-mcp-sdk` crate provides several features that can be enabled or disabl - `macros`: Provides procedural macros for simplifying the creation and manipulation of MCP Tool structures. - `sse`: Enables support for the `Server-Sent Events (SSE)` transport. - `streamable-http`: Enables support for the `Streamable HTTP` transport. -- `stdio`: Enables support for the `standard input/output (stdio)` transport. +- `stdio`: Enables support for the `standard input/output (stdio)` transport. - `tls-no-provider`: Enables TLS without a crypto provider. This is useful if you are already using a different crypto provider than the aws-lc default. #### MCP Protocol Versions with Corresponding Features diff --git a/crates/rust-mcp-sdk/README.md b/crates/rust-mcp-sdk/README.md index 8036022..51c3b49 100644 --- a/crates/rust-mcp-sdk/README.md +++ b/crates/rust-mcp-sdk/README.md @@ -153,6 +153,7 @@ let server = hyper_server::create_server( HyperServerOptions { host: "127.0.0.1".to_string(), sse_support: false, + event_store: Some(Arc::new(InMemoryEventStore::default())), // enable resumability ..Default::default() }, ); @@ -415,6 +416,7 @@ server.start().await?; Here is a list of available options with descriptions for configuring the HyperServer: ```rs + pub struct HyperServerOptions { /// Hostname or IP address the server will bind to (default: "127.0.0.1") pub host: String, @@ -431,6 +433,10 @@ pub struct HyperServerOptions { /// Shared transport configuration used by the server pub transport_options: Arc, + /// Event store for resumability support + /// If provided, resumability will be enabled, allowing clients to reconnect and resume messages + pub event_store: Option>, + /// This setting only applies to streamable HTTP. /// If true, the server will return JSON responses instead of starting an SSE stream. /// This can be useful for simple request/response scenarios without streaming. @@ -499,8 +505,8 @@ The `rust-mcp-sdk` crate provides several features that can be enabled or disabl - `macros`: Provides procedural macros for simplifying the creation and manipulation of MCP Tool structures. - `sse`: Enables support for the `Server-Sent Events (SSE)` transport. - `streamable-http`: Enables support for the `Streamable HTTP` transport. -- `stdio`: Enables support for the `standard input/output (stdio)` transport. +- `stdio`: Enables support for the `standard input/output (stdio)` transport. - `tls-no-provider`: Enables TLS without a crypto provider. This is useful if you are already using a different crypto provider than the aws-lc default. #### MCP Protocol Versions with Corresponding Features diff --git a/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs b/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs index ff6d5b2..e7f8793 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs @@ -3,6 +3,7 @@ use std::{sync::Arc, time::Duration}; use super::session_store::SessionStore; use crate::mcp_traits::mcp_handler::McpServerHandler; use crate::{id_generator::FastIdGenerator, mcp_traits::IdGenerator, schema::InitializeResult}; +use rust_mcp_transport::event_store::EventStore; use rust_mcp_transport::{SessionId, TransportOptions}; /// Application state struct for the Hyper server @@ -30,6 +31,9 @@ pub struct AppState { /// Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). /// Default is false for backwards compatibility. pub dns_rebinding_protection: bool, + /// Event store for resumability support + /// If provided, resumability will be enabled, allowing clients to reconnect and resume messages + pub event_store: Option>, } impl AppState { diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs index da69c67..7101a73 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs @@ -23,7 +23,8 @@ use axum::{ use futures::stream; use hyper::{header, HeaderMap, StatusCode}; use rust_mcp_transport::{ - SessionId, SseTransport, StreamId, MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, + EventId, McpDispatch, SessionId, SseTransport, StreamId, ID_SEPARATOR, + MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, }; use std::{sync::Arc, time::Duration}; use tokio::io::{duplex, AsyncBufReadExt, BufReader}; @@ -36,6 +37,7 @@ async fn create_sse_stream( state: Arc, payload: Option<&str>, standalone: bool, + last_event_id: Option, ) -> TransportServerResult> { let payload_string = payload.map(|p| p.to_string()); @@ -53,50 +55,85 @@ async fn create_sse_stream( // writable stream to deliver message to the client let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE); - let transport = Arc::new( - SseTransport::::new( - read_rx, - write_tx, - read_tx, - Arc::clone(&state.transport_options), - ) - .map_err(|err| TransportServerError::TransportError(err.to_string()))?, - ); - - let stream_id: StreamId = if standalone { - DEFAULT_STREAM_ID.to_string() + let session_id = Arc::new(session_id); + let stream_id: Arc = if standalone { + Arc::new(DEFAULT_STREAM_ID.to_string()) } else { - state.stream_id_gen.generate() + Arc::new(state.stream_id_gen.generate()) }; + + let event_store = state.event_store.as_ref().map(Arc::clone); + let resumability_enabled = event_store.is_some(); + + let mut transport = SseTransport::::new( + read_rx, + write_tx, + read_tx, + Arc::clone(&state.transport_options), + ) + .map_err(|err| TransportServerError::TransportError(err.to_string()))?; + if let Some(event_store) = event_store.clone() { + transport.make_resumable((*session_id).clone(), (*stream_id).clone(), event_store); + } + let transport = Arc::new(transport); + let ping_interval = state.ping_interval; let runtime_clone = Arc::clone(&runtime); + let stream_id_clone = stream_id.clone(); + let transport_clone = transport.clone(); //Start the server runtime tokio::spawn(async move { match runtime_clone - .start_stream(transport, &stream_id, ping_interval, payload_string) + .start_stream( + transport_clone, + &stream_id_clone, + ping_interval, + payload_string, + ) .await { - Ok(_) => tracing::trace!("stream {} exited gracefully.", &stream_id), - Err(err) => tracing::info!("stream {} exited with error : {}", &stream_id, err), + Ok(_) => tracing::trace!("stream {} exited gracefully.", &stream_id_clone), + Err(err) => tracing::info!("stream {} exited with error : {}", &stream_id_clone, err), } - let _ = runtime.remove_transport(&stream_id).await; + let _ = runtime.remove_transport(&stream_id_clone).await; }); // Construct SSE stream let reader = BufReader::new(write_rx); - // outgoing messages from server to the client - let message_stream = stream::unfold(reader, |mut reader| async move { - let mut line = String::new(); - - match reader.read_line(&mut line).await { - Ok(0) => None, // EOF - Ok(_) => { - let trimmed_line = line.trim_end_matches('\n').to_owned(); - Some((Ok(Event::default().data(trimmed_line)), reader)) + // send outgoing messages from server to the client over the sse stream + let message_stream = stream::unfold(reader, move |mut reader| { + async move { + let mut line = String::new(); + + match reader.read_line(&mut line).await { + Ok(0) => None, // EOF + Ok(_) => { + let trimmed_line = line.trim_end_matches('\n').to_owned(); + + // empty sse comment to keep-alive + if is_empty_sse_message(&trimmed_line) { + return Some((Ok(Event::default()), reader)); + } + + let (event_id, message) = match ( + resumability_enabled, + trimmed_line.split_once(char::from(ID_SEPARATOR)), + ) { + (true, Some((id, msg))) => (Some(id.to_string()), msg.to_string()), + _ => (None, trimmed_line), + }; + + let event = match event_id { + Some(id) => Event::default().data(message).id(id), + None => Event::default().data(message), + }; + + Some((Ok(event), reader)) + } + Err(e) => Some((Err(e), reader)), } - Err(e) => Some((Err(e), reader)), } }); @@ -111,6 +148,23 @@ async fn create_sse_stream( HeaderValue::from_str(&session_id).unwrap(), ); + // if last_event_id exists we replay messages from the event-store + tokio::spawn(async move { + if let Some(last_event_id) = last_event_id { + if let Some(event_store) = state.event_store.as_ref() { + if let Some(events) = event_store.events_after(last_event_id).await { + for message_payload in events.messages { + // skip storing replay messages + let error = transport.write_str(&message_payload, true).await; + if let Err(error) = error { + tracing::trace!("Error replaying message: {error}") + } + } + } + } + } + }); + if !payload_contains_request { *response.status_mut() = StatusCode::ACCEPTED; } @@ -148,6 +202,7 @@ fn is_result(json_str: &str) -> Result { pub async fn create_standalone_stream( session_id: SessionId, + last_event_id: Option, state: Arc, ) -> TransportServerResult> { let runtime = state.session_store.get(&session_id).await.ok_or( @@ -161,12 +216,20 @@ pub async fn create_standalone_stream( return Ok((StatusCode::CONFLICT, Json(error)).into_response()); } + if let Some(last_event_id) = last_event_id.as_ref() { + tracing::trace!( + "SSE stream re-connected with last-event-id: {}", + last_event_id + ); + } + let mut response = create_sse_stream( runtime.clone(), session_id.clone(), state.clone(), None, true, + last_event_id, ) .await?; *response.status_mut() = StatusCode::OK; @@ -195,6 +258,7 @@ pub async fn start_new_session( state.clone(), Some(payload), false, + None, ) .await; @@ -354,6 +418,7 @@ pub async fn process_incoming_message( state.clone(), Some(payload), false, + None, ) .await } @@ -365,6 +430,10 @@ pub async fn process_incoming_message( } } +pub fn is_empty_sse_message(sse_payload: &str) -> bool { + sse_payload.is_empty() || sse_payload.trim() == ":" +} + pub async fn delete_session( session_id: SessionId, state: Arc, diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs index 00d46c0..67f8679 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs @@ -23,7 +23,7 @@ use axum::{ Json, Router, }; use hyper::{HeaderMap, StatusCode}; -use rust_mcp_transport::{SessionId, MCP_SESSION_ID_HEADER}; +use rust_mcp_transport::{SessionId, MCP_LAST_EVENT_ID_HEADER, MCP_SESSION_ID_HEADER}; use std::{collections::HashMap, sync::Arc}; pub fn routes(state: Arc, streamable_http_endpoint: &str) -> Router> { @@ -60,9 +60,14 @@ pub async fn handle_streamable_http_get( .and_then(|value| value.to_str().ok()) .map(|s| s.to_string()); + let last_event_id: Option = headers + .get(MCP_LAST_EVENT_ID_HEADER) + .and_then(|value| value.to_str().ok()) + .map(|s| s.to_string()); + match session_id { Some(session_id) => { - let res = create_standalone_stream(session_id, state).await?; + let res = create_standalone_stream(session_id, last_event_id, state).await?; Ok(res.into_response()) } None => { diff --git a/crates/rust-mcp-sdk/src/hyper_servers/server.rs b/crates/rust-mcp-sdk/src/hyper_servers/server.rs index 1c3b3cf..71bccee 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/server.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/server.rs @@ -23,7 +23,7 @@ use super::{ }; use crate::schema::InitializeResult; use axum::Router; -use rust_mcp_transport::{SessionId, TransportOptions}; +use rust_mcp_transport::{event_store::EventStore, SessionId, TransportOptions}; // Default client ping interval (12 seconds) const DEFAULT_CLIENT_PING_INTERVAL: Duration = Duration::from_secs(12); @@ -53,6 +53,10 @@ pub struct HyperServerOptions { /// Shared transport configuration used by the server pub transport_options: Arc, + /// Event store for resumability support + /// If provided, resumability will be enabled, allowing clients to reconnect and resume messages + pub event_store: Option>, + /// This setting only applies to streamable HTTP. /// If true, the server will return JSON responses instead of starting an SSE stream. /// This can be useful for simple request/response scenarios without streaming. @@ -225,6 +229,7 @@ impl Default for HyperServerOptions { allowed_hosts: None, allowed_origins: None, dns_rebinding_protection: false, + event_store: None, } } } @@ -271,6 +276,7 @@ impl HyperServer { allowed_hosts: server_options.allowed_hosts.take(), allowed_origins: server_options.allowed_origins.take(), dns_rebinding_protection: server_options.dns_rebinding_protection, + event_store: server_options.event_store.as_ref().map(Arc::clone), }); let app = app_routes(Arc::clone(&state), &server_options); Self { diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs index 1b24b57..5502cee 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs @@ -368,16 +368,17 @@ impl ServerRuntime { Ok(()) } + //TODO: re-visit and simplify unnecessary hashmap pub(crate) async fn remove_transport(&self, stream_id: &str) -> SdkResult<()> { if stream_id != DEFAULT_STREAM_ID { return Ok(()); } - let mut transport_map = self.transport_map.write().await; + let transport_map = self.transport_map.read().await; tracing::trace!("removing transport for stream id : {}", stream_id); if let Some(transport) = transport_map.get(stream_id) { transport.shut_down().await?; } - transport_map.remove(stream_id); + // transport_map.remove(stream_id); Ok(()) } @@ -435,6 +436,7 @@ impl ServerRuntime { }; // in case there is a payload, we consume it by transport to get processed + // payload would be message payload coming from the client if let Some(payload) = payload { if let Err(err) = transport.consume_string_payload(&payload).await { let _ = self.remove_transport(stream_id).await; diff --git a/crates/rust-mcp-sdk/tests/common/common.rs b/crates/rust-mcp-sdk/tests/common/common.rs index f330dda..6b78895 100644 --- a/crates/rust-mcp-sdk/tests/common/common.rs +++ b/crates/rust-mcp-sdk/tests/common/common.rs @@ -11,9 +11,11 @@ use rust_mcp_sdk::mcp_client::ClientHandler; use rust_mcp_sdk::schema::{ClientCapabilities, Implementation, InitializeRequestParams}; use std::collections::HashMap; use std::process; +use std::sync::Once; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use tokio::time::timeout; use tokio_stream::StreamExt; +use tracing_subscriber::EnvFilter; use wiremock::{MockServer, Request, ResponseTemplate}; pub use test_client::*; @@ -23,7 +25,17 @@ pub const NPX_SERVER_EVERYTHING: &str = "@modelcontextprotocol/server-everything #[cfg(unix)] pub const UVX_SERVER_GIT: &str = "mcp-server-git"; +static INIT: Once = Once::new(); +pub fn init_tracing() { + INIT.call_once(|| { + let filter = EnvFilter::try_from_default_env() + .or_else(|_| EnvFilter::try_new("tracing")) + .unwrap(); + + tracing_subscriber::fmt().with_env_filter(filter).init(); + }); +} #[mcp_tool( name = "say_hello", description = "Accepts a person's name and says a personalized \"Hello\" to that person", @@ -126,16 +138,18 @@ pub async fn send_get_request( ); } } + client.get(url).headers(headers).send().await } use futures::stream::Stream; // stream: &mut impl Stream>, +/// reads sse events and return them as (id, event, data) tuple pub async fn read_sse_event_from_stream( stream: &mut (impl Stream> + Unpin), event_count: usize, -) -> Option> { +) -> Option, Option, String)>> { let mut buffer = String::new(); let mut events = vec![]; @@ -146,27 +160,28 @@ pub async fn read_sse_event_from_stream( buffer.push_str(chunk_str); while let Some(pos) = buffer.find("\n\n") { - let data = { - // Scope to limit borrows - let (event_str, rest) = buffer.split_at(pos); - let mut data = None; - - // Process the event string - for line in event_str.lines() { - if line.starts_with("data:") { - data = Some(line.trim_start_matches("data:").trim().to_string()); - break; // Exit loop after finding data - } + let (event_str, rest) = buffer.split_at(pos); + let mut id = None; + let mut event = None; + let mut data = None; + + // Process the event string + for line in event_str.lines() { + if line.starts_with("id:") { + id = Some(line.trim_start_matches("id:").trim().to_string()); + } else if line.starts_with("event:") { + event = Some(line.trim_start_matches("event:").trim().to_string()); + } else if line.starts_with("data:") { + data = Some(line.trim_start_matches("data:").trim().to_string()); } + } - // Update buffer after processing - buffer = rest[2..].to_string(); // Skip "\n\n" - data - }; + // Update buffer after processing + buffer = rest[2..].to_string(); // Skip "\n\n" - // Return if data was found + // Only include events with data if let Some(data) = data { - events.push(data); + events.push((id, event, data)); if events.len().eq(&event_count) { return Some(events); } @@ -174,17 +189,26 @@ pub async fn read_sse_event_from_stream( } } Err(_e) => { - // return Err(TransportServerError::HyperError(e)); return None; } } } - None + if !events.is_empty() { + Some(events) + } else { + None + } } -pub async fn read_sse_event(response: Response, event_count: usize) -> Option> { +// return sse event as (id, event, data) tuple +pub async fn read_sse_event( + response: Response, + event_count: usize, +) -> Option, Option, String)>> { let mut stream = response.bytes_stream(); - read_sse_event_from_stream(&mut stream, event_count).await + let events = read_sse_event_from_stream(&mut stream, event_count).await; + // drop(stream); + events } pub fn test_client_info() -> InitializeRequestParams { @@ -280,9 +304,16 @@ pub fn random_port_old() -> u16 { } pub mod sample_tools { + use std::{sync::Arc, time::Duration}; + + use rust_mcp_schema::{LoggingMessageNotificationParams, TextContent}; #[cfg(feature = "2025_06_18")] use rust_mcp_sdk::macros::{mcp_tool, JsonSchema}; - use rust_mcp_sdk::schema::{schema_utils::CallToolError, CallToolResult}; + use rust_mcp_sdk::{ + schema::{schema_utils::CallToolError, CallToolResult}, + McpServer, + }; + use serde_json::json; //****************// // SayHelloTool // @@ -342,6 +373,43 @@ pub mod sample_tools { return Ok(CallToolResult::text_content(goodbye_message, None)); } } + + //****************************// + // StartNotificationStream // + //****************************// + #[mcp_tool( + name = "start-notification-stream", + description = "Accepts a person's name and says a personalized \"Goodbye\" to that person." + )] + #[derive(Debug, ::serde::Deserialize, ::serde::Serialize, JsonSchema)] + pub struct StartNotificationStream { + /// Interval in milliseconds between notifications + interval: u64, + /// Number of notifications to send (0 for 100) + count: u32, + } + impl StartNotificationStream { + pub async fn call_tool( + &self, + runtime: Arc, + ) -> Result { + for i in 0..self.count { + let _ = runtime + .send_logging_message(LoggingMessageNotificationParams { + data: json!({"id":format!("message {} of {}",i,self.count)}), + level: rust_mcp_sdk::schema::LoggingLevel::Emergency, + logger: None, + }) + .await; + tokio::time::sleep(Duration::from_millis(self.interval)).await; + } + + let message = format!("so many messages sent"); + Ok(CallToolResult::text_content(vec![TextContent::from( + message, + )])) + } + } } pub async fn wiremock_request(mock_server: &MockServer, index: usize) -> Request { diff --git a/crates/rust-mcp-sdk/tests/common/test_server.rs b/crates/rust-mcp-sdk/tests/common/test_server.rs index 769f8c6..d64244b 100644 --- a/crates/rust-mcp-sdk/tests/common/test_server.rs +++ b/crates/rust-mcp-sdk/tests/common/test_server.rs @@ -7,6 +7,7 @@ pub mod test_server_common { CallToolRequest, CallToolResult, ListToolsRequest, ListToolsResult, ProtocolVersion, RpcError, }; + use rust_mcp_sdk::event_store::EventStore; use rust_mcp_sdk::id_generator::IdGenerator; use rust_mcp_sdk::mcp_server::hyper_runtime::HyperRuntime; use rust_mcp_sdk::schema::{ @@ -31,6 +32,7 @@ pub mod test_server_common { pub streamable_url: String, pub sse_url: String, pub sse_message_url: String, + pub event_store: Option>, } pub fn initialize_request() -> InitializeRequest { @@ -120,6 +122,7 @@ pub mod test_server_common { let sse_url = options.sse_url(); let sse_message_url = options.sse_message_url(); + let event_store_clone = options.event_store.clone(); let server = hyper_server::create_server(test_server_details(), TestServerHandler {}, options); @@ -132,6 +135,7 @@ pub mod test_server_common { streamable_url, sse_url, sse_message_url, + event_store: event_store_clone, } } diff --git a/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs b/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs index cb82ff5..1d273e5 100644 --- a/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs +++ b/crates/rust-mcp-sdk/tests/test_streamable_http_client.rs @@ -350,6 +350,7 @@ async fn should_receive_server_initiated_messaged() { streamable_url, sse_url, sse_message_url, + event_store, } = create_start_server(server_options).await; let (client, message_history) = create_client(&streamable_url, None).await; diff --git a/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs b/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs index 4809d6d..af2dce6 100644 --- a/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs +++ b/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs @@ -12,7 +12,7 @@ use rust_mcp_schema::{ LoggingMessageNotificationParams, RequestId, RootsListChangedNotification, ServerNotification, ServerRequest, ServerResult, }; -use rust_mcp_sdk::mcp_server::HyperServerOptions; +use rust_mcp_sdk::{event_store::InMemoryEventStore, mcp_server::HyperServerOptions}; use serde_json::{json, Map, Value}; use crate::common::{ @@ -40,6 +40,8 @@ async fn initialize_server( "AAA-BBB-CCC".to_string() ]))), enable_json_response, + ping_interval: Duration::from_secs(1), + event_store: Some(Arc::new(InMemoryEventStore::default())), ..Default::default() }; @@ -169,7 +171,7 @@ async fn should_handle_post_requests_via_sse_response_correctly() { assert_eq!(response.status(), StatusCode::OK); let events = read_sse_event(response, 1).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&events[0]).unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&events[0].2).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -220,7 +222,7 @@ async fn should_call_a_tool_and_return_the_result() { assert_eq!(response.status(), StatusCode::OK); let events = read_sse_event(response, 1).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&events[0]).unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&events[0].2).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -290,12 +292,20 @@ async fn should_reject_invalid_session_id() { server.hyper_runtime.await_server().await.unwrap() } -async fn get_standalone_stream(streamable_url: &str, session_id: &str) -> reqwest::Response { +async fn get_standalone_stream( + streamable_url: &str, + session_id: &str, + last_event_id: Option<&str>, +) -> reqwest::Response { let mut headers = HashMap::new(); headers.insert("Accept", "text/event-stream , application/json"); headers.insert("mcp-session-id", session_id); headers.insert("mcp-protocol-version", "2025-03-26"); + if let Some(last_event_id) = last_event_id.clone() { + headers.insert("last-event-id", last_event_id); + } + let response = send_get_request(streamable_url, Some(headers)) .await .unwrap(); @@ -306,7 +316,7 @@ async fn get_standalone_stream(streamable_url: &str, session_id: &str) -> reqwes #[tokio::test] async fn should_establish_standalone_stream_and_receive_server_messages() { let (server, session_id) = initialize_server(None).await.unwrap(); - let response = get_standalone_stream(&server.streamable_url, &session_id).await; + let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; assert_eq!(response.status(), StatusCode::OK); @@ -345,7 +355,7 @@ async fn should_establish_standalone_stream_and_receive_server_messages() { .unwrap(); let events = read_sse_event(response, 1).await.unwrap(); - let message: ServerJsonrpcNotification = serde_json::from_str(&events[0]).unwrap(); + let message: ServerJsonrpcNotification = serde_json::from_str(&events[0].2).unwrap(); let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( notification, @@ -368,7 +378,7 @@ async fn should_establish_standalone_stream_and_receive_server_messages() { #[tokio::test] async fn should_establish_standalone_stream_and_receive_server_requests() { let (server, session_id) = initialize_server(None).await.unwrap(); - let response = get_standalone_stream(&server.streamable_url, &session_id).await; + let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; assert_eq!(response.status(), StatusCode::OK); @@ -429,14 +439,14 @@ async fn should_establish_standalone_stream_and_receive_server_requests() { // read two events from the sse stream let events = read_sse_event(response, 2).await.unwrap(); - let message1: ServerJsonrpcRequest = serde_json::from_str(&events[0]).unwrap(); + let message1: ServerJsonrpcRequest = serde_json::from_str(&events[0].2).unwrap(); let RequestFromServer::ServerRequest(ServerRequest::ListRootsRequest(_)) = message1.request else { panic!("invalid message received!"); }; - let message2: ServerJsonrpcRequest = serde_json::from_str(&events[1]).unwrap(); + let message2: ServerJsonrpcRequest = serde_json::from_str(&events[1].2).unwrap(); let RequestFromServer::ServerRequest(ServerRequest::ListRootsRequest(_)) = message1.request else { @@ -453,7 +463,7 @@ async fn should_establish_standalone_stream_and_receive_server_requests() { #[tokio::test] async fn should_not_close_get_sse_stream() { let (server, session_id) = initialize_server(None).await.unwrap(); - let response = get_standalone_stream(&server.streamable_url, &session_id).await; + let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; assert_eq!(response.status(), StatusCode::OK); @@ -472,7 +482,7 @@ async fn should_not_close_get_sse_stream() { let mut stream = response.bytes_stream(); let event = read_sse_event_from_stream(&mut stream, 1).await.unwrap()[0].clone(); - let message: ServerJsonrpcNotification = serde_json::from_str(&event).unwrap(); + let message: ServerJsonrpcNotification = serde_json::from_str(&event.2).unwrap(); let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( notification, @@ -501,7 +511,7 @@ async fn should_not_close_get_sse_stream() { .unwrap(); let event = read_sse_event_from_stream(&mut stream, 1).await.unwrap()[0].clone(); - let message: ServerJsonrpcNotification = serde_json::from_str(&event).unwrap(); + let message: ServerJsonrpcNotification = serde_json::from_str(&event.2).unwrap(); let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( notification_2, @@ -524,10 +534,10 @@ async fn should_not_close_get_sse_stream() { #[tokio::test] async fn should_reject_second_sse_stream_for_the_same_session() { let (server, session_id) = initialize_server(None).await.unwrap(); - let response = get_standalone_stream(&server.streamable_url, &session_id).await; + let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; assert_eq!(response.status(), StatusCode::OK); - let second_response = get_standalone_stream(&server.streamable_url, &session_id).await; + let second_response = get_standalone_stream(&server.streamable_url, &session_id, None).await; assert_eq!(second_response.status(), StatusCode::CONFLICT); let error_data: SdkError = second_response.json().await.unwrap(); @@ -713,7 +723,7 @@ async fn should_send_response_messages_to_the_connection_that_sent_the_request() assert_eq!(response_2.status(), StatusCode::OK); let events = read_sse_event(response_2, 1).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&events[0]).unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&events[0].2).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -729,7 +739,7 @@ async fn should_send_response_messages_to_the_connection_that_sent_the_request() ); let events = read_sse_event(response_1, 1).await.unwrap(); - let message: ServerJsonrpcResponse = serde_json::from_str(&events[0]).unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&events[0].2).unwrap(); assert!(matches!(message.id, RequestId::Integer(1))); @@ -1080,7 +1090,7 @@ async fn should_handle_batch_request_messages_with_sse_stream_for_responses() { ); let events = read_sse_event(response, 1).await.unwrap(); - let message: ServerMessages = serde_json::from_str(&events[0]).unwrap(); + let message: ServerMessages = serde_json::from_str(&events[0].2).unwrap(); let ServerMessages::Batch(mut messages) = message else { panic!("Invalid message type"); @@ -1358,5 +1368,177 @@ async fn should_skip_all_validations_when_false() { server.hyper_runtime.await_server().await.unwrap() } -//TODO: +// should store and include event IDs in server SSE messages +#[tokio::test] +async fn should_store_and_include_event_ids_in_server_sse_messages() { + common::init_tracing(); + let (server, session_id) = initialize_server(Some(true)).await.unwrap(); + let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; + + assert_eq!(response.status(), StatusCode::OK); + + let _ = server + .hyper_runtime + .send_logging_message( + &session_id, + LoggingMessageNotificationParams { + data: json!("notification1"), + level: LoggingLevel::Info, + logger: None, + }, + ) + .await; + + let _ = server + .hyper_runtime + .send_logging_message( + &session_id, + LoggingMessageNotificationParams { + data: json!("notification2"), + level: LoggingLevel::Info, + logger: None, + }, + ) + .await; + + // read two events + let events = read_sse_event(response, 2).await.unwrap(); + assert_eq!(events.len(), 2); + // verify we got the notification with an event ID + let (first_id, _, data) = events[0].clone(); + let (second_id, _, _) = events[0].clone(); + + let message: ServerJsonrpcNotification = serde_json::from_str(&data).unwrap(); + + let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( + notification1, + )) = message.notification + else { + panic!("invalid message received!"); + }; + + assert_eq!(notification1.params.data.as_str().unwrap(), "notification1"); + + let first_id = first_id.unwrap(); + assert!(second_id.is_some()); + + //messages should be stored and accessible + let events = server + .event_store + .unwrap() + .events_after(first_id) + .await + .unwrap(); + assert_eq!(events.messages.len(), 1); + + // deserialize the message returned by event_store + let message: ServerJsonrpcNotification = serde_json::from_str(&events.messages[0]).unwrap(); + let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( + notification2, + )) = message.notification + else { + panic!("invalid message in store!"); + }; + assert_eq!(notification2.params.data.as_str().unwrap(), "notification2"); +} + +// should store and replay MCP server tool notifications +#[tokio::test] +async fn should_store_and_replay_mcp_server_tool_notifications() { + common::init_tracing(); + let (server, session_id) = initialize_server(Some(true)).await.unwrap(); + let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; + assert_eq!(response.status(), StatusCode::OK); + + let _ = server + .hyper_runtime + .send_logging_message( + &session_id, + LoggingMessageNotificationParams { + data: json!("notification1"), + level: LoggingLevel::Info, + logger: None, + }, + ) + .await; + + let events = read_sse_event(response, 1).await.unwrap(); + assert_eq!(events.len(), 1); + // verify we got the notification with an event ID + let (first_id, _, data) = events[0].clone(); + + let message: ServerJsonrpcNotification = serde_json::from_str(&data).unwrap(); + + let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( + notification1, + )) = message.notification + else { + panic!("invalid message received!"); + }; + + assert_eq!(notification1.params.data.as_str().unwrap(), "notification1"); + + let first_id = first_id.unwrap(); + + // sse connection is closed in read_sse_event() + // wait so server detect the disconnect and simulate a network error + tokio::time::sleep(Duration::from_secs(3)).await; + tokio::task::yield_now().await; + // we send another notification while SSE is disconnected + let _result = server + .hyper_runtime + .send_logging_message( + &session_id, + LoggingMessageNotificationParams { + data: json!("notification2"), + level: LoggingLevel::Info, + logger: None, + }, + ) + .await; + + // make a new standalone SSE connection to simulate a re-connection + let response = + get_standalone_stream(&server.streamable_url, &session_id, Some(&first_id)).await; + assert_eq!(response.status(), StatusCode::OK); + let events = read_sse_event(response, 1).await.unwrap(); + + assert_eq!(events.len(), 1); + let message: ServerJsonrpcNotification = serde_json::from_str(&events[0].2).unwrap(); + + let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( + notification1, + )) = message.notification + else { + panic!("invalid message received!"); + }; + + assert_eq!(notification1.params.data.as_str().unwrap(), "notification2"); +} + // should return 400 error for invalid JSON-RPC messages +// should keep stream open after sending server notifications +// NA: should reject second initialization request +// NA: should pass request info to tool callback +// NA: should reject second SSE stream even in stateless mode +// should reject requests to uninitialized server +// should accept requests with matching protocol version +// should accept when protocol version differs from negotiated version +// should call a tool with authInfo +// should calls tool without authInfo when it is optional +// should accept pre-parsed request body +// should handle pre-parsed batch messages +// should prefer pre-parsed body over request body +// should operate without session ID validation +// should handle POST requests with various session IDs in stateless mode +// should call onsessionclosed callback when session is closed via DELETE +// should not call onsessionclosed callback when not provided +// should not call onsessionclosed callback for invalid session DELETE +// should call onsessionclosed callback with correct session ID when multiple sessions exist +// should support async onsessioninitialized callback +// should support sync onsessioninitialized callback (backwards compatibility) +// should support async onsessionclosed callback +// should propagate errors from async onsessioninitialized callback +// should propagate errors from async onsessionclosed callback +// should handle both async callbacks together +// should validate both host and origin when both are configured diff --git a/crates/rust-mcp-transport/src/client_sse.rs b/crates/rust-mcp-transport/src/client_sse.rs index 8d55bd0..0a1e8f3 100644 --- a/crates/rust-mcp-transport/src/client_sse.rs +++ b/crates/rust-mcp-transport/src/client_sse.rs @@ -457,10 +457,10 @@ impl McpDispatch sender.send_batch(message, request_timeout).await } - async fn write_str(&self, payload: &str) -> TransportResult<()> { + async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> { let sender = self.message_sender.read().await; let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; - sender.write_str(payload).await + sender.write_str(payload, skip_store).await } } diff --git a/crates/rust-mcp-transport/src/client_streamable_http.rs b/crates/rust-mcp-transport/src/client_streamable_http.rs index c318649..edda062 100644 --- a/crates/rust-mcp-transport/src/client_streamable_http.rs +++ b/crates/rust-mcp-transport/src/client_streamable_http.rs @@ -496,10 +496,10 @@ impl McpDispatch sender.send_batch(message, request_timeout).await } - async fn write_str(&self, payload: &str) -> TransportResult<()> { + async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> { let sender = self.message_sender.read().await; let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; - sender.write_str(payload).await + sender.write_str(payload, skip_store).await } } diff --git a/crates/rust-mcp-transport/src/event_store.rs b/crates/rust-mcp-transport/src/event_store.rs new file mode 100644 index 0000000..fdc0734 --- /dev/null +++ b/crates/rust-mcp-transport/src/event_store.rs @@ -0,0 +1,27 @@ +mod in_memory_event_store; +use async_trait::async_trait; +pub use in_memory_event_store::*; + +use crate::{EventId, SessionId, StreamId}; + +#[derive(Debug, Clone)] +pub struct EventStoreMessages { + pub session_id: SessionId, + pub stream_id: StreamId, + pub messages: Vec, +} + +#[async_trait] +pub trait EventStore: Send + Sync { + async fn store_event( + &self, + session_id: SessionId, + stream_id: StreamId, + time_stamp: u128, + message: String, + ) -> EventId; + async fn remove_by_session_id(&self, session_id: SessionId); + async fn remove_stream_in_session(&self, session_id: SessionId, stream_id: StreamId); + async fn clear(&self); + async fn events_after(&self, last_event_id: EventId) -> Option; +} diff --git a/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs b/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs new file mode 100644 index 0000000..66e738c --- /dev/null +++ b/crates/rust-mcp-transport/src/event_store/in_memory_event_store.rs @@ -0,0 +1,274 @@ +use async_trait::async_trait; +use std::collections::HashMap; +use std::collections::VecDeque; +use tokio::sync::RwLock; + +use crate::{ + event_store::{EventStore, EventStoreMessages}, + EventId, SessionId, StreamId, +}; + +const MAX_EVENTS_PER_SESSION: usize = 64; +const ID_SEPARATOR: &str = "-.-"; + +#[derive(Debug, Clone)] +struct EventEntry { + pub stream_id: StreamId, + pub time_stamp: u128, + pub message: String, +} + +#[derive(Debug)] +pub struct InMemoryEventStore { + max_events_per_session: usize, + storage_map: RwLock>>, +} + +impl Default for InMemoryEventStore { + fn default() -> Self { + Self { + max_events_per_session: MAX_EVENTS_PER_SESSION, + storage_map: Default::default(), + } + } +} + +/// In-memory implementation of the `EventStore` trait for MCP's Streamable HTTP transport. +/// +/// Stores events in a `HashMap` of session IDs to `VecDeque`s of events, with a per-session limit. +/// Events are identified by `event_id` (format: `session-.-stream-.-timestamp`) and used for SSE resumption. +/// Thread-safe via `RwLock` for concurrent access. +impl InMemoryEventStore { + /// Creates a new `InMemoryEventStore` with an optional maximum events per session. + /// + /// # Arguments + /// - `max_events_per_session`: Maximum number of events per session. Defaults to `MAX_EVENTS_PER_SESSION` (32) if `None`. + /// + /// # Returns + /// A new `InMemoryEventStore` instance with an empty `HashMap` wrapped in a `RwLock`. + /// + /// # Example + /// ``` + /// let store = InMemoryEventStore::new(Some(10)); + /// assert_eq!(store.max_events_per_session, 10); + /// ``` + pub fn new(max_events_per_session: Option) -> Self { + Self { + max_events_per_session: max_events_per_session.unwrap_or(MAX_EVENTS_PER_SESSION), + storage_map: RwLock::new(HashMap::new()), + } + } + + /// Generates an `event_id` string from session, stream, and timestamp components. + /// + /// Format: `session-.-stream-.-timestamp`, used as a resumption cursor in SSE (`Last-Event-ID`). + /// + /// # Arguments + /// - `session_id`: The session identifier. + /// - `stream_id`: The stream identifier. + /// - `time_stamp`: The event timestamp (u128). + /// + /// # Returns + /// A `String` in the format `session-.-stream-.-timestamp`. + fn generate_event_id( + &self, + session_id: &SessionId, + stream_id: &StreamId, + time_stamp: u128, + ) -> String { + format!("{session_id}{ID_SEPARATOR}{stream_id}{ID_SEPARATOR}{time_stamp}") + } + + /// Parses an event ID into its session, stream, and timestamp components. + /// + /// The event ID must follow the format `session-.-stream-.-timestamp`. + /// Returns `None` if the format is invalid, empty, or contains invalid characters (e.g., NULL). + /// + /// # Arguments + /// - `event_id`: The event ID string to parse. + /// + /// # Returns + /// An `Option` containing a tuple of `(session_id, stream_id, time_stamp)` as string slices, + /// or `None` if the format is invalid. + /// + /// # Example + /// ``` + /// let store = InMemoryEventStore::new(None); + /// let event_id = "session1-.-stream1-.-12345"; + /// assert_eq!( + /// store.parse_event_id(event_id), + /// Some(("session1", "stream1", "12345")) + /// ); + /// assert_eq!(store.parse_event_id("invalid"), None); + /// ``` + pub fn parse_event_id<'a>(&self, event_id: &'a str) -> Option<(&'a str, &'a str, &'a str)> { + // Check for empty input or invalid characters (e.g., NULL) + if event_id.is_empty() || event_id.contains('\0') { + return None; + } + + // Split into exactly three parts + let parts: Vec<&'a str> = event_id.split(ID_SEPARATOR).collect(); + if parts.len() != 3 { + return None; + } + + let session_id = parts[0]; + let stream_id = parts[1]; + let time_stamp = parts[2]; + + // Ensure no part is empty + if session_id.is_empty() || stream_id.is_empty() || time_stamp.is_empty() { + return None; + } + + Some((session_id, stream_id, time_stamp)) + } +} + +#[async_trait] +impl EventStore for InMemoryEventStore { + /// Stores an event for a given session and stream, returning its `event_id`. + /// + /// Adds the event to the session’s `VecDeque`, removing the oldest event if the session + /// reaches `max_events_per_session`. + /// + /// # Arguments + /// - `session_id`: The session identifier. + /// - `stream_id`: The stream identifier. + /// - `time_stamp`: The event timestamp (u128). + /// - `message`: The `ServerMessages` payload. + /// + /// # Returns + /// The generated `EventId` for the stored event. + async fn store_event( + &self, + session_id: SessionId, + stream_id: StreamId, + time_stamp: u128, + message: String, + ) -> EventId { + let event_id = self.generate_event_id(&session_id, &stream_id, time_stamp); + + let mut storage_map = self.storage_map.write().await; + + tracing::trace!( + "Storing event for session: {session_id}, stream_id: {stream_id}, message: '{message}', {time_stamp} ", + ); + + let session_map = storage_map + .entry(session_id) + .or_insert_with(|| VecDeque::with_capacity(self.max_events_per_session)); + + if session_map.len() == self.max_events_per_session { + session_map.pop_front(); // remove the oldest if full + } + + let entry = EventEntry { + stream_id, + time_stamp, + message, + }; + + session_map.push_back(entry); + + event_id + } + + /// Removes all events associated with a given stream ID within a specific session. + /// + /// Removes events matching `stream_id` from the specified `session_id`’s event queue. + /// If the session’s queue becomes empty, it is removed from the store. + /// Idempotent if `session_id` or `stream_id` doesn’t exist. + /// + /// # Arguments + /// - `session_id`: The session identifier to target. + /// - `stream_id`: The stream identifier to remove. + async fn remove_stream_in_session(&self, session_id: SessionId, stream_id: StreamId) { + let mut storage_map = self.storage_map.write().await; + + // Check if session exists + if let Some(events) = storage_map.get_mut(&session_id) { + // Remove events with the given stream_id + events.retain(|event| event.stream_id != stream_id); + // Remove session if empty + if events.is_empty() { + storage_map.remove(&session_id); + } + } + // No action if session_id doesn’t exist (idempotent) + } + + /// Removes all events associated with a given session ID. + /// + /// Removes the entire session from the store. Idempotent if `session_id` doesn’t exist. + /// + /// # Arguments + /// - `session_id`: The session identifier to remove. + async fn remove_by_session_id(&self, session_id: SessionId) { + let mut storage_map = self.storage_map.write().await; + storage_map.remove(&session_id); + } + + /// Retrieves events after a given `event_id` for a specific session and stream. + /// + /// Parses `last_event_id` to extract `session_id`, `stream_id`, and `time_stamp`. + /// Returns events after the matching event in the session’s stream, sorted by timestamp + /// in ascending order (earliest to latest). Returns `None` if the `event_id` is invalid, + /// the session doesn’t exist, or the timestamp is non-numeric. + /// + /// # Arguments + /// - `last_event_id`: The event ID (format: `session-.-stream-.-timestamp`) to start after. + /// + /// # Returns + /// An `Option` containing `EventStoreMessages` with the session ID, stream ID, and sorted messages, + /// or `None` if no events are found or the input is invalid. + async fn events_after(&self, last_event_id: EventId) -> Option { + let Some((session_id, stream_id, time_stamp)) = self.parse_event_id(&last_event_id) else { + tracing::warn!("error parsing last event id: '{last_event_id}'"); + return None; + }; + + let storage_map = self.storage_map.read().await; + let Some(events) = storage_map.get(session_id) else { + tracing::warn!("could not find the session_id in the store : '{session_id}'"); + return None; + }; + + let Ok(time_stamp) = time_stamp.parse::() else { + tracing::warn!("could not parse the timestamp: '{time_stamp}'"); + return None; + }; + + let events = match events + .iter() + .position(|e| e.stream_id == stream_id && e.time_stamp == time_stamp) + { + Some(index) if index + 1 < events.len() => { + // Collect subsequent events that match the stream_id + let mut subsequent: Vec<_> = events + .range(index + 1..) + .filter(|e| e.stream_id == stream_id) + .cloned() + .collect(); + + subsequent.sort_by(|a, b| a.time_stamp.cmp(&b.time_stamp)); + subsequent.iter().map(|e| e.message.clone()).collect() + } + _ => vec![], + }; + + tracing::trace!("{} messages after '{last_event_id}'", events.len()); + + Some(EventStoreMessages { + session_id: session_id.to_string(), + stream_id: stream_id.to_string(), + messages: events, + }) + } + + async fn clear(&self) { + let mut storage_map = self.storage_map.write().await; + storage_map.clear(); + } +} diff --git a/crates/rust-mcp-transport/src/lib.rs b/crates/rust-mcp-transport/src/lib.rs index 4a918db..d21e5dd 100644 --- a/crates/rust-mcp-transport/src/lib.rs +++ b/crates/rust-mcp-transport/src/lib.rs @@ -8,6 +8,7 @@ mod client_sse; mod client_streamable_http; mod constants; pub mod error; +pub mod event_store; mod mcp_stream; mod message_dispatcher; mod schema; diff --git a/crates/rust-mcp-transport/src/message_dispatcher.rs b/crates/rust-mcp-transport/src/message_dispatcher.rs index 7c7c93e..cd9727c 100644 --- a/crates/rust-mcp-transport/src/message_dispatcher.rs +++ b/crates/rust-mcp-transport/src/message_dispatcher.rs @@ -1,13 +1,20 @@ -use crate::schema::{ - schema_utils::{ - self, ClientMessage, ClientMessages, McpMessage, RpcMessage, ServerMessage, ServerMessages, +use crate::error::{TransportError, TransportResult}; +use crate::schema::{RequestId, RpcError}; +use crate::utils::{await_timeout, current_timestamp}; +use crate::McpDispatch; +use crate::{ + event_store::EventStore, + schema::{ + schema_utils::{ + self, ClientMessage, ClientMessages, McpMessage, RpcMessage, ServerMessage, + ServerMessages, + }, + JsonrpcError, }, - JsonrpcError, + SessionId, StreamId, }; -use crate::schema::{RequestId, RpcError}; use async_trait::async_trait; use futures::future::join_all; - use std::collections::HashMap; use std::pin::Pin; use std::sync::Arc; @@ -16,9 +23,7 @@ use tokio::io::AsyncWriteExt; use tokio::sync::oneshot::{self}; use tokio::sync::Mutex; -use crate::error::{TransportError, TransportResult}; -use crate::utils::await_timeout; -use crate::McpDispatch; +pub const ID_SEPARATOR: u8 = b'|'; /// Provides a dispatcher for sending MCP messages and handling responses. /// @@ -37,6 +42,10 @@ pub struct MessageDispatcher { )>, >, request_timeout: Duration, + // resumability support + session_id: Option, + stream_id: Option, + event_store: Option>, } impl MessageDispatcher { @@ -60,6 +69,9 @@ impl MessageDispatcher { writable_std: Some(writable_std), writable_tx: None, request_timeout, + session_id: None, + stream_id: None, + event_store: None, } } @@ -76,9 +88,25 @@ impl MessageDispatcher { writable_tx: Some(writable_tx), writable_std: None, request_timeout, + session_id: None, + stream_id: None, + event_store: None, } } + /// Supports resumability for streamable HTTP transports by setting the session ID, + /// stream ID, and event store. + pub fn make_resumable( + &mut self, + session_id: SessionId, + stream_id: StreamId, + event_store: Arc, + ) { + self.session_id = Some(session_id); + self.stream_id = Some(stream_id); + self.event_store = Some(event_store); + } + async fn store_pending_request( &self, request_id: RequestId, @@ -141,7 +169,7 @@ impl McpDispatch crate::error::TransportError::JsonrpcError(RpcError::parse_error()) })?; - self.write_str(message_payload.as_str()).await?; + self.write_str(message_payload.as_str(), true).await?; if let Some(rx) = rx_response { // Wait for the response with timeout @@ -177,7 +205,7 @@ impl McpDispatch let message_payload = serde_json::to_string(&client_messages).map_err(|_| { crate::error::TransportError::JsonrpcError(RpcError::parse_error()) })?; - self.write_str(message_payload.as_str()).await?; + self.write_str(message_payload.as_str(), true).await?; // no request in the batch, no need to wait for the result if request_ids.is_empty() { @@ -233,7 +261,7 @@ impl McpDispatch /// Writes a string payload to the underlying asynchronous writable stream, /// appending a newline character and flushing the stream afterward. /// - async fn write_str(&self, payload: &str) -> TransportResult<()> { + async fn write_str(&self, payload: &str, _skip_store: bool) -> TransportResult<()> { if let Some(writable_std) = self.writable_std.as_ref() { let mut writable_std = writable_std.lock().await; writable_std.write_all(payload.as_bytes()).await?; @@ -289,7 +317,7 @@ impl McpDispatch crate::error::TransportError::JsonrpcError(RpcError::parse_error()) })?; - self.write_str(message_payload.as_str()).await?; + self.write_str(message_payload.as_str(), false).await?; if let Some(rx) = rx_response { match await_timeout(rx, request_timeout.unwrap_or(self.request_timeout)).await { @@ -317,7 +345,7 @@ impl McpDispatch crate::error::TransportError::JsonrpcError(RpcError::parse_error()) })?; - self.write_str(message_payload.as_str()).await?; + self.write_str(message_payload.as_str(), false).await?; // no request in the batch, no need to wait for the result if pending_tasks.is_empty() { @@ -375,9 +403,34 @@ impl McpDispatch /// Writes a string payload to the underlying asynchronous writable stream, /// appending a newline character and flushing the stream afterward. /// - async fn write_str(&self, payload: &str) -> TransportResult<()> { + async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> { + let mut event_id = None; + + if !skip_store && !payload.trim().is_empty() { + if let (Some(session_id), Some(stream_id), Some(event_store)) = ( + self.session_id.as_ref(), + self.stream_id.as_ref(), + self.event_store.as_ref(), + ) { + event_id = Some( + event_store + .store_event( + session_id.clone(), + stream_id.clone(), + current_timestamp(), + payload.to_owned(), + ) + .await, + ) + }; + } + if let Some(writable_std) = self.writable_std.as_ref() { let mut writable_std = writable_std.lock().await; + if let Some(id) = event_id { + writable_std.write_all(id.as_bytes()).await?; + writable_std.write_all(&[ID_SEPARATOR]).await?; // separate id from message + } writable_std.write_all(payload.as_bytes()).await?; writable_std.write_all(b"\n").await?; // new line writable_std.flush().await?; diff --git a/crates/rust-mcp-transport/src/sse.rs b/crates/rust-mcp-transport/src/sse.rs index 09809e4..89ca67f 100644 --- a/crates/rust-mcp-transport/src/sse.rs +++ b/crates/rust-mcp-transport/src/sse.rs @@ -1,3 +1,4 @@ +use crate::event_store::EventStore; use crate::schema::schema_utils::{ ClientMessage, ClientMessages, MessageFromServer, SdkError, ServerMessage, ServerMessages, }; @@ -19,7 +20,7 @@ use crate::mcp_stream::MCPStream; use crate::message_dispatcher::MessageDispatcher; use crate::transport::Transport; use crate::utils::{endpoint_with_session_id, CancellationTokenSource}; -use crate::{IoStream, McpDispatch, SessionId, TransportDispatcher, TransportOptions}; +use crate::{IoStream, McpDispatch, SessionId, StreamId, TransportDispatcher, TransportOptions}; pub struct SseTransport where @@ -33,6 +34,10 @@ where message_sender: Arc>>>, error_stream: tokio::sync::RwLock>, pending_requests: Arc>>>, + // resumability support + session_id: Option, + stream_id: Option, + event_store: Option>, } /// Server-Sent Events (SSE) transport implementation @@ -67,6 +72,9 @@ where message_sender: Arc::new(tokio::sync::RwLock::new(None)), error_stream: tokio::sync::RwLock::new(None), pending_requests: Arc::new(Mutex::new(HashMap::new())), + session_id: None, + stream_id: None, + event_store: None, }) } @@ -86,6 +94,19 @@ where let mut lock = self.error_stream.write().await; *lock = Some(IoStream::Writable(error_stream)); } + + /// Supports resumability for streamable HTTP transports by setting the session ID, + /// stream ID, and event store. + pub fn make_resumable( + &mut self, + session_id: SessionId, + stream_id: StreamId, + event_store: Arc, + ) { + self.session_id = Some(session_id); + self.stream_id = Some(stream_id); + self.event_store = Some(event_store); + } } #[async_trait] @@ -123,10 +144,10 @@ impl McpDispatch sender.send_batch(message, request_timeout).await } - async fn write_str(&self, payload: &str) -> TransportResult<()> { + async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> { let sender = self.message_sender.read().await; let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; - sender.write_str(payload).await + sender.write_str(payload, skip_store).await } } @@ -161,7 +182,7 @@ impl Transport( + let (stream, mut sender, error_stream) = MCPStream::create::( Box::pin(read_rx), Mutex::new(Box::pin(write_tx)), IoStream::Writable(Box::pin(tokio::io::stderr())), @@ -170,6 +191,18 @@ impl Transport {} Err(TransportError::Io(error)) => { if error.kind() == std::io::ErrorKind::BrokenPipe { diff --git a/crates/rust-mcp-transport/src/stdio.rs b/crates/rust-mcp-transport/src/stdio.rs index 11bd0a6..7678c65 100644 --- a/crates/rust-mcp-transport/src/stdio.rs +++ b/crates/rust-mcp-transport/src/stdio.rs @@ -348,10 +348,10 @@ impl McpDispatch sender.send_batch(message, request_timeout).await } - async fn write_str(&self, payload: &str) -> TransportResult<()> { + async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> { let sender = self.message_sender.read().await; let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; - sender.write_str(payload).await + sender.write_str(payload, skip_store).await } } @@ -400,10 +400,10 @@ impl McpDispatch sender.send_batch(message, request_timeout).await } - async fn write_str(&self, payload: &str) -> TransportResult<()> { + async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()> { let sender = self.message_sender.read().await; let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; - sender.write_str(payload).await + sender.write_str(payload, skip_store).await } } diff --git a/crates/rust-mcp-transport/src/transport.rs b/crates/rust-mcp-transport/src/transport.rs index b8e3ddc..a9e7190 100644 --- a/crates/rust-mcp-transport/src/transport.rs +++ b/crates/rust-mcp-transport/src/transport.rs @@ -82,7 +82,7 @@ where /// Writes a string payload to the underlying asynchronous writable stream, /// appending a newline character and flushing the stream afterward. /// - async fn write_str(&self, payload: &str) -> TransportResult<()>; + async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()>; } /// A trait representing the transport layer for the MCP (Message Communication Protocol). diff --git a/crates/rust-mcp-transport/src/utils.rs b/crates/rust-mcp-transport/src/utils.rs index 82d7326..034f062 100644 --- a/crates/rust-mcp-transport/src/utils.rs +++ b/crates/rust-mcp-transport/src/utils.rs @@ -25,6 +25,8 @@ pub(crate) use sse_stream::*; pub(crate) use streamable_http_stream::*; #[cfg(any(feature = "sse", feature = "streamable-http"))] pub(crate) use writable_channel::*; +mod time_utils; +pub use time_utils::*; use crate::schema::schema_utils::SdkError; use tokio::time::{timeout, Duration}; diff --git a/crates/rust-mcp-transport/src/utils/time_utils.rs b/crates/rust-mcp-transport/src/utils/time_utils.rs new file mode 100644 index 0000000..25c4f5d --- /dev/null +++ b/crates/rust-mcp-transport/src/utils/time_utils.rs @@ -0,0 +1,8 @@ +use std::time::{SystemTime, UNIX_EPOCH}; + +pub fn current_timestamp() -> u128 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Invalid time") + .as_nanos() +} diff --git a/examples/hello-world-server-streamable-http-core/src/main.rs b/examples/hello-world-server-streamable-http-core/src/main.rs index 7b41c70..81a6ae5 100644 --- a/examples/hello-world-server-streamable-http-core/src/main.rs +++ b/examples/hello-world-server-streamable-http-core/src/main.rs @@ -1,7 +1,10 @@ mod handler; mod tools; +use std::sync::Arc; + use handler::MyServerHandler; +use rust_mcp_sdk::event_store::InMemoryEventStore; use rust_mcp_sdk::schema::{ Implementation, InitializeResult, ServerCapabilities, ServerCapabilitiesTools, LATEST_PROTOCOL_VERSION, @@ -48,6 +51,7 @@ async fn main() -> SdkResult<()> { handler, HyperServerOptions { sse_support: true, + event_store: Some(Arc::new(InMemoryEventStore::default())), // enable resumability ..Default::default() }, ); diff --git a/examples/hello-world-server-streamable-http/src/main.rs b/examples/hello-world-server-streamable-http/src/main.rs index cd8c658..3923a6d 100644 --- a/examples/hello-world-server-streamable-http/src/main.rs +++ b/examples/hello-world-server-streamable-http/src/main.rs @@ -1,8 +1,10 @@ mod handler; mod tools; +use std::sync::Arc; use std::time::Duration; +use rust_mcp_sdk::event_store::InMemoryEventStore; use rust_mcp_sdk::mcp_server::{hyper_server, HyperServerOptions}; use handler::MyServerHandler; @@ -57,6 +59,7 @@ async fn main() -> SdkResult<()> { HyperServerOptions { host: "127.0.0.1".to_string(), ping_interval: Duration::from_secs(5), + event_store: Some(Arc::new(InMemoryEventStore::default())), // enable resumability ..Default::default() }, );