diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8badef4..eb07df2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -69,6 +69,7 @@ jobs: uses: actions-rust-lang/audit@v1 with: token: ${{ secrets.GITHUB_TOKEN }} + ignore: RUSTSEC-2023-0071 if: matrix.os == 'ubuntu-latest' # Run audit only on Linux # install nodejs that is required for tests diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 91033a8..89a7ac1 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -24,7 +24,7 @@ Anyone can participate at any stage, whether it's discussing, triaging, or revie ### **Filing a Bug Report** -When reporting a bug, use the provided issue template and fill in as many details as possible. Don’t worry if you can’t answer everything—just provide what you can. +When reporting a bug, use the provided issue template and fill in as many details as possible. Don’t worry if you can’t answer everything-just provide what you can. ### **Fixing Issues** diff --git a/Cargo.lock b/Cargo.lock index 3c18cf2..638b1ed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,6 +11,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "anyhow" version = "1.0.100" @@ -40,10 +49,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "81953c529336010edd6d8e358f886d9581267795c61b19475b71314bffa46d35" dependencies = [ "concurrent-queue", - "event-listener", + "event-listener 2.5.3", "futures-core", ] +[[package]] +name = "async-lock" +version = "3.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fd03604047cee9b6ce9de9f70c6cd540a0520c813cbd49bae61f33ab80ed1dc" +dependencies = [ + "event-listener 5.4.1", + "event-listener-strategy", + "pin-project-lite", +] + [[package]] name = "async-trait" version = "0.1.89" @@ -69,21 +89,21 @@ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] name = "aws-lc-rs" -version = "1.14.1" +version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "879b6c89592deb404ba4dc0ae6b58ffd1795c78991cbb5b8bc441c48a070440d" +checksum = "6b5ce75405893cd713f9ab8e297d8e438f624dde7d706108285f7e17a25a180f" dependencies = [ "aws-lc-sys", + "untrusted 0.7.1", "zeroize", ] [[package]] name = "aws-lc-sys" -version = "0.32.3" +version = "0.34.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "107a4e9d9cab9963e04e84bb8dee0e25f2a987f9a8bad5ed054abd439caa8f8c" +checksum = "179c3777a8b5e70e90ea426114ffc565b2c1a9f82f6c4a0c5a34aa6ef5e781b6" dependencies = [ - "bindgen", "cc", "cmake", "dunce", @@ -92,9 +112,9 @@ dependencies = [ [[package]] name = "axum" -version = "0.8.6" +version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a18ed336352031311f4e0b4dd2ff392d4fbb370777c9d18d7fc9d7359f73871" +checksum = "5b098575ebe77cb6d14fc7f32749631a6e44edbef6b796f89b020e99ba20d425" dependencies = [ "axum-core", "bytes", @@ -103,7 +123,7 @@ dependencies = [ "http 1.3.1", "http-body 1.0.1", "http-body-util", - "hyper 1.7.0", + "hyper 1.8.1", "hyper-util", "itoa", "matchit", @@ -144,16 +164,16 @@ dependencies = [ [[package]] name = "axum-server" -version = "0.7.2" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "495c05f60d6df0093e8fb6e74aa5846a0ad06abaf96d76166283720bf740f8ab" +checksum = "c1ab4a3ec9ea8a657c72d99a03a824af695bd0fb5ec639ccbd9cd3543b41a5f9" dependencies = [ "arc-swap", "bytes", "fs-err", "http 1.3.1", "http-body 1.0.1", - "hyper 1.7.0", + "hyper 1.8.1", "hyper-util", "pin-project-lite", "rustls", @@ -183,24 +203,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] -name = "bindgen" -version = "0.72.1" +name = "base64ct" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "993776b509cfb49c750f11b8f07a46fa23e0a1386ffc01fb1e7d343efc387895" -dependencies = [ - "bitflags", - "cexpr", - "clang-sys", - "itertools", - "log", - "prettyplease", - "proc-macro2", - "quote", - "regex", - "rustc-hash", - "shlex", - "syn", -] +checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba" [[package]] name = "bitflags" @@ -208,6 +214,15 @@ version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "bumpalo" version = "3.19.0" @@ -216,15 +231,15 @@ checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" [[package]] name = "bytes" -version = "1.10.1" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" +checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" [[package]] name = "cc" -version = "1.2.45" +version = "1.2.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35900b6c8d709fb1d854671ae27aeaa9eec2f8b01b364e1619a40da3e6fe2afe" +checksum = "cd405d82c84ff7f35739f175f67d8b9fb7687a0e84ccdc78bd3568839827cf07" dependencies = [ "find-msvc-tools", "jobserver", @@ -232,15 +247,6 @@ dependencies = [ "shlex", ] -[[package]] -name = "cexpr" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" -dependencies = [ - "nom", -] - [[package]] name = "cfg-if" version = "1.0.4" @@ -254,14 +260,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] -name = "clang-sys" -version = "1.8.1" +name = "chrono" +version = "0.4.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" +checksum = "145052bdd345b87320e369255277e3fb5152762ad123a901ef5c262dd38fe8d2" dependencies = [ - "glob", - "libc", - "libloading", + "iana-time-zone", + "js-sys", + "num-traits", + "serde", + "wasm-bindgen", + "windows-link", ] [[package]] @@ -291,6 +300,12 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "const-oid" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" + [[package]] name = "cookie" version = "0.18.1" @@ -320,12 +335,47 @@ dependencies = [ "url", ] +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + [[package]] name = "crossbeam-utils" version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" +[[package]] +name = "crypto-common" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +dependencies = [ + "generic-array", + "typenum", +] + [[package]] name = "deadpool" version = "0.9.5" @@ -345,6 +395,17 @@ version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b" +[[package]] +name = "der" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" +dependencies = [ + "const-oid", + "pem-rfc7468", + "zeroize", +] + [[package]] name = "deranged" version = "0.5.5" @@ -354,6 +415,17 @@ dependencies = [ "powerfmt", ] +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "const-oid", + "crypto-common", +] + [[package]] name = "displaydoc" version = "0.2.5" @@ -381,10 +453,13 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" [[package]] -name = "either" -version = "1.15.0" +name = "encoding_rs" +version = "0.8.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" +dependencies = [ + "cfg-if", +] [[package]] name = "equivalent" @@ -408,6 +483,27 @@ version = "2.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" +[[package]] +name = "event-listener" +version = "5.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13b66accf52311f30a0db42147dadea9850cb48cd070028831ae5f5d4b856ab" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + +[[package]] +name = "event-listener-strategy" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8be9f3dfaaffdae2972880079a491a1a8bb7cbed0b8dd7a347f668b4150a3b93" +dependencies = [ + "event-listener 5.4.1", + "pin-project-lite", +] + [[package]] name = "fastrand" version = "1.9.0" @@ -425,9 +521,9 @@ checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" [[package]] name = "find-msvc-tools" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52051878f80a721bb68ebfbc930e07b65ba72f2da88968ea5c06fd6ca3d3a127" +checksum = "3a3076410a55c90011c298b04d0cfa770b00fa04e1e3c97d3f6c9de105a03844" [[package]] name = "fnv" @@ -435,6 +531,21 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "form_urlencoded" version = "1.2.2" @@ -446,9 +557,9 @@ dependencies = [ [[package]] name = "fs-err" -version = "3.1.3" +version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ad492b2cf1d89d568a43508ab24f98501fe03f2f31c01e1d0fe7366a71745d2" +checksum = "62d91fd049c123429b018c47887d3f75a265540dd3c30ba9cb7bae9197edb03a" dependencies = [ "autocfg", "tokio", @@ -570,6 +681,16 @@ dependencies = [ "slab", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.1.16" @@ -608,12 +729,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "glob" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" - [[package]] name = "h2" version = "0.3.27" @@ -654,9 +769,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.16.0" +version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5419bdc4f6a9207fbeba6d11b604d481addf78ecd10c11ad51e76c2f6482748d" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" [[package]] name = "hello-world-mcp-server-stdio" @@ -833,9 +948,9 @@ dependencies = [ [[package]] name = "hyper" -version = "1.7.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb3aa54a13a0dfe7fbe3a59e0c76093041720fdc77b110cc0fc260fafb4dc51e" +checksum = "2ab2d4f250c3d7b1c9fcdff1cece94ea4e2dfbec68614f7b87cb205f24ca9d11" dependencies = [ "atomic-waker", "bytes", @@ -861,7 +976,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" dependencies = [ "http 1.3.1", - "hyper 1.7.0", + "hyper 1.8.1", "hyper-util", "rustls", "rustls-pki-types", @@ -871,11 +986,27 @@ dependencies = [ "webpki-roots", ] +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper 1.8.1", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + [[package]] name = "hyper-util" -version = "0.1.17" +version = "0.1.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c6995591a8f1380fcb4ba966a252a4b29188d51d2b89e3a252f5305be65aea8" +checksum = "52e9a2a24dc5c6821e71a7030e1e14b7b632acac55c40e9d2e082c621261bb56" dependencies = [ "base64 0.22.1", "bytes", @@ -884,15 +1015,41 @@ dependencies = [ "futures-util", "http 1.3.1", "http-body 1.0.1", - "hyper 1.7.0", + "hyper 1.8.1", "ipnet", "libc", "percent-encoding", "pin-project-lite", "socket2 0.6.1", + "system-configuration", "tokio", "tower-service", "tracing", + "windows-registry", +] + +[[package]] +name = "iana-time-zone" +version = "0.1.64" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33e57f83510bb73707521ebaffa789ec8caf86f9657cad665b092b581d40e9fb" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "log", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", ] [[package]] @@ -999,9 +1156,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.12.0" +version = "2.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6717a8d2a5a929a1a2eb43a12812498ed141a0bcfb7e8f7844fbdbe4303bba9f" +checksum = "0ad4bb2b565bca0645f4d68c5c9af97fba094e9791da685bf83cb5f3ce74acf2" dependencies = [ "equivalent", "hashbrown", @@ -1038,15 +1195,6 @@ dependencies = [ "serde", ] -[[package]] -name = "itertools" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" -dependencies = [ - "either", -] - [[package]] name = "itoa" version = "1.0.15" @@ -1073,11 +1221,31 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "jsonwebtoken" +version = "10.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c76e1c7d7df3e34443b3621b459b066a7b79644f059fc8b2db7070c825fd417e" +dependencies = [ + "aws-lc-rs", + "base64 0.22.1", + "getrandom 0.2.16", + "js-sys", + "pem", + "serde", + "serde_json", + "signature", + "simple_asn1", +] + [[package]] name = "lazy_static" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +dependencies = [ + "spin", +] [[package]] name = "libc" @@ -1085,16 +1253,6 @@ version = "0.2.177" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976" -[[package]] -name = "libloading" -version = "0.8.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55" -dependencies = [ - "cfg-if", - "windows-link", -] - [[package]] name = "libm" version = "0.2.15" @@ -1177,12 +1335,6 @@ dependencies = [ "unicase", ] -[[package]] -name = "minimal-lexical" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" - [[package]] name = "mio" version = "1.1.0" @@ -1204,13 +1356,20 @@ dependencies = [ ] [[package]] -name = "nom" -version = "7.1.3" +name = "native-tls" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e" dependencies = [ - "memchr", - "minimal-lexical", + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", ] [[package]] @@ -1222,12 +1381,58 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-bigint-dig" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e661dda6640fad38e827a6d4a310ff4763082116fe217f279885c97f511bb0b7" +dependencies = [ + "lazy_static", + "libm", + "num-integer", + "num-iter", + "num-traits", + "rand 0.8.5", + "smallvec", + "zeroize", +] + [[package]] name = "num-conv" version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-iter" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -1248,12 +1453,84 @@ dependencies = [ "libc", ] +[[package]] +name = "oauth2-test-server" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bb78cf155f91eba1d99533e49aafc31f5e7e42b9964d2c0c8470d6641accb54" +dependencies = [ + "axum", + "base64 0.21.7", + "chrono", + "colored", + "futures", + "http 1.3.1", + "jsonwebtoken", + "once_cell", + "rand 0.8.5", + "reqwest", + "rsa", + "serde", + "serde_json", + "sha2", + "tokio", + "tower-http 0.5.2", + "tracing", + "tracing-subscriber", + "url", + "uuid", +] + [[package]] name = "once_cell" version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +[[package]] +name = "openssl" +version = "0.10.75" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328" +dependencies = [ + "bitflags", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "openssl-probe" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" + +[[package]] +name = "openssl-sys" +version = "0.9.111" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "parking" version = "2.2.1" @@ -1283,6 +1560,25 @@ dependencies = [ "windows-link", ] +[[package]] +name = "pem" +version = "3.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d30c53c26bc5b31a98cd02d20f25a7c8567146caf63ed593a9d87b2775291be" +dependencies = [ + "base64 0.22.1", + "serde_core", +] + +[[package]] +name = "pem-rfc7468" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] + [[package]] name = "percent-encoding" version = "2.3.2" @@ -1301,6 +1597,33 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkcs1" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8ffb9f10fa047879315e6625af03c164b16962a5368d724ed16323b68ace47f" +dependencies = [ + "der", + "pkcs8", + "spki", +] + +[[package]] +name = "pkcs8" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" +dependencies = [ + "der", + "spki", +] + +[[package]] +name = "pkg-config" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" + [[package]] name = "potential_utf" version = "0.1.4" @@ -1325,16 +1648,6 @@ dependencies = [ "zerocopy", ] -[[package]] -name = "prettyplease" -version = "0.2.37" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" -dependencies = [ - "proc-macro2", - "syn", -] - [[package]] name = "proc-macro2" version = "1.0.103" @@ -1588,17 +1901,22 @@ dependencies = [ "bytes", "cookie", "cookie_store", + "encoding_rs", "futures-core", "futures-util", + "h2 0.4.12", "http 1.3.1", "http-body 1.0.1", "http-body-util", - "hyper 1.7.0", + "hyper 1.8.1", "hyper-rustls", + "hyper-tls", "hyper-util", "js-sys", "log", + "mime", "mime_guess", + "native-tls", "percent-encoding", "pin-project-lite", "quinn", @@ -1609,10 +1927,11 @@ dependencies = [ "serde_urlencoded", "sync_wrapper", "tokio", + "tokio-native-tls", "tokio-rustls", "tokio-util", "tower", - "tower-http", + "tower-http 0.6.6", "tower-service", "url", "wasm-bindgen", @@ -1638,20 +1957,54 @@ dependencies = [ "cfg-if", "getrandom 0.2.16", "libc", - "untrusted", + "untrusted 0.9.0", "windows-sys 0.52.0", ] +[[package]] +name = "rsa" +version = "0.9.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40a0376c50d0358279d9d643e4bf7b7be212f1f4ff1da9070a7b54d22ef75c88" +dependencies = [ + "const-oid", + "digest", + "num-bigint-dig", + "num-integer", + "num-traits", + "pkcs1", + "pkcs8", + "rand_core 0.6.4", + "signature", + "spki", + "subtle", + "zeroize", +] + [[package]] name = "rust-mcp-extra" version = "0.1.3" dependencies = [ + "async-lock", + "async-trait", "base64 0.22.1", + "bytes", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", "nanoid", + "oauth2-test-server", "once_cell", "rand 0.9.2", "rand_distr", + "reqwest", "rust-mcp-sdk", + "serde", + "serde_json", + "tokio", + "tracing", + "tracing-subscriber", + "url", ] [[package]] @@ -1689,7 +2042,8 @@ dependencies = [ "http 1.3.1", "http-body 1.0.1", "http-body-util", - "hyper 1.7.0", + "hyper 1.8.1", + "jsonwebtoken", "reqwest", "rust-mcp-macros", "rust-mcp-schema", @@ -1702,6 +2056,7 @@ dependencies = [ "tokio-stream", "tracing", "tracing-subscriber", + "url", "uuid", "wiremock", ] @@ -1786,7 +2141,7 @@ dependencies = [ "aws-lc-rs", "ring", "rustls-pki-types", - "untrusted", + "untrusted 0.9.0", ] [[package]] @@ -1801,12 +2156,44 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +[[package]] +name = "schannel" +version = "0.1.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1" +dependencies = [ + "windows-sys 0.61.2", +] + [[package]] name = "scopeguard" version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "security-framework" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +dependencies = [ + "bitflags", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "serde" version = "1.0.228" @@ -1884,6 +2271,32 @@ dependencies = [ "serde", ] +[[package]] +name = "server-oauth-remote" +version = "0.1.34" +dependencies = [ + "async-trait", + "futures", + "rust-mcp-extra", + "rust-mcp-sdk", + "serde", + "serde_json", + "tokio", + "tracing", + "tracing-subscriber", +] + +[[package]] +name = "sha2" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -1901,13 +2314,23 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "signal-hook-registry" -version = "1.4.6" +version = "1.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2a4719bff48cee6b39d12c020eeb490953ad2443b7055bd0b21fca26bd8c28b" +checksum = "7664a098b8e616bdfcc2dc0e9ac44eb231eedf41db4e9fe95d8d32ec728dedad" dependencies = [ "libc", ] +[[package]] +name = "signature" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" +dependencies = [ + "digest", + "rand_core 0.6.4", +] + [[package]] name = "simple-mcp-client-sse" version = "0.1.26" @@ -2000,6 +2423,18 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "simple_asn1" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "297f631f50729c8c99b84667867963997ec0b50f32b2a7dbcab828ef0541e8bb" +dependencies = [ + "num-bigint", + "num-traits", + "thiserror 2.0.17", + "time", +] + [[package]] name = "slab" version = "0.4.11" @@ -2032,6 +2467,22 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" + +[[package]] +name = "spki" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" +dependencies = [ + "base64ct", + "der", +] + [[package]] name = "stable_deref_trait" version = "1.2.1" @@ -2046,9 +2497,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.109" +version = "2.0.111" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f17c7e013e88258aa9543dcbe81aca68a667a9ac37cd69c9fbc07858bfe0e2f" +checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87" dependencies = [ "proc-macro2", "quote", @@ -2075,6 +2526,27 @@ dependencies = [ "syn", ] +[[package]] +name = "system-configuration" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" +dependencies = [ + "bitflags", + "core-foundation", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "tempfile" version = "3.23.0" @@ -2221,6 +2693,16 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.26.4" @@ -2271,6 +2753,23 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower-http" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5" +dependencies = [ + "bitflags", + "bytes", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "pin-project-lite", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "tower-http" version = "0.6.6" @@ -2369,6 +2868,12 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "typenum" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" + [[package]] name = "unicase" version = "2.8.1" @@ -2381,6 +2886,12 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" +[[package]] +name = "untrusted" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" + [[package]] name = "untrusted" version = "0.9.0" @@ -2422,6 +2933,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "version_check" version = "0.9.5" @@ -2564,12 +3081,76 @@ dependencies = [ "rustls-pki-types", ] +[[package]] +name = "windows-core" +version = "0.62.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-implement" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-interface" +version = "0.59.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "windows-link" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" +[[package]] +name = "windows-registry" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02752bf7fbdcce7f2a27a742f798510f3e5ad88dbe84871e5168e2120c3d5720" +dependencies = [ + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-result" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-strings" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" +dependencies = [ + "windows-link", +] + [[package]] name = "windows-sys" version = "0.52.0" @@ -2794,18 +3375,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.27" +version = "0.8.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0894878a5fa3edfd6da3f88c4805f4c8558e2b996227a3d864f47fe11e38282c" +checksum = "43fa6694ed34d6e57407afbccdeecfa268c470a7d2a5b0cf49ce9fcc345afb90" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.27" +version = "0.8.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88d2b8d9c68ad2b9e4340d7832716a4d21a22a1154777ad56ea55c51a9cf3831" +checksum = "c640b22cd9817fae95be82f0d2f90b11f7605f6c319d16705c459b27ac2cbc26" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index faba18a..9695021 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ members = [ "examples/simple-mcp-client-sse-core", "examples/simple-mcp-client-streamable-http", "examples/simple-mcp-client-streamable-http-core", + "examples/auth/server-oauth-remote" ] @@ -60,6 +61,10 @@ reqwest = { version = "0.12", default-features = false, features = [ "multipart", ] } bytes = "1.10" +url = {version="2.5"} +http = { version ="1.3" } +http-body-util = { version ="0.1" } +http-body = { version ="1.0" } # [workspace.dependencies.windows] diff --git a/Makefile.toml b/Makefile.toml index 7362412..173e5c1 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -41,8 +41,13 @@ command = "cargo" args = ["test", "--doc", "-p", "rust-mcp-macros"] +[tasks.doc-strict] +command = "cargo" +args = ["doc", "--no-deps"] +env = { RUSTDOCFLAGS = "-D warnings" } + [tasks.check] -dependencies = ["fmt", "clippy", "test", "doc-test"] +dependencies = ["fmt", "clippy", "test", "doc-strict", "doc-test"] [tasks.clippy-fix] command = "cargo" diff --git a/README.md b/README.md index 2c70c3e..d92d964 100644 --- a/README.md +++ b/README.md @@ -21,37 +21,35 @@ Leveraging the [rust-mcp-schema](https://github.com/rust-mcp-stack/rust-mcp-sche **rust-mcp-sdk** supports all three official versions of the MCP protocol. By default, it uses the **2025-06-18** version, but earlier versions can be enabled via Cargo features. - - -This project supports following transports: -- **Stdio** (Standard Input/Output) -- **Streamable HTTP** -- **SSE** (Server-Sent Events) - - 🚀 The **rust-mcp-sdk** includes a lightweight [Axum](https://github.com/tokio-rs/axum) based server that handles all core functionality seamlessly. Switching between `stdio` and `Streamable HTTP` is straightforward, requiring minimal code changes. The server is designed to efficiently handle multiple concurrent client connections and offers built-in support for SSL. -**MCP Streamable HTTP Support** -- ✅ Streamable HTTP Support for MCP Servers +**Features** +- ✅ Stdio, SSE and Streamable HTTP Support +- ✅ Supports multiple MCP protocol versions - ✅ DNS Rebinding Protection - ✅ Batch Messages - ✅ Streaming & non-streaming JSON response -- ✅ Streamable HTTP Support for MCP Clients - ✅ Resumability -- ⬜ Oauth Authentication +- ✅ OAuth Authentication for MCP Servers + - ✅ [Remote Oauth Provider](crates/rust-mcp-sdk/src/auth/auth_provider/remote_auth_provider.rs) (for any provider with DCR support) + - ✅ **Keycloak** Provider (via [rust-mcp-extra](crates/rust-mcp-extra/README.md#keycloak)) + - ✅ **WorkOS** Authkit Provider (via [rust-mcp-extra](crates/rust-mcp-extra/README.md#workos-authkit)) + - ✅ **Scalekit** Authkit Provider (via [rust-mcp-extra](crates/rust-mcp-extra/README.md#scalekit)) +- ⬜ OAuth Authentication for MCP Clients **⚠️** Project is currently under development and should be used at your own risk. ## Table of Contents +- [Getting Started](#getting-started) - [Usage Examples](#usage-examples) - [MCP Server (stdio)](#mcp-server-stdio) - [MCP Server (Streamable HTTP)](#mcp-server-streamable-http) - [MCP Client (stdio)](#mcp-client-stdio) - - [MCP Client (Streamable HTTP)](#mcp-client_streamable-http)) + - [MCP Client (Streamable HTTP)](#mcp-client-streamable-http) - [MCP Client (sse)](#mcp-client-sse) +- [Authentication](#authentication) - [Macros](#macros) -- [Getting Started](#getting-started) - [HyperServerOptions](#hyperserveroptions) - [Security Considerations](#security-considerations) - [Cargo features](#cargo-features) @@ -68,6 +66,12 @@ This project supports following transports: - [Development](#development) - [License](#license) + +## Getting Started + +If you are looking for a step-by-step tutorial on how to get started with `rust-mcp-sdk` , please see : [Getting Started MCP Server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/doc/getting-started-mcp-server.md) + + ## Usage Examples ### MCP Server (stdio) @@ -387,6 +391,26 @@ Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost 👉 see [examples/simple-mcp-client-sse](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-sse) for a complete working example. +## Authentication +MCP server can verify tokens issued by other systems, integrate with external identity providers, or manage the entire authentication process itself. Each option offers a different balance of simplicity, security, and control. + + ### RemoteAuthProvider + [RemoteAuthProvider](src/mcp_http/auth/auth_provider/remote_auth_provider.rs) RemoteAuthProvider enables authentication with identity providers that support Dynamic Client Registration (DCR) such as KeyCloak and WorkOS AuthKit, letting MCP clients auto-register and obtain credentials without manual setup. + +👉 See the [server-oauth-remote](examples/auth/server-oauth-remote) example for how to use RemoteAuthProvider with a DCR-capable remote provider. + +👉 [rust-mcp-extra](https://crates.io/crates/rust-mcp-extra) also offers drop-in auth providers for common identity platforms, working seamlessly with rust-mcp-sdk: + - [Keycloack auth example](crates/rust-mcp-extra/README.md#keycloak) + - [WorkOS autn example](crates/rust-mcp-extra/README.md#workos-authkit) + + + + ### OAuthProxy + OAuthProxy enables authentication with OAuth providers that don’t support Dynamic Client Registration (DCR).It accepts any client registration request, handles the DCR on your server side and then uses your pre-registered app credentials upstream.The proxy also forwards callbacks, allowing dynamic redirect URIs to work with providers that require fixed ones. + +> ⚠️ OAuthProxy support is still in development, please use RemoteAuthProvider for now. + + ## Macros [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) includes several helpful macros that simplify common tasks when building MCP servers and clients. For example, they can automatically generate tool specifications and tool schemas right from your structs, or assist with elicitation requests and responses making them completely type safe. @@ -495,10 +519,6 @@ let user_info = UserInfo::from_content_map(result.content)?; 💻 For mre info please see : - https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/crates/rust-mcp-macros -## Getting Started - -If you are looking for a step-by-step tutorial on how to get started with `rust-mcp-sdk` , please see : [Getting Started MCP Server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/doc/getting-started-mcp-server.md) - ## HyperServerOptions HyperServer is a lightweight Axum-based server that streamlines MCP servers by supporting **Streamable HTTP** and **SSE** transports. It supports simultaneous client connections, internal session management, and includes built-in security features like DNS rebinding protection and more. @@ -588,6 +608,9 @@ pub struct HyperServerOptions { /// Optional custom path for the MCP messages endpoint for sse (default: `/messages`) /// Applicable only if sse_support is true pub custom_messages_endpoint: Option, + + /// Optional authentication provider for protecting MCP server. + pub auth: Option>, } ``` @@ -625,7 +648,7 @@ The `rust-mcp-sdk` crate provides several features that can be enabled or disabl - `2025_03_26` : Activates MCP Protocol version 2025-03-26 - `2024_11_05` : Activates MCP Protocol version 2024-11-05 -> Note: MCP protocol versions are mutually exclusive—only one can be active at any given time. +> Note: MCP protocol versions are mutually exclusive-only one can be active at any given time. ### Default Features diff --git a/assets/examples/mcp-remote-oauth.gif b/assets/examples/mcp-remote-oauth.gif new file mode 100644 index 0000000..87eb37f Binary files /dev/null and b/assets/examples/mcp-remote-oauth.gif differ diff --git a/crates/rust-mcp-extra/Cargo.toml b/crates/rust-mcp-extra/Cargo.toml index b2dfc09..ca6acdf 100644 --- a/crates/rust-mcp-extra/Cargo.toml +++ b/crates/rust-mcp-extra/Cargo.toml @@ -13,21 +13,43 @@ rust-version = { workspace = true } exclude = ["assets/", "tests/"] [dependencies] -rust-mcp-sdk = { version = "0.7.3" , path = "../rust-mcp-sdk", default-features = false, features=["server","2025_06_18"] } - +rust-mcp-sdk = { version = "0.7.3" , path = "../rust-mcp-sdk", default-features = false, features=["server","2025_06_18","auth","hyper-server","macros"] } base64 = {workspace = true, optional=true} +url= {workspace = true, optional=true} nanoid = {version="0.4", optional=true} once_cell = {version="1.2", optional=true} rand = {version="0.9.2", features = ["std", "alloc"] , optional=true} rand_distr = {version="0.5.1", optional=true} +reqwest = { workspace = true, default-features = false, features = [ + "rustls-tls", + "json", + "cookies", +], optional = true } +async-lock = {version="3.4.1", optional=true} +async-trait = {workspace = true, optional=true} +serde = { workspace = true, optional=true } +serde_json = { workspace = true, optional=true } +http = {workspace=true, optional=true } +http-body-util = { workspace = true, optional = true } +http-body = { workspace = true, optional = true } +bytes = {workspace=true, optional=true } +tracing = { workspace = true } + + +[dev-dependencies] +oauth2-test-server = "0.1" +tokio={ workspace = true} +tracing-subscriber = { workspace = true, features = ["env-filter"] } [features] -default = ["nano_id","snowflake_id","random_62_id","time_64_id"] +default = ["auth","nano_id","snowflake_id","random_62_id","time_64_id"] nano_id = ["nanoid"] snowflake_id = ["once_cell"] random_62_id = ["rand","rand_distr"] time_64_id = ["base64"] +auth=["url","reqwest","async-lock","async-trait","rust-mcp-sdk/auth","rust-mcp-sdk/hyper-server" +,"rust-mcp-sdk/sse", "serde","serde_json","http","bytes","http-body","http-body-util"] [lints] workspace = true diff --git a/crates/rust-mcp-extra/README.md b/crates/rust-mcp-extra/README.md index 2adf1d4..aeaad82 100644 --- a/crates/rust-mcp-extra/README.md +++ b/crates/rust-mcp-extra/README.md @@ -1,10 +1,120 @@ # rust-mcp-extra -**A companion crate to [`rust-mcp-sdk`](https://github.com/rust-mcp-stack/rust-mcp-sdk) providing additional implementations for core traits like `IdGenerator`, `SessionStore` and `EventStore`.** +A companion crate to [`rust-mcp-sdk`](https://github.com/rust-mcp-stack/rust-mcp-sdk) providing additional implementations for core traits like `AuthProvider`, `IdGenerator`, `SessionStore` and `EventStore`. + +## 📖 Table of Contents +- **[Authentication Providers](#-authentication-providers)** + - [Keycloak](#keycloak) + - [WorkOs Authkit](#workos-authkit) + - [Scalekit](#scalekit) +- **[ID Generators](#-id-generators)** + - [NanoIdGenerator](#nanoidgenerator) + - [TimeBase64Generator](#timebase64generator) + - [RandomBase62Generator](#randombase62generator) + - [SnowflakeIdGenerator](#snowflakeidgenerator) +- **[Session Stores](#-session-stores)** + - 🔜 Coming Soon +- **[Event Stores](#-event-stores)** + - 🔜 Coming Soon + ----- +## 🔐 Authentication Providers +A collection of authentication providers that integrate seamlessly with the [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk). +These providers offer a ready-to-use integration with common identity systems, so developers don’t have to build an AuthProvider implementation for each provider themselves. + + +### **Keycloak** +A full OAuth2/OpenID Connect provider integration for [Keycloak](https://www.keycloak.org) backed identity systems. +Useful for enterprise environments or self-hosted identity setups. + +- Example usage: + +```rs +let auth_provider = KeycloakAuthProvider::new(KeycloakAuthOptions { + keycloak_base_url: "http://localhost:8080/realms/master".to_string(), + mcp_server_url: "http://localhost:3000".to_string(), + resource_name: Some("Keycloak Oauth Test MCP Server".to_string()), + required_scopes: None, + client_id: "keycloak-client-id".to_string(), + client_secret: "keycloak-client-secret".to_string(), + token_verifier: None, + resource_documentation: None, +})?; +``` + +Before running the [example](../../examples/keycloak-auth.rs), ensure you have a Keycloak instance properly configured. +Follow the official MCP authorization tutorial for Keycloak setup:[Keycloak Setup Guide](https://modelcontextprotocol.io/docs/tutorials/security/authorization#keycloak-setup) + +By default, the example assumes Keycloak is running at `http://localhost:8080`. + +configure a confidential client in Keycloak and provide credentials as environment variables: + +```sh +export AUTH_SERVER=http://localhost:8080/realms/master +export CLIENT_ID=your-confidential-client-id +export CLIENT_SECRET=your-client-secret +cargo run -p rust-mcp-extra --example keycloak-auth +``` + + +### **WorkOS AuthKit** +An OAuth provider implementation for [WorkOS Authkit](https://workos.com). + +- Example usage: + +```rs +let auth_provider = WorkOsAuthProvider::new(WorkOSAuthOptions { + authkit_domain: "https://stalwart-opera-85-staging.authkit.app".to_string(), + mcp_server_url: "http://127.0.0.1:3000/mcp".to_string(), + required_scopes: Some(vec!["openid", "profile"]), + resource_name: Some("Workos Oauth Test MCP Server".to_string()), + resource_documentation: None, + token_verifier: None, + })?; +``` + +Before running the [example](../../examples/workos-auth.rs), make sure you enabled DCR (Dynamic Client Registration) in your WorkOS Authkit dashboard. + +Set the `AUTH_SERVER` environment variable and start the example: + +``` +export AUTH_SERVER=https://stalwart-opera-85-staging.authkit.app +cargo run -p rust-mcp-extra --example workos-auth +``` + + + +### **Scalekit** +An OAuth provider implementation for [Scalekit](https://www.scalekit.com). + +- Example usage: + +```rs +let auth_provider = ScalekitAuthProvider::new(ScalekitAuthOptions { + mcp_server_url: "http://127.0.0.1:3000/mcp".to_string(), + required_scopes: Some(vec!["profile"]), + token_verifier: None, + resource_name: Some("Scalekit Oauth Test MCP Server".to_string()), + resource_documentation: None, + environment_url: "yourapp.scalekit.dev".to_string(), + resource_id: "res_your-resource_id".to_string(), +}) +.await?; +``` + +Set the `ENVIRONMENT_URL` and `RESOURCE_ID` environment variable and start the example: + +``` +export ENVIRONMENT_URL=yourapp.scalekit.dev +export RESOURCE_ID=res_your-resource_id +cargo run -p rust-mcp-extra --example scalekit-auth +``` + + + ## 🔢 ID Generators -Various implementations of the IdGenerator trait (from [rust-mcp-sdk]) for generating unique identifiers. +Various implementations of the IdGenerator trait (from [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk)) for generating unique identifiers. | **🧩 All ID generators in this crate can be used as `SessionId` generators in [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk)).** diff --git a/crates/rust-mcp-extra/examples/common/handler.rs b/crates/rust-mcp-extra/examples/common/handler.rs new file mode 100644 index 0000000..5fc0714 --- /dev/null +++ b/crates/rust-mcp-extra/examples/common/handler.rs @@ -0,0 +1,47 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use rust_mcp_sdk::{ + mcp_server::ServerHandler, + schema::{ + schema_utils::CallToolError, CallToolRequest, CallToolResult, ListToolsRequest, + ListToolsResult, RpcError, + }, + McpServer, +}; + +use crate::common::tool::ShowAuthInfo; + +pub struct McpServerHandler; +#[async_trait] +impl ServerHandler for McpServerHandler { + // Handle ListToolsRequest, return list of available tools as ListToolsResult + async fn handle_list_tools_request( + &self, + _request: ListToolsRequest, + _runtime: Arc, + ) -> std::result::Result { + Ok(ListToolsResult { + meta: None, + next_cursor: None, + tools: vec![ShowAuthInfo::tool()], + }) + } + + /// Handles incoming CallToolRequest and processes it using the appropriate tool. + async fn handle_call_tool_request( + &self, + request: CallToolRequest, + runtime: Arc, + ) -> std::result::Result { + if request.params.name.eq(&ShowAuthInfo::tool_name()) { + let tool = ShowAuthInfo::default(); + tool.call_tool(runtime.auth_info_cloned().await) + } else { + Err(CallToolError::from_message(format!( + "Tool \"{}\" does not exists or inactive!", + request.params.name, + ))) + } + } +} diff --git a/crates/rust-mcp-extra/examples/common/mod.rs b/crates/rust-mcp-extra/examples/common/mod.rs new file mode 100644 index 0000000..b271812 --- /dev/null +++ b/crates/rust-mcp-extra/examples/common/mod.rs @@ -0,0 +1,3 @@ +pub mod handler; +pub mod tool; +pub mod utils; diff --git a/crates/rust-mcp-extra/examples/common/tool.rs b/crates/rust-mcp-extra/examples/common/tool.rs new file mode 100644 index 0000000..1d395d5 --- /dev/null +++ b/crates/rust-mcp-extra/examples/common/tool.rs @@ -0,0 +1,25 @@ +use rust_mcp_sdk::{ + auth::AuthInfo, + macros::{mcp_tool, JsonSchema}, + schema::{schema_utils::CallToolError, CallToolResult, TextContent}, +}; + +//*******************************// +// Show Authentication Info // +//*******************************// +#[mcp_tool( + name = "show_auth_info", + description = "Shows current user authentication info in json format" +)] +#[derive(Debug, ::serde::Deserialize, ::serde::Serialize, JsonSchema, Default)] +pub struct ShowAuthInfo {} +impl ShowAuthInfo { + pub fn call_tool(&self, auth_info: Option) -> Result { + let auth_info_json = serde_json::to_string_pretty(&auth_info).map_err(|err| { + CallToolError::from_message(format!("Undable to display auth info as string :{err}")) + })?; + Ok(CallToolResult::text_content(vec![TextContent::from( + auth_info_json, + )])) + } +} diff --git a/crates/rust-mcp-extra/examples/common/utils.rs b/crates/rust-mcp-extra/examples/common/utils.rs new file mode 100644 index 0000000..6889b56 --- /dev/null +++ b/crates/rust-mcp-extra/examples/common/utils.rs @@ -0,0 +1,31 @@ +use rust_mcp_sdk::schema::{ + Implementation, InitializeResult, ServerCapabilities, ServerCapabilitiesTools, + LATEST_PROTOCOL_VERSION, +}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +pub fn create_server_info(server_name: &str) -> InitializeResult { + InitializeResult { + server_info: Implementation { + name: server_name.to_string(), + version: "0.1.0".to_string(), + title: Some(server_name.to_string()), + }, + capabilities: ServerCapabilities { + tools: Some(ServerCapabilitiesTools { list_changed: None }), + ..Default::default() + }, + meta: None, + instructions: None, + protocol_version: LATEST_PROTOCOL_VERSION.to_string(), + } +} + +pub fn enable_tracing() { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); +} diff --git a/crates/rust-mcp-extra/examples/keycloak-auth.rs b/crates/rust-mcp-extra/examples/keycloak-auth.rs new file mode 100644 index 0000000..b4b191b --- /dev/null +++ b/crates/rust-mcp-extra/examples/keycloak-auth.rs @@ -0,0 +1,47 @@ +mod common; +use crate::common::{ + handler::McpServerHandler, + utils::{create_server_info, enable_tracing}, +}; +use rust_mcp_extra::auth_provider::keycloak::{KeycloakAuthOptions, KeycloakAuthProvider}; +use rust_mcp_sdk::{ + error::SdkResult, + mcp_server::{hyper_server, HyperServerOptions}, +}; +use std::{env, sync::Arc}; + +#[tokio::main] +async fn main() -> SdkResult<()> { + enable_tracing(); + let server_details = create_server_info("Keycloak Oauth Test MCP Server"); + + let handler = McpServerHandler {}; + + let auth_provider = KeycloakAuthProvider::new(KeycloakAuthOptions { + keycloak_base_url: env::var("AUTH_SERVER") + .unwrap_or("http://localhost:8080/realms/master".to_string()), + mcp_server_url: "http://localhost:3000".to_string(), + resource_name: Some("Keycloak Oauth Test MCP Server".to_string()), + required_scopes: Some(vec!["mcp:tools"]), + client_id: env::var("CLIENT_ID").ok(), + client_secret: env::var("CLIENT_SECRET").ok(), + token_verifier: None, + resource_documentation: None, + })?; + + let server = hyper_server::create_server( + server_details, + handler, + HyperServerOptions { + host: "localhost".to_string(), + port: 3000, + custom_streamable_http_endpoint: Some("/".to_string()), + auth: Some(Arc::new(auth_provider)), // enable authentication + sse_support: false, + ..Default::default() + }, + ); + + server.start().await?; + Ok(()) +} diff --git a/crates/rust-mcp-extra/examples/scalekit-auth.rs b/crates/rust-mcp-extra/examples/scalekit-auth.rs new file mode 100644 index 0000000..8fd625f --- /dev/null +++ b/crates/rust-mcp-extra/examples/scalekit-auth.rs @@ -0,0 +1,47 @@ +mod common; +use crate::common::{ + handler::McpServerHandler, + utils::{create_server_info, enable_tracing}, +}; +use rust_mcp_extra::auth_provider::scalekit::{ScalekitAuthOptions, ScalekitAuthProvider}; +use rust_mcp_sdk::{ + error::SdkResult, + mcp_server::{hyper_server, HyperServerOptions}, +}; +use std::{env, sync::Arc}; + +#[tokio::main] +async fn main() -> SdkResult<()> { + enable_tracing(); + let server_details = create_server_info("Scalekit Oauth Test MCP Server"); + + let handler = McpServerHandler {}; + + let auth_provider = ScalekitAuthProvider::new(ScalekitAuthOptions { + mcp_server_url: "http://127.0.0.1:3000/mcp".to_string(), + required_scopes: Some(vec!["profile"]), + token_verifier: None, + resource_name: Some("Scalekit Oauth Test MCP Server".to_string()), + resource_documentation: None, + environment_url: env::var("ENVIRONMENT_URL") + .expect("Please set 'ENVIRONMENT_URL' evnrionment variable and try again."), + resource_id: env::var("RESOURCE_ID") + .expect("Please set 'RESOURCE_ID' evnrionment variable and try again."), + }) + .await?; + + let server = hyper_server::create_server( + server_details, + handler, + HyperServerOptions { + host: "127.0.0.1".to_string(), + port: 3000, + auth: Some(Arc::new(auth_provider)), // enable authentication + sse_support: false, + ..Default::default() + }, + ); + + server.start().await?; + Ok(()) +} diff --git a/crates/rust-mcp-extra/examples/workos-auth.rs b/crates/rust-mcp-extra/examples/workos-auth.rs new file mode 100644 index 0000000..01d980b --- /dev/null +++ b/crates/rust-mcp-extra/examples/workos-auth.rs @@ -0,0 +1,44 @@ +mod common; +use crate::common::{ + handler::McpServerHandler, + utils::{create_server_info, enable_tracing}, +}; +use rust_mcp_extra::auth_provider::work_os::{WorkOSAuthOptions, WorkOsAuthProvider}; +use rust_mcp_sdk::{ + error::SdkResult, + mcp_server::{hyper_server, HyperServerOptions}, +}; +use std::{env, sync::Arc}; + +#[tokio::main] +async fn main() -> SdkResult<()> { + enable_tracing(); + let server_details = create_server_info("Workos Oauth Test MCP Server"); + + let handler = McpServerHandler {}; + + let auth_provider = WorkOsAuthProvider::new(WorkOSAuthOptions { + authkit_domain: env::var("AUTH_SERVER") + .unwrap_or("https://stalwart-opera-85-staging.authkit.app".to_string()), + mcp_server_url: "http://127.0.0.1:3000/mcp".to_string(), + required_scopes: Some(vec!["openid", "profile"]), + resource_name: Some("Workos Oauth Test MCP Server".to_string()), + resource_documentation: None, + token_verifier: None, + })?; + + let server = hyper_server::create_server( + server_details, + handler, + HyperServerOptions { + host: "127.0.0.1".to_string(), + port: 3000, + auth: Some(Arc::new(auth_provider)), // enable authentication + sse_support: false, + ..Default::default() + }, + ); + + server.start().await?; + Ok(()) +} diff --git a/crates/rust-mcp-extra/src/auth_provider.rs b/crates/rust-mcp-extra/src/auth_provider.rs new file mode 100644 index 0000000..ee826e9 --- /dev/null +++ b/crates/rust-mcp-extra/src/auth_provider.rs @@ -0,0 +1,3 @@ +pub mod keycloak; +pub mod scalekit; +pub mod work_os; diff --git a/crates/rust-mcp-extra/src/auth_provider/keycloak.rs b/crates/rust-mcp-extra/src/auth_provider/keycloak.rs new file mode 100644 index 0000000..1dfa8d4 --- /dev/null +++ b/crates/rust-mcp-extra/src/auth_provider/keycloak.rs @@ -0,0 +1,290 @@ +use crate::token_verifier::{ + GenericOauthTokenVerifier, TokenVerifierOptions, VerificationStrategies, +}; +use async_trait::async_trait; +use bytes::Bytes; +use http::{header::CONTENT_TYPE, StatusCode}; +use http_body_util::{BodyExt, Full}; +use rust_mcp_sdk::{ + auth::{ + create_discovery_endpoints, AuthInfo, AuthMetadataBuilder, AuthProvider, + AuthenticationError, AuthorizationServerMetadata, OauthEndpoint, + OauthProtectedResourceMetadata, OauthTokenVerifier, + }, + error::McpSdkError, + mcp_http::{middleware::CorsMiddleware, GenericBody, GenericBodyExt, Middleware}, + mcp_server::{ + error::{TransportServerError, TransportServerResult}, + join_url, McpAppState, + }, +}; +use std::{collections::HashMap, sync::Arc}; + +static SCOPES_SUPPORTED: &[&str] = &[ + "openid", + "acr", + "basic", + "web-origins", + "email", + "mcp:tools", + "address", + "profile", + "phone", + "roles", + "microprofile-jwt", + "service_account", + "offline_access", + "organization", +]; + +/// Configuration options for the Keycloak OAuth provider. +pub struct KeycloakAuthOptions<'a> { + /// Base URL of the Keycloak server (e.g. `https://keycloak.example.com`) + pub keycloak_base_url: String, + /// Public base URL of this MCP server (used for discovery endpoints) + pub mcp_server_url: String, + /// Scopes that must be present in the access token + pub required_scopes: Option>, + /// Client ID for confidential client (required for token introspection) + pub client_id: Option, + /// Client secret for confidential client (required for token introspection) + pub client_secret: Option, + /// Optional custom token verifier + pub token_verifier: Option>, + /// Human-readable name of the protected resource (optional, shown in discovery) + pub resource_name: Option, + /// Documentation URL for this resource (optional) + pub resource_documentation: Option, +} + +/// Keycloak integration implementing `AuthProvider` for MCP servers. +/// +/// This provider makes your MCP server compatible with clients that expect standard +/// OAuth2/OpenID Connect discovery endpoints (authorization server metadata and +/// protected resource metadata) when using Keycloak as the identity provider. +/// +/// It supports multiple token verification strategies with the following precedence: +/// +/// 1. JWKs-based verification (always enabled) – validates JWT signature, issuer, expiry, etc. +/// 2. Token Introspection (if client_id + client_secret provided) – active validation against Keycloak +/// 3. UserInfo endpoint validation (fallback when `openid` scope is required but no introspection credentials) +/// +pub struct KeycloakAuthProvider { + auth_server_meta: AuthorizationServerMetadata, + protected_resource_meta: OauthProtectedResourceMetadata, + endpoint_map: HashMap, + protected_resource_metadata_url: String, + token_verifier: Box, +} + +impl KeycloakAuthProvider { + /// Creates a new KeycloakAuthProvider instance. + /// + /// This method configures OAuth2/OpenID Connect discovery metadata and selects + /// the best available token verification strategy: + /// + /// ### Verification Strategy Priority & Security Considerations + /// + /// | Strategy | When Used | Security Level | Notes | + /// |------------------|---------------------------------------------------|----------------|-------| + /// | JWKs (local) | Always | High | Validates signature, `iss`, `exp`, `nbf`, etc. No network call. | + /// | Introspection | When `client_id` + `client_secret` are provided | Highest | Active validation with Keycloak. Detects revoked/expired tokens immediately. Recommended for production. | + /// | UserInfo | Fallback when `openid` scope is required but no introspection credentials | Medium | Validates token by calling `/userinfo`. Less secure than introspection (some IdPs accept invalid tokens). | + /// + /// Warning: If neither introspection nor `openid` scope is configured, only local JWT validation occurs. + /// This means revoked tokens may still be accepted until they expire. + /// + /// Recommendation: Always provide `client_id` and `client_secret` in production for full revocation support. + /// + pub fn new(mut options: KeycloakAuthOptions) -> Result { + let (endpoint_map, protected_resource_metadata_url) = + create_discovery_endpoints(&options.mcp_server_url)?; + + let required_scopes = options.required_scopes.take(); + let scopes_supported = required_scopes.clone().unwrap_or(SCOPES_SUPPORTED.to_vec()); + + let mut builder = AuthMetadataBuilder::new(&options.mcp_server_url) + .issuer(&options.keycloak_base_url) + .authorization_servers(vec![&options.keycloak_base_url]) + .authorization_endpoint("/protocol/openid-connect/auth") + .introspection_endpoint("/protocol/openid-connect/token/introspect") + .registration_endpoint("/clients-registrations/openid-connect") + .token_endpoint("/protocol/openid-connect/token") + .revocation_endpoint("/protocol/openid-connect/revoke") + .jwks_uri("/protocol/openid-connect/certs") + .scopes_supported(scopes_supported); + + let has_openid_scope = + matches!(required_scopes.as_ref(), Some(scopes) if scopes.contains(&"openid")); + + if let Some(scopes) = required_scopes { + builder = builder.reqquired_scopes(scopes) + } + if let Some(resource_name) = options.resource_name.as_ref() { + builder = builder.resource_name(resource_name) + } + if let Some(resource_documentation) = options.resource_documentation.as_ref() { + builder = builder.service_documentation(resource_documentation) + } + + let (auth_server_meta, protected_resource_meta) = builder.build()?; + + let Some(jwks_uri) = auth_server_meta.jwks_uri.as_ref().map(|s| s.to_string()) else { + return Err(McpSdkError::Internal { + description: "jwks_uri is not defined!".to_string(), + }); + }; + + let mut strategies = Vec::with_capacity(2); + strategies.push(VerificationStrategies::JWKs { jwks_uri }); + + if let (Some(client_id), Some(client_secret), Some(introspection_uri)) = ( + options.client_id.take(), + options.client_secret.take(), + auth_server_meta.introspection_endpoint.as_ref(), + ) { + strategies.push(VerificationStrategies::Introspection { + introspection_uri: introspection_uri.to_string(), + client_id, + client_secret, + use_basic_auth: true, + extra_params: Some(vec![("token_type_hint", "access_token")]), + }); + } else if has_openid_scope { + let userinfo_uri = join_url( + &auth_server_meta.issuer, + "/protocol/openid-connect/userinfo", + ) + .map_err(|err| McpSdkError::Internal { + description: format!("invalid userinfo url :{err}"), + })? + .to_string(); + + strategies.push(VerificationStrategies::UserInfo { userinfo_uri }) + } else { + tracing::warn!("Keycloak token verification is missing both Introspection and UserInfo strategies. Please provide client_id and client_secret, or ensure openid is included as a required scope.") + }; + + let token_verifier: Box = match options.token_verifier { + Some(verifier) => verifier, + None => Box::new(GenericOauthTokenVerifier::new(TokenVerifierOptions { + strategies, + validate_audience: None, + validate_issuer: Some(options.keycloak_base_url.clone()), + cache_capacity: None, + })?), + }; + + Ok(Self { + endpoint_map, + protected_resource_metadata_url, + token_verifier, + auth_server_meta, + protected_resource_meta, + }) + } + + /// Helper to build JSON response for authorization server metadata with CORS. + fn handle_authorization_server_metadata( + response_str: String, + ) -> TransportServerResult> { + let body = Full::new(Bytes::from(response_str)) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed(); + http::Response::builder() + .status(StatusCode::OK) + .header(CONTENT_TYPE, "application/json") + .body(body) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + } + + /// Helper to build JSON response for protected resource metadata with permissive CORS. + fn handle_protected_resource_metadata( + response_str: String, + ) -> TransportServerResult> { + use http_body_util::BodyExt; + + let body = Full::new(Bytes::from(response_str)) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed(); + http::Response::builder() + .status(StatusCode::OK) + .header(CONTENT_TYPE, "application/json") + .body(body) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + } +} + +#[async_trait] +impl AuthProvider for KeycloakAuthProvider { + /// Returns the map of supported OAuth discovery endpoints. + fn auth_endpoints(&self) -> Option<&HashMap> { + Some(&self.endpoint_map) + } + + /// Handles incoming requests to OAuth metadata endpoints. + async fn handle_request( + &self, + request: http::Request<&str>, + state: Arc, + ) -> Result, TransportServerError> { + let Some(endpoint) = self.endpoint_type(&request) else { + return http::Response::builder() + .status(StatusCode::NOT_FOUND) + .body(GenericBody::empty()) + .map_err(|err| TransportServerError::HttpError(err.to_string())); + }; + + // return early if method is not allowed + if let Some(response) = self.validate_allowed_methods(endpoint, request.method()) { + return Ok(response); + } + + match endpoint { + OauthEndpoint::AuthorizationServerMetadata => { + let json_payload = serde_json::to_string(&self.auth_server_meta) + .map_err(|err| TransportServerError::HttpError(err.to_string()))?; + let cors = &CorsMiddleware::default(); + cors.handle( + request, + state, + Box::new(move |_req, _state| { + Box::pin( + async move { Self::handle_authorization_server_metadata(json_payload) }, + ) + }), + ) + .await + } + OauthEndpoint::ProtectedResourceMetadata => { + let json_payload = serde_json::to_string(&self.protected_resource_meta) + .map_err(|err| TransportServerError::HttpError(err.to_string()))?; + + let cors = &CorsMiddleware::default(); + cors.handle( + request, + state, + Box::new(move |_req, _state| { + Box::pin( + async move { Self::handle_protected_resource_metadata(json_payload) }, + ) + }), + ) + .await + } + _ => Ok(GenericBody::create_404_response()), + } + } + + /// Verifies an access token using JWKs and optional UserInfo validation. + /// + /// Returns authenticated `AuthInfo` on success. + async fn verify_token(&self, access_token: String) -> Result { + self.token_verifier.verify_token(access_token).await + } + + /// Returns the full URL to the protected resource metadata document. + fn protected_resource_metadata_url(&self) -> Option<&str> { + Some(self.protected_resource_metadata_url.as_str()) + } +} diff --git a/crates/rust-mcp-extra/src/auth_provider/scalekit.rs b/crates/rust-mcp-extra/src/auth_provider/scalekit.rs new file mode 100644 index 0000000..b3b1ffb --- /dev/null +++ b/crates/rust-mcp-extra/src/auth_provider/scalekit.rs @@ -0,0 +1,273 @@ +use crate::token_verifier::{ + GenericOauthTokenVerifier, TokenVerifierOptions, VerificationStrategies, +}; +use async_trait::async_trait; +use bytes::Bytes; +use http::{header::CONTENT_TYPE, StatusCode}; +use http_body_util::{BodyExt, Full}; +use rust_mcp_sdk::{ + auth::{ + create_discovery_endpoints, AuthInfo, AuthMetadataBuilder, AuthProvider, + AuthenticationError, AuthorizationServerMetadata, OauthEndpoint, + OauthProtectedResourceMetadata, OauthTokenVerifier, + }, + error::McpSdkError, + mcp_http::{middleware::CorsMiddleware, GenericBody, GenericBodyExt, Middleware}, + mcp_server::{ + error::{TransportServerError, TransportServerResult}, + join_url, McpAppState, + }, +}; +use std::{collections::HashMap, sync::Arc, vec}; +use url::Url; + +/// Configuration options for the [`ScalekitAuthProvider`]. +/// +/// These values come from the Scalekit dashboard and MCP server configuration. +pub struct ScalekitAuthOptions<'a> { + /// Base Scalekit environment URL. + /// This value can be found in the Scalekit dashboard, located in the Settings section + /// + /// If protocol is missing (no `http://` or `https://`), `https://` is automatically added. + pub environment_url: String, + /// This value can be found in the Scalekit dashboard, located in MCp Servers + pub resource_id: String, + /// Public-facing MCP server base URL. + pub mcp_server_url: String, + /// Optional list of required OAuth scopes for this resource. + pub required_scopes: Option>, + /// Human-readable resource name for documentation/metadata. + pub resource_name: Option, + /// Human-readable resource documentation URL or content identifier. + pub resource_documentation: Option, + /// Optional custom token verifier. + /// If omitted, a default JWK-based [`GenericOauthTokenVerifier`] is created. + pub token_verifier: Option>, +} + +/// MCP OAuth provider implementation for Scalekit. +pub struct ScalekitAuthProvider { + auth_server_meta: AuthorizationServerMetadata, + protected_resource_meta: OauthProtectedResourceMetadata, + endpoint_map: HashMap, + protected_resource_metadata_url: String, + token_verifier: Box, +} + +impl ScalekitAuthProvider { + /// Creates a new [`ScalekitAuthProvider`] from configuration options. + /// + /// This method: + /// - Normalizes the environment URL protocol + /// - Builds OAuth discovery URLs + /// - Pulls authorization server metadata + /// - Builds protected resource metadata + /// - Instantiates a JWK-based token verifier if no custom verifier is provided + /// + /// # Errors + /// Returns [`McpSdkError`] if: + /// - URLs are invalid + /// - Metadata discovery fails + /// - JWK verifier initialization fails + pub async fn new<'a>(mut options: ScalekitAuthOptions<'a>) -> Result { + // Normalize environment URL and add https:// if needed + let environment_url = if options.environment_url.starts_with("http://") + || options.environment_url.starts_with("https://") + { + &options.environment_url + } else { + &format!("https://{}", options.environment_url) + }; + + let issuer = Url::parse(environment_url).map_err(|err| McpSdkError::Internal { + description: format!("invalid userinfo url :{err}"), + })?; + + // Build discovery document URL for this resource + let discovery_url = join_url( + &issuer, + &format!( + "/.well-known/oauth-authorization-server/resources/{}", + options.resource_id + ), + ) + .map_err(|err| McpSdkError::Internal { + description: format!("invalid userinfo url :{err}"), + })?; + + let (endpoint_map, protected_resource_metadata_url) = + create_discovery_endpoints(&options.mcp_server_url)?; + + let required_scopes: Vec = options + .required_scopes + .take() + .unwrap_or_default() + .iter() + .map(|s| s.to_string()) + .collect(); + + let mut builder = AuthMetadataBuilder::from_discovery_url( + discovery_url.as_str(), + options.mcp_server_url, + required_scopes.clone(), + ) + .await + .unwrap(); + + if let Some(resource_name) = options.resource_name.as_ref() { + builder = builder.resource_name(resource_name) + } + + if let Some(resource_documentation) = options.resource_documentation.as_ref() { + builder = builder.service_documentation(resource_documentation) + } + + let authorization_servers = + join_url(&issuer, &format!("/resources/{}", options.resource_id)) + .map_err(|err| McpSdkError::Internal { + description: format!("invalid userinfo url :{err}"), + })? + .to_string(); + + builder = builder.authorization_servers(vec![&authorization_servers]); + + if !required_scopes.is_empty() { + builder = builder.reqquired_scopes(required_scopes) + } + if let Some(resource_name) = options.resource_name.as_ref() { + builder = builder.resource_name(resource_name) + } + if let Some(resource_documentation) = options.resource_documentation.as_ref() { + builder = builder.service_documentation(resource_documentation) + } + + let (auth_server_meta, protected_resource_meta) = builder.build()?; + + let Some(jwks_uri) = auth_server_meta.jwks_uri.as_ref().map(|s| s.to_string()) else { + return Err(McpSdkError::Internal { + description: "jwks_uri is not defined!".to_string(), + }); + }; + + let token_verifier: Box = match options.token_verifier { + Some(verifier) => verifier, + None => Box::new(GenericOauthTokenVerifier::new(TokenVerifierOptions { + strategies: vec![VerificationStrategies::JWKs { jwks_uri }], + validate_audience: None, + validate_issuer: Some(issuer.to_string().trim_end_matches("/").to_string()), + cache_capacity: None, + })?), + }; + + Ok(Self { + endpoint_map, + protected_resource_metadata_url, + token_verifier, + auth_server_meta, + protected_resource_meta, + }) + } + + /// Helper to build JSON response for authorization server metadata with CORS. + fn handle_authorization_server_metadata( + response_str: String, + ) -> TransportServerResult> { + let body = Full::new(Bytes::from(response_str)) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed(); + http::Response::builder() + .status(StatusCode::OK) + .header(CONTENT_TYPE, "application/json") + .body(body) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + } + + /// Helper to build JSON response for protected resource metadata with permissive CORS. + fn handle_protected_resource_metadata( + response_str: String, + ) -> TransportServerResult> { + use http_body_util::BodyExt; + + let body = Full::new(Bytes::from(response_str)) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed(); + http::Response::builder() + .status(StatusCode::OK) + .header(CONTENT_TYPE, "application/json") + .body(body) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + } +} + +#[async_trait] +impl AuthProvider for ScalekitAuthProvider { + /// Returns the map of supported OAuth discovery endpoints. + fn auth_endpoints(&self) -> Option<&HashMap> { + Some(&self.endpoint_map) + } + + /// Handles incoming requests to OAuth metadata endpoints. + async fn handle_request( + &self, + request: http::Request<&str>, + state: Arc, + ) -> Result, TransportServerError> { + let Some(endpoint) = self.endpoint_type(&request) else { + return http::Response::builder() + .status(StatusCode::NOT_FOUND) + .body(GenericBody::empty()) + .map_err(|err| TransportServerError::HttpError(err.to_string())); + }; + + // return early if method is not allowed + if let Some(response) = self.validate_allowed_methods(endpoint, request.method()) { + return Ok(response); + } + + match endpoint { + OauthEndpoint::AuthorizationServerMetadata => { + let json_payload = serde_json::to_string(&self.auth_server_meta) + .map_err(|err| TransportServerError::HttpError(err.to_string()))?; + let cors = &CorsMiddleware::default(); + cors.handle( + request, + state, + Box::new(move |_req, _state| { + Box::pin( + async move { Self::handle_authorization_server_metadata(json_payload) }, + ) + }), + ) + .await + } + OauthEndpoint::ProtectedResourceMetadata => { + let json_payload = serde_json::to_string(&self.protected_resource_meta) + .map_err(|err| TransportServerError::HttpError(err.to_string()))?; + let cors = &CorsMiddleware::default(); + cors.handle( + request, + state, + Box::new(move |_req, _state| { + Box::pin( + async move { Self::handle_protected_resource_metadata(json_payload) }, + ) + }), + ) + .await + } + _ => Ok(GenericBody::create_404_response()), + } + } + + /// Verifies an access token using JWKs and optional UserInfo validation. + /// + /// Returns authenticated `AuthInfo` on success. + async fn verify_token(&self, access_token: String) -> Result { + self.token_verifier.verify_token(access_token).await + } + + /// Returns the full URL to the protected resource metadata document. + fn protected_resource_metadata_url(&self) -> Option<&str> { + Some(self.protected_resource_metadata_url.as_str()) + } +} diff --git a/crates/rust-mcp-extra/src/auth_provider/work_os.rs b/crates/rust-mcp-extra/src/auth_provider/work_os.rs new file mode 100644 index 0000000..cd7f8fe --- /dev/null +++ b/crates/rust-mcp-extra/src/auth_provider/work_os.rs @@ -0,0 +1,272 @@ +//! # WorkOS AuthKit OAuth2 Provider for MCP Servers +//! +//! This module implements an OAuth2 specifically designed to integrate +//! [WorkOS AuthKit](https://workos.com/docs/authkit) as the identity +//! provider (IdP) in an MCP (Model Context Protocol) server ecosystem. +//! +//! It enables your MCP server to: +//! - Expose standard OAuth2/.well-known endpoints +//! - Serve authorization server metadata (`/.well-known/oauth-authorization-server`) +//! - Serve protected resource metadata (custom per MCP) +//! - Verify incoming access tokens using JWKs + UserInfo endpoint validation +//! +//! ## Features +//! +//! - Zero-downtime token verification with cached JWKs +//! - Automatic construction of OAuth2 discovery documents +//! - Built-in CORS support for metadata endpoints +//! - Pluggable into `rust-mcp-sdk`'s authentication system via the `AuthProvider` trait +//! +//! ## Example +//! +//! ```rust,ignore +//! +//! let auth_provider = WorkOsAuthProvider::new(WorkOSAuthOptions { +//! // Your AuthKit app domain (found in WorkOS dashboard) +//! authkit_domain: "https://your-app.authkit.app".to_string(), +//! // Base URL of your MCP server (used to build protected resource metadata URL) +//! mcp_server_url: "http://localhost:3000/mcp".to_string(), +//! })?; +//! +//! // Register in your MCP server +//! let server = hyper_server::create_server( +//! server_details, +//! handler, +//! HyperServerOptions { +//! host: "localhost".to_string(), +//! port: 3000, +//! auth: Some(Arc::new(auth_provider)), +//! ..Default::default() +//! }); +//! ``` +use crate::token_verifier::{ + GenericOauthTokenVerifier, TokenVerifierOptions, VerificationStrategies, +}; +use async_trait::async_trait; +use bytes::Bytes; +use http::{header::CONTENT_TYPE, StatusCode}; +use http_body_util::{BodyExt, Full}; +use rust_mcp_sdk::{ + auth::{ + create_discovery_endpoints, AuthInfo, AuthMetadataBuilder, AuthProvider, + AuthenticationError, AuthorizationServerMetadata, OauthEndpoint, + OauthProtectedResourceMetadata, OauthTokenVerifier, + }, + error::McpSdkError, + mcp_http::{middleware::CorsMiddleware, GenericBody, GenericBodyExt, Middleware}, + mcp_server::{ + error::{TransportServerError, TransportServerResult}, + join_url, McpAppState, + }, +}; +use std::{collections::HashMap, sync::Arc, vec}; + +static SCOPES_SUPPORTED: &[&str] = &["email", "offline_access", "openid", "profile"]; + +/// Configuration options for the WorkOS AuthKit OAuth provider. +pub struct WorkOSAuthOptions<'a> { + pub authkit_domain: String, + pub mcp_server_url: String, + pub required_scopes: Option>, + pub token_verifier: Option>, + pub resource_name: Option, + pub resource_documentation: Option, +} + +/// WorkOS AuthKit integration implementing `AuthProvider` for MCP servers. +/// +/// This provider makes your MCP server compatible with clients that expect standard +/// OAuth2 authorization server and protected resource discovery endpoints when using +/// WorkOS AuthKit as the identity provider. +pub struct WorkOsAuthProvider { + auth_server_meta: AuthorizationServerMetadata, + protected_resource_meta: OauthProtectedResourceMetadata, + endpoint_map: HashMap, + protected_resource_metadata_url: String, + token_verifier: Box, +} + +impl WorkOsAuthProvider { + /// Creates a new `WorkOsAuthProvider` instance. + /// + /// This performs: + /// - Validation and parsing of URLs + /// - Construction of OAuth2 metadata documents + /// - Setup of token verification using JWKs and UserInfo endpoint + /// + /// /// # Example + /// + /// ```rust,ignore + /// use rust_mcp_extra::auth_provider::work_os::{WorkOSAuthOptions, WorkOsAuthProvider}; + /// + /// let auth_provider = WorkOsAuthProvider::new(WorkOSAuthOptions { + /// authkit_domain: "https://your-app.authkit.app".to_string(), + /// mcp_server_url: "http://localhost:3000/mcp".to_string(), + /// })?; + /// + pub fn new(mut options: WorkOSAuthOptions) -> Result { + let (endpoint_map, protected_resource_metadata_url) = + create_discovery_endpoints(&options.mcp_server_url)?; + + let required_scopes = options.required_scopes.take(); + let scopes_supported = required_scopes.clone().unwrap_or(SCOPES_SUPPORTED.to_vec()); + + let mut builder = AuthMetadataBuilder::new(&options.mcp_server_url) + .issuer(&options.authkit_domain) + .authorization_servers(vec![&options.authkit_domain]) + .authorization_endpoint("/oauth2/authorize") + .introspection_endpoint("/oauth2/introspection") + .registration_endpoint("/oauth2/register") + .token_endpoint("/oauth2/token") + .jwks_uri("/oauth2/jwks") + .scopes_supported(scopes_supported); + + if let Some(scopes) = required_scopes { + builder = builder.reqquired_scopes(scopes) + } + if let Some(resource_name) = options.resource_name.as_ref() { + builder = builder.resource_name(resource_name) + } + if let Some(resource_documentation) = options.resource_documentation.as_ref() { + builder = builder.service_documentation(resource_documentation) + } + + let (auth_server_meta, protected_resource_meta) = builder.build()?; + + let Some(jwks_uri) = auth_server_meta.jwks_uri.as_ref().map(|s| s.to_string()) else { + return Err(McpSdkError::Internal { + description: "jwks_uri is not defined!".to_string(), + }); + }; + + let userinfo_uri = join_url(&auth_server_meta.issuer, "oauth2/userinfo") + .map_err(|err| McpSdkError::Internal { + description: format!("invalid userinfo url :{err}"), + })? + .to_string(); + + let token_verifier: Box = match options.token_verifier { + Some(verifier) => verifier, + None => Box::new(GenericOauthTokenVerifier::new(TokenVerifierOptions { + strategies: vec![ + VerificationStrategies::JWKs { jwks_uri }, + VerificationStrategies::UserInfo { userinfo_uri }, + ], + validate_audience: None, + validate_issuer: Some(options.authkit_domain.clone()), + cache_capacity: None, + })?), + }; + + Ok(Self { + endpoint_map, + protected_resource_metadata_url, + token_verifier, + auth_server_meta, + protected_resource_meta, + }) + } + + /// Helper to build JSON response for authorization server metadata with CORS. + fn handle_authorization_server_metadata( + response_str: String, + ) -> TransportServerResult> { + let body = Full::new(Bytes::from(response_str)) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed(); + http::Response::builder() + .status(StatusCode::OK) + .header(CONTENT_TYPE, "application/json") + .body(body) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + } + + /// Helper to build JSON response for protected resource metadata with permissive CORS. + fn handle_protected_resource_metadata( + response_str: String, + ) -> TransportServerResult> { + use http_body_util::BodyExt; + + let body = Full::new(Bytes::from(response_str)) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed(); + http::Response::builder() + .status(StatusCode::OK) + .header(CONTENT_TYPE, "application/json") + .body(body) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + } +} + +#[async_trait] +impl AuthProvider for WorkOsAuthProvider { + /// Returns the map of supported OAuth discovery endpoints. + fn auth_endpoints(&self) -> Option<&HashMap> { + Some(&self.endpoint_map) + } + + /// Handles incoming requests to OAuth metadata endpoints. + async fn handle_request( + &self, + request: http::Request<&str>, + state: Arc, + ) -> Result, TransportServerError> { + let Some(endpoint) = self.endpoint_type(&request) else { + return http::Response::builder() + .status(StatusCode::NOT_FOUND) + .body(GenericBody::empty()) + .map_err(|err| TransportServerError::HttpError(err.to_string())); + }; + + // return early if method is not allowed + if let Some(response) = self.validate_allowed_methods(endpoint, request.method()) { + return Ok(response); + } + + match endpoint { + OauthEndpoint::AuthorizationServerMetadata => { + let json_payload = serde_json::to_string(&self.auth_server_meta) + .map_err(|err| TransportServerError::HttpError(err.to_string()))?; + let cors = &CorsMiddleware::default(); + cors.handle( + request, + state, + Box::new(move |_req, _state| { + Box::pin( + async move { Self::handle_authorization_server_metadata(json_payload) }, + ) + }), + ) + .await + } + OauthEndpoint::ProtectedResourceMetadata => { + let json_payload = serde_json::to_string(&self.protected_resource_meta) + .map_err(|err| TransportServerError::HttpError(err.to_string()))?; + let cors = &CorsMiddleware::default(); + cors.handle( + request, + state, + Box::new(move |_req, _state| { + Box::pin( + async move { Self::handle_protected_resource_metadata(json_payload) }, + ) + }), + ) + .await + } + _ => Ok(GenericBody::create_404_response()), + } + } + + /// Verifies an access token using JWKs and optional UserInfo validation. + /// + /// Returns authenticated `AuthInfo` on success. + async fn verify_token(&self, access_token: String) -> Result { + self.token_verifier.verify_token(access_token).await + } + + /// Returns the full URL to the protected resource metadata document. + fn protected_resource_metadata_url(&self) -> Option<&str> { + Some(self.protected_resource_metadata_url.as_str()) + } +} diff --git a/crates/rust-mcp-extra/src/id_generator/snow_flake_id_generator.rs b/crates/rust-mcp-extra/src/id_generator/snow_flake_id_generator.rs index 6f378b1..3bea626 100644 --- a/crates/rust-mcp-extra/src/id_generator/snow_flake_id_generator.rs +++ b/crates/rust-mcp-extra/src/id_generator/snow_flake_id_generator.rs @@ -65,10 +65,10 @@ impl SnowflakeIdGenerator { let last_ts = self.last_timestamp.load(Ordering::Relaxed); let sequence = if timestamp == last_ts { - // same millisecond — increment sequence + // same millisecond - increment sequence let seq = self.sequence.fetch_add(1, Ordering::Relaxed) & 0xFFF; // 12 bits if seq == 0 { - // Sequence overflow — wait for next ms + // Sequence overflow - wait for next ms while timestamp <= last_ts { timestamp = self.current_timestamp(); } diff --git a/crates/rust-mcp-extra/src/lib.rs b/crates/rust-mcp-extra/src/lib.rs index f4a3d0e..d67bfc9 100644 --- a/crates/rust-mcp-extra/src/lib.rs +++ b/crates/rust-mcp-extra/src/lib.rs @@ -1,5 +1,9 @@ +#[cfg(feature = "auth")] +pub mod auth_provider; pub mod http_adaptors; pub mod id_generator; pub mod sqlite; +#[cfg(feature = "auth")] +pub mod token_verifier; pub use rust_mcp_sdk::id_generator::IdGenerator; diff --git a/crates/rust-mcp-extra/src/token_verifier.rs b/crates/rust-mcp-extra/src/token_verifier.rs new file mode 100644 index 0000000..e3b013f --- /dev/null +++ b/crates/rust-mcp-extra/src/token_verifier.rs @@ -0,0 +1,5 @@ +mod generic_token_verifier; +mod jwt_cache; + +pub use generic_token_verifier::*; +pub use jwt_cache::*; diff --git a/crates/rust-mcp-extra/src/token_verifier/generic_token_verifier.rs b/crates/rust-mcp-extra/src/token_verifier/generic_token_verifier.rs new file mode 100644 index 0000000..579807e --- /dev/null +++ b/crates/rust-mcp-extra/src/token_verifier/generic_token_verifier.rs @@ -0,0 +1,869 @@ +use crate::token_verifier::jwt_cache::JwtCache; +use async_lock::RwLock; +use async_trait::async_trait; +use reqwest::{header::AUTHORIZATION, StatusCode}; +use rust_mcp_sdk::{ + auth::{ + decode_token_header, Audience, AuthInfo, AuthenticationError, IntrospectionResponse, + JsonWebKeySet, OauthTokenVerifier, + }, + mcp_http::error_message_from_response, +}; +use serde_json::Value; +use std::{ + collections::HashMap, + time::{Duration, SystemTime}, +}; +use url::Url; + +const JWKS_REFRESH_TIME: Duration = Duration::from_secs(24 * 60 * 60); // re-fetch jwks every 24 hours +const REMOTE_VERIFICATION_INTERVAL: Duration = Duration::from_secs(15 * 60); // 15 minutes +const JWT_CACHE_CAPACITY: usize = 1000; + +struct JwksCache { + last_updated: Option, + jwks: JsonWebKeySet, +} + +/// Supported OAuth token verification strategies. +/// +/// Each variant represents a different method for validating access tokens, +/// depending on what the authorization server exposes or what your application +/// requires. +pub enum VerificationStrategies { + /// Verifies tokens by calling the authorization server's introspection + /// endpoint, as defined in RFC 7662. + /// + /// This method allows the resource server to validate opaque or JWT tokens + /// by sending them to the introspection URI along with its client credentials. + Introspection { + /// The OAuth introspection endpoint. + introspection_uri: String, + /// Client identifier used to authenticate the introspection request. + client_id: String, + /// Client secret used to authenticate the introspection request. + client_secret: String, + /// Indicates whether the OAuth2 client should use HTTP Basic Authentication when + ///calling the token introspection endpoint. + /// if false: client_id and client_secret will be sent in the POST body instead of using Basic Authentication + use_basic_auth: bool, + /// Optional key-value pairs to include as additional parameters in the + /// body of the token introspection request. + /// Example : ("token_type_hint", "access_token") + extra_params: Option>, + }, + /// Verifies JWT access tokens using the authorization server’s JSON Web Key + /// Set (JWKS) endpoint. + /// + /// This strategy allows fully offline signature validation after retrieving + /// the key set, making it efficient for high-throughput services. + JWKs { + /// The JWKS endpoint URL used to retrieve signing keys. + jwks_uri: String, + }, + /// Verifies tokens by querying the OpenID Connect UserInfo endpoint. + /// + /// This strategy is typically used when token validity is tied to the user's + /// profile information or when the resource server relies on OIDC user data + /// for validation. + UserInfo { userinfo_uri: String }, +} + +/// Options for configuring a token verifier. +/// +/// `TokenVerifierOptions` allows specifying one or more strategies for verifying +/// OAuth access tokens. Multiple strategies can be provided; the verifier will +/// attempt them in order until one succeeds or all fail. +pub struct TokenVerifierOptions { + /// The list of token verification strategies to use. + /// Each strategy defines a different method for validating tokens, such as + /// introspection, JWKS signature validation, or querying the UserInfo endpoint. + /// For optimal performance, it is recommended to include JWKS alongside either introspection or UserInfo. + pub strategies: Vec, + /// Optional audience value to validate against the token's `aud` claim. + pub validate_audience: Option, + /// Optional issuer value to validate against the token's `iss` claim. + pub validate_issuer: Option, + /// Optional capacity for the internal cache, used to reduce unnecessary requests during verification. + pub cache_capacity: Option, +} + +#[derive(Default, Debug)] +struct StrategiesOptions { + pub introspection_uri: Option, + pub introspection_basic_auth: bool, + pub introspect_extra_params: Option>, + pub client_id: Option, + pub client_secret: Option, + pub jwks_uri: Option, + pub userinfo_uri: Option, +} + +impl TokenVerifierOptions { + fn unpack(&mut self) -> Result<(StrategiesOptions, bool), AuthenticationError> { + let mut result = StrategiesOptions::default(); + + let mut has_jwks = false; + let mut has_other = false; + + for strategy in self.strategies.drain(0..) { + match strategy { + VerificationStrategies::Introspection { + introspection_uri, + client_id, + client_secret, + use_basic_auth, + extra_params, + } => { + result.introspection_uri = + Some(Url::parse(&introspection_uri).map_err(|err| { + AuthenticationError::ParsingError(format!( + "Invalid introspection uri: {err}", + )) + })?); + result.client_id = Some(client_id); + result.client_secret = Some(client_secret); + result.introspection_basic_auth = use_basic_auth; + result.introspect_extra_params = extra_params; + has_other = true; + } + VerificationStrategies::JWKs { jwks_uri } => { + result.jwks_uri = Some(Url::parse(&jwks_uri).map_err(|err| { + AuthenticationError::ParsingError(format!("Invalid jwks uri: {err}")) + })?); + has_jwks = true; + } + VerificationStrategies::UserInfo { userinfo_uri } => { + result.userinfo_uri = Some(Url::parse(&userinfo_uri).map_err(|err| { + AuthenticationError::ParsingError(format!("Invalid userinfo uri: {err}")) + })?); + has_other = true; + } + } + } + + Ok((result, has_jwks && has_other)) + } +} + +pub struct GenericOauthTokenVerifier { + /// Optional audience value to validate against the token's `aud` claim. + validate_audience: Option, + /// Optional issuer value to validate against the token's `iss` claim. + validate_issuer: Option, + jwt_cache: Option>, + json_web_key_set: RwLock>, + introspection_uri: Option, + introspection_basic_auth: bool, + introspect_extra_params: Option>, + client_id: Option, + client_secret: Option, + jwks_uri: Option, + userinfo_uri: Option, +} + +impl GenericOauthTokenVerifier { + pub fn new(mut options: TokenVerifierOptions) -> Result { + let (strategy_options, chachable) = options.unpack()?; + + let validate_audience = options.validate_audience.take(); + + let validate_issuer = options + .validate_issuer + .map(|iss| iss.trim_end_matches('/').to_string()); + + // we only need to cache if both jwks and introspection are supported + let jwt_cache = if chachable { + Some(RwLock::new(JwtCache::new( + REMOTE_VERIFICATION_INTERVAL, + options.cache_capacity.unwrap_or(JWT_CACHE_CAPACITY), + ))) + } else { + None + }; + + Ok(Self { + validate_issuer, + validate_audience, + jwt_cache, + json_web_key_set: RwLock::new(None), + introspection_uri: strategy_options.introspection_uri, + introspection_basic_auth: strategy_options.introspection_basic_auth, + introspect_extra_params: strategy_options.introspect_extra_params, + client_id: strategy_options.client_id, + client_secret: strategy_options.client_secret, + jwks_uri: strategy_options.jwks_uri, + userinfo_uri: strategy_options.userinfo_uri, + }) + } + + async fn verify_user_info( + &self, + token: &str, + token_unique_id: Option<&str>, + user_info_endpoint: &Url, + ) -> Result { + // use token_unique_id or get from token header + let token_unique_id = match token_unique_id { + Some(id) => id.to_owned(), + None => { + let header = decode_token_header(token)?; + header.kid.unwrap_or(token.to_string()).to_owned() + } + }; + + let client = reqwest::Client::new(); + println!(">>> user_info_endpoint {:?} ", user_info_endpoint.as_str()); + + let response = client + .get(user_info_endpoint.to_owned()) + .header(AUTHORIZATION, format!("Bearer {token}")) + .send() + .await + .map_err(|err| AuthenticationError::Jwks(err.to_string()))?; + + let status_code = response.status(); + + if !response.status().is_success() { + return Err(AuthenticationError::TokenVerificationFailed { + description: error_message_from_response(response, "Unauthorized!").await, + status_code: Some(status_code.as_u16()), + }); + } + + let json: Value = response.json().await.unwrap(); + + let extra = match json { + Value::Object(map) => Some(map), + _ => None, + }; + + let auth_info: AuthInfo = AuthInfo { + token_unique_id, + client_id: None, + user_id: None, + scopes: None, + expires_at: None, + audience: None, + extra, + }; + + Ok(auth_info) + } + + async fn verify_introspection( + &self, + token: &str, + introspection_endpoint: &Url, + ) -> Result { + let client = reqwest::Client::new(); + + // Form data body + let mut form = HashMap::new(); + form.insert("token", token); + + if !self.introspection_basic_auth { + if let Some(client_id) = self.client_id.as_ref() { + form.insert("client_id", client_id); + }; + if let Some(client_secret) = self.client_secret.as_ref() { + form.insert("client_secret", client_secret); + }; + } + + if let Some(extra_params) = self.introspect_extra_params.as_ref() { + extra_params.iter().for_each(|(key, value)| { + form.insert(key, value); + }); + } + + let mut request = client.post(introspection_endpoint.to_owned()).form(&form); + if self.introspection_basic_auth { + request = request.basic_auth( + self.client_id.clone().unwrap_or_default(), + self.client_secret.clone(), + ); + } + + let response = request + .send() + .await + .map_err(|err| AuthenticationError::Jwks(err.to_string()))?; + + let status_code = response.status(); + if !response.status().is_success() { + let description = response.text().await.unwrap_or("Unauthorized!".to_string()); + return Err(AuthenticationError::TokenVerificationFailed { + description, + status_code: Some(status_code.as_u16()), + }); + } + + let introspect_response: IntrospectionResponse = response + .json() + .await + .map_err(|err| AuthenticationError::Jwks(err.to_string()))?; + + if !introspect_response.active { + return Err(AuthenticationError::InactiveToken); + } + + if let Some(validate_audience) = self.validate_audience.as_ref() { + let Some(token_audience) = introspect_response.audience.as_ref() else { + return Err(AuthenticationError::InvalidToken { + description: "Audience attribute (aud) is missing.", + }); + }; + + if token_audience != validate_audience { + return Err(AuthenticationError::TokenVerificationFailed { description: + format!("None of the provided audiences are allowed. Expected ${validate_audience}, got: ${token_audience}") + , status_code: Some(StatusCode::UNAUTHORIZED.as_u16()) + }); + } + } + + if let Some(validate_issuer) = self.validate_issuer.as_ref() { + let Some(token_issuer) = introspect_response.issuer.as_ref() else { + return Err(AuthenticationError::InvalidToken { + description: "Issuer (iss) is missing.", + }); + }; + + if token_issuer != validate_issuer { + return Err(AuthenticationError::TokenVerificationFailed { + description: format!( + "Issuer is not allowed. Expected ${validate_issuer}, got: ${token_issuer}" + ), + status_code: Some(StatusCode::UNAUTHORIZED.as_u16()), + }); + } + } + + AuthInfo::from_introspection_response(token.to_owned(), introspect_response, None) + } + + async fn populate_jwks(&self, jwks_uri: &Url) -> Result<(), AuthenticationError> { + let response = reqwest::get(jwks_uri.to_owned()) + .await + .map_err(|err| AuthenticationError::Jwks(err.to_string()))?; + let jwks: JsonWebKeySet = response + .json() + .await + .map_err(|err| AuthenticationError::Jwks(err.to_string()))?; + let mut guard = self.json_web_key_set.write().await; + *guard = Some(JwksCache { + last_updated: Some(SystemTime::now()), + jwks, + }); + Ok(()) + } + + async fn verify_jwks(&self, token: &str, jwks: &Url) -> Result { + // read-modify-write pattern + { + let guard = self.json_web_key_set.read().await; + if let Some(cache) = guard.as_ref() { + if let Some(last_updated) = cache.last_updated { + if SystemTime::now() + .duration_since(last_updated) + .unwrap_or(Duration::from_secs(0)) + < JWKS_REFRESH_TIME + { + let token_info = cache.jwks.verify( + token.to_string(), + self.validate_audience.as_ref(), + self.validate_issuer.as_ref(), + )?; + + return AuthInfo::from_token_data(token.to_owned(), token_info, None); + } + } + } + } + + // Refresh JWKS if cache is invalid or missing + self.populate_jwks(jwks).await?; + + // Proceed with verification + let guard = self.json_web_key_set.read().await; + if let Some(cache) = guard.as_ref() { + let token_info = cache.jwks.verify( + token.to_string(), + self.validate_audience.as_ref(), + self.validate_issuer.as_ref(), + )?; + + AuthInfo::from_token_data(token.to_owned(), token_info, None) + } else { + Err(AuthenticationError::Jwks( + "Failed to retrieve or parse JWKS".to_string(), + )) + } + } +} + +#[async_trait] +impl OauthTokenVerifier for GenericOauthTokenVerifier { + async fn verify_token(&self, access_token: String) -> Result { + // perform local jwks verification if supported + if let Some(jwks_endpoint) = self.jwks_uri.as_ref() { + let mut auth_info = self.verify_jwks(&access_token, jwks_endpoint).await?; + + // perform remote verification only if it is supported and jwt is stale + if let Some(jwt_cache) = self.jwt_cache.as_ref() { + // return auth_info if it is recent + if jwt_cache.read().await.is_recent(&auth_info.token_unique_id) { + return Ok(auth_info); + } + + // introspection validation if introspection_uri is provided + if let Some(introspection_endpoint) = self.introspection_uri.as_ref() { + let fresh_auth_info = self + .verify_introspection(&access_token, introspection_endpoint) + .await?; + jwt_cache + .write() + .await + .record(fresh_auth_info.token_unique_id.to_owned()); + return Ok(fresh_auth_info); + } + + // call userInfo endpoint only if introspect strategy is not used + if let Some(user_info_endpoint) = self.userinfo_uri.as_ref() { + let fresh_auth_info = self + .verify_user_info( + &access_token, + Some(&auth_info.token_unique_id), + user_info_endpoint, + ) + .await?; + + auth_info.extra = fresh_auth_info.extra; + jwt_cache + .write() + .await + .record(auth_info.token_unique_id.to_owned()); + return Ok(auth_info); + } + } + + return Ok(auth_info); + } + + // use introspection if jwks is not supported, no caching + if let Some(introspection_endpoint) = self.introspection_uri.as_ref() { + let auth_info = self + .verify_introspection(&access_token, introspection_endpoint) + .await?; + return Ok(auth_info); + } + + // use userInfo endpoint if introspect strategy is not used + if let Some(user_info_endpoint) = self.userinfo_uri.as_ref() { + let auth_info = self + .verify_user_info(&access_token, None, user_info_endpoint) + .await?; + return Ok(auth_info); + } + + Err(AuthenticationError::InvalidToken { + description: "Invalid token verification strategy!", + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use oauth2_test_server::{OAuthTestServer, OauthEndpoints}; + use rust_mcp_sdk::auth::*; + use serde_json::json; + + async fn token_verifier( + strategies: Vec, + endpoints: &OauthEndpoints, + audience: Option, + ) -> GenericOauthTokenVerifier { + let auth_metadata = AuthMetadataBuilder::new("http://127.0.0.1:3000/mcp") + .issuer(&endpoints.oauth_server) + .authorization_servers(vec![&endpoints.oauth_server]) + .authorization_endpoint(&endpoints.authorize) + .token_endpoint(&endpoints.token) + .scopes_supported(vec!["openid".to_string()]) + .introspection_endpoint(&endpoints.introspect) + .jwks_uri(&endpoints.jwks) + .resource_name("MCP Demo Server".to_string()) + .build() + .unwrap(); + let meta = &auth_metadata.0; + + let token_verifier = GenericOauthTokenVerifier::new(TokenVerifierOptions { + validate_audience: audience, + validate_issuer: Some(meta.issuer.to_string()), + strategies, + cache_capacity: None, + }) + .unwrap(); + token_verifier + } + + #[tokio::test] + async fn test_jwks_strategy() { + let server = OAuthTestServer::start().await; + + let client = server.register_client( + json!({ "scope": "openid", "redirect_uris":["http://localhost:8080/callback"]}), + ); + + let verifier = token_verifier( + vec![VerificationStrategies::JWKs { + jwks_uri: server.endpoints.jwks.clone(), + }], + &server.endpoints, + Some(Audience::Single(client.client_id.clone())), + ) + .await; + + let token = server.generate_jwt(&client, server.jwt_options().user_id("rustmcp").build()); + + let auth_info = verifier.verify_token(token).await.unwrap(); + assert_eq!( + auth_info.audience.as_ref().unwrap().to_string(), + client.client_id + ); + assert_eq!( + auth_info.client_id.as_ref().unwrap().to_string(), + client.client_id + ); + assert_eq!(auth_info.user_id.as_ref().unwrap(), "rustmcp"); + let scopes = auth_info.scopes.as_ref().unwrap(); + assert_eq!(scopes.as_slice(), ["openid"]); + } + + #[tokio::test] + async fn test_userinfo_strategy() { + let server = OAuthTestServer::start().await; + + let client = server.register_client( + json!({ "scope": "openid", "redirect_uris":["http://localhost:8080/callback"]}), + ); + + let verifier = token_verifier( + vec![VerificationStrategies::UserInfo { + userinfo_uri: server.endpoints.userinfo.clone(), + }], + &server.endpoints, + None, + ) + .await; + + let token = server.generate_token(&client, server.jwt_options().user_id("rustmcp").build()); + + let auth_info = verifier.verify_token(token.access_token).await.unwrap(); + + assert!(auth_info.audience.is_none()); + assert_eq!( + auth_info + .extra + .unwrap() + .get("sub") + .unwrap() + .as_str() + .unwrap(), + "rustmcp" + ); + } + + #[tokio::test] + async fn test_introspect_strategy() { + let server = OAuthTestServer::start().await; + + let client = server.register_client( + json!({ "scope": "openid", "redirect_uris":["http://localhost:8080/callback"]}), + ); + + let verifier = token_verifier( + vec![VerificationStrategies::Introspection { + introspection_uri: server.endpoints.introspect.clone(), + client_id: client.client_id.clone(), + client_secret: client.client_secret.as_ref().unwrap().clone(), + use_basic_auth: true, + extra_params: None, + }], + &server.endpoints, + None, + ) + .await; + + let token = server.generate_token(&client, server.jwt_options().user_id("rustmcp").build()); + let auth_info = verifier.verify_token(token.access_token).await.unwrap(); + + assert_eq!( + auth_info.audience.as_ref().unwrap().to_string(), + client.client_id + ); + assert_eq!( + auth_info.client_id.as_ref().unwrap().to_string(), + client.client_id + ); + assert_eq!(auth_info.user_id.as_ref().unwrap(), "rustmcp"); + let scopes = auth_info.scopes.as_ref().unwrap(); + assert_eq!(scopes.as_slice(), ["openid"]); + } + + #[tokio::test] + async fn test_introspect_strategy_with_client_secret_post() { + let server = OAuthTestServer::start().await; + + let client = server.register_client( + json!({ "scope": "openid profile", "redirect_uris":["http://localhost:8080/cb"]}), + ); + + let verifier = token_verifier( + vec![VerificationStrategies::Introspection { + introspection_uri: server.endpoints.introspect.clone(), + client_id: client.client_id.clone(), + client_secret: client.client_secret.as_ref().unwrap().clone(), + use_basic_auth: false, // <--- POST body instead of Basic Auth + extra_params: None, + }], + &server.endpoints, + Some(Audience::Single(client.client_id.clone())), + ) + .await; + + let token = server.generate_token(&client, server.jwt_options().user_id("alice").build()); + + let auth_info = verifier.verify_token(token.access_token).await.unwrap(); + + assert_eq!(auth_info.user_id.as_ref().unwrap(), "alice"); + assert!(auth_info.scopes.unwrap().contains(&"profile".to_string())); + assert_eq!( + auth_info.audience.as_ref().unwrap().to_string(), + client.client_id + ); + } + + #[tokio::test] + async fn test_introspect_rejects_inactive_token() { + let server = OAuthTestServer::start().await; + let client = server + .register_client(json!({ "scope": "openid", "redirect_uris": ["http://localhost"] })); + + let verifier = token_verifier( + vec![VerificationStrategies::Introspection { + introspection_uri: server.endpoints.introspect.clone(), + client_id: client.client_id.clone(), + client_secret: client.client_secret.as_ref().unwrap().clone(), + use_basic_auth: true, + extra_params: None, + }], + &server.endpoints, + None, + ) + .await; + + let token_response = + server.generate_token(&client, server.jwt_options().user_id("bob").build()); + server + .revoke_token(&client, &token_response.access_token) + .await; + + let result = verifier.verify_token(token_response.access_token).await; + assert!(matches!(result, Err(AuthenticationError::InactiveToken))); + } + + #[tokio::test] + async fn test_expired_token_rejected_by_jwks_and_introspection() { + let server = OAuthTestServer::start().await; + let client = server.register_client( + json!({ "scope": "openid email", "redirect_uris": ["http://localhost"] }), + ); + + // Use both strategies → expect rejection on expiration alone + let verifier = token_verifier( + vec![ + VerificationStrategies::JWKs { + jwks_uri: server.endpoints.jwks.clone(), + }, + VerificationStrategies::Introspection { + introspection_uri: server.endpoints.introspect.clone(), + client_id: client.client_id.clone(), + client_secret: client.client_secret.as_ref().unwrap().clone(), + use_basic_auth: true, + extra_params: None, + }, + ], + &server.endpoints, + Some(Audience::Single(client.client_id.clone())), + ) + .await; + + // Generate short-lived token + let short_lived = server + .jwt_options() + .user_id("charlie") + .expires_in(1) + .build(); + let token = server.generate_token(&client, short_lived); + + // Wait for expiry + tokio::time::sleep(tokio::time::Duration::from_millis(1500)).await; + + // JWKS should reject immediately (exp validation) + // But since fallback is enabled, it hits introspection → active: false → error + let err1 = verifier + .verify_token(token.access_token.clone()) + .await + .unwrap_err(); + assert!(matches!(err1, AuthenticationError::InactiveToken)); + + // Now revoke it (expired + revoked) → still InactiveToken (no special handling needed) + server.revoke_token(&client, &token.access_token).await; + let err2 = verifier.verify_token(token.access_token).await.unwrap_err(); + assert!(matches!(err2, AuthenticationError::InactiveToken)); + } + + #[tokio::test] + async fn test_jwks_and_introspection_cache_works() { + let server = OAuthTestServer::start().await; + let client = server + .register_client(json!({ "scope": "openid", "redirect_uris": ["http://localhost"] })); + + let verifier = token_verifier( + vec![ + VerificationStrategies::JWKs { + jwks_uri: server.endpoints.jwks.clone(), + }, + VerificationStrategies::Introspection { + introspection_uri: server.endpoints.introspect.clone(), + client_id: client.client_id.clone(), + client_secret: client.client_secret.as_ref().unwrap().clone(), + use_basic_auth: true, + extra_params: None, + }, + ], + &server.endpoints, + None, + ) + .await; + + let token = server.generate_token(&client, server.jwt_options().user_id("dave").build()); + + // First call → goes through full flow + let info1 = verifier + .verify_token(token.access_token.clone()) + .await + .unwrap(); + + // Second call → should hit cache (no network) + let info2 = verifier + .verify_token(token.access_token.clone()) + .await + .unwrap(); + + assert_eq!(info1.user_id, info2.user_id); + assert_eq!(info1.token_unique_id, info2.token_unique_id); + } + + #[tokio::test] + async fn test_audience_validation_rejects_wrong_aud() { + let server = OAuthTestServer::start().await; + let client = server + .register_client(json!({ "scope": "openid", "redirect_uris": ["http://localhost"] })); + + let verifier = token_verifier( + vec![VerificationStrategies::Introspection { + introspection_uri: server.endpoints.introspect.clone(), + client_id: client.client_id.clone(), + client_secret: client.client_secret.as_ref().unwrap().clone(), + use_basic_auth: true, + extra_params: None, + }], + &server.endpoints, + Some(Audience::Single("wrong-client-id-999".to_string())), + ) + .await; + + let token = server.generate_token(&client, server.jwt_options().user_id("eve").build()); + + let err = verifier.verify_token(token.access_token).await.unwrap_err(); + assert!(matches!( + err, + AuthenticationError::TokenVerificationFailed { .. } + )); + } + + #[tokio::test] + async fn test_issuer_validation_rejects_wrong_iss() { + let server = OAuthTestServer::start().await; + let client = server + .register_client(json!({ "scope": "openid", "redirect_uris": ["http://localhost"] })); + + let _verifier = token_verifier( + vec![VerificationStrategies::JWKs { + jwks_uri: server.endpoints.jwks.clone(), + }], + &server.endpoints, + None, + ) + .await; + + // Force wrong expected issuer + let wrong_verifier = GenericOauthTokenVerifier::new(TokenVerifierOptions { + strategies: vec![VerificationStrategies::JWKs { + jwks_uri: server.endpoints.jwks.clone(), + }], + validate_audience: None, + validate_issuer: Some("https://wrong-issuer.example.com".to_string()), + cache_capacity: None, + }) + .unwrap(); + + let token = server.generate_token(&client, server.jwt_options().user_id("frank").build()); + + let err = wrong_verifier + .verify_token(token.access_token) + .await + .unwrap_err(); + assert!(matches!( + err, + AuthenticationError::TokenVerificationFailed { .. } + )); + } + + #[tokio::test] + async fn test_userinfo_enriches_jwt_claims() { + let server = OAuthTestServer::start().await; + let client = server.register_client( + json!({ "scope": "openid profile email", "redirect_uris": ["http://localhost"] }), + ); + + let verifier = token_verifier( + vec![ + VerificationStrategies::JWKs { + jwks_uri: server.endpoints.jwks.clone(), + }, + VerificationStrategies::UserInfo { + userinfo_uri: server.endpoints.userinfo.clone(), + }, + ], + &server.endpoints, + None, + ) + .await; + + let token = server.generate_token(&client, server.jwt_options().user_id("grace").build()); + + let auth_info = verifier.verify_token(token.access_token).await.unwrap(); + + let extra = auth_info.extra.unwrap(); + assert_eq!( + extra.get("email").unwrap().as_str().unwrap(), + "test@example.com" + ); + assert_eq!(extra.get("name").unwrap().as_str().unwrap(), "Test User"); + assert!(extra.get("picture").is_some()); + } +} diff --git a/crates/rust-mcp-extra/src/token_verifier/jwt_cache.rs b/crates/rust-mcp-extra/src/token_verifier/jwt_cache.rs new file mode 100644 index 0000000..f4c5259 --- /dev/null +++ b/crates/rust-mcp-extra/src/token_verifier/jwt_cache.rs @@ -0,0 +1,67 @@ +use std::collections::{HashMap, VecDeque}; +use std::time::{Duration, Instant}; + +/// JWT introspection cache with TTL and max capacity +pub struct JwtCache { + map: HashMap, // Key -> last introspection time + order: VecDeque, // Keys in insertion order + remote_verification_interval: Duration, + capacity: usize, +} + +impl JwtCache { + /// Create a new cache with given TTL and capacity + pub fn new(remote_verification_interval: Duration, capacity: usize) -> Self { + Self { + map: HashMap::with_capacity(capacity), + order: VecDeque::with_capacity(capacity), + remote_verification_interval, + capacity, + } + } + + pub fn is_recent(&self, key: &str) -> bool { + self.map + .get(key) + .is_some_and(|t| t.elapsed() <= self.remote_verification_interval) + } + + /// Record , updates timestamp or adds new entry + pub fn record(&mut self, key: String) { + // Remove expired entries first + self.remove_expired(); + + if self.map.contains_key(&key) { + // Update timestamp (no promotion in order) + self.map.insert(key.clone(), Instant::now()); + } else { + // Evict oldest if over capacity + if self.map.len() >= self.capacity { + if let Some(oldest) = self.order.pop_front() { + self.map.remove(&oldest); + } + } + self.map.insert(key.clone(), Instant::now()); + self.order.push_back(key); + } + } + + /// Remove expired entries + pub fn remove_expired(&mut self) { + let now = Instant::now(); + let mut expired = Vec::new(); + + for key in &self.order { + if let Some(&last) = self.map.get(key).as_ref() { + if now.duration_since(last.to_owned()) > self.remote_verification_interval { + expired.push(key.clone()); + } + } + } + + for key in expired { + self.map.remove(&key); + self.order.retain(|k| *k != key); + } + } +} diff --git a/crates/rust-mcp-macros/tests/common/common.rs b/crates/rust-mcp-macros/tests/common/common.rs index d6bae2e..1133d64 100644 --- a/crates/rust-mcp-macros/tests/common/common.rs +++ b/crates/rust-mcp-macros/tests/common/common.rs @@ -1,7 +1,6 @@ -use std::str::FromStr; - use rust_mcp_macros::JsonSchema; use rust_mcp_schema::RpcError; +use std::str::FromStr; #[derive(::serde::Deserialize, ::serde::Serialize, Clone, Debug, JsonSchema)] /// Represents a text replacement operation. diff --git a/crates/rust-mcp-sdk/Cargo.toml b/crates/rust-mcp-sdk/Cargo.toml index 4f4238a..cbff768 100644 --- a/crates/rust-mcp-sdk/Cargo.toml +++ b/crates/rust-mcp-sdk/Cargo.toml @@ -34,30 +34,33 @@ bytes.workspace = true # rustls = { workspace = true, optional = true } hyper = { version = "1.6.0", optional = true } -http = { version ="1.3", optional = true } -http-body-util = { version ="0.1", optional = true } -http-body = { version ="1.0", optional = true } - -[dev-dependencies] -wiremock = "0.5" +http = { workspace = true, optional = true } +http-body-util = { workspace = true, optional = true } +http-body = { workspace = true, optional = true } +url = {workspace = true, optional=true} +jsonwebtoken = {version="10.1", optional=true, features=["aws_lc_rs"]} reqwest = { workspace = true, default-features = false, features = [ "stream", "rustls-tls", "json", "cookies", "multipart", -] } +], optional = true } + +[dev-dependencies] +wiremock = "0.5" tempfile = "3.23.0" tracing-subscriber = { workspace = true, features = [ "env-filter", "std", "fmt", -] } +]} [features] default = [ "client", "server", + "auth", "macros", "stdio", "sse", @@ -70,6 +73,7 @@ default = [ sse = ["rust-mcp-transport/sse","http","http-body","http-body-util"] streamable-http = ["rust-mcp-transport/streamable-http","http","http-body","http-body-util"] stdio = ["rust-mcp-transport/stdio"] +auth=["url","jsonwebtoken/aws_lc_rs","reqwest"] server = [] # Server feature client = [] # Client feature diff --git a/crates/rust-mcp-sdk/README.md b/crates/rust-mcp-sdk/README.md index 2c70c3e..d92d964 100644 --- a/crates/rust-mcp-sdk/README.md +++ b/crates/rust-mcp-sdk/README.md @@ -21,37 +21,35 @@ Leveraging the [rust-mcp-schema](https://github.com/rust-mcp-stack/rust-mcp-sche **rust-mcp-sdk** supports all three official versions of the MCP protocol. By default, it uses the **2025-06-18** version, but earlier versions can be enabled via Cargo features. - - -This project supports following transports: -- **Stdio** (Standard Input/Output) -- **Streamable HTTP** -- **SSE** (Server-Sent Events) - - 🚀 The **rust-mcp-sdk** includes a lightweight [Axum](https://github.com/tokio-rs/axum) based server that handles all core functionality seamlessly. Switching between `stdio` and `Streamable HTTP` is straightforward, requiring minimal code changes. The server is designed to efficiently handle multiple concurrent client connections and offers built-in support for SSL. -**MCP Streamable HTTP Support** -- ✅ Streamable HTTP Support for MCP Servers +**Features** +- ✅ Stdio, SSE and Streamable HTTP Support +- ✅ Supports multiple MCP protocol versions - ✅ DNS Rebinding Protection - ✅ Batch Messages - ✅ Streaming & non-streaming JSON response -- ✅ Streamable HTTP Support for MCP Clients - ✅ Resumability -- ⬜ Oauth Authentication +- ✅ OAuth Authentication for MCP Servers + - ✅ [Remote Oauth Provider](crates/rust-mcp-sdk/src/auth/auth_provider/remote_auth_provider.rs) (for any provider with DCR support) + - ✅ **Keycloak** Provider (via [rust-mcp-extra](crates/rust-mcp-extra/README.md#keycloak)) + - ✅ **WorkOS** Authkit Provider (via [rust-mcp-extra](crates/rust-mcp-extra/README.md#workos-authkit)) + - ✅ **Scalekit** Authkit Provider (via [rust-mcp-extra](crates/rust-mcp-extra/README.md#scalekit)) +- ⬜ OAuth Authentication for MCP Clients **⚠️** Project is currently under development and should be used at your own risk. ## Table of Contents +- [Getting Started](#getting-started) - [Usage Examples](#usage-examples) - [MCP Server (stdio)](#mcp-server-stdio) - [MCP Server (Streamable HTTP)](#mcp-server-streamable-http) - [MCP Client (stdio)](#mcp-client-stdio) - - [MCP Client (Streamable HTTP)](#mcp-client_streamable-http)) + - [MCP Client (Streamable HTTP)](#mcp-client-streamable-http) - [MCP Client (sse)](#mcp-client-sse) +- [Authentication](#authentication) - [Macros](#macros) -- [Getting Started](#getting-started) - [HyperServerOptions](#hyperserveroptions) - [Security Considerations](#security-considerations) - [Cargo features](#cargo-features) @@ -68,6 +66,12 @@ This project supports following transports: - [Development](#development) - [License](#license) + +## Getting Started + +If you are looking for a step-by-step tutorial on how to get started with `rust-mcp-sdk` , please see : [Getting Started MCP Server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/doc/getting-started-mcp-server.md) + + ## Usage Examples ### MCP Server (stdio) @@ -387,6 +391,26 @@ Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost 👉 see [examples/simple-mcp-client-sse](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-sse) for a complete working example. +## Authentication +MCP server can verify tokens issued by other systems, integrate with external identity providers, or manage the entire authentication process itself. Each option offers a different balance of simplicity, security, and control. + + ### RemoteAuthProvider + [RemoteAuthProvider](src/mcp_http/auth/auth_provider/remote_auth_provider.rs) RemoteAuthProvider enables authentication with identity providers that support Dynamic Client Registration (DCR) such as KeyCloak and WorkOS AuthKit, letting MCP clients auto-register and obtain credentials without manual setup. + +👉 See the [server-oauth-remote](examples/auth/server-oauth-remote) example for how to use RemoteAuthProvider with a DCR-capable remote provider. + +👉 [rust-mcp-extra](https://crates.io/crates/rust-mcp-extra) also offers drop-in auth providers for common identity platforms, working seamlessly with rust-mcp-sdk: + - [Keycloack auth example](crates/rust-mcp-extra/README.md#keycloak) + - [WorkOS autn example](crates/rust-mcp-extra/README.md#workos-authkit) + + + + ### OAuthProxy + OAuthProxy enables authentication with OAuth providers that don’t support Dynamic Client Registration (DCR).It accepts any client registration request, handles the DCR on your server side and then uses your pre-registered app credentials upstream.The proxy also forwards callbacks, allowing dynamic redirect URIs to work with providers that require fixed ones. + +> ⚠️ OAuthProxy support is still in development, please use RemoteAuthProvider for now. + + ## Macros [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) includes several helpful macros that simplify common tasks when building MCP servers and clients. For example, they can automatically generate tool specifications and tool schemas right from your structs, or assist with elicitation requests and responses making them completely type safe. @@ -495,10 +519,6 @@ let user_info = UserInfo::from_content_map(result.content)?; 💻 For mre info please see : - https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/crates/rust-mcp-macros -## Getting Started - -If you are looking for a step-by-step tutorial on how to get started with `rust-mcp-sdk` , please see : [Getting Started MCP Server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/doc/getting-started-mcp-server.md) - ## HyperServerOptions HyperServer is a lightweight Axum-based server that streamlines MCP servers by supporting **Streamable HTTP** and **SSE** transports. It supports simultaneous client connections, internal session management, and includes built-in security features like DNS rebinding protection and more. @@ -588,6 +608,9 @@ pub struct HyperServerOptions { /// Optional custom path for the MCP messages endpoint for sse (default: `/messages`) /// Applicable only if sse_support is true pub custom_messages_endpoint: Option, + + /// Optional authentication provider for protecting MCP server. + pub auth: Option>, } ``` @@ -625,7 +648,7 @@ The `rust-mcp-sdk` crate provides several features that can be enabled or disabl - `2025_03_26` : Activates MCP Protocol version 2025-03-26 - `2024_11_05` : Activates MCP Protocol version 2024-11-05 -> Note: MCP protocol versions are mutually exclusive—only one can be active at any given time. +> Note: MCP protocol versions are mutually exclusive-only one can be active at any given time. ### Default Features diff --git a/crates/rust-mcp-sdk/src/auth.rs b/crates/rust-mcp-sdk/src/auth.rs new file mode 100644 index 0000000..61651ef --- /dev/null +++ b/crates/rust-mcp-sdk/src/auth.rs @@ -0,0 +1,24 @@ +mod auth_info; + +#[cfg(feature = "auth")] +mod auth_provider; +#[cfg(feature = "auth")] +mod error; +#[cfg(feature = "auth")] +mod metadata; +mod spec; +#[cfg(feature = "auth")] +mod token_verifier; + +pub use auth_info::AuthInfo; +#[cfg(feature = "auth")] +pub use auth_provider::*; +#[cfg(feature = "auth")] +pub use error::*; +#[cfg(feature = "auth")] +pub use metadata::*; +pub use spec::Audience; +#[cfg(feature = "auth")] +pub use spec::*; +#[cfg(feature = "auth")] +pub use token_verifier::*; diff --git a/crates/rust-mcp-sdk/src/auth/auth_info.rs b/crates/rust-mcp-sdk/src/auth/auth_info.rs new file mode 100644 index 0000000..01a4136 --- /dev/null +++ b/crates/rust-mcp-sdk/src/auth/auth_info.rs @@ -0,0 +1,105 @@ +#[cfg(feature = "auth")] +use crate::auth::{AuthClaims, AuthenticationError, IntrospectionResponse}; +use crate::{auth::Audience, utils::unix_timestamp_to_systemtime}; +#[cfg(feature = "auth")] +use jsonwebtoken::TokenData; +use serde::{Deserialize, Serialize}; +use serde_json::Map; +use std::time::SystemTime; + +/// Information about a validated access token, provided to request handlers. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AuthInfo { + /// Contains a unique id for jwt + /// use jti claim if available, otherwise use token or a reliable hash of token + pub token_unique_id: String, + + /// The client ID associated with this token. + #[serde(skip_serializing_if = "std::option::Option::is_none")] + pub client_id: Option, + + /// Optional user identifier for the token + #[serde(skip_serializing_if = "std::option::Option::is_none")] + pub user_id: Option, + + /// Scopes associated with this token. + #[serde(skip_serializing_if = "std::option::Option::is_none")] + pub scopes: Option>, + + /// When the token expires (in seconds since epoch). + /// This field is optional, as the token may not have an expiration time. + #[serde(skip_serializing_if = "std::option::Option::is_none")] + pub expires_at: Option, + + /// The RFC 8707 resource server identifier for which this token is valid. + /// If set, this MUST match the MCP server's resource identifier (minus hash fragment). + #[serde(skip_serializing_if = "std::option::Option::is_none")] + pub audience: Option, + + /// Additional data associated with the token. + /// This field can be used to attach any extra data to the auth info. + #[serde(flatten, skip_serializing_if = "std::option::Option::is_none")] + pub extra: Option>, +} + +#[cfg(feature = "auth")] +impl AuthInfo { + pub fn from_token_data( + token: String, + token_data: TokenData, + extra: Option>, + ) -> Result { + let client_id = token_data.claims.authorized_party.or(token_data + .claims + .client_id + .or(token_data.claims.application_id)); + + let scopes = token_data + .claims + .scope + .map(|c| c.split(" ").map(|s| s.to_string()).collect::>()); + + let expires_at = token_data + .claims + .expiration + .map(|v| unix_timestamp_to_systemtime(v as u64)); + + let token_unique_id = token_data.claims.jwt_id.unwrap_or(token); + + Ok(AuthInfo { + token_unique_id, + client_id, + scopes, + user_id: token_data.claims.subject, + expires_at, + audience: token_data.claims.audience, + extra, + }) + } + + pub fn from_introspection_response( + token: String, + data: IntrospectionResponse, + extra: Option>, + ) -> Result { + let scopes = data + .scope + .map(|c| c.split(" ").map(|s| s.to_string()).collect::>()); + + let expires_at = data + .expiration + .map(|v| unix_timestamp_to_systemtime(v as u64)); + + let token_unique_id = data.jwt_id.unwrap_or(token); + + Ok(AuthInfo { + token_unique_id, + client_id: data.client_id, + user_id: data.subject, + scopes, + expires_at, + audience: data.audience, + extra, + }) + } +} diff --git a/crates/rust-mcp-sdk/src/auth/auth_provider.rs b/crates/rust-mcp-sdk/src/auth/auth_provider.rs new file mode 100644 index 0000000..0f7ab20 --- /dev/null +++ b/crates/rust-mcp-sdk/src/auth/auth_provider.rs @@ -0,0 +1,98 @@ +mod remote_auth_provider; +use crate::auth::OauthEndpoint; +use crate::auth::{AuthInfo, AuthenticationError}; +use crate::mcp_http::{GenericBody, GenericBodyExt, McpAppState}; +use crate::mcp_server::error::TransportServerError; +use async_trait::async_trait; +use http::Method; +pub use remote_auth_provider::*; +use std::collections::HashMap; +use std::sync::Arc; + +#[async_trait] +pub trait AuthProvider: Send + Sync { + async fn verify_token(&self, access_token: String) -> Result; + + /// Returns an optional list of scopes required to access this resource. + /// If this function returns `Some(scopes)`, the authenticated user’s token + /// must include **all** of the listed scopes. + /// If any are missing, the request will be rejected with a `403 Forbidden` response. + fn required_scopes(&self) -> Option<&Vec> { + None + } + + /// Returns the configured OAuth endpoints for this provider. + /// + /// - Key: endpoint path as a string (e.g., "/oauth/token") + /// - Value: corresponding `OauthEndpoint` configuration + /// + /// Returns `None` if no endpoints are configured. + fn auth_endpoints(&self) -> Option<&HashMap>; + + /// Handles an incoming HTTP request for this authentication provider. + /// + /// This is the main entry point for processing OAuth requests, + /// such as token issuance, authorization code exchange, or revocation. + async fn handle_request( + &self, + request: http::Request<&str>, + state: Arc, + ) -> Result, TransportServerError>; + + /// Returns the `OauthEndpoint` associated with the given request path. + /// + /// This method looks up the request URI path in the endpoints returned by `auth_endpoints()`. + /// + /// ⚠️ Note: + /// - If your token and revocation endpoints share the same URL path (valid in some implementations), + /// you may want to override this method to correctly distinguish the request type + /// (e.g., based on request parameters like `grant_type` vs `token`). + fn endpoint_type(&self, request: &http::Request<&str>) -> Option<&OauthEndpoint> { + let endpoints = self.auth_endpoints()?; + endpoints.get(request.uri().path()) + } + + /// Returns the absolute URL of this resource's OAuth 2.0 Protected Resource Metadata document. + /// + /// This corresponds to the `resource_metadata` parameter defined in + /// [RFC 9531 - OAuth 2.0 Protected Resource Metadata](https://datatracker.ietf.org/doc/html/rfc9531). + /// + /// The returned URL is an **absolute** URL (including scheme and host), for example: + /// `https://api.example.com/.well-known/oauth-protected-resource`. + /// + fn protected_resource_metadata_url(&self) -> Option<&str>; + + fn validate_allowed_methods( + &self, + endpoint: &OauthEndpoint, + method: &Method, + ) -> Option> { + let allowed_methods = match endpoint { + OauthEndpoint::AuthorizationEndpoint => { + vec![Method::GET, Method::HEAD, Method::OPTIONS] + } + OauthEndpoint::TokenEndpoint => vec![Method::POST, Method::OPTIONS], + OauthEndpoint::RegistrationEndpoint => vec![ + Method::POST, + Method::GET, + Method::PUT, + Method::PATCH, + Method::DELETE, + Method::OPTIONS, + ], + OauthEndpoint::RevocationEndpoint => vec![Method::POST, Method::OPTIONS], + OauthEndpoint::IntrospectionEndpoint => vec![Method::POST, Method::OPTIONS], + OauthEndpoint::AuthorizationServerMetadata => { + vec![Method::GET, Method::HEAD, Method::OPTIONS] + } + OauthEndpoint::ProtectedResourceMetadata => { + vec![Method::GET, Method::HEAD, Method::OPTIONS] + } + }; + + if !allowed_methods.contains(method) { + return Some(GenericBody::create_405_response(method, &allowed_methods)); + } + None + } +} diff --git a/crates/rust-mcp-sdk/src/auth/auth_provider/remote_auth_provider.rs b/crates/rust-mcp-sdk/src/auth/auth_provider/remote_auth_provider.rs new file mode 100644 index 0000000..ce4fa56 --- /dev/null +++ b/crates/rust-mcp-sdk/src/auth/auth_provider/remote_auth_provider.rs @@ -0,0 +1,187 @@ +use crate::{ + auth::{ + create_protected_resource_metadata_url, AuthInfo, AuthProvider, AuthenticationError, + AuthorizationServerMetadata, OauthEndpoint, OauthProtectedResourceMetadata, + OauthTokenVerifier, WELL_KNOWN_OAUTH_AUTHORIZATION_SERVER, + }, + mcp_http::{ + middleware::CorsMiddleware, url_base, GenericBody, GenericBodyExt, McpAppState, Middleware, + }, + mcp_server::error::{TransportServerError, TransportServerResult}, +}; +use async_trait::async_trait; +use bytes::Bytes; +use http::{header::CONTENT_TYPE, StatusCode}; +use http_body_util::{BodyExt, Full}; +use reqwest::Client; +use std::{collections::HashMap, sync::Arc}; + +/// Represents a **Remote OAuth authentication provider** integrated with the MCP server. +/// This struct defines how the MCP server interacts with an external identity provider +/// that supports **Dynamic Client Registration (DCR)**. +/// The [`RemoteAuthProvider`] enables enterprise-grade authentication by leveraging +/// external OAuth infrastructure, while maintaining secure token verification and +/// identity validation within the MCP server. +pub struct RemoteAuthProvider { + auth_server_meta: AuthorizationServerMetadata, + protected_resource_meta: OauthProtectedResourceMetadata, + token_verifier: Box, + endpoint_map: HashMap, + required_scopes: Option>, + protected_resource_metadata_url: String, +} + +impl RemoteAuthProvider { + pub fn new( + auth_server_meta: AuthorizationServerMetadata, + protected_resource_meta: OauthProtectedResourceMetadata, + token_verifier: Box, + required_scopes: Option>, + ) -> Self { + let mut endpoint_map = HashMap::new(); + endpoint_map.insert( + WELL_KNOWN_OAUTH_AUTHORIZATION_SERVER.to_string(), + OauthEndpoint::AuthorizationServerMetadata, + ); + + let resource_url = &protected_resource_meta.resource; + let relative_url = create_protected_resource_metadata_url(resource_url.path()); + let base_url = url_base(resource_url); + let protected_resource_metadata_url = + format!("{}{relative_url}", base_url.trim_end_matches('/')); + + endpoint_map.insert(relative_url, OauthEndpoint::ProtectedResourceMetadata); + + Self { + auth_server_meta, + protected_resource_meta, + token_verifier, + endpoint_map, + required_scopes, + protected_resource_metadata_url, + } + } + + pub async fn with_remote_metadata_url( + authorization_server_metadata_url: &str, + protected_resource_meta: OauthProtectedResourceMetadata, + token_verifier: Box, + required_scopes: Option>, + ) -> Result { + let client = Client::new(); + + let auth_server_meta = client + .get(authorization_server_metadata_url) + .send() + .await? + .json::() + .await?; + + Ok(Self::new( + auth_server_meta, + protected_resource_meta, + token_verifier, + required_scopes, + )) + } + + fn handle_authorization_server_metadata( + response_str: String, + ) -> TransportServerResult> { + let body = Full::new(Bytes::from(response_str)) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed(); + http::Response::builder() + .status(StatusCode::OK) + .header(CONTENT_TYPE, "application/json") + .body(body) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + } + + fn handle_protected_resource_metadata( + response_str: String, + ) -> TransportServerResult> { + use http_body_util::BodyExt; + + let body = Full::new(Bytes::from(response_str)) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed(); + http::Response::builder() + .status(StatusCode::OK) + .header(CONTENT_TYPE, "application/json") + .body(body) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + } +} + +#[async_trait] +impl AuthProvider for RemoteAuthProvider { + fn protected_resource_metadata_url(&self) -> Option<&str> { + Some(self.protected_resource_metadata_url.as_str()) + } + + async fn verify_token(&self, access_token: String) -> Result { + self.token_verifier.verify_token(access_token).await + } + + fn required_scopes(&self) -> Option<&Vec> { + self.required_scopes.as_ref() + } + + async fn handle_request( + &self, + request: http::Request<&str>, + state: Arc, + ) -> Result, TransportServerError> { + let Some(endpoint) = self.endpoint_type(&request) else { + return http::Response::builder() + .status(StatusCode::NOT_FOUND) + .body(GenericBody::empty()) + .map_err(|err| TransportServerError::HttpError(err.to_string())); + }; + + // return early if method is not allowed + if let Some(response) = self.validate_allowed_methods(endpoint, request.method()) { + return Ok(response); + } + + match endpoint { + OauthEndpoint::AuthorizationServerMetadata => { + let json_payload = serde_json::to_string(&self.auth_server_meta) + .map_err(|err| TransportServerError::HttpError(err.to_string()))?; + let cors = &CorsMiddleware::default(); + cors.handle( + request, + state, + Box::new(move |_req, _state| { + Box::pin( + async move { Self::handle_authorization_server_metadata(json_payload) }, + ) + }), + ) + .await + } + OauthEndpoint::ProtectedResourceMetadata => { + let json_payload = serde_json::to_string(&self.protected_resource_meta) + .map_err(|err| TransportServerError::HttpError(err.to_string()))?; + + let cors = &CorsMiddleware::default(); + cors.handle( + request, + state, + Box::new(move |_req, _state| { + Box::pin( + async move { Self::handle_protected_resource_metadata(json_payload) }, + ) + }), + ) + .await + } + _ => Ok(GenericBody::create_404_response()), + } + } + + fn auth_endpoints(&self) -> Option<&HashMap> { + Some(&self.endpoint_map) + } +} diff --git a/crates/rust-mcp-sdk/src/auth/error.rs b/crates/rust-mcp-sdk/src/auth/error.rs new file mode 100644 index 0000000..3ac0a87 --- /dev/null +++ b/crates/rust-mcp-sdk/src/auth/error.rs @@ -0,0 +1,62 @@ +use serde::Serialize; +use serde_json::{json, Value}; +use thiserror::Error; + +#[derive(Debug, Error, Clone, Serialize)] +#[serde(tag = "error", rename_all = "snake_case")] +pub enum AuthenticationError { + #[error("No token verification endpoint available in metadata.")] + NoIntrospectionEndpoint, + + #[error("failed to retrieve JWKS from the authorization server : {0}")] + Jwks(String), + + #[error("{description}")] + InvalidToken { description: &'static str }, + + #[error("Inactive Token")] + InactiveToken, + + #[error("Resource indicator (aud) missing.")] + AudiencesAttributeMissing, + + #[error( + "Insufficient scope: you do not have the necessary permissions to perform this action." + )] + InsufficientScope, + + #[error("None of the provided audiences are allowed. Expected ${expected}, got: ${received}")] + AudienceNotAllowed { expected: String, received: String }, + + #[error("Invalid or expired token: {0}")] + InvalidOrExpiredToken(String), + + #[error("{description}")] + TokenVerificationFailed { + description: String, + status_code: Option, + }, + + #[error("{description}")] + ServerError { description: String }, + + #[error("{0}")] + ParsingError(String), + + #[error("{0}")] + NotFound(String), +} + +impl AuthenticationError { + pub fn as_json_value(&self) -> Value { + let serialized = serde_json::to_value(self).unwrap_or(Value::Null); + let error_name = serialized + .get("error") + .and_then(|v| v.as_str()) + .unwrap_or("unknown_error"); + json!({ + "error": error_name, + "error_description": self.to_string() + }) + } +} diff --git a/crates/rust-mcp-sdk/src/auth/metadata.rs b/crates/rust-mcp-sdk/src/auth/metadata.rs new file mode 100644 index 0000000..bd7ff49 --- /dev/null +++ b/crates/rust-mcp-sdk/src/auth/metadata.rs @@ -0,0 +1,685 @@ +use std::borrow::Cow; + +use crate::{ + auth::{AuthorizationServerMetadata, OauthProtectedResourceMetadata}, + error::McpSdkError, + utils::join_url, +}; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use thiserror::Error; +use url::Url; + +pub const WELL_KNOWN_OAUTH_AUTHORIZATION_SERVER: &str = "/.well-known/oauth-authorization-server"; +pub const OAUTH_PROTECTED_RESOURCE_BASE: &str = "/.well-known/oauth-protected-resource"; + +#[allow(unused)] +#[derive(Hash, Eq, PartialEq, Clone)] +pub enum OauthEndpoint { + AuthorizationEndpoint, + TokenEndpoint, + RegistrationEndpoint, + RevocationEndpoint, + IntrospectionEndpoint, + AuthorizationServerMetadata, + ProtectedResourceMetadata, +} + +#[derive(Debug, Error)] +pub enum AuthMetadateError { + #[error("Url Parse Error: {0}")] + Transport(#[from] url::ParseError), +} + +pub struct AuthMetadataEndpoints { + pub protected_resource_endpoint: String, + pub authorization_server_endpoint: String, +} + +// Builder struct to construct both OAuthMetadata and OAuthProtectedResourceMetadata + +#[derive(Default)] +pub struct AuthMetadataBuilder<'a> { + // OAuthMetadata-specific fields + issuer: Option>, + authorization_endpoint: Option>, + token_endpoint: Option>, + registration_endpoint: Option>, + revocation_endpoint: Option>, + introspection_endpoint: Option>, + scopes_supported: Option>>, + + response_types_supported: Option>>, + response_modes_supported: Option>>, + grant_types_supported: Option>>, + token_endpoint_auth_methods_supported: Option>>, + token_endpoint_auth_signing_alg_values_supported: Option>>, + revocation_endpoint_auth_signing_alg_values_supported: Option>>, + revocation_endpoint_auth_methods_supported: Option>>, + introspection_endpoint_auth_methods_supported: Option>>, + introspection_endpoint_auth_signing_alg_values_supported: Option>>, + code_challenge_methods_supported: Option>>, + service_documentation: Option>, + + // OAuthProtectedResourceMetadata-specific fields + resource: Option>, + authorization_servers: Option>>, + required_scopes: Option>>, + + jwks_uri: Option>, + bearer_methods_supported: Option>>, + resource_signing_alg_values_supported: Option>>, + resource_name: Option>, + resource_documentation: Option>, + resource_policy_uri: Option>, + resource_tos_uri: Option>, + tls_client_certificate_bound_access_tokens: Option, + authorization_details_types_supported: Option>>, + dpop_signing_alg_values_supported: Option>>, + dpop_bound_access_tokens_required: Option, + + // none-standard + userinfo_endpoint: Option>, +} + +// Result struct to hold both metadata types +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct OauthMetadata { + authorization_server_metadata: AuthorizationServerMetadata, + protected_resource_metadata: OauthProtectedResourceMetadata, +} + +impl OauthMetadata { + pub fn protected_resource_metadata(&self) -> &OauthProtectedResourceMetadata { + &self.protected_resource_metadata + } + + pub fn authorization_server_metadata(&self) -> &AuthorizationServerMetadata { + &self.authorization_server_metadata + } + + pub fn endpoints(&self) -> AuthMetadataEndpoints { + AuthMetadataEndpoints { + authorization_server_endpoint: WELL_KNOWN_OAUTH_AUTHORIZATION_SERVER.to_string(), + protected_resource_endpoint: format!( + "{OAUTH_PROTECTED_RESOURCE_BASE}{}", + match self.protected_resource_metadata.resource.path() { + "/" => "", + other => other, + } + ), + } + } +} + +impl<'a> AuthMetadataBuilder<'a> { + fn with_defaults(protected_resource: &'a str) -> Self { + Self { + response_types_supported: Some(vec!["code".into()]), + code_challenge_methods_supported: Some(vec!["S256".into()]), + token_endpoint_auth_methods_supported: Some(vec!["client_secret_post".into()]), + grant_types_supported: Some(vec!["authorization_code".into(), "refresh_token".into()]), + resource: Some(protected_resource.into()), + ..Default::default() + } + } + + /// Creates a new instance of the builder for the given protected resource. + /// The `protected_resource` parameter must specify the full URL of the MCP server. + pub fn new(protected_resource_url: &'a str) -> Self { + Self::with_defaults(protected_resource_url) + } + + pub async fn from_discovery_url( + discovery_url: &str, + protected_resource: S, + required_scopes: Vec, + ) -> Result + where + S: Into>, + { + let client = Client::new(); + let json: Value = client + .get(discovery_url) + .send() + .await + .map_err(|e| McpSdkError::Internal { + description: format!( + "Failed to fetch discovery document : \"{discovery_url}\": {e}" + ), + })? + .error_for_status() + .map_err(|e| McpSdkError::Internal { + description: format!("Discovery endpoint returned error: {e}"), + })? + .json() + .await + .map_err(|e| McpSdkError::Internal { + description: format!("Failed to parse JSON from discovery document: {e}"), + })?; + + // Helper to extract string field safely + let get_str = |key: &str| { + json.get(key) + .and_then(|v| v.as_str()) + .map(|s| Cow::::Owned(s.to_string())) + }; + // Helper for optional array of strings + let get_str_array = |key: &str| { + json.get(key).and_then(|v| v.as_array()).map(|arr| { + arr.iter() + .filter_map(|item| item.as_str()) + .filter(|v| !v.is_empty()) + .map(|s| Cow::::Owned(s.to_string())) + .collect::>() + }) + }; + + let issuer = get_str("issuer").ok_or_else(|| McpSdkError::Internal { + description: "Missing 'issuer' in discovery document".to_string(), + })?; + + Ok(Self { + issuer: Some(issuer.clone()), + authorization_endpoint: get_str("authorization_endpoint"), + scopes_supported: get_str_array("scopes_supported"), + required_scopes: Some(required_scopes.into_iter().map(|s| s.into()).collect()), + token_endpoint: get_str("token_endpoint"), + jwks_uri: get_str("jwks_uri"), + + userinfo_endpoint: get_str("userinfo_endpoint"), + + registration_endpoint: get_str("registration_endpoint"), + revocation_endpoint: get_str("revocation_endpoint"), + introspection_endpoint: get_str("introspection_endpoint"), + response_types_supported: get_str_array("response_types_supported"), + response_modes_supported: get_str_array("response_modes_supported"), + grant_types_supported: get_str_array("grant_types_supported"), + token_endpoint_auth_methods_supported: get_str_array( + "token_endpoint_auth_methods_supported", + ), + token_endpoint_auth_signing_alg_values_supported: get_str_array( + "token_endpoint_auth_signing_alg_values_supported", + ), + revocation_endpoint_auth_signing_alg_values_supported: get_str_array( + "revocation_endpoint_auth_signing_alg_values_supported", + ), + revocation_endpoint_auth_methods_supported: get_str_array( + "revocation_endpoint_auth_methods_supported", + ), + introspection_endpoint_auth_methods_supported: get_str_array( + "introspection_endpoint_auth_methods_supported", + ), + introspection_endpoint_auth_signing_alg_values_supported: get_str_array( + "introspection_endpoint_auth_signing_alg_values_supported", + ), + code_challenge_methods_supported: get_str_array("code_challenge_methods_supported"), + service_documentation: get_str("service_documentation"), + resource: Some(protected_resource.into()), + authorization_servers: Some(vec![issuer]), + bearer_methods_supported: None, + resource_signing_alg_values_supported: None, + resource_name: None, + resource_documentation: None, + resource_policy_uri: None, + resource_tos_uri: None, + tls_client_certificate_bound_access_tokens: None, + authorization_details_types_supported: None, + dpop_signing_alg_values_supported: None, + dpop_bound_access_tokens_required: None, + }) + } + + fn parse_url_field( + field_name: &str, + value: Option, + base_url: Option<&Url>, + ) -> Result + where + S: Into>, + { + let value = value + .ok_or(McpSdkError::Internal { + description: format!("Error: '{field_name}' is missing."), + })? + .into(); + + let url = if value.contains("://") { + // Absolute URL + Url::parse(&value) + } else if let Some(base_url) = base_url { + // Relative URL, join with base_url + join_url(base_url, &value) + } else { + // No base_url provided, try to parse as absolute URL anyway + Url::parse(&value) + }; + + url.map_err(|e| McpSdkError::Internal { + description: format!("Error: '{field_name}' is not a valid URL: {e}"), + }) + } + + fn parse_optional_url_field( + field_name: &str, + value: Option, + base_url: Option<&Url>, + ) -> Result, McpSdkError> + where + S: Into>, + { + value + .map(|v| { + let value = v.into(); + if value.contains("://") { + // Absolute URL + Url::parse(&value) + } else if let Some(base_url) = base_url { + // Relative URL, join with base_url + join_url(base_url, &value) + } else { + // No base_url provided, try to parse as absolute URL anyway + Url::parse(&value) + } + }) + .transpose() + .map_err(|e| McpSdkError::Internal { + description: format!("Error: '{field_name}' is not a valid URL: {e}"), + }) + } + + pub fn scopes_supported(mut self, scopes: Vec) -> Self + where + S: Into>, + { + self.scopes_supported = Some(scopes.into_iter().map(|s| s.into()).collect()); + self + } + + // OAuthMetadata setters + pub fn issuer(mut self, issuer: S) -> Self + where + S: Into>, + { + self.issuer = Some(issuer.into()); + self + } + + pub fn service_documentation(mut self, url: S) -> Self + where + S: Into>, + { + self.service_documentation = Some(url.into()); + self + } + + pub fn authorization_endpoint(mut self, url: S) -> Self + where + S: Into>, + { + self.authorization_endpoint = Some(url.into()); + self + } + + pub fn token_endpoint(mut self, url: S) -> Self + where + S: Into>, + { + self.token_endpoint = Some(url.into()); + self + } + + pub fn response_types_supported(mut self, types: Vec) -> Self + where + S: Into>, + { + self.response_types_supported = Some(types.into_iter().map(|s| s.into()).collect()); + self + } + + pub fn response_modes_supported(mut self, modes: Vec) -> Self + where + S: Into>, + { + self.response_modes_supported = Some(modes.into_iter().map(|s| s.into()).collect()); + self + } + + pub fn registration_endpoint(mut self, url: &'a str) -> Self { + self.registration_endpoint = Some(url.into()); + self + } + + pub fn userinfo_endpoint(mut self, url: &'a str) -> Self { + self.userinfo_endpoint = Some(url.into()); + self + } + + pub fn grant_types_supported(mut self, types: Vec) -> Self + where + S: Into>, + { + self.grant_types_supported = Some(types.into_iter().map(|s| s.into()).collect()); + self + } + + pub fn token_endpoint_auth_methods_supported(mut self, methods: Vec) -> Self + where + S: Into>, + { + self.token_endpoint_auth_methods_supported = + Some(methods.into_iter().map(|s| s.into()).collect()); + self + } + + pub fn token_endpoint_auth_signing_alg_values_supported(mut self, algs: Vec) -> Self + where + S: Into>, + { + self.token_endpoint_auth_signing_alg_values_supported = + Some(algs.into_iter().map(|s| s.into()).collect()); + self + } + + pub fn revocation_endpoint(mut self, url: &'a str) -> Self { + self.revocation_endpoint = Some(url.into()); + self + } + + pub fn revocation_endpoint_auth_methods_supported(mut self, methods: Vec) -> Self + where + S: Into>, + { + self.revocation_endpoint_auth_methods_supported = + Some(methods.into_iter().map(|s| s.into()).collect()); + self + } + + pub fn revocation_endpoint_auth_signing_alg_values_supported(mut self, algs: Vec) -> Self + where + S: Into>, + { + self.revocation_endpoint_auth_signing_alg_values_supported = + Some(algs.into_iter().map(|s| s.into()).collect()); + self + } + + pub fn introspection_endpoint(mut self, endpoint: &'a str) -> Self { + self.introspection_endpoint = Some(endpoint.into()); + self + } + + pub fn introspection_endpoint_auth_methods_supported(mut self, methods: Vec) -> Self + where + S: Into>, + { + self.introspection_endpoint_auth_methods_supported = + Some(methods.into_iter().map(|s| s.into()).collect()); + self + } + + pub fn introspection_endpoint_auth_signing_alg_values_supported( + mut self, + algs: Vec, + ) -> Self + where + S: Into>, + { + self.introspection_endpoint_auth_signing_alg_values_supported = + Some(algs.into_iter().map(|s| s.into()).collect()); + self + } + + pub fn code_challenge_methods_supported(mut self, methods: Vec) -> Self + where + S: Into>, + { + self.code_challenge_methods_supported = + Some(methods.into_iter().map(|s| s.into()).collect()); + self + } + + // OAuthProtectedResourceMetadata setters + pub fn resource(mut self, url: &'a str) -> Self { + self.resource = Some(url.into()); + self + } + + pub fn authorization_servers(mut self, servers: Vec<&'a str>) -> Self { + self.authorization_servers = Some(servers.into_iter().map(|s| s.into()).collect()); + self + } + + pub fn reqquired_scopes(mut self, scopes: Vec) -> Self + where + S: Into>, + { + self.required_scopes = Some(scopes.into_iter().map(|s| s.into()).collect()); + self + } + + pub fn resource_documentation(mut self, doc: String) -> Self + where + S: Into>, + { + self.resource_documentation = Some(doc.into()); + self + } + + pub fn jwks_uri(mut self, url: &'a str) -> Self { + self.jwks_uri = Some(url.into()); + self + } + + pub fn bearer_methods_supported(mut self, methods: Vec) -> Self + where + S: Into>, + { + self.bearer_methods_supported = Some(methods.into_iter().map(|s| s.into()).collect()); + self + } + + pub fn resource_signing_alg_values_supported(mut self, algs: Vec) -> Self + where + S: Into>, + { + self.resource_signing_alg_values_supported = + Some(algs.into_iter().map(|s| s.into()).collect()); + self + } + + pub fn resource_name(mut self, name: S) -> Self + where + S: Into>, + { + self.resource_name = Some(name.into()); + self + } + + pub fn resource_policy_uri(mut self, url: &'a str) -> Self { + self.resource_policy_uri = Some(url.into()); + self + } + + pub fn resource_tos_uri(mut self, url: &'a str) -> Self { + self.resource_tos_uri = Some(url.into()); + self + } + + pub fn tls_client_certificate_bound_access_tokens(mut self, value: bool) -> Self { + self.tls_client_certificate_bound_access_tokens = Some(value); + self + } + + pub fn authorization_details_types_supported(mut self, types: Vec) -> Self + where + S: Into>, + { + self.authorization_details_types_supported = + Some(types.into_iter().map(|s| s.into()).collect()); + self + } + + pub fn dpop_signing_alg_values_supported(mut self, algs: Vec) -> Self + where + S: Into>, + { + self.dpop_signing_alg_values_supported = Some(algs.into_iter().map(|s| s.into()).collect()); + self + } + + pub fn dpop_bound_access_tokens_required(mut self, value: bool) -> Self { + self.dpop_bound_access_tokens_required = Some(value); + self + } + + // Build method to construct OauthMetadata + pub fn build( + self, + ) -> Result<(AuthorizationServerMetadata, OauthProtectedResourceMetadata), McpSdkError> { + let issuer = Self::parse_url_field("issuer", self.issuer, None)?; + + let authorization_endpoint = Self::parse_url_field( + "authorization_endpoint", + self.authorization_endpoint, + Some(&issuer), + )?; + + let token_endpoint = + Self::parse_url_field("token_endpoint", self.token_endpoint, Some(&issuer))?; + + let registration_endpoint = Self::parse_optional_url_field( + "registration_endpoint", + self.registration_endpoint, + Some(&issuer), + )?; + + let revocation_endpoint = Self::parse_optional_url_field( + "revocation_endpoint", + self.revocation_endpoint, + Some(&issuer), + )?; + + let introspection_endpoint = Self::parse_optional_url_field( + "introspection_endpoint", + self.introspection_endpoint, + Some(&issuer), + )?; + + let service_documentation = Self::parse_optional_url_field( + "service_documentation", + self.service_documentation, + None, + )?; + + let jwks_uri = Self::parse_optional_url_field("jwks_uri", self.jwks_uri, Some(&issuer))?; + + let authorization_server_metadata = AuthorizationServerMetadata { + issuer, + authorization_endpoint, + token_endpoint, + registration_endpoint, + service_documentation, + revocation_endpoint, + introspection_endpoint, + userinfo_endpoint: self.userinfo_endpoint.map(|v| v.into()), + response_types_supported: self + .response_types_supported + .unwrap_or_default() + .into_iter() // iterate over Cow<'a, str> + .map(|c| c.into_owned()) + .collect(), + response_modes_supported: self + .response_modes_supported + .map(|v| v.into_iter().map(|c| c.into_owned()).collect()), + scopes_supported: self + .scopes_supported + .map(|v| v.into_iter().map(|c| c.into_owned()).collect()), + grant_types_supported: self + .grant_types_supported + .map(|v| v.into_iter().map(|c| c.into_owned()).collect()), + token_endpoint_auth_methods_supported: self + .token_endpoint_auth_methods_supported + .map(|v| v.into_iter().map(|c| c.into_owned()).collect()), + token_endpoint_auth_signing_alg_values_supported: self + .token_endpoint_auth_signing_alg_values_supported + .map(|v| v.into_iter().map(|c| c.into_owned()).collect()), + revocation_endpoint_auth_signing_alg_values_supported: self + .revocation_endpoint_auth_signing_alg_values_supported + .map(|v| v.into_iter().map(|c| c.into_owned()).collect()), + revocation_endpoint_auth_methods_supported: self + .revocation_endpoint_auth_methods_supported + .map(|v| v.into_iter().map(|c| c.into_owned()).collect()), + introspection_endpoint_auth_methods_supported: self + .introspection_endpoint_auth_methods_supported + .map(|v| v.into_iter().map(|c| c.into_owned()).collect()), + introspection_endpoint_auth_signing_alg_values_supported: self + .introspection_endpoint_auth_signing_alg_values_supported + .map(|v| v.into_iter().map(|c| c.into_owned()).collect()), + code_challenge_methods_supported: self + .code_challenge_methods_supported + .map(|v| v.into_iter().map(|c| c.into_owned()).collect()), + jwks_uri: jwks_uri.clone(), + }; + + let resource = Self::parse_url_field("resource", self.resource, None)?; + let resource_policy_uri = + Self::parse_optional_url_field("resource_policy_uri", self.resource_policy_uri, None)?; + let resource_tos_uri = + Self::parse_optional_url_field("resource_tos_uri", self.resource_tos_uri, None)?; + + // Validate mandatory authorization_servers + let authorization_servers = + self.authorization_servers + .ok_or_else(|| McpSdkError::Internal { + description: "Error: 'authorization_servers' is missing".to_string(), + })?; + if authorization_servers.is_empty() { + return Err(McpSdkError::Internal { + description: "Error: 'authorization_servers' must contain at least one URL" + .to_string(), + }); + } + let authorization_servers = authorization_servers + .iter() + .map(|url| { + Url::parse(url).map_err(|err| McpSdkError::Internal { + description: format!( + "Error: 'authorization_servers' contains invalid URL '{url}': {err}", + ), + }) + }) + .collect::, _>>()?; + + let protected_resource_metadata = OauthProtectedResourceMetadata { + resource, + authorization_servers, + jwks_uri, + scopes_supported: self + .required_scopes + .map(|v| v.into_iter().map(|c| c.into_owned()).collect()), + bearer_methods_supported: self + .bearer_methods_supported + .map(|v| v.into_iter().map(|c| c.into_owned()).collect()), + resource_signing_alg_values_supported: self + .resource_signing_alg_values_supported + .map(|v| v.into_iter().map(|c| c.into_owned()).collect()), + resource_name: self.resource_name.map(|s| s.into()), + resource_documentation: self.resource_documentation.map(|s| s.into()), + resource_policy_uri, + resource_tos_uri, + tls_client_certificate_bound_access_tokens: self + .tls_client_certificate_bound_access_tokens, + authorization_details_types_supported: self + .authorization_details_types_supported + .map(|v| v.into_iter().map(|c| c.into_owned()).collect()), + dpop_signing_alg_values_supported: self + .dpop_signing_alg_values_supported + .map(|v| v.into_iter().map(|c| c.into_owned()).collect()), + dpop_bound_access_tokens_required: self.dpop_bound_access_tokens_required, + }; + + Ok((authorization_server_metadata, protected_resource_metadata)) + } +} diff --git a/crates/rust-mcp-sdk/src/auth/spec.rs b/crates/rust-mcp-sdk/src/auth/spec.rs new file mode 100644 index 0000000..1bbd809 --- /dev/null +++ b/crates/rust-mcp-sdk/src/auth/spec.rs @@ -0,0 +1,15 @@ +mod audience; +#[cfg(feature = "auth")] +mod claims; +#[cfg(feature = "auth")] +mod discovery; +#[cfg(feature = "auth")] +mod jwk; + +pub use audience::*; +#[cfg(feature = "auth")] +pub use claims::*; +#[cfg(feature = "auth")] +pub use discovery::*; +#[cfg(feature = "auth")] +pub use jwk::*; diff --git a/crates/rust-mcp-sdk/src/auth/spec/audience.rs b/crates/rust-mcp-sdk/src/auth/spec/audience.rs new file mode 100644 index 0000000..229c27d --- /dev/null +++ b/crates/rust-mcp-sdk/src/auth/spec/audience.rs @@ -0,0 +1,111 @@ +use core::fmt; + +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use serde_json::Value; + +/// Represents the audience claim, which can be a single string or a list of strings. +#[derive(Debug, Clone)] +pub enum Audience { + Single(String), + Multiple(Vec), +} + +impl Audience { + /// Converts the audience to a `Vec` for uniform access. + pub fn to_vec(&self) -> Vec { + match self { + Audience::Single(s) => vec![s.clone()], + Audience::Multiple(v) => v.clone(), + } + } +} + +impl fmt::Display for Audience { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Audience::Single(s) => write!(f, "{s}"), + Audience::Multiple(v) => { + let formatted = v.join(", "); + write!(f, "{formatted}") + } + } + } +} + +impl PartialEq for Audience { + fn eq(&self, other: &Self) -> bool { + self.to_vec() == other.to_vec() + } +} + +impl Eq for Audience {} + +impl Serialize for Audience { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + match self { + // Serialize a single string directly as a JSON string + Audience::Single(s) => serializer.serialize_str(s), + // Serialize multiple strings as a JSON array + Audience::Multiple(v) => serializer.collect_seq(v), + } + } +} + +impl<'de> Deserialize<'de> for Audience { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + // Use a Value to handle both string and array cases + let value = Value::deserialize(deserializer)?; + match value { + Value::String(s) => Ok(Audience::Single(s)), + Value::Array(arr) => { + let strings = arr + .into_iter() + .map(|v| match v { + Value::String(s) => Ok(s), + _ => Err(serde::de::Error::custom( + "audience array must contain strings", + )), + }) + .collect::, D::Error>>()?; + Ok(Audience::Multiple(strings)) + } + _ => Err(serde::de::Error::custom( + "audience must be a string or an array of strings", + )), + } + } +} + +// Allow converting from &str +impl From<&str> for Audience { + fn from(s: &str) -> Self { + Audience::Single(s.to_string()) + } +} + +// Allow converting from String +impl From for Audience { + fn from(s: String) -> Self { + Audience::Single(s) + } +} + +// Allow converting from Vec +impl From> for Audience { + fn from(v: Vec) -> Self { + Audience::Multiple(v) + } +} + +// Allow converting from Vec<&str> for convenience +impl From> for Audience { + fn from(v: Vec<&str>) -> Self { + Audience::Multiple(v.into_iter().map(|s| s.to_string()).collect()) + } +} diff --git a/crates/rust-mcp-sdk/src/auth/spec/claims.rs b/crates/rust-mcp-sdk/src/auth/spec/claims.rs new file mode 100644 index 0000000..814bf95 --- /dev/null +++ b/crates/rust-mcp-sdk/src/auth/spec/claims.rs @@ -0,0 +1,283 @@ +use super::Audience; +use serde::{Deserialize, Serialize}; + +/// Represents a structured address for the OIDC address claim. +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Address { + /// Full mailing address, formatted for display or use. + #[serde(skip_serializing_if = "Option::is_none")] + pub formatted: Option, + /// Street address component (e.g., house number and street name). + #[serde(skip_serializing_if = "Option::is_none")] + pub street_address: Option, + /// City or locality component. + #[serde(skip_serializing_if = "Option::is_none")] + pub locality: Option, + /// State, province, or region component. + #[serde(skip_serializing_if = "Option::is_none")] + pub region: Option, + /// ZIP or postal code component. + #[serde(skip_serializing_if = "Option::is_none")] + pub postal_code: Option, + /// Country name component. + #[serde(skip_serializing_if = "Option::is_none")] + pub country: Option, +} + +/// Represents a combined set of JWT, OAuth 2.0, OIDC, and provider-specific claims. +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct AuthClaims { + // Standard JWT Claims (RFC 7519) + /// Issuer - Identifies the authorization server that issued the token (JWT: iss). + #[serde(rename = "iss", skip_serializing_if = "Option::is_none")] + pub issuer: Option, + + /// Subject - Unique identifier for the user or client (JWT: sub). + #[serde(rename = "sub", skip_serializing_if = "Option::is_none")] + pub subject: Option, + + /// Audience - Identifies the intended recipients, can be a string or array (JWT: aud). + #[serde(rename = "aud", skip_serializing_if = "Option::is_none")] + pub audience: Option, + + /// Expiration Time - Unix timestamp when the token expires (JWT: exp). + #[serde(rename = "exp", skip_serializing_if = "Option::is_none")] + pub expiration: Option, + + /// Not Before - Unix timestamp when the token becomes valid (JWT: nbf). + #[serde(rename = "nbf", skip_serializing_if = "Option::is_none")] + pub not_before: Option, + + /// Issued At - Unix timestamp when the token was issued (JWT: iat). + #[serde(rename = "iat", skip_serializing_if = "Option::is_none")] + pub issued_at: Option, + + /// JWT ID - Unique identifier for the token to prevent reuse (JWT: jti). + #[serde(rename = "jti", skip_serializing_if = "Option::is_none")] + pub jwt_id: Option, + + // OAuth 2.0 Access Token Claims (RFC 9068) + /// Scope - Space-separated list of scopes authorized for the token. + #[serde(rename = "scope", skip_serializing_if = "Option::is_none")] + pub scope: Option, + + /// Client ID - ID of the OAuth client that obtained the token. + #[serde(rename = "client_id", skip_serializing_if = "Option::is_none")] + pub client_id: Option, + + /// Confirmation - Provides key binding info (e.g., cnf.jkt for PoP tokens). + #[serde(rename = "cnf", skip_serializing_if = "Option::is_none")] + pub confirmation: Option, + + /// Authentication Time - Unix timestamp when the user was authenticated. + #[serde(rename = "auth_time", skip_serializing_if = "Option::is_none")] + pub auth_time: Option, + + /// Authorized Party - The party to which the token was issued. + #[serde(rename = "azp", skip_serializing_if = "Option::is_none")] + pub authorized_party: Option, + + /// Actor - Used for delegated authorization (on behalf of another party). + #[serde(rename = "act", skip_serializing_if = "Option::is_none")] + pub actor: Option, + + /// Session ID - Links the token to a specific user session (for logout, etc.). + #[serde(rename = "sid", skip_serializing_if = "Option::is_none")] + pub session_id: Option, + + // OpenID Connect Standard Claims (OIDC Core 1.0) + /// User's full name. + #[serde(rename = "name", skip_serializing_if = "Option::is_none")] + pub name: Option, + + /// User's first name. + #[serde(rename = "given_name", skip_serializing_if = "Option::is_none")] + pub given_name: Option, + + /// User's last name. + #[serde(rename = "family_name", skip_serializing_if = "Option::is_none")] + pub family_name: Option, + + /// User's middle name. + #[serde(rename = "middle_name", skip_serializing_if = "Option::is_none")] + pub middle_name: Option, + + /// Casual name of the user. + #[serde(rename = "nickname", skip_serializing_if = "Option::is_none")] + pub nickname: Option, + + /// Preferred username (often login name). + #[serde(rename = "preferred_username", skip_serializing_if = "Option::is_none")] + pub preferred_username: Option, + + /// URL of the user's profile page. + #[serde(rename = "profile", skip_serializing_if = "Option::is_none")] + pub profile: Option, + + /// URL of the user's profile picture. + #[serde(rename = "picture", skip_serializing_if = "Option::is_none")] + pub picture: Option, + + /// URL of the user's website. + #[serde(rename = "website", skip_serializing_if = "Option::is_none")] + pub website: Option, + + /// User's email address. + #[serde(rename = "email", skip_serializing_if = "Option::is_none")] + pub email: Option, + + /// Whether the email has been verified. + #[serde(rename = "email_verified", skip_serializing_if = "Option::is_none")] + pub email_verified: Option, + + /// User's gender. + #[serde(rename = "gender", skip_serializing_if = "Option::is_none")] + pub gender: Option, + + /// User's date of birth (e.g., "YYYY-MM-DD"). + #[serde(rename = "birthdate", skip_serializing_if = "Option::is_none")] + pub birthdate: Option, + + /// User's time zone (e.g., "America/New_York"). + #[serde(rename = "zoneinfo", skip_serializing_if = "Option::is_none")] + pub zoneinfo: Option, + + /// User's locale (e.g., "en-US"). + #[serde(rename = "locale", skip_serializing_if = "Option::is_none")] + pub locale: Option, + + /// User's phone number. + #[serde(rename = "phone_number", skip_serializing_if = "Option::is_none")] + pub phone_number: Option, + + /// Whether the phone number has been verified. + #[serde( + rename = "phone_number_verified", + skip_serializing_if = "Option::is_none" + )] + pub phone_number_verified: Option, + + /// User's structured address. + #[serde(rename = "address", skip_serializing_if = "Option::is_none")] + pub address: Option
, + + /// Last time the user's information was updated (Unix timestamp). + #[serde(rename = "updated_at", skip_serializing_if = "Option::is_none")] + pub updated_at: Option, + + // Microsoft Entra ID (Azure AD) Provider-Specific Claims + /// Object ID of the user or service principal (Entra ID). + #[serde(rename = "oid", skip_serializing_if = "Option::is_none")] + pub object_id: Option, + + /// Tenant ID (directory ID) (Entra ID). + #[serde(rename = "tid", skip_serializing_if = "Option::is_none")] + pub tenant_id: Option, + + /// User Principal Name (login, e.g., user@domain) (Entra ID). + #[serde(rename = "upn", skip_serializing_if = "Option::is_none")] + pub user_principal_name: Option, + + /// Assigned roles (Entra ID). + #[serde(rename = "roles", skip_serializing_if = "Option::is_none")] + pub roles: Option>, + + /// Azure AD groups (GUIDs) (Entra ID). + #[serde(rename = "groups", skip_serializing_if = "Option::is_none")] + pub groups: Option>, + + /// Application ID (same as client_id) (Entra ID). + #[serde(rename = "appid", skip_serializing_if = "Option::is_none")] + pub application_id: Option, + + /// Unique name (e.g., user@domain) (Entra ID). + #[serde(rename = "unique_name", skip_serializing_if = "Option::is_none")] + pub unique_name: Option, + + /// Token version (e.g., "1.0" or "2.0") (Entra ID). + #[serde(rename = "ver", skip_serializing_if = "Option::is_none")] + pub version: Option, +} + +/// Represents an OAuth 2.0 Token Introspection response as per RFC 7662. +/// +/// This struct captures the response from an OAuth 2.0 introspection endpoint, +/// providing details about the validity and metadata of an access or refresh token. +/// All fields are optional except `active`, as per the specification, to handle +/// cases where the token is inactive or certain metadata is not provided. +/// +/// # Example JSON +/// ```json +/// { +/// "active": true, +/// "scope": "read write", +/// "client_id": "client123", +/// "username": "john_doe", +/// "token_type": "access_token", +/// "exp": 1697054400, +/// "iat": 1697050800, +/// "nbf": 1697050800, +/// "sub": "user123", +/// "aud": ["resource_server_1", "resource_server_2"], +/// "iss": "https://auth.example.com", +/// "jti": "abc123" +/// } +/// ``` +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(rename_all = "snake_case")] +pub struct IntrospectionResponse { + /// Indicates whether the token is active (valid, not expired, etc.). + /// This field is required by the OAuth 2.0 introspection specification. + pub active: bool, + + /// Space-separated list of scopes granted to the token. + /// Optional, as the token may have no scopes or be inactive. + #[serde(default)] + pub scope: Option, + + /// Identifier of the client that requested the token. + /// Optional, as it may not be provided for inactive tokens. + #[serde(default)] + pub client_id: Option, + + /// Username of the resource owner associated with the token, if applicable. + /// Optional, as it may not apply to all token types or be absent for inactive tokens. + #[serde(default)] + pub username: Option, + + /// Type of the token, typically "access_token" or "refresh_token". + /// Optional, as it may not be provided for inactive tokens. + #[serde(default)] + pub token_type: Option, + + /// Expiration Time - Unix timestamp when the token expires (JWT: exp). + #[serde(rename = "exp", skip_serializing_if = "Option::is_none")] + pub expiration: Option, + + /// Issued At - Unix timestamp when the token was issued (JWT: iat). + #[serde(rename = "iat", skip_serializing_if = "Option::is_none")] + pub issued_at: Option, + + /// Not Before - Unix timestamp when the token becomes valid (JWT: nbf). + #[serde(rename = "nbf", skip_serializing_if = "Option::is_none")] + pub not_before: Option, + + /// Subject identifier, often the user ID associated with the token. + /// Optional, as it may not be provided for inactive tokens. + #[serde(rename = "sub", skip_serializing_if = "Option::is_none")] + pub subject: Option, + + /// Audience(s) the token is intended for, which can be a single string or an array of strings. + /// Optional, as it may not be provided for inactive tokens. + #[serde(rename = "aud", skip_serializing_if = "Option::is_none")] + pub audience: Option, + + /// Issuer identifier, typically the URI of the authorization server. + /// Optional, as it may not be provided for inactive tokens. + #[serde(rename = "iss", skip_serializing_if = "Option::is_none")] + pub issuer: Option, + + /// JWT ID - Unique identifier for the token to prevent reuse (JWT: jti). + #[serde(rename = "jti", skip_serializing_if = "Option::is_none")] + pub jwt_id: Option, +} diff --git a/crates/rust-mcp-sdk/src/auth/spec/discovery.rs b/crates/rust-mcp-sdk/src/auth/spec/discovery.rs new file mode 100644 index 0000000..181a37c --- /dev/null +++ b/crates/rust-mcp-sdk/src/auth/spec/discovery.rs @@ -0,0 +1,315 @@ +use crate::{ + auth::{OauthEndpoint, OAUTH_PROTECTED_RESOURCE_BASE, WELL_KNOWN_OAUTH_AUTHORIZATION_SERVER}, + error::McpSdkError, + mcp_http::url_base, +}; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use url::Url; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct AuthorizationServerMetadata { + /// The base URL of the authorization server (e.g., "http://localhost:8080/realms/master/"). + pub issuer: Url, + + /// URL to which the client redirects the user for authorization. + pub authorization_endpoint: Url, + + /// URL to exchange authorization codes for tokens or refresh tokens. + pub token_endpoint: Url, + + /// URL of the authorization server's JWK Set `JWK` document + #[serde(default, skip_serializing_if = "::std::option::Option::is_none")] + pub jwks_uri: Option, + + /// Endpoint where clients can register dynamically. + #[serde(default, skip_serializing_if = "::std::option::Option::is_none")] + pub registration_endpoint: Option, + + /// List of supported OAuth scopes (e.g., "openid", "profile", "email", mcp:tools) + #[serde(default, skip_serializing_if = "::std::option::Option::is_none")] + pub scopes_supported: Option>, + + /// Response Types. Required by spec. If missing, default is empty vec. + /// Examples: "code", "token", "id_token" + #[serde(default, skip_serializing_if = "::std::vec::Vec::is_empty")] + pub response_types_supported: Vec, + + /// Response Modes. Indicates how the authorization response is returned. + /// Examples: "query", "fragment", "form_post" + #[serde(default, skip_serializing_if = "::std::option::Option::is_none")] + pub response_modes_supported: Option>, + + // ui_locales_supported + // op_policy_uri + // op_tos_uri + /// List of supported Grant Types + /// Examples: "authorization_code", "client_credentials", "refresh_token" + #[serde(default, skip_serializing_if = "::std::option::Option::is_none")] + pub grant_types_supported: Option>, + + /// Methods like "client_secret_basic", "client_secret_post" + #[serde(default, skip_serializing_if = "::std::option::Option::is_none")] + pub token_endpoint_auth_methods_supported: Option>, + + /// Signing algorithms for client authentication (e.g., "RS256") + #[serde(default, skip_serializing_if = "::std::option::Option::is_none")] + pub token_endpoint_auth_signing_alg_values_supported: Option>, + + /// Link to human-readable docs for developers. + /// + #[serde(default, skip_serializing_if = "::std::option::Option::is_none")] + pub service_documentation: Option, + + /// OAuth 2.0 Token Revocation endpoint. + #[serde(default, skip_serializing_if = "::std::option::Option::is_none")] + pub revocation_endpoint: Option, + + /// Similar to token endpoint, but for revocation-specific auth. + #[serde(default, skip_serializing_if = "::std::option::Option::is_none")] + pub revocation_endpoint_auth_signing_alg_values_supported: Option>, + + /// Tells the client which authentication methods are supported when accessing the token revocation endpoint. + /// These are standardized methods from RFC 6749 (OAuth 2.0) + /// Common values: "client_secret_basic", "client_secret_post", "private_key_jwt" + /// `client_secret_basic` – client credentials sent in HTTP Basic Auth. + /// `client_secret_post` – client credentials sent in the POST body. + /// `private_key_jwt` – client authenticates using a signed JWT. + #[serde(default, skip_serializing_if = "::std::option::Option::is_none")] + pub revocation_endpoint_auth_methods_supported: Option>, + + /// URL to validate tokens and get their metadata. + #[serde(default, skip_serializing_if = "::std::option::Option::is_none")] + pub introspection_endpoint: Option, + + /// Auth methods for accessing introspection. + #[serde(default, skip_serializing_if = "::std::option::Option::is_none")] + pub introspection_endpoint_auth_methods_supported: Option>, + + /// Algorithms for accessing introspection. + #[serde(default, skip_serializing_if = "::std::option::Option::is_none")] + pub introspection_endpoint_auth_signing_alg_values_supported: Option>, + + /// Methods supported for PKCE (Proof Key for Code Exchange). + /// Common values: "plain", "S256" + #[serde(default, skip_serializing_if = "::std::option::Option::is_none")] + pub code_challenge_methods_supported: Option>, + + #[serde(default, skip_serializing_if = "::std::option::Option::is_none")] + pub userinfo_endpoint: Option, +} + +impl AuthorizationServerMetadata { + /// Creates a new `AuthorizationServerMetadata` instance with the minimal required fields. + /// According to the OAuth 2.0 Authorization Server Metadata Metadata specification (RFC 8414), + /// the following fields are **required** for a valid metadata document: + /// - `issuer` + /// - `authorization_endpoint` + /// - `token_endpoint` + /// + /// All other fields are initialized with their default values (typically `None` or empty collections). + /// + pub fn new( + issuer: &str, + authorization_endpoint: &str, + token_endpoint: &str, + ) -> Result { + let issuer = Url::parse(issuer)?; + let authorization_endpoint = Url::parse(authorization_endpoint)?; + let token_endpoint = Url::parse(token_endpoint)?; + + Ok(Self { + issuer, + authorization_endpoint, + token_endpoint, + jwks_uri: Default::default(), + registration_endpoint: Default::default(), + scopes_supported: Default::default(), + response_types_supported: Default::default(), + response_modes_supported: Default::default(), + grant_types_supported: Default::default(), + token_endpoint_auth_methods_supported: Default::default(), + token_endpoint_auth_signing_alg_values_supported: Default::default(), + service_documentation: Default::default(), + revocation_endpoint: Default::default(), + revocation_endpoint_auth_signing_alg_values_supported: Default::default(), + revocation_endpoint_auth_methods_supported: Default::default(), + introspection_endpoint: Default::default(), + introspection_endpoint_auth_methods_supported: Default::default(), + introspection_endpoint_auth_signing_alg_values_supported: Default::default(), + code_challenge_methods_supported: Default::default(), + userinfo_endpoint: Default::default(), + }) + } + + /// Fetches authorization server metadata from a remote `.well-known/openid-configuration` + /// or OAuth 2.0 Authorization Server Metadata endpoint. + /// + /// This performs an HTTP GET request and deserializes the response directly into + /// `AuthorizationServerMetadata`. The endpoint must return a JSON document conforming + /// to RFC 8414 (OAuth 2.0 Authorization Server Metadata) or OpenID Connect Discovery 1.0. + /// + pub async fn from_discovery_url(discovery_url: &str) -> Result { + let client = Client::new(); + let metadata = client + .get(discovery_url) + .send() + .await + .map_err(|err| McpSdkError::Internal { + description: err.to_string(), + })? + .json::() + .await + .map_err(|err| McpSdkError::Internal { + description: err.to_string(), + })?; + Ok(metadata) + } +} + +/// represents metadata about a protected resource in the OAuth 2.0 ecosystem. +/// It allows clients and authorization servers to discover how to interact with a protected resource (like an MCP endpoint), +/// including security requirements and supported features. +/// +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct OauthProtectedResourceMetadata { + /// The base identifier of the protected resource (e.g., an MCP server's URI). + /// This is the only required field. + pub resource: Url, + + /// List of authorization servers that can issue access tokens for this resource. + /// Allows dynamic trust discovery. + #[serde(default, skip_serializing_if = "::std::vec::Vec::is_empty")] + pub authorization_servers: Vec, + + /// URL where the resource exposes its public keys (JWKS) to verify signed tokens. + /// Typically used to verify JWT access tokens. + /// Example: `https://example.com/.well-known/jwks.json` + #[serde(default, skip_serializing_if = "::std::option::Option::is_none")] + pub jwks_uri: Option, + + /// OAuth scopes the resource supports (e.g., "mcp:tool", "read", "write", "admin"). + /// Helps clients know what they can request for access. + #[serde(default, skip_serializing_if = "::std::option::Option::is_none")] + pub scopes_supported: Option>, + + /// Methods accepted for presenting Bearer tokens: + /// `authorization_header` (typical) + /// `form_post` + /// `uri_query` + #[serde(default, skip_serializing_if = "::std::option::Option::is_none")] + pub bearer_methods_supported: Option>, + + /// Supported signing algorithms for access tokens (if tokens are JWTs). + /// Example: ["RS256", "ES256"] + #[serde(default, skip_serializing_if = "::std::option::Option::is_none")] + pub resource_signing_alg_values_supported: Option>, + + /// A human-readable name for the resource. + /// Useful for UIs, logs, or developer documentation. + #[serde(default, skip_serializing_if = "::std::option::Option::is_none")] + pub resource_name: Option, + + /// URL to developer docs describing the resource and how to use it. + #[serde(default, skip_serializing_if = "::std::option::Option::is_none")] + pub resource_documentation: Option, + + /// URL to the resource's access policy or terms (e.g., rules on who can access what). + #[serde(default, skip_serializing_if = "::std::option::Option::is_none")] + pub resource_policy_uri: Option, + + /// URL to terms of service applicable to this resource. + #[serde(default, skip_serializing_if = "::std::option::Option::is_none")] + pub resource_tos_uri: Option, + + /// If true, access tokens must be bound to a client TLS certificate. + /// Used in mutual TLS scenarios for additional security. + #[serde(default, skip_serializing_if = "::std::option::Option::is_none")] + pub tls_client_certificate_bound_access_tokens: Option, + + ///Lists structured authorization types supported (used with Rich Authorization Requests (RAR) + /// Example: ["payment_initiation", "account_information"] + #[serde(default, skip_serializing_if = "::std::option::Option::is_none")] + pub authorization_details_types_supported: Option>, + + /// Supported algorithms for DPoP (Demonstration of Proof-of-Possession) tokens. + /// Example: ["ES256", "RS256"] + #[serde(default, skip_serializing_if = "::std::option::Option::is_none")] + pub dpop_signing_alg_values_supported: Option>, + + /// If true, the resource requires access tokens to be DPoP-bound. + /// Enhances security by tying tokens to a specific client and key. + #[serde(default, skip_serializing_if = "::std::option::Option::is_none")] + pub dpop_bound_access_tokens_required: Option, +} + +impl OauthProtectedResourceMetadata { + /// Creates a new `OAuthProtectedResourceMetadata` instance with only the + /// minimal required fields populated. + /// + /// The `resource` and each entry in `authorization_servers` must be valid URLs. + /// All other metadata fields are initialized to their defaults. + /// To provide optional or extended metadata, assign those fields after creation or construct the struct directly. + pub fn new( + resource: S, + authorization_servers: Vec, + scopes_supported: Option>, + ) -> Result + where + S: AsRef, + { + let resource = Url::parse(resource.as_ref())?; + let authorization_servers: Vec<_> = authorization_servers + .iter() + .map(|s| Url::parse(s.as_ref())) + .collect::>()?; + + Ok(Self { + resource, + authorization_servers, + jwks_uri: Default::default(), + scopes_supported, + bearer_methods_supported: Default::default(), + resource_signing_alg_values_supported: Default::default(), + resource_name: Default::default(), + resource_documentation: Default::default(), + resource_policy_uri: Default::default(), + resource_tos_uri: Default::default(), + tls_client_certificate_bound_access_tokens: Default::default(), + authorization_details_types_supported: Default::default(), + dpop_signing_alg_values_supported: Default::default(), + dpop_bound_access_tokens_required: Default::default(), + }) + } +} + +pub fn create_protected_resource_metadata_url(path: &str) -> String { + format!( + "{OAUTH_PROTECTED_RESOURCE_BASE}{}", + if path == "/" { "" } else { path } + ) +} + +pub fn create_discovery_endpoints( + mcp_server_url: &str, +) -> Result<(HashMap, String), McpSdkError> { + let mut endpoint_map = HashMap::new(); + endpoint_map.insert( + WELL_KNOWN_OAUTH_AUTHORIZATION_SERVER.to_string(), + OauthEndpoint::AuthorizationServerMetadata, + ); + + let resource_url = Url::parse(mcp_server_url).map_err(|err| McpSdkError::Internal { + description: err.to_string(), + })?; + + let relative_url = create_protected_resource_metadata_url(resource_url.path()); + let base_url = url_base(&resource_url); + let protected_resource_metadata_url = + format!("{}{relative_url}", base_url.trim_end_matches('/')); + + endpoint_map.insert(relative_url, OauthEndpoint::ProtectedResourceMetadata); + + Ok((endpoint_map, protected_resource_metadata_url)) +} diff --git a/crates/rust-mcp-sdk/src/auth/spec/jwk.rs b/crates/rust-mcp-sdk/src/auth/spec/jwk.rs new file mode 100644 index 0000000..0af9124 --- /dev/null +++ b/crates/rust-mcp-sdk/src/auth/spec/jwk.rs @@ -0,0 +1,94 @@ +use crate::auth::{Audience, AuthClaims, AuthenticationError}; +use http::StatusCode; +use jsonwebtoken::{decode, decode_header, jwk::Jwk, DecodingKey, TokenData, Validation}; +use serde::{Deserialize, Serialize}; + +/// A JSON Web Key Set (JWKS) containing a list of JSON Web Keys. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JsonWebKeySet { + /// List of JSON Web Keys. + pub keys: Vec, +} + +pub fn decode_token_header(token: &str) -> Result { + let header = + decode_header(token).map_err(|err| AuthenticationError::TokenVerificationFailed { + description: err.to_string(), + status_code: Some(StatusCode::UNAUTHORIZED.as_u16()), + })?; + Ok(header) +} + +impl JsonWebKeySet { + pub fn verify( + &self, + token: String, + validate_audience: Option<&Audience>, + validate_issuer: Option<&String>, + ) -> Result, AuthenticationError> { + let header = decode_token_header(&token)?; + + let kid = header.kid.ok_or(AuthenticationError::InvalidToken { + description: "Missing kid in token header", + })?; + + let jwk = self + .keys + .iter() + .find(|key| key.common.key_id == Some(kid.clone())) + .ok_or(AuthenticationError::InvalidToken { + description: "No matching key found in JWKS", + })?; + + let decoding_key = DecodingKey::from_jwk(jwk).map_err(|err| { + AuthenticationError::TokenVerificationFailed { + description: err.to_string(), + status_code: None, + } + })?; + + let mut validation = Validation::new(header.alg); + + let mut required_claims = vec![]; + if let Some(validate_audience) = validate_audience { + let vec_audience = match validate_audience { + Audience::Single(aud) => &vec![aud.to_owned()], + Audience::Multiple(auds) => auds, + }; + validation.set_audience(vec_audience); + required_claims.push("aud"); + } else { + validation.validate_aud = false; + } + + if let Some(validate_issuer) = validate_issuer { + validation.set_issuer(&[validate_issuer]); + required_claims.push("iss"); + } + if !required_claims.is_empty() { + validation.set_required_spec_claims(&required_claims); + } + + let token_data = + decode::(token, &decoding_key, &validation).map_err(|err| { + match err.kind() { + jsonwebtoken::errors::ErrorKind::InvalidToken => { + AuthenticationError::InvalidToken { + description: "Invalid token", + } + } + jsonwebtoken::errors::ErrorKind::ExpiredSignature => { + AuthenticationError::InvalidToken { + description: "Expired token", + } + } + _ => AuthenticationError::TokenVerificationFailed { + description: err.to_string(), + status_code: Some(StatusCode::BAD_REQUEST.as_u16()), + }, + } + })?; + + Ok(token_data) + } +} diff --git a/crates/rust-mcp-sdk/src/auth/token_verifier.rs b/crates/rust-mcp-sdk/src/auth/token_verifier.rs new file mode 100644 index 0000000..f6f9c67 --- /dev/null +++ b/crates/rust-mcp-sdk/src/auth/token_verifier.rs @@ -0,0 +1,7 @@ +use super::{AuthInfo, AuthenticationError}; +use async_trait::async_trait; + +#[async_trait] +pub trait OauthTokenVerifier: Send + Sync { + async fn verify_token(&self, access_token: String) -> Result; +} diff --git a/crates/rust-mcp-sdk/src/error.rs b/crates/rust-mcp-sdk/src/error.rs index 3879526..9a99336 100644 --- a/crates/rust-mcp-sdk/src/error.rs +++ b/crates/rust-mcp-sdk/src/error.rs @@ -1,5 +1,6 @@ +#[cfg(feature = "auth")] +use crate::auth::AuthenticationError; use crate::schema::{ParseProtocolVersionError, RpcError}; - use rust_mcp_transport::error::TransportError; use thiserror::Error; use tokio::task::JoinError; @@ -27,11 +28,18 @@ pub enum McpSdkError { #[error("{0}")] HyperServer(#[from] TransportServerError), + #[cfg(feature = "auth")] + #[error("{0}")] + AuthenticationError(#[from] AuthenticationError), + #[error("{0}")] SdkError(#[from] crate::schema::schema_utils::SdkError), #[error("Protocol error: {kind}")] Protocol { kind: ProtocolErrorKind }, + + #[error("Server error: {description}")] + Internal { description: String }, } // Sub-enum for protocol-related errors diff --git a/crates/rust-mcp-sdk/src/hyper_servers/error.rs b/crates/rust-mcp-sdk/src/hyper_servers/error.rs index dd55d8f..bc0955a 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/error.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/error.rs @@ -3,6 +3,9 @@ use std::net::AddrParseError; use axum::{http::StatusCode, response::IntoResponse}; use thiserror::Error; +#[cfg(feature = "auth")] +use crate::auth::AuthenticationError; + pub type TransportServerResult = core::result::Result; #[derive(Debug, Error, Clone)] @@ -25,6 +28,9 @@ pub enum TransportServerError { SslCertError(String), #[error("{0}")] TransportError(String), + #[cfg(feature = "auth")] + #[error("{0}")] + AuthenticationError(#[from] AuthenticationError), } impl IntoResponse for TransportServerError { diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes.rs index fcaa290..8724839 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "auth")] +pub mod auth_routes; pub mod fallback_routes; pub mod messages_routes; #[cfg(feature = "sse")] @@ -5,7 +7,8 @@ pub mod sse_routes; pub mod streamable_http_routes; use super::HyperServerOptions; -use crate::mcp_http::{McpAppState, McpHttpHandler}; +use crate::mcp_http::McpAppState; +use crate::mcp_http::McpHttpHandler; use axum::{Extension, Router}; use std::sync::Arc; @@ -20,33 +23,41 @@ use std::sync::Arc; /// /// # Returns /// * `Router` - An Axum router configured with all application routes and state +/// pub fn app_routes( state: Arc, server_options: &HyperServerOptions, http_handler: McpHttpHandler, ) -> Router { - let router: Router = Router::new() - .merge(streamable_http_routes::routes( + let http_handler = Arc::new(http_handler); + + let router = { + let mut router = Router::new(); + + #[cfg(feature = "auth")] + { + router = router.merge(auth_routes::routes(http_handler.clone())); + } + + router = router.merge(streamable_http_routes::routes( server_options.streamable_http_endpoint(), - )) - .merge({ - let mut r = Router::new(); - #[cfg(feature = "sse")] - if server_options.sse_support { - r = r - .merge(sse_routes::routes( - server_options.sse_endpoint(), - server_options.sse_messages_endpoint(), - )) - .merge(messages_routes::routes( - server_options.sse_messages_endpoint(), - )) - } - r - }) - .with_state(state) - .merge(fallback_routes::routes()) - .layer(Extension(Arc::new(http_handler))); + )); + + #[cfg(feature = "sse")] + { + router = router + .merge(sse_routes::routes( + server_options.sse_endpoint(), + server_options.sse_messages_endpoint(), + )) + .merge(messages_routes::routes( + server_options.sse_messages_endpoint(), + )); + } + + router = router.merge(fallback_routes::routes()); + router.with_state(state).layer(Extension(http_handler)) + }; router } diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/auth_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/auth_routes.rs new file mode 100644 index 0000000..e905177 --- /dev/null +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/auth_routes.rs @@ -0,0 +1,33 @@ +use crate::hyper_servers::error::TransportServerResult; +use crate::mcp_http::{McpAppState, McpHttpHandler}; +use axum::routing::any; +use axum::Extension; +use axum::{extract::State, response::IntoResponse, Router}; +use http::{HeaderMap, Method, Uri}; +use std::sync::Arc; + +pub fn routes(mcp_handler: Arc) -> Router> { + let endpoints: Vec<&String> = mcp_handler.oauth_endppoints().unwrap_or_default(); + + endpoints + .into_iter() + .fold(Router::new(), |router, endpoint| { + router.route(endpoint, any(handle_auth_request)) + }) +} + +#[cfg(feature = "auth")] +pub async fn handle_auth_request( + method: Method, + uri: Uri, + headers: HeaderMap, + State(state): State>, + Extension(http_handler): Extension>, + payload: String, +) -> TransportServerResult { + let request = McpHttpHandler::create_request(method, uri, headers, Some(payload.as_str())); + let generic_res = http_handler.handle_auth_requests(request, state).await?; + let (parts, body) = generic_res.into_parts(); + let resp = axum::response::Response::from_parts(parts, axum::body::Body::new(body)); + Ok(resp) +} diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/fallback_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/fallback_routes.rs index 971ed43..5f70fb2 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/fallback_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/fallback_routes.rs @@ -1,13 +1,16 @@ +use crate::mcp_http::McpAppState; use axum::{ http::{StatusCode, Uri}, + response::IntoResponse, Router, }; +use std::sync::Arc; -pub fn routes() -> Router { +pub fn routes() -> Router> { Router::new().fallback(not_found) } -pub async fn not_found(uri: Uri) -> (StatusCode, String) { +pub async fn not_found(uri: Uri) -> impl IntoResponse { ( StatusCode::NOT_FOUND, format!("The requested uri does not exist:\r\nuri: {uri}"), diff --git a/crates/rust-mcp-sdk/src/hyper_servers/server.rs b/crates/rust-mcp-sdk/src/hyper_servers/server.rs index 0b30da6..a3eb655 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/server.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/server.rs @@ -1,3 +1,11 @@ +use super::{ + error::{TransportServerError, TransportServerResult}, + routes::app_routes, +}; +#[cfg(feature = "auth")] +use crate::auth::AuthProvider; +#[cfg(feature = "auth")] +use crate::mcp_http::middleware::AuthMiddleware; use crate::{ error::SdkResult, id_generator::{FastIdGenerator, UuidGenerator}, @@ -5,16 +13,19 @@ use crate::{ http_utils::{ DEFAULT_MESSAGES_ENDPOINT, DEFAULT_SSE_ENDPOINT, DEFAULT_STREAMABLE_HTTP_ENDPOINT, }, - middleware::dns_rebind_protector::DnsRebindProtector, + middleware::DnsRebindProtector, McpAppState, McpHttpHandler, }, mcp_server::hyper_runtime::HyperRuntime, - mcp_traits::{mcp_handler::McpServerHandler, IdGenerator}, + mcp_traits::{IdGenerator, McpServerHandler}, session_store::InMemorySessionStore, }; +use crate::{mcp_http::Middleware, schema::InitializeResult}; +use axum::Router; #[cfg(feature = "ssl")] use axum_server::tls_rustls::RustlsConfig; use axum_server::Handle; +use rust_mcp_transport::{event_store::EventStore, SessionId, TransportOptions}; use std::{ net::{SocketAddr, ToSocketAddrs}, path::Path, @@ -23,14 +34,6 @@ use std::{ }; use tokio::signal; -use super::{ - error::{TransportServerError, TransportServerResult}, - routes::app_routes, -}; -use crate::schema::InitializeResult; -use axum::Router; -use rust_mcp_transport::{event_store::EventStore, SessionId, TransportOptions}; - // Default client ping interval (12 seconds) const DEFAULT_CLIENT_PING_INTERVAL: Duration = Duration::from_secs(12); const GRACEFUL_SHUTDOWN_TMEOUT_SECS: u64 = 5; @@ -99,6 +102,10 @@ pub struct HyperServerOptions { /// Optional custom path for the MCP messages endpoint for sse (default: `/messages`) /// Applicable only if sse_support is true pub custom_messages_endpoint: Option, + + /// Optional authentication provider for protecting MCP server. + #[cfg(feature = "auth")] + pub auth: Option>, } impl HyperServerOptions { @@ -235,6 +242,8 @@ impl Default for HyperServerOptions { allowed_origins: None, dns_rebinding_protection: false, event_store: None, + #[cfg(feature = "auth")] + auth: None, } } } @@ -279,16 +288,32 @@ impl HyperServer { event_store: server_options.event_store.as_ref().map(Arc::clone), }); - let mut http_handler = McpHttpHandler::new(); - + // populate middlewares + let mut middlewares: Vec> = vec![]; if server_options.needs_dns_protection() { - http_handler.add_middleware(DnsRebindProtector::new( + //dns pritection middleware + middlewares.push(Arc::new(DnsRebindProtector::new( server_options.allowed_hosts.take(), server_options.allowed_origins.take(), - )); + ))); } + let http_handler = { + #[cfg(feature = "auth")] + { + let auth_provider = server_options.auth.take(); + // add auth middleware if there is a auth_provider + if let Some(auth_provider) = auth_provider.as_ref() { + middlewares.push(Arc::new(AuthMiddleware::new(auth_provider.clone()))) + } + McpHttpHandler::new(auth_provider, middlewares) + } + #[cfg(not(feature = "auth"))] + McpHttpHandler::new(middlewares) + }; + let app = app_routes(Arc::clone(&state), &server_options, http_handler); + Self { app, state, diff --git a/crates/rust-mcp-sdk/src/lib.rs b/crates/rust-mcp-sdk/src/lib.rs index 0d668a0..1d6476a 100644 --- a/crates/rust-mcp-sdk/src/lib.rs +++ b/crates/rust-mcp-sdk/src/lib.rs @@ -2,7 +2,8 @@ pub mod error; #[cfg(feature = "hyper-server")] mod hyper_servers; mod mcp_handlers; -#[cfg(feature = "hyper-server")] + +#[cfg(any(feature = "hyper-server", feature = "auth"))] pub mod mcp_http; mod mcp_macros; mod mcp_runtimes; @@ -74,24 +75,18 @@ pub mod mcp_server { pub use super::mcp_runtimes::server_runtime::mcp_server_runtime_core as server_runtime_core; pub use super::mcp_runtimes::server_runtime::ServerRuntime; - #[cfg(feature = "hyper-server")] - pub use super::hyper_servers::hyper_server; - #[cfg(feature = "hyper-server")] - pub use super::hyper_servers::hyper_server_core; #[cfg(feature = "hyper-server")] pub use super::hyper_servers::*; pub use super::utils::enforce_compatible_protocol_version; + #[cfg(feature = "auth")] + pub use super::utils::join_url; #[cfg(feature = "hyper-server")] pub use super::mcp_http::{McpAppState, McpHttpHandler}; } -#[cfg(feature = "client")] -pub use mcp_traits::mcp_client::*; - -#[cfg(feature = "server")] -pub use mcp_traits::mcp_server::*; - +pub mod auth; +pub use mcp_traits::*; pub use rust_mcp_transport::error::*; pub use rust_mcp_transport::*; diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs index c6fb208..e78db9a 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler.rs @@ -10,7 +10,7 @@ use crate::schema::{ElicitRequest, ElicitResult}; use async_trait::async_trait; use serde_json::Value; -use crate::mcp_traits::mcp_client::McpClient; +use crate::mcp_traits::McpClient; /// Defines the `ClientHandler` trait for handling Model Context Protocol (MCP) operations on a client. /// This trait provides default implementations for request and notification handlers in an MCP client, diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs index a0afdf1..59444b0 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_client_handler_core.rs @@ -2,7 +2,7 @@ use crate::schema::schema_utils::*; use crate::schema::*; use async_trait::async_trait; -use crate::mcp_traits::mcp_client::McpClient; +use crate::mcp_traits::McpClient; /// Defines the `ClientHandlerCore` trait for handling Model Context Protocol (MCP) client operations. /// Unlike `ClientHandler`, this trait offers no default implementations, providing full control over MCP message handling diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs index 9f8c9e3..0a51967 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs @@ -1,13 +1,13 @@ use crate::{ mcp_server::server_runtime::ServerRuntimeInternalHandler, - mcp_traits::mcp_handler::McpServerHandler, + mcp_traits::McpServerHandler, schema::{schema_utils::CallToolError, *}, }; use async_trait::async_trait; use serde_json::Value; use std::sync::Arc; -use crate::{mcp_traits::mcp_server::McpServer, utils::enforce_compatible_protocol_version}; +use crate::{mcp_traits::McpServer, utils::enforce_compatible_protocol_version}; /// Defines the `ServerHandler` trait for handling Model Context Protocol (MCP) operations on a server. /// This trait provides default implementations for request and notification handlers in an MCP server, diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs index 9275da7..c89e403 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs @@ -1,4 +1,4 @@ -use crate::mcp_traits::mcp_server::McpServer; +use crate::mcp_traits::McpServer; use crate::schema::schema_utils::*; use crate::schema::*; use async_trait::async_trait; diff --git a/crates/rust-mcp-sdk/src/mcp_http.rs b/crates/rust-mcp-sdk/src/mcp_http.rs index 17c8236..5c995f9 100644 --- a/crates/rust-mcp-sdk/src/mcp_http.rs +++ b/crates/rust-mcp-sdk/src/mcp_http.rs @@ -1,12 +1,14 @@ mod app_state; pub(crate) mod http_utils; mod mcp_http_handler; + pub mod middleware; mod types; pub use app_state::*; pub use http_utils::*; pub use mcp_http_handler::*; + pub use types::*; pub use middleware::Middleware; diff --git a/crates/rust-mcp-sdk/src/mcp_http/app_state.rs b/crates/rust-mcp-sdk/src/mcp_http/app_state.rs index b068612..edb1c94 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/app_state.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/app_state.rs @@ -1,4 +1,4 @@ -use crate::mcp_traits::mcp_handler::McpServerHandler; +use crate::mcp_traits::McpServerHandler; use crate::session_store::SessionStore; use crate::{id_generator::FastIdGenerator, mcp_traits::IdGenerator, schema::InitializeResult}; use rust_mcp_transport::event_store::EventStore; diff --git a/crates/rust-mcp-sdk/src/mcp_http/http_utils.rs b/crates/rust-mcp-sdk/src/mcp_http/http_utils.rs index 29cff4f..52509d3 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/http_utils.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/http_utils.rs @@ -1,12 +1,14 @@ +use crate::auth::AuthInfo; use crate::mcp_http::types::GenericBody; use crate::schema::schema_utils::{ClientMessage, SdkError}; +use crate::McpServer; use crate::{ error::SdkResult, hyper_servers::error::{TransportServerError, TransportServerResult}, mcp_http::McpAppState, mcp_runtimes::server_runtime::DEFAULT_STREAM_ID, mcp_server::{server_runtime, ServerRuntime}, - mcp_traits::{mcp_handler::McpServerHandler, IdGenerator}, + mcp_traits::{IdGenerator, McpServerHandler}, utils::validate_mcp_protocol_version, }; use axum::http::HeaderValue; @@ -14,13 +16,13 @@ use bytes::Bytes; use futures::stream; use http::header::{ACCEPT, CONNECTION, CONTENT_TYPE}; use http_body::Frame; -use http_body_util::StreamBody; -use http_body_util::{BodyExt, Full}; +use http_body_util::{BodyExt, Full, StreamBody}; use hyper::{HeaderMap, StatusCode}; use rust_mcp_transport::{ EventId, McpDispatch, SessionId, SseEvent, SseTransport, StreamId, ID_SEPARATOR, MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, }; +use serde_json::{Map, Value}; use std::sync::Arc; use tokio::io::{duplex, AsyncBufReadExt, BufReader}; use tokio_stream::StreamExt; @@ -33,31 +35,6 @@ pub(crate) const DEFAULT_MESSAGES_ENDPOINT: &str = "/messages"; pub(crate) const DEFAULT_STREAMABLE_HTTP_ENDPOINT: &str = "/mcp"; const DUPLEX_BUFFER_SIZE: usize = 8192; -/// Creates an empty HTTP response body. -/// -/// This function constructs a `GenericBody` containing an empty `Bytes` buffer, -/// The body is wrapped in a `BoxBody` to ensure type erasure and compatibility -/// with the HTTP framework. -pub fn empty_response() -> GenericBody { - Full::new(Bytes::new()) - .map_err(|err| TransportServerError::HttpError(err.to_string())) - .boxed() -} - -pub fn build_response( - status_code: StatusCode, - payload: String, -) -> Result, TransportServerError> { - let body = Full::new(Bytes::from(payload)) - .map_err(|err| TransportServerError::HttpError(err.to_string())) - .boxed(); - - http::Response::builder() - .status(status_code) - .body(body) - .map_err(|err| TransportServerError::HttpError(err.to_string())) -} - /// Creates an initial SSE event that returns the messages endpoint /// /// Constructs an SSE event containing the messages endpoint URL with the session ID. @@ -74,6 +51,96 @@ fn initial_sse_event(endpoint: &str) -> Result { .as_bytes()) } +#[cfg(feature = "auth")] +pub fn url_base(url: &url::Url) -> String { + format!("{}://{}", url.scheme(), url.host_str().unwrap_or_default()) +} + +/// Remove the `Bearer` prefix from a `WWW-Authenticate` or `Authorization` header. +/// +/// This function performs a **case-insensitive** check for the `Bearer` +/// authentication scheme. If present, the prefix is removed and the +/// remaining parameter string is returned trimmed. +fn strip_bearer_prefix(header: &str) -> &str { + let lower = header.to_lowercase(); + if lower.starts_with("bearer ") { + header[7..].trim() + } else if lower == "bearer" { + "" + } else { + header.trim() + } +} + +/// Parse a `WWW-Authenticate` header with Bearer-style key/value parameters +/// into a JSON object (`serde_json::Map`). +#[cfg(feature = "auth")] +pub fn parse_www_authenticate(header: &str) -> Option> { + let params_str = strip_bearer_prefix(header); + + let mut result: Option> = None; + + for part in params_str.split(',') { + let part = part.trim(); + + if let Some((key, value)) = part.split_once('=') { + let cleaned = value.trim().trim_matches('"'); + + // Create the map only when first key=value is found + let map = result.get_or_insert_with(Map::new); + map.insert(key.to_string(), Value::String(cleaned.to_string())); + } + } + + result +} + +/// Extract the most meaningful error message from an HTTP response. +/// This is useful for handling OAuth2 / OpenID Connect Bearer errors +/// +/// Extraction order: +/// 1. If the `WWW-Authenticate` header exists and contains a Bearer error: +/// - Return `error_description` if present +/// - Else return `error` if present +/// - Else join all string values in the header +/// 2. If no usable info is found in the header: +/// - Return the response body text +/// - If body cannot be read, return `default_message` +#[cfg(feature = "auth")] +pub async fn error_message_from_response( + response: reqwest::Response, + default_message: &str, +) -> String { + if let Some(www_authenticate) = response + .headers() + .get(http::header::WWW_AUTHENTICATE) + .and_then(|v| v.to_str().ok()) + { + if let Some(map) = parse_www_authenticate(www_authenticate) { + if let Some(Value::String(s)) = map.get("error_description") { + return s.clone(); + } + if let Some(Value::String(s)) = map.get("error") { + return s.clone(); + } + + // Fallback: join all string values + let values: Vec<&str> = map + .values() + .filter_map(|v| match v { + Value::String(s) => Some(s.as_str()), + _ => None, + }) + .collect(); + if !values.is_empty() { + return values.join(", "); + } + } + } + + response.text().await.unwrap_or(default_message.to_owned()) +} + async fn create_sse_stream( runtime: Arc, session_id: SessionId, @@ -268,11 +335,14 @@ pub(crate) async fn create_standalone_stream( session_id: SessionId, last_event_id: Option, state: Arc, + auth_info: Option, ) -> TransportServerResult> { let runtime = state.session_store.get(&session_id).await.ok_or( TransportServerError::SessionIdInvalid(session_id.to_string()), )?; + runtime.update_auth_info(auth_info).await; + if runtime.stream_id_exists(DEFAULT_STREAM_ID).await { let error = SdkError::bad_request().with_message("Only one SSE stream is allowed per session"); @@ -303,6 +373,7 @@ pub(crate) async fn create_standalone_stream( pub(crate) async fn start_new_session( state: Arc, payload: &str, + auth_info: Option, ) -> TransportServerResult> { let session_id: SessionId = state.id_generator.generate(); @@ -312,6 +383,7 @@ pub(crate) async fn start_new_session( Arc::clone(&state.server_details), h, session_id.to_owned(), + auth_info, ); tracing::info!("a new client joined : {}", &session_id); @@ -438,9 +510,11 @@ pub(crate) async fn process_incoming_message_return( session_id: SessionId, state: Arc, payload: &str, + auth_info: Option, ) -> TransportServerResult> { match state.session_store.get(&session_id).await { Some(runtime) => { + runtime.update_auth_info(auth_info).await; single_shot_stream( runtime.clone(), session_id, @@ -463,9 +537,11 @@ pub(crate) async fn process_incoming_message( session_id: SessionId, state: Arc, payload: &str, + auth_info: Option, ) -> TransportServerResult> { match state.session_store.get(&session_id).await { Some(runtime) => { + runtime.update_auth_info(auth_info).await; // when receiving a result in a streamable_http server, that means it was sent by the standalone sse transport // it should be processed by the same transport , therefore no need to call create_sse_stream let Ok(is_result) = is_result(payload) else { @@ -637,6 +713,7 @@ pub(crate) fn query_param(request: &http::Request<&str>, key: &str) -> Option, sse_message_endpoint: Option<&str>, + auth_info: Option, ) -> TransportServerResult> { let session_id: SessionId = state.id_generator.generate(); @@ -669,6 +746,7 @@ pub(crate) async fn handle_sse_connection( Arc::clone(&state.server_details), h, session_id.to_owned(), + auth_info, ); state diff --git a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs index cb17689..75ffcc3 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/mcp_http_handler.rs @@ -1,10 +1,14 @@ #[cfg(feature = "sse")] +use super::http_utils::handle_sse_connection; use super::http_utils::{ - accepts_event_stream, empty_response, error_response, handle_sse_connection, query_param, - validate_mcp_protocol_version_header, + accepts_event_stream, error_response, query_param, validate_mcp_protocol_version_header, }; use super::types::GenericBody; +use crate::auth::AuthInfo; +#[cfg(feature = "auth")] +use crate::auth::AuthProvider; use crate::mcp_http::{middleware::compose, BoxFutureResponse, Middleware, RequestHandler}; +use crate::mcp_http::{GenericBodyExt, RequestExt}; use crate::mcp_runtimes::server_runtime::DEFAULT_STREAM_ID; use crate::mcp_server::error::TransportServerError; use crate::schema::schema_utils::SdkError; @@ -32,11 +36,14 @@ use std::sync::Arc; /// ```ignore /// let handle = with_middlewares!(self, Self::internal_handle_sse_message); /// handle +/// +/// // OR +/// let handler = with_middlewares!(self, Self::internal_handle_sse_message, extra_middlewares1, extra_middlewares2); /// ``` #[macro_export] macro_rules! with_middlewares { ($self:ident, $handler:path) => {{ - let final_handler: RequestHandler = std::sync::Arc::new( + let final_handler: RequestHandler = Box::new( move |req: http::Request<&str>, state: std::sync::Arc| -> BoxFutureResponse<'_> { @@ -45,25 +52,43 @@ macro_rules! with_middlewares { ); $crate::mcp_http::middleware::compose(&$self.middlewares, final_handler) }}; + + // Handler + extra middleware(s) + ($self:ident, $handler:path, $($extra:expr),+ $(,)?) => {{ + let final_handler: RequestHandler = Box::new( + move |req: http::Request<&str>, + state: std::sync::Arc| + -> BoxFutureResponse<'_> { + Box::pin(async move { $handler(req, state).await }) + }, + ); + + // Chain $self.middlewares with any extra middleware iterators + let all = $self.middlewares.iter() + $(.chain($extra.iter()))+; + + $crate::mcp_http::middleware::compose(all, final_handler) + }}; } #[derive(Clone)] pub struct McpHttpHandler { + #[cfg(feature = "auth")] + auth: Option>, middlewares: Vec>, } -impl Default for McpHttpHandler { - fn default() -> Self { - Self::new() +impl McpHttpHandler { + #[cfg(feature = "auth")] + pub fn new(auth: Option>, middlewares: Vec>) -> Self { + McpHttpHandler { auth, middlewares } } -} -impl McpHttpHandler { - pub fn new() -> Self { - McpHttpHandler { - middlewares: vec![], - } + #[cfg(not(feature = "auth"))] + pub fn new(middlewares: Vec>) -> Self { + McpHttpHandler { middlewares } } + pub fn add_middleware(&mut self, middleware: M) { let m: Arc = Arc::new(middleware); self.middlewares.push(m); @@ -92,6 +117,42 @@ impl McpHttpHandler { } } +// auth related methods +#[cfg(feature = "auth")] +impl McpHttpHandler { + pub fn oauth_endppoints(&self) -> Option> { + self.auth + .as_ref() + .and_then(|a| a.auth_endpoints().map(|e| e.keys().collect::>())) + } + + pub async fn handle_auth_requests( + &self, + request: http::Request<&str>, + state: Arc, + ) -> TransportServerResult> { + let Some(auth_provider) = self.auth.as_ref() else { + return Err(TransportServerError::HttpError( + "Authentication is not supported by this server.".to_string(), + )); + }; + + let auth_provider = auth_provider.clone(); + let final_handler: RequestHandler = Box::new(move |req, state| { + Box::pin(async move { + use futures::TryFutureExt; + auth_provider + .handle_request(req, state) + .map_err(|e| e) + .await + }) + }); + + let handle = compose(&[], final_handler); + handle(request, state).await + } +} + impl McpHttpHandler { /// Handles an MCP connection using the SSE (Server-Sent Events) transport. /// @@ -112,10 +173,16 @@ impl McpHttpHandler { state: Arc, sse_message_endpoint: Option<&str>, ) -> TransportServerResult> { - let sse_endpoint = Arc::from(sse_message_endpoint.map(|s| s.to_string())); - let final_handler: RequestHandler = Arc::new(move |_req, state| { - let sse_endpoint = sse_endpoint.clone(); - Box::pin(async move { handle_sse_connection(state, sse_endpoint.as_deref()).await }) + use crate::auth::AuthInfo; + use crate::mcp_http::RequestExt; + + let (request, auth_info) = request.take::(); + + let sse_endpoint = sse_message_endpoint.map(|s| s.to_string()); + let final_handler: RequestHandler = Box::new(move |_req, state| { + Box::pin(async move { + handle_sse_connection(state, sse_endpoint.as_deref(), auth_info).await + }) }); let handle = compose(&self.middlewares, final_handler); handle(request, state).await @@ -205,7 +272,7 @@ impl McpHttpHandler { http::Response::builder() .status(StatusCode::ACCEPTED) - .body(empty_response()) + .body(GenericBody::empty()) .map_err(|err| TransportServerError::HttpError(err.to_string())) } @@ -213,10 +280,13 @@ impl McpHttpHandler { request: http::Request<&str>, state: Arc, ) -> TransportServerResult> { + let (request, auth_info) = request.take::(); + let method = request.method(); + let response = match method { - &http::Method::GET => return Self::handle_http_get(request, state).await, - &http::Method::POST => return Self::handle_http_post(request, state).await, + &http::Method::GET => return Self::handle_http_get(request, state, auth_info).await, + &http::Method::POST => return Self::handle_http_post(request, state, auth_info).await, &http::Method::DELETE => return Self::handle_http_delete(request, state).await, other => { let error = SdkError::bad_request().with_message(&format!( @@ -233,6 +303,7 @@ impl McpHttpHandler { async fn handle_http_post( request: http::Request<&str>, state: Arc, + auth_info: Option, ) -> TransportServerResult> { let headers = request.headers(); @@ -265,14 +336,14 @@ impl McpHttpHandler { // has session-id => write to the existing stream Some(id) => { if state.enable_json_response { - process_incoming_message_return(id, state, payload).await + process_incoming_message_return(id, state, payload, auth_info).await } else { - process_incoming_message(id, state, payload).await + process_incoming_message(id, state, payload, auth_info).await } } None => match valid_initialize_method(payload) { Ok(_) => { - return start_new_session(state, payload).await; + return start_new_session(state, payload, auth_info).await; } Err(McpSdkError::SdkError(error)) => error_response(StatusCode::BAD_REQUEST, error), Err(error) => { @@ -289,6 +360,7 @@ impl McpHttpHandler { async fn handle_http_get( request: http::Request<&str>, state: Arc, + auth_info: Option, ) -> TransportServerResult> { let headers = request.headers(); @@ -316,7 +388,8 @@ impl McpHttpHandler { let response = match session_id { Some(session_id) => { - let res = create_standalone_stream(session_id, last_event_id, state).await; + let res = + create_standalone_stream(session_id, last_event_id, state, auth_info).await; res } None => { diff --git a/crates/rust-mcp-sdk/src/mcp_http/middleware.rs b/crates/rust-mcp-sdk/src/mcp_http/middleware.rs index bd7d2ad..c8637e0 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/middleware.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/middleware.rs @@ -1,10 +1,16 @@ -pub mod cors_middleware; -pub(crate) mod dns_rebind_protector; +#[cfg(feature = "auth")] +mod auth_middleware; +mod cors_middleware; +mod dns_rebind_protector; pub mod logging_middleware; -use super::types::{BoxFutureResponse, GenericBody, RequestHandler}; +use super::types::{GenericBody, RequestHandler}; use crate::mcp_http::{McpAppState, MiddlewareNext}; use crate::mcp_server::error::TransportServerResult; +#[cfg(feature = "auth")] +pub(crate) use auth_middleware::*; +pub use cors_middleware::*; +pub(crate) use dns_rebind_protector::*; use http::{Request, Response}; use std::sync::Arc; @@ -19,21 +25,24 @@ pub trait Middleware: Send + Sync + 'static { } /// Build the final handler by folding the middlewares **in reverse**. -pub fn compose( - middlewares: &Vec>, - final_handler: RequestHandler, -) -> RequestHandler { +/// Each middleware and handler is consumed exactly once. +pub fn compose<'a, I>(middlewares: I, final_handler: RequestHandler) -> RequestHandler +where + I: IntoIterator>, + I::IntoIter: DoubleEndedIterator, +{ + // Start with the final handler let mut handler = final_handler; - for mw in middlewares.iter().rev() { - let mw = mw.clone(); - let next = handler.clone(); + // Fold middlewares in reverse order + for mw in middlewares.into_iter().rev() { + let mw = Arc::clone(mw); + let next = handler; - handler = Arc::new(move |req: Request<&str>, state: Arc| { - let mw = mw.clone(); - let next = next.clone(); - - Box::pin(async move { mw.handle(req, state, next).await }) as BoxFutureResponse<'_> + // Each loop iteration consumes `next` and returns a new boxed FnOnce + handler = Box::new(move |req: Request<&str>, state: Arc| { + let mw = Arc::clone(&mw); + Box::pin(async move { mw.handle(req, state, next).await }) }); } @@ -242,7 +251,7 @@ mod tests { /// Final handler – returns a fixed response fn final_handler(body: &'static str, status: StatusCode) -> RequestHandler { - Arc::new(move |_req, _| { + Box::new(move |_req, _| { let resp = Response::builder() .status(status) .body(GenericBody::from_string(body.to_string())) @@ -261,6 +270,7 @@ mod tests { let mw3 = Arc::new(TestMiddleware::new(3)); let middlewares: Vec> = vec![mw1.clone(), mw2.clone(), mw3.clone()]; + let handler = final_handler("final", StatusCode::OK); let composed = compose(&middlewares, handler); @@ -440,7 +450,7 @@ mod tests { let mw2 = Arc::new(TestMiddleware::new(2)); let middlewares: Vec> = vec![mw1.clone(), mw2.clone()]; - let handler: RequestHandler = Arc::new(move |req, _| { + let handler: RequestHandler = Box::new(move |req, _| { let body = req.into_body().to_string(); Box::pin(async move { Ok(Response::builder() diff --git a/crates/rust-mcp-sdk/src/mcp_http/middleware/auth_middleware.rs b/crates/rust-mcp-sdk/src/mcp_http/middleware/auth_middleware.rs new file mode 100644 index 0000000..f6de197 --- /dev/null +++ b/crates/rust-mcp-sdk/src/mcp_http/middleware/auth_middleware.rs @@ -0,0 +1,531 @@ +use crate::{ + auth::{AuthInfo, AuthProvider, AuthenticationError}, + mcp_http::{types::GenericBody, GenericBodyExt, McpAppState, Middleware, MiddlewareNext}, + mcp_server::error::TransportServerResult, +}; +use async_trait::async_trait; +use http::{ + header::{AUTHORIZATION, WWW_AUTHENTICATE}, + HeaderMap, HeaderValue, Request, Response, StatusCode, +}; +use std::{sync::Arc, time::SystemTime}; + +pub struct AuthMiddleware { + auth_provider: Arc, +} + +impl AuthMiddleware { + pub fn new(auth_provider: Arc) -> Self { + Self { auth_provider } + } + + async fn validate( + &self, + headers: &HeaderMap, + ) -> Result { + let Some(auth_token) = headers + .get(AUTHORIZATION) + .map(|v| v.to_str().ok().unwrap_or_default()) + else { + return Err(AuthenticationError::InvalidToken { + description: "Missing access token in Authorization header", + }); + }; + + let token = auth_token.trim(); + let parts: Vec<&str> = token.splitn(2, ' ').collect(); + + if parts.len() != 2 || !parts[0].eq_ignore_ascii_case("bearer") { + return Err(AuthenticationError::InvalidToken { + description: "Invalid Authorization header format, expected 'Bearer TOKEN'", + }); + } + + let bearer_token = parts[1].trim(); + + let auth_info = self + .auth_provider + .verify_token(bearer_token.to_string()) + .await?; + + match auth_info.expires_at { + Some(expires_at) => { + if SystemTime::now() >= expires_at { + return Err(AuthenticationError::InvalidToken { + description: "Token has expired", + }); + } + } + None => { + return Err(AuthenticationError::InvalidToken { + description: "Token has no expiration time", + }) + } + } + + if let Some(required_scopes) = self.auth_provider.required_scopes() { + if let Some(user_scopes) = auth_info.scopes.as_ref() { + if !required_scopes + .iter() + .all(|scope| user_scopes.contains(scope)) + { + return Err(AuthenticationError::InsufficientScope); + } + } + } + + Ok(auth_info) + } + + fn create_www_auth_value(&self, error_code: &str, error: AuthenticationError) -> String { + if let Some(resource_metadata) = self.auth_provider.protected_resource_metadata_url() { + format!( + r#"Bearer error="{error_code}", error_description="{error}", resource_metadata="{resource_metadata}""#, + ) + } else { + format!(r#"Bearer error="{error_code}", error_description="{error}""#,) + } + } + + fn error_response(&self, error: AuthenticationError) -> Response { + let as_json = error.as_json_value(); + let error_code = as_json + .get("error") + .unwrap_or_default() + .as_str() + .unwrap_or("unknown"); + + let (status_code, www_auth_value) = match error { + AuthenticationError::InactiveToken + | AuthenticationError::InvalidToken { description: _ } => ( + StatusCode::UNAUTHORIZED, + Some(self.create_www_auth_value(error_code, error)), + ), + AuthenticationError::InsufficientScope => ( + StatusCode::FORBIDDEN, + Some(self.create_www_auth_value(error_code, error)), + ), + AuthenticationError::TokenVerificationFailed { + description: _, + status_code, + } => { + if status_code.is_some_and(|s| s == StatusCode::FORBIDDEN) { + ( + StatusCode::FORBIDDEN, + Some(self.create_www_auth_value(error_code, error)), + ) + } else { + ( + status_code + .and_then(|v| StatusCode::from_u16(v).ok()) + .unwrap_or(StatusCode::BAD_REQUEST), + None, + ) + } + } + _ => (StatusCode::BAD_REQUEST, None), + }; + + let mut response = GenericBody::from_value(&as_json).into_json_response(status_code, None); + + if let Some(www_auth_value) = www_auth_value { + let Ok(www_auth_header_value) = HeaderValue::from_str(&www_auth_value) else { + return GenericBody::from_string("Unsupported WWW_AUTHENTICATE value".to_string()) + .into_response(StatusCode::INTERNAL_SERVER_ERROR, None); + }; + response + .headers_mut() + .append(WWW_AUTHENTICATE, www_auth_header_value); + } + + response + } +} + +#[async_trait] +impl Middleware for AuthMiddleware { + async fn handle<'req>( + &self, + mut req: Request<&'req str>, + state: Arc, + next: MiddlewareNext<'req>, + ) -> TransportServerResult> { + let auth_info = match self.validate(req.headers()).await { + Ok(auth_info) => auth_info, + Err(err) => { + return Ok(self.error_response(err)); + } + }; + req.extensions_mut().insert(auth_info); + next(req, state).await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::auth::AuthMetadataBuilder; + use crate::schema::{Implementation, InitializeResult, ProtocolVersion, ServerCapabilities}; + use crate::{ + auth::{OauthTokenVerifier, RemoteAuthProvider}, + error::SdkResult, + id_generator::{FastIdGenerator, UuidGenerator}, + mcp_server::{ServerHandler, ToMcpServerHandler}, + session_store::InMemorySessionStore, + }; + use crate::{mcp_http::GenericBodyExt, mcp_server::error::TransportServerError}; + use bytes::Bytes; + use http_body_util::combinators::BoxBody; + use http_body_util::BodyExt; + use std::time::Duration; + + pub struct TestTokenVerifier {} + + impl TestTokenVerifier { + pub fn new() -> Self { + Self {} + } + } + + pub(crate) async fn body_to_string( + body: BoxBody, + ) -> Result { + let bytes = body.collect().await?.to_bytes(); + Ok(String::from_utf8_lossy(&bytes).into_owned()) + } + + #[async_trait] + impl OauthTokenVerifier for TestTokenVerifier { + async fn verify_token( + &self, + access_token: String, + ) -> Result { + let info = match access_token.as_str() { + "valid-token" => AuthInfo { + token_unique_id: "valid-token".to_string(), + client_id: Some("client-id".to_string()), + user_id: None, + scopes: Some(vec!["read".to_string(), "write".to_string()]), + expires_at: Some(SystemTime::now() + Duration::from_secs(90)), + audience: None, + extra: None, + }, + "expired-token" => AuthInfo { + token_unique_id: "expired-token".to_string(), + client_id: Some("client-id".to_string()), + user_id: None, + scopes: Some(vec!["read".to_string(), "write".to_string()]), + expires_at: Some(SystemTime::now() - Duration::from_secs(90)), // 90 seconds in the past + audience: None, + extra: None, + }, + + "no-expiration-token" => AuthInfo { + token_unique_id: "no-expiration-token".to_string(), + client_id: Some("client-id".to_string()), + scopes: Some(vec!["read".to_string(), "write".to_string()]), + user_id: None, + expires_at: None, + audience: None, + extra: None, + }, + "insufficient-scope" => AuthInfo { + token_unique_id: "insufficient-scope".to_string(), + client_id: Some("client-id".to_string()), + scopes: Some(vec!["read".to_string()]), + user_id: None, + expires_at: Some(SystemTime::now() + Duration::from_secs(90)), + audience: None, + extra: None, + }, + _ => return Err(AuthenticationError::NotFound("Bad token".to_string())), + }; + + Ok(info) + } + } + + pub fn create_oauth_provider() -> SdkResult { + let auth_metadata = AuthMetadataBuilder::new("http://127.0.0.1:3000/mcp") + .issuer("http://localhost:8090") + .authorization_servers(vec!["http://localhost:8090"]) + .scopes_supported(vec![ + "mcp:tools".to_string(), + "read".to_string(), + "write".to_string(), + ]) + .introspection_endpoint("/introspect") + .authorization_endpoint("/authorize") + .token_endpoint("/token") + .resource_name("MCP Demo Server".to_string()) + .build() + .unwrap(); + + let token_verifier = TestTokenVerifier::new(); + + Ok(RemoteAuthProvider::new( + auth_metadata.0, + auth_metadata.1, + Box::new(token_verifier), + Some(vec!["read".to_string(), "write".to_string()]), + )) + } + struct TestHandler; + impl ServerHandler for TestHandler {} + fn app_state() -> Arc { + let handler = TestHandler {}; + + Arc::new(McpAppState { + session_store: Arc::new(InMemorySessionStore::new()), + id_generator: Arc::new(UuidGenerator {}), + stream_id_gen: Arc::new(FastIdGenerator::new(Some("s_"))), + server_details: Arc::new(InitializeResult { + capabilities: ServerCapabilities { + ..Default::default() + }, + instructions: None, + meta: None, + protocol_version: ProtocolVersion::V2025_06_18.to_string(), + server_info: Implementation { + name: "server".to_string(), + title: None, + version: "0.1.0".to_string(), + }, + }), + handler: handler.to_mcp_server_handler(), + ping_interval: Duration::from_secs(15), + transport_options: Arc::new(rust_mcp_transport::TransportOptions::default()), + enable_json_response: false, + event_store: None, + }) + } + + #[tokio::test] + //should call next when token is valid + async fn test_call_next_when_token_is_valid() { + let provider = create_oauth_provider().unwrap(); + let middleware = AuthMiddleware::new(Arc::new(provider)); + + let req = Request::builder() + .header(AUTHORIZATION, "Bearer valid-token") + .body("") + .unwrap(); + let res = middleware + .handle( + req, + app_state(), + Box::new(move |_req, _state| { + let resp = Response::builder() + .status(StatusCode::OK) + .body(GenericBody::from_string("reached".to_string())) + .unwrap(); + Box::pin(async { Ok(resp) }) + }), + ) + .await + .unwrap(); + let (parts, body) = res.into_parts(); + assert_eq!(body_to_string(body).await.unwrap(), "reached"); + assert_eq!(parts.status, StatusCode::OK) + } + + #[tokio::test] + //should reject expired tokens + async fn should_reject_expired_tokens() { + let provider = create_oauth_provider().unwrap(); + let middleware = AuthMiddleware::new(Arc::new(provider)); + + let req = Request::builder() + .header(AUTHORIZATION, "Bearer expired-token") + .body("") + .unwrap(); + let res = middleware + .handle( + req, + app_state(), + Box::new(move |_req, _state| { + let resp = Response::builder() + .status(StatusCode::OK) + .body(GenericBody::from_string("reached".to_string())) + .unwrap(); + Box::pin(async { Ok(resp) }) + }), + ) + .await + .unwrap(); + let (parts, body) = res.into_parts(); + + let body_string = body_to_string(body).await.unwrap(); + assert!(body_string.contains("Token has expired")); + assert!(body_string.contains("invalid_token")); + assert_eq!(parts.status, StatusCode::UNAUTHORIZED); + let header_value = parts + .headers + .get(WWW_AUTHENTICATE) + .unwrap() + .to_str() + .unwrap(); + assert!(header_value.contains(r#"Bearer error="invalid_token""#)) + } + + //should reject tokens with no expiration time + #[tokio::test] + async fn should_reject_tokens_with_no_expiration_time() { + let provider = create_oauth_provider().unwrap(); + let middleware = AuthMiddleware::new(Arc::new(provider)); + + let req = Request::builder() + .header(AUTHORIZATION, "Bearer no-expiration-token") + .body("") + .unwrap(); + let res = middleware + .handle( + req, + app_state(), + Box::new(move |_req, _state| { + let resp = Response::builder() + .status(StatusCode::OK) + .body(GenericBody::from_string("reached".to_string())) + .unwrap(); + Box::pin(async { Ok(resp) }) + }), + ) + .await + .unwrap(); + let (parts, body) = res.into_parts(); + + assert_eq!(parts.status, StatusCode::UNAUTHORIZED); + + let body_string = body_to_string(body).await.unwrap(); + assert!(body_string.contains("invalid_token")); + assert!(body_string.contains("Token has no expiration time")); + let header_value = parts + .headers + .get(WWW_AUTHENTICATE) + .unwrap() + .to_str() + .unwrap(); + assert!(header_value.contains(r#"Bearer error="invalid_token""#)) + } + + // should require specific scopes when configured + #[tokio::test] + async fn should_require_specific_scopes_when_configured() { + let provider = create_oauth_provider().unwrap(); + let middleware = AuthMiddleware::new(Arc::new(provider)); + + let req = Request::builder() + .header(AUTHORIZATION, "Bearer insufficient-scope") + .body("") + .unwrap(); + let res = middleware + .handle( + req, + app_state(), + Box::new(move |_req, _state| { + let resp = Response::builder() + .status(StatusCode::OK) + .body(GenericBody::from_string("reached".to_string())) + .unwrap(); + Box::pin(async { Ok(resp) }) + }), + ) + .await + .unwrap(); + let (parts, body) = res.into_parts(); + + assert_eq!(parts.status, StatusCode::FORBIDDEN); + + let body_string = body_to_string(body).await.unwrap(); + assert!(body_string.contains("insufficient_scope")); + assert!(body_string.contains("Insufficient scope")); + let header_value = parts + .headers + .get(WWW_AUTHENTICATE) + .unwrap() + .to_str() + .unwrap(); + assert!(header_value.contains(r#"Bearer error="insufficient_scope""#)) + } + + // should return 401 when no Authorization header is present + #[tokio::test] + async fn should_return_401_when_no_authorization_header_is_present() { + let provider = create_oauth_provider().unwrap(); + let middleware = AuthMiddleware::new(Arc::new(provider)); + + let req = Request::builder().body("").unwrap(); + let res = middleware + .handle( + req, + app_state(), + Box::new(move |_req, _state| { + let resp = Response::builder() + .status(StatusCode::OK) + .body(GenericBody::from_string("reached".to_string())) + .unwrap(); + Box::pin(async { Ok(resp) }) + }), + ) + .await + .unwrap(); + let (parts, body) = res.into_parts(); + + assert_eq!(parts.status, StatusCode::UNAUTHORIZED); + + let body_string = body_to_string(body).await.unwrap(); + assert!(body_string.contains("invalid_token")); + assert!(body_string.contains("Missing access token in Authorization header")); + let header_value = parts + .headers + .get(WWW_AUTHENTICATE) + .unwrap() + .to_str() + .unwrap(); + assert!(header_value.contains(r#"Bearer error="invalid_token""#)) + } + //should return 401 when Authorization header format is invalid + #[tokio::test] + async fn should_return_401_when_authorization_header_format_is_invalid() { + let provider = create_oauth_provider().unwrap(); + let middleware = AuthMiddleware::new(Arc::new(provider)); + + let req = Request::builder() + .header(AUTHORIZATION, "INVALID") + .body("") + .unwrap(); + let res = middleware + .handle( + req, + app_state(), + Box::new(move |_req, _state| { + let resp = Response::builder() + .status(StatusCode::OK) + .body(GenericBody::from_string("reached".to_string())) + .unwrap(); + Box::pin(async { Ok(resp) }) + }), + ) + .await + .unwrap(); + let (parts, body) = res.into_parts(); + + assert_eq!(parts.status, StatusCode::UNAUTHORIZED); + + let body_string = body_to_string(body).await.unwrap(); + assert!(body_string.contains("invalid_token")); + assert!(body_string.contains("Bearer TOKEN")); + let header_value = parts + .headers + .get(WWW_AUTHENTICATE) + .unwrap() + .to_str() + .unwrap(); + + assert!(header_value.contains(r#"Bearer error="invalid_token""#)); + + assert!(header_value.contains( + r#"resource_metadata="http://127.0.0.1/.well-known/oauth-protected-resource/mcp"# + )); + } +} diff --git a/crates/rust-mcp-sdk/src/mcp_http/middleware/cors_middleware.rs b/crates/rust-mcp-sdk/src/mcp_http/middleware/cors_middleware.rs index 2f2608d..08bbba1 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/middleware/cors_middleware.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/middleware/cors_middleware.rs @@ -11,11 +11,7 @@ //! - `Access-Control-Expose-Headers` support use crate::{ - mcp_http::{ - http_utils::{build_response, empty_response}, - types::GenericBody, - McpAppState, Middleware, MiddlewareNext, - }, + mcp_http::{types::GenericBody, GenericBodyExt, McpAppState, Middleware, MiddlewareNext}, mcp_server::error::TransportServerResult, }; use http::{ @@ -27,6 +23,7 @@ use http::{ }, Method, Request, Response, StatusCode, }; +use rust_mcp_transport::MCP_SESSION_ID_HEADER; use std::{collections::HashSet, sync::Arc}; /// Configuration for CORS behavior. @@ -45,7 +42,7 @@ pub struct CorsConfig { /// Whether to allow credentials (cookies, HTTP auth, etc). /// - /// **Important**: When `true`, `allow_origins` cannot be `Any` — browsers reject `*`. + /// **Important**: When `true`, `allow_origins` cannot be `Any` - browsers reject `*`. pub allow_credentials: bool, /// How long (in seconds) the preflight response can be cached. @@ -60,7 +57,11 @@ impl Default for CorsConfig { Self { allow_origins: AllowOrigins::Any, allow_methods: vec![Method::GET, Method::POST, Method::OPTIONS], - allow_headers: vec![header::CONTENT_TYPE, header::AUTHORIZATION], + allow_headers: vec![ + header::CONTENT_TYPE, + header::AUTHORIZATION, + HeaderName::from_static(MCP_SESSION_ID_HEADER), + ], allow_credentials: false, max_age: Some(86_400), // 24 hours expose_headers: vec![], @@ -87,7 +88,7 @@ pub enum AllowOrigins { /// /// Handles both **preflight** (`OPTIONS`) and **actual** requests, /// adding appropriate CORS headers and rejecting invalid origins/methods/headers. -#[derive(Clone)] +#[derive(Clone, Default)] pub struct CorsMiddleware { config: Arc, } @@ -100,7 +101,7 @@ impl CorsMiddleware { } } - /// Create a permissive CORS config — useful for public APIs or local dev. + /// Create a permissive CORS config - useful for public APIs or local dev. /// /// Allows all common methods, credentials, and common headers. pub fn permissive() -> Self { @@ -158,7 +159,7 @@ impl CorsMiddleware { let allowed_origin = self.resolve_allowed_origin(origin); let mut resp = Response::builder() .status(StatusCode::NO_CONTENT) - .body(empty_response()) + .body(GenericBody::empty()) .expect("preflight response is static"); let headers = resp.headers_mut(); @@ -293,36 +294,38 @@ impl Middleware for CorsMiddleware { let origin = match origin { Some(o) => o, None => { - // Some tools send preflight without Origin — allow if Any + // Some tools send preflight without Origin - allow if Any if matches!(self.config.allow_origins, AllowOrigins::Any) && !self.config.allow_credentials { return Ok(self.preflight_response("*")); } else { - let response = build_response( + return Ok(GenericBody::build_response( StatusCode::BAD_REQUEST, "CORS origin missing in preflight".to_string(), - ); - return response; + None, + )); } } }; // Validate origin if self.resolve_allowed_origin(&origin).is_none() { - let response = - build_response(StatusCode::FORBIDDEN, "CORS origin not allowed".to_string()); - return response; + return Ok(GenericBody::build_response( + StatusCode::FORBIDDEN, + "CORS origin not allowed".to_string(), + None, + )); } // Validate method if let Some(m) = requested_method { if !self.config.allow_methods.contains(&m) { - let response = build_response( + return Ok(GenericBody::build_response( StatusCode::METHOD_NOT_ALLOWED, "CORS method not allowed".to_string(), - ); - return response; + None, + )); } } @@ -335,14 +338,14 @@ impl Middleware for CorsMiddleware { .collect::>(); if !requested_headers.is_subset(&allowed) { - let response = build_response( + return Ok(GenericBody::build_response( StatusCode::BAD_REQUEST, "CORS header not allowed".to_string(), - ); - return response; + None, + )); } - // All good — return preflight + // All good - return preflight return Ok(self.preflight_response(&origin)); } @@ -404,7 +407,7 @@ mod tests { } fn make_handler<'req>(status: StatusCode, body: &'static str) -> MiddlewareNext<'req> { - Arc::new(move |_, _| { + Box::new(move |_, _| { let resp = Response::builder() .status(status) .body(GenericBody::from_string(body.to_string())) diff --git a/crates/rust-mcp-sdk/src/mcp_http/middleware/dns_rebind_protector.rs b/crates/rust-mcp-sdk/src/mcp_http/middleware/dns_rebind_protector.rs index 77e7013..4cdd7ef 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/middleware/dns_rebind_protector.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/middleware/dns_rebind_protector.rs @@ -15,9 +15,7 @@ //! - If allowlist is `None` or empty → that check is skipped use crate::{ - mcp_http::{ - error_response, middleware::BoxFutureResponse, types::GenericBody, McpAppState, Middleware, - }, + mcp_http::{error_response, types::GenericBody, McpAppState, Middleware, MiddlewareNext}, mcp_server::error::TransportServerResult, schema::schema_utils::SdkError, }; @@ -73,9 +71,7 @@ impl Middleware for DnsRebindProtector { &self, req: Request<&'req str>, state: Arc, - next: Arc< - dyn Fn(Request<&'req str>, Arc) -> BoxFutureResponse<'req> + Send + Sync, - >, + next: MiddlewareNext<'req>, ) -> TransportServerResult> { if let Err(error) = self.protect_dns_rebinding(req.headers()).await { return error_response(StatusCode::FORBIDDEN, error); diff --git a/crates/rust-mcp-sdk/src/mcp_http/middleware/logging_middleware.rs b/crates/rust-mcp-sdk/src/mcp_http/middleware/logging_middleware.rs index 49f2e52..957e04d 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/middleware/logging_middleware.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/middleware/logging_middleware.rs @@ -5,7 +5,7 @@ //! responses. In a real-world application, you might extend this to //! include structured logging, tracing, timing, or error reporting. use crate::{ - mcp_http::{middleware::BoxFutureResponse, types::GenericBody, McpAppState, Middleware}, + mcp_http::{types::GenericBody, McpAppState, Middleware, MiddlewareNext}, mcp_server::error::TransportServerResult, }; use async_trait::async_trait; @@ -24,9 +24,7 @@ impl Middleware for LoggingMiddleware { &self, req: Request<&'req str>, state: Arc, - next: Arc< - dyn Fn(Request<&'req str>, Arc) -> BoxFutureResponse<'req> + Send + Sync, - >, + next: MiddlewareNext<'req>, ) -> TransportServerResult> { println!("➡️ Logging request: {}", req.uri()); let res = next(req, state).await?; diff --git a/crates/rust-mcp-sdk/src/mcp_http/types.rs b/crates/rust-mcp-sdk/src/mcp_http/types.rs index 59645d2..b6db3dc 100644 --- a/crates/rust-mcp-sdk/src/mcp_http/types.rs +++ b/crates/rust-mcp-sdk/src/mcp_http/types.rs @@ -3,14 +3,44 @@ use crate::{ mcp_server::error::{TransportServerError, TransportServerResult}, }; use bytes::Bytes; -use http::{Request, Response}; +use futures::future::BoxFuture; +use http::{ + header::{ALLOW, CONTENT_TYPE}, + HeaderMap, HeaderName, HeaderValue, Method, Request, Response, StatusCode, +}; use http_body_util::{combinators::BoxBody, BodyExt, Full}; -use std::{future::Future, pin::Pin, sync::Arc}; +use serde_json::Value; +use std::sync::Arc; pub type GenericBody = BoxBody; pub trait GenericBodyExt { fn from_string(s: String) -> Self; + fn from_value(value: &Value) -> Self; + fn empty() -> Self; + fn build_response( + status_code: StatusCode, + payload: String, + headers: Option, + ) -> http::Response; + fn into_response( + self, + status_code: StatusCode, + headers: Option, + ) -> http::Response; + + fn into_json_response( + self, + status_code: StatusCode, + headers: Option, + ) -> http::Response; + + fn create_404_response() -> http::Response; + + fn create_405_response( + method: &Method, + allowed_methods: &[Method], + ) -> http::Response; } impl GenericBodyExt for GenericBody { @@ -19,23 +49,128 @@ impl GenericBodyExt for GenericBody { .map_err(|err| TransportServerError::HttpError(err.to_string())) .boxed() } + + fn from_value(value: &Value) -> Self { + let bytes = match serde_json::to_vec(value) { + Ok(vec) => Bytes::from(vec), + Err(_) => Bytes::from_static(b"{\"error\":\"internal_error\"}"), + }; + Full::new(bytes) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed() + } + + fn empty() -> Self { + Full::new(Bytes::new()) + .map_err(|err| TransportServerError::HttpError(err.to_string())) + .boxed() + } + + fn build_response( + status_code: StatusCode, + payload: String, + headers: Option, + ) -> http::Response { + let body = Self::from_string(payload); + body.into_response(status_code, headers) + } + + fn into_json_response( + self, + status_code: StatusCode, + headers: Option, + ) -> http::Response { + let mut headers = headers.unwrap_or_default(); + headers.append(CONTENT_TYPE, HeaderValue::from_static("application/json")); + self.into_response(status_code, Some(headers)) + } + + fn into_response( + self, + status_code: StatusCode, + headers: Option, + ) -> http::Response { + let mut resp = http::Response::new(self); + *resp.status_mut() = status_code; + + if let Some(mut headers) = headers { + let mut current_name: Option = None; + for (name_opt, value) in headers.drain() { + if let Some(name) = name_opt { + current_name = Some(name); + } + let name = current_name.as_ref().unwrap(); + resp.headers_mut().append(name.clone(), value); + } + } + if !resp.headers().contains_key(CONTENT_TYPE) { + resp.headers_mut() + .append(CONTENT_TYPE, HeaderValue::from_static("text/plain")); + } + resp + } + + fn create_404_response() -> http::Response { + Self::empty().into_response(StatusCode::NOT_FOUND, None) + } + + fn create_405_response( + method: &Method, + allowed_methods: &[Method], + ) -> http::Response { + let allow_header_value = HeaderValue::from_str( + allowed_methods + .iter() + .map(|m| m.as_str()) + .collect::>() + .join(", ") + .as_str(), + ) + .unwrap_or(HeaderValue::from_static("unknown")); + let mut response = Self::from_string(format!( + "The method {method} is not allowed for this endpoint" + )) + .into_response(StatusCode::METHOD_NOT_ALLOWED, None); + response.headers_mut().append(ALLOW, allow_header_value); + response + } } -pub type BoxFutureResponse<'req> = - Pin>> + Send + 'req>>; +pub trait RequestExt { + fn insert(&mut self, val: T); + fn get(&self) -> Option<&T>; + fn take(self) -> (Self, Option) + where + Self: Sized; +} + +impl RequestExt for http::Request<&str> { + fn insert(&mut self, val: T) { + self.extensions_mut().insert(val); + } + + fn get(&self) -> Option<&T> { + self.extensions().get::() + } + + fn take(mut self) -> (Self, Option) { + let exts = self.extensions_mut(); + let val = exts.remove::(); + (self, val) + } +} -// Define a short alias for your handler function type. -/// A handler function that processes an HTTP request and shared state, -/// returning an async response future. -pub type RequestHandlerFn = - dyn for<'req> Fn(Request<&'req str>, Arc) -> BoxFutureResponse<'req> + Send + Sync; +pub type BoxFutureResponse<'req> = BoxFuture<'req, TransportServerResult>>; +// pub type BoxFutureResponse<'req> = +// Pin>> + Send + 'req>>; -/// A shared, reference-counted request handler. -pub type RequestHandler = Arc; +// Handler function type (can only be called once) +pub type RequestHandlerFnOnce = + dyn for<'req> FnOnce(Request<&'req str>, Arc) -> BoxFutureResponse<'req> + Send; -// pub type RequestHandler = Arc< -// dyn for<'req> FnOnce(Request<&'req str>) -> BoxFutureResponse<'req> + Send + Sync -// >; +// RequestHandler cannot be Arc<...> anymore because FnOnce isn’t clonable +pub type RequestHandler = Box; +// Middleware "next" closure type - can only be called once pub type MiddlewareNext<'req> = - Arc, Arc) -> BoxFutureResponse<'req> + Send + Sync>; + Box, Arc) -> BoxFutureResponse<'req> + Send>; diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs index 2093dc3..70a18d2 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs @@ -2,9 +2,7 @@ pub mod mcp_client_runtime; pub mod mcp_client_runtime_core; use crate::error::{McpSdkError, SdkResult}; use crate::id_generator::FastIdGenerator; -use crate::mcp_traits::mcp_client::McpClient; -use crate::mcp_traits::mcp_handler::McpClientHandler; -use crate::mcp_traits::IdGenerator; +use crate::mcp_traits::{IdGenerator, McpClient, McpClientHandler}; use crate::utils::ensure_server_protocole_compatibility; use crate::{ mcp_traits::{RequestIdGen, RequestIdGenNumeric}, diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs index 43a7079..06964ed 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs @@ -13,10 +13,7 @@ use async_trait::async_trait; use rust_mcp_transport::StreamableTransportOptions; use rust_mcp_transport::TransportDispatcher; -use crate::{ - error::SdkResult, mcp_client::ClientHandler, mcp_traits::mcp_handler::McpClientHandler, - McpClient, -}; +use crate::{error::SdkResult, mcp_client::ClientHandler, mcp_traits::McpClientHandler, McpClient}; use super::ClientRuntime; diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs index 884de9d..21c0a4a 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs @@ -9,7 +9,7 @@ use crate::schema::{ use crate::{ error::SdkResult, mcp_handlers::mcp_client_handler_core::ClientHandlerCore, - mcp_traits::{mcp_client::McpClient, mcp_handler::McpClientHandler}, + mcp_traits::{McpClient, McpClientHandler}, }; use async_trait::async_trait; #[cfg(feature = "streamable-http")] diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs index 5502cee..a429bae 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs @@ -1,9 +1,8 @@ pub mod mcp_server_runtime; pub mod mcp_server_runtime_core; +use crate::auth::AuthInfo; use crate::error::SdkResult; -use crate::mcp_traits::mcp_handler::McpServerHandler; -use crate::mcp_traits::mcp_server::McpServer; -use crate::mcp_traits::{RequestIdGen, RequestIdGenNumeric}; +use crate::mcp_traits::{McpServer, McpServerHandler, RequestIdGen, RequestIdGenNumeric}; use crate::schema::{ schema_utils::{ ClientMessage, ClientMessages, FromMessage, MessageFromServer, SdkError, ServerMessage, @@ -23,8 +22,7 @@ use std::panic; use std::sync::Arc; use std::time::Duration; use tokio::io::AsyncWriteExt; - -use tokio::sync::{mpsc, oneshot, watch}; +use tokio::sync::{mpsc, oneshot, watch, RwLock, RwLockReadGuard}; pub const DEFAULT_STREAM_ID: &str = "STANDALONE-STREAM"; const TASK_CHANNEL_CAPACITY: usize = 500; @@ -52,6 +50,7 @@ pub struct ServerRuntime { request_id_gen: Box, client_details_tx: watch::Sender>, client_details_rx: watch::Receiver>, + auth_info: tokio::sync::RwLock>, } #[async_trait] @@ -67,6 +66,30 @@ impl McpServer for ServerRuntime { }) } + async fn update_auth_info(&self, new_auth_info: Option) { + let should_update = { + let current = self.auth_info.read().await; + match (&*current, &new_auth_info) { + (None, Some(_)) => true, + (Some(old), Some(new)) => old.token_unique_id != new.token_unique_id, + (Some(_), None) => true, + (None, None) => false, + } + }; + + if should_update { + *self.auth_info.write().await = new_auth_info; + } + } + + async fn auth_info(&self) -> RwLockReadGuard<'_, Option> { + self.auth_info.read().await + } + async fn auth_info_cloned(&self) -> Option { + let guard = self.auth_info.read().await; + guard.clone() + } + async fn wait_for_initialization(&self) { loop { if self.client_details_rx.borrow().is_some() { @@ -548,7 +571,10 @@ impl ServerRuntime { server_details: Arc, handler: Arc, session_id: SessionId, + auth_info: Option, ) -> Arc { + use tokio::sync::RwLock; + let (client_details_tx, client_details_rx) = watch::channel::>(None); Arc::new(Self { @@ -559,6 +585,7 @@ impl ServerRuntime { client_details_tx, client_details_rx, request_id_gen: Box::new(RequestIdGenNumeric::new(None)), + auth_info: RwLock::new(auth_info), }) } @@ -586,6 +613,7 @@ impl ServerRuntime { client_details_tx, client_details_rx, request_id_gen: Box::new(RequestIdGenNumeric::new(None)), + auth_info: RwLock::new(None), }) } } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs index 62fd31f..c4eeb81 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs @@ -1,5 +1,7 @@ use std::sync::Arc; +#[cfg(feature = "hyper-server")] +use crate::auth::AuthInfo; use crate::schema::{ schema_utils::{ self, CallToolError, ClientMessage, ClientMessages, MessageFromServer, @@ -18,7 +20,7 @@ use rust_mcp_transport::SessionId; use crate::{ error::SdkResult, mcp_handlers::mcp_server_handler::ServerHandler, - mcp_traits::{mcp_handler::McpServerHandler, mcp_server::McpServer}, + mcp_traits::{McpServer, McpServerHandler}, }; /// Creates a new MCP server runtime with the specified configuration. @@ -62,8 +64,9 @@ pub(crate) fn create_server_instance( server_details: Arc, handler: Arc, session_id: SessionId, + auth_info: Option, ) -> Arc { - ServerRuntime::new_instance(server_details, handler, session_id) + ServerRuntime::new_instance(server_details, handler, session_id, auth_info) } pub(crate) struct ServerRuntimeInternalHandler { diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs index 110b20b..c617cea 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs @@ -1,8 +1,7 @@ use super::ServerRuntime; use crate::error::SdkResult; use crate::mcp_handlers::mcp_server_handler_core::ServerHandlerCore; -use crate::mcp_traits::mcp_handler::McpServerHandler; -use crate::mcp_traits::mcp_server::McpServer; +use crate::mcp_traits::{McpServer, McpServerHandler}; use crate::schema::schema_utils::{ self, ClientMessage, MessageFromServer, NotificationFromClient, RequestFromClient, ResultFromServer, ServerMessage, diff --git a/crates/rust-mcp-sdk/src/mcp_traits.rs b/crates/rust-mcp-sdk/src/mcp_traits.rs index b66ba93..f265829 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits.rs @@ -1,10 +1,15 @@ pub(super) mod id_generator; #[cfg(feature = "client")] -pub mod mcp_client; -pub mod mcp_handler; +mod mcp_client; +mod mcp_handler; #[cfg(feature = "server")] -pub mod mcp_server; +mod mcp_server; mod request_id_gen; pub use id_generator::*; +#[cfg(feature = "client")] +pub use mcp_client::*; +pub use mcp_handler::*; +#[cfg(feature = "server")] +pub use mcp_server::*; pub use request_id_gen::*; diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs index da087d1..43e04f1 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs @@ -1,3 +1,6 @@ +use crate::auth::AuthInfo; +use crate::{error::SdkResult, utils::format_assertion_message}; + use crate::schema::{ schema_utils::{ ClientMessage, McpMessage, MessageFromServer, NotificationFromServer, RequestFromServer, @@ -14,10 +17,10 @@ use crate::schema::{ ResourceUpdatedNotification, ResourceUpdatedNotificationParams, RpcError, ServerCapabilities, SetLevelRequest, ToolListChangedNotification, ToolListChangedNotificationParams, }; -use crate::{error::SdkResult, utils::format_assertion_message}; use async_trait::async_trait; use rust_mcp_transport::SessionId; use std::{sync::Arc, time::Duration}; +use tokio::sync::RwLockReadGuard; //TODO: support options , such as enforceStrictCapabilities #[async_trait] @@ -27,6 +30,10 @@ pub trait McpServer: Sync + Send { fn server_info(&self) -> &InitializeResult; fn client_info(&self) -> Option; + async fn auth_info(&self) -> RwLockReadGuard<'_, Option>; + async fn auth_info_cloned(&self) -> Option; + async fn update_auth_info(&self, auth_info: Option); + async fn wait_for_initialization(&self); async fn send( diff --git a/crates/rust-mcp-sdk/src/utils.rs b/crates/rust-mcp-sdk/src/utils.rs index 2d80f1e..c63010d 100644 --- a/crates/rust-mcp-sdk/src/utils.rs +++ b/crates/rust-mcp-sdk/src/utils.rs @@ -2,6 +2,9 @@ use crate::error::{McpSdkError, ProtocolErrorKind, SdkResult}; use crate::schema::schema_utils::{ClientMessages, SdkError}; use crate::schema::ProtocolVersion; use std::cmp::Ordering; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; +#[cfg(feature = "auth")] +use url::Url; /// A guard type that automatically aborts a Tokio task when dropped. /// @@ -42,6 +45,11 @@ pub fn format_assertion_message(entity: &str, capability: &str, method_name: &st format!("{entity} does not support {capability} (required for {method_name})") } +// Function to convert Unix timestamp to SystemTime +pub fn unix_timestamp_to_systemtime(timestamp: u64) -> SystemTime { + UNIX_EPOCH + Duration::from_secs(timestamp) +} + /// Checks if the client and server protocol versions are compatible by ensuring they are equal. /// /// This function compares the provided client and server protocol versions. If they are equal, @@ -233,6 +241,30 @@ pub fn valid_initialize_method(json_str: &str) -> SdkResult<()> { Ok(()) } +#[cfg(feature = "auth")] +pub fn join_url(base: &Url, segment: &str) -> Result { + // Fast early check - Url must be absolute + if base.cannot_be_a_base() { + return Err(url::ParseError::RelativeUrlWithoutBase); + } + + // We have to clone - there is no way around this when taking &Url + let mut url = base.clone(); + + // This is the official, safe, and correct way + url.path_segments_mut() + .map_err(|_| url::ParseError::RelativeUrlWithoutBase)? + .pop_if_empty() // makes it act like a directory + .extend( + segment + .trim_start_matches('/') + .split('/') + .filter(|s| !s.is_empty()), + ); + + Ok(url) +} + #[cfg(test)] mod tests { use super::*; @@ -250,4 +282,36 @@ mod tests { ); assert_eq!(remove_query_and_hash("/"), "/"); } + + #[test] + fn test_join_url() { + let expect = "http://example.com/api/user/userinfo"; + let result = join_url( + &Url::parse("http://example.com/api").unwrap(), + "/user/userinfo", + ) + .unwrap(); + assert_eq!(result.to_string(), expect); + + let result = join_url( + &Url::parse("http://example.com/api").unwrap(), + "user/userinfo", + ) + .unwrap(); + assert_eq!(result.to_string(), expect); + + let result = join_url( + &Url::parse("http://example.com/api/").unwrap(), + "/user/userinfo", + ) + .unwrap(); + assert_eq!(result.to_string(), expect); + + let result = join_url( + &Url::parse("http://example.com/api/").unwrap(), + "user/userinfo", + ) + .unwrap(); + assert_eq!(result.to_string(), expect); + } } diff --git a/crates/rust-mcp-sdk/tests/common/common.rs b/crates/rust-mcp-sdk/tests/common/common.rs index d6b45f7..8e61704 100644 --- a/crates/rust-mcp-sdk/tests/common/common.rs +++ b/crates/rust-mcp-sdk/tests/common/common.rs @@ -8,6 +8,7 @@ use rust_mcp_macros::{mcp_tool, JsonSchema}; use rust_mcp_schema::ProtocolVersion; use rust_mcp_sdk::mcp_client::ClientHandler; +use rust_mcp_sdk::auth::{AuthInfo, AuthenticationError, OauthTokenVerifier}; use rust_mcp_sdk::schema::{ClientCapabilities, Implementation, InitializeRequestParams}; use std::collections::HashMap; use std::process; @@ -142,6 +143,31 @@ pub async fn send_get_request( client.get(url).headers(headers).send().await } +pub async fn send_option_request( + base_url: &str, + extra_headers: Option>, +) -> Result { + let client = Client::new(); + let url = Url::parse(base_url).expect("Invalid URL"); + + let mut headers = reqwest::header::HeaderMap::new(); + + if let Some(extra) = extra_headers { + for (key, value) in extra { + headers.insert( + reqwest::header::HeaderName::from_bytes(key.as_bytes()).unwrap(), + value.parse().unwrap(), + ); + } + } + + client + .request(reqwest::Method::OPTIONS, url) + .headers(headers) + .send() + .await +} + use futures::stream::Stream; // stream: &mut impl Stream>, @@ -260,11 +286,10 @@ impl Xorshift { // Generate the next random u64 using Xorshift fn next_u64(&mut self) -> u64 { let mut x = self.state; - x ^= x << 13; - x ^= x >> 7; - x ^= x << 17; - self.state = x; - x + self.state = x.wrapping_add(0x9E3779B97F4A7C15); + x = (x ^ (x >> 30)).wrapping_mul(0xBF58476D1CE4E5B9); + x = (x ^ (x >> 27)).wrapping_mul(0x94D049BB133111EB); + x ^ (x >> 31) } // Generate a random u16 within a range [min, max] @@ -345,6 +370,35 @@ pub mod sample_tools { } } + //******************// + // AuthInfo Tool // + //******************// + #[mcp_tool( + name = "display_auth_info", + description = "Displays auth_info if user is authenticated", + idempotent_hint = false, + destructive_hint = false, + open_world_hint = false, + read_only_hint = false + )] + #[derive(Debug, ::serde::Deserialize, ::serde::Serialize, JsonSchema)] + pub struct DisplayAuthInfo {} + use rust_mcp_sdk::auth::AuthInfo; + impl DisplayAuthInfo { + pub fn call_tool( + &self, + auth_info: Option, + ) -> Result { + let message = format!("{}", serde_json::to_string(&auth_info).unwrap()); + #[cfg(feature = "2025_06_18")] + return Ok(CallToolResult::text_content(vec![ + rust_mcp_sdk::schema::TextContent::from(message), + ])); + #[cfg(not(feature = "2025_06_18"))] + return Ok(CallToolResult::text_content(message, None)); + } + } + //******************// // SayGoodbyeTool // //******************// @@ -420,7 +474,6 @@ pub async fn wiremock_request(mock_server: &MockServer, index: usize) -> Request pub async fn debug_wiremock(mock_server: &MockServer) { let requests = mock_server.received_requests().await.unwrap(); let len = requests.len(); - println!(">>> {len} request(s) received <<<"); for (index, request) in requests.iter().enumerate() { println!("\n--- #{index} of {len} ---"); @@ -460,3 +513,32 @@ pub async fn wait_for_n_requests( .await .unwrap(); } + +pub struct TestTokenVerifier { + token_map: HashMap, +} + +impl TestTokenVerifier { + pub fn new(token_map: HashMap) -> Self { + Self { token_map } + } +} + +#[async_trait] +impl OauthTokenVerifier for TestTokenVerifier { + async fn verify_token(&self, access_token: String) -> Result { + let info = self.token_map.get(&access_token); + + let Some(info) = info else { + return Err(AuthenticationError::InactiveToken); + }; + + if info.expires_at.unwrap() < SystemTime::now() { + return Err(AuthenticationError::InvalidOrExpiredToken( + "expired".to_string(), + )); + } + + Ok(info.clone()) + } +} diff --git a/crates/rust-mcp-sdk/tests/common/mock_server.rs b/crates/rust-mcp-sdk/tests/common/mock_server.rs index f5b533a..59c9b16 100644 --- a/crates/rust-mcp-sdk/tests/common/mock_server.rs +++ b/crates/rust-mcp-sdk/tests/common/mock_server.rs @@ -330,7 +330,6 @@ impl MockServerHandle { let requests = self.get_history().await; let len = requests.len(); - println!("\n>>> {len} request(s) received <<<"); for (index, (request, response)) in requests.iter().enumerate() { println!( diff --git a/crates/rust-mcp-sdk/tests/common/test_server.rs b/crates/rust-mcp-sdk/tests/common/test_server.rs index d64244b..9c8e6ee 100644 --- a/crates/rust-mcp-sdk/tests/common/test_server.rs +++ b/crates/rust-mcp-sdk/tests/common/test_server.rs @@ -1,6 +1,6 @@ #[cfg(feature = "hyper-server")] pub mod test_server_common { - use crate::common::sample_tools::SayHelloTool; + use crate::common::sample_tools::{DisplayAuthInfo, SayHelloTool}; use async_trait::async_trait; use rust_mcp_schema::schema_utils::CallToolError; use rust_mcp_schema::{ @@ -95,20 +95,27 @@ pub mod test_server_common { runtime .assert_server_request_capabilities(request.method()) .map_err(CallToolError::new)?; - if request.params.name != "say_hello" { - Ok( - CallToolError::unknown_tool(format!("Unknown tool: {}", request.params.name)) - .into(), - ) - } else { - let tool = SayHelloTool { - name: request.params.arguments.unwrap()["name"] - .as_str() - .unwrap() - .to_string(), - }; - - Ok(tool.call_tool().unwrap()) + + match request.params.name.as_str() { + "say_hello" => { + let tool = SayHelloTool { + name: request.params.arguments.unwrap()["name"] + .as_str() + .unwrap() + .to_string(), + }; + + Ok(tool.call_tool().unwrap()) + } + "display_auth_info" => { + let tool = DisplayAuthInfo {}; + Ok(tool.call_tool(runtime.auth_info_cloned().await).unwrap()) + } + _ => Ok(CallToolError::unknown_tool(format!( + "Unknown tool: {}", + request.params.name + )) + .into()), } } } diff --git a/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs b/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs index 3592c97..43f162d 100644 --- a/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs +++ b/crates/rust-mcp-sdk/tests/test_streamable_http_server.rs @@ -1,5 +1,12 @@ -use std::{collections::HashMap, error::Error, sync::Arc, time::Duration, vec}; - +use crate::common::{ + random_port, read_sse_event, read_sse_event_from_stream, send_delete_request, send_get_request, + send_option_request, send_post_request, + test_server_common::{ + create_start_server, initialize_request, LaunchedServer, TestIdGenerator, + }, + TestTokenVerifier, +}; +use http::header::{ACCEPT, ACCESS_CONTROL_ALLOW_ORIGIN, AUTHORIZATION, CONTENT_TYPE}; use hyper::StatusCode; use rust_mcp_schema::{ schema_utils::{ @@ -12,36 +19,89 @@ use rust_mcp_schema::{ LoggingMessageNotificationParams, RequestId, RootsListChangedNotification, ServerNotification, ServerRequest, ServerResult, }; -use rust_mcp_sdk::{event_store::InMemoryEventStore, mcp_server::HyperServerOptions}; +use rust_mcp_sdk::{ + auth::{AuthInfo, AuthMetadataBuilder, AuthProvider, RemoteAuthProvider}, + event_store::InMemoryEventStore, + mcp_server::HyperServerOptions, +}; use serde_json::{json, Map, Value}; - -use crate::common::{ - random_port, read_sse_event, read_sse_event_from_stream, send_delete_request, send_get_request, - send_post_request, - test_server_common::{ - create_start_server, initialize_request, LaunchedServer, TestIdGenerator, - }, +use std::{ + collections::HashMap, + error::Error, + sync::Arc, + time::{Duration, SystemTime}, + vec, }; +use url::Url; #[path = "common/common.rs"] pub mod common; const ONE_MILLISECOND: Option = Some(Duration::from_millis(1)); +pub const VALID_ACCESS_TOKEN: &str = "valid-access-token"; async fn initialize_server( enable_json_response: Option, + auth_token_map: Option>, ) -> Result<(LaunchedServer, String), Box> { let json_rpc_message: ClientJsonrpcRequest = ClientJsonrpcRequest::new(RequestId::Integer(0), initialize_request().into()); + let port = random_port(); + + let auth = auth_token_map.and_then(|token_map| { + let mut token_map: HashMap = token_map + .into_iter() + .map(|(k, v)| (k.to_string(), v)) + .collect(); + + token_map.insert( + VALID_ACCESS_TOKEN.to_string(), + AuthInfo { + token_unique_id: VALID_ACCESS_TOKEN.to_string(), + client_id: Some("valid-client-id".to_string()), + scopes: Some(vec!["mcp".to_string(), "mcp:tools".to_string()]), + expires_at: Some(SystemTime::now() + Duration::from_secs(90)), + audience: None, + extra: None, + user_id: None, + }, + ); + + let (auth_server_meta, protected_resource_met) = + AuthMetadataBuilder::new("http://127.0.0.1:3000/mcp") + .issuer("http://localhost:3030") + .authorization_servers(vec!["http://localhost:3030"]) + .authorization_endpoint("/authorize") + .token_endpoint("/token") + .scopes_supported(vec!["mcp:tools".to_string()]) + .introspection_endpoint("/introspect") + .resource_name("MCP Test Server".to_string()) + .build() + .unwrap(); + + let token_verifier = TestTokenVerifier::new(token_map); + + Some(RemoteAuthProvider::new( + auth_server_meta, + protected_resource_met, + Box::new(token_verifier), + None, + )) + }); + + let oauth_metadata_provider = auth.map(|v| -> Arc { Arc::new(v) }); + let server_options = HyperServerOptions { - port: random_port(), + port, session_id_generator: Some(Arc::new(TestIdGenerator::new(vec![ "AAA-BBB-CCC".to_string() ]))), enable_json_response, ping_interval: Duration::from_secs(1), event_store: Some(Arc::new(InMemoryEventStore::default())), + auth: oauth_metadata_provider, + ..Default::default() }; @@ -52,7 +112,14 @@ async fn initialize_server( &server.streamable_url, &serde_json::to_string(&json_rpc_message).unwrap(), None, - None, + Some(HashMap::from([ + ( + AUTHORIZATION.as_str(), + format!("Bearer {VALID_ACCESS_TOKEN}").as_str(), + ), + (CONTENT_TYPE.as_str(), "application/json"), + (ACCEPT.as_str(), "application/json, text/event-stream"), + ])), ) .await .expect("Request failed"); @@ -154,7 +221,7 @@ async fn should_reject_batch_initialize_request() { // should handle post requests via sse response correctly #[tokio::test] async fn should_handle_post_requests_via_sse_response_correctly() { - let (server, session_id) = initialize_server(None).await.unwrap(); + let (server, session_id) = initialize_server(None, None).await.unwrap(); let json_rpc_message: ClientJsonrpcRequest = ClientJsonrpcRequest::new(RequestId::Integer(1), ListToolsRequest::new(None).into()); @@ -196,7 +263,7 @@ async fn should_handle_post_requests_via_sse_response_correctly() { // should call a tool and return the result #[tokio::test] async fn should_call_a_tool_and_return_the_result() { - let (server, session_id) = initialize_server(None).await.unwrap(); + let (server, session_id) = initialize_server(None, None).await.unwrap(); let mut map = Map::new(); map.insert("name".to_string(), Value::String("Ali".to_string())); @@ -244,7 +311,7 @@ async fn should_call_a_tool_and_return_the_result() { // should reject requests without a valid session ID #[tokio::test] async fn should_reject_requests_without_a_valid_session_id() { - let (server, _session_id) = initialize_server(None).await.unwrap(); + let (server, _session_id) = initialize_server(None, None).await.unwrap(); let json_rpc_message: ClientJsonrpcRequest = ClientJsonrpcRequest::new(RequestId::Integer(1), ListToolsRequest::new(None).into()); @@ -269,7 +336,7 @@ async fn should_reject_requests_without_a_valid_session_id() { // should reject invalid session ID #[tokio::test] async fn should_reject_invalid_session_id() { - let (server, _session_id) = initialize_server(None).await.unwrap(); + let (server, _session_id) = initialize_server(None, None).await.unwrap(); let json_rpc_message: ClientJsonrpcRequest = ClientJsonrpcRequest::new(RequestId::Integer(1), ListToolsRequest::new(None).into()); @@ -315,7 +382,7 @@ async fn get_standalone_stream( // should establish standalone SSE stream and receive server-initiated messages #[tokio::test] async fn should_establish_standalone_stream_and_receive_server_messages() { - let (server, session_id) = initialize_server(None).await.unwrap(); + let (server, session_id) = initialize_server(None, None).await.unwrap(); let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; assert_eq!(response.status(), StatusCode::OK); @@ -377,7 +444,7 @@ async fn should_establish_standalone_stream_and_receive_server_messages() { // should establish standalone SSE stream and receive server-initiated requests #[tokio::test] async fn should_establish_standalone_stream_and_receive_server_requests() { - let (server, session_id) = initialize_server(None).await.unwrap(); + let (server, session_id) = initialize_server(None, None).await.unwrap(); let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; assert_eq!(response.status(), StatusCode::OK); @@ -462,7 +529,7 @@ async fn should_establish_standalone_stream_and_receive_server_requests() { // should not close GET SSE stream after sending multiple server notifications #[tokio::test] async fn should_not_close_get_sse_stream() { - let (server, session_id) = initialize_server(None).await.unwrap(); + let (server, session_id) = initialize_server(None, None).await.unwrap(); let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; assert_eq!(response.status(), StatusCode::OK); @@ -533,7 +600,7 @@ async fn should_not_close_get_sse_stream() { //should reject second SSE stream for the same session #[tokio::test] async fn should_reject_second_sse_stream_for_the_same_session() { - let (server, session_id) = initialize_server(None).await.unwrap(); + let (server, session_id) = initialize_server(None, None).await.unwrap(); let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; assert_eq!(response.status(), StatusCode::OK); @@ -550,7 +617,7 @@ async fn should_reject_second_sse_stream_for_the_same_session() { // should reject GET requests without Accept: text/event-stream header #[tokio::test] async fn should_reject_get_requests() { - let (server, session_id) = initialize_server(None).await.unwrap(); + let (server, session_id) = initialize_server(None, None).await.unwrap(); let mut headers = HashMap::new(); headers.insert("Accept", "application/json"); @@ -573,7 +640,7 @@ async fn should_reject_get_requests() { // should reject POST requests without proper Accept header #[tokio::test] async fn reject_post_requests_without_accept_header() { - let (server, session_id) = initialize_server(None).await.unwrap(); + let (server, session_id) = initialize_server(None, None).await.unwrap(); let json_rpc_message: ClientJsonrpcRequest = ClientJsonrpcRequest::new(RequestId::Integer(1), ListToolsRequest::new(None).into()); @@ -607,7 +674,7 @@ async fn reject_post_requests_without_accept_header() { //should reject unsupported Content-Type #[tokio::test] async fn should_reject_unsupported_content_type() { - let (server, session_id) = initialize_server(None).await.unwrap(); + let (server, session_id) = initialize_server(None, None).await.unwrap(); let json_rpc_message: ClientJsonrpcRequest = ClientJsonrpcRequest::new(RequestId::Integer(1), ListToolsRequest::new(None).into()); @@ -642,7 +709,7 @@ async fn should_reject_unsupported_content_type() { // should handle JSON-RPC batch notification messages with 202 response #[tokio::test] async fn should_handle_batch_notification_messages_with_202_response() { - let (server, session_id) = initialize_server(None).await.unwrap(); + let (server, session_id) = initialize_server(None, None).await.unwrap(); let batch_notification = ClientMessages::Batch(vec![ ClientMessage::from_message(RootsListChangedNotification::new(None), None).unwrap(), @@ -663,7 +730,7 @@ async fn should_handle_batch_notification_messages_with_202_response() { // should properly handle invalid JSON data #[tokio::test] async fn should_properly_handle_invalid_json_data() { - let (server, session_id) = initialize_server(None).await.unwrap(); + let (server, session_id) = initialize_server(None, None).await.unwrap(); let response = send_post_request( &server.streamable_url, @@ -684,7 +751,7 @@ async fn should_properly_handle_invalid_json_data() { // should send response messages to the connection that sent the request #[tokio::test] async fn should_send_response_messages_to_the_connection_that_sent_the_request() { - let (server, session_id) = initialize_server(None).await.unwrap(); + let (server, session_id) = initialize_server(None, None).await.unwrap(); let json_rpc_message_1: ClientJsonrpcRequest = ClientJsonrpcRequest::new(RequestId::Integer(1), ListToolsRequest::new(None).into()); @@ -764,7 +831,7 @@ async fn should_send_response_messages_to_the_connection_that_sent_the_request() // should properly handle DELETE requests and close session #[tokio::test] async fn should_properly_handle_delete_requests_and_close_session() { - let (server, session_id) = initialize_server(None).await.unwrap(); + let (server, session_id) = initialize_server(None, None).await.unwrap(); let mut headers = HashMap::new(); headers.insert("Content-Type", "text/plain"); @@ -785,7 +852,7 @@ async fn should_properly_handle_delete_requests_and_close_session() { // should reject DELETE requests with invalid session ID #[tokio::test] async fn should_reject_delete_requests_with_invalid_session_id() { - let (server, _session_id) = initialize_server(None).await.unwrap(); + let (server, _session_id) = initialize_server(None, None).await.unwrap(); let mut headers = HashMap::new(); headers.insert("Content-Type", "text/plain"); @@ -819,7 +886,7 @@ async fn should_reject_delete_requests_with_invalid_session_id() { // should accept requests without protocol version header #[tokio::test] async fn should_accept_requests_without_protocol_version_header() { - let (server, session_id) = initialize_server(None).await.unwrap(); + let (server, session_id) = initialize_server(None, None).await.unwrap(); let mut headers = HashMap::new(); headers.insert("Content-Type", "application/json"); @@ -846,7 +913,7 @@ async fn should_accept_requests_without_protocol_version_header() { // should reject requests with unsupported protocol version #[tokio::test] async fn should_reject_requests_with_unsupported_protocol_version() { - let (server, session_id) = initialize_server(None).await.unwrap(); + let (server, session_id) = initialize_server(None, None).await.unwrap(); let mut headers = HashMap::new(); headers.insert("Content-Type", "application/json"); @@ -878,7 +945,7 @@ async fn should_reject_requests_with_unsupported_protocol_version() { // should handle protocol version validation for get requests #[tokio::test] async fn should_handle_protocol_version_validation_for_get_requests() { - let (server, session_id) = initialize_server(None).await.unwrap(); + let (server, session_id) = initialize_server(None, None).await.unwrap(); let mut headers = HashMap::new(); headers.insert("Content-Type", "application/json"); @@ -903,7 +970,7 @@ async fn should_handle_protocol_version_validation_for_get_requests() { // should handle protocol version validation for DELETE requests #[tokio::test] async fn should_handle_protocol_version_validation_for_delete_requests() { - let (server, session_id) = initialize_server(None).await.unwrap(); + let (server, session_id) = initialize_server(None, None).await.unwrap(); let mut headers = HashMap::new(); headers.insert("Content-Type", "application/json"); @@ -931,7 +998,7 @@ async fn should_handle_protocol_version_validation_for_delete_requests() { // should return JSON response for a single request #[tokio::test] async fn should_return_json_response_for_a_single_request() { - let (server, session_id) = initialize_server(Some(true)).await.unwrap(); + let (server, session_id) = initialize_server(Some(true), None).await.unwrap(); let json_rpc_message: ClientJsonrpcRequest = ClientJsonrpcRequest::new(RequestId::Integer(1), ListToolsRequest::new(None).into()); @@ -975,7 +1042,7 @@ async fn should_return_json_response_for_a_single_request() { // should return JSON response for batch requests #[tokio::test] async fn should_return_json_response_for_a_batch_request() { - let (server, session_id) = initialize_server(Some(true)).await.unwrap(); + let (server, session_id) = initialize_server(Some(true), None).await.unwrap(); let json_rpc_message_1: ClientJsonrpcRequest = ClientJsonrpcRequest::new( RequestId::String("req_1".to_string()), @@ -1053,7 +1120,7 @@ async fn should_return_json_response_for_a_batch_request() { // should handle batch request messages with SSE stream for responses #[tokio::test] async fn should_handle_batch_request_messages_with_sse_stream_for_responses() { - let (server, session_id) = initialize_server(None).await.unwrap(); + let (server, session_id) = initialize_server(None, None).await.unwrap(); let json_rpc_message_1: ClientJsonrpcRequest = ClientJsonrpcRequest::new( RequestId::String("req_1".to_string()), @@ -1131,18 +1198,18 @@ async fn should_accept_requests_with_allowed_host_headers() { ClientJsonrpcRequest::new(RequestId::Integer(0), initialize_request().into()); let server_options = HyperServerOptions { - port: 8090, + port: 9090, session_id_generator: Some(Arc::new(TestIdGenerator::new(vec![ "AAA-BBB-CCC".to_string() ]))), - allowed_hosts: Some(vec!["127.0.0.1:8090".to_string()]), + allowed_hosts: Some(vec!["127.0.0.1:9090".to_string()]), dns_rebinding_protection: true, ..Default::default() }; let server = create_start_server(server_options).await; - tokio::time::sleep(Duration::from_millis(250)).await; + tokio::time::sleep(Duration::from_millis(350)).await; let response = send_post_request( &server.streamable_url, &serde_json::to_string(&json_rpc_message).unwrap(), @@ -1372,7 +1439,7 @@ async fn should_skip_all_validations_when_false() { #[tokio::test] async fn should_store_and_include_event_ids_in_server_sse_messages() { common::init_tracing(); - let (server, session_id) = initialize_server(Some(true)).await.unwrap(); + let (server, session_id) = initialize_server(Some(true), None).await.unwrap(); let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; assert_eq!(response.status(), StatusCode::OK); @@ -1447,7 +1514,7 @@ async fn should_store_and_include_event_ids_in_server_sse_messages() { #[tokio::test] async fn should_store_and_replay_mcp_server_tool_notifications() { common::init_tracing(); - let (server, session_id) = initialize_server(Some(true)).await.unwrap(); + let (server, session_id) = initialize_server(Some(true), None).await.unwrap(); let response = get_standalone_stream(&server.streamable_url, &session_id, None).await; assert_eq!(response.status(), StatusCode::OK); @@ -1517,6 +1584,130 @@ async fn should_store_and_replay_mcp_server_tool_notifications() { assert_eq!(notification1.params.data.as_str().unwrap(), "notification2"); } +#[tokio::test] +async fn metadata_requires_get_method() { + common::init_tracing(); + let auth_map = HashMap::new(); + let (server, _session_id) = initialize_server(Some(true), Some(auth_map)).await.unwrap(); + + let mut headers = HashMap::new(); + headers.insert("Content-Type", "application/json"); + headers.insert("Accept", "application/json, text/event-stream"); + + let url = Url::parse(&server.streamable_url).unwrap(); + let url = url.join("/.well-known/oauth-authorization-server").unwrap(); + let response = send_post_request(&url.to_string(), "", None, Some(headers)) + .await + .unwrap(); + assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED); +} + +#[tokio::test] +async fn should_return_the_metadata_object_with_cors() { + common::init_tracing(); + let auth_map = HashMap::new(); + let (server, _session_id) = initialize_server(Some(true), Some(auth_map)).await.unwrap(); + + let mut headers = HashMap::new(); + headers.insert("Content-Type", "application/json"); + headers.insert("Accept", "application/json, text/event-stream"); + headers.insert("Origin", "https://example.com"); + + let url = Url::parse(&server.streamable_url).unwrap(); + let url = url.join("/.well-known/oauth-authorization-server").unwrap(); + let response = send_get_request(&url.to_string(), Some(headers)) + .await + .unwrap(); + + let allow_origin = response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN).unwrap(); + assert_eq!(allow_origin, "*"); + + let metadata = response.json::().await.unwrap(); + let issuer = metadata.get("issuer").unwrap().as_str().unwrap(); + assert_eq!(issuer, "http://localhost:3030/"); +} + +#[tokio::test] +// supports OPTIONS preflight requests +async fn should_support_options_preflight_requests() { + common::init_tracing(); + let auth_map = HashMap::new(); + let (server, _session_id) = initialize_server(Some(true), Some(auth_map)).await.unwrap(); + + let mut headers = HashMap::new(); + headers.insert("Content-Type", "application/json"); + headers.insert("Access-Control-Request-Method", "GET"); + headers.insert("Origin", "https://example.com"); + + let url = Url::parse(&server.streamable_url).unwrap(); + let url = url.join("/.well-known/oauth-authorization-server").unwrap(); + let response = send_option_request(&url.to_string(), Some(headers)) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::NO_CONTENT); + let allow_origin = response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN).unwrap(); + assert_eq!(allow_origin, "*"); +} + +#[tokio::test] +// should call a tool with authInfo when authenticated +async fn should_call_a_tool_with_auth_info_when_authenticated() { + common::init_tracing(); + let auth_map = HashMap::new(); + let (server, session_id) = initialize_server(Some(false), Some(auth_map)) + .await + .unwrap(); + + let json_rpc_message: ClientJsonrpcRequest = ClientJsonrpcRequest::new( + RequestId::Integer(1), + CallToolRequest::new(CallToolRequestParams { + arguments: None, + name: "display_auth_info".to_string(), + }) + .into(), + ); + + let response = send_post_request( + &server.streamable_url, + &serde_json::to_string(&json_rpc_message).unwrap(), + Some(&session_id), + Some(HashMap::from([ + ( + AUTHORIZATION.as_str(), + format!("Bearer {VALID_ACCESS_TOKEN}").as_str(), + ), + (CONTENT_TYPE.as_str(), "application/json"), + (ACCEPT.as_str(), "application/json, text/event-stream"), + ])), + ) + .await + .expect("Request failed"); + + assert_eq!(response.status(), StatusCode::OK); + + let events = read_sse_event(response, 1).await.unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&events[0].2).unwrap(); + + assert!(matches!(message.id, RequestId::Integer(1))); + + let ResultFromServer::ServerResult(ServerResult::CallToolResult(result)) = message.result + else { + panic!("invalid CallToolResult") + }; + + let response_json: Value = + serde_json::from_str(&result.content[0].as_text_content().unwrap().text).unwrap(); + + assert_eq!(response_json["client_id"], "valid-client-id"); + assert_eq!(response_json["token_unique_id"], "valid-access-token"); + + assert!(response_json["scopes"] + .as_array() + .unwrap() + .iter() + .all(|s| ["mcp", "mcp:tools"].contains(&s.as_str().unwrap())),); +} // should return 400 error for invalid JSON-RPC messages // should keep stream open after sending server notifications // NA: should reject second initialization request @@ -1525,8 +1716,7 @@ async fn should_store_and_replay_mcp_server_tool_notifications() { // should reject requests to uninitialized server // should accept requests with matching protocol version // should accept when protocol version differs from negotiated version -// should call a tool with authInfo -// should calls tool without authInfo when it is optional + // should accept pre-parsed request body // should handle pre-parsed batch messages // should prefer pre-parsed body over request body diff --git a/crates/rust-mcp-transport/src/constants.rs b/crates/rust-mcp-transport/src/constants.rs index 6ae0342..05ebb5d 100644 --- a/crates/rust-mcp-transport/src/constants.rs +++ b/crates/rust-mcp-transport/src/constants.rs @@ -1,3 +1,3 @@ -pub const MCP_SESSION_ID_HEADER: &str = "Mcp-Session-Id"; -pub const MCP_PROTOCOL_VERSION_HEADER: &str = "Mcp-Protocol-Version"; +pub const MCP_SESSION_ID_HEADER: &str = "mcp-session-id"; +pub const MCP_PROTOCOL_VERSION_HEADER: &str = "mcp-protocol-version"; pub const MCP_LAST_EVENT_ID_HEADER: &str = "last-event-id"; diff --git a/crates/rust-mcp-transport/src/utils/sse_parser.rs b/crates/rust-mcp-transport/src/utils/sse_parser.rs index 3074e9f..8ed5a07 100644 --- a/crates/rust-mcp-transport/src/utils/sse_parser.rs +++ b/crates/rust-mcp-transport/src/utils/sse_parser.rs @@ -46,7 +46,7 @@ impl SseParser { pub fn process_new_chunk(&mut self, bytes: Bytes) -> Vec { self.buffer.extend_from_slice(&bytes); - // Collect complete lines (ending in \n)—keep ALL lines, including empty ones for \n\n detection + // Collect complete lines (ending in \n)-keep ALL lines, including empty ones for \n\n detection let mut lines = Vec::new(); while let Some(pos) = self.buffer.iter().position(|&b| b == b'\n') { let line = self.buffer.split_to(pos + 1).freeze(); diff --git a/examples/auth/server-oauth-remote/Cargo.toml b/examples/auth/server-oauth-remote/Cargo.toml new file mode 100644 index 0000000..f9d7f2b --- /dev/null +++ b/examples/auth/server-oauth-remote/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "server-oauth-remote" +version = "0.1.34" +edition = "2021" +publish = false +license = "MIT" + + +[dependencies] +rust-mcp-sdk = { workspace = true, default-features = false, features = [ + "server", + "macros", + "streamable-http", + "sse", + "auth", + "hyper-server", + "2025_06_18", +] } +rust-mcp-extra={ workspace = true, features=["auth"]} + +tokio = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +async-trait = { workspace = true } +futures = { workspace = true } +tracing = { workspace = true } +tracing-subscriber = { workspace = true, features = ["env-filter"] } + + +[lints] +workspace = true diff --git a/examples/auth/server-oauth-remote/README.md b/examples/auth/server-oauth-remote/README.md new file mode 100644 index 0000000..e8cf59b --- /dev/null +++ b/examples/auth/server-oauth-remote/README.md @@ -0,0 +1,63 @@ +# MCP Server - Remote Oauth (Keycloack) + + +A minimal, MCP server example that demonstrates **OAuth 2.0 / OpenID Connect authentication** using the `RemoteAuthProvider` from [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk). + +It features: +- Full OAuth 2.0 protection via bearer tokens +- Remote authentication metadata discovery +- Token verification using both JWKs and token introspection +- A single tool: `show_auth_info` - returns the authenticated user's claims and scopes in pretty-printed JSON + +## Overview + +**RemoteAuthProvider** can be used with any OpenID Connect provider that supports Dynamic Client Registration (DCR), but in this example, it is configured to point to a local [Keycloak](https://www.keycloak.org) instance. + +👉 For more information on how to start and configure your local Keycloak server, please refer to the **keycloak-setup** section of the following blog post: https://modelcontextprotocol.io/docs/tutorials/security/authorization#keycloak-setup + + +## Running the Example + + +### Step 1: +Clone the repo: +```bash +git clone git@github.com:rust-mcp-stack/rust-mcp-sdk.git +cd rust-mcp-sdk +``` + +### Step 2: +Make sure you have a Keycloak server running and configured as described in this [blog post](https://modelcontextprotocol.io/docs/tutorials/security/authorization#keycloak-setup) + +> 💡 _You can update the configuration in `create_oauth_provider()` function to connect to any other OAuth provider with DCR support or in case your keycloak configuration is different._ + +### Step 3: +Set the `OAUTH_CLIENT_ID` and `OAUTH_CLIENT_SECRET` environment variables with the values from your keycloak server dashboard: + +``` +export OAUTH_CLIENT_ID=test-server OAUTH_CLIENT_SECRET=XYZ +``` + + +### Step 3: +start the project + +```bash +cargo run -p server-oauth-remote +``` + +You will see: + +```sh +• Streamable HTTP Server is available at http://[::1]:3000/ +``` + +You can test it with [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector), or alternatively, use it with any MCP client you prefer. + +```bash +npx -y @modelcontextprotocol/inspector@latest +``` + +Here you can see it in action : + +![mcp-server-remote-oauth](../../assets/examples/mcp-remote-oauth.gif) diff --git a/examples/auth/server-oauth-remote/src/handler.rs b/examples/auth/server-oauth-remote/src/handler.rs new file mode 100644 index 0000000..d137080 --- /dev/null +++ b/examples/auth/server-oauth-remote/src/handler.rs @@ -0,0 +1,71 @@ +use async_trait::async_trait; +use rust_mcp_sdk::auth::AuthInfo; +use rust_mcp_sdk::macros::{mcp_tool, JsonSchema}; +use rust_mcp_sdk::schema::TextContent; +use rust_mcp_sdk::schema::{ + schema_utils::CallToolError, CallToolRequest, CallToolResult, ListToolsRequest, + ListToolsResult, RpcError, +}; +use rust_mcp_sdk::{mcp_server::ServerHandler, McpServer}; +use std::sync::Arc; +use std::vec; + +//*******************************// +// Show Authentication Info // +//*******************************// +#[mcp_tool( + name = "show_auth_info", + description = "Shows current user authentication info in json format" +)] +#[derive(Debug, ::serde::Deserialize, ::serde::Serialize, JsonSchema, Default)] +pub struct ShowAuthInfo {} +impl ShowAuthInfo { + pub fn call_tool(&self, auth_info: Option) -> Result { + let auth_info_json = serde_json::to_string_pretty(&auth_info).map_err(|err| { + CallToolError::from_message(format!("Undable to display auth info as string :{err}")) + })?; + Ok(CallToolResult::text_content(vec![TextContent::from( + auth_info_json, + )])) + } +} + +// Custom Handler to handle MCP Messages +pub struct McpServerHandler; + +// To check out a list of all the methods in the trait that you can override, take a look at +// https://github.com/rust-mcp-stack/rust-mcp-sdk/blob/main/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs + +#[async_trait] +#[allow(unused)] +impl ServerHandler for McpServerHandler { + // Handle ListToolsRequest, return list of available tools as ListToolsResult + async fn handle_list_tools_request( + &self, + request: ListToolsRequest, + runtime: Arc, + ) -> std::result::Result { + Ok(ListToolsResult { + meta: None, + next_cursor: None, + tools: vec![ShowAuthInfo::tool()], + }) + } + + /// Handles incoming CallToolRequest and processes it using the appropriate tool. + async fn handle_call_tool_request( + &self, + request: CallToolRequest, + runtime: Arc, + ) -> std::result::Result { + if request.params.name.eq(&ShowAuthInfo::tool_name()) { + let tool = ShowAuthInfo::default(); + tool.call_tool(runtime.auth_info_cloned().await) + } else { + Err(CallToolError::from_message(format!( + "Tool \"{}\" does not exists or inactive!", + request.params.name, + ))) + } + } +} diff --git a/examples/auth/server-oauth-remote/src/main.rs b/examples/auth/server-oauth-remote/src/main.rs new file mode 100644 index 0000000..e1d442e --- /dev/null +++ b/examples/auth/server-oauth-remote/src/main.rs @@ -0,0 +1,132 @@ +mod handler; + +use crate::handler::McpServerHandler; +use rust_mcp_extra::token_verifier::{ + GenericOauthTokenVerifier, TokenVerifierOptions, VerificationStrategies, +}; +use rust_mcp_sdk::auth::{AuthMetadataBuilder, RemoteAuthProvider}; +use rust_mcp_sdk::error::SdkResult; +use rust_mcp_sdk::event_store::InMemoryEventStore; +use rust_mcp_sdk::mcp_server::{hyper_server, HyperServerOptions}; +use rust_mcp_sdk::schema::{ + Implementation, InitializeResult, ServerCapabilities, ServerCapabilitiesTools, + LATEST_PROTOCOL_VERSION, +}; +use std::env; +use std::sync::Arc; +use std::time::Duration; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +// this function creates and setup a RemoteAuthProvider , pointing to a local KeyCloak server +// please refer to the keycloak-setup section of the following blog post for +// detailed instructions on how to setup a KeyCloak server for this : +// https://modelcontextprotocol.io/docs/tutorials/security/authorization#keycloak-setup +pub async fn create_oauth_provider() -> SdkResult { + // build metadata from a oauth discovery url : .well-known/openid-configuration + let (auth_server_meta, protected_resource_meta) = AuthMetadataBuilder::from_discovery_url( + "http://localhost:8080/realms/master/.well-known/openid-configuration", + "http://localhost:3000", //mcp server url + vec!["mcp:tools", "phone"], + ) + .await? + .resource_name("MCP Server with Remote Oauth") + .build()?; + + // Alternatively, build metadata manually: + // let (auth_server_meta, protected_resource_meta) = + // AuthMetadataBuilder::new("http://localhost:3000") + // .issuer("http://localhost:8080/realms/master") + // .authorization_endpoint("/protocol/openid-connect/auth") + // .token_endpoint("/protocol/openid-connect/token") + // .jwks_uri("/protocol/openid-connect/certs") + // .introspection_endpoint("/protocol/openid-connect/token/introspect") + // .authorization_servers(vec!["http://localhost:8080/realms/master"]) + // .scopes_supported(vec!["mcp:tools", "phone"]) + // .resource_name("MCP Server with Remote Oauth") + // .build()?; + + // create a token verifier with Jwks and Introspection strategies + // GenericOauthTokenVerifier is used from rust-mcp-extra crate + // you can implement yours by implementing the OauthTokenVerifier trait + let token_verifier = GenericOauthTokenVerifier::new(TokenVerifierOptions { + validate_audience: None, + validate_issuer: Some(auth_server_meta.issuer.to_string()), + strategies: vec![ + VerificationStrategies::JWKs { + jwks_uri: auth_server_meta.jwks_uri.as_ref().unwrap().to_string(), + }, + VerificationStrategies::Introspection { + introspection_uri: auth_server_meta + .introspection_endpoint + .as_ref() + .unwrap() + .to_string(), + client_id: env::var("OAUTH_CLIENT_ID") + .expect("Please set the 'OAUTH_CLIENT_ID' environment variable!"), + client_secret: env::var("OAUTH_CLIENT_SECRET") + .expect("Please set the 'OAUTH_CLIENT_SECRET' environment variable!"), + use_basic_auth: true, + extra_params: None, + }, + ], + cache_capacity: Some(15), + }) + .unwrap(); + + Ok(RemoteAuthProvider::new( + auth_server_meta, + protected_resource_meta, + Box::new(token_verifier), + Some(vec!["mcp:tools".to_string()]), + )) +} + +#[tokio::main] +async fn main() -> SdkResult<()> { + // initialize tracing + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let server_details = InitializeResult { + // server name and version + server_info: Implementation { + name: "Remote Oauth Test MCP Server".to_string(), + version: "0.1.0".to_string(), + title: Some("Remote Oauth Test MCP Server".to_string()), + }, + capabilities: ServerCapabilities { + // indicates that server support mcp tools + tools: Some(ServerCapabilitiesTools { list_changed: None }), + ..Default::default() // Using default values for other fields + }, + meta: None, + instructions: Some("server instructions...".to_string()), + protocol_version: LATEST_PROTOCOL_VERSION.to_string(), + }; + + let handler = McpServerHandler {}; + + let oauth_metadata_provider = create_oauth_provider().await?; + + let server = hyper_server::create_server( + server_details, + handler, + HyperServerOptions { + host: "localhost".to_string(), + port: 3000, + custom_streamable_http_endpoint: Some("/".to_string()), + ping_interval: Duration::from_secs(5), + event_store: Some(Arc::new(InMemoryEventStore::default())), // enable resumability + auth: Some(Arc::new(oauth_metadata_provider)), // enable authentication + ..Default::default() + }, + ); + + server.start().await?; + + Ok(()) +}