Skip to content

Commit

Permalink
Use status code constants
Browse files Browse the repository at this point in the history
  • Loading branch information
hluk committed Sep 15, 2023
1 parent fa16595 commit e50acda
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 30 deletions.
26 changes: 15 additions & 11 deletions product_listings_manager/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging

import gssapi
from fastapi import HTTPException
from fastapi import HTTPException, status

logger = logging.getLogger(__name__)

Expand All @@ -25,7 +25,8 @@ def process_gssapi_request(token):
if not sc.complete:
logger.error("Multiple GSSAPI round trips not supported")
raise HTTPException(
status_code=403, detail="Attempted multiple GSSAPI round trips"
status_code=status.HTTP_403_FORBIDDEN,
detail="Attempted multiple GSSAPI round trips",
)

logger.debug("Completed GSSAPI negotiation")
Expand All @@ -39,31 +40,33 @@ def process_gssapi_request(token):
stage,
e.gen_message(),
)
raise HTTPException(status_code=403, detail="Authentication failed")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Authentication failed",
)


def get_user(request):
return get_user_by_method(request, "Kerberos")


def get_user_by_method(request, auth_method):
if "Authorization" not in request.headers:
raise HTTPException(
status_code=401, headers={"WWW-Authenticate": "Negotiate"}
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Negotiate"},
)

header = request.headers.get("Authorization")
scheme, *rest = header.strip().split(maxsplit=1)

if scheme != "Negotiate":
raise HTTPException(
status_code=401,
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Unsupported authentication scheme; supported is Negotiate",
)

if not rest or not rest[0]:
raise HTTPException(
status_code=401, detail="Missing authentication token"
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing authentication token",
)

token = rest[0]
Expand All @@ -72,7 +75,8 @@ def get_user_by_method(request, auth_method):
user, token = process_gssapi_request(base64.b64decode(token))
except binascii.Error:
raise HTTPException(
status_code=401, detail="Invalid authentication token"
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication token",
)

token = base64.b64encode(token).decode("utf-8")
Expand Down
8 changes: 5 additions & 3 deletions product_listings_manager/authorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass

import ldap
from fastapi import HTTPException
from fastapi import HTTPException, status

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -37,10 +37,12 @@ def get_user_groups(
except ldap.SERVER_DOWN:
log.exception("The LDAP server is unreachable")
raise HTTPException(
status_code=502, detail="The LDAP server is unreachable"
status_code=status.HTTP_502_BAD_GATEWAY,
detail="The LDAP server is unreachable",
)
except ldap.LDAPError:
log.exception("Unexpected LDAP connection error")
raise HTTPException(
status_code=502, detail="Unexpected LDAP connection error"
status_code=status.HTTP_502_BAD_GATEWAY,
detail="Unexpected LDAP connection error",
)
5 changes: 3 additions & 2 deletions product_listings_manager/db_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from typing import Any

from fastapi import HTTPException
from fastapi import HTTPException, status
from sqlalchemy import text
from sqlalchemy.exc import ResourceClosedError, SQLAlchemyError

Expand All @@ -22,7 +22,8 @@ def execute_queries(db, queries: list[SqlQuery]) -> list[list[Any]]:
except SQLAlchemyError as e:
logger.warning("Failed DB query for user %s", e)
raise HTTPException(
status_code=400, detail=f"DB query failed: {e}"
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"DB query failed: {e}",
)

db.commit()
Expand Down
49 changes: 35 additions & 14 deletions product_listings_manager/rest_api_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
from functools import lru_cache

from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi import APIRouter, Depends, HTTPException, Request, status
from sqlalchemy import text
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -35,7 +35,7 @@ def ldap_config() -> LdapConfig:

if not ldap_host or not ldap_searches:
raise HTTPException(
status_code=500,
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Server configuration LDAP_HOST and LDAP_SEARCHES is required.",
)

Expand Down Expand Up @@ -104,21 +104,27 @@ def health(db: Session = Depends(get_db)):
except Exception as e:
logger.error("Failed to parse permissions configuration: %s", e)
raise HTTPException(
status_code=503,
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Failed to parse permissions configuration: {e}",
)

try:
db.execute(text("SELECT 1"))
except SQLAlchemyError as e:
logger.warning("DB health check failed: %s", e)
raise HTTPException(status_code=503, detail=f"DB Error: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"DB Error: {e}",
)

try:
products.get_koji_session().getAPIVersion()
except Exception as e:
logger.warning("Koji health check failed: %s", e)
raise HTTPException(status_code=503, detail=f"Koji Error: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Koji Error: {e}",
)

return Message(message=HEALTH_OK_MESSAGE)

Expand All @@ -138,12 +144,16 @@ def product_info(label: str, request: Request, db: Session = Depends(get_db)):
try:
versions, variants = products.get_product_info(db, label)
except products.ProductListingsNotFoundError as ex:
raise HTTPException(status_code=404, detail=str(ex))
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(ex)
)
except Exception as ex:
utils.log_remote_call_error(
request, "API call get_product_info() failed", label
)
raise HTTPException(status_code=500, detail=str(ex))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(ex)
)
return [versions, variants]


Expand All @@ -155,7 +165,9 @@ def product_labels(request: Request, db: Session = Depends(get_db)):
utils.log_remote_call_error(
request, "API call get_product_labels() failed"
)
raise HTTPException(status_code=500, detail=str(ex))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(ex)
)


@router.get("/product-listings/{label}/{build_info}")
Expand All @@ -168,15 +180,19 @@ def product_listings(
try:
return products.get_product_listings(db, label, build_info)
except products.ProductListingsNotFoundError as ex:
raise HTTPException(status_code=404, detail=str(ex))
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(ex)
)
except Exception as ex:
utils.log_remote_call_error(
request,
"API call get_product_listings() failed",
label,
build_info,
)
raise HTTPException(status_code=500, detail=str(ex))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(ex)
)


@router.get("/module-product-listings/{label}/{module_build_nvr}")
Expand All @@ -191,15 +207,19 @@ def module_product_listings(
db, label, module_build_nvr
)
except products.ProductListingsNotFoundError as ex:
raise HTTPException(status_code=404, detail=str(ex))
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(ex)
)
except Exception as ex:
utils.log_remote_call_error(
request,
"API call get_module_product_listings() failed",
label,
module_build_nvr,
)
raise HTTPException(status_code=500, detail=str(ex))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(ex)
)


@router.get("/permissions")
Expand Down Expand Up @@ -227,7 +247,8 @@ async def dbquery(
"""
if not query_or_queries:
raise HTTPException(
status_code=422, detail="Queries must not be empty"
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Queries must not be empty",
)

ldap_config_ = ldap_config()
Expand All @@ -247,7 +268,7 @@ async def dbquery(
if not has_permission(user, queries, permissions(), ldap_config_):
logger.warning("Unauthorized access for user %s", user)
raise HTTPException(
status_code=401,
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"User {user} is not authorized to use this query",
)

Expand Down

0 comments on commit e50acda

Please sign in to comment.