Skip to content

Commit

Permalink
Updating the update_model decorator (#2136)
Browse files Browse the repository at this point in the history
* new update models

* dont update the user of the stack

* minor fixes

* Auto-update of E2E template

* minor fixes

* Auto-update of E2E template

* fixing some tests

* revert sql zen store cahnges

* Auto-update of E2E template

* added the missing ignore

* fixes based on failing integration tests

---------

Co-authored-by: GitHub Actions <actions@github.com>
Co-authored-by: Alex Strick van Linschoten <strickvl@users.noreply.github.com>
Co-authored-by: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com>
  • Loading branch information
4 people committed Feb 2, 2024
1 parent 16233b4 commit 67031c8
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 14 deletions.
8 changes: 1 addition & 7 deletions src/zenml/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,8 +1160,6 @@ def update_stack(

# Create the update model
update_model = StackUpdate( # type: ignore[call-arg]
workspace=self.active_workspace.id,
user=self.active_user.id,
stack_spec_path=stack_spec_file,
)

Expand Down Expand Up @@ -1665,8 +1663,6 @@ def update_stack_component(
)

update_model = ComponentUpdate( # type: ignore[call-arg]
workspace=self.active_workspace.id,
user=self.active_user.id,
component_spec_path=component_spec_path,
)

Expand Down Expand Up @@ -4296,15 +4292,13 @@ def update_service_connector(
elif expiration_seconds is None:
expiration_seconds = connector_model.expiration_seconds

connector_update = ServiceConnectorUpdate(
connector_update = ServiceConnectorUpdate( # type: ignore[call-arg]
name=name or connector_model.name,
connector_type=connector.connector_type,
description=description or connector_model.description,
auth_method=auth_method or connector_model.auth_method,
expires_skew_tolerance=expires_skew_tolerance,
expiration_seconds=expiration_seconds,
user=self.active_user.id,
workspace=self.active_workspace.id,
)
# Validate and configure the resources
if configuration is not None:
Expand Down
12 changes: 11 additions & 1 deletion src/zenml/models/v2/base/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

from typing import TYPE_CHECKING, Type, TypeVar

from pydantic.config import Extra

if TYPE_CHECKING:
from zenml.models.v2.base.base import BaseRequest

Expand All @@ -33,8 +35,16 @@ def update_model(_cls: Type["T"]) -> Type["T"]:
Returns:
The decorated class.
"""
for _, value in _cls.__fields__.items():
if "workspace" in _cls.__fields__:
_cls.__fields__.pop("workspace")

if "user" in _cls.__fields__:
_cls.__fields__.pop("user")

for key, value in _cls.__fields__.items():
value.required = False
value.allow_none = True

_cls.__config__.extra = Extra.forbid

return _cls
8 changes: 5 additions & 3 deletions src/zenml/stack/flavor_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,11 @@ def register_builtin_flavors(self, store: BaseZenStore) -> None:
if len(existing_flavor) == 0:
store.create_flavor(flavor_request_model)
else:
flavor_update_model = FlavorUpdate.parse_obj(
flavor_request_model
)
flavor_dict = flavor_request_model.dict()
flavor_dict.pop("workspace")
flavor_dict.pop("user")

flavor_update_model = FlavorUpdate.parse_obj(flavor_dict)
store.update_flavor(
existing_flavor[0].id, flavor_update_model
)
Expand Down
14 changes: 11 additions & 3 deletions src/zenml/zen_stores/sql_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5948,7 +5948,11 @@ def create_stack(self, stack: StackRequest) -> StackResponse:
The registered stack.
"""
with Session(self.engine) as session:
self._fail_if_stack_with_name_exists(stack=stack, session=session)
self._fail_if_stack_with_name_exists(
stack=stack,
workspace_id=stack.workspace,
session=session,
)

# Get the Schemas of all components mentioned
component_ids = (
Expand Down Expand Up @@ -6070,7 +6074,9 @@ def update_stack(
if stack_update.name:
if existing_stack.name != stack_update.name:
self._fail_if_stack_with_name_exists(
stack=stack_update, session=session
stack=stack_update,
session=session,
workspace_id=existing_stack.workspace_id,
)

components = []
Expand Down Expand Up @@ -6139,12 +6145,14 @@ def count_stacks(self, filter_model: Optional[StackFilter]) -> int:
def _fail_if_stack_with_name_exists(
self,
stack: StackRequest,
workspace_id: UUID,
session: Session,
) -> None:
"""Raise an exception if a stack with same name exists.
Args:
stack: The Stack
workspace_id: The ID of the workspace
session: The Session
Returns:
Expand All @@ -6156,7 +6164,7 @@ def _fail_if_stack_with_name_exists(
existing_domain_stack = session.exec(
select(StackSchema)
.where(StackSchema.name == stack.name)
.where(StackSchema.workspace_id == stack.workspace)
.where(StackSchema.workspace_id == workspace_id)
).first()
if existing_domain_stack is not None:
workspace = self._get_workspace_schema(
Expand Down

0 comments on commit 67031c8

Please sign in to comment.