Skip to content

Commit 6f70e18

Browse files
authored
fix: issue 116 - custom_streamable_http_endpoint (#117)
* Issue #116: Fix custom_streamable_http_endpoint * Adding some basic unit tests for HyperServerOptions * Adding tempfile to properly test SSL validation in HyperServerOptions
1 parent 2688e1e commit 6f70e18

File tree

3 files changed

+219
-2
lines changed

3 files changed

+219
-2
lines changed

Cargo.lock

Lines changed: 50 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/rust-mcp-sdk/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ reqwest = { workspace = true, default-features = false, features = [
4747
"cookies",
4848
"multipart",
4949
] }
50+
tempfile = "3.23.0"
5051
tracing-subscriber = { workspace = true, features = [
5152
"env-filter",
5253
"std",

crates/rust-mcp-sdk/src/hyper_servers/server.rs

Lines changed: 168 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ impl HyperServerOptions {
200200
}
201201

202202
pub fn streamable_http_endpoint(&self) -> &str {
203-
self.custom_messages_endpoint
203+
self.custom_streamable_http_endpoint
204204
.as_deref()
205205
.unwrap_or(DEFAULT_STREAMABLE_HTTP_ENDPOINT)
206206
}
@@ -490,3 +490,170 @@ async fn shutdown_signal(handle: Handle, state: Arc<McpAppState>) {
490490
// Trigger graceful shutdown with a timeout
491491
handle.graceful_shutdown(Some(Duration::from_secs(GRACEFUL_SHUTDOWN_TMEOUT_SECS)));
492492
}
493+
494+
#[cfg(test)]
495+
mod tests {
496+
use super::*;
497+
498+
use tempfile::NamedTempFile;
499+
500+
#[test]
501+
fn test_server_options_base_url_custom() {
502+
let options = HyperServerOptions {
503+
host: String::from("127.0.0.1"),
504+
port: 8081,
505+
enable_ssl: true,
506+
..Default::default()
507+
};
508+
assert_eq!(options.base_url(), "https://127.0.0.1:8081");
509+
}
510+
511+
#[test]
512+
fn test_server_options_streamable_http_custom() {
513+
let options = HyperServerOptions {
514+
custom_streamable_http_endpoint: Some(String::from("/abcd/mcp")),
515+
host: String::from("127.0.0.1"),
516+
port: 8081,
517+
enable_ssl: true,
518+
..Default::default()
519+
};
520+
assert_eq!(
521+
options.streamable_http_url(),
522+
"https://127.0.0.1:8081/abcd/mcp"
523+
);
524+
assert_eq!(options.streamable_http_endpoint(), "/abcd/mcp");
525+
}
526+
527+
#[test]
528+
fn test_server_options_sse_custom() {
529+
let options = HyperServerOptions {
530+
custom_sse_endpoint: Some(String::from("/abcd/sse")),
531+
host: String::from("127.0.0.1"),
532+
port: 8081,
533+
enable_ssl: true,
534+
..Default::default()
535+
};
536+
assert_eq!(options.sse_url(), "https://127.0.0.1:8081/abcd/sse");
537+
assert_eq!(options.sse_endpoint(), "/abcd/sse");
538+
}
539+
540+
#[test]
541+
fn test_server_options_sse_messages_custom() {
542+
let options = HyperServerOptions {
543+
custom_messages_endpoint: Some(String::from("/abcd/messages")),
544+
..Default::default()
545+
};
546+
assert_eq!(
547+
options.sse_message_url(),
548+
"http://127.0.0.1:8080/abcd/messages"
549+
);
550+
assert_eq!(options.sse_messages_endpoint(), "/abcd/messages");
551+
}
552+
553+
#[test]
554+
fn test_server_options_needs_dns_protection() {
555+
let options = HyperServerOptions::default();
556+
557+
// should be false by default
558+
assert!(!options.needs_dns_protection());
559+
560+
// should still be false unless allowed_hosts or allowed_origins are also provided
561+
let options = HyperServerOptions {
562+
dns_rebinding_protection: true,
563+
..Default::default()
564+
};
565+
assert!(!options.needs_dns_protection());
566+
567+
// should be true when dns_rebinding_protection is true and allowed_hosts is provided
568+
let options = HyperServerOptions {
569+
dns_rebinding_protection: true,
570+
allowed_hosts: Some(vec![String::from("127.0.0.1")]),
571+
..Default::default()
572+
};
573+
assert!(options.needs_dns_protection());
574+
575+
// should be true when dns_rebinding_protection is true and allowed_origins is provided
576+
let options = HyperServerOptions {
577+
dns_rebinding_protection: true,
578+
allowed_origins: Some(vec![String::from("http://127.0.0.1:8080")]),
579+
..Default::default()
580+
};
581+
assert!(options.needs_dns_protection());
582+
}
583+
584+
#[test]
585+
fn test_server_options_validate() {
586+
let options = HyperServerOptions::default();
587+
assert!(options.validate().is_ok());
588+
589+
// with ssl enabled but no cert or key provided, validate should fail
590+
let options = HyperServerOptions {
591+
enable_ssl: true,
592+
..Default::default()
593+
};
594+
assert!(options.validate().is_err());
595+
596+
// with ssl enabled and invalid cert/key paths, validate should fail
597+
let options = HyperServerOptions {
598+
enable_ssl: true,
599+
ssl_cert_path: Some(String::from("/invalid/path/to/cert.pem")),
600+
ssl_key_path: Some(String::from("/invalid/path/to/key.pem")),
601+
..Default::default()
602+
};
603+
assert!(options.validate().is_err());
604+
605+
// with ssl enabled and valid cert/key paths, validate should succeed
606+
let cert_file =
607+
NamedTempFile::with_suffix(".pem").expect("Expected to create test cert file");
608+
let ssl_cert_path = cert_file
609+
.path()
610+
.to_str()
611+
.expect("Expected to get cert path")
612+
.to_string();
613+
let key_file =
614+
NamedTempFile::with_suffix(".pem").expect("Expected to create test key file");
615+
let ssl_key_path = key_file
616+
.path()
617+
.to_str()
618+
.expect("Expected to get key path")
619+
.to_string();
620+
621+
let options = HyperServerOptions {
622+
enable_ssl: true,
623+
ssl_cert_path: Some(ssl_cert_path),
624+
ssl_key_path: Some(ssl_key_path),
625+
..Default::default()
626+
};
627+
assert!(options.validate().is_ok());
628+
}
629+
630+
#[tokio::test]
631+
async fn test_server_options_resolve_server_address() {
632+
let options = HyperServerOptions::default();
633+
assert!(options.resolve_server_address().await.is_ok());
634+
635+
// valid host should still work
636+
let options = HyperServerOptions {
637+
host: String::from("8.6.7.5"),
638+
port: 309,
639+
..Default::default()
640+
};
641+
assert!(options.resolve_server_address().await.is_ok());
642+
643+
// valid host (prepended with http://) should still work
644+
let options = HyperServerOptions {
645+
host: String::from("http://8.6.7.5"),
646+
port: 309,
647+
..Default::default()
648+
};
649+
assert!(options.resolve_server_address().await.is_ok());
650+
651+
// invalid host should raise an error
652+
let options = HyperServerOptions {
653+
host: String::from("invalid-host"),
654+
port: 309,
655+
..Default::default()
656+
};
657+
assert!(options.resolve_server_address().await.is_err());
658+
}
659+
}

0 commit comments

Comments
 (0)