diff --git a/docs/book/deploying-zenml/zenml-self-hosted/deploy-with-docker.md b/docs/book/deploying-zenml/zenml-self-hosted/deploy-with-docker.md index cabd1b97f24..07148bc2eb7 100644 --- a/docs/book/deploying-zenml/zenml-self-hosted/deploy-with-docker.md +++ b/docs/book/deploying-zenml/zenml-self-hosted/deploy-with-docker.md @@ -27,7 +27,7 @@ The following environment variables can be passed to the container: * **ZENML\_DEFAULT\_PROJECT\_NAME**: The name of the default project created by the server on the first deployment, during database initialization. Defaults to `default`. * **ZENML\_DEFAULT\_USER\_NAME**: The name of the default admin user account created by the server on the first deployment, during database initialization. Defaults to `default`. * **ZENML\_DEFAULT\_USER\_PASSWORD**: The password to use for the default admin user account. Defaults to an empty password value, if not set. -* **ZENML\_STORE\_URL**: This URL should point to an SQLite database file _mounted in the container_, or to a MySQL-compatible database service _reachable from the container_. It takes one of these forms: +* **ZENML\_STORE\_URL**: This URL should point to an SQLite database file _mounted in the container_, or to a MySQL-compatible database service _reachable from the container_. It takes one of these forms: ``` sqlite:////path/to/zenml.db @@ -43,6 +43,7 @@ The following environment variables can be passed to the container: * **ZENML\_STORE\_SSL\_KEY**: This can be set to a client SSL private key required to connect to the MySQL database service. Only valid when `ZENML_STORE_URL` points to a MySQL database that uses SSL-secured connections and requires client SSL certificates. The variable can be set either to the path where the certificate file is mounted inside the container or to the certificate contents themselves. This variable also requires `ZENML_STORE_SSL_CERT` to be set. * **ZENML\_STORE\_SSL\_VERIFY\_SERVER\_CERT**: This boolean variable controls whether the SSL certificate in use by the MySQL server is verified. Only valid when `ZENML_STORE_URL` points to a MySQL database that uses SSL-secured connections. Defaults to `False`. * **ZENML\_LOGGING\_VERBOSITY**: Use this variable to control the verbosity of logs inside the container. It can be set to one of the following values: `NOTSET`, `ERROR`, `WARN`, `INFO` (default), `DEBUG` or `CRITICAL`. +* **ZENML\_STORE\_BACKUP\_STRATEGY**: This variable controls the database backup strategy used by the ZenML server. See the [Database backup and recovery](#database-backup-and-recovery) section for more details about this feature and other related environment variables. Defaults to `in-memory`. If none of the `ZENML_STORE_*` variables are set, the container will default to creating and using an SQLite database file stored at `/zenml/.zenconfig/local_stores/default_zen_store/zenml.db` inside the container. The `/zenml/.zenconfig/local_stores` base path where the default SQLite database is located can optionally be overridden by setting the `ZENML_LOCAL_STORES_PATH` environment variable to point to a different path (e.g. a persistent volume or directory that is mounted from the host). @@ -425,6 +426,47 @@ Tearing down the installation is as simple as running: docker-compose -p zenml down ``` + +## Database backup and recovery + +An automated database backup and recovery feature is enabled by default for all Docker deployments. The ZenML server will automatically back up the database in-memory before every database schema migration and restore it if the migration fails. + +{% hint style="info" %} +The database backup automatically created by the ZenML server is only temporary and only used as an immediate recovery in case of database migration failures. It is not meant to be used as a long-term backup solution. If you need to back up your database for long-term storage, you should use a dedicated backup solution. +{% endhint %} + +Several database backup strategies are supported, depending on where and how the backup is stored. The strategy can be configured by means of the `ZENML_STORE_BACKUP_STRATEGY` environment variable: + +* `disabled` - no backup is performed +* `in-memory` - the database schema and data are stored in memory. This is the fastest backup strategy, but the backup is not persisted across container restarts, so no manual intervention is possible in case the automatic DB recovery fails after a failed DB migration. Adequate memory resources should be allocated to the ZenML server container when using this backup strategy with larger databases. This is the default backup strategy. +* `database` - the database is copied to a backup database in the same database server. This requires the `ZENML_STORE_BACKUP_DATABASE` environment variable to be set to the name of the backup database. This backup strategy is only supported for MySQL compatible databases and the user specified in the database URL must have permissions to manage (create, drop, and modify) the backup database in addition to the main database. +* `dump-file` - the database schema and data are dumped to a filesystem location inside the ZenML server container. This location can be customized by means of the `ZENML_STORE_BACKUP_DIRECTORY` environment variable. When this strategy is configured, users should mount a host directory in the container and point the `ZENML_STORE_BACKUP_DIRECTORY` variable to where it's mounted inside the container. If a host directory is not mounted, the dump file will be stored in the container's filesystem and will be lost when the container is removed. + +The following additional rules are applied concerning the creation and lifetime of the backup: + +* a backup is not attempted if the database doesn't need to undergo a migration (e.g. when the ZenML server is upgraded to a new version that doesn't require a database schema change or if the ZenML version doesn't change at all). +* a backup file or database is created before every database migration attempt (i.e. when the container starts). If a backup already exists (i.e. persisted in a mounted host directory or backup database), it is overwritten. +* the persistent backup file or database is cleaned up after the migration is completed successfully or if the database doesn't need to undergo a migration. This includes backups created by previous failed migration attempts. +* the persistent backup file or database is NOT cleaned up after a failed migration. This allows the user to manually inspect and/or apply the backup if the automatic recovery fails. + +The following example shows how to deploy the ZenML server to use a mounted host directory to persist the database backup file during a database migration: + +```shell +mkdir mysql-data + +docker run --name mysql -d -p 3306:3306 -e MYSQL_ROOT_PASSWORD=password \ + --mount type=bind,source=$PWD/mysql-data,target=/var/lib/mysql \ + mysql:8.0 + +docker run -it -d -p 8080:8080 --name zenml \ + --add-host host.docker.internal:host-gateway \ + --mount type=bind,source=$PWD/mysql-data,target=/db-dump \ + --env ZENML_STORE_URL=mysql://root:password@host.docker.internal/zenml \ + --env ZENML_STORE_BACKUP_STRATEGY=dump-file \ + --env ZENML_STORE_BACKUP_DIRECTORY=/db-dump \ + zenmldocker/zenml-server +``` + ## Troubleshooting You can check the logs of the container to verify if the server is up and, depending on where you have deployed it, you can also access the dashboard at a `localhost` port (if running locally) or through some other service that exposes your container to the internet. diff --git a/docs/book/deploying-zenml/zenml-self-hosted/deploy-with-helm.md b/docs/book/deploying-zenml/zenml-self-hosted/deploy-with-helm.md index a52ce0a5f52..a00c33e2483 100644 --- a/docs/book/deploying-zenml/zenml-self-hosted/deploy-with-helm.md +++ b/docs/book/deploying-zenml/zenml-self-hosted/deploy-with-helm.md @@ -669,5 +669,47 @@ To configure a backup secrets store in the Helm chart, use the same approach and aws_secret_access_key: ``` +### Database backup and recovery + +An automated database backup and recovery feature is enabled by default for all Helm deployments. The ZenML server will automatically back up the database before every upgrade and restore it if the upgrade fails in a way that affects the database. + +{% hint style="info" %} +The database backup automatically created by the ZenML server is only temporary and only used as an immediate recovery in case of database migration failures. It is not meant to be used as a long-term backup solution. If you need to back up your database for long-term storage, you should use a dedicated backup solution. +{% endhint %} + +Several database backup strategies are supported, depending on where and how the backup is stored. The strategy can be configured by means of the `zenml.database.backupStrategy` Helm value: + +* `disabled` - no backup is performed +* `in-memory` - the database schema and data are stored in memory. This is the fastest backup strategy, but the backup is not persisted across pod restarts, so no manual intervention is possible in case the automatic DB recovery fails after a failed DB migration. Adequate memory resources should be allocated to the ZenML server pod when using this backup strategy with larger databases. This is the default backup strategy. +* `database` - the database is copied to a backup database in the same database server. This requires the `backupDatabase` option to be set to the name of the backup database. This backup strategy is only supported for MySQL compatible databases and the user specified in the database URL must have permissions to manage (create, drop, and modify) the backup database in addition to the main database. +* `dump-file` - the database schema and data are dumped to a file local to the database initialization and upgrade job. Users may optionally configure a persistent volume where the dump file will be stored by setting the `backupPVStorageSize` and optionally the `backupPVStorageClass` options. If a persistent volume is not configured, the dump file will be stored in an emptyDir volume, which is not persisted. If configured, the user is responsible for deleting the resulting PVC when uninstalling the Helm release. + +> **NOTE:** You should also set the `podSecurityContext.fsGroup` option if you are using a persistent volume to store the dump file. + +The following additional rules are applied concerning the creation and lifetime of the backup: + +* a backup is not attempted if the database doesn't need to undergo a migration (e.g. when the ZenML server is upgraded to a new version that doesn't require a database schema change or if the ZenML version doesn't change at all). +* a backup file or database is created before every database migration attempt (i.e. during every Helm upgrade). If a backup already exists (i.e. persisted in a persistent volume or backup database), it is overwritten. +* the persistent backup file or database is cleaned up after the migration is completed successfully or if the database doesn't need to undergo a migration. This includes backups created by previous failed migration attempts. +* the persistent backup file or database is NOT cleaned up after a failed migration. This allows the user to manually inspect and/or apply the backup if the automatic recovery fails. + +The following example shows how to configure the ZenML server to use a persistent volume to store the database dump file: + +```yaml + zenml: + + # ... + + database: + url: "mysql://admin:password@my.database.org:3306/zenml" + + # Configure the database backup strategy + backupStrategy: dump-file + backupPVStorageSize: 1Gi + +podSecurityContext: + fsGroup: 1000 # if you're using a PVC for backup, this should necessarily be set. +``` +
ZenML Scarf
diff --git a/scripts/format.sh b/scripts/format.sh index 5da17a625c0..8aa9d917ab3 100755 --- a/scripts/format.sh +++ b/scripts/format.sh @@ -1,4 +1,4 @@ -#!/bin/sh -e +#!/usr/bin/env bash set -x # Initialize default source directories diff --git a/src/zenml/cli/base.py b/src/zenml/cli/base.py index c05c47bd83e..e03e89169d9 100644 --- a/src/zenml/cli/base.py +++ b/src/zenml/cli/base.py @@ -36,7 +36,7 @@ ENV_ZENML_ENABLE_REPO_INIT_WARNINGS, REPOSITORY_DIRECTORY_NAME, ) -from zenml.enums import AnalyticsEventSource, StoreType +from zenml.enums import AnalyticsEventSource, DatabaseBackupStrategy, StoreType from zenml.environment import Environment, get_environment from zenml.exceptions import GitNotFoundError, InitializationException from zenml.integrations.registry import integration_registry @@ -659,3 +659,135 @@ def migrate_database(skip_default_registrations: bool = False) -> None: cli_utils.warning( "Unable to migrate database while connected to a ZenML server." ) + + +@cli.command("backup-database", help="Create a database backup.", hidden=True) +@click.option( + "--strategy", + "-s", + help="Custom backup strategy to use. Defaults to whatever is configured " + "in the store config.", + type=click.Choice(choices=DatabaseBackupStrategy.values()), + required=False, + default=None, +) +@click.option( + "--location", + default=None, + help="Custom location to store the backup. Defaults to whatever is " + "configured in the store config. Depending on the strategy, this can be " + "a local path or a database name.", + type=str, +) +@click.option( + "--overwrite", + "-o", + is_flag=True, + default=False, + help="Overwrite the existing backup.", + type=bool, +) +def backup_database( + strategy: Optional[str] = None, + location: Optional[str] = None, + overwrite: bool = False, +) -> None: + """Backup the ZenML database. + + Args: + strategy: Custom backup strategy to use. Defaults to whatever is + configured in the store config. + location: Custom location to store the backup. Defaults to whatever is + configured in the store config. Depending on the strategy, this can + be a local path or a database name. + overwrite: Whether to overwrite the existing backup. + """ + from zenml.zen_stores.base_zen_store import BaseZenStore + from zenml.zen_stores.sql_zen_store import SqlZenStore + + store_config = ( + GlobalConfiguration().store + or GlobalConfiguration().get_default_store() + ) + if store_config.type == StoreType.SQL: + store = BaseZenStore.create_store( + store_config, skip_default_registrations=True, skip_migrations=True + ) + assert isinstance(store, SqlZenStore) + msg, location = store.backup_database( + strategy=DatabaseBackupStrategy(strategy) if strategy else None, + location=location, + overwrite=overwrite, + ) + cli_utils.declare(f"Database was backed up to {msg}.") + else: + cli_utils.warning( + "Cannot backup database while connected to a ZenML server." + ) + + +@cli.command( + "restore-database", help="Restore the database from a backup.", hidden=True +) +@click.option( + "--strategy", + "-s", + help="Custom backup strategy to use. Defaults to whatever is configured " + "in the store config.", + type=click.Choice(choices=DatabaseBackupStrategy.values()), + required=False, + default=None, +) +@click.option( + "--location", + default=None, + help="Custom location where the backup is stored. Defaults to whatever is " + "configured in the store config. Depending on the strategy, this can be " + "a local path or a database name.", + type=str, +) +@click.option( + "--cleanup", + "-c", + is_flag=True, + default=False, + help="Cleanup the backup after restoring.", + type=bool, +) +def restore_database( + strategy: Optional[str] = None, + location: Optional[str] = None, + cleanup: bool = False, +) -> None: + """Restore the ZenML database. + + Args: + strategy: Custom backup strategy to use. Defaults to whatever is + configured in the store config. + location: Custom location where the backup is stored. Defaults to + whatever is configured in the store config. Depending on the + strategy, this can be a local path or a database name. + cleanup: Whether to cleanup the backup after restoring. + """ + from zenml.zen_stores.base_zen_store import BaseZenStore + from zenml.zen_stores.sql_zen_store import SqlZenStore + + store_config = ( + GlobalConfiguration().store + or GlobalConfiguration().get_default_store() + ) + if store_config.type == StoreType.SQL: + store = BaseZenStore.create_store( + store_config, skip_default_registrations=True, skip_migrations=True + ) + assert isinstance(store, SqlZenStore) + store.restore_database( + strategy=DatabaseBackupStrategy(strategy) if strategy else None, + location=location, + cleanup=cleanup, + ) + cli_utils.declare("Database restore finished.") + else: + cli_utils.warning( + "Cannot restore database while connected to a ZenML server." + ) diff --git a/src/zenml/constants.py b/src/zenml/constants.py index 77b82835a91..f904b9d480f 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -139,6 +139,9 @@ def handle_int_env_var(var: str, default: int = 0) -> int: # Default store directory subpath: DEFAULT_STORE_DIRECTORY_NAME = "default_zen_store" +# SQL Store backup directory subpath: +SQL_STORE_BACKUP_DIRECTORY_NAME = "database_backup" + DEFAULT_USERNAME = "default" DEFAULT_PASSWORD = "" DEFAULT_WORKSPACE_NAME = "default" diff --git a/src/zenml/enums.py b/src/zenml/enums.py index d2657a4104b..b46e06d4e4c 100644 --- a/src/zenml/enums.py +++ b/src/zenml/enums.py @@ -346,3 +346,16 @@ class MetadataResourceTypes(StrEnum): STEP_RUN = "step_run" ARTIFACT_VERSION = "artifact_version" MODEL_VERSION = "model_version" + + +class DatabaseBackupStrategy(StrEnum): + """All available database backup strategies.""" + + # Backup disabled + DISABLED = "disabled" + # In-memory backup + IN_MEMORY = "in-memory" + # Dump the database to a file + DUMP_FILE = "dump-file" + # Create a backup of the database in the remote database service + DATABASE = "database" diff --git a/src/zenml/zen_server/deploy/helm/templates/server-db-job.yaml b/src/zenml/zen_server/deploy/helm/templates/server-db-job.yaml index 48a8879aa37..cfc6f02f8f0 100644 --- a/src/zenml/zen_server/deploy/helm/templates/server-db-job.yaml +++ b/src/zenml/zen_server/deploy/helm/templates/server-db-job.yaml @@ -10,7 +10,7 @@ metadata: "helm.sh/hook-weight": "-1" "helm.sh/hook-delete-policy": before-hook-creation{{ if not .Values.zenml.debug }},hook-succeeded{{ end }} spec: - backoffLimit: 2 + backoffLimit: 0 template: metadata: annotations: @@ -32,6 +32,20 @@ spec: {{- end }} securityContext: {{- toYaml .Values.podSecurityContext | nindent 8 }} + + {{- if eq .Values.zenml.database.backupStrategy "dump-file" }} + volumes: + # define a volume that will hold a backup of the database + - name: db-backup + # if a storage PVC is configured, then use it + {{- if .Values.zenml.database.backupPVStorageSize }} + persistentVolumeClaim: + claimName: {{ include "zenml.fullname" . }}-db-backup + {{- else }} + # otherwise, use an emptyDir + emptyDir: {} + {{- end }} + {{- end }} restartPolicy: Never containers: - name: {{ .Chart.Name }}-db-migration @@ -41,6 +55,11 @@ spec: imagePullPolicy: {{ .Values.zenml.image.pullPolicy }} args: ["migrate-database"] command: ['zenml'] + {{- if eq .Values.zenml.database.backupStrategy "dump-file" }} + volumeMounts: + - name: db-backup + mountPath: /backups + {{- end }} env: {{- if .Values.zenml.debug }} - name: ZENML_LOGGING_VERBOSITY @@ -56,6 +75,17 @@ spec: value: sql - name: ZENML_STORE_SSL_VERIFY_SERVER_CERT value: {{ .Values.zenml.database.sslVerifyServerCert | default "false" | quote }} + {{- if .Values.zenml.database.backupStrategy }} + - name: ZENML_STORE_BACKUP_STRATEGY + value: {{ .Values.zenml.database.backupStrategy | quote }} + {{- if eq .Values.zenml.database.backupStrategy "database" }} + - name: ZENML_STORE_BACKUP_DATABASE + value: {{ .Values.zenml.database.backupDatabase | quote }} + {{- else if eq .Values.zenml.database.backupStrategy "dump-file" }} + - name: ZENML_STORE_BACKUP_DIRECTORY + value: /backups + {{- end }} + {{- end }} {{- range $k, $v := include "zenml.serverEnvVariables" . | fromYaml }} - name: {{ $k }} value: {{ $v | quote }} diff --git a/src/zenml/zen_server/deploy/helm/templates/server-db-pvc.yaml b/src/zenml/zen_server/deploy/helm/templates/server-db-pvc.yaml new file mode 100644 index 00000000000..b9618e7a876 --- /dev/null +++ b/src/zenml/zen_server/deploy/helm/templates/server-db-pvc.yaml @@ -0,0 +1,25 @@ +{{- if and (eq .Values.zenml.database.backupStrategy "dump-file") .Values.zenml.database.backupPVStorageSize }} +{{- $pvc_name := printf "%s-db-backup" (include "zenml.fullname" .) -}} +{{- $pvc := (lookup "v1" "PersistentVolumeClaim" .Release.Namespace $pvc_name) }} +{{- if not $pvc }} +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: {{ $pvc_name }} + labels: + {{- include "zenml.labels" . | nindent 4 }} + annotations: + "helm.sh/hook": "pre-install,pre-upgrade" + "helm.sh/hook-weight": "-1" + "helm.sh/hook-delete-policy": before-hook-creation +spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: {{ .Values.zenml.database.backupPVStorageSize }} + {{- with .Values.zenml.database.backupPVStorageClass }} + storageClassName: {{ . }} + {{- end }} +{{- end }} +{{- end }} \ No newline at end of file diff --git a/src/zenml/zen_server/deploy/helm/values.yaml b/src/zenml/zen_server/deploy/helm/values.yaml index 66be69ea146..800dcc6a25b 100644 --- a/src/zenml/zen_server/deploy/helm/values.yaml +++ b/src/zenml/zen_server/deploy/helm/values.yaml @@ -188,13 +188,51 @@ zenml: # used, which will not be persisted across pod restarts. # NOTE: the certificate files need to be copied in the helm chart folder and # the paths configured here need to be relative to the root of the helm chart. - database: {} + database: # url: "mysql://admin:password@zenml-mysql:3306/database" # sslCa: /path/to/ca.pem # sslCert: /path/to/client-cert.pem # sslKey: /path/to/client-key.pem # sslVerifyServerCert: True + # ZenML supports backing up the database before DB migrations are performed + # and restoring it in case of a DB migration failure. For more information, + # see the following documentation: + # https://docs.zenml.io/deploying-zenml/zenml-self-hosted/deploy-with-helm#database-backup-and-recovery + # + # Several backup strategies are supported: + # + # disabled - no backup is performed + # in-memory - the database schema and data are stored in memory. This is + # the fastest backup strategy, but the backup is not persisted + # across pod restarts, so no manual intervention is possible + # in case the automatic DB recovery fails after a failed DB + # migration. Adequate memory resources should be allocated to + # the ZenML server pod when using this backup strategy with + # large databases. + # This is the default backup strategy. + # dump-file - the database schema and data are dumped to a local file. + # Users may optionally configure a persistent volume where + # the dump file will be stored by setting the + # `backupPVStorageSize` and optionally the + # `backupPVStorageClass` options. If a + # persistent volume is not configured, the dump file will be + # stored in an emptyDir volume, which is not persisted. + # NOTE: you should set the podSecurityContext.fsGroup option + # if you are using a persistent volume to store the dump file. + # database - the database is copied to a backup database in the same + # database server. This requires the `backupDatabase` + # option to be set to the name of the backup database. + # This backup strategy is only supported for MySQL + # compatible databases and the user specified in the + # database URL must have permissions to manage (create, drop, and + # modify) the backup database in addition to the main + # database. + backupStrategy: in-memory + # backupPVStorageClass: standard + # backupPVStorageSize: 1Gi + # backupDatabase: "zenml_backup" + # Secrets store settings. This is used to store centralized secrets. secretsStore: @@ -787,7 +825,7 @@ serviceAccount: podAnnotations: {} podSecurityContext: {} - # fsGroup: 2000 + # fsGroup: 1000 # if you're using a PVC for backup, this should necessarily be set. securityContext: runAsNonRoot: true diff --git a/src/zenml/zen_stores/migrations/alembic.py b/src/zenml/zen_stores/migrations/alembic.py index e2c2286a325..ace35d3e989 100644 --- a/src/zenml/zen_stores/migrations/alembic.py +++ b/src/zenml/zen_stores/migrations/alembic.py @@ -156,6 +156,27 @@ def run_migrations( with self.environment_context.begin_transaction(): self.environment_context.run_migrations() + def head_revisions(self) -> List[str]: + """Get the head database revisions. + + Returns: + List of head revisions. + """ + head_revisions: List[str] = [] + + def do_get_head_rev(rev: _RevIdType, context: Any) -> List[Any]: + nonlocal head_revisions + + for r in self.script_directory.get_heads(): + if r is None: + continue + head_revisions.append(r) + return [] + + self.run_migrations(do_get_head_rev) + + return head_revisions + def current_revisions(self) -> List[str]: """Get the current database revisions. diff --git a/src/zenml/zen_stores/migrations/utils.py b/src/zenml/zen_stores/migrations/utils.py new file mode 100644 index 00000000000..f1300946ee5 --- /dev/null +++ b/src/zenml/zen_stores/migrations/utils.py @@ -0,0 +1,653 @@ +# Copyright (c) ZenML GmbH 2022. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""ZenML database migration, backup and recovery utilities.""" + +import json +import os +import shutil +from typing import ( + Any, + Callable, + Dict, + Generator, + List, + Optional, + cast, +) + +import pymysql +from pydantic import BaseModel +from pydantic.json import pydantic_encoder +from sqlalchemy import MetaData, func, text +from sqlalchemy.engine import URL, Engine +from sqlalchemy.exc import ( + OperationalError, +) +from sqlalchemy.schema import CreateTable +from sqlmodel import ( + create_engine, + select, +) + +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +class MigrationUtils(BaseModel): + """Utilities for database migration, backup and recovery.""" + + url: URL + connect_args: Dict[str, Any] + engine_args: Dict[str, Any] + + _engine: Optional[Engine] = None + _master_engine: Optional[Engine] = None + + def create_engine(self, database: Optional[str] = None) -> Engine: + """Get the SQLAlchemy engine for a database. + + Args: + database: The name of the database. If not set, a master engine + will be returned. + + Returns: + The SQLAlchemy engine. + """ + url = self.url._replace(database=database) + return create_engine( + url=url, + connect_args=self.connect_args, + **self.engine_args, + ) + + @property + def engine(self) -> Engine: + """The SQLAlchemy engine. + + Returns: + The SQLAlchemy engine. + """ + if self._engine is None: + self._engine = self.create_engine(database=self.url.database) + return self._engine + + @property + def master_engine(self) -> Engine: + """The SQLAlchemy engine for the master database. + + Returns: + The SQLAlchemy engine for the master database. + """ + if self._master_engine is None: + self._master_engine = self.create_engine() + return self._master_engine + + @classmethod + def is_mysql_missing_database_error(cls, error: OperationalError) -> bool: + """Checks if the given error is due to a missing database. + + Args: + error: The error to check. + + Returns: + If the error because the MySQL database doesn't exist. + """ + from pymysql.constants.ER import BAD_DB_ERROR + + if not isinstance(error.orig, pymysql.err.OperationalError): + return False + + error_code = cast(int, error.orig.args[0]) + return error_code == BAD_DB_ERROR + + def database_exists( + self, + database: Optional[str] = None, + ) -> bool: + """Check if a database exists. + + Args: + database: The name of the database to check. If not set, the + database name from the configuration will be used. + + Returns: + Whether the database exists. + + Raises: + OperationalError: If connecting to the database failed. + """ + database = database or self.url.database + + engine = self.create_engine(database=database) + try: + engine.connect() + except OperationalError as e: + if self.is_mysql_missing_database_error(e): + return False + else: + logger.exception( + f"Failed to connect to mysql database `{database}`.", + ) + raise + else: + return True + + def drop_database( + self, + database: Optional[str] = None, + ) -> None: + """Drops a mysql database. + + Args: + database: The name of the database to drop. If not set, the + database name from the configuration will be used. + """ + database = database or self.url.database + with self.master_engine.connect() as conn: + # drop the database if it exists + logger.info(f"Dropping database '{database}'") + conn.execute(text(f"DROP DATABASE IF EXISTS `{database}`")) + + def create_database( + self, + database: Optional[str] = None, + drop: bool = False, + ) -> None: + """Creates a mysql database. + + Args: + database: The name of the database to create. If not set, the + database name from the configuration will be used. + drop: Whether to drop the database if it already exists. + """ + database = database or self.url.database + if drop: + self.drop_database(database=database) + + with self.master_engine.connect() as conn: + logger.info(f"Creating database '{database}'") + conn.execute(text(f"CREATE DATABASE IF NOT EXISTS `{database}`")) + + def backup_database_to_storage( + self, store_db_info: Callable[[Dict[str, Any]], None] + ) -> None: + """Backup the database to a storage location. + + Backup the database to an abstract storage location. The storage + location is specified by a function that is called repeatedly to + store the database information. The function is called with a single + argument, which is a dictionary containing either the table schema or + table data. The dictionary contains the following keys: + + * `table`: The name of the table. + * `create_stmt`: The table creation statement. + * `data`: A list of rows in the table. + + Args: + store_db_info: The function to call to store the database + information. + """ + metadata = MetaData() + metadata.reflect(bind=self.engine) + with self.engine.connect() as conn: + for table in metadata.sorted_tables: + # 1. extract the table creation statements + + create_table_construct = CreateTable(table) + create_table_stmt = str(create_table_construct).strip() + for column in create_table_construct.columns: + # enclosing all column names in backticks. This is because + # some column names are reserved keywords in MySQL. For + # example, keys and values. So, instead of tracking all + # keywords, we just enclose all column names in backticks. + # enclose the first word in the column definition in + # backticks + words = str(column).split() + words[0] = f"`{words[0]}`" + create_table_stmt = create_table_stmt.replace( + f"\n\t{str(column)}", " ".join(words) + ) + # if any double quotes are used for column names, replace them + # with backticks + create_table_stmt = create_table_stmt.replace('"', "") + ";" + + # Store the table schema + store_db_info( + dict(table=table.name, create_stmt=create_table_stmt) + ) + + # 2. extract the table data in batches + + # If the table has a `created` column, we use it to sort + # the rows in the table starting with the oldest rows. + # This is to ensure that the rows are inserted in the + # correct order, since some tables have inner foreign key + # constraints. + if "created" in table.columns: + order_by = table.columns["created"] + else: + order_by = None + + # Fetch the number of rows in the table + row_count = conn.scalar( + select([func.count("*")]).select_from(table) + ) + + # Fetch the data from the table in batches + batch_size = 50 + for i in range(0, row_count, batch_size): + rows = conn.execute( + table.select() + .order_by(order_by) + .limit(batch_size) + .offset(i) + ).fetchall() + + store_db_info( + dict( + table=table.name, + data=[row._asdict() for row in rows], + ), + ) + + def restore_database_from_storage( + self, load_db_info: Callable[[], Generator[Dict[str, Any], None, None]] + ) -> None: + """Restore the database from a backup storage location. + + Restores the database from an abstract storage location. The storage + location is specified by a function that is called repeatedly to + load the database information from the external storage chunk by chunk. + The function must yield a dictionary containing either the table schema + or table data. The dictionary contains the following keys: + + * `table`: The name of the table. + * `create_stmt`: The table creation statement. + * `data`: A list of rows in the table. + + The function must return `None` when there is no more data to load. + + Args: + load_db_info: The function to call to load the database + information. + """ + # Drop and re-create the primary database + self.create_database( + drop=True, + ) + + metadata = MetaData(bind=self.engine) + + with self.engine.begin() as connection: + # read the DB information one JSON object at a time + for table_dump in load_db_info(): + table_name = table_dump["table"] + if "create_stmt" in table_dump: + # execute the table creation statement + connection.execute(text(table_dump["create_stmt"])) + # Reload the database metadata after creating the table + metadata.reflect() + + if "data" in table_dump: + # insert the data into the database + table = metadata.tables[table_name] + for row in table_dump["data"]: + # Convert column values to the correct type + for column in table.columns: + # Blob columns are stored as binary strings + if column.type.python_type == bytes and isinstance( + row[column.name], str + ): + # Convert the string to bytes + row[column.name] = bytes( + row[column.name], "utf-8" + ) + + # Insert the rows into the table + connection.execute( + table.insert().values(table_dump["data"]) + ) + + def backup_database_to_file(self, dump_file: str) -> None: + """Backup the database to a file. + + This method dumps the entire database into a JSON file. Instead of + using a SQL dump, we use a proprietary JSON dump because: + + * it is (mostly) not dependent on the SQL dialect or database version + * it is safer with respect to SQL injection attacks + * it is easier to read and debug + + The JSON file contains a list of JSON objects instead of a single JSON + object, because it allows for buffered reading and writing of the file + and thus reduces the memory footprint. Each JSON object can contain + either schema or data information about a single table. For tables with + a large amount of data, the data is split into multiple JSON objects + with the first object always containing the schema. + + The format of the dump is as depicted in the following example: + + ```json + { + "table": "table1", + "create_stmt": "CREATE TABLE table1 (id INTEGER NOT NULL, " + "name VARCHAR(255), PRIMARY KEY (id))" + } + { + "table": "table1", + "data": [ + { + "id": 1, + "name": "foo" + }, + { + "id": 1, + "name": "bar" + }, + ... + ] + } + { + "table": "table1", + "data": [ + { + "id": 101, + "name": "fee" + }, + { + "id": 102, + "name": "bee" + }, + ... + ] + } + ``` + + Args: + dump_file: The path to the dump file. + """ + # create the directory if it does not exist + dump_path = os.path.dirname(os.path.abspath(dump_file)) + if not os.path.exists(dump_path): + os.makedirs(dump_path) + + if self.url.drivername == "sqlite": + # For a sqlite database, we can just make a copy of the database + # file + assert self.url.database is not None + shutil.copyfile( + self.url.database, + dump_file, + ) + return + + with open(dump_file, "w") as f: + + def json_dump(obj: Dict[str, Any]) -> None: + """Dump a JSON object to the dump file. + + Args: + obj: The JSON object to dump. + """ + # Write the data to the JSON file. Use an encoder that + # can handle datetime, Decimal and other types. + json.dump( + obj, + f, + indent=4, + default=pydantic_encoder, + ) + f.write("\n") + + # Call the generic backup method with a function that dumps the + # JSON objects to the dump file + self.backup_database_to_storage(json_dump) + + logger.debug(f"Database backed up to {dump_file}") + + def restore_database_from_file(self, dump_file: str) -> None: + """Restore the database from a backup dump file. + + See the documentation of the `backup_database_to_file` method for + details on the format of the dump file. + + Args: + dump_file: The path to the dump file. + + Raises: + RuntimeError: If the database cannot be restored successfully. + """ + if not os.path.exists(dump_file): + raise RuntimeError( + f"Database backup file '{dump_file}' does not " + f"exist or is not accessible." + ) + + if self.url.drivername == "sqlite": + # For a sqlite database, we just overwrite the database file + # with the backup file + assert self.url.database is not None + shutil.copyfile( + dump_file, + self.url.database, + ) + return + + # read the DB dump file one JSON object at a time + with open(dump_file, "r") as f: + + def json_load() -> Generator[Dict[str, Any], None, None]: + """Generator that loads the JSON objects in the dump file. + + Yields: + The loaded JSON objects. + """ + buffer = "" + while True: + chunk = f.readline() + if not chunk: + break + buffer += chunk + if chunk.rstrip() == "}": + yield json.loads(buffer) + buffer = "" + + # Call the generic restore method with a function that loads the + # JSON objects from the dump file + self.restore_database_from_storage(json_load) + + logger.info(f"Database successfully restored from '{dump_file}'") + + def backup_database_to_memory(self) -> List[Dict[str, Any]]: + """Backup the database in memory. + + Returns: + The in-memory representation of the database backup. + + Raises: + RuntimeError: If the database cannot be backed up successfully. + """ + if self.url.drivername == "sqlite": + # For a sqlite database, this is not supported. + raise RuntimeError( + "In-memory backup is not supported for sqlite databases." + ) + + db_dump: List[Dict[str, Any]] = [] + + def store_in_mem(obj: Dict[str, Any]) -> None: + """Store a JSON object in the in-memory database backup. + + Args: + obj: The JSON object to store. + """ + db_dump.append(obj) + + # Call the generic backup method with a function that stores the + # JSON objects in the in-memory database backup + self.backup_database_to_storage(store_in_mem) + + logger.debug("Database backed up in memory") + + return db_dump + + def restore_database_from_memory( + self, db_dump: List[Dict[str, Any]] + ) -> None: + """Restore the database from an in-memory backup. + + Args: + db_dump: The in-memory database backup to restore from generated + by the `backup_database_to_memory` method. + + Raises: + RuntimeError: If the database cannot be restored successfully. + """ + if self.url.drivername == "sqlite": + # For a sqlite database, this is not supported. + raise RuntimeError( + "In-memory backup is not supported for sqlite databases." + ) + + def load_from_mem() -> Generator[Dict[str, Any], None, None]: + """Generator that loads the JSON objects from the in-memory backup. + + Yields: + The loaded JSON objects. + """ + for obj in db_dump: + yield obj + + # Call the generic restore method with a function that loads the + # JSON objects from the in-memory database backup + self.restore_database_from_storage(load_from_mem) + + logger.info("Database successfully restored from memory") + + @classmethod + def _copy_database(cls, src_engine: Engine, dst_engine: Engine) -> None: + """Copy the database from one engine to another. + + This method assumes that the destination database exists and is empty. + + Args: + src_engine: The source SQLAlchemy engine. + dst_engine: The destination SQLAlchemy engine. + """ + src_metadata = MetaData(bind=src_engine) + src_metadata.reflect() + + dst_metadata = MetaData(bind=dst_engine) + dst_metadata.reflect() + + # @event.listens_for(src_metadata, "column_reflect") + # def generalize_datatypes(inspector, tablename, column_dict): + # column_dict["type"] = column_dict["type"].as_generic(allow_nulltype=True) + + # Create all tables in the target database + for table in src_metadata.sorted_tables: + table.create(bind=dst_engine) + + # Refresh target metadata after creating the tables + dst_metadata.clear() + dst_metadata.reflect() + + # Copy all data from the source database to the destination database + with src_engine.begin() as src_conn: + with dst_engine.begin() as dst_conn: + for src_table in src_metadata.sorted_tables: + dst_table = dst_metadata.tables[src_table.name] + insert = dst_table.insert() + # If the table has a `created` column, we use it to sort + # the rows in the table starting with the oldest rows. + # This is to ensure that the rows are inserted in the + # correct order, since some tables have inner foreign key + # constraints. + if "created" in src_table.columns: + order_by = src_table.columns["created"] + else: + order_by = None + + row_count = src_conn.scalar( + select([func.count("*")]).select_from(src_table) + ) + + # Copy rows in batches + batch_size = 50 + for i in range(0, row_count, batch_size): + rows = src_conn.execute( + src_table.select() + .order_by(order_by) + .limit(batch_size) + .offset(i) + ).fetchall() + + dst_conn.execute( + insert, [row._asdict() for row in rows] + ) + + def backup_database_to_db(self, backup_db_name: str) -> None: + """Backup the database to a backup database. + + Args: + backup_db_name: Backup database name to backup to. + """ + # Re-create the backup database + self.create_database( + database=backup_db_name, + drop=True, + ) + + backup_engine = self.create_engine(database=backup_db_name) + + self._copy_database(self.engine, backup_engine) + + logger.debug( + f"Database backed up to the `{backup_db_name}` backup database." + ) + + def restore_database_from_db(self, backup_db_name: str) -> None: + """Restore the database from the backup database. + + Args: + backup_db_name: Backup database name to restore from. + + Raises: + RuntimeError: If the backup database does not exist. + """ + if not self.database_exists(database=backup_db_name): + raise RuntimeError( + f"Backup database `{backup_db_name}` does not exist." + ) + + backup_engine = self.create_engine(database=backup_db_name) + + # Drop and re-create the primary database + self.create_database( + drop=True, + ) + + self._copy_database(backup_engine, self.engine) + + logger.debug( + f"Database restored from the `{backup_db_name}` " + "backup database." + ) + + class Config: + """Pydantic configuration class.""" + + # all attributes with leading underscore are private + underscore_attrs_are_private = True diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index f40c5c0b1e9..b2df002a9e8 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -39,15 +39,13 @@ ) from uuid import UUID -import pymysql from pydantic import SecretStr, root_validator, validator -from sqlalchemy import asc, desc, func, text +from sqlalchemy import asc, desc, func from sqlalchemy.engine import URL, Engine, make_url from sqlalchemy.exc import ( ArgumentError, IntegrityError, NoResultFound, - OperationalError, ) from sqlalchemy.orm import noload from sqlmodel import ( @@ -75,10 +73,12 @@ ENV_ZENML_DEFAULT_USER_NAME, ENV_ZENML_DEFAULT_USER_PASSWORD, ENV_ZENML_DISABLE_DATABASE_MIGRATION, + SQL_STORE_BACKUP_DIRECTORY_NAME, TEXT_FIELD_MAX_LENGTH, ) from zenml.enums import ( AuthScheme, + DatabaseBackupStrategy, ExecutionStatus, LoggingLevels, ModelStages, @@ -234,9 +234,9 @@ BaseZenStore, ) from zenml.zen_stores.migrations.alembic import ( - ZENML_ALEMBIC_START_REVISION, Alembic, ) +from zenml.zen_stores.migrations.utils import MigrationUtils from zenml.zen_stores.schemas import ( APIKeySchema, ArtifactSchema, @@ -296,24 +296,6 @@ ZENML_SQLITE_DB_FILENAME = "zenml.db" -def _is_mysql_missing_database_error(error: OperationalError) -> bool: - """Checks if the given error is due to a missing database. - - Args: - error: The error to check. - - Returns: - If the error because the MySQL database doesn't exist. - """ - from pymysql.constants.ER import BAD_DB_ERROR - - if not isinstance(error.orig, pymysql.err.OperationalError): - return False - - error_code = cast(int, error.orig.args[0]) - return error_code == BAD_DB_ERROR - - class SQLDatabaseDriver(StrEnum): """SQL database drivers supported by the SQL ZenML store.""" @@ -372,6 +354,14 @@ class SqlZenStoreConfiguration(StoreConfiguration): max_overflow: int = 20 pool_pre_ping: bool = True + backup_strategy: DatabaseBackupStrategy = DatabaseBackupStrategy.IN_MEMORY + # database backup directory + backup_directory: str = os.path.join( + GlobalConfiguration().config_directory, + SQL_STORE_BACKUP_DIRECTORY_NAME, + ) + backup_database: Optional[str] = None + @validator("secrets_store") def validate_secrets_store( cls, secrets_store: Optional[SecretsStoreConfiguration] @@ -416,6 +406,33 @@ def _remove_grpc_attributes(cls, values: Dict[str, Any]) -> Dict[str, Any]: return values + @root_validator + def _validate_backup_strategy( + cls, values: Dict[str, Any] + ) -> Dict[str, Any]: + """Validate the backup strategy. + + Args: + values: All model attribute values. + + Returns: + The model attribute values. + + Raises: + ValueError: If the backup database name is not set when the backup + database is requested. + """ + backup_strategy = values.get("backup_strategy") + if backup_strategy == DatabaseBackupStrategy.DATABASE and ( + not values.get("backup_database") + ): + raise ValueError( + "The `backup_database` attribute must also be set if the " + "backup strategy is set to use a backup database." + ) + + return values + @root_validator def _validate_url(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Validate the SQL URL. @@ -628,13 +645,18 @@ def copy_configuration( return config - def get_sqlmodel_config( + def get_sqlalchemy_config( self, - ) -> Tuple[str, Dict[str, Any], Dict[str, Any]]: - """Get the SQLModel engine configuration for the SQL ZenML store. + database: Optional[str] = None, + ) -> Tuple[URL, Dict[str, Any], Dict[str, Any]]: + """Get the SQLAlchemy engine configuration for the SQL ZenML store. + + Args: + database: Custom database name to use. If not set, the database name + from the configuration will be used. Returns: - The URL and connection arguments for the SQLModel engine. + The URL and connection arguments for the SQLAlchemy engine. Raises: NotImplementedError: If the SQL driver is not supported. @@ -656,6 +678,9 @@ def get_sqlmodel_config( assert self.password is not None assert sql_url.host is not None + if not database: + database = self.database + engine_args = { "pool_size": self.pool_size, "max_overflow": self.max_overflow, @@ -666,7 +691,7 @@ def get_sqlmodel_config( drivername="mysql+pymysql", username=self.username, password=self.password, - database=self.database, + database=database, ) sqlalchemy_ssl_args: Dict[str, Any] = {} @@ -691,7 +716,7 @@ def get_sqlmodel_config( f"SQL driver `{sql_url.drivername}` is not supported." ) - return str(sql_url), sqlalchemy_connect_args, engine_args + return sql_url, sqlalchemy_connect_args, engine_args class Config: """Pydantic configuration class.""" @@ -721,6 +746,7 @@ class SqlZenStore(BaseZenStore): CONFIG_TYPE: ClassVar[Type[StoreConfiguration]] = SqlZenStoreConfiguration _engine: Optional[Engine] = None + _migration_utils: Optional[MigrationUtils] = None _alembic: Optional[Alembic] = None _secrets_store: Optional[BaseSecretsStore] = None _backup_secrets_store: Optional[BaseSecretsStore] = None @@ -766,6 +792,20 @@ def engine(self) -> Engine: raise ValueError("Store not initialized") return self._engine + @property + def migration_utils(self) -> MigrationUtils: + """The migration utils. + + Returns: + The migration utils. + + Raises: + ValueError: If the store is not initialized. + """ + if not self._migration_utils: + raise ValueError("Store not initialized") + return self._migration_utils + @property def alembic(self) -> Alembic: """The Alembic wrapper. @@ -914,17 +954,18 @@ def filter_and_paginate( # -------------------------------- def _initialize(self) -> None: - """Initialize the SQL store. - - Raises: - OperationalError: If connecting to the database failed. - """ + """Initialize the SQL store.""" logger.debug("Initializing SqlZenStore at %s", self.config.url) - url, connect_args, engine_args = self.config.get_sqlmodel_config() + url, connect_args, engine_args = self.config.get_sqlalchemy_config() self._engine = create_engine( url=url, connect_args=connect_args, **engine_args ) + self._migration_utils = MigrationUtils( + url=url, + connect_args=connect_args, + engine_args=engine_args, + ) # SQLite: As long as the parent directory exists, SQLAlchemy will # automatically create the database. @@ -943,24 +984,11 @@ def _initialize(self) -> None: self.config.driver == SQLDatabaseDriver.MYSQL and self.config.database ): - try: - self._engine.connect() - except OperationalError as e: - logger.debug( - "Failed to connect to mysql database `%s`.", - self._engine.url.database, - ) - - if _is_mysql_missing_database_error(e): - self._create_mysql_database( - url=self._engine.url, - connect_args=connect_args, - engine_args=engine_args, - ) - else: - raise + if not self.migration_utils.database_exists(): + self.migration_utils.create_database() self._alembic = Alembic(self.engine) + if ( not self.skip_migrations and ENV_ZENML_DISABLE_DATABASE_MIGRATION not in os.environ @@ -1014,34 +1042,219 @@ def _initialize_database(self) -> None: if config.auth_scheme != AuthScheme.EXTERNAL: self._get_or_create_default_user() - def _create_mysql_database( + def _get_db_backup_file_path(self) -> str: + """Get the path to the database backup file. + + Returns: + The path to the configured database backup file. + """ + if self.config.driver == SQLDatabaseDriver.SQLITE: + return os.path.join( + self.config.backup_directory, + # Add the -backup suffix to the database filename + ZENML_SQLITE_DB_FILENAME[:-3] + "-backup.db", + ) + + # For a MySQL database, we need to dump the database to a JSON + # file + return os.path.join( + self.config.backup_directory, + f"{self.engine.url.database}-backup.json", + ) + + def backup_database( self, - url: URL, - connect_args: Dict[str, Any], - engine_args: Dict[str, Any], + strategy: Optional[DatabaseBackupStrategy] = None, + location: Optional[str] = None, + overwrite: bool = False, + ) -> Tuple[str, Any]: + """Backup the database. + + Args: + strategy: Custom backup strategy to use. If not set, the backup + strategy from the store configuration will be used. + location: Custom target location to backup the database to. If not + set, the configured backup location will be used. Depending on + the backup strategy, this can be a file path or a database name. + overwrite: Whether to overwrite an existing backup if it exists. + If set to False, the existing backup will be reused. + + Returns: + The location where the database was backed up to and an accompanying + user-friendly message that describes the backup location, or None + if no backup was created (i.e. because the backup already exists). + + Raises: + ValueError: If the backup database name is not set when the backup + database is requested or if the backup strategy is invalid. + """ + strategy = strategy or self.config.backup_strategy + + if ( + strategy == DatabaseBackupStrategy.DUMP_FILE + or self.config.driver == SQLDatabaseDriver.SQLITE + ): + dump_file = location or self._get_db_backup_file_path() + + if not overwrite and os.path.isfile(dump_file): + logger.warning( + f"A previous backup file already exists at '{dump_file}'. " + "Reusing the existing backup." + ) + else: + self.migration_utils.backup_database_to_file( + dump_file=dump_file + ) + return f"the '{dump_file}' backup file", dump_file + elif strategy == DatabaseBackupStrategy.DATABASE: + backup_db_name = location or self.config.backup_database + if not backup_db_name: + raise ValueError( + "The backup database name must be set in the store " + "configuration to use the backup database strategy." + ) + + if not overwrite and self.migration_utils.database_exists( + backup_db_name + ): + logger.warning( + "A previous backup database already exists at " + f"'{backup_db_name}'. Reusing the existing backup." + ) + else: + self.migration_utils.backup_database_to_db( + backup_db_name=backup_db_name + ) + return f"the '{backup_db_name}' backup database", backup_db_name + elif strategy == DatabaseBackupStrategy.IN_MEMORY: + return ( + "memory", + self.migration_utils.backup_database_to_memory(), + ) + + else: + raise ValueError(f"Invalid backup strategy: {strategy}.") + + def restore_database( + self, + strategy: Optional[DatabaseBackupStrategy] = None, + location: Optional[Any] = None, + cleanup: bool = False, ) -> None: - """Creates a mysql database. + """Restore the database. Args: - url: The URL of the database to create. - connect_args: Connect arguments for the SQLAlchemy engine. - engine_args: Additional initialization arguments for the SQLAlchemy - engine + strategy: Custom backup strategy to use. If not set, the backup + strategy from the store configuration will be used. + location: Custom target location to restore the database from. If + not set, the configured backup location will be used. Depending + on the backup strategy, this can be a file path, a database + name or an in-memory database representation. + cleanup: Whether to cleanup the backup after restoring the database. + + Raises: + ValueError: If the backup database name is not set when the backup + database is requested or if the backup strategy is invalid. """ - logger.info("Trying to create database %s.", url.database) - master_url = url._replace(database=None) - master_engine = create_engine( - url=master_url, connect_args=connect_args, **engine_args - ) - query = f"CREATE DATABASE IF NOT EXISTS {self.config.database}" - try: - connection = master_engine.connect() - connection.execute(text(query)) - finally: - connection.close() + strategy = strategy or self.config.backup_strategy + + if ( + strategy == DatabaseBackupStrategy.DUMP_FILE + or self.config.driver == SQLDatabaseDriver.SQLITE + ): + dump_file = location or self._get_db_backup_file_path() + self.migration_utils.restore_database_from_file( + dump_file=dump_file + ) + elif strategy == DatabaseBackupStrategy.DATABASE: + backup_db_name = location or self.config.backup_database + if not backup_db_name: + raise ValueError( + "The backup database name must be set in the store " + "configuration to use the backup database strategy." + ) + + self.migration_utils.restore_database_from_db( + backup_db_name=backup_db_name + ) + elif strategy == DatabaseBackupStrategy.IN_MEMORY: + if location is None or not isinstance(location, list): + raise ValueError( + "The in-memory database representation must be provided " + "to restore the database from an in-memory backup." + ) + self.migration_utils.restore_database_from_memory(db_dump=location) + + else: + raise ValueError(f"Invalid backup strategy: {strategy}.") + + if cleanup: + self.cleanup_database_backup() + + def cleanup_database_backup( + self, + strategy: Optional[DatabaseBackupStrategy] = None, + location: Optional[Any] = None, + ) -> None: + """Delete the database backup. + + Args: + strategy: Custom backup strategy to use. If not set, the backup + strategy from the store configuration will be used. + location: Custom target location to delete the database backup + from. If not set, the configured backup location will be used. + Depending on the backup strategy, this can be a file path or a + database name. + + Raises: + ValueError: If the backup database name is not set when the backup + database is requested. + """ + strategy = strategy or self.config.backup_strategy + + if ( + strategy == DatabaseBackupStrategy.DUMP_FILE + or self.config.driver == SQLDatabaseDriver.SQLITE + ): + dump_file = location or self._get_db_backup_file_path() + if dump_file is not None and os.path.isfile(dump_file): + try: + os.remove(dump_file) + except OSError: + logger.warning( + f"Failed to cleanup database dump file " + f"{dump_file}." + ) + else: + logger.info( + f"Successfully cleaned up database dump file " + f"{dump_file}." + ) + elif strategy == DatabaseBackupStrategy.DATABASE: + backup_db_name = location or self.config.backup_database + + if not backup_db_name: + raise ValueError( + "The backup database name must be set in the store " + "configuration to use the backup database strategy." + ) + if self.migration_utils.database_exists(backup_db_name): + # Drop the backup database + self.migration_utils.drop_database( + database=backup_db_name, + ) + logger.info( + f"Successfully cleaned up backup database " + f"{backup_db_name}." + ) def migrate_database(self) -> None: - """Migrate the database to the head as defined by the python package.""" + """Migrate the database to the head as defined by the python package. + + Raises: + RuntimeError: If the database exists and is not empty but has never + been migrated with alembic before. + """ alembic_logger = logging.getLogger("alembic") # remove all existing handlers @@ -1060,57 +1273,140 @@ def migrate_database(self) -> None: # We need to account for 3 distinct cases here: # 1. the database is completely empty (not initialized) - # 2. the database is not empty, but has never been migrated with alembic + # 2. the database is not empty and has been migrated with alembic before + # 3. the database is not empty, but has never been migrated with alembic # before (i.e. was created with SQLModel back when alembic wasn't - # used) - # 3. the database is not empty and has been migrated with alembic before - revisions = self.alembic.current_revisions() - if len(revisions) >= 1: - if len(revisions) > 1: + # used). We don't support this direct upgrade case anymore. + current_revisions = self.alembic.current_revisions() + head_revisions = self.alembic.head_revisions() + if len(current_revisions) >= 1: + # Case 2: the database has been migrated with alembic before. Just + # upgrade to the latest revision. + if len(current_revisions) > 1: logger.warning( "The ZenML database has more than one migration head " "revision. This is not expected and might indicate a " "database migration problem. Please raise an issue on " "GitHub if you encounter this." ) - # Case 3: the database has been migrated with alembic before. Just - # upgrade to the latest revision. - self.alembic.upgrade() - else: - if self.alembic.db_is_empty(): - # Case 1: the database is empty. We can just create the - # tables from scratch with from SQLModel. After tables are - # created we put an alembic revision to latest and populate - # identity table with needed info. - logger.info("Creating database tables") - with self.engine.begin() as conn: - conn.run_callable( - SQLModel.metadata.create_all # type: ignore[arg-type] - ) - with Session(self.engine) as session: - session.add( - IdentitySchema( - id=str(GlobalConfiguration().user_id).replace( - "-", "" - ) + + logger.debug("Current revisions: %s", current_revisions) + logger.debug("Head revisions: %s", head_revisions) + + # If the current revision and head revision don't match, a database + # migration that changes the database structure or contents may + # actually be performed, in which case we enable the backup + # functionality. We only enable the backup functionality if the + # database will actually be changed, to avoid the overhead for + # unnecessary backups. + backup_enabled = ( + self.config.backup_strategy != DatabaseBackupStrategy.DISABLED + and set(current_revisions) != set(head_revisions) + ) + backup_location: Optional[Any] = None + backup_location_msg: Optional[str] = None + + if backup_enabled: + try: + logger.info("Backing up the database before migration.") + ( + backup_location_msg, + backup_location, + ) = self.backup_database(overwrite=True) + except Exception as e: + raise RuntimeError( + f"Failed to backup the database: {str(e)}. " + "Please check the logs for more details." + "If you would like to disable the database backup " + "functionality, set the `backup_strategy` attribute " + "of the store configuration to `disabled`." + ) from e + else: + if backup_location is not None: + logger.info( + "Database successfully backed up to " + f"{backup_location_msg}. If something goes wrong " + "with the upgrade, ZenML will attempt to restore " + "the database from this backup automatically." ) + + try: + self.alembic.upgrade() + except Exception as e: + if backup_enabled and backup_location: + logger.exception( + "Failed to migrate the database. Attempting to restore " + f"the database from {backup_location_msg}." ) - session.commit() - self.alembic.stamp("head") + try: + self.restore_database(location=backup_location) + except Exception: + logger.exception( + "Failed to restore the database from " + f"{backup_location_msg}. Please " + "check the logs for more details. You might need " + "to restore the database manually." + ) + else: + raise RuntimeError( + "The database migration failed, but the database " + "was successfully restored from the backup. " + "You can safely retry the upgrade or revert to " + "the previous version of ZenML. Please check the " + "logs for more details." + ) from e + raise RuntimeError( + f"The database migration failed: {str(e)}" + ) from e + else: - # Case 2: the database is not empty, but has never been - # migrated with alembic before. We need to create the alembic - # version table, initialize it with the first revision where we - # introduced alembic and then upgrade to the latest revision. - self.alembic.stamp(ZENML_ALEMBIC_START_REVISION) - self.alembic.upgrade() + # We always remove the backup after a successful upgrade, + # not just to avoid cluttering the disk, but also to avoid + # reusing an outdated database from the backup in case of + # future upgrade failures. + try: + self.cleanup_database_backup() + except Exception: + logger.exception("Failed to cleanup the database backup.") + + elif self.alembic.db_is_empty(): + # Case 1: the database is empty. We can just create the + # tables from scratch with from SQLModel. After tables are + # created we put an alembic revision to latest and populate + # identity table with needed info. + logger.info("Creating database tables") + with self.engine.begin() as conn: + conn.run_callable( + SQLModel.metadata.create_all # type: ignore[arg-type] + ) + with Session(self.engine) as session: + session.add( + IdentitySchema( + id=str(GlobalConfiguration().user_id).replace("-", "") + ) + ) + session.commit() + self.alembic.stamp("head") + else: + # Case 3: the database is not empty, but has never been + # migrated with alembic before. We don't support this direct + # upgrade case anymore. The user needs to run a two-step + # upgrade. + raise RuntimeError( + "The ZenML database has never been migrated with alembic " + "before. This can happen if you are performing a direct " + "upgrade from a really old version of ZenML. This direct " + "upgrade path is not supported anymore. Please upgrade " + "your ZenML installation first to 0.54.0 or an earlier " + "version and then to the latest version." + ) # If an alembic migration took place, all non-custom flavors are purged # and the FlavorRegistry recreates all in-built and integration # flavors in the db. revisions_afterwards = self.alembic.current_revisions() - if revisions != revisions_afterwards: + if current_revisions != revisions_afterwards: self._sync_flavors() def _sync_flavors(self) -> None: