Skip to content

Commit

Permalink
Upgrading SQLModel to the latest version (#2452)
Browse files Browse the repository at this point in the history
* upgrade sqlmodel, remove fastapi-utils

* remove fastapi-utils as a requirement

* small query bug

* new migration scriptqq

* adding some comments

* fixing docstring

* change in schemas instead of migration

* fixing the tag updates

* formatting

* latest sqlmodel

* fixing versions

* formatting and linting

* new version for ge and formatting and linting

* new versions

* zen store fix

* fixing the spaces

* checking

* sorting requirements

* half half install hack

* remove bc

* fixing the sql model bugs

* added an extra comment on dependencies

* formatting examples

* Auto-update of Starter template

* Auto-update of NLP template

---------

Co-authored-by: Alex Strick van Linschoten <strickvl@users.noreply.github.com>
Co-authored-by: Alex Strick van Linschoten <stricksubscriptions@fastmail.fm>
Co-authored-by: GitHub Actions <actions@github.com>
  • Loading branch information
4 people committed Mar 6, 2024
1 parent 8e13b42 commit e93c687
Show file tree
Hide file tree
Showing 16 changed files with 85 additions and 55 deletions.
6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,15 @@ python = ">=3.8,<3.12"
python-dateutil = "^2.8.1"
pyyaml = ">=6.0.1"
rich = { extras = ["jupyter"], version = ">=12.0.0" }
sqlalchemy_utils = "0.38.3"
sqlmodel = "0.0.8"
sqlalchemy_utils = "0.41.1"
sqlmodel = ">=0.0.9, <=0.0.16"
importlib_metadata = { version = "<=7.0.0", python = "<3.10" }

# Optional dependencies for the ZenServer
fastapi = { version = ">=0.75,<0.100", optional = true }
uvicorn = { extras = ["standard"], version = ">=0.17.5", optional = true }
python-multipart = { version = "~0.0.5", optional = true }
pyjwt = { extras = ["crypto"], version = "2.7.*", optional = true }
fastapi-utils = { version = "~0.2.1", optional = true }
orjson = { version = "~3.8.3", optional = true }
Jinja2 = { version = "*", optional = true }
ipinfo = { version = ">=4.4.3", optional = true }
Expand Down Expand Up @@ -441,7 +440,6 @@ module = [
"bentoml.*",
"multipart.*",
"jose.*",
"fastapi_utils.*",
"sqlalchemy_utils.*",
"sky.*",
"copier.*",
Expand Down
11 changes: 9 additions & 2 deletions src/zenml/integrations/airflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
orchestrator. You can enable it by registering the Airflow orchestrator with
the CLI tool, then bootstrap using the ``zenml orchestrator up`` command.
"""
from typing import List, Optional, Type
from typing import List, Type

from zenml.integrations.constants import AIRFLOW
from zenml.integrations.integration import Integration
Expand All @@ -32,7 +32,14 @@ class AirflowIntegration(Integration):
NAME = AIRFLOW
# remove pendulum version requirement once Airflow supports
# pendulum>-3.0.0
REQUIREMENTS = ["apache-airflow~=2.4.0", "pendulum<3.0.0"]
REQUIREMENTS = [
"apache-airflow~=2.4.0",
"pendulum<3.0.0",
# We need to add this as an extra dependency to manually downgrade
# SQLModel. Otherwise, the initial installation of ZenML installs
# a higher version SQLModel and a version mismatch is created.
"sqlmodel>=0.0.9,<=0.0.16",
]

@classmethod
def flavors(cls) -> List[Type[Flavor]]:
Expand Down
8 changes: 7 additions & 1 deletion src/zenml/integrations/evidently/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,13 @@ class EvidentlyIntegration(Integration):
"""[Evidently](https://github.com/evidentlyai/evidently) integration for ZenML."""

NAME = EVIDENTLY
REQUIREMENTS = ["evidently>0.2.6,<0.4.5"] # supports pyyaml 6
REQUIREMENTS = [
"evidently>0.2.6,<0.4.5", # supports pyyaml 6
# We need to add this as an extra dependency to manually downgrade
# SQLModel. Otherwise, the initial installation of ZenML installs
# a higher version SQLModel and a version mismatch is created.
"sqlmodel>=0.0.9,<=0.0.16"
]

@classmethod
def flavors(cls) -> List[Type[Flavor]]:
Expand Down
4 changes: 4 additions & 0 deletions src/zenml/integrations/great_expectations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ class GreatExpectationsIntegration(Integration):
"great-expectations>=0.15.0,<=0.15.47",
# typing_extensions 4.6.0 and above doesn't work with GE
"typing_extensions<4.6.0",
# We need to add this as an extra dependency to manually downgrade
# SQLModel. Otherwise, the initial installation of ZenML installs
# a higher version SQLModel and a version mismatch is created.
"sqlmodel>=0.0.9,<=0.0.16",
]

@staticmethod
Expand Down
1 change: 0 additions & 1 deletion src/zenml/zen_server/deploy/local/local_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def check_local_server_dependencies() -> None:
try:
# Make sure the ZenML Server dependencies are installed
import fastapi # noqa
import fastapi_utils # noqa
import jwt # noqa
import multipart # noqa
import uvicorn # noqa
Expand Down
6 changes: 3 additions & 3 deletions src/zenml/zen_stores/schemas/artifact_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ class ArtifactVersionSchema(BaseSchema, table=True):
# Fields
version: str
version_number: Optional[int]
type: ArtifactType
type: str
uri: str = Field(sa_column=Column(TEXT, nullable=False))
materializer: str = Field(sa_column=Column(TEXT, nullable=False))
data_type: str = Field(sa_column=Column(TEXT, nullable=False))
Expand Down Expand Up @@ -277,7 +277,7 @@ def from_request(
artifact_store_id=artifact_version_request.artifact_store_id,
workspace_id=artifact_version_request.workspace,
user_id=artifact_version_request.user,
type=artifact_version_request.type,
type=artifact_version_request.type.value,
uri=artifact_version_request.uri,
materializer=artifact_version_request.materializer.json(),
data_type=artifact_version_request.data_type.json(),
Expand Down Expand Up @@ -328,7 +328,7 @@ def to_model(
version=self.version_number or self.version,
user=self.user.to_model() if self.user else None,
uri=self.uri,
type=self.type,
type=ArtifactType(self.type),
materializer=materializer,
data_type=data_type,
created=self.created,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class ArtifactVisualizationSchema(BaseSchema, table=True):
__tablename__ = "artifact_visualization"

# Fields
type: VisualizationType
type: str
uri: str = Field(sa_column=Column(TEXT, nullable=False))

# Foreign Keys
Expand Down Expand Up @@ -71,7 +71,7 @@ def from_model(
The `ArtifactVisualizationSchema`.
"""
return cls(
type=artifact_visualization_request.type,
type=artifact_visualization_request.type.value,
uri=artifact_visualization_request.uri,
artifact_version_id=artifact_version_id,
)
Expand All @@ -95,7 +95,7 @@ def to_model(
The `Visualization`.
"""
body = ArtifactVisualizationResponseBody(
type=self.type,
type=VisualizationType(self.type),
uri=self.uri,
created=self.created,
updated=self.updated,
Expand Down
6 changes: 4 additions & 2 deletions src/zenml/zen_stores/schemas/component_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class StackComponentSchema(NamedSchema, table=True):

__tablename__ = "stack_component"

type: StackComponentType
type: str
flavor: str
configuration: bytes
labels: Optional[bytes]
Expand Down Expand Up @@ -127,6 +127,8 @@ def update(
self.labels = base64.b64encode(
json.dumps(component_update.labels).encode("utf-8")
)
elif field == "type":
self.type = component_update.type.value
else:
setattr(self, field, value)

Expand All @@ -151,7 +153,7 @@ def to_model(
A `ComponentModel`
"""
body = ComponentResponseBody(
type=self.type,
type=StackComponentType(self.type),
flavor=self.flavor,
user=self.user.to_model() if self.user else None,
created=self.created,
Expand Down
10 changes: 5 additions & 5 deletions src/zenml/zen_stores/schemas/device_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class OAuthDeviceSchema(BaseSchema, table=True):
client_id: UUID
user_code: str
device_code: str
status: OAuthDeviceStatus
status: str
failed_auth_attempts: int = 0
expires: Optional[datetime] = None
last_login: Optional[datetime] = None
Expand Down Expand Up @@ -121,7 +121,7 @@ def from_request(
client_id=request.client_id,
user_code=hashed_user_code,
device_code=hashed_device_code,
status=OAuthDeviceStatus.PENDING,
status=OAuthDeviceStatus.PENDING.value,
failed_auth_attempts=0,
expires=now + timedelta(seconds=request.expires_in),
os=request.os,
Expand Down Expand Up @@ -153,9 +153,9 @@ def update(self, device_update: OAuthDeviceUpdate) -> "OAuthDeviceSchema":
setattr(self, field, value)

if device_update.locked is True:
self.status = OAuthDeviceStatus.LOCKED
self.status = OAuthDeviceStatus.LOCKED.value
elif device_update.locked is False:
self.status = OAuthDeviceStatus.ACTIVE
self.status = OAuthDeviceStatus.ACTIVE.value

self.updated = datetime.utcnow()
return self
Expand Down Expand Up @@ -233,7 +233,7 @@ def to_model(
client_id=self.client_id,
expires=self.expires,
trusted_device=self.trusted_device,
status=self.status,
status=OAuthDeviceStatus(self.status),
os=self.os,
ip_address=self.ip_address,
hostname=self.hostname,
Expand Down
6 changes: 4 additions & 2 deletions src/zenml/zen_stores/schemas/flavor_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class FlavorSchema(NamedSchema, table=True):

__tablename__ = "flavor"

type: StackComponentType
type: str
source: str
config_schema: str = Field(sa_column=Column(TEXT, nullable=False))
integration: Optional[str] = Field(default="")
Expand Down Expand Up @@ -98,6 +98,8 @@ def update(self, flavor_update: "FlavorUpdate") -> "FlavorSchema":
).items():
if field == "config_schema":
setattr(self, field, json.dumps(value))
elif field == "type":
setattr(self, field, value.value)
else:
setattr(self, field, value)

Expand All @@ -123,7 +125,7 @@ def to_model(
"""
body = FlavorResponseBody(
user=self.user.to_model() if self.user else None,
type=self.type,
type=StackComponentType(self.type),
integration=self.integration,
logo_url=self.logo_url,
created=self.created,
Expand Down
10 changes: 5 additions & 5 deletions src/zenml/zen_stores/schemas/pipeline_run_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class PipelineRunSchema(NamedSchema, table=True):
orchestrator_run_id: Optional[str] = Field(nullable=True)
start_time: Optional[datetime] = Field(nullable=True)
end_time: Optional[datetime] = Field(nullable=True, default=None)
status: ExecutionStatus = Field(nullable=False)
status: str = Field(nullable=False)
orchestrator_environment: Optional[str] = Field(
sa_column=Column(TEXT, nullable=True)
)
Expand Down Expand Up @@ -203,7 +203,7 @@ def from_request(
orchestrator_run_id=request.orchestrator_run_id,
orchestrator_environment=orchestrator_environment,
start_time=request.start_time,
status=request.status,
status=request.status.value,
pipeline_id=request.pipeline,
deployment_id=request.deployment,
trigger_execution_id=request.trigger_execution_id,
Expand Down Expand Up @@ -277,7 +277,7 @@ def to_model(

body = PipelineRunResponseBody(
user=self.user.to_model() if self.user else None,
status=self.status,
status=ExecutionStatus(self.status),
stack=stack,
pipeline=pipeline,
build=build,
Expand Down Expand Up @@ -322,7 +322,7 @@ def update(self, run_update: "PipelineRunUpdate") -> "PipelineRunSchema":
The updated `PipelineRunSchema`.
"""
if run_update.status:
self.status = run_update.status
self.status = run_update.status.value
self.end_time = run_update.end_time

self.updated = datetime.utcnow()
Expand Down Expand Up @@ -367,7 +367,7 @@ def update_placeholder(

self.orchestrator_run_id = request.orchestrator_run_id
self.orchestrator_environment = orchestrator_environment
self.status = request.status
self.status = request.status.value

self.updated = datetime.utcnow()

Expand Down
4 changes: 2 additions & 2 deletions src/zenml/zen_stores/schemas/run_metadata_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class RunMetadataSchema(BaseSchema, table=True):

key: str
value: str = Field(sa_column=Column(TEXT, nullable=False))
type: MetadataTypeEnum
type: str

def to_model(
self,
Expand All @@ -134,7 +134,7 @@ def to_model(
created=self.created,
updated=self.updated,
value=json.loads(self.value),
type=self.type,
type=MetadataTypeEnum(self.type),
)
metadata = None
if include_metadata:
Expand Down
11 changes: 7 additions & 4 deletions src/zenml/zen_stores/schemas/secret_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class SecretSchema(NamedSchema, table=True):

__tablename__ = "secret"

scope: SecretScope
scope: str

values: Optional[bytes] = Field(sa_column=Column(TEXT, nullable=True))

Expand Down Expand Up @@ -177,7 +177,7 @@ def from_request(
assert secret.user is not None, "User must be set for secret creation."
return cls(
name=secret.name,
scope=secret.scope,
scope=secret.scope.value,
workspace_id=secret.workspace,
user_id=secret.user,
# Don't store secret values implicitly in the secret. The
Expand All @@ -204,7 +204,10 @@ def update(
for field, value in secret_update.dict(
exclude_unset=True, exclude={"workspace", "user", "values"}
).items():
setattr(self, field, value)
if field == "scope":
setattr(self, field, value.value)
else:
setattr(self, field, value)

self.updated = datetime.utcnow()
return self
Expand Down Expand Up @@ -239,7 +242,7 @@ def to_model(
user=self.user.to_model() if self.user else None,
created=self.created,
updated=self.updated,
scope=self.scope,
scope=SecretScope(self.scope),
)
return SecretResponse(
id=self.id,
Expand Down
14 changes: 6 additions & 8 deletions src/zenml/zen_stores/schemas/step_run_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
from zenml.enums import (
ExecutionStatus,
MetadataResourceTypes,
StepRunInputArtifactType,
StepRunOutputArtifactType,
)
from zenml.models import (
StepRunRequest,
Expand Down Expand Up @@ -60,7 +58,7 @@ class StepRunSchema(NamedSchema, table=True):
# Fields
start_time: Optional[datetime] = Field(nullable=True)
end_time: Optional[datetime] = Field(nullable=True)
status: ExecutionStatus = Field(nullable=False)
status: str = Field(nullable=False)

docstring: Optional[str] = Field(sa_column=Column(TEXT, nullable=True))
cache_key: Optional[str] = Field(nullable=True)
Expand Down Expand Up @@ -165,7 +163,7 @@ def from_request(cls, request: StepRunRequest) -> "StepRunSchema":
user_id=request.user,
start_time=request.start_time,
end_time=request.end_time,
status=request.status,
status=request.status.value,
original_step_run_id=request.original_step_run_id,
pipeline_run_id=request.pipeline_run_id,
deployment_id=request.deployment,
Expand Down Expand Up @@ -225,7 +223,7 @@ def to_model(

body = StepRunResponseBody(
user=self.user.to_model() if self.user else None,
status=self.status,
status=ExecutionStatus(self.status),
inputs=input_artifacts,
outputs=output_artifacts,
created=self.created,
Expand Down Expand Up @@ -270,7 +268,7 @@ def update(self, step_update: "StepRunUpdate") -> "StepRunSchema":
exclude_unset=True, exclude_none=True
).items():
if key == "status":
self.status = value
self.status = value.value
if key == "end_time":
self.end_time = value

Expand Down Expand Up @@ -312,7 +310,7 @@ class StepRunInputArtifactSchema(SQLModel, table=True):

# Fields
name: str = Field(nullable=False, primary_key=True)
type: StepRunInputArtifactType
type: str

# Foreign keys
step_id: UUID = build_foreign_key_field(
Expand Down Expand Up @@ -348,7 +346,7 @@ class StepRunOutputArtifactSchema(SQLModel, table=True):

# Fields
name: str
type: StepRunOutputArtifactType
type: str

# Foreign keys
step_id: UUID = build_foreign_key_field(
Expand Down
Loading

0 comments on commit e93c687

Please sign in to comment.