From 26fb647b3406709cd70dc3934b61c63e0e2b42d0 Mon Sep 17 00:00:00 2001 From: Sacha <55644767+SoldierSacha@users.noreply.github.com> Date: Tue, 11 Nov 2025 23:50:43 -0500 Subject: [PATCH] Add token exchange metadata fallbacks and refresh match coverage --- tests/unit/client/test_oauth2_providers.py | 113 +++++++++++++++++++ tests/unit/server/auth/test_token_handler.py | 86 +++++++++++++- 2 files changed, 198 insertions(+), 1 deletion(-) diff --git a/tests/unit/client/test_oauth2_providers.py b/tests/unit/client/test_oauth2_providers.py index 41025e509..ad18beb47 100644 --- a/tests/unit/client/test_oauth2_providers.py +++ b/tests/unit/client/test_oauth2_providers.py @@ -829,6 +829,119 @@ async def post(self, url: str, *, data: dict[str, str], headers: dict[str, str]) assert provider.client_metadata.scope is None +@pytest.mark.anyio +async def test_token_exchange_request_token_stops_on_non_authoritative_response( + monkeypatch: pytest.MonkeyPatch, +) -> None: + storage = InMemoryStorage() + client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope="alpha") + + provider = TokenExchangeProvider( + "https://api.example.com/service", + client_metadata, + storage, + subject_token_supplier=AsyncMock(return_value="subject-token"), + ) + + metadata_responses = [ + _make_response(204), + _make_response(200, json_data=_metadata_json()), + ] + registration_response = _make_response(200, json_data=_registration_json()) + token_response = _make_response(200, json_data=_token_json("alpha")) + + class RecordingAsyncClient(DummyAsyncClient): + def __init__(self, *args: object, **kwargs: object) -> None: + super().__init__(*args, **kwargs) + self.send_calls = 0 + + async def send(self, request: httpx.Request) -> httpx.Response: + self.send_calls += 1 + return await super().send(request) + + recording_client = RecordingAsyncClient(send_responses=list(metadata_responses)) + clients = [ + recording_client, + DummyAsyncClient(send_responses=[registration_response]), + DummyAsyncClient(post_responses=[token_response]), + ] + + monkeypatch.setattr("mcp.client.auth.oauth2.httpx.AsyncClient", AsyncClientFactory(clients)) + + await provider._request_token() + + assert recording_client.send_calls == 1 + assert storage.tokens is not None + assert storage.tokens.scope == "alpha" + assert provider._metadata is None + + +@pytest.mark.anyio +async def test_token_exchange_request_token_stops_on_server_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + storage = InMemoryStorage() + client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope="alpha") + + provider = TokenExchangeProvider( + "https://api.example.com/service", + client_metadata, + storage, + subject_token_supplier=AsyncMock(return_value="subject-token"), + ) + + metadata_responses = [_make_response(503)] + registration_response = _make_response(200, json_data=_registration_json()) + token_response = _make_response(200, json_data=_token_json("alpha")) + + clients = [ + DummyAsyncClient(send_responses=metadata_responses), + DummyAsyncClient(send_responses=[registration_response]), + DummyAsyncClient(post_responses=[token_response]), + ] + + monkeypatch.setattr("mcp.client.auth.oauth2.httpx.AsyncClient", AsyncClientFactory(clients)) + + await provider._request_token() + + assert storage.tokens is not None + assert storage.tokens.scope == "alpha" + assert provider._metadata is None + + +@pytest.mark.anyio +async def test_token_exchange_request_token_without_metadata( + monkeypatch: pytest.MonkeyPatch, +) -> None: + storage = InMemoryStorage() + client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope="alpha") + + provider = TokenExchangeProvider( + "https://api.example.com/service", + client_metadata, + storage, + subject_token_supplier=AsyncMock(return_value="subject-token"), + ) + + metadata_responses = [_make_response(404) for _ in range(4)] + registration_response = _make_response(200, json_data=_registration_json()) + token_response = _make_response(200, json_data=_token_json("alpha")) + + clients = [ + DummyAsyncClient(send_responses=metadata_responses), + DummyAsyncClient(send_responses=[registration_response]), + DummyAsyncClient(post_responses=[token_response]), + ] + + monkeypatch.setattr("mcp.client.auth.oauth2.httpx.AsyncClient", AsyncClientFactory(clients)) + + await provider._request_token() + + assert storage.tokens is not None + assert storage.tokens.scope == "alpha" + assert provider._metadata is None + + @pytest.mark.anyio async def test_token_exchange_request_token_raises_on_failure(monkeypatch: pytest.MonkeyPatch) -> None: storage = InMemoryStorage() diff --git a/tests/unit/server/auth/test_token_handler.py b/tests/unit/server/auth/test_token_handler.py index 56e41ae94..04963c3ab 100644 --- a/tests/unit/server/auth/test_token_handler.py +++ b/tests/unit/server/auth/test_token_handler.py @@ -3,7 +3,7 @@ import json import time from collections.abc import Mapping -from types import SimpleNamespace +from types import MethodType, SimpleNamespace from typing import Any, cast import pytest @@ -12,8 +12,10 @@ from mcp.server.auth.handlers.token import ( AuthorizationCodeRequest, ClientCredentialsRequest, + RefreshTokenRequest, TokenErrorResponse, TokenHandler, + TokenRequest, TokenSuccessResponse, ) from mcp.server.auth.middleware.client_auth import ClientAuthenticator @@ -287,6 +289,88 @@ async def test_handle_route_refresh_token_invalid_scope() -> None: } +@pytest.mark.anyio +async def test_handle_route_refresh_token_dispatches_to_handler( + monkeypatch: pytest.MonkeyPatch, +) -> None: + provider = RefreshTokenProvider() + client_info = OAuthClientInformationFull( + client_id="client", + grant_types=["refresh_token"], + scope="alpha", + ) + handler = TokenHandler( + provider=cast(OAuthAuthorizationServerProvider[Any, Any, Any], provider), + client_authenticator=cast(ClientAuthenticator, DummyAuthenticator(client_info)), + ) + + captured_requests: list[RefreshTokenRequest] = [] + + async def fake_handle_refresh_token( + self: TokenHandler, + client: OAuthClientInformationFull, + token_request: RefreshTokenRequest, + ) -> TokenSuccessResponse: + captured_requests.append(token_request) + return TokenSuccessResponse(root=OAuthToken(access_token="dispatched-token")) + + monkeypatch.setattr( + handler, + "_handle_refresh_token", + MethodType(fake_handle_refresh_token, handler), + ) + + request_data = { + "grant_type": "refresh_token", + "refresh_token": "refresh-token", + "client_id": "client", + "client_secret": "secret", + } + + response = await handler.handle(cast(Request, DummyRequest(request_data))) + + assert response.status_code == 200 + assert captured_requests + assert isinstance(captured_requests[0], RefreshTokenRequest) + + +@pytest.mark.anyio +async def test_handle_route_refresh_token_unrecognized_request( + monkeypatch: pytest.MonkeyPatch, +) -> None: + provider = RefreshTokenProvider() + client_info = OAuthClientInformationFull( + client_id="client", + grant_types=["mystery"], + scope="alpha", + ) + handler = TokenHandler( + provider=cast(OAuthAuthorizationServerProvider[Any, Any, Any], provider), + client_authenticator=cast(ClientAuthenticator, DummyAuthenticator(client_info)), + ) + + class UnknownRequest: + grant_type = "mystery" + client_id = "client" + client_secret = "secret" + + unknown_request = UnknownRequest() + + def fake_model_validate(cls: type[TokenRequest], data: dict[str, object]) -> SimpleNamespace: # type: ignore[unused-argument] + return SimpleNamespace(root=unknown_request) + + monkeypatch.setattr(TokenRequest, "model_validate", classmethod(fake_model_validate)) + + request_data = { + "grant_type": "mystery", + "client_id": "client", + "client_secret": "secret", + } + + with pytest.raises(UnboundLocalError): + await handler.handle(cast(Request, DummyRequest(request_data))) + + @pytest.mark.anyio async def test_handle_route_token_exchange_branch() -> None: provider = TokenExchangeProviderStub()