diff --git a/gotrue/_async/client.py b/gotrue/_async/client.py index 4a1b6c5f..6073cc71 100644 --- a/gotrue/_async/client.py +++ b/gotrue/_async/client.py @@ -18,6 +18,7 @@ Subscription, User, UserAttributes, + UserAttributesDict, ) from .api import AsyncGoTrueAPI from .storage import AsyncMemoryStorage, AsyncSupportedStorage @@ -301,13 +302,15 @@ async def refresh_session(self) -> Session: raise ValueError("Not logged in.") return await self._call_refresh_token() - async def update(self, *, attributes: UserAttributes) -> User: + async def update( + self, *, attributes: Union[UserAttributesDict, UserAttributes] + ) -> User: """Updates user data, if there is a logged in user. Parameters ---------- - attributes : UserAttributes - The attributes to update. + attributes : UserAttributesDict | UserAttributes + Attributes to update, could be: email, password, email_change_token, data Returns ------- @@ -321,9 +324,15 @@ async def update(self, *, attributes: UserAttributes) -> User: """ if not self.current_session: raise ValueError("Not logged in.") + + if isinstance(attributes, dict): + attributes_to_update = UserAttributes(**attributes) + else: + attributes_to_update = attributes + response = await self.api.update_user( jwt=self.current_session.access_token, - attributes=attributes, + attributes=attributes_to_update, ) self.current_session.user = response await self._save_session(session=self.current_session) diff --git a/gotrue/_sync/client.py b/gotrue/_sync/client.py index 7d8d9ca4..09f8f1d1 100644 --- a/gotrue/_sync/client.py +++ b/gotrue/_sync/client.py @@ -18,6 +18,7 @@ Subscription, User, UserAttributes, + UserAttributesDict, ) from .api import SyncGoTrueAPI from .storage import SyncMemoryStorage, SyncSupportedStorage @@ -299,7 +300,7 @@ def refresh_session(self) -> Session: raise ValueError("Not logged in.") return self._call_refresh_token() - def update(self, *, attributes: UserAttributes) -> User: + def update(self, *, attributes: Union[UserAttributesDict, UserAttributes]) -> User: """Updates user data, if there is a logged in user. Parameters @@ -319,9 +320,15 @@ def update(self, *, attributes: UserAttributes) -> User: """ if not self.current_session: raise ValueError("Not logged in.") + + if isinstance(attributes, dict): + attributes_to_update = UserAttributes(**attributes) + else: + attributes_to_update = attributes + response = self.api.update_user( jwt=self.current_session.access_token, - attributes=attributes, + attributes=attributes_to_update, ) self.current_session.user = response self._save_session(session=self.current_session) diff --git a/gotrue/types.py b/gotrue/types.py index 72b54cc7..63680b41 100644 --- a/gotrue/types.py +++ b/gotrue/types.py @@ -1,11 +1,17 @@ from __future__ import annotations +import sys from datetime import datetime from enum import Enum from time import time from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union from uuid import UUID +if sys.version_info >= (3, 8): + from typing import TypedDict +else: + from typing_extensions import TypedDict + from httpx import Response from pydantic import BaseModel, root_validator @@ -150,3 +156,12 @@ class LinkType(str, Enum): magiclink = "magiclink" recovery = "recovery" invite = "invite" + + +class UserAttributesDict(TypedDict): + """Dict version of `UserAttributes`""" + + email: Optional[str] + password: Optional[str] + email_change_token: Optional[str] + data: Optional[Any] diff --git a/tests/_async/test_client_with_auto_confirm_enabled.py b/tests/_async/test_client_with_auto_confirm_enabled.py index 3c41e2d0..b1b3eeb4 100644 --- a/tests/_async/test_client_with_auto_confirm_enabled.py +++ b/tests/_async/test_client_with_auto_confirm_enabled.py @@ -259,6 +259,25 @@ async def test_update_user(client: AsyncGoTrueClient): assert False, str(e) +@pytest.mark.asyncio +@pytest.mark.depends(on=[test_sign_in.__name__]) +async def test_update_user_dict(client: AsyncGoTrueClient): + try: + await client.init_recover() + response = await client.update(attributes={"data": {"hello": "world"}}) + assert isinstance(response, User) + assert response.id + assert response.email == email + assert response.email_confirmed_at + assert response.last_sign_in_at + assert response.created_at + assert response.updated_at + assert response.user_metadata + assert response.user_metadata.get("hello") == "world" + except Exception as e: + assert False, str(e) + + @pytest.mark.asyncio @pytest.mark.depends(on=[test_update_user.__name__]) async def test_get_user_after_update(client: AsyncGoTrueClient): diff --git a/tests/_sync/test_client_with_auto_confirm_enabled.py b/tests/_sync/test_client_with_auto_confirm_enabled.py index 65edf3c0..8848e9c4 100644 --- a/tests/_sync/test_client_with_auto_confirm_enabled.py +++ b/tests/_sync/test_client_with_auto_confirm_enabled.py @@ -253,6 +253,25 @@ def test_update_user(client: SyncGoTrueClient): assert False, str(e) +@pytest.mark.asyncio +@pytest.mark.depends(on=[test_sign_in.__name__]) +def test_update_user(client: SyncGoTrueClient): + try: + client.init_recover() + response = client.update(attributes={"data": {"hello": "world"}}) + assert isinstance(response, User) + assert response.id + assert response.email == email + assert response.email_confirmed_at + assert response.last_sign_in_at + assert response.created_at + assert response.updated_at + assert response.user_metadata + assert response.user_metadata.get("hello") == "world" + except Exception as e: + assert False, str(e) + + @pytest.mark.asyncio @pytest.mark.depends(on=[test_update_user.__name__]) def test_get_user_after_update(client: SyncGoTrueClient):