diff --git a/ci/conda_recipe/meta.yaml b/ci/conda_recipe/meta.yaml index 7bc9658f..722c0384 100644 --- a/ci/conda_recipe/meta.yaml +++ b/ci/conda_recipe/meta.yaml @@ -20,14 +20,22 @@ requirements: run: - python - absl-py>=0.15,<2 + - anyio>=3.5.0,<4 - fsspec>=2022.11,<=2023.1 - numpy>=1.23,<1.24 - pyyaml>=6.0,<7 - scipy>=1.9,<2 - - scikit-learn==1.2.1 - snowflake-connector-python - - snowflake-snowpark-python>=1.0.0,<=1.3 + - snowflake-snowpark-python>=1.3.0,<=2 - sqlparse>=0.4,<1 + + # TODO(snandamuri): Versions of these packages must be exactly same between user's workspace and + # snowpark sandbox. Generic definitions like scikit-learn>=1.1.0,<2 wont work because snowflake conda channel + # only has a few allowlisted versions of scikit-learn available, so we must force users to use scikit-learn + # versions that are available in the snowflake conda channel. Since there is no way to specify allow list of + # versions in the requirements file, we are pinning the versions here. + - joblib>=1.0.0,<=1.1.1 + - scikit-learn==1.2.1 - xgboost==1.7.3 about: home: https://github.com/snowflakedb/snowflake-ml-python diff --git a/codegen/sklearn_wrapper_generator.py b/codegen/sklearn_wrapper_generator.py index 803744e9..8979b9a3 100644 --- a/codegen/sklearn_wrapper_generator.py +++ b/codegen/sklearn_wrapper_generator.py @@ -369,17 +369,11 @@ class WrapperGeneratorBase: original_class_name INFERRED Class name for the given scikit-learn estimator. - estimator_class_name GENERATED Name for the new estimator class. - transformer_class_name GENERATED [TODO] Name for the new transformer - class. module_name INFERRED Name of the module that given class is is contained in. estimator_imports GENERATED Imports needed for the estimator / fit() call. fit_sproc_imports GENERATED Imports needed for the fit sproc call. - transform_function_name INFERRED Name for the transformer function. This - will be one of "transform" or - "predict()" depending on the class. ------------------------------------------------------------------------------------ SIGNATURES AND ARGUMENTS ------------------------------------------------------------------------------------ @@ -444,9 +438,6 @@ def __init__(self, module_name: str, class_object: Tuple[str, type]) -> None: # Naming of the class. self.original_class_name = "" - self.estimator_class_name = "" - self.transformer_class_name = "" - self.transform_function_name = "" # The signature and argument passing the __init__ functions. self.original_init_signature = inspect.Signature() @@ -456,18 +447,21 @@ def __init__(self, module_name: str, class_object: Tuple[str, type]) -> None: self.sklearn_init_args_dict = "" self.estimator_init_member_args = "" + # Doc strings self.original_class_docstring = "" self.estimator_class_docstring = "" self.transformer_class_docstring = "" - - self.estimator_imports = "" - self.estimator_imports_list: List[str] = [] - self.original_fit_docstring = "" self.fit_docstring = "" self.original_transform_docstring = "" self.transform_docstring = "" + # Import strings + self.estimator_imports = "" + self.estimator_imports_list: List[str] = [] + self.additional_import_statements = "" + + # Test strings self.test_dataset_func = "" self.test_estimator_input_args = "" self.test_estimator_input_args_list: List[str] = [] @@ -475,14 +469,10 @@ def __init__(self, module_name: str, class_object: Tuple[str, type]) -> None: self.test_estimator_imports = "" self.test_estimator_imports_list: List[str] = [] - self.additional_import_statements = "" - + # Dependencies self.predict_udf_deps = "" self.fit_sproc_deps = "" - # TODO(amauser): Make fit a no-op if there is no internal state - # TODO(amauser): handling sparse input and output (LabelBinarizer) - def _format_default_value(self, default_value: Any) -> str: if isinstance(default_value, str): return f'"{default_value}"' @@ -561,26 +551,13 @@ def split_long_lines(line: str) -> str: self.estimator_class_docstring = class_docstring def _populate_class_names(self) -> None: - # TODO(snandamuri): All the 3 fields have exact same value. Do we really need these - # 3 separate fields? self.original_class_name = self.class_object[0] - self.estimator_class_name = self.original_class_name - self.transformer_class_name = self.estimator_class_name - self.test_class_name = f"{self.original_class_name}Test" def _populate_function_names_and_signatures(self) -> None: for member in inspect.getmembers(self.class_object[1]): if member[0] == "__init__": self.original_init_signature = inspect.signature(member[1]) - elif member[0] == "predict" or member[0] == "transform": - if self.transform_function_name != "": - print("ERROR: Class has both transform() and predict() methods.") - # TODO(snandamuri): Add support for both transform() and predict() methods in estimators. - # For now, resolve to predict() method when both predict() and transform() are available. - self.transform_function_name = "predict" - else: - self.transform_function_name = member[0] signature_lines = [] sklearn_init_lines = [] @@ -642,6 +619,7 @@ def _populate_function_names_and_signatures(self) -> None: self.estimator_init_member_args = "\n ".join(init_member_args) self.estimator_args_transform_calls = "\n ".join(arg_transform_calls) + # TODO(snandamuri): Implement type inference for classifiers. self.udf_datatype = "float" if self._from_data_py or self._is_regressor else "" def _populate_file_paths(self) -> None: @@ -825,7 +803,7 @@ def generate(self) -> "SklearnWrapperGenerator": self.test_estimator_input_args_list.extend(["min_samples_leaf=1", "max_leaf_nodes=100"]) self.fit_sproc_deps = self.predict_udf_deps = ( - "f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'scikit-learn=={sklearn.__version__}'," + "f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'scikit-learn=={sklearn.__version__}', " "f'xgboost=={xgboost.__version__}', f'joblib=={joblib.__version__}'" ) self._construct_string_from_lists() @@ -842,7 +820,7 @@ def generate(self) -> "XGBoostWrapperGenerator": self.test_estimator_input_args_list.extend(["random_state=0", "subsample=1.0", "colsample_bynode=1.0"]) self.fit_sproc_imports = "import xgboost" self.fit_sproc_deps = self.predict_udf_deps = ( - "f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'xgboost=={xgboost.__version__}'," + "f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'xgboost=={xgboost.__version__}', " "f'joblib=={joblib.__version__}'" ) self._construct_string_from_lists() @@ -859,7 +837,7 @@ def generate(self) -> "LightGBMWrapperGenerator": self.test_estimator_input_args_list.extend(["random_state=0"]) self.fit_sproc_imports = "import lightgbm" self.fit_sproc_deps = self.predict_udf_deps = ( - "f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'lightgbm=={lightgbm.__version__}'," + "f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'lightgbm=={lightgbm.__version__}', " "f'joblib=={joblib.__version__}'" ) self._construct_string_from_lists() diff --git a/codegen/sklearn_wrapper_template.py_template b/codegen/sklearn_wrapper_template.py_template index 46f1225e..f883b9cd 100644 --- a/codegen/sklearn_wrapper_template.py_template +++ b/codegen/sklearn_wrapper_template.py_template @@ -7,7 +7,6 @@ from typing import Iterable, Optional, Union, List, Any, Dict, Callable from uuid import uuid4 import joblib -import json import pandas as pd import numpy as np {transform.estimator_imports} @@ -20,7 +19,6 @@ from snowflake.ml._internal.utils import pkg_version_utils, identifier from snowflake.ml._internal.utils.temp_file_utils import cleanup_temp_files, get_temp_file_path from snowflake.snowpark import DataFrame, Session from snowflake.snowpark.functions import pandas_udf, sproc -from snowflake.snowpark.session import _get_active_session from snowflake.snowpark.types import PandasSeries _PROJECT = "ModelDevelopment" @@ -87,12 +85,12 @@ def _validate_sklearn_args(args: Dict[str, Any], klass: type) -> Dict[str, Any]: result = {{}} signature = inspect.signature(klass.__init__) for k, v in args.items(): - if k not in signature.parameters.keys(): # Arg is not supported. + if k not in signature.parameters.keys(): # Arg is not supported. if ( - v[2] # Arg doesn't have default value in the signature. + v[2] # Arg doesn't have default value in the signature. or ( v[0] != v[1] # Value is not same as default. - and not (isinstance(v[0], float) and np.isnan(v[0]) and np.isnan(v[1]))) # both are not NANs + and not (isinstance(v[0], float) and np.isnan(v[0]) and np.isnan(v[1]))) # both are not NANs ): raise RuntimeError(f"Arg {{k}} is not supported by current version of SKLearn/XGBoost.") else: @@ -100,14 +98,14 @@ def _validate_sklearn_args(args: Dict[str, Any], klass: type) -> Dict[str, Any]: return result -class {transform.estimator_class_name}(BaseEstimator, BaseTransformer): +class {transform.original_class_name}(BaseEstimator, BaseTransformer): r"""{transform.estimator_class_docstring} """ def __init__( {transform.estimator_init_signature} ) -> None: - super().__init__(custom_states=None) + super().__init__() self.id = str(uuid4()).replace("-", "_").upper() {transform.estimator_args_transform_calls} init_args = {transform.sklearn_init_args_dict} @@ -122,21 +120,18 @@ class {transform.estimator_class_name}(BaseEstimator, BaseTransformer): def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None: """ - Infer `self.input_cols` and `self.output_cols` if they are not set explicitly. + Infer `self.input_cols` and `self.output_cols` if they are not explicitly set. Args: dataset: Input dataset. """ if not self.input_cols: - non_input_cols = [] - if self.label_cols: - non_input_cols.extend(self.label_cols) - if self.sample_weight_col: - non_input_cols.extended(self.sample_weight_col) - - cols = [c for c in dataset.columns if c not in non_input_cols] + cols = [ + c for c in dataset.columns + if c not in self.get_label_cols() and c != self.sample_weight_col + ] self.set_input_cols(input_cols=cols) - + if not self.output_cols: cols = [identifier.concat_names(ids=['OUTPUT_', c]) for c in self.label_cols] self.set_output_cols(output_cols=cols) @@ -146,7 +141,7 @@ class {transform.estimator_class_name}(BaseEstimator, BaseTransformer): subproject=_SUBPROJECT, custom_tags=dict([("autogen", True)]), ) - def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "{transform.transformer_class_name}": + def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "{transform.original_class_name}": """{transform.fit_docstring} Args: @@ -176,9 +171,9 @@ class {transform.estimator_class_name}(BaseEstimator, BaseTransformer): # Specify input columns so column pruing will be enforced selected_cols = ( - self.input_cols + self.label_cols + [self.sample_weight_col] - if self.sample_weight_col is not None - else [] + self.input_cols + + self.label_cols + + [self.sample_weight_col] if self.sample_weight_col is not None else [] ) if len(selected_cols) > 0: dataset = dataset.select(selected_cols) @@ -221,6 +216,7 @@ class {transform.estimator_class_name}(BaseEstimator, BaseTransformer): api_calls=[sproc], custom_tags=dict([("autogen", True)]), ) + @sproc( is_permanent=False, name=fit_sproc_name, @@ -279,7 +275,7 @@ class {transform.estimator_class_name}(BaseEstimator, BaseTransformer): # Note: you can add something like + "|" + str(df) to the return string # to pass debug information to the caller. return str(os.path.basename(joblib_dump_files[0])) - + # Call fit sproc statement_params = telemetry.get_function_usage_statement_params( project=_PROJECT, @@ -369,6 +365,7 @@ class {transform.estimator_class_name}(BaseEstimator, BaseTransformer): api_calls=[pandas_udf], custom_tags=dict([("autogen", True)]), ) + @pandas_udf( is_permanent=False, name=batch_inference_udf_name, @@ -407,8 +404,8 @@ class {transform.estimator_class_name}(BaseEstimator, BaseTransformer): input_df.columns = unquoted_input_cols # Replace the quoted columns identifier with unquoted column ids. transformed_numpy_array = getattr(estimator, inference_method)(input_df) if ( - isinstance(transformed_numpy_array, list) - and len(transformed_numpy_array) > 0 + isinstance(transformed_numpy_array, list) + and len(transformed_numpy_array) > 0 and isinstance(transformed_numpy_array[0], np.ndarray) ): # In case of multioutput estimators, predict_proba(), decision_function(), etc., functions return @@ -422,7 +419,7 @@ class {transform.estimator_class_name}(BaseEstimator, BaseTransformer): transformed_numpy_array = np.hstack(transformed_numpy_array) if ( - len(transformed_numpy_array.shape) > 1 + len(transformed_numpy_array.shape) > 1 and transformed_numpy_array.shape[1] != len(expected_output_cols_list) ): # HeterogeneousEnsemble's transfrom method produce results with variying shapes @@ -446,7 +443,7 @@ class {transform.estimator_class_name}(BaseEstimator, BaseTransformer): pass_through_columns = self._get_pass_through_columns(dataset) # Run Transform query_from_df = str(dataset.queries["queries"][0]) - + outer_select_list = pass_through_columns[:] inner_select_list = pass_through_columns[:] @@ -494,22 +491,23 @@ class {transform.estimator_class_name}(BaseEstimator, BaseTransformer): dataset[self.input_cols] ) if ( - isinstance(transformed_numpy_array, list) - and len(transformed_numpy_array) > 0 - and isinstance(transformed_numpy_array[0], np.ndarray) - ): - # In case of multioutput estimators, predict_proba(), decision_function(), etc., functions return - # a list of ndarrays. We need to concatenate them. - - # First compute output column names - if(len(output_cols) == len(transformed_numpy_array)): - actual_output_cols = [] - for idx, np_arr in enumerate(transformed_numpy_array): - for i in range(1 if len(np_arr.shape) <= 1 else np_arr.shape[1]): - actual_output_cols.append(f"{{output_cols[idx]}}_{{i}}") - output_cols = actual_output_cols - # Concatenate np arrays - transformed_numpy_array = np.concatenate(transformed_numpy_array, axis=1) + isinstance(transformed_numpy_array, list) + and len(transformed_numpy_array) > 0 + and isinstance(transformed_numpy_array[0], np.ndarray) + ): + # In case of multioutput estimators, predict_proba(), decision_function(), etc., functions return + # a list of ndarrays. We need to concatenate them. + + # First compute output column names + if len(output_cols) == len(transformed_numpy_array): + actual_output_cols = [] + for idx, np_arr in enumerate(transformed_numpy_array): + for i in range(1 if len(np_arr.shape) <= 1 else np_arr.shape[1]): + actual_output_cols.append(f"{{output_cols[idx]}}_{{i}}") + output_cols = actual_output_cols + + # Concatenate np arrays + transformed_numpy_array = np.concatenate(transformed_numpy_array, axis=1) if len(transformed_numpy_array.shape) == 3: # VotingClassifier will return results of shape (n_classifiers, n_samples, n_classes) @@ -523,8 +521,10 @@ class {transform.estimator_class_name}(BaseEstimator, BaseTransformer): shape = transformed_numpy_array.shape if shape[1] != len(output_cols): if len(output_cols) != 1: - raise TypeError("expected_output_cols_list must be same length as transformed array or " - "should be of length 1 or should be of length number of label columns") + raise TypeError( + "expected_output_cols_list must be same length as transformed array or " + "should be of length 1 or should be of length number of label columns" + ) actual_output_cols = [] for i in range(shape[1]): actual_output_cols.append(f"{{output_cols[0]}}_{{i}}") @@ -589,9 +589,9 @@ class {transform.estimator_class_name}(BaseEstimator, BaseTransformer): """ if isinstance(dataset, DataFrame): expected_dtype = "{transform.udf_datatype}" - if {transform._is_heterogeneous_ensemble}: # is child of _BaseHeterogeneousEnsemble - # transform() method of HeterogeneousEnsemble estimators return responses of varying - # shapes from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between) + if {transform._is_heterogeneous_ensemble}: # is child of _BaseHeterogeneousEnsemble + # transform() method of HeterogeneousEnsemble estimators return responses of varying shapes + # from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between) # based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with # each row containing a list of values. expected_dtype = "ARRAY" @@ -764,7 +764,6 @@ class {transform.estimator_class_name}(BaseEstimator, BaseTransformer): return output_df - @available_if(_original_estimator_has_callable("score")) def score(self, dataset: Union[DataFrame, pd.DataFrame]) -> float: """{transform.fit_docstring} @@ -786,7 +785,6 @@ class {transform.estimator_class_name}(BaseEstimator, BaseTransformer): "Supported dataset types: snowpark.DataFrame, pandas.DataFrame." ) - def _score_sklearn(self, dataset: pd.DataFrame) -> float: argspec = inspect.getfullargspec(self._sklearn_object.score) if "X" in argspec.args: @@ -806,7 +804,6 @@ class {transform.estimator_class_name}(BaseEstimator, BaseTransformer): score = self._sklearn_object.score(**args) return score - def _score_snowpark(self, dataset: DataFrame) -> float: # Specify input columns so column pruing will be enforced selected_cols = ( @@ -902,7 +899,7 @@ class {transform.estimator_class_name}(BaseEstimator, BaseTransformer): result = estimator.score(**args) return result - + # Call score sproc statement_params = telemetry.get_function_usage_statement_params( project=_PROJECT, diff --git a/codegen/transformer_autogen_test_template.py_template b/codegen/transformer_autogen_test_template.py_template index 61c27693..b751a859 100644 --- a/codegen/transformer_autogen_test_template.py_template +++ b/codegen/transformer_autogen_test_template.py_template @@ -57,13 +57,10 @@ class {transform.test_class_name}(TestCase): # Normalize column names input_df_pandas.columns = [inflection.parameterize(c, "_").upper() for c in input_df_pandas.columns] - input_cols = [c for c in input_df_pandas.columns if not c.startswith("TARGET")] - if {transform._is_single_col_input}: - input_cols = [input_cols[0]] - label_col = [c for c in input_df_pandas.columns if c.startswith("TARGET")] + if add_sample_weight_col: random.seed(0) - input_df_pandas["sample_weight"] = np.array([random.randint(0, 100) for _ in range(input_df_pandas.shape[0])]) + input_df_pandas["SAMPLE_WEIGHT"] = np.array([random.randint(0, 100) for _ in range(input_df_pandas.shape[0])]) # Predict UDF processes and returns data in random order. # Add INDEX column so that output can be sorted by that column @@ -72,13 +69,12 @@ class {transform.test_class_name}(TestCase): if {transform._is_positive_value_input}: input_df_pandas = input_df_pandas.abs() - # Normalize column names - input_df_pandas.columns = [inflection.parameterize(c, "_").upper() for c in input_df_pandas.columns] - input_cols = [ c for c in input_df_pandas.columns if not c.startswith("TARGET") and not c.startswith("SAMPLE_WEIGHT") and not c.startswith("INDEX") ] + if {transform._is_single_col_input}: + input_cols = [input_cols[0]] label_col = [c for c in input_df_pandas.columns if c.startswith("TARGET")] return (input_df_pandas, input_cols, label_col) @@ -174,12 +170,7 @@ class {transform.test_class_name}(TestCase): num_example = sklearn_numpy_arr.shape[0] assert num_diffs < 0.1 * num_example else: - if not np.allclose(actual_arr, sklearn_numpy_arr, rtol=1.e-1, atol=1.e-2): - has_diff = ~np.isclose(actual_arr, sklearn_numpy_arr, rtol=1.e-1, atol=1.e-2) - print(f"Num differences: {{has_diff.sum()}}") - print(f"Actual values: {{actual_arr.take(has_diff.nonzero())}}") - print(f"SK values: {{sklearn_numpy_arr.take(has_diff.nonzero())}}") - raise AssertionError(f"Results didn't match for {{m}}") + np.testing.assert_allclose(actual_arr, sklearn_numpy_arr, rtol=1.e-1, atol=1.e-2) if {transform._is_classifier}: expected_methods = ["predict_proba", "predict_log_proba", "decision_function"] @@ -205,34 +196,19 @@ class {transform.test_class_name}(TestCase): # ndarrays as output. We need to concatenate them to compare with snowflake output. sklearn_inference_result = np.concatenate(sklearn_inference_result, axis=1) - if not np.allclose(actual_inference_result, sklearn_inference_result, rtol=1.e-1, atol=1.e-2): - has_diff = ~np.isclose(actual_inference_result, sklearn_inference_result, rtol=1.e-1, atol=1.e-2) - print(f"Num differences: {{has_diff.sum()}}") - print(f"Actual values: {{actual_inference_result.take(has_diff.nonzero())}}") - print(f"SK values: {{sklearn_inference_result.take(has_diff.nonzero())}}") - raise AssertionError(f"Results didn't match for {{m}}") + np.testing.assert_allclose( + actual_inference_result, sklearn_inference_result, rtol=1.e-1, atol=1.e-2) if callable(getattr(sklearn_reg, "score", None)) and callable(getattr(reg, "score", None)): + score_argspec = inspect.getfullargspec(sklearn_reg.score) # Some classes that has sample_weight argument in fit() but not in score(). - if use_weighted_dataset is True: - no_sample_weight_for_score = ['KernelDensity', 'RANSACRegressor'] - for c in inspect.getmro({transform.original_class_name}): - if c.__name__ in no_sample_weight_for_score: - del args['sample_weight'] - input_df_pandas = input_df_pandas.drop(['sample_weight', 'SAMPLE_WEIGHT'], axis=1, errors='ignore') + if use_weighted_dataset is True and 'sample_weight' not in score_argspec.args: + del args['sample_weight'] + input_df_pandas = input_df_pandas.drop(['sample_weight', 'SAMPLE_WEIGHT'], axis=1, errors='ignore') # Some classes have different arg name in score: X -> X_test - arg_name_is_x_test = [ - 'GraphicalLassoCV', - 'ShrunkCovariance', - 'LedoitWolf', - 'MinCovDet', - 'EmpiricalCovariance', - 'GraphicalLasso', - 'OAS'] - for c in inspect.getmro({transform.original_class_name}): - if c.__name__ in arg_name_is_x_test: - args['X_test'] = args.pop('X') + if "X_test" in score_argspec.args: + args['X_test'] = args.pop('X') if inference_with_udf: actual_score = getattr(reg, "score")(dataset=input_df) @@ -249,13 +225,7 @@ class {transform.test_class_name}(TestCase): sklearn_score = getattr(sklearn_reg, "score")(**args) - if not np.allclose(actual_score, sklearn_score, rtol=1.e-1, atol=1.e-2): - has_diff = ~np.isclose(actual_score, sklearn_score, rtol=1.e-1, atol=1.e-2) - print(f"Num differences: {{has_diff.sum()}}") - print(f"Actual values: {{actual_score.take(has_diff.nonzero())}}") - print(f"SK values: {{sklearn_score.take(has_diff.nonzero())}}") - raise AssertionError(f"Results didn't match for {{m}}") - + np.testing.assert_allclose(actual_score, sklearn_score, rtol=1.e-1, atol=1.e-2) def test_fit_with_sproc_infer_with_udf_non_weighted_datasets(self): diff --git a/snowflake/ml/preprocessing/k_bins_discretizer.py b/snowflake/ml/preprocessing/k_bins_discretizer.py index c4ef2623..f2467acb 100644 --- a/snowflake/ml/preprocessing/k_bins_discretizer.py +++ b/snowflake/ml/preprocessing/k_bins_discretizer.py @@ -16,6 +16,7 @@ from snowflake import snowpark from snowflake.ml.framework import base from snowflake.snowpark import functions as F, types as T +from snowflake.snowpark._internal import utils as snowpark_utils # constants used to validate the compatibility of the kwargs passed to the sklearn # transformer with the sklearn version @@ -46,9 +47,47 @@ def decimal_to_float(data: npt.NDArray[np.generic]) -> npt.NDArray[np.float32]: return np.array([float(x) for x in data]) -# TODO(tbao): add doc string +# TODO(tbao): suport kmeans with snowpark if needed # TODO(tbao): add telemetry class KBinsDiscretizer(base.BaseEstimator, base.BaseTransformer): + """ + Bin continuous data into intervals. + + Args: + n_bins: int or array-like of shape (n_features,), default=5 + The number of bins to produce. Raises ValueError if n_bins < 2. + + encode: {'onehot', 'onehot-dense', 'ordinal'}, default='onehot' + Method used to encode the transformed result. + + - 'onehot': Encode the transformed result with one-hot encoding and return a sparse representation. + - 'onehot-dense': Encode the transformed result with one-hot encoding and return separate column for + each encoded value. + - 'ordinal': Return the bin identifier encoded as an integer value. + + strategy: {'uniform', 'quantile'}, default='quantile' + Strategy used to define the widths of the bins. + + - 'uniform': All bins in each feature have identical widths. + - 'quantile': All bins in each feature have the same number of points. + + input_cols: str or Iterable [column_name], default=None + Single or multiple input columns. + + output_cols: str or Iterable [column_name], default=None + Single or multiple output columns. + + drop_input_cols: boolean, default=False + Remove input columns from output if set True. + + Attributes: + bin_edges_: ndarray of ndarray of shape (n_features,) + The edges of each bin. Contain arrays of varying shapes (n_bins_, ) + + n_bins_: ndarray of shape (n_features,), dtype=np.int_ + Number of bins per feature. + """ + def __init__( self, *, @@ -87,6 +126,18 @@ def _reset(self) -> None: self.n_bins_: Optional[npt.NDArray[np.int32]] = None def fit(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> KBinsDiscretizer: + """ + Fit KBinsDiscretizer with dataset. + + Args: + dataset: Input dataset. + + Returns: + Fitted self instance. + + Raises: + TypeError: If the input dataset is neither a pandas nor Snowpark DataFrame. + """ self._reset() self._enforce_params() super()._check_input_cols() @@ -107,6 +158,21 @@ def fit(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> KBinsDiscreti def transform( self, dataset: Union[snowpark.DataFrame, pd.DataFrame] ) -> Union[snowpark.DataFrame, pd.DataFrame, sparse.csr_matrix]: + """ + Discretize the data. + + Args: + dataset: Input dataset. + + Returns: + Discretized output data based on input type. + - If input is snowpark DataFrame, returns snowpark DataFrame + - If input is a pd.DataFrame and 'self.encdoe=onehot', returns 'csr_matrix' + - If input is a pd.DataFrame and 'self.encode in ['ordinal', 'onehot-dense']', returns 'pd.DataFrame' + + Raises: + TypeError: If the input dataset is neither a pandas nor Snowpark DataFrame. + """ self.enforce_fit() super()._check_input_cols() super()._check_output_cols() @@ -132,6 +198,13 @@ def _fit_snowpark(self, dataset: snowpark.DataFrame) -> None: raise NotImplementedError("kmeans not supported yet") def _handle_quantile(self, dataset: snowpark.DataFrame) -> None: + """ + Compute bins with percentile values of the feature. + All bins in each feature will have the same number of points. + + Args: + dataset: Input dataset. + """ # 1. Collect percentiles for each feature column # NB: if SQL compilation ever take too long on wide schema, consider applying optimization mentioned in # https://docs.google.com/document/d/1cilfCCtKYv6HvHqaqdZxfHAvQ0gg-t1AM8KYCQtJiLE/edit @@ -156,6 +229,13 @@ def _handle_quantile(self, dataset: snowpark.DataFrame) -> None: self.n_bins_[i] = len(self.bin_edges_[i]) - 1 def _handle_uniform(self, dataset: snowpark.DataFrame) -> None: + """ + Compute bins with min and max value of the feature. + All bins in each feature will have identical widths. + + Args: + dataset: Input dataset. + """ # 1. Collect min and max for each feature column agg_queries = list( chain.from_iterable( @@ -221,10 +301,14 @@ def _handle_ordinal(self, dataset: snowpark.DataFrame) -> snowpark.DataFrame: Returns: Output dataset with ordinal encoding. """ + # NB: the reason we need to generate a random UDF name each time is because the UDF registration + # is centralized per database, so if there are multiple sessions with same UDF name, there might be + # a conflict and some parties could fail to fetch the UDF. + udf_name = f"vec_bucketize_{snowpark_utils.generate_random_alphanumeric()}" # 1. Register vec_bucketize UDF @F.pandas_udf( # type: ignore[arg-type, misc] - name="vec_bucketize", + name=udf_name, replace=True, packages=["numpy"], session=dataset._session, @@ -242,7 +326,7 @@ def vec_bucketize(x: T.PandasSeries[float], boarders: T.PandasSeries[List[float] dataset = dataset.select( *dataset.columns, F.call_udf( - "vec_bucketize", F.col(input_col), F.array_construct(*boarders) # type: ignore[arg-type] + f"{udf_name}", F.col(input_col), F.array_construct(*boarders) # type: ignore[arg-type] ).alias(output_col), ) return dataset @@ -258,9 +342,10 @@ def _handle_onehot(self, dataset: snowpark.DataFrame) -> snowpark.DataFrame: Returns: Output dataset in sparse representation. """ + udf_name = f"vec_bucketize_sparse_{snowpark_utils.generate_random_alphanumeric()}" @F.pandas_udf( # type: ignore[arg-type, misc] - name="vec_bucketize_sparse_output", + name=udf_name, replace=True, packages=["numpy"], session=dataset._session, @@ -282,15 +367,27 @@ def vec_bucketize_sparse_output( boarders = [F.lit(float(x)) for x in self.bin_edges_[idx]] # type: ignore[arg-type, index] dataset = dataset.select( *dataset.columns, - F.call_udf( - "vec_bucketize_sparse_output", F.col(input_col), F.array_construct(*boarders) # type: ignore - ).alias(output_col), + F.call_udf(f"{udf_name}", F.col(input_col), F.array_construct(*boarders)).alias( # type: ignore + output_col + ), ) return dataset def _handle_onehot_dense(self, dataset: snowpark.DataFrame) -> snowpark.DataFrame: + """ + Transform dataset with bucketization and output as onehot dense representation: + Each category will be reprensented in its own output column. + + Args: + dataset: Input dataset. + + Returns: + Output dataset in dense representation. + """ + udf_name = f"vec_bucketize_dense_{snowpark_utils.generate_random_alphanumeric()}" + @F.pandas_udf( # type: ignore[arg-type, misc] - name="vec_bucketize_dense_output", + name=udf_name, replace=True, packages=["numpy"], session=dataset._session, @@ -313,9 +410,9 @@ def vec_bucketize_dense_output( boarders = [F.lit(float(x)) for x in self.bin_edges_[idx]] # type: ignore[arg-type, index] dataset = dataset.select( *dataset.columns, - F.call_udf( - "vec_bucketize_dense_output", F.col(input_col), F.array_construct(*boarders) # type: ignore - ).alias(output_col), + F.call_udf(f"{udf_name}", F.col(input_col), F.array_construct(*boarders)).alias( # type: ignore + output_col + ), ) dataset = dataset.with_columns( [f"{output_col}_{i}" for i in range(len(boarders) - 1)], diff --git a/snowflake/ml/registry/model_registry.py b/snowflake/ml/registry/model_registry.py index 27fb9f50..2aa7ee30 100644 --- a/snowflake/ml/registry/model_registry.py +++ b/snowflake/ml/registry/model_registry.py @@ -359,7 +359,7 @@ def _insert_metadata_entry(self, *, id: str, attribute: str, value: Any) -> List return self._insert_table_entry(table=self._fully_qualified_metadata_table_name(), columns=columns) - def _prepare_model_stage(self, *, id: str) -> str: + def _prepare_model_stage(self, *, model_name: str, model_version: str) -> str: """Create a stage in the model registry for storing the model with the given id. Creating a permanent stage here since we do not have a way to swtich a stage from temporary to permanent. @@ -368,7 +368,8 @@ def _prepare_model_stage(self, *, id: str) -> str: operation is complete. Args: - id: Identifier string of the model intended to be stored in the stage. + model_name: Model Name string. + model_version: Model Version string. Returns: Name of the stage that was created. @@ -378,8 +379,10 @@ def _prepare_model_stage(self, *, id: str) -> str: """ schema = self._fully_qualified_schema_name() + stage_name = f"{model_name}_{model_version}".replace("-", "_").upper() + # Replacing dashes and uppercasing the model_stage_name to avoid having to quote the the stage name. - model_stage_name = "SNOWML_MODEL_{safe_id}".format(safe_id=id.replace("-", "_").upper()) + model_stage_name = f"SNOWML_MODEL_{stage_name}" fully_qualified_model_stage_name = f"{schema}.{model_stage_name}" statement_params = self._get_statement_params(inspect.currentframe()) @@ -405,7 +408,7 @@ def _prepare_model_stage(self, *, id: str) -> str: def _list_selected_models( self, *, id: Optional[str] = None, model_name: Optional[str] = None, model_version: Optional[str] = None - ) -> Any: + ) -> snowpark.DataFrame: """Retrieve the Snowpark dataframe of models matching the specified ID or (name and version). Args: @@ -433,7 +436,43 @@ def _list_selected_models( snowpark.Column("VERSION") == model_version ) - return filtered_models + return cast(snowpark.DataFrame, filtered_models) + + def _validate_exact_one_result( + self, selected_model: snowpark.DataFrame, model_identifier: str + ) -> List[snowpark.Row]: + """Validate the filtered model has exactly one result. + + Args: + selected_model: A snowpark dataframe representing the models that are filtered out. + model_identifier: A string which is used to filter the model. + + Returns: + A snowpark row which contains the metadata of the filtered model + + Raises: + KeyError: The target model doesn't exist. + DataError: The target model is not unique. + """ + statement_params = self._get_statement_params(inspect.currentframe()) + model_info = None + try: + model_info = ( + query_result_checker.ResultValidator(result=selected_model.collect(statement_params=statement_params)) + .has_dimensions(expected_rows=1) + .validate() + ) + except connector.DataError: + if model_info is None or len(model_info) == 0: + raise KeyError(f"The model {model_identifier} does not exist in the current registry.") + else: + raise connector.DataError( + formatting.unwrap( + f"""There are {len(model_info)} models {model_identifier}. This might indicate a problem with + the integrity of the model registry data.""" + ) + ) + return model_info def _get_metadata_attribute( self, @@ -442,7 +481,7 @@ def _get_metadata_attribute( model_name: Optional[str] = None, model_version: Optional[str] = None, ) -> Any: - """Get the value of the given metadata attribute for target model with (model name + model version) or id. + """Get the value of the given metadata attribute for target model with given (model name + model version) or id. Args: attribute: Name of the attribute to get. @@ -452,21 +491,11 @@ def _get_metadata_attribute( Returns: The value of the attribute that was requested. Can be None if the attribute is not set. - - Raises: - DataError: The given model identifier points to more than one models. """ - statement_params = self._get_statement_params(inspect.currentframe()) selected_models = self._list_selected_models(id=id, model_name=model_name, model_version=model_version) - result = selected_models.select(attribute).collect(statement_params=statement_params) - - if len(result) > 1: - identifier = f"id {id}" if id else f"{model_name}/{model_version}" - raise connector.DataError(f"Model {identifier} existed {len(result)} times. It should only exist once.") - elif len(result) == 1 and attribute in result[0]: - return result[0][attribute] - else: - return None + identifier = f"id {id}" if id else f"{model_name}/{model_version}" + model_info = self._validate_exact_one_result(selected_models, identifier) + return model_info[0][attribute] def _set_metadata_attribute( self, @@ -477,7 +506,7 @@ def _set_metadata_attribute( model_version: Optional[str] = None, enable_model_presence_check: bool = True, ) -> None: - """Set the value of the given metadata attribute for model id. + """Set the value of the given metadata attribute for targat model with given (model name + model version) or id. Args: attribute: Name of the attribute to set. @@ -489,23 +518,20 @@ def _set_metadata_attribute( before setting the metadata attribute. False by default meaning that by default we will check. Raises: - DataError: The requested model id could not be found or is ambiguous. + DataError: Failed to set the meatdata attribute. + KeyError: The target model doesn't exist """ - statement_params = self._get_statement_params(inspect.currentframe()) selected_models = self._list_selected_models(id=id, model_name=model_name, model_version=model_version) - number_of_entries_filtered = selected_models.count(statement_params=statement_params) - identifier = f"id {id}" if id else f"{model_name}/{model_version}" - if enable_model_presence_check and number_of_entries_filtered == 0: - raise connector.DataError(f"Model {identifier} was not found in the registry.") - elif number_of_entries_filtered > 1: - raise connector.DataError( - f"Model {identifier} existed {number_of_entries_filtered} times. It should only exist once." - ) + try: + model_info = self._validate_exact_one_result(selected_models, identifier) + except KeyError as e: + # If the target model doesn't exist, raise the error only if enable_model_presence_check is True. + if enable_model_presence_check: + raise e if not id: - res = selected_models.select("ID").collect(statement_params=statement_params) - id = res[0]["ID"] + id = model_info[0]["ID"] assert id is not None try: @@ -569,7 +595,7 @@ def list_models(self) -> snowpark.DataFrame: project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, ) - def _get_model_id(self, *, model_name: Optional[str], model_version: Optional[str]) -> str: + def _get_model_id(self, *, model_name: str, model_version: str) -> str: """Get ID of the model with the given (model name + model version). Args: @@ -594,58 +620,53 @@ def _get_model_id(self, *, model_name: Optional[str], model_version: Optional[st def set_tag( self, name: str, - value: str, - id: Optional[str] = None, - model_name: Optional[str] = None, - model_version: Optional[str] = None, + model_name: str, + model_version: str, + value: Optional[str] = None, ) -> None: - """Set model tag to with value. + """Set model tag to the model with value. If the model tag already exists, the tag value will be overwritten. Args: name: Desired tag name. + model_name: Model Name string. + model_version: Model Version string. value: (optional) New tag value. If no value is given the value of the tag will be set to None. - id: Model ID string. Required if either model name or model version is None. - model_name: Model Name string. Required if id is None. - model_version: Model Version string. Required if version is None. """ # This method uses a read-modify-write pattern for setting tags. # TODO(amauser): Investigate the use of transactions to avoid race conditions. - model_tags = self.get_tags(id=id, model_name=model_name, model_version=model_version) + model_tags = self.get_tags(model_name=model_name, model_version=model_version) model_tags[name] = value self._set_metadata_attribute( - _METADATA_ATTRIBUTE_TAGS, model_tags, id=id, model_name=model_name, model_version=model_version + _METADATA_ATTRIBUTE_TAGS, model_tags, model_name=model_name, model_version=model_version ) @telemetry.send_api_usage_telemetry( project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, ) - def remove_tag( - self, name: str, id: Optional[str] = None, model_name: Optional[str] = None, model_version: Optional[str] = None - ) -> None: + def remove_tag(self, name: str, model_name: str, model_version: str) -> None: """Remove target model tag. Args: name: Desired tag name. - id: Model ID string. Required if either model name or model version is None. - model_name: Model Name string. Required if id is None. - model_version: Model Version string. Required if version is None. + model_name: Model Name string. + model_version: Model Version string. Raises: DataError: If the model does not have the requested tag. """ # This method uses a read-modify-write pattern for updating tags. - model_tags = self.get_tags(id=id, model_name=model_name, model_version=model_version) + model_tags = self.get_tags(model_name=model_name, model_version=model_version) try: del model_tags[name] except KeyError: raise connector.DataError(f"Model id {id} has not tag named {name}. Full list of tags: {model_tags}") self._set_metadata_attribute( - _METADATA_ATTRIBUTE_TAGS, model_tags, id=id, model_name=model_name, model_version=model_version + _METADATA_ATTRIBUTE_TAGS, model_tags, model_name=model_name, model_version=model_version ) @telemetry.send_api_usage_telemetry( @@ -655,10 +676,9 @@ def remove_tag( def has_tag( self, name: str, + model_name: str, + model_version: str, value: Optional[str] = None, - id: Optional[str] = None, - model_name: Optional[str] = None, - model_version: Optional[str] = None, ) -> bool: """Check if a model has a tag with the given name and value. @@ -666,53 +686,48 @@ def has_tag( Args: name: Desired tag name. - value: (optional) Tag value to check. If not value is given, only the presence of the tag will be - checked. - id: Model ID string. Required if either model name or model version is None. - model_name: Model Name string. Required if id is None. - model_version: Model Version string. Required if version is None. + model_name: Model Name string. + model_version: Model Version string. + value: (optional) Tag value to check. If not value is given, only the presence of the tag will be checked. Returns: True if the tag or tag and value combination is present for the model with the given id, False otherwise. """ - tags = self.get_tags(id=id, model_name=model_name, model_version=model_version) - return name in tags and tags[name] == str(value) + tags = self.get_tags(model_name=model_name, model_version=model_version) + has_tag = name in tags + if value is None: + return has_tag + return has_tag and tags[name] == str(value) @telemetry.send_api_usage_telemetry( project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, ) - def get_tag_value( - self, name: str, id: Optional[str] = None, model_name: Optional[str] = None, model_version: Optional[str] = None - ) -> Optional[str]: + def get_tag_value(self, name: str, model_name: str, model_version: str) -> Any: """Return the value of the tag for the model. The returned value can be None. If the tag does not exist, KeyError will be raised. Args: name: Desired tag name. - id: Model ID string. Required if either model name or model version is None. - model_name: Model Name string. Required if id is None. - model_version: Model Version string. Required if version is None. + model_name: Model Name string. + model_version: Model Version string. Returns: Value string of the tag or None, if no value is set for the tag. """ - return self.get_tags(id=id, model_name=model_name, model_version=model_version)[name] + return self.get_tags(model_name=model_name, model_version=model_version)[name] @telemetry.send_api_usage_telemetry( project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, ) - def get_tags( - self, id: Optional[str] = None, model_name: Optional[str] = None, model_version: Optional[str] = None - ) -> Dict[str, str]: - """Get all tags and values stored for the given (model name + model version) or model id. + def get_tags(self, model_name: str = None, model_version: str = None) -> Dict[str, Any]: + """Get all tags and values stored for the target model. Args: - id: Model ID string. Required if either model name or model version is None. - model_name: Model Name string. Required if id is None. - model_version: Model Version string. Required if version is None. + model_name: Model Name string. + model_version: Model Version string. Returns: String-to-string dictionary containing all tags and values. The resulting dictionary can be empty. @@ -720,11 +735,11 @@ def get_tags( # Snowpark snowpark.dataframes returns dictionary objects as strings. We need to convert it back to a dictionary # here. result = self._get_metadata_attribute( - _METADATA_ATTRIBUTE_TAGS, id=id, model_name=model_name, model_version=model_version + _METADATA_ATTRIBUTE_TAGS, model_name=model_name, model_version=model_version ) if result: - ret: Dict[str, str] = json.loads(result) + ret: Dict[str, Optional[str]] = json.loads(result) return ret else: return dict() @@ -733,21 +748,18 @@ def get_tags( project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, ) - def get_model_description( - self, id: Optional[str] = None, model_name: Optional[str] = None, model_version: Optional[str] = None - ) -> Optional[str]: - """Get the description of the model with the given (model name + model version) or id. + def get_model_description(self, model_name: str, model_version: str) -> Optional[str]: + """Get the description of the model. Args: - id: Model ID string. Required if either model name or model version is None. - model_name: Model Name string. Required if id is None. - model_version: Model Version string. Required if version is None. + model_name: Model Name string. + model_version: Model Version string. Returns: Descrption of the model or None. """ result = self._get_metadata_attribute( - _METADATA_ATTRIBUTE_DESCRIPTION, id=id, model_name=model_name, model_version=model_version + _METADATA_ATTRIBUTE_DESCRIPTION, model_name=model_name, model_version=model_version ) return None if result is None else str(result) @@ -759,20 +771,18 @@ def set_model_description( self, *, description: str, - id: Optional[str] = None, - model_name: Optional[str] = None, - model_version: Optional[str] = None, + model_name: str, + model_version: str, ) -> None: - """Set the description of the model with the given id. + """Set the description of the model. Args: description: Desired new model description. - id: Model ID string. Required if either model name or model version is None. - model_name: Model Name string. Required if id is None. - model_version: Model Version string. Required if version is None. + model_name: Model Name string. + model_version: Model Version string. """ self._set_metadata_attribute( - _METADATA_ATTRIBUTE_DESCRIPTION, description, id=id, model_name=model_name, model_version=model_version + _METADATA_ATTRIBUTE_DESCRIPTION, description, model_name=model_name, model_version=model_version ) @telemetry.send_api_usage_telemetry( @@ -808,24 +818,21 @@ def get_history(self) -> snowpark.DataFrame: ) def get_model_history( self, - id: Optional[str] = None, - model_name: Optional[str] = None, - model_version: Optional[str] = None, + model_name: str, + model_version: str, ) -> snowpark.DataFrame: """Return a dataframe with the history of operations performed on the desired model. The returned dataframe is order by time and can be filtered further. Args: - id: Id of the model to retrieve the history for. Required if either model name or model version is None. - model_name: Model Name string. Required if id is None. - model_version: Model Version string. Required if version is None. + model_name: Model Name string. + model_version: Model Version string. Returns: snowpark.DataFrame with the history of the model. """ - if not id: - id = self._get_model_id(model_name=model_name, model_version=model_version) + id = self._get_model_id(model_name=model_name, model_version=model_version) return cast(snowpark.DataFrame, self.get_history().filter(snowpark.Column("MODEL_ID") == id)) @telemetry.send_api_usage_telemetry( @@ -836,9 +843,8 @@ def set_metric( self, name: str, value: object, - id: Optional[str] = None, - model_name: Optional[str] = None, - model_version: Optional[str] = None, + model_name: str, + model_version: str, ) -> None: """Set scalar model metric to value. @@ -847,16 +853,15 @@ def set_metric( Args: name: Desired metric name. value: New metric value. - id: Model ID string. Required if either model name or model version is None. - model_name: Model Name string. Required if id is None. - model_version: Model Version string. Required if version is None. + model_name: Model Name string. + model_version: Model Version string. """ # This method uses a read-modify-write pattern for setting tags. # TODO(amauser): Investigate the use of transactions to avoid race conditions. - model_metrics = self.get_metrics(id=id, model_name=model_name, model_version=model_version) + model_metrics = self.get_metrics(model_name=model_name, model_version=model_version) model_metrics[name] = value self._set_metadata_attribute( - _METADATA_ATTRIBUTE_METRICS, model_metrics, id=id, model_name=model_name, model_version=model_version + _METADATA_ATTRIBUTE_METRICS, model_metrics, model_name=model_name, model_version=model_version ) @telemetry.send_api_usage_telemetry( @@ -866,91 +871,80 @@ def set_metric( def remove_metric( self, name: str, - id: Optional[str] = None, - model_name: Optional[str] = None, - model_version: Optional[str] = None, + model_name: str, + model_version: str, ) -> None: """Remove a specific metric entry from the model. Args: name: Desired tag name. - id: Model ID string. Required if either model name or model version is None. - model_name: Model Name string. Required if id is None. - model_version: Model Version string. Required if version is None. + model_name: Model Name string. + model_version: Model Version string. Raises: DataError: If the model does not have the requested metric. """ # This method uses a read-modify-write pattern for updating tags. - model_metrics = self.get_metrics(id=id, model_name=model_name, model_version=model_version) + model_metrics = self.get_metrics(model_name=model_name, model_version=model_version) try: del model_metrics[name] except KeyError: raise connector.DataError( - f"Model id {id} has no metric named {name}. Full list of metrics: {model_metrics}" + f"Model {model_name}/{model_version} has no metric named {name}. Full list of metrics: {model_metrics}" ) self._set_metadata_attribute( - _METADATA_ATTRIBUTE_METRICS, model_metrics, id=id, model_name=model_name, model_version=model_version + _METADATA_ATTRIBUTE_METRICS, model_metrics, model_name=model_name, model_version=model_version ) @telemetry.send_api_usage_telemetry( project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, ) - def has_metric( - self, name: str, id: Optional[str] = None, model_name: Optional[str] = None, model_version: Optional[str] = None - ) -> bool: + def has_metric(self, name: str, model_name: str, model_version: str) -> bool: """Check if a model has a metric with the given name. Args: name: Desired metric name. - id: Model ID string. Required if either model name or model version is None. - model_name: Model Name string. Required if id is None. - model_version: Model Version string. Required if version is None. + model_name: Model Name string. + model_version: Model Version string. Returns: True if the metric is present for the model with the given id, False otherwise. """ - metrics = self.get_metrics(id=id, model_name=model_name, model_version=model_version) + metrics = self.get_metrics(model_name=model_name, model_version=model_version) return name in metrics @telemetry.send_api_usage_telemetry( project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, ) - def get_metric_value( - self, name: str, id: Optional[str] = None, model_name: Optional[str] = None, model_version: Optional[str] = None - ) -> Optional[object]: + def get_metric_value(self, name: str, model_name: str, model_version: str) -> Optional[object]: """Return the value of the given metric for the model. The returned value can be None. If the metric does not exist, KeyError will be raised. Args: name: Desired tag name. - id: Model ID string. Required if either model name or model version is None. - model_name: Model Name string. Required if id is None. - model_version: Model Version string. Required if version is None. + model_name: Model Name string. + model_version: Model Version string. Returns: Value of the metric. Can be None if the metric was set to None. """ - return self.get_metrics(id=id, model_name=model_name, model_version=model_version)[name] + return self.get_metrics(model_name=model_name, model_version=model_version)[name] @telemetry.send_api_usage_telemetry( project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, ) - def get_metrics( - self, id: Optional[str] = None, model_name: Optional[str] = None, model_version: Optional[str] = None - ) -> Dict[str, object]: - """Get all metrics and values stored for the given (model name + model version) or model id. + def get_metrics(self, model_name: str, model_version: str) -> Dict[str, object]: + """Get all metrics and values stored for the given model. Args: - id: Model ID string. Required if either model name or model version is None. - model_name: Model Name string. Required if id is None. - model_version: Model Version string. Required if version is None. + model_name: Model Name string. + model_version: Model Version string. Returns: String-to-float dictionary containing all metrics and values. The resulting dictionary can be empty. @@ -958,7 +952,7 @@ def get_metrics( # Snowpark snowpark.dataframes returns dictionary objects as strings. We need to convert it back to a dictionary # here. result = self._get_metadata_attribute( - _METADATA_ATTRIBUTE_METRICS, id=id, model_name=model_name, model_version=model_version + _METADATA_ATTRIBUTE_METRICS, model_name=model_name, model_version=model_version ) if result: @@ -1072,7 +1066,6 @@ def log_model( def register_model( self, *, - id: str, type: str, uri: str, name: str, @@ -1081,15 +1074,13 @@ def register_model( output_spec: Optional[Dict[str, str]] = None, description: Optional[str] = None, tags: Optional[Dict[str, str]] = None, - ) -> bool: + ) -> str: """Register a new model in the ModelRegistry. This operation will only create the metadata and not handle any model artifacts. A URI is expected to be given that points the the actual model artifact. Args: - id: Unique identifier to be used for the model. This is required to be unique within the registry and - uniqueness will be verified. The model id is immutable once set. type: Type of the model. Only a subset of types are supported natively. uri: Resource identifier pointing to the model artifact. There are no restrictions on the URI format, however only a limited set of URI schemes is supported natively. @@ -1108,15 +1099,19 @@ def register_model( after model registration. Returns: - True if the operation was successful. + The model id string, which is unique identifier to be used for the model. None will be returned if the + operation failed. Raises: DataError: The given model already exists. + DatabaseError: Unable to register the model properties into table. """ # TODO(Zhe SNOW-813224): Remove input_spec and output_spec. Use signature instead. # Create registry entry. + id = self._get_new_unique_identifier() + new_model: Dict[Any, Any] = {} new_model["ID"] = id new_model["NAME"] = name @@ -1129,19 +1124,23 @@ def register_model( new_model["CREATION_ENVIRONMENT_SPEC"] = {"python": ".".join(map(str, sys.version_info[:3]))} new_model["URI"] = uri - existing_model_nums = self._list_selected_models(id=id, model_name=name, model_version=version).count() + existing_model_nums = self._list_selected_models(model_name=name, model_version=version).count() if existing_model_nums: raise connector.DataError(f"Model {name}/{version} already exists. Unable to register the model.") if self._insert_registry_entry(id=id, name=name, version=version, properties=new_model): - self._set_metadata_attribute(id=id, attribute=_METADATA_ATTRIBUTE_REGISTRATION, value=new_model) + self._set_metadata_attribute( + model_name=name, model_version=version, attribute=_METADATA_ATTRIBUTE_REGISTRATION, value=new_model + ) if description: - self.set_model_description(id=id, description=description) + self.set_model_description(model_name=name, model_version=version, description=description) if tags: - self._set_metadata_attribute(id=id, attribute=_METADATA_ATTRIBUTE_TAGS, value=tags) - return True + self._set_metadata_attribute( + _METADATA_ATTRIBUTE_TAGS, value=tags, model_name=name, model_version=version + ) + return id else: - return False + raise connector.DatabaseError("Failed to insert the model properties to the registry table.") @telemetry.send_api_usage_telemetry( project=_TELEMETRY_PROJECT, @@ -1175,10 +1174,9 @@ def log_model_path( Returns: String of the auto-generate unique model identifier. """ - id = self._get_new_unique_identifier() # Copy model from local disk to remote stage. - fully_qualified_model_stage_name = self._prepare_model_stage(id=id) + fully_qualified_model_stage_name = self._prepare_model_stage(model_name=name, model_version=version) # Check if directory or file and adapt accordingly. # TODO: Unify and explicit about compression for both file and directory. @@ -1196,8 +1194,7 @@ def log_model_path( overwrite=True, is_in_udf=True, ) - self.register_model( - id=id, + id = self.register_model( type=type, uri=uri.get_uri_from_snowflake_stage_path(fully_qualified_model_stage_name), name=name if name else fully_qualified_model_stage_name, @@ -1208,6 +1205,14 @@ def log_model_path( return id + def _get_fully_qualified_stage_name_from_uri(self, model_uri: str) -> Optional[str]: + raw_stage_name = uri.get_snowflake_stage_path_from_uri(model_uri) + if not raw_stage_name: + return None + model_stage_name = raw_stage_name.split(".")[-1] + qualified_stage_path = f"{self._fully_qualified_schema_name()}.{model_stage_name}" + return qualified_stage_path + def _get_model_path( self, id: Optional[str] = None, model_name: Optional[str] = None, model_version: Optional[str] = None ) -> str: @@ -1218,37 +1223,30 @@ def _get_model_path( model_name: Model Name string. Required if id is None. model_version: Model Version string. Required if id is None. + Returns: + str: Stage path for the model. + Raises: DataError: When the model cannot be found or not be restored. NotImplementedError: For models that span multiple files. - - Returns: - str: Stage path for the model. """ statement_params = self._get_statement_params(inspect.currentframe()) selected_models = self._list_selected_models(id=id, model_name=model_name, model_version=model_version) - model_uri_result = selected_models.select("ID", "URI").collect(statement_params=statement_params) - - table_name = self._fully_qualified_registry_table_name() - if len(model_uri_result) == 0: - raise connector.DataError(f"Model with id {id} not found in ModelRegistry {table_name}.") - - if len(model_uri_result) > 1: - raise connector.DataError( - f"Model with id {id} exist multiple ({len(model_uri_result)}) times in ModelRegistry " "{table_name}." - ) - - model_uri = model_uri_result[0].URI + identifier = f"id {id}" if id else f"{model_name}/{model_version}" + model_info = self._validate_exact_one_result(selected_models, identifier) + if not id: + id = model_info[0]["ID"] + model_uri = model_info[0]["URI"] if not uri.is_snowflake_stage_uri(model_uri): raise connector.DataError( f"Artifacts with URI scheme {uri.get_uri_scheme(model_uri)} are currently not supported." ) - model_stage_name = uri.get_snowflake_stage_path_from_uri(model_uri) + model_stage_path = self._get_fully_qualified_stage_name_from_uri(model_uri=model_uri) # Currently we assume only the model is on the stage. - model_file_list = self._session.sql(f"LIST @{model_stage_name}").collect(statement_params=statement_params) + model_file_list = self._session.sql(f"LIST @{model_stage_path}").collect(statement_params=statement_params) if len(model_file_list) == 0: raise connector.DataError(f"No files in model artifact for id {id} located at {model_uri}.") if len(model_file_list) > 1: @@ -1259,20 +1257,17 @@ def _get_model_path( project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, ) - def load_model( - self, *, id: Optional[str] = None, model_name: Optional[str] = None, model_version: Optional[str] = None - ) -> Any: - """Loads the model with the given (model_name + model_version) or `id` from the registry into memory. + def load_model(self, *, model_name: str, model_version: str) -> Any: + """Loads the model with the given (model_name + model_version) from the registry into memory. Args: - id: Model identifier. Required if either model name or model version is None. - model_name: Model Name string. Required if id is None. - model_version: Model Version string. Required if version is None. + model_name: Model Name string. + model_version: Model Version string. Returns: Restored model object. """ - remote_model_path = self._get_model_path(id=id, model_name=model_name, model_version=model_version) + remote_model_path = self._get_model_path(model_name=model_name, model_version=model_version) restored_model = None with tempfile.TemporaryDirectory() as local_model_directory: self._session.file.get(remote_model_path, local_model_directory) @@ -1306,9 +1301,8 @@ def deploy( *, deployment_name: str, target_method: str, - id: Optional[str] = None, - model_name: Optional[str] = None, - model_version: Optional[str] = None, + model_name: str, + model_version: str, options: Optional[model_types.DeployOptions] = None, ) -> None: """Deploy the model with the the given deployment name. @@ -1316,9 +1310,8 @@ def deploy( Args: deployment_name: name of the generated UDF. target_method: The method name to use in deployment. - id: Id of the model to deploy. Required if either model name or model version is None. - model_name: Model Name string. Required if id is None. - model_version: Model Version string. Required if version is None. + model_name: Model Name string. + model_version: Model Version string. options: Optional options for model deployment. Defaults to None. Raises: @@ -1327,7 +1320,7 @@ def deploy( if options is None: options = {} - remote_model_path = self._get_model_path(id=id, model_name=model_name, model_version=model_version) + remote_model_path = self._get_model_path(model_name=model_name, model_version=model_version) with tempfile.TemporaryDirectory() as local_model_directory: self._session.file.get(remote_model_path, local_model_directory) is_native_model_format = False @@ -1377,51 +1370,29 @@ def predict(self, deployment_name: str, data: Any) -> "pd.DataFrame": def delete_model( self, + model_name: str, + model_version: str, delete_artifact: bool = True, - id: Optional[str] = None, - model_name: Optional[str] = None, - model_version: Optional[str] = None, ) -> None: """Delete model with the given ID from the registry. The history of the model will still be preserved. Args: + model_name: Model Name string. + model_version: Model Version string. delete_artifact: If True, the underlying model artifact will also be deleted, not just the entry in the registry table. - id: Id of the model to delete. Required if either model name or model version is None. - model_name: Model Name string. Required if id is None. - model_version: Model Version string. Required if version is None. - - - Raises: - KeyError: Model with the given ID does not exist in the registry. """ # Check that a model with the given ID exists and there is only one of them. # TODO(amauser): The following sequence should be a transaction. Transactions currently cannot contain DDL # statements. model_info = None - try: - selected_models = self._list_selected_models(id=id, model_name=model_name, model_version=model_version) - model_info = ( - query_result_checker.ResultValidator(result=selected_models.collect()) - .has_dimensions(expected_rows=1) - .validate() - ) - except connector.DataError: - identifier = f"with id {id}" if id else f"named {model_name}/{model_version}" - if model_info is None or len(model_info) == 0: - raise KeyError(f"The model {identifier} does not exist in the current registry.") - else: - raise KeyError( - formatting.unwrap( - f"""There are {len(model_info)} models {identifier}. This might indicate a problem with - the integrity of the model registry data.""" - ) - ) - if not id: - id = model_info[0]["ID"] + selected_models = self._list_selected_models(model_name=model_name, model_version=model_version) + identifier = f"{model_name}/{model_version}" + model_info = self._validate_exact_one_result(selected_models, identifier) + id = model_info[0]["ID"] model_uri = model_info[0]["URI"] # Step 1/3: Delete the registry entry. @@ -1432,8 +1403,8 @@ def delete_model( # Step 2/3: Delete the artifact (if desired). if delete_artifact: if uri.is_snowflake_stage_uri(model_uri): - stage_name = uri.get_snowflake_stage_path_from_uri(model_uri) - query_result_checker.SqlResultValidator(self._session, f"DROP STAGE {stage_name}").has_value_match( + stage_path = self._get_fully_qualified_stage_name_from_uri(model_uri) + query_result_checker.SqlResultValidator(self._session, f"DROP STAGE {stage_path}").has_value_match( row_idx=0, col_idx=0, expected_value="successfully dropped." ).validate() @@ -1527,12 +1498,14 @@ def __init__( self, *, registry: ModelRegistry, + model_name: str, + model_version: str, id: Optional[str] = None, - model_name: Optional[str] = None, - model_version: Optional[str] = None, ) -> None: self._registry = registry self._id = id if id else registry._get_model_id(model_name=model_name, model_version=model_version) + self._model_name = model_name + self._model_version = model_version # Wrap all functions of the ModelRegistry that have an "id" parameter and bind that parameter # the the "_id" member of this class. @@ -1541,7 +1514,11 @@ def __init__( return for name, obj in self._registry.__class__.__dict__.items(): - if not inspect.isfunction(obj) or "id" not in inspect.signature(obj).parameters: + if ( + not inspect.isfunction(obj) + or "model_name" not in inspect.signature(obj).parameters + or "model_version" not in inspect.signature(obj).parameters + ): continue # Ensure that we are not silently overwriting existing functions. @@ -1551,13 +1528,14 @@ def __init__( old_sig = inspect.signature(obj) removed_none_type = map( lambda x: x.replace(annotation=str(x.annotation)), - filter(lambda p: p.name not in ["id"], old_sig.parameters.values()), + filter(lambda p: p.name not in ["model_name", "model_version"], old_sig.parameters.values()), ) new_sig = old_sig.replace( parameters=list(removed_none_type), return_annotation=str(old_sig.return_annotation) ) arguments = ", ".join( - ["id=self._id"] + ["model_name=self._model_name"] + + ["model_version=self._model_version"] + [ "{p.name}={p.name}".format(p=p) for p in filter( @@ -1566,9 +1544,7 @@ def __init__( ) ] ) - docstring = self._remove_arg_from_docstring("id", obj.__doc__) - if docstring and "model_name" in docstring: - docstring = self._remove_arg_from_docstring("model_name", docstring) + docstring = self._remove_arg_from_docstring("model_name", obj.__doc__) if docstring and "model_version" in docstring: docstring = self._remove_arg_from_docstring("model_version", docstring) exec( diff --git a/snowflake/ml/registry/model_registry_test.py b/snowflake/ml/registry/model_registry_test.py index e55f1bf3..20d9c54b 100644 --- a/snowflake/ml/registry/model_registry_test.py +++ b/snowflake/ml/registry/model_registry_test.py @@ -35,6 +35,8 @@ def setUp(self) -> None: self._session = mock_session.MockSession(conn=None, test_case=self) self.event_id = "fedcba9876543210fedcba9876543210" self.model_id = "0123456789abcdef0123456789abcdef" + self.model_name = "name" + self.model_version = "abc" self.datetime = datetime.datetime(2022, 11, 4, 17, 1, 30, 153000) def tearDown(self) -> None: @@ -219,12 +221,30 @@ def setup_create_views_call(self) -> None: result=mock_data_frame.MockDataFrame([snowpark.Row(status="View MODELS_VIEW successfully created.")]), ) + def template_test_get_attribute( + self, collection_res: List[snowpark.Row], use_id: bool = False + ) -> mock_data_frame.MockDataFrame: + expected_df = self.setup_list_model_call() + expected_df.add_operation("filter") + if not use_id: + expected_df.add_operation("filter") + expected_df.add_collect_result(collection_res) + return expected_df + def template_test_set_attribute( - self, attribute_name: str, attribute_value: Union[str, Dict[Any, Any]], result_num_inserted: int = 1 + self, + attribute_name: str, + attribute_value: Union[str, Dict[Any, Any]], + result_num_inserted: int = 1, + use_id: bool = False, ) -> None: expected_df = self.setup_list_model_call() expected_df.add_operation("filter") - expected_df.add_count_result(1) + if not use_id: + expected_df.add_operation("filter") + expected_df.add_collect_result( + [snowpark.Row(ID=self.model_id, NAME="name", VERSION="abc", URI="sfc://model_stage")] + ) self._session.add_operation("get_current_role", result="current_role") @@ -337,20 +357,24 @@ def test_set_model_description(self) -> None: "_get_new_unique_identifier", return_value=self.event_id, ): - model_registry.set_model_description(id=self.model_id, description="new_description") + model_registry.set_model_description( + model_name=self.model_name, model_version=self.model_version, description="new_description" + ) def test_get_model_description(self) -> None: """Test that we can get the description of an existing model from the registry.""" model_registry = self.get_model_registry() - expected_df = self.setup_list_model_call() - expected_df.add_operation(operation="filter") - expected_df.add_operation( - operation="select", - args=("DESCRIPTION",), - result=mock_data_frame.MockDataFrame([snowpark.Row(ID=self.model_id, DESCRIPTION="model_description")]), + self.template_test_get_attribute( + [ + snowpark.Row( + ID=self.model_id, NAME=self.model_name, VERSION=self.model_version, DESCRIPTION="model_description" + ) + ] ) - model_description = model_registry.get_model_description(id=self.model_id) + model_description = model_registry.get_model_description( + model_name=self.model_name, model_version=self.model_version + ) self.assertEqual(model_description, "model_description") def test_get_history(self) -> None: @@ -392,6 +416,9 @@ def test_get_history(self) -> None: def test_get_model_history(self) -> None: """Test that we can retrieve the history for a specific model.""" model_registry = self.get_model_registry() + self.template_test_get_attribute( + [snowpark.Row(ID=self.model_id, NAME=self.model_name, VERSION=self.model_version)] + ) expected_collect_result = [ snowpark.Row( EVENT_TIMESTAMP="ts", @@ -424,17 +451,16 @@ def test_get_model_history(self) -> None: expected_df.add_operation(operation="filter", check_args=False, check_kwargs=False) expected_df.add_collect_result(expected_collect_result) - self.assertEqual(model_registry.get_model_history(id=self.model_id).collect(), expected_collect_result) + self.assertEqual( + model_registry.get_model_history(model_name=self.model_name, model_version=self.model_version).collect(), + expected_collect_result, + ) def test_set_metric_no_existing(self) -> None: """Test that we can set a metric for an existing model that does not yet have any metrics set.""" model_registry = self.get_model_registry() - expected_df = self.setup_list_model_call() - expected_df.add_operation(operation="filter") - expected_df.add_operation( - operation="select", - args=("METRICS",), - result=mock_data_frame.MockDataFrame([snowpark.Row(ID=self.model_id, METRICS=None)]), + self.template_test_get_attribute( + [snowpark.Row(ID=self.model_id, NAME=self.model_name, VERSION=self.model_version, METRICS=None)] ) self.template_test_set_attribute("METRICS", {"voight-kampff": 0.9}) @@ -444,17 +470,19 @@ def test_set_metric_no_existing(self) -> None: "_get_new_unique_identifier", return_value=self.event_id, ): - model_registry.set_metric(id=self.model_id, name="voight-kampff", value=0.9) + model_registry.set_metric( + model_name=self.model_name, model_version=self.model_version, name="voight-kampff", value=0.9 + ) def test_set_metric_with_existing(self) -> None: """Test that we can set a metric for an existing model that already has metrics.""" model_registry = self.get_model_registry() - expected_df = self.setup_list_model_call() - expected_df.add_operation(operation="filter") - expected_df.add_operation( - operation="select", - args=("METRICS",), - result=mock_data_frame.MockDataFrame([snowpark.Row(ID=self.model_id, METRICS='{"human-factor": 1.1}')]), + self.template_test_get_attribute( + [ + snowpark.Row( + ID=self.model_id, NAME=self.model_name, VERSION=self.model_version, METRICS='{"human-factor": 1.1}' + ) + ] ) self.template_test_set_attribute("METRICS", {"human-factor": 1.1, "voight-kampff": 0.9}) @@ -464,35 +492,42 @@ def test_set_metric_with_existing(self) -> None: "_get_new_unique_identifier", return_value=self.event_id, ): - model_registry.set_metric(id=self.model_id, name="voight-kampff", value=0.9) + model_registry.set_metric( + model_name=self.model_name, model_version=self.model_version, name="voight-kampff", value=0.9 + ) def test_get_metrics(self) -> None: """Test that we can get the metrics for an existing model.""" metrics_dict = {"human-factor": 1.1, "voight-kampff": 0.9} model_registry = self.get_model_registry() - expected_df = self.setup_list_model_call() - expected_df.add_operation(operation="filter") - expected_df.add_operation( - operation="select", - args=("METRICS",), - result=mock_data_frame.MockDataFrame([snowpark.Row(ID=self.model_id, METRICS=json.dumps(metrics_dict))]), + self.template_test_get_attribute( + [ + snowpark.Row( + ID=self.model_id, NAME=self.model_name, VERSION=self.model_version, METRICS=json.dumps(metrics_dict) + ) + ] + ) + self.assertEqual( + model_registry.get_metrics(model_name=self.model_name, model_version=self.model_version), metrics_dict ) - - self.assertEqual(model_registry.get_metrics(id=self.model_id), metrics_dict) def test_get_metric_value(self) -> None: """Test that we can get a single metric value for an existing model.""" metrics_dict = {"human-factor": 1.1, "voight-kampff": 0.9} model_registry = self.get_model_registry() - expected_df = self.setup_list_model_call() - expected_df.add_operation(operation="filter") - expected_df.add_operation( - operation="select", - args=("METRICS",), - result=mock_data_frame.MockDataFrame([snowpark.Row(ID=self.model_id, METRICS=json.dumps(metrics_dict))]), + self.template_test_get_attribute( + [ + snowpark.Row( + ID=self.model_id, NAME=self.model_name, VERSION=self.model_version, METRICS=json.dumps(metrics_dict) + ) + ] + ) + self.assertEqual( + model_registry.get_metric_value( + model_name=self.model_name, model_version=self.model_version, name="human-factor" + ), + 1.1, ) - - self.assertEqual(model_registry.get_metric_value(id=self.model_id, name="human-factor"), 1.1) def test_private_insert_registry_entry(self) -> None: model_registry = self.get_model_registry() @@ -513,7 +548,7 @@ def test_register_model(self) -> None: self._session.add_operation("get_current_role", result="current_role") mock_df = self.setup_list_model_call() - mock_df.add_operation("filter", result=mock_data_frame.MockDataFrame([])) + mock_df.add_operation("filter").add_operation("filter", result=mock_data_frame.MockDataFrame([])) mock_df.add_count_result(0) self.add_session_mock_sql( @@ -526,12 +561,13 @@ def test_register_model(self) -> None: # Mock calls to variable values: python version and internal _set_metadata_attribute. with absltest.mock.patch.object(model_registry, "_set_metadata_attribute", return_value=True): - with absltest.mock.patch( - "model_registry.sys.version_info", new_callable=absltest.mock.PropertyMock(return_value=(3, 8, 13)) - ): - model_registry.register_model( - id="id", uri="uri", type="type", name="name", version="abc", tags={"tag_name": "tag_value"} - ) + with absltest.mock.patch.object(model_registry, "_get_new_unique_identifier", return_value="id"): + with absltest.mock.patch( + "model_registry.sys.version_info", new_callable=absltest.mock.PropertyMock(return_value=(3, 8, 13)) + ): + model_registry.register_model( + uri="uri", type="type", name="name", version="abc", tags={"tag_name": "tag_value"} + ) def test_register_model_no_tags(self) -> None: """Test registering a model without giving a tag.""" @@ -540,7 +576,7 @@ def test_register_model_no_tags(self) -> None: self._session.add_operation("get_current_role", result="current_role") mock_df = self.setup_list_model_call() - mock_df.add_operation("filter", result=mock_data_frame.MockDataFrame([])) + mock_df.add_operation("filter").add_operation("filter", result=mock_data_frame.MockDataFrame([])) mock_df.add_count_result(0) self.add_session_mock_sql( @@ -572,23 +608,20 @@ def test_register_model_no_tags(self) -> None: with absltest.mock.patch.object( model_registry, "_get_new_unique_identifier", - return_value=self.event_id, + side_effect=[self.model_id, self.event_id], ): with absltest.mock.patch( "model_registry.sys.version_info", new_callable=absltest.mock.PropertyMock(return_value=(3, 8, 13)) ): - model_registry.register_model(id=self.model_id, uri="uri", type="type", name="name", version="abc") + model_registry.register_model(uri="uri", type="type", name="name", version="abc") def test_get_tags(self) -> None: """Test that get_tags is working correctly with various types.""" model_registry = self.get_model_registry() - self.setup_list_model_call().add_operation(operation="filter").add_operation( - operation="select", - args=("TAGS",), - result=mock_data_frame.MockDataFrame( - [ - snowpark.Row( - TAGS=""" + self.template_test_get_attribute( + [ + snowpark.Row( + TAGS=""" { "top_level": "string", "nested": { @@ -615,11 +648,12 @@ def test_get_tags(self) -> None: ] } }""", - ) - ], - ), + ) + ] ) - model_registry.get_tags(id=self.model_id) + tags = model_registry.get_tags(model_name=self.model_name, model_version=self.model_version) + self.assertEqual(tags["top_level"], "string") + self.assertEqual(tags["nested"]["float"], 0.9) def test_register_model_with_description(self) -> None: """Test registering a model with a description.""" @@ -628,7 +662,7 @@ def test_register_model_with_description(self) -> None: self._session.add_operation("get_current_role", result="current_role") mock_df = self.setup_list_model_call() - mock_df.add_operation("filter", result=mock_data_frame.MockDataFrame([])) + mock_df.add_operation("filter").add_operation("filter", result=mock_data_frame.MockDataFrame([])) mock_df.add_count_result(0) self.add_session_mock_sql( @@ -665,13 +699,12 @@ def test_register_model_with_description(self) -> None: with absltest.mock.patch.object( model_registry, "_get_new_unique_identifier", - return_value=self.event_id, + side_effect=[self.model_id, self.event_id, self.event_id], ): with absltest.mock.patch( "model_registry.sys.version_info", new_callable=absltest.mock.PropertyMock(return_value=(3, 8, 13)) ): model_registry.register_model( - id=self.model_id, uri="uri", type="type", version="abc", @@ -687,10 +720,14 @@ def test_log_model_path_file(self) -> None: """ model_registry = self.get_model_registry() + model_name = "name" + model_version = "abc" + expected_stage_postfix = f"{model_name}_{model_version}".upper() + self.add_session_mock_sql( - query=f"""CREATE OR REPLACE STAGE "{_DATABASE_NAME}"."{_SCHEMA_NAME}".SNOWML_MODEL_{self.model_id}""", + query=f'CREATE OR REPLACE STAGE "{_DATABASE_NAME}"."{_SCHEMA_NAME}".SNOWML_MODEL_{expected_stage_postfix}', result=mock_data_frame.MockDataFrame( - [snowpark.Row(**{"status": f"Stage area SNOWML_MODEL_{self.model_id.upper()} successfully created."})] + [snowpark.Row(**{"status": f"Stage area SNOWML_MODEL_{expected_stage_postfix} successfully created."})] ), ) @@ -698,7 +735,7 @@ def test_log_model_path_file(self) -> None: mock_sp_file_operation = absltest.mock.Mock() self._session.__setattr__("file", mock_sp_file_operation) - expected_stage_path = f'"{_DATABASE_NAME}"."{_SCHEMA_NAME}".SNOWML_MODEL_{self.model_id.upper()}/data' + expected_stage_path = f'"{_DATABASE_NAME}"."{_SCHEMA_NAME}".SNOWML_MODEL_{expected_stage_postfix}/data' with absltest.mock.patch("model_registry.os.path.isfile", return_value=True) as mock_isfile: with absltest.mock.patch.object( @@ -712,17 +749,16 @@ def test_log_model_path_file(self) -> None: return_value=True, ): model_registry.log_model_path( - path="path", type="type", name="name", version="abc", description="description" + path="path", type="type", name=model_name, version=model_version, description="description" ) mock_isfile.assert_called_once_with("path") mock_sp_file_operation.put.assert_called_with("path", expected_stage_path) assert isinstance(model_registry.register_model, absltest.mock.Mock) model_registry.register_model.assert_called_with( - id=self.model_id, type="type", - uri=f"sfc:MODEL_REGISTRY.PUBLIC.SNOWML_MODEL_{self.model_id.upper()}", - name="name", - version="abc", + uri=f"sfc:MODEL_REGISTRY.PUBLIC.SNOWML_MODEL_{expected_stage_postfix}", + name=model_name, + version=model_version, description="description", tags=None, ) @@ -730,15 +766,17 @@ def test_log_model_path_file(self) -> None: def test_delete_model_with_artifact(self) -> None: """Test deleting a model and artifact from the registry.""" model_registry = self.get_model_registry() - self.setup_list_model_call().add_operation(operation="filter").add_collect_result( - [snowpark.Row(URI="sfc://model_stage")], + self.setup_list_model_call().add_operation(operation="filter").add_operation( + operation="filter" + ).add_collect_result( + [snowpark.Row(ID=self.model_id, NAME=self.model_name, VERSION=self.model_version, URI="sfc://model_stage")], ) self.add_session_mock_sql( query=f"""DELETE FROM "MODEL_REGISTRY"."PUBLIC"."MODELS" WHERE ID='{self.model_id}'""", result=mock_data_frame.MockDataFrame([snowpark.Row(**{"number of rows deleted": 1})]), ) self.add_session_mock_sql( - query="DROP STAGE model_stage", + query='DROP STAGE "MODEL_REGISTRY"."PUBLIC".model_stage', result=mock_data_frame.MockDataFrame([snowpark.Row(**{"status": "'model_stage' successfully dropped."})]), ) self.template_test_set_attribute( @@ -747,6 +785,7 @@ def test_delete_model_with_artifact(self) -> None: "URI": "sfc://model_stage", "delete_artifact": True, }, + use_id=True, ) with absltest.mock.patch.object( @@ -754,7 +793,7 @@ def test_delete_model_with_artifact(self) -> None: "_get_new_unique_identifier", return_value=self.event_id, ): - model_registry.delete_model(id=self.model_id) + model_registry.delete_model(model_name="name", model_version="abc") if __name__ == "__main__": diff --git a/snowflake/ml/registry/notebooks/Model Registry Demo.ipynb b/snowflake/ml/registry/notebooks/Model Registry Demo.ipynb index 3a4b2f83..4066a5ca 100644 --- a/snowflake/ml/registry/notebooks/Model Registry Demo.ipynb +++ b/snowflake/ml/registry/notebooks/Model Registry Demo.ipynb @@ -51,7 +51,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Adding /Users/amauser/src/snowml to system path\n" + "Adding /Users/zzhu/workspace/sfml/snowml to system path\n" ] } ], @@ -144,7 +144,16 @@ "execution_count": 4, "id": "75282f6d", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/miniconda3/envs/py38/lib/python3.8/site-packages/snowflake/connector/options.py:96: UserWarning: You have an incompatible version of 'pyarrow' installed (10.0.1), please install a version that adheres to: 'pyarrow<8.1.0,>=8.0.0; extra == \"pandas\"'\n", + " warn_incompatible_dep(\n" + ] + } + ], "source": [ "from snowflake.ml.utils.connection_params import SnowflakeLoginOptions\n", "from snowflake.snowpark import Session, Column, functions\n", @@ -173,23 +182,14 @@ "execution_count": 5, "id": "5d37ad34", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "import importlib\n", "from snowflake.ml.registry import model_registry\n", "# Force re-loading model_registry in case we updated the package during the runtime of this notebook.\n", - "importlib.reload(model_registry)" + "importlib.reload(model_registry)\n", + "\n", + "registry_name = \"model_registry_zzhu\"" ] }, { @@ -202,13 +202,13 @@ "name": "stderr", "output_type": "stream", "text": [ - "WARNING:absl:The database MODEL_REGISTRY already exists. Skipping creation.\n" + "WARNING:absl:The database model_registry_zzhu already exists. Skipping creation.\n" ] } ], "source": [ "# Create a new model registry. This will be a no-op if the registry already exists.\n", - "create_result = model_registry.create_model_registry(session)" + "create_result = model_registry.create_model_registry(session, registry_name)" ] }, { @@ -218,7 +218,7 @@ "metadata": {}, "outputs": [], "source": [ - "registry = model_registry.ModelRegistry(session=session)" + "registry = model_registry.ModelRegistry(session=session, name=registry_name)" ] }, { @@ -262,22 +262,30 @@ "execution_count": 8, "id": "9d8ad06e", "metadata": {}, + "outputs": [], + "source": [ + "# A name and model tags can be added to the model at registration time.\n", + "model_id = registry.log_model(model=clf, name=\"my_model\", version=\"103\", tags={\n", + " \"stage\": \"testing\", \"classifier_type\": \"svm.SVC\", \"svc_gamma\": svc_gamma, \"svc_C\": svc_C}, sample_input_data=train_features)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "b463bad9", + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Registered new model: 89f1499eb70011ed9cd3e289b4f89202\n" + "Registered new model: 8aa8fac2f03911edb94aacde48001122\n" ] } ], "source": [ - "# A name and model tags can be added to the model at registration time.\n", - "model_id = registry.log_model(model=clf, name=\"my_model\", tags={\n", - " \"stage\": \"testing\", \"classifier_type\": \"svm.SVC\", \"svc_gamma\": svc_gamma, \"svc_C\": svc_C})\n", - "\n", "# The object API can be used to reference a model after creation.\n", - "model = model_registry.ModelReference(registry=registry, id=model_id)\n", + "model = model_registry.ModelReference(registry=registry, id=model_id, model_name=\"my_model\", model_version=\"103\")\n", "print(\"Registered new model:\", model_id)" ] }, @@ -299,7 +307,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "id": "c2b0cdbd", "metadata": {}, "outputs": [ @@ -322,18 +330,69 @@ "# Simple scalar metrics.\n", "\n", "# Relational API\n", - "registry.set_metric(id=model_id, name=\"test_accuracy\", value=test_accuracy)\n", + "registry.set_metric(model_name=\"my_model\", model_version=\"103\", name=\"test_accuracy\", value=test_accuracy)\n", + "\n", "# Object API\n", "model.set_metric(name=\"num_training_examples\", value=num_training_examples)\n", "\n", "# Hierarchical metric.\n", - "registry.set_metric(id=model_id, name=\"dataset_test\", value={\"accuracy\": test_accuracy})\n", + "registry.set_metric(model_name=\"my_model\", model_version=\"103\", name=\"dataset_test\", value={\"accuracy\": test_accuracy})\n", "\n", "# Multivalent metric:\n", "test_confusion_matrix = metrics.confusion_matrix(test_labels, prediction)\n", "print(\"Confusion matrix:\", test_confusion_matrix)\n", "\n", - "registry.set_metric(id=model_id, name=\"confusion_matrix\", value=test_confusion_matrix)" + "registry.set_metric(model_name=\"my_model\", model_version=\"103\", name=\"confusion_matrix\", value=test_confusion_matrix)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "45b81834", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'confusion_matrix': [[90, 0], [3, 7]],\n", + " 'dataset_test': {'accuracy': 0.97},\n", + " 'num_training_examples': 10,\n", + " 'test_accuracy': 0.97}" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Relational API\n", + "registry.get_metrics(model_name=\"my_model\", model_version=\"103\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "9a2627c5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'confusion_matrix': [[90, 0], [3, 7]],\n", + " 'dataset_test': {'accuracy': 0.97},\n", + " 'num_training_examples': 10,\n", + " 'test_accuracy': 0.97}" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Object API\n", + "model.get_metrics()" ] }, { @@ -354,7 +413,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 14, "id": "dc82b541", "metadata": {}, "outputs": [ @@ -362,27 +421,27 @@ "name": "stdout", "output_type": "stream", "text": [ - "---------------------------------------------------------------------------------\n", - "|\"NAME\" |\"TAGS\" |\"METRICS\" |\n", - "---------------------------------------------------------------------------------\n", - "|\"my_model\" |{ |{ |\n", - "| | \"classifier_type\": \"svm.SVC\", | \"confusion_matrix\": [ |\n", - "| | \"stage\": \"testing\", | [ |\n", - "| | \"svc_C\": 10, | 90, |\n", - "| | \"svc_gamma\": 0.001 | 0 |\n", - "| |} | ], |\n", - "| | | [ |\n", - "| | | 3, |\n", - "| | | 7 |\n", - "| | | ] |\n", - "| | | ], |\n", - "| | | \"dataset_test\": { |\n", - "| | | \"accuracy\": 0.97 |\n", - "| | | }, |\n", - "| | | \"num_training_examples\": 10, |\n", - "| | | \"test_accuracy\": 0.97 |\n", - "| | |} |\n", - "---------------------------------------------------------------------------------\n", + "-------------------------------------------------------------------------------------------\n", + "|\"NAME\" |\"VERSION\" |\"TAGS\" |\"METRICS\" |\n", + "-------------------------------------------------------------------------------------------\n", + "|my_model |103 |{ |{ |\n", + "| | | \"classifier_type\": \"svm.SVC\", | \"confusion_matrix\": [ |\n", + "| | | \"stage\": \"testing\", | [ |\n", + "| | | \"svc_C\": 10, | 90, |\n", + "| | | \"svc_gamma\": 0.001 | 0 |\n", + "| | |} | ], |\n", + "| | | | [ |\n", + "| | | | 3, |\n", + "| | | | 7 |\n", + "| | | | ] |\n", + "| | | | ], |\n", + "| | | | \"dataset_test\": { |\n", + "| | | | \"accuracy\": 0.97 |\n", + "| | | | }, |\n", + "| | | | \"num_training_examples\": 10, |\n", + "| | | | \"test_accuracy\": 0.97 |\n", + "| | | |} |\n", + "-------------------------------------------------------------------------------------------\n", "\n" ] } @@ -390,23 +449,25 @@ "source": [ "model_list = registry.list_models()\n", "\n", - "model_list.filter(model_list[\"ID\"] == model_id).select(\"NAME\",\"TAGS\",\"METRICS\").show()" + "model_list.filter(model_list[\"ID\"] == model_id).select(\"NAME\",\"VERSION\",\"TAGS\",\"METRICS\").show()" ] }, { + "attachments": {}, "cell_type": "markdown", "id": "5706004c", "metadata": {}, "source": [ - "## Metadata: Tags and Name" + "## Metadata: Tags and Descriptions" ] }, { + "attachments": {}, "cell_type": "markdown", "id": "05cee94f", "metadata": {}, "source": [ - "Similar to how we changed metrics in the example above, we can also edit tags and names of models both with the relational API and with the object API." + "Similar to how we changed metrics in the example above, we can also edit tags and descriptions of models both with the relational API and with the object API." ] }, { @@ -419,7 +480,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 15, "id": "f80f78da", "metadata": {}, "outputs": [ @@ -431,29 +492,24 @@ "Added tag: {'classifier_type': 'svm.SVC', 'minor_version': '23', 'stage': 'testing', 'svc_C': 10, 'svc_gamma': 0.001}\n", "Removed tag {'classifier_type': 'svm.SVC', 'stage': 'testing', 'svc_C': 10, 'svc_gamma': 0.001}\n", "Updated tag: {'classifier_type': 'svm.SVC', 'stage': 'production', 'svc_C': 10, 'svc_gamma': 0.001}\n", - "Old name: \"my_model\"\n", - "New name: \"target_digit_6\"\n" + "Added description: \"My model is better than talkgpt-5!\"\n" ] } ], "source": [ - "print(\"Old tags:\", registry.get_tags(model_id))\n", - "\n", - "registry.set_tag(model_id, \"minor_version\", \"23\")\n", - "print(\"Added tag:\", registry.get_tags(model_id))\n", + "print(\"Old tags:\", registry.get_tags(model_name=\"my_model\", model_version=\"103\",))\n", "\n", - "registry.remove_tag(model_id, \"minor_version\")\n", - "print(\"Removed tag\", registry.get_tags(model_id))\n", - "registry.set_tag(model_id, \"stage\", \"production\")\n", - "print(\"Updated tag:\", registry.get_tags(model_id))\n", + "registry.set_tag(name=\"minor_version\", value=\"23\", model_name=\"my_model\", model_version=\"103\",)\n", + "print(\"Added tag:\", registry.get_tags(model_name=\"my_model\", model_version=\"103\",))\n", "\n", - "# Rename Model\n", - "print(\"Old name:\", registry.get_model_name(model_id))\n", + "registry.remove_tag(name=\"minor_version\", model_name=\"my_model\", model_version=\"103\",)\n", + "print(\"Removed tag\", registry.get_tags(model_name=\"my_model\", model_version=\"103\",))\n", "\n", - "new_model_name = f\"target_digit_{target_digit}\"\n", - "registry.set_model_name(id=model_id, name=new_model_name)\n", + "registry.set_tag(name=\"stage\", value=\"production\", model_name=\"my_model\", model_version=\"103\",)\n", + "print(\"Updated tag:\", registry.get_tags(model_name=\"my_model\", model_version=\"103\",))\n", "\n", - "print(\"New name:\", registry.get_model_name(model_id))" + "registry.set_model_description(description=\"My model is better than talkgpt-5!\", model_name=\"my_model\", model_version=\"103\",)\n", + "print(\"Added description:\", registry.get_model_description(model_name=\"my_model\", model_version=\"103\",))" ] }, { @@ -466,7 +522,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 16, "id": "7905d9c9", "metadata": {}, "outputs": [ @@ -478,8 +534,7 @@ "Added tag: {'classifier_type': 'svm.SVC', 'minor_version': '23', 'stage': 'production', 'svc_C': 10, 'svc_gamma': 0.001}\n", "Removed tag {'classifier_type': 'svm.SVC', 'stage': 'production', 'svc_C': 10, 'svc_gamma': 0.001}\n", "Updated tag: {'classifier_type': 'svm.SVC', 'stage': 'production', 'svc_C': 10, 'svc_gamma': 0.001}\n", - "Old name: \"target_digit_6\"\n", - "New name: \"target_digit_6\"\n" + "New description: \"My model is better than speakgpt-6!\"\n" ] } ], @@ -491,16 +546,12 @@ "\n", "model.remove_tag(\"minor_version\")\n", "print(\"Removed tag\", model.get_tags())\n", + "\n", "model.set_tag(\"stage\", \"production\")\n", "print(\"Updated tag:\", model.get_tags())\n", "\n", - "# Rename Model\n", - "print(\"Old name:\", model.get_model_name())\n", - "\n", - "new_model_name = f\"target_digit_{target_digit}\"\n", - "model.set_model_name(name=new_model_name)\n", - "\n", - "print(\"New name:\", model.get_model_name())" + "model.set_model_description(description=\"My model is better than speakgpt-6!\")\n", + "print(\"New description:\", model.get_model_description())" ] }, { @@ -521,7 +572,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 18, "id": "eef6965d", "metadata": { "scrolled": true @@ -531,143 +582,34 @@ "name": "stdout", "output_type": "stream", "text": [ - "---------------------------------------------------------------------------------------------------------------------------------------------\n", - "|\"ID\" |\"NAME\" |\"CREATION_TIME\" |\"TAGS\" |\n", - "---------------------------------------------------------------------------------------------------------------------------------------------\n", - "|89f1499eb70011ed9cd3e289b4f89202 |\"target_digit_6\" |2023-02-27 16:40:47.108000-08:00 |{ |\n", - "| | | | \"classifier_type\": \"svm.SVC\", |\n", - "| | | | \"stage\": \"production\", |\n", - "| | | | \"svc_C\": 10, |\n", - "| | | | \"svc_gamma\": 0.001 |\n", - "| | | |} |\n", - "|4c94d3d4b6fd11eda2d4e289b4f89202 |\"target_digit_6\" |2023-02-27 16:17:36.381000-08:00 |{ |\n", - "| | | | \"classifier_type\": \"svm.SVC\", |\n", - "| | | | \"stage\": \"production\", |\n", - "| | | | \"svc_C\": 10, |\n", - "| | | | \"svc_gamma\": 0.001 |\n", - "| | | |} |\n", - "|6a1b8e98b2f711eda1c2e289b4f89202 |\"target_digit_6\" |2023-02-22 13:25:24.538000-08:00 |{ |\n", - "| | | | \"classifier_type\": \"svm.SVC\", |\n", - "| | | | \"stage\": \"production\", |\n", - "| | | | \"svc_C\": 10, |\n", - "| | | | \"svc_gamma\": 0.001 |\n", - "| | | |} |\n", - "|1ff312fab2f711edaccee289b4f89202 |\"target_digit_6\" |2023-02-22 13:23:20.365000-08:00 |{ |\n", - "| | | | \"classifier_type\": \"svm.SVC\", |\n", - "| | | | \"stage\": \"production\", |\n", - "| | | | \"svc_C\": 10, |\n", - "| | | | \"svc_gamma\": 0.001 |\n", - "| | | |} |\n", - "|1cf19188b16411ed8846e289b4f89202 |\"uci-bank-marketing\" |2023-02-20 13:18:27.178000-08:00 |{ |\n", - "| | | | \"classifier\": \"GradientBoostingRegressor\", |\n", - "| | | | \"params\": { |\n", - "| | | | \"alpha\": 0.9, |\n", - "| | | | \"ccp_alpha\": 0, |\n", - "| | | | \"criterion\": \"friedman_mse\", |\n", - "| | | | \"learning_rate\": 0.1, |\n", - "| | | | \"loss\": \"squared_error\", |\n", - "| | | | \"max_depth\": 3, |\n", - "| | | | \"min_impurity_decrease\": 0, |\n", - "| | | | \"min_samples_leaf\": 1, |\n", - "| | | | \"min_samples_split\": 2, |\n", - "| | | | \"min_weight_fraction_leaf\": 0, |\n", - "| | | | \"n_estimators\": 100, |\n", - "| | | | \"subsample\": 1, |\n", - "| | | | \"tol\": 0.0001, |\n", - "| | | | \"validation_fraction\": 0.1, |\n", - "| | | | \"verbose\": 0, |\n", - "| | | | \"warm_start\": false |\n", - "| | | | }, |\n", - "| | | | \"stage\": \"experimental\" |\n", - "| | | |} |\n", - "|0c0e951eb16411ed8846e289b4f89202 |\"uci-bank-marketing\" |2023-02-20 13:18:06.855000-08:00 |{ |\n", - "| | | | \"classifier\": \"RandomForestRegressor\", |\n", - "| | | | \"params\": { |\n", - "| | | | \"bootstrap\": true, |\n", - "| | | | \"ccp_alpha\": 0, |\n", - "| | | | \"criterion\": \"squared_error\", |\n", - "| | | | \"max_features\": 1, |\n", - "| | | | \"min_impurity_decrease\": 0, |\n", - "| | | | \"min_samples_leaf\": 1, |\n", - "| | | | \"min_samples_split\": 2, |\n", - "| | | | \"min_weight_fraction_leaf\": 0, |\n", - "| | | | \"n_estimators\": 100, |\n", - "| | | | \"oob_score\": false, |\n", - "| | | | \"verbose\": 0, |\n", - "| | | | \"warm_start\": false |\n", - "| | | | }, |\n", - "| | | | \"stage\": \"experimental\" |\n", - "| | | |} |\n", - "|ff52e938b16311ed8846e289b4f89202 |\"uci-bank-marketing\" |2023-02-20 13:17:38.185000-08:00 |{ |\n", - "| | | | \"classifier\": \"RandomForestRegressor\", |\n", - "| | | | \"params\": { |\n", - "| | | | \"bootstrap\": true, |\n", - "| | | | \"ccp_alpha\": 0, |\n", - "| | | | \"criterion\": \"squared_error\", |\n", - "| | | | \"max_features\": 1, |\n", - "| | | | \"min_impurity_decrease\": 0, |\n", - "| | | | \"min_samples_leaf\": 1, |\n", - "| | | | \"min_samples_split\": 2, |\n", - "| | | | \"min_weight_fraction_leaf\": 0, |\n", - "| | | | \"n_estimators\": 10, |\n", - "| | | | \"oob_score\": false, |\n", - "| | | | \"verbose\": 0, |\n", - "| | | | \"warm_start\": false |\n", - "| | | | }, |\n", - "| | | | \"stage\": \"experimental\" |\n", - "| | | |} |\n", - "|f268515eb16311ed8846e289b4f89202 |\"uci-bank-marketing\" |2023-02-20 13:17:15.883000-08:00 |{ |\n", - "| | | | \"classifier\": \"RandomForestRegressor\", |\n", - "| | | | \"params\": { |\n", - "| | | | \"bootstrap\": true, |\n", - "| | | | \"ccp_alpha\": 0, |\n", - "| | | | \"criterion\": \"squared_error\", |\n", - "| | | | \"max_features\": 1, |\n", - "| | | | \"min_impurity_decrease\": 0, |\n", - "| | | | \"min_samples_leaf\": 1, |\n", - "| | | | \"min_samples_split\": 2, |\n", - "| | | | \"min_weight_fraction_leaf\": 0, |\n", - "| | | | \"n_estimators\": 1, |\n", - "| | | | \"oob_score\": false, |\n", - "| | | | \"verbose\": 0, |\n", - "| | | | \"warm_start\": false |\n", - "| | | | }, |\n", - "| | | | \"stage\": \"experimental\" |\n", - "| | | |} |\n", - "|e15897cab16311ed8846e289b4f89202 |\"uci-bank-marketing\" |2023-02-20 13:16:47.052000-08:00 |{ |\n", - "| | | | \"classifier\": \"LogisticRegression\", |\n", - "| | | | \"params\": { |\n", - "| | | | \"C\": 1, |\n", - "| | | | \"dual\": false, |\n", - "| | | | \"fit_intercept\": true, |\n", - "| | | | \"intercept_scaling\": 1, |\n", - "| | | | \"max_iter\": 10000, |\n", - "| | | | \"multi_class\": \"auto\", |\n", - "| | | | \"penalty\": \"l2\", |\n", - "| | | | \"solver\": \"lbfgs\", |\n", - "| | | | \"tol\": 0.0001, |\n", - "| | | | \"verbose\": 0, |\n", - "| | | | \"warm_start\": false |\n", - "| | | | }, |\n", - "| | | | \"stage\": \"experimental\" |\n", - "| | | |} |\n", - "|d1e675e6b16311ed8846e289b4f89202 |\"uci-bank-marketing\" |2023-02-20 13:16:21.607000-08:00 |{ |\n", - "| | | | \"classifier\": \"LinearRegression\", |\n", - "| | | | \"params\": { |\n", - "| | | | \"copy_X\": true, |\n", - "| | | | \"fit_intercept\": true, |\n", - "| | | | \"normalize\": \"deprecated\", |\n", - "| | | | \"positive\": false |\n", - "| | | | }, |\n", - "| | | | \"stage\": \"experimental\" |\n", - "| | | |} |\n", - "---------------------------------------------------------------------------------------------------------------------------------------------\n", + "--------------------------------------------------------------------------------------------------------------------------------\n", + "|\"ID\" |\"NAME\" |\"VERSION\" |\"CREATION_TIME\" |\"TAGS\" |\n", + "--------------------------------------------------------------------------------------------------------------------------------\n", + "|8aa8fac2f03911edb94aacde48001122 |my_model |103 |2023-05-11 13:22:26.904000-07:00 |{ |\n", + "| | | | | \"classifier_type\": \"svm.SVC\", |\n", + "| | | | | \"stage\": \"production\", |\n", + "| | | | | \"svc_C\": 10, |\n", + "| | | | | \"svc_gamma\": 0.001 |\n", + "| | | | |} |\n", + "|f3df1616f03411edbc35acde48001122 |my_model |102 |2023-05-11 12:49:35.043000-07:00 |{ |\n", + "| | | | | \"classifier_type\": \"svm.SVC\", |\n", + "| | | | | \"stage\": \"testing\", |\n", + "| | | | | \"svc_C\": 10, |\n", + "| | | | | \"svc_gamma\": 0.001 |\n", + "| | | | |} |\n", + "|13ab6b68f02411ed862bacde48001122 |my_model |100 |2023-05-11 10:48:47.905000-07:00 |{ |\n", + "| | | | | \"classifier_type\": \"svm.SVC\", |\n", + "| | | | | \"stage\": \"testing\", |\n", + "| | | | | \"svc_C\": 10, |\n", + "| | | | | \"svc_gamma\": 0.001 |\n", + "| | | | |} |\n", + "--------------------------------------------------------------------------------------------------------------------------------\n", "\n" ] } ], "source": [ - "model_list.select(\"ID\",\"NAME\",\"CREATION_TIME\",\"TAGS\").order_by(\"CREATION_TIME\", ascending=False).show()" + "model_list.select(\"ID\",\"NAME\",\"VERSION\",\"CREATION_TIME\",\"TAGS\").order_by(\"CREATION_TIME\", ascending=False).show(3)" ] }, { @@ -688,7 +630,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 36, "id": "6df2eafc", "metadata": { "scrolled": false @@ -698,187 +640,40 @@ "name": "stdout", "output_type": "stream", "text": [ - "--------------------------------------------------------------------------------------------------------------------------\n", - "|\"ID\" |\"NAME\" |\"TAGS\" |\"METRICS\" |\n", - "--------------------------------------------------------------------------------------------------------------------------\n", - "|3df8e89099ef11edbd00e289b4f89203 |\"target_digit_6\" |{ |{ |\n", - "| | | \"classifier_type\": \"svm.SVC\", | \"confusion_matrix\": [ |\n", - "| | | \"stage\": \"production\", | [ |\n", - "| | | \"svc_C\": 100, | 88, |\n", - "| | | \"svc_gamma\": 0.0001 | 2 |\n", - "| | |} | ], |\n", - "| | | | [ |\n", - "| | | | 0, |\n", - "| | | | 10 |\n", - "| | | | ] |\n", - "| | | | ], |\n", - "| | | | \"dataset_test\": { |\n", - "| | | | \"accuracy\": 0.98 |\n", - "| | | | }, |\n", - "| | | | \"num_training_examples\": 10, |\n", - "| | | | \"test_accuracy\": 0.98 |\n", - "| | | |} |\n", - "|e6bfa41499ef11edbd00e289b4f89203 |\"target_digit_6\" |{ |{ |\n", - "| | | \"classifier_type\": \"svm.SVC\", | \"confusion_matrix\": [ |\n", - "| | | \"stage\": \"production\", | [ |\n", - "| | | \"svc_C\": 10, | 90, |\n", - "| | | \"svc_gamma\": 0.001 | 0 |\n", - "| | |} | ], |\n", - "| | | | [ |\n", - "| | | | 3, |\n", - "| | | | 7 |\n", - "| | | | ] |\n", - "| | | | ], |\n", - "| | | | \"dataset_test\": { |\n", - "| | | | \"accuracy\": 0.97 |\n", - "| | | | }, |\n", - "| | | | \"num_training_examples\": 10, |\n", - "| | | | \"test_accuracy\": 0.97 |\n", - "| | | |} |\n", - "|a149adb899ee11edbd00e289b4f89203 |\"target_digit_6\" |{ |{ |\n", - "| | | \"classifier_type\": \"svm.SVC\", | \"confusion_matrix\": [ |\n", - "| | | \"stage\": \"production\", | [ |\n", - "| | | \"svc_C\": 100, | 90, |\n", - "| | | \"svc_gamma\": 0.001 | 0 |\n", - "| | |} | ], |\n", - "| | | | [ |\n", - "| | | | 3, |\n", - "| | | | 7 |\n", - "| | | | ] |\n", - "| | | | ], |\n", - "| | | | \"dataset_test\": { |\n", - "| | | | \"accuracy\": 0.97 |\n", - "| | | | }, |\n", - "| | | | \"num_training_examples\": 10, |\n", - "| | | | \"test_accuracy\": 0.97 |\n", - "| | | |} |\n", - "|6a1b8e98b2f711eda1c2e289b4f89202 |\"target_digit_6\" |{ |{ |\n", - "| | | \"classifier_type\": \"svm.SVC\", | \"confusion_matrix\": [ |\n", - "| | | \"stage\": \"production\", | [ |\n", - "| | | \"svc_C\": 10, | 90, |\n", - "| | | \"svc_gamma\": 0.001 | 0 |\n", - "| | |} | ], |\n", - "| | | | [ |\n", - "| | | | 3, |\n", - "| | | | 7 |\n", - "| | | | ] |\n", - "| | | | ], |\n", - "| | | | \"dataset_test\": { |\n", - "| | | | \"accuracy\": 0.97 |\n", - "| | | | }, |\n", - "| | | | \"num_training_examples\": 10, |\n", - "| | | | \"test_accuracy\": 0.97 |\n", - "| | | |} |\n", - "|89f1499eb70011ed9cd3e289b4f89202 |\"target_digit_6\" |{ |{ |\n", - "| | | \"classifier_type\": \"svm.SVC\", | \"confusion_matrix\": [ |\n", - "| | | \"stage\": \"production\", | [ |\n", - "| | | \"svc_C\": 10, | 90, |\n", - "| | | \"svc_gamma\": 0.001 | 0 |\n", - "| | |} | ], |\n", - "| | | | [ |\n", - "| | | | 3, |\n", - "| | | | 7 |\n", - "| | | | ] |\n", - "| | | | ], |\n", - "| | | | \"dataset_test\": { |\n", - "| | | | \"accuracy\": 0.97 |\n", - "| | | | }, |\n", - "| | | | \"num_training_examples\": 10, |\n", - "| | | | \"test_accuracy\": 0.97 |\n", - "| | | |} |\n", - "|98c88dc29ba011ed8b7a001c423403fe |\"target_digit_6\" |{ |{ |\n", - "| | | \"classifier_type\": \"svm.SVC\", | \"confusion_matrix\": [ |\n", - "| | | \"stage\": \"production\", | [ |\n", - "| | | \"svc_C\": 10, | 90, |\n", - "| | | \"svc_gamma\": 0.001 | 0 |\n", - "| | |} | ], |\n", - "| | | | [ |\n", - "| | | | 3, |\n", - "| | | | 7 |\n", - "| | | | ] |\n", - "| | | | ], |\n", - "| | | | \"dataset_test\": { |\n", - "| | | | \"accuracy\": 0.97 |\n", - "| | | | }, |\n", - "| | | | \"num_training_examples\": 10, |\n", - "| | | | \"test_accuracy\": 0.97 |\n", - "| | | |} |\n", - "|1ff312fab2f711edaccee289b4f89202 |\"target_digit_6\" |{ |{ |\n", - "| | | \"classifier_type\": \"svm.SVC\", | \"confusion_matrix\": [ |\n", - "| | | \"stage\": \"production\", | [ |\n", - "| | | \"svc_C\": 10, | 90, |\n", - "| | | \"svc_gamma\": 0.001 | 0 |\n", - "| | |} | ], |\n", - "| | | | [ |\n", - "| | | | 3, |\n", - "| | | | 7 |\n", - "| | | | ] |\n", - "| | | | ], |\n", - "| | | | \"dataset_test\": { |\n", - "| | | | \"accuracy\": 0.97 |\n", - "| | | | }, |\n", - "| | | | \"num_training_examples\": 10, |\n", - "| | | | \"test_accuracy\": 0.97 |\n", - "| | | |} |\n", - "|4c94d3d4b6fd11eda2d4e289b4f89202 |\"target_digit_6\" |{ |{ |\n", - "| | | \"classifier_type\": \"svm.SVC\", | \"confusion_matrix\": [ |\n", - "| | | \"stage\": \"production\", | [ |\n", - "| | | \"svc_C\": 10, | 90, |\n", - "| | | \"svc_gamma\": 0.001 | 0 |\n", - "| | |} | ], |\n", - "| | | | [ |\n", - "| | | | 3, |\n", - "| | | | 7 |\n", - "| | | | ] |\n", - "| | | | ], |\n", - "| | | | \"dataset_test\": { |\n", - "| | | | \"accuracy\": 0.97 |\n", - "| | | | }, |\n", - "| | | | \"num_training_examples\": 10, |\n", - "| | | | \"test_accuracy\": 0.97 |\n", - "| | | |} |\n", - "|0fa110349abd11ed9a9de289b4f89203 |\"target_digit_6\" |{ |{ |\n", - "| | | \"classifier_type\": \"svm.SVC\", | \"confusion_matrix\": [ |\n", - "| | | \"stage\": \"production\", | [ |\n", - "| | | \"svc_C\": 10, | 90, |\n", - "| | | \"svc_gamma\": 0.001 | 0 |\n", - "| | |} | ], |\n", - "| | | | [ |\n", - "| | | | 3, |\n", - "| | | | 7 |\n", - "| | | | ] |\n", - "| | | | ], |\n", - "| | | | \"dataset_test\": { |\n", - "| | | | \"accuracy\": 0.97 |\n", - "| | | | }, |\n", - "| | | | \"num_training_examples\": 10, |\n", - "| | | | \"test_accuracy\": 0.97 |\n", - "| | | |} |\n", - "|11415e4a99ef11edbd00e289b4f89203 |\"target_digit_6\" |{ |{ |\n", - "| | | \"classifier_type\": \"svm.SVC\", | \"confusion_matrix\": [ |\n", - "| | | \"stage\": \"production\", | [ |\n", - "| | | \"svc_C\": 100, | 90, |\n", - "| | | \"svc_gamma\": 0.01 | 0 |\n", - "| | |} | ], |\n", - "| | | | [ |\n", - "| | | | 10, |\n", - "| | | | 0 |\n", - "| | | | ] |\n", - "| | | | ], |\n", - "| | | | \"dataset_test\": { |\n", - "| | | | \"accuracy\": 0.9 |\n", - "| | | | }, |\n", - "| | | | \"num_training_examples\": 10, |\n", - "| | | | \"test_accuracy\": 0.9 |\n", - "| | | |} |\n", - "--------------------------------------------------------------------------------------------------------------------------\n", + "------------------------------------------------------------------------------------------------------------------------------\n", + "|\"ID\" |\"NAME\" |\"VERSION\" |\"TAGS\" |\"METRICS\" |\n", + "------------------------------------------------------------------------------------------------------------------------------\n", + "|ae1a8938efc811edb049acde48001122 |my_model |1 |{ |{ |\n", + "| | | | \"classifier_type\": \"svm.SVC\", | \"confusion_matrix\": [ |\n", + "| | | | \"stage\": \"production\", | [ |\n", + "| | | | \"svc_C\": 10, | 90, |\n", + "| | | | \"svc_gamma\": 0.001 | 0 |\n", + "| | | |} | ], |\n", + "| | | | | [ |\n", + "| | | | | 3, |\n", + "| | | | | 7 |\n", + "| | | | | ] |\n", + "| | | | | ], |\n", + "| | | | | \"dataset_test\": { |\n", + "| | | | | \"accuracy\": 0.97 |\n", + "| | | | | }, |\n", + "| | | | | \"num_training_examples\": 10, |\n", + "| | | | | \"test_accuracy\": 0.97 |\n", + "| | | | |} |\n", + "|2c3779e0efcc11edb049acde48001122 |my_model |2 |{ |NULL |\n", + "| | | | \"classifier_type\": \"svm.SVC\", | |\n", + "| | | | \"stage\": \"testing\", | |\n", + "| | | | \"svc_C\": 10, | |\n", + "| | | | \"svc_gamma\": 0.001 | |\n", + "| | | |} | |\n", + "------------------------------------------------------------------------------------------------------------------------------\n", "\n" ] } ], "source": [ - "model_list.select(\"ID\",\"NAME\",\"TAGS\",\"METRICS\").filter(\n", - " Column(\"NAME\") == new_model_name).order_by(Column(\"METRICS\")[\"test_accuracy\"], ascending=False \n", + "model_list.select(\"ID\",\"NAME\",\"VERSION\",\"TAGS\",\"METRICS\").filter(\n", + " Column(\"NAME\") == \"my_model\").order_by(Column(\"METRICS\")[\"test_accuracy\"], ascending=False \n", ").show() " ] }, @@ -908,7 +703,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 19, "id": "aed50394", "metadata": { "scrolled": false @@ -921,38 +716,39 @@ "-----------------------------------------------------------------------------------------------------------------------------------\n", "|\"EVENT_TIMESTAMP\" |\"ROLE\" |\"ATTRIBUTE_NAME\" |\"VALUE[ATTRIBUTE_NAME]\" |\n", "-----------------------------------------------------------------------------------------------------------------------------------\n", - "|2023-02-27 16:40:50.433000-08:00 |\"ENG_ML_MODELING_RL\" |REGISTRATION |{ |\n", + "|2023-05-11 13:22:28.276000-07:00 |\"ENG_ML_MODELING_RL\" |REGISTRATION |{ |\n", "| | | | \"CREATION_ENVIRONMENT_SPEC\": { |\n", "| | | | \"python\": \"3.8.16\" |\n", "| | | | }, |\n", "| | | | \"CREATION_ROLE\": \"\\\"ENG_ML_MODELING_RL\\\"\", |\n", - "| | | | \"CREATION_TIME\": \"2023-02-27 16:40:50.433 -08... |\n", - "| | | | \"ID\": \"89f1499eb70011ed9cd3e289b4f89202\", |\n", - "| | | | \"TYPE\": \"SVC\", |\n", - "| | | | \"URI\": \"sfc:MODEL_REGISTRY.PUBLIC.SNOWML_MODE... |\n", + "| | | | \"CREATION_TIME\": \"2023-05-11 13:22:28.276 -07... |\n", + "| | | | \"ID\": \"8aa8fac2f03911edb94aacde48001122\", |\n", + "| | | | \"NAME\": \"my_model\", |\n", + "| | | | \"TYPE\": \"snowflake_native\", |\n", + "| | | | \"URI\": \"sfc:model_registry_zzhu.PUBLIC.SNOWML... |\n", + "| | | | \"VERSION\": \"103\" |\n", "| | | |} |\n", - "|2023-02-27 16:40:52.919000-08:00 |\"ENG_ML_MODELING_RL\" |NAME |\"my_model\" |\n", - "|2023-02-27 16:40:54.661000-08:00 |\"ENG_ML_MODELING_RL\" |TAGS |{ |\n", + "|2023-05-11 13:22:30.424000-07:00 |\"ENG_ML_MODELING_RL\" |TAGS |{ |\n", "| | | | \"classifier_type\": \"svm.SVC\", |\n", "| | | | \"stage\": \"testing\", |\n", "| | | | \"svc_C\": 10, |\n", "| | | | \"svc_gamma\": 0.001 |\n", "| | | |} |\n", - "|2023-02-27 16:40:57.991000-08:00 |\"ENG_ML_MODELING_RL\" |METRICS |{ |\n", + "|2023-05-11 13:22:55.070000-07:00 |\"ENG_ML_MODELING_RL\" |METRICS |{ |\n", "| | | | \"test_accuracy\": 0.97 |\n", "| | | |} |\n", - "|2023-02-27 16:41:00.957000-08:00 |\"ENG_ML_MODELING_RL\" |METRICS |{ |\n", + "|2023-05-11 13:22:57.459000-07:00 |\"ENG_ML_MODELING_RL\" |METRICS |{ |\n", "| | | | \"num_training_examples\": 10, |\n", "| | | | \"test_accuracy\": 0.97 |\n", "| | | |} |\n", - "|2023-02-27 16:41:03.460000-08:00 |\"ENG_ML_MODELING_RL\" |METRICS |{ |\n", + "|2023-05-11 13:22:58.826000-07:00 |\"ENG_ML_MODELING_RL\" |METRICS |{ |\n", "| | | | \"dataset_test\": { |\n", "| | | | \"accuracy\": 0.97 |\n", "| | | | }, |\n", "| | | | \"num_training_examples\": 10, |\n", "| | | | \"test_accuracy\": 0.97 |\n", "| | | |} |\n", - "|2023-02-27 16:41:06.240000-08:00 |\"ENG_ML_MODELING_RL\" |METRICS |{ |\n", + "|2023-05-11 13:23:00.636000-07:00 |\"ENG_ML_MODELING_RL\" |METRICS |{ |\n", "| | | | \"confusion_matrix\": [ |\n", "| | | | [ |\n", "| | | | 90, |\n", @@ -969,32 +765,33 @@ "| | | | \"num_training_examples\": 10, |\n", "| | | | \"test_accuracy\": 0.97 |\n", "| | | |} |\n", - "|2023-02-27 16:41:10.150000-08:00 |\"ENG_ML_MODELING_RL\" |TAGS |{ |\n", + "|2023-05-11 13:24:28.065000-07:00 |\"ENG_ML_MODELING_RL\" |TAGS |{ |\n", "| | | | \"classifier_type\": \"svm.SVC\", |\n", "| | | | \"minor_version\": \"23\", |\n", "| | | | \"stage\": \"testing\", |\n", "| | | | \"svc_C\": 10, |\n", "| | | | \"svc_gamma\": 0.001 |\n", "| | | |} |\n", - "|2023-02-27 16:41:13.219000-08:00 |\"ENG_ML_MODELING_RL\" |TAGS |{ |\n", + "|2023-05-11 13:24:30.872000-07:00 |\"ENG_ML_MODELING_RL\" |TAGS |{ |\n", "| | | | \"classifier_type\": \"svm.SVC\", |\n", "| | | | \"stage\": \"testing\", |\n", "| | | | \"svc_C\": 10, |\n", "| | | | \"svc_gamma\": 0.001 |\n", "| | | |} |\n", - "|2023-02-27 16:41:16.383000-08:00 |\"ENG_ML_MODELING_RL\" |TAGS |{ |\n", + "|2023-05-11 13:24:33.108000-07:00 |\"ENG_ML_MODELING_RL\" |TAGS |{ |\n", "| | | | \"classifier_type\": \"svm.SVC\", |\n", "| | | | \"stage\": \"production\", |\n", "| | | | \"svc_C\": 10, |\n", "| | | | \"svc_gamma\": 0.001 |\n", "| | | |} |\n", + "|2023-05-11 13:24:35.150000-07:00 |\"ENG_ML_MODELING_RL\" |DESCRIPTION |\"My model is better than talkgpt-5!\" |\n", "-----------------------------------------------------------------------------------------------------------------------------------\n", "\n" ] } ], "source": [ - "registry.get_model_history(id=model_id).select(\"EVENT_TIMESTAMP\", \"ROLE\", \"ATTRIBUTE_NAME\",\"VALUE[ATTRIBUTE_NAME]\").show()" + "registry.get_model_history(model_name=\"my_model\", model_version=\"103\").select(\"EVENT_TIMESTAMP\", \"ROLE\", \"ATTRIBUTE_NAME\",\"VALUE[ATTRIBUTE_NAME]\").show()" ] }, { @@ -1007,7 +804,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 21, "id": "21f7e0b5", "metadata": { "scrolled": false @@ -1020,38 +817,39 @@ "-----------------------------------------------------------------------------------------------------------------------------------\n", "|\"EVENT_TIMESTAMP\" |\"ROLE\" |\"ATTRIBUTE_NAME\" |\"VALUE[ATTRIBUTE_NAME]\" |\n", "-----------------------------------------------------------------------------------------------------------------------------------\n", - "|2023-02-27 16:40:50.433000-08:00 |\"ENG_ML_MODELING_RL\" |REGISTRATION |{ |\n", + "|2023-05-11 13:22:28.276000-07:00 |\"ENG_ML_MODELING_RL\" |REGISTRATION |{ |\n", "| | | | \"CREATION_ENVIRONMENT_SPEC\": { |\n", "| | | | \"python\": \"3.8.16\" |\n", "| | | | }, |\n", "| | | | \"CREATION_ROLE\": \"\\\"ENG_ML_MODELING_RL\\\"\", |\n", - "| | | | \"CREATION_TIME\": \"2023-02-27 16:40:50.433 -08... |\n", - "| | | | \"ID\": \"89f1499eb70011ed9cd3e289b4f89202\", |\n", - "| | | | \"TYPE\": \"SVC\", |\n", - "| | | | \"URI\": \"sfc:MODEL_REGISTRY.PUBLIC.SNOWML_MODE... |\n", + "| | | | \"CREATION_TIME\": \"2023-05-11 13:22:28.276 -07... |\n", + "| | | | \"ID\": \"8aa8fac2f03911edb94aacde48001122\", |\n", + "| | | | \"NAME\": \"my_model\", |\n", + "| | | | \"TYPE\": \"snowflake_native\", |\n", + "| | | | \"URI\": \"sfc:model_registry_zzhu.PUBLIC.SNOWML... |\n", + "| | | | \"VERSION\": \"103\" |\n", "| | | |} |\n", - "|2023-02-27 16:40:52.919000-08:00 |\"ENG_ML_MODELING_RL\" |NAME |\"my_model\" |\n", - "|2023-02-27 16:40:54.661000-08:00 |\"ENG_ML_MODELING_RL\" |TAGS |{ |\n", + "|2023-05-11 13:22:30.424000-07:00 |\"ENG_ML_MODELING_RL\" |TAGS |{ |\n", "| | | | \"classifier_type\": \"svm.SVC\", |\n", "| | | | \"stage\": \"testing\", |\n", "| | | | \"svc_C\": 10, |\n", "| | | | \"svc_gamma\": 0.001 |\n", "| | | |} |\n", - "|2023-02-27 16:40:57.991000-08:00 |\"ENG_ML_MODELING_RL\" |METRICS |{ |\n", + "|2023-05-11 13:22:55.070000-07:00 |\"ENG_ML_MODELING_RL\" |METRICS |{ |\n", "| | | | \"test_accuracy\": 0.97 |\n", "| | | |} |\n", - "|2023-02-27 16:41:00.957000-08:00 |\"ENG_ML_MODELING_RL\" |METRICS |{ |\n", + "|2023-05-11 13:22:57.459000-07:00 |\"ENG_ML_MODELING_RL\" |METRICS |{ |\n", "| | | | \"num_training_examples\": 10, |\n", "| | | | \"test_accuracy\": 0.97 |\n", "| | | |} |\n", - "|2023-02-27 16:41:03.460000-08:00 |\"ENG_ML_MODELING_RL\" |METRICS |{ |\n", + "|2023-05-11 13:22:58.826000-07:00 |\"ENG_ML_MODELING_RL\" |METRICS |{ |\n", "| | | | \"dataset_test\": { |\n", "| | | | \"accuracy\": 0.97 |\n", "| | | | }, |\n", "| | | | \"num_training_examples\": 10, |\n", "| | | | \"test_accuracy\": 0.97 |\n", "| | | |} |\n", - "|2023-02-27 16:41:06.240000-08:00 |\"ENG_ML_MODELING_RL\" |METRICS |{ |\n", + "|2023-05-11 13:23:00.636000-07:00 |\"ENG_ML_MODELING_RL\" |METRICS |{ |\n", "| | | | \"confusion_matrix\": [ |\n", "| | | | [ |\n", "| | | | 90, |\n", @@ -1068,25 +866,26 @@ "| | | | \"num_training_examples\": 10, |\n", "| | | | \"test_accuracy\": 0.97 |\n", "| | | |} |\n", - "|2023-02-27 16:41:10.150000-08:00 |\"ENG_ML_MODELING_RL\" |TAGS |{ |\n", + "|2023-05-11 13:24:28.065000-07:00 |\"ENG_ML_MODELING_RL\" |TAGS |{ |\n", "| | | | \"classifier_type\": \"svm.SVC\", |\n", "| | | | \"minor_version\": \"23\", |\n", "| | | | \"stage\": \"testing\", |\n", "| | | | \"svc_C\": 10, |\n", "| | | | \"svc_gamma\": 0.001 |\n", "| | | |} |\n", - "|2023-02-27 16:41:13.219000-08:00 |\"ENG_ML_MODELING_RL\" |TAGS |{ |\n", + "|2023-05-11 13:24:30.872000-07:00 |\"ENG_ML_MODELING_RL\" |TAGS |{ |\n", "| | | | \"classifier_type\": \"svm.SVC\", |\n", "| | | | \"stage\": \"testing\", |\n", "| | | | \"svc_C\": 10, |\n", "| | | | \"svc_gamma\": 0.001 |\n", "| | | |} |\n", - "|2023-02-27 16:41:16.383000-08:00 |\"ENG_ML_MODELING_RL\" |TAGS |{ |\n", + "|2023-05-11 13:24:33.108000-07:00 |\"ENG_ML_MODELING_RL\" |TAGS |{ |\n", "| | | | \"classifier_type\": \"svm.SVC\", |\n", "| | | | \"stage\": \"production\", |\n", "| | | | \"svc_C\": 10, |\n", "| | | | \"svc_gamma\": 0.001 |\n", "| | | |} |\n", + "|2023-05-11 13:24:35.150000-07:00 |\"ENG_ML_MODELING_RL\" |DESCRIPTION |\"My model is better than talkgpt-5!\" |\n", "-----------------------------------------------------------------------------------------------------------------------------------\n", "\n" ] @@ -1122,7 +921,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 22, "id": "cc0512e1", "metadata": {}, "outputs": [ @@ -1136,9 +935,9 @@ } ], "source": [ - "registry = model_registry.ModelRegistry(session=session)\n", + "registry = model_registry.ModelRegistry(session=session, name=registry_name)\n", "\n", - "restored_clf = registry.load_model(id=model_id)\n", + "restored_clf = registry.load_model(model_name=\"my_model\", model_version=\"103\")\n", "\n", "restored_prediction = restored_clf.predict(test_features)\n", "\n", @@ -1156,7 +955,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 23, "id": "2796f2e0", "metadata": {}, "outputs": [ @@ -1170,8 +969,8 @@ } ], "source": [ - "registry = model_registry.ModelRegistry(session=session)\n", - "model = model_registry.ModelReference(registry=registry, id=model_id)\n", + "registry = model_registry.ModelRegistry(session=session, name=registry_name)\n", + "model = model_registry.ModelReference(registry=registry, model_name=\"my_model\", model_version=\"103\")\n", "restored_clf = model.load_model()\n", "\n", "restored_prediction = restored_clf.predict(test_features)\n", diff --git a/snowflake/ml/registry/ui/render_pandas_df.py b/snowflake/ml/registry/ui/render_pandas_df.py index 5c76bc88..232ad230 100644 --- a/snowflake/ml/registry/ui/render_pandas_df.py +++ b/snowflake/ml/registry/ui/render_pandas_df.py @@ -141,8 +141,9 @@ def render(self) -> RenderedContent: model_list = model_list.filter(model_list["ID"] == model_id) transpose = True content.header = "Details of model ID: " + model_id + model_info = model_list.collect()[0] history = ( - self._registry.get_model_history(id=model_id) + self._registry.get_model_history(model_name=model_info["NAME"], model_version=model_info["VERSION"]) .select(_HISTORY_COLUMNS) .order_by("EVENT_TIMESTAMP", ascending=False) ) diff --git a/tests/integ/snowflake/ml/preprocessing/test_k_bins_discretizer.py b/tests/integ/snowflake/ml/preprocessing/test_k_bins_discretizer.py index 9857066e..50dc0b15 100644 --- a/tests/integ/snowflake/ml/preprocessing/test_k_bins_discretizer.py +++ b/tests/integ/snowflake/ml/preprocessing/test_k_bins_discretizer.py @@ -22,6 +22,9 @@ def setUp(self) -> None: self._session = Session.builder.configs(SnowflakeLoginOptions()).create() self._strategies = ["quantile", "uniform"] + def tearDown(self) -> None: + self._session.close() + # TODO(tbao): just some small dev functions for fast iteration, remove later # def test_dummy(self) -> None: # import pandas as pd