diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 8410058b7..a3fa9e030 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -215,12 +215,15 @@ def _get_discovery_urls(self, server_url: str | None = None) -> list[str]: def _create_oauth_metadata_request(self, url: str) -> httpx.Request: return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) - async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None: - content = await response.aread() - metadata = OAuthMetadata.model_validate_json(content) - self._metadata = metadata - if self.client_metadata.scope is None and metadata.scopes_supported is not None: - self.client_metadata.scope = " ".join(metadata.scopes_supported) + async def _handle_oauth_metadata_response( + self, response: httpx.Response + ) -> tuple[bool, OAuthMetadata | None]: + ok, metadata = await handle_auth_metadata_response(response) + if metadata: + self._metadata = metadata + if self.client_metadata.scope is None and metadata.scopes_supported is not None: + self.client_metadata.scope = " ".join(metadata.scopes_supported) + return ok, metadata def _create_registration_request(self, metadata: OAuthMetadata | None = None) -> httpx.Request | None: context = getattr(self, "context", None) @@ -348,15 +351,25 @@ async def _handle_protected_resource_response(self, response: httpx.Response) -> ) return False - async def _handle_oauth_metadata_response(self, response: httpx.Response) -> tuple[bool, OAuthMetadata | None]: - ok, asm = await handle_auth_metadata_response(response) + async def _handle_oauth_metadata_response( + self, response: httpx.Response + ) -> tuple[bool, OAuthMetadata | None]: + ok, asm = await super()._handle_oauth_metadata_response(response) if asm: self.context.oauth_metadata = asm - self._metadata = asm if self.context.client_metadata.scope is None and asm.scopes_supported is not None: self.context.client_metadata.scope = " ".join(asm.scopes_supported) return ok, asm + def _select_scopes(self, scope_header: str | None) -> None: + """Select scopes based on discovery data and WWW-Authenticate header.""" + + self.context.client_metadata.scope = get_client_metadata_scopes( + scope_header, + self.context.protected_resource_metadata, + self.context.oauth_metadata, + ) + async def _perform_authorization(self) -> httpx.Request: """Perform the authorization flow.""" auth_code, code_verifier = await self._perform_authorization_code_grant() @@ -588,12 +601,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. oauth_metadata_request = self._create_oauth_metadata_request(url) oauth_metadata_response = yield oauth_metadata_request - result = await self._handle_oauth_metadata_response(oauth_metadata_response) - if isinstance(result, tuple): - ok, asm = result - else: - ok = bool(result) if result is not None else True - asm = self.context.oauth_metadata or self._metadata + ok, asm = await self._handle_oauth_metadata_response(oauth_metadata_response) if not ok: break @@ -608,11 +616,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. self._metadata = authorization_metadata # Step 3: Apply scope selection strategy - self.context.client_metadata.scope = get_client_metadata_scopes( - www_auth_scope, - self.context.protected_resource_metadata, - self.context.oauth_metadata, - ) + self._select_scopes(www_auth_scope) # Step 4: Register client if needed if not self.context.client_info: