Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: implement more advanced CORS handling #3530

Merged
merged 1 commit into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions server/polar/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
import structlog
from fastapi import FastAPI
from fastapi.routing import APIRoute
from starlette.middleware.cors import CORSMiddleware

from polar import receivers, worker # noqa
from polar.api import router
from polar.config import settings
from polar.exception_handlers import add_exception_handlers
from polar.health.endpoints import router as health_router
from polar.kit.cors import CORSConfig, CORSMatcherMiddleware, Scope
from polar.kit.db.postgres import (
AsyncEngine,
AsyncSessionMaker,
Expand Down Expand Up @@ -48,16 +48,34 @@


def configure_cors(app: FastAPI) -> None:
if not settings.CORS_ORIGINS:
return

app.add_middleware(
CORSMiddleware,
allow_origins=[str(origin) for origin in settings.CORS_ORIGINS],
allow_credentials=True,
configs: list[CORSConfig] = []

# Polar frontend CORS configuration
if settings.CORS_ORIGINS:

def polar_frontend_matcher(origin: str, scope: Scope) -> bool:
return origin in settings.CORS_ORIGINS

polar_frontend_config = CORSConfig(
polar_frontend_matcher,
allow_origins=[str(origin) for origin in settings.CORS_ORIGINS],
allow_credentials=True, # Cookies are allowed, but only there!
allow_methods=["*"],
allow_headers=["*"],
)
configs.append(polar_frontend_config)

# External API calls CORS configuration
api_config = CORSConfig(
lambda origin, scope: True,
allow_origins=["*"],
allow_credentials=False, # No cookies allowed
allow_methods=["*"],
allow_headers=["*"],
allow_headers=["Authorization"], # Allow Authorization header to pass tokens
)
configs.append(api_config)

app.add_middleware(CORSMatcherMiddleware, configs=configs)


def generate_unique_openapi_id(route: APIRoute) -> str:
Expand Down
78 changes: 78 additions & 0 deletions server/polar/kit/cors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import dataclasses
from collections.abc import Sequence
from typing import Protocol

from starlette.datastructures import Headers
from starlette.middleware.cors import CORSMiddleware
from starlette.types import ASGIApp, Receive, Scope, Send


class CORSMatcher(Protocol):
def __call__(self, origin: str, scope: Scope) -> bool: ...


@dataclasses.dataclass
class CORSConfig:
matcher: CORSMatcher
allow_origins: Sequence[str] = ()
allow_methods: Sequence[str] = ("GET",)
allow_headers: Sequence[str] = ()
allow_credentials: bool = False
allow_origin_regex: str | None = None
expose_headers: Sequence[str] = ()
max_age: int = 600

def get_middleware(self, app: ASGIApp) -> CORSMiddleware:
return CORSMiddleware(
app=app,
allow_origins=self.allow_origins,
allow_methods=self.allow_methods,
allow_headers=self.allow_headers,
allow_credentials=self.allow_credentials,
allow_origin_regex=self.allow_origin_regex,
expose_headers=self.expose_headers,
max_age=self.max_age,
)


class CORSMatcherMiddleware:
def __init__(self, app: ASGIApp, *, configs: Sequence[CORSConfig]) -> None:
self.app = app
self.config_middlewares = tuple(
(config, config.get_middleware(app)) for config in configs
)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http": # pragma: no cover
await self.app(scope, receive, send)
return

method = scope["method"]
headers = Headers(scope=scope)
origin = headers.get("origin")

if origin is None:
await self.app(scope, receive, send)
return

middleware = self._get_config_middleware(origin, scope)
if middleware is None:
await self.app(scope, receive, send)
return

if method == "OPTIONS" and "access-control-request-method" in headers:
response = middleware.preflight_response(request_headers=headers)
await response(scope, receive, send)
return
await middleware.simple_response(scope, receive, send, request_headers=headers)

def _get_config_middleware(
self, origin: str, scope: Scope
) -> CORSMiddleware | None:
for config, middleware in self.config_middlewares:
if config.matcher(origin, scope):
return middleware
return None


__all__ = ["CORSConfig", "CORSMatcherMiddleware", "Scope"]
Loading