Skip to content
This repository was archived by the owner on Sep 8, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions gotrue/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Subscription,
User,
UserAttributes,
UserAttributesDict,
)
from .api import AsyncGoTrueAPI
from .storage import AsyncMemoryStorage, AsyncSupportedStorage
Expand Down Expand Up @@ -301,13 +302,15 @@ async def refresh_session(self) -> Session:
raise ValueError("Not logged in.")
return await self._call_refresh_token()

async def update(self, *, attributes: UserAttributes) -> User:
async def update(
self, *, attributes: Union[UserAttributesDict, UserAttributes]
) -> User:
"""Updates user data, if there is a logged in user.
Parameters
----------
attributes : UserAttributes
The attributes to update.
attributes : UserAttributesDict | UserAttributes
Attributes to update, could be: email, password, email_change_token, data
Returns
-------
Expand All @@ -321,9 +324,15 @@ async def update(self, *, attributes: UserAttributes) -> User:
"""
if not self.current_session:
raise ValueError("Not logged in.")

if isinstance(attributes, dict):
attributes_to_update = UserAttributes(**attributes)
else:
attributes_to_update = attributes

response = await self.api.update_user(
jwt=self.current_session.access_token,
attributes=attributes,
attributes=attributes_to_update,
)
self.current_session.user = response
await self._save_session(session=self.current_session)
Expand Down
11 changes: 9 additions & 2 deletions gotrue/_sync/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Subscription,
User,
UserAttributes,
UserAttributesDict,
)
from .api import SyncGoTrueAPI
from .storage import SyncMemoryStorage, SyncSupportedStorage
Expand Down Expand Up @@ -299,7 +300,7 @@ def refresh_session(self) -> Session:
raise ValueError("Not logged in.")
return self._call_refresh_token()

def update(self, *, attributes: UserAttributes) -> User:
def update(self, *, attributes: Union[UserAttributesDict, UserAttributes]) -> User:
"""Updates user data, if there is a logged in user.
Parameters
Expand All @@ -319,9 +320,15 @@ def update(self, *, attributes: UserAttributes) -> User:
"""
if not self.current_session:
raise ValueError("Not logged in.")

if isinstance(attributes, dict):
attributes_to_update = UserAttributes(**attributes)
else:
attributes_to_update = attributes

response = self.api.update_user(
jwt=self.current_session.access_token,
attributes=attributes,
attributes=attributes_to_update,
)
self.current_session.user = response
self._save_session(session=self.current_session)
Expand Down
15 changes: 15 additions & 0 deletions gotrue/types.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from __future__ import annotations

import sys
from datetime import datetime
from enum import Enum
from time import time
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
from uuid import UUID

if sys.version_info >= (3, 8):
from typing import TypedDict
else:
from typing_extensions import TypedDict

from httpx import Response
from pydantic import BaseModel, root_validator

Expand Down Expand Up @@ -150,3 +156,12 @@ class LinkType(str, Enum):
magiclink = "magiclink"
recovery = "recovery"
invite = "invite"


class UserAttributesDict(TypedDict):
"""Dict version of `UserAttributes`"""

email: Optional[str]
password: Optional[str]
email_change_token: Optional[str]
data: Optional[Any]
26 changes: 23 additions & 3 deletions tests/_async/test_client_with_auto_confirm_enabled.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,25 @@ async def test_update_user(client: AsyncGoTrueClient):
assert False, str(e)


@pytest.mark.asyncio
@pytest.mark.depends(on=[test_sign_in.__name__])
async def test_update_user_dict(client: AsyncGoTrueClient):
try:
await client.init_recover()
response = await client.update(attributes={"data": {"hello": "world"}})
assert isinstance(response, User)
assert response.id
assert response.email == email
assert response.email_confirmed_at
assert response.last_sign_in_at
assert response.created_at
assert response.updated_at
assert response.user_metadata
assert response.user_metadata.get("hello") == "world"
except Exception as e:
assert False, str(e)


@pytest.mark.asyncio
@pytest.mark.depends(on=[test_update_user.__name__])
async def test_get_user_after_update(client: AsyncGoTrueClient):
Expand Down Expand Up @@ -319,10 +338,10 @@ async def test_get_update_user_after_sign_out(client: AsyncGoTrueClient):
@pytest.mark.depends(on=[test_get_user_after_sign_out.__name__])
async def test_sign_in_with_the_wrong_password(client: AsyncGoTrueClient):
try:
await client.sign_in(email=email, password=password + "2")
await client.sign_in(email=email, password=f"{password}2")
assert False
except APIError:
assert True
pass
except Exception as e:
assert False, str(e)

Expand Down Expand Up @@ -401,8 +420,9 @@ async def test_get_session_from_url_errors(client: AsyncGoTrueClient):
error_description = fake.email()
try:
await client.get_session_from_url(
url=dummy_url + f"?error_description={error_description}"
url=f"{dummy_url}?error_description={error_description}"
)

assert False
except APIError as e:
assert e.code == 400
Expand Down
26 changes: 23 additions & 3 deletions tests/_sync/test_client_with_auto_confirm_enabled.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,25 @@ def test_update_user(client: SyncGoTrueClient):
assert False, str(e)


@pytest.mark.asyncio
@pytest.mark.depends(on=[test_sign_in.__name__])
def test_update_user(client: SyncGoTrueClient):
try:
client.init_recover()
response = client.update(attributes={"data": {"hello": "world"}})
assert isinstance(response, User)
assert response.id
assert response.email == email
assert response.email_confirmed_at
assert response.last_sign_in_at
assert response.created_at
assert response.updated_at
assert response.user_metadata
assert response.user_metadata.get("hello") == "world"
except Exception as e:
assert False, str(e)


@pytest.mark.asyncio
@pytest.mark.depends(on=[test_update_user.__name__])
def test_get_user_after_update(client: SyncGoTrueClient):
Expand Down Expand Up @@ -313,10 +332,10 @@ def test_get_update_user_after_sign_out(client: SyncGoTrueClient):
@pytest.mark.depends(on=[test_get_user_after_sign_out.__name__])
def test_sign_in_with_the_wrong_password(client: SyncGoTrueClient):
try:
client.sign_in(email=email, password=password + "2")
client.sign_in(email=email, password=f"{password}2")
assert False
except APIError:
assert True
pass
except Exception as e:
assert False, str(e)

Expand Down Expand Up @@ -395,8 +414,9 @@ def test_get_session_from_url_errors(client: SyncGoTrueClient):
error_description = fake.email()
try:
client.get_session_from_url(
url=dummy_url + f"?error_description={error_description}"
url=f"{dummy_url}?error_description={error_description}"
)

assert False
except APIError as e:
assert e.code == 400
Expand Down