Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions ci/conda_recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,22 @@ requirements:
run:
- python
- absl-py>=0.15,<2
- anyio>=3.5.0,<4
- fsspec>=2022.11,<=2023.1
- numpy>=1.23,<1.24
- pyyaml>=6.0,<7
- scipy>=1.9,<2
- scikit-learn==1.2.1
- snowflake-connector-python
- snowflake-snowpark-python>=1.0.0,<=1.3
- snowflake-snowpark-python>=1.3.0,<=2
- sqlparse>=0.4,<1

# TODO(snandamuri): Versions of these packages must be exactly same between user's workspace and
# snowpark sandbox. Generic definitions like scikit-learn>=1.1.0,<2 wont work because snowflake conda channel
# only has a few allowlisted versions of scikit-learn available, so we must force users to use scikit-learn
# versions that are available in the snowflake conda channel. Since there is no way to specify allow list of
# versions in the requirements file, we are pinning the versions here.
- joblib>=1.0.0,<=1.1.1
- scikit-learn==1.2.1
- xgboost==1.7.3
about:
home: https://github.com/snowflakedb/snowflake-ml-python
Expand Down
46 changes: 12 additions & 34 deletions codegen/sklearn_wrapper_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,17 +369,11 @@ class WrapperGeneratorBase:

original_class_name INFERRED Class name for the given scikit-learn
estimator.
estimator_class_name GENERATED Name for the new estimator class.
transformer_class_name GENERATED [TODO] Name for the new transformer
class.
module_name INFERRED Name of the module that given class is
is contained in.
estimator_imports GENERATED Imports needed for the estimator / fit()
call.
fit_sproc_imports GENERATED Imports needed for the fit sproc call.
transform_function_name INFERRED Name for the transformer function. This
will be one of "transform" or
"predict()" depending on the class.
------------------------------------------------------------------------------------
SIGNATURES AND ARGUMENTS
------------------------------------------------------------------------------------
Expand Down Expand Up @@ -444,9 +438,6 @@ def __init__(self, module_name: str, class_object: Tuple[str, type]) -> None:

# Naming of the class.
self.original_class_name = ""
self.estimator_class_name = ""
self.transformer_class_name = ""
self.transform_function_name = ""

# The signature and argument passing the __init__ functions.
self.original_init_signature = inspect.Signature()
Expand All @@ -456,33 +447,32 @@ def __init__(self, module_name: str, class_object: Tuple[str, type]) -> None:
self.sklearn_init_args_dict = ""
self.estimator_init_member_args = ""

# Doc strings
self.original_class_docstring = ""
self.estimator_class_docstring = ""
self.transformer_class_docstring = ""

self.estimator_imports = ""
self.estimator_imports_list: List[str] = []

self.original_fit_docstring = ""
self.fit_docstring = ""
self.original_transform_docstring = ""
self.transform_docstring = ""

# Import strings
self.estimator_imports = ""
self.estimator_imports_list: List[str] = []
self.additional_import_statements = ""

# Test strings
self.test_dataset_func = ""
self.test_estimator_input_args = ""
self.test_estimator_input_args_list: List[str] = []
self.test_class_name = ""
self.test_estimator_imports = ""
self.test_estimator_imports_list: List[str] = []

self.additional_import_statements = ""

# Dependencies
self.predict_udf_deps = ""
self.fit_sproc_deps = ""

# TODO(amauser): Make fit a no-op if there is no internal state
# TODO(amauser): handling sparse input and output (LabelBinarizer)

def _format_default_value(self, default_value: Any) -> str:
if isinstance(default_value, str):
return f'"{default_value}"'
Expand Down Expand Up @@ -561,26 +551,13 @@ def split_long_lines(line: str) -> str:
self.estimator_class_docstring = class_docstring

def _populate_class_names(self) -> None:
# TODO(snandamuri): All the 3 fields have exact same value. Do we really need these
# 3 separate fields?
self.original_class_name = self.class_object[0]
self.estimator_class_name = self.original_class_name
self.transformer_class_name = self.estimator_class_name

self.test_class_name = f"{self.original_class_name}Test"

def _populate_function_names_and_signatures(self) -> None:
for member in inspect.getmembers(self.class_object[1]):
if member[0] == "__init__":
self.original_init_signature = inspect.signature(member[1])
elif member[0] == "predict" or member[0] == "transform":
if self.transform_function_name != "":
print("ERROR: Class has both transform() and predict() methods.")
# TODO(snandamuri): Add support for both transform() and predict() methods in estimators.
# For now, resolve to predict() method when both predict() and transform() are available.
self.transform_function_name = "predict"
else:
self.transform_function_name = member[0]

signature_lines = []
sklearn_init_lines = []
Expand Down Expand Up @@ -642,6 +619,7 @@ def _populate_function_names_and_signatures(self) -> None:
self.estimator_init_member_args = "\n ".join(init_member_args)
self.estimator_args_transform_calls = "\n ".join(arg_transform_calls)

# TODO(snandamuri): Implement type inference for classifiers.
self.udf_datatype = "float" if self._from_data_py or self._is_regressor else ""

def _populate_file_paths(self) -> None:
Expand Down Expand Up @@ -825,7 +803,7 @@ def generate(self) -> "SklearnWrapperGenerator":
self.test_estimator_input_args_list.extend(["min_samples_leaf=1", "max_leaf_nodes=100"])

self.fit_sproc_deps = self.predict_udf_deps = (
"f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'scikit-learn=={sklearn.__version__}',"
"f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'scikit-learn=={sklearn.__version__}', "
"f'xgboost=={xgboost.__version__}', f'joblib=={joblib.__version__}'"
)
self._construct_string_from_lists()
Expand All @@ -842,7 +820,7 @@ def generate(self) -> "XGBoostWrapperGenerator":
self.test_estimator_input_args_list.extend(["random_state=0", "subsample=1.0", "colsample_bynode=1.0"])
self.fit_sproc_imports = "import xgboost"
self.fit_sproc_deps = self.predict_udf_deps = (
"f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'xgboost=={xgboost.__version__}',"
"f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'xgboost=={xgboost.__version__}', "
"f'joblib=={joblib.__version__}'"
)
self._construct_string_from_lists()
Expand All @@ -859,7 +837,7 @@ def generate(self) -> "LightGBMWrapperGenerator":
self.test_estimator_input_args_list.extend(["random_state=0"])
self.fit_sproc_imports = "import lightgbm"
self.fit_sproc_deps = self.predict_udf_deps = (
"f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'lightgbm=={lightgbm.__version__}',"
"f'numpy=={np.__version__}', f'pandas=={pd.__version__}', f'lightgbm=={lightgbm.__version__}', "
"f'joblib=={joblib.__version__}'"
)
self._construct_string_from_lists()
Expand Down
Loading