From 6b6c1cd10c4c8d47dc756dbf2b508f40fd6483bc Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Tue, 1 Jul 2025 16:30:48 -0300 Subject: [PATCH] fix: address issue with improper server start failure handling --- crates/rust-mcp-sdk/src/error.rs | 12 ++++++++++- .../src/mcp_handlers/mcp_server_handler.rs | 20 ++++++++++++------- .../src/mcp_runtimes/server_runtime.rs | 9 ++++++++- .../hello-world-mcp-server-core/src/main.rs | 10 +++++++++- examples/hello-world-mcp-server/src/main.rs | 10 +++++++++- .../src/handler.rs | 12 ++++++++--- rust-toolchain.toml | 2 +- 7 files changed, 60 insertions(+), 15 deletions(-) diff --git a/crates/rust-mcp-sdk/src/error.rs b/crates/rust-mcp-sdk/src/error.rs index 63bb1b5..788e31a 100644 --- a/crates/rust-mcp-sdk/src/error.rs +++ b/crates/rust-mcp-sdk/src/error.rs @@ -22,9 +22,19 @@ pub enum McpSdkError { #[cfg(feature = "hyper-server")] #[error("{0}")] TransportServerError(#[from] TransportServerError), - #[error("Incompatible mcp protocol version!\n client:{0}\nserver:{1}")] + #[error("Incompatible mcp protocol version: client:{0} server:{1}")] IncompatibleProtocolVersion(String, String), } +impl McpSdkError { + /// Returns the RPC error message if the error is of type `McpSdkError::RpcError`. + pub fn rpc_error_message(&self) -> Option<&String> { + if let McpSdkError::RpcError(rpc_error) = self { + return Some(&rpc_error.message); + } + None + } +} + #[deprecated(since = "0.2.0", note = "Use `McpSdkError` instead.")] pub type MCPSdkError = McpSdkError; diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs index f1e129a..5b0fdc0 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs @@ -31,10 +31,6 @@ pub trait ServerHandler: Send + Sync + 'static { initialize_request: InitializeRequest, runtime: &dyn McpServer, ) -> std::result::Result { - runtime - .set_client_details(initialize_request.params.clone()) - .map_err(|err| RpcError::internal_error().with_message(format!("{err}")))?; - let mut server_info = runtime.server_info().to_owned(); // Provide compatibility for clients using older MCP protocol versions. @@ -42,11 +38,21 @@ pub trait ServerHandler: Send + Sync + 'static { &initialize_request.params.protocol_version, &server_info.protocol_version, ) - .map_err(|err| RpcError::internal_error().with_message(err.to_string()))? - { - server_info.protocol_version = initialize_request.params.protocol_version; + .map_err(|err| { + tracing::error!( + "Incompatible protocol version : client: {} server: {}", + &initialize_request.params.protocol_version, + &server_info.protocol_version + ); + RpcError::internal_error().with_message(err.to_string()) + })? { + server_info.protocol_version = updated_protocol_version; } + runtime + .set_client_details(initialize_request.params.clone()) + .map_err(|err| RpcError::internal_error().with_message(format!("{err}")))?; + Ok(server_info) } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs index 6679ae7..fff5acd 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs @@ -106,7 +106,14 @@ impl McpServer for ServerRuntime { // create a response to send back to the client let response: MessageFromServer = match result { Ok(success_value) => success_value.into(), - Err(error_value) => MessageFromServer::Error(error_value), + Err(error_value) => { + // Error occurred during initialization. + // A likely cause could be an unsupported protocol version. + if !self.is_initialized() { + return Err(error_value.into()); + } + MessageFromServer::Error(error_value) + } }; // send the response back with corresponding request id diff --git a/examples/hello-world-mcp-server-core/src/main.rs b/examples/hello-world-mcp-server-core/src/main.rs index e0e7de1..1bb9f72 100644 --- a/examples/hello-world-mcp-server-core/src/main.rs +++ b/examples/hello-world-mcp-server-core/src/main.rs @@ -40,5 +40,13 @@ async fn main() -> SdkResult<()> { let server = server_runtime_core::create_server(server_details, transport, handler); // STEP 5: Start the server - server.start().await + if let Err(start_error) = server.start().await { + eprintln!( + "{}", + start_error + .rpc_error_message() + .unwrap_or(&start_error.to_string()) + ); + }; + Ok(()) } diff --git a/examples/hello-world-mcp-server/src/main.rs b/examples/hello-world-mcp-server/src/main.rs index 6a22711..d946fd2 100644 --- a/examples/hello-world-mcp-server/src/main.rs +++ b/examples/hello-world-mcp-server/src/main.rs @@ -42,5 +42,13 @@ async fn main() -> SdkResult<()> { let server: ServerRuntime = server_runtime::create_server(server_details, transport, handler); // STEP 5: Start the server - server.start().await + if let Err(start_error) = server.start().await { + eprintln!( + "{}", + start_error + .rpc_error_message() + .unwrap_or(&start_error.to_string()) + ); + }; + Ok(()) } diff --git a/examples/hello-world-server-core-sse/src/handler.rs b/examples/hello-world-server-core-sse/src/handler.rs index 82b0b2b..53f884c 100644 --- a/examples/hello-world-server-core-sse/src/handler.rs +++ b/examples/hello-world-server-core-sse/src/handler.rs @@ -36,9 +36,15 @@ impl ServerHandlerCore for MyServerHandler { &initialize_request.params.protocol_version, &server_info.protocol_version, ) - .map_err(|err| RpcError::internal_error().with_message(err.to_string()))? - { - server_info.protocol_version = initialize_request.params.protocol_version; + .map_err(|err| { + tracing::error!( + "Incompatible protocol version :\nclient: {}\nserver: {}", + &initialize_request.params.protocol_version, + &server_info.protocol_version + ); + RpcError::internal_error().with_message(err.to_string()) + })? { + server_info.protocol_version = updated_protocol_version; } return Ok(server_info.into()); diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 5675074..7855e6d 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,3 +1,3 @@ [toolchain] -channel = "1.87.0" +channel = "1.88.0" components = ["rustfmt", "clippy"]