Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions auth_backend/auth_plugins/auth_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ def __init_subclass__(cls, **kwargs):

@staticmethod
@abstractmethod
async def _register(**kwargs) -> object:
async def _register(*args, **kwargs) -> object:
raise NotImplementedError()

@staticmethod
@abstractmethod
async def _login(**kwargs) -> Session:
async def _login(*args, **kwargs) -> Session:
raise NotImplementedError()
2 changes: 0 additions & 2 deletions auth_backend/auth_plugins/email.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,15 +156,13 @@ async def _add_to_db(user_inp: EmailRegister, confirmation_token: str, user: Use
)
db.session.flush()


@staticmethod
async def _change_confirmation_link(user: User, confirmation_token: str) -> None:
if user.auth_methods.confirmed.value == "true":
raise AlreadyExists(User, user.id)
else:
user.auth_methods.confirmation_token.value = confirmation_token


@staticmethod
async def _get_user_by_token_and_id(id: int, token: str) -> User:
user: User = db.session.query(User).get(id)
Expand Down
61 changes: 58 additions & 3 deletions auth_backend/models/base.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,76 @@
from __future__ import annotations
import re

from sqlalchemy import not_, Integer
from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.declarative import as_declarative, declared_attr
from sqlalchemy.orm import Session, Mapped, mapped_column, Query

from auth_backend.exceptions import ObjectNotFound


@as_declarative()
class Base:
"""Base class for all database entities"""

@classmethod
@declared_attr
def __tablename__(cls) -> str: # pylint: disable=no-self-argument
"""Generate database table name automatically.
Convert CamelCase class name to snake_case db table name.
"""
return re.sub(r"(?<!^)(?=[A-Z])", "_", cls.__name__).lower()

def __repr__(self) -> str:
def __repr__(self):
attrs = []
for c in self.__table__.columns:
attrs.append(f"{c.name}={getattr(self, c.name)}")
return "{}({})".format(self.__class__.__name__, ', '.join(attrs))
return "{}({})".format(c.__class__.__name__, ', '.join(attrs))


class BaseDbModel(Base):
__abstract__ = True
id: Mapped[int] = mapped_column(Integer, primary_key=True)

@classmethod
def create(cls, *, session: Session, **kwargs) -> BaseDbModel:
obj = cls(**kwargs)
session.add(obj)
session.flush()
return obj

@classmethod
def get_all(cls, *, with_deleted: bool = False, session: Session) -> Query:
"""Get all objects with soft deletes"""
objs = session.query(cls)
if not with_deleted and hasattr(cls, "is_deleted"):
objs = objs.filter(not_(cls.is_deleted))
return objs

@classmethod
def get(cls, id: int, *, with_deleted=False, session: Session) -> BaseDbModel:
"""Get object with soft deletes"""
objs = session.query(cls)
if not with_deleted and hasattr(cls, "is_deleted"):
objs = objs.filter(not_(cls.is_deleted))
try:
return objs.filter(cls.id == id).one()
except NoResultFound:
raise ObjectNotFound(cls, id)

@classmethod
def update(cls, id: int, *, session: Session, **kwargs) -> BaseDbModel:
obj = cls.get(id, session=session)
for k, v in kwargs.items():
setattr(obj, k, v)
session.flush()
return obj

@classmethod
def delete(cls, id: int, *, session: Session) -> None:
"""Soft delete object if possible, else hard delete"""
obj = cls.get(id, session=session)
if hasattr(obj, "is_deleted"):
obj.is_deleted = True
else:
session.delete(obj)
session.flush()
92 changes: 76 additions & 16 deletions auth_backend/models/db.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from __future__ import annotations

import datetime
from typing import Iterator

import sqlalchemy.orm
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy import String, Integer, ForeignKey, DateTime
from sqlalchemy import String, Integer, ForeignKey, DateTime, Boolean
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import Mapped, mapped_column, relationship, backref

from auth_backend.models.base import Base
from auth_backend.models.base import BaseDbModel


class ParamDict:

# Type hints
email: AuthMethod
hashed_password: AuthMethod
Expand All @@ -32,12 +32,28 @@ def __new__(cls, methods: list[AuthMethod], *args, **kwargs):
return obj


