Skip to content

Commit

Permalink
Add security headers to the ZenML server
Browse files Browse the repository at this point in the history
  • Loading branch information
stefannica committed Apr 3, 2024
1 parent 39d3ca8 commit cad2728
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ fastapi-utils = { version = "~0.2.1", optional = true }
orjson = { version = "~3.8.3", optional = true }
Jinja2 = { version = "*", optional = true }
ipinfo = { version = ">=4.4.3", optional = true }
secure = { version = "~0.3.0", optional = true }

# Optional dependencies for project templates
copier = { version = ">=8.1.0", optional = true }
Expand Down Expand Up @@ -180,6 +181,7 @@ server = [
"orjson",
"Jinja2",
"ipinfo",
"secure",
]
templates = ["copier", "jinja2-time", "ruff"]
terraform = ["python-terraform"]
Expand Down
24 changes: 24 additions & 0 deletions src/zenml/zen_server/zen_server_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from asyncio.log import logger
from typing import Any, List

import secure
from fastapi import FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import ORJSONResponse
Expand Down Expand Up @@ -120,6 +121,29 @@ def validation_exception_handler(
allow_headers=["*"],
)

secure_headers = secure.Secure(
# TODO: Add a nonce to the CSP header when ZenML supports it
# (see https://content-security-policy.com/examples/allow-inline-script/)
# csp=secure.ContentSecurityPolicy(),
permissions=secure.PermissionsPolicy()
)


@app.middleware("http")
async def set_secure_headers(request: Request, call_next: Any) -> Any:
"""Middleware to set secure headers.
Args:
request: The incoming request.
call_next: The next function to be called.
Returns:
The response with secure headers set.
"""
response = await call_next(request)
secure_headers.framework.fastapi(response)
return response


@app.middleware("http")
async def infer_source_context(request: Request, call_next: Any) -> Any:
Expand Down

0 comments on commit cad2728

Please sign in to comment.