From 5db0e7e988b6ec0d23ac1fcaefc7627724858aef Mon Sep 17 00:00:00 2001 From: Sacha <55644767+SoldierSacha@users.noreply.github.com> Date: Tue, 11 Nov 2025 23:40:53 -0500 Subject: [PATCH] Test retry after invalid OAuth metadata --- tests/unit/client/test_oauth2_providers.py | 98 ++++++++++++++++++++ tests/unit/server/auth/test_token_handler.py | 48 +++++++++- 2 files changed, 145 insertions(+), 1 deletion(-) diff --git a/tests/unit/client/test_oauth2_providers.py b/tests/unit/client/test_oauth2_providers.py index 41025e509..a0e2e1ac9 100644 --- a/tests/unit/client/test_oauth2_providers.py +++ b/tests/unit/client/test_oauth2_providers.py @@ -829,6 +829,104 @@ async def post(self, url: str, *, data: dict[str, str], headers: dict[str, str]) assert provider.client_metadata.scope is None +@pytest.mark.anyio +async def test_token_exchange_request_token_stops_on_non_authoritative_response( + monkeypatch: pytest.MonkeyPatch, +) -> None: + storage = InMemoryStorage() + client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope="alpha") + + provider = TokenExchangeProvider( + "https://api.example.com/service", + client_metadata, + storage, + subject_token_supplier=AsyncMock(return_value="subject-token"), + ) + + metadata_responses = [ + _make_response(204), + _make_response(200, json_data=_metadata_json()), + ] + registration_response = _make_response(200, json_data=_registration_json()) + token_response = _make_response(200, json_data=_token_json("alpha")) + + class RecordingAsyncClient(DummyAsyncClient): + def __init__(self, *args: object, **kwargs: object) -> None: + super().__init__(*args, **kwargs) + self.send_calls = 0 + + async def send(self, request: httpx.Request) -> httpx.Response: + self.send_calls += 1 + return await super().send(request) + + recording_client = RecordingAsyncClient(send_responses=list(metadata_responses)) + clients = [ + recording_client, + DummyAsyncClient(send_responses=[registration_response]), + DummyAsyncClient(post_responses=[token_response]), + ] + + monkeypatch.setattr("mcp.client.auth.oauth2.httpx.AsyncClient", AsyncClientFactory(clients)) + + await provider._request_token() + + assert recording_client.send_calls == 1 + assert storage.tokens is not None + assert storage.tokens.scope == "alpha" + assert provider._metadata is None + + +@pytest.mark.anyio +async def test_token_exchange_request_token_retries_after_invalid_metadata( + monkeypatch: pytest.MonkeyPatch, +) -> None: + storage = InMemoryStorage() + client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope="alpha") + + provider = TokenExchangeProvider( + "https://api.example.com/service", + client_metadata, + storage, + subject_token_supplier=AsyncMock(return_value="subject-token"), + ) + + invalid_metadata = _metadata_json() + invalid_metadata.pop("token_endpoint") + + metadata_responses = [ + _make_response(200, json_data=invalid_metadata), + _make_response(200, json_data=_metadata_json()), + ] + registration_response = _make_response(200, json_data=_registration_json()) + token_response = _make_response(200, json_data=_token_json("alpha")) + + class RecordingAsyncClient(DummyAsyncClient): + def __init__(self, *args: object, **kwargs: object) -> None: + super().__init__(*args, **kwargs) + self.send_calls = 0 + + async def send(self, request: httpx.Request) -> httpx.Response: + self.send_calls += 1 + return await super().send(request) + + recording_client = RecordingAsyncClient(send_responses=list(metadata_responses)) + clients = [ + recording_client, + DummyAsyncClient(send_responses=[registration_response]), + DummyAsyncClient(post_responses=[token_response]), + ] + + monkeypatch.setattr("mcp.client.auth.oauth2.httpx.AsyncClient", AsyncClientFactory(clients)) + + await provider._request_token() + + assert recording_client.send_calls == 2 + assert storage.tokens is not None + assert storage.tokens.scope == "alpha" + assert provider._metadata is not None + assert str(provider._metadata.token_endpoint) == _metadata_json()["token_endpoint"] + + @pytest.mark.anyio async def test_token_exchange_request_token_raises_on_failure(monkeypatch: pytest.MonkeyPatch) -> None: storage = InMemoryStorage() diff --git a/tests/unit/server/auth/test_token_handler.py b/tests/unit/server/auth/test_token_handler.py index 56e41ae94..2ebef7395 100644 --- a/tests/unit/server/auth/test_token_handler.py +++ b/tests/unit/server/auth/test_token_handler.py @@ -3,7 +3,7 @@ import json import time from collections.abc import Mapping -from types import SimpleNamespace +from types import MethodType, SimpleNamespace from typing import Any, cast import pytest @@ -12,6 +12,7 @@ from mcp.server.auth.handlers.token import ( AuthorizationCodeRequest, ClientCredentialsRequest, + RefreshTokenRequest, TokenErrorResponse, TokenHandler, TokenSuccessResponse, @@ -287,6 +288,51 @@ async def test_handle_route_refresh_token_invalid_scope() -> None: } +@pytest.mark.anyio +async def test_handle_route_refresh_token_dispatches_to_handler( + monkeypatch: pytest.MonkeyPatch, +) -> None: + provider = RefreshTokenProvider() + client_info = OAuthClientInformationFull( + client_id="client", + grant_types=["refresh_token"], + scope="alpha", + ) + handler = TokenHandler( + provider=cast(OAuthAuthorizationServerProvider[Any, Any, Any], provider), + client_authenticator=cast(ClientAuthenticator, DummyAuthenticator(client_info)), + ) + + captured_requests: list[RefreshTokenRequest] = [] + + async def fake_handle_refresh_token( + self: TokenHandler, + client: OAuthClientInformationFull, + token_request: RefreshTokenRequest, + ) -> TokenSuccessResponse: + captured_requests.append(token_request) + return TokenSuccessResponse(root=OAuthToken(access_token="dispatched-token")) + + monkeypatch.setattr( + handler, + "_handle_refresh_token", + MethodType(fake_handle_refresh_token, handler), + ) + + request_data = { + "grant_type": "refresh_token", + "refresh_token": "refresh-token", + "client_id": "client", + "client_secret": "secret", + } + + response = await handler.handle(cast(Request, DummyRequest(request_data))) + + assert response.status_code == 200 + assert captured_requests + assert isinstance(captured_requests[0], RefreshTokenRequest) + + @pytest.mark.anyio async def test_handle_route_token_exchange_branch() -> None: provider = TokenExchangeProviderStub()