diff --git a/CHANGELOG.md b/CHANGELOG.md index 911183d6..53efc88c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,26 @@ # Release History +## 1.17.0 + +### Bug Fixes + +* ML Job: Added support for retrieving details of deleted jobs, including status, compute pool, and target instances. + +### Behavior Changes + +### New Features + +* Support xgboost 3.x. +* ML Job: Overhauled the `MLJob.result()` API with broader cross-version + compatibility and support for additional data types, namely: + * Pandas DataFrames + * PyArrow Tables + * NumPy arrays + * NOTE: Requires `snowflake-ml-python>=1.17.0` to be installed inside remote container environment. +* ML Job: Enabled job submission v2 by default + * Jobs submitted using v2 will automatically use the latest Container Runtime image + * v1 behavior can be restored by setting environment variable `MLRS_USE_SUBMIT_JOB_V2` to `false` + ## 1.16.0 ### Bug Fixes @@ -34,9 +55,10 @@ options = { "function_type": "TABLE_FUNCTION", "volatility": Volatility.VOLATILE, }, + }, } -```` +``` ## 1.15.0 (09-29-2025) diff --git a/bazel/BUILD.bazel b/bazel/BUILD.bazel index 4d70d4a1..768091f9 100644 --- a/bazel/BUILD.bazel +++ b/bazel/BUILD.bazel @@ -28,14 +28,6 @@ py_binary( srcs_version = "PY3", ) -py_binary( - name = "add_pytest_marks", - srcs = ["add_pytest_marks.py"], - main = "add_pytest_marks.py", - python_version = "PY3", - srcs_version = "PY3", -) - # Package group for common targets in the repo. package_group( name = "snowml_public_common", diff --git a/bazel/add_pytest_marks.py b/bazel/add_pytest_marks.py deleted file mode 100644 index 7aac0f03..00000000 --- a/bazel/add_pytest_marks.py +++ /dev/null @@ -1,203 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import os -import re -import subprocess -import sys -from typing import Match, Optional - - -def get_targets_with_cquery(target_pattern: str, bazel_path: str = "bazel") -> list[tuple[str, str, list[str]]]: - """ - Get py_test targets with their feature areas and source files using bazel query. - - Args: - target_pattern: bazel target pattern to search (e.g., "//tests/integ/...") - bazel_path: path to bazel executable - - Returns: - List of tuples with the following elements: - - target_name (str): full bazel target name (e.g., "//tests/integ:my_test") - - feature_area (str): feature area extracted from tags (e.g., "jobs", "model_registry") - - source_files (list[str]): list of python source file paths for the target - """ - targets_info = [] - - try: - # Query for py_test targets - cmd = [bazel_path, "query", f'kind("py_test", {target_pattern})', "--output=label"] - result = subprocess.run(cmd, capture_output=True, text=True, check=True) - targets = [line.strip() for line in result.stdout.splitlines() if line.strip()] - - for target in targets: - feature_area = get_feature_area(target, bazel_path) - if feature_area: - source_files = get_source_files(target, bazel_path) - if source_files: - targets_info.append((target, feature_area, source_files)) - - return targets_info - - except subprocess.CalledProcessError: - return [] - except Exception: - return [] - - -def get_feature_area(target: str, bazel_path: str) -> Optional[str]: - """Extract feature area from target.""" - try: - cmd = [bazel_path, "query", f"attr('tags', 'feature:.*', {target})", "--output=build"] - result = subprocess.run(cmd, capture_output=True, text=True, check=True) - - # Look for feature: tags in the build output - for line in result.stdout.splitlines(): - if "feature:" in line: - match = re.search(r'feature:([^"\']+)', line) - if match: - return match.group(1) - - return None - - except subprocess.CalledProcessError: - return None - - -def get_source_files(target: str, bazel_path: str = "bazel") -> list[str]: - """Get source files for target.""" - try: - cmd = [bazel_path, "query", f"attr('srcs', '.*', {target})", "--output=build"] - result = subprocess.run(cmd, capture_output=True, text=True, check=True) - - source_files = [] - target_dir = target.replace("//", "").split(":")[0] - - # Extract .py files from srcs attribute - for line in result.stdout.splitlines(): - if "srcs = [" in line or '.py"' in line: - py_files = re.findall(r'"([^"]*\.py)"', line) - for py_file in py_files: - if not py_file.startswith("//"): - # Relative path, add target directory - source_files.append(f"{target_dir}/{py_file}") - else: - # Absolute path, convert to relative - source_files.append(py_file.replace("//", "").replace(":", "/")) - - return source_files - - except subprocess.CalledProcessError: - return [] - - -def process_test_file(source_file: str, feature_area: str) -> None: - """ - Process a python test file to add pytest feature area marks. - * adds "import pytest" if not already present - * adds @pytest.mark.feature_area_ decorators to test classes or functions - - Args: - source_file: path to the python test file to process - feature_area: feature area name (e.g., "jobs", "model_registry") used in pytest marks - """ - if not os.path.exists(source_file): - return - - try: - with open(source_file, encoding="utf-8") as f: - content = f.read() - - original_content = content - - # Add pytest import at the top with other imports if not present - if "import pytest" not in content: - lines = content.split("\n") - # Find the first import line to insert pytest import there - insert_idx = 0 - for i, line in enumerate(lines): - stripped = line.strip() - if stripped.startswith("from ") or stripped.startswith("import "): - insert_idx = i - break - elif stripped and not stripped.startswith("#"): - # If we hit non-comment, non-import content, insert at current position - insert_idx = i - break - - lines.insert(insert_idx, "import pytest") - content = "\n".join(lines) - - # Add pytest marks using - mark = f"@pytest.mark.feature_area_{feature_area}" - - # Skip if mark is already present - if mark in content: - return - - # Mark test classes - match classes that start with "Test" OR end with "Test" - class_pattern = r"(\n+)(\s*)(class (?:Test\w*|\w*Test).*?:)$" - class_marked = False - - def add_class_mark(match: Match[str]) -> str: - nonlocal class_marked - class_marked = True - preceding_newlines = match.group(1) - indent = match.group(2) - definition = match.group(3) - # Preserve original spacing - add mark right before class definition - return f"{preceding_newlines}{indent}{mark}\n{indent}{definition}" - - new_content = re.sub(class_pattern, add_class_mark, content, flags=re.MULTILINE) - if new_content != content: - content = new_content - - # Only mark test functions if no test class was marked - if not class_marked: - func_pattern = r"^(\s*)(def test_\w*.*?:)$" - - def add_func_mark(match: Match[str]) -> str: - indent = match.group(1) - definition = match.group(2) - return f"{indent}{mark}\n{indent}{definition}" - - new_content = re.sub(func_pattern, add_func_mark, content, flags=re.MULTILINE) - if new_content != content: - content = new_content - - # Only write if changes were made - if content != original_content: - with open(source_file, "w", encoding="utf-8") as f: - f.write(content) - - except Exception: - pass - - -def main() -> int: - parser = argparse.ArgumentParser( - description="Add pytest marks to test files based on Bazel py_test target feature tags" - ) - parser.add_argument("--targets", default="//...", help="Bazel target pattern to process (default: //...)") - parser.add_argument("--bazel-path", default="bazel", help="Path to bazel executable") - - args = parser.parse_args() - - try: - targets_info = get_targets_with_cquery(args.targets, args.bazel_path) - - if not targets_info: - return 0 - - for _target_name, feature_area, source_files in targets_info: - for source_file in source_files: - process_test_file(source_file, feature_area) - - return 0 - - except Exception: - return 1 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/ci/build_and_run_tests.sh b/ci/build_and_run_tests.sh index d5c2cdaa..40d708dc 100755 --- a/ci/build_and_run_tests.sh +++ b/ci/build_and_run_tests.sh @@ -236,7 +236,7 @@ pushd ${SNOWML_DIR} VERSION=$(grep -oE "VERSION = \"[0-9]+\\.[0-9]+\\.[0-9]+.*\"" snowflake/ml/version.py| cut -d'"' -f2) echo "Extracted Package Version from code: ${VERSION}" -# Generate and copy auto-gen tests with pytest marks. +# Generate and copy auto-gen tests. "${BAZEL}" "${BAZEL_ADDITIONAL_STARTUP_FLAGS[@]+"${BAZEL_ADDITIONAL_STARTUP_FLAGS[@]}"}" build --config=build "${BAZEL_ADDITIONAL_BUILD_FLAGS[@]+"${BAZEL_ADDITIONAL_BUILD_FLAGS[@]}"}" //tests/integ/... # Rsync cannot work well with path that has drive letter in Windows, @@ -244,14 +244,8 @@ echo "Extracted Package Version from code: ${VERSION}" rsync -av --exclude '*.runfiles_manifest' --exclude '*.runfiles/**' "bazel-bin/tests" . -if [[ -n "${FEATURE_AREAS}" ]]; then - # Add pytest marks to all test files based on their feature: tags - echo "Adding pytest marks to test files..." - ${PYTHON_EXECUTABLE} bazel/add_pytest_marks.py --targets //tests/integ/... --bazel-path "${BAZEL}" -fi - # Read environments from optional_dependency_groups.bzl -groups=() +groups=("core") while IFS= read -r line; do groups+=("$line") done < <(python3 -c ' @@ -265,26 +259,35 @@ with open("bazel/platforms/optional_dependency_groups.bzl", "r") as f: print(group) ') -groups+=("core") - for i in "${!groups[@]}"; do group="${groups[$i]}" # Compare test required dependencies with wheel pkg dependencies and exclude tests if necessary EXCLUDE_TESTS=$(mktemp "${TEMP_TEST_DIR}/exclude_tests_${group}_XXXXX") - ./ci/get_excluded_tests.sh -f "${EXCLUDE_TESTS}" -m "${MODE}" -b "${BAZEL}" -e "${SF_ENV}" -g "${group}" + + # Add feature area filtering if FEATURE_AREAS is set + if [[ -n "${FEATURE_AREAS}" ]]; then + echo "Applying feature area filter: ${FEATURE_AREAS}" + ./ci/get_excluded_tests.sh -f "${EXCLUDE_TESTS}" -m "${MODE}" -b "${BAZEL}" -e "${SF_ENV}" -g "${group}" -a "${FEATURE_AREAS}" + else + ./ci/get_excluded_tests.sh -f "${EXCLUDE_TESTS}" -m "${MODE}" -b "${BAZEL}" -e "${SF_ENV}" -g "${group}" + fi # Copy tests into temp directory pushd "${TEMP_TEST_DIR}" - rsync -av --exclude-from "${EXCLUDE_TESTS}" "../${SNOWML_DIR}/tests" "${group}" + + # Copy from snowml root so exclude patterns can use "tests/integ/..." format + rsync -av --exclude-from "${EXCLUDE_TESTS}" --include="tests/***" --exclude="*" "../${SNOWML_DIR}/" "${group}/" popd done "${BAZEL}" "${BAZEL_ADDITIONAL_STARTUP_FLAGS[@]+"${BAZEL_ADDITIONAL_STARTUP_FLAGS[@]}"}" clean --expunge popd + # Build snowml package if [ "${ENV}" = "pip" ]; then + echo "Building snowml package: pip" # Clean build workspace rm -f "${WORKSPACE}"/*.whl @@ -340,6 +343,7 @@ if [[ "${WITH_SPCS_IMAGE}" = true ]]; then fi # Start testing +echo "Starting testing..." pushd "${TEMP_TEST_DIR}" # Set up common pytest flag @@ -350,24 +354,6 @@ COMMON_PYTEST_FLAG+=(--log-cli-level=INFO) COMMON_PYTEST_FLAG+=(-n logical) COMMON_PYTEST_FLAG+=(--timeout=3600) -# Add feature area filtering if specified -if [[ -n "${FEATURE_AREAS}" ]]; then - # Convert comma-separated list to pytest mark expression - # e.g., "jobs,core" becomes "feature_area_jobs or feature_area_core" - IFS=',' read -ra AREAS <<< "${FEATURE_AREAS}" - MARK_EXPR="" - for i in "${!AREAS[@]}"; do - area="${AREAS[$i]}" - if [[ $i -eq 0 ]]; then - MARK_EXPR="feature_area_${area}" - else - MARK_EXPR="${MARK_EXPR} or feature_area_${area}" - fi - done - COMMON_PYTEST_FLAG+=(-m "${MARK_EXPR}") - echo "Running tests with feature area filter: ${MARK_EXPR}" -fi - group_exit_codes=() group_coverage_report_files=() @@ -385,6 +371,7 @@ for i in "${!groups[@]}"; do pushd "${group}" if [ "${ENV}" = "pip" ]; then + echo "Testing with pip environment: ${group}" if [ "${WITH_SPCS_IMAGE}" = true ]; then COMMON_PYTEST_FLAG+=(-m "spcs_deployment_image and not pip_incompatible") else @@ -414,6 +401,7 @@ for i in "${!groups[@]}"; do python -m pip list # Run the tests + echo "Running tests with pytest flags: ${COMMON_PYTEST_FLAG[*]}" set +e TEST_SRCDIR="${TEMP_TEST_DIR}" python -m pytest "${COMMON_PYTEST_FLAG[@]}" tests/integ/ group_exit_codes[$i]=$? @@ -447,6 +435,7 @@ for i in "${!groups[@]}"; do "${_MICROMAMBA_BIN}" list -p ./testenv # Run integration tests + echo "Running tests with pytest flags: ${COMMON_PYTEST_FLAG[*]}" set +e TEST_SRCDIR="${TEMP_TEST_DIR}" ${CONDA} run -p ./testenv --no-capture-output python -m pytest "${COMMON_PYTEST_FLAG[@]}" tests/integ/ group_exit_codes[$i]=$? @@ -510,7 +499,9 @@ fi # Check all group exit codes for exit_code in "${group_exit_codes[@]}"; do - if [[ (${MODE} = "merge_gate" || ${MODE} = "quarantined" || ${WITH_SPCS_IMAGE} = "true" ) && ${exit_code} -eq 5 ]]; then + # Allow exit code 5 (no tests found) for all modes, as this is expected + # when an optional dependency group has no tests + if [[ ${exit_code} -eq 5 ]]; then continue fi if [[ ${exit_code} -ne 0 ]]; then diff --git a/ci/conda_recipe/meta.yaml b/ci/conda_recipe/meta.yaml index f48663dc..e4bc25d7 100644 --- a/ci/conda_recipe/meta.yaml +++ b/ci/conda_recipe/meta.yaml @@ -17,7 +17,7 @@ build: noarch: python package: name: snowflake-ml-python - version: 1.16.0 + version: 1.17.0 requirements: build: - python @@ -51,7 +51,7 @@ requirements: - sqlparse>=0.4,<1 - tqdm<5 - typing-extensions>=4.1.0,<5 - - xgboost>=1.7.3,<3 + - xgboost<4 - python>=3.9,<3.13 run_constrained: - altair>=5,<6 diff --git a/ci/get_excluded_tests.sh b/ci/get_excluded_tests.sh index e1b39553..bb00fff8 100755 --- a/ci/get_excluded_tests.sh +++ b/ci/get_excluded_tests.sh @@ -1,7 +1,7 @@ #!/bin/bash # Usage -# exclude_tests.sh [-b ] [-f ] [- merge_gate|continuous_run|release] +# exclude_tests.sh [-b ] [-f ] [- merge_gate|continuous_run|release] [-a ] # # Flags # -b: specify path to bazel @@ -17,6 +17,7 @@ # -g: specify the group (default: "core") Test group could be found in bazel/platforms/optional_dependency_groups.bzl. # `core` group is the default group that includes all tests that does not have a group specified. # `all` group includes all tests. +# -a: specify comma-separated list of feature areas to INCLUDE (exclude all others) (e.g., "core,modeling,data") set -o pipefail set -u @@ -25,7 +26,7 @@ PROG=$0 help() { local exit_code=$1 - echo "Usage: ${PROG} [-b ] [-f ] [-m merge_gate|continuous_run|quarantined] [-e ] [-g ]" + echo "Usage: ${PROG} [-b ] [-f ] [-m merge_gate|continuous_run|quarantined] [-e ] [-g ] [-a ]" exit "${exit_code}" } @@ -36,8 +37,9 @@ output_path="/tmp/files_to_exclude" mode="continuous_run" SF_ENV="prod3" group="core" +feature_areas="" -while getopts "b:f:m:e:g:h" opt; do +while getopts "b:f:m:e:g:a:h" opt; do case "${opt}" in b) bazel=${OPTARG} @@ -57,6 +59,9 @@ while getopts "b:f:m:e:g:h" opt; do g) group=${OPTARG} ;; + a) + feature_areas=${OPTARG} + ;; h) help 0 ;; @@ -158,6 +163,40 @@ if [[ $group != "all" ]]; then sort -u "${targets_to_exclude_file}.tmp" "${incompatible_targets_file}" | uniq -u >"${targets_to_exclude_file}" fi +# Handle feature area exclusions if specified +if [[ -n "${feature_areas}" ]]; then + feature_area_query_file="${working_dir}/feature_area_query" + + # Convert comma-separated feature areas to bazel query format + IFS=',' read -ra AREAS <<< "${feature_areas}" + include_conditions="" + + for area in "${AREAS[@]}"; do + # Trim whitespace + area=$(echo "${area}" | xargs) + if [[ -z "${include_conditions}" ]]; then + include_conditions="attr(tags, \"feature:${area}\", //tests/...)" + else + include_conditions="${include_conditions} + attr(tags, \"feature:${area}\", //tests/...)" + fi + done + + # Create bazel query to find py_test targets NOT in the specified feature areas + # This excludes everything that doesn't have the specified feature tags + cat >"${feature_area_query_file}" <>"${targets_to_exclude_file}" + echo "Feature area filtering: added $(echo "${feature_area_test_targets}" | wc -l) non-${feature_areas} test exclusions" + fi + +fi + excluded_test_source_rule_file=${working_dir}/excluded_test_source_rule # -- Begin of Query Rules Heredoc -- @@ -174,6 +213,19 @@ ${bazel} query --query_file="${excluded_test_source_rule_file}" \ awk -F// '{print $2}' | sed -e 's/:/\//g' >"${output_path}" +# Special handling for modeling tests: exclude all modeling feature area tests if feature areas specified and "modeling" not included +if [[ -n "${feature_areas}" && ",${feature_areas}," != *",modeling,"* ]]; then + modeling_query_file="${working_dir}/modeling_query" + cat >"${modeling_query_file}" <>"${output_path}" + fi +fi + # This is for modeling model tests that are automatically generated and not part of the build. if [[ -n "${incompatible_targets}" ]]; then echo "${incompatible_targets}" | sed 's|^//||' | sed 's|:|/|g' | sed 's|$|.py|' >>"${output_path}" @@ -189,6 +241,8 @@ grep ':' "ci/targets/quarantine/${SF_ENV}.txt" | \ fi echo "Tests getting excluded:" +# Sort and deduplicate the exclusion file +sort -u "${output_path}" -o "${output_path}" cat "${output_path}" echo "Done running ${PROG}" diff --git a/ci/targets/quarantine/prod3.txt b/ci/targets/quarantine/prod3.txt index 5382574b..30f2a9ff 100644 --- a/ci/targets/quarantine/prod3.txt +++ b/ci/targets/quarantine/prod3.txt @@ -5,6 +5,7 @@ //tests/integ/snowflake/ml/modeling/linear_model:logistic_regression_test //tests/integ/snowflake/ml/registry/services:registry_huggingface_pipeline_model_deployment_test //tests/integ/snowflake/ml/registry/services:registry_custom_model_batch_inference_test +//tests/integ/snowflake/ml/registry/services:registry_batch_inference_case_sensitivity_test //tests/integ/snowflake/ml/registry/services:registry_sentence_transformers_batch_inference_test //tests/integ/snowflake/ml/registry/services:registry_lightgbm_batch_inference_test //tests/integ/snowflake/ml/registry/services:registry_sklearn_batch_inference_test diff --git a/requirements.yml b/requirements.yml index 874e5740..cf86bc60 100644 --- a/requirements.yml +++ b/requirements.yml @@ -292,7 +292,7 @@ version_requirements: '>=4.1.0,<5' - name: xgboost dev_version: 2.1.4 - version_requirements: '>=1.7.3,<3' + version_requirements: <4 tags: - build_essential - name: werkzeug diff --git a/snowflake/ml/_internal/human_readable_id/adjectives.txt b/snowflake/ml/_internal/human_readable_id/adjectives.txt index be9f0f1e..3c18adf2 100644 --- a/snowflake/ml/_internal/human_readable_id/adjectives.txt +++ b/snowflake/ml/_internal/human_readable_id/adjectives.txt @@ -1,3 +1,4 @@ +aerial afraid ancient angry @@ -26,7 +27,6 @@ dull empty evil fast -fat fluffy foolish fresh @@ -57,10 +57,10 @@ lovely lucky massive mean +metallic mighty modern moody -nasty neat nervous new @@ -85,7 +85,6 @@ rotten rude selfish serious -shaggy sharp short shy @@ -96,14 +95,15 @@ slippery smart smooth soft +solid sour spicy splendid spotty +squishy stale strange strong -stupid sweet swift tall @@ -116,7 +116,6 @@ tidy tiny tough tricky -ugly warm weak wet @@ -124,5 +123,6 @@ wicked wise witty wonderful +wooden yellow young diff --git a/snowflake/ml/_internal/human_readable_id/animals.txt b/snowflake/ml/_internal/human_readable_id/animals.txt index efe28601..249b14f3 100644 --- a/snowflake/ml/_internal/human_readable_id/animals.txt +++ b/snowflake/ml/_internal/human_readable_id/animals.txt @@ -1,10 +1,9 @@ anaconda ant -ape -baboon badger bat bear +beetle bird bobcat bulldog @@ -73,7 +72,6 @@ lobster mayfly mamba mole -monkey moose moth mouse @@ -114,6 +112,7 @@ swan termite tiger treefrog +tuna turkey turtle vampirebat @@ -126,3 +125,4 @@ worm yak yeti zebra +zebrafish diff --git a/snowflake/ml/jobs/BUILD.bazel b/snowflake/ml/jobs/BUILD.bazel index 42cd79a7..b5ba770b 100644 --- a/snowflake/ml/jobs/BUILD.bazel +++ b/snowflake/ml/jobs/BUILD.bazel @@ -15,6 +15,7 @@ py_library( srcs = ["job.py"], deps = [ "//snowflake/ml/_internal:telemetry", + "//snowflake/ml/jobs/_interop:interop", "//snowflake/ml/jobs/_utils:job_utils", ], ) diff --git a/snowflake/ml/jobs/__init__.py b/snowflake/ml/jobs/__init__.py index 17a9d5b5..ac91976d 100644 --- a/snowflake/ml/jobs/__init__.py +++ b/snowflake/ml/jobs/__init__.py @@ -1,3 +1,4 @@ +from snowflake.ml.jobs._interop.exception_utils import install_exception_display_hooks from snowflake.ml.jobs._utils.types import JOB_STATUS from snowflake.ml.jobs.decorators import remote from snowflake.ml.jobs.job import MLJob @@ -10,6 +11,9 @@ submit_from_stage, ) +# Initialize exception display hooks for remote job error handling +install_exception_display_hooks() + __all__ = [ "remote", "submit_file", diff --git a/snowflake/ml/jobs/_interop/BUILD.bazel b/snowflake/ml/jobs/_interop/BUILD.bazel new file mode 100644 index 00000000..f579a2c0 --- /dev/null +++ b/snowflake/ml/jobs/_interop/BUILD.bazel @@ -0,0 +1,106 @@ +load("//bazel:py_rules.bzl", "py_library", "py_test") + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "dto_schema", + srcs = ["dto_schema.py"], +) + +py_library( + name = "results", + srcs = ["results.py"], +) + +py_library( + name = "data_utils", + srcs = ["data_utils.py"], + deps = [ + ":dto_schema", + ], +) + +py_library( + name = "protocols", + srcs = ["protocols.py"], + deps = [ + ":data_utils", + ":dto_schema", + ], +) + +py_library( + name = "exception_utils", + srcs = ["exception_utils.py"], +) + +py_library( + name = "legacy", + srcs = ["legacy.py"], + deps = [ + ":exception_utils", + ":results", + ], +) + +py_library( + name = "utils", + srcs = ["utils.py"], + deps = [ + ":data_utils", + ":dto_schema", + ":exception_utils", + ":legacy", + ":protocols", + ":results", + ], +) + +py_test( + name = "protocols_test", + srcs = ["protocols_test.py"], + tags = ["feature:jobs"], + deps = [ + ":protocols", + ], +) + +py_test( + name = "exception_utils_test", + srcs = ["exception_utils_test.py"], + tags = ["feature:jobs"], + deps = [ + ":exception_utils", + ], +) + +py_test( + name = "utils_test", + srcs = ["utils_test.py"], + tags = ["feature:jobs"], + deps = [ + ":utils", + ], +) + +py_test( + name = "legacy_test", + srcs = ["legacy_test.py"], + tags = ["feature:jobs"], + deps = [ + ":exception_utils", + ":legacy", + ], +) + +py_library( + name = "interop", + srcs = [ + "__init__.py", + ], + deps = [ + ":dto_schema", + ":protocols", + ":utils", + ], +) diff --git a/snowflake/ml/jobs/_interop/__init__.py b/snowflake/ml/jobs/_interop/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/snowflake/ml/jobs/_interop/data_utils.py b/snowflake/ml/jobs/_interop/data_utils.py new file mode 100644 index 00000000..984dfae5 --- /dev/null +++ b/snowflake/ml/jobs/_interop/data_utils.py @@ -0,0 +1,124 @@ +import io +import json +from typing import Any, Literal, Optional, Protocol, Union, cast, overload + +from snowflake import snowpark +from snowflake.ml.jobs._interop import dto_schema + + +class StageFileWriter(io.IOBase): + """ + A context manager IOBase implementation that proxies writes to an internal BytesIO + and uploads to Snowflake stage on close. + """ + + def __init__(self, session: snowpark.Session, path: str) -> None: + self._session = session + self._path = path + self._buffer = io.BytesIO() + self._closed = False + self._exception_occurred = False + + def write(self, data: Union[bytes, bytearray]) -> int: + """Write data to the internal buffer.""" + if self._closed: + raise ValueError("I/O operation on closed file") + return self._buffer.write(data) + + def close(self, write_contents: bool = True) -> None: + """Close the file and upload the buffer contents to the stage.""" + if not self._closed: + # Only upload if buffer has content and no exception occurred + if write_contents and self._buffer.tell() > 0: + self._buffer.seek(0) + self._session.file.put_stream(self._buffer, self._path) + self._buffer.close() + self._closed = True + + def __enter__(self) -> "StageFileWriter": + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + exception_occurred = exc_type is not None + self.close(write_contents=not exception_occurred) + + @property + def closed(self) -> bool: + return self._closed + + def writable(self) -> bool: + return not self._closed + + def readable(self) -> bool: + return False + + def seekable(self) -> bool: + return not self._closed + + +def _is_stage_path(path: str) -> bool: + return path.startswith("@") or path.startswith("snow://") + + +def open_stream(path: str, mode: str = "rb", session: Optional[snowpark.Session] = None) -> io.IOBase: + if _is_stage_path(path): + if session is None: + raise ValueError("Session is required when opening a stage path") + if "r" in mode: + stream: io.IOBase = session.file.get_stream(path) # type: ignore[assignment] + return stream + elif "w" in mode: + return StageFileWriter(session, path) + else: + raise ValueError(f"Unsupported mode '{mode}' for stage path") + else: + result: io.IOBase = open(path, mode) # type: ignore[assignment] + return result + + +class DtoCodec(Protocol): + @overload + @staticmethod + def decode(stream: io.IOBase, as_dict: Literal[True]) -> dict[str, Any]: + ... + + @overload + @staticmethod + def decode(stream: io.IOBase, as_dict: Literal[False] = False) -> dto_schema.ResultDTO: + ... + + @staticmethod + def decode(stream: io.IOBase, as_dict: bool = False) -> Union[dto_schema.ResultDTO, dict[str, Any]]: + pass + + @staticmethod + def encode(dto: dto_schema.ResultDTO) -> bytes: + pass + + +class JsonDtoCodec(DtoCodec): + @overload + @staticmethod + def decode(stream: io.IOBase, as_dict: Literal[True]) -> dict[str, Any]: + ... + + @overload + @staticmethod + def decode(stream: io.IOBase, as_dict: Literal[False] = False) -> dto_schema.ResultDTO: + ... + + @staticmethod + def decode(stream: io.IOBase, as_dict: bool = False) -> Union[dto_schema.ResultDTO, dict[str, Any]]: + data = cast(dict[str, Any], json.load(stream)) + if as_dict: + return data + return dto_schema.ResultDTO.model_validate(data) + + @staticmethod + def encode(dto: dto_schema.ResultDTO) -> bytes: + # Temporarily extract the value to avoid accidentally applying model_dump() on it + result_value = dto.value + dto.value = None # Clear value to avoid serializing it in the model_dump + result_dict = dto.model_dump() + result_dict["value"] = result_value # Put back the value + return json.dumps(result_dict).encode("utf-8") diff --git a/snowflake/ml/jobs/_interop/dto_schema.py b/snowflake/ml/jobs/_interop/dto_schema.py new file mode 100644 index 00000000..c4dbee2f --- /dev/null +++ b/snowflake/ml/jobs/_interop/dto_schema.py @@ -0,0 +1,95 @@ +from typing import Any, Optional, Union + +from pydantic import BaseModel, model_validator +from typing_extensions import NotRequired, TypedDict + + +class BinaryManifest(TypedDict): + """ + Binary data manifest schema. + Contains one of: path, bytes, or base64 for the serialized data. + """ + + path: NotRequired[str] # Path to file + bytes: NotRequired[bytes] # In-line byte string (not supported with JSON codec) + base64: NotRequired[str] # Base64 encoded string + + +class ParquetManifest(TypedDict): + """Protocol manifest schema for parquet files.""" + + paths: list[str] # File paths + + +# Union type for all manifest types, including catch-all dict[str, Any] for backward compatibility +PayloadManifest = Union[BinaryManifest, ParquetManifest, dict[str, Any]] + + +class ProtocolInfo(BaseModel): + """ + The protocol used to serialize the result and the manifest of the result. + """ + + name: str + version: Optional[str] = None + metadata: Optional[dict[str, str]] = None + manifest: Optional[PayloadManifest] = None + + def __str__(self) -> str: + result = self.name + if self.version: + result += f"-{self.version}" + return result + + def with_manifest(self, manifest: PayloadManifest) -> "ProtocolInfo": + """ + Return a new ProtocolInfo object with the manifest. + """ + return ProtocolInfo( + name=self.name, + version=self.version, + metadata=self.metadata, + manifest=manifest, + ) + + +class ResultMetadata(BaseModel): + """ + The metadata of a result. + """ + + type: str + repr: str + + +class ExceptionMetadata(ResultMetadata): + message: str + traceback: str + + +class ResultDTO(BaseModel): + """ + A JSON representation of an execution result. + + Args: + success: Whether the execution was successful. + value: The value of the execution or the exception if the execution failed. + protocol: The protocol used to serialize the result. + metadata: The metadata of the result. + """ + + success: bool + value: Optional[Any] = None + protocol: Optional[ProtocolInfo] = None + metadata: Optional[Union[ResultMetadata, ExceptionMetadata]] = None + serialize_error: Optional[str] = None + + @model_validator(mode="before") + @classmethod + def validate_fields(cls, data: Any) -> Any: + """Ensure at least one of value, protocol, or metadata keys is specified.""" + if isinstance(data, dict): + required_fields = {"value", "protocol", "metadata"} + if not any(field in data for field in required_fields): + raise ValueError("At least one of 'value', 'protocol', or 'metadata' must be specified") + return data diff --git a/snowflake/ml/jobs/_utils/interop_utils.py b/snowflake/ml/jobs/_interop/exception_utils.py similarity index 63% rename from snowflake/ml/jobs/_utils/interop_utils.py rename to snowflake/ml/jobs/_interop/exception_utils.py index 688c2907..fc35f9e4 100644 --- a/snowflake/ml/jobs/_utils/interop_utils.py +++ b/snowflake/ml/jobs/_interop/exception_utils.py @@ -1,19 +1,12 @@ import builtins import functools import importlib -import json -import os -import pickle import re import sys import traceback from collections import namedtuple -from dataclasses import dataclass from types import TracebackType -from typing import Any, Callable, Optional, Union, cast - -from snowflake import snowpark -from snowflake.snowpark import exceptions as sp_exceptions +from typing import Any, Callable, Optional, cast _TRACEBACK_ENTRY_PATTERN = re.compile( r'File "(?P[^"]+)", line (?P\d+), in (?P[^\n]+)(?:\n(?!^\s*File)^\s*(?P[^\n]+))?\n', @@ -21,175 +14,46 @@ ) _REMOTE_ERROR_ATTR_NAME = "_remote_error" -RemoteError = namedtuple("RemoteError", ["exc_type", "exc_msg", "exc_tb"]) - - -@dataclass(frozen=True) -class ExecutionResult: - result: Any = None - exception: Optional[BaseException] = None - - @property - def success(self) -> bool: - return self.exception is None - - def to_dict(self) -> dict[str, Any]: - """Return the serializable dictionary.""" - if isinstance(self.exception, BaseException): - exc_type = type(self.exception) - return { - "success": False, - "exc_type": f"{exc_type.__module__}.{exc_type.__name__}", - "exc_value": self.exception, - "exc_tb": "".join(traceback.format_tb(self.exception.__traceback__)), - } - return { - "success": True, - "result_type": type(self.result).__qualname__, - "result": self.result, - } - - @classmethod - def from_dict(cls, result_dict: dict[str, Any]) -> "ExecutionResult": - if not isinstance(result_dict.get("success"), bool): - raise ValueError("Invalid result dictionary") - - if result_dict["success"]: - # Load successful result - return cls(result=result_dict.get("result")) - - # Load exception - exc_type = result_dict.get("exc_type", "RuntimeError") - exc_value = result_dict.get("exc_value", "Unknown error") - exc_tb = result_dict.get("exc_tb", "") - return cls(exception=load_exception(exc_type, exc_value, exc_tb)) - - -def fetch_result(session: snowpark.Session, result_path: str) -> ExecutionResult: - """ - Fetch the serialized result from the specified path. +RemoteErrorInfo = namedtuple("RemoteErrorInfo", ["exc_type", "exc_msg", "exc_tb"]) - Args: - session: Snowpark Session to use for file operations. - result_path: The path to the serialized result file. - Returns: - A dictionary containing the execution result if available, None otherwise. +class RemoteError(RuntimeError): + """Base exception for errors from remote execution environment which could not be reconstructed locally.""" - Raises: - RuntimeError: If both pickle and JSON result retrieval fail. - """ - try: - # TODO: Check if file exists - with session.file.get_stream(result_path) as result_stream: - return ExecutionResult.from_dict(pickle.load(result_stream)) - except ( - sp_exceptions.SnowparkSQLException, - pickle.UnpicklingError, - TypeError, - ImportError, - AttributeError, - MemoryError, - ) as pickle_error: - # Fall back to JSON result if loading pickled result fails for any reason - try: - result_json_path = os.path.splitext(result_path)[0] + ".json" - with session.file.get_stream(result_json_path) as result_stream: - return ExecutionResult.from_dict(json.load(result_stream)) - except Exception as json_error: - # Both pickle and JSON failed - provide helpful error message - raise RuntimeError(_fetch_result_error_message(pickle_error, result_path, json_error)) from pickle_error - - -def _fetch_result_error_message(error: Exception, result_path: str, json_error: Optional[Exception] = None) -> str: - """Create helpful error messages for common result retrieval failures.""" - - # Package import issues - if isinstance(error, ImportError): - return f"Failed to retrieve job result: Package not installed in your local environment. Error: {str(error)}" - - # Package versions differ between runtime and local environment - if isinstance(error, AttributeError): - return f"Failed to retrieve job result: Package version mismatch. Error: {str(error)}" - - # Serialization issues - if isinstance(error, TypeError): - return f"Failed to retrieve job result: Non-serializable objects were returned. Error: {str(error)}" - - # Python version pickling incompatibility - if isinstance(error, pickle.UnpicklingError) and "protocol" in str(error).lower(): - # TODO: Update this once we support different Python versions - client_version = f"Python {sys.version_info.major}.{sys.version_info.minor}" - runtime_version = "Python 3.10" - return ( - f"Failed to retrieve job result: Python version mismatch - job ran on {runtime_version}, " - f"local environment using Python {client_version}. Error: {str(error)}" - ) - # File access issues - if isinstance(error, sp_exceptions.SnowparkSQLException): - if "not found" in str(error).lower() or "does not exist" in str(error).lower(): - return ( - f"Failed to retrieve job result: No result file found. Check job.get_logs() for execution " - f"errors. Error: {str(error)}" - ) - else: - return f"Failed to retrieve job result: Cannot access result file. Error: {str(error)}" - - if isinstance(error, MemoryError): - return f"Failed to retrieve job result: Result too large for memory. Error: {str(error)}" - - # Generic fallback - base_message = f"Failed to retrieve job result: {str(error)}" - if json_error: - base_message += f" (JSON fallback also failed: {str(json_error)})" - return base_message - - -def load_exception(exc_type_name: str, exc_value: Union[Exception, str], exc_tb: str) -> Exception: - """ - Create an exception with a string-formatted traceback. - - When this exception is raised and not caught, it will display the original traceback. - When caught, it behaves like a regular exception without showing the traceback. - - Args: - exc_type_name: Name of the exception type (e.g., 'ValueError', 'RuntimeError') - exc_value: The deserialized exception value or exception string (i.e. message) - exc_tb: String representation of the traceback +def build_exception(type_str: str, message: str, traceback: str, original_repr: Optional[str] = None) -> BaseException: + """Build an exception from metadata, attaching remote error info.""" + if not original_repr: + original_repr = f"{type_str}('{message}')" + try: + ex = reconstruct_exception(type_str=type_str, message=message) + except Exception as e: + # Fallback to a generic error type if reconstruction fails + ex = RemoteError(original_repr) + ex.__cause__ = e + return attach_remote_error_info(ex, type_str, message, traceback) - Returns: - An exception object with the original traceback information - # noqa: DAR401 - """ - if isinstance(exc_value, Exception): - exception = exc_value - else: - # Try to load the original exception type if possible - try: - # First check built-in exceptions - exc_type = getattr(builtins, exc_type_name, None) - if exc_type is None and "." in exc_type_name: - # Try to import from module path if it's a qualified name - module_path, class_name = exc_type_name.rsplit(".", 1) - module = importlib.import_module(module_path) - exc_type = getattr(module, class_name) - if exc_type is None or not issubclass(exc_type, Exception): - raise TypeError(f"{exc_type_name} is not a known exception type") - # Create the exception instance - exception = exc_type(exc_value) - except (ImportError, AttributeError, TypeError): - # Fall back to a generic exception - exception = RuntimeError( - f"Exception deserialization failed, original exception: {exc_type_name}: {exc_value}" - ) +def reconstruct_exception(type_str: str, message: str) -> BaseException: + """Best effort reconstruction of an exception from metadata.""" + try: + type_split = type_str.rsplit(".", 1) + if len(type_split) == 1: + module = builtins + else: + module = importlib.import_module(type_split[0]) + exc_type = getattr(module, type_split[-1]) + except (ImportError, AttributeError): + raise ModuleNotFoundError( + f"Unrecognized exception type '{type_str}', likely due to a missing or unavailable package" + ) from None - # Attach the traceback information to the exception - return _attach_remote_error_info(exception, exc_type_name, str(exc_value), exc_tb) + if not issubclass(exc_type, BaseException): + raise TypeError(f"Imported type {type_str} is not a known exception type, possibly due to a name conflict") + return cast(BaseException, exc_type(message)) -def _attach_remote_error_info(ex: Exception, exc_type: str, exc_msg: str, traceback_str: str) -> Exception: +def attach_remote_error_info(ex: BaseException, exc_type: str, exc_msg: str, traceback_str: str) -> BaseException: """ Attach a string-formatted traceback to an exception. @@ -207,11 +71,11 @@ def _attach_remote_error_info(ex: Exception, exc_type: str, exc_msg: str, traceb """ # Store the traceback information exc_type = exc_type.rsplit(".", 1)[-1] # Remove module path - setattr(ex, _REMOTE_ERROR_ATTR_NAME, RemoteError(exc_type=exc_type, exc_msg=exc_msg, exc_tb=traceback_str)) + setattr(ex, _REMOTE_ERROR_ATTR_NAME, RemoteErrorInfo(exc_type=exc_type, exc_msg=exc_msg, exc_tb=traceback_str)) return ex -def _retrieve_remote_error_info(ex: Optional[BaseException]) -> Optional[RemoteError]: +def retrieve_remote_error_info(ex: Optional[BaseException]) -> Optional[RemoteErrorInfo]: """ Retrieve the string-formatted traceback from an exception if it exists. @@ -285,7 +149,7 @@ def _install_sys_excepthook() -> None: sys.excepthook is the global hook that Python calls when an unhandled exception occurs. By default it prints the exception type, message and traceback to stderr. - We override sys.excepthook to intercept exceptions that contain our special RemoteError + We override sys.excepthook to intercept exceptions that contain our special RemoteErrorInfo attribute. These exceptions come from deserialized remote execution results and contain the original traceback information from where they occurred. @@ -327,7 +191,7 @@ def custom_excepthook( "\nDuring handling of the above exception, another exception occurred:\n", file=sys.stderr ) - if (remote_err := _retrieve_remote_error_info(exc_value)) and isinstance(remote_err, RemoteError): + if (remote_err := retrieve_remote_error_info(exc_value)) and isinstance(remote_err, RemoteErrorInfo): # Display stored traceback for deserialized exceptions print("Traceback (from remote execution):", file=sys.stderr) # noqa: T201 print(remote_err.exc_tb, end="", file=sys.stderr) # noqa: T201 @@ -408,7 +272,7 @@ def custom_format_exception_as_a_whole( tb_offset: Optional[int], **kwargs: Any, ) -> list[list[str]]: - if (remote_err := _retrieve_remote_error_info(evalue)) and isinstance(remote_err, RemoteError): + if (remote_err := retrieve_remote_error_info(evalue)) and isinstance(remote_err, RemoteErrorInfo): # Implementation forked from IPython.core.ultratb.VerboseTB.format_exception_as_a_whole head = self.prepare_header(remote_err.exc_type, long_version=False).replace( "(most recent call last)", @@ -448,7 +312,7 @@ def structured_traceback( tb_offset: Optional[int] = None, **kwargs: Any, ) -> list[str]: - if (remote_err := _retrieve_remote_error_info(evalue)) and isinstance(remote_err, RemoteError): + if (remote_err := retrieve_remote_error_info(evalue)) and isinstance(remote_err, RemoteErrorInfo): tb_list = [ (m.group("filename"), m.group("lineno"), m.group("name"), m.group("line")) for m in re.finditer(_TRACEBACK_ENTRY_PATTERN, remote_err.exc_tb or "") @@ -493,9 +357,16 @@ def _uninstall_ipython_hook() -> None: def install_exception_display_hooks() -> None: - if not _install_ipython_hook(): - _install_sys_excepthook() + """Install custom exception display hooks for improved remote error reporting. + This function should be called once during package initialization to set up + enhanced error handling for remote job execution errors. The hooks will: -# ------ Install the custom traceback hooks by default ------ # -install_exception_display_hooks() + - Display original remote tracebacks instead of local deserialization traces + - Work in both standard Python and IPython/Jupyter environments + - Safely fall back to original behavior if errors occur + + Note: This function is idempotent and safe to call multiple times. + """ + if not _install_ipython_hook(): + _install_sys_excepthook() diff --git a/snowflake/ml/jobs/_utils/interop_utils_test.py b/snowflake/ml/jobs/_interop/exception_utils_test.py similarity index 67% rename from snowflake/ml/jobs/_utils/interop_utils_test.py rename to snowflake/ml/jobs/_interop/exception_utils_test.py index bd91a4c5..0334e105 100644 --- a/snowflake/ml/jobs/_utils/interop_utils_test.py +++ b/snowflake/ml/jobs/_interop/exception_utils_test.py @@ -1,73 +1,110 @@ import sys from types import TracebackType -from typing import Any +from typing import Any, Optional from unittest import mock -from absl.testing import absltest - -from snowflake.ml.jobs._utils import interop_utils -from snowflake.snowpark import exceptions - - -class TestInteropUtils(absltest.TestCase): - def test_load_exception_with_builtin_exception(self) -> None: - """Test loading a built-in exception type.""" - exc = interop_utils.load_exception("ValueError", "test error message", "traceback info") - self.assertIsInstance(exc, ValueError) - self.assertEqual(str(exc), "test error message") - - remote_error = interop_utils._retrieve_remote_error_info(exc) - assert remote_error is not None - self.assertEqual(remote_error.exc_type, "ValueError") - self.assertEqual(remote_error.exc_msg, "test error message") - self.assertEqual(remote_error.exc_tb, "traceback info") - - def test_load_exception_with_custom_exception_name(self) -> None: - """Test loading an exception with a custom exception name.""" - # Create a non-existent exception type name - exc = interop_utils.load_exception("NonExistentError", "custom error", "traceback info") - self.assertIsInstance(exc, RuntimeError) - self.assertTrue("Exception deserialization failed" in str(exc)) - - remote_error = interop_utils._retrieve_remote_error_info(exc) - assert remote_error is not None - self.assertEqual(remote_error.exc_type, "NonExistentError") - self.assertEqual(remote_error.exc_msg, "custom error") - self.assertEqual(remote_error.exc_tb, "traceback info") - - def test_load_exception_with_qualified_name(self) -> None: - """Test loading an exception with a qualified name.""" - # Use a common exception from a module - exc_type = exceptions.SnowparkClientException - exc = interop_utils.load_exception(f"{exc_type.__module__}.{exc_type.__name__}", "mock error", "traceback info") - self.assertIsInstance(exc, exceptions.SnowparkClientException) - self.assertEqual(str(exc), "mock error") - - remote_error = interop_utils._retrieve_remote_error_info(exc) - assert remote_error is not None - self.assertEqual(remote_error.exc_type, "SnowparkClientException") - self.assertEqual(remote_error.exc_msg, "mock error") - self.assertEqual(remote_error.exc_tb, "traceback info") - - def test_load_exception_with_exception_instance(self) -> None: - """Test loading with an existing exception instance.""" - original_exc = ValueError("original error") - exc = interop_utils.load_exception("ValueError", original_exc, "traceback info") - self.assertIs(exc, original_exc) # Should be the same object - - remote_error = interop_utils._retrieve_remote_error_info(exc) - assert remote_error is not None - self.assertEqual(remote_error.exc_type, "ValueError") - self.assertEqual(remote_error.exc_msg, "original error") - self.assertEqual(remote_error.exc_tb, "traceback info") +from absl.testing import absltest, parameterized + +from snowflake.ml.jobs._interop import exception_utils +from snowflake.snowpark import exceptions as sp_exceptions + + +class ComplexError(Exception): + def __init__(self, message: str, code: int) -> None: + super().__init__(message) + self.code = code + + def __repr__(self) -> str: + return f"ComplexError(message={self.args[0]!r}, code={self.code})" + + +class TestExceptionUtils(parameterized.TestCase): + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + + # Ensure hooks are installed for testing + exception_utils.install_exception_display_hooks() + + @parameterized.named_parameters( # type: ignore[misc] + ("value_error", "ValueError", "test error message", "traceback info", None, ValueError("test error message")), + ( + "not_implemented_error", + "NotImplementedError", + "test error message", + "traceback info", + None, + NotImplementedError("test error message"), + ), + ( + "snowpark_error", + "snowflake.snowpark.exceptions.SnowparkSQLException", + "test error message", + "traceback info", + None, + sp_exceptions.SnowparkSQLException("test error message"), + ), + ( + "not_exist_error", + "NonExistentError", + "custom error", + "traceback info", + None, + exception_utils.RemoteError("NonExistentError('custom error')"), + ), + ( + "complex_ctor", + "__main__.ComplexError", + "custom error", + "traceback info", + repr(ComplexError("Execution failed with error: custom error", 100)), + exception_utils.RemoteError(repr(ComplexError("Execution failed with error: custom error", 100))), + ), + ( + "custom_repr_NonExistentError", + "NonExistentError", + "custom error", + "traceback info", + "NonExistentError with custom repr: custom error", + exception_utils.RemoteError("NonExistentError with custom repr: custom error"), + ), + ( + "custom_repr_complex_ctor", + "__main__.ComplexError", + "custom error", + "traceback info", + "ComplexError with custom repr: custom error", + exception_utils.RemoteError("ComplexError with custom repr: custom error"), + ), + ) + def test_build_exception( + self, + exc_type: str, + exc_msg: str, + exc_tb: str, + exc_repr: Optional[str], + expected: BaseException, + ) -> None: + exc_value = exception_utils.build_exception( + type_str=exc_type, + message=exc_msg, + traceback=exc_tb, + original_repr=exc_repr, + ) + self.assertEqual(type(exc_value), type(expected)) + self.assertEqual(str(exc_value), str(expected)) + self.assertEqual( + exception_utils.retrieve_remote_error_info(exc_value), + exception_utils.RemoteErrorInfo(exc_type.rsplit(".", 1)[-1], exc_msg, exc_tb), + ) def test_attach_and_retrieve_traceback(self) -> None: """Test attaching and retrieving a traceback from an exception.""" exc = ValueError("test error") - interop_utils._attach_remote_error_info(exc, type(exc).__name__, str(exc), "sample traceback") + exception_utils.attach_remote_error_info(exc, type(exc).__name__, str(exc), "sample traceback") # Test retrieval - remote_error = interop_utils._retrieve_remote_error_info(exc) + remote_error = exception_utils.retrieve_remote_error_info(exc) assert remote_error is not None self.assertEqual(remote_error.exc_type, "ValueError") self.assertEqual(remote_error.exc_msg, "test error") @@ -75,7 +112,7 @@ def test_attach_and_retrieve_traceback(self) -> None: # Test retrieval on exception without traceback exc2 = RuntimeError("no traceback") - self.assertIsNone(interop_utils._retrieve_remote_error_info(exc2)) + self.assertIsNone(exception_utils.retrieve_remote_error_info(exc2)) def test_excepthook_installation(self) -> None: """Test that the excepthook is installed correctly.""" @@ -91,7 +128,7 @@ def test_uninstall_sys_excepthook(self) -> None: custom_excepthook = sys.excepthook # Uninstall our custom excepthook - interop_utils._uninstall_sys_excepthook() + exception_utils._uninstall_sys_excepthook() # Verify the original excepthook is restored self.assertEqual(sys.excepthook, original_excepthook) @@ -136,7 +173,7 @@ def test_uninstall_ipython_hook(self) -> None: IPython.core.ultratb.ListTB = mock_list_tb # Call the uninstall function - interop_utils._uninstall_ipython_hook() + exception_utils._uninstall_ipython_hook() # Verify that the original methods were restored self.assertEqual(mock_verbose_tb.format_exception_as_a_whole, original_format_exception) @@ -153,7 +190,7 @@ def test_revert_func_wrapper(self) -> None: uninstall_func = mock.MagicMock() # Create the wrapped function - wrapped_func = interop_utils._revert_func_wrapper(patched_func, original_func, uninstall_func) + wrapped_func = exception_utils._revert_func_wrapper(patched_func, original_func, uninstall_func) # Call the wrapped function result = wrapped_func("arg1", kwarg1="value1") @@ -181,8 +218,8 @@ def failing_custom_excepthook( # Install our test hooks sys._original_excepthook = mock_original_excepthook # type: ignore[attr-defined] - sys.excepthook = interop_utils._revert_func_wrapper( - failing_custom_excepthook, mock_original_excepthook, interop_utils._uninstall_sys_excepthook + sys.excepthook = exception_utils._revert_func_wrapper( + failing_custom_excepthook, mock_original_excepthook, exception_utils._uninstall_sys_excepthook ) # Trigger the excepthook with an exception @@ -225,13 +262,13 @@ def test_ipython_hook_fallback(self) -> None: # Setup the class mocks with original methods saved mock_verbose_tb._original_format_exception_as_a_whole = original_format_exception - mock_verbose_tb.format_exception_as_a_whole = interop_utils._revert_func_wrapper( - failing_format_exception, original_format_exception, interop_utils._uninstall_ipython_hook + mock_verbose_tb.format_exception_as_a_whole = exception_utils._revert_func_wrapper( + failing_format_exception, original_format_exception, exception_utils._uninstall_ipython_hook ) mock_list_tb._original_structured_traceback = original_structured_traceback - mock_list_tb.structured_traceback = interop_utils._revert_func_wrapper( - failing_structured_traceback, original_structured_traceback, interop_utils._uninstall_ipython_hook + mock_list_tb.structured_traceback = exception_utils._revert_func_wrapper( + failing_structured_traceback, original_structured_traceback, exception_utils._uninstall_ipython_hook ) # Assign to IPython mock @@ -250,8 +287,8 @@ def test_ipython_hook_fallback(self) -> None: # Reset for second test mock_verbose_tb._original_format_exception_as_a_whole = original_format_exception - mock_verbose_tb.format_exception_as_a_whole = interop_utils._revert_func_wrapper( - failing_format_exception, original_format_exception, interop_utils._uninstall_ipython_hook + mock_verbose_tb.format_exception_as_a_whole = exception_utils._revert_func_wrapper( + failing_format_exception, original_format_exception, exception_utils._uninstall_ipython_hook ) # Test ListTB structured_traceback fallback diff --git a/snowflake/ml/jobs/_interop/legacy.py b/snowflake/ml/jobs/_interop/legacy.py new file mode 100644 index 00000000..5a2d734d --- /dev/null +++ b/snowflake/ml/jobs/_interop/legacy.py @@ -0,0 +1,225 @@ +"""Legacy result serialization protocol support for ML Jobs. + +This module provides backward compatibility with the result serialization protocol used by +mljob_launcher.py prior to snowflake-ml-python>=1.17.0 + +LEGACY PROTOCOL (v1): +--------------------- +The old serialization protocol (save_mljob_result_v1 in mljob_launcher.py) worked as follows: + +1. Results were stored in an ExecutionResult dataclass with two optional fields: + - result: Any = None # For successful executions + - exception: BaseException = None # For failed executions + +2. The ExecutionResult was converted to a dictionary via to_dict(): + Success case: + {"success": True, "result_type": , "result": } + + Failure case: + {"success": False, "exc_type": ".", "exc_value": , + "exc_tb": } + +3. The dictionary was serialized TWICE for fault tolerance: + - Primary: cloudpickle to .pkl file under output/mljob_result.pkl (supports complex Python objects) + - Fallback: JSON to .json file under output/mljob_result.json (for cross-version compatibility) + +WHY THIS MODULE EXISTS: +----------------------- +Jobs submitted with client versions using the v1 protocol will write v1-format result files. +This module ensures that newer clients can still retrieve results from: +- Jobs submitted before the protocol change +- Jobs running in environments where snowflake.ml.jobs._interop is not available + (triggering the ImportError fallback to v1 in save_mljob_result) + +RETRIEVAL FLOW: +--------------- +fetch_result() implements the v1 retrieval logic: +1. Try to unpickle from .pkl file +2. On failure (version mismatch, missing imports, etc.), fall back to .json file +3. Convert the legacy dict format to ExecutionResult +4. Provide helpful error messages for common failure modes + +REMOVAL IMPLICATIONS: +--------------------- +Removing this module would break result retrieval for: +- Any jobs that were submitted with snowflake-ml-python<1.17.0 and are still running/completed +- Any jobs running in old runtime environments that fall back to v1 serialization + +Safe to remove when: +- All ML Runtime images have been updated to include the new _interop modules +- Sufficient time has passed that no jobs using the old protocol are still retrievable + (consider retention policies for job history/logs) +""" + +import json +import os +import pickle +import sys +import traceback +from dataclasses import dataclass +from typing import Any, Optional, Union + +from snowflake import snowpark +from snowflake.ml.jobs._interop import exception_utils, results +from snowflake.snowpark import exceptions as sp_exceptions + + +@dataclass(frozen=True) +class ExecutionResult: + result: Any = None + exception: Optional[BaseException] = None + + @property + def success(self) -> bool: + return self.exception is None + + def to_dict(self) -> dict[str, Any]: + """Return the serializable dictionary.""" + if isinstance(self.exception, BaseException): + exc_type = type(self.exception) + return { + "success": False, + "exc_type": f"{exc_type.__module__}.{exc_type.__name__}", + "exc_value": self.exception, + "exc_tb": "".join(traceback.format_tb(self.exception.__traceback__)), + } + return { + "success": True, + "result_type": type(self.result).__qualname__, + "result": self.result, + } + + @classmethod + def from_dict(cls, result_dict: dict[str, Any]) -> "ExecutionResult": + if not isinstance(result_dict.get("success"), bool): + raise ValueError("Invalid result dictionary") + + if result_dict["success"]: + # Load successful result + return cls(result=result_dict.get("result")) + + # Load exception + exc_type = result_dict.get("exc_type", "RuntimeError") + exc_value = result_dict.get("exc_value", "Unknown error") + exc_tb = result_dict.get("exc_tb", "") + return cls(exception=load_exception(exc_type, exc_value, exc_tb)) + + +def fetch_result( + session: snowpark.Session, result_path: str, result_json: Optional[dict[str, Any]] = None +) -> ExecutionResult: + """ + Fetch the serialized result from the specified path. + + Args: + session: Snowpark Session to use for file operations. + result_path: The path to the serialized result file. + result_json: Optional pre-loaded JSON result dictionary to use instead of fetching from file. + + Returns: + A dictionary containing the execution result if available, None otherwise. + + Raises: + RuntimeError: If both pickle and JSON result retrieval fail. + """ + try: + with session.file.get_stream(result_path) as result_stream: + return ExecutionResult.from_dict(pickle.load(result_stream)) + except ( + sp_exceptions.SnowparkSQLException, + pickle.UnpicklingError, + TypeError, + ImportError, + AttributeError, + MemoryError, + ) as pickle_error: + # Fall back to JSON result if loading pickled result fails for any reason + try: + if result_json is None: + result_json_path = os.path.splitext(result_path)[0] + ".json" + with session.file.get_stream(result_json_path) as result_stream: + result_json = json.load(result_stream) + return ExecutionResult.from_dict(result_json) + except Exception as json_error: + # Both pickle and JSON failed - provide helpful error message + raise RuntimeError(_fetch_result_error_message(pickle_error, result_path, json_error)) from pickle_error + + +def _fetch_result_error_message(error: Exception, result_path: str, json_error: Optional[Exception] = None) -> str: + """Create helpful error messages for common result retrieval failures.""" + + # Package import issues + if isinstance(error, ImportError): + return f"Failed to retrieve job result: Package not installed in your local environment. Error: {str(error)}" + + # Package versions differ between runtime and local environment + if isinstance(error, AttributeError): + return f"Failed to retrieve job result: Package version mismatch. Error: {str(error)}" + + # Serialization issues + if isinstance(error, TypeError): + return f"Failed to retrieve job result: Non-serializable objects were returned. Error: {str(error)}" + + # Python version pickling incompatibility + if isinstance(error, pickle.UnpicklingError) and "protocol" in str(error).lower(): + client_version = f"Python {sys.version_info.major}.{sys.version_info.minor}" + runtime_version = "Python 3.10" # NOTE: This may be inaccurate, but this path isn't maintained anymore + return ( + f"Failed to retrieve job result: Python version mismatch - job ran on {runtime_version}, " + f"local environment using Python {client_version}. Error: {str(error)}" + ) + + # File access issues + if isinstance(error, sp_exceptions.SnowparkSQLException): + if "not found" in str(error).lower() or "does not exist" in str(error).lower(): + return ( + f"Failed to retrieve job result: No result file found. Check job.get_logs() for execution " + f"errors. Error: {str(error)}" + ) + else: + return f"Failed to retrieve job result: Cannot access result file. Error: {str(error)}" + + if isinstance(error, MemoryError): + return f"Failed to retrieve job result: Result too large for memory. Error: {str(error)}" + + # Generic fallback + base_message = f"Failed to retrieve job result: {str(error)}" + if json_error: + base_message += f" (JSON fallback also failed: {str(json_error)})" + return base_message + + +def load_exception(exc_type_name: str, exc_value: Union[Exception, str], exc_tb: str) -> BaseException: + """ + Create an exception with a string-formatted traceback. + + When this exception is raised and not caught, it will display the original traceback. + When caught, it behaves like a regular exception without showing the traceback. + + Args: + exc_type_name: Name of the exception type (e.g., 'ValueError', 'RuntimeError') + exc_value: The deserialized exception value or exception string (i.e. message) + exc_tb: String representation of the traceback + + Returns: + An exception object with the original traceback information + + # noqa: DAR401 + """ + if isinstance(exc_value, Exception): + exception = exc_value + return exception_utils.attach_remote_error_info(exception, exc_type_name, str(exc_value), exc_tb) + return exception_utils.build_exception(exc_type_name, str(exc_value), exc_tb) + + +def load_legacy_result( + session: snowpark.Session, result_path: str, result_json: Optional[dict[str, Any]] = None +) -> results.ExecutionResult: + # Load result using legacy interop + legacy_result = fetch_result(session, result_path, result_json=result_json) + + # Adapt legacy result to new result + return results.ExecutionResult( + success=legacy_result.success, + value=legacy_result.exception or legacy_result.result, + ) diff --git a/snowflake/ml/jobs/_interop/legacy_test.py b/snowflake/ml/jobs/_interop/legacy_test.py new file mode 100644 index 00000000..2b142cee --- /dev/null +++ b/snowflake/ml/jobs/_interop/legacy_test.py @@ -0,0 +1,61 @@ +from absl.testing import absltest + +from snowflake.ml.jobs._interop import exception_utils, legacy +from snowflake.snowpark import exceptions + + +class TestLegacy(absltest.TestCase): + def test_load_exception_with_builtin_exception(self) -> None: + """Test loading a built-in exception type.""" + exc = legacy.load_exception("ValueError", "test error message", "traceback info") + self.assertIsInstance(exc, ValueError) + self.assertEqual(str(exc), "test error message") + + remote_error = exception_utils.retrieve_remote_error_info(exc) + assert remote_error is not None + self.assertEqual(remote_error.exc_type, "ValueError") + self.assertEqual(remote_error.exc_msg, "test error message") + self.assertEqual(remote_error.exc_tb, "traceback info") + + def test_load_exception_with_custom_exception_name(self) -> None: + """Test loading an exception with a custom exception name.""" + # Create a non-existent exception type name + exc = legacy.load_exception("NonExistentError", "custom error", "traceback info") + self.assertIsInstance(exc, RuntimeError) + self.assertIn("NonExistentError", str(exc)) + + remote_error = exception_utils.retrieve_remote_error_info(exc) + assert remote_error is not None + self.assertEqual(remote_error.exc_type, "NonExistentError") + self.assertEqual(remote_error.exc_msg, "custom error") + self.assertEqual(remote_error.exc_tb, "traceback info") + + def test_load_exception_with_qualified_name(self) -> None: + """Test loading an exception with a qualified name.""" + # Use a common exception from a module + exc_type = exceptions.SnowparkClientException + exc = legacy.load_exception(f"{exc_type.__module__}.{exc_type.__name__}", "mock error", "traceback info") + self.assertIsInstance(exc, exceptions.SnowparkClientException) + self.assertEqual(str(exc), "mock error") + + remote_error = exception_utils.retrieve_remote_error_info(exc) + assert remote_error is not None + self.assertEqual(remote_error.exc_type, "SnowparkClientException") + self.assertEqual(remote_error.exc_msg, "mock error") + self.assertEqual(remote_error.exc_tb, "traceback info") + + def test_load_exception_with_exception_instance(self) -> None: + """Test loading with an existing exception instance.""" + original_exc = ValueError("original error") + exc = legacy.load_exception("ValueError", original_exc, "traceback info") + self.assertIs(exc, original_exc) # Should be the same object + + remote_error = exception_utils.retrieve_remote_error_info(exc) + assert remote_error is not None + self.assertEqual(remote_error.exc_type, "ValueError") + self.assertEqual(remote_error.exc_msg, "original error") + self.assertEqual(remote_error.exc_tb, "traceback info") + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/jobs/_interop/protocols.py b/snowflake/ml/jobs/_interop/protocols.py new file mode 100644 index 00000000..e114fece --- /dev/null +++ b/snowflake/ml/jobs/_interop/protocols.py @@ -0,0 +1,471 @@ +import base64 +import logging +import pickle +import posixpath +import sys +from typing import Any, Callable, Optional, Protocol, Union, cast, runtime_checkable + +from snowflake import snowpark +from snowflake.ml.jobs._interop import data_utils +from snowflake.ml.jobs._interop.dto_schema import ( + BinaryManifest, + ParquetManifest, + ProtocolInfo, +) + +Condition = Union[type, tuple[type, ...], Callable[[Any], bool], None] + +logger = logging.getLogger(__name__) + + +class SerializationError(TypeError): + """Exception raised when a serialization protocol fails.""" + + +class DeserializationError(ValueError): + """Exception raised when a serialization protocol fails.""" + + +class InvalidPayloadError(DeserializationError): + """Exception raised when the payload is invalid.""" + + +class ProtocolMismatchError(DeserializationError): + """Exception raised when the protocol of the serialization protocol is incompatible.""" + + +class VersionMismatchError(ProtocolMismatchError): + """Exception raised when the version of the serialization protocol is incompatible.""" + + +class ProtocolNotFoundError(SerializationError): + """Exception raised when no suitable serialization protocol is available.""" + + +@runtime_checkable +class SerializationProtocol(Protocol): + """ + More advanced protocol which supports more flexibility in how results are saved or loaded. + Results can be saved as one or more files, or directly inline in the PayloadManifest. + If saving as files, the PayloadManifest can save arbitrary "manifest" information. + """ + + @property + def supported_types(self) -> Condition: + """The types that the protocol supports.""" + + @property + def protocol_info(self) -> ProtocolInfo: + """The information about the protocol.""" + + def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo: + """Save the object to the destination directory.""" + + def load( + self, + payload_info: ProtocolInfo, + session: Optional[snowpark.Session] = None, + path_transform: Optional[Callable[[str], str]] = None, + ) -> Any: + """Load the object from the source directory.""" + + +class CloudPickleProtocol(SerializationProtocol): + """ + CloudPickle serialization protocol. + Uses BinaryManifest for manifest schema. + """ + + DEFAULT_PATH = "mljob_extra.pkl" + + def __init__(self) -> None: + import cloudpickle as cp + + self._backend = cp + + def _get_compatibility_error(self, payload_info: ProtocolInfo) -> Optional[Exception]: + """Check compatibility and attempt load, raising helpful errors on failure.""" + version_error = python_error = None + + # Check cloudpickle version compatibility + if payload_info.version: + try: + from packaging import version + + payload_major, current_major = ( + version.parse(payload_info.version).major, + version.parse(self._backend.__version__).major, + ) + if payload_major != current_major: + version_error = "cloudpickle version mismatch: payload={}, current={}".format( + payload_info.version, self._backend.__version__ + ) + except Exception: + if payload_info.version != self.protocol_info.version: + version_error = "cloudpickle version mismatch: payload={}, current={}".format( + payload_info.version, self.protocol_info.version + ) + + # Check Python version compatibility + if payload_info.metadata and "python_version" in payload_info.metadata: + payload_py, current_py = ( + payload_info.metadata["python_version"], + f"{sys.version_info.major}.{sys.version_info.minor}", + ) + if payload_py != current_py: + python_error = f"Python version mismatch: payload={payload_py}, current={current_py}" + + if version_error or python_error: + errors = [err for err in [version_error, python_error] if err] + return VersionMismatchError(f"Load failed due to incompatibility: {'; '.join(errors)}") + return None + + @property + def supported_types(self) -> Condition: + return None # All types are supported + + @property + def protocol_info(self) -> ProtocolInfo: + return ProtocolInfo( + name="cloudpickle", + version=self._backend.__version__, + metadata={ + "python_version": f"{sys.version_info.major}.{sys.version_info.minor}", + }, + ) + + def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo: + """Save the object to the destination directory.""" + result_path = posixpath.join(dest_dir, self.DEFAULT_PATH) + with data_utils.open_stream(result_path, "wb", session=session) as f: + self._backend.dump(obj, f) + manifest: BinaryManifest = {"path": result_path} + return self.protocol_info.with_manifest(manifest) + + def load( + self, + payload_info: ProtocolInfo, + session: Optional[snowpark.Session] = None, + path_transform: Optional[Callable[[str], str]] = None, + ) -> Any: + """Load the object from the source directory.""" + if payload_info.name != self.protocol_info.name: + raise ProtocolMismatchError( + f"Invalid payload protocol: expected '{self.protocol_info.name}', got '{payload_info.name}'" + ) + + payload_manifest = cast(BinaryManifest, payload_info.manifest) + try: + if payload_bytes := payload_manifest.get("bytes"): + return self._backend.loads(payload_bytes) + if payload_b64 := payload_manifest.get("base64"): + return self._backend.loads(base64.b64decode(payload_b64)) + result_path = path_transform(payload_manifest["path"]) if path_transform else payload_manifest["path"] + with data_utils.open_stream(result_path, "rb", session=session) as f: + return self._backend.load(f) + except ( + pickle.UnpicklingError, + TypeError, + AttributeError, + MemoryError, + ) as pickle_error: + if error := self._get_compatibility_error(payload_info): + raise error from pickle_error + raise + + +class ArrowTableProtocol(SerializationProtocol): + """ + Arrow Table serialization protocol. + Uses ParquetManifest for manifest schema. + """ + + DEFAULT_PATH_PATTERN = "mljob_extra_{0}.parquet" + + def __init__(self) -> None: + import pyarrow as pa + import pyarrow.parquet as pq + + self._pa = pa + self._pq = pq + + @property + def supported_types(self) -> Condition: + return cast(type, self._pa.Table) + + @property + def protocol_info(self) -> ProtocolInfo: + return ProtocolInfo( + name="pyarrow", + version=self._pa.__version__, + ) + + def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo: + """Save the object to the destination directory.""" + if not isinstance(obj, self._pa.Table): + raise SerializationError(f"Expected {self._pa.Table.__name__} object, got {type(obj).__name__}") + + # TODO: Support partitioned writes for large datasets + result_path = posixpath.join(dest_dir, self.DEFAULT_PATH_PATTERN.format(0)) + with data_utils.open_stream(result_path, "wb", session=session) as stream: + self._pq.write_table(obj, stream) + + manifest: ParquetManifest = {"paths": [result_path]} + return self.protocol_info.with_manifest(manifest) + + def load( + self, + payload_info: ProtocolInfo, + session: Optional[snowpark.Session] = None, + path_transform: Optional[Callable[[str], str]] = None, + ) -> Any: + """Load the object from the source directory.""" + if payload_info.name != self.protocol_info.name: + raise ProtocolMismatchError( + f"Invalid payload protocol: expected '{self.protocol_info.name}', got '{payload_info.name}'" + ) + + payload_manifest = cast(ParquetManifest, payload_info.manifest) + tables = [] + for path in payload_manifest["paths"]: + transformed_path = path_transform(path) if path_transform else path + with data_utils.open_stream(transformed_path, "rb", session=session) as f: + table = self._pq.read_table(f) + tables.append(table) + return self._pa.concat_tables(tables) if len(tables) > 1 else tables[0] + + +class PandasDataFrameProtocol(SerializationProtocol): + """ + Pandas DataFrame serialization protocol. + Uses ParquetManifest for manifest schema. + """ + + DEFAULT_PATH_PATTERN = "mljob_extra_{0}.parquet" + + def __init__(self) -> None: + import pandas as pd + + self._pd = pd + + @property + def supported_types(self) -> Condition: + return cast(type, self._pd.DataFrame) + + @property + def protocol_info(self) -> ProtocolInfo: + return ProtocolInfo( + name="pandas", + version=self._pd.__version__, + ) + + def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo: + """Save the object to the destination directory.""" + if not isinstance(obj, self._pd.DataFrame): + raise SerializationError(f"Expected {self._pd.DataFrame.__name__} object, got {type(obj).__name__}") + + # TODO: Support partitioned writes for large datasets + result_path = posixpath.join(dest_dir, self.DEFAULT_PATH_PATTERN.format(0)) + with data_utils.open_stream(result_path, "wb", session=session) as stream: + obj.to_parquet(stream) + + manifest: ParquetManifest = {"paths": [result_path]} + return self.protocol_info.with_manifest(manifest) + + def load( + self, + payload_info: ProtocolInfo, + session: Optional[snowpark.Session] = None, + path_transform: Optional[Callable[[str], str]] = None, + ) -> Any: + """Load the object from the source directory.""" + if payload_info.name != self.protocol_info.name: + raise ProtocolMismatchError( + f"Invalid payload protocol: expected '{self.protocol_info.name}', got '{payload_info.name}'" + ) + + payload_manifest = cast(ParquetManifest, payload_info.manifest) + dfs = [] + for path in payload_manifest["paths"]: + transformed_path = path_transform(path) if path_transform else path + with data_utils.open_stream(transformed_path, "rb", session=session) as f: + df = self._pd.read_parquet(f) + dfs.append(df) + return self._pd.concat(dfs) if len(dfs) > 1 else dfs[0] + + +class NumpyArrayProtocol(SerializationProtocol): + """ + Numpy Array serialization protocol. + Uses BinaryManifest for manifest schema. + """ + + DEFAULT_PATH_PATTERN = "mljob_extra.npy" + + def __init__(self) -> None: + import numpy as np + + self._np = np + + @property + def supported_types(self) -> Condition: + return cast(type, self._np.ndarray) + + @property + def protocol_info(self) -> ProtocolInfo: + return ProtocolInfo( + name="numpy", + version=self._np.__version__, + ) + + def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo: + """Save the object to the destination directory.""" + if not isinstance(obj, self._np.ndarray): + raise SerializationError(f"Expected {self._np.ndarray.__name__} object, got {type(obj).__name__}") + result_path = posixpath.join(dest_dir, self.DEFAULT_PATH_PATTERN) + with data_utils.open_stream(result_path, "wb", session=session) as stream: + self._np.save(stream, obj) + + manifest: BinaryManifest = {"path": result_path} + return self.protocol_info.with_manifest(manifest) + + def load( + self, + payload_info: ProtocolInfo, + session: Optional[snowpark.Session] = None, + path_transform: Optional[Callable[[str], str]] = None, + ) -> Any: + """Load the object from the source directory.""" + if payload_info.name != self.protocol_info.name: + raise ProtocolMismatchError( + f"Invalid payload protocol: expected '{self.protocol_info.name}', got '{payload_info.name}'" + ) + + payload_manifest = cast(BinaryManifest, payload_info.manifest) + transformed_path = path_transform(payload_manifest["path"]) if path_transform else payload_manifest["path"] + with data_utils.open_stream(transformed_path, "rb", session=session) as f: + return self._np.load(f) + + +class AutoProtocol(SerializationProtocol): + def __init__(self) -> None: + self._protocols: list[SerializationProtocol] = [] + self._protocol_info = ProtocolInfo( + name="auto", + version=None, + metadata=None, + ) + + @property + def supported_types(self) -> Condition: + return None # All types are supported + + @property + def protocol_info(self) -> ProtocolInfo: + return self._protocol_info + + def try_register_protocol( + self, + klass: type[SerializationProtocol], + *args: Any, + index: int = 0, + **kwargs: Any, + ) -> None: + """ + Try to construct and register a protocol. If the protocol cannot be constructed, + log a warning and skip registration. By default (index=0), the most recently + registered protocol takes precedence. + + Args: + klass: The class of the protocol to register. + args: The positional arguments to pass to the protocol constructor. + index: The index to register the protocol at. If -1, the protocol is registered at the end of the list. + kwargs: The keyword arguments to pass to the protocol constructor. + """ + try: + protocol = klass(*args, **kwargs) + self.register_protocol(protocol, index=index) + except Exception as e: + logger.warning(f"Failed to register protocol {klass}: {e}") + + def register_protocol( + self, + protocol: SerializationProtocol, + index: int = 0, + ) -> None: + """ + Register a protocol with a condition. By default (index=0), the most recently + registered protocol takes precedence. + + Args: + protocol: The protocol to register. + index: The index to register the protocol at. If -1, the protocol is registered at the end of the list. + + Raises: + ValueError: If the condition is invalid. + ValueError: If the index is invalid. + """ + # Validate condition + # TODO: Build lookup table of supported types to protocols (in priority order) + # for faster lookup at save/load time (instead of iterating over all protocols) + if not isinstance(protocol, SerializationProtocol): + raise ValueError(f"Invalid protocol type: {type(protocol)}. Expected SerializationProtocol.") + if index == -1: + self._protocols.append(protocol) + elif index < 0: + raise ValueError(f"Invalid index: {index}. Expected -1 or >= 0.") + else: + self._protocols.insert(index, protocol) + + def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo: + """Save the object to the destination directory.""" + last_protocol_error = None + for protocol in self._protocols: + try: + if self._is_supported_type(obj, protocol): + logger.debug(f"Dumping object of type {type(obj)} with protocol {protocol}") + return protocol.save(obj, dest_dir, session) + except Exception as e: + logger.warning(f"Error dumping object {obj} with protocol {protocol}: {repr(e)}") + last_protocol_error = (protocol.protocol_info, e) + last_error_str = ( + f", most recent error ({last_protocol_error[0]}): {repr(last_protocol_error[1])}" + if last_protocol_error + else "" + ) + raise ProtocolNotFoundError( + f"No suitable protocol found for type {type(obj).__name__}" + f" (available: {', '.join(str(p.protocol_info) for p in self._protocols)}){last_error_str}" + ) + + def load( + self, + payload_info: ProtocolInfo, + session: Optional[snowpark.Session] = None, + path_transform: Optional[Callable[[str], str]] = None, + ) -> Any: + """Load the object from the source directory.""" + last_error = None + for protocol in self._protocols: + if protocol.protocol_info.name == payload_info.name: + try: + return protocol.load(payload_info, session, path_transform) + except Exception as e: + logger.warning(f"Error loading object with protocol {protocol}: {repr(e)}") + last_error = e + if last_error: + raise last_error + raise ProtocolNotFoundError( + f"No protocol matching {payload_info} available" + f" (available: {', '.join(str(p.protocol_info) for p in self._protocols)})" + ", possibly due to snowflake-ml-python package version mismatch" + ) + + def _is_supported_type(self, obj: Any, protocol: SerializationProtocol) -> bool: + if protocol.supported_types is None: + return True # None means all types are supported + elif isinstance(protocol.supported_types, (type, tuple)): + return isinstance(obj, protocol.supported_types) + elif callable(protocol.supported_types): + return protocol.supported_types(obj) is True + raise ValueError(f"Invalid supported types: {protocol.supported_types} for protocol {protocol}") diff --git a/snowflake/ml/jobs/_interop/protocols_test.py b/snowflake/ml/jobs/_interop/protocols_test.py new file mode 100644 index 00000000..112bd6f9 --- /dev/null +++ b/snowflake/ml/jobs/_interop/protocols_test.py @@ -0,0 +1,252 @@ +import tempfile +from typing import Any, Callable, Optional + +import numpy as np +import pandas as pd +import pyarrow as pa +from absl.testing import absltest, parameterized + +from snowflake.ml.jobs._interop import protocols as p +from snowflake.ml.jobs._interop.dto_schema import ProtocolInfo + + +class MyClass: + def __init__(self, x: int, y: int) -> None: + self.x = x + self.y = y + + def __eq__(self, other: Any) -> bool: + return bool(self.x == other.x and self.y == other.y) + + +class DummyProtocol(p.SerializationProtocol): + def __init__( + self, + name: str, + version: Optional[str] = None, + metadata: Optional[dict[str, str]] = None, + supported_types: p.Condition = None, + ) -> None: + self._protocol_info = ProtocolInfo( + name=name, + version=version, + metadata=metadata, + ) + self._supported_types = supported_types + + @property + def supported_types(self) -> p.Condition: + return self._supported_types + + @property + def protocol_info(self) -> ProtocolInfo: + return self._protocol_info + + def save(self, obj: Any, dest_dir: str, session: Optional[Any] = None) -> ProtocolInfo: + # Simple dummy implementation - just return protocol info with manifest + return self.protocol_info.with_manifest({"dummy_path": f"{dest_dir}/dummy.pkl"}) + + def load( + self, + payload_info: ProtocolInfo, + session: Optional[Any] = None, + path_transform: Optional[Callable[[str], str]] = None, + ) -> Any: + # Simple dummy implementation - return a dummy object + manifest_value = payload_info.manifest.get("dummy_path") if payload_info.manifest else None + return {"loaded": True, "from": manifest_value} + + +class TestProtocols(parameterized.TestCase): + @parameterized.parameters( # type: ignore[misc] + (p.CloudPickleProtocol(), None), # All types supported + (p.ArrowTableProtocol(), pa.Table), + (p.PandasDataFrameProtocol(), pd.DataFrame), + (p.NumpyArrayProtocol(), np.ndarray), + ) + def test_supported_types(self, protocol: p.SerializationProtocol, expected_type: Any) -> None: + """Test that protocols report their supported types correctly.""" + self.assertEqual(protocol.supported_types, expected_type) + + @parameterized.parameters( # type: ignore[misc] + # CloudPickle - supports arbitrary Python objects + ( + p.CloudPickleProtocol(), + MyClass(1, 2), + ), + # Arrow Table + ( + p.ArrowTableProtocol(), + pa.table({"col1": [1, 2, 3], "col2": ["a", "b", "c"]}), + lambda test, expected, actual: ( + test.assertEqual(expected.schema, actual.schema), + test.assertEqual(expected.to_pydict(), actual.to_pydict()), + ), + ), + # Pandas DataFrame + ( + p.PandasDataFrameProtocol(), + pd.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"]}), + lambda test, expected, actual: pd.testing.assert_frame_equal(expected, actual), + ), + # Numpy Array + ( + p.NumpyArrayProtocol(), + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float64), + lambda test, expected, actual: np.testing.assert_array_equal(expected, actual), + ), + ) + def test_serialization( + self, + protocol: p.SerializationProtocol, + obj: Any, + assertion_func: Optional[Callable[[Any, Any, Any], None]] = None, + ) -> None: + """Test serialization and deserialization of supported types.""" + with tempfile.TemporaryDirectory() as temp_dir: + # Test serialization + protocol_info = protocol.save(obj, temp_dir) + self.assertIsInstance(protocol_info, ProtocolInfo) + self.assertIsNotNone(protocol_info.manifest) + + # Check protocol-specific manifest structure + assert isinstance(protocol_info.manifest, dict) + if isinstance(protocol, p.CloudPickleProtocol): + self.assertIn("path", protocol_info.manifest) + elif isinstance(protocol, (p.ArrowTableProtocol, p.PandasDataFrameProtocol)): + self.assertIn("paths", protocol_info.manifest) + self.assertIsInstance(protocol_info.manifest["paths"], list) # type: ignore[typeddict-item] + self.assertGreater(len(protocol_info.manifest["paths"]), 0) # type: ignore[typeddict-item] + elif isinstance(protocol, p.NumpyArrayProtocol): + self.assertIn("path", protocol_info.manifest) + + # Test deserialization + loaded_obj = protocol.load(protocol_info) + + # Use the custom assertion function + if assertion_func is None: + self.assertEqual(obj, loaded_obj) + else: + assertion_func(self, obj, loaded_obj) + + @parameterized.parameters( # type: ignore[misc] + # Arrow Table with wrong type + ( + p.ArrowTableProtocol(), + {"not": "a table"}, # dict instead of Arrow Table + ), + # Pandas DataFrame with wrong type + ( + p.PandasDataFrameProtocol(), + [1, 2, 3], # list instead of DataFrame + ), + # Numpy Array with wrong type + ( + p.NumpyArrayProtocol(), + "not an array", # string instead of ndarray + ), + ) + def test_serialization_unsupported_types( + self, + protocol: p.SerializationProtocol, + obj: Any, + ) -> None: + """Test that protocols raise appropriate errors for unsupported types.""" + with tempfile.TemporaryDirectory() as temp_dir: + with self.assertRaises(p.SerializationError): + protocol.save(obj, temp_dir) + + +class TestAutoProtocol(parameterized.TestCase): + def setUp(self) -> None: + super().setUp() + self.sut = p.AutoProtocol() + + def test_register_protocol(self) -> None: + proto = p.CloudPickleProtocol() + self.sut.register_protocol(proto) + self.assertEqual(len(self.sut._protocols), 1) + self.assertEqual(self.sut._protocols[0], proto) + + def test_register_protocol_index(self) -> None: + proto1 = DummyProtocol("proto_1") + proto2 = DummyProtocol("proto_2") + proto3 = DummyProtocol("proto_3") + proto4 = DummyProtocol("proto_4") + + self.sut.register_protocol(proto1, index=0) + self.assertEqual(self.sut._protocols, [proto1]) + self.sut.register_protocol(proto2, index=1) + self.assertEqual(self.sut._protocols, [proto1, proto2]) + self.sut.register_protocol(proto3, index=-1) + self.assertEqual(self.sut._protocols, [proto1, proto2, proto3]) + self.sut.register_protocol(proto4, index=1) + self.assertEqual(self.sut._protocols, [proto1, proto4, proto2, proto3]) + + def test_register_protocol_negative(self) -> None: + # Test invalid protocol type + with self.assertRaises(ValueError): + self.sut.register_protocol("not_a_protocol") # type: ignore[arg-type] + + # Test invalid index + proto = DummyProtocol("proto") + with self.assertRaises(ValueError): + self.sut.register_protocol(proto, index=-2) + + @parameterized.parameters( # type: ignore[misc] + ("use_proto_1", "proto_2"), + ("some_str", "proto_1"), + (1.0, "proto_3"), + (1, "proto_3"), + (MyClass(1, 2), "proto_0"), + ) + def test_save_protocol(self, obj: Any, expected: str) -> None: + self.sut.register_protocol(DummyProtocol("proto_0", supported_types=None)) + self.sut.register_protocol(DummyProtocol("proto_1", supported_types=str)) + self.sut.register_protocol(DummyProtocol("proto_2", supported_types=lambda x: x == "use_proto_1")) + self.sut.register_protocol(DummyProtocol("proto_3", supported_types=(float, int))) + + with tempfile.TemporaryDirectory() as temp_dir: + proto_info = self.sut.save(obj, temp_dir) + self.assertEqual(proto_info.name, expected) + + def test_try_register_protocol_success(self) -> None: + initial_count = len(self.sut._protocols) + self.sut.try_register_protocol(DummyProtocol, "test_proto") + self.assertEqual(len(self.sut._protocols), initial_count + 1) + self.assertEqual(self.sut._protocols[-1].protocol_info.name, "test_proto") + + def test_try_register_protocol_failure(self) -> None: + # Create a class that will fail during construction + class FailingProtocol(p.SerializationProtocol): + def __init__(self) -> None: + raise RuntimeError("Construction failed") + + @property + def supported_types(self) -> p.Condition: + return None + + @property + def protocol_info(self) -> ProtocolInfo: + return ProtocolInfo(name="failing") + + def save(self, obj: Any, dest_dir: str, session: Optional[Any] = None) -> ProtocolInfo: + return ProtocolInfo(name="failing") + + def load( + self, + payload_info: ProtocolInfo, + session: Optional[Any] = None, + path_transform: Optional[Callable[[str], str]] = None, + ) -> Any: + pass + + initial_count = len(self.sut._protocols) + # Should not raise an exception, just log a warning + with self.assertLogs("snowflake.ml.jobs._interop.protocols", level="WARNING"): + self.sut.try_register_protocol(FailingProtocol) + self.assertEqual(len(self.sut._protocols), initial_count) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/jobs/_interop/results.py b/snowflake/ml/jobs/_interop/results.py new file mode 100644 index 00000000..9e918b82 --- /dev/null +++ b/snowflake/ml/jobs/_interop/results.py @@ -0,0 +1,51 @@ +from dataclasses import dataclass +from typing import Any, Optional + + +@dataclass(frozen=True) +class ExecutionResult: + """ + A result of a job execution. + + Args: + success: Whether the execution was successful. + value: The value of the execution. + """ + + success: bool + value: Any + + def get_value(self, wrap_exceptions: bool = True) -> Any: + if not self.success: + assert isinstance(self.value, BaseException), "Unexpected non-exception value for failed result" + self._raise_exception(self.value, wrap_exceptions) + return self.value + + def _raise_exception(self, exception: BaseException, wrap_exceptions: bool) -> None: + if wrap_exceptions: + raise RuntimeError(f"Job execution failed with error: {exception!r}") from exception + else: + raise exception + + +@dataclass(frozen=True) +class LoadedExecutionResult(ExecutionResult): + """ + A result of a job execution that has been loaded from a file. + """ + + load_error: Optional[Exception] = None + result_metadata: Optional[dict[str, Any]] = None + + def get_value(self, wrap_exceptions: bool = True) -> Any: + if not self.success: + # Raise the original exception if available, otherwise raise the load error + ex = self.value + if not isinstance(ex, BaseException): + ex = RuntimeError(f"Unknown error {ex or ''}") + ex.__cause__ = self.load_error + self._raise_exception(ex, wrap_exceptions) + else: + if self.load_error: + raise ValueError("Job execution succeeded but result retrieval failed") from self.load_error + return self.value diff --git a/snowflake/ml/jobs/_interop/utils.py b/snowflake/ml/jobs/_interop/utils.py new file mode 100644 index 00000000..ebb7abed --- /dev/null +++ b/snowflake/ml/jobs/_interop/utils.py @@ -0,0 +1,144 @@ +import logging +import os +import traceback +from pathlib import PurePath +from typing import Any, Callable, Optional + +import pydantic + +from snowflake import snowpark +from snowflake.ml.jobs._interop import data_utils, exception_utils, legacy, protocols +from snowflake.ml.jobs._interop.dto_schema import ( + ExceptionMetadata, + ResultDTO, + ResultMetadata, +) +from snowflake.ml.jobs._interop.results import ExecutionResult, LoadedExecutionResult +from snowflake.snowpark import exceptions as sp_exceptions + +DEFAULT_CODEC = data_utils.JsonDtoCodec +DEFAULT_PROTOCOL = protocols.AutoProtocol() +DEFAULT_PROTOCOL.try_register_protocol(protocols.CloudPickleProtocol) +DEFAULT_PROTOCOL.try_register_protocol(protocols.ArrowTableProtocol) +DEFAULT_PROTOCOL.try_register_protocol(protocols.PandasDataFrameProtocol) +DEFAULT_PROTOCOL.try_register_protocol(protocols.NumpyArrayProtocol) + + +logger = logging.getLogger(__name__) + + +def save_result(result: ExecutionResult, path: str, session: Optional[snowpark.Session] = None) -> None: + """ + Save the result to a file. + """ + result_dto = ResultDTO( + success=result.success, + value=result.value, + ) + + try: + # Try to encode result directly + payload = DEFAULT_CODEC.encode(result_dto) + except TypeError: + result_dto.value = None # Remove raw value to avoid serialization error + result_dto.metadata = _get_metadata(result.value) # Add metadata for client fallback on protocol mismatch + try: + path_dir = PurePath(path).parent.as_posix() + protocol_info = DEFAULT_PROTOCOL.save(result.value, path_dir, session=session) + result_dto.protocol = protocol_info + + except Exception as e: + logger.warning(f"Error dumping result value: {repr(e)}") + result_dto.serialize_error = repr(e) + + # Encode the modified result DTO + payload = DEFAULT_CODEC.encode(result_dto) + + with data_utils.open_stream(path, "wb", session=session) as stream: + stream.write(payload) + + +def load_result( + path: str, session: Optional[snowpark.Session] = None, path_transform: Optional[Callable[[str], str]] = None +) -> ExecutionResult: + """Load the result from a file on a Snowflake stage.""" + try: + with data_utils.open_stream(path, "r", session=session) as stream: + # Load the DTO as a dict for easy fallback to legacy loading if necessary + dto_dict = DEFAULT_CODEC.decode(stream, as_dict=True) + except UnicodeDecodeError: + # Path may be a legacy result file (cloudpickle) + # TODO: Re-use the stream + assert session is not None + return legacy.load_legacy_result(session, path) + + try: + dto = ResultDTO.model_validate(dto_dict) + except pydantic.ValidationError as e: + if "success" in dto_dict: + assert session is not None + if path.endswith(".json"): + path = os.path.splitext(path)[0] + ".pkl" + return legacy.load_legacy_result(session, path, result_json=dto_dict) + raise ValueError("Invalid result schema") from e + + # Try loading data from file using the protocol info + result_value = None + data_load_error = None + if dto.protocol is not None: + try: + logger.debug(f"Loading result value with protocol {dto.protocol}") + result_value = DEFAULT_PROTOCOL.load(dto.protocol, session=session, path_transform=path_transform) + except sp_exceptions.SnowparkSQLException: + raise # Data retrieval errors should be bubbled up + except Exception as e: + logger.debug(f"Error loading result value with protocol {dto.protocol}: {repr(e)}") + data_load_error = e + + # Wrap serialize_error in a TypeError + if dto.serialize_error: + serialize_error = TypeError("Original result serialization failed with error: " + dto.serialize_error) + if data_load_error: + data_load_error.__context__ = serialize_error + else: + data_load_error = serialize_error + + # Prepare to assemble the final result + result_value = result_value if result_value is not None else dto.value + if not dto.success and result_value is None: + # Try to reconstruct exception from metadata if available + if isinstance(dto.metadata, ExceptionMetadata): + logger.debug(f"Reconstructing exception from metadata {dto.metadata}") + result_value = exception_utils.build_exception( + type_str=dto.metadata.type, + message=dto.metadata.message, + traceback=dto.metadata.traceback, + original_repr=dto.metadata.repr, + ) + + # Generate a generic error if we still don't have a value, + # attaching the data load error if any + if result_value is None: + result_value = exception_utils.RemoteError("Unknown remote error") + result_value.__cause__ = data_load_error + + return LoadedExecutionResult( + success=dto.success, + value=result_value, + load_error=data_load_error, + ) + + +def _get_metadata(value: Any) -> ResultMetadata: + type_name = f"{type(value).__module__}.{type(value).__name__}" + if isinstance(value, BaseException): + return ExceptionMetadata( + type=type_name, + repr=repr(value), + message=str(value), + traceback="".join(traceback.format_tb(value.__traceback__)), + ) + return ResultMetadata( + type=type_name, + repr=repr(value), + ) diff --git a/snowflake/ml/jobs/_interop/utils_test.py b/snowflake/ml/jobs/_interop/utils_test.py new file mode 100644 index 00000000..a73f649f --- /dev/null +++ b/snowflake/ml/jobs/_interop/utils_test.py @@ -0,0 +1,662 @@ +import io +import json +import os +import re +import subprocess +import sys +import tempfile +import threading +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union +from unittest import mock + +import cloudpickle as cp +import numpy as np +import pandas as pd +from absl.testing import absltest, parameterized + +from snowflake.ml.jobs._interop import utils as u +from snowflake.ml.jobs._interop.dto_schema import ( + ExceptionMetadata, + PayloadManifest, + ResultMetadata, +) +from snowflake.ml.jobs._interop.exception_utils import RemoteError +from snowflake.ml.jobs._interop.results import ExecutionResult, LoadedExecutionResult +from snowflake.snowpark import exceptions as sp_exceptions + + +@dataclass(frozen=True) +class DataClass: + int_value: int = 0 + str_value: str = "null" + + +class DummyClass: + def __init__(self, value: int = 0, label: str = "null") -> None: + self.value = value + self.label = label + + def __eq__(self, value: object) -> bool: + return isinstance(value, DummyClass) and self.value == value.value and self.label == value.label + + def __repr__(self) -> str: + return f"DummyClass(value={self.value}, label={self.label})" + + +class DummyException(Exception): + def __init__(self, message: str = "dummy error") -> None: + super().__init__(message) + self.message = message + + def __eq__(self, value: object) -> bool: + return isinstance(value, DummyException) and self.message == value.message + + +class DummyNonserializableClass: + def __init__(self) -> None: + self._lock = threading.Lock() + + +class DummyNonserializableException(Exception): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._lock = threading.Lock() + + +def create_serialized_unavailable_class() -> bytes: + """Creates a cloudpickle-serialized object from an unavailable module namespace.""" + code = """ +import sys, cloudpickle, base64 + +# Create fake module and a class inside the fake module +module_name = "unavailable_test_module" +fake_mod = type(sys)(module_name) +sys.modules[module_name] = fake_mod + +# Define class in fake module +class ExternalTestError(Exception): + def __init__(self, message, value=42): + super().__init__(message) + self.value = value + +ExternalTestError.__module__ = module_name +fake_mod.ExternalTestError = ExternalTestError + +# Serialize and encode +serialized = cloudpickle.dumps(ExternalTestError("test exception", 123)) +print(base64.b64encode(serialized).decode()) +""" + result = subprocess.run([sys.executable, "-c", code], capture_output=True, text=True, check=True) + import base64 + + return base64.b64decode(result.stdout.strip()) + + +# Create a serialized unavailable class once and re-use it +external_exception_bytes = create_serialized_unavailable_class() + + +class TestInteropUtils(parameterized.TestCase): + def setUp(self) -> None: + self.mock_session = mock.MagicMock() + + @parameterized.named_parameters( # type: ignore[misc] + ("int_result", ExecutionResult(True, 1)), + ("float_result", ExecutionResult(True, 3.14)), + ("str_result", ExecutionResult(True, "test string")), + ) + def test_save_result(self, result: ExecutionResult) -> None: + with tempfile.TemporaryDirectory() as temp_dir: + temp_file_path = f"{temp_dir}/result.json" + u.save_result(result, temp_file_path) + + with open(temp_file_path) as f: + result_json = json.load(f) + self.assertEqual(result_json["success"], result.success) + self.assertEqual(result_json["value"], result.value) + self.assertIsNone(result_json["protocol"]) + self.assertIsNone(result_json["metadata"]) + + # Assert that no other files exist in temp_dir + files_in_temp_dir = os.listdir(temp_dir) + self.assertEqual(len(files_in_temp_dir), 1) + self.assertEqual(files_in_temp_dir[0], "result.json") + + @parameterized.named_parameters( # type: ignore[misc] + ("result_object", ExecutionResult(True, DummyClass()), "mljob_extra.pkl"), + ("dataclass_object", ExecutionResult(True, DataClass()), "mljob_extra.pkl"), + ("numpy_array", ExecutionResult(True, np.array([1, 2, 3])), "mljob_extra.npy"), + ( + "pandas_dataframe", + ExecutionResult(True, pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})), + "mljob_extra_0.parquet", + lambda x: x["paths"][0], + ), + ("exception_as_return_value", ExecutionResult(True, ValueError("error as result")), "mljob_extra.pkl"), + ("value_error", ExecutionResult(False, ValueError("test error")), "mljob_extra.pkl"), + ) + def test_save_result_complex( + self, + result: ExecutionResult, + expected_path: str, + path_getter: Callable[[PayloadManifest], str] = lambda x: x["path"], # type: ignore[typeddict-item] + ) -> None: + with tempfile.TemporaryDirectory() as temp_dir: + temp_file_path = f"{temp_dir}/result.json" + u.save_result(result, temp_file_path) + + with open(temp_file_path) as f: + result_json = json.load(f) + self.assertEqual(result_json["success"], result.success) + self.assertIsNone(result_json["value"], result.value) + + actual_path = path_getter(result_json["protocol"]["manifest"]) + self.assertEndsWith(actual_path, expected_path) + self.assertIsNotNone(result_json["protocol"]) + self.assertIsNotNone(result_json["metadata"]) + self.assertTrue(os.path.exists(actual_path)) + + @parameterized.named_parameters( # type: ignore[misc] + ("result", ExecutionResult(True, DummyNonserializableClass())), + ("exception", ExecutionResult(False, DummyNonserializableException())), + ) + def test_save_result_nonserializable(self, result: ExecutionResult) -> None: + with tempfile.TemporaryDirectory() as temp_dir: + temp_file_path = f"{temp_dir}/result.json" + u.save_result(result, temp_file_path) + + with open(temp_file_path) as f: + result_json = json.load(f) + self.assertEqual(result_json["success"], result.success) + self.assertIsNone(result_json["value"], result.value) + + # In this case, protocol should not be set + # but metadata should be set + self.assertIsNone(result_json["protocol"]) + self.assertIsNotNone(result_json["metadata"]) + + # Assert that no other files exist in temp_dir + # FIXME: Protocols currently may write partial files on error + # files_in_temp_dir = os.listdir(temp_dir) + # self.assertEqual(len(files_in_temp_dir), 1) + + @parameterized.named_parameters( # type: ignore[misc] + ("result_object", ExecutionResult(True, DummyClass()), 2), + ("value_error", ExecutionResult(False, ValueError("test error")), 2), + ("nonserializable_result", ExecutionResult(True, DummyNonserializableClass()), 1), + ("nonserializable_exception", ExecutionResult(False, DummyNonserializableException()), 1), + ) + def test_save_result_to_stage( + self, result: ExecutionResult, expected_files: int, expected_path: str = "mljob_extra.pkl" + ) -> None: + temp_file_path = "@dummy_stage/result.json" + u.save_result(result, temp_file_path, session=self.mock_session) + + self.assertEqual(self.mock_session.file.put_stream.call_count, expected_files) + self.mock_session.file.put_stream.assert_any_call(mock.ANY, temp_file_path) + if expected_files > 1: + self.mock_session.file.put_stream.assert_any_call(mock.ANY, f"@dummy_stage/{expected_path}") + + @parameterized.named_parameters( # type: ignore[misc] + dict( + testcase_name="no_result", + data={"success": True, "value": None, "protocol": None}, + expected_value=None, + ), + dict( + testcase_name="simple_result", + data={"success": True, "value": "test string", "protocol": None}, + expected_value="test string", + ), + dict( + testcase_name="cloudpickled_result", + data={ + "success": True, + "value": None, + "protocol": { + "name": "cloudpickle", + "version": cp.__version__, + "manifest": {"path": "@dummy/secondary"}, + }, + }, + secondary_data=cp.dumps(DummyClass(42, "loaded label")), + expected_value=DummyClass(42, "loaded label"), + ), + dict( + testcase_name="cloudpickled_result_dataclass", + data={ + "success": True, + "value": None, + "protocol": { + "name": "cloudpickle", + "version": cp.__version__, + "manifest": {"path": "@dummy/secondary"}, + }, + }, + secondary_data=cp.dumps(DataClass(int_value=42, str_value="loaded string")), + expected_value=DataClass(int_value=42, str_value="loaded string"), + ), + # TODO: Add pandas DataFrame or numpy array test case + dict( + testcase_name="nonserializable_result", + data={ + "success": True, + "value": None, + "protocol": None, + "metadata": {"type": "__main__.DummyNonserializableClass", "repr": "..."}, + "serialize_error": "TypeError(\"cannot pickle '_thread.lock' object\")", + }, + expected_error=ValueError("Job execution succeeded but result retrieval failed"), + expected_cause=TypeError("Original result serialization failed"), + ), + dict( + testcase_name="nonserializable_result_wrapped_exception", + data={ + "success": True, + "value": None, + "protocol": None, + "metadata": {"type": "__main__.DummyNonserializableClass", "repr": "..."}, + "serialize_error": "TypeError(\"cannot pickle '_thread.lock' object\")", + }, + wrap_exceptions=True, # Shouldn't make a difference for success case + expected_error=ValueError("Job execution succeeded but result retrieval failed"), + expected_cause=TypeError("Original result serialization failed"), + ), + dict( + testcase_name="exception_as_result", + data={ + "success": True, + "value": None, + "protocol": { + "name": "cloudpickle", + "version": cp.__version__, + "manifest": {"path": "@dummy/secondary"}, + }, + }, + secondary_data=cp.dumps(DummyException("loaded dummy error")), + expected_value=DummyException("loaded dummy error"), + ), + dict( + testcase_name="simple_exception", + data={ + "success": False, + "value": None, + "protocol": { + "name": "cloudpickle", + "version": cp.__version__, + "manifest": {"path": "@dummy/secondary"}, + }, + }, + secondary_data=cp.dumps(RuntimeError("loaded runtime error")), + expected_error=RuntimeError("loaded runtime error"), + wrap_exceptions=False, + ), + dict( + testcase_name="simple_exception_wrapped", + data={ + "success": False, + "value": None, + "protocol": { + "name": "cloudpickle", + "version": cp.__version__, + "manifest": {"path": "@dummy/secondary"}, + }, + }, + secondary_data=cp.dumps(RuntimeError("loaded runtime error")), + expected_error=RuntimeError("Job execution failed"), + expected_cause=RuntimeError("loaded runtime error"), + wrap_exceptions=True, + ), + dict( + testcase_name="nonserializable_exception", + data={ + "success": False, + "value": None, + "protocol": None, + "metadata": { + "type": "SomeNonserializableClass", + "repr": "SomeNonserializableClass('...')", + "message": "...", + "traceback": "...", + }, + "serialize_error": "TypeError(\"cannot pickle '_thread.lock' object\")", + }, + expected_error=RemoteError("SomeNonserializableClass('...')"), + expected_cause=ModuleNotFoundError("Unrecognized exception"), + ), + dict( + testcase_name="nonserializable_exception_reconstructed", + data={ + "success": False, + "value": None, + "protocol": None, + "metadata": { + "type": "NotImplementedError", + "repr": "NotImplementedError('test')", + "message": "test", + "traceback": "...", + }, + "serialize_error": "TypeError(\"cannot pickle '_thread.lock' object\")", + }, + expected_error=NotImplementedError("test"), + ), + dict( + testcase_name="nonserializable_exception_no_metadata", + data={ + "success": False, + "value": None, + "protocol": None, + "metadata": None, + "serialize_error": "TypeError(\"cannot pickle '_thread.lock' object\")", + }, + expected_error=RuntimeError("Unknown remote error"), + expected_cause=TypeError("Original result serialization failed"), + ), + dict( + testcase_name="unknown_exception", + data={ + "success": False, + "value": None, + }, + expected_error=RuntimeError("Unknown remote error"), + ), + # Negative deserialization cases + dict( + testcase_name="remote_only_result", + data={ + "success": True, + "value": None, + "protocol": { + "name": "cloudpickle", + "version": cp.__version__, + "manifest": {"path": "@dummy/secondary"}, + }, + }, + secondary_data=external_exception_bytes, + expected_error=ValueError("Job execution succeeded but result retrieval failed"), + expected_cause=ModuleNotFoundError("No module"), + ), + dict( + testcase_name="remote_only_exception", + data={ + "success": False, + "value": None, + "protocol": { + "name": "cloudpickle", + "version": cp.__version__, + "manifest": {"path": "@dummy/secondary"}, + }, + "metadata": { + "type": "ExternalTestError", + "repr": "ExternalTestError('test exception', 123)", + "message": "test exception", + "traceback": "...", + }, + }, + secondary_data=external_exception_bytes, + expected_error=RemoteError("ExternalTestError('test exception', 123)"), + expected_cause=ModuleNotFoundError("Unrecognized exception"), + ), + dict( + testcase_name="remote_only_exception_no_metadata", + data={ + "success": False, + "value": None, + "protocol": { + "name": "cloudpickle", + "version": cp.__version__, + "manifest": {"path": "@dummy/secondary"}, + }, + }, + secondary_data=external_exception_bytes, + expected_error=RuntimeError("Unknown remote error"), + expected_cause=ModuleNotFoundError("No module"), + ), + dict( + testcase_name="incompatible_cloudpickle", + data={ + "success": True, + "value": None, + "protocol": {"name": "cloudpickle", "version": "0.0.0", "manifest": {"path": "@dummy/secondary"}}, + }, + secondary_data=b"invalid_pickle_data", + expected_error=ValueError("Job execution succeeded but result retrieval failed"), + expected_cause=ValueError("cloudpickle version"), + ), + dict( + testcase_name="incompatible_result_protocol", + data={ + "success": True, + "value": None, + "protocol": {"name": "unknown_proto", "manifest": {"path": "@dummy/secondary"}}, + }, + secondary_data=b"fake_data", + expected_error=ValueError("Job execution succeeded but result retrieval failed"), + expected_cause=TypeError("No protocol matching"), + ), + dict( + testcase_name="incompatible_exception_protocol", + data={ + "success": False, + "value": None, + "protocol": {"name": "unknown_proto", "manifest": {"path": "@dummy/secondary"}}, + }, + secondary_data=b"fake_data", + expected_error=RuntimeError("Unknown remote error"), + expected_cause=TypeError("No protocol matching"), + ), + dict( + testcase_name="incompatible_exception_protocol_reconstructed", + data={ + "success": False, + "value": None, + "protocol": {"name": "unknown_proto", "manifest": {"path": "@dummy/secondary"}}, + "metadata": { + "type": "NotImplementedError", + "repr": "NotImplementedError('test')", + "message": "test", + "traceback": "...", + }, + }, + secondary_data=b"fake_data", + expected_error=NotImplementedError("test"), + ), + dict( + testcase_name="legacy_result_simple", + data={"success": True, "result_type": int.__qualname__, "result": 42}, + secondary_data=cp.dumps({"success": True, "result_type": int.__qualname__, "result": 42}), + expected_value=42, + ), + dict( + testcase_name="legacy_result_complex", + data={"success": True, "result_type": int.__qualname__, "result": str(DummyClass(42, "loaded label"))}, + secondary_data=cp.dumps( + {"success": True, "result_type": int.__qualname__, "result": DummyClass(42, "loaded label")} + ), + expected_value=DummyClass(42, "loaded label"), + ), + dict( + testcase_name="legacy_result_nonserializable", + data={ + "success": True, + "result_type": DummyNonserializableClass.__qualname__, + "result": str(DummyNonserializableClass()), + }, + secondary_data=None, + expected_value=str(DummyNonserializableClass()), + ), + dict( + testcase_name="legacy_exception_simple", + data={"success": False, "exc_type": "builtins.ValueError", "exc_value": "legacy error", "exc_tb": "..."}, + secondary_data=cp.dumps( + {"success": False, "exc_type": "builtins.ValueError", "exc_value": "legacy error", "exc_tb": "..."} + ), + expected_error=ValueError("legacy error"), + ), + dict( + testcase_name="legacy_exception_nonserializable", + data={ + "success": False, + "exc_type": "SomeNonserializableClass", + "exc_value": "legacy error", + "exc_tb": "...", + }, + secondary_data=None, + expected_error=RemoteError("SomeNonserializableClass('legacy error')"), + expected_cause=ModuleNotFoundError("Unrecognized exception"), + ), + dict( + testcase_name="legacy_exception_nonserializable_reconstructed", + data={ + "success": False, + "exc_type": "__main__.DummyNonserializableException", + "exc_value": "legacy error", + "exc_tb": "...", + }, + secondary_data=None, + expected_error=DummyNonserializableException("legacy error"), + ), + ) + def test_load_result( + self, + data: dict[str, Any], + secondary_data: Optional[bytes] = None, + expected_value: Any = None, + expected_error: Optional[Exception] = None, + expected_cause: Optional[Exception] = None, + expected_context: Optional[Exception] = None, + wrap_exceptions: bool = False, + ) -> None: + result_path = "@dummy_stage/result.json" + data_str = json.dumps(data) # NOTE: Need to do this outside mock_get_stream to make closure work correctly + + def mock_get_stream(path: str, *args: Any, **kwargs: Any) -> io.BytesIO: + # Hacky behavior: If the input path is result_path, return the encoded JSON data + # Else, return secondary_data + # Note that path must be a stage path to trigger this mock, else it'll try to read from disk + if path == result_path: + return io.BytesIO(data_str.encode("utf-8")) + if secondary_data is None: + raise sp_exceptions.SnowparkSQLException(f"No secondary data, path: {path}") + return io.BytesIO(secondary_data) + + self.mock_session.file.get_stream.side_effect = mock_get_stream + + result = u.load_result(result_path, session=self.mock_session) + self.assertIsInstance(result, ExecutionResult) + + if expected_error is not None: + with self.assertRaisesRegex(type(expected_error), re.escape(str(expected_error))) as cm: + _ = result.get_value(wrap_exceptions=wrap_exceptions) + if expected_cause: + actual_cause = cm.exception.__cause__ + self.assertIsInstance(actual_cause, type(expected_cause)) + self.assertIn(str(expected_cause), str(actual_cause)) + else: + self.assertIsNone(cm.exception.__cause__) + if expected_context: + actual_context = cm.exception.__context__ + self.assertIsInstance(actual_context, type(expected_context)) + self.assertIn(str(expected_context), str(actual_context)) + else: + self.assertIsNone(cm.exception.__context__) + else: + value = result.get_value(wrap_exceptions=wrap_exceptions) + load_error = result.load_error if isinstance(result, LoadedExecutionResult) else None + self.assertEqual(value, expected_value, load_error) + self.assertEqual(type(value), type(expected_value), load_error) + + @parameterized.named_parameters( # type: ignore[misc] + dict( + testcase_name="not_json", + data="this is not json", + expected_error=json.JSONDecodeError("Expecting value", "this is not json", 0), + ), + dict( + testcase_name="malformed_json", + data="{'success': True, 'value': 1", # Missing closing brace + expected_error=json.JSONDecodeError( + "Expecting property name enclosed in double quotes", "{'success': True, 'value': 1", 1 + ), + ), + dict( + testcase_name="empty_dict", + data={}, + expected_error=ValueError("Invalid result"), + ), + dict( + testcase_name="missing_multiple_fields", + data={"value": "test"}, + expected_error=ValueError("Invalid result schema"), + ), + dict( + testcase_name="missing_success_field", + data={"value": "test", "protocol": None}, + expected_error=ValueError("Invalid result schema"), + ), + ) + def test_load_result_negative( + self, + data: Union[dict[str, Any], str], + expected_error: Exception, + secondary_data: Optional[bytes] = None, + ) -> None: + result_path = "@dummy_stage/result.json" + if isinstance(data, dict): + data = json.dumps(data) # NOTE: Need to do this outside mock_get_stream to make closure work correctly + + def mock_get_stream(path: str, *args: Any, **kwargs: Any) -> io.BytesIO: + # Hacky behavior: If the input path is result_path, return the encoded JSON data + # Else, return secondary_data + # Note that path must be a stage path to trigger this mock, else it'll try to read from disk + if path == result_path: + return io.BytesIO(data.encode("utf-8")) + if secondary_data is None: + raise sp_exceptions.SnowparkSQLException(f"No secondary data, path: {path}") + return io.BytesIO(secondary_data) + + self.mock_session.file.get_stream.side_effect = mock_get_stream + + with self.assertRaisesRegex(type(expected_error), re.escape(str(expected_error))): + _ = u.load_result(result_path, session=self.mock_session) + + @parameterized.parameters( # type: ignore[misc] + (None, "builtins.NoneType"), + (1, "builtins.int"), + (1.0, "builtins.float"), + (False, "builtins.bool"), + ((1, 2, 3), "builtins.tuple"), + (DummyClass(), "__main__.DummyClass"), + (ValueError("test error"), "builtins.ValueError"), + (RuntimeError("test error"), "builtins.RuntimeError"), + (sp_exceptions.SnowparkClientException("test error"), "snowflake.snowpark.exceptions.SnowparkClientException"), + ) + def test_get_metadata(self, obj: Any, expected_type: str) -> None: + m = u._get_metadata(obj) + assert isinstance(m, ResultMetadata) + self.assertEqual(m.type, expected_type) + self.assertEqual(m.repr, repr(obj)) + + @parameterized.parameters( # type: ignore[misc] + (ValueError("test error"), "builtins.ValueError"), + (RuntimeError("test error"), "builtins.RuntimeError"), + (sp_exceptions.SnowparkClientException("test error"), "snowflake.snowpark.exceptions.SnowparkClientException"), + ) + def test_get_metadata_exception(self, exception: Exception, expected_type: str) -> None: + # Raise the exception to generate a traceback + try: + raise exception + except Exception as e: + err = e + + m = u._get_metadata(err) + assert isinstance(m, ExceptionMetadata) + self.assertEqual(m.type, expected_type) + self.assertEqual(m.repr, repr(err)) + self.assertEqual(m.message, str(err)) + self.assertNotEmpty(m.traceback) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/ml/jobs/_utils/BUILD.bazel b/snowflake/ml/jobs/_utils/BUILD.bazel index 014f63da..d3fffdfc 100644 --- a/snowflake/ml/jobs/_utils/BUILD.bazel +++ b/snowflake/ml/jobs/_utils/BUILD.bazel @@ -107,11 +107,6 @@ py_library( srcs = ["query_helper.py"], ) -py_library( - name = "interop_utils", - srcs = ["interop_utils.py"], -) - py_library( name = "function_payload_utils", srcs = ["function_payload_utils.py"], @@ -128,23 +123,29 @@ py_library( ) py_test( - name = "interop_utils_test", - srcs = ["interop_utils_test.py"], + name = "mljob_launcher_test", + srcs = ["mljob_launcher_test.py"], tags = ["feature:jobs"], deps = [ - ":interop_utils", + ":constants", + ":payload_scripts", + ":test_file_helper", + "//snowflake/ml/jobs/_interop:legacy", + "//snowflake/ml/jobs/_interop:utils", + "//snowflake/ml/utils:connection_params", ], ) py_test( - name = "mljob_launcher_test", + name = "mljob_launcher_legacy_test", srcs = ["mljob_launcher_test.py"], + main = "mljob_launcher_test.py", tags = ["feature:jobs"], deps = [ ":constants", - ":interop_utils", ":payload_scripts", ":test_file_helper", + "//snowflake/ml/jobs/_interop:legacy", "//snowflake/ml/utils:connection_params", ], ) @@ -155,11 +156,11 @@ py_library( "__init__.py", ], deps = [ - ":interop_utils", ":payload_utils", ":query_helper", ":spec_utils", ":stage_utils", "//snowflake/ml/_internal/utils:mixins", + "//snowflake/ml/jobs/_interop:legacy", ], ) diff --git a/snowflake/ml/jobs/_utils/constants.py b/snowflake/ml/jobs/_utils/constants.py index 1e9ebe42..ddc1b737 100644 --- a/snowflake/ml/jobs/_utils/constants.py +++ b/snowflake/ml/jobs/_utils/constants.py @@ -12,6 +12,9 @@ RESULT_PATH_ENV_VAR = "MLRS_RESULT_PATH" MIN_INSTANCES_ENV_VAR = "MLRS_MIN_INSTANCES" TARGET_INSTANCES_ENV_VAR = "SNOWFLAKE_JOBS_COUNT" +INSTANCES_MIN_WAIT_ENV_VAR = "MLRS_INSTANCES_MIN_WAIT" +INSTANCES_TIMEOUT_ENV_VAR = "MLRS_INSTANCES_TIMEOUT" +INSTANCES_CHECK_INTERVAL_ENV_VAR = "MLRS_INSTANCES_CHECK_INTERVAL" RUNTIME_IMAGE_TAG_ENV_VAR = "MLRS_CONTAINER_IMAGE_TAG" # Stage mount paths @@ -19,7 +22,7 @@ APP_STAGE_SUBPATH = "app" SYSTEM_STAGE_SUBPATH = "system" OUTPUT_STAGE_SUBPATH = "output" -RESULT_PATH_DEFAULT_VALUE = f"{OUTPUT_STAGE_SUBPATH}/mljob_result.pkl" +RESULT_PATH_DEFAULT_VALUE = f"{OUTPUT_STAGE_SUBPATH}/mljob_result" # Default container image information DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images" diff --git a/snowflake/ml/jobs/_utils/feature_flags.py b/snowflake/ml/jobs/_utils/feature_flags.py index 28c20000..214cb817 100644 --- a/snowflake/ml/jobs/_utils/feature_flags.py +++ b/snowflake/ml/jobs/_utils/feature_flags.py @@ -1,16 +1,48 @@ import os from enum import Enum +from typing import Optional + + +def parse_bool_env_value(value: Optional[str], default: bool = False) -> bool: + """Parse a boolean value from an environment variable string. + + Args: + value: The environment variable value to parse (may be None). + default: The default value to return if the value is None or unrecognized. + + Returns: + True if the value is a truthy string (true, 1, yes, on - case insensitive), + False if the value is a falsy string (false, 0, no, off - case insensitive), + or the default value if the value is None or unrecognized. + """ + if value is None: + return default + + normalized_value = value.strip().lower() + if normalized_value in ("true", "1", "yes", "on"): + return True + elif normalized_value in ("false", "0", "no", "off"): + return False + else: + # For unrecognized values, return the default + return default class FeatureFlags(Enum): USE_SUBMIT_JOB_V2 = "MLRS_USE_SUBMIT_JOB_V2" - ENABLE_IMAGE_VERSION_ENV_VAR = "MLRS_ENABLE_RUNTIME_VERSIONS" + ENABLE_RUNTIME_VERSIONS = "MLRS_ENABLE_RUNTIME_VERSIONS" + + def is_enabled(self, default: bool = False) -> bool: + """Check if the feature flag is enabled. - def is_enabled(self) -> bool: - return os.getenv(self.value, "false").lower() == "true" + Args: + default: The default value to return if the environment variable is not set. - def is_disabled(self) -> bool: - return not self.is_enabled() + Returns: + True if the environment variable is set to a truthy value, + False if set to a falsy value, or the default value if not set. + """ + return parse_bool_env_value(os.getenv(self.value), default) def __str__(self) -> str: return self.value diff --git a/snowflake/ml/jobs/_utils/mljob_launcher_test.py b/snowflake/ml/jobs/_utils/mljob_launcher_test.py index 7d56969e..a05b6baf 100644 --- a/snowflake/ml/jobs/_utils/mljob_launcher_test.py +++ b/snowflake/ml/jobs/_utils/mljob_launcher_test.py @@ -4,12 +4,13 @@ import sys import tempfile import time -from typing import Any, Optional +from typing import Any, Optional, cast from unittest import mock from absl.testing import absltest, parameterized -from snowflake.ml.jobs._utils import constants, interop_utils +from snowflake.ml.jobs._interop import legacy +from snowflake.ml.jobs._utils import constants from snowflake.ml.jobs._utils.scripts import mljob_launcher from snowflake.ml.jobs._utils.test_file_helper import resolve_path @@ -31,6 +32,7 @@ def setUp(self) -> None: self.error_script = os.path.join(self.test_dir, "error_script.py") self.complex_script = os.path.join(self.test_dir, "complex_result_script.py") self.argument_script = os.path.join(self.test_dir, "argument_script.py") + self.nonserializable_script = os.path.join(self.test_dir, "nonserializable_result_script.py") def tearDown(self) -> None: # Clean up @@ -40,17 +42,16 @@ def tearDown(self) -> None: def test_run_script_simple(self) -> None: # Test running a simple script - result = mljob_launcher.main(self.simple_script) - self.assertTrue(result.success) - self.assertEqual(result.result["status"], "success") - self.assertEqual(result.result["value"], 42) + with self.assertNoLogs(level="WARNING"): + result = mljob_launcher.main(self.simple_script) + self.assertEqual(result["status"], "success") + self.assertEqual(result["value"], 42) def test_run_script_with_function(self) -> None: # Test running a script with a specified main function result = mljob_launcher.main(self.function_script, script_main_func="main_function") - self.assertTrue(result.success) - self.assertEqual(result.result["status"], "success from function") - self.assertEqual(result.result["value"], 100) + self.assertEqual(result["status"], "success from function") + self.assertEqual(result["value"], 100) @parameterized.parameters( # type: ignore[misc] (100, {"status": "success from another function", "value": 100}), @@ -61,8 +62,7 @@ def test_run_script_with_function_and_args(self, arg_value: Optional[int], expec # Test running a script with a function that takes arguments args = [] if arg_value is None else [arg_value] result = mljob_launcher.main(self.function_script, *args, script_main_func="another_function") - self.assertTrue(result.success) - self.assertEqual(result.result, expected) + self.assertEqual(result, expected) def test_run_script_invalid_function(self) -> None: # Test error when function doesn't exist @@ -72,32 +72,30 @@ def test_run_script_invalid_function(self) -> None: def test_run_script_with_args(self) -> None: # Test running a script with arguments result = mljob_launcher.main(self.argument_script, "arg1", "arg2", "--named_arg=value") - self.assertTrue(result.success) - self.assertListEqual(result.result["args"], ["arg1", "arg2", "--named_arg=value"]) + self.assertListEqual(result["args"], ["arg1", "arg2", "--named_arg=value"]) def test_main_success(self) -> None: # Test the main function with successful execution try: - result_obj = mljob_launcher.main(self.simple_script) - self.assertTrue(result_obj.success) - self.assertEqual(result_obj.result["value"], 42) - - # Check serialized results - with open(self.result_path, "rb") as f: - pickled_result: dict[str, Any] = pickle.load(f) - pickled_result_obj = interop_utils.ExecutionResult.from_dict(pickled_result) - self.assertTrue(pickled_result_obj.success) - assert isinstance(pickled_result_obj.result, dict) - self.assertEqual(pickled_result_obj.result["value"], 42) + result = mljob_launcher.main(self.simple_script) + self.assertEqual(result["value"], 42) + except Exception as e: + self.fail(f"main() raised exception unexpectedly: {e}") + # Check serialized results + pickled_result = _load_result_dict(self.result_path) + pickled_result_obj = legacy.ExecutionResult.from_dict(pickled_result) + self.assertTrue(pickled_result_obj.success) + assert isinstance(pickled_result_obj.result, dict) + self.assertEqual(pickled_result_obj.result["value"], 42) + + if not pickled_result.get("_converted_from_v2"): with open(self.result_json_path) as f: json_result: dict[str, Any] = json.load(f) - json_result_obj = interop_utils.ExecutionResult.from_dict(json_result) + json_result_obj = legacy.ExecutionResult.from_dict(json_result) self.assertTrue(json_result_obj.success) assert isinstance(json_result_obj.result, dict) self.assertEqual(json_result_obj.result["value"], 42) - except Exception as e: - self.fail(f"main() raised exception unexpectedly: {e}") def test_main_error(self) -> None: # Test the main function with script that raises an error @@ -105,9 +103,8 @@ def test_main_error(self) -> None: mljob_launcher.main(self.error_script) # Check serialized error results - with open(self.result_path, "rb") as f: - pickled_result: dict[str, Any] = pickle.load(f) - pickled_result_obj = interop_utils.ExecutionResult.from_dict(pickled_result) + pickled_result = _load_result_dict(self.result_path) + pickled_result_obj = legacy.ExecutionResult.from_dict(pickled_result) self.assertFalse(pickled_result_obj.success) self.assertEqual(type(pickled_result_obj.exception), RuntimeError) self.assertIn("Test error from script", str(pickled_result_obj.exception)) @@ -116,16 +113,17 @@ def test_main_error(self) -> None: self.assertNotIn("mljob_launcher.py", pickled_exc_tb) self.assertNotIn("runpy", pickled_exc_tb) - with open(self.result_json_path) as f: - json_result: dict[str, Any] = json.load(f) - json_result_obj = interop_utils.ExecutionResult.from_dict(json_result) - self.assertFalse(json_result_obj.success) - self.assertEqual(type(json_result_obj.exception), RuntimeError) - self.assertIn("Test error from script", str(json_result_obj.exception)) - json_exc_tb = json_result.get("exc_tb") - self.assertIsInstance(json_exc_tb, str) - self.assertNotIn("mljob_launcher.py", json_exc_tb) - self.assertNotIn("runpy", json_exc_tb) + if not pickled_result.get("_converted_from_v2"): + with open(self.result_json_path) as f: + json_result: dict[str, Any] = json.load(f) + json_result_obj = legacy.ExecutionResult.from_dict(json_result) + self.assertFalse(json_result_obj.success) + self.assertEqual(type(json_result_obj.exception), RuntimeError) + self.assertIn("Test error from script", str(json_result_obj.exception)) + json_exc_tb = json_result.get("exc_tb") + self.assertIsInstance(json_exc_tb, str) + self.assertNotIn("mljob_launcher.py", json_exc_tb) + self.assertNotIn("runpy", json_exc_tb) def test_function_error(self) -> None: # Test error in a function @@ -133,9 +131,8 @@ def test_function_error(self) -> None: mljob_launcher.main(self.error_script, script_main_func="error_function") # Check serialized error results - with open(self.result_path, "rb") as f: - pickled_result = pickle.load(f) - pickled_result_obj = interop_utils.ExecutionResult.from_dict(pickled_result) + pickled_result = _load_result_dict(self.result_path) + pickled_result_obj = legacy.ExecutionResult.from_dict(pickled_result) self.assertFalse(pickled_result_obj.success) self.assertEqual(type(pickled_result_obj.exception), ValueError) self.assertIn("Test error from function", str(pickled_result_obj.exception)) @@ -147,27 +144,26 @@ def test_complex_result_serialization(self) -> None: # Test handling of complex, non-JSON-serializable results try: - result_obj = mljob_launcher.main(self.complex_script) - self.assertTrue(result_obj.success) - - # Check serialized results - pickle should handle complex objects - with open(self.result_path, "rb") as f: - pickled_result = pickle.load(f) - pickled_result_obj = interop_utils.ExecutionResult.from_dict(pickled_result) - self.assertTrue(pickled_result_obj.success) - assert isinstance(pickled_result_obj.result, dict) - self.assertIsInstance(pickled_result_obj.result["custom"], CustomObject) - - # JSON should convert non-serializable objects to strings + _ = mljob_launcher.main(self.complex_script) + except Exception as e: + self.fail(f"main() raised exception unexpectedly: {e}") + + # Check serialized results - pickle should handle complex objects + pickled_result = _load_result_dict(self.result_path) + pickled_result_obj = legacy.ExecutionResult.from_dict(pickled_result) + self.assertTrue(pickled_result_obj.success) + assert isinstance(pickled_result_obj.result, dict), pickled_result + self.assertIsInstance(pickled_result_obj.result["custom"], CustomObject) + + # JSON should convert non-serializable objects to strings + if not pickled_result.get("_converted_from_v2"): with open(self.result_json_path) as f: json_result = json.load(f) - json_result_obj = interop_utils.ExecutionResult.from_dict(json_result) + json_result_obj = legacy.ExecutionResult.from_dict(json_result) self.assertTrue(json_result_obj.success) assert isinstance(json_result_obj.result, dict) self.assertIsInstance(json_result_obj.result["custom"], str) self.assertIn("CustomObject", json_result_obj.result["custom"]) - except Exception as e: - self.fail(f"main() raised exception unexpectedly: {e}") def test_invalid_script_path(self) -> None: # Test with non-existent script path @@ -175,13 +171,11 @@ def test_invalid_script_path(self) -> None: with self.assertRaises(FileNotFoundError): mljob_launcher.main(nonexistent_path) - @absltest.mock.patch("cloudpickle.dump") # type: ignore[misc] - def test_result_pickling_error(self, mock_dump: absltest.mock.MagicMock) -> None: - # Test handling of pickling errors by creating an unpicklable result - # (by monkeypatching cloudpickle.dump to raise an exception) - mock_dump.side_effect = pickle.PicklingError("Mocked pickling error") - with self.assertWarns(RuntimeWarning): - mljob_launcher.main(self.simple_script) + def test_result_pickling_error(self) -> None: + with self.assertLogs(level="WARNING"): + result = mljob_launcher.main(self.nonserializable_script) + # Even with pickling error, main() should still return the result directly + self.assertEqual(str(result), "100") @mock.patch.dict("sys.modules", {"common_utils": mock.MagicMock()}) @mock.patch("common_utils.common_util") @@ -300,5 +294,35 @@ def test_wait_for_instances_timeout(self, mock_sleep: mock.MagicMock, mock_commo self.assertIn("only 1 available", str(cm.exception)) +def _load_result_dict(path: str) -> dict[str, Any]: + """Handle both v1 and v2 result formats, converting final result to v1 format.""" + with open(path, "rb") as f: + try: + return cast(dict[str, Any], pickle.load(f)) + except pickle.UnpicklingError: + f.seek(0) + result_v2_dict = json.load(f) + result_obj = None + if result_protocol := result_v2_dict.get("protocol", {}): + result_path = result_protocol["manifest"]["path"] + with open(result_path, "rb") as f2: + result_obj = pickle.load(f2) + else: + result_obj = result_v2_dict["value"] + + result_v1_dict = { + "success": result_v2_dict["success"], + "_converted_from_v2": True, + } + if result_v2_dict["success"]: + result_v1_dict["result_type"] = type(result_obj).__qualname__ + result_v1_dict["result"] = result_obj + else: + result_v1_dict["exc_type"] = type(result_obj).__qualname__ + result_v1_dict["exc_value"] = result_obj + result_v1_dict["exc_tb"] = result_v2_dict["metadata"]["traceback"] + return result_v1_dict + + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/jobs/_utils/payload_utils.py b/snowflake/ml/jobs/_utils/payload_utils.py index 56559489..a9548884 100644 --- a/snowflake/ml/jobs/_utils/payload_utils.py +++ b/snowflake/ml/jobs/_utils/payload_utils.py @@ -268,7 +268,7 @@ def upload_payloads(session: snowpark.Session, stage_path: PurePath, *payload_sp # can't handle directories. Reduce the number of PUT operations by using # wildcard patterns to batch upload files with the same extension. upload_path_patterns = set() - for p in source_path.resolve().rglob("*"): + for p in source_path.rglob("*"): if p.is_dir(): continue if p.name.startswith("."): diff --git a/snowflake/ml/jobs/_utils/scripts/mljob_launcher.py b/snowflake/ml/jobs/_utils/scripts/mljob_launcher.py index 7e802b86..e5179469 100644 --- a/snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +++ b/snowflake/ml/jobs/_utils/scripts/mljob_launcher.py @@ -9,19 +9,23 @@ import sys import time import traceback -import warnings -from pathlib import Path from typing import Any, Optional -import cloudpickle - -from snowflake.ml.jobs._utils import constants -from snowflake.snowpark import Session - -try: - from snowflake.ml._internal.utils.connection_params import SnowflakeLoginOptions -except ImportError: - from snowflake.ml.utils.connection_params import SnowflakeLoginOptions +# Ensure payload directory is in sys.path for module imports before importing other modules +# This is needed to support relative imports in user scripts and to allow overriding +# modules using modules in the payload directory +# TODO: Inject the environment variable names at job submission time +STAGE_MOUNT_PATH = os.environ.get("MLRS_STAGE_MOUNT_PATH", "/mnt/job_stage") +JOB_RESULT_PATH = os.environ.get("MLRS_RESULT_PATH", "output/mljob_result.pkl") +PAYLOAD_PATH = os.environ.get("MLRS_PAYLOAD_DIR") +if PAYLOAD_PATH and not os.path.isabs(PAYLOAD_PATH): + PAYLOAD_PATH = os.path.join(STAGE_MOUNT_PATH, PAYLOAD_PATH) +if PAYLOAD_PATH and PAYLOAD_PATH not in sys.path: + sys.path.insert(0, PAYLOAD_PATH) + +# Imports below must come after sys.path modification to support module overrides +import snowflake.ml.jobs._utils.constants # noqa: E402 +import snowflake.snowpark # noqa: E402 # Configure logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") @@ -33,48 +37,74 @@ # not have the latest version of the code # Log start and end messages LOG_START_MSG = getattr( - constants, + snowflake.ml.jobs._utils.constants, "LOG_START_MSG", "--------------------------------\nML job started\n--------------------------------", ) LOG_END_MSG = getattr( - constants, + snowflake.ml.jobs._utils.constants, "LOG_END_MSG", "--------------------------------\nML job finished\n--------------------------------", ) +MIN_INSTANCES_ENV_VAR = getattr( + snowflake.ml.jobs._utils.constants, + "MIN_INSTANCES_ENV_VAR", + "MLRS_MIN_INSTANCES", +) +TARGET_INSTANCES_ENV_VAR = getattr( + snowflake.ml.jobs._utils.constants, + "TARGET_INSTANCES_ENV_VAR", + "SNOWFLAKE_JOBS_COUNT", +) +INSTANCES_MIN_WAIT_ENV_VAR = getattr( + snowflake.ml.jobs._utils.constants, + "INSTANCES_MIN_WAIT_ENV_VAR", + "MLRS_INSTANCES_MIN_WAIT", +) +INSTANCES_TIMEOUT_ENV_VAR = getattr( + snowflake.ml.jobs._utils.constants, + "INSTANCES_TIMEOUT_ENV_VAR", + "MLRS_INSTANCES_TIMEOUT", +) +INSTANCES_CHECK_INTERVAL_ENV_VAR = getattr( + snowflake.ml.jobs._utils.constants, + "INSTANCES_CHECK_INTERVAL_ENV_VAR", + "MLRS_INSTANCES_CHECK_INTERVAL", +) -# min_instances environment variable name -MIN_INSTANCES_ENV_VAR = getattr(constants, "MIN_INSTANCES_ENV_VAR", "MLRS_MIN_INSTANCES") -TARGET_INSTANCES_ENV_VAR = getattr(constants, "TARGET_INSTANCES_ENV_VAR", "SNOWFLAKE_JOBS_COUNT") - -# Fallbacks in case of SnowML version mismatch -STAGE_MOUNT_PATH_ENV_VAR = getattr(constants, "STAGE_MOUNT_PATH_ENV_VAR", "MLRS_STAGE_MOUNT_PATH") -RESULT_PATH_ENV_VAR = getattr(constants, "RESULT_PATH_ENV_VAR", "MLRS_RESULT_PATH") -PAYLOAD_DIR_ENV_VAR = getattr(constants, "PAYLOAD_DIR_ENV_VAR", "MLRS_PAYLOAD_DIR") # Constants for the wait_for_instances function -MIN_WAIT_TIME = float(os.getenv("MLRS_INSTANCES_MIN_WAIT") or -1) # seconds -TIMEOUT = float(os.getenv("MLRS_INSTANCES_TIMEOUT") or 720) # seconds -CHECK_INTERVAL = float(os.getenv("MLRS_INSTANCES_CHECK_INTERVAL") or 10) # seconds +MIN_INSTANCES = int(os.environ.get(MIN_INSTANCES_ENV_VAR) or "1") +TARGET_INSTANCES = int(os.environ.get(TARGET_INSTANCES_ENV_VAR) or MIN_INSTANCES) +MIN_WAIT_TIME = float(os.getenv(INSTANCES_MIN_WAIT_ENV_VAR) or -1) # seconds +TIMEOUT = float(os.getenv(INSTANCES_TIMEOUT_ENV_VAR) or 720) # seconds +CHECK_INTERVAL = float(os.getenv(INSTANCES_CHECK_INTERVAL_ENV_VAR) or 10) # seconds -STAGE_MOUNT_PATH = os.environ.get(STAGE_MOUNT_PATH_ENV_VAR, "/mnt/job_stage") -JOB_RESULT_PATH = os.environ.get(RESULT_PATH_ENV_VAR, "output/mljob_result.pkl") +def save_mljob_result_v2(value: Any, is_error: bool, path: str) -> None: + from snowflake.ml.jobs._interop import ( + results as interop_result, + utils as interop_utils, + ) + + result_obj = interop_result.ExecutionResult(success=not is_error, value=value) + interop_utils.save_result(result_obj, path) -try: - from snowflake.ml.jobs._utils.interop_utils import ExecutionResult -except ImportError: + +def save_mljob_result_v1(value: Any, is_error: bool, path: str) -> None: from dataclasses import dataclass + import cloudpickle + + # Directly in-line the ExecutionResult class since the legacy type + # instead of attempting to import the to-be-deprecated + # snowflake.ml.jobs._utils.interop module + # Eventually, this entire function will be removed in favor of v2 @dataclass(frozen=True) - class ExecutionResult: # type: ignore[no-redef] + class ExecutionResult: result: Optional[Any] = None exception: Optional[BaseException] = None - @property - def success(self) -> bool: - return self.exception is None - def to_dict(self) -> dict[str, Any]: """Return the serializable dictionary.""" if isinstance(self.exception, BaseException): @@ -91,14 +121,45 @@ def to_dict(self) -> dict[str, Any]: "result": self.result, } + # Create a custom JSON encoder that converts non-serializable types to strings + class SimpleJSONEncoder(json.JSONEncoder): + def default(self, obj: Any) -> Any: + try: + return super().default(obj) + except TypeError: + return f"Unserializable object: {repr(obj)}" + + result_obj = ExecutionResult(result=None if is_error else value, exception=value if is_error else None) + result_dict = result_obj.to_dict() + try: + # Serialize result using cloudpickle + result_pickle_path = path + with open(result_pickle_path, "wb") as f: + cloudpickle.dump(result_dict, f) # Pickle dictionary form for compatibility + except Exception as pkl_exc: + logger.warning(f"Failed to pickle result to {result_pickle_path}: {pkl_exc}") -# Create a custom JSON encoder that converts non-serializable types to strings -class SimpleJSONEncoder(json.JSONEncoder): - def default(self, obj: Any) -> Any: - try: - return super().default(obj) - except TypeError: - return f"Unserializable object: {repr(obj)}" + try: + # Serialize result to JSON as fallback path in case of cross version incompatibility + result_json_path = os.path.splitext(path)[0] + ".json" + with open(result_json_path, "w") as f: + json.dump(result_dict, f, indent=2, cls=SimpleJSONEncoder) + except Exception as json_exc: + logger.warning(f"Failed to serialize JSON result to {result_json_path}: {json_exc}") + + +def save_mljob_result(result_obj: Any, is_error: bool, path: str) -> None: + """Saves the result or error message to a file in the stage mount path. + + Args: + result_obj: The result object to save, either the return value or the exception. + is_error: Whether the result_obj is a raised exception. + path: The file path to save the result to. + """ + try: + save_mljob_result_v2(result_obj, is_error, path) + except ImportError: + save_mljob_result_v1(result_obj, is_error, path) def wait_for_instances( @@ -225,20 +286,10 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N original_argv = sys.argv sys.argv = [script_path, *script_args] - # Ensure payload directory is in sys.path for module imports - # This is needed because mljob_launcher.py is now in /mnt/job_stage/system - # but user scripts are in the payload directory and may import from each other - payload_dir = os.environ.get(PAYLOAD_DIR_ENV_VAR) - if payload_dir and not os.path.isabs(payload_dir): - payload_dir = os.path.join(STAGE_MOUNT_PATH, payload_dir) - if payload_dir and payload_dir not in sys.path: - sys.path.insert(0, payload_dir) - try: - if main_func: # Use importlib for scripts with a main function defined - module_name = Path(script_path).stem + module_name = os.path.splitext(os.path.basename(script_path))[0] spec = importlib.util.spec_from_file_location(module_name, script_path) assert spec is not None assert spec.loader is not None @@ -262,7 +313,7 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N sys.argv = original_argv -def main(script_path: str, *script_args: Any, script_main_func: Optional[str] = None) -> ExecutionResult: +def main(script_path: str, *script_args: Any, script_main_func: Optional[str] = None) -> Any: """Executes a Python script and serializes the result to JOB_RESULT_PATH. Args: @@ -271,55 +322,53 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] = script_main_func (str, optional): The name of the function to call in the script (if any). Returns: - ExecutionResult: Object containing execution results. + Any: The result of the script execution. Raises: Exception: Re-raises any exception caught during script execution. """ - # Ensure the output directory exists before trying to write result files. - result_abs_path = ( - JOB_RESULT_PATH if os.path.isabs(JOB_RESULT_PATH) else os.path.join(STAGE_MOUNT_PATH, JOB_RESULT_PATH) - ) - output_dir = os.path.dirname(result_abs_path) - os.makedirs(output_dir, exist_ok=True) + try: + from snowflake.ml._internal.utils.connection_params import SnowflakeLoginOptions + except ImportError: + from snowflake.ml.utils.connection_params import SnowflakeLoginOptions + # Initialize Ray if available try: import ray ray.init(address="auto") except ModuleNotFoundError: - warnings.warn("Ray is not installed, skipping Ray initialization", ImportWarning, stacklevel=1) + logger.debug("Ray is not installed, skipping Ray initialization") # Create a Snowpark session before starting # Session can be retrieved from using snowflake.snowpark.context.get_active_session() config = SnowflakeLoginOptions() config["client_session_keep_alive"] = "True" - session = Session.builder.configs(config).create() # noqa: F841 + session = snowflake.snowpark.Session.builder.configs(config).create() # noqa: F841 + execution_result_is_error = False + execution_result_value = None try: - # Wait for minimum required instances if specified - min_instances_str = os.environ.get(MIN_INSTANCES_ENV_VAR) or "1" - target_instances_str = os.environ.get(TARGET_INSTANCES_ENV_VAR) or min_instances_str - if target_instances_str and int(target_instances_str) > 1: - wait_for_instances( - int(min_instances_str), - int(target_instances_str), - min_wait_time=MIN_WAIT_TIME, - timeout=TIMEOUT, - check_interval=CHECK_INTERVAL, - ) - - # Log start marker for user script execution + # Wait for minimum required instances before starting user script execution + wait_for_instances( + MIN_INSTANCES, + TARGET_INSTANCES, + min_wait_time=MIN_WAIT_TIME, + timeout=TIMEOUT, + check_interval=CHECK_INTERVAL, + ) + + # Log start marker before starting user script execution print(LOG_START_MSG) # noqa: T201 - # Run the script with the specified arguments - result = run_script(script_path, *script_args, main_func=script_main_func) + # Run the user script + execution_result_value = run_script(script_path, *script_args, main_func=script_main_func) # Log end marker for user script execution print(LOG_END_MSG) # noqa: T201 - result_obj = ExecutionResult(result=result) - return result_obj + return execution_result_value + except Exception as e: tb = e.__traceback__ skip_files = {__file__, runpy.__file__} @@ -328,35 +377,23 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] = tb = tb.tb_next cleaned_ex = copy.copy(e) # Need to create a mutable copy of exception to set __traceback__ cleaned_ex = cleaned_ex.with_traceback(tb) - result_obj = ExecutionResult(exception=cleaned_ex) + execution_result_value = cleaned_ex + execution_result_is_error = True raise finally: - result_dict = result_obj.to_dict() - try: - # Serialize result using cloudpickle - result_pickle_path = result_abs_path - with open(result_pickle_path, "wb") as f: - cloudpickle.dump(result_dict, f) # Pickle dictionary form for compatibility - except Exception as pkl_exc: - warnings.warn(f"Failed to pickle result to {result_pickle_path}: {pkl_exc}", RuntimeWarning, stacklevel=1) - - try: - # Serialize result to JSON as fallback path in case of cross version incompatibility - # TODO: Manually convert non-serializable types to strings - result_json_path = os.path.splitext(result_abs_path)[0] + ".json" - with open(result_json_path, "w") as f: - json.dump(result_dict, f, indent=2, cls=SimpleJSONEncoder) - except Exception as json_exc: - warnings.warn( - f"Failed to serialize JSON result to {result_json_path}: {json_exc}", RuntimeWarning, stacklevel=1 - ) - - # Close the session after serializing the result + # Ensure the output directory exists before trying to write result files. + result_abs_path = ( + JOB_RESULT_PATH if os.path.isabs(JOB_RESULT_PATH) else os.path.join(STAGE_MOUNT_PATH, JOB_RESULT_PATH) + ) + output_dir = os.path.dirname(result_abs_path) + os.makedirs(output_dir, exist_ok=True) + + # Save the result before closing the session + save_mljob_result(execution_result_value, execution_result_is_error, result_abs_path) session.close() if __name__ == "__main__": - # Parse command line arguments parser = argparse.ArgumentParser(description="Launch a Python script and save the result") parser.add_argument("script_path", help="Path to the Python script to execute") parser.add_argument("script_args", nargs="*", help="Arguments to pass to the script") diff --git a/snowflake/ml/jobs/_utils/spec_utils.py b/snowflake/ml/jobs/_utils/spec_utils.py index 216a22f8..3a197846 100644 --- a/snowflake/ml/jobs/_utils/spec_utils.py +++ b/snowflake/ml/jobs/_utils/spec_utils.py @@ -104,7 +104,7 @@ def _get_image_spec( image_tag = runtime_environment else: container_image = runtime_environment - elif feature_flags.FeatureFlags.ENABLE_IMAGE_VERSION_ENV_VAR.is_enabled(): + elif feature_flags.FeatureFlags.ENABLE_RUNTIME_VERSIONS.is_enabled(): container_image = _get_runtime_image(session, hardware) # type: ignore[arg-type] container_image = container_image or f"{image_repo}/{image_name}:{image_tag}" @@ -266,6 +266,7 @@ def generate_service_spec( {"name": "ray-client-server-endpoint", "port": 10001, "protocol": "TCP"}, {"name": "ray-gcs-endpoint", "port": 12001, "protocol": "TCP"}, {"name": "ray-dashboard-grpc-endpoint", "port": 12002, "protocol": "TCP"}, + {"name": "ray-dashboard-endpoint", "port": 12003, "protocol": "TCP"}, {"name": "ray-object-manager-endpoint", "port": 12011, "protocol": "TCP"}, {"name": "ray-node-manager-endpoint", "port": 12012, "protocol": "TCP"}, {"name": "ray-runtime-agent-endpoint", "port": 12013, "protocol": "TCP"}, diff --git a/snowflake/ml/jobs/_utils/test_files/mljob_launcher_tests/nonserializable_result_script.py b/snowflake/ml/jobs/_utils/test_files/mljob_launcher_tests/nonserializable_result_script.py new file mode 100644 index 00000000..94eded89 --- /dev/null +++ b/snowflake/ml/jobs/_utils/test_files/mljob_launcher_tests/nonserializable_result_script.py @@ -0,0 +1,13 @@ +import threading + + +class NonserializableClass: + def __init__(self, value) -> None: + self.value = value + self._lock = threading.Lock() # Non-serializable attribute + + def __str__(self) -> str: + return str(self.value) + + +__return__ = NonserializableClass(100) diff --git a/snowflake/ml/jobs/_utils/types.py b/snowflake/ml/jobs/_utils/types.py index 31357879..112da206 100644 --- a/snowflake/ml/jobs/_utils/types.py +++ b/snowflake/ml/jobs/_utils/types.py @@ -11,6 +11,7 @@ "CANCELLING", "CANCELLED", "INTERNAL_ERROR", + "DELETED", ] @@ -106,3 +107,12 @@ class ImageSpec: resource_requests: ComputeResources resource_limits: ComputeResources container_image: str + + +@dataclass(frozen=True) +class ServiceInfo: + database_name: str + schema_name: str + status: str + compute_pool: str + target_instances: int diff --git a/snowflake/ml/jobs/job.py b/snowflake/ml/jobs/job.py index 3188ac7e..c1cb997e 100644 --- a/snowflake/ml/jobs/job.py +++ b/snowflake/ml/jobs/job.py @@ -12,12 +12,19 @@ from snowflake.ml._internal import telemetry from snowflake.ml._internal.utils import identifier from snowflake.ml._internal.utils.mixins import SerializableSessionMixin -from snowflake.ml.jobs._utils import constants, interop_utils, query_helper, types +from snowflake.ml.jobs._interop import results as interop_result, utils as interop_utils +from snowflake.ml.jobs._utils import ( + constants, + payload_utils, + query_helper, + stage_utils, + types, +) from snowflake.snowpark import Row, context as sp_context from snowflake.snowpark.exceptions import SnowparkSQLException _PROJECT = "MLJob" -TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "CANCELLED", "INTERNAL_ERROR"} +TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "CANCELLED", "INTERNAL_ERROR", "DELETED"} T = TypeVar("T") @@ -36,7 +43,12 @@ def __init__( self._session = session or sp_context.get_active_session() self._status: types.JOB_STATUS = "PENDING" - self._result: Optional[interop_utils.ExecutionResult] = None + self._result: Optional[interop_result.ExecutionResult] = None + + @cached_property + def _service_info(self) -> types.ServiceInfo: + """Get the job's service info.""" + return _resolve_service_info(self.id, self._session) @cached_property def name(self) -> str: @@ -44,7 +56,7 @@ def name(self) -> str: @cached_property def target_instances(self) -> int: - return _get_target_instances(self._session, self.id) + return self._service_info.target_instances @cached_property def min_instances(self) -> int: @@ -69,8 +81,7 @@ def status(self) -> types.JOB_STATUS: @cached_property def _compute_pool(self) -> str: """Get the job's compute pool name.""" - row = _get_service_info(self._session, self.id) - return cast(str, row["compute_pool"]) + return self._service_info.compute_pool @property def _service_spec(self) -> dict[str, Any]: @@ -82,7 +93,13 @@ def _service_spec(self) -> dict[str, Any]: @property def _container_spec(self) -> dict[str, Any]: """Get the job's main container spec.""" - containers = self._service_spec["spec"]["containers"] + try: + containers = self._service_spec["spec"]["containers"] + except SnowparkSQLException as e: + if e.sql_error_code == 2003: + # If the job is deleted, the service spec is not available + return {} + raise if len(containers) == 1: return cast(dict[str, Any], containers[0]) try: @@ -105,22 +122,28 @@ def _result_path(self) -> str: if result_path_str is None: raise RuntimeError(f"Job {self.name} doesn't have a result path configured") - # If result path is relative, it is relative to the stage mount path - result_path = Path(result_path_str) - if not result_path.is_absolute(): - return f"{self._stage_path}/{result_path.as_posix()}" + return self._transform_path(result_path_str) - # If result path is absolute, it is relative to the stage mount path + def _transform_path(self, path_str: str) -> str: + """Transform a local path within the container to a stage path.""" + path = payload_utils.resolve_path(path_str) + if isinstance(path, stage_utils.StagePath): + # Stage paths need no transformation + return path.as_posix() + if not path.is_absolute(): + # Assume relative paths are relative to stage mount path + return f"{self._stage_path}/{path.as_posix()}" + + # If result path is absolute, rebase it onto the stage mount path + # TODO: Rather than matching by name, use the longest mount path which matches volume_mounts = self._container_spec["volumeMounts"] stage_mount_str = next(v for v in volume_mounts if v.get("name") == constants.STAGE_VOLUME_NAME)["mountPath"] stage_mount = Path(stage_mount_str) try: - relative_path = result_path.relative_to(stage_mount) + relative_path = path.relative_to(stage_mount) return f"{self._stage_path}/{relative_path.as_posix()}" except ValueError: - raise ValueError( - f"Result path {result_path} is absolute, but should be relative to stage mount {stage_mount}" - ) + raise ValueError(f"Result path {path} is absolute, but should be relative to stage mount {stage_mount}") @overload def get_logs( @@ -165,7 +188,14 @@ def get_logs( Returns: The job's execution logs. """ - logs = _get_logs(self._session, self.id, limit, instance_id, self._container_spec["name"], verbose) + logs = _get_logs( + self._session, + self.id, + limit, + instance_id, + self._container_spec["name"] if "name" in self._container_spec else constants.DEFAULT_CONTAINER_NAME, + verbose, + ) assert isinstance(logs, str) # mypy if as_list: return logs.splitlines() @@ -218,7 +248,6 @@ def wait(self, timeout: float = -1) -> types.JOB_STATUS: delay = min(delay * 1.2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff return self.status - @snowpark._internal.utils.private_preview(version="1.8.2") @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["timeout"]) def result(self, timeout: float = -1) -> T: """ @@ -237,13 +266,13 @@ def result(self, timeout: float = -1) -> T: if self._result is None: self.wait(timeout) try: - self._result = interop_utils.fetch_result(self._session, self._result_path) + self._result = interop_utils.load_result( + self._result_path, session=self._session, path_transform=self._transform_path + ) except Exception as e: - raise RuntimeError(f"Failed to retrieve result for job (id={self.name})") from e + raise RuntimeError(f"Failed to retrieve result for job, error: {e!r}") from e - if self._result.success: - return cast(T, self._result.result) - raise RuntimeError(f"Job execution failed (id={self.name})") from self._result.exception + return cast(T, self._result.get_value()) @telemetry.send_api_usage_telemetry(project=_PROJECT) def cancel(self) -> None: @@ -256,22 +285,28 @@ def cancel(self) -> None: self._session.sql(f"CALL {self.id}!spcs_cancel_job()").collect() logger.debug(f"Cancellation requested for job {self.id}") except SnowparkSQLException as e: - raise RuntimeError(f"Failed to cancel job {self.id}: {e.message}") from e + raise RuntimeError(f"Failed to cancel job, error: {e!r}") from e @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "instance_id"]) def _get_status(session: snowpark.Session, job_id: str, instance_id: Optional[int] = None) -> types.JOB_STATUS: """Retrieve job or job instance execution status.""" - if instance_id is not None: - # Get specific instance status - rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect() - for row in rows: - if row["instance_id"] == str(instance_id): - return cast(types.JOB_STATUS, row["status"]) - raise ValueError(f"Instance {instance_id} not found in job {job_id}") - else: - row = _get_service_info(session, job_id) - return cast(types.JOB_STATUS, row["status"]) + try: + if instance_id is not None: + # Get specific instance status + rows = query_helper.run_query(session, "SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)) + for row in rows: + if row["instance_id"] == str(instance_id): + return cast(types.JOB_STATUS, row["status"]) + raise ValueError(f"Instance {instance_id} not found in job {job_id}") + else: + row = _get_service_info(session, job_id) + return cast(types.JOB_STATUS, row["status"]) + except SnowparkSQLException as e: + if e.sql_error_code == 2003: + row = _get_service_info_spcs(session, job_id) + return cast(types.JOB_STATUS, row["STATUS"]) + raise @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"]) @@ -542,8 +577,21 @@ def _get_compute_pool_info(session: snowpark.Session, compute_pool: str) -> Row: @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"]) def _get_target_instances(session: snowpark.Session, job_id: str) -> int: - row = _get_service_info(session, job_id) - return int(row["target_instances"]) + try: + row = _get_service_info(session, job_id) + return int(row["target_instances"]) + except SnowparkSQLException as e: + if e.sql_error_code == 2003: + row = _get_service_info_spcs(session, job_id) + try: + params = json.loads(row["PARAMETERS"]) + if isinstance(params, dict): + return int(params.get("REPLICAS", 1)) + else: + return 1 + except (json.JSONDecodeError, ValueError): + return 1 + raise def _get_logs_spcs( @@ -581,3 +629,87 @@ def _get_logs_spcs( query.append(f" LIMIT {limit};") rows = session.sql("\n".join(query)).collect() return rows + + +def _get_service_info_spcs(session: snowpark.Session, job_id: str) -> Any: + """ + Retrieve the service info from the SPCS interface. + + Args: + session (Session): The Snowpark session to use. + job_id (str): The job ID. + + Returns: + Any: The service info. + + Raises: + SnowparkSQLException: If the job does not exist or is too old to retrieve. + """ + db, schema, name = identifier.parse_schema_level_object_identifier(job_id) + db = db or session.get_current_database() + schema = schema or session.get_current_schema() + rows = query_helper.run_query( + session, + """ + select DATABASE_NAME, SCHEMA_NAME, NAME, STATUS, COMPUTE_POOL_NAME, PARAMETERS + from table(snowflake.spcs.get_job_history()) + where database_name = ? and schema_name = ? and name = ? + """, + params=(db, schema, name), + ) + if rows: + return rows[0] + else: + raise SnowparkSQLException(f"Job {job_id} does not exist or could not be retrieved", sql_error_code=2003) + + +def _resolve_service_info(id: str, session: snowpark.Session) -> types.ServiceInfo: + try: + row = _get_service_info(session, id) + except SnowparkSQLException as e: + if e.sql_error_code == 2003: + row = _get_service_info_spcs(session, id) + else: + raise + if not row: + raise SnowparkSQLException(f"Job {id} does not exist or could not be retrieved", sql_error_code=2003) + + if "compute_pool" in row: + compute_pool = row["compute_pool"] + elif "COMPUTE_POOL_NAME" in row: + compute_pool = row["COMPUTE_POOL_NAME"] + else: + raise ValueError(f"compute_pool not found in row: {row}") + + if "status" in row: + status = row["status"] + elif "STATUS" in row: + status = row["STATUS"] + else: + raise ValueError(f"status not found in row: {row}") + # Normalize target_instances + target_instances: int + if "target_instances" in row and row["target_instances"] is not None: + try: + target_instances = int(row["target_instances"]) + except (ValueError, TypeError): + target_instances = 1 + elif "PARAMETERS" in row and row["PARAMETERS"]: + try: + params = json.loads(row["PARAMETERS"]) + target_instances = int(params.get("REPLICAS", 1)) if isinstance(params, dict) else 1 + except (json.JSONDecodeError, ValueError, TypeError): + target_instances = 1 + else: + target_instances = 1 + + database_name = row["database_name"] if "database_name" in row else row["DATABASE_NAME"] + schema_name = row["schema_name"] if "schema_name" in row else row["SCHEMA_NAME"] + + return types.ServiceInfo( + database_name=database_name, + schema_name=schema_name, + status=cast(types.JOB_STATUS, status), + compute_pool=cast(str, compute_pool), + target_instances=target_instances, + ) diff --git a/snowflake/ml/jobs/manager.py b/snowflake/ml/jobs/manager.py index d462034e..8898304c 100644 --- a/snowflake/ml/jobs/manager.py +++ b/snowflake/ml/jobs/manager.py @@ -21,6 +21,7 @@ spec_utils, types, ) +from snowflake.snowpark._internal import utils as sp_utils from snowflake.snowpark.context import get_active_session from snowflake.snowpark.exceptions import SnowparkSQLException from snowflake.snowpark.functions import coalesce, col, lit, when @@ -179,8 +180,10 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob _ = job._service_spec return job except SnowparkSQLException as e: - if "does not exist" in e.message: - raise ValueError(f"Job does not exist: {job_id}") from e + if e.sql_error_code == 2003: + job = jb.MLJob[Any](job_id, session=session) + _ = job.status + return job raise @@ -446,7 +449,7 @@ def _submit_job( Raises: ValueError: If database or schema value(s) are invalid RuntimeError: If schema is not specified in session context or job submission - snowpark.exceptions.SnowparkSQLException: if failed to upload payload + SnowparkSQLException: if failed to upload payload """ session = _ensure_session(session) @@ -512,49 +515,44 @@ def _submit_job( uploaded_payload = payload_utils.JobPayload( source, entrypoint=entrypoint, pip_requirements=pip_requirements, additional_payloads=imports ).upload(session, stage_path) - except snowpark.exceptions.SnowparkSQLException as e: + except SnowparkSQLException as e: if e.sql_error_code == 90106: raise RuntimeError( "Please specify a schema, either in the session context or as a parameter in the job submission" ) raise - # FIXME: Temporary patches, remove this after v1 is deprecated - if target_instances > 1: - default_spec_overrides = { - "spec": { - "endpoints": [ - {"name": "ray-dashboard-endpoint", "port": 12003, "protocol": "TCP"}, - ] - }, - } - if spec_overrides: - spec_overrides = spec_utils.merge_patch( - default_spec_overrides, spec_overrides, display_name="spec_overrides" - ) - else: - spec_overrides = default_spec_overrides - - if feature_flags.FeatureFlags.USE_SUBMIT_JOB_V2.is_enabled(): + if feature_flags.FeatureFlags.USE_SUBMIT_JOB_V2.is_enabled(default=True): # Add default env vars (extracted from spec_utils.generate_service_spec) combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})} - return _do_submit_job_v2( - session=session, - payload=uploaded_payload, - args=args, - env_vars=combined_env_vars, - spec_overrides=spec_overrides, - compute_pool=compute_pool, - job_id=job_id, - external_access_integrations=external_access_integrations, - query_warehouse=query_warehouse, - target_instances=target_instances, - min_instances=min_instances, - enable_metrics=enable_metrics, - use_async=True, - runtime_environment=runtime_environment, - ) + try: + return _do_submit_job_v2( + session=session, + payload=uploaded_payload, + args=args, + env_vars=combined_env_vars, + spec_overrides=spec_overrides, + compute_pool=compute_pool, + job_id=job_id, + external_access_integrations=external_access_integrations, + query_warehouse=query_warehouse, + target_instances=target_instances, + min_instances=min_instances, + enable_metrics=enable_metrics, + use_async=True, + runtime_environment=runtime_environment, + ) + except SnowparkSQLException as e: + if not (e.sql_error_code == 90237 and sp_utils.is_in_stored_procedure()): # type: ignore[no-untyped-call] + raise + # SNOW-2390287: SYSTEM$EXECUTE_ML_JOB() is erroneously blocked in owner's rights + # stored procedures. This will be fixed in an upcoming release. + logger.warning( + "Job submission using V2 failed with error {}. Falling back to V1.".format( + str(e).split("\n", 1)[0], + ) + ) # Fall back to v1 # Generate service spec @@ -688,7 +686,7 @@ def _do_submit_job_v2( # for the image tag or full image URL, we use that directly if runtime_environment: spec_options["RUNTIME"] = runtime_environment - elif feature_flags.FeatureFlags.ENABLE_IMAGE_VERSION_ENV_VAR.is_enabled(): + elif feature_flags.FeatureFlags.ENABLE_RUNTIME_VERSIONS.is_enabled(): # when feature flag is enabled, we get the local python version and wrap it in a dict # in system function, we can know whether it is python version or image tag or full image URL through the format spec_options["RUNTIME"] = json.dumps({"pythonVersion": f"{sys.version_info.major}.{sys.version_info.minor}"}) diff --git a/snowflake/ml/model/_client/model/BUILD.bazel b/snowflake/ml/model/_client/model/BUILD.bazel index 80002f81..d500c189 100644 --- a/snowflake/ml/model/_client/model/BUILD.bazel +++ b/snowflake/ml/model/_client/model/BUILD.bazel @@ -51,6 +51,7 @@ py_library( "//snowflake/ml/model/_client/ops:model_ops", "//snowflake/ml/model/_client/ops:service_ops", "//snowflake/ml/model/_model_composer/model_manifest:model_manifest_schema", + "//snowflake/ml/model/_model_composer/model_method:utils", "//snowflake/ml/utils:html_utils", ], ) @@ -58,6 +59,7 @@ py_library( py_test( name = "model_version_impl_test", srcs = ["model_version_impl_test.py"], + data = ["sample_model_spec.yaml"], tags = ["feature:model_registry"], deps = [ ":batch_inference_specs", diff --git a/snowflake/ml/model/_client/model/model_version_impl.py b/snowflake/ml/model/_client/model/model_version_impl.py index 8044469b..fdbfbf66 100644 --- a/snowflake/ml/model/_client/model/model_version_impl.py +++ b/snowflake/ml/model/_client/model/model_version_impl.py @@ -19,7 +19,9 @@ from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops from snowflake.ml.model._model_composer import model_composer from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema +from snowflake.ml.model._model_composer.model_method import utils as model_method_utils from snowflake.ml.model._packager.model_handlers import snowmlmodel +from snowflake.ml.model._packager.model_meta import model_meta_schema from snowflake.snowpark import Session, async_job, dataframe _TELEMETRY_PROJECT = "MLOps" @@ -41,6 +43,7 @@ class ModelVersion(lineage_node.LineageNode): _model_name: sql_identifier.SqlIdentifier _version_name: sql_identifier.SqlIdentifier _functions: list[model_manifest_schema.ModelFunctionInfo] + _model_spec: Optional[model_meta_schema.ModelMetadataDict] def __init__(self) -> None: raise RuntimeError("ModelVersion's initializer is not meant to be used. Use `version` from model instead.") @@ -150,6 +153,7 @@ def _ref( self._model_name = model_name self._version_name = version_name self._functions = self._get_functions() + self._model_spec = None super(cls, cls).__init__( self, session=model_ops._session, @@ -437,6 +441,26 @@ def show_functions(self) -> list[model_manifest_schema.ModelFunctionInfo]: """ return self._functions + def _get_model_spec(self, statement_params: Optional[dict[str, Any]] = None) -> model_meta_schema.ModelMetadataDict: + """Fetch and cache the model spec for this model version. + + Args: + statement_params: Optional dictionary of statement parameters to include + in the SQL command to fetch the model spec. + + Returns: + The model spec as a dictionary for this model version. + """ + if self._model_spec is None: + self._model_spec = self._model_ops._fetch_model_spec( + database_name=None, + schema_name=None, + model_name=self._model_name, + version_name=self._version_name, + statement_params=statement_params, + ) + return self._model_spec + @overload def run( self, @@ -531,6 +555,8 @@ def run( statement_params=statement_params, ) else: + explain_case_sensitive = self._determine_explain_case_sensitivity(target_function_info, statement_params) + return self._model_ops.invoke_method( method_name=sql_identifier.SqlIdentifier(target_function_info["name"]), method_function_type=target_function_info["target_method_function_type"], @@ -544,8 +570,20 @@ def run( partition_column=partition_column, statement_params=statement_params, is_partitioned=target_function_info["is_partitioned"], + explain_case_sensitive=explain_case_sensitive, ) + def _determine_explain_case_sensitivity( + self, + target_function_info: model_manifest_schema.ModelFunctionInfo, + statement_params: Optional[dict[str, Any]] = None, + ) -> bool: + model_spec = self._get_model_spec(statement_params) + method_options = model_spec.get("method_options", {}) + return model_method_utils.determine_explain_case_sensitive_from_method_options( + method_options, target_function_info["name"] + ) + @telemetry.send_api_usage_telemetry( project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, @@ -803,13 +841,7 @@ def _check_huggingface_text_generation_model( ValueError: If the model is not a HuggingFace text-generation model. """ # Fetch model spec - model_spec = self._model_ops._fetch_model_spec( - database_name=None, - schema_name=None, - model_name=self._model_name, - version_name=self._version_name, - statement_params=statement_params, - ) + model_spec = self._get_model_spec(statement_params) # Check if model_type is huggingface_pipeline model_type = model_spec.get("model_type") diff --git a/snowflake/ml/model/_client/model/model_version_impl_test.py b/snowflake/ml/model/_client/model/model_version_impl_test.py index 414f2298..2b41d9df 100644 --- a/snowflake/ml/model/_client/model/model_version_impl_test.py +++ b/snowflake/ml/model/_client/model/model_version_impl_test.py @@ -21,7 +21,7 @@ from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema from snowflake.ml.test_utils import mock_data_frame, mock_session from snowflake.ml.test_utils.mock_progress import create_mock_progress_status -from snowflake.snowpark import Session +from snowflake.snowpark import Session, row _DUMMY_SIG = { "predict": model_signature.ModelSignature( @@ -286,6 +286,7 @@ def test_run(self) -> None: self.m_mv.run(m_df) with mock.patch.object(self.m_mv._model_ops, "invoke_method", return_value=m_df) as mock_invoke_method: + self._add_show_versions_mock() self.m_mv.run(m_df, function_name='"predict"') mock_invoke_method.assert_called_once_with( method_name='"predict"', @@ -300,6 +301,7 @@ def test_run(self) -> None: partition_column=None, statement_params=mock.ANY, is_partitioned=False, + explain_case_sensitive=False, ) with mock.patch.object(self.m_mv._model_ops, "invoke_method", return_value=m_df) as mock_invoke_method: @@ -317,6 +319,7 @@ def test_run(self) -> None: partition_column=None, statement_params=mock.ANY, is_partitioned=False, + explain_case_sensitive=False, ) def test_run_without_method_name(self) -> None: @@ -336,6 +339,7 @@ def test_run_without_method_name(self) -> None: self.m_mv._functions = m_methods with mock.patch.object(self.m_mv._model_ops, "invoke_method", return_value=m_df) as mock_invoke_method: + self._add_show_versions_mock() self.m_mv.run(m_df) mock_invoke_method.assert_called_once_with( method_name='"predict"', @@ -350,6 +354,7 @@ def test_run_without_method_name(self) -> None: partition_column=None, statement_params=mock.ANY, is_partitioned=False, + explain_case_sensitive=False, ) def test_run_strict(self) -> None: @@ -369,6 +374,7 @@ def test_run_strict(self) -> None: self.m_mv._functions = m_methods with mock.patch.object(self.m_mv._model_ops, "invoke_method", return_value=m_df) as mock_invoke_method: + self._add_show_versions_mock() self.m_mv.run(m_df, strict_input_validation=True) mock_invoke_method.assert_called_once_with( method_name='"predict"', @@ -383,6 +389,7 @@ def test_run_strict(self) -> None: partition_column=None, statement_params=mock.ANY, is_partitioned=False, + explain_case_sensitive=False, ) def test_run_table_function_method(self) -> None: @@ -410,6 +417,7 @@ def test_run_table_function_method(self) -> None: self.m_mv._functions = m_methods with mock.patch.object(self.m_mv._model_ops, "invoke_method", return_value=m_df) as mock_invoke_method: + self._add_show_versions_mock() self.m_mv.run(m_df, function_name='"predict_table"') mock_invoke_method.assert_called_once_with( method_name='"predict_table"', @@ -424,9 +432,11 @@ def test_run_table_function_method(self) -> None: partition_column=None, statement_params=mock.ANY, is_partitioned=True, + explain_case_sensitive=False, ) with mock.patch.object(self.m_mv._model_ops, "invoke_method", return_value=m_df) as mock_invoke_method: + self._add_show_versions_mock() self.m_mv.run(m_df, function_name='"predict_table"', partition_column="PARTITION_COLUMN") mock_invoke_method.assert_called_once_with( method_name='"predict_table"', @@ -441,6 +451,7 @@ def test_run_table_function_method(self) -> None: partition_column="PARTITION_COLUMN", statement_params=mock.ANY, is_partitioned=True, + explain_case_sensitive=False, ) def test_run_table_function_method_no_partition(self) -> None: @@ -467,7 +478,22 @@ def test_run_table_function_method_no_partition(self) -> None: ] self.m_mv._functions = m_methods - with mock.patch.object(self.m_mv._model_ops, "invoke_method", return_value=m_df) as mock_invoke_method: + with ( + mock.patch.object(self.m_mv._model_ops, "invoke_method", return_value=m_df) as mock_invoke_method, + mock.patch.object( + self.m_mv._model_ops, + "_fetch_model_spec", + return_value={ + "model_type": "huggingface_pipeline", + "models": { + "model1": { + "model_type": "huggingface_pipeline", + "options": {"task": "text-generation"}, + } + }, + }, + ), + ): self.m_mv.run(m_df, function_name='"explain_table"') mock_invoke_method.assert_called_once_with( method_name='"explain_table"', @@ -482,6 +508,7 @@ def test_run_table_function_method_no_partition(self) -> None: partition_column=None, statement_params=mock.ANY, is_partitioned=False, + explain_case_sensitive=False, ) def test_run_service(self) -> None: @@ -1468,6 +1495,9 @@ def test_check_huggingface_text_generation_model(self) -> None: statement_params=None, ) + # Reset the cached model spec + self.m_mv._model_spec = None + # Test failure case - not a HuggingFace model with mock.patch.object( self.m_mv._model_ops, @@ -1484,6 +1514,9 @@ def test_check_huggingface_text_generation_model(self) -> None: ) self.assertIn("Found model_type: sklearn", str(cm.exception)) + # Reset the cached model spec + self.m_mv._model_spec = None + # Test failure case - HuggingFace model but wrong task with mock.patch.object( self.m_mv._model_ops, @@ -1503,6 +1536,9 @@ def test_check_huggingface_text_generation_model(self) -> None: self.assertIn("Inference engine is only supported for task 'text-generation'", str(cm.exception)) self.assertIn("Found task(s): image-classification", str(cm.exception)) + # Reset the cached model spec + self.m_mv._model_spec = None + # Test failure case - HuggingFace model with no task with mock.patch.object( self.m_mv._model_ops, @@ -1758,6 +1794,28 @@ def test_run_batch_with_none_job_spec(self) -> None: statement_params=mock.ANY, ) + def _add_show_versions_mock(self) -> None: + current_dir = os.path.dirname(__file__) + data_file_path = os.path.join(current_dir, "sample_model_spec.yaml") + with open(data_file_path) as f: + model_spec = f.read() + model_attributes = """{ + "framework":"sklearn", + "task":"TABULAR_BINARY_CLASSIFICATION", + "client":"snowflake-ml-python 1.7.5"}""" + sql_result = [ + row.Row( + name='"v1"', + comment=None, + metadata={}, + model_spec=model_spec, + model_attributes=model_attributes, + ), + ] + self.m_session.add_mock_sql( + "SHOW VERSIONS LIKE 'v1' IN MODEL TEMP.\"test\".MODEL", result=mock_data_frame.MockDataFrame(sql_result) + ) + if __name__ == "__main__": absltest.main() diff --git a/snowflake/ml/model/_client/model/sample_model_spec.yaml b/snowflake/ml/model/_client/model/sample_model_spec.yaml new file mode 100644 index 00000000..ed17343d --- /dev/null +++ b/snowflake/ml/model/_client/model/sample_model_spec.yaml @@ -0,0 +1,144 @@ +--- +version: '2023-12-01' +method_options: + predict: + case_sensitive: true +min_snowpark_ml_version: 1.0.12 +creation_timestamp: '2025-07-24 18:31:40.087360' +env: + conda: env/conda.yml + cuda_version: + pip: env/requirements.txt + python_version: '3.9' + snowpark_ml_version: 1.7.5 +explainability: + algorithm: shap +model_type: sklearn +models: + MODEL: + artifacts: {} + function_properties: {} + handler_version: '2023-12-01' + model_type: sklearn + name: MODEL + options: {} + path: model.pkl +name: MODEL +function_properties: {} +metadata: +task: TABULAR_MULTI_CLASSIFICATION +runtimes: + cpu: + dependencies: + conda: runtimes/cpu/env/conda.yml + pip: runtimes/cpu/env/requirements.txt + imports: [] +signatures: + explain: + inputs: + - name: feature_1 + nullable: true + type: DOUBLE + - name: feature_2 + nullable: true + type: DOUBLE + - name: feature_3 + nullable: true + type: DOUBLE + - name: feature_4 + nullable: true + type: DOUBLE + - name: feature_5 + nullable: true + type: DOUBLE + outputs: + - name: feature_1_explanation + nullable: true + type: STRING + - name: feature_2_explanation + nullable: true + type: STRING + - name: feature_3_explanation + nullable: true + type: STRING + - name: feature_4_explanation + nullable: true + type: STRING + - name: feature_5_explanation + nullable: true + type: STRING + predict: + inputs: + - name: feature_1 + nullable: true + type: DOUBLE + - name: feature_2 + nullable: true + type: DOUBLE + - name: feature_3 + nullable: true + type: DOUBLE + - name: feature_4 + nullable: true + type: DOUBLE + - name: feature_5 + nullable: true + type: DOUBLE + outputs: + - name: output_feature_0 + nullable: false + type: INT64 + predict_log_proba: + inputs: + - name: feature_1 + nullable: true + type: DOUBLE + - name: feature_2 + nullable: true + type: DOUBLE + - name: feature_3 + nullable: true + type: DOUBLE + - name: feature_4 + nullable: true + type: DOUBLE + - name: feature_5 + nullable: true + type: DOUBLE + outputs: + - name: output_feature_0 + nullable: false + type: DOUBLE + - name: output_feature_1 + nullable: false + type: DOUBLE + - name: output_feature_2 + nullable: false + type: DOUBLE + predict_proba: + inputs: + - name: feature_1 + nullable: true + type: DOUBLE + - name: feature_2 + nullable: true + type: DOUBLE + - name: feature_3 + nullable: true + type: DOUBLE + - name: feature_4 + nullable: true + type: DOUBLE + - name: feature_5 + nullable: true + type: DOUBLE + outputs: + - name: output_feature_0 + nullable: false + type: DOUBLE + - name: output_feature_1 + nullable: false + type: DOUBLE + - name: output_feature_2 + nullable: false + type: DOUBLE diff --git a/snowflake/ml/model/_client/ops/model_ops.py b/snowflake/ml/model/_client/ops/model_ops.py index 67bc13cf..c32bf7ad 100644 --- a/snowflake/ml/model/_client/ops/model_ops.py +++ b/snowflake/ml/model/_client/ops/model_ops.py @@ -952,6 +952,7 @@ def invoke_method( partition_column: Optional[sql_identifier.SqlIdentifier] = None, statement_params: Optional[dict[str, str]] = None, is_partitioned: Optional[bool] = None, + explain_case_sensitive: bool = False, ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]: ... @@ -967,6 +968,7 @@ def invoke_method( service_name: sql_identifier.SqlIdentifier, strict_input_validation: bool = False, statement_params: Optional[dict[str, str]] = None, + explain_case_sensitive: bool = False, ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]: ... @@ -986,6 +988,7 @@ def invoke_method( partition_column: Optional[sql_identifier.SqlIdentifier] = None, statement_params: Optional[dict[str, str]] = None, is_partitioned: Optional[bool] = None, + explain_case_sensitive: bool = False, ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]: identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED @@ -1068,6 +1071,7 @@ def invoke_method( version_name=version_name, statement_params=statement_params, is_partitioned=is_partitioned or False, + explain_case_sensitive=explain_case_sensitive, ) if keep_order: diff --git a/snowflake/ml/model/_client/ops/model_ops_test.py b/snowflake/ml/model/_client/ops/model_ops_test.py index 76bc1a4f..84fcbfa7 100644 --- a/snowflake/ml/model/_client/ops/model_ops_test.py +++ b/snowflake/ml/model/_client/ops/model_ops_test.py @@ -1601,6 +1601,7 @@ def test_invoke_method_table_function(self) -> None: version_name=sql_identifier.SqlIdentifier("V1"), statement_params=self.m_statement_params, is_partitioned=True, + explain_case_sensitive=False, ) mock_convert_from_df.assert_called_once_with( self.c_session, @@ -1621,6 +1622,7 @@ def test_invoke_method_table_function(self) -> None: version_name=sql_identifier.SqlIdentifier("V1"), statement_params=self.m_statement_params, is_partitioned=True, + explain_case_sensitive=False, ) mock_convert_to_df.assert_called_once_with( m_df, features=m_sig.outputs, statement_params=self.m_statement_params @@ -1654,6 +1656,7 @@ def test_invoke_method_table_function_partition_column(self) -> None: partition_column=partition_column, statement_params=self.m_statement_params, is_partitioned=True, + explain_case_sensitive=False, ) mock_convert_from_df.assert_called_once_with( self.c_session, @@ -1674,6 +1677,7 @@ def test_invoke_method_table_function_partition_column(self) -> None: version_name=sql_identifier.SqlIdentifier("V1"), statement_params=self.m_statement_params, is_partitioned=True, + explain_case_sensitive=False, ) mock_convert_to_df.assert_called_once_with( m_df, features=m_sig.outputs, statement_params=self.m_statement_params diff --git a/snowflake/ml/model/_client/sql/model_version.py b/snowflake/ml/model/_client/sql/model_version.py index d0687d3a..193ad5e1 100644 --- a/snowflake/ml/model/_client/sql/model_version.py +++ b/snowflake/ml/model/_client/sql/model_version.py @@ -438,6 +438,7 @@ def invoke_table_function_method( partition_column: Optional[sql_identifier.SqlIdentifier], statement_params: Optional[dict[str, Any]] = None, is_partitioned: bool = True, + explain_case_sensitive: bool = False, ) -> dataframe.DataFrame: with_statements = [] if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0: @@ -505,7 +506,8 @@ def invoke_table_function_method( cols_to_drop = [] for output_name, output_type, output_col_name in returns: - output_identifier = sql_identifier.SqlIdentifier(output_name).identifier() + case_sensitive = "explain" in method_name.resolved().lower() and explain_case_sensitive + output_identifier = sql_identifier.SqlIdentifier(output_name, case_sensitive=case_sensitive).identifier() if output_identifier != output_col_name: cols_to_drop.append(output_identifier) output_cols.append(F.col(output_identifier).astype(output_type)) diff --git a/snowflake/ml/model/_model_composer/model_method/BUILD.bazel b/snowflake/ml/model/_model_composer/model_method/BUILD.bazel index 924cd9ea..e1db7dd8 100644 --- a/snowflake/ml/model/_model_composer/model_method/BUILD.bazel +++ b/snowflake/ml/model/_model_composer/model_method/BUILD.bazel @@ -21,6 +21,11 @@ py_library( ], ) +py_library( + name = "utils", + srcs = ["utils.py"], +) + py_library( name = "function_generator", srcs = ["function_generator.py"], @@ -53,6 +58,7 @@ py_library( deps = [ ":constants", ":function_generator", + ":utils", "//snowflake/ml/_internal:platform_capabilities", "//snowflake/ml/_internal/utils:sql_identifier", "//snowflake/ml/model:model_signature", diff --git a/snowflake/ml/model/_model_composer/model_method/model_method.py b/snowflake/ml/model/_model_composer/model_method/model_method.py index 2ae8cbc1..7051385b 100644 --- a/snowflake/ml/model/_model_composer/model_method/model_method.py +++ b/snowflake/ml/model/_model_composer/model_method/model_method.py @@ -11,6 +11,7 @@ from snowflake.ml.model._model_composer.model_method import ( constants, function_generator, + utils, ) from snowflake.ml.model._packager.model_meta import model_meta as model_meta_api from snowflake.ml.model.volatility import Volatility @@ -34,9 +35,13 @@ def get_model_method_options_from_options( options: type_hints.ModelSaveOption, target_method: str ) -> ModelMethodOptions: default_function_type = model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value + method_option = options.get("method_options", {}).get(target_method, {}) + case_sensitive = method_option.get("case_sensitive", False) if target_method == "explain": default_function_type = model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value - method_option = options.get("method_options", {}).get(target_method, {}) + case_sensitive = utils.determine_explain_case_sensitive_from_method_options( + options.get("method_options", {}), target_method + ) global_function_type = options.get("function_type", default_function_type) function_type = method_option.get("function_type", global_function_type) if function_type not in [function_type.value for function_type in model_manifest_schema.ModelMethodFunctionTypes]: @@ -48,7 +53,7 @@ def get_model_method_options_from_options( # Only include volatility if explicitly provided in method options result: ModelMethodOptions = ModelMethodOptions( - case_sensitive=method_option.get("case_sensitive", False), + case_sensitive=case_sensitive, function_type=function_type, ) if resolved_volatility: diff --git a/snowflake/ml/model/_model_composer/model_method/utils.py b/snowflake/ml/model/_model_composer/model_method/utils.py new file mode 100644 index 00000000..46d3c8a7 --- /dev/null +++ b/snowflake/ml/model/_model_composer/model_method/utils.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from typing import Any, Mapping, Optional + + +def determine_explain_case_sensitive_from_method_options( + method_options: Mapping[str, Optional[Mapping[str, Any]]], + target_method: str, +) -> bool: + """Determine explain method case sensitivity from related predict methods. + + Args: + method_options: Mapping from method name to its options. Each option may + contain ``"case_sensitive"`` to indicate SQL identifier sensitivity. + target_method: The target method name being resolved (e.g., an ``explain_*`` + method). + + Returns: + True if the explain method should be treated as case sensitive; otherwise False. + """ + if "explain" not in target_method: + return False + predict_priority_methods = ["predict_proba", "predict", "predict_log_proba"] + for src_method in predict_priority_methods: + src_opts = method_options.get(src_method) + if src_opts is not None: + return bool(src_opts.get("case_sensitive", False)) + return False diff --git a/snowflake/ml/model/_packager/model_env/model_env.py b/snowflake/ml/model/_packager/model_env/model_env.py index b99551cf..9f11dd50 100644 --- a/snowflake/ml/model/_packager/model_env/model_env.py +++ b/snowflake/ml/model/_packager/model_env/model_env.py @@ -240,14 +240,31 @@ def remove_if_present_conda(self, conda_pkgs: list[str]) -> None: self._conda_dependencies[channel].remove(spec) def generate_env_for_cuda(self) -> None: + + # Insert py-xgboost-gpu only for XGBoost versions < 3.0.0 xgboost_spec = env_utils.find_dep_spec( - self._conda_dependencies, self._pip_requirements, conda_pkg_name="xgboost", remove_spec=True + self._conda_dependencies, self._pip_requirements, conda_pkg_name="xgboost", remove_spec=False ) if xgboost_spec: - self.include_if_absent( - [ModelDependency(requirement=f"py-xgboost-gpu{xgboost_spec.specifier}", pip_name="xgboost")], - check_local_version=False, - ) + # Only handle explicitly pinned versions. Insert GPU variant iff pinned major < 3. + pinned_major: Optional[int] = None + for spec in xgboost_spec.specifier: + if spec.operator in ("==", "===", ">", ">="): + try: + pinned_major = version.parse(spec.version).major + except version.InvalidVersion: + pinned_major = None + break + + if pinned_major is not None and pinned_major < 3: + xgboost_spec = env_utils.find_dep_spec( + self._conda_dependencies, self._pip_requirements, conda_pkg_name="xgboost", remove_spec=True + ) + if xgboost_spec: + self.include_if_absent( + [ModelDependency(requirement=f"py-xgboost-gpu{xgboost_spec.specifier}", pip_name="xgboost")], + check_local_version=False, + ) tf_spec = env_utils.find_dep_spec( self._conda_dependencies, self._pip_requirements, conda_pkg_name="tensorflow", remove_spec=True diff --git a/snowflake/ml/model/_packager/model_meta/model_meta.py b/snowflake/ml/model/_packager/model_meta/model_meta.py index 4633bb7e..ed74aedc 100644 --- a/snowflake/ml/model/_packager/model_meta/model_meta.py +++ b/snowflake/ml/model/_packager/model_meta/model_meta.py @@ -116,6 +116,8 @@ def create_model_metadata( if embed_local_ml_library: env.snowpark_ml_version = f"{snowml_version.VERSION}+{file_utils.hash_directory(path_to_copy)}" + # Persist full method_options + method_options: dict[str, dict[str, Any]] = kwargs.pop("method_options", {}) model_meta = ModelMetadata( name=name, env=env, @@ -124,6 +126,7 @@ def create_model_metadata( signatures=signatures, function_properties=function_properties, task=task, + method_options=method_options, ) code_dir_path = os.path.join(model_dir_path, MODEL_CODE_DIR) @@ -256,6 +259,7 @@ def __init__( original_metadata_version: Optional[str] = model_meta_schema.MODEL_METADATA_VERSION, task: model_types.Task = model_types.Task.UNKNOWN, explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = None, + method_options: Optional[dict[str, dict[str, Any]]] = None, ) -> None: self.name = name self.signatures: dict[str, model_signature.ModelSignature] = dict() @@ -283,6 +287,7 @@ def __init__( self.task: model_types.Task = task self.explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = explain_algorithm + self.method_options: dict[str, dict[str, Any]] = method_options or {} @property def min_snowpark_ml_version(self) -> str: @@ -342,6 +347,7 @@ def save(self, model_dir_path: str) -> None: else None ), "function_properties": self.function_properties, + "method_options": self.method_options, } ) with open(model_yaml_path, "w", encoding="utf-8") as out: @@ -381,6 +387,7 @@ def _validate_model_metadata(loaded_meta: Any) -> model_meta_schema.ModelMetadat task=loaded_meta.get("task", model_types.Task.UNKNOWN.value), explainability=loaded_meta.get("explainability", None), function_properties=loaded_meta.get("function_properties", {}), + method_options=loaded_meta.get("method_options", {}), ) @classmethod @@ -436,4 +443,5 @@ def load(cls, model_dir_path: str) -> "ModelMetadata": task=model_types.Task(model_dict.get("task", model_types.Task.UNKNOWN.value)), explain_algorithm=explanation_algorithm, function_properties=model_dict.get("function_properties", {}), + method_options=model_dict.get("method_options", {}), ) diff --git a/snowflake/ml/model/_packager/model_meta/model_meta_schema.py b/snowflake/ml/model/_packager/model_meta/model_meta_schema.py index ebd3b23a..7dabaaa1 100644 --- a/snowflake/ml/model/_packager/model_meta/model_meta_schema.py +++ b/snowflake/ml/model/_packager/model_meta/model_meta_schema.py @@ -125,6 +125,7 @@ class ModelMetadataDict(TypedDict): task: Required[str] explainability: NotRequired[Optional[ExplainabilityMetadataDict]] function_properties: NotRequired[dict[str, dict[str, Any]]] + method_options: NotRequired[dict[str, dict[str, Any]]] class ModelExplainAlgorithm(Enum): diff --git a/snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py b/snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py index 85cec2f8..cf476957 100755 --- a/snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +++ b/snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py @@ -30,5 +30,5 @@ "sqlparse>=0.4,<1", "tqdm<5", "typing-extensions>=4.1.0,<5", - "xgboost>=1.7.3,<3", + "xgboost<4", ] diff --git a/snowflake/ml/model/_packager/model_task/model_task_utils_test.py b/snowflake/ml/model/_packager/model_task/model_task_utils_test.py index 56f96beb..7689c300 100644 --- a/snowflake/ml/model/_packager/model_task/model_task_utils_test.py +++ b/snowflake/ml/model/_packager/model_task/model_task_utils_test.py @@ -55,19 +55,19 @@ def test_model_task_and_output_xgb_binary_classifier(self) -> None: def test_model_task_and_output_xgb_for_single_class(self) -> None: # without objective - classifier = xgboost.XGBClassifier() + classifier = xgboost.XGBClassifier(base_score=0.5) classifier.fit(binary_data_X, single_class_y) self._validate_model_task_and_output( classifier, type_hints.Task.TABULAR_BINARY_CLASSIFICATION, model_signature.DataType.DOUBLE ) # with binary objective - classifier = xgboost.XGBClassifier(objective="binary:logistic") + classifier = xgboost.XGBClassifier(objective="binary:logistic", base_score=0.5) classifier.fit(binary_data_X, single_class_y) self._validate_model_task_and_output( classifier, type_hints.Task.TABULAR_BINARY_CLASSIFICATION, model_signature.DataType.DOUBLE ) # with multiclass objective - params = {"objective": "multi:softmax", "num_class": 3} + params = {"objective": "multi:softmax", "num_class": 3, "base_score": 0.5} classifier = xgboost.XGBClassifier(**params) classifier.fit(binary_data_X, single_class_y) self._validate_model_task_and_output( diff --git a/snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py b/snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py index 86374337..069d7a48 100644 --- a/snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +++ b/snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py @@ -93,7 +93,7 @@ def __init__( cache_dir_name = tempfile.mkdtemp() super().__init__(cache_prefix=os.path.join(cache_dir_name, "cache")) - def next(self, batch_consumer_fn) -> int: # type: ignore[no-untyped-def] + def next(self, batch_consumer_fn) -> bool | int: # type: ignore[no-untyped-def] """Advance the iterator by 1 step and pass the data to XGBoost's batch_consumer_fn. This function is called by XGBoost during the construction of ``DMatrix`` @@ -101,7 +101,7 @@ def next(self, batch_consumer_fn) -> int: # type: ignore[no-untyped-def] batch_consumer_fn: batch consumer function Returns: - 0 if there is no more data, else 1. + False/0 if there is no more data, else True/1. """ while (self._df is None) or (self._df.shape[0] < self._batch_size): # Read files and append data to temp df until batch size is reached. @@ -117,7 +117,7 @@ def next(self, batch_consumer_fn) -> int: # type: ignore[no-untyped-def] if (self._df is None) or (self._df.shape[0] == 0): # No more data - return 0 + return False # Slice the temp df and save the remainder in the temp df batch_end_index = min(self._batch_size, self._df.shape[0]) @@ -133,8 +133,8 @@ def next(self, batch_consumer_fn) -> int: # type: ignore[no-untyped-def] func_args["weight"] = batch_df[self._sample_weight_col].squeeze() batch_consumer_fn(**func_args) - # Return 1 to let XGBoost know we haven't seen all the files yet. - return 1 + # Return True to let XGBoost know we haven't seen all the files yet. + return True def reset(self) -> None: """Reset the iterator to its beginning""" diff --git a/snowflake/ml/version.py b/snowflake/ml/version.py index 11936720..d236121c 100644 --- a/snowflake/ml/version.py +++ b/snowflake/ml/version.py @@ -1,2 +1,2 @@ # This is parsed by regex in conda recipe meta file. Make sure not to break it. -VERSION = "1.16.0" +VERSION = "1.17.0" diff --git a/tests/integ/snowflake/ml/experiment/experiment_tracking_integ_test.py b/tests/integ/snowflake/ml/experiment/experiment_tracking_integ_test.py index 2114025d..af2aa3be 100644 --- a/tests/integ/snowflake/ml/experiment/experiment_tracking_integ_test.py +++ b/tests/integ/snowflake/ml/experiment/experiment_tracking_integ_test.py @@ -111,7 +111,8 @@ def test_experiment_getstate_and_setstate_no_session(self) -> None: saved_schema = exp._session.get_current_schema() pickled = pickle.dumps(exp) - session = snowpark_session._active_sessions.pop() # Simulate having no active session + session_set = snowpark_session._active_sessions.copy() + snowpark_session._active_sessions.clear() # Simulate having no active session new_exp = pickle.loads(pickled) # Check that the session is None and the session state is populated correctly @@ -122,7 +123,7 @@ def test_experiment_getstate_and_setstate_no_session(self) -> None: self.assertEqual(new_exp._session_state.schema, saved_schema) # Restore the session and check for equality - snowpark_session._active_sessions.add(session) + snowpark_session._active_sessions.update(session_set) new_exp.set_experiment("TEST_EXPERIMENT") # set_experiment is decorated with @_restore_session self.assertIsNone(new_exp._session_state) self.assert_experiment_tracking_equality(exp, new_exp) diff --git a/tests/integ/snowflake/ml/jobs/classical_models_integ_test.py b/tests/integ/snowflake/ml/jobs/classical_models_integ_test.py index b18e7d9f..a4f380a2 100644 --- a/tests/integ/snowflake/ml/jobs/classical_models_integ_test.py +++ b/tests/integ/snowflake/ml/jobs/classical_models_integ_test.py @@ -2,14 +2,14 @@ from packaging import version from snowflake.ml._internal import env -from tests.integ.snowflake.ml.jobs.job_test_base import JobTestBase +from tests.integ.snowflake.ml.jobs.job_test_base import ModelingJobTestBase """ this integration test is only for classic models, like XGBoost and lightgbm. """ -class ClassicalModelTest(JobTestBase): +class ClassicalModelTest(ModelingJobTestBase): @parameterized.parameters("xgboost", "lightgbm", "sklearn") @absltest.skipIf( version.Version(env.PYTHON_VERSION) >= version.Version("3.11"), diff --git a/tests/integ/snowflake/ml/jobs/job_test_base.py b/tests/integ/snowflake/ml/jobs/job_test_base.py index ff3030e2..3bd3d69b 100644 --- a/tests/integ/snowflake/ml/jobs/job_test_base.py +++ b/tests/integ/snowflake/ml/jobs/job_test_base.py @@ -71,6 +71,8 @@ def _submit_func_as_file(self, func: Callable[[], None], **kwargs: Any) -> jobs. ) return job + +class ModelingJobTestBase(JobTestBase): def get_inference(self, model: Any, module_path: str) -> Any: return reflection_utils.run_reflected_func(module_path, _PREDICT_FUNC, model) diff --git a/tests/integ/snowflake/ml/jobs/jobs_integ_test.py b/tests/integ/snowflake/ml/jobs/jobs_integ_test.py index 52aac74b..0fc08bb0 100644 --- a/tests/integ/snowflake/ml/jobs/jobs_integ_test.py +++ b/tests/integ/snowflake/ml/jobs/jobs_integ_test.py @@ -10,6 +10,8 @@ from uuid import uuid4 import cloudpickle as cp +import numpy as np +import pandas as pd from absl.testing import absltest, parameterized from packaging import version @@ -18,6 +20,7 @@ from snowflake.ml._internal import env from snowflake.ml._internal.utils import identifier from snowflake.ml.jobs import job as jd +from snowflake.ml.jobs._interop import results as interop_result from snowflake.ml.jobs._utils import ( constants, feature_flags, @@ -149,23 +152,27 @@ def test_get_job_negative(self) -> None: ] for id in nonexistent_job_ids: with self.subTest(f"id={id}"): - with self.assertRaisesRegex(ValueError, "does not exist"): + with self.assertRaises(sp_exceptions.SnowparkSQLException) as cm: jobs.get_job(id, session=self.session) + self.assertRegex(str(cm.exception), "does not exist") + self.assertEqual(cm.exception.sql_error_code, 2003) def test_delete_job_negative(self) -> None: nonexistent_job_ids = [ f"{self.db}.non_existent_schema.nonexistent_job_id", f"{self.db}.{self.schema}.nonexistent_job_id", "nonexistent_job_id", - *INVALID_IDENTIFIERS, ] for id in nonexistent_job_ids: with self.subTest(f"id={id}"): job = jobs.MLJob[None](id, session=self.session) - with self.assertRaises(ValueError, msg=f"id={id}"): - jobs.delete_job(job.id, session=self.session) with self.assertRaises(sp_exceptions.SnowparkSQLException, msg=f"id={id}"): jobs.delete_job(job, session=self.session) + for id in INVALID_IDENTIFIERS: + with self.subTest(f"id={id}"): + job = jobs.MLJob[None](id, session=self.session) + with self.assertRaises(ValueError, msg=f"id={id}"): + jobs.delete_job(job.id, session=self.session) def test_get_status_negative(self) -> None: nonexistent_job_ids = [ @@ -177,7 +184,7 @@ def test_get_status_negative(self) -> None: for id in nonexistent_job_ids: with self.subTest(f"id={id}"): job = jobs.MLJob[None](id, session=self.session) - with self.assertRaises(sp_exceptions.SnowparkSQLException, msg=f"id={id}"): + with self.assertRaises((sp_exceptions.SnowparkSQLException, ValueError), msg=f"id={id}"): job.status def test_get_logs(self) -> None: @@ -232,30 +239,40 @@ def test_job_wait(self) -> None: finally: constants.JOB_POLL_MAX_DELAY_SECONDS = max_backoff - def test_job_execution(self) -> None: - payload = TestAsset("src/main.py") + @parameterized.product( # type: ignore[misc] + USE_SUBMIT_JOB_V2=[True, False], + ENABLE_RUNTIME_VERSIONS=[True, False], + ) + def test_job_execution(self, **feature_flag_kwargs: bool) -> None: + env_vars = {} + for param_name, param_value in feature_flag_kwargs.items(): + flag = getattr(feature_flags.FeatureFlags, param_name.upper()) + env_vars[flag.value] = str(param_value).lower() - # Create a job - job = jobs.submit_file( - payload.path, - self.compute_pool, - stage_name="payload_stage", - args=["foo", "--delay", "1"], - session=self.session, - ) + with mock.patch.dict(os.environ, env_vars): + payload = TestAsset("src/main.py") - # Wait for job to finish - self.assertEqual(job.wait(), "DONE", job.get_logs()) - self.assertEqual(job.status, "DONE") - self.assertIn("Job complete", job.get_logs()) - self.assertIsNone(job.result()) + # Create a job + job = jobs.submit_file( + payload.path, + self.compute_pool, + stage_name="payload_stage", + args=["foo", "--delay", "1"], + session=self.session, + ) - # Test loading job by ID - loaded_job = jobs.get_job(job.id, session=self.session) - self.assertEqual(loaded_job.status, "DONE") - self.assertIn("Job start", loaded_job.get_logs()) - self.assertIn("Job complete", loaded_job.get_logs()) - self.assertIsNone(loaded_job.result()) + # Wait for job to finish + self.assertEqual(job.wait(), "DONE", job.get_logs()) + self.assertEqual(job.status, "DONE") + self.assertIn("Job complete", job.get_logs()) + self.assertIsNone(job.result()) + + # Test loading job by ID + loaded_job = jobs.get_job(job.id, session=self.session) + self.assertEqual(loaded_job.status, "DONE") + self.assertIn("Job start", loaded_job.get_logs()) + self.assertIn("Job complete", loaded_job.get_logs()) + self.assertIsNone(loaded_job.result()) def test_job_execution_metrics(self) -> None: payload = TestAsset("src/main.py") @@ -325,7 +342,13 @@ def test_job_pickling(self) -> None: """Dedicated test for MLJob pickling and unpickling functionality.""" payload = TestAsset("src/main.py") - @jobs.remote(self.compute_pool, stage_name="payload_stage", session=self.session) + @jobs.remote( + self.compute_pool, + stage_name="payload_stage", + # ML Job pickling only guaranteed for matching SnowML version + imports=[(os.path.dirname(jobs.__file__), "snowflake.ml.jobs")], + session=self.session, + ) def check_job_status(job: jobs.MLJob[Any]) -> str: return job.status @@ -346,7 +369,7 @@ def check_job_status(job: jobs.MLJob[Any]) -> str: original_logs = job.get_logs() original_result = job.result() original_target_instances = job.target_instances - + cp.register_pickle_by_value(types) pickled_data = cp.dumps(job) self.assertIsInstance(pickled_data, bytes) self.assertGreater(len(pickled_data), 0) @@ -391,6 +414,126 @@ def check_job_status(job: jobs.MLJob[Any]) -> str: with self.assertRaisesRegex(RuntimeError, "No active Snowpark session available"): cp.loads(pickled_data) + def test_job_result_v2(self) -> None: + """Test that v2 job results can be saved and loaded correctly.""" + + def func_with_return_value() -> None: + return {"key": "value", "number": 123} + + job = self._submit_func_as_file( + func_with_return_value, + imports=[(os.path.dirname(jobs.__file__), "snowflake.ml.jobs")], + ) + self.assertEqual(job.wait(), "DONE", job.get_logs(verbose=True)) + + # Validate result + result: Any = job.result() + self.assertIsInstance(result, dict) + self.assertEqual(result.get("key"), "value") + self.assertEqual(result.get("number"), 123) + + # Ensure the result was saved with v2 (v1 results get loaded as normal ExecutionResults) + self.assertIsInstance(job._result, interop_result.LoadedExecutionResult) + + def test_job_result_v2_pandas(self) -> None: + """Test that v2 job results can save and load Pandas DataFrames correctly.""" + + def func_with_dataframe_return() -> None: + import numpy as np + import pandas as pd + + # Create a DataFrame with random data but fixed seed for reproducibility + np.random.seed(42) + data = { + "col1": np.random.randn(10), + "col2": np.random.randint(1, 100, 10), + "col3": ["value_" + str(i) for i in range(10)], + } + return pd.DataFrame(data) + + job = self._submit_func_as_file( + func_with_dataframe_return, + imports=[(os.path.dirname(jobs.__file__), "snowflake.ml.jobs")], + ) + self.assertEqual(job.wait(), "DONE", job.get_logs(verbose=True)) + + # Validate result + result: Any = job.result() + self.assertIsInstance(result, pd.DataFrame) + self.assertEqual(result.shape, (10, 3)) + + # Validate actual values by recreating the expected DataFrame + np.random.seed(42) + expected_data = { + "col1": np.random.randn(10), + "col2": np.random.randint(1, 100, 10), + "col3": ["value_" + str(i) for i in range(10)], + } + expected_df = pd.DataFrame(expected_data) + pd.testing.assert_frame_equal(result, expected_df) + + # Ensure the result was saved with v2 (v1 results get loaded as normal ExecutionResults) + self.assertIsInstance(job._result, interop_result.LoadedExecutionResult) + + def test_job_result_v2_numpy(self) -> None: + """Test that v2 job results can save and load NumPy arrays correctly.""" + + def func_with_numpy_return() -> None: + import numpy as np + + # Create a NumPy array with random data but fixed seed for reproducibility + np.random.seed(42) + return np.random.randn(5, 3) + + job = self._submit_func_as_file( + func_with_numpy_return, + imports=[(os.path.dirname(jobs.__file__), "snowflake.ml.jobs")], + ) + self.assertEqual(job.wait(), "DONE", job.get_logs(verbose=True)) + + # Validate result + result: Any = job.result() + self.assertIsInstance(result, np.ndarray) + self.assertEqual(result.shape, (5, 3)) + + # Validate actual values by recreating the expected array + np.random.seed(42) + expected_array = np.random.randn(5, 3) + np.testing.assert_array_equal(result, expected_array) + + # Ensure the result was saved with v2 (v1 results get loaded as normal ExecutionResults) + self.assertIsInstance(job._result, interop_result.LoadedExecutionResult) + + def test_job_result_backcompat(self) -> None: + """Test that v1 job results can still be loaded correctly.""" + + def func_with_return_value() -> None: + return {"key": "value", "number": 123} + + job = self._submit_func_as_file( + func_with_return_value, + spec_overrides={ + "spec": { + "containers": [ + { + "name": constants.DEFAULT_CONTAINER_NAME, + "image": f"{constants.DEFAULT_IMAGE_REPO}/{constants.DEFAULT_IMAGE_CPU}:1.7.1", + } + ] + } + }, + ) + self.assertEqual(job.wait(), "DONE", job.get_logs(verbose=True)) + + # Validate result + result: Any = job.result() + self.assertIsInstance(result, dict) + self.assertEqual(result.get("key"), "value") + self.assertEqual(result.get("number"), 123) + + # Ensure the result was saved with v1 (v2 results get loaded as LoadedExecutionResults) + self.assertEqual(type(job._result), interop_result.ExecutionResult) + # TODO(SNOW-1911482): Enable test for Python 3.11+ @absltest.skipIf( version.Version(env.PYTHON_VERSION) >= version.Version("3.11"), @@ -457,9 +600,11 @@ def __str__(self) -> str: version.Version(env.PYTHON_VERSION) >= version.Version("3.11"), "Decorator test only works for Python 3.10 and below due to pickle compatibility", ) # type: ignore[misc] - def test_job_execution_in_stored_procedure(self) -> None: - jobs_import_src = os.path.dirname(jobs.__file__) - + @parameterized.parameters( # type: ignore[misc] + "owner", + "caller", + ) + def test_job_execution_in_stored_procedure(self, sproc_rights: str) -> None: @jobs.remote(self.compute_pool, stage_name="payload_stage") def job_fn() -> None: print("Hello from remote function!") @@ -467,9 +612,8 @@ def job_fn() -> None: @F.sproc( session=self.session, packages=["snowflake-snowpark-python", "snowflake-ml-python"], - imports=[ - (jobs_import_src, "snowflake.ml.jobs"), - ], + imports=[(os.path.dirname(jobs.__file__), "snowflake.ml.jobs")], + execute_as=sproc_rights, ) def job_sproc(session: snowpark.Session) -> None: job = job_fn() @@ -736,8 +880,10 @@ def test_submit_job_fully_qualified_name(self): self.assertEqual(job_schema, identifier.resolve_identifier(schema or self.schema)) if schema == temp_schema: - with self.assertRaisesRegex(ValueError, "does not exist"): + with self.assertRaisesRegex(sp_exceptions.SnowparkSQLException, "does not exist") as cm: jobs.get_job(job_name, session=self.session) + self.assertRegex(str(cm.exception), "does not exist") + self.assertEqual(cm.exception.sql_error_code, 2003) else: self.assertIsNotNone(jobs.get_job(job_name, session=self.session)) finally: @@ -1080,7 +1226,7 @@ def test_job_with_different_python_version(self) -> None: except Exception: expected_runtime_image = None - with mock.patch.dict(os.environ, {feature_flags.FeatureFlags.ENABLE_IMAGE_VERSION_ENV_VAR.value: "true"}): + with mock.patch.dict(os.environ, {feature_flags.FeatureFlags.ENABLE_RUNTIME_VERSIONS.value: "true"}): job = jobs.submit_file( TestAsset("src/check_python.py").path, self.compute_pool, @@ -1155,24 +1301,25 @@ def test_modin_function() -> None: ("/snowflake/images/snowflake_images/st_plat/runtime/x86/runtime_image/snowbooks:1.7.1"), ) def test_job_with_runtime_environment(self, runtime_environment: str) -> None: - job_v1 = self._submit_func_as_file(dummy_function, runtime_environment=runtime_environment) - self.assertEqual(job_v1.wait(), "DONE", job_v1.get_logs()) - self.assertIn(runtime_environment, job_v1._container_spec["image"]) + def check_runtime_version() -> None: + from snowflake.runtime._version import __version__ as mlrs_version - rows = self.session.sql("SHOW PARAMETERS LIKE 'ENABLE_EXECUTE_ML_JOB_FUNCTION'").collect() - if not rows or rows[0]["value"] == "false": - self.skipTest("ENABLE_EXECUTE_ML_JOB_FUNCTION is disabled.") + print(mlrs_version) - try: - self.session.sql("ALTER SESSION SET ENABLE_EXECUTE_ML_JOB_FUNCTION = TRUE").collect() - with mock.patch.dict(os.environ, {feature_flags.FeatureFlags.USE_SUBMIT_JOB_V2.value: "true"}): - job_v2 = self._submit_func_as_file(dummy_function, runtime_environment=runtime_environment) - self.assertEqual(job_v2.wait(), "DONE", job_v2.get_logs()) - self.assertIn(runtime_environment, job_v2._container_spec["image"]) - except sp_exceptions.SnowparkSQLException: - self.skipTest("Unable to enable required session parameters for runtime_environment. Skipping test.") - finally: - self.session.sql("ALTER SESSION SET ENABLE_EXECUTE_ML_JOB_FUNCTION = FALSE").collect() + job = self._submit_func_as_file(check_runtime_version, runtime_environment=runtime_environment) + self.assertEqual(job.wait(), "DONE", job.get_logs()) + self.assertIn(runtime_environment, job._container_spec["image"]) + self.assertEqual("1.7.1", job.get_logs()) + + def test_get_job_after_job_deleted(self) -> None: + job = self._submit_func_as_file(dummy_function) + job.wait() + jobs.delete_job(job.id, session=self.session) + loaded_job = jobs.get_job(job.id, session=self.session) + self.assertIsNotNone(loaded_job.status) + self.assertIsNotNone(loaded_job.get_logs()) + self.assertEqual(loaded_job.target_instances, 1) + self.assertEqual(loaded_job._compute_pool, self.compute_pool) if __name__ == "__main__": diff --git a/tests/integ/snowflake/ml/jobs/lightgbm_distributed_integ_test.py b/tests/integ/snowflake/ml/jobs/lightgbm_distributed_integ_test.py index d3d28fdf..c493d2d6 100644 --- a/tests/integ/snowflake/ml/jobs/lightgbm_distributed_integ_test.py +++ b/tests/integ/snowflake/ml/jobs/lightgbm_distributed_integ_test.py @@ -7,7 +7,7 @@ from snowflake.ml.jobs import remote from snowflake.snowpark import DataFrame from snowflake.snowpark.context import get_active_session -from tests.integ.snowflake.ml.jobs.job_test_base import JobTestBase +from tests.integ.snowflake.ml.jobs.job_test_base import ModelingJobTestBase TEST_TABLE_NAME = "MULTINODE_CPU_LIGHTGBM_TRAIN_DS" @@ -29,7 +29,7 @@ def split_dataset(snowpark_df: DataFrame) -> tuple[DataFrame, DataFrame, str, li return train_df, test_df, label_col, feature_cols -class LightGBMDistributedTest(JobTestBase): +class LightGBMDistributedTest(ModelingJobTestBase): @classmethod def setUpClass(cls) -> None: super().setUpClass() diff --git a/tests/integ/snowflake/ml/jobs/multi_node_jobs_integ_test.py b/tests/integ/snowflake/ml/jobs/multi_node_jobs_integ_test.py index b4e5c423..346cabfb 100644 --- a/tests/integ/snowflake/ml/jobs/multi_node_jobs_integ_test.py +++ b/tests/integ/snowflake/ml/jobs/multi_node_jobs_integ_test.py @@ -54,7 +54,7 @@ def compute_heavy(n): return socket.gethostname() ray.init(address="auto", ignore_reinit_error=True) - hosts = [compute_heavy.remote(10_000) for _ in range(10)] + hosts = [compute_heavy.remote(20) for _ in range(10)] unique_hosts = set(ray.get(hosts)) assert ( len(unique_hosts) >= 2 diff --git a/tests/integ/snowflake/ml/jobs/pytorch_integ_test.py b/tests/integ/snowflake/ml/jobs/pytorch_integ_test.py index 2a09fd35..d3547669 100644 --- a/tests/integ/snowflake/ml/jobs/pytorch_integ_test.py +++ b/tests/integ/snowflake/ml/jobs/pytorch_integ_test.py @@ -2,14 +2,14 @@ from packaging import version from snowflake.ml._internal import env -from tests.integ.snowflake.ml.jobs.job_test_base import JobTestBase +from tests.integ.snowflake.ml.jobs.job_test_base import ModelingJobTestBase """ this integration test is only for pytorch. """ -class PytorchModelTest(JobTestBase): +class PytorchModelTest(ModelingJobTestBase): @absltest.skipIf( version.Version(env.PYTHON_VERSION) >= version.Version("3.11"), "only works for Python 3.10 and below due to pickle compatibility", diff --git a/tests/integ/snowflake/ml/jobs/tensorflow_integ_test.py b/tests/integ/snowflake/ml/jobs/tensorflow_integ_test.py index 904befcb..9a98c13c 100644 --- a/tests/integ/snowflake/ml/jobs/tensorflow_integ_test.py +++ b/tests/integ/snowflake/ml/jobs/tensorflow_integ_test.py @@ -2,14 +2,14 @@ from packaging import version from snowflake.ml._internal import env -from tests.integ.snowflake.ml.jobs.job_test_base import JobTestBase +from tests.integ.snowflake.ml.jobs.job_test_base import ModelingJobTestBase """ this integration test is only for tensorflow. """ -class TensorflowModelTest(JobTestBase): +class TensorflowModelTest(ModelingJobTestBase): @absltest.skipIf( version.Version(env.PYTHON_VERSION) >= version.Version("3.11"), "only works for Python 3.10 and below due to pickle compatibility", diff --git a/tests/integ/snowflake/ml/registry/model/registry_sklearn_model_test.py b/tests/integ/snowflake/ml/registry/model/registry_sklearn_model_test.py index 7ccd7c6a..fbdb3d57 100644 --- a/tests/integ/snowflake/ml/registry/model/registry_sklearn_model_test.py +++ b/tests/integ/snowflake/ml/registry/model/registry_sklearn_model_test.py @@ -260,6 +260,73 @@ def test_skl_unsupported_explain( self.registry.delete_model(model_name=name) + def test_skl_pipeline_explain_case_sensitive_with_quoted_identifiers_ignore_case(self) -> None: + # Build a pipeline with OneHotEncoder to simulate transformed feature names + data = { + "Color": ["red eyes", "blue", "green", "red eyes", "blue", "green"], + "size": [1, 2, 2, 4, 3, 1], + "price": [10, 15, 20, 25, 18, 12], + "target": [0, 1, 1, 0, 1, 0], + } + df = pd.DataFrame(data) + df["Color"] = df["Color"].astype("category") + input_features = ["Color", "size", "price"] + + preprocessor = compose.ColumnTransformer( + transformers=[ + ("cat", preprocessing.OneHotEncoder(), ["Color"]), + ], + remainder="passthrough", + ) + + pipeline = SK_pipeline.Pipeline( + [ + ("preprocessor", preprocessor), + ("classifier", linear_model.LogisticRegression(max_iter=1000)), + ] + ) + + pipeline.fit(df[input_features], df["target"]) + + name = "skl_pipeline_test_quoted_identifiers_case_sensitive_explain" + version = f"ver_{self._run_id}" + + mv = self.registry.log_model( + model=pipeline, + model_name=name, + version_name=version, + sample_input_data=df[input_features], + options={ + "enable_explainability": True, + # Ensure some methods are registered as case-sensitive, including explain + "method_options": { + "predict": {"case_sensitive": True}, + "predict_proba": {"case_sensitive": True}, + }, + }, + ) + + functions = mv._functions + find_method: Callable[[model_manifest_schema.ModelFunctionInfo], bool] = ( + lambda method: "explain" in method["name"] + ) + target_function_info = next( + filter(find_method, functions), + None, + ) + self.assertIsNotNone(target_function_info, "explain function not found") + + result = mv.run( + df[input_features], + function_name=target_function_info["name"], + strict_input_validation=False, + ) + + self.assertIsInstance(result, pd.DataFrame) + self.assertTrue(len(result) > 0, "Result should not be empty") + + self.registry.delete_model(model_name=name) + self.assertNotIn(mv.model_name, [m.name for m in self.registry.models()]) def test_skl_model_with_signature_and_sample_data(self) -> None: diff --git a/tests/integ/snowflake/ml/registry/model/registry_target_platforms_test.py b/tests/integ/snowflake/ml/registry/model/registry_target_platforms_test.py index 34bffa19..7d1cbca1 100644 --- a/tests/integ/snowflake/ml/registry/model/registry_target_platforms_test.py +++ b/tests/integ/snowflake/ml/registry/model/registry_target_platforms_test.py @@ -55,7 +55,7 @@ class TestRegistryTargetPlatformsInteg(registry_model_test_base.RegistryModelTes }, { "target_platforms": [type_hints.TargetPlatform.WAREHOUSE.value], - "pip_requirements": ["prophet"], + "pip_requirements": ["prophet", "pandas==2.1.4"], # Pin pandas version to override snowpark "conda_dependencies": None, "artifact_repository_map": True, "expect_error": False, diff --git a/tests/integ/snowflake/ml/registry/model/registry_xgboost_model_test.py b/tests/integ/snowflake/ml/registry/model/registry_xgboost_model_test.py index 549dafa1..846d65d3 100644 --- a/tests/integ/snowflake/ml/registry/model/registry_xgboost_model_test.py +++ b/tests/integ/snowflake/ml/registry/model/registry_xgboost_model_test.py @@ -162,6 +162,37 @@ def test_xgb_explain_explicitly_enabled(self) -> None: function_type_assert={"explain": model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION}, ) + def test_xgb_explain_case_sensitive(self) -> None: + cal_data = datasets.load_breast_cancer(as_frame=True) + cal_X = cal_data.data + cal_y = cal_data.target + cal_X.columns = [inflection.parameterize(c, "_") for c in cal_X.columns] + cal_X.rename(columns={"mean_radius": '"Mean Radius"'}, inplace=True) + + cal_X_train, cal_X_test, cal_y_train, cal_y_test = model_selection.train_test_split(cal_X, cal_y) + regressor = xgboost.XGBRegressor(n_estimators=100, reg_lambda=1, gamma=0, max_depth=3) + regressor.fit(cal_X_train, cal_y_train) + expected_explanations = shap.TreeExplainer(regressor)(cal_X_test).values + self._test_registry_model( + model=regressor, + sample_input_data=cal_X_test, + prediction_assert_fns={ + '"explain"': ( + cal_X_test, + lambda res: pd.testing.assert_frame_equal( + res, + pd.DataFrame(expected_explanations, columns=res.columns), + check_dtype=False, + ), + ), + }, + options={ + "enable_explainability": True, + "method_options": {"predict": {"case_sensitive": True}, "predict_proba": {"case_sensitive": True}}, + }, + function_type_assert={"explain": model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION}, + ) + def test_xgb_sp_no_explain(self) -> None: cal_data = datasets.load_breast_cancer(as_frame=True).frame cal_data.columns = [inflection.parameterize(c, "_") for c in cal_data] diff --git a/tests/integ/snowflake/ml/registry/services/BUILD.bazel b/tests/integ/snowflake/ml/registry/services/BUILD.bazel index 5a7a68d6..dcc27113 100644 --- a/tests/integ/snowflake/ml/registry/services/BUILD.bazel +++ b/tests/integ/snowflake/ml/registry/services/BUILD.bazel @@ -149,11 +149,22 @@ py_test( ], ) +py_test( + name = "registry_batch_inference_case_sensitivity_test", + timeout = "eternal", + srcs = ["registry_batch_inference_case_sensitivity_test.py"], + shard_count = 2, + tags = ["feature:model_registry"], + deps = [ + ":registry_model_deployment_test_base", + ], +) + py_test( name = "registry_custom_model_batch_inference_test", timeout = "eternal", srcs = ["registry_custom_model_batch_inference_test.py"], - shard_count = 2, + shard_count = 3, tags = ["feature:model_registry"], deps = [ ":registry_model_deployment_test_base", diff --git a/tests/integ/snowflake/ml/registry/services/registry_batch_inference_case_sensitivity_test.py b/tests/integ/snowflake/ml/registry/services/registry_batch_inference_case_sensitivity_test.py new file mode 100644 index 00000000..0baa5a6b --- /dev/null +++ b/tests/integ/snowflake/ml/registry/services/registry_batch_inference_case_sensitivity_test.py @@ -0,0 +1,185 @@ +import uuid + +import pandas as pd +from absl.testing import absltest + +from snowflake.ml.model import custom_model +from tests.integ.snowflake.ml.registry.services import ( + registry_model_deployment_test_base, +) + + +class TestModel(custom_model.CustomModel): + """Simple model for case sensitivity testing.""" + + def __init__(self, context: custom_model.ModelContext) -> None: + super().__init__(context) + + @custom_model.inference_api + def predict(self, input: pd.DataFrame) -> pd.DataFrame: + return pd.DataFrame({"output": [1] * len(input)}) + + +class RegistryBatchInferenceCaseSensitivityTest(registry_model_deployment_test_base.RegistryModelDeploymentTestBase): + def test_case_sensitive_1(self) -> None: + model = TestModel(custom_model.ModelContext()) + + # Model signature + sample_input_data = self.session.create_dataframe( + [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14]], + schema=[ + '"feature1"', + '"Feature2"', + '"FEATURE3"', + '"feature 4"', + '"feature_5"', + '"feature-6"', + '"feature7"', + ], + ) + + # Actual input data + input_spec = self.session.create_dataframe( + [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14]], + schema=[ + '"FEATURE1"', + '"FEATURE2"', + '"FEATURE3"', + '"FEATURE 4"', + '"FEATURE_5"', + '"FEATURE-6"', + '"feature7"', + ], + ) + + name = f"{str(uuid.uuid4()).replace('-', '_').upper()}" + output_stage_location = f"@{self._test_db}.{self._test_schema}.{self._test_stage}/{name}/output/" + + self._test_registry_batch_inference( + model=model, + sample_input_data=sample_input_data, + input_spec=input_spec, + output_stage_location=output_stage_location, + cpu_requests=None, + num_workers=1, + service_name=f"case_sensitivity_1_{name}", + replicas=1, + options={"method_options": {"predict": {"case_sensitive": True}}}, + ) + + def test_case_sensitive_2(self) -> None: + model = TestModel(custom_model.ModelContext()) + + # Model signature + sample_input_data = self.session.create_dataframe( + [[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]], + schema=['"FEATURE1"', '"FEATURE2"', '"FEATURE3"', '"FEATURE 4"', '"FEATURE_5"', '"FEATURE-6"'], + ) + + # Actual input data + input_spec = self.session.create_dataframe( + [[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]], + schema=['"feature1"', '"Feature2"', '"FEATURE3"', '"feature 4"', '"feature_5"', '"feature-6"'], + ) + + name = f"{str(uuid.uuid4()).replace('-', '_').upper()}" + output_stage_location = f"@{self._test_db}.{self._test_schema}.{self._test_stage}/{name}/output/" + + self._test_registry_batch_inference( + model=model, + sample_input_data=sample_input_data, + input_spec=input_spec, + output_stage_location=output_stage_location, + cpu_requests=None, + num_workers=1, + service_name=f"case_sensitivity_2_{name}", + replicas=1, + options={"method_options": {"predict": {"case_sensitive": True}}}, + ) + + def test_insensitive_model_input_signature(self) -> None: + model = TestModel(custom_model.ModelContext()) + + # Model signature + sample_input_data = self.session.create_dataframe( + [[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]], + schema=["FEATURE1", "FEATURE2", "FEATURE3", "FEATURE_4"], + ) + + # Actual input data + input_spec = self.session.create_dataframe( + [[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]], + schema=['"feature1"', '"Feature2"', "FEATURE3", '"feature_4"'], + ) + + name = f"{str(uuid.uuid4()).replace('-', '_').upper()}" + output_stage_location = f"@{self._test_db}.{self._test_schema}.{self._test_stage}/{name}/output/" + + self._test_registry_batch_inference( + model=model, + sample_input_data=sample_input_data, + input_spec=input_spec, + output_stage_location=output_stage_location, + cpu_requests=None, + num_workers=1, + service_name=f"case_sensitivity_3_{name}", + replicas=1, + options={"method_options": {"predict": {"case_sensitive": False}}}, + ) + + def test_column_reordering(self) -> None: + """Test that columns are properly reordered even with case differences.""" + model = TestModel(custom_model.ModelContext()) + + # Model signature expects specific order: feature1, feature2 (lowercase) + sample_input_data = self.session.create_dataframe([[1, 2], [3, 4]], schema=['"FEATURE1"', '"FEATURE2"']) + + # Actual input data has columns in different order and case: FEATURE2, FEATURE1 (uppercase, reversed) + input_spec = self.session.create_dataframe([[2, 1], [4, 3], [6, 5]], schema=['"feature2"', '"FEATURE1"']) + + name = f"{str(uuid.uuid4()).replace('-', '_').upper()}" + output_stage_location = f"@{self._test_db}.{self._test_schema}.{self._test_stage}/{name}/output/" + + self._test_registry_batch_inference( + model=model, + sample_input_data=sample_input_data, + input_spec=input_spec, + output_stage_location=output_stage_location, + cpu_requests=None, + num_workers=1, + service_name=f"reorder_{name}", + replicas=1, + options={"method_options": {"predict": {"case_sensitive": False}}}, + ) + + def test_extra_columns(self) -> None: + """Test case insensitive matching when input data has extra columns.""" + model = TestModel(custom_model.ModelContext()) + + # Model signature expects only two lowercase columns + sample_input_data = self.session.create_dataframe([[1, 2], [3, 4]], schema=['"feature1"', '"feature2"']) + + # Actual input data has extra columns and different case + input_spec = self.session.create_dataframe( + [[1, 2, "extra1", 10], [3, 4, "extra2", 11]], + schema=['"FEATURE1"', '"FEATURE2"', '"EXTRA_COL1"', '"EXTRA_COL2"'], + ) + + name = f"{str(uuid.uuid4()).replace('-', '_').upper()}" + output_stage_location = f"@{self._test_db}.{self._test_schema}.{self._test_stage}/{name}/output/" + + self._test_registry_batch_inference( + model=model, + sample_input_data=sample_input_data, + input_spec=input_spec, + output_stage_location=output_stage_location, + cpu_requests=None, + num_workers=1, + service_name=f"extra_cols_{name}", + replicas=1, + options={"method_options": {"predict": {"case_sensitive": True}}}, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/integ/snowflake/ml/registry/services/registry_catboost_batch_inference_test.py b/tests/integ/snowflake/ml/registry/services/registry_catboost_batch_inference_test.py index 23cf5ccb..9ac9d264 100644 --- a/tests/integ/snowflake/ml/registry/services/registry_catboost_batch_inference_test.py +++ b/tests/integ/snowflake/ml/registry/services/registry_catboost_batch_inference_test.py @@ -2,6 +2,7 @@ import catboost import inflection +import pandas as pd from absl.testing import absltest, parameterized from sklearn import datasets, model_selection @@ -29,15 +30,21 @@ def test_catboost( classifier = catboost.CatBoostClassifier() classifier.fit(cal_X_train, cal_y_train) - cal_data_sp_df_train = self.session.create_dataframe(cal_X_train) + + # Generate expected predictions using the original model + model_output = classifier.predict(cal_X_test) + model_output_df = pd.DataFrame({"output_feature_0": model_output}) + + # Prepare input data and expected predictions using common function + input_spec, expected_predictions = self._prepare_batch_inference_data(cal_X_test, model_output_df) name = f"{str(uuid.uuid4()).replace('-', '_').upper()}" output_stage_location = f"@{self._test_db}.{self._test_schema}.{self._test_stage}/{name}/output/" self._test_registry_batch_inference( model=classifier, - sample_input_data=cal_data_sp_df_train, - input_spec=cal_data_sp_df_train, + sample_input_data=cal_X_test, + input_spec=input_spec, output_stage_location=output_stage_location, gpu_requests=gpu_requests, cpu_requests=cpu_requests, @@ -46,6 +53,7 @@ def test_catboost( service_name=f"batch_inference_{name}", replicas=2, function_name="predict", + expected_predictions=expected_predictions, ) diff --git a/tests/integ/snowflake/ml/registry/services/registry_custom_model_batch_inference_test.py b/tests/integ/snowflake/ml/registry/services/registry_custom_model_batch_inference_test.py index b2c0a611..add589c1 100644 --- a/tests/integ/snowflake/ml/registry/services/registry_custom_model_batch_inference_test.py +++ b/tests/integ/snowflake/ml/registry/services/registry_custom_model_batch_inference_test.py @@ -1,8 +1,7 @@ -import uuid - import pandas as pd from absl.testing import absltest, parameterized +from snowflake.ml.jobs import delete_job, get_job from snowflake.ml.model import custom_model from tests.integ.snowflake.ml.registry.services import ( registry_model_deployment_test_base, @@ -15,40 +14,127 @@ def __init__(self, context: custom_model.ModelContext) -> None: @custom_model.inference_api def predict(self, input: pd.DataFrame) -> pd.DataFrame: - return pd.DataFrame({"output": input["c1"]}) + return pd.DataFrame({"output": input["C1"]}) class TestCustomModelBatchInferenceInteg(registry_model_deployment_test_base.RegistryModelDeploymentTestBase): + def _prepare_test(self): + model = DemoModel(custom_model.ModelContext()) + num_cols = 2 + + # Create input data + input_data = [[0] * num_cols, [1] * num_cols] + input_cols = [f"C{i}" for i in range(num_cols)] + + # Create pandas DataFrame + input_pandas_df = pd.DataFrame(input_data, columns=input_cols) + + # Generate expected predictions using the original model + model_output = model.predict(input_pandas_df[input_cols]) + + # Prepare input data and expected predictions using common function + input_spec, expected_predictions = self._prepare_batch_inference_data(input_pandas_df, model_output) + + # Create sample input data without INDEX column for model signature + sp_df = self.session.create_dataframe(input_data, schema=input_cols) + + service_name, output_stage_location = self._prepare_service_name_and_stage_for_batch_inference() + + return model, service_name, output_stage_location, input_spec, expected_predictions, sp_df + @parameterized.parameters( # type: ignore[misc] {"num_workers": 1, "replicas": 1, "cpu_requests": None}, {"num_workers": 2, "replicas": 2, "cpu_requests": "4"}, ) - def test_end_to_end_pipeline( + def test_custom_model( self, replicas: int, cpu_requests: str, num_workers: int, ) -> None: - model = DemoModel(custom_model.ModelContext()) - num_cols = 2 + model, service_name, output_stage_location, input_spec, expected_predictions, sp_df = self._prepare_test() - sp_df = self.session.create_dataframe( - [[0] * num_cols, [1] * num_cols], schema=[f'"c{i}"' for i in range(num_cols)] + self._test_registry_batch_inference( + model=model, + sample_input_data=sp_df, + input_spec=input_spec, + output_stage_location=output_stage_location, + cpu_requests=cpu_requests, + num_workers=num_workers, + service_name=service_name, + replicas=replicas, + function_name="predict", + expected_predictions=expected_predictions, ) - name = f"{str(uuid.uuid4()).replace('-', '_').upper()}" + def test_mljob_api(self) -> None: + model, service_name, output_stage_location, input_spec, _, sp_df = self._prepare_test() - output_stage_location = f"@{self._test_db}.{self._test_schema}.{self._test_stage}/{name}/output/" + replicas = 2 - self._test_registry_batch_inference( + job = self._test_registry_batch_inference( model=model, sample_input_data=sp_df, - input_spec=sp_df, + input_spec=input_spec, output_stage_location=output_stage_location, - cpu_requests=cpu_requests, - num_workers=num_workers, - service_name=f"batch_inference_{name}", + cpu_requests=None, replicas=replicas, + service_name=service_name, + blocking=False, + ) + + self.assertEqual(job.id, f"{self._test_db}.{self._test_schema}.{service_name}") + self.assertEqual(job.min_instances, 1) + self.assertEqual(job.target_instances, replicas) + self.assertIn(job.status, ["PENDING", "RUNNING"]) + self.assertEqual(job.name, service_name) + + # We just wanted to make sure the log functoin don't throw exceptions + job.get_logs() + job.show_logs() + + job.cancel() + job.wait() # wait until it is cancelled otherwise the job might be still pending + self.assertEqual(job.status, "CANCELLED") + + def test_mljob_job_manager(self) -> None: + model, service_name, output_stage_location, input_spec, _, sp_df = self._prepare_test() + + job = self._test_registry_batch_inference( + model=model, + sample_input_data=sp_df, + input_spec=input_spec, + output_stage_location=output_stage_location, + cpu_requests=None, + service_name=service_name, + blocking=False, + ) + + # the same job in another MLJob wrapper + job2 = get_job(job.id) + delete_job(job2) + + # the job will not be queryable any more + try: + job2.wait() + except Exception as e: + error_message = str(e) + self.assertIn("does not exist or not authorized", error_message) + + def test_default_system_compute_pool( + self, + ) -> None: + model, service_name, output_stage_location, input_spec, _, sp_df = self._prepare_test() + + self._test_registry_batch_inference( + model=model, + sample_input_data=sp_df, + input_spec=input_spec, + output_stage_location=output_stage_location, + service_name=service_name, + replicas=2, + function_name="predict", + service_compute_pool="SYSTEM_COMPUTE_POOL_CPU", ) diff --git a/tests/integ/snowflake/ml/registry/services/registry_inference_logging_test.py b/tests/integ/snowflake/ml/registry/services/registry_inference_logging_test.py index c6b63e22..ff945e1a 100644 --- a/tests/integ/snowflake/ml/registry/services/registry_inference_logging_test.py +++ b/tests/integ/snowflake/ml/registry/services/registry_inference_logging_test.py @@ -178,6 +178,7 @@ def _verify_processing_path(self, system_logs: list[str], expected_path: str) -> else: self.fail(f"Unknown expected_path: {expected_path}") + @absltest.skip("Skipping test_inference_logging_batch_path") def test_inference_logging_batch_path(self): """Test inference logging for small requests that trigger batch processing.""" mv = self._deploy_simple_sklearn_model(autocapture_enabled=True) @@ -212,6 +213,7 @@ def test_inference_logging_batch_path(self): # Verify batch path was used self._verify_processing_path(system_logs, "batch") + @absltest.skip("Skipping test_inference_logging_streaming_path") def test_inference_logging_streaming_path(self): """Test inference logging for large requests that trigger streaming processing.""" mv = self._deploy_simple_sklearn_model(autocapture_enabled=True) @@ -248,6 +250,7 @@ def test_inference_logging_streaming_path(self): # Verify streaming path was used self._verify_processing_path(system_logs, "streaming") + @absltest.skip("Skipping test_inference_logging_disabled_by_default") def test_inference_logging_disabled_by_default(self): """Test that inference logging is disabled by default (no logs captured).""" # Deploy model with autocapture explicitly DISABLED diff --git a/tests/integ/snowflake/ml/registry/services/registry_keras_batch_inference_test.py b/tests/integ/snowflake/ml/registry/services/registry_keras_batch_inference_test.py index 789957ee..379c6272 100644 --- a/tests/integ/snowflake/ml/registry/services/registry_keras_batch_inference_test.py +++ b/tests/integ/snowflake/ml/registry/services/registry_keras_batch_inference_test.py @@ -3,9 +3,10 @@ import keras import numpy as np import numpy.typing as npt +import pandas as pd from absl.testing import absltest, parameterized -from snowflake.ml.model._signatures import numpy_handler, snowpark_handler +from snowflake.ml.model._signatures import numpy_handler from tests.integ.snowflake.ml.registry.services import ( registry_model_deployment_test_base, ) @@ -40,10 +41,13 @@ def test_keras( model, data_x, data_y = _prepare_keras_functional_model() x_df = numpy_handler.NumpyArrayHandler.convert_to_df(data_x) x_df.columns = [f"input_feature_{i}" for i in range(len(x_df.columns))] - x_df_sp = snowpark_handler.SnowparkDataFrameHandler.convert_from_df( - self.session, - x_df, - ) + + # Generate expected predictions using the original model + model_output = model.predict(data_x) + model_output_df = pd.DataFrame({"output_feature_0": model_output.flatten()}) + + # Prepare input data and expected predictions using common function + input_spec, expected_predictions = self._prepare_batch_inference_data(x_df, model_output_df) name = f"{str(uuid.uuid4()).replace('-', '_').upper()}" output_stage_location = f"@{self._test_db}.{self._test_schema}.{self._test_stage}/{name}/output/" @@ -51,7 +55,7 @@ def test_keras( self._test_registry_batch_inference( model=model, sample_input_data=x_df, - input_spec=x_df_sp, + input_spec=input_spec, output_stage_location=output_stage_location, gpu_requests=gpu_requests, cpu_requests=cpu_requests, @@ -60,6 +64,7 @@ def test_keras( service_name=f"batch_inference_{name}", replicas=2, function_name="predict", + expected_predictions=expected_predictions, ) diff --git a/tests/integ/snowflake/ml/registry/services/registry_lightgbm_batch_inference_test.py b/tests/integ/snowflake/ml/registry/services/registry_lightgbm_batch_inference_test.py index d51eabc9..58296a9d 100644 --- a/tests/integ/snowflake/ml/registry/services/registry_lightgbm_batch_inference_test.py +++ b/tests/integ/snowflake/ml/registry/services/registry_lightgbm_batch_inference_test.py @@ -2,6 +2,7 @@ import inflection import lightgbm +import pandas as pd from absl.testing import absltest, parameterized from sklearn import datasets, model_selection @@ -29,20 +30,27 @@ def test_lightgbm_batch_inference( classifier = lightgbm.LGBMClassifier() classifier.fit(cal_X_train, cal_y_train) - sp_df = self.session.create_dataframe(cal_X_test) + # Generate expected predictions using the original model + model_output = classifier.predict(cal_X_test) + model_output_df = pd.DataFrame({"output_feature_0": model_output}) + + # Prepare input data and expected predictions using common function + input_spec, expected_predictions = self._prepare_batch_inference_data(cal_X_test, model_output_df) + name = f"{str(uuid.uuid4()).replace('-', '_').upper()}" output_stage_location = f"@{self._test_db}.{self._test_schema}.{self._test_stage}/{name}/output/" self._test_registry_batch_inference( model=classifier, - sample_input_data=sp_df, - input_spec=sp_df, + sample_input_data=cal_X_test, + input_spec=input_spec, output_stage_location=output_stage_location, cpu_requests=cpu_requests, num_workers=2, service_name=f"batch_inference_{name}", replicas=replicas, function_name="predict", + expected_predictions=expected_predictions, ) diff --git a/tests/integ/snowflake/ml/registry/services/registry_model_deployment_test_base.py b/tests/integ/snowflake/ml/registry/services/registry_model_deployment_test_base.py index e8f35f06..10767844 100644 --- a/tests/integ/snowflake/ml/registry/services/registry_model_deployment_test_base.py +++ b/tests/integ/snowflake/ml/registry/services/registry_model_deployment_test_base.py @@ -18,6 +18,7 @@ from cryptography.hazmat.primitives import serialization from snowflake import snowpark +from snowflake.ml import jobs from snowflake.ml._internal import file_utils, platform_capabilities as pc from snowflake.ml._internal.utils import identifier, jwt_generator, sql_identifier from snowflake.ml.model import ( @@ -45,6 +46,7 @@ class RegistryModelDeploymentTestBase(common_test_base.CommonTestBase): _TEST_CPU_COMPUTE_POOL = "REGTEST_INFERENCE_CPU_POOL" _TEST_GPU_COMPUTE_POOL = "REGTEST_INFERENCE_GPU_POOL" _TEST_SPCS_WH = "REGTEST_ML_SMALL" + _INDEX_COL = "INDEX" BUILDER_IMAGE_PATH = os.getenv("BUILDER_IMAGE_PATH", None) BASE_CPU_IMAGE_PATH = os.getenv("BASE_CPU_IMAGE_PATH", None) @@ -396,7 +398,9 @@ def _test_registry_batch_inference( memory_requests: Optional[str] = None, use_default_repo: bool = False, function_name: Optional[str] = None, - ) -> ModelVersion: + expected_predictions: Optional[pd.DataFrame] = None, + blocking: bool = True, + ) -> jobs.MLJob[Any]: conda_dependencies = [ test_env_utils.get_latest_package_version_spec_in_server(self.session, "snowflake-snowpark-python") ] @@ -428,6 +432,8 @@ def _test_registry_batch_inference( num_workers=num_workers, replicas=replicas, function_name=function_name, + expected_predictions=expected_predictions, + blocking=blocking, ) def _deploy_batch_inference( @@ -442,7 +448,9 @@ def _deploy_batch_inference( num_workers: Optional[int] = None, replicas: int = 1, function_name: Optional[str] = None, - ) -> ModelVersion: + expected_predictions: Optional[pd.DataFrame] = None, + blocking: bool = True, + ) -> jobs.MLJob[Any]: if self.BUILDER_IMAGE_PATH and self.BASE_CPU_IMAGE_PATH and self.BASE_GPU_IMAGE_PATH: with_image_override = True elif not self.BUILDER_IMAGE_PATH and not self.BASE_CPU_IMAGE_PATH and not self.BASE_GPU_IMAGE_PATH: @@ -488,7 +496,10 @@ def _deploy_batch_inference( function_name=function_name, ), ) - job.wait() + if blocking: + job.wait() + else: + return job self.assertEqual(job.status, "DONE") @@ -504,8 +515,75 @@ def _deploy_batch_inference( f"Output row count ({df.count()}) does not match input row count ({input_spec.count()})", ) + # Compare expected and actual output if provided + if expected_predictions is not None: + # Convert Snowpark DataFrame to pandas for comparison + actual_output = df.to_pandas() + + # Sort both dataframes by the index column for consistent comparison + self.assertTrue(self._INDEX_COL in expected_predictions.columns) + self.assertTrue(self._INDEX_COL in actual_output.columns) + expected_predictions = expected_predictions.sort_values(self._INDEX_COL).reset_index(drop=True) + actual_output = actual_output.sort_values(self._INDEX_COL).reset_index(drop=True) + + # Order columns consistently + expected_columns = sorted(expected_predictions.columns) + actual_columns = sorted(actual_output.columns) + + # Ensure both dataframes have the same columns + self.assertEqual( + set(expected_columns), + set(actual_columns), + f"Expected columns {expected_columns} do not match actual columns {actual_columns}", + ) + + # Reorder columns to match + actual_output = actual_output[expected_columns] + + # Compare the dataframes + pd.testing.assert_frame_equal( + expected_predictions, + actual_output, + check_dtype=False, + check_exact=False, + rtol=1e-3, + atol=1e-6, + ) + return mv + def _prepare_batch_inference_data( + self, + input_pandas_df: pd.DataFrame, + model_output: pd.DataFrame, + ) -> tuple[snowpark.DataFrame, pd.DataFrame]: + """Prepare input data with an index column and expected predictions. + + Args: + input_pandas_df: Input data as pandas DataFrame + model_output: Model predictions as pandas DataFrame + + Returns: + Tuple of (input_spec, expected_predictions) + """ + # Create input data with an index column for deterministic ordering + input_with_index = input_pandas_df.copy() + input_with_index[self._INDEX_COL] = range(len(input_pandas_df)) + + # Convert to Snowpark DataFrame + input_spec = self.session.create_dataframe(input_with_index) + + # Generate expected predictions by concatenating input data with model output + # Reset both indices to ensure proper alignment + expected_predictions = input_with_index.reset_index(drop=True) + model_output_reset = model_output.reset_index(drop=True) + expected_predictions = pd.concat([expected_predictions, model_output_reset], axis=1) + + # Sort columns to match the actual output order + expected_predictions = expected_predictions.reindex(columns=sorted(expected_predictions.columns)) + + return input_spec, expected_predictions + @staticmethod def retry_if_result_status_retriable(result: requests.Response) -> bool: if result.status_code in [ @@ -621,3 +699,20 @@ def _single_inference_request( timeout=60, # 60 second timeout since ingrrss will timeout after 60 seconds. # This will help in case the service itself is not reachable. ) + + def _prepare_service_name_and_stage_for_batch_inference(self) -> tuple[str, str]: + """Prepare batch inference setup by generating unique identifiers and output stage location. + + Creates a unique name based on UUID and constructs the corresponding output stage + location path for batch inference operations. + + Returns: + tuple[str, str]: A tuple containing: + - service_name: Unique identifier with underscores (replacing hyphens from UUID) + - output_stage_location: Full stage path for batch inference output files + """ + name = f"{str(uuid.uuid4()).replace('-', '_').upper()}" + service_name = f"BATCH_INFERENCE_{name}" + output_stage_location = f"@{self._test_db}.{self._test_schema}.{self._test_stage}/{service_name}/output/" + + return service_name, output_stage_location diff --git a/tests/integ/snowflake/ml/registry/services/registry_pytorch_batch_inference_test.py b/tests/integ/snowflake/ml/registry/services/registry_pytorch_batch_inference_test.py index 26cdd13e..f23ed1d1 100644 --- a/tests/integ/snowflake/ml/registry/services/registry_pytorch_batch_inference_test.py +++ b/tests/integ/snowflake/ml/registry/services/registry_pytorch_batch_inference_test.py @@ -1,9 +1,10 @@ import uuid +import pandas as pd import torch from absl.testing import absltest, parameterized -from snowflake.ml.model._signatures import pytorch_handler, snowpark_handler +from snowflake.ml.model._signatures import pytorch_handler from tests.integ.snowflake.ml.registry.services import ( registry_model_deployment_test_base, ) @@ -24,10 +25,13 @@ def test_pt( model, data_x, data_y = model_factory.ModelFactory.prepare_torch_model(torch.float64) x_df = pytorch_handler.PyTorchTensorHandler.convert_to_df(data_x, ensure_serializable=False) x_df.columns = [f"col_{i}" for i in range(data_x.shape[1])] - sp_df = snowpark_handler.SnowparkDataFrameHandler.convert_from_df( - self.session, - x_df, - ) + + # Generate expected predictions using the original model + model_output = model.forward(data_x) + model_output_df = pd.DataFrame({"output_feature_0": model_output.detach().numpy().flatten()}) + + # Prepare input data and expected predictions using common function + input_spec, expected_predictions = self._prepare_batch_inference_data(x_df, model_output_df) name = f"{str(uuid.uuid4()).replace('-', '_').upper()}" output_stage_location = f"@{self._test_db}.{self._test_schema}.{self._test_stage}/{name}/output/" @@ -35,7 +39,7 @@ def test_pt( self._test_registry_batch_inference( model=model, sample_input_data=x_df, - input_spec=sp_df, + input_spec=input_spec, output_stage_location=output_stage_location, gpu_requests=gpu_requests, cpu_requests=cpu_requests, @@ -43,6 +47,8 @@ def test_pt( num_workers=1, service_name=f"batch_inference_{name}", replicas=2, + function_name="forward", + expected_predictions=expected_predictions, ) diff --git a/tests/integ/snowflake/ml/registry/services/registry_sklearn_batch_inference_test.py b/tests/integ/snowflake/ml/registry/services/registry_sklearn_batch_inference_test.py index 37ce99d1..0ead54a9 100644 --- a/tests/integ/snowflake/ml/registry/services/registry_sklearn_batch_inference_test.py +++ b/tests/integ/snowflake/ml/registry/services/registry_sklearn_batch_inference_test.py @@ -19,7 +19,14 @@ def test_sklearn(self, pip_requirements: Optional[list[str]]) -> None: # Convert numpy array to pandas DataFrame for create_dataframe iris_df = pd.DataFrame(iris_X, columns=[f"input_feature_{i}" for i in range(iris_X.shape[1])]) - sp_df = self.session.create_dataframe(iris_df) + + # Generate expected predictions using the original model + model_output = svc.predict(iris_X) + model_output_df = pd.DataFrame({"output_feature_0": model_output}) + + # Prepare input data and expected predictions using common function + input_spec, expected_predictions = self._prepare_batch_inference_data(iris_df, model_output_df) + name = f"{str(uuid.uuid4()).replace('-', '_').upper()}" output_stage_location = f"@{self._test_db}.{self._test_schema}.{self._test_stage}/{name}/output/" @@ -28,12 +35,13 @@ def test_sklearn(self, pip_requirements: Optional[list[str]]) -> None: sample_input_data=iris_X, pip_requirements=pip_requirements, options={"enable_explainability": False}, - input_spec=sp_df, + input_spec=input_spec, output_stage_location=output_stage_location, num_workers=1, service_name=f"batch_inference_{name}", replicas=1, function_name="predict", + expected_predictions=expected_predictions, ) diff --git a/tests/integ/snowflake/ml/registry/services/registry_tensorflow_batch_inference_test.py b/tests/integ/snowflake/ml/registry/services/registry_tensorflow_batch_inference_test.py index c2e7e18b..55288cfe 100644 --- a/tests/integ/snowflake/ml/registry/services/registry_tensorflow_batch_inference_test.py +++ b/tests/integ/snowflake/ml/registry/services/registry_tensorflow_batch_inference_test.py @@ -1,8 +1,9 @@ import uuid +import pandas as pd from absl.testing import absltest, parameterized -from snowflake.ml.model._signatures import snowpark_handler, tensorflow_handler +from snowflake.ml.model._signatures import tensorflow_handler from tests.integ.snowflake.ml.registry.services import ( registry_model_deployment_test_base, ) @@ -23,10 +24,13 @@ def test_tf( model, data_x = model_factory.ModelFactory.prepare_tf_model() x_df = tensorflow_handler.TensorflowTensorHandler.convert_to_df(data_x, ensure_serializable=False) x_df.columns = [f"col_{i}" for i in range(x_df.shape[1])] - sp_df = snowpark_handler.SnowparkDataFrameHandler.convert_from_df( - self.session, - x_df, - ) + + # Generate expected predictions using the original model + model_output = model(data_x) + model_output_df = pd.DataFrame({"output_feature_0": model_output.numpy().flatten()}) + + # Prepare input data and expected predictions using common function + input_spec, expected_predictions = self._prepare_batch_inference_data(x_df, model_output_df) name = f"{str(uuid.uuid4()).replace('-', '_').upper()}" output_stage_location = f"@{self._test_db}.{self._test_schema}.{self._test_stage}/{name}/output/" @@ -34,7 +38,7 @@ def test_tf( self._test_registry_batch_inference( model=model, sample_input_data=x_df, - input_spec=sp_df, + input_spec=input_spec, output_stage_location=output_stage_location, gpu_requests=gpu_requests, cpu_requests=cpu_requests, @@ -42,6 +46,7 @@ def test_tf( num_workers=1, service_name=f"batch_inference_{name}", replicas=2, + expected_predictions=expected_predictions, ) diff --git a/tests/integ/snowflake/ml/registry/services/registry_xgboost_batch_inference_test.py b/tests/integ/snowflake/ml/registry/services/registry_xgboost_batch_inference_test.py index 253c487e..b4bf405a 100644 --- a/tests/integ/snowflake/ml/registry/services/registry_xgboost_batch_inference_test.py +++ b/tests/integ/snowflake/ml/registry/services/registry_xgboost_batch_inference_test.py @@ -1,6 +1,7 @@ import uuid import inflection +import pandas as pd import xgboost from absl.testing import absltest, parameterized from sklearn import datasets, model_selection @@ -31,7 +32,12 @@ def test_xgb( regressor = xgboost.XGBRegressor(n_estimators=100, reg_lambda=1, gamma=0, max_depth=3) regressor.fit(cal_X_train, cal_y_train) - sp_df = self.session.create_dataframe(cal_X_test) + # Generate expected predictions using the original model + model_output = regressor.predict(cal_X_test) + model_output_df = pd.DataFrame({"output_feature_0": model_output}) + + # Prepare input data and expected predictions using common function + input_spec, expected_predictions = self._prepare_batch_inference_data(cal_X_test, model_output_df) name = f"{str(uuid.uuid4()).replace('-', '_').upper()}" output_stage_location = f"@{self._test_db}.{self._test_schema}.{self._test_stage}/{name}/output/" @@ -43,7 +49,7 @@ def test_xgb( if gpu_requests else {"enable_explainability": False} ), - input_spec=sp_df, + input_spec=input_spec, output_stage_location=output_stage_location, gpu_requests=gpu_requests, cpu_requests=cpu_requests, @@ -51,6 +57,7 @@ def test_xgb( num_workers=1, service_name=f"batch_inference_{name}", replicas=2, + expected_predictions=expected_predictions, ) diff --git a/tests/integ/snowflake/ml/test_utils/_snowml_requirements.py b/tests/integ/snowflake/ml/test_utils/_snowml_requirements.py index 85cec2f8..cf476957 100755 --- a/tests/integ/snowflake/ml/test_utils/_snowml_requirements.py +++ b/tests/integ/snowflake/ml/test_utils/_snowml_requirements.py @@ -30,5 +30,5 @@ "sqlparse>=0.4,<1", "tqdm<5", "typing-extensions>=4.1.0,<5", - "xgboost>=1.7.3,<3", + "xgboost<4", ] diff --git a/tests/pytest.ini b/tests/pytest.ini index 5d5c312e..364a19e1 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -21,13 +21,3 @@ markers = pip_incompatible: mark a test as incompatible with pip environment. conda_incompatible: mark a test as incompatible with conda environment. spcs_deployment_image: mark a test as requiring the SPCS deployment image. - feature_area_model_registry: mark a test as belonging to the model registry feature area. - feature_area_feature_store: mark a test as belonging to the feature store feature area. - feature_area_jobs: mark a test as belonging to the jobs feature area. - feature_area_observability: mark a test as belonging to the observability feature area. - feature_area_cortex: mark a test as belonging to the cortex feature area. - feature_area_core: mark a test as belonging to the core feature area. - feature_area_modeling: mark a test as belonging to the modeling feature area. - feature_area_model_serving: mark a test as belonging to the model serving feature area. - feature_area_data: mark a test as belonging to the data feature area. - feature_area_none: mark a test as not belonging to any specific feature area.