Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions auth_backend/__main__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
import logging

import uvicorn

from auth_backend.routes import app

if __name__ == '__main__':

logging.basicConfig(
filename=f'logger_{__name__}.log',
level=logging.INFO,
format='%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
)

uvicorn.run(app)
204 changes: 125 additions & 79 deletions auth_backend/auth_plugins/email.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion auth_backend/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def create(cls, *, session: Session, **kwargs) -> BaseDbModel:
return obj

@classmethod
def get_all(cls, *, with_deleted: bool = False, session: Session) -> Query:
def query(cls, *, with_deleted: bool = False, session: Session) -> Query:
"""Get all objects with soft deletes"""
objs = session.query(cls)
if not with_deleted and hasattr(cls, "is_deleted"):
Expand Down
22 changes: 11 additions & 11 deletions auth_backend/routes/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from fastapi_sqlalchemy import db

from auth_backend.exceptions import ObjectNotFound, AlreadyExists
from auth_backend.models.db import Group as DbGroup
from .models.models import Group, GroupPost, GroupsGet, GroupPatch, GroupGet
from ..base import ResponseModel
from ..utils.security import UnionAuth
from auth_backend.models.db import Group as DbGroup, UserSession
from auth_backend.routes.models.models import Group, GroupPost, GroupsGet, GroupPatch, GroupGet
from auth_backend.base import ResponseModel
from auth_backend.utils.security import UnionAuth

auth = UnionAuth()

Expand All @@ -25,20 +25,20 @@ async def get_group(id: int, info: list[Literal["child"]] = Query(default=[])) -


@groups.post("", response_model=Group)
async def create_group(group_inp: GroupPost, _: dict[str, str] = Depends(auth)) -> Group:
async def create_group(group_inp: GroupPost, _: UserSession = Depends(auth)) -> Group:
if group_inp.parent_id and not db.session.query(DbGroup).get(group_inp.parent_id):
raise ObjectNotFound(Group, group_inp.parent_id)
if DbGroup.get_all(session=db.session).filter(DbGroup.name == group_inp.name).one_or_none():
if DbGroup.query(session=db.session).filter(DbGroup.name == group_inp.name).one_or_none():
raise HTTPException(status_code=409, detail=ResponseModel(status="Error", message="Name already exists").json())
group = DbGroup.create(session=db.session, **group_inp.dict())
db.session.commit()
return Group.from_orm(group)


