Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Server scalability improvements #2752

Merged
merged 5 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/zenml/config/server_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
DEFAULT_ZENML_SERVER_SECURE_HEADERS_REFERRER,
DEFAULT_ZENML_SERVER_SECURE_HEADERS_XFO,
DEFAULT_ZENML_SERVER_SECURE_HEADERS_XXP,
DEFAULT_ZENML_SERVER_THREAD_POOL_SIZE,
DEFAULT_ZENML_SERVER_USE_LEGACY_DASHBOARD,
ENV_ZENML_SERVER_PREFIX,
)
Expand Down Expand Up @@ -301,6 +302,8 @@ class ServerConfiguration(BaseModel):
display_updates: bool = True
auto_activate: bool = False

thread_pool_size: int = DEFAULT_ZENML_SERVER_THREAD_POOL_SIZE

_deployment_id: Optional[UUID] = None

@root_validator(pre=True)
Expand Down
1 change: 1 addition & 0 deletions src/zenml/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int:

# Server settings
DEFAULT_ZENML_SERVER_NAME = "default"
DEFAULT_ZENML_SERVER_THREAD_POOL_SIZE = 40
DEFAULT_ZENML_JWT_TOKEN_LEEWAY = 10
DEFAULT_ZENML_JWT_TOKEN_ALGORITHM = "HS256"
DEFAULT_ZENML_AUTH_SCHEME = AuthScheme.OAUTH2_PASSWORD_BEARER
Expand Down
22 changes: 20 additions & 2 deletions src/zenml/zen_server/cloud_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from requests.adapters import HTTPAdapter, Retry

from zenml.exceptions import SubscriptionUpgradeRequiredError
from zenml.zen_server.utils import server_config

ZENML_CLOUD_RBAC_ENV_PREFIX = "ZENML_CLOUD_"

Expand Down Expand Up @@ -99,7 +100,7 @@ def _get(
raise SubscriptionUpgradeRequiredError(response.json())
else:
raise RuntimeError(
f"Failed with the following error {response.json()}"
f"Failed with the following error {response} {response.text}"
)

return response
Expand Down Expand Up @@ -154,12 +155,29 @@ def session(self) -> requests.Session:
A requests session with the authentication token.
"""
if self._session is None:
# Set up the session's connection pool size to match the server's
# thread pool size. This allows the server to cache one connection
# per thread, which means we can keep connections open for longer
# and avoid the overhead of setting up a new connection for each
# request.
conn_pool_size = server_config().thread_pool_size

self._session = requests.Session()
token = self._fetch_auth_token()
self._session.headers.update({"Authorization": "Bearer " + token})

retries = Retry(total=5, backoff_factor=0.1)
self._session.mount("https://", HTTPAdapter(max_retries=retries))
self._session.mount(
"https://",
HTTPAdapter(
max_retries=retries,
# We only use one connection pool to be cached because we
# only communicate with one remote server (the control
# plane)
pool_connections=1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding, is it costly to have more cached hosts? I understand we need only one but want to know what impact this would have.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no impact in this case. We can set this to 10 or 100, there will only be one connection pool because we only need to connect to one host. This is better covered here, if you want the gruesome details: https://stackoverflow.com/questions/34837026/whats-the-meaning-of-pool-connections-in-requests-adapters-httpadapter

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really well explained answer; had this in bookmarks for a while ;) thanks!

pool_maxsize=conn_pool_size,
),
)

return self._session

Expand Down
3 changes: 3 additions & 0 deletions src/zenml/zen_server/deploy/helm/templates/_environment.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ Returns:
{{- define "zenml.serverConfigurationAttrs" -}}
auth_scheme: {{ .ZenML.authType | default .ZenML.auth.authType | quote }}
deployment_type: {{ .ZenML.deploymentType | default "kubernetes" }}
{{- if .ZenML.threadPoolSize }}
thread_pool_size: {{ .ZenML.threadPoolSize | quote }}
{{- end }}
{{- if .ZenML.auth.jwtTokenAlgorithm }}
jwt_token_algorithm: {{ .ZenML.auth.jwtTokenAlgorithm | quote }}
{{- end }}
Expand Down
6 changes: 6 additions & 0 deletions src/zenml/zen_server/deploy/helm/templates/server-secret.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ data:
{{- if .Values.zenml.database.sslKey }}
ZENML_STORE_SSL_KEY: {{ .Files.Get .Values.zenml.database.sslKey | b64enc }}
{{- end }}
{{- if .Values.zenml.database.poolSize }}
ZENML_STORE_POOL_SIZE: {{ .Values.zenml.database.poolSize | b64enc | quote }}
{{- end }}
{{- if .Values.zenml.database.maxOverflow }}
ZENML_STORE_MAX_OVERFLOW: {{ .Values.zenml.database.maxOverflow | b64enc | quote }}
{{- end }}
{{- end }}
{{- range $k, $v := include "zenml.secretsStoreSecretEnvVariables" . | fromYaml}}
{{ $k }}: {{ $v | b64enc | quote }}
Expand Down
21 changes: 21 additions & 0 deletions src/zenml/zen_server/deploy/helm/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,17 @@ zenml:

replicaCount: 1

# The number of ZenML server worker threads to use. This controls the number
# of concurrent requests that each ZenML server pod/replica can handle at a
# time. If not specified, the default value is 40.
#
# NOTE: this value should be coordinated with the `zenml.database.poolSize`
# and `zenml.database.maxOverflow` values to ensure that the ZenML server
# workers do not block on database connections (i.e. the sum of the pool size
# and max overflow should be greater than or equal to the thread pool size).
#
# threadPoolSize: 40

image:
repository: zenmldocker/zenml-server
pullPolicy: Always
Expand Down Expand Up @@ -198,6 +209,16 @@ zenml:
# sslKey: /path/to/client-key.pem
# sslVerifyServerCert: True

# Connection pool settings (only relevant for MySQL databases).
#
# NOTE: these values should be coordinated with the `zenml.threadPoolSize`
# to ensure that the ZenML server workers do not block on database
# connections (i.e. the sum of the pool size and max overflow should be
# greater than or equal to the thread pool size).
#
# poolSize: 20
# maxOverflow: 20

# ZenML supports backing up the database before DB migrations are performed
# and restoring it in case of a DB migration failure. For more information,
# see the following documentation:
Expand Down
26 changes: 18 additions & 8 deletions src/zenml/zen_server/zen_server_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Zen Server API."""
"""Zen Server API.

To run this file locally, execute:

```
uvicorn zenml.zen_server.zen_server_api:app --reload
```
"""

import os
from asyncio.log import logger
from genericpath import isfile
from typing import Any, List

from anyio import to_thread
from fastapi import FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import ORJSONResponse
Expand Down Expand Up @@ -73,6 +81,11 @@
server_config,
)

# Set the maximum number of worker threads
to_thread.current_default_thread_limiter().total_tokens = (
server_config().thread_pool_size
)

if server_config().use_legacy_dashboard:
DASHBOARD_DIRECTORY = "dashboard_legacy"
else:
Expand Down Expand Up @@ -218,7 +231,7 @@ def initialize() -> None:
# Basic Health Endpoint
@app.head(HEALTH, include_in_schema=False)
@app.get(HEALTH)
def health() -> str:
async def health() -> str:
"""Get health status of the server.

Returns:
Expand All @@ -231,7 +244,7 @@ def health() -> str:


@app.get("/", include_in_schema=False)
def dashboard(request: Request) -> Any:
async def dashboard(request: Request) -> Any:
"""Dashboard endpoint.

Args:
Expand All @@ -250,9 +263,6 @@ def dashboard(request: Request) -> Any:
return templates.TemplateResponse("index.html", {"request": request})


# to run this file locally, execute:
# uvicorn zenml.zen_server.zen_server_api:app --reload

app.include_router(artifact_endpoint.artifact_router)
app.include_router(artifact_version_endpoints.artifact_version_router)
app.include_router(auth_endpoints.router)
Expand Down Expand Up @@ -325,7 +335,7 @@ def get_root_static_files() -> List[str]:
@app.get(
API + "/{invalid_api_path:path}", status_code=404, include_in_schema=False
)
def invalid_api(invalid_api_path: str) -> None:
async def invalid_api(invalid_api_path: str) -> None:
"""Invalid API endpoint.

All API endpoints that are not defined in the API routers will be
Expand All @@ -342,7 +352,7 @@ def invalid_api(invalid_api_path: str) -> None:


@app.get("/{file_path:path}", include_in_schema=False)
def catch_all(request: Request, file_path: str) -> Any:
async def catch_all(request: Request, file_path: str) -> Any:
"""Dashboard endpoint.

Args:
Expand Down
Loading