Skip to content

Commit

Permalink
feat: add ability to get original userinfo response instead of OpenID (
Browse files Browse the repository at this point in the history
…#148)

* feat: add ability to get original userinfo response instead of OpenID

* chore: make typehinting compatible with python 3.8

* test: fix typing in tests

* test(cov): allow 1 % drop in coverage
  • Loading branch information
tomasvotava committed Apr 3, 2024
1 parent 0b2a0b4 commit 6a311f4
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 10 deletions.
4 changes: 2 additions & 2 deletions codecov.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ coverage:
patch:
default:
target: auto
threshold: null
threshold: 1%
if_not_found: success
only_pulls: true
project:
default:
target: auto
threshold: null
threshold: 1%
if_not_found: success
only_pulls: true
70 changes: 63 additions & 7 deletions fastapi_sso/sso/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sys
import warnings
from types import TracebackType
from typing import Any, Dict, List, Optional, Type, Union
from typing import Any, Dict, List, Literal, Optional, Type, Union, overload

import httpx
import pydantic
Expand Down Expand Up @@ -296,14 +296,37 @@ async def get_login_redirect(
response.set_cookie("pkce_code_verifier", str(self._pkce_code_verifier))
return response

@overload
async def verify_and_process(
self,
request: Request,
*,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None,
redirect_uri: Optional[str] = None,
) -> Optional[OpenID]:
convert_response: Literal[True] = True,
) -> Optional[OpenID]: ...

@overload
async def verify_and_process(
self,
request: Request,
*,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None,
redirect_uri: Optional[str] = None,
convert_response: Literal[False],
) -> Optional[Dict[str, Any]]: ...

async def verify_and_process(
self,
request: Request,
*,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None,
redirect_uri: Optional[str] = None,
convert_response: Union[Literal[True], Literal[False]] = True,
) -> Union[Optional[OpenID], Optional[Dict[str, Any]]]:
"""
Processes the login given a FastAPI (Starlette) Request object. This should be used for the /callback path.
Expand All @@ -312,12 +335,14 @@ async def verify_and_process(
params (Optional[Dict[str, Any]]): Additional query parameters to pass to the provider.
headers (Optional[Dict[str, Any]]): Additional headers to pass to the provider.
redirect_uri (Optional[str]): Overrides the `redirect_uri` specified on this instance.
convert_response (bool): If True, userinfo response is converted to OpenID object.
Raises:
SSOLoginError: If the 'code' parameter is not found in the callback request.
Returns:
Optional[OpenID]: User information in OpenID format if the login was successful.
Optional[OpenID]: User information as OpenID instance (if convert_response == True)
Optional[Dict[str, Any]]: The original JSON response from the API.
"""
headers = headers or {}
code = request.query_params.get("code")
Expand All @@ -338,6 +363,7 @@ async def verify_and_process(
additional_headers=headers,
redirect_uri=redirect_uri,
pkce_code_verifier=pkce_code_verifier,
convert_response=convert_response,
)

def __enter__(self) -> "SSOBase":
Expand All @@ -363,6 +389,7 @@ def __exit__(
def _extra_query_params(self) -> Dict:
return {}

@overload
async def process_login(
self,
code: str,
Expand All @@ -372,7 +399,33 @@ async def process_login(
additional_headers: Optional[Dict[str, Any]] = None,
redirect_uri: Optional[str] = None,
pkce_code_verifier: Optional[str] = None,
) -> Optional[OpenID]:
convert_response: Literal[True] = True,
) -> Optional[OpenID]: ...

@overload
async def process_login(
self,
code: str,
request: Request,
*,
params: Optional[Dict[str, Any]] = None,
additional_headers: Optional[Dict[str, Any]] = None,
redirect_uri: Optional[str] = None,
pkce_code_verifier: Optional[str] = None,
convert_response: Literal[False],
) -> Optional[Dict[str, Any]]: ...

async def process_login(
self,
code: str,
request: Request,
*,
params: Optional[Dict[str, Any]] = None,
additional_headers: Optional[Dict[str, Any]] = None,
redirect_uri: Optional[str] = None,
pkce_code_verifier: Optional[str] = None,
convert_response: Union[Literal[True], Literal[False]] = True,
) -> Union[Optional[OpenID], Optional[Dict[str, Any]]]:
"""
Processes login from the callback endpoint to verify the user and request user info endpoint.
It's a lower-level method, typically, you should use `verify_and_process` instead.
Expand All @@ -384,12 +437,14 @@ async def process_login(
additional_headers (Optional[Dict[str, Any]]): Additional headers to be added to all requests.
redirect_uri (Optional[str]): Overrides the `redirect_uri` specified on this instance.
pkce_code_verifier (Optional[str]): A PKCE code verifier sent to the server to verify the login request.
convert_response (bool): If True, userinfo response is converted to OpenID object.
Raises:
ReusedOauthClientWarning: If the SSO object is reused, which is not safe and caused security issues.
Returns:
Optional[OpenID]: User information in OpenID format if the login was successful.
Optional[OpenID]: User information in OpenID format if the login was successful (convert_response == True).
Optional[Dict[str, Any]]: Original userinfo API endpoint response.
"""
# pylint: disable=too-many-locals
if self._oauth_client is not None: # pragma: no cover
Expand Down Expand Up @@ -447,5 +502,6 @@ async def process_login(
session.headers.update(headers)
response = await session.get(uri)
content = response.json()

return await self.openid_from_response(content, session)
if convert_response:
return await self.openid_from_response(content, session)
return content
2 changes: 1 addition & 1 deletion tests/test_openid_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from fastapi_sso.sso.facebook import FacebookSSO
from fastapi_sso.sso.yandex import YandexSSO

sso_test_cases: Tuple[Type[SSOBase], Tuple[Dict[str, Any], OpenID]] = (
sso_test_cases: Tuple[Tuple[Type[SSOBase], Dict[str, Any], OpenID], ...] = (
(
TwitterSSO,
{"data": {"id": "test", "username": "TestUser1234", "name": "Test User"}},
Expand Down

0 comments on commit 6a311f4

Please sign in to comment.