@groups.patch("/{id}", response_model=Group)
async def patch_group(id: int, group_inp: GroupPatch, _: dict[str, str] = Depends(auth)) -> Group:
async def patch_group(id: int, group_inp: GroupPatch, _: UserSession = Depends(auth)) -> Group:
if (
exists_check := DbGroup.get_all(session=db.session)
exists_check := DbGroup.query(session=db.session)
.filter(DbGroup.name == group_inp.name, DbGroup.id != id)
.one_or_none()
):
Expand All @@ -52,11 +52,11 @@ async def patch_group(id: int, group_inp: GroupPatch, _: dict[str, str] = Depend


@groups.delete("/{id}", response_model=None)
async def delete_group(id: int, _: dict[str, str] = Depends(auth)) -> None:
async def delete_group(id: int, _: UserSession = Depends(auth)) -> None:
group: DbGroup = DbGroup.get(id, session=db.session)
if child := group.child:
for children in child:
children.parent = group.parent
children.parent_id = group.parent_id
db.session.flush()
DbGroup.delete(id, session=db.session)
db.session.commit()
Expand All @@ -65,4 +65,4 @@ async def delete_group(id: int, _: dict[str, str] = Depends(auth)) -> None:

@groups.get("", response_model=GroupsGet)
async def get_groups() -> GroupsGet:
return GroupsGet(items=DbGroup.get_all(session=db.session).all())
return GroupsGet(items=DbGroup.query(session=db.session).all())
1 change: 0 additions & 1 deletion auth_backend/routes/models/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from __future__ import annotations
import datetime

from pydantic import Field

Expand Down
6 changes: 3 additions & 3 deletions auth_backend/routes/user_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from starlette.exceptions import HTTPException

from auth_backend.models.db import Group, UserGroup
from .models.models import UserGroupGet, GroupUserListGet, UserGroupPost
from ..base import ResponseModel
from ..utils.security import UnionAuth
from auth_backend.routes.models.models import UserGroupGet, GroupUserListGet, UserGroupPost
from auth_backend.base import ResponseModel
from auth_backend.utils.security import UnionAuth

auth = UnionAuth()
user_groups = APIRouter(prefix="/group/{id}/user", tags=["User Groups"])
Expand Down
28 changes: 11 additions & 17 deletions auth_backend/routes/user_session.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,37 @@
from datetime import datetime
from typing import Literal, Union
from typing import Literal

from fastapi import APIRouter, Header, HTTPException, Query
from fastapi import APIRouter, Query, Depends
from fastapi_sqlalchemy import db
from starlette.responses import JSONResponse

from auth_backend.base import ResponseModel
from auth_backend.exceptions import AuthFailed
from auth_backend.exceptions import SessionExpired
from auth_backend.models.db import UserSession, Group
from .models.models import UserGroups, UserIndirectGroups, UserInfo, UserGet
from auth_backend.routes.models.models import UserGroups, UserIndirectGroups, UserInfo, UserGet
from auth_backend.utils.security import UnionAuth

auth = UnionAuth()

logout_router = APIRouter(prefix="", tags=["Logout"])


@logout_router.post("/logout", response_model=str)
async def logout(token: str = Header(min_length=1)) -> JSONResponse:
session = db.session.query(UserSession).filter(UserSession.token == token).one_or_none()
if not session:
raise AuthFailed(error="Session not found")
async def logout(session: UserSession = Depends(auth)) -> JSONResponse:
if session.expired:
raise SessionExpired(session.token)
session.expires = datetime.utcnow()
db.session.commit()
return JSONResponse(status_code=200, content=ResponseModel(status="Success", message="Logout successful").json())


@logout_router.post("/me", response_model_exclude_unset=True, response_model=UserGet)
@logout_router.get("/me", response_model_exclude_unset=True, response_model=UserGet)
async def me(
token: str = Header(min_length=1), info: list[Literal["groups", "indirect_groups", ""]] = Query(default=[])
session: UserSession = Depends(auth), info: list[Literal["groups", "indirect_groups", ""]] = Query(default=[])
) -> dict[str, str | int]:
if not token:
raise HTTPException(status_code=400, detail=ResponseModel(status="Error", message="Header missing").json())
session: UserSession = db.session.query(UserSession).filter(UserSession.token == token).one_or_none()
if not session:
raise HTTPException(status_code=404, detail=ResponseModel(status="Error", message="Session not found").json())
if session.expired:
raise SessionExpired(token)
result = {}
raise SessionExpired(str(session.token))
result: dict[str, str | int] = {}
result = result | UserInfo(id=session.user_id, email=session.user.auth_methods.email.value).dict()
if "groups" in info:
result = result | UserGroups(groups=session.user.groups).dict()
Expand Down
2 changes: 1 addition & 1 deletion auth_backend/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class Settings(BaseSettings):

EMAIL: str | None
APPLICATION_HOST: str = "localhost"
EMAIL_PASS: str = None
EMAIL_PASS: str | None
SMTP_HOST: str = 'smtp.gmail.com'
SMTP_PORT: int = 587
ENABLED_AUTH_METHODS: list[str] | None
Expand Down
12 changes: 5 additions & 7 deletions auth_backend/utils/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,27 @@
class UnionAuth(SecurityBase):
model = APIKey.construct(in_=APIKeyIn.header, name="Authorization")
scheme_name = "token"
auth_url: str

def __init__(self, auth_url: str = "", auto_error=True) -> None:
def __init__(self, auto_error=True) -> None:
super().__init__()
self.auto_error = auto_error
self.auth_url = auth_url

def _except(self):
if self.auto_error:
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Not authenticated")
else:
return {}
return None

async def __call__(
self,
request: Request,
) -> dict[str, str | int]:
) -> UserSession:
token = request.headers.get("Authorization")
if not token:
return self._except()
user_session: UserSession = (
UserSession.get_all(session=db.session).filter(UserSession.token == token).one_or_none()
UserSession.query(session=db.session).filter(UserSession.token == token).one_or_none()
)
if not user_session:
self._except()
return {"id": user_session.user_id, "email": user_session.user.auth_methods.email.value}
return user_session
69 changes: 52 additions & 17 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,64 @@
import datetime
from unittest.mock import Mock
from unittest.mock import patch

import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from starlette import status

