diff --git a/CHANGELOG.md b/CHANGELOG.md index 70047125..7efe6954 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,11 +1,31 @@ # Release History -## 1.10.0 +## 1.11.0 ### Bug Fixes +* ML Job: Fix `Error: Unable to retrieve head IP address` if not all instances start within the timeout. +* ML Job: Fix `TypeError: SnowflakeCursor.execute() got an unexpected keyword argument '_force_qmark_paramstyle'` + when running inside Stored Procedures. + +### Behavior Changes + +### New Features + +* `ModelVersion.create_service()`: Made `image_repo` argument optional. By + default it will use a default image repo, which is + being rolled out in server version 9.22+. +* Experiment Tracking (PrPr): Automatically log the model, metrics, and parameters while training Keras models with + `snowflake.ml.experiment.callback.keras.SnowflakeKerasCallback`. + +## 1.10.0 + ### Behavior Changes +* Experiment Tracking (PrPr): The import paths for the auto-logging callbacks have changed to + `snowflake.ml.experiment.callback.xgboost.SnowflakeXgboostCallback` and + `snowflake.ml.experiment.callback.lightgbm.SnowflakeLightgbmCallback`. + ### New Features * Registry: add progress bars for `ModelVersion.create_service` and `ModelVersion.log_model`. @@ -26,13 +46,13 @@ ```python from snowflake.ml.experiment import ExperimentTracking +from snowflake.ml.experiment.callback import SnowflakeXgboostCallback, SnowflakeLightgbmCallback exp = ExperimentTracking(session=sp_session, database_name="ML", schema_name="PUBLIC") exp.set_experiment("MY_EXPERIMENT") # XGBoost -from snowflake.ml.experiment.callback.xgboost import SnowflakeXgboostCallback callback = SnowflakeXgboostCallback( exp, log_model=True, log_metrics=True, log_params=True, model_name="model_name", model_signature=sig ) @@ -41,7 +61,6 @@ with exp.start_run(): model.fit(X, y, eval_set=[(X_test, y_test)]) # LightGBM -from snowflake.ml.experiment.callback.lightgbm import SnowflakeLightgbmCallback callback = SnowflakeLightgbmCallback( exp, log_model=True, log_metrics=True, log_params=True, model_name="model_name", model_signature=sig ) diff --git a/bazel/environments/conda-env-all.yml b/bazel/environments/conda-env-all.yml index 29550c73..9a2283f0 100755 --- a/bazel/environments/conda-env-all.yml +++ b/bazel/environments/conda-env-all.yml @@ -12,7 +12,7 @@ dependencies: - anyio==4.2.0 - boto3==1.34.69 - cachetools==5.3.3 - - catboost==1.2.0 + - catboost==1.2.8 - cloudpickle==2.2.1 - coverage==7.2.2 - cryptography==41.0.3 diff --git a/bazel/environments/conda-env-ml.yml b/bazel/environments/conda-env-ml.yml index 49c31565..07960d02 100755 --- a/bazel/environments/conda-env-ml.yml +++ b/bazel/environments/conda-env-ml.yml @@ -12,7 +12,7 @@ dependencies: - anyio==4.2.0 - boto3==1.34.69 - cachetools==5.3.3 - - catboost==1.2.0 + - catboost==1.2.8 - cloudpickle==2.2.1 - coverage==7.2.2 - cryptography==41.0.3 diff --git a/bazel/environments/requirements_ml.txt b/bazel/environments/requirements_ml.txt index b22932e0..ed639434 100755 --- a/bazel/environments/requirements_ml.txt +++ b/bazel/environments/requirements_ml.txt @@ -8,7 +8,7 @@ anyio==4.2.0 boto3==1.34.69 build==0.10.0 cachetools==5.3.3 -catboost==1.2.0 +catboost==1.2.8 cloudpickle==2.2.1 coverage==7.2.2 cryptography==41.0.3 diff --git a/ci/RunBazelAction.sh b/ci/RunBazelAction.sh index ae3a32e9..7e692b8a 100755 --- a/ci/RunBazelAction.sh +++ b/ci/RunBazelAction.sh @@ -1,7 +1,7 @@ #!/bin/bash # DESCRIPTION: Utility Shell script to run bazel action for snowml repository # -# RunBazelAction.sh [-b ] [-m merge_gate|continuous_run|quarantined|local_unittest|local_all] [-t ] [-c ] [--tags ] +# RunBazelAction.sh [-b ] [-m merge_gate|continuous_run|quarantined|local_unittest|local_all] [-t ] [-c ] [--tags ] [--with-spcs-image] # # Args: # action: bazel action, choose from test and coverage @@ -18,6 +18,7 @@ # -c: specify the path to the coverage report dat file. # -e: specify the environment, used to determine. # --tags: specify bazel test tag filters (e.g., "feature:jobs,feature:data") +# --with-spcs-image: use spcs image for testing. # set -o pipefail @@ -40,6 +41,7 @@ help() { echo "" echo "Options:" echo " --tags Specify bazel tag filters (comma-separated)" + echo " --with-spcs-image Use spcs image for testing." echo "" echo "Examples:" echo " ${PROG} test --tags 'feature:jobs'" @@ -109,7 +111,7 @@ fi action_env=() if [[ "${WITH_SPCS_IMAGE}" = true ]]; then - export SKIP_GRYPE=true + export RUN_GRYPE=false source model_container_services_deployment/ci/build_and_push_images.sh action_env=("--action_env=BUILDER_IMAGE_PATH=${BUILDER_IMAGE_PATH}" "--action_env=BASE_CPU_IMAGE_PATH=${BASE_CPU_IMAGE_PATH}" "--action_env=BASE_GPU_IMAGE_PATH=${BASE_GPU_IMAGE_PATH}" "--action_env=IMAGE_BUILD_SIDECAR_CPU_PATH=${IMAGE_BUILD_SIDECAR_CPU_PATH}" "--action_env=IMAGE_BUILD_SIDECAR_GPU_PATH=${IMAGE_BUILD_SIDECAR_GPU_PATH}" "--action_env=PROXY_IMAGE_PATH=${PROXY_IMAGE_PATH}" "--action_env=VLLM_IMAGE_PATH=${VLLM_IMAGE_PATH}") fi diff --git a/ci/build_and_run_tests.sh b/ci/build_and_run_tests.sh index 3c0c0f7a..25939880 100755 --- a/ci/build_and_run_tests.sh +++ b/ci/build_and_run_tests.sh @@ -1,7 +1,7 @@ #!/bin/bash # Usage -# build_and_run_tests.sh [-b ] [--env pip|conda] [--mode merge_gate|continuous_run] [--with-snowpark] [--with-spcs-image] [--report ] +# build_and_run_tests.sh [-b ] [--env pip|conda] [--mode merge_gate|continuous_run] [--with-snowpark] [--with-spcs-image] [--run-grype] [--report ] # # Args # workspace: path to the workspace, SnowML code should be in snowml directory. @@ -15,6 +15,7 @@ # quarantined: run all quarantined tests. # with-snowpark: Build and test with snowpark in snowpark-python directory in the workspace. # with-spcs-image: Build and test with spcs-image in spcs-image directory in the workspace. +# run-grype: Run grype security scanning on SPCS images. Only valid with --with-spcs-image. # snowflake-env: The environment of the snowflake, use to determine the test quarantine list # report: Path to xml test report # @@ -30,7 +31,7 @@ PROG=$0 help() { local exit_code=$1 - echo "Usage: ${PROG} [-b ] [--env pip|conda] [--mode merge_gate|continuous_run|quarantined] [--with-snowpark] [--with-spcs-image] [--snowflake-env ] [--report ]" + echo "Usage: ${PROG} [-b ] [--env pip|conda] [--mode merge_gate|continuous_run|quarantined] [--with-snowpark] [--with-spcs-image] [--run-grype] [--snowflake-env ] [--report ]" exit "${exit_code}" } @@ -39,6 +40,7 @@ BAZEL="bazel" ENV="pip" WITH_SNOWPARK=false WITH_SPCS_IMAGE=false +RUN_GRYPE=false MODE="continuous_run" PYTHON_VERSION=3.9 PYTHON_ENABLE_SCRIPT="bin/activate" @@ -91,6 +93,9 @@ while (($#)); do --with-spcs-image) WITH_SPCS_IMAGE=true ;; + --run-grype) + RUN_GRYPE=true + ;; -h | --help) help 0 ;; @@ -101,6 +106,12 @@ while (($#)); do shift done +# Validate flag combinations +if [ "${RUN_GRYPE}" = true ] && [ "${WITH_SPCS_IMAGE}" = false ]; then + echo "Error: --run-grype flag requires --with-spcs-image to be set" + help 1 +fi + echo "Running build_and_run_tests with PYTHON_VERSION ${PYTHON_VERSION}" EXT="" @@ -180,6 +191,25 @@ trap 'rm -rf "${TEMP_BIN}"' EXIT # Install micromamba _MICROMAMBA_BIN="micromamba${EXT}" if [ "${ENV}" = "conda" ]; then + CONDA="/mnt/jenkins/home/jenkins/miniforge3/condabin/conda" + + # Check if miniforge is already installed + if [ -x "${CONDA}" ]; then + echo "Miniforge exists at ${CONDA}." + else + echo "Downloading miniforge ..." + curl -L -O "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh" + + echo "Installing miniforge ..." + /bin/bash "Miniforge3-$(uname)-$(uname -m).sh" -b -u + fi + + echo "Using ${CONDA} ..." + + echo "Installing conda-build ..." + ${CONDA} install conda-build --yes + + echo "Installing micromamba ..." if ! command -v "${_MICROMAMBA_BIN}" &>/dev/null; then curl -Lsv "https://github.com/mamba-org/micromamba-releases/releases/latest/download/micromamba-${MICROMAMBA_PLATFORM}-${MICROMAMBA_ARCH}" -o "${TEMP_BIN}/micromamba${EXT}" && chmod +x "${TEMP_BIN}/micromamba${EXT}" _MICROMAMBA_BIN="${TEMP_BIN}/micromamba${EXT}" @@ -264,30 +294,31 @@ if [ "${ENV}" = "pip" ]; then cp "$("${BAZEL}" "${BAZEL_ADDITIONAL_STARTUP_FLAGS[@]+"${BAZEL_ADDITIONAL_STARTUP_FLAGS[@]}"}" info bazel-bin)/dist/snowflake_ml_python-${VERSION}-py3-none-any.whl" "${WORKSPACE}" popd else - # Clean conda cache - conda clean --all --force-pkgs-dirs -y + echo "Cleaning conda cache ..." + ${CONDA} clean --all --force-pkgs-dirs -y - # Clean conda build workspace + echo "Cleaning conda build workspace ..." rm -rf "${WORKSPACE}/conda-bld" - # Build Snowpark + echo "Building snowpark-python conda package ..." if [ "${WITH_SNOWPARK}" = true ]; then pushd ${SNOWPARK_DIR} - conda build recipe/ --python=${PYTHON_VERSION} --numpy=1.16 --croot "${WORKSPACE}/conda-bld" + ${CONDA} build recipe/ --python=${PYTHON_VERSION} --numpy=1.16 --croot "${WORKSPACE}/conda-bld" popd fi - # Build SnowML pushd ${SNOWML_DIR} - # Build conda package - conda build -c conda-forge --override-channels --prefix-length 50 --python=${PYTHON_VERSION} --croot "${WORKSPACE}/conda-bld" ci/conda_recipe - conda build purge + + echo "Building snowflake-ml-python conda package ..." + ${CONDA} build -c conda-forge --override-channels --prefix-length 50 --python=${PYTHON_VERSION} --croot "${WORKSPACE}/conda-bld" ci/conda_recipe + ${CONDA} build purge popd fi if [[ "${WITH_SPCS_IMAGE}" = true ]]; then pushd ${SNOWML_DIR} - # Build SPCS Image + echo "Building SPCS Image ..." + export RUN_GRYPE source model_container_services_deployment/ci/build_and_push_images.sh popd fi @@ -361,7 +392,7 @@ for i in "${!groups[@]}"; do COMMON_PYTEST_FLAG+=(-m "not conda_incompatible") fi # Create local conda channel - conda index "${WORKSPACE}/conda-bld" + ${CONDA} index "${WORKSPACE}/conda-bld" # Clean conda cache "${_MICROMAMBA_BIN}" clean --all --force-pkgs-dirs -y @@ -384,7 +415,7 @@ for i in "${!groups[@]}"; do # Run integration tests set +e - TEST_SRCDIR="${TEMP_TEST_DIR}" conda run -p ./testenv --no-capture-output python -m pytest "${COMMON_PYTEST_FLAG[@]}" tests/integ/ + TEST_SRCDIR="${TEMP_TEST_DIR}" ${CONDA} run -p ./testenv --no-capture-output python -m pytest "${COMMON_PYTEST_FLAG[@]}" tests/integ/ group_exit_codes[$i]=$? set -e diff --git a/ci/conda_recipe/meta.yaml b/ci/conda_recipe/meta.yaml index b694e06e..85561798 100644 --- a/ci/conda_recipe/meta.yaml +++ b/ci/conda_recipe/meta.yaml @@ -17,7 +17,7 @@ build: noarch: python package: name: snowflake-ml-python - version: 1.10.0 + version: 1.11.0 requirements: build: - python diff --git a/ci/targets/quarantine/prod3.txt b/ci/targets/quarantine/prod3.txt index a67dc3f2..96a5a714 100644 --- a/ci/targets/quarantine/prod3.txt +++ b/ci/targets/quarantine/prod3.txt @@ -1,7 +1,7 @@ //tests/integ/snowflake/ml/extra_tests:xgboost_external_memory_training_test //tests/integ/snowflake/ml/extra_tests:pipeline_with_ohe_and_xgbr_test -//tests/integ/snowflake/ml/lineage:lineage_integ_test //tests/integ/snowflake/ml/modeling/manifold:spectral_embedding_test //tests/integ/snowflake/ml/modeling/linear_model:logistic_regression_test //tests/integ/snowflake/ml/registry/services:registry_huggingface_pipeline_model_deployment_test //tests/integ/snowflake/ml/registry/services:registry_sentence_transformers_model_deployment_test +//tests/integ/snowflake/ml/jobs:jobs_integ_test diff --git a/requirements.yml b/requirements.yml index 417f829c..7ef04729 100644 --- a/requirements.yml +++ b/requirements.yml @@ -84,7 +84,7 @@ - name: boto3 dev_version: 1.34.69 - name: catboost - dev_version: 1.2.0 + dev_version: 1.2.8 version_requirements: '>=1.2.0, <2' requirements_extra_tags: - catboost diff --git a/snowflake/ml/experiment/callback/BUILD.bazel b/snowflake/ml/experiment/callback/BUILD.bazel index cb0b320d..91d77fc1 100644 --- a/snowflake/ml/experiment/callback/BUILD.bazel +++ b/snowflake/ml/experiment/callback/BUILD.bazel @@ -11,6 +11,29 @@ py_library( ], ) +py_library( + name = "keras", + srcs = ["keras.py"], + deps = [ + "//snowflake/ml/experiment:experiment_tracking", + "//snowflake/ml/experiment:utils", + "//snowflake/ml/model:model_signature", + ], +) + +py_test( + name = "keras_test", + srcs = ["test/keras_test.py"], + optional_dependencies = [ + "keras", + ], + tags = ["feature:observability"], + deps = [ + ":keras", + ":test_base", + ], +) + py_library( name = "lightgbm", srcs = ["lightgbm.py"], @@ -56,6 +79,7 @@ py_test( py_library( name = "callback", deps = [ + ":keras", ":lightgbm", ":xgboost", ], diff --git a/snowflake/ml/experiment/callback/keras.py b/snowflake/ml/experiment/callback/keras.py new file mode 100644 index 00000000..a0657d41 --- /dev/null +++ b/snowflake/ml/experiment/callback/keras.py @@ -0,0 +1,63 @@ +import json +from typing import TYPE_CHECKING, Any, Optional +from warnings import warn + +import keras + +from snowflake.ml.experiment import utils + +if TYPE_CHECKING: + from snowflake.ml.experiment.experiment_tracking import ExperimentTracking + from snowflake.ml.model.model_signature import ModelSignature + + +class SnowflakeKerasCallback(keras.callbacks.Callback): + def __init__( + self, + experiment_tracking: "ExperimentTracking", + log_model: bool = True, + log_metrics: bool = True, + log_params: bool = True, + log_every_n_epochs: int = 1, + model_name: Optional[str] = None, + model_signature: Optional["ModelSignature"] = None, + ) -> None: + self._experiment_tracking = experiment_tracking + self.log_model = log_model + self.log_metrics = log_metrics + self.log_params = log_params + if log_every_n_epochs < 1: + raise ValueError("`log_every_n_epochs` must be positive.") + self.log_every_n_epochs = log_every_n_epochs + self.model_name = model_name + self.model_signature = model_signature + + def on_train_begin(self, logs: Optional[dict[str, Any]] = None) -> None: + if self.log_params: + params = json.loads(self.model.to_json()) + self._experiment_tracking.log_params(utils.flatten_nested_params(params)) + + def on_epoch_end(self, epoch: int, logs: Optional[dict[str, Any]] = None) -> None: + if self.log_metrics and logs and epoch % self.log_every_n_epochs == 0: + for key, value in logs.items(): + try: + value = float(value) + except Exception: + pass + else: + self._experiment_tracking.log_metric(key=key, value=value, step=epoch) + + def on_train_end(self, logs: Optional[dict[str, Any]] = None) -> None: + if self.log_model: + if not self.model_signature: + warn( + "Model will not be logged because model signature is missing. " + "To autolog the model, please specify `model_signature` when constructing SnowflakeKerasCallback." + ) + return + model_name = self.model_name or self._experiment_tracking._get_or_set_experiment().name + "_model" + self._experiment_tracking.log_model( # type: ignore[call-arg] + model=self.model, + model_name=model_name, + signatures={"predict": self.model_signature}, + ) diff --git a/snowflake/ml/experiment/callback/lightgbm.py b/snowflake/ml/experiment/callback/lightgbm.py index ea725de1..2ee500ac 100644 --- a/snowflake/ml/experiment/callback/lightgbm.py +++ b/snowflake/ml/experiment/callback/lightgbm.py @@ -15,6 +15,7 @@ def __init__( log_model: bool = True, log_metrics: bool = True, log_params: bool = True, + log_every_n_epochs: int = 1, model_name: Optional[str] = None, model_signature: Optional["ModelSignature"] = None, ) -> None: @@ -22,6 +23,9 @@ def __init__( self.log_model = log_model self.log_metrics = log_metrics self.log_params = log_params + if log_every_n_epochs < 1: + raise ValueError("`log_every_n_epochs` must be positive.") + self.log_every_n_epochs = log_every_n_epochs self.model_name = model_name self.model_signature = model_signature @@ -32,7 +36,7 @@ def __call__(self, env: lgb.callback.CallbackEnv) -> None: if env.iteration == env.begin_iteration: # Log params only at the first iteration self._experiment_tracking.log_params(env.params) - if self.log_metrics: + if self.log_metrics and env.iteration % self.log_every_n_epochs == 0: super().__call__(env) for dataset_name, metrics in self.eval_result.items(): for metric_name, log in metrics.items(): diff --git a/snowflake/ml/experiment/callback/test/base.py b/snowflake/ml/experiment/callback/test/base.py index 8caa5408..b5ea0837 100644 --- a/snowflake/ml/experiment/callback/test/base.py +++ b/snowflake/ml/experiment/callback/test/base.py @@ -1,3 +1,4 @@ +import math from typing import Any, Optional from unittest.mock import ANY, MagicMock @@ -17,7 +18,7 @@ def setUp(self) -> None: # Create training data and parameters self.X = np.array([[1, 2], [3, 4]]) self.y = np.array([0, 1]) - self.num_steps = 2 + self.num_steps = 3 self.model_signature = ModelSignature( inputs=[ FeatureSpec(name="feature1", dtype=DataType.FLOAT), @@ -32,17 +33,22 @@ def _train_model(self, model_class: type[Any], callback: Any) -> None: def _get_callback(self, **kwargs: Any) -> Any: pass - def _log_metrics(self, model_class: type[Any]) -> None: + def _log_metrics(self, model_class: type[Any], log_every_n_epochs: int) -> None: """Test that metrics are autologged.""" callback = self._get_callback( experiment_tracking=self.experiment_tracking, log_model=False, log_metrics=True, log_params=False, + log_every_n_epochs=log_every_n_epochs, ) self._train_model(model_class=model_class, callback=callback) - self.assertEqual(self.experiment_tracking.log_metric.call_count, self.num_steps) + # Expected call count is rounded up to the next integer because we always log at epoch 0. + expected_call_count = math.ceil(self.num_steps / log_every_n_epochs) + self.assertEqual(self.experiment_tracking.log_metric.call_count, expected_call_count) + for epoch in range(0, self.num_steps, log_every_n_epochs): + self.experiment_tracking.log_metric.assert_any_call(key=ANY, value=ANY, step=epoch) def _log_model(self, model_class: type[Any], model_name: Optional[str]) -> None: """Test that model is autologged.""" diff --git a/snowflake/ml/experiment/callback/test/keras_test.py b/snowflake/ml/experiment/callback/test/keras_test.py new file mode 100644 index 00000000..3f3e0d0d --- /dev/null +++ b/snowflake/ml/experiment/callback/test/keras_test.py @@ -0,0 +1,37 @@ +from typing import Any, Optional + +import keras +from absl.testing import absltest, parameterized + +from snowflake.ml.experiment.callback.keras import SnowflakeKerasCallback +from snowflake.ml.experiment.callback.test.base import SnowflakeCallbackTest + + +class SnowflakeKerasCallbackTest(SnowflakeCallbackTest, parameterized.TestCase): + def _train_model( + self, + model_class: type[keras.Model], + callback: SnowflakeKerasCallback, + ) -> None: + model = model_class() + model.add(keras.layers.Dense(1)) + model.compile(loss="mean_squared_error") + model.fit(self.X, self.y, epochs=self.num_steps, callbacks=[callback]) + + def _get_callback(self, **kwargs: Any) -> SnowflakeKerasCallback: + return SnowflakeKerasCallback(**kwargs) + + @parameterized.parameters(1, 2) # type: ignore[misc] + def test_log_metrics(self, log_every_n_epochs: int) -> None: + super()._log_metrics(keras.Sequential, log_every_n_epochs=log_every_n_epochs) + + @parameterized.parameters(None, "custom_model_name") # type: ignore[misc] + def test_log_model(self, model_name: Optional[str] = None) -> None: + super()._log_model(keras.Sequential, model_name) + + def test_log_param(self) -> None: + super()._log_param(keras.Sequential) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/experiment/callback/test/lightgbm_test.py b/snowflake/ml/experiment/callback/test/lightgbm_test.py index 38c06296..b9e2a5b2 100644 --- a/snowflake/ml/experiment/callback/test/lightgbm_test.py +++ b/snowflake/ml/experiment/callback/test/lightgbm_test.py @@ -31,9 +31,9 @@ def _train_model( def _get_callback(self, **kwargs: Any) -> SnowflakeLightgbmCallback: return SnowflakeLightgbmCallback(**kwargs) - @parameterized.parameters(*supported_model_classes) # type: ignore[misc] - def test_log_metrics(self, model_class: type[ModelClass]) -> None: - super()._log_metrics(model_class) + @parameterized.product(model_class=supported_model_classes, log_every_n_epochs=[1, 2]) # type: ignore[misc] + def test_log_metrics(self, model_class: type[ModelClass], log_every_n_epochs: int) -> None: + super()._log_metrics(model_class, log_every_n_epochs=log_every_n_epochs) @parameterized.product( model_class=supported_model_classes, diff --git a/snowflake/ml/experiment/callback/test/xgboost_test.py b/snowflake/ml/experiment/callback/test/xgboost_test.py index fea56c50..8da0fe7d 100644 --- a/snowflake/ml/experiment/callback/test/xgboost_test.py +++ b/snowflake/ml/experiment/callback/test/xgboost_test.py @@ -29,9 +29,9 @@ def _train_model( def _get_callback(self, **kwargs: Any) -> SnowflakeXgboostCallback: return SnowflakeXgboostCallback(**kwargs) - @parameterized.parameters(*supported_model_classes) # type: ignore[misc] - def test_log_metrics(self, model_class: type[ModelClass]) -> None: - super()._log_metrics(model_class) + @parameterized.product(model_class=supported_model_classes, log_every_n_epochs=[1, 2]) # type: ignore[misc] + def test_log_metrics(self, model_class: type[ModelClass], log_every_n_epochs: int) -> None: + super()._log_metrics(model_class, log_every_n_epochs=log_every_n_epochs) @parameterized.product( model_class=supported_model_classes, diff --git a/snowflake/ml/experiment/callback/xgboost.py b/snowflake/ml/experiment/callback/xgboost.py index 5bfaed1d..c9ce67af 100644 --- a/snowflake/ml/experiment/callback/xgboost.py +++ b/snowflake/ml/experiment/callback/xgboost.py @@ -18,6 +18,7 @@ def __init__( log_model: bool = True, log_metrics: bool = True, log_params: bool = True, + log_every_n_epochs: int = 1, model_name: Optional[str] = None, model_signature: Optional["ModelSignature"] = None, ) -> None: @@ -25,6 +26,9 @@ def __init__( self.log_model = log_model self.log_metrics = log_metrics self.log_params = log_params + if log_every_n_epochs < 1: + raise ValueError("`log_every_n_epochs` must be positive.") + self.log_every_n_epochs = log_every_n_epochs self.model_name = model_name self.model_signature = model_signature @@ -36,7 +40,7 @@ def before_training(self, model: xgb.Booster) -> xgb.Booster: return model def after_iteration(self, model: Any, epoch: int, evals_log: dict[str, dict[str, Any]]) -> bool: - if self.log_metrics: + if self.log_metrics and epoch % self.log_every_n_epochs == 0: for dataset_name, metrics in evals_log.items(): for metric_name, log in metrics.items(): metric_key = dataset_name + ":" + metric_name diff --git a/snowflake/ml/jobs/_utils/BUILD.bazel b/snowflake/ml/jobs/_utils/BUILD.bazel index ef9f2fbd..0579e870 100644 --- a/snowflake/ml/jobs/_utils/BUILD.bazel +++ b/snowflake/ml/jobs/_utils/BUILD.bazel @@ -34,6 +34,7 @@ py_library( deps = [ ":constants", ":query_helper", + ":runtime_env_utils", ":types", "//snowflake/ml/_internal/utils:snowflake_env", ], @@ -62,13 +63,26 @@ py_library( ], ) +py_test( + name = "stage_utils_test", + srcs = ["stage_utils_test.py"], + tags = ["feature:jobs"], + deps = [ + ":payload_utils", + ], +) + py_library( name = "payload_utils", - srcs = ["payload_utils.py"], + srcs = [ + "__init__.py", + "payload_utils.py", + ], deps = [ ":constants", ":function_payload_utils", ":payload_scripts", + ":query_helper", ":stage_utils", ":types", ], @@ -83,21 +97,10 @@ py_test( tags = ["feature:jobs"], deps = [ ":payload_utils", - ":query_helper", - ":stage_utils", ":test_file_helper", ], ) -py_test( - name = "stage_utils_test", - srcs = ["stage_utils_test.py"], - tags = ["feature:jobs"], - deps = [ - ":stage_utils", - ], -) - py_library( name = "query_helper", srcs = ["query_helper.py"], @@ -113,6 +116,11 @@ py_library( srcs = ["function_payload_utils.py"], ) +py_library( + name = "runtime_env_utils", + srcs = ["runtime_env_utils.py"], +) + py_test( name = "interop_utils_test", srcs = ["interop_utils_test.py"], @@ -137,7 +145,9 @@ py_test( py_library( name = "job_utils", - srcs = [], + srcs = [ + "__init__.py", + ], deps = [ ":interop_utils", ":payload_utils", diff --git a/snowflake/ml/jobs/_utils/__init__.py b/snowflake/ml/jobs/_utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/snowflake/ml/jobs/_utils/constants.py b/snowflake/ml/jobs/_utils/constants.py index d611fe23..3708ddbb 100644 --- a/snowflake/ml/jobs/_utils/constants.py +++ b/snowflake/ml/jobs/_utils/constants.py @@ -28,7 +28,7 @@ DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images" DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks" DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks" -DEFAULT_IMAGE_TAG = "1.5.0" +DEFAULT_IMAGE_TAG = "1.6.2" DEFAULT_ENTRYPOINT_PATH = "func.py" # Percent of container memory to allocate for /dev/shm volume @@ -98,3 +98,6 @@ SnowflakeCloudType.AWS: AWS_INSTANCE_FAMILIES, SnowflakeCloudType.AZURE: AZURE_INSTANCE_FAMILIES, } + +# runtime version environment variable +ENABLE_IMAGE_VERSION_ENV_VAR = "MLRS_ENABLE_RUNTIME_VERSIONS" diff --git a/snowflake/ml/jobs/_utils/mljob_launcher_test.py b/snowflake/ml/jobs/_utils/mljob_launcher_test.py index b2d4cf76..7d56969e 100644 --- a/snowflake/ml/jobs/_utils/mljob_launcher_test.py +++ b/snowflake/ml/jobs/_utils/mljob_launcher_test.py @@ -296,7 +296,7 @@ def test_wait_for_instances_timeout(self, mock_sleep: mock.MagicMock, mock_commo with self.assertRaises(TimeoutError) as cm: mljob_launcher.wait_for_instances(2, 5, timeout=5, check_interval=10) - self.assertIn("Timed out after 5s waiting for 2 instances", str(cm.exception)) + self.assertIn("Timed out after 6s waiting for 2 instances", str(cm.exception)) self.assertIn("only 1 available", str(cm.exception)) diff --git a/snowflake/ml/jobs/_utils/payload_utils.py b/snowflake/ml/jobs/_utils/payload_utils.py index 2cf56a34..559c186b 100644 --- a/snowflake/ml/jobs/_utils/payload_utils.py +++ b/snowflake/ml/jobs/_utils/payload_utils.py @@ -1,4 +1,5 @@ import functools +import importlib import inspect import io import itertools @@ -7,6 +8,7 @@ import pickle import sys import textwrap +from importlib.abc import Traversable from pathlib import Path, PurePath from typing import Any, Callable, Optional, Union, cast, get_args, get_origin @@ -262,11 +264,24 @@ def upload_payloads(session: snowpark.Session, stage_path: PurePath, *payload_sp # Manually traverse the directory and upload each file, since Snowflake PUT # can't handle directories. Reduce the number of PUT operations by using # wildcard patterns to batch upload files with the same extension. - for path in { - p.parent.joinpath(f"*{p.suffix}") if p.suffix else p - for p in source_path.resolve().rglob("*") - if p.is_file() - }: + upload_path_patterns = set() + for p in source_path.resolve().rglob("*"): + if p.is_dir(): + continue + if p.name.startswith("."): + # Hidden files: use .* pattern for batch upload + if p.suffix: + upload_path_patterns.add(p.parent.joinpath(f".*{p.suffix}")) + else: + upload_path_patterns.add(p.parent.joinpath(".*")) + else: + # Regular files: use * pattern for batch upload + if p.suffix: + upload_path_patterns.add(p.parent.joinpath(f"*{p.suffix}")) + else: + upload_path_patterns.add(p) + + for path in upload_path_patterns: session.file.put( str(path), payload_stage_path.joinpath(path.parent.relative_to(source_path)).as_posix(), @@ -282,6 +297,27 @@ def upload_payloads(session: snowpark.Session, stage_path: PurePath, *payload_sp ) +def upload_system_resources(session: snowpark.Session, stage_path: PurePath) -> None: + resource_ref = importlib.resources.files(__package__).joinpath("scripts") + + def upload_dir(ref: Traversable, relative_path: str = "") -> None: + for item in ref.iterdir(): + current_path = Path(relative_path) / item.name if relative_path else Path(item.name) + if item.is_dir(): + # Recursively process subdirectories + upload_dir(item, str(current_path)) + elif item.is_file(): + content = item.read_bytes() + session.file.put_stream( + io.BytesIO(content), + stage_path.joinpath(current_path).as_posix(), + auto_compress=False, + overwrite=True, + ) + + upload_dir(resource_ref) + + def resolve_source( source: Union[types.PayloadPath, Callable[..., Any]] ) -> Union[types.PayloadPath, Callable[..., Any]]: @@ -497,15 +533,7 @@ def upload(self, session: snowpark.Session, stage_path: Union[str, PurePath]) -> overwrite=False, # FIXME ) - scripts_dir = Path(__file__).parent.joinpath("scripts") - for script_file in scripts_dir.glob("*"): - if script_file.is_file(): - session.file.put( - script_file.as_posix(), - system_stage_path.as_posix(), - overwrite=True, - auto_compress=False, - ) + upload_system_resources(session, system_stage_path) python_entrypoint: list[Union[str, PurePath]] = [ PurePath(f"{constants.SYSTEM_MOUNT_PATH}/mljob_launcher.py"), PurePath(f"{constants.APP_MOUNT_PATH}/{entrypoint.file_path.relative_to(source).as_posix()}"), diff --git a/snowflake/ml/jobs/_utils/payload_utils_test.py b/snowflake/ml/jobs/_utils/payload_utils_test.py index 89dd0c4b..2bb5d086 100644 --- a/snowflake/ml/jobs/_utils/payload_utils_test.py +++ b/snowflake/ml/jobs/_utils/payload_utils_test.py @@ -100,6 +100,7 @@ class PayloadUtilsTests(parameterized.TestCase): ("@test_stage/main.py", None, "@test_stage/main.py"), ("@test_stage/main.py", "@test_stage/main.py", "@test_stage/main.py"), ("@test_stage/src/dir", "@test_stage/src/dir/dir1/main.py", "@test_stage/src/dir/dir1/main.py"), + ("@test_stage/src/dir/", "dir1/main.py", "@test_stage/src/dir/dir1/main.py"), ("snow://headless/abc/versions/v9.8.7/main.py", None, "snow://headless/abc/versions/v9.8.7/main.py"), ( "snow://headless/abc/versions/v9.8.7/main.py", @@ -116,6 +117,11 @@ class PayloadUtilsTests(parameterized.TestCase): "snow://headless/abc/versions/v9.8.7/src/main.py", "snow://headless/abc/versions/v9.8.7/src/main.py", ), + ( + "snow://headless/abc/versions/v9.8.7/src", + "main.py", + "snow://headless/abc/versions/v9.8.7/src/main.py", + ), ) def test_payload_validate(self, source: str, entrypoint: Optional[str], expected_entrypoint: str) -> None: with pushd(resolve_path("")): diff --git a/snowflake/ml/jobs/_utils/query_helper.py b/snowflake/ml/jobs/_utils/query_helper.py index e63f9987..9fbec7f0 100644 --- a/snowflake/ml/jobs/_utils/query_helper.py +++ b/snowflake/ml/jobs/_utils/query_helper.py @@ -4,6 +4,7 @@ from snowflake.snowpark import Row from snowflake.snowpark._internal import utils from snowflake.snowpark._internal.analyzer import snowflake_plan +from snowflake.snowpark._internal.utils import is_in_stored_procedure def result_set_to_rows(session: snowpark.Session, result: dict[str, Any]) -> list[Row]: @@ -14,7 +15,10 @@ def result_set_to_rows(session: snowpark.Session, result: dict[str, Any]) -> lis @snowflake_plan.SnowflakePlan.Decorator.wrap_exception # type: ignore[misc] def run_query(session: snowpark.Session, query_text: str, params: Optional[Sequence[Any]] = None) -> list[Row]: - result = session._conn.run_query(query=query_text, params=params, _force_qmark_paramstyle=True) + kwargs: dict[str, Any] = {"query": query_text, "params": params} + if not is_in_stored_procedure(): # type: ignore[no-untyped-call] + kwargs["_force_qmark_paramstyle"] = True + result = session._conn.run_query(**kwargs) if not isinstance(result, dict) or "data" not in result: raise ValueError(f"Unprocessable result: {result}") return result_set_to_rows(session, result) diff --git a/snowflake/ml/jobs/_utils/runtime_env_utils.py b/snowflake/ml/jobs/_utils/runtime_env_utils.py new file mode 100644 index 00000000..47b97046 --- /dev/null +++ b/snowflake/ml/jobs/_utils/runtime_env_utils.py @@ -0,0 +1,63 @@ +from typing import Any, Optional, Union + +from packaging.version import Version +from pydantic import BaseModel, Field, RootModel, field_validator + + +class SpcsContainerRuntime(BaseModel): + python_version: Version = Field(alias="pythonVersion") + hardware_type: str = Field(alias="hardwareType") + runtime_container_image: str = Field(alias="runtimeContainerImage") + + @field_validator("python_version", mode="before") + @classmethod + def validate_python_version(cls, v: Union[str, Version]) -> Version: + if isinstance(v, Version): + return v + try: + return Version(v) + except Exception: + raise ValueError(f"Invalid Python version format: {v}") + + class Config: + frozen = True + extra = "allow" + arbitrary_types_allowed = True + + +class RuntimeEnvironmentEntry(BaseModel): + spcs_container_runtime: Optional[SpcsContainerRuntime] = Field(alias="spcsContainerRuntime", default=None) + + class Config: + extra = "allow" + frozen = True + + +class RuntimeEnvironmentsDict(RootModel[dict[str, RuntimeEnvironmentEntry]]): + @field_validator("root", mode="before") + @classmethod + def _filter_to_dict_entries(cls, data: Any) -> dict[str, dict[str, Any]]: + """ + Pre-validation hook: keep only those items at the root level + whose values are dicts. Non-dict values will be dropped. + + Args: + data: The input data to filter, expected to be a dictionary. + + Returns: + A dictionary containing only the key-value pairs where values are dictionaries. + + Raises: + ValueError: If input data is not a dictionary. + """ + # If the entire root is not a dict, raise error immediately + if not isinstance(data, dict): + raise ValueError(f"Expected dictionary data, but got {type(data).__name__}: {data}") + + # Filter out any key whose value is not a dict + return {key: value for key, value in data.items() if isinstance(value, dict)} + + def get_spcs_container_runtimes(self) -> list[SpcsContainerRuntime]: + return [ + entry.spcs_container_runtime for entry in self.root.values() if entry.spcs_container_runtime is not None + ] diff --git a/snowflake/ml/jobs/_utils/scripts/get_instance_ip.py b/snowflake/ml/jobs/_utils/scripts/get_instance_ip.py index 664265fe..695e4294 100644 --- a/snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +++ b/snowflake/ml/jobs/_utils/scripts/get_instance_ip.py @@ -47,8 +47,8 @@ def get_first_instance(service_name: str) -> Optional[tuple[str, str, str]]: if not result: return None - # Sort by start_time first, then by instance_id - sorted_instances = sorted(result, key=lambda x: (x["start_time"], int(x["instance_id"]))) + # Sort by start_time first, then by instance_id. If start_time is null/empty, it will be sorted to the end. + sorted_instances = sorted(result, key=lambda x: (not bool(x["start_time"]), x["start_time"], int(x["instance_id"]))) head_instance = sorted_instances[0] if not head_instance["instance_id"] or not head_instance["ip_address"]: return None diff --git a/snowflake/ml/jobs/_utils/scripts/mljob_launcher.py b/snowflake/ml/jobs/_utils/scripts/mljob_launcher.py index efcd0f5c..37cec35c 100644 --- a/snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +++ b/snowflake/ml/jobs/_utils/scripts/mljob_launcher.py @@ -173,10 +173,10 @@ def wait_for_instances( start_time = time.time() current_interval = max(min(1, check_interval), 0.1) # Default 1s, minimum 0.1s - logger.debug( + logger.info( "Waiting for instances to be ready " - "(min_instances={}, target_instances={}, timeout={}s, max_check_interval={}s)".format( - min_instances, target_instances, timeout, check_interval + "(min_instances={}, target_instances={}, min_wait_time={}s, timeout={}s, max_check_interval={}s)".format( + min_instances, target_instances, min_wait_time, timeout, check_interval ) ) diff --git a/snowflake/ml/jobs/_utils/spec_utils.py b/snowflake/ml/jobs/_utils/spec_utils.py index b6d44dd4..d045ba20 100644 --- a/snowflake/ml/jobs/_utils/spec_utils.py +++ b/snowflake/ml/jobs/_utils/spec_utils.py @@ -1,12 +1,14 @@ import logging import os +import sys from math import ceil from pathlib import PurePath -from typing import Any, Optional, Union +from typing import Any, Literal, Optional, Union from snowflake import snowpark from snowflake.ml._internal.utils import snowflake_env from snowflake.ml.jobs._utils import constants, query_helper, types +from snowflake.ml.jobs._utils.runtime_env_utils import RuntimeEnvironmentsDict def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.ComputeResources: @@ -28,22 +30,53 @@ def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.C ) +def _get_runtime_image(session: snowpark.Session, target_hardware: Literal["CPU", "GPU"]) -> Optional[str]: + rows = query_helper.run_query(session, "CALL SYSTEM$NOTEBOOKS_FIND_LABELED_RUNTIMES()") + if not rows: + return None + try: + runtime_envs = RuntimeEnvironmentsDict.model_validate_json(rows[0][0]) + spcs_container_runtimes = runtime_envs.get_spcs_container_runtimes() + except Exception as e: + logging.warning(f"Failed to parse runtime image name from {rows[0][0]}, error: {e}") + return None + + selected_runtime = next( + ( + runtime + for runtime in spcs_container_runtimes + if ( + runtime.hardware_type.lower() == target_hardware.lower() + and runtime.python_version.major == sys.version_info.major + and runtime.python_version.minor == sys.version_info.minor + ) + ), + None, + ) + return selected_runtime.runtime_container_image if selected_runtime else None + + def _get_image_spec(session: snowpark.Session, compute_pool: str) -> types.ImageSpec: # Retrieve compute pool node resources resources = _get_node_resources(session, compute_pool=compute_pool) # Use MLRuntime image - image_repo = constants.DEFAULT_IMAGE_REPO - image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU - image_tag = _get_runtime_image_tag() + hardware = "GPU" if resources.gpu > 0 else "CPU" + container_image = None + if os.environ.get(constants.ENABLE_IMAGE_VERSION_ENV_VAR, "").lower() == "true": + container_image = _get_runtime_image(session, hardware) # type: ignore[arg-type] + + if not container_image: + image_repo = constants.DEFAULT_IMAGE_REPO + image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU + image_tag = _get_runtime_image_tag() + container_image = f"{image_repo}/{image_name}:{image_tag}" # TODO: Should each instance consume the entire pod? return types.ImageSpec( - repo=image_repo, - image_name=image_name, - image_tag=image_tag, resource_requests=resources, resource_limits=resources, + container_image=container_image, ) @@ -220,7 +253,7 @@ def generate_service_spec( "containers": [ { "name": constants.DEFAULT_CONTAINER_NAME, - "image": image_spec.full_name, + "image": image_spec.container_image, "command": ["/usr/local/bin/_entrypoint.sh"], "args": [ (stage_mount.joinpath(v).as_posix() if isinstance(v, PurePath) else v) for v in payload.entrypoint diff --git a/snowflake/ml/jobs/_utils/spec_utils_test.py b/snowflake/ml/jobs/_utils/spec_utils_test.py index 163aa70b..69752aa2 100644 --- a/snowflake/ml/jobs/_utils/spec_utils_test.py +++ b/snowflake/ml/jobs/_utils/spec_utils_test.py @@ -1,3 +1,4 @@ +import json import os from pathlib import Path, PurePath from typing import Any, Optional @@ -8,6 +9,7 @@ from snowflake.ml.jobs._utils import constants, spec_utils, types from snowflake.ml.jobs._utils.test_file_helper import TestAsset +from snowflake.snowpark.row import Row def _get_dict_difference(expected: dict[str, Any], actual: dict[str, Any], prefix: str = "") -> str: @@ -316,9 +318,7 @@ def test_prepare_spec( ) -> None: with mock.patch("snowflake.ml.jobs._utils.spec_utils._get_image_spec") as mock_get_image_spec: mock_get_image_spec.return_value = types.ImageSpec( - repo="dummy_repo", - image_name="dummy_image", - image_tag="latest", + container_image="dummy_repo/dummy_image:latest", resource_requests=resources, resource_limits=resources, ) @@ -342,9 +342,7 @@ def test_prepare_spec_with_metrics(self) -> None: entrypoint = Path("src/main.py") with mock.patch("snowflake.ml.jobs._utils.spec_utils._get_image_spec") as mock_get_image_spec: mock_get_image_spec.return_value = types.ImageSpec( - repo="dummy_repo", - image_name="dummy_image", - image_tag="latest", + container_image="dummy_repo/dummy_image:latest", resource_requests=resources, resource_limits=resources, ) @@ -386,6 +384,398 @@ def test_get_runtime_image_tag(self, name: str, env_vars: dict[str, str], expect result = spec_utils._get_runtime_image_tag() self.assertEqual(result, expected) + @parameterized.named_parameters( # type: ignore[misc] + { + "testcase_name": "basic_spcs_runtime", + "query_result": [ + Row( + RESULT=json.dumps( + { + "runtime:spcs": { + "spcsContainerRuntime": { + "pythonVersion": "3.11.10", + "hardwareType": "CPU", + "runtimeContainerImage": ( + "/snowflake/images/snowflake_images/st_plat/runtime/" + "x86/runtime_image/snowbooks:1.5.0" + ), + } + } + } + ) + ) + ], + "expected": "/snowflake/images/snowflake_images/st_plat/runtime/x86/runtime_image/snowbooks:1.5.0", + }, + { + "testcase_name": "spcs_runtime_with_extra_fields", + "query_result": [ + Row( + RESULT=json.dumps( + { + "runtime:spcs": { + "spcsContainerRuntime": { + "pythonVersion": "3.11.11", + "hardwareType": "CPU", + "runtimeContainerImage": ( + "/snowflake/images/snowflake_images/st_plat/runtime/" + "x86/runtime_image/snowbooks:1.5.0" + ), + "randomKey1": "randomValue1", + "extraField": "test123", + }, + "randomOuterKey": "outerValue", + "extraOuterField": 42, + }, + "random_key": "random_value", + } + ) + ) + ], + "expected": "/snowflake/images/snowflake_images/st_plat/runtime/x86/runtime_image/snowbooks:1.5.0", + }, + { + "testcase_name": "gpu_hardware_no_match", + "query_result": [ + Row( + RESULT=json.dumps( + { + "runtime:spcs": { + "spcsContainerRuntime": { + "pythonVersion": "3.11.0", + "hardwareType": "GPU", + "runtimeContainerImage": ( + "/snowflake/images/snowflake_images/st_plat/runtime/" + "x86/runtime_image/snowbooks:1.5.0" + ), + } + } + } + ) + ) + ], + "expected": None, + }, + { + "testcase_name": "image_without_tag", + "query_result": [ + Row( + RESULT=json.dumps( + { + "runtime:spcs": { + "spcsContainerRuntime": { + "pythonVersion": "3.11.19", + "hardwareType": "CPU", + "runtimeContainerImage": ( + "/snowflake/images/snowflake_images/st_plat/runtime/x86/runtime_image/snowbooks" + ), + } + } + } + ) + ) + ], + "expected": "/snowflake/images/snowflake_images/st_plat/runtime/x86/runtime_image/snowbooks", + }, + { + "testcase_name": "mismatched_python_version", + "query_result": [ + Row( + RESULT=json.dumps( + { + "runtime:spcs": { + "spcsContainerImage": { + "pythonVersion": "3.10.19", + "hardwareType": "CPU", + "runtimeContainerImage": ( + "/snowflake/images/snowflake_images/st_plat/runtime/" + "x86/runtime_image/snowbooks:1.5.0" + ), + } + } + } + ) + ) + ], + "expected": None, + }, + { + "testcase_name": "multiple_same_python_version", + "query_result": [ + Row( + RESULT=json.dumps( + { + "MLJOB-RUNTIME-A:spcs": { + "createdOn": "2025-01-15T10:30:45.123Z", + "description": "First ML Job Runtime with Python 3.11", + "id": "nre-3.11-runtime-a", + "title": "ML Job Runtime A", + "spcsContainerRuntime": { + "pythonVersion": "3.11.5", + "hardwareType": "CPU", + "runtimeContainerImage": ( + "/snowflake/images/snowflake_images/st_plat/runtime/" + "x86/runtime_image/snowbooks:1.5.0" + ), + }, + }, + "MLJOB-RUNTIME-B:spcs": { + "createdOn": "2025-01-15T11:30:45.123Z", + "description": "Second ML Job Runtime with Python 3.11", + "id": "nre-3.11-runtime-b", + "title": "ML Job Runtime B", + "spcsContainerRuntime": { + "pythonVersion": "3.11.5", + "hardwareType": "CPU", + "runtimeContainerImage": ( + "/snowflake/images/snowflake_images/st_plat/runtime/" + "x86/runtime_image/snowbooks:1.6.0" + ), + }, + }, + } + ) + ) + ], + "expected": "/snowflake/images/snowflake_images/st_plat/runtime/x86/runtime_image/snowbooks:1.5.0", + }, + { + "testcase_name": "gpu_and_cpu_runtimes", + "query_result": [ + Row( + RESULT=json.dumps( + { + "MLJOB-GPU-RUNTIME:spcs": { + "createdOn": "2025-01-15T10:30:45.123Z", + "description": "GPU ML Job Runtime with Python 3.11", + "id": "nre-3.11-gpu-runtime", + "title": "ML Job GPU Runtime", + "spcsContainerRuntime": { + "pythonVersion": "3.11.5", + "hardwareType": "GPU", + "runtimeContainerImage": ( + "/snowflake/images/snowflake_images/st_plat/runtime/" + "x86/gpu_image/snowbooks:1.5.0" + ), + }, + }, + "MLJOB-CPU-RUNTIME:spcs": { + "createdOn": "2025-01-15T11:30:45.123Z", + "description": "CPU ML Job Runtime with Python 3.11", + "id": "nre-3.11-cpu-runtime", + "title": "ML Job CPU Runtime", + "spcsContainerRuntime": { + "pythonVersion": "3.11.5", + "hardwareType": "CPU", + "runtimeContainerImage": ( + "/snowflake/images/snowflake_images/st_plat/runtime/" + "x86/runtime_image/snowbooks:1.5.0" + ), + }, + }, + } + ) + ) + ], + "expected": "/snowflake/images/snowflake_images/st_plat/runtime/x86/runtime_image/snowbooks:1.5.0", + }, + { + "testcase_name": "runtime_without_spcs_suffix", + "query_result": [ + Row( + RESULT=json.dumps( + { + "MLJOB-RUNTIME-NO-SUFFIX": { + "createdOn": "2025-01-15T10:30:45.123Z", + "description": "Runtime without spcs suffix", + "id": "nre-3.11-no-suffix", + "title": "Runtime No Suffix", + "spcsContainerRuntime": { + "pythonVersion": "3.11.5", + "hardwareType": "CPU", + "runtimeContainerImage": ( + "/snowflake/images/snowflake_images/st_plat/runtime/" + "x86/runtime_image/snowbooks:1.5.0" + ), + }, + }, + } + ) + ) + ], + "expected": "/snowflake/images/snowflake_images/st_plat/runtime/x86/runtime_image/snowbooks:1.5.0", + }, + { + "testcase_name": "python_version_without_patch", + "query_result": [ + Row( + RESULT=json.dumps( + { + "MLJOB-RUNTIME-3.11:spcs": { + "createdOn": "2025-01-15T10:30:45.123Z", + "description": "ML Job Runtime with Python 3.11", + "id": "nre-3.11-runtime", + "title": "ML Job Runtime 3.11", + "spcsContainerRuntime": { + "pythonVersion": "3.11", + "hardwareType": "CPU", + "runtimeContainerImage": ( + "/snowflake/images/snowflake_images/st_plat/runtime/" + "x86/runtime_image/snowbooks:1.5.0" + ), + }, + } + } + ) + ) + ], + "expected": "/snowflake/images/snowflake_images/st_plat/runtime/x86/runtime_image/snowbooks:1.5.0", + }, + { + "testcase_name": "mixed_warehouse_and_spcs", + "query_result": [ + Row( + RESULT=json.dumps( + { + "WH-RUNTIME-2.0:warehouse": { + "createdOn": "2025-04-11T20:03:52.569Z", + "description": "includes Python 3.11", + "id": "nre-3.10-2.0", + "title": "Snowflake Warehouse Runtime 2.0", + "warehouseRuntime": {"pythonEnvironmentId": "3.11.2"}, + }, + "MLJOB-CORRECT:spcs": { + "createdOn": "2025-07-28T21:49:39.685Z", + "description": "Correct Python 3.11 runtime", + "id": "nre-3.11-mljob-test", + "title": "Correct Runtime Test", + "spcsContainerRuntime": { + "hardwareType": "CPU", + "pythonVersion": "3.11.2", + "runtimeContainerImage": ( + "/snowflake/images/snowflake_images/st_plat/runtime/" + "x86/runtime_image/snowbooks:1.5.0" + ), + }, + }, + } + ) + ) + ], + "expected": "/snowflake/images/snowflake_images/st_plat/runtime/x86/runtime_image/snowbooks:1.5.0", + }, + { + "testcase_name": "only_warehouse_runtimes", + "query_result": [ + Row( + RESULT=json.dumps( + { + "WH-RUNTIME-2.0:warehouse": { + "createdOn": "2025-04-11T20:03:52.569Z", + "description": "includes Python 3.11", + "id": "nre-3.11-2.0", + "title": "Snowflake Warehouse Runtime 2.0", + "warehouseRuntime": {"pythonEnvironmentId": "3.11-2.0"}, + }, + "WH-RUNTIME-1.0:warehouse": { + "createdOn": "2025-04-11T20:03:51.363Z", + "description": "includes Python 3.9", + "id": "nre-3.9-1.0", + "title": "Snowflake Warehouse Runtime 1.0", + "warehouseRuntime": {"pythonEnvironmentId": "3.9-1.0"}, + }, + } + ) + ) + ], + "expected": None, + }, + { + "testcase_name": "invalid_python_version_format", + "query_result": [ + Row( + RESULT=json.dumps( + { + "INVALID-BAD-PYVER:spcs": { + "createdOn": "2025-01-15T10:30:45.123Z", + "description": "Invalid runtime with bad pythonVersion", + "id": "nre-invalid-1", + "title": "Invalid Runtime 1", + "spcsContainerRuntime": { + "pythonVersion": "invalid.version.format", + "hardwareType": "CPU", + "runtimeContainerImage": "/snowflake/images/image:1.5.0", + }, + } + } + ) + ) + ], + "expected": None, + }, + { + "testcase_name": "missing_python_version", + "query_result": [ + Row( + RESULT=json.dumps( + { + "INVALID-MISSING-PYVER:spcs": { + "createdOn": "2025-01-15T10:30:45.123Z", + "description": "Invalid runtime missing pythonVersion", + "id": "nre-invalid-2", + "title": "Invalid Runtime 2", + "spcsContainerRuntime": { + "hardwareType": "CPU", + "runtimeContainerImage": "/snowflake/images/image:1.5.0", + }, + } + } + ) + ) + ], + "expected": None, + }, + { + "testcase_name": "missing_hardware_type", + "query_result": [ + Row( + RESULT=json.dumps( + { + "INVALID-MISSING-HW:spcs": { + "createdOn": "2025-01-15T10:30:45.123Z", + "description": "Invalid runtime missing hardwareType", + "id": "nre-invalid-3", + "title": "Invalid Runtime 3", + "spcsContainerRuntime": { + "pythonVersion": "3.11.5", + "runtimeContainerImage": "/snowflake/images/image:1.5.0", + }, + } + } + ) + ) + ], + "expected": None, + }, + ) + def test_get_runtime_image( + self, + query_result: list[Row], + expected: Optional[str], + ) -> None: + """Test _get_runtime_image function core scenarios.""" + with mock.patch("snowflake.ml.jobs._utils.query_helper.run_query") as mock_query, mock.patch( + "sys.version_info", new=mock.Mock(major=3, minor=11) + ): + + mock_query.return_value = query_result + + result = spec_utils._get_runtime_image( + mock.Mock(), + "CPU", + ) + self.assertEqual(expected, result) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/jobs/_utils/stage_utils.py b/snowflake/ml/jobs/_utils/stage_utils.py index 195d629b..c8e8168a 100644 --- a/snowflake/ml/jobs/_utils/stage_utils.py +++ b/snowflake/ml/jobs/_utils/stage_utils.py @@ -121,15 +121,28 @@ def __fspath__(self) -> str: return self._compose_path(self._path) def joinpath(self, *args: Union[str, PathLike[str]]) -> "StagePath": + """ + Joins the given path arguments to the current path, + mimicking the behavior of pathlib.Path.joinpath. + If the argument is a stage path (i.e., an absolute path), + it overrides the current path and is returned as the final path. + If the argument is a normal path, it is joined with the current relative path + using self._path.joinpath(arg). + + Args: + *args: Path components to join. + + Returns: + A new StagePath with the joined path. + + Raises: + NotImplementedError: the argument is a stage path. + """ path = self for arg in args: - path = path._make_child(arg) + if isinstance(arg, StagePath): + raise NotImplementedError + else: + # the arg might be an absolute path, so we need to remove the leading '/' + path = StagePath(f"{path.root}/{path._path.joinpath(arg).as_posix().lstrip('/')}") return path - - def _make_child(self, path: Union[str, PathLike[str]]) -> "StagePath": - stage_path = path if isinstance(path, StagePath) else StagePath(os.fspath(path)) - if self.root == stage_path.root: - child_path = self._path.joinpath(stage_path._path) - return StagePath(self._compose_path(child_path)) - else: - return stage_path diff --git a/snowflake/ml/jobs/_utils/stage_utils_test.py b/snowflake/ml/jobs/_utils/stage_utils_test.py index f79c4d90..64fd9793 100644 --- a/snowflake/ml/jobs/_utils/stage_utils_test.py +++ b/snowflake/ml/jobs/_utils/stage_utils_test.py @@ -1,6 +1,6 @@ from absl.testing import absltest, parameterized -from snowflake.ml.jobs._utils import stage_utils +from snowflake.ml.jobs._utils import payload_utils """ StagePath inherits the PurePosixPath @@ -21,8 +21,8 @@ class StageUtilsTests(parameterized.TestCase): ("snow://headless/abc/versions/v9.8.7///dirs", "snow://headless/abc/versions/v9.8.7/main.py", False), ) def test_is_relative_to(self, path1: str, path2: str, expected: bool) -> None: - stagePath1 = stage_utils.StagePath(path1) - stagePath2 = stage_utils.StagePath(path2) + stagePath1 = payload_utils.resolve_path(path1) + stagePath2 = payload_utils.resolve_path(path2) self.assertEqual(stagePath2.is_relative_to(stagePath1), expected) @parameterized.parameters( # type: ignore[misc] @@ -32,7 +32,7 @@ def test_is_relative_to(self, path1: str, path2: str, expected: bool) -> None: ("@test_stage/", "@test_stage"), ) def test_root(self, path: str, expected: str) -> None: - self.assertEqual(stage_utils.StagePath(path).root, expected) + self.assertEqual(payload_utils.resolve_path(path).root, expected) @parameterized.parameters( # type: ignore[misc] ("@test_stage/src/", "@test_stage/src"), @@ -41,7 +41,7 @@ def test_root(self, path: str, expected: str) -> None: ("snow://headless/abc/versions/v9.8.7", "snow://headless/abc/versions/v9.8.7"), ) def test_absolute(self, path: str, expected_path: str) -> None: - self.assertEqual(stage_utils.StagePath(path).absolute().as_posix(), expected_path) + self.assertEqual(payload_utils.resolve_path(path).absolute().as_posix(), expected_path) @parameterized.parameters( # type: ignore[misc] ("@test_stage/src/", "@test_stage"), @@ -52,7 +52,7 @@ def test_absolute(self, path: str, expected_path: str) -> None: ("snow://headless/abc/versions/v9.8.7/dirs/main.py", "snow://headless/abc/versions/v9.8.7/dirs"), ) def test_parent(self, path: str, expected_path: str) -> None: - self.assertEqual(stage_utils.StagePath(path).parent.as_posix(), expected_path) + self.assertEqual(payload_utils.resolve_path(path).parent.as_posix(), expected_path) @parameterized.parameters( # type: ignore[misc] ("@test_stage/src/", "@test_stage/src"), @@ -65,7 +65,7 @@ def test_parent(self, path: str, expected_path: str) -> None: ("snow://headless/abc/versions/v9.8.7/dirs/dir", "snow://headless/abc/versions/v9.8.7/dirs/dir"), ) def test_posix(self, path: str, expected_path: str) -> None: - self.assertEqual(stage_utils.StagePath(path).as_posix(), expected_path) + self.assertEqual(payload_utils.resolve_path(path).as_posix(), expected_path) @parameterized.parameters( # type: ignore[misc] ("@test_stage/src/", "@test_stage/src", True), @@ -78,29 +78,22 @@ def test_posix(self, path: str, expected_path: str) -> None: ), ) def test_equal(self, path1: str, path2: str, expected_result: bool) -> None: - self.assertEqual(stage_utils.StagePath(path1) == stage_utils.StagePath(path2), expected_result) + self.assertEqual(payload_utils.resolve_path(path1) == payload_utils.resolve_path(path2), expected_result) @parameterized.parameters( # type: ignore[misc] - ("@test_stage/src/", ("@test_stage/src",), "@test_stage/src/src"), - ("@test_stage/dir1", ("@test_stage/dir2",), "@test_stage/dir1/dir2"), - ( - "snow://headless/abc/versions/v9.8.7/src/", - ("snow://headless/abc/versions/v9.8.7/src",), - "snow://headless/abc/versions/v9.8.7/src/src", - ), + ("@test_stage/dir1", ("dir2",), "@test_stage/dir1/dir2"), ( "snow://headless/abc/versions/v9.8.7/dir1", - ("snow://headless/abc/versions/v9.8.7/dir2",), + ("dir2",), "snow://headless/abc/versions/v9.8.7/dir1/dir2", ), - ("@test_stage/src/", ("@test_stage/dir1", "@test_stage/dir2"), "@test_stage/src/dir1/dir2"), - ("@test_stage/src/", ("@test_stage1/dir1", "@test_stage1/dir2"), "@test_stage1/dir1/dir2"), + ("@test_stage/src/", ("dir1", "/dir2"), "@test_stage/dir2"), ) def test_joinpath(self, path1: str, paths: tuple[str], expected_path: str) -> None: - stagePath1 = stage_utils.StagePath(path1) + stagePath1 = payload_utils.resolve_path(path1) stagePaths = [] for path in paths: - stagePaths.append(stage_utils.StagePath(path)) + stagePaths.append(payload_utils.resolve_path(path)) self.assertEqual(stagePath1.joinpath(*tuple(stagePaths)).as_posix(), expected_path) diff --git a/snowflake/ml/jobs/_utils/types.py b/snowflake/ml/jobs/_utils/types.py index 8b41c465..51501a42 100644 --- a/snowflake/ml/jobs/_utils/types.py +++ b/snowflake/ml/jobs/_utils/types.py @@ -30,6 +30,10 @@ def suffix(self) -> str: def parent(self) -> "PayloadPath": ... + @property + def root(self) -> str: + ... + def exists(self) -> bool: ... @@ -98,12 +102,6 @@ class ComputeResources: @dataclass(frozen=True) class ImageSpec: - repo: str - image_name: str - image_tag: str resource_requests: ComputeResources resource_limits: ComputeResources - - @property - def full_name(self) -> str: - return f"{self.repo}/{self.image_name}:{self.image_tag}" + container_image: str diff --git a/snowflake/ml/jobs/job.py b/snowflake/ml/jobs/job.py index bf481e83..ae045299 100644 --- a/snowflake/ml/jobs/job.py +++ b/snowflake/ml/jobs/job.py @@ -199,7 +199,7 @@ def wait(self, timeout: float = -1) -> types.JOB_STATUS: elapsed = time.monotonic() - start_time if elapsed >= timeout >= 0: raise TimeoutError(f"Job {self.name} did not complete within {timeout} seconds") - elif status == "PENDING" and not warning_shown and elapsed >= 2: # Only show warning after 2s + elif status == "PENDING" and not warning_shown and elapsed >= 5: # Only show warning after 5s pool_info = _get_compute_pool_info(self._session, self._compute_pool) if (pool_info.max_nodes - pool_info.active_nodes) < self.min_instances: logger.warning( diff --git a/snowflake/ml/jobs/manager.py b/snowflake/ml/jobs/manager.py index 262ca69c..598f8ba4 100644 --- a/snowflake/ml/jobs/manager.py +++ b/snowflake/ml/jobs/manager.py @@ -426,7 +426,6 @@ def _submit_job( Raises: ValueError: If database or schema value(s) are invalid - SnowparkSQLException: If there is an error submitting the job. """ session = session or get_active_session() @@ -504,18 +503,7 @@ def _submit_job( query_text, params = _generate_submission_query( spec, external_access_integrations, query_warehouse, target_instances, session, compute_pool, job_id ) - try: - _ = query_helper.run_query(session, query_text, params=params) - except SnowparkSQLException as e: - if "Invalid spec: unknown option 'resourceManagement' for 'spec'." in e.message: - logger.warning("Dropping 'resourceManagement' from spec because control policy is not enabled.") - spec["spec"].pop("resourceManagement", None) - query_text, params = _generate_submission_query( - spec, external_access_integrations, query_warehouse, target_instances, session, compute_pool, job_id - ) - _ = query_helper.run_query(session, query_text, params=params) - else: - raise + _ = query_helper.run_query(session, query_text, params=params) return get_job(job_id, session=session) diff --git a/snowflake/ml/model/BUILD.bazel b/snowflake/ml/model/BUILD.bazel index a71f7e96..4e66ae3b 100644 --- a/snowflake/ml/model/BUILD.bazel +++ b/snowflake/ml/model/BUILD.bazel @@ -62,6 +62,12 @@ py_library( ], ) +py_library( + name = "inference_engine", + srcs = ["inference_engine.py"], + deps = [], +) + py_library( name = "model", srcs = ["__init__.py"], diff --git a/snowflake/ml/model/_client/model/model_version_impl.py b/snowflake/ml/model/_client/model/model_version_impl.py index 2c2859e5..80057dbe 100644 --- a/snowflake/ml/model/_client/model/model_version_impl.py +++ b/snowflake/ml/model/_client/model/model_version_impl.py @@ -707,6 +707,128 @@ def _load_from_lineage_node(session: Session, name: str, version: str) -> "Model version_name=sql_identifier.SqlIdentifier(version), ) + def _get_inference_engine_args( + self, experimental_options: Optional[dict[str, Any]] + ) -> Optional[service_ops.InferenceEngineArgs]: + + if not experimental_options: + return None + + if "inference_engine" not in experimental_options: + raise ValueError("inference_engine is required in experimental_options") + + return service_ops.InferenceEngineArgs( + inference_engine=experimental_options["inference_engine"], + inference_engine_args_override=experimental_options.get("inference_engine_args_override"), + ) + + def _enrich_inference_engine_args( + self, + inference_engine_args: service_ops.InferenceEngineArgs, + gpu_requests: Optional[Union[str, int]] = None, + ) -> Optional[service_ops.InferenceEngineArgs]: + """Enrich inference engine args with model path and tensor parallelism settings. + + Args: + inference_engine_args: The original inference engine args + gpu_requests: The number of GPUs requested + + Returns: + Enriched inference engine args + + Raises: + ValueError: Invalid gpu_requests + """ + if inference_engine_args.inference_engine_args_override is None: + inference_engine_args.inference_engine_args_override = [] + + # Get model stage path and strip off "snow://" prefix + model_stage_path = self._model_ops.get_model_version_stage_path( + database_name=None, + schema_name=None, + model_name=self._model_name, + version_name=self._version_name, + ) + + # Strip "snow://" prefix + if model_stage_path.startswith("snow://"): + model_stage_path = model_stage_path.replace("snow://", "", 1) + + # Always overwrite the model key by appending + inference_engine_args.inference_engine_args_override.append(f"--model={model_stage_path}") + + gpu_count = None + + # Set tensor-parallelism if gpu_requests is specified + if gpu_requests is not None: + # assert gpu_requests is a string or an integer before casting to int + if isinstance(gpu_requests, str) or isinstance(gpu_requests, int): + try: + gpu_count = int(gpu_requests) + except ValueError: + raise ValueError(f"Invalid gpu_requests: {gpu_requests}") + + if gpu_count is not None: + if gpu_count > 0: + inference_engine_args.inference_engine_args_override.append(f"--tensor-parallel-size={gpu_count}") + else: + raise ValueError(f"Invalid gpu_requests: {gpu_requests}") + + return inference_engine_args + + def _check_huggingface_text_generation_model( + self, + statement_params: Optional[dict[str, Any]] = None, + ) -> None: + """Check if the model is a HuggingFace pipeline with text-generation task. + + Args: + statement_params: Optional dictionary of statement parameters to include + in the SQL command to fetch model spec. + + Raises: + ValueError: If the model is not a HuggingFace text-generation model. + """ + # Fetch model spec + model_spec = self._model_ops._fetch_model_spec( + database_name=None, + schema_name=None, + model_name=self._model_name, + version_name=self._version_name, + statement_params=statement_params, + ) + + # Check if model_type is huggingface_pipeline + model_type = model_spec.get("model_type") + if model_type != "huggingface_pipeline": + raise ValueError( + f"Inference engine is only supported for HuggingFace text-generation models. " + f"Found model_type: {model_type}" + ) + + # Check if model supports text-generation task + # There should only be one model in the list because we don't support multiple models in a single model spec + models = model_spec.get("models", {}) + is_text_generation = False + found_tasks: list[str] = [] + + # As long as the model supports text-generation task, we can use it + for _, model_info in models.items(): + options = model_info.get("options", {}) + task = options.get("task") + if task: + found_tasks.append(str(task)) + if task == "text-generation": + is_text_generation = True + break + + if not is_text_generation: + tasks_str = ", ".join(found_tasks) + found_tasks_str = ( + f"Found task(s): {tasks_str} in model spec." if found_tasks else "No task found in model spec." + ) + raise ValueError(f"Inference engine is only supported for task 'text-generation'. {found_tasks_str}") + @overload def create_service( self, @@ -714,7 +836,7 @@ def create_service( service_name: str, image_build_compute_pool: Optional[str] = None, service_compute_pool: str, - image_repo: str, + image_repo: Optional[str] = None, ingress_enabled: bool = False, max_instances: int = 1, cpu_requests: Optional[str] = None, @@ -725,6 +847,7 @@ def create_service( force_rebuild: bool = False, build_external_access_integration: Optional[str] = None, block: bool = True, + experimental_options: Optional[dict[str, Any]] = None, ) -> Union[str, async_job.AsyncJob]: """Create an inference service with the given spec. @@ -735,7 +858,8 @@ def create_service( the service compute pool if None. service_compute_pool: The name of the compute pool used to run the inference service. image_repo: The name of the image repository, can be fully qualified. If not fully qualified, the database - or schema of the model will be used. + or schema of the model will be used. This can be None, in that case a default hidden image repository + will be used. ingress_enabled: If true, creates an service endpoint associated with the service. User must have BIND SERVICE ENDPOINT privilege on the account. max_instances: The maximum number of inference service instances to run. The same value it set to @@ -756,6 +880,10 @@ def create_service( block: A bool value indicating whether this function will wait until the service is available. When it is ``False``, this function executes the underlying service creation asynchronously and returns an :class:`AsyncJob`. + experimental_options: Experimental options for the service creation with custom inference engine. + Currently, only `inference_engine` and `inference_engine_args_override` are supported. + `inference_engine` is the name of the inference engine to use. + `inference_engine_args_override` is a list of string arguments to pass to the inference engine. """ ... @@ -766,7 +894,7 @@ def create_service( service_name: str, image_build_compute_pool: Optional[str] = None, service_compute_pool: str, - image_repo: str, + image_repo: Optional[str] = None, ingress_enabled: bool = False, max_instances: int = 1, cpu_requests: Optional[str] = None, @@ -777,6 +905,7 @@ def create_service( force_rebuild: bool = False, build_external_access_integrations: Optional[list[str]] = None, block: bool = True, + experimental_options: Optional[dict[str, Any]] = None, ) -> Union[str, async_job.AsyncJob]: """Create an inference service with the given spec. @@ -787,7 +916,8 @@ def create_service( the service compute pool if None. service_compute_pool: The name of the compute pool used to run the inference service. image_repo: The name of the image repository, can be fully qualified. If not fully qualified, the database - or schema of the model will be used. + or schema of the model will be used. This can be None, in that case a default hidden image repository + will be used. ingress_enabled: If true, creates an service endpoint associated with the service. User must have BIND SERVICE ENDPOINT privilege on the account. max_instances: The maximum number of inference service instances to run. The same value it set to @@ -808,6 +938,10 @@ def create_service( block: A bool value indicating whether this function will wait until the service is available. When it is ``False``, this function executes the underlying service creation asynchronously and returns an :class:`AsyncJob`. + experimental_options: Experimental options for the service creation with custom inference engine. + Currently, only `inference_engine` and `inference_engine_args_override` are supported. + `inference_engine` is the name of the inference engine to use. + `inference_engine_args_override` is a list of string arguments to pass to the inference engine. """ ... @@ -832,7 +966,7 @@ def create_service( service_name: str, image_build_compute_pool: Optional[str] = None, service_compute_pool: str, - image_repo: str, + image_repo: Optional[str] = None, ingress_enabled: bool = False, max_instances: int = 1, cpu_requests: Optional[str] = None, @@ -844,6 +978,7 @@ def create_service( build_external_access_integration: Optional[str] = None, build_external_access_integrations: Optional[list[str]] = None, block: bool = True, + experimental_options: Optional[dict[str, Any]] = None, ) -> Union[str, async_job.AsyncJob]: """Create an inference service with the given spec. @@ -854,7 +989,8 @@ def create_service( the service compute pool if None. service_compute_pool: The name of the compute pool used to run the inference service. image_repo: The name of the image repository, can be fully qualified. If not fully qualified, the database - or schema of the model will be used. + or schema of the model will be used. This can be None, in that case a default hidden image repository + will be used. ingress_enabled: If true, creates an service endpoint associated with the service. User must have BIND SERVICE ENDPOINT privilege on the account. max_instances: The maximum number of inference service instances to run. The same value it set to @@ -877,6 +1013,11 @@ def create_service( block: A bool value indicating whether this function will wait until the service is available. When it is False, this function executes the underlying service creation asynchronously and returns an AsyncJob. + experimental_options: Experimental options for the service creation with custom inference engine. + Currently, only `inference_engine` and `inference_engine_args_override` are supported. + `inference_engine` is the name of the inference engine to use. + `inference_engine_args_override` is a list of string arguments to pass to the inference engine. + Raises: ValueError: Illegal external access integration arguments. @@ -885,6 +1026,9 @@ def create_service( Returns: If `block=True`, return result information about service creation from server. Otherwise, return the service creation AsyncJob. + + Raises: + ValueError: Illegal external access integration arguments. """ statement_params = telemetry.get_statement_params( project=_TELEMETRY_PROJECT, @@ -906,7 +1050,18 @@ def create_service( build_external_access_integrations = [build_external_access_integration] service_db_id, service_schema_id, service_id = sql_identifier.parse_fully_qualified_name(service_name) - image_repo_db_id, image_repo_schema_id, image_repo_id = sql_identifier.parse_fully_qualified_name(image_repo) + + # Check if model is HuggingFace text-generation before doing inference engine checks + if experimental_options: + self._check_huggingface_text_generation_model(statement_params) + + inference_engine_args: Optional[service_ops.InferenceEngineArgs] = self._get_inference_engine_args( + experimental_options + ) + + # Enrich inference engine args if inference engine is specified + if inference_engine_args is not None: + inference_engine_args = self._enrich_inference_engine_args(inference_engine_args, gpu_requests) from snowflake.ml.model import event_handler from snowflake.snowpark import exceptions @@ -929,7 +1084,7 @@ def create_service( else sql_identifier.SqlIdentifier(service_compute_pool) ), service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool), - image_repo=image_repo, + image_repo_name=image_repo, ingress_enabled=ingress_enabled, max_instances=max_instances, cpu_requests=cpu_requests, @@ -946,6 +1101,7 @@ def create_service( block=block, statement_params=statement_params, progress_status=status, + inference_engine_args=inference_engine_args, ) status.update(label="Model service created successfully", state="complete", expanded=False) return result @@ -1039,7 +1195,7 @@ def _run_job( *, job_name: str, compute_pool: str, - image_repo: str, + image_repo: Optional[str] = None, output_table_name: str, function_name: Optional[str] = None, cpu_requests: Optional[str] = None, @@ -1074,7 +1230,7 @@ def _run_job( job_name=job_id, compute_pool_name=sql_identifier.SqlIdentifier(compute_pool), warehouse_name=sql_identifier.SqlIdentifier(warehouse), - image_repo=image_repo, + image_repo_name=image_repo, output_table_database_name=output_table_db_id, output_table_schema_name=output_table_schema_id, output_table_name=output_table_id, diff --git a/snowflake/ml/model/_client/model/model_version_impl_test.py b/snowflake/ml/model/_client/model/model_version_impl_test.py index 1ead2422..2d9dc178 100644 --- a/snowflake/ml/model/_client/model/model_version_impl_test.py +++ b/snowflake/ml/model/_client/model/model_version_impl_test.py @@ -1,7 +1,7 @@ import os import pathlib import tempfile -from typing import cast +from typing import Any, cast from unittest import mock import pandas as pd @@ -9,7 +9,7 @@ from snowflake.ml._internal import platform_capabilities as pc from snowflake.ml._internal.utils import sql_identifier -from snowflake.ml.model import model_signature, task, type_hints +from snowflake.ml.model import inference_engine, model_signature, task, type_hints from snowflake.ml.model._client.model import model_version_impl from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops from snowflake.ml.model._model_composer import model_composer @@ -793,7 +793,7 @@ def test_create_service(self) -> None: service_name=sql_identifier.SqlIdentifier("SERVICE"), image_build_compute_pool_name=sql_identifier.SqlIdentifier("IMAGE_BUILD_COMPUTE_POOL"), service_compute_pool_name=sql_identifier.SqlIdentifier("SERVICE_COMPUTE_POOL"), - image_repo="IMAGE_REPO", + image_repo_name="IMAGE_REPO", ingress_enabled=False, max_instances=3, cpu_requests="CPU", @@ -806,6 +806,7 @@ def test_create_service(self) -> None: block=True, statement_params=mock.ANY, progress_status=mock_progress_status, + inference_engine_args=None, ) def test_create_service_same_pool(self) -> None: @@ -841,7 +842,7 @@ def test_create_service_same_pool(self) -> None: service_name=sql_identifier.SqlIdentifier("SERVICE"), image_build_compute_pool_name=sql_identifier.SqlIdentifier("SERVICE_COMPUTE_POOL"), service_compute_pool_name=sql_identifier.SqlIdentifier("SERVICE_COMPUTE_POOL"), - image_repo="IMAGE_REPO", + image_repo_name="IMAGE_REPO", ingress_enabled=False, max_instances=3, cpu_requests="CPU", @@ -854,6 +855,7 @@ def test_create_service_same_pool(self) -> None: block=True, statement_params=mock.ANY, progress_status=mock_progress_status, + inference_engine_args=None, ) def test_create_service_no_eai(self) -> None: @@ -889,7 +891,7 @@ def test_create_service_no_eai(self) -> None: service_name=sql_identifier.SqlIdentifier("SERVICE"), image_build_compute_pool_name=sql_identifier.SqlIdentifier("IMAGE_BUILD_COMPUTE_POOL"), service_compute_pool_name=sql_identifier.SqlIdentifier("SERVICE_COMPUTE_POOL"), - image_repo="IMAGE_REPO", + image_repo_name="IMAGE_REPO", ingress_enabled=False, max_instances=3, cpu_requests="CPU", @@ -902,6 +904,7 @@ def test_create_service_no_eai(self) -> None: block=True, statement_params=mock.ANY, progress_status=mock_progress_status, + inference_engine_args=None, ) def test_create_service_async_job(self) -> None: @@ -938,7 +941,7 @@ def test_create_service_async_job(self) -> None: service_name=sql_identifier.SqlIdentifier("SERVICE"), image_build_compute_pool_name=sql_identifier.SqlIdentifier("IMAGE_BUILD_COMPUTE_POOL"), service_compute_pool_name=sql_identifier.SqlIdentifier("SERVICE_COMPUTE_POOL"), - image_repo="IMAGE_REPO", + image_repo_name="IMAGE_REPO", ingress_enabled=False, max_instances=3, cpu_requests="CPU", @@ -951,6 +954,7 @@ def test_create_service_async_job(self) -> None: block=False, statement_params=mock.ANY, progress_status=mock_progress_status, + inference_engine_args=None, ) def test_list_services(self) -> None: @@ -1006,6 +1010,86 @@ def test_delete_service(self) -> None: statement_params=mock.ANY, ) + def test_create_service_with_experimental_options(self) -> None: + """Test create_service with experimental options for inference engine.""" + with ( + mock.patch.object(self.m_mv._service_ops, "create_service") as mock_create_service, + mock.patch.object( + self.m_mv, "_check_huggingface_text_generation_model" + ) as mock_check_huggingface_text_generation_model, + mock.patch.object( + self.m_mv._model_ops, + "get_model_version_stage_path", + return_value="snow://model/DB.SCHEMA.MODEL/versions/v1/", + ), + mock.patch.object( + self.m_mv._model_ops, + "_fetch_model_spec", + return_value={ + "model_type": "huggingface_pipeline", + "models": { + "model1": { + "model_type": "huggingface_pipeline", + "options": {"task": "text-generation"}, + } + }, + }, + ), + ): + # Test with inference engine and GPU + self.m_mv.create_service( + service_name="SERVICE", + service_compute_pool="SERVICE_COMPUTE_POOL", + image_repo="IMAGE_REPO", + gpu_requests="4", + experimental_options={ + "inference_engine": inference_engine.InferenceEngine.VLLM, + "inference_engine_args_override": ["--max_tokens=1000", "--temperature=0.8"], + }, + ) + + # This check should happen when experimental_options is provided + mock_check_huggingface_text_generation_model.assert_called_once() + + # Verify that the enriched kwargs were passed to create_service + mock_create_service.assert_called_once() + call_args = mock_create_service.call_args + + # Check that inference_engine is passed correctly + self.assertEqual( + call_args.kwargs["inference_engine_args"].inference_engine, inference_engine.InferenceEngine.VLLM + ) + + # Check that inference_engine_args is enriched correctly + expected_args = [ + "--max_tokens=1000", + "--temperature=0.8", + "--model=model/DB.SCHEMA.MODEL/versions/v1/", + "--tensor-parallel-size=4", + ] + + self.assertEqual( + call_args.kwargs["inference_engine_args"].inference_engine_args_override, + expected_args, + ) + + def test_create_service_without_experimental_options(self) -> None: + """Test create_service without experimental options to ensure existing behavior is preserved.""" + with mock.patch.object(self.m_mv._service_ops, "create_service") as mock_create_service: + # Test without experimental_options + self.m_mv.create_service( + service_name="SERVICE", + service_compute_pool="SERVICE_COMPUTE_POOL", + image_repo="IMAGE_REPO", + gpu_requests="2", + ) + + # Verify that None is passed for inference engine parameters + mock_create_service.assert_called_once() + call_args = mock_create_service.call_args + + self.assertIsNone(call_args.kwargs["inference_engine_args"]) + def test_run_job(self) -> None: m_df = mock_data_frame.MockDataFrame() m_methods = [ @@ -1095,7 +1179,7 @@ def test_run_job(self) -> None: job_name=sql_identifier.SqlIdentifier("TEST_JOB"), compute_pool_name=sql_identifier.SqlIdentifier("TEST_COMPUTE_POOL"), warehouse_name="TEST_WAREHOUSE", - image_repo="TEST_IMAGE_REPO", + image_repo_name="TEST_IMAGE_REPO", output_table_database_name=None, output_table_schema_name=None, output_table_name=sql_identifier.SqlIdentifier("TEST_OUTPUT_TABLE"), @@ -1142,7 +1226,7 @@ def test_run_job(self) -> None: job_name=sql_identifier.SqlIdentifier("TEST_JOB"), compute_pool_name=sql_identifier.SqlIdentifier("TEST_COMPUTE_POOL"), warehouse_name=sql_identifier.SqlIdentifier("TEST_WAREHOUSE"), - image_repo="DB.SCHEMA.TEST_IMAGE_REPO", + image_repo_name="DB.SCHEMA.TEST_IMAGE_REPO", output_table_database_name=sql_identifier.SqlIdentifier("DB"), output_table_schema_name=sql_identifier.SqlIdentifier("SCHEMA"), output_table_name=sql_identifier.SqlIdentifier("TEST_OUTPUT_TABLE"), @@ -1367,6 +1451,243 @@ def test_repr_html_happy_path_function_details(self) -> None: self.assertIn(" None: + # Test with None + inference_engine_args = self.m_mv._get_inference_engine_args(None) + self.assertIsNone(inference_engine_args) + + # Test with empty experimental_options + inference_engine_args = self.m_mv._get_inference_engine_args({}) + self.assertIsNone(inference_engine_args) + + # Test with experimental_options missing inference_engine key + with self.assertRaises(ValueError) as cm: + self.m_mv._get_inference_engine_args({"other_key": "value"}) + self.assertEqual(str(cm.exception), "inference_engine is required in experimental_options") + + # Test with only inference_engine (no args_override) + experimental_options: dict[str, Any] = {"inference_engine": inference_engine.InferenceEngine.VLLM} + inference_engine_args = self.m_mv._get_inference_engine_args(experimental_options) + assert inference_engine_args is not None + self.assertEqual(inference_engine_args.inference_engine, inference_engine.InferenceEngine.VLLM) + self.assertIsNone(inference_engine_args.inference_engine_args_override) + + # Test with inference_engine and args_override + experimental_options = { + "inference_engine": inference_engine.InferenceEngine.VLLM, + "inference_engine_args_override": ["--max_tokens=100", "--temperature=0.7"], + } + inference_engine_args = self.m_mv._get_inference_engine_args(experimental_options) + assert inference_engine_args is not None + self.assertEqual(inference_engine_args.inference_engine, inference_engine.InferenceEngine.VLLM) + self.assertEqual( + inference_engine_args.inference_engine_args_override, ["--max_tokens=100", "--temperature=0.7"] + ) + + # Test with inference_engine and empty args_override + experimental_options = { + "inference_engine": inference_engine.InferenceEngine.VLLM, + "inference_engine_args_override": [], + } + inference_engine_args = self.m_mv._get_inference_engine_args(experimental_options) + assert inference_engine_args is not None + self.assertEqual(inference_engine_args.inference_engine, inference_engine.InferenceEngine.VLLM) + self.assertEqual(inference_engine_args.inference_engine_args_override, []) + + def test_enrich_inference_engine_args(self) -> None: + """Test _enrich_inference_engine_args method with various inputs.""" + # Mock get_model_version_stage_path to return a predictable path + with mock.patch.object( + self.m_mv._model_ops, + "get_model_version_stage_path", + return_value="snow://model/TEMP.test.MODEL/versions/v1/", + ) as mock_get_path: + # Test with args=None and no GPU + enriched = self.m_mv._enrich_inference_engine_args( + service_ops.InferenceEngineArgs( + inference_engine=inference_engine.InferenceEngine.VLLM, + inference_engine_args_override=None, + ), + gpu_requests=None, + ) + assert enriched is not None + self.assertEqual(enriched.inference_engine, inference_engine.InferenceEngine.VLLM) + self.assertEqual(enriched.inference_engine_args_override, ["--model=model/TEMP.test.MODEL/versions/v1/"]) + mock_get_path.assert_called_with( + database_name=None, + schema_name=None, + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + ) + + # Test with empty args and GPU count + enriched = self.m_mv._enrich_inference_engine_args( + service_ops.InferenceEngineArgs( + inference_engine=inference_engine.InferenceEngine.VLLM, + inference_engine_args_override=None, + ), + gpu_requests=2, + ) + assert enriched is not None + self.assertEqual(enriched.inference_engine, inference_engine.InferenceEngine.VLLM) + + # Test with args and string GPU count + original_args = ["--max_tokens=100", "--temperature=0.7"] + enriched = self.m_mv._enrich_inference_engine_args( + service_ops.InferenceEngineArgs( + inference_engine=inference_engine.InferenceEngine.VLLM, + inference_engine_args_override=original_args, + ), + gpu_requests="4", + ) + assert enriched is not None + self.assertEqual( + enriched, + service_ops.InferenceEngineArgs( + inference_engine=inference_engine.InferenceEngine.VLLM, + inference_engine_args_override=[ + "--max_tokens=100", + "--temperature=0.7", + "--model=model/TEMP.test.MODEL/versions/v1/", + "--tensor-parallel-size=4", + ], + ), + ) + + # Test overwriting existing model key with new model key by appending to the list + enriched = self.m_mv._enrich_inference_engine_args( + service_ops.InferenceEngineArgs( + inference_engine=inference_engine.InferenceEngine.VLLM, + inference_engine_args_override=[ + "--max_tokens=100", + "--temperature=0.7", + "--model=old/path", + ], + ), + ) + self.assertEqual( + enriched, + service_ops.InferenceEngineArgs( + inference_engine=inference_engine.InferenceEngine.VLLM, + inference_engine_args_override=[ + "--max_tokens=100", + "--temperature=0.7", + "--model=old/path", + "--model=model/TEMP.test.MODEL/versions/v1/", + ], + ), + ) + + # Test with invalid GPU string + with self.assertRaises(ValueError): + self.m_mv._enrich_inference_engine_args( + service_ops.InferenceEngineArgs( + inference_engine=inference_engine.InferenceEngine.VLLM, + inference_engine_args_override=None, + ), + gpu_requests="invalid", + ) + + # Test with zero GPU (should not set tensor-parallel-size) + with self.assertRaises(ValueError): + self.m_mv._enrich_inference_engine_args( + service_ops.InferenceEngineArgs( + inference_engine=inference_engine.InferenceEngine.VLLM, + inference_engine_args_override=None, + ), + gpu_requests=0, + ) + + # Test with negative GPU (should not set tensor-parallel-size) + with self.assertRaises(ValueError): + self.m_mv._enrich_inference_engine_args( + service_ops.InferenceEngineArgs( + inference_engine=inference_engine.InferenceEngine.VLLM, + inference_engine_args_override=None, + ), + gpu_requests=-1, + ) + + def test_check_huggingface_text_generation_model(self) -> None: + """Test _check_huggingface_text_generation_model method.""" + # Test successful case - HuggingFace text-generation model + with mock.patch.object( + self.m_mv._model_ops, + "_fetch_model_spec", + return_value={ + "model_type": "huggingface_pipeline", + "models": { + "model1": { + "model_type": "huggingface_pipeline", + "options": {"task": "text-generation"}, + } + }, + }, + ) as mock_fetch: + # Should not raise any exception + self.m_mv._check_huggingface_text_generation_model() + mock_fetch.assert_called_once_with( + database_name=None, + schema_name=None, + model_name=self.m_mv._model_name, + version_name=self.m_mv._version_name, + statement_params=None, + ) + + # Test failure case - not a HuggingFace model + with mock.patch.object( + self.m_mv._model_ops, + "_fetch_model_spec", + return_value={ + "model_type": "sklearn", + "models": {"model1": {"model_type": "sklearn"}}, + }, + ): + with self.assertRaises(ValueError) as cm: + self.m_mv._check_huggingface_text_generation_model() + self.assertIn( + "Inference engine is only supported for HuggingFace text-generation models", str(cm.exception) + ) + self.assertIn("Found model_type: sklearn", str(cm.exception)) + + # Test failure case - HuggingFace model but wrong task + with mock.patch.object( + self.m_mv._model_ops, + "_fetch_model_spec", + return_value={ + "model_type": "huggingface_pipeline", + "models": { + "model1": { + "model_type": "huggingface_pipeline", + "options": {"task": "image-classification"}, + } + }, + }, + ): + with self.assertRaises(ValueError) as cm: + self.m_mv._check_huggingface_text_generation_model() + self.assertIn("Inference engine is only supported for task 'text-generation'", str(cm.exception)) + self.assertIn("Found task(s): image-classification", str(cm.exception)) + + # Test failure case - HuggingFace model with no task + with mock.patch.object( + self.m_mv._model_ops, + "_fetch_model_spec", + return_value={ + "model_type": "huggingface_pipeline", + "models": { + "model1": { + "model_type": "huggingface_pipeline", + "options": {}, + } + }, + }, + ): + with self.assertRaises(ValueError) as cm: + self.m_mv._check_huggingface_text_generation_model() + self.assertIn("Inference engine is only supported for task 'text-generation'", str(cm.exception)) + self.assertIn("No task found in model spec.", str(cm.exception)) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/_client/ops/service_ops.py b/snowflake/ml/model/_client/ops/service_ops.py index 9fa6f4fc..8bdac82f 100644 --- a/snowflake/ml/model/_client/ops/service_ops.py +++ b/snowflake/ml/model/_client/ops/service_ops.py @@ -12,7 +12,11 @@ from snowflake import snowpark from snowflake.ml._internal import file_utils, platform_capabilities as pc from snowflake.ml._internal.utils import identifier, service_logger, sql_identifier -from snowflake.ml.model import model_signature, type_hints +from snowflake.ml.model import ( + inference_engine as inference_engine_module, + model_signature, + type_hints, +) from snowflake.ml.model._client.service import model_deployment_spec from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql from snowflake.ml.model._signatures import snowpark_handler @@ -131,6 +135,12 @@ class HFModelArgs: warehouse: Optional[str] = None +@dataclasses.dataclass +class InferenceEngineArgs: + inference_engine: inference_engine_module.InferenceEngine + inference_engine_args_override: Optional[list[str]] = None + + class ServiceOperator: """Service operator for container services logic.""" @@ -180,7 +190,7 @@ def create_service( service_name: sql_identifier.SqlIdentifier, image_build_compute_pool_name: sql_identifier.SqlIdentifier, service_compute_pool_name: sql_identifier.SqlIdentifier, - image_repo: str, + image_repo_name: Optional[str], ingress_enabled: bool, max_instances: int, cpu_requests: Optional[str], @@ -195,6 +205,8 @@ def create_service( statement_params: Optional[dict[str, Any]] = None, # hf model hf_model_args: Optional[HFModelArgs] = None, + # inference engine model + inference_engine_args: Optional[InferenceEngineArgs] = None, ) -> Union[str, async_job.AsyncJob]: # Generate operation ID for this deployment @@ -205,15 +217,14 @@ def create_service( schema_name = schema_name or self._schema_name # Fall back to the model's database and schema if not provided then to the registry's database and schema - service_database_name = service_database_name or database_name or self._database_name - service_schema_name = service_schema_name or schema_name or self._schema_name + service_database_name = service_database_name or database_name + service_schema_name = service_schema_name or schema_name - # Parse image repo - image_repo_database_name, image_repo_schema_name, image_repo_name = sql_identifier.parse_fully_qualified_name( - image_repo - ) - image_repo_database_name = image_repo_database_name or database_name or self._database_name - image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name + image_repo_fqn = ServiceOperator._get_image_repo_fqn(image_repo_name, database_name, schema_name) + + # There may be more conditions to enable image build in the future + # For now, we only enable image build if inference engine is not specified + is_enable_image_build = inference_engine_args is None # Step 1: Preparing deployment artifacts progress_status.update("preparing deployment artifacts...") @@ -230,14 +241,15 @@ def create_service( model_name=model_name, version_name=version_name, ) - self._model_deployment_spec.add_image_build_spec( - image_build_compute_pool_name=image_build_compute_pool_name, - image_repo_database_name=image_repo_database_name, - image_repo_schema_name=image_repo_schema_name, - image_repo_name=image_repo_name, - force_rebuild=force_rebuild, - external_access_integrations=build_external_access_integrations, - ) + + if is_enable_image_build: + self._model_deployment_spec.add_image_build_spec( + image_build_compute_pool_name=image_build_compute_pool_name, + fully_qualified_image_repo_name=image_repo_fqn, + force_rebuild=force_rebuild, + external_access_integrations=build_external_access_integrations, + ) + self._model_deployment_spec.add_service_spec( service_database_name=service_database_name, service_schema_name=service_schema_name, @@ -266,6 +278,13 @@ def create_service( warehouse=hf_model_args.warehouse, **(hf_model_args.hf_model_kwargs if hf_model_args.hf_model_kwargs else {}), ) + + if inference_engine_args: + self._model_deployment_spec.add_inference_engine_spec( + inference_engine=inference_engine_args.inference_engine, + inference_engine_args=inference_engine_args.inference_engine_args_override, + ) + spec_yaml_str_or_path = self._model_deployment_spec.save() # Step 2: Uploading deployment artifacts @@ -412,6 +431,29 @@ def create_service( return async_job + @staticmethod + def _get_image_repo_fqn( + image_repo_name: Optional[str], + database_name: sql_identifier.SqlIdentifier, + schema_name: sql_identifier.SqlIdentifier, + ) -> Optional[str]: + """Get the fully qualified name of the image repository.""" + if image_repo_name is None or image_repo_name.strip() == "": + return None + # Parse image repo + ( + image_repo_database_name, + image_repo_schema_name, + image_repo_name, + ) = sql_identifier.parse_fully_qualified_name(image_repo_name) + image_repo_database_name = image_repo_database_name or database_name + image_repo_schema_name = image_repo_schema_name or schema_name + return identifier.get_schema_level_object_identifier( + db=image_repo_database_name.identifier(), + schema=image_repo_schema_name.identifier(), + object_name=image_repo_name.identifier(), + ) + def _start_service_log_streaming( self, async_job: snowpark.AsyncJob, @@ -838,7 +880,7 @@ def invoke_job_method( job_name: sql_identifier.SqlIdentifier, compute_pool_name: sql_identifier.SqlIdentifier, warehouse_name: sql_identifier.SqlIdentifier, - image_repo: str, + image_repo_name: Optional[str], output_table_database_name: Optional[sql_identifier.SqlIdentifier], output_table_schema_name: Optional[sql_identifier.SqlIdentifier], output_table_name: sql_identifier.SqlIdentifier, @@ -859,12 +901,7 @@ def invoke_job_method( job_database_name = job_database_name or database_name or self._database_name job_schema_name = job_schema_name or schema_name or self._schema_name - # Parse image repo - image_repo_database_name, image_repo_schema_name, image_repo_name = sql_identifier.parse_fully_qualified_name( - image_repo - ) - image_repo_database_name = image_repo_database_name or database_name or self._database_name - image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name + image_repo_fqn = self._get_image_repo_fqn(image_repo_name, database_name, schema_name) input_table_database_name = job_database_name input_table_schema_name = job_schema_name @@ -948,9 +985,7 @@ def invoke_job_method( self._model_deployment_spec.add_image_build_spec( image_build_compute_pool_name=compute_pool_name, - image_repo_database_name=image_repo_database_name, - image_repo_schema_name=image_repo_schema_name, - image_repo_name=image_repo_name, + fully_qualified_image_repo_name=image_repo_fqn, force_rebuild=force_rebuild, external_access_integrations=build_external_access_integrations, ) diff --git a/snowflake/ml/model/_client/ops/service_ops_test.py b/snowflake/ml/model/_client/ops/service_ops_test.py index 509dc09e..3e11a521 100644 --- a/snowflake/ml/model/_client/ops/service_ops_test.py +++ b/snowflake/ml/model/_client/ops/service_ops_test.py @@ -10,7 +10,7 @@ from snowflake import snowpark from snowflake.ml._internal import file_utils, platform_capabilities from snowflake.ml._internal.utils import identifier, sql_identifier -from snowflake.ml.model import model_signature +from snowflake.ml.model import inference_engine, model_signature from snowflake.ml.model._client.ops import service_ops from snowflake.ml.model._client.sql import service as service_sql from snowflake.ml.model._signatures import snowpark_handler @@ -159,6 +159,10 @@ def test_create_service(self, huggingface_args: dict[str, Any]) -> None: "_wait_for_service_status", return_value=None, ), + mock.patch.object( + self.m_ops._model_deployment_spec, + "add_inference_engine_spec", + ) as mock_add_inference_engine_spec, ): self.m_ops.create_service( database_name=sql_identifier.SqlIdentifier("DB"), @@ -170,7 +174,7 @@ def test_create_service(self, huggingface_args: dict[str, Any]) -> None: service_name=sql_identifier.SqlIdentifier("MYSERVICE"), image_build_compute_pool_name=sql_identifier.SqlIdentifier("IMAGE_BUILD_COMPUTE_POOL"), service_compute_pool_name=sql_identifier.SqlIdentifier("SERVICE_COMPUTE_POOL"), - image_repo="IMAGE_REPO_DB.IMAGE_REPO_SCHEMA.IMAGE_REPO", + image_repo_name="IMAGE_REPO_DB.IMAGE_REPO_SCHEMA.IMAGE_REPO", ingress_enabled=True, max_instances=1, cpu_requests="1", @@ -184,6 +188,7 @@ def test_create_service(self, huggingface_args: dict[str, Any]) -> None: statement_params=self.m_statement_params, hf_model_args=service_ops.HFModelArgs(**huggingface_args) if huggingface_args else None, progress_status=create_mock_progress_status(), + inference_engine_args=None, ) mock_create_stage.assert_called_once_with( database_name=sql_identifier.SqlIdentifier("DB"), @@ -212,9 +217,7 @@ def test_create_service(self, huggingface_args: dict[str, Any]) -> None: ) mock_add_image_build_spec.assert_called_once_with( image_build_compute_pool_name=sql_identifier.SqlIdentifier("IMAGE_BUILD_COMPUTE_POOL"), - image_repo_database_name=sql_identifier.SqlIdentifier("IMAGE_REPO_DB"), - image_repo_schema_name=sql_identifier.SqlIdentifier("IMAGE_REPO_SCHEMA"), - image_repo_name=sql_identifier.SqlIdentifier("IMAGE_REPO"), + fully_qualified_image_repo_name="IMAGE_REPO_DB.IMAGE_REPO_SCHEMA.IMAGE_REPO", force_rebuild=True, external_access_integrations=[sql_identifier.SqlIdentifier("EXTERNAL_ACCESS_INTEGRATION")], ) @@ -248,6 +251,9 @@ def test_create_service(self, huggingface_args: dict[str, Any]) -> None: statement_params=self.m_statement_params, ) + # by default, no inference engine spec is added + mock_add_inference_engine_spec.assert_not_called() + @parameterized.parameters( # type: ignore[misc] {"huggingface_args": {}}, { @@ -331,7 +337,7 @@ def test_create_service_model_db_and_schema(self, huggingface_args: dict[str, An service_name=sql_identifier.SqlIdentifier("MYSERVICE"), image_build_compute_pool_name=sql_identifier.SqlIdentifier("IMAGE_BUILD_COMPUTE_POOL"), service_compute_pool_name=sql_identifier.SqlIdentifier("SERVICE_COMPUTE_POOL"), - image_repo="IMAGE_REPO", + image_repo_name="IMAGE_REPO", ingress_enabled=True, max_instances=1, cpu_requests="1", @@ -345,6 +351,7 @@ def test_create_service_model_db_and_schema(self, huggingface_args: dict[str, An statement_params=self.m_statement_params, hf_model_args=service_ops.HFModelArgs(**huggingface_args) if huggingface_args else None, progress_status=create_mock_progress_status(), + inference_engine_args=None, ) mock_create_stage.assert_called_once_with( database_name=sql_identifier.SqlIdentifier("DB"), @@ -374,10 +381,8 @@ def test_create_service_model_db_and_schema(self, huggingface_args: dict[str, An ) mock_add_image_build_spec.assert_called_once_with( image_build_compute_pool_name=sql_identifier.SqlIdentifier("IMAGE_BUILD_COMPUTE_POOL"), - image_repo_name=sql_identifier.SqlIdentifier("IMAGE_REPO"), + fully_qualified_image_repo_name="DB.SCHEMA.IMAGE_REPO", force_rebuild=True, - image_repo_database_name=sql_identifier.SqlIdentifier("DB"), - image_repo_schema_name=sql_identifier.SqlIdentifier("SCHEMA"), external_access_integrations=[sql_identifier.SqlIdentifier("EXTERNAL_ACCESS_INTEGRATION")], ) if huggingface_args: @@ -491,7 +496,7 @@ def test_create_service_default_db_and_schema(self, huggingface_args: dict[str, service_name=sql_identifier.SqlIdentifier("MYSERVICE"), image_build_compute_pool_name=sql_identifier.SqlIdentifier("IMAGE_BUILD_COMPUTE_POOL"), service_compute_pool_name=sql_identifier.SqlIdentifier("SERVICE_COMPUTE_POOL"), - image_repo="IMAGE_REPO", + image_repo_name="IMAGE_REPO", ingress_enabled=True, max_instances=1, cpu_requests="1", @@ -505,6 +510,7 @@ def test_create_service_default_db_and_schema(self, huggingface_args: dict[str, statement_params=self.m_statement_params, hf_model_args=service_ops.HFModelArgs(**huggingface_args) if huggingface_args else None, progress_status=create_mock_progress_status(), + inference_engine_args=None, ) mock_create_stage.assert_called_once_with( database_name=sql_identifier.SqlIdentifier("TEMP"), @@ -533,9 +539,7 @@ def test_create_service_default_db_and_schema(self, huggingface_args: dict[str, ) mock_add_image_build_spec.assert_called_once_with( image_build_compute_pool_name=sql_identifier.SqlIdentifier("IMAGE_BUILD_COMPUTE_POOL"), - image_repo_database_name=sql_identifier.SqlIdentifier("TEMP"), - image_repo_schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), - image_repo_name=sql_identifier.SqlIdentifier("IMAGE_REPO"), + fully_qualified_image_repo_name='TEMP."test".IMAGE_REPO', force_rebuild=True, external_access_integrations=[sql_identifier.SqlIdentifier("EXTERNAL_ACCESS_INTEGRATION")], ) @@ -638,7 +642,7 @@ def test_create_service_async_job(self, huggingface_args: dict[str, Any]) -> Non service_name=sql_identifier.SqlIdentifier("MYSERVICE"), image_build_compute_pool_name=sql_identifier.SqlIdentifier("IMAGE_BUILD_COMPUTE_POOL"), service_compute_pool_name=sql_identifier.SqlIdentifier("SERVICE_COMPUTE_POOL"), - image_repo="IMAGE_REPO_DB.IMAGE_REPO_SCHEMA.IMAGE_REPO", + image_repo_name="IMAGE_REPO_DB.IMAGE_REPO_SCHEMA.IMAGE_REPO", ingress_enabled=True, max_instances=1, cpu_requests="1", @@ -652,6 +656,7 @@ def test_create_service_async_job(self, huggingface_args: dict[str, Any]) -> Non statement_params=self.m_statement_params, hf_model_args=service_ops.HFModelArgs(**huggingface_args) if huggingface_args else None, progress_status=create_mock_progress_status(), + inference_engine_args=None, ) self.assertIsInstance(res, snowpark.AsyncJob) @@ -788,7 +793,7 @@ def test_invoke_job_method( job_name=sql_identifier.SqlIdentifier("JOB"), compute_pool_name=sql_identifier.SqlIdentifier("COMPUTE_POOL"), warehouse_name=sql_identifier.SqlIdentifier("WAREHOUSE"), - image_repo=image_repo_fqn, + image_repo_name=image_repo_fqn, output_table_database_name=output_table_database_name[0], output_table_schema_name=output_table_schema_name[0], output_table_name=sql_identifier.SqlIdentifier("OUTPUT_TABLE"), @@ -833,11 +838,12 @@ def test_invoke_job_method( output_table_schema_name=output_table_schema_name[1], output_table_name=sql_identifier.SqlIdentifier("OUTPUT_TABLE"), ) + image_repo_fqn = identifier.get_schema_level_object_identifier( + image_repo_database_name[1], image_repo_schema_name[1], "IMAGE_REPO" + ) mock_add_image_build_spec.assert_called_once_with( - image_repo_database_name=image_repo_database_name[1], - image_repo_schema_name=image_repo_schema_name[1], image_build_compute_pool_name=sql_identifier.SqlIdentifier("COMPUTE_POOL"), - image_repo_name=sql_identifier.SqlIdentifier("IMAGE_REPO"), + fully_qualified_image_repo_name=image_repo_fqn, force_rebuild=True, external_access_integrations=[sql_identifier.SqlIdentifier("EAI")], ) @@ -925,7 +931,7 @@ def test_create_service_uses_operation_id_for_logging(self) -> None: service_name=sql_identifier.SqlIdentifier("MYSERVICE"), image_build_compute_pool_name=sql_identifier.SqlIdentifier("IMAGE_BUILD_COMPUTE_POOL"), service_compute_pool_name=sql_identifier.SqlIdentifier("SERVICE_COMPUTE_POOL"), - image_repo="IMAGE_REPO_DB.IMAGE_REPO_SCHEMA.IMAGE_REPO", + image_repo_name="IMAGE_REPO_DB.IMAGE_REPO_SCHEMA.IMAGE_REPO", ingress_enabled=True, max_instances=1, cpu_requests="1", @@ -939,6 +945,7 @@ def test_create_service_uses_operation_id_for_logging(self) -> None: statement_params=self.m_statement_params, hf_model_args=None, progress_status=create_mock_progress_status(), + inference_engine_args=None, ) def test_get_model_build_service_name(self) -> None: @@ -963,6 +970,335 @@ def test_get_model_build_service_name(self) -> None: expected, ) + def test_create_service_custom_inference_engine(self) -> None: + """Test create_service with custom inference engine parameters.""" + self._add_snowflake_version_check_mock_operations(self.m_session) + m_statuses = [ + service_sql.ServiceStatusInfo( + service_status=service_sql.ServiceStatus.PENDING, + instance_id=0, + instance_status=service_sql.InstanceStatus.PENDING, + container_status=service_sql.ContainerStatus.PENDING, + message=None, + ) + ] + + # Define test inference engine kwargs + test_inference_engine_args = [ + "--model=model/DB.SCHEMA.MODEL/versions/VERSION/", + "--tensor-parallel-size=2", + "--max_tokens=1000", + "--temperature=0.8", + ] + + with ( + mock.patch.object( + self.m_ops._stage_client, + "create_tmp_stage", + ) as mock_create_stage, + mock.patch.object( + snowpark_utils, "random_name_for_temp_object", return_value="SNOWPARK_TEMP_STAGE_ABCDEF0123" + ), + mock.patch.object( + self.m_ops._model_deployment_spec, + "save", + ) as mock_save, + mock.patch.object( + self.m_ops._model_deployment_spec, + "add_model_spec", + ) as mock_add_model_spec, + mock.patch.object( + self.m_ops._model_deployment_spec, + "add_service_spec", + ) as mock_add_service_spec, + mock.patch.object( + self.m_ops._model_deployment_spec, + "add_image_build_spec", + ) as mock_add_image_build_spec, + mock.patch.object( + self.m_ops._model_deployment_spec, + "add_inference_engine_spec", + ) as mock_add_inference_engine_spec, + mock.patch.object( + file_utils, "upload_directory_to_stage", return_value=None + ) as mock_upload_directory_to_stage, + mock.patch.object( + self.m_ops._service_client, + "deploy_model", + return_value=(str(uuid.uuid4()), self._create_mock_async_job()), + ) as mock_deploy_model, + mock.patch.object( + self.m_ops._service_client, + "get_service_container_statuses", + return_value=m_statuses, + ) as mock_get_service_container_statuses, + mock.patch.object( + self.m_ops._service_client, + "get_service_logs", + return_value="", + ), + mock.patch.object( + self.m_ops, + "_wait_for_service_status", + return_value=None, + ), + ): + # Call create_service with inference engine parameters + self.m_ops.create_service( + database_name=sql_identifier.SqlIdentifier("DB"), + schema_name=sql_identifier.SqlIdentifier("SCHEMA"), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("VERSION"), + service_database_name=sql_identifier.SqlIdentifier("SERVICE_DB"), + service_schema_name=sql_identifier.SqlIdentifier("SERVICE_SCHEMA"), + service_name=sql_identifier.SqlIdentifier("MYSERVICE"), + image_build_compute_pool_name=sql_identifier.SqlIdentifier("IMAGE_BUILD_COMPUTE_POOL"), + service_compute_pool_name=sql_identifier.SqlIdentifier("SERVICE_COMPUTE_POOL"), + image_repo_name="IMAGE_REPO_DB.IMAGE_REPO_SCHEMA.IMAGE_REPO", + ingress_enabled=True, + max_instances=1, + cpu_requests="1", + memory_requests="6GiB", + gpu_requests="2", # This should match tensor-parallel-size + num_workers=1, + max_batch_rows=1024, + force_rebuild=True, + build_external_access_integrations=[sql_identifier.SqlIdentifier("EXTERNAL_ACCESS_INTEGRATION")], + block=True, + statement_params=self.m_statement_params, + inference_engine_args=service_ops.InferenceEngineArgs( + inference_engine=inference_engine.InferenceEngine.VLLM, + inference_engine_args_override=test_inference_engine_args, + ), + progress_status=create_mock_progress_status(), + ) + + # Verify all the standard method calls + mock_create_stage.assert_called_once_with( + database_name=sql_identifier.SqlIdentifier("DB"), + schema_name=sql_identifier.SqlIdentifier("SCHEMA"), + stage_name=sql_identifier.SqlIdentifier("SNOWPARK_TEMP_STAGE_ABCDEF0123"), + statement_params=self.m_statement_params, + ) + mock_add_model_spec.assert_called_once_with( + database_name=sql_identifier.SqlIdentifier("DB"), + schema_name=sql_identifier.SqlIdentifier("SCHEMA"), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("VERSION"), + ) + mock_add_service_spec.assert_called_once_with( + service_database_name=sql_identifier.SqlIdentifier("SERVICE_DB"), + service_schema_name=sql_identifier.SqlIdentifier("SERVICE_SCHEMA"), + service_name=sql_identifier.SqlIdentifier("MYSERVICE"), + inference_compute_pool_name=sql_identifier.SqlIdentifier("SERVICE_COMPUTE_POOL"), + ingress_enabled=True, + max_instances=1, + cpu="1", + memory="6GiB", + gpu="2", + num_workers=1, + max_batch_rows=1024, + ) + + # This is the key assertion - verify add_inference_engine_spec was called + mock_add_inference_engine_spec.assert_called_once_with( + inference_engine=inference_engine.InferenceEngine.VLLM, inference_engine_args=test_inference_engine_args + ) + + mock_add_image_build_spec.assert_not_called() + mock_save.assert_called_once() + + mock_upload_directory_to_stage.assert_called_once_with( + self.c_session, + local_path=self.m_ops._model_deployment_spec.workspace_path, + stage_path=pathlib.PurePosixPath( + self.m_ops._stage_client.fully_qualified_object_name( + sql_identifier.SqlIdentifier("DB"), + sql_identifier.SqlIdentifier("SCHEMA"), + sql_identifier.SqlIdentifier("SNOWPARK_TEMP_STAGE_ABCDEF0123"), + ) + ), + statement_params=self.m_statement_params, + ) + mock_deploy_model.assert_called_once_with( + stage_path="DB.SCHEMA.SNOWPARK_TEMP_STAGE_ABCDEF0123", + model_deployment_spec_file_rel_path=self.m_ops._model_deployment_spec.DEPLOY_SPEC_FILE_REL_PATH, + model_deployment_spec_yaml_str=None, + statement_params=self.m_statement_params, + ) + mock_get_service_container_statuses.assert_called_once_with( + database_name=sql_identifier.SqlIdentifier("SERVICE_DB"), + schema_name=sql_identifier.SqlIdentifier("SERVICE_SCHEMA"), + service_name=sql_identifier.SqlIdentifier("MYSERVICE"), + include_message=False, + statement_params=self.m_statement_params, + ) + + def test_create_service_with_inference_engine_and_no_image_build(self) -> None: + """Test create_service with custom inference engine parameters and no image build.""" + self._add_snowflake_version_check_mock_operations(self.m_session) + m_statuses = [ + service_sql.ServiceStatusInfo( + service_status=service_sql.ServiceStatus.PENDING, + instance_id=0, + instance_status=service_sql.InstanceStatus.PENDING, + container_status=service_sql.ContainerStatus.PENDING, + message=None, + ) + ] + + # Define test inference engine kwargs + test_inference_engine_args = [ + "--model=model/DB.SCHEMA.MODEL/versions/VERSION/", + "--tensor-parallel-size=2", + "--max_tokens=1000", + "--temperature=0.8", + ] + + with ( + mock.patch.object( + self.m_ops._stage_client, + "create_tmp_stage", + ) as mock_create_stage, + mock.patch.object( + snowpark_utils, "random_name_for_temp_object", return_value="SNOWPARK_TEMP_STAGE_ABCDEF0123" + ), + mock.patch.object( + self.m_ops._model_deployment_spec, + "save", + ) as mock_save, + mock.patch.object( + self.m_ops._model_deployment_spec, + "add_model_spec", + ) as mock_add_model_spec, + mock.patch.object( + self.m_ops._model_deployment_spec, + "add_service_spec", + ) as mock_add_service_spec, + mock.patch.object( + self.m_ops._model_deployment_spec, + "add_image_build_spec", + ) as mock_add_image_build_spec, + mock.patch.object( + self.m_ops._model_deployment_spec, + "add_inference_engine_spec", + ) as mock_add_inference_engine_spec, + mock.patch.object( + file_utils, "upload_directory_to_stage", return_value=None + ) as mock_upload_directory_to_stage, + mock.patch.object( + self.m_ops._service_client, + "deploy_model", + return_value=(str(uuid.uuid4()), mock.MagicMock(spec=snowpark.AsyncJob)), + ) as mock_deploy_model, + mock.patch.object( + self.m_ops._service_client, + "get_service_container_statuses", + return_value=m_statuses, + ) as mock_get_service_container_statuses, + mock.patch.object( + self.m_ops._service_client, + "get_service_logs", + return_value="", # Return empty logs to prevent SQL calls + ), + mock.patch.object( + self.m_ops, + "_wait_for_service_status", + return_value=None, + ), + ): + # Call create_service with inference engine parameters + self.m_ops.create_service( + database_name=sql_identifier.SqlIdentifier("DB"), + schema_name=sql_identifier.SqlIdentifier("SCHEMA"), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("VERSION"), + service_database_name=sql_identifier.SqlIdentifier("SERVICE_DB"), + service_schema_name=sql_identifier.SqlIdentifier("SERVICE_SCHEMA"), + service_name=sql_identifier.SqlIdentifier("MYSERVICE"), + image_build_compute_pool_name=sql_identifier.SqlIdentifier("IMAGE_BUILD_COMPUTE_POOL"), + service_compute_pool_name=sql_identifier.SqlIdentifier("SERVICE_COMPUTE_POOL"), + image_repo_name="IMAGE_REPO_DB.IMAGE_REPO_SCHEMA.IMAGE_REPO", + ingress_enabled=True, + max_instances=1, + cpu_requests="1", + memory_requests="6GiB", + gpu_requests="2", # This should match tensor-parallel-size + num_workers=1, + max_batch_rows=1024, + force_rebuild=True, + build_external_access_integrations=[sql_identifier.SqlIdentifier("EXTERNAL_ACCESS_INTEGRATION")], + block=True, + statement_params=self.m_statement_params, + inference_engine_args=service_ops.InferenceEngineArgs( + inference_engine=inference_engine.InferenceEngine.VLLM, + inference_engine_args_override=test_inference_engine_args, + ), + progress_status=create_mock_progress_status(), + ) + + # Verify all the standard method calls + mock_create_stage.assert_called_once_with( + database_name=sql_identifier.SqlIdentifier("DB"), + schema_name=sql_identifier.SqlIdentifier("SCHEMA"), + stage_name=sql_identifier.SqlIdentifier("SNOWPARK_TEMP_STAGE_ABCDEF0123"), + statement_params=self.m_statement_params, + ) + mock_add_model_spec.assert_called_once_with( + database_name=sql_identifier.SqlIdentifier("DB"), + schema_name=sql_identifier.SqlIdentifier("SCHEMA"), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("VERSION"), + ) + mock_add_service_spec.assert_called_once_with( + service_database_name=sql_identifier.SqlIdentifier("SERVICE_DB"), + service_schema_name=sql_identifier.SqlIdentifier("SERVICE_SCHEMA"), + service_name=sql_identifier.SqlIdentifier("MYSERVICE"), + inference_compute_pool_name=sql_identifier.SqlIdentifier("SERVICE_COMPUTE_POOL"), + ingress_enabled=True, + max_instances=1, + cpu="1", + memory="6GiB", + gpu="2", + num_workers=1, + max_batch_rows=1024, + ) + + # key assertions -- image build is not called and inference engine model is called + # when inference engine is specified + mock_add_image_build_spec.assert_not_called() + mock_add_inference_engine_spec.assert_called_once_with( + inference_engine=inference_engine.InferenceEngine.VLLM, + inference_engine_args=test_inference_engine_args, + ) + + mock_save.assert_called_once() + mock_upload_directory_to_stage.assert_called_once_with( + self.c_session, + local_path=self.m_ops._model_deployment_spec.workspace_path, + stage_path=pathlib.PurePosixPath( + self.m_ops._stage_client.fully_qualified_object_name( + sql_identifier.SqlIdentifier("DB"), + sql_identifier.SqlIdentifier("SCHEMA"), + sql_identifier.SqlIdentifier("SNOWPARK_TEMP_STAGE_ABCDEF0123"), + ) + ), + statement_params=self.m_statement_params, + ) + mock_deploy_model.assert_called_once_with( + stage_path="DB.SCHEMA.SNOWPARK_TEMP_STAGE_ABCDEF0123", + model_deployment_spec_file_rel_path=self.m_ops._model_deployment_spec.DEPLOY_SPEC_FILE_REL_PATH, + model_deployment_spec_yaml_str=None, + statement_params=self.m_statement_params, + ) + mock_get_service_container_statuses.assert_called_once_with( + database_name=sql_identifier.SqlIdentifier("SERVICE_DB"), + schema_name=sql_identifier.SqlIdentifier("SERVICE_SCHEMA"), + service_name=sql_identifier.SqlIdentifier("MYSERVICE"), + include_message=False, + statement_params=self.m_statement_params, + ) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/_client/service/BUILD.bazel b/snowflake/ml/model/_client/service/BUILD.bazel index 2b35cb9e..d226565c 100644 --- a/snowflake/ml/model/_client/service/BUILD.bazel +++ b/snowflake/ml/model/_client/service/BUILD.bazel @@ -17,6 +17,9 @@ py_library( srcs = ["model_deployment_spec.py"], deps = [ ":model_deployment_spec_schema", + "//snowflake/ml/_internal/utils:identifier", + "//snowflake/ml/_internal/utils:sql_identifier", + "//snowflake/ml/model:inference_engine", ], ) @@ -27,5 +30,6 @@ py_test( deps = [ ":model_deployment_spec", "//snowflake/ml/_internal/utils:sql_identifier", + "//snowflake/ml/model:inference_engine", ], ) diff --git a/snowflake/ml/model/_client/service/model_deployment_spec.py b/snowflake/ml/model/_client/service/model_deployment_spec.py index 8d7ef994..0eb26eda 100644 --- a/snowflake/ml/model/_client/service/model_deployment_spec.py +++ b/snowflake/ml/model/_client/service/model_deployment_spec.py @@ -1,10 +1,12 @@ import json import pathlib +import warnings from typing import Any, Optional, Union import yaml from snowflake.ml._internal.utils import identifier, sql_identifier +from snowflake.ml.model import inference_engine as inference_engine_module from snowflake.ml.model._client.service import model_deployment_spec_schema @@ -24,6 +26,8 @@ def __init__(self, workspace_path: Optional[pathlib.Path] = None) -> None: self._service: Optional[model_deployment_spec_schema.Service] = None self._job: Optional[model_deployment_spec_schema.Job] = None self._model_loggings: Optional[list[model_deployment_spec_schema.ModelLogging]] = None + # this is referring to custom inference engine spec (vllm, sglang, etc) + self._inference_engine_spec: Optional[model_deployment_spec_schema.InferenceEngineSpec] = None self._inference_spec: dict[str, Any] = {} # Common inference spec for service/job self.database: Optional[sql_identifier.SqlIdentifier] = None @@ -71,10 +75,8 @@ def add_model_spec( def add_image_build_spec( self, - image_build_compute_pool_name: sql_identifier.SqlIdentifier, - image_repo_name: sql_identifier.SqlIdentifier, - image_repo_database_name: Optional[sql_identifier.SqlIdentifier] = None, - image_repo_schema_name: Optional[sql_identifier.SqlIdentifier] = None, + image_build_compute_pool_name: Optional[sql_identifier.SqlIdentifier] = None, + fully_qualified_image_repo_name: Optional[str] = None, force_rebuild: bool = False, external_access_integrations: Optional[list[sql_identifier.SqlIdentifier]] = None, ) -> "ModelDeploymentSpec": @@ -82,33 +84,29 @@ def add_image_build_spec( Args: image_build_compute_pool_name: Compute pool for image building. - image_repo_name: Name of the image repository. - image_repo_database_name: Database name for the image repository. - image_repo_schema_name: Schema name for the image repository. + fully_qualified_image_repo_name: Fully qualified name of the image repository. force_rebuild: Whether to force rebuilding the image. external_access_integrations: List of external access integrations. Returns: Self for chaining. """ - saved_image_repo_database = image_repo_database_name or self.database - saved_image_repo_schema = image_repo_schema_name or self.schema - assert saved_image_repo_database is not None - assert saved_image_repo_schema is not None - fq_image_repo_name = identifier.get_schema_level_object_identifier( - db=saved_image_repo_database.identifier(), - schema=saved_image_repo_schema.identifier(), - object_name=image_repo_name.identifier(), - ) - - self._image_build = model_deployment_spec_schema.ImageBuild( - compute_pool=image_build_compute_pool_name.identifier(), - image_repo=fq_image_repo_name, - force_rebuild=force_rebuild, - external_access_integrations=( - [eai.identifier() for eai in external_access_integrations] if external_access_integrations else None - ), - ) + if ( + image_build_compute_pool_name is not None + or fully_qualified_image_repo_name is not None + or force_rebuild is True + or external_access_integrations is not None + ): + self._image_build = model_deployment_spec_schema.ImageBuild( + compute_pool=( + None if image_build_compute_pool_name is None else image_build_compute_pool_name.identifier() + ), + image_repo=fully_qualified_image_repo_name, + force_rebuild=force_rebuild, + external_access_integrations=( + [eai.identifier() for eai in external_access_integrations] if external_access_integrations else None + ), + ) return self def _add_inference_spec( @@ -363,6 +361,86 @@ def add_hf_logger_spec( self._model_loggings.append(model_logging) return self + def add_inference_engine_spec( + self, + inference_engine: inference_engine_module.InferenceEngine, + inference_engine_args: Optional[list[str]] = None, + ) -> "ModelDeploymentSpec": + """Add inference engine specification. This must be called after self.add_service_spec(). + + Args: + inference_engine: Inference engine. + inference_engine_args: Inference engine arguments. + + Returns: + Self for chaining. + + Raises: + ValueError: If inference engine specification is called before add_service_spec(). + ValueError: If the argument does not have a '--' prefix. + """ + # TODO: needs to eventually support job deployment spec + if self._service is None: + raise ValueError("Inference engine specification must be called after add_service_spec().") + + if inference_engine_args is None: + inference_engine_args = [] + + # Validate inference engine + if inference_engine == inference_engine_module.InferenceEngine.VLLM: + # Block list for VLLM args that should not be user-configurable + # make this a set for faster lookup + block_list = { + "--host", + "--port", + "--allowed-headers", + "--api-key", + "--lora-modules", + "--prompt-adapter", + "--ssl-keyfile", + "--ssl-certfile", + "--ssl-ca-certs", + "--enable-ssl-refresh", + "--ssl-cert-reqs", + "--root-path", + "--middleware", + "--disable-frontend-multiprocessing", + "--enable-request-id-headers", + "--enable-auto-tool-choice", + "--tool-call-parser", + "--tool-parser-plugin", + "--log-config-file", + } + + filtered_args = [] + for arg in inference_engine_args: + # Check if the argument has a '--' prefix + if not arg.startswith("--"): + raise ValueError( + f"""The argument {arg} is not allowed for configuration in Snowflake ML's + {inference_engine.value} inference engine. Maybe you forgot to add '--' prefix?""", + ) + + # Filter out blocked args and warn user + if arg.split("=")[0] in block_list: + warnings.warn( + f"""The argument {arg} is not allowed for configuration in Snowflake ML's + {inference_engine.value} inference engine. It will be ignored.""", + UserWarning, + stacklevel=2, + ) + else: + filtered_args.append(arg) + + inference_engine_args = filtered_args + + self._service.inference_engine_spec = model_deployment_spec_schema.InferenceEngineSpec( + # convert to string to be saved in the deployment spec + inference_engine_name=inference_engine.value, + inference_engine_args=inference_engine_args, + ) + return self + def save(self) -> str: """Constructs the final deployment spec from added components and saves it. @@ -377,8 +455,6 @@ def save(self) -> str: # Validations if not self._models: raise ValueError("Model specification is required. Call add_model_spec().") - if not self._image_build: - raise ValueError("Image build specification is required. Call add_image_build_spec().") if not self._service and not self._job: raise ValueError( "Either service or job specification is required. Call add_service_spec() or add_job_spec()." diff --git a/snowflake/ml/model/_client/service/model_deployment_spec_schema.py b/snowflake/ml/model/_client/service/model_deployment_spec_schema.py index 4aa896a2..f7fdf8f9 100644 --- a/snowflake/ml/model/_client/service/model_deployment_spec_schema.py +++ b/snowflake/ml/model/_client/service/model_deployment_spec_schema.py @@ -10,10 +10,15 @@ class Model(BaseModel): version: str +class InferenceEngineSpec(BaseModel): + inference_engine_name: str + inference_engine_args: Optional[list[str]] = None + + class ImageBuild(BaseModel): - compute_pool: str - image_repo: str - force_rebuild: bool + compute_pool: Optional[str] = None + image_repo: Optional[str] = None + force_rebuild: Optional[bool] = None external_access_integrations: Optional[list[str]] = None @@ -27,6 +32,7 @@ class Service(BaseModel): gpu: Optional[str] = None num_workers: Optional[int] = None max_batch_rows: Optional[int] = None + inference_engine_spec: Optional[InferenceEngineSpec] = None class Job(BaseModel): @@ -68,13 +74,13 @@ class ModelLogging(BaseModel): class ModelServiceDeploymentSpec(BaseModel): models: list[Model] - image_build: ImageBuild + image_build: Optional[ImageBuild] = None service: Service model_loggings: Optional[list[ModelLogging]] = None class ModelJobDeploymentSpec(BaseModel): models: list[Model] - image_build: ImageBuild + image_build: Optional[ImageBuild] = None job: Job model_loggings: Optional[list[ModelLogging]] = None diff --git a/snowflake/ml/model/_client/service/model_deployment_spec_test.py b/snowflake/ml/model/_client/service/model_deployment_spec_test.py index 0351851a..5dd762f5 100644 --- a/snowflake/ml/model/_client/service/model_deployment_spec_test.py +++ b/snowflake/ml/model/_client/service/model_deployment_spec_test.py @@ -5,6 +5,7 @@ from absl.testing import absltest, parameterized from snowflake.ml._internal.utils import sql_identifier +from snowflake.ml.model import inference_engine from snowflake.ml.model._client.service import model_deployment_spec @@ -20,7 +21,7 @@ def test_minimal(self) -> None: ) mds.add_image_build_spec( image_build_compute_pool_name=sql_identifier.SqlIdentifier("image_build_compute_pool"), - image_repo_name=sql_identifier.SqlIdentifier("image_repo"), + fully_qualified_image_repo_name="DB.SCHEMA.IMAGE_REPO", ) mds.add_service_spec( service_name=sql_identifier.SqlIdentifier("service"), @@ -61,7 +62,7 @@ def test_minimal_inline_yaml(self) -> None: ) mds.add_image_build_spec( image_build_compute_pool_name=sql_identifier.SqlIdentifier("image_build_compute_pool"), - image_repo_name=sql_identifier.SqlIdentifier("image_repo"), + fully_qualified_image_repo_name="DB.SCHEMA.IMAGE_REPO", ) mds.add_service_spec( service_name=sql_identifier.SqlIdentifier("service"), @@ -104,7 +105,7 @@ def test_minimal_case_sensitive(self) -> None: image_build_compute_pool_name=sql_identifier.SqlIdentifier( "image_build_compute_pool", case_sensitive=True ), - image_repo_name=sql_identifier.SqlIdentifier("image_repo", case_sensitive=True), + fully_qualified_image_repo_name='"db"."schema"."image_repo"', ) mds.add_service_spec( service_name=sql_identifier.SqlIdentifier("service", case_sensitive=True), @@ -161,9 +162,7 @@ def test_full( ) mds.add_image_build_spec( image_build_compute_pool_name=sql_identifier.SqlIdentifier("image_build_compute_pool"), - image_repo_database_name=sql_identifier.SqlIdentifier("image_repo_db"), - image_repo_schema_name=sql_identifier.SqlIdentifier("image_repo_schema"), - image_repo_name=sql_identifier.SqlIdentifier("image_repo"), + fully_qualified_image_repo_name="IMAGE_REPO_DB.IMAGE_REPO_SCHEMA.IMAGE_REPO", force_rebuild=force_rebuild, external_access_integrations=[sql_identifier.SqlIdentifier("external_access_integration")], ) @@ -221,7 +220,7 @@ def test_no_eai(self) -> None: ) mds.add_image_build_spec( image_build_compute_pool_name=sql_identifier.SqlIdentifier("image_build_compute_pool"), - image_repo_name=sql_identifier.SqlIdentifier("image_repo"), + fully_qualified_image_repo_name="DB.SCHEMA.IMAGE_REPO", external_access_integrations=None, # Explicitly None ) mds.add_service_spec( @@ -266,9 +265,7 @@ def test_job(self) -> None: ) mds.add_image_build_spec( image_build_compute_pool_name=sql_identifier.SqlIdentifier("image_build_compute_pool"), - image_repo_database_name=sql_identifier.SqlIdentifier("image_repo_db"), - image_repo_schema_name=sql_identifier.SqlIdentifier("image_repo_schema"), - image_repo_name=sql_identifier.SqlIdentifier("image_repo"), + fully_qualified_image_repo_name="IMAGE_REPO_DB.IMAGE_REPO_SCHEMA.IMAGE_REPO", force_rebuild=True, external_access_integrations=[sql_identifier.SqlIdentifier("external_access_integration")], ) @@ -334,7 +331,7 @@ def test_hf_config(self) -> None: ) mds.add_image_build_spec( image_build_compute_pool_name=sql_identifier.SqlIdentifier("image_build_compute_pool"), - image_repo_name=sql_identifier.SqlIdentifier("image_repo"), + fully_qualified_image_repo_name="DB.SCHEMA.IMAGE_REPO", ) mds.add_service_spec( service_name=sql_identifier.SqlIdentifier("service"), @@ -409,7 +406,7 @@ def test_hf_config_without_hf_model_name_raises(self) -> None: ) mds.add_image_build_spec( image_build_compute_pool_name=sql_identifier.SqlIdentifier("image_build_compute_pool"), - image_repo_name=sql_identifier.SqlIdentifier("image_repo"), + fully_qualified_image_repo_name="DB.SCHEMA.IMAGE_REPO", ) mds.add_service_spec( service_name=sql_identifier.SqlIdentifier("service"), @@ -430,9 +427,7 @@ def test_missing_model_spec_raises(self) -> None: mds = model_deployment_spec.ModelDeploymentSpec() mds.add_image_build_spec( image_build_compute_pool_name=sql_identifier.SqlIdentifier("pool"), - image_repo_name=sql_identifier.SqlIdentifier("repo"), - image_repo_database_name=sql_identifier.SqlIdentifier("db"), - image_repo_schema_name=sql_identifier.SqlIdentifier("schema"), + fully_qualified_image_repo_name="db.schema.repo", ) mds.add_service_spec( service_name=sql_identifier.SqlIdentifier("service"), @@ -444,19 +439,36 @@ def test_missing_model_spec_raises(self) -> None: mds.save() def test_missing_image_build_spec_raises(self) -> None: - mds = model_deployment_spec.ModelDeploymentSpec() - mds.add_model_spec( - database_name=sql_identifier.SqlIdentifier("db"), - schema_name=sql_identifier.SqlIdentifier("schema"), - model_name=sql_identifier.SqlIdentifier("model"), - version_name=sql_identifier.SqlIdentifier("version"), - ) - mds.add_service_spec( - service_name=sql_identifier.SqlIdentifier("service"), - inference_compute_pool_name=sql_identifier.SqlIdentifier("pool"), - ) - with self.assertRaisesRegex(ValueError, "Image build specification is required"): - mds.save() + with tempfile.TemporaryDirectory() as tmpdir: + mds = model_deployment_spec.ModelDeploymentSpec(workspace_path=pathlib.Path(tmpdir)) + mds.add_model_spec( + database_name=sql_identifier.SqlIdentifier("db"), + schema_name=sql_identifier.SqlIdentifier("schema"), + model_name=sql_identifier.SqlIdentifier("model"), + version_name=sql_identifier.SqlIdentifier("version"), + ) + mds.add_service_spec( + service_name=sql_identifier.SqlIdentifier("service"), + inference_compute_pool_name=sql_identifier.SqlIdentifier("pool"), + ) + file_path_str = mds.save() + + assert mds.workspace_path + file_path = pathlib.Path(file_path_str) + with file_path.open("r", encoding="utf-8") as f: + result = yaml.safe_load(f) + self.assertDictEqual( + result, + { + "models": [{"name": "DB.SCHEMA.MODEL", "version": "VERSION"}], + "service": { + "name": "DB.SCHEMA.SERVICE", + "compute_pool": "POOL", + "ingress_enabled": True, + "max_instances": 1, + }, + }, + ) def test_missing_service_or_job_spec_raises(self) -> None: mds = model_deployment_spec.ModelDeploymentSpec() @@ -468,7 +480,7 @@ def test_missing_service_or_job_spec_raises(self) -> None: ) mds.add_image_build_spec( image_build_compute_pool_name=sql_identifier.SqlIdentifier("pool"), - image_repo_name=sql_identifier.SqlIdentifier("repo"), + fully_qualified_image_repo_name=sql_identifier.SqlIdentifier("repo"), ) with self.assertRaisesRegex(ValueError, "Either service or job specification is required"): mds.save() @@ -529,6 +541,200 @@ def test_clear_config(self) -> None: mds.clear() self.assertLen(mds._models, 0) + def test_image_build_spec_minimal_params(self) -> None: + """Test add_image_build_spec with only required parameter and all optional parameters as None/default.""" + with tempfile.TemporaryDirectory() as tmpdir: + mds = model_deployment_spec.ModelDeploymentSpec(workspace_path=pathlib.Path(tmpdir)) + mds.add_model_spec( + database_name=sql_identifier.SqlIdentifier("db"), + schema_name=sql_identifier.SqlIdentifier("schema"), + model_name=sql_identifier.SqlIdentifier("model"), + version_name=sql_identifier.SqlIdentifier("version"), + ) + mds.add_image_build_spec( + image_build_compute_pool_name=None, # Explicitly None + fully_qualified_image_repo_name=None, # Explicitly None + force_rebuild=False, # Default value + external_access_integrations=None, # Explicitly None + ) + mds.add_service_spec( + service_name=sql_identifier.SqlIdentifier("service"), + inference_compute_pool_name=sql_identifier.SqlIdentifier("service_compute_pool"), + ingress_enabled=True, + max_instances=1, + ) + file_path_str = mds.save() + + assert mds.workspace_path + file_path = pathlib.Path(file_path_str) + with file_path.open("r", encoding="utf-8") as f: + result = yaml.safe_load(f) + self.assertDictEqual( + result, + { + "models": [{"name": "DB.SCHEMA.MODEL", "version": "VERSION"}], + "service": { + "name": "DB.SCHEMA.SERVICE", + "compute_pool": "SERVICE_COMPUTE_POOL", + "ingress_enabled": True, + "max_instances": 1, + }, + }, + ) + + def test_experimental_options_minimal(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + mds = model_deployment_spec.ModelDeploymentSpec(workspace_path=pathlib.Path(tmpdir)) + mds.add_model_spec( + database_name=sql_identifier.SqlIdentifier("db"), + schema_name=sql_identifier.SqlIdentifier("schema"), + model_name=sql_identifier.SqlIdentifier("model"), + version_name=sql_identifier.SqlIdentifier("version"), + ) + mds.add_image_build_spec( + image_build_compute_pool_name=sql_identifier.SqlIdentifier("pool"), + fully_qualified_image_repo_name="DB.SCHEMA.REPO", + ) + mds.add_service_spec( + service_name=sql_identifier.SqlIdentifier("service"), + inference_compute_pool_name=sql_identifier.SqlIdentifier("pool"), + ) + mds.add_inference_engine_spec( + inference_engine=inference_engine.InferenceEngine.VLLM, + inference_engine_args=[ + "--some_vllm_arg=0.8", + "--model=model", + "--tensor_parallel_size=2", + ], + ) + file_path_str = mds.save() + + assert mds.workspace_path + file_path = pathlib.Path(file_path_str) + with file_path.open("r", encoding="utf-8") as f: + result = yaml.safe_load(f) + self.assertDictEqual( + result, + { + "models": [{"name": "DB.SCHEMA.MODEL", "version": "VERSION"}], + "image_build": { + "compute_pool": "POOL", + "force_rebuild": False, + "image_repo": "DB.SCHEMA.REPO", + }, + "service": { + "name": "DB.SCHEMA.SERVICE", + "compute_pool": "POOL", + "ingress_enabled": True, + "max_instances": 1, + "inference_engine_spec": { + "inference_engine_name": "vllm", + "inference_engine_args": [ + "--some_vllm_arg=0.8", + "--model=model", + "--tensor_parallel_size=2", + ], + }, + }, + }, + ) + mds.clear() + + def test_experimental_options_minimal_with_blocklist_args(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + mds = model_deployment_spec.ModelDeploymentSpec(workspace_path=pathlib.Path(tmpdir)) + mds.add_model_spec( + database_name=sql_identifier.SqlIdentifier("db"), + schema_name=sql_identifier.SqlIdentifier("schema"), + model_name=sql_identifier.SqlIdentifier("model"), + version_name=sql_identifier.SqlIdentifier("version"), + ) + mds.add_image_build_spec( + image_build_compute_pool_name=sql_identifier.SqlIdentifier("pool"), + fully_qualified_image_repo_name="DB.SCHEMA.REPO", + ) + mds.add_service_spec( + service_name=sql_identifier.SqlIdentifier("service"), + inference_compute_pool_name=sql_identifier.SqlIdentifier("pool"), + ) + mds.add_inference_engine_spec( + inference_engine=inference_engine.InferenceEngine.VLLM, + inference_engine_args=[ + "--some_vllm_arg=0.8", + "--host=host", + "--port=8000", + ], + ) + file_path_str = mds.save() + + assert mds.workspace_path + file_path = pathlib.Path(file_path_str) + with file_path.open("r", encoding="utf-8") as f: + result = yaml.safe_load(f) + self.assertDictEqual( + result, + { + "models": [{"name": "DB.SCHEMA.MODEL", "version": "VERSION"}], + "image_build": { + "compute_pool": "POOL", + "force_rebuild": False, + "image_repo": "DB.SCHEMA.REPO", + }, + "service": { + "name": "DB.SCHEMA.SERVICE", + "compute_pool": "POOL", + "ingress_enabled": True, + "max_instances": 1, + "inference_engine_spec": { + "inference_engine_name": "vllm", + "inference_engine_args": ["--some_vllm_arg=0.8"], + }, + }, + }, + ) + mds.clear() + + def test_skip_image_build_with_inference_engine(self) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + mds = model_deployment_spec.ModelDeploymentSpec(workspace_path=pathlib.Path(tmpdir)) + mds.add_model_spec( + database_name=sql_identifier.SqlIdentifier("db"), + schema_name=sql_identifier.SqlIdentifier("schema"), + model_name=sql_identifier.SqlIdentifier("model"), + version_name=sql_identifier.SqlIdentifier("version"), + ) + mds.add_service_spec( + service_name=sql_identifier.SqlIdentifier("service"), + inference_compute_pool_name=sql_identifier.SqlIdentifier("pool"), + ) + mds.add_inference_engine_spec( + inference_engine=inference_engine.InferenceEngine.VLLM, + inference_engine_args=["--some_vllm_arg=0.8", "--host=host", "--port=8000"], + ) + file_path_str = mds.save() + + assert mds.workspace_path + file_path = pathlib.Path(file_path_str) + with file_path.open("r", encoding="utf-8") as f: + result = yaml.safe_load(f) + self.assertDictEqual( + result, + { + "models": [{"name": "DB.SCHEMA.MODEL", "version": "VERSION"}], + "service": { + "name": "DB.SCHEMA.SERVICE", + "compute_pool": "POOL", + "ingress_enabled": True, + "max_instances": 1, + "inference_engine_spec": { + "inference_engine_name": "vllm", + "inference_engine_args": ["--some_vllm_arg=0.8"], + }, + }, + }, + ) + mds.clear() + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/_model_composer/model_composer.py b/snowflake/ml/model/_model_composer/model_composer.py index 65b0e7a4..3ece7e25 100644 --- a/snowflake/ml/model/_model_composer/model_composer.py +++ b/snowflake/ml/model/_model_composer/model_composer.py @@ -1,17 +1,12 @@ import pathlib import tempfile import uuid -import warnings from types import ModuleType from typing import TYPE_CHECKING, Any, Optional, Union from urllib import parse -from absl import logging -from packaging import requirements - from snowflake import snowpark -from snowflake.ml import version as snowml_version -from snowflake.ml._internal import env as snowml_env, env_utils, file_utils +from snowflake.ml._internal import file_utils from snowflake.ml._internal.lineage import lineage_utils from snowflake.ml.data import data_source from snowflake.ml.model import model_signature, type_hints as model_types @@ -19,7 +14,6 @@ from snowflake.ml.model._packager import model_packager from snowflake.ml.model._packager.model_meta import model_meta from snowflake.snowpark import Session -from snowflake.snowpark._internal import utils as snowpark_utils if TYPE_CHECKING: from snowflake.ml.experiment._experiment_info import ExperimentInfo @@ -142,73 +136,10 @@ def save( experiment_info: Optional["ExperimentInfo"] = None, options: Optional[model_types.ModelSaveOption] = None, ) -> model_meta.ModelMetadata: - # set enable_explainability=False if the model is not runnable in WH or the target platforms include SPCS - conda_dep_dict = env_utils.validate_conda_dependency_string_list( - conda_dependencies if conda_dependencies else [] - ) - - enable_explainability = None - - if options: - enable_explainability = options.get("enable_explainability", None) - - # skip everything if user said False explicitly - if enable_explainability is None or enable_explainability is True: - is_warehouse_runnable = ( - not conda_dep_dict - or all( - chan == env_utils.DEFAULT_CHANNEL_NAME or chan == env_utils.SNOWFLAKE_CONDA_CHANNEL_URL - for chan in conda_dep_dict - ) - ) and (not pip_requirements) - - only_spcs = ( - target_platforms - and len(target_platforms) == 1 - and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms - ) - if only_spcs or (not is_warehouse_runnable): - # if only SPCS and user asked for explainability we fail - if enable_explainability is True: - raise ValueError( - "`enable_explainability` cannot be set to True when the model is not runnable in WH " - "or the target platforms include SPCS." - ) - elif not options: # explicitly set flag to false in these cases if not specified - options = model_types.BaseModelSaveOption() - options["enable_explainability"] = False - elif ( - target_platforms - and len(target_platforms) > 1 - and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms - ): # if both then only available for WH - if enable_explainability is True: - warnings.warn( - ("Explain function will only be available for model deployed to warehouse."), - category=UserWarning, - stacklevel=2, - ) if not options: options = model_types.BaseModelSaveOption() - if not snowpark_utils.is_in_stored_procedure() and target_platforms != [ # type: ignore[no-untyped-call] - model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES # no information schema check for SPCS-only models - ]: - snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema( - self.session, - reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={snowml_version.VERSION}")], - python_version=python_version or snowml_env.PYTHON_VERSION, - statement_params=self._statement_params, - ).get(env_utils.SNOWPARK_ML_PKG_NAME, []) - - if len(snowml_matched_versions) < 1 and options.get("embed_local_ml_library", False) is False: - logging.info( - f"Local snowflake-ml-python library has version {snowml_version.VERSION}," - " which is not available in the Snowflake server, embedding local ML library automatically." - ) - options["embed_local_ml_library"] = True - model_metadata: model_meta.ModelMetadata = self.packager.save( name=name, model=model, diff --git a/snowflake/ml/model/_model_composer/model_composer_test.py b/snowflake/ml/model/_model_composer/model_composer_test.py index 054bf6f0..990c9a0d 100644 --- a/snowflake/ml/model/_model_composer/model_composer_test.py +++ b/snowflake/ml/model/_model_composer/model_composer_test.py @@ -131,134 +131,6 @@ def test_save_interface(self, params: dict[str, Any]) -> None: mock_save.assert_called_once() mock_manifest_save.assert_called_once() - @parameterized.parameters( # type: ignore[misc] - {"disable_explainability": True, "target_platforms": [model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]}, - { - "disable_explainability": False, - "target_platforms": [ - model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES, - model_types.TargetPlatform.WAREHOUSE, - ], - }, - {"disable_explainability": False, "target_platforms": []}, - { - "disable_explainability": True, - "conda_dependencies": ["python-package1==1.0.0", "conda-forge::python-package2==1.1.0"], - }, - { - "disable_explainability": False, - "conda_dependencies": [ - "python-package1==1.0.0", - "https://repo.anaconda.com/pkgs/snowflake::python-package2", - ], - }, - {"disable_explainability": True, "pip_requirements": ["python-package==1.0.0"]}, - {"disable_explainability": False, "pip_requirements": None}, - ) - def test_save_enable_explainability(self, disable_explainability: bool, **kwargs: Any) -> None: - m_session = mock_session.MockSession(conn=None, test_case=self) - c_session = cast(Session, m_session) - - stage_path = '@"db"."schema"."stage"' - - mock_pk = mock.MagicMock() - mock_pk.meta = mock.MagicMock() - mock_pk.meta.signatures = mock.MagicMock() - m = model_composer.ModelComposer(session=c_session, stage_path=stage_path) - - with open(os.path.join(m._packager_workspace_path, "model.yaml"), "w", encoding="utf-8") as f: - f.write("") - m.packager = mock_pk - - with mock.patch.object(m.packager, "save", return_value=mock_pk.meta) as mock_save, mock.patch.object( - m.manifest, "save" - ), mock.patch.object(file_utils, "upload_directory_to_stage", return_value=None), mock.patch.object( - env_utils, - "get_matched_package_versions_in_information_schema", - return_value={env_utils.SNOWPARK_ML_PKG_NAME: []}, - ): - m.save( - name="model1", - model=linear_model.LinearRegression(), - **kwargs, - ) - mock_save.assert_called_once() - _, called_kwargs = mock_save.call_args - self.assertIn("options", called_kwargs) - if ( - disable_explainability - ): # set to false if the model is not runnable in WH or the target platforms is only SPCS - self.assertEqual(called_kwargs["options"], called_kwargs["options"] | {"enable_explainability": False}) - else: - # else options should be empty since user did not pass anything - # and explainability does not need to be explicitly disabled - self.assertNotIn("enable_explainability", called_kwargs["options"]) - - if disable_explainability: - with self.assertRaisesRegex( - ValueError, - "`enable_explainability` cannot be set to True when the model is not runnable in WH " - "or the target platforms include SPCS.", - ), mock.patch.object(m.packager, "save", return_value=mock_pk.meta), mock.patch.object( - m.manifest, "save" - ), mock.patch.object( - file_utils, "upload_directory_to_stage", return_value=None - ), mock.patch.object( - file_utils, "copytree", return_value="/model" - ), mock.patch.object( - env_utils, - "get_matched_package_versions_in_information_schema", - return_value={env_utils.SNOWPARK_ML_PKG_NAME: []}, - ): - m.save( - name="model1", - model=linear_model.LinearRegression(), - options={"enable_explainability": True}, - **kwargs, - ) - - @parameterized.parameters( # type: ignore[misc] - {"target_platforms": [model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]}, - { - "target_platforms": [ - model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES, - model_types.TargetPlatform.WAREHOUSE, - ] - }, - ) - def test_save_information_schema_check(self, target_platforms: list[model_types.TargetPlatform]) -> None: - m_session = mock_session.MockSession(conn=None, test_case=self) - c_session = cast(Session, m_session) - - stage_path = '@"db"."schema"."stage"' - - mock_pk = mock.MagicMock() - mock_pk.meta = mock.MagicMock() - mock_pk.meta.signatures = mock.MagicMock() - m = model_composer.ModelComposer(session=c_session, stage_path=stage_path) - - with open(os.path.join(m._packager_workspace_path, "model.yaml"), "w", encoding="utf-8") as f: - f.write("") - m.packager = mock_pk - - with mock.patch.object(m.packager, "save", return_value=mock_pk.meta) as mock_save, mock.patch.object( - m.manifest, "save" - ), mock.patch.object(file_utils, "upload_directory_to_stage", return_value=None), mock.patch.object( - env_utils, - "get_matched_package_versions_in_information_schema", - return_value={env_utils.SNOWPARK_ML_PKG_NAME: []}, - ) as mock_get_matched_package_versions_in_information_schema: - m.save( - name="model1", - model=linear_model.LinearRegression(), - target_platforms=target_platforms, - ) - mock_save.assert_called_once() - if target_platforms == [model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]: - mock_get_matched_package_versions_in_information_schema.assert_not_called() - else: - mock_get_matched_package_versions_in_information_schema.assert_called_once() - def test_load(self) -> None: m_options = model_types.PyTorchLoadOptions(use_gpu=False) with mock.patch.object(model_packager.ModelPackager, "load") as mock_load: diff --git a/snowflake/ml/model/_model_composer/model_manifest/BUILD.bazel b/snowflake/ml/model/_model_composer/model_manifest/BUILD.bazel index a68d8f84..555f93db 100644 --- a/snowflake/ml/model/_model_composer/model_manifest/BUILD.bazel +++ b/snowflake/ml/model/_model_composer/model_manifest/BUILD.bazel @@ -52,7 +52,6 @@ py_test( deps = [ ":model_manifest", "//snowflake/ml/_internal:env_utils", - "//snowflake/ml/_internal/exceptions", "//snowflake/ml/model:model_signature", "//snowflake/ml/model:type_hints", "//snowflake/ml/model/_packager/model_meta", diff --git a/snowflake/ml/model/_model_composer/model_manifest/model_manifest.py b/snowflake/ml/model/_model_composer/model_manifest/model_manifest.py index 4809a1c9..92439810 100644 --- a/snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +++ b/snowflake/ml/model/_model_composer/model_manifest/model_manifest.py @@ -1,13 +1,11 @@ import collections import logging import pathlib -import warnings from typing import TYPE_CHECKING, Optional, cast import yaml from snowflake.ml._internal import env_utils -from snowflake.ml._internal.exceptions import error_codes, exceptions from snowflake.ml.data import data_source from snowflake.ml.model import type_hints from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema @@ -55,47 +53,8 @@ def save( experiment_info: Optional["ExperimentInfo"] = None, target_platforms: Optional[list[type_hints.TargetPlatform]] = None, ) -> None: - if options is None: - options = {} - - has_pip_requirements = len(model_meta.env.pip_requirements) > 0 - only_spcs = ( - target_platforms - and len(target_platforms) == 1 - and target_platforms[0] == type_hints.TargetPlatform.SNOWPARK_CONTAINER_SERVICES - ) - - if "relax_version" not in options: - if has_pip_requirements or only_spcs: - logger.info( - "Setting `relax_version=False` as this model will run in Snowpark Container Services " - "or in Warehouse with a specified artifact_repository_map where exact version " - " specifications will be honored." - ) - relax_version = False - else: - warnings.warn( - ( - "`relax_version` is not set and therefore defaulted to True. Dependency version constraints" - " relaxed from ==x.y.z to >=x.y, <(x+1). To use specific dependency versions for compatibility," - " reproducibility, etc., set `options={'relax_version': False}` when logging the model." - ), - category=UserWarning, - stacklevel=2, - ) - relax_version = True - options["relax_version"] = relax_version - else: - relax_version = options.get("relax_version", True) - if relax_version and (has_pip_requirements or only_spcs): - raise exceptions.SnowflakeMLException( - error_code=error_codes.INVALID_ARGUMENT, - original_exception=ValueError( - "Setting `relax_version=True` is only allowed for models to be run in Warehouse with " - "Snowflake Conda Channel dependencies. It cannot be used with pip requirements or when " - "targeting only Snowpark Container Services." - ), - ) + assert options is not None, "ModelParameterReconciler should have set options with relax_version" + relax_version = options["relax_version"] runtime_to_use = model_runtime.ModelRuntime( name=self._DEFAULT_RUNTIME_NAME, diff --git a/snowflake/ml/model/_model_composer/model_manifest/model_manifest_test.py b/snowflake/ml/model/_model_composer/model_manifest/model_manifest_test.py index 0bcc1bf7..f987bde1 100644 --- a/snowflake/ml/model/_model_composer/model_manifest/model_manifest_test.py +++ b/snowflake/ml/model/_model_composer/model_manifest/model_manifest_test.py @@ -8,7 +8,6 @@ from packaging import requirements from snowflake.ml._internal import env_utils -from snowflake.ml._internal.exceptions import error_codes, exceptions from snowflake.ml.model import model_signature, type_hints from snowflake.ml.model._model_composer.model_manifest import model_manifest from snowflake.ml.model._packager.model_meta import ( @@ -118,8 +117,7 @@ def test_model_manifest_1(self) -> None: ) as meta: meta.models["model1"] = _DUMMY_BLOB - with self.assertWarnsRegex(UserWarning, "`relax_version` is not set and therefore defaulted to True."): - mm.save(meta, pathlib.PurePosixPath("model")) + mm.save(meta, pathlib.PurePosixPath("model"), options={"relax_version": True}) with open(pathlib.Path(workspace, "runtimes", "python_runtime", "env", "conda.yml"), encoding="utf-8") as f: self.assertDictEqual( yaml.safe_load(f), @@ -150,36 +148,6 @@ def test_model_manifest_1(self) -> None: f.read(), ) - def test_model_manifest_1_relax_version(self) -> None: - with tempfile.TemporaryDirectory() as workspace, tempfile.TemporaryDirectory() as tmpdir: - mm = model_manifest.ModelManifest(pathlib.Path(workspace)) - with model_meta.create_model_metadata( - model_dir_path=tmpdir, - name="model1", - model_type="custom", - signatures=_DUMMY_SIG, - python_version="3.8", - embed_local_ml_library=False, - ) as meta: - meta.models["model1"] = _DUMMY_BLOB - - mm.save( - meta, - pathlib.PurePosixPath("model"), - options=type_hints.BaseModelSaveOption( - relax_version=False, - ), - ) - with open(pathlib.Path(workspace, "runtimes", "python_runtime", "env", "conda.yml"), encoding="utf-8") as f: - self.assertDictEqual( - yaml.safe_load(f), - { - "channels": [env_utils.SNOWFLAKE_CONDA_CHANNEL_URL, "nodefaults"], - "dependencies": ["python==3.8.*"] + _PACKAGING_REQUIREMENTS_TARGET_WITH_SNOWML, - "name": "snow-env", - }, - ) - def test_model_manifest_2(self) -> None: with tempfile.TemporaryDirectory() as workspace, tempfile.TemporaryDirectory() as tmpdir: mm = model_manifest.ModelManifest(pathlib.Path(workspace)) @@ -231,37 +199,6 @@ def test_model_manifest_2(self) -> None: f.read(), ) - def test_model_manifest_2_relax_version(self) -> None: - with tempfile.TemporaryDirectory() as workspace, tempfile.TemporaryDirectory() as tmpdir: - mm = model_manifest.ModelManifest(pathlib.Path(workspace)) - with model_meta.create_model_metadata( - model_dir_path=tmpdir, - name="model1", - model_type="custom", - signatures={"__call__": _DUMMY_SIG["predict"]}, - python_version="3.8", - embed_local_ml_library=True, - ) as meta: - meta.models["model1"] = _DUMMY_BLOB - - mm.save( - meta, - pathlib.PurePosixPath("model"), - options=type_hints.BaseModelSaveOption( - method_options={"__call__": type_hints.ModelMethodSaveOptions(max_batch_size=10)}, - relax_version=True, - ), - ) - with open(pathlib.Path(workspace, "runtimes", "python_runtime", "env", "conda.yml"), encoding="utf-8") as f: - self.assertDictEqual( - yaml.safe_load(f), - { - "channels": [env_utils.SNOWFLAKE_CONDA_CHANNEL_URL, "nodefaults"], - "dependencies": ["python==3.8.*"] + _PACKAGING_REQUIREMENTS_TARGET_WITHOUT_SNOWML_RELAXED, - "name": "snow-env", - }, - ) - def test_model_manifest_mix(self) -> None: with tempfile.TemporaryDirectory() as workspace, tempfile.TemporaryDirectory() as tmpdir: mm = model_manifest.ModelManifest(pathlib.Path(workspace)) @@ -282,7 +219,8 @@ def test_model_manifest_mix(self) -> None: method_options={ "predict": type_hints.ModelMethodSaveOptions(case_sensitive=True), "__call__": type_hints.ModelMethodSaveOptions(max_batch_size=10), - } + }, + relax_version=True, ), ) with open(os.path.join(workspace, "MANIFEST.yml"), encoding="utf-8") as f: @@ -341,6 +279,7 @@ def test_model_manifest_bad(self) -> None: mm.save( meta, pathlib.PurePosixPath("model"), + options={"relax_version": True}, ) def test_model_manifest_table_function(self) -> None: @@ -360,7 +299,8 @@ def test_model_manifest_table_function(self) -> None: meta, pathlib.PurePosixPath("model"), options=type_hints.BaseModelSaveOption( - method_options={"predict": type_hints.ModelMethodSaveOptions(function_type="TABLE_FUNCTION")} + method_options={"predict": type_hints.ModelMethodSaveOptions(function_type="TABLE_FUNCTION")}, + relax_version=True, ), ) with open(os.path.join(workspace, "MANIFEST.yml"), encoding="utf-8") as f: @@ -402,7 +342,8 @@ def test_model_manifest_partitioned_function(self) -> None: meta, pathlib.PurePosixPath("model"), options=type_hints.BaseModelSaveOption( - method_options={"predict": type_hints.ModelMethodSaveOptions(function_type="TABLE_FUNCTION")} + method_options={"predict": type_hints.ModelMethodSaveOptions(function_type="TABLE_FUNCTION")}, + relax_version=True, ), ) with open(os.path.join(workspace, "MANIFEST.yml"), encoding="utf-8") as f: @@ -440,9 +381,8 @@ def test_model_manifest_pip(self) -> None: ) as meta: meta.models["model1"] = _DUMMY_BLOB - options: type_hints.ModelSaveOption = dict() + options: type_hints.ModelSaveOption = {"relax_version": False} mm.save(meta, pathlib.PurePosixPath("model"), options=options) - self.assertFalse(options.get("relax_version", True)) with open(os.path.join(workspace, "MANIFEST.yml"), encoding="utf-8") as f: self.assertEqual( @@ -465,33 +405,6 @@ def test_model_manifest_pip(self) -> None: f.read(), ) - def test_model_manifest_pip_relax_version(self) -> None: - with tempfile.TemporaryDirectory() as workspace, tempfile.TemporaryDirectory() as tmpdir: - mm = model_manifest.ModelManifest(pathlib.Path(workspace)) - with model_meta.create_model_metadata( - model_dir_path=tmpdir, - name="model1", - model_type="custom", - signatures={"predict": _DUMMY_SIG["predict"]}, - pip_requirements=["xgboost==1.2.3"], - python_version="3.8", - embed_local_ml_library=True, - ) as meta: - meta.models["model1"] = _DUMMY_BLOB - - with self.assertRaises(exceptions.SnowflakeMLException) as cm: - mm.save( - meta, - pathlib.PurePosixPath("model"), - options=type_hints.BaseModelSaveOption(relax_version=True), - ) - self.assertEqual(cm.exception.error_code, error_codes.INVALID_ARGUMENT) - self.assertIn( - "Setting `relax_version=True` is only allowed for models to be run in Warehouse with " - "Snowflake Conda Channel dependencies", - str(cm.exception), - ) - def test_model_manifest_target_platforms(self) -> None: with tempfile.TemporaryDirectory() as workspace, tempfile.TemporaryDirectory() as tmpdir: mm = model_manifest.ModelManifest(pathlib.Path(workspace)) @@ -506,7 +419,12 @@ def test_model_manifest_target_platforms(self) -> None: ) as meta: meta.models["model1"] = _DUMMY_BLOB - mm.save(meta, pathlib.PurePosixPath("model"), target_platforms=[type_hints.TargetPlatform.WAREHOUSE]) + mm.save( + meta, + pathlib.PurePosixPath("model"), + options={"relax_version": True}, + target_platforms=[type_hints.TargetPlatform.WAREHOUSE], + ) with open(os.path.join(workspace, "MANIFEST.yml"), encoding="utf-8") as f: self.assertEqual( ( @@ -528,33 +446,6 @@ def test_model_manifest_target_platforms(self) -> None: f.read(), ) - def test_model_manifest_target_platforms_relax_version(self) -> None: - with tempfile.TemporaryDirectory() as workspace, tempfile.TemporaryDirectory() as tmpdir: - mm = model_manifest.ModelManifest(pathlib.Path(workspace)) - with model_meta.create_model_metadata( - model_dir_path=tmpdir, - name="model1", - model_type="custom", - signatures={"predict": _DUMMY_SIG["predict"]}, - python_version="3.8", - embed_local_ml_library=True, - ) as meta: - meta.models["model1"] = _DUMMY_BLOB - - with self.assertRaises(exceptions.SnowflakeMLException) as cm: - mm.save( - meta, - pathlib.PurePosixPath("model"), - options=type_hints.BaseModelSaveOption(relax_version=True), - target_platforms=[type_hints.TargetPlatform.SNOWPARK_CONTAINER_SERVICES], - ) - self.assertEqual(cm.exception.error_code, error_codes.INVALID_ARGUMENT) - self.assertIn( - "Setting `relax_version=True` is only allowed for models to be run in Warehouse with " - "Snowflake Conda Channel dependencies", - str(cm.exception), - ) - def test_model_manifest_artifact_repo_map(self) -> None: with tempfile.TemporaryDirectory() as workspace, tempfile.TemporaryDirectory() as tmpdir: mm = model_manifest.ModelManifest(pathlib.Path(workspace)) @@ -570,7 +461,7 @@ def test_model_manifest_artifact_repo_map(self) -> None: ) as meta: meta.models["model1"] = _DUMMY_BLOB - mm.save(meta, pathlib.PurePosixPath("model")) + mm.save(meta, pathlib.PurePosixPath("model"), options={"relax_version": True}) with open(os.path.join(workspace, "MANIFEST.yml"), encoding="utf-8") as f: self.assertEqual( ( @@ -597,7 +488,7 @@ def test_model_manifest_resource_constraint(self) -> None: ) as meta: meta.models["model1"] = _DUMMY_BLOB - mm.save(meta, pathlib.PurePosixPath("model")) + mm.save(meta, pathlib.PurePosixPath("model"), options={"relax_version": True}) with open(os.path.join(workspace, "MANIFEST.yml"), encoding="utf-8") as f: self.assertEqual( ( @@ -634,6 +525,7 @@ def test_model_manifest_user_files(self) -> None: mm.save( meta, pathlib.PurePosixPath("model"), + options={"relax_version": True}, user_files=user_files, ) with open(os.path.join(workspace, "MANIFEST.yml"), encoding="utf-8") as f: diff --git a/snowflake/ml/model/inference_engine.py b/snowflake/ml/model/inference_engine.py new file mode 100644 index 00000000..6ffdf812 --- /dev/null +++ b/snowflake/ml/model/inference_engine.py @@ -0,0 +1,5 @@ +import enum + + +class InferenceEngine(enum.Enum): + VLLM = "vllm" diff --git a/snowflake/ml/model/models/huggingface_pipeline.py b/snowflake/ml/model/models/huggingface_pipeline.py index 4124b216..6e6aec91 100644 --- a/snowflake/ml/model/models/huggingface_pipeline.py +++ b/snowflake/ml/model/models/huggingface_pipeline.py @@ -258,7 +258,7 @@ def create_service( # model_version_impl.create_service parameters service_name: str, service_compute_pool: str, - image_repo: str, + image_repo: Optional[str] = None, image_build_compute_pool: Optional[str] = None, ingress_enabled: bool = False, max_instances: int = 1, @@ -282,7 +282,8 @@ def create_service( comment: Comment for the model. Defaults to None. service_name: The name of the service to create. service_compute_pool: The compute pool for the service. - image_repo: The name of the image repository. + image_repo: The name of the image repository. This can be None, in that case a default hidden image + repository will be used. image_build_compute_pool: The name of the compute pool used to build the model inference image. It uses the service compute pool if None. ingress_enabled: Whether ingress is enabled. Defaults to False. @@ -356,7 +357,7 @@ def create_service( else sql_identifier.SqlIdentifier(service_compute_pool) ), service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool), - image_repo=image_repo, + image_repo_name=image_repo, ingress_enabled=ingress_enabled, max_instances=max_instances, cpu_requests=cpu_requests, diff --git a/snowflake/ml/registry/_manager/BUILD.bazel b/snowflake/ml/registry/_manager/BUILD.bazel index e99604fd..a44f80ee 100644 --- a/snowflake/ml/registry/_manager/BUILD.bazel +++ b/snowflake/ml/registry/_manager/BUILD.bazel @@ -12,8 +12,14 @@ py_library( "model_parameter_reconciler.py", ], deps = [ + "//snowflake/ml:version", + "//snowflake/ml/_internal:env", + "//snowflake/ml/_internal:env_utils", + "//snowflake/ml/_internal/exceptions", "//snowflake/ml/_internal/utils:sql_identifier", + "//snowflake/ml/model:target_platform", "//snowflake/ml/model:type_hints", + "//snowflake/ml/model/_model_composer/model_manifest:model_manifest_schema", ], ) @@ -24,7 +30,6 @@ py_library( ], deps = [ ":model_parameter_reconciler", - "//snowflake/ml/_internal:env", "//snowflake/ml/_internal:platform_capabilities", "//snowflake/ml/_internal/human_readable_id:hrid_generator", "//snowflake/ml/_internal/utils:sql_identifier", @@ -35,7 +40,6 @@ py_library( "//snowflake/ml/model/_client/ops:model_ops", "//snowflake/ml/model/_client/ops:service_ops", "//snowflake/ml/model/_model_composer:model_composer", - "//snowflake/ml/model/_model_composer/model_manifest:model_manifest_schema", ], ) @@ -49,6 +53,7 @@ py_test( ":model_parameter_reconciler", "//snowflake/ml/_internal/utils:sql_identifier", "//snowflake/ml/model:type_hints", + "//snowflake/ml/test_utils:mock_session", ], ) diff --git a/snowflake/ml/registry/_manager/model_manager.py b/snowflake/ml/registry/_manager/model_manager.py index 9df86dcf..8456ccaf 100644 --- a/snowflake/ml/registry/_manager/model_manager.py +++ b/snowflake/ml/registry/_manager/model_manager.py @@ -4,15 +4,14 @@ import pandas as pd from absl.logging import logging -from snowflake.ml._internal import env, platform_capabilities, telemetry +from snowflake.ml._internal import platform_capabilities, telemetry from snowflake.ml._internal.exceptions import error_codes, exceptions from snowflake.ml._internal.human_readable_id import hrid_generator from snowflake.ml._internal.utils import sql_identifier -from snowflake.ml.model import model_signature, target_platform, task, type_hints +from snowflake.ml.model import model_signature, task, type_hints from snowflake.ml.model._client.model import model_impl, model_version_impl from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops from snowflake.ml.model._model_composer import model_composer -from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema from snowflake.ml.model._packager.model_meta import model_meta from snowflake.ml.registry._manager import model_parameter_reconciler from snowflake.snowpark import exceptions as snowpark_exceptions, session @@ -221,37 +220,8 @@ def _log_model( statement_params=statement_params, ) - platforms = None - # User specified target platforms are defaulted to None and will not show up in the generated manifest. - if target_platforms: - # Convert any string target platforms to TargetPlatform objects - platforms = [type_hints.TargetPlatform(platform) for platform in target_platforms] - else: - # Default the target platform to warehouse if not specified and any table function exists - if options and ( - options.get("function_type") == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value - or ( - any( - opt.get("function_type") == "TABLE_FUNCTION" - for opt in options.get("method_options", {}).values() - ) - ) - ): - logger.info( - "Logging a partitioned model with a table function without specifying `target_platforms`. " - 'Default to `target_platforms=["WAREHOUSE"]`.' - ) - platforms = [target_platform.TargetPlatform.WAREHOUSE] - - # Default the target platform to SPCS if not specified when running in ML runtime - if not platforms and env.IN_ML_RUNTIME: - logger.info( - "Logging the model on Container Runtime for ML without specifying `target_platforms`. " - 'Default to `target_platforms=["SNOWPARK_CONTAINER_SERVICES"]`.' - ) - platforms = [target_platform.TargetPlatform.SNOWPARK_CONTAINER_SERVICES] - reconciler = model_parameter_reconciler.ModelParameterReconciler( + session=self._model_ops._session, database_name=self._database_name, schema_name=self._schema_name, conda_dependencies=conda_dependencies, @@ -259,6 +229,8 @@ def _log_model( target_platforms=target_platforms, artifact_repository_map=artifact_repository_map, options=options, + python_version=python_version, + statement_params=statement_params, ) model_params = reconciler.reconcile() @@ -293,12 +265,12 @@ def _log_model( pip_requirements=pip_requirements, artifact_repository_map=artifact_repository_map, resource_constraint=resource_constraint, - target_platforms=platforms, + target_platforms=model_params.target_platforms, python_version=python_version, user_files=user_files, code_paths=code_paths, ext_modules=ext_modules, - options=options, + options=model_params.options, task=task, experiment_info=experiment_info, ) diff --git a/snowflake/ml/registry/_manager/model_manager_test.py b/snowflake/ml/registry/_manager/model_manager_test.py index bdcbaa2d..0de07069 100644 --- a/snowflake/ml/registry/_manager/model_manager_test.py +++ b/snowflake/ml/registry/_manager/model_manager_test.py @@ -4,7 +4,7 @@ import pandas as pd from absl.testing import absltest, parameterized -from snowflake.ml._internal import platform_capabilities, telemetry +from snowflake.ml._internal import env_utils, platform_capabilities, telemetry from snowflake.ml._internal.utils import sql_identifier from snowflake.ml.model import target_platform, task, type_hints from snowflake.ml.model._client.model import model_impl, model_version_impl @@ -198,6 +198,11 @@ def test_log_model_minimal(self, is_live_commit_enabled: bool = False) -> None: self.m_r._hrid_generator, "generate", return_value=(1, "angry_yeti_1") ) as mock_hrid_generate, mock.patch.object(model_version_impl.ModelVersion, "_get_functions", return_value=[]), + mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: []}, + ), platform_capabilities.PlatformCapabilities.mock_features( {platform_capabilities.LIVE_COMMIT_PARAMETER: True} ), @@ -234,7 +239,7 @@ def test_log_model_minimal(self, is_live_commit_enabled: bool = False) -> None: user_files=None, code_paths=None, ext_modules=None, - options=None, + options={"embed_local_ml_library": True, "relax_version": True}, task=task.Task.UNKNOWN, experiment_info=None, ) @@ -279,6 +284,11 @@ def test_log_model_1(self, is_live_commit_enabled: bool = False) -> None: mock.patch.object(model_composer.ModelComposer, "save", return_value=m_model_metadata) as mock_save, mock.patch.object(self.m_r._model_ops, "create_from_stage") as mock_create_from_stage, mock.patch.object(model_version_impl.ModelVersion, "_get_functions", return_value=[]), + mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: []}, + ), platform_capabilities.PlatformCapabilities.mock_features( {platform_capabilities.LIVE_COMMIT_PARAMETER: True} ), @@ -317,7 +327,7 @@ def test_log_model_1(self, is_live_commit_enabled: bool = False) -> None: user_files=None, code_paths=None, ext_modules=None, - options=None, + options={"embed_local_ml_library": True, "relax_version": True}, task=task.Task.UNKNOWN, experiment_info=None, ) @@ -340,7 +350,7 @@ def test_log_model_2(self, is_live_commit_enabled: bool = False) -> None: m_model = mock.MagicMock() m_pip_requirements = mock.MagicMock() m_signatures = mock.MagicMock() - m_options = mock.MagicMock() + m_options = type_hints.BaseModelSaveOption(enable_explainability=False) m_stage_path = "@TEMP.TEST.MODEL/V1" m_model_metadata = mock.MagicMock() m_model_metadata.telemetry_metadata = mock.MagicMock(return_value=self.model_md_telemetry) @@ -353,6 +363,11 @@ def test_log_model_2(self, is_live_commit_enabled: bool = False) -> None: mock.patch.object(model_composer.ModelComposer, "save", return_value=m_model_metadata) as mock_save, mock.patch.object(self.m_r._model_ops, "create_from_stage") as mock_create_from_stage, mock.patch.object(model_version_impl.ModelVersion, "_get_functions", return_value=[]), + mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: []}, + ), platform_capabilities.PlatformCapabilities.mock_features( {platform_capabilities.LIVE_COMMIT_PARAMETER: True} ), @@ -386,7 +401,7 @@ def test_log_model_2(self, is_live_commit_enabled: bool = False) -> None: user_files=None, code_paths=None, ext_modules=None, - options=m_options, + options={"enable_explainability": False, "embed_local_ml_library": True, "relax_version": False}, task=task.Task.UNKNOWN, experiment_info=None, ) @@ -425,6 +440,11 @@ def test_log_model_3(self, is_live_commit_enabled: bool = False) -> None: mock.patch.object(model_composer.ModelComposer, "save", return_value=m_model_metadata) as mock_save, mock.patch.object(self.m_r._model_ops, "create_from_stage") as mock_create_from_stage, mock.patch.object(model_version_impl.ModelVersion, "_get_functions", return_value=[]), + mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: []}, + ), platform_capabilities.PlatformCapabilities.mock_features( {platform_capabilities.LIVE_COMMIT_PARAMETER: True} ), @@ -458,7 +478,7 @@ def test_log_model_3(self, is_live_commit_enabled: bool = False) -> None: user_files=None, code_paths=m_code_paths, ext_modules=m_ext_modules, - options=None, + options={"embed_local_ml_library": True, "relax_version": True}, task=task.Task.UNKNOWN, experiment_info=None, ) @@ -496,6 +516,11 @@ def test_log_model_4(self, is_live_commit_enabled: bool = False) -> None: mock.patch.object(ModelOperator, "set_comment") as mock_set_comment, mock.patch.object(self.m_r._model_ops._metadata_ops, "save") as mock_metadata_save, mock.patch.object(model_version_impl.ModelVersion, "_get_functions", return_value=[]), + mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: []}, + ), platform_capabilities.PlatformCapabilities.mock_features( {platform_capabilities.LIVE_COMMIT_PARAMETER: True} ), @@ -528,7 +553,7 @@ def test_log_model_4(self, is_live_commit_enabled: bool = False) -> None: user_files=None, code_paths=None, ext_modules=None, - options=None, + options={"embed_local_ml_library": True, "relax_version": True}, task=task.Task.UNKNOWN, experiment_info=None, ) @@ -640,6 +665,11 @@ def test_log_model_target_platforms( mock.patch.object(model_composer.ModelComposer, "save", return_value=m_model_metadata) as mock_save, mock.patch.object(self.m_r._model_ops, "create_from_stage"), mock.patch.object(model_version_impl.ModelVersion, "_get_functions", return_value=[]), + mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: []}, + ), platform_capabilities.PlatformCapabilities.mock_features( {platform_capabilities.LIVE_COMMIT_PARAMETER: True} ), @@ -666,7 +696,10 @@ def test_log_model_target_platforms( user_files=None, code_paths=None, ext_modules=None, - options=None, + options={"enable_explainability": False, "relax_version": False} + if target_platforms == ["SNOWPARK_CONTAINER_SERVICES"] + or target_platforms == [target_platform.TargetPlatform.SNOWPARK_CONTAINER_SERVICES] + else {"embed_local_ml_library": True, "relax_version": True}, task=task.Task.UNKNOWN, experiment_info=None, ) @@ -695,6 +728,11 @@ def test_log_model_target_platform_constant( mock.patch.object(model_composer.ModelComposer, "save", return_value=m_model_metadata) as mock_save, mock.patch.object(self.m_r._model_ops, "create_from_stage"), mock.patch.object(model_version_impl.ModelVersion, "_get_functions", return_value=[]), + mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: []}, + ), platform_capabilities.PlatformCapabilities.mock_features( {platform_capabilities.LIVE_COMMIT_PARAMETER: True} ), @@ -721,7 +759,9 @@ def test_log_model_target_platform_constant( user_files=None, code_paths=None, ext_modules=None, - options=None, + options={"enable_explainability": False, "relax_version": False} + if target_platform_constant == [target_platform.TargetPlatform.SNOWPARK_CONTAINER_SERVICES] + else {"embed_local_ml_library": True, "relax_version": True}, task=task.Task.UNKNOWN, experiment_info=None, ) @@ -746,6 +786,11 @@ def test_log_model_fully_qualified(self, is_live_commit_enabled: bool = False) - mock.patch.object(ModelOperator, "set_comment") as mock_set_comment, mock.patch.object(self.m_r._model_ops._metadata_ops, "save") as mock_metadata_save, mock.patch.object(model_version_impl.ModelVersion, "_get_functions", return_value=[]), + mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: []}, + ), platform_capabilities.PlatformCapabilities.mock_features( {platform_capabilities.LIVE_COMMIT_PARAMETER: True} ), @@ -778,7 +823,7 @@ def test_log_model_fully_qualified(self, is_live_commit_enabled: bool = False) - user_files=None, code_paths=None, ext_modules=None, - options=None, + options={"embed_local_ml_library": True, "relax_version": True}, task=task.Task.UNKNOWN, experiment_info=None, ) @@ -888,6 +933,11 @@ def validate_existence_side_effect(**kwargs: Any) -> bool: self.m_r._hrid_generator, "generate", side_effect=[(1, "angry_yeti_1"), (2, "angry_yeti_2")] ) as mock_hrid_generate, mock.patch.object(model_version_impl.ModelVersion, "_get_functions", return_value=[]), + mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: []}, + ), platform_capabilities.PlatformCapabilities.mock_features( {platform_capabilities.LIVE_COMMIT_PARAMETER: True} ), @@ -918,6 +968,11 @@ def test_log_model_in_ml_runtime(self, is_live_commit_enabled: bool = False) -> mock.patch.object(model_composer.ModelComposer, "save", return_value=m_model_metadata) as mock_save, mock.patch.object(self.m_r._model_ops, "create_from_stage"), mock.patch.object(model_version_impl.ModelVersion, "_get_functions", return_value=[]), + mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: []}, + ), platform_capabilities.PlatformCapabilities.mock_features( {platform_capabilities.LIVE_COMMIT_PARAMETER: True} ), @@ -943,7 +998,7 @@ def test_log_model_in_ml_runtime(self, is_live_commit_enabled: bool = False) -> user_files=None, code_paths=None, ext_modules=None, - options=None, + options={"enable_explainability": False, "relax_version": False}, task=task.Task.UNKNOWN, experiment_info=None, ) @@ -975,6 +1030,11 @@ def test_log_model_table_function( mock.patch.object(model_composer.ModelComposer, "save", return_value=m_model_metadata) as mock_save, mock.patch.object(self.m_r._model_ops, "create_from_stage"), mock.patch.object(model_version_impl.ModelVersion, "_get_functions", return_value=[]), + mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: []}, + ), platform_capabilities.PlatformCapabilities.mock_features( {platform_capabilities.LIVE_COMMIT_PARAMETER: True} ), @@ -987,6 +1047,9 @@ def test_log_model_table_function( statement_params=self.base_statement_params, progress_status=create_mock_progress_status(), ) + expected_options = options.copy() + expected_options["embed_local_ml_library"] = True + expected_options["relax_version"] = True mock_save.assert_called_once_with( name="MODEL", model=m_model, @@ -1001,7 +1064,7 @@ def test_log_model_table_function( user_files=None, code_paths=None, ext_modules=None, - options=options, + options=expected_options, task=task.Task.UNKNOWN, experiment_info=None, ) @@ -1043,6 +1106,11 @@ def test_artifact_repository(self) -> None: mock.patch.object(ModelOperator, "set_comment"), mock.patch.object(self.m_r._model_ops._metadata_ops, "save"), mock.patch.object(model_version_impl.ModelVersion, "_get_functions", return_value=[]), + mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: []}, + ), platform_capabilities.PlatformCapabilities.mock_features( {platform_capabilities.LIVE_COMMIT_PARAMETER: False} ), @@ -1071,7 +1139,7 @@ def test_artifact_repository(self) -> None: user_files=None, code_paths=None, ext_modules=None, - options=None, + options={"embed_local_ml_library": True, "relax_version": True}, task=task.Task.UNKNOWN, experiment_info=None, ) @@ -1100,7 +1168,7 @@ def test_artifact_repository(self) -> None: user_files=None, code_paths=None, ext_modules=None, - options=None, + options={"embed_local_ml_library": True, "relax_version": True}, task=task.Task.UNKNOWN, experiment_info=None, ) @@ -1129,7 +1197,7 @@ def test_artifact_repository(self) -> None: user_files=None, code_paths=None, ext_modules=None, - options=None, + options={"embed_local_ml_library": True, "relax_version": True}, task=task.Task.UNKNOWN, experiment_info=None, ) @@ -1147,6 +1215,11 @@ def test_resource_constraint(self) -> None: mock.patch.object(ModelOperator, "set_comment"), mock.patch.object(self.m_r._model_ops._metadata_ops, "save"), mock.patch.object(model_version_impl.ModelVersion, "_get_functions", return_value=[]), + mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: []}, + ), platform_capabilities.PlatformCapabilities.mock_features( {platform_capabilities.LIVE_COMMIT_PARAMETER: False} ), @@ -1175,7 +1248,7 @@ def test_resource_constraint(self) -> None: user_files=None, code_paths=None, ext_modules=None, - options=None, + options={"embed_local_ml_library": True, "relax_version": True}, task=task.Task.UNKNOWN, experiment_info=None, ) diff --git a/snowflake/ml/registry/_manager/model_parameter_reconciler.py b/snowflake/ml/registry/_manager/model_parameter_reconciler.py index 0c190f53..3b36d044 100644 --- a/snowflake/ml/registry/_manager/model_parameter_reconciler.py +++ b/snowflake/ml/registry/_manager/model_parameter_reconciler.py @@ -1,9 +1,20 @@ import warnings from dataclasses import dataclass -from typing import Optional +from typing import Any, Optional +from absl.logging import logging +from packaging import requirements + +from snowflake.ml import version as snowml_version +from snowflake.ml._internal import env, env as snowml_env, env_utils +from snowflake.ml._internal.exceptions import error_codes, exceptions from snowflake.ml._internal.utils import sql_identifier -from snowflake.ml.model import type_hints as model_types +from snowflake.ml.model import target_platform, type_hints as model_types +from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema +from snowflake.snowpark import Session +from snowflake.snowpark._internal import utils as snowpark_utils + +logger = logging.getLogger(__name__) @dataclass @@ -12,7 +23,7 @@ class ReconciledParameters: conda_dependencies: Optional[list[str]] = None pip_requirements: Optional[list[str]] = None - target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None + target_platforms: Optional[list[model_types.TargetPlatform]] = None artifact_repository_map: Optional[dict[str, str]] = None options: Optional[model_types.ModelSaveOption] = None save_location: Optional[str] = None @@ -23,6 +34,7 @@ class ModelParameterReconciler: def __init__( self, + session: Session, database_name: sql_identifier.SqlIdentifier, schema_name: sql_identifier.SqlIdentifier, conda_dependencies: Optional[list[str]] = None, @@ -30,7 +42,10 @@ def __init__( target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None, artifact_repository_map: Optional[dict[str, str]] = None, options: Optional[model_types.ModelSaveOption] = None, + python_version: Optional[str] = None, + statement_params: Optional[dict[str, str]] = None, ) -> None: + self._session = session self._database_name = database_name self._schema_name = schema_name self._conda_dependencies = conda_dependencies @@ -38,20 +53,27 @@ def __init__( self._target_platforms = target_platforms self._artifact_repository_map = artifact_repository_map self._options = options + self._python_version = python_version + self._statement_params = statement_params def reconcile(self) -> ReconciledParameters: """Perform all parameter reconciliation and return clean parameters.""" + reconciled_artifact_repository_map = self._reconcile_artifact_repository_map() reconciled_save_location = self._extract_save_location() self._validate_pip_requirements_warehouse_compatibility(reconciled_artifact_repository_map) + reconciled_target_platforms = self._reconcile_target_platforms() + reconciled_options = self._reconcile_explainability_options(reconciled_target_platforms) + reconciled_options = self._reconcile_relax_version(reconciled_options, reconciled_target_platforms) + return ReconciledParameters( conda_dependencies=self._conda_dependencies, pip_requirements=self._pip_requirements, - target_platforms=self._target_platforms, + target_platforms=reconciled_target_platforms, artifact_repository_map=reconciled_artifact_repository_map, - options=self._options, + options=reconciled_options, save_location=reconciled_save_location, ) @@ -82,6 +104,45 @@ def _extract_save_location(self) -> Optional[str]: return None + def _reconcile_target_platforms(self) -> Optional[list[model_types.TargetPlatform]]: + """Reconcile target platforms with proper defaulting logic.""" + # User specified target platforms are defaulted to None and will not show up in the generated manifest. + if self._target_platforms: + # Convert any string target platforms to TargetPlatform objects + return [model_types.TargetPlatform(platform) for platform in self._target_platforms] + + # Default the target platform to warehouse if not specified and any table function exists + if self._has_table_function(): + logger.info( + "Logging a partitioned model with a table function without specifying `target_platforms`. " + 'Default to `target_platforms=["WAREHOUSE"]`.' + ) + return [target_platform.TargetPlatform.WAREHOUSE] + + # Default the target platform to SPCS if not specified when running in ML runtime + if env.IN_ML_RUNTIME: + logger.info( + "Logging the model on Container Runtime for ML without specifying `target_platforms`. " + 'Default to `target_platforms=["SNOWPARK_CONTAINER_SERVICES"]`.' + ) + return [target_platform.TargetPlatform.SNOWPARK_CONTAINER_SERVICES] + + return None + + def _has_table_function(self) -> bool: + """Check if any table function exists in options.""" + if self._options is None: + return False + + if self._options.get("function_type") == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value: + return True + + for opt in self._options.get("method_options", {}).values(): + if opt.get("function_type") == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value: + return True + + return False + def _validate_pip_requirements_warehouse_compatibility( self, artifact_repository_map: Optional[dict[str, str]] ) -> None: @@ -103,3 +164,131 @@ def _targets_warehouse(target_platforms: Optional[list[model_types.SupportedTarg or model_types.TargetPlatform.WAREHOUSE in target_platforms or "WAREHOUSE" in target_platforms ) + + def _reconcile_explainability_options( + self, target_platforms: Optional[list[model_types.TargetPlatform]] + ) -> model_types.ModelSaveOption: + """Reconcile explainability settings and embed_local_ml_library based on warehouse runnability.""" + options = self._options.copy() if self._options else model_types.BaseModelSaveOption() + + conda_dep_dict = env_utils.validate_conda_dependency_string_list(self._conda_dependencies or []) + + enable_explainability = options.get("enable_explainability", None) + + # Handle case where user explicitly disabled explainability + if enable_explainability is False: + return self._handle_embed_local_ml_library(options, target_platforms) + + target_platform_set = set(target_platforms) if target_platforms else set() + + is_warehouse_runnable = self._is_warehouse_runnable(conda_dep_dict) + only_spcs = target_platform_set == set(target_platform.SNOWPARK_CONTAINER_SERVICES_ONLY) + has_both_platforms = target_platform_set == set(target_platform.BOTH_WAREHOUSE_AND_SNOWPARK_CONTAINER_SERVICES) + + # Handle case where user explicitly requested explainability + if enable_explainability: + if only_spcs or not is_warehouse_runnable: + raise ValueError( + "`enable_explainability` cannot be set to True when the model is not runnable in WH " + "or the target platforms include SPCS." + ) + elif has_both_platforms: + warnings.warn( + ("Explain function will only be available for model deployed to warehouse."), + category=UserWarning, + stacklevel=2, + ) + + # Handle case where explainability is not specified (None) - set default behavior + if enable_explainability is None: + if only_spcs or not is_warehouse_runnable: + options["enable_explainability"] = False + + return self._handle_embed_local_ml_library(options, target_platforms) + + def _handle_embed_local_ml_library( + self, options: model_types.ModelSaveOption, target_platforms: Optional[list[model_types.TargetPlatform]] + ) -> model_types.ModelSaveOption: + """Handle embed_local_ml_library logic.""" + if not snowpark_utils.is_in_stored_procedure() and target_platforms != [ # type: ignore[no-untyped-call] + model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES # no information schema check for SPCS-only models + ]: + snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema( + self._session, + reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={snowml_version.VERSION}")], + python_version=self._python_version or snowml_env.PYTHON_VERSION, + statement_params=self._statement_params, + ).get(env_utils.SNOWPARK_ML_PKG_NAME, []) + + if len(snowml_matched_versions) < 1 and not options.get("embed_local_ml_library", False): + logging.info( + f"Local snowflake-ml-python library has version {snowml_version.VERSION}," + " which is not available in the Snowflake server, embedding local ML library automatically." + ) + options["embed_local_ml_library"] = True + + return options + + def _is_warehouse_runnable(self, conda_dep_dict: dict[str, list[Any]]) -> bool: + """Check if model can run in warehouse based on conda channels and pip requirements.""" + # If pip requirements are present but no artifact repository map, model cannot run in warehouse + if self._pip_requirements and not self._artifact_repository_map: + return False + + # If no conda dependencies, model can run in warehouse + if not conda_dep_dict: + return True + + # Check if all conda channels are warehouse-compatible + warehouse_compatible_channels = {env_utils.DEFAULT_CHANNEL_NAME, env_utils.SNOWFLAKE_CONDA_CHANNEL_URL} + for channel in conda_dep_dict: + if channel not in warehouse_compatible_channels: + return False + + return True + + def _reconcile_relax_version( + self, + options: model_types.ModelSaveOption, + target_platforms: Optional[list[model_types.TargetPlatform]], + ) -> model_types.ModelSaveOption: + """Reconcile relax_version setting based on pip requirements and target platforms.""" + target_platform_set = set(target_platforms) if target_platforms else set() + has_pip_requirements = bool(self._pip_requirements) + only_spcs = target_platform_set == set(target_platform.SNOWPARK_CONTAINER_SERVICES_ONLY) + + if "relax_version" not in options: + if has_pip_requirements or only_spcs: + logger.info( + "Setting `relax_version=False` as this model will run in Snowpark Container Services " + "or in Warehouse with a specified artifact_repository_map where exact version " + " specifications will be honored." + ) + relax_version = False + else: + warnings.warn( + ( + "`relax_version` is not set and therefore defaulted to True. Dependency version constraints" + " relaxed from ==x.y.z to >=x.y, <(x+1). To use specific dependency versions for compatibility," + " reproducibility, etc., set `options={'relax_version': False}` when logging the model." + ), + category=UserWarning, + stacklevel=2, + ) + relax_version = True + options["relax_version"] = relax_version + return options + + # Handle case where relax_version is already set + relax_version = options["relax_version"] + if relax_version and (has_pip_requirements or only_spcs): + raise exceptions.SnowflakeMLException( + error_code=error_codes.INVALID_ARGUMENT, + original_exception=ValueError( + "Setting `relax_version=True` is only allowed for models to be run in Warehouse with " + "Snowflake Conda Channel dependencies. It cannot be used with pip requirements or when " + "targeting only Snowpark Container Services." + ), + ) + + return options diff --git a/snowflake/ml/registry/_manager/model_parameter_reconciler_test.py b/snowflake/ml/registry/_manager/model_parameter_reconciler_test.py index 7635de63..b4846b36 100644 --- a/snowflake/ml/registry/_manager/model_parameter_reconciler_test.py +++ b/snowflake/ml/registry/_manager/model_parameter_reconciler_test.py @@ -1,14 +1,19 @@ import warnings -from typing import Any +from typing import Any, cast +from unittest import mock -from absl.testing import absltest +from absl.testing import absltest, parameterized +from snowflake.ml._internal import env_utils +from snowflake.ml._internal.exceptions import error_codes, exceptions from snowflake.ml._internal.utils import sql_identifier from snowflake.ml.model import type_hints as model_types from snowflake.ml.registry._manager import model_parameter_reconciler +from snowflake.ml.test_utils import mock_session +from snowflake.snowpark import Session -class ModelParameterReconcilerTest(absltest.TestCase): +class ModelParameterReconcilerTest(parameterized.TestCase): """Test cases for ModelParameterReconciler functionality.""" def setUp(self) -> None: @@ -18,7 +23,9 @@ def setUp(self) -> None: def _create_reconciler(self, **kwargs: Any) -> model_parameter_reconciler.ModelParameterReconciler: """Helper to create reconciler with default context.""" + mock_session = kwargs.get("session", mock.MagicMock()) return model_parameter_reconciler.ModelParameterReconciler( + session=mock_session, database_name=kwargs.get("database_name", self.database_name), schema_name=kwargs.get("schema_name", self.schema_name), conda_dependencies=kwargs.get("conda_dependencies"), @@ -26,6 +33,8 @@ def _create_reconciler(self, **kwargs: Any) -> model_parameter_reconciler.ModelP target_platforms=kwargs.get("target_platforms"), artifact_repository_map=kwargs.get("artifact_repository_map"), options=kwargs.get("options"), + python_version=kwargs.get("python_version"), + statement_params=kwargs.get("statement_params"), ) def test_artifact_repository_map_none(self) -> None: @@ -62,21 +71,6 @@ def test_save_location_missing_from_options(self) -> None: result = reconciler.reconcile() self.assertIsNone(result.save_location) - def test_parameter_passthrough(self) -> None: - """Test that other parameters are passed through unchanged.""" - conda_deps = ["numpy==1.21.0"] - pip_reqs = ["pandas>=1.3.0"] - target_platforms = [model_types.TargetPlatform.WAREHOUSE] - - reconciler = self._create_reconciler( - conda_dependencies=conda_deps, pip_requirements=pip_reqs, target_platforms=target_platforms - ) - result = reconciler.reconcile() - - self.assertEqual(result.conda_dependencies, conda_deps) - self.assertEqual(result.pip_requirements, pip_reqs) - self.assertEqual(result.target_platforms, target_platforms) - def test_targets_warehouse(self) -> None: """Test _targets_warehouse method with various target platform configurations.""" self.assertTrue(model_parameter_reconciler.ModelParameterReconciler._targets_warehouse(None)) @@ -155,43 +149,414 @@ def test_pip_requirements_warehouse_warnings(self) -> None: reconciler.reconcile() def test_pip_requirements_no_warnings(self) -> None: - """Test scenarios where warnings are not raised from pip_requirements.""" + """Test scenarios where pip_requirements warnings are not raised.""" reconciler = self._create_reconciler( pip_requirements=["pandas>=1.3.0"], artifact_repository_map={"pip": "my_repo"}, target_platforms=None, ) - with warnings.catch_warnings(): - warnings.simplefilter("error") + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") reconciler.reconcile() + pip_warnings = [ + warning for warning in w if "Models logged specifying `pip_requirements`" in str(warning.message) + ] + self.assertEqual(len(pip_warnings), 0) reconciler = self._create_reconciler( pip_requirements=None, artifact_repository_map=None, target_platforms=None, ) - with warnings.catch_warnings(): - warnings.simplefilter("error") + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") reconciler.reconcile() + pip_warnings = [ + warning for warning in w if "Models logged specifying `pip_requirements`" in str(warning.message) + ] + self.assertEqual(len(pip_warnings), 0) reconciler = self._create_reconciler( pip_requirements=["pandas>=1.3.0"], artifact_repository_map=None, target_platforms=[model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES], ) - with warnings.catch_warnings(): - warnings.simplefilter("error") + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") reconciler.reconcile() + pip_warnings = [ + warning for warning in w if "Models logged specifying `pip_requirements`" in str(warning.message) + ] + self.assertEqual(len(pip_warnings), 0) reconciler = self._create_reconciler( pip_requirements=["pandas>=1.3.0"], artifact_repository_map={"pip": "my_repo"}, target_platforms=[model_types.TargetPlatform.WAREHOUSE], ) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + reconciler.reconcile() + pip_warnings = [ + warning for warning in w if "Models logged specifying `pip_requirements`" in str(warning.message) + ] + self.assertEqual(len(pip_warnings), 0) + + def test_has_table_function(self) -> None: + """Test _has_table_function method with various options configurations.""" + reconciler = self._create_reconciler(options=None) + self.assertFalse(reconciler._has_table_function()) + + reconciler = self._create_reconciler(options={}) + self.assertFalse(reconciler._has_table_function()) + + reconciler = self._create_reconciler(options={"function_type": "TABLE_FUNCTION"}) + self.assertTrue(reconciler._has_table_function()) + + reconciler = self._create_reconciler( + options={ + "method_options": { + "predict": {"function_type": "FUNCTION"}, + "predict_proba": {"function_type": "TABLE_FUNCTION"}, + } + } + ) + self.assertTrue(reconciler._has_table_function()) + + reconciler = self._create_reconciler( + options={ + "method_options": { + "predict": {"function_type": "FUNCTION"}, + "predict_proba": {"function_type": "FUNCTION"}, + } + } + ) + self.assertFalse(reconciler._has_table_function()) + + def test_reconcile_target_platforms_user_specified(self) -> None: + """Test _reconcile_target_platforms with user-specified platforms.""" + reconciler = self._create_reconciler(target_platforms=["WAREHOUSE"]) + result = reconciler.reconcile() + self.assertEqual(result.target_platforms, [model_types.TargetPlatform.WAREHOUSE]) + + reconciler = self._create_reconciler(target_platforms=[model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]) + result = reconciler.reconcile() + self.assertEqual(result.target_platforms, [model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]) + + reconciler = self._create_reconciler( + target_platforms=["WAREHOUSE", model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES] + ) + result = reconciler.reconcile() + self.assertEqual( + result.target_platforms, + [model_types.TargetPlatform.WAREHOUSE, model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES], + ) + + def test_reconcile_target_platforms_defaults(self) -> None: + """Test _reconcile_target_platforms default behavior in various scenarios.""" + with mock.patch("snowflake.ml._internal.env.IN_ML_RUNTIME", True): + reconciler = self._create_reconciler(target_platforms=None, options=None) + result = reconciler.reconcile() + self.assertEqual(result.target_platforms, [model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]) + + with mock.patch("snowflake.ml._internal.env.IN_ML_RUNTIME", False): + reconciler = self._create_reconciler(target_platforms=None, options={"function_type": "TABLE_FUNCTION"}) + result = reconciler.reconcile() + self.assertEqual(result.target_platforms, [model_types.TargetPlatform.WAREHOUSE]) + + with mock.patch("snowflake.ml._internal.env.IN_ML_RUNTIME", False): + reconciler = self._create_reconciler(target_platforms=None, options=None) + result = reconciler.reconcile() + self.assertIsNone(result.target_platforms) + + def test_is_warehouse_runnable(self) -> None: + """Test _is_warehouse_runnable logic using conda channels and pip requirements.""" + reconciler = self._create_reconciler() + self.assertTrue(reconciler._is_warehouse_runnable({})) + self.assertTrue(reconciler._is_warehouse_runnable({"": ["pkg1"]})) + self.assertTrue(reconciler._is_warehouse_runnable({"https://repo.anaconda.com/pkgs/snowflake": ["pkg1"]})) + + self.assertFalse(reconciler._is_warehouse_runnable({"conda-forge": ["pkg1"]})) + + reconciler = self._create_reconciler(pip_requirements=["numpy"]) + self.assertFalse(reconciler._is_warehouse_runnable({})) + + def test_explainability_validation(self) -> None: + """Test explainability validation logic for different platform configurations.""" + + reconciler = self._create_reconciler(target_platforms=[model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]) + result = reconciler.reconcile() + assert result.options is not None + self.assertEqual(result.options["enable_explainability"], False) + + reconciler = self._create_reconciler( + target_platforms=[ + model_types.TargetPlatform.WAREHOUSE, + model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES, + ], + options={"enable_explainability": True}, + ) + with self.assertWarnsRegex( + UserWarning, "Explain function will only be available for model deployed to warehouse" + ): + result = reconciler.reconcile() + assert result.options is not None + self.assertEqual(result.options["enable_explainability"], True) + + reconciler = self._create_reconciler( + conda_dependencies=["conda-forge::python-package==1.0.0"], + target_platforms=[model_types.TargetPlatform.WAREHOUSE], + ) + result = reconciler.reconcile() + assert result.options is not None + self.assertEqual(result.options["enable_explainability"], False) + + reconciler = self._create_reconciler( + conda_dependencies=["conda-forge::python-package==1.0.0"], + target_platforms=[model_types.TargetPlatform.WAREHOUSE], + options={"enable_explainability": True}, + ) + with self.assertRaisesRegex(ValueError, "`enable_explainability` cannot be set to True.*not runnable in WH"): + reconciler.reconcile() + + def test_embed_local_ml_library_logic(self) -> None: + """Test embed_local_ml_library auto-setting logic.""" + with mock.patch.object(env_utils, "get_matched_package_versions_in_information_schema") as mock_get_versions: + + mock_get_versions.return_value = {} + + reconciler = self._create_reconciler( + target_platforms=[model_types.TargetPlatform.WAREHOUSE], options={"embed_local_ml_library": False} + ) + result = reconciler.reconcile() + assert result.options is not None + self.assertTrue(result.options["embed_local_ml_library"]) + + mock_get_versions.return_value = {"snowflake-ml-python": ["1.0.0"]} + reconciler = self._create_reconciler( + target_platforms=[model_types.TargetPlatform.WAREHOUSE], options={"embed_local_ml_library": False} + ) + result = reconciler.reconcile() + assert result.options is not None + self.assertFalse(result.options["embed_local_ml_library"]) + + reconciler = self._create_reconciler( + target_platforms=[model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES], + options={"embed_local_ml_library": False}, + ) + result = reconciler.reconcile() + assert result.options is not None + self.assertFalse(result.options["embed_local_ml_library"]) + + def test_relax_version_logic(self) -> None: + """Test relax_version auto-setting and validation logic.""" + + reconciler = self._create_reconciler(pip_requirements=["xgboost==1.2.3"]) + with mock.patch.object(model_parameter_reconciler.logger, "info") as mock_info: + result = reconciler.reconcile() + assert result.options is not None + self.assertFalse(result.options["relax_version"]) + mock_info.assert_called_with( + "Setting `relax_version=False` as this model will run in Snowpark Container Services " + "or in Warehouse with a specified artifact_repository_map where exact version " + " specifications will be honored." + ) + + reconciler = self._create_reconciler(target_platforms=[model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]) + with mock.patch.object(model_parameter_reconciler.logger, "info") as mock_info: + result = reconciler.reconcile() + assert result.options is not None + self.assertFalse(result.options["relax_version"]) + mock_info.assert_called_with( + "Setting `relax_version=False` as this model will run in Snowpark Container Services " + "or in Warehouse with a specified artifact_repository_map where exact version " + " specifications will be honored." + ) + + reconciler = self._create_reconciler() + with self.assertWarnsRegex(UserWarning, "`relax_version` is not set and therefore defaulted to True"): + result = reconciler.reconcile() + assert result.options is not None + self.assertTrue(result.options["relax_version"]) + + reconciler = self._create_reconciler(pip_requirements=["xgboost==1.2.3"], options=None) + with mock.patch.object(model_parameter_reconciler.logger, "info") as mock_info: + result = reconciler.reconcile() + assert result.options is not None + self.assertFalse(result.options["relax_version"]) + mock_info.assert_called_with( + "Setting `relax_version=False` as this model will run in Snowpark Container Services " + "or in Warehouse with a specified artifact_repository_map where exact version " + " specifications will be honored." + ) + + reconciler = self._create_reconciler(options={"relax_version": False}) + result = reconciler.reconcile() + assert result.options is not None + self.assertFalse(result.options["relax_version"]) + + reconciler = self._create_reconciler(options={"relax_version": True}) with warnings.catch_warnings(): warnings.simplefilter("error") + result = reconciler.reconcile() + assert result.options is not None + self.assertTrue(result.options["relax_version"]) + + reconciler = self._create_reconciler(pip_requirements=["xgboost==1.2.3"], options={"relax_version": True}) + with self.assertRaises(exceptions.SnowflakeMLException) as cm: reconciler.reconcile() + self.assertEqual(cm.exception.error_code, error_codes.INVALID_ARGUMENT) + self.assertIn( + "Setting `relax_version=True` is only allowed for models to be run in Warehouse with " + "Snowflake Conda Channel dependencies", + str(cm.exception), + ) + + reconciler = self._create_reconciler( + target_platforms=[model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES], options={"relax_version": True} + ) + with self.assertRaises(exceptions.SnowflakeMLException) as cm: + reconciler.reconcile() + self.assertEqual(cm.exception.error_code, error_codes.INVALID_ARGUMENT) + self.assertIn( + "Setting `relax_version=True` is only allowed for models to be run in Warehouse with " + "Snowflake Conda Channel dependencies", + str(cm.exception), + ) + + @parameterized.parameters( # type: ignore[misc] + {"disable_explainability": True, "target_platforms": [model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]}, + { + "disable_explainability": False, + "target_platforms": [ + model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES, + model_types.TargetPlatform.WAREHOUSE, + ], + }, + {"disable_explainability": False, "target_platforms": []}, + { + "disable_explainability": True, + "conda_dependencies": ["python-package1==1.0.0", "conda-forge::python-package2==1.1.0"], + }, + { + "disable_explainability": False, + "conda_dependencies": [ + "python-package1==1.0.0", + "https://repo.anaconda.com/pkgs/snowflake::python-package2", + ], + }, + {"disable_explainability": True, "pip_requirements": ["python-package==1.0.0"]}, + {"disable_explainability": False, "pip_requirements": None}, + ) + def test_explainability_parameter_reconciliation(self, disable_explainability: bool, **kwargs: Any) -> None: + """Test explainability parameter reconciliation matching original model_composer test structure.""" + m_session = mock_session.MockSession(conn=None, test_case=self) + c_session = cast(Session, m_session) + + reconciler = model_parameter_reconciler.ModelParameterReconciler( + session=c_session, + database_name=self.database_name, + schema_name=self.schema_name, + conda_dependencies=kwargs.get("conda_dependencies"), + pip_requirements=kwargs.get("pip_requirements"), + target_platforms=kwargs.get("target_platforms"), + options=None, + ) + + if kwargs.get("conda_dependencies") == ["python-package1==1.0.0", "conda-forge::python-package2==1.1.0"]: + mock_conda_result = {"conda-forge": ["python-package2"]} + else: + mock_conda_result = {} + + with mock.patch.object( + env_utils, "validate_conda_dependency_string_list", return_value=mock_conda_result + ) as mock_validate_conda: + with mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: []}, + ) as mock_get_versions: + + result = reconciler.reconcile() + + mock_validate_conda.assert_called_once() + if kwargs.get("target_platforms") != [model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]: + mock_get_versions.assert_called_once() + else: + mock_get_versions.assert_not_called() + + if disable_explainability: + assert result.options is not None + self.assertEqual(result.options["enable_explainability"], False) + else: + assert result.options is not None + self.assertNotIn("enable_explainability", result.options) + + if disable_explainability: + reconciler = model_parameter_reconciler.ModelParameterReconciler( + session=c_session, + database_name=self.database_name, + schema_name=self.schema_name, + conda_dependencies=kwargs.get("conda_dependencies"), + pip_requirements=kwargs.get("pip_requirements"), + target_platforms=kwargs.get("target_platforms"), + options={"enable_explainability": True}, + ) + + with mock.patch.object(env_utils, "validate_conda_dependency_string_list", return_value=mock_conda_result): + with mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: []}, + ): + + with self.assertRaisesRegex( + ValueError, + "`enable_explainability` cannot be set to True when the model is not runnable in WH " + "or the target platforms include SPCS.", + ): + reconciler.reconcile() + + @parameterized.parameters( # type: ignore[misc] + {"target_platforms": [model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]}, + { + "target_platforms": [ + model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES, + model_types.TargetPlatform.WAREHOUSE, + ] + }, + ) + def test_embed_ml_library_information_schema_check( + self, target_platforms: list[model_types.TargetPlatform] + ) -> None: + """Test embed_local_ml_library information schema check matching original model_composer test structure.""" + m_session = mock_session.MockSession(conn=None, test_case=self) + c_session = cast(Session, m_session) + + reconciler = model_parameter_reconciler.ModelParameterReconciler( + session=c_session, + database_name=self.database_name, + schema_name=self.schema_name, + conda_dependencies=None, + pip_requirements=None, + target_platforms=cast(list[model_types.SupportedTargetPlatformType], target_platforms), + options=None, + ) + + with mock.patch.object(env_utils, "validate_conda_dependency_string_list", return_value={}): + with mock.patch.object( + env_utils, + "get_matched_package_versions_in_information_schema", + return_value={env_utils.SNOWPARK_ML_PKG_NAME: []}, + ) as mock_get_versions: + + reconciler.reconcile() + + if target_platforms == [model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]: + mock_get_versions.assert_not_called() + else: + mock_get_versions.assert_called_once() if __name__ == "__main__": diff --git a/snowflake/ml/version.py b/snowflake/ml/version.py index 597c4a10..4d5a4ad5 100644 --- a/snowflake/ml/version.py +++ b/snowflake/ml/version.py @@ -1,2 +1,2 @@ # This is parsed by regex in conda recipe meta file. Make sure not to break it. -VERSION = "1.10.0" +VERSION = "1.11.0" diff --git a/tests/integ/snowflake/ml/experiment/BUILD.bazel b/tests/integ/snowflake/ml/experiment/BUILD.bazel index 71744462..cf58a1a0 100644 --- a/tests/integ/snowflake/ml/experiment/BUILD.bazel +++ b/tests/integ/snowflake/ml/experiment/BUILD.bazel @@ -25,6 +25,20 @@ py_library( ], ) +py_test( + name = "autolog_keras_integ_test", + timeout = "long", + srcs = ["autolog_keras_integ_test.py"], + optional_dependencies = [ + "keras", + ], + tags = ["feature:observability"], + deps = [ + ":autolog_integ_test_base", + "//snowflake/ml/experiment/callback:keras", + ], +) + py_test( name = "autolog_lightgbm_integ_test", timeout = "long", diff --git a/tests/integ/snowflake/ml/experiment/autolog_integ_test_base.py b/tests/integ/snowflake/ml/experiment/autolog_integ_test_base.py index 43360f04..9a135946 100644 --- a/tests/integ/snowflake/ml/experiment/autolog_integ_test_base.py +++ b/tests/integ/snowflake/ml/experiment/autolog_integ_test_base.py @@ -41,6 +41,13 @@ def setUp(self) -> None: self.X = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) self.y = [0, 1, 0] self.num_steps = 5 + self.model_sig = model_signature.ModelSignature( + inputs=[ + model_signature.FeatureSpec(name="a", dtype=model_signature.DataType.FLOAT), + model_signature.FeatureSpec(name="b", dtype=model_signature.DataType.FLOAT), + ], + outputs=[model_signature.FeatureSpec(name="target", dtype=model_signature.DataType.INT8)], + ) def tearDown(self) -> None: self._db_manager.drop_database(self._db_name) @@ -49,26 +56,21 @@ def tearDown(self) -> None: def _train_model(self, model_class: type[Any], callback: Any) -> None: pass - def _test_autolog(self, model_class: type[Any], callback_class: type[Any], metric_name: str) -> None: + def _test_autolog( + self, model_class: type[Any], callback_class: type[Any], metric_name: str, log_every_n_epochs: int + ) -> None: """Test that autologging works.""" experiment_name = "TEST_EXPERIMENT_AUTOLOG" model_name = "TEST_AUTOLOG_MODEL" - model_sig = model_signature.ModelSignature( - inputs=[ - model_signature.FeatureSpec(name="a", dtype=model_signature.DataType.FLOAT), - model_signature.FeatureSpec(name="b", dtype=model_signature.DataType.FLOAT), - ], - outputs=[model_signature.FeatureSpec(name="target", dtype=model_signature.DataType.INT8)], - ) - callback = callback_class( self.exp, log_model=True, log_metrics=True, log_params=True, + log_every_n_epochs=log_every_n_epochs, model_name=model_name, - model_signature=model_sig, + model_signature=self.model_sig, ) self.exp.set_experiment(experiment_name=experiment_name) self._train_model(model_class, callback) @@ -82,13 +84,18 @@ def _test_autolog(self, model_class: type[Any], callback_class: type[Any], metri # Parse and verify metadata metadata = json.loads(runs[0]["metadata"]) self.assertIn("metrics", metadata) - metric_dict = {f"{m['name']}_step_{m['step']}": m for m in metadata["metrics"]} - # Verify that the expected metric was logged for each step - for step in range(self.num_steps): - self.assertIn(f"{metric_name}_step_{step}", metric_dict) + metric_set = {f"{m['name']}_step_{m['step']}" for m in metadata["metrics"]} + # Verify that the specified metric was logged at all expected epochs + for epoch in range(0, self.num_steps, log_every_n_epochs): + self.assertIn(f"{metric_name}_step_{epoch}", metric_set) + # Verify that no metrics were logged at epochs that are not multiples of `log_every_n_epochs` + for metric in metadata["metrics"]: + self.assertIn(metric["step"], range(0, self.num_steps, log_every_n_epochs)) + # Verify that params were logged self.assertIn("parameters", metadata) self.assertGreater(len(metadata["parameters"]), 0) + # Verify that the model was logged models = self._session.sql(f"SHOW MODELS LIKE '{model_name}'").collect() self.assertEqual(len(models), 1) diff --git a/tests/integ/snowflake/ml/experiment/autolog_keras_integ_test.py b/tests/integ/snowflake/ml/experiment/autolog_keras_integ_test.py new file mode 100644 index 00000000..63df9622 --- /dev/null +++ b/tests/integ/snowflake/ml/experiment/autolog_keras_integ_test.py @@ -0,0 +1,39 @@ +import keras +import numpy as np +from absl.testing import absltest, parameterized + +from snowflake.ml.experiment.callback.keras import SnowflakeKerasCallback +from tests.integ.snowflake.ml.experiment.autolog_integ_test_base import ( + AutologIntegrationTest, +) + + +class AutologKerasIntegrationTest(AutologIntegrationTest, parameterized.TestCase): + def _train_model( + self, + model_class: type[keras.Model], + callback: SnowflakeKerasCallback, + ) -> None: + model = model_class() + model.add(keras.layers.Dense(1)) + model.compile(loss="mean_squared_error", metrics=["mean_absolute_error"]) + model.fit(self.X.values, np.array(self.y), epochs=self.num_steps, callbacks=[callback]) + + @parameterized.parameters( + (keras.Sequential, "loss", 1), + (keras.Sequential, "loss", 2), + (keras.Sequential, "mean_absolute_error", 1), + (keras.Sequential, "mean_absolute_error", 3), + ) # type: ignore[misc] + def test_autolog(self, model_class: type[keras.Model], metric_name: str, log_every_n_epochs: int) -> None: + """Test that autologging works for Keras models.""" + self._test_autolog( + model_class=model_class, + callback_class=SnowflakeKerasCallback, + metric_name=metric_name, + log_every_n_epochs=log_every_n_epochs, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/integ/snowflake/ml/experiment/autolog_lightgbm_integ_test.py b/tests/integ/snowflake/ml/experiment/autolog_lightgbm_integ_test.py index 22299c16..1051e044 100644 --- a/tests/integ/snowflake/ml/experiment/autolog_lightgbm_integ_test.py +++ b/tests/integ/snowflake/ml/experiment/autolog_lightgbm_integ_test.py @@ -25,16 +25,20 @@ def _train_model( raise ValueError(f"Unsupported model class: {model_class}") @parameterized.parameters( - (lgb.LGBMClassifier, "training:binary_logloss"), - (lgb.LGBMRegressor, "training:l2"), - (lgb.Booster, "training:l2"), + (lgb.LGBMClassifier, "training:binary_logloss", 1), + (lgb.LGBMRegressor, "training:l2", 2), + (lgb.Booster, "training:l2", 1), + (lgb.Booster, "training:l2", 3), ) # type: ignore[misc] - def test_autolog(self, model_class: type[Union[lgb.LGBMModel, lgb.Booster]], metric_name: str) -> None: + def test_autolog( + self, model_class: type[Union[lgb.LGBMModel, lgb.Booster]], metric_name: str, log_every_n_epochs: int + ) -> None: """Test that autologging works for LightGBM models.""" self._test_autolog( model_class=model_class, callback_class=SnowflakeLightgbmCallback, metric_name=metric_name, + log_every_n_epochs=log_every_n_epochs, ) diff --git a/tests/integ/snowflake/ml/experiment/autolog_xgboost_integ_test.py b/tests/integ/snowflake/ml/experiment/autolog_xgboost_integ_test.py index fe146e55..55a9919a 100644 --- a/tests/integ/snowflake/ml/experiment/autolog_xgboost_integ_test.py +++ b/tests/integ/snowflake/ml/experiment/autolog_xgboost_integ_test.py @@ -25,16 +25,20 @@ def _train_model( raise ValueError(f"Unsupported model class: {model_class}") @parameterized.parameters( - (xgb.XGBClassifier, "validation_0:logloss"), - (xgb.XGBRegressor, "validation_0:rmse"), - (xgb.Booster, "train:rmse"), + (xgb.XGBClassifier, "validation_0:logloss", 1), + (xgb.XGBRegressor, "validation_0:rmse", 2), + (xgb.Booster, "train:rmse", 1), + (xgb.Booster, "train:rmse", 3), ) # type: ignore[misc] - def test_autolog(self, model_class: type[Union[xgb.XGBModel, xgb.Booster]], metric_name: str) -> None: + def test_autolog( + self, model_class: type[Union[xgb.XGBModel, xgb.Booster]], metric_name: str, log_every_n_epochs: int + ) -> None: """Test that autologging works for XGBoost models.""" self._test_autolog( model_class=model_class, callback_class=SnowflakeXgboostCallback, metric_name=metric_name, + log_every_n_epochs=log_every_n_epochs, ) diff --git a/tests/integ/snowflake/ml/jobs/BUILD.bazel b/tests/integ/snowflake/ml/jobs/BUILD.bazel index 72aa9aad..0c22be9a 100644 --- a/tests/integ/snowflake/ml/jobs/BUILD.bazel +++ b/tests/integ/snowflake/ml/jobs/BUILD.bazel @@ -71,6 +71,8 @@ py_test( ":test_file_helper", "//snowflake/ml/data", "//snowflake/ml/jobs", + "//snowflake/ml/jobs/_utils:payload_utils", + "//snowflake/ml/jobs/_utils:types", "//tests/integ/snowflake/ml/test_utils:db_manager", "//tests/integ/snowflake/ml/test_utils:test_env_utils", ], diff --git a/tests/integ/snowflake/ml/jobs/jobs_integ_test.py b/tests/integ/snowflake/ml/jobs/jobs_integ_test.py index 0de7ccfa..d1e4abf6 100644 --- a/tests/integ/snowflake/ml/jobs/jobs_integ_test.py +++ b/tests/integ/snowflake/ml/jobs/jobs_integ_test.py @@ -3,6 +3,7 @@ import os import pathlib import re +import sys import tempfile import textwrap import time @@ -19,9 +20,15 @@ from snowflake.ml._internal import env from snowflake.ml._internal.utils import identifier from snowflake.ml.jobs import job as jd -from snowflake.ml.jobs._utils import constants, query_helper +from snowflake.ml.jobs._utils import ( + constants, + payload_utils, + query_helper, + spec_utils, + types, +) from snowflake.ml.utils import sql_client -from snowflake.snowpark import exceptions as sp_exceptions +from snowflake.snowpark import exceptions as sp_exceptions, functions as F from tests.integ.snowflake.ml.jobs import test_constants from tests.integ.snowflake.ml.jobs.test_file_helper import TestAsset from tests.integ.snowflake.ml.test_utils import db_manager, test_env_utils @@ -506,6 +513,33 @@ def __str__(self) -> str: self.assertEqual(loaded_job.status, "DONE") self.assertDictEqual(loaded_job.result(), job_result) + # TODO(SNOW-1911482): Enable test for Python 3.11+ + @absltest.skipIf( + version.Version(env.PYTHON_VERSION) >= version.Version("3.11"), + "Decorator test only works for Python 3.10 and below due to pickle compatibility", + ) # type: ignore[misc] + def test_job_execution_in_stored_procedure(self) -> None: + jobs_import_src = os.path.dirname(jobs.__file__) + + @jobs.remote(self.compute_pool, stage_name="payload_stage") + def job_fn() -> None: + print("Hello from remote function!") + + @F.sproc( + session=self.session, + packages=["snowflake-snowpark-python", "snowflake-ml-python"], + imports=[ + (jobs_import_src, "snowflake.ml.jobs"), + ], + ) + def job_sproc(session: snowpark.Session) -> None: + job = job_fn() + assert job.wait() == "DONE", f"Job {job.id} failed. Logs:\n{job.get_logs()}" + return job.get_logs() + + result = job_sproc() + self.assertEqual("Hello from remote function!", result) + # TODO(SNOW-1911482): Enable test for Python 3.11+ @absltest.skipIf( # type: ignore[misc] version.Version(env.PYTHON_VERSION) >= version.Version("3.11"), @@ -653,11 +687,13 @@ def compute_heavy(n): ray.init(address="auto", ignore_reinit_error=True) hosts = [compute_heavy.remote(50_000) for _ in range(10)] unique_hosts = set(ray.get(hosts)) - assert len(unique_hosts) >= 2, f"Expected at least 2 unique hosts, get: {unique_hosts}" + assert ( + len(unique_hosts) >= 2 + ), f"Expected at least 2 unique hosts, get: {unique_hosts}, hosts: {ray.get(hosts)}" print("test succeeded") job = self._submit_func_as_file(ray_workload, target_instances=2, min_instances=2) - self.assertEqual(job.wait(), "DONE", job.get_logs(verbose=True)) + self.assertEqual(job.wait(), "DONE", f"job {job.id} logs: {job.get_logs(verbose=True)}") self.assertTrue("test succeeded" in job.get_logs()) def test_multinode_job_wait_for_instances(self) -> None: @@ -955,23 +991,16 @@ def test_submit_job_from_stage( f"CREATE {'TEMPORARY' if temporary else ''} STAGE {stage_name} ENCRYPTION = (TYPE = {repr(encryption)});" ).collect() upload_files = TestAsset("src") - for path in { - p.parent.joinpath(f"*{p.suffix}") if p.suffix else p - for p in upload_files.path.resolve().rglob("*") - if p.is_file() - }: - self.session.file.put( - str(path), - pathlib.Path(stage_name).joinpath(path.parent.relative_to(upload_files.path)).as_posix(), - overwrite=True, - auto_compress=False, - ) - + payload_utils.upload_payloads( + self.session, pathlib.PurePath(stage_name), types.PayloadSpec(upload_files.path, None) + ) test_cases = [ - (f"@{stage_name}/", f"@{stage_name}/subdir/sub_main.py"), - (f"@{stage_name}/subdir", f"@{stage_name}/subdir/sub_main.py"), + (f"@{stage_name}/", f"@{stage_name}/subdir/sub_main.py", "DONE"), + (f"@{stage_name}/subdir", f"@{stage_name}/subdir/sub_main.py", "DONE"), + (f"@{stage_name}/subdir", "sub_main.py", "DONE"), + (f"@{stage_name}/subdir", "non_exist_file.py", "FAILED"), ] - for source, entrypoint in test_cases: + for source, entrypoint, expected_status in test_cases: with self.subTest(source=source, entrypoint=entrypoint): job = jobs.submit_from_stage( source=source, @@ -982,7 +1011,7 @@ def test_submit_job_from_stage( session=self.session, ) - self.assertEqual(job.wait(), "DONE", job.get_logs()) + self.assertEqual(job.wait(), expected_status, job.get_logs()) @parameterized.parameters( [ @@ -1090,11 +1119,6 @@ def greet(): def test_cancel_job(self) -> None: """Test cancelling a long running job.""" - try: - self.session.sql("ALTER SESSION SET SNOWSERVICES_ENABLE_SPCS_JOB_CANCELLATION = TRUE").collect() - self.session.sql("ALTER SESSION SET ENABLE_ENTITY_FACADE_SYSTEM_FUNCTIONS = TRUE").collect() - except sp_exceptions.SnowparkSQLException: - self.skipTest("Unable to control the SPCS job cancellation parameter. Skipping test.") def long_running_function() -> None: import time @@ -1104,8 +1128,12 @@ def long_running_function() -> None: job = self._submit_func_as_file(long_running_function) self.assertIn(job.status, ["PENDING", "RUNNING"]) job.cancel() - final_status = job.wait(timeout=20) - self.assertEqual(final_status, "CANCELLED") + try: + job.wait(timeout=20) + except TimeoutError: + print("Job did not cancel within timeout", job.status, job.get_logs()) + finally: + self.assertIn(job.status, ["CANCELLED", "CANCELLING"]) def test_cancel_nonexistent_job(self) -> None: """Test cancelling a job that doesn't exist.""" @@ -1125,23 +1153,20 @@ def test_multinode_job_orders(self) -> None: """Test that the job orders are correct for a multinode job.""" job = self._submit_func_as_file(dummy_function, target_instances=2) self.assertEqual(job.wait(), "DONE", job.get_logs()) - if "resourceManagement" in job._service_spec["spec"]: - # Step 1: Show service instances in service - rows = query_helper.run_query( - self.session, "SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=[job.id] - ) - self.assertEqual(len(rows), 2, "Expected 2 service instances for target_instances=2") + # Step 1: Show service instances in service + rows = query_helper.run_query(self.session, "SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=[job.id]) + self.assertEqual(len(rows), 2, "Expected 2 service instances for target_instances=2") - # Step 2: Sort them by start-time - sorted_instances = sorted(rows, key=lambda x: (x["start_time"], int(x["instance_id"]))) + # Step 2: Sort them by start-time + sorted_instances = sorted(rows, key=lambda x: (x["start_time"], int(x["instance_id"]))) - # Step 3: Check instance with id 0 starts first - first_instance = sorted_instances[0] - self.assertEqual( - int(first_instance["instance_id"]), - 0, - f"Expected instance 0 to start first, but instance {first_instance['instance_id']} started first", - ) + # Step 3: Check instance with id 0 starts first + first_instance = sorted_instances[0] + self.assertEqual( + int(first_instance["instance_id"]), + 0, + f"Expected instance 0 to start first, but instance {first_instance['instance_id']} started first", + ) @parameterized.parameters( # type: ignore[misc] ("src", "src/entry.py", [(TestAsset("src/subdir/utils").path.as_posix(), "src.subdir.utils")]), @@ -1174,17 +1199,9 @@ def test_submit_with_additional_payloads_stage(self) -> None: stage_path = f"{self.session.get_session_stage()}/{str(uuid4())}" upload_files = TestAsset("src") - for path in { - p.parent.joinpath(f"*{p.suffix}") if p.suffix else p - for p in upload_files.path.resolve().rglob("*") - if p.is_file() - }: - self.session.file.put( - str(path), - pathlib.Path(stage_path).joinpath(path.parent.relative_to(upload_files.path)).as_posix(), - overwrite=True, - auto_compress=False, - ) + payload_utils.upload_payloads( + self.session, pathlib.PurePath(stage_path), types.PayloadSpec(upload_files.path, None) + ) test_cases = [ (f"{stage_path}/", f"{stage_path}/entry.py", [(f"{stage_path}/subdir/utils", "src.subdir.utils")]), @@ -1224,11 +1241,54 @@ def test_requirements_non_overwrite(self) -> None: entrypoint="main.py", stage_name="payload_stage", external_access_integrations=pypi_eais, + session=self.session, ) self.assertEqual(job.wait(), "DONE", job.get_logs()) self.assertIn("Numpy version: 1.23", job.get_logs()) self.assertIn(f"Cloudpickle version: {version.parse(cp.__version__).major}.", job.get_logs()) + def test_submit_with_hidden_files(self) -> None: + job = jobs.submit_directory( + TestAsset("src/subdir6").path, + self.compute_pool, + entrypoint="main.py", + stage_name="payload_stage", + session=self.session, + ) + self.assertEqual(job.wait(), "DONE", job.get_logs()) + self.assertIn("This is a secret message stored in a hidden YAML file", job.get_logs()) + self.assertIn("This is the content of a hidden file with no extension", job.get_logs()) + + def test_job_with_different_python_version(self) -> None: + target_version = f"{sys.version_info.major}.{sys.version_info.minor}" + resources = spec_utils._get_node_resources(self.session, self.compute_pool) + hardware = "GPU" if resources.gpu > 0 else "CPU" + try: + expected_runtime_image = spec_utils._get_runtime_image(self.session, hardware) + except Exception: + expected_runtime_image = None + + with mock.patch.dict(os.environ, {constants.ENABLE_IMAGE_VERSION_ENV_VAR: "True"}): + job = jobs.submit_file( + TestAsset("src/check_python.py").path, + self.compute_pool, + stage_name="payload_stage", + session=self.session, + ) + self.assertEqual(job.wait(), "DONE", job.get_logs()) + if expected_runtime_image: + self.assertIn( + target_version, + job.get_logs(), + f"Expected Python {target_version} when matching runtime available: {expected_runtime_image}", + ) + else: + self.assertIn( + "3.10", + job.get_logs(), + "Expected fallback to default Python version when no matching runtime available", + ) + if __name__ == "__main__": absltest.main() diff --git a/tests/integ/snowflake/ml/jobs/payload_utils_integ_test.py b/tests/integ/snowflake/ml/jobs/payload_utils_integ_test.py index 4acf45ec..341429c0 100644 --- a/tests/integ/snowflake/ml/jobs/payload_utils_integ_test.py +++ b/tests/integ/snowflake/ml/jobs/payload_utils_integ_test.py @@ -4,7 +4,7 @@ from absl.testing import absltest, parameterized -from snowflake.ml.jobs._utils import constants, payload_utils +from snowflake.ml.jobs._utils import constants, payload_utils, types from tests.integ.snowflake.ml.jobs.test_file_helper import TestAsset from tests.integ.snowflake.ml.test_utils import db_manager, test_env_utils @@ -83,17 +83,17 @@ def test_upload_payload_negative( (TestAsset("src/main.py"), TestAsset("src/main.py"), "/mnt/job_stage/app/main.py", 1), (TestAsset("src/main.py"), None, "/mnt/job_stage/app/main.py", 1), # Entrypoint as relative path inside payload directory - (TestAsset("src"), TestAsset("main.py", resolve_path=False), "/mnt/job_stage/app/main.py", 16), + (TestAsset("src"), TestAsset("main.py", resolve_path=False), "/mnt/job_stage/app/main.py", 20), ( TestAsset("src"), TestAsset("subdir/sub_main.py", resolve_path=False), "/mnt/job_stage/app/subdir/sub_main.py", - 16, + 20, ), (TestAsset("src/subdir"), TestAsset("sub_main.py", resolve_path=False), "/mnt/job_stage/app/sub_main.py", 2), # Entrypoint as absolute path - (TestAsset("src"), TestAsset("src/main.py"), "/mnt/job_stage/app/main.py", 16), - (TestAsset("src"), TestAsset("src/subdir/sub_main.py"), "/mnt/job_stage/app/subdir/sub_main.py", 16), + (TestAsset("src"), TestAsset("src/main.py"), "/mnt/job_stage/app/main.py", 20), + (TestAsset("src"), TestAsset("src/subdir/sub_main.py"), "/mnt/job_stage/app/subdir/sub_main.py", 20), (TestAsset("src/subdir"), TestAsset("src/subdir/sub_main.py"), "/mnt/job_stage/app/sub_main.py", 2), # Function as payload (function_with_pos_arg, pathlib.Path("function_payload.py"), "/mnt/job_stage/app/function_payload.py", 1), @@ -135,14 +135,14 @@ def test_upload_payload( 1, ), (TestAsset("src/main.py"), f"@{_TEST_STAGE}/main.py", None, "/mnt/job_stage/app/main.py", 1), - (TestAsset("src"), f"@{_TEST_STAGE}/main.py", None, "/mnt/job_stage/app/main.py", 16), - (TestAsset("src"), f"@{_TEST_STAGE}/", f"@{_TEST_STAGE}/main.py", "/mnt/job_stage/app/main.py", 16), + (TestAsset("src"), f"@{_TEST_STAGE}/main.py", None, "/mnt/job_stage/app/main.py", 20), + (TestAsset("src"), f"@{_TEST_STAGE}/", f"@{_TEST_STAGE}/main.py", "/mnt/job_stage/app/main.py", 20), ( TestAsset("src"), f"@{_TEST_STAGE}/", f"@{_TEST_STAGE}/subdir/sub_main.py", "/mnt/job_stage/app/subdir/sub_main.py", - 16, + 20, ), ( TestAsset("src"), @@ -161,25 +161,9 @@ def test_copy_payload_positive( expected_file_count: int, ) -> None: stage_path = f"{self.session.get_session_stage()}/{str(uuid4())}" - if upload_files.path.is_dir(): - for path in { - p.parent.joinpath(f"*{p.suffix}") if p.suffix else p - for p in upload_files.path.resolve().rglob("*") - if p.is_file() - }: - self.session.file.put( - str(path), - pathlib.Path(_TEST_STAGE).joinpath(path.parent.relative_to(upload_files.path)).as_posix(), - overwrite=True, - auto_compress=False, - ) - else: - self.session.file.put( - str(upload_files.path.resolve()), - f"{_TEST_STAGE}", - overwrite=True, - auto_compress=False, - ) + payload_utils.upload_payloads( + self.session, pathlib.Path(_TEST_STAGE), types.PayloadSpec(upload_files.path, None) + ) payload = payload_utils.JobPayload( source=source, entrypoint=entrypoint, @@ -204,7 +188,7 @@ def test_copy_payload_positive( TestAsset("src/third.py"), [(TestAsset("src/subdir/utils").path.as_posix(), "remote.src.utils")], "/mnt/job_stage/app/third.py", - 17, + 21, ), ) def test_upload_payload_additional_packages_local( @@ -236,18 +220,9 @@ def test_upload_payload_additional_packages_local( def test_upload_payload_additional_packages_stage(self) -> None: upload_files = TestAsset("src") - - for path in { - p.parent.joinpath(f"*{p.suffix}") if p.suffix else p - for p in upload_files.path.resolve().rglob("*") - if p.is_file() - }: - self.session.file.put( - str(path), - pathlib.Path(_TEST_STAGE).joinpath(path.parent.relative_to(upload_files.path)).as_posix(), - overwrite=True, - auto_compress=False, - ) + payload_utils.upload_payloads( + self.session, pathlib.Path(_TEST_STAGE), types.PayloadSpec(upload_files.path, None) + ) test_cases = [ ( @@ -255,7 +230,7 @@ def test_upload_payload_additional_packages_stage(self) -> None: f"@{_TEST_STAGE}/subdir/sub_main.py", [(f"@{_TEST_STAGE}/subdir/utils", "remote.subdir.utils")], "/mnt/job_stage/app/subdir/sub_main.py", - 17, + 21, ), ( f"@{_TEST_STAGE}/subdir", diff --git a/tests/integ/snowflake/ml/jobs/test_files/src/check_python.py b/tests/integ/snowflake/ml/jobs/test_files/src/check_python.py new file mode 100644 index 00000000..04292727 --- /dev/null +++ b/tests/integ/snowflake/ml/jobs/test_files/src/check_python.py @@ -0,0 +1,3 @@ +import sys + +print(f"Python version: {sys.version}") diff --git a/tests/integ/snowflake/ml/jobs/test_files/src/subdir6/.config.yaml b/tests/integ/snowflake/ml/jobs/test_files/src/subdir6/.config.yaml new file mode 100644 index 00000000..64ffc996 --- /dev/null +++ b/tests/integ/snowflake/ml/jobs/test_files/src/subdir6/.config.yaml @@ -0,0 +1,2 @@ +--- +secret_message: This is a secret message stored in a hidden YAML file. diff --git a/tests/integ/snowflake/ml/jobs/test_files/src/subdir6/.no_ext b/tests/integ/snowflake/ml/jobs/test_files/src/subdir6/.no_ext new file mode 100644 index 00000000..f2ef01ab --- /dev/null +++ b/tests/integ/snowflake/ml/jobs/test_files/src/subdir6/.no_ext @@ -0,0 +1 @@ +This is the content of a hidden file with no extension. diff --git a/tests/integ/snowflake/ml/jobs/test_files/src/subdir6/main.py b/tests/integ/snowflake/ml/jobs/test_files/src/subdir6/main.py new file mode 100644 index 00000000..f939dd42 --- /dev/null +++ b/tests/integ/snowflake/ml/jobs/test_files/src/subdir6/main.py @@ -0,0 +1,9 @@ +import yaml + +with open(".config.yaml", encoding="utf-8") as f: + config = yaml.safe_load(f) + +print(config["secret_message"]) +with open(".no_ext", encoding="utf-8") as f: + content = f.read() +print(f"{content}") diff --git a/tests/integ/snowflake/ml/registry/services/registry_model_deployment_test_base.py b/tests/integ/snowflake/ml/registry/services/registry_model_deployment_test_base.py index 78bbfa9e..a21745f9 100644 --- a/tests/integ/snowflake/ml/registry/services/registry_model_deployment_test_base.py +++ b/tests/integ/snowflake/ml/registry/services/registry_model_deployment_test_base.py @@ -123,11 +123,12 @@ def _deploy_model_with_image_override( version_name=mv._version_name, ) + image_repo_fqn = identifier.get_schema_level_object_identifier( + database_name_id.identifier(), schema_name_id.identifier(), image_repo_name.identifier() + ) mv._service_ops._model_deployment_spec.add_image_build_spec( - image_repo_database_name=database_name_id, - image_repo_schema_name=schema_name_id, - image_repo_name=image_repo_name, image_build_compute_pool_name=build_compute_pool, + fully_qualified_image_repo_name=image_repo_fqn, force_rebuild=force_rebuild, external_access_integrations=None, )