-
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy path__init__.py
122 lines (100 loc) · 4.73 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""FastAPI Users database adapter for ormar."""
from typing import Any, List, Optional, Type, cast
import ormar
from fastapi_users.db.base import BaseUserDatabase
from fastapi_users.models import UD, BaseOAuthAccount
from ormar.exceptions import NoMatch
from pydantic import UUID4
__version__ = "1.0.0"
class OrmarBaseUserModel(ormar.Model):
class Meta:
tablename = "users"
abstract = True
id = ormar.UUID(primary_key=True, uuid_format="string")
email = ormar.String(index=True, unique=True, nullable=False, max_length=255)
hashed_password = ormar.String(nullable=False, max_length=255)
is_active = ormar.Boolean(default=True, nullable=False)
is_superuser = ormar.Boolean(default=False, nullable=False)
is_verified = ormar.Boolean(default=False, nullable=False)
class OrmarBaseOAuthAccountModel(ormar.Model):
class Meta:
tablename = "oauth_accounts"
abstract = True
id = ormar.UUID(primary_key=True, uuid_format="string")
oauth_name = ormar.String(nullable=False, max_length=255)
access_token = ormar.String(nullable=False, max_length=255)
expires_at = ormar.Integer(nullable=True)
refresh_token = ormar.String(nullable=True, max_length=255)
account_id = ormar.String(index=True, nullable=False, max_length=255)
account_email = ormar.String(nullable=False, max_length=255)
class OrmarUserDatabase(BaseUserDatabase[UD]):
"""
Database adapter for ormar.
:param user_db_model: Pydantic model of a DB representation of a user.
:param model: ormar ORM model.
:param oauth_account_model: Optional ormar ORM model of a OAuth account.
:param select_related: Optional list of relationship names to retrieve with User queries.
"""
model: Type[OrmarBaseUserModel]
oauth_account_model: Optional[Type[OrmarBaseOAuthAccountModel]]
def __init__(
self,
user_db_model: Type[UD],
model: Type[OrmarBaseUserModel],
oauth_account_model: Optional[Type[OrmarBaseOAuthAccountModel]] = None,
select_related: Optional[List[str]] = None
):
super().__init__(user_db_model)
self.model = model
self.oauth_account_model = oauth_account_model
self.select_related = select_related
async def get(self, id: UUID4) -> Optional[UD]:
return await self._get_user(id=id)
async def get_by_email(self, email: str) -> Optional[UD]:
return await self._get_user(email__iexact=email)
async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]:
return await self._get_user(
oauth_accounts__oauth_name=oauth, oauth_accounts__account_id=account_id
)
async def create(self, user: UD) -> UD:
oauth_accounts = getattr(user, "oauth_accounts", [])
model = await self.model(**user.dict(exclude={"oauth_accounts"})).save()
await model.save_related()
if oauth_accounts and self.oauth_account_model:
await self._create_oauth_models(model=model, oauth_accounts=oauth_accounts)
user_db = await self._get_user(id=user.id)
return cast(UD, user_db)
async def update(self, user: UD) -> UD:
oauth_accounts = getattr(user, "oauth_accounts", [])
model = await self._get_db_user(id=user.id)
await model.update(**user.dict(exclude={"oauth_accounts"}))
if oauth_accounts and self.oauth_account_model:
await model.oauth_accounts.clear(keep_reversed=False)
await self._create_oauth_models(model=model, oauth_accounts=oauth_accounts)
user_db = await self._get_user(id=user.id)
return cast(UD, user_db)
async def delete(self, user: UD) -> None:
await self.model.objects.delete(id=user.id)
async def _create_oauth_models(
self, model: OrmarBaseUserModel, oauth_accounts: List[BaseOAuthAccount]
):
if self.oauth_account_model:
oauth_accounts_db = [
self.oauth_account_model(user=model, **oacc.dict())
for oacc in oauth_accounts
]
await self.oauth_account_model.objects.bulk_create(oauth_accounts_db)
async def _get_db_user(self, **kwargs: Any) -> OrmarBaseUserModel:
query = self.model.objects.filter(**kwargs)
if self.oauth_account_model is not None:
query = query.select_related("oauth_accounts")
if self.select_related is not None:
for relation in self.select_related:
query = query.select_related(relation)
return cast(OrmarBaseUserModel, await query.get())
async def _get_user(self, **kwargs: Any) -> Optional[UD]:
try:
user = await self._get_db_user(**kwargs)
except NoMatch:
return None
return self.user_db_model(**user.dict())