Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 30 additions & 91 deletions src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
handle_token_response_scopes,
)
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
from mcp.types import LATEST_PROTOCOL_VERSION
from mcp.shared.auth import (
OAuthClientInformationFull,
OAuthClientMetadata,
Expand Down Expand Up @@ -341,35 +342,11 @@ async def _handle_protected_resource_response(self, response: httpx.Response) ->
logger.debug(f"Protected resource metadata not found at {response.request.url}, trying next URL")
return False
else:
#<<<<<<< main
# Priority 3: Omit scope parameter
self.context.client_metadata.scope = None

# Discovery and registration helpers provided by BaseOAuthProvider
#=======
# Other error - fail immediately
raise OAuthFlowError(
f"Protected Resource Metadata request failed: {response.status_code}"
) # pragma: no cover

async def _register_client(self) -> httpx.Request | None:
"""Build registration request or skip if already registered."""
if self.context.client_info:
return None

if self.context.oauth_metadata and self.context.oauth_metadata.registration_endpoint:
registration_url = str(self.context.oauth_metadata.registration_endpoint) # pragma: no cover
else:
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
registration_url = urljoin(auth_base_url, "/register")

registration_data = self.context.client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True)

return httpx.Request(
"POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"}
)
#>>>>>>> main

async def _perform_authorization(self) -> httpx.Request:
"""Perform the authorization flow."""
auth_code, code_verifier = await self._perform_authorization_code_grant()
Expand Down Expand Up @@ -473,21 +450,10 @@ async def _exchange_token_authorization_code(

async def _handle_token_response(self, response: httpx.Response) -> None:
"""Handle token exchange response."""
#<<<<<<< main
if response.status_code != 200: # pragma: no cover
body = response.content or await response.aread()
body = body.decode("utf-8")
raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body}")

try:
content = response.content or await response.aread()
token_response = OAuthToken.model_validate_json(content)
#=======
if response.status_code != 200:
body = await response.aread() # pragma: no cover
body_text = body.decode("utf-8") # pragma: no cover
raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body_text}") # pragma: no cover
#>>>>>>> main

# Parse and validate response with scope validation
token_response = await handle_token_response_scopes(response)
Expand Down Expand Up @@ -557,14 +523,6 @@ def _add_auth_header(self, request: httpx.Request) -> None:
if self.context.current_tokens and self.context.current_tokens.access_token: # pragma: no branch
request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}"

#<<<<<<< main
#=======
async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None:
content = await response.aread()
metadata = OAuthMetadata.model_validate_json(content)
self.context.oauth_metadata = metadata

#>>>>>>> main
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
"""HTTPX auth flow integration."""
async with self.context.lock:
Expand Down Expand Up @@ -593,6 +551,13 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
try:
# OAuth flow must be inline due to generator constraints
www_auth_resource_metadata_url = extract_resource_metadata_from_www_auth(response)
www_auth_scope = extract_scope_from_www_auth(response)

# Reset discovery context before attempting new discovery sequence
self.context.protected_resource_metadata = None
self.context.auth_server_url = None
self.context.oauth_metadata = None
self._metadata = None

