From 219b71f4746f2ade6059e6e2c9a7ac12edb32e88 Mon Sep 17 00:00:00 2001 From: Sacha <55644767+SoldierSacha@users.noreply.github.com> Date: Thu, 13 Nov 2025 21:10:20 -0500 Subject: [PATCH] Fix OAuth discovery fallbacks --- src/mcp/client/auth/oauth2.py | 99 ++++++++++++++++++++--------------- tests/client/test_auth.py | 4 +- 2 files changed, 60 insertions(+), 43 deletions(-) diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index a43c113b2..46d021549 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -251,7 +251,9 @@ def _create_registration_request(self, metadata: OAuthMetadata | None = None) -> headers={"Content-Type": "application/json"}, ) - async def _handle_registration_response(self, response: httpx.Response) -> None: + async def _handle_registration_response( + self, response: httpx.Response + ) -> OAuthClientInformationFull: if response.status_code not in (200, 201): await response.aread() raise OAuthRegistrationError(f"Registration failed: {response.status_code} {response.text}") @@ -259,6 +261,10 @@ async def _handle_registration_response(self, response: httpx.Response) -> None: client_info = OAuthClientInformationFull.model_validate_json(content) self._client_info = client_info await self.storage.set_client_info(client_info) + context = getattr(self, "context", None) + if context is not None: + context.client_info = client_info + return client_info def _apply_client_auth( self, @@ -315,6 +321,18 @@ def __init__( ) self._initialized = False + def _build_protected_resource_discovery_urls(self, resource_metadata_url: str | None) -> list[str]: + """Build the list of PRM discovery URLs with legacy fallbacks.""" + return build_protected_resource_metadata_discovery_urls( + resource_metadata_url, self.context.server_url + ) + + def _get_discovery_urls(self, server_url: str | None = None) -> list[str]: + """Build OAuth authorization server discovery URLs with legacy fallbacks.""" + return build_oauth_authorization_server_metadata_discovery_urls( + server_url, self.context.server_url + ) + async def _handle_protected_resource_response(self, response: httpx.Response) -> bool: """ Handle protected resource metadata discovery response. @@ -324,28 +342,30 @@ async def _handle_protected_resource_response(self, response: httpx.Response) -> Returns: True if metadata was successfully discovered, False if we should try next URL """ - if response.status_code == 200: - try: - content = await response.aread() - metadata = ProtectedResourceMetadata.model_validate_json(content) - self.context.protected_resource_metadata = metadata - if metadata.authorization_servers: # pragma: no branch - self.context.auth_server_url = str(metadata.authorization_servers[0]) - return True - - except ValidationError: # pragma: no cover - # Invalid metadata - try next URL - logger.warning(f"Invalid protected resource metadata at {response.request.url}") - return False - elif response.status_code == 404: # pragma: no cover - # Not found - try next URL in fallback chain - logger.debug(f"Protected resource metadata not found at {response.request.url}, trying next URL") - return False - else: - # Other error - fail immediately - raise OAuthFlowError( - f"Protected Resource Metadata request failed: {response.status_code}" - ) # pragma: no cover + metadata = await handle_protected_resource_response(response) + if metadata: + self.context.protected_resource_metadata = metadata + if metadata.authorization_servers: # pragma: no branch + self.context.auth_server_url = str(metadata.authorization_servers[0]) + return True + + logger.debug( + "Protected resource metadata discovery failed with status %s at %s", + response.status_code, + response.request.url, + ) + return False + + async def _handle_oauth_metadata_response( + self, response: httpx.Response + ) -> tuple[bool, OAuthMetadata | None]: + ok, asm = await handle_auth_metadata_response(response) + if asm: + self.context.oauth_metadata = asm + self._metadata = asm + if self.context.client_metadata.scope is None and asm.scopes_supported is not None: + self.context.client_metadata.scope = " ".join(asm.scopes_supported) + return ok, asm async def _perform_authorization(self) -> httpx.Request: """Perform the authorization flow.""" @@ -560,34 +580,33 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. self._metadata = None # Step 1: Discover protected resource metadata (SEP-985 with fallback support) - prm_discovery_urls = build_protected_resource_metadata_discovery_urls( - www_auth_resource_metadata_url, self.context.server_url + prm_discovery_urls = self._build_protected_resource_discovery_urls( + www_auth_resource_metadata_url ) for url in prm_discovery_urls: # pragma: no branch - discovery_request = create_oauth_metadata_request(url) + discovery_request = self._create_oauth_metadata_request(url) discovery_response = yield discovery_request - prm = await handle_protected_resource_response(discovery_response) - if prm: - self.context.protected_resource_metadata = prm - if prm.authorization_servers: # pragma: no branch - self.context.auth_server_url = str(prm.authorization_servers[0]) + handled = await self._handle_protected_resource_response(discovery_response) + if handled: break - logger.debug(f"Protected resource metadata discovery failed: {url}") - # Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers) - asm_discovery_urls = build_oauth_authorization_server_metadata_discovery_urls( - self.context.auth_server_url, self.context.server_url - ) + asm_discovery_urls = self._get_discovery_urls(self.context.auth_server_url) authorization_metadata: OAuthMetadata | None = None for url in asm_discovery_urls: # pragma: no branch - oauth_metadata_request = create_oauth_metadata_request(url) + oauth_metadata_request = self._create_oauth_metadata_request(url) oauth_metadata_response = yield oauth_metadata_request - ok, asm = await handle_auth_metadata_response(oauth_metadata_response) + result = await self._handle_oauth_metadata_response(oauth_metadata_response) + if isinstance(result, tuple): + ok, asm = result + else: + ok = bool(result) if result is not None else True + asm = self.context.oauth_metadata or self._metadata + if not ok: break if asm: @@ -615,9 +634,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. self.context.get_authorization_base_url(self.context.server_url), ) registration_response = yield registration_request - client_information = await handle_registration_response(registration_response) - self.context.client_info = client_information - await self.context.storage.set_client_info(client_information) + await self._handle_registration_response(registration_response) # Step 5: Perform authorization and complete token exchange token_response = yield await self._perform_authorization() diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index bee725a37..df9dcba8d 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -1364,7 +1364,7 @@ async def callback_handler() -> tuple[str, str | None]: ) # Mock authorization - provider._perform_authorization_code_grant = mock.AsyncMock( + provider._perform_authorization_code_grant = AsyncMock( return_value=("test_auth_code", "test_code_verifier") ) @@ -1470,7 +1470,7 @@ async def callback_handler() -> tuple[str, str | None]: request=oauth_metadata_request, ) - provider._perform_authorization_code_grant = mock.AsyncMock( + provider._perform_authorization_code_grant = AsyncMock( return_value=("test_auth_code", "test_code_verifier") )