Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions bazel/requirements/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ _GENERATE_TOOL = ":parse_and_generate_requirements"

_GENERATE_COMMAND = "$(location " + _GENERATE_TOOL + ") $(location " + _SRC_REQUIREMENT_FILE + ") --schema $(location " + _SCHEMA_FILE + ") {options} > $@"

_TEMPLATE_FOLDER_PATH = "//bazel/requirements/templates"

_AUTOGEN_HEADERS = """# DO NOT EDIT!
# Generate by running 'bazel run //bazel/requirements:sync_requirements'
"""

# "---" is a document start marker, which is legit but optional (https://yaml.org/spec/1.1/#c-document-start). This
# is needed for conda meta.yaml to bypass some bug from conda side.
_YAML_START_DOCUMENT_MARKER = "---"

_GENERATED_REQUIREMENTS_FILES = {
"requirements_txt": {
"cmd": "--mode dev_version --format text",
Expand Down Expand Up @@ -77,7 +79,7 @@ _GENERATED_REQUIREMENTS_FILES = {
"{generated}.body".format(generated = value["generated"]),
],
outs = [value["generated"]],
cmd = "(echo -e \""+ _AUTOGEN_HEADERS +"\" ; cat $(location :{generated}.body) ) > $@".format(
cmd = "(echo -e \"" + _AUTOGEN_HEADERS + "\" ; cat $(location :{generated}.body) ) > $@".format(
generated = value["generated"],
),
tools = [_GENERATE_TOOL],
Expand All @@ -99,15 +101,24 @@ genrule(
)

yq(
name = "gen_conda_meta",
name = "gen_conda_meta_body_format",
srcs = [
":meta.body.yaml",
"{template_folder}:meta.tpl.yaml".format(template_folder = _TEMPLATE_FOLDER_PATH),
"//bazel/requirements/templates:meta.tpl.yaml",
],
outs = ["meta.yaml"],
outs = ["meta.body.formatted.yaml"],
expression = ". as $item ireduce ({}; . * $item ) | sort_keys(..)",
)

genrule(
name = "gen_conda_meta",
srcs = [
":meta.body.formatted.yaml",
],
outs = ["meta.yaml"],
cmd = "(echo -e \"" + _AUTOGEN_HEADERS + "\" ; echo \"" + _YAML_START_DOCUMENT_MARKER + "\"; cat $(location :meta.body.formatted.yaml) ) > $@",
)

# Create a test target for each file that Bazel should
# write to the source tree.
[
Expand Down
6 changes: 0 additions & 6 deletions bazel/requirements/templates/meta.tpl.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
# DO NOT EDIT!
# Generated by //bazel/requirements:gen_conda_meta
# To update, run:
# bazel run //bazel/requirements:sync_requirements
#

package:
name: snowflake-ml-python

Expand Down
7 changes: 3 additions & 4 deletions ci/conda_recipe/meta.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# DO NOT EDIT!
# Generated by //bazel/requirements:gen_conda_meta
# To update, run:
# bazel run //bazel/requirements:sync_requirements
#
# Generate by running 'bazel run //bazel/requirements:sync_requirements'

---
about:
description: |
Snowflake ML client Library is used for interacting with Snowflake to build machine learning solutions.
Expand Down
2 changes: 1 addition & 1 deletion ci/get_excluded_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# The missing dependency cuold happen when a new operator is being developed, but not yet released.

set -o pipefail
set -eu
set -u

echo "Running "$0

Expand Down
19 changes: 6 additions & 13 deletions snowflake/ml/modeling/impute/simple_imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from snowflake.snowpark import functions as F, types as T
from snowflake.snowpark._internal import utils as snowpark_utils

_SUBPROJECT = "Impute"

STRATEGY_TO_STATE_DICT = {
"constant": None,
"mean": _utils.NumericStatistics.MEAN,
Expand Down Expand Up @@ -194,10 +196,7 @@ def check_type_consistency(col_types: Dict[str, T.DataType]) -> None:

return input_col_datatypes

@telemetry.send_api_usage_telemetry(
project=base.PROJECT,
subproject=base.SUBPROJECT,
)
@telemetry.send_api_usage_telemetry(project=base.PROJECT, subproject=_SUBPROJECT)
def fit(self, dataset: snowpark.DataFrame) -> "SimpleImputer":
"""
Compute values to impute for the dataset according to the strategy.
Expand All @@ -214,7 +213,7 @@ def fit(self, dataset: snowpark.DataFrame) -> "SimpleImputer":
input_col_datatypes = self._get_dataset_input_col_datatypes(dataset)

self.statistics_: Dict[str, Any] = {}
statement_params = telemetry.get_statement_params(base.PROJECT, base.SUBPROJECT, self.__class__.__name__)
statement_params = telemetry.get_statement_params(base.PROJECT, _SUBPROJECT, self.__class__.__name__)

if self.strategy == "constant":
if self.fill_value is None:
Expand Down Expand Up @@ -274,14 +273,8 @@ def fit(self, dataset: snowpark.DataFrame) -> "SimpleImputer":
self._is_fitted = True
return self

@telemetry.send_api_usage_telemetry(
project=base.PROJECT,
subproject=base.SUBPROJECT,
)
@telemetry.add_stmt_params_to_df(
project=base.PROJECT,
subproject=base.SUBPROJECT,
)
@telemetry.send_api_usage_telemetry(project=base.PROJECT, subproject=_SUBPROJECT)
@telemetry.add_stmt_params_to_df(project=base.PROJECT, subproject=_SUBPROJECT)
def transform(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> Union[snowpark.DataFrame, pd.DataFrame]:
"""
Transform the input dataset by imputing the computed statistics in the input columns.
Expand Down
1 change: 1 addition & 0 deletions snowflake/ml/modeling/metrics/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ py_library(
"precision_recall_fscore_support.py",
"precision_score.py",
"regression.py",
"roc_curve.py",
],
deps = [
":init",
Expand Down
2 changes: 2 additions & 0 deletions snowflake/ml/modeling/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .covariance import covariance
from .precision_recall_fscore_support import precision_recall_fscore_support
from .precision_score import precision_score
from .roc_curve import roc_curve

__all__ = [
"accuracy_score",
Expand All @@ -12,4 +13,5 @@
"covariance",
"precision_recall_fscore_support",
"precision_score",
"roc_curve",
]
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,18 @@ def precision_recall_fscore_support(

session = df._session
assert session is not None
query = df.queries["queries"][-1]
sproc_name = f"precision_recall_fscore_support_{snowpark_utils.generate_random_alphanumeric()}"
statement_params = telemetry.get_statement_params(_PROJECT, _SUBPROJECT)

cols = []
if isinstance(y_true_col_names, str):
cols = [y_true_col_names, y_pred_col_names]
elif isinstance(y_true_col_names, list):
cols = y_true_col_names + y_pred_col_names # type:ignore[assignment, operator]
if sample_weight_col_name:
cols.append(sample_weight_col_name)
query = df[cols].queries["queries"][-1]

@F.sproc( # type: ignore[misc]
session=session,
name=sproc_name,
Expand Down
94 changes: 94 additions & 0 deletions snowflake/ml/modeling/metrics/roc_curve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from typing import Optional, Tuple, Union

import cloudpickle
import numpy.typing as npt
from sklearn import metrics

from snowflake import snowpark
from snowflake.ml._internal import telemetry
from snowflake.snowpark import functions as F
from snowflake.snowpark._internal import utils as snowpark_utils

_PROJECT = "ModelDevelopment"
_SUBPROJECT = "Metrics"


@telemetry.send_api_usage_telemetry(project=_PROJECT, subproject=_SUBPROJECT)
def roc_curve(
*,
df: snowpark.DataFrame,
y_true_col_name: str,
y_score_col_name: str,
pos_label: Optional[Union[str, int]] = None,
sample_weight_col_name: Optional[str] = None,
drop_intermediate: bool = True,
) -> Tuple[npt.ArrayLike, npt.ArrayLike, npt.ArrayLike]:
"""
Compute Receiver operating characteristic (ROC).

Note: this implementation is restricted to the binary classification task.

Args:
df: Input dataframe.
y_true_col_name: Column name representing true binary labels.
If labels are not either {-1, 1} or {0, 1}, then pos_label should be
explicitly given.
y_score_col_name: Column name representing target scores, can either
be probability estimates of the positive class, confidence values,
or non-thresholded measure of decisions (as returned by
"decision_function" on some classifiers).
pos_label: The label of the positive class.
When ``pos_label=None``, if `y_true` is in {-1, 1} or {0, 1},
``pos_label`` is set to 1, otherwise an error will be raised.
sample_weight_col_name: Column name representing sample weights.
drop_intermediate: Whether to drop some suboptimal thresholds which would
not appear on a plotted ROC curve. This is useful in order to create
lighter ROC curves.

Returns:
fpr: ndarray of shape (>2,)
Increasing false positive rates such that element i is the false
positive rate of predictions with score >= `thresholds[i]`.
tpr : ndarray of shape (>2,)
Increasing true positive rates such that element `i` is the true
positive rate of predictions with score >= `thresholds[i]`.
thresholds : ndarray of shape = (n_thresholds,)
Decreasing thresholds on the decision function used to compute
fpr and tpr. `thresholds[0]` represents no instances being predicted
and is arbitrarily set to `max(y_score) + 1`.
"""
session = df._session
assert session is not None
sproc_name = f"roc_curve_{snowpark_utils.generate_random_alphanumeric()}"
statement_params = telemetry.get_statement_params(_PROJECT, _SUBPROJECT)

cols = [y_true_col_name, y_score_col_name]
if sample_weight_col_name:
cols.append(sample_weight_col_name)
query = df[cols].queries["queries"][-1]

@F.sproc( # type: ignore[misc]
session=session,
name=sproc_name,
replace=True,
packages=["cloudpickle", "scikit-learn", "snowflake-snowpark-python"],
statement_params=statement_params,
)
def roc_curve_sproc(session: snowpark.Session) -> bytes:
df = session.sql(query).to_pandas(statement_params=statement_params)
y_true = df[y_true_col_name]
y_score = df[y_score_col_name]
sample_weight = df[sample_weight_col_name] if sample_weight_col_name else None
fpr, tpr, thresholds = metrics.roc_curve(
y_true,
y_score,
pos_label=pos_label,
sample_weight=sample_weight,
drop_intermediate=drop_intermediate,
)

return cloudpickle.dumps((fpr, tpr, thresholds)) # type: ignore[no-any-return]

loaded_data = cloudpickle.loads(session.call(sproc_name))
res: Tuple[npt.ArrayLike, npt.ArrayLike, npt.ArrayLike] = loaded_data
return res
12 changes: 7 additions & 5 deletions tests/integ/snowflake/ml/modeling/framework/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class DataType(Enum):


def gen_fuzz_data(
rows: int, types: List[DataType], low: int = MIN_INT, high: int = MAX_INT
rows: int, types: List[DataType], low: Union[int, List[int]] = MIN_INT, high: Union[int, List[int]] = MAX_INT
) -> Tuple[List[Any], List[str]]:
"""
Generate random data based on input column types and row count.
Expand All @@ -153,8 +153,8 @@ def gen_fuzz_data(
Args:
rows: num of rows to generate
types: type per column
low: lower bound of the output interval (inclusive)
high: upper bound of the output interval (exclusive)
low: lower bound(s) of the output interval (inclusive)
high: upper bound(s) of the output interval (exclusive)

Returns:
A tuple of generated data and column names
Expand All @@ -166,10 +166,12 @@ def gen_fuzz_data(
names = ["ID"]

for idx, t in enumerate(types):
_low = low if isinstance(low, int) else low[idx]
_high = high if isinstance(high, int) else high[idx]
if t == DataType.INTEGER:
data.append(np.random.randint(low, high, rows))
data.append(np.random.randint(_low, _high, rows))
elif t == DataType.FLOAT:
data.append(np.random.uniform(low, high, rows))
data.append(np.random.uniform(_low, _high, rows))
else:
raise ValueError(f"Unsupported data type {t}")
names.append(f"COL_{idx}")
Expand Down
27 changes: 22 additions & 5 deletions tests/integ/snowflake/ml/modeling/metrics/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ load("//bazel:py_rules.bzl", "py_test")

package(default_visibility = ["//visibility:public"])

SHARD_COUNT = 3
TIMEOUT = "long" # 900s

py_test(
name = "test_r2_score",
srcs = ["test_r2_score.py"],
Expand All @@ -23,7 +26,7 @@ py_test(

py_test(
name = "test_confusion_matrix",
timeout = "long",
timeout = TIMEOUT,
srcs = ["test_confusion_matrix.py"],
deps = [
"//snowflake/ml/modeling/metrics",
Expand All @@ -34,7 +37,7 @@ py_test(

py_test(
name = "test_correlation",
timeout = "long",
timeout = TIMEOUT,
srcs = ["test_correlation.py"],
deps = [
"//snowflake/ml/modeling/metrics",
Expand All @@ -44,7 +47,7 @@ py_test(

py_test(
name = "test_covariance",
timeout = "long",
timeout = TIMEOUT,
srcs = ["test_covariance.py"],
deps = [
"//snowflake/ml/modeling/metrics",
Expand All @@ -54,8 +57,9 @@ py_test(

py_test(
name = "test_precision_recall_fscore_support",
timeout = "long",
timeout = TIMEOUT,
srcs = ["test_precision_recall_fscore_support.py"],
shard_count = SHARD_COUNT,
deps = [
"//snowflake/ml/modeling/metrics",
"//snowflake/ml/utils:connection_params",
Expand All @@ -65,8 +69,21 @@ py_test(

py_test(
name = "test_precision_score",
timeout = "long",
timeout = TIMEOUT,
srcs = ["test_precision_score.py"],
shard_count = SHARD_COUNT,
deps = [
"//snowflake/ml/modeling/metrics",
"//snowflake/ml/utils:connection_params",
"//tests/integ/snowflake/ml/modeling/framework:utils",
],
)

py_test(
name = "test_roc_curve",
timeout = TIMEOUT,
srcs = ["test_roc_curve.py"],
shard_count = SHARD_COUNT,
deps = [
"//snowflake/ml/modeling/metrics",
"//snowflake/ml/utils:connection_params",
Expand Down
Loading