Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 0 additions & 98 deletions tests/unit/client/test_oauth2_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,104 +829,6 @@ async def post(self, url: str, *, data: dict[str, str], headers: dict[str, str])
assert provider.client_metadata.scope is None


@pytest.mark.anyio
async def test_token_exchange_request_token_stops_on_non_authoritative_response(
monkeypatch: pytest.MonkeyPatch,
) -> None:
storage = InMemoryStorage()
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope="alpha")

provider = TokenExchangeProvider(
"https://api.example.com/service",
client_metadata,
storage,
subject_token_supplier=AsyncMock(return_value="subject-token"),
)

metadata_responses = [
_make_response(204),
_make_response(200, json_data=_metadata_json()),
]
registration_response = _make_response(200, json_data=_registration_json())
token_response = _make_response(200, json_data=_token_json("alpha"))

class RecordingAsyncClient(DummyAsyncClient):
def __init__(self, *args: object, **kwargs: object) -> None:
super().__init__(*args, **kwargs)
self.send_calls = 0

async def send(self, request: httpx.Request) -> httpx.Response:
self.send_calls += 1
return await super().send(request)

recording_client = RecordingAsyncClient(send_responses=list(metadata_responses))
clients = [
recording_client,
DummyAsyncClient(send_responses=[registration_response]),
DummyAsyncClient(post_responses=[token_response]),
]

monkeypatch.setattr("mcp.client.auth.oauth2.httpx.AsyncClient", AsyncClientFactory(clients))

await provider._request_token()

assert recording_client.send_calls == 1
assert storage.tokens is not None
assert storage.tokens.scope == "alpha"
assert provider._metadata is None


@pytest.mark.anyio
async def test_token_exchange_request_token_retries_after_invalid_metadata(
monkeypatch: pytest.MonkeyPatch,
) -> None:
storage = InMemoryStorage()
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope="alpha")

provider = TokenExchangeProvider(
"https://api.example.com/service",
client_metadata,
storage,
subject_token_supplier=AsyncMock(return_value="subject-token"),
)

invalid_metadata = _metadata_json()
invalid_metadata.pop("token_endpoint")

metadata_responses = [
_make_response(200, json_data=invalid_metadata),
_make_response(200, json_data=_metadata_json()),
]
registration_response = _make_response(200, json_data=_registration_json())
token_response = _make_response(200, json_data=_token_json("alpha"))

class RecordingAsyncClient(DummyAsyncClient):
def __init__(self, *args: object, **kwargs: object) -> None:
super().__init__(*args, **kwargs)
self.send_calls = 0

async def send(self, request: httpx.Request) -> httpx.Response:
self.send_calls += 1
return await super().send(request)

recording_client = RecordingAsyncClient(send_responses=list(metadata_responses))
clients = [
recording_client,
DummyAsyncClient(send_responses=[registration_response]),
DummyAsyncClient(post_responses=[token_response]),
]

monkeypatch.setattr("mcp.client.auth.oauth2.httpx.AsyncClient", AsyncClientFactory(clients))

await provider._request_token()

assert recording_client.send_calls == 2
assert storage.tokens is not None
assert storage.tokens.scope == "alpha"
assert provider._metadata is not None
assert str(provider._metadata.token_endpoint) == _metadata_json()["token_endpoint"]


@pytest.mark.anyio
async def test_token_exchange_request_token_raises_on_failure(monkeypatch: pytest.MonkeyPatch) -> None:
storage = InMemoryStorage()
Expand Down
48 changes: 1 addition & 47 deletions tests/unit/server/auth/test_token_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import time
from collections.abc import Mapping
from types import MethodType, SimpleNamespace
from types import SimpleNamespace
from typing import Any, cast

import pytest
Expand All @@ -12,7 +12,6 @@
from mcp.server.auth.handlers.token import (
AuthorizationCodeRequest,
ClientCredentialsRequest,
RefreshTokenRequest,
TokenErrorResponse,
TokenHandler,
TokenSuccessResponse,
Expand Down Expand Up @@ -288,51 +287,6 @@ async def test_handle_route_refresh_token_invalid_scope() -> None:
}


@pytest.mark.anyio
async def test_handle_route_refresh_token_dispatches_to_handler(
monkeypatch: pytest.MonkeyPatch,
) -> None:
provider = RefreshTokenProvider()
client_info = OAuthClientInformationFull(
client_id="client",
grant_types=["refresh_token"],
scope="alpha",
)
handler = TokenHandler(
provider=cast(OAuthAuthorizationServerProvider[Any, Any, Any], provider),
client_authenticator=cast(ClientAuthenticator, DummyAuthenticator(client_info)),
)

captured_requests: list[RefreshTokenRequest] = []

async def fake_handle_refresh_token(
self: TokenHandler,
client: OAuthClientInformationFull,
token_request: RefreshTokenRequest,
) -> TokenSuccessResponse:
captured_requests.append(token_request)
return TokenSuccessResponse(root=OAuthToken(access_token="dispatched-token"))

monkeypatch.setattr(
handler,
"_handle_refresh_token",
MethodType(fake_handle_refresh_token, handler),
)

request_data = {
"grant_type": "refresh_token",
"refresh_token": "refresh-token",
"client_id": "client",
"client_secret": "secret",
}

response = await handler.handle(cast(Request, DummyRequest(request_data)))

assert response.status_code == 200
assert captured_requests
assert isinstance(captured_requests[0], RefreshTokenRequest)


@pytest.mark.anyio
async def test_handle_route_token_exchange_branch() -> None:
provider = TokenExchangeProviderStub()
Expand Down