diff --git a/Cargo.lock b/Cargo.lock index 0acb30d..51d9cb0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1658,7 +1658,11 @@ dependencies = [ "axum", "axum-server", "base64 0.22.1", + "bytes", "futures", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", "hyper 1.7.0", "reqwest", "rust-mcp-macros", diff --git a/crates/rust-mcp-sdk/Cargo.toml b/crates/rust-mcp-sdk/Cargo.toml index 8bba7c7..be70e07 100644 --- a/crates/rust-mcp-sdk/Cargo.toml +++ b/crates/rust-mcp-sdk/Cargo.toml @@ -29,9 +29,13 @@ tokio-stream = { workspace = true, optional = true } axum-server = { version = "0.7", features = [], optional = true } tracing.workspace = true base64.workspace = true +bytes.workspace = true # rustls = { workspace = true, optional = true } hyper = { version = "1.6.0", optional = true } +http = { version ="1.3", optional = true } +http-body-util = { version ="0.1", optional = true } +http-body = { version ="1.0", optional = true } [dev-dependencies] wiremock = "0.5" @@ -61,13 +65,13 @@ default = [ "2025_06_18", ] # All features enabled by default -sse = ["rust-mcp-transport/sse"] -streamable-http = ["rust-mcp-transport/streamable-http"] +sse = ["rust-mcp-transport/sse","http","http-body","http-body-util"] +streamable-http = ["rust-mcp-transport/streamable-http","http","http-body","http-body-util"] stdio = ["rust-mcp-transport/stdio"] server = [] # Server feature client = [] # Client feature -hyper-server = ["axum", "axum-server", "hyper", "server", "tokio-stream"] +hyper-server = ["axum", "axum-server", "hyper", "server", "tokio-stream","http","http-body","http-body-util"] ssl = ["axum-server/tls-rustls"] tls-no-provider = ["axum-server/tls-rustls-no-provider"] macros = ["rust-mcp-macros/sdk"] diff --git a/crates/rust-mcp-sdk/src/hyper_servers.rs b/crates/rust-mcp-sdk/src/hyper_servers.rs index f18c428..87307c0 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers.rs @@ -1,12 +1,8 @@ -mod app_state; pub mod error; pub mod hyper_runtime; pub mod hyper_server; pub mod hyper_server_core; -mod middlewares; mod routes; mod server; -mod session_store; pub use server::*; -pub use session_store::*; diff --git a/crates/rust-mcp-sdk/src/hyper_servers/error.rs b/crates/rust-mcp-sdk/src/hyper_servers/error.rs index 74cbcd1..dd55d8f 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/error.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/error.rs @@ -15,6 +15,8 @@ pub enum TransportServerError { StreamIoError(String), #[error("{0}")] AddrParseError(#[from] AddrParseError), + #[error("{0}")] + HttpError(String), #[error("Server start error: {0}")] ServerStartError(String), #[error("Invalid options: {0}")] diff --git a/crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs b/crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs index 85cf791..92eed79 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs @@ -1,6 +1,7 @@ use std::{sync::Arc, time::Duration}; use crate::{ + mcp_http::McpAppState, mcp_server::HyperServer, schema::{ schema_utils::{NotificationFromServer, RequestFromServer, ResultFromClient}, @@ -18,7 +19,6 @@ use tokio::{sync::Mutex, task::JoinHandle}; use crate::{ error::SdkResult, - hyper_servers::app_state::AppState, mcp_server::{ error::{TransportServerError, TransportServerResult}, ServerRuntime, @@ -26,7 +26,7 @@ use crate::{ }; pub struct HyperRuntime { - pub(crate) state: Arc, + pub(crate) state: Arc, pub(crate) server_task: JoinHandle>, pub(crate) server_handle: Handle, } diff --git a/crates/rust-mcp-sdk/src/hyper_servers/middlewares.rs b/crates/rust-mcp-sdk/src/hyper_servers/middlewares.rs deleted file mode 100644 index 0222952..0000000 --- a/crates/rust-mcp-sdk/src/hyper_servers/middlewares.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub(crate) mod protect_dns_rebinding; -pub(crate) mod session_id_gen; diff --git a/crates/rust-mcp-sdk/src/hyper_servers/middlewares/protect_dns_rebinding.rs b/crates/rust-mcp-sdk/src/hyper_servers/middlewares/protect_dns_rebinding.rs deleted file mode 100644 index 5674e87..0000000 --- a/crates/rust-mcp-sdk/src/hyper_servers/middlewares/protect_dns_rebinding.rs +++ /dev/null @@ -1,66 +0,0 @@ -use crate::hyper_servers::app_state::AppState; -use crate::schema::schema_utils::SdkError; -use axum::{ - extract::{Request, State}, - middleware::Next, - response::IntoResponse, - Json, -}; -use hyper::{ - header::{HOST, ORIGIN}, - HeaderMap, StatusCode, -}; -use std::sync::Arc; - -// Middleware to protect against DNS rebinding attacks by validating Host and Origin headers. -pub async fn protect_dns_rebinding( - headers: HeaderMap, - State(state): State>, - request: Request, - next: Next, -) -> impl IntoResponse { - if !state.needs_dns_protection() { - // If protection is not needed, pass the request to the next handler - return next.run(request).await.into_response(); - } - - if let Some(allowed_hosts) = state.allowed_hosts.as_ref() { - if !allowed_hosts.is_empty() { - let Some(host) = headers.get(HOST).and_then(|h| h.to_str().ok()) else { - let error = SdkError::bad_request().with_message("Invalid Host header: [unknown] "); - return (StatusCode::FORBIDDEN, Json(error)).into_response(); - }; - - if !allowed_hosts - .iter() - .any(|allowed| allowed.eq_ignore_ascii_case(host)) - { - let error = SdkError::bad_request() - .with_message(format!("Invalid Host header: \"{host}\" ").as_str()); - return (StatusCode::FORBIDDEN, Json(error)).into_response(); - } - } - } - - if let Some(allowed_origins) = state.allowed_origins.as_ref() { - if !allowed_origins.is_empty() { - let Some(origin) = headers.get(ORIGIN).and_then(|h| h.to_str().ok()) else { - let error = - SdkError::bad_request().with_message("Invalid Origin header: [unknown] "); - return (StatusCode::FORBIDDEN, Json(error)).into_response(); - }; - - if !allowed_origins - .iter() - .any(|allowed| allowed.eq_ignore_ascii_case(origin)) - { - let error = SdkError::bad_request() - .with_message(format!("Invalid Origin header: \"{origin}\" ").as_str()); - return (StatusCode::FORBIDDEN, Json(error)).into_response(); - } - } - } - - // If all checks pass, proceed to the next handler in the chain - next.run(request).await -} diff --git a/crates/rust-mcp-sdk/src/hyper_servers/middlewares/session_id_gen.rs b/crates/rust-mcp-sdk/src/hyper_servers/middlewares/session_id_gen.rs deleted file mode 100644 index b68b325..0000000 --- a/crates/rust-mcp-sdk/src/hyper_servers/middlewares/session_id_gen.rs +++ /dev/null @@ -1,23 +0,0 @@ -use std::sync::Arc; - -use axum::{ - extract::{Request, State}, - middleware::Next, - response::Response, -}; -use hyper::StatusCode; -use rust_mcp_transport::SessionId; - -use crate::hyper_servers::app_state::AppState; - -// Middleware to generate and attach a session ID -pub async fn generate_session_id( - State(state): State>, - mut request: Request, - next: Next, -) -> Result { - let session_id: SessionId = state.id_generator.generate(); - request.extensions_mut().insert(session_id); - // Proceed to the next middleware or handler - Ok(next.run(request).await) -} diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes.rs index b1b15fc..4ae274b 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes.rs @@ -1,10 +1,12 @@ pub mod fallback_routes; -mod hyper_utils; pub mod messages_routes; +#[cfg(feature = "sse")] pub mod sse_routes; pub mod streamable_http_routes; -use super::{app_state::AppState, HyperServerOptions}; +use crate::mcp_http::McpAppState; + +use super::HyperServerOptions; use axum::Router; use std::sync::Arc; @@ -19,21 +21,23 @@ use std::sync::Arc; /// /// # Returns /// * `Router` - An Axum router configured with all application routes and state -pub fn app_routes(state: Arc, server_options: &HyperServerOptions) -> Router { +pub fn app_routes(state: Arc, server_options: &HyperServerOptions) -> Router { let router: Router = Router::new() .merge(streamable_http_routes::routes( - state.clone(), server_options.streamable_http_endpoint(), )) .merge({ let mut r = Router::new(); + #[cfg(feature = "sse")] if server_options.sse_support { r = r .merge(sse_routes::routes( - state.clone(), server_options.sse_endpoint(), + server_options.sse_messages_endpoint(), + )) + .merge(messages_routes::routes( + server_options.sse_messages_endpoint(), )) - .merge(messages_routes::routes(state.clone())) } r }) diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs deleted file mode 100644 index 7101a73..0000000 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs +++ /dev/null @@ -1,502 +0,0 @@ -use crate::{ - error::SdkResult, - hyper_servers::{ - app_state::AppState, - error::{TransportServerError, TransportServerResult}, - }, - mcp_runtimes::server_runtime::DEFAULT_STREAM_ID, - mcp_server::{server_runtime, ServerRuntime}, - mcp_traits::{mcp_handler::McpServerHandler, IdGenerator}, - utils::validate_mcp_protocol_version, -}; - -use crate::schema::schema_utils::{ClientMessage, SdkError}; - -use axum::{http::HeaderValue, response::IntoResponse}; -use axum::{ - response::{ - sse::{Event, KeepAlive}, - Sse, - }, - Json, -}; -use futures::stream; -use hyper::{header, HeaderMap, StatusCode}; -use rust_mcp_transport::{ - EventId, McpDispatch, SessionId, SseTransport, StreamId, ID_SEPARATOR, - MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, -}; -use std::{sync::Arc, time::Duration}; -use tokio::io::{duplex, AsyncBufReadExt, BufReader}; - -const DUPLEX_BUFFER_SIZE: usize = 8192; - -async fn create_sse_stream( - runtime: Arc, - session_id: SessionId, - state: Arc, - payload: Option<&str>, - standalone: bool, - last_event_id: Option, -) -> TransportServerResult> { - let payload_string = payload.map(|p| p.to_string()); - - // TODO: this logic should be moved out after refactoing the mcp_stream.rs - let payload_contains_request = payload_string - .as_ref() - .map(|json_str| contains_request(json_str)) - .unwrap_or(Ok(false)); - let Ok(payload_contains_request) = payload_contains_request else { - return Ok((StatusCode::BAD_REQUEST, Json(SdkError::parse_error())).into_response()); - }; - - // readable stream of string to be used in transport - let (read_tx, read_rx) = duplex(DUPLEX_BUFFER_SIZE); - // writable stream to deliver message to the client - let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE); - - let session_id = Arc::new(session_id); - let stream_id: Arc = if standalone { - Arc::new(DEFAULT_STREAM_ID.to_string()) - } else { - Arc::new(state.stream_id_gen.generate()) - }; - - let event_store = state.event_store.as_ref().map(Arc::clone); - let resumability_enabled = event_store.is_some(); - - let mut transport = SseTransport::::new( - read_rx, - write_tx, - read_tx, - Arc::clone(&state.transport_options), - ) - .map_err(|err| TransportServerError::TransportError(err.to_string()))?; - if let Some(event_store) = event_store.clone() { - transport.make_resumable((*session_id).clone(), (*stream_id).clone(), event_store); - } - let transport = Arc::new(transport); - - let ping_interval = state.ping_interval; - let runtime_clone = Arc::clone(&runtime); - let stream_id_clone = stream_id.clone(); - let transport_clone = transport.clone(); - - //Start the server runtime - tokio::spawn(async move { - match runtime_clone - .start_stream( - transport_clone, - &stream_id_clone, - ping_interval, - payload_string, - ) - .await - { - Ok(_) => tracing::trace!("stream {} exited gracefully.", &stream_id_clone), - Err(err) => tracing::info!("stream {} exited with error : {}", &stream_id_clone, err), - } - let _ = runtime.remove_transport(&stream_id_clone).await; - }); - - // Construct SSE stream - let reader = BufReader::new(write_rx); - - // send outgoing messages from server to the client over the sse stream - let message_stream = stream::unfold(reader, move |mut reader| { - async move { - let mut line = String::new(); - - match reader.read_line(&mut line).await { - Ok(0) => None, // EOF - Ok(_) => { - let trimmed_line = line.trim_end_matches('\n').to_owned(); - - // empty sse comment to keep-alive - if is_empty_sse_message(&trimmed_line) { - return Some((Ok(Event::default()), reader)); - } - - let (event_id, message) = match ( - resumability_enabled, - trimmed_line.split_once(char::from(ID_SEPARATOR)), - ) { - (true, Some((id, msg))) => (Some(id.to_string()), msg.to_string()), - _ => (None, trimmed_line), - }; - - let event = match event_id { - Some(id) => Event::default().data(message).id(id), - None => Event::default().data(message), - }; - - Some((Ok(event), reader)) - } - Err(e) => Some((Err(e), reader)), - } - } - }); - - let sse_stream = - Sse::new(message_stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(10))); - - // Return SSE response with keep-alive - // Create a Response and set headers - let mut response = sse_stream.into_response(); - response.headers_mut().insert( - MCP_SESSION_ID_HEADER, - HeaderValue::from_str(&session_id).unwrap(), - ); - - // if last_event_id exists we replay messages from the event-store - tokio::spawn(async move { - if let Some(last_event_id) = last_event_id { - if let Some(event_store) = state.event_store.as_ref() { - if let Some(events) = event_store.events_after(last_event_id).await { - for message_payload in events.messages { - // skip storing replay messages - let error = transport.write_str(&message_payload, true).await; - if let Err(error) = error { - tracing::trace!("Error replaying message: {error}") - } - } - } - } - } - }); - - if !payload_contains_request { - *response.status_mut() = StatusCode::ACCEPTED; - } - Ok(response) -} - -// TODO: this function will be removed after refactoring the readable stream of the transports -// so we would deserialize the string syncronousely and have more control over the flow -// this function may incur a slight runtime cost which could be avoided after refactoring -fn contains_request(json_str: &str) -> Result { - let value: serde_json::Value = serde_json::from_str(json_str)?; - match value { - serde_json::Value::Object(obj) => Ok(obj.contains_key("id") && obj.contains_key("method")), - serde_json::Value::Array(arr) => Ok(arr.iter().any(|item| { - item.as_object() - .map(|obj| obj.contains_key("id") && obj.contains_key("method")) - .unwrap_or(false) - })), - _ => Ok(false), - } -} - -fn is_result(json_str: &str) -> Result { - let value: serde_json::Value = serde_json::from_str(json_str)?; - match value { - serde_json::Value::Object(obj) => Ok(obj.contains_key("result")), - serde_json::Value::Array(arr) => Ok(arr.iter().all(|item| { - item.as_object() - .map(|obj| obj.contains_key("result")) - .unwrap_or(false) - })), - _ => Ok(false), - } -} - -pub async fn create_standalone_stream( - session_id: SessionId, - last_event_id: Option, - state: Arc, -) -> TransportServerResult> { - let runtime = state.session_store.get(&session_id).await.ok_or( - TransportServerError::SessionIdInvalid(session_id.to_string()), - )?; - let runtime = runtime.lock().await.to_owned(); - - if runtime.stream_id_exists(DEFAULT_STREAM_ID).await { - let error = - SdkError::bad_request().with_message("Only one SSE stream is allowed per session"); - return Ok((StatusCode::CONFLICT, Json(error)).into_response()); - } - - if let Some(last_event_id) = last_event_id.as_ref() { - tracing::trace!( - "SSE stream re-connected with last-event-id: {}", - last_event_id - ); - } - - let mut response = create_sse_stream( - runtime.clone(), - session_id.clone(), - state.clone(), - None, - true, - last_event_id, - ) - .await?; - *response.status_mut() = StatusCode::OK; - Ok(response) -} - -pub async fn start_new_session( - state: Arc, - payload: &str, -) -> TransportServerResult> { - let session_id: SessionId = state.id_generator.generate(); - - let h: Arc = state.handler.clone(); - // create a new server instance with unique session_id and - let runtime: Arc = server_runtime::create_server_instance( - Arc::clone(&state.server_details), - h, - session_id.to_owned(), - ); - - tracing::info!("a new client joined : {}", &session_id); - - let response = create_sse_stream( - runtime.clone(), - session_id.clone(), - state.clone(), - Some(payload), - false, - None, - ) - .await; - - if response.is_ok() { - state - .session_store - .set(session_id.to_owned(), runtime.clone()) - .await; - } - response -} - -async fn single_shot_stream( - runtime: Arc, - session_id: SessionId, - state: Arc, - payload: Option<&str>, - standalone: bool, -) -> TransportServerResult> { - // readable stream of string to be used in transport - let (read_tx, read_rx) = duplex(DUPLEX_BUFFER_SIZE); - // writable stream to deliver message to the client - let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE); - - let transport = SseTransport::::new( - read_rx, - write_tx, - read_tx, - Arc::clone(&state.transport_options), - ) - .map_err(|err| TransportServerError::TransportError(err.to_string()))?; - - let stream_id = if standalone { - DEFAULT_STREAM_ID.to_string() - } else { - state.id_generator.generate() - }; - let ping_interval = state.ping_interval; - let runtime_clone = Arc::clone(&runtime); - - let payload_string = payload.map(|p| p.to_string()); - - tokio::spawn(async move { - match runtime_clone - .start_stream( - Arc::new(transport), - &stream_id, - ping_interval, - payload_string, - ) - .await - { - Ok(_) => tracing::info!("stream {} exited gracefully.", &stream_id), - Err(err) => tracing::info!("stream {} exited with error : {}", &stream_id, err), - } - let _ = runtime.remove_transport(&stream_id).await; - }); - - let mut reader = BufReader::new(write_rx); - let mut line = String::new(); - let response = match reader.read_line(&mut line).await { - Ok(0) => None, // EOF - Ok(_) => { - let trimmed_line = line.trim_end_matches('\n').to_owned(); - Some(Ok(trimmed_line)) - } - Err(e) => Some(Err(e)), - }; - - let mut headers = HeaderMap::new(); - headers.insert( - header::CONTENT_TYPE, - HeaderValue::from_static("application/json"), - ); - headers.insert( - MCP_SESSION_ID_HEADER, - HeaderValue::from_str(&session_id).unwrap(), - ); - - match response { - Some(response_result) => match response_result { - Ok(response_str) => { - Ok((StatusCode::OK, headers, response_str.to_string()).into_response()) - } - Err(err) => Ok(( - StatusCode::INTERNAL_SERVER_ERROR, - headers, - Json(err.to_string()), - ) - .into_response()), - }, - None => Ok(( - StatusCode::UNPROCESSABLE_ENTITY, - headers, - Json("End of the transport stream reached."), - ) - .into_response()), - } -} - -pub async fn process_incoming_message_return( - session_id: SessionId, - state: Arc, - payload: &str, -) -> TransportServerResult { - match state.session_store.get(&session_id).await { - Some(runtime) => { - let runtime = runtime.lock().await.to_owned(); - - single_shot_stream( - runtime.clone(), - session_id.clone(), - state.clone(), - Some(payload), - false, - ) - .await - // Ok(StatusCode::OK.into_response()) - } - None => { - let error = SdkError::session_not_found(); - Ok((StatusCode::NOT_FOUND, Json(error)).into_response()) - } - } -} - -pub async fn process_incoming_message( - session_id: SessionId, - state: Arc, - payload: &str, -) -> TransportServerResult { - match state.session_store.get(&session_id).await { - Some(runtime) => { - let runtime = runtime.lock().await.to_owned(); - // when receiving a result in a streamable_http server, that means it was sent by the standalone sse transport - // it should be processed by the same transport , therefore no need to call create_sse_stream - let Ok(is_result) = is_result(payload) else { - return Ok((StatusCode::BAD_REQUEST, Json(SdkError::parse_error())).into_response()); - }; - - if is_result { - match runtime - .consume_payload_string(DEFAULT_STREAM_ID, payload) - .await - { - Ok(()) => Ok((StatusCode::ACCEPTED, Json(())).into_response()), - Err(err) => Ok(( - StatusCode::BAD_REQUEST, - Json(SdkError::internal_error().with_message(err.to_string().as_ref())), - ) - .into_response()), - } - } else { - create_sse_stream( - runtime.clone(), - session_id.clone(), - state.clone(), - Some(payload), - false, - None, - ) - .await - } - } - None => { - let error = SdkError::session_not_found(); - Ok((StatusCode::NOT_FOUND, Json(error)).into_response()) - } - } -} - -pub fn is_empty_sse_message(sse_payload: &str) -> bool { - sse_payload.is_empty() || sse_payload.trim() == ":" -} - -pub async fn delete_session( - session_id: SessionId, - state: Arc, -) -> TransportServerResult { - match state.session_store.get(&session_id).await { - Some(runtime) => { - let runtime = runtime.lock().await.to_owned(); - runtime.shutdown().await; - state.session_store.delete(&session_id).await; - tracing::info!("client disconnected : {}", &session_id); - Ok((StatusCode::OK, Json("ok")).into_response()) - } - None => { - let error = SdkError::session_not_found(); - Ok((StatusCode::NOT_FOUND, Json(error)).into_response()) - } - } -} - -pub fn acceptable_content_type(headers: &HeaderMap) -> bool { - let accept_header = headers - .get("content-type") - .and_then(|val| val.to_str().ok()) - .unwrap_or(""); - accept_header - .split(',') - .any(|val| val.trim().starts_with("application/json")) -} - -pub fn validate_mcp_protocol_version_header(headers: &HeaderMap) -> SdkResult<()> { - let protocol_version_header = headers - .get(MCP_PROTOCOL_VERSION_HEADER) - .and_then(|val| val.to_str().ok()) - .unwrap_or(""); - - // requests without protocol version header are acceptable - if protocol_version_header.is_empty() { - return Ok(()); - } - - validate_mcp_protocol_version(protocol_version_header) -} - -pub fn accepts_event_stream(headers: &HeaderMap) -> bool { - let accept_header = headers - .get("accept") - .and_then(|val| val.to_str().ok()) - .unwrap_or(""); - - accept_header - .split(',') - .any(|val| val.trim().starts_with("text/event-stream")) -} - -pub fn valid_streaming_http_accept_header(headers: &HeaderMap) -> bool { - let accept_header = headers - .get("accept") - .and_then(|val| val.to_str().ok()) - .unwrap_or(""); - - let types: Vec<_> = accept_header.split(',').map(|v| v.trim()).collect(); - - let has_event_stream = types.iter().any(|v| v.starts_with("text/event-stream")); - let has_json = types.iter().any(|v| v.starts_with("application/json")); - has_event_stream && has_json -} diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs index 44b671f..65490a3 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs @@ -1,54 +1,28 @@ use crate::{ - hyper_servers::{ - app_state::AppState, - error::{TransportServerError, TransportServerResult}, - }, - mcp_runtimes::server_runtime::DEFAULT_STREAM_ID, + hyper_servers::error::TransportServerResult, + mcp_http::{McpAppState, McpHttpHandler}, utils::remove_query_and_hash, }; -use axum::{ - extract::{Query, State}, - response::IntoResponse, - routing::post, - Router, -}; -use std::{collections::HashMap, sync::Arc}; +use axum::{extract::State, response::IntoResponse, routing::post, Router}; +use http::{HeaderMap, Method, Uri}; +use std::sync::Arc; -pub fn routes(state: Arc) -> Router> { +pub fn routes(sse_message_endpoint: &str) -> Router> { Router::new().route( - remove_query_and_hash(&state.sse_message_endpoint).as_str(), + remove_query_and_hash(sse_message_endpoint).as_str(), post(handle_messages), ) } pub async fn handle_messages( - State(state): State>, - Query(params): Query>, + uri: Uri, + headers: HeaderMap, + State(state): State>, message: String, ) -> TransportServerResult { - let session_id = params - .get("sessionId") - .ok_or(TransportServerError::SessionIdMissing)?; - - // transmit to the readable stream, that transport is reading from - let transmit = - state - .session_store - .get(session_id) - .await - .ok_or(TransportServerError::SessionIdInvalid( - session_id.to_string(), - ))?; - - let transmit = transmit.lock().await; - - transmit - .consume_payload_string(DEFAULT_STREAM_ID, &message) - .await - .map_err(|err| { - tracing::trace!("{}", err); - TransportServerError::StreamIoError(err.to_string()) - })?; - - Ok(axum::http::StatusCode::ACCEPTED) + let request = McpHttpHandler::create_request(Method::POST, uri, headers, Some(&message)); + let generic_response = McpHttpHandler::handle_sse_message(request, state.clone()).await?; + let (parts, body) = generic_response.into_parts(); + let resp = axum::response::Response::from_parts(parts, axum::body::Body::new(body)); + Ok(resp) } diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs index 27a16b2..e13c724 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs @@ -1,47 +1,10 @@ -use crate::mcp_server::error::TransportServerError; -use crate::schema::schema_utils::ClientMessage; -use crate::{ - hyper_servers::{ - app_state::AppState, - error::TransportServerResult, - middlewares::{ - protect_dns_rebinding::protect_dns_rebinding, session_id_gen::generate_session_id, - }, - }, - mcp_runtimes::server_runtime::DEFAULT_STREAM_ID, - mcp_server::{server_runtime, ServerRuntime}, - mcp_traits::mcp_handler::McpServerHandler, -}; -use axum::{ - extract::State, - middleware, - response::{ - sse::{Event, KeepAlive}, - IntoResponse, Sse, - }, - routing::get, - Extension, Router, -}; -use futures::stream::{self}; -use rust_mcp_transport::{SessionId, SseTransport}; -use std::{convert::Infallible, sync::Arc, time::Duration}; -use tokio::io::{duplex, AsyncBufReadExt, BufReader}; -use tokio_stream::StreamExt; +use crate::hyper_servers::error::TransportServerResult; +use crate::mcp_http::{McpAppState, McpHttpHandler}; +use axum::{extract::State, response::IntoResponse, routing::get, Extension, Router}; +use std::sync::Arc; -const DUPLEX_BUFFER_SIZE: usize = 8192; - -/// Creates an initial SSE event that returns the messages endpoint -/// -/// Constructs an SSE event containing the messages endpoint URL with the session ID. -/// -/// # Arguments -/// * `session_id` - The session identifier for the client -/// -/// # Returns -/// * `Result` - The constructed SSE event, infallible -fn initial_event(endpoint: &str) -> Result { - Ok(Event::default().event("endpoint").data(endpoint)) -} +#[derive(Clone)] +pub struct SseMessageEndpoint(pub String); /// Configures the SSE routes for the application /// @@ -52,18 +15,13 @@ fn initial_event(endpoint: &str) -> Result { /// * `sse_endpoint` - The path for the SSE endpoint /// /// # Returns -/// * `Router>` - An Axum router configured with the SSE route -pub fn routes(state: Arc, sse_endpoint: &str) -> Router> { - Router::new() - .route(sse_endpoint, get(handle_sse)) - .route_layer(middleware::from_fn_with_state( - state.clone(), - generate_session_id, - )) - .route_layer(middleware::from_fn_with_state( - state.clone(), - protect_dns_rebinding, - )) +/// * `Router>` - An Axum router configured with the SSE route +pub fn routes(sse_endpoint: &str, sse_message_endpoint: &str) -> Router> { + let sse_message_endpoint = SseMessageEndpoint(sse_message_endpoint.to_string()); + Router::new().route( + sse_endpoint, + get(handle_sse).layer(Extension(sse_message_endpoint)), + ) } /// Handles Server-Sent Events (SSE) connections @@ -77,91 +35,13 @@ pub fn routes(state: Arc, sse_endpoint: &str) -> Router> /// # Returns /// * `TransportServerResult` - The SSE response stream or an error pub async fn handle_sse( - Extension(session_id): Extension, - State(state): State>, + Extension(sse_message_endpoint): Extension, + State(state): State>, ) -> TransportServerResult { - let messages_endpoint = - SseTransport::::message_endpoint(&state.sse_message_endpoint, &session_id); - - // readable stream of string to be used in transport - // writing string to read_tx will be received as messages inside the transport and messages will be processed - let (read_tx, read_rx) = duplex(DUPLEX_BUFFER_SIZE); - - // writable stream to deliver message to the client - let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE); - - // create a transport for sending/receiving messages - let Ok(transport) = SseTransport::new( - read_rx, - write_tx, - read_tx, - Arc::clone(&state.transport_options), - ) else { - return Err(TransportServerError::TransportError( - "Failed to create SSE transport".to_string(), - )); - }; - - let h: Arc = state.handler.clone(); - // create a new server instance with unique session_id and - let server: Arc = server_runtime::create_server_instance( - Arc::clone(&state.server_details), - h, - session_id.to_owned(), - ); - - state - .session_store - .set(session_id.to_owned(), server.clone()) - .await; - - tracing::info!("A new client joined : {}", session_id.to_owned()); - - // Start the server - tokio::spawn(async move { - match server - .start_stream( - Arc::new(transport), - DEFAULT_STREAM_ID, - state.ping_interval, - None, - ) - .await - { - Ok(_) => tracing::info!("server {} exited gracefully.", session_id.to_owned()), - Err(err) => tracing::info!( - "server {} exited with error : {}", - session_id.to_owned(), - err - ), - }; - - state.session_store.delete(&session_id).await; - }); - - // Initial SSE message to inform the client about the server's endpoint - let initial_event = stream::once(async move { initial_event(&messages_endpoint) }); - - // Construct SSE stream - let reader = BufReader::new(write_rx); - - let message_stream = stream::unfold(reader, |mut reader| async move { - let mut line = String::new(); - - match reader.read_line(&mut line).await { - Ok(0) => None, // EOF - Ok(_) => { - let trimmed_line = line.trim_end_matches('\n').to_owned(); - Some((Ok(Event::default().data(trimmed_line)), reader)) - } - Err(_) => None, // Err(e) => Some((Err(e), reader)), - } - }); - - let stream = initial_event.chain(message_stream); - let sse_stream = - Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(10))); - - // Return SSE response with keep-alive - Ok(sse_stream) + let SseMessageEndpoint(sse_message_endpoint) = sse_message_endpoint; + let generic_response = + McpHttpHandler::handle_sse_connection(state.clone(), Some(&sse_message_endpoint)).await?; + let (parts, body) = generic_response.into_parts(); + let resp = axum::response::Response::from_parts(parts, axum::body::Body::new(body)); + Ok(resp) } diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs index 67f8679..6f2e470 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs @@ -1,32 +1,16 @@ -use super::hyper_utils::start_new_session; -use crate::schema::schema_utils::SdkError; -use crate::{ - error::McpSdkError, - hyper_servers::{ - app_state::AppState, - error::TransportServerResult, - middlewares::protect_dns_rebinding::protect_dns_rebinding, - routes::hyper_utils::{ - acceptable_content_type, accepts_event_stream, create_standalone_stream, - delete_session, process_incoming_message, process_incoming_message_return, - valid_streaming_http_accept_header, validate_mcp_protocol_version_header, - }, - }, - utils::valid_initialize_method, -}; +use crate::hyper_servers::error::TransportServerResult; +use crate::mcp_http::{McpAppState, McpHttpHandler}; use axum::routing::get; use axum::{ extract::{Query, State}, - middleware, response::IntoResponse, routing::{delete, post}, - Json, Router, + Router, }; -use hyper::{HeaderMap, StatusCode}; -use rust_mcp_transport::{SessionId, MCP_LAST_EVENT_ID_HEADER, MCP_SESSION_ID_HEADER}; +use http::{HeaderMap, Method, Uri}; use std::{collections::HashMap, sync::Arc}; -pub fn routes(state: Arc, streamable_http_endpoint: &str) -> Router> { +pub fn routes(streamable_http_endpoint: &str) -> Router> { Router::new() .route(streamable_http_endpoint, get(handle_streamable_http_get)) .route(streamable_http_endpoint, post(handle_streamable_http_post)) @@ -34,129 +18,43 @@ pub fn routes(state: Arc, streamable_http_endpoint: &str) -> Router>, + uri: Uri, + State(state): State>, ) -> TransportServerResult { - if !accepts_event_stream(&headers) { - let error = SdkError::bad_request().with_message(r#"Client must accept text/event-stream"#); - return Ok((StatusCode::NOT_ACCEPTABLE, Json(error)).into_response()); - } - - if let Err(parse_error) = validate_mcp_protocol_version_header(&headers) { - let error = - SdkError::bad_request().with_message(format!(r#"Bad Request: {parse_error}"#).as_str()); - return Ok((StatusCode::BAD_REQUEST, Json(error)).into_response()); - } - - let session_id: Option = headers - .get(MCP_SESSION_ID_HEADER) - .and_then(|value| value.to_str().ok()) - .map(|s| s.to_string()); - - let last_event_id: Option = headers - .get(MCP_LAST_EVENT_ID_HEADER) - .and_then(|value| value.to_str().ok()) - .map(|s| s.to_string()); - - match session_id { - Some(session_id) => { - let res = create_standalone_stream(session_id, last_event_id, state).await?; - Ok(res.into_response()) - } - None => { - let error = SdkError::bad_request().with_message("Bad request: session not found"); - Ok((StatusCode::BAD_REQUEST, Json(error)).into_response()) - } - } + let request = McpHttpHandler::create_request(Method::GET, uri, headers, None); + let generic_res = McpHttpHandler::handle_streamable_http(request, state).await?; + let (parts, body) = generic_res.into_parts(); + let resp = axum::response::Response::from_parts(parts, axum::body::Body::new(body)); + Ok(resp) } pub async fn handle_streamable_http_post( headers: HeaderMap, - State(state): State>, + uri: Uri, + State(state): State>, Query(_params): Query>, payload: String, ) -> TransportServerResult { - if !valid_streaming_http_accept_header(&headers) { - let error = SdkError::bad_request() - .with_message(r#"Client must accept both application/json and text/event-stream"#); - return Ok((StatusCode::NOT_ACCEPTABLE, Json(error)).into_response()); - } - - if !acceptable_content_type(&headers) { - let error = SdkError::bad_request() - .with_message(r#"Unsupported Media Type: Content-Type must be application/json"#); - return Ok((StatusCode::UNSUPPORTED_MEDIA_TYPE, Json(error)).into_response()); - } - - if let Err(parse_error) = validate_mcp_protocol_version_header(&headers) { - let error = - SdkError::bad_request().with_message(format!(r#"Bad Request: {parse_error}"#).as_str()); - return Ok((StatusCode::BAD_REQUEST, Json(error)).into_response()); - } - - let session_id: Option = headers - .get(MCP_SESSION_ID_HEADER) - .and_then(|value| value.to_str().ok()) - .map(|s| s.to_string()); - - //TODO: validate reconnect after disconnect - - match session_id { - // has session-id => write to the existing stream - Some(id) => { - if state.enable_json_response { - let res = process_incoming_message_return(id, state, &payload).await?; - Ok(res.into_response()) - } else { - let res = process_incoming_message(id, state, &payload).await?; - Ok(res.into_response()) - } - } - None => match valid_initialize_method(&payload) { - Ok(_) => { - return start_new_session(state, &payload).await; - } - Err(McpSdkError::SdkError(error)) => { - Ok((StatusCode::BAD_REQUEST, Json(error)).into_response()) - } - Err(error) => { - let error = SdkError::bad_request().with_message(&error.to_string()); - Ok((StatusCode::BAD_REQUEST, Json(error)).into_response()) - } - }, - } + let request = + McpHttpHandler::create_request(Method::POST, uri, headers, Some(payload.as_str())); + let generic_res = McpHttpHandler::handle_streamable_http(request, state).await?; + let (parts, body) = generic_res.into_parts(); + let resp = axum::response::Response::from_parts(parts, axum::body::Body::new(body)); + Ok(resp) } pub async fn handle_streamable_http_delete( headers: HeaderMap, - State(state): State>, + uri: Uri, + State(state): State>, ) -> TransportServerResult { - if let Err(parse_error) = validate_mcp_protocol_version_header(&headers) { - let error = - SdkError::bad_request().with_message(format!(r#"Bad Request: {parse_error}"#).as_str()); - return Ok((StatusCode::BAD_REQUEST, Json(error)).into_response()); - } - - let session_id: Option = headers - .get(MCP_SESSION_ID_HEADER) - .and_then(|value| value.to_str().ok()) - .map(|s| s.to_string()); - - match session_id { - Some(id) => { - let res = delete_session(id, state).await; - Ok(res.into_response()) - } - None => { - let error = SdkError::bad_request().with_message("Bad Request: Session not found"); - Ok((StatusCode::BAD_REQUEST, Json(error)).into_response()) - } - } + let request = McpHttpHandler::create_request(Method::DELETE, uri, headers, None); + let generic_res = McpHttpHandler::handle_streamable_http(request, state).await?; + let (parts, body) = generic_res.into_parts(); + let resp = axum::response::Response::from_parts(parts, axum::body::Body::new(body)); + Ok(resp) } diff --git a/crates/rust-mcp-sdk/src/hyper_servers/server.rs b/crates/rust-mcp-sdk/src/hyper_servers/server.rs index 71bccee..881d4b3 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/server.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/server.rs @@ -1,6 +1,12 @@ use crate::{ error::SdkResult, id_generator::{FastIdGenerator, UuidGenerator}, + mcp_http::{ + utils::{ + DEFAULT_MESSAGES_ENDPOINT, DEFAULT_SSE_ENDPOINT, DEFAULT_STREAMABLE_HTTP_ENDPOINT, + }, + InMemorySessionStore, McpAppState, + }, mcp_server::hyper_runtime::HyperRuntime, mcp_traits::{mcp_handler::McpServerHandler, IdGenerator}, }; @@ -16,10 +22,8 @@ use std::{ use tokio::signal; use super::{ - app_state::AppState, error::{TransportServerError, TransportServerResult}, routes::app_routes, - InMemorySessionStore, }; use crate::schema::InitializeResult; use axum::Router; @@ -28,12 +32,6 @@ use rust_mcp_transport::{event_store::EventStore, SessionId, TransportOptions}; // Default client ping interval (12 seconds) const DEFAULT_CLIENT_PING_INTERVAL: Duration = Duration::from_secs(12); const GRACEFUL_SHUTDOWN_TMEOUT_SECS: u64 = 5; -// Default Server-Sent Events (SSE) endpoint path -const DEFAULT_SSE_ENDPOINT: &str = "/sse"; -// Default MCP Messages endpoint path -const DEFAULT_MESSAGES_ENDPOINT: &str = "/messages"; -// Default Streamable HTTP endpoint path -const DEFAULT_STREAMABLE_HTTP_ENDPOINT: &str = "/mcp"; /// Configuration struct for the Hyper server /// Used to configure the HyperServer instance. @@ -237,7 +235,7 @@ impl Default for HyperServerOptions { /// Hyper server struct for managing the Axum-based web server pub struct HyperServer { app: Router, - state: Arc, + state: Arc, pub(crate) options: HyperServerOptions, handle: Handle, } @@ -259,7 +257,7 @@ impl HyperServer { handler: Arc, mut server_options: HyperServerOptions, ) -> Self { - let state: Arc = Arc::new(AppState { + let state: Arc = Arc::new(McpAppState { session_store: Arc::new(InMemorySessionStore::new()), id_generator: server_options .session_id_generator @@ -269,8 +267,6 @@ impl HyperServer { server_details: Arc::new(server_details), handler, ping_interval: server_options.ping_interval, - sse_message_endpoint: server_options.sse_messages_endpoint().to_owned(), - http_streamable_endpoint: server_options.streamable_http_endpoint().to_owned(), transport_options: Arc::clone(&server_options.transport_options), enable_json_response: server_options.enable_json_response.unwrap_or(false), allowed_hosts: server_options.allowed_hosts.take(), @@ -290,8 +286,8 @@ impl HyperServer { /// Returns a shared reference to the application state /// /// # Returns - /// * `Arc` - Shared application state - pub fn state(&self) -> Arc { + /// * `Arc` - Shared application state + pub fn state(&self) -> Arc { Arc::clone(&self.state) } @@ -451,7 +447,7 @@ impl HyperServer { } // Shutdown signal handler -async fn shutdown_signal(handle: Handle, state: Arc) { +async fn shutdown_signal(handle: Handle, state: Arc) { // Wait for a Ctrl+C or SIGTERM signal let ctrl_c = async { signal::ctrl_c() diff --git a/crates/rust-mcp-sdk/src/lib.rs b/crates/rust-mcp-sdk/src/lib.rs index a33f889..2f88673 100644 --- a/crates/rust-mcp-sdk/src/lib.rs +++ b/crates/rust-mcp-sdk/src/lib.rs @@ -2,6 +2,8 @@ pub mod error; #[cfg(feature = "hyper-server")] mod hyper_servers; mod mcp_handlers; +#[cfg(feature = "hyper-server")] +pub(crate) mod mcp_http; mod mcp_macros; mod mcp_runtimes; mod mcp_traits; diff --git a/crates/rust-mcp-sdk/src/mcp_http.rs b/crates/rust-mcp-sdk/src/mcp_http.rs new file mode 100644 index 0000000..59cfedc --- /dev/null +++ b/crates/rust-mcp-sdk/src/mcp_http.rs @@ -0,0 +1,12 @@ +mod app_state; +mod mcp_http_handler; +pub(crate) mod mcp_http_utils; +mod session_store; + +pub(crate) use app_state::*; +pub use mcp_http_handler::*; +pub use session_store::*; + +pub(crate) mod utils { + pub use super::mcp_http_utils::*; +} diff --git a/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs b/crates/rust-mcp-sdk/src/mcp_http/app_state.rs similarity index 91% rename from crates/rust-mcp-sdk/src/hyper_servers/app_state.rs rename to crates/rust-mcp-sdk/src/mcp_http/app_state.rs index f96b261..95ae297 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/app_state.rs @@ -8,20 +8,18 @@ use rust_mcp_transport::event_store::EventStore; use rust_mcp_transport::{SessionId, TransportOptions}; -/// Application state struct for the Hyper server +/// Application state struct for the Hyper ser /// /// Holds shared, thread-safe references to session storage, ID generator, /// server details, handler, ping interval, and transport options. #[derive(Clone)] -pub struct AppState { +pub struct McpAppState { pub session_store: Arc, pub id_generator: Arc>, pub stream_id_gen: Arc, pub server_details: Arc, pub handler: Arc, pub ping_interval: Duration, - pub sse_message_endpoint: String, - pub http_streamable_endpoint: String, pub transport_options: Arc, pub enable_json_response: bool, /// List of allowed host header values for DNS rebinding protection. @@ -38,7 +36,7 @@ pub struct AppState { pub event_store: Option>, } -impl AppState { +impl McpAppState { pub fn needs_dns_protection(&self) -> bool { self.dns_rebinding_protection && (self.allowed_hosts.is_some() || self.allowed_origins.is_some()) diff --git a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs new file mode 100644 index 0000000..fb830ae --- /dev/null +++ b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs @@ -0,0 +1,306 @@ +#[cfg(feature = "sse")] +use super::utils::handle_sse_connection; +use crate::mcp_http::utils::{ + accepts_event_stream, error_response, query_param, validate_mcp_protocol_version_header, +}; +use crate::mcp_runtimes::server_runtime::DEFAULT_STREAM_ID; +use crate::mcp_server::error::TransportServerError; +use crate::schema::schema_utils::SdkError; +use crate::{ + error::McpSdkError, + mcp_http::{ + utils::{ + acceptable_content_type, create_standalone_stream, delete_session, + process_incoming_message, process_incoming_message_return, protect_dns_rebinding, + start_new_session, valid_streaming_http_accept_header, GenericBody, + }, + McpAppState, + }, + mcp_server::error::TransportServerResult, + utils::valid_initialize_method, +}; +use bytes::Bytes; +use http::{self, HeaderMap, Method, StatusCode, Uri}; +use http_body_util::{BodyExt, Full}; +use rust_mcp_transport::{SessionId, MCP_LAST_EVENT_ID_HEADER, MCP_SESSION_ID_HEADER}; +use std::sync::Arc; + +pub struct McpHttpHandler {} + +impl McpHttpHandler { + /// Creates a new HTTP request with the given method, URI, headers, and optional body. + /// + /// # Arguments + /// + /// * `method` - The HTTP method to use (e.g., GET, POST). + /// * `uri` - The target URI for the request. + /// * `headers` - A map of optional header keys and their corresponding values. + /// * `body` - An optional string slice representing the request body. + /// + /// # Returns + /// + /// An `http::Request<&str>` initialized with the specified method, URI, headers, and body. + /// If the `body` is `None`, an empty string is used as the default. + /// + pub fn create_request( + method: Method, + uri: Uri, + headers: HeaderMap, + body: Option<&str>, + ) -> http::Request<&str> { + let mut request = http::Request::default(); + *request.method_mut() = method; + *request.uri_mut() = uri; + *request.body_mut() = body.unwrap_or_default(); + let req_headers = request.headers_mut(); + for (key, value) in headers { + if let Some(k) = key { + req_headers.insert(k, value); + } + } + request + } +} + +impl McpHttpHandler { + /// Handles an MCP connection using the SSE (Server-Sent Events) transport. + /// + /// This function serves as the entry point for initializing and managing a client connection + /// over SSE when the `sse` feature is enabled. + /// + /// # Arguments + /// * `state` - Shared application state required to manage the MCP session. + /// * `sse_message_endpoint` - Optional message endpoint to override the default SSE route (default: `/messages` ). + /// + /// + /// # Features + /// This function is only available when the `sse` feature is enabled. + #[cfg(feature = "sse")] + pub async fn handle_sse_connection( + state: Arc, + sse_message_endpoint: Option<&str>, + ) -> TransportServerResult> { + handle_sse_connection(state, sse_message_endpoint).await + } + + /// Handles incoming MCP messages from the client after an SSE connection is established. + /// + /// This function processes a message sent by the client as part of an active SSE session. It: + /// - Extracts the `sessionId` from the request query parameters. + /// - Locates the corresponding session's transmit channel. + /// - Forwards the incoming message payload to the MCP transport stream for consumption. + /// # Arguments + /// * `request` - The HTTP request containing the message body and query parameters (including `sessionId`). + /// * `state` - Shared application state, including access to the session store. + /// + /// # Returns + /// * `TransportServerResult>`: + /// - Returns a `202 Accepted` HTTP response if the message is successfully forwarded. + /// - Returns an error if the session ID is missing, invalid, or if any I/O issues occur while processing the message. + /// + /// # Errors + /// - `SessionIdMissing`: if the `sessionId` query parameter is not present. + /// - `SessionIdInvalid`: if the session ID does not map to a valid session in the session store. + /// - `StreamIoError`: if an error occurs while writing to the stream. + /// - `HttpError`: if constructing the HTTP response fails. + pub async fn handle_sse_message( + request: http::Request<&str>, + state: Arc, + ) -> TransportServerResult> { + let session_id = + query_param(&request, "sessionId").ok_or(TransportServerError::SessionIdMissing)?; + + // transmit to the readable stream, that transport is reading from + let transmit = state.session_store.get(&session_id).await.ok_or( + TransportServerError::SessionIdInvalid(session_id.to_string()), + )?; + + let transmit = transmit.lock().await; + let message = *request.body(); + transmit + .consume_payload_string(DEFAULT_STREAM_ID, message) + .await + .map_err(|err| { + tracing::trace!("{}", err); + TransportServerError::StreamIoError(err.to_string()) + })?; + + let body = Full::new(Bytes::new()) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed(); + + http::Response::builder() + .status(StatusCode::ACCEPTED) + .body(body) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + } + + /// Handles incoming MCP messages over the StreamableHTTP transport. + /// + /// It supports `GET`, `POST`, and `DELETE` methods for handling streaming operations, and performs optional + /// DNS rebinding protection if it is configured. + /// + /// # Arguments + /// * `request` - The HTTP request from the client, including method, headers, and optional body. + /// * `state` - Shared application state, including configuration and session management. + /// + /// # Behavior + /// - If DNS rebinding protection is enabled via the app state, the function checks the request headers. + /// If dns protection fails, a `403 Forbidden` response is returned. + /// - Dispatches the request to method-specific handlers based on the HTTP method: + /// - `GET` → `handle_http_get` + /// - `POST` → `handle_http_post` + /// - `DELETE` → `handle_http_delete` + /// - Returns `405 Method Not Allowed` for unsupported methods. + /// + /// # Returns + /// * A `TransportServerResult` wrapping an HTTP response indicating success or failure of the operation. + /// + pub async fn handle_streamable_http( + request: http::Request<&str>, + state: Arc, + ) -> TransportServerResult> { + // Enforces DNS rebinding protection if required by state. + // If protection fails, respond with HTTP 403 Forbidden. + if state.needs_dns_protection() { + if let Err(error) = protect_dns_rebinding(request.headers(), state.clone()).await { + return error_response(StatusCode::FORBIDDEN, error); + } + } + + let method = request.method(); + match method { + &http::Method::GET => return Self::handle_http_get(request, state).await, + &http::Method::POST => return Self::handle_http_post(request, state).await, + &http::Method::DELETE => return Self::handle_http_delete(request, state).await, + other => { + let error = SdkError::bad_request().with_message(&format!( + "'{other}' is not a valid HTTP method for StreamableHTTP transport." + )); + error_response(StatusCode::METHOD_NOT_ALLOWED, error) + } + } + } + + /// Processes POST requests for the Streamable HTTP Protocol + async fn handle_http_post( + request: http::Request<&str>, + state: Arc, + ) -> TransportServerResult> { + let headers = request.headers(); + + if !valid_streaming_http_accept_header(headers) { + let error = SdkError::bad_request() + .with_message(r#"Client must accept both application/json and text/event-stream"#); + return error_response(StatusCode::NOT_ACCEPTABLE, error); + } + + if !acceptable_content_type(headers) { + let error = SdkError::bad_request() + .with_message(r#"Unsupported Media Type: Content-Type must be application/json"#); + return error_response(StatusCode::UNSUPPORTED_MEDIA_TYPE, error); + } + + if let Err(parse_error) = validate_mcp_protocol_version_header(headers) { + let error = SdkError::bad_request() + .with_message(format!(r#"Bad Request: {parse_error}"#).as_str()); + return error_response(StatusCode::BAD_REQUEST, error); + } + + let session_id: Option = headers + .get(MCP_SESSION_ID_HEADER) + .and_then(|value| value.to_str().ok()) + .map(|s| s.to_string()); + + let payload = *request.body(); + + match session_id { + // has session-id => write to the existing stream + Some(id) => { + if state.enable_json_response { + process_incoming_message_return(id, state, payload).await + } else { + process_incoming_message(id, state, payload).await + } + } + None => match valid_initialize_method(payload) { + Ok(_) => { + return start_new_session(state, payload).await; + } + Err(McpSdkError::SdkError(error)) => error_response(StatusCode::BAD_REQUEST, error), + Err(error) => { + let error = SdkError::bad_request().with_message(&error.to_string()); + error_response(StatusCode::BAD_REQUEST, error) + } + }, + } + } + + /// Processes GET requests for the Streamable HTTP Protocol + async fn handle_http_get( + request: http::Request<&str>, + state: Arc, + ) -> TransportServerResult> { + let headers = request.headers(); + + if !accepts_event_stream(headers) { + let error = + SdkError::bad_request().with_message(r#"Client must accept text/event-stream"#); + return error_response(StatusCode::NOT_ACCEPTABLE, error); + } + + if let Err(parse_error) = validate_mcp_protocol_version_header(headers) { + let error = SdkError::bad_request() + .with_message(format!(r#"Bad Request: {parse_error}"#).as_str()); + return error_response(StatusCode::BAD_REQUEST, error); + } + + let session_id: Option = headers + .get(MCP_SESSION_ID_HEADER) + .and_then(|value| value.to_str().ok()) + .map(|s| s.to_string()); + + let last_event_id: Option = headers + .get(MCP_LAST_EVENT_ID_HEADER) + .and_then(|value| value.to_str().ok()) + .map(|s| s.to_string()); + + match session_id { + Some(session_id) => { + let res = create_standalone_stream(session_id, last_event_id, state).await; + res + } + None => { + let error = SdkError::bad_request().with_message("Bad request: session not found"); + error_response(StatusCode::BAD_REQUEST, error) + } + } + } + + /// Processes DELETE requests for the Streamable HTTP Protocol + async fn handle_http_delete( + request: http::Request<&str>, + state: Arc, + ) -> TransportServerResult> { + let headers = request.headers(); + + if let Err(parse_error) = validate_mcp_protocol_version_header(headers) { + let error = SdkError::bad_request() + .with_message(format!(r#"Bad Request: {parse_error}"#).as_str()); + return error_response(StatusCode::BAD_REQUEST, error); + } + + let session_id: Option = headers + .get(MCP_SESSION_ID_HEADER) + .and_then(|value| value.to_str().ok()) + .map(|s| s.to_string()); + + match session_id { + Some(id) => delete_session(id, state).await, + None => { + let error = SdkError::bad_request().with_message("Bad Request: Session not found"); + error_response(StatusCode::BAD_REQUEST, error) + } + } + } +} diff --git a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs new file mode 100644 index 0000000..06443e9 --- /dev/null +++ b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs @@ -0,0 +1,759 @@ +use crate::schema::schema_utils::{ClientMessage, SdkError}; +use crate::{ + error::SdkResult, + hyper_servers::error::{TransportServerError, TransportServerResult}, + mcp_http::McpAppState, + mcp_runtimes::server_runtime::DEFAULT_STREAM_ID, + mcp_server::{server_runtime, ServerRuntime}, + mcp_traits::{mcp_handler::McpServerHandler, IdGenerator}, + utils::validate_mcp_protocol_version, +}; +use axum::http::HeaderValue; +use bytes::Bytes; +use futures::stream; +use http::header::{ACCEPT, CONNECTION, CONTENT_TYPE, HOST, ORIGIN}; +use http_body::Frame; +use http_body_util::StreamBody; +use http_body_util::{combinators::BoxBody, BodyExt, Full}; +use hyper::{HeaderMap, StatusCode}; +use rust_mcp_transport::{ + EventId, McpDispatch, SessionId, SseEvent, SseTransport, StreamId, ID_SEPARATOR, + MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, +}; +use std::sync::Arc; +use tokio::io::{duplex, AsyncBufReadExt, BufReader}; +use tokio_stream::StreamExt; + +// Default Server-Sent Events (SSE) endpoint path +pub(crate) const DEFAULT_SSE_ENDPOINT: &str = "/sse"; +// Default MCP Messages endpoint path +pub(crate) const DEFAULT_MESSAGES_ENDPOINT: &str = "/messages"; +// Default Streamable HTTP endpoint path +pub(crate) const DEFAULT_STREAMABLE_HTTP_ENDPOINT: &str = "/mcp"; +const DUPLEX_BUFFER_SIZE: usize = 8192; + +pub type GenericBody = BoxBody; + +/// Creates an initial SSE event that returns the messages endpoint +/// +/// Constructs an SSE event containing the messages endpoint URL with the session ID. +/// +/// # Arguments +/// * `session_id` - The session identifier for the client +/// +/// # Returns +/// * `Result` - The constructed SSE event, infallible +fn initial_sse_event(endpoint: &str) -> Result { + Ok(SseEvent::default() + .with_event("endpoint") + .with_data(endpoint.to_string()) + .as_bytes()) +} + +async fn create_sse_stream( + runtime: Arc, + session_id: SessionId, + state: Arc, + payload: Option<&str>, + standalone: bool, + last_event_id: Option, +) -> TransportServerResult> { + let payload_string = payload.map(|p| p.to_string()); + + // TODO: this logic should be moved out after refactoing the mcp_stream.rs + let payload_contains_request = payload_string + .as_ref() + .map(|json_str| contains_request(json_str)) + .unwrap_or(Ok(false)); + let Ok(payload_contains_request) = payload_contains_request else { + return error_response(StatusCode::BAD_REQUEST, SdkError::parse_error()); + }; + + // readable stream of string to be used in transport + let (read_tx, read_rx) = duplex(DUPLEX_BUFFER_SIZE); + // writable stream to deliver message to the client + let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE); + + let session_id = Arc::new(session_id); + let stream_id: Arc = if standalone { + Arc::new(DEFAULT_STREAM_ID.to_string()) + } else { + Arc::new(state.stream_id_gen.generate()) + }; + + let event_store = state.event_store.as_ref().map(Arc::clone); + let resumability_enabled = event_store.is_some(); + + let mut transport = SseTransport::::new( + read_rx, + write_tx, + read_tx, + Arc::clone(&state.transport_options), + ) + .map_err(|err| TransportServerError::TransportError(err.to_string()))?; + if let Some(event_store) = event_store.clone() { + transport.make_resumable((*session_id).clone(), (*stream_id).clone(), event_store); + } + let transport = Arc::new(transport); + + let ping_interval = state.ping_interval; + let runtime_clone = Arc::clone(&runtime); + let stream_id_clone = stream_id.clone(); + let transport_clone = transport.clone(); + + //Start the server runtime + tokio::spawn(async move { + match runtime_clone + .start_stream( + transport_clone, + &stream_id_clone, + ping_interval, + payload_string, + ) + .await + { + Ok(_) => tracing::trace!("stream {} exited gracefully.", &stream_id_clone), + Err(err) => tracing::info!("stream {} exited with error : {}", &stream_id_clone, err), + } + let _ = runtime.remove_transport(&stream_id_clone).await; + }); + + // Construct SSE stream + let reader = BufReader::new(write_rx); + + // send outgoing messages from server to the client over the sse stream + let message_stream = stream::unfold(reader, move |mut reader| { + async move { + let mut line = String::new(); + + match reader.read_line(&mut line).await { + Ok(0) => None, // EOF + Ok(_) => { + let trimmed_line = line.trim_end_matches('\n').to_owned(); + + // empty sse comment to keep-alive + if is_empty_sse_message(&trimmed_line) { + return Some((Ok(SseEvent::default().as_bytes()), reader)); + } + + let (event_id, message) = match ( + resumability_enabled, + trimmed_line.split_once(char::from(ID_SEPARATOR)), + ) { + (true, Some((id, msg))) => (Some(id.to_string()), msg.to_string()), + _ => (None, trimmed_line), + }; + + let event = match event_id { + Some(id) => SseEvent::default() + .with_data(message) + .with_id(id) + .as_bytes(), + None => SseEvent::default().with_data(message).as_bytes(), + }; + + Some((Ok(event), reader)) + } + Err(e) => Some((Err(e), reader)), + } + } + }); + + // create a stream body + let streaming_body: GenericBody = + http_body_util::BodyExt::boxed(StreamBody::new(message_stream.map(|res| { + res.map(Frame::data) + .map_err(|err: std::io::Error| TransportServerError::HttpError(err.to_string())) + }))); + + let session_id_value = HeaderValue::from_str(&session_id) + .map_err(|err| TransportServerError::HttpError(err.to_string()))?; + + let status_code = if !payload_contains_request { + StatusCode::ACCEPTED + } else { + StatusCode::OK + }; + + let response = http::Response::builder() + .status(status_code) + .header(CONTENT_TYPE, "text/event-stream") + .header(MCP_SESSION_ID_HEADER, session_id_value) + .header(CONNECTION, "keep-alive") + .body(streaming_body) + .map_err(|err| TransportServerError::HttpError(err.to_string()))?; + + // if last_event_id exists we replay messages from the event-store + tokio::spawn(async move { + if let Some(last_event_id) = last_event_id { + if let Some(event_store) = state.event_store.as_ref() { + if let Some(events) = event_store.events_after(last_event_id).await { + for message_payload in events.messages { + // skip storing replay messages + let error = transport.write_str(&message_payload, true).await; + if let Err(error) = error { + tracing::trace!("Error replaying message: {error}") + } + } + } + } + } + }); + + Ok(response) +} + +// TODO: this function will be removed after refactoring the readable stream of the transports +// so we would deserialize the string syncronousely and have more control over the flow +// this function may incur a slight runtime cost which could be avoided after refactoring +fn contains_request(json_str: &str) -> Result { + let value: serde_json::Value = serde_json::from_str(json_str)?; + match value { + serde_json::Value::Object(obj) => Ok(obj.contains_key("id") && obj.contains_key("method")), + serde_json::Value::Array(arr) => Ok(arr.iter().any(|item| { + item.as_object() + .map(|obj| obj.contains_key("id") && obj.contains_key("method")) + .unwrap_or(false) + })), + _ => Ok(false), + } +} + +fn is_result(json_str: &str) -> Result { + let value: serde_json::Value = serde_json::from_str(json_str)?; + match value { + serde_json::Value::Object(obj) => Ok(obj.contains_key("result")), + serde_json::Value::Array(arr) => Ok(arr.iter().all(|item| { + item.as_object() + .map(|obj| obj.contains_key("result")) + .unwrap_or(false) + })), + _ => Ok(false), + } +} + +pub async fn create_standalone_stream( + session_id: SessionId, + last_event_id: Option, + state: Arc, +) -> TransportServerResult> { + let runtime = state.session_store.get(&session_id).await.ok_or( + TransportServerError::SessionIdInvalid(session_id.to_string()), + )?; + let runtime = runtime.lock().await.to_owned(); + + if runtime.stream_id_exists(DEFAULT_STREAM_ID).await { + let error = + SdkError::bad_request().with_message("Only one SSE stream is allowed per session"); + return error_response(StatusCode::CONFLICT, error) + .map_err(|err| TransportServerError::HttpError(err.to_string())); + } + + if let Some(last_event_id) = last_event_id.as_ref() { + tracing::trace!( + "SSE stream re-connected with last-event-id: {}", + last_event_id + ); + } + + let mut response = create_sse_stream( + runtime.clone(), + session_id.clone(), + state.clone(), + None, + true, + last_event_id, + ) + .await?; + *response.status_mut() = StatusCode::OK; + Ok(response) +} + +pub async fn start_new_session( + state: Arc, + payload: &str, +) -> TransportServerResult> { + let session_id: SessionId = state.id_generator.generate(); + + let h: Arc = state.handler.clone(); + // create a new server instance with unique session_id and + let runtime: Arc = server_runtime::create_server_instance( + Arc::clone(&state.server_details), + h, + session_id.to_owned(), + ); + + tracing::info!("a new client joined : {}", &session_id); + + let response = create_sse_stream( + runtime.clone(), + session_id.clone(), + state.clone(), + Some(payload), + false, + None, + ) + .await; + + if response.is_ok() { + state + .session_store + .set(session_id.to_owned(), runtime.clone()) + .await; + } + response +} +async fn single_shot_stream( + runtime: Arc, + session_id: SessionId, + state: Arc, + payload: Option<&str>, + standalone: bool, +) -> TransportServerResult> { + // readable stream of string to be used in transport + let (read_tx, read_rx) = duplex(DUPLEX_BUFFER_SIZE); + // writable stream to deliver message to the client + let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE); + + let transport = SseTransport::::new( + read_rx, + write_tx, + read_tx, + Arc::clone(&state.transport_options), + ) + .map_err(|err| TransportServerError::TransportError(err.to_string()))?; + + let stream_id = if standalone { + DEFAULT_STREAM_ID.to_string() + } else { + state.id_generator.generate() + }; + let ping_interval = state.ping_interval; + let runtime_clone = Arc::clone(&runtime); + + let payload_string = payload.map(|p| p.to_string()); + + tokio::spawn(async move { + match runtime_clone + .start_stream( + Arc::new(transport), + &stream_id, + ping_interval, + payload_string, + ) + .await + { + Ok(_) => tracing::info!("stream {} exited gracefully.", &stream_id), + Err(err) => tracing::info!("stream {} exited with error : {}", &stream_id, err), + } + let _ = runtime.remove_transport(&stream_id).await; + }); + + let mut reader = BufReader::new(write_rx); + let mut line = String::new(); + let response = match reader.read_line(&mut line).await { + Ok(0) => None, // EOF + Ok(_) => { + let trimmed_line = line.trim_end_matches('\n').to_owned(); + Some(Ok(trimmed_line)) + } + Err(e) => Some(Err(e)), + }; + + let session_id_value = HeaderValue::from_str(&session_id) + .map_err(|err| TransportServerError::HttpError(err.to_string()))?; + + match response { + Some(response_result) => match response_result { + Ok(response_str) => { + let body = Full::new(Bytes::from(response_str)) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed(); + + http::Response::builder() + .status(StatusCode::OK) + .header(CONTENT_TYPE, "application/json") + .header(MCP_SESSION_ID_HEADER, session_id_value) + .body(body) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + } + Err(err) => { + let body = Full::new(Bytes::from(err.to_string())) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed(); + http::Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .header(CONTENT_TYPE, "application/json") + .body(body) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + } + }, + None => { + let body = Full::new(Bytes::from( + "End of the transport stream reached.".to_string(), + )) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed(); + http::Response::builder() + .status(StatusCode::UNPROCESSABLE_ENTITY) + .header(CONTENT_TYPE, "application/json") + .body(body) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + } + } +} + +pub async fn process_incoming_message_return( + session_id: SessionId, + state: Arc, + payload: &str, +) -> TransportServerResult> { + match state.session_store.get(&session_id).await { + Some(runtime) => { + let runtime = runtime.lock().await.to_owned(); + + single_shot_stream( + runtime.clone(), + session_id, + state.clone(), + Some(payload), + false, + ) + .await + // Ok(StatusCode::OK.into_response()) + } + None => { + let error = SdkError::session_not_found(); + error_response(StatusCode::NOT_FOUND, error) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + } + } +} + +pub async fn process_incoming_message( + session_id: SessionId, + state: Arc, + payload: &str, +) -> TransportServerResult> { + match state.session_store.get(&session_id).await { + Some(runtime) => { + let runtime = runtime.lock().await.to_owned(); + // when receiving a result in a streamable_http server, that means it was sent by the standalone sse transport + // it should be processed by the same transport , therefore no need to call create_sse_stream + let Ok(is_result) = is_result(payload) else { + return error_response(StatusCode::BAD_REQUEST, SdkError::parse_error()); + }; + + if is_result { + match runtime + .consume_payload_string(DEFAULT_STREAM_ID, payload) + .await + { + Ok(()) => { + let body = Full::new(Bytes::new()) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed(); + http::Response::builder() + .status(200) + .header("Content-Type", "application/json") + .body(body) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + } + Err(err) => { + let error = + SdkError::internal_error().with_message(err.to_string().as_ref()); + error_response(StatusCode::BAD_REQUEST, error) + } + } + } else { + create_sse_stream( + runtime.clone(), + session_id.clone(), + state.clone(), + Some(payload), + false, + None, + ) + .await + } + } + None => { + let error = SdkError::session_not_found(); + error_response(StatusCode::NOT_FOUND, error) + } + } +} + +pub fn is_empty_sse_message(sse_payload: &str) -> bool { + sse_payload.is_empty() || sse_payload.trim() == ":" +} + +pub async fn delete_session( + session_id: SessionId, + state: Arc, +) -> TransportServerResult> { + match state.session_store.get(&session_id).await { + Some(runtime) => { + let runtime = runtime.lock().await.to_owned(); + runtime.shutdown().await; + state.session_store.delete(&session_id).await; + tracing::info!("client disconnected : {}", &session_id); + + let body = Full::new(Bytes::from("ok")) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed(); + http::Response::builder() + .status(200) + .header("Content-Type", "application/json") + .body(body) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + } + None => { + let error = SdkError::session_not_found(); + error_response(StatusCode::NOT_FOUND, error) + } + } +} + +pub fn acceptable_content_type(headers: &HeaderMap) -> bool { + let accept_header = headers + .get("content-type") + .and_then(|val| val.to_str().ok()) + .unwrap_or(""); + accept_header + .split(',') + .any(|val| val.trim().starts_with("application/json")) +} + +pub fn validate_mcp_protocol_version_header(headers: &HeaderMap) -> SdkResult<()> { + let protocol_version_header = headers + .get(MCP_PROTOCOL_VERSION_HEADER) + .and_then(|val| val.to_str().ok()) + .unwrap_or(""); + + // requests without protocol version header are acceptable + if protocol_version_header.is_empty() { + return Ok(()); + } + + validate_mcp_protocol_version(protocol_version_header) +} + +pub fn accepts_event_stream(headers: &HeaderMap) -> bool { + let accept_header = headers + .get(ACCEPT) + .and_then(|val| val.to_str().ok()) + .unwrap_or(""); + + accept_header + .split(',') + .any(|val| val.trim().starts_with("text/event-stream")) +} + +pub fn valid_streaming_http_accept_header(headers: &HeaderMap) -> bool { + let accept_header = headers + .get(ACCEPT) + .and_then(|val| val.to_str().ok()) + .unwrap_or(""); + + let types: Vec<_> = accept_header.split(',').map(|v| v.trim()).collect(); + + let has_event_stream = types.iter().any(|v| v.starts_with("text/event-stream")); + let has_json = types.iter().any(|v| v.starts_with("application/json")); + has_event_stream && has_json +} + +pub fn error_response( + status_code: StatusCode, + error: SdkError, +) -> TransportServerResult> { + let error_string = serde_json::to_string(&error).unwrap_or_default(); + let body = Full::new(Bytes::from(error_string)) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed(); + + http::Response::builder() + .status(status_code) + .header(CONTENT_TYPE, "application/json") + .body(body) + .map_err(|err| TransportServerError::HttpError(err.to_string())) +} + +// Protect against DNS rebinding attacks by validating Host and Origin headers. +pub(crate) async fn protect_dns_rebinding( + headers: &http::HeaderMap, + state: Arc, +) -> Result<(), SdkError> { + if !state.needs_dns_protection() { + // If protection is not needed, pass the request to the next handler + return Ok(()); + } + + if let Some(allowed_hosts) = state.allowed_hosts.as_ref() { + if !allowed_hosts.is_empty() { + let Some(host) = headers.get(HOST).and_then(|h| h.to_str().ok()) else { + return Err(SdkError::bad_request().with_message("Invalid Host header: [unknown] ")); + }; + + if !allowed_hosts + .iter() + .any(|allowed| allowed.eq_ignore_ascii_case(host)) + { + return Err(SdkError::bad_request() + .with_message(format!("Invalid Host header: \"{host}\" ").as_str())); + } + } + } + + if let Some(allowed_origins) = state.allowed_origins.as_ref() { + if !allowed_origins.is_empty() { + let Some(origin) = headers.get(ORIGIN).and_then(|h| h.to_str().ok()) else { + return Err( + SdkError::bad_request().with_message("Invalid Origin header: [unknown] ") + ); + }; + + if !allowed_origins + .iter() + .any(|allowed| allowed.eq_ignore_ascii_case(origin)) + { + return Err(SdkError::bad_request() + .with_message(format!("Invalid Origin header: \"{origin}\" ").as_str())); + } + } + } + + Ok(()) +} + +/// Extracts the value of a query parameter from an HTTP request by key. +/// +/// This function parses the query string from the request URI and searches +/// for the specified key. If found, it returns the corresponding value as a `String`. +/// +/// # Arguments +/// * `request` - The HTTP request containing the URI with the query string. +/// * `key` - The name of the query parameter to retrieve. +/// +/// # Returns +/// * `Some(String)` containing the value of the query parameter if found. +/// * `None` if the query string is missing or the key is not present. +/// +pub fn query_param(request: &http::Request<&str>, key: &str) -> Option { + request.uri().query().and_then(|query| { + for pair in query.split('&') { + let mut split = pair.splitn(2, '='); + let k = split.next()?; + let v = split.next().unwrap_or(""); + if k == key { + return Some(v.to_string()); + } + } + None + }) +} + +#[cfg(feature = "sse")] +pub(crate) async fn handle_sse_connection( + state: Arc, + sse_message_endpoint: Option<&str>, +) -> TransportServerResult> { + let session_id: SessionId = state.id_generator.generate(); + + let sse_message_endpoint = sse_message_endpoint.unwrap_or(DEFAULT_MESSAGES_ENDPOINT); + let messages_endpoint = + SseTransport::::message_endpoint(sse_message_endpoint, &session_id); + + // readable stream of string to be used in transport + // writing string to read_tx will be received as messages inside the transport and messages will be processed + let (read_tx, read_rx) = duplex(DUPLEX_BUFFER_SIZE); + + // writable stream to deliver message to the client + let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE); + + // / create a transport for sending/receiving messages + let Ok(transport) = SseTransport::new( + read_rx, + write_tx, + read_tx, + Arc::clone(&state.transport_options), + ) else { + return Err(TransportServerError::TransportError( + "Failed to create SSE transport".to_string(), + )); + }; + + let h: Arc = state.handler.clone(); + // create a new server instance with unique session_id and + let server: Arc = server_runtime::create_server_instance( + Arc::clone(&state.server_details), + h, + session_id.to_owned(), + ); + + state + .session_store + .set(session_id.to_owned(), server.clone()) + .await; + + tracing::info!("A new client joined : {}", session_id.to_owned()); + + // Start the server + tokio::spawn(async move { + match server + .start_stream( + Arc::new(transport), + DEFAULT_STREAM_ID, + state.ping_interval, + None, + ) + .await + { + Ok(_) => tracing::info!("server {} exited gracefully.", session_id.to_owned()), + Err(err) => tracing::info!( + "server {} exited with error : {}", + session_id.to_owned(), + err + ), + }; + + state.session_store.delete(&session_id).await; + }); + + // Initial SSE message to inform the client about the server's endpoint + let initial_sse_event = stream::once(async move { initial_sse_event(&messages_endpoint) }); + + // Construct SSE stream + let reader = BufReader::new(write_rx); + + let message_stream = stream::unfold(reader, |mut reader| async move { + let mut line = String::new(); + + match reader.read_line(&mut line).await { + Ok(0) => None, // EOF + Ok(_) => { + let trimmed_line = line.trim_end_matches('\n').to_owned(); + Some(( + Ok(SseEvent::default().with_data(trimmed_line).as_bytes()), + reader, + )) + } + Err(_) => None, // Err(e) => Some((Err(e), reader)), + } + }); + + let stream = initial_sse_event.chain(message_stream); + + // create a stream body + let streaming_body: GenericBody = + http_body_util::BodyExt::boxed(StreamBody::new(stream.map(|res| res.map(Frame::data)))); + + let response = http::Response::builder() + .status(StatusCode::OK) + .header(CONTENT_TYPE, "text/event-stream") + .header(CONNECTION, "keep-alive") + .body(streaming_body) + .map_err(|err| TransportServerError::HttpError(err.to_string()))?; + + Ok(response) +} diff --git a/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs b/crates/rust-mcp-sdk/src/mcp_http/session_store.rs similarity index 100% rename from crates/rust-mcp-sdk/src/hyper_servers/session_store.rs rename to crates/rust-mcp-sdk/src/mcp_http/session_store.rs diff --git a/crates/rust-mcp-sdk/src/hyper_servers/session_store/in_memory.rs b/crates/rust-mcp-sdk/src/mcp_http/session_store/in_memory.rs similarity index 100% rename from crates/rust-mcp-sdk/src/hyper_servers/session_store/in_memory.rs rename to crates/rust-mcp-sdk/src/mcp_http/session_store/in_memory.rs diff --git a/crates/rust-mcp-sdk/src/utils.rs b/crates/rust-mcp-sdk/src/utils.rs index 16fe7c7..2d80f1e 100644 --- a/crates/rust-mcp-sdk/src/utils.rs +++ b/crates/rust-mcp-sdk/src/utils.rs @@ -1,6 +1,5 @@ -use crate::schema::schema_utils::{ClientMessages, SdkError}; - use crate::error::{McpSdkError, ProtocolErrorKind, SdkResult}; +use crate::schema::schema_utils::{ClientMessages, SdkError}; use crate::schema::ProtocolVersion; use std::cmp::Ordering; diff --git a/crates/rust-mcp-transport/src/lib.rs b/crates/rust-mcp-transport/src/lib.rs index d21e5dd..7566290 100644 --- a/crates/rust-mcp-transport/src/lib.rs +++ b/crates/rust-mcp-transport/src/lib.rs @@ -31,6 +31,9 @@ pub use sse::*; pub use stdio::*; pub use transport::*; +#[cfg(any(feature = "sse", feature = "streamable-http"))] +pub use utils::SseEvent; + // Type alias for session identifier, represented as a String pub type SessionId = String; // Type alias for stream identifier (that will be used at the transport scope), represented as a String diff --git a/crates/rust-mcp-transport/src/utils.rs b/crates/rust-mcp-transport/src/utils.rs index 034f062..36977a2 100644 --- a/crates/rust-mcp-transport/src/utils.rs +++ b/crates/rust-mcp-transport/src/utils.rs @@ -1,40 +1,57 @@ mod cancellation_token; + #[cfg(any(feature = "sse", feature = "streamable-http"))] mod http_utils; + #[cfg(any(feature = "sse", feature = "streamable-http"))] mod readable_channel; + +#[cfg(any(feature = "sse", feature = "streamable-http"))] +mod sse_event; + #[cfg(any(feature = "sse", feature = "streamable-http"))] mod sse_parser; + #[cfg(feature = "sse")] mod sse_stream; + #[cfg(feature = "streamable-http")] mod streamable_http_stream; + +mod time_utils; + #[cfg(any(feature = "sse", feature = "streamable-http"))] mod writable_channel; +use crate::error::{TransportError, TransportResult}; +use crate::schema::schema_utils::SdkError; pub(crate) use cancellation_token::*; + +#[cfg(any(feature = "sse", feature = "streamable-http"))] +use crate::SessionId; + #[cfg(any(feature = "sse", feature = "streamable-http"))] pub(crate) use http_utils::*; + #[cfg(any(feature = "sse", feature = "streamable-http"))] pub(crate) use readable_channel::*; + +#[cfg(any(feature = "sse", feature = "streamable-http"))] +pub use sse_event::*; + #[cfg(any(feature = "sse", feature = "streamable-http"))] pub(crate) use sse_parser::*; + #[cfg(feature = "sse")] pub(crate) use sse_stream::*; + #[cfg(feature = "streamable-http")] pub(crate) use streamable_http_stream::*; -#[cfg(any(feature = "sse", feature = "streamable-http"))] -pub(crate) use writable_channel::*; -mod time_utils; -pub use time_utils::*; -use crate::schema::schema_utils::SdkError; +pub use time_utils::*; use tokio::time::{timeout, Duration}; - -use crate::error::{TransportError, TransportResult}; - #[cfg(any(feature = "sse", feature = "streamable-http"))] -use crate::SessionId; +pub(crate) use writable_channel::*; pub async fn await_timeout(operation: F, timeout_duration: Duration) -> TransportResult where diff --git a/crates/rust-mcp-transport/src/utils/sse_event.rs b/crates/rust-mcp-transport/src/utils/sse_event.rs new file mode 100644 index 0000000..5837807 --- /dev/null +++ b/crates/rust-mcp-transport/src/utils/sse_event.rs @@ -0,0 +1,122 @@ +use bytes::Bytes; +use core::fmt; + +/// Represents a single Server-Sent Event (SSE) as defined in the SSE protocol. +/// +/// Contains the event type, data payload, and optional event ID. +#[derive(Clone, Default)] +pub struct SseEvent { + /// The optional event type (e.g., "message"). + pub event: Option, + /// The optional data payload of the event, stored as bytes. + pub data: Option, + /// The optional event ID for reconnection or tracking purposes. + pub id: Option, + /// Optional reconnection retry interval (in milliseconds). + pub retry: Option, +} + +impl SseEvent { + /// Creates a new `SseEvent` with the given string data. + pub fn new>(data: T) -> Self { + Self { + event: None, + data: Some(Bytes::from(data.into())), + id: None, + retry: None, + } + } + + /// Sets the event name (e.g., "message"). + pub fn with_event>(mut self, event: T) -> Self { + self.event = Some(event.into()); + self + } + + /// Sets the ID of the event. + pub fn with_id>(mut self, id: T) -> Self { + self.id = Some(id.into()); + self + } + + /// Sets the retry interval (in milliseconds). + pub fn with_retry(mut self, retry: u64) -> Self { + self.retry = Some(retry); + self + } + + /// Sets the data as bytes. + pub fn with_data_bytes(mut self, data: Bytes) -> Self { + self.data = Some(data); + self + } + + /// Sets the data. + pub fn with_data(mut self, data: String) -> Self { + self.data = Some(Bytes::from(data)); + self + } + + /// Converts the event into a string in SSE format (ready for HTTP body). + pub fn to_sse_string(&self) -> String { + self.to_string() + } + + pub fn as_bytes(&self) -> Bytes { + Bytes::from(self.to_string()) + } +} + +impl std::fmt::Display for SseEvent { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // Emit retry interval + if let Some(retry) = self.retry { + writeln!(f, "retry: {retry}")?; + } + + // Emit ID + if let Some(id) = &self.id { + writeln!(f, "id: {id}")?; + } + + // Emit event type + if let Some(event) = &self.event { + writeln!(f, "event: {event}")?; + } + + // Emit data lines + if let Some(data) = &self.data { + match std::str::from_utf8(data) { + Ok(text) => { + for line in text.lines() { + writeln!(f, "data: {line}")?; + } + } + Err(_) => { + writeln!(f, "data: [binary data]")?; + } + } + } + + writeln!(f)?; // Trailing newline for SSE message end, separates events + Ok(()) + } +} + +impl fmt::Debug for SseEvent { + /// Formats the `SseEvent` for debugging, converting the `data` field to a UTF-8 string + /// (with lossy conversion if invalid UTF-8 is encountered). + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let data_str = self + .data + .as_ref() + .map(|b| String::from_utf8_lossy(b).to_string()); + + f.debug_struct("SseEvent") + .field("event", &self.event) + .field("data", &data_str) + .field("id", &self.id) + .field("retry", &self.retry) + .finish() + } +} diff --git a/crates/rust-mcp-transport/src/utils/sse_parser.rs b/crates/rust-mcp-transport/src/utils/sse_parser.rs index 5933726..3074e9f 100644 --- a/crates/rust-mcp-transport/src/utils/sse_parser.rs +++ b/crates/rust-mcp-transport/src/utils/sse_parser.rs @@ -1,66 +1,9 @@ -use core::fmt; +use bytes::{Bytes, BytesMut}; use std::collections::HashMap; -use bytes::{Bytes, BytesMut}; +use super::SseEvent; const BUFFER_CAPACITY: usize = 1024; -/// Represents a single Server-Sent Event (SSE) as defined in the SSE protocol. -/// -/// Contains the event type, data payload, and optional event ID. -pub struct SseEvent { - /// The optional event type (e.g., "message"). - pub event: Option, - /// The optional data payload of the event, stored as bytes. - pub data: Option, - /// The optional event ID for reconnection or tracking purposes. - pub id: Option, -} - -impl std::fmt::Display for SseEvent { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if let Some(id) = &self.id { - writeln!(f, "id: {id}")?; - } - - if let Some(event) = &self.event { - writeln!(f, "event: {event}")?; - } - - if let Some(data) = &self.data { - match std::str::from_utf8(data) { - Ok(text) => { - for line in text.lines() { - writeln!(f, "data: {line}")?; - } - } - Err(_) => { - writeln!(f, "data: [binary data]")?; - } - } - } - - writeln!(f)?; // Trailing newline for SSE message end - Ok(()) - } -} - -impl fmt::Debug for SseEvent { - /// Formats the `SseEvent` for debugging, converting the `data` field to a UTF-8 string - /// (with lossy conversion if invalid UTF-8 is encountered). - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let data_str = self - .data - .as_ref() - .map(|b| String::from_utf8_lossy(b).to_string()); - - f.debug_struct("SseEvent") - .field("event", &self.event) - .field("data", &data_str) - .field("id", &self.id) - .finish() - } -} - /// A parser for Server-Sent Events (SSE) that processes incoming byte chunks into `SseEvent`s. /// This Parser is specifically designed for MCP messages and with no multi-line data support /// @@ -193,11 +136,15 @@ impl SseParser { // Get event (default to None) let event = fields.get("event").cloned(); let id = fields.get("id").cloned(); + let retry = fields + .get("retry") + .and_then(|r| r.trim().parse::().ok()); Some(SseEvent { event, data: Some(data), id, + retry, }) } } @@ -317,4 +264,20 @@ mod tests { Some(Bytes::from("second\n").as_ref()) ); } + + #[test] + fn test_basic_sse_event() { + let mut parser = SseParser::new(); + let input = Bytes::from("event: message\ndata: Hello\nid: 1\nretry: 5000\n\n"); + + let events = parser.process_new_chunk(input); + + assert_eq!(events.len(), 1); + + let event = &events[0]; + assert_eq!(event.event.as_deref(), Some("message")); + assert_eq!(event.data.as_deref(), Some(Bytes::from("Hello\n").as_ref())); + assert_eq!(event.id.as_deref(), Some("1")); + assert_eq!(event.retry, Some(5000)); + } }