import auth_backend.auth_plugins.email
from auth_backend.models import AuthMethod, User
from auth_backend.models.db import Group, UserSession, UserGroup
from auth_backend.routes.base import app
from auth_backend.settings import get_settings
import auth_backend.utils.security


@pytest.fixture(scope="session")
@pytest.fixture
def client():
auth_backend.auth_plugins.email.send_confirmation_email = Mock(return_value=None)
auth_backend.auth_plugins.email.send_change_password_confirmation = Mock(return_value=None)
auth_backend.auth_plugins.email.send_changes_password_notification = Mock(return_value=None)
auth_backend.auth_plugins.email.send_reset_email = Mock(return_value=None)
auth_backend.utils.security.UnionAuth.__call__ = Mock(return_value={"id": 0, "email": ""})
patcher1 = patch("auth_backend.auth_plugins.email.send_confirmation_email")
patcher2 = patch("auth_backend.auth_plugins.email.send_change_password_confirmation")
patcher3 = patch("auth_backend.auth_plugins.email.send_changes_password_notification")
patcher4 = patch("auth_backend.auth_plugins.email.send_reset_email")
patcher5 = patch("auth_backend.utils.security.UnionAuth.__call__")
patcher1.start()
patcher2.start()
patcher3.start()
patcher4.start()
patcher5.start()
patcher1.return_value = None
patcher2.return_value = None
patcher3.return_value = None
patcher4.return_value = None
patcher5.return_value = {"id": 0, "email": None}
client = TestClient(app)
yield client
patcher1.stop()
patcher2.stop()
patcher3.stop()
patcher4.stop()
patcher5.stop()


@pytest.fixture
def client_auth():
patcher1 = patch("auth_backend.auth_plugins.email.send_confirmation_email")
patcher2 = patch("auth_backend.auth_plugins.email.send_change_password_confirmation")
patcher3 = patch("auth_backend.auth_plugins.email.send_changes_password_notification")
patcher4 = patch("auth_backend.auth_plugins.email.send_reset_email")
patcher1.start()
patcher2.start()
patcher3.start()
patcher4.start()
patcher1.return_value = None
patcher2.return_value = None
patcher3.return_value = None
patcher4.return_value = None
client = TestClient(app)
yield client
patcher1.stop()
patcher2.stop()
patcher3.stop()
patcher4.stop()


@pytest.fixture(scope='session')
Expand All @@ -35,10 +70,10 @@ def dbsession():


@pytest.fixture()
def user_id(client: TestClient, dbsession):
def user_id(client_auth: TestClient, dbsession):
time = datetime.datetime.utcnow()
body = {"email": f"user{time}@example.com", "password": "string"}
client.post("/email/registration", json=body)
client_auth.post("/email/registration", json=body)
db_user: AuthMethod = (
dbsession.query(AuthMethod).filter(AuthMethod.value == body['email'], AuthMethod.param == 'email').one()
)
Expand All @@ -54,15 +89,15 @@ def user_id(client: TestClient, dbsession):


@pytest.fixture()
def user(client: TestClient, dbsession):
def user(client_auth: TestClient, dbsession):
url = "/email/login"
time = datetime.datetime.utcnow()
body = {"email": f"user{time}@example.com", "password": "string"}
client.post("/email/registration", json=body)
client_auth.post("/email/registration", json=body)
db_user: AuthMethod = (
dbsession.query(AuthMethod).filter(AuthMethod.value == body['email'], AuthMethod.param == 'email').one()
)
response = client.post(url, json=body)
response = client_auth.post(url, json=body)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
token = (
dbsession.query(AuthMethod)
Expand All @@ -73,9 +108,9 @@ def user(client: TestClient, dbsession):
)
.one()
)
response = client.get(f"/email/approve?token={token.value}")
response = client_auth.get(f"/email/approve?token={token.value}")
assert response.status_code == status.HTTP_200_OK
response = client.post(url, json=body)
response = client_auth.post(url, json=body)
assert response.status_code == status.HTTP_200_OK
yield {"user_id": db_user.user_id, "body": body, "login_json": response.json()}
session = dbsession.query(UserSession).filter(UserSession.user_id == db_user.user_id).all()
Expand All @@ -88,7 +123,7 @@ def user(client: TestClient, dbsession):
dbsession.commit()


@pytest.fixture(scope="module")
@pytest.fixture
def parent_id(client, dbsession):
time = datetime.datetime.utcnow()
body = {"name": f"group{time}", "parent_id": None}
Expand Down
Loading