diff --git a/.github/labeler.yml b/.github/labeler.yml index ae83bb87..45caf4e3 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -13,44 +13,50 @@ 'module: database': - src/app/db/* +'module: schemas': +- src/app/schemas/* + +'module: models': +- src/app/models/* + 'module: services': - src/app/services/* 'route: accesses': -- src/app/api/routes/accesses.py +- src/app/api/endpoints/accesses.py - src/app/api/crud/accesses.py 'route: alerts': -- src/app/api/routes/alerts.py +- src/app/api/endpoints/alerts.py - src/app/api/crud/alerts.py 'route: devices': -- src/app/api/routes/devices.py +- src/app/api/endpoints/devices.py 'route: events': -- src/app/api/routes/events.py +- src/app/api/endpoints/events.py 'route: groups': -- src/app/api/routes/groups.py +- src/app/api/endpoints/groups.py - src/app/api/crud/groups.py 'route: installations': -- src/app/api/routes/installations.py +- src/app/api/endpoints/installations.py 'route: login': -- src/app/api/routes/login.py +- src/app/api/endpoints/login.py 'route: media': -- src/app/api/routes/media.py +- src/app/api/endpoints/media.py 'route: sites': -- src/app/api/routes/sites.py +- src/app/api/endpoints/sites.py 'route: users': -- src/app/api/routes/users.py +- src/app/api/endpoints/users.py 'route: webhooks': -- src/app/api/routes/webhooks.py +- src/app/api/endpoints/webhooks.py - src/app/api/crud/webhooks.py 'topic: build': diff --git a/README.md b/README.md index bc869f77..c56687a2 100644 --- a/README.md +++ b/README.md @@ -2,10 +2,10 @@

- CI Status + CI Status - Documentation Status + Documentation Status Test coverage percentage diff --git a/docker-compose-dev.yml b/docker-compose-dev.yml index 66e1dade..ffe9851e 100644 --- a/docker-compose-dev.yml +++ b/docker-compose-dev.yml @@ -27,13 +27,6 @@ services: - S3_REGION=${S3_REGION} depends_on: - db - proxy: - build: nginx - ports: - - 80:80 - - 443:443 - depends_on: - - backend db: image: postgres:12.1-alpine volumes: diff --git a/docker-compose.yml b/docker-compose.yml index 03126c73..e620c2c8 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -36,13 +36,6 @@ services: - POSTGRES_USER=dummy_pg_user - POSTGRES_PASSWORD=dummy_pg_pwd - POSTGRES_DB=dummy_pg_db - proxy: - build: nginx - ports: - - 80:80 - - 443:443 - depends_on: - - backend volumes: postgres_data: diff --git a/nginx/Dockerfile b/nginx/Dockerfile deleted file mode 100644 index 17b5f688..00000000 --- a/nginx/Dockerfile +++ /dev/null @@ -1,3 +0,0 @@ -FROM nginx:latest - -COPY nginx.conf /etc/nginx/nginx.conf diff --git a/nginx/nginx.conf b/nginx/nginx.conf deleted file mode 100644 index 04ace48a..00000000 --- a/nginx/nginx.conf +++ /dev/null @@ -1,55 +0,0 @@ - -worker_processes 1; - -events { - worker_connections 1024; # increase if you have lots of clients - accept_mutex off; # set to 'on' if nginx worker_processes > 1 -} - -http { - include mime.types; - # fallback in case we can't determine a type - default_type application/octet-stream; - access_log /var/log/nginx/access.log combined; - sendfile on; - - upstream app_server { - # fail_timeout=0 means we always retry an upstream even if it failed - # to return a good HTTP response - - # for a TCP configuration - server backend:8080 fail_timeout=0; - } - - server { - # if no Host match, close the connection to prevent host spoofing - listen 80 default_server; - return 444; - } - - server { - # use 'listen 80 deferred;' for Linux - # use 'listen 80 accept_filter=httpready;' for FreeBSD - client_max_body_size 4G; - - # set the correct host(s) for your site - server_name api.pyronear.org; - - keepalive_timeout 5; - - location / { - # checks for static file, if not found proxy to app - try_files $uri @proxy_to_app; - } - - location @proxy_to_app { - proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; - proxy_set_header X-Forwarded-Proto $scheme; - proxy_set_header Host $http_host; - # we don't want nginx trying to do something clever with - # redirects, we set the Host: header above already. - proxy_redirect off; - proxy_pass http://app_server; - } - } -} diff --git a/src/app/api/crud/accesses.py b/src/app/api/crud/accesses.py index 67a690ef..275b7379 100644 --- a/src/app/api/crud/accesses.py +++ b/src/app/api/crud/accesses.py @@ -10,7 +10,7 @@ from app.api import security from app.api.crud import base -from app.api.schemas import ( +from app.schemas import ( AccessCreation, AccessRead, Cred, diff --git a/src/app/api/crud/alerts.py b/src/app/api/crud/alerts.py index f6e59caa..4f86186d 100644 --- a/src/app/api/crud/alerts.py +++ b/src/app/api/crud/alerts.py @@ -10,9 +10,9 @@ import app.config as cfg from app.api import crud -from app.api.routes.events import create_event -from app.api.schemas import AlertIn, AlertOut, EventIn +from app.api.endpoints.events import create_event from app.db import alerts +from app.schemas import AlertIn, AlertOut, EventIn async def resolve_previous_alert(device_id: int) -> Optional[AlertOut]: diff --git a/src/app/api/crud/authorizations.py b/src/app/api/crud/authorizations.py index 2fe875fe..d5c10b4c 100644 --- a/src/app/api/crud/authorizations.py +++ b/src/app/api/crud/authorizations.py @@ -8,7 +8,7 @@ from app.api import crud from app.db import accesses -from app.db.models import AccessType +from app.models import AccessType async def is_in_same_group(table: Table, entry_id: int, group_id: int) -> bool: diff --git a/src/app/api/deps.py b/src/app/api/deps.py index 381050f8..443a3d80 100644 --- a/src/app/api/deps.py +++ b/src/app/api/deps.py @@ -10,9 +10,10 @@ import app.config as cfg from app.api import crud -from app.api.schemas import AccessRead, DeviceOut, TokenPayload, UserRead from app.db import accesses, devices, users -from app.db.models import AccessType +from app.db.session import SessionLocal +from app.models import AccessType +from app.schemas import AccessRead, DeviceOut, TokenPayload, UserRead # Scope definition oauth2_scheme = OAuth2PasswordBearer( @@ -25,6 +26,14 @@ ) +def get_db(): + db = SessionLocal() # noqa: F405 + try: + yield db + finally: + db.close() + + async def get_current_access(security_scopes: SecurityScopes, token: str = Depends(oauth2_scheme)) -> AccessRead: """Dependency to use as fastapi.security.Security with scopes. diff --git a/src/app/api/routes/accesses.py b/src/app/api/endpoints/accesses.py similarity index 92% rename from src/app/api/routes/accesses.py rename to src/app/api/endpoints/accesses.py index d26d77bf..0d861b91 100644 --- a/src/app/api/routes/accesses.py +++ b/src/app/api/endpoints/accesses.py @@ -9,9 +9,9 @@ from app.api import crud from app.api.deps import get_current_access -from app.api.schemas.accesses import AccessRead from app.db import accesses -from app.db.models import AccessType +from app.models import AccessType +from app.schemas.accesses import AccessRead router = APIRouter() diff --git a/src/app/api/routes/alerts.py b/src/app/api/endpoints/alerts.py similarity index 86% rename from src/app/api/routes/alerts.py rename to src/app/api/endpoints/alerts.py index 71714125..605e0cf6 100644 --- a/src/app/api/routes/alerts.py +++ b/src/app/api/endpoints/alerts.py @@ -12,11 +12,11 @@ from app.api import crud from app.api.crud.authorizations import check_group_read, is_admin_access from app.api.crud.groups import get_entity_group_id -from app.api.deps import get_current_access, get_current_device +from app.api.deps import get_current_access, get_current_device, get_db from app.api.external import post_request -from app.api.schemas import AlertBase, AlertIn, AlertOut, DeviceOut -from app.db import alerts, events, get_session, media, models -from app.db.models import AccessType +from app.db import alerts, events, media +from app.models import Access, AccessType, Alert, Device, Event +from app.schemas import AlertBase, AlertIn, AlertOut, DeviceOut router = APIRouter() @@ -95,7 +95,7 @@ async def get_alert( @router.get("/", response_model=List[AlertOut], summary="Get the list of all alerts") async def fetch_alerts( - requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_session) + requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_db) ): """ Retrieves the list of all alerts and their information @@ -104,11 +104,7 @@ async def fetch_alerts( return await crud.fetch_all(alerts) else: retrieved_alerts = ( - session.query(models.Alerts) - .join(models.Devices) - .join(models.Accesses) - .filter(models.Accesses.group_id == requester.group_id) - .all() + session.query(Alert).join(Device).join(Access).filter(Access.group_id == requester.group_id).all() ) retrieved_alerts = [x.__dict__ for x in retrieved_alerts] return retrieved_alerts @@ -124,7 +120,7 @@ async def delete_alert(alert_id: int = Path(..., gt=0), _=Security(get_current_a @router.get("/ongoing", response_model=List[AlertOut], summary="Get the list of ongoing alerts") async def fetch_ongoing_alerts( - requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_session) + requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_db) ): """ Retrieves the list of ongoing alerts and their information @@ -138,12 +134,12 @@ async def fetch_ongoing_alerts( return (await crud.base.database.fetch_all(query=query.limit(50)))[::-1] else: retrieved_alerts = ( - session.query(models.Alerts) - .join(models.Events) - .filter(models.Events.end_ts.is_(None)) - .join(models.Devices) - .join(models.Accesses) - .filter(models.Accesses.group_id == requester.group_id) + session.query(Alert) + .join(Event) + .filter(Event.end_ts.is_(None)) + .join(Device) + .join(Access) + .filter(Access.group_id == requester.group_id) ) retrieved_alerts = [x.__dict__ for x in retrieved_alerts.all()] return retrieved_alerts diff --git a/src/app/api/routes/devices.py b/src/app/api/endpoints/devices.py similarity index 94% rename from src/app/api/routes/devices.py rename to src/app/api/endpoints/devices.py index 07b038ab..98ccf471 100644 --- a/src/app/api/routes/devices.py +++ b/src/app/api/endpoints/devices.py @@ -11,8 +11,10 @@ from app.api import crud from app.api.crud.authorizations import is_admin_access from app.api.crud.groups import get_entity_group_id -from app.api.deps import get_current_access, get_current_device, get_current_user -from app.api.schemas import ( +from app.api.deps import get_current_access, get_current_device, get_current_user, get_db +from app.db import accesses, devices, users +from app.models import Access, AccessType, Device +from app.schemas import ( AdminDeviceAuth, Cred, DeviceAuth, @@ -24,8 +26,6 @@ SoftwareHash, UserRead, ) -from app.db import accesses, devices, get_session, models, users -from app.db.models import AccessType router = APIRouter() @@ -77,7 +77,7 @@ async def get_my_device(me: DeviceOut = Security(get_current_device, scopes=["de @router.get("/", response_model=List[DeviceOut], summary="Get the list of all devices") async def fetch_devices( - requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_session) + requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_db) ): """ Retrieves the list of all devices and their information @@ -85,12 +85,7 @@ async def fetch_devices( if await is_admin_access(requester.id): return await crud.fetch_all(devices) else: - retrieved_devices = ( - session.query(models.Devices) - .join(models.Accesses) - .filter(models.Accesses.group_id == requester.group_id) - .all() - ) + retrieved_devices = session.query(Device).join(Access).filter(Access.group_id == requester.group_id).all() retrieved_devices = [x.__dict__ for x in retrieved_devices] return retrieved_devices diff --git a/src/app/api/routes/events.py b/src/app/api/endpoints/events.py similarity index 80% rename from src/app/api/routes/events.py rename to src/app/api/endpoints/events.py index b2199144..d5b8b4a0 100644 --- a/src/app/api/routes/events.py +++ b/src/app/api/endpoints/events.py @@ -11,10 +11,10 @@ from app.api import crud from app.api.crud.authorizations import check_group_read, check_group_update, is_admin_access from app.api.crud.groups import get_entity_group_id -from app.api.deps import get_current_access -from app.api.schemas import Acknowledgement, AcknowledgementOut, EventIn, EventOut, EventUpdate -from app.db import events, get_session, models -from app.db.models import AccessType +from app.api.deps import get_current_access, get_db +from app.db import events +from app.models import Access, AccessType, Alert, Device, Event +from app.schemas import Acknowledgement, AcknowledgementOut, EventIn, EventOut, EventUpdate router = APIRouter() @@ -43,7 +43,7 @@ async def get_event( @router.get("/", response_model=List[EventOut], summary="Get the list of all events") async def fetch_events( - requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_session) + requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_db) ): """ Retrieves the list of all events and their information @@ -52,11 +52,7 @@ async def fetch_events( return await crud.fetch_all(events) else: retrieved_events = ( - session.query(models.Events) - .join(models.Alerts) - .join(models.Devices) - .join(models.Accesses) - .filter(models.Accesses.group_id == requester.group_id) + session.query(Event).join(Alert).join(Device).join(Access).filter(Access.group_id == requester.group_id) ) retrieved_events = [x.__dict__ for x in retrieved_events.all()] return retrieved_events @@ -64,7 +60,7 @@ async def fetch_events( @router.get("/past", response_model=List[EventOut], summary="Get the list of all past events") async def fetch_past_events( - requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_session) + requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_db) ): """ Retrieves the list of all events and their information @@ -73,11 +69,11 @@ async def fetch_past_events( return await crud.fetch_all(events, exclusions={"end_ts": None}) else: retrieved_events = ( - session.query(models.Events) - .join(models.Alerts) - .join(models.Devices) - .join(models.Accesses) - .filter(and_(models.Accesses.group_id == requester.group_id, models.Events.end_ts.isnot(None))) + session.query(Event) + .join(Alert) + .join(Device) + .join(Access) + .filter(and_(Access.group_id == requester.group_id, Event.end_ts.isnot(None))) ) retrieved_events = [x.__dict__ for x in retrieved_events.all()] return retrieved_events @@ -113,7 +109,7 @@ async def acknowledge_event( async def delete_event( event_id: int = Path(..., gt=0), _=Security(get_current_access, scopes=[AccessType.admin]), - session=Depends(get_session), + session=Depends(get_db), ): """ Based on a event_id, deletes the specified event @@ -125,7 +121,7 @@ async def delete_event( "/unacknowledged", response_model=List[EventOut], summary="Get the list of events that haven't been acknowledged" ) async def fetch_unacknowledged_events( - requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_session) + requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_db) ): """ Retrieves the list of non confirmed alerts and their information @@ -134,11 +130,11 @@ async def fetch_unacknowledged_events( return await crud.fetch_all(events, {"is_acknowledged": False}) else: retrieved_events = ( - session.query(models.Events) - .join(models.Alerts) - .join(models.Devices) - .join(models.Accesses) - .filter(and_(models.Accesses.group_id == requester.group_id, models.Events.is_acknowledged.is_(False))) + session.query(Event) + .join(Alert) + .join(Device) + .join(Access) + .filter(and_(Access.group_id == requester.group_id, Event.is_acknowledged.is_(False))) ) retrieved_events = [x.__dict__ for x in retrieved_events.all()] return retrieved_events diff --git a/src/app/api/routes/groups.py b/src/app/api/endpoints/groups.py similarity index 96% rename from src/app/api/routes/groups.py rename to src/app/api/endpoints/groups.py index d86a4ddc..54dcf394 100644 --- a/src/app/api/routes/groups.py +++ b/src/app/api/endpoints/groups.py @@ -9,9 +9,9 @@ from app.api import crud from app.api.deps import get_current_access -from app.api.schemas import GroupIn, GroupOut from app.db import groups -from app.db.models import AccessType +from app.models import AccessType +from app.schemas import GroupIn, GroupOut router = APIRouter() diff --git a/src/app/api/routes/installations.py b/src/app/api/endpoints/installations.py similarity index 83% rename from src/app/api/routes/installations.py rename to src/app/api/endpoints/installations.py index 2e8b5f01..14d143d1 100644 --- a/src/app/api/routes/installations.py +++ b/src/app/api/endpoints/installations.py @@ -12,10 +12,10 @@ from app.api import crud from app.api.crud.authorizations import check_group_read, check_group_update, is_admin_access from app.api.crud.groups import get_entity_group_id -from app.api.deps import get_current_access -from app.api.schemas import InstallationIn, InstallationOut, InstallationUpdate -from app.db import get_session, installations, models -from app.db.models import AccessType +from app.api.deps import get_current_access, get_db +from app.db import installations +from app.models import AccessType, Installation, Site +from app.schemas import InstallationIn, InstallationOut, InstallationUpdate router = APIRouter() @@ -49,7 +49,7 @@ async def get_installation( @router.get("/", response_model=List[InstallationOut], summary="Get the list of all installations") async def fetch_installations( - requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_session) + requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_db) ): """ Retrieves the list of all installations and their information @@ -58,10 +58,7 @@ async def fetch_installations( return await crud.fetch_all(installations) else: retrieved_installations = ( - session.query(models.Installations) - .join(models.Sites) - .filter(models.Sites.group_id == requester.group_id) - .all() + session.query(Installation).join(Site).filter(Site.group_id == requester.group_id).all() ) retrieved_installations = [x.__dict__ for x in retrieved_installations] return retrieved_installations @@ -97,7 +94,7 @@ async def delete_installation( async def get_active_devices_on_site( site_id: int = Path(..., gt=0), requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), - session=Depends(get_session), + session=Depends(get_db), ): """ Based on a site_id, retrieves the list of all the related devices and their information @@ -105,20 +102,20 @@ async def get_active_devices_on_site( current_ts = datetime.utcnow() query = ( - session.query(models.Installations) - .join(models.Sites) + session.query(Installation) + .join(Site) .filter( and_( - models.Sites.id == site_id, - models.Installations.start_ts <= current_ts, - or_(models.Installations.end_ts.is_(None), models.Installations.end_ts >= current_ts), + Site.id == site_id, + Installation.start_ts <= current_ts, + or_(Installation.end_ts.is_(None), Installation.end_ts >= current_ts), ) ) ) if not await is_admin_access(requester.id): # Restrict on the group_id of the requester - query = query.filter(models.Sites.group_id == requester.group_id) + query = query.filter(Site.group_id == requester.group_id) retrieved_device_ids = [x.__dict__["device_id"] for x in query.all()] return retrieved_device_ids diff --git a/src/app/api/routes/login.py b/src/app/api/endpoints/login.py similarity index 97% rename from src/app/api/routes/login.py rename to src/app/api/endpoints/login.py index e705a856..aa58b6c8 100644 --- a/src/app/api/routes/login.py +++ b/src/app/api/endpoints/login.py @@ -10,8 +10,8 @@ from app import config as cfg from app.api import crud, security -from app.api.schemas import Token from app.db import accesses +from app.schemas import Token router = APIRouter() diff --git a/src/app/api/routes/media.py b/src/app/api/endpoints/media.py similarity index 94% rename from src/app/api/routes/media.py rename to src/app/api/endpoints/media.py index 4ceddd76..0e6bbde1 100644 --- a/src/app/api/routes/media.py +++ b/src/app/api/endpoints/media.py @@ -13,11 +13,11 @@ from app.api import crud from app.api.crud.authorizations import check_group_read, is_admin_access from app.api.crud.groups import get_entity_group_id -from app.api.deps import get_current_access, get_current_device, get_current_user -from app.api.schemas import BaseMedia, DeviceOut, MediaCreation, MediaIn, MediaOut, MediaUrl +from app.api.deps import get_current_access, get_current_device, get_current_user, get_db from app.api.security import hash_content_file -from app.db import get_session, media, models -from app.db.models import AccessType +from app.db import media +from app.models import Access, AccessType, Device, Media +from app.schemas import BaseMedia, DeviceOut, MediaCreation, MediaIn, MediaOut, MediaUrl from app.services import bucket_service, resolve_bucket_key router = APIRouter() @@ -86,7 +86,7 @@ async def get_media( @router.get("/", response_model=List[MediaOut], summary="Get the list of all media") async def fetch_media( - requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_session) + requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_db) ): """ Retrieves the list of all media and their information @@ -95,11 +95,7 @@ async def fetch_media( return await crud.fetch_all(media) else: retrieved_media = ( - session.query(models.Media) - .join(models.Devices) - .join(models.Accesses) - .filter(models.Accesses.group_id == requester.group_id) - .all() + session.query(Media).join(Device).join(Access).filter(Access.group_id == requester.group_id).all() ) retrieved_media = [x.__dict__ for x in retrieved_media] return retrieved_media diff --git a/src/app/api/routes/sites.py b/src/app/api/endpoints/sites.py similarity index 93% rename from src/app/api/routes/sites.py rename to src/app/api/endpoints/sites.py index e68880d2..100e7a11 100644 --- a/src/app/api/routes/sites.py +++ b/src/app/api/endpoints/sites.py @@ -10,10 +10,10 @@ from app.api import crud from app.api.crud.authorizations import check_group_read, check_group_update, is_admin_access from app.api.crud.groups import get_entity_group_id -from app.api.deps import get_current_access -from app.api.schemas import SiteBase, SiteIn, SiteOut, SiteUpdate -from app.db import SiteType, get_session, sites -from app.db.models import AccessType +from app.api.deps import get_current_access, get_db +from app.db import sites +from app.models import AccessType, SiteType +from app.schemas import SiteBase, SiteIn, SiteOut, SiteUpdate router = APIRouter() @@ -63,7 +63,7 @@ async def get_site( @router.get("/", response_model=List[SiteOut], summary="Get the list of all sites in your group") async def fetch_sites( - requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_session) + requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_db) ): """ Retrieves the list of all sites and their information diff --git a/src/app/api/routes/users.py b/src/app/api/endpoints/users.py similarity index 89% rename from src/app/api/routes/users.py rename to src/app/api/endpoints/users.py index 08ed7aed..c54a97d7 100644 --- a/src/app/api/routes/users.py +++ b/src/app/api/endpoints/users.py @@ -9,10 +9,10 @@ from app.api import crud from app.api.crud.authorizations import is_admin_access -from app.api.deps import get_current_access, get_current_user -from app.api.schemas import Cred, Login, UserAuth, UserCreation, UserRead -from app.db import accesses, get_session, models, users -from app.db.models import AccessType +from app.api.deps import get_current_access, get_current_user, get_db +from app.db import accesses, users +from app.models import Access, AccessType, User +from app.schemas import Cred, Login, UserAuth, UserCreation, UserRead router = APIRouter() @@ -68,7 +68,7 @@ async def get_user(user_id: int = Path(..., gt=0), _=Security(get_current_user, @router.get("/", response_model=List[UserRead], summary="Get the list of all users") async def fetch_users( - requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_session) + requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_db) ): """ Retrieves the list of all users and their information @@ -76,12 +76,7 @@ async def fetch_users( if await is_admin_access(requester.id): return await crud.fetch_all(users) else: - retrieved_users = ( - session.query(models.Users) - .join(models.Accesses) - .filter(models.Accesses.group_id == requester.group_id) - .all() - ) + retrieved_users = session.query(User).join(Access).filter(Access.group_id == requester.group_id).all() retrieved_users = [x.__dict__ for x in retrieved_users] return retrieved_users diff --git a/src/app/api/routes/webhooks.py b/src/app/api/endpoints/webhooks.py similarity index 97% rename from src/app/api/routes/webhooks.py rename to src/app/api/endpoints/webhooks.py index 27cc1f17..1de46233 100644 --- a/src/app/api/routes/webhooks.py +++ b/src/app/api/endpoints/webhooks.py @@ -9,8 +9,8 @@ from app.api import crud from app.api.deps import get_current_access -from app.api.schemas import WebhookIn, WebhookOut from app.db import webhooks +from app.schemas import WebhookIn, WebhookOut router = APIRouter() diff --git a/src/app/db/__init__.py b/src/app/db/__init__.py index dcfa3427..e501448d 100644 --- a/src/app/db/__init__.py +++ b/src/app/db/__init__.py @@ -1,13 +1,4 @@ from .init_db import * -from .models import * +from .base_class import * from .session import * from .tables import * - - -# Dependency -def get_session(): - db = SessionLocal() # noqa: F405 - try: - yield db - finally: - db.close() diff --git a/src/app/db/base_class.py b/src/app/db/base_class.py new file mode 100644 index 00000000..a6e6b81c --- /dev/null +++ b/src/app/db/base_class.py @@ -0,0 +1,10 @@ +# Copyright (C) 2020-2023, Pyronear. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from sqlalchemy.ext.declarative import declarative_base + +__all__ = ["Base"] + +Base = declarative_base() diff --git a/src/app/db/init_db.py b/src/app/db/init_db.py index c6ca3380..833d9953 100644 --- a/src/app/db/init_db.py +++ b/src/app/db/init_db.py @@ -5,10 +5,10 @@ from app import config as cfg from app.api import crud -from app.api.schemas import AccessCreation, GroupIn, UserCreation from app.api.security import hash_password +from app.models import AccessType +from app.schemas import AccessCreation, GroupIn, UserCreation -from .models import AccessType from .tables import accesses, groups, users __all__ = ["init_db"] diff --git a/src/app/db/models.py b/src/app/db/models.py deleted file mode 100644 index b15134de..00000000 --- a/src/app/db/models.py +++ /dev/null @@ -1,224 +0,0 @@ -# Copyright (C) 2021-2023, Pyronear. - -# This program is licensed under the Apache License 2.0. -# See LICENSE or go to for full license details. - -import enum - -from sqlalchemy import Boolean, Column, DateTime, Enum, Float, ForeignKey, Integer, String -from sqlalchemy.orm import relationship -from sqlalchemy.sql import func - -from .session import Base - -__all__ = ["AccessType", "EventType", "MediaType", "SiteType"] - - -class Users(Base): - __tablename__ = "users" - - id = Column(Integer, primary_key=True) - login = Column(String(50), unique=True) - access_id = Column(Integer, ForeignKey("accesses.id", ondelete="CASCADE"), unique=True) - created_at = Column(DateTime, default=func.now()) - - access = relationship("Accesses", uselist=False, back_populates="user") - device = relationship("Devices", uselist=False, back_populates="owner") - - def __repr__(self): - return f"" - - -class AccessType(str, enum.Enum): - user: str = "user" - admin: str = "admin" - device: str = "device" - - -class Accesses(Base): - __tablename__ = "accesses" - - id = Column(Integer, primary_key=True) - login = Column(String(50), unique=True, index=True) # index for fast lookup - hashed_password = Column(String(70), nullable=False) - scope = Column(Enum(AccessType), default=AccessType.user, nullable=False) - group_id = Column(Integer, ForeignKey("groups.id", ondelete="CASCADE"), nullable=False) - - user = relationship("Users", uselist=False, back_populates="access") - device = relationship("Devices", uselist=False, back_populates="access") - group = relationship("Groups", uselist=False, back_populates="accesses") - - def __repr__(self): - return f"" - - -class Groups(Base): - __tablename__ = "groups" - - id = Column(Integer, primary_key=True) - name = Column(String(50), unique=True) - - accesses = relationship("Accesses", back_populates="group") - sites = relationship("Sites", back_populates="group") - - def __repr__(self): - return f"" - - -class SiteType(str, enum.Enum): - tower: str = "tower" - station: str = "station" - no_alert: str = "no_alert" - - -class Sites(Base): - __tablename__ = "sites" - - id = Column(Integer, primary_key=True) - name = Column(String(50)) - group_id = Column(Integer, ForeignKey("groups.id", ondelete="CASCADE"), nullable=False) - lat = Column(Float(4, asdecimal=True)) - lon = Column(Float(4, asdecimal=True)) - country = Column(String(5), nullable=False) - geocode = Column(String(10), nullable=False) - type = Column(Enum(SiteType), default=SiteType.tower) - created_at = Column(DateTime, default=func.now()) - - installations = relationship("Installations", back_populates="site") - group = relationship("Groups", uselist=False, back_populates="sites") - - def __repr__(self): - return ( - f"" - ) - - -class EventType(str, enum.Enum): - wildfire: str = "wildfire" - - -class Events(Base): - __tablename__ = "events" - - id = Column(Integer, primary_key=True) - lat = Column(Float(4, asdecimal=True)) - lon = Column(Float(4, asdecimal=True)) - type = Column(Enum(EventType), default=EventType.wildfire) - start_ts = Column(DateTime, default=func.now()) - end_ts = Column(DateTime, default=None, nullable=True) - is_acknowledged = Column(Boolean, default=False) - created_at = Column(DateTime, default=func.now()) - - alerts = relationship("Alerts", back_populates="event") - - def __repr__(self): - return ( - f"" - ) - - -# Linked tables -class Devices(Base): - __tablename__ = "devices" - - id = Column(Integer, primary_key=True) - login = Column(String(50), unique=True) - owner_id = Column(Integer, ForeignKey("users.id")) - access_id = Column(Integer, ForeignKey("accesses.id", ondelete="CASCADE"), unique=True) - specs = Column(String(50)) - software_hash = Column(String(16), default=None, nullable=True) - angle_of_view = Column(Float(2, asdecimal=True)) - elevation = Column(Float(1, asdecimal=True), default=None, nullable=True) - lat = Column(Float(4, asdecimal=True), default=None, nullable=True) - lon = Column(Float(4, asdecimal=True), default=None, nullable=True) - azimuth = Column(Float(1, asdecimal=True), default=None, nullable=True) - pitch = Column(Float(1, asdecimal=True), default=None, nullable=True) - last_ping = Column(DateTime, default=None, nullable=True) - created_at = Column(DateTime, default=func.now()) - - access = relationship("Accesses", uselist=False, back_populates="device") - owner = relationship("Users", uselist=False, back_populates="device") - media = relationship("Media", back_populates="device") - alerts = relationship("Alerts", back_populates="device") - installation = relationship("Installations", back_populates="device") - - def __repr__(self): - return ( - f"" - ) - - -class MediaType(str, enum.Enum): - image: str = "image" - video: str = "video" - - -class Media(Base): - __tablename__ = "media" - - id = Column(Integer, primary_key=True) - device_id = Column(Integer, ForeignKey("devices.id")) - bucket_key = Column(String(100), nullable=True) - type = Column(Enum(MediaType), default=MediaType.image) - created_at = Column(DateTime, default=func.now()) - - device = relationship("Devices", uselist=False, back_populates="media") - alerts = relationship("Alerts", back_populates="media") - - def __repr__(self): - return f"" - - -class Installations(Base): - __tablename__ = "installations" - - id = Column(Integer, primary_key=True) - device_id = Column(Integer, ForeignKey("devices.id")) - site_id = Column(Integer, ForeignKey("sites.id")) - start_ts = Column(DateTime, nullable=False) - end_ts = Column(DateTime, default=None, nullable=True) - is_trustworthy = Column(Boolean, default=True) - created_at = Column(DateTime, default=func.now()) - - device = relationship("Devices", back_populates="installation") - site = relationship("Sites", back_populates="installations") - - def __repr__(self): - return ( - f"" - ) - - -class Alerts(Base): - __tablename__ = "alerts" - - id = Column(Integer, primary_key=True) - device_id = Column(Integer, ForeignKey("devices.id"), nullable=False) - event_id = Column(Integer, ForeignKey("events.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=True) - media_id = Column(Integer, ForeignKey("media.id"), nullable=False) - azimuth = Column(Float(4, asdecimal=True)) - lat = Column(Float(4, asdecimal=True)) - lon = Column(Float(4, asdecimal=True)) - created_at = Column(DateTime, default=func.now()) - - device = relationship("Devices", back_populates="alerts") - event = relationship("Events", back_populates="alerts") - media = relationship("Media", back_populates="alerts") - - def __repr__(self): - return f"" - - -class Webhooks(Base): - __tablename__ = "webhooks" - - id = Column(Integer, primary_key=True) - callback = Column(String(50), nullable=False) - url = Column(String(100), nullable=False) - - def __repr__(self): - return f"" diff --git a/src/app/db/session.py b/src/app/db/session.py index 6a80fd30..75296eeb 100644 --- a/src/app/db/session.py +++ b/src/app/db/session.py @@ -5,15 +5,13 @@ from databases import Database from sqlalchemy import create_engine -from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from app import config as cfg -__all__ = ["Base", "SessionLocal", "database", "engine"] +__all__ = ["SessionLocal", "database", "engine"] engine = create_engine(cfg.DATABASE_URL) -database = Database(cfg.DATABASE_URL) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) -Base = declarative_base() +database = Database(cfg.DATABASE_URL) diff --git a/src/app/db/tables.py b/src/app/db/tables.py index 9388b3cb..40415e64 100644 --- a/src/app/db/tables.py +++ b/src/app/db/tables.py @@ -4,8 +4,9 @@ # See LICENSE or go to for full license details. -from .models import Accesses, Alerts, Devices, Events, Groups, Installations, Media, Sites, Users, Webhooks -from .session import Base +from app.models import Access, Alert, Device, Event, Group, Installation, Media, Site, User, Webhook + +from .base_class import Base __all__ = [ "metadata", @@ -21,15 +22,15 @@ "webhooks", ] -users = Users.__table__ -accesses = Accesses.__table__ -groups = Groups.__table__ -sites = Sites.__table__ -events = Events.__table__ -devices = Devices.__table__ +users = User.__table__ +accesses = Access.__table__ +groups = Group.__table__ +sites = Site.__table__ +events = Event.__table__ +devices = Device.__table__ media = Media.__table__ -installations = Installations.__table__ -alerts = Alerts.__table__ -webhooks = Webhooks.__table__ +installations = Installation.__table__ +alerts = Alert.__table__ +webhooks = Webhook.__table__ metadata = Base.metadata diff --git a/src/app/main.py b/src/app/main.py index f05f0292..664fa406 100644 --- a/src/app/main.py +++ b/src/app/main.py @@ -12,7 +12,7 @@ from sentry_sdk.integrations.asgi import SentryAsgiMiddleware from app import config as cfg -from app.api.routes import ( +from app.api.endpoints import ( accesses, alerts, devices, diff --git a/src/app/models/__init__.py b/src/app/models/__init__.py new file mode 100644 index 00000000..eafc065d --- /dev/null +++ b/src/app/models/__init__.py @@ -0,0 +1,10 @@ +from .access import * +from .alert import * +from .device import * +from .event import * +from .group import * +from .installation import * +from .media import * +from .site import * +from .user import * +from .webhook import * diff --git a/src/app/models/access.py b/src/app/models/access.py new file mode 100644 index 00000000..9804e319 --- /dev/null +++ b/src/app/models/access.py @@ -0,0 +1,36 @@ +# Copyright (C) 2021-2023, Pyronear. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import enum + +from sqlalchemy import Column, Enum, ForeignKey, Integer, String +from sqlalchemy.orm import relationship + +from app.db.base_class import Base + +__all__ = ["AccessType", "Access"] + + +class AccessType(str, enum.Enum): + user: str = "user" + admin: str = "admin" + device: str = "device" + + +class Access(Base): + __tablename__ = "accesses" + + id = Column(Integer, primary_key=True) + login = Column(String(50), unique=True, index=True) # index for fast lookup + hashed_password = Column(String(70), nullable=False) + scope = Column(Enum(AccessType), default=AccessType.user, nullable=False) + group_id = Column(Integer, ForeignKey("groups.id", ondelete="CASCADE"), nullable=False) + + user = relationship("User", uselist=False, back_populates="access") + device = relationship("Device", uselist=False, back_populates="access") + group = relationship("Group", uselist=False, back_populates="accesses") + + def __repr__(self): + return f"" diff --git a/src/app/models/alert.py b/src/app/models/alert.py new file mode 100644 index 00000000..316f0aa4 --- /dev/null +++ b/src/app/models/alert.py @@ -0,0 +1,33 @@ +# Copyright (C) 2021-2023, Pyronear. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + + +from sqlalchemy import Column, DateTime, Float, ForeignKey, Integer +from sqlalchemy.orm import relationship +from sqlalchemy.sql import func + +from app.db.base_class import Base + +__all__ = ["Alert"] + + +class Alert(Base): + __tablename__ = "alerts" + + id = Column(Integer, primary_key=True) + device_id = Column(Integer, ForeignKey("devices.id"), nullable=False) + event_id = Column(Integer, ForeignKey("events.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=True) + media_id = Column(Integer, ForeignKey("media.id"), nullable=False) + azimuth = Column(Float(4, asdecimal=True)) + lat = Column(Float(4, asdecimal=True)) + lon = Column(Float(4, asdecimal=True)) + created_at = Column(DateTime, default=func.now()) + + device = relationship("Device", back_populates="alerts") + event = relationship("Event", back_populates="alerts") + media = relationship("Media", back_populates="alerts") + + def __repr__(self): + return f"" diff --git a/src/app/models/device.py b/src/app/models/device.py new file mode 100644 index 00000000..17aafa9b --- /dev/null +++ b/src/app/models/device.py @@ -0,0 +1,44 @@ +# Copyright (C) 2021-2023, Pyronear. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + + +from sqlalchemy import Column, DateTime, Float, ForeignKey, Integer, String +from sqlalchemy.orm import relationship +from sqlalchemy.sql import func + +from app.db.base_class import Base + +__all__ = ["Device"] + + +class Device(Base): + __tablename__ = "devices" + + id = Column(Integer, primary_key=True) + login = Column(String(50), unique=True) + owner_id = Column(Integer, ForeignKey("users.id")) + access_id = Column(Integer, ForeignKey("accesses.id", ondelete="CASCADE"), unique=True) + specs = Column(String(50)) + software_hash = Column(String(16), default=None, nullable=True) + angle_of_view = Column(Float(2, asdecimal=True)) + elevation = Column(Float(1, asdecimal=True), default=None, nullable=True) + lat = Column(Float(4, asdecimal=True), default=None, nullable=True) + lon = Column(Float(4, asdecimal=True), default=None, nullable=True) + azimuth = Column(Float(1, asdecimal=True), default=None, nullable=True) + pitch = Column(Float(1, asdecimal=True), default=None, nullable=True) + last_ping = Column(DateTime, default=None, nullable=True) + created_at = Column(DateTime, default=func.now()) + + access = relationship("Access", uselist=False, back_populates="device") + owner = relationship("User", uselist=False, back_populates="device") + media = relationship("Media", back_populates="device") + alerts = relationship("Alert", back_populates="device") + installation = relationship("Installation", back_populates="device") + + def __repr__(self): + return ( + f"" + ) diff --git a/src/app/models/event.py b/src/app/models/event.py new file mode 100644 index 00000000..d61f352b --- /dev/null +++ b/src/app/models/event.py @@ -0,0 +1,39 @@ +# Copyright (C) 2021-2023, Pyronear. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import enum + +from sqlalchemy import Boolean, Column, DateTime, Enum, Float, Integer +from sqlalchemy.orm import relationship +from sqlalchemy.sql import func + +from app.db.base_class import Base + +__all__ = ["EventType", "Event"] + + +class EventType(str, enum.Enum): + wildfire: str = "wildfire" + + +class Event(Base): + __tablename__ = "events" + + id = Column(Integer, primary_key=True) + lat = Column(Float(4, asdecimal=True)) + lon = Column(Float(4, asdecimal=True)) + type = Column(Enum(EventType), default=EventType.wildfire) + start_ts = Column(DateTime, default=func.now()) + end_ts = Column(DateTime, default=None, nullable=True) + is_acknowledged = Column(Boolean, default=False) + created_at = Column(DateTime, default=func.now()) + + alerts = relationship("Alert", back_populates="event") + + def __repr__(self): + return ( + f"" + ) diff --git a/src/app/models/group.py b/src/app/models/group.py new file mode 100644 index 00000000..71bf8d5e --- /dev/null +++ b/src/app/models/group.py @@ -0,0 +1,25 @@ +# Copyright (C) 2021-2023, Pyronear. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + + +from sqlalchemy import Column, Integer, String +from sqlalchemy.orm import relationship + +from app.db.base_class import Base + +__all__ = ["Group"] + + +class Group(Base): + __tablename__ = "groups" + + id = Column(Integer, primary_key=True) + name = Column(String(50), unique=True) + + accesses = relationship("Access", back_populates="group") + sites = relationship("Site", back_populates="group") + + def __repr__(self): + return f"" diff --git a/src/app/models/installation.py b/src/app/models/installation.py new file mode 100644 index 00000000..366ade21 --- /dev/null +++ b/src/app/models/installation.py @@ -0,0 +1,34 @@ +# Copyright (C) 2021-2023, Pyronear. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + + +from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer +from sqlalchemy.orm import relationship +from sqlalchemy.sql import func + +from app.db.base_class import Base + +__all__ = ["Installation"] + + +class Installation(Base): + __tablename__ = "installations" + + id = Column(Integer, primary_key=True) + device_id = Column(Integer, ForeignKey("devices.id")) + site_id = Column(Integer, ForeignKey("sites.id")) + start_ts = Column(DateTime, nullable=False) + end_ts = Column(DateTime, default=None, nullable=True) + is_trustworthy = Column(Boolean, default=True) + created_at = Column(DateTime, default=func.now()) + + device = relationship("Device", back_populates="installation") + site = relationship("Site", back_populates="installations") + + def __repr__(self): + return ( + f"" + ) diff --git a/src/app/models/media.py b/src/app/models/media.py new file mode 100644 index 00000000..bcb251ea --- /dev/null +++ b/src/app/models/media.py @@ -0,0 +1,35 @@ +# Copyright (C) 2021-2023, Pyronear. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import enum + +from sqlalchemy import Column, DateTime, Enum, ForeignKey, Integer, String +from sqlalchemy.orm import relationship +from sqlalchemy.sql import func + +from app.db.base_class import Base + +__all__ = ["MediaType", "Media"] + + +class MediaType(str, enum.Enum): + image: str = "image" + video: str = "video" + + +class Media(Base): + __tablename__ = "media" + + id = Column(Integer, primary_key=True) + device_id = Column(Integer, ForeignKey("devices.id")) + bucket_key = Column(String(100), nullable=True) + type = Column(Enum(MediaType), default=MediaType.image) + created_at = Column(DateTime, default=func.now()) + + device = relationship("Device", uselist=False, back_populates="media") + alerts = relationship("Alert", back_populates="media") + + def __repr__(self): + return f"" diff --git a/src/app/models/site.py b/src/app/models/site.py new file mode 100644 index 00000000..74fb2446 --- /dev/null +++ b/src/app/models/site.py @@ -0,0 +1,43 @@ +# Copyright (C) 2021-2023, Pyronear. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import enum + +from sqlalchemy import Column, DateTime, Enum, Float, ForeignKey, Integer, String +from sqlalchemy.orm import relationship +from sqlalchemy.sql import func + +from app.db.base_class import Base + +__all__ = ["SiteType", "Site"] + + +class SiteType(str, enum.Enum): + tower: str = "tower" + station: str = "station" + no_alert: str = "no_alert" + + +class Site(Base): + __tablename__ = "sites" + + id = Column(Integer, primary_key=True) + name = Column(String(50)) + group_id = Column(Integer, ForeignKey("groups.id", ondelete="CASCADE"), nullable=False) + lat = Column(Float(4, asdecimal=True)) + lon = Column(Float(4, asdecimal=True)) + country = Column(String(5), nullable=False) + geocode = Column(String(10), nullable=False) + type = Column(Enum(SiteType), default=SiteType.tower) + created_at = Column(DateTime, default=func.now()) + + installations = relationship("Installation", back_populates="site") + group = relationship("Group", uselist=False, back_populates="sites") + + def __repr__(self): + return ( + f"" + ) diff --git a/src/app/models/user.py b/src/app/models/user.py new file mode 100644 index 00000000..6b058067 --- /dev/null +++ b/src/app/models/user.py @@ -0,0 +1,28 @@ +# Copyright (C) 2021-2023, Pyronear. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + + +from sqlalchemy import Column, DateTime, ForeignKey, Integer, String +from sqlalchemy.orm import relationship +from sqlalchemy.sql import func + +from app.db.base_class import Base + +__all__ = ["User"] + + +class User(Base): + __tablename__ = "users" + + id = Column(Integer, primary_key=True) + login = Column(String(50), unique=True) + access_id = Column(Integer, ForeignKey("accesses.id", ondelete="CASCADE"), unique=True) + created_at = Column(DateTime, default=func.now()) + + access = relationship("Access", uselist=False, back_populates="user") + device = relationship("Device", uselist=False, back_populates="owner") + + def __repr__(self): + return f"" diff --git a/src/app/models/webhook.py b/src/app/models/webhook.py new file mode 100644 index 00000000..c269ebd6 --- /dev/null +++ b/src/app/models/webhook.py @@ -0,0 +1,22 @@ +# Copyright (C) 2021-2023, Pyronear. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + + +from sqlalchemy import Column, Integer, String + +from app.db.base_class import Base + +__all__ = ["Webhook"] + + +class Webhook(Base): + __tablename__ = "webhooks" + + id = Column(Integer, primary_key=True) + callback = Column(String(50), nullable=False) + url = Column(String(100), nullable=False) + + def __repr__(self): + return f"" diff --git a/src/app/api/schemas/__init__.py b/src/app/schemas/__init__.py similarity index 100% rename from src/app/api/schemas/__init__.py rename to src/app/schemas/__init__.py diff --git a/src/app/api/schemas/accesses.py b/src/app/schemas/accesses.py similarity index 93% rename from src/app/api/schemas/accesses.py rename to src/app/schemas/accesses.py index dc7d888c..05983e31 100644 --- a/src/app/api/schemas/accesses.py +++ b/src/app/schemas/accesses.py @@ -5,7 +5,7 @@ from pydantic import Field -from app.db.models import AccessType +from app.models import AccessType from .base import CredHash, Login, _GroupId, _Id diff --git a/src/app/api/schemas/alerts.py b/src/app/schemas/alerts.py similarity index 100% rename from src/app/api/schemas/alerts.py rename to src/app/schemas/alerts.py diff --git a/src/app/api/schemas/base.py b/src/app/schemas/base.py similarity index 100% rename from src/app/api/schemas/base.py rename to src/app/schemas/base.py diff --git a/src/app/api/schemas/devices.py b/src/app/schemas/devices.py similarity index 98% rename from src/app/api/schemas/devices.py rename to src/app/schemas/devices.py index 2fb573b8..b3b83d41 100644 --- a/src/app/api/schemas/devices.py +++ b/src/app/schemas/devices.py @@ -8,7 +8,7 @@ from pydantic import BaseModel, Field -from app.db.models import AccessType +from app.models import AccessType from .base import Cred, DefaultPosition, Login, _CreatedAt, _GroupId, _Id diff --git a/src/app/api/schemas/events.py b/src/app/schemas/events.py similarity index 97% rename from src/app/api/schemas/events.py rename to src/app/schemas/events.py index 939deb35..8417cfac 100644 --- a/src/app/api/schemas/events.py +++ b/src/app/schemas/events.py @@ -8,7 +8,7 @@ from pydantic import BaseModel, Field, validator -from app.db.models import EventType +from app.models import EventType from .base import _CreatedAt, _FlatLocation, _Id diff --git a/src/app/api/schemas/groups.py b/src/app/schemas/groups.py similarity index 100% rename from src/app/api/schemas/groups.py rename to src/app/schemas/groups.py diff --git a/src/app/api/schemas/installations.py b/src/app/schemas/installations.py similarity index 100% rename from src/app/api/schemas/installations.py rename to src/app/schemas/installations.py diff --git a/src/app/api/schemas/login.py b/src/app/schemas/login.py similarity index 95% rename from src/app/api/schemas/login.py rename to src/app/schemas/login.py index 889eddc2..e5ba4aa9 100644 --- a/src/app/api/schemas/login.py +++ b/src/app/schemas/login.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, Field -from app.db.models import AccessType +from app.models import AccessType __all__ = ["Token", "TokenPayload"] diff --git a/src/app/api/schemas/media.py b/src/app/schemas/media.py similarity index 95% rename from src/app/api/schemas/media.py rename to src/app/schemas/media.py index b096dc88..d2ba486f 100644 --- a/src/app/api/schemas/media.py +++ b/src/app/schemas/media.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, Field -from app.db.models import MediaType +from app.models import MediaType from .base import _CreatedAt, _Id diff --git a/src/app/api/schemas/sites.py b/src/app/schemas/sites.py similarity index 96% rename from src/app/api/schemas/sites.py rename to src/app/schemas/sites.py index 05754357..bc2368f2 100644 --- a/src/app/api/schemas/sites.py +++ b/src/app/schemas/sites.py @@ -5,7 +5,7 @@ from pydantic import Field -from app.db.models import SiteType +from app.models import SiteType from .base import _CreatedAt, _FlatLocation, _Id diff --git a/src/app/api/schemas/users.py b/src/app/schemas/users.py similarity index 94% rename from src/app/api/schemas/users.py rename to src/app/schemas/users.py index 915c1386..1e818e65 100644 --- a/src/app/api/schemas/users.py +++ b/src/app/schemas/users.py @@ -6,7 +6,7 @@ from pydantic import Field -from app.db.models import AccessType +from app.models import AccessType from .base import Cred, Login, _CreatedAt, _GroupId, _Id diff --git a/src/app/api/schemas/webhooks.py b/src/app/schemas/webhooks.py similarity index 100% rename from src/app/api/schemas/webhooks.py rename to src/app/schemas/webhooks.py diff --git a/src/tests/routes/test_alerts.py b/src/tests/routes/test_alerts.py index 340e3a4d..52d3d62a 100644 --- a/src/tests/routes/test_alerts.py +++ b/src/tests/routes/test_alerts.py @@ -5,7 +5,7 @@ import pytest_asyncio from app import db -from app.api import crud +from app.api import crud, deps from tests.db_utils import TestSessionLocal, fill_table, get_entry from tests.utils import parse_time, ts_to_string, update_only_datetime @@ -155,7 +155,7 @@ @pytest_asyncio.fixture(scope="function") async def init_test_db(monkeypatch, test_db): monkeypatch.setattr(crud.base, "database", test_db) - monkeypatch.setattr(db, "SessionLocal", TestSessionLocal) + monkeypatch.setattr(deps, "SessionLocal", TestSessionLocal) await fill_table(test_db, db.groups, GROUP_TABLE) await fill_table(test_db, db.accesses, ACCESS_TABLE) await fill_table(test_db, db.users, USER_TABLE_FOR_DB) diff --git a/src/tests/routes/test_devices.py b/src/tests/routes/test_devices.py index c1aa9193..e126d9be 100644 --- a/src/tests/routes/test_devices.py +++ b/src/tests/routes/test_devices.py @@ -5,7 +5,7 @@ import pytest_asyncio from app import db -from app.api import crud, security +from app.api import crud, deps, security from tests.db_utils import TestSessionLocal, fill_table, get_entry from tests.utils import parse_time, update_only_datetime @@ -85,7 +85,7 @@ async def init_test_db(monkeypatch, test_db): monkeypatch.setattr(security, "hash_password", pytest.mock_hash_password) monkeypatch.setattr(crud.base, "database", test_db) - monkeypatch.setattr(db, "SessionLocal", TestSessionLocal) + monkeypatch.setattr(deps, "SessionLocal", TestSessionLocal) await fill_table(test_db, db.groups, GROUP_TABLE) await fill_table(test_db, db.accesses, ACCESS_TABLE) await fill_table(test_db, db.users, USER_TABLE_FOR_DB) diff --git a/src/tests/routes/test_events.py b/src/tests/routes/test_events.py index 41b082c7..0573b0a7 100644 --- a/src/tests/routes/test_events.py +++ b/src/tests/routes/test_events.py @@ -5,7 +5,7 @@ import pytest_asyncio from app import db -from app.api import crud +from app.api import crud, deps from tests.db_utils import TestSessionLocal, fill_table, get_entry from tests.utils import ts_to_string, update_only_datetime @@ -153,7 +153,7 @@ @pytest_asyncio.fixture(scope="function") async def init_test_db(monkeypatch, test_db): monkeypatch.setattr(crud.base, "database", test_db) - monkeypatch.setattr(db, "SessionLocal", TestSessionLocal) + monkeypatch.setattr(deps, "SessionLocal", TestSessionLocal) await fill_table(test_db, db.groups, GROUP_TABLE) await fill_table(test_db, db.accesses, ACCESS_TABLE) await fill_table(test_db, db.users, USER_TABLE_FOR_DB) diff --git a/src/tests/routes/test_groups.py b/src/tests/routes/test_groups.py index 638b0973..9e24f1d9 100644 --- a/src/tests/routes/test_groups.py +++ b/src/tests/routes/test_groups.py @@ -4,7 +4,7 @@ import pytest_asyncio from app import db -from app.api import crud +from app.api import crud, deps from tests.db_utils import TestSessionLocal, fill_table, get_entry from tests.utils import update_only_datetime @@ -70,7 +70,7 @@ @pytest_asyncio.fixture(scope="function") async def init_test_db(monkeypatch, test_db): monkeypatch.setattr(crud.base, "database", test_db) - monkeypatch.setattr(db, "SessionLocal", TestSessionLocal) + monkeypatch.setattr(deps, "SessionLocal", TestSessionLocal) await fill_table(test_db, db.groups, GROUP_TABLE_FOR_DB) await fill_table(test_db, db.accesses, ACCESS_TABLE) await fill_table(test_db, db.users, USER_TABLE_FOR_DB) diff --git a/src/tests/routes/test_installations.py b/src/tests/routes/test_installations.py index cfa00444..c4693a21 100644 --- a/src/tests/routes/test_installations.py +++ b/src/tests/routes/test_installations.py @@ -5,7 +5,7 @@ import pytest_asyncio from app import db -from app.api import crud +from app.api import crud, deps from tests.db_utils import TestSessionLocal, fill_table, get_entry from tests.utils import parse_time, update_only_datetime @@ -115,7 +115,7 @@ @pytest_asyncio.fixture(scope="function") async def init_test_db(monkeypatch, test_db): monkeypatch.setattr(crud.base, "database", test_db) - monkeypatch.setattr(db, "SessionLocal", TestSessionLocal) + monkeypatch.setattr(deps, "SessionLocal", TestSessionLocal) await fill_table(test_db, db.groups, GROUP_TABLE) await fill_table(test_db, db.accesses, ACCESS_TABLE) await fill_table(test_db, db.users, USER_TABLE_FOR_DB) diff --git a/src/tests/routes/test_media.py b/src/tests/routes/test_media.py index 7153ba3a..27bb1427 100644 --- a/src/tests/routes/test_media.py +++ b/src/tests/routes/test_media.py @@ -8,7 +8,7 @@ import requests from app import db -from app.api import crud +from app.api import crud, deps from app.api.security import hash_content_file from app.services import bucket_service from tests.db_utils import TestSessionLocal, fill_table, get_entry @@ -79,7 +79,7 @@ @pytest_asyncio.fixture(scope="function") async def init_test_db(monkeypatch, test_db): monkeypatch.setattr(crud.base, "database", test_db) - monkeypatch.setattr(db, "SessionLocal", TestSessionLocal) + monkeypatch.setattr(deps, "SessionLocal", TestSessionLocal) await fill_table(test_db, db.groups, GROUP_TABLE) await fill_table(test_db, db.accesses, ACCESS_TABLE) await fill_table(test_db, db.users, USER_TABLE_FOR_DB) diff --git a/src/tests/routes/test_sites.py b/src/tests/routes/test_sites.py index 2f7ec99d..7e3fd2c3 100644 --- a/src/tests/routes/test_sites.py +++ b/src/tests/routes/test_sites.py @@ -5,7 +5,7 @@ import pytest_asyncio from app import db -from app.api import crud +from app.api import crud, deps from tests.db_utils import TestSessionLocal, fill_table, get_entry from tests.utils import update_only_datetime @@ -60,7 +60,7 @@ def compare_entries(ref, test): @pytest_asyncio.fixture(scope="function") async def init_test_db(monkeypatch, test_db): monkeypatch.setattr(crud.base, "database", test_db) - monkeypatch.setattr(db, "SessionLocal", TestSessionLocal) + monkeypatch.setattr(deps, "SessionLocal", TestSessionLocal) await fill_table(test_db, db.groups, GROUP_TABLE) await fill_table(test_db, db.accesses, ACCESS_TABLE) await fill_table(test_db, db.sites, SITE_TABLE_FOR_DB) diff --git a/src/tests/routes/test_users.py b/src/tests/routes/test_users.py index a6982a9a..40291707 100644 --- a/src/tests/routes/test_users.py +++ b/src/tests/routes/test_users.py @@ -5,7 +5,7 @@ import pytest_asyncio from app import db -from app.api import crud, security +from app.api import crud, deps, security from tests.db_utils import TestSessionLocal, fill_table, get_entry from tests.utils import parse_time, update_only_datetime @@ -32,7 +32,7 @@ @pytest_asyncio.fixture(scope="function") async def init_test_db(monkeypatch, test_db): monkeypatch.setattr(crud.base, "database", test_db) - monkeypatch.setattr(db, "SessionLocal", TestSessionLocal) + monkeypatch.setattr(deps, "SessionLocal", TestSessionLocal) await fill_table(test_db, db.groups, GROUP_TABLE) await fill_table(test_db, db.accesses, ACCESS_TABLE) await fill_table(test_db, db.users, USER_TABLE_FOR_DB) diff --git a/src/tests/routes/test_webhooks.py b/src/tests/routes/test_webhooks.py index 5fca77b9..39764bb9 100644 --- a/src/tests/routes/test_webhooks.py +++ b/src/tests/routes/test_webhooks.py @@ -4,7 +4,7 @@ import pytest_asyncio from app import db -from app.api import crud +from app.api import crud, deps from tests.db_utils import TestSessionLocal, fill_table, get_entry from tests.utils import update_only_datetime @@ -30,7 +30,7 @@ @pytest_asyncio.fixture(scope="function") async def init_test_db(monkeypatch, test_db): monkeypatch.setattr(crud.base, "database", test_db) - monkeypatch.setattr(db, "SessionLocal", TestSessionLocal) + monkeypatch.setattr(deps, "SessionLocal", TestSessionLocal) await fill_table(test_db, db.groups, GROUP_TABLE) await fill_table(test_db, db.accesses, ACCESS_TABLE) await fill_table(test_db, db.webhooks, WEBHOOK_TABLE_FOR_DB) diff --git a/src/tests/test_deps.py b/src/tests/test_deps.py index 4c1afa7b..f7dfc3d7 100644 --- a/src/tests/test_deps.py +++ b/src/tests/test_deps.py @@ -5,7 +5,7 @@ from app import db from app.api import crud, deps, security -from app.api.schemas import AccessRead, DeviceOut, UserRead +from app.schemas import AccessRead, DeviceOut, UserRead from tests.db_utils import fill_table from tests.utils import update_only_datetime