Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable CSP headers for the openAPI docs pages and fix API docs building #2598

Merged
merged 4 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/mocked_libs.json
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
"google.api_core",
"google.api_core.exceptions",
"google.auth",
"google.auth.aws",
"google.auth._default",
"google.auth.exceptions",
"google.auth.transport",
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ python = ">=3.8,<3.12"
python-dateutil = "^2.8.1"
pyyaml = ">=6.0.1"
rich = { extras = ["jupyter"], version = ">=12.0.0" }
secure = "~0.3.0"
sqlalchemy_utils = "0.38.3"
sqlmodel = "0.0.8"
importlib_metadata = { version = "<=7.0.0", python = "<3.10" }
Expand All @@ -73,7 +74,6 @@ fastapi-utils = { version = "~0.2.1", optional = true }
orjson = { version = "~3.10.0", optional = true }
Jinja2 = { version = "*", optional = true }
ipinfo = { version = ">=4.4.3", optional = true }
secure = { version = "~0.3.0", optional = true }

# Optional dependencies for project templates
copier = { version = ">=8.1.0", optional = true }
Expand Down Expand Up @@ -182,7 +182,6 @@ server = [
"orjson",
"Jinja2",
"ipinfo",
"secure",
]
templates = ["copier", "jinja2-time", "ruff", "pyyaml-include"]
terraform = ["python-terraform"]
Expand Down
6 changes: 3 additions & 3 deletions src/zenml/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,10 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
DEFAULT_ZENML_SERVER_SECURE_HEADERS_CONTENT = "nosniff"
DEFAULT_ZENML_SERVER_SECURE_HEADERS_CSP = (
"default-src 'none'; "
"script-src 'self' 'unsafe-inline' 'unsafe-eval' https://cdn.jsdelivr.net; "
"script-src 'self' 'unsafe-inline' 'unsafe-eval'; "
"connect-src 'self' https://sdkdocs.zenml.io https://hubapi.zenml.io; "
"img-src 'self' data: https://public-flavor-logos.s3.eu-central-1.amazonaws.com https://fastapi.tiangolo.com; "
"style-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net; "
"img-src 'self' data: https://public-flavor-logos.s3.eu-central-1.amazonaws.com; "
"style-src 'self' 'unsafe-inline'; "
"base-uri 'self'; "
"form-action 'self'; "
"font-src 'self';"
Expand Down
11 changes: 3 additions & 8 deletions src/zenml/zen_server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import os
from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
Callable,
Optional,
Expand All @@ -28,6 +27,7 @@
)
from urllib.parse import urlparse

import secure
from pydantic import BaseModel, ValidationError

from zenml.config.global_config import GlobalConfiguration
Expand All @@ -53,17 +53,14 @@
from zenml.zen_server.rbac.rbac_interface import RBACInterface
from zenml.zen_stores.sql_zen_store import SqlZenStore

if TYPE_CHECKING:
import secure

logger = get_logger(__name__)

_zen_store: Optional["SqlZenStore"] = None
_rbac: Optional[RBACInterface] = None
_feature_gate: Optional[FeatureGateInterface] = None
_workload_manager: Optional[WorkloadManagerInterface] = None
_plugin_flavor_registry: Optional[PluginFlavorRegistry] = None
_secure_headers: Optional["secure.Secure"] = None
_secure_headers: Optional[secure.Secure] = None


def zen_store() -> "SqlZenStore":
Expand Down Expand Up @@ -219,7 +216,7 @@ def initialize_zen_store() -> None:
_zen_store = zen_store_


def secure_headers() -> "secure.Secure":
def secure_headers() -> secure.Secure:
"""Return the secure headers component.

Returns:
Expand All @@ -236,8 +233,6 @@ def secure_headers() -> "secure.Secure":

def initialize_secure_headers() -> None:
"""Initialize the secure headers component."""
import secure

global _secure_headers

config = server_config()
Expand Down
6 changes: 6 additions & 0 deletions src/zenml/zen_server/zen_server_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ async def set_secure_headers(request: Request, call_next: Any) -> Any:
Returns:
The response with secure headers set.
"""
# If the request is for the openAPI docs, don't set secure headers
if request.url.path.startswith("/docs") or request.url.path.startswith(
"/redoc"
):
return await call_next(request)

response = await call_next(request)
secure_headers().framework.fastapi(response)
return response
Expand Down
Loading