Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 58 additions & 41 deletions src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,14 +251,20 @@ def _create_registration_request(self, metadata: OAuthMetadata | None = None) ->
headers={"Content-Type": "application/json"},
)

async def _handle_registration_response(self, response: httpx.Response) -> None:
async def _handle_registration_response(
self, response: httpx.Response
) -> OAuthClientInformationFull:
if response.status_code not in (200, 201):
await response.aread()
raise OAuthRegistrationError(f"Registration failed: {response.status_code} {response.text}")
content = await response.aread()
client_info = OAuthClientInformationFull.model_validate_json(content)
self._client_info = client_info
await self.storage.set_client_info(client_info)
context = getattr(self, "context", None)
if context is not None:
context.client_info = client_info
return client_info

def _apply_client_auth(
self,
Expand Down Expand Up @@ -315,6 +321,18 @@ def __init__(
)
self._initialized = False

def _build_protected_resource_discovery_urls(self, resource_metadata_url: str | None) -> list[str]:
"""Build the list of PRM discovery URLs with legacy fallbacks."""
return build_protected_resource_metadata_discovery_urls(
resource_metadata_url, self.context.server_url
)

def _get_discovery_urls(self, server_url: str | None = None) -> list[str]:
"""Build OAuth authorization server discovery URLs with legacy fallbacks."""
return build_oauth_authorization_server_metadata_discovery_urls(
server_url, self.context.server_url
)

async def _handle_protected_resource_response(self, response: httpx.Response) -> bool:
"""
Handle protected resource metadata discovery response.
Expand All @@ -324,28 +342,30 @@ async def _handle_protected_resource_response(self, response: httpx.Response) ->
Returns:
True if metadata was successfully discovered, False if we should try next URL
"""
if response.status_code == 200:
try:
content = await response.aread()
metadata = ProtectedResourceMetadata.model_validate_json(content)
self.context.protected_resource_metadata = metadata
if metadata.authorization_servers: # pragma: no branch
self.context.auth_server_url = str(metadata.authorization_servers[0])
return True

except ValidationError: # pragma: no cover
# Invalid metadata - try next URL
logger.warning(f"Invalid protected resource metadata at {response.request.url}")
return False
elif response.status_code == 404: # pragma: no cover
# Not found - try next URL in fallback chain
logger.debug(f"Protected resource metadata not found at {response.request.url}, trying next URL")
return False
else:
# Other error - fail immediately
raise OAuthFlowError(
f"Protected Resource Metadata request failed: {response.status_code}"
) # pragma: no cover
metadata = await handle_protected_resource_response(response)
if metadata:
self.context.protected_resource_metadata = metadata
if metadata.authorization_servers: # pragma: no branch
self.context.auth_server_url = str(metadata.authorization_servers[0])
return True

logger.debug(
"Protected resource metadata discovery failed with status %s at %s",
response.status_code,
response.request.url,
)
return False

async def _handle_oauth_metadata_response(
self, response: httpx.Response
) -> tuple[bool, OAuthMetadata | None]:
ok, asm = await handle_auth_metadata_response(response)
if asm:
self.context.oauth_metadata = asm
self._metadata = asm
if self.context.client_metadata.scope is None and asm.scopes_supported is not None:
self.context.client_metadata.scope = " ".join(asm.scopes_supported)
return ok, asm

async def _perform_authorization(self) -> httpx.Request:
"""Perform the authorization flow."""
Expand Down Expand Up @@ -560,34 +580,33 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
self._metadata = None

# Step 1: Discover protected resource metadata (SEP-985 with fallback support)
prm_discovery_urls = build_protected_resource_metadata_discovery_urls(
www_auth_resource_metadata_url, self.context.server_url
prm_discovery_urls = self._build_protected_resource_discovery_urls(
www_auth_resource_metadata_url
)

for url in prm_discovery_urls: # pragma: no branch
discovery_request = create_oauth_metadata_request(url)
discovery_request = self._create_oauth_metadata_request(url)
discovery_response = yield discovery_request

prm = await handle_protected_resource_response(discovery_response)
if prm:
self.context.protected_resource_metadata = prm
if prm.authorization_servers: # pragma: no branch
self.context.auth_server_url = str(prm.authorization_servers[0])
handled = await self._handle_protected_resource_response(discovery_response)
if handled:
break

logger.debug(f"Protected resource metadata discovery failed: {url}")

# Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers)
asm_discovery_urls = build_oauth_authorization_server_metadata_discovery_urls(
self.context.auth_server_url, self.context.server_url
)
asm_discovery_urls = self._get_discovery_urls(self.context.auth_server_url)

authorization_metadata: OAuthMetadata | None = None
for url in asm_discovery_urls: # pragma: no branch
oauth_metadata_request = create_oauth_metadata_request(url)
oauth_metadata_request = self._create_oauth_metadata_request(url)
oauth_metadata_response = yield oauth_metadata_request

ok, asm = await handle_auth_metadata_response(oauth_metadata_response)
result = await self._handle_oauth_metadata_response(oauth_metadata_response)
if isinstance(result, tuple):
ok, asm = result
else:
ok = bool(result) if result is not None else True
asm = self.context.oauth_metadata or self._metadata

if not ok:
break
if asm:
Expand Down Expand Up @@ -615,9 +634,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
self.context.get_authorization_base_url(self.context.server_url),
)
registration_response = yield registration_request
client_information = await handle_registration_response(registration_response)
self.context.client_info = client_information
await self.context.storage.set_client_info(client_information)
await self._handle_registration_response(registration_response)

# Step 5: Perform authorization and complete token exchange
token_response = yield await self._perform_authorization()
Expand Down
4 changes: 2 additions & 2 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -1364,7 +1364,7 @@ async def callback_handler() -> tuple[str, str | None]:
)

# Mock authorization
provider._perform_authorization_code_grant = mock.AsyncMock(
provider._perform_authorization_code_grant = AsyncMock(
return_value=("test_auth_code", "test_code_verifier")
)

Expand Down Expand Up @@ -1470,7 +1470,7 @@ async def callback_handler() -> tuple[str, str | None]:
request=oauth_metadata_request,
)

provider._perform_authorization_code_grant = mock.AsyncMock(
provider._perform_authorization_code_grant = AsyncMock(
return_value=("test_auth_code", "test_code_verifier")
)

Expand Down