Skip to content

Commit

Permalink
Merge pull request #293 from supabase-community/or/pydantic-v1-v2-sup…
Browse files Browse the repository at this point in the history
…port

Support for pydantic v1 & v2
  • Loading branch information
J0 committed Aug 23, 2023
2 parents c2ed950 + 79cd743 commit 6765f07
Show file tree
Hide file tree
Showing 14 changed files with 140 additions and 63 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ jobs:
run: make run_tests
- name: Upload Coverage
uses: codecov/codecov-action@v3
- name: Run Tests with pydantic v1
run: |
pip install pydantic==1.10.12
make tests_only
publish:
needs: test
Expand Down
4 changes: 2 additions & 2 deletions gotrue/_async/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Any, Dict, List, Optional, Union

from pydantic import TypeAdapter
from pydantic import parse_obj_as

from ..exceptions import APIError
from ..helpers import check_response, encode_uri_component
Expand Down Expand Up @@ -94,7 +94,7 @@ async def list_users(self) -> List[User]:
raise APIError("No users found in response", 400)
if not isinstance(users, list):
raise APIError("Expected a list of users", 400)
return TypeAdapter(List[User]).validate_python(users)
return parse_obj_as(List[User], users)

async def sign_up_with_email(
self,
Expand Down
5 changes: 3 additions & 2 deletions gotrue/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from ..constants import COOKIE_OPTIONS, DEFAULT_HEADERS, GOTRUE_URL, STORAGE_KEY
from ..exceptions import APIError
from ..helpers import model_dump, model_validate
from ..types import (
AuthChangeEvent,
CookieOptions,
Expand Down Expand Up @@ -560,7 +561,7 @@ async def _recover_common(self) -> Optional[Tuple[Session, int, int]]:
and session_raw
and isinstance(session_raw, dict)
):
session = Session.model_validate(session_raw)
session = model_validate(Session, session_raw)
expires_at = int(expires_at_raw)
time_now = round(time())
return session, expires_at, time_now
Expand Down Expand Up @@ -628,7 +629,7 @@ async def _save_session(self, *, session: Session) -> None:
await self._persist_session(session=session)

async def _persist_session(self, *, session: Session) -> None:
data = {"session": session.model_dump(), "expires_at": session.expires_at}
data = {"session": model_dump(session), "expires_at": session.expires_at}
await self.local_storage.set_item(STORAGE_KEY, dumps(data, default=str))

async def _remove_session(self) -> None:
Expand Down
9 changes: 5 additions & 4 deletions gotrue/_async/gotrue_admin_api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

from functools import partial
from typing import Dict, List, Union

from ..helpers import parse_link_response, parse_user_response
from ..helpers import model_validate, parse_link_response, parse_user_response
from ..http_clients import AsyncClient
from ..types import (
AdminUserAttributes,
Expand Down Expand Up @@ -109,7 +110,7 @@ async def list_users(self) -> List[User]:
return await self._request(
"GET",
"admin/users",
xform=lambda data: [User.model_validate(user) for user in data["users"]]
xform=lambda data: [model_validate(User, user) for user in data["users"]]
if "users" in data
else [],
)
Expand Down Expand Up @@ -161,7 +162,7 @@ async def _list_factors(
return await self._request(
"GET",
f"admin/users/{params.get('user_id')}/factors",
xform=AuthMFAAdminListFactorsResponse.model_validate,
xform=partial(model_validate, AuthMFAAdminListFactorsResponse),
)

async def _delete_factor(
Expand All @@ -171,5 +172,5 @@ async def _delete_factor(
return await self._request(
"DELETE",
f"admin/users/{params.get('user_id')}/factors/{params.get('factor_id')}",
xform=AuthMFAAdminDeleteFactorResponse.model_validate,
xform=partial(model_validate, AuthMFAAdminDeleteFactorResponse),
)
4 changes: 2 additions & 2 deletions gotrue/_async/gotrue_base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pydantic import BaseModel
from typing_extensions import Literal, Self

from ..helpers import handle_exception
from ..helpers import handle_exception, model_dump
from ..http_clients import AsyncClient

T = TypeVar("T")
Expand Down Expand Up @@ -108,7 +108,7 @@ async def _request(
url,
headers=headers,
params=query,
json=body.model_dump() if isinstance(body, BaseModel) else body,
json=model_dump(body) if isinstance(body, BaseModel) else body,
)
response.raise_for_status()
result = response if no_resolve_json else response.json()
Expand Down
24 changes: 16 additions & 8 deletions gotrue/_async/gotrue_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from functools import partial
from json import loads
from time import time
from typing import Callable, Dict, List, Tuple, Union
Expand All @@ -20,7 +21,14 @@
AuthRetryableError,
AuthSessionMissingError,
)
from ..helpers import decode_jwt_payload, parse_auth_response, parse_user_response
from ..helpers import (
decode_jwt_payload,
model_dump,
model_dump_json,
model_validate,
parse_auth_response,
parse_user_response,
)
from ..http_clients import AsyncClient
from ..timer import Timer
from ..types import (
Expand Down Expand Up @@ -531,7 +539,7 @@ async def _enroll(self, params: MFAEnrollParams) -> AuthMFAEnrollResponse:
"factors",
body=params,
jwt=session.access_token,
xform=AuthMFAEnrollResponse.model_validate,
xform=partial(model_validate, AuthMFAEnrollResponse),
)
if response.totp.qr_code:
response.totp.qr_code = f"data:image/svg+xml;utf-8,{response.totp.qr_code}"
Expand All @@ -545,7 +553,7 @@ async def _challenge(self, params: MFAChallengeParams) -> AuthMFAChallengeRespon
"POST",
f"factors/{params.get('factor_id')}/challenge",
jwt=session.access_token,
xform=AuthMFAChallengeResponse.model_validate,
xform=partial(model_validate, AuthMFAChallengeResponse),
)

async def _challenge_and_verify(
Expand Down Expand Up @@ -574,9 +582,9 @@ async def _verify(self, params: MFAVerifyParams) -> AuthMFAVerifyResponse:
f"factors/{params.get('factor_id')}/verify",
body=params,
jwt=session.access_token,
xform=AuthMFAVerifyResponse.model_validate,
xform=partial(model_validate, AuthMFAVerifyResponse),
)
session = Session.model_validate(response.model_dump())
session = model_validate(Session, model_dump(response))
await self._save_session(session)
self._notify_all_subscribers("MFA_CHALLENGE_VERIFIED", session)
return response
Expand All @@ -589,7 +597,7 @@ async def _unenroll(self, params: MFAUnenrollParams) -> AuthMFAUnenrollResponse:
"DELETE",
f"factors/{params.get('factor_id')}",
jwt=session.access_token,
xform=AuthMFAUnenrollResponse.model_validate,
xform=partial(AuthMFAUnenrollResponse, model_validate),
)

async def _list_factors(self) -> AuthMFAListFactorsResponse:
Expand Down Expand Up @@ -751,7 +759,7 @@ async def _save_session(self, session: Session) -> None:
value = (expire_in - refresh_duration_before_expires) * 1000
await self._start_auto_refresh_token(value)
if self._persist_session and session.expires_at:
await self._storage.set_item(self._storage_key, session.model_dump_json())
await self._storage.set_item(self._storage_key, model_dump_json(session))

async def _start_auto_refresh_token(self, value: float) -> None:
if self._refresh_token_timer:
Expand Down Expand Up @@ -808,7 +816,7 @@ def _get_valid_session(
except ValueError:
return None
try:
return Session.model_validate(data)
return model_validate(Session, data)
except Exception:
return None

Expand Down
4 changes: 2 additions & 2 deletions gotrue/_sync/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Any, Dict, List, Optional, Union

from pydantic import TypeAdapter
from pydantic import parse_obj_as

from ..exceptions import APIError
from ..helpers import check_response, encode_uri_component
Expand Down Expand Up @@ -94,7 +94,7 @@ def list_users(self) -> List[User]:
raise APIError("No users found in response", 400)
if not isinstance(users, list):
raise APIError("Expected a list of users", 400)
return TypeAdapter(List[User]).validate_python(users)
return parse_obj_as(List[User], users)

def sign_up_with_email(
self,
Expand Down
5 changes: 3 additions & 2 deletions gotrue/_sync/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from ..constants import COOKIE_OPTIONS, DEFAULT_HEADERS, GOTRUE_URL, STORAGE_KEY
from ..exceptions import APIError
from ..helpers import model_dump, model_validate
from ..types import (
AuthChangeEvent,
CookieOptions,
Expand Down Expand Up @@ -556,7 +557,7 @@ def _recover_common(self) -> Optional[Tuple[Session, int, int]]:
and session_raw
and isinstance(session_raw, dict)
):
session = Session.model_validate(session_raw)
session = model_validate(Session, session_raw)
expires_at = int(expires_at_raw)
time_now = round(time())
return session, expires_at, time_now
Expand Down Expand Up @@ -620,7 +621,7 @@ def _save_session(self, *, session: Session) -> None:
self._persist_session(session=session)

def _persist_session(self, *, session: Session) -> None:
data = {"session": session.model_dump(), "expires_at": session.expires_at}
data = {"session": model_dump(session), "expires_at": session.expires_at}
self.local_storage.set_item(STORAGE_KEY, dumps(data, default=str))

def _remove_session(self) -> None:
Expand Down
9 changes: 5 additions & 4 deletions gotrue/_sync/gotrue_admin_api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

from functools import partial
from typing import Dict, List, Union

from ..helpers import parse_link_response, parse_user_response
from ..helpers import model_validate, parse_link_response, parse_user_response
from ..http_clients import SyncClient
from ..types import (
AdminUserAttributes,
Expand Down Expand Up @@ -109,7 +110,7 @@ def list_users(self) -> List[User]:
return self._request(
"GET",
"admin/users",
xform=lambda data: [User.model_validate(user) for user in data["users"]]
xform=lambda data: [model_validate(User, user) for user in data["users"]]
if "users" in data
else [],
)
Expand Down Expand Up @@ -161,7 +162,7 @@ def _list_factors(
return self._request(
"GET",
f"admin/users/{params.get('user_id')}/factors",
xform=AuthMFAAdminListFactorsResponse.model_validate,
xform=partial(model_validate, AuthMFAAdminListFactorsResponse),
)

def _delete_factor(
Expand All @@ -171,5 +172,5 @@ def _delete_factor(
return self._request(
"DELETE",
f"admin/users/{params.get('user_id')}/factors/{params.get('factor_id')}",
xform=AuthMFAAdminDeleteFactorResponse.model_validate,
xform=partial(model_validate, AuthMFAAdminDeleteFactorResponse),
)
4 changes: 2 additions & 2 deletions gotrue/_sync/gotrue_base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pydantic import BaseModel
from typing_extensions import Literal, Self

from ..helpers import handle_exception
from ..helpers import handle_exception, model_dump
from ..http_clients import SyncClient

T = TypeVar("T")
Expand Down Expand Up @@ -108,7 +108,7 @@ def _request(
url,
headers=headers,
params=query,
json=body.model_dump() if isinstance(body, BaseModel) else body,
json=model_dump(body) if isinstance(body, BaseModel) else body,
)
response.raise_for_status()
result = response if no_resolve_json else response.json()
Expand Down
21 changes: 14 additions & 7 deletions gotrue/_sync/gotrue_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from functools import partial
from json import loads
from time import time
from typing import Callable, Dict, List, Tuple, Union
Expand All @@ -20,7 +21,13 @@
AuthRetryableError,
AuthSessionMissingError,
)
from ..helpers import decode_jwt_payload, parse_auth_response, parse_user_response
from ..helpers import (
decode_jwt_payload,
model_dump,
model_validate,
parse_auth_response,
parse_user_response,
)
from ..http_clients import SyncClient
from ..timer import Timer
from ..types import (
Expand Down Expand Up @@ -529,7 +536,7 @@ def _enroll(self, params: MFAEnrollParams) -> AuthMFAEnrollResponse:
"factors",
body=params,
jwt=session.access_token,
xform=AuthMFAEnrollResponse.model_validate,
xform=partial(model_validate, AuthMFAEnrollResponse),
)
if response.totp.qr_code:
response.totp.qr_code = f"data:image/svg+xml;utf-8,{response.totp.qr_code}"
Expand All @@ -543,7 +550,7 @@ def _challenge(self, params: MFAChallengeParams) -> AuthMFAChallengeResponse:
"POST",
f"factors/{params.get('factor_id')}/challenge",
jwt=session.access_token,
xform=AuthMFAChallengeResponse.model_validate,
xform=partial(model_validate, AuthMFAChallengeResponse),
)

def _challenge_and_verify(
Expand Down Expand Up @@ -572,9 +579,9 @@ def _verify(self, params: MFAVerifyParams) -> AuthMFAVerifyResponse:
f"factors/{params.get('factor_id')}/verify",
body=params,
jwt=session.access_token,
xform=AuthMFAVerifyResponse.model_validate,
xform=partial(model_validate, AuthMFAVerifyResponse),
)
session = Session.model_validate(response.model_dump())
session = model_validate(Session, model_dump(response))
self._save_session(session)
self._notify_all_subscribers("MFA_CHALLENGE_VERIFIED", session)
return response
Expand All @@ -587,7 +594,7 @@ def _unenroll(self, params: MFAUnenrollParams) -> AuthMFAUnenrollResponse:
"DELETE",
f"factors/{params.get('factor_id')}",
jwt=session.access_token,
xform=AuthMFAUnenrollResponse.model_validate,
xform=partial(model_validate, AuthMFAUnenrollResponse),
)

def _list_factors(self) -> AuthMFAListFactorsResponse:
Expand Down Expand Up @@ -806,7 +813,7 @@ def _get_valid_session(
except ValueError:
return None
try:
return Session.model_validate(data)
return model_validate(Session, data)
except Exception:
return None

Expand Down

0 comments on commit 6765f07

Please sign in to comment.