class User(Base):

id: Mapped[int] = mapped_column(Integer, primary_key=True)
class User(BaseDbModel):
is_deleted: Mapped[bool] = mapped_column(Boolean, default=False)

_auth_methods: Mapped[list[AuthMethod]] = relationship(
"AuthMethod",
foreign_keys="AuthMethod.user_id",
primaryjoin="and_(User.id==AuthMethod.user_id, not_(AuthMethod.is_deleted))",
)
sessions: Mapped[list[UserSession]] = relationship(
"UserSession", foreign_keys="UserSession.user_id", back_populates="user"
)
groups: Mapped[list[Group]] = relationship(
"Group",
secondary="user_group",
back_populates="users",
primaryjoin="and_(User.id==UserGroup.user_id, not_(UserGroup.is_deleted))",
secondaryjoin="and_(Group.id==UserGroup.group_id, not_(Group.is_deleted))",
)

_auth_methods: Mapped[list["AuthMethod"]] = relationship("AuthMethod", foreign_keys="AuthMethod.user_id")
sessions: Mapped[list["UserSession"]] = relationship("UserSession", foreign_keys="UserSession.user_id")
@hybrid_property
def active_sessions(self) -> list:
return [row for row in self.sessions if not row.expired]

@hybrid_property
def auth_methods(self) -> ParamDict:
Expand All @@ -49,24 +65,68 @@ def auth_methods(self) -> ParamDict:
return ParamDict.__new__(ParamDict, self._auth_methods)


class AuthMethod(Base):
class Group(BaseDbModel):
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True, nullable=False)
parent_id: Mapped[int] = mapped_column(Integer, ForeignKey("group.id"), nullable=True)
create_ts: Mapped[datetime.datetime] = mapped_column(DateTime, default=datetime.datetime.utcnow)
is_deleted: Mapped[bool] = mapped_column(Boolean, default=False)

child: Mapped[list[Group]] = relationship(
"Group",
backref=backref("parent", remote_side=[id]),
primaryjoin="and_(Group.id==Group.parent_id, not_(Group.is_deleted))",
)
users: Mapped[list[User]] = relationship(
"User",
secondary="user_group",
back_populates="groups",
primaryjoin="and_(Group.id==UserGroup.group_id, not_(UserGroup.is_deleted))",
secondaryjoin="and_(User.id==UserGroup.user_id, not_(User.is_deleted))",
)

@hybrid_property
def parents(self) -> Iterator[Group]:
parent = self
while parent := parent.parent:
yield parent


class UserGroup(BaseDbModel):
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"))
group_id: Mapped[int] = mapped_column(Integer, ForeignKey("group.id"))
is_deleted: Mapped[bool] = mapped_column(Boolean, default=False)


class AuthMethod(BaseDbModel):
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("user.id"))
auth_method: Mapped[str] = mapped_column(String)
param: Mapped[str] = mapped_column(String)
value: Mapped[str] = mapped_column(String)
is_deleted: Mapped[bool] = mapped_column(Boolean, default=False)

user: Mapped["User"] = relationship("User", foreign_keys=[user_id], back_populates="_auth_methods")
user: Mapped[User] = relationship(
"User",
foreign_keys=[user_id],
back_populates="_auth_methods",
primaryjoin="and_(AuthMethod.user_id==User.id, not_(User.is_deleted))",
)


class UserSession(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True)
class UserSession(BaseDbModel):
user_id: Mapped[int] = mapped_column(Integer, sqlalchemy.ForeignKey("user.id"))
expires: Mapped[datetime.datetime] = mapped_column(DateTime, default=datetime.datetime.utcnow() + datetime.timedelta(days=7))
expires: Mapped[datetime.datetime] = mapped_column(
DateTime, default=datetime.datetime.utcnow() + datetime.timedelta(days=7)
)
token: Mapped[str] = mapped_column(String, unique=True)

