Skip to content

Commit

Permalink
feat: change .update method to also allow dictionaries (#130)
Browse files Browse the repository at this point in the history
* Allow users to send a dict instead of UserAttributes model

* Add tests

* Check Python version before trying to import TypedDict

* Remove `TypedDict` from the main `typing` import

* Change format

* Change format of types

* 'Refactored by Sourcery'

Co-authored-by: odiseo0 <pedro.esserweb@gmail.com>
Co-authored-by: Sourcery AI <>
  • Loading branch information
sourcery-ai[bot] and odiseo0 committed Jul 13, 2022
1 parent 76fad56 commit df3f69e
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 12 deletions.
17 changes: 13 additions & 4 deletions gotrue/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Subscription,
User,
UserAttributes,
UserAttributesDict,
)
from .api import AsyncGoTrueAPI
from .storage import AsyncMemoryStorage, AsyncSupportedStorage
Expand Down Expand Up @@ -301,13 +302,15 @@ async def refresh_session(self) -> Session:
raise ValueError("Not logged in.")
return await self._call_refresh_token()

async def update(self, *, attributes: UserAttributes) -> User:
async def update(
self, *, attributes: Union[UserAttributesDict, UserAttributes]
) -> User:
"""Updates user data, if there is a logged in user.
Parameters
----------
attributes : UserAttributes
The attributes to update.
attributes : UserAttributesDict | UserAttributes
Attributes to update, could be: email, password, email_change_token, data
Returns
-------
Expand All @@ -321,9 +324,15 @@ async def update(self, *, attributes: UserAttributes) -> User:
"""
if not self.current_session:
raise ValueError("Not logged in.")

if isinstance(attributes, dict):
attributes_to_update = UserAttributes(**attributes)
else:
attributes_to_update = attributes

response = await self.api.update_user(
jwt=self.current_session.access_token,
attributes=attributes,
attributes=attributes_to_update,
)
self.current_session.user = response
await self._save_session(session=self.current_session)
Expand Down
11 changes: 9 additions & 2 deletions gotrue/_sync/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Subscription,
User,
UserAttributes,
UserAttributesDict,
)
from .api import SyncGoTrueAPI
from .storage import SyncMemoryStorage, SyncSupportedStorage
Expand Down Expand Up @@ -299,7 +300,7 @@ def refresh_session(self) -> Session:
raise ValueError("Not logged in.")
return self._call_refresh_token()

def update(self, *, attributes: UserAttributes) -> User:
def update(self, *, attributes: Union[UserAttributesDict, UserAttributes]) -> User:
"""Updates user data, if there is a logged in user.
Parameters
Expand All @@ -319,9 +320,15 @@ def update(self, *, attributes: UserAttributes) -> User:
"""
if not self.current_session:
raise ValueError("Not logged in.")

if isinstance(attributes, dict):
attributes_to_update = UserAttributes(**attributes)
else:
attributes_to_update = attributes

response = self.api.update_user(
jwt=self.current_session.access_token,
attributes=attributes,
attributes=attributes_to_update,
)
self.current_session.user = response
self._save_session(session=self.current_session)
Expand Down
15 changes: 15 additions & 0 deletions gotrue/types.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from __future__ import annotations

import sys
from datetime import datetime
from enum import Enum
from time import time
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
from uuid import UUID

if sys.version_info >= (3, 8):
from typing import TypedDict
else:
from typing_extensions import TypedDict

from httpx import Response
from pydantic import BaseModel, root_validator

Expand Down Expand Up @@ -150,3 +156,12 @@ class LinkType(str, Enum):
magiclink = "magiclink"
recovery = "recovery"
invite = "invite"


class UserAttributesDict(TypedDict):
"""Dict version of `UserAttributes`"""

email: Optional[str]
password: Optional[str]
email_change_token: Optional[str]
data: Optional[Any]
26 changes: 23 additions & 3 deletions tests/_async/test_client_with_auto_confirm_enabled.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,25 @@ async def test_update_user(client: AsyncGoTrueClient):
assert False, str(e)


@pytest.mark.asyncio
@pytest.mark.depends(on=[test_sign_in.__name__])
async def test_update_user_dict(client: AsyncGoTrueClient):
try:
await client.init_recover()
response = await client.update(attributes={"data": {"hello": "world"}})
assert isinstance(response, User)
assert response.id
assert response.email == email
assert response.email_confirmed_at
assert response.last_sign_in_at
assert response.created_at
assert response.updated_at
assert response.user_metadata
assert response.user_metadata.get("hello") == "world"
except Exception as e:
assert False, str(e)


@pytest.mark.asyncio
@pytest.mark.depends(on=[test_update_user.__name__])
async def test_get_user_after_update(client: AsyncGoTrueClient):
Expand Down Expand Up @@ -319,10 +338,10 @@ async def test_get_update_user_after_sign_out(client: AsyncGoTrueClient):
@pytest.mark.depends(on=[test_get_user_after_sign_out.__name__])
async def test_sign_in_with_the_wrong_password(client: AsyncGoTrueClient):
try:
await client.sign_in(email=email, password=password + "2")
await client.sign_in(email=email, password=f"{password}2")
assert False
except APIError:
assert True
pass
except Exception as e:
assert False, str(e)

Expand Down Expand Up @@ -401,8 +420,9 @@ async def test_get_session_from_url_errors(client: AsyncGoTrueClient):
error_description = fake.email()
try:
await client.get_session_from_url(
url=dummy_url + f"?error_description={error_description}"
url=f"{dummy_url}?error_description={error_description}"
)

assert False
except APIError as e:
assert e.code == 400
Expand Down
26 changes: 23 additions & 3 deletions tests/_sync/test_client_with_auto_confirm_enabled.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,25 @@ def test_update_user(client: SyncGoTrueClient):
assert False, str(e)


@pytest.mark.asyncio
@pytest.mark.depends(on=[test_sign_in.__name__])
def test_update_user(client: SyncGoTrueClient):
try:
client.init_recover()
response = client.update(attributes={"data": {"hello": "world"}})
assert isinstance(response, User)
assert response.id
assert response.email == email
assert response.email_confirmed_at
assert response.last_sign_in_at
assert response.created_at
assert response.updated_at
assert response.user_metadata
assert response.user_metadata.get("hello") == "world"
except Exception as e:
assert False, str(e)


@pytest.mark.asyncio
@pytest.mark.depends(on=[test_update_user.__name__])
def test_get_user_after_update(client: SyncGoTrueClient):
Expand Down Expand Up @@ -313,10 +332,10 @@ def test_get_update_user_after_sign_out(client: SyncGoTrueClient):
@pytest.mark.depends(on=[test_get_user_after_sign_out.__name__])
def test_sign_in_with_the_wrong_password(client: SyncGoTrueClient):
try:
client.sign_in(email=email, password=password + "2")
client.sign_in(email=email, password=f"{password}2")
assert False
except APIError:
assert True
pass
except Exception as e:
assert False, str(e)

Expand Down Expand Up @@ -395,8 +414,9 @@ def test_get_session_from_url_errors(client: SyncGoTrueClient):
error_description = fake.email()
try:
client.get_session_from_url(
url=dummy_url + f"?error_description={error_description}"
url=f"{dummy_url}?error_description={error_description}"
)

assert False
except APIError as e:
assert e.code == 400
Expand Down

0 comments on commit df3f69e

Please sign in to comment.