# Step 1: Discover protected resource metadata (SEP-985 with fallback support)
prm_discovery_urls = build_protected_resource_metadata_discovery_urls(
Expand All @@ -601,84 +566,58 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.

for url in prm_discovery_urls: # pragma: no branch
discovery_request = create_oauth_metadata_request(url)
discovery_response = yield discovery_request

discovery_response = yield discovery_request # sending request

#<<<<<<< main
# Step 3: Discover OAuth metadata (with fallback for legacy servers)
discovery_urls = self._get_discovery_urls(self.context.auth_server_url or self.context.server_url)
for url in discovery_urls:
oauth_metadata_request = self._create_oauth_metadata_request(url)
oauth_metadata_response = yield oauth_metadata_request

if oauth_metadata_response.status_code == 200:
try:
await self._handle_oauth_metadata_response(oauth_metadata_response)
self.context.oauth_metadata = self._metadata
break
except ValidationError: # pragma: no cover
continue
elif oauth_metadata_response.status_code < 400 or oauth_metadata_response.status_code >= 500:
break # Non-4XX error, stop trying

# Step 4: Register client if needed
registration_request = self._create_registration_request(self._metadata)
if registration_request:
registration_response = yield registration_request
await self._handle_registration_response(registration_response)
self.context.client_info = self._client_info
#=======
prm = await handle_protected_resource_response(discovery_response)
if prm:
self.context.protected_resource_metadata = prm

# todo: try all authorization_servers to find the OASM
assert (
len(prm.authorization_servers) > 0
) # this is always true as authorization_servers has a min length of 1

self.context.auth_server_url = str(prm.authorization_servers[0])
if prm.authorization_servers: # pragma: no branch
self.context.auth_server_url = str(prm.authorization_servers[0])
break
else:
logger.debug(f"Protected resource metadata discovery failed: {url}")

logger.debug(f"Protected resource metadata discovery failed: {url}")

# Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers)
asm_discovery_urls = build_oauth_authorization_server_metadata_discovery_urls(
self.context.auth_server_url, self.context.server_url
)

# Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers)
for url in asm_discovery_urls: # pragma: no cover
authorization_metadata: OAuthMetadata | None = None
for url in asm_discovery_urls: # pragma: no branch
oauth_metadata_request = create_oauth_metadata_request(url)
oauth_metadata_response = yield oauth_metadata_request

ok, asm = await handle_auth_metadata_response(oauth_metadata_response)
if not ok:
break
if ok and asm:
self.context.oauth_metadata = asm
if asm:
authorization_metadata = asm
break
else:
logger.debug(f"OAuth metadata discovery failed: {url}")

logger.debug(f"OAuth metadata discovery failed: {url}")

if authorization_metadata:
self.context.oauth_metadata = authorization_metadata
self._metadata = authorization_metadata

# Step 3: Apply scope selection strategy
self.context.client_metadata.scope = get_client_metadata_scopes(
www_auth_resource_metadata_url,
www_auth_scope,
self.context.protected_resource_metadata,
self.context.oauth_metadata,
)

# Step 4: Register client if needed
registration_request = create_client_registration_request(
self.context.oauth_metadata,
self.context.client_metadata,
self.context.get_authorization_base_url(self.context.server_url),
)
if not self.context.client_info:
registration_request = create_client_registration_request(
self.context.oauth_metadata,
self.context.client_metadata,
self.context.get_authorization_base_url(self.context.server_url),
)
registration_response = yield registration_request
client_information = await handle_registration_response(registration_response)
self.context.client_info = client_information
await self.context.storage.set_client_info(client_information)
#>>>>>>> main

# Step 5: Perform authorization and complete token exchange
token_response = yield await self._perform_authorization()
Expand Down
33 changes: 7 additions & 26 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,12 @@
from inline_snapshot import Is, snapshot
from pydantic import AnyHttpUrl, AnyUrl

#<<<<<<< main
from mcp.client.auth import (
ClientCredentialsProvider,
OAuthClientProvider,
PKCEParameters,
TokenExchangeProvider,
)
from mcp.shared.auth import (
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthMetadata,
OAuthToken,
ProtectedResourceMetadata,
)
#=======
from mcp.client.auth import OAuthClientProvider, PKCEParameters
from mcp.client.auth.utils import (
build_oauth_authorization_server_metadata_discovery_urls,
build_protected_resource_metadata_discovery_urls,
Expand All @@ -39,8 +29,13 @@
get_client_metadata_scopes,
handle_registration_response,
)
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken, ProtectedResourceMetadata
#>>>>>>> main
from mcp.shared.auth import (
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthMetadata,
OAuthToken,
ProtectedResourceMetadata,
)


class MockTokenStorage:
Expand Down Expand Up @@ -556,23 +551,9 @@ async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthCl
return_value=("test_auth_code", "test_code_verifier")
)

#<<<<<<< main
# Next request should fall back to legacy behavior: register then obtain token
registration_request = await auth_flow.asend(oauth_metadata_response_3)
assert str(registration_request.url) == "https://api.example.com/register"
assert registration_request.method == "POST"

registration_response = httpx.Response(
200,
content=b'{"client_id":"c","redirect_uris":["http://localhost:3030/callback"]}',
request=registration_request,
)
token_request = await auth_flow.asend(registration_response)
#=======
# All path-based URLs failed, flow continues with default endpoints
# Next request should be token exchange using MCP server base URL (fallback when OAuth metadata not found)
token_request = await auth_flow.asend(oauth_metadata_response_3)
#>>>>>>> main
assert str(token_request.url) == "https://api.example.com/token"
assert token_request.method == "POST"

Expand Down