-
First Check
Commit to Help
Example Codefrom datetime import datetime, timedelta, timezone
from typing import Annotated
import jwt # type: ignore
from fastapi import Depends, FastAPI, HTTPException, status, APIRouter
from sqlmodel import Session, SQLModel, Field, select
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jwt.exceptions import InvalidTokenError # type: ignore
from passlib.context import CryptContext # type: ignore
from pydantic import BaseModel
from api.database import get_session
# to get a string like this run:
# openssl rand -hex 32
SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
class Token(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
username: str | None
class User(BaseModel):
username: str
email: str | None = None
full_name: str | None = None
disabled: bool | None = None
class UserInDB(User):
hashed_password: str
# SQLModel for user db
class AuthUserBase(SQLModel):
UserName: str = Field(max_length=20)
DisplayName: str = Field(max_length=50)
Enabled: bool
class AuthUser(AuthUserBase, table=True):
UserId: int | None = Field(default=None, primary_key=True)
Password: str = Field(max_length=60)
class AuthUserCreate(AuthUserBase):
pass
class AuthUserRead(AuthUserBase):
Password: str
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
app = FastAPI()
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password):
return pwd_context.hash(password)
def get_user(db, username: str):
user: AuthUser = db.exec(select(AuthUser).where(AuthUser.UserName==username)).one_or_none()
if user:
return user
def authenticate_user(db, username: str, password: str):
user = get_user(db, username)
if not user:
return False
if not verify_password(password, user.Password):
return False
return user
def create_access_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_data = TokenData(username=username)
except InvalidTokenError:
raise credentials_exception
user = get_user(real_db, username=token_data.username)
if user is None:
raise credentials_exception
return user
async def get_current_active_user(
current_user: Annotated[User, Depends(get_current_user)],
):
if current_user.disabled:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
router = APIRouter()
@router.post("/token")
async def login_for_access_token(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: Session = Depends(get_session)
) -> Token:
user = authenticate_user(db, form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user.UserName}, expires_delta=access_token_expires
)
return Token(access_token=access_token, token_type="bearer")
@router.get("/users/me/", response_model=User)
async def read_users_me(
current_user: Annotated[User, Depends(get_current_active_user)],
):
return current_user
@router.get("/users/me/items/")
async def read_own_items(
current_user: Annotated[User, Depends(get_current_active_user)],
):
return [{"item_id": "Foo", "owner": current_user.username}] DescriptionWorking through the tutorial for OAuth2 and am having a difficult time ripping out the fake_users_db and plumbing in my own database. I have it working for /token endpoint because I can pass the db Session through to authenticate_user and get_user. I cannot figure out how to do the same for the two other endpoints. Operating SystemLinux Operating System DetailsNo response FastAPI Version0.92.0 Pydantic Version1.10.4 Python Version3.11.2 Additional ContextTutorial: https://fastapi.tiangolo.com/tutorial/security/oauth2-jwt/ |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Problem AnalysisThe core issue is how to pass the database session in FastAPI dependencies so that it can be used in other endpoints. For example, in the SolutionWe need to modify the
Try the following codefrom typing import Annotated
from jwt.exceptions import InvalidTokenError
from fastapi import Depends, FastAPI, HTTPException, status, APIRouter
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from sqlmodel import Session, select
from pydantic import BaseModel
from api.database import get_session
SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
class Token(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
username: str | None
class User(BaseModel):
username: str
email: str | None = None
full_name: str | None = None
disabled: bool | None = None
class UserInDB(User):
hashed_password: str
class AuthUserBase(SQLModel):
UserName: str = Field(max_length=20)
DisplayName: str = Field(max_length=50)
Enabled: bool
class AuthUser(AuthUserBase, table=True):
UserId: int | None = Field(default=None, primary_key=True)
Password: str = Field(max_length=60)
class AuthUserCreate(AuthUserBase):
pass
class AuthUserRead(AuthUserBase):
Password: str
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
app = FastAPI()
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password):
return pwd_context.hash(password)
def get_user(db: Session, username: str):
user: AuthUser = db.exec(select(AuthUser).where(AuthUser.UserName == username)).one_or_none()
return user
def authenticate_user(db: Session, username: str, password: str):
user = get_user(db, username)
if not user:
return False
if not verify_password(password, user.Password):
return False
return user
def create_access_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)], db: Session = Depends(get_session)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_data = TokenData(username=username)
except InvalidTokenError:
raise credentials_exception
user = get_user(db, username=token_data.username)
if user is None:
raise credentials_exception
return user
async def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]):
if current_user.disabled:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
router = APIRouter()
@router.post("/token")
async def login_for_access_token(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: Session = Depends(get_session)
) -> Token:
user = authenticate_user(db, form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user.UserName}, expires_delta=access_token_expires
)
return Token(access_token=access_token, token_type="bearer")
@router.get("/users/me/", response_model=User)
async def read_users_me(
current_user: Annotated[User, Depends(get_current_active_user)],
):
return current_user
@router.get("/users/me/items/")
async def read_own_items(
current_user: Annotated[User, Depends(get_current_active_user)],
):
return [{"item_id": "Foo", "owner": current_user.username}]
app.include_router(router) |
Beta Was this translation helpful? Give feedback.
Problem Analysis
The core issue is how to pass the database session in FastAPI dependencies so that it can be used in other endpoints. For example, in the
get_current_user
function, thedb
session needs to be passed.Solution
We need to modify the
get_current_user
and other related functions so that they can receive the database session as a dependency. The specific steps are as follows:Define a dependency to get the database session:
Ensure you have a dependency function to get the database session, such as
get_session
.Modify the
get_current_user
function:Add the database session as a parameter and use FastAPI's
Depends
dependency injection mechanism to get the session.Modify o…