From 0a27464e1784df1952b352f3ca99fbf685b5bc86 Mon Sep 17 00:00:00 2001 From: odiseo0 Date: Tue, 12 Jul 2022 13:20:52 -0400 Subject: [PATCH 1/7] Allow users to send a dict instead of UserAttributes model --- gotrue/_async/client.py | 15 +++++++++++---- gotrue/_sync/client.py | 11 +++++++++-- gotrue/types.py | 11 ++++++++++- 3 files changed, 30 insertions(+), 7 deletions(-) diff --git a/gotrue/_async/client.py b/gotrue/_async/client.py index 4a1b6c5f..0298ba15 100644 --- a/gotrue/_async/client.py +++ b/gotrue/_async/client.py @@ -18,6 +18,7 @@ Subscription, User, UserAttributes, + UserAttributesDict ) from .api import AsyncGoTrueAPI from .storage import AsyncMemoryStorage, AsyncSupportedStorage @@ -301,13 +302,13 @@ async def refresh_session(self) -> Session: raise ValueError("Not logged in.") return await self._call_refresh_token() - async def update(self, *, attributes: UserAttributes) -> User: + async def update(self, *, attributes: Union[UserAttributesDict, UserAttributes]) -> User: """Updates user data, if there is a logged in user. Parameters ---------- - attributes : UserAttributes - The attributes to update. + attributes : UserAttributesDict | UserAttributes + Attributes to update, could be: email, password, email_change_token, data Returns ------- @@ -321,9 +322,15 @@ async def update(self, *, attributes: UserAttributes) -> User: """ if not self.current_session: raise ValueError("Not logged in.") + + if isinstance(attributes, dict): + attributes_to_update = UserAttributes(**attributes) + else: + attributes_to_update = attributes + response = await self.api.update_user( jwt=self.current_session.access_token, - attributes=attributes, + attributes=attributes_to_update, ) self.current_session.user = response await self._save_session(session=self.current_session) diff --git a/gotrue/_sync/client.py b/gotrue/_sync/client.py index 7d8d9ca4..99514cc9 100644 --- a/gotrue/_sync/client.py +++ b/gotrue/_sync/client.py @@ -18,6 +18,7 @@ Subscription, User, UserAttributes, + UserAttributesDict ) from .api import SyncGoTrueAPI from .storage import SyncMemoryStorage, SyncSupportedStorage @@ -299,7 +300,7 @@ def refresh_session(self) -> Session: raise ValueError("Not logged in.") return self._call_refresh_token() - def update(self, *, attributes: UserAttributes) -> User: + def update(self, *, attributes: Union[UserAttributesDict, UserAttributes]) -> User: """Updates user data, if there is a logged in user. Parameters @@ -319,9 +320,15 @@ def update(self, *, attributes: UserAttributes) -> User: """ if not self.current_session: raise ValueError("Not logged in.") + + if isinstance(attributes, dict): + attributes_to_update = UserAttributes(**attributes) + else: + attributes_to_update = attributes + response = self.api.update_user( jwt=self.current_session.access_token, - attributes=attributes, + attributes=attributes_to_update, ) self.current_session.user = response self._save_session(session=self.current_session) diff --git a/gotrue/types.py b/gotrue/types.py index 72b54cc7..2e426978 100644 --- a/gotrue/types.py +++ b/gotrue/types.py @@ -3,7 +3,7 @@ from datetime import datetime from enum import Enum from time import time -from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, TypedDict from uuid import UUID from httpx import Response @@ -11,6 +11,7 @@ from gotrue.helpers import check_response + T = TypeVar("T", bound=BaseModel) @@ -150,3 +151,11 @@ class LinkType(str, Enum): magiclink = "magiclink" recovery = "recovery" invite = "invite" + + +class UserAttributesDict(TypedDict): + """Dict version of `UserAttributes`""" + email: Optional[str] + password: Optional[str] + email_change_token: Optional[str] + data: Optional[Any] From a32364ae108002982f2fd7a36b6922e12d9cb488 Mon Sep 17 00:00:00 2001 From: odiseo0 Date: Tue, 12 Jul 2022 13:20:59 -0400 Subject: [PATCH 2/7] Add tests --- .../test_client_with_auto_confirm_enabled.py | 21 +++++++++++++++++++ .../test_client_with_auto_confirm_enabled.py | 19 +++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/tests/_async/test_client_with_auto_confirm_enabled.py b/tests/_async/test_client_with_auto_confirm_enabled.py index 3c41e2d0..7e9f9495 100644 --- a/tests/_async/test_client_with_auto_confirm_enabled.py +++ b/tests/_async/test_client_with_auto_confirm_enabled.py @@ -259,6 +259,27 @@ async def test_update_user(client: AsyncGoTrueClient): assert False, str(e) +@pytest.mark.asyncio +@pytest.mark.depends(on=[test_sign_in.__name__]) +async def test_update_user_dict(client: AsyncGoTrueClient): + try: + await client.init_recover() + response = await client.update( + attributes={"data": {"hello":"world"}} + ) + assert isinstance(response, User) + assert response.id + assert response.email == email + assert response.email_confirmed_at + assert response.last_sign_in_at + assert response.created_at + assert response.updated_at + assert response.user_metadata + assert response.user_metadata.get("hello") == "world" + except Exception as e: + assert False, str(e) + + @pytest.mark.asyncio @pytest.mark.depends(on=[test_update_user.__name__]) async def test_get_user_after_update(client: AsyncGoTrueClient): diff --git a/tests/_sync/test_client_with_auto_confirm_enabled.py b/tests/_sync/test_client_with_auto_confirm_enabled.py index 65edf3c0..6291ad4a 100644 --- a/tests/_sync/test_client_with_auto_confirm_enabled.py +++ b/tests/_sync/test_client_with_auto_confirm_enabled.py @@ -253,6 +253,25 @@ def test_update_user(client: SyncGoTrueClient): assert False, str(e) +@pytest.mark.asyncio +@pytest.mark.depends(on=[test_sign_in.__name__]) +def test_update_user(client: SyncGoTrueClient): + try: + client.init_recover() + response = client.update(attributes={"data":{"hello": "world"}}) + assert isinstance(response, User) + assert response.id + assert response.email == email + assert response.email_confirmed_at + assert response.last_sign_in_at + assert response.created_at + assert response.updated_at + assert response.user_metadata + assert response.user_metadata.get("hello") == "world" + except Exception as e: + assert False, str(e) + + @pytest.mark.asyncio @pytest.mark.depends(on=[test_update_user.__name__]) def test_get_user_after_update(client: SyncGoTrueClient): From 6d3833c725c6a2a76c57ae521f5f116677a38e99 Mon Sep 17 00:00:00 2001 From: odiseo0 Date: Wed, 13 Jul 2022 10:54:37 -0400 Subject: [PATCH 3/7] Check Python version before trying to import TypedDict --- gotrue/types.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/gotrue/types.py b/gotrue/types.py index 2e426978..fd4f6660 100644 --- a/gotrue/types.py +++ b/gotrue/types.py @@ -1,11 +1,17 @@ from __future__ import annotations +import sys from datetime import datetime from enum import Enum from time import time from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, TypedDict from uuid import UUID +if sys.version_info >= (3, 8): + from typing import TypedDict +else: + from typing_extensions import TypedDict + from httpx import Response from pydantic import BaseModel, root_validator From 8ec45760d1fbef0a5c984f0a13177023d1e447ae Mon Sep 17 00:00:00 2001 From: odiseo0 Date: Wed, 13 Jul 2022 11:00:06 -0400 Subject: [PATCH 4/7] Remove `TypedDict` from the main `typing` import --- gotrue/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gotrue/types.py b/gotrue/types.py index fd4f6660..21d0fe7e 100644 --- a/gotrue/types.py +++ b/gotrue/types.py @@ -4,7 +4,7 @@ from datetime import datetime from enum import Enum from time import time -from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, TypedDict +from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union from uuid import UUID if sys.version_info >= (3, 8): From 94ef8bfe569b99efaea252b4c7d1ae9b404c84d2 Mon Sep 17 00:00:00 2001 From: odiseo0 Date: Wed, 13 Jul 2022 12:02:23 -0400 Subject: [PATCH 5/7] Change format --- gotrue/_async/client.py | 6 ++++-- gotrue/_sync/client.py | 2 +- tests/_async/test_client_with_auto_confirm_enabled.py | 4 +--- tests/_sync/test_client_with_auto_confirm_enabled.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/gotrue/_async/client.py b/gotrue/_async/client.py index 0298ba15..6073cc71 100644 --- a/gotrue/_async/client.py +++ b/gotrue/_async/client.py @@ -18,7 +18,7 @@ Subscription, User, UserAttributes, - UserAttributesDict + UserAttributesDict, ) from .api import AsyncGoTrueAPI from .storage import AsyncMemoryStorage, AsyncSupportedStorage @@ -302,7 +302,9 @@ async def refresh_session(self) -> Session: raise ValueError("Not logged in.") return await self._call_refresh_token() - async def update(self, *, attributes: Union[UserAttributesDict, UserAttributes]) -> User: + async def update( + self, *, attributes: Union[UserAttributesDict, UserAttributes] + ) -> User: """Updates user data, if there is a logged in user. Parameters diff --git a/gotrue/_sync/client.py b/gotrue/_sync/client.py index 99514cc9..09f8f1d1 100644 --- a/gotrue/_sync/client.py +++ b/gotrue/_sync/client.py @@ -18,7 +18,7 @@ Subscription, User, UserAttributes, - UserAttributesDict + UserAttributesDict, ) from .api import SyncGoTrueAPI from .storage import SyncMemoryStorage, SyncSupportedStorage diff --git a/tests/_async/test_client_with_auto_confirm_enabled.py b/tests/_async/test_client_with_auto_confirm_enabled.py index 7e9f9495..b1b3eeb4 100644 --- a/tests/_async/test_client_with_auto_confirm_enabled.py +++ b/tests/_async/test_client_with_auto_confirm_enabled.py @@ -264,9 +264,7 @@ async def test_update_user(client: AsyncGoTrueClient): async def test_update_user_dict(client: AsyncGoTrueClient): try: await client.init_recover() - response = await client.update( - attributes={"data": {"hello":"world"}} - ) + response = await client.update(attributes={"data": {"hello": "world"}}) assert isinstance(response, User) assert response.id assert response.email == email diff --git a/tests/_sync/test_client_with_auto_confirm_enabled.py b/tests/_sync/test_client_with_auto_confirm_enabled.py index 6291ad4a..8848e9c4 100644 --- a/tests/_sync/test_client_with_auto_confirm_enabled.py +++ b/tests/_sync/test_client_with_auto_confirm_enabled.py @@ -258,7 +258,7 @@ def test_update_user(client: SyncGoTrueClient): def test_update_user(client: SyncGoTrueClient): try: client.init_recover() - response = client.update(attributes={"data":{"hello": "world"}}) + response = client.update(attributes={"data": {"hello": "world"}}) assert isinstance(response, User) assert response.id assert response.email == email From 8e41782f6e35b29d65da3db531213bdaab1e3df5 Mon Sep 17 00:00:00 2001 From: odiseo0 Date: Wed, 13 Jul 2022 12:13:36 -0400 Subject: [PATCH 6/7] Change format of types --- gotrue/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gotrue/types.py b/gotrue/types.py index 21d0fe7e..63680b41 100644 --- a/gotrue/types.py +++ b/gotrue/types.py @@ -17,7 +17,6 @@ from gotrue.helpers import check_response - T = TypeVar("T", bound=BaseModel) @@ -161,6 +160,7 @@ class LinkType(str, Enum): class UserAttributesDict(TypedDict): """Dict version of `UserAttributes`""" + email: Optional[str] password: Optional[str] email_change_token: Optional[str] From 0d6ed3cd2bda72278951bbda42855c299e2e754c Mon Sep 17 00:00:00 2001 From: Sourcery AI <> Date: Wed, 13 Jul 2022 16:14:08 +0000 Subject: [PATCH 7/7] 'Refactored by Sourcery' --- tests/_async/test_client_with_auto_confirm_enabled.py | 7 ++++--- tests/_sync/test_client_with_auto_confirm_enabled.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/_async/test_client_with_auto_confirm_enabled.py b/tests/_async/test_client_with_auto_confirm_enabled.py index b1b3eeb4..b82ab552 100644 --- a/tests/_async/test_client_with_auto_confirm_enabled.py +++ b/tests/_async/test_client_with_auto_confirm_enabled.py @@ -338,10 +338,10 @@ async def test_get_update_user_after_sign_out(client: AsyncGoTrueClient): @pytest.mark.depends(on=[test_get_user_after_sign_out.__name__]) async def test_sign_in_with_the_wrong_password(client: AsyncGoTrueClient): try: - await client.sign_in(email=email, password=password + "2") + await client.sign_in(email=email, password=f"{password}2") assert False except APIError: - assert True + pass except Exception as e: assert False, str(e) @@ -420,8 +420,9 @@ async def test_get_session_from_url_errors(client: AsyncGoTrueClient): error_description = fake.email() try: await client.get_session_from_url( - url=dummy_url + f"?error_description={error_description}" + url=f"{dummy_url}?error_description={error_description}" ) + assert False except APIError as e: assert e.code == 400 diff --git a/tests/_sync/test_client_with_auto_confirm_enabled.py b/tests/_sync/test_client_with_auto_confirm_enabled.py index 8848e9c4..ce4ba110 100644 --- a/tests/_sync/test_client_with_auto_confirm_enabled.py +++ b/tests/_sync/test_client_with_auto_confirm_enabled.py @@ -332,10 +332,10 @@ def test_get_update_user_after_sign_out(client: SyncGoTrueClient): @pytest.mark.depends(on=[test_get_user_after_sign_out.__name__]) def test_sign_in_with_the_wrong_password(client: SyncGoTrueClient): try: - client.sign_in(email=email, password=password + "2") + client.sign_in(email=email, password=f"{password}2") assert False except APIError: - assert True + pass except Exception as e: assert False, str(e) @@ -414,8 +414,9 @@ def test_get_session_from_url_errors(client: SyncGoTrueClient): error_description = fake.email() try: client.get_session_from_url( - url=dummy_url + f"?error_description={error_description}" + url=f"{dummy_url}?error_description={error_description}" ) + assert False except APIError as e: assert e.code == 400