diff --git a/gotrue/_async/client.py b/gotrue/_async/client.py index 4a1b6c5f..6073cc71 100644 --- a/gotrue/_async/client.py +++ b/gotrue/_async/client.py @@ -18,6 +18,7 @@ Subscription, User, UserAttributes, + UserAttributesDict, ) from .api import AsyncGoTrueAPI from .storage import AsyncMemoryStorage, AsyncSupportedStorage @@ -301,13 +302,15 @@ async def refresh_session(self) -> Session: raise ValueError("Not logged in.") return await self._call_refresh_token() - async def update(self, *, attributes: UserAttributes) -> User: + async def update( + self, *, attributes: Union[UserAttributesDict, UserAttributes] + ) -> User: """Updates user data, if there is a logged in user. Parameters ---------- - attributes : UserAttributes - The attributes to update. + attributes : UserAttributesDict | UserAttributes + Attributes to update, could be: email, password, email_change_token, data Returns ------- @@ -321,9 +324,15 @@ async def update(self, *, attributes: UserAttributes) -> User: """ if not self.current_session: raise ValueError("Not logged in.") + + if isinstance(attributes, dict): + attributes_to_update = UserAttributes(**attributes) + else: + attributes_to_update = attributes + response = await self.api.update_user( jwt=self.current_session.access_token, - attributes=attributes, + attributes=attributes_to_update, ) self.current_session.user = response await self._save_session(session=self.current_session) diff --git a/gotrue/_sync/client.py b/gotrue/_sync/client.py index 7d8d9ca4..09f8f1d1 100644 --- a/gotrue/_sync/client.py +++ b/gotrue/_sync/client.py @@ -18,6 +18,7 @@ Subscription, User, UserAttributes, + UserAttributesDict, ) from .api import SyncGoTrueAPI from .storage import SyncMemoryStorage, SyncSupportedStorage @@ -299,7 +300,7 @@ def refresh_session(self) -> Session: raise ValueError("Not logged in.") return self._call_refresh_token() - def update(self, *, attributes: UserAttributes) -> User: + def update(self, *, attributes: Union[UserAttributesDict, UserAttributes]) -> User: """Updates user data, if there is a logged in user. Parameters @@ -319,9 +320,15 @@ def update(self, *, attributes: UserAttributes) -> User: """ if not self.current_session: raise ValueError("Not logged in.") + + if isinstance(attributes, dict): + attributes_to_update = UserAttributes(**attributes) + else: + attributes_to_update = attributes + response = self.api.update_user( jwt=self.current_session.access_token, - attributes=attributes, + attributes=attributes_to_update, ) self.current_session.user = response self._save_session(session=self.current_session) diff --git a/gotrue/types.py b/gotrue/types.py index 72b54cc7..63680b41 100644 --- a/gotrue/types.py +++ b/gotrue/types.py @@ -1,11 +1,17 @@ from __future__ import annotations +import sys from datetime import datetime from enum import Enum from time import time from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union from uuid import UUID +if sys.version_info >= (3, 8): + from typing import TypedDict +else: + from typing_extensions import TypedDict + from httpx import Response from pydantic import BaseModel, root_validator @@ -150,3 +156,12 @@ class LinkType(str, Enum): magiclink = "magiclink" recovery = "recovery" invite = "invite" + + +class UserAttributesDict(TypedDict): + """Dict version of `UserAttributes`""" + + email: Optional[str] + password: Optional[str] + email_change_token: Optional[str] + data: Optional[Any] diff --git a/tests/_async/test_client_with_auto_confirm_enabled.py b/tests/_async/test_client_with_auto_confirm_enabled.py index 3c41e2d0..b82ab552 100644 --- a/tests/_async/test_client_with_auto_confirm_enabled.py +++ b/tests/_async/test_client_with_auto_confirm_enabled.py @@ -259,6 +259,25 @@ async def test_update_user(client: AsyncGoTrueClient): assert False, str(e) +@pytest.mark.asyncio +@pytest.mark.depends(on=[test_sign_in.__name__]) +async def test_update_user_dict(client: AsyncGoTrueClient): + try: + await client.init_recover() + response = await client.update(attributes={"data": {"hello": "world"}}) + assert isinstance(response, User) + assert response.id + assert response.email == email + assert response.email_confirmed_at + assert response.last_sign_in_at + assert response.created_at + assert response.updated_at + assert response.user_metadata + assert response.user_metadata.get("hello") == "world" + except Exception as e: + assert False, str(e) + + @pytest.mark.asyncio @pytest.mark.depends(on=[test_update_user.__name__]) async def test_get_user_after_update(client: AsyncGoTrueClient): @@ -319,10 +338,10 @@ async def test_get_update_user_after_sign_out(client: AsyncGoTrueClient): @pytest.mark.depends(on=[test_get_user_after_sign_out.__name__]) async def test_sign_in_with_the_wrong_password(client: AsyncGoTrueClient): try: - await client.sign_in(email=email, password=password + "2") + await client.sign_in(email=email, password=f"{password}2") assert False except APIError: - assert True + pass except Exception as e: assert False, str(e) @@ -401,8 +420,9 @@ async def test_get_session_from_url_errors(client: AsyncGoTrueClient): error_description = fake.email() try: await client.get_session_from_url( - url=dummy_url + f"?error_description={error_description}" + url=f"{dummy_url}?error_description={error_description}" ) + assert False except APIError as e: assert e.code == 400 diff --git a/tests/_sync/test_client_with_auto_confirm_enabled.py b/tests/_sync/test_client_with_auto_confirm_enabled.py index 65edf3c0..ce4ba110 100644 --- a/tests/_sync/test_client_with_auto_confirm_enabled.py +++ b/tests/_sync/test_client_with_auto_confirm_enabled.py @@ -253,6 +253,25 @@ def test_update_user(client: SyncGoTrueClient): assert False, str(e) +@pytest.mark.asyncio +@pytest.mark.depends(on=[test_sign_in.__name__]) +def test_update_user(client: SyncGoTrueClient): + try: + client.init_recover() + response = client.update(attributes={"data": {"hello": "world"}}) + assert isinstance(response, User) + assert response.id + assert response.email == email + assert response.email_confirmed_at + assert response.last_sign_in_at + assert response.created_at + assert response.updated_at + assert response.user_metadata + assert response.user_metadata.get("hello") == "world" + except Exception as e: + assert False, str(e) + + @pytest.mark.asyncio @pytest.mark.depends(on=[test_update_user.__name__]) def test_get_user_after_update(client: SyncGoTrueClient): @@ -313,10 +332,10 @@ def test_get_update_user_after_sign_out(client: SyncGoTrueClient): @pytest.mark.depends(on=[test_get_user_after_sign_out.__name__]) def test_sign_in_with_the_wrong_password(client: SyncGoTrueClient): try: - client.sign_in(email=email, password=password + "2") + client.sign_in(email=email, password=f"{password}2") assert False except APIError: - assert True + pass except Exception as e: assert False, str(e) @@ -395,8 +414,9 @@ def test_get_session_from_url_errors(client: SyncGoTrueClient): error_description = fake.email() try: client.get_session_from_url( - url=dummy_url + f"?error_description={error_description}" + url=f"{dummy_url}?error_description={error_description}" ) + assert False except APIError as e: assert e.code == 400