diff --git a/Cargo.toml b/Cargo.toml index 26fb067..718c9a4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -99,3 +99,6 @@ todo = "deny" [workspace.metadata.cargo-machete] ignored = ["bindgen", "cbindgen", "prost_build", "serde"] + +[workspace.metadata.typos] +default.extend-ignore-re = ["clonable"] diff --git a/crates/rust-mcp-extra/src/id_generator/nano_id_generator.rs b/crates/rust-mcp-extra/src/id_generator/nano_id_generator.rs index a50ec2b..1ad5697 100644 --- a/crates/rust-mcp-extra/src/id_generator/nano_id_generator.rs +++ b/crates/rust-mcp-extra/src/id_generator/nano_id_generator.rs @@ -64,7 +64,7 @@ mod tests { for _ in 0..1000 { let id: String = generator.generate(); - assert!(seen.insert(id.clone()), "Duplicate ID: {}", id); + assert!(seen.insert(id.clone()), "Duplicate ID: {id}"); } } } diff --git a/crates/rust-mcp-extra/src/id_generator/snow_flake_id_generator.rs b/crates/rust-mcp-extra/src/id_generator/snow_flake_id_generator.rs index 39942ec..5ab2cb8 100644 --- a/crates/rust-mcp-extra/src/id_generator/snow_flake_id_generator.rs +++ b/crates/rust-mcp-extra/src/id_generator/snow_flake_id_generator.rs @@ -133,9 +133,7 @@ mod tests { let current_id: u64 = id.parse().expect("ID should be a valid u64"); assert!( current_id > prev_id, - "ID not strictly increasing: {} <= {}", - current_id, - prev_id + "ID not strictly increasing: {current_id} <= {prev_id}" ); prev_id = current_id; } diff --git a/crates/rust-mcp-sdk/Cargo.toml b/crates/rust-mcp-sdk/Cargo.toml index 0ecc527..27f8d6f 100644 --- a/crates/rust-mcp-sdk/Cargo.toml +++ b/crates/rust-mcp-sdk/Cargo.toml @@ -109,3 +109,6 @@ macros = ["rust-mcp-macros/sdk"] [lints] workspace = true + +[package.metadata.typos] +default.extend-ignore-re = ["clonable"] diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes.rs index 4ae274b..fcaa290 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes.rs @@ -4,10 +4,9 @@ pub mod messages_routes; pub mod sse_routes; pub mod streamable_http_routes; -use crate::mcp_http::McpAppState; - use super::HyperServerOptions; -use axum::Router; +use crate::mcp_http::{McpAppState, McpHttpHandler}; +use axum::{Extension, Router}; use std::sync::Arc; /// Constructs the Axum router with all application routes @@ -21,7 +20,11 @@ use std::sync::Arc; /// /// # Returns /// * `Router` - An Axum router configured with all application routes and state -pub fn app_routes(state: Arc, server_options: &HyperServerOptions) -> Router { +pub fn app_routes( + state: Arc, + server_options: &HyperServerOptions, + http_handler: McpHttpHandler, +) -> Router { let router: Router = Router::new() .merge(streamable_http_routes::routes( server_options.streamable_http_endpoint(), @@ -42,7 +45,8 @@ pub fn app_routes(state: Arc, server_options: &HyperServerOptions) r }) .with_state(state) - .merge(fallback_routes::routes()); + .merge(fallback_routes::routes()) + .layer(Extension(Arc::new(http_handler))); router } diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs index 65490a3..cc85254 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs @@ -3,7 +3,7 @@ use crate::{ mcp_http::{McpAppState, McpHttpHandler}, utils::remove_query_and_hash, }; -use axum::{extract::State, response::IntoResponse, routing::post, Router}; +use axum::{extract::State, response::IntoResponse, routing::post, Extension, Router}; use http::{HeaderMap, Method, Uri}; use std::sync::Arc; @@ -18,10 +18,13 @@ pub async fn handle_messages( uri: Uri, headers: HeaderMap, State(state): State>, + Extension(http_handler): Extension>, message: String, ) -> TransportServerResult { let request = McpHttpHandler::create_request(Method::POST, uri, headers, Some(&message)); - let generic_response = McpHttpHandler::handle_sse_message(request, state.clone()).await?; + let generic_response = http_handler + .handle_sse_message(request, state.clone()) + .await?; let (parts, body) = generic_response.into_parts(); let resp = axum::response::Response::from_parts(parts, axum::body::Body::new(body)); Ok(resp) diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs index e13c724..c85d81f 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs @@ -36,11 +36,13 @@ pub fn routes(sse_endpoint: &str, sse_message_endpoint: &str) -> Router` - The SSE response stream or an error pub async fn handle_sse( Extension(sse_message_endpoint): Extension, + Extension(http_handler): Extension>, State(state): State>, ) -> TransportServerResult { let SseMessageEndpoint(sse_message_endpoint) = sse_message_endpoint; - let generic_response = - McpHttpHandler::handle_sse_connection(state.clone(), Some(&sse_message_endpoint)).await?; + let generic_response = http_handler + .handle_sse_connection(state.clone(), Some(&sse_message_endpoint)) + .await?; let (parts, body) = generic_response.into_parts(); let resp = axum::response::Response::from_parts(parts, axum::body::Body::new(body)); Ok(resp) diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs index 6f2e470..69287d4 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs @@ -1,6 +1,7 @@ use crate::hyper_servers::error::TransportServerResult; use crate::mcp_http::{McpAppState, McpHttpHandler}; use axum::routing::get; +use axum::Extension; use axum::{ extract::{Query, State}, response::IntoResponse, @@ -24,9 +25,10 @@ pub async fn handle_streamable_http_get( headers: HeaderMap, uri: Uri, State(state): State>, + Extension(http_handler): Extension>, ) -> TransportServerResult { let request = McpHttpHandler::create_request(Method::GET, uri, headers, None); - let generic_res = McpHttpHandler::handle_streamable_http(request, state).await?; + let generic_res = http_handler.handle_streamable_http(request, state).await?; let (parts, body) = generic_res.into_parts(); let resp = axum::response::Response::from_parts(parts, axum::body::Body::new(body)); Ok(resp) @@ -36,12 +38,13 @@ pub async fn handle_streamable_http_post( headers: HeaderMap, uri: Uri, State(state): State>, + Extension(http_handler): Extension>, Query(_params): Query>, payload: String, ) -> TransportServerResult { let request = McpHttpHandler::create_request(Method::POST, uri, headers, Some(payload.as_str())); - let generic_res = McpHttpHandler::handle_streamable_http(request, state).await?; + let generic_res = http_handler.handle_streamable_http(request, state).await?; let (parts, body) = generic_res.into_parts(); let resp = axum::response::Response::from_parts(parts, axum::body::Body::new(body)); Ok(resp) @@ -51,9 +54,10 @@ pub async fn handle_streamable_http_delete( headers: HeaderMap, uri: Uri, State(state): State>, + Extension(http_handler): Extension>, ) -> TransportServerResult { let request = McpHttpHandler::create_request(Method::DELETE, uri, headers, None); - let generic_res = McpHttpHandler::handle_streamable_http(request, state).await?; + let generic_res = http_handler.handle_streamable_http(request, state).await?; let (parts, body) = generic_res.into_parts(); let resp = axum::response::Response::from_parts(parts, axum::body::Body::new(body)); Ok(resp) diff --git a/crates/rust-mcp-sdk/src/hyper_servers/server.rs b/crates/rust-mcp-sdk/src/hyper_servers/server.rs index 4cd8eb6..f3e0983 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/server.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/server.rs @@ -5,7 +5,7 @@ use crate::{ utils::{ DEFAULT_MESSAGES_ENDPOINT, DEFAULT_SSE_ENDPOINT, DEFAULT_STREAMABLE_HTTP_ENDPOINT, }, - McpAppState, + McpAppState, McpHttpHandler, }, mcp_server::hyper_runtime::HyperRuntime, mcp_traits::{mcp_handler::McpServerHandler, IdGenerator}, @@ -275,7 +275,9 @@ impl HyperServer { dns_rebinding_protection: server_options.dns_rebinding_protection, event_store: server_options.event_store.as_ref().map(Arc::clone), }); - let app = app_routes(Arc::clone(&state), &server_options); + + let http_handler = McpHttpHandler::new(); //TODO: add auth handlers + let app = app_routes(Arc::clone(&state), &server_options, http_handler); Self { app, state, diff --git a/crates/rust-mcp-sdk/src/mcp_http.rs b/crates/rust-mcp-sdk/src/mcp_http.rs index 3f443d5..2e5d8fd 100644 --- a/crates/rust-mcp-sdk/src/mcp_http.rs +++ b/crates/rust-mcp-sdk/src/mcp_http.rs @@ -2,8 +2,11 @@ mod app_state; mod mcp_http_handler; pub(crate) mod mcp_http_utils; +mod mcp_http_middleware; //TODO: + pub use app_state::*; pub use mcp_http_handler::*; +pub use mcp_http_middleware::Middleware; pub(crate) mod utils { pub use super::mcp_http_utils::*; diff --git a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs index 8b7efcf..c60b4dc 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs @@ -1,8 +1,11 @@ #[cfg(feature = "sse")] use super::utils::handle_sse_connection; +use crate::mcp_http::mcp_http_middleware::MiddlewareChain; use crate::mcp_http::utils::{ - accepts_event_stream, error_response, query_param, validate_mcp_protocol_version_header, + accepts_event_stream, empty_response, error_response, query_param, + validate_mcp_protocol_version_header, }; +use crate::mcp_http::Middleware; use crate::mcp_runtimes::server_runtime::DEFAULT_STREAM_ID; use crate::mcp_server::error::TransportServerError; use crate::schema::schema_utils::SdkError; @@ -19,26 +22,32 @@ use crate::{ mcp_server::error::TransportServerResult, utils::valid_initialize_method, }; -use bytes::Bytes; use http::{self, HeaderMap, Method, StatusCode, Uri}; -use http_body_util::{BodyExt, Full}; use rust_mcp_transport::{SessionId, MCP_LAST_EVENT_ID_HEADER, MCP_SESSION_ID_HEADER}; use std::sync::Arc; -pub struct McpHttpHandler {} +#[derive(Clone)] +pub struct McpHttpHandler { + middleware_chain: MiddlewareChain, +} + +impl Default for McpHttpHandler { + fn default() -> Self { + Self::new() + } +} impl McpHttpHandler { - /// Creates a new HTTP request with the given method, URI, headers, and optional body. - /// - /// # Arguments - /// - /// * `method` - The HTTP method to use (e.g., GET, POST). - /// * `uri` - The target URI for the request. - /// * `headers` - A map of optional header keys and their corresponding values. - /// * `body` - An optional string slice representing the request body. - /// - /// # Returns - /// + pub fn new() -> Self { + McpHttpHandler { + middleware_chain: MiddlewareChain::new(), + } + } + + pub fn add_middleware(&mut self, middleware: M) { + self.middleware_chain.add_middleware(middleware); + } + /// An `http::Request<&str>` initialized with the specified method, URI, headers, and body. /// If the `body` is `None`, an empty string is used as the default. /// @@ -77,6 +86,7 @@ impl McpHttpHandler { /// This function is only available when the `sse` feature is enabled. #[cfg(feature = "sse")] pub async fn handle_sse_connection( + &self, state: Arc, sse_message_endpoint: Option<&str>, ) -> TransportServerResult> { @@ -104,6 +114,7 @@ impl McpHttpHandler { /// - `StreamIoError`: if an error occurs while writing to the stream. /// - `HttpError`: if constructing the HTTP response fails. pub async fn handle_sse_message( + &self, request: http::Request<&str>, state: Arc, ) -> TransportServerResult> { @@ -124,13 +135,9 @@ impl McpHttpHandler { TransportServerError::StreamIoError(err.to_string()) })?; - let body = Full::new(Bytes::new()) - .map_err(|err| TransportServerError::HttpError(err.to_string())) - .boxed(); - http::Response::builder() .status(StatusCode::ACCEPTED) - .body(body) + .body(empty_response()) .map_err(|err| TransportServerError::HttpError(err.to_string())) } @@ -156,9 +163,16 @@ impl McpHttpHandler { /// * A `TransportServerResult` wrapping an HTTP response indicating success or failure of the operation. /// pub async fn handle_streamable_http( + &self, request: http::Request<&str>, state: Arc, ) -> TransportServerResult> { + let request = self + .middleware_chain + .process_request(request) + .await + .map_err(|e| TransportServerError::HttpError(e.to_string()))?; + // Enforces DNS rebinding protection if required by state. // If protection fails, respond with HTTP 403 Forbidden. if state.needs_dns_protection() { @@ -168,24 +182,36 @@ impl McpHttpHandler { } let method = request.method(); - match method { - &http::Method::GET => return Self::handle_http_get(request, state).await, - &http::Method::POST => return Self::handle_http_post(request, state).await, - &http::Method::DELETE => return Self::handle_http_delete(request, state).await, + let response = match method { + &http::Method::GET => return self.handle_http_get(request, state).await, + &http::Method::POST => return self.handle_http_post(request, state).await, + &http::Method::DELETE => return self.handle_http_delete(request, state).await, other => { let error = SdkError::bad_request().with_message(&format!( "'{other}' is not a valid HTTP method for StreamableHTTP transport." )); error_response(StatusCode::METHOD_NOT_ALLOWED, error) } - } + }; + + self.middleware_chain + .process_response(response?) + .await + .map_err(|e| TransportServerError::HttpError(e.to_string())) } /// Processes POST requests for the Streamable HTTP Protocol async fn handle_http_post( + &self, request: http::Request<&str>, state: Arc, ) -> TransportServerResult> { + let request = self + .middleware_chain + .process_request(request) + .await + .map_err(|e| TransportServerError::HttpError(e.to_string()))?; + let headers = request.headers(); if !valid_streaming_http_accept_header(headers) { @@ -213,7 +239,7 @@ impl McpHttpHandler { let payload = *request.body(); - match session_id { + let response = match session_id { // has session-id => write to the existing stream Some(id) => { if state.enable_json_response { @@ -232,14 +258,26 @@ impl McpHttpHandler { error_response(StatusCode::BAD_REQUEST, error) } }, - } + }; + + self.middleware_chain + .process_response(response?) + .await + .map_err(|e| TransportServerError::HttpError(e.to_string())) } /// Processes GET requests for the Streamable HTTP Protocol async fn handle_http_get( + &self, request: http::Request<&str>, state: Arc, ) -> TransportServerResult> { + let request = self + .middleware_chain + .process_request(request) + .await + .map_err(|e| TransportServerError::HttpError(e.to_string()))?; + let headers = request.headers(); if !accepts_event_stream(headers) { @@ -264,7 +302,7 @@ impl McpHttpHandler { .and_then(|value| value.to_str().ok()) .map(|s| s.to_string()); - match session_id { + let response = match session_id { Some(session_id) => { let res = create_standalone_stream(session_id, last_event_id, state).await; res @@ -273,14 +311,26 @@ impl McpHttpHandler { let error = SdkError::bad_request().with_message("Bad request: session not found"); error_response(StatusCode::BAD_REQUEST, error) } - } + }; + + self.middleware_chain + .process_response(response?) + .await + .map_err(|e| TransportServerError::HttpError(e.to_string())) } /// Processes DELETE requests for the Streamable HTTP Protocol async fn handle_http_delete( + &self, request: http::Request<&str>, state: Arc, ) -> TransportServerResult> { + let request = self + .middleware_chain + .process_request(request) + .await + .map_err(|e| TransportServerError::HttpError(e.to_string()))?; + let headers = request.headers(); if let Err(parse_error) = validate_mcp_protocol_version_header(headers) { @@ -294,12 +344,17 @@ impl McpHttpHandler { .and_then(|value| value.to_str().ok()) .map(|s| s.to_string()); - match session_id { + let response = match session_id { Some(id) => delete_session(id, state).await, None => { let error = SdkError::bad_request().with_message("Bad Request: Session not found"); error_response(StatusCode::BAD_REQUEST, error) } - } + }; + + self.middleware_chain + .process_response(response?) + .await + .map_err(|e| TransportServerError::HttpError(e.to_string())) } } diff --git a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_middleware.rs b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_middleware.rs new file mode 100644 index 0000000..22027d7 --- /dev/null +++ b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_middleware.rs @@ -0,0 +1,389 @@ +use crate::mcp_http::utils::GenericBody; +use crate::mcp_server::error::TransportServerResult; +use http::{Request, Response}; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; + +/// Defines a middleware trait for processing HTTP requests and responses. +/// +/// Implementors of this trait can define custom logic to modify or inspect HTTP +/// requests before they reach the handler and HTTP responses before they are sent +/// back to the client. Middleware must be thread-safe (`Send + Sync`) and have a +/// static lifetime. +pub trait Middleware: Send + Sync + 'static { + /// Processes an incoming HTTP request. + /// + /// This method takes a request, applies middleware-specific logic, and returns + /// a future that resolves to a `TransportServerResult` containing the modified + /// request or an error. + /// + /// # Arguments + /// * `request` - The incoming HTTP request with a string body reference. + /// + /// # Returns + /// A pinned boxed future resolving to a `TransportServerResult` containing the + /// processed request. + fn process_request<'a, 'b>( + &'a self, + request: Request<&'b str>, + ) -> Pin>> + Send + 'a>> + where + 'b: 'a; // Ensure the request's lifetime outlives the future + + /// Processes an outgoing HTTP response. + /// + /// This method takes a response, applies middleware-specific logic, and returns + /// a future that resolves to a `TransportServerResult` containing the modified + /// response or an error. + /// + /// # Arguments + /// * `response` - The HTTP response with a `GenericBody`. + /// + /// # Returns + /// A pinned boxed future resolving to a `TransportServerResult` containing the + /// processed response. + fn process_response<'a, 'b>( + &'a self, + response: Response, + ) -> Pin>> + Send + 'a>> + where + 'b: 'a; // Optional, included for consistency +} + +/// A chain of middleware to process HTTP requests and responses sequentially. +/// +/// `MiddlewareChain` allows multiple middleware instances to be registered and +/// executed in order for requests (forward order) and responses (reverse order). +#[derive(Clone)] +pub struct MiddlewareChain { + middlewares: Vec>, +} + +impl MiddlewareChain { + /// Creates a new, empty middleware chain. + /// + /// # Returns + /// A new `MiddlewareChain` instance with no middleware registered. + pub fn new() -> Self { + MiddlewareChain { + middlewares: Vec::new(), + } + } + + /// Adds a middleware to the chain. + /// + /// The middleware is wrapped in an `Arc` to ensure thread-safety and shared + /// ownership. Middleware will be executed in the order they are added for + /// requests and in reverse order for responses. + /// + /// # Arguments + /// * `middleware` - The middleware to add to the chain. + pub fn add_middleware(&mut self, middleware: M) { + self.middlewares.push(Arc::new(middleware)); + } + + /// Processes an HTTP request through all registered middleware. + /// + /// Each middleware's `process_request` method is called in the order they + /// were added. If any middleware returns an error, processing stops and the + /// error is returned. + /// + /// # Arguments + /// * `request` - The HTTP request to process. + /// + /// # Returns + /// A `TransportServerResult` containing the processed request or an error. + pub async fn process_request<'a>( + &self, + request: http::Request<&'a str>, + ) -> TransportServerResult> { + let mut request = request; + for middleware in &self.middlewares { + request = middleware.process_request(request).await?; + } + Ok(request) + } + + /// Processes an HTTP response through all registered middleware. + /// + /// Each middleware's `process_response` method is called in the reverse order + /// of their addition. If any middleware returns an error, processing stops and + /// the error is returned. + /// + /// # Arguments + /// * `response` - The HTTP response to process. + /// + /// # Returns + /// A `TransportServerResult` containing the processed response or an error. + pub async fn process_response( + &self, + response: http::Response, + ) -> TransportServerResult> { + let mut response = response; + for middleware in self.middlewares.iter().rev() { + response = middleware.process_response(response).await?; + } + Ok(response) + } +} + +// Sample Middleware +pub struct LoggingMiddleware; + +impl Middleware for LoggingMiddleware { + fn process_request<'a, 'b>( + &'a self, + request: http::Request<&'b str>, + ) -> Pin>> + Send + 'a>> + where + 'b: 'a, + { + Box::pin(async move { + tracing::info!("Request: {} {}", request.method(), request.uri()); + Ok(request) + }) + } + + fn process_response<'a, 'b>( + &'a self, + response: http::Response, + ) -> Pin>> + Send + 'a>> + where + 'b: 'a, + { + Box::pin(async move { + tracing::info!("Response: {}", response.status()); + Ok(response) + }) + } +} + +#[cfg(test)] +mod tests { + use crate::{mcp_http::utils::empty_response, mcp_server::error::TransportServerError}; + + use super::*; + use async_trait::async_trait; + use bytes::Bytes; + use http::{Request, Response}; + use http_body_util::{BodyExt, Full}; + use std::sync::Mutex; + use thiserror::Error; + + /// Custom error type for test middleware. + #[derive(Error, Debug)] + enum TestMiddlewareError { + #[error("Request processing failed: {0}")] + RequestError(String), + #[error("Response processing failed: {0}")] + ResponseError(String), + } + + /// A test middleware that records its interactions with requests and responses. + struct TestMiddleware { + /// Tracks request calls with their input bodies. + request_calls: Arc>>, + /// Tracks response calls with their status codes. + response_calls: Arc>>, + /// Optional error to simulate failure in request processing. + request_error: Option, + /// Optional error to simulate failure in response processing. + response_error: Option, + } + + impl TestMiddleware { + fn new() -> Self { + TestMiddleware { + request_calls: Arc::new(Mutex::new(Vec::new())), + response_calls: Arc::new(Mutex::new(Vec::new())), + request_error: None, + response_error: None, + } + } + + fn with_errors(request_error: Option, response_error: Option) -> Self { + TestMiddleware { + request_calls: Arc::new(Mutex::new(Vec::new())), + response_calls: Arc::new(Mutex::new(Vec::new())), + request_error, + response_error, + } + } + } + + #[async_trait] + impl Middleware for TestMiddleware { + fn process_request<'a, 'b>( + &'a self, + request: Request<&'b str>, + ) -> Pin>> + Send + 'a>> + where + 'b: 'a, + { + Box::pin(async move { + if let Some(err) = &self.request_error { + return Err(TransportServerError::HttpError(err.to_string())); + } + self.request_calls + .lock() + .unwrap() + .push(request.body().to_string()); + Ok(request) + }) + } + + fn process_response<'a, 'b>( + &'a self, + response: Response, + ) -> Pin>> + Send + 'a>> + where + 'b: 'a, + { + Box::pin(async move { + if let Some(err) = &self.response_error { + return Err(TransportServerError::HttpError(err.to_string())); + } + self.response_calls + .lock() + .unwrap() + .push(response.status().as_u16()); + Ok(response) + }) + } + } + + #[tokio::test] + async fn test_empty_middleware_chain() { + let chain = MiddlewareChain::new(); + let request = Request::builder().body("test").unwrap(); + + let response = Response::builder() + .status(200) + .body(empty_response()) + .unwrap(); + + let result_request = chain.process_request(request).await.unwrap(); + let result_response = chain.process_response(response).await.unwrap(); + + assert_eq!(result_request.body().to_ascii_lowercase(), "test"); + assert_eq!(result_response.status(), 200); + } + + #[tokio::test] + async fn test_single_middleware() { + let mut chain = MiddlewareChain::new(); + let middleware = TestMiddleware::new(); + let request_calls = middleware.request_calls.clone(); + let response_calls = middleware.response_calls.clone(); + + chain.add_middleware(middleware); + + let request = Request::builder().body("test").unwrap(); + let response = Response::builder() + .status(200) + .body(empty_response()) + .unwrap(); + + let result_request = chain.process_request(request).await.unwrap(); + let result_response = chain.process_response(response).await.unwrap(); + + assert_eq!(result_request.body().to_ascii_lowercase(), "test"); + assert_eq!(result_response.status(), 200); + assert_eq!(request_calls.lock().unwrap().as_slice(), &["test"]); + assert_eq!(response_calls.lock().unwrap().as_slice(), &[200]); + } + + #[tokio::test] + async fn test_multiple_middlewares_request_order() { + let mut chain = MiddlewareChain::new(); + let middleware1 = TestMiddleware::new(); + let middleware2 = TestMiddleware::new(); + let request_calls1 = middleware1.request_calls.clone(); + let request_calls2 = middleware2.request_calls.clone(); + + chain.add_middleware(middleware1); + chain.add_middleware(middleware2); + + let request = Request::builder().body("test").unwrap(); + + let result = chain.process_request(request).await.unwrap(); + assert_eq!(result.body().to_ascii_lowercase(), "test"); + + // Check order of execution + assert_eq!(request_calls1.lock().unwrap().as_slice(), &["test"]); + assert_eq!(request_calls2.lock().unwrap().as_slice(), &["test"]); + } + + #[tokio::test] + async fn test_multiple_middlewares_response_reverse_order() { + let mut chain = MiddlewareChain::new(); + let middleware1 = TestMiddleware::new(); + let middleware2 = TestMiddleware::new(); + let response_calls1 = middleware1.response_calls.clone(); + let response_calls2 = middleware2.response_calls.clone(); + + chain.add_middleware(middleware1); + chain.add_middleware(middleware2); + + let response = Response::builder() + .status(200) + .body(empty_response()) + .unwrap(); + + let result = chain.process_response(response).await.unwrap(); + assert_eq!(result.status(), 200); + + // Check reverse order of execution + assert_eq!(response_calls2.lock().unwrap().as_slice(), &[200]); + assert_eq!(response_calls1.lock().unwrap().as_slice(), &[200]); + } + + #[tokio::test] + async fn test_middleware_request_error() { + let mut chain = MiddlewareChain::new(); + let middleware = TestMiddleware::with_errors(Some("request error".to_string()), None); + chain.add_middleware(middleware); + + let request = Request::builder().body("test").unwrap(); + + let result = chain.process_request(request).await; + assert!(result.is_err()); + assert_eq!(result.unwrap_err().to_string(), "request error"); + } + + #[tokio::test] + async fn test_middleware_response_error() { + let mut chain = MiddlewareChain::new(); + let middleware = TestMiddleware::with_errors(None, Some("response error".to_string())); + chain.add_middleware(middleware); + + let response = Response::builder() + .status(200) + .body(empty_response()) + .unwrap(); + + let result = chain.process_response(response).await; + assert!(result.is_err()); + assert_eq!(result.unwrap_err().to_string(), "response error"); + } + + #[tokio::test] + async fn test_middleware_chain_clone() { + let mut chain = MiddlewareChain::new(); + let middleware = TestMiddleware::new(); + let request_calls = middleware.request_calls.clone(); + + chain.add_middleware(middleware); + let chain_clone = chain.clone(); + + let request = Request::builder().body("test").unwrap(); + + // Process on original and clone + chain.process_request(request.clone()).await.unwrap(); + chain_clone.process_request(request).await.unwrap(); + + // Both should have processed the request + assert_eq!(request_calls.lock().unwrap().as_slice(), &["test", "test"]); + } +} diff --git a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs index 6d003b9..06020d1 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_utils.rs @@ -34,6 +34,17 @@ const DUPLEX_BUFFER_SIZE: usize = 8192; pub type GenericBody = BoxBody; +/// Creates an empty HTTP response body. +/// +/// This function constructs a `GenericBody` containing an empty `Bytes` buffer, +/// The body is wrapped in a `BoxBody` to ensure type erasure and compatibility +/// with the HTTP framework. +pub fn empty_response() -> GenericBody { + Full::new(Bytes::new()) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed() +} + /// Creates an initial SSE event that returns the messages endpoint /// /// Constructs an SSE event containing the messages endpoint URL with the session ID.