From b0674ab0ad47e995e6e680140461afb8be14fd22 Mon Sep 17 00:00:00 2001 From: Sacha <55644767+SoldierSacha@users.noreply.github.com> Date: Thu, 13 Nov 2025 20:44:36 -0500 Subject: [PATCH] Fix merge conflicts in OAuth2 auth flow --- src/mcp/client/auth/oauth2.py | 121 +++++++++------------------------- tests/client/test_auth.py | 33 ++-------- 2 files changed, 37 insertions(+), 117 deletions(-) diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 7c44222c1..a43c113b2 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -35,6 +35,7 @@ handle_token_response_scopes, ) from mcp.client.streamable_http import MCP_PROTOCOL_VERSION +from mcp.types import LATEST_PROTOCOL_VERSION from mcp.shared.auth import ( OAuthClientInformationFull, OAuthClientMetadata, @@ -341,35 +342,11 @@ async def _handle_protected_resource_response(self, response: httpx.Response) -> logger.debug(f"Protected resource metadata not found at {response.request.url}, trying next URL") return False else: -#<<<<<<< main - # Priority 3: Omit scope parameter - self.context.client_metadata.scope = None - - # Discovery and registration helpers provided by BaseOAuthProvider -#======= # Other error - fail immediately raise OAuthFlowError( f"Protected Resource Metadata request failed: {response.status_code}" ) # pragma: no cover - async def _register_client(self) -> httpx.Request | None: - """Build registration request or skip if already registered.""" - if self.context.client_info: - return None - - if self.context.oauth_metadata and self.context.oauth_metadata.registration_endpoint: - registration_url = str(self.context.oauth_metadata.registration_endpoint) # pragma: no cover - else: - auth_base_url = self.context.get_authorization_base_url(self.context.server_url) - registration_url = urljoin(auth_base_url, "/register") - - registration_data = self.context.client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True) - - return httpx.Request( - "POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"} - ) -#>>>>>>> main - async def _perform_authorization(self) -> httpx.Request: """Perform the authorization flow.""" auth_code, code_verifier = await self._perform_authorization_code_grant() @@ -473,21 +450,10 @@ async def _exchange_token_authorization_code( async def _handle_token_response(self, response: httpx.Response) -> None: """Handle token exchange response.""" -#<<<<<<< main - if response.status_code != 200: # pragma: no cover - body = response.content or await response.aread() - body = body.decode("utf-8") - raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body}") - - try: - content = response.content or await response.aread() - token_response = OAuthToken.model_validate_json(content) -#======= if response.status_code != 200: body = await response.aread() # pragma: no cover body_text = body.decode("utf-8") # pragma: no cover raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body_text}") # pragma: no cover -#>>>>>>> main # Parse and validate response with scope validation token_response = await handle_token_response_scopes(response) @@ -557,14 +523,6 @@ def _add_auth_header(self, request: httpx.Request) -> None: if self.context.current_tokens and self.context.current_tokens.access_token: # pragma: no branch request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}" -#<<<<<<< main -#======= - async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None: - content = await response.aread() - metadata = OAuthMetadata.model_validate_json(content) - self.context.oauth_metadata = metadata - -#>>>>>>> main async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]: """HTTPX auth flow integration.""" async with self.context.lock: @@ -593,6 +551,13 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. try: # OAuth flow must be inline due to generator constraints www_auth_resource_metadata_url = extract_resource_metadata_from_www_auth(response) + www_auth_scope = extract_scope_from_www_auth(response) + + # Reset discovery context before attempting new discovery sequence + self.context.protected_resource_metadata = None + self.context.auth_server_url = None + self.context.oauth_metadata = None + self._metadata = None # Step 1: Discover protected resource metadata (SEP-985 with fallback support) prm_discovery_urls = build_protected_resource_metadata_discovery_urls( @@ -601,84 +566,58 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. for url in prm_discovery_urls: # pragma: no branch discovery_request = create_oauth_metadata_request(url) + discovery_response = yield discovery_request - discovery_response = yield discovery_request # sending request - -#<<<<<<< main - # Step 3: Discover OAuth metadata (with fallback for legacy servers) - discovery_urls = self._get_discovery_urls(self.context.auth_server_url or self.context.server_url) - for url in discovery_urls: - oauth_metadata_request = self._create_oauth_metadata_request(url) - oauth_metadata_response = yield oauth_metadata_request - - if oauth_metadata_response.status_code == 200: - try: - await self._handle_oauth_metadata_response(oauth_metadata_response) - self.context.oauth_metadata = self._metadata - break - except ValidationError: # pragma: no cover - continue - elif oauth_metadata_response.status_code < 400 or oauth_metadata_response.status_code >= 500: - break # Non-4XX error, stop trying - - # Step 4: Register client if needed - registration_request = self._create_registration_request(self._metadata) - if registration_request: - registration_response = yield registration_request - await self._handle_registration_response(registration_response) - self.context.client_info = self._client_info -#======= prm = await handle_protected_resource_response(discovery_response) if prm: self.context.protected_resource_metadata = prm - - # todo: try all authorization_servers to find the OASM - assert ( - len(prm.authorization_servers) > 0 - ) # this is always true as authorization_servers has a min length of 1 - - self.context.auth_server_url = str(prm.authorization_servers[0]) + if prm.authorization_servers: # pragma: no branch + self.context.auth_server_url = str(prm.authorization_servers[0]) break - else: - logger.debug(f"Protected resource metadata discovery failed: {url}") + logger.debug(f"Protected resource metadata discovery failed: {url}") + + # Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers) asm_discovery_urls = build_oauth_authorization_server_metadata_discovery_urls( self.context.auth_server_url, self.context.server_url ) - # Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers) - for url in asm_discovery_urls: # pragma: no cover + authorization_metadata: OAuthMetadata | None = None + for url in asm_discovery_urls: # pragma: no branch oauth_metadata_request = create_oauth_metadata_request(url) oauth_metadata_response = yield oauth_metadata_request ok, asm = await handle_auth_metadata_response(oauth_metadata_response) if not ok: break - if ok and asm: - self.context.oauth_metadata = asm + if asm: + authorization_metadata = asm break - else: - logger.debug(f"OAuth metadata discovery failed: {url}") + + logger.debug(f"OAuth metadata discovery failed: {url}") + + if authorization_metadata: + self.context.oauth_metadata = authorization_metadata + self._metadata = authorization_metadata # Step 3: Apply scope selection strategy self.context.client_metadata.scope = get_client_metadata_scopes( - www_auth_resource_metadata_url, + www_auth_scope, self.context.protected_resource_metadata, self.context.oauth_metadata, ) # Step 4: Register client if needed - registration_request = create_client_registration_request( - self.context.oauth_metadata, - self.context.client_metadata, - self.context.get_authorization_base_url(self.context.server_url), - ) if not self.context.client_info: + registration_request = create_client_registration_request( + self.context.oauth_metadata, + self.context.client_metadata, + self.context.get_authorization_base_url(self.context.server_url), + ) registration_response = yield registration_request client_information = await handle_registration_response(registration_response) self.context.client_info = client_information await self.context.storage.set_client_info(client_information) -#>>>>>>> main # Step 5: Perform authorization and complete token exchange token_response = yield await self._perform_authorization() diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index c6400ac17..bee725a37 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -13,22 +13,12 @@ from inline_snapshot import Is, snapshot from pydantic import AnyHttpUrl, AnyUrl -#<<<<<<< main from mcp.client.auth import ( ClientCredentialsProvider, OAuthClientProvider, PKCEParameters, TokenExchangeProvider, ) -from mcp.shared.auth import ( - OAuthClientInformationFull, - OAuthClientMetadata, - OAuthMetadata, - OAuthToken, - ProtectedResourceMetadata, -) -#======= -from mcp.client.auth import OAuthClientProvider, PKCEParameters from mcp.client.auth.utils import ( build_oauth_authorization_server_metadata_discovery_urls, build_protected_resource_metadata_discovery_urls, @@ -39,8 +29,13 @@ get_client_metadata_scopes, handle_registration_response, ) -from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken, ProtectedResourceMetadata -#>>>>>>> main +from mcp.shared.auth import ( + OAuthClientInformationFull, + OAuthClientMetadata, + OAuthMetadata, + OAuthToken, + ProtectedResourceMetadata, +) class MockTokenStorage: @@ -556,23 +551,9 @@ async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthCl return_value=("test_auth_code", "test_code_verifier") ) -#<<<<<<< main - # Next request should fall back to legacy behavior: register then obtain token - registration_request = await auth_flow.asend(oauth_metadata_response_3) - assert str(registration_request.url) == "https://api.example.com/register" - assert registration_request.method == "POST" - - registration_response = httpx.Response( - 200, - content=b'{"client_id":"c","redirect_uris":["http://localhost:3030/callback"]}', - request=registration_request, - ) - token_request = await auth_flow.asend(registration_response) -#======= # All path-based URLs failed, flow continues with default endpoints # Next request should be token exchange using MCP server base URL (fallback when OAuth metadata not found) token_request = await auth_flow.asend(oauth_metadata_response_3) -#>>>>>>> main assert str(token_request.url) == "https://api.example.com/token" assert token_request.method == "POST"