diff --git a/setup.cfg b/setup.cfg index f8a7875..df6742c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -27,7 +27,7 @@ license_files = LICENSE platforms = unix, linux, osx, win32 classifiers = Operating System :: OS Independent - Development Status :: 4 - Beta + Development Status :: 5 - Production/Stable Framework :: FastAPI Programming Language :: Python Programming Language :: Python :: 3 diff --git a/src/fastapi_oauth2/__init__.py b/src/fastapi_oauth2/__init__.py index 81a2814..5becc17 100644 --- a/src/fastapi_oauth2/__init__.py +++ b/src/fastapi_oauth2/__init__.py @@ -1 +1 @@ -__version__ = "1.0.0-beta.3" +__version__ = "1.0.0" diff --git a/src/fastapi_oauth2/core.py b/src/fastapi_oauth2/core.py index 1eb6c59..76486e2 100644 --- a/src/fastapi_oauth2/core.py +++ b/src/fastapi_oauth2/core.py @@ -54,6 +54,7 @@ class OAuth2Core: _oauth_client: Optional[WebApplicationClient] = None _authorization_endpoint: str = None _token_endpoint: str = None + _access_token: str = None _state: str = None def __init__(self, client: OAuth2Client) -> None: @@ -70,7 +71,9 @@ def __init__(self, client: OAuth2Client) -> None: @property def access_token(self) -> str: - return self._oauth_client.access_token + if not self._access_token: + self._access_token = self._oauth_client.access_token + return self._access_token def get_redirect_uri(self, request: Request) -> str: return urljoin(str(request.base_url), "/oauth2/%s/token" % self.provider) diff --git a/tests/conftest.py b/tests/conftest.py index 26fc29b..5766231 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -78,6 +78,10 @@ def auth(request: Request): ) return response + @app_router.get("/access-token") + def access_token(request: Request): + return Response(request.auth.provider.access_token) + if with_idp: @app_router.get("/oauth2/{provider}/token") async def token(request: Request, provider: str): diff --git a/tests/test_oauth2.py b/tests/test_oauth2.py index a9f9d2c..6d10669 100644 --- a/tests/test_oauth2.py +++ b/tests/test_oauth2.py @@ -62,3 +62,21 @@ async def test_oauth2_csrf_workflow(get_app): await oauth2_workflow(get_app, idp=True, ssr=False, authorize_query=aq, token_query=tq, use_header=True) except AssertionError: assert aq != tq + + +@pytest.mark.anyio +async def test_core_access_token(get_app): + async with AsyncClient(app=get_app(with_idp=True, with_ssr=True), base_url="http://test") as client: + response = await client.get("/oauth2/test/authorize") + authorization_endpoint = response.headers.get("location") + response = await client.get(authorization_endpoint) + token_url = response.headers.get("location") + query = {k: v[0] for k, v in parse_qs(urlparse(token_url).query).items()} + token_url = "%s?%s" % (token_url.split("?")[0], urlencode(query)) + await client.get(token_url) + + response = await client.get("/access-token") + assert response.content != b"" + + response = await client.get("/access-token") + assert response.content != b""