Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 24 additions & 20 deletions src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,12 +215,15 @@ def _get_discovery_urls(self, server_url: str | None = None) -> list[str]:
def _create_oauth_metadata_request(self, url: str) -> httpx.Request:
return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})

async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None:
content = await response.aread()
metadata = OAuthMetadata.model_validate_json(content)
self._metadata = metadata
if self.client_metadata.scope is None and metadata.scopes_supported is not None:
self.client_metadata.scope = " ".join(metadata.scopes_supported)
async def _handle_oauth_metadata_response(
self, response: httpx.Response
) -> tuple[bool, OAuthMetadata | None]:
ok, metadata = await handle_auth_metadata_response(response)
if metadata:
self._metadata = metadata
if self.client_metadata.scope is None and metadata.scopes_supported is not None:
self.client_metadata.scope = " ".join(metadata.scopes_supported)
return ok, metadata

def _create_registration_request(self, metadata: OAuthMetadata | None = None) -> httpx.Request | None:
context = getattr(self, "context", None)
Expand Down Expand Up @@ -348,15 +351,25 @@ async def _handle_protected_resource_response(self, response: httpx.Response) ->
)
return False

async def _handle_oauth_metadata_response(self, response: httpx.Response) -> tuple[bool, OAuthMetadata | None]:
ok, asm = await handle_auth_metadata_response(response)
async def _handle_oauth_metadata_response(
self, response: httpx.Response
) -> tuple[bool, OAuthMetadata | None]:
ok, asm = await super()._handle_oauth_metadata_response(response)
if asm:
self.context.oauth_metadata = asm
self._metadata = asm
if self.context.client_metadata.scope is None and asm.scopes_supported is not None:
self.context.client_metadata.scope = " ".join(asm.scopes_supported)
return ok, asm

def _select_scopes(self, scope_header: str | None) -> None:
"""Select scopes based on discovery data and WWW-Authenticate header."""

self.context.client_metadata.scope = get_client_metadata_scopes(
scope_header,
self.context.protected_resource_metadata,
self.context.oauth_metadata,
)

async def _perform_authorization(self) -> httpx.Request:
"""Perform the authorization flow."""
auth_code, code_verifier = await self._perform_authorization_code_grant()
Expand Down Expand Up @@ -588,12 +601,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
oauth_metadata_request = self._create_oauth_metadata_request(url)
oauth_metadata_response = yield oauth_metadata_request

result = await self._handle_oauth_metadata_response(oauth_metadata_response)
if isinstance(result, tuple):
ok, asm = result
else:
ok = bool(result) if result is not None else True
asm = self.context.oauth_metadata or self._metadata
ok, asm = await self._handle_oauth_metadata_response(oauth_metadata_response)

if not ok:
break
Expand All @@ -608,11 +616,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
self._metadata = authorization_metadata

# Step 3: Apply scope selection strategy
self.context.client_metadata.scope = get_client_metadata_scopes(
www_auth_scope,
self.context.protected_resource_metadata,
self.context.oauth_metadata,
)
self._select_scopes(www_auth_scope)

# Step 4: Register client if needed
if not self.context.client_info:
Expand Down