Skip to content

Commit

Permalink
✨ Add base class to simplify CRUD (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
ebreton authored and tiangolo committed Jan 19, 2020
1 parent 1c975c7 commit ab46165
Show file tree
Hide file tree
Showing 33 changed files with 321 additions and 282 deletions.
2 changes: 1 addition & 1 deletion test.sh
Expand Up @@ -9,6 +9,6 @@ cookiecutter --config-file ./testing-config.yml --no-input -f ./

cd ./testing-project

bash ./scripts/test.sh
bash ./scripts/test.sh "$@"

cd ../
4 changes: 2 additions & 2 deletions {{cookiecutter.project_slug}}/README.md
Expand Up @@ -55,7 +55,7 @@ If your Docker is not running in `localhost` (the URLs above wouldn't work) chec

Open your editor at `./backend/app/` (instead of the project root: `./`), so that you see an `./app/` directory with your code inside. That way, your editor will be able to find all the imports, etc.

Modify or add SQLAlchemy models in `./backend/app/app/db_models/`, Pydantic models in `./backend/app/app/models/`, API endpoints in `./backend/app/app/api/`, CRUD (Create, Read, Update, Delete) utils in `./backend/app/app/crud/`. The easiest might be to copy the ones for Items (models, endpoints, and CRUD utils) and update them to your needs.
Modify or add SQLAlchemy models in `./backend/app/app/models/`, Pydantic schemas in `./backend/app/app/schemas/`, API endpoints in `./backend/app/app/api/`, CRUD (Create, Read, Update, Delete) utils in `./backend/app/app/crud/`. The easiest might be to copy the ones for Items (models, endpoints, and CRUD utils) and update them to your needs.

Add and modify tasks to the Celery worker in `./backend/app/app/worker.py`.

Expand Down Expand Up @@ -205,7 +205,7 @@ Make sure you create a "revision" of your models and that you "upgrade" your dat
docker-compose exec backend bash
```

* If you created a new model in `./backend/app/app/db_models/`, make sure to import it in `./backend/app/app/db/base.py`, that Python module (`base.py`) that imports all the models will be used by Alembic.
* If you created a new model in `./backend/app/app/models/`, make sure to import it in `./backend/app/app/db/base.py`, that Python module (`base.py`) that imports all the models will be used by Alembic.

* After changing a model (for example, adding a column), inside the container, create a revision, e.g.:

Expand Down
Expand Up @@ -6,8 +6,8 @@
from app import crud
from app.api.utils.db import get_db
from app.api.utils.security import get_current_active_user
from app.db_models.user import User as DBUser
from app.models.item import Item, ItemCreate, ItemUpdate
from app.models.user import User as DBUser
from app.schemas.item import Item, ItemCreate, ItemUpdate

router = APIRouter()

Expand Down Expand Up @@ -41,7 +41,9 @@ def create_item(
"""
Create new item.
"""
item = crud.item.create(db_session=db, item_in=item_in, owner_id=current_user.id)
item = crud.item.create_with_owner(
db_session=db, obj_in=item_in, owner_id=current_user.id
)
return item


Expand All @@ -61,7 +63,7 @@ def update_item(
raise HTTPException(status_code=404, detail="Item not found")
if not crud.user.is_superuser(current_user) and (item.owner_id != current_user.id):
raise HTTPException(status_code=400, detail="Not enough permissions")
item = crud.item.update(db_session=db, item=item, item_in=item_in)
item = crud.item.update(db_session=db, db_obj=item, obj_in=item_in)
return item


Expand Down
Expand Up @@ -10,10 +10,10 @@
from app.core import config
from app.core.jwt import create_access_token
from app.core.security import get_password_hash
from app.db_models.user import User as DBUser
from app.models.msg import Msg
from app.models.token import Token
from app.models.user import User
from app.models.user import User as DBUser
from app.schemas.msg import Msg
from app.schemas.token import Token
from app.schemas.user import User
from app.utils import (
generate_password_reset_token,
send_reset_password_email,
Expand Down
Expand Up @@ -2,15 +2,15 @@

from fastapi import APIRouter, Body, Depends, HTTPException
from fastapi.encoders import jsonable_encoder
from pydantic.types import EmailStr
from pydantic.networks import EmailStr
from sqlalchemy.orm import Session

from app import crud
from app.api.utils.db import get_db
from app.api.utils.security import get_current_active_superuser, get_current_active_user
from app.core import config
from app.db_models.user import User as DBUser
from app.models.user import User, UserCreate, UserInDB, UserUpdate
from app.models.user import User as DBUser
from app.schemas.user import User, UserCreate, UserUpdate
from app.utils import send_new_account_email

router = APIRouter()
Expand Down Expand Up @@ -46,7 +46,7 @@ def create_user(
status_code=400,
detail="The user with this username already exists in the system.",
)
user = crud.user.create(db, user_in=user_in)
user = crud.user.create(db, obj_in=user_in)
if config.EMAILS_ENABLED and user_in.email:
send_new_account_email(
email_to=user_in.email, username=user_in.email, password=user_in.password
Expand Down Expand Up @@ -74,7 +74,7 @@ def update_user_me(
user_in.full_name = full_name
if email is not None:
user_in.email = email
user = crud.user.update(db, user=current_user, user_in=user_in)
user = crud.user.update(db, db_obj=current_user, obj_in=user_in)
return user


Expand Down Expand Up @@ -103,7 +103,7 @@ def create_user_open(
if not config.USERS_OPEN_REGISTRATION:
raise HTTPException(
status_code=403,
detail="Open user resgistration is forbidden on this server",
detail="Open user registration is forbidden on this server",
)
user = crud.user.get_by_email(db, email=email)
if user:
Expand All @@ -112,7 +112,7 @@ def create_user_open(
detail="The user with this username already exists in the system",
)
user_in = UserCreate(password=password, email=email, full_name=full_name)
user = crud.user.create(db, user_in=user_in)
user = crud.user.create(db, obj_in=user_in)
return user


Expand All @@ -125,7 +125,7 @@ def read_user_by_id(
"""
Get a specific user by id.
"""
user = crud.user.get(db, user_id=user_id)
user = crud.user.get(db, id=user_id)
if user == current_user:
return user
if not crud.user.is_superuser(current_user):
Expand All @@ -141,16 +141,16 @@ def update_user(
db: Session = Depends(get_db),
user_id: int,
user_in: UserUpdate,
current_user: UserInDB = Depends(get_current_active_superuser),
current_user: DBUser = Depends(get_current_active_superuser),
):
"""
Update a user.
"""
user = crud.user.get(db, user_id=user_id)
user = crud.user.get(db, id=user_id)
if not user:
raise HTTPException(
status_code=404,
detail="The user with this username does not exist in the system",
)
user = crud.user.update(db, user=user, user_in=user_in)
user = crud.user.update(db, db_obj=user, obj_in=user_in)
return user
@@ -1,18 +1,19 @@
from fastapi import APIRouter, Depends
from pydantic.types import EmailStr
from pydantic.networks import EmailStr

from app.api.utils.security import get_current_active_superuser
from app.core.celery_app import celery_app
from app.models.msg import Msg
from app.models.user import UserInDB
from app.schemas.msg import Msg
from app.schemas.user import User
from app.models.user import User as DBUser
from app.utils import send_test_email

router = APIRouter()


@router.post("/test-celery/", response_model=Msg, status_code=201)
def test_celery(
msg: Msg, current_user: UserInDB = Depends(get_current_active_superuser)
msg: Msg, current_user: DBUser = Depends(get_current_active_superuser)
):
"""
Test Celery worker.
Expand All @@ -23,7 +24,7 @@ def test_celery(

@router.post("/test-email/", response_model=Msg, status_code=201)
def test_email(
email_to: EmailStr, current_user: UserInDB = Depends(get_current_active_superuser)
email_to: EmailStr, current_user: DBUser = Depends(get_current_active_superuser)
):
"""
Test emails.
Expand Down
Expand Up @@ -9,8 +9,8 @@
from app.api.utils.db import get_db
from app.core import config
from app.core.jwt import ALGORITHM
from app.db_models.user import User
from app.models.token import TokenPayload
from app.models.user import User
from app.schemas.token import TokenPayload

reusable_oauth2 = OAuth2PasswordBearer(tokenUrl="/api/v1/login/access-token")

Expand All @@ -25,7 +25,7 @@ def get_current_user(
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials"
)
user = crud.user.get(db, user_id=token_data.user_id)
user = crud.user.get(db, id=token_data.user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
return user
Expand Down
11 changes: 10 additions & 1 deletion {{cookiecutter.project_slug}}/backend/app/app/crud/__init__.py
@@ -1 +1,10 @@
from . import item, user
from .crud_user import user
from .crud_item import item

# For a new basic set of CRUD operations you could just do

# from .base import CRUDBase
# from app.models.item import Item
# from app.schemas.item import ItemCreate, ItemUpdate

# item = CRUDBase[Item, ItemCreate, ItemUpdate](Item)
57 changes: 57 additions & 0 deletions {{cookiecutter.project_slug}}/backend/app/app/crud/base.py
@@ -0,0 +1,57 @@
from typing import List, Optional, Generic, TypeVar, Type

from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from sqlalchemy.orm import Session

from app.db.base_class import Base

ModelType = TypeVar("ModelType", bound=Base)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)


class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
def __init__(self, model: Type[ModelType]):
"""
CRUD object with default methods to Create, Read, Update, Delete (CRUD).
**Parameters**
* `model`: A SQLAlchemy model class
* `schema`: A Pydantic model (schema) class
"""
self.model = model

def get(self, db_session: Session, id: int) -> Optional[ModelType]:
return db_session.query(self.model).filter(self.model.id == id).first()

def get_multi(self, db_session: Session, *, skip=0, limit=100) -> List[ModelType]:
return db_session.query(self.model).offset(skip).limit(limit).all()

def create(self, db_session: Session, *, obj_in: CreateSchemaType) -> ModelType:
obj_in_data = jsonable_encoder(obj_in)
db_obj = self.model(**obj_in_data)
db_session.add(db_obj)
db_session.commit()
db_session.refresh(db_obj)
return db_obj

def update(
self, db_session: Session, *, db_obj: ModelType, obj_in: UpdateSchemaType
) -> ModelType:
obj_data = jsonable_encoder(db_obj)
update_data = obj_in.dict(skip_defaults=True)
for field in obj_data:
if field in update_data:
setattr(db_obj, field, update_data[field])
db_session.add(db_obj)
db_session.commit()
db_session.refresh(db_obj)
return db_obj

def remove(self, db_session: Session, *, id: int) -> ModelType:
obj = db_session.query(self.model).get(id)
db_session.delete(obj)
db_session.commit()
return obj
34 changes: 34 additions & 0 deletions {{cookiecutter.project_slug}}/backend/app/app/crud/crud_item.py
@@ -0,0 +1,34 @@
from typing import List

from fastapi.encoders import jsonable_encoder
from sqlalchemy.orm import Session

from app.models.item import Item
from app.schemas.item import ItemCreate, ItemUpdate
from app.crud.base import CRUDBase


class CRUDItem(CRUDBase[Item, ItemCreate, ItemUpdate]):
def create_with_owner(
self, db_session: Session, *, obj_in: ItemCreate, owner_id: int
) -> Item:
obj_in_data = jsonable_encoder(obj_in)
db_obj = self.model(**obj_in_data, owner_id=owner_id)
db_session.add(db_obj)
db_session.commit()
db_session.refresh(db_obj)
return db_obj

def get_multi_by_owner(
self, db_session: Session, *, owner_id: int, skip=0, limit=100
) -> List[Item]:
return (
db_session.query(self.model)
.filter(Item.owner_id == owner_id)
.offset(skip)
.limit(limit)
.all()
)


item = CRUDItem(Item)
44 changes: 44 additions & 0 deletions {{cookiecutter.project_slug}}/backend/app/app/crud/crud_user.py
@@ -0,0 +1,44 @@
from typing import Optional

from sqlalchemy.orm import Session

from app.models.user import User
from app.schemas.user import UserCreate, UserUpdate
from app.core.security import verify_password, get_password_hash
from app.crud.base import CRUDBase


class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
def get_by_email(self, db_session: Session, *, email: str) -> Optional[User]:
return db_session.query(User).filter(User.email == email).first()

def create(self, db_session: Session, *, obj_in: UserCreate) -> User:
db_obj = User(
email=obj_in.email,
hashed_password=get_password_hash(obj_in.password),
full_name=obj_in.full_name,
is_superuser=obj_in.is_superuser,
)
db_session.add(db_obj)
db_session.commit()
db_session.refresh(db_obj)
return db_obj

def authenticate(
self, db_session: Session, *, email: str, password: str
) -> Optional[User]:
user = self.get_by_email(db_session, email=email)
if not user:
return None
if not verify_password(password, user.hashed_password):
return None
return user

def is_active(self, user: User) -> bool:
return user.is_active

def is_superuser(self, user: User) -> bool:
return user.is_superuser


user = CRUDUser(User)

0 comments on commit ab46165

Please sign in to comment.