user: Mapped["User"] = relationship("User", foreign_keys=[user_id], back_populates="sessions")
user: Mapped[User] = relationship(
"User",
foreign_keys=[user_id],
back_populates="sessions",
primaryjoin="and_(UserSession.user_id==User.id, not_(User.is_deleted))",
)

@hybrid_property
def expired(self):
def expired(self) -> bool:
return self.expires <= datetime.datetime.utcnow()
8 changes: 5 additions & 3 deletions auth_backend/routes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
from auth_backend.auth_plugins.auth_method import AUTH_METHODS
from auth_backend.settings import get_settings
from .user_session import logout_router
from .user_groups import user_groups
from .groups import groups

settings = get_settings()

app = FastAPI()


app.add_middleware(
DBSessionMiddleware, db_url=settings.DB_DSN, engine_args={"pool_pre_ping": True}
)
app.add_middleware(DBSessionMiddleware, db_url=settings.DB_DSN, engine_args={"pool_pre_ping": True})

app.add_middleware(
CORSMiddleware,
Expand All @@ -24,6 +24,8 @@
)

app.include_router(logout_router)
app.include_router(user_groups)
app.include_router(groups)
if not settings.ENABLED_AUTH_METHODS:
for method in AUTH_METHODS.values():
app.include_router(router := method().router, prefix=router.prefix, tags=[method.get_name()])
Expand Down
68 changes: 68 additions & 0 deletions auth_backend/routes/groups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Literal

from fastapi import APIRouter, HTTPException, Depends, Query
from fastapi_sqlalchemy import db

from auth_backend.exceptions import ObjectNotFound, AlreadyExists
from auth_backend.models.db import Group as DbGroup
from .models.models import Group, GroupPost, GroupsGet, GroupPatch, GroupGet
from ..base import ResponseModel
from ..utils.security import UnionAuth

auth = UnionAuth()

groups = APIRouter(prefix="/group", tags=["Groups"])


@groups.get("/{id}", response_model=GroupGet, response_model_exclude_unset=True)
async def get_group(id: int, info: list[Literal["child"]] = Query(default=[])) -> dict[str, str | int]:
group = DbGroup.get(id, session=db.session)
result = {}
result = result | Group.from_orm(group).dict()
if "child" in info:
result = result | {"child": group.child}
return GroupGet(**result).dict(exclude_unset=True)


@groups.post("", response_model=Group)
async def create_group(group_inp: GroupPost, _: dict[str, str] = Depends(auth)) -> Group:
if group_inp.parent_id and not db.session.query(DbGroup).get(group_inp.parent_id):
raise ObjectNotFound(Group, group_inp.parent_id)
if DbGroup.get_all(session=db.session).filter(DbGroup.name == group_inp.name).one_or_none():
raise HTTPException(status_code=409, detail=ResponseModel(status="Error", message="Name already exists").json())
group = DbGroup.create(session=db.session, **group_inp.dict())
db.session.commit()
return Group.from_orm(group)


@groups.patch("/{id}", response_model=Group)
async def patch_group(id: int, group_inp: GroupPatch, _: dict[str, str] = Depends(auth)) -> Group:
if (
exists_check := DbGroup.get_all(session=db.session)
.filter(DbGroup.name == group_inp.name, DbGroup.id != id)
.one_or_none()
):
raise AlreadyExists(Group, exists_check.id)
group = DbGroup.get(id, session=db.session)
if group_inp.parent_id in (row.id for row in group.child):
raise HTTPException(status_code=400, detail=ResponseModel(status="Error", message="Cycle detected").json())
patched = DbGroup.update(id, session=db.session, **group_inp.dict(exclude_unset=True))
db.session.commit()
return Group.from_orm(patched)


@groups.delete("/{id}", response_model=None)
async def delete_group(id: int, _: dict[str, str] = Depends(auth)) -> None:
group: DbGroup = DbGroup.get(id, session=db.session)
if child := group.child:
for children in child:
children.parent = group.parent
db.session.flush()
DbGroup.delete(id, session=db.session)
db.session.commit()
return None


@groups.get("", response_model=GroupsGet)
async def get_groups() -> GroupsGet:
return GroupsGet(items=DbGroup.get_all(session=db.session).all())
Loading