Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 29 additions & 51 deletions src/mcp/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,18 +549,6 @@ def _add_auth_header(self, request: httpx.Request) -> None:
"""Add authorization header to request if we have valid tokens."""
if self.context.current_tokens and self.context.current_tokens.access_token:
request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}"

#<<<<<<< main
#=======
def _create_oauth_metadata_request(self, url: str) -> httpx.Request:
return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})

async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None:
content = await response.aread()
metadata = OAuthMetadata.model_validate_json(content)
self.context.oauth_metadata = metadata

#>>>>>>> main
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
"""HTTPX auth flow integration."""
async with self.context.lock:
Expand Down Expand Up @@ -593,16 +581,13 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
discovery_response = yield discovery_request
await self._handle_protected_resource_response(discovery_response)

#<<<<<<< main
# Step 2: Discover OAuth metadata (with fallback for legacy servers)
discovery_urls = self._get_discovery_urls(self.context.auth_server_url or self.context.server_url)
#=======
# Step 2: Apply scope selection strategy
self._select_scopes(response)

# Step 3: Discover OAuth metadata (with fallback for legacy servers)
discovery_urls = self._get_discovery_urls()
#>>>>>>> main
discovery_urls = self._get_discovery_urls(
self.context.auth_server_url or self.context.server_url
)
for url in discovery_urls:
oauth_metadata_request = self._create_oauth_metadata_request(url)
oauth_metadata_response = yield oauth_metadata_request
Expand All @@ -617,13 +602,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
elif oauth_metadata_response.status_code < 400 or oauth_metadata_response.status_code >= 500:
break # Non-4XX error, stop trying

#<<<<<<< main
# Step 3: Register client if needed
registration_request = self._create_registration_request(self._metadata)
#=======
# Step 4: Register client if needed
registration_request = await self._register_client()
#>>>>>>> main
registration_request = self._create_registration_request(self._metadata)
if registration_request:
registration_response = yield registration_request
await self._handle_registration_response(registration_response)
Expand All @@ -643,7 +623,31 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
# Retry with new tokens
self._add_auth_header(request)
yield request
#<<<<<<< main

elif response.status_code == 403:
# Step 1: Extract error field from WWW-Authenticate header
error = self._extract_field_from_www_auth(response, "error")

# Step 2: Check if we need to step-up authorization
if error == "insufficient_scope":
try:
# Step 2a: Update the required scopes
self._select_scopes(response)

# Step 2b: Perform (re-)authorization
auth_code, code_verifier = await self._perform_authorization()

# Step 2c: Exchange authorization code for tokens
token_request = await self._exchange_token(auth_code, code_verifier)
token_response = yield token_request
await self._handle_token_response(token_response)
except Exception:
logger.exception("OAuth flow error")
raise

# Retry with new tokens
self._add_auth_header(request)
yield request


class ClientCredentialsProvider(BaseOAuthProvider):
Expand Down Expand Up @@ -919,29 +923,3 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
response = yield request
if response.status_code == 401:
self._current_tokens = None
#=======
elif response.status_code == 403:
# Step 1: Extract error field from WWW-Authenticate header
error = self._extract_field_from_www_auth(response, "error")

# Step 2: Check if we need to step-up authorization
if error == "insufficient_scope":
try:
# Step 2a: Update the required scopes
self._select_scopes(response)

# Step 2b: Perform (re-)authorization
auth_code, code_verifier = await self._perform_authorization()

# Step 2c: Exchange authorization code for tokens
token_request = await self._exchange_token(auth_code, code_verifier)
token_response = yield token_request
await self._handle_token_response(token_response)
except Exception:
logger.exception("OAuth flow error")
raise

# Retry with new tokens
self._add_auth_header(request)
yield request
#>>>>>>> main
28 changes: 16 additions & 12 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ async def callback_handler() -> tuple[str, str | None]:


@pytest.fixture
#<<<<<<< main
def client_credentials_metadata():
return OAuthClientMetadata(
redirect_uris=[AnyHttpUrl("http://localhost:3000/callback")],
Expand All @@ -103,7 +102,10 @@ def client_credentials_metadata():
response_types=["code"],
scope="read write",
token_endpoint_auth_method="client_secret_post",
#=======
)


@pytest.fixture
def prm_metadata_response():
"""PRM metadata response with scopes."""
return httpx.Response(
Expand All @@ -113,12 +115,10 @@ def prm_metadata_response():
b'"authorization_servers": ["https://auth.example.com"], '
b'"scopes_supported": ["resource:read", "resource:write"]}'
),
#>>>>>>> main
)


@pytest.fixture
#<<<<<<< main
def oauth_metadata():
return OAuthMetadata(
issuer=AnyHttpUrl("https://auth.example.com"),
Expand All @@ -129,7 +129,10 @@ def oauth_metadata():
response_types_supported=["code"],
grant_types_supported=["authorization_code", "refresh_token", "client_credentials"],
code_challenge_methods_supported=["S256"],
#=======
)


@pytest.fixture
def prm_metadata_without_scopes_response():
"""PRM metadata response without scopes."""
return httpx.Response(
Expand All @@ -139,12 +142,10 @@ def prm_metadata_without_scopes_response():
b'"authorization_servers": ["https://auth.example.com"], '
b'"scopes_supported": null}'
),
#>>>>>>> main
)


@pytest.fixture
#<<<<<<< main
def oauth_client_info():
return OAuthClientInformationFull(
client_id="test_client_id",
Expand All @@ -154,19 +155,20 @@ def oauth_client_info():
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
scope="read write",
#=======
)


@pytest.fixture
def init_response_with_www_auth_scope():
"""Initial 401 response with WWW-Authenticate header containing scope."""
return httpx.Response(
401,
headers={"WWW-Authenticate": 'Bearer scope="special:scope from:www-authenticate"'},
request=httpx.Request("GET", "https://api.example.com/test"),
#>>>>>>> main
)


@pytest.fixture
#<<<<<<< main
def oauth_token():
return OAuthToken(
access_token="test_access_token",
Expand Down Expand Up @@ -197,14 +199,16 @@ async def token_exchange_provider(
client_metadata=client_credentials_metadata,
storage=mock_storage,
subject_token_supplier=lambda: asyncio.sleep(0, result="user_token"),
#=======
)


@pytest.fixture
def init_response_without_www_auth_scope():
"""Initial 401 response without WWW-Authenticate scope."""
return httpx.Response(
401,
headers={},
request=httpx.Request("GET", "https://api.example.com/test"),
#>>>>>>> main
)


Expand Down