Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Create a `.env` file in the root with the following values inside it (each varia
- `OAUTH2_CLIENT_ID`: Client ID of Discord OAuth2 Application (see prerequisites).
- `OAUTH2_CLIENT_SECRET`: Client Secret of Discord OAuth2 Application (see prerequisites).
- `ALLOWED_URL`: Allowed origin for CORS middleware.
- `PRODUCTION`: Set to False if running on localhost. Defaults to true.
Comment thread
HassanAbouelela marked this conversation as resolved.

#### Running
To start using the application, simply run `docker-compose up` in the repository root. You'll be able to access the application by visiting http://localhost:8000/
Expand Down
21 changes: 16 additions & 5 deletions backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,21 @@

from backend import constants
from backend.authentication import JWTAuthenticationBackend
from backend.route_manager import create_route_map
from backend.middleware import DatabaseMiddleware, ProtectedDocsMiddleware
from backend.route_manager import create_route_map
from backend.validation import api

ORIGINS = [
r"(https://[^.?#]*--pydis-forms\.netlify\.app)", # Netlify Previews
r"(https?://[^.?#]*.forms-frontend.pages.dev)", # Cloudflare Previews
]

if not constants.PRODUCTION:
Comment thread
HassanAbouelela marked this conversation as resolved.
# Allow all hosts on non-production deployments
ORIGINS.append(r"(.*)")

ALLOW_ORIGIN_REGEX = "|".join(ORIGINS)

sentry_sdk.init(
dsn=constants.FORMS_BACKEND_DSN,
send_default_pii=True,
Expand All @@ -20,13 +31,13 @@
middleware = [
Middleware(
CORSMiddleware,
# TODO: Convert this into a RegEx that works for prod, netlify & previews
allow_origins=["*"],
allow_origins=["https://forms.pythondiscord.com"],
allow_origin_regex=ALLOW_ORIGIN_REGEX,
allow_headers=[
"Authorization",
"Content-Type"
],
allow_methods=["*"]
allow_methods=["*"],
allow_credentials=True
),
Middleware(DatabaseMiddleware),
Middleware(AuthenticationMiddleware, backend=JWTAuthenticationBackend()),
Expand Down
37 changes: 26 additions & 11 deletions backend/authentication/backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import jwt
import typing as t

import jwt
from starlette import authentication
from starlette.requests import Request

Expand All @@ -13,18 +13,18 @@ class JWTAuthenticationBackend(authentication.AuthenticationBackend):
"""Custom Starlette authentication backend for JWT."""

@staticmethod
def get_token_from_header(header: str) -> str:
"""Parse JWT token from header value."""
def get_token_from_cookie(cookie: str) -> str:
"""Parse JWT token from cookie."""
try:
prefix, token = header.split()
prefix, token = cookie.split()
except ValueError:
raise authentication.AuthenticationError(
"Unable to split prefix and token from Authorization header."
"Unable to split prefix and token from authorization cookie."
)

if prefix.upper() != "JWT":
raise authentication.AuthenticationError(
f"Invalid Authorization header prefix '{prefix}'."
f"Invalid authorization cookie prefix '{prefix}'."
)

return token
Expand All @@ -33,11 +33,11 @@ async def authenticate(
self, request: Request
) -> t.Optional[tuple[authentication.AuthCredentials, authentication.BaseUser]]:
"""Handles JWT authentication process."""
if "Authorization" not in request.headers:
cookie = request.cookies.get("token")
if not cookie:
return None

auth = request.headers["Authorization"]
token = self.get_token_from_header(auth)
token = self.get_token_from_cookie(cookie)

try:
payload = jwt.decode(token, constants.SECRET_KEY, algorithms=["HS256"])
Expand All @@ -46,7 +46,22 @@ async def authenticate(

scopes = ["authenticated"]

if payload.get("admin") is True:
if not payload.get("token"):
raise authentication.AuthenticationError("Token is missing from JWT.")
if not payload.get("refresh"):
raise authentication.AuthenticationError(
"Refresh token is missing from JWT."
)

try:
user_details = payload.get("user_details")
if not user_details or not user_details.get("id"):
raise authentication.AuthenticationError("Improper user details.")
except Exception:
raise authentication.AuthenticationError("Could not parse user details.")

user = User(token, user_details)
if await user.fetch_admin_status(request):
scopes.append("admin")

return authentication.AuthCredentials(scopes), User(token, payload)
return authentication.AuthCredentials(scopes), user
26 changes: 26 additions & 0 deletions backend/authentication/user.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import typing as t

import jwt
from starlette.authentication import BaseUser
from starlette.requests import Request

from backend.constants import SECRET_KEY
from backend.discord import fetch_user_details


class User(BaseUser):
Expand All @@ -9,6 +14,7 @@ class User(BaseUser):
def __init__(self, token: str, payload: dict[str, t.Any]) -> None:
self.token = token
self.payload = payload
self.admin = False

@property
def is_authenticated(self) -> bool:
Expand All @@ -23,3 +29,23 @@ def display_name(self) -> str:
@property
def discord_mention(self) -> str:
return f"<@{self.payload['id']}>"

@property
def decoded_token(self) -> dict[str, any]:
return jwt.decode(self.token, SECRET_KEY, algorithms=["HS256"])

async def fetch_admin_status(self, request: Request) -> bool:
self.admin = await request.state.db.admins.find_one(
{"_id": self.payload["id"]}
) is not None

return self.admin

async def refresh_data(self) -> None:
"""Fetches user data from discord, and updates the instance."""
self.payload = await fetch_user_details(self.decoded_token.get("token"))

updated_info = self.decoded_token
updated_info["user_details"] = self.payload

self.token = jwt.encode(updated_info, SECRET_KEY, algorithm="HS256")
2 changes: 2 additions & 0 deletions backend/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
MONGO_DATABASE = os.getenv("MONGO_DATABASE", "pydis_forms")
SNEKBOX_URL = os.getenv("SNEKBOX_URL", "http://snekbox.default.svc.cluster.local/eval")

PRODUCTION = os.getenv("PRODUCTION", "True").lower() != "false"

OAUTH2_CLIENT_ID = os.getenv("OAUTH2_CLIENT_ID")
OAUTH2_CLIENT_SECRET = os.getenv("OAUTH2_CLIENT_SECRET")
OAUTH2_REDIRECT_URI = os.getenv(
Expand Down
15 changes: 10 additions & 5 deletions backend/discord.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,27 @@
import httpx

from backend.constants import (
OAUTH2_CLIENT_ID, OAUTH2_CLIENT_SECRET, OAUTH2_REDIRECT_URI
OAUTH2_CLIENT_ID, OAUTH2_CLIENT_SECRET
)

API_BASE_URL = "https://discord.com/api/v8"


async def fetch_bearer_token(access_code: str) -> dict:
async def fetch_bearer_token(code: str, redirect: str, *, refresh: bool) -> dict:
Comment thread
HassanAbouelela marked this conversation as resolved.
async with httpx.AsyncClient() as client:
data = {
"client_id": OAUTH2_CLIENT_ID,
"client_secret": OAUTH2_CLIENT_SECRET,
"grant_type": "authorization_code",
"code": access_code,
"redirect_uri": OAUTH2_REDIRECT_URI
"redirect_uri": f"{redirect}/callback"
}

if refresh:
data["grant_type"] = "refresh_token"
data["refresh_token"] = code
else:
data["grant_type"] = "authorization_code"
data["code"] = code

r = await client.post(f"{API_BASE_URL}/oauth2/token", headers={
"Content-Type": "application/x-www-form-urlencoded"
}, data=data)
Expand Down
87 changes: 74 additions & 13 deletions backend/routes/auth/authorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,71 @@
Use a token received from the Discord OAuth2 system to fetch user information.
"""

import datetime
from typing import Union

import httpx
import jwt
from pydantic.fields import Field
from pydantic.main import BaseModel
from spectree.response import Response
from starlette.authentication import requires
from starlette.requests import Request
from starlette.responses import JSONResponse

from backend import constants
from backend.authentication.user import User
from backend.constants import SECRET_KEY
from backend.route import Route
from backend.discord import fetch_bearer_token, fetch_user_details
from backend.route import Route
from backend.validation import ErrorMessage, api

AUTH_FAILURE = JSONResponse({"error": "auth_failure"}, status_code=400)


class AuthorizeRequest(BaseModel):
token: str = Field(description="The access token received from Discord.")


class AuthorizeResponse(BaseModel):
token: str = Field(description="A JWT token containing the user information")
username: str = Field("Discord display name.")
expiry: str = Field("ISO formatted timestamp of expiry.")


async def process_token(bearer_token: dict) -> Union[AuthorizeResponse, AUTH_FAILURE]:
"""Post a bearer token to Discord, and return a JWT and username."""
interaction_start = datetime.datetime.now()

try:
user_details = await fetch_user_details(bearer_token["access_token"])
except httpx.HTTPStatusError:
AUTH_FAILURE.delete_cookie("token")
return AUTH_FAILURE

max_age = datetime.timedelta(seconds=int(bearer_token["expires_in"]))
token_expiry = interaction_start + max_age

data = {
"token": bearer_token["access_token"],
"refresh": bearer_token["refresh_token"],
"user_details": user_details,
"expiry": token_expiry.isoformat()
}

token = jwt.encode(data, SECRET_KEY, algorithm="HS256")
user = User(token, user_details)

response = JSONResponse({
"username": user.display_name,
"expiry": token_expiry.isoformat()
})

response.set_cookie(
"token", f"JWT {token}",
secure=constants.PRODUCTION, httponly=True, samesite="strict",
max_age=bearer_token["expires_in"]
)
return response


class AuthorizeRoute(Route):
Expand All @@ -40,19 +85,35 @@ class AuthorizeRoute(Route):
async def post(self, request: Request) -> JSONResponse:
"""Generate an authorization token."""
data = await request.json()

try:
bearer_token = await fetch_bearer_token(data["token"])
user_details = await fetch_user_details(bearer_token["access_token"])
url = request.headers.get("origin")
bearer_token = await fetch_bearer_token(data["token"], url, refresh=False)
except httpx.HTTPStatusError:
return JSONResponse({
"error": "auth_failure"
}, status_code=400)
return AUTH_FAILURE

return await process_token(bearer_token)

user_details["admin"] = await request.state.db.admins.find_one(
{"_id": user_details["id"]}
) is not None

token = jwt.encode(user_details, SECRET_KEY, algorithm="HS256")
class TokenRefreshRoute(Route):
"""
Use the refresh code from a JWT to get a new token and generate a new JWT token.
"""

name = "refresh"
path = "/refresh"

@requires(["authenticated"])
@api.validate(
resp=Response(HTTP_200=AuthorizeResponse, HTTP_400=ErrorMessage),
tags=["auth"]
)
async def post(self, request: Request) -> JSONResponse:
"""Refresh an authorization token."""
try:
token = request.user.decoded_token.get("refresh")
url = request.headers.get("origin")
bearer_token = await fetch_bearer_token(token, url, refresh=True)
except httpx.HTTPStatusError:
return AUTH_FAILURE

return JSONResponse({"token": token})
return await process_token(bearer_token)
2 changes: 1 addition & 1 deletion backend/routes/forms/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class SingleForm(Route):
@api.validate(resp=Response(HTTP_200=Form, HTTP_404=ErrorMessage), tags=["forms"])
async def get(self, request: Request) -> JSONResponse:
"""Returns single form information by ID."""
admin = request.user.payload["admin"] if request.user.is_authenticated else False
admin = request.user.admin if request.user.is_authenticated else False

filters = {
"_id": request.path_params["form_id"]
Expand Down
Loading