Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ impl SnowflakeIdGenerator {
.expect("invalid system time!")
.as_millis() as u64;

now - *SHORTER_EPOCH
now.saturating_sub(*SHORTER_EPOCH)
}

fn next_id(&self) -> u64 {
Expand Down
6 changes: 5 additions & 1 deletion crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::hyper_servers::error::TransportServerResult;
use crate::mcp_http::{McpAppState, McpHttpHandler};
use axum::{extract::State, response::IntoResponse, routing::get, Extension, Router};
use http::{HeaderMap, Method, Uri};
use std::sync::Arc;

#[derive(Clone)]
Expand Down Expand Up @@ -35,13 +36,16 @@ pub fn routes(sse_endpoint: &str, sse_message_endpoint: &str) -> Router<Arc<McpA
/// # Returns
/// * `TransportServerResult<impl IntoResponse>` - The SSE response stream or an error
pub async fn handle_sse(
headers: HeaderMap,
uri: Uri,
Extension(sse_message_endpoint): Extension<SseMessageEndpoint>,
Extension(http_handler): Extension<Arc<McpHttpHandler>>,
State(state): State<Arc<McpAppState>>,
) -> TransportServerResult<impl IntoResponse> {
let SseMessageEndpoint(sse_message_endpoint) = sse_message_endpoint;
let request = McpHttpHandler::create_request(Method::GET, uri, headers, None);
let generic_response = http_handler
.handle_sse_connection(state.clone(), Some(&sse_message_endpoint))
.handle_sse_connection(request, state.clone(), Some(&sse_message_endpoint))
.await?;
let (parts, body) = generic_response.into_parts();
let resp = axum::response::Response::from_parts(parts, axum::body::Body::new(body));
Expand Down
21 changes: 16 additions & 5 deletions crates/rust-mcp-sdk/src/hyper_servers/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ use crate::{
error::SdkResult,
id_generator::{FastIdGenerator, UuidGenerator},
mcp_http::{
utils::{
http_utils::{
DEFAULT_MESSAGES_ENDPOINT, DEFAULT_SSE_ENDPOINT, DEFAULT_STREAMABLE_HTTP_ENDPOINT,
},
middleware::dns_rebind_protector::DnsRebindProtector,
McpAppState, McpHttpHandler,
},
mcp_server::hyper_runtime::HyperRuntime,
Expand Down Expand Up @@ -203,6 +204,11 @@ impl HyperServerOptions {
.as_deref()
.unwrap_or(DEFAULT_STREAMABLE_HTTP_ENDPOINT)
}

pub fn needs_dns_protection(&self) -> bool {
self.dns_rebinding_protection
&& (self.allowed_hosts.is_some() || self.allowed_origins.is_some())
}
}

/// Default implementation for HyperServerOptions
Expand Down Expand Up @@ -270,13 +276,18 @@ impl HyperServer {
ping_interval: server_options.ping_interval,
transport_options: Arc::clone(&server_options.transport_options),
enable_json_response: server_options.enable_json_response.unwrap_or(false),
allowed_hosts: server_options.allowed_hosts.take(),
allowed_origins: server_options.allowed_origins.take(),
dns_rebinding_protection: server_options.dns_rebinding_protection,
event_store: server_options.event_store.as_ref().map(Arc::clone),
});

let http_handler = McpHttpHandler::new(); //TODO: add auth handlers
let mut http_handler = McpHttpHandler::new();

if server_options.needs_dns_protection() {
http_handler.add_middleware(DnsRebindProtector::new(
server_options.allowed_hosts.take(),
server_options.allowed_origins.take(),
));
}

let app = app_routes(Arc::clone(&state), &server_options, http_handler);
Self {
app,
Expand Down
2 changes: 1 addition & 1 deletion crates/rust-mcp-sdk/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ pub mod error;
mod hyper_servers;
mod mcp_handlers;
#[cfg(feature = "hyper-server")]
pub(crate) mod mcp_http;
pub mod mcp_http;
mod mcp_macros;
mod mcp_runtimes;
mod mcp_traits;
Expand Down
13 changes: 6 additions & 7 deletions crates/rust-mcp-sdk/src/mcp_http.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
mod app_state;
pub(crate) mod http_utils;
mod mcp_http_handler;
pub(crate) mod mcp_http_utils;

mod mcp_http_middleware; //TODO:
pub mod middleware;
mod types;

pub use app_state::*;
pub use http_utils::*;
pub use mcp_http_handler::*;
pub use mcp_http_middleware::Middleware;
pub use types::*;

pub(crate) mod utils {
pub use super::mcp_http_utils::*;
}
pub use middleware::Middleware;
16 changes: 0 additions & 16 deletions crates/rust-mcp-sdk/src/mcp_http/app_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,7 @@ pub struct McpAppState {
pub ping_interval: Duration,
pub transport_options: Arc<TransportOptions>,
pub enable_json_response: bool,
/// List of allowed host header values for DNS rebinding protection.
/// If not specified, host validation is disabled.
pub allowed_hosts: Option<Vec<String>>,
/// List of allowed origin header values for DNS rebinding protection.
/// If not specified, origin validation is disabled.
pub allowed_origins: Option<Vec<String>>,
/// Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured).
/// Default is false for backwards compatibility.
pub dns_rebinding_protection: bool,
/// Event store for resumability support
/// If provided, resumability will be enabled, allowing clients to reconnect and resume messages
pub event_store: Option<Arc<dyn EventStore>>,
}

impl McpAppState {
pub fn needs_dns_protection(&self) -> bool {
self.dns_rebinding_protection
&& (self.allowed_hosts.is_some() || self.allowed_origins.is_some())
}
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::mcp_http::types::GenericBody;
use crate::schema::schema_utils::{ClientMessage, SdkError};
use crate::{
error::SdkResult,
Expand All @@ -11,10 +12,10 @@ use crate::{
use axum::http::HeaderValue;
use bytes::Bytes;
use futures::stream;
use http::header::{ACCEPT, CONNECTION, CONTENT_TYPE, HOST, ORIGIN};
use http::header::{ACCEPT, CONNECTION, CONTENT_TYPE};
use http_body::Frame;
use http_body_util::StreamBody;
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use http_body_util::{BodyExt, Full};
use hyper::{HeaderMap, StatusCode};
use rust_mcp_transport::{
EventId, McpDispatch, SessionId, SseEvent, SseTransport, StreamId, ID_SEPARATOR,
Expand All @@ -32,8 +33,6 @@ pub(crate) const DEFAULT_MESSAGES_ENDPOINT: &str = "/messages";
pub(crate) const DEFAULT_STREAMABLE_HTTP_ENDPOINT: &str = "/mcp";
const DUPLEX_BUFFER_SIZE: usize = 8192;

pub type GenericBody = BoxBody<Bytes, TransportServerError>;

/// Creates an empty HTTP response body.
///
/// This function constructs a `GenericBody` containing an empty `Bytes` buffer,
Expand All @@ -45,6 +44,20 @@ pub fn empty_response() -> GenericBody {
.boxed()
}

pub fn build_response(
status_code: StatusCode,
payload: String,
) -> Result<http::Response<GenericBody>, TransportServerError> {
let body = Full::new(Bytes::from(payload))
.map_err(|err| TransportServerError::HttpError(err.to_string()))
.boxed();

http::Response::builder()
.status(status_code)
.body(body)
.map_err(|err| TransportServerError::HttpError(err.to_string()))
}

/// Creates an initial SSE event that returns the messages endpoint
///
/// Constructs an SSE event containing the messages endpoint URL with the session ID.
Expand Down Expand Up @@ -251,7 +264,7 @@ fn is_result(json_str: &str) -> Result<bool, serde_json::Error> {
}
}

pub async fn create_standalone_stream(
pub(crate) async fn create_standalone_stream(
session_id: SessionId,
last_event_id: Option<EventId>,
state: Arc<McpAppState>,
Expand Down Expand Up @@ -287,7 +300,7 @@ pub async fn create_standalone_stream(
Ok(response)
}

pub async fn start_new_session(
pub(crate) async fn start_new_session(
state: Arc<McpAppState>,
payload: &str,
) -> TransportServerResult<http::Response<GenericBody>> {
Expand Down Expand Up @@ -421,7 +434,7 @@ async fn single_shot_stream(
}
}

pub async fn process_incoming_message_return(
pub(crate) async fn process_incoming_message_return(
session_id: SessionId,
state: Arc<McpAppState>,
payload: &str,
Expand All @@ -446,7 +459,7 @@ pub async fn process_incoming_message_return(
}
}

pub async fn process_incoming_message(
pub(crate) async fn process_incoming_message(
session_id: SessionId,
state: Arc<McpAppState>,
payload: &str,
Expand Down Expand Up @@ -499,11 +512,11 @@ pub async fn process_incoming_message(
}
}

pub fn is_empty_sse_message(sse_payload: &str) -> bool {
pub(crate) fn is_empty_sse_message(sse_payload: &str) -> bool {
sse_payload.is_empty() || sse_payload.trim() == ":"
}

pub async fn delete_session(
pub(crate) async fn delete_session(
session_id: SessionId,
state: Arc<McpAppState>,
) -> TransportServerResult<http::Response<GenericBody>> {
Expand All @@ -529,7 +542,7 @@ pub async fn delete_session(
}
}

pub fn acceptable_content_type(headers: &HeaderMap) -> bool {
pub(crate) fn acceptable_content_type(headers: &HeaderMap) -> bool {
let accept_header = headers
.get("content-type")
.and_then(|val| val.to_str().ok())
Expand All @@ -539,7 +552,7 @@ pub fn acceptable_content_type(headers: &HeaderMap) -> bool {
.any(|val| val.trim().starts_with("application/json"))
}

pub fn validate_mcp_protocol_version_header(headers: &HeaderMap) -> SdkResult<()> {
pub(crate) fn validate_mcp_protocol_version_header(headers: &HeaderMap) -> SdkResult<()> {
let protocol_version_header = headers
.get(MCP_PROTOCOL_VERSION_HEADER)
.and_then(|val| val.to_str().ok())
Expand All @@ -553,7 +566,7 @@ pub fn validate_mcp_protocol_version_header(headers: &HeaderMap) -> SdkResult<()
validate_mcp_protocol_version(protocol_version_header)
}

pub fn accepts_event_stream(headers: &HeaderMap) -> bool {
pub(crate) fn accepts_event_stream(headers: &HeaderMap) -> bool {
let accept_header = headers
.get(ACCEPT)
.and_then(|val| val.to_str().ok())
Expand All @@ -564,7 +577,7 @@ pub fn accepts_event_stream(headers: &HeaderMap) -> bool {
.any(|val| val.trim().starts_with("text/event-stream"))
}

pub fn valid_streaming_http_accept_header(headers: &HeaderMap) -> bool {
pub(crate) fn valid_streaming_http_accept_header(headers: &HeaderMap) -> bool {
let accept_header = headers
.get(ACCEPT)
.and_then(|val| val.to_str().ok())
Expand Down Expand Up @@ -593,53 +606,6 @@ pub fn error_response(
.map_err(|err| TransportServerError::HttpError(err.to_string()))
}

// Protect against DNS rebinding attacks by validating Host and Origin headers.
pub(crate) async fn protect_dns_rebinding(
headers: &http::HeaderMap,
state: Arc<McpAppState>,
) -> Result<(), SdkError> {
if !state.needs_dns_protection() {
// If protection is not needed, pass the request to the next handler
return Ok(());
}

if let Some(allowed_hosts) = state.allowed_hosts.as_ref() {
if !allowed_hosts.is_empty() {
let Some(host) = headers.get(HOST).and_then(|h| h.to_str().ok()) else {
return Err(SdkError::bad_request().with_message("Invalid Host header: [unknown] "));
};

if !allowed_hosts
.iter()
.any(|allowed| allowed.eq_ignore_ascii_case(host))
{
return Err(SdkError::bad_request()
.with_message(format!("Invalid Host header: \"{host}\" ").as_str()));
}
}
}

if let Some(allowed_origins) = state.allowed_origins.as_ref() {
if !allowed_origins.is_empty() {
let Some(origin) = headers.get(ORIGIN).and_then(|h| h.to_str().ok()) else {
return Err(
SdkError::bad_request().with_message("Invalid Origin header: [unknown] ")
);
};

if !allowed_origins
.iter()
.any(|allowed| allowed.eq_ignore_ascii_case(origin))
{
return Err(SdkError::bad_request()
.with_message(format!("Invalid Origin header: \"{origin}\" ").as_str()));
}
}
}

Ok(())
}

/// Extracts the value of a query parameter from an HTTP request by key.
///
/// This function parses the query string from the request URI and searches
Expand All @@ -653,7 +619,7 @@ pub(crate) async fn protect_dns_rebinding(
/// * `Some(String)` containing the value of the query parameter if found.
/// * `None` if the query string is missing or the key is not present.
///
pub fn query_param(request: &http::Request<&str>, key: &str) -> Option<String> {
pub(crate) fn query_param(request: &http::Request<&str>, key: &str) -> Option<String> {
request.uri().query().and_then(|query| {
for pair in query.split('&') {
let mut split = pair.splitn(2, '=');
Expand Down
Loading