Skip to content

Commit

Permalink
server: implement more advanced CORS handling
Browse files Browse the repository at this point in the history
The goal is to allow external API callers to be able to do so in a cross-origin scenario *but* only with the `Authorization` header and a proper token: cookie authentication should be blocked for them and allowed only for our own allow-listed origins.
  • Loading branch information
frankie567 committed Jun 25, 2024
1 parent 02492ce commit fee6d6a
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 9 deletions.
36 changes: 27 additions & 9 deletions server/polar/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
import structlog
from fastapi import FastAPI
from fastapi.routing import APIRoute
from starlette.middleware.cors import CORSMiddleware

from polar import receivers, worker # noqa
from polar.api import router
from polar.config import settings
from polar.exception_handlers import add_exception_handlers
from polar.health.endpoints import router as health_router
from polar.kit.cors import CORSConfig, CORSMatcherMiddleware, Scope
from polar.kit.db.postgres import (
AsyncEngine,
AsyncSessionMaker,
Expand Down Expand Up @@ -48,16 +48,34 @@


def configure_cors(app: FastAPI) -> None:
if not settings.CORS_ORIGINS:
return

app.add_middleware(
CORSMiddleware,
allow_origins=[str(origin) for origin in settings.CORS_ORIGINS],
allow_credentials=True,
configs: list[CORSConfig] = []

# Polar frontend CORS configuration
if settings.CORS_ORIGINS:

def polar_frontend_matcher(origin: str, scope: Scope) -> bool:
return origin in settings.CORS_ORIGINS

polar_frontend_config = CORSConfig(
polar_frontend_matcher,
allow_origins=[str(origin) for origin in settings.CORS_ORIGINS],
allow_credentials=True, # Cookies are allowed, but only there!
allow_methods=["*"],
allow_headers=["*"],
)
configs.append(polar_frontend_config)

# External API calls CORS configuration
api_config = CORSConfig(
lambda origin, scope: True,
allow_origins=["*"],
allow_credentials=False, # No cookies allowed
allow_methods=["*"],
allow_headers=["*"],
allow_headers=["Authorization"], # Allow Authorization header to pass tokens
)
configs.append(api_config)

app.add_middleware(CORSMatcherMiddleware, configs=configs)


def generate_unique_openapi_id(route: APIRoute) -> str:
Expand Down
78 changes: 78 additions & 0 deletions server/polar/kit/cors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import dataclasses
from collections.abc import Sequence
from typing import Protocol

from starlette.datastructures import Headers
from starlette.middleware.cors import CORSMiddleware
from starlette.types import ASGIApp, Receive, Scope, Send


class CORSMatcher(Protocol):
def __call__(self, origin: str, scope: Scope) -> bool: ...


@dataclasses.dataclass
class CORSConfig:
matcher: CORSMatcher
allow_origins: Sequence[str] = ()
allow_methods: Sequence[str] = ("GET",)
allow_headers: Sequence[str] = ()
allow_credentials: bool = False
allow_origin_regex: str | None = None
expose_headers: Sequence[str] = ()
max_age: int = 600

def get_middleware(self, app: ASGIApp) -> CORSMiddleware:
return CORSMiddleware(
app=app,
allow_origins=self.allow_origins,
allow_methods=self.allow_methods,
allow_headers=self.allow_headers,
allow_credentials=self.allow_credentials,
allow_origin_regex=self.allow_origin_regex,
expose_headers=self.expose_headers,
max_age=self.max_age,
)


class CORSMatcherMiddleware:
def __init__(self, app: ASGIApp, *, configs: Sequence[CORSConfig]) -> None:
self.app = app
self.config_middlewares = tuple(
(config, config.get_middleware(app)) for config in configs
)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http": # pragma: no cover
await self.app(scope, receive, send)
return

method = scope["method"]
headers = Headers(scope=scope)
origin = headers.get("origin")

if origin is None:
await self.app(scope, receive, send)
return

middleware = self._get_config_middleware(origin, scope)
if middleware is None:
await self.app(scope, receive, send)
return

if method == "OPTIONS" and "access-control-request-method" in headers:
response = middleware.preflight_response(request_headers=headers)
await response(scope, receive, send)
return
await middleware.simple_response(scope, receive, send, request_headers=headers)

def _get_config_middleware(
self, origin: str, scope: Scope
) -> CORSMiddleware | None:
for config, middleware in self.config_middlewares:
if config.matcher(origin, scope):
return middleware
return None


__all__ = ["CORSConfig", "CORSMatcherMiddleware", "Scope"]

0 comments on commit fee6d6a

Please sign in to comment.