Skip to content

Commit

Permalink
Interact with feature gate (#2492)
Browse files Browse the repository at this point in the history
* Interact with feature gate

* Properly handle entitlement violations

* Apply suggestions from code review

Co-authored-by: Barış Can Durak <36421093+bcdurak@users.noreply.github.com>
Co-authored-by: Safoine El Khabich <34200873+safoinme@users.noreply.github.com>

* Auto-update of Starter template

* Applied code reviews

* reformatted

* Reformatted

* Disable feature_gate when no source specified.

* Auto-update of Starter template

* Auto-update of E2E template

* Auto-update of NLP template

* Handle corrupted or empty global configuration file (#2508)

* Handle corrupted or empty global configuration file

* Auto-update of Starter template

---------

Co-authored-by: GitHub Actions <actions@github.com>

* Linted

* Add admin users notion (#2494)

* add admin users to OSS

* skip missing methods

* increase readability

* doc string

* lint

* lint

* missing arg

* add some edge-cases

* wip commit to carve out clean_client changes

* revert irrelevant changes

* revert irrelevant changes

* rework tests to run on rest

* Apply suggestions from code review

Co-authored-by: Alex Strick van Linschoten <strickvl@users.noreply.github.com>
Co-authored-by: Stefan Nica <stefan@zenml.io>

* polish test cases

* fix branching

* admin user mgmt CLI/Client

* close activation vulnerability

* revert rbac changes

* verify admin permissions in endpoints

* add `is_admin` to external users

* only reg users will be migrated as admins

* default is always admin

* extend tests

* lint

* default `is_admin` None

* Auto-update of Starter template

* review suggestions

* review suggestions

* calm down linter

* Update src/zenml/cli/user_management.py

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Alex Strick van Linschoten <strickvl@users.noreply.github.com>

* review suggestion

---------

Co-authored-by: Alex Strick van Linschoten <strickvl@users.noreply.github.com>
Co-authored-by: Stefan Nica <stefan@zenml.io>
Co-authored-by: GitHub Actions <actions@github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* remove dashboard from gitignore (#2517)

* Colima / Homebrew fix (#2512)

* attempt fix

* Auto-update of Starter template

* colima qemu fix trial

* remove qemu

* logs

* logs better

* testing brew workaround

* try second possible fix for python gha

* actually apply the fix

* try the second possible solution for unbreaking python

* make the CI whole again

* linting

* fix python 3.11 on mac (test)

* one more attempt

* formatting

* different fix

* restore the CI to full glory (fixed now!)

---------

Co-authored-by: GitHub Actions <actions@github.com>

* remove extra env var assignment (#2518)

* Allow installing packages using UV (#2510)

* Allow installing packages using UV

* Auto-update of Starter template

* actually make it work

* Auto-update of Starter template

---------

Co-authored-by: GitHub Actions <actions@github.com>

* Additional fields for track events (#2507)

* additional fields for track events

* formatting

* Auto-update of Starter template

* adding a few recommendations

* formatting

* Auto-update of Starter template

---------

Co-authored-by: GitHub Actions <actions@github.com>
Co-authored-by: Alex Strick van Linschoten <strickvl@users.noreply.github.com>

* Auto-update of Starter template

* Auto-update of NLP template

* Auto-update of E2E template

* Update src/zenml/zen_server/exceptions.py

Co-authored-by: Stefan Nica <stefan@zenml.io>

* Update src/zenml/zen_server/cloud_utils.py

Co-authored-by: Stefan Nica <stefan@zenml.io>

* Applied code review.

* Properly reformatted

* Reformatted

* Fixed test

* Fixed docstring

* Model deletion works now, fixed error message

* Show correct error message when creating models that exceed subscription limit

* Send resource id

* Auto-update of LLM Finetuning template

* Fix error

* Limit pipeline namespaces

* Remove billing url

* Linted

* Potential fix

---------

Co-authored-by: Barış Can Durak <36421093+bcdurak@users.noreply.github.com>
Co-authored-by: Safoine El Khabich <34200873+safoinme@users.noreply.github.com>
Co-authored-by: GitHub Actions <actions@github.com>
Co-authored-by: Stefan Nica <stefan@zenml.io>
Co-authored-by: Alex Strick van Linschoten <strickvl@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: Jayesh Sharma <wjayesh@outlook.com>
Co-authored-by: Michael Schuster <schustmi@users.noreply.github.com>
Co-authored-by: Michael Schuster <michael.schuster.ffb@googlemail.com>
  • Loading branch information
10 people authored Mar 26, 2024
1 parent 76fc719 commit 5406fa7
Show file tree
Hide file tree
Showing 19 changed files with 710 additions and 217 deletions.
19 changes: 17 additions & 2 deletions src/zenml/config/server_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,13 @@ class ServerConfiguration(BaseModel):
construct the OAuth 2.0 device authorization endpoint. If not set,
a partial URL is returned to the client which is used to construct
the full URL based on the server's root URL path.
device_expiration: The time in minutes that an OAuth 2.0 device is
device_expiration_minutes: The time in minutes that an OAuth 2.0 device is
allowed to be used to authenticate with the ZenML server. If not
set or if `jwt_token_expire_minutes` is not set, the devices are
allowed to be used indefinitely. This controls the expiration time
of the JWT tokens issued to clients after they have authenticated
with the ZenML server using an OAuth 2.0 device.
trusted_device_expiration: The time in minutes that a trusted OAuth 2.0
trusted_device_expiration_minutes: The time in minutes that a trusted OAuth 2.0
device is allowed to be used to authenticate with the ZenML server.
If not set or if `jwt_token_expire_minutes` is not set, the devices
are allowed to be used indefinitely. This controls the expiration
Expand All @@ -116,6 +116,11 @@ class ServerConfiguration(BaseModel):
the RBAC interface defined by
`zenml.zen_server.rbac_interface.RBACInterface`. If not specified,
RBAC will not be enabled for this server.
feature_gate_implementation_source: Source pointing to a class
implementing the feature gate interface defined by
`zenml.zen_server.feature_gate.feature_gate_interface.FeatureGateInterface`.
If not specified, feature usage will not be gated/tracked for this
server.
workload_manager_implementation_source: Source pointing to a class
implementing the workload management interface.
pipeline_run_auth_window: The default time window in minutes for which
Expand Down Expand Up @@ -156,6 +161,7 @@ class ServerConfiguration(BaseModel):
external_server_id: Optional[UUID] = None

rbac_implementation_source: Optional[str] = None
feature_gate_implementation_source: Optional[str] = None
workload_manager_implementation_source: Optional[str] = None
pipeline_run_auth_window: int = (
DEFAULT_ZENML_SERVER_PIPELINE_RUN_AUTH_WINDOW
Expand Down Expand Up @@ -244,6 +250,15 @@ def rbac_enabled(self) -> bool:
"""
return self.rbac_implementation_source is not None

@property
def feature_gate_enabled(self) -> bool:
"""Whether feature gating is enabled on the server or not.
Returns:
Whether feature gating is enabled on the server or not.
"""
return self.feature_gate_implementation_source is not None

@property
def workload_manager_enabled(self) -> bool:
"""Whether workload management is enabled on the server or not.
Expand Down
64 changes: 64 additions & 0 deletions src/zenml/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,61 @@
# permissions and limitations under the License.
"""ZenML constants."""

import json
import logging
import os
from typing import Any, List, Optional, Type, TypeVar

from zenml.enums import AuthScheme

T = TypeVar("T")


def handle_json_env_var(
var: str,
expected_type: Type[T],
default: Optional[List[str]] = None,
) -> Any:
"""Converts a json env var into a Python object.
Args:
var: The environment variable to convert.
default: The default value to return if the env var is not set.
expected_type: The type of the expected Python object.
Returns:
The converted list value.
Raises:
TypeError: In case the value of the environment variable is not of a
valid type.
"""
# this needs to be here to avoid mutable defaults
if default is None:
default = []

value = os.getenv(var)
if value:
try:
loaded_value = json.loads(value)
# check if loaded value is of correct type
if expected_type is None or isinstance(
loaded_value, expected_type
):
return loaded_value
else:
raise TypeError # if not correct type, raise TypeError
except (TypeError, json.JSONDecodeError):
# Use raw logging to avoid cyclic dependency
logging.warning(
f"Environment Variable {var} could not be loaded, into type "
f"{expected_type}, defaulting to: {default}."
)
return default
else:
return default


def handle_bool_env_var(var: str, default: bool = False) -> bool:
"""Converts normal env var to boolean.
Expand Down Expand Up @@ -100,6 +151,9 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
ENV_ZENML_SERVER_PREFIX = "ZENML_SERVER_"
ENV_ZENML_SERVER_DEPLOYMENT_TYPE = f"{ENV_ZENML_SERVER_PREFIX}DEPLOYMENT_TYPE"
ENV_ZENML_SERVER_AUTH_SCHEME = f"{ENV_ZENML_SERVER_PREFIX}AUTH_SCHEME"
ENV_ZENML_SERVER_REPORTABLE_RESOURCES = (
f"{ENV_ZENML_SERVER_PREFIX}REPORTABLE_RESOURCES"
)

# Logging variables
IS_DEBUG_ENV: bool = handle_bool_env_var(ENV_ZENML_DEBUG, default=False)
Expand Down Expand Up @@ -181,6 +235,16 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_MINUTE = 5
DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_DAY = 1000

# Configurations to decide which resources report their usage and check for
# entitlement in the case of a cloud deployment. Expected Format is this:
# ENV_ZENML_REPORTABLE_RESOURCES='["Foo", "bar"]'
REPORTABLE_RESOURCES: List[str] = handle_json_env_var(
ENV_ZENML_SERVER_REPORTABLE_RESOURCES,
expected_type=list,
default=["pipeline_run", "model"],
)
REQUIRES_CUSTOM_RESOURCE_REPORTING = ["pipeline"]

# API Endpoint paths:
ACTIVATE = "/activate"
ACTIONS = "/action-flavors"
Expand Down
4 changes: 4 additions & 0 deletions src/zenml/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,10 @@ class InputResolutionError(ZenMLBaseException):
"""Raised when step input resolving failed."""


class SubscriptionUpgradeRequiredError(ZenMLBaseException):
"""Raised when user tries to perform an action outside their current subscription tier."""


class HydrationError(ZenMLBaseException):
"""Raised when the model hydration failed."""

Expand Down
4 changes: 1 addition & 3 deletions src/zenml/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,12 +549,10 @@ def _get_or_create_model(self) -> "ModelResponse":
)
logger.info(f"New model `{self.name}` was created implicitly.")
except EntityExistsError:
# this is backup logic, if model was created somehow in between get and create calls
pass
finally:
model = zenml_client.zen_store.get_model(
model_name_or_id=self.name
)

self._model_id = model.id
return model

Expand Down
201 changes: 201 additions & 0 deletions src/zenml/zen_server/cloud_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
"""Utils concerning anything concerning the cloud control plane backend."""

import os
from typing import Any, Dict, Optional

import requests
from pydantic import BaseModel, validator
from requests.adapters import HTTPAdapter, Retry

from zenml.exceptions import SubscriptionUpgradeRequiredError

ZENML_CLOUD_RBAC_ENV_PREFIX = "ZENML_CLOUD_"


class ZenMLCloudConfiguration(BaseModel):
"""ZenML Cloud RBAC configuration."""

api_url: str

oauth2_client_id: str
oauth2_client_secret: str
oauth2_audience: str
auth0_domain: str

@validator("api_url")
def _strip_trailing_slashes_url(cls, url: str) -> str:
"""Strip any trailing slashes on the API URL.
Args:
url: The API URL.
Returns:
The API URL with potential trailing slashes removed.
"""
return url.rstrip("/")

@classmethod
def from_environment(cls) -> "ZenMLCloudConfiguration":
"""Get the RBAC configuration from environment variables.
Returns:
The RBAC configuration.
"""
env_config: Dict[str, Any] = {}
for k, v in os.environ.items():
if v == "":
continue
if k.startswith(ZENML_CLOUD_RBAC_ENV_PREFIX):
env_config[k[len(ZENML_CLOUD_RBAC_ENV_PREFIX) :].lower()] = v

return ZenMLCloudConfiguration(**env_config)

class Config:
"""Pydantic configuration class."""

# Allow extra attributes from configs of previous ZenML versions to
# permit downgrading
extra = "allow"


class ZenMLCloudSession:
"""Class to use for communication between server and control plane."""

def __init__(self) -> None:
"""Initialize the RBAC component."""
self._config = ZenMLCloudConfiguration.from_environment()
self._session: Optional[requests.Session] = None

def _get(
self, endpoint: str, params: Optional[Dict[str, Any]]
) -> requests.Response:
"""Send a GET request using the active session.
Args:
endpoint: The endpoint to send the request to. This will be appended
to the base URL.
params: Parameters to include in the request.
Raises:
RuntimeError: If the request failed.
SubscriptionUpgradeRequiredError: In case the current subscription
tier is insufficient for the attempted operation.
Returns:
The response.
"""
url = self._config.api_url + endpoint

response = self.session.get(url=url, params=params, timeout=7)
if response.status_code == 401:
# Refresh the auth token and try again
self._clear_session()
response = self.session.get(url=url, params=params, timeout=7)

try:
response.raise_for_status()
except requests.HTTPError:
if response.status_code == 402:
raise SubscriptionUpgradeRequiredError(response.json())
else:
raise RuntimeError(
f"Failed with the following error {response.json()}"
)

return response

def _post(
self,
endpoint: str,
params: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None,
) -> requests.Response:
"""Send a POST request using the active session.
Args:
endpoint: The endpoint to send the request to. This will be appended
to the base URL.
params: Parameters to include in the request.
data: Data to include in the request.
Raises:
RuntimeError: If the request failed.
Returns:
The response.
"""
url = self._config.api_url + endpoint

response = self.session.post(
url=url, params=params, json=data, timeout=7
)
if response.status_code == 401:
# Refresh the auth token and try again
self._clear_session()
response = self.session.post(
url=url, params=params, json=data, timeout=7
)

try:
response.raise_for_status()
except requests.HTTPError as e:
raise RuntimeError(
f"Failed while trying to contact the central zenml cloud "
f"service: {e}"
)

return response

@property
def session(self) -> requests.Session:
"""Authenticate to the ZenML Cloud API.
Returns:
A requests session with the authentication token.
"""
if self._session is None:
self._session = requests.Session()
token = self._fetch_auth_token()
self._session.headers.update({"Authorization": "Bearer " + token})

retries = Retry(total=5, backoff_factor=0.1)
self._session.mount("https://", HTTPAdapter(max_retries=retries))

return self._session

def _clear_session(self) -> None:
"""Clear the authentication session."""
self._session = None

def _fetch_auth_token(self) -> str:
"""Fetch an auth token for the Cloud API from auth0.
Raises:
RuntimeError: If the auth token can't be fetched.
Returns:
Auth token.
"""
# Get an auth token from auth0
auth0_url = f"https://{self._config.auth0_domain}/oauth/token"
headers = {"content-type": "application/x-www-form-urlencoded"}
payload = {
"client_id": self._config.oauth2_client_id,
"client_secret": self._config.oauth2_client_secret,
"audience": self._config.oauth2_audience,
"grant_type": "client_credentials",
}
try:
response = requests.post(
auth0_url, headers=headers, data=payload, timeout=7
)
response.raise_for_status()
except Exception as e:
raise RuntimeError(f"Error fetching auth token from auth0: {e}")

access_token = response.json().get("access_token", "")

if not access_token or not isinstance(access_token, str):
raise RuntimeError("Could not fetch auth token from auth0.")

return str(access_token)
3 changes: 3 additions & 0 deletions src/zenml/zen_server/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
SecretExistsError,
StackComponentExistsError,
StackExistsError,
SubscriptionUpgradeRequiredError,
ValidationError,
ZenKeyError,
)
Expand Down Expand Up @@ -77,6 +78,8 @@ class ErrorModel(BaseModel):
(IllegalOperationError, 403),
# 401 Unauthorized
(AuthorizationException, 401),
# 402 Payment required
(SubscriptionUpgradeRequiredError, 402),
# 404 Not Found
(DoesNotExistException, 404),
(ZenKeyError, 404),
Expand Down
13 changes: 13 additions & 0 deletions src/zenml/zen_server/feature_gate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
Loading

0 comments on commit 5406fa7

Please sign in to comment.