diff --git a/.bazelrc b/.bazelrc index 5dd8686c..0493a227 100644 --- a/.bazelrc +++ b/.bazelrc @@ -31,13 +31,13 @@ build:py3.12 --repo_env=BAZEL_CONDA_PYTHON_VERSION=3.12 build:build --config=_build # Config to sync files -run:pre_build --config=_build --config=py3.9 +run:pre_build --config=_build --config=py3.10 # Config to run type check -build:typecheck --aspects @rules_mypy//:mypy.bzl%mypy_aspect --output_groups=mypy --config=_all --config=py3.9 +build:typecheck --aspects @rules_mypy//:mypy.bzl%mypy_aspect --output_groups=mypy --config=_all --config=py3.10 # Config to build the doc -build:docs --config=_all --config=py3.9 +build:docs --config=_all --config=py3.10 # Public the extended setting diff --git a/CHANGELOG.md b/CHANGELOG.md index db3e8d38..6d631c11 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,15 @@ # Release History +## 1.14.0 + +### Bug Fixes + +### Behavior Changes + +### New Features + +* ML Job: The `additional_payloads` argument is now **deprecated** in favor of `imports`. + ## 1.13.0 ### Bug Fixes diff --git a/bazel/environments/fetch_conda_env_config.bzl b/bazel/environments/fetch_conda_env_config.bzl index bc453220..4a4324fc 100755 --- a/bazel/environments/fetch_conda_env_config.bzl +++ b/bazel/environments/fetch_conda_env_config.bzl @@ -3,7 +3,7 @@ load("//bazel/platforms:optional_dependency_groups.bzl", "OPTIONAL_DEPENDENCY_GR def _fetch_conda_env_config_impl(rctx): # read the particular environment variable we are interested in env_name = rctx.os.environ.get("BAZEL_CONDA_ENV_NAME", "core").lower() - python_ver = rctx.os.environ.get("BAZEL_CONDA_PYTHON_VERSION", "3.9").lower() + python_ver = rctx.os.environ.get("BAZEL_CONDA_PYTHON_VERSION", "3.10").lower() # necessary to create empty BUILD file for this rule # which will be located somewhere in the Bazel build files diff --git a/bazel/requirements/templates/bazelrc.tpl b/bazel/requirements/templates/bazelrc.tpl index 503e1d73..329a9284 100644 --- a/bazel/requirements/templates/bazelrc.tpl +++ b/bazel/requirements/templates/bazelrc.tpl @@ -28,13 +28,13 @@ build:py3.12 --repo_env=BAZEL_CONDA_PYTHON_VERSION=3.12 build:build --config=_build # Config to sync files -run:pre_build --config=_build --config=py3.9 +run:pre_build --config=_build --config=py3.10 # Config to run type check -build:typecheck --aspects @rules_mypy//:mypy.bzl%mypy_aspect --output_groups=mypy --config=_all --config=py3.9 +build:typecheck --aspects @rules_mypy//:mypy.bzl%mypy_aspect --output_groups=mypy --config=_all --config=py3.10 # Config to build the doc -build:docs --config=_all --config=py3.9 +build:docs --config=_all --config=py3.10 # Public the extended setting diff --git a/ci/build_and_run_tests.sh b/ci/build_and_run_tests.sh index 25939880..5f5d2821 100755 --- a/ci/build_and_run_tests.sh +++ b/ci/build_and_run_tests.sh @@ -42,7 +42,7 @@ WITH_SNOWPARK=false WITH_SPCS_IMAGE=false RUN_GRYPE=false MODE="continuous_run" -PYTHON_VERSION=3.9 +PYTHON_VERSION=3.10 PYTHON_ENABLE_SCRIPT="bin/activate" SNOWML_DIR="snowml" SNOWPARK_DIR="snowpark-python" diff --git a/ci/conda_recipe/meta.yaml b/ci/conda_recipe/meta.yaml index 8025a9f0..78c1a9d6 100644 --- a/ci/conda_recipe/meta.yaml +++ b/ci/conda_recipe/meta.yaml @@ -17,7 +17,7 @@ build: noarch: python package: name: snowflake-ml-python - version: 1.13.0 + version: 1.14.0 requirements: build: - python diff --git a/snowflake/ml/jobs/_utils/constants.py b/snowflake/ml/jobs/_utils/constants.py index a3cc5fba..1e9ebe42 100644 --- a/snowflake/ml/jobs/_utils/constants.py +++ b/snowflake/ml/jobs/_utils/constants.py @@ -25,7 +25,7 @@ DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images" DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks" DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks" -DEFAULT_IMAGE_TAG = "1.6.2" +DEFAULT_IMAGE_TAG = "1.8.0" DEFAULT_ENTRYPOINT_PATH = "func.py" # Percent of container memory to allocate for /dev/shm volume diff --git a/snowflake/ml/jobs/_utils/scripts/mljob_launcher.py b/snowflake/ml/jobs/_utils/scripts/mljob_launcher.py index e9986e87..7e802b86 100644 --- a/snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +++ b/snowflake/ml/jobs/_utils/scripts/mljob_launcher.py @@ -234,12 +234,6 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N if payload_dir and payload_dir not in sys.path: sys.path.insert(0, payload_dir) - # Create a Snowpark session before running the script - # Session can be retrieved from using snowflake.snowpark.context.get_active_session() - config = SnowflakeLoginOptions() - config["client_session_keep_alive"] = "True" - session = Session.builder.configs(config).create() # noqa: F841 - try: if main_func: @@ -266,7 +260,6 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N finally: # Restore original sys.argv sys.argv = original_argv - session.close() def main(script_path: str, *script_args: Any, script_main_func: Optional[str] = None) -> ExecutionResult: @@ -297,6 +290,12 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] = except ModuleNotFoundError: warnings.warn("Ray is not installed, skipping Ray initialization", ImportWarning, stacklevel=1) + # Create a Snowpark session before starting + # Session can be retrieved from using snowflake.snowpark.context.get_active_session() + config = SnowflakeLoginOptions() + config["client_session_keep_alive"] = "True" + session = Session.builder.configs(config).create() # noqa: F841 + try: # Wait for minimum required instances if specified min_instances_str = os.environ.get(MIN_INSTANCES_ENV_VAR) or "1" @@ -352,6 +351,9 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] = f"Failed to serialize JSON result to {result_json_path}: {json_exc}", RuntimeWarning, stacklevel=1 ) + # Close the session after serializing the result + session.close() + if __name__ == "__main__": # Parse command line arguments diff --git a/snowflake/ml/jobs/job.py b/snowflake/ml/jobs/job.py index cfea8a16..3188ac7e 100644 --- a/snowflake/ml/jobs/job.py +++ b/snowflake/ml/jobs/job.py @@ -83,6 +83,8 @@ def _service_spec(self) -> dict[str, Any]: def _container_spec(self) -> dict[str, Any]: """Get the job's main container spec.""" containers = self._service_spec["spec"]["containers"] + if len(containers) == 1: + return cast(dict[str, Any], containers[0]) try: container_spec = next(c for c in containers if c["name"] == constants.DEFAULT_CONTAINER_NAME) except StopIteration: @@ -163,7 +165,7 @@ def get_logs( Returns: The job's execution logs. """ - logs = _get_logs(self._session, self.id, limit, instance_id, verbose) + logs = _get_logs(self._session, self.id, limit, instance_id, self._container_spec["name"], verbose) assert isinstance(logs, str) # mypy if as_list: return logs.splitlines() @@ -281,7 +283,12 @@ def _get_service_spec(session: snowpark.Session, job_id: str) -> dict[str, Any]: @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit", "instance_id"]) def _get_logs( - session: snowpark.Session, job_id: str, limit: int = -1, instance_id: Optional[int] = None, verbose: bool = True + session: snowpark.Session, + job_id: str, + limit: int = -1, + instance_id: Optional[int] = None, + container_name: str = constants.DEFAULT_CONTAINER_NAME, + verbose: bool = True, ) -> str: """ Retrieve the job's execution logs. @@ -291,6 +298,7 @@ def _get_logs( limit: The maximum number of lines to return. Negative values are treated as no limit. session: The Snowpark session to use. If none specified, uses active session. instance_id: Optional instance ID to get logs from a specific instance. + container_name: The container name to get logs from a specific container. verbose: Whether to return the full log or just the portion between START and END messages. Returns: @@ -311,7 +319,7 @@ def _get_logs( params: list[Any] = [ job_id, 0 if instance_id is None else instance_id, - constants.DEFAULT_CONTAINER_NAME, + container_name, ] if limit > 0: params.append(limit) @@ -337,7 +345,7 @@ def _get_logs( job_id, limit=limit, instance_id=instance_id if instance_id else 0, - container_name=constants.DEFAULT_CONTAINER_NAME, + container_name=container_name, ) full_log = os.linesep.join(row[0] for row in logs) diff --git a/snowflake/ml/jobs/jobs_test.py b/snowflake/ml/jobs/jobs_test.py index a3900ece..b1ee96df 100644 --- a/snowflake/ml/jobs/jobs_test.py +++ b/snowflake/ml/jobs/jobs_test.py @@ -10,6 +10,13 @@ from snowflake.snowpark import exceptions as sp_exceptions from snowflake.snowpark.row import Row +SERVICE_SPEC = """ +spec: + containers: + - name: main + image: test-image +""" + class JobTest(parameterized.TestCase): @parameterized.named_parameters( # type: ignore[misc] @@ -83,7 +90,7 @@ def test_get_logs_negative(self) -> None: def sql_side_effect(session: snowpark.Session, query_str: str, *args: Any, **kwargs: Any) -> Any: if query_str.startswith("DESCRIBE SERVICE IDENTIFIER"): - return [Row(target_instances=2)] + return [Row(target_instances=2, spec=SERVICE_SPEC)] else: raise sp_exceptions.SnowparkSQLException("Waiting to start, Container Status: PENDING") @@ -97,7 +104,7 @@ def test_get_logs_from_event_table(self) -> None: def sql_side_effect(session: snowpark.Session, query_str: str, *args: Any, **kwargs: Any) -> Any: if query_str.startswith("DESCRIBE SERVICE IDENTIFIER"): return [ - Row(target_instances=2), + Row(target_instances=2, spec=SERVICE_SPEC), ] elif query_str.startswith("SELECT VALUE FROM "): return [ diff --git a/snowflake/ml/jobs/manager.py b/snowflake/ml/jobs/manager.py index 4d064d31..f9205736 100644 --- a/snowflake/ml/jobs/manager.py +++ b/snowflake/ml/jobs/manager.py @@ -232,6 +232,7 @@ def submit_file( enable_metrics (bool): Whether to enable metrics publishing for the job. query_warehouse (str): The query warehouse to use. Defaults to session warehouse. spec_overrides (dict): A dictionary of overrides for the service spec. + imports (list[Union[tuple[str, str], tuple[str]]]): A list of additional payloads used in the job. Returns: An object representing the submitted job. @@ -286,6 +287,7 @@ def submit_directory( enable_metrics (bool): Whether to enable metrics publishing for the job. query_warehouse (str): The query warehouse to use. Defaults to session warehouse. spec_overrides (dict): A dictionary of overrides for the service spec. + imports (list[Union[tuple[str, str], tuple[str]]]): A list of additional payloads used in the job. Returns: An object representing the submitted job. @@ -341,6 +343,7 @@ def submit_from_stage( enable_metrics (bool): Whether to enable metrics publishing for the job. query_warehouse (str): The query warehouse to use. Defaults to session warehouse. spec_overrides (dict): A dictionary of overrides for the service spec. + imports (list[Union[tuple[str, str], tuple[str]]]): A list of additional payloads used in the job. Returns: An object representing the submitted job. @@ -404,6 +407,8 @@ def _submit_job( "num_instances", # deprecated "target_instances", "min_instances", + "enable_metrics", + "query_warehouse", ], ) def _submit_job( @@ -447,6 +452,13 @@ def _submit_job( ) target_instances = max(target_instances, kwargs.pop("num_instances")) + imports = None + if "additional_payloads" in kwargs: + logger.warning( + "'additional_payloads' is deprecated and will be removed in a future release. Use 'imports' instead." + ) + imports = kwargs.pop("additional_payloads") + # Use kwargs for less common optional parameters database = kwargs.pop("database", None) schema = kwargs.pop("schema", None) @@ -457,10 +469,7 @@ def _submit_job( spec_overrides = kwargs.pop("spec_overrides", None) enable_metrics = kwargs.pop("enable_metrics", True) query_warehouse = kwargs.pop("query_warehouse", session.get_current_warehouse()) - additional_payloads = kwargs.pop("additional_payloads", None) - - if additional_payloads: - logger.warning("'additional_payloads' is in private preview since 1.9.1. Do not use it in production.") + imports = kwargs.pop("imports", None) or imports # Warn if there are unknown kwargs if kwargs: @@ -492,7 +501,7 @@ def _submit_job( try: # Upload payload uploaded_payload = payload_utils.JobPayload( - source, entrypoint=entrypoint, pip_requirements=pip_requirements, additional_payloads=additional_payloads + source, entrypoint=entrypoint, pip_requirements=pip_requirements, additional_payloads=imports ).upload(session, stage_path) except snowpark.exceptions.SnowparkSQLException as e: if e.sql_error_code == 90106: @@ -501,6 +510,22 @@ def _submit_job( ) raise + # FIXME: Temporary patches, remove this after v1 is deprecated + if target_instances > 1: + default_spec_overrides = { + "spec": { + "endpoints": [ + {"name": "ray-dashboard-endpoint", "port": 12003, "protocol": "TCP"}, + ] + }, + } + if spec_overrides: + spec_overrides = spec_utils.merge_patch( + default_spec_overrides, spec_overrides, display_name="spec_overrides" + ) + else: + spec_overrides = default_spec_overrides + if feature_flags.FeatureFlags.USE_SUBMIT_JOB_V2.is_enabled(): # Add default env vars (extracted from spec_utils.generate_service_spec) combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})} @@ -668,8 +693,10 @@ def _ensure_session(session: Optional[snowpark.Session]) -> snowpark.Session: session = session or get_active_session() except snowpark.exceptions.SnowparkSessionException as e: if "More than one active session" in e.message: - raise RuntimeError("Please specify the session as a parameter in API call") + raise RuntimeError( + "More than one active session is found. Please specify the session explicitly as a parameter" + ) from None if "No default Session is found" in e.message: - raise RuntimeError("Please create a session before API call") + raise RuntimeError("No active session is found. Please create a session") from None raise return session diff --git a/snowflake/ml/lineage/lineage_node.py b/snowflake/ml/lineage/lineage_node.py index 219760a8..184d3398 100644 --- a/snowflake/ml/lineage/lineage_node.py +++ b/snowflake/ml/lineage/lineage_node.py @@ -83,7 +83,6 @@ def _load_from_lineage_node(session: snowpark.Session, name: str, version: str) raise NotImplementedError() @telemetry.send_api_usage_telemetry(project=_PROJECT) - @snowpark._internal.utils.private_preview(version="1.5.3") def lineage( self, direction: Literal["upstream", "downstream"] = "downstream", diff --git a/snowflake/ml/lineage/notebooks/ML Lineage Workflows.ipynb b/snowflake/ml/lineage/notebooks/ML Lineage Workflows.ipynb index 4d50b555..95b929cc 100644 --- a/snowflake/ml/lineage/notebooks/ML Lineage Workflows.ipynb +++ b/snowflake/ml/lineage/notebooks/ML Lineage Workflows.ipynb @@ -1774,7 +1774,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "py38_env", "language": "python", "name": "python3" }, diff --git a/snowflake/ml/model/_client/model/model_version_impl.py b/snowflake/ml/model/_client/model/model_version_impl.py index c4d05e32..59d023cd 100644 --- a/snowflake/ml/model/_client/model/model_version_impl.py +++ b/snowflake/ml/model/_client/model/model_version_impl.py @@ -788,7 +788,7 @@ def _enrich_inference_engine_args( inference_engine_args: service_ops.InferenceEngineArgs, gpu_requests: Optional[Union[str, int]] = None, ) -> Optional[service_ops.InferenceEngineArgs]: - """Enrich inference engine args with model path and tensor parallelism settings. + """Enrich inference engine args with tensor parallelism settings. Args: inference_engine_args: The original inference engine args @@ -803,21 +803,6 @@ def _enrich_inference_engine_args( if inference_engine_args.inference_engine_args_override is None: inference_engine_args.inference_engine_args_override = [] - # Get model stage path and strip off "snow://" prefix - model_stage_path = self._model_ops.get_model_version_stage_path( - database_name=None, - schema_name=None, - model_name=self._model_name, - version_name=self._version_name, - ) - - # Strip "snow://" prefix - if model_stage_path.startswith("snow://"): - model_stage_path = model_stage_path.replace("snow://", "", 1) - - # Always overwrite the model key by appending - inference_engine_args.inference_engine_args_override.append(f"--model={model_stage_path}") - gpu_count = None # Set tensor-parallelism if gpu_requests is specified diff --git a/snowflake/ml/model/_client/model/model_version_impl_test.py b/snowflake/ml/model/_client/model/model_version_impl_test.py index f41093f0..09df54e9 100644 --- a/snowflake/ml/model/_client/model/model_version_impl_test.py +++ b/snowflake/ml/model/_client/model/model_version_impl_test.py @@ -1018,11 +1018,6 @@ def test_create_service_with_experimental_options(self) -> None: mock.patch.object( self.m_mv, "_check_huggingface_text_generation_model" ) as mock_check_huggingface_text_generation_model, - mock.patch.object( - self.m_mv._model_ops, - "get_model_version_stage_path", - return_value="snow://model/DB.SCHEMA.MODEL/versions/v1/", - ), mock.patch.object( self.m_mv._model_ops, "_fetch_model_spec", @@ -1065,7 +1060,6 @@ def test_create_service_with_experimental_options(self) -> None: expected_args = [ "--max_tokens=1000", "--temperature=0.8", - "--model=model/DB.SCHEMA.MODEL/versions/v1/", "--tensor-parallel-size=4", ] @@ -1347,118 +1341,103 @@ def test_get_inference_engine_args(self) -> None: def test_enrich_inference_engine_args(self) -> None: """Test _enrich_inference_engine_args method with various inputs.""" - # Mock get_model_version_stage_path to return a predictable path - with mock.patch.object( - self.m_mv._model_ops, - "get_model_version_stage_path", - return_value="snow://model/TEMP.test.MODEL/versions/v1/", - ) as mock_get_path: - # Test with args=None and no GPU - enriched = self.m_mv._enrich_inference_engine_args( - service_ops.InferenceEngineArgs( - inference_engine=inference_engine.InferenceEngine.VLLM, - inference_engine_args_override=None, - ), - gpu_requests=None, - ) - assert enriched is not None - self.assertEqual(enriched.inference_engine, inference_engine.InferenceEngine.VLLM) - self.assertEqual(enriched.inference_engine_args_override, ["--model=model/TEMP.test.MODEL/versions/v1/"]) - mock_get_path.assert_called_with( - database_name=None, - schema_name=None, - model_name=sql_identifier.SqlIdentifier("MODEL"), - version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), - ) + # Test with args=None and no GPU + enriched = self.m_mv._enrich_inference_engine_args( + service_ops.InferenceEngineArgs( + inference_engine=inference_engine.InferenceEngine.VLLM, + inference_engine_args_override=None, + ), + gpu_requests=None, + ) + assert enriched is not None + self.assertEqual(enriched.inference_engine, inference_engine.InferenceEngine.VLLM) + self.assertEqual(enriched.inference_engine_args_override, []) + + # Test with empty args and GPU count + enriched = self.m_mv._enrich_inference_engine_args( + service_ops.InferenceEngineArgs( + inference_engine=inference_engine.InferenceEngine.VLLM, + inference_engine_args_override=None, + ), + gpu_requests=2, + ) + assert enriched is not None + self.assertEqual(enriched.inference_engine, inference_engine.InferenceEngine.VLLM) + self.assertEqual(enriched.inference_engine_args_override, ["--tensor-parallel-size=2"]) + + # Test with args and string GPU count + original_args = ["--max_tokens=100", "--temperature=0.7"] + enriched = self.m_mv._enrich_inference_engine_args( + service_ops.InferenceEngineArgs( + inference_engine=inference_engine.InferenceEngine.VLLM, + inference_engine_args_override=original_args, + ), + gpu_requests="4", + ) + assert enriched is not None + self.assertEqual( + enriched, + service_ops.InferenceEngineArgs( + inference_engine=inference_engine.InferenceEngine.VLLM, + inference_engine_args_override=[ + "--max_tokens=100", + "--temperature=0.7", + "--tensor-parallel-size=4", + ], + ), + ) + + # Test overwriting existing model key with new model key by appending to the list + enriched = self.m_mv._enrich_inference_engine_args( + service_ops.InferenceEngineArgs( + inference_engine=inference_engine.InferenceEngine.VLLM, + inference_engine_args_override=[ + "--max_tokens=100", + "--temperature=0.7", + ], + ), + ) + self.assertEqual( + enriched, + service_ops.InferenceEngineArgs( + inference_engine=inference_engine.InferenceEngine.VLLM, + inference_engine_args_override=[ + "--max_tokens=100", + "--temperature=0.7", + ], + ), + ) - # Test with empty args and GPU count - enriched = self.m_mv._enrich_inference_engine_args( + # Test with invalid GPU string + with self.assertRaises(ValueError): + self.m_mv._enrich_inference_engine_args( service_ops.InferenceEngineArgs( inference_engine=inference_engine.InferenceEngine.VLLM, inference_engine_args_override=None, ), - gpu_requests=2, + gpu_requests="invalid", ) - assert enriched is not None - self.assertEqual(enriched.inference_engine, inference_engine.InferenceEngine.VLLM) - # Test with args and string GPU count - original_args = ["--max_tokens=100", "--temperature=0.7"] - enriched = self.m_mv._enrich_inference_engine_args( + # Test with zero GPU (should not set tensor-parallel-size) + with self.assertRaises(ValueError): + self.m_mv._enrich_inference_engine_args( service_ops.InferenceEngineArgs( inference_engine=inference_engine.InferenceEngine.VLLM, - inference_engine_args_override=original_args, - ), - gpu_requests="4", - ) - assert enriched is not None - self.assertEqual( - enriched, - service_ops.InferenceEngineArgs( - inference_engine=inference_engine.InferenceEngine.VLLM, - inference_engine_args_override=[ - "--max_tokens=100", - "--temperature=0.7", - "--model=model/TEMP.test.MODEL/versions/v1/", - "--tensor-parallel-size=4", - ], + inference_engine_args_override=None, ), + gpu_requests=0, ) - # Test overwriting existing model key with new model key by appending to the list - enriched = self.m_mv._enrich_inference_engine_args( - service_ops.InferenceEngineArgs( - inference_engine=inference_engine.InferenceEngine.VLLM, - inference_engine_args_override=[ - "--max_tokens=100", - "--temperature=0.7", - "--model=old/path", - ], - ), - ) - self.assertEqual( - enriched, + # Test with negative GPU (should not set tensor-parallel-size) + with self.assertRaises(ValueError): + self.m_mv._enrich_inference_engine_args( service_ops.InferenceEngineArgs( inference_engine=inference_engine.InferenceEngine.VLLM, - inference_engine_args_override=[ - "--max_tokens=100", - "--temperature=0.7", - "--model=old/path", - "--model=model/TEMP.test.MODEL/versions/v1/", - ], + inference_engine_args_override=None, ), + gpu_requests=-1, ) - # Test with invalid GPU string - with self.assertRaises(ValueError): - self.m_mv._enrich_inference_engine_args( - service_ops.InferenceEngineArgs( - inference_engine=inference_engine.InferenceEngine.VLLM, - inference_engine_args_override=None, - ), - gpu_requests="invalid", - ) - - # Test with zero GPU (should not set tensor-parallel-size) - with self.assertRaises(ValueError): - self.m_mv._enrich_inference_engine_args( - service_ops.InferenceEngineArgs( - inference_engine=inference_engine.InferenceEngine.VLLM, - inference_engine_args_override=None, - ), - gpu_requests=0, - ) - - # Test with negative GPU (should not set tensor-parallel-size) - with self.assertRaises(ValueError): - self.m_mv._enrich_inference_engine_args( - service_ops.InferenceEngineArgs( - inference_engine=inference_engine.InferenceEngine.VLLM, - inference_engine_args_override=None, - ), - gpu_requests=-1, - ) - def test_check_huggingface_text_generation_model(self) -> None: """Test _check_huggingface_text_generation_model method.""" # Test successful case - HuggingFace text-generation model diff --git a/snowflake/ml/model/_client/ops/model_ops.py b/snowflake/ml/model/_client/ops/model_ops.py index 4b3c4a9c..67bc13cf 100644 --- a/snowflake/ml/model/_client/ops/model_ops.py +++ b/snowflake/ml/model/_client/ops/model_ops.py @@ -47,7 +47,8 @@ class ServiceInfo(TypedDict): class ModelOperator: INFERENCE_SERVICE_ENDPOINT_NAME = "inference" INGRESS_ENDPOINT_URL_SUFFIX = "snowflakecomputing.app" - PRIVATELINK_INGRESS_ENDPOINT_URL_SUBSTRING = "privatelink.snowflakecomputing" + # app-service-privatelink might not contain "snowflakecomputing" in the url - using the minimum required substring + PRIVATELINK_INGRESS_ENDPOINT_URL_SUBSTRING = "privatelink.snowflake" def __init__( self, @@ -631,7 +632,13 @@ def _extract_and_validate_ingress_url(self, res_row: "row.Row") -> Optional[str] def _extract_and_validate_privatelink_url(self, res_row: "row.Row") -> Optional[str]: """Extract and validate privatelink ingress URL from endpoint row.""" - url_value = res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_PRIVATELINK_INGRESS_URL_COL_NAME] + # Check if the privatelink_ingress_url column exists + col_name = self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_PRIVATELINK_INGRESS_URL_COL_NAME + if col_name not in res_row: + # Column doesn't exist in query result for non-Business Critical accounts + return None + + url_value = res_row[col_name] if url_value is None: return None url_str = str(url_value) diff --git a/snowflake/ml/model/_client/ops/model_ops_test.py b/snowflake/ml/model/_client/ops/model_ops_test.py index 392540bb..76bc1a4f 100644 --- a/snowflake/ml/model/_client/ops/model_ops_test.py +++ b/snowflake/ml/model/_client/ops/model_ops_test.py @@ -468,12 +468,13 @@ def test_unset_tag(self) -> None: def test_show_services_1(self) -> None: m_services_list_res = [Row(inference_services='["a.b.c", "d.e.f"]')] + # Row objects with privatelink_ingress_url field for Business Critical accounts m_endpoints_list_res_0 = [Row(name="inference", ingress_url="Waiting", privatelink_ingress_url=None)] m_endpoints_list_res_1 = [ Row( name="inference", ingress_url="foo.snowflakecomputing.app", - privatelink_ingress_url="privatelink.foo.privatelink.snowflakecomputing.com", + privatelink_ingress_url="bar.privatelink.snowflakecomputing.com", ) ] m_statuses_0 = [ @@ -512,6 +513,7 @@ def test_show_services_1(self) -> None: ], ) as mock_show_endpoints: + # Test with regular connection - should display the ingress_url with mock.patch.object(self.m_ops._session.connection, "host", "account.snowflakecomputing.com"): res = self.m_ops.show_services( database_name=sql_identifier.SqlIdentifier("TEMP"), @@ -528,6 +530,7 @@ def test_show_services_1(self) -> None: ], ) + # Test with privatelink connection - should display the privatelink_ingress_url with mock.patch.object( self.m_ops._session.connection, "host", "account.privatelink.snowflakecomputing.com" ): @@ -545,7 +548,7 @@ def test_show_services_1(self) -> None: { "name": "d.e.f", "status": "RUNNING", - "inference_endpoint": "privatelink.foo.privatelink.snowflakecomputing.com", + "inference_endpoint": "bar.privatelink.snowflakecomputing.com", }, ], ) @@ -617,6 +620,8 @@ def test_show_services_2(self) -> None: version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), statement_params=self.m_statement_params, ) + # Test with regular connection + # Inference endpoint will be None as both ingress_url and privatelink_ingress_url are None self.assertListEqual( res, [ @@ -649,12 +654,12 @@ def test_show_services_3(self) -> None: Row( name="inference", ingress_url="foo.snowflakecomputing.app", - privatelink_ingress_url="privatelink.foo.privatelink.snowflakecomputing.com", + privatelink_ingress_url="foo.privatelink.snowflakecomputing.com", ), Row( name="another", ingress_url="bar.snowflakecomputing.app", - privatelink_ingress_url="privatelink.bar.privatelink.snowflakecomputing.com", + privatelink_ingress_url="bar.privatelink.snowflakecomputing.com", ), ] m_statuses = [ @@ -681,6 +686,8 @@ def test_show_services_3(self) -> None: version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), statement_params=self.m_statement_params, ) + # Test with regular connection + # Inference endpoint will take the ingress_url value only if the name field contains "inference" self.assertListEqual( res, [ @@ -713,7 +720,7 @@ def test_show_services_4(self) -> None: Row( name="custom", ingress_url="foo.snowflakecomputing.app", - privatelink_ingress_url="privatelink.foo.privatelink.snowflakecomputing.com", + privatelink_ingress_url="foo.privatelink.snowflakecomputing.com", ) ] m_statuses = [ @@ -740,6 +747,8 @@ def test_show_services_4(self) -> None: version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), statement_params=self.m_statement_params, ) + # Test with regular connection + # Inference endpoint will be None as the name field does not contain "inference" self.assertListEqual( res, [ @@ -766,6 +775,125 @@ def test_show_services_4(self) -> None: statement_params=self.m_statement_params, ) + def test_show_services_5(self) -> None: + """Test show_services for non-Business Critical accounts where privatelink_ingress_url column doesn't exist.""" + m_services_list_res = [Row(inference_services='["a.b.c", "d.e.f"]')] + # Row objects without privatelink_ingress_url field + m_endpoints_list_res_0 = [Row(name="inference", ingress_url="bar.snowflakecomputing.app")] + m_endpoints_list_res_1 = [Row(name="inference", ingress_url="foo.snowflakecomputing.app")] + m_statuses_0 = [ + service_sql.ServiceStatusInfo( + service_status=service_sql.ServiceStatus.PENDING, + instance_id=0, + instance_status=service_sql.InstanceStatus.PENDING, + container_status=service_sql.ContainerStatus.PENDING, + message=None, + ) + ] + m_statuses_1 = [ + service_sql.ServiceStatusInfo( + service_status=service_sql.ServiceStatus.RUNNING, + instance_id=1, + instance_status=service_sql.InstanceStatus.READY, + container_status=service_sql.ContainerStatus.READY, + message=None, + ) + ] + + with mock.patch.object( + self.m_ops._model_client, "show_versions", return_value=m_services_list_res + ) as mock_show_versions, mock.patch.object( + self.m_ops._service_client, + "get_service_container_statuses", + side_effect=[m_statuses_0, m_statuses_1, m_statuses_0, m_statuses_1], + ) as mock_get_service_container_statuses, mock.patch.object( + self.m_ops._service_client, + "show_endpoints", + side_effect=[ + m_endpoints_list_res_0, + m_endpoints_list_res_1, + m_endpoints_list_res_0, + m_endpoints_list_res_1, + ], + ) as mock_show_endpoints: + + # Test with regular connection + with mock.patch.object(self.m_ops._session.connection, "host", "account.snowflakecomputing.com"): + res = self.m_ops.show_services( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=self.m_statement_params, + ) + self.assertListEqual( + res, + [ + {"name": "a.b.c", "status": "PENDING", "inference_endpoint": "bar.snowflakecomputing.app"}, + {"name": "d.e.f", "status": "RUNNING", "inference_endpoint": "foo.snowflakecomputing.app"}, + ], + ) + + # Test with privatelink connection - should still use ingress_url since privatelink column doesn't exist + with mock.patch.object( + self.m_ops._session.connection, "host", "account.privatelink.snowflakecomputing.com" + ): + res = self.m_ops.show_services( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=self.m_statement_params, + ) + self.assertListEqual( + res, + [ + {"name": "a.b.c", "status": "PENDING", "inference_endpoint": "bar.snowflakecomputing.app"}, + {"name": "d.e.f", "status": "RUNNING", "inference_endpoint": "foo.snowflakecomputing.app"}, + ], + ) + + expected_show_versions_call = mock.call( + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + model_name=sql_identifier.SqlIdentifier("MODEL"), + version_name=sql_identifier.SqlIdentifier("v1", case_sensitive=True), + statement_params=self.m_statement_params, + ) + mock_show_versions.assert_has_calls([expected_show_versions_call, expected_show_versions_call]) + + expected_status_calls = [ + mock.call( + database_name=sql_identifier.SqlIdentifier("a"), + schema_name=sql_identifier.SqlIdentifier("b"), + service_name=sql_identifier.SqlIdentifier("c"), + statement_params=self.m_statement_params, + ), + mock.call( + database_name=sql_identifier.SqlIdentifier("d"), + schema_name=sql_identifier.SqlIdentifier("e"), + service_name=sql_identifier.SqlIdentifier("f"), + statement_params=self.m_statement_params, + ), + ] + mock_get_service_container_statuses.assert_has_calls(expected_status_calls * 2) + + expected_endpoint_calls = [ + mock.call( + database_name=sql_identifier.SqlIdentifier("a"), + schema_name=sql_identifier.SqlIdentifier("b"), + service_name=sql_identifier.SqlIdentifier("c"), + statement_params=self.m_statement_params, + ), + mock.call( + database_name=sql_identifier.SqlIdentifier("d"), + schema_name=sql_identifier.SqlIdentifier("e"), + service_name=sql_identifier.SqlIdentifier("f"), + statement_params=self.m_statement_params, + ), + ] + mock_show_endpoints.assert_has_calls(expected_endpoint_calls * 2) + def test_show_services_pre_bcr(self) -> None: m_list_res = [Row(comment="mycomment")] with mock.patch.object( diff --git a/snowflake/ml/model/_client/ops/service_ops.py b/snowflake/ml/model/_client/ops/service_ops.py index 8576ff3b..98c91361 100644 --- a/snowflake/ml/model/_client/ops/service_ops.py +++ b/snowflake/ml/model/_client/ops/service_ops.py @@ -323,17 +323,20 @@ def create_service( statement_params=statement_params, ) - # stream service logs in a thread - model_build_service_name = sql_identifier.SqlIdentifier( - self._get_service_id_from_deployment_step(query_id, DeploymentStep.MODEL_BUILD) - ) - model_build_service = ServiceLogInfo( - database_name=service_database_name, - schema_name=service_schema_name, - service_name=model_build_service_name, - deployment_step=DeploymentStep.MODEL_BUILD, - log_color=service_logger.LogColor.GREEN, - ) + model_build_service: Optional[ServiceLogInfo] = None + if is_enable_image_build: + # stream service logs in a thread + model_build_service_name = sql_identifier.SqlIdentifier( + self._get_service_id_from_deployment_step(query_id, DeploymentStep.MODEL_BUILD) + ) + model_build_service = ServiceLogInfo( + database_name=service_database_name, + schema_name=service_schema_name, + service_name=model_build_service_name, + deployment_step=DeploymentStep.MODEL_BUILD, + log_color=service_logger.LogColor.GREEN, + ) + model_inference_service = ServiceLogInfo( database_name=service_database_name, schema_name=service_schema_name, @@ -375,7 +378,7 @@ def create_service( progress_status.increment() # Poll for model build to start if not using existing service - if not model_inference_service_exists: + if not model_inference_service_exists and model_build_service: self._wait_for_service_status( model_build_service_name, service_sql.ServiceStatus.RUNNING, @@ -390,7 +393,7 @@ def create_service( progress_status.increment() # Poll for model build completion - if not model_inference_service_exists: + if not model_inference_service_exists and model_build_service: self._wait_for_service_status( model_build_service_name, service_sql.ServiceStatus.DONE, @@ -454,7 +457,7 @@ def _start_service_log_streaming( self, async_job: snowpark.AsyncJob, model_logger_service: Optional[ServiceLogInfo], - model_build_service: ServiceLogInfo, + model_build_service: Optional[ServiceLogInfo], model_inference_service: ServiceLogInfo, model_inference_service_exists: bool, force_rebuild: bool, @@ -483,7 +486,7 @@ def _fetch_log_and_update_meta( self, force_rebuild: bool, service_log_meta: ServiceLogMetadata, - model_build_service: ServiceLogInfo, + model_build_service: Optional[ServiceLogInfo], model_inference_service: ServiceLogInfo, operation_id: str, statement_params: Optional[dict[str, Any]] = None, @@ -599,13 +602,24 @@ def _fetch_log_and_update_meta( # check if model logger service is done # and transition the service log metadata to the model image build service if service.deployment_step == DeploymentStep.MODEL_LOGGING: - service_log_meta.transition_service_log_metadata( - model_build_service, - f"Model Logger service {service.display_service_name} complete.", - is_model_build_service_done=False, - is_model_logger_service_done=service_log_meta.is_model_logger_service_done, - operation_id=operation_id, - ) + if model_build_service: + # building the inference image, transition to the model build service + service_log_meta.transition_service_log_metadata( + model_build_service, + f"Model Logger service {service.display_service_name} complete.", + is_model_build_service_done=False, + is_model_logger_service_done=service_log_meta.is_model_logger_service_done, + operation_id=operation_id, + ) + else: + # no model build service, transition to the model inference service + service_log_meta.transition_service_log_metadata( + model_inference_service, + f"Model Logger service {service.display_service_name} complete.", + is_model_build_service_done=True, + is_model_logger_service_done=service_log_meta.is_model_logger_service_done, + operation_id=operation_id, + ) # check if model build service is done # and transition the service log metadata to the model inference service elif service.deployment_step == DeploymentStep.MODEL_BUILD: @@ -616,6 +630,8 @@ def _fetch_log_and_update_meta( is_model_logger_service_done=service_log_meta.is_model_logger_service_done, operation_id=operation_id, ) + elif service.deployment_step == DeploymentStep.MODEL_INFERENCE: + module_logger.info(f"Inference service {service.display_service_name} is deployed.") else: module_logger.warning(f"Service {service.display_service_name} is done, but not transitioning.") @@ -623,7 +639,7 @@ def _stream_service_logs( self, async_job: snowpark.AsyncJob, model_logger_service: Optional[ServiceLogInfo], - model_build_service: ServiceLogInfo, + model_build_service: Optional[ServiceLogInfo], model_inference_service: ServiceLogInfo, model_inference_service_exists: bool, force_rebuild: bool, @@ -632,14 +648,23 @@ def _stream_service_logs( ) -> None: """Stream service logs while the async job is running.""" - model_build_service_logger = service_logger.get_logger( # BuildJobName - model_build_service.display_service_name, - model_build_service.log_color, - operation_id=operation_id, - ) - if model_logger_service: - model_logger_service_logger = service_logger.get_logger( # ModelLoggerName - model_logger_service.display_service_name, + if model_build_service: + model_build_service_logger = service_logger.get_logger( + model_build_service.display_service_name, # BuildJobName + model_build_service.log_color, + operation_id=operation_id, + ) + service_log_meta = ServiceLogMetadata( + service_logger=model_build_service_logger, + service=model_build_service, + service_status=None, + is_model_build_service_done=False, + is_model_logger_service_done=True, + log_offset=0, + ) + elif model_logger_service: + model_logger_service_logger = service_logger.get_logger( + model_logger_service.display_service_name, # ModelLoggerName model_logger_service.log_color, operation_id=operation_id, ) @@ -653,12 +678,17 @@ def _stream_service_logs( log_offset=0, ) else: + model_inference_service_logger = service_logger.get_logger( + model_inference_service.display_service_name, # ModelInferenceName + model_inference_service.log_color, + operation_id=operation_id, + ) service_log_meta = ServiceLogMetadata( - service_logger=model_build_service_logger, - service=model_build_service, + service_logger=model_inference_service_logger, + service=model_inference_service, service_status=None, is_model_build_service_done=False, - is_model_logger_service_done=True, + is_model_logger_service_done=False, log_offset=0, ) diff --git a/snowflake/ml/model/_client/ops/service_ops_test.py b/snowflake/ml/model/_client/ops/service_ops_test.py index 12067f6f..9fddc785 100644 --- a/snowflake/ml/model/_client/ops/service_ops_test.py +++ b/snowflake/ml/model/_client/ops/service_ops_test.py @@ -774,7 +774,6 @@ def test_create_service_custom_inference_engine(self) -> None: # Define test inference engine kwargs test_inference_engine_args = [ - "--model=model/DB.SCHEMA.MODEL/versions/VERSION/", "--tensor-parallel-size=2", "--max_tokens=1000", "--temperature=0.8", @@ -938,7 +937,6 @@ def test_create_service_with_inference_engine_and_no_image_build(self) -> None: # Define test inference engine kwargs test_inference_engine_args = [ - "--model=model/DB.SCHEMA.MODEL/versions/VERSION/", "--tensor-parallel-size=2", "--max_tokens=1000", "--temperature=0.8", diff --git a/snowflake/ml/model/_client/service/model_deployment_spec_test.py b/snowflake/ml/model/_client/service/model_deployment_spec_test.py index 03f7404a..a4d6e7c8 100644 --- a/snowflake/ml/model/_client/service/model_deployment_spec_test.py +++ b/snowflake/ml/model/_client/service/model_deployment_spec_test.py @@ -612,7 +612,6 @@ def test_experimental_options_minimal(self) -> None: inference_engine=inference_engine.InferenceEngine.VLLM, inference_engine_args=[ "--some_vllm_arg=0.8", - "--model=model", "--tensor_parallel_size=2", ], ) @@ -640,7 +639,6 @@ def test_experimental_options_minimal(self) -> None: "inference_engine_name": "vllm", "inference_engine_args": [ "--some_vllm_arg=0.8", - "--model=model", "--tensor_parallel_size=2", ], }, diff --git a/snowflake/ml/model/_client/sql/service.py b/snowflake/ml/model/_client/sql/service.py index 33e6fec2..bbbd1457 100644 --- a/snowflake/ml/model/_client/sql/service.py +++ b/snowflake/ml/model/_client/sql/service.py @@ -256,9 +256,6 @@ def show_endpoints( ) .has_column(ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME, allow_empty=True) .has_column(ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME, allow_empty=True) - .has_column( - ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_PRIVATELINK_INGRESS_URL_COL_NAME, allow_empty=True - ) ) return res.validate() diff --git a/snowflake/ml/model/_client/sql/service_test.py b/snowflake/ml/model/_client/sql/service_test.py index 6e2b0d3f..a857d6dc 100644 --- a/snowflake/ml/model/_client/sql/service_test.py +++ b/snowflake/ml/model/_client/sql/service_test.py @@ -433,6 +433,43 @@ def test_show_endpoints(self) -> None: statement_params=m_statement_params, ) + def test_show_endpoints_non_business_critical(self) -> None: + """Test show_endpoints for non-Business Critical accounts where privatelink_ingress_url column doesn't exist.""" + m_statement_params = {"test": "1"} + # Row without privatelink_ingress_url field + m_df = mock_data_frame.MockDataFrame( + collect_result=[ + Row( + name="inference", + ingress_url="foo.snowflakecomputing.app", + ) + ], + collect_statement_params=m_statement_params, + ) + self.m_session.add_mock_sql( + """SHOW ENDPOINTS IN SERVICE TEMP."test".MYSERVICE""", + copy.deepcopy(m_df), + ) + c_session = cast(Session, self.m_session) + + # This should work without errors even though privatelink_ingress_url column is missing + result = service_sql.ServiceSQLClient( + c_session, + database_name=sql_identifier.SqlIdentifier("TEMP"), + schema_name=sql_identifier.SqlIdentifier("test", case_sensitive=True), + ).show_endpoints( + database_name=None, + schema_name=None, + service_name=sql_identifier.SqlIdentifier("MYSERVICE"), + statement_params=m_statement_params, + ) + + # Verify the result contains the row with only name and ingress_url + self.assertEqual(len(result), 1) + self.assertEqual(result[0]["name"], "inference") + self.assertEqual(result[0]["ingress_url"], "foo.snowflakecomputing.app") + self.assertNotIn("privatelink_ingress_url", result[0]) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/_packager/model_runtime/model_runtime_test.py b/snowflake/ml/model/_packager/model_runtime/model_runtime_test.py index 5cdf6380..f301c594 100644 --- a/snowflake/ml/model/_packager/model_runtime/model_runtime_test.py +++ b/snowflake/ml/model/_packager/model_runtime/model_runtime_test.py @@ -331,7 +331,7 @@ def test_model_runtime_gpu(self) -> None: dependencies = yaml.safe_load(f) self.assertContainsSubset( - ["python==3.9.*", "pytorch", "snowflake-ml-python==1.0.0", "nvidia::cuda==11.7.*"], + ["python==3.10.*", "pytorch", "snowflake-ml-python==1.0.0", "nvidia::cuda==11.7.*"], dependencies["dependencies"], ) diff --git a/snowflake/ml/monitoring/_client/model_monitor_sql_client.py b/snowflake/ml/monitoring/_client/model_monitor_sql_client.py index 3707ae12..4d636791 100644 --- a/snowflake/ml/monitoring/_client/model_monitor_sql_client.py +++ b/snowflake/ml/monitoring/_client/model_monitor_sql_client.py @@ -30,8 +30,8 @@ def supported_target_properties(self) -> frozenset[str]: _OPERATION_SUPPORTED_PROPS: dict[MonitorOperation, frozenset[str]] = { MonitorOperation.SUSPEND: frozenset(), MonitorOperation.RESUME: frozenset(), - MonitorOperation.ADD: frozenset({"SEGMENT_COLUMN"}), - MonitorOperation.DROP: frozenset({"SEGMENT_COLUMN"}), + MonitorOperation.ADD: frozenset({"SEGMENT_COLUMN", "CUSTOM_METRIC_COLUMN"}), + MonitorOperation.DROP: frozenset({"SEGMENT_COLUMN", "CUSTOM_METRIC_COLUMN"}), } @@ -91,6 +91,7 @@ def create_model_monitor( baseline_schema: Optional[sql_identifier.SqlIdentifier] = None, baseline: Optional[sql_identifier.SqlIdentifier] = None, segment_columns: Optional[list[sql_identifier.SqlIdentifier]] = None, + custom_metric_columns: Optional[list[sql_identifier.SqlIdentifier]] = None, statement_params: Optional[dict[str, Any]] = None, ) -> None: baseline_sql = "" @@ -101,6 +102,10 @@ def create_model_monitor( if segment_columns: segment_columns_sql = f"SEGMENT_COLUMNS={_build_sql_list_from_columns(segment_columns)}" + custom_metric_columns_sql = "" + if custom_metric_columns: + custom_metric_columns_sql = f"CUSTOM_METRIC_COLUMNS={_build_sql_list_from_columns(custom_metric_columns)}" + query_result_checker.SqlResultValidator( self._sql_client._session, f""" @@ -120,6 +125,7 @@ def create_model_monitor( REFRESH_INTERVAL='{refresh_interval}' AGGREGATION_WINDOW='{aggregation_window}' {segment_columns_sql} + {custom_metric_columns_sql} {baseline_sql}""", statement_params=statement_params, ).has_column("status").has_dimensions(1, 1).validate() @@ -210,6 +216,7 @@ def _validate_columns_exist_in_source( actual_class_columns: list[sql_identifier.SqlIdentifier], id_columns: list[sql_identifier.SqlIdentifier], segment_columns: Optional[list[sql_identifier.SqlIdentifier]] = None, + custom_metric_columns: Optional[list[sql_identifier.SqlIdentifier]] = None, ) -> None: """Ensures all columns exist in the source table. @@ -222,12 +229,14 @@ def _validate_columns_exist_in_source( actual_class_columns: List of actual class column names. id_columns: List of id column names. segment_columns: List of segment column names. + custom_metric_columns: List of custom metric column names. Raises: ValueError: If any of the columns do not exist in the source. """ segment_columns = [] if segment_columns is None else segment_columns + custom_metric_columns = [] if custom_metric_columns is None else custom_metric_columns if timestamp_column not in source_column_schema: raise ValueError(f"Timestamp column {timestamp_column} does not exist in source.") @@ -248,6 +257,9 @@ def _validate_columns_exist_in_source( if not all([column_name in source_column_schema for column_name in segment_columns]): raise ValueError(f"Segment column(s): {segment_columns} do not exist in source.") + if not all([column_name in source_column_schema for column_name in custom_metric_columns]): + raise ValueError(f"Custom Metric column(s): {custom_metric_columns} do not exist in source.") + def validate_source( self, *, @@ -261,6 +273,7 @@ def validate_source( actual_class_columns: list[sql_identifier.SqlIdentifier], id_columns: list[sql_identifier.SqlIdentifier], segment_columns: Optional[list[sql_identifier.SqlIdentifier]] = None, + custom_metric_columns: Optional[list[sql_identifier.SqlIdentifier]] = None, ) -> None: source_database = source_database or self._database_name @@ -281,6 +294,7 @@ def validate_source( actual_class_columns=actual_class_columns, id_columns=id_columns, segment_columns=segment_columns, + custom_metric_columns=custom_metric_columns, ) def _alter_monitor( @@ -299,7 +313,7 @@ def _alter_monitor( if target_property not in supported_target_properties: raise ValueError( - f"Only {', '.join(supported_target_properties)} supported as target property " + f"Only {', '.join(sorted(supported_target_properties))} supported as target property " f"for {operation.name} operation" ) @@ -366,3 +380,33 @@ def drop_segment_column( target_value=segment_column, statement_params=statement_params, ) + + def add_custom_metric_column( + self, + monitor_name: sql_identifier.SqlIdentifier, + custom_metric_column: sql_identifier.SqlIdentifier, + statement_params: Optional[dict[str, Any]] = None, + ) -> None: + """Add a custom metric column to the Model Monitor""" + self._alter_monitor( + operation=MonitorOperation.ADD, + monitor_name=monitor_name, + target_property="CUSTOM_METRIC_COLUMN", + target_value=custom_metric_column, + statement_params=statement_params, + ) + + def drop_custom_metric_column( + self, + monitor_name: sql_identifier.SqlIdentifier, + custom_metric_column: sql_identifier.SqlIdentifier, + statement_params: Optional[dict[str, Any]] = None, + ) -> None: + """Drop a custom metric column from the Model Monitor""" + self._alter_monitor( + operation=MonitorOperation.DROP, + monitor_name=monitor_name, + target_property="CUSTOM_METRIC_COLUMN", + target_value=custom_metric_column, + statement_params=statement_params, + ) diff --git a/snowflake/ml/monitoring/_client/model_monitor_sql_client_test.py b/snowflake/ml/monitoring/_client/model_monitor_sql_client_test.py index ac49a1a2..28620c3a 100644 --- a/snowflake/ml/monitoring/_client/model_monitor_sql_client_test.py +++ b/snowflake/ml/monitoring/_client/model_monitor_sql_client_test.py @@ -250,6 +250,34 @@ def test_drop_segment_column(self) -> None: ) self.m_session.finalize() + def test_add_custom_metric_column(self) -> None: + test_custom_metric_column = sql_identifier.SqlIdentifier("CUSTOM_METRIC") + self.m_session.add_mock_sql( + ( + f"""ALTER MODEL MONITOR {self.test_db_name}.{self.test_schema_name}.{self.test_monitor_name} """ + f"""ADD CUSTOM_METRIC_COLUMN={test_custom_metric_column}""" + ), + result=mock_data_frame.MockDataFrame([Row(status="Success")]), + ) + self.monitor_sql_client.add_custom_metric_column( + monitor_name=self.test_monitor_name, custom_metric_column=test_custom_metric_column + ) + self.m_session.finalize() + + def test_drop_custom_metric_column(self) -> None: + test_custom_metric_column = sql_identifier.SqlIdentifier("CUSTOM_METRIC") + self.m_session.add_mock_sql( + ( + f"""ALTER MODEL MONITOR {self.test_db_name}.{self.test_schema_name}.{self.test_monitor_name} """ + f"""DROP CUSTOM_METRIC_COLUMN={test_custom_metric_column}""" + ), + result=mock_data_frame.MockDataFrame([Row(status="Success")]), + ) + self.monitor_sql_client.drop_custom_metric_column( + monitor_name=self.test_monitor_name, custom_metric_column=test_custom_metric_column + ) + self.m_session.finalize() + def test_alter_monitor_validation(self) -> None: """Test validation logic in _alter_monitor method""" # Test missing target property and value for ADD operation @@ -271,7 +299,9 @@ def test_alter_monitor_validation(self) -> None: ) # Test invalid target property for ADD operation - with self.assertRaisesRegex(ValueError, "Only SEGMENT_COLUMN supported as target property for ADD operation"): + with self.assertRaisesRegex( + ValueError, "Only CUSTOM_METRIC_COLUMN, SEGMENT_COLUMN supported as target property for ADD operation" + ): self.monitor_sql_client._alter_monitor( operation=MonitorOperation.ADD, monitor_name=self.test_monitor_name, @@ -280,7 +310,9 @@ def test_alter_monitor_validation(self) -> None: ) # Test invalid target property for DROP operation - with self.assertRaisesRegex(ValueError, "Only SEGMENT_COLUMN supported as target property for DROP operation"): + with self.assertRaisesRegex( + ValueError, "Only CUSTOM_METRIC_COLUMN, SEGMENT_COLUMN supported as target property for DROP operation" + ): self.monitor_sql_client._alter_monitor( operation=MonitorOperation.DROP, monitor_name=self.test_monitor_name, @@ -378,6 +410,55 @@ def test_create_model_monitor_with_segment_columns(self) -> None: ) self.m_session.finalize() + def test_create_model_monitor_with_custom_metric_columns(self) -> None: + """Test creating model monitor with custom metric columns""" + self.m_session.add_mock_sql( + f""" + CREATE MODEL MONITOR {self.test_db_name}.{self.test_schema_name}.{self.test_monitor_name} + WITH + MODEL={self.test_db_name}.{self.test_schema_name}.{self.test_model_name} + VERSION='{self.test_model_version_name}' + FUNCTION='predict' + WAREHOUSE='{self.test_wh_name}' + SOURCE={self.test_db_name}.{self.test_schema_name}.{self.test_source_table_name} + ID_COLUMNS=('ID') + PREDICTION_SCORE_COLUMNS=('PREDICTION') + PREDICTION_CLASS_COLUMNS=() + ACTUAL_SCORE_COLUMNS=('LABEL') + ACTUAL_CLASS_COLUMNS=() + TIMESTAMP_COLUMN='TIMESTAMP' + REFRESH_INTERVAL='1 hour' + AGGREGATION_WINDOW='1 day' + CUSTOM_METRIC_COLUMNS=('CUSTOM_METRIC') + """, + result=mock_data_frame.MockDataFrame([Row(status="Success")]), + ) + + self.monitor_sql_client.create_model_monitor( + monitor_database=self.test_db_name, + monitor_schema=self.test_schema_name, + monitor_name=self.test_monitor_name, + source_database=self.test_db_name, + source_schema=self.test_schema_name, + source=self.test_source_table_name, + model_database=self.test_db_name, + model_schema=self.test_schema_name, + model_name=self.test_model_name, + version_name=self.test_model_version_name, + function_name="predict", + warehouse_name=self.test_wh_name, + timestamp_column=self.test_timestamp_column, + id_columns=[self.test_id_column_name], + prediction_score_columns=[self.test_prediction_column_name], + prediction_class_columns=[], + actual_score_columns=[self.test_label_column_name], + actual_class_columns=[], + refresh_interval="1 hour", + aggregation_window="1 day", + custom_metric_columns=[sql_identifier.SqlIdentifier("CUSTOM_METRIC")], + ) + self.m_session.finalize() + def test_create_model_monitor_without_segment_columns(self) -> None: """Test creating model monitor without segment columns (empty list)""" self.m_session.add_mock_sql( diff --git a/snowflake/ml/monitoring/_manager/model_monitor_manager.py b/snowflake/ml/monitoring/_manager/model_monitor_manager.py index bda3e1cf..bc60b842 100644 --- a/snowflake/ml/monitoring/_manager/model_monitor_manager.py +++ b/snowflake/ml/monitoring/_manager/model_monitor_manager.py @@ -109,6 +109,7 @@ def add_monitor( actual_score_columns = self._build_column_list_from_input(source_config.actual_score_columns) actual_class_columns = self._build_column_list_from_input(source_config.actual_class_columns) segment_columns = self._build_column_list_from_input(source_config.segment_columns) + custom_metric_columns = self._build_column_list_from_input(source_config.custom_metric_columns) id_columns = [sql_identifier.SqlIdentifier(column_name) for column_name in source_config.id_columns] ts_column = sql_identifier.SqlIdentifier(source_config.timestamp_column) @@ -125,6 +126,7 @@ def add_monitor( actual_class_columns=actual_class_columns, id_columns=id_columns, segment_columns=segment_columns, + custom_metric_columns=custom_metric_columns, ) self._model_monitor_client.create_model_monitor( @@ -147,6 +149,7 @@ def add_monitor( actual_score_columns=actual_score_columns, actual_class_columns=actual_class_columns, segment_columns=segment_columns, + custom_metric_columns=custom_metric_columns, refresh_interval=model_monitor_config.refresh_interval, aggregation_window=model_monitor_config.aggregation_window, baseline_database=baseline_database_name_id, diff --git a/snowflake/ml/monitoring/_manager/model_monitor_manager_test.py b/snowflake/ml/monitoring/_manager/model_monitor_manager_test.py index cac71b83..12c998ba 100644 --- a/snowflake/ml/monitoring/_manager/model_monitor_manager_test.py +++ b/snowflake/ml/monitoring/_manager/model_monitor_manager_test.py @@ -182,6 +182,7 @@ def test_add_monitor(self) -> None: actual_class_columns=[], id_columns=["ID"], segment_columns=[], + custom_metric_columns=[], ) mock_get_model_task.assert_called_once() mock_create_model_monitor.assert_called_once_with( @@ -204,6 +205,7 @@ def test_add_monitor(self) -> None: actual_score_columns=["LABEL"], actual_class_columns=[], segment_columns=[], + custom_metric_columns=[], refresh_interval="1 hour", aggregation_window="1 day", baseline_database=None, @@ -249,6 +251,7 @@ def test_add_monitor_fully_qualified_monitor_name(self) -> None: actual_score_columns=["LABEL"], actual_class_columns=[], segment_columns=[], + custom_metric_columns=[], refresh_interval="1 hour", aggregation_window="1 day", baseline_database=None, @@ -312,6 +315,7 @@ def test_add_monitor_objects_in_different_schemas(self) -> None: actual_score_columns=["LABEL"], actual_class_columns=[], segment_columns=[], + custom_metric_columns=[], refresh_interval="1 hour", aggregation_window="1 day", baseline_database=sql_identifier.SqlIdentifier("BASELINE_DB"), @@ -351,6 +355,7 @@ def test_add_monitor_with_segment_columns_happy_path(self) -> None: actual_class_columns=[], id_columns=["ID"], segment_columns=["CUSTOMER_SEGMENT", "REGION"], + custom_metric_columns=[], ) mock_get_model_task.assert_called_once() @@ -375,6 +380,72 @@ def test_add_monitor_with_segment_columns_happy_path(self) -> None: actual_score_columns=["LABEL"], actual_class_columns=[], segment_columns=["CUSTOMER_SEGMENT", "REGION"], + custom_metric_columns=[], + refresh_interval="1 hour", + aggregation_window="1 day", + baseline_database=None, + baseline_schema=None, + baseline=None, + statement_params=None, + ) + + def test_add_monitor_with_custom_metric_columns(self) -> None: + """Test that custom_metric_columns are correctly passed through when provided.""" + source_config = model_monitor_config.ModelMonitorSourceConfig( + prediction_score_columns=["PREDICTION"], + actual_score_columns=["LABEL"], + id_columns=["ID"], + timestamp_column="TS", + source=self.test_source_table_name, + custom_metric_columns=["CUSTOM_METRIC_1", "CUSTOM_METRIC_2"], + ) + with mock.patch.object( + self.mm._model_monitor_client, "validate_source" + ) as mock_validate_source, mock.patch.object( + self.mv, "get_model_task", return_value=type_hints.Task.TABULAR_REGRESSION + ) as mock_get_model_task, mock.patch.object( + self.mm._model_monitor_client, "create_model_monitor", return_value=None + ) as mock_create_model_monitor: + self.mm.add_monitor("TEST", source_config, self.test_monitor_config) + + # Verify validate_source was called with custom_metric_columns + mock_validate_source.assert_called_once_with( + source_database=None, + source_schema=None, + source=self.test_source_table_name, + timestamp_column="TS", + prediction_score_columns=["PREDICTION"], + prediction_class_columns=[], + actual_score_columns=["LABEL"], + actual_class_columns=[], + id_columns=["ID"], + segment_columns=[], + custom_metric_columns=["CUSTOM_METRIC_1", "CUSTOM_METRIC_2"], + ) + mock_get_model_task.assert_called_once() + + # Verify create_model_monitor was called with custom_metric_columns + mock_create_model_monitor.assert_called_once_with( + monitor_database=None, + monitor_schema=None, + monitor_name=sql_identifier.SqlIdentifier("TEST"), + source_database=None, + source_schema=None, + source=sql_identifier.SqlIdentifier(self.test_source_table_name), + model_database=sql_identifier.SqlIdentifier("MODEL_DB"), + model_schema=sql_identifier.SqlIdentifier("MODEL_SCHEMA"), + model_name=self.test_model, + version_name=sql_identifier.SqlIdentifier(self.test_model_version), + function_name="predict", + warehouse_name=sql_identifier.SqlIdentifier(self.test_warehouse), + timestamp_column="TS", + id_columns=["ID"], + prediction_score_columns=["PREDICTION"], + prediction_class_columns=[], + actual_score_columns=["LABEL"], + actual_class_columns=[], + segment_columns=[], + custom_metric_columns=["CUSTOM_METRIC_1", "CUSTOM_METRIC_2"], refresh_interval="1 hour", aggregation_window="1 day", baseline_database=None, @@ -417,6 +488,44 @@ def test_add_monitor_with_segment_columns_validation_failure(self) -> None: actual_class_columns=[], id_columns=["ID"], segment_columns=["NONEXISTENT_COLUMN"], + custom_metric_columns=[], + ) + + def test_add_monitor_with_custom_metric_columns_validation_failure(self) -> None: + """Test that add_monitor fails when custom_metric_columns don't exist in source.""" + source_config = model_monitor_config.ModelMonitorSourceConfig( + prediction_score_columns=["PREDICTION"], + actual_score_columns=["LABEL"], + id_columns=["ID"], + timestamp_column="TS", + source=self.test_source_table_name, + custom_metric_columns=["NONEXISTENT_COLUMN"], + ) + + with mock.patch.object( + self.mm._model_monitor_client, + "validate_source", + side_effect=ValueError("Custom metric column(s): ['NONEXISTENT_COLUMN'] do not exist in source."), + ) as mock_validate_source, mock.patch.object( + self.mv, "get_model_task", return_value=type_hints.Task.TABULAR_REGRESSION + ): + with self.assertRaisesRegex( + ValueError, "Custom metric column\\(s\\): \\['NONEXISTENT_COLUMN'\\] do not exist in source\\." + ): + self.mm.add_monitor("TEST", source_config, self.test_monitor_config) + + mock_validate_source.assert_called_once_with( + source_database=None, + source_schema=None, + source=self.test_source_table_name, + timestamp_column="TS", + prediction_score_columns=["PREDICTION"], + prediction_class_columns=[], + actual_score_columns=["LABEL"], + actual_class_columns=[], + id_columns=["ID"], + segment_columns=[], + custom_metric_columns=["NONEXISTENT_COLUMN"], ) diff --git a/snowflake/ml/monitoring/entities/model_monitor_config.py b/snowflake/ml/monitoring/entities/model_monitor_config.py index 131f858e..4d393831 100644 --- a/snowflake/ml/monitoring/entities/model_monitor_config.py +++ b/snowflake/ml/monitoring/entities/model_monitor_config.py @@ -36,6 +36,9 @@ class ModelMonitorSourceConfig: segment_columns: Optional[list[str]] = None """List of columns in the source containing segment information for grouped monitoring.""" + custom_metric_columns: Optional[list[str]] = None + """List of columns in the source containing custom metrics.""" + @dataclass class ModelMonitorConfig: diff --git a/snowflake/ml/monitoring/model_monitor.py b/snowflake/ml/monitoring/model_monitor.py index 738f58aa..6745d0ae 100644 --- a/snowflake/ml/monitoring/model_monitor.py +++ b/snowflake/ml/monitoring/model_monitor.py @@ -72,3 +72,33 @@ def drop_segment_column(self, segment_column: str) -> None: ) segment_column_id = sql_identifier.SqlIdentifier(segment_column) self._model_monitor_client.drop_segment_column(self.name, segment_column_id, statement_params=statement_params) + + @telemetry.send_api_usage_telemetry( + project=telemetry.TelemetryProject.MLOPS.value, + subproject=telemetry.TelemetrySubProject.MONITORING.value, + ) + def add_custom_metric_column(self, custom_metric_column: str) -> None: + """Add a custom metric column to the Model Monitor""" + statement_params = telemetry.get_statement_params( + telemetry.TelemetryProject.MLOPS.value, + telemetry.TelemetrySubProject.MONITORING.value, + ) + custom_metric_column_identifier = sql_identifier.SqlIdentifier(custom_metric_column) + self._model_monitor_client.add_custom_metric_column( + self.name, custom_metric_column_identifier, statement_params=statement_params + ) + + @telemetry.send_api_usage_telemetry( + project=telemetry.TelemetryProject.MLOPS.value, + subproject=telemetry.TelemetrySubProject.MONITORING.value, + ) + def drop_custom_metric_column(self, custom_metric_column: str) -> None: + """Drop a custom metric column from the Model Monitor""" + statement_params = telemetry.get_statement_params( + telemetry.TelemetryProject.MLOPS.value, + telemetry.TelemetrySubProject.MONITORING.value, + ) + custom_metric_column_identifier = sql_identifier.SqlIdentifier(custom_metric_column) + self._model_monitor_client.drop_custom_metric_column( + self.name, custom_metric_column_identifier, statement_params=statement_params + ) diff --git a/snowflake/ml/monitoring/model_monitor_test.py b/snowflake/ml/monitoring/model_monitor_test.py index 594f928c..61f8f1ca 100644 --- a/snowflake/ml/monitoring/model_monitor_test.py +++ b/snowflake/ml/monitoring/model_monitor_test.py @@ -53,6 +53,34 @@ def test_drop_segment_column(self) -> None: self.assertEqual(call_args[0][1].identifier(), test_segment_column) # segment_column as SqlIdentifier self.assertIsNotNone(call_args[1]["statement_params"]) # statement_params + def test_add_custom_metric_column(self) -> None: + test_custom_metric_column = "CUSTOM_METRIC" + with mock.patch.object( + self.model_monitor._model_monitor_client, "add_custom_metric_column" + ) as mock_add_custom_metric: + self.model_monitor.add_custom_metric_column(test_custom_metric_column) + mock_add_custom_metric.assert_called_once_with( + self.test_monitor_name, test_custom_metric_column, statement_params=mock.ANY + ) + call_args = mock_add_custom_metric.call_args + self.assertEqual(call_args[0][0], self.test_monitor_name) + self.assertEqual(call_args[0][1].identifier(), test_custom_metric_column) + self.assertIsNotNone(call_args[1]["statement_params"]) + + def test_drop_custom_metric_column(self) -> None: + test_custom_metric_column = "CUSTOM_METRIC" + with mock.patch.object( + self.model_monitor._model_monitor_client, "drop_custom_metric_column" + ) as mock_drop_custom_metric: + self.model_monitor.drop_custom_metric_column(test_custom_metric_column) + mock_drop_custom_metric.assert_called_once_with( + self.test_monitor_name, test_custom_metric_column, statement_params=mock.ANY + ) + call_args = mock_drop_custom_metric.call_args + self.assertEqual(call_args[0][0], self.test_monitor_name) + self.assertEqual(call_args[0][1].identifier(), test_custom_metric_column) + self.assertIsNotNone(call_args[1]["statement_params"]) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/version.py b/snowflake/ml/version.py index 71b854e2..6a428499 100644 --- a/snowflake/ml/version.py +++ b/snowflake/ml/version.py @@ -1,2 +1,2 @@ # This is parsed by regex in conda recipe meta file. Make sure not to break it. -VERSION = "1.13.0" +VERSION = "1.14.0" diff --git a/tests/integ/snowflake/ml/jobs/jobs_integ_test.py b/tests/integ/snowflake/ml/jobs/jobs_integ_test.py index 2ea66705..ca923a6e 100644 --- a/tests/integ/snowflake/ml/jobs/jobs_integ_test.py +++ b/tests/integ/snowflake/ml/jobs/jobs_integ_test.py @@ -556,7 +556,6 @@ def runtime_func() -> None: "Decorator test only works for Python 3.10 to pickle compatibility", ) def test_job_data_connector(self) -> None: - from snowflake.ml._internal.utils import mixins from snowflake.ml.data import data_connector from snowflake.ml.data._internal import arrow_ingestor @@ -564,9 +563,8 @@ def test_job_data_connector(self) -> None: @jobs.remote(self.compute_pool, stage_name="payload_stage", session=self.session) def runtime_func(dc: data_connector.DataConnector) -> data_connector.DataConnector: - # TODO(SNOW-2182155): Enable this once headless backend receives updated SnowML with unpickle support - # assert "Ray" in type(dc._ingestor).__name__, type(dc._ingestor).__qualname__ - assert len(dc.to_pandas()) == num_rows, len(dc.to_pandas()) + if len(dc.to_pandas()) != num_rows: + raise RuntimeError(f"Unexpected number of rows: expected {num_rows}, actual {len(dc.to_pandas())}") return dc df = self.session.sql( @@ -575,16 +573,7 @@ def runtime_func(dc: data_connector.DataConnector) -> data_connector.DataConnect dc = data_connector.DataConnector.from_dataframe(df) self.assertIsInstance(dc._ingestor, arrow_ingestor.ArrowIngestor) - # TODO(SNOW-2182155): Remove this once headless backend receives updated SnowML with unpickle support - # Register key modules to be picklable by value to avoid version desync in this test - cp.register_pickle_by_value(mixins) - cp.register_pickle_by_value(arrow_ingestor) - try: - job = runtime_func(dc) - finally: - cp.unregister_pickle_by_value(mixins) - cp.unregister_pickle_by_value(arrow_ingestor) - + job = runtime_func(dc) self.assertEqual(job.wait(), "DONE", job.get_logs()) dc_unpickled = job.result() self.assertIsInstance(dc_unpickled, data_connector.DataConnector) diff --git a/tests/integ/snowflake/ml/monitoring/model_monitor_integ_test.py b/tests/integ/snowflake/ml/monitoring/model_monitor_integ_test.py index 8f1fc59f..83af209d 100644 --- a/tests/integ/snowflake/ml/monitoring/model_monitor_integ_test.py +++ b/tests/integ/snowflake/ml/monitoring/model_monitor_integ_test.py @@ -31,7 +31,11 @@ def _get_monitor_segment_columns(self, monitor_name: str) -> list[str]: return columns_json.get("segment_columns", []) def _create_test_table( - self, fully_qualified_table_name: str, id_column_type: str = "STRING", segment_columns: list = None + self, + fully_qualified_table_name: str, + id_column_type: str = "STRING", + segment_columns: list = None, + custom_metric_columns: list = None, ) -> None: """Create a test table with optional segment columns for testing.""" @@ -42,10 +46,15 @@ def _create_test_table( if segment_columns: segment_columns_def = ", " + ", ".join([f"{col} STRING" for col in segment_columns]) + # Build the custom metric columns part of the table definition + custom_metric_columns_def = "" + if custom_metric_columns: + custom_metric_columns_def = ", " + ", ".join([f"{col} FLOAT" for col in custom_metric_columns]) + self._session.sql( f"""CREATE OR REPLACE TABLE {fully_qualified_table_name} (label FLOAT, prediction FLOAT, - {s}, id {id_column_type}, timestamp TIMESTAMP_NTZ{segment_columns_def})""" + {s}, id {id_column_type}, timestamp TIMESTAMP_NTZ{segment_columns_def}{custom_metric_columns_def})""" ).collect() # Needed to create DT against this table @@ -58,10 +67,18 @@ def _create_test_table( segment_columns_insert = ", " + ", ".join(segment_columns) segment_values_insert = ", " + ", ".join([f"'{col}_value'" for col in segment_columns]) + # Build the custom metric columns part of the INSERT statement + custom_metric_columns_insert = "" + custom_metric_values_insert = "" + if custom_metric_columns: + custom_metric_columns_insert = ", " + ", ".join(custom_metric_columns) + custom_metric_values_insert = ", " + ", ".join(["1.23" for _ in custom_metric_columns]) + self._session.sql( f"""INSERT INTO {fully_qualified_table_name} - (label, prediction, {", ".join(INPUT_FEATURE_COLUMNS_NAMES)}, id, timestamp{segment_columns_insert}) - VALUES (1, 1, {", ".join(["1"] * 64)}, '1', CURRENT_TIMESTAMP(){segment_values_insert})""" + (label, prediction, {", ".join(INPUT_FEATURE_COLUMNS_NAMES)}, id, timestamp{segment_columns_insert} + {custom_metric_columns_insert}) VALUES (1, 1, {", ".join(["1"] * 64)}, '1', CURRENT_TIMESTAMP() + {segment_values_insert}{custom_metric_values_insert})""" ).collect() @classmethod @@ -108,7 +125,12 @@ def _add_sample_model_version(self, model_name: str, version_name: str) -> model ) def _add_sample_monitor( - self, monitor_name: str, source: str, model_version: model_version_impl.ModelVersion, segment_columns=None + self, + monitor_name: str, + source: str, + model_version: model_version_impl.ModelVersion, + segment_columns=None, + custom_metric_columns=None, ) -> model_monitor.ModelMonitor: return self.registry.add_monitor( name=monitor_name, @@ -119,6 +141,7 @@ def _add_sample_monitor( id_columns=["id"], timestamp_column="timestamp", segment_columns=segment_columns, + custom_metric_columns=custom_metric_columns, ), model_monitor_config=model_monitor_config.ModelMonitorConfig( model_version=model_version, @@ -318,6 +341,45 @@ def test_create_monitor_with_segment_columns_happy_path(self): monitor_names = [m["name"] for m in monitors] self.assertIn(monitor_name.upper(), monitor_names) + def test_create_monitor_with_custom_metric_columns(self): + """Test creating a monitor with valid custom_metric_columns.""" + + self._session.sql("ALTER SESSION SET ENABLE_MODEL_MONITOR_CUSTOM_METRICS = TRUE").collect() + + source_table_name = "source_table_with_custom_metrics" + model_name = "model_with_custom_metrics" + monitor_name = "monitor_with_custom_metrics" + + # Create table with custom metric columns + self._create_test_table( + f"{self._db_name}.{self._schema_name}.{source_table_name}", custom_metric_columns=["custom_metric"] + ) + + # Create model version + mv = self._add_sample_model_version(model_name=model_name, version_name="V1") + + # Create monitor with custom metric columns - this should succeed + monitor = self._add_sample_monitor( + monitor_name=monitor_name, + source=source_table_name, + model_version=mv, + custom_metric_columns=["custom_metric"], + ) + + # Verify monitor was created successfully + self.assertEqual(monitor.name, monitor_name.upper()) + + # Verify it appears in the list of monitors + monitors = self.registry.show_model_monitors() + monitor_names = [m["name"] for m in monitors] + self.assertIn(monitor_name.upper(), monitor_names) + + # Verify describe monitor shows custom metric column + describe_result = self._session.sql( + f"DESCRIBE MODEL MONITOR {self._db_name}.{self._schema_name}.{monitor_name}" + ).collect() + self.assertIn("CUSTOM_METRIC", describe_result[0]["columns"]) + def test_create_monitor_with_segment_columns_missing_in_source(self): """Test creating a monitor with invalid segment_columns should fail.""" @@ -350,6 +412,55 @@ def test_create_monitor_with_segment_columns_missing_in_source(self): monitor_names = [m["name"] for m in monitors] self.assertNotIn(monitor_name.upper(), monitor_names) + def test_add_drop_custom_metric_columns(self): + """Test adding and dropping custom metric columns.""" + + self._session.sql("ALTER SESSION SET ENABLE_MODEL_MONITOR_CUSTOM_METRICS = TRUE").collect() + source_table_name = "source_table_custom_metrics" + model_name = "model_custom_metrics" + monitor_name = "monitor_custom_metrics" + + # Create table with initial custom metric columns + self._create_test_table( + f"{self._db_name}.{self._schema_name}.{source_table_name}", custom_metric_columns=["initial_metric"] + ) + + # Create model version + mv = self._add_sample_model_version(model_name=model_name, version_name="V1") + + # Create monitor with initial custom metric columns + monitor = self._add_sample_monitor( + monitor_name=monitor_name, + source=source_table_name, + model_version=mv, + custom_metric_columns=["initial_metric"], + ) + + # Verify monitor was created successfully with initial custom metric + describe_result = self._session.sql( + f"DESCRIBE MODEL MONITOR {self._db_name}.{self._schema_name}.{monitor_name}" + ).collect() + self.assertIn("INITIAL_METRIC", describe_result[0]["columns"]) + + # Add new custom metric columns + monitor.add_custom_metric_column("input_feature_1") + + # Verify new custom metric columns were added + describe_result = self._session.sql( + f"DESCRIBE MODEL MONITOR {self._db_name}.{self._schema_name}.{monitor_name}" + ).collect() + self.assertIn("INPUT_FEATURE_1", json.loads(describe_result[0]["columns"])["custom_metric_columns"]) + + # Drop a custom metric column + monitor.drop_custom_metric_column("initial_metric") + + # Verify the custom metric column was dropped + describe_result = self._session.sql( + f"DESCRIBE MODEL MONITOR {self._db_name}.{self._schema_name}.{monitor_name}" + ).collect() + self.assertNotIn("INITIAL_METRIC", json.loads(describe_result[0]["columns"])["custom_metric_columns"]) + self.assertIn("INITIAL_METRIC", json.loads(describe_result[0]["columns"])["numerical_columns"]) + if __name__ == "__main__": absltest.main() diff --git a/third_party/rules_conda/env.bzl b/third_party/rules_conda/env.bzl index bf653e84..2a973c48 100644 --- a/third_party/rules_conda/env.bzl +++ b/third_party/rules_conda/env.bzl @@ -156,7 +156,7 @@ conda_create_rule = repository_rule( "python_version": attr.string( mandatory = True, doc = "The Python version to use when creating the environment.", - values = ["3.9", "3.10", "3.11", "3.12"], + values = ["3.10", "3.11", "3.12"], ), "quiet": attr.bool( default = True,