diff --git a/docs/source/api/api.rst b/docs/source/api/api.rst index 2f2a673d..fcdb4fb9 100644 --- a/docs/source/api/api.rst +++ b/docs/source/api/api.rst @@ -1,5 +1,5 @@ API ====== -.. autoclass:: gotrue._async.api.AsyncGoTrueAPI +.. autoclass:: gotrue._async.gotrue_admin_api.AsycnGoTrueAdminAPI :inherited-members: diff --git a/gotrue/__init__.py b/gotrue/__init__.py index 6db087d8..1c050251 100644 --- a/gotrue/__init__.py +++ b/gotrue/__init__.py @@ -2,13 +2,13 @@ __version__ = "0.5.4" -from ._async.api import AsyncGoTrueAPI from ._async.client import AsyncGoTrueClient +from ._async.gotrue_admin_api import AsyncGoTrueAdminAPI from ._async.storage import AsyncMemoryStorage, AsyncSupportedStorage -from ._sync.api import SyncGoTrueAPI from ._sync.client import SyncGoTrueClient +from ._sync.gotrue_admin_api import SyncGoTrueAdminAPI from ._sync.storage import SyncMemoryStorage, SyncSupportedStorage from .types import * Client = SyncGoTrueClient -GoTrueAPI = SyncGoTrueAPI +GoTrueAPI = SyncGoTrueAdminAPI diff --git a/gotrue/_async/api.py b/gotrue/_async/api.py deleted file mode 100644 index 77fbf84a..00000000 --- a/gotrue/_async/api.py +++ /dev/null @@ -1,642 +0,0 @@ -from __future__ import annotations - -from typing import Any, Dict, List, Optional, Union - -from pydantic import parse_obj_as - -from ..exceptions import APIError -from ..helpers import check_response, encode_uri_component -from ..http_clients import AsyncClient -from ..types import ( - CookieOptions, - LinkType, - Provider, - Session, - User, - UserAttributes, - determine_session_or_user_model_from_response, -) - - -class AsyncGoTrueAPI: - def __init__( - self, - *, - url: str, - headers: Dict[str, str], - cookie_options: CookieOptions, - http_client: Optional[AsyncClient] = None, - ) -> None: - """Initialise API class.""" - self.url = url - self.headers = headers - self.cookie_options = cookie_options - self.http_client = http_client or AsyncClient() - - async def __aenter__(self) -> AsyncGoTrueAPI: - return self - - async def __aexit__(self, exc_t, exc_v, exc_tb) -> None: - await self.close() - - async def close(self) -> None: - await self.http_client.aclose() - - async def create_user(self, *, attributes: UserAttributes) -> User: - """Creates a new user. - - This function should only be called on a server. - Never expose your `service_role` key in the browser. - - Parameters - ---------- - attributes: UserAttributes - The data you want to create the user with. - - Returns - ------- - response : User - The created user - - Raises - ------ - APIError - If an error occurs. - """ - headers = self.headers - data = attributes.dict() - url = f"{self.url}/admin/users" - response = await self.http_client.post(url, json=data, headers=headers) - return User.parse_response(response) - - async def list_users(self) -> List[User]: - """Get a list of users. - - This function should only be called on a server. - Never expose your `service_role` key in the browser. - - Returns - ------- - response : List[User] - A list of users - - Raises - ------ - APIError - If an error occurs. - """ - headers = self.headers - url = f"{self.url}/admin/users" - response = await self.http_client.get(url, headers=headers) - check_response(response) - users = response.json().get("users") - if users is None: - raise APIError("No users found in response", 400) - if not isinstance(users, list): - raise APIError("Expected a list of users", 400) - return parse_obj_as(List[User], users) - - async def sign_up_with_email( - self, - *, - email: str, - password: str, - redirect_to: Optional[str] = None, - data: Optional[Dict[str, Any]] = None, - ) -> Union[Session, User]: - """Creates a new user using their email address. - - Parameters - ---------- - email : str - The email address of the user. - password : str - The password of the user. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - data : Optional[Dict[str, Any]] - Optional user metadata. - - Returns - ------- - response : Union[Session, User] - A logged-in session if the server has "autoconfirm" ON - A user if the server has "autoconfirm" OFF - - Raises - ------ - APIError - If an error occurs. - """ - headers = self.headers - query_string = "" - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - query_string = f"?redirect_to={redirect_to_encoded}" - data = {"email": email, "password": password, "data": data} - url = f"{self.url}/signup{query_string}" - response = await self.http_client.post(url, json=data, headers=headers) - SessionOrUserModel = determine_session_or_user_model_from_response(response) - return SessionOrUserModel.parse_response(response) - - async def sign_in_with_email( - self, - *, - email: str, - password: str, - redirect_to: Optional[str] = None, - ) -> Session: - """Logs in an existing user using their email address. - - Parameters - ---------- - email : str - The email address of the user. - password : str - The password of the user. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - - Returns - ------- - response : Session - A logged-in session - - Raises - ------ - APIError - If an error occurs. - """ - - headers = self.headers - query_string = "?grant_type=password" - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - query_string += f"&redirect_to={redirect_to_encoded}" - data = {"email": email, "password": password} - url = f"{self.url}/token{query_string}" - response = await self.http_client.post(url, json=data, headers=headers) - return Session.parse_response(response) - - async def sign_up_with_phone( - self, - *, - phone: str, - password: str, - data: Optional[Dict[str, Any]] = None, - ) -> Union[Session, User]: - """Signs up a new user using their phone number and a password. - - Parameters - ---------- - phone : str - The phone number of the user. - password : str - The password of the user. - data : Optional[Dict[str, Any]] - Optional user metadata. - - Returns - ------- - response : Union[Session, User] - A logged-in session if the server has "autoconfirm" ON - A user if the server has "autoconfirm" OFF - - Raises - ------ - APIError - If an error occurs. - """ - headers = self.headers - data = {"phone": phone, "password": password, "data": data} - url = f"{self.url}/signup" - response = await self.http_client.post(url, json=data, headers=headers) - SessionOrUserModel = determine_session_or_user_model_from_response(response) - return SessionOrUserModel.parse_response(response) - - async def sign_in_with_phone( - self, - *, - phone: str, - password: str, - ) -> Session: - """Logs in an existing user using their phone number and password. - - Parameters - ---------- - phone : str - The phone number of the user. - password : str - The password of the user. - - Returns - ------- - response : Session - A logged-in session - - Raises - ------ - APIError - If an error occurs. - """ - data = {"phone": phone, "password": password} - url = f"{self.url}/token?grant_type=password" - headers = self.headers - response = await self.http_client.post(url, json=data, headers=headers) - return Session.parse_response(response) - - async def send_magic_link_email( - self, - *, - email: str, - create_user: bool, - redirect_to: Optional[str] = None, - ) -> None: - """Sends a magic login link to an email address. - - Parameters - ---------- - email : str - The email address of the user. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - - Raises - ------ - APIError - If an error occurs. - """ - headers = self.headers - query_string = "" - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - query_string = f"?redirect_to={redirect_to_encoded}" - data = {"email": email, "create_user": create_user} - url = f"{self.url}/magiclink{query_string}" - response = await self.http_client.post(url, json=data, headers=headers) - return check_response(response) - - async def send_mobile_otp(self, *, phone: str, create_user: bool) -> None: - """Sends a mobile OTP via SMS. Will register the account if it doesn't already exist - - Parameters - ---------- - phone : str - The user's phone number WITH international prefix - - Raises - ------ - APIError - If an error occurs. - """ - headers = self.headers - data = {"phone": phone, "create_user": create_user} - url = f"{self.url}/otp" - response = await self.http_client.post(url, json=data, headers=headers) - return check_response(response) - - async def verify_mobile_otp( - self, - *, - phone: str, - token: str, - redirect_to: Optional[str] = None, - ) -> Union[Session, User]: - """Send User supplied Mobile OTP to be verified - - Parameters - ---------- - phone : str - The user's phone number WITH international prefix - token : str - Token that user was sent to their mobile phone - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - - Returns - ------- - response : Union[Session, User] - A logged-in session if the server has "autoconfirm" ON - A user if the server has "autoconfirm" OFF - - Raises - ------ - APIError - If an error occurs. - """ - headers = self.headers - data = { - "phone": phone, - "token": token, - "type": "sms", - } - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - data["redirect_to"] = redirect_to_encoded - url = f"{self.url}/verify" - response = await self.http_client.post(url, json=data, headers=headers) - SessionOrUserModel = determine_session_or_user_model_from_response(response) - return SessionOrUserModel.parse_response(response) - - async def invite_user_by_email( - self, - *, - email: str, - redirect_to: Optional[str] = None, - data: Optional[Dict[str, Any]] = None, - ) -> User: - """Sends an invite link to an email address. - - Parameters - ---------- - email : str - The email address of the user. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - data : Optional[Dict[str, Any]] - Optional user metadata. - - Returns - ------- - response : User - A user - - Raises - ------ - APIError - If an error occurs. - """ - headers = self.headers - query_string = "" - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - query_string = f"?redirect_to={redirect_to_encoded}" - data = {"email": email, "data": data} - url = f"{self.url}/invite{query_string}" - response = await self.http_client.post(url, json=data, headers=headers) - return User.parse_response(response) - - async def reset_password_for_email( - self, - *, - email: str, - redirect_to: Optional[str] = None, - ) -> None: - """Sends a reset request to an email address. - - Parameters - ---------- - email : str - The email address of the user. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - - Raises - ------ - APIError - If an error occurs. - """ - headers = self.headers - query_string = "" - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - query_string = f"?redirect_to={redirect_to_encoded}" - data = {"email": email} - url = f"{self.url}/recover{query_string}" - response = await self.http_client.post(url, json=data, headers=headers) - return check_response(response) - - def _create_request_headers(self, *, jwt: str) -> Dict[str, str]: - """Create temporary object. - - Create a temporary object with all configured headers and adds the - Authorization token to be used on request methods. - - Parameters - ---------- - jwt : str - A valid, logged-in JWT. - - Returns - ------- - headers : dict of str - The headers required for a successful request statement with the - supabase backend. - """ - headers = {**self.headers, "Authorization": f"Bearer {jwt}"} - return headers - - async def sign_out(self, *, jwt: str) -> None: - """Removes a logged-in session. - - Parameters - ---------- - jwt : str - A valid, logged-in JWT. - """ - headers = self._create_request_headers(jwt=jwt) - url = f"{self.url}/logout" - await self.http_client.post(url, headers=headers) - - async def get_url_for_provider( - self, - *, - provider: Provider, - redirect_to: Optional[str] = None, - scopes: Optional[str] = None, - ) -> str: - """Generates the relevant login URL for a third-party provider. - - Parameters - ---------- - provider : Provider - One of the providers supported by GoTrue. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - scopes : Optional[str] - A space-separated list of scopes granted to the OAuth application. - - Returns - ------- - url : str - The URL to redirect the user to. - - Raises - ------ - APIError - If an error occurs. - """ - url_params = [f"provider={encode_uri_component(provider)}"] - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - url_params.append(f"redirect_to={redirect_to_encoded}") - if scopes: - url_params.append(f"scopes={encode_uri_component(scopes)}") - return f"{self.url}/authorize?{'&'.join(url_params)}" - - async def get_user(self, *, jwt: str) -> User: - """Gets the user details. - - Parameters - ---------- - jwt : str - A valid, logged-in JWT. - - Returns - ------- - response : User - A user - - Raises - ------ - APIError - If an error occurs. - """ - headers = self._create_request_headers(jwt=jwt) - url = f"{self.url}/user" - response = await self.http_client.get(url, headers=headers) - return User.parse_response(response) - - async def update_user( - self, - *, - jwt: str, - attributes: UserAttributes, - ) -> User: - """ - Updates the user data. - - Parameters - ---------- - jwt : str - A valid, logged-in JWT. - attributes : UserAttributes - The data you want to update. - - Returns - ------- - response : User - A user - - Raises - ------ - APIError - If an error occurs. - """ - headers = self._create_request_headers(jwt=jwt) - data = attributes.dict() - url = f"{self.url}/user" - response = await self.http_client.put(url, json=data, headers=headers) - return User.parse_response(response) - - async def delete_user(self, *, uid: str, jwt: str) -> None: - """Delete a user. Requires a `service_role` key. - - This function should only be called on a server. - Never expose your `service_role` key in the browser. - - Parameters - ---------- - uid : str - The user uid you want to remove. - jwt : str - A valid, logged-in JWT. - - Returns - ------- - response : User - A user - - Raises - ------ - APIError - If an error occurs. - """ - headers = self._create_request_headers(jwt=jwt) - url = f"{self.url}/admin/users/{uid}" - response = await self.http_client.delete(url, headers=headers) - return check_response(response) - - async def refresh_access_token(self, *, refresh_token: str) -> Session: - """Generates a new JWT. - - Parameters - ---------- - refresh_token : str - A valid refresh token that was returned on login. - - Returns - ------- - response : Session - A session - - Raises - ------ - APIError - If an error occurs. - """ - data = {"refresh_token": refresh_token} - url = f"{self.url}/token?grant_type=refresh_token" - headers = self.headers - response = await self.http_client.post(url, json=data, headers=headers) - return Session.parse_response(response) - - async def generate_link( - self, - *, - type: LinkType, - email: str, - password: Optional[str] = None, - redirect_to: Optional[str] = None, - data: Optional[Dict[str, Any]] = None, - ) -> Union[Session, User]: - """ - Generates links to be sent via email or other. - - Parameters - ---------- - type : LinkType - The link type ("signup" or "magiclink" or "recovery" or "invite"). - email : str - The user's email. - password : Optional[str] - User password. For signup only. - redirect_to : Optional[str] - The link type ("signup" or "magiclink" or "recovery" or "invite"). - data : Optional[Dict[str, Any]] - Optional user metadata. For signup only. - - Returns - ------- - response : Union[Session, User] - A logged-in session if the server has "autoconfirm" ON - A user if the server has "autoconfirm" OFF - - Raises - ------ - APIError - If an error occurs. - """ - headers = self.headers - data = { - "type": type, - "email": email, - "data": data, - } - if password: - data["password"] = password - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - data["redirect_to"] = redirect_to_encoded - url = f"{self.url}/admin/generate_link" - response = await self.http_client.post(url, json=data, headers=headers) - SessionOrUserModel = determine_session_or_user_model_from_response(response) - return SessionOrUserModel.parse_response(response) - - async def set_auth_cookie(self, *, req, res): - """Stub for parity with JS api.""" - raise NotImplementedError("set_auth_cookie not implemented.") - - async def get_user_by_cookie(self, *, req): - """Stub for parity with JS api.""" - raise NotImplementedError("get_user_by_cookie not implemented.") diff --git a/gotrue/_async/client.py b/gotrue/_async/client.py index c72fe99e..8204b940 100644 --- a/gotrue/_async/client.py +++ b/gotrue/_async/client.py @@ -20,7 +20,7 @@ UserAttributes, UserAttributesDict, ) -from .api import AsyncGoTrueAPI +from .gotrue_admin_api import AsycnGoTrueAdminAPI from .storage import AsyncMemoryStorage, AsyncSupportedStorage @@ -34,7 +34,7 @@ def __init__( persist_session: bool = True, local_storage: AsyncSupportedStorage = AsyncMemoryStorage(), cookie_options: CookieOptions = CookieOptions.parse_obj(COOKIE_OPTIONS), - api: Optional[AsyncGoTrueAPI] = None, + api: Optional[AsycnGoTrueAdminAPI] = None, replace_default_headers: bool = False, ) -> None: """Create a new client @@ -72,7 +72,7 @@ def __init__( "headers": {**empty_or_default_headers, **headers}, "cookie_options": cookie_options, } - self.api = api or AsyncGoTrueAPI(**args) + self.api = api or AsyncGoTrueAdminAPI(**args) async def __aenter__(self) -> AsyncGoTrueClient: return self @@ -463,6 +463,9 @@ async def get_session_from_url( self._notify_all_subscribers(event=AuthChangeEvent.PASSWORD_RECOVERY) return session + async def get_session(self) -> None: + return None + async def sign_out(self) -> None: """Log the user out.""" access_token: Optional[str] = None diff --git a/gotrue/_async/gotrue_admin_api.py b/gotrue/_async/gotrue_admin_api.py new file mode 100644 index 00000000..7236bd97 --- /dev/null +++ b/gotrue/_async/gotrue_admin_api.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +from typing import Dict, List, Union + +from ..helpers import parse_link_response, parse_user_response +from ..http_clients import AsyncClient +from ..types import ( + AdminUserAttributes, + GenerateLinkParams, + GenerateLinkResponse, + Options, + User, + UserResponse, +) +from .gotrue_base_api import AsyncGoTrueBaseAPI + + +class AsyncGoTrueAdminAPI(AsyncGoTrueBaseAPI): + def __init__( + self, + *, + url: str = "", + headers: Dict[str, str] = {}, + http_client: Union[AsyncClient, None] = None, + ) -> None: + AsyncGoTrueBaseAPI.__init__( + self, + url=url, + headers=headers, + http_client=http_client, + ) + + async def sign_out(self, jwt: str) -> None: + """ + Removes a logged-in session. + """ + return await self._request( + "POST", + "logout", + jwt=jwt, + no_resolve_json=True, + ) + + async def invite_user_by_email( + self, + email: str, + options: Options = {}, + ) -> UserResponse: + """ + Sends an invite link to an email address. + """ + return await self._request( + "POST", + "invite", + body={"email": email, "data": options.get("data")}, + redirect_to=options.get("redirect_to"), + xform=parse_user_response, + ) + + async def generate_link(self, params: GenerateLinkParams) -> GenerateLinkResponse: + """ + Generates email links and OTPs to be sent via a custom email provider. + """ + return await self._request( + "POST", + "admin/generate_link", + body={ + "type": params.get("type"), + "email": params.get("email"), + "password": params.get("password"), + "new_email": params.get("new_email"), + "data": params.get("options", {}).get("data"), + }, + redirect_to=params.get("options", {}).get("redirect_to"), + xform=parse_link_response, + ) + + # User Admin API + + async def create_user(self, attributes: AdminUserAttributes) -> UserResponse: + """ + Creates a new user. + + This function should only be called on a server. + Never expose your `service_role` key in the browser. + """ + return await self._request( + "POST", + "admin/users", + body=attributes, + xform=parse_user_response, + ) + + async def list_users(self) -> List[User]: + """ + Get a list of users. + + This function should only be called on a server. + Never expose your `service_role` key in the browser. + """ + return await self._request( + "GET", + "admin/users", + xform=lambda data: [User.parse_obj(user) for user in data["users"]] + if "users" in data + else [], + ) + + async def get_user_by_id(self, uid: str) -> UserResponse: + """ + Get user by id. + + This function should only be called on a server. + Never expose your `service_role` key in the browser. + """ + return await self._request( + "GET", + f"admin/users/{uid}", + xform=parse_user_response, + ) + + async def update_user_by_id( + self, + uid: str, + attributes: AdminUserAttributes, + ) -> UserResponse: + """ + Updates the user data. + + This function should only be called on a server. + Never expose your `service_role` key in the browser. + """ + return await self._request( + "PUT", + f"admin/users/{uid}", + body=attributes, + xform=parse_user_response, + ) + + async def delete_user(self, id: str) -> None: + """ + Delete a user. Requires a `service_role` key. + + This function should only be called on a server. + Never expose your `service_role` key in the browser. + """ + return await self._request("DELETE", f"admin/users/{id}") diff --git a/gotrue/_async/gotrue_base_api.py b/gotrue/_async/gotrue_base_api.py new file mode 100644 index 00000000..8d7b4697 --- /dev/null +++ b/gotrue/_async/gotrue_base_api.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +from typing import Any, Callable, Dict, TypeVar, Union, overload + +from httpx import Response +from pydantic import BaseModel +from typing_extensions import Literal, Self + +from ..helpers import handle_exception +from ..http_clients import AsyncClient + +T = TypeVar("T") + + +class AsyncGoTrueBaseAPI: + def __init__( + self, + *, + url: str, + headers: Dict[str, str], + http_client: Union[AsyncClient, None], + ): + self._url = url + self._headers = headers + self._http_client = http_client or AsyncClient() + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_t, exc_v, exc_tb) -> None: + await self.close() + + async def close(self) -> None: + await self._http_client.aclose() + + @overload + async def _request( + self, + method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], + path: str, + *, + jwt: Union[str, None] = None, + redirect_to: Union[str, None] = None, + headers: Union[Dict[str, str], None] = None, + query: Union[Dict[str, str], None] = None, + body: Union[Any, None] = None, + no_resolve_json: Literal[False] = False, + xform: Callable[[Any], T], + ) -> T: + ... # pragma: no cover + + @overload + async def _request( + self, + method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], + path: str, + *, + jwt: Union[str, None] = None, + redirect_to: Union[str, None] = None, + headers: Union[Dict[str, str], None] = None, + query: Union[Dict[str, str], None] = None, + body: Union[Any, None] = None, + no_resolve_json: Literal[True], + xform: Callable[[Response], T], + ) -> T: + ... # pragma: no cover + + @overload + async def _request( + self, + method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], + path: str, + *, + jwt: Union[str, None] = None, + redirect_to: Union[str, None] = None, + headers: Union[Dict[str, str], None] = None, + query: Union[Dict[str, str], None] = None, + body: Union[Any, None] = None, + no_resolve_json: bool = False, + ) -> None: + ... # pragma: no cover + + async def _request( + self, + method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], + path: str, + *, + jwt: Union[str, None] = None, + redirect_to: Union[str, None] = None, + headers: Union[Dict[str, str], None] = None, + query: Union[Dict[str, str], None] = None, + body: Union[Any, None] = None, + no_resolve_json: bool = False, + xform: Union[Callable[[Any], T], None] = None, + ) -> Union[T, None]: + url = f"{self._url}/{path}" + headers = {**self._headers, **(headers or {})} + if "Content-Type" not in headers: + headers["Content-Type"] = "application/json;charset=UTF-8" + if jwt: + headers["Authorization"] = f"Bearer {jwt}" + query = query or {} + if redirect_to: + query["redirect_to"] = redirect_to + try: + response = await self._http_client.request( + method, + url, + headers=headers, + params=query, + json=body.dict() if isinstance(body, BaseModel) else body, + ) + response.raise_for_status() + result = response if no_resolve_json else response.json() + if xform: + return xform(result) + except Exception as e: + raise handle_exception(e) diff --git a/gotrue/_sync/api.py b/gotrue/_sync/api.py deleted file mode 100644 index 6ab024f9..00000000 --- a/gotrue/_sync/api.py +++ /dev/null @@ -1,642 +0,0 @@ -from __future__ import annotations - -from typing import Any, Dict, List, Optional, Union - -from pydantic import parse_obj_as - -from ..exceptions import APIError -from ..helpers import check_response, encode_uri_component -from ..http_clients import SyncClient -from ..types import ( - CookieOptions, - LinkType, - Provider, - Session, - User, - UserAttributes, - determine_session_or_user_model_from_response, -) - - -class SyncGoTrueAPI: - def __init__( - self, - *, - url: str, - headers: Dict[str, str], - cookie_options: CookieOptions, - http_client: Optional[SyncClient] = None, - ) -> None: - """Initialise API class.""" - self.url = url - self.headers = headers - self.cookie_options = cookie_options - self.http_client = http_client or SyncClient() - - def __enter__(self) -> SyncGoTrueAPI: - return self - - def __exit__(self, exc_t, exc_v, exc_tb) -> None: - self.close() - - def close(self) -> None: - self.http_client.aclose() - - def create_user(self, *, attributes: UserAttributes) -> User: - """Creates a new user. - - This function should only be called on a server. - Never expose your `service_role` key in the browser. - - Parameters - ---------- - attributes: UserAttributes - The data you want to create the user with. - - Returns - ------- - response : User - The created user - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - data = attributes.dict() - url = f"{self.url}/admin/users" - response = self.http_client.post(url, json=data, headers=headers) - return User.parse_response(response) - - def list_users(self) -> List[User]: - """Get a list of users. - - This function should only be called on a server. - Never expose your `service_role` key in the browser. - - Returns - ------- - response : List[User] - A list of users - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - url = f"{self.url}/admin/users" - response = self.http_client.get(url, headers=headers) - check_response(response) - users = response.json().get("users") - if users is None: - raise APIError("No users found in response", 400) - if not isinstance(users, list): - raise APIError("Expected a list of users", 400) - return parse_obj_as(List[User], users) - - def sign_up_with_email( - self, - *, - email: str, - password: str, - redirect_to: Optional[str] = None, - data: Optional[Dict[str, Any]] = None, - ) -> Union[Session, User]: - """Creates a new user using their email address. - - Parameters - ---------- - email : str - The email address of the user. - password : str - The password of the user. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - data : Optional[Dict[str, Any]] - Optional user metadata. - - Returns - ------- - response : Union[Session, User] - A logged-in session if the server has "autoconfirm" ON - A user if the server has "autoconfirm" OFF - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - query_string = "" - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - query_string = f"?redirect_to={redirect_to_encoded}" - data = {"email": email, "password": password, "data": data} - url = f"{self.url}/signup{query_string}" - response = self.http_client.post(url, json=data, headers=headers) - SessionOrUserModel = determine_session_or_user_model_from_response(response) - return SessionOrUserModel.parse_response(response) - - def sign_in_with_email( - self, - *, - email: str, - password: str, - redirect_to: Optional[str] = None, - ) -> Session: - """Logs in an existing user using their email address. - - Parameters - ---------- - email : str - The email address of the user. - password : str - The password of the user. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - - Returns - ------- - response : Session - A logged-in session - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - query_string = "?grant_type=password" - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - query_string += f"&redirect_to={redirect_to_encoded}" - data = {"email": email, "password": password} - url = f"{self.url}/token{query_string}" - response = self.http_client.post(url, json=data, headers=headers) - return Session.parse_response(response) - - def sign_up_with_phone( - self, - *, - phone: str, - password: str, - data: Optional[Dict[str, Any]] = None, - ) -> Union[Session, User]: - """Signs up a new user using their phone number and a password. - - Parameters - ---------- - phone : str - The phone number of the user. - password : str - The password of the user. - data : Optional[Dict[str, Any]] - Optional user metadata. - - Returns - ------- - response : Union[Session, User] - A logged-in session if the server has "autoconfirm" ON - A user if the server has "autoconfirm" OFF - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - data = {"phone": phone, "password": password, "data": data} - url = f"{self.url}/signup" - response = self.http_client.post(url, json=data, headers=headers) - SessionOrUserModel = determine_session_or_user_model_from_response(response) - return SessionOrUserModel.parse_response(response) - - def sign_in_with_phone( - self, - *, - phone: str, - password: str, - ) -> Session: - """Logs in an existing user using their phone number and password. - - Parameters - ---------- - phone : str - The phone number of the user. - password : str - The password of the user. - - Returns - ------- - response : Session - A logged-in session - - Raises - ------ - error : APIError - If an error occurs - """ - data = {"phone": phone, "password": password} - url = f"{self.url}/token?grant_type=password" - headers = self.headers - response = self.http_client.post(url, json=data, headers=headers) - return Session.parse_response(response) - - def send_magic_link_email( - self, - *, - email: str, - create_user: bool, - redirect_to: Optional[str] = None, - ) -> None: - """Sends a magic login link to an email address. - - Parameters - ---------- - email : str - The email address of the user. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - query_string = "" - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - query_string = f"?redirect_to={redirect_to_encoded}" - data = {"email": email, "create_user": create_user} - url = f"{self.url}/magiclink{query_string}" - response = self.http_client.post(url, json=data, headers=headers) - return check_response(response) - - def send_mobile_otp(self, *, phone: str, create_user: bool) -> None: - """Sends a mobile OTP via SMS. Will register the account if it doesn't already exist - - Parameters - ---------- - phone : str - The user's phone number WITH international prefix - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - data = {"phone": phone, "create_user": create_user} - url = f"{self.url}/otp" - response = self.http_client.post(url, json=data, headers=headers) - return check_response(response) - - def verify_mobile_otp( - self, - *, - phone: str, - token: str, - redirect_to: Optional[str] = None, - ) -> Union[Session, User]: - """Send User supplied Mobile OTP to be verified - - Parameters - ---------- - phone : str - The user's phone number WITH international prefix - token : str - Token that user was sent to their mobile phone - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - - Returns - ------- - response : Union[Session, User] - A logged-in session if the server has "autoconfirm" ON - A user if the server has "autoconfirm" OFF - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - data = { - "phone": phone, - "token": token, - "type": "sms", - } - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - data["redirect_to"] = redirect_to_encoded - url = f"{self.url}/verify" - response = self.http_client.post(url, json=data, headers=headers) - SessionOrUserModel = determine_session_or_user_model_from_response(response) - return SessionOrUserModel.parse_response(response) - - def invite_user_by_email( - self, - *, - email: str, - redirect_to: Optional[str] = None, - data: Optional[Dict[str, Any]] = None, - ) -> User: - """Sends an invite link to an email address. - - Parameters - ---------- - email : str - The email address of the user. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - data : Optional[Dict[str, Any]] - Optional user metadata. - - Returns - ------- - response : User - A user - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - query_string = "" - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - query_string = f"?redirect_to={redirect_to_encoded}" - data = {"email": email, "data": data} - url = f"{self.url}/invite{query_string}" - response = self.http_client.post(url, json=data, headers=headers) - return User.parse_response(response) - - def reset_password_for_email( - self, - *, - email: str, - redirect_to: Optional[str] = None, - ) -> None: - """Sends a reset request to an email address. - - Parameters - ---------- - email : str - The email address of the user. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - query_string = "" - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - query_string = f"?redirect_to={redirect_to_encoded}" - data = {"email": email} - url = f"{self.url}/recover{query_string}" - response = self.http_client.post(url, json=data, headers=headers) - return check_response(response) - - def _create_request_headers(self, *, jwt: str) -> Dict[str, str]: - """Create temporary object. - - Create a temporary object with all configured headers and adds the - Authorization token to be used on request methods. - - Parameters - ---------- - jwt : str - A valid, logged-in JWT. - - Returns - ------- - headers : dict of str - The headers required for a successful request statement with the - supabase backend. - """ - headers = {**self.headers} - headers["Authorization"] = f"Bearer {jwt}" - return headers - - def sign_out(self, *, jwt: str) -> None: - """Removes a logged-in session. - - Parameters - ---------- - jwt : str - A valid, logged-in JWT. - """ - headers = self._create_request_headers(jwt=jwt) - url = f"{self.url}/logout" - self.http_client.post(url, headers=headers) - - def get_url_for_provider( - self, - *, - provider: Provider, - redirect_to: Optional[str] = None, - scopes: Optional[str] = None, - ) -> str: - """Generates the relevant login URL for a third-party provider. - - Parameters - ---------- - provider : Provider - One of the providers supported by GoTrue. - redirect_to : Optional[str] - A URL or mobile address to send the user to after they are confirmed. - scopes : Optional[str] - A space-separated list of scopes granted to the OAuth application. - - Returns - ------- - url : str - The URL to redirect the user to. - - Raises - ------ - error : APIError - If an error occurs - """ - url_params = [f"provider={encode_uri_component(provider)}"] - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - url_params.append(f"redirect_to={redirect_to_encoded}") - if scopes: - url_params.append(f"scopes={encode_uri_component(scopes)}") - return f"{self.url}/authorize?{'&'.join(url_params)}" - - def get_user(self, *, jwt: str) -> User: - """Gets the user details. - - Parameters - ---------- - jwt : str - A valid, logged-in JWT. - - Returns - ------- - response : User - A user - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self._create_request_headers(jwt=jwt) - url = f"{self.url}/user" - response = self.http_client.get(url, headers=headers) - return User.parse_response(response) - - def update_user( - self, - *, - jwt: str, - attributes: UserAttributes, - ) -> User: - """ - Updates the user data. - - Parameters - ---------- - jwt : str - A valid, logged-in JWT. - attributes : UserAttributes - The data you want to update. - - Returns - ------- - response : User - A user - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self._create_request_headers(jwt=jwt) - data = attributes.dict() - url = f"{self.url}/user" - response = self.http_client.put(url, json=data, headers=headers) - return User.parse_response(response) - - def delete_user(self, *, uid: str, jwt: str) -> None: - """Delete a user. Requires a `service_role` key. - - This function should only be called on a server. - Never expose your `service_role` key in the browser. - - Parameters - ---------- - uid : str - The user uid you want to remove. - jwt : str - A valid, logged-in JWT. - - Returns - ------- - response : User - A user - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self._create_request_headers(jwt=jwt) - url = f"{self.url}/admin/users/{uid}" - response = self.http_client.delete(url, headers=headers) - return check_response(response) - - def refresh_access_token(self, *, refresh_token: str) -> Session: - """Generates a new JWT. - - Parameters - ---------- - refresh_token : str - A valid refresh token that was returned on login. - - Returns - ------- - response : Session - A session - - Raises - ------ - error : APIError - If an error occurs - """ - data = {"refresh_token": refresh_token} - url = f"{self.url}/token?grant_type=refresh_token" - headers = self.headers - response = self.http_client.post(url, json=data, headers=headers) - return Session.parse_response(response) - - def generate_link( - self, - *, - type: LinkType, - email: str, - password: Optional[str] = None, - redirect_to: Optional[str] = None, - data: Optional[Dict[str, Any]] = None, - ) -> Union[Session, User]: - """ - Generates links to be sent via email or other. - - Parameters - ---------- - type : LinkType - The link type ("signup" or "magiclink" or "recovery" or "invite"). - email : str - The user's email. - password : Optional[str] - User password. For signup only. - redirect_to : Optional[str] - The link type ("signup" or "magiclink" or "recovery" or "invite"). - data : Optional[Dict[str, Any]] - Optional user metadata. For signup only. - - Returns - ------- - response : Union[Session, User] - A logged-in session if the server has "autoconfirm" ON - A user if the server has "autoconfirm" OFF - - Raises - ------ - error : APIError - If an error occurs - """ - headers = self.headers - data = { - "type": type, - "email": email, - "data": data, - } - if password: - data["password"] = password - if redirect_to: - redirect_to_encoded = encode_uri_component(redirect_to) - data["redirect_to"] = redirect_to_encoded - url = f"{self.url}/admin/generate_link" - response = self.http_client.post(url, json=data, headers=headers) - SessionOrUserModel = determine_session_or_user_model_from_response(response) - return SessionOrUserModel.parse_response(response) - - def set_auth_cookie(self, *, req, res): - """Stub for parity with JS api.""" - raise NotImplementedError("set_auth_cookie not implemented.") - - def get_user_by_cookie(self, *, req): - """Stub for parity with JS api.""" - raise NotImplementedError("get_user_by_cookie not implemented.") diff --git a/gotrue/_sync/client.py b/gotrue/_sync/client.py index f7f20455..6e91936d 100644 --- a/gotrue/_sync/client.py +++ b/gotrue/_sync/client.py @@ -20,7 +20,7 @@ UserAttributes, UserAttributesDict, ) -from .api import SyncGoTrueAPI +from .gotrue_admin_api import AsycnGoTrueAdminAPI from .storage import SyncMemoryStorage, SyncSupportedStorage @@ -34,7 +34,7 @@ def __init__( persist_session: bool = True, local_storage: SyncSupportedStorage = SyncMemoryStorage(), cookie_options: CookieOptions = CookieOptions.parse_obj(COOKIE_OPTIONS), - api: Optional[SyncGoTrueAPI] = None, + api: Optional[AsycnGoTrueAdminAPI] = None, replace_default_headers: bool = False, ) -> None: """Create a new client @@ -72,7 +72,7 @@ def __init__( "headers": {**empty_or_default_headers, **headers}, "cookie_options": cookie_options, } - self.api = api or SyncGoTrueAPI(**args) + self.api = api or SyncGoTrueAdminAPI(**args) def __enter__(self) -> SyncGoTrueClient: return self @@ -121,8 +121,8 @@ def sign_up( Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ self._remove_session() @@ -202,8 +202,8 @@ def sign_in( Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ self._remove_session() if email: @@ -268,8 +268,8 @@ def verify_otp( Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ self._remove_session() response = self.api.verify_mobile_otp( @@ -315,8 +315,8 @@ def update(self, *, attributes: Union[UserAttributesDict, UserAttributes]) -> Us Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ if not self.current_session: raise ValueError("Not logged in.") @@ -350,8 +350,8 @@ def set_session(self, *, refresh_token: str) -> Session: Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ response = self.api.refresh_access_token(refresh_token=refresh_token) self._save_session(session=response) @@ -374,8 +374,8 @@ def set_auth(self, *, access_token: str) -> Session: Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ session = Session( access_token=access_token, @@ -416,8 +416,8 @@ def get_session_from_url( Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ data = urlparse(url) query = parse_qs(data.query) @@ -459,6 +459,9 @@ def get_session_from_url( self._notify_all_subscribers(event=AuthChangeEvent.PASSWORD_RECOVERY) return session + def get_session(self) -> None: + return None + def sign_out(self) -> None: """Log the user out.""" access_token: Optional[str] = None @@ -492,8 +495,8 @@ def on_auth_state_change( Raises ------ - error : APIError - If an error occurs + APIError + If an error occurs. """ unique_id = uuid4() subscription = Subscription( diff --git a/gotrue/_sync/gotrue_admin_api.py b/gotrue/_sync/gotrue_admin_api.py new file mode 100644 index 00000000..4b934176 --- /dev/null +++ b/gotrue/_sync/gotrue_admin_api.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +from typing import Dict, List, Union + +from ..helpers import parse_link_response, parse_user_response +from ..http_clients import SyncClient +from ..types import ( + AdminUserAttributes, + GenerateLinkParams, + GenerateLinkResponse, + Options, + User, + UserResponse, +) +from .gotrue_base_api import SyncGoTrueBaseAPI + + +class SyncGoTrueAdminAPI(SyncGoTrueBaseAPI): + def __init__( + self, + *, + url: str = "", + headers: Dict[str, str] = {}, + http_client: Union[SyncClient, None] = None, + ) -> None: + SyncGoTrueBaseAPI.__init__( + self, + url=url, + headers=headers, + http_client=http_client, + ) + + def sign_out(self, jwt: str) -> None: + """ + Removes a logged-in session. + """ + return self._request( + "POST", + "logout", + jwt=jwt, + no_resolve_json=True, + ) + + def invite_user_by_email( + self, + email: str, + options: Options = {}, + ) -> UserResponse: + """ + Sends an invite link to an email address. + """ + return self._request( + "POST", + "invite", + body={"email": email, "data": options.get("data")}, + redirect_to=options.get("redirect_to"), + xform=parse_user_response, + ) + + def generate_link(self, params: GenerateLinkParams) -> GenerateLinkResponse: + """ + Generates email links and OTPs to be sent via a custom email provider. + """ + return self._request( + "POST", + "admin/generate_link", + body={ + "type": params.get("type"), + "email": params.get("email"), + "password": params.get("password"), + "new_email": params.get("new_email"), + "data": params.get("options", {}).get("data"), + }, + redirect_to=params.get("options", {}).get("redirect_to"), + xform=parse_link_response, + ) + + # User Admin API + + def create_user(self, attributes: AdminUserAttributes) -> UserResponse: + """ + Creates a new user. + + This function should only be called on a server. + Never expose your `service_role` key in the browser. + """ + return self._request( + "POST", + "admin/users", + body=attributes, + xform=parse_user_response, + ) + + def list_users(self) -> List[User]: + """ + Get a list of users. + + This function should only be called on a server. + Never expose your `service_role` key in the browser. + """ + return self._request( + "GET", + "admin/users", + xform=lambda data: [User.parse_obj(user) for user in data["users"]] + if "users" in data + else [], + ) + + def get_user_by_id(self, uid: str) -> UserResponse: + """ + Get user by id. + + This function should only be called on a server. + Never expose your `service_role` key in the browser. + """ + return self._request( + "GET", + f"admin/users/{uid}", + xform=parse_user_response, + ) + + def update_user_by_id( + self, + uid: str, + attributes: AdminUserAttributes, + ) -> UserResponse: + """ + Updates the user data. + + This function should only be called on a server. + Never expose your `service_role` key in the browser. + """ + return self._request( + "PUT", + f"admin/users/{uid}", + body=attributes, + xform=parse_user_response, + ) + + def delete_user(self, id: str) -> None: + """ + Delete a user. Requires a `service_role` key. + + This function should only be called on a server. + Never expose your `service_role` key in the browser. + """ + return self._request("DELETE", f"admin/users/{id}") diff --git a/gotrue/_sync/gotrue_base_api.py b/gotrue/_sync/gotrue_base_api.py new file mode 100644 index 00000000..81701304 --- /dev/null +++ b/gotrue/_sync/gotrue_base_api.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +from typing import Any, Callable, Dict, TypeVar, Union, overload + +from httpx import Response +from pydantic import BaseModel +from typing_extensions import Literal, Self + +from ..helpers import handle_exception +from ..http_clients import SyncClient + +T = TypeVar("T") + + +class SyncGoTrueBaseAPI: + def __init__( + self, + *, + url: str, + headers: Dict[str, str], + http_client: Union[SyncClient, None], + ): + self._url = url + self._headers = headers + self._http_client = http_client or SyncClient() + + def __enter__(self) -> Self: + return self + + def __exit__(self, exc_t, exc_v, exc_tb) -> None: + self.close() + + def close(self) -> None: + self._http_client.aclose() + + @overload + def _request( + self, + method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], + path: str, + *, + jwt: Union[str, None] = None, + redirect_to: Union[str, None] = None, + headers: Union[Dict[str, str], None] = None, + query: Union[Dict[str, str], None] = None, + body: Union[Any, None] = None, + no_resolve_json: Literal[False] = False, + xform: Callable[[Any], T], + ) -> T: + ... # pragma: no cover + + @overload + def _request( + self, + method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], + path: str, + *, + jwt: Union[str, None] = None, + redirect_to: Union[str, None] = None, + headers: Union[Dict[str, str], None] = None, + query: Union[Dict[str, str], None] = None, + body: Union[Any, None] = None, + no_resolve_json: Literal[True], + xform: Callable[[Response], T], + ) -> T: + ... # pragma: no cover + + @overload + def _request( + self, + method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], + path: str, + *, + jwt: Union[str, None] = None, + redirect_to: Union[str, None] = None, + headers: Union[Dict[str, str], None] = None, + query: Union[Dict[str, str], None] = None, + body: Union[Any, None] = None, + no_resolve_json: bool = False, + ) -> None: + ... # pragma: no cover + + def _request( + self, + method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], + path: str, + *, + jwt: Union[str, None] = None, + redirect_to: Union[str, None] = None, + headers: Union[Dict[str, str], None] = None, + query: Union[Dict[str, str], None] = None, + body: Union[Any, None] = None, + no_resolve_json: bool = False, + xform: Union[Callable[[Any], T], None] = None, + ) -> Union[T, None]: + url = f"{self._url}/{path}" + headers = {**self._headers, **(headers or {})} + if "Content-Type" not in headers: + headers["Content-Type"] = "application/json;charset=UTF-8" + if jwt: + headers["Authorization"] = f"Bearer {jwt}" + query = query or {} + if redirect_to: + query["redirect_to"] = redirect_to + try: + response = self._http_client.request( + method, + url, + headers=headers, + params=query, + json=body.dict() if isinstance(body, BaseModel) else body, + ) + response.raise_for_status() + result = response if no_resolve_json else response.json() + if xform: + return xform(result) + except Exception as e: + raise handle_exception(e) diff --git a/gotrue/_sync/gotrue_client.py b/gotrue/_sync/gotrue_client.py new file mode 100644 index 00000000..43e41079 --- /dev/null +++ b/gotrue/_sync/gotrue_client.py @@ -0,0 +1,835 @@ +from __future__ import annotations + +from json import loads +from time import time +from typing import Callable, Dict, List, Tuple, Union +from urllib.parse import parse_qs, quote, urlencode, urlparse +from uuid import uuid4 + +from ..constants import ( + DEFAULT_HEADERS, + EXPIRY_MARGIN, + GOTRUE_URL, + MAX_RETRIES, + RETRY_INTERVAL, + STORAGE_KEY, +) +from ..errors import ( + AuthImplicitGrantRedirectError, + AuthInvalidCredentialsError, + AuthRetryableError, + AuthSessionMissingError, +) +from ..helpers import decode_jwt_payload, parse_auth_response, parse_user_response +from ..http_clients import SyncClient +from ..timer import Timer +from ..types import ( + AuthChangeEvent, + AuthenticatorAssuranceLevels, + AuthMFAChallengeResponse, + AuthMFAEnrollResponse, + AuthMFAGetAuthenticatorAssuranceLevelResponse, + AuthMFAListFactorsResponse, + AuthMFAUnenrollResponse, + AuthMFAVerifyResponse, + AuthResponse, + DecodedJWTDict, + MFAChallengeAndVerifyParams, + MFAChallengeParams, + MFAEnrollParams, + MFAUnenrollParams, + MFAVerifyParams, + OAuthResponse, + Options, + Provider, + Session, + SignInWithOAuthCredentials, + SignInWithPasswordCredentials, + SignInWithPasswordlessCredentials, + SignUpWithPasswordCredentials, + Subscription, + UserAttributes, + UserResponse, + VerifyOtpParams, +) +from .gotrue_admin_api import SyncGoTrueAdminAPI +from .gotrue_base_api import SyncGoTrueBaseAPI +from .gotrue_mfa_api import SyncGoTrueMFAAPI +from .storage import SyncMemoryStorage, SyncSupportedStorage + + +class SyncGoTrueClient(SyncGoTrueBaseAPI): + def __init__( + self, + *, + url: Union[str, None] = None, + headers: Union[Dict[str, str], None] = None, + storage_key: Union[str, None] = None, + auto_refresh_token: bool = True, + persist_session: bool = True, + storage: Union[SyncSupportedStorage, None] = None, + http_client: Union[SyncClient, None] = None, + ) -> None: + SyncGoTrueBaseAPI.__init__( + self, + url=url or GOTRUE_URL, + headers=headers or DEFAULT_HEADERS, + http_client=http_client, + ) + self._storage_key = storage_key or STORAGE_KEY + self._auto_refresh_token = auto_refresh_token + self._persist_session = persist_session + self._storage = storage or SyncMemoryStorage() + self._in_memory_session: Union[Session, None] = None + self._refresh_token_timer: Union[Timer, None] = None + self._network_retries = 0 + self._state_change_emitters: Dict[str, Subscription] = {} + + self.admin = SyncGoTrueAdminAPI( + url=self._url, + headers=self._headers, + http_client=self._http_client, + ) + self.mfa = SyncGoTrueMFAAPI() + self.mfa.challenge = self._challenge + self.mfa.challenge_and_verify = self._challenge_and_verify + self.mfa.enroll = self._enroll + self.mfa.get_authenticator_assurance_level = ( + self._get_authenticator_assurance_level + ) + self.mfa.list_factors = self._list_factors + self.mfa.unenroll = self._unenroll + self.mfa.verify = self._verify + + # Initializations + + def initialize(self, *, url: Union[str, None] = None) -> None: + if url and self._is_implicit_grant_flow(url): + self.initialize_from_url(url) + else: + self.initialize_from_storage() + + def initialize_from_storage(self) -> None: + return self._recover_and_refresh() + + def initialize_from_url(self, url: str) -> None: + try: + if self._is_implicit_grant_flow(url): + session, redirect_type = self._get_session_from_url(url) + self._save_session(session) + self._notify_all_subscribers("SIGNED_IN", session) + if redirect_type == "recovery": + self._notify_all_subscribers("PASSWORD_RECOVERY", session) + except Exception as e: + self._remove_session() + raise e + + # Public methods + + def sign_up( + self, + credentials: SignUpWithPasswordCredentials, + ) -> AuthResponse: + """ + Creates a new user. + """ + self._remove_session() + email = credentials.get("email") + phone = credentials.get("phone") + password = credentials.get("password") + options = credentials.get("options", {}) + redirect_to = options.get("redirect_to") + data = options.get("data") or {} + captcha_token = options.get("captcha_token") + if email: + response = self._request( + "POST", + "signup", + body={ + "email": email, + "password": password, + "data": data, + "gotrue_meta_security": { + "captcha_token": captcha_token, + }, + }, + redirect_to=redirect_to, + xform=parse_auth_response, + ) + elif phone: + response = self._request( + "POST", + "signup", + body={ + "phone": phone, + "password": password, + "data": data, + "gotrue_meta_security": { + "captcha_token": captcha_token, + }, + }, + xform=parse_auth_response, + ) + else: + raise AuthInvalidCredentialsError( + "You must provide either an email or phone number and a password" + ) + if response.session: + self._save_session(response.session) + self._notify_all_subscribers("SIGNED_IN", response.session) + return response + + def sign_in_with_password( + self, + credentials: SignInWithPasswordCredentials, + ) -> AuthResponse: + """ + Log in an existing user with an email or phone and password. + """ + self._remove_session() + email = credentials.get("email") + phone = credentials.get("phone") + password = credentials.get("password") + options = credentials.get("options", {}) + data = options.get("data") or {} + captcha_token = options.get("captcha_token") + if email: + response = self._request( + "POST", + "token", + body={ + "email": email, + "password": password, + "data": data, + "gotrue_meta_security": { + "captcha_token": captcha_token, + }, + }, + query={ + "grant_type": "password", + }, + xform=parse_auth_response, + ) + elif phone: + response = self._request( + "POST", + "token", + body={ + "phone": phone, + "password": password, + "data": data, + "gotrue_meta_security": { + "captcha_token": captcha_token, + }, + }, + query={ + "grant_type": "password", + }, + xform=parse_auth_response, + ) + else: + raise AuthInvalidCredentialsError( + "You must provide either an email or phone number and a password" + ) + if response.session: + self._save_session(response.session) + self._notify_all_subscribers("SIGNED_IN", response.session) + return response + + def sign_in_with_oauth( + self, + credentials: SignInWithOAuthCredentials, + ) -> OAuthResponse: + """ + Log in an existing user via a third-party provider. + """ + self._remove_session() + provider = credentials.get("provider") + options = credentials.get("options", {}) + redirect_to = options.get("redirect_to") + scopes = options.get("scopes") + params = options.get("query_params", {}) + if redirect_to: + params["redirect_to"] = redirect_to + if scopes: + params["scopes"] = scopes + url = self._get_url_for_provider(provider, params) + return OAuthResponse(provider=provider, url=url) + + def sign_in_with_otp( + self, + credentials: SignInWithPasswordlessCredentials, + ) -> AuthResponse: + """ + Log in a user using magiclink or a one-time password (OTP). + + If the `{{ .ConfirmationURL }}` variable is specified in + the email template, a magiclink will be sent. + + If the `{{ .Token }}` variable is specified in the email + template, an OTP will be sent. + + If you're using phone sign-ins, only an OTP will be sent. + You won't be able to send a magiclink for phone sign-ins. + """ + self._remove_session() + email = credentials.get("email") + phone = credentials.get("phone") + options = credentials.get("options", {}) + email_redirect_to = options.get("email_redirect_to") + should_create_user = options.get("create_user", True) + data = options.get("data") + captcha_token = options.get("captcha_token") + if email: + return self._request( + "POST", + "otp", + body={ + "email": email, + "data": data, + "create_user": should_create_user, + "gotrue_meta_security": { + "captcha_token": captcha_token, + }, + }, + redirect_to=email_redirect_to, + xform=parse_auth_response, + ) + if phone: + return self._request( + "POST", + "otp", + body={ + "phone": phone, + "data": data, + "create_user": should_create_user, + "gotrue_meta_security": { + "captcha_token": captcha_token, + }, + }, + xform=parse_auth_response, + ) + raise AuthInvalidCredentialsError( + "You must provide either an email or phone number" + ) + + def verify_otp(self, params: VerifyOtpParams) -> AuthResponse: + """ + Log in a user given a User supplied OTP received via mobile. + """ + self._remove_session() + response = self._request( + "POST", + "verify", + body={ + "gotrue_meta_security": { + "captcha_token": params.get("options", {}).get("captcha_token"), + }, + **params, + }, + redirect_to=params.get("options", {}).get("redirect_to"), + xform=parse_auth_response, + ) + if response.session: + self._save_session(response.session) + self._notify_all_subscribers("SIGNED_IN", response.session) + return response + + def get_session(self) -> Union[Session, None]: + """ + Returns the session, refreshing it if necessary. + + The session returned can be null if the session is not detected which + can happen in the event a user is not signed-in or has logged out. + """ + current_session: Union[Session, None] = None + if self._persist_session: + maybe_session = self._storage.get_item(self._storage_key) + current_session = self._get_valid_session(maybe_session) + if not current_session: + self._remove_session() + else: + current_session = self._in_memory_session + if not current_session: + return None + time_now = round(time()) + has_expired = ( + current_session.expires_at <= time_now + EXPIRY_MARGIN + if current_session.expires_at + else False + ) + return ( + self._call_refresh_token(current_session.refresh_token) + if has_expired + else current_session + ) + + def get_user(self, jwt: Union[str, None] = None) -> UserResponse: + """ + Gets the current user details if there is an existing session. + + Takes in an optional access token `jwt`. If no `jwt` is provided, + `get_user()` will attempt to get the `jwt` from the current session. + """ + if not jwt: + if session := self.get_session(): + jwt = session.access_token + return self._request("GET", "user", jwt=jwt, xform=parse_user_response) + + def update_user(self, attributes: UserAttributes) -> UserResponse: + """ + Updates user data, if there is a logged in user. + """ + session = self.get_session() + if not session: + raise AuthSessionMissingError() + response = self._request( + "PUT", + "user", + body=attributes, + jwt=session.access_token, + xform=parse_user_response, + ) + session.user = response.user + self._save_session(session) + self._notify_all_subscribers("USER_UPDATED", session) + return response + + def set_session(self, access_token: str, refresh_token: str) -> AuthResponse: + """ + Sets the session data from the current session. If the current session + is expired, `set_session` will take care of refreshing it to obtain a + new session. + + If the refresh token in the current session is invalid and the current + session has expired, an error will be thrown. + + If the current session does not contain at `expires_at` field, + `set_session` will use the exp claim defined in the access token. + + The current session that minimally contains an access token, + refresh token and a user. + """ + time_now = round(time()) + expires_at = time_now + has_expired = True + session: Union[Session, None] = None + if access_token and access_token.split(".")[1]: + payload = self._decode_jwt(access_token) + if exp := payload.get("exp"): + expires_at = int(exp) + has_expired = expires_at <= time_now + if has_expired: + if not refresh_token: + raise AuthSessionMissingError() + response = self._refresh_access_token(refresh_token) + if not response.session: + return AuthResponse() + session = response.session + else: + response = self.get_user(access_token) + session = Session( + access_token=access_token, + refresh_token=refresh_token, + user=response.user, + token_type="bearer", + expires_in=expires_at - time_now, + expires_at=expires_at, + ) + self._save_session(session) + self._notify_all_subscribers("TOKEN_REFRESHED", session) + return AuthResponse(session=session, user=response.user) + + def refresh_session(self, refresh_token: Union[str, None] = None) -> AuthResponse: + """ + Returns a new session, regardless of expiry status. + + Takes in an optional current session. If not passed in, then refreshSession() + will attempt to retrieve it from getSession(). If the current session's + refresh token is invalid, an error will be thrown. + """ + if not refresh_token: + session = self.get_session() + if session: + refresh_token = session.refresh_token + if not refresh_token: + raise AuthSessionMissingError() + session = self._call_refresh_token(refresh_token) + return AuthResponse(session=session, user=session.user) + + def sign_out(self) -> None: + """ + Inside a browser context, `sign_out` will remove the logged in user from the + browser session and log them out - removing all items from localstorage and + then trigger a `"SIGNED_OUT"` event. + + For server-side management, you can revoke all refresh tokens for a user by + passing a user's JWT through to `api.sign_out`. + + There is no way to revoke a user's access token jwt until it expires. + It is recommended to set a shorter expiry on the jwt for this reason. + """ + session = self.get_session() + access_token = session.access_token if session else None + if access_token: + self.admin.sign_out(access_token) + self._remove_session() + self._notify_all_subscribers("SIGNED_OUT", None) + + def on_auth_state_change( + self, + callback: Callable[[AuthChangeEvent, Union[Session, None]], None], + ) -> Subscription: + """ + Receive a notification every time an auth event happens. + """ + unique_id = str(uuid4()) + + def _unsubscribe() -> None: + self._state_change_emitters.pop(unique_id) + + subscription = Subscription( + id=unique_id, + callback=callback, + unsubscribe=_unsubscribe, + ) + self._state_change_emitters[unique_id] = subscription + return subscription + + def reset_password_email( + self, + email: str, + options: Options = {}, + ) -> None: + """ + Sends a password reset request to an email address. + """ + self._request( + "POST", + "recover", + body={ + "email": email, + "gotrue_meta_security": { + "captcha_token": options.get("captcha_token"), + }, + }, + redirect_to=options.get("redirect_to"), + ) + + # MFA methods + + def _enroll(self, params: MFAEnrollParams) -> AuthMFAEnrollResponse: + session = self.get_session() + if not session: + raise AuthSessionMissingError() + response = self._request( + "POST", + "factors", + body=params, + jwt=session.access_token, + xform=AuthMFAEnrollResponse.parse_obj, + ) + if response.totp.qr_code: + response.totp.qr_code = f"data:image/svg+xml;utf-8,{response.totp.qr_code}" + return response + + def _challenge(self, params: MFAChallengeParams) -> AuthMFAChallengeResponse: + if session := self.get_session(): + return self._request( + "POST", + f"factors/{params.get('factor_id')}/challenge", + jwt=session.access_token, + xform=AuthMFAChallengeResponse.parse_obj, + ) + else: + raise AuthSessionMissingError() + + def _challenge_and_verify( + self, + params: MFAChallengeAndVerifyParams, + ) -> AuthMFAVerifyResponse: + response = self._challenge( + { + "factor_id": params.get("factor_id"), + } + ) + return self._verify( + { + "factor_id": params.get("factor_id"), + "challenge_id": response.id, + "code": params.get("code"), + } + ) + + def _verify(self, params: MFAVerifyParams) -> AuthMFAVerifyResponse: + session = self.get_session() + if not session: + raise AuthSessionMissingError() + response = self._request( + "POST", + f"factors/{params.get('factor_id')}/verify", + body=params, + jwt=session.access_token, + xform=AuthMFAVerifyResponse.parse_obj, + ) + session = Session.parse_obj(response.dict()) + self._save_session(session) + self._notify_all_subscribers("MFA_CHALLENGE_VERIFIED", session) + return response + + def _unenroll(self, params: MFAUnenrollParams) -> AuthMFAUnenrollResponse: + if session := self.get_session(): + return self._request( + "DELETE", + f"factors/{params.get('factor_id')}", + jwt=session.access_token, + xform=AuthMFAUnenrollResponse.parse_obj, + ) + else: + raise AuthSessionMissingError() + + def _list_factors(self) -> AuthMFAListFactorsResponse: + response = self.get_user() + all = response.user.factors or [] + totp = [f for f in all if f.factor_type == "totp" and f.status == "verified"] + return AuthMFAListFactorsResponse(all=all, totp=totp) + + def _get_authenticator_assurance_level( + self, + ) -> AuthMFAGetAuthenticatorAssuranceLevelResponse: + session = self.get_session() + if not session: + return AuthMFAGetAuthenticatorAssuranceLevelResponse( + current_level=None, + next_level=None, + current_authentication_methods=[], + ) + payload = self._decode_jwt(session.access_token) + current_level: Union[AuthenticatorAssuranceLevels, None] = None + if payload.get("aal"): + current_level = payload.get("aal") + verified_factors = [ + f for f in session.user.factors or [] if f.status == "verified" + ] + next_level = "aal2" if verified_factors else current_level + current_authentication_methods = payload.get("amr") or [] + return AuthMFAGetAuthenticatorAssuranceLevelResponse( + current_level=current_level, + next_level=next_level, + current_authentication_methods=current_authentication_methods, + ) + + # Private methods + + def _remove_session(self) -> None: + if self._persist_session: + self._storage.remove_item(self._storage_key) + else: + self._in_memory_session = None + if self._refresh_token_timer: + self._refresh_token_timer.cancel() + self._refresh_token_timer = None + + def _get_session_from_url( + self, + url: str, + ) -> Tuple[Session, Union[str, None]]: + if not self._is_implicit_grant_flow(url): + raise AuthImplicitGrantRedirectError("Not a valid implicit grant flow url.") + result = urlparse(url) + params = parse_qs(result.query) + if error_description := self._get_param(params, "error_description"): + error_code = self._get_param(params, "error_code") + error = self._get_param(params, "error") + if not error_code: + raise AuthImplicitGrantRedirectError("No error_code detected.") + if not error: + raise AuthImplicitGrantRedirectError("No error detected.") + raise AuthImplicitGrantRedirectError( + error_description, + {"code": error_code, "error": error}, + ) + provider_token = self._get_param(params, "provider_token") + provider_refresh_token = self._get_param(params, "provider_refresh_token") + access_token = self._get_param(params, "access_token") + if not access_token: + raise AuthImplicitGrantRedirectError("No access_token detected.") + expires_in = self._get_param(params, "expires_in") + if not expires_in: + raise AuthImplicitGrantRedirectError("No expires_in detected.") + refresh_token = self._get_param(params, "refresh_token") + if not refresh_token: + raise AuthImplicitGrantRedirectError("No refresh_token detected.") + token_type = self._get_param(params, "token_type") + if not token_type: + raise AuthImplicitGrantRedirectError("No token_type detected.") + time_now = round(time()) + expires_at = time_now + int(expires_in) + user = self.get_user(access_token) + session = Session( + provider_token=provider_token, + provider_refresh_token=provider_refresh_token, + access_token=access_token, + expires_in=int(expires_in), + expires_at=expires_at, + refresh_token=refresh_token, + token_type=token_type, + user=user.user, + ) + redirect_type = self._get_param(params, "type") + return session, redirect_type + + def _recover_and_refresh(self) -> None: + raw_session = self._storage.get_item(self._storage_key) + current_session = self._get_valid_session(raw_session) + if not current_session: + if raw_session: + self._remove_session() + return + time_now = round(time()) + expires_at = current_session.expires_at + if expires_at and expires_at < time_now + EXPIRY_MARGIN: + refresh_token = current_session.refresh_token + if self._auto_refresh_token and refresh_token: + self._network_retries += 1 + try: + self._call_refresh_token(refresh_token) + self._network_retries = 0 + except Exception as e: + if ( + isinstance(e, AuthRetryableError) + and self._network_retries < MAX_RETRIES + ): + if self._refresh_token_timer: + self._refresh_token_timer.cancel() + self._refresh_token_timer = Timer( + (RETRY_INTERVAL ** (self._network_retries * 100)), + self._recover_and_refresh, + ) + self._refresh_token_timer.start() + return + self._remove_session() + return + if self._persist_session: + self._save_session(current_session) + self._notify_all_subscribers("SIGNED_IN", current_session) + + def _call_refresh_token(self, refresh_token: str) -> Session: + if not refresh_token: + raise AuthSessionMissingError() + response = self._refresh_access_token(refresh_token) + if not response.session: + raise AuthSessionMissingError() + self._save_session(response.session) + self._notify_all_subscribers("TOKEN_REFRESHED", response.session) + return response.session + + def _refresh_access_token(self, refresh_token: str) -> AuthResponse: + return self._request( + "POST", + "token", + query={"grant_type": "refresh_token"}, + body={"refresh_token": refresh_token}, + xform=parse_auth_response, + ) + + def _save_session(self, session: Session) -> None: + if not self._persist_session: + self._in_memory_session = session + if expire_at := session.expires_at: + time_now = round(time()) + expire_in = expire_at - time_now + refresh_duration_before_expires = ( + EXPIRY_MARGIN if expire_in > EXPIRY_MARGIN else 0.5 + ) + value = (expire_in - refresh_duration_before_expires) * 1000 + self._start_auto_refresh_token(value) + if self._persist_session and session.expires_at: + self._storage.set_item(self._storage_key, session.json()) + + def _start_auto_refresh_token(self, value: float) -> None: + if self._refresh_token_timer: + self._refresh_token_timer.cancel() + self._refresh_token_timer = None + if value <= 0 or not self._auto_refresh_token: + return + + def refresh_token_function(): + self._network_retries += 1 + try: + session = self.get_session() + if session: + self._call_refresh_token(session.refresh_token) + self._network_retries = 0 + except Exception as e: + if ( + isinstance(e, AuthRetryableError) + and self._network_retries < MAX_RETRIES + ): + self._start_auto_refresh_token( + RETRY_INTERVAL ** (self._network_retries * 100) + ) + + self._refresh_token_timer = Timer(value, refresh_token_function) + self._refresh_token_timer.start() + + def _notify_all_subscribers( + self, + event: AuthChangeEvent, + session: Union[Session, None], + ) -> None: + for subscription in self._state_change_emitters.values(): + subscription.callback(event, session) + + def _get_valid_session( + self, + raw_session: Union[str, None], + ) -> Union[Session, None]: + if not raw_session: + return None + data = loads(raw_session) + if not data: + return None + if not data.get("access_token"): + return None + if not data.get("refresh_token"): + return None + if not data.get("expires_at"): + return None + try: + expires_at = int(data["expires_at"]) + data["expires_at"] = expires_at + except ValueError: + return None + try: + return Session.parse_obj(data) + except Exception: + return None + + def _get_param( + self, + query_params: Dict[str, List[str]], + name: str, + ) -> Union[str, None]: + return query_params[name][0] if name in query_params else None + + def _is_implicit_grant_flow(self, url: str) -> bool: + result = urlparse(url) + params = parse_qs(result.query) + return "access_token" in params or "error_description" in params + + def _get_url_for_provider( + self, + provider: Provider, + params: Dict[str, str], + ) -> str: + params = {k: quote(v) for k, v in params.items()} + params["provider"] = quote(provider) + query = urlencode(params) + return f"{self._url}/authorize?{query}" + + def _decode_jwt(self, jwt: str) -> DecodedJWTDict: + """ + Decodes a JWT (without performing any validation). + """ + return decode_jwt_payload(jwt) diff --git a/gotrue/errors.py b/gotrue/errors.py new file mode 100644 index 00000000..742d5d44 --- /dev/null +++ b/gotrue/errors.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from typing import Union + +from typing_extensions import TypedDict + + +class AuthError(Exception): + def __init__(self, message: str) -> None: + Exception.__init__(self, message) + self.message = message + self.name = "AuthError" + + +class AuthApiErrorDict(TypedDict): + name: str + message: str + status: int + + +class AuthApiError(AuthError): + def __init__(self, message: str, status: int) -> None: + AuthError.__init__(self, message) + self.name = "AuthApiError" + self.status = status + + def to_dict(self) -> AuthApiErrorDict: + return { + "name": self.name, + "message": self.message, + "status": self.status, + } + + +class AuthUnknownError(AuthError): + def __init__(self, message: str, original_error: Exception) -> None: + AuthError.__init__(self, message) + self.name = "AuthUnknownError" + self.original_error = original_error + + +class CustomAuthError(AuthError): + def __init__(self, message: str, name: str, status: int) -> None: + AuthError.__init__(self, message) + self.name = name + self.status = status + + def to_dict(self) -> AuthApiErrorDict: + return { + "name": self.name, + "message": self.message, + "status": self.status, + } + + +class AuthSessionMissingError(CustomAuthError): + def __init__(self) -> None: + CustomAuthError.__init__( + self, + "Auth session missing!", + "AuthSessionMissingError", + 400, + ) + + +class AuthInvalidCredentialsError(CustomAuthError): + def __init__(self, message: str) -> None: + CustomAuthError.__init__( + self, + message, + "AuthInvalidCredentialsError", + 400, + ) + + +class AuthImplicitGrantRedirectErrorDetails(TypedDict): + error: str + code: str + + +class AuthImplicitGrantRedirectErrorDict(AuthApiErrorDict): + details: Union[AuthImplicitGrantRedirectErrorDetails, None] + + +class AuthImplicitGrantRedirectError(CustomAuthError): + def __init__( + self, + message: str, + details: Union[AuthImplicitGrantRedirectErrorDetails, None] = None, + ) -> None: + CustomAuthError.__init__( + self, + message, + "AuthImplicitGrantRedirectError", + 500, + ) + self.details = details + + def to_dict(self) -> AuthImplicitGrantRedirectErrorDict: + return { + "name": self.name, + "message": self.message, + "status": self.status, + "details": self.details, + } + + +class AuthRetryableError(CustomAuthError): + def __init__(self, message: str, status: int) -> None: + CustomAuthError.__init__( + self, + message, + "AuthRetryableError", + status, + ) diff --git a/gotrue/helpers.py b/gotrue/helpers.py index 069b79b7..6aa73025 100644 --- a/gotrue/helpers.py +++ b/gotrue/helpers.py @@ -1,18 +1,86 @@ from __future__ import annotations -from urllib.parse import quote +from base64 import b64decode +from json import loads +from typing import Any, Union, cast -from httpx import HTTPError, Response +from httpx import HTTPStatusError -from .exceptions import APIError +from .errors import AuthApiError, AuthError, AuthRetryableError, AuthUnknownError +from .types import ( + AuthResponse, + GenerateLinkProperties, + GenerateLinkResponse, + Session, + User, + UserResponse, +) -def encode_uri_component(uri: str) -> str: - return quote(uri.encode("utf-8")) +def parse_auth_response(data: Any) -> AuthResponse: + session: Union[Session, None] = None + if ( + "access_token" in data + and "refresh_token" in data + and "expires_in" in data + and data["access_token"] + and data["refresh_token"] + and data["expires_in"] + ): + session = Session.parse_obj(data) + user = User.parse_obj(data["user"]) if "user" in data else User.parse_obj(data) + return AuthResponse(session=session, user=user) -def check_response(response: Response) -> None: +def parse_link_response(data: Any) -> GenerateLinkResponse: + properties = GenerateLinkProperties( + action_link=data.get("action_link"), + email_otp=data.get("email_otp"), + hashed_token=data.get("hashed_token"), + redirect_to=data.get("redirect_to"), + verification_type=data.get("verification_type"), + ) + user = User.parse_obj({k: v for k, v in data.items() if k not in properties.dict()}) + return GenerateLinkResponse(properties=properties, user=user) + + +def parse_user_response(data: Any) -> UserResponse: + if "user" not in data: + data = {"user": data} + return UserResponse.parse_obj(data) + + +def get_error_message(error: Any) -> str: + props = ["msg", "message", "error_description", "error"] + filter = ( + lambda prop: prop in error if isinstance(error, dict) else hasattr(error, prop) + ) + return next((error[prop] for prop in props if filter(prop)), str(error)) + + +def looks_like_http_status_error(exception: Exception) -> bool: + return isinstance(exception, HTTPStatusError) + + +def handle_exception(exception: Exception) -> AuthError: + if not looks_like_http_status_error(exception): + return AuthRetryableError(get_error_message(exception), 0) + error = cast(HTTPStatusError, exception) try: - response.raise_for_status() - except HTTPError: - raise APIError.from_dict(response.json()) + network_error_codes = [502, 503, 504] + if error.response.status_code in network_error_codes: + return AuthRetryableError( + get_error_message(error), error.response.status_code + ) + json = error.response.json() + return AuthApiError(get_error_message(json), error.response.status_code or 500) + except Exception as e: + return AuthUnknownError(get_error_message(error), e) + + +def decode_jwt_payload(token: str) -> Any: + parts = token.split(".") + if len(parts) != 3: + raise ValueError("JWT is not valid: not a JWT structure") + base64Url = parts[1] + return loads(b64decode(base64Url).decode("utf-8")) diff --git a/tests/_async/test_api_with_auto_confirm_disabled.py b/tests/_async/test_api_with_auto_confirm_disabled.py index 3cd458dd..9564688c 100644 --- a/tests/_async/test_api_with_auto_confirm_disabled.py +++ b/tests/_async/test_api_with_auto_confirm_disabled.py @@ -3,7 +3,7 @@ import pytest from faker import Faker -from gotrue import AsyncGoTrueAPI +from gotrue import AsycnGoTrueAdminAPI from gotrue.constants import COOKIE_OPTIONS from gotrue.types import CookieOptions, LinkType, User @@ -12,8 +12,8 @@ @pytest.fixture(name="api") -async def create_api() -> AsyncIterable[AsyncGoTrueAPI]: - async with AsyncGoTrueAPI( +async def create_api() -> AsyncIterable[AsycnGoTrueAdminAPI]: + async with AsycnGoTrueAdminAPI( url=GOTRUE_URL, headers={"Authorization": f"Bearer {TOKEN}"}, cookie_options=CookieOptions.parse_obj(COOKIE_OPTIONS), @@ -27,7 +27,7 @@ async def create_api() -> AsyncIterable[AsyncGoTrueAPI]: password = fake.password() -async def test_sign_up_with_email_and_password(api: AsyncGoTrueAPI): +async def test_sign_up_with_email_and_password(api: AsycnGoTrueAdminAPI): try: response = await api.sign_up_with_email( email=email, @@ -44,7 +44,7 @@ async def test_sign_up_with_email_and_password(api: AsyncGoTrueAPI): password2 = fake.password() -async def test_generate_sign_up_link(api: AsyncGoTrueAPI): +async def test_generate_sign_up_link(api: AsycnGoTrueAdminAPI): try: response = await api.generate_link( type=LinkType.signup, @@ -61,7 +61,7 @@ async def test_generate_sign_up_link(api: AsyncGoTrueAPI): email3 = f"api_generate_link_signup_{fake.email().lower()}" -async def test_generate_magic_link(api: AsyncGoTrueAPI): +async def test_generate_magic_link(api: AsycnGoTrueAdminAPI): try: response = await api.generate_link( type=LinkType.magiclink, @@ -73,7 +73,7 @@ async def test_generate_magic_link(api: AsyncGoTrueAPI): assert False, str(e) -async def test_generate_invite_link(api: AsyncGoTrueAPI): +async def test_generate_invite_link(api: AsycnGoTrueAdminAPI): try: response = await api.generate_link( type=LinkType.invite, @@ -86,7 +86,7 @@ async def test_generate_invite_link(api: AsyncGoTrueAPI): @pytest.mark.depends(on=[test_sign_up_with_email_and_password.__name__]) -async def test_generate_recovery_link(api: AsyncGoTrueAPI): +async def test_generate_recovery_link(api: AsycnGoTrueAdminAPI): try: response = await api.generate_link( type=LinkType.recovery, diff --git a/tests/_async/test_api_with_auto_confirm_enabled.py b/tests/_async/test_api_with_auto_confirm_enabled.py index 13ff0411..3ff74d93 100644 --- a/tests/_async/test_api_with_auto_confirm_enabled.py +++ b/tests/_async/test_api_with_auto_confirm_enabled.py @@ -3,7 +3,6 @@ import pytest from faker import Faker -from gotrue import AsyncGoTrueAPI from gotrue.constants import COOKIE_OPTIONS from gotrue.types import CookieOptions, Session, User @@ -12,8 +11,8 @@ @pytest.fixture(name="api") -async def create_api() -> AsyncIterable[AsyncGoTrueAPI]: - async with AsyncGoTrueAPI( +async def create_api() -> AsyncIterable[AsyncGoTrueAdminAPI]: + async with AsyncGoTrueAdminAPI( url=GOTRUE_URL, headers={"Authorization": f"Bearer {TOKEN}"}, cookie_options=CookieOptions.parse_obj(COOKIE_OPTIONS), @@ -28,7 +27,7 @@ async def create_api() -> AsyncIterable[AsyncGoTrueAPI]: valid_session: Optional[Session] = None -async def test_sign_up_with_email(api: AsyncGoTrueAPI): +async def test_sign_up_with_email(api: AsyncGoTrueAdminAPI): global valid_session try: response = await api.sign_up_with_email( @@ -43,7 +42,7 @@ async def test_sign_up_with_email(api: AsyncGoTrueAPI): @pytest.mark.depends(on=[test_sign_up_with_email.__name__]) -async def test_get_user(api: AsyncGoTrueAPI): +async def test_get_user(api: AsyncGoTrueAdminAPI): try: jwt = valid_session.access_token if valid_session else "" response = await api.get_user(jwt=jwt) @@ -53,7 +52,7 @@ async def test_get_user(api: AsyncGoTrueAPI): @pytest.mark.depends(on=[test_get_user.__name__]) -async def test_delete_user(api: AsyncGoTrueAPI): +async def test_delete_user(api: AsyncGoTrueAdminAPI): try: jwt = valid_session.access_token if valid_session else "" user = await api.get_user(jwt=jwt) diff --git a/tests/_async/test_client_with_sign_ups_disabled.py b/tests/_async/test_client_with_sign_ups_disabled.py index cae4bbd8..51cef276 100644 --- a/tests/_async/test_client_with_sign_ups_disabled.py +++ b/tests/_async/test_client_with_sign_ups_disabled.py @@ -3,7 +3,7 @@ import pytest from faker import Faker -from gotrue import AsyncGoTrueAPI, AsyncGoTrueClient +from gotrue import AsycnGoTrueAdminAPI, AsyncGoTrueClient from gotrue.constants import COOKIE_OPTIONS, DEFAULT_HEADERS from gotrue.exceptions import APIError from gotrue.types import CookieOptions, LinkType, User, UserAttributes @@ -13,8 +13,8 @@ @pytest.fixture(name="auth_admin") -async def create_auth_admin() -> AsyncIterable[AsyncGoTrueAPI]: - async with AsyncGoTrueAPI( +async def create_auth_admin() -> AsyncIterable[AsycnGoTrueAdminAPI]: + async with AsycnGoTrueAdminAPI( url=GOTRUE_URL, headers={"Authorization": f"Bearer {AUTH_ADMIN_TOKEN}"}, cookie_options=CookieOptions.parse_obj(COOKIE_OPTIONS), @@ -53,7 +53,7 @@ async def test_sign_up(client: AsyncGoTrueClient): async def test_generate_link_should_be_able_to_generate_multiple_links( - auth_admin: AsyncGoTrueAPI, + auth_admin: AsycnGoTrueAdminAPI, ): try: response = await auth_admin.generate_link( @@ -103,7 +103,7 @@ async def test_generate_link_should_be_able_to_generate_multiple_links( email2 = fake.email().lower() -async def test_create_user(auth_admin: AsyncGoTrueAPI): +async def test_create_user(auth_admin: AsycnGoTrueAdminAPI): try: attributes = UserAttributes(email=email2) response = await auth_admin.create_user(attributes=attributes) diff --git a/tests/_async/test_gotrue_admin_api.py b/tests/_async/test_gotrue_admin_api.py new file mode 100644 index 00000000..a70380cf --- /dev/null +++ b/tests/_async/test_gotrue_admin_api.py @@ -0,0 +1,273 @@ +from gotrue.errors import AuthError + +from .clients import ( + auth_client_with_session, + client_api_auto_confirm_disabled_client, + client_api_auto_confirm_off_signups_enabled_client, + service_role_api_client, +) +from .utils import ( + create_new_user_with_email, + mock_app_metadata, + mock_user_credentials, + mock_user_metadata, + mock_verification_otp, +) + + +async def test_create_user_should_create_a_new_user(): + credentials = mock_user_credentials() + response = await create_new_user_with_email(email=credentials.get("email")) + assert response.email == credentials.get("email") + + +async def test_create_user_with_user_metadata(): + user_metadata = mock_user_metadata() + credentials = mock_user_credentials() + response = await service_role_api_client().create_user( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + "user_metadata": user_metadata, + } + ) + assert response.user.email == credentials.get("email") + assert response.user.user_metadata == user_metadata + assert "profile_image" in response.user.user_metadata + + +async def test_create_user_with_app_metadata(): + app_metadata = mock_app_metadata() + credentials = mock_user_credentials() + response = await service_role_api_client().create_user( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + "app_metadata": app_metadata, + } + ) + assert response.user.email == credentials.get("email") + assert "provider" in response.user.app_metadata + assert "providers" in response.user.app_metadata + + +async def test_create_user_with_user_and_app_metadata(): + user_metadata = mock_user_metadata() + app_metadata = mock_app_metadata() + credentials = mock_user_credentials() + response = await service_role_api_client().create_user( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + "user_metadata": user_metadata, + "app_metadata": app_metadata, + } + ) + assert response.user.email == credentials.get("email") + assert "profile_image" in response.user.user_metadata + assert "provider" in response.user.app_metadata + assert "providers" in response.user.app_metadata + + +async def test_list_users_should_return_registered_users(): + credentials = mock_user_credentials() + await create_new_user_with_email(email=credentials.get("email")) + users = await service_role_api_client().list_users() + assert users + emails = [user.email for user in users] + assert emails + assert credentials.get("email") in emails + + +async def test_get_user_fetches_a_user_by_their_access_token(): + credentials = mock_user_credentials() + auth_client_with_session_current_user = auth_client_with_session() + response = await auth_client_with_session_current_user.sign_up( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + } + ) + assert response.session + response = await auth_client_with_session_current_user.get_user() + assert response.user.email == credentials.get("email") + + +async def test_get_user_by_id_should_a_registered_user_given_its_user_identifier(): + credentials = mock_user_credentials() + user = await create_new_user_with_email(email=credentials.get("email")) + assert user.id + response = await service_role_api_client().get_user_by_id(user.id) + assert response.user.email == credentials.get("email") + + +async def test_modify_email_using_update_user_by_id(): + credentials = mock_user_credentials() + user = await create_new_user_with_email(email=credentials.get("email")) + response = await service_role_api_client().update_user_by_id( + user.id, + { + "email": f"new_{user.email}", + }, + ) + assert response.user.email == f"new_{user.email}" + + +async def test_modify_user_metadata_using_update_user_by_id(): + credentials = mock_user_credentials() + user = await create_new_user_with_email(email=credentials.get("email")) + user_metadata = {"favorite_color": "yellow"} + response = await service_role_api_client().update_user_by_id( + user.id, + { + "user_metadata": user_metadata, + }, + ) + assert response.user.email == user.email + assert response.user.user_metadata == user_metadata + + +async def test_modify_app_metadata_using_update_user_by_id(): + credentials = mock_user_credentials() + user = await create_new_user_with_email(email=credentials.get("email")) + app_metadata = {"roles": ["admin", "publisher"]} + response = await service_role_api_client().update_user_by_id( + user.id, + { + "app_metadata": app_metadata, + }, + ) + assert response.user.email == user.email + assert "roles" in response.user.app_metadata + + +async def test_modify_confirm_email_using_update_user_by_id(): + credentials = mock_user_credentials() + response = await client_api_auto_confirm_off_signups_enabled_client().sign_up( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + } + ) + assert response.user + assert not response.user.email_confirmed_at + response = await service_role_api_client().update_user_by_id( + response.user.id, + { + "email_confirm": True, + }, + ) + assert response.user.email_confirmed_at + + +async def test_delete_user_should_be_able_delete_an_existing_user(): + credentials = mock_user_credentials() + user = await create_new_user_with_email(email=credentials.get("email")) + await service_role_api_client().delete_user(user.id) + users = await service_role_api_client().list_users() + emails = [user.email for user in users] + assert credentials.get("email") not in emails + + +async def test_generate_link_supports_sign_up_with_generate_confirmation_signup_link(): + credentials = mock_user_credentials() + redirect_to = "http://localhost:9999/welcome" + user_metadata = {"status": "alpha"} + response = await service_role_api_client().generate_link( + { + "type": "signup", + "email": credentials.get("email"), + "password": credentials.get("password"), + "options": { + "data": user_metadata, + "redirect_to": redirect_to, + }, + }, + ) + assert response.user.user_metadata == user_metadata + + +async def test_generate_link_supports_updating_emails_with_generate_email_change_links(): # noqa: E501 + credentials = mock_user_credentials() + user = await create_new_user_with_email(email=credentials.get("email")) + assert user.email + assert user.email == credentials.get("email") + credentials = mock_user_credentials() + redirect_to = "http://localhost:9999/welcome" + response = await service_role_api_client().generate_link( + { + "type": "email_change_current", + "email": user.email, + "new_email": credentials.get("email"), + "options": { + "redirect_to": redirect_to, + }, + }, + ) + assert response.user.new_email == credentials.get("email") + + +async def test_invite_user_by_email_creates_a_new_user_with_an_invited_at_timestamp(): + credentials = mock_user_credentials() + redirect_to = "http://localhost:9999/welcome" + user_metadata = {"status": "alpha"} + response = await service_role_api_client().invite_user_by_email( + credentials.get("email"), + { + "data": user_metadata, + "redirect_to": redirect_to, + }, + ) + assert response.user.invited_at + + +async def test_sign_out_with_an_valid_access_token(): + credentials = mock_user_credentials() + response = await auth_client_with_session().sign_up( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + }, + ) + assert response.session + response = await service_role_api_client().sign_out(response.session.access_token) + + +async def test_sign_out_with_an_invalid_access_token(): + try: + await service_role_api_client().sign_out("this-is-a-bad-token") + assert False + except AuthError: + pass + + +async def test_verify_otp_with_non_existent_phone_number(): + credentials = mock_user_credentials() + otp = mock_verification_otp() + try: + await client_api_auto_confirm_disabled_client().verify_otp( + { + "phone": credentials.get("phone"), + "token": otp, + "type": "sms", + }, + ) + assert False + except AuthError as e: + assert e.message == "User not found" + + +async def test_verify_otp_with_invalid_phone_number(): + credentials = mock_user_credentials() + otp = mock_verification_otp() + try: + await client_api_auto_confirm_disabled_client().verify_otp( + { + "phone": f"{credentials.get('phone')}-invalid", + "token": otp, + "type": "sms", + }, + ) + assert False + except AuthError as e: + assert e.message == "Invalid phone number format" diff --git a/tests/_sync/test_api_with_auto_confirm_disabled.py b/tests/_sync/test_api_with_auto_confirm_disabled.py index b87f489c..3fd66f6c 100644 --- a/tests/_sync/test_api_with_auto_confirm_disabled.py +++ b/tests/_sync/test_api_with_auto_confirm_disabled.py @@ -3,7 +3,7 @@ import pytest from faker import Faker -from gotrue import SyncGoTrueAPI +from gotrue import AsycnGoTrueAdminAPI from gotrue.constants import COOKIE_OPTIONS from gotrue.types import CookieOptions, LinkType, User @@ -12,8 +12,8 @@ @pytest.fixture(name="api") -def create_api() -> Iterable[SyncGoTrueAPI]: - with SyncGoTrueAPI( +def create_api() -> Iterable[AsycnGoTrueAdminAPI]: + with AsycnGoTrueAdminAPI( url=GOTRUE_URL, headers={"Authorization": f"Bearer {TOKEN}"}, cookie_options=CookieOptions.parse_obj(COOKIE_OPTIONS), @@ -27,7 +27,7 @@ def create_api() -> Iterable[SyncGoTrueAPI]: password = fake.password() -def test_sign_up_with_email_and_password(api: SyncGoTrueAPI): +def test_sign_up_with_email_and_password(api: AsycnGoTrueAdminAPI): try: response = api.sign_up_with_email( email=email, @@ -44,7 +44,7 @@ def test_sign_up_with_email_and_password(api: SyncGoTrueAPI): password2 = fake.password() -def test_generate_sign_up_link(api: SyncGoTrueAPI): +def test_generate_sign_up_link(api: AsycnGoTrueAdminAPI): try: response = api.generate_link( type=LinkType.signup, @@ -61,7 +61,7 @@ def test_generate_sign_up_link(api: SyncGoTrueAPI): email3 = f"api_generate_link_signup_{fake.email().lower()}" -def test_generate_magic_link(api: SyncGoTrueAPI): +def test_generate_magic_link(api: AsycnGoTrueAdminAPI): try: response = api.generate_link( type=LinkType.magiclink, @@ -73,7 +73,7 @@ def test_generate_magic_link(api: SyncGoTrueAPI): assert False, str(e) -def test_generate_invite_link(api: SyncGoTrueAPI): +def test_generate_invite_link(api: AsycnGoTrueAdminAPI): try: response = api.generate_link( type=LinkType.invite, @@ -86,7 +86,7 @@ def test_generate_invite_link(api: SyncGoTrueAPI): @pytest.mark.depends(on=[test_sign_up_with_email_and_password.__name__]) -def test_generate_recovery_link(api: SyncGoTrueAPI): +def test_generate_recovery_link(api: AsycnGoTrueAdminAPI): try: response = api.generate_link( type=LinkType.recovery, diff --git a/tests/_sync/test_api_with_auto_confirm_enabled.py b/tests/_sync/test_api_with_auto_confirm_enabled.py index 578646a8..d3830a2b 100644 --- a/tests/_sync/test_api_with_auto_confirm_enabled.py +++ b/tests/_sync/test_api_with_auto_confirm_enabled.py @@ -3,7 +3,6 @@ import pytest from faker import Faker -from gotrue import SyncGoTrueAPI from gotrue.constants import COOKIE_OPTIONS from gotrue.types import CookieOptions, Session, User @@ -12,8 +11,8 @@ @pytest.fixture(name="api") -def create_api() -> Iterable[SyncGoTrueAPI]: - with SyncGoTrueAPI( +def create_api() -> Iterable[SyncGoTrueAdminAPI]: + with SyncGoTrueAdminAPI( url=GOTRUE_URL, headers={"Authorization": f"Bearer {TOKEN}"}, cookie_options=CookieOptions.parse_obj(COOKIE_OPTIONS), @@ -28,7 +27,7 @@ def create_api() -> Iterable[SyncGoTrueAPI]: valid_session: Optional[Session] = None -def test_sign_up_with_email(api: SyncGoTrueAPI): +def test_sign_up_with_email(api: SyncGoTrueAdminAPI): global valid_session try: response = api.sign_up_with_email( @@ -43,7 +42,7 @@ def test_sign_up_with_email(api: SyncGoTrueAPI): @pytest.mark.depends(on=[test_sign_up_with_email.__name__]) -def test_get_user(api: SyncGoTrueAPI): +def test_get_user(api: SyncGoTrueAdminAPI): try: jwt = valid_session.access_token if valid_session else "" response = api.get_user(jwt=jwt) @@ -53,7 +52,7 @@ def test_get_user(api: SyncGoTrueAPI): @pytest.mark.depends(on=[test_get_user.__name__]) -def test_delete_user(api: SyncGoTrueAPI): +def test_delete_user(api: SyncGoTrueAdminAPI): try: jwt = valid_session.access_token if valid_session else "" user = api.get_user(jwt=jwt) diff --git a/tests/_sync/test_client_with_sign_ups_disabled.py b/tests/_sync/test_client_with_sign_ups_disabled.py index 0e0cb9c8..9ff48c0b 100644 --- a/tests/_sync/test_client_with_sign_ups_disabled.py +++ b/tests/_sync/test_client_with_sign_ups_disabled.py @@ -3,7 +3,7 @@ import pytest from faker import Faker -from gotrue import SyncGoTrueAPI, SyncGoTrueClient +from gotrue import AsycnGoTrueAdminAPI, SyncGoTrueClient from gotrue.constants import COOKIE_OPTIONS, DEFAULT_HEADERS from gotrue.exceptions import APIError from gotrue.types import CookieOptions, LinkType, User, UserAttributes @@ -13,8 +13,8 @@ @pytest.fixture(name="auth_admin") -def create_auth_admin() -> Iterable[SyncGoTrueAPI]: - with SyncGoTrueAPI( +def create_auth_admin() -> Iterable[AsycnGoTrueAdminAPI]: + with AsycnGoTrueAdminAPI( url=GOTRUE_URL, headers={"Authorization": f"Bearer {AUTH_ADMIN_TOKEN}"}, cookie_options=CookieOptions.parse_obj(COOKIE_OPTIONS), @@ -53,7 +53,7 @@ def test_sign_up(client: SyncGoTrueClient): def test_generate_link_should_be_able_to_generate_multiple_links( - auth_admin: SyncGoTrueAPI, + auth_admin: AsycnGoTrueAdminAPI, ): try: response = auth_admin.generate_link( @@ -103,7 +103,7 @@ def test_generate_link_should_be_able_to_generate_multiple_links( email2 = fake.email().lower() -def test_create_user(auth_admin: SyncGoTrueAPI): +def test_create_user(auth_admin: AsycnGoTrueAdminAPI): try: attributes = UserAttributes(email=email2) response = auth_admin.create_user(attributes=attributes) diff --git a/tests/_sync/test_gotrue_admin_api.py b/tests/_sync/test_gotrue_admin_api.py new file mode 100644 index 00000000..34cd9580 --- /dev/null +++ b/tests/_sync/test_gotrue_admin_api.py @@ -0,0 +1,273 @@ +from gotrue.errors import AuthError + +from .clients import ( + auth_client_with_session, + client_api_auto_confirm_disabled_client, + client_api_auto_confirm_off_signups_enabled_client, + service_role_api_client, +) +from .utils import ( + create_new_user_with_email, + mock_app_metadata, + mock_user_credentials, + mock_user_metadata, + mock_verification_otp, +) + + +def test_create_user_should_create_a_new_user(): + credentials = mock_user_credentials() + response = create_new_user_with_email(email=credentials.get("email")) + assert response.email == credentials.get("email") + + +def test_create_user_with_user_metadata(): + user_metadata = mock_user_metadata() + credentials = mock_user_credentials() + response = service_role_api_client().create_user( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + "user_metadata": user_metadata, + } + ) + assert response.user.email == credentials.get("email") + assert response.user.user_metadata == user_metadata + assert "profile_image" in response.user.user_metadata + + +def test_create_user_with_app_metadata(): + app_metadata = mock_app_metadata() + credentials = mock_user_credentials() + response = service_role_api_client().create_user( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + "app_metadata": app_metadata, + } + ) + assert response.user.email == credentials.get("email") + assert "provider" in response.user.app_metadata + assert "providers" in response.user.app_metadata + + +def test_create_user_with_user_and_app_metadata(): + user_metadata = mock_user_metadata() + app_metadata = mock_app_metadata() + credentials = mock_user_credentials() + response = service_role_api_client().create_user( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + "user_metadata": user_metadata, + "app_metadata": app_metadata, + } + ) + assert response.user.email == credentials.get("email") + assert "profile_image" in response.user.user_metadata + assert "provider" in response.user.app_metadata + assert "providers" in response.user.app_metadata + + +def test_list_users_should_return_registered_users(): + credentials = mock_user_credentials() + create_new_user_with_email(email=credentials.get("email")) + users = service_role_api_client().list_users() + assert users + emails = [user.email for user in users] + assert emails + assert credentials.get("email") in emails + + +def test_get_user_fetches_a_user_by_their_access_token(): + credentials = mock_user_credentials() + auth_client_with_session_current_user = auth_client_with_session() + response = auth_client_with_session_current_user.sign_up( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + } + ) + assert response.session + response = auth_client_with_session_current_user.get_user() + assert response.user.email == credentials.get("email") + + +def test_get_user_by_id_should_a_registered_user_given_its_user_identifier(): + credentials = mock_user_credentials() + user = create_new_user_with_email(email=credentials.get("email")) + assert user.id + response = service_role_api_client().get_user_by_id(user.id) + assert response.user.email == credentials.get("email") + + +def test_modify_email_using_update_user_by_id(): + credentials = mock_user_credentials() + user = create_new_user_with_email(email=credentials.get("email")) + response = service_role_api_client().update_user_by_id( + user.id, + { + "email": f"new_{user.email}", + }, + ) + assert response.user.email == f"new_{user.email}" + + +def test_modify_user_metadata_using_update_user_by_id(): + credentials = mock_user_credentials() + user = create_new_user_with_email(email=credentials.get("email")) + user_metadata = {"favorite_color": "yellow"} + response = service_role_api_client().update_user_by_id( + user.id, + { + "user_metadata": user_metadata, + }, + ) + assert response.user.email == user.email + assert response.user.user_metadata == user_metadata + + +def test_modify_app_metadata_using_update_user_by_id(): + credentials = mock_user_credentials() + user = create_new_user_with_email(email=credentials.get("email")) + app_metadata = {"roles": ["admin", "publisher"]} + response = service_role_api_client().update_user_by_id( + user.id, + { + "app_metadata": app_metadata, + }, + ) + assert response.user.email == user.email + assert "roles" in response.user.app_metadata + + +def test_modify_confirm_email_using_update_user_by_id(): + credentials = mock_user_credentials() + response = client_api_auto_confirm_off_signups_enabled_client().sign_up( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + } + ) + assert response.user + assert not response.user.email_confirmed_at + response = service_role_api_client().update_user_by_id( + response.user.id, + { + "email_confirm": True, + }, + ) + assert response.user.email_confirmed_at + + +def test_delete_user_should_be_able_delete_an_existing_user(): + credentials = mock_user_credentials() + user = create_new_user_with_email(email=credentials.get("email")) + service_role_api_client().delete_user(user.id) + users = service_role_api_client().list_users() + emails = [user.email for user in users] + assert credentials.get("email") not in emails + + +def test_generate_link_supports_sign_up_with_generate_confirmation_signup_link(): + credentials = mock_user_credentials() + redirect_to = "http://localhost:9999/welcome" + user_metadata = {"status": "alpha"} + response = service_role_api_client().generate_link( + { + "type": "signup", + "email": credentials.get("email"), + "password": credentials.get("password"), + "options": { + "data": user_metadata, + "redirect_to": redirect_to, + }, + }, + ) + assert response.user.user_metadata == user_metadata + + +def test_generate_link_supports_updating_emails_with_generate_email_change_links(): # noqa: E501 + credentials = mock_user_credentials() + user = create_new_user_with_email(email=credentials.get("email")) + assert user.email + assert user.email == credentials.get("email") + credentials = mock_user_credentials() + redirect_to = "http://localhost:9999/welcome" + response = service_role_api_client().generate_link( + { + "type": "email_change_current", + "email": user.email, + "new_email": credentials.get("email"), + "options": { + "redirect_to": redirect_to, + }, + }, + ) + assert response.user.new_email == credentials.get("email") + + +def test_invite_user_by_email_creates_a_new_user_with_an_invited_at_timestamp(): + credentials = mock_user_credentials() + redirect_to = "http://localhost:9999/welcome" + user_metadata = {"status": "alpha"} + response = service_role_api_client().invite_user_by_email( + credentials.get("email"), + { + "data": user_metadata, + "redirect_to": redirect_to, + }, + ) + assert response.user.invited_at + + +def test_sign_out_with_an_valid_access_token(): + credentials = mock_user_credentials() + response = auth_client_with_session().sign_up( + { + "email": credentials.get("email"), + "password": credentials.get("password"), + }, + ) + assert response.session + response = service_role_api_client().sign_out(response.session.access_token) + + +def test_sign_out_with_an_invalid_access_token(): + try: + service_role_api_client().sign_out("this-is-a-bad-token") + assert False + except AuthError: + pass + + +def test_verify_otp_with_non_existent_phone_number(): + credentials = mock_user_credentials() + otp = mock_verification_otp() + try: + client_api_auto_confirm_disabled_client().verify_otp( + { + "phone": credentials.get("phone"), + "token": otp, + "type": "sms", + }, + ) + assert False + except AuthError as e: + assert e.message == "User not found" + + +def test_verify_otp_with_invalid_phone_number(): + credentials = mock_user_credentials() + otp = mock_verification_otp() + try: + client_api_auto_confirm_disabled_client().verify_otp( + { + "phone": f"{credentials.get('phone')}-invalid", + "token": otp, + "type": "sms", + }, + ) + assert False + except AuthError as e: + assert e.message == "Invalid phone number format"