From 72cca2c7c77de24822aa7d9d7e8a2c8bd38c5d3a Mon Sep 17 00:00:00 2001 From: Sacha <55644767+SoldierSacha@users.noreply.github.com> Date: Tue, 11 Nov 2025 23:48:58 -0500 Subject: [PATCH] Revert "Add branch coverage tests for OAuth metadata and refresh handling" --- tests/unit/client/test_oauth2_providers.py | 98 -------------------- tests/unit/server/auth/test_token_handler.py | 48 +--------- 2 files changed, 1 insertion(+), 145 deletions(-) diff --git a/tests/unit/client/test_oauth2_providers.py b/tests/unit/client/test_oauth2_providers.py index a0e2e1ac9..41025e509 100644 --- a/tests/unit/client/test_oauth2_providers.py +++ b/tests/unit/client/test_oauth2_providers.py @@ -829,104 +829,6 @@ async def post(self, url: str, *, data: dict[str, str], headers: dict[str, str]) assert provider.client_metadata.scope is None -@pytest.mark.anyio -async def test_token_exchange_request_token_stops_on_non_authoritative_response( - monkeypatch: pytest.MonkeyPatch, -) -> None: - storage = InMemoryStorage() - client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope="alpha") - - provider = TokenExchangeProvider( - "https://api.example.com/service", - client_metadata, - storage, - subject_token_supplier=AsyncMock(return_value="subject-token"), - ) - - metadata_responses = [ - _make_response(204), - _make_response(200, json_data=_metadata_json()), - ] - registration_response = _make_response(200, json_data=_registration_json()) - token_response = _make_response(200, json_data=_token_json("alpha")) - - class RecordingAsyncClient(DummyAsyncClient): - def __init__(self, *args: object, **kwargs: object) -> None: - super().__init__(*args, **kwargs) - self.send_calls = 0 - - async def send(self, request: httpx.Request) -> httpx.Response: - self.send_calls += 1 - return await super().send(request) - - recording_client = RecordingAsyncClient(send_responses=list(metadata_responses)) - clients = [ - recording_client, - DummyAsyncClient(send_responses=[registration_response]), - DummyAsyncClient(post_responses=[token_response]), - ] - - monkeypatch.setattr("mcp.client.auth.oauth2.httpx.AsyncClient", AsyncClientFactory(clients)) - - await provider._request_token() - - assert recording_client.send_calls == 1 - assert storage.tokens is not None - assert storage.tokens.scope == "alpha" - assert provider._metadata is None - - -@pytest.mark.anyio -async def test_token_exchange_request_token_retries_after_invalid_metadata( - monkeypatch: pytest.MonkeyPatch, -) -> None: - storage = InMemoryStorage() - client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope="alpha") - - provider = TokenExchangeProvider( - "https://api.example.com/service", - client_metadata, - storage, - subject_token_supplier=AsyncMock(return_value="subject-token"), - ) - - invalid_metadata = _metadata_json() - invalid_metadata.pop("token_endpoint") - - metadata_responses = [ - _make_response(200, json_data=invalid_metadata), - _make_response(200, json_data=_metadata_json()), - ] - registration_response = _make_response(200, json_data=_registration_json()) - token_response = _make_response(200, json_data=_token_json("alpha")) - - class RecordingAsyncClient(DummyAsyncClient): - def __init__(self, *args: object, **kwargs: object) -> None: - super().__init__(*args, **kwargs) - self.send_calls = 0 - - async def send(self, request: httpx.Request) -> httpx.Response: - self.send_calls += 1 - return await super().send(request) - - recording_client = RecordingAsyncClient(send_responses=list(metadata_responses)) - clients = [ - recording_client, - DummyAsyncClient(send_responses=[registration_response]), - DummyAsyncClient(post_responses=[token_response]), - ] - - monkeypatch.setattr("mcp.client.auth.oauth2.httpx.AsyncClient", AsyncClientFactory(clients)) - - await provider._request_token() - - assert recording_client.send_calls == 2 - assert storage.tokens is not None - assert storage.tokens.scope == "alpha" - assert provider._metadata is not None - assert str(provider._metadata.token_endpoint) == _metadata_json()["token_endpoint"] - - @pytest.mark.anyio async def test_token_exchange_request_token_raises_on_failure(monkeypatch: pytest.MonkeyPatch) -> None: storage = InMemoryStorage() diff --git a/tests/unit/server/auth/test_token_handler.py b/tests/unit/server/auth/test_token_handler.py index 2ebef7395..56e41ae94 100644 --- a/tests/unit/server/auth/test_token_handler.py +++ b/tests/unit/server/auth/test_token_handler.py @@ -3,7 +3,7 @@ import json import time from collections.abc import Mapping -from types import MethodType, SimpleNamespace +from types import SimpleNamespace from typing import Any, cast import pytest @@ -12,7 +12,6 @@ from mcp.server.auth.handlers.token import ( AuthorizationCodeRequest, ClientCredentialsRequest, - RefreshTokenRequest, TokenErrorResponse, TokenHandler, TokenSuccessResponse, @@ -288,51 +287,6 @@ async def test_handle_route_refresh_token_invalid_scope() -> None: } -@pytest.mark.anyio -async def test_handle_route_refresh_token_dispatches_to_handler( - monkeypatch: pytest.MonkeyPatch, -) -> None: - provider = RefreshTokenProvider() - client_info = OAuthClientInformationFull( - client_id="client", - grant_types=["refresh_token"], - scope="alpha", - ) - handler = TokenHandler( - provider=cast(OAuthAuthorizationServerProvider[Any, Any, Any], provider), - client_authenticator=cast(ClientAuthenticator, DummyAuthenticator(client_info)), - ) - - captured_requests: list[RefreshTokenRequest] = [] - - async def fake_handle_refresh_token( - self: TokenHandler, - client: OAuthClientInformationFull, - token_request: RefreshTokenRequest, - ) -> TokenSuccessResponse: - captured_requests.append(token_request) - return TokenSuccessResponse(root=OAuthToken(access_token="dispatched-token")) - - monkeypatch.setattr( - handler, - "_handle_refresh_token", - MethodType(fake_handle_refresh_token, handler), - ) - - request_data = { - "grant_type": "refresh_token", - "refresh_token": "refresh-token", - "client_id": "client", - "client_secret": "secret", - } - - response = await handler.handle(cast(Request, DummyRequest(request_data))) - - assert response.status_code == 200 - assert captured_requests - assert isinstance(captured_requests[0], RefreshTokenRequest) - - @pytest.mark.anyio async def test_handle_route_token_exchange_branch() -> None: provider = TokenExchangeProviderStub()