From bd8bfbe84ff8258b1221bd689e2d8ae5b31f9c81 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Wed, 14 Dec 2022 11:41:09 +0530 Subject: [PATCH 01/16] feat: add mfa bindings --- gotrue/_async/gotrue_admin_mfa_api.py | 32 +++++++++ gotrue/_async/gotrue_mfa_api.py | 94 +++++++++++++++++++++++++++ 2 files changed, 126 insertions(+) create mode 100644 gotrue/_async/gotrue_admin_mfa_api.py create mode 100644 gotrue/_async/gotrue_mfa_api.py diff --git a/gotrue/_async/gotrue_admin_mfa_api.py b/gotrue/_async/gotrue_admin_mfa_api.py new file mode 100644 index 00000000..ca812fcd --- /dev/null +++ b/gotrue/_async/gotrue_admin_mfa_api.py @@ -0,0 +1,32 @@ +from ..types import ( + AuthMFAAdminDeleteFactorParams, + AuthMFAAdminDeleteFactorResponse, + AuthMFAAdminListFactorsParams, + AuthMFAAdminListFactorsResponse, +) + + +class AsyncGoTrueAdminMFAAPI: + """ + Contains the full multi-factor authentication administration API. + """ + + async def list_factors( + self, + params: AuthMFAAdminListFactorsParams, + ) -> AuthMFAAdminListFactorsResponse: + """ + Lists all factors attached to a user. + """ + raise NotImplementedError() # pragma: no cover + + async def delete_factor( + self, + params: AuthMFAAdminDeleteFactorParams, + ) -> AuthMFAAdminDeleteFactorResponse: + """ + Deletes a factor on a user. This will log the user out of all active + sessions (if the deleted factor was verified). There's no need to delete + unverified factors. + """ + raise NotImplementedError() # pragma: no cover diff --git a/gotrue/_async/gotrue_mfa_api.py b/gotrue/_async/gotrue_mfa_api.py new file mode 100644 index 00000000..a30c4c73 --- /dev/null +++ b/gotrue/_async/gotrue_mfa_api.py @@ -0,0 +1,94 @@ +from ..types import ( + AuthMFAChallengeResponse, + AuthMFAEnrollResponse, + AuthMFAGetAuthenticatorAssuranceLevelResponse, + AuthMFAListFactorsResponse, + AuthMFAUnenrollResponse, + AuthMFAVerifyResponse, + MFAChallengeAndVerifyParams, + MFAChallengeParams, + MFAEnrollParams, + MFAUnenrollParams, + MFAVerifyParams, +) + + +class AsyncGoTrueMFAAPI: + """ + Contains the full multi-factor authentication API. + """ + + async def enroll(self, params: MFAEnrollParams) -> AuthMFAEnrollResponse: + """ + Starts the enrollment process for a new Multi-Factor Authentication + factor. This method creates a new factor in the 'unverified' state. + Present the QR code or secret to the user and ask them to add it to their + authenticator app. Ask the user to provide you with an authenticator code + from their app and verify it by calling challenge and then verify. + + The first successful verification of an unverified factor activates the + factor. All other sessions are logged out and the current one gets an + `aal2` authenticator level. + """ + raise NotImplementedError() # pragma: no cover + + async def challenge(self, params: MFAChallengeParams) -> AuthMFAChallengeResponse: + """ + Prepares a challenge used to verify that a user has access to a MFA + factor. Provide the challenge ID and verification code by calling `verify`. + """ + raise NotImplementedError() # pragma: no cover + + async def challenge_and_verify( + self, + params: MFAChallengeAndVerifyParams, + ) -> AuthMFAVerifyResponse: + """ + Helper method which creates a challenge and immediately uses the given code + to verify against it thereafter. The verification code is provided by the + user by entering a code seen in their authenticator app. + """ + raise NotImplementedError() # pragma: no cover + + async def verify(self, params: MFAVerifyParams) -> AuthMFAVerifyResponse: + """ + Verifies a verification code against a challenge. The verification code is + provided by the user by entering a code seen in their authenticator app. + """ + raise NotImplementedError() # pragma: no cover + + async def unenroll(self, params: MFAUnenrollParams) -> AuthMFAUnenrollResponse: + """ + Unenroll removes a MFA factor. Unverified factors can safely be ignored + and it's not necessary to unenroll them. Unenrolling a verified MFA factor + cannot be done from a session with an `aal1` authenticator level. + """ + raise NotImplementedError() # pragma: no cover + + async def list_factors(self) -> AuthMFAListFactorsResponse: + """ + Returns the list of MFA factors enabled for this user. For most use cases + you should consider using `get_authenticator_assurance_level`. + + This uses a cached version of the factors and avoids incurring a network call. + If you need to update this list, call `get_user` first. + """ + raise NotImplementedError() # pragma: no cover + + async def get_authenticator_assurance_level( + self, + ) -> AuthMFAGetAuthenticatorAssuranceLevelResponse: + """ + Returns the Authenticator Assurance Level (AAL) for the active session. + + - `aal1` (or `null`) means that the user's identity has been verified only + with a conventional login (email+password, OTP, magic link, social login, + etc.). + - `aal2` means that the user's identity has been verified both with a + conventional login and at least one MFA factor. + + Although this method returns a promise, it's fairly quick (microseconds) + and rarely uses the network. You can use this to check whether the current + user needs to be shown a screen to verify their MFA factors. + """ + raise NotImplementedError() # pragma: no cover From 34141a855df646963aa2bb72a88c27c690672e0b Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Wed, 14 Dec 2022 11:42:59 +0530 Subject: [PATCH 02/16] fix: add initial mfa bindings --- gotrue/_sync/api.py | 72 ++++++++++----------- gotrue/_sync/client.py | 32 +++++----- gotrue/_sync/gotrue_admin_mfa_api.py | 32 ++++++++++ gotrue/_sync/gotrue_mfa_api.py | 94 ++++++++++++++++++++++++++++ 4 files changed, 178 insertions(+), 52 deletions(-) create mode 100644 gotrue/_sync/gotrue_admin_mfa_api.py create mode 100644 gotrue/_sync/gotrue_mfa_api.py diff --git a/gotrue/_sync/api.py b/gotrue/_sync/api.py index 6ab024f9..abbdc480 100644 --- a/gotrue/_sync/api.py +++ b/gotrue/_sync/api.py @@ -60,8 +60,8 @@ def create_user(self, *, attributes: UserAttributes) -> User: Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ headers = self.headers data = attributes.dict() @@ -82,8 +82,8 @@ def list_users(self) -> List[User]: Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ headers = self.headers url = f"{self.url}/admin/users" @@ -125,8 +125,8 @@ def sign_up_with_email( Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ headers = self.headers query_string = "" @@ -164,9 +164,10 @@ def sign_in_with_email( Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ + headers = self.headers query_string = "?grant_type=password" if redirect_to: @@ -203,8 +204,8 @@ def sign_up_with_phone( Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ headers = self.headers data = {"phone": phone, "password": password, "data": data} @@ -235,8 +236,8 @@ def sign_in_with_phone( Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ data = {"phone": phone, "password": password} url = f"{self.url}/token?grant_type=password" @@ -262,8 +263,8 @@ def send_magic_link_email( Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ headers = self.headers query_string = "" @@ -285,8 +286,8 @@ def send_mobile_otp(self, *, phone: str, create_user: bool) -> None: Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ headers = self.headers data = {"phone": phone, "create_user": create_user} @@ -320,8 +321,8 @@ def verify_mobile_otp( Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ headers = self.headers data = { @@ -362,8 +363,8 @@ def invite_user_by_email( Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ headers = self.headers query_string = "" @@ -392,8 +393,8 @@ def reset_password_for_email( Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ headers = self.headers query_string = "" @@ -422,8 +423,7 @@ def _create_request_headers(self, *, jwt: str) -> Dict[str, str]: The headers required for a successful request statement with the supabase backend. """ - headers = {**self.headers} - headers["Authorization"] = f"Bearer {jwt}" + headers = {**self.headers, "Authorization": f"Bearer {jwt}"} return headers def sign_out(self, *, jwt: str) -> None: @@ -463,8 +463,8 @@ def get_url_for_provider( Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ url_params = [f"provider={encode_uri_component(provider)}"] if redirect_to: @@ -489,8 +489,8 @@ def get_user(self, *, jwt: str) -> User: Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ headers = self._create_request_headers(jwt=jwt) url = f"{self.url}/user" @@ -520,8 +520,8 @@ def update_user( Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ headers = self._create_request_headers(jwt=jwt) data = attributes.dict() @@ -549,8 +549,8 @@ def delete_user(self, *, uid: str, jwt: str) -> None: Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ headers = self._create_request_headers(jwt=jwt) url = f"{self.url}/admin/users/{uid}" @@ -572,8 +572,8 @@ def refresh_access_token(self, *, refresh_token: str) -> Session: Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ data = {"refresh_token": refresh_token} url = f"{self.url}/token?grant_type=refresh_token" @@ -614,8 +614,8 @@ def generate_link( Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ headers = self.headers data = { diff --git a/gotrue/_sync/client.py b/gotrue/_sync/client.py index f7f20455..3ce6b2de 100644 --- a/gotrue/_sync/client.py +++ b/gotrue/_sync/client.py @@ -121,8 +121,8 @@ def sign_up( Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ self._remove_session() @@ -202,8 +202,8 @@ def sign_in( Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ self._remove_session() if email: @@ -268,8 +268,8 @@ def verify_otp( Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ self._remove_session() response = self.api.verify_mobile_otp( @@ -315,8 +315,8 @@ def update(self, *, attributes: Union[UserAttributesDict, UserAttributes]) -> Us Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ if not self.current_session: raise ValueError("Not logged in.") @@ -350,8 +350,8 @@ def set_session(self, *, refresh_token: str) -> Session: Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ response = self.api.refresh_access_token(refresh_token=refresh_token) self._save_session(session=response) @@ -374,8 +374,8 @@ def set_auth(self, *, access_token: str) -> Session: Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ session = Session( access_token=access_token, @@ -416,8 +416,8 @@ def get_session_from_url( Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ data = urlparse(url) query = parse_qs(data.query) @@ -492,8 +492,8 @@ def on_auth_state_change( Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ unique_id = uuid4() subscription = Subscription( diff --git a/gotrue/_sync/gotrue_admin_mfa_api.py b/gotrue/_sync/gotrue_admin_mfa_api.py new file mode 100644 index 00000000..c3fcfc8e --- /dev/null +++ b/gotrue/_sync/gotrue_admin_mfa_api.py @@ -0,0 +1,32 @@ +from ..types import ( + AuthMFAAdminDeleteFactorParams, + AuthMFAAdminDeleteFactorResponse, + AuthMFAAdminListFactorsParams, + AuthMFAAdminListFactorsResponse, +) + + +class SyncGoTrueAdminMFAAPI: + """ + Contains the full multi-factor authentication administration API. + """ + + def list_factors( + self, + params: AuthMFAAdminListFactorsParams, + ) -> AuthMFAAdminListFactorsResponse: + """ + Lists all factors attached to a user. + """ + raise NotImplementedError() # pragma: no cover + + def delete_factor( + self, + params: AuthMFAAdminDeleteFactorParams, + ) -> AuthMFAAdminDeleteFactorResponse: + """ + Deletes a factor on a user. This will log the user out of all active + sessions (if the deleted factor was verified). There's no need to delete + unverified factors. + """ + raise NotImplementedError() # pragma: no cover diff --git a/gotrue/_sync/gotrue_mfa_api.py b/gotrue/_sync/gotrue_mfa_api.py new file mode 100644 index 00000000..16bec8d5 --- /dev/null +++ b/gotrue/_sync/gotrue_mfa_api.py @@ -0,0 +1,94 @@ +from ..types import ( + AuthMFAChallengeResponse, + AuthMFAEnrollResponse, + AuthMFAGetAuthenticatorAssuranceLevelResponse, + AuthMFAListFactorsResponse, + AuthMFAUnenrollResponse, + AuthMFAVerifyResponse, + MFAChallengeAndVerifyParams, + MFAChallengeParams, + MFAEnrollParams, + MFAUnenrollParams, + MFAVerifyParams, +) + + +class SyncGoTrueMFAAPI: + """ + Contains the full multi-factor authentication API. + """ + + def enroll(self, params: MFAEnrollParams) -> AuthMFAEnrollResponse: + """ + Starts the enrollment process for a new Multi-Factor Authentication + factor. This method creates a new factor in the 'unverified' state. + Present the QR code or secret to the user and ask them to add it to their + authenticator app. Ask the user to provide you with an authenticator code + from their app and verify it by calling challenge and then verify. + + The first successful verification of an unverified factor activates the + factor. All other sessions are logged out and the current one gets an + `aal2` authenticator level. + """ + raise NotImplementedError() # pragma: no cover + + def challenge(self, params: MFAChallengeParams) -> AuthMFAChallengeResponse: + """ + Prepares a challenge used to verify that a user has access to a MFA + factor. Provide the challenge ID and verification code by calling `verify`. + """ + raise NotImplementedError() # pragma: no cover + + def challenge_and_verify( + self, + params: MFAChallengeAndVerifyParams, + ) -> AuthMFAVerifyResponse: + """ + Helper method which creates a challenge and immediately uses the given code + to verify against it thereafter. The verification code is provided by the + user by entering a code seen in their authenticator app. + """ + raise NotImplementedError() # pragma: no cover + + def verify(self, params: MFAVerifyParams) -> AuthMFAVerifyResponse: + """ + Verifies a verification code against a challenge. The verification code is + provided by the user by entering a code seen in their authenticator app. + """ + raise NotImplementedError() # pragma: no cover + + def unenroll(self, params: MFAUnenrollParams) -> AuthMFAUnenrollResponse: + """ + Unenroll removes a MFA factor. Unverified factors can safely be ignored + and it's not necessary to unenroll them. Unenrolling a verified MFA factor + cannot be done from a session with an `aal1` authenticator level. + """ + raise NotImplementedError() # pragma: no cover + + def list_factors(self) -> AuthMFAListFactorsResponse: + """ + Returns the list of MFA factors enabled for this user. For most use cases + you should consider using `get_authenticator_assurance_level`. + + This uses a cached version of the factors and avoids incurring a network call. + If you need to update this list, call `get_user` first. + """ + raise NotImplementedError() # pragma: no cover + + def get_authenticator_assurance_level( + self, + ) -> AuthMFAGetAuthenticatorAssuranceLevelResponse: + """ + Returns the Authenticator Assurance Level (AAL) for the active session. + + - `aal1` (or `null`) means that the user's identity has been verified only + with a conventional login (email+password, OTP, magic link, social login, + etc.). + - `aal2` means that the user's identity has been verified both with a + conventional login and at least one MFA factor. + + Although this method returns a promise, it's fairly quick (microseconds) + and rarely uses the network. You can use this to check whether the current + user needs to be shown a screen to verify their MFA factors. + """ + raise NotImplementedError() # pragma: no cover From 82b78018ef8f888ad8c7e3da1e97fae40ba1b60b Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Wed, 14 Dec 2022 16:09:47 +0530 Subject: [PATCH 03/16] fix: add types --- gotrue/_async/gotrue_mfa_api.py | 17 +++ gotrue/types.py | 242 +++++++++++++++++++++++++++++++- 2 files changed, 257 insertions(+), 2 deletions(-) diff --git a/gotrue/_async/gotrue_mfa_api.py b/gotrue/_async/gotrue_mfa_api.py index a30c4c73..d8489142 100644 --- a/gotrue/_async/gotrue_mfa_api.py +++ b/gotrue/_async/gotrue_mfa_api.py @@ -1,3 +1,4 @@ +from ..http_clients import AsyncClient from ..types import ( AuthMFAChallengeResponse, AuthMFAEnrollResponse, @@ -18,6 +19,19 @@ class AsyncGoTrueMFAAPI: Contains the full multi-factor authentication API. """ + def __init__( + self, + *, + url: str, + headers: Dict[str, str], + cookie_options: CookieOptions, + http_client: Optional[AsyncClient] = None + ): + self.url = url + self.headers = headers + self.cookie_options = cookie_options + self.http_client = http_client or AsyncClient() + async def enroll(self, params: MFAEnrollParams) -> AuthMFAEnrollResponse: """ Starts the enrollment process for a new Multi-Factor Authentication @@ -30,6 +44,9 @@ async def enroll(self, params: MFAEnrollParams) -> AuthMFAEnrollResponse: factor. All other sessions are logged out and the current one gets an `aal2` authenticator level. """ + headers = self.headers + data = {} + raise NotImplementedError() # pragma: no cover async def challenge(self, params: MFAChallengeParams) -> AuthMFAChallengeResponse: diff --git a/gotrue/types.py b/gotrue/types.py index fdfd294d..09b27ba3 100644 --- a/gotrue/types.py +++ b/gotrue/types.py @@ -8,9 +8,9 @@ from uuid import UUID if sys.version_info >= (3, 8): - from typing import TypedDict + from typing import Literal, NotRequired, TypedDict else: - from typing_extensions import TypedDict + from typing_extensions import Literal, TypedDict, NotRequired from httpx import Response from pydantic import BaseModel, root_validator @@ -82,6 +82,37 @@ class User(BaseModelFromResponse): email_change_sent_at: Optional[datetime] = None new_phone: Optional[str] = None phone_change_sent_at: Optional[datetime] = None + factors: Union[List[Factor], None] = None + + +class Factor(BaseModel): + """ + A MFA factor. + """ + + id: str + """ + ID of the factor. + """ + friendly_name: Union[str, None] = None + """ + Friendly name of the factor, useful to disambiguate between multiple factors. + """ + factor_type: Union[Literal["totp"], str] + """ + Type of factor. Only `totp` supported with this version but may change in + future versions. + """ + status: Literal["verified", "unverified"] + """ + Factor's status. + """ + created_at: datetime + updated_at: datetime + + +class UpdatableFactorAttributes(TypedDict): + friendly_name: str class UserAttributes(BaseModelFromResponse): @@ -165,3 +196,210 @@ class UserAttributesDict(TypedDict, total=False): password: Optional[str] email_change_token: Optional[str] data: Optional[Any] + + +class MFAEnrollParams(TypedDict): + factor_type: Literal["totp"] + issuer: NotRequired[str] + friendly_name: NotRequired[str] + + +class MFAUnenrollParams(TypedDict): + factor_id: str + """ + ID of the factor being unenrolled. + """ + + +class MFAVerifyParams(TypedDict): + factor_id: str + """ + ID of the factor being verified. + """ + challenge_id: str + """ + ID of the challenge being verified. + """ + code: str + """ + Verification code provided by the user. + """ + + +class MFAChallengeParams(TypedDict): + factor_id: str + """ + ID of the factor to be challenged. + """ + + +class MFAChallengeAndVerifyParams(TypedDict): + factor_id: str + """ + ID of the factor being verified. + """ + code: str + """ + Verification code provided by the user. + """ + + +class AuthMFAVerifyResponse(BaseModel): + access_token: str + """ + New access token (JWT) after successful verification. + """ + token_type: str + """ + Type of token, typically `Bearer`. + """ + expires_in: int + """ + Number of seconds in which the access token will expire. + """ + refresh_token: str + """ + Refresh token you can use to obtain new access tokens when expired. + """ + user: User + """ + Updated user profile. + """ + + +class AuthMFAEnrollResponseTotp(BaseModel): + qr_code: str + """ + Contains a QR code encoding the authenticator URI. You can + convert it to a URL by prepending `data:image/svg+xml;utf-8,` to + the value. Avoid logging this value to the console. + """ + secret: str + """ + The TOTP secret (also encoded in the QR code). Show this secret + in a password-style field to the user, in case they are unable to + scan the QR code. Avoid logging this value to the console. + """ + uri: str + """ + The authenticator URI encoded within the QR code, should you need + to use it. Avoid loggin this value to the console. + """ + + +class AuthMFAEnrollResponse(BaseModel): + id: str + """ + ID of the factor that was just enrolled (in an unverified state). + """ + type: Literal["totp"] + """ + Type of MFA factor. Only `totp` supported for now. + """ + totp: AuthMFAEnrollResponseTotp + """ + TOTP enrollment information. + """ + + +class AuthMFAUnenrollResponse(BaseModel): + id: str + """ + ID of the factor that was successfully unenrolled. + """ + + +class AuthMFAChallengeResponse(BaseModel): + id: str + """ + ID of the newly created challenge. + """ + expires_at: int + """ + Timestamp in UNIX seconds when this challenge will no longer be usable. + """ + + +class AuthMFAListFactorsResponse(BaseModel): + all: List[Factor] + """ + All available factors (verified and unverified). + """ + totp: List[Factor] + """ + Only verified TOTP factors. (A subset of `all`.) + """ + + +AuthenticatorAssuranceLevels = Literal["aal1", "aal2"] + + +class AuthMFAGetAuthenticatorAssuranceLevelResponse(BaseModel): + current_level: Union[AuthenticatorAssuranceLevels, None] = None + """ + Current AAL level of the session. + """ + next_level: Union[AuthenticatorAssuranceLevels, None] = None + """ + Next possible AAL level for the session. If the next level is higher + than the current one, the user should go through MFA. + """ + current_authentication_methods: List[AMREntry] + """ + A list of all authentication methods attached to this session. Use + the information here to detect the last time a user verified a + factor, for example if implementing a step-up scenario. + """ + + +class AuthMFAAdminDeleteFactorResponse(BaseModel): + id: str + """ + ID of the factor that was successfully deleted. + """ + + +class AuthMFAAdminDeleteFactorParams(TypedDict): + id: str + """ + ID of the MFA factor to delete. + """ + user_id: str + """ + ID of the user whose factor is being deleted. + """ + + +class AuthMFAAdminListFactorsResponse(BaseModel): + factors: List[Factor] + """ + All factors attached to the user. + """ + + +class AuthMFAAdminListFactorsParams(TypedDict): + user_id: str + """ + ID of the user for which to list all MFA factors. + """ + + +class DecodedJWTDict(TypedDict): + exp: NotRequired[int] + aal: NotRequired[Union[AuthenticatorAssuranceLevels, None]] + amr: NotRequired[Union[List[AMREntry], None]] + + +AMREntry.update_forward_refs() +UserResponse.update_forward_refs() +Factor.update_forward_refs() +User.update_forward_refs() +AuthMFAVerifyResponse.update_forward_refs() +AuthMFAEnrollResponseTotp.update_forward_refs() +AuthMFAEnrollResponse.update_forward_refs() +AuthMFAUnenrollResponse.update_forward_refs() +AuthMFAChallengeResponse.update_forward_refs() +AuthMFAListFactorsResponse.update_forward_refs() +AuthMFAGetAuthenticatorAssuranceLevelResponse.update_forward_refs() +AuthMFAAdminDeleteFactorResponse.update_forward_refs() +AuthMFAAdminListFactorsResponse.update_forward_refs() From 04af894d4bf211b301b1b581e5216a7ea515a505 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Tue, 17 Jan 2023 16:43:06 +0800 Subject: [PATCH 04/16] feat: port over relevant bindings from `next` --- .github/workflows/ci.yml | 6 +- gotrue/_async/gotrue_client.py | 841 ++++++++++++++++++++++++++++++++ gotrue/_async/gotrue_mfa_api.py | 11 +- gotrue/_sync/gotrue_client.py | 839 +++++++++++++++++++++++++++++++ gotrue/_sync/gotrue_mfa_api.py | 26 +- 5 files changed, 1715 insertions(+), 8 deletions(-) create mode 100644 gotrue/_async/gotrue_client.py create mode 100644 gotrue/_sync/gotrue_client.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index abc798b7..e89cd29f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,7 +8,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: [3.7, 3.8, 3.9, "3.10"] + python-version: [3.7, 3.8, 3.9, "3.10", "3.11"] runs-on: ${{ matrix.os }} steps: - name: Clone Repository @@ -18,9 +18,9 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Set up Poetry - uses: abatilo/actions-poetry@v2.2.0 + uses: abatilo/actions-poetry@v2.1.6 with: - poetry-version: 1.1.13 + poetry-version: 1.2.2 - name: Run Tests run: make run_tests - name: Upload Coverage diff --git a/gotrue/_async/gotrue_client.py b/gotrue/_async/gotrue_client.py new file mode 100644 index 00000000..97f5f14c --- /dev/null +++ b/gotrue/_async/gotrue_client.py @@ -0,0 +1,841 @@ +from __future__ import annotations + +from json import loads +from time import time +from typing import Callable, Dict, List, Tuple, Union +from urllib.parse import parse_qs, quote, urlencode, urlparse +from uuid import uuid4 + +from ..constants import ( + DEFAULT_HEADERS, + EXPIRY_MARGIN, + GOTRUE_URL, + MAX_RETRIES, + RETRY_INTERVAL, + STORAGE_KEY, +) +from ..errors import ( + AuthImplicitGrantRedirectError, + AuthInvalidCredentialsError, + AuthRetryableError, + AuthSessionMissingError, +) +from ..helpers import decode_jwt_payload, parse_auth_response, parse_user_response +from ..http_clients import AsyncClient +from ..timer import Timer +from ..types import ( + AuthChangeEvent, + AuthenticatorAssuranceLevels, + AuthMFAChallengeResponse, + AuthMFAEnrollResponse, + AuthMFAGetAuthenticatorAssuranceLevelResponse, + AuthMFAListFactorsResponse, + AuthMFAUnenrollResponse, + AuthMFAVerifyResponse, + AuthResponse, + DecodedJWTDict, + MFAChallengeAndVerifyParams, + MFAChallengeParams, + MFAEnrollParams, + MFAUnenrollParams, + MFAVerifyParams, + OAuthResponse, + Options, + Provider, + Session, + SignInWithOAuthCredentials, + SignInWithPasswordCredentials, + SignInWithPasswordlessCredentials, + SignUpWithPasswordCredentials, + Subscription, + UserAttributes, + UserResponse, + VerifyOtpParams, +) +from .gotrue_admin_api import AsyncGoTrueAdminAPI +from .gotrue_base_api import AsyncGoTrueBaseAPI +from .gotrue_mfa_api import AsyncGoTrueMFAAPI +from .storage import AsyncMemoryStorage, AsyncSupportedStorage + + +class AsyncGoTrueClient(AsyncGoTrueBaseAPI): + def __init__( + self, + *, + url: Union[str, None] = None, + headers: Union[Dict[str, str], None] = None, + storage_key: Union[str, None] = None, + auto_refresh_token: bool = True, + persist_session: bool = True, + storage: Union[AsyncSupportedStorage, None] = None, + http_client: Union[AsyncClient, None] = None, + ) -> None: + AsyncGoTrueBaseAPI.__init__( + self, + url=url or GOTRUE_URL, + headers=headers or DEFAULT_HEADERS, + http_client=http_client, + ) + self._storage_key = storage_key or STORAGE_KEY + self._auto_refresh_token = auto_refresh_token + self._persist_session = persist_session + self._storage = storage or AsyncMemoryStorage() + self._in_memory_session: Union[Session, None] = None + self._refresh_token_timer: Union[Timer, None] = None + self._network_retries = 0 + self._state_change_emitters: Dict[str, Subscription] = {} + + self.admin = AsyncGoTrueAdminAPI( + url=self._url, + headers=self._headers, + http_client=self._http_client, + ) + self.mfa = AsyncGoTrueMFAAPI() + self.mfa.challenge = self._challenge + self.mfa.challenge_and_verify = self._challenge_and_verify + self.mfa.enroll = self._enroll + self.mfa.get_authenticator_assurance_level = ( + self._get_authenticator_assurance_level + ) + self.mfa.list_factors = self._list_factors + self.mfa.unenroll = self._unenroll + self.mfa.verify = self._verify + + # Initializations + + async def initialize(self, *, url: Union[str, None] = None) -> None: + if url and self._is_implicit_grant_flow(url): + await self.initialize_from_url(url) + else: + await self.initialize_from_storage() + + async def initialize_from_storage(self) -> None: + return await self._recover_and_refresh() + + async def initialize_from_url(self, url: str) -> None: + try: + if self._is_implicit_grant_flow(url): + session, redirect_type = await self._get_session_from_url(url) + await self._save_session(session) + self._notify_all_subscribers("SIGNED_IN", session) + if redirect_type == "recovery": + self._notify_all_subscribers("PASSWORD_RECOVERY", session) + except Exception as e: + await self._remove_session() + raise e + + # Public methods + + async def sign_up( + self, + credentials: SignUpWithPasswordCredentials, + ) -> AuthResponse: + """ + Creates a new user. + """ + await self._remove_session() + email = credentials.get("email") + phone = credentials.get("phone") + password = credentials.get("password") + options = credentials.get("options", {}) + redirect_to = options.get("redirect_to") + data = options.get("data") or {} + captcha_token = options.get("captcha_token") + if email: + response = await self._request( + "POST", + "signup", + body={ + "email": email, + "password": password, + "data": data, + "gotrue_meta_security": { + "captcha_token": captcha_token, + }, + }, + redirect_to=redirect_to, + xform=parse_auth_response, + ) + elif phone: + response = await self._request( + "POST", + "signup", + body={ + "phone": phone, + "password": password, + "data": data, + "gotrue_meta_security": { + "captcha_token": captcha_token, + }, + }, + xform=parse_auth_response, + ) + else: + raise AuthInvalidCredentialsError( + "You must provide either an email or phone number and a password" + ) + if response.session: + await self._save_session(response.session) + self._notify_all_subscribers("SIGNED_IN", response.session) + return response + + async def sign_in_with_password( + self, + credentials: SignInWithPasswordCredentials, + ) -> AuthResponse: + """ + Log in an existing user with an email or phone and password. + """ + await self._remove_session() + email = credentials.get("email") + phone = credentials.get("phone") + password = credentials.get("password") + options = credentials.get("options", {}) + data = options.get("data") or {} + captcha_token = options.get("captcha_token") + if email: + response = await self._request( + "POST", + "token", + body={ + "email": email, + "password": password, + "data": data, + "gotrue_meta_security": { + "captcha_token": captcha_token, + }, + }, + query={ + "grant_type": "password", + }, + xform=parse_auth_response, + ) + elif phone: + response = await self._request( + "POST", + "token", + body={ + "phone": phone, + "password": password, + "data": data, + "gotrue_meta_security": { + "captcha_token": captcha_token, + }, + }, + query={ + "grant_type": "password", + }, + xform=parse_auth_response, + ) + else: + raise AuthInvalidCredentialsError( + "You must provide either an email or phone number and a password" + ) + if response.session: + await self._save_session(response.session) + self._notify_all_subscribers("SIGNED_IN", response.session) + return response + + async def sign_in_with_oauth( + self, + credentials: SignInWithOAuthCredentials, + ) -> OAuthResponse: + """ + Log in an existing user via a third-party provider. + """ + await self._remove_session() + provider = credentials.get("provider") + options = credentials.get("options", {}) + redirect_to = options.get("redirect_to") + scopes = options.get("scopes") + params = options.get("query_params", {}) + if redirect_to: + params["redirect_to"] = redirect_to + if scopes: + params["scopes"] = scopes + url = self._get_url_for_provider(provider, params) + return OAuthResponse(provider=provider, url=url) + + async def sign_in_with_otp( + self, + credentials: SignInWithPasswordlessCredentials, + ) -> AuthResponse: + """ + Log in a user using magiclink or a one-time password (OTP). + + If the `{{ .ConfirmationURL }}` variable is specified in + the email template, a magiclink will be sent. + + If the `{{ .Token }}` variable is specified in the email + template, an OTP will be sent. + + If you're using phone sign-ins, only an OTP will be sent. + You won't be able to send a magiclink for phone sign-ins. + """ + await self._remove_session() + email = credentials.get("email") + phone = credentials.get("phone") + options = credentials.get("options", {}) + email_redirect_to = options.get("email_redirect_to") + should_create_user = options.get("create_user", True) + data = options.get("data") + captcha_token = options.get("captcha_token") + if email: + return await self._request( + "POST", + "otp", + body={ + "email": email, + "data": data, + "create_user": should_create_user, + "gotrue_meta_security": { + "captcha_token": captcha_token, + }, + }, + redirect_to=email_redirect_to, + xform=parse_auth_response, + ) + if phone: + return await self._request( + "POST", + "otp", + body={ + "phone": phone, + "data": data, + "create_user": should_create_user, + "gotrue_meta_security": { + "captcha_token": captcha_token, + }, + }, + xform=parse_auth_response, + ) + raise AuthInvalidCredentialsError( + "You must provide either an email or phone number" + ) + + async def verify_otp(self, params: VerifyOtpParams) -> AuthResponse: + """ + Log in a user given a User supplied OTP received via mobile. + """ + await self._remove_session() + response = await self._request( + "POST", + "verify", + body={ + "gotrue_meta_security": { + "captcha_token": params.get("options", {}).get("captcha_token"), + }, + **params, + }, + redirect_to=params.get("options", {}).get("redirect_to"), + xform=parse_auth_response, + ) + if response.session: + await self._save_session(response.session) + self._notify_all_subscribers("SIGNED_IN", response.session) + return response + + async def get_session(self) -> Union[Session, None]: + """ + Returns the session, refreshing it if necessary. + + The session returned can be null if the session is not detected which + can happen in the event a user is not signed-in or has logged out. + """ + current_session: Union[Session, None] = None + if self._persist_session: + maybe_session = await self._storage.get_item(self._storage_key) + current_session = self._get_valid_session(maybe_session) + if not current_session: + await self._remove_session() + else: + current_session = self._in_memory_session + if not current_session: + return None + time_now = round(time()) + has_expired = ( + current_session.expires_at <= time_now + EXPIRY_MARGIN + if current_session.expires_at + else False + ) + return ( + await self._call_refresh_token(current_session.refresh_token) + if has_expired + else current_session + ) + + async def get_user(self, jwt: Union[str, None] = None) -> UserResponse: + """ + Gets the current user details if there is an existing session. + + Takes in an optional access token `jwt`. If no `jwt` is provided, + `get_user()` will attempt to get the `jwt` from the current session. + """ + if not jwt: + session = await self.get_session() + if session: + jwt = session.access_token + return await self._request("GET", "user", jwt=jwt, xform=parse_user_response) + + async def update_user(self, attributes: UserAttributes) -> UserResponse: + """ + Updates user data, if there is a logged in user. + """ + session = await self.get_session() + if not session: + raise AuthSessionMissingError() + response = await self._request( + "PUT", + "user", + body=attributes, + jwt=session.access_token, + xform=parse_user_response, + ) + session.user = response.user + await self._save_session(session) + self._notify_all_subscribers("USER_UPDATED", session) + return response + + async def set_session(self, access_token: str, refresh_token: str) -> AuthResponse: + """ + Sets the session data from the current session. If the current session + is expired, `set_session` will take care of refreshing it to obtain a + new session. + + If the refresh token in the current session is invalid and the current + session has expired, an error will be thrown. + + If the current session does not contain at `expires_at` field, + `set_session` will use the exp claim defined in the access token. + + The current session that minimally contains an access token, + refresh token and a user. + """ + time_now = round(time()) + expires_at = time_now + has_expired = True + session: Union[Session, None] = None + if access_token and access_token.split(".")[1]: + payload = self._decode_jwt(access_token) + exp = payload.get("exp") + if exp: + expires_at = int(exp) + has_expired = expires_at <= time_now + if has_expired: + if not refresh_token: + raise AuthSessionMissingError() + response = await self._refresh_access_token(refresh_token) + if not response.session: + return AuthResponse() + session = response.session + else: + response = await self.get_user(access_token) + session = Session( + access_token=access_token, + refresh_token=refresh_token, + user=response.user, + token_type="bearer", + expires_in=expires_at - time_now, + expires_at=expires_at, + ) + await self._save_session(session) + self._notify_all_subscribers("TOKEN_REFRESHED", session) + return AuthResponse(session=session, user=response.user) + + async def refresh_session( + self, refresh_token: Union[str, None] = None + ) -> AuthResponse: + """ + Returns a new session, regardless of expiry status. + + Takes in an optional current session. If not passed in, then refreshSession() + will attempt to retrieve it from getSession(). If the current session's + refresh token is invalid, an error will be thrown. + """ + if not refresh_token: + session = await self.get_session() + if session: + refresh_token = session.refresh_token + if not refresh_token: + raise AuthSessionMissingError() + session = await self._call_refresh_token(refresh_token) + return AuthResponse(session=session, user=session.user) + + async def sign_out(self) -> None: + """ + Inside a browser context, `sign_out` will remove the logged in user from the + browser session and log them out - removing all items from localstorage and + then trigger a `"SIGNED_OUT"` event. + + For server-side management, you can revoke all refresh tokens for a user by + passing a user's JWT through to `api.sign_out`. + + There is no way to revoke a user's access token jwt until it expires. + It is recommended to set a shorter expiry on the jwt for this reason. + """ + session = await self.get_session() + access_token = session.access_token if session else None + if access_token: + await self.admin.sign_out(access_token) + await self._remove_session() + self._notify_all_subscribers("SIGNED_OUT", None) + + async def on_auth_state_change( + self, + callback: Callable[[AuthChangeEvent, Union[Session, None]], None], + ) -> Subscription: + """ + Receive a notification every time an auth event happens. + """ + unique_id = str(uuid4()) + + def _unsubscribe() -> None: + self._state_change_emitters.pop(unique_id) + + subscription = Subscription( + id=unique_id, + callback=callback, + unsubscribe=_unsubscribe, + ) + self._state_change_emitters[unique_id] = subscription + return subscription + + async def reset_password_email( + self, + email: str, + options: Options = {}, + ) -> None: + """ + Sends a password reset request to an email address. + """ + await self._request( + "POST", + "recover", + body={ + "email": email, + "gotrue_meta_security": { + "captcha_token": options.get("captcha_token"), + }, + }, + redirect_to=options.get("redirect_to"), + ) + + # MFA methods + + async def _enroll(self, params: MFAEnrollParams) -> AuthMFAEnrollResponse: + session = await self.get_session() + if not session: + raise AuthSessionMissingError() + response = await self._request( + "POST", + "factors", + body=params, + jwt=session.access_token, + xform=AuthMFAEnrollResponse.parse_obj, + ) + if response.totp.qr_code: + response.totp.qr_code = f"data:image/svg+xml;utf-8,{response.totp.qr_code}" + return response + + async def _challenge(self, params: MFAChallengeParams) -> AuthMFAChallengeResponse: + session = await self.get_session() + if not session: + raise AuthSessionMissingError() + return await self._request( + "POST", + f"factors/{params.get('factor_id')}/challenge", + jwt=session.access_token, + xform=AuthMFAChallengeResponse.parse_obj, + ) + + async def _challenge_and_verify( + self, + params: MFAChallengeAndVerifyParams, + ) -> AuthMFAVerifyResponse: + response = await self._challenge( + { + "factor_id": params.get("factor_id"), + } + ) + return await self._verify( + { + "factor_id": params.get("factor_id"), + "challenge_id": response.id, + "code": params.get("code"), + } + ) + + async def _verify(self, params: MFAVerifyParams) -> AuthMFAVerifyResponse: + session = await self.get_session() + if not session: + raise AuthSessionMissingError() + response = await self._request( + "POST", + f"factors/{params.get('factor_id')}/verify", + body=params, + jwt=session.access_token, + xform=AuthMFAVerifyResponse.parse_obj, + ) + session = Session.parse_obj(response.dict()) + await self._save_session(session) + self._notify_all_subscribers("MFA_CHALLENGE_VERIFIED", session) + return response + + async def _unenroll(self, params: MFAUnenrollParams) -> AuthMFAUnenrollResponse: + session = await self.get_session() + if not session: + raise AuthSessionMissingError() + return await self._request( + "DELETE", + f"factors/{params.get('factor_id')}", + jwt=session.access_token, + xform=AuthMFAUnenrollResponse.parse_obj, + ) + + async def _list_factors(self) -> AuthMFAListFactorsResponse: + response = await self.get_user() + all = response.user.factors or [] + totp = [f for f in all if f.factor_type == "totp" and f.status == "verified"] + return AuthMFAListFactorsResponse(all=all, totp=totp) + + async def _get_authenticator_assurance_level( + self, + ) -> AuthMFAGetAuthenticatorAssuranceLevelResponse: + session = await self.get_session() + if not session: + return AuthMFAGetAuthenticatorAssuranceLevelResponse( + current_level=None, + next_level=None, + current_authentication_methods=[], + ) + payload = self._decode_jwt(session.access_token) + current_level: Union[AuthenticatorAssuranceLevels, None] = None + if payload.get("aal"): + current_level = payload.get("aal") + verified_factors = [ + f for f in session.user.factors or [] if f.status == "verified" + ] + next_level = "aal2" if verified_factors else current_level + current_authentication_methods = payload.get("amr") or [] + return AuthMFAGetAuthenticatorAssuranceLevelResponse( + current_level=current_level, + next_level=next_level, + current_authentication_methods=current_authentication_methods, + ) + + # Private methods + + async def _remove_session(self) -> None: + if self._persist_session: + await self._storage.remove_item(self._storage_key) + else: + self._in_memory_session = None + if self._refresh_token_timer: + self._refresh_token_timer.cancel() + self._refresh_token_timer = None + + async def _get_session_from_url( + self, + url: str, + ) -> Tuple[Session, Union[str, None]]: + if not self._is_implicit_grant_flow(url): + raise AuthImplicitGrantRedirectError("Not a valid implicit grant flow url.") + result = urlparse(url) + params = parse_qs(result.query) + error_description = self._get_param(params, "error_description") + if error_description: + error_code = self._get_param(params, "error_code") + error = self._get_param(params, "error") + if not error_code: + raise AuthImplicitGrantRedirectError("No error_code detected.") + if not error: + raise AuthImplicitGrantRedirectError("No error detected.") + raise AuthImplicitGrantRedirectError( + error_description, + {"code": error_code, "error": error}, + ) + provider_token = self._get_param(params, "provider_token") + provider_refresh_token = self._get_param(params, "provider_refresh_token") + access_token = self._get_param(params, "access_token") + if not access_token: + raise AuthImplicitGrantRedirectError("No access_token detected.") + expires_in = self._get_param(params, "expires_in") + if not expires_in: + raise AuthImplicitGrantRedirectError("No expires_in detected.") + refresh_token = self._get_param(params, "refresh_token") + if not refresh_token: + raise AuthImplicitGrantRedirectError("No refresh_token detected.") + token_type = self._get_param(params, "token_type") + if not token_type: + raise AuthImplicitGrantRedirectError("No token_type detected.") + time_now = round(time()) + expires_at = time_now + int(expires_in) + user = await self.get_user(access_token) + session = Session( + provider_token=provider_token, + provider_refresh_token=provider_refresh_token, + access_token=access_token, + expires_in=int(expires_in), + expires_at=expires_at, + refresh_token=refresh_token, + token_type=token_type, + user=user.user, + ) + redirect_type = self._get_param(params, "type") + return session, redirect_type + + async def _recover_and_refresh(self) -> None: + raw_session = await self._storage.get_item(self._storage_key) + current_session = self._get_valid_session(raw_session) + if not current_session: + if raw_session: + await self._remove_session() + return + time_now = round(time()) + expires_at = current_session.expires_at + if expires_at and expires_at < time_now + EXPIRY_MARGIN: + refresh_token = current_session.refresh_token + if self._auto_refresh_token and refresh_token: + self._network_retries += 1 + try: + await self._call_refresh_token(refresh_token) + self._network_retries = 0 + except Exception as e: + if ( + isinstance(e, AuthRetryableError) + and self._network_retries < MAX_RETRIES + ): + if self._refresh_token_timer: + self._refresh_token_timer.cancel() + self._refresh_token_timer = Timer( + (RETRY_INTERVAL ** (self._network_retries * 100)), + self._recover_and_refresh, + ) + self._refresh_token_timer.start() + return + await self._remove_session() + return + if self._persist_session: + await self._save_session(current_session) + self._notify_all_subscribers("SIGNED_IN", current_session) + + async def _call_refresh_token(self, refresh_token: str) -> Session: + if not refresh_token: + raise AuthSessionMissingError() + response = await self._refresh_access_token(refresh_token) + if not response.session: + raise AuthSessionMissingError() + await self._save_session(response.session) + self._notify_all_subscribers("TOKEN_REFRESHED", response.session) + return response.session + + async def _refresh_access_token(self, refresh_token: str) -> AuthResponse: + return await self._request( + "POST", + "token", + query={"grant_type": "refresh_token"}, + body={"refresh_token": refresh_token}, + xform=parse_auth_response, + ) + + async def _save_session(self, session: Session) -> None: + if not self._persist_session: + self._in_memory_session = session + expire_at = session.expires_at + if expire_at: + time_now = round(time()) + expire_in = expire_at - time_now + refresh_duration_before_expires = ( + EXPIRY_MARGIN if expire_in > EXPIRY_MARGIN else 0.5 + ) + value = (expire_in - refresh_duration_before_expires) * 1000 + await self._start_auto_refresh_token(value) + if self._persist_session and session.expires_at: + await self._storage.set_item(self._storage_key, session.json()) + + async def _start_auto_refresh_token(self, value: float) -> None: + if self._refresh_token_timer: + self._refresh_token_timer.cancel() + self._refresh_token_timer = None + if value <= 0 or not self._auto_refresh_token: + return + + async def refresh_token_function(): + self._network_retries += 1 + try: + session = await self.get_session() + if session: + await self._call_refresh_token(session.refresh_token) + self._network_retries = 0 + except Exception as e: + if ( + isinstance(e, AuthRetryableError) + and self._network_retries < MAX_RETRIES + ): + await self._start_auto_refresh_token( + RETRY_INTERVAL ** (self._network_retries * 100) + ) + + self._refresh_token_timer = Timer(value, refresh_token_function) + self._refresh_token_timer.start() + + def _notify_all_subscribers( + self, + event: AuthChangeEvent, + session: Union[Session, None], + ) -> None: + for subscription in self._state_change_emitters.values(): + subscription.callback(event, session) + + def _get_valid_session( + self, + raw_session: Union[str, None], + ) -> Union[Session, None]: + if not raw_session: + return None + data = loads(raw_session) + if not data: + return None + if not data.get("access_token"): + return None + if not data.get("refresh_token"): + return None + if not data.get("expires_at"): + return None + try: + expires_at = int(data["expires_at"]) + data["expires_at"] = expires_at + except ValueError: + return None + try: + return Session.parse_obj(data) + except Exception: + return None + + def _get_param( + self, + query_params: Dict[str, List[str]], + name: str, + ) -> Union[str, None]: + return query_params[name][0] if name in query_params else None + + def _is_implicit_grant_flow(self, url: str) -> bool: + result = urlparse(url) + params = parse_qs(result.query) + return "access_token" in params or "error_description" in params + + def _get_url_for_provider( + self, + provider: Provider, + params: Dict[str, str], + ) -> str: + params = {k: quote(v) for k, v in params.items()} + params["provider"] = quote(provider) + query = urlencode(params) + return f"{self._url}/authorize?{query}" + + def _decode_jwt(self, jwt: str) -> DecodedJWTDict: + """ + Decodes a JWT (without performing any validation). + """ + return decode_jwt_payload(jwt) diff --git a/gotrue/_async/gotrue_mfa_api.py b/gotrue/_async/gotrue_mfa_api.py index d8489142..1b5218f7 100644 --- a/gotrue/_async/gotrue_mfa_api.py +++ b/gotrue/_async/gotrue_mfa_api.py @@ -45,16 +45,21 @@ async def enroll(self, params: MFAEnrollParams) -> AuthMFAEnrollResponse: `aal2` authenticator level. """ headers = self.headers - data = {} + response = await self.http_client.post(url, json=params, headers=headers) - raise NotImplementedError() # pragma: no cover + return check_response(response) # pragma: no cover async def challenge(self, params: MFAChallengeParams) -> AuthMFAChallengeResponse: """ Prepares a challenge used to verify that a user has access to a MFA factor. Provide the challenge ID and verification code by calling `verify`. """ - raise NotImplementedError() # pragma: no cover + + # TODO(joel): fetch session + headers = self.headers + + response = await self.http_client.post(url, json=params, headers=headers) + return check_response(response) # pragma: no cover async def challenge_and_verify( self, diff --git a/gotrue/_sync/gotrue_client.py b/gotrue/_sync/gotrue_client.py new file mode 100644 index 00000000..be893d90 --- /dev/null +++ b/gotrue/_sync/gotrue_client.py @@ -0,0 +1,839 @@ +from __future__ import annotations + +from json import loads +from time import time +from typing import Callable, Dict, List, Tuple, Union +from urllib.parse import parse_qs, quote, urlencode, urlparse +from uuid import uuid4 + +from ..constants import ( + DEFAULT_HEADERS, + EXPIRY_MARGIN, + GOTRUE_URL, + MAX_RETRIES, + RETRY_INTERVAL, + STORAGE_KEY, +) +from ..errors import ( + AuthImplicitGrantRedirectError, + AuthInvalidCredentialsError, + AuthRetryableError, + AuthSessionMissingError, +) +from ..helpers import decode_jwt_payload, parse_auth_response, parse_user_response +from ..http_clients import SyncClient +from ..timer import Timer +from ..types import ( + AuthChangeEvent, + AuthenticatorAssuranceLevels, + AuthMFAChallengeResponse, + AuthMFAEnrollResponse, + AuthMFAGetAuthenticatorAssuranceLevelResponse, + AuthMFAListFactorsResponse, + AuthMFAUnenrollResponse, + AuthMFAVerifyResponse, + AuthResponse, + DecodedJWTDict, + MFAChallengeAndVerifyParams, + MFAChallengeParams, + MFAEnrollParams, + MFAUnenrollParams, + MFAVerifyParams, + OAuthResponse, + Options, + Provider, + Session, + SignInWithOAuthCredentials, + SignInWithPasswordCredentials, + SignInWithPasswordlessCredentials, + SignUpWithPasswordCredentials, + Subscription, + UserAttributes, + UserResponse, + VerifyOtpParams, +) +from .gotrue_admin_api import SyncGoTrueAdminAPI +from .gotrue_base_api import SyncGoTrueBaseAPI +from .gotrue_mfa_api import SyncGoTrueMFAAPI +from .storage import SyncMemoryStorage, SyncSupportedStorage + + +class SyncGoTrueClient(SyncGoTrueBaseAPI): + def __init__( + self, + *, + url: Union[str, None] = None, + headers: Union[Dict[str, str], None] = None, + storage_key: Union[str, None] = None, + auto_refresh_token: bool = True, + persist_session: bool = True, + storage: Union[SyncSupportedStorage, None] = None, + http_client: Union[SyncClient, None] = None, + ) -> None: + SyncGoTrueBaseAPI.__init__( + self, + url=url or GOTRUE_URL, + headers=headers or DEFAULT_HEADERS, + http_client=http_client, + ) + self._storage_key = storage_key or STORAGE_KEY + self._auto_refresh_token = auto_refresh_token + self._persist_session = persist_session + self._storage = storage or SyncMemoryStorage() + self._in_memory_session: Union[Session, None] = None + self._refresh_token_timer: Union[Timer, None] = None + self._network_retries = 0 + self._state_change_emitters: Dict[str, Subscription] = {} + + self.admin = SyncGoTrueAdminAPI( + url=self._url, + headers=self._headers, + http_client=self._http_client, + ) + self.mfa = SyncGoTrueMFAAPI() + self.mfa.challenge = self._challenge + self.mfa.challenge_and_verify = self._challenge_and_verify + self.mfa.enroll = self._enroll + self.mfa.get_authenticator_assurance_level = ( + self._get_authenticator_assurance_level + ) + self.mfa.list_factors = self._list_factors + self.mfa.unenroll = self._unenroll + self.mfa.verify = self._verify + + # Initializations + + def initialize(self, *, url: Union[str, None] = None) -> None: + if url and self._is_implicit_grant_flow(url): + self.initialize_from_url(url) + else: + self.initialize_from_storage() + + def initialize_from_storage(self) -> None: + return self._recover_and_refresh() + + def initialize_from_url(self, url: str) -> None: + try: + if self._is_implicit_grant_flow(url): + session, redirect_type = self._get_session_from_url(url) + self._save_session(session) + self._notify_all_subscribers("SIGNED_IN", session) + if redirect_type == "recovery": + self._notify_all_subscribers("PASSWORD_RECOVERY", session) + except Exception as e: + self._remove_session() + raise e + + # Public methods + + def sign_up( + self, + credentials: SignUpWithPasswordCredentials, + ) -> AuthResponse: + """ + Creates a new user. + """ + self._remove_session() + email = credentials.get("email") + phone = credentials.get("phone") + password = credentials.get("password") + options = credentials.get("options", {}) + redirect_to = options.get("redirect_to") + data = options.get("data") or {} + captcha_token = options.get("captcha_token") + if email: + response = self._request( + "POST", + "signup", + body={ + "email": email, + "password": password, + "data": data, + "gotrue_meta_security": { + "captcha_token": captcha_token, + }, + }, + redirect_to=redirect_to, + xform=parse_auth_response, + ) + elif phone: + response = self._request( + "POST", + "signup", + body={ + "phone": phone, + "password": password, + "data": data, + "gotrue_meta_security": { + "captcha_token": captcha_token, + }, + }, + xform=parse_auth_response, + ) + else: + raise AuthInvalidCredentialsError( + "You must provide either an email or phone number and a password" + ) + if response.session: + self._save_session(response.session) + self._notify_all_subscribers("SIGNED_IN", response.session) + return response + + def sign_in_with_password( + self, + credentials: SignInWithPasswordCredentials, + ) -> AuthResponse: + """ + Log in an existing user with an email or phone and password. + """ + self._remove_session() + email = credentials.get("email") + phone = credentials.get("phone") + password = credentials.get("password") + options = credentials.get("options", {}) + data = options.get("data") or {} + captcha_token = options.get("captcha_token") + if email: + response = self._request( + "POST", + "token", + body={ + "email": email, + "password": password, + "data": data, + "gotrue_meta_security": { + "captcha_token": captcha_token, + }, + }, + query={ + "grant_type": "password", + }, + xform=parse_auth_response, + ) + elif phone: + response = self._request( + "POST", + "token", + body={ + "phone": phone, + "password": password, + "data": data, + "gotrue_meta_security": { + "captcha_token": captcha_token, + }, + }, + query={ + "grant_type": "password", + }, + xform=parse_auth_response, + ) + else: + raise AuthInvalidCredentialsError( + "You must provide either an email or phone number and a password" + ) + if response.session: + self._save_session(response.session) + self._notify_all_subscribers("SIGNED_IN", response.session) + return response + + def sign_in_with_oauth( + self, + credentials: SignInWithOAuthCredentials, + ) -> OAuthResponse: + """ + Log in an existing user via a third-party provider. + """ + self._remove_session() + provider = credentials.get("provider") + options = credentials.get("options", {}) + redirect_to = options.get("redirect_to") + scopes = options.get("scopes") + params = options.get("query_params", {}) + if redirect_to: + params["redirect_to"] = redirect_to + if scopes: + params["scopes"] = scopes + url = self._get_url_for_provider(provider, params) + return OAuthResponse(provider=provider, url=url) + + def sign_in_with_otp( + self, + credentials: SignInWithPasswordlessCredentials, + ) -> AuthResponse: + """ + Log in a user using magiclink or a one-time password (OTP). + + If the `{{ .ConfirmationURL }}` variable is specified in + the email template, a magiclink will be sent. + + If the `{{ .Token }}` variable is specified in the email + template, an OTP will be sent. + + If you're using phone sign-ins, only an OTP will be sent. + You won't be able to send a magiclink for phone sign-ins. + """ + self._remove_session() + email = credentials.get("email") + phone = credentials.get("phone") + options = credentials.get("options", {}) + email_redirect_to = options.get("email_redirect_to") + should_create_user = options.get("create_user", True) + data = options.get("data") + captcha_token = options.get("captcha_token") + if email: + return self._request( + "POST", + "otp", + body={ + "email": email, + "data": data, + "create_user": should_create_user, + "gotrue_meta_security": { + "captcha_token": captcha_token, + }, + }, + redirect_to=email_redirect_to, + xform=parse_auth_response, + ) + if phone: + return self._request( + "POST", + "otp", + body={ + "phone": phone, + "data": data, + "create_user": should_create_user, + "gotrue_meta_security": { + "captcha_token": captcha_token, + }, + }, + xform=parse_auth_response, + ) + raise AuthInvalidCredentialsError( + "You must provide either an email or phone number" + ) + + def verify_otp(self, params: VerifyOtpParams) -> AuthResponse: + """ + Log in a user given a User supplied OTP received via mobile. + """ + self._remove_session() + response = self._request( + "POST", + "verify", + body={ + "gotrue_meta_security": { + "captcha_token": params.get("options", {}).get("captcha_token"), + }, + **params, + }, + redirect_to=params.get("options", {}).get("redirect_to"), + xform=parse_auth_response, + ) + if response.session: + self._save_session(response.session) + self._notify_all_subscribers("SIGNED_IN", response.session) + return response + + def get_session(self) -> Union[Session, None]: + """ + Returns the session, refreshing it if necessary. + + The session returned can be null if the session is not detected which + can happen in the event a user is not signed-in or has logged out. + """ + current_session: Union[Session, None] = None + if self._persist_session: + maybe_session = self._storage.get_item(self._storage_key) + current_session = self._get_valid_session(maybe_session) + if not current_session: + self._remove_session() + else: + current_session = self._in_memory_session + if not current_session: + return None + time_now = round(time()) + has_expired = ( + current_session.expires_at <= time_now + EXPIRY_MARGIN + if current_session.expires_at + else False + ) + return ( + self._call_refresh_token(current_session.refresh_token) + if has_expired + else current_session + ) + + def get_user(self, jwt: Union[str, None] = None) -> UserResponse: + """ + Gets the current user details if there is an existing session. + + Takes in an optional access token `jwt`. If no `jwt` is provided, + `get_user()` will attempt to get the `jwt` from the current session. + """ + if not jwt: + session = self.get_session() + if session: + jwt = session.access_token + return self._request("GET", "user", jwt=jwt, xform=parse_user_response) + + def update_user(self, attributes: UserAttributes) -> UserResponse: + """ + Updates user data, if there is a logged in user. + """ + session = self.get_session() + if not session: + raise AuthSessionMissingError() + response = self._request( + "PUT", + "user", + body=attributes, + jwt=session.access_token, + xform=parse_user_response, + ) + session.user = response.user + self._save_session(session) + self._notify_all_subscribers("USER_UPDATED", session) + return response + + def set_session(self, access_token: str, refresh_token: str) -> AuthResponse: + """ + Sets the session data from the current session. If the current session + is expired, `set_session` will take care of refreshing it to obtain a + new session. + + If the refresh token in the current session is invalid and the current + session has expired, an error will be thrown. + + If the current session does not contain at `expires_at` field, + `set_session` will use the exp claim defined in the access token. + + The current session that minimally contains an access token, + refresh token and a user. + """ + time_now = round(time()) + expires_at = time_now + has_expired = True + session: Union[Session, None] = None + if access_token and access_token.split(".")[1]: + payload = self._decode_jwt(access_token) + exp = payload.get("exp") + if exp: + expires_at = int(exp) + has_expired = expires_at <= time_now + if has_expired: + if not refresh_token: + raise AuthSessionMissingError() + response = self._refresh_access_token(refresh_token) + if not response.session: + return AuthResponse() + session = response.session + else: + response = self.get_user(access_token) + session = Session( + access_token=access_token, + refresh_token=refresh_token, + user=response.user, + token_type="bearer", + expires_in=expires_at - time_now, + expires_at=expires_at, + ) + self._save_session(session) + self._notify_all_subscribers("TOKEN_REFRESHED", session) + return AuthResponse(session=session, user=response.user) + + def refresh_session(self, refresh_token: Union[str, None] = None) -> AuthResponse: + """ + Returns a new session, regardless of expiry status. + + Takes in an optional current session. If not passed in, then refreshSession() + will attempt to retrieve it from getSession(). If the current session's + refresh token is invalid, an error will be thrown. + """ + if not refresh_token: + session = self.get_session() + if session: + refresh_token = session.refresh_token + if not refresh_token: + raise AuthSessionMissingError() + session = self._call_refresh_token(refresh_token) + return AuthResponse(session=session, user=session.user) + + def sign_out(self) -> None: + """ + Inside a browser context, `sign_out` will remove the logged in user from the + browser session and log them out - removing all items from localstorage and + then trigger a `"SIGNED_OUT"` event. + + For server-side management, you can revoke all refresh tokens for a user by + passing a user's JWT through to `api.sign_out`. + + There is no way to revoke a user's access token jwt until it expires. + It is recommended to set a shorter expiry on the jwt for this reason. + """ + session = self.get_session() + access_token = session.access_token if session else None + if access_token: + self.admin.sign_out(access_token) + self._remove_session() + self._notify_all_subscribers("SIGNED_OUT", None) + + def on_auth_state_change( + self, + callback: Callable[[AuthChangeEvent, Union[Session, None]], None], + ) -> Subscription: + """ + Receive a notification every time an auth event happens. + """ + unique_id = str(uuid4()) + + def _unsubscribe() -> None: + self._state_change_emitters.pop(unique_id) + + subscription = Subscription( + id=unique_id, + callback=callback, + unsubscribe=_unsubscribe, + ) + self._state_change_emitters[unique_id] = subscription + return subscription + + def reset_password_email( + self, + email: str, + options: Options = {}, + ) -> None: + """ + Sends a password reset request to an email address. + """ + self._request( + "POST", + "recover", + body={ + "email": email, + "gotrue_meta_security": { + "captcha_token": options.get("captcha_token"), + }, + }, + redirect_to=options.get("redirect_to"), + ) + + # MFA methods + + def _enroll(self, params: MFAEnrollParams) -> AuthMFAEnrollResponse: + session = self.get_session() + if not session: + raise AuthSessionMissingError() + response = self._request( + "POST", + "factors", + body=params, + jwt=session.access_token, + xform=AuthMFAEnrollResponse.parse_obj, + ) + if response.totp.qr_code: + response.totp.qr_code = f"data:image/svg+xml;utf-8,{response.totp.qr_code}" + return response + + def _challenge(self, params: MFAChallengeParams) -> AuthMFAChallengeResponse: + session = self.get_session() + if not session: + raise AuthSessionMissingError() + return self._request( + "POST", + f"factors/{params.get('factor_id')}/challenge", + jwt=session.access_token, + xform=AuthMFAChallengeResponse.parse_obj, + ) + + def _challenge_and_verify( + self, + params: MFAChallengeAndVerifyParams, + ) -> AuthMFAVerifyResponse: + response = self._challenge( + { + "factor_id": params.get("factor_id"), + } + ) + return self._verify( + { + "factor_id": params.get("factor_id"), + "challenge_id": response.id, + "code": params.get("code"), + } + ) + + def _verify(self, params: MFAVerifyParams) -> AuthMFAVerifyResponse: + session = self.get_session() + if not session: + raise AuthSessionMissingError() + response = self._request( + "POST", + f"factors/{params.get('factor_id')}/verify", + body=params, + jwt=session.access_token, + xform=AuthMFAVerifyResponse.parse_obj, + ) + session = Session.parse_obj(response.dict()) + self._save_session(session) + self._notify_all_subscribers("MFA_CHALLENGE_VERIFIED", session) + return response + + def _unenroll(self, params: MFAUnenrollParams) -> AuthMFAUnenrollResponse: + session = self.get_session() + if not session: + raise AuthSessionMissingError() + return self._request( + "DELETE", + f"factors/{params.get('factor_id')}", + jwt=session.access_token, + xform=AuthMFAUnenrollResponse.parse_obj, + ) + + def _list_factors(self) -> AuthMFAListFactorsResponse: + response = self.get_user() + all = response.user.factors or [] + totp = [f for f in all if f.factor_type == "totp" and f.status == "verified"] + return AuthMFAListFactorsResponse(all=all, totp=totp) + + def _get_authenticator_assurance_level( + self, + ) -> AuthMFAGetAuthenticatorAssuranceLevelResponse: + session = self.get_session() + if not session: + return AuthMFAGetAuthenticatorAssuranceLevelResponse( + current_level=None, + next_level=None, + current_authentication_methods=[], + ) + payload = self._decode_jwt(session.access_token) + current_level: Union[AuthenticatorAssuranceLevels, None] = None + if payload.get("aal"): + current_level = payload.get("aal") + verified_factors = [ + f for f in session.user.factors or [] if f.status == "verified" + ] + next_level = "aal2" if verified_factors else current_level + current_authentication_methods = payload.get("amr") or [] + return AuthMFAGetAuthenticatorAssuranceLevelResponse( + current_level=current_level, + next_level=next_level, + current_authentication_methods=current_authentication_methods, + ) + + # Private methods + + def _remove_session(self) -> None: + if self._persist_session: + self._storage.remove_item(self._storage_key) + else: + self._in_memory_session = None + if self._refresh_token_timer: + self._refresh_token_timer.cancel() + self._refresh_token_timer = None + + def _get_session_from_url( + self, + url: str, + ) -> Tuple[Session, Union[str, None]]: + if not self._is_implicit_grant_flow(url): + raise AuthImplicitGrantRedirectError("Not a valid implicit grant flow url.") + result = urlparse(url) + params = parse_qs(result.query) + error_description = self._get_param(params, "error_description") + if error_description: + error_code = self._get_param(params, "error_code") + error = self._get_param(params, "error") + if not error_code: + raise AuthImplicitGrantRedirectError("No error_code detected.") + if not error: + raise AuthImplicitGrantRedirectError("No error detected.") + raise AuthImplicitGrantRedirectError( + error_description, + {"code": error_code, "error": error}, + ) + provider_token = self._get_param(params, "provider_token") + provider_refresh_token = self._get_param(params, "provider_refresh_token") + access_token = self._get_param(params, "access_token") + if not access_token: + raise AuthImplicitGrantRedirectError("No access_token detected.") + expires_in = self._get_param(params, "expires_in") + if not expires_in: + raise AuthImplicitGrantRedirectError("No expires_in detected.") + refresh_token = self._get_param(params, "refresh_token") + if not refresh_token: + raise AuthImplicitGrantRedirectError("No refresh_token detected.") + token_type = self._get_param(params, "token_type") + if not token_type: + raise AuthImplicitGrantRedirectError("No token_type detected.") + time_now = round(time()) + expires_at = time_now + int(expires_in) + user = self.get_user(access_token) + session = Session( + provider_token=provider_token, + provider_refresh_token=provider_refresh_token, + access_token=access_token, + expires_in=int(expires_in), + expires_at=expires_at, + refresh_token=refresh_token, + token_type=token_type, + user=user.user, + ) + redirect_type = self._get_param(params, "type") + return session, redirect_type + + def _recover_and_refresh(self) -> None: + raw_session = self._storage.get_item(self._storage_key) + current_session = self._get_valid_session(raw_session) + if not current_session: + if raw_session: + self._remove_session() + return + time_now = round(time()) + expires_at = current_session.expires_at + if expires_at and expires_at < time_now + EXPIRY_MARGIN: + refresh_token = current_session.refresh_token + if self._auto_refresh_token and refresh_token: + self._network_retries += 1 + try: + self._call_refresh_token(refresh_token) + self._network_retries = 0 + except Exception as e: + if ( + isinstance(e, AuthRetryableError) + and self._network_retries < MAX_RETRIES + ): + if self._refresh_token_timer: + self._refresh_token_timer.cancel() + self._refresh_token_timer = Timer( + (RETRY_INTERVAL ** (self._network_retries * 100)), + self._recover_and_refresh, + ) + self._refresh_token_timer.start() + return + self._remove_session() + return + if self._persist_session: + self._save_session(current_session) + self._notify_all_subscribers("SIGNED_IN", current_session) + + def _call_refresh_token(self, refresh_token: str) -> Session: + if not refresh_token: + raise AuthSessionMissingError() + response = self._refresh_access_token(refresh_token) + if not response.session: + raise AuthSessionMissingError() + self._save_session(response.session) + self._notify_all_subscribers("TOKEN_REFRESHED", response.session) + return response.session + + def _refresh_access_token(self, refresh_token: str) -> AuthResponse: + return self._request( + "POST", + "token", + query={"grant_type": "refresh_token"}, + body={"refresh_token": refresh_token}, + xform=parse_auth_response, + ) + + def _save_session(self, session: Session) -> None: + if not self._persist_session: + self._in_memory_session = session + expire_at = session.expires_at + if expire_at: + time_now = round(time()) + expire_in = expire_at - time_now + refresh_duration_before_expires = ( + EXPIRY_MARGIN if expire_in > EXPIRY_MARGIN else 0.5 + ) + value = (expire_in - refresh_duration_before_expires) * 1000 + self._start_auto_refresh_token(value) + if self._persist_session and session.expires_at: + self._storage.set_item(self._storage_key, session.json()) + + def _start_auto_refresh_token(self, value: float) -> None: + if self._refresh_token_timer: + self._refresh_token_timer.cancel() + self._refresh_token_timer = None + if value <= 0 or not self._auto_refresh_token: + return + + def refresh_token_function(): + self._network_retries += 1 + try: + session = self.get_session() + if session: + self._call_refresh_token(session.refresh_token) + self._network_retries = 0 + except Exception as e: + if ( + isinstance(e, AuthRetryableError) + and self._network_retries < MAX_RETRIES + ): + self._start_auto_refresh_token( + RETRY_INTERVAL ** (self._network_retries * 100) + ) + + self._refresh_token_timer = Timer(value, refresh_token_function) + self._refresh_token_timer.start() + + def _notify_all_subscribers( + self, + event: AuthChangeEvent, + session: Union[Session, None], + ) -> None: + for subscription in self._state_change_emitters.values(): + subscription.callback(event, session) + + def _get_valid_session( + self, + raw_session: Union[str, None], + ) -> Union[Session, None]: + if not raw_session: + return None + data = loads(raw_session) + if not data: + return None + if not data.get("access_token"): + return None + if not data.get("refresh_token"): + return None + if not data.get("expires_at"): + return None + try: + expires_at = int(data["expires_at"]) + data["expires_at"] = expires_at + except ValueError: + return None + try: + return Session.parse_obj(data) + except Exception: + return None + + def _get_param( + self, + query_params: Dict[str, List[str]], + name: str, + ) -> Union[str, None]: + return query_params[name][0] if name in query_params else None + + def _is_implicit_grant_flow(self, url: str) -> bool: + result = urlparse(url) + params = parse_qs(result.query) + return "access_token" in params or "error_description" in params + + def _get_url_for_provider( + self, + provider: Provider, + params: Dict[str, str], + ) -> str: + params = {k: quote(v) for k, v in params.items()} + params["provider"] = quote(provider) + query = urlencode(params) + return f"{self._url}/authorize?{query}" + + def _decode_jwt(self, jwt: str) -> DecodedJWTDict: + """ + Decodes a JWT (without performing any validation). + """ + return decode_jwt_payload(jwt) diff --git a/gotrue/_sync/gotrue_mfa_api.py b/gotrue/_sync/gotrue_mfa_api.py index 16bec8d5..a429af23 100644 --- a/gotrue/_sync/gotrue_mfa_api.py +++ b/gotrue/_sync/gotrue_mfa_api.py @@ -1,3 +1,4 @@ +from ..http_clients import SyncClient from ..types import ( AuthMFAChallengeResponse, AuthMFAEnrollResponse, @@ -18,6 +19,19 @@ class SyncGoTrueMFAAPI: Contains the full multi-factor authentication API. """ + def __init__( + self, + *, + url: str, + headers: Dict[str, str], + cookie_options: CookieOptions, + http_client: Optional[SyncClient] = None + ): + self.url = url + self.headers = headers + self.cookie_options = cookie_options + self.http_client = http_client or SyncClient() + def enroll(self, params: MFAEnrollParams) -> AuthMFAEnrollResponse: """ Starts the enrollment process for a new Multi-Factor Authentication @@ -30,14 +44,22 @@ def enroll(self, params: MFAEnrollParams) -> AuthMFAEnrollResponse: factor. All other sessions are logged out and the current one gets an `aal2` authenticator level. """ - raise NotImplementedError() # pragma: no cover + headers = self.headers + response = self.http_client.post(url, json=params, headers=headers) + + return check_response(response) # pragma: no cover def challenge(self, params: MFAChallengeParams) -> AuthMFAChallengeResponse: """ Prepares a challenge used to verify that a user has access to a MFA factor. Provide the challenge ID and verification code by calling `verify`. """ - raise NotImplementedError() # pragma: no cover + + # TODO(joel): fetch session + headers = self.headers + + response = self.http_client.post(url, json=params, headers=headers) + return check_response(response) # pragma: no cover def challenge_and_verify( self, From 894e9067019ed14daaddba57def220d1c5ff990f Mon Sep 17 00:00:00 2001 From: Sourcery AI <> Date: Tue, 17 Jan 2023 08:44:40 +0000 Subject: [PATCH 05/16] 'Refactored by Sourcery' --- gotrue/_async/gotrue_client.py | 9 +++---- gotrue/_sync/gotrue_client.py | 44 ++++++++++++++++------------------ 2 files changed, 23 insertions(+), 30 deletions(-) diff --git a/gotrue/_async/gotrue_client.py b/gotrue/_async/gotrue_client.py index 97f5f14c..71a45c88 100644 --- a/gotrue/_async/gotrue_client.py +++ b/gotrue/_async/gotrue_client.py @@ -417,8 +417,7 @@ async def set_session(self, access_token: str, refresh_token: str) -> AuthRespon session: Union[Session, None] = None if access_token and access_token.split(".")[1]: payload = self._decode_jwt(access_token) - exp = payload.get("exp") - if exp: + if exp := payload.get("exp"): expires_at = int(exp) has_expired = expires_at <= time_now if has_expired: @@ -642,8 +641,7 @@ async def _get_session_from_url( raise AuthImplicitGrantRedirectError("Not a valid implicit grant flow url.") result = urlparse(url) params = parse_qs(result.query) - error_description = self._get_param(params, "error_description") - if error_description: + if error_description := self._get_param(params, "error_description"): error_code = self._get_param(params, "error_code") error = self._get_param(params, "error") if not error_code: @@ -741,8 +739,7 @@ async def _refresh_access_token(self, refresh_token: str) -> AuthResponse: async def _save_session(self, session: Session) -> None: if not self._persist_session: self._in_memory_session = session - expire_at = session.expires_at - if expire_at: + if expire_at := session.expires_at: time_now = round(time()) expire_in = expire_at - time_now refresh_duration_before_expires = ( diff --git a/gotrue/_sync/gotrue_client.py b/gotrue/_sync/gotrue_client.py index be893d90..43e41079 100644 --- a/gotrue/_sync/gotrue_client.py +++ b/gotrue/_sync/gotrue_client.py @@ -372,8 +372,7 @@ def get_user(self, jwt: Union[str, None] = None) -> UserResponse: `get_user()` will attempt to get the `jwt` from the current session. """ if not jwt: - session = self.get_session() - if session: + if session := self.get_session(): jwt = session.access_token return self._request("GET", "user", jwt=jwt, xform=parse_user_response) @@ -417,8 +416,7 @@ def set_session(self, access_token: str, refresh_token: str) -> AuthResponse: session: Union[Session, None] = None if access_token and access_token.split(".")[1]: payload = self._decode_jwt(access_token) - exp = payload.get("exp") - if exp: + if exp := payload.get("exp"): expires_at = int(exp) has_expired = expires_at <= time_now if has_expired: @@ -536,15 +534,15 @@ def _enroll(self, params: MFAEnrollParams) -> AuthMFAEnrollResponse: return response def _challenge(self, params: MFAChallengeParams) -> AuthMFAChallengeResponse: - session = self.get_session() - if not session: + if session := self.get_session(): + return self._request( + "POST", + f"factors/{params.get('factor_id')}/challenge", + jwt=session.access_token, + xform=AuthMFAChallengeResponse.parse_obj, + ) + else: raise AuthSessionMissingError() - return self._request( - "POST", - f"factors/{params.get('factor_id')}/challenge", - jwt=session.access_token, - xform=AuthMFAChallengeResponse.parse_obj, - ) def _challenge_and_verify( self, @@ -580,15 +578,15 @@ def _verify(self, params: MFAVerifyParams) -> AuthMFAVerifyResponse: return response def _unenroll(self, params: MFAUnenrollParams) -> AuthMFAUnenrollResponse: - session = self.get_session() - if not session: + if session := self.get_session(): + return self._request( + "DELETE", + f"factors/{params.get('factor_id')}", + jwt=session.access_token, + xform=AuthMFAUnenrollResponse.parse_obj, + ) + else: raise AuthSessionMissingError() - return self._request( - "DELETE", - f"factors/{params.get('factor_id')}", - jwt=session.access_token, - xform=AuthMFAUnenrollResponse.parse_obj, - ) def _list_factors(self) -> AuthMFAListFactorsResponse: response = self.get_user() @@ -640,8 +638,7 @@ def _get_session_from_url( raise AuthImplicitGrantRedirectError("Not a valid implicit grant flow url.") result = urlparse(url) params = parse_qs(result.query) - error_description = self._get_param(params, "error_description") - if error_description: + if error_description := self._get_param(params, "error_description"): error_code = self._get_param(params, "error_code") error = self._get_param(params, "error") if not error_code: @@ -739,8 +736,7 @@ def _refresh_access_token(self, refresh_token: str) -> AuthResponse: def _save_session(self, session: Session) -> None: if not self._persist_session: self._in_memory_session = session - expire_at = session.expires_at - if expire_at: + if expire_at := session.expires_at: time_now = round(time()) expire_in = expire_at - time_now refresh_duration_before_expires = ( From a458df36449d634c54bae1fadb46efe6d9075a80 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Sun, 22 Jan 2023 22:56:23 +0800 Subject: [PATCH 06/16] refactor: remove api file --- gotrue/_async/api.py | 642 ------------------------------------------- 1 file changed, 642 deletions(-) delete mode 100644 gotrue/_async/api.py diff --git a/gotrue/_async/api.py b/gotrue/_async/api.py deleted file mode 100644 index 77fbf84a..00000000 --- a/gotrue/_async/api.py +++ /dev/null @@ -1,642 +0,0 @@ -from __future__ import annotations - -from typing import Any, Dict, List, Optional, Union - -from pydantic import parse_obj_as - -from ..exceptions import APIError -from ..helpers import check_response, encode_uri_component -from ..http_clients import AsyncClient -from ..types import ( - CookieOptions, - LinkType, - Provider, - Session, - User, - UserAttributes, - determine_session_or_user_model_from_response, -) - - -class AsyncGoTrueAPI: - def __init__( - self, - *, - url: str, - headers: Dict[str, str], - cookie_options: CookieOptions, - http_client: Optional[AsyncClient] = None, - ) -> None: - """Initialise API class.""" - self.url = url - self.headers = headers - self.cookie_options = cookie_options - self.http_client = http_client or AsyncClient() - - async def __aenter__(self) -> AsyncGoTrueAPI: - return self - - async def __aexit__(self, exc_t, exc_v, exc_tb) -> None: - await self.close() - - async def close(self) -> None: - await self.http_client.aclose() - - async def create_user(self, *, attributes: UserAttributes) -> User: - """Creates a new user. - - This function should only be called on a server. - Never expose your `service_role` key in the browser. - - Parameters - ---------- - attributes: UserAttributes - The data you want to create the user with. - - Returns - ------- - response : User - The created user - - Raises - ------ - APIError - If an error occurs. - """ - headers = self.headers - data = attributes.dict() - url = f"{self.url}/admin/users" - response = await self.http_client.post(url, json=data, headers=headers) - return User.parse_response(response) - - async def list_users(self) -> List[User]: - """Get a list of users. - - This function should only be called on a server. - Never expose your `service_role` key in the browser. - - Returns - ------- - response : List[User] - A list of users - - Raises - ------ - APIError - If an error occurs. - """ - headers = self.headers - url = f"{self.url}/admin/users" - response = await self.http_client.get(url, headers=headers) - check_response(response) - users = response.json().get("users") - if users is None: - raise APIError("No users found in response", 400) - if not isinstance(users, list): - raise APIError("Expected a list of users", 400) - return parse_obj_as(List[User], users) - - async def sign_up_with_email( - self, - *, - email: str, - password: str, - redirect_to: Optional[str] = None, - data: Optional[Dict[str, Any]] = None, - ) -> Union[Session, User]: - """Creates a new user using their email address. - - Parameters - ---------- - email : str - The email address of the user. - password : str - The password of the user. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - data : Optional[Dict[str, Any]] - Optional user metadata. - - Returns - ------- - response : Union[Session, User] - A logged-in session if the server has "autoconfirm" ON - A user if the server has "autoconfirm" OFF - - Raises - ------ - APIError - If an error occurs. - """ - headers = self.headers - query_string = "" - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - query_string = f"?redirect_to={redirect_to_encoded}" - data = {"email": email, "password": password, "data": data} - url = f"{self.url}/signup{query_string}" - response = await self.http_client.post(url, json=data, headers=headers) - SessionOrUserModel = determine_session_or_user_model_from_response(response) - return SessionOrUserModel.parse_response(response) - - async def sign_in_with_email( - self, - *, - email: str, - password: str, - redirect_to: Optional[str] = None, - ) -> Session: - """Logs in an existing user using their email address. - - Parameters - ---------- - email : str - The email address of the user. - password : str - The password of the user. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - - Returns - ------- - response : Session - A logged-in session - - Raises - ------ - APIError - If an error occurs. - """ - - headers = self.headers - query_string = "?grant_type=password" - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - query_string += f"&redirect_to={redirect_to_encoded}" - data = {"email": email, "password": password} - url = f"{self.url}/token{query_string}" - response = await self.http_client.post(url, json=data, headers=headers) - return Session.parse_response(response) - - async def sign_up_with_phone( - self, - *, - phone: str, - password: str, - data: Optional[Dict[str, Any]] = None, - ) -> Union[Session, User]: - """Signs up a new user using their phone number and a password. - - Parameters - ---------- - phone : str - The phone number of the user. - password : str - The password of the user. - data : Optional[Dict[str, Any]] - Optional user metadata. - - Returns - ------- - response : Union[Session, User] - A logged-in session if the server has "autoconfirm" ON - A user if the server has "autoconfirm" OFF - - Raises - ------ - APIError - If an error occurs. - """ - headers = self.headers - data = {"phone": phone, "password": password, "data": data} - url = f"{self.url}/signup" - response = await self.http_client.post(url, json=data, headers=headers) - SessionOrUserModel = determine_session_or_user_model_from_response(response) - return SessionOrUserModel.parse_response(response) - - async def sign_in_with_phone( - self, - *, - phone: str, - password: str, - ) -> Session: - """Logs in an existing user using their phone number and password. - - Parameters - ---------- - phone : str - The phone number of the user. - password : str - The password of the user. - - Returns - ------- - response : Session - A logged-in session - - Raises - ------ - APIError - If an error occurs. - """ - data = {"phone": phone, "password": password} - url = f"{self.url}/token?grant_type=password" - headers = self.headers - response = await self.http_client.post(url, json=data, headers=headers) - return Session.parse_response(response) - - async def send_magic_link_email( - self, - *, - email: str, - create_user: bool, - redirect_to: Optional[str] = None, - ) -> None: - """Sends a magic login link to an email address. - - Parameters - ---------- - email : str - The email address of the user. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - - Raises - ------ - APIError - If an error occurs. - """ - headers = self.headers - query_string = "" - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - query_string = f"?redirect_to={redirect_to_encoded}" - data = {"email": email, "create_user": create_user} - url = f"{self.url}/magiclink{query_string}" - response = await self.http_client.post(url, json=data, headers=headers) - return check_response(response) - - async def send_mobile_otp(self, *, phone: str, create_user: bool) -> None: - """Sends a mobile OTP via SMS. Will register the account if it doesn't already exist - - Parameters - ---------- - phone : str - The user's phone number WITH international prefix - - Raises - ------ - APIError - If an error occurs. - """ - headers = self.headers - data = {"phone": phone, "create_user": create_user} - url = f"{self.url}/otp" - response = await self.http_client.post(url, json=data, headers=headers) - return check_response(response) - - async def verify_mobile_otp( - self, - *, - phone: str, - token: str, - redirect_to: Optional[str] = None, - ) -> Union[Session, User]: - """Send User supplied Mobile OTP to be verified - - Parameters - ---------- - phone : str - The user's phone number WITH international prefix - token : str - Token that user was sent to their mobile phone - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - - Returns - ------- - response : Union[Session, User] - A logged-in session if the server has "autoconfirm" ON - A user if the server has "autoconfirm" OFF - - Raises - ------ - APIError - If an error occurs. - """ - headers = self.headers - data = { - "phone": phone, - "token": token, - "type": "sms", - } - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - data["redirect_to"] = redirect_to_encoded - url = f"{self.url}/verify" - response = await self.http_client.post(url, json=data, headers=headers) - SessionOrUserModel = determine_session_or_user_model_from_response(response) - return SessionOrUserModel.parse_response(response) - - async def invite_user_by_email( - self, - *, - email: str, - redirect_to: Optional[str] = None, - data: Optional[Dict[str, Any]] = None, - ) -> User: - """Sends an invite link to an email address. - - Parameters - ---------- - email : str - The email address of the user. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - data : Optional[Dict[str, Any]] - Optional user metadata. - - Returns - ------- - response : User - A user - - Raises - ------ - APIError - If an error occurs. - """ - headers = self.headers - query_string = "" - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - query_string = f"?redirect_to={redirect_to_encoded}" - data = {"email": email, "data": data} - url = f"{self.url}/invite{query_string}" - response = await self.http_client.post(url, json=data, headers=headers) - return User.parse_response(response) - - async def reset_password_for_email( - self, - *, - email: str, - redirect_to: Optional[str] = None, - ) -> None: - """Sends a reset request to an email address. - - Parameters - ---------- - email : str - The email address of the user. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - - Raises - ------ - APIError - If an error occurs. - """ - headers = self.headers - query_string = "" - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - query_string = f"?redirect_to={redirect_to_encoded}" - data = {"email": email} - url = f"{self.url}/recover{query_string}" - response = await self.http_client.post(url, json=data, headers=headers) - return check_response(response) - - def _create_request_headers(self, *, jwt: str) -> Dict[str, str]: - """Create temporary object. - - Create a temporary object with all configured headers and adds the - Authorization token to be used on request methods. - - Parameters - ---------- - jwt : str - A valid, logged-in JWT. - - Returns - ------- - headers : dict of str - The headers required for a successful request statement with the - supabase backend. - """ - headers = {**self.headers, "Authorization": f"Bearer {jwt}"} - return headers - - async def sign_out(self, *, jwt: str) -> None: - """Removes a logged-in session. - - Parameters - ---------- - jwt : str - A valid, logged-in JWT. - """ - headers = self._create_request_headers(jwt=jwt) - url = f"{self.url}/logout" - await self.http_client.post(url, headers=headers) - - async def get_url_for_provider( - self, - *, - provider: Provider, - redirect_to: Optional[str] = None, - scopes: Optional[str] = None, - ) -> str: - """Generates the relevant login URL for a third-party provider. - - Parameters - ---------- - provider : Provider - One of the providers supported by GoTrue. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - scopes : Optional[str] - A space-separated list of scopes granted to the OAuth application. - - Returns - ------- - url : str - The URL to redirect the user to. - - Raises - ------ - APIError - If an error occurs. - """ - url_params = [f"provider={encode_uri_component(provider)}"] - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - url_params.append(f"redirect_to={redirect_to_encoded}") - if scopes: - url_params.append(f"scopes={encode_uri_component(scopes)}") - return f"{self.url}/authorize?{'&'.join(url_params)}" - - async def get_user(self, *, jwt: str) -> User: - """Gets the user details. - - Parameters - ---------- - jwt : str - A valid, logged-in JWT. - - Returns - ------- - response : User - A user - - Raises - ------ - APIError - If an error occurs. - """ - headers = self._create_request_headers(jwt=jwt) - url = f"{self.url}/user" - response = await self.http_client.get(url, headers=headers) - return User.parse_response(response) - - async def update_user( - self, - *, - jwt: str, - attributes: UserAttributes, - ) -> User: - """ - Updates the user data. - - Parameters - ---------- - jwt : str - A valid, logged-in JWT. - attributes : UserAttributes - The data you want to update. - - Returns - ------- - response : User - A user - - Raises - ------ - APIError - If an error occurs. - """ - headers = self._create_request_headers(jwt=jwt) - data = attributes.dict() - url = f"{self.url}/user" - response = await self.http_client.put(url, json=data, headers=headers) - return User.parse_response(response) - - async def delete_user(self, *, uid: str, jwt: str) -> None: - """Delete a user. Requires a `service_role` key. - - This function should only be called on a server. - Never expose your `service_role` key in the browser. - - Parameters - ---------- - uid : str - The user uid you want to remove. - jwt : str - A valid, logged-in JWT. - - Returns - ------- - response : User - A user - - Raises - ------ - APIError - If an error occurs. - """ - headers = self._create_request_headers(jwt=jwt) - url = f"{self.url}/admin/users/{uid}" - response = await self.http_client.delete(url, headers=headers) - return check_response(response) - - async def refresh_access_token(self, *, refresh_token: str) -> Session: - """Generates a new JWT. - - Parameters - ---------- - refresh_token : str - A valid refresh token that was returned on login. - - Returns - ------- - response : Session - A session - - Raises - ------ - APIError - If an error occurs. - """ - data = {"refresh_token": refresh_token} - url = f"{self.url}/token?grant_type=refresh_token" - headers = self.headers - response = await self.http_client.post(url, json=data, headers=headers) - return Session.parse_response(response) - - async def generate_link( - self, - *, - type: LinkType, - email: str, - password: Optional[str] = None, - redirect_to: Optional[str] = None, - data: Optional[Dict[str, Any]] = None, - ) -> Union[Session, User]: - """ - Generates links to be sent via email or other. - - Parameters - ---------- - type : LinkType - The link type ("signup" or "magiclink" or "recovery" or "invite"). - email : str - The user's email. - password : Optional[str] - User password. For signup only. - redirect_to : Optional[str] - The link type ("signup" or "magiclink" or "recovery" or "invite"). - data : Optional[Dict[str, Any]] - Optional user metadata. For signup only. - - Returns - ------- - response : Union[Session, User] - A logged-in session if the server has "autoconfirm" ON - A user if the server has "autoconfirm" OFF - - Raises - ------ - APIError - If an error occurs. - """ - headers = self.headers - data = { - "type": type, - "email": email, - "data": data, - } - if password: - data["password"] = password - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - data["redirect_to"] = redirect_to_encoded - url = f"{self.url}/admin/generate_link" - response = await self.http_client.post(url, json=data, headers=headers) - SessionOrUserModel = determine_session_or_user_model_from_response(response) - return SessionOrUserModel.parse_response(response) - - async def set_auth_cookie(self, *, req, res): - """Stub for parity with JS api.""" - raise NotImplementedError("set_auth_cookie not implemented.") - - async def get_user_by_cookie(self, *, req): - """Stub for parity with JS api.""" - raise NotImplementedError("get_user_by_cookie not implemented.") From 3ffaa9eed8606237216ba3229c0ba152a1489880 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Sun, 22 Jan 2023 23:04:50 +0800 Subject: [PATCH 07/16] refactor: remove mfa related content for now --- .github/workflows/ci.yml | 2 +- gotrue/_async/gotrue_admin_api.py | 147 ++++++++++++++++++++++++++ gotrue/_async/gotrue_admin_mfa_api.py | 32 ------ gotrue/_async/gotrue_base_api.py | 120 +++++++++++++++++++++ gotrue/_async/gotrue_mfa_api.py | 116 -------------------- gotrue/_sync/gotrue_admin_api.py | 147 ++++++++++++++++++++++++++ gotrue/_sync/gotrue_base_api.py | 120 +++++++++++++++++++++ gotrue/_sync/gotrue_client.py | 35 +++--- 8 files changed, 553 insertions(+), 166 deletions(-) create mode 100644 gotrue/_async/gotrue_admin_api.py delete mode 100644 gotrue/_async/gotrue_admin_mfa_api.py create mode 100644 gotrue/_async/gotrue_base_api.py delete mode 100644 gotrue/_async/gotrue_mfa_api.py create mode 100644 gotrue/_sync/gotrue_admin_api.py create mode 100644 gotrue/_sync/gotrue_base_api.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3971e0f4..8ed573e7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,7 +18,7 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Set up Poetry - uses: abatilo/actions-poetry@v2.1.6 + uses: abatilo/actions-poetry@v2.2.0 with: poetry-version: 1.3.2 - name: Run Tests diff --git a/gotrue/_async/gotrue_admin_api.py b/gotrue/_async/gotrue_admin_api.py new file mode 100644 index 00000000..7236bd97 --- /dev/null +++ b/gotrue/_async/gotrue_admin_api.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +from typing import Dict, List, Union + +from ..helpers import parse_link_response, parse_user_response +from ..http_clients import AsyncClient +from ..types import ( + AdminUserAttributes, + GenerateLinkParams, + GenerateLinkResponse, + Options, + User, + UserResponse, +) +from .gotrue_base_api import AsyncGoTrueBaseAPI + + +class AsyncGoTrueAdminAPI(AsyncGoTrueBaseAPI): + def __init__( + self, + *, + url: str = "", + headers: Dict[str, str] = {}, + http_client: Union[AsyncClient, None] = None, + ) -> None: + AsyncGoTrueBaseAPI.__init__( + self, + url=url, + headers=headers, + http_client=http_client, + ) + + async def sign_out(self, jwt: str) -> None: + """ + Removes a logged-in session. + """ + return await self._request( + "POST", + "logout", + jwt=jwt, + no_resolve_json=True, + ) + + async def invite_user_by_email( + self, + email: str, + options: Options = {}, + ) -> UserResponse: + """ + Sends an invite link to an email address. + """ + return await self._request( + "POST", + "invite", + body={"email": email, "data": options.get("data")}, + redirect_to=options.get("redirect_to"), + xform=parse_user_response, + ) + + async def generate_link(self, params: GenerateLinkParams) -> GenerateLinkResponse: + """ + Generates email links and OTPs to be sent via a custom email provider. + """ + return await self._request( + "POST", + "admin/generate_link", + body={ + "type": params.get("type"), + "email": params.get("email"), + "password": params.get("password"), + "new_email": params.get("new_email"), + "data": params.get("options", {}).get("data"), + }, + redirect_to=params.get("options", {}).get("redirect_to"), + xform=parse_link_response, + ) + + # User Admin API + + async def create_user(self, attributes: AdminUserAttributes) -> UserResponse: + """ + Creates a new user. + + This function should only be called on a server. + Never expose your `service_role` key in the browser. + """ + return await self._request( + "POST", + "admin/users", + body=attributes, + xform=parse_user_response, + ) + + async def list_users(self) -> List[User]: + """ + Get a list of users. + + This function should only be called on a server. + Never expose your `service_role` key in the browser. + """ + return await self._request( + "GET", + "admin/users", + xform=lambda data: [User.parse_obj(user) for user in data["users"]] + if "users" in data + else [], + ) + + async def get_user_by_id(self, uid: str) -> UserResponse: + """ + Get user by id. + + This function should only be called on a server. + Never expose your `service_role` key in the browser. + """ + return await self._request( + "GET", + f"admin/users/{uid}", + xform=parse_user_response, + ) + + async def update_user_by_id( + self, + uid: str, + attributes: AdminUserAttributes, + ) -> UserResponse: + """ + Updates the user data. + + This function should only be called on a server. + Never expose your `service_role` key in the browser. + """ + return await self._request( + "PUT", + f"admin/users/{uid}", + body=attributes, + xform=parse_user_response, + ) + + async def delete_user(self, id: str) -> None: + """ + Delete a user. Requires a `service_role` key. + + This function should only be called on a server. + Never expose your `service_role` key in the browser. + """ + return await self._request("DELETE", f"admin/users/{id}") diff --git a/gotrue/_async/gotrue_admin_mfa_api.py b/gotrue/_async/gotrue_admin_mfa_api.py deleted file mode 100644 index ca812fcd..00000000 --- a/gotrue/_async/gotrue_admin_mfa_api.py +++ /dev/null @@ -1,32 +0,0 @@ -from ..types import ( - AuthMFAAdminDeleteFactorParams, - AuthMFAAdminDeleteFactorResponse, - AuthMFAAdminListFactorsParams, - AuthMFAAdminListFactorsResponse, -) - - -class AsyncGoTrueAdminMFAAPI: - """ - Contains the full multi-factor authentication administration API. - """ - - async def list_factors( - self, - params: AuthMFAAdminListFactorsParams, - ) -> AuthMFAAdminListFactorsResponse: - """ - Lists all factors attached to a user. - """ - raise NotImplementedError() # pragma: no cover - - async def delete_factor( - self, - params: AuthMFAAdminDeleteFactorParams, - ) -> AuthMFAAdminDeleteFactorResponse: - """ - Deletes a factor on a user. This will log the user out of all active - sessions (if the deleted factor was verified). There's no need to delete - unverified factors. - """ - raise NotImplementedError() # pragma: no cover diff --git a/gotrue/_async/gotrue_base_api.py b/gotrue/_async/gotrue_base_api.py new file mode 100644 index 00000000..8d7b4697 --- /dev/null +++ b/gotrue/_async/gotrue_base_api.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +from typing import Any, Callable, Dict, TypeVar, Union, overload + +from httpx import Response +from pydantic import BaseModel +from typing_extensions import Literal, Self + +from ..helpers import handle_exception +from ..http_clients import AsyncClient + +T = TypeVar("T") + + +class AsyncGoTrueBaseAPI: + def __init__( + self, + *, + url: str, + headers: Dict[str, str], + http_client: Union[AsyncClient, None], + ): + self._url = url + self._headers = headers + self._http_client = http_client or AsyncClient() + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_t, exc_v, exc_tb) -> None: + await self.close() + + async def close(self) -> None: + await self._http_client.aclose() + + @overload + async def _request( + self, + method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], + path: str, + *, + jwt: Union[str, None] = None, + redirect_to: Union[str, None] = None, + headers: Union[Dict[str, str], None] = None, + query: Union[Dict[str, str], None] = None, + body: Union[Any, None] = None, + no_resolve_json: Literal[False] = False, + xform: Callable[[Any], T], + ) -> T: + ... # pragma: no cover + + @overload + async def _request( + self, + method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], + path: str, + *, + jwt: Union[str, None] = None, + redirect_to: Union[str, None] = None, + headers: Union[Dict[str, str], None] = None, + query: Union[Dict[str, str], None] = None, + body: Union[Any, None] = None, + no_resolve_json: Literal[True], + xform: Callable[[Response], T], + ) -> T: + ... # pragma: no cover + + @overload + async def _request( + self, + method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], + path: str, + *, + jwt: Union[str, None] = None, + redirect_to: Union[str, None] = None, + headers: Union[Dict[str, str], None] = None, + query: Union[Dict[str, str], None] = None, + body: Union[Any, None] = None, + no_resolve_json: bool = False, + ) -> None: + ... # pragma: no cover + + async def _request( + self, + method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], + path: str, + *, + jwt: Union[str, None] = None, + redirect_to: Union[str, None] = None, + headers: Union[Dict[str, str], None] = None, + query: Union[Dict[str, str], None] = None, + body: Union[Any, None] = None, + no_resolve_json: bool = False, + xform: Union[Callable[[Any], T], None] = None, + ) -> Union[T, None]: + url = f"{self._url}/{path}" + headers = {**self._headers, **(headers or {})} + if "Content-Type" not in headers: + headers["Content-Type"] = "application/json;charset=UTF-8" + if jwt: + headers["Authorization"] = f"Bearer {jwt}" + query = query or {} + if redirect_to: + query["redirect_to"] = redirect_to + try: + response = await self._http_client.request( + method, + url, + headers=headers, + params=query, + json=body.dict() if isinstance(body, BaseModel) else body, + ) + response.raise_for_status() + result = response if no_resolve_json else response.json() + if xform: + return xform(result) + except Exception as e: + raise handle_exception(e) diff --git a/gotrue/_async/gotrue_mfa_api.py b/gotrue/_async/gotrue_mfa_api.py deleted file mode 100644 index 1b5218f7..00000000 --- a/gotrue/_async/gotrue_mfa_api.py +++ /dev/null @@ -1,116 +0,0 @@ -from ..http_clients import AsyncClient -from ..types import ( - AuthMFAChallengeResponse, - AuthMFAEnrollResponse, - AuthMFAGetAuthenticatorAssuranceLevelResponse, - AuthMFAListFactorsResponse, - AuthMFAUnenrollResponse, - AuthMFAVerifyResponse, - MFAChallengeAndVerifyParams, - MFAChallengeParams, - MFAEnrollParams, - MFAUnenrollParams, - MFAVerifyParams, -) - - -class AsyncGoTrueMFAAPI: - """ - Contains the full multi-factor authentication API. - """ - - def __init__( - self, - *, - url: str, - headers: Dict[str, str], - cookie_options: CookieOptions, - http_client: Optional[AsyncClient] = None - ): - self.url = url - self.headers = headers - self.cookie_options = cookie_options - self.http_client = http_client or AsyncClient() - - async def enroll(self, params: MFAEnrollParams) -> AuthMFAEnrollResponse: - """ - Starts the enrollment process for a new Multi-Factor Authentication - factor. This method creates a new factor in the 'unverified' state. - Present the QR code or secret to the user and ask them to add it to their - authenticator app. Ask the user to provide you with an authenticator code - from their app and verify it by calling challenge and then verify. - - The first successful verification of an unverified factor activates the - factor. All other sessions are logged out and the current one gets an - `aal2` authenticator level. - """ - headers = self.headers - response = await self.http_client.post(url, json=params, headers=headers) - - return check_response(response) # pragma: no cover - - async def challenge(self, params: MFAChallengeParams) -> AuthMFAChallengeResponse: - """ - Prepares a challenge used to verify that a user has access to a MFA - factor. Provide the challenge ID and verification code by calling `verify`. - """ - - # TODO(joel): fetch session - headers = self.headers - - response = await self.http_client.post(url, json=params, headers=headers) - return check_response(response) # pragma: no cover - - async def challenge_and_verify( - self, - params: MFAChallengeAndVerifyParams, - ) -> AuthMFAVerifyResponse: - """ - Helper method which creates a challenge and immediately uses the given code - to verify against it thereafter. The verification code is provided by the - user by entering a code seen in their authenticator app. - """ - raise NotImplementedError() # pragma: no cover - - async def verify(self, params: MFAVerifyParams) -> AuthMFAVerifyResponse: - """ - Verifies a verification code against a challenge. The verification code is - provided by the user by entering a code seen in their authenticator app. - """ - raise NotImplementedError() # pragma: no cover - - async def unenroll(self, params: MFAUnenrollParams) -> AuthMFAUnenrollResponse: - """ - Unenroll removes a MFA factor. Unverified factors can safely be ignored - and it's not necessary to unenroll them. Unenrolling a verified MFA factor - cannot be done from a session with an `aal1` authenticator level. - """ - raise NotImplementedError() # pragma: no cover - - async def list_factors(self) -> AuthMFAListFactorsResponse: - """ - Returns the list of MFA factors enabled for this user. For most use cases - you should consider using `get_authenticator_assurance_level`. - - This uses a cached version of the factors and avoids incurring a network call. - If you need to update this list, call `get_user` first. - """ - raise NotImplementedError() # pragma: no cover - - async def get_authenticator_assurance_level( - self, - ) -> AuthMFAGetAuthenticatorAssuranceLevelResponse: - """ - Returns the Authenticator Assurance Level (AAL) for the active session. - - - `aal1` (or `null`) means that the user's identity has been verified only - with a conventional login (email+password, OTP, magic link, social login, - etc.). - - `aal2` means that the user's identity has been verified both with a - conventional login and at least one MFA factor. - - Although this method returns a promise, it's fairly quick (microseconds) - and rarely uses the network. You can use this to check whether the current - user needs to be shown a screen to verify their MFA factors. - """ - raise NotImplementedError() # pragma: no cover diff --git a/gotrue/_sync/gotrue_admin_api.py b/gotrue/_sync/gotrue_admin_api.py new file mode 100644 index 00000000..4b934176 --- /dev/null +++ b/gotrue/_sync/gotrue_admin_api.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +from typing import Dict, List, Union + +from ..helpers import parse_link_response, parse_user_response +from ..http_clients import SyncClient +from ..types import ( + AdminUserAttributes, + GenerateLinkParams, + GenerateLinkResponse, + Options, + User, + UserResponse, +) +from .gotrue_base_api import SyncGoTrueBaseAPI + + +class SyncGoTrueAdminAPI(SyncGoTrueBaseAPI): + def __init__( + self, + *, + url: str = "", + headers: Dict[str, str] = {}, + http_client: Union[SyncClient, None] = None, + ) -> None: + SyncGoTrueBaseAPI.__init__( + self, + url=url, + headers=headers, + http_client=http_client, + ) + + def sign_out(self, jwt: str) -> None: + """ + Removes a logged-in session. + """ + return self._request( + "POST", + "logout", + jwt=jwt, + no_resolve_json=True, + ) + + def invite_user_by_email( + self, + email: str, + options: Options = {}, + ) -> UserResponse: + """ + Sends an invite link to an email address. + """ + return self._request( + "POST", + "invite", + body={"email": email, "data": options.get("data")}, + redirect_to=options.get("redirect_to"), + xform=parse_user_response, + ) + + def generate_link(self, params: GenerateLinkParams) -> GenerateLinkResponse: + """ + Generates email links and OTPs to be sent via a custom email provider. + """ + return self._request( + "POST", + "admin/generate_link", + body={ + "type": params.get("type"), + "email": params.get("email"), + "password": params.get("password"), + "new_email": params.get("new_email"), + "data": params.get("options", {}).get("data"), + }, + redirect_to=params.get("options", {}).get("redirect_to"), + xform=parse_link_response, + ) + + # User Admin API + + def create_user(self, attributes: AdminUserAttributes) -> UserResponse: + """ + Creates a new user. + + This function should only be called on a server. + Never expose your `service_role` key in the browser. + """ + return self._request( + "POST", + "admin/users", + body=attributes, + xform=parse_user_response, + ) + + def list_users(self) -> List[User]: + """ + Get a list of users. + + This function should only be called on a server. + Never expose your `service_role` key in the browser. + """ + return self._request( + "GET", + "admin/users", + xform=lambda data: [User.parse_obj(user) for user in data["users"]] + if "users" in data + else [], + ) + + def get_user_by_id(self, uid: str) -> UserResponse: + """ + Get user by id. + + This function should only be called on a server. + Never expose your `service_role` key in the browser. + """ + return self._request( + "GET", + f"admin/users/{uid}", + xform=parse_user_response, + ) + + def update_user_by_id( + self, + uid: str, + attributes: AdminUserAttributes, + ) -> UserResponse: + """ + Updates the user data. + + This function should only be called on a server. + Never expose your `service_role` key in the browser. + """ + return self._request( + "PUT", + f"admin/users/{uid}", + body=attributes, + xform=parse_user_response, + ) + + def delete_user(self, id: str) -> None: + """ + Delete a user. Requires a `service_role` key. + + This function should only be called on a server. + Never expose your `service_role` key in the browser. + """ + return self._request("DELETE", f"admin/users/{id}") diff --git a/gotrue/_sync/gotrue_base_api.py b/gotrue/_sync/gotrue_base_api.py new file mode 100644 index 00000000..81701304 --- /dev/null +++ b/gotrue/_sync/gotrue_base_api.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +from typing import Any, Callable, Dict, TypeVar, Union, overload + +from httpx import Response +from pydantic import BaseModel +from typing_extensions import Literal, Self + +from ..helpers import handle_exception +from ..http_clients import SyncClient + +T = TypeVar("T") + + +class SyncGoTrueBaseAPI: + def __init__( + self, + *, + url: str, + headers: Dict[str, str], + http_client: Union[SyncClient, None], + ): + self._url = url + self._headers = headers + self._http_client = http_client or SyncClient() + + def __enter__(self) -> Self: + return self + + def __exit__(self, exc_t, exc_v, exc_tb) -> None: + self.close() + + def close(self) -> None: + self._http_client.aclose() + + @overload + def _request( + self, + method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], + path: str, + *, + jwt: Union[str, None] = None, + redirect_to: Union[str, None] = None, + headers: Union[Dict[str, str], None] = None, + query: Union[Dict[str, str], None] = None, + body: Union[Any, None] = None, + no_resolve_json: Literal[False] = False, + xform: Callable[[Any], T], + ) -> T: + ... # pragma: no cover + + @overload + def _request( + self, + method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], + path: str, + *, + jwt: Union[str, None] = None, + redirect_to: Union[str, None] = None, + headers: Union[Dict[str, str], None] = None, + query: Union[Dict[str, str], None] = None, + body: Union[Any, None] = None, + no_resolve_json: Literal[True], + xform: Callable[[Response], T], + ) -> T: + ... # pragma: no cover + + @overload + def _request( + self, + method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], + path: str, + *, + jwt: Union[str, None] = None, + redirect_to: Union[str, None] = None, + headers: Union[Dict[str, str], None] = None, + query: Union[Dict[str, str], None] = None, + body: Union[Any, None] = None, + no_resolve_json: bool = False, + ) -> None: + ... # pragma: no cover + + def _request( + self, + method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], + path: str, + *, + jwt: Union[str, None] = None, + redirect_to: Union[str, None] = None, + headers: Union[Dict[str, str], None] = None, + query: Union[Dict[str, str], None] = None, + body: Union[Any, None] = None, + no_resolve_json: bool = False, + xform: Union[Callable[[Any], T], None] = None, + ) -> Union[T, None]: + url = f"{self._url}/{path}" + headers = {**self._headers, **(headers or {})} + if "Content-Type" not in headers: + headers["Content-Type"] = "application/json;charset=UTF-8" + if jwt: + headers["Authorization"] = f"Bearer {jwt}" + query = query or {} + if redirect_to: + query["redirect_to"] = redirect_to + try: + response = self._http_client.request( + method, + url, + headers=headers, + params=query, + json=body.dict() if isinstance(body, BaseModel) else body, + ) + response.raise_for_status() + result = response if no_resolve_json else response.json() + if xform: + return xform(result) + except Exception as e: + raise handle_exception(e) diff --git a/gotrue/_sync/gotrue_client.py b/gotrue/_sync/gotrue_client.py index 43e41079..36d9f274 100644 --- a/gotrue/_sync/gotrue_client.py +++ b/gotrue/_sync/gotrue_client.py @@ -372,7 +372,8 @@ def get_user(self, jwt: Union[str, None] = None) -> UserResponse: `get_user()` will attempt to get the `jwt` from the current session. """ if not jwt: - if session := self.get_session(): + session = self.get_session() + if session: jwt = session.access_token return self._request("GET", "user", jwt=jwt, xform=parse_user_response) @@ -534,15 +535,15 @@ def _enroll(self, params: MFAEnrollParams) -> AuthMFAEnrollResponse: return response def _challenge(self, params: MFAChallengeParams) -> AuthMFAChallengeResponse: - if session := self.get_session(): - return self._request( - "POST", - f"factors/{params.get('factor_id')}/challenge", - jwt=session.access_token, - xform=AuthMFAChallengeResponse.parse_obj, - ) - else: + session = self.get_session() + if not session: raise AuthSessionMissingError() + return self._request( + "POST", + f"factors/{params.get('factor_id')}/challenge", + jwt=session.access_token, + xform=AuthMFAChallengeResponse.parse_obj, + ) def _challenge_and_verify( self, @@ -578,15 +579,15 @@ def _verify(self, params: MFAVerifyParams) -> AuthMFAVerifyResponse: return response def _unenroll(self, params: MFAUnenrollParams) -> AuthMFAUnenrollResponse: - if session := self.get_session(): - return self._request( - "DELETE", - f"factors/{params.get('factor_id')}", - jwt=session.access_token, - xform=AuthMFAUnenrollResponse.parse_obj, - ) - else: + session = self.get_session() + if not session: raise AuthSessionMissingError() + return self._request( + "DELETE", + f"factors/{params.get('factor_id')}", + jwt=session.access_token, + xform=AuthMFAUnenrollResponse.parse_obj, + ) def _list_factors(self) -> AuthMFAListFactorsResponse: response = self.get_user() From 884789ca2b524f9c9f45944a876f812d48145e57 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Sun, 22 Jan 2023 23:08:36 +0800 Subject: [PATCH 08/16] chore: remove unused _sync files --- gotrue/_sync/api.py | 642 -------------------------- gotrue/_sync/gotrue_admin_mfa_api.py | 32 -- gotrue/_sync/gotrue_mfa_api.py | 116 ----- tests/_async/test_gotrue_admin_api.py | 273 +++++++++++ tests/_sync/test_gotrue_admin_api.py | 273 +++++++++++ 5 files changed, 546 insertions(+), 790 deletions(-) delete mode 100644 gotrue/_sync/api.py delete mode 100644 gotrue/_sync/gotrue_admin_mfa_api.py delete mode 100644 gotrue/_sync/gotrue_mfa_api.py create mode 100644 tests/_async/test_gotrue_admin_api.py create mode 100644 tests/_sync/test_gotrue_admin_api.py diff --git a/gotrue/_sync/api.py b/gotrue/_sync/api.py deleted file mode 100644 index abbdc480..00000000 --- a/gotrue/_sync/api.py +++ /dev/null @@ -1,642 +0,0 @@ -from __future__ import annotations - -from typing import Any, Dict, List, Optional, Union - -from pydantic import parse_obj_as - -from ..exceptions import APIError -from ..helpers import check_response, encode_uri_component -from ..http_clients import SyncClient -from ..types import ( - CookieOptions, - LinkType, - Provider, - Session, - User, - UserAttributes, - determine_session_or_user_model_from_response, -) - - -class SyncGoTrueAPI: - def __init__( - self, - *, - url: str, - headers: Dict[str, str], - cookie_options: CookieOptions, - http_client: Optional[SyncClient] = None, - ) -> None: - """Initialise API class.""" - self.url = url - self.headers = headers - self.cookie_options = cookie_options - self.http_client = http_client or SyncClient() - - def __enter__(self) -> SyncGoTrueAPI: - return self - - def __exit__(self, exc_t, exc_v, exc_tb) -> None: - self.close() - - def close(self) -> None: - self.http_client.aclose() - - def create_user(self, *, attributes: UserAttributes) -> User: - """Creates a new user. - - This function should only be called on a server. - Never expose your `service_role` key in the browser. - - Parameters - ---------- - attributes: UserAttributes - The data you want to create the user with. - - Returns - ------- - response : User - The created user - - Raises - ------ - APIError - If an error occurs. - """ - headers = self.headers - data = attributes.dict() - url = f"{self.url}/admin/users" - response = self.http_client.post(url, json=data, headers=headers) - return User.parse_response(response) - - def list_users(self) -> List[User]: - """Get a list of users. - - This function should only be called on a server. - Never expose your `service_role` key in the browser. - - Returns - ------- - response : List[User] - A list of users - - Raises - ------ - APIError - If an error occurs. - """ - headers = self.headers - url = f"{self.url}/admin/users" - response = self.http_client.get(url, headers=headers) - check_response(response) - users = response.json().get("users") - if users is None: - raise APIError("No users found in response", 400) - if not isinstance(users, list): - raise APIError("Expected a list of users", 400) - return parse_obj_as(List[User], users) - - def sign_up_with_email( - self, - *, - email: str, - password: str, - redirect_to: Optional[str] = None, - data: Optional[Dict[str, Any]] = None, - ) -> Union[Session, User]: - """Creates a new user using their email address. - - Parameters - ---------- - email : str - The email address of the user. - password : str - The password of the user. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - data : Optional[Dict[str, Any]] - Optional user metadata. - - Returns - ------- - response : Union[Session, User] - A logged-in session if the server has "autoconfirm" ON - A user if the server has "autoconfirm" OFF - - Raises - ------ - APIError - If an error occurs. - """ - headers = self.headers - query_string = "" - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - query_string = f"?redirect_to={redirect_to_encoded}" - data = {"email": email, "password": password, "data": data} - url = f"{self.url}/signup{query_string}" - response = self.http_client.post(url, json=data, headers=headers) - SessionOrUserModel = determine_session_or_user_model_from_response(response) - return SessionOrUserModel.parse_response(response) - - def sign_in_with_email( - self, - *, - email: str, - password: str, - redirect_to: Optional[str] = None, - ) -> Session: - """Logs in an existing user using their email address. - - Parameters - ---------- - email : str - The email address of the user. - password : str - The password of the user. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - - Returns - ------- - response : Session - A logged-in session - - Raises - ------ - APIError - If an error occurs. - """ - - headers = self.headers - query_string = "?grant_type=password" - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - query_string += f"&redirect_to={redirect_to_encoded}" - data = {"email": email, "password": password} - url = f"{self.url}/token{query_string}" - response = self.http_client.post(url, json=data, headers=headers) - return Session.parse_response(response) - - def sign_up_with_phone( - self, - *, - phone: str, - password: str, - data: Optional[Dict[str, Any]] = None, - ) -> Union[Session, User]: - """Signs up a new user using their phone number and a password. - - Parameters - ---------- - phone : str - The phone number of the user. - password : str - The password of the user. - data : Optional[Dict[str, Any]] - Optional user metadata. - - Returns - ------- - response : Union[Session, User] - A logged-in session if the server has "autoconfirm" ON - A user if the server has "autoconfirm" OFF - - Raises - ------ - APIError - If an error occurs. - """ - headers = self.headers - data = {"phone": phone, "password": password, "data": data} - url = f"{self.url}/signup" - response = self.http_client.post(url, json=data, headers=headers) - SessionOrUserModel = determine_session_or_user_model_from_response(response) - return SessionOrUserModel.parse_response(response) - - def sign_in_with_phone( - self, - *, - phone: str, - password: str, - ) -> Session: - """Logs in an existing user using their phone number and password. - - Parameters - ---------- - phone : str - The phone number of the user. - password : str - The password of the user. - - Returns - ------- - response : Session - A logged-in session - - Raises - ------ - APIError - If an error occurs. - """ - data = {"phone": phone, "password": password} - url = f"{self.url}/token?grant_type=password" - headers = self.headers - response = self.http_client.post(url, json=data, headers=headers) - return Session.parse_response(response) - - def send_magic_link_email( - self, - *, - email: str, - create_user: bool, - redirect_to: Optional[str] = None, - ) -> None: - """Sends a magic login link to an email address. - - Parameters - ---------- - email : str - The email address of the user. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - - Raises - ------ - APIError - If an error occurs. - """ - headers = self.headers - query_string = "" - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - query_string = f"?redirect_to={redirect_to_encoded}" - data = {"email": email, "create_user": create_user} - url = f"{self.url}/magiclink{query_string}" - response = self.http_client.post(url, json=data, headers=headers) - return check_response(response) - - def send_mobile_otp(self, *, phone: str, create_user: bool) -> None: - """Sends a mobile OTP via SMS. Will register the account if it doesn't already exist - - Parameters - ---------- - phone : str - The user's phone number WITH international prefix - - Raises - ------ - APIError - If an error occurs. - """ - headers = self.headers - data = {"phone": phone, "create_user": create_user} - url = f"{self.url}/otp" - response = self.http_client.post(url, json=data, headers=headers) - return check_response(response) - - def verify_mobile_otp( - self, - *, - phone: str, - token: str, - redirect_to: Optional[str] = None, - ) -> Union[Session, User]: - """Send User supplied Mobile OTP to be verified - - Parameters - ---------- - phone : str - The user's phone number WITH international prefix - token : str - Token that user was sent to their mobile phone - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - - Returns - ------- - response : Union[Session, User] - A logged-in session if the server has "autoconfirm" ON - A user if the server has "autoconfirm" OFF - - Raises - ------ - APIError - If an error occurs. - """ - headers = self.headers - data = { - "phone": phone, - "token": token, - "type": "sms", - } - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - data["redirect_to"] = redirect_to_encoded - url = f"{self.url}/verify" - response = self.http_client.post(url, json=data, headers=headers) - SessionOrUserModel = determine_session_or_user_model_from_response(response) - return SessionOrUserModel.parse_response(response) - - def invite_user_by_email( - self, - *, - email: str, - redirect_to: Optional[str] = None, - data: Optional[Dict[str, Any]] = None, - ) -> User: - """Sends an invite link to an email address. - - Parameters - ---------- - email : str - The email address of the user. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - data : Optional[Dict[str, Any]] - Optional user metadata. - - Returns - ------- - response : User - A user - - Raises - ------ - APIError - If an error occurs. - """ - headers = self.headers - query_string = "" - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - query_string = f"?redirect_to={redirect_to_encoded}" - data = {"email": email, "data": data} - url = f"{self.url}/invite{query_string}" - response = self.http_client.post(url, json=data, headers=headers) - return User.parse_response(response) - - def reset_password_for_email( - self, - *, - email: str, - redirect_to: Optional[str] = None, - ) -> None: - """Sends a reset request to an email address. - - Parameters - ---------- - email : str - The email address of the user. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - - Raises - ------ - APIError - If an error occurs. - """ - headers = self.headers - query_string = "" - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - query_string = f"?redirect_to={redirect_to_encoded}" - data = {"email": email} - url = f"{self.url}/recover{query_string}" - response = self.http_client.post(url, json=data, headers=headers) - return check_response(response) - - def _create_request_headers(self, *, jwt: str) -> Dict[str, str]: - """Create temporary object. - - Create a temporary object with all configured headers and adds the - Authorization token to be used on request methods. - - Parameters - ---------- - jwt : str - A valid, logged-in JWT. - - Returns - ------- - headers : dict of str - The headers required for a successful request statement with the - supabase backend. - """ - headers = {**self.headers, "Authorization": f"Bearer {jwt}"} - return headers - - def sign_out(self, *, jwt: str) -> None: - """Removes a logged-in session. - - Parameters - ---------- - jwt : str - A valid, logged-in JWT. - """ - headers = self._create_request_headers(jwt=jwt) - url = f"{self.url}/logout" - self.http_client.post(url, headers=headers) - - def get_url_for_provider( - self, - *, - provider: Provider, - redirect_to: Optional[str] = None, - scopes: Optional[str] = None, - ) -> str: - """Generates the relevant login URL for a third-party provider. - - Parameters - ---------- - provider : Provider - One of the providers supported by GoTrue. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - scopes : Optional[str] - A space-separated list of scopes granted to the OAuth application. - - Returns - ------- - url : str - The URL to redirect the user to. - - Raises - ------ - APIError - If an error occurs. - """ - url_params = [f"provider={encode_uri_component(provider)}"] - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - url_params.append(f"redirect_to={redirect_to_encoded}") - if scopes: - url_params.append(f"scopes={encode_uri_component(scopes)}") - return f"{self.url}/authorize?{'&'.join(url_params)}" - - def get_user(self, *, jwt: str) -> User: - """Gets the user details. - - Parameters - ---------- - jwt : str - A valid, logged-in JWT. - - Returns - ------- - response : User - A user - - Raises - ------ - APIError - If an error occurs. - """ - headers = self._create_request_headers(jwt=jwt) - url = f"{self.url}/user" - response = self.http_client.get(url, headers=headers) - return User.parse_response(response) - - def update_user( - self, - *, - jwt: str, - attributes: UserAttributes, - ) -> User: - """ - Updates the user data. - - Parameters - ---------- - jwt : str - A valid, logged-in JWT. - attributes : UserAttributes - The data you want to update. - - Returns - ------- - response : User - A user - - Raises - ------ - APIError - If an error occurs. - """ - headers = self._create_request_headers(jwt=jwt) - data = attributes.dict() - url = f"{self.url}/user" - response = self.http_client.put(url, json=data, headers=headers) - return User.parse_response(response) - - def delete_user(self, *, uid: str, jwt: str) -> None: - """Delete a user. Requires a `service_role` key. - - This function should only be called on a server. - Never expose your `service_role` key in the browser. - - Parameters - ---------- - uid : str - The user uid you want to remove. - jwt : str - A valid, logged-in JWT. - - Returns - ------- - response : User - A user - - Raises - ------ - APIError - If an error occurs. - """ - headers = self._create_request_headers(jwt=jwt) - url = f"{self.url}/admin/users/{uid}" - response = self.http_client.delete(url, headers=headers) - return check_response(response) - - def refresh_access_token(self, *, refresh_token: str) -> Session: - """Generates a new JWT. - - Parameters - ---------- - refresh_token : str - A valid refresh token that was returned on login. - - Returns - ------- - response : Session - A session - - Raises - ------ - APIError - If an error occurs. - """ - data = {"refresh_token": refresh_token} - url = f"{self.url}/token?grant_type=refresh_token" - headers = self.headers - response = self.http_client.post(url, json=data, headers=headers) - return Session.parse_response(response) - - def generate_link( - self, - *, - type: LinkType, - email: str, - password: Optional[str] = None, - redirect_to: Optional[str] = None, - data: Optional[Dict[str, Any]] = None, - ) -> Union[Session, User]: - """ - Generates links to be sent via email or other. - - Parameters - ---------- - type : LinkType - The link type ("signup" or "magiclink" or "recovery" or "invite"). - email : str - The user's email. - password : Optional[str] - User password. For signup only. - redirect_to : Optional[str] - The link type ("signup" or "magiclink" or "recovery" or "invite"). - data : Optional[Dict[str, Any]] - Optional user metadata. For signup only. - - Returns - ------- - response : Union[Session, User] - A logged-in session if the server has "autoconfirm" ON - A user if the server has "autoconfirm" OFF - - Raises - ------ - APIError - If an error occurs. - """ - headers = self.headers - data = { - "type": type, - "email": email, - "data": data, - } - if password: - data["password"] = password - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - data["redirect_to"] = redirect_to_encoded - url = f"{self.url}/admin/generate_link" - response = self.http_client.post(url, json=data, headers=headers) - SessionOrUserModel = determine_session_or_user_model_from_response(response) - return SessionOrUserModel.parse_response(response) - - def set_auth_cookie(self, *, req, res): - """Stub for parity with JS api.""" - raise NotImplementedError("set_auth_cookie not implemented.") - - def get_user_by_cookie(self, *, req): - """Stub for parity with JS api.""" - raise NotImplementedError("get_user_by_cookie not implemented.") diff --git a/gotrue/_sync/gotrue_admin_mfa_api.py b/gotrue/_sync/gotrue_admin_mfa_api.py deleted file mode 100644 index c3fcfc8e..00000000 --- a/gotrue/_sync/gotrue_admin_mfa_api.py +++ /dev/null @@ -1,32 +0,0 @@ -from ..types import ( - AuthMFAAdminDeleteFactorParams, - AuthMFAAdminDeleteFactorResponse, - AuthMFAAdminListFactorsParams, - AuthMFAAdminListFactorsResponse, -) - - -class SyncGoTrueAdminMFAAPI: - """ - Contains the full multi-factor authentication administration API. - """ - - def list_factors( - self, - params: AuthMFAAdminListFactorsParams, - ) -> AuthMFAAdminListFactorsResponse: - """ - Lists all factors attached to a user. - """ - raise NotImplementedError() # pragma: no cover - - def delete_factor( - self, - params: AuthMFAAdminDeleteFactorParams, - ) -> AuthMFAAdminDeleteFactorResponse: - """ - Deletes a factor on a user. This will log the user out of all active - sessions (if the deleted factor was verified). There's no need to delete - unverified factors. - """ - raise NotImplementedError() # pragma: no cover diff --git a/gotrue/_sync/gotrue_mfa_api.py b/gotrue/_sync/gotrue_mfa_api.py deleted file mode 100644 index a429af23..00000000 --- a/gotrue/_sync/gotrue_mfa_api.py +++ /dev/null @@ -1,116 +0,0 @@ -from ..http_clients import SyncClient -from ..types import ( - AuthMFAChallengeResponse, - AuthMFAEnrollResponse, - AuthMFAGetAuthenticatorAssuranceLevelResponse, - AuthMFAListFactorsResponse, - AuthMFAUnenrollResponse, - AuthMFAVerifyResponse, - MFAChallengeAndVerifyParams, - MFAChallengeParams, - MFAEnrollParams, - MFAUnenrollParams, - MFAVerifyParams, -) - - -class SyncGoTrueMFAAPI: - """ - Contains the full multi-factor authentication API. - """ - - def __init__( - self, - *, - url: str, - headers: Dict[str, str], - cookie_options: CookieOptions, - http_client: Optional[SyncClient] = None - ): - self.url = url - self.headers = headers - self.cookie_options = cookie_options - self.http_client = http_client or SyncClient() - - def enroll(self, params: MFAEnrollParams) -> AuthMFAEnrollResponse: - """ - Starts the enrollment process for a new Multi-Factor Authentication - factor. This method creates a new factor in the 'unverified' state. - Present the QR code or secret to the user and ask them to add it to their - authenticator app. Ask the user to provide you with an authenticator code - from their app and verify it by calling challenge and then verify. - - The first successful verification of an unverified factor activates the - factor. All other sessions are logged out and the current one gets an - `aal2` authenticator level. - """ - headers = self.headers - response = self.http_client.post(url, json=params, headers=headers) - - return check_response(response) # pragma: no cover - - def challenge(self, params: MFAChallengeParams) -> AuthMFAChallengeResponse: - """ - Prepares a challenge used to verify that a user has access to a MFA - factor. Provide the challenge ID and verification code by calling `verify`. - """ - - # TODO(joel): fetch session - headers = self.headers - - response = self.http_client.post(url, json=params, headers=headers) - return check_response(response) # pragma: no cover - - def challenge_and_verify( - self, - params: MFAChallengeAndVerifyParams, - ) -> AuthMFAVerifyResponse: - """ - Helper method which creates a challenge and immediately uses the given code - to verify against it thereafter. The verification code is provided by the - user by entering a code seen in their authenticator app. - """ - raise NotImplementedError() # pragma: no cover - - def verify(self, params: MFAVerifyParams) -> AuthMFAVerifyResponse: - """ - Verifies a verification code against a challenge. The verification code is - provided by the user by entering a code seen in their authenticator app. - """ - raise NotImplementedError() # pragma: no cover - - def unenroll(self, params: MFAUnenrollParams) -> AuthMFAUnenrollResponse: - """ - Unenroll removes a MFA factor. Unverified factors can safely be ignored - and it's not necessary to unenroll them. Unenrolling a verified MFA factor - cannot be done from a session with an `aal1` authenticator level. - """ - raise NotImplementedError() # pragma: no cover - - def list_factors(self) -> AuthMFAListFactorsResponse: - """ - Returns the list of MFA factors enabled for this user. For most use cases - you should consider using `get_authenticator_assurance_level`. - - This uses a cached version of the factors and avoids incurring a network call. - If you need to update this list, call `get_user` first. - """ - raise NotImplementedError() # pragma: no cover - - def get_authenticator_assurance_level( - self, - ) -> AuthMFAGetAuthenticatorAssuranceLevelResponse: - """ - Returns the Authenticator Assurance Level (AAL) for the active session. - - - `aal1` (or `null`) means that the user's identity has been verified only - with a conventional login (email+password, OTP, magic link, social login, - etc.). - - `aal2` means that the user's identity has been verified both with a - conventional login and at least one MFA factor. - - Although this method returns a promise, it's fairly quick (microseconds) - and rarely uses the network. You can use this to check whether the current - user needs to be shown a screen to verify their MFA factors. - """ - raise NotImplementedError() # pragma: no cover diff --git a/tests/_async/test_gotrue_admin_api.py b/tests/_async/test_gotrue_admin_api.py new file mode 100644 index 00000000..a70380cf --- /dev/null +++ b/tests/_async/test_gotrue_admin_api.py @@ -0,0 +1,273 @@ +from gotrue.errors import AuthError + +from .clients import ( + auth_client_with_session, + client_api_auto_confirm_disabled_client, + client_api_auto_confirm_off_signups_enabled_client, + service_role_api_client, +) +from .utils import ( + create_new_user_with_email, + mock_app_metadata, + mock_user_credentials, + mock_user_metadata, + mock_verification_otp, +) + + +async def test_create_user_should_create_a_new_user(): + credentials = mock_user_credentials() + response = await create_new_user_with_email(email=credentials.get("email")) + assert response.email == credentials.get("email") + + +async def test_create_user_with_user_metadata(): + user_metadata = mock_user_metadata() + credentials = mock_user_credentials() + response = await service_role_api_client().create_user( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + "user_metadata": user_metadata, + } + ) + assert response.user.email == credentials.get("email") + assert response.user.user_metadata == user_metadata + assert "profile_image" in response.user.user_metadata + + +async def test_create_user_with_app_metadata(): + app_metadata = mock_app_metadata() + credentials = mock_user_credentials() + response = await service_role_api_client().create_user( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + "app_metadata": app_metadata, + } + ) + assert response.user.email == credentials.get("email") + assert "provider" in response.user.app_metadata + assert "providers" in response.user.app_metadata + + +async def test_create_user_with_user_and_app_metadata(): + user_metadata = mock_user_metadata() + app_metadata = mock_app_metadata() + credentials = mock_user_credentials() + response = await service_role_api_client().create_user( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + "user_metadata": user_metadata, + "app_metadata": app_metadata, + } + ) + assert response.user.email == credentials.get("email") + assert "profile_image" in response.user.user_metadata + assert "provider" in response.user.app_metadata + assert "providers" in response.user.app_metadata + + +async def test_list_users_should_return_registered_users(): + credentials = mock_user_credentials() + await create_new_user_with_email(email=credentials.get("email")) + users = await service_role_api_client().list_users() + assert users + emails = [user.email for user in users] + assert emails + assert credentials.get("email") in emails + + +async def test_get_user_fetches_a_user_by_their_access_token(): + credentials = mock_user_credentials() + auth_client_with_session_current_user = auth_client_with_session() + response = await auth_client_with_session_current_user.sign_up( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + } + ) + assert response.session + response = await auth_client_with_session_current_user.get_user() + assert response.user.email == credentials.get("email") + + +async def test_get_user_by_id_should_a_registered_user_given_its_user_identifier(): + credentials = mock_user_credentials() + user = await create_new_user_with_email(email=credentials.get("email")) + assert user.id + response = await service_role_api_client().get_user_by_id(user.id) + assert response.user.email == credentials.get("email") + + +async def test_modify_email_using_update_user_by_id(): + credentials = mock_user_credentials() + user = await create_new_user_with_email(email=credentials.get("email")) + response = await service_role_api_client().update_user_by_id( + user.id, + { + "email": f"new_{user.email}", + }, + ) + assert response.user.email == f"new_{user.email}" + + +async def test_modify_user_metadata_using_update_user_by_id(): + credentials = mock_user_credentials() + user = await create_new_user_with_email(email=credentials.get("email")) + user_metadata = {"favorite_color": "yellow"} + response = await service_role_api_client().update_user_by_id( + user.id, + { + "user_metadata": user_metadata, + }, + ) + assert response.user.email == user.email + assert response.user.user_metadata == user_metadata + + +async def test_modify_app_metadata_using_update_user_by_id(): + credentials = mock_user_credentials() + user = await create_new_user_with_email(email=credentials.get("email")) + app_metadata = {"roles": ["admin", "publisher"]} + response = await service_role_api_client().update_user_by_id( + user.id, + { + "app_metadata": app_metadata, + }, + ) + assert response.user.email == user.email + assert "roles" in response.user.app_metadata + + +async def test_modify_confirm_email_using_update_user_by_id(): + credentials = mock_user_credentials() + response = await client_api_auto_confirm_off_signups_enabled_client().sign_up( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + } + ) + assert response.user + assert not response.user.email_confirmed_at + response = await service_role_api_client().update_user_by_id( + response.user.id, + { + "email_confirm": True, + }, + ) + assert response.user.email_confirmed_at + + +async def test_delete_user_should_be_able_delete_an_existing_user(): + credentials = mock_user_credentials() + user = await create_new_user_with_email(email=credentials.get("email")) + await service_role_api_client().delete_user(user.id) + users = await service_role_api_client().list_users() + emails = [user.email for user in users] + assert credentials.get("email") not in emails + + +async def test_generate_link_supports_sign_up_with_generate_confirmation_signup_link(): + credentials = mock_user_credentials() + redirect_to = "http://localhost:9999/welcome" + user_metadata = {"status": "alpha"} + response = await service_role_api_client().generate_link( + { + "type": "signup", + "email": credentials.get("email"), + "password": credentials.get("password"), + "options": { + "data": user_metadata, + "redirect_to": redirect_to, + }, + }, + ) + assert response.user.user_metadata == user_metadata + + +async def test_generate_link_supports_updating_emails_with_generate_email_change_links(): # noqa: E501 + credentials = mock_user_credentials() + user = await create_new_user_with_email(email=credentials.get("email")) + assert user.email + assert user.email == credentials.get("email") + credentials = mock_user_credentials() + redirect_to = "http://localhost:9999/welcome" + response = await service_role_api_client().generate_link( + { + "type": "email_change_current", + "email": user.email, + "new_email": credentials.get("email"), + "options": { + "redirect_to": redirect_to, + }, + }, + ) + assert response.user.new_email == credentials.get("email") + + +async def test_invite_user_by_email_creates_a_new_user_with_an_invited_at_timestamp(): + credentials = mock_user_credentials() + redirect_to = "http://localhost:9999/welcome" + user_metadata = {"status": "alpha"} + response = await service_role_api_client().invite_user_by_email( + credentials.get("email"), + { + "data": user_metadata, + "redirect_to": redirect_to, + }, + ) + assert response.user.invited_at + + +async def test_sign_out_with_an_valid_access_token(): + credentials = mock_user_credentials() + response = await auth_client_with_session().sign_up( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + }, + ) + assert response.session + response = await service_role_api_client().sign_out(response.session.access_token) + + +async def test_sign_out_with_an_invalid_access_token(): + try: + await service_role_api_client().sign_out("this-is-a-bad-token") + assert False + except AuthError: + pass + + +async def test_verify_otp_with_non_existent_phone_number(): + credentials = mock_user_credentials() + otp = mock_verification_otp() + try: + await client_api_auto_confirm_disabled_client().verify_otp( + { + "phone": credentials.get("phone"), + "token": otp, + "type": "sms", + }, + ) + assert False + except AuthError as e: + assert e.message == "User not found" + + +async def test_verify_otp_with_invalid_phone_number(): + credentials = mock_user_credentials() + otp = mock_verification_otp() + try: + await client_api_auto_confirm_disabled_client().verify_otp( + { + "phone": f"{credentials.get('phone')}-invalid", + "token": otp, + "type": "sms", + }, + ) + assert False + except AuthError as e: + assert e.message == "Invalid phone number format" diff --git a/tests/_sync/test_gotrue_admin_api.py b/tests/_sync/test_gotrue_admin_api.py new file mode 100644 index 00000000..34cd9580 --- /dev/null +++ b/tests/_sync/test_gotrue_admin_api.py @@ -0,0 +1,273 @@ +from gotrue.errors import AuthError + +from .clients import ( + auth_client_with_session, + client_api_auto_confirm_disabled_client, + client_api_auto_confirm_off_signups_enabled_client, + service_role_api_client, +) +from .utils import ( + create_new_user_with_email, + mock_app_metadata, + mock_user_credentials, + mock_user_metadata, + mock_verification_otp, +) + + +def test_create_user_should_create_a_new_user(): + credentials = mock_user_credentials() + response = create_new_user_with_email(email=credentials.get("email")) + assert response.email == credentials.get("email") + + +def test_create_user_with_user_metadata(): + user_metadata = mock_user_metadata() + credentials = mock_user_credentials() + response = service_role_api_client().create_user( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + "user_metadata": user_metadata, + } + ) + assert response.user.email == credentials.get("email") + assert response.user.user_metadata == user_metadata + assert "profile_image" in response.user.user_metadata + + +def test_create_user_with_app_metadata(): + app_metadata = mock_app_metadata() + credentials = mock_user_credentials() + response = service_role_api_client().create_user( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + "app_metadata": app_metadata, + } + ) + assert response.user.email == credentials.get("email") + assert "provider" in response.user.app_metadata + assert "providers" in response.user.app_metadata + + +def test_create_user_with_user_and_app_metadata(): + user_metadata = mock_user_metadata() + app_metadata = mock_app_metadata() + credentials = mock_user_credentials() + response = service_role_api_client().create_user( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + "user_metadata": user_metadata, + "app_metadata": app_metadata, + } + ) + assert response.user.email == credentials.get("email") + assert "profile_image" in response.user.user_metadata + assert "provider" in response.user.app_metadata + assert "providers" in response.user.app_metadata + + +def test_list_users_should_return_registered_users(): + credentials = mock_user_credentials() + create_new_user_with_email(email=credentials.get("email")) + users = service_role_api_client().list_users() + assert users + emails = [user.email for user in users] + assert emails + assert credentials.get("email") in emails + + +def test_get_user_fetches_a_user_by_their_access_token(): + credentials = mock_user_credentials() + auth_client_with_session_current_user = auth_client_with_session() + response = auth_client_with_session_current_user.sign_up( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + } + ) + assert response.session + response = auth_client_with_session_current_user.get_user() + assert response.user.email == credentials.get("email") + + +def test_get_user_by_id_should_a_registered_user_given_its_user_identifier(): + credentials = mock_user_credentials() + user = create_new_user_with_email(email=credentials.get("email")) + assert user.id + response = service_role_api_client().get_user_by_id(user.id) + assert response.user.email == credentials.get("email") + + +def test_modify_email_using_update_user_by_id(): + credentials = mock_user_credentials() + user = create_new_user_with_email(email=credentials.get("email")) + response = service_role_api_client().update_user_by_id( + user.id, + { + "email": f"new_{user.email}", + }, + ) + assert response.user.email == f"new_{user.email}" + + +def test_modify_user_metadata_using_update_user_by_id(): + credentials = mock_user_credentials() + user = create_new_user_with_email(email=credentials.get("email")) + user_metadata = {"favorite_color": "yellow"} + response = service_role_api_client().update_user_by_id( + user.id, + { + "user_metadata": user_metadata, + }, + ) + assert response.user.email == user.email + assert response.user.user_metadata == user_metadata + + +def test_modify_app_metadata_using_update_user_by_id(): + credentials = mock_user_credentials() + user = create_new_user_with_email(email=credentials.get("email")) + app_metadata = {"roles": ["admin", "publisher"]} + response = service_role_api_client().update_user_by_id( + user.id, + { + "app_metadata": app_metadata, + }, + ) + assert response.user.email == user.email + assert "roles" in response.user.app_metadata + + +def test_modify_confirm_email_using_update_user_by_id(): + credentials = mock_user_credentials() + response = client_api_auto_confirm_off_signups_enabled_client().sign_up( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + } + ) + assert response.user + assert not response.user.email_confirmed_at + response = service_role_api_client().update_user_by_id( + response.user.id, + { + "email_confirm": True, + }, + ) + assert response.user.email_confirmed_at + + +def test_delete_user_should_be_able_delete_an_existing_user(): + credentials = mock_user_credentials() + user = create_new_user_with_email(email=credentials.get("email")) + service_role_api_client().delete_user(user.id) + users = service_role_api_client().list_users() + emails = [user.email for user in users] + assert credentials.get("email") not in emails + + +def test_generate_link_supports_sign_up_with_generate_confirmation_signup_link(): + credentials = mock_user_credentials() + redirect_to = "http://localhost:9999/welcome" + user_metadata = {"status": "alpha"} + response = service_role_api_client().generate_link( + { + "type": "signup", + "email": credentials.get("email"), + "password": credentials.get("password"), + "options": { + "data": user_metadata, + "redirect_to": redirect_to, + }, + }, + ) + assert response.user.user_metadata == user_metadata + + +def test_generate_link_supports_updating_emails_with_generate_email_change_links(): # noqa: E501 + credentials = mock_user_credentials() + user = create_new_user_with_email(email=credentials.get("email")) + assert user.email + assert user.email == credentials.get("email") + credentials = mock_user_credentials() + redirect_to = "http://localhost:9999/welcome" + response = service_role_api_client().generate_link( + { + "type": "email_change_current", + "email": user.email, + "new_email": credentials.get("email"), + "options": { + "redirect_to": redirect_to, + }, + }, + ) + assert response.user.new_email == credentials.get("email") + + +def test_invite_user_by_email_creates_a_new_user_with_an_invited_at_timestamp(): + credentials = mock_user_credentials() + redirect_to = "http://localhost:9999/welcome" + user_metadata = {"status": "alpha"} + response = service_role_api_client().invite_user_by_email( + credentials.get("email"), + { + "data": user_metadata, + "redirect_to": redirect_to, + }, + ) + assert response.user.invited_at + + +def test_sign_out_with_an_valid_access_token(): + credentials = mock_user_credentials() + response = auth_client_with_session().sign_up( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + }, + ) + assert response.session + response = service_role_api_client().sign_out(response.session.access_token) + + +def test_sign_out_with_an_invalid_access_token(): + try: + service_role_api_client().sign_out("this-is-a-bad-token") + assert False + except AuthError: + pass + + +def test_verify_otp_with_non_existent_phone_number(): + credentials = mock_user_credentials() + otp = mock_verification_otp() + try: + client_api_auto_confirm_disabled_client().verify_otp( + { + "phone": credentials.get("phone"), + "token": otp, + "type": "sms", + }, + ) + assert False + except AuthError as e: + assert e.message == "User not found" + + +def test_verify_otp_with_invalid_phone_number(): + credentials = mock_user_credentials() + otp = mock_verification_otp() + try: + client_api_auto_confirm_disabled_client().verify_otp( + { + "phone": f"{credentials.get('phone')}-invalid", + "token": otp, + "type": "sms", + }, + ) + assert False + except AuthError as e: + assert e.message == "Invalid phone number format" From 60d3f7a3924bdd73b3bd73f31da182c7f5e0131d Mon Sep 17 00:00:00 2001 From: Sourcery AI <> Date: Sun, 22 Jan 2023 15:09:08 +0000 Subject: [PATCH 09/16] 'Refactored by Sourcery' --- gotrue/_sync/gotrue_client.py | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/gotrue/_sync/gotrue_client.py b/gotrue/_sync/gotrue_client.py index 36d9f274..43e41079 100644 --- a/gotrue/_sync/gotrue_client.py +++ b/gotrue/_sync/gotrue_client.py @@ -372,8 +372,7 @@ def get_user(self, jwt: Union[str, None] = None) -> UserResponse: `get_user()` will attempt to get the `jwt` from the current session. """ if not jwt: - session = self.get_session() - if session: + if session := self.get_session(): jwt = session.access_token return self._request("GET", "user", jwt=jwt, xform=parse_user_response) @@ -535,15 +534,15 @@ def _enroll(self, params: MFAEnrollParams) -> AuthMFAEnrollResponse: return response def _challenge(self, params: MFAChallengeParams) -> AuthMFAChallengeResponse: - session = self.get_session() - if not session: + if session := self.get_session(): + return self._request( + "POST", + f"factors/{params.get('factor_id')}/challenge", + jwt=session.access_token, + xform=AuthMFAChallengeResponse.parse_obj, + ) + else: raise AuthSessionMissingError() - return self._request( - "POST", - f"factors/{params.get('factor_id')}/challenge", - jwt=session.access_token, - xform=AuthMFAChallengeResponse.parse_obj, - ) def _challenge_and_verify( self, @@ -579,15 +578,15 @@ def _verify(self, params: MFAVerifyParams) -> AuthMFAVerifyResponse: return response def _unenroll(self, params: MFAUnenrollParams) -> AuthMFAUnenrollResponse: - session = self.get_session() - if not session: + if session := self.get_session(): + return self._request( + "DELETE", + f"factors/{params.get('factor_id')}", + jwt=session.access_token, + xform=AuthMFAUnenrollResponse.parse_obj, + ) + else: raise AuthSessionMissingError() - return self._request( - "DELETE", - f"factors/{params.get('factor_id')}", - jwt=session.access_token, - xform=AuthMFAUnenrollResponse.parse_obj, - ) def _list_factors(self) -> AuthMFAListFactorsResponse: response = self.get_user() From cb6481593570bba3d2c5b40911696426ab1ae306 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Sun, 22 Jan 2023 23:13:45 +0800 Subject: [PATCH 10/16] refactor remove gotrue_client --- gotrue/_async/gotrue_client.py | 838 --------------------------------- 1 file changed, 838 deletions(-) delete mode 100644 gotrue/_async/gotrue_client.py diff --git a/gotrue/_async/gotrue_client.py b/gotrue/_async/gotrue_client.py deleted file mode 100644 index 71a45c88..00000000 --- a/gotrue/_async/gotrue_client.py +++ /dev/null @@ -1,838 +0,0 @@ -from __future__ import annotations - -from json import loads -from time import time -from typing import Callable, Dict, List, Tuple, Union -from urllib.parse import parse_qs, quote, urlencode, urlparse -from uuid import uuid4 - -from ..constants import ( - DEFAULT_HEADERS, - EXPIRY_MARGIN, - GOTRUE_URL, - MAX_RETRIES, - RETRY_INTERVAL, - STORAGE_KEY, -) -from ..errors import ( - AuthImplicitGrantRedirectError, - AuthInvalidCredentialsError, - AuthRetryableError, - AuthSessionMissingError, -) -from ..helpers import decode_jwt_payload, parse_auth_response, parse_user_response -from ..http_clients import AsyncClient -from ..timer import Timer -from ..types import ( - AuthChangeEvent, - AuthenticatorAssuranceLevels, - AuthMFAChallengeResponse, - AuthMFAEnrollResponse, - AuthMFAGetAuthenticatorAssuranceLevelResponse, - AuthMFAListFactorsResponse, - AuthMFAUnenrollResponse, - AuthMFAVerifyResponse, - AuthResponse, - DecodedJWTDict, - MFAChallengeAndVerifyParams, - MFAChallengeParams, - MFAEnrollParams, - MFAUnenrollParams, - MFAVerifyParams, - OAuthResponse, - Options, - Provider, - Session, - SignInWithOAuthCredentials, - SignInWithPasswordCredentials, - SignInWithPasswordlessCredentials, - SignUpWithPasswordCredentials, - Subscription, - UserAttributes, - UserResponse, - VerifyOtpParams, -) -from .gotrue_admin_api import AsyncGoTrueAdminAPI -from .gotrue_base_api import AsyncGoTrueBaseAPI -from .gotrue_mfa_api import AsyncGoTrueMFAAPI -from .storage import AsyncMemoryStorage, AsyncSupportedStorage - - -class AsyncGoTrueClient(AsyncGoTrueBaseAPI): - def __init__( - self, - *, - url: Union[str, None] = None, - headers: Union[Dict[str, str], None] = None, - storage_key: Union[str, None] = None, - auto_refresh_token: bool = True, - persist_session: bool = True, - storage: Union[AsyncSupportedStorage, None] = None, - http_client: Union[AsyncClient, None] = None, - ) -> None: - AsyncGoTrueBaseAPI.__init__( - self, - url=url or GOTRUE_URL, - headers=headers or DEFAULT_HEADERS, - http_client=http_client, - ) - self._storage_key = storage_key or STORAGE_KEY - self._auto_refresh_token = auto_refresh_token - self._persist_session = persist_session - self._storage = storage or AsyncMemoryStorage() - self._in_memory_session: Union[Session, None] = None - self._refresh_token_timer: Union[Timer, None] = None - self._network_retries = 0 - self._state_change_emitters: Dict[str, Subscription] = {} - - self.admin = AsyncGoTrueAdminAPI( - url=self._url, - headers=self._headers, - http_client=self._http_client, - ) - self.mfa = AsyncGoTrueMFAAPI() - self.mfa.challenge = self._challenge - self.mfa.challenge_and_verify = self._challenge_and_verify - self.mfa.enroll = self._enroll - self.mfa.get_authenticator_assurance_level = ( - self._get_authenticator_assurance_level - ) - self.mfa.list_factors = self._list_factors - self.mfa.unenroll = self._unenroll - self.mfa.verify = self._verify - - # Initializations - - async def initialize(self, *, url: Union[str, None] = None) -> None: - if url and self._is_implicit_grant_flow(url): - await self.initialize_from_url(url) - else: - await self.initialize_from_storage() - - async def initialize_from_storage(self) -> None: - return await self._recover_and_refresh() - - async def initialize_from_url(self, url: str) -> None: - try: - if self._is_implicit_grant_flow(url): - session, redirect_type = await self._get_session_from_url(url) - await self._save_session(session) - self._notify_all_subscribers("SIGNED_IN", session) - if redirect_type == "recovery": - self._notify_all_subscribers("PASSWORD_RECOVERY", session) - except Exception as e: - await self._remove_session() - raise e - - # Public methods - - async def sign_up( - self, - credentials: SignUpWithPasswordCredentials, - ) -> AuthResponse: - """ - Creates a new user. - """ - await self._remove_session() - email = credentials.get("email") - phone = credentials.get("phone") - password = credentials.get("password") - options = credentials.get("options", {}) - redirect_to = options.get("redirect_to") - data = options.get("data") or {} - captcha_token = options.get("captcha_token") - if email: - response = await self._request( - "POST", - "signup", - body={ - "email": email, - "password": password, - "data": data, - "gotrue_meta_security": { - "captcha_token": captcha_token, - }, - }, - redirect_to=redirect_to, - xform=parse_auth_response, - ) - elif phone: - response = await self._request( - "POST", - "signup", - body={ - "phone": phone, - "password": password, - "data": data, - "gotrue_meta_security": { - "captcha_token": captcha_token, - }, - }, - xform=parse_auth_response, - ) - else: - raise AuthInvalidCredentialsError( - "You must provide either an email or phone number and a password" - ) - if response.session: - await self._save_session(response.session) - self._notify_all_subscribers("SIGNED_IN", response.session) - return response - - async def sign_in_with_password( - self, - credentials: SignInWithPasswordCredentials, - ) -> AuthResponse: - """ - Log in an existing user with an email or phone and password. - """ - await self._remove_session() - email = credentials.get("email") - phone = credentials.get("phone") - password = credentials.get("password") - options = credentials.get("options", {}) - data = options.get("data") or {} - captcha_token = options.get("captcha_token") - if email: - response = await self._request( - "POST", - "token", - body={ - "email": email, - "password": password, - "data": data, - "gotrue_meta_security": { - "captcha_token": captcha_token, - }, - }, - query={ - "grant_type": "password", - }, - xform=parse_auth_response, - ) - elif phone: - response = await self._request( - "POST", - "token", - body={ - "phone": phone, - "password": password, - "data": data, - "gotrue_meta_security": { - "captcha_token": captcha_token, - }, - }, - query={ - "grant_type": "password", - }, - xform=parse_auth_response, - ) - else: - raise AuthInvalidCredentialsError( - "You must provide either an email or phone number and a password" - ) - if response.session: - await self._save_session(response.session) - self._notify_all_subscribers("SIGNED_IN", response.session) - return response - - async def sign_in_with_oauth( - self, - credentials: SignInWithOAuthCredentials, - ) -> OAuthResponse: - """ - Log in an existing user via a third-party provider. - """ - await self._remove_session() - provider = credentials.get("provider") - options = credentials.get("options", {}) - redirect_to = options.get("redirect_to") - scopes = options.get("scopes") - params = options.get("query_params", {}) - if redirect_to: - params["redirect_to"] = redirect_to - if scopes: - params["scopes"] = scopes - url = self._get_url_for_provider(provider, params) - return OAuthResponse(provider=provider, url=url) - - async def sign_in_with_otp( - self, - credentials: SignInWithPasswordlessCredentials, - ) -> AuthResponse: - """ - Log in a user using magiclink or a one-time password (OTP). - - If the `{{ .ConfirmationURL }}` variable is specified in - the email template, a magiclink will be sent. - - If the `{{ .Token }}` variable is specified in the email - template, an OTP will be sent. - - If you're using phone sign-ins, only an OTP will be sent. - You won't be able to send a magiclink for phone sign-ins. - """ - await self._remove_session() - email = credentials.get("email") - phone = credentials.get("phone") - options = credentials.get("options", {}) - email_redirect_to = options.get("email_redirect_to") - should_create_user = options.get("create_user", True) - data = options.get("data") - captcha_token = options.get("captcha_token") - if email: - return await self._request( - "POST", - "otp", - body={ - "email": email, - "data": data, - "create_user": should_create_user, - "gotrue_meta_security": { - "captcha_token": captcha_token, - }, - }, - redirect_to=email_redirect_to, - xform=parse_auth_response, - ) - if phone: - return await self._request( - "POST", - "otp", - body={ - "phone": phone, - "data": data, - "create_user": should_create_user, - "gotrue_meta_security": { - "captcha_token": captcha_token, - }, - }, - xform=parse_auth_response, - ) - raise AuthInvalidCredentialsError( - "You must provide either an email or phone number" - ) - - async def verify_otp(self, params: VerifyOtpParams) -> AuthResponse: - """ - Log in a user given a User supplied OTP received via mobile. - """ - await self._remove_session() - response = await self._request( - "POST", - "verify", - body={ - "gotrue_meta_security": { - "captcha_token": params.get("options", {}).get("captcha_token"), - }, - **params, - }, - redirect_to=params.get("options", {}).get("redirect_to"), - xform=parse_auth_response, - ) - if response.session: - await self._save_session(response.session) - self._notify_all_subscribers("SIGNED_IN", response.session) - return response - - async def get_session(self) -> Union[Session, None]: - """ - Returns the session, refreshing it if necessary. - - The session returned can be null if the session is not detected which - can happen in the event a user is not signed-in or has logged out. - """ - current_session: Union[Session, None] = None - if self._persist_session: - maybe_session = await self._storage.get_item(self._storage_key) - current_session = self._get_valid_session(maybe_session) - if not current_session: - await self._remove_session() - else: - current_session = self._in_memory_session - if not current_session: - return None - time_now = round(time()) - has_expired = ( - current_session.expires_at <= time_now + EXPIRY_MARGIN - if current_session.expires_at - else False - ) - return ( - await self._call_refresh_token(current_session.refresh_token) - if has_expired - else current_session - ) - - async def get_user(self, jwt: Union[str, None] = None) -> UserResponse: - """ - Gets the current user details if there is an existing session. - - Takes in an optional access token `jwt`. If no `jwt` is provided, - `get_user()` will attempt to get the `jwt` from the current session. - """ - if not jwt: - session = await self.get_session() - if session: - jwt = session.access_token - return await self._request("GET", "user", jwt=jwt, xform=parse_user_response) - - async def update_user(self, attributes: UserAttributes) -> UserResponse: - """ - Updates user data, if there is a logged in user. - """ - session = await self.get_session() - if not session: - raise AuthSessionMissingError() - response = await self._request( - "PUT", - "user", - body=attributes, - jwt=session.access_token, - xform=parse_user_response, - ) - session.user = response.user - await self._save_session(session) - self._notify_all_subscribers("USER_UPDATED", session) - return response - - async def set_session(self, access_token: str, refresh_token: str) -> AuthResponse: - """ - Sets the session data from the current session. If the current session - is expired, `set_session` will take care of refreshing it to obtain a - new session. - - If the refresh token in the current session is invalid and the current - session has expired, an error will be thrown. - - If the current session does not contain at `expires_at` field, - `set_session` will use the exp claim defined in the access token. - - The current session that minimally contains an access token, - refresh token and a user. - """ - time_now = round(time()) - expires_at = time_now - has_expired = True - session: Union[Session, None] = None - if access_token and access_token.split(".")[1]: - payload = self._decode_jwt(access_token) - if exp := payload.get("exp"): - expires_at = int(exp) - has_expired = expires_at <= time_now - if has_expired: - if not refresh_token: - raise AuthSessionMissingError() - response = await self._refresh_access_token(refresh_token) - if not response.session: - return AuthResponse() - session = response.session - else: - response = await self.get_user(access_token) - session = Session( - access_token=access_token, - refresh_token=refresh_token, - user=response.user, - token_type="bearer", - expires_in=expires_at - time_now, - expires_at=expires_at, - ) - await self._save_session(session) - self._notify_all_subscribers("TOKEN_REFRESHED", session) - return AuthResponse(session=session, user=response.user) - - async def refresh_session( - self, refresh_token: Union[str, None] = None - ) -> AuthResponse: - """ - Returns a new session, regardless of expiry status. - - Takes in an optional current session. If not passed in, then refreshSession() - will attempt to retrieve it from getSession(). If the current session's - refresh token is invalid, an error will be thrown. - """ - if not refresh_token: - session = await self.get_session() - if session: - refresh_token = session.refresh_token - if not refresh_token: - raise AuthSessionMissingError() - session = await self._call_refresh_token(refresh_token) - return AuthResponse(session=session, user=session.user) - - async def sign_out(self) -> None: - """ - Inside a browser context, `sign_out` will remove the logged in user from the - browser session and log them out - removing all items from localstorage and - then trigger a `"SIGNED_OUT"` event. - - For server-side management, you can revoke all refresh tokens for a user by - passing a user's JWT through to `api.sign_out`. - - There is no way to revoke a user's access token jwt until it expires. - It is recommended to set a shorter expiry on the jwt for this reason. - """ - session = await self.get_session() - access_token = session.access_token if session else None - if access_token: - await self.admin.sign_out(access_token) - await self._remove_session() - self._notify_all_subscribers("SIGNED_OUT", None) - - async def on_auth_state_change( - self, - callback: Callable[[AuthChangeEvent, Union[Session, None]], None], - ) -> Subscription: - """ - Receive a notification every time an auth event happens. - """ - unique_id = str(uuid4()) - - def _unsubscribe() -> None: - self._state_change_emitters.pop(unique_id) - - subscription = Subscription( - id=unique_id, - callback=callback, - unsubscribe=_unsubscribe, - ) - self._state_change_emitters[unique_id] = subscription - return subscription - - async def reset_password_email( - self, - email: str, - options: Options = {}, - ) -> None: - """ - Sends a password reset request to an email address. - """ - await self._request( - "POST", - "recover", - body={ - "email": email, - "gotrue_meta_security": { - "captcha_token": options.get("captcha_token"), - }, - }, - redirect_to=options.get("redirect_to"), - ) - - # MFA methods - - async def _enroll(self, params: MFAEnrollParams) -> AuthMFAEnrollResponse: - session = await self.get_session() - if not session: - raise AuthSessionMissingError() - response = await self._request( - "POST", - "factors", - body=params, - jwt=session.access_token, - xform=AuthMFAEnrollResponse.parse_obj, - ) - if response.totp.qr_code: - response.totp.qr_code = f"data:image/svg+xml;utf-8,{response.totp.qr_code}" - return response - - async def _challenge(self, params: MFAChallengeParams) -> AuthMFAChallengeResponse: - session = await self.get_session() - if not session: - raise AuthSessionMissingError() - return await self._request( - "POST", - f"factors/{params.get('factor_id')}/challenge", - jwt=session.access_token, - xform=AuthMFAChallengeResponse.parse_obj, - ) - - async def _challenge_and_verify( - self, - params: MFAChallengeAndVerifyParams, - ) -> AuthMFAVerifyResponse: - response = await self._challenge( - { - "factor_id": params.get("factor_id"), - } - ) - return await self._verify( - { - "factor_id": params.get("factor_id"), - "challenge_id": response.id, - "code": params.get("code"), - } - ) - - async def _verify(self, params: MFAVerifyParams) -> AuthMFAVerifyResponse: - session = await self.get_session() - if not session: - raise AuthSessionMissingError() - response = await self._request( - "POST", - f"factors/{params.get('factor_id')}/verify", - body=params, - jwt=session.access_token, - xform=AuthMFAVerifyResponse.parse_obj, - ) - session = Session.parse_obj(response.dict()) - await self._save_session(session) - self._notify_all_subscribers("MFA_CHALLENGE_VERIFIED", session) - return response - - async def _unenroll(self, params: MFAUnenrollParams) -> AuthMFAUnenrollResponse: - session = await self.get_session() - if not session: - raise AuthSessionMissingError() - return await self._request( - "DELETE", - f"factors/{params.get('factor_id')}", - jwt=session.access_token, - xform=AuthMFAUnenrollResponse.parse_obj, - ) - - async def _list_factors(self) -> AuthMFAListFactorsResponse: - response = await self.get_user() - all = response.user.factors or [] - totp = [f for f in all if f.factor_type == "totp" and f.status == "verified"] - return AuthMFAListFactorsResponse(all=all, totp=totp) - - async def _get_authenticator_assurance_level( - self, - ) -> AuthMFAGetAuthenticatorAssuranceLevelResponse: - session = await self.get_session() - if not session: - return AuthMFAGetAuthenticatorAssuranceLevelResponse( - current_level=None, - next_level=None, - current_authentication_methods=[], - ) - payload = self._decode_jwt(session.access_token) - current_level: Union[AuthenticatorAssuranceLevels, None] = None - if payload.get("aal"): - current_level = payload.get("aal") - verified_factors = [ - f for f in session.user.factors or [] if f.status == "verified" - ] - next_level = "aal2" if verified_factors else current_level - current_authentication_methods = payload.get("amr") or [] - return AuthMFAGetAuthenticatorAssuranceLevelResponse( - current_level=current_level, - next_level=next_level, - current_authentication_methods=current_authentication_methods, - ) - - # Private methods - - async def _remove_session(self) -> None: - if self._persist_session: - await self._storage.remove_item(self._storage_key) - else: - self._in_memory_session = None - if self._refresh_token_timer: - self._refresh_token_timer.cancel() - self._refresh_token_timer = None - - async def _get_session_from_url( - self, - url: str, - ) -> Tuple[Session, Union[str, None]]: - if not self._is_implicit_grant_flow(url): - raise AuthImplicitGrantRedirectError("Not a valid implicit grant flow url.") - result = urlparse(url) - params = parse_qs(result.query) - if error_description := self._get_param(params, "error_description"): - error_code = self._get_param(params, "error_code") - error = self._get_param(params, "error") - if not error_code: - raise AuthImplicitGrantRedirectError("No error_code detected.") - if not error: - raise AuthImplicitGrantRedirectError("No error detected.") - raise AuthImplicitGrantRedirectError( - error_description, - {"code": error_code, "error": error}, - ) - provider_token = self._get_param(params, "provider_token") - provider_refresh_token = self._get_param(params, "provider_refresh_token") - access_token = self._get_param(params, "access_token") - if not access_token: - raise AuthImplicitGrantRedirectError("No access_token detected.") - expires_in = self._get_param(params, "expires_in") - if not expires_in: - raise AuthImplicitGrantRedirectError("No expires_in detected.") - refresh_token = self._get_param(params, "refresh_token") - if not refresh_token: - raise AuthImplicitGrantRedirectError("No refresh_token detected.") - token_type = self._get_param(params, "token_type") - if not token_type: - raise AuthImplicitGrantRedirectError("No token_type detected.") - time_now = round(time()) - expires_at = time_now + int(expires_in) - user = await self.get_user(access_token) - session = Session( - provider_token=provider_token, - provider_refresh_token=provider_refresh_token, - access_token=access_token, - expires_in=int(expires_in), - expires_at=expires_at, - refresh_token=refresh_token, - token_type=token_type, - user=user.user, - ) - redirect_type = self._get_param(params, "type") - return session, redirect_type - - async def _recover_and_refresh(self) -> None: - raw_session = await self._storage.get_item(self._storage_key) - current_session = self._get_valid_session(raw_session) - if not current_session: - if raw_session: - await self._remove_session() - return - time_now = round(time()) - expires_at = current_session.expires_at - if expires_at and expires_at < time_now + EXPIRY_MARGIN: - refresh_token = current_session.refresh_token - if self._auto_refresh_token and refresh_token: - self._network_retries += 1 - try: - await self._call_refresh_token(refresh_token) - self._network_retries = 0 - except Exception as e: - if ( - isinstance(e, AuthRetryableError) - and self._network_retries < MAX_RETRIES - ): - if self._refresh_token_timer: - self._refresh_token_timer.cancel() - self._refresh_token_timer = Timer( - (RETRY_INTERVAL ** (self._network_retries * 100)), - self._recover_and_refresh, - ) - self._refresh_token_timer.start() - return - await self._remove_session() - return - if self._persist_session: - await self._save_session(current_session) - self._notify_all_subscribers("SIGNED_IN", current_session) - - async def _call_refresh_token(self, refresh_token: str) -> Session: - if not refresh_token: - raise AuthSessionMissingError() - response = await self._refresh_access_token(refresh_token) - if not response.session: - raise AuthSessionMissingError() - await self._save_session(response.session) - self._notify_all_subscribers("TOKEN_REFRESHED", response.session) - return response.session - - async def _refresh_access_token(self, refresh_token: str) -> AuthResponse: - return await self._request( - "POST", - "token", - query={"grant_type": "refresh_token"}, - body={"refresh_token": refresh_token}, - xform=parse_auth_response, - ) - - async def _save_session(self, session: Session) -> None: - if not self._persist_session: - self._in_memory_session = session - if expire_at := session.expires_at: - time_now = round(time()) - expire_in = expire_at - time_now - refresh_duration_before_expires = ( - EXPIRY_MARGIN if expire_in > EXPIRY_MARGIN else 0.5 - ) - value = (expire_in - refresh_duration_before_expires) * 1000 - await self._start_auto_refresh_token(value) - if self._persist_session and session.expires_at: - await self._storage.set_item(self._storage_key, session.json()) - - async def _start_auto_refresh_token(self, value: float) -> None: - if self._refresh_token_timer: - self._refresh_token_timer.cancel() - self._refresh_token_timer = None - if value <= 0 or not self._auto_refresh_token: - return - - async def refresh_token_function(): - self._network_retries += 1 - try: - session = await self.get_session() - if session: - await self._call_refresh_token(session.refresh_token) - self._network_retries = 0 - except Exception as e: - if ( - isinstance(e, AuthRetryableError) - and self._network_retries < MAX_RETRIES - ): - await self._start_auto_refresh_token( - RETRY_INTERVAL ** (self._network_retries * 100) - ) - - self._refresh_token_timer = Timer(value, refresh_token_function) - self._refresh_token_timer.start() - - def _notify_all_subscribers( - self, - event: AuthChangeEvent, - session: Union[Session, None], - ) -> None: - for subscription in self._state_change_emitters.values(): - subscription.callback(event, session) - - def _get_valid_session( - self, - raw_session: Union[str, None], - ) -> Union[Session, None]: - if not raw_session: - return None - data = loads(raw_session) - if not data: - return None - if not data.get("access_token"): - return None - if not data.get("refresh_token"): - return None - if not data.get("expires_at"): - return None - try: - expires_at = int(data["expires_at"]) - data["expires_at"] = expires_at - except ValueError: - return None - try: - return Session.parse_obj(data) - except Exception: - return None - - def _get_param( - self, - query_params: Dict[str, List[str]], - name: str, - ) -> Union[str, None]: - return query_params[name][0] if name in query_params else None - - def _is_implicit_grant_flow(self, url: str) -> bool: - result = urlparse(url) - params = parse_qs(result.query) - return "access_token" in params or "error_description" in params - - def _get_url_for_provider( - self, - provider: Provider, - params: Dict[str, str], - ) -> str: - params = {k: quote(v) for k, v in params.items()} - params["provider"] = quote(provider) - query = urlencode(params) - return f"{self._url}/authorize?{query}" - - def _decode_jwt(self, jwt: str) -> DecodedJWTDict: - """ - Decodes a JWT (without performing any validation). - """ - return decode_jwt_payload(jwt) From 91b0597e77efe4e5b3ad11c18188f6121fe75d20 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Sun, 22 Jan 2023 23:20:23 +0800 Subject: [PATCH 11/16] chore: patch module level import --- gotrue/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gotrue/__init__.py b/gotrue/__init__.py index 6db087d8..1c050251 100644 --- a/gotrue/__init__.py +++ b/gotrue/__init__.py @@ -2,13 +2,13 @@ __version__ = "0.5.4" -from ._async.api import AsyncGoTrueAPI from ._async.client import AsyncGoTrueClient +from ._async.gotrue_admin_api import AsyncGoTrueAdminAPI from ._async.storage import AsyncMemoryStorage, AsyncSupportedStorage -from ._sync.api import SyncGoTrueAPI from ._sync.client import SyncGoTrueClient +from ._sync.gotrue_admin_api import SyncGoTrueAdminAPI from ._sync.storage import SyncMemoryStorage, SyncSupportedStorage from .types import * Client = SyncGoTrueClient -GoTrueAPI = SyncGoTrueAPI +GoTrueAPI = SyncGoTrueAdminAPI From 80907eb31ea7441a19fbb6b1e2006dd949d3649c Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Sun, 22 Jan 2023 23:26:57 +0800 Subject: [PATCH 12/16] chore: build sync --- docs/source/api/api.rst | 2 +- gotrue/_async/client.py | 9 ++++++--- gotrue/_sync/client.py | 9 ++++++--- .../test_api_with_auto_confirm_disabled.py | 16 ++++++++-------- .../_async/test_api_with_auto_confirm_enabled.py | 11 +++++------ .../_async/test_client_with_sign_ups_disabled.py | 10 +++++----- .../_sync/test_api_with_auto_confirm_disabled.py | 16 ++++++++-------- .../_sync/test_api_with_auto_confirm_enabled.py | 11 +++++------ .../_sync/test_client_with_sign_ups_disabled.py | 10 +++++----- 9 files changed, 49 insertions(+), 45 deletions(-) diff --git a/docs/source/api/api.rst b/docs/source/api/api.rst index 2f2a673d..e1329b64 100644 --- a/docs/source/api/api.rst +++ b/docs/source/api/api.rst @@ -1,5 +1,5 @@ API ====== -.. autoclass:: gotrue._async.api.AsyncGoTrueAPI +.. autoclass:: gotrue._async.api.AsycnGoTrueAdminAPI :inherited-members: diff --git a/gotrue/_async/client.py b/gotrue/_async/client.py index c72fe99e..91fb204e 100644 --- a/gotrue/_async/client.py +++ b/gotrue/_async/client.py @@ -20,7 +20,7 @@ UserAttributes, UserAttributesDict, ) -from .api import AsyncGoTrueAPI +from .api import AsycnGoTrueAdminAPI from .storage import AsyncMemoryStorage, AsyncSupportedStorage @@ -34,7 +34,7 @@ def __init__( persist_session: bool = True, local_storage: AsyncSupportedStorage = AsyncMemoryStorage(), cookie_options: CookieOptions = CookieOptions.parse_obj(COOKIE_OPTIONS), - api: Optional[AsyncGoTrueAPI] = None, + api: Optional[AsycnGoTrueAdminAPI] = None, replace_default_headers: bool = False, ) -> None: """Create a new client @@ -72,7 +72,7 @@ def __init__( "headers": {**empty_or_default_headers, **headers}, "cookie_options": cookie_options, } - self.api = api or AsyncGoTrueAPI(**args) + self.api = api or AsyncGoTrueAdminAPI(**args) async def __aenter__(self) -> AsyncGoTrueClient: return self @@ -463,6 +463,9 @@ async def get_session_from_url( self._notify_all_subscribers(event=AuthChangeEvent.PASSWORD_RECOVERY) return session + async def get_session(self) -> None: + return None + async def sign_out(self) -> None: """Log the user out.""" access_token: Optional[str] = None diff --git a/gotrue/_sync/client.py b/gotrue/_sync/client.py index 3ce6b2de..4eaa1a0f 100644 --- a/gotrue/_sync/client.py +++ b/gotrue/_sync/client.py @@ -20,7 +20,7 @@ UserAttributes, UserAttributesDict, ) -from .api import SyncGoTrueAPI +from .api import AsycnGoTrueAdminAPI from .storage import SyncMemoryStorage, SyncSupportedStorage @@ -34,7 +34,7 @@ def __init__( persist_session: bool = True, local_storage: SyncSupportedStorage = SyncMemoryStorage(), cookie_options: CookieOptions = CookieOptions.parse_obj(COOKIE_OPTIONS), - api: Optional[SyncGoTrueAPI] = None, + api: Optional[AsycnGoTrueAdminAPI] = None, replace_default_headers: bool = False, ) -> None: """Create a new client @@ -72,7 +72,7 @@ def __init__( "headers": {**empty_or_default_headers, **headers}, "cookie_options": cookie_options, } - self.api = api or SyncGoTrueAPI(**args) + self.api = api or SyncGoTrueAdminAPI(**args) def __enter__(self) -> SyncGoTrueClient: return self @@ -459,6 +459,9 @@ def get_session_from_url( self._notify_all_subscribers(event=AuthChangeEvent.PASSWORD_RECOVERY) return session + def get_session(self) -> None: + return None + def sign_out(self) -> None: """Log the user out.""" access_token: Optional[str] = None diff --git a/tests/_async/test_api_with_auto_confirm_disabled.py b/tests/_async/test_api_with_auto_confirm_disabled.py index 3cd458dd..9564688c 100644 --- a/tests/_async/test_api_with_auto_confirm_disabled.py +++ b/tests/_async/test_api_with_auto_confirm_disabled.py @@ -3,7 +3,7 @@ import pytest from faker import Faker -from gotrue import AsyncGoTrueAPI +from gotrue import AsycnGoTrueAdminAPI from gotrue.constants import COOKIE_OPTIONS from gotrue.types import CookieOptions, LinkType, User @@ -12,8 +12,8 @@ @pytest.fixture(name="api") -async def create_api() -> AsyncIterable[AsyncGoTrueAPI]: - async with AsyncGoTrueAPI( +async def create_api() -> AsyncIterable[AsycnGoTrueAdminAPI]: + async with AsycnGoTrueAdminAPI( url=GOTRUE_URL, headers={"Authorization": f"Bearer {TOKEN}"}, cookie_options=CookieOptions.parse_obj(COOKIE_OPTIONS), @@ -27,7 +27,7 @@ async def create_api() -> AsyncIterable[AsyncGoTrueAPI]: password = fake.password() -async def test_sign_up_with_email_and_password(api: AsyncGoTrueAPI): +async def test_sign_up_with_email_and_password(api: AsycnGoTrueAdminAPI): try: response = await api.sign_up_with_email( email=email, @@ -44,7 +44,7 @@ async def test_sign_up_with_email_and_password(api: AsyncGoTrueAPI): password2 = fake.password() -async def test_generate_sign_up_link(api: AsyncGoTrueAPI): +async def test_generate_sign_up_link(api: AsycnGoTrueAdminAPI): try: response = await api.generate_link( type=LinkType.signup, @@ -61,7 +61,7 @@ async def test_generate_sign_up_link(api: AsyncGoTrueAPI): email3 = f"api_generate_link_signup_{fake.email().lower()}" -async def test_generate_magic_link(api: AsyncGoTrueAPI): +async def test_generate_magic_link(api: AsycnGoTrueAdminAPI): try: response = await api.generate_link( type=LinkType.magiclink, @@ -73,7 +73,7 @@ async def test_generate_magic_link(api: AsyncGoTrueAPI): assert False, str(e) -async def test_generate_invite_link(api: AsyncGoTrueAPI): +async def test_generate_invite_link(api: AsycnGoTrueAdminAPI): try: response = await api.generate_link( type=LinkType.invite, @@ -86,7 +86,7 @@ async def test_generate_invite_link(api: AsyncGoTrueAPI): @pytest.mark.depends(on=[test_sign_up_with_email_and_password.__name__]) -async def test_generate_recovery_link(api: AsyncGoTrueAPI): +async def test_generate_recovery_link(api: AsycnGoTrueAdminAPI): try: response = await api.generate_link( type=LinkType.recovery, diff --git a/tests/_async/test_api_with_auto_confirm_enabled.py b/tests/_async/test_api_with_auto_confirm_enabled.py index 13ff0411..3ff74d93 100644 --- a/tests/_async/test_api_with_auto_confirm_enabled.py +++ b/tests/_async/test_api_with_auto_confirm_enabled.py @@ -3,7 +3,6 @@ import pytest from faker import Faker -from gotrue import AsyncGoTrueAPI from gotrue.constants import COOKIE_OPTIONS from gotrue.types import CookieOptions, Session, User @@ -12,8 +11,8 @@ @pytest.fixture(name="api") -async def create_api() -> AsyncIterable[AsyncGoTrueAPI]: - async with AsyncGoTrueAPI( +async def create_api() -> AsyncIterable[AsyncGoTrueAdminAPI]: + async with AsyncGoTrueAdminAPI( url=GOTRUE_URL, headers={"Authorization": f"Bearer {TOKEN}"}, cookie_options=CookieOptions.parse_obj(COOKIE_OPTIONS), @@ -28,7 +27,7 @@ async def create_api() -> AsyncIterable[AsyncGoTrueAPI]: valid_session: Optional[Session] = None -async def test_sign_up_with_email(api: AsyncGoTrueAPI): +async def test_sign_up_with_email(api: AsyncGoTrueAdminAPI): global valid_session try: response = await api.sign_up_with_email( @@ -43,7 +42,7 @@ async def test_sign_up_with_email(api: AsyncGoTrueAPI): @pytest.mark.depends(on=[test_sign_up_with_email.__name__]) -async def test_get_user(api: AsyncGoTrueAPI): +async def test_get_user(api: AsyncGoTrueAdminAPI): try: jwt = valid_session.access_token if valid_session else "" response = await api.get_user(jwt=jwt) @@ -53,7 +52,7 @@ async def test_get_user(api: AsyncGoTrueAPI): @pytest.mark.depends(on=[test_get_user.__name__]) -async def test_delete_user(api: AsyncGoTrueAPI): +async def test_delete_user(api: AsyncGoTrueAdminAPI): try: jwt = valid_session.access_token if valid_session else "" user = await api.get_user(jwt=jwt) diff --git a/tests/_async/test_client_with_sign_ups_disabled.py b/tests/_async/test_client_with_sign_ups_disabled.py index cae4bbd8..51cef276 100644 --- a/tests/_async/test_client_with_sign_ups_disabled.py +++ b/tests/_async/test_client_with_sign_ups_disabled.py @@ -3,7 +3,7 @@ import pytest from faker import Faker -from gotrue import AsyncGoTrueAPI, AsyncGoTrueClient +from gotrue import AsycnGoTrueAdminAPI, AsyncGoTrueClient from gotrue.constants import COOKIE_OPTIONS, DEFAULT_HEADERS from gotrue.exceptions import APIError from gotrue.types import CookieOptions, LinkType, User, UserAttributes @@ -13,8 +13,8 @@ @pytest.fixture(name="auth_admin") -async def create_auth_admin() -> AsyncIterable[AsyncGoTrueAPI]: - async with AsyncGoTrueAPI( +async def create_auth_admin() -> AsyncIterable[AsycnGoTrueAdminAPI]: + async with AsycnGoTrueAdminAPI( url=GOTRUE_URL, headers={"Authorization": f"Bearer {AUTH_ADMIN_TOKEN}"}, cookie_options=CookieOptions.parse_obj(COOKIE_OPTIONS), @@ -53,7 +53,7 @@ async def test_sign_up(client: AsyncGoTrueClient): async def test_generate_link_should_be_able_to_generate_multiple_links( - auth_admin: AsyncGoTrueAPI, + auth_admin: AsycnGoTrueAdminAPI, ): try: response = await auth_admin.generate_link( @@ -103,7 +103,7 @@ async def test_generate_link_should_be_able_to_generate_multiple_links( email2 = fake.email().lower() -async def test_create_user(auth_admin: AsyncGoTrueAPI): +async def test_create_user(auth_admin: AsycnGoTrueAdminAPI): try: attributes = UserAttributes(email=email2) response = await auth_admin.create_user(attributes=attributes) diff --git a/tests/_sync/test_api_with_auto_confirm_disabled.py b/tests/_sync/test_api_with_auto_confirm_disabled.py index b87f489c..3fd66f6c 100644 --- a/tests/_sync/test_api_with_auto_confirm_disabled.py +++ b/tests/_sync/test_api_with_auto_confirm_disabled.py @@ -3,7 +3,7 @@ import pytest from faker import Faker -from gotrue import SyncGoTrueAPI +from gotrue import AsycnGoTrueAdminAPI from gotrue.constants import COOKIE_OPTIONS from gotrue.types import CookieOptions, LinkType, User @@ -12,8 +12,8 @@ @pytest.fixture(name="api") -def create_api() -> Iterable[SyncGoTrueAPI]: - with SyncGoTrueAPI( +def create_api() -> Iterable[AsycnGoTrueAdminAPI]: + with AsycnGoTrueAdminAPI( url=GOTRUE_URL, headers={"Authorization": f"Bearer {TOKEN}"}, cookie_options=CookieOptions.parse_obj(COOKIE_OPTIONS), @@ -27,7 +27,7 @@ def create_api() -> Iterable[SyncGoTrueAPI]: password = fake.password() -def test_sign_up_with_email_and_password(api: SyncGoTrueAPI): +def test_sign_up_with_email_and_password(api: AsycnGoTrueAdminAPI): try: response = api.sign_up_with_email( email=email, @@ -44,7 +44,7 @@ def test_sign_up_with_email_and_password(api: SyncGoTrueAPI): password2 = fake.password() -def test_generate_sign_up_link(api: SyncGoTrueAPI): +def test_generate_sign_up_link(api: AsycnGoTrueAdminAPI): try: response = api.generate_link( type=LinkType.signup, @@ -61,7 +61,7 @@ def test_generate_sign_up_link(api: SyncGoTrueAPI): email3 = f"api_generate_link_signup_{fake.email().lower()}" -def test_generate_magic_link(api: SyncGoTrueAPI): +def test_generate_magic_link(api: AsycnGoTrueAdminAPI): try: response = api.generate_link( type=LinkType.magiclink, @@ -73,7 +73,7 @@ def test_generate_magic_link(api: SyncGoTrueAPI): assert False, str(e) -def test_generate_invite_link(api: SyncGoTrueAPI): +def test_generate_invite_link(api: AsycnGoTrueAdminAPI): try: response = api.generate_link( type=LinkType.invite, @@ -86,7 +86,7 @@ def test_generate_invite_link(api: SyncGoTrueAPI): @pytest.mark.depends(on=[test_sign_up_with_email_and_password.__name__]) -def test_generate_recovery_link(api: SyncGoTrueAPI): +def test_generate_recovery_link(api: AsycnGoTrueAdminAPI): try: response = api.generate_link( type=LinkType.recovery, diff --git a/tests/_sync/test_api_with_auto_confirm_enabled.py b/tests/_sync/test_api_with_auto_confirm_enabled.py index 578646a8..d3830a2b 100644 --- a/tests/_sync/test_api_with_auto_confirm_enabled.py +++ b/tests/_sync/test_api_with_auto_confirm_enabled.py @@ -3,7 +3,6 @@ import pytest from faker import Faker -from gotrue import SyncGoTrueAPI from gotrue.constants import COOKIE_OPTIONS from gotrue.types import CookieOptions, Session, User @@ -12,8 +11,8 @@ @pytest.fixture(name="api") -def create_api() -> Iterable[SyncGoTrueAPI]: - with SyncGoTrueAPI( +def create_api() -> Iterable[SyncGoTrueAdminAPI]: + with SyncGoTrueAdminAPI( url=GOTRUE_URL, headers={"Authorization": f"Bearer {TOKEN}"}, cookie_options=CookieOptions.parse_obj(COOKIE_OPTIONS), @@ -28,7 +27,7 @@ def create_api() -> Iterable[SyncGoTrueAPI]: valid_session: Optional[Session] = None -def test_sign_up_with_email(api: SyncGoTrueAPI): +def test_sign_up_with_email(api: SyncGoTrueAdminAPI): global valid_session try: response = api.sign_up_with_email( @@ -43,7 +42,7 @@ def test_sign_up_with_email(api: SyncGoTrueAPI): @pytest.mark.depends(on=[test_sign_up_with_email.__name__]) -def test_get_user(api: SyncGoTrueAPI): +def test_get_user(api: SyncGoTrueAdminAPI): try: jwt = valid_session.access_token if valid_session else "" response = api.get_user(jwt=jwt) @@ -53,7 +52,7 @@ def test_get_user(api: SyncGoTrueAPI): @pytest.mark.depends(on=[test_get_user.__name__]) -def test_delete_user(api: SyncGoTrueAPI): +def test_delete_user(api: SyncGoTrueAdminAPI): try: jwt = valid_session.access_token if valid_session else "" user = api.get_user(jwt=jwt) diff --git a/tests/_sync/test_client_with_sign_ups_disabled.py b/tests/_sync/test_client_with_sign_ups_disabled.py index 0e0cb9c8..9ff48c0b 100644 --- a/tests/_sync/test_client_with_sign_ups_disabled.py +++ b/tests/_sync/test_client_with_sign_ups_disabled.py @@ -3,7 +3,7 @@ import pytest from faker import Faker -from gotrue import SyncGoTrueAPI, SyncGoTrueClient +from gotrue import AsycnGoTrueAdminAPI, SyncGoTrueClient from gotrue.constants import COOKIE_OPTIONS, DEFAULT_HEADERS from gotrue.exceptions import APIError from gotrue.types import CookieOptions, LinkType, User, UserAttributes @@ -13,8 +13,8 @@ @pytest.fixture(name="auth_admin") -def create_auth_admin() -> Iterable[SyncGoTrueAPI]: - with SyncGoTrueAPI( +def create_auth_admin() -> Iterable[AsycnGoTrueAdminAPI]: + with AsycnGoTrueAdminAPI( url=GOTRUE_URL, headers={"Authorization": f"Bearer {AUTH_ADMIN_TOKEN}"}, cookie_options=CookieOptions.parse_obj(COOKIE_OPTIONS), @@ -53,7 +53,7 @@ def test_sign_up(client: SyncGoTrueClient): def test_generate_link_should_be_able_to_generate_multiple_links( - auth_admin: SyncGoTrueAPI, + auth_admin: AsycnGoTrueAdminAPI, ): try: response = auth_admin.generate_link( @@ -103,7 +103,7 @@ def test_generate_link_should_be_able_to_generate_multiple_links( email2 = fake.email().lower() -def test_create_user(auth_admin: SyncGoTrueAPI): +def test_create_user(auth_admin: AsycnGoTrueAdminAPI): try: attributes = UserAttributes(email=email2) response = auth_admin.create_user(attributes=attributes) From 4cacf393d317eaee4deb0ff59a476f46df450526 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Sun, 22 Jan 2023 23:31:45 +0800 Subject: [PATCH 13/16] chore:rever types to original state --- gotrue/types.py | 242 +----------------------------------------------- 1 file changed, 2 insertions(+), 240 deletions(-) diff --git a/gotrue/types.py b/gotrue/types.py index 09b27ba3..fdfd294d 100644 --- a/gotrue/types.py +++ b/gotrue/types.py @@ -8,9 +8,9 @@ from uuid import UUID if sys.version_info >= (3, 8): - from typing import Literal, NotRequired, TypedDict + from typing import TypedDict else: - from typing_extensions import Literal, TypedDict, NotRequired + from typing_extensions import TypedDict from httpx import Response from pydantic import BaseModel, root_validator @@ -82,37 +82,6 @@ class User(BaseModelFromResponse): email_change_sent_at: Optional[datetime] = None new_phone: Optional[str] = None phone_change_sent_at: Optional[datetime] = None - factors: Union[List[Factor], None] = None - - -class Factor(BaseModel): - """ - A MFA factor. - """ - - id: str - """ - ID of the factor. - """ - friendly_name: Union[str, None] = None - """ - Friendly name of the factor, useful to disambiguate between multiple factors. - """ - factor_type: Union[Literal["totp"], str] - """ - Type of factor. Only `totp` supported with this version but may change in - future versions. - """ - status: Literal["verified", "unverified"] - """ - Factor's status. - """ - created_at: datetime - updated_at: datetime - - -class UpdatableFactorAttributes(TypedDict): - friendly_name: str class UserAttributes(BaseModelFromResponse): @@ -196,210 +165,3 @@ class UserAttributesDict(TypedDict, total=False): password: Optional[str] email_change_token: Optional[str] data: Optional[Any] - - -class MFAEnrollParams(TypedDict): - factor_type: Literal["totp"] - issuer: NotRequired[str] - friendly_name: NotRequired[str] - - -class MFAUnenrollParams(TypedDict): - factor_id: str - """ - ID of the factor being unenrolled. - """ - - -class MFAVerifyParams(TypedDict): - factor_id: str - """ - ID of the factor being verified. - """ - challenge_id: str - """ - ID of the challenge being verified. - """ - code: str - """ - Verification code provided by the user. - """ - - -class MFAChallengeParams(TypedDict): - factor_id: str - """ - ID of the factor to be challenged. - """ - - -class MFAChallengeAndVerifyParams(TypedDict): - factor_id: str - """ - ID of the factor being verified. - """ - code: str - """ - Verification code provided by the user. - """ - - -class AuthMFAVerifyResponse(BaseModel): - access_token: str - """ - New access token (JWT) after successful verification. - """ - token_type: str - """ - Type of token, typically `Bearer`. - """ - expires_in: int - """ - Number of seconds in which the access token will expire. - """ - refresh_token: str - """ - Refresh token you can use to obtain new access tokens when expired. - """ - user: User - """ - Updated user profile. - """ - - -class AuthMFAEnrollResponseTotp(BaseModel): - qr_code: str - """ - Contains a QR code encoding the authenticator URI. You can - convert it to a URL by prepending `data:image/svg+xml;utf-8,` to - the value. Avoid logging this value to the console. - """ - secret: str - """ - The TOTP secret (also encoded in the QR code). Show this secret - in a password-style field to the user, in case they are unable to - scan the QR code. Avoid logging this value to the console. - """ - uri: str - """ - The authenticator URI encoded within the QR code, should you need - to use it. Avoid loggin this value to the console. - """ - - -class AuthMFAEnrollResponse(BaseModel): - id: str - """ - ID of the factor that was just enrolled (in an unverified state). - """ - type: Literal["totp"] - """ - Type of MFA factor. Only `totp` supported for now. - """ - totp: AuthMFAEnrollResponseTotp - """ - TOTP enrollment information. - """ - - -class AuthMFAUnenrollResponse(BaseModel): - id: str - """ - ID of the factor that was successfully unenrolled. - """ - - -class AuthMFAChallengeResponse(BaseModel): - id: str - """ - ID of the newly created challenge. - """ - expires_at: int - """ - Timestamp in UNIX seconds when this challenge will no longer be usable. - """ - - -class AuthMFAListFactorsResponse(BaseModel): - all: List[Factor] - """ - All available factors (verified and unverified). - """ - totp: List[Factor] - """ - Only verified TOTP factors. (A subset of `all`.) - """ - - -AuthenticatorAssuranceLevels = Literal["aal1", "aal2"] - - -class AuthMFAGetAuthenticatorAssuranceLevelResponse(BaseModel): - current_level: Union[AuthenticatorAssuranceLevels, None] = None - """ - Current AAL level of the session. - """ - next_level: Union[AuthenticatorAssuranceLevels, None] = None - """ - Next possible AAL level for the session. If the next level is higher - than the current one, the user should go through MFA. - """ - current_authentication_methods: List[AMREntry] - """ - A list of all authentication methods attached to this session. Use - the information here to detect the last time a user verified a - factor, for example if implementing a step-up scenario. - """ - - -class AuthMFAAdminDeleteFactorResponse(BaseModel): - id: str - """ - ID of the factor that was successfully deleted. - """ - - -class AuthMFAAdminDeleteFactorParams(TypedDict): - id: str - """ - ID of the MFA factor to delete. - """ - user_id: str - """ - ID of the user whose factor is being deleted. - """ - - -class AuthMFAAdminListFactorsResponse(BaseModel): - factors: List[Factor] - """ - All factors attached to the user. - """ - - -class AuthMFAAdminListFactorsParams(TypedDict): - user_id: str - """ - ID of the user for which to list all MFA factors. - """ - - -class DecodedJWTDict(TypedDict): - exp: NotRequired[int] - aal: NotRequired[Union[AuthenticatorAssuranceLevels, None]] - amr: NotRequired[Union[List[AMREntry], None]] - - -AMREntry.update_forward_refs() -UserResponse.update_forward_refs() -Factor.update_forward_refs() -User.update_forward_refs() -AuthMFAVerifyResponse.update_forward_refs() -AuthMFAEnrollResponseTotp.update_forward_refs() -AuthMFAEnrollResponse.update_forward_refs() -AuthMFAUnenrollResponse.update_forward_refs() -AuthMFAChallengeResponse.update_forward_refs() -AuthMFAListFactorsResponse.update_forward_refs() -AuthMFAGetAuthenticatorAssuranceLevelResponse.update_forward_refs() -AuthMFAAdminDeleteFactorResponse.update_forward_refs() -AuthMFAAdminListFactorsResponse.update_forward_refs() From 5621602bd3efc4c0bb0db172db7b0347493f58c2 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Sun, 22 Jan 2023 23:37:22 +0800 Subject: [PATCH 14/16] chore: update naming --- docs/source/api/api.rst | 2 +- gotrue/_async/client.py | 2 +- gotrue/_sync/client.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/api/api.rst b/docs/source/api/api.rst index e1329b64..fcdb4fb9 100644 --- a/docs/source/api/api.rst +++ b/docs/source/api/api.rst @@ -1,5 +1,5 @@ API ====== -.. autoclass:: gotrue._async.api.AsycnGoTrueAdminAPI +.. autoclass:: gotrue._async.gotrue_admin_api.AsycnGoTrueAdminAPI :inherited-members: diff --git a/gotrue/_async/client.py b/gotrue/_async/client.py index 91fb204e..8204b940 100644 --- a/gotrue/_async/client.py +++ b/gotrue/_async/client.py @@ -20,7 +20,7 @@ UserAttributes, UserAttributesDict, ) -from .api import AsycnGoTrueAdminAPI +from .gotrue_admin_api import AsycnGoTrueAdminAPI from .storage import AsyncMemoryStorage, AsyncSupportedStorage diff --git a/gotrue/_sync/client.py b/gotrue/_sync/client.py index 4eaa1a0f..6e91936d 100644 --- a/gotrue/_sync/client.py +++ b/gotrue/_sync/client.py @@ -20,7 +20,7 @@ UserAttributes, UserAttributesDict, ) -from .api import AsycnGoTrueAdminAPI +from .gotrue_admin_api import AsycnGoTrueAdminAPI from .storage import SyncMemoryStorage, SyncSupportedStorage From 130e493fbaa5fe0b86f80a5d025aef0b5252ee25 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Sun, 22 Jan 2023 23:43:05 +0800 Subject: [PATCH 15/16] chore: update helpers.py --- gotrue/helpers.py | 86 ++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 77 insertions(+), 9 deletions(-) diff --git a/gotrue/helpers.py b/gotrue/helpers.py index 069b79b7..6aa73025 100644 --- a/gotrue/helpers.py +++ b/gotrue/helpers.py @@ -1,18 +1,86 @@ from __future__ import annotations -from urllib.parse import quote +from base64 import b64decode +from json import loads +from typing import Any, Union, cast -from httpx import HTTPError, Response +from httpx import HTTPStatusError -from .exceptions import APIError +from .errors import AuthApiError, AuthError, AuthRetryableError, AuthUnknownError +from .types import ( + AuthResponse, + GenerateLinkProperties, + GenerateLinkResponse, + Session, + User, + UserResponse, +) -def encode_uri_component(uri: str) -> str: - return quote(uri.encode("utf-8")) +def parse_auth_response(data: Any) -> AuthResponse: + session: Union[Session, None] = None + if ( + "access_token" in data + and "refresh_token" in data + and "expires_in" in data + and data["access_token"] + and data["refresh_token"] + and data["expires_in"] + ): + session = Session.parse_obj(data) + user = User.parse_obj(data["user"]) if "user" in data else User.parse_obj(data) + return AuthResponse(session=session, user=user) -def check_response(response: Response) -> None: +def parse_link_response(data: Any) -> GenerateLinkResponse: + properties = GenerateLinkProperties( + action_link=data.get("action_link"), + email_otp=data.get("email_otp"), + hashed_token=data.get("hashed_token"), + redirect_to=data.get("redirect_to"), + verification_type=data.get("verification_type"), + ) + user = User.parse_obj({k: v for k, v in data.items() if k not in properties.dict()}) + return GenerateLinkResponse(properties=properties, user=user) + + +def parse_user_response(data: Any) -> UserResponse: + if "user" not in data: + data = {"user": data} + return UserResponse.parse_obj(data) + + +def get_error_message(error: Any) -> str: + props = ["msg", "message", "error_description", "error"] + filter = ( + lambda prop: prop in error if isinstance(error, dict) else hasattr(error, prop) + ) + return next((error[prop] for prop in props if filter(prop)), str(error)) + + +def looks_like_http_status_error(exception: Exception) -> bool: + return isinstance(exception, HTTPStatusError) + + +def handle_exception(exception: Exception) -> AuthError: + if not looks_like_http_status_error(exception): + return AuthRetryableError(get_error_message(exception), 0) + error = cast(HTTPStatusError, exception) try: - response.raise_for_status() - except HTTPError: - raise APIError.from_dict(response.json()) + network_error_codes = [502, 503, 504] + if error.response.status_code in network_error_codes: + return AuthRetryableError( + get_error_message(error), error.response.status_code + ) + json = error.response.json() + return AuthApiError(get_error_message(json), error.response.status_code or 500) + except Exception as e: + return AuthUnknownError(get_error_message(error), e) + + +def decode_jwt_payload(token: str) -> Any: + parts = token.split(".") + if len(parts) != 3: + raise ValueError("JWT is not valid: not a JWT structure") + base64Url = parts[1] + return loads(b64decode(base64Url).decode("utf-8")) From a0ebbac032b289fe35602f5113837ed85dcd1bb4 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Sun, 22 Jan 2023 23:48:12 +0800 Subject: [PATCH 16/16] chore: add errors file --- gotrue/errors.py | 115 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 gotrue/errors.py diff --git a/gotrue/errors.py b/gotrue/errors.py new file mode 100644 index 00000000..742d5d44 --- /dev/null +++ b/gotrue/errors.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from typing import Union + +from typing_extensions import TypedDict + + +class AuthError(Exception): + def __init__(self, message: str) -> None: + Exception.__init__(self, message) + self.message = message + self.name = "AuthError" + + +class AuthApiErrorDict(TypedDict): + name: str + message: str + status: int + + +class AuthApiError(AuthError): + def __init__(self, message: str, status: int) -> None: + AuthError.__init__(self, message) + self.name = "AuthApiError" + self.status = status + + def to_dict(self) -> AuthApiErrorDict: + return { + "name": self.name, + "message": self.message, + "status": self.status, + } + + +class AuthUnknownError(AuthError): + def __init__(self, message: str, original_error: Exception) -> None: + AuthError.__init__(self, message) + self.name = "AuthUnknownError" + self.original_error = original_error + + +class CustomAuthError(AuthError): + def __init__(self, message: str, name: str, status: int) -> None: + AuthError.__init__(self, message) + self.name = name + self.status = status + + def to_dict(self) -> AuthApiErrorDict: + return { + "name": self.name, + "message": self.message, + "status": self.status, + } + + +class AuthSessionMissingError(CustomAuthError): + def __init__(self) -> None: + CustomAuthError.__init__( + self, + "Auth session missing!", + "AuthSessionMissingError", + 400, + ) + + +class AuthInvalidCredentialsError(CustomAuthError): + def __init__(self, message: str) -> None: + CustomAuthError.__init__( + self, + message, + "AuthInvalidCredentialsError", + 400, + ) + + +class AuthImplicitGrantRedirectErrorDetails(TypedDict): + error: str + code: str + + +class AuthImplicitGrantRedirectErrorDict(AuthApiErrorDict): + details: Union[AuthImplicitGrantRedirectErrorDetails, None] + + +class AuthImplicitGrantRedirectError(CustomAuthError): + def __init__( + self, + message: str, + details: Union[AuthImplicitGrantRedirectErrorDetails, None] = None, + ) -> None: + CustomAuthError.__init__( + self, + message, + "AuthImplicitGrantRedirectError", + 500, + ) + self.details = details + + def to_dict(self) -> AuthImplicitGrantRedirectErrorDict: + return { + "name": self.name, + "message": self.message, + "status": self.status, + "details": self.details, + } + + +class AuthRetryableError(CustomAuthError): + def __init__(self, message: str, status: int) -> None: + CustomAuthError.__init__( + self, + message, + "AuthRetryableError", + status, + )