diff --git a/app/songs/handlers.py b/app/songs/handlers.py index ae6cc4b..85e044c 100644 --- a/app/songs/handlers.py +++ b/app/songs/handlers.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Query from fastapi.security import OAuth2PasswordBearer from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -6,21 +6,31 @@ from ..db import get_db_session from ..redis import r +from ..utils.pagination import paginate from .models import City, Song, Tag -from .schemas import CityCreate, CityRead, SongCreate, SongRead, TagCreate, TagRead +from .schemas import ( + CityCreate, + CityRead, + SongCreate, + TagCreate, + TagRead, +) router = APIRouter() oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") -@router.get("/songs", response_model=list[SongRead]) -async def get_songs(session: AsyncSession = Depends(get_db_session)): - result = await session.scalars( - select(Song).options(selectinload(Song.tags), selectinload(Song.city)) - ) - songs = result.all() - return songs +@router.get("/songs") +async def get_songs( + page: int = Query(1, ge=1), + per_page: int = Query(5, le=100), + session: AsyncSession = Depends(get_db_session), +): + query = select(Song).options(selectinload(Song.tags), selectinload(Song.city)) + items, pagination = await paginate(session, query, page, per_page) + + return {"meta": pagination, "data": items} @router.post("/songs") diff --git a/app/songs/schemas.py b/app/songs/schemas.py index f919816..2ea790b 100644 --- a/app/songs/schemas.py +++ b/app/songs/schemas.py @@ -50,3 +50,8 @@ class SongRead(BaseModel): # class Config: # from_attributes = True + + +class PaginatedSong(BaseModel): + data: list[SongRead] + meta: dict[str, int | None] diff --git a/app/utils/pagination.py b/app/utils/pagination.py new file mode 100644 index 0000000..71f8a61 --- /dev/null +++ b/app/utils/pagination.py @@ -0,0 +1,43 @@ +from fastapi import HTTPException +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + + +async def paginate( + session: AsyncSession, + query, + page: int = 1, + per_page: int = 10, +): + """ + Paginate a SQLModel query. + :param session: The SQLAlchemy session instance. + :param query: The SQLModel query to paginate. + :param page: The page number to retrieve. + :param per_page: The number of items per page. + :return: Tuple containing the paginated data and pagination information. + """ + total_items = await session.execute(select(func.count()).select_from(query)) + count = total_items.scalar() + + if count == 0: + return [], {"total_pages": 0, "current_page": 0, "next_page": None} + + total_pages = (count - 1) // per_page + 1 + + if page > total_pages: + raise HTTPException(status_code=404, detail="Page not found") + + offset = (page - 1) * per_page + items = await session.execute(query.offset(offset).limit(per_page)) + result = items.scalars().all() + + next_page = page + 1 if page < total_pages else None + + pagination = { + "current_page": page, + "total_pages": total_pages, + "next_page": next_page, + } + + return result, pagination