diff --git a/docs/mocked_libs.json b/docs/mocked_libs.json index f7f6f283f1..3135440e9f 100644 --- a/docs/mocked_libs.json +++ b/docs/mocked_libs.json @@ -71,6 +71,7 @@ "google.api_core", "google.api_core.exceptions", "google.auth", + "google.auth.aws", "google.auth._default", "google.auth.exceptions", "google.auth.transport", diff --git a/pyproject.toml b/pyproject.toml index 0f9bbe758f..5a0b0a3374 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ python = ">=3.8,<3.12" python-dateutil = "^2.8.1" pyyaml = ">=6.0.1" rich = { extras = ["jupyter"], version = ">=12.0.0" } +secure = "~0.3.0" sqlalchemy_utils = "0.38.3" sqlmodel = "0.0.8" importlib_metadata = { version = "<=7.0.0", python = "<3.10" } @@ -73,7 +74,6 @@ fastapi-utils = { version = "~0.2.1", optional = true } orjson = { version = "~3.10.0", optional = true } Jinja2 = { version = "*", optional = true } ipinfo = { version = ">=4.4.3", optional = true } -secure = { version = "~0.3.0", optional = true } # Optional dependencies for project templates copier = { version = ">=8.1.0", optional = true } @@ -182,7 +182,6 @@ server = [ "orjson", "Jinja2", "ipinfo", - "secure", ] templates = ["copier", "jinja2-time", "ruff", "pyyaml-include"] terraform = ["python-terraform"] diff --git a/src/zenml/constants.py b/src/zenml/constants.py index 112f7c6013..21f6cfe9ad 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -244,10 +244,10 @@ def handle_int_env_var(var: str, default: int = 0) -> int: DEFAULT_ZENML_SERVER_SECURE_HEADERS_CONTENT = "nosniff" DEFAULT_ZENML_SERVER_SECURE_HEADERS_CSP = ( "default-src 'none'; " - "script-src 'self' 'unsafe-inline' 'unsafe-eval' https://cdn.jsdelivr.net; " + "script-src 'self' 'unsafe-inline' 'unsafe-eval'; " "connect-src 'self' https://sdkdocs.zenml.io https://hubapi.zenml.io; " - "img-src 'self' data: https://public-flavor-logos.s3.eu-central-1.amazonaws.com https://fastapi.tiangolo.com; " - "style-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net; " + "img-src 'self' data: https://public-flavor-logos.s3.eu-central-1.amazonaws.com; " + "style-src 'self' 'unsafe-inline'; " "base-uri 'self'; " "form-action 'self'; " "font-src 'self';" diff --git a/src/zenml/zen_server/utils.py b/src/zenml/zen_server/utils.py index f98039b225..2953621a1e 100644 --- a/src/zenml/zen_server/utils.py +++ b/src/zenml/zen_server/utils.py @@ -17,7 +17,6 @@ import os from functools import wraps from typing import ( - TYPE_CHECKING, Any, Callable, Optional, @@ -28,6 +27,7 @@ ) from urllib.parse import urlparse +import secure from pydantic import BaseModel, ValidationError from zenml.config.global_config import GlobalConfiguration @@ -53,9 +53,6 @@ from zenml.zen_server.rbac.rbac_interface import RBACInterface from zenml.zen_stores.sql_zen_store import SqlZenStore -if TYPE_CHECKING: - import secure - logger = get_logger(__name__) _zen_store: Optional["SqlZenStore"] = None @@ -63,7 +60,7 @@ _feature_gate: Optional[FeatureGateInterface] = None _workload_manager: Optional[WorkloadManagerInterface] = None _plugin_flavor_registry: Optional[PluginFlavorRegistry] = None -_secure_headers: Optional["secure.Secure"] = None +_secure_headers: Optional[secure.Secure] = None def zen_store() -> "SqlZenStore": @@ -219,7 +216,7 @@ def initialize_zen_store() -> None: _zen_store = zen_store_ -def secure_headers() -> "secure.Secure": +def secure_headers() -> secure.Secure: """Return the secure headers component. Returns: @@ -236,8 +233,6 @@ def secure_headers() -> "secure.Secure": def initialize_secure_headers() -> None: """Initialize the secure headers component.""" - import secure - global _secure_headers config = server_config() diff --git a/src/zenml/zen_server/zen_server_api.py b/src/zenml/zen_server/zen_server_api.py index 892dbc37e7..6995a1499b 100644 --- a/src/zenml/zen_server/zen_server_api.py +++ b/src/zenml/zen_server/zen_server_api.py @@ -134,6 +134,12 @@ async def set_secure_headers(request: Request, call_next: Any) -> Any: Returns: The response with secure headers set. """ + # If the request is for the openAPI docs, don't set secure headers + if request.url.path.startswith("/docs") or request.url.path.startswith( + "/redoc" + ): + return await call_next(request) + response = await call_next(request) secure_headers().framework.fastapi